diff --git a/.ci/aarch64_linux/aarch64_ci_build.sh b/.ci/aarch64_linux/aarch64_ci_build.sh index ff3337e3f6d8c7..b914b28f755676 100644 --- a/.ci/aarch64_linux/aarch64_ci_build.sh +++ b/.ci/aarch64_linux/aarch64_ci_build.sh @@ -3,9 +3,7 @@ set -eux -o pipefail GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-} -if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then - export TORCH_CUDA_ARCH_LIST="9.0" -elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then +if [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then export TORCH_CUDA_ARCH_LIST="9.0;10.0;12.0" fi @@ -27,6 +25,7 @@ if [ "$DESIRED_CUDA" = "cpu" ]; then USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn else echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA" + export USE_SYSTEM_NCCL=1 #USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda fi diff --git a/.ci/aarch64_linux/aarch64_wheel_ci_build.py b/.ci/aarch64_linux/aarch64_wheel_ci_build.py index 23097a4c483a79..d7bbdebc677aba 100755 --- a/.ci/aarch64_linux/aarch64_wheel_ci_build.py +++ b/.ci/aarch64_linux/aarch64_wheel_ci_build.py @@ -79,6 +79,7 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None: os.system(f"unzip {wheel_path} -d {folder}/tmp") libs_to_copy = [ "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12", + "/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so", "/usr/local/cuda/lib64/libcudnn.so.9", "/usr/local/cuda/lib64/libcublas.so.12", "/usr/local/cuda/lib64/libcublasLt.so.12", @@ -88,7 +89,7 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None: "/usr/local/cuda/lib64/libcusparseLt.so.0", "/usr/local/cuda/lib64/libcusolver.so.11", "/usr/local/cuda/lib64/libcurand.so.10", - "/usr/local/cuda/lib64/libnvToolsExt.so.1", + "/usr/local/cuda/lib64/libnccl.so.2", "/usr/local/cuda/lib64/libnvJitLink.so.12", "/usr/local/cuda/lib64/libnvrtc.so.12", "/usr/local/cuda/lib64/libcudnn_adv.so.9", @@ -108,9 +109,9 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None: "/usr/local/lib/libnvpl_blas_core.so.0", ] - if "128" in desired_cuda: + if "129" in desired_cuda: libs_to_copy += [ - "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.8", + "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.9", "/usr/local/cuda/lib64/libcufile.so.0", "/usr/local/cuda/lib64/libcufile_rdma.so.1", ] diff --git a/.ci/docker/almalinux/Dockerfile b/.ci/docker/almalinux/Dockerfile index e0e84fdf000043..418a76ceac2345 100644 --- a/.ci/docker/almalinux/Dockerfile +++ b/.ci/docker/almalinux/Dockerfile @@ -1,4 +1,4 @@ -ARG CUDA_VERSION=12.4 +ARG CUDA_VERSION=12.6 ARG BASE_TARGET=cuda${CUDA_VERSION} ARG ROCM_IMAGE=rocm/dev-almalinux-8:6.3-complete FROM amd64/almalinux:8.10-20250519 as base @@ -52,10 +52,6 @@ ENV CUDA_VERSION=${CUDA_VERSION} # Make things in our path by default ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH -FROM cuda as cuda11.8 -RUN bash ./install_cuda.sh 11.8 -ENV DESIRED_CUDA=11.8 - FROM cuda as cuda12.6 RUN bash ./install_cuda.sh 12.6 ENV DESIRED_CUDA=12.6 @@ -64,6 +60,10 @@ FROM cuda as cuda12.8 RUN bash ./install_cuda.sh 12.8 ENV DESIRED_CUDA=12.8 +FROM cuda as cuda12.9 +RUN bash ./install_cuda.sh 12.9 +ENV DESIRED_CUDA=12.9 + FROM ${ROCM_IMAGE} as rocm ENV PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" ADD ./common/install_mkl.sh install_mkl.sh @@ -78,7 +78,8 @@ RUN bash ./install_mnist.sh FROM base as all_cuda COPY --from=cuda11.8 /usr/local/cuda-11.8 /usr/local/cuda-11.8 COPY --from=cuda12.6 /usr/local/cuda-12.6 /usr/local/cuda-12.6 -COPY --from=cuda12.4 /usr/local/cuda-12.8 /usr/local/cuda-12.8 +COPY --from=cuda12.8 /usr/local/cuda-12.8 /usr/local/cuda-12.8 +COPY --from=cuda12.9 /usr/local/cuda-12.9 /usr/local/cuda-12.9 # Final step FROM ${BASE_TARGET} as final diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 484eac1fadd53a..2f7a85939e01d3 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -50,30 +50,21 @@ if [[ "$image" == *xla* ]]; then exit 0 fi -if [[ "$image" == *-focal* ]]; then - UBUNTU_VERSION=20.04 -elif [[ "$image" == *-jammy* ]]; then +if [[ "$image" == *-jammy* ]]; then UBUNTU_VERSION=22.04 elif [[ "$image" == *ubuntu* ]]; then extract_version_from_image_name ubuntu UBUNTU_VERSION -elif [[ "$image" == *centos* ]]; then - extract_version_from_image_name centos CENTOS_VERSION fi if [ -n "${UBUNTU_VERSION}" ]; then OS="ubuntu" -elif [ -n "${CENTOS_VERSION}" ]; then - OS="centos" else echo "Unable to derive operating system base..." exit 1 fi DOCKERFILE="${OS}/Dockerfile" -# When using ubuntu - 22.04, start from Ubuntu docker image, instead of nvidia/cuda docker image. -if [[ "$image" == *cuda* && "$UBUNTU_VERSION" != "22.04" ]]; then - DOCKERFILE="${OS}-cuda/Dockerfile" -elif [[ "$image" == *rocm* ]]; then +if [[ "$image" == *rocm* ]]; then DOCKERFILE="${OS}-rocm/Dockerfile" elif [[ "$image" == *xpu* ]]; then DOCKERFILE="${OS}-xpu/Dockerfile" @@ -98,8 +89,8 @@ tag=$(echo $image | awk -F':' '{print $2}') # configuration, so we hardcode everything here rather than do it # from scratch case "$tag" in - pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11) - CUDA_VERSION=12.6.3 + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11) + CUDA_VERSION=12.8.1 CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 @@ -110,7 +101,7 @@ case "$tag" in TRITON=yes ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks) - CUDA_VERSION=12.8 + CUDA_VERSION=12.8.1 CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 @@ -121,7 +112,31 @@ case "$tag" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc9) + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=9 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.13 + GCC_VERSION=9 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9) CUDA_VERSION=12.6.3 CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 @@ -168,8 +183,8 @@ case "$tag" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9) - CUDA_VERSION=11.8.0 + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9) + CUDA_VERSION=12.8.1 CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 @@ -179,25 +194,25 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-focal-py3-clang10-onnx) + pytorch-linux-jammy-py3-clang12-onnx) ANACONDA_PYTHON_VERSION=3.9 - CLANG_VERSION=10 + CLANG_VERSION=12 VISION=yes ONNX=yes ;; - pytorch-linux-focal-py3.9-clang10) + pytorch-linux-jammy-py3.9-clang12) ANACONDA_PYTHON_VERSION=3.9 - CLANG_VERSION=10 + CLANG_VERSION=12 VISION=yes TRITON=yes ;; - pytorch-linux-focal-py3.11-clang10) + pytorch-linux-jammy-py3.11-clang12) ANACONDA_PYTHON_VERSION=3.11 - CLANG_VERSION=10 + CLANG_VERSION=12 VISION=yes TRITON=yes ;; - pytorch-linux-focal-py3.9-gcc9) + pytorch-linux-jammy-py3.9-gcc9) ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=9 VISION=yes @@ -252,25 +267,14 @@ case "$tag" in DOCS=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12) + pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12) ANACONDA_PYTHON_VERSION=3.9 - CUDA_VERSION=11.8 + CUDA_VERSION=12.8.1 CUDNN_VERSION=9 CLANG_VERSION=12 VISION=yes TRITON=yes ;; - pytorch-linux-jammy-py3-clang12-asan) - ANACONDA_PYTHON_VERSION=3.9 - CLANG_VERSION=12 - VISION=yes - TRITON=yes - ;; - pytorch-linux-jammy-py3-clang15-asan) - ANACONDA_PYTHON_VERSION=3.10 - CLANG_VERSION=15 - VISION=yes - ;; pytorch-linux-jammy-py3-clang18-asan) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=18 @@ -303,15 +307,15 @@ case "$tag" in GCC_VERSION=11 TRITON_CPU=yes ;; - pytorch-linux-focal-linter) + pytorch-linux-jammy-linter) # TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627. # We will need to update mypy version eventually, but that's for another day. The task # would be to upgrade mypy to 1.0.0 with Python 3.11 PYTHON_VERSION=3.9 ;; - pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter) + pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter) PYTHON_VERSION=3.9 - CUDA_VERSION=11.8 + CUDA_VERSION=12.8.1 ;; pytorch-linux-jammy-aarch64-py3.10-gcc11) ANACONDA_PYTHON_VERSION=3.10 @@ -370,14 +374,6 @@ esac tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]') -#when using cudnn version 8 install it separately from cuda -if [[ "$image" == *cuda* && ${OS} == "ubuntu" ]]; then - IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}" - if [[ ${CUDNN_VERSION} == 9 ]]; then - IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}" - fi -fi - no_cache_flag="" progress_flag="" # Do not use cache and progress=plain when in CI @@ -394,7 +390,6 @@ docker build \ --build-arg "LLVMDEV=${LLVMDEV:-}" \ --build-arg "VISION=${VISION:-}" \ --build-arg "UBUNTU_VERSION=${UBUNTU_VERSION}" \ - --build-arg "CENTOS_VERSION=${CENTOS_VERSION}" \ --build-arg "DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}" \ --build-arg "GLIBC_VERSION=${GLIBC_VERSION}" \ --build-arg "CLANG_VERSION=${CLANG_VERSION}" \ diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index 88fee98625713c..8d1e7f5972b1df 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -39,6 +39,7 @@ RUN bash ./install_user.sh && rm install_user.sh # Install conda and other packages (e.g., numpy, pytest) ARG ANACONDA_PYTHON_VERSION +ARG BUILD_ENVIRONMENT ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION ENV PATH /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/bin:/opt/conda/bin:$PATH COPY requirements-ci.txt /opt/conda/requirements-ci.txt diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index ce68b499a8f43e..0e527f4682297f 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -b173722085b3f555d6ba4533d6bbaddfd7c71144 +56392aa978594cc155fa8af48cd949f5b5f1823a diff --git a/.ci/docker/ci_commit_pins/nccl-cu12.txt b/.ci/docker/ci_commit_pins/nccl-cu12.txt index 95406d215faa8e..c002780c5de38a 100644 --- a/.ci/docker/ci_commit_pins/nccl-cu12.txt +++ b/.ci/docker/ci_commit_pins/nccl-cu12.txt @@ -1 +1 @@ -v2.26.5-1 +v2.27.3-1 diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index 261082fb3d387a..80d7d7ed18af95 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -b0e26b7359c147b8aa0af686c20510fb9b15990a +ae324eeac8e102a2b40370e341460f3791353398 diff --git a/.ci/docker/common/install_base.sh b/.ci/docker/common/install_base.sh index c84147fc914918..64304fec6ed9da 100755 --- a/.ci/docker/common/install_base.sh +++ b/.ci/docker/common/install_base.sh @@ -30,18 +30,6 @@ install_ubuntu() { maybe_libomp_dev="" fi - # HACK: UCC testing relies on libnccl library from NVIDIA repo, and version 2.16 crashes - # See https://github.com/pytorch/pytorch/pull/105260#issuecomment-1673399729 - # TODO: Eliminate this hack, we should not relay on apt-get installation - # See https://github.com/pytorch/pytorch/issues/144768 - if [[ "$UBUNTU_VERSION" == "20.04"* && "$CUDA_VERSION" == "11.8"* ]]; then - maybe_libnccl_dev="libnccl2=2.15.5-1+cuda11.8 libnccl-dev=2.15.5-1+cuda11.8 --allow-downgrades --allow-change-held-packages" - elif [[ "$UBUNTU_VERSION" == "20.04"* && "$CUDA_VERSION" == "12.4"* ]]; then - maybe_libnccl_dev="libnccl2=2.26.2-1+cuda12.4 libnccl-dev=2.26.2-1+cuda12.4 --allow-downgrades --allow-change-held-packages" - else - maybe_libnccl_dev="" - fi - # Install common dependencies apt-get update # TODO: Some of these may not be necessary @@ -70,7 +58,6 @@ install_ubuntu() { libasound2-dev \ libsndfile-dev \ ${maybe_libomp_dev} \ - ${maybe_libnccl_dev} \ software-properties-common \ wget \ sudo \ diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index eec5f1dd235e6f..11c51cac0bf835 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -6,7 +6,7 @@ set -ex if [ -n "$ANACONDA_PYTHON_VERSION" ]; then BASE_URL="https://repo.anaconda.com/miniconda" CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" - if [[ $(uname -m) == "aarch64" ]] || [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then + if [[ $(uname -m) == "aarch64" ]] || [[ "$BUILD_ENVIRONMENT" == *xpu* ]] || [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" fi @@ -64,6 +64,11 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then # which is provided in libstdcxx 12 and up. conda_install libstdcxx-ng=12.3.0 --update-deps -c conda-forge + # Miniforge installer doesn't install sqlite by default + if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then + conda_install sqlite + fi + # Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README if [[ $(uname -m) == "aarch64" ]]; then conda_install "openblas==0.3.29=*openmp*" diff --git a/.ci/docker/common/install_cpython.sh b/.ci/docker/common/install_cpython.sh index ca6d59d6d9c2be..14d52f15214c36 100755 --- a/.ci/docker/common/install_cpython.sh +++ b/.ci/docker/common/install_cpython.sh @@ -3,11 +3,10 @@ set -uex -o pipefail PYTHON_DOWNLOAD_URL=https://www.python.org/ftp/python -PYTHON_DOWNLOAD_GITHUB_BRANCH=https://github.com/python/cpython/archive/refs/heads # @lint-ignore GET_PIP_URL=https://bootstrap.pypa.io/get-pip.py # Python versions to be installed in /opt/$VERSION_NO -CPYTHON_VERSIONS=${CPYTHON_VERSIONS:-"3.9.0 3.10.1 3.11.0 3.12.0 3.13.0 3.13.0t"} +CPYTHON_VERSIONS=${CPYTHON_VERSIONS:-"3.9.0 3.10.1 3.11.0 3.12.0 3.13.0 3.13.0t 3.14.0 3.14.0t"} function check_var { if [ -z "$1" ]; then @@ -24,9 +23,8 @@ function do_cpython_build { tar -xzf Python-$py_ver.tgz local additional_flags="" - if [ "$py_ver" == "3.13.0t" ]; then + if [[ "$py_ver" == *"t" ]]; then additional_flags=" --disable-gil" - mv cpython-3.13/ cpython-3.13t/ fi pushd $py_folder @@ -76,24 +74,20 @@ function do_cpython_build { function build_cpython { local py_ver=$1 check_var $py_ver - check_var $PYTHON_DOWNLOAD_URL - local py_ver_folder=$py_ver - - if [ "$py_ver" = "3.13.0t" ]; then - PY_VER_SHORT="3.13" - PYT_VER_SHORT="3.13t" - check_var $PYTHON_DOWNLOAD_GITHUB_BRANCH - wget $PYTHON_DOWNLOAD_GITHUB_BRANCH/$PY_VER_SHORT.tar.gz -O Python-$py_ver.tgz - do_cpython_build $py_ver cpython-$PYT_VER_SHORT - elif [ "$py_ver" = "3.13.0" ]; then - PY_VER_SHORT="3.13" - check_var $PYTHON_DOWNLOAD_GITHUB_BRANCH - wget $PYTHON_DOWNLOAD_GITHUB_BRANCH/$PY_VER_SHORT.tar.gz -O Python-$py_ver.tgz - do_cpython_build $py_ver cpython-$PY_VER_SHORT - else - wget -q $PYTHON_DOWNLOAD_URL/$py_ver_folder/Python-$py_ver.tgz - do_cpython_build $py_ver Python-$py_ver + local py_suffix=$py_ver + local py_folder=$py_ver + + # Special handling for nogil + if [[ "${py_ver}" == *"t" ]]; then + py_suffix=${py_ver::-1} + py_folder=$py_suffix + fi + # Only b3 is available now + if [ "$py_suffix" == "3.14.0" ]; then + py_suffix="3.14.0b3" fi + wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz + do_cpython_build $py_ver Python-$py_suffix rm -f Python-$py_ver.tgz } diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index a21d6503f86aa6..cd9701e7590b5d 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -10,6 +10,8 @@ else arch_path='sbsa' fi +NVSHMEM_VERSION=3.3.9 + function install_cuda { version=$1 runfile=$2 @@ -40,41 +42,52 @@ function install_cudnn { rm -rf tmp_cudnn } -function install_118 { - CUDNN_VERSION=9.1.0.70 - echo "Installing CUDA 11.8 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.4.0" - install_cuda 11.8.0 cuda_11.8.0_520.61.05_linux - - install_cudnn 11 $CUDNN_VERSION - - CUDA_VERSION=11.8 bash install_nccl.sh - - CUDA_VERSION=11.8 bash install_cusparselt.sh +function install_nvshmem { + cuda_major_version=$1 # e.g. "12" + nvshmem_version=$2 # e.g. "3.3.9" + + case "${arch_path}" in + sbsa) + dl_arch="aarch64" + ;; + x86_64) + dl_arch="x64" + ;; + *) + dl_arch="${arch}" + ;; + esac + + tmpdir="tmp_nvshmem" + mkdir -p "${tmpdir}" && cd "${tmpdir}" + + # nvSHMEM license: https://docs.nvidia.com/nvshmem/api/sla.html + filename="libnvshmem_cuda${cuda_major_version}-linux-${arch_path}-${nvshmem_version}" + url="https://developer.download.nvidia.com/compute/redist/nvshmem/${nvshmem_version}/builds/cuda${cuda_major_version}/txz/agnostic/${dl_arch}/${filename}.tar.gz" + + # download, unpack, install + wget -q "${url}" + tar xf "${filename}.tar.gz" + cp -a "libnvshmem/include/"* /usr/local/include/ + cp -a "libnvshmem/lib/"* /usr/local/lib/ + + # cleanup + cd .. + rm -rf "${tmpdir}" - ldconfig + echo "nvSHMEM ${nvshmem_version} for CUDA ${cuda_major_version} (${arch_path}) installed." } -function install_124 { - CUDNN_VERSION=9.1.0.70 - echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.2" - install_cuda 12.4.1 cuda_12.4.1_550.54.15_linux - - install_cudnn 12 $CUDNN_VERSION - - CUDA_VERSION=12.4 bash install_nccl.sh - - CUDA_VERSION=12.4 bash install_cusparselt.sh - - ldconfig -} function install_126 { - CUDNN_VERSION=9.5.1.17 - echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.3" + CUDNN_VERSION=9.10.2.21 + echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" install_cuda 12.6.3 cuda_12.6.3_560.35.05_linux install_cudnn 12 $CUDNN_VERSION + install_nvshmem 12 $NVSHMEM_VERSION + CUDA_VERSION=12.6 bash install_nccl.sh CUDA_VERSION=12.6 bash install_cusparselt.sh @@ -82,69 +95,22 @@ function install_126 { ldconfig } -function prune_118 { - echo "Pruning CUDA 11.8 and cuDNN" - ##################################################################################### - # CUDA 11.8 prune static libs - ##################################################################################### - export NVPRUNE="/usr/local/cuda-11.8/bin/nvprune" - export CUDA_LIB_DIR="/usr/local/cuda-11.8/lib64" - - export GENCODE="-gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" - export GENCODE_CUDNN="-gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" - - if [[ -n "$OVERRIDE_GENCODE" ]]; then - export GENCODE=$OVERRIDE_GENCODE - fi - - # all CUDA libs except CuDNN and CuBLAS (cudnn and cublas need arch 3.7 included) - ls $CUDA_LIB_DIR/ | grep "\.a" | grep -v "culibos" | grep -v "cudart" | grep -v "cudnn" | grep -v "cublas" | grep -v "metis" \ - | xargs -I {} bash -c \ - "echo {} && $NVPRUNE $GENCODE $CUDA_LIB_DIR/{} -o $CUDA_LIB_DIR/{}" - - # prune CuDNN and CuBLAS - $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublas_static.a -o $CUDA_LIB_DIR/libcublas_static.a - $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublasLt_static.a -o $CUDA_LIB_DIR/libcublasLt_static.a - - ##################################################################################### - # CUDA 11.8 prune visual tools - ##################################################################################### - export CUDA_BASE="/usr/local/cuda-11.8/" - rm -rf $CUDA_BASE/libnvvp $CUDA_BASE/nsightee_plugins $CUDA_BASE/nsight-compute-2022.3.0 $CUDA_BASE/nsight-systems-2022.4.2/ -} - -function prune_124 { - echo "Pruning CUDA 12.4" - ##################################################################################### - # CUDA 12.4 prune static libs - ##################################################################################### - export NVPRUNE="/usr/local/cuda-12.4/bin/nvprune" - export CUDA_LIB_DIR="/usr/local/cuda-12.4/lib64" +function install_129 { + CUDNN_VERSION=9.10.2.21 + echo "Installing CUDA 12.9.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" + # install CUDA 12.9.1 in the same container + install_cuda 12.9.1 cuda_12.9.1_575.57.08_linux - export GENCODE="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" - export GENCODE_CUDNN="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement + install_cudnn 12 $CUDNN_VERSION - if [[ -n "$OVERRIDE_GENCODE" ]]; then - export GENCODE=$OVERRIDE_GENCODE - fi - if [[ -n "$OVERRIDE_GENCODE_CUDNN" ]]; then - export GENCODE_CUDNN=$OVERRIDE_GENCODE_CUDNN - fi + install_nvshmem 12 $NVSHMEM_VERSION - # all CUDA libs except CuDNN and CuBLAS - ls $CUDA_LIB_DIR/ | grep "\.a" | grep -v "culibos" | grep -v "cudart" | grep -v "cudnn" | grep -v "cublas" | grep -v "metis" \ - | xargs -I {} bash -c \ - "echo {} && $NVPRUNE $GENCODE $CUDA_LIB_DIR/{} -o $CUDA_LIB_DIR/{}" + CUDA_VERSION=12.9 bash install_nccl.sh - # prune CuDNN and CuBLAS - $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublas_static.a -o $CUDA_LIB_DIR/libcublas_static.a - $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublasLt_static.a -o $CUDA_LIB_DIR/libcublasLt_static.a + CUDA_VERSION=12.9 bash install_cusparselt.sh - ##################################################################################### - # CUDA 12.4 prune visual tools - ##################################################################################### - export CUDA_BASE="/usr/local/cuda-12.4/" - rm -rf $CUDA_BASE/libnvvp $CUDA_BASE/nsightee_plugins $CUDA_BASE/nsight-compute-2024.1.0 $CUDA_BASE/nsight-systems-2023.4.4/ + ldconfig } function prune_126 { @@ -183,13 +149,15 @@ function prune_126 { function install_128 { CUDNN_VERSION=9.8.0.87 - echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.3" + echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" # install CUDA 12.8.1 in the same container install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement install_cudnn 12 $CUDNN_VERSION + install_nvshmem 12 $NVSHMEM_VERSION + CUDA_VERSION=12.8 bash install_nccl.sh CUDA_VERSION=12.8 bash install_cusparselt.sh @@ -201,13 +169,11 @@ function install_128 { while test $# -gt 0 do case "$1" in - 11.8) install_118; prune_118 - ;; - 12.4) install_124; prune_124 + 12.6|12.6.*) install_126; prune_126 ;; - 12.6) install_126; prune_126 + 12.8|12.8.*) install_128; ;; - 12.8) install_128; + 12.9|12.9.*) install_129; ;; *) echo "bad argument $1"; exit 1 ;; diff --git a/.ci/docker/common/install_cudnn.sh b/.ci/docker/common/install_cudnn.sh index eb824d880de37a..7ee5e73226cb60 100644 --- a/.ci/docker/common/install_cudnn.sh +++ b/.ci/docker/common/install_cudnn.sh @@ -4,12 +4,10 @@ if [[ -n "${CUDNN_VERSION}" ]]; then # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement mkdir tmp_cudnn pushd tmp_cudnn - if [[ ${CUDA_VERSION:0:4} == "12.8" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-9.8.0.87_cuda12-archive" + if [[ ${CUDA_VERSION:0:4} == "12.9" || ${CUDA_VERSION:0:4} == "12.8" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" elif [[ ${CUDA_VERSION:0:4} == "12.6" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-9.5.1.17_cuda12-archive" - elif [[ ${CUDA_VERSION:0:2} == "12" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda12-archive" + CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive" else diff --git a/.ci/docker/common/install_cusparselt.sh b/.ci/docker/common/install_cusparselt.sh index 0603739fb041fc..ca29a94e58fc9c 100644 --- a/.ci/docker/common/install_cusparselt.sh +++ b/.ci/docker/common/install_cusparselt.sh @@ -5,25 +5,14 @@ set -ex # cuSPARSELt license: https://docs.nvidia.com/cuda/cusparselt/license.html mkdir tmp_cusparselt && cd tmp_cusparselt -if [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-8]$ ]]; then +if [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-9]$ ]]; then arch_path='sbsa' export TARGETARCH=${TARGETARCH:-$(uname -m)} if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then arch_path='x86_64' fi - CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.6.3.2-archive" + CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.7.1.0-archive" curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz -elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then - arch_path='sbsa' - export TARGETARCH=${TARGETARCH:-$(uname -m)} - if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then - arch_path='x86_64' - fi - CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.6.2.3-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz -elif [[ ${CUDA_VERSION:0:4} == "11.8" ]]; then - CUSPARSELT_NAME="libcusparse_lt-linux-x86_64-0.4.0.7-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-x86_64/${CUSPARSELT_NAME}.tar.xz else echo "Not sure which libcusparselt version to install for this ${CUDA_VERSION}" fi diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 7e229662212516..d07ec320016359 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -8,16 +8,6 @@ retry () { "$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@") } -# A bunch of custom pip dependencies for ONNX -pip_install \ - beartype==0.15.0 \ - filelock==3.9.0 \ - flatbuffers==2.0 \ - mock==5.0.1 \ - ninja==1.10.2 \ - networkx==2.5 \ - numpy==1.24.2 - # ONNXRuntime should be installed before installing # onnx-weekly. Otherwise, onnx-weekly could be # overwritten by onnx. @@ -29,11 +19,8 @@ pip_install \ transformers==4.36.2 pip_install coloredlogs packaging - pip_install onnxruntime==1.18.1 -pip_install onnxscript==0.2.6 --no-deps -# required by onnxscript -pip_install ml_dtypes +pip_install onnxscript==0.3.1 # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ diff --git a/.ci/docker/common/install_openblas.sh b/.ci/docker/common/install_openblas.sh index 7f0b3620bdc11a..e932ecd1cdc1a2 100644 --- a/.ci/docker/common/install_openblas.sh +++ b/.ci/docker/common/install_openblas.sh @@ -4,8 +4,7 @@ set -ex cd / -git clone https://github.com/OpenMathLib/OpenBLAS.git -b v0.3.29 --depth 1 --shallow-submodules - +git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION:-v0.3.29}" --depth 1 --shallow-submodules OPENBLAS_BUILD_FLAGS=" NUM_THREADS=128 diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 3ac2a1bc2f4fa6..2a8d5b30e74e3c 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -26,6 +26,11 @@ Pin: release o=repo.radeon.com Pin-Priority: 600 EOF + # we want the patch version of 6.4 instead + if [[ $(ver $ROCM_VERSION) -eq $(ver 6.4) ]]; then + ROCM_VERSION="${ROCM_VERSION}.1" + fi + # Add amdgpu repository UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'` echo "deb [arch=amd64] https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list @@ -67,19 +72,23 @@ EOF # ROCm 6.3 had a regression where initializing static code objects had significant overhead # ROCm 6.4 did not yet fix the regression, also HIP branch names are different - if [[ $(ver $ROCM_VERSION) -eq $(ver 6.3) ]] || [[ $(ver $ROCM_VERSION) -eq $(ver 6.4) ]]; then - if [[ $(ver $ROCM_VERSION) -eq $(ver 6.3) ]]; then - HIP_BRANCH=rocm-6.3.x - VER_STR=6.3 + if [[ $(ver $ROCM_VERSION) -ge $(ver 6.3) ]] && [[ $(ver $ROCM_VERSION) -lt $(ver 7.0) ]]; then + if [[ $(ver $ROCM_VERSION) -eq $(ver 6.4.1) ]]; then + HIP_BRANCH=release/rocm-rel-6.4 + VER_STR=6.4 + VER_PATCH=.1 elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.4) ]]; then HIP_BRANCH=release/rocm-rel-6.4 VER_STR=6.4 + elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.3) ]]; then + HIP_BRANCH=rocm-6.3.x + VER_STR=6.3 fi # clr build needs CppHeaderParser but can only find it using conda's python /opt/conda/bin/python -m pip install CppHeaderParser git clone https://github.com/ROCm/HIP -b $HIP_BRANCH HIP_COMMON_DIR=$(readlink -f HIP) - git clone https://github.com/jeffdaily/clr -b release/rocm-rel-${VER_STR}-statco-hotfix + git clone https://github.com/jeffdaily/clr -b release/rocm-rel-${VER_STR}${VER_PATCH}-statco-hotfix mkdir -p clr/build pushd clr/build cmake .. -DCLR_BUILD_HIP=ON -DHIP_COMMON_DIR=$HIP_COMMON_DIR diff --git a/.ci/docker/common/install_rocm_magma.sh b/.ci/docker/common/install_rocm_magma.sh index 364ee23b97e57a..a8d8ba00b35b85 100644 --- a/.ci/docker/common/install_rocm_magma.sh +++ b/.ci/docker/common/install_rocm_magma.sh @@ -5,7 +5,12 @@ set -eou pipefail function do_install() { rocm_version=$1 - rocm_version_nodot=${1//./} + if [[ ${rocm_version} =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + # chop off any patch version + rocm_version="${rocm_version%.*}" + fi + + rocm_version_nodot=${rocm_version//./} # Version 2.7.2 + ROCm related updates MAGMA_VERSION=a1625ff4d9bc362906bd01f805dbbe12612953f6 diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index 90ef587a0566a3..a965f0f743d4ea 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -51,7 +51,12 @@ as_jenkins git clone --recursive ${TRITON_REPO} triton cd triton as_jenkins git checkout ${TRITON_PINNED_COMMIT} as_jenkins git submodule update --init --recursive -cd python + +# Old versions of python have setup.py in ./python; newer versions have it in ./ +if [ ! -f setup.py ]; then + cd python +fi + pip_install pybind11==2.13.6 # TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 @@ -93,3 +98,10 @@ fi if [ -n "${NUMPY_VERSION}" ]; then pip_install "numpy==${NUMPY_VERSION}" fi + +# IMPORTANT: helion needs to be installed without dependencies. +# It depends on torch and triton. We don't want to install +# triton and torch from production on Docker CI images +if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then + pip_install helion --no-deps +fi diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile index fb459a41b7d5eb..776053a5d8750a 100644 --- a/.ci/docker/libtorch/Dockerfile +++ b/.ci/docker/libtorch/Dockerfile @@ -54,16 +54,6 @@ COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ COPY ./common/install_cusparselt.sh install_cusparselt.sh ENV CUDA_HOME /usr/local/cuda -FROM cuda as cuda11.8 -RUN bash ./install_cuda.sh 11.8 -RUN bash ./install_magma.sh 11.8 -RUN ln -sf /usr/local/cuda-11.8 /usr/local/cuda - -FROM cuda as cuda12.4 -RUN bash ./install_cuda.sh 12.4 -RUN bash ./install_magma.sh 12.4 -RUN ln -sf /usr/local/cuda-12.4 /usr/local/cuda - FROM cuda as cuda12.6 RUN bash ./install_cuda.sh 12.6 RUN bash ./install_magma.sh 12.6 @@ -74,6 +64,11 @@ RUN bash ./install_cuda.sh 12.8 RUN bash ./install_magma.sh 12.8 RUN ln -sf /usr/local/cuda-12.8 /usr/local/cuda +FROM cuda as cuda12.9 +RUN bash ./install_cuda.sh 12.9 +RUN bash ./install_magma.sh 12.9 +RUN ln -sf /usr/local/cuda-12.9 /usr/local/cuda + FROM cpu as rocm ARG ROCM_VERSION ARG PYTORCH_ROCM_ARCH diff --git a/.ci/docker/libtorch/build.sh b/.ci/docker/libtorch/build.sh index b8e5af5ce33292..a2c67f0aa6411e 100755 --- a/.ci/docker/libtorch/build.sh +++ b/.ci/docker/libtorch/build.sh @@ -39,6 +39,10 @@ case ${DOCKER_TAG_PREFIX} in DOCKER_GPU_BUILD_ARG="" ;; rocm*) + # we want the patch version of 6.4 instead + if [[ $(ver $GPU_ARCH_VERSION) -eq $(ver 6.4) ]]; then + GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.1" + fi BASE_TARGET=rocm GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index 5280f34740c691..b150423e99544a 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -26,7 +26,7 @@ ADD ./common/install_openssl.sh install_openssl.sh RUN bash ./install_openssl.sh && rm install_openssl.sh -# remove unncessary python versions +# remove unnecessary python versions RUN rm -rf /opt/python/cp26-cp26m /opt/_internal/cpython-2.6.9-ucs2 RUN rm -rf /opt/python/cp26-cp26mu /opt/_internal/cpython-2.6.9-ucs4 RUN rm -rf /opt/python/cp33-cp33m /opt/_internal/cpython-3.3.6 @@ -103,6 +103,7 @@ ENV SSL_CERT_FILE=/opt/_internal/certs.pem # Install LLVM version COPY --from=openssl /opt/openssl /opt/openssl COPY --from=base /opt/python /opt/python +COPY --from=base /usr/local/lib/ /usr/local/lib/ COPY --from=base /opt/_internal /opt/_internal COPY --from=base /usr/local/bin/auditwheel /usr/local/bin/auditwheel COPY --from=intel /opt/intel /opt/intel diff --git a/.ci/docker/manywheel/Dockerfile_2_28_aarch64 b/.ci/docker/manywheel/Dockerfile_2_28_aarch64 index 99947aceaa97be..da7ab4d3fd1543 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_2_28_aarch64 @@ -2,7 +2,7 @@ FROM quay.io/pypa/manylinux_2_28_aarch64 as base ARG GCCTOOLSET_VERSION=13 -# Language variabes +# Language variables ENV LC_ALL=en_US.UTF-8 ENV LANG=en_US.UTF-8 ENV LANGUAGE=en_US.UTF-8 @@ -58,12 +58,13 @@ RUN git config --global --add safe.directory "*" FROM base as openblas # Install openblas +ARG OPENBLAS_VERSION ADD ./common/install_openblas.sh install_openblas.sh RUN bash ./install_openblas.sh && rm install_openblas.sh FROM base as final -# remove unncessary python versions +# remove unnecessary python versions RUN rm -rf /opt/python/cp26-cp26m /opt/_internal/cpython-2.6.9-ucs2 RUN rm -rf /opt/python/cp26-cp26mu /opt/_internal/cpython-2.6.9-ucs4 RUN rm -rf /opt/python/cp33-cp33m /opt/_internal/cpython-3.3.6 diff --git a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 index 1e3ebc128c2de0..3697060557375d 100644 --- a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 @@ -60,7 +60,7 @@ RUN bash ./install_openssl.sh && rm install_openssl.sh ENV SSL_CERT_FILE=/opt/_internal/certs.pem FROM openssl as final -# remove unncessary python versions +# remove unnecessary python versions RUN rm -rf /opt/python/cp26-cp26m /opt/_internal/cpython-2.6.9-ucs2 RUN rm -rf /opt/python/cp26-cp26mu /opt/_internal/cpython-2.6.9-ucs4 RUN rm -rf /opt/python/cp33-cp33m /opt/_internal/cpython-3.3.6 diff --git a/.ci/docker/manywheel/Dockerfile_s390x b/.ci/docker/manywheel/Dockerfile_s390x index af19fb68a25b64..46ec7f77ae8ba8 100644 --- a/.ci/docker/manywheel/Dockerfile_s390x +++ b/.ci/docker/manywheel/Dockerfile_s390x @@ -120,15 +120,19 @@ RUN python3 -mpip install cmake==3.28.0 # so just build it from upstream repository. # h5py is dependency of onnxruntime_training. # h5py==3.11.0 builds with hdf5-devel 1.10.5 from repository. +# h5py 3.11.0 doesn't build with numpy >= 2.3.0. # install newest flatbuffers version first: # for some reason old version is getting pulled in otherwise. # packaging package is required for onnxruntime wheel build. RUN pip3 install flatbuffers && \ - pip3 install h5py==3.11.0 && \ + pip3 install cython 'pkgconfig>=1.5.5' 'setuptools>=77' 'numpy<2.3.0' && \ + pip3 install --no-build-isolation h5py==3.11.0 && \ pip3 install packaging && \ git clone https://github.com/microsoft/onnxruntime && \ cd onnxruntime && git checkout v1.21.0 && \ git submodule update --init --recursive && \ + wget https://github.com/microsoft/onnxruntime/commit/f57db79743c4d1a3553aa05cf95bcd10966030e6.patch && \ + patch -p1 < f57db79743c4d1a3553aa05cf95bcd10966030e6.patch && \ ./build.sh --config Release --parallel 0 --enable_pybind \ --build_wheel --enable_training --enable_training_apis \ --enable_training_ops --skip_tests --allow_running_as_root \ diff --git a/.ci/docker/manywheel/build.sh b/.ci/docker/manywheel/build.sh index 62230430542869..a4942a65bf5775 100755 --- a/.ci/docker/manywheel/build.sh +++ b/.ci/docker/manywheel/build.sh @@ -27,6 +27,7 @@ fi MANY_LINUX_VERSION=${MANY_LINUX_VERSION:-} DOCKERFILE_SUFFIX=${DOCKERFILE_SUFFIX:-} +OPENBLAS_VERSION=${OPENBLAS_VERSION:-} case ${image} in manylinux2_28-builder:cpu) @@ -40,6 +41,7 @@ case ${image} in GPU_IMAGE=arm64v8/almalinux:8 DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13 --build-arg NINJA_VERSION=1.12.1" MANY_LINUX_VERSION="2_28_aarch64" + OPENBLAS_VERSION="v0.3.29" ;; manylinuxcxx11-abi-builder:cpu-cxx11-abi) TARGET=final @@ -73,6 +75,10 @@ case ${image} in DOCKERFILE_SUFFIX="_cuda_aarch64" ;; manylinux2_28-builder:rocm*) + # we want the patch version of 6.4 instead + if [[ $(ver $GPU_ARCH_VERSION) -eq $(ver 6.4) ]]; then + GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.1" + fi TARGET=rocm_final MANY_LINUX_VERSION="2_28" DEVTOOLSET_VERSION="11" @@ -109,6 +115,7 @@ tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]') DOCKER_BUILDKIT=1 docker build \ ${DOCKER_GPU_BUILD_ARG} \ --build-arg "GPU_IMAGE=${GPU_IMAGE}" \ + --build-arg "OPENBLAS_VERSION=${OPENBLAS_VERSION}" \ --target "${TARGET}" \ -t "${tmp_tag}" \ $@ \ diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 57215d51e9a057..4ecdde62408deb 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -41,14 +41,11 @@ fbscribelogger==0.1.7 #Pinned versions: 0.1.6 #test that import: -flatbuffers==2.0 ; platform_machine != "s390x" +flatbuffers==24.12.23 #Description: cross platform serialization library -#Pinned versions: 2.0 +#Pinned versions: 24.12.23 #test that import: -flatbuffers ; platform_machine == "s390x" -#Description: cross platform serialization library; Newer version is required on s390x for new python version - hypothesis==5.35.1 # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 #Description: advanced library for generating parametrized tests @@ -93,10 +90,10 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==1.15.0 +mypy==1.16.0 # Pin MyPy version because new errors are likely to appear with each release #Description: linter -#Pinned versions: 1.14.0 +#Pinned versions: 1.16.0 #test that import: test_typing.py, test_type_hints.py networkx==2.8.8 @@ -342,7 +339,7 @@ onnx==1.18.0 #Pinned versions: #test that import: -onnxscript==0.2.6 +onnxscript==0.3.1 #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal #Pinned versions: #test that import: @@ -382,3 +379,10 @@ dataclasses_json==0.6.7 cmake==4.0.0 #Description: required for building + +tlparse==0.3.30 +#Description: required for log parsing + +cuda-bindings>=12.0,<13.0 +#Description: required for testing CUDAGraph::raw_cuda_graph(). See https://nvidia.github.io/cuda-python/cuda-bindings/latest/support.html for how this version was chosen. Note "Any fix in the latest bindings would be backported to the prior major version" means that only the newest version of cuda-bindings will get fixes. Depending on the latest version of 12.x is okay because all 12.y versions will be supported via "CUDA minor version compatibility". Pytorch builds against 13.z versions of cuda toolkit work with 12.x versions of cuda-bindings as well because newer drivers work with old toolkits. +#test that import: test_cuda.py diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index 15e8075e617f4e..0d80c4d4937254 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -19,9 +19,10 @@ sphinx_sitemap==2.6.0 #Description: This is used to generate sitemap for PyTorch docs #Pinned versions: 2.6.0 -matplotlib==3.5.3 +matplotlib==3.5.3 ; python_version < "3.13" +matplotlib==3.6.3 ; python_version >= "3.13" #Description: This is used to generate PyTorch docs -#Pinned versions: 3.5.3 +#Pinned versions: 3.6.3 if python > 3.12. Otherwise 3.5.3. tensorboard==2.13.0 ; python_version < "3.13" tensorboard==2.18.0 ; python_version >= "3.13" diff --git a/.ci/docker/triton_xpu_version.txt b/.ci/docker/triton_xpu_version.txt new file mode 100644 index 00000000000000..18091983f59ddd --- /dev/null +++ b/.ci/docker/triton_xpu_version.txt @@ -0,0 +1 @@ +3.4.0 diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile deleted file mode 100644 index e2d8d4f618429f..00000000000000 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ /dev/null @@ -1,170 +0,0 @@ -ARG UBUNTU_VERSION -ARG CUDA_VERSION -ARG IMAGE_NAME - -FROM ${IMAGE_NAME} as base - -ARG UBUNTU_VERSION -ARG CUDA_VERSION - -ENV DEBIAN_FRONTEND noninteractive - -# Install common dependencies (so that this step can be cached separately) -COPY ./common/install_base.sh install_base.sh -RUN bash ./install_base.sh && rm install_base.sh - -# Install user -COPY ./common/install_user.sh install_user.sh -RUN bash ./install_user.sh && rm install_user.sh - -# Install katex -ARG KATEX -COPY ./common/install_docs_reqs.sh install_docs_reqs.sh -RUN bash ./install_docs_reqs.sh && rm install_docs_reqs.sh - -# Install conda and other packages (e.g., numpy, pytest) -ARG ANACONDA_PYTHON_VERSION -ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION -ENV PATH /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/bin:/opt/conda/bin:$PATH -COPY requirements-ci.txt /opt/conda/requirements-ci.txt -COPY ./common/install_conda.sh install_conda.sh -COPY ./common/common_utils.sh common_utils.sh -COPY ./common/install_magma_conda.sh install_magma_conda.sh -RUN bash ./install_conda.sh && rm install_conda.sh install_magma_conda.sh common_utils.sh /opt/conda/requirements-ci.txt - -# Install gcc -ARG GCC_VERSION -COPY ./common/install_gcc.sh install_gcc.sh -RUN bash ./install_gcc.sh && rm install_gcc.sh - -# Install clang -ARG CLANG_VERSION -COPY ./common/install_clang.sh install_clang.sh -RUN bash ./install_clang.sh && rm install_clang.sh - -# (optional) Install vision packages like OpenCV -ARG VISION -COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ -RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi -RUN rm install_vision.sh cache_vision_models.sh common_utils.sh -ENV INSTALLED_VISION ${VISION} - -# (optional) Install UCC -ARG UCX_COMMIT -ARG UCC_COMMIT -ENV UCX_COMMIT $UCX_COMMIT -ENV UCC_COMMIT $UCC_COMMIT -ENV UCX_HOME /usr -ENV UCC_HOME /usr -ADD ./common/install_ucc.sh install_ucc.sh -RUN if [ -n "${UCX_COMMIT}" ] && [ -n "${UCC_COMMIT}" ]; then bash ./install_ucc.sh; fi -RUN rm install_ucc.sh - -COPY ./common/install_openssl.sh install_openssl.sh -ENV OPENSSL_ROOT_DIR /opt/openssl -RUN bash ./install_openssl.sh -ENV OPENSSL_DIR /opt/openssl - -ARG INDUCTOR_BENCHMARKS -ARG ANACONDA_PYTHON_VERSION -ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION -COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh -COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/huggingface.txt huggingface.txt -COPY ci_commit_pins/timm.txt timm.txt -RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt - -ARG TRITON - -FROM base as triton-builder -# Install triton, this needs to be done before sccache because the latter will -# try to reach out to S3, which docker build runners don't have access -COPY ./common/install_triton.sh install_triton.sh -COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton.txt triton.txt -COPY triton_version.txt triton_version.txt -RUN bash ./install_triton.sh - -FROM base as final -COPY --from=triton-builder /opt/triton /opt/triton -RUN if [ -n "${TRITON}" ]; then pip install /opt/triton/*.whl; chown -R jenkins:jenkins /opt/conda; fi -RUN rm -rf /opt/triton - -ARG HALIDE -# Build and install halide -COPY ./common/install_halide.sh install_halide.sh -COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/halide.txt halide.txt -RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi -RUN rm install_halide.sh common_utils.sh halide.txt - -# Install ccache/sccache (do this last, so we get priority in PATH) -COPY ./common/install_cache.sh install_cache.sh -ENV PATH /opt/cache/bin:$PATH -# See https://github.com/pytorch/pytorch/issues/82174 -# TODO(sdym@fb.com): -# check if this is needed after full off Xenial migration -ENV CARGO_NET_GIT_FETCH_WITH_CLI true -RUN bash ./install_cache.sh && rm install_cache.sh -ENV CMAKE_CUDA_COMPILER_LAUNCHER=/opt/cache/bin/sccache - -# Add jni.h for java host build -COPY ./common/install_jni.sh install_jni.sh -COPY ./java/jni.h jni.h -RUN bash ./install_jni.sh && rm install_jni.sh - -# Install Open MPI for CUDA -COPY ./common/install_openmpi.sh install_openmpi.sh -RUN if [ -n "${CUDA_VERSION}" ]; then bash install_openmpi.sh; fi -RUN rm install_openmpi.sh - -# Include BUILD_ENVIRONMENT environment variable in image -ARG BUILD_ENVIRONMENT -ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} - -# AWS specific CUDA build guidance -ENV TORCH_CUDA_ARCH_LIST Maxwell -ENV TORCH_NVCC_FLAGS "-Xfatbin -compress-all" -ENV CUDA_PATH /usr/local/cuda - -# Install LLVM dev version (Defined in the pytorch/builder github repository) -COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm - -# Install CUDNN -ARG CUDNN_VERSION -ARG CUDA_VERSION -COPY ./common/install_cudnn.sh install_cudnn.sh -RUN if [ -n "${CUDNN_VERSION}" ]; then bash install_cudnn.sh; fi -RUN rm install_cudnn.sh - -# Install CUSPARSELT -ARG CUDA_VERSION -COPY ./common/install_cusparselt.sh install_cusparselt.sh -RUN bash install_cusparselt.sh -RUN rm install_cusparselt.sh - -# Install NCCL -ARG CUDA_VERSION -COPY ./common/install_nccl.sh install_nccl.sh -COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ -RUN bash install_nccl.sh -RUN rm install_nccl.sh /ci_commit_pins/nccl-cu* -ENV USE_SYSTEM_NCCL=1 -ENV NCCL_INCLUDE_DIR="/usr/local/cuda/include/" -ENV NCCL_LIB_DIR="/usr/local/cuda/lib64/" - -# Install CUDSS -ARG CUDA_VERSION -COPY ./common/install_cudss.sh install_cudss.sh -RUN bash install_cudss.sh -RUN rm install_cudss.sh - -# Delete /usr/local/cuda-11.X/cuda-11.X symlinks -RUN if [ -h /usr/local/cuda-11.6/cuda-11.6 ]; then rm /usr/local/cuda-11.6/cuda-11.6; fi -RUN if [ -h /usr/local/cuda-11.7/cuda-11.7 ]; then rm /usr/local/cuda-11.7/cuda-11.7; fi -RUN if [ -h /usr/local/cuda-12.1/cuda-12.1 ]; then rm /usr/local/cuda-12.1/cuda-12.1; fi -RUN if [ -h /usr/local/cuda-12.4/cuda-12.4 ]; then rm /usr/local/cuda-12.4/cuda-12.4; fi - -USER jenkins -CMD ["bash"] diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 009766a0ccd521..2528da07c69e38 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -25,6 +25,7 @@ RUN bash ./install_docs_reqs.sh && rm install_docs_reqs.sh # Install conda and other packages (e.g., numpy, pytest) ARG ANACONDA_PYTHON_VERSION +ARG BUILD_ENVIRONMENT ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION ENV PATH /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/bin:/opt/conda/bin:$PATH COPY requirements-ci.txt /opt/conda/requirements-ci.txt diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile index e4426129b1d613..a0e7dce3df4d55 100644 --- a/.ci/docker/ubuntu-xpu/Dockerfile +++ b/.ci/docker/ubuntu-xpu/Dockerfile @@ -72,7 +72,7 @@ ARG TRITON COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/triton-xpu.txt triton-xpu.txt -COPY triton_version.txt triton_version.txt +COPY triton_xpu_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-xpu.txt triton_version.txt diff --git a/.ci/magma/Makefile b/.ci/magma/Makefile index 12a37d35090e49..5035e1ee3b2c61 100644 --- a/.ci/magma/Makefile +++ b/.ci/magma/Makefile @@ -1,7 +1,7 @@ SHELL=/usr/bin/env bash DOCKER_CMD ?= docker -DESIRED_CUDA ?= 11.8 +DESIRED_CUDA ?= 12.8 DESIRED_CUDA_SHORT = $(subst .,,$(DESIRED_CUDA)) PACKAGE_NAME = magma-cuda CUDA_ARCH_LIST ?= -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 @@ -16,15 +16,21 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \ magma/build_magma.sh .PHONY: all +all: magma-cuda129 all: magma-cuda128 all: magma-cuda126 -all: magma-cuda118 .PHONY: clean: $(RM) -r magma-* $(RM) -r output +.PHONY: magma-cuda129 +magma-cuda129: DESIRED_CUDA := 12.9 +magma-cuda129: CUDA_ARCH_LIST += -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 +magma-cuda129: + $(DOCKER_RUN) + .PHONY: magma-cuda128 magma-cuda128: DESIRED_CUDA := 12.8 magma-cuda128: CUDA_ARCH_LIST += -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 @@ -35,9 +41,3 @@ magma-cuda128: magma-cuda126: DESIRED_CUDA := 12.6 magma-cuda126: $(DOCKER_RUN) - -.PHONY: magma-cuda118 -magma-cuda118: DESIRED_CUDA := 11.8 -magma-cuda118: CUDA_ARCH_LIST += -gencode arch=compute_37,code=sm_37 -magma-cuda118: - $(DOCKER_RUN) diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index ec822b0cd4afbc..74ce994207d486 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -31,7 +31,6 @@ elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then # Comment out nvidia repositories to prevent them from getting apt-get updated, see https://github.com/pytorch/pytorch/issues/74968 # shellcheck disable=SC2046 sed -i 's/.*nvidia.*/# &/' $(find /etc/apt/ -type f -name "*.list") - retry apt-get update retry apt-get -y install zip openssl else @@ -98,6 +97,7 @@ if [[ -z "$PYTORCH_ROOT" ]]; then exit 1 fi pushd "$PYTORCH_ROOT" +retry pip install -q cmake python setup.py clean retry pip install -qr requirements.txt case ${DESIRED_PYTHON} in @@ -151,7 +151,7 @@ if [[ "$USE_SPLIT_BUILD" == "true" ]]; then BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 \ BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ - python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR --cmake + CMAKE_FRESH=1 python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" else time CMAKE_ARGS=${CMAKE_ARGS[@]} \ diff --git a/.ci/manywheel/build_cuda.sh b/.ci/manywheel/build_cuda.sh index d3536148a77df6..a2c02d656c72db 100644 --- a/.ci/manywheel/build_cuda.sh +++ b/.ci/manywheel/build_cuda.sh @@ -15,6 +15,9 @@ export INSTALL_TEST=0 # dont install test binaries into site-packages export USE_CUPTI_SO=0 export USE_CUSPARSELT=${USE_CUSPARSELT:-1} # Enable if not disabled by libtorch build export USE_CUFILE=${USE_CUFILE:-1} +export USE_SYSTEM_NCCL=1 +export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" +export NCCL_LIB_DIR="/usr/local/cuda/lib64/" # Keep an array of cmake variables to add to if [[ -z "$CMAKE_ARGS" ]]; then @@ -48,20 +51,22 @@ else fi cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.') +EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") -TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6" case ${CUDA_VERSION} in + #removing sm_50-sm_70 as these architectures are deprecated in CUDA 12.8/9 and will be removed in future releases 12.8) - TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX" #removing sm_50-sm_70 as these architectures are deprecated in CUDA 12.8 and will be removed in future releases - EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") + TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0" ;; - 12.6) - TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0" - EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") + 12.9) + TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX" + # WAR to resolve the ld error in libtorch build with CUDA 12.9 + if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then + TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX" + fi ;; - 11.8) - TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};3.7;9.0" - EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") + 12.6) + TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0" ;; *) echo "unknown cuda version $CUDA_VERSION" @@ -104,12 +109,11 @@ DEPS_SONAME=( ) -# CUDA_VERSION 12.6, 12.8 +# CUDA_VERSION 12.6, 12.8, 12.9 if [[ $CUDA_VERSION == 12* ]]; then export USE_STATIC_CUDNN=0 # Try parallelizing nvcc as well export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2" - if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then echo "Bundling with cudnn and cublas." DEPS_LIST+=( @@ -125,11 +129,12 @@ if [[ $CUDA_VERSION == 12* ]]; then "/usr/local/cuda/lib64/libcublasLt.so.12" "/usr/local/cuda/lib64/libcusparseLt.so.0" "/usr/local/cuda/lib64/libcudart.so.12" - "/usr/local/cuda/lib64/libnvToolsExt.so.1" "/usr/local/cuda/lib64/libnvrtc.so.12" "/usr/local/cuda/lib64/libnvrtc-builtins.so" "/usr/local/cuda/lib64/libcufile.so.0" "/usr/local/cuda/lib64/libcufile_rdma.so.1" + "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12" + "/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so" ) DEPS_SONAME+=( "libcudnn_adv.so.9" @@ -144,12 +149,18 @@ if [[ $CUDA_VERSION == 12* ]]; then "libcublasLt.so.12" "libcusparseLt.so.0" "libcudart.so.12" - "libnvToolsExt.so.1" "libnvrtc.so.12" "libnvrtc-builtins.so" "libcufile.so.0" "libcufile_rdma.so.1" + "libcupti.so.12" + "libnvperf_host.so" ) + # Add libnvToolsExt only if CUDA version is not 12.9 + if [[ $CUDA_VERSION != 12.9* ]]; then + DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1") + DEPS_SONAME+=("libnvToolsExt.so.1") + fi else echo "Using nvidia libs from pypi." CUDA_RPATHS=( @@ -162,8 +173,10 @@ if [[ $CUDA_VERSION == 12* ]]; then '$ORIGIN/../../nvidia/curand/lib' '$ORIGIN/../../nvidia/cusolver/lib' '$ORIGIN/../../nvidia/cusparse/lib' + '$ORIGIN/../../nvidia/cusparselt/lib' '$ORIGIN/../../cusparselt/lib' '$ORIGIN/../../nvidia/nccl/lib' + '$ORIGIN/../../nvidia/nvshmem/lib' '$ORIGIN/../../nvidia/nvtx/lib' '$ORIGIN/../../nvidia/cufile/lib' ) @@ -172,94 +185,9 @@ if [[ $CUDA_VERSION == 12* ]]; then export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' export FORCE_RPATH="--force-rpath" export USE_STATIC_NCCL=0 - export USE_SYSTEM_NCCL=1 - export ATEN_STATIC_CUDA=0 - export USE_CUDA_STATIC_LINK=0 - export USE_CUPTI_SO=1 - export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" - export NCCL_LIB_DIR="/usr/local/cuda/lib64/" - fi -elif [[ $CUDA_VERSION == "11.8" ]]; then - export USE_STATIC_CUDNN=0 - # Turn USE_CUFILE off for CUDA 11.8 since nvidia-cufile-cu11 and 1.9.0.20 are - # not available in PYPI - export USE_CUFILE=0 - # Try parallelizing nvcc as well - export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2" - # Bundle ptxas into the wheel, see https://github.com/pytorch/pytorch/pull/119750 - export BUILD_BUNDLE_PTXAS=1 - - # CUDA 11.8 have to ship the libcusparseLt.so.0 with the binary - # since nvidia-cusparselt-cu11 is not available in PYPI - if [[ $USE_CUSPARSELT == "1" ]]; then - DEPS_SONAME+=( - "libcusparseLt.so.0" - ) - DEPS_LIST+=( - "/usr/local/cuda/lib64/libcusparseLt.so.0" - ) - fi - - if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then - echo "Bundling with cudnn and cublas." - DEPS_LIST+=( - "/usr/local/cuda/lib64/libcudnn_adv.so.9" - "/usr/local/cuda/lib64/libcudnn_cnn.so.9" - "/usr/local/cuda/lib64/libcudnn_graph.so.9" - "/usr/local/cuda/lib64/libcudnn_ops.so.9" - "/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9" - "/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9" - "/usr/local/cuda/lib64/libcudnn_heuristic.so.9" - "/usr/local/cuda/lib64/libcudnn.so.9" - "/usr/local/cuda/lib64/libcublas.so.11" - "/usr/local/cuda/lib64/libcublasLt.so.11" - "/usr/local/cuda/lib64/libcudart.so.11.0" - "/usr/local/cuda/lib64/libnvToolsExt.so.1" - "/usr/local/cuda/lib64/libnvrtc.so.11.2" # this is not a mistake, it links to more specific cuda version - "/usr/local/cuda/lib64/libnvrtc-builtins.so.11.8" - ) - DEPS_SONAME+=( - "libcudnn_adv.so.9" - "libcudnn_cnn.so.9" - "libcudnn_graph.so.9" - "libcudnn_ops.so.9" - "libcudnn_engines_runtime_compiled.so.9" - "libcudnn_engines_precompiled.so.9" - "libcudnn_heuristic.so.9" - "libcudnn.so.9" - "libcublas.so.11" - "libcublasLt.so.11" - "libcudart.so.11.0" - "libnvToolsExt.so.1" - "libnvrtc.so.11.2" - "libnvrtc-builtins.so.11.8" - ) - else - echo "Using nvidia libs from pypi." - CUDA_RPATHS=( - '$ORIGIN/../../nvidia/cublas/lib' - '$ORIGIN/../../nvidia/cuda_cupti/lib' - '$ORIGIN/../../nvidia/cuda_nvrtc/lib' - '$ORIGIN/../../nvidia/cuda_runtime/lib' - '$ORIGIN/../../nvidia/cudnn/lib' - '$ORIGIN/../../nvidia/cufft/lib' - '$ORIGIN/../../nvidia/curand/lib' - '$ORIGIN/../../nvidia/cusolver/lib' - '$ORIGIN/../../nvidia/cusparse/lib' - '$ORIGIN/../../nvidia/nccl/lib' - '$ORIGIN/../../nvidia/nvtx/lib' - ) - CUDA_RPATHS=$(IFS=: ; echo "${CUDA_RPATHS[*]}") - export C_SO_RPATH=$CUDA_RPATHS':$ORIGIN:$ORIGIN/lib' - export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' - export FORCE_RPATH="--force-rpath" - export USE_STATIC_NCCL=0 - export USE_SYSTEM_NCCL=1 export ATEN_STATIC_CUDA=0 export USE_CUDA_STATIC_LINK=0 export USE_CUPTI_SO=1 - export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" - export NCCL_LIB_DIR="/usr/local/cuda/lib64/" fi else echo "Unknown cuda version $CUDA_VERSION" diff --git a/.ci/manywheel/build_libtorch.sh b/.ci/manywheel/build_libtorch.sh index 8f70210dd5b6ce..e9ab620475d10f 100644 --- a/.ci/manywheel/build_libtorch.sh +++ b/.ci/manywheel/build_libtorch.sh @@ -92,6 +92,7 @@ if [[ -z "$PYTORCH_ROOT" ]]; then exit 1 fi pushd "$PYTORCH_ROOT" +retry pip install -q cmake python setup.py clean retry pip install -qr requirements.txt retry pip install -q numpy==2.0.1 diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh index 703248d44aa919..690600efdb375a 100755 --- a/.ci/manywheel/build_rocm.sh +++ b/.ci/manywheel/build_rocm.sh @@ -95,6 +95,7 @@ ROCM_SO_FILES=( "libroctracer64.so" "libroctx64.so" "libhipblaslt.so" + "libhipsparselt.so" "libhiprtc.so" ) @@ -186,20 +187,28 @@ do OS_SO_FILES[${#OS_SO_FILES[@]}]=$file_name # Append lib to array done +ARCH=$(echo $PYTORCH_ROCM_ARCH | sed 's/;/|/g') # Replace ; separated arch list to bar for grep + # rocBLAS library files ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library ROCBLAS_LIB_DST=lib/rocblas/library -ARCH=$(echo $PYTORCH_ROCM_ARCH | sed 's/;/|/g') # Replace ; seperated arch list to bar for grep -ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH) -OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx) -ROCBLAS_LIB_FILES=($ARCH_SPECIFIC_FILES $OTHER_FILES) +ROCBLAS_ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH) +ROCBLAS_OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx) +ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $OTHER_FILES) # hipblaslt library files HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library HIPBLASLT_LIB_DST=lib/hipblaslt/library -ARCH_SPECIFIC_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -E $ARCH) -OTHER_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -v gfx) -HIPBLASLT_LIB_FILES=($ARCH_SPECIFIC_FILES $OTHER_FILES) +HIPBLASLT_ARCH_SPECIFIC_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -E $ARCH) +HIPBLASLT_OTHER_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -v gfx) +HIPBLASLT_LIB_FILES=($HIPBLASLT_ARCH_SPECIFIC_FILES $HIPBLASLT_OTHER_FILES) + +# hipsparselt library files +HIPSPARSELT_LIB_SRC=$ROCM_HOME/lib/hipsparselt/library +HIPSPARSELT_LIB_DST=lib/hipsparselt/library +HIPSPARSELT_ARCH_SPECIFIC_FILES=$(ls $HIPSPARSELT_LIB_SRC | grep -E $ARCH) +#HIPSPARSELT_OTHER_FILES=$(ls $HIPSPARSELT_LIB_SRC | grep -v gfx) +HIPSPARSELT_LIB_FILES=($HIPSPARSELT_ARCH_SPECIFIC_FILES $HIPSPARSELT_OTHER_FILES) # ROCm library files ROCM_SO_PATHS=() @@ -234,12 +243,14 @@ DEPS_SONAME=( DEPS_AUX_SRCLIST=( "${ROCBLAS_LIB_FILES[@]/#/$ROCBLAS_LIB_SRC/}" "${HIPBLASLT_LIB_FILES[@]/#/$HIPBLASLT_LIB_SRC/}" + "${HIPSPARSELT_LIB_FILES[@]/#/$HIPSPARSELT_LIB_SRC/}" "/opt/amdgpu/share/libdrm/amdgpu.ids" ) DEPS_AUX_DSTLIST=( "${ROCBLAS_LIB_FILES[@]/#/$ROCBLAS_LIB_DST/}" "${HIPBLASLT_LIB_FILES[@]/#/$HIPBLASLT_LIB_DST/}" + "${HIPSPARSELT_LIB_FILES[@]/#/$HIPSPARSELT_LIB_DST/}" "share/libdrm/amdgpu.ids" ) diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 2bcfe9f02b02b1..994bd179e4649f 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -27,6 +27,12 @@ cmake --version echo "Environment variables:" env +# The sccache wrapped version of nvcc gets put in /opt/cache/lib in docker since +# there are some issues if it is always wrapped, so we need to add it to PATH +# during CI builds. +# https://github.com/pytorch/pytorch/blob/0b6c0898e6c352c8ea93daec854e704b41485375/.ci/docker/common/install_cache.sh#L97 +export PATH="/opt/cache/lib:$PATH" + if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then # Use jemalloc during compilation to mitigate https://github.com/pytorch/pytorch/issues/116289 export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2 @@ -52,12 +58,6 @@ fi export USE_LLVM=/opt/llvm export LLVM_DIR=/opt/llvm/lib/cmake/llvm -if [[ "$BUILD_ENVIRONMENT" == *executorch* ]]; then - # To build test_edge_op_registration - export BUILD_EXECUTORCH=ON - export USE_CUDA=0 -fi - if ! which conda; then # In ROCm CIs, we are doing cross compilation on build machines with # intel cpu and later run tests on machines with amd cpu. @@ -198,10 +198,8 @@ fi # We only build FlashAttention files for CUDA 8.0+, and they require large amounts of # memory to build and will OOM -if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]] && [ -z "$MAX_JOBS_OVERRIDE" ]; then - echo "WARNING: FlashAttention files require large amounts of memory to build and will OOM" - echo "Setting MAX_JOBS=(nproc-2)/3 to reduce memory usage" - export MAX_JOBS="$(( $(nproc --ignore=2) / 3 ))" +if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]]; then + export BUILD_CUSTOM_STEP="ninja -C build flash_attention -j 2" fi if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then @@ -257,6 +255,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then set -e -o pipefail get_bazel + python3 tools/optional_submodules.py checkout_eigen # Leave 1 CPU free and use only up to 80% of memory to reduce the change of crashing # the runner @@ -394,10 +393,8 @@ else # This is an attempt to mitigate flaky libtorch build OOM error. By default, the build parallelization # is set to be the number of CPU minus 2. So, let's try a more conservative value here. A 4xlarge has # 16 CPUs - if [ -z "$MAX_JOBS_OVERRIDE" ]; then - MAX_JOBS=$(nproc --ignore=4) - export MAX_JOBS - fi + MAX_JOBS=$(nproc --ignore=4) + export MAX_JOBS # NB: Install outside of source directory (at the same level as the root # pytorch folder) so that it doesn't get cleaned away prior to docker push. diff --git a/.ci/pytorch/check_binary.sh b/.ci/pytorch/check_binary.sh index 7a7d9f30256e58..78baf6a0761d72 100755 --- a/.ci/pytorch/check_binary.sh +++ b/.ci/pytorch/check_binary.sh @@ -313,7 +313,7 @@ if [[ "$(uname)" == 'Linux' && "$PACKAGE_TYPE" == 'manywheel' ]]; then # Please see issue for reference: https://github.com/pytorch/pytorch/issues/152426 if [[ "$(uname -m)" == "s390x" ]]; then cxx_abi="19" - elif [[ "$DESIRED_CUDA" != 'cu118' && "$DESIRED_CUDA" != 'xpu' && "$DESIRED_CUDA" != 'rocm'* ]]; then + elif [[ "$DESIRED_CUDA" != 'xpu' && "$DESIRED_CUDA" != 'rocm'* ]]; then cxx_abi="18" else cxx_abi="16" diff --git a/.ci/pytorch/common-build.sh b/.ci/pytorch/common-build.sh index 88acd09d660845..8ca9fdb34c77a9 100644 --- a/.ci/pytorch/common-build.sh +++ b/.ci/pytorch/common-build.sh @@ -13,6 +13,13 @@ if [[ "$BUILD_ENVIRONMENT" != *win-* ]]; then fi if which sccache > /dev/null; then + # Clear SCCACHE_BUCKET and SCCACHE_REGION if they are empty, otherwise + # sccache will complain about invalid bucket configuration + if [[ -z "${SCCACHE_BUCKET:-}" ]]; then + unset SCCACHE_BUCKET + unset SCCACHE_REGION + fi + # Save sccache logs to file sccache --stop-server > /dev/null 2>&1 || true rm -f ~/sccache_error.log || true diff --git a/.ci/pytorch/common.sh b/.ci/pytorch/common.sh index b81d0e464203a2..1dc0615a99ddb6 100644 --- a/.ci/pytorch/common.sh +++ b/.ci/pytorch/common.sh @@ -15,6 +15,6 @@ if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then export PYTORCH_TEST_WITH_ROCM=1 fi -# TODO: Renable libtorch testing for MacOS, see https://github.com/pytorch/pytorch/issues/62598 +# TODO: Reenable libtorch testing for MacOS, see https://github.com/pytorch/pytorch/issues/62598 # shellcheck disable=SC2034 BUILD_TEST_LIBTORCH=0 diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index de2449d47c83c8..9c0e5242f433ca 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -159,11 +159,6 @@ function install_torchvision() { fi } -function install_tlparse() { - pip_install --user "tlparse==0.3.30" - PATH="$(python -m site --user-base)/bin:$PATH" -} - function install_torchrec_and_fbgemm() { local torchrec_commit torchrec_commit=$(get_pinned_commit torchrec) diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 05711a1ce4d063..57549a7d63e1f8 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -5,11 +5,6 @@ set -x # shellcheck source=./macos-common.sh source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh" -if [[ -n "$CONDA_ENV" ]]; then - # Use binaries under conda environment - export PATH="$CONDA_ENV/bin":$PATH -fi - # Test that OpenMP is enabled pushd test if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available()))") == "1" ]]; then @@ -233,53 +228,52 @@ test_torchbench_smoketest() { mkdir -p "$TEST_REPORTS_DIR" local device=mps - local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor) - local hf_models=(GoogleFnet YituTechConvBert Speech2Text2ForCausalLM) + local dtypes=(undefined float16 bfloat16 notset) + local dtype=${dtypes[$1]} + local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor timm_resnet timm_vovnet vgg16) for backend in eager inductor; do - for dtype in notset float16 bfloat16; do - echo "Launching torchbench inference performance run for backend ${backend} and dtype ${dtype}" - local dtype_arg="--${dtype}" - if [ "$dtype" == notset ]; then - dtype_arg="--float32" - fi - touch "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_performance.csv" - for model in "${models[@]}"; do + echo "Launching torchbench inference performance run for backend ${backend} and dtype ${dtype}" + local dtype_arg="--${dtype}" + if [ "$dtype" == notset ]; then + dtype_arg="--float32" + fi + touch "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_performance.csv" + for model in "${models[@]}"; do + PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/torchbench.py \ + --performance --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \ + --output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_performance.csv" || true + if [ "$backend" == "inductor" ]; then PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/torchbench.py \ - --performance --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \ - --output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_performance.csv" || true - if [ "$backend" == "inductor" ]; then - PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/torchbench.py \ - --accuracy --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \ - --output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_accuracy.csv" || true - fi - done - for model in "${hf_models[@]}"; do - if [ "$backend" == "inductor" ]; then - PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \ - --performance --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \ - --output "$TEST_REPORTS_DIR/inductor_${backend}_huggingface_${dtype}_inference_${device}_performance.csv" || true - PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \ - --accuracy --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \ - --output "$TEST_REPORTS_DIR/inductor_${backend}_huggingface_${dtype}_inference_${device}_accuracy.csv" || true - fi - done + --accuracy --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \ + --output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_accuracy.csv" || true + fi done - - for dtype in notset amp; do - echo "Launching torchbench training performance run for backend ${backend} and dtype ${dtype}" - touch "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_training_${device}_performance.csv" - local dtype_arg="--${dtype}" - if [ "$dtype" == notset ]; then + if [ "$backend" == "inductor" ]; then + PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \ + --performance --backend "$backend" --inference --devices "$device" "$dtype_arg" \ + --output "$TEST_REPORTS_DIR/inductor_${backend}_huggingface_${dtype}_inference_${device}_performance.csv" || true + PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \ + --accuracy --backend "$backend" --inference --devices "$device" "$dtype_arg" \ + --output "$TEST_REPORTS_DIR/inductor_${backend}_huggingface_${dtype}_inference_${device}_accuracy.csv" || true + fi + + if [ "$dtype" == notset ]; then + for dtype_ in notset amp; do + echo "Launching torchbench training performance run for backend ${backend} and dtype ${dtype_}" + touch "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype_}_training_${device}_performance.csv" + local dtype_arg="--${dtype_}" + if [ "$dtype_" == notset ]; then dtype_arg="--float32" - fi - for model in "${models[@]}"; do - PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/torchbench.py \ - --performance --only "$model" --backend "$backend" --training --devices "$device" "$dtype_arg" \ - --output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_training_${device}_performance.csv" || true + fi + for model in "${models[@]}"; do + PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/torchbench.py \ + --performance --only "$model" --backend "$backend" --training --devices "$device" "$dtype_arg" \ + --output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype_}_training_${device}_performance.csv" || true + done done - done + fi done @@ -318,8 +312,6 @@ test_timm_perf() { echo "timm benchmark on mps device completed" } -install_tlparse - if [[ $TEST_CONFIG == *"perf_all"* ]]; then test_torchbench_perf test_hf_perf @@ -331,7 +323,7 @@ elif [[ $TEST_CONFIG == *"perf_hf"* ]]; then elif [[ $TEST_CONFIG == *"perf_timm"* ]]; then test_timm_perf elif [[ $TEST_CONFIG == *"perf_smoketest"* ]]; then - test_torchbench_smoketest + test_torchbench_smoketest "${SHARD_NUMBER}" elif [[ $TEST_CONFIG == *"mps"* ]]; then test_python_mps elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index b5c3bbab358145..3e88ffe4ffd77e 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -93,7 +93,7 @@ def check_lib_symbols_for_abi_correctness(lib: str) -> None: f"Found pre-cxx11 symbols, but there shouldn't be any, see: {pre_cxx11_symbols[:100]}" ) if num_cxx11_symbols < 100: - raise RuntimeError("Didn't find enought cxx11 symbols") + raise RuntimeError("Didn't find enough cxx11 symbols") def main() -> None: diff --git a/.ci/pytorch/smoke_test/check_gomp.py b/.ci/pytorch/smoke_test/check_gomp.py index 93430ff39906fd..225574dcffa033 100644 --- a/.ci/pytorch/smoke_test/check_gomp.py +++ b/.ci/pytorch/smoke_test/check_gomp.py @@ -46,6 +46,9 @@ def get_gomp_thread(): # use the default gomp path of AlmaLinux OS libgomp_path = "/usr/lib64/libgomp.so.1" + # if it does not exist, try Ubuntu path + if not os.path.exists(libgomp_path): + libgomp_path = f"/usr/lib/{os.uname().machine}-linux-gnu/libgomp.so.1" os.environ["GOMP_CPU_AFFINITY"] = "0-3" diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index 24d1d64dd2056d..a5f9100266d2b9 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -276,7 +276,7 @@ def smoke_test_cuda( torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) print(f"Torch nccl; version: {torch_nccl_version}") - # Pypi dependencies are installed on linux ony and nccl is availbale only on Linux. + # Pypi dependencies are installed on linux only and nccl is available only on Linux. if pypi_pkg_check == "enabled" and sys.platform in ["linux", "linux2"]: compare_pypi_to_torch_versions( "cudnn", find_pypi_package_version("nvidia-cudnn"), torch_cudnn_version diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index e068351f897fd5..4994aa25c8f7f5 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -11,6 +11,8 @@ export TERM=vt100 # shellcheck source=./common.sh source "$(dirname "${BASH_SOURCE[0]}")/common.sh" +# shellcheck source=./common-build.sh +source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" # Do not change workspace permissions for ROCm and s390x CI jobs # as it can leave workspace with bad permissions for cancelled jobs @@ -196,7 +198,7 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then # shellcheck disable=SC1091 source /opt/intel/oneapi/mpi/latest/env/vars.sh # Check XPU status before testing - xpu-smi discovery + timeout 30 xpu-smi discovery || true fi if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then @@ -212,8 +214,6 @@ if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then export VALGRIND=OFF fi -install_tlparse - # DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems # if you're not careful. Check this if you made some changes and the # ASAN test is not working @@ -226,7 +226,7 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then export PYTORCH_TEST_WITH_ASAN=1 export PYTORCH_TEST_WITH_UBSAN=1 # TODO: Figure out how to avoid hard-coding these paths - export ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-15/bin/llvm-symbolizer + export ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-18/bin/llvm-symbolizer export TORCH_USE_RTLD_GLOBAL=1 # NB: We load libtorch.so with RTLD_GLOBAL for UBSAN, unlike our # default behavior. @@ -324,6 +324,23 @@ test_python_smoke() { assert_git_not_dirty } +test_h100_distributed() { + # Distributed tests at H100 + time python test/run_test.py --include distributed/_composable/test_composability/test_pp_composability.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + # This test requires multicast support + time python test/run_test.py --include distributed/_composable/fsdp/test_fully_shard_comm.py -k TestFullyShardAllocFromPG $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + assert_git_not_dirty +} + +test_h100_symm_mem() { + # symmetric memory test + time python test/run_test.py --include distributed/test_symmetric_memory.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include distributed/test_nvshmem.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include distributed/test_nvshmem_triton.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include distributed/test_nccl.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + assert_git_not_dirty +} + test_lazy_tensor_meta_reference_disabled() { export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1 echo "Testing lazy tensor operations without meta reference" @@ -338,6 +355,7 @@ test_dynamo_wrapped_shard() { exit 1 fi python tools/dynamo/verify_dynamo.py + python tools/dynamo/gb_id_mapping.py verify # PLEASE DO NOT ADD ADDITIONAL EXCLUDES HERE. # Instead, use @skipIfTorchDynamo on your tests. time python test/run_test.py --dynamo \ @@ -352,6 +370,17 @@ test_dynamo_wrapped_shard() { assert_git_not_dirty } +test_einops() { + pip install einops==0.6.1 + time python test/run_test.py --einops --verbose --upload-artifacts-while-running + pip install einops==0.7.0 + time python test/run_test.py --einops --verbose --upload-artifacts-while-running + pip install einops==0.8.1 + time python test/run_test.py --einops --verbose --upload-artifacts-while-running + assert_git_not_dirty +} + + test_inductor_distributed() { # Smuggle a few multi-gpu tests here so that we don't have to request another large node echo "Testing multi_gpu tests in test_torchinductor" @@ -409,14 +438,21 @@ test_inductor_aoti() { python3 tools/amd_build/build_amd.py fi if [[ "$BUILD_ENVIRONMENT" == *sm86* ]]; then - BUILD_AOT_INDUCTOR_TEST=1 TORCH_CUDA_ARCH_LIST=8.6 USE_FLASH_ATTENTION=OFF python setup.py develop + BUILD_COMMAND=(TORCH_CUDA_ARCH_LIST=8.6 USE_FLASH_ATTENTION=OFF python setup.py develop) # TODO: Replace me completely, as one should not use conda libstdc++, nor need special path to TORCH_LIB - LD_LIBRARY_PATH=/opt/conda/envs/py_3.10/lib/:${TORCH_LIB_DIR}:$LD_LIBRARY_PATH - CPP_TESTS_DIR="${BUILD_BIN_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile + TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="/opt/conda/envs/py_3.10/lib:${TORCH_LIB_DIR}:${LD_LIBRARY_PATH}") else - BUILD_AOT_INDUCTOR_TEST=1 python setup.py develop - CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="${TORCH_LIB_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile + BUILD_COMMAND=(python setup.py develop) + TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="${TORCH_LIB_DIR}") fi + + # aoti cmake custom command requires `torch` to be installed + # initialize the cmake build cache and install torch + /usr/bin/env "${BUILD_COMMAND[@]}" + # rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled + /usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}" + + /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile } test_inductor_cpp_wrapper_shard() { @@ -429,47 +465,26 @@ test_inductor_cpp_wrapper_shard() { TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" - if [[ "$1" -eq "2" ]]; then - # For now, manually put the opinfo tests in shard 2, and all other tests in - # shard 1. Run all CPU tests, as well as specific GPU tests triggering past - # bugs, for now. - python test/run_test.py \ - --include inductor/test_torchinductor_opinfo \ - -k 'linalg or to_sparse or TestInductorOpInfoCPU' \ - --verbose - exit - fi - # Run certain inductor unit tests with cpp wrapper. In the end state, we # should be able to run all the inductor unit tests with cpp_wrapper. + # + # TODO: I'm pretty sure that "TestInductorOpInfoCPU" is not a valid filter, + # but change that in another PR to more accurately monitor the increased CI + # usage. + python test/run_test.py \ + --include inductor/test_torchinductor_opinfo \ + -k 'linalg or to_sparse or TestInductorOpInfoCPU' \ + --shard "$1" "$NUM_TEST_SHARDS" \ + --verbose python test/run_test.py \ --include inductor/test_torchinductor inductor/test_max_autotune inductor/test_cpu_repro \ + --shard "$1" "$NUM_TEST_SHARDS" \ + --verbose + python test/run_test.py --inductor \ + --include test_torch \ + -k 'take' \ + --shard "$1" "$NUM_TEST_SHARDS" \ --verbose - python test/run_test.py --inductor --include test_torch -k 'take' --verbose - - # Run inductor benchmark tests with cpp wrapper. - # Skip benchmark tests if it's in rerun-disabled-mode. - if [[ "${PYTORCH_TEST_RERUN_DISABLED_TESTS}" == "1" ]]; then - echo "skip dynamo benchmark tests for rerun-disabled-test" - else - echo "run dynamo benchmark tests with cpp wrapper" - python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \ - --training --inductor --disable-cudagraphs --only vit_base_patch16_224 \ - --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" - python benchmarks/dynamo/check_accuracy.py \ - --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/${MAYBE_ROCM}inductor_timm_training.csv" - - python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ - --bfloat16 --inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" - python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ - --bfloat16 --inference --inductor --only llama --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" - python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ - --bfloat16 --inference --inductor --only moco --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" - python benchmarks/dynamo/check_accuracy.py \ - --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/${MAYBE_ROCM}inductor_torchbench_inference.csv" - fi } # "Global" flags for inductor benchmarking controlled by TEST_CONFIG @@ -590,11 +605,14 @@ test_perf_for_dashboard() { local device=cuda if [[ "${TEST_CONFIG}" == *cpu* ]]; then - if [[ "${TEST_CONFIG}" == *cpu_x86* ]]; then + if [[ "${TEST_CONFIG}" == *cpu_x86_zen* ]]; then + device=cpu_x86_zen + elif [[ "${TEST_CONFIG}" == *cpu_x86* ]]; then device=cpu_x86 elif [[ "${TEST_CONFIG}" == *cpu_aarch64* ]]; then device=cpu_aarch64 fi + test_inductor_set_cpu_affinity elif [[ "${TEST_CONFIG}" == *cuda_a10g* ]]; then device=cuda_a10g elif [[ "${TEST_CONFIG}" == *h100* ]]; then @@ -603,12 +621,13 @@ test_perf_for_dashboard() { device=rocm fi - # Always set CPU affinity because metrics like compilation time requires CPU - test_inductor_set_cpu_affinity - for mode in "${modes[@]}"; do if [[ "$mode" == "inference" ]]; then - dtype=bfloat16 + if [[ "$device" == "cpu_x86" ]]; then + dtype=amp + else + dtype=bfloat16 + fi elif [[ "$mode" == "training" ]]; then dtype=amp fi @@ -620,6 +639,10 @@ test_perf_for_dashboard() { target_flag+=( --no-translation-validation) fi + if [[ "$DASHBOARD_TAG" == *freezing-true* ]]; then + target_flag+=( --freezing) + fi + if [[ "$DASHBOARD_TAG" == *default-true* ]]; then $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --disable-cudagraphs "$@" \ @@ -1133,6 +1156,12 @@ test_custom_backend() { test_custom_script_ops() { echo "Testing custom script operators" + + if [[ "$BUILD_ENVIRONMENT" == *s390x* ]]; then + echo "Skipping custom script operators until it's fixed" + return 0 + fi + CUSTOM_OP_BUILD="${CUSTOM_TEST_ARTIFACT_BUILD_DIR}/custom-op-build" pushd test/custom_operator cp -a "$CUSTOM_OP_BUILD" build @@ -1522,7 +1551,7 @@ test_executorch() { test_linux_aarch64() { python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \ test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \ - test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops \ + test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops test_cpp_extensions_open_device_registration \ --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose # Dynamo tests @@ -1556,7 +1585,8 @@ test_operator_benchmark() { cd "${TEST_DIR}"/benchmarks/operator_benchmark $TASKSET python -m benchmark_all_test --device "$1" --tag-filter "$2" \ - --output-dir "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" + --output-csv "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \ + --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.json" \ pip_install pandas python check_perf_csv.py \ @@ -1641,7 +1671,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then install_torchaudio cuda fi install_torchvision - TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install git+https://github.com/pytorch/ao.git + TORCH_CUDA_ARCH_LIST="8.0;8.6" install_torchao id=$((SHARD_NUMBER-1)) # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 @@ -1666,11 +1696,11 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" fi elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then - install_torchaudio cuda install_torchvision - checkout_install_torchbench hf_T5 llama moco PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" - test_inductor_aoti + if [[ "$SHARD_NUMBER" -eq "1" ]]; then + test_inductor_aoti + fi elif [[ "${TEST_CONFIG}" == *inductor* ]]; then install_torchvision test_inductor_shard "${SHARD_NUMBER}" @@ -1679,6 +1709,8 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then test_inductor_distributed fi fi +elif [[ "${TEST_CONFIG}" == *einops* ]]; then + test_einops elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then install_torchvision test_dynamo_wrapped_shard "${SHARD_NUMBER}" @@ -1726,6 +1758,10 @@ elif [[ "${BUILD_ENVIRONMENT}" == *xpu* ]]; then test_xpu_bin elif [[ "${TEST_CONFIG}" == smoke ]]; then test_python_smoke +elif [[ "${TEST_CONFIG}" == h100_distributed ]]; then + test_h100_distributed +elif [[ "${TEST_CONFIG}" == "h100-symm-mem" ]]; then + test_h100_symm_mem else install_torchvision install_monkeytype diff --git a/.ci/pytorch/win-build.sh b/.ci/pytorch/win-build.sh index 7966e56695c2e1..d08fa87a5a6fae 100755 --- a/.ci/pytorch/win-build.sh +++ b/.ci/pytorch/win-build.sh @@ -31,7 +31,7 @@ PYLONG_API_CHECK=$? if [[ $PYLONG_API_CHECK == 0 ]]; then echo "Usage of PyLong_{From,As}{Unsigned}Long API may lead to overflow errors on Windows" echo "because \`sizeof(long) == 4\` and \`sizeof(unsigned long) == 4\`." - echo "Please include \"torch/csrc/utils/python_numbers.h\" and use the correspoding APIs instead." + echo "Please include \"torch/csrc/utils/python_numbers.h\" and use the corresponding APIs instead." echo "PyLong_FromLong -> THPUtils_packInt32 / THPUtils_packInt64" echo "PyLong_AsLong -> THPUtils_unpackInt (32-bit) / THPUtils_unpackLong (64-bit)" echo "PyLong_FromUnsignedLong -> THPUtils_packUInt32 / THPUtils_packUInt64" diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 2e1c07ece7f047..6a475dd89d32bf 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -10,7 +10,7 @@ set PATH=C:\Program Files\CMake\bin;C:\Program Files\7-Zip;C:\ProgramData\chocol :: able to see what our cl.exe commands are (since you can actually :: just copy-paste them into a local Windows setup to just rebuild a :: single file.) -:: log sizes are too long, but leaving this here incase someone wants to use it locally +:: log sizes are too long, but leaving this here in case someone wants to use it locally :: set CMAKE_VERBOSE_MAKEFILE=1 diff --git a/.ci/pytorch/win-test-helpers/run_python_nn_smoketests.py b/.ci/pytorch/win-test-helpers/run_python_nn_smoketests.py index 6df547d4a3ebcc..6b19c79218504e 100755 --- a/.ci/pytorch/win-test-helpers/run_python_nn_smoketests.py +++ b/.ci/pytorch/win-test-helpers/run_python_nn_smoketests.py @@ -52,7 +52,7 @@ if os.path.exists(debugger): command_args = [debugger, "-o", "-c", "~*g; q"] + command_args command_string = " ".join(command_args) - print("Reruning with traceback enabled") + print("Rerunning with traceback enabled") print("Command:", command_string) subprocess.run(command_args, check=False) sys.exit(e.returncode) diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index 077470a0588da3..b61dd06ef562cf 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -52,6 +52,9 @@ python -m pip install parameterized==0.8.1 # Install pulp for testing ilps under torch\distributed\_tools python -m pip install pulp==2.9.0 +# Install expecttest to merge https://github.com/pytorch/pytorch/pull/155308 +python -m pip install expecttest==0.3.0 + run_tests() { # Run nvidia-smi if available for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do diff --git a/.ci/pytorch/windows/cuda118.bat b/.ci/pytorch/windows/cuda118.bat deleted file mode 100644 index b2773fe3863223..00000000000000 --- a/.ci/pytorch/windows/cuda118.bat +++ /dev/null @@ -1,59 +0,0 @@ -@echo off - -set MODULE_NAME=pytorch - -IF NOT EXIST "setup.py" IF NOT EXIST "%MODULE_NAME%" ( - call internal\clone.bat - cd %~dp0 -) ELSE ( - call internal\clean.bat -) -IF ERRORLEVEL 1 goto :eof - -call internal\check_deps.bat -IF ERRORLEVEL 1 goto :eof - -REM Check for optional components - -set USE_CUDA= -set CMAKE_GENERATOR=Visual Studio 15 2017 Win64 - -IF "%NVTOOLSEXT_PATH%"=="" ( - IF EXIST "C:\Program Files\NVIDIA Corporation\NvToolsExt\lib\x64\nvToolsExt64_1.lib" ( - set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt - ) ELSE ( - echo NVTX ^(Visual Studio Extension ^for CUDA^) ^not installed, failing - exit /b 1 - ) -) - -IF "%CUDA_PATH_V118%"=="" ( - IF EXIST "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\nvcc.exe" ( - set "CUDA_PATH_V118=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" - ) ELSE ( - echo CUDA 11.8 not found, failing - exit /b 1 - ) -) - -IF "%BUILD_VISION%" == "" ( - set TORCH_CUDA_ARCH_LIST=3.7+PTX;5.0;6.0;6.1;7.0;7.5;8.0;8.6;9.0 - set TORCH_NVCC_FLAGS=-Xfatbin -compress-all -) ELSE ( - set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -) - -set "CUDA_PATH=%CUDA_PATH_V118%" -set "PATH=%CUDA_PATH_V118%\bin;%PATH%" - -:optcheck - -call internal\check_opts.bat -IF ERRORLEVEL 1 goto :eof - -if exist "%NIGHTLIES_PYTORCH_ROOT%" cd %NIGHTLIES_PYTORCH_ROOT%\.. -call %~dp0\internal\copy.bat -IF ERRORLEVEL 1 goto :eof - -call %~dp0\internal\setup.bat -IF ERRORLEVEL 1 goto :eof diff --git a/.ci/pytorch/windows/cuda124.bat b/.ci/pytorch/windows/cuda124.bat deleted file mode 100644 index 36eff665ccf7f5..00000000000000 --- a/.ci/pytorch/windows/cuda124.bat +++ /dev/null @@ -1,59 +0,0 @@ -@echo off - -set MODULE_NAME=pytorch - -IF NOT EXIST "setup.py" IF NOT EXIST "%MODULE_NAME%" ( - call internal\clone.bat - cd %~dp0 -) ELSE ( - call internal\clean.bat -) -IF ERRORLEVEL 1 goto :eof - -call internal\check_deps.bat -IF ERRORLEVEL 1 goto :eof - -REM Check for optional components - -set USE_CUDA= -set CMAKE_GENERATOR=Visual Studio 15 2017 Win64 - -IF "%NVTOOLSEXT_PATH%"=="" ( - IF EXIST "C:\Program Files\NVIDIA Corporation\NvToolsExt\lib\x64\nvToolsExt64_1.lib" ( - set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt - ) ELSE ( - echo NVTX ^(Visual Studio Extension ^for CUDA^) ^not installed, failing - exit /b 1 - ) -) - -IF "%CUDA_PATH_V124%"=="" ( - IF EXIST "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\nvcc.exe" ( - set "CUDA_PATH_V124=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" - ) ELSE ( - echo CUDA 12.4 not found, failing - exit /b 1 - ) -) - -IF "%BUILD_VISION%" == "" ( - set TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5;8.0;8.6;9.0 - set TORCH_NVCC_FLAGS=-Xfatbin -compress-all -) ELSE ( - set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -) - -set "CUDA_PATH=%CUDA_PATH_V124%" -set "PATH=%CUDA_PATH_V124%\bin;%PATH%" - -:optcheck - -call internal\check_opts.bat -IF ERRORLEVEL 1 goto :eof - -if exist "%NIGHTLIES_PYTORCH_ROOT%" cd %NIGHTLIES_PYTORCH_ROOT%\.. -call %~dp0\internal\copy.bat -IF ERRORLEVEL 1 goto :eof - -call %~dp0\internal\setup.bat -IF ERRORLEVEL 1 goto :eof diff --git a/.ci/pytorch/windows/cuda129.bat b/.ci/pytorch/windows/cuda129.bat new file mode 100644 index 00000000000000..77ef14921aa63e --- /dev/null +++ b/.ci/pytorch/windows/cuda129.bat @@ -0,0 +1,59 @@ +@echo off + +set MODULE_NAME=pytorch + +IF NOT EXIST "setup.py" IF NOT EXIST "%MODULE_NAME%" ( + call internal\clone.bat + cd %~dp0 +) ELSE ( + call internal\clean.bat +) +IF ERRORLEVEL 1 goto :eof + +call internal\check_deps.bat +IF ERRORLEVEL 1 goto :eof + +REM Check for optional components + +set USE_CUDA= +set CMAKE_GENERATOR=Visual Studio 15 2017 Win64 + +IF "%NVTOOLSEXT_PATH%"=="" ( + IF EXIST "C:\Program Files\NVIDIA Corporation\NvToolsExt\lib\x64\nvToolsExt64_1.lib" ( + set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt + ) ELSE ( + echo NVTX ^(Visual Studio Extension ^for CUDA^) ^not installed, failing + exit /b 1 + ) +) + +IF "%CUDA_PATH_V129%"=="" ( + IF EXIST "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9\bin\nvcc.exe" ( + set "CUDA_PATH_V129=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9" + ) ELSE ( + echo CUDA 12.9 not found, failing + exit /b 1 + ) +) + +IF "%BUILD_VISION%" == "" ( + set TORCH_CUDA_ARCH_LIST=7.5;8.0;8.6;9.0;10.0;12.0 + set TORCH_NVCC_FLAGS=-Xfatbin -compress-all +) ELSE ( + set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120 +) + +set "CUDA_PATH=%CUDA_PATH_V129%" +set "PATH=%CUDA_PATH_V129%\bin;%PATH%" + +:optcheck + +call internal\check_opts.bat +IF ERRORLEVEL 1 goto :eof + +if exist "%NIGHTLIES_PYTORCH_ROOT%" cd %NIGHTLIES_PYTORCH_ROOT%\.. +call %~dp0\internal\copy.bat +IF ERRORLEVEL 1 goto :eof + +call %~dp0\internal\setup.bat +IF ERRORLEVEL 1 goto :eof diff --git a/.ci/pytorch/windows/internal/check_deps.bat b/.ci/pytorch/windows/internal/check_deps.bat index 46f438615774ce..35e6877188d235 100644 --- a/.ci/pytorch/windows/internal/check_deps.bat +++ b/.ci/pytorch/windows/internal/check_deps.bat @@ -65,7 +65,7 @@ for /F "usebackq delims=" %%i in (`python -c "import sys; print('{0[0]}{0[1]}'.f if %PYVER% LSS 35 ( echo Warning: PyTorch for Python 2 under Windows is experimental. echo Python x64 3.5 or up is recommended to compile PyTorch on Windows - echo Maybe you can create a virual environment if you have conda installed: + echo Maybe you can create a virtual environment if you have conda installed: echo ^> conda create -n test python=3.6 pyyaml numpy echo ^> activate test ) diff --git a/.ci/pytorch/windows/internal/copy.bat b/.ci/pytorch/windows/internal/copy.bat index 8042db09f462a4..40f2bd7acdbb91 100644 --- a/.ci/pytorch/windows/internal/copy.bat +++ b/.ci/pytorch/windows/internal/copy.bat @@ -8,6 +8,7 @@ copy "%CUDA_PATH%\bin\cusolver*64_*.dll*" pytorch\torch\lib copy "%CUDA_PATH%\bin\cudnn*64_*.dll*" pytorch\torch\lib copy "%CUDA_PATH%\bin\nvrtc*64_*.dll*" pytorch\torch\lib copy "%CUDA_PATH%\extras\CUPTI\lib64\cupti64_*.dll*" pytorch\torch\lib +copy "%CUDA_PATH%\extras\CUPTI\lib64\nvperf_host*.dll*" pytorch\torch\lib copy "C:\Program Files\NVIDIA Corporation\NvToolsExt\bin\x64\nvToolsExt64_1.dll*" pytorch\torch\lib copy "%PYTHON_LIB_PATH%\libiomp*5md.dll" pytorch\torch\lib diff --git a/.ci/pytorch/windows/internal/cuda_install.bat b/.ci/pytorch/windows/internal/cuda_install.bat index e968bd16813d39..a0eb650f8506a4 100644 --- a/.ci/pytorch/windows/internal/cuda_install.bat +++ b/.ci/pytorch/windows/internal/cuda_install.bat @@ -23,49 +23,23 @@ set CUDNN_LIB_FOLDER="lib\x64" :: Skip all of this if we already have cuda installed if exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" goto set_cuda_env_vars -if %CUDA_VER% EQU 118 goto cuda118 -if %CUDA_VER% EQU 124 goto cuda124 if %CUDA_VER% EQU 126 goto cuda126 if %CUDA_VER% EQU 128 goto cuda128 +if %CUDA_VER% EQU 129 goto cuda129 echo CUDA %CUDA_VERSION_STR% is not supported exit /b 1 -:cuda118 - -set CUDA_INSTALL_EXE=cuda_11.8.0_522.06_windows.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" & REM @lint-ignore - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=cuda_profiler_api_11.8 thrust_11.8 nvcc_11.8 cuobjdump_11.8 nvprune_11.8 nvprof_11.8 cupti_11.8 cublas_11.8 cublas_dev_11.8 cudart_11.8 cufft_11.8 cufft_dev_11.8 curand_11.8 curand_dev_11.8 cusolver_11.8 cusolver_dev_11.8 cusparse_11.8 cusparse_dev_11.8 npp_11.8 npp_dev_11.8 nvrtc_11.8 nvrtc_dev_11.8 nvml_dev_11.8 nvtx_11.8" -) - -set CUDNN_FOLDER=cudnn-windows-x86_64-9.5.0.50_cuda11-archive -set CUDNN_LIB_FOLDER="lib" -set "CUDNN_INSTALL_ZIP=%CUDNN_FOLDER%.zip" -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" & REM @lint-ignore - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -@REM cuDNN 8.3+ required zlib to be installed on the path -echo Installing ZLIB dlls -curl -k -L "http://s3.amazonaws.com/ossci-windows/zlib123dllx64.zip" --output "%SRC_DIR%\temp_build\zlib123dllx64.zip" -7z x "%SRC_DIR%\temp_build\zlib123dllx64.zip" -o"%SRC_DIR%\temp_build\zlib" -xcopy /Y "%SRC_DIR%\temp_build\zlib\dll_x64\*.dll" "C:\Windows\System32" - goto cuda_common -:cuda124 +:cuda126 -set CUDA_INSTALL_EXE=cuda_12.4.0_551.61_windows.exe +set CUDA_INSTALL_EXE=cuda_12.6.2_560.94_windows.exe if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" & REM @lint-ignore if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=cuda_profiler_api_12.4 thrust_12.4 nvcc_12.4 cuobjdump_12.4 nvprune_12.4 nvprof_12.4 cupti_12.4 cublas_12.4 cublas_dev_12.4 cudart_12.4 cufft_12.4 cufft_dev_12.4 curand_12.4 curand_dev_12.4 cusolver_12.4 cusolver_dev_12.4 cusparse_12.4 cusparse_dev_12.4 npp_12.4 npp_dev_12.4 nvrtc_12.4 nvrtc_dev_12.4 nvml_dev_12.4 nvjitlink_12.4 nvtx_12.4" + set "ARGS=cuda_profiler_api_12.6 thrust_12.6 nvcc_12.6 cuobjdump_12.6 nvprune_12.6 nvprof_12.6 cupti_12.6 cublas_12.6 cublas_dev_12.6 cudart_12.6 cufft_12.6 cufft_dev_12.6 curand_12.6 curand_dev_12.6 cusolver_12.6 cusolver_dev_12.6 cusparse_12.6 cusparse_dev_12.6 npp_12.6 npp_dev_12.6 nvrtc_12.6 nvrtc_dev_12.6 nvml_dev_12.6 nvjitlink_12.6 nvtx_12.6" ) set CUDNN_FOLDER=cudnn-windows-x86_64-9.5.0.50_cuda12-archive @@ -85,17 +59,17 @@ xcopy /Y "%SRC_DIR%\temp_build\zlib\dll_x64\*.dll" "C:\Windows\System32" goto cuda_common -:cuda126 +:cuda128 -set CUDA_INSTALL_EXE=cuda_12.6.2_560.94_windows.exe +set CUDA_INSTALL_EXE=cuda_12.8.0_571.96_windows.exe if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" & REM @lint-ignore if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=cuda_profiler_api_12.6 thrust_12.6 nvcc_12.6 cuobjdump_12.6 nvprune_12.6 nvprof_12.6 cupti_12.6 cublas_12.6 cublas_dev_12.6 cudart_12.6 cufft_12.6 cufft_dev_12.6 curand_12.6 curand_dev_12.6 cusolver_12.6 cusolver_dev_12.6 cusparse_12.6 cusparse_dev_12.6 npp_12.6 npp_dev_12.6 nvrtc_12.6 nvrtc_dev_12.6 nvml_dev_12.6 nvjitlink_12.6 nvtx_12.6" + set "ARGS=cuda_profiler_api_12.8 thrust_12.8 nvcc_12.8 cuobjdump_12.8 nvprune_12.8 nvprof_12.8 cupti_12.8 cublas_12.8 cublas_dev_12.8 cudart_12.8 cufft_12.8 cufft_dev_12.8 curand_12.8 curand_dev_12.8 cusolver_12.8 cusolver_dev_12.8 cusparse_12.8 cusparse_dev_12.8 npp_12.8 npp_dev_12.8 nvrtc_12.8 nvrtc_dev_12.8 nvml_dev_12.8 nvjitlink_12.8 nvtx_12.8" ) -set CUDNN_FOLDER=cudnn-windows-x86_64-9.5.0.50_cuda12-archive +set CUDNN_FOLDER=cudnn-windows-x86_64-9.7.0.66_cuda12-archive set CUDNN_LIB_FOLDER="lib" set "CUDNN_INSTALL_ZIP=%CUDNN_FOLDER%.zip" if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( @@ -112,17 +86,17 @@ xcopy /Y "%SRC_DIR%\temp_build\zlib\dll_x64\*.dll" "C:\Windows\System32" goto cuda_common -:cuda128 +:cuda129 -set CUDA_INSTALL_EXE=cuda_12.8.0_571.96_windows.exe +set CUDA_INSTALL_EXE=cuda_12.9.1_576.57_windows.exe if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" & REM @lint-ignore if errorlevel 1 exit /b 1 set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=cuda_profiler_api_12.8 thrust_12.8 nvcc_12.8 cuobjdump_12.8 nvprune_12.8 nvprof_12.8 cupti_12.8 cublas_12.8 cublas_dev_12.8 cudart_12.8 cufft_12.8 cufft_dev_12.8 curand_12.8 curand_dev_12.8 cusolver_12.8 cusolver_dev_12.8 cusparse_12.8 cusparse_dev_12.8 npp_12.8 npp_dev_12.8 nvrtc_12.8 nvrtc_dev_12.8 nvml_dev_12.8 nvjitlink_12.8 nvtx_12.8" + set "ARGS=cuda_profiler_api_12.9 thrust_12.9 nvcc_12.9 cuobjdump_12.9 nvprune_12.9 nvprof_12.9 cupti_12.9 cublas_12.9 cublas_dev_12.9 cudart_12.9 cufft_12.9 cufft_dev_12.9 curand_12.9 curand_dev_12.9 cusolver_12.9 cusolver_dev_12.9 cusparse_12.9 cusparse_dev_12.9 npp_12.9 npp_dev_12.9 nvrtc_12.9 nvrtc_dev_12.9 nvml_dev_12.9 nvjitlink_12.9 nvtx_12.9" ) -set CUDNN_FOLDER=cudnn-windows-x86_64-9.7.0.66_cuda12-archive +set CUDNN_FOLDER=cudnn-windows-x86_64-9.10.2.21_cuda12-archive set CUDNN_LIB_FOLDER="lib" set "CUDNN_INSTALL_ZIP=%CUDNN_FOLDER%.zip" if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( diff --git a/.ci/pytorch/windows/internal/install_python.bat b/.ci/pytorch/windows/internal/install_python.bat index 642acdb3981b43..73622bd736edd6 100644 --- a/.ci/pytorch/windows/internal/install_python.bat +++ b/.ci/pytorch/windows/internal/install_python.bat @@ -18,3 +18,5 @@ start /wait "" python-amd64.exe /quiet InstallAllUsers=1 PrependPath=0 Include_t if errorlevel 1 exit /b 1 set "PATH=%CD%\Python\Scripts;%CD%\Python;%PATH%" +%PYTHON_EXEC% -m pip install --upgrade pip setuptools packaging wheel +if errorlevel 1 exit /b 1 diff --git a/.ci/pytorch/windows/internal/smoke_test.bat b/.ci/pytorch/windows/internal/smoke_test.bat index d0fbaa7d479f3d..b7463f855428fb 100644 --- a/.ci/pytorch/windows/internal/smoke_test.bat +++ b/.ci/pytorch/windows/internal/smoke_test.bat @@ -99,7 +99,6 @@ goto end :libtorch echo "install and test libtorch" -if "%VC_YEAR%" == "2019" powershell internal\vs2019_install.ps1 if "%VC_YEAR%" == "2022" powershell internal\vs2022_install.ps1 if ERRORLEVEL 1 exit /b 1 @@ -111,10 +110,6 @@ pushd tmp\libtorch set VC_VERSION_LOWER=17 set VC_VERSION_UPPER=18 -IF "%VC_YEAR%" == "2019" ( - set VC_VERSION_LOWER=16 - set VC_VERSION_UPPER=17 -) for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( diff --git a/.ci/pytorch/windows/internal/vc_install_helper.bat b/.ci/pytorch/windows/internal/vc_install_helper.bat index b7044c4ab185b6..299134990f2987 100644 --- a/.ci/pytorch/windows/internal/vc_install_helper.bat +++ b/.ci/pytorch/windows/internal/vc_install_helper.bat @@ -1,14 +1,7 @@ -if "%VC_YEAR%" == "2019" powershell windows/internal/vs2019_install.ps1 if "%VC_YEAR%" == "2022" powershell windows/internal/vs2022_install.ps1 set VC_VERSION_LOWER=17 set VC_VERSION_UPPER=18 -:: Please don't delete VS2019 as an alternative, in case some Windows compiler issue. -:: Reference: https://github.com/pytorch/pytorch/issues/145702#issuecomment-2858693930 -if "%VC_YEAR%" == "2019" ( - set VC_VERSION_LOWER=16 - set VC_VERSION_UPPER=17 -) for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -products Microsoft.VisualStudio.Product.BuildTools -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( diff --git a/.ci/pytorch/windows/internal/vs2019_install.ps1 b/.ci/pytorch/windows/internal/vs2019_install.ps1 deleted file mode 100644 index 5574f82ebe24e3..00000000000000 --- a/.ci/pytorch/windows/internal/vs2019_install.ps1 +++ /dev/null @@ -1,48 +0,0 @@ -# https://developercommunity.visualstudio.com/t/install-specific-version-of-vs-component/1142479 -# https://docs.microsoft.com/en-us/visualstudio/releases/2019/history#release-dates-and-build-numbers - -# 16.8.6 BuildTools -$VS_DOWNLOAD_LINK = "https://ossci-windows.s3.us-east-1.amazonaws.com/vs16.8.6_BuildTools.exe" -$COLLECT_DOWNLOAD_LINK = "https://aka.ms/vscollect.exe" -$VS_INSTALL_ARGS = @("--nocache","--quiet","--wait", "--add Microsoft.VisualStudio.Workload.VCTools", - "--add Microsoft.Component.MSBuild", - "--add Microsoft.VisualStudio.Component.Roslyn.Compiler", - "--add Microsoft.VisualStudio.Component.TextTemplating", - "--add Microsoft.VisualStudio.Component.VC.CoreIde", - "--add Microsoft.VisualStudio.Component.VC.Redist.14.Latest", - "--add Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Core", - "--add Microsoft.VisualStudio.Component.VC.Tools.x86.x64", - "--add Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Win81") - -curl.exe --retry 3 -kL $VS_DOWNLOAD_LINK --output vs_installer.exe -if ($LASTEXITCODE -ne 0) { - echo "Download of the VS 2019 Version 16.8.5 installer failed" - exit 1 -} - -if (Test-Path "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe") { - $existingPath = & "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe" -products "Microsoft.VisualStudio.Product.BuildTools" -version "[16, 17)" -property installationPath - if ($existingPath -ne $null) { - if (!${env:CIRCLECI}) { - echo "Found correctly versioned existing BuildTools installation in $existingPath" - exit 0 - } - echo "Found existing BuildTools installation in $existingPath, keeping it" - } -} - -$process = Start-Process "${PWD}\vs_installer.exe" -ArgumentList $VS_INSTALL_ARGS -NoNewWindow -Wait -PassThru -Remove-Item -Path vs_installer.exe -Force -$exitCode = $process.ExitCode -if (($exitCode -ne 0) -and ($exitCode -ne 3010)) { - echo "VS 2019 installer exited with code $exitCode, which should be one of [0, 3010]." - curl.exe --retry 3 -kL $COLLECT_DOWNLOAD_LINK --output Collect.exe - if ($LASTEXITCODE -ne 0) { - echo "Download of the VS Collect tool failed." - exit 1 - } - Start-Process "${PWD}\Collect.exe" -NoNewWindow -Wait -PassThru - New-Item -Path "C:\w\build-results" -ItemType "directory" -Force - Copy-Item -Path "C:\Users\${env:USERNAME}\AppData\Local\Temp\vslogs.zip" -Destination "C:\w\build-results\" - exit 1 -} diff --git a/.ci/pytorch/windows/internal/xpu_install.bat b/.ci/pytorch/windows/internal/xpu_install.bat index ee78afb70e91a0..2296adf4dfe662 100644 --- a/.ci/pytorch/windows/internal/xpu_install.bat +++ b/.ci/pytorch/windows/internal/xpu_install.bat @@ -25,8 +25,8 @@ set XPU_EXTRA_INSTALLED=0 set XPU_EXTRA_UNINSTALL=0 if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.1] ( - set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/1a9fff3d-04c2-4d77-8861-3d86c774b66f/intel-deep-learning-essentials-2025.1.1.26_offline.exe - set XPU_BUNDLE_VERSION=2025.1.1+23 + set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/75d4eb97-914a-4a95-852c-7b9733d80f74/intel-deep-learning-essentials-2025.1.3.8_offline.exe + set XPU_BUNDLE_VERSION=2025.1.3+5 ) :: Check if XPU bundle is target version or already installed diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index b6b0d978cc2331..76206d9937ef7c 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -206,7 +206,7 @@ if [[ "$USE_SPLIT_BUILD" == "true" ]]; then BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel -d "$whl_tmp_dir" echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" - BUILD_PYTHON_ONLY=1 BUILD_LIBTORCH_WHL=0 python setup.py bdist_wheel -d "$whl_tmp_dir" --cmake + BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 CMAKE_FRESH=1 python setup.py bdist_wheel -d "$whl_tmp_dir" echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" else python setup.py bdist_wheel -d "$whl_tmp_dir" diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 3f67d2ec1e6d7e..7f89c5c2dd8e6d 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -75,8 +75,8 @@ TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) # Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'" -# CUDA 12.8 builds have triton for Linux and Linux aarch64 binaries. -if [[ "$DESIRED_CUDA" == cu128 ]]; then +# CUDA 12.9 builds have triton for Linux and Linux aarch64 binaries. +if [[ "$DESIRED_CUDA" == "cu129" ]]; then TRITON_CONSTRAINT="platform_system == 'Linux'" fi @@ -105,6 +105,7 @@ fi # Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton xpu package if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*xpu.* ]]; then + TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_xpu_version.txt) TRITON_REQUIREMENT="pytorch-triton-xpu==${TRITON_VERSION}" if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-xpu.txt) diff --git a/.circleci/scripts/trigger_azure_pipeline.py b/.circleci/scripts/trigger_azure_pipeline.py deleted file mode 100644 index c0ac8bbd3ad0b6..00000000000000 --- a/.circleci/scripts/trigger_azure_pipeline.py +++ /dev/null @@ -1,157 +0,0 @@ -# Documentation: https://docs.microsoft.com/en-us/rest/api/azure/devops/build/?view=azure-devops-rest-6.0 - -import json -import os -import re -import sys -import time - -import requests - - -AZURE_PIPELINE_BASE_URL = "https://aiinfra.visualstudio.com/PyTorch/" -AZURE_DEVOPS_PAT_BASE64 = os.environ.get("AZURE_DEVOPS_PAT_BASE64_SECRET", "") -PIPELINE_ID = "911" -PROJECT_ID = "0628bce4-2d33-499e-bac5-530e12db160f" -TARGET_BRANCH = os.environ.get("CIRCLE_BRANCH", "main") -TARGET_COMMIT = os.environ.get("CIRCLE_SHA1", "") - -build_base_url = AZURE_PIPELINE_BASE_URL + "_apis/build/builds?api-version=6.0" - -s = requests.Session() -s.headers.update({"Authorization": "Basic " + AZURE_DEVOPS_PAT_BASE64}) - - -def submit_build(pipeline_id, project_id, source_branch, source_version): - print("Submitting build for branch: " + source_branch) - print("Commit SHA1: ", source_version) - - run_build_raw = s.post( - build_base_url, - json={ - "definition": {"id": pipeline_id}, - "project": {"id": project_id}, - "sourceBranch": source_branch, - "sourceVersion": source_version, - }, - ) - - try: - run_build_json = run_build_raw.json() - except json.decoder.JSONDecodeError as e: - print(e) - print( - "Failed to parse the response. Check if the Azure DevOps PAT is incorrect or expired." - ) - sys.exit(-1) - - build_id = run_build_json["id"] - - print("Submitted bulid: " + str(build_id)) - print("Bulid URL: " + run_build_json["url"]) - return build_id - - -def get_build(_id): - get_build_url = ( - AZURE_PIPELINE_BASE_URL + f"/_apis/build/builds/{_id}?api-version=6.0" - ) - get_build_raw = s.get(get_build_url) - return get_build_raw.json() - - -def get_build_logs(_id): - get_build_logs_url = ( - AZURE_PIPELINE_BASE_URL + f"/_apis/build/builds/{_id}/logs?api-version=6.0" - ) - get_build_logs_raw = s.get(get_build_logs_url) - return get_build_logs_raw.json() - - -def get_log_content(url): - resp = s.get(url) - return resp.text - - -def wait_for_build(_id): - build_detail = get_build(_id) - build_status = build_detail["status"] - - while build_status == "notStarted": - print("Waiting for run to start: " + str(_id)) - sys.stdout.flush() - try: - build_detail = get_build(_id) - build_status = build_detail["status"] - except Exception as e: - print("Error getting build") - print(e) - - time.sleep(30) - - print("Bulid started: ", str(_id)) - - handled_logs = set() - while build_status == "inProgress": - try: - print("Waiting for log: " + str(_id)) - logs = get_build_logs(_id) - except Exception as e: - print("Error fetching logs") - print(e) - time.sleep(30) - continue - - for log in logs["value"]: - log_id = log["id"] - if log_id in handled_logs: - continue - handled_logs.add(log_id) - print("Fetching log: \n" + log["url"]) - try: - log_content = get_log_content(log["url"]) - print(log_content) - except Exception as e: - print("Error getting log content") - print(e) - sys.stdout.flush() - build_detail = get_build(_id) - build_status = build_detail["status"] - time.sleep(30) - - build_result = build_detail["result"] - - print("Bulid status: " + build_status) - print("Bulid result: " + build_result) - - return build_status, build_result - - -if __name__ == "__main__": - # Convert the branch name for Azure DevOps - match = re.search(r"pull/(\d+)", TARGET_BRANCH) - if match is not None: - pr_num = match.group(1) - SOURCE_BRANCH = f"refs/pull/{pr_num}/head" - else: - SOURCE_BRANCH = f"refs/heads/{TARGET_BRANCH}" - - MAX_RETRY = 2 - retry = MAX_RETRY - - while retry > 0: - build_id = submit_build(PIPELINE_ID, PROJECT_ID, SOURCE_BRANCH, TARGET_COMMIT) - build_status, build_result = wait_for_build(build_id) - - if build_result != "succeeded": - retry = retry - 1 - if retry > 0: - print("Retrying... remaining attempt: " + str(retry)) - # Wait a bit before retrying - time.sleep((MAX_RETRY - retry) * 120) - continue - else: - print("No more chance to retry. Giving up.") - sys.exit(-1) - else: - break diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 458f283507fcf0..885d6393213a66 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -12,7 +12,9 @@ body: description: | Please provide a clear and concise description of what the bug is. - If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: + If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. + Your example should be fully self-contained and not rely on any artifact that should be downloaded. + For example: ```python # All necessary imports at the beginning @@ -26,6 +28,7 @@ body: If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + If your issue is related to numerical accuracy or reproducibility, please read the [numerical accuracy](https://docs.pytorch.org/docs/stable/notes/numerical_accuracy.html) and [reproducibility](https://docs.pytorch.org/docs/stable/notes/randomness.html) notes. If the difference is not expected as described in these documents, please provide appropriate justification on why one result is wrong and the other is correct. placeholder: | A clear and concise description of what the bug is. diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index a7fbd571558717..6687b58008d920 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -14,6 +14,7 @@ self-hosted-runner: - linux.12xlarge - linux.24xlarge - linux.24xlarge.ephemeral + - linux.24xlarge.amd - linux.arm64.2xlarge - linux.arm64.2xlarge.ephemeral - linux.arm64.m7g.4xlarge @@ -49,6 +50,7 @@ self-hosted-runner: # Organization-wide AMD-hosted runners # MI2xx runners - linux.rocm.gpu + - linux.rocm.gpu.mi250 - linux.rocm.gpu.2 - linux.rocm.gpu.4 # MI300 runners diff --git a/.github/actions/build-android/action.yml b/.github/actions/build-android/action.yml index 1d4d71fd9d3673..bccd42aa42f2c5 100644 --- a/.github/actions/build-android/action.yml +++ b/.github/actions/build-android/action.yml @@ -9,7 +9,7 @@ inputs: arch-for-build-env: description: | arch to pass to build environment. - This is currently different than the arch name we use elswhere, which + This is currently different than the arch name we use elsewhere, which should be fixed. required: true github-secret: diff --git a/.github/actions/filter-test-configs/action.yml b/.github/actions/filter-test-configs/action.yml index 7da1ce3fe07144..ca6643f9e2fc1b 100644 --- a/.github/actions/filter-test-configs/action.yml +++ b/.github/actions/filter-test-configs/action.yml @@ -125,7 +125,7 @@ runs: TAG: ${{ steps.parse-ref.outputs.tag }} EVENT_NAME: ${{ github.event_name }} SCHEDULE: ${{ github.event.schedule }} - HEAD_BRANCH: ${{ github.event.workflow_run.head_branch }} + HEAD_BRANCH: ${{ steps.parse-ref.outputs.branch }} id: filter run: | echo "Workflow: ${GITHUB_WORKFLOW}" @@ -157,4 +157,4 @@ runs: echo "Is keep-going label set? ${{ steps.filter.outputs.keep-going }}" echo - echo "Renabled issues? ${{ steps.filter.outputs.reenabled-issues }}" + echo "Reenabled issues? ${{ steps.filter.outputs.reenabled-issues }}" diff --git a/.github/actions/linux-test/action.yml b/.github/actions/linux-test/action.yml index 81923683e79068..fb46709d9b0db8 100644 --- a/.github/actions/linux-test/action.yml +++ b/.github/actions/linux-test/action.yml @@ -153,7 +153,7 @@ runs: github-token: ${{ inputs.GITHUB_TOKEN }} - name: Check for keep-going label and re-enabled test issues - # This uses the filter-test-configs action because it conviniently + # This uses the filter-test-configs action because it conveniently # checks for labels and re-enabled test issues. It does not actually do # any filtering. All filtering is done in the build step. id: keep-going diff --git a/.github/actions/reuse-old-whl/action.yml b/.github/actions/reuse-old-whl/action.yml index d8e99a17652285..1976a30828edd1 100644 --- a/.github/actions/reuse-old-whl/action.yml +++ b/.github/actions/reuse-old-whl/action.yml @@ -13,6 +13,12 @@ inputs: github-token: description: GitHub token required: true + job-id: + description: Job ID + required: true + job-name: + description: Job name + required: true outputs: reuse: @@ -30,8 +36,11 @@ runs: continue-on-error: true env: GITHUB_TOKEN: ${{ inputs.github-token }} + JOB_ID: ${{ inputs.job-id }} + JOB_NAME: ${{ inputs.job-name }} run: | set -x + python3 -m pip install boto3==1.35.42 python3 ${GITHUB_ACTION_PATH}/reuse_old_whl.py \ --build-environment "${{ inputs.build-environment }}" \ --run-id "${{ inputs.run-id }}" \ diff --git a/.github/actions/reuse-old-whl/reuse_old_whl.py b/.github/actions/reuse-old-whl/reuse_old_whl.py index 45c470cca3d71d..def0276a9c8a35 100644 --- a/.github/actions/reuse-old-whl/reuse_old_whl.py +++ b/.github/actions/reuse-old-whl/reuse_old_whl.py @@ -1,13 +1,22 @@ import argparse import os import subprocess +import sys from functools import lru_cache from pathlib import Path -from typing import Any, cast, Optional +from typing import Any, cast, Optional, Union import requests +REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent +sys.path.insert(0, str(REPO_ROOT)) +from tools.stats.upload_metrics import emit_metric + + +sys.path.remove(str(REPO_ROOT)) # Clean up sys.path after import + + FORCE_REBUILD_LABEL = "ci-force-rebuild" @@ -114,15 +123,43 @@ def ok_changed_file(file: str) -> bool: return True if file.startswith("test/") and file.endswith(".py"): return True + if file.startswith("docs/") and file.endswith((".md", ".rst")): + return True return False def check_changed_files(sha: str) -> bool: # Return true if all the changed files are in the list of allowed files to # be changed to reuse the old whl + + # Removing files in the torch folder is not allowed since rsync will not + # remove files + removed_files = ( + subprocess.check_output( + [ + "git", + "diff", + "--name-only", + sha, + "HEAD", + "--diff-filter=D", + "--no-renames", + ], + text=True, + stderr=subprocess.DEVNULL, + ) + .strip() + .split() + ) + if any(file.startswith("torch/") for file in removed_files): + print( + f"Removed files between {sha} and HEAD: {removed_files}, cannot reuse old whl" + ) + return False + changed_files = ( subprocess.check_output( - ["git", "diff", "--name-only", sha, "HEAD"], + ["git", "diff", "--name-only", sha, "HEAD", "--no-renames"], text=True, stderr=subprocess.DEVNULL, ) @@ -179,38 +216,83 @@ def unzip_artifact_and_replace_files() -> None: ) os.remove("artifacts.zip") + head_sha = get_head_sha() + # Rename wheel into zip wheel_path = Path("artifacts/dist").glob("*.whl") for path in wheel_path: - new_path = path.with_suffix(".zip") - os.rename(path, new_path) - print(f"Renamed {path} to {new_path}") - print(new_path.stem) + # Should be of the form torch-2.0.0+git1234567-cp37-etc.whl + # Should usually be the merge base sha but for the ones that didn't do + # the replacement, it won't be. Can probably change it to just be merge + # base later + old_version = f"+git{path.stem.split('+')[1].split('-')[0][3:]}" + new_version = f"+git{head_sha[:7]}" + + def rename_to_new_version(file: Union[str, Path]) -> None: + # Rename file with old_version to new_version + subprocess.check_output( + ["mv", file, str(file).replace(old_version, new_version)] + ) + + def change_content_to_new_version(file: Union[str, Path]) -> None: + # Check if is a file + if os.path.isdir(file): + return + # Replace the old version in the file with the new version + with open(file) as f: + content = f.read() + content = content.replace(old_version, new_version) + with open(file, "w") as f: + f.write(content) + + zip_path = path.with_suffix(".zip") + os.rename(path, zip_path) + old_stem = zip_path.stem # Unzip the wheel subprocess.check_output( - ["unzip", "-o", new_path, "-d", f"artifacts/dist/{new_path.stem}"], + ["unzip", "-o", zip_path, "-d", f"artifacts/dist/{old_stem}"], ) + + # Remove the old wheel (which is now a zip file) + os.remove(zip_path) + # Copy python files into the artifact subprocess.check_output( - ["rsync", "-avz", "torch", f"artifacts/dist/{new_path.stem}"], + ["rsync", "-avz", "torch", f"artifacts/dist/{old_stem}"], ) + change_content_to_new_version(f"artifacts/dist/{old_stem}/torch/version.py") + + for file in Path(f"artifacts/dist/{old_stem}").glob( + "*.dist-info/**", + ): + change_content_to_new_version(file) + + rename_to_new_version(f"artifacts/dist/{old_stem}") + new_stem = old_stem.replace(old_version, new_version) + + for file in Path(f"artifacts/dist/{new_stem}").glob( + "*.dist-info", + ): + rename_to_new_version(file) + # Zip the wheel back subprocess.check_output( - ["zip", "-r", f"{new_path.stem}.zip", "."], - cwd=f"artifacts/dist/{new_path.stem}", + ["zip", "-r", f"{new_stem}.zip", "."], + cwd=f"artifacts/dist/{new_stem}", ) + subprocess.check_output( [ "mv", - f"artifacts/dist/{new_path.stem}/{new_path.stem}.zip", - f"artifacts/dist/{new_path.stem}.whl", + f"artifacts/dist/{new_stem}/{new_stem}.zip", + f"artifacts/dist/{new_stem}.whl", ], ) # Remove the extracted folder subprocess.check_output( - ["rm", "-rf", f"artifacts/dist/{new_path.stem}"], + ["rm", "-rf", f"artifacts/dist/{new_stem}"], ) # Rezip the artifact @@ -222,8 +304,7 @@ def unzip_artifact_and_replace_files() -> None: def set_output() -> None: - # Disable for now so we can monitor first - # pass + print("Setting output reuse=true") if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: print("reuse=true", file=env) @@ -244,46 +325,60 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def can_reuse_whl(args: argparse.Namespace) -> bool: - # if is_main_branch() or ( - # args.github_ref - # and any( - # args.github_ref.startswith(x) - # for x in ["refs/heads/release", "refs/tags/v", "refs/heads/main"] - # ) - # ): - # print("On main branch or release branch, rebuild whl") - # return False +def can_reuse_whl(args: argparse.Namespace) -> tuple[bool, str]: + if args.github_ref and any( + args.github_ref.startswith(x) + for x in [ + "refs/heads/release", + "refs/tags/v", + "refs/heads/nightly", + ] + ): + print("Release branch, rebuild whl") + return (False, "Release branch") + + if not check_changed_files(get_merge_base()): + print("Cannot use old whl due to the changed files, rebuild whl") + return (False, "Changed files not allowed") if check_labels_for_pr(): print(f"Found {FORCE_REBUILD_LABEL} label on PR, rebuild whl") - return False + return (False, "Found FORCE_REBUILD_LABEL on PR") if check_issue_open(): print("Issue #153759 is open, rebuild whl") - return False - - if not check_changed_files(get_merge_base()): - print("Cannot use old whl due to the changed files, rebuild whl") - return False + return (False, "Issue #153759 is open") workflow_id = get_workflow_id(args.run_id) if workflow_id is None: print("No workflow ID found, rebuild whl") - return False + return (False, "No workflow ID found") if not find_old_whl(workflow_id, args.build_environment, get_merge_base()): print("No old whl found, rebuild whl") + return (False, "No old whl found") # TODO: go backwards from merge base to find more runs - return False - return True + return (True, "Found old whl") if __name__ == "__main__": args = parse_args() - if can_reuse_whl(args): + reuse_whl, reason = can_reuse_whl(args) + + if reuse_whl: print("Reusing old whl") unzip_artifact_and_replace_files() set_output() + + emit_metric( + "reuse_old_whl", + { + "reuse_whl": reuse_whl, + "reason": reason, + "build_environment": args.build_environment, + "merge_base": get_merge_base(), + "head_sha": get_head_sha(), + }, + ) diff --git a/.github/actions/setup-linux/action.yml b/.github/actions/setup-linux/action.yml index da514c04a69f06..5af32ac0349728 100644 --- a/.github/actions/setup-linux/action.yml +++ b/.github/actions/setup-linux/action.yml @@ -33,14 +33,14 @@ runs: id: check_container_runner run: echo "IN_CONTAINER_RUNNER=$(if [ -f /.inarc ] || [ -f /.incontainer ]; then echo true ; else echo false; fi)" >> "$GITHUB_OUTPUT" - - name: Start docker if docker deamon is not running + - name: Start docker if docker daemon is not running shell: bash if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} run: | if systemctl is-active --quiet docker; then echo "Docker daemon is running..."; else - echo "Starting docker deamon..." && sudo systemctl start docker; + echo "Starting docker daemon..." && sudo systemctl start docker; fi - name: Log in to ECR diff --git a/.github/actions/setup-xpu/action.yml b/.github/actions/setup-xpu/action.yml index 50411e4bdf3381..740492475d6e24 100644 --- a/.github/actions/setup-xpu/action.yml +++ b/.github/actions/setup-xpu/action.yml @@ -29,13 +29,13 @@ runs: if: always() shell: bash run: | - xpu-smi discovery + timeout 30 xpu-smi discovery || true - name: Runner health check GPU count if: always() shell: bash run: | - ngpu=$(xpu-smi discovery | grep -c -E 'Device Name') + ngpu=$(timeout 30 xpu-smi discovery | grep -c -E 'Device Name' || true) msg="Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" if [[ $ngpu -eq 0 ]]; then echo "Error: Failed to detect any GPUs on the runner" diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index e8c89c91623fff..2a6c3b66d45cfe 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -1a8f6213b0b61efc6a4862bc45b853551a93dbb6 +70caf76066ef2c1054d6128b11769dc816a779e7 diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index de4e76454a9f4e..dfbc78d8884c8c 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -edc1a882d872dd7f1362e4312fd045a1d81b3355 +926700d7832caa552ba2e1fc8302f6a2f4d2f6d8 diff --git a/.github/label_to_label.yml b/.github/label_to_label.yml index 28bad93f808dba..0cd56143535fe8 100644 --- a/.github/label_to_label.yml +++ b/.github/label_to_label.yml @@ -48,3 +48,12 @@ - "module: dynamic shapes" then: - "oncall: pt2" +- any: + - "release notes: distributed (c10d)" + - "release notes: distributed (symm_mem)" + - "release notes: distributed (pipeline)" + - "release notes: distributed (fsdp)" + - "release notes: distributed (dtensor)" + - "oncall: distributed" + then: + - "ciflow/h100-distributed" diff --git a/.github/labeler.yml b/.github/labeler.yml index bbdf14bc29182d..8b1acc77c267f8 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -116,7 +116,6 @@ "release notes: inductor (aoti)": - torch/_C/_aoti.pyi - torch/_dynamo/repro/aoti.py -- torch/_export/serde/aoti_schema.py - torch/_higher_order_ops/aoti_call_delegate.py - torch/_inductor/codegen/aoti_runtime/** - torch/_inductor/codegen/aoti_hipify_utils.py diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index cc5817ed7f391d..5786c2aa1652c0 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -123,6 +123,8 @@ - torch/*docs.py approved_by: - svekars + - sekyondaMeta + - AlannaBurke mandatory_checks_name: - EasyCLA - Lint @@ -382,6 +384,7 @@ - leslie-fang-intel - jgong5 - EikanWang + - CaoE mandatory_checks_name: - EasyCLA - Lint @@ -433,6 +436,7 @@ approved_by: - leslie-fang-intel - jgong5 + - CaoE mandatory_checks_name: - EasyCLA - Lint diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 614b7c4a7a1088..ac8cb3df0ffcd5 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -4,6 +4,7 @@ ciflow_push_tags: - ciflow/binaries - ciflow/binaries_libtorch - ciflow/binaries_wheel +- ciflow/triton_binaries - ciflow/inductor - ciflow/inductor-periodic - ciflow/inductor-rocm @@ -11,6 +12,7 @@ ciflow_push_tags: - ciflow/inductor-perf-compare - ciflow/inductor-micro-benchmark - ciflow/inductor-micro-benchmark-cpu-x86 +- ciflow/inductor-perf-test-nightly-x86-zen - ciflow/inductor-cu126 - ciflow/linux-aarch64 - ciflow/mps @@ -28,6 +30,8 @@ ciflow_push_tags: - ciflow/op-benchmark - ciflow/pull - ciflow/h100 +- ciflow/h100-distributed +- ciflow/h100-symm-mem retryable_workflows: - pull - trunk diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index caabd1edf200b5..5e2819c8a8362a 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -10,5 +10,5 @@ lintrunner==0.10.7 ninja==1.10.0.post1 nvidia-ml-py==11.525.84 pyyaml==6.0 -requests==2.32.2 +requests==2.32.4 rich==10.9.0 diff --git a/.github/requirements/conda-env-macOS-ARM64 b/.github/requirements/conda-env-macOS-ARM64 index 5af413ca7b45a5..b6e9a6ce9f3e57 100644 --- a/.github/requirements/conda-env-macOS-ARM64 +++ b/.github/requirements/conda-env-macOS-ARM64 @@ -2,5 +2,4 @@ certifi pip=23.2.1 pkg-config=0.29.2 -setuptools=72.1.0 wheel=0.37.1 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index f32eb1784d524f..e8464f0a55ff53 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -1,5 +1,5 @@ boto3==1.35.42 -cmake==3.25.* +cmake==3.27.* expecttest==0.3.0 fbscribelogger==0.1.7 filelock==3.6.0 @@ -14,7 +14,7 @@ opt-einsum>=3.3 optree==0.13.0 packaging==23.1 parameterized==0.8.1 -pillow==10.0.1 +pillow==10.3.0 protobuf==5.29.4 psutil==5.9.1 pygments==2.15.0 @@ -26,7 +26,9 @@ pytest-xdist==3.3.1 pytest==7.3.2 pyyaml==6.0.2 scipy==1.12.0 +setuptools==72.1.0 sympy==1.13.3 +tlparse==0.3.30 tensorboard==2.13.0 typing-extensions==4.12.2 unittest-xml-reporting<=3.2.0,>=2.0.0 diff --git a/.github/scripts/amd/patch_triton_wheel.sh b/.github/scripts/amd/patch_triton_wheel.sh index 667fcb645587ca..36691346315467 100755 --- a/.github/scripts/amd/patch_triton_wheel.sh +++ b/.github/scripts/amd/patch_triton_wheel.sh @@ -78,7 +78,7 @@ for pkg in /$WHEELHOUSE_DIR/*triton*.whl; do echo "Copied $filepath to $patchedpath" done - # Go through all required shared objects and see if any of our other objects are dependants. If so, replace so.ver wth so + # Go through all required shared objects and see if any of our other objects are dependants. If so, replace so.ver with so for ((i=0;i<${#deps[@]};++i)); do echo "replacing "${deps_soname[i]} ${patched[i]} replace_needed_sofiles $PREFIX/$ROCM_LIB ${deps_soname[i]} ${patched[i]} diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 5caccd04152ca8..beec9f96aba216 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -21,8 +21,11 @@ def read_triton_pin(device: str = "cuda") -> str: return f.read().strip() -def read_triton_version() -> str: - with open(REPO_DIR / ".ci" / "docker" / "triton_version.txt") as f: +def read_triton_version(device: str = "cuda") -> str: + triton_version_file = "triton_version.txt" + if device == "xpu": + triton_version_file = "triton_xpu_version.txt" + with open(REPO_DIR / ".ci" / "docker" / triton_version_file) as f: return f.read().strip() @@ -65,6 +68,7 @@ def build_triton( with TemporaryDirectory() as tmpdir: triton_basedir = Path(tmpdir) / "triton" triton_pythondir = triton_basedir / "python" + triton_repo = "https://github.com/openai/triton" if device == "rocm": triton_pkg_name = "pytorch-triton-rocm" @@ -90,7 +94,7 @@ def build_triton( patch_init_py( triton_pythondir / "triton" / "__init__.py", version=f"{version}", - expected_version=None, + expected_version=read_triton_version(device), ) if device == "rocm": @@ -101,11 +105,19 @@ def build_triton( ) print("ROCm libraries setup for triton installation...") + # old triton versions have setup.py in the python/ dir, + # new versions have it in the root dir. + triton_setupdir = ( + triton_basedir + if (triton_basedir / "setup.py").exists() + else triton_pythondir + ) + check_call( - [sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env + [sys.executable, "setup.py", "bdist_wheel"], cwd=triton_setupdir, env=env ) - whl_path = next(iter((triton_pythondir / "dist").glob("*.whl"))) + whl_path = next(iter((triton_setupdir / "dist").glob("*.whl"))) shutil.copy(whl_path, Path.cwd()) if device == "rocm": @@ -128,15 +140,19 @@ def main() -> None: parser.add_argument("--py-version", type=str) parser.add_argument("--commit-hash", type=str) parser.add_argument("--with-clang-ldd", action="store_true") - parser.add_argument("--triton-version", type=str, default=read_triton_version()) + parser.add_argument("--triton-version", type=str, default=None) args = parser.parse_args() + triton_version = read_triton_version(args.device) + if args.triton_version: + triton_version = args.triton_version + build_triton( device=args.device, commit_hash=( args.commit_hash if args.commit_hash else read_triton_pin(args.device) ), - version=args.triton_version, + version=triton_version, py_version=args.py_version, release=args.release, with_clang_ldd=args.with_clang_ldd, diff --git a/.github/scripts/delete_old_branches.py b/.github/scripts/delete_old_branches.py index b96c3956856fb9..8032008edf1221 100644 --- a/.github/scripts/delete_old_branches.py +++ b/.github/scripts/delete_old_branches.py @@ -275,7 +275,7 @@ def delete_branches() -> None: delete_branch(git_repo, branch) -def delete_old_ciflow_tags() -> None: +def delete_old_tags() -> None: # Deletes ciflow tags if they are associated with a closed PR or a specific # commit. Lightweight tags don't have information about the date they were # created, so we can't check how old they are. The script just assumes that @@ -288,23 +288,29 @@ def delete_tag(tag: str) -> None: delete_branch(git_repo, f"refs/tags/{tag}") tags = git_repo._run_git("tag").splitlines() - open_pr_numbers = [x["number"] for x in get_open_prs()] + CIFLOW_TAG_REGEX = re.compile(r"^ciflow\/.*\/(\d{5,6}|[0-9a-f]{40})$") + AUTO_REVERT_TAG_REGEX = re.compile(r"^trunk\/[0-9a-f]{40}$") for tag in tags: try: if ESTIMATED_TOKENS[0] > 400: print("Estimated tokens exceeded, exiting") break - if not tag.startswith("ciflow/"): + + if not CIFLOW_TAG_REGEX.match(tag) and not AUTO_REVERT_TAG_REGEX.match(tag): continue - re_match_pr = re.match(r"^ciflow\/.*\/(\d{5,6})$", tag) - re_match_sha = re.match(r"^ciflow\/.*\/([0-9a-f]{40})$", tag) - if re_match_pr: - pr_number = int(re_match_pr.group(1)) - if pr_number in open_pr_numbers: - continue - delete_tag(tag) - elif re_match_sha: + + # This checks the date of the commit associated with the tag instead + # of the tag itself since lightweight tags don't have this + # information. I think it should be ok since this only runs once a + # day + tag_info = git_repo._run_git("show", "-s", "--format=%ct", tag) + tag_timestamp = int(tag_info.strip()) + # Maybe some timezone issues, but a few hours shouldn't matter + tag_age_days = (datetime.now().timestamp() - tag_timestamp) / SEC_IN_DAY + + if tag_age_days > 7: + print(f"[{tag}] Tag is older than 7 days, deleting") delete_tag(tag) except Exception as e: print(f"Failed to check tag {tag}: {e}") @@ -312,4 +318,4 @@ def delete_tag(tag: str) -> None: if __name__ == "__main__": delete_branches() - delete_old_ciflow_tags() + delete_old_tags() diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index dcaf08989f9ea3..9ba210a5ed2b58 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -18,6 +18,7 @@ REENABLE_TEST_REGEX = "(?i)(Close(d|s)?|Resolve(d|s)?|Fix(ed|es)?) (#|https://github.com/pytorch/pytorch/issues/)([0-9]+)" +MAIN_BRANCH = "main" PREFIX = "test-config/" @@ -80,7 +81,7 @@ def parse_args() -> Any: parser.add_argument( "--job-name", type=str, - help="the name of the current job, i.e. linux-focal-py3.8-gcc7 / build", + help="the name of the current job, i.e. linux-jammy-py3.8-gcc7 / build", ) parser.add_argument("--pr-number", type=str, help="the pull request number") parser.add_argument("--tag", type=str, help="the associated tag if it exists") @@ -97,7 +98,7 @@ def parse_args() -> Any: parser.add_argument( "--branch", type=str, - default="main", + default=MAIN_BRANCH, help="the branch name", ) return parser.parse_args() @@ -456,6 +457,7 @@ def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> An def set_output(name: str, val: Any) -> None: + print(f"Setting output {name}={val}") if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: print(f"{name}={val}", file=env) @@ -495,13 +497,20 @@ def check_for_setting(labels: set[str], body: str, setting: str) -> bool: def perform_misc_tasks( - labels: set[str], test_matrix: dict[str, list[Any]], job_name: str, pr_body: str + labels: set[str], + test_matrix: dict[str, list[Any]], + job_name: str, + pr_body: str, + branch: Optional[str] = None, ) -> None: """ In addition to apply the filter logic, the script also does the following misc tasks to set keep-going and is-unstable variables """ - set_output("keep-going", check_for_setting(labels, pr_body, "keep-going")) + set_output( + "keep-going", + branch == MAIN_BRANCH or check_for_setting(labels, pr_body, "keep-going"), + ) set_output( "ci-verbose-test-logs", check_for_setting(labels, pr_body, "ci-verbose-test-logs"), @@ -624,6 +633,7 @@ def main() -> None: test_matrix=filtered_test_matrix, job_name=args.job_name, pr_body=pr_body if pr_body else "", + branch=args.branch, ) # Set the filtered test matrix as the output diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index e26f9053eaf498..b71440aaa6ab3a 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -15,21 +15,21 @@ from typing import Optional -# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list -CUDA_ARCHES = ["11.8", "12.6", "12.8"] -CUDA_STABLE = "12.6" +# NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this +CUDA_ARCHES = ["12.6", "12.8", "12.9"] +CUDA_STABLE = "12.8" CUDA_ARCHES_FULL_VERSION = { - "11.8": "11.8.0", "12.6": "12.6.3", "12.8": "12.8.1", + "12.9": "12.9.1", } CUDA_ARCHES_CUDNN_VERSION = { - "11.8": "9", "12.6": "9", "12.8": "9", + "12.9": "9", } -# NOTE: Also update the ROCm sources in tools/nightly.py when changing this list +# NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this ROCM_ARCHES = ["6.3", "6.4"] XPU_ARCHES = ["xpu"] @@ -38,35 +38,23 @@ CPU_S390X_ARCH = ["cpu-s390x"] -CUDA_AARCH64_ARCHES = ["12.8-aarch64"] +CUDA_AARCH64_ARCHES = ["12.9-aarch64"] PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { - "11.8": ( - "nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 - "nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64'" - ), "12.6": ( "nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==9.5.1.17; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64'" @@ -75,25 +63,43 @@ "nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64'" ), + "12.9": ( + "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'" + ), "xpu": ( "intel-cmplr-lib-rt==2025.1.1 | " "intel-cmplr-lib-ur==2025.1.1 | " "intel-cmplr-lic-rt==2025.1.1 | " "intel-sycl-rt==2025.1.1 | " - "oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " "impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | " "onemkl-sycl-blas==2025.1.0 | " "onemkl-sycl-dft==2025.1.0 | " @@ -107,7 +113,7 @@ "tbb==2022.1.0 | " "tcmlib==1.3.0 | " "umf==0.10.0 | " - "intel-pti==0.12.0" + "intel-pti==0.12.3" ), } @@ -311,10 +317,10 @@ def generate_wheels_matrix( continue if use_split_build and ( - arch_version not in ["12.6", "12.8", "11.8", "cpu"] or os != "linux" + arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux" ): raise RuntimeError( - "Split build is only supported on linux with cuda 12*, 11.8, and cpu.\n" + "Split build is only supported on linux with cuda 12* and cpu.\n" f"Currently attempting to build on arch version {arch_version} and os {os}.\n" "Please modify the matrix generation to exclude this combination." ) @@ -322,7 +328,7 @@ def generate_wheels_matrix( # cuda linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install if ( - arch_version in ["12.8", "12.6", "11.8"] + arch_version in ["12.9", "12.8", "12.6"] and os == "linux" or arch_version in CUDA_AARCH64_ARCHES ): @@ -411,6 +417,6 @@ def generate_wheels_matrix( return ret +validate_nccl_dep_consistency("12.9") validate_nccl_dep_consistency("12.8") validate_nccl_dep_consistency("12.6") -validate_nccl_dep_consistency("11.8") diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 751e34a10853d0..55cb02504ea452 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -152,7 +152,7 @@ class OperatingSystem: package_type="manywheel", build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.LINUX, - arches=["11.8", "12.6", "12.8"], + arches=["12.6", "12.8", "12.9", "6.4"], python_versions=["3.9"], ), branches="main", diff --git a/.github/scripts/get_workflow_job_id.py b/.github/scripts/get_workflow_job_id.py index cfbfe315bf69ca..b04cbed76e9559 100644 --- a/.github/scripts/get_workflow_job_id.py +++ b/.github/scripts/get_workflow_job_id.py @@ -64,7 +64,7 @@ def fetch_url( ) exception_message = ( "Is github alright?", - f"Recieved status code '{err.code}' when attempting to retrieve {url}:\n", + f"Received status code '{err.code}' when attempting to retrieve {url}:\n", f"{err.reason}\n\nheaders={err.headers}", ) raise RuntimeError(exception_message) from err @@ -136,10 +136,10 @@ def find_job_id_name(args: Any) -> tuple[str, str]: def set_output(name: str, val: Any) -> None: + print(f"Setting output {name}={val}") if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: print(f"{name}={val}", file=env) - print(f"setting {name}={val}") else: print(f"::set-output name={name}::{val}") diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py index 43ee063bd63492..3a90ddb5f4c6b6 100644 --- a/.github/scripts/gitutils.py +++ b/.github/scripts/gitutils.py @@ -211,7 +211,7 @@ def compute_branch_diffs( self, from_branch: str, to_branch: str ) -> tuple[list[str], list[str]]: """ - Returns list of commmits that are missing in each other branch since their merge base + Returns list of commits that are missing in each other branch since their merge base Might be slow if merge base is between two branches is pretty far off """ from_ref = self.rev_parse(from_branch) diff --git a/.github/scripts/gql_mocks.json.gz b/.github/scripts/gql_mocks.json.gz index 4445e8c9041ff0..07628227a18a8c 100644 Binary files a/.github/scripts/gql_mocks.json.gz and b/.github/scripts/gql_mocks.json.gz differ diff --git a/.github/scripts/parse_ref.py b/.github/scripts/parse_ref.py index 17c7c0c82189d7..e821750a49e103 100755 --- a/.github/scripts/parse_ref.py +++ b/.github/scripts/parse_ref.py @@ -5,6 +5,7 @@ def set_output(name: str, val: str) -> None: + print(f"Setting output {name}={val}") if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: print(f"{name}={val}", file=env) diff --git a/.github/scripts/pr-sanity-check.sh b/.github/scripts/pr-sanity-check.sh index 2b33dd91f770de..4c7ac868ab1b47 100644 --- a/.github/scripts/pr-sanity-check.sh +++ b/.github/scripts/pr-sanity-check.sh @@ -12,7 +12,7 @@ BASE=${BASE:-HEAD~1} HEAD=${HEAD:-HEAD} ancestor=$(git merge-base "${BASE}" "${HEAD}") -echo "INFO: Checking aginst the following stats" +echo "INFO: Checking against the following stats" ( set -x git diff --stat=10000 "$ancestor" "${HEAD}" | sed '$d' > "${TMPFILE}" diff --git a/.github/scripts/tag_docker_images_for_release.py b/.github/scripts/tag_docker_images_for_release.py deleted file mode 100644 index b2bf474575f6f5..00000000000000 --- a/.github/scripts/tag_docker_images_for_release.py +++ /dev/null @@ -1,64 +0,0 @@ -import argparse -import subprocess - -import generate_binary_build_matrix - - -def tag_image( - image: str, - default_tag: str, - release_version: str, - dry_run: str, - tagged_images: dict[str, bool], -) -> None: - if image in tagged_images: - return - release_image = image.replace(f"-{default_tag}", f"-{release_version}") - print(f"Tagging {image} to {release_image} , dry_run: {dry_run}") - - if dry_run == "disabled": - subprocess.check_call(["docker", "pull", image]) - subprocess.check_call(["docker", "tag", image, release_image]) - subprocess.check_call(["docker", "push", release_image]) - tagged_images[image] = True - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--version", - help="Version to tag", - type=str, - default="2.2", - ) - parser.add_argument( - "--dry-run", - help="No Runtime Error check", - type=str, - choices=["enabled", "disabled"], - default="enabled", - ) - - options = parser.parse_args() - tagged_images: dict[str, bool] = {} - platform_images = [ - generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES, - generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES, - ] - default_tag = generate_binary_build_matrix.DEFAULT_TAG - - for platform_image in platform_images: # type: ignore[attr-defined] - for arch in platform_image.keys(): # type: ignore[attr-defined] - if arch == "cpu-s390x": - continue - tag_image( - platform_image[arch], # type: ignore[index] - default_tag, - options.version, - options.dry_run, - tagged_images, - ) - - -if __name__ == "__main__": - main() diff --git a/.github/scripts/test_delete_old_branches.py b/.github/scripts/test_delete_old_branches.py new file mode 100644 index 00000000000000..08ccd84f7d2954 --- /dev/null +++ b/.github/scripts/test_delete_old_branches.py @@ -0,0 +1,56 @@ +import os +import unittest +from datetime import datetime +from unittest.mock import MagicMock, patch + + +os.environ["GITHUB_TOKEN"] = "test_token" + +from delete_old_branches import delete_old_tags + + +@patch("delete_old_branches.delete_branch") +@patch("gitutils.GitRepo._run_git") +class TestDeleteTag(unittest.TestCase): + def test_delete_tag( + self, mock_run_git: "MagicMock", mock_delete_tag: "MagicMock" + ) -> None: + for tag in [ + "ciflow/branch/12345", + "ciflow/commitsha/1234567890abcdef1234567890abcdef12345678", + "trunk/1234567890abcdef1234567890abcdef12345678", + ]: + mock_run_git.side_effect = [ + tag, + str(int(datetime.now().timestamp() - 8 * 24 * 60 * 60)), # 8 days ago + ] + delete_old_tags() + mock_delete_tag.assert_called_once() + mock_delete_tag.reset_mock() + + # Don't delete if the tag is not old enough + mock_run_git.side_effect = [ + tag, + str(int(datetime.now().timestamp() - 6 * 24 * 60 * 60)), # 6 days ago + ] + delete_old_tags() + mock_delete_tag.assert_not_called() + + def test_do_not_delete_tag( + self, mock_run_git: "MagicMock", mock_delete_tag: "MagicMock" + ) -> None: + for tag in [ + "ciflow/doesntseemtomatch", + "trunk/doesntseemtomatch", + "doesntseemtomatch", + ]: + mock_run_git.side_effect = [ + tag, + str(int(datetime.now().timestamp() - 8 * 24 * 60 * 60)), # 8 days ago + ] + delete_old_tags() + mock_delete_tag.assert_not_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/.github/scripts/test_filter_test_configs.py b/.github/scripts/test_filter_test_configs.py index d7d2a80f1d6f4b..26e38828b7865e 100755 --- a/.github/scripts/test_filter_test_configs.py +++ b/.github/scripts/test_filter_test_configs.py @@ -347,26 +347,26 @@ def test_set_periodic_modes(self) -> None: { "job_name": "a-ci-job", "test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "cfg", runner: "macos"}]}', - "descripion": "Replicate each periodic mode in a different config", + "description": "Replicate each periodic mode in a different config", }, { "job_name": "a-ci-cuda11.8-job", "test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "cfg", runner: "macos"}]}', - "descripion": "Replicate each periodic mode in a different config for a CUDA job", + "description": "Replicate each periodic mode in a different config for a CUDA job", }, { "job_name": "a-ci-rocm-job", "test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "cfg", runner: "macos"}]}', - "descripion": "Replicate each periodic mode in a different config for a ROCm job", + "description": "Replicate each periodic mode in a different config for a ROCm job", }, { "job_name": "", "test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "cfg", runner: "macos"}]}', - "descripion": "Empty job name", + "description": "Empty job name", }, { "test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "cfg", runner: "macos"}]}', - "descripion": "Missing job name", + "description": "Missing job name", }, ] diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index 1a152dc95945e4..e4a8cb2bc8df1f 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -19,6 +19,7 @@ from github_utils import gh_graphql from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo from trymerge import ( + _revlist_to_prs, categorize_checks, DRCI_CHECKRUN_NAME, find_matching_merge_rule, @@ -264,7 +265,7 @@ def commits_resolving_gh_pr(self, pr_num: int) -> list[str]: return ["FakeCommitSha"] def commit_message(self, ref: str) -> str: - return "super awsome commit message" + return "super awesome commit message" @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql) @@ -432,7 +433,7 @@ def test_get_checkruns_many_runs(self, *args: Any) -> None: ) def test_cancelled_gets_ignored(self, *args: Any) -> None: - """Tests that cancelled workflow does not override existing successfull status""" + """Tests that cancelled workflow does not override existing successful status""" pr = GitHubPR("pytorch", "pytorch", 110367) conclusions = pr.get_checkrun_conclusions() lint_checks = [name for name in conclusions.keys() if "Lint" in name] @@ -1088,5 +1089,51 @@ def test_merge_ghstack_into( ) +@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql) +@mock.patch("trymerge.gh_fetch_merge_base", return_value="") +@mock.patch( + "trymerge.get_drci_classifications", side_effect=mocked_drci_classifications +) +@mock.patch.object(DummyGitRepo, "commit_message") +class TestRevListToPR(TestCase): + # Tests for _revlist_to_prs function + def test__revlist_to_prs_zero_matches( + self, mock_commit_message: mock.MagicMock, *args: Any + ) -> None: + # If zero PRs are mentioned in the commit message, it should raise an error + pr_num = 154098 + pr = GitHubPR("pytorch", "pytorch", pr_num) + repo = DummyGitRepo() + mock_commit_message.return_value = "no PRs" + self.assertRaisesRegex( + RuntimeError, + "PRs mentioned in commit dummy: 0.", + lambda: _revlist_to_prs(repo, pr, ["dummy"]), + ) + + def test__revlist_to_prs_two_prs( + self, mock_commit_message: mock.MagicMock, *args: Any + ) -> None: + # If two PRs are mentioned in the commit message, it should raise an error + pr_num = 154394 + pr = GitHubPR("pytorch", "pytorch", pr_num) + repo = DummyGitRepo() + # https://github.com/pytorch/pytorch/commit/343c56e7650f55fd030aca0b9275d6d73501d3f4 + + commit_message = """add sticky cache pgo + +ghstack-source-id: 9bc6dee0b427819f978bfabccb72727ba8be2f81 +Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/154098 + +ghstack-source-id: 9bc6dee0b427819f978bfabccb72727ba8be2f81 +Pull Request resolved: https://github.com/pytorch/pytorch/pull/154394""" + mock_commit_message.return_value = commit_message + self.assertRaisesRegex( + RuntimeError, + "PRs mentioned in commit dummy: 2.", + lambda: _revlist_to_prs(repo, pr, ["dummy"]), + ) + + if __name__ == "__main__": main() diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 54285e7f5da82d..1b6b96cff7dd01 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -628,11 +628,17 @@ def _revlist_to_prs( rc: list[tuple[GitHubPR, str]] = [] for idx, rev in enumerate(rev_list): msg = repo.commit_message(rev) - m = RE_PULL_REQUEST_RESOLVED.search(msg) - if m is None: + # findall doesn't return named captures, so we need to use finditer + all_matches = list(RE_PULL_REQUEST_RESOLVED.finditer(msg)) + if len(all_matches) != 1: raise RuntimeError( - f"Could not find PR-resolved string in {msg} of ghstacked PR {pr.pr_num}" + f"Found an unexpected number of PRs mentioned in commit {rev}: " + f"{len(all_matches)}. This is probably because you are using an " + "old version of ghstack. Please update ghstack and resubmit " + "your PRs" ) + + m = all_matches[0] if pr.org != m.group("owner") or pr.project != m.group("repo"): raise RuntimeError( f"PR {m.group('number')} resolved to wrong owner/repo pair" @@ -666,6 +672,9 @@ def skip_func(idx: int, candidate: "GitHubPR") -> bool: assert pr.is_ghstack_pr() entire_stack = _revlist_to_prs(repo, pr, reversed(rev_list), skip_func) + print( + f"Found {len(entire_stack)} PRs in the stack for {pr.pr_num}: {[x[0].pr_num for x in entire_stack]}" + ) for stacked_pr, rev in entire_stack: if stacked_pr.is_closed(): diff --git a/.github/scripts/windows/build_magma.bat b/.github/scripts/windows/build_magma.bat index b8701ddde3fcc9..0f11fe34068eb2 100644 --- a/.github/scripts/windows/build_magma.bat +++ b/.github/scripts/windows/build_magma.bat @@ -35,15 +35,15 @@ cd magma mkdir build && cd build set GPU_TARGET=All +if "%CUVER_NODOT%" == "129" ( + set CUDA_ARCH_LIST=-gencode=arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 +) if "%CUVER_NODOT%" == "128" ( set CUDA_ARCH_LIST=-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 ) -if "%CUVER_NODOT:~0,2%" == "12" if NOT "%CUVER_NODOT%" == "128" ( +if "%CUVER_NODOT%" == "126" ( set CUDA_ARCH_LIST=-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 ) -if "%CUVER_NODOT%" == "118" ( - set CUDA_ARCH_LIST= -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -) set CC=cl.exe set CXX=cl.exe diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index f95cc63f8d536c..b14a13f3f90c25 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -114,12 +114,12 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" {%- elif config["gpu_arch_type"] == "rocm" %} runs_on: linux.rocm.gpu - {%- elif config["gpu_arch_type"] == "cuda" and config["gpu_arch_version"] == "12.8" %} + {%- elif config["gpu_arch_type"] == "cuda" and config["gpu_arch_version"] in ["12.8", "12.9"] %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner - {%- elif config["gpu_arch_type"] == "cuda" and config["gpu_arch_version"] != "12.8"%} + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + {%- elif config["gpu_arch_type"] == "cuda" %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner {%- else %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge @@ -171,7 +171,7 @@ jobs: - name: Teardown XPU uses: ./.github/actions/teardown-xpu {%- else %} - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: !{{ common.timeout_minutes }} !{{ upload.binary_env(config) }} steps: diff --git a/.github/workflows/_link_check.yml b/.github/workflows/_link_check.yml index 7219a868580721..efe92ca627bbaa 100644 --- a/.github/workflows/_link_check.yml +++ b/.github/workflows/_link_check.yml @@ -15,7 +15,7 @@ jobs: with: timeout: 120 runner: ${{ inputs.runner }}linux.2xlarge - docker-image: ci-image:pytorch-linux-focal-linter + docker-image: ci-image:pytorch-linux-jammy-linter fetch-depth: 0 submodules: false ref: ${{ inputs.ref }} @@ -40,7 +40,7 @@ jobs: with: timeout: 60 runner: ${{ inputs.runner }}linux.2xlarge - docker-image: ci-image:pytorch-linux-focal-linter + docker-image: ci-image:pytorch-linux-jammy-linter fetch-depth: 0 submodules: false ref: ${{ inputs.ref }} diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index e02db8ca43e9ae..8cb4fbb9404fcc 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -69,11 +69,6 @@ on: required: false type: string default: "" - max-jobs: - description: | - Overwrite the number of jobs to use for the build - required: false - type: string disable-monitor: description: | Disable utilization monitoring for build job @@ -99,7 +94,7 @@ on: commit with no cpp changes from this commit required: false type: boolean - default: false + default: true secrets: HUGGING_FACE_HUB_TOKEN: @@ -158,14 +153,23 @@ jobs: role-session-name: gha-linux-build aws-region: us-east-1 + - name: Get workflow job id + id: get-job-id + uses: ./.github/actions/get-workflow-job-id + if: always() + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + - name: Check if can use old whl build id: use-old-whl uses: ./.github/actions/reuse-old-whl - if: ${{ inputs.allow-reuse-old-whl && github.event_name == 'push' }} + if: ${{ inputs.allow-reuse-old-whl }} with: build-environment: ${{ inputs.build-environment }} run-id: ${{ github.run_id }} github-token: ${{ secrets.GITHUB_TOKEN }} + job-id: ${{ steps.get-job-id.outputs.job-id }} + job-name: ${{ steps.get-job-id.outputs.job-name }} - name: Calculate docker image id: calculate-docker-image @@ -181,7 +185,7 @@ jobs: ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} shell: bash run: | - tag=${ECR_DOCKER_IMAGE##*/} + tag=${ECR_DOCKER_IMAGE##*:} echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image @@ -194,13 +198,6 @@ jobs: id: parse-ref run: .github/scripts/parse_ref.py - - name: Get workflow job id - id: get-job-id - uses: ./.github/actions/get-workflow-job-id - if: always() - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - # Apply the filter logic to the build step too if the test-config label is already there - name: Select all requested test configurations (if the test matrix is available) id: filter @@ -264,7 +261,6 @@ jobs: OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} - MAX_JOBS_OVERRIDE: ${{ inputs.max-jobs }} run: | START_TIME=$(date +%s) if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then @@ -284,12 +280,6 @@ jobs: DOCKER_SHELL_CMD= fi - if [[ ${MAX_JOBS_OVERRIDE} == "" ]]; then - MAX_JOBS="$(nproc --ignore=2)" - else - MAX_JOBS="${MAX_JOBS_OVERRIDE}" - fi - # Leaving 1GB for the runner and other things TOTAL_AVAILABLE_MEMORY_IN_GB=$(awk '/MemTotal/ { printf "%.3f \n", $2/1024/1024 - 1 }' /proc/meminfo) # https://docs.docker.com/engine/containers/resource_constraints/#--memory-swap-details, the 3GB swap @@ -301,8 +291,7 @@ jobs: # shellcheck disable=SC2086 container_name=$(docker run \ -e BUILD_ENVIRONMENT \ - -e MAX_JOBS=${MAX_JOBS} \ - -e MAX_JOBS_OVERRIDE \ + -e MAX_JOBS="$(nproc --ignore=2)" \ -e AWS_DEFAULT_REGION \ -e PR_NUMBER \ -e SHA1 \ diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 0e4459907f1ddb..469367d4d6841f 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -90,10 +90,13 @@ jobs: environment: ${{ github.ref == 'refs/heads/main' && 'scribe-protected' || startsWith(github.ref, 'refs/heads/release/') && 'scribe-protected' || contains(github.event.pull_request.labels.*.name, 'ci-scribe') && 'scribe-pr' || '' }} runs-on: ${{ matrix.runner }} timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }} + permissions: + id-token: write + contents: read steps: - name: Setup SSH (Click me for login details) uses: pytorch/test-infra/.github/actions/setup-ssh@main - if: ${{ !contains(matrix.runner, 'gcp.a100') && inputs.build-environment != 'linux-s390x-binary-manywheel' }} + if: ${{ matrix.runner != 'B200' && inputs.build-environment != 'linux-s390x-binary-manywheel' }} with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -105,18 +108,31 @@ jobs: with: no-sudo: true + - name: Setup Python + if: matrix.runner == 'B200' + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.12' + cache: pip + - name: Setup Linux uses: ./.github/actions/setup-linux - if: inputs.build-environment != 'linux-s390x-binary-manywheel' + if: inputs.build-environment != 'linux-s390x-binary-manywheel' && matrix.runner != 'B200' - name: configure aws credentials - if : ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }} + if: ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }} uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 with: role-to-assume: ${{ inputs.aws-role-to-assume }} role-session-name: gha-linux-test aws-region: us-east-1 + - name: Login to Amazon ECR + if: ${{ inputs.aws-role-to-assume != '' && matrix.runner == 'B200' }} + id: login-ecr + continue-on-error: true + uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 + - name: Calculate docker image id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main @@ -131,7 +147,7 @@ jobs: ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} shell: bash run: | - tag=${ECR_DOCKER_IMAGE##*/} + tag=${ECR_DOCKER_IMAGE##*:} echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image @@ -148,17 +164,17 @@ jobs: - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG id: install-nvidia-driver uses: pytorch/test-infra/.github/actions/setup-nvidia@main - if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} + if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }} - name: Setup GPU_FLAG for docker run id: setup-gpu-flag run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}" - if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' }} + if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || matrix.runner == 'B200') }} - name: Setup SCCACHE_SERVER_PORT environment for docker run when on container id: setup-sscache-port-flag run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}" - if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' }} + if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && matrix.runner != 'B200' }} - name: Lock NVIDIA A100 40GB Frequency run: | @@ -207,7 +223,7 @@ jobs: run: .github/scripts/parse_ref.py - name: Check for keep-going label and re-enabled test issues - # This uses the filter-test-configs action because it conviniently + # This uses the filter-test-configs action because it conveniently # checks for labels and re-enabled test issues. It does not actually do # any filtering. All filtering is done in the build step. id: keep-going @@ -225,6 +241,12 @@ jobs: run: | echo "timeout=$((JOB_TIMEOUT-30))" >> "${GITHUB_OUTPUT}" + - name: Preserve github env variables for use in docker + shell: bash + run: | + env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}" + env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Test id: test timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }} @@ -253,8 +275,8 @@ jobs: NO_TD: ${{ steps.keep-going.outputs.ci-no-td }} TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }} # Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs - SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 - SCCACHE_REGION: us-east-1 + SCCACHE_BUCKET: ${{ matrix.runner != 'B200' && 'ossci-compiler-cache-circleci-v2' || '' }} + SCCACHE_REGION: ${{ matrix.runner != 'B200' && 'us-east-1' || '' }} SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }} DOCKER_IMAGE: ${{ inputs.docker-image }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} @@ -264,7 +286,6 @@ jobs: DASHBOARD_TAG: ${{ inputs.dashboard-tag }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} - IS_A100_RUNNER: ${{ contains(matrix.runner, 'a100') && '1' || '0' }} ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} run: | set -x @@ -290,10 +311,6 @@ jobs: # if for some reason cleanup action doesn't stop container # when job is cancelled DOCKER_SHELL_CMD="sleep 12h" - - # since some steps are skipped on s390x, if they are necessary, run them here - env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}" - env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}" else SHM_OPTS="--shm-size=${SHM_SIZE}" JENKINS_USER="--user jenkins" @@ -345,7 +362,6 @@ jobs: -e HUGGING_FACE_HUB_TOKEN \ -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ -e DASHBOARD_TAG \ - -e IS_A100_RUNNER \ -e ARTIFACTS_FILE_SUFFIX \ --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \ --memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \ @@ -384,8 +400,18 @@ jobs: test_config: ${{ matrix.config }} job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }} + - name: Authenticate with AWS + if: ${{ matrix.runner == 'B200' }} + uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results + # The max duration enforced by the server side + role-duration-seconds: 18000 + aws-region: us-east-1 + - name: Upload the benchmark results uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + if: inputs.build-environment != 'linux-s390x-binary-manywheel' with: benchmark-results-dir: test/test-reports dry-run: false diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 118022785e039c..78d675d76af82b 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -81,13 +81,11 @@ jobs: echo "DEVELOPER_DIR=/Applications/Xcode_${XCODE_VERSION}.app/Contents/Developer" >> "${GITHUB_ENV}" fi - - name: Setup miniconda - uses: pytorch/test-infra/.github/actions/setup-miniconda@main + - name: Setup Python + uses: pytorch/test-infra/.github/actions/setup-python@main with: python-version: ${{ inputs.python-version }} - environment-file: .github/requirements/conda-env-macOS-ARM64 pip-requirements-file: .github/requirements/pip-requirements-macOS.txt - default-packages: "" - name: Install sccache (only for non-forked PRs, and pushes to trunk) uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0 @@ -125,7 +123,7 @@ jobs: else # The runner has access to the S3 bucket via IAM profile without the need # for any credential - echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}"0 + echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" fi diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 265c3a629f45a0..d7ef208da50bb4 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -60,8 +60,6 @@ jobs: test: # Don't run on forked repos or empty test matrix if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' - # For setup-miniconda, see https://github.com/conda-incubator/setup-miniconda/issues/179 - # Also ensure that we always run with the right architecture defaults: run: shell: bash -e -l {0} @@ -90,6 +88,10 @@ jobs: pkill "${PROCESS}" || true done + - name: Clean up leftover miniconda installation + continue-on-error: true + run: brew uninstall miniconda || true + - name: Clean up leftover local python3 site-packages on MacOS pet runner continue-on-error: true run: | @@ -124,8 +126,8 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | - ${CONDA_RUN} python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 - ${CONDA_RUN} python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & + python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 + python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" - name: Download build artifacts @@ -140,11 +142,10 @@ jobs: with: use-gha: true - - name: Setup miniconda - uses: pytorch/test-infra/.github/actions/setup-miniconda@main + - name: Setup Python + uses: pytorch/test-infra/.github/actions/setup-python@main with: python-version: ${{ inputs.python-version }} - environment-file: .github/requirements/conda-env-macOS-ARM64 pip-requirements-file: .github/requirements/pip-requirements-macOS.txt default-packages: "" @@ -153,7 +154,7 @@ jobs: run: .github/scripts/parse_ref.py - name: Check for keep-going label and re-enabled test issues - # This uses the filter-test-configs action because it conviniently + # This uses the filter-test-configs action because it conveniently # checks for labels and re-enabled test issues. It does not actually do # any filtering. All filtering is done in the build step. id: keep-going @@ -197,37 +198,32 @@ jobs: # shellcheck disable=SC1090 set -ex - arch - - if [[ -n "$CONDA_ENV" ]]; then - # Use binaries under conda environment - export PATH="$CONDA_ENV/bin":$PATH - fi + # TODO: Remove me later, and properly activate venv + PATH="$(dirname "$(which python)"):$PATH" + export PATH # Print out some information about the test environment - which conda - conda --version - ${CONDA_RUN} which python3 - ${CONDA_RUN} python3 --version - ${CONDA_RUN} which python - ${CONDA_RUN} python --version + for tool in python3 python; do + which $tool + $tool --version + done - ${CONDA_RUN} python3 -mpip install --no-index --no-deps dist/*.whl + python3 -mpip install --no-index --no-deps dist/*.whl set +e pushd "${RUNNER_TEMP}" # Install pip dependencies if they are not found. This is to mitigate a peculiar # flaky missing dependencies on MacOS - ${CONDA_RUN} python3 -c "import torch" + python3 -c "import torch" RC=$? popd if [ "${RC}" -ne 0 ]; then - ${CONDA_RUN} python3 -mpip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}" + python3 -mpip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}" fi set -e - ${CONDA_RUN} .ci/pytorch/macos-test.sh + .ci/pytorch/macos-test.sh - name: Print remaining test logs shell: bash @@ -239,11 +235,7 @@ jobs: shell: bash if: ${{ contains(steps.get-job-id.outputs.job-name, 'mps') }} run: | - if [[ -n "$CONDA_ENV" ]]; then - # Use binaries under conda environment - export PATH="$CONDA_ENV/bin":$PATH - fi - ${CONDA_RUN} python3 test/bench_mps_ops.py + python3 test/bench_mps_ops.py - name: Stop monitoring script diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index be83312455e21c..006ab43da29d6b 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -150,7 +150,7 @@ jobs: run: .github/scripts/parse_ref.py - name: Check for keep-going label and re-enabled test issues - # This uses the filter-test-configs action because it conviniently + # This uses the filter-test-configs action because it conveniently # checks for labels and re-enabled test issues. It does not actually do # any filtering. All filtering is done in the build step. id: keep-going diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index f3ae11ff27c51d..0d674f044ec42c 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -7,7 +7,7 @@ on: required: false type: string description: | - List of experiments for this workfow. If not defined, all default experiments are included. + List of experiments for this workflow. If not defined, all default experiments are included. opt_out_experiments: required: false type: string diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index 799b833047bc64..ebfb4001e4379d 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -23,7 +23,7 @@ on: vc-year: required: false type: string - default: "2019" + default: "2022" description: The Visual Studio year to use for building. build-with-debug: required: false @@ -98,7 +98,7 @@ jobs: To start build locally, change working folder to \actions-runner\_work\pytorch\pytorch, Activate miniconda and Visual Studio environment, by running: call C:\Jenkins\Miniconda3\Scripts\activate.bat C:\Jenkins\Miniconda3 - call "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 + call "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 # [see note: pytorch repo ref] - name: Checkout PyTorch diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index f9071d3bb7134d..36b4e5cd753f68 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -91,7 +91,7 @@ jobs: To start tests locally, change working folder to \actions-runner\_work\pytorch\pytorch\test, Activate miniconda and Visual Studio environment and set PYTHON_PATH, by running: call C:\Jenkins\Miniconda3\Scripts\activate.bat C:\Jenkins\Miniconda3 - call "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 + call "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 set PYTHONPATH=C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build # [see note: pytorch repo ref] @@ -158,7 +158,7 @@ jobs: uses: ./.github/actions/download-td-artifacts - name: Check for keep-going label and re-enabled test issues - # This uses the filter-test-configs action because it conviniently + # This uses the filter-test-configs action because it conveniently # checks for labels and re-enabled test issues. It does not actually do # any filtering. All filtering is done in the build step. id: keep-going @@ -191,8 +191,8 @@ jobs: NO_TD: ${{ steps.keep-going.outputs.ci-no-td }} VC_PRODUCT: "BuildTools" VC_VERSION: "" - VS_VERSION: "16.8.6" - VC_YEAR: "2019" + VS_VERSION: "17.4.1" + VC_YEAR: "2022" AWS_DEFAULT_REGION: us-east-1 PR_NUMBER: ${{ github.event.pull_request.number }} GITHUB_REPOSITORY: ${{ github.repository }} diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index a05519b543f4f5..7f78280aada6f8 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -105,7 +105,7 @@ jobs: ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} shell: bash run: | - tag=${ECR_DOCKER_IMAGE##*/} + tag=${ECR_DOCKER_IMAGE##*:} echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image @@ -147,7 +147,7 @@ jobs: run: .github/scripts/parse_ref.py - name: Check for keep-going label and re-enabled test issues - # This uses the filter-test-configs action because it conviniently + # This uses the filter-test-configs action because it conveniently # checks for labels and re-enabled test issues. It does not actually do # any filtering. All filtering is done in the build step. id: keep-going diff --git a/.github/workflows/build-almalinux-images.yml b/.github/workflows/build-almalinux-images.yml index 0a4a619808c4f5..aaf85d7fc8067e 100644 --- a/.github/workflows/build-almalinux-images.yml +++ b/.github/workflows/build-almalinux-images.yml @@ -23,7 +23,7 @@ on: env: DOCKER_REGISTRY: "docker.io" DOCKER_BUILDKIT: 1 - WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release')) }} + WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) }} concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} @@ -32,11 +32,11 @@ concurrency: jobs: build-docker: if: github.repository_owner == 'pytorch' - environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} + environment: ${{ (github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) && 'docker-build') || '' }} runs-on: linux.9xlarge.ephemeral strategy: matrix: - tag: ["cuda11.8", "cuda12.6", "cuda12.8", "rocm6.3", "rocm6.4", "cpu"] + tag: ["cuda12.6", "cuda12.8", "cuda12.9", "rocm6.3", "rocm6.4", "cpu"] steps: - name: Build docker image uses: pytorch/pytorch/.github/actions/binary-docker-build@main diff --git a/.github/workflows/build-libtorch-images.yml b/.github/workflows/build-libtorch-images.yml index 90a879022890ab..b2d50efd7d96ce 100644 --- a/.github/workflows/build-libtorch-images.yml +++ b/.github/workflows/build-libtorch-images.yml @@ -22,7 +22,7 @@ on: env: DOCKER_REGISTRY: "docker.io" DOCKER_BUILDKIT: 1 - WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release')) }} + WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) }} concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} @@ -40,7 +40,7 @@ jobs: curr_ref_type: ${{ github.ref_type }} build: - environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} + environment: ${{ (github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) && 'docker-build') || '' }} needs: get-label-type runs-on: ${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral name: libtorch-cxx11-builder:${{ matrix.tag }} @@ -48,10 +48,9 @@ jobs: fail-fast: false matrix: include: [ + { tag: "cuda12.9" }, { tag: "cuda12.8" }, { tag: "cuda12.6" }, - { tag: "cuda12.4" }, - { tag: "cuda11.8" }, { tag: "rocm6.3" }, { tag: "rocm6.4" }, { tag: "cpu" }, diff --git a/.github/workflows/build-magma-linux.yml b/.github/workflows/build-magma-linux.yml index ee30f7fb3b1652..e13de48b2408a3 100644 --- a/.github/workflows/build-magma-linux.yml +++ b/.github/workflows/build-magma-linux.yml @@ -34,7 +34,7 @@ jobs: id-token: write strategy: matrix: - cuda_version: ["128", "126", "118"] + cuda_version: ["129", "128", "126"] steps: - name: Checkout PyTorch uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/build-magma-rocm-linux.yml b/.github/workflows/build-magma-rocm-linux.yml index ef34ca12fb46fe..b6eb09188fd486 100644 --- a/.github/workflows/build-magma-rocm-linux.yml +++ b/.github/workflows/build-magma-rocm-linux.yml @@ -29,7 +29,7 @@ concurrency: jobs: build-linux-magma-rocm: if: github.repository_owner == 'pytorch' - runs-on: linux.12xlarge + runs-on: linux.2xlarge permissions: id-token: write strategy: diff --git a/.github/workflows/build-magma-windows.yml b/.github/workflows/build-magma-windows.yml index 1fdaf0ff98afc4..80d870f419e427 100644 --- a/.github/workflows/build-magma-windows.yml +++ b/.github/workflows/build-magma-windows.yml @@ -19,14 +19,15 @@ concurrency: jobs: build-windows-magma: if: github.repository_owner == 'pytorch' - runs-on: windows-2019 + runs-on: windows-2022 strategy: matrix: - cuda_version: ["128", "126", "118"] + cuda_version: ["129", "128", "126"] config: ["Release", "Debug"] env: CUDA_VERSION: ${{ matrix.cuda_version }} CONFIG: ${{ matrix.config }} + VC_YEAR: "2022" steps: - name: Checkout pytorch/pytorch uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/build-manywheel-images-s390x.yml b/.github/workflows/build-manywheel-images-s390x.yml index ac44f41e803267..c498e169f1aa58 100644 --- a/.github/workflows/build-manywheel-images-s390x.yml +++ b/.github/workflows/build-manywheel-images-s390x.yml @@ -3,26 +3,16 @@ name: Build manywheel docker images for s390x on: workflow_dispatch: push: - branches: - - main - - release/* tags: - # NOTE: Binary build pipelines should only get triggered on release candidate or nightly builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + - ciflow/s390/* paths: - - .ci/docker/** - - .github/workflows/build-manywheel-images-s390x.yml - pull_request: - paths: - - .ci/docker/** - .github/workflows/build-manywheel-images-s390x.yml env: DOCKER_REGISTRY: "docker.io" DOCKER_BUILDKIT: 1 - WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release')) }} + WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) }} concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} @@ -31,7 +21,7 @@ concurrency: jobs: build-docker-cpu-s390x: if: github.repository_owner == 'pytorch' - environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} + environment: ${{ (github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) && 'docker-build') || '' }} runs-on: linux.s390x steps: - name: Checkout PyTorch @@ -63,7 +53,7 @@ jobs: docker tag "${CREATED_FULL_DOCKER_IMAGE_NAME}" "${DOCKER_IMAGE_NAME_PREFIX}-${GIT_COMMIT_SHA}" docker tag "${CREATED_FULL_DOCKER_IMAGE_NAME}" "${DOCKER_IMAGE_NAME_PREFIX}-${CI_FOLDER_SHA}" - # Prety sure Github will mask tokens and I'm not sure if it will even be + # Pretty sure Github will mask tokens and I'm not sure if it will even be # printed due to pipe, but just in case set +x if [[ "${WITH_PUSH:-false}" == "true" ]]; then diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index af3af643275810..e84b84f6158ba7 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -23,8 +23,7 @@ on: env: DOCKER_REGISTRY: "docker.io" DOCKER_BUILDKIT: 1 - WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release')) }} - + WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) }} concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true @@ -41,16 +40,16 @@ jobs: curr_ref_type: ${{ github.ref_type }} build: - environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} + environment: ${{ (github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) && 'docker-build') || '' }} needs: get-label-type strategy: fail-fast: false matrix: include: [ + { name: "manylinux2_28-builder", tag: "cuda12.9", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cuda12.8", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cuda12.6", runner: "linux.9xlarge.ephemeral" }, - { name: "manylinux2_28-builder", tag: "cuda12.4", runner: "linux.9xlarge.ephemeral" }, - { name: "manylinux2_28-builder", tag: "cuda11.8", runner: "linux.9xlarge.ephemeral" }, + { name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "rocm6.3", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" }, diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 63d2ae7aa10993..83bfda55c5cdc1 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -8,6 +8,7 @@ on: # NOTE: Binary build pipelines should only get triggered on release candidate builds # Release candidate tags look like: v1.11.0-rc1 - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + - 'ciflow/triton_binaries/*' paths: - .github/workflows/build-triton-wheel.yml - .github/scripts/build_triton_wheel.py @@ -139,6 +140,15 @@ jobs: docker exec -t "${container_name}" yum install -y zlib-devel zip docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==78.1.0 pybind11==2.13.1 auditwheel wheel + set +e + docker exec -t "${container_name}" command -v pip + has_pip=$? + set -e + if [ $has_pip -eq 0 ] ; then + docker exec -t "${container_name}" pip install -U cmake --force-reinstall + else + docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U cmake --force-reinstall + fi if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "rocm" || "${{ matrix.device }}" == "aarch64" ) ]]; then # With this install, it gets clang 16.0.6. diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index 027fa784d9e51a..db8fbcb4bdc7d6 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -55,6 +55,8 @@ jobs: tag_or_branch="${tag_or_branch//\//_}" echo "PT_RELEASE_NAME=pytorch-$tag_or_branch" >> "$GITHUB_ENV" echo "PT_RELEASE_FILE=pytorch-$tag_or_branch.tar.gz" >> "$GITHUB_ENV" + - name: Checkout optional submodules + run: python3 tools/optional_submodules.py - name: Create source distribution run: | # Create new folder with specified name so extracting the archive yields that @@ -80,7 +82,7 @@ jobs: path: ${{ env.PT_RELEASE_FILE }} - name: Set output id: release_name - run: echo "name=pt_release_name::${{ env.PT_RELEASE_NAME }}.tar.gz" >> "${GITHUB_OUTPUT}" + run: echo "pt_release_name=${{ env.PT_RELEASE_NAME }}.tar.gz" >> "${GITHUB_OUTPUT}" upload_source_code_to_s3: if: ${{ github.repository == 'pytorch/pytorch' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }} diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index a648ea90350d07..caf73275332b32 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -49,29 +49,30 @@ jobs: matrix: runner: [linux.12xlarge] docker-image-name: [ - pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11, + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9, - pytorch-linux-focal-py3.9-clang10, - pytorch-linux-focal-py3.11-clang10, - pytorch-linux-focal-py3.12-clang10, - pytorch-linux-focal-py3.13-clang10, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks, + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, + pytorch-linux-jammy-py3.9-clang12, + pytorch-linux-jammy-py3.11-clang12, + pytorch-linux-jammy-py3.12-clang12, + pytorch-linux-jammy-py3.13-clang12, pytorch-linux-jammy-rocm-n-1-py3, pytorch-linux-jammy-rocm-n-py3, - pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12, pytorch-linux-jammy-py3.9-gcc11, pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks, pytorch-linux-jammy-py3.12-halide, pytorch-linux-jammy-xpu-2025.0-py3, pytorch-linux-jammy-xpu-2025.1-py3, - pytorch-linux-jammy-py3-clang15-asan, pytorch-linux-jammy-py3-clang18-asan, - pytorch-linux-focal-py3-clang10-onnx, - pytorch-linux-focal-linter, - pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter, + pytorch-linux-jammy-py3-clang12-onnx, + pytorch-linux-jammy-linter, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter, pytorch-linux-jammy-py3-clang12-executorch, pytorch-linux-jammy-py3.12-triton-cpu ] diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 3555cc3cffa974..25dca2b73a0268 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -156,7 +156,7 @@ jobs: docker push ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}${CUDA_SUFFIX}" - # Please note, here we ned to pin specific verison of CUDA as with latest label + # Please note, here we need to pin specific version of CUDA as with latest label if [[ ${CUDA_VERSION_SHORT} == "${STABLE_CUDA_VERSION}" ]]; then docker tag ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}${CUDA_SUFFIX}" \ ghcr.io/pytorch/pytorch-nightly:latest diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 349fbecf18b245..be0c28892eb2ff 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -115,7 +115,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda-aarch64-12_8-build: + manywheel-py3_9-cuda-aarch64-12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -124,41 +124,41 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_9-cuda-aarch64-12_8 + build_name: manywheel-py3_9-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda-aarch64-12_8-upload: # Uploading + manywheel-py3_9-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_9-cuda-aarch64-12_8-build + needs: manywheel-py3_9-cuda-aarch64-12_9-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda-aarch64-12_8 + build_name: manywheel-py3_9-cuda-aarch64-12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -231,7 +231,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda-aarch64-12_8-build: + manywheel-py3_10-cuda-aarch64-12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -240,41 +240,41 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_10-cuda-aarch64-12_8 + build_name: manywheel-py3_10-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda-aarch64-12_8-upload: # Uploading + manywheel-py3_10-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda-aarch64-12_8-build + needs: manywheel-py3_10-cuda-aarch64-12_9-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda-aarch64-12_8 + build_name: manywheel-py3_10-cuda-aarch64-12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -347,7 +347,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda-aarch64-12_8-build: + manywheel-py3_11-cuda-aarch64-12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -356,41 +356,41 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_11-cuda-aarch64-12_8 + build_name: manywheel-py3_11-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda-aarch64-12_8-upload: # Uploading + manywheel-py3_11-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda-aarch64-12_8-build + needs: manywheel-py3_11-cuda-aarch64-12_9-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda-aarch64-12_8 + build_name: manywheel-py3_11-cuda-aarch64-12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -463,7 +463,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda-aarch64-12_8-build: + manywheel-py3_12-cuda-aarch64-12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -472,41 +472,41 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_12-cuda-aarch64-12_8 + build_name: manywheel-py3_12-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda-aarch64-12_8-upload: # Uploading + manywheel-py3_12-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda-aarch64-12_8-build + needs: manywheel-py3_12-cuda-aarch64-12_9-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda-aarch64-12_8 + build_name: manywheel-py3_12-cuda-aarch64-12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -579,7 +579,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda-aarch64-12_8-build: + manywheel-py3_13-cuda-aarch64-12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -588,41 +588,41 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13-cuda-aarch64-12_8 + build_name: manywheel-py3_13-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda-aarch64-12_8-upload: # Uploading + manywheel-py3_13-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-cuda-aarch64-12_8-build + needs: manywheel-py3_13-cuda-aarch64-12_9-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda-aarch64-12_8 + build_name: manywheel-py3_13-cuda-aarch64-12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -695,7 +695,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda-aarch64-12_8-build: + manywheel-py3_13t-cuda-aarch64-12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -704,41 +704,41 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13t-cuda-aarch64-12_8 + build_name: manywheel-py3_13t-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda-aarch64-12_8-upload: # Uploading + manywheel-py3_13t-cuda-aarch64-12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda-aarch64-12_8-build + needs: manywheel-py3_13t-cuda-aarch64-12_9-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8-aarch64 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-12_8 + build_name: manywheel-py3_13t-cuda-aarch64-12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-linux-binary-libtorch-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-nightly.yml index 230901f8d46442..9f4a8194d2874c 100644 --- a/.github/workflows/generated-linux-binary-libtorch-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-nightly.yml @@ -112,7 +112,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda11_8-shared-with-deps-release-build: + libtorch-cuda12_6-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -121,22 +121,22 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: libtorch-cuda11_8-shared-with-deps-release + build_name: libtorch-cuda12_6-shared-with-deps-release build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-cuda11_8-shared-with-deps-release-test: # Testing + libtorch-cuda12_6-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda11_8-shared-with-deps-release-build + - libtorch-cuda12_6-shared-with-deps-release-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -144,43 +144,43 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda11_8-shared-with-deps-release + build_name: libtorch-cuda12_6-shared-with-deps-release build_environment: linux-binary-libtorch runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-cuda11_8-shared-with-deps-release-upload: # Uploading + libtorch-cuda12_6-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda11_8-shared-with-deps-release-test + needs: libtorch-cuda12_6-shared-with-deps-release-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda11_8-shared-with-deps-release + build_name: libtorch-cuda12_6-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_6-shared-with-deps-release-build: + libtorch-cuda12_8-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -189,22 +189,22 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: libtorch-cuda12_6-shared-with-deps-release + build_name: libtorch-cuda12_8-shared-with-deps-release build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-cuda12_6-shared-with-deps-release-test: # Testing + libtorch-cuda12_8-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda12_6-shared-with-deps-release-build + - libtorch-cuda12_8-shared-with-deps-release-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -212,43 +212,43 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda12_6-shared-with-deps-release + build_name: libtorch-cuda12_8-shared-with-deps-release build_environment: linux-binary-libtorch runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-cuda12_6-shared-with-deps-release-upload: # Uploading + libtorch-cuda12_8-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_6-shared-with-deps-release-test + needs: libtorch-cuda12_8-shared-with-deps-release-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda12_6-shared-with-deps-release + build_name: libtorch-cuda12_8-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_8-shared-with-deps-release-build: + libtorch-cuda12_9-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -257,22 +257,22 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: libtorch-cuda12_8-shared-with-deps-release + build_name: libtorch-cuda12_9-shared-with-deps-release build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-cuda12_8-shared-with-deps-release-test: # Testing + libtorch-cuda12_9-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda12_8-shared-with-deps-release-build + - libtorch-cuda12_9-shared-with-deps-release-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -280,38 +280,38 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda12_8-shared-with-deps-release + build_name: libtorch-cuda12_9-shared-with-deps-release build_environment: linux-binary-libtorch runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-cuda12_8-shared-with-deps-release-upload: # Uploading + libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_8-shared-with-deps-release-test + needs: libtorch-cuda12_9-shared-with-deps-release-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda12_8-shared-with-deps-release + build_name: libtorch-cuda12_9-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -342,7 +342,7 @@ jobs: needs: - libtorch-rocm6_3-shared-with-deps-release-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -456,7 +456,7 @@ jobs: needs: - libtorch-rocm6_4-shared-with-deps-release-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 151bcb852554e4..064156936defd7 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -42,7 +42,7 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - manywheel-py3_9-cuda11_8-build: + manywheel-py3_9-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -51,23 +51,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda11_8 + build_name: manywheel-py3_9-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-test: # Testing + manywheel-py3_9-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_9-cuda11_8-build + - manywheel-py3_9-cuda12_6-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -75,21 +75,21 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8 + build_name: manywheel-py3_9-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_6-build: + manywheel-py3_9-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -98,23 +98,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_6 + build_name: manywheel-py3_9-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.5.1.17; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_6-test: # Testing + manywheel-py3_9-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_9-cuda12_6-build + - manywheel-py3_9-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -122,21 +122,21 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_6 + build_name: manywheel-py3_9-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_8-build: + manywheel-py3_9-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -145,23 +145,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_8 + build_name: manywheel-py3_9-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_8-test: # Testing + manywheel-py3_9-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_9-cuda12_8-build + - manywheel-py3_9-cuda12_9-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -169,16 +169,108 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_8 + build_name: manywheel-py3_9-cuda12_9 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_9-rocm6_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-rocm6_4 + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-rocm6_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-rocm6_4-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_9-rocm6_4 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.4 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 4322ac1ccd3ccf..5a530a39d6ca3d 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -112,7 +112,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda11_8-build: + manywheel-py3_9-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -121,23 +121,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda11_8 + build_name: manywheel-py3_9-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-test: # Testing + manywheel-py3_9-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_9-cuda11_8-build + - manywheel-py3_9-cuda12_6-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -145,43 +145,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8 + build_name: manywheel-py3_9-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-upload: # Uploading + manywheel-py3_9-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_9-cuda11_8-test + needs: manywheel-py3_9-cuda12_6-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8 + build_name: manywheel-py3_9-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_6-build: + manywheel-py3_9-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -190,23 +190,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_6 + build_name: manywheel-py3_9-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.5.1.17; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_6-test: # Testing + manywheel-py3_9-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_9-cuda12_6-build + - manywheel-py3_9-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -214,43 +214,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_6 + build_name: manywheel-py3_9-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_6-upload: # Uploading + manywheel-py3_9-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_9-cuda12_6-test + needs: manywheel-py3_9-cuda12_8-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_6 + build_name: manywheel-py3_9-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_8-build: + manywheel-py3_9-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -259,23 +259,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.9" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_8 + build_name: manywheel-py3_9-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_8-test: # Testing + manywheel-py3_9-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_9-cuda12_8-build + - manywheel-py3_9-cuda12_9-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -283,38 +283,38 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_8 + build_name: manywheel-py3_9-cuda12_9 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_8-upload: # Uploading + manywheel-py3_9-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_9-cuda12_8-test + needs: manywheel-py3_9-cuda12_9-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_8 + build_name: manywheel-py3_9-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -345,7 +345,7 @@ jobs: needs: - manywheel-py3_9-rocm6_3-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -459,7 +459,7 @@ jobs: needs: - manywheel-py3_9-rocm6_4-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -565,7 +565,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-xpu-test: # Testing @@ -725,7 +725,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda11_8-build: + manywheel-py3_10-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -734,23 +734,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda11_8 + build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda11_8-test: # Testing + manywheel-py3_10-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda11_8-build + - manywheel-py3_10-cuda12_6-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -758,43 +758,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda11_8 + build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda11_8-upload: # Uploading + manywheel-py3_10-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda11_8-test + needs: manywheel-py3_10-cuda12_6-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda11_8 + build_name: manywheel-py3_10-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_6-build: + manywheel-py3_10-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -803,23 +803,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_6 + build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.5.1.17; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_6-test: # Testing + manywheel-py3_10-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda12_6-build + - manywheel-py3_10-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -827,43 +827,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_6 + build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_6-upload: # Uploading + manywheel-py3_10-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda12_6-test + needs: manywheel-py3_10-cuda12_8-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_6 + build_name: manywheel-py3_10-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_8-build: + manywheel-py3_10-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -872,23 +872,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_8 + build_name: manywheel-py3_10-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_8-test: # Testing + manywheel-py3_10-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda12_8-build + - manywheel-py3_10-cuda12_9-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -896,38 +896,38 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_8 + build_name: manywheel-py3_10-cuda12_9 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_8-upload: # Uploading + manywheel-py3_10-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda12_8-test + needs: manywheel-py3_10-cuda12_9-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_8 + build_name: manywheel-py3_10-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -958,7 +958,7 @@ jobs: needs: - manywheel-py3_10-rocm6_3-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -1072,7 +1072,7 @@ jobs: needs: - manywheel-py3_10-rocm6_4-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -1178,7 +1178,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-xpu-test: # Testing @@ -1338,7 +1338,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda11_8-build: + manywheel-py3_11-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1347,23 +1347,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda11_8 + build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda11_8-test: # Testing + manywheel-py3_11-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda11_8-build + - manywheel-py3_11-cuda12_6-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1371,43 +1371,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8 + build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda11_8-upload: # Uploading + manywheel-py3_11-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda11_8-test + needs: manywheel-py3_11-cuda12_6-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8 + build_name: manywheel-py3_11-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_6-build: + manywheel-py3_11-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1416,23 +1416,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_6 + build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.5.1.17; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_6-test: # Testing + manywheel-py3_11-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda12_6-build + - manywheel-py3_11-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1440,43 +1440,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_6 + build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_6-upload: # Uploading + manywheel-py3_11-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda12_6-test + needs: manywheel-py3_11-cuda12_8-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_6 + build_name: manywheel-py3_11-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_6-full-build: + manywheel-py3_11-cuda12_8-full-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1485,22 +1485,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_6-full + build_name: manywheel-py3_11-cuda12_8-full build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_6-full-test: # Testing + manywheel-py3_11-cuda12_8-full-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda12_6-full-build + - manywheel-py3_11-cuda12_8-full-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1508,43 +1508,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_6-full + build_name: manywheel-py3_11-cuda12_8-full build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_6-full-upload: # Uploading + manywheel-py3_11-cuda12_8-full-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda12_6-full-test + needs: manywheel-py3_11-cuda12_8-full-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_6-full + build_name: manywheel-py3_11-cuda12_8-full secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_8-build: + manywheel-py3_11-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1553,23 +1553,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_8 + build_name: manywheel-py3_11-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_8-test: # Testing + manywheel-py3_11-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda12_8-build + - manywheel-py3_11-cuda12_9-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1577,38 +1577,38 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_8 + build_name: manywheel-py3_11-cuda12_9 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_8-upload: # Uploading + manywheel-py3_11-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda12_8-test + needs: manywheel-py3_11-cuda12_9-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_8 + build_name: manywheel-py3_11-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -1639,7 +1639,7 @@ jobs: needs: - manywheel-py3_11-rocm6_3-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -1753,7 +1753,7 @@ jobs: needs: - manywheel-py3_11-rocm6_4-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -1859,7 +1859,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-xpu-test: # Testing @@ -2019,7 +2019,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda11_8-build: + manywheel-py3_12-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2028,23 +2028,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda11_8 + build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda11_8-test: # Testing + manywheel-py3_12-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda11_8-build + - manywheel-py3_12-cuda12_6-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2052,43 +2052,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8 + build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda11_8-upload: # Uploading + manywheel-py3_12-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda11_8-test + needs: manywheel-py3_12-cuda12_6-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8 + build_name: manywheel-py3_12-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda12_6-build: + manywheel-py3_12-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2097,23 +2097,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_6 + build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.5.1.17; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_6-test: # Testing + manywheel-py3_12-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda12_6-build + - manywheel-py3_12-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2121,43 +2121,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_6 + build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_6-upload: # Uploading + manywheel-py3_12-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda12_6-test + needs: manywheel-py3_12-cuda12_8-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_6 + build_name: manywheel-py3_12-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda12_8-build: + manywheel-py3_12-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2166,23 +2166,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_8 + build_name: manywheel-py3_12-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_8-test: # Testing + manywheel-py3_12-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda12_8-build + - manywheel-py3_12-cuda12_9-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2190,38 +2190,38 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_8 + build_name: manywheel-py3_12-cuda12_9 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_8-upload: # Uploading + manywheel-py3_12-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda12_8-test + needs: manywheel-py3_12-cuda12_9-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_8 + build_name: manywheel-py3_12-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -2252,7 +2252,7 @@ jobs: needs: - manywheel-py3_12-rocm6_3-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -2366,7 +2366,7 @@ jobs: needs: - manywheel-py3_12-rocm6_4-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -2472,7 +2472,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-xpu-test: # Testing @@ -2632,7 +2632,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda11_8-build: + manywheel-py3_13-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2641,23 +2641,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda11_8 + build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda11_8-test: # Testing + manywheel-py3_13-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-cuda11_8-build + - manywheel-py3_13-cuda12_6-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2665,43 +2665,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda11_8 + build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda11_8-upload: # Uploading + manywheel-py3_13-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-cuda11_8-test + needs: manywheel-py3_13-cuda12_6-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda11_8 + build_name: manywheel-py3_13-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_6-build: + manywheel-py3_13-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2710,23 +2710,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda12_6 + build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.5.1.17; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_6-test: # Testing + manywheel-py3_13-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-cuda12_6-build + - manywheel-py3_13-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2734,43 +2734,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_6 + build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_6-upload: # Uploading + manywheel-py3_13-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-cuda12_6-test + needs: manywheel-py3_13-cuda12_8-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_6 + build_name: manywheel-py3_13-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_8-build: + manywheel-py3_13-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2779,23 +2779,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda12_8 + build_name: manywheel-py3_13-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_8-test: # Testing + manywheel-py3_13-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-cuda12_8-build + - manywheel-py3_13-cuda12_9-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2803,38 +2803,38 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_8 + build_name: manywheel-py3_13-cuda12_9 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_8-upload: # Uploading + manywheel-py3_13-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-cuda12_8-test + needs: manywheel-py3_13-cuda12_9-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_8 + build_name: manywheel-py3_13-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -2865,7 +2865,7 @@ jobs: needs: - manywheel-py3_13-rocm6_3-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -2979,7 +2979,7 @@ jobs: needs: - manywheel-py3_13-rocm6_4-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -3085,7 +3085,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-xpu-test: # Testing @@ -3245,7 +3245,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda11_8-build: + manywheel-py3_13t-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3254,23 +3254,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda11_8 + build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda11_8-test: # Testing + manywheel-py3_13t-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-cuda11_8-build + - manywheel-py3_13t-cuda12_6-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3278,43 +3278,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda11_8 + build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda11_8-upload: # Uploading + manywheel-py3_13t-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda11_8-test + needs: manywheel-py3_13t-cuda12_6-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda11.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 use_split_build: False DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda11_8 + build_name: manywheel-py3_13t-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda12_6-build: + manywheel-py3_13t-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3323,23 +3323,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda12_6 + build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.5.1.17; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_6-test: # Testing + manywheel-py3_13t-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-cuda12_6-build + - manywheel-py3_13t-cuda12_8-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3347,43 +3347,43 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_6 + build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_6-upload: # Uploading + manywheel-py3_13t-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda12_6-test + needs: manywheel-py3_13t-cuda12_8-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 use_split_build: False DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_6 + build_name: manywheel-py3_13t-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda12_8-build: + manywheel-py3_13t-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3392,23 +3392,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda12_8 + build_name: manywheel-py3_13t-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.8.0.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.26.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_8-test: # Testing + manywheel-py3_13t-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-cuda12_8-build + - manywheel-py3_13t-cuda12_9-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3416,38 +3416,38 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_8 + build_name: manywheel-py3_13t-cuda12_9 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 build needs sm_70+ runner + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_8-upload: # Uploading + manywheel-py3_13t-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda12_8-test + needs: manywheel-py3_13t-cuda12_9-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 use_split_build: False DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_8 + build_name: manywheel-py3_13t-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -3478,7 +3478,7 @@ jobs: needs: - manywheel-py3_13t-rocm6_3-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -3592,7 +3592,7 @@ jobs: needs: - manywheel-py3_13t-rocm6_4-build - get-label-type - runs-on: linux.rocm.gpu + runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 env: PYTORCH_ROOT: /pytorch @@ -3698,7 +3698,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-xpu build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-xpu-test: # Testing diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index bd800a1b7c5c72..75c393b46e59bf 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -288,7 +288,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda11_8-shared-with-deps-debug-build: + libtorch-cuda12_6-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -298,8 +298,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -384,7 +384,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda11_8-shared-with-deps-debug + name: libtorch-cuda12_6-shared-with-deps-debug retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -402,10 +402,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda11_8-shared-with-deps-debug-test: # Testing + libtorch-cuda12_6-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda11_8-shared-with-deps-debug-build + - libtorch-cuda12_6-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -414,8 +414,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -492,7 +492,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-cuda11_8-shared-with-deps-debug + name: libtorch-cuda12_6-shared-with-deps-debug path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -515,30 +515,30 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda11_8-shared-with-deps-debug-upload: # Uploading + libtorch-cuda12_6-shared-with-deps-debug-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda11_8-shared-with-deps-debug-test + needs: libtorch-cuda12_6-shared-with-deps-debug-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.9" - build_name: libtorch-cuda11_8-shared-with-deps-debug + build_name: libtorch-cuda12_6-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_6-shared-with-deps-debug-build: + libtorch-cuda12_8-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -548,8 +548,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -634,7 +634,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda12_6-shared-with-deps-debug + name: libtorch-cuda12_8-shared-with-deps-debug retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -652,10 +652,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_6-shared-with-deps-debug-test: # Testing + libtorch-cuda12_8-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda12_6-shared-with-deps-debug-build + - libtorch-cuda12_8-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -664,8 +664,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -742,7 +742,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-cuda12_6-shared-with-deps-debug + name: libtorch-cuda12_8-shared-with-deps-debug path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -765,30 +765,30 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_6-shared-with-deps-debug-upload: # Uploading + libtorch-cuda12_8-shared-with-deps-debug-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_6-shared-with-deps-debug-test + needs: libtorch-cuda12_8-shared-with-deps-debug-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.9" - build_name: libtorch-cuda12_6-shared-with-deps-debug + build_name: libtorch-cuda12_8-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_8-shared-with-deps-debug-build: + libtorch-cuda12_9-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -798,8 +798,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -884,7 +884,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda12_8-shared-with-deps-debug + name: libtorch-cuda12_9-shared-with-deps-debug retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -902,10 +902,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_8-shared-with-deps-debug-test: # Testing + libtorch-cuda12_9-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda12_8-shared-with-deps-debug-build + - libtorch-cuda12_9-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -914,8 +914,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -992,7 +992,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-cuda12_8-shared-with-deps-debug + name: libtorch-cuda12_9-shared-with-deps-debug path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -1015,26 +1015,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_8-shared-with-deps-debug-upload: # Uploading + libtorch-cuda12_9-shared-with-deps-debug-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_8-shared-with-deps-debug-test + needs: libtorch-cuda12_9-shared-with-deps-debug-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.9" - build_name: libtorch-cuda12_8-shared-with-deps-debug + build_name: libtorch-cuda12_9-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index 29b334ec49289e..eccd332c74a1fd 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -288,7 +288,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda11_8-shared-with-deps-release-build: + libtorch-cuda12_6-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -298,8 +298,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release @@ -384,7 +384,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda11_8-shared-with-deps-release + name: libtorch-cuda12_6-shared-with-deps-release retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -402,10 +402,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda11_8-shared-with-deps-release-test: # Testing + libtorch-cuda12_6-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda11_8-shared-with-deps-release-build + - libtorch-cuda12_6-shared-with-deps-release-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -414,8 +414,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release @@ -492,7 +492,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-cuda11_8-shared-with-deps-release + name: libtorch-cuda12_6-shared-with-deps-release path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -515,30 +515,30 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda11_8-shared-with-deps-release-upload: # Uploading + libtorch-cuda12_6-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda11_8-shared-with-deps-release-test + needs: libtorch-cuda12_6-shared-with-deps-release-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.9" - build_name: libtorch-cuda11_8-shared-with-deps-release + build_name: libtorch-cuda12_6-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_6-shared-with-deps-release-build: + libtorch-cuda12_8-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -548,8 +548,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release @@ -634,7 +634,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda12_6-shared-with-deps-release + name: libtorch-cuda12_8-shared-with-deps-release retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -652,10 +652,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_6-shared-with-deps-release-test: # Testing + libtorch-cuda12_8-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda12_6-shared-with-deps-release-build + - libtorch-cuda12_8-shared-with-deps-release-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -664,8 +664,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release @@ -742,7 +742,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-cuda12_6-shared-with-deps-release + name: libtorch-cuda12_8-shared-with-deps-release path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -765,30 +765,30 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_6-shared-with-deps-release-upload: # Uploading + libtorch-cuda12_8-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_6-shared-with-deps-release-test + needs: libtorch-cuda12_8-shared-with-deps-release-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.9" - build_name: libtorch-cuda12_6-shared-with-deps-release + build_name: libtorch-cuda12_8-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_8-shared-with-deps-release-build: + libtorch-cuda12_9-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -798,8 +798,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release @@ -884,7 +884,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda12_8-shared-with-deps-release + name: libtorch-cuda12_9-shared-with-deps-release retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -902,10 +902,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_8-shared-with-deps-release-test: # Testing + libtorch-cuda12_9-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda12_8-shared-with-deps-release-build + - libtorch-cuda12_9-shared-with-deps-release-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -914,8 +914,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release @@ -992,7 +992,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-cuda12_8-shared-with-deps-release + name: libtorch-cuda12_9-shared-with-deps-release path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -1015,26 +1015,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_8-shared-with-deps-release-upload: # Uploading + libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_8-shared-with-deps-release-test + needs: libtorch-cuda12_9-shared-with-deps-release-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.9" - build_name: libtorch-cuda12_8-shared-with-deps-release + build_name: libtorch-cuda12_9-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 9bd79a16742fc1..22ebe8db70eace 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -276,7 +276,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_9-cuda11_8-build: + wheel-py3_9-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -286,8 +286,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" @@ -368,7 +368,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_9-cuda11_8 + name: wheel-py3_9-cuda12_6 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -386,10 +386,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_9-cuda11_8-test: # Testing + wheel-py3_9-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_9-cuda11_8-build + - wheel-py3_9-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -398,8 +398,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" @@ -472,7 +472,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_9-cuda11_8 + name: wheel-py3_9-cuda12_6 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -495,26 +495,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_9-cuda11_8-upload: # Uploading + wheel-py3_9-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_9-cuda11_8-test + needs: wheel-py3_9-cuda12_6-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.9" - build_name: wheel-py3_9-cuda11_8 + build_name: wheel-py3_9-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_9-cuda12_6-build: + wheel-py3_9-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -524,8 +524,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" @@ -606,7 +606,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_9-cuda12_6 + name: wheel-py3_9-cuda12_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -624,10 +624,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_9-cuda12_6-test: # Testing + wheel-py3_9-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_9-cuda12_6-build + - wheel-py3_9-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -636,8 +636,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" @@ -710,7 +710,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_9-cuda12_6 + name: wheel-py3_9-cuda12_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -733,26 +733,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_9-cuda12_6-upload: # Uploading + wheel-py3_9-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_9-cuda12_6-test + needs: wheel-py3_9-cuda12_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.9" - build_name: wheel-py3_9-cuda12_6 + build_name: wheel-py3_9-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_9-cuda12_8-build: + wheel-py3_9-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -762,8 +762,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" @@ -844,7 +844,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_9-cuda12_8 + name: wheel-py3_9-cuda12_9 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -862,10 +862,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_9-cuda12_8-test: # Testing + wheel-py3_9-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_9-cuda12_8-build + - wheel-py3_9-cuda12_9-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -874,8 +874,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" @@ -948,7 +948,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_9-cuda12_8 + name: wheel-py3_9-cuda12_9 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -971,22 +971,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_9-cuda12_8-upload: # Uploading + wheel-py3_9-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_9-cuda12_8-test + needs: wheel-py3_9-cuda12_9-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.9" - build_name: wheel-py3_9-cuda12_8 + build_name: wheel-py3_9-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -1004,7 +1004,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -1461,7 +1461,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cuda11_8-build: + wheel-py3_10-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -1471,8 +1471,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -1553,7 +1553,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cuda11_8 + name: wheel-py3_10-cuda12_6 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -1571,10 +1571,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda11_8-test: # Testing + wheel-py3_10-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_10-cuda11_8-build + - wheel-py3_10-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -1583,8 +1583,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -1657,7 +1657,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda11_8 + name: wheel-py3_10-cuda12_6 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -1680,26 +1680,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda11_8-upload: # Uploading + wheel-py3_10-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda11_8-test + needs: wheel-py3_10-cuda12_6-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda11_8 + build_name: wheel-py3_10-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cuda12_6-build: + wheel-py3_10-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -1709,8 +1709,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -1791,7 +1791,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cuda12_6 + name: wheel-py3_10-cuda12_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -1809,10 +1809,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_6-test: # Testing + wheel-py3_10-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_10-cuda12_6-build + - wheel-py3_10-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -1821,8 +1821,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -1895,7 +1895,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda12_6 + name: wheel-py3_10-cuda12_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -1918,26 +1918,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_6-upload: # Uploading + wheel-py3_10-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda12_6-test + needs: wheel-py3_10-cuda12_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda12_6 + build_name: wheel-py3_10-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cuda12_8-build: + wheel-py3_10-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -1947,8 +1947,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -2029,7 +2029,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cuda12_8 + name: wheel-py3_10-cuda12_9 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2047,10 +2047,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_8-test: # Testing + wheel-py3_10-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_10-cuda12_8-build + - wheel-py3_10-cuda12_9-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -2059,8 +2059,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -2133,7 +2133,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda12_8 + name: wheel-py3_10-cuda12_9 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -2156,22 +2156,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_8-upload: # Uploading + wheel-py3_10-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda12_8-test + needs: wheel-py3_10-cuda12_9-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda12_8 + build_name: wheel-py3_10-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -2189,7 +2189,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -2646,7 +2646,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda11_8-build: + wheel-py3_11-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -2656,8 +2656,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2738,7 +2738,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cuda11_8 + name: wheel-py3_11-cuda12_6 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2756,10 +2756,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda11_8-test: # Testing + wheel-py3_11-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cuda11_8-build + - wheel-py3_11-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -2768,8 +2768,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2842,7 +2842,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda11_8 + name: wheel-py3_11-cuda12_6 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -2865,26 +2865,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda11_8-upload: # Uploading + wheel-py3_11-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda11_8-test + needs: wheel-py3_11-cuda12_6-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda11_8 + build_name: wheel-py3_11-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda12_6-build: + wheel-py3_11-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -2894,8 +2894,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2976,7 +2976,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cuda12_6 + name: wheel-py3_11-cuda12_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2994,10 +2994,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_6-test: # Testing + wheel-py3_11-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cuda12_6-build + - wheel-py3_11-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -3006,8 +3006,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -3080,7 +3080,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda12_6 + name: wheel-py3_11-cuda12_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -3103,26 +3103,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_6-upload: # Uploading + wheel-py3_11-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda12_6-test + needs: wheel-py3_11-cuda12_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda12_6 + build_name: wheel-py3_11-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda12_8-build: + wheel-py3_11-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -3132,8 +3132,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -3214,7 +3214,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cuda12_8 + name: wheel-py3_11-cuda12_9 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3232,10 +3232,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_8-test: # Testing + wheel-py3_11-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cuda12_8-build + - wheel-py3_11-cuda12_9-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -3244,8 +3244,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -3318,7 +3318,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda12_8 + name: wheel-py3_11-cuda12_9 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -3341,22 +3341,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_8-upload: # Uploading + wheel-py3_11-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda12_8-test + needs: wheel-py3_11-cuda12_9-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda12_8 + build_name: wheel-py3_11-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -3374,7 +3374,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -3831,7 +3831,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda11_8-build: + wheel-py3_12-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -3841,8 +3841,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3923,7 +3923,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_12-cuda11_8 + name: wheel-py3_12-cuda12_6 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3941,10 +3941,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda11_8-test: # Testing + wheel-py3_12-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cuda11_8-build + - wheel-py3_12-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -3953,8 +3953,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -4027,7 +4027,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cuda11_8 + name: wheel-py3_12-cuda12_6 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -4050,26 +4050,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda11_8-upload: # Uploading + wheel-py3_12-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cuda11_8-test + needs: wheel-py3_12-cuda12_6-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cuda11_8 + build_name: wheel-py3_12-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda12_6-build: + wheel-py3_12-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -4079,8 +4079,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -4161,7 +4161,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_12-cuda12_6 + name: wheel-py3_12-cuda12_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -4179,10 +4179,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_6-test: # Testing + wheel-py3_12-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cuda12_6-build + - wheel-py3_12-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -4191,8 +4191,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -4265,7 +4265,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cuda12_6 + name: wheel-py3_12-cuda12_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -4288,26 +4288,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_6-upload: # Uploading + wheel-py3_12-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cuda12_6-test + needs: wheel-py3_12-cuda12_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cuda12_6 + build_name: wheel-py3_12-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda12_8-build: + wheel-py3_12-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -4317,8 +4317,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -4399,7 +4399,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_12-cuda12_8 + name: wheel-py3_12-cuda12_9 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -4417,10 +4417,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_8-test: # Testing + wheel-py3_12-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cuda12_8-build + - wheel-py3_12-cuda12_9-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -4429,8 +4429,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -4503,7 +4503,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cuda12_8 + name: wheel-py3_12-cuda12_9 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -4526,22 +4526,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_8-upload: # Uploading + wheel-py3_12-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cuda12_8-test + needs: wheel-py3_12-cuda12_9-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cuda12_8 + build_name: wheel-py3_12-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -4559,7 +4559,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -5016,7 +5016,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13-cuda11_8-build: + wheel-py3_13-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -5026,8 +5026,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -5108,7 +5108,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13-cuda11_8 + name: wheel-py3_13-cuda12_6 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -5126,10 +5126,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda11_8-test: # Testing + wheel-py3_13-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13-cuda11_8-build + - wheel-py3_13-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -5138,8 +5138,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -5212,7 +5212,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13-cuda11_8 + name: wheel-py3_13-cuda12_6 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -5235,26 +5235,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda11_8-upload: # Uploading + wheel-py3_13-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13-cuda11_8-test + needs: wheel-py3_13-cuda12_6-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13" - build_name: wheel-py3_13-cuda11_8 + build_name: wheel-py3_13-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13-cuda12_6-build: + wheel-py3_13-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -5264,8 +5264,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -5346,7 +5346,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13-cuda12_6 + name: wheel-py3_13-cuda12_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -5364,10 +5364,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda12_6-test: # Testing + wheel-py3_13-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13-cuda12_6-build + - wheel-py3_13-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -5376,8 +5376,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -5450,7 +5450,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13-cuda12_6 + name: wheel-py3_13-cuda12_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -5473,26 +5473,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda12_6-upload: # Uploading + wheel-py3_13-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13-cuda12_6-test + needs: wheel-py3_13-cuda12_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13" - build_name: wheel-py3_13-cuda12_6 + build_name: wheel-py3_13-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13-cuda12_8-build: + wheel-py3_13-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -5502,8 +5502,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -5584,7 +5584,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13-cuda12_8 + name: wheel-py3_13-cuda12_9 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -5602,10 +5602,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda12_8-test: # Testing + wheel-py3_13-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13-cuda12_8-build + - wheel-py3_13-cuda12_9-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -5614,8 +5614,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -5688,7 +5688,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13-cuda12_8 + name: wheel-py3_13-cuda12_9 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -5711,22 +5711,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda12_8-upload: # Uploading + wheel-py3_13-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13-cuda12_8-test + needs: wheel-py3_13-cuda12_9-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13" - build_name: wheel-py3_13-cuda12_8 + build_name: wheel-py3_13-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -5744,7 +5744,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -6201,7 +6201,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13t-cuda11_8-build: + wheel-py3_13t-cuda12_6-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -6211,8 +6211,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -6293,7 +6293,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13t-cuda11_8 + name: wheel-py3_13t-cuda12_6 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -6311,10 +6311,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda11_8-test: # Testing + wheel-py3_13t-cuda12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13t-cuda11_8-build + - wheel-py3_13t-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -6323,8 +6323,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -6397,7 +6397,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13t-cuda11_8 + name: wheel-py3_13t-cuda12_6 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -6420,26 +6420,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda11_8-upload: # Uploading + wheel-py3_13t-cuda12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13t-cuda11_8-test + needs: wheel-py3_13t-cuda12_6-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13t" - build_name: wheel-py3_13t-cuda11_8 + build_name: wheel-py3_13t-cuda12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13t-cuda12_6-build: + wheel-py3_13t-cuda12_8-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -6449,8 +6449,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -6531,7 +6531,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13t-cuda12_6 + name: wheel-py3_13t-cuda12_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -6549,10 +6549,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda12_6-test: # Testing + wheel-py3_13t-cuda12_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13t-cuda12_6-build + - wheel-py3_13t-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -6561,8 +6561,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -6635,7 +6635,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13t-cuda12_6 + name: wheel-py3_13t-cuda12_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -6658,26 +6658,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda12_6-upload: # Uploading + wheel-py3_13t-cuda12_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13t-cuda12_6-test + needs: wheel-py3_13t-cuda12_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: 12.6 + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13t" - build_name: wheel-py3_13t-cuda12_6 + build_name: wheel-py3_13t-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13t-cuda12_8-build: + wheel-py3_13t-cuda12_9-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" @@ -6687,8 +6687,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -6769,7 +6769,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13t-cuda12_8 + name: wheel-py3_13t-cuda12_9 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -6787,10 +6787,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda12_8-test: # Testing + wheel-py3_13t-cuda12_9-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13t-cuda12_8-build + - wheel-py3_13t-cuda12_9-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 300 @@ -6799,8 +6799,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -6873,7 +6873,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13t-cuda12_8 + name: wheel-py3_13t-cuda12_9 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -6896,22 +6896,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda12_8-upload: # Uploading + wheel-py3_13t-cuda12_9-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13t-cuda12_8-test + needs: wheel-py3_13t-cuda12_9-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: 12.8 + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13t" - build_name: wheel-py3_13t-cuda12_8 + build_name: wheel-py3_13t-cuda12_9 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -6929,7 +6929,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the diff --git a/.github/workflows/h100-distributed.yml b/.github/workflows/h100-distributed.yml new file mode 100644 index 00000000000000..45579672f23205 --- /dev/null +++ b/.github/workflows/h100-distributed.yml @@ -0,0 +1,55 @@ +name: Limited CI for distributed tests on H100 + +on: + pull_request: + paths: + - .github/workflows/h100-distributed.yml + workflow_dispatch: + push: + tags: + - ciflow/h100-distributed/* + schedule: + - cron: 46 8 * * * # about 1:46am PDT + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + + get-label-type: + if: github.repository_owner == 'pytorch' + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-dist: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runner: "linux.12xlarge" + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: '9.0' + test-matrix: | + { include: [ + { config: "h100_distributed", shard: 1, num_shards: 1, runner: "linux.aws.h100.8" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-sm90-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-dist + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-dist.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-dist.outputs.test-matrix }} + secrets: inherit diff --git a/.github/workflows/h100-symm-mem.yml b/.github/workflows/h100-symm-mem.yml new file mode 100644 index 00000000000000..0f08ecf9de09cf --- /dev/null +++ b/.github/workflows/h100-symm-mem.yml @@ -0,0 +1,54 @@ +name: Limited CI for symmetric memory tests on H100 + +on: + pull_request: + paths: + - .github/workflows/h100-symm-mem.yml + workflow_dispatch: + push: + tags: + - ciflow/h100-symm-mem/* + schedule: + - cron: 22 8 * * * # about 1:22am PDT + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + + get-label-type: + if: github.repository_owner == 'pytorch' + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-symm: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90-symm + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-symm + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: '9.0' + test-matrix: | + { include: [ + { config: "h100-symm-mem", shard: 1, num_shards: 1, runner: "linux.aws.h100.4" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-sm90-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90-symm + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-symm + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-symm + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-symm.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-symm.outputs.test-matrix }} + secrets: inherit diff --git a/.github/workflows/inductor-micro-benchmark-x86.yml b/.github/workflows/inductor-micro-benchmark-x86.yml index 07115c41cddbb5..117183428abc1c 100644 --- a/.github/workflows/inductor-micro-benchmark-x86.yml +++ b/.github/workflows/inductor-micro-benchmark-x86.yml @@ -13,7 +13,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: linux-jammy-cpu-py3_9-gcc11-inductor-build: diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index cdec17fab974da..a0ae234ab5669f 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -13,7 +13,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-default-label-prefix: diff --git a/.github/workflows/inductor-nightly.yml b/.github/workflows/inductor-nightly.yml index 3e393dcec8c882..d8dc7146fda13f 100644 --- a/.github/workflows/inductor-nightly.yml +++ b/.github/workflows/inductor-nightly.yml @@ -16,7 +16,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-default-label-prefix: diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index c607e04981b1ed..25191643b3599b 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -10,7 +10,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-default-label-prefix: diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index e83675f2d48c16..ed04d88eb1277a 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -48,7 +48,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-label-type: diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index 6c261a80453a7d..c94996f58002be 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -2,7 +2,7 @@ name: inductor-perf-nightly-h100 on: schedule: - - cron: 0 7 * * 1-6 + - cron: 15 0,4,8,12,16,20 * * 1-6 - cron: 0 7 * * 0 # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs @@ -63,7 +63,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-label-type: @@ -94,27 +96,31 @@ jobs: { config: "inductor_huggingface_perf_cuda_h100", shard: 3, num_shards: 5, runner: "linux.aws.h100" }, { config: "inductor_huggingface_perf_cuda_h100", shard: 4, num_shards: 5, runner: "linux.aws.h100" }, { config: "inductor_huggingface_perf_cuda_h100", shard: 5, num_shards: 5, runner: "linux.aws.h100" }, - { config: "inductor_timm_perf_cuda_h100", shard: 1, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_timm_perf_cuda_h100", shard: 2, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_timm_perf_cuda_h100", shard: 3, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_timm_perf_cuda_h100", shard: 4, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_timm_perf_cuda_h100", shard: 5, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_timm_perf_cuda_h100", shard: 6, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_torchbench_perf_cuda_h100", shard: 1, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_torchbench_perf_cuda_h100", shard: 2, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_torchbench_perf_cuda_h100", shard: 3, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_torchbench_perf_cuda_h100", shard: 4, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_torchbench_perf_cuda_h100", shard: 5, num_shards: 6, runner: "linux.aws.h100" }, - { config: "inductor_torchbench_perf_cuda_h100", shard: 6, num_shards: 6, runner: "linux.aws.h100" }, + { config: "inductor_timm_perf_cuda_h100", shard: 1, num_shards: 7, runner: "linux.aws.h100" }, + { config: "inductor_timm_perf_cuda_h100", shard: 2, num_shards: 7, runner: "linux.aws.h100" }, + { config: "inductor_timm_perf_cuda_h100", shard: 3, num_shards: 7, runner: "linux.aws.h100" }, + { config: "inductor_timm_perf_cuda_h100", shard: 4, num_shards: 7, runner: "linux.aws.h100" }, + { config: "inductor_timm_perf_cuda_h100", shard: 5, num_shards: 7, runner: "linux.aws.h100" }, + { config: "inductor_timm_perf_cuda_h100", shard: 6, num_shards: 7, runner: "linux.aws.h100" }, + { config: "inductor_timm_perf_cuda_h100", shard: 7, num_shards: 7, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 1, num_shards: 9, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 2, num_shards: 9, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 3, num_shards: 9, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 4, num_shards: 9, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 5, num_shards: 9, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 6, num_shards: 9, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 7, num_shards: 9, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 8, num_shards: 9, runner: "linux.aws.h100" }, + { config: "inductor_torchbench_perf_cuda_h100", shard: 9, num_shards: 9, runner: "linux.aws.h100" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} secrets: inherit - test-nightly: + test-periodically: name: cuda12.8-py3.10-gcc9-sm90 uses: ./.github/workflows/_linux-test.yml needs: build - if: github.event.schedule == '0 7 * * 1-6' + if: github.event.schedule == '15 0,4,8,12,16,20 * * 1-6' with: build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true diff --git a/.github/workflows/inductor-perf-test-nightly-macos.yml b/.github/workflows/inductor-perf-test-nightly-macos.yml index 40d87c15f8a35c..0d92455a8f3c75 100644 --- a/.github/workflows/inductor-perf-test-nightly-macos.yml +++ b/.github/workflows/inductor-perf-test-nightly-macos.yml @@ -45,7 +45,9 @@ jobs: python-version: 3.12.7 test-matrix: | { include: [ - { config: "perf_smoketest", shard: 1, num_shards: 1, runner: "macos-m2-15" }, + { config: "perf_smoketest", shard: 1, num_shards: 3, runner: "macos-m2-15" }, + { config: "perf_smoketest", shard: 2, num_shards: 3, runner: "macos-m2-15" }, + { config: "perf_smoketest", shard: 3, num_shards: 3, runner: "macos-m2-15" }, ]} secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-rocm.yml b/.github/workflows/inductor-perf-test-nightly-rocm.yml index e45e596d5af061..389a1c0fc07c2e 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm.yml @@ -5,7 +5,7 @@ on: tags: - ciflow/inductor-perf-test-nightly-rocm/* schedule: - - cron: 0 7 * * 0 + - cron: 0 7 * * 0,3 # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs workflow_dispatch: @@ -88,18 +88,23 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 test-matrix: | { include: [ - { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi300.2" }, - { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi300.2" }, - { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, - { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" }, ]} secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-x86-zen.yml b/.github/workflows/inductor-perf-test-nightly-x86-zen.yml new file mode 100644 index 00000000000000..6e19130a192465 --- /dev/null +++ b/.github/workflows/inductor-perf-test-nightly-x86-zen.yml @@ -0,0 +1,131 @@ +name: inductor-perf-nightly-x86-zen + +on: + push: + tags: + - ciflow/inductor-perf-test-nightly-x86-zen/* + schedule: + # - cron: 0 7 * * 1-6 + # - cron: 0 7 * * 0 + # Does not perform max_autotune on CPU, so skip the weekly run setup + - cron: 0 7 * * * + # NB: GitHub has an upper limit of 10 inputs here + workflow_dispatch: + inputs: + training: + # CPU for training is not typical, but leave the option open here + description: Run training (off by default)? + required: false + type: boolean + default: false + inference: + description: Run inference (on by default)? + required: false + type: boolean + default: true + default: + description: Run inductor_default? + required: false + type: boolean + default: true + dynamic: + description: Run inductor_dynamic_shapes? + required: false + type: boolean + default: false + cppwrapper: + description: Run inductor_cpp_wrapper? + required: false + type: boolean + default: false + aotinductor: + description: Run aot_inductor for inference? + required: false + type: boolean + default: false + benchmark_configs: + description: The list of configs used the benchmark + required: false + type: string + default: inductor_huggingface_perf_cpu_x86_zen,inductor_timm_perf_cpu_x86_zen,inductor_torchbench_perf_cpu_x86_zen + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + +jobs: + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + opt_out_experiments: lf + + linux-jammy-zen-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-zen-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks + test-matrix: | + { include: [ + { config: "inductor_huggingface_perf_cpu_x86_zen", shard: 1, num_shards: 3, runner: "linux.24xlarge.amd" }, + { config: "inductor_huggingface_perf_cpu_x86_zen", shard: 2, num_shards: 3, runner: "linux.24xlarge.amd" }, + { config: "inductor_huggingface_perf_cpu_x86_zen", shard: 3, num_shards: 3, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_cpu_x86_zen", shard: 1, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_cpu_x86_zen", shard: 2, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_cpu_x86_zen", shard: 3, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_cpu_x86_zen", shard: 4, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_cpu_x86_zen", shard: 5, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_torchbench_perf_cpu_x86_zen", shard: 1, num_shards: 4, runner: "linux.24xlarge.amd" }, + { config: "inductor_torchbench_perf_cpu_x86_zen", shard: 2, num_shards: 4, runner: "linux.24xlarge.amd" }, + { config: "inductor_torchbench_perf_cpu_x86_zen", shard: 3, num_shards: 4, runner: "linux.24xlarge.amd" }, + { config: "inductor_torchbench_perf_cpu_x86_zen", shard: 4, num_shards: 4, runner: "linux.24xlarge.amd" }, + ]} + selected-test-configs: ${{ inputs.benchmark_configs }} + secrets: inherit + + linux-jammy-zen-cpu-py3_9-gcc11-inductor-test-nightly: + name: linux-jammy-zen-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-zen-cpu-py3_9-gcc11-inductor-build + if: github.event.schedule == '0 7 * * *' + with: + build-environment: linux-jammy-py3.9-gcc11-build + dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true + docker-image: ${{ needs.linux-jammy-zen-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-zen-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} + timeout-minutes: 720 + # disable monitor in perf tests + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit + + + linux-jammy-zen-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-zen-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-zen-cpu-py3_9-gcc11-inductor-build + if: github.event_name == 'workflow_dispatch' + with: + build-environment: linux-jammy-py3.9-gcc11-build + dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }} + docker-image: ${{ needs.linux-jammy-zen-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-zen-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} + timeout-minutes: 720 + # disable monitor in perf tests + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index f9ea78427e36b2..0466576658d45c 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -1,6 +1,9 @@ name: inductor-perf-nightly-x86 on: + pull_request: + paths: + - .github/workflows/inductor-perf-test-nightly-x86.yml schedule: # - cron: 0 7 * * 1-6 # - cron: 0 7 * * 0 @@ -40,6 +43,11 @@ on: required: false type: boolean default: false + freezing: + description: Run freezing? + required: false + type: boolean + default: true benchmark_configs: description: The list of configs used the benchmark required: false @@ -47,10 +55,12 @@ on: default: inductor_huggingface_perf_cpu_x86,inductor_timm_perf_cpu_x86,inductor_torchbench_perf_cpu_x86 concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-label-type: @@ -90,15 +100,14 @@ jobs: selected-test-configs: ${{ inputs.benchmark_configs }} secrets: inherit - - linux-jammy-cpu-py3_9-gcc11-inductor-test-nightly: + linux-jammy-cpu-py3_9-gcc11-inductor-test-nightly-freezing: name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-test.yml needs: linux-jammy-cpu-py3_9-gcc11-inductor-build if: github.event.schedule == '0 7 * * *' with: build-environment: linux-jammy-py3.9-gcc11-build - dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true + dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true-freezing-true docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} timeout-minutes: 720 @@ -108,7 +117,6 @@ jobs: monitor-data-collect-interval: 4 secrets: inherit - linux-jammy-cpu-py3_9-gcc11-inductor-test: name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-test.yml @@ -116,7 +124,7 @@ jobs: if: github.event_name == 'workflow_dispatch' with: build-environment: linux-jammy-py3.9-gcc11-build - dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }} + dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-freezing-${{ inputs.freezing }} docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} timeout-minutes: 720 diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 61a98be23d780d..015204473339d1 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -63,7 +63,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-label-type: diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 1dbcd17927abc2..2e16c2e403fb05 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -15,7 +15,9 @@ concurrency: cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-default-label-prefix: diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index 5661575d5dbb61..4241854aa3278d 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -43,7 +43,6 @@ jobs: { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2" }, { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2" }, ]} - allow-reuse-old-whl: true secrets: inherit linux-jammy-rocm-py3_10-inductor-test: diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index afff2ee3f3806e..df918c329dd777 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -1,6 +1,6 @@ # Workflow: Inductor Unit Test # 1. runs unit tests for inductor. -# 2. perfoms daily memory leak checks and reruns of disabled tests, scheduled at `29 8 * * *`. +# 2. performs daily memory leak checks and reruns of disabled tests, scheduled at `29 8 * * *`. name: inductor-unittest on: @@ -12,7 +12,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-unittest cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: get-label-type: @@ -26,13 +28,13 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf - linux-jammy-cuda12_6-py3_10-gcc9-inductor-build: - name: cuda12.6-py3.10-gcc9-sm86 + linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: + name: cuda12.8-py3.10-gcc9-sm86 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.6-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | @@ -43,26 +45,25 @@ jobs: { config: "inductor_cpp_wrapper", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_cpp_wrapper", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} - allow-reuse-old-whl: true secrets: inherit - linux-jammy-cuda12_6-py3_10-gcc9-inductor-test: - name: cuda12.6-py3.10-gcc9-sm86 + linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: + name: cuda12.8-py3.10-gcc9-sm86 uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_6-py3_10-gcc9-inductor-build + needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build with: - build-environment: linux-jammy-cuda12.6-py3.10-gcc9-sm86 - docker-image: ${{ needs.linux-jammy-cuda12_6-py3_10-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} secrets: inherit - linux-jammy-cuda12_6-py3_12-gcc9-inductor-build: - name: cuda12.6-py3.12-gcc9-sm86 + linux-jammy-cuda12_8-py3_12-gcc9-inductor-build: + name: cuda12.8-py3.12-gcc9-sm86 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.6-py3.12-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | @@ -70,17 +71,16 @@ jobs: { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} - allow-reuse-old-whl: true secrets: inherit - linux-jammy-cuda12_6-py3_12-gcc9-inductor-test: - name: cuda12.6-py3.12-gcc9-sm86 + linux-jammy-cuda12_8-py3_12-gcc9-inductor-test: + name: cuda12.8-py3.12-gcc9-sm86 uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_6-py3_12-gcc9-inductor-build + needs: linux-jammy-cuda12_8-py3_12-gcc9-inductor-build with: - build-environment: linux-jammy-cuda12.6-py3.12-gcc9-sm86 - docker-image: ${{ needs.linux-jammy-cuda12_6-py3_12-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_6-py3_12-gcc9-inductor-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_12-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_12-gcc9-inductor-build.outputs.test-matrix }} secrets: inherit linux-jammy-cpu-py3_12-inductor-halide-build: @@ -95,7 +95,6 @@ jobs: { include: [ { config: "inductor-halide", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, ]} - allow-reuse-old-whl: true secrets: inherit linux-jammy-cpu-py3_12-inductor-halide-test: @@ -120,7 +119,6 @@ jobs: { include: [ { config: "inductor-triton-cpu", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, ]} - allow-reuse-old-whl: true secrets: inherit linux-jammy-cpu-py3_12-inductor-triton-cpu-test: @@ -148,7 +146,6 @@ jobs: { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, ]} - allow-reuse-old-whl: true secrets: inherit linux-jammy-cpu-py3_9-gcc11-inductor-test: @@ -161,28 +158,27 @@ jobs: test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} secrets: inherit - linux-jammy-cuda12_6-py3_13-gcc9-inductor-build: - name: cuda12.6-py3.13-gcc9-sm86 + linux-jammy-cuda12_8-py3_13-gcc9-inductor-build: + name: cuda12.8-py3.13-gcc9-sm86 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.6-py3.13-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.13-gcc9-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} - allow-reuse-old-whl: true secrets: inherit - linux-jammy-cuda12_6-py3_13-gcc9-inductor-test: - name: cuda12.6-py3.13-gcc9-sm86 + linux-jammy-cuda12_8-py3_13-gcc9-inductor-test: + name: cuda12.8-py3.13-gcc9-sm86 uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_6-py3_13-gcc9-inductor-build + needs: linux-jammy-cuda12_8-py3_13-gcc9-inductor-build with: - build-environment: linux-jammy-cuda12.6-py3.13-gcc9-sm86 - docker-image: ${{ needs.linux-jammy-cuda12_6-py3_13-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_6-py3_13-gcc9-inductor-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.13-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_13-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_13-gcc9-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 24a209ed94daab..e6fc7aa65431a8 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -22,7 +22,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: unit-test: @@ -60,7 +62,6 @@ jobs: { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} - allow-reuse-old-whl: true secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: @@ -93,7 +94,6 @@ jobs: { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, ]} - allow-reuse-old-whl: true secrets: inherit linux-jammy-cpu-py3_9-gcc11-inductor-test: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9408365025d282..d0a2fda509ef3f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -33,7 +33,7 @@ jobs: with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter + docker-image: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 @@ -50,7 +50,7 @@ jobs: with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter + docker-image: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 @@ -66,7 +66,7 @@ jobs: with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-focal-linter + docker-image: ci-image:pytorch-linux-jammy-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | @@ -119,7 +119,7 @@ jobs: with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-focal-linter + docker-image: ci-image:pytorch-linux-jammy-linter fetch-depth: -1 submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -153,7 +153,7 @@ jobs: with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-focal-linter + docker-image: ci-image:pytorch-linux-jammy-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | @@ -189,7 +189,7 @@ jobs: with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-focal-linter + docker-image: ci-image:pytorch-linux-jammy-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml index 173cabe232a1ad..2b840a39a5c210 100644 --- a/.github/workflows/linux-aarch64.yml +++ b/.github/workflows/linux-aarch64.yml @@ -34,16 +34,15 @@ jobs: runner_prefix: ${{ needs.get-label-type.outputs.label-type }} build-environment: linux-jammy-aarch64-py3.10 docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 - runner: linux.arm64.2xlarge + runner: linux.arm64.m7g.4xlarge test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge" }, { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, + { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, ]} secrets: inherit diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index 810602b9c57b45..16cb1600b8d6bb 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -19,7 +19,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: linux-jammy-cpu-py3_9-gcc11-opbenchmark-build: diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 3060636dc5125e..0882019d51151d 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -20,7 +20,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: llm-td: @@ -49,14 +51,14 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-cuda12_6-py3_10-gcc11-build: - name: linux-focal-cuda12.6-py3.10-gcc11 + linux-jammy-cuda12_8-py3_10-gcc11-build: + name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda12.6-py3.10-gcc11 - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -68,26 +70,26 @@ jobs: ]} secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-test: - name: linux-focal-cuda12.6-py3.10-gcc11 + linux-jammy-cuda12_8-py3_10-gcc11-test: + name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda12_6-py3_10-gcc11-build + - linux-jammy-cuda12_8-py3_10-gcc11-build - target-determination with: - build-environment: linux-focal-cuda12.6-py3.10-gcc11 - docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-focal-cuda11_8-py3_9-gcc9-build: - name: linux-focal-cuda11.8-py3.9-gcc9 + linux-jammy-cuda12_8-py3_9-gcc9-build: + name: linux-jammy-cuda12.8-py3.9-gcc9 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda11.8-py3.9-gcc9 - docker-image-name: ci-image:pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 + build-environment: linux-jammy-cuda12.8-py3.9-gcc9 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -97,24 +99,24 @@ jobs: build-with-debug: false secrets: inherit - linux-focal-cuda11_8-py3_9-gcc9-test: - name: linux-focal-cuda11.8-py3.9-gcc9 + linux-jammy-cuda12_8-py3_9-gcc9-test: + name: linux-jammy-cuda12.8-py3.9-gcc9 uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-cuda11_8-py3_9-gcc9-build + needs: linux-jammy-cuda12_8-py3_9-gcc9-build with: - build-environment: linux-focal-cuda11.8-py3.9-gcc9 - docker-image: ${{ needs.linux-focal-cuda11_8-py3_9-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda11_8-py3_9-gcc9-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.9-gcc9 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_9-gcc9-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_9-gcc9-build.outputs.test-matrix }} secrets: inherit - linux-focal-cuda11_8-py3_10-gcc9-debug-build: - name: linux-focal-cuda11.8-py3.10-gcc9-debug + linux-jammy-cuda12_8-py3_10-gcc9-debug-build: + name: linux-jammy-cuda12.8-py3.10-gcc9-debug uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda11.8-py3.10-gcc9-debug - docker-image-name: ci-image:pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 build-with-debug: true test-matrix: | { include: [ @@ -128,16 +130,16 @@ jobs: ]} secrets: inherit - linux-focal-cuda11_8-py3_10-gcc9-debug-test: - name: linux-focal-cuda11.8-py3.10-gcc9-debug + linux-jammy-cuda12_8-py3_10-gcc9-debug-test: + name: linux-jammy-cuda12.8-py3.10-gcc9-debug uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda11_8-py3_10-gcc9-debug-build + - linux-jammy-cuda12_8-py3_10-gcc9-debug-build - target-determination with: - build-environment: linux-focal-cuda11.8-py3.10-gcc9-debug - docker-image: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-debug-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-debug-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.test-matrix }} secrets: inherit linux-jammy-rocm-py3_10-build: @@ -171,14 +173,14 @@ jobs: test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit - linux-focal-cuda12_6-py3-gcc11-slow-gradcheck-build: - name: linux-focal-cuda12.6-py3-gcc11-slow-gradcheck + linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build: + name: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda12.6-py3-gcc11-slow-gradcheck - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -193,15 +195,15 @@ jobs: ]} secrets: inherit - linux-focal-cuda12_6-py3-gcc11-slow-gradcheck-test: - name: linux-focal-cuda12.6-py3-gcc11-slow-gradcheck + linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-test: + name: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda12_6-py3-gcc11-slow-gradcheck-build + - linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build - target-determination with: - build-environment: linux-focal-cuda12.6-py3-gcc11-slow-gradcheck - docker-image: ${{ needs.linux-focal-cuda12_6-py3-gcc11-slow-gradcheck-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_6-py3-gcc11-slow-gradcheck-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck + docker-image: ${{ needs.linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build.outputs.test-matrix }} timeout-minutes: 300 secrets: inherit diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 424f1fb0e605be..53a4f6357e5c2b 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -19,7 +19,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: llm-td: @@ -120,14 +122,14 @@ jobs: ]} secrets: inherit - linux-jammy-py3_10-clang15-asan-build: - name: linux-jammy-py3.10-clang15-asan + linux-jammy-py3_10-clang18-asan-build: + name: linux-jammy-py3.10-clang18-asan uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.10-clang15-asan - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang15-asan + build-environment: linux-jammy-py3.10-clang18-asan + docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -138,31 +140,30 @@ jobs: { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} sync-tag: asan-build - allow-reuse-old-whl: true secrets: inherit - linux-jammy-py3_10-clang15-asan-test: - name: linux-jammy-py3.10-clang15-asan + linux-jammy-py3_10-clang18-asan-test: + name: linux-jammy-py3.10-clang18-asan uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3_10-clang15-asan-build + - linux-jammy-py3_10-clang18-asan-build - target-determination with: - build-environment: linux-jammy-py3.10-clang15-asan - docker-image: ${{ needs.linux-jammy-py3_10-clang15-asan-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_10-clang15-asan-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.10-clang18-asan + docker-image: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.test-matrix }} sync-tag: asan-test secrets: inherit - linux-focal-py3_9-clang10-onnx-build: - name: linux-focal-py3.9-clang10-onnx + linux-jammy-py3_9-clang12-onnx-build: + name: linux-jammy-py3.9-clang12-onnx uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-py3.9-clang10-onnx - docker-image-name: ci-image:pytorch-linux-focal-py3-clang10-onnx + build-environment: linux-jammy-py3.9-clang12-onnx + docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -170,26 +171,26 @@ jobs: ]} secrets: inherit - linux-focal-py3_9-clang10-onnx-test: - name: linux-focal-py3.9-clang10-onnx + linux-jammy-py3_9-clang12-onnx-test: + name: linux-jammy-py3.9-clang12-onnx uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_9-clang10-onnx-build + - linux-jammy-py3_9-clang12-onnx-build - target-determination with: - build-environment: linux-focal-py3.9-clang10-onnx - docker-image: ${{ needs.linux-focal-py3_9-clang10-onnx-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_9-clang10-onnx-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.9-clang12-onnx + docker-image: ${{ needs.linux-jammy-py3_9-clang12-onnx-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang12-onnx-build.outputs.test-matrix }} secrets: inherit - linux-focal-py3_9-clang10-build: - name: linux-focal-py3.9-clang10 + linux-jammy-py3_9-clang12-build: + name: linux-jammy-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-py3.9-clang10 - docker-image-name: ci-image:pytorch-linux-focal-py3.9-clang10 + build-environment: linux-jammy-py3.9-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -202,30 +203,30 @@ jobs: { config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" } ]} - allow-reuse-old-whl: true secrets: inherit - linux-focal-py3_9-clang10-test: - name: linux-focal-py3.9-clang10 + linux-jammy-py3_9-clang12-test: + name: linux-jammy-py3.9-clang12 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_9-clang10-build + - linux-jammy-py3_9-clang12-build - target-determination with: - build-environment: linux-focal-py3.9-clang10 - docker-image: ${{ needs.linux-focal-py3_9-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.9-clang12 + docker-image: ${{ needs.linux-jammy-py3_9-clang12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang12-build.outputs.test-matrix }} secrets: inherit - linux-focal-py3_13-clang10-build: - name: linux-focal-py3.13-clang10 + linux-jammy-py3_13-clang12-build: + name: linux-jammy-py3.13-clang12 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-py3.13-clang10 - docker-image-name: ci-image:pytorch-linux-focal-py3.13-clang10 + build-environment: linux-jammy-py3.13-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-py3.13-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -238,29 +239,29 @@ jobs: { config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" } ]} - allow-reuse-old-whl: true secrets: inherit - linux-focal-py3_13-clang10-test: - name: linux-focal-py3.13-clang10 + linux-jammy-py3_13-clang12-test: + name: linux-jammy-py3.13-clang12 uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-py3_13-clang10-build + needs: linux-jammy-py3_13-clang12-build with: - build-environment: linux-focal-py3.13-clang10 - docker-image: ${{ needs.linux-focal-py3_13-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_13-clang10-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.13-clang12 + docker-image: ${{ needs.linux-jammy-py3_13-clang12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_13-clang12-build.outputs.test-matrix }} timeout-minutes: 600 secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-build-distributed: - name: linux-focal-cuda12.6-py3.10-gcc11-build-distributed + linux-jammy-cuda12_8-py3_10-gcc11-build-distributed: + name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda12.6-py3.10-gcc11-distributed - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '7.5' test-matrix: | { include: [ @@ -270,27 +271,27 @@ jobs: ]} secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-test-distributed: - name: linux-focal-cuda12.6-py3.10-gcc11-test + linux-jammy-cuda12_8-py3_10-gcc11-test-distributed: + name: linux-jammy-cuda12.8-py3.10-gcc11-test uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda12_6-py3_10-gcc11-build-distributed + - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed - target-determination with: timeout-minutes: 360 - build-environment: linux-focal-cuda12.6-py3.10-gcc11-distributed - docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-build-distributed.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-build-distributed.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.test-matrix }} secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-build: - name: linux-focal-cuda12.6-py3.10-gcc11 + linux-jammy-cuda12_8-py3_10-gcc11-build: + name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda12.6-py3.10-gcc11 - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, @@ -299,20 +300,19 @@ jobs: { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} - allow-reuse-old-whl: true secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-test: - name: linux-focal-cuda12.6-py3.10-gcc11 + linux-jammy-cuda12_8-py3_10-gcc11-test: + name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda12_6-py3_10-gcc11-build + - linux-jammy-cuda12_8-py3_10-gcc11-build - target-determination with: timeout-minutes: 360 - build-environment: linux-focal-cuda12.6-py3.10-gcc11 - docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit linux-jammy-py3-clang12-mobile-build: @@ -330,27 +330,27 @@ jobs: ]} secrets: inherit - linux-jammy-cuda-11_8-cudnn9-py3_9-clang12-build: - name: linux-jammy-cuda11.8-cudnn9-py3.9-clang12 + linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build: + name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda11.8-cudnn9-py3.9-clang12 - docker-image-name: ci-image:pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12 + build-environment: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} secrets: inherit - linux-focal-py3_9-clang9-xla-build: - name: linux-focal-py3_9-clang9-xla + linux-jammy-py3_9-clang9-xla-build: + name: linux-jammy-py3_9-clang9-xla uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-py3.9-clang9-xla + build-environment: linux-jammy-py3.9-clang9-xla docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite test-matrix: | { include: [ @@ -358,24 +358,24 @@ jobs: ]} secrets: inherit - linux-focal-py3_9-clang9-xla-test: - name: linux-focal-py3_9-clang9-xla + linux-jammy-py3_9-clang9-xla-test: + name: linux-jammy-py3_9-clang9-xla uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-py3_9-clang9-xla-build + needs: linux-jammy-py3_9-clang9-xla-build with: - build-environment: linux-focal-py3.9-clang9-xla - docker-image: ${{ needs.linux-focal-py3_9-clang9-xla-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_9-clang9-xla-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.9-clang9-xla + docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }} secrets: inherit - linux-focal-cpu-py3_10-gcc11-bazel-test: - name: linux-focal-cpu-py3.10-gcc11-bazel-test + linux-jammy-cpu-py3_10-gcc11-bazel-test: + name: linux-jammy-cpu-py3.10-gcc11-bazel-test uses: ./.github/workflows/_bazel-build-test.yml needs: get-label-type with: runner: "${{ needs.get-label-type.outputs.label-type }}linux.large" - build-environment: linux-focal-cuda12.6-py3.10-gcc11-bazel-test - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-bazel-test + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-version: cpu test-matrix: | { include: [ @@ -417,14 +417,14 @@ jobs: ]} secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-sm89-build: - name: linux-focal-cuda12.6-py3.10-gcc11-sm89 + linux-jammy-cuda12_8-py3_10-gcc11-sm89-build: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm89 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm89 - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: 8.9 test-matrix: | { include: [ @@ -434,40 +434,18 @@ jobs: { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, ]} - allow-reuse-old-whl: true - secrets: inherit - - unstable-linux-focal-cuda12_6-py3_10-gcc11-sm89-build-xfail: - # A version of the build that sets a larger number of jobs for a build. May - # OOM - name: unstable-linux-focal-cuda12.6-py3.10-gcc11-sm89-xfail - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm89 - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 - cuda-arch-list: 8.9 - max-jobs: 4 - # Doesn't actually run tests, but need this in order to prevent the build - # from being skipped - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, - ]} - allow-reuse-old-whl: true secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-sm89-test: - name: linux-focal-cuda12.6-py3.10-gcc11-sm89 + linux-jammy-cuda12_8-py3_10-gcc11-sm89-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm89 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda12_6-py3_10-gcc11-sm89-build + - linux-jammy-cuda12_8-py3_10-gcc11-sm89-build - target-determination with: - build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm89 - docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm89-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm89-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.test-matrix }} secrets: inherit linux-jammy-py3-clang12-executorch-build: @@ -488,6 +466,7 @@ jobs: name: linux-jammy-py3-clang12-executorch uses: ./.github/workflows/_linux-test.yml needs: linux-jammy-py3-clang12-executorch-build + if: false # Has been broken for a while with: build-environment: linux-jammy-py3-clang12-executorch docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }} @@ -507,7 +486,6 @@ jobs: { include: [ { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, ]} - allow-reuse-old-whl: true secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: diff --git a/.github/workflows/s390.yml b/.github/workflows/s390.yml index 1fe4638bff60b0..a01c62f22f8210 100644 --- a/.github/workflows/s390.yml +++ b/.github/workflows/s390.yml @@ -2,8 +2,6 @@ name: s390 on: push: - branches: - - main tags: - ciflow/s390/* workflow_dispatch: diff --git a/.github/workflows/s390x-periodic.yml b/.github/workflows/s390x-periodic.yml index 93e28ee257fae2..405e3e1a581ccd 100644 --- a/.github/workflows/s390x-periodic.yml +++ b/.github/workflows/s390x-periodic.yml @@ -9,15 +9,15 @@ on: tags: - ciflow/periodic/* - ciflow/s390/* - branches: - - release/* workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: llm-td: diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 9e84ca613c70a3..2a7b1d184330bb 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -18,7 +18,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: llm-td: @@ -47,14 +49,14 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-cuda12_6-py3_10-gcc11-sm86-build: - name: linux-focal-cuda12.6-py3.10-gcc11-sm86 + linux-jammy-cuda12_8-py3_10-gcc11-sm86-build: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm86 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm86 - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -64,26 +66,26 @@ jobs: ]} secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-sm86-test: - name: linux-focal-cuda12.6-py3.10-gcc11-sm86 + linux-jammy-cuda12_8-py3_10-gcc11-sm86-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm86 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda12_6-py3_10-gcc11-sm86-build + - linux-jammy-cuda12_8-py3_10-gcc11-sm86-build - target-determination with: - build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm86 - docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm86-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm86-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm86-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm86-build.outputs.test-matrix }} secrets: inherit - linux-focal-py3_9-clang10-build: - name: linux-focal-py3.9-clang10 + linux-jammy-py3_9-clang12-build: + name: linux-jammy-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-py3.9-clang10 - docker-image-name: ci-image:pytorch-linux-focal-py3.9-clang10 + build-environment: linux-jammy-py3.9-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-clang12 test-matrix: | { include: [ { config: "slow", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, @@ -91,16 +93,16 @@ jobs: ]} secrets: inherit - linux-focal-py3_9-clang10-test: - name: linux-focal-py3.9-clang10 + linux-jammy-py3_9-clang12-test: + name: linux-jammy-py3.9-clang12 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_9-clang10-build + - linux-jammy-py3_9-clang12-build - target-determination with: - build-environment: linux-focal-py3.9-clang10 - docker-image: ${{ needs.linux-focal-py3_9-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.9-clang12 + docker-image: ${{ needs.linux-jammy-py3_9-clang12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang12-build.outputs.test-matrix }} secrets: inherit linux-jammy-rocm-py3_10-build: @@ -133,14 +135,14 @@ jobs: test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit - linux-jammy-py3_10-clang15-asan-build: - name: linux-jammy-py3.10-clang15-asan + linux-jammy-py3_10-clang18-asan-build: + name: linux-jammy-py3.10-clang18-asan uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.10-clang15-asan - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang15-asan + build-environment: linux-jammy-py3.10-clang18-asan + docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan test-matrix: | { include: [ { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -148,18 +150,17 @@ jobs: { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} sync-tag: asan-build - allow-reuse-old-whl: true secrets: inherit - linux-jammy-py3_10-clang15-asan-test: - name: linux-jammy-py3.10-clang15-asan + linux-jammy-py3_10-clang18-asan-test: + name: linux-jammy-py3.10-clang18-asan uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3_10-clang15-asan-build + - linux-jammy-py3_10-clang18-asan-build - target-determination with: - build-environment: linux-jammy-py3.10-clang15-asan - docker-image: ${{ needs.linux-jammy-py3_10-clang15-asan-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_10-clang15-asan-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.10-clang18-asan + docker-image: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.test-matrix }} sync-tag: asan-test secrets: inherit diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index eaaa1b93b967c2..ec579fda8da947 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -37,7 +37,7 @@ jobs: id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: - docker-image-name: ci-image:pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 working-directory: pytorch - name: Use following to pull public copy of the image @@ -46,7 +46,7 @@ jobs: ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} shell: bash run: | - tag=${ECR_DOCKER_IMAGE##*/} + tag=${ECR_DOCKER_IMAGE##*:} echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image @@ -100,8 +100,10 @@ jobs: AWS_DEFAULT_REGION: us-east-1 run: | # detached container should get cleaned up by teardown_ec2_linux + # Disable shellcheck warning for GPU_FLAG + # shellcheck disable=SC2086 container_name=$(docker run \ - "${GPU_FLAG:-}" \ + ${GPU_FLAG:-} \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e AWS_DEFAULT_REGION \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ diff --git a/.github/workflows/test-check-binary.yml b/.github/workflows/test-check-binary.yml index ab00386c3e5415..0d31948f196a10 100644 --- a/.github/workflows/test-check-binary.yml +++ b/.github/workflows/test-check-binary.yml @@ -34,7 +34,9 @@ jobs: docker-image: python:3.11 docker-build-dir: "skip-docker-build" script: | + STABLE_CUDA_VERSION=$(python3 .github/scripts/get_ci_variable.py --cuda-stable-version) + CUDA_VERSION_NODOT=$(echo ${STABLE_CUDA_VERSION} | tr -d '.') pushd .ci/pytorch/ - pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 - DESIRED_PYTHON=3.11 DESIRED_CUDA=cu124 PACKAGE_TYPE=manywheel ./check_binary.sh + pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION_NODOT} + DESIRED_PYTHON=3.11 DESIRED_CUDA=cu${CUDA_VERSION_NODOT} PACKAGE_TYPE=manywheel ./check_binary.sh popd diff --git a/.github/workflows/test-h100.yml b/.github/workflows/test-h100.yml index 9f61c53580c305..7fc878a5ad9d54 100644 --- a/.github/workflows/test-h100.yml +++ b/.github/workflows/test-h100.yml @@ -27,15 +27,15 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-cuda12_6-py3_10-gcc11-sm90-build: - name: linux-focal-cuda12.6-py3.10-gcc11-sm90 + linux-jammy-cuda12_8-py3_10-gcc11-sm90-build: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: "linux.12xlarge" - build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm90 - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '9.0' test-matrix: | { include: [ @@ -43,13 +43,13 @@ jobs: ]} secrets: inherit - linux-focal-cuda12_6-py3_10-gcc11-sm90-test: - name: linux-focal-cuda12.6-py3.10-gcc11-sm90 + linux-jammy-cuda12_8-py3_10-gcc11-sm90-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda12_6-py3_10-gcc11-sm90-build + - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build with: - build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm90 - docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm90-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm90-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/trunk-tagging.yml b/.github/workflows/trunk-tagging.yml new file mode 100644 index 00000000000000..b460195c37e6b6 --- /dev/null +++ b/.github/workflows/trunk-tagging.yml @@ -0,0 +1,224 @@ +name: trunk-tagging + +on: + push: + branches: + - main + workflow_dispatch: + inputs: + commit_sha: + description: 'Commit SHA to tag (leave empty for current HEAD)' + required: false + type: string + +concurrency: + group: trunk-tagging-${{ github.event.inputs.commit_sha || github.sha }} + cancel-in-progress: false + +permissions: + contents: write + +jobs: + tag-trunk-commit: + name: Tag trunk commit + runs-on: ubuntu-latest + if: github.repository_owner == 'pytorch' + + steps: + - name: Pre-checkout validation + run: | + # For workflow_dispatch, validate SHA format before checkout + if [ -n "${{ github.event.inputs.commit_sha }}" ]; then + COMMIT_SHA="${{ github.event.inputs.commit_sha }}" + + # Verify it's a well-formed SHA (40 hex characters) + if ! echo "${COMMIT_SHA}" | grep -qE '^[a-f0-9]{40}$'; then + echo "Error: Invalid commit SHA format. Expected 40 hexadecimal characters, got: ${COMMIT_SHA}" + exit 1 + fi + + echo "✅ Pre-checkout validation passed for: ${COMMIT_SHA}" + else + echo "✅ Using current commit SHA - no pre-checkout validation needed" + fi + + - name: Checkout repository + uses: actions/checkout@v4 + with: + # Fetch full history to ensure we have all commits + fetch-depth: 0 + # For workflow_dispatch, checkout the specified commit + ref: ${{ github.event.inputs.commit_sha || github.sha }} + + - name: Set commit SHA + id: commit + run: | + if [ -n "${{ github.event.inputs.commit_sha }}" ]; then + COMMIT_SHA="${{ github.event.inputs.commit_sha }}" + else + COMMIT_SHA="${{ github.sha }}" + fi + echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" + echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" + + - name: Validate commit SHA + run: | + COMMIT_SHA="${{ steps.commit.outputs.sha }}" + + # Verify the commit exists and is valid + if ! git cat-file -e "${COMMIT_SHA}"; then + echo "Error: Commit SHA ${COMMIT_SHA} does not exist in repository" + exit 1 + fi + + # For workflow_dispatch, verify the commit exists on main branch + if [ -n "${{ github.event.inputs.commit_sha }}" ]; then + echo "Manual dispatch detected - validating commit is on main branch..." + + # Get all commits reachable from main branch + if ! git merge-base --is-ancestor "${COMMIT_SHA}" origin/main; then + echo "Error: Commit ${COMMIT_SHA} is not reachable from main branch" + echo "Only commits that exist on the main branch can be tagged" + exit 1 + fi + + echo "✅ Commit ${COMMIT_SHA} is valid and exists on main branch" + else + echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)" + fi + + - name: Create and push tag with retry + id: check_tag + env: + TAG_NAME: ${{ steps.commit.outputs.tag_name }} + COMMIT_SHA: ${{ steps.commit.outputs.sha }} + run: | + set -e + + # Check if tag already exists + check_tag_exists() { + # Check if tag exists locally + if git tag -l "${TAG_NAME}" | grep -q "${TAG_NAME}"; then + echo "Tag ${TAG_NAME} already exists locally" + return 0 + fi + + # Check if tag exists on remote + if git ls-remote --tags origin "${TAG_NAME}" | grep -q "${TAG_NAME}"; then + echo "Tag ${TAG_NAME} already exists on remote" + return 0 + fi + + return 1 + } + + # Exit early if tag already exists + if check_tag_exists; then + echo "✅ Tag already exists - no action needed" + echo "exists=true" >> "${GITHUB_OUTPUT}" + exit 0 + fi + + echo "Tag ${TAG_NAME} does not exist, proceeding with creation" + + # Retry configuration + MAX_RETRIES=5 + BASE_DELAY=2 + BACKOFF_MULTIPLIER=4 + MAX_DELAY=3600 + + # Common retry function with exponential backoff + retry_with_backoff() { + local command="${1}" + local description="${2}" + local retry_count=0 + + while [ "${retry_count}" -le "${MAX_RETRIES}" ]; do + echo "Attempt $((retry_count + 1))/$((MAX_RETRIES + 1)): ${description}" + + if eval "${command}"; then + echo "Success on attempt $((retry_count + 1))" + return 0 + fi + + retry_count=$((retry_count + 1)) + + if [ "${retry_count}" -le "${MAX_RETRIES}" ]; then + # Calculate delay with exponential backoff + local delay=$((BASE_DELAY * (BACKOFF_MULTIPLIER ** retry_count))) + if [ "${delay}" -gt "${MAX_DELAY}" ]; then + delay="${MAX_DELAY}" + fi + + echo "Failed. Retrying in ${delay} seconds..." + sleep "${delay}" + fi + done + + echo "All retry attempts exhausted" + return 1 + } + + # Function to create and push tag + create_and_push_tag() { + # Create the tag + if ! git tag "${TAG_NAME}" "${COMMIT_SHA}"; then + echo "Failed to create local tag" + return 1 + fi + + # Push the tag + if git push origin "${TAG_NAME}"; then + echo "Successfully created and pushed tag ${TAG_NAME}" + return 0 + else + echo "Failed to push tag to remote" + # Clean up local tag for retry + git tag -d "${TAG_NAME}" 2>/dev/null || true + return 1 + fi + } + + # Function to handle retries with race condition checks + tag_with_retry() { + # Check if tag exists before attempting creation + if check_tag_exists; then + echo "Tag ${TAG_NAME} was created by another process, exiting successfully" + return 0 + fi + + create_and_push_tag || { + # Fetch latest state for next retry + git fetch origin --tags + return 1 + } + } + + # Execute with retry + if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then + echo "exists=false" >> "${GITHUB_OUTPUT}" + exit 0 + else + echo "Tag creation failed after all retry attempts" + exit 1 + fi + + - name: Tag creation summary + if: always() + run: | + if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then + echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" + elif [ "${{ job.status }}" = "success" ]; then + echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" + else + echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" + fi + + echo "" + echo "Tag details:" + echo " Name: ${{ steps.commit.outputs.tag_name }}" + echo " Commit: ${{ steps.commit.outputs.sha }}" + echo " Trigger: ${{ github.event_name }}" + if [ -n "${{ github.event.inputs.commit_sha }}" ]; then + echo " Manual commit: ${{ github.event.inputs.commit_sha }}" + fi diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 9e4aa9237f0764..261c95c507bbbf 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -16,7 +16,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read jobs: llm-td: @@ -45,13 +47,13 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - libtorch-linux-focal-cuda12_6-py3_10-gcc11-debug-build: - name: libtorch-linux-focal-cuda12.6-py3.10-gcc11-debug + libtorch-linux-jammy-cuda12_8-py3_10-gcc11-debug-build: + name: libtorch-linux-jammy-cuda12.8-py3.10-gcc11-debug uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: libtorch-linux-focal-cuda12.6-py3.10-gcc11 - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: libtorch-linux-jammy-cuda12.8-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 build-generates-artifacts: false runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: "linux.4xlarge" @@ -62,14 +64,14 @@ jobs: secrets: inherit # no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated - linux-focal-cuda12_6-py3_10-gcc11-no-ops-build: - name: linux-focal-cuda12.6-py3.10-gcc11-no-ops + linux-jammy-cuda12_8-py3_10-gcc11-no-ops-build: + name: linux-jammy-cuda12.8-py3.10-gcc11-no-ops uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-cuda12.6-py3.10-gcc11-no-ops - docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-no-ops + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml index 03e5611226d826..19e169bd973b32 100644 --- a/.github/workflows/trymerge.yml +++ b/.github/workflows/trymerge.yml @@ -102,7 +102,7 @@ jobs: s3-prefix: merges/${{ github.repository }}/${{ github.event.client_payload.pr_num }}/${{ github.event.client_payload.comment_id }}/${{ github.run_id }} path: merge_record.json -# We want newer merge commands to supercede old ones +# We want newer merge commands to supersede old ones concurrency: group: try-merge-${{ github.event.client_payload.pr_num }} cancel-in-progress: true diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index a830a5bc724f8e..59d3590665fda2 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -7,7 +7,7 @@ on: concurrency: group: ${{ github.workflow }} - cancel-in-progress: false + cancel-in-progress: true jobs: do_update_viablestrict: diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index 67fa4081041517..c62918b4af210d 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -5,6 +5,10 @@ on: tags: - ciflow/xpu/* workflow_dispatch: + schedule: + # Run 3 times on weekdays and less frequently on weekends. + - cron: 45 0,8,16 * * 1-5 + - cron: 45 4 * * 0,6 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} diff --git a/.gitignore b/.gitignore index 1e82712e718d02..b4e78e642b2450 100644 --- a/.gitignore +++ b/.gitignore @@ -382,3 +382,6 @@ android/pytorch_android_torchvision/.cxx .arcconfig .stable_pyre_client .pyre_client + +# Claude Code local configuration +CLAUDE.local.md diff --git a/.gitmodules b/.gitmodules index bf1fc38c6797d2..4eb6e511127d0b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,10 +2,6 @@ ignore = dirty path = third_party/pybind11 url = https://github.com/pybind/pybind11.git -[submodule "third_party/eigen"] - ignore = dirty - path = third_party/eigen - url = https://gitlab.com/libeigen/eigen.git [submodule "third_party/googletest"] ignore = dirty path = third_party/googletest @@ -133,3 +129,6 @@ [submodule "third_party/flash-attention"] path = third_party/flash-attention url = https://github.com/Dao-AILab/flash-attention.git +[submodule "third_party/aiter"] + path = third_party/aiter + url = https://github.com/ROCm/aiter.git diff --git a/.lintrunner.toml b/.lintrunner.toml index 0e13cfc9629c0e..567d8f3fd36f94 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -64,6 +64,7 @@ include_patterns = [ 'aten/src/ATen/xpu/**/*.cpp', 'aten/src/ATen/core/boxing/**/*.h', 'aten/src/ATen/core/dispatch/**/*.h', + 'aten/src/ATen/core/Formatting.cpp', 'aten/src/ATen/native/mps/**/*.metal', 'aten/src/ATen/native/mps/**/*.mm', 'aten/src/ATen/native/mps/**/*.h', @@ -86,6 +87,7 @@ include_patterns = [ 'torch/csrc/**/*.cpp', 'torch/nativert/**/*.h', 'torch/nativert/**/*.cpp', + 'torch/headeronly/**/*.h', 'test/cpp/**/*.h', 'test/cpp/**/*.cpp', ] @@ -120,6 +122,7 @@ is_formatter = true [[linter]] code = 'MYPY' include_patterns = [ + 'setup.py', 'torch/**/*.py', 'torch/**/*.pyi', 'caffe2/**/*.py', @@ -152,21 +155,21 @@ init_command = [ 'numpy==1.26.4 ; python_version >= "3.9" and python_version <= "3.11"', 'numpy==2.1.0 ; python_version >= "3.12"', 'expecttest==0.3.0', - 'mypy==1.15.0', + 'mypy==1.16.0', 'sympy==1.13.3', 'types-requests==2.27.25', - 'types-PyYAML==6.0.7', + 'types-pyyaml==6.0.1', 'types-tabulate==0.8.8', 'types-protobuf==5.29.1.20250403', - 'types-pkg-resources==0.1.3', - 'types-Jinja2==2.11.9', + 'types-setuptools==79.0.0.20250422', + 'types-jinja2==2.11.9', 'types-colorama==0.4.6', 'filelock==3.13.1', 'junitparser==2.1.1', 'rich==10.9.0', 'pyyaml==6.0.1', 'optree==0.13.0', - 'dataclasses_json==0.6.7', + 'dataclasses-json==0.6.7', 'pandas==2.2.3', ] @@ -239,6 +242,7 @@ include_patterns = [ 'torch/nativert/*.cpp', 'torch/nativert/**/*.h', 'torch/nativert/**/*.cpp', + 'torch/headeronly/**/*.h', ] exclude_patterns = [ # The negative filters below are to exclude files that include onnx_pb.h or @@ -1131,6 +1135,53 @@ init_command = [ 'PyYAML==6.0.1', ] +[[linter]] +code = 'CODESPELL' +command = [ + 'python3', + 'tools/linter/adapters/codespell_linter.py', + '--', + '@{{PATHSFILE}}' +] +include_patterns = [ + '**', +] +exclude_patterns = [ + # We don't care too much about files in this directory, don't enforce + # spelling on them + 'caffe2/**', + 'fb/**', + '**/fb/**', + 'third_party/**', + 'test/dynamo/cpython/**', + 'torch/_vendor/**', + 'torch/_inductor/fx_passes/serialized_patterns/**', + 'torch/_inductor/autoheuristic/artifacts/**', + 'torch/utils/model_dump/preact.mjs', + # These files are all grandfathered in, feel free to remove from this list + # as necessary + # NOTE: remove the patterns in the order they are listed + 'aten/**', + 'aten/src/ATen/native/**', + 'aten/src/ATen/native/q*/**', + 'aten/src/ATen/native/[a-pA-P]*/**', + 'aten/src/ATen/[a-mA-M]*/**', + 'test/**', + 'test/test_*', + 'test/[a-hA-h]*/**', + 'test/distributed/**', + 'torch/**', + 'torch/_*/**', + 'torch/distributed/tensor/**', +] +init_command = [ + 'python3', + 'tools/linter/adapters/pip_init.py', + '--dry-run={{DRYRUN}}', + 'codespell[toml]==2.4.1', +] +is_formatter = true + # usort + ruff-format [[linter]] code = 'PYFMT' @@ -1264,10 +1315,6 @@ exclude_patterns = [ 'test/test_unary_ufuncs.py', 'test/test_vulkan.py', 'torch/_awaits/__init__.py', - 'torch/_custom_op/__init__.py', - 'torch/_custom_op/autograd.py', - 'torch/_custom_op/functional.py', - 'torch/_custom_op/impl.py', 'torch/_export/__init__.py', 'torch/_export/constraints.py', 'torch/_export/db/__init__.py', @@ -1305,17 +1352,6 @@ exclude_patterns = [ 'torch/_export/db/examples/type_reflection_method.py', 'torch/_export/db/gen_example.py', 'torch/_export/db/logging.py', - 'torch/_export/error.py', - 'torch/_export/exported_program.py', - 'torch/_export/pass_base.py', - 'torch/_export/pass_infra/__init__.py', - 'torch/_export/pass_infra/node_metadata.py', - 'torch/_export/pass_infra/proxy_value.py', - 'torch/_export/passes/__init__.py', - 'torch/_export/passes/add_runtime_assertions_for_constraints_pass.py', - 'torch/_export/passes/const_prop_pass.py', - 'torch/_export/passes/functionalize_side_effectful_ops_pass.py', - 'torch/_export/passes/replace_sym_size_ops_pass.py', 'torch/testing/_internal/__init__.py', 'torch/testing/_internal/autocast_test_lists.py', 'torch/testing/_internal/autograd_function_db.py', @@ -1355,19 +1391,6 @@ exclude_patterns = [ 'torch/testing/_internal/test_module/__init__.py', 'torch/testing/_internal/test_module/future_div.py', 'torch/testing/_internal/test_module/no_future_div.py', - 'torch/utils/_contextlib.py', - 'torch/utils/_cpp_extension_versioner.py', - 'torch/utils/_crash_handler.py', - 'torch/utils/_device.py', - 'torch/utils/_foreach_utils.py', - 'torch/utils/_freeze.py', - 'torch/utils/_mode_utils.py', - 'torch/utils/_python_dispatch.py', - 'torch/utils/_stats.py', - 'torch/utils/_traceback.py', - 'torch/utils/_zip.py', - 'torch/utils/backcompat/__init__.py', - 'torch/utils/backend_registration.py', 'torch/utils/benchmark/__init__.py', 'torch/utils/benchmark/examples/__init__.py', 'torch/utils/benchmark/examples/compare.py', @@ -1440,8 +1463,8 @@ init_command = [ '--no-black-binary', 'black==23.12.1', 'usort==1.0.8.post1', - 'isort==5.13.2', - 'ruff==0.11.10', # sync with RUFF + 'isort==6.0.1', + 'ruff==0.11.13', # sync with RUFF ] is_formatter = true @@ -1532,7 +1555,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.11.10', # sync with PYFMT + 'ruff==0.11.13', # sync with PYFMT ] is_formatter = true diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 9b22ad8d65e5e4..e6d0ebc6afc1ef 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -1,5 +1,11 @@ { - "recommendations": [ - "ms-python.python", - ] + "recommendations": [ + "ms-python.python", + "charliermarsh.ruff", + "ms-python.flake8", + "ms-python.mypy-type-checker", + "ms-vscode.cmake-tools", + "EditorConfig.EditorConfig", + "streetsidesoftware.code-spell-checker", + ] } diff --git a/.vscode/settings_recommended.json b/.vscode/settings_recommended.json index 551a3ec2a5a389..ca06859b80d389 100644 --- a/.vscode/settings_recommended.json +++ b/.vscode/settings_recommended.json @@ -1,15 +1,53 @@ { - "[python]": { - "editor.tabSize": 4 - }, "files.associations": { + ".clang-format": "yaml", + ".clang-tidy": "yaml", + ".flake8": "ini", + ".coveragerc": "ini", "*.py.in": "python", - "*.pyi.in": "python" + "*.pyi.in": "python", + "*requirements*.txt": "pip-requirements", + "*requirements*.in": "pip-requirements", + "*.cpp.in": "cpp", + "*.h.in": "cpp", + "*.cmake.in": "cmake", + "Makefile.*": "makefile", + "*.Makefile": "makefile", + "BUCK": "starlark", + "BUCK.*": "starlark" }, "files.eol": "\n", "files.insertFinalNewline": true, "files.trimFinalNewlines": true, "files.trimTrailingWhitespace": true, - "python.linting.enabled": true, - "python.linting.flake8Enabled": true + "cmake.preferredGenerators": [ + "Ninja", + "Unix Makefiles" + ], + "cmake.configureEnvironment": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON" + }, + "cmake.sourceDirectory": "${workspaceFolder}", + "cmake.buildDirectory": "${workspaceFolder}/build", + "cmake.configureArgs": [ + "-DPython_EXECUTABLE=${workspaceFolder}/venv/bin/python", + "-DPython_ROOT_DIR=${workspaceFolder}/venv" + ], + "[python]": { + "editor.tabSize": 4, + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python", + "python.analysis.inlayHints.functionReturnTypes": true, + "flake8.importStrategy": "fromEnvironment", + "flake8.args": [ + "--append-config=${workspaceFolder}/.flake8" + ], + "ruff.importStrategy": "fromEnvironment", + "ruff.lineLength": 88, + "ruff.organizeImports": false, + "ruff.configurationPreference": "filesystemFirst", + "mypy-type-checker.importStrategy": "fromEnvironment", + "mypy-type-checker.preferDaemon": true, + "mypy-type-checker.reportingScope": "workspace" } diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000000000..daf0f491702ba7 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +- This is the only AGENTS.md, there are no recursive AGENTS.md diff --git a/BUILD.bazel b/BUILD.bazel index 2d3e1d7cdf72fd..5a31eb6558aa87 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -290,6 +290,7 @@ header_template_rule( substitutions = { "@AT_CUDNN_ENABLED@": "1", "@AT_CUSPARSELT_ENABLED@": "0", + "@AT_HIPSPARSELT_ENABLED@": "0", "@AT_ROCM_ENABLED@": "0", "@AT_MAGMA_ENABLED@": "0", "@NVCC_FLAGS_EXTRA@": "", @@ -499,7 +500,7 @@ filegroup( # To achieve finer granularity and make debug easier, caffe2 is split into three libraries: # ATen, caffe2 and caffe2_for_aten_headers. ATen lib group up source codes under # aten/ directory and caffe2 contains most files under `caffe2/` directory. Since the -# ATen lib and the caffe2 lib would depend on each other, `caffe2_for_aten_headers` is splitted +# ATen lib and the caffe2 lib would depend on each other, `caffe2_for_aten_headers` is split # out from `caffe2` to avoid dependency cycle. cc_library( name = "caffe2_for_aten_headers", @@ -581,9 +582,9 @@ cc_library( cu_library( name = "torch_cuda", srcs = [ - "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/NanCheck.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", + "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", ], copts = torch_cuda_half_options, visibility = ["//visibility:public"], @@ -736,15 +737,15 @@ cc_library( srcs = if_cuda(glob( libtorch_cuda_sources, exclude = [ - "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", - "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp", - "torch/csrc/distributed/c10d/cuda/AsyncMM.cu", + "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/distributed/c10d/NanCheck.cu", + "torch/csrc/distributed/c10d/cuda/AsyncMM.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", + "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu", + "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu", + "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp", + "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", ], )) + torch_sources, copts = TORCH_COPTS, diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b479073f3a6d4..99c0b9e0ea0c99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +cmake_minimum_required(VERSION 3.27 FATAL_ERROR) # cmake_policy(SET CMP0022 NEW) cmake_policy(SET CMP0023 NEW) # Use compiler ID "AppleClang" instead of "Clang" for XCode. Not setting this @@ -6,6 +6,7 @@ cmake_minimum_required(VERSION 3.18 FATAL_ERROR) # one is detected as "AppleClang". cmake_policy(SET CMP0010 NEW) cmake_policy(SET CMP0025 NEW) +cmake_policy(SET CMP0126 OLD) # Enables CMake to set LTO on compilers other than Intel. cmake_policy(SET CMP0069 NEW) @@ -16,6 +17,8 @@ cmake_policy(SET CMP0069 NEW) # we do this (and we don't if cmake is old), but it's nice when it's possible, # and it's possible on our Windows configs. cmake_policy(SET CMP0092 NEW) +# Don't remove the FindCUDA module +cmake_policy(SET CMP0146 OLD) # Prohibit in-source builds if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR}) @@ -259,12 +262,14 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) -cmake_dependent_option(USE_XCCL "Use XCCL" OFF +cmake_dependent_option(USE_XCCL "Use XCCL" ON "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" OFF) +cmake_dependent_option(USE_NVSHMEM "Use NVSHMEM" ON + "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) option(USE_NNAPI "Use NNAPI" OFF) option(USE_NNPACK "Use NNPACK" ON) cmake_dependent_option(USE_NUMA "Use NUMA. Only available on Linux." ON "LINUX" @@ -526,7 +531,6 @@ if(USE_LIGHTWEIGHT_DISPATCH AND NOT STATIC_DISPATCH_BACKEND) endif() option(TRACING_BASED "Master flag to build Lite Interpreter with tracing build option" OFF) -option(BUILD_EXECUTORCH "Master flag to build Executorch" ON) # This is a fix for a rare build issue on Ubuntu: symbol lookup error: # miniconda3/envs/pytorch-py3.7/lib/libmkl_intel_lp64.so: undefined symbol: # mkl_blas_dsyrk @@ -558,6 +562,10 @@ if(MSVC) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus") set(CMAKE_NINJA_CMCLDEPS_RC OFF) + if(MSVC_Z7_OVERRIDE) + # CMake set debug flags to use /Z7 + set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT Embedded) + endif() foreach( flag_var CMAKE_C_FLAGS @@ -570,12 +578,6 @@ if(MSVC) CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) - # Replace /Zi and /ZI with /Z7 - if(MSVC_Z7_OVERRIDE) - if(${flag_var} MATCHES "/Z[iI]") - string(REGEX REPLACE "/Z[iI]" "/Z7" ${flag_var} "${${flag_var}}") - endif(${flag_var} MATCHES "/Z[iI]") - endif(MSVC_Z7_OVERRIDE) if(${CAFFE2_USE_MSVC_STATIC_RUNTIME}) if(${flag_var} MATCHES "/MD") @@ -698,7 +700,7 @@ endif() if(USE_KLEIDIAI AND CMAKE_C_COMPILER_VERSION) if(CMAKE_C_COMPILER_VERSION VERSION_LESS 11) set(USE_KLEIDIAI OFF) - message(WARNING "Disabling KleidiAI: Requires atleast GCC 11 or Clang 11") + message(WARNING "Disabling KleidiAI: Requires at least GCC 11 or Clang 11") endif() endif() @@ -983,6 +985,9 @@ endif() # ---[ Build flags Re-include to override append_cxx_flag_if_supported from # third_party/FBGEMM include(cmake/public/utils.cmake) +if(USE_COLORIZE_OUTPUT) + set(CMAKE_COLOR_DIAGNOSTICS ON) +endif() if(NOT MSVC) string(APPEND CMAKE_CXX_FLAGS " -O2 -fPIC") @@ -1058,19 +1063,6 @@ if(NOT MSVC) CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Qunused-arguments" CMAKE_CXX_FLAGS) - if(${USE_COLORIZE_OUTPUT}) - # Why compiler checks are necessary even when `try_compile` is used Because - # of the bug in ccache that can incorrectly identify `-fcolor-diagnostics` - # As supported by GCC, see https://github.com/ccache/ccache/issues/740 (for - # older ccache) and https://github.com/ccache/ccache/issues/1275 (for newer - # ones) - if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - append_cxx_flag_if_supported("-fdiagnostics-color=always" CMAKE_CXX_FLAGS) - else() - append_cxx_flag_if_supported("-fcolor-diagnostics" CMAKE_CXX_FLAGS) - endif() - endif() - append_cxx_flag_if_supported("-faligned-new" CMAKE_CXX_FLAGS) if(WERROR) @@ -1256,7 +1248,7 @@ endif() add_subdirectory(c10) add_subdirectory(caffe2) -# ---[ CMake related files Uninistall option. +# ---[ CMake related files Uninstall option. if(NOT TARGET caffe2_uninstall) configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/cmake/cmake_uninstall.cmake.in diff --git a/CODEOWNERS b/CODEOWNERS index 9eb3a858272c68..2982b405c3df44 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -7,7 +7,7 @@ # Each line is a file pattern followed by one or more owners. # For module labels => owners mapping, please see https://github.com/pytorch/pytorch/issues/24422. -/torch/utils/cpp_extension.py @fmassa @soumith @ezyang +/torch/utils/cpp_extension.py @fmassa @ezyang @malfet # Not there to strictly require the approval, but to be tagged as a reviewer # on the PRs to push them into a high priority inbox. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6361a5ddc2128e..9d677901170c6d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -228,6 +228,8 @@ dependencies as well as the nightly binaries into the repo directory. details. * [cuda](aten/src/ATen/native/cuda) - CUDA implementations of operators. + * [mps](aten/src/ATen/native/mps) - MPS implementations of + operators for Apple's Metal GPU family. * [sparse](aten/src/ATen/native/sparse) - CPU and CUDA implementations of COO sparse tensor operations * [mkl](aten/src/ATen/native/mkl) [mkldnn](aten/src/ATen/native/mkldnn) @@ -343,13 +345,7 @@ command runs tests such as `TestNN.test_BCELoss` and ### Local linting -Install all prerequisites by running - -```bash -make setup-lint -``` - -You can now run the same linting steps that are used in CI locally via `make`: +You can run the same linting steps that are used in CI locally via `make`: ```bash make lint @@ -473,7 +469,7 @@ In addition to the standard Google Style docstring formatting rules, the followi ### Building documentation -To build the documentation: +Note that the docs will only build with Python versions <3.13. To build the documentation: 1. Build and install PyTorch @@ -583,9 +579,8 @@ rsync -az me@my_machine:/path/to/pytorch/docs/cpp/build/html cpp/build ### Previewing documentation on PRs -PyTorch will host documentation previews at `https://docs-preview.pytorch.org/pytorch/pytorch//index.html` once the -`pytorch_python_doc_build` GitHub Actions job has completed on your PR. You can visit that page directly -or find its link in the automated Dr. CI comment on your PR. +PyTorch will host documentation previews at `https://docs-preview.pytorch.org/pytorch/pytorch//index.html` once the docs GitHub Actions job has completed on your PR. You can find its link in the automated pytorchbot comment on your PR or go to the URL +directly. ### Adding documentation tests diff --git a/Dockerfile b/Dockerfile index 5cec2173063be1..48d014ea1b9e77 100644 --- a/Dockerfile +++ b/Dockerfile @@ -70,7 +70,7 @@ RUN /opt/conda/bin/conda install -y python=${PYTHON_VERSION} ARG TARGETPLATFORM -# INSTALL_CHANNEL whl - release, whl/nightly - nightly, whle/test - test channels +# INSTALL_CHANNEL whl - release, whl/nightly - nightly, whl/test - test channels RUN case ${TARGETPLATFORM} in \ "linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchaudio ;; \ *) pip install --index-url https://download.pytorch.org/${INSTALL_CHANNEL}/${CUDA_PATH#.}/ torch torchvision torchaudio ;; \ diff --git a/MANIFEST.in b/MANIFEST.in index f6ffb4e02a8afb..ec00f251160b7b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,31 +1,50 @@ -include MANIFEST.in +# Reference: https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html + +# Include source files in SDist include CMakeLists.txt +include *.bzl *.bazel .bazel* BUILD *.BUILD BUILD.* WORKSPACE +include BUCK BUCK.* +include requirements*.txt +include version.txt +include [Mm]akefile *.[Mm]akefile [Mm]akefile.* +include [Dd]ockerfile *.[Dd]ockerfile [Dd]ockerfile.* .dockerignore +graft android +graft aten +graft binaries +graft c10 +graft caffe2 +graft cmake +graft functorch +graft third_party +graft tools +graft torch +graft torchgen +# FIXME: torch-xla build during codegen will fail if include this file in wheel +exclude torchgen/BUILD.bazel + +# Misc files and directories in SDist +include *.md include CITATION.cff -include LICENSE -include NOTICE +include LICENSE NOTICE +include mypy*.ini +graft benchmarks +graft docs +graft mypy_plugins +graft scripts + +# Misc files needed for custom setuptools command +include .gitignore include .gitmodules -include build_variables.bzl -include mypy.ini -include requirements.txt -include ufunc_defs.bzl -include version.txt -recursive-include android *.* -recursive-include aten *.* -recursive-include binaries *.* -recursive-include c10 *.* -recursive-include caffe2 *.* -recursive-include cmake *.* -recursive-include torch *.* -recursive-include tools *.* -recursive-include test *.* -recursive-include docs *.* -recursive-include ios *.* -recursive-include third_party * -recursive-include test *.* -recursive-include benchmarks *.* -recursive-include scripts *.* -recursive-include mypy_plugins *.* -recursive-include modules *.* -recursive-include functorch *.* + +# Include test suites in SDist +graft test +include pytest.ini +include .coveragerc + +# Prune generated/compiled files +prune torchgen/packaged prune */__pycache__ -global-exclude *.o *.so *.dylib *.a .git *.pyc *.swp +global-exclude *.o *.obj *.so *.a *.dylib *.pxd *.dll *.lib *.py[cod] + +prune */.git +global-exclude .git *~ *.swp diff --git a/Makefile b/Makefile index e5b4386b5dd225..3db2b7aa44e76f 100644 --- a/Makefile +++ b/Makefile @@ -1,59 +1,92 @@ # This makefile does nothing but delegating the actual building to cmake. -PYTHON = python3 -PIP = $(PYTHON) -m pip + +SHELL = /bin/bash +.SHELLFLAGS := -eu -o pipefail -c +PYTHON ?= $(shell command -v python3 || command -v python) +PIP = $(PYTHON) -m pip NIGHTLY_TOOL_OPTS := pull +.PHONY: all all: - @mkdir -p build && cd build && cmake .. $(shell $(PYTHON) ./scripts/get_python_cmake_flags.py) && $(MAKE) + @cmake -S . -B build $(shell $(PYTHON) ./scripts/get_python_cmake_flags.py) && \ + cmake --build build --parallel -- +.PHONY: local local: @./scripts/build_local.sh +.PHONY: android android: @./scripts/build_android.sh +.PHONY: ios ios: @./scripts/build_ios.sh +.PHONY: triton +triton: + $(PIP) uninstall -y triton + @./scripts/install_triton_wheel.sh + +.PHONY: clean clean: # This will remove ALL build folders. - @rm -r build*/ + @rm -r build*/ || true +.PHONY: linecount linecount: @cloc --read-lang-def=caffe.cloc caffe2 || \ echo "Cloc is not available on the machine. You can install cloc with " && \ echo " sudo apt-get install cloc" +.PHONY: ensure-branch-clean ensure-branch-clean: @if [ -n "$(shell git status --porcelain)" ]; then \ echo "Please commit or stash all changes before running this script"; \ exit 1; \ fi +.PHONY: setup-env setup-env: ensure-branch-clean $(PYTHON) tools/nightly.py $(NIGHTLY_TOOL_OPTS) +.PHONY: setup-env-cuda setup-env-cuda: $(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --cuda" +.PHONY: setup-env-rocm setup-env-rocm: $(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --rocm" -setup_env: setup-env -setup_env_cuda: setup-env-cuda -setup_env_rocm: setup-env-rocm - -setup-lint: +.PHONY: setup-lint +setup-lint .lintbin/.lintrunner.sha256: requirements.txt pyproject.toml .lintrunner.toml + @echo "Setting up lintrunner..." $(PIP) install lintrunner lintrunner init + @echo "Generating .lintrunner.sha256..." + @mkdir -p .lintbin + @sha256sum requirements.txt pyproject.toml .lintrunner.toml > .lintbin/.lintrunner.sha256 + +.PHONY: lazy-setup-lint +lazy-setup-lint: .lintbin/.lintrunner.sha256 + @if [ ! -x "$(shell command -v lintrunner)" ]; then \ + $(MAKE) setup-lint; \ + fi -setup_lint: setup-lint +.PHONY: lint +lint: lazy-setup-lint + lintrunner --all-files -lint: +.PHONY: quicklint +quicklint: lazy-setup-lint lintrunner -quicklint: - lintrunner +.PHONY: quickfix +quickfix: lazy-setup-lint + lintrunner --apply-patches -triton: - $(PIP) uninstall -y triton - @./scripts/install_triton_wheel.sh +# Deprecated target aliases +.PHONY: setup_env setup_env_cuda setup_env_rocm setup_lint +setup_env: setup-env +setup_env_cuda: setup-env-cuda +setup_env_rocm: setup-env-rocm +setup_lint: setup-lint diff --git a/README.md b/README.md index 15bf60977cce33..561495692feb56 100644 --- a/README.md +++ b/README.md @@ -189,16 +189,23 @@ $ conda activate $ call "C:\Program Files\Microsoft Visual Studio\\Community\VC\Auxiliary\Build\vcvarsall.bat" x64 ``` +A conda environment is not required. You can also do a PyTorch build in a +standard virtual environment, e.g., created with tools like `uv`, provided +your system has installed all the necessary dependencies unavailable as pip +packages (e.g., CUDA, MKL.) + ##### NVIDIA CUDA Support If you want to compile with CUDA support, [select a supported version of CUDA from our support matrix](https://pytorch.org/get-started/locally/), then install the following: - [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) - [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v8.5 or above - [Compiler](https://gist.github.com/ax3l/9489132) compatible with CUDA -Note: You could refer to the [cuDNN Support Matrix](https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html) for cuDNN versions with the various supported CUDA, CUDA driver and NVIDIA hardware +Note: You could refer to the [cuDNN Support Matrix](https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html) for cuDNN versions with the various supported CUDA, CUDA driver, and NVIDIA hardware. If you want to disable CUDA support, export the environment variable `USE_CUDA=0`. -Other potentially useful environment variables may be found in `setup.py`. +Other potentially useful environment variables may be found in `setup.py`. If +CUDA is installed in a non-standard location, set PATH so that the nvcc you +want to use can be found (e.g., `export PATH=/usr/local/cuda-12.8/bin:$PATH`). If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xavier), Instructions to install PyTorch for Jetson Nano are [available here](https://devtalk.nvidia.com/default/topic/1049071/jetson-nano/pytorch-for-jetson-nano/) @@ -377,14 +384,14 @@ with such a step. On Linux ```bash export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" -python setup.py build --cmake-only +CMAKE_ONLY=1 python setup.py build ccmake build # or cmake-gui build ``` On macOS ```bash export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" -MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build --cmake-only +MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ CMAKE_ONLY=1 python setup.py build ccmake build # or cmake-gui build ``` @@ -550,7 +557,7 @@ To learn more about making a contribution to Pytorch, please see our [Contributi PyTorch is a community-driven project with several skillful engineers and researchers contributing to it. PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. -A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jamesb93), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). +A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jekbradbury), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. diff --git a/RELEASE.md b/RELEASE.md index f1f83190e001cd..047bb10161f71c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -50,6 +50,7 @@ Following is the Release Compatibility Matrix for PyTorch releases: | PyTorch version | Python | C++ | Stable CUDA | Experimental CUDA | Stable ROCm | | --- | --- | --- | --- | --- | --- | +| 2.8 | >=3.9, <=3.13, (3.13t experimental) | C++17 | CUDA 12.6 (CUDNN 9.10.2.21), CUDA 12.8 (CUDNN 9.10.2.21) | CUDA 12.9 (CUDNN 9.10.2.21) | ROCm 6.4 | | 2.7 | >=3.9, <=3.13, (3.13t experimental) | C++17 | CUDA 11.8 (CUDNN 9.1.0.70), CUDA 12.6 (CUDNN 9.5.1.17) | CUDA 12.8 (CUDNN 9.7.1.26) | ROCm 6.3 | | 2.6 | >=3.9, <=3.13, (3.13t experimental) | C++17 | CUDA 11.8, CUDA 12.4 (CUDNN 9.1.0.70) | CUDA 12.6 (CUDNN 9.5.1.17) | ROCm 6.2.4 | | 2.5 | >=3.9, <=3.12, (3.13 experimental) | C++17 | CUDA 11.8, CUDA 12.1, CUDA 12.4, CUDNN 9.1.0.70 | None | ROCm 6.2 | @@ -73,9 +74,9 @@ Following is the release cadence. All future dates below are tentative. For late | 2.4 | Jun 2024 | Jul 2024 | Sept 2024 | Not planned | | 2.5 | Sep 2024 | Oct 2024 | Nov 2024 | Not planned | | 2.6 | Dec 2024 | Jan 2025 | Not planned | Not planned | -| 2.7 | Mar 2025 | Apr 2025 | (May 2025) | (Jun 2025) | +| 2.7 | Mar 2025 | Apr 2025 | Jun 2025 | Not planned | | 2.8 | Jun 2025 | Jul 2025 | (Aug 2025) | (Sep 2025) | -| 2.9 | Aug 2025 | Oct 2025 | (Nov 2025) | (Dec 2025) | +| 2.9 | Sept 2025 | Oct 2025 | (Nov 2025) | (Dec 2025) | | 2.10 | Dec 2025 | Jan 2026 | (Feb 2026) | (Mar 2026) | | 2.11 | Mar 2026 | Apr 2026 | (Jun 2026) | (Jul 2026) | diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index bda6aea327062f..d787d0850ab67b 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -18,6 +18,7 @@ cmake_policy(SET CMP0012 NEW) ############################################# set(ATen_CPU_SRCS) +set(ATen_MTIA_SRCS) set(ATen_XPU_SRCS) set(ATen_XPU_INCLUDE) set(ATen_CPU_TEST_SRCS) @@ -101,6 +102,13 @@ else() set(AT_CUSPARSELT_ENABLED 1) endif() +# Add hipSPARSELt support flag if the package is available. +if(USE_ROCM AND hipsparselt_FOUND) + set(AT_HIPSPARSELT_ENABLED 1) +else() + set(AT_HIPSPARSELT_ENABLED 0) +endif() + list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/src) add_subdirectory(src/ATen) @@ -108,6 +116,7 @@ add_subdirectory(src/ATen) # Pass source, includes, and libs to parent set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE) +set(ATen_MTIA_SRCS ${ATen_MTIA_SRCS} PARENT_SCOPE) set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE) set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index bab044d3c01812..af8fea2529477f 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +cmake_minimum_required(VERSION 3.27 FATAL_ERROR) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) if(NOT MSVC) @@ -34,6 +34,7 @@ set_bool(AT_MAGMA_ENABLED USE_MAGMA) set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA) set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN) set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT) +set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT) configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h") # TODO: Do not generate CUDAConfig.h for ROCm BUILDS @@ -65,6 +66,12 @@ file(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh") file(GLOB cudnn_cpp "cudnn/*.cpp") file(GLOB ops_h "ops/*.h") +# MTIA +file(GLOB mtia_h "mtia/*.h" "mtia/detail/*.h") +file(GLOB mtia_cpp "mtia/*.cpp" "mtia/detail/*.cpp") +file(GLOB_RECURSE native_mtia_cpp "native/mtia/*.cpp") +file(GLOB_RECURSE native_mtia_h "native/mtia/*.h") + file(GLOB xpu_h "xpu/*.h" "xpu/detail/*.h") file(GLOB xpu_cpp "xpu/*.cpp" "xpu/detail/*.cpp") @@ -162,14 +169,10 @@ file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip") file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp") file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp") file(GLOB native_utils_cpp "native/utils/*.cpp") - -# flash_attention sources file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu) -# Flash attention C++ sources -file(GLOB flash_attention_cuda_cpp - "${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp" - "native/transformers/cuda/flash_attn/flash_api.cpp" -) +file(GLOB flash_attention_cuda_cpp ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp) +file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_api.cpp") + # flash_attention hip sources file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") @@ -190,6 +193,10 @@ if(USE_FLASH_ATTENTION) add_subdirectory(native/transformers/hip/flash_attn/ck) file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + # FAv3 Generation + add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) + file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) endif() endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") @@ -201,10 +208,29 @@ file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/ file(GLOB mem_eff_attention_cuda_kernels_cu "native/transformers/cuda/mem_eff_attention/kernels/*.cu") file(GLOB mem_eff_attention_cuda_cpp "native/transformers/cuda/mem_eff_attention/*.cpp") +if(USE_CUDA AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)) + add_library(flash_attention OBJECT EXCLUDE_FROM_ALL ${flash_attention_cuda_kernels_cu} ${flash_attention_cuda_cpp}) + + target_include_directories(flash_attention PUBLIC + ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc + ${PROJECT_SOURCE_DIR}/third_party/flash-attention/include + ${PROJECT_SOURCE_DIR}/third_party/cutlass/include + ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src + ) + + target_compile_definitions(flash_attention PRIVATE + # Copied from https://github.com/pytorch/pytorch/blob/a10024d7dea47c52469059a47efe376eb20adca0/caffe2/CMakeLists.txt#L1431 + FLASH_NAMESPACE=pytorch_flash + FLASHATTENTION_DISABLE_ALIBI + FLASHATTENTION_DISABLE_SOFTCAP + UNFUSE_FMA + ) + + set_target_properties(flash_attention PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() + if(USE_FLASH_ATTENTION) - list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu}) - list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_kernels_cu}) - list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp}) + list(APPEND native_transformers_cuda_cpp ${native_flash_attn_api_cpp}) list(APPEND FLASH_ATTENTION_CUDA_SOURCES ${flash_attention_cuda_cu} ${flash_attention_cuda_kernels_cu}) list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu}) @@ -269,6 +295,10 @@ else() set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp}) endif() +if(USE_MTIA) + set(ATen_MTIA_SRCS ${ATen_MTIA_SRCS} ${mtia_cpp} ${mtia_h} ${native_mtia_cpp} ${native_mtia_h}) +endif() + if(USE_XPU) list(APPEND ATen_XPU_SRCS ${mkldnn_xpu_cpp}) list(APPEND ATen_XPU_DEPENDENCY_LIBS xpu_mkldnn) @@ -366,6 +396,7 @@ if(USE_ROCM) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include) _pytorch_rocm_generate_ck_conf() # Next two lines are needed because TunableOp uses third-party/fmt @@ -677,7 +708,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) if(NOT INTERN_BUILD_MOBILE) - list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h}) + list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${native_mtia_h} ${cudnn_h} ${hip_h} ${mtia_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h}) # Metal if(USE_PYTORCH_METAL_EXPORT) # Add files needed from exporting metal models(optimized_for_mobile) @@ -748,6 +779,7 @@ set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SC set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) set(ATen_MPS_SRCS ${ATen_MPS_SRCS} PARENT_SCOPE) +set(ATen_MTIA_SRCS ${ATen_MTIA_SRCS} PARENT_SCOPE) set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE) set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE) set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index fd346b2d9af005..cac0e31eaad46c 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -19,9 +19,69 @@ #if defined(__aarch64__) && !defined(C10_MOBILE) #include #endif - namespace at { +namespace { + +/* + These const variables defined the fp32 precisions for different backend + We have "generic", "cuda", "mkldnn" backend now and we can choose fp32 + prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means + IEEE standard floating point format "tf32" and "bf16" means we are allowed to + use "tf32" or "bf16" as internal computation data types for fp32 computations. + And "none" means it is override-able by parent's node + + generic->mkldnn->matmul + ->conv + ->rnn + ->cuda ->matmul + ->conv + ->rnn +*/ +const std::map> _fp32_precisions = { + {"generic", {{"ieee", "tf32", "bf16", "none"}}}, + {"mkldnn", {{"ieee", "bf16", "none"}}}, + {"cuda", {{"ieee", "tf32", "none"}}}}; + +// Check whether the backend and op are legal +void check_fp32_prec_backend_and_op( + const std::string& backend, + const std::string& op) { + static std::vector backends = {"generic", "mkldnn", "cuda"}; + static std::vector operators = {"conv", "matmul", "rnn", "all"}; + TORCH_CHECK( + std::find(backends.begin(), backends.end(), backend) != backends.end(), + "Invalid backend: ", + backend); + TORCH_CHECK( + std::find(operators.begin(), operators.end(), op) != operators.end(), + "Invalid operator: ", + op); + if (backend == "generic") { + TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op); + } + } + + // Return whether the precision is supported by backends + bool validate_fp32_prec( + const std::string& backend, + const std::string& precision) { + auto iterp = _fp32_precisions.find(backend); + TORCH_CHECK(iterp != _fp32_precisions.end()); + auto precisions = iterp->second; + bool valid = std::find(precisions.begin(), precisions.end(), precision) != + precisions.end(); + return valid; + } + + C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){ + TORCH_WARN_ONCE( + "This API is going to be deprecated, please see " + "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices" + ); + } +} // namespace + Context::Context() = default; // TODO: This could be bad juju if someone calls globalContext() in the @@ -115,12 +175,29 @@ void Context::setUserEnabledNNPACK(bool e) { enabled_nnpack = e; } -bool Context::allowTF32CuDNN() const { +bool Context::allowTF32CuDNN(const std::string& op) const { + if (op.size() == 0){ + bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32"; + bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32"; + TORCH_CHECK( + allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn, + "PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,", + "but the current flag(s) indicate that cuDNN conv and cuDNN RNN have different TF32 flags.", + "This combination indicates that you have used a mix of the legacy and new APIs to set the TF32 flags. ", + "We suggest only using the new API to set the TF32 flag(s). See also: ", + "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); + } else { + return float32Precision("cuda", op) == "tf32"; + } + warn_deprecated_fp32_precision_api(); return allow_tf32_cudnn; } void Context::setAllowTF32CuDNN(bool b) { + setFloat32Precision("cuda", "rnn", b ? "tf32" : "none"); + setFloat32Precision("cuda", "conv", b ? "tf32" : "none"); allow_tf32_cudnn = b; + warn_deprecated_fp32_precision_api(); } void Context::setSDPPriorityOrder(const std::vector& order) { @@ -141,12 +218,13 @@ bool Context::allowTF32OneDNN() const { return allow_tf32_onednn; } -void Context::setAllowTF32OneDNN(bool b){ -#ifdef USE_XPU + // NOLINTNEXTLINE(clang-diagnostic-unused-parameter) + void Context::setAllowTF32OneDNN(bool b){ + #ifdef USE_XPU allow_tf32_onednn = b; -#else + #else TORCH_WARN("TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support."); -#endif + #endif } bool Context::userEnabledFlashSDP() const { @@ -259,7 +337,16 @@ bool Context::allowTF32CuBLAS() const { return false; } #endif - return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; + bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; + bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32"; + TORCH_CHECK( + legacy_allow_tf32 == allow_tf32_new, + "PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,", + "Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ", + "We suggest only using the new API to set the TF32 flag. See also: ", + "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); + warn_deprecated_fp32_precision_api(); + return allow_tf32_new; } void Context::setAllowTF32CuBLAS(bool b) { @@ -272,27 +359,54 @@ void Context::setAllowTF32CuBLAS(bool b) { } #endif float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; + setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee"); } Float32MatmulPrecision Context::float32MatmulPrecision() const { + bool invalid = float32Precision("cuda", "matmul") == "tf32" && + float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST; + invalid = invalid || + (float32Precision("mkldnn", "matmul") == "bf16" && + float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM); + TORCH_CHECK( + !invalid, + "PyTorch is checking the matmul precision without a specific backend name,", + "Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ", + "We suggest only using the new API for matmul precision. See also: ", + "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); + warn_deprecated_fp32_precision_api(); return float32_matmul_precision; } -void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) { - float32_matmul_precision = p; +std::string Context::float32Precision(const std::string& backend, const std::string& op) const { + check_fp32_prec_backend_and_op(backend, op); + auto precision = fp32_precision.find(backend)->second.find(op)->second; + if (precision == "none") + precision = fp32_precision.find(backend)->second.find("all")->second; + if (precision == "none") + precision = fp32_precision.find("generic")->second.find("all")->second; + bool valid_prec = validate_fp32_prec(backend, precision); + return valid_prec ? precision : "none"; } void Context::setFloat32MatmulPrecision(const std::string &s) { auto match = [this](const std::string & s_) { + warn_deprecated_fp32_precision_api(); // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention if (s_ == "highest") { float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; + setFloat32Precision("cuda", "matmul", "ieee"); + setFloat32Precision("mkldnn", "matmul", "ieee"); return true; } else if (s_ == "high") { float32_matmul_precision = at::Float32MatmulPrecision::HIGH; + setFloat32Precision("cuda", "matmul", "tf32"); + setFloat32Precision("mkldnn", "matmul", "ieee"); return true; } else if (s_ == "medium") { float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM; + setFloat32Precision("cuda", "matmul", "tf32"); + setFloat32Precision("mkldnn", "matmul", "bf16"); return true; } return false; @@ -306,6 +420,27 @@ void Context::setFloat32MatmulPrecision(const std::string &s) { "setFloat32MatmulPrecision call has no effect."); } +void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) { + check_fp32_prec_backend_and_op(backend, op); + if (validate_fp32_prec(backend, p)) { + fp32_precision[backend][op] = p; + } else { + std::string msg; + auto iterp = _fp32_precisions.find(backend); + TORCH_CHECK(iterp != _fp32_precisions.end()); + for (auto p : iterp->second) { + msg += p; + msg += " "; + } + TORCH_WARN( + "you have set wrong precision for backend:", + backend, + " setFloat32Precision call has no effect.", + "Please choose precision from: ", + msg); + } +} + at::LinalgBackend Context::linalgPreferredBackend() const { return linalg_preferred_backend; } @@ -535,13 +670,14 @@ at::QEngine Context::qEngine() const { #endif return qengine; }(); - return quantized_engine.value_or(_quantized_engine); + auto qt_engine = quantized_engine.load(); + return qt_engine == at::QEngine::NoQEngine ? _quantized_engine : qt_engine; } void Context::setQEngine(at::QEngine e) { const auto& qengines = supportedQEngines(); if (std::find(qengines.begin(), qengines.end(), e) != qengines.end()) { - quantized_engine = e; + quantized_engine.store(e); return; } TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported"); @@ -553,17 +689,9 @@ const std::vector& Context::supportedQEngines() { // Engines are listed in priority order: later one wins // By default we prefer FBGEMM if we're running on server side // QNNPACK on server side has some issue, so we disable it by default. -#ifdef C10_MOBILE - engines.push_back(at::kNoQEngine); -#ifdef USE_PYTORCH_QNNPACK - engines.push_back(at::kQNNPACK); -#endif -#else // C10_MOBILE #ifdef USE_PYTORCH_QNNPACK engines.push_back(at::kQNNPACK); #endif - engines.push_back(at::kNoQEngine); -#endif // C10_MOBILE #if AT_MKLDNN_ENABLED() engines.push_back(at::kONEDNN); @@ -695,6 +823,7 @@ void Context::setAllowFP16ReductionCPU(bool b) { #if defined(__aarch64__) && !defined(C10_MOBILE) if (!cpuinfo_initialize() || !cpuinfo_has_arm_fp16_arith()) #else + // NOLINTNEXTLINE(facebook-hte-MissingBraces) if (true) #endif TORCH_CHECK(false, "Float16 arithmetic is not supported by the CPU!"); diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 6de119c2c63bf6..04a35e06630a79 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -28,6 +28,7 @@ #include #include +#include #include namespace at { @@ -336,14 +337,20 @@ class TORCH_API Context { void alertCuBLASConfigNotDeterministic() const; void setFloat32MatmulPrecision(const std::string& s); - bool allowTF32CuDNN() const; + void setFloat32Precision( + const std::string& backend, + const std::string& op, + const std::string& s); + bool allowTF32CuDNN(const std::string& op = std::string()) const; void setAllowTF32CuDNN(bool); bool allowTF32OneDNN() const; void setAllowTF32OneDNN(bool); bool allowTF32CuBLAS() const; void setAllowTF32CuBLAS(bool); Float32MatmulPrecision float32MatmulPrecision() const; - void setFloat32MatmulPrecision(Float32MatmulPrecision p); + std::string float32Precision( + const std::string& backend, + const std::string& op) const; bool allowFP16ReductionCuBLAS() const; void setAllowFP16ReductionCuBLAS(bool); bool allowBF16ReductionCuBLAS() const; @@ -465,10 +472,27 @@ class TORCH_API Context { bool release_original_weights = false; #endif bool display_vmap_fallback_warnings_ = false; - std::optional quantized_engine = std::nullopt; + std::atomic quantized_engine = at::QEngine::NoQEngine; bool enable_sparse_tensor_invariant_checks = false; bool allow_fp16_reduction_cpu = false; + std::map> fp32_precision = { + {"generic", {{"all", "none"}}}, + {"mkldnn", + {{"matmul", "none"}, + {"conv", "none"}, + {"rnn", "none"}, + {"all", "none"}}}, + {"cuda", + {{"matmul", + float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST + ? "none" + : "tf32"}, + {"conv", "tf32"}, + {"rnn", "tf32"}, + {"all", "none"}}}, + }; + Allocator* prev_allocator_ptr_{nullptr}; }; diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 29f7e7351b667d..f25e68001ff4dd 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -266,19 +266,38 @@ ScalarType toScalarType(const DLDataType& dtype) { } namespace { + +// The templated classes below are needed for supporting both: +// - DLManagedTensor +// - DLManagedTensorVersioned +template struct ATenDLMTensor { Tensor handle; - DLManagedTensor tensor{}; + T tensor{}; }; -} // namespace -static void deleter(DLManagedTensor* arg) { - delete static_cast(arg->manager_ctx); +template +void deleter(T* arg) { + delete static_cast*>(arg->manager_ctx); +} + +// Adds version information for DLManagedTensorVersioned. +// This is a no-op for the other types. +template +void fillVersion(T* tensor) {} + +template <> +void fillVersion( + DLManagedTensorVersioned* tensor) { + tensor->flags = 0; + tensor->version.major = DLPACK_MAJOR_VERSION; + tensor->version.minor = DLPACK_MINOR_VERSION; } // This function returns a shared_ptr to memory managed DLpack tensor // constructed out of ATen tensor -DLManagedTensor* toDLPack(const Tensor& src) { +template +T* toDLPackImpl(const Tensor& src) { // create a new tensor with possibly normalized strides // gh-83069 auto shape = src.sizes(); @@ -290,10 +309,10 @@ DLManagedTensor* toDLPack(const Tensor& src) { } auto view = src.as_strided(shape, strides, src.storage_offset()); - ATenDLMTensor* atDLMTensor(new ATenDLMTensor); + ATenDLMTensor* atDLMTensor(new ATenDLMTensor); atDLMTensor->handle = view; atDLMTensor->tensor.manager_ctx = atDLMTensor; - atDLMTensor->tensor.deleter = &deleter; + atDLMTensor->tensor.deleter = &deleter; atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); c10::DeviceIndex device_id = 0; if (src.is_cuda() || src.is_privateuseone()) { @@ -305,35 +324,68 @@ DLManagedTensor* toDLPack(const Tensor& src) { atDLMTensor->tensor.dl_tensor.shape = view.sizes().data(); atDLMTensor->tensor.dl_tensor.strides = view.strides().data(); atDLMTensor->tensor.dl_tensor.byte_offset = 0; + fillVersion(&atDLMTensor->tensor); + return &(atDLMTensor->tensor); } -Tensor fromDLPack(DLManagedTensor* src) { - auto deleter = [src](void* self [[maybe_unused]]) { - if (src->deleter) { - src->deleter(src); - } - }; - return fromDLPack(src, std::move(deleter)); -} +// Explicitly instantiate the template above for both classes. +template DLManagedTensor* toDLPackImpl(const Tensor&); +template DLManagedTensorVersioned* toDLPackImpl(const Tensor&); -Tensor fromDLPack(DLManagedTensor* src, std::function deleter) { - Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data); - ScalarType stype = toScalarType(src->dl_tensor.dtype); - if (!src->dl_tensor.strides) { +// This function constructs a Tensor from a memory managed DLPack which +// may be represented as either: DLManagedTensor and DLManagedTensorVersioned. +template +at::Tensor fromDLPackImpl(T* src, std::function deleter) { + if (!deleter) { + deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + } + + DLTensor& dl_tensor = src->dl_tensor; + Device device = getATenDevice(dl_tensor.device, dl_tensor.data); + ScalarType stype = toScalarType(dl_tensor.dtype); + + if (!dl_tensor.strides) { return at::from_blob( - src->dl_tensor.data, - IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), + dl_tensor.data, + IntArrayRef(dl_tensor.shape, dl_tensor.ndim), std::move(deleter), at::device(device).dtype(stype), {device}); } return at::from_blob( - src->dl_tensor.data, - IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), - IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim), + dl_tensor.data, + IntArrayRef(dl_tensor.shape, dl_tensor.ndim), + IntArrayRef(dl_tensor.strides, dl_tensor.ndim), deleter, at::device(device).dtype(stype), {device}); } + +// Explicitly instantiate the template above for both classes. +template at::Tensor fromDLPackImpl(DLManagedTensor* src, std::function deleter); +template at::Tensor fromDLPackImpl(DLManagedTensorVersioned* src, std::function deleter); + +} // namespace + +DLManagedTensor* toDLPack(const Tensor& src) { + return toDLPackImpl(src); +} + +DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src) { + return toDLPackImpl(src); +} + +Tensor fromDLPack(DLManagedTensor* src, std::function deleter) { + return fromDLPackImpl(src, std::move(deleter)); +} + +Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function deleter) { + return fromDLPackImpl(src, std::move(deleter)); +} + } // namespace at diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index d43d189002a3f8..abc996db5ab462 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -12,10 +12,48 @@ namespace at { TORCH_API ScalarType toScalarType(const DLDataType& dtype); TORCH_API DLManagedTensor* toDLPack(const Tensor& src); -TORCH_API Tensor fromDLPack(DLManagedTensor* src); +TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src); TORCH_API Tensor -fromDLPack(DLManagedTensor* src, std::function deleter); +fromDLPack(DLManagedTensor* src, std::function deleter = {}); +TORCH_API Tensor fromDLPackVersioned( + DLManagedTensorVersioned* src, + std::function deleter = {}); TORCH_API DLDataType getDLDataType(const Tensor& t); TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); +// This trait class is used for retrieving different attributes, such as the +// PyCapsule names and conversion functions for both DLPack tensor classes: +// `DLManagedTensor` and `DLManagedTensorVersioned`. +// +// Each specialization should contain the following 2 traits: +// - `capsule`: actual name of the capsule +// - `used`: name of the capsule after using it +// - `toDLPack`: function for converting a tensor into a DLPack capsule +// - `fromDLPack`: function for creating a tensor from a DLPack capsule +// +// While `toDLPack` is the directly exposed to Python, `fromDLPack` is not. +// Although it contains the core implementation, it lacks the required book +// keeping logic contained in its caller `tensor_fromDLPack`. +// +// That said, `fromDLPack` is used directly in a few DLPack tests that live +// inside ATen (no Python available). +template +struct DLPackTraits {}; + +template <> +struct DLPackTraits { + inline static const char* capsule = "dltensor"; + inline static const char* used = "used_dltensor"; + inline static auto toDLPack = at::toDLPack; + inline static auto fromDLPack = at::fromDLPack; +}; + +template <> +struct DLPackTraits { + inline static const char* capsule = "dltensor_versioned"; + inline static const char* used = "used_dltensor_versioned"; + inline static auto toDLPack = at::toDLPackVersioned; + inline static auto fromDLPack = at::fromDLPackVersioned; +}; + } // namespace at diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 5c7b39c6427a86..15a862274f003a 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -200,7 +200,7 @@ inline at::ScalarType scalar_type(at::ScalarType s) { switch (_st) { \ __VA_ARGS__ \ default: \ - TORCH_CHECK( \ + TORCH_CHECK_NOT_IMPLEMENTED( \ false, \ '"', \ at_dispatch_name, \ diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index b124ac39001474..5634733325a2ef 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -28,8 +28,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { opt_device_type = at::getAccelerator(false); } if (opt_device_type.has_value()) { - return at::globalContext().getPinnedMemoryAllocator( - opt_device_type.value()); + return at::globalContext().getPinnedMemoryAllocator(opt_device_type); } else { TORCH_CHECK( false, "Need to provide pin_memory allocator to use pin memory.") diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index e9abc85b59c307..090699339ccffc 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -461,9 +461,17 @@ inline Tensor _sum_to( reduce_dims.push_back(i); } for (int64_t i = leading_dims; i < static_cast(sizes.size()); ++i) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) && - TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) { + if (TORCH_GUARD_OR_FALSE(sym_eq(shape[i - leading_dims], 1)) && + TORCH_GUARD_OR_TRUE(sym_ne(sizes[i], 1))) { reduce_dims.push_back(i); + } else { + // if we assume no reduction due to unbacked we ensure that at runtime. + TORCH_MAYBE_SYM_CHECK( + sym_eq(shape[i - leading_dims], sizes[i]), + "non-reduction path was assumed due to unabcked symbols expected those two sizes to be the same:", + shape[i - leading_dims], + ", ", + sizes[i]) } } diff --git a/aten/src/ATen/FunctionalStorageImpl.h b/aten/src/ATen/FunctionalStorageImpl.h index 3f80171196fbc7..8cd1cb7434aa3f 100644 --- a/aten/src/ATen/FunctionalStorageImpl.h +++ b/aten/src/ATen/FunctionalStorageImpl.h @@ -122,6 +122,9 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { ~FunctionalStorageImpl() override = default; + uint64_t mutation_counter() { + return mutation_counter_; + } void mark_mutation() { mutation_counter_++; } @@ -150,12 +153,17 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { void mark_inductor_storage_resize(c10::SymInt new_size) { inductor_storage_resized_ = true; curr_storage_size_ = std::move(new_size); + inductor_storage_resized_counter_++; } bool was_inductor_storage_resized() { return inductor_storage_resized_; } + uint64_t inductor_storage_resized_counter() { + return inductor_storage_resized_counter_; + } + private: // NB: base_ should always point to a tensor BELOW the current // functionalization layer. This is mainly to avoid reference cycles. e.g. @@ -201,6 +209,7 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { // (1) There were any storage resizes on a graph input // (2) The original/curr storage size tell us if these resizes result in a nop bool inductor_storage_resized_ = false; + uint64_t inductor_storage_resized_counter_ = 0; c10::SymInt original_storage_size_; c10::SymInt curr_storage_size_; }; diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index a634dea4557c12..ff4e2b562278b2 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -178,7 +178,7 @@ bool FunctionalTensorWrapper::is_up_to_date() const { // See Note [Functionalization Pass - Inplace View Ops] void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { view_metas_.push_back(meta); - // Manually track the fact that this tensor recieved a metadata mutation! + // Manually track the fact that this tensor received a metadata mutation! has_metadata_mutation_ = true; // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. maybe_mark_symbolic(meta); @@ -273,7 +273,7 @@ void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) { // (We could check if the updated value has a new storage than the original value, // but this won't also let us uniquely determine if the tensor **also** // experienced a data mutation). - was_storage_changed_ = true; + mark_storage_changed(); auto sizes_ = value_.sym_sizes(); auto strides_ = value_.sym_strides(); @@ -499,8 +499,8 @@ int64_t FunctionalTensorWrapper::dim_custom() const { int64_t FunctionalTensorWrapper::numel_custom() const { return value_.unsafeGetTensorImpl()->numel(); } -bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const { - return value_.unsafeGetTensorImpl()->is_contiguous(memory_format); +c10::SymBool FunctionalTensorWrapper::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { + return value_.unsafeGetTensorImpl()->sym_is_contiguous(memory_format); } c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const { return value_.unsafeGetTensorImpl()->sym_sizes(); @@ -579,7 +579,7 @@ std::vector from_functional_tensor(ITensorListRef t_list) { for (const auto& tensor : t_list) { // from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call // it on a non-functional input, - // but from_functional_tensor(TensorList) can recieve a list containing both + // but from_functional_tensor(TensorList) can receive a list containing both // functional and non-functional tensors. // Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor). // When that happens, we're okay with only unwrapping the functional tensors. diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index c418ef39427c05..bec2d463196cc0 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -74,7 +74,9 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { bool has_metadata_mutation() const { return has_metadata_mutation_; } - + uint64_t mutation_counter() const { + return functional_storage_impl()->mutation_counter(); + } void mark_mutation() { functional_storage_impl()->mark_mutation(); } @@ -161,8 +163,13 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { return was_storage_changed_; } - void set_storage_changed() { + void mark_storage_changed() { was_storage_changed_ = true; + storage_changed_counter_++; + } + + uint64_t storage_changed_counter() { + return storage_changed_counter_; } // A FunctionalTensor is considered a base if its not a view of another @@ -181,6 +188,9 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { return functional_storage_impl()->was_inductor_storage_resized(); } + bool inductor_storage_resized_counter() { + return functional_storage_impl()->inductor_storage_resized_counter(); + } // The functionalization pass can be used to remove mutations. // It does so by replacing any mutation op with it's corresponding // out-of-place op, followed by a call to replace_(). e.g: @@ -226,7 +236,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { at::IntArrayRef strides_custom() const override; int64_t dim_custom() const override; int64_t numel_custom() const override; - bool is_contiguous_custom(at::MemoryFormat memory_format) const override; + c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const override; c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymInt sym_size_custom(int64_t d) const override; c10::SymIntArrayRef sym_strides_custom() const override; @@ -269,6 +280,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { bool is_multi_output_view_ = false; // Did the tensor experience a set_() call. bool was_storage_changed_ = false; + uint64_t storage_changed_counter_ = 0; // Did the tensor experience any view operation with symbolic int. bool is_symbolic_ = false; diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index bc2170b7ba0929..97094c9f125a0c 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -320,11 +320,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size); if (!stride.has_value()) { - // With unbacked symints, computeStride could fail even on contiguous - // tensors. In this case, we can use the strides of an empty tensor of - // inferred_size. - TORCH_CHECK( - self.is_contiguous(), + + TORCH_SYM_CHECK( + self.sym_is_contiguous(), "View is not valid from size:", self.sym_sizes(), " stride: ", @@ -333,6 +331,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt inferred_size, " in case of unbacked symbols consider adding torch.check to guide computing strides."); + // With unbacked symints, computeStride could fail even on contiguous + // tensors. In this case, we can use the strides of an empty tensor of + // inferred_size. stride = at::detail::empty_symint_meta( inferred_size, std::nullopt, diff --git a/aten/src/ATen/LegacyBatchedTensorImpl.cpp b/aten/src/ATen/LegacyBatchedTensorImpl.cpp index 12c562f5d8e13b..cceefe985a7e20 100644 --- a/aten/src/ATen/LegacyBatchedTensorImpl.cpp +++ b/aten/src/ATen/LegacyBatchedTensorImpl.cpp @@ -84,7 +84,7 @@ IntArrayRef BatchedTensorImpl::strides_custom() const { // TODO: implement proper contiguity on batched tensor, then put // sizes_strides_policy back to Default -bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { +c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { TORCH_CHECK(memory_format == MemoryFormat::Contiguous, "NYI: querying is_contiguous inside of vmap for memory_format ", "other than torch.contiguous_format"); diff --git a/aten/src/ATen/LegacyBatchedTensorImpl.h b/aten/src/ATen/LegacyBatchedTensorImpl.h index fa6c472e1fa0ed..798e3535af3fbd 100644 --- a/aten/src/ATen/LegacyBatchedTensorImpl.h +++ b/aten/src/ATen/LegacyBatchedTensorImpl.h @@ -82,7 +82,8 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { IntArrayRef strides_custom() const override; // Override a bunch of methods inherited from TensorImpl to return error // messages. - bool is_contiguous_custom(at::MemoryFormat memory_format) const override; + c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const override; void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 61336037d71b1d..1bc8c30158aec7 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -24,7 +24,7 @@ MemOverlap has_internal_overlap(TensorImpl* t) { } } - if (t->is_non_overlapping_and_dense()) { + if (t->is_non_overlapping_and_dense_or_false()) { return MemOverlap::No; } @@ -35,7 +35,7 @@ MemOverlap has_internal_overlap(TensorImpl* t) { // SymInts. Thus, if I have u0 size, we should assume that this has > 1 // elements (first expression), but if I have a u0 stride, I should NOT // assume that it is not zero (second expression) - if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_gt(1)) && strides[i] == 0) { + if (TORCH_GUARD_OR_FALSE(sizes[i].sym_gt(1)) && strides[i] == 0) { return MemOverlap::Yes; } } @@ -63,7 +63,7 @@ MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) { if (a->numel() == 0 || b->numel() == 0) { return MemOverlapStatus::No; } - if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) { + if (!a->is_non_overlapping_and_dense_or_false() || !b->is_non_overlapping_and_dense_or_false()) { return MemOverlapStatus::TooHard; } // Test for storage equality, rather than pointer equality. diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp index b35a7ef9401f21..647b2f1685d17c 100644 --- a/aten/src/ATen/NestedTensorImpl.cpp +++ b/aten/src/ATen/NestedTensorImpl.cpp @@ -273,7 +273,7 @@ c10::SymInt NestedTensorImpl::sym_numel_custom() const { return NestedTensorImpl::numel_custom(); } -bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const { +c10::SymBool NestedTensorImpl::sym_is_contiguous_custom(MemoryFormat) const { return nested_tensor_impl_is_contiguous(this); } IntArrayRef NestedTensorImpl::sizes_custom() const { diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h index 697969edbbd44f..f40684ce0ba264 100644 --- a/aten/src/ATen/NestedTensorImpl.h +++ b/aten/src/ATen/NestedTensorImpl.h @@ -115,7 +115,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { // with real implementations int64_t numel_custom() const override; c10::SymInt sym_numel_custom() const override; - bool is_contiguous_custom(MemoryFormat) const override; + c10::SymBool sym_is_contiguous_custom(MemoryFormat) const override; int64_t size_custom(int64_t d) const override { return this->size(d); } diff --git a/aten/src/ATen/ParallelNative.cpp b/aten/src/ATen/ParallelNative.cpp index 699c47e36725df..805df08a55300a 100644 --- a/aten/src/ATen/ParallelNative.cpp +++ b/aten/src/ATen/ParallelNative.cpp @@ -222,8 +222,7 @@ void set_num_threads(int nthreads) { int stored_nthreads = num_intraop_threads.load(); if (stored_nthreads <= 0) { // plus one because of master thread - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - stored_nthreads = _get_intraop_pool().size() + 1; + stored_nthreads = static_cast(_get_intraop_pool().size() + 1); } if (stored_nthreads != nthreads) { TORCH_WARN( @@ -251,8 +250,7 @@ int get_num_threads() { return intraop_default_num_threads(); } else { TORCH_INTERNAL_ASSERT(nthreads == CONSUMED); - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - return _get_intraop_pool().size() + 1; + return static_cast(_get_intraop_pool().size() + 1); } #else caffe2::PThreadPool* const pool = caffe2::pthreadpool(); diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp index 0ec3c97a2dac3c..f73d75ab53ad9c 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.cpp +++ b/aten/src/ATen/SparseCsrTensorImpl.cpp @@ -252,8 +252,7 @@ void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) { void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) { TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset."); } -bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const { +c10::SymBool SparseCsrTensorImpl::sym_is_contiguous_custom(MemoryFormat) const { TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous"); } - } // namespace at diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h index 94ac1e1c393448..14688163a374ff 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.h +++ b/aten/src/ATen/SparseCsrTensorImpl.h @@ -86,7 +86,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl { protected: IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; - bool is_contiguous_custom(MemoryFormat) const override; + SymBool sym_is_contiguous_custom(MemoryFormat) const override; public: void set_size(int64_t dim, int64_t new_size) override; diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 2a3b9481255f5c..2b2f286ea50d33 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -108,7 +108,7 @@ void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, cons AT_ASSERT(device() == values_.device()); AT_ASSERT(values_.device() == indices_.device()); - coalesced_ = TORCH_GUARD_SIZE_OBLIVIOUS(sym_nnz().sym_lt(2)); + coalesced_ = TORCH_GUARD_OR_FALSE(sym_nnz().sym_lt(2)); } diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 28c5fd6012f055..32f0f1e2defeb3 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -1388,7 +1388,7 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) { case FastSetupType::NON_OVERLAPPING_DENSE: { // find the index of a defined tensor in operands_ start from input tensor - int i_defined; // NOLINT(cppcoreguidelines-init-variables) + int i_defined = -1; for (i_defined = ntensors() - 1; i_defined >= 0; --i_defined) { if (tensor(i_defined).defined()) break; } diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 92adcc820276bd..1636bbcb6f75b7 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -343,7 +343,7 @@ inline static std::optional computeStride_impl( // This could perhaps be combined with the below code, but the complexity // didn't seem worth it. const Numel numel = c10::multiply_integers(oldshape); - bool zero_numel = TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0)); + bool zero_numel = TORCH_GUARD_OR_FALSE(sym_eq(numel, 0)); if (zero_numel && oldshape.equals(newshape)) { return toResult(oldstride); } diff --git a/aten/src/ATen/core/Formatting.cpp b/aten/src/ATen/core/Formatting.cpp index c1b2b952930bd4..eddd5e4b4d6cf1 100644 --- a/aten/src/ATen/core/Formatting.cpp +++ b/aten/src/ATen/core/Formatting.cpp @@ -45,14 +45,14 @@ std::string toString(const Scalar& s) { namespace at { -std::ostream& operator<<(std::ostream & out, const DeprecatedTypeProperties& t) { +std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t) { return out << t.toString(); } enum class FormatType { - Default, // 'g' format (defaultfloat equivalent) - Scientific, // 'e' format with precision 4 - Fixed // 'f' format with precision 4 + Default, // 'g' format (defaultfloat equivalent) + Scientific, // 'e' format with precision 4 + Fixed // 'f' format with precision 4 }; struct PrintFormat { @@ -61,12 +61,12 @@ struct PrintFormat { FormatType type; PrintFormat(double s, int w, FormatType t = FormatType::Default) - : scale(s), width(w), type(t) {} + : scale(s), width(w), type(t) {} }; static PrintFormat __printFormat(const Tensor& self) { auto size = self.numel(); - if(size == 0) { + if (size == 0) { return PrintFormat(1., 0); } @@ -74,8 +74,8 @@ static PrintFormat __printFormat(const Tensor& self) { auto self_p = self.const_data_ptr(); for (const auto i : c10::irange(size)) { auto z = self_p[i]; - if(std::isfinite(z)) { - if(z != std::ceil(z)) { + if (std::isfinite(z)) { + if (z != std::ceil(z)) { intMode = false; break; } @@ -83,28 +83,28 @@ static PrintFormat __printFormat(const Tensor& self) { } int64_t offset = 0; - while(offset < size && !std::isfinite(self_p[offset])) { + while (offset < size && !std::isfinite(self_p[offset])) { offset = offset + 1; } double expMin = 1; double expMax = 1; - if(offset != size) { + if (offset != size) { expMin = std::fabs(self_p[offset]); expMax = std::fabs(self_p[offset]); for (const auto i : c10::irange(offset, size)) { double z = std::fabs(self_p[i]); - if(std::isfinite(z)) { + if (std::isfinite(z)) { expMin = std::min(expMin, z); expMax = std::max(expMax, z); } } - if(expMin != 0) { + if (expMin != 0) { expMin = std::floor(std::log10(expMin)) + 1; } else { expMin = 1; } - if(expMax != 0) { + if (expMax != 0) { expMax = std::floor(std::log10(expMax)) + 1; } else { expMax = 1; @@ -114,8 +114,8 @@ static PrintFormat __printFormat(const Tensor& self) { double scale = 1; int sz = 11; - if(intMode) { - if(expMax > 9) { + if (intMode) { + if (expMax > 9) { sz = 11; return PrintFormat(scale, sz, FormatType::Scientific); } else { @@ -123,19 +123,19 @@ static PrintFormat __printFormat(const Tensor& self) { return PrintFormat(scale, sz, FormatType::Default); } } else { - if(expMax-expMin > 4) { + if (expMax - expMin > 4) { sz = 11; - if(std::fabs(expMax) > 99 || std::fabs(expMin) > 99) { + if (std::fabs(expMax) > 99 || std::fabs(expMin) > 99) { sz = sz + 1; } return PrintFormat(scale, sz, FormatType::Scientific); } else { - if(expMax > 5 || expMax < 0) { + if (expMax > 5 || expMax < 0) { sz = 7; - scale = std::pow(10, expMax-1); + scale = std::pow(10, expMax - 1); return PrintFormat(scale, sz, FormatType::Fixed); } else { - if(expMax == 0) { + if (expMax == 0) { sz = 7; } else { sz = static_cast(expMax) + 6; @@ -147,9 +147,9 @@ static PrintFormat __printFormat(const Tensor& self) { } // Precompiled format specs -static constexpr auto FMT_G = FMT_COMPILE("{:>{}g}"); -static constexpr auto FMT_E4 = FMT_COMPILE("{:>{}.4e}"); -static constexpr auto FMT_F4 = FMT_COMPILE("{:>{}.4f}"); +static constexpr auto FMT_G = FMT_COMPILE("{:>{}g}"); +static constexpr auto FMT_E4 = FMT_COMPILE("{:>{}.4e}"); +static constexpr auto FMT_F4 = FMT_COMPILE("{:>{}.4f}"); // Print a single value directly into the stream buffer with no temporaries static void printValue(std::ostream& stream, double v, const PrintFormat& pf) { @@ -157,7 +157,7 @@ static void printValue(std::ostream& stream, double v, const PrintFormat& pf) { double val = v / pf.scale; switch (pf.type) { case FormatType::Default: - fmt::format_to(out_it, FMT_G, val, pf.width); + fmt::format_to(out_it, FMT_G, val, pf.width); break; case FormatType::Scientific: fmt::format_to(out_it, FMT_E4, val, pf.width); @@ -168,57 +168,60 @@ static void printValue(std::ostream& stream, double v, const PrintFormat& pf) { } } -static void __printMatrix(std::ostream& stream, const Tensor& self, int64_t linesize, int64_t indent) { +static void __printMatrix( + std::ostream& stream, + const Tensor& self, + int64_t linesize, + int64_t indent) { auto printFmt = __printFormat(self); int64_t nColumnPerLine = (linesize - indent) / (printFmt.width + 1); int64_t firstColumn = 0; int64_t lastColumn = -1; - while(firstColumn < self.size(1)) { - if(firstColumn + nColumnPerLine <= self.size(1)) { + while (firstColumn < self.size(1)) { + if (firstColumn + nColumnPerLine <= self.size(1)) { lastColumn = firstColumn + nColumnPerLine - 1; } else { lastColumn = self.size(1) - 1; } - if(nColumnPerLine < self.size(1)) { - if(firstColumn != 0) { - stream.put('\n'); + if (nColumnPerLine < self.size(1)) { + if (firstColumn != 0) { + stream.put('\n'); } fmt::print( stream, "Columns {} to {}{:>{}s}", firstColumn + 1, lastColumn + 1, - "", // empty string to pad - indent // width to pad to + "", // empty string to pad + indent // width to pad to ); } - if(printFmt.scale != 1) { - fmt::print(stream, "{} *\n{:>{}s}", - printFmt.scale, "", indent); + if (printFmt.scale != 1) { + fmt::print(stream, "{} *\n{:>{}s}", printFmt.scale, "", indent); } for (const auto l : c10::irange(self.size(0))) { Tensor row = self.select(0, l); - const double *row_ptr = row.const_data_ptr(); + const double* row_ptr = row.const_data_ptr(); - for (const auto c : c10::irange(firstColumn, lastColumn+1)) { + for (const auto c : c10::irange(firstColumn, lastColumn + 1)) { printValue(stream, row_ptr[c], printFmt); - if(c == lastColumn) { + if (c == lastColumn) { stream.put('\n'); - if(l != self.size(0)-1) { - if(printFmt.scale != 1) { + if (l != self.size(0) - 1) { + if (printFmt.scale != 1) { fmt::print(stream, "{:>{}s} ", "", indent); } else { fmt::print(stream, "{:>{}s}", "", indent); } } } else { - stream.put(' '); + stream.put(' '); } } } @@ -226,18 +229,21 @@ static void __printMatrix(std::ostream& stream, const Tensor& self, int64_t line } } -static void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize) { - std::vector counter(self.ndimension()-2, 0); +static void __printTensor( + std::ostream& stream, + Tensor& self, + int64_t linesize) { + std::vector counter(self.ndimension() - 2, 0); counter[0] = -1; bool start = true; bool finished = false; - while(true) { - for(int64_t i = 0; self.ndimension()-2; i++) { + while (true) { + for (int64_t i = 0; self.ndimension() - 2; i++) { counter[i] = counter[i] + 1; - if(counter[i] >= self.size(i)) { - if(i == self.ndimension()-3) { + if (counter[i] >= self.size(i)) { + if (i == self.ndimension() - 3) { finished = true; break; } @@ -246,10 +252,10 @@ static void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize) break; } } - if(finished) { + if (finished) { break; } - if(start) { + if (start) { start = false; } else { stream.put('\n'); @@ -257,21 +263,24 @@ static void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize) stream.put('('); Tensor tensor = self; - for (const auto i : c10::irange(self.ndimension()-2)) { + for (const auto i : c10::irange(self.ndimension() - 2)) { tensor = tensor.select(0, counter[i]); - fmt::print(stream, "{},", counter[i]+1); + fmt::print(stream, "{},", counter[i] + 1); } fmt::print(stream, ".,.) = \n"); __printMatrix(stream, tensor, linesize, 1); } } -void print(const Tensor & t, int64_t linesize) { +void print(const Tensor& t, int64_t linesize) { print(std::cout, t, linesize); } -std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesize) { - if(!tensor_.defined()) { +std::ostream& print( + std::ostream& stream, + const Tensor& tensor_, + int64_t linesize) { + if (!tensor_.defined()) { fmt::print(stream, "[ Tensor (undefined) ]"); return stream; } @@ -299,15 +308,16 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi tensor = tensor_.to(kCPU, kDouble).contiguous(); } - if(tensor.ndimension() == 0) { - fmt::print(stream, - "{}\n[ {}{{}}", - tensor.const_data_ptr()[0], - tensor_.toString()); - } else if(tensor.ndimension() == 1) { + if (tensor.ndimension() == 0) { + fmt::print( + stream, + "{}\n[ {}{{}}", + tensor.const_data_ptr()[0], + tensor_.toString()); + } else if (tensor.ndimension() == 1) { if (tensor.numel() > 0) { auto printFmt = __printFormat(tensor); - if(printFmt.scale != 1) { + if (printFmt.scale != 1) { fmt::print(stream, "{} *\n", printFmt.scale); } const double* tensor_p = tensor.const_data_ptr(); @@ -317,12 +327,16 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi } } fmt::print(stream, "[ {}{{{}}}", tensor_.toString(), tensor.size(0)); - } else if(tensor.ndimension() == 2) { + } else if (tensor.ndimension() == 2) { if (tensor.numel() > 0) { __printMatrix(stream, tensor, linesize, 0); } - fmt::print(stream, "[ {}{{{},{}}}", - tensor_.toString(), tensor.size(0), tensor.size(1)); + fmt::print( + stream, + "[ {}{{{},{}}}", + tensor_.toString(), + tensor.size(0), + tensor.size(1)); } else { if (tensor.numel() > 0) { __printTensor(stream, tensor, linesize); @@ -338,10 +352,14 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi if (tensor_.is_quantized()) { fmt::print(stream, ", qscheme: {}", toString(tensor_.qscheme())); if (tensor_.qscheme() == c10::kPerTensorAffine) { - fmt::print(stream, ", scale: {}, zero_point: {}", - tensor_.q_scale(), tensor_.q_zero_point()); - } else if (tensor_.qscheme() == c10::kPerChannelAffine || - tensor_.qscheme() == c10::kPerChannelAffineFloatQParams) { + fmt::print( + stream, + ", scale: {}, zero_point: {}", + tensor_.q_scale(), + tensor_.q_zero_point()); + } else if ( + tensor_.qscheme() == c10::kPerChannelAffine || + tensor_.qscheme() == c10::kPerChannelAffineFloatQParams) { fmt::print(stream, ", scales: "); print(stream, tensor_.q_per_channel_scales(), linesize); fmt::print(stream, ", zero_points: "); @@ -363,4 +381,4 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi return stream; } -} +} // namespace at diff --git a/aten/src/ATen/core/IListRef_inl.h b/aten/src/ATen/core/IListRef_inl.h index a21bd22cf16c94..df320c13d9c238 100644 --- a/aten/src/ATen/core/IListRef_inl.h +++ b/aten/src/ATen/core/IListRef_inl.h @@ -168,7 +168,9 @@ class IListRefTagImpl */ static IListRefConstRef iterator_get( const typename list_type::const_iterator& it) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdangling-reference") const auto& ivalue = (*it).get(); + C10_DIAGNOSTIC_POP() if (!ivalue.isNone()) { const auto& tensor = ivalue.toTensor(); return (tensor.defined()) ? tensor : at::OptionalTensorRef{}; diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 8d300debebe3de..8463379149e273 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -57,16 +57,16 @@ inline bool variable_excluded_from_dispatch() { // NOTE: [Tensor vs. TensorBase] // // Tensor, being the central data structure in PyTorch, gets used and -// it's header included almost everywhere. Unfortunately this means +// its header included almost everywhere. Unfortunately this means // every time an operator signature is updated or changed in // native_functions.yaml, you (and every other PyTorch developer) need -// to recompile all of ATen and it's dependencies. +// to recompile all of ATen and its dependencies. // // TensorBase aims to break up these header dependencies, and improve // incremental build times for all PyTorch developers. TensorBase // represents a reference counted handle to TensorImpl, exactly the // same as Tensor. However, TensorBase doesn't have code generated -// methods in it's API and thus no dependence on native_functions.yaml. +// methods in its API and thus no dependence on native_functions.yaml. // // Usage tips // ---------- @@ -75,9 +75,9 @@ inline bool variable_excluded_from_dispatch() { // native_functions.yaml (direct or indirect). // - Tensor inherits from TensorBase, so functions taking // `const TensorBase &` are callable with Tensor as well. -// - TensorBase can be converted to tensor with `Tensor(tensor_base)`, -// but this requires a reference-count bump. OptionalTensorRef on -// the other hand can materialize a `const Tensor &` without +// - TensorBase can be converted to Tensor with `Tensor(tensor_base)`, +// but this requires a reference-count bump. OptionalTensorRef, on +// the other hand, can materialize a `const Tensor &` without // touching the reference-count. class TORCH_API TensorBase { public: @@ -124,7 +124,7 @@ class TORCH_API TensorBase { } TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { - if (is_contiguous(memory_format)) { + if (is_contiguous_or_false(memory_format)) { return *this; } else { return __dispatch_contiguous(memory_format); @@ -265,6 +265,25 @@ class TORCH_API TensorBase { return impl_->is_contiguous(memory_format); } + // Like is_contiguous, but more dynamic shape-friendly. May return a symbolic representation of + // contiguity instead of SymTrue SymFalse, when results are data-dependent. + c10::SymBool sym_is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { + if (impl_->has_symbolic_sizes_strides()) { + return impl_->sym_is_contiguous(memory_format); + } + return impl_->is_contiguous(memory_format); + } + + // Like is_contiguous, but more dynamic shape-friendly. Can returns + // false instead of throwing data-dependent errors for tensors with unbacked + // sizes or strides. + bool is_contiguous_or_false(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { + if (impl_->has_symbolic_sizes_strides()) { + return impl_->sym_is_contiguous(memory_format).guard_or_false(__FILE__, __LINE__); + } + return impl_->is_contiguous(memory_format); + } + bool is_non_overlapping_and_dense() const { return impl_->is_non_overlapping_and_dense(); } diff --git a/aten/src/ATen/core/adaption.cpp b/aten/src/ATen/core/adaption.cpp index ef06b9606ba7ee..abb21d31e048b4 100644 --- a/aten/src/ATen/core/adaption.cpp +++ b/aten/src/ATen/core/adaption.cpp @@ -5,9 +5,8 @@ namespace c10::impl { void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) { TORCH_CHECK(false, - "Expected all tensors to be on the same device, but " - "found at least two devices, ", common_device, " and ", tensor.device(), "! " - "(when checking argument for argument ", argName, " in method ", methodName, ")"); + "Expected all tensors to be on the same device, but got ", argName, " is on ", tensor.device(), + ", different from other tensors on ", common_device, " (when checking argument in method ", methodName, ")"); } } // namespace c10::impl diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index e67d1badc9a465..20dfde846e648c 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -105,7 +105,7 @@ using supported_primitive_arg_types = guts::typelist::typelist< // So a valid input type is one that our boxed functor wrapper can // unbox from an IValue into a C++ value. // -// Whereas a valid output type is one that our wrapper can recieve +// Whereas a valid output type is one that our wrapper can receive // as a C++ value from the unboxed functor, and box into an IValue. // diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index c3e4a82f11947f..ecc4bc7b5d893c 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -152,8 +152,11 @@ struct TORCH_API DispatchKeyExtractor final { // no safe toTensorRef method, alas) ks = ks | ivalue.unsafeToTensorImpl()->key_set(); } else if (C10_UNLIKELY(ivalue.isTensorList())) { - for (const at::Tensor& tensor : ivalue.toTensorList()) { - ks = ks | tensor.key_set(); + // NB: use toListRef as it doesn't induce refcount bumps + // (toTensorListRef is not a thing) + for (const auto& nv : ivalue.toListRef()) { + auto* tensor = nv.unsafeToTensorImpl(); + ks = ks | tensor->key_set(); } } // Tensor?[] translates to a c10::List so we need to peek inside diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 055a10af777c9d..b16ddd78edb859 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -179,6 +179,18 @@ const std::vector Dispatcher::getAllOpNames() { }); } +const std::vector Dispatcher::getAllOpNamesForDispatchKey(DispatchKey k) { + return operatorLookupTable_.read([&] (const ska::flat_hash_map& operatorLookupTable) -> std::vector { + std::vector allOpNames; + for (const auto& op : operatorLookupTable) { + if (op.second.hasKernelForDispatchKey(k)) { + allOpNames.push_back(op.first); + } + } + return allOpNames; + }); +} + // Postcondition: caller is responsible for disposing of registration when they // are done OperatorHandle Dispatcher::findOrRegisterName_(const OperatorName& op_name) { diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index dbc501afe7ce5c..590c9ba8d11d82 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -165,6 +165,10 @@ class TORCH_API Dispatcher final { // Returns a list of all operator names present in the operatorLookupTable_ const std::vector getAllOpNames(); + // Returns a list of all operator names present in the operatorLookupTable_ + // for a given dispatch key + const std::vector getAllOpNamesForDispatchKey(DispatchKey k); + // ------------------------------------------------------------------------ // // Invoking operators diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 19c08359b78dae..b4063fb720be0a 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -213,7 +213,8 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel( #endif // Suppress the warning for Meta key as we are overriding C++ meta functions with python meta functions // for some ops - if (dispatch_key != DispatchKey::Meta) { + // Also suppress the warning for MTIA, as MTIA achieves CPU fallback by overriding registration. + if (dispatch_key != DispatchKey::Meta && dispatch_key != DispatchKey::MTIA) { TORCH_WARN_ONCE("Warning only once for all operators, other operators may also be overridden.\n", " Overriding a previously registered kernel for the same operator and the same dispatch key\n", " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", @@ -353,7 +354,7 @@ std::pair OperatorEntry::computeDispatchTab // CompositExplicitAutogradNonFunctional > CompositeExplicitAutograd > CompositeImplicitAutograd > Autograd // Note [CompositeExplicitAutograd and CompositeImplicitAutograd] // When there're registrations to both CompositeExplicitAutograd & CompositeImplicitAutograd & Autograd, from (2.2) we know CompositeExplicitAutograd - // and Autograd kernels will be picked up and CompositeImplicitAutograd is overriden. + // and Autograd kernels will be picked up and CompositeImplicitAutograd is overridden. // This is fine and in practice CompositeExplicitAutograd and CompositeImplicitAutograd shouldn't co-exist for an op. // TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA . diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 0ffc061870f1d0..5c92d07ff699fe 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -1787,8 +1787,7 @@ TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) { } TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool fpga_called, math_called = false; + bool fpga_called = false, math_called = false; auto m = MAKE_TORCH_LIBRARY(test); m.def("fn", torch::dispatch(c10::DispatchKey::FPGA, [&](const Tensor& x) { fpga_called = true; return x; })); m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }); diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h index 4c6e21b28c2c53..c7fcf3178a27e1 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -202,18 +202,14 @@ class Vectorized { store(tmp); return tmp[idx]; } - // For boolean version where we want to if any 1/all zero - // etc. can be done faster in a different way. int zero_mask() const { - __at_align__ float tmp[size()]; - store(tmp); - int mask = 0; - for (int i = 0; i < size(); ++i) { - if (tmp[i] == 0.f) { - mask |= (1 << i); - } - } - return mask; + uint32x4_t is_zero_vec = vceqzq_f32(values); + const int32x4_t shift = vcombine_s32( + vcreate_s32(0x0 | (int64_t(0x1) << 32)), + vcreate_s32(0x2 | (int64_t(0x3) << 32))); + uint32x4_t bits_vec = + vshlq_u32(vandq_u32(is_zero_vec, vdupq_n_u32(1)), shift); + return vaddvq_u32(bits_vec); } Vectorized isnan() const { return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, values))); diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h index e468244968c08d..90e9cb65fe1067 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h @@ -220,8 +220,32 @@ class Vectorized : public Vectorized16< std::memcpy(ptr, tmp_values, count * sizeof(float16_t)); } } - // For boolean version where we want to if any 1/all zero - // etc. can be done faster in a different way. + int zero_mask() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + uint16x8_t is_zero_vec = vceqzq_f16(values); + const int16x8_t shift = vcombine_s16( + vcreate_s16( + 0x0 | (int64_t(0x1) << 16) | (int64_t(0x2) << 32) | + (int64_t(0x3) << 48)), + vcreate_s16( + 0x4 | (int64_t(0x5) << 16) | (int64_t(0x6) << 32) | + (int64_t(0x7) << 48))); + uint16x8_t bits_vec = + vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift); + return vaddvq_u16(bits_vec); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + // use known working implmentation. + __at_align__ value_type tmp[size()]; + store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (tmp[i] == 0) { + mask |= (1 << i); + } + } + return mask; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } Vectorized isnan() const { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values))); diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 43c14d3420bc6e..50c3cc31a6c488 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -205,7 +205,7 @@ std::pair, Vectorized> inline interleave2( const Vectorized& a, const Vectorized& b) { // inputs: - // a = {a0, a1, a3, a3} + // a = {a0, a1, a2, a3} // b = {b0, b1, b2, b3} // swap lanes: diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h index 705d8436edf922..a6a883e53b39b6 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h @@ -478,7 +478,7 @@ class Vectorized { this->store(tmp1); b.store(tmp2); - for (const auto i : c10::irange(Vectorized>::size())) { + for (const auto i : c10::irange(Vectorized>::size())) { out[i] = tmp1[i] / tmp2[i]; } return loadu(out); diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h index fe25c979906c40..2c2a199da80dca 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -348,26 +348,6 @@ class Vectorized { DEFINE_MEMBER_OP(operator^, int16_t, vec_xor) }; -template <> -Vectorized inline operator<<( - const Vectorized& a, - const Vectorized& b) { - vuint16 shift_vec0 = reinterpret_cast(b.vec0()); - vuint16 shift_vec1 = reinterpret_cast(b.vec1()); - return Vectorized{ - vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; -} - -template <> -Vectorized inline operator>>( - const Vectorized& a, - const Vectorized& b) { - vuint16 shift_vec0 = reinterpret_cast(b.vec0()); - vuint16 shift_vec1 = reinterpret_cast(b.vec1()); - return Vectorized{ - vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; -} - template <> Vectorized inline maximum( const Vectorized& a, @@ -382,6 +362,8 @@ Vectorized inline minimum( return a.minimum(b); } +DEFINE_SHIFT_FUNCS(int16_t) + template <> Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& a, const Vectorized& b) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h index c47b7ea66d74d2..ea22e8dde2df23 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -278,26 +278,6 @@ class Vectorized { DEFINE_MEMBER_OP(operator^, int32_t, vec_xor) }; -template <> -Vectorized inline operator<<( - const Vectorized& a, - const Vectorized& b) { - vuint32 shift_vec0 = reinterpret_cast(b.vec0()); - vuint32 shift_vec1 = reinterpret_cast(b.vec1()); - return Vectorized{ - vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; -} - -template <> -Vectorized inline operator>>( - const Vectorized& a, - const Vectorized& b) { - vuint32 shift_vec0 = reinterpret_cast(b.vec0()); - vuint32 shift_vec1 = reinterpret_cast(b.vec1()); - return Vectorized{ - vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; -} - template <> Vectorized inline maximum( const Vectorized& a, @@ -312,6 +292,8 @@ Vectorized inline minimum( return a.minimum(b); } +DEFINE_SHIFT_FUNCS(int32_t) + template <> Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& a, const Vectorized& b) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h index 2238f41aef300b..8d0bd52c90103f 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -231,26 +231,6 @@ class Vectorized { DEFINE_MEMBER_OP(operator^, int64_t, vec_xor) }; -template <> -Vectorized inline operator<<( - const Vectorized& a, - const Vectorized& b) { - vuint64 shift_vec0 = reinterpret_cast(b.vec0()); - vuint64 shift_vec1 = reinterpret_cast(b.vec1()); - return Vectorized{ - vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; -} - -template <> -Vectorized inline operator>>( - const Vectorized& a, - const Vectorized& b) { - vuint64 shift_vec0 = reinterpret_cast(b.vec0()); - vuint64 shift_vec1 = reinterpret_cast(b.vec1()); - return Vectorized{ - vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; -} - template <> Vectorized inline maximum( const Vectorized& a, @@ -265,6 +245,8 @@ Vectorized inline minimum( return a.minimum(b); } +DEFINE_SHIFT_FUNCS(int64_t) + template <> Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& a, const Vectorized& b) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h b/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h index 136b68911061f5..7ca603c0b91dfa 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include @@ -39,6 +40,19 @@ using vfloat32 = __attribute__((altivec(vector__))) float; using vfloat64 = __attribute__((altivec(vector__))) double; #endif +inline auto make_vuint(vint8 v) { + return reinterpret_cast(v); +} +inline auto make_vuint(vint16 v) { + return reinterpret_cast(v); +} +inline auto make_vuint(vint32 v) { + return reinterpret_cast(v); +} +inline auto make_vuint(vint64 v) { + return reinterpret_cast(v); +} + #if !defined(vec_float) C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) { vfloat32 vec_out; @@ -521,6 +535,42 @@ const vfloat64 vd_imag_half = vfloat64{0.0, 0.5}; const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757}; const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0}; +template +Vectorized VsxShiftRightArith( + const Vectorized& a, + const Vectorized& b) { + const Vectorized max_shift(sizeof(T) * CHAR_BIT - std::is_signed_v); + const auto mask = (b < Vectorized(0)) | (b >= max_shift); + const auto shift = Vectorized::blendv(b, max_shift, mask); + return Vectorized{ + vec_sra(a.vec0(), make_vuint(shift.vec0())), + vec_sra(a.vec1(), make_vuint(shift.vec1()))}; +} + +template +Vectorized VsxShiftLeftArith( + const Vectorized& a, + const Vectorized& b) { + const Vectorized max_shift(sizeof(T) * CHAR_BIT); + const auto mask = (b < Vectorized(0)) | (b >= max_shift); + Vectorized ret( + vec_sl(a.vec0(), make_vuint(b.vec0())), + vec_sl(a.vec1(), make_vuint(b.vec1()))); + return Vectorized::blendv(ret, Vectorized(0), mask); +} + +#define DEFINE_SHIFT_FUNCS(operand_type) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator>>( \ + const Vectorized& a, const Vectorized& b) { \ + return VsxShiftRightArith(a, b); \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator<<( \ + const Vectorized& a, const Vectorized& b) { \ + return VsxShiftLeftArith(a, b); \ + } + } // namespace CPU_CAPABILITY } // namespace vec } // namespace at diff --git a/aten/src/ATen/cpu/vec/vec_quant.h b/aten/src/ATen/cpu/vec/vec_quant.h new file mode 100644 index 00000000000000..36602c4a760f09 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec_quant.h @@ -0,0 +1,153 @@ +#pragma once + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Transpose a [4, 64] block to [64, 4] (with contiguous output, ld=4) +template > +static inline void transpose_pad_4x64_block( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int krem = 4, + int nrem = 64) { +#if defined(CPU_CAPABILITY_AVX512) + __m512i r[4]; + // Load with mask if partial + if (nrem < 64) { + __mmask64 mask = (1ULL << nrem) - 1; + for (int i = 0; i < krem; ++i) { + r[i] = _mm512_maskz_loadu_epi8(mask, src + i * ld_src); + } + for (int i = krem; i < 4; ++i) { + r[i] = _mm512_setzero_si512(); + } + } else { + for (int i = 0; i < krem; ++i) { + r[i] = _mm512_loadu_si512( + reinterpret_cast(src + i * ld_src)); + } + for (int i = krem; i < 4; ++i) { + r[i] = _mm512_setzero_si512(); + } + } + + // Transpose 4x64 bytes using unpack and shuffle + __m512i t0 = _mm512_unpacklo_epi8(r[0], r[1]); + __m512i t1 = _mm512_unpackhi_epi8(r[0], r[1]); + __m512i t2 = _mm512_unpacklo_epi8(r[2], r[3]); + __m512i t3 = _mm512_unpackhi_epi8(r[2], r[3]); + + __m512i u0 = _mm512_unpacklo_epi16(t0, t2); + __m512i u1 = _mm512_unpackhi_epi16(t0, t2); + __m512i u2 = _mm512_unpacklo_epi16(t1, t3); + __m512i u3 = _mm512_unpackhi_epi16(t1, t3); + + __m512i v0 = _mm512_shuffle_i32x4(u0, u1, 0x88); + __m512i v1 = _mm512_shuffle_i32x4(u0, u1, 0xdd); + __m512i v2 = _mm512_shuffle_i32x4(u2, u3, 0x88); + __m512i v3 = _mm512_shuffle_i32x4(u2, u3, 0xdd); + + __m512i r0 = _mm512_shuffle_i32x4(v0, v2, 0x88); + __m512i r1 = _mm512_shuffle_i32x4(v1, v3, 0x88); + __m512i r2 = _mm512_shuffle_i32x4(v0, v2, 0xdd); + __m512i r3 = _mm512_shuffle_i32x4(v1, v3, 0xdd); + + // Store output + if (nrem < 16) { + __mmask64 mask = (1ULL << (nrem * 4)) - 1; + _mm512_mask_storeu_epi8(dst, mask, r0); + } else if (nrem == 16) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + } else if (nrem < 32) { + int n_bytes1 = 64; + int n_bytes2 = (nrem * 4) - n_bytes1; + __mmask64 mask = (1ULL << n_bytes2) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64), mask, r1); + } else if (nrem == 32) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + } else if (nrem < 48) { + int n_bytes1 = 64 * 2; + int n_bytes2 = (nrem * 4) - n_bytes1; + __mmask64 mask = (1ULL << n_bytes2) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 2), mask, r2); + } else if (nrem == 48) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); + } else if (nrem < 64) { + int n_bytes1 = 64 * 3; + int n_bytes2 = (nrem * 4) - n_bytes1; + __mmask64 mask = (1ULL << n_bytes2) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 3), mask, r3); + } else { + // normal case, nrem == 64 + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 3), r3); + } +#else + TORCH_CHECK( + false, + "transpose_pad_4x64_block is only supported when AVX-512 is supported") +#endif +} + +// Reorder [K, N] → [K/4, N, 4] (VNNI4-style layout for bit8) +template > +static inline void pack_vnni4( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int64_t K, + int64_t N) { +#if defined(CPU_CAPABILITY_AVX512) + int64_t bk = 0; + int64_t _K = K / 4 * 4; + int64_t _N = N / 64 * 64; + for (; bk < _K; bk += 4) { + int64_t bn = 0; + for (; bn < _N; bn += 64) { + transpose_pad_4x64_block( + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_4x64_block( + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, 4, nrem); + } + } + + // Handle leftover K rows (< 4) + if (K % 4 != 0) { + int krem = K - bk; + int64_t bn = 0; + for (; bn < _N; bn += 64) { + transpose_pad_4x64_block( + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_4x64_block( + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem, nrem); + } + } +#else + TORCH_CHECK(false, "pack_vnni4 is only supported when AVX-512 is supported") +#endif +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 4a0816fd305800..d009520d05ab83 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -17,6 +17,7 @@ #include #ifdef USE_ROCM +#include #include // until hipblas has an API to accept flags, we must use rocblas here #include @@ -111,12 +112,15 @@ static cublasOperation_t _cublasOpFromChar(char op) { // NOLINTNEXTLINE(bugprone-switch-missing-default-case) switch (op) { case 'n': + [[fallthrough]]; case 'N': return CUBLAS_OP_N; case 't': + [[fallthrough]]; case 'T': return CUBLAS_OP_T; case 'c': + [[fallthrough]]; case 'C': return CUBLAS_OP_C; } @@ -185,82 +189,65 @@ uint32_t _getAlignment(uintptr_t address) { } #endif -static size_t _parseChosenWorkspaceSize() { - auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE"); #ifdef USE_ROCM - if (!val.has_value()) { - // accept either env var - val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); +static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) { + static int32_t last_value = 0; + static hipStream_t stream; + if (last_value == 0) { + // first request, do nothing for this case + } + else if (last_value == value) { + // stream was created previously and value hasn't changed + return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device()); + } + else { + // need a new stream and a previous stream exists, delete it + AT_CUDA_CHECK(hipStreamDestroy(stream)); } - size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */ -#else - size_t workspace_size = 1024; /* default size in KiB according to #73328 */ -#endif - if (val.has_value()) { - try { - workspace_size = std::stoi(val.value()); - } catch (std::invalid_argument const&) { - TORCH_WARN( - "invalid CUBLASLT_WORKSPACE_SIZE,", - " using default workspace size of ", - workspace_size, - " KiB."); - } catch (std::out_of_range const&) { - TORCH_WARN( - "CUBLASLT_WORKSPACE_SIZE out of range,", - " using default workspace size of ", - workspace_size, - " KiB."); - } + // if we got here, we need to create a new stream + int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + // how many uint32_t do we need to cover all CUs, fill bitmask with 1 + uint32_t mask_size = static_cast((CUs + 32 - 1) / 32); + std::vector mask(mask_size, uint32_t{0xffffffff}); + // starting from lowest order bits, in 32-bit chunks + // set bits to 0 based on how many CUs to carve out + int32_t full_shifts = value / 32; + int32_t remainder = value % 32; + for (int32_t i = 0; i < full_shifts; i++) { + mask[i] = uint32_t{0x00000000}; } - return workspace_size * 1024; -} + mask[full_shifts] = uint32_t{0xffffffff} << remainder; -static size_t _getWorkspaceSize() { - static size_t workspace_size = _parseChosenWorkspaceSize(); - return workspace_size; + // finally, create masked stream + AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0])); + + last_value = value; + return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device()); } -void* _getUnifiedWorkspaceWithoutHandle() { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - auto stream = c10::cuda::getCurrentCUDAStream(); - cudaStream_t _stream = stream; - auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); - auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key); - TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end()); - return workspace_it->second.mutable_get(); +static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) { + hipEvent_t event; + AT_CUDA_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming)); + + auto current_stream = at::cuda::getCurrentCUDAStream(); + + if (presync) { + AT_CUDA_CHECK(hipEventRecord(event, current_stream)); + AT_CUDA_CHECK(hipStreamWaitEvent(stream, event, 0)); + } + else { + AT_CUDA_CHECK(hipEventRecord(event, stream)); + AT_CUDA_CHECK(hipStreamWaitEvent(current_stream, event, 0)); + } } +#endif struct CublasLtWorkspace { CublasLtWorkspace() { - size = _getWorkspaceSize(); -#ifndef USE_ROCM - static bool unified = c10::utils::check_env("TORCH_CUBLASLT_UNIFIED_WORKSPACE") == true; - if (unified) { - auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize(); - if (cublasWorkspaceSize < size) { - TORCH_WARN_ONCE("Requested unified CUBLASLT workspace size of ", size, - " bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize, - " bytes. Please increase CUBLAS workspace size", - " via CUBLAS_WORKSPACE_CONFIG or decrease requested" - " CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace" - " size will be limited to the CUBLAS workspace size."); - size = cublasWorkspaceSize; - } - ptr = _getUnifiedWorkspaceWithoutHandle(); - } else { - auto allocator = c10::cuda::CUDACachingAllocator::get(); - stashed_ptr_ = allocator->allocate(size); - ptr = stashed_ptr_.mutable_get(); - } -#else - auto allocator = c10::cuda::CUDACachingAllocator::get(); - stashed_ptr_ = allocator->allocate(size); - ptr = stashed_ptr_.mutable_get(); -#endif + size = at::cuda::getCUDABlasLtWorkspaceSize(); + ptr = at::cuda::getCUDABlasLtWorkspace(); } - at::DataPtr stashed_ptr_; void * ptr; size_t size; }; @@ -404,7 +391,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; } else if constexpr (std::is_same_v) { - if (at::globalContext().allowTF32CuBLAS()) { + if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } } else if constexpr (std::is_same_v>) { @@ -458,6 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb); + auto stream = at::cuda::getCurrentCUDAStream(); #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -465,6 +453,12 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +#else + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + stream = _getCarveoutStream( + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); + _syncCurrentWithCarveoutStream(stream, true); + } #endif CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T); CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T); @@ -527,7 +521,12 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D &heuristicResult.algo, ltworkspace.ptr, ltworkspace.size, - at::cuda::getCurrentCUDAStream()); + stream); +#ifdef USE_ROCM + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + _syncCurrentWithCarveoutStream(stream, false); + } +#endif } if (cublasStatus != CUBLAS_STATUS_SUCCESS) { TORCH_WARN( @@ -1586,7 +1585,7 @@ bool gemm_and_bias( computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; } else if constexpr (std::is_same_v) { - if (at::globalContext().allowTF32CuBLAS()) { + if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } } else if constexpr (std::is_same_v) { @@ -1625,6 +1624,7 @@ bool gemm_and_bias( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); + auto stream = at::cuda::getCurrentCUDAStream(); #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -1632,6 +1632,12 @@ bool gemm_and_bias( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +#else + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + stream = _getCarveoutStream( + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); + _syncCurrentWithCarveoutStream(stream, true); + } #endif cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; if (activation == GEMMAndBiasActivationEpilogue::RELU) { @@ -1700,7 +1706,12 @@ bool gemm_and_bias( &heuristicResult.algo, ltworkspace.ptr, ltworkspace.size, - at::cuda::getCurrentCUDAStream()); + stream); +#ifdef USE_ROCM + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + _syncCurrentWithCarveoutStream(stream, false); + } +#endif } if (cublasStatus != CUBLAS_STATUS_SUCCESS) { TORCH_WARN( @@ -1868,20 +1879,25 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER; cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER; -#if defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) +#if defined(USE_ROCM) +#if defined(HIPBLASLT_OUTER_VEC) + // this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F +#elif defined(HIPBLASLT_VEC_EXT) if (use_rowwise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } #else - // rowwise isn't supported using cublaslt or older hipblaslt - TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); -#endif // if defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) + // rowwise isn't supported using older hipblaslt + TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt"); +#endif +#endif // defined(USE_ROCM) computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); if (result_scale_ptr != nullptr) { computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); } + auto stream = at::cuda::getCurrentCUDAStream(); #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -1889,6 +1905,12 @@ void scaled_gemm( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +#else + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + stream = _getCarveoutStream( + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); + _syncCurrentWithCarveoutStream(stream, true); + } #endif // ifndef USE_ROCM #ifndef USE_ROCM const int8_t fastAccuMode = use_fast_accum ? 1 : 0; @@ -1923,9 +1945,17 @@ void scaled_gemm( #else TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above"); #endif // if CUDA_VERSION >= 12080 + } else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) { +#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); +#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) + // no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 } - auto stream = c10::cuda::getCurrentCUDAStream(); CuBlasLtMatmulPreference preference; auto ltworkspace = CublasLtWorkspace(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size); @@ -2012,6 +2042,11 @@ void scaled_gemm( ltworkspace.ptr, ltworkspace.size, stream); +#ifdef USE_ROCM + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + _syncCurrentWithCarveoutStream(stream, false); + } +#endif TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -2065,6 +2100,7 @@ void int8_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); + auto stream = at::cuda::getCurrentCUDAStream(); #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -2072,6 +2108,12 @@ void int8_gemm( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +#else + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + stream = _getCarveoutStream( + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); + _syncCurrentWithCarveoutStream(stream, true); + } #endif CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1); @@ -2085,10 +2127,8 @@ void int8_gemm( #ifdef USE_ROCM CuBlasLtMatmulPreference preference; - size_t workspaceSize = _getWorkspaceSize(); - preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); - auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - auto workspace = allocator.allocate(workspaceSize); + auto ltworkspace = CublasLtWorkspace(); + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size); cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( @@ -2126,16 +2166,16 @@ void int8_gemm( nullptr, // Heuristics don't seem to work for int8 #endif #ifdef USE_ROCM - workspace.mutable_get(), + ltworkspace.ptr, #else nullptr, // Non-zero workspace doesn't seem to work. #endif #ifdef USE_ROCM - workspaceSize, + ltworkspace.size, #else 0, #endif - at::cuda::getCurrentCUDAStream()); + stream); TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -2164,6 +2204,11 @@ void int8_gemm( computeType, " scaleType ", scaleType); +#ifdef USE_ROCM + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + _syncCurrentWithCarveoutStream(stream, false); + } +#endif } template <> diff --git a/aten/src/ATen/cuda/CUDAConfig.h.in b/aten/src/ATen/cuda/CUDAConfig.h.in index 7c7f2cc7470a42..6263e8455eaf81 100644 --- a/aten/src/ATen/cuda/CUDAConfig.h.in +++ b/aten/src/ATen/cuda/CUDAConfig.h.in @@ -8,6 +8,7 @@ // only be included from C++ files. #define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@ #define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@ +#define AT_HIPSPARSELT_ENABLED() @AT_HIPSPARSELT_ENABLED@ #define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@ #define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@ diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h index 65019bb6097c9b..86e960cc1ab4ad 100644 --- a/aten/src/ATen/cuda/CUDAContextLight.h +++ b/aten/src/ATen/cuda/CUDAContextLight.h @@ -89,7 +89,10 @@ TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); TORCH_CUDA_CPP_API void clearCublasWorkspaces(); TORCH_CUDA_CPP_API std::map, at::DataPtr>& cublas_handle_stream_to_workspace(); +TORCH_CUDA_CPP_API std::map, at::DataPtr>& cublaslt_handle_stream_to_workspace(); TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize(); +TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize(); +TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace(); #if defined(CUDART_VERSION) || defined(USE_ROCM) TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h index 63b41343f9c054..dc8e2c4f70ff63 100644 --- a/aten/src/ATen/cuda/CUDAEvent.h +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -13,6 +13,17 @@ #include #include +/* +* `cudaEventExternal` is a torch-specific flag that is used to +* indicate that the CUDAEvent will be used only for synchronization +* with work outside of the cuda graph, rather than creation of +* cross-stream dependencies within a cuda graph. Resources: +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e +*/ +#define cudaEventExternal 0x08 + namespace at::cuda { /* @@ -118,7 +129,14 @@ struct TORCH_CUDA_CPP_API CUDAEvent { TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_, " does not match recording stream's device ", stream.device_index(), "."); CUDAGuard guard(device_index_); + +#ifndef USE_ROCM + // it is an error to use cudaEventRecordExternal when not doing stream capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventRecordExternal : cudaEventRecordDefault; + AT_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); +#else AT_CUDA_CHECK(cudaEventRecord(event_, stream)); +#endif const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_event_record(at::kCUDA, @@ -134,7 +152,13 @@ struct TORCH_CUDA_CPP_API CUDAEvent { void block(const CUDAStream& stream) { if (is_created_) { CUDAGuard guard(stream.device_index()); - AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0)); +#ifndef USE_ROCM + // it is an error to use cudaEventWaitExternal when not doing stream capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventWaitExternal : cudaEventWaitDefault; + AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); +#else + AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); +#endif const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_event_wait(at::kCUDA, @@ -193,10 +217,16 @@ struct TORCH_CUDA_CPP_API CUDAEvent { unsigned int flags_ = cudaEventDisableTiming; bool is_created_ = false; bool was_recorded_ = false; + bool external_ = false; DeviceIndex device_index_ = -1; cudaEvent_t event_{}; void createEvent(DeviceIndex device_index) { + external_ = (flags_ & cudaEventExternal) != 0; +#ifdef USE_ROCM + TORCH_CHECK(!external_, "External events are disallowed in rocm"); +#endif + flags_ &= ~cudaEventExternal; device_index_ = device_index; CUDAGuard guard(device_index_); AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 42f199c4d909ef..7fba7c4c7424c3 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -38,9 +38,10 @@ MempoolId_t graph_pool_handle() { * describes memory management for captures. */ -CUDAGraph::CUDAGraph() +CUDAGraph::CUDAGraph(bool keep_graph) // CUDAStreams may not be default-constructed. - : capture_stream_(at::cuda::getCurrentCUDAStream()) { + : capture_stream_(at::cuda::getCurrentCUDAStream()), + keep_graph_(keep_graph) { } void CUDAGraph::register_generator_state( @@ -126,8 +127,37 @@ void CUDAGraph::capture_end() { c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_); TORCH_CHECK(graph_ != nullptr, "Invalid capture."); + + for (auto& [generator_state, wholegraph_increments] : + captured_generator_states_) { + wholegraph_increments = generator_state->capture_epilogue(); + } + + size_t numCUDAGraphNodes = 0; + AT_CUDA_CHECK(cudaGraphGetNodes(graph_, nullptr, &numCUDAGraphNodes)); + if (numCUDAGraphNodes == 0) { + TORCH_WARN("The CUDA Graph is empty. This usually means that the graph was ", + "attempted to be captured on wrong device or stream."); + } + + capture_ended_ = true; has_graph_ = true; + if (!keep_graph_) { + instantiate(); + if (!_cuda_graphs_debug) { + AT_CUDA_CHECK(cudaGraphDestroy(graph_)); + } + has_graph_ = false; + } +} + +void CUDAGraph::instantiate() { + TORCH_CHECK(capture_ended_, "capture_end() must have been called before calling instantiate"); + if (has_graph_exec_) { + TORCH_CHECK(keep_graph_, "instantiate() is intended to be called by the user only when keep_graph=true"); + AT_CUDA_CHECK(cudaGraphExecDestroy(graph_exec_)); + } // In typical graph usage some tensors (e.g. the tensors used for graph IO) are not freed // between replays. // If Pytorch compiles and runs with a CUDA 11.4+ toolkit, there's a chance the allocator backend @@ -161,36 +191,18 @@ void CUDAGraph::capture_end() { cudaGraphInstantiateFlagAutoFreeOnLaunch)); } #endif - has_graph_exec_ = true; - - for (auto& [generator_state, wholegraph_increments] : - captured_generator_states_) { - wholegraph_increments = generator_state->capture_epilogue(); - } - - size_t numCUDAGraphNodes = 0; - AT_CUDA_CHECK(cudaGraphGetNodes(graph_, nullptr, &numCUDAGraphNodes)); - if (numCUDAGraphNodes == 0) { - TORCH_WARN("The CUDA Graph is empty. This usually means that the graph was ", - "attempted to be captured on wrong device or stream."); - } - - // check if debug path is set - if (!_cuda_graphs_debug) { - // Now that we've instantiated graph_ into graph_exec_, - // we don't need graph_ anymore. - AT_CUDA_CHECK(cudaGraphDestroy(graph_)); - has_graph_ = false; - } else { - TORCH_WARN("DEBUG: TORCH_CUDAGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called."); - } } void CUDAGraph::replay() { - TORCH_CHECK(has_graph_exec_, + TORCH_CHECK(capture_ended_, "Called CUDAGraph::replay without a preceding successful capture."); + if (!has_graph_exec_) { + TORCH_INTERNAL_ASSERT(keep_graph_); + instantiate(); + } + c10::OptionalDeviceGuard device_guard{capture_stream_.device()}; for (auto& [generator_state, wholegraph_increments] : @@ -217,13 +229,15 @@ void CUDAGraph::enable_debug_mode() { void CUDAGraph::debug_dump(const std::string& debug_path) { #if defined(CUDA_VERSION) || defined(USE_ROCM) - if (_cuda_graphs_debug) { + if (_cuda_graphs_debug || keep_graph_) { TORCH_WARN("DEBUG: calling debug_dump()"); if (has_graph_) { TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path); C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), cudaGraphDebugDotFlagsVerbose)); // most verbose output - AT_CUDA_CHECK(cudaGraphDestroy(graph_)); - has_graph_ = false; + if (!keep_graph_) { + AT_CUDA_CHECK(cudaGraphDestroy(graph_)); + has_graph_ = false; + } } } else { TORCH_WARN("CUDA Graphs debug not enabled, set with [graph].enable_debug_mode()"); @@ -233,6 +247,12 @@ void CUDAGraph::debug_dump(const std::string& debug_path) { #endif } +cudaGraph_t CUDAGraph::raw_cuda_graph() { + TORCH_CHECK(keep_graph_, "You cannot access the raw cudaGraph_t instance unless CUDAGraph was initialized with keep_graph=true"); + TORCH_CHECK(has_graph_, "You cannot access the raw cudaGraph_t instance until capture_end() has been called"); + return graph_; +} + void CUDAGraph::reset() { // I'd prefer these checks throw exceptions, not print warnings, // but the destructor calls reset(), and at least one CI build @@ -253,9 +273,10 @@ void CUDAGraph::reset() { // and the allocator could end up in all kinds of weird states depending where failure occurred. // If the user catches the failure exception in a script, or is running in REPL or (god forbid) // a Jupyter notebook, I don't see an easy way for reset() to gracefully fix all such possible error states. - if (has_graph_ || has_graph_exec_) { + if (capture_ended_) { // notifyCaptureDestroy may throw. How should we handle this? c10::cuda::CUDACachingAllocator::releasePool(capture_dev_, mempool_id_); + capture_ended_ = false; } if (has_graph_) { C10_CUDA_CHECK_WARN(cudaGraphDestroy(graph_)); @@ -269,7 +290,7 @@ void CUDAGraph::reset() { // Returns an id another graph's capture_begin can use to share the same memory pool as this graph. MempoolId_t CUDAGraph::pool() { -TORCH_CHECK(has_graph_exec_, +TORCH_CHECK(capture_ended_, "Called CUDAGraph::pool() without a preceding successful capture."); return mempool_id_; } diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index 76a090579d1dfa..c8cae16b624fe9 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -19,7 +19,7 @@ namespace cuda { TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle(); struct TORCH_CUDA_CPP_API CUDAGraph { - CUDAGraph(); + CUDAGraph(bool keep_graph=false); ~CUDAGraph(); // See Note [Explicit Registration of Generators to the CUDA Graph] @@ -29,21 +29,26 @@ struct TORCH_CUDA_CPP_API CUDAGraph { MempoolId_t pool = {0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal); void capture_end(); + void instantiate(); void replay(); void reset(); MempoolId_t pool(); void enable_debug_mode(); void debug_dump(const std::string& debug_path); + cudaGraph_t raw_cuda_graph(); protected: cudaGraph_t graph_ = nullptr; cudaGraphExec_t graph_exec_ = nullptr; // internal states so reset() can do its best cleaning up + // Set to true in capture_end if cudaStreamEndCapture succeeded - // Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate - // to create graph_exec_, then graph_ is deleted + // Set back to false after instantiate() unless keep_graph=True or + // enable_debug_mode() was called on any CUDAGraph instance. bool has_graph_ = false; + // Set to true in capture_end if cudaStreamEndCapture succeeded + bool capture_ended_ = false; // Set to true in capture_end if cudaGraphInstantiate succeeded bool has_graph_exec_ = false; @@ -80,6 +85,8 @@ struct TORCH_CUDA_CPP_API CUDAGraph { // init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1; c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE}; + + bool keep_graph_; }; } // namespace cuda diff --git a/aten/src/ATen/cuda/CachingHostAllocator.h b/aten/src/ATen/cuda/CachingHostAllocator.h index 1379879c3ea4da..b9486314b1c21b 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.h +++ b/aten/src/ATen/cuda/CachingHostAllocator.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace at::cuda { diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 720304ad198e82..32985134144380 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -23,6 +23,9 @@ * To work around this difference in behavior, a separate handle pool is available for ROCm builds. * For CUDA builds, getCurrentCUDABlasLtHandle will alias for getCurrentCUDABlasHandle, * whereas for ROCm builds, it is a distinct function. + * + * The workspace pools are separate for ROCm. On CUDA, the env var + * TORCH_CUBLASLT_UNIFIED_WORKSPACE can be used to opt-in to unifying the workspace pools. */ namespace at::cuda { @@ -109,8 +112,14 @@ std::map, at::DataPtr>& cublas_handle_stream_to_works return instance; } +std::map, at::DataPtr>& cublaslt_handle_stream_to_workspace() { + static auto& instance = *new std::map, at::DataPtr>; + return instance; +} + void clearCublasWorkspaces() { cublas_handle_stream_to_workspace().clear(); + cublaslt_handle_stream_to_workspace().clear(); } size_t parseChosenWorkspaceSize() { @@ -157,15 +166,97 @@ size_t parseChosenWorkspaceSize() { } } +size_t parseCUDABlasLtWorkspaceSize() { + auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE"); +#ifdef USE_ROCM + if (!val.has_value()) { + // accept either env var + val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); + } + size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */ +#else + size_t workspace_size = 1024; /* default size in KiB according to #73328 */ +#endif + + if (val.has_value()) { + try { + workspace_size = std::stoi(val.value()); + } catch (std::invalid_argument const&) { + TORCH_WARN( + "invalid CUBLASLT_WORKSPACE_SIZE,", + " using default workspace size of ", + workspace_size, + " KiB."); + } catch (std::out_of_range const&) { + TORCH_WARN( + "CUBLASLT_WORKSPACE_SIZE out of range,", + " using default workspace size of ", + workspace_size, + " KiB."); + } + } + return workspace_size * 1024; +} + size_t getChosenWorkspaceSize() { size_t pool_size = parseChosenWorkspaceSize(); return pool_size; } +#define TORCH_CUBLASLT_UNIFIED_WORKSPACE "TORCH_CUBLASLT_UNIFIED_WORKSPACE" + +size_t getCUDABlasLtWorkspaceSize() { + size_t pool_size = parseCUDABlasLtWorkspaceSize(); +#ifndef USE_ROCM + static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true; + if (unified) { + auto cublasWorkspaceSize = getChosenWorkspaceSize(); + if (cublasWorkspaceSize < pool_size) { + TORCH_WARN_ONCE("Requested unified CUBLASLT workspace size of ", pool_size, + " bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize, + " bytes. Please increase CUBLAS workspace size", + " via CUBLAS_WORKSPACE_CONFIG or decrease requested" + " CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace" + " size will be limited to the CUBLAS workspace size."); + pool_size = cublasWorkspaceSize; + } + } +#endif + return pool_size; +} + at::DataPtr getNewWorkspace() { return c10::cuda::CUDACachingAllocator::get()->allocate(getChosenWorkspaceSize()); } +at::DataPtr getNewCUDABlasLtWorkspace() { + return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize()); +} + +void* getCUDABlasLtWorkspace() { +#ifndef USE_ROCM + static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true; + if (unified) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + auto stream = c10::cuda::getCurrentCUDAStream(); + cudaStream_t _stream = stream; + auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); + auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key); + TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end()); + return workspace_it->second.mutable_get(); + } +#endif + cublasLtHandle_t handle = getCurrentCUDABlasLtHandle(); + auto stream = c10::cuda::getCurrentCUDAStream(); + cudaStream_t _stream = stream; + auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); + auto workspace_it = cublaslt_handle_stream_to_workspace().find(key); + if (workspace_it == cublaslt_handle_stream_to_workspace().end()) { + workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()}); + } + return workspace_it->second.mutable_get(); +} + cublasHandle_t getCurrentCUDABlasHandle() { c10::DeviceIndex device = 0; AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); @@ -218,7 +309,8 @@ cublasHandle_t getCurrentCUDABlasHandle() { // On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup // FP32 data type calculations based on the value of the allow_tf32 flag. // To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH. - if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) { + if (!NoTF32Guard::should_disable_tf32() && + at::globalContext().float32Precision("cuda", "matmul") == "tf32") { TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH)); } else { TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); diff --git a/aten/src/ATen/cuda/cub.cu b/aten/src/ATen/cuda/cub.cu index 839652f581a516..bc863b8880da7f 100644 --- a/aten/src/ATen/cuda/cub.cu +++ b/aten/src/ATen/cuda/cub.cu @@ -15,8 +15,7 @@ struct SumOp { template void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t num_items) { - using NO_ROCM(at_cuda_detail)::cub::Sum; - inclusive_scan(input, output, Sum{}, num_items); + inclusive_scan(input, output, NO_ROCM(::cuda)::std::plus<>{}, num_items); } template void inclusive_sum_truncating(const int32_t *input, int32_t *output, int64_t num_items); @@ -42,8 +41,7 @@ struct CountMaskOp { void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n) { CountMaskOp op{}; - auto iter = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator< - bool, decltype(op), decltype(mask)>(mask, op); + auto iter = ATEN_CUB_TRANSFORM_ITERATOR(bool, decltype(op), decltype(mask))(mask, op); exclusive_scan(iter, output_idx, SumOp{}, int64_t{0}, n); } diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index e1d452bac4a603..23a3ff8c8958c4 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -6,6 +6,10 @@ #include #include +#ifndef USE_ROCM +#include +#endif + #include #include @@ -51,6 +55,21 @@ #define ROCM_HIPCUB(x) x #endif +#if CUB_V3_PLUS() +#include +#include +#include +#define ATEN_CUB_TRANSFORM_ITERATOR(ValueType, ...) ::thrust::transform_iterator<__VA_ARGS__> +#define ATEN_CUB_COUNTING_ITERATOR(...) ::thrust::counting_iterator<__VA_ARGS__> +#define ATEN_CUB_CONSTANT_ITERATOR(...) ::thrust::constant_iterator<__VA_ARGS__> +#define ATEN_CUB_MAXIMUM() ::cuda::maximum<>() +#else +#define ATEN_CUB_TRANSFORM_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::TransformInputIterator<__VA_ARGS__> +#define ATEN_CUB_COUNTING_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::CountingInputIterator<__VA_ARGS__> +#define ATEN_CUB_CONSTANT_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<__VA_ARGS__> +#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max() +#endif + #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM) #if !defined(USE_ROCM) @@ -270,7 +289,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT return x.value; } }; - auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator( + auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)( ArgIndexInputIterator(input + i), input_iter_transform); CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, input_, @@ -425,7 +444,7 @@ __global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int i aggT data[ITEMS_PER_THREAD]; aggT agg_val = 0; TransformFunctor transform_functor; - auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator, const T*>(d_in, transform_functor); + auto iter_in = ATEN_CUB_TRANSFORM_ITERATOR(aggT, TransformFunctor, const T*)(d_in, transform_functor); for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { BlockLoadT(temp_storage.load).Load(iter_in, data); @@ -568,7 +587,7 @@ inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT i "cub InclusiveSumByKey does not support more than INT_MAX elements"); #if !defined(USE_ROCM) CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey, - keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); + keys, input, output, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream()); #else CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey, keys, input, output, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream()); @@ -581,7 +600,7 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT "cub InclusiveSumByKey does not support more than INT_MAX elements"); #if !defined(USE_ROCM) CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey, - keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); + keys, input, output, scan_op, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream()); #else CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey, keys, input, output, scan_op, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream()); diff --git a/aten/src/ATen/cuda/cub_definitions.cuh b/aten/src/ATen/cuda/cub_definitions.cuh index db7cc9120a099e..aad19c6771ed76 100644 --- a/aten/src/ATen/cuda/cub_definitions.cuh +++ b/aten/src/ATen/cuda/cub_definitions.cuh @@ -51,3 +51,11 @@ #else #define CUB_SUPPORTS_FUTURE_VALUE() false #endif + +// There were many bc-breaking changes in major version release of CCCL v3.0.0 +// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html +#if CUB_VERSION >= 300000 +#define CUB_V3_PLUS() true +#else +#define CUB_V3_PLUS() false +#endif diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index c4a425fe359ed2..247fdb2537cb49 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -331,6 +331,16 @@ long CUDAHooks::versionCuDNN() const { #endif } +long CUDAHooks::versionMIOpen() const { +#if AT_ROCM_ENABLED() + return MIOPEN_VERSION_MAJOR * 10000 + + MIOPEN_VERSION_MINOR * 100 + + MIOPEN_VERSION_PATCH; +#else + TORCH_CHECK(false, "Cannot query MIOpen version if ATen_cuda is not built with ROCm"); +#endif +} + long CUDAHooks::versionCUDART() const { #ifdef CUDART_VERSION return CUDART_VERSION; diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index 2b4c11136321de..b0dac7a71e809d 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -46,6 +46,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCUDART() const override; long versionCUDART() const override; long versionCuDNN() const override; + long versionMIOpen() const override; std::string showConfig() const override; double batchnormMinEpsilonCuDNN() const override; int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override; diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index ba5672352da115..6f896f1a22bfc5 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -160,7 +160,7 @@ inline std::string ComputeTypeFor() { // ROCBLAS and hipBLASLt. template <> inline std::string ComputeTypeFor() { - if (!at::globalContext().allowTF32CuBLAS()) { + if (at::globalContext().float32Precision("cuda", "matmul") != "tf32") { return "f32_r"; } else { return "xf32_r"; diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 4808bb63346891..32fb7c2774fff3 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -381,28 +381,6 @@ static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) { return HIPBLAS_OP_T; } -static size_t GetHipblasltWorkspaceSize() { - static const auto env = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); - // 256MB is max workspace size allowed for hipblaslt - // hipblaslt-bench uses 32MB - // recommendation from hipblaslt author was 76MB - // TunableOp hipBLASLt workspace size is aligned with - // PyTorch's default in CUDABlas.cpp (_parseChosenWorkspaceSize) - size_t workspace_size = 76*1024; - if (env) { - try { - workspace_size = std::stoi(env.value()); - } catch(std::invalid_argument const& e) { - TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,", - " using default workspace size of ", workspace_size, " KiB."); - } catch(std::out_of_range const& e) { - TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,", - " using default workspace size of ", workspace_size, " KiB."); - } - } - return workspace_size * 1024; -} - template struct HipBlasLtDeleter { void operator()(T* x) { @@ -499,7 +477,7 @@ class HipblasltGemmOp : public Callable { } hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; - if (at::globalContext().allowTF32CuBLAS()) { + if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") { computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; } HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F); @@ -522,6 +500,12 @@ class HipblasltGemmOp : public Callable { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); } +#ifdef HIPBLASLT_OUTER_VEC + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + } +#endif } if (result_scale_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); @@ -544,7 +528,7 @@ class HipblasltGemmOp : public Callable { } } - size_t workspace_size = GetHipblasltWorkspaceSize(); + size_t workspace_size = at::cuda::getCUDABlasLtWorkspaceSize(); auto op_handle = at::cuda::getCurrentCUDABlasLtHandle(); @@ -569,10 +553,7 @@ class HipblasltGemmOp : public Callable { return FAIL; } - void* workspace_buffer = nullptr; - if (workspace_size > 0) { - workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size); - } + void* workspace_buffer = at::cuda::getCUDABlasLtWorkspace(); TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle, matmul.descriptor(), @@ -595,9 +576,6 @@ class HipblasltGemmOp : public Callable { TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a)); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b)); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c)); - if (workspace_size > 0) { - c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer); - } return OK; } diff --git a/aten/src/ATen/cuda/tunable/GemmRocblas.h b/aten/src/ATen/cuda/tunable/GemmRocblas.h index 857eddd85d10aa..d7c45dc91c2120 100644 --- a/aten/src/ATen/cuda/tunable/GemmRocblas.h +++ b/aten/src/ATen/cuda/tunable/GemmRocblas.h @@ -141,7 +141,7 @@ class RocblasGemmOp : public Callable> { TuningStatus Call(const GemmParams* params) override { auto input_output_type = RocBlasDataTypeFor(); - if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r) + if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r) return FAIL; // no support for TF32 in rocBLAS auto compute_type = RocBlasComputeTypeFor(); auto h_a = DoCastForHalfOrBfloat16(params->alpha); @@ -209,7 +209,7 @@ class RocblasGemmStridedBatchedOp : public Callable> TuningStatus Call(const GemmStridedBatchedParams* params) override { auto input_output_type = RocBlasDataTypeFor(); - if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r) + if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r) return FAIL; // no support for TF32 in rocBLAS auto compute_type = RocBlasComputeTypeFor(); auto h_a = DoCastForHalfOrBfloat16(params->alpha); diff --git a/aten/src/ATen/cuda/tunable/README.md b/aten/src/ATen/cuda/tunable/README.md index 6328403360e5c8..b30040b7e28421 100644 --- a/aten/src/ATen/cuda/tunable/README.md +++ b/aten/src/ATen/cuda/tunable/README.md @@ -154,7 +154,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins | PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS | Default is 0, meaning it is not used. | | PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED | Default is 1. Set to 0 to disable. | | PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE | Default (or < 0) is to query L2 cache size. Set to 0 to disable. Otherwise, set to the number of MiB to use for the pool of operator parameters. For example, setting this to the size of your device's memory cache will guarantee that every tuning iteration will use a cold cache. | -| PYTORCH_TUNABLEOP_BLAS_LOG | Default is 0. Set to 1 to enable. Write BLAS paramters to tuning CSV file. | +| PYTORCH_TUNABLEOP_BLAS_LOG | Default is 0. Set to 1 to enable. Write BLAS parameters to tuning CSV file. | ### Python Interface All python APIs exist in the `torch.cuda.tunable` module. diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index d7c32ac2cf3340..2fc1867d276d06 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -156,8 +156,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo default: TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters"); } - // NOLINTNEXTLINE(*narrowing-conversions) - set(getDataType(t), static_cast(dim), size, filter_format); + set(getDataType(t), static_cast(dim), size, filter_format); } std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) { diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index c356ff57aa55a4..f99e03d156c9b5 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -162,6 +162,10 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); } + virtual long versionMIOpen() const { + TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP); + } + virtual long versionCUDART() const { TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP); } diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 01d6281e8afe02..1a16072fd95d5e 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -78,6 +78,9 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { virtual uint32_t acquireEvent(bool enable_timing) const { FAIL_MPSHOOKS_FUNC(__func__); } + Device getDeviceFromPtr(void* data) const override { + TORCH_CHECK(false, "Cannot get device of pointer on MPS without ATen_mps library. "); + } virtual void releaseEvent(uint32_t event_id) const { FAIL_MPSHOOKS_FUNC(__func__); } diff --git a/aten/src/ATen/dlpack.h b/aten/src/ATen/dlpack.h index 6f8e03dd570422..5d0234b5653e71 100644 --- a/aten/src/ATen/dlpack.h +++ b/aten/src/ATen/dlpack.h @@ -15,11 +15,11 @@ #define DLPACK_EXTERN_C #endif -/*! \brief The current version of dlpack */ -#define DLPACK_VERSION 80 +/*! \brief The current major version of dlpack */ +#define DLPACK_MAJOR_VERSION 1 -/*! \brief The current ABI version of dlpack */ -#define DLPACK_ABI_VERSION 1 +/*! \brief The current minor version of dlpack */ +#define DLPACK_MINOR_VERSION 0 /*! \brief DLPACK_DLL prefix for windows */ #ifdef _WIN32 @@ -40,6 +40,33 @@ #ifdef __cplusplus extern "C" { #endif + +/*! + * \brief The DLPack version. + * + * A change in major version indicates that we have changed the + * data layout of the ABI - DLManagedTensorVersioned. + * + * A change in minor version indicates that we have added new + * code, such as a new device type, but the ABI is kept the same. + * + * If an obtained DLPack tensor has a major version that disagrees + * with the version number specified in this header file + * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter + * (and it is safe to do so). It is not safe to access any other fields + * as the memory layout will have changed. + * + * In the case of a minor version mismatch, the tensor can be safely used as + * long as the consumer knows how to interpret all fields. Minor version + * updates indicate the addition of enumeration values. + */ +typedef struct { + /*! \brief DLPack major version. */ + uint32_t major; + /*! \brief DLPack minor version. */ + uint32_t minor; +} DLPackVersion; + /*! * \brief The device type in DLDevice. */ @@ -91,7 +118,7 @@ typedef enum { kDLWebGPU = 15, /*! \brief Qualcomm Hexagon DSP */ kDLHexagon = 16, - /*! \brief Microsoft AI Accelerator */ + /*! \brief Microsoft MAIA devices */ kDLMAIA = 17, } DLDeviceType; @@ -190,6 +217,9 @@ typedef struct { * return size; * } * \endcode + * + * Note that if the tensor is of size zero, then the data pointer should be + * set to `NULL`. */ void* data; /*! \brief The device of the tensor */ @@ -215,6 +245,13 @@ typedef struct { * not meant to transfer the tensor. When the borrowing framework doesn't need * the tensor, it should call the deleter to notify the host that the resource * is no longer needed. + * + * \note This data structure is used as Legacy DLManagedTensor + * in DLPack exchange and is deprecated after DLPack v0.8 + * Use DLManagedTensorVersioned instead. + * This data structure may get renamed or deleted in future versions. + * + * \sa DLManagedTensorVersioned */ typedef struct DLManagedTensor { /*! \brief DLTensor which is being memory managed */ @@ -223,13 +260,74 @@ typedef struct DLManagedTensor { * which DLManagedTensor is used in the framework. It can also be NULL. */ void * manager_ctx; - /*! \brief Destructor signature void (*)(void*) - this should be called - * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL - * if there is no way for the caller to provide a reasonable destructor. - * The destructors deletes the argument self as well. + /*! + * \brief Destructor - this should be called + * to destruct the manager_ctx which backs the DLManagedTensor. It can be + * NULL if there is no way for the caller to provide a reasonable destructor. + * The destructor deletes the argument self as well. */ void (*deleter)(struct DLManagedTensor * self); } DLManagedTensor; + +// bit masks used in in the DLManagedTensorVersioned + +/*! \brief bit mask to indicate that the tensor is read only. */ +#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) + +/*! + * \brief bit mask to indicate that the tensor is a copy made by the producer. + * + * If set, the tensor is considered solely owned throughout its lifetime by the + * consumer, until the producer-provided deleter is invoked. + */ +#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) + +/*! + * \brief A versioned and managed C Tensor object, manage memory of DLTensor. + * + * This data structure is intended to facilitate the borrowing of DLTensor by + * another framework. It is not meant to transfer the tensor. When the borrowing + * framework doesn't need the tensor, it should call the deleter to notify the + * host that the resource is no longer needed. + * + * \note This is the current standard DLPack exchange data structure. + */ +struct DLManagedTensorVersioned { + /*! + * \brief The API and ABI version of the current managed Tensor + */ + DLPackVersion version; + /*! + * \brief the context of the original host framework. + * + * Stores DLManagedTensorVersioned is used in the + * framework. It can also be NULL. + */ + void *manager_ctx; + /*! + * \brief Destructor. + * + * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. + * It can be NULL if there is no way for the caller to provide a reasonable + * destructor. The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensorVersioned *self); + /*! + * \brief Additional bitmask flags information about the tensor. + * + * By default the flags should be set to 0. + * + * \note Future ABI changes should keep everything until this field + * stable, to ensure that deleter can be correctly called. + * + * \sa DLPACK_FLAG_BITMASK_READ_ONLY + * \sa DLPACK_FLAG_BITMASK_IS_COPIED + */ + uint64_t flags; + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; +}; + #ifdef __cplusplus } // DLPACK_EXTERN_C #endif diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index cca20e9e553e5c..4b66b30b62e7fb 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -193,6 +193,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(_lu_with_info); OP_DECOMPOSE(matmul); OP_DECOMPOSE(matrix_H); + OP_DECOMPOSE(matrix_exp); OP_DECOMPOSE(matrix_power); OP_DECOMPOSE2(max, other ); OP_DECOMPOSE(max_pool1d); diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp index c3c85144565988..ee222b4e61a52b 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -126,7 +126,7 @@ SymIntArrayRef BatchedTensorImpl::sym_strides_custom() const { // TODO: implement proper contiguity on batched tensor, then put // sizes_strides_policy back to Default -bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { +c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { TORCH_CHECK(memory_format == MemoryFormat::Contiguous, "NYI: querying is_contiguous inside of vmap for memory_format ", "other than torch.contiguous_format"); diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.h b/aten/src/ATen/functorch/BatchedTensorImpl.h index e42f8dd87b5011..3eccc94d3ea60a 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.h +++ b/aten/src/ATen/functorch/BatchedTensorImpl.h @@ -69,7 +69,7 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; // Override a bunch of methods inherited from TensorImpl to return error messages. - bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; + c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override; void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; c10::intrusive_ptr shallow_copy_and_detach( @@ -159,6 +159,7 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ DispatchKey::XLA, DispatchKey::CUDA, DispatchKey::CPU, + DispatchKey::PrivateUse1, }); inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { diff --git a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp index 7bc3a3cbfe44ae..ecedc729ccd738 100644 --- a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp +++ b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp @@ -143,7 +143,7 @@ static Tensor make_feature_noise(const Tensor& input) { } static bool is_fused_kernel_acceptable(const Tensor& input, double p) { - return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0; + return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.numel() > 0; } // NB: sure, we could have used different overloads here, but I would feel insecure diff --git a/aten/src/ATen/functorch/TensorWrapper.cpp b/aten/src/ATen/functorch/TensorWrapper.cpp index 4f50a1fe2b4017..65de9268927f09 100644 --- a/aten/src/ATen/functorch/TensorWrapper.cpp +++ b/aten/src/ATen/functorch/TensorWrapper.cpp @@ -56,7 +56,8 @@ void dumpTensorCout(const Tensor& tensor) { static c10::intrusive_ptr makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr& life_handle) { auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({ - DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA}); + DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA, + DispatchKey::AutogradPrivateUse1}); auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate); key_set = key_set.add(DispatchKey::FuncTorchGradWrapper); return c10::make_intrusive(key_set, tensor, level, life_handle); @@ -76,7 +77,8 @@ static Tensor unsafeMakeTensorWrapper( } auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({ - DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA}); + DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA, + DispatchKey::AutogradPrivateUse1}); auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate); key_set = key_set.add(DispatchKey::FuncTorchGradWrapper); auto result = at::detail::make_tensor( diff --git a/aten/src/ATen/miopen/Descriptors.h b/aten/src/ATen/miopen/Descriptors.h index a0ad4a4e1098a0..2eee837cd533d2 100644 --- a/aten/src/ATen/miopen/Descriptors.h +++ b/aten/src/ATen/miopen/Descriptors.h @@ -39,7 +39,7 @@ struct DescriptorDeleter { // function. template // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_CUDA_CPP_API Descriptor { +class TORCH_HIP_CPP_API Descriptor { public: // Use desc() to access the underlying descriptor pointer in // a read-only fashion. Most client code should use this. @@ -65,7 +65,7 @@ class TORCH_CUDA_CPP_API Descriptor { std::unique_ptr> desc_; }; -class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< +class TORCH_HIP_CPP_API TensorDescriptor : public Descriptor< miopenTensorDescriptor, &miopenCreateTensorDescriptor, &miopenDestroyTensorDescriptor> { @@ -88,7 +88,7 @@ class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); -class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< +class TORCH_HIP_CPP_API FilterDescriptor : public Descriptor< miopenTensorDescriptor, &miopenCreateTensorDescriptor, &miopenDestroyTensorDescriptor> { @@ -105,7 +105,7 @@ class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< } }; -struct TORCH_CUDA_CPP_API ConvolutionDescriptor +struct TORCH_HIP_CPP_API ConvolutionDescriptor : public Descriptor< miopenConvolutionDescriptor, &miopenCreateConvolutionDescriptor, @@ -121,7 +121,7 @@ struct TORCH_CUDA_CPP_API ConvolutionDescriptor }; // NOLINTNEXTLINE(bugprone-exception-escape) -struct TORCH_CUDA_CPP_API DropoutDescriptor +struct TORCH_HIP_CPP_API DropoutDescriptor : public Descriptor< miopenDropoutDescriptor, &miopenCreateDropoutDescriptor, @@ -137,7 +137,7 @@ struct TORCH_CUDA_CPP_API DropoutDescriptor } }; -struct TORCH_CUDA_CPP_API RNNDescriptor +struct TORCH_HIP_CPP_API RNNDescriptor : public Descriptor diff --git a/aten/src/ATen/miopen/Handle.h b/aten/src/ATen/miopen/Handle.h index 4c80c3aea65bf6..b1637fca0a5823 100644 --- a/aten/src/ATen/miopen/Handle.h +++ b/aten/src/ATen/miopen/Handle.h @@ -5,5 +5,5 @@ namespace at::native { -TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle(); +TORCH_HIP_CPP_API miopenHandle_t getMiopenHandle(); } // namespace at::native diff --git a/aten/src/ATen/miopen/Types.h b/aten/src/ATen/miopen/Types.h index 0a8a1a952e2e28..fdc0f6a607b71d 100644 --- a/aten/src/ATen/miopen/Types.h +++ b/aten/src/ATen/miopen/Types.h @@ -6,7 +6,7 @@ namespace at::native { -TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor); +TORCH_HIP_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor); int64_t miopen_version(); diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index 03637e7ca65f53..a70ce25108201a 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -1,6 +1,7 @@ // Copyright © 2022 Apple Inc. #pragma once +#include #include #include #include @@ -70,4 +71,8 @@ TORCH_API bool is_available(); TORCH_API bool is_macos_13_or_newer(MacOSVersion version); TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); +inline Device getDeviceFromPtr(void* ptr) { + return {c10::DeviceType::MPS, 0}; +} + } // namespace at::mps diff --git a/aten/src/ATen/mps/MPSFallback.mm b/aten/src/ATen/mps/MPSFallback.mm index ea807bd6624295..10a7c9191afabd 100644 --- a/aten/src/ATen/mps/MPSFallback.mm +++ b/aten/src/ATen/mps/MPSFallback.mm @@ -93,7 +93,6 @@ static Tensor slow_conv2d_forward_mps(const Tensor& self, m.impl("linalg_svd", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps); - m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); } } // namespace at diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 17a3d3a68cec72..da5760c9256100 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -18,6 +18,8 @@ struct MPSHooks : public at::MPSHooksInterface { bool hasMPS() const override; bool isOnMacOSorNewer(unsigned major, unsigned minor) const override; + Device getDeviceFromPtr(void* data) const override; + // MPSGeneratorImpl interface const Generator& getDefaultGenerator( DeviceIndex device_index = -1) const override; diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index 03c39c957368be..f6133e8877222d 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -129,6 +129,10 @@ at::mps::getMPSEventPool()->recordEvent(event_id, /* syncEvent*/ true); } +Device MPSHooks::getDeviceFromPtr(void* data) const { + return at::mps::getDeviceFromPtr(data); +} + void MPSHooks::waitForEvent(uint32_t event_id) const { at::mps::getMPSEventPool()->waitForEvent(event_id, /* syncEvent*/ true); } diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 775e1cb04e8482..cfeb67bef3bd9f 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -697,7 +697,7 @@ TORCH_META_FUNC(linalg_cholesky_ex)(const Tensor& A, auto ndim = A_shape.size(); // L - auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/A.device().type() != at::kMPS); + auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true); set_output_strided(0, A_shape, L_strides, A.options(), {}); // info diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index 2ca484f808b26d..674ccf11cfb9b6 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -296,7 +296,7 @@ _scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2, std::optional out_dtype, bool use_fast_accum, Tensor& out) { -#if AT_MKLDNN_ENABLED() +#if AT_MKLDNN_ENABLED() && !defined(__powerpc__) if (at::globalContext().userEnabledMkldnn()) { bool mixed_dtype = mat1.scalar_type() != mat2.scalar_type(); if ((!mixed_dtype && cpuinfo_has_x86_amx_int8()) || diff --git a/aten/src/ATen/native/CPUFallback.cpp b/aten/src/ATen/native/CPUFallback.cpp index fd850846ba619e..d51a119804d282 100644 --- a/aten/src/ATen/native/CPUFallback.cpp +++ b/aten/src/ATen/native/CPUFallback.cpp @@ -98,13 +98,13 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool const auto arguments_begin = stack->size() - num_arguments; std::vector tensor_args; - std::vector tensor_args_indices; + std::vector tensor_args_indices; std::vector> tensorlist_args; - std::vector tensorlist_args_indices; + std::vector tensorlist_args_indices; std::vector>> optional_tensorlist_args; - std::vector optional_tensorlist_args_indices; + std::vector optional_tensorlist_args_indices; std::optional tgt_device = std::nullopt; // save converted cpu tensor for TensorList and optional TensorList diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index ffe3f56e555056..7932e32b428b69 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -446,11 +446,6 @@ struct ConvParams { } } if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) { - // bypass dilation checks for channels_last convolution - if (deterministic && is_dilated()) { - // cudnn doesn't support deterministic dilated convolution fully yet - return false; - } if (is_dilated()) { return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big(); } @@ -1179,7 +1174,7 @@ at::Tensor convolution( bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); return at::_convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, - ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN()); + ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv")); } at::Tensor convolution_overrideable( @@ -1324,7 +1319,7 @@ ConvBackend select_conv_backend( params.benchmark = ctx.benchmarkCuDNN(); params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); params.cudnn_enabled = ctx.userEnabledCuDNN(); - params.allow_tf32 = ctx.allowTF32CuDNN(); + params.allow_tf32 = ctx.allowTF32CuDNN("conv"); auto input = input_r; auto weight = weight_r; @@ -1710,7 +1705,7 @@ at::Tensor _convolution( c10::MaybeOwned bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt); const Tensor& bias_r = *bias_r_maybe_owned; - return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN()); + return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv")); } std::tuple convolution_backward_overrideable( @@ -2008,7 +2003,7 @@ std::tuple convolution_backward( params.benchmark = ctx.benchmarkCuDNN(); params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); params.cudnn_enabled = ctx.userEnabledCuDNN(); - params.allow_tf32 = ctx.allowTF32CuDNN(); + params.allow_tf32 = ctx.allowTF32CuDNN("conv"); // Validate inputs. check_shape_backward(input, weight.sizes(), params); diff --git a/aten/src/ATen/native/Cross.cpp b/aten/src/ATen/native/Cross.cpp index 7297aaed80d384..820f9fddd78d52 100644 --- a/aten/src/ATen/native/Cross.cpp +++ b/aten/src/ATen/native/Cross.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS @@ -77,6 +78,9 @@ Tensor & cross_out(const Tensor & input, const Tensor & other, const std::option TORCH_IMPL_FUNC(linalg_cross_out) (const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) { + at::assert_no_internal_overlap(out); + at::assert_no_overlap(out, input); + at::assert_no_overlap(out, other); dim = maybe_wrap_dim(dim, input.dim()); auto out_size = out.sizes(); Tensor input_broadcasted = input.expand(out_size); diff --git a/aten/src/ATen/native/IndexingUtils.h b/aten/src/ATen/native/IndexingUtils.h index c442b2232a967f..948a6b8320a4e6 100644 --- a/aten/src/ATen/native/IndexingUtils.h +++ b/aten/src/ATen/native/IndexingUtils.h @@ -5,6 +5,13 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + namespace at::native { [[noreturn]] @@ -15,7 +22,8 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, [[maybe_unused]] static std::vector expandTensors( const Tensor& self, - IOptTensorListRef indices) { + IOptTensorListRef indices, + bool ensure_same_device = false) { // If indices come in as ByteTensor or BoolTensor (masks), expand them into // the equivalent indexing by LongTensors std::vector result; @@ -38,10 +46,19 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, } } // Replace with nonzeros - auto nonzero = index.nonzero(); + at::Tensor nonzero; + if (ensure_same_device && index.device() != self.device()) { + bool non_blocking = index.is_cpu() && self.device().is_cuda(); + auto out = at::empty({0}, index.options().dtype(kLong).pinned_memory(non_blocking)); + nonzero = at::nonzero_out(out, index).to(self.device(), non_blocking); + } else { + nonzero = index.nonzero(); + } for (const auto j : c10::irange(index.dim())) { result.emplace_back(nonzero.select(1, j)); } + } else if (ensure_same_device && index.device() != self.device()) { + result.emplace_back(index.to(self.device())); } else { result.emplace_back(index); } diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index adb538da6cc3c8..5d3a84ea39f6de 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -92,9 +93,10 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optionaldefined() && !input.is_xla()) { // Also hit the fused path for contiguous 3D input, if not using xla // backend. Reshaping/flattening has some performance implications on xla. - if (input.is_contiguous() && input_dim == 3) { + bool is_contiguous = input.is_contiguous_or_false(); + if (is_contiguous && input_dim == 3) { return _flatten_nd_linear(input, weight, *bias); - } else if (input.is_contiguous() && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { + } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { return _flatten_nd_linear(input, weight, *bias); } else if (parseLinearFlatten3d() && input_dim == 3) { // If user forces flattening via env var @@ -152,8 +154,8 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra Tensor left = left_; Tensor right = right_; for (const auto i : c10::irange(dim)) { - auto sl = TORCH_GUARD_SIZE_OBLIVIOUS(left.sym_size(i).sym_ne(1)); - auto sr = TORCH_GUARD_SIZE_OBLIVIOUS(right.sym_size(i).sym_ne(1)); + auto sl = TORCH_GUARD_OR_TRUE(left.sym_size(i).sym_ne(1)); + auto sr = TORCH_GUARD_OR_TRUE(right.sym_size(i).sym_ne(1)); if (sum_dims[i]) { // first dimensions that will be summed over after multiplication if (sl && sr) { // dimensions nontrivially in both left and right must be of the same size TORCH_SYM_CHECK(left.sym_size(i).sym_eq(right.sym_size(i)), "non-broadcast dimensions must match"); @@ -475,7 +477,7 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr // Iterate over each dimension covered by ellipsis const auto ndim = operands[i].ndimension() - (static_cast(op_labels[i].size()) - 1); for (auto j = ell_num_dim - ndim; j < ell_num_dim; ++j) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(op.sym_size(dim).sym_ne(1))) { + if (TORCH_GUARD_OR_TRUE(op.sym_size(dim).sym_ne(1))) { // Update ellipsis size TORCH_SYM_CHECK( ell_sizes[j].sym_eq(1).sym_or(ell_sizes[j].sym_eq(op.sym_size(dim))), @@ -494,7 +496,7 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr permutation[ell_index + j] = dim++; } } else if (permutation[label_perm_index[s]] == -1) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(op.sym_size(dim).sym_ne(1))) { + if (TORCH_GUARD_OR_TRUE(op.sym_size(dim).sym_ne(1))) { // Update subscript TORCH_SYM_CHECK( label_size[s].sym_eq(1).sym_or(label_size[s].sym_eq(op.sym_size(dim))), @@ -572,17 +574,22 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr SmallVector a_dims_to_sum; SmallVector b_dims_to_sum; for (auto dim = out_num_dim; dim < perm_index; ++dim) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(a.sym_size(dim).sym_ne(1)) - && TORCH_GUARD_SIZE_OBLIVIOUS(b.sym_size(dim).sym_ne(1))) { + auto sa = TORCH_GUARD_OR_TRUE(a.sym_size(dim).sym_ne(1)); + auto sb = TORCH_GUARD_OR_TRUE(b.sym_size(dim).sym_ne(1)); + + if (sa && sb) { + // if both a and b are equal, or we can't tell that its a broadcast for sure, + // we assume non-broadcast. + TORCH_SYM_CHECK(a.sym_size(dim).sym_eq(b.sym_size(dim)), "non-broadcast dimensions must match"); if (--dim_counts[dim] == 1) { sum_dims.push_back(dim); dim_counts[dim] = 0; } } else if (dim_counts[dim] == 1) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(a.sym_size(dim).sym_ne(1))) { + if (sa) { a_dims_to_sum.push_back(dim); dim_counts[dim] = 0; - } else if (TORCH_GUARD_SIZE_OBLIVIOUS(b.sym_size(dim).sym_ne(1))) { + } else if (sb) { b_dims_to_sum.push_back(dim); dim_counts[dim] = 0; } diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index c9e3ab9e8bc273..265bc112adcc2e 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -127,6 +127,9 @@ TORCH_IMPL_FUNC(smooth_l1_loss_out) TORCH_IMPL_FUNC(mse_loss_out) (const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& result) { + TORCH_CHECK(input.device() == target.device(), + "Expected all tensors to be on the same device, but found at least two devices, ", + input.device(), " and ", target.device(), "!"); if (reduction != Reduction::None) { Tensor loss; auto iter = TensorIterator::borrowing_binary_op(loss, input, target); diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index b6d44fca5901c9..a372c5f0c7e549 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -2862,7 +2862,7 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { T q = x; T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -2910,7 +2910,7 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { T q = x + x; T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -2966,7 +2966,7 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { T q = x + x - T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -3026,7 +3026,7 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { T q = x + x + T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -3150,7 +3150,7 @@ inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { T q = T(1.0) - x; T r; - for (int64_t k = 1; k < n; k++) { + for (int64_t k = 1; (k < n) && !std::isnan(q); k++) { r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); p = q; q = r; @@ -3190,7 +3190,7 @@ inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { T q = x; T r; - for (int64_t k = 1; k < n; k++) { + for (int64_t k = 1; (k < n) && !std::isnan(q); k++) { r = ((k + k + 1) * x * q - k * p) / (k + 1); p = q; q = r; @@ -3733,7 +3733,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) T q = x + x - T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -3785,7 +3785,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) T q = x + x - T(1.0) + (x + x - T(1.0)); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -3841,7 +3841,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -3897,7 +3897,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index fb4ce917bf1654..ecad7d7f341970 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -521,17 +521,17 @@ BatchNormBackend _select_batch_norm_backend( } if ( - input.is_cuda() + detail::getCUDAHooks().compiledWithMIOpen() + && cudnn_enabled + && input.is_cuda() && input.dim() <= MIOPEN_DIM_MAX + && input.dim() >= 3 && input.scalar_type() != at::kDouble - && input.scalar_type() != at::kBFloat16 - && (weight.scalar_type() != at::kHalf) + && (detail::getCUDAHooks().versionMIOpen() >= 30400 || input.scalar_type() != at::kBFloat16) + && weight.scalar_type() == at::kFloat // only FP32 weight for FP32 or FP16/BF16(mixed) input && weight.defined() && bias.defined() && ((running_mean.defined() && running_var.defined()) || (!running_mean.defined() && !running_var.defined() && training)) - && (input.dim() >= 3) - && detail::getCUDAHooks().compiledWithMIOpen() - && cudnn_enabled && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d ) { diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp index 5ecc0f1593315d..24b745b1a68b00 100644 --- a/aten/src/ATen/native/RangeFactories.cpp +++ b/aten/src/ATen/native/RangeFactories.cpp @@ -157,12 +157,8 @@ Tensor& range_out(const Scalar& start, const Scalar& end, const Scalar& step, Te auto xend = end.to(); auto xstep = step.to(); - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and lower bound inconsistent with step sign"); + arange_check_bounds(start, end, step); + int64_t size = static_cast(((xend - xstart) / xstep) + 1); if (result.numel() != size) { result.resize_({size}); diff --git a/aten/src/ATen/native/RangeUtils.h b/aten/src/ATen/native/RangeUtils.h index d3ad1c6ab7df77..dcab86ca9a42c4 100644 --- a/aten/src/ATen/native/RangeUtils.h +++ b/aten/src/ATen/native/RangeUtils.h @@ -6,19 +6,30 @@ namespace at::native { +inline void arange_check_bounds( + const c10::Scalar& start, + const c10::Scalar& end, + const c10::Scalar& step) { + // use double precision for validation to avoid precision issues + double dstart = start.to(); + double dend = end.to(); + double dstep = step.to(); + + TORCH_CHECK(dstep > 0 || dstep < 0, "step must be nonzero"); + TORCH_CHECK( + std::isfinite(dstart) && std::isfinite(dend), + "unsupported range: ", + dstart, + " -> ", + dend); + TORCH_CHECK( + ((dstep > 0) && (dend >= dstart)) || ((dstep < 0) && (dend <= dstart)), + "upper bound and lower bound inconsistent with step sign"); +} + template int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) { - using accscalar_t = at::acc_type; - auto xstart = start.to(); - auto xend = end.to(); - auto xstep = step.to(); - - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); // we use double precision for (start - end) / step // to compute size_d for consistency across devices. @@ -29,6 +40,10 @@ int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar // the corner-case we do want to take into account is int64_t, which has higher precision than double double size_d; if constexpr (std::is_same_v) { + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); int64_t sgn = (xstep > 0) - (xstep < 0); size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); } else { diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index b5470727a9b483..d5c661c158d3f6 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -1451,7 +1451,7 @@ Tensor& nanmean_out( "nanmean(): expected input to have floating point or complex dtype but got ", self.scalar_type()); const auto factor = at::native::isnan(self).logical_not_().sum(dim, keepdim); - at::native::nansum_out(self, dim, keepdim, opt_dtype, result).div_(factor); + at::nansum_out(result, self, dim, keepdim, opt_dtype).div_(factor); return result; } diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp index fe1db473ea4cba..11b528b445ed42 100644 --- a/aten/src/ATen/native/Repeat.cpp +++ b/aten/src/ATen/native/Repeat.cpp @@ -74,7 +74,7 @@ Tensor repeat_interleave_symint( } Tensor repeats_ = repeats; - if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) { + if (repeats.dim() == 0 || (repeats.dim() == 1 && TORCH_GUARD_OR_FALSE(repeats.sym_size(0).sym_eq(1)))) { repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())}); } else if (repeats.dim() == 1) { TORCH_CHECK( diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 9111e4a0800737..3346cd2cb220e4 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -101,7 +101,7 @@ inline void checkInBoundsForStorage( // It's ok to always evaluate to False for this early return for SymInts because // (1) maybe_convert_symint below only installs guard for int64_t case // (2) we check for this condition in the TORCH_MAYBE_SYM_CHECK below - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(storage_size_bytes, 0))) { + if (TORCH_GUARD_OR_FALSE(sym_eq(storage_size_bytes, 0))) { // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel. return; } @@ -138,7 +138,7 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, // storageOffset TORCH_CHECK( - storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset); + TORCH_GUARD_OR_TRUE(sym_ge(storage_offset, 0)), "Tensor: invalid storage offset ", storage_offset); // set_storage_{device} (except set_storage_meta__symint) // will (unsafely) set the storage offset and then call resize_impl that diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 8b71533b370b65..67c0af9212bc77 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -2153,81 +2153,53 @@ static void _scatter_via_index_put( const Tensor& src, const Tensor& mut_out, bool accumulate) { - if (self.dim() == 1) { - torch::List> indices; - indices.reserve(1); - indices.push_back(index); - mut_out.index_put_(indices, src, accumulate); - } else { - Tensor mut_out_contig = mut_out.contiguous(); - - auto index_coords_sizes = index.sizes().vec(); - index_coords_sizes.push_back(self.dim()); - auto index_coords = at::empty( - index_coords_sizes, - at::TensorOptions().dtype(at::ScalarType::Long).device(self.device())); + // If index is expanded with zero strides across non-scatter dimensions, + // advanced indexing with the index tensor alone achieves the desired + // semantics and avoids creating large intermediate tensors. + bool broadcast_index = true; + for (const auto i : c10::irange(index.dim())) { + if (i == dim) { + continue; + } + if (index.stride(i) != 0) { + broadcast_index = false; + break; + } + } - for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) { - if (dim_other == dim) { - continue; - } - auto dim_coord_vals = at::arange( - index.size(dim_other), at::TensorOptions().device(self.device())); + auto src_view = at::as_strided(src, index.sizes(), src.strides()); + torch::List> indices; + indices.reserve(self.dim()); - for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1; - dim_unsqueeze++) { - dim_coord_vals = - dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0); + if (self.dim() == 1 || broadcast_index) { + Tensor squeezed = index; + if (broadcast_index && index.dim() > 1) { + for (const auto d : c10::irange(index.dim())) { + if (d == dim) { + continue; + } + squeezed = squeezed.select(d, 0); } - - auto view_sizes = index.sizes().vec(); - view_sizes.push_back(1); - auto view_strides = index_coords.strides().vec(); - view_strides[self.dim()] = self.dim(); - - at::as_strided(index_coords, view_sizes, view_strides, dim_other) - .copy_(dim_coord_vals.unsqueeze(-1)); } + for ([[maybe_unused]] const auto d : c10::irange(dim)) { + indices.push_back(Tensor()); + } + indices.push_back(squeezed); + mut_out.index_put_(indices, src_view, accumulate); + return; + } - auto view_sizes = index.sizes().vec(); - view_sizes.push_back(1); - auto view_strides = index_coords.strides().vec(); - view_strides[self.dim()] = self.dim(); - - at::as_strided(index_coords, view_sizes, view_strides, dim) - .copy_(index.unsqueeze(-1)); - - Tensor index_coords_flat = index_coords.flatten(0, -2); - - // Copy mut_out_contig's strides into a tensor - // TODO: Is there a utility function that already does this? - IntArrayRef mut_out_contig_strides = mut_out_contig.strides(); - Tensor coord_strides = at::empty( - {mut_out_contig.dim()}, - TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU)); - std::memcpy( - coord_strides.mutable_data_ptr(), - mut_out_contig_strides.data(), - coord_strides.nbytes()); - coord_strides = coord_strides.to(mut_out_contig.device()); - - // `index_flat` contains the 1-D indices corresponding with the - // flattened `mut_out` - Tensor index_flat = (index_coords_flat * coord_strides).sum({-1}); - Tensor mut_out_flat = mut_out_contig.flatten(); - Tensor src_flat = - at::as_strided(src, index.sizes(), src.strides()).flatten(); - - torch::List> indices; - indices.reserve(1); - indices.push_back(index_flat); - - mut_out_flat.index_put_(indices, src_flat, accumulate); - - if (!mut_out.is_contiguous()) { - mut_out.copy_(mut_out_flat.reshape(mut_out.sizes())); + for (const auto d : c10::irange(self.dim())) { + if (d == dim) { + indices.push_back(index); + } else { + auto arange = at::arange(index.size(d), index.options()); + std::vector shape(index.dim(), 1); + shape[d] = index.size(d); + indices.push_back(arange.view(shape).expand(index.sizes())); } } + mut_out.index_put_(indices, src_view, accumulate); } template < diff --git a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h index 05009e96a7c4a9..0a200f157d5112 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h +++ b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h @@ -71,7 +71,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { checkIndexTensorTypes(orig, /*allow_int*/ true); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more // LongTensors - auto indices = expandTensors(self, orig); + auto indices = expandTensors(self, orig, /*ensure_same_device=*/true); // next broadcast all index tensors together try { indices = expand_outplace(indices); @@ -91,12 +91,6 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { if (!hasContiguousSubspace(indices)) { std::tie(self, indices) = transposeToFront(self, indices); } - // Ensure indices are on the same device as self - for (auto& indice : indices) { - if (indice.defined() && indice.device() != self.device()) { - indice = indice.to(self.device()); - } - } for (auto& indice : indices) { if (indice.defined() && indice.dtype() == at::kInt) { indice = indice.to(at::kLong); diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 37b46b5a384648..d9a42da482c02e 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -89,6 +89,16 @@ static inline void check_for_unsupported_isin_dtype(const ScalarType type) { type); } +static inline void check_for_unsupported_clamp_dtypes(ScalarType dtype) { + TORCH_CHECK_NOT_IMPLEMENTED( + !isComplexType(dtype), "clamp is not supported for complex types"); +} + +static inline void check_for_unsupported_clamp_dtypes(const Scalar& s) { + TORCH_CHECK_NOT_IMPLEMENTED( + !s.isComplex(), "clamp is not supported for complex types"); +} + TORCH_META_FUNC(clamp) (const Tensor& self, const OptionalScalarRef min, const OptionalScalarRef max) { if (!min && !max) { @@ -96,9 +106,8 @@ TORCH_META_FUNC(clamp) false, "torch.clamp: At least one of 'min' or 'max' must not be None"); } // Manual type promotion, since scalars have to participate in it - ScalarType result_type = self.scalar_type(); - TORCH_CHECK( - !isComplexType(result_type), "clamp is not supported for complex types"); + auto result_type = self.scalar_type(); + check_for_unsupported_clamp_dtypes(result_type); // Floating is the highest supported if (!isFloatingType(result_type)) { at::native::ResultTypeState state = {}; @@ -122,8 +131,7 @@ TORCH_META_FUNC(clamp) self.dtype()); } // make sure scalars weren't complex - TORCH_CHECK( - !isComplexType(result_type), "clamp is not supported for complex types"); + check_for_unsupported_clamp_dtypes(result_type); build_unary_op(maybe_get_output(), self.to(result_type)); } @@ -132,9 +140,7 @@ TORCH_META_FUNC2(clamp, Tensor) TORCH_CHECK( min || max, "torch.clamp: At least one of 'min' or 'max' must not be None"); - TORCH_CHECK( - !isComplexType(self.scalar_type()), - "clamp is not supported for complex types"); + check_for_unsupported_clamp_dtypes(self.scalar_type()); #define CLAMP_CONFIG() \ TensorIteratorConfig() \ .set_check_mem_overlap(true) \ @@ -157,10 +163,9 @@ TORCH_META_FUNC(clamp_max)(const Tensor& self, const Scalar& max) { // we could wrap max into tensor and send to tensor overload, // but relu is implemented via clamp_min, so for perf an uniformity reasons // do a faster but correct thing - ScalarType result_type = self.scalar_type(); - TORCH_CHECK( - !isComplexType(result_type), "clamp is not supported for complex types"); - TORCH_CHECK(!max.isComplex(), "clamp is not supported for complex types"); + auto result_type = self.scalar_type(); + check_for_unsupported_clamp_dtypes(result_type); + check_for_unsupported_clamp_dtypes(max); // Floating is the highest supported if (!isFloatingType(result_type)) { auto result_type = at::native::result_type(self, max); @@ -183,10 +188,9 @@ TORCH_META_FUNC2(clamp_max, Tensor)(const Tensor& self, const Tensor& max) { } TORCH_META_FUNC(clamp_min)(const Tensor& self, const Scalar& min) { - ScalarType result_type = self.scalar_type(); - TORCH_CHECK( - !isComplexType(result_type), "clamp is not supported for complex types"); - TORCH_CHECK(!min.isComplex(), "clamp is not supported for complex types"); + auto result_type = self.scalar_type(); + check_for_unsupported_clamp_dtypes(result_type); + check_for_unsupported_clamp_dtypes(min); // Floating is the highest supported if (!isFloatingType(result_type)) { auto result_type = at::native::result_type(self, min); diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 006e54c2495a0d..1aab4b11c9634e 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -1212,6 +1212,28 @@ Tensor randint_like( return result.random_(0, high, std::nullopt); } +Tensor randint_like( + const Tensor& self, + const Tensor& high, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory, + std::optional optional_memory_format) { + TORCH_CHECK( + high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(), + "high must be a scalar tensor and on CPU"); + int64_t high_scalar = high.item(); + return at::native::randint_like( + self, + high_scalar, + dtype, + layout, + device, + pin_memory, + optional_memory_format); +} + Tensor randint_like( const Tensor& self, int64_t low, diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 5a4d55e0e3cb42..77acfe47363e4f 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -113,7 +113,7 @@ Tensor& detach_(Tensor& self) { } Tensor contiguous(const Tensor& self, MemoryFormat memory_format) { - if (self.is_contiguous(memory_format)) { + if (self.is_contiguous_or_false(memory_format)) { return self; } TORCH_CHECK( diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index f04e6cac631fda..958c80f2c7f0b8 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -431,7 +431,7 @@ Tensor& set_storage_meta__symint( size, stride, storage_offset); // Matches maybe_resize_storage_cpu no-numel behavior - if (TORCH_GUARD_SIZE_OBLIVIOUS(result.sym_numel().sym_ne(0))) { + if (TORCH_GUARD_OR_TRUE(result.sym_numel().sym_ne(0))) { // maybe_resize_storage_cpu can handle no storage exists at all but // that should never be the case here TORCH_INTERNAL_ASSERT(storage); @@ -440,12 +440,7 @@ Tensor& set_storage_meta__symint( // All meta data pointers are the same, so we don't have to "re" allocate // it. TODO: Actually this might not quite be correct if we use special // pointers to track whether or not fake cuda tensors are pinned or not - const auto itemsize = result.dtype().itemsize(); - c10::SymInt new_size_bytes = result.is_contiguous() - ? at::detail::computeStorageNbytesContiguous( - size, itemsize, std::move(storage_offset)) - : at::detail::computeStorageNbytes( - size, stride, itemsize, std::move(storage_offset)); + // TODO: When there are unbacked SymInts, we unconditionally skip the // setter. This is technically wrong, but we cannot conveniently test // the real condition in many cases, because a lot of people are using @@ -454,10 +449,19 @@ Tensor& set_storage_meta__symint( // // The old behavior was to unconditionally set_nbytes, but I think not // setting it is more safe. - if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && - TORCH_GUARD_SIZE_OBLIVIOUS( - new_size_bytes.sym_gt(storage.sym_nbytes()))) { - storage.set_nbytes(std::move(new_size_bytes)); + if (result.sym_numel().has_hint()) { + const auto itemsize = result.dtype().itemsize(); + + c10::SymInt new_size_bytes = result.is_contiguous() + ? at::detail::computeStorageNbytesContiguous( + size, itemsize, std::move(storage_offset)) + : at::detail::computeStorageNbytes( + size, stride, itemsize, std::move(storage_offset)); + + if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && + (new_size_bytes > storage.sym_nbytes())) { + storage.set_nbytes(std::move(new_size_bytes)); + } } } return result; @@ -758,22 +762,22 @@ TORCH_IMPL_FUNC(cat_out_cpu) } Tensor& cat_out(TensorList tensors, Dimname dim, Tensor& result) { - TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors"); + TORCH_CHECK_VALUE(!tensors.empty(), "expected a non-empty list of Tensors"); return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim)); } Tensor cat(TensorList tensors, Dimname dim) { - TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors"); + TORCH_CHECK_VALUE(!tensors.empty(), "expected a non-empty list of Tensors"); return at::cat(tensors, dimname_to_position(tensors[0], dim)); } // torch.concat, alias for torch.cat Tensor& concat_out(TensorList tensors, Dimname dim, Tensor& result) { - return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim)); + return cat_out(tensors, dim, result); } Tensor concat(TensorList tensors, Dimname dim) { - return at::cat(tensors, dimname_to_position(tensors[0], dim)); + return at::cat(tensors, dim); } Tensor& concat_out(TensorList tensors, int64_t dim, Tensor& result) { @@ -786,11 +790,11 @@ Tensor concat(TensorList tensors, int64_t dim) { // torch.concatenate, alias for torch.cat Tensor& concatenate_out(TensorList tensors, Dimname dim, Tensor& result) { - return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim)); + return cat_out(tensors, dim, result); } Tensor concatenate(TensorList tensors, Dimname dim) { - return at::cat(tensors, dimname_to_position(tensors[0], dim)); + return at::cat(tensors, dim); } Tensor& concatenate_out(TensorList tensors, int64_t dim, Tensor& result) { @@ -1994,19 +1998,18 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); } - auto sym_sizes = self.sym_sizes(); - auto sym_strides = self.sym_strides(); - auto sym_numel = self.sym_numel(); - if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) && - !self.is_mkldnn()) { + if (self.is_contiguous_or_false() && !self.is_mkldnn()) { return self.view_symint(proposed_shape); } + auto sym_numel = self.sym_numel(); c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel); if (self.is_mkldnn()) { return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape)); } + auto sym_sizes = self.sym_sizes(); + auto sym_strides = self.sym_strides(); // `computeStride` returns the proper strides to use if this // `reshape` can be just a view. diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 420a81767fbadc..f849283043d376 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -258,26 +258,26 @@ TORCH_META_FUNC(neg)(const Tensor& self) { TORCH_META_FUNC(trunc) (const Tensor& self) { // Note: this is consistent with NumPy - TORCH_CHECK(!self.is_complex(), + TORCH_CHECK_NOT_IMPLEMENTED(!self.is_complex(), "trunc is not supported for complex inputs"); build_borrowing_unary_op(maybe_get_output(), self); } TORCH_META_FUNC(floor) (const Tensor& self) { // Note: this is consistent with NumPy - TORCH_CHECK(!self.is_complex(), + TORCH_CHECK_NOT_IMPLEMENTED(!self.is_complex(), "floor is not supported for complex inputs"); build_borrowing_unary_op(maybe_get_output(), self); } TORCH_META_FUNC(sign) (const Tensor& self) { - TORCH_CHECK(!self.is_complex(), + TORCH_CHECK_NOT_IMPLEMENTED(!self.is_complex(), "Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead."); build_borrowing_unary_op(maybe_get_output(), self); } TORCH_META_FUNC(signbit) (const Tensor& self) { - TORCH_CHECK(!self.is_complex(), "signbit is not implemented for complex tensors."); + TORCH_CHECK_NOT_IMPLEMENTED(!self.is_complex(), "signbit is not implemented for complex tensors."); TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true, "signbit does not support non-boolean outputs."); build_borrowing_unary_force_boolean_op(maybe_get_output(), self); @@ -285,7 +285,7 @@ TORCH_META_FUNC(signbit) (const Tensor& self) { TORCH_META_FUNC(ceil) (const Tensor& self) { // Note: this is consistent with NumPy - TORCH_CHECK(!self.is_complex(), + TORCH_CHECK_NOT_IMPLEMENTED(!self.is_complex(), "ceil is not supported for complex inputs"); build_borrowing_unary_op(maybe_get_output(), self); } diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index ffa0b6c4f2b41e..47e2ef12b3b49e 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -106,7 +106,6 @@ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional), random_ DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub) DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub) -DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub) DECLARE_DISPATCH( void (*)(Tensor&, const Tensor&, int64_t, std::optional), multinomial_with_replacement_stub) diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 82e7dfd213f58f..ab3b16c395a3cc 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -100,6 +100,7 @@ auto sum(int64_t N, Func f) { } template +__ubsan_ignore_signed_int_overflow__ std::enable_if_t, void> gemm_notrans_( int64_t m, @@ -117,18 +118,19 @@ gemm_notrans_( scale_(m, n, beta, c, ldc); // c += alpha * (a @ b) - for (const auto l : c10::irange(k)) { - for (const auto j : c10::irange(n)) { + const uint64_t unsigned_m = static_cast(m); + const uint64_t i_m = unsigned_m / 4; + for (const uint64_t l : c10::irange(k)) { + for (const uint64_t j : c10::irange(n)) { opmath_t val = b[l + j * ldb] * alpha; - int64_t i_m = m / 4; for (const auto i_i : c10::irange(i_m)) { c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val; c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val; c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val; c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val; } - int64_t i = i_m * 4; - for (; i < m; i++) + uint64_t i = i_m * 4; + for (; i < unsigned_m; i++) c[j * ldc + i] += a[i + l * lda] * val; } } diff --git a/aten/src/ATen/native/cpu/Elu.h b/aten/src/ATen/native/cpu/Elu.h index acce2756f85ab9..d438d4303709b1 100644 --- a/aten/src/ATen/native/cpu/Elu.h +++ b/aten/src/ATen/native/cpu/Elu.h @@ -11,6 +11,7 @@ #include // For c10::is_reduced_floating_point_v. namespace at::native { +inline namespace CPU_CAPABILITY { /** * Return a function object that calculates ELU with the given * parameters on its input element. ParamT is the type of the input @@ -24,7 +25,7 @@ auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale const auto poscoef = scale; const auto negiptcoef = input_scale; return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT { - return MathT(a) <= MathT(0) + return MathT(a) < MathT(0) ? std::expm1(MathT(a) * negiptcoef) * negcoef : MathT(a) * poscoef; }; @@ -42,7 +43,7 @@ auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) { const vec::Vectorized negiptcoef_vec(input_scale); const vec::Vectorized zero_vec(static_cast(0)); return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized a) -> vec::Vectorized { - const auto cmp = a > zero_vec; + const auto cmp = a >= zero_vec; if (!cmp.zero_mask()) { return a * poscoef_vec; } else { @@ -69,4 +70,5 @@ auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_s return vec::convert_from_float(res0, res1); }; } +} // namespace CPU_CAPABILITY } // namespace at::native diff --git a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp index 3d323ba2838331..ba03472da3351b 100644 --- a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp +++ b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp @@ -201,6 +201,9 @@ void reshape_attn_mask_to_4d( attn_mask = attn_mask .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)}) .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); + if (attn_mask.sym_stride(-1) != 1 && attn_mask.sym_stride(-1) != 0) { + attn_mask = attn_mask.contiguous(); + } } template diff --git a/aten/src/ATen/native/cpu/Gelu.h b/aten/src/ATen/native/cpu/Gelu.h index 637d7e3d8e1dda..613c69e225864e 100644 --- a/aten/src/ATen/native/cpu/Gelu.h +++ b/aten/src/ATen/native/cpu/Gelu.h @@ -12,6 +12,7 @@ #include // For c10::is_reduced_floating_point_v. namespace at::native { +inline namespace CPU_CAPABILITY { constexpr double kGeluBeta = M_SQRT2 * M_2_SQRTPI * 0.5; constexpr double kGeluKappa = 0.044715; @@ -78,5 +79,5 @@ vec::Vectorized vectorized_gelu(vec::Vectorized x) { return at::vec::convert_from_float(vectorized_gelu(x0), vectorized_gelu(x1)); } - -} // namespace +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/aten/src/ATen/native/cpu/LogSoftmaxKernelImpl.h b/aten/src/ATen/native/cpu/LogSoftmaxKernelImpl.h new file mode 100644 index 00000000000000..b8af353e8866cf --- /dev/null +++ b/aten/src/ATen/native/cpu/LogSoftmaxKernelImpl.h @@ -0,0 +1,337 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { +template +int64_t vec_log_softmax_lastdim_chunk_size(int64_t grain_size, int64_t outer_size, int64_t dim_size) { + // Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the + // size of L1D cache on many processors. Some processors have 48 KB L1D cache + // nowadays, so maybe in the future, we can leverage the knowledge of a + // machine's L1D cache size. + int64_t MAX_CHUNK_SIZE = std::max( + 1, + grain_size / (sizeof(scalar_t) * dim_size)); + return std::min(MAX_CHUNK_SIZE, outer_size); +} + +template +void serial_vec_log_softmax_lastdim_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t dim_size, + int64_t chunk_size, + int64_t begin, + int64_t end) { + if (end <= begin) { + return; + } + using Vec = vec::Vectorized>; + // MSVC requires such a declaration of dynamic arrays + // Source: https://stackoverflow.com/a/33423538 + auto tmp_sum_scalar = std::make_unique(chunk_size); + auto max_input_arr = std::make_unique(chunk_size); + for (int64_t ii = begin; ii < end; ii += chunk_size) { + int64_t loop_end = chunk_size; + if (ii + chunk_size > end) { + loop_end = end - ii; + } + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + max_input_arr[j] = vec::reduce_all( + [](Vec& x, Vec& y) { return vec::maximum(x, y); }, + input_data, + dim_size); + } + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + scalar_t max_input = max_input_arr[j]; + tmp_sum_scalar[j] = vec::map_reduce_all( + [max_input](Vec x) { return (x - Vec(max_input)).exp(); }, + [](Vec x, Vec y) { return x + y; }, + input_data, + dim_size); + } + // See [Note AVX-SSE transitions] for why this should call the + // vectorized version (aside from perf improvements). + vec::map( + [](Vec x) { return x.log(); }, + tmp_sum_scalar.get(), + tmp_sum_scalar.get(), + loop_end); + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + scalar_t* output_data = output_data_base + i * dim_size; + scalar_t tmp_sum = tmp_sum_scalar[j]; + scalar_t max_input = max_input_arr[j]; + + // It's necessary to keep the order of the operations below. + // In some cases that input is large digits and the difference + // is small, if we compute `max_input` plus `tmp_sum` before, + // there would be a numerical problem. See an example in + // https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379 + vec::map( + [tmp_sum, max_input](Vec x) { + return x - Vec(max_input) - Vec(tmp_sum); + }, + output_data, + input_data, + dim_size); + } + } +} + +// Can't include ATen/Parallel.h. +// TODO: find a way to have only one copy of divup. +inline int64_t divup(int64_t x, int64_t y) { + return (x + y - 1) / y; +} + +template +std::pair vec_logsoftmax_chunk_size_and_num_chunks(int64_t inner_size, int64_t dim_size) { + using Vec = vec::Vectorized; + int64_t MAX_CHUNK_SIZE = std::max(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size()); + MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size(); + int64_t CHUNK_SIZE = std::min(MAX_CHUNK_SIZE, inner_size); + int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + return {CHUNK_SIZE, num_chunks}; +} + +template +std::enable_if_t>, void> +serial_vec_logsoftmax_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t inner_size, + int64_t chunk_size, + int64_t num_chunks, + int64_t dim_size, + int64_t begin, + int64_t end) { + using Vec = vec::Vectorized; + // thread local temp buffer which holds vertical reduction result: max and sum. + auto buffer = std::make_unique(chunk_size * 2); + scalar_t* input_max_data = buffer.get(); + scalar_t* tmp_sum_data = buffer.get() + chunk_size; + + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * chunk_size; + int64_t size = std::min(chunk_size, inner_size - inner_idx_begin); + + // init + Vec zero_vec = Vec(scalar_t(0)); + Vec min_vec = Vec(-std::numeric_limits::infinity()); + int64_t d0 = 0; + for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { + min_vec.store(input_max_data + d0); + zero_vec.store(tmp_sum_data + d0); + } + for (; d0 < size; d0++) { + input_max_data[d0] = -std::numeric_limits::infinity(); + tmp_sum_data[d0] = scalar_t(0); + } + + // compute max + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d1); + Vec max_vec = Vec::loadu(input_max_data + d1); + max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec); + max_vec.store(input_max_data + d1); + } + for (; d1 < size; d1++) { + scalar_t data_val = input_ptr[d1]; + scalar_t max_val = input_max_data[d1]; + input_max_data[d1] = data_val > max_val ? data_val : max_val; + } + } + + // compute sum of (x - max).exp() + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d2); + Vec sum_vec = Vec::loadu(tmp_sum_data + d2); + Vec max_vec = Vec::loadu(input_max_data + d2); + sum_vec += (data_vec - max_vec).exp(); + sum_vec.store(tmp_sum_data + d2); + } + for (; d2 < size; d2++) { + scalar_t data_val = input_ptr[d2]; + scalar_t max_val = input_max_data[d2]; + tmp_sum_data[d2] += std::exp(data_val - max_val); + } + } + + // apply log + vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size); + + // compute x - max - sum + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin; + const scalar_t* input_ptr = input_data_base + offset; + scalar_t* output_ptr = output_data_base + offset; + + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d3); + Vec max_vec = Vec::loadu(input_max_data + d3); + Vec sum_vec = Vec::loadu(tmp_sum_data + d3); + Vec out_vec = data_vec - max_vec - sum_vec; + out_vec.store(output_ptr + d3); + } + for (; d3 < size; d3++) { + output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]; + } + } + } +} + +template +std::enable_if_t>, void> +serial_vec_logsoftmax_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t inner_size, + int64_t chunk_size, + int64_t num_chunks, + int64_t dim_size, + int64_t begin, + int64_t end) { + using Vec = vec::Vectorized; + using fVec = vec::Vectorized; + auto buffer = std::make_unique(chunk_size * 2); + float* input_max_data = buffer.get(); + float* tmp_sum_data = buffer.get() + chunk_size; + + // thread local buffer that holds input data in float32 to save next 2 dtype conversion + auto input_buffer = std::make_unique(dim_size * chunk_size); + float* input_buffer_data = input_buffer.get(); + + // init + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * chunk_size; + int64_t size = std::min(chunk_size, inner_size - inner_idx_begin); + + fVec zero_fvec = fVec(float(0)); + fVec min_fvec = fVec(-std::numeric_limits::infinity()); + int64_t d0 = 0; + for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { + min_fvec.store(input_max_data + d0); + min_fvec.store(input_max_data + d0 + fVec::size()); + zero_fvec.store(tmp_sum_data + d0); + zero_fvec.store(tmp_sum_data + d0 + fVec::size()); + } + for (; d0 < size; d0++) { + input_max_data[d0] = -std::numeric_limits::infinity(); + tmp_sum_data[d0] = float(0); + } + + // compute max + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d1); + auto [data_fvec0, data_fvec1] = vec::convert_to_float(data_vec); + fVec max_fvec0 = fVec::loadu(input_max_data + d1); + fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size()); + max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0); + max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1); + max_fvec0.store(input_max_data + d1); + max_fvec1.store(input_max_data + d1 + fVec::size()); + + // cache the 'converted' float input + data_fvec0.store(input_buffer_ptr + d1); + data_fvec1.store(input_buffer_ptr + d1 + fVec::size()); + } + for (; d1 < size; d1++) { + float data_val = float(input_ptr[d1]); + float max_val = input_max_data[d1]; + input_max_data[d1] = data_val > max_val ? data_val : max_val; + input_buffer_ptr[d1] = data_val; + } + } + + // compute sum of (x - max).exp() + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2); + fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size()); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size()); + fVec max_fvec0 = fVec::loadu(input_max_data + d2); + fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size()); + sum_fvec0 += (data_fvec0 - max_fvec0).exp(); + sum_fvec1 += (data_fvec1 - max_fvec1).exp(); + sum_fvec0.store(tmp_sum_data + d2); + sum_fvec1.store(tmp_sum_data + d2 + fVec::size()); + } + for (; d2 < size; d2++) { + float data_val = input_buffer_ptr[d2]; + float max_val = input_max_data[d2]; + tmp_sum_data[d2] += std::exp(data_val - max_val); + } + } + + // apply log + vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size); + + // compute x - max - sum + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3); + fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size()); + fVec max_fvec0 = fVec::loadu(input_max_data + d3); + fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size()); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size()); + fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0; + fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1; + Vec out_vec = vec::convert_from_float(out_fvec0, out_fvec1); + out_vec.store(output_ptr + d3); + } + for (; d3 < size; d3++) { + output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]); + } + } + } +} // namespace CPU_CAPABILITY +}} // namespace at::native diff --git a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp index 7855191d6c0674..317647123d4c03 100644 --- a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp +++ b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp @@ -2,6 +2,8 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include + #include #include #include @@ -28,7 +30,6 @@ // We use a chunk size such that it'd fit in L1D. namespace at::native { - namespace { template inline void _vec_log_softmax_lastdim( @@ -36,15 +37,10 @@ inline void _vec_log_softmax_lastdim( scalar_t* output_data_base, int64_t outer_size, int64_t dim_size) { - using Vec = vec::Vectorized>; - // Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the - // size of L1D cache on many processors. Some processors have 48 KB L1D cache - // nowadays, so maybe in the future, we can leverage the knowledge of a - // machine's L1D cache size. - int64_t MAX_CHUNK_SIZE = std::max( - 1, - at::internal::GRAIN_SIZE / (sizeof(scalar_t) * dim_size)); - int64_t CHUNK_SIZE = std::min(MAX_CHUNK_SIZE, outer_size); + const auto chunk_size = vec_log_softmax_lastdim_chunk_size( + at::internal::GRAIN_SIZE, + outer_size, + dim_size); // Note: grain_size value of 0 // We don't change the number of OpenMP threads in the OpenMP thread-pool, // so some threads do useful work, while others don't. @@ -52,60 +48,13 @@ inline void _vec_log_softmax_lastdim( // work among threads in an equitable manner. We compute CHUNK_SIZE to ensure // each thread's computations would be efficient. parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) { - // MSVC requires such a declaration of dynamic arrays - // Source: https://stackoverflow.com/a/33423538 - auto tmp_sum_scalar = std::make_unique(CHUNK_SIZE); - auto max_input_arr = std::make_unique(CHUNK_SIZE); - for (int64_t ii = begin; ii < end; ii += CHUNK_SIZE) { - int64_t loop_end = CHUNK_SIZE; - if (ii + CHUNK_SIZE > end) - loop_end = end - ii; - for (const auto j : c10::irange(loop_end)) { - int64_t i = ii + j; - const scalar_t* input_data = input_data_base + i * dim_size; - max_input_arr[j] = vec::reduce_all( - [](Vec& x, Vec& y) { return vec::maximum(x, y); }, - input_data, - dim_size); - } - for (const auto j : c10::irange(loop_end)) { - int64_t i = ii + j; - const scalar_t* input_data = input_data_base + i * dim_size; - scalar_t max_input = max_input_arr[j]; - tmp_sum_scalar[j] = vec::map_reduce_all( - [max_input](Vec x) { return (x - Vec(max_input)).exp(); }, - [](Vec x, Vec y) { return x + y; }, - input_data, - dim_size); - } - // See [Note AVX-SSE transitions] for why this should call the - // vectorized version (aside from perf improvements). - vec::map( - [](Vec x) { return x.log(); }, - tmp_sum_scalar.get(), - tmp_sum_scalar.get(), - loop_end); - for (const auto j : c10::irange(loop_end)) { - int64_t i = ii + j; - const scalar_t* input_data = input_data_base + i * dim_size; - scalar_t* output_data = output_data_base + i * dim_size; - scalar_t tmp_sum = tmp_sum_scalar[j]; - scalar_t max_input = max_input_arr[j]; - - // It's necessary to keep the order of the operations below. - // In some cases that input is large digits and the difference - // is small, if we compute `max_input` plus `tmp_sum` before, - // there would be a numerical problem. See an example in - // https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379 - vec::map( - [tmp_sum, max_input](Vec x) { - return x - Vec(max_input) - Vec(tmp_sum); - }, - output_data, - input_data, - dim_size); - } - } + serial_vec_log_softmax_lastdim_range( + input_data_base, + output_data_base, + dim_size, + chunk_size, + begin, + end); }); } @@ -891,100 +840,23 @@ _vec_logsoftmax( int64_t outer_size, int64_t inner_size, int64_t dim_size) { - using Vec = vec::Vectorized; - int64_t BLOCK_SIZE = 128 * 1024; - int64_t MAX_CHUNK_SIZE = std::max(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size()); - MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size(); - int64_t CHUNK_SIZE = std::min(MAX_CHUNK_SIZE, inner_size); - int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + const auto [CHUNK_SIZE_binding, num_chunks_binding] = vec_logsoftmax_chunk_size_and_num_chunks( + inner_size, dim_size); + // Work around "capturing a structured binding is not yet supported in OpenMP". + const auto CHUNK_SIZE = CHUNK_SIZE_binding; + const auto num_chunks = num_chunks_binding; // See Note: grain_size value of 0 at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) { - // thread local temp buffer which holds vertical reduction result: max and sum. - auto buffer = std::make_unique(CHUNK_SIZE * 2); - scalar_t* input_max_data = buffer.get(); - scalar_t* tmp_sum_data = buffer.get() + CHUNK_SIZE; - - for (int64_t i = begin; i < end; i++) { - int64_t outer_idx = i / num_chunks; - int64_t k = i % num_chunks; - int64_t inner_idx_begin = k * CHUNK_SIZE; - int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin); - - // init - Vec zero_vec = Vec(scalar_t(0)); - Vec min_vec = Vec(-std::numeric_limits::infinity()); - int64_t d0 = 0; - for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { - min_vec.store(input_max_data + d0); - zero_vec.store(tmp_sum_data + d0); - } - for (; d0 < size; d0++) { - input_max_data[d0] = -std::numeric_limits::infinity(); - tmp_sum_data[d0] = scalar_t(0); - } - - // compute max - for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { - const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size - + dim_idx * inner_size + inner_idx_begin; - - int64_t d1 = 0; - for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { - Vec data_vec = Vec::loadu(input_ptr + d1); - Vec max_vec = Vec::loadu(input_max_data + d1); - max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec); - max_vec.store(input_max_data + d1); - } - for (; d1 < size; d1++) { - scalar_t data_val = input_ptr[d1]; - scalar_t max_val = input_max_data[d1]; - input_max_data[d1] = data_val > max_val ? data_val : max_val; - } - } - - // compute sum of (x - max).exp() - for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { - const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size - + dim_idx * inner_size + inner_idx_begin; - - int64_t d2 = 0; - for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { - Vec data_vec = Vec::loadu(input_ptr + d2); - Vec sum_vec = Vec::loadu(tmp_sum_data + d2); - Vec max_vec = Vec::loadu(input_max_data + d2); - sum_vec += (data_vec - max_vec).exp(); - sum_vec.store(tmp_sum_data + d2); - } - for (; d2 < size; d2++) { - scalar_t data_val = input_ptr[d2]; - scalar_t max_val = input_max_data[d2]; - tmp_sum_data[d2] += std::exp(data_val - max_val); - } - } - - // apply log - vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size); - - // compute x - max - sum - for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { - int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin; - const scalar_t* input_ptr = input_data_base + offset; - scalar_t* output_ptr = output_data_base + offset; - - int64_t d3 = 0; - for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { - Vec data_vec = Vec::loadu(input_ptr + d3); - Vec max_vec = Vec::loadu(input_max_data + d3); - Vec sum_vec = Vec::loadu(tmp_sum_data + d3); - Vec out_vec = data_vec - max_vec - sum_vec; - out_vec.store(output_ptr + d3); - } - for (; d3 < size; d3++) { - output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]; - } - } - } + serial_vec_logsoftmax_range( + input_data_base, + output_data_base, + inner_size, + CHUNK_SIZE, + num_chunks, + dim_size, + begin, + end); }); } @@ -996,125 +868,23 @@ _vec_logsoftmax( int64_t outer_size, int64_t inner_size, int64_t dim_size) { - using Vec = vec::Vectorized; - using fVec = vec::Vectorized; - int64_t BLOCK_SIZE = 128 * 1024; - int64_t MAX_CHUNK_SIZE = std::max(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size()); - MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size(); - int64_t CHUNK_SIZE = std::min(MAX_CHUNK_SIZE, inner_size); - int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + const auto [CHUNK_SIZE_binding, num_chunks_binding] = vec_logsoftmax_chunk_size_and_num_chunks( + inner_size, dim_size); + // Work around "capturing a structured binding is not yet supported in OpenMP". + const auto CHUNK_SIZE = CHUNK_SIZE_binding; + const auto num_chunks = num_chunks_binding; // See Note: grain_size value of 0 at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) { - auto buffer = std::make_unique(CHUNK_SIZE * 2); - float* input_max_data = buffer.get(); - float* tmp_sum_data = buffer.get() + CHUNK_SIZE; - - // thread local buffer that holds input data in float32 to save next 2 dtype conversion - auto input_buffer = std::make_unique(dim_size * CHUNK_SIZE); - float* input_buffer_data = input_buffer.get(); - - // init - for (int64_t i = begin; i < end; i++) { - int64_t outer_idx = i / num_chunks; - int64_t k = i % num_chunks; - int64_t inner_idx_begin = k * CHUNK_SIZE; - int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin); - - fVec zero_fvec = fVec(float(0)); - fVec min_fvec = fVec(-std::numeric_limits::infinity()); - int64_t d0 = 0; - for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { - min_fvec.store(input_max_data + d0); - min_fvec.store(input_max_data + d0 + fVec::size()); - zero_fvec.store(tmp_sum_data + d0); - zero_fvec.store(tmp_sum_data + d0 + fVec::size()); - } - for (; d0 < size; d0++) { - input_max_data[d0] = -std::numeric_limits::infinity(); - tmp_sum_data[d0] = float(0); - } - - // compute max - for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { - const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size - + dim_idx * inner_size + inner_idx_begin; - float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE; - - int64_t d1 = 0; - for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { - Vec data_vec = Vec::loadu(input_ptr + d1); - auto [data_fvec0, data_fvec1] = vec::convert_to_float(data_vec); - fVec max_fvec0 = fVec::loadu(input_max_data + d1); - fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size()); - max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0); - max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1); - max_fvec0.store(input_max_data + d1); - max_fvec0.store(input_max_data + d1 + fVec::size()); - - // cache the 'converted' float input - data_fvec0.store(input_buffer_ptr + d1); - data_fvec1.store(input_buffer_ptr + d1 + fVec::size()); - } - for (; d1 < size; d1++) { - float data_val = float(input_ptr[d1]); - float max_val = input_max_data[d1]; - input_max_data[d1] = data_val > max_val ? data_val : max_val; - input_buffer_ptr[d1] = data_val; - } - } - - // compute sum of (x - max).exp() - for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { - float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE; - - int64_t d2 = 0; - for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { - fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2); - fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size()); - fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2); - fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size()); - fVec max_fvec0 = fVec::loadu(input_max_data + d2); - fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size()); - sum_fvec0 += (data_fvec0 - max_fvec0).exp(); - sum_fvec1 += (data_fvec1 - max_fvec1).exp(); - sum_fvec0.store(tmp_sum_data + d2); - sum_fvec1.store(tmp_sum_data + d2 + fVec::size()); - } - for (; d2 < size; d2++) { - float data_val = input_buffer_ptr[d2]; - float max_val = input_max_data[d2]; - tmp_sum_data[d2] += std::exp(data_val - max_val); - } - } - - // apply log - vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size); - - // compute x - max - sum - for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { - float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE; - scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size - + dim_idx * inner_size + inner_idx_begin; - - int64_t d3 = 0; - for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { - fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3); - fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size()); - fVec max_fvec0 = fVec::loadu(input_max_data + d3); - fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size()); - fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3); - fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size()); - fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0; - fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1; - Vec out_vec = vec::convert_from_float(out_fvec0, out_fvec1); - out_vec.store(output_ptr + d3); - } - for (; d3 < size; d3++) { - output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]); - } - } - } + serial_vec_logsoftmax_range( + input_data_base, + output_data_base, + inner_size, + CHUNK_SIZE, + num_chunks, + dim_size, + begin, + end); }); } diff --git a/aten/src/ATen/native/cpu/moments_utils.h b/aten/src/ATen/native/cpu/moments_utils.h index 6f403d60ea7c09..8aba425e896377 100644 --- a/aten/src/ATen/native/cpu/moments_utils.h +++ b/aten/src/ATen/native/cpu/moments_utils.h @@ -8,7 +8,6 @@ #include #include #include -#include #include namespace at::native { @@ -118,9 +117,11 @@ std::pair, opmath_t> RowwiseMomentsImpl(const T* X, int64_t N, in using Vec = vec::Vectorized; const Vec kZeroVec(math_t(0)); - c10::SmallVector m0_stk(depth, 0); - c10::SmallVector m1_stk(depth, kZeroVec); - c10::SmallVector m2_stk(depth, kZeroVec); + std::array m0_stk = {{0}}; + std::array m1_stk; + m1_stk.fill(kZeroVec); + std::array m2_stk; + m2_stk.fill(kZeroVec); for (const auto i : c10::irange(m)) { const T* X_ptr = X + i * kChunkSize * kVecSize; diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h index 3d4c39afe3c75c..e1c7e5c6074773 100644 --- a/aten/src/ATen/native/cpu/utils.h +++ b/aten/src/ATen/native/cpu/utils.h @@ -165,6 +165,12 @@ inline void transpose(int64_t M, int64_t N, const uint16_t* src, int64 TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); } + +template <> +inline void transpose(int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} #endif template diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu index 728431ada66aba..d9a0b0059917f5 100644 --- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu +++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu @@ -53,7 +53,7 @@ __global__ void adaptiveaveragepool( const scalar_t *input, scalar_t *output, int isizeT, int isizeH, int isizeW, int osizeT, int osizeH, int osizeW, - int64_t istrideD, + int64_t sizeD, int64_t istrideB, int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW, int64_t offsetZ) { // iterates on output pixels @@ -70,15 +70,17 @@ __global__ void adaptiveaveragepool( // select output plane int64_t o_plane = blockIdx.x + offsetZ; ot = o_plane % osizeT; // output frame/time - int d = o_plane / osizeT; // slice/feature + int d = o_plane / osizeT; // flattened (batch, channel) index + + // Decompose d into batch and channel indices + int batch_idx = d / sizeD; + int channel_idx = d % sizeD; // input frame/time range is fixed. int istartT = start_index(ot, osizeT, isizeT); int iendT = end_index(ot, osizeT, isizeT); int kT = iendT - istartT; - // input offset by slice/feature and earliest relevant frame/time - const scalar_t *input_dt = input + d*istrideD + istartT*istrideT; // output offset by slice/feature and frame/time scalar_t *output_dt = output + o_plane*osizeH*osizeW; @@ -93,8 +95,6 @@ __global__ void adaptiveaveragepool( int iendW = end_index(ow, osizeW, isizeW); int kW = iendW - istartW; - // Compute the average pooling from corresponding input pixels - const scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW; scalar_t *ptr_output = output_dt + oh*osizeW + ow; accscalar_t sum = static_cast(0); @@ -102,11 +102,13 @@ __global__ void adaptiveaveragepool( for (it = 0; it < kT; ++it) { for (ih = 0; ih < kH; ++ih) { for (iw = 0; iw < kW; ++iw) { - scalar_t val = ptr_input[ih*istrideH + iw*istrideW]; + int64_t input_offset = batch_idx * istrideB + channel_idx * istrideD + + (istartT + it) * istrideT + + (istartH + ih) * istrideH + (istartW + iw) * istrideW; + scalar_t val = input[input_offset]; sum += static_cast(val); } } - ptr_input += istrideT; // next input frame } // Update output const accscalar_t divide_factor = static_cast(kT * kH * kW); @@ -121,7 +123,7 @@ void adaptiveaveragepool_loop( int64_t totalZ, int isizeT, int isizeH, int isizeW, int osizeT, int osizeH, int osizeW, - int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) { + int64_t sizeD, int64_t istrideB, int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) { int64_t offsetZ = 0; dim3 threads(32, 8); // each H*W plane is processed by blocksH thread blocks @@ -133,7 +135,7 @@ void adaptiveaveragepool_loop( input_data, output_data, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, - istrideD, + sizeD, istrideB, istrideD, istrideT, istrideH, istrideW, offsetZ); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -364,7 +366,7 @@ void adaptive_avg_pool3d_out_cuda_template( int64_t osizeW = output_size[2]; int64_t sizeD, isizeT, isizeH, isizeW; - int64_t istrideD, istrideT, istrideH, istrideW; + int64_t istrideB, istrideD, istrideT, istrideH, istrideW; int64_t totalZ; const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous(); @@ -375,6 +377,7 @@ void adaptive_avg_pool3d_out_cuda_template( isizeH = input.size(2); isizeW = input.size(3); + istrideB = 0; istrideD = input.stride(0); istrideT = input.stride(1); istrideH = input.stride(2); @@ -390,6 +393,7 @@ void adaptive_avg_pool3d_out_cuda_template( isizeH = input.size(3); isizeW = input.size(4); + istrideB = input.stride(0); istrideD = input.stride(1); istrideT = input.stride(2); istrideH = input.stride(3); @@ -415,7 +419,7 @@ void adaptive_avg_pool3d_out_cuda_template( totalZ, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, - istrideD, istrideT, istrideH, istrideW); + sizeD, istrideB, istrideD, istrideT, istrideH, istrideW); }); } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index b3b56eae764eb7..e15777082864f6 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -1043,7 +1044,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { return _int_mm_out_cuda(self, mat2, result); } -static bool _scaled_mm_allowed_device(bool sm90_only=false) { +static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) { #ifdef USE_ROCM static const std::vector archs = { "gfx942", @@ -1057,8 +1058,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false) { return at::detail::getCUDAHooks().isGPUArch(archs); #else auto dprops = at::cuda::getCurrentDeviceProperties(); - if (sm90_only) { - return dprops->major == 9; + + if (sm90_only || sm100_only) { + return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10); } else { return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); } @@ -1168,7 +1170,7 @@ ScalingType get_scaling_type( if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { #if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \ - (defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)) + (defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC))) TORCH_CHECK( scale_a.is_contiguous() && scale_b.is_contiguous(), "Both scale_a and scale_b must be contiguous for RowWise scaling."); @@ -1481,29 +1483,49 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } namespace { - c10::SmallVector compute_grouped_gemm_output_size(const Tensor& mat_a, + at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a, const Tensor& mat_b, - const std::optional& offs + const std::optional& offs, + std::optional out_dtype ) { + c10::SmallVector out_size; const bool a_is_2d = mat_a.dim() == 2; const bool b_is_2d = mat_b.dim() == 2; if (a_is_2d) { if (b_is_2d) { - return {offs->size(0), mat_a.size(0), mat_b.size(1)}; + out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)}; } else { TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match"); - return {mat_a.size(0), mat_b.size(-1)}; + out_size = {mat_a.size(0), mat_b.size(-1)}; } } else { if (b_is_2d) { // this case is not actually encountered for MoE gemms TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match"); - return {mat_a.size(1), mat_b.size(1)}; + out_size = {mat_a.size(1), mat_b.size(1)}; } else { // regular bmm TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match"); - return {mat_a.size(0), mat_a.size(1), mat_b.size(-1)}; + out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)}; } } + + const auto out_dtype_ = out_dtype.value_or(kBFloat16); + TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); + + // For TMA transfers, strides of output tensor have to be either + // 1, or aligned to 16 bytes. + const auto last_dim = out_size.size() - 1; + const auto alignment = 16 / c10::elementSize(out_dtype_); + const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment; + std::vector out_stride; + if (a_is_2d != b_is_2d) { + out_stride = {size_padded, 1}; + } else { + out_stride = {out_size[1] * size_padded, size_padded, 1}; + } + auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_)); + + return out; } bool check_valid_strides_and_return_transposed(const Tensor& mat) { @@ -1519,7 +1541,7 @@ namespace { TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes"); return false; } else { - TORCH_CHECK(false, "Tensor should have a contiguous dimension and not be self-overlapping, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); + TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); } } @@ -1532,7 +1554,7 @@ namespace { "D, arg ", arg_idx); TORCH_CHECK( - scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx); + scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx); TORCH_CHECK( scale.size(0) == mat.size(dim) * scale_multiplier, "scale must have the same length as mat for arg ", @@ -1545,8 +1567,8 @@ namespace { "D for arg ", arg_idx); TORCH_CHECK( - scale.stride(1), - "scale_a must be contiguous in the last dimension for arg ", + scale.stride(1) == 1, + "scale must be contiguous in the last dimension for arg ", arg_idx); TORCH_CHECK( scale.size(0) == mat.size(0), @@ -1610,6 +1632,7 @@ bool use_fast_accum) { TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); + TORCH_CHECK(!scale_result.has_value(), "Scale result not supported yet"); TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix"); if (offs.has_value()) { @@ -1626,11 +1649,7 @@ bool use_fast_accum) { check_scale(mat_a, scale_a, 0 ,0, scale_multiplier); check_scale(mat_b, scale_b, 1, 1, scale_multiplier); - const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); - TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); - const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs); - Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_)); - + Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); at::cuda::detail::f8f8bf16_grouped_mm( mat_a, @@ -1657,8 +1676,8 @@ const std::optional& offs, const std::optional& bias, std::optional out_dtype) { #ifndef USE_ROCM - bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true); - TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0"); + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true); + TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0, 10.0"); TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_a.scalar_type()); TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_b.scalar_type()); @@ -1666,6 +1685,7 @@ std::optional out_dtype) { TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); const bool a_is_2d = mat_a.dim() == 2; const bool b_is_2d = mat_b.dim() == 2; + // check that the strides are valid, the fn will throw an error if not check_valid_strides_and_return_transposed(mat_a); check_valid_strides_and_return_transposed(mat_b); @@ -1675,12 +1695,10 @@ std::optional out_dtype) { TORCH_CHECK(offs->dim() == 1, "offs has to be 1D"); TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); } - const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); - TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high output type is supported for grouped gemm"); TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); - const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs); - Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_)); + Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); + at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); return out; #else diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 5d19b95b32f9b6..5120d3a58ece32 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -317,7 +317,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice auto count_data = count.mutable_data_ptr(); cuda::cub::inclusive_sum_by_key( sorted_data, - NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator(1), + ATEN_CUB_CONSTANT_ITERATOR(index_t)(1), count_data, num_indices ); @@ -329,7 +329,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice thrust::make_reverse_iterator(sorted_data + num_indices), thrust::make_reverse_iterator(static_cast(count_data) + num_indices), thrust::make_reverse_iterator(count_data + num_indices), - NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(), + ATEN_CUB_MAXIMUM(), num_indices ); }); @@ -369,7 +369,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 && - num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads, + num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(), "BlockReduceSum requires all warps be active"); const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr(); dim3 grid = unique_indices.numel(); diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 5a6961187a2a04..fb92c7488a1524 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -210,7 +210,7 @@ Tensor embedding_bag_backward_cuda_sum_avg( auto count_data = count.mutable_data_ptr(); cuda::cub::inclusive_sum_by_key( sorted_data, - NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator(1), + ATEN_CUB_CONSTANT_ITERATOR(index_t)(1), count_data, num_indices ); @@ -222,7 +222,7 @@ Tensor embedding_bag_backward_cuda_sum_avg( thrust::make_reverse_iterator(sorted_data + num_indices), thrust::make_reverse_iterator(count_data + num_indices), thrust::make_reverse_iterator(count_data + num_indices), - NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(), + ATEN_CUB_MAXIMUM(), num_indices ); }); diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu index 11d44b9d4cd0f5..7ee02b02b41f1e 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu @@ -356,7 +356,7 @@ struct CopyFunctor { static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1); template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -441,7 +441,6 @@ void foreach_tensor_copy_list_kernel_cuda_( self[0].scalar_type(), "foreach_tensor_copy", [&]() { - using opmath_t = at::opmath_type; AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] { if constexpr (std::is_same_v) { multi_tensor_apply<2>( @@ -451,7 +450,7 @@ void foreach_tensor_copy_list_kernel_cuda_( /* depth */ 2, /* r_args_depth */ 1, /* res_arg_index */ 1>(), - Copy()); + Copy()); } else { // Ref: // https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301 diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index 645b095c5a6e50..c121d971cd7bef 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -208,7 +208,7 @@ struct BinaryOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op, opmath_t scalar) { @@ -232,7 +232,7 @@ struct BinaryOpScalarListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -256,7 +256,7 @@ struct BinaryOpListAlphaFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op, opmath_t alpha) { @@ -308,7 +308,7 @@ struct BinaryOpScalarTensorFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op, T* scalar, @@ -364,7 +364,7 @@ struct BinaryOpScalarTensorFunctor { template struct ZeroFunctor { __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata<1>& tl) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -406,7 +406,7 @@ struct UnaryOpFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -458,7 +458,7 @@ struct PointwiseOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op, opmath_t scalar) { @@ -482,7 +482,7 @@ struct PointwiseOpScalarListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -506,7 +506,7 @@ struct PointwiseOpListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -557,7 +557,7 @@ struct TernaryOpListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op) { static_assert(depth == 3 || depth == 4, ""); @@ -611,7 +611,7 @@ struct TernaryOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, Op op, opmath_t alpha) { @@ -668,7 +668,7 @@ struct TernaryOpScalarListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListScalarListMetadata& tl, Op op) { static_assert(depth == 2 || depth == 3, ""); diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 61793fd5f9e08c..2da8e634981f9e 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -53,7 +53,7 @@ template < int res_arg_index = 0> struct LpMaxFunctor { __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, T* output_per_tensor_ptr, const int max_chunks_per_tensor) { @@ -243,7 +243,7 @@ template < struct LpNormFunctor { using out_opmath_t = typename at::opmath_type; __device__ __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, TensorListMetadata& tl, out_opmath_t* output_per_tensor_ptr, const int max_chunks_per_tensor) { diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu index 7d791613c8e1c5..d0cf7e06c86887 100644 --- a/aten/src/ATen/native/cuda/FusedSgdKernel.cu +++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu @@ -62,7 +62,7 @@ struct FusedSgdMathFunctor { depth == 2 || depth == 3, "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0"); C10_DEVICE __forceinline__ void operator()( - const int chunk_size, + const int64_t chunk_size, TensorListMetadata& tl, const double weight_decay, const double momentum, diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu index d43875e3c8a6d5..68acf79f6894da 100644 --- a/aten/src/ATen/native/cuda/GroupMM.cu +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -8,9 +8,10 @@ #include -// Two warninngs in Cutlass included header files +// Three warninngs in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") // Determine if the architecture supports rowwise scaled mm // Currently failing on windows with: @@ -43,11 +44,14 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") #include #include +#include + namespace { using Strides = at::cuda::detail::Strides; // std::array; -template +template struct Schedule { + // SM90 using CooperativeSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; @@ -55,10 +59,19 @@ struct Schedule { cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using PongEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using KernelSchedule = - cute::conditional_t; - using EpilogueSchedule = cute:: - conditional_t; + // SM100 + using MMA1SMKernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using MMA1SMEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MMA2SMKernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using MMA2SMEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + + using KernelSchedule = cute::conditional_t, + cute::conditional_t, + cute::conditional_t>; + using EpilogueSchedule = cute::conditional_t, + cute::conditional_t, + cute::conditional_t>; + }; int ceildiv(int a, int b) { @@ -70,13 +83,14 @@ int round_up_to_nearest_multiple(int a, int b) { } template < + typename ArchTag, bool a_row_major, bool b_row_major, - bool Pong, + bool PONGOr2SM, typename TB_M, typename TB_N, typename TB_K> -void bf16bf16_grouped_gemm_impl_sm90( +void bf16bf16_grouped_gemm_impl_sm90_sm100( at::Tensor mat_a, // bf16 at::Tensor mat_b, // bf16 std::optional offs, @@ -99,14 +113,13 @@ void bf16bf16_grouped_gemm_impl_sm90( constexpr int AlignmentB = 16 / sizeof(DtypeB); using LayoutOutput = cutlass::layout::RowMajor; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); - using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = - typename Schedule::KernelSchedule; + typename Schedule::KernelSchedule; using EpilogueSchedule = - typename Schedule::EpilogueSchedule; + typename Schedule::EpilogueSchedule; using ProblemShape = cutlass::gemm::GroupProblemShape< cute::Shape>; // per // group @@ -146,8 +159,16 @@ void bf16bf16_grouped_gemm_impl_sm90( cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel:: - GemmUniversal; + + using GemmKernelBase = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue>; + + using GemmKernel = std::conditional_t< + std::is_same_v, + at::cuda::detail::enable_3x_kernel_for_sm10, + at::cuda::detail::enable_3x_kernel_for_sm9x>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using StrideA = typename Gemm::GemmKernel::InternalStrideA; @@ -319,22 +340,49 @@ void dispatch_bf16_grouped_kernel_on_tile_size( // ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || // (K >= 2048 && N >= 2048)); bool small = (M <= 128 || N <= 128); - if (small) { - bf16bf16_grouped_gemm_impl_sm90< + cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); + const bool sm10x = properties != nullptr && properties->major == 10; + + if (sm10x) { + if (small){ + bf16bf16_grouped_gemm_impl_sm90_sm100< + cutlass::arch::Sm100, + a_row_major, + b_row_major, + /*PONGOr2SM*/ false, + cute::_128, + cute::_256, + cute::_64>(mat_a, mat_b, offs, bias, out); // Tile shape taken from CUTLASS examples, 64 = 128/sizeof(bfloat16) + } else { + bf16bf16_grouped_gemm_impl_sm90_sm100< + cutlass::arch::Sm100, + a_row_major, + b_row_major, + /*PONGOr2SM*/ true, + cute::_256, + cute::_256, + cute::_64>(mat_a, mat_b, offs, bias, out); // Same as above ^ + } + } else { + if(small) { + bf16bf16_grouped_gemm_impl_sm90_sm100< + cutlass::arch::Sm90, a_row_major, b_row_major, - /*Pong*/ true, + /*PONGOr2SM*/ true, cute::_64, cute::_128, cute::_128>(mat_a, mat_b, offs, bias, out); - } else { - bf16bf16_grouped_gemm_impl_sm90< + } else { + bf16bf16_grouped_gemm_impl_sm90_sm100< + cutlass::arch::Sm90, a_row_major, b_row_major, - /*Pong*/ false, + /*PONGOr2SM*/ false, cute::_128, cute::_256, cute::_64>(mat_a, mat_b, offs, bias, out); + } } } diff --git a/aten/src/ATen/native/cuda/GroupMMCommon.cuh b/aten/src/ATen/native/cuda/GroupMMCommon.cuh index a0474b7ad1799a..a4d8a97b6fd890 100644 --- a/aten/src/ATen/native/cuda/GroupMMCommon.cuh +++ b/aten/src/ATen/native/cuda/GroupMMCommon.cuh @@ -47,10 +47,44 @@ __global__ void prepare_grouped_gemm_data( if (offs != nullptr) { int32_t start = tid == 0 ? 0 : offs[tid - 1]; delta = offs[tid] - start; - int align = 16 / sizeof(DtypeA); - CUDA_KERNEL_ASSERT( - delta >=0 && delta % align == 0 && - "expected dynamic dimension byte size to be non-negative multiple of 16 \n"); + if (K < 0) { + if (!a_row_major && b_row_major) { + CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n"); + } else { + // CUTLASS cannot handle delta=0 here. + CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n"); + } + } + + // TMA transfers require global memory tensor addresses to be + // aligned to 16 bytes. + if (tid < blockDim.x - 1) { + // Check this requirement for input tensors, in case group + // addresses are increased along the dynamic dimension. + if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension + (M < 0 && !a_row_major)) { // 3D/2D: check along N dimension + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension + (N < 0 && b_row_major)) { // 3D/2D: check along N dimension + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + + // Check the same requirement for output tensor (that is always + // contiguous, and in row-major layout). + if (N < 0) { + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + } } int64_t lda, ldb, ldoutput; if (M < 0) { @@ -81,7 +115,6 @@ __global__ void prepare_grouped_gemm_data( } else if (K < 0) { // A, B is 2d, output is 3d K = delta; - CUDA_KERNEL_ASSERT(delta > 0 && "can't handle K=0"); lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; ldoutput = tensor_StrideOutput[1]; diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 89308177bfea89..1d603132e6893f 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -1946,7 +1946,7 @@ const auto chebyshev_polynomial_t_string = jiterator_stringify( T q = x; T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -1996,7 +1996,7 @@ const auto chebyshev_polynomial_u_string = jiterator_stringify( T q = x + x; T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -2054,7 +2054,7 @@ const auto chebyshev_polynomial_v_string = jiterator_stringify( T q = x + x - T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -2116,7 +2116,7 @@ const auto chebyshev_polynomial_w_string = jiterator_stringify( T q = x + x + T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -2252,7 +2252,7 @@ const auto laguerre_polynomial_l_string = jiterator_stringify( T q = T(1.0) - x; T r; - for (int64_t k = 1; k < n; k++) { + for (int64_t k = 1; (k < n) && !isnan(q); k++) { r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); p = q; q = r; @@ -2294,7 +2294,7 @@ const auto legendre_polynomial_p_string = jiterator_stringify( T q = x; T r; - for (int64_t k = 1; k < n; k++) { + for (int64_t k = 1; (k < n) && !isnan(q); k++) { r = ((k + k + 1) * x * q - k * p) / (k + 1); p = q; q = r; @@ -2851,7 +2851,7 @@ const auto shifted_chebyshev_polynomial_t_string = jiterator_stringify( T q = x + x - T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !isnan(q); k++) { r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -2905,7 +2905,7 @@ const auto shifted_chebyshev_polynomial_u_string = jiterator_stringify( T q = x + x - T(1.0) + (x + x - T(1.0)); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !isnan(q); k++) { r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -2963,7 +2963,7 @@ const auto shifted_chebyshev_polynomial_v_string = jiterator_stringify( T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !isnan(q); k++) { r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -3021,7 +3021,7 @@ const auto shifted_chebyshev_polynomial_w_string = jiterator_stringify( T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); T r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !isnan(q); k++) { r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 65770e40a8b2be..8132e7df57b51d 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -86,7 +86,7 @@ void renormRows(Tensor& t) { TORCH_CHECK(props != nullptr); int numSM = props->multiProcessorCount; const int64_t maxThreads = std::min( - props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads); + props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads()); int warp_size = at::cuda::warp_size(); dim3 grid(rows < numSM * 4 ? rows : numSM * 4); diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index d61b99fb5a3761..2d0e32d4e8c05b 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -94,7 +94,7 @@ __global__ void flag_kernel(const T* d_in, int64_t * d_out, const int64_t * agg, // Specialize BlockScan type for our thread block using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan; - using TransformInputIteratorT = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator, const T*>; + using TransformInputIteratorT = ATEN_CUB_TRANSFORM_ITERATOR(int, NonZeroOp, const T*); using BlockExchangeT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockExchange; // Shared memory @@ -184,7 +184,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { auto num_nonzeros = allocator.allocate(sizeof(int) * num_chunks); for (int64_t idx = 0; idx < num_chunks; idx++) { int64_t remaining = std::min(chunk_size, self.numel() - idx * chunk_size); - cub::TransformInputIterator, const scalar_t*> itr( + ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp, const scalar_t*) itr( self_.const_data_ptr() + idx * chunk_size, NonZeroOp()); AT_CUDA_CHECK(cub::DeviceReduce::Sum( @@ -243,8 +243,8 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { for (int64_t idx = 0; idx < num_chunks; idx++) { int remaining = std::min(chunk_size, self.numel() - idx * chunk_size); - cub::CountingInputIterator counting_itr(idx * chunk_size); - cub::TransformInputIterator, const scalar_t*> + ATEN_CUB_COUNTING_ITERATOR(int64_t) counting_itr(idx * chunk_size); + ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp, const scalar_t*) itr(self_.const_data_ptr() + idx * chunk_size, NonZeroOp()); temp_storage_bytes = 0; diff --git a/aten/src/ATen/native/cuda/PowKernel.cu b/aten/src/ATen/native/cuda/PowKernel.cu index 438f833616056a..2698207c45ef54 100644 --- a/aten/src/ATen/native/cuda/PowKernel.cu +++ b/aten/src/ATen/native/cuda/PowKernel.cu @@ -185,6 +185,12 @@ void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar return; } AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_cuda", [&]() { + if (exp_scalar.equal(2.0)) { + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t { + return base * base; + }); + return; + } const auto exp = exp_scalar.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t { return pow_(base, exp); diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index e471ce9f9d7737..9d7ead7e49892e 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -1,10 +1,11 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include #include -#include +#include +#include #include +#include #include +#include #include #include @@ -181,12 +182,8 @@ Tensor& range_cuda_out(const Scalar& start, const Scalar& end, const Scalar& ste auto xend = end.to(); auto xstep = step.to(); - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); + int64_t size = static_cast(((xend - xstart) / xstep) + 1); if (result.numel() != size) { @@ -217,12 +214,7 @@ Tensor& arange_cuda_out(const Scalar& start, const Scalar& end, const Scalar& st auto xend = end.to(); auto xstep = step.to(); - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); // we use double precision for (start - end) / step // to compute size_d for consistency across devices. diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index ad2588e181ed97..15a572804af5fc 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1118,13 +1118,19 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ int max_threads_per_mp = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor; #ifdef USE_ROCM - // Control the number of threadblocks by adjusting the maximum number of - // threads per multi-processor. These numbers better reflect the maximum - // theoretical achievable threads per MP for the reduction operation. - if (iter.ndim() == 1 || iter.ndim() == 3) - max_threads_per_mp = 512; - if (iter.ndim() == 2) - max_threads_per_mp = 256; + // If the grid consists of a single threadblock, do not change the max threads per + // MP value. This will increase the parallelism across the y dimension of the grid. + bool uses_a_single_block = config.grid().x == config.grid().y == config.grid().z == 1; + + if (!uses_a_single_block) { + // Control the number of threadblocks by adjusting the maximum number of + // threads per multi-processor. These numbers better reflect the maximum + // theoretical achievable threads per MP for the reduction operation. + if (iter.ndim() == 1 || iter.ndim() == 3) + max_threads_per_mp = 512; + else if (iter.ndim() == 2) + max_threads_per_mp = 256; + } #endif const int blocks_per_sm = max_threads_per_mp / config.num_threads; const int target_grid_size = num_mp * blocks_per_sm; diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 19e9f4881c3dfb..b39f07de9d516a 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -302,8 +302,9 @@ void f8f8bf16_rowwise_impl( } -// Cutlass rowwise kernel for SM100 +// Cutlass rowwise kernel for SM100/SM120 template < + typename ArchTag, typename TileShape, typename ClusterShape, typename Transposed, @@ -311,7 +312,7 @@ template < typename DtypeA, typename DtypeB, typename DtypeBias> -void f8f8bf16_rowwise_impl_sm100( +void f8f8bf16_rowwise_impl_sm100_sm120( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 at::Tensor x_scale, @@ -344,8 +345,6 @@ void f8f8bf16_rowwise_impl_sm100( cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); - // Tag indicating the minimum SM that supports the intended feature - using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; // Implement rowwise scaling epilogue. @@ -380,7 +379,7 @@ void f8f8bf16_rowwise_impl_sm100( using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, OperatorClass, + ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, @@ -389,7 +388,13 @@ void f8f8bf16_rowwise_impl_sm100( EpilogueScheduleType, EpilogueEVT>::CollectiveOp; - using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; + // as of CUTLASS 3.9.2, on sm120, KernelScheduleAuto resolves to + // KernelTmaWarpSpecializedCooperativeSm120<2>>, + // which does not support TileShape.M < 128 + using MainloopScheduleType = std::conditional_t< + std::is_same_v && cute::size<0>(TileShape{}) < 128, + cutlass::gemm::KernelTmaWarpSpecializedPingpong, + cutlass::gemm::collective::KernelScheduleAuto>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, @@ -698,7 +703,7 @@ void f8f8bf16_rowwise_impl_sm89( C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template +template void dispatch_fp8_rowwise_kernel_on_tile_size( at::Tensor XQ, at::Tensor WQ, @@ -715,9 +720,6 @@ void dispatch_fp8_rowwise_kernel_on_tile_size( smTarget -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } - cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); - const bool sm10x = properties != nullptr && properties->major == 10; - // We prefer to use smaller tiles (less wasted compute in case of padding), // but if this causes us to have more CUDA blocks than there are SMs on the // GPU then we'll hit wave quantization, hence we'll switch to larger tiles. @@ -726,33 +728,38 @@ void dispatch_fp8_rowwise_kernel_on_tile_size( smTarget / cute::size(ClusterShape{}); if (use_smaller_tiles) { - if (sm10x) { - return f8f8bf16_rowwise_impl_sm100< + if constexpr (std::is_same_v) { + return f8f8bf16_rowwise_impl< + /*TileShape=*/cute::Shape, + ClusterShape, + Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); + } else { + return f8f8bf16_rowwise_impl_sm100_sm120< + ArchTag, /*TileShape=*/cute::Shape, ClusterShape, Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); } - return f8f8bf16_rowwise_impl< - /*TileShape=*/cute::Shape, - ClusterShape, - Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); } else { - if (sm10x) { - return f8f8bf16_rowwise_impl_sm100< + if constexpr (std::is_same_v) { + return f8f8bf16_rowwise_impl< /*TileShape=*/cute::Shape, ClusterShape, Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); - } - return f8f8bf16_rowwise_impl< + } else { + return f8f8bf16_rowwise_impl_sm100_sm120< + ArchTag, /*TileShape=*/cute::Shape, ClusterShape, Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); + } } } template < typename ClusterShape, typename Transposed, + typename ArchTag, typename FastAccum, typename DtypeA, typename DtypeB, @@ -768,6 +775,7 @@ void handle_transposition( if constexpr (!Transposed::value) { dispatch_fp8_rowwise_kernel_on_tile_size< ClusterShape, + ArchTag, Transposed, FastAccum, DtypeA, @@ -776,6 +784,7 @@ void handle_transposition( } else { dispatch_fp8_rowwise_kernel_on_tile_size< ClusterShape, + ArchTag, Transposed, FastAccum, DtypeB, @@ -938,13 +947,27 @@ void dispatch_fp8_rowwise_kernel_on_sm( const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9; const bool sm9x = properties != nullptr && properties->major == 9; const bool sm10x = properties != nullptr && properties->major == 10; - if (!(sm89 || sm9x || sm10x)) { + const bool sm12x = properties != nullptr && properties->major == 12; + if (!(sm89 || sm9x || sm10x || sm12x)) { TORCH_CHECK( false, "Rowwise scaling is not currently supported on your device"); } - if (sm9x || sm10x) { - dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(XQ, WQ, x_scale, w_scale, bias, out); + if (sm9x) { + dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< + /*ArchTag=*/cutlass::arch::Sm90, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } else if (sm10x) { + dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< + /*ArchTag=*/cutlass::arch::Sm100, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } else if (sm12x) { + // sm12x doesn't have multicast feature + handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::false_type, + /*ArchTag=*/cutlass::arch::Sm120, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); } else { dispatch_fp8_rowwise_kernel_sm89(XQ, WQ, x_scale, w_scale, bias, out); } diff --git a/aten/src/ATen/native/cuda/SegmentReduce.cu b/aten/src/ATen/native/cuda/SegmentReduce.cu index 04bec043b725b8..3acb359342f135 100644 --- a/aten/src/ATen/native/cuda/SegmentReduce.cu +++ b/aten/src/ATen/native/cuda/SegmentReduce.cu @@ -1,5 +1,6 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include @@ -17,6 +18,10 @@ #include #endif +// SegmentReduce compilation with CUDA-12.9 causes NVCC crash on Windows +// See https://github.com/pytorch/pytorch/issues/156181 +#if !defined(_WIN32) || CUDART_VERSION < 12090 + namespace at::native { namespace { @@ -600,3 +605,5 @@ REGISTER_DISPATCH( &_segment_reduce_offsets_backward_cuda_kernel); } // namespace at::native + +#endif diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index ffaa9e13f141c1..f27d76256cdb3f 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -183,15 +183,16 @@ inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) { uint64_t block_size = 1; uint64_t max_block_size = std::min(dim_size, static_cast(max_threads)); - // We need a block size that is a multiple of C10_WARP_SIZE in order + // We need a block size that is a multiple of at::cuda::warp_size() in order // to perform block size reductions using warp shuffle instructions. - // Since max_threads is also a multiple of C10_WARPS_SIZE we do not + // Since max_threads is also a multiple of at::cuda::warp_size() we do not // risk creating a block size larger than the limit. - if (max_block_size % C10_WARP_SIZE == 0) { + int warp_size = at::cuda::warp_size(); + if (max_block_size % warp_size == 0) { block_size = max_block_size; } else { - block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE; + block_size = (max_block_size / warp_size + 1) * warp_size; } return dim3(block_size); @@ -611,7 +612,7 @@ WriteBpropResultsVectorized( if (threadIdx.x >= shift) { gradInput[offset] = epilogue(gradOutput[offset], output[offset]); } - size -= blockDim.x; + size -= blockDim.x > size ? size : blockDim.x; gradInput += blockDim.x; output += blockDim.x; gradOutput += blockDim.x; @@ -1107,7 +1108,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t constexpr int ILP = sizeof(float4) / sizeof(scalar_t); if constexpr (use_fast_softmax) { dim3 block(512); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t); if (dim_size % ILP == 0) { cunn_SoftMaxForwardGmem <<>>(output_ptr, input_ptr, dim_size); @@ -1117,7 +1118,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } } else { dim3 block = SoftMaxForward_getBlockSize(dim_size); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); @@ -1198,7 +1199,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t constexpr int ILP = sizeof(float4) / sizeof(scalar_t); if constexpr (use_fast_softmax) { dim3 block(512); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t); if (dim_size % ILP == 0) { cunn_SoftMaxForwardGmem <<>>(output_ptr, input_ptr, dim_size); @@ -1208,7 +1209,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } } else { dim3 block = SoftMaxForward_getBlockSize(dim_size); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); @@ -1274,7 +1275,7 @@ void dispatch_host_softmax_backward(int64_t dim_size, dim3 grid, Tensor &grad, T constexpr int ILP = sizeof(float4) / sizeof(output_t); dim3 block = SoftMax_getBlockSize(ILP, dim_size); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(output_t); bool can_use_smem = static_cast(dim_size) < max_elements_per_smem; diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cu b/aten/src/ATen/native/cuda/TensorModeKernel.cu index 4764b078c050b8..0c97ab742103ff 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cu +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cu @@ -207,7 +207,7 @@ void handle_fused_mode( constexpr int num_threads = size / 2; int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0 && - num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, ""); + num_threads <= cuda_utils::kCUDABlockReduceMaxThreads(), ""); const auto memsize = (sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int)); compute_mode diff --git a/aten/src/ATen/native/cuda/TensorTopK.cu b/aten/src/ATen/native/cuda/TensorTopK.cu index 103b360bcb8689..584c1c49a03caa 100644 --- a/aten/src/ATen/native/cuda/TensorTopK.cu +++ b/aten/src/ATen/native/cuda/TensorTopK.cu @@ -439,8 +439,12 @@ __global__ void computeBlockwiseWithinKCounts( warp_counts[warp] = count; } __syncthreads(); +#ifdef USE_ROCM + CUDA_KERNEL_ASSERT(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE); +#else static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE, "Assuming only 1 warp is needed for final reduction"); +#endif if (warp != 0) { return; } @@ -721,8 +725,8 @@ void launch( desired, counts, num_blocks, blocks_per_slice, kthCounts); C10_CUDA_KERNEL_LAUNCH_CHECK(); // Do a prefix scan of withinKCounts and kthCounts using slice_idx as keys to get the starting index of each block - using counting_iter_t = cub::CountingInputIterator; - using slice_idx_iter_t = cub::TransformInputIterator; + using counting_iter_t = ATEN_CUB_COUNTING_ITERATOR(uint32_t, uint32_t); + using slice_idx_iter_t = ATEN_CUB_TRANSFORM_ITERATOR(uint32_t, BlockIdxToKey, counting_iter_t); slice_idx_iter_t slice_idx_iter(counting_iter_t(0), BlockIdxToKey(blocks_per_slice)); at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, withinKCounts, withinKCounts, num_blocks); at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, kthCounts, kthCounts, num_blocks); diff --git a/aten/src/ATen/native/cuda/UniqueCub.cu b/aten/src/ATen/native/cuda/UniqueCub.cu index 1bda65815d6d49..0a1f3408e783df 100644 --- a/aten/src/ATen/native/cuda/UniqueCub.cu +++ b/aten/src/ATen/native/cuda/UniqueCub.cu @@ -54,7 +54,7 @@ struct LoadBoolOp { auto wrap_input_iterator(const bool *data) { // See NOTE [Loading boolean values] LoadBoolOp op; - return NO_ROCM(at_cuda_detail)::cub::TransformInputIterator( + return ATEN_CUB_TRANSFORM_ITERATOR(bool, LoadBoolOp, const uint8_t*, int)( reinterpret_cast(data), op); } @@ -259,10 +259,10 @@ struct UniqueCub { const bool* self_data = self.const_data_ptr(); MapNumberOfTrueValues op; - NO_ROCM(at_cuda_detail)::cub::TransformInputIterator + ATEN_CUB_TRANSFORM_ITERATOR(int, MapNumberOfTrueValues, const uint8_t*, int) data_iter(reinterpret_cast(self_data), op); at::cuda::cub::reduce(data_iter, tmp_num_true.get(), num_inp, - NO_ROCM(at_cuda_detail)::cub::Sum{}, 0); + NO_ROCM(::cuda)::std::plus<>{}, 0); auto options = self.options(); output = at::empty({2}, self.options()); diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index 2a272d22c0c60e..1818987c6a5880 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -12,7 +12,17 @@ constexpr int kCUDABlockReduceNumThreads = 512; // of which reduces C10_WARP_SIZE elements. So, at most // C10_WARP_SIZE**2 elements can be reduced at a time. // NOTE: This is >= the max block size on current hardware anyway (1024). -constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; +// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions, +// and kCUDABlockReduceMaxThreads is a host-side variable. +#ifdef USE_ROCM +static int kCUDABlockReduceMaxThreads() { + return at::cuda::warp_size() * at::cuda::warp_size(); +} +#else +constexpr int kCUDABlockReduceMaxThreads() { + return C10_WARP_SIZE * C10_WARP_SIZE; +} +#endif // Sums `val` across all threads in a warp. // diff --git a/aten/src/ATen/native/cuda/cutlass_common.cuh b/aten/src/ATen/native/cuda/cutlass_common.cuh index 0bf4da7a7be86c..8f5143713aa993 100644 --- a/aten/src/ATen/native/cuda/cutlass_common.cuh +++ b/aten/src/ATen/native/cuda/cutlass_common.cuh @@ -25,6 +25,16 @@ struct enable_3x_kernel_for_sm9x : Kernel { } }; +template +struct enable_3x_kernel_for_sm10 : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + template struct enable_3x_kernel_for_sm10_or_later : Kernel { template diff --git a/aten/src/ATen/native/cuda/fused_adam_utils.cuh b/aten/src/ATen/native/cuda/fused_adam_utils.cuh index f6949e69f2ca80..7a8f4a0d0e7e2b 100644 --- a/aten/src/ATen/native/cuda/fused_adam_utils.cuh +++ b/aten/src/ATen/native/cuda/fused_adam_utils.cuh @@ -108,7 +108,7 @@ struct FusedAdamMathFunctor { "depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); using opmath_t = at::opmath_type; C10_DEVICE __forceinline__ void operator()( - int chunk_size, + int64_t chunk_size, FusedOptimizerTensorListMetadata& tl, const float* lr_ptr, const double& lr, diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index ee573e2e566f65..bdb169e26b142a 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -884,8 +884,12 @@ void LaunchGammaBetaBackwardCUDAKernel( LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); - *dgamma = dgamma_blocks.sum(0); - *dbeta = dbeta_blocks.sum(0); + if (dgamma_blocks.defined()) { + *dgamma = dgamma_blocks.sum(0); + } + if (dbeta_blocks.defined()) { + *dbeta = dbeta_blocks.sum(0); + } } else { // We are in the normal case where M is not that large. // We can change the tile shape (which is the last template parameter) in accordance with M. diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index 3cf47804e9148b..71cbe361a0373d 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -1433,7 +1433,7 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) // This function calculates the inverse matrix in-place // result should be in column major order and contain matrices to invert // the content of result is overwritten by 'apply_cholesky_inverse' -#if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM) +#if defined(USE_LINALG_SOLVER) auto preferred_backend = at::globalContext().linalgPreferredBackend(); switch (preferred_backend) { case at::LinalgBackend::Cusolver: diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp index 7885583a0d59fc..888ab64db61f9c 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp @@ -28,6 +28,18 @@ #include #endif +#if defined(USE_ROCM) +#include +#include +#define PYTORCH_ROCSOLVER_VERSION \ + (ROCSOLVER_VERSION_MAJOR * 10000 + ROCSOLVER_VERSION_MINOR * 100 + ROCSOLVER_VERSION_PATCH) +#if (PYTORCH_ROCSOLVER_VERSION >= 32600) +#define ROCSOLVER_SYEVD_BATCHED_ENABLED 1 +#else +#define ROCSOLVER_SYEVD_BATCHED_ENABLED 0 +#endif +#endif // defined(USE_ROCM) + namespace at::native { static cublasOperation_t to_cublas(TransposeType trans) { @@ -1204,6 +1216,115 @@ Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau) { return result; } +#if defined(USE_ROCM) && ROCSOLVER_SYEVD_BATCHED_ENABLED +template +rocblas_status _rocsolver_syevd_strided_batched( + rocblas_handle handle, + const rocblas_evect evect, + const rocblas_fill uplo, + const rocblas_int n, + scalar_t* A, + const rocblas_int lda, + const rocblas_stride strideA, + scalar_t* D, + const rocblas_stride strideD, + scalar_t* E, + const rocblas_stride strideE, + rocblas_int* info, + const rocblas_int batch_count +); + +template <> +rocblas_status _rocsolver_syevd_strided_batched( + rocblas_handle handle, + const rocblas_evect evect, + const rocblas_fill uplo, + const rocblas_int n, + float* A, + const rocblas_int lda, + const rocblas_stride strideA, + float* D, + const rocblas_stride strideD, + float* E, + const rocblas_stride strideE, + rocblas_int* info, + const rocblas_int batch_count +){ + return rocsolver_ssyevd_strided_batched( + handle, evect, uplo, n, A, lda, strideA, D, strideD, E, strideE, info, batch_count + ); +} + +template <> +rocblas_status _rocsolver_syevd_strided_batched( + rocblas_handle handle, + const rocblas_evect evect, + const rocblas_fill uplo, + const rocblas_int n, + double* A, + const rocblas_int lda, + const rocblas_stride strideA, + double* D, + const rocblas_stride strideD, + double* E, + const rocblas_stride strideE, + rocblas_int* info, + const rocblas_int batch_count +){ + return rocsolver_dsyevd_strided_batched( + handle, evect, uplo, n, A, lda, strideA, D, strideD, E, strideE, info, batch_count + ); +} + +template +static void apply_syevd_batched_rocsolver(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { + + using value_t = typename c10::scalar_value_type::type; + + auto uplo = upper ? rocblas_fill::rocblas_fill_upper : rocblas_fill::rocblas_fill_lower; + auto evect = compute_eigenvectors ? rocblas_evect::rocblas_evect_original : rocblas_evect::rocblas_evect_none; + + int64_t n = vectors.size(-1); + int64_t lda = std::max(1, n); + int64_t batch_size = batchCount(vectors); + + auto vectors_stride = matrixStride(vectors); + auto values_stride = n; + + auto vectors_data = vectors.data_ptr(); + auto values_data = values.data_ptr(); + auto infos_data = infos.data_ptr(); + + auto work_stride = n; + auto work_size = work_stride * batch_size; + // allocate workspace storage on device + auto& allocator = *at::cuda::getCUDADeviceAllocator(); + auto work_data = allocator.allocate(sizeof(scalar_t) * work_size); + + rocblas_handle handle = static_cast(at::cuda::getCurrentCUDASolverDnHandle()); + + // rocsolver will manage the workspace size automatically + if(!rocblas_is_managing_device_memory(handle)) + TORCH_ROCBLAS_CHECK(rocblas_set_workspace(handle, nullptr, 0)); + + TORCH_ROCBLAS_CHECK(_rocsolver_syevd_strided_batched( + handle, + evect, + uplo, + n, + vectors_data, + lda, + vectors_stride, + values_data, + values_stride, + static_cast(work_data.get()), + work_stride, + infos_data, + batch_size + )); +} +#endif // USE_ROCM && ROCSOLVER_SYEVD_BATCHED_ENABLED + template static void apply_syevd(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { using value_t = typename c10::scalar_value_type::type; @@ -1363,6 +1484,7 @@ static void apply_syevj_batched(const Tensor& values, const Tensor& vectors, con auto values_data = values.data_ptr(); auto infos_data = infos.data_ptr(); +#ifndef USE_CUSOLVER_64_BIT_XSYEV_BATCHED // syevj_params controls the numerical accuracy of syevj // by default the tolerance is set to machine accuracy // the maximum number of iteration of Jacobi method by default is 100 @@ -1406,6 +1528,54 @@ static void apply_syevj_batched(const Tensor& values, const Tensor& vectors, con syevj_params, batch_size); TORCH_CUSOLVER_CHECK(cusolverDnDestroySyevjInfo(syevj_params)); + +#else + + cusolverDnParams_t syev_params; + TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(&syev_params)); + + auto handle = at::cuda::getCurrentCUDASolverDnHandle(); + + // get the optimal work size and allocate workspace tensor + size_t worksize_device; + size_t worksize_host; + + at::cuda::solver::xsyevBatched_bufferSize( + handle, + syev_params, + jobz, + uplo, + n, + vectors_data, + lda, + values_data, + &worksize_device, + &worksize_host, + batch_size); + + // allocate workspace storage on device and host + auto& device_allocator = *at::cuda::getCUDADeviceAllocator(); + auto work_device_data = device_allocator.allocate(worksize_device); + auto& host_allocator = *at::getCPUAllocator(); + auto work_host_data = host_allocator.allocate(worksize_host); + at::cuda::solver::xsyevBatched( + handle, + syev_params, + jobz, + uplo, + n, + vectors_data, + lda, + values_data, + work_device_data.get(), + worksize_device, + work_host_data.get(), + worksize_host, + infos_data, + batch_size); + TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(syev_params)); + +#endif // USE_CUSOLVER_64_BIT_XSYEV_BATCHED } static void linalg_eigh_cusolver_syevd(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { @@ -1426,11 +1596,22 @@ static void linalg_eigh_cusolver_syevj_batched(const Tensor& eigenvalues, const }); } +#if defined(USE_ROCM) && ROCSOLVER_SYEVD_BATCHED_ENABLED +static void linalg_eigh_rocsolver_syevd_batched(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { + AT_DISPATCH_FLOATING_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&]() { + apply_syevd_batched_rocsolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);}); +} +#endif // USE_ROCM && ROCSOLVER_SYEVD_BATCHED_ENABLED + void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { -#ifdef USE_ROCM - // syevj has larger numerical errors than syevd - linalg_eigh_cusolver_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); -#else +#if defined(USE_ROCM) +#if ROCSOLVER_SYEVD_BATCHED_ENABLED + if (batchCount(eigenvectors) > 1 && (eigenvectors.scalar_type() == at::kFloat || eigenvectors.scalar_type() == at::kDouble)) + linalg_eigh_rocsolver_syevd_batched(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); + else // not ROCSOLVER_SYEVD_BATCHED_ENABLED or batch==1 or complex input +#endif // ROCSOLVER_SYEVD_BATCHED_ENABLED + linalg_eigh_cusolver_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); +#else // not USE_ROCM if (use_cusolver_syevj_batched_ && batchCount(eigenvectors) > 1 && eigenvectors.size(-1) <= 32) { // Use syevjBatched for batched matrix operation when matrix size <= 32 // See https://github.com/pytorch/pytorch/pull/53040#issuecomment-788264724 diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp index 7b068b5f5aec10..99c38077611d66 100644 --- a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp +++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp @@ -1956,6 +1956,274 @@ void xsyevd, double>( } #endif // USE_CUSOLVER_64_BIT +#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED + +template <> +void xsyevBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + const float *A, + int64_t lda, + const float *W, + size_t *workspaceInBytesOnDevice, + size_t *workspaceInBytesOnHost, + int64_t batchSize) { + TORCH_CUSOLVER_CHECK(cusolverDnXsyevBatched_bufferSize( + handle, + params, + jobz, + uplo, + n, + CUDA_R_32F, + reinterpret_cast(A), + lda, + CUDA_R_32F, + reinterpret_cast(W), + CUDA_R_32F, + workspaceInBytesOnDevice, + workspaceInBytesOnHost, + batchSize)); +} + +template <> +void xsyevBatched_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + const double *A, + int64_t lda, + const double *W, + size_t *workspaceInBytesOnDevice, + size_t *workspaceInBytesOnHost, + int64_t batchSize) { + TORCH_CUSOLVER_CHECK(cusolverDnXsyevBatched_bufferSize( + handle, + params, + jobz, + uplo, + n, + CUDA_R_64F, + reinterpret_cast(A), + lda, + CUDA_R_64F, + reinterpret_cast(W), + CUDA_R_64F, + workspaceInBytesOnDevice, + workspaceInBytesOnHost, + batchSize)); +} + +template <> +void xsyevBatched_bufferSize, float>( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + const c10::complex *A, + int64_t lda, + const float *W, + size_t *workspaceInBytesOnDevice, + size_t *workspaceInBytesOnHost, + int64_t batchSize) { + TORCH_CUSOLVER_CHECK(cusolverDnXsyevBatched_bufferSize( + handle, + params, + jobz, + uplo, + n, + CUDA_C_32F, + reinterpret_cast(A), + lda, + CUDA_R_32F, + reinterpret_cast(W), + CUDA_C_32F, + workspaceInBytesOnDevice, + workspaceInBytesOnHost, + batchSize)); +} + +template <> +void xsyevBatched_bufferSize, double>( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + const c10::complex *A, + int64_t lda, + const double *W, + size_t *workspaceInBytesOnDevice, + size_t *workspaceInBytesOnHost, + int64_t batchSize) { + TORCH_CUSOLVER_CHECK(cusolverDnXsyevBatched_bufferSize( + handle, + params, + jobz, + uplo, + n, + CUDA_C_64F, + reinterpret_cast(A), + lda, + CUDA_R_64F, + reinterpret_cast(W), + CUDA_C_64F, + workspaceInBytesOnDevice, + workspaceInBytesOnHost, + batchSize)); +} + +template <> +void xsyevBatched( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + float *A, + int64_t lda, + float *W, + void *bufferOnDevice, + size_t workspaceInBytesOnDevice, + void *bufferOnHost, + size_t workspaceInBytesOnHost, + int *info, + int64_t batchSize) { + TORCH_CUSOLVER_CHECK(cusolverDnXsyevBatched( + handle, + params, + jobz, + uplo, + n, + CUDA_R_32F, + reinterpret_cast(A), + lda, + CUDA_R_32F, + reinterpret_cast(W), + CUDA_R_32F, + bufferOnDevice, + workspaceInBytesOnDevice, + bufferOnHost, + workspaceInBytesOnHost, + info, + batchSize)); +} + +template <> +void xsyevBatched( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + double *A, + int64_t lda, + double *W, + void *bufferOnDevice, + size_t workspaceInBytesOnDevice, + void *bufferOnHost, + size_t workspaceInBytesOnHost, + int *info, + int64_t batchSize) { + TORCH_CUSOLVER_CHECK(cusolverDnXsyevBatched( + handle, + params, + jobz, + uplo, + n, + CUDA_R_64F, + reinterpret_cast(A), + lda, + CUDA_R_64F, + reinterpret_cast(W), + CUDA_R_64F, + bufferOnDevice, + workspaceInBytesOnDevice, + bufferOnHost, + workspaceInBytesOnHost, + info, + batchSize)); +} + +template <> +void xsyevBatched, float>( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + c10::complex *A, + int64_t lda, + float *W, + void *bufferOnDevice, + size_t workspaceInBytesOnDevice, + void *bufferOnHost, + size_t workspaceInBytesOnHost, + int *info, + int64_t batchSize) { + TORCH_CUSOLVER_CHECK(cusolverDnXsyevBatched( + handle, + params, + jobz, + uplo, + n, + CUDA_C_32F, + reinterpret_cast(A), + lda, + CUDA_R_32F, + reinterpret_cast(W), + CUDA_C_32F, + bufferOnDevice, + workspaceInBytesOnDevice, + bufferOnHost, + workspaceInBytesOnHost, + info, + batchSize)); +} + +template <> +void xsyevBatched, double>( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int64_t n, + c10::complex *A, + int64_t lda, + double *W, + void *bufferOnDevice, + size_t workspaceInBytesOnDevice, + void *bufferOnHost, + size_t workspaceInBytesOnHost, + int *info, + int64_t batchSize) { + TORCH_CUSOLVER_CHECK(cusolverDnXsyevBatched( + handle, + params, + jobz, + uplo, + n, + CUDA_C_64F, + reinterpret_cast(A), + lda, + CUDA_R_64F, + reinterpret_cast(W), + CUDA_C_64F, + bufferOnDevice, + workspaceInBytesOnDevice, + bufferOnHost, + workspaceInBytesOnHost, + info, + batchSize)); +} + +#endif // USE_CUSOLVER_64_BIT_XSYEV_BATCHED + } // namespace at::cuda::solver #endif // CUDART_VERSION diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.h b/aten/src/ATen/native/cuda/linalg/CUDASolver.h index 9b17086646d87e..cb46608c50b54f 100644 --- a/aten/src/ATen/native/cuda/linalg/CUDASolver.h +++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.h @@ -7,6 +7,11 @@ #define USE_CUSOLVER_64_BIT #endif +#if defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && CUSOLVER_VERSION >= 11701 +// cuSOLVER version >= 11701 includes 64-bit API for batched syev +#define USE_CUSOLVER_64_BIT_XSYEV_BATCHED +#endif + #if defined(CUDART_VERSION) || defined(USE_ROCM) namespace at { @@ -671,6 +676,84 @@ void xsyevd, double>( #endif // USE_CUSOLVER_64_BIT +#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED + +#define CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \ + cusolverDnHandle_t handle, \ + cusolverDnParams_t params, \ + cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, \ + int64_t n, \ + const scalar_t *A, \ + int64_t lda, \ + const value_t *W, \ + size_t *workspaceInBytesOnDevice, \ + size_t *workspaceInBytesOnHost, \ + int64_t batchSize + +template +void xsyevBatched_bufferSize( + CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xsyevBatched_bufferSize: not implemented"); +} + +template <> +void xsyevBatched_bufferSize( + CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(float, float)); + +template <> +void xsyevBatched_bufferSize( + CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(double, double)); + +template <> +void xsyevBatched_bufferSize, float>( + CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex, float)); + +template <> +void xsyevBatched_bufferSize, double>( + CUDASOLVER_XSYEV_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex, double)); + +#define CUDASOLVER_XSYEV_BATCHED_ARGTYPES(scalar_t, value_t) \ + cusolverDnHandle_t handle, \ + cusolverDnParams_t params, \ + cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, \ + int64_t n, \ + scalar_t *A, \ + int64_t lda, \ + value_t *W, \ + void *bufferOnDevice, \ + size_t workspaceInBytesOnDevice, \ + void *bufferOnHost, \ + size_t workspaceInBytesOnHost, \ + int *info, \ + int64_t batchSize + +template +void xsyevBatched(CUDASOLVER_XSYEV_BATCHED_ARGTYPES(scalar_t, value_t)) { + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xsyevBatched: not implemented"); +} + +template <> +void xsyevBatched( + CUDASOLVER_XSYEV_BATCHED_ARGTYPES(float, float)); + +template <> +void xsyevBatched( + CUDASOLVER_XSYEV_BATCHED_ARGTYPES(double, double)); + +template <> +void xsyevBatched, float>( + CUDASOLVER_XSYEV_BATCHED_ARGTYPES(c10::complex, float)); + +template <> +void xsyevBatched, double>( + CUDASOLVER_XSYEV_BATCHED_ARGTYPES(c10::complex, double)); + +#endif // USE_CUSOLVER_64_BIT_XSYEV_BATCHED + } // namespace solver } // namespace cuda } // namespace at diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp index 4b5cb372bee6b7..9b32f05482d5cf 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.cpp +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -169,7 +169,8 @@ std::string repro_from_args(const ConvolutionParams& params) { ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n"; ss << "import torch\n"; ss << "torch.backends.cuda.matmul.allow_tf32 = " - << pybool(at::globalContext().allowTF32CuBLAS()) << "\n"; + << pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32") + << "\n"; ss << "torch.backends.cudnn.benchmark = " << pybool(at::globalContext().benchmarkCuDNN()) << "\n"; ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic) @@ -725,7 +726,7 @@ Tensor cudnn_convolution_relu( auto& ctx = at::globalContext(); bool benchmark = ctx.benchmarkCuDNN(); - bool allow_tf32 = ctx.allowTF32CuDNN(); + bool allow_tf32 = ctx.allowTF32CuDNN("conv"); auto _bias = bias_t.has_value() ? bias_t.value() : at::zeros( @@ -783,7 +784,7 @@ Tensor cudnn_convolution_add_relu( } auto& ctx = at::globalContext(); - bool allow_tf32 = ctx.allowTF32CuDNN(); + bool allow_tf32 = ctx.allowTF32CuDNN("conv"); bool benchmark = ctx.benchmarkCuDNN(); auto _alpha = alpha.has_value() ? alpha.value().to() : 1.0; auto _bias = bias_t.has_value() diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index 740b54d6772fb9..f9837ccc79a2d4 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -1183,6 +1183,9 @@ void raw_cudnn_convolution_forward_out( if (output.numel() == 0) { return; } + for (auto it = dilation.begin(); it != dilation.end(); it++) { + TORCH_CHECK_VALUE(*it > 0, "Expected positive dilation in convolution."); + } if (at::native::cudnnv8_enabled_check_debug()) { run_single_conv( CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 5d146edb90b069..48119a6a3b4c37 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -92,6 +92,7 @@ void run_cudnn_SDP_bprop( #include #include #include +#include #include #include @@ -319,88 +320,6 @@ auto fixSizeOneDimStrideSDPA( } return strides; } - -void alloc_with_matching_layout( - const Tensor& q, - Tensor& output, - const std::vector& shape) { - TORCH_INTERNAL_ASSERT( - shape.size() == q.sizes().size(), - "cuDNN SDPA alloc_with_matching_layout got requested shape ndim != q ndim"); - - if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) { - output = at::empty_like(q); - return; - } - - // get the "fill order," which is just an argsort on the strides - std::vector fill_order(shape.size()); - std::iota(fill_order.begin(), fill_order.end(), 0); - const auto q_strides = q.strides(); - std::stable_sort( - fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) { - return q_strides[idx1] < q_strides[idx2]; - }); - std::vector ordered_strides(shape.size()); - int64_t current_stride = 1; - for (const int dim_idx : fill_order) { - ordered_strides[dim_idx] = current_stride; - current_stride *= shape[dim_idx]; - } - output = at::empty(at::IntArrayRef(shape), q.options()) - .as_strided( - at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0); -} - -void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) { - const int dims = output.sizes().size(); - std::vector outer_to_inner(dims); - std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0); - const auto o_strides = output.strides(); - std::stable_sort( - outer_to_inner.begin(), - outer_to_inner.end(), - [&o_strides](int idx1, int idx2) { - return o_strides[idx1] > o_strides[idx2]; - }); - std::vector inverse(dims); - for (int d = 0; d < dims; d++) { - inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) - - outer_to_inner.begin(); - } - grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner)) - .contiguous() - .permute(at::IntArrayRef(inverse)); -} - -bool same_strides(const Tensor& t1, const Tensor& t2) { - std::vector t1_strides_no_ones; - std::vector t2_strides_no_ones; - const auto t1strides = t1.strides(); - const auto t2strides = t2.strides(); - const int dim = t1strides.size(); - if (dim != (int)t2strides.size()) { - return false; - } - const auto t1sizes = t1.sizes(); - const auto t2sizes = t2.sizes(); - - // we are going through strides backward here, but if both are backward it's - // comparable - for (int i = 0; i < dim; i++) { - if (t1sizes[i] > 1) { - t1_strides_no_ones.push_back(t1strides[i]); - } - if (t2sizes[i] > 1) { - t2_strides_no_ones.push_back(t2strides[i]); - } - } - return std::equal( - t1_strides_no_ones.begin(), - t1_strides_no_ones.end(), - t2_strides_no_ones.begin(), - t2_strides_no_ones.end()); -} } // namespace auto build_graph_and_tensors( diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index c2c73273041003..7d73ed5305108c 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -245,7 +245,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const { datatype, input_datatype, algo, - at::globalContext().allowTF32CuDNN()); + at::globalContext().allowTF32CuDNN("rnn")); #else rnn_desc.set( handle, @@ -261,7 +261,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const { datatype, input_datatype, algo, - at::globalContext().allowTF32CuDNN()); + at::globalContext().allowTF32CuDNN("rnn")); #endif return rnn_desc; } diff --git a/aten/src/ATen/native/metal/MetalTensorImpl.h b/aten/src/ATen/native/metal/MetalTensorImpl.h index 2fb87b2f4f8979..44152dd3c6d03a 100644 --- a/aten/src/ATen/native/metal/MetalTensorImpl.h +++ b/aten/src/ATen/native/metal/MetalTensorImpl.h @@ -35,7 +35,7 @@ struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl { return c10::fromIntArrayRefKnownNonNegative(strides_); } - bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { + c10::SymBool sym_is_contiguous_custom(c10::MemoryFormat memory_format) const override { return true; } diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 6375f49386b84b..af69dfc76e571f 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -79,7 +79,9 @@ std::tuple miopen_batch_norm( checkAllDefined(c, {running_mean, running_var}); } checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); - if (input->scalar_type() != ScalarType::Half) { + if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) { + checkScalarType(c, weight, ScalarType::Float); + } else { checkAllSameType(c, {input, weight}); } checkAllSameType(c, {weight, bias, running_mean, running_var}); @@ -186,7 +188,7 @@ std::tuple miopen_batch_norm_backward( checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var}); - if (input->scalar_type() == ScalarType::Half) { + if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) { checkScalarType(c, weight, ScalarType::Float); } else { checkAllSameType(c, {input, weight}); diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 8deefaade89cbf..636e94e20f6bfb 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -279,7 +279,7 @@ Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { - TORCH_CHECK(self.is_floating_point()); + TORCH_CHECK(self.is_floating_point(), "Only supports floating-point dtypes, but found: ", self.scalar_type()); auto input_sizes = self.sizes(); DimVector out_sizes(input_sizes.begin(), input_sizes.end()); auto last_dim = dim.back(); @@ -307,7 +307,7 @@ Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, } Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { - TORCH_CHECK(self.is_complex()); + TORCH_CHECK(self.is_complex(), "Only supports complex dtypes, but found: ", self.scalar_type()); if (dim.empty()) { return self.clone(); } @@ -516,7 +516,7 @@ static DimVector _sort_dims(const Tensor& self, IntArrayRef dim, bool exclude_la // n-dimensional complex to real IFFT Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { - TORCH_CHECK(self.is_complex()); + TORCH_CHECK(self.is_complex(), "Only supports complex dtypes, but found: ", self.scalar_type()); // NOTE: Multi-dimensional C2R transforms don't agree with numpy in cases // where the input isn't strictly Hermitian-symmetric. Instead, we use a // multi-dim C2C transform followed by a 1D C2R transform. @@ -539,7 +539,7 @@ Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, // n-dimensional real to complex FFT Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { - TORCH_CHECK(self.is_floating_point()); + TORCH_CHECK(self.is_floating_point(), "Only supports floating-point dtypes, but found: ", self.scalar_type()); auto input_sizes = self.sizes(); DimVector out_sizes(input_sizes.begin(), input_sizes.end()); auto last_dim = dim.back(); @@ -560,7 +560,7 @@ Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, // n-dimensional complex to complex FFT/IFFT Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { - TORCH_CHECK(self.is_complex()); + TORCH_CHECK(self.is_complex(), "Only supports complex dtypes, but found: ", self.scalar_type()); if (dim.empty()) { return self.clone(); } diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index d13fe6b23286c5..1e2993e79f4d75 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -155,6 +155,12 @@ static void check_shape_forward(const Tensor& input, // but weight/bias and grad_weight/grad_bias are always CPU tensor. // +static bool mkldnn_conv_enabled_fpmath_mode_bf16(){ + return at::globalContext().float32Precision("mkldnn", "conv") == "bf16" && + mkldnn_bf16_device_check(); +} + + static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) { auto memory_format = at::MemoryFormat::Contiguous; if (is_channels_last) { @@ -163,7 +169,7 @@ static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bo return memory_format; } -static void _mkldnn_convolution_out ( +static void _mkldnn_convolution_out( const Tensor& input_t, const Tensor& weight_t, const Tensor& bias, @@ -261,6 +267,10 @@ static Tensor _mkldnn_convolution( output.resize_(output_sizes, memory_format); y = itensor_from_tensor(output); } + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + input_t.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } _mkldnn_convolution_out( input_t, weight_t, @@ -442,6 +452,10 @@ Tensor mkldnn_convolution_pointwise_binary( op_attr.set_post_ops(po); auto aprop_kind = ideep::prop_kind::forward_inference; + if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } + if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); ideep::convolution_forward::compute_binary( @@ -579,6 +593,10 @@ Tensor& mkldnn_convolution_pointwise_binary_( op_attr = ideep::attr_t::fuse_sum(); } auto aprop_kind = ideep::prop_kind::forward_inference; + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + input_t.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } _mkldnn_convolution_out( input_t, weight_t, @@ -697,6 +715,10 @@ Tensor _mkldnn_convolution_transpose( y = itensor_from_tensor(output); } + if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } + if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true); ideep::convolution_transpose_forward::compute_v3( @@ -781,6 +803,11 @@ Tensor mkldnn_convolution_backward_input( grad_input.resize_(input_size, memory_format); grad_x = itensor_from_tensor(grad_input); } + ideep::attr_t op_attr = ideep::attr_t(); + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + weight.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } ideep::convolution_backward_data::compute_v2( grad_y, w, @@ -791,7 +818,17 @@ Tensor mkldnn_convolution_backward_input( padding.vec(), padding.vec(), groups, +#if IDEEP_PREREQ(3, 4, 1, 3) + is_channels_last, + op_attr); +#else is_channels_last); + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + weight.scalar_type() == at::kFloat) { + TORCH_WARN_ONCE( + "Unexpected ideep version to support fpmath_mode_bf16, please update ideep version to align with pytorch main branch"); + } +#endif if (grad_output.is_mkldnn()) { return MKLDNNTensor(grad_x, grad_output.options()); @@ -816,6 +853,11 @@ std::tuple mkldnn_convolution_backward_weights( const ideep::tensor x = itensor_from_tensor(input, /*from_const_data_ptr*/true); ideep::tensor grad_w, grad_b; + ideep::attr_t op_attr = ideep::attr_t(); + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + input.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } if (bias_defined) { ideep::convolution_backward_weights::compute_v2( x, @@ -828,7 +870,8 @@ std::tuple mkldnn_convolution_backward_weights( padding.vec(), padding.vec(), groups, - is_channels_last); + is_channels_last, + op_attr); } else { ideep::convolution_backward_weights::compute_v2( x, @@ -840,7 +883,8 @@ std::tuple mkldnn_convolution_backward_weights( padding.vec(), padding.vec(), groups, - is_channels_last); + is_channels_last, + op_attr); } if (!is_channels_last) { @@ -962,6 +1006,11 @@ Tensor mkldnn_convolution_transpose_backward_input( grad_input.resize_(input_size, memory_format); grad_x = itensor_from_tensor(grad_input); } + ideep::attr_t op_attr = ideep::attr_t(); + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + weight.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } ideep::convolution_transpose_backward_data::compute_v3( grad_y, w, @@ -972,7 +1021,8 @@ Tensor mkldnn_convolution_transpose_backward_input( padding_r(padding, output_padding), dilation.vec(), groups, - is_channels_last); + is_channels_last, + op_attr); if (grad_output.is_mkldnn()) { return MKLDNNTensor(grad_x, grad_output.options()); @@ -998,6 +1048,11 @@ std::tuple mkldnn_convolution_transpose_backward_weights( auto x = itensor_from_tensor(input, /*from_const_data_ptr*/true); ideep::tensor grad_w, grad_b; + ideep::attr_t op_attr = ideep::attr_t(); + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + input.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } if (bias_defined) { ideep::convolution_transpose_backward_weights::compute_v3( x, @@ -1010,7 +1065,8 @@ std::tuple mkldnn_convolution_transpose_backward_weights( padding_r(padding, output_padding), dilation.vec(), groups, - is_channels_last); + is_channels_last, + op_attr); } else { ideep::convolution_transpose_backward_weights::compute_v3( x, @@ -1022,7 +1078,8 @@ std::tuple mkldnn_convolution_transpose_backward_weights( padding_r(padding, output_padding), dilation.vec(), groups, - is_channels_last); + is_channels_last, + op_attr); } if (!is_channels_last) { diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index d7dbbcc8c9869d..8dbb29bb3e01b7 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -68,6 +68,11 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, namespace at::native { +static bool use_mkldnn_bf32_linear() { + return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16" && + mkldnn_bf16_device_check(); +} + Tensor mkldnn_linear( const Tensor& self, const Tensor& weight_t, const std::optional& bias_opt) { @@ -251,7 +256,9 @@ Tensor mkldnn_linear_pointwise( it != fusion_unary_attr_map().end(), "Fusion behavior undefined."); op_attr = it->second(scalars, algorithm); } - + if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } if (mkldnn_bias.has_value()) { ideep::inner_product_forward::compute( mkldnn_input, @@ -341,6 +348,10 @@ Tensor mkldnn_linear_pointwise_binary( auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc); auto aprop_kind = ideep::prop_kind::forward_inference; + if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } + if (mkldnn_bias.has_value()) { ideep::inner_product_forward::compute_binary( mkldnn_input, diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index d896ee27dd04e7..a9c094d85989a8 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -104,7 +104,7 @@ static bool use_mkldnn_fp16_matmul() { } static bool use_mkldnn_bf32_matmul() { - return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM; + return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16"; } // returns an ideep::tensor diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index b35fb768677dd4..724b353415b972 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -162,8 +162,7 @@ std::tuple mkldnn_batch_norm( ideep::tensor saved_mean; ideep::tensor saved_var; ideep::batch_normalization_forward_training::compute( - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - x, w, b, y, saved_mean, saved_var, momentum, eps); + x, w, b, y, saved_mean, saved_var, static_cast(momentum), static_cast(eps)); if (use_running_stat) { auto len = x.get_nelems() / w.get_nelems(); // n*h*w ideep::tensor m = itensor_from_tensor(running_mean); @@ -171,8 +170,7 @@ std::tuple mkldnn_batch_norm( const std::vector scales_mean{static_cast(1 - momentum), static_cast(momentum)}; const std::vector scales_var{static_cast(1 - momentum), - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - static_cast(momentum * len / (len - 1))}; + static_cast(momentum * static_cast(len) / (static_cast(len) - 1))}; ideep::sum::compute(scales_mean, {m, saved_mean}, m); ideep::sum::compute(scales_var, {v, saved_var}, v); } diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 89b69ec70ddb71..0e5c7fcf3a0aac 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -9,27 +10,28 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { const auto query_size_last = params.query.sym_size(-1); const auto key_size_last = params.key.sym_size(-1); const auto value_size_last = params.value.sym_size(-1); - if ((query_size_last != key_size_last) || - (query_size_last != value_size_last)) { + if (query_size_last != key_size_last) { if (debug) { TORCH_WARN( - "OneDNN attention requires q,k,v to have the same last dimension.", + "OneDNN attention requires q,k to have the same last dimension.", " Got Query.size(-1): ", query_size_last, ", Key.size(-1): ", key_size_last, - ", Value.size(-1): ", - value_size_last, " instead."); } return false; } - if (query_size_last > 256) { + + constexpr int MAX_HEAD_DIM = 576; + const auto max_size_last = query_size_last.max(value_size_last); + if (max_size_last > MAX_HEAD_DIM) { if (debug) { TORCH_WARN( - "OneDNN attention requires q,k,v to have head dimension less than 256.", - " Got ", - query_size_last, + "OneDNN attention requires q,k,v to have head dimension less than ", + MAX_HEAD_DIM, + ". Got ", + max_size_last, " instead."); } return false; @@ -173,6 +175,9 @@ _scaled_dot_product_fused_attention_overrideable_xpu( TORCH_INTERNAL_ASSERT( query.size(3) == key.size(3), "scaled_dot_product_fused_attention_overrideable_xpu: Q/K should have the same head_dim"); + TORCH_INTERNAL_ASSERT( + query.size(1) % key.size(1) == 0, + "scaled_dot_product_fused_attention_overrideable_xpu: number of heads in K/V must divide number of heads in Q"); TORCH_INTERNAL_ASSERT( dropout_p == 0.0, "scaled_dot_product_fused_attention_overrideable_xpu: Currently do not support dropout > 0"); @@ -181,31 +186,33 @@ _scaled_dot_product_fused_attention_overrideable_xpu( "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal"); const int64_t batch_size = query.size(0); - const int64_t num_head = query.size(1); + const int64_t num_head_q = query.size(1); const int64_t num_head_kv = key.size(1); - const int64_t head_dim = query.size(3); + const int64_t head_dim_qk = query.size(3); const int64_t head_dim_v = value.size(3); const int64_t seq_len_q = query.size(2); const int64_t seq_len_kv = key.size(2); - auto opts = query.options(); - auto output = at::empty({batch_size, num_head, seq_len_q, head_dim}, opts); + at::Tensor output; + std::vector output_shape = { + batch_size, num_head_q, seq_len_q, head_dim_v}; + alloc_with_matching_layout(query, output, output_shape); at::Tensor logsumexp, debug_attn_mask; // not supported at::native::onednn::gpu_float_sdpa( batch_size, seq_len_q, seq_len_kv, - num_head, + num_head_q, num_head_kv, - head_dim, + head_dim_qk, head_dim_v, query, key, value, attn_bias, is_causal, - scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim)), + scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)), output); // rng not used diff --git a/aten/src/ATen/native/mkldnn/xpu/Conv.cpp b/aten/src/ATen/native/mkldnn/xpu/Conv.cpp index 81a4191171875b..67558aeebbb83a 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Conv.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include diff --git a/aten/src/ATen/native/mkldnn/xpu/Conv.h b/aten/src/ATen/native/mkldnn/xpu/Conv.h new file mode 100644 index 00000000000000..31ddf4e89aa51c --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/Conv.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include + +#if AT_MKLDNN_ENABLED() + +namespace at::native::xpu { +C10_API Tensor convolution_pointwise( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view attr, + torch::List> scalars, + std::optional algorithm); + +C10_API Tensor convolution_pointwise_binary( + const Tensor& input_t, + const Tensor& other_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +C10_API Tensor& convolution_pointwise_binary_( + Tensor& other_t, + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +} // namespace at::native::xpu + +#endif // AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/mkldnn/xpu/FusionUtils.cpp b/aten/src/ATen/native/mkldnn/xpu/FusionUtils.cpp index ffb638e553b79e..74922bdb0c0680 100644 --- a/aten/src/ATen/native/mkldnn/xpu/FusionUtils.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/FusionUtils.cpp @@ -122,7 +122,7 @@ onednn::Attr& construct_unary_attr( // If further unary operations required, they can be added to these sets or // add new sets according to their new categories. static const std::set argument_less = { - "relu", "sigmoid", "tanh", "hardswish", "swish", "hardsigmoid"}; + "none", "relu", "sigmoid", "tanh", "hardswish", "swish", "hardsigmoid"}; static const std::set need_scalars = { "leaky_relu", "hardtanh"}; static const std::set need_algorithm = {"gelu"}; diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index 0699d0a5604ebd..1d90711f6e3829 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -1,9 +1,11 @@ +#include #include #include #include - #include +namespace { + using namespace at::native::onednn; using logical_tensor = dnnl::graph::logical_tensor; using data_type = logical_tensor::data_type; @@ -11,7 +13,13 @@ using dims = logical_tensor::dims; using op = dnnl::graph::op; using partition = dnnl::graph::partition; -namespace { +inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) { + return scalar_type == c10::ScalarType::Float ? data_type::f32 + : scalar_type == c10::ScalarType::Half ? data_type::f16 + : scalar_type == c10::ScalarType::BFloat16 ? data_type::bf16 + : data_type::undef; +} + struct SDPALogicalParams { enum class TensorID { query, @@ -38,12 +46,15 @@ struct SDPALogicalParams { const at::Tensor& value_, const std::optional& attn_mask_, const at::Tensor& output_, + int batch_size, + int seq_len_q, + int seq_len_kv, + int num_head_q, + int num_head_kv, + int head_dim_qk, + int head_dim_v, bool is_causal) { - const data_type dtype = // to logical_tensor data type - query_.scalar_type() == c10::ScalarType::Float ? data_type::f32 - : query_.scalar_type() == c10::ScalarType::Half ? data_type::f16 - : query_.scalar_type() == c10::ScalarType::BFloat16 ? data_type::bf16 - : data_type::undef; + const data_type dtype = to_logical_tensor_data_type(query_.scalar_type()); TORCH_INTERNAL_ASSERT( (dtype != data_type::undef), "Only FP16/BF16/FP32 datatypes are currently supported"); @@ -72,6 +83,25 @@ struct SDPALogicalParams { at::native::onednn::undo_broadcast(reshaped_attn_mask); } + if (num_head_q != num_head_kv) { // Check whether the attention is a + // Grouped-Query Attention (GQA) + int group_num = num_head_kv; + int group_size = num_head_q / num_head_kv; + // oneDNN requires the shape of the query tensor to be represented as + // [batch_size, num_head_q / num_head_kv, num_head_kv, seq_len_q, + // head_dim_qk]. Please refer to + // https://uxlfoundation.github.io/oneDNN/dev_guide_graph_gqa.html#gqa-pattern + reshaped_query = query_.view( + {batch_size, group_num, group_size, seq_len_q, head_dim_qk}); + reshaped_key = key_.unsqueeze(2); + reshaped_value = value_.unsqueeze(2); + reshaped_output = output_.view( + {batch_size, group_num, group_size, seq_len_q, head_dim_v}); + if (attn_mask_.has_value() && attn_mask_.value().dim() == 4) { + reshaped_attn_mask = attn_mask_.value().unsqueeze(2); + } + } + query = { static_cast(TensorID::query), dtype, @@ -84,22 +114,27 @@ struct SDPALogicalParams { reshaped_key.strides().vec()}; scale = { static_cast(TensorID::scale), - dtype, + to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), scalar_shape, logical_tensor::layout_type::strided, logical_tensor::property_type::constant}; if (is_causal) { neg_inf = { static_cast(TensorID::neg_inf), - dtype, + to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), scalar_shape, logical_tensor::layout_type::strided, logical_tensor::property_type::constant}; } if (attn_mask_.has_value()) { + const data_type mask_dtype = + to_logical_tensor_data_type(attn_mask_->scalar_type()); + TORCH_INTERNAL_ASSERT( + (mask_dtype != data_type::undef), + "Only FP16/BF16/FP32 datatypes are currently supported for attn_mask"); attn_mask = { static_cast(TensorID::attn_mask), - dtype, + mask_dtype, reshaped_attn_mask.sizes().vec(), reshaped_attn_mask.strides().vec()}; } @@ -131,23 +166,21 @@ struct SDPALogicalParams { }; partition create_sdpa_graph_partition( - int batch_size, - int seq_len_q, - int seq_len_k, - int num_head, - int head_dim, bool is_causal, data_type dtype, const SDPALogicalParams& params) { // graph building and partitioning // currently, we assume that Q and K have same sequence length - dims qk_output_shape = {batch_size, num_head, seq_len_q, seq_len_k}; - dims scale_shape = {1}; size_t lt_id = static_cast(SDPALogicalParams::TensorID::end); size_t op_id = 0; - logical_tensor matmul_qk_out{lt_id++, dtype}; + // OneDNN graph has optimized implementation for `f16` or `bf16` SDPA with + // `f32` intermediate data type on Intel Graphics Products with Intel(R) Xe + // Matrix Extensions (Intel(R) XMX) support, which means the + // Q/K/V tensors have bf16 or f16 data type while the output of the first + // MatMul, Scale, Mask, and the input of SoftMax are in f32 data type. + logical_tensor matmul_qk_out{lt_id++, data_type::f32}; op matmul_qk{ op_id++, op::kind::MatMul, @@ -156,7 +189,7 @@ partition create_sdpa_graph_partition( "matmul_qk"}; matmul_qk.set_attr(op::attr::transpose_b, true); - logical_tensor scaled_qk_out{lt_id++, dtype}; + logical_tensor scaled_qk_out{lt_id++, data_type::f32}; op scale_mul{ op_id++, op::kind::Multiply, @@ -181,7 +214,7 @@ partition create_sdpa_graph_partition( if (params.attn_mask.has_value()) { TORCH_INTERNAL_ASSERT( !is_causal, "Additive mask cannot use with is_causal."); - masked_qk_out = {lt_id++, dtype}; + masked_qk_out = {lt_id++, data_type::f32}; mask_add = { op_id++, op::kind::Add, @@ -216,7 +249,7 @@ partition create_sdpa_graph_partition( {mask_gt_out.value()}, "mask_gt"}; - masked_qk_out = {lt_id++, dtype}; + masked_qk_out = {lt_id++, data_type::f32}; mask_select = { op_id++, op::kind::Select, @@ -232,6 +265,7 @@ partition create_sdpa_graph_partition( op softmax{op_id++, op::kind::SoftMax, "softmax"}; softmax.set_attr(op::attr::axis, -1); + softmax.set_attr(op::attr::mode, "inf_as_zero"); logical_tensor softmax_out{lt_id++, dtype}; softmax.add_input(masked_qk_out.value_or(scaled_qk_out)); @@ -269,11 +303,6 @@ partition create_sdpa_graph_partition( } partition& find_or_create_graph_partition( - int batch_size, - int seq_len_q, - int seq_len_k, - int num_head, - int head_dim, bool is_causal, const SDPALogicalParams& params) { thread_local static PartitionCache cache; @@ -303,15 +332,8 @@ partition& find_or_create_graph_partition( if (!partition_.has_value()) { // partition cache no hit // graph building and partitioning - partition sdp_partition = create_sdpa_graph_partition( - batch_size, - seq_len_q, - seq_len_k, - num_head, - head_dim, - is_causal, - dtype, - params); + partition sdp_partition = + create_sdpa_graph_partition(is_causal, dtype, params); partition_ = cache.insert_partition_cache(patternID, sdp_partition); } return *partition_; @@ -322,10 +344,10 @@ namespace at::native::onednn { void gpu_float_sdpa( int batch_size, int seq_len_q, - int seq_len_k, - int num_head, + int seq_len_kv, + int num_head_q, int num_head_kv, - int head_dim, + int head_dim_qk, int head_dim_v, const Tensor& query, const Tensor& key, @@ -340,46 +362,42 @@ void gpu_float_sdpa( const auto get_tril_mask = [&]() { auto opts = query.options(); auto bool_tril = - at::ones_symint( - {query.sym_size(-2), key.sym_size(-2)}, opts.dtype(at::kBool)) - .tril(); + at::ones_symint({seq_len_q, seq_len_kv}, opts.dtype(at::kBool)).tril(); return at::where( bool_tril, 0.f, at::scalar_tensor(-std::numeric_limits::infinity(), opts)); }; - static bool driver_support_implict_causal = true; - if (attn_mask.has_value()) { - TORCH_INTERNAL_ASSERT( - !is_causal, - "scaled_dot_product_fused_attention_overrideable_xpu: " - "attn_mask cannot present with is_causal"); - } else { - // Currenetly implict mask only supports square fp16 cases - const bool support_implict_causal = driver_support_implict_causal && - (query.dtype() == at::kHalf || query.dtype() == at::kBFloat16) && - seq_len_q == seq_len_k; - if (is_causal && !support_implict_causal) { - attn_mask = get_tril_mask(); - is_causal = false; - } + // OneDNN doesn't support fp32 ukernel for implicit causal mask, + // and the reference implementation is worse than aten math + explict causal + // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32 + // ukernel for implicit causal mask. + if (is_causal && query.dtype() == at::kFloat) { + attn_mask = get_tril_mask(); + is_causal = false; } - std::vector l_inputs, l_outputs; + std::vector l_inputs, l_outputs; std::optional compiled_partition; auto get_compiled_partition = [&]() { const SDPALogicalParams logical_params( - query, key, value, attn_mask, output, is_causal); - auto& partition_ = find_or_create_graph_partition( + query, + key, + value, + attn_mask, + output, batch_size, seq_len_q, - seq_len_k, - num_head, - head_dim, - is_causal, - logical_params); + seq_len_kv, + num_head_q, + num_head_kv, + head_dim_qk, + head_dim_v, + is_causal); + auto& partition_ = + find_or_create_graph_partition(is_causal, logical_params); auto i = logical_params.get_input(); auto o = logical_params.get_output(); auto compiled_partition = partition_.compile(i, o, eng); @@ -388,24 +406,18 @@ void gpu_float_sdpa( return compiled_partition; }; - // maybe retry without causal mask - try { - compiled_partition = get_compiled_partition(); - } catch (std::exception& e) { - if (is_causal) { - attn_mask = get_tril_mask(); - is_causal = false; - compiled_partition = get_compiled_partition(); - driver_support_implict_causal = false; - } else { - throw e; - } - } + compiled_partition = get_compiled_partition(); - Tensor softmax_scale1 = at::full({}, softmax_scale, query.options()); + Tensor softmax_scale1 = at::full( + {}, + softmax_scale, + query.options().dtype(at::toOpMathType(query.scalar_type()))); std::optional neg_inf; if (is_causal) { - neg_inf = at::full({}, -INFINITY, query.options()); + neg_inf = at::full( + {}, + -INFINITY, + query.options().dtype(at::toOpMathType(query.scalar_type()))); } std::vector outputs = { diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/DnnlExt.h b/aten/src/ATen/native/mkldnn/xpu/detail/DnnlExt.h new file mode 100644 index 00000000000000..d4fc77ed516ade --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/DnnlExt.h @@ -0,0 +1,594 @@ +#pragma once + +#include + +#include +#include +#include + +#include +#include + +namespace std { + +template <> +struct hash { + size_t operator()(dnnl::memory::dims const& vec) const { + size_t seed = vec.size(); + for (auto& i : vec) { + seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +} // namespace std + +using namespace dnnl; + +namespace at::native::onednn { + +class primitive_ext : public primitive { + static constexpr int max_args = 12; + + public: + primitive_ext(const primitive& base) : primitive(base) {} + primitive_ext(primitive&& base) : primitive(std::move(base)) {} + + /// Returns a memory descriptor. + /// + /// @note + /// There are also convenience methods + /// #dnnl::primitive_desc_base::src_desc(), + /// #dnnl::primitive_desc_base::dst_desc(), and others. + /// + /// @param what The kind of parameter to query; can be + /// #dnnl::query::src_md, #dnnl::query::dst_md, etc. + /// @param idx Index of the parameter. For example, convolution bias can + /// be queried with what = #dnnl::query::weights_md and idx = 1. + /// @returns The requested memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// parameter of the specified kind or index. + const_dnnl_memory_desc_t query_md(query what, int idx = 0) const { + std::vector valid_q{ + query::src_md, + query::diff_src_md, + query::weights_md, + query::diff_weights_md, + query::dst_md, + query::diff_dst_md, + query::workspace_md, + query::scratchpad_md, + query::exec_arg_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), [=](query q) { + return what == q; + })) + DNNL_THROW_ERROR( + dnnl_invalid_arguments, "memory descriptor query is invalid"); + + const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md( + this->get_primitive_desc(), dnnl::convert_to_c(what), idx); + + return cdesc ? cdesc : nullptr; + } + + /// Returns a source memory descriptor. + /// @param idx Source index. + /// @returns Source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source parameter with index @p idx. + const_dnnl_memory_desc_t src_desc(int idx) const { + return query_md(query::src_md, idx); + } + + /// Returns a destination memory descriptor. + /// @param idx Destination index. + /// @returns Destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination parameter with index @p idx. + const_dnnl_memory_desc_t dst_desc(int idx) const { + return query_md(query::dst_md, idx); + } + + /// Returns a weights memory descriptor. + /// @param idx Weights index. + /// @returns Weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// weights parameter with index @p idx. + const_dnnl_memory_desc_t weights_desc(int idx) const { + return query_md(query::weights_md, idx); + } + + /// Returns a diff source memory descriptor. + /// @param idx Diff source index. + /// @returns Diff source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source parameter with index @p idx. + const_dnnl_memory_desc_t diff_src_desc(int idx) const { + return query_md(query::diff_src_md, idx); + } + + /// Returns a diff destination memory descriptor. + /// @param idx Diff destination index. + /// @returns Diff destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination parameter with index @p idx. + const_dnnl_memory_desc_t diff_dst_desc(int idx) const { + return query_md(query::diff_dst_md, idx); + } + + /// Returns a diff weights memory descriptor. + /// @param idx Diff weights index. + /// @returns Diff weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff weights parameter with index @p idx. + const_dnnl_memory_desc_t diff_weights_desc(int idx) const { + return query_md(query::diff_weights_md, idx); + } + + const_dnnl_memory_desc_t exec_arg_desc(int idx) const { + return query_md(query::exec_arg_md, idx); + } + + // Separate versions without the index argument for documentation + // purposes. + + /// Returns a source memory descriptor. + /// @returns Source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source parameter. + const_dnnl_memory_desc_t src_desc() const { + return src_desc(0); + } + + /// Returns a destination memory descriptor. + /// @returns Destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination parameter. + const_dnnl_memory_desc_t dst_desc() const { + return dst_desc(0); + } + + /// Returns a weights memory descriptor. + /// @returns Weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// weights parameter. + const_dnnl_memory_desc_t weights_desc() const { + return weights_desc(0); + } + + /// Returns a diff source memory descriptor. + /// @returns Diff source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source memory with. + const_dnnl_memory_desc_t diff_src_desc() const { + return diff_src_desc(0); + } + + /// Returns a diff destination memory descriptor. + /// @returns Diff destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination parameter. + const_dnnl_memory_desc_t diff_dst_desc() const { + return diff_dst_desc(0); + } + + /// Returns a diff weights memory descriptor. + /// @returns Diff weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff weights parameter. + const_dnnl_memory_desc_t diff_weights_desc() const { + return diff_weights_desc(0); + } + + /// Returns the workspace memory descriptor. + /// @returns Workspace memory descriptor. + /// @returns A zero memory descriptor if the primitive does not require + /// workspace parameter. + const_dnnl_memory_desc_t workspace_desc() const { + return query_md(query::workspace_md, 0); + } + + /// Returns the scratchpad memory descriptor. + /// @returns scratchpad memory descriptor. + /// @returns A zero memory descriptor if the primitive does not require + /// scratchpad parameter. + /// @sa @ref dev_guide_attributes_scratchpad + const_dnnl_memory_desc_t scratchpad_desc() const { + return query_md(query::scratchpad_md, 0); + } + + inline memory make_memory( + const_dnnl_memory_desc_t md_t, + const engine& aengine, + void* handle = DNNL_MEMORY_ALLOCATE) const { + sycl_interop::memory_kind kind = dnnl::sycl_interop::memory_kind::usm; + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_sycl_interop_memory_create( + &c_memory, md_t, aengine.get(), convert_to_c(kind), handle), + "could not create a memory"); + return memory(c_memory); + } + + memory make_src(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(src_desc(), aengine, handle); + } + + memory make_weight(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(weights_desc(), aengine, handle); + } + + memory make_bias(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(weights_desc(1), aengine, handle); + } + + memory make_dst(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(dst_desc(), aengine, handle); + } + + memory make_scratchpad( + const engine& aengine, + void* handle = DNNL_MEMORY_ALLOCATE) const { + return make_memory(scratchpad_desc(), aengine, handle); + } + + size_t get_scratchpad_size() const { + return dnnl_memory_desc_get_size(scratchpad_desc()); + } + + memory make_args(int arg_class, const engine& aengine, void* handle) const { + switch (arg_class) { + case DNNL_ARG_SRC: + return make_src(aengine, handle); + case DNNL_ARG_WEIGHTS: + return make_weight(aengine, handle); + case DNNL_ARG_SCRATCHPAD: + return make_scratchpad(aengine, handle); + case DNNL_ARG_DST: + return make_dst(aengine, handle); + case DNNL_ARG_BIAS: + return make_bias(aengine, handle); + default: + TORCH_INTERNAL_ASSERT( + false, "unsupported argument class for primitive_ext"); + } + } + + template + void set_attribute(int slot, int arg_class, void* handle, M constructor) { + if (mem_arg_cache[slot]) + mem_arg_cache[slot].set_data_handle(handle); + else { + mem_arg_cache[slot] = constructor(); + c_args[slot].arg = arg_class; + c_args[slot].memory = mem_arg_cache[slot].get(); + } + } + + sycl::event execute( + const stream& astream, + const engine& aengine, + std::vector>&& handles, + int slot_off = 2) { + auto off = slot_off; + for (const auto& p : handles) { + auto& m_arg = mem_arg_cache[off]; + if (m_arg) + m_arg.set_data_handle(p.second); + else { + m_arg = make_args(p.first, aengine, p.second); + c_args[off].arg = p.first; + c_args[off].memory = m_arg.get(); + } + ++off; + } + + sycl::event return_event; + std::vector deps{}; + error::wrap_c_api( + dnnl_sycl_interop_primitive_execute( + this->get(), astream.get(), off, c_args, &deps, &return_event), + "could not execute a primitive"); + return return_event; + } + + private: + memory mem_arg_cache[max_args]; + dnnl_exec_arg_t c_args[max_args]; +}; + +// Specifies the combined data types of input and weight tensors. +// For example, f32 means both input and weight are FP32, +// bf16_int4 means input is BF16 and weight is INT4. +enum class joint_dtypes_t { f32 = 0, f16, bf16, int8, f16_int4, bf16_int4 }; + +// Specifies the transposition state of input and weight tensors. +// Convention: first letter = input, second letter = weight. +// 'n' = not transposed, 't' = transposed. +// For example, 'nt' means input is not transposed, weight is transposed. +enum class trans_type_t { nn = 0, nt, tn, tt }; + +// Specifies the type and placement of bias in the computation. +// 'none' = no bias, +// 'scalar' = a single scalar bias applied to all elements, +// 'm' = per-row bias (typically matched to input rows), +// 'n' = per-column bias (typically matched to output channels), +// 'mn' = full bias matrix matching the output dimensions. +enum class bias_type_t { none = 0, scalar, m, n, mn }; + +template +T concat(const T& t1, at::ScalarType d) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back((int64_t)d); + + return t; +} + +template +T concat(const T& t1, bool b) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back(b); + + return t; +} + +template +T concat(const T& t1, int b) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back(b); + + return t; +} + +template +T concat(const T& t1, const T& t2) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.insert(t.end(), t2.begin(), t2.end()); + + return t; +} + +template +T1 concat(const T1& t1, const T2& t2, const Ts&... ts) { + return concat(concat(t1, t2), ts...); +} + +template +struct onednn_types_mapper; + +template <> +struct onednn_types_mapper { + static inline std::tuple + get() { + return std::make_tuple( + dnnl::memory::data_type::f16, dnnl::memory::data_type::u4); + } +}; + +template <> +struct onednn_types_mapper { + static inline std::tuple + get() { + return std::make_tuple( + dnnl::memory::data_type::bf16, dnnl::memory::data_type::u4); + } +}; + +// TODO: bias types maybe not right +static inline dnnl::memory::dims get_bias_type( + bias_type_t b_dims, + const int m, + const int n) { + switch (b_dims) { + case bias_type_t::none: + return {0}; + case bias_type_t::scalar: + return {1, 1}; + case bias_type_t::m: + return {m, 1}; + case bias_type_t::n: + return {1, n}; + case bias_type_t::mn: + return {m, n}; + default: + TORCH_INTERNAL_ASSERT(false, "unsupported bias type ..."); + } +} + +// TODO: use template specialization on struct +template +inline void get_strides( + memory::dims& src_strides, + memory::dims& wei_strides, + memory::dims& dst_strides, + const int64_t lda, + const int64_t ldb, + const int64_t ldc) {} + +template <> +inline void get_strides( + memory::dims& src_strides, + memory::dims& wei_strides, + memory::dims& dst_strides, + const int64_t lda, + const int64_t ldb, + const int64_t ldc) { + src_strides = {lda, 1}; + wei_strides = {1, ldb}; + dst_strides = {ldc, 1}; +} + +using primitive_cache = + at::native::onednn::lru_cache; + +template +struct matmul_primitive_cache_t { + static inline primitive_ext& get( + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, + const int64_t ldc, + const bias_type_t + b_dims, // for shapeless bias, not put it into template parameter + const int device_id, + F f_attr, + const int64_t scale_group_size, + const int64_t zp_group_size) { + auto& cached = get_cache(device_id); + memory::dims src_strides, wei_strides, dst_strides; + get_strides(src_strides, wei_strides, dst_strides, lda, ldb, ldc); + auto pri_key = at::native::onednn::concat( + src_strides, + wei_strides, + m, + n, + k, + int(b_dims), + int(scale_group_size), + int(zp_group_size)); + auto iter = cached.find(pri_key); + if (iter == cached.end()) { + auto [src_dt, wei_dt] = onednn_types_mapper::get(); + auto bias_dims = get_bias_type(b_dims, m, n); + + auto src_md = memory::desc({m, k}, src_dt, src_strides); + auto wei_md = memory::desc({k, n}, wei_dt, wei_strides); + auto dst_md = memory::desc({m, n}, src_dt, dst_strides); + auto bias_format = b_dims == bias_type_t::none + ? dnnl::memory::format_tag::undef + : dnnl::memory::format_tag::ab; + auto bias_md = + memory::desc(bias_dims, src_dt, bias_format); // {m, n} or {1, n} + + primitive_attr pattr; + f_attr(pattr); + + dnnl::matmul::primitive_desc matmul_pd; + auto aengine = + at::native::onednn::GpuEngineManager::Instance().get_engine( + device_id); + if (b_dims == bias_type_t::none) { + matmul_pd = dnnl::matmul::primitive_desc( + aengine, src_md, wei_md, dst_md, pattr); + } else { + matmul_pd = dnnl::matmul::primitive_desc( + aengine, src_md, wei_md, bias_md, dst_md, pattr); + } + + return cached.insert({pri_key, primitive_ext(dnnl::matmul(matmul_pd))}) + .first->second; + } else { + return iter->second; + } + } + + private: + static constexpr int max_cache_capacity = 512; + // if default constructor of primitive cache could read the environment + // variable then it'll save a lot of trouble + static inline thread_local std::array mappings; + + // this won't be needed if primitive_cache have good default constructor + static inline primitive_cache& get_cache(const int device_id) { + auto& mapping = mappings[device_id]; + if (mapping.max_size() == 0) { + mapping.resize(max_cache_capacity); + } + return mapping; + } +}; + +template +static inline primitive_ext& matmul_primitive_create_and_cache( + const trans_type_t Tt, + const bias_type_t b_dims, + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, + const int64_t ldc, + const int device_id, + F attr, + const int64_t scale_group_size, + const int64_t zp_group_size) { + switch (Tt) { + case trans_type_t::nt: + return matmul_primitive_cache_t::get( + m, + n, + k, + lda, + ldb, + ldc, + b_dims, + device_id, + attr, + scale_group_size, + zp_group_size); + default: + TORCH_INTERNAL_ASSERT(false, "unsupported trans type ..."); + } +} + +template +static inline primitive_ext& matmul_primitive_create_and_cache( + const joint_dtypes_t Ts, + const trans_type_t Tt, + const bias_type_t b_dims, + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, // is weight ldb necessary? + const int64_t ldc, + const int device_id, + F attr, + const int64_t scale_group_size = 0, + const int64_t zp_group_size = 0) { + switch (Ts) { + case joint_dtypes_t::f16_int4: + return matmul_primitive_create_and_cache( + Tt, + b_dims, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + attr, + scale_group_size, + zp_group_size); + case joint_dtypes_t::bf16_int4: + return matmul_primitive_create_and_cache( + Tt, + b_dims, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + attr, + scale_group_size, + zp_group_size); + default: + TORCH_INTERNAL_ASSERT(false, "Only support int4 ..."); + } +} + +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/LRUCache.h b/aten/src/ATen/native/mkldnn/xpu/detail/LRUCache.h new file mode 100644 index 00000000000000..9229c10bc57a36 --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/LRUCache.h @@ -0,0 +1,110 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native::onednn { + +template < + class key_t, + class value_t, + template class map_t = std::unordered_map> +class lru_cache { + public: + using value_type = std::pair; + using list_type = std::list; + using list_iter = typename list_type::iterator; + using map_type = map_t; + using const_list_iter = typename list_type::const_iterator; + using size_type = typename list_type::size_type; + + explicit lru_cache(size_type capacity) : capacity_(capacity) {} + lru_cache() : capacity_(0) {} + + [[nodiscard]] size_type size() const noexcept { + return map_.size(); + } + [[nodiscard]] size_type max_size() const noexcept { + return capacity_; + } + [[nodiscard]] bool empty() const noexcept { + return vlist_.empty(); + } + + void resize(size_type new_capacity) { + capacity_ = new_capacity; + trim(); + } + + list_iter begin() noexcept { + return vlist_.begin(); + } + const_list_iter begin() const noexcept { + return vlist_.begin(); + } + list_iter end() noexcept { + return vlist_.end(); + } + const_list_iter end() const noexcept { + return vlist_.end(); + } + + void clear() noexcept { + map_.clear(); + vlist_.clear(); + } + + void swap(lru_cache& other) noexcept { + using std::swap; + swap(vlist_, other.vlist_); + swap(map_, other.map_); + swap(capacity_, other.capacity_); + } + + list_iter find(const key_t& key) { + auto it = map_.find(key); + if (it == map_.end()) + return end(); + vlist_.splice(vlist_.begin(), vlist_, it->second); + return it->second; + } + + std::pair insert(const value_type& value) { + auto it = map_.find(value.first); + if (it != map_.end()) { + // Move existing to front + vlist_.splice(vlist_.begin(), vlist_, it->second); + return {it->second, false}; + } + + // Insert new at front + vlist_.emplace_front(value); + map_[value.first] = vlist_.begin(); + + trim(); + + return {vlist_.begin(), true}; + } + + list_iter erase(list_iter pos) { + map_.erase(pos->first); + return vlist_.erase(pos); + } + + private: + void trim() { + while (map_.size() > capacity_) { + auto last = std::prev(vlist_.end()); + map_.erase(last->first); + vlist_.pop_back(); + } + } + + list_type vlist_; + map_type map_; + size_type capacity_; +}; + +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp index 96c3405665cbfc..282f42f37a3646 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp @@ -85,8 +85,9 @@ at::Tensor quantized_convolution( std::optional unary_attr, torch::List> unary_scalars, std::optional unary_algorithm) { - Attr attr = - Attr(/*q_scale=*/1.0 / inv_output_scale, /*zp=*/output_zero_point); + Attr attr = Attr( + /*q_scale=*/static_cast(1.0 / inv_output_scale), + /*zp=*/output_zero_point); auto ndim = act.ndimension(); construct_attr_by_post_op( diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp index 180b54e2da53c2..41da31c7eb6b2c 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp @@ -112,7 +112,7 @@ void quantized_matmul( // config we support: // activation: s8&u8; per tensor calibrated; symmetric&asymmetric // weight: s8; per_tensor/per_channel calibrated; symmetric - auto attr = Attr(1.0 / output_scale, output_zero_point); + auto attr = Attr(static_cast(1.0 / output_scale), output_zero_point); construct_attr_by_post_op( binary_post_op, binary_alpha, diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp index 3f2e8097e377f2..3c8603e44833e4 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp @@ -294,6 +294,13 @@ bool is_onednn_matmul_strides(const at::Tensor& tensor) { if (tensor.is_contiguous()) return true; + if (tensor.storage_offset() > 0) { + // currently onednn asks 64 byte alignment + constexpr int alignment_byte = 64; + if (reinterpret_cast(tensor.data_ptr()) % alignment_byte > 0) + return false; + } + // the overlaped cases are not supported dnnl::memory::dims strides = get_onednn_strides(tensor); int64_t storage_size = 1; diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp index cb9fac96a886ba..6ef371424eed84 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -8,22 +9,13 @@ namespace at::native::onednn { -void woq_matmul_int4( - Tensor& result, // torchao: [M, K], dtype: fp16,bf16 - const Tensor& mat1_, // torchao: [M, K], dtype: fp16,bf16 - const Tensor& mat2_, // torchao quantized weight, [K/8, N], dtype: uint4x8 - const Tensor& scale, // torchao: [K/group_size, N], dtype: fp16,bf16 - const Tensor& zp, // torchao: [K/group_size, N], dtype: int8 +void woq_matmul_int4_impl( + Tensor& result, + const Tensor& mat1_, + const Tensor& mat2_, + const Tensor& scale, + const Tensor& zp, int64_t group_size) { - size_t dims = result.dim(); - TORCH_CHECK( - dims == 2, "INT4 matmul at XPU only works with 2D input, got ", dims); - TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); - - at::Device cur_device = at::Device(at::kXPU, at::xpu::current_device()); - TORCH_CHECK( - cur_device == mat1_.device(), - "_weight_int4pack_mm_with_scales_and_zeros input should be on current device."); auto& engine = GpuEngineManager::Instance().get_engine(); auto& stream = GpuStreamManager::Instance().get_stream(); @@ -176,4 +168,162 @@ void woq_matmul_int4( args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_usr_m}); dnnl::sycl_interop::execute(matmul_p, stream, args); } + +static inline void set_quant_primitive_attr( + primitive_attr& pattr, + const Tensor& scale, + const Tensor& zp, + const int64_t group_size) { + // set scale and zero point for matmul args + pattr.set_scales( + DNNL_ARG_WEIGHTS, + /* mask */ (1 << 0) + (1 << 1), + {group_size, 1}, + get_onednn_dtype(scale)); + pattr.set_zero_points( + DNNL_ARG_WEIGHTS, + /* mask */ (1 << 0) + (1 << 1), + {group_size, 1}, + memory::data_type::s8); +} + +void woq_matmul_int4_impl_cache( + Tensor& result, + const Tensor& mat1, + const Tensor& mat2, + const Tensor& scale, + const Tensor& zp, + int64_t group_size) { + auto a_sz = mat1.sizes(); + auto c_sz = result.sizes(); + + const int m = + std::reduce(a_sz.begin(), a_sz.end() - 1, 1, std::multiplies()); + const int n = *(c_sz.end() - 1); + const int k = *(a_sz.end() - 1); + + const int64_t ldb = mat2.strides()[mat2.dim() - 2] * 8; // for int4 matmul + const int64_t lda = mat1.strides()[mat1.dim() - 2]; + const int64_t ldc = result.strides()[result.dim() - 2]; + + bias_type_t b_type = bias_type_t::none; + trans_type_t tt = trans_type_t::nt; // only support nt for int4 matmul + + joint_dtypes_t jd; + if (mat1.scalar_type() == at::ScalarType::Half) { + jd = joint_dtypes_t::f16_int4; + } else if (mat1.scalar_type() == at::ScalarType::BFloat16) { + jd = joint_dtypes_t::bf16_int4; + } else { + TORCH_INTERNAL_ASSERT( + false, "Unsupported data type for int4 matmul: ", mat1.scalar_type()); + } + + auto f_attr = [&](primitive_attr& pattr) { + pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + if (jd == joint_dtypes_t::f16_int4) { + pattr.set_fpmath_mode(dnnl::fpmath_mode::f16, true); + } else if (jd == joint_dtypes_t::bf16_int4) { + pattr.set_fpmath_mode(dnnl::fpmath_mode::bf16, true); + } + + set_quant_primitive_attr(pattr, scale, zp, group_size); + +#if ONEDNN_SUPPORT_DETERMINISTIC + if (at::globalContext().deterministicAlgorithms() || + at::globalContext().deterministicMkldnn()) { + pattr.set_deterministic(true); + } +#endif + }; + + int64_t zp_group_size = group_size; + auto device_id = c10::xpu::current_device(); + auto& matmul_ext = matmul_primitive_create_and_cache( + jd, + tt, + b_type, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + f_attr, + group_size, + zp_group_size); + + auto& engine = GpuEngineManager::Instance().get_engine(); + + int arg_off = 0; + // set scale and zero point for matmul args + matmul_ext.set_attribute( + arg_off++, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, + scale.data_ptr(), + [&]() { + return make_onednn_memory( + get_onednn_md(scale), engine, scale.data_ptr()); + }); + + // set zp_md for asymmetric quantization + matmul_ext.set_attribute( + arg_off++, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, + zp.data_ptr(), + [&]() { + int num_groups = k / group_size; + memory zp_usr_m( + {{num_groups, n}, memory::data_type::s8, {n, 1}}, + engine, + zp.data_ptr()); + return zp_usr_m; + }); + + // set general args + std::vector> arg_handles; + arg_handles.reserve(8); + + arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr()); + arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr()); + arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr()); + + int scratchpad_size = matmul_ext.get_scratchpad_size(); + Tensor scratchpad_tensor = at::empty( + {scratchpad_size}, mat1.options().dtype(at::kByte), std::nullopt); + arg_handles.emplace_back(DNNL_ARG_SCRATCHPAD, scratchpad_tensor.data_ptr()); + + auto& strm = GpuStreamManager::Instance().get_stream(); + auto qint4_matmul_event = + matmul_ext.execute(strm, engine, std::move(arg_handles), arg_off); +} + +void woq_matmul_int4( + Tensor& result, // torchao: [M, K], dtype: fp16,bf16 + const Tensor& mat1_, // torchao: [M, K], dtype: fp16,bf16 + const Tensor& mat2_, // torchao quantized weight, [K/8, N], dtype: uint4x8 + const Tensor& scale, // torchao: [K/group_size, N], dtype: fp16,bf16 + const Tensor& zp, // torchao: [K/group_size, N], dtype: int8 + int64_t group_size, + bool pri_cache) { + size_t dims = result.dim(); + TORCH_CHECK( + dims == 2, "INT4 matmul at XPU only works with 2D input, got ", dims); + TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); + + const int device_id = c10::xpu::current_device(); + at::Device cur_device = at::Device(at::kXPU, device_id); + TORCH_CHECK( + cur_device == mat1_.device(), + "_weight_int4pack_mm_with_scales_and_zeros input should be on current device."); + + if (pri_cache) { + woq_matmul_int4_impl_cache(result, mat1_, mat2_, scale, zp, group_size); + } else { + woq_matmul_int4_impl(result, mat1_, mat2_, scale, zp, group_size); + } +} + } // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h index fb2ced9549b717..e73cb73e8b1e7b 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -95,7 +95,8 @@ TORCH_API void woq_matmul_int4( const at::Tensor& mat2_, // quantized weight, [K/8, N] const at::Tensor& scale, // [K/group_size, N] const at::Tensor& zp, // [k/group_size, N] - int64_t group_size); + int64_t group_size, + bool pri_cache = true); dnnl::memory::dims conv_dst_size( int64_t ndim, @@ -166,10 +167,10 @@ void quantized_matmul( void gpu_float_sdpa( int batch_size, int seq_len_q, - int seq_len_k, - int num_head, + int seq_len_kv, + int num_head_q, int num_head_kv, - int head_dim, + int head_dim_qk, int head_dim_v, const Tensor& query, const Tensor& key, diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index 594355c88c5b1d..6d35a5e9b2a31d 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -137,11 +137,13 @@ class MetalShaderLibrary { void exec_unary_kernel( TensorIteratorBase& iter, const std::string& name, - std::optional extra = std::nullopt); + const std::optional alpha = std::nullopt, + const std::optional scalar_arg_type = std::nullopt); void exec_binary_kernel( TensorIteratorBase& iter, const std::string& name, - const std::optional alpha = std::nullopt); + const std::optional alpha = std::nullopt, + const std::optional scalar_arg_type = std::nullopt); protected: virtual MTLLibrary_t getLibrary(); @@ -154,6 +156,7 @@ class MetalShaderLibrary { MTLLibrary_t lib, const std::string& fname); MTLLibrary_t compileLibrary(const std::string& src); + void bind_tensors(MTLComputeCommandEncoder_t, TensorIteratorBase&); std::string shaderSource; unsigned nparams; MTLCompileOptions* compile_options; diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index c64c1a011a8edc..6474faac43ab83 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -473,6 +473,11 @@ static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigne [encoder setBytes:&val length:sizeof(val) atIndex:idx]; return; } + if (C10_UNLIKELY(t.scalar_type() == kComplexDouble)) { + auto val = static_cast>(*reinterpret_cast*>(t.const_data_ptr())); + [encoder setBytes:&val length:sizeof(val) atIndex:idx]; + return; + } [encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx]; } else { TORCH_CHECK(false, "Passed CPU tensor to MPS op"); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index bdc75b571abdad..583eb410345082 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -971,24 +971,51 @@ static dispatch_data_t getSectionData(const std::string& name) { } }; +void MetalShaderLibrary::bind_tensors(id encoder, TensorIteratorBase& iter) { + for (auto idx : c10::irange(iter.ntensors())) { + auto& t = iter.tensor_base(idx); + // Handle CPU scalars + if (C10_UNLIKELY(t.device().type() == kCPU)) { + mtl_setBuffer(encoder, t, idx); + continue; + } + // At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id object + // But TensorIterator constructs data_ptr as if base was just a raw pointer + // Workaround this problem by computing an offset from the start of the tensor, which works for both + // tensor vies and sliced 64-bit iterators + auto offs = reinterpret_cast(iter.data_ptr(idx)) - reinterpret_cast(t.storage().data()); + [encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx]; + } +} + void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter, const std::string& name, - std::optional extra) { + std::optional alpha, + std::optional scalar_arg_type) { + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_unary_kernel(sub_iter, name, alpha, scalar_arg_type); + } + return; + } + auto inputTensor = iter.input(0); auto outputTensor = iter.output(0); - bool is_storage_dense = is_dense_in_storage(inputTensor) && inputTensor.strides().equals(outputTensor.strides()); uint32_t length = iter.numel(); if (length == 0) { return; } using namespace mps; + const auto alpha_type = scalar_arg_type.has_value() ? scalar_arg_type.value() : iter.common_dtype(); + auto kernel_name = fmt::format("{}_{}_{}_{}{}", + name, + iter.is_contiguous() ? "dense" : "strided", + scalarToMetalTypeString(outputTensor), + scalarToMetalTypeString(inputTensor), + alpha.has_value() ? fmt::format("_{}", scalarToMetalTypeString(alpha_type)) : ""); @autoreleasepool { - id cplState = nil; - cplState = getPipelineStateForFunc(fmt::format("{}_{}_{}_{}", - name, - is_storage_dense ? "dense" : "strided", - scalarToMetalTypeString(outputTensor), - scalarToMetalTypeString(inputTensor))); + auto cplState = getPipelineStateForFunc(kernel_name); MPSStream* mpsStream = getCurrentMPSStream(); dispatch_sync(mpsStream->queue(), ^() { @@ -997,22 +1024,16 @@ static dispatch_data_t getSectionData(const std::string& name) { getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor}); [computeEncoder setComputePipelineState:cplState]; - if (is_storage_dense) { - mtl_setArgs(computeEncoder, outputTensor, inputTensor); - if (extra) { - mtl_setBytes(computeEncoder, *extra, 2); - } - } else { - mtl_setArgs(computeEncoder, - outputTensor, - inputTensor, - outputTensor.sizes(), - inputTensor.strides(), - outputTensor.strides(), - inputTensor.ndimension()); - if (extra) { - mtl_setBytes(computeEncoder, *extra, 6); - } + bind_tensors(computeEncoder, iter); + if (!iter.is_contiguous()) { + mtl_setArgs<2>(computeEncoder, + outputTensor.sizes(), + inputTensor.strides(), + outputTensor.strides(), + inputTensor.ndimension()); + } + if (alpha) { + mtl_setBytes(computeEncoder, getMPSScalar(*alpha, alpha_type), iter.is_contiguous() ? 2 : 6); } mtl_dispatch1DJob(computeEncoder, cplState, length); @@ -1023,18 +1044,26 @@ static dispatch_data_t getSectionData(const std::string& name) { void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std::string& name, - std::optional alpha) { + std::optional alpha, + std::optional scalar_arg_type) { // TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?) // Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with // double as common dtype (because Python floating point are always 64-bit values) TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS"); - TORCH_CHECK(iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator"); // Skip for empty iterators if (iter.numel() == 0) { return; } + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_binary_kernel(sub_iter, name, alpha, scalar_arg_type); + } + return; + } + auto convert_double_scalar = [](Tensor& t) { if (t.dim() != 0) { return; @@ -1053,17 +1082,16 @@ static dispatch_data_t getSectionData(const std::string& name) { convert_double_scalar(input); convert_double_scalar(other); - id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); - const uint32_t nDim = iter.ndim(); - constexpr uint32_t nOffsets = 3; - const uint32_t numThreads = iter.numel(); const auto cast_needed = input.scalar_type() != other.scalar_type(); const auto suffix = iter.is_contiguous() ? "dense" : "strided"; + const auto alpha_type = scalar_arg_type.has_value() ? scalar_arg_type.value() : iter.common_dtype(); + const auto alpha_suffix = alpha.has_value() ? fmt::format("_{}", scalarToMetalTypeString(alpha_type)) : ""; // TODO: Implicitly pass both input and output types to non-cast kernels const auto kernel_name = cast_needed - ? fmt::format("{}_{}_cast_{}", name, suffix, scalarToMetalTypeString(out)) - : fmt::format("{}_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input)); + ? fmt::format("{}_{}_cast_{}{}", name, suffix, scalarToMetalTypeString(out), alpha_suffix) + : fmt::format( + "{}_{}_{}_{}{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input), alpha_suffix); dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { auto computeEncoder = mpsStream->commandEncoder(); @@ -1071,12 +1099,13 @@ static dispatch_data_t getSectionData(const std::string& name) { // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other}); [computeEncoder setComputePipelineState:binaryPSO]; + // Set input and output tensors + bind_tensors(computeEncoder, iter); // Iterator is contiguous if all of its elements are dense in storage, // i.e. it's true for both row-first and column-first tensors if (iter.is_contiguous()) { - mtl_setArgs(computeEncoder, out, input, other); if (alpha) { - mtl_setBytes(computeEncoder, getMPSScalar(*alpha, iter.common_dtype()), 3); + mtl_setBytes(computeEncoder, getMPSScalar(*alpha, alpha_type), 3); } if (cast_needed) { std::array size_and_types = {static_cast(c10::elementSize(input.scalar_type())), @@ -1089,32 +1118,24 @@ static dispatch_data_t getSectionData(const std::string& name) { // Please note that shapes and strides of the iterator might be // different than that of its operands, for example binary op // between 4x4 tensor and scalar will result in 1D 16 element iterator - std::array ndim_and_types = { - iter.ndim(), static_cast(input.scalar_type()), static_cast(other.scalar_type())}; + std::array ndim_and_types = {iter.ndim(), + static_cast(input.scalar_type()), + static_cast(other.scalar_type()), + static_cast(out.scalar_type())}; if (alpha) { - mtl_setArgs(computeEncoder, - out, - input, - other, - getMPSScalar(*alpha, iter.common_dtype()), - iter.shape(), - iter.strides(0), - iter.strides(1), - iter.strides(2), - ndim_and_types); + mtl_setArgs<3>(computeEncoder, + getMPSScalar(*alpha, alpha_type), + iter.shape(), + iter.strides(0), + iter.strides(1), + iter.strides(2), + ndim_and_types); } else { - mtl_setArgs(computeEncoder, - out, - input, - other, - iter.shape(), - iter.strides(0), - iter.strides(1), - iter.strides(2), - ndim_and_types); + mtl_setArgs<3>( + computeEncoder, iter.shape(), iter.strides(0), iter.strides(1), iter.strides(2), ndim_and_types); } } - mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads); + mtl_dispatch1DJob(computeEncoder, binaryPSO, iter.numel()); getMPSProfiler().endProfileKernel(binaryPSO); } }); diff --git a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal new file mode 100644 index 00000000000000..f7335d150d4063 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal @@ -0,0 +1,146 @@ +#include +#include +#include +using namespace metal; +using namespace c10::metal; + +struct hardshrink_functor { + template + inline T operator()(const T x, const T lambda) { + return abs(float(x)) <= float(lambda) ? T(0) : x; + } +}; + +struct softshrink_functor { + template + inline T operator()(const T x, const T lambda) { + if (x > lambda) { + return x - lambda; + } else if (x < -lambda) { + return x + lambda; + } else { + return T(0); + } + } +}; + +struct shrink_backward_functor { + template + inline T operator()(const T grad_output, const T x, const T lambda) { + return abs(float(x)) <= float(lambda) ? T(0) : grad_output; + } +}; + +REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float); +REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat); +#endif + +REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float); +REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat); +#endif + +REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float); +REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat); +#endif + +struct hardsigmoid_functor { + template + inline T operator()(const T x) { + return static_cast(min(max(x + 3.0f, .0f), 6.f) / 6.f); + } +}; + +struct hardsigmoid_backward_functor { + template + inline T operator()(const T grad_output, const T self) { + constexpr auto one_sixth = 1.0f / 6.0f; + return static_cast( + abs(float(self)) < 3.0f ? float(grad_output) * one_sixth : 0.0f); + } +}; + +REGISTER_UNARY_OP(hardsigmoid, float, float); +REGISTER_UNARY_OP(hardsigmoid, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat); +#endif + +REGISTER_BINARY_OP(hardsigmoid_backward, float, float); +REGISTER_BINARY_OP(hardsigmoid_backward, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat); +#endif + +struct hardswish_functor { + template + inline T operator()(const T x) { + return static_cast(float(x) * min(max(float(x) + 3.0f, .0f), 6.f) / 6.f); + } +}; + +struct hardswish_backward_functor { + template + inline T operator()(const T grad_output, const T self) { + constexpr T zero(0); + constexpr T three(3); + constexpr T neg_three(-3); + + if (self <= neg_three) { + return zero; + } else if (self >= three) { + return grad_output; + } else { + return static_cast(float(grad_output) * (float(self) / 3.0f + 0.5f)); + } + } +}; + +REGISTER_UNARY_OP(hardswish, float, float); +REGISTER_UNARY_OP(hardswish, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_UNARY_OP(hardswish, bfloat, bfloat); +#endif + +REGISTER_BINARY_OP(hardswish_backward, float, float); +REGISTER_BINARY_OP(hardswish_backward, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat); +#endif + +struct leaky_relu_functor { + template + inline T operator()(const T x, const T negative_slope) { + return float(x) > 0.0f ? x + : static_cast(float(x) * float(negative_slope)); + } +}; + +struct leaky_relu_backward_functor { + template + inline T operator()( + const T self, + const T grad_output, + const T negative_slope) { + return float(self) > 0.0f + ? grad_output + : static_cast(float(grad_output) * float(negative_slope)); + } +}; + +REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float); +REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat); +#endif + +REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float); +REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat); +#endif diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 774a8283100d05..a178ba25712312 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -69,59 +69,143 @@ struct copysign_functor { }; struct zeta_functor { - template + template , bool> = true> inline T operator()(const T a, const T b) { return static_cast(c10::metal::zeta(a, b)); } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::zeta(float(a), float(b)); + } }; struct xlog1py_functor { - template + template , bool> = true> inline T operator()(const T a, const T b) { return static_cast(c10::metal::xlog1py(a, b)); } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::xlog1py(float(a), float(b)); + } }; struct chebyshev_polynomial_t_functor { - template + template , bool> = true> inline T operator()(const T a, const T b) { return static_cast(c10::metal::chebyshev_polynomial_t_forward(a, b)); } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::chebyshev_polynomial_t_forward(float(a), float(b)); + } }; struct chebyshev_polynomial_u_functor { - template + template , bool> = true> inline T operator()(const T a, const T b) { return static_cast(c10::metal::chebyshev_polynomial_u_forward(a, b)); } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::chebyshev_polynomial_u_forward(float(a), float(b)); + } }; struct chebyshev_polynomial_v_functor { - template + template , bool> = true> inline T operator()(const T a, const T b) { return static_cast(c10::metal::chebyshev_polynomial_v_forward(a, b)); } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::chebyshev_polynomial_v_forward(float(a), float(b)); + } }; struct chebyshev_polynomial_w_functor { - template + template , bool> = true> inline T operator()(const T a, const T b) { return static_cast(c10::metal::chebyshev_polynomial_w_forward(a, b)); } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::chebyshev_polynomial_w_forward(float(a), float(b)); + } +}; + +struct shifted_chebyshev_polynomial_t_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return static_cast( + c10::metal::shifted_chebyshev_polynomial_t_forward(a, b)); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::shifted_chebyshev_polynomial_t_forward( + float(a), float(b)); + } +}; + +struct shifted_chebyshev_polynomial_u_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return static_cast( + c10::metal::shifted_chebyshev_polynomial_u_forward(a, b)); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::shifted_chebyshev_polynomial_u_forward( + float(a), float(b)); + } +}; + +struct shifted_chebyshev_polynomial_v_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return static_cast( + c10::metal::shifted_chebyshev_polynomial_v_forward(a, b)); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::shifted_chebyshev_polynomial_v_forward( + float(a), float(b)); + } +}; + +struct shifted_chebyshev_polynomial_w_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return static_cast( + c10::metal::shifted_chebyshev_polynomial_w_forward(a, b)); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::shifted_chebyshev_polynomial_w_forward( + float(a), float(b)); + } }; struct hermite_polynomial_h_functor { - template + template , bool> = true> inline T operator()(const T a, const T b) { return static_cast(c10::metal::hermite_polynomial_h_forward(a, b)); } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::hermite_polynomial_h_forward(float(a), float(b)); + } }; struct hermite_polynomial_he_functor { - template + template , bool> = true> inline T operator()(const T a, const T b) { return static_cast(c10::metal::hermite_polynomial_he_forward(a, b)); } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::hermite_polynomial_he_forward(float(a), float(b)); + } }; struct nextafter_functor { @@ -249,7 +333,7 @@ struct div_trunc_functor { struct remainder_functor { template inline T operator()(const T a, const T b) { - return T(a - b * c10::metal::floor_divide(a, b)); + return T(c10::metal::remainder(a, b)); } }; @@ -299,13 +383,29 @@ REGISTER_FLOAT_BINARY_OP(fmax); REGISTER_FLOAT_BINARY_OP(fmin); REGISTER_FLOAT_BINARY_OP(nextafter); REGISTER_FLOAT_BINARY_OP(zeta); +REGISTER_INT2FLOAT_BINARY_OP(zeta); REGISTER_FLOAT_BINARY_OP(xlog1py); +REGISTER_INT2FLOAT_BINARY_OP(xlog1py); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t); +REGISTER_INT2FLOAT_BINARY_OP(chebyshev_polynomial_t); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_u); +REGISTER_INT2FLOAT_BINARY_OP(chebyshev_polynomial_u); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_v); +REGISTER_INT2FLOAT_BINARY_OP(chebyshev_polynomial_w); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_w); +REGISTER_INT2FLOAT_BINARY_OP(chebyshev_polynomial_v); +REGISTER_FLOAT_BINARY_OP(shifted_chebyshev_polynomial_t); +REGISTER_INT2FLOAT_BINARY_OP(shifted_chebyshev_polynomial_t); +REGISTER_FLOAT_BINARY_OP(shifted_chebyshev_polynomial_u); +REGISTER_INT2FLOAT_BINARY_OP(shifted_chebyshev_polynomial_u); +REGISTER_FLOAT_BINARY_OP(shifted_chebyshev_polynomial_v); +REGISTER_INT2FLOAT_BINARY_OP(shifted_chebyshev_polynomial_v); +REGISTER_FLOAT_BINARY_OP(shifted_chebyshev_polynomial_w); +REGISTER_INT2FLOAT_BINARY_OP(shifted_chebyshev_polynomial_w); REGISTER_FLOAT_BINARY_OP(hermite_polynomial_h); +REGISTER_INT2FLOAT_BINARY_OP(hermite_polynomial_h); REGISTER_FLOAT_BINARY_OP(hermite_polynomial_he); +REGISTER_INT2FLOAT_BINARY_OP(hermite_polynomial_he); REGISTER_FLOAT_BINARY_OP(add); REGISTER_INTEGER_BINARY_OP(add); REGISTER_OPMATH_FLOAT_BINARY_OP(mul); @@ -322,35 +422,35 @@ REGISTER_OPMATH_FLOAT_BINARY_OP(remainder); REGISTER_INTEGER_BINARY_OP(remainder); REGISTER_OPMATH_FLOAT_BINARY_OP(fmod); REGISTER_INTEGER_BINARY_OP(fmod); -REGISTER_BINARY_ALPHA_OP(add_alpha, long, long); -REGISTER_BINARY_ALPHA_OP(add_alpha, int, int); -REGISTER_BINARY_ALPHA_OP(add_alpha, float, float); -REGISTER_BINARY_ALPHA_OP(add_alpha, half, half); -REGISTER_BINARY_ALPHA_OP(add_alpha, short, short); -REGISTER_BINARY_ALPHA_OP(add_alpha, uchar, uchar); -REGISTER_BINARY_ALPHA_OP(add_alpha, char, char); -REGISTER_BINARY_ALPHA_OP(add_alpha, bool, bool); -REGISTER_BINARY_ALPHA_OP(sub_alpha, long, long); -REGISTER_BINARY_ALPHA_OP(sub_alpha, int, int); -REGISTER_BINARY_ALPHA_OP(sub_alpha, float, float); -REGISTER_BINARY_ALPHA_OP(sub_alpha, half, half); -REGISTER_BINARY_ALPHA_OP(sub_alpha, short, short); -REGISTER_BINARY_ALPHA_OP(sub_alpha, uchar, uchar); -REGISTER_BINARY_ALPHA_OP(sub_alpha, char, char); -REGISTER_BINARY_ALPHA_OP(sub_alpha, bool, bool); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, long, long); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, int, int); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, float, float); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, half, half); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, short, short); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool); +REGISTER_BINARY_ALPHA_OP(add_alpha, long, long, long); +REGISTER_BINARY_ALPHA_OP(add_alpha, int, int, int); +REGISTER_BINARY_ALPHA_OP(add_alpha, float, float, float); +REGISTER_BINARY_ALPHA_OP(add_alpha, half, half, half); +REGISTER_BINARY_ALPHA_OP(add_alpha, short, short, short); +REGISTER_BINARY_ALPHA_OP(add_alpha, uchar, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(add_alpha, char, char, char); +REGISTER_BINARY_ALPHA_OP(add_alpha, bool, bool, bool); +REGISTER_BINARY_ALPHA_OP(sub_alpha, long, long, long); +REGISTER_BINARY_ALPHA_OP(sub_alpha, int, int, int); +REGISTER_BINARY_ALPHA_OP(sub_alpha, float, float, float); +REGISTER_BINARY_ALPHA_OP(sub_alpha, half, half, half); +REGISTER_BINARY_ALPHA_OP(sub_alpha, short, short, short); +REGISTER_BINARY_ALPHA_OP(sub_alpha, uchar, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(sub_alpha, char, char, char); +REGISTER_BINARY_ALPHA_OP(sub_alpha, bool, bool, bool); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, long, long, long); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, int, int, int); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, float, float, float); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, half, half, half); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, short, short, short); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool); #if __METAL_VERSION__ >= 310 -REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat); -REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat); #endif // Complex binary functions @@ -366,9 +466,9 @@ REGISTER_BINARY_OP(add, float2, float2); REGISTER_BINARY_OP(add, half2, half2); REGISTER_BINARY_OP(sub, float2, float2); REGISTER_BINARY_OP(sub, half2, half2); -REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2); -REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(sub_alpha, half2, half2); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, half2, half2); +REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2); +REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2); +REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2); +REGISTER_BINARY_ALPHA_OP(sub_alpha, half2, half2, half2); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, float2, float2, float2); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, half2, half2, half2); diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index c98cc6950f2f5a..15d46d8c8d8e18 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -7,36 +7,31 @@ using namespace metal; constant uint TILE_DIM = 16; template -kernel void matmul( - constant T* mat1Data [[buffer(0)]], - constant T* mat2Data [[buffer(1)]], - device T* outputData [[buffer(2)]], - constant array& strides [[buffer(3)]], - constant uint3& sizes [[buffer(4)]], - uint2 tid [[thread_position_in_threadgroup]], - uint2 group_id [[threadgroup_position_in_grid]]) { - uint col = group_id.x * TILE_DIM + tid.x; - uint row = group_id.y * TILE_DIM + tid.y; - +inline c10::metal::opmath_t matmul_inner( + constant T* mat1Data, + constant T* mat2Data, + constant array& strides, + constant uint3& sizes, + threadgroup T A_tile[TILE_DIM][TILE_DIM], + threadgroup T B_tile[TILE_DIM][TILE_DIM], + uint2 tid, + uint2 thread_id) { c10::metal::opmath_t sum = 0; - threadgroup T A_tile[TILE_DIM][TILE_DIM]; - threadgroup T B_tile[TILE_DIM][TILE_DIM]; - uint numTiles = (sizes.y + TILE_DIM - 1) / TILE_DIM; for (uint t = 0; t < numTiles; t++) { uint tiledCol = t * TILE_DIM + tid.x; - if (row < sizes.x && tiledCol < sizes.y) { + if (thread_id.y < sizes.x && tiledCol < sizes.y) { A_tile[tid.y][tid.x] = - mat1Data[row * strides[0].x + tiledCol * strides[0].y]; + mat1Data[thread_id.y * strides[0].x + tiledCol * strides[0].y]; } else { A_tile[tid.y][tid.x] = 0; } uint tiledRow = t * TILE_DIM + tid.y; - if (tiledRow < sizes.y && col < sizes.z) { + if (tiledRow < sizes.y && thread_id.x < sizes.z) { B_tile[tid.y][tid.x] = - mat2Data[tiledRow * strides[1].x + col * strides[1].y]; + mat2Data[tiledRow * strides[1].x + thread_id.x * strides[1].y]; } else { B_tile[tid.y][tid.x] = 0; } @@ -50,8 +45,26 @@ kernel void matmul( threadgroup_barrier(mem_flags::mem_threadgroup); } - if (row < sizes.x && col < sizes.z) { - outputData[row * strides[2].x + col * strides[2].y] = static_cast(sum); + return sum; +} + +template +kernel void matmul( + constant T* mat1Data [[buffer(0)]], + constant T* mat2Data [[buffer(1)]], + device T* outputData [[buffer(2)]], + constant array& strides [[buffer(3)]], + constant uint3& sizes [[buffer(4)]], + uint2 tid [[thread_position_in_threadgroup]], + uint2 thread_id [[thread_position_in_grid]]) { + threadgroup T A_tile[TILE_DIM][TILE_DIM]; + threadgroup T B_tile[TILE_DIM][TILE_DIM]; + + auto sum = matmul_inner( + mat1Data, mat2Data, strides, sizes, A_tile, B_tile, tid, thread_id); + if (thread_id.y < sizes.x && thread_id.x < sizes.z) { + outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] = + static_cast(sum); } } @@ -132,6 +145,28 @@ inline float blockReduceSum( return sharedScratch[0]; } +template +inline device float& get_ref(device float* A, uint row, uint col, uint N); + +template <> +inline device float& get_ref( + device float* A, + uint row, + uint col, + uint N) { + return A[row * N + col]; +} + +template <> +inline device float& get_ref( + device float* A, + uint row, + uint col, + uint N) { + return A[row + col * N]; +} + +template kernel void factorDiagonalBlock( device float* A [[buffer(0)]], device int* info [[buffer(1)]], @@ -158,7 +193,7 @@ kernel void factorDiagonalBlock( for (uint i = linear_tid; i < tileSize; i += group_size) { uint r = i / actSize; uint c = i % actSize; - tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)]; + tile[r][c] = get_ref(A + batch_offset, row0 + r, col0 + c, N); } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -231,10 +266,33 @@ kernel void factorDiagonalBlock( for (uint i = linear_tid; i < tileSize; i += group_size) { uint r = i / actSize; uint c = i % actSize; - A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c]; + get_ref(A + batch_offset, row0 + r, col0 + c, N) = tile[r][c]; } } +template [[host_name("factorDiagonalBlockU")]] +kernel void factorDiagonalBlock( + device float* A [[buffer(0)]], + device int* info [[buffer(1)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 bid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]]); + +template [[host_name("factorDiagonalBlockL")]] +kernel void factorDiagonalBlock( + device float* A [[buffer(0)]], + device int* info [[buffer(1)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 bid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]]); + +template kernel void applyTRSM( device float* A [[buffer(0)]], constant uint& N [[buffer(2)]], @@ -270,12 +328,12 @@ kernel void applyTRSM( for (uint i = linear_tid; i < actSize_k * actSize_k; i += group_size) { uint r = i / actSize_k; uint c = i % actSize_k; - diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)]; + diag[i] = get_ref(A + batch_offset, k * NB + r, k * NB + c, N); } for (uint i = linear_tid; i < actSize_j * actSize_k; i += group_size) { uint r = i / actSize_k; uint c = i % actSize_k; - target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)]; + target[i] = get_ref(A + batch_offset, row0 + r, col0 + c, N); } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -319,10 +377,31 @@ kernel void applyTRSM( for (uint i = linear_tid; i < actSize_j * actSize_k; i += group_size) { uint r = i / actSize_k; uint c = i % actSize_k; - A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i]; + get_ref(A + batch_offset, row0 + r, col0 + c, N) = target[i]; } } +template [[host_name("applyTRSMU")]] +kernel void applyTRSM( + device float* A [[buffer(0)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 tgid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]]); + +template [[host_name("applyTRSML")]] +kernel void applyTRSM( + device float* A [[buffer(0)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 tgid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]]); + +template kernel void applySYRK( device float* A [[buffer(0)]], constant uint& N [[buffer(2)]], @@ -390,17 +469,25 @@ kernel void applySYRK( // Same logic to load/store Cfrag, Afrag, Bfrag... simdgroup_matrix Cfrag; simdgroup_load( - Cfrag, &A[batch_offset + (row0 + sb_y) * N + (col0 + sb_x)], N); + Cfrag, + &get_ref(A + batch_offset, row0 + sb_y, col0 + sb_x, N), + N, + 0, + !upper); for (uint kk = 0; kk < actSize_k; kk += 8) { simdgroup_load( - Afrag, &A[batch_offset + (row0 + sb_y) * N + (k * NB + kk)], N); + Afrag, + &get_ref(A + batch_offset, row0 + sb_y, k * NB + kk, N), + N, + 0, + !upper); simdgroup_load( Bfrag, - &A[batch_offset + (col0 + sb_x) * N + (k * NB + kk)], + &get_ref(A + batch_offset, col0 + sb_x, k * NB + kk, N), N, /* matrix_origin = */ 0, - /* transpose = */ true); + /* transpose = */ upper); simdgroup_multiply(Prod, Afrag, Bfrag); simdgroup_multiply(Prod, Prod, negative_identity); @@ -408,7 +495,11 @@ kernel void applySYRK( } simdgroup_store( - Cfrag, &A[batch_offset + (row0 + sb_y) * N + (col0 + sb_x)], N); + Cfrag, + &get_ref(A + batch_offset, row0 + sb_y, col0 + sb_x, N), + N, + 0, + !upper); } } else { // Fallback for non-multiple-of-8 dimensions @@ -429,8 +520,10 @@ kernel void applySYRK( float sum = 0.0f; for (uint i = 0; i < actSize_k; i++) { - float a_val = A[batch_offset + (row0 + y) * N + k * NB + i]; - float b_val = A[batch_offset + (col0 + x) * N + k * NB + i]; + float a_val = + get_ref(A + batch_offset, row0 + y, k * NB + i, N); + float b_val = + get_ref(A + batch_offset, col0 + x, k * NB + i, N); sum = fma(a_val, b_val, sum); } sum_accumulator[y * tpg.x + x] += sum; @@ -439,13 +532,35 @@ kernel void applySYRK( threadgroup_barrier(mem_flags::mem_threadgroup); for (uint y = ty; y < actSize_j; y += tpg.y) { for (uint x = tx; x < actSize_h; x += tpg.x) { - A[batch_offset + (row0 + y) * N + col0 + x] -= + get_ref(A + batch_offset, row0 + y, col0 + x, N) -= sum_accumulator[y * tpg.x + x]; } } } } +template [[host_name("applySYRKU")]] +kernel void applySYRK( + device float* A [[buffer(0)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 tgid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]]); + +template [[host_name("applySYRKL")]] +kernel void applySYRK( + device float* A [[buffer(0)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 tgid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]]); + kernel void applyPivots( device float* P [[buffer(0)]], device const int* pivots [[buffer(1)]], diff --git a/aten/src/ATen/native/mps/kernels/Pooling.h b/aten/src/ATen/native/mps/kernels/Pooling.h new file mode 100644 index 00000000000000..9838b14a6bfcf3 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/Pooling.h @@ -0,0 +1,29 @@ +#pragma once + +#ifndef __METAL__ +#include +#define _ARRAY_NS std +#else +#include +#define _ARRAY_NS metal +#endif + +// N is the maximum allowed number of dimensions in the input and outputs. The +// maximum allowed pooling dimensions is N-2, because the input may have up to 2 +// leading dimensions that are not pooled. To support up to 3-D pooling, N=5 is +// the default. +template +struct PoolingParams { + int32_t dims; + int32_t pooling_dims; + _ARRAY_NS::array input_sizes; + _ARRAY_NS::array input_strides; + _ARRAY_NS::array output_sizes; + _ARRAY_NS::array output_strides; + _ARRAY_NS::array indices_sizes; + _ARRAY_NS::array indices_strides; + _ARRAY_NS::array kernel_size; + _ARRAY_NS::array stride; + _ARRAY_NS::array padding; + _ARRAY_NS::array dilation; +}; diff --git a/aten/src/ATen/native/mps/kernels/Pooling.metal b/aten/src/ATen/native/mps/kernels/Pooling.metal new file mode 100644 index 00000000000000..967cc3915960db --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/Pooling.metal @@ -0,0 +1,175 @@ +#include +#include +#include +using namespace metal; + +// Iterates through all the input elements that this kernel needs to +// apply max to. Specialized for 3 pooling dimensions. +// TODO: Support any number of pooling dims +template +void max_pool_3d_input_iter( + constant T* input, + device T* output, + device int64_t* indices, + constant int64_t* input_sizes, + constant int64_t* input_strides, + device int64_t* work_pooling_dim_indices, + constant int64_t* kernel_size, + constant int64_t* stride, + constant int64_t* padding, + constant int64_t* dilation) { + int64_t o0 = work_pooling_dim_indices[0]; + int64_t o1 = work_pooling_dim_indices[1]; + int64_t o2 = work_pooling_dim_indices[2]; + + int64_t k0 = kernel_size[0]; + int64_t k1 = kernel_size[1]; + int64_t k2 = kernel_size[2]; + + int64_t s0 = stride[0]; + int64_t s1 = stride[1]; + int64_t s2 = stride[2]; + + int64_t d0 = dilation[0]; + int64_t d1 = dilation[1]; + int64_t d2 = dilation[2]; + + T max_value = 0; + int64_t max_index = -1; + + int64_t size12 = input_sizes[1] * input_sizes[2]; + + for (int64_t i0 = (s0 * o0) - padding[0]; + i0 < (s0 * o0 - padding[0] + k0 * d0) && i0 < input_sizes[0]; + i0 += d0) { + if (i0 < 0) { + continue; + } + int64_t offset0 = input_strides[0] * i0; + + for (int64_t i1 = (s1 * o1) - padding[1]; + i1 < (s1 * o1 - padding[1] + k1 * d1) && i1 < input_sizes[1]; + i1 += d1) { + if (i1 < 0) { + continue; + } + int64_t offset1 = input_strides[1] * i1; + + for (int64_t i2 = (s2 * o2) - padding[2]; + i2 < (s2 * o2 - padding[2] + k2 * d2) && i2 < input_sizes[2]; + i2 += d2) { + if (i2 < 0) { + continue; + } + int64_t offset2 = input_strides[2] * i2; + + const T input_value = input[offset0 + offset1 + offset2]; + int64_t input_index = i0 * size12 + i1 * input_sizes[2] + i2; + + T new_max_value = (max_index == -1 || input_value > max_value) + ? input_value + : max_value; + int64_t new_max_index = (max_index == -1 || input_value > max_value) + ? input_index + : max_index; + + max_value = new_max_value; + max_index = new_max_index; + } + } + } + + *output = max_value; + *indices = max_index; +} + +// Kernel computes one element of the output per kernel call. +template +kernel void max_pool( + constant void* input_ [[buffer(0)]], + device void* output_ [[buffer(1)]], + device void* indices_ [[buffer(2)]], + device int64_t* work_pooling_dim_indices_ [[buffer(3)]], + constant PoolingParams<5>& params [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + int32_t pooling_dims = params.pooling_dims; + int32_t dims = params.dims; + constant int64_t* input_sizes = params.input_sizes.data(); + constant int64_t* input_strides = params.input_strides.data(); + constant int64_t* output_sizes = params.output_sizes.data(); + constant int64_t* output_strides = params.output_strides.data(); + constant int64_t* indices_strides = params.indices_strides.data(); + constant int64_t* kernel_size = params.kernel_size.data(); + constant int64_t* stride = params.stride.data(); + constant int64_t* padding = params.padding.data(); + constant int64_t* dilation = params.dilation.data(); + + int32_t leading_dims = dims - pooling_dims; + constant T* input = reinterpret_cast(input_); + device T* output = reinterpret_cast(output_); + device int64_t* indices = reinterpret_cast(indices_); + + // This buffer keeps track of the pooling dimension indices of this thread's + // element of the output. We need to fill it with the proper values below. + device int64_t* work_pooling_dim_indices = + work_pooling_dim_indices_ + tid * pooling_dims; + int64_t output_idx = static_cast(tid); + int64_t output_offset = 0; + int64_t indices_offset = 0; + int64_t input_leading_offset = 0; + + // First, find the offset of the output element this thread will calculate, + // `output[N, C, d, h, w]`. Also, find the offset of the input for the leading + // dim indices, `input[N, C]` and keep track of the pooling dimension indices, + // `[d, h , w]`. + for (int64_t dim = dims - 1; dim >= 0; dim--) { + int64_t dim_idx = output_idx % (output_sizes[dim]); + output_offset += output_strides[dim] * dim_idx; + indices_offset += indices_strides[dim] * dim_idx; + + if (dim < leading_dims) { + input_leading_offset += input_strides[dim] * dim_idx; + } else { + // Keep track of pooling dimension indices of the output element, so we + // can use them in the input iteration later on. + work_pooling_dim_indices[dim - leading_dims] = dim_idx; + } + output_idx = output_idx / output_sizes[dim]; + } + output += output_offset; + indices += indices_offset; + input += input_leading_offset; + + max_pool_3d_input_iter( + input, + output, + indices, + input_sizes + leading_dims, + input_strides + leading_dims, + work_pooling_dim_indices, + kernel_size, + stride, + padding, + dilation); +} + +#define REGISTER_MAX_POOL_OP(DTYPE) \ + template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool( \ + constant void* input_ [[buffer(0)]], \ + device void* output_ [[buffer(1)]], \ + device void* indices_ [[buffer(2)]], \ + device int64_t* work_pooling_dim_indices_ [[buffer(3)]], \ + constant PoolingParams<5>& params [[buffer(4)]], \ + uint tid [[thread_position_in_grid]]); + +REGISTER_MAX_POOL_OP(float); +REGISTER_MAX_POOL_OP(half); +REGISTER_MAX_POOL_OP(int); +REGISTER_MAX_POOL_OP(long); +REGISTER_MAX_POOL_OP(short); +REGISTER_MAX_POOL_OP(char); +REGISTER_MAX_POOL_OP(uchar); +REGISTER_MAX_POOL_OP(bool); +#if __METAL_VERSION__ >= 310 +REGISTER_MAX_POOL_OP(bfloat); +#endif diff --git a/aten/src/ATen/native/mps/kernels/ScanKernel.metal b/aten/src/ATen/native/mps/kernels/ScanKernel.metal new file mode 100644 index 00000000000000..c12fdb33cd7014 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/ScanKernel.metal @@ -0,0 +1,1165 @@ +#include +#include +using namespace metal; + +#include +#include + +using c10::metal::accum_t; + +struct LogAddExp { + template + T operator()(T x, T y) { + // Reference: + // https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp + T min_val = c10::metal::min(x, y); + T max_val = c10::metal::max(x, y); + + if (min_val != max_val || metal::isfinite(min_val)) { + // nan will be propagated here + return c10::metal::log1p(metal::exp(min_val - max_val)) + max_val; + } else { + // special case to correctly handle infinite cases + return x; + } + }; +}; + +#if __METAL_VERSION__ < 310 +template > +struct CumMinOp { + static acc_t apply(acc_t a, acc_t b) { + return metal::min(a, b); + } + static acc_t identity() { + return static_cast( + metal::is_floating_point_v ? metal::numeric_limits::infinity() + : metal::numeric_limits::max()); + } +}; + +template > +struct CumMaxOp { + static acc_t apply(acc_t a, acc_t b) { + return metal::max(a, b); + } + static acc_t identity() { + return static_cast( + metal::is_floating_point_v ? -metal::numeric_limits::infinity() + : metal::numeric_limits::lowest()); + } +}; + +template > +struct LogCumSumExpOp { + static acc_t apply(acc_t x, acc_t y) { + return LogAddExp{}(x, y); + } + static acc_t identity() { + return -metal::numeric_limits::infinity(); + } +}; + +// Inclusive scan along innermost dimension for contiguous tensors +template > +kernel void scan_contiguous_innermost_dim( + constant T* input [[buffer(0)]], + device T* output [[buffer(1)]], + constant uint& num_rows [[buffer(2)]], + constant uint& row_size [[buffer(3)]], + uint row [[thread_position_in_grid]]) { + if (row >= num_rows) + return; + + const uint offset = row * row_size; + + acc_t accumulator = Op::identity(); + + for (uint col = 0; col < row_size; col++) { + T val = input[offset + col]; + acc_t accum_val = static_cast(val); + accumulator = Op::apply(accumulator, accum_val); + output[offset + col] = static_cast(accumulator); + } +} + +// Inclusive scan along outer dimension for contiguous tensors +template > +kernel void scan_contiguous_outer_dim( + constant T* input [[buffer(0)]], + device T* output [[buffer(1)]], + constant uint& num_orows [[buffer(2)]], + constant uint& num_irows [[buffer(3)]], + constant uint& row_size [[buffer(4)]], + uint thread_index [[thread_position_in_grid]]) { + const uint orow = thread_index / num_irows; + const uint irow = thread_index % num_irows; + + if (orow >= num_orows) + return; + + acc_t accumulator = Op::identity(); + + const uint idx_base = orow * row_size * num_irows + irow; + for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) { + T val = input[idx]; + acc_t accum_val = static_cast(val); + accumulator = Op::apply(accumulator, accum_val); + output[idx] = static_cast(accumulator); + } +} + +// Inclusive scan with indices along innermost dimension for contiguous tensors +template > +kernel void scan_with_indices_contiguous_innermost_dim( + constant T* input [[buffer(0)]], + device T* values [[buffer(1)]], + device int64_t* indices [[buffer(2)]], + constant uint& num_rows [[buffer(3)]], + constant uint& row_size [[buffer(4)]], + uint row [[thread_position_in_grid]]) { + if (row >= num_rows) + return; + + const uint offset = row * row_size; + + acc_t accumulator = Op::identity(); + int64_t best_idx = 0; + + for (uint col = 0; col < row_size; col++) { + T val = input[offset + col]; + acc_t accum_val = static_cast(val); + if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) { + accumulator = accum_val; + best_idx = col; + } + values[offset + col] = static_cast(accumulator); + indices[offset + col] = best_idx; + } +} + +// Inclusive scan with indices along outer dimension for contiguous tensors +template > +kernel void scan_with_indices_contiguous_outer_dim( + constant T* input [[buffer(0)]], + device T* values [[buffer(1)]], + device int64_t* indices [[buffer(2)]], + constant uint& num_orows [[buffer(3)]], + constant uint& num_irows [[buffer(4)]], + constant uint& row_size [[buffer(5)]], + uint thread_index [[thread_position_in_grid]]) { + const uint orow = thread_index / num_irows; + const uint irow = thread_index % num_irows; + + if (orow >= num_orows) + return; + + acc_t accumulator = Op::identity(); + int64_t best_idx = 0; + + const uint idx_base = orow * row_size * num_irows + irow; + for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) { + T val = input[idx]; + acc_t accum_val = static_cast(val); + if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) { + accumulator = accum_val; + best_idx = col; + } + values[idx] = static_cast(accumulator); + indices[idx] = best_idx; + } +} + +// Shared utility functions for strided kernels +inline long calculate_non_scan_elements( + constant long* sizes, + uint ndim, + uint scan_dim) { + long total = 1; + for (uint i = 0; i < ndim; ++i) { + if (i != scan_dim) { + total *= sizes[i]; + } + } + return total; +} + +inline void thread_index_to_coordinates( + uint index, + int pos[c10::metal::max_ndim], + constant long* sizes, + uint ndim, + uint scan_dim) { + long remaining_index = index; + for (uint i = 0; i < ndim; ++i) { + if (i != scan_dim) { + pos[i] = remaining_index % sizes[i]; + remaining_index /= sizes[i]; + } else { + pos[i] = 0; + } + } +} + +inline long calculate_base_offset( + int pos[c10::metal::max_ndim], + constant long* strides, + uint ndim, + uint scan_dim) { + long offset = 0; + for (uint i = 0; i < ndim; ++i) { + if (i != scan_dim) { + offset += pos[i] * strides[i]; + } + } + return offset; +} + +// Generic strided scan kernel +template > +kernel void scan_strided( + constant T* input [[buffer(0)]], + device T* output [[buffer(1)]], + constant long* sizes [[buffer(2)]], + constant long* input_strides [[buffer(3)]], + constant long* output_strides [[buffer(4)]], + constant uint& ndim [[buffer(5)]], + constant uint& scan_dim [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]) { + const long total_non_scan_elements = + calculate_non_scan_elements(sizes, ndim, scan_dim); + if (thread_index >= total_non_scan_elements) { + return; + } + + int pos[c10::metal::max_ndim]; + thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim); + + const long input_base_offset = + calculate_base_offset(pos, input_strides, ndim, scan_dim); + const long output_base_offset = + calculate_base_offset(pos, output_strides, ndim, scan_dim); + + acc_t accumulator = Op::identity(); + const long scan_size = sizes[scan_dim]; + const long input_scan_stride = input_strides[scan_dim]; + const long output_scan_stride = output_strides[scan_dim]; + + for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) { + const long input_offset = input_base_offset + scan_idx * input_scan_stride; + const long output_offset = + output_base_offset + scan_idx * output_scan_stride; + + T val = input[input_offset]; + acc_t accum_val = static_cast(val); + accumulator = Op::apply(accumulator, accum_val); + output[output_offset] = static_cast(accumulator); + } +} + +// Generic strided scan with indices kernel +template > +kernel void scan_with_indices_strided( + constant T* input [[buffer(0)]], + device T* values [[buffer(1)]], + device int64_t* indices [[buffer(2)]], + constant long* sizes [[buffer(3)]], + constant long* input_strides [[buffer(4)]], + constant long* values_strides [[buffer(5)]], + constant long* indices_strides [[buffer(6)]], + constant uint& ndim [[buffer(7)]], + constant uint& scan_dim [[buffer(8)]], + uint thread_index [[thread_position_in_grid]]) { + const long total_non_scan_elements = + calculate_non_scan_elements(sizes, ndim, scan_dim); + if (thread_index >= total_non_scan_elements) { + return; + } + + int pos[c10::metal::max_ndim]; + thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim); + + const long input_base_offset = + calculate_base_offset(pos, input_strides, ndim, scan_dim); + const long values_base_offset = + calculate_base_offset(pos, values_strides, ndim, scan_dim); + const long indices_base_offset = + calculate_base_offset(pos, indices_strides, ndim, scan_dim); + + acc_t accumulator = Op::identity(); + int64_t best_idx = 0; + const long scan_size = sizes[scan_dim]; + const long input_scan_stride = input_strides[scan_dim]; + const long values_scan_stride = values_strides[scan_dim]; + const long indices_scan_stride = indices_strides[scan_dim]; + + for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) { + const long input_offset = input_base_offset + scan_idx * input_scan_stride; + const long values_offset = + values_base_offset + scan_idx * values_scan_stride; + const long indices_offset = + indices_base_offset + scan_idx * indices_scan_stride; + + T val = input[input_offset]; + acc_t accum_val = static_cast(val); + if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) { + accumulator = accum_val; + best_idx = scan_idx; + } + values[values_offset] = static_cast(accumulator); + indices[indices_offset] = best_idx; + } +} + +#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \ + template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \ + scan_contiguous_innermost_dim>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + constant uint & num_rows [[buffer(2)]], \ + constant uint & row_size [[buffer(3)]], \ + uint row [[thread_position_in_grid]]); \ + \ + template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \ + scan_contiguous_outer_dim>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + constant uint & num_orows [[buffer(2)]], \ + constant uint & num_irows [[buffer(3)]], \ + constant uint & row_size [[buffer(4)]], \ + uint thread_index [[thread_position_in_grid]]); \ + \ + template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \ + scan_strided>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + constant long* sizes [[buffer(2)]], \ + constant long* input_strides [[buffer(3)]], \ + constant long* output_strides [[buffer(4)]], \ + constant uint& ndim [[buffer(5)]], \ + constant uint& scan_dim [[buffer(6)]], \ + uint thread_index [[thread_position_in_grid]]); + +#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \ + template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \ + scan_with_indices_contiguous_innermost_dim>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * values [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant uint& num_rows [[buffer(3)]], \ + constant uint& row_size [[buffer(4)]], \ + uint row [[thread_position_in_grid]]); \ + \ + template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \ + scan_with_indices_contiguous_outer_dim>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * values [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant uint& num_orows [[buffer(3)]], \ + constant uint& num_irows [[buffer(4)]], \ + constant uint& row_size [[buffer(5)]], \ + uint thread_index [[thread_position_in_grid]]); \ + \ + template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \ + scan_with_indices_strided>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * values [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant long* sizes [[buffer(3)]], \ + constant long* input_strides [[buffer(4)]], \ + constant long* values_strides [[buffer(5)]], \ + constant long* indices_strides [[buffer(6)]], \ + constant uint& ndim [[buffer(7)]], \ + constant uint& scan_dim [[buffer(8)]], \ + uint thread_index [[thread_position_in_grid]]); + +// Simple scan operations +REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float); +REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half); + +// Scan operations with indices +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool); + +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool); + +#else // __METAL_VERSION__ >= 310 + +// The reminder of this file contains cummin and cummax implementations adapted +// from MLX: +// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scan.h +// +// The original MLX kernels have been modified to integrate with PyTorch's MPS +// backend. Most notably: +// - Keeping track and returning indices, which MLX kernels don't do. +// - Perform computations on half/bfloat tensors at higher precision (float) +// via c10::metal::accum_t +// +// Original work is licensed under MIT License: +// https://github.com/ml-explore/mlx/blob/main/LICENSE + +inline uint64_t simd_shuffle_and_fill_up( + uint64_t data, + uint64_t filling, + uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline int64_t simd_shuffle_and_fill_up( + int64_t data, + int64_t filling, + uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { + return simd_shuffle_and_fill_up( + static_cast(data), static_cast(filling), delta); +} + +inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline int64_t simd_shuffle(int64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline bool simd_shuffle(bool data, uint16_t lane) { + return simd_shuffle(static_cast(data), lane); +} + +#define DEFINE_SIMD_SCAN() \ + template = true> \ + U simd_scan(U val) { \ + return simd_scan_impl(val); \ + } \ + \ + template = true> \ + U simd_scan(U val) { \ + for (int i = 1; i <= 16; i *= 2) { \ + val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \ + } \ + return val; \ + } + +#define DEFINE_SIMD_EXCLUSIVE_SCAN() \ + template = true> \ + U simd_exclusive_scan(U val) { \ + return simd_exclusive_scan_impl(val); \ + } \ + \ + template = true> \ + U simd_exclusive_scan(U val) { \ + val = simd_scan(val); \ + return simd_shuffle_and_fill_up(val, init, 1); \ + } + +template > +struct LogCumSumExpOp { + static constexpr constant acc_t init = static_cast( + metal::is_floating_point_v ? -metal::numeric_limits::infinity() + : metal::numeric_limits::lowest()); + + acc_t operator()(acc_t a, acc_t b) { + return LogAddExp{}(a, b); + } + + acc_t simd_scan(acc_t x) { + for (int i = 1; i <= 16; i *= 2) { + acc_t other = simd_shuffle_and_fill_up(x, init, i); + x = LogAddExp{}(x, other); + } + return x; + } + + acc_t simd_exclusive_scan(acc_t x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +// Pair structure to hold value and index for cummin/cummax operations +template > +struct ValueIndexPair { + acc_t value; + int64_t index; +}; + +// Helper function to create ValueIndexPair +template > +inline ValueIndexPair make_pair(acc_t v, int64_t i) { + ValueIndexPair result; + result.value = v; + result.index = i; + return result; +} + +// Helper function for shuffling pairs in SIMD operations +template > +inline ValueIndexPair simd_shuffle_pair( + ValueIndexPair data, + uint16_t lane) { + return make_pair( + simd_shuffle(data.value, lane), simd_shuffle(data.index, lane)); +} + +template > +struct CumMinOp { + using pair_t = ValueIndexPair; + + static constexpr constant acc_t init_val = static_cast( + metal::is_floating_point_v ? metal::numeric_limits::infinity() + : metal::numeric_limits::max()); + + static pair_t get_init() { + return make_pair(init_val, 0); + } + + pair_t operator()(pair_t a, pair_t b) { + if (::metal::isnan(static_cast(a.value)) && + ::metal::isnan(static_cast(b.value))) { + return (a.index >= b.index) ? a : b; + } else if (::metal::isnan(static_cast(a.value))) { + return a; + } else if (::metal::isnan(static_cast(b.value))) { + return b; + } else if (a.value < b.value) { + return a; + } else if (a.value > b.value) { + return b; + } else { + return (a.index >= b.index) ? a : b; + } + } + + // For SIMD operations, we need to handle pairs differently + pair_t simd_scan(pair_t val) { + // For pairs, we need to implement scan manually since SIMD doesn't support + // pairs directly + pair_t init_val = get_init(); + for (int i = 1; i <= 16; i *= 2) { + pair_t shuffled = make_pair( + simd_shuffle_and_fill_up(val.value, init_val.value, i), + simd_shuffle_and_fill_up(val.index, init_val.index, i)); + val = operator()(val, shuffled); + } + return val; + } + + pair_t simd_exclusive_scan(pair_t val) { + val = simd_scan(val); + pair_t init_val = get_init(); + return simd_shuffle_and_fill_up_pair(val, init_val, 1); + } + + private: + pair_t simd_shuffle_pair(pair_t data, uint16_t delta) { + pair_t init_val = get_init(); + return make_pair( + simd_shuffle_and_fill_up(data.value, init_val.value, delta), + simd_shuffle_and_fill_up(data.index, init_val.index, delta)); + } + + pair_t simd_shuffle_and_fill_up_pair( + pair_t data, + pair_t filling, + uint16_t delta) { + return make_pair( + simd_shuffle_and_fill_up(data.value, filling.value, delta), + simd_shuffle_and_fill_up(data.index, filling.index, delta)); + } +}; + +template > +struct CumMaxOp { + using pair_t = ValueIndexPair; + + static constexpr constant acc_t init_val = static_cast( + metal::is_floating_point_v ? -metal::numeric_limits::infinity() + : metal::numeric_limits::lowest()); + + static pair_t get_init() { + return make_pair(init_val, 0); + } + + pair_t operator()(pair_t a, pair_t b) { + if (::metal::isnan(static_cast(a.value)) && + ::metal::isnan(static_cast(b.value))) { + return (a.index >= b.index) ? a : b; + } else if (::metal::isnan(static_cast(a.value))) { + return a; + } else if (::metal::isnan(static_cast(b.value))) { + return b; + } else if (a.value > b.value) { + return a; + } else if (a.value < b.value) { + return b; + } else { + return (a.index >= b.index) ? a : b; + } + } + + // For SIMD operations, we need to handle pairs differently + pair_t simd_scan(pair_t val) { + // For pairs, we need to implement scan manually since SIMD doesn't support + // pairs directly + pair_t init_val = get_init(); + for (int i = 1; i <= 16; i *= 2) { + pair_t shuffled = make_pair( + simd_shuffle_and_fill_up(val.value, init_val.value, i), + simd_shuffle_and_fill_up(val.index, init_val.index, i)); + val = operator()(val, shuffled); + } + return val; + } + + pair_t simd_exclusive_scan(pair_t val) { + val = simd_scan(val); + pair_t init_val = get_init(); + return simd_shuffle_and_fill_up_pair(val, init_val, 1); + } + + private: + pair_t simd_shuffle_pair(pair_t data, uint16_t delta) { + pair_t init_val = get_init(); + return make_pair( + simd_shuffle_and_fill_up(data.value, init_val.value, delta), + simd_shuffle_and_fill_up(data.index, init_val.index, delta)); + } + + pair_t simd_shuffle_and_fill_up_pair( + pair_t data, + pair_t filling, + uint16_t delta) { + return make_pair( + simd_shuffle_and_fill_up(data.value, filling.value, delta), + simd_shuffle_and_fill_up(data.index, filling.index, delta)); + } +}; + +template > +inline void load_unsafe(acc_t values[N_READS], const device T* input) { + for (int i = 0; i < N_READS; i++) { + values[i] = static_cast(input[i]); + } +} + +template > +inline void load_safe( + acc_t values[N_READS], + const device T* input, + int start, + int total, + acc_t init) { + for (int i = 0; i < N_READS; i++) { + values[i] = (start + i < total) ? static_cast(input[i]) : init; + } +} + +template > +inline void write_unsafe(acc_t values[N_READS], device T* out) { + for (int i = 0; i < N_READS; i++) { + out[i] = static_cast(values[i]); + } +} + +template > +inline void write_safe( + acc_t values[N_READS], + device T* out, + int start, + int total) { + for (int i = 0; i < N_READS; i++) { + if (start + i < total) { + out[i] = static_cast(values[i]); + } + } +} + +// Utility function for ceiling division +template +inline T ceildiv(T N, U M) { + return (N + M - 1) / M; +} + +// Inclusive scan along innermost dimension for contiguous tensors +template > +kernel void scan_innermost_dim( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + Op op; + + // Position the pointers + size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; + in += offset; + out += offset; + + // Compute the number of simd_groups + uint simd_groups = lsize.x / simd_size; + + // Allocate memory + acc_t prefix = Op::init; + acc_t values[N_READS]; + threadgroup acc_t simdgroup_sums[32]; + + // Loop over the reduced axis in blocks of size ceildiv(axis_size, + // N_READS*lsize) + // Read block + // Compute inclusive scan of the block + // Compute inclusive scan per thread + // Compute exclusive scan of thread sums in simdgroup + // Write simdgroup sums in SM + // Compute exclusive scan of simdgroup sums + // Compute the output by scanning prefix, prev_simdgroup, prev_thread, + // value + // Write block + + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { + // Compute the block offset + uint offset = r * lsize.x * N_READS + lid.x * N_READS; + + // Read the values + if ((offset + N_READS) < axis_size) { + load_unsafe(values, in + offset); + } else { + load_safe(values, in + offset, offset, axis_size, Op::init); + } + + // Compute an inclusive scan per thread + for (int i = 1; i < N_READS; i++) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums + acc_t prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); + + // Write simdgroup_sums to SM + if (simd_lane_id == simd_size - 1) { + simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute exclusive scan of simdgroup_sums + if (simd_group_id == 0) { + acc_t prev_simdgroup = + op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); + simdgroup_sums[simd_lane_id] = prev_simdgroup; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute the output + for (int i = 0; i < N_READS; i++) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], simdgroup_sums[simd_group_id]); + values[i] = op(values[i], prev_thread); + } + + // Write the values + if ((offset + N_READS) < axis_size) { + write_unsafe(values, out + offset); + } else { + write_safe(values, out + offset, offset, axis_size); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Share the prefix + if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { + simdgroup_sums[0] = values[N_READS - 1]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + prefix = simdgroup_sums[0]; + } +} + +// Inclusive scan along outer dimension for contiguous tensors +template > +kernel void scan_outer_dim( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + const constant size_t& stride [[buffer(3)]], + const constant size_t& stride_blocks [[buffer(4)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + constexpr int BM = 32; + constexpr int BN = 32; + constexpr int BN_pad = 32 + 16 / sizeof(T); + constexpr int n_simds = BN / N_READS; + constexpr int n_scans = BN / n_simds; + Op op; + + threadgroup acc_t read_buffer[BM * BN_pad]; + acc_t values[n_scans]; + acc_t prefix[n_scans]; + for (int i = 0; i < n_scans; i++) { + prefix[i] = Op::init; + } + + // Compute offsets + size_t full_gid = gid.y + gsize.y * size_t(gid.z); + size_t offset = full_gid / stride_blocks * axis_size * stride; + size_t global_index_x = full_gid % stride_blocks * BN; + uint read_offset_y = (lid.x * N_READS) / BN; + uint read_offset_x = (lid.x * N_READS) % BN; + uint scan_offset_y = simd_lane_id; + uint scan_offset_x = simd_group_id * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + threadgroup acc_t* read_into = + read_buffer + read_offset_y * BN_pad + read_offset_x; + threadgroup acc_t* read_from = + read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread + uint index_y = j + read_offset_y; + uint check_index_y = index_y; + + // Read into shared memory with type conversion + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; i++) { + read_into[i] = static_cast(in[index_y * stride + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = static_cast(in[index_y * stride + i]); + } else { + read_into[i] = Op::init; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read strided into registers + for (int i = 0; i < n_scans; i++) { + values[i] = read_from[i]; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Perform the scan + for (int i = 0; i < n_scans; i++) { + values[i] = op.simd_scan(values[i]); + values[i] = op(values[i], prefix[i]); + prefix[i] = simd_shuffle(values[i], simd_size - 1); + } + + // Write to shared memory + for (int i = 0; i < n_scans; i++) { + read_from[i] = values[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write to device memory with type conversion + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; i++) { + out[index_y * stride + i] = static_cast(read_into[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = static_cast(read_into[i]); + } + } + } + } +} + +template > +kernel void scan_with_indices_innermost_dim( + const device T* in [[buffer(0)]], + device T* out_values [[buffer(1)]], + device int64_t* out_indices [[buffer(2)]], + const constant size_t& axis_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + Op op; + using pair_t = typename Op::pair_t; + + // Position the pointers + size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; + in += offset; + out_values += offset; + out_indices += offset; + + // Compute the number of simd_groups + uint simd_groups = lsize.x / simd_size; + + // Allocate memory + pair_t prefix = op.get_init(); + pair_t values[N_READS]; + threadgroup pair_t simdgroup_sums[32]; + + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { + // Compute the block offset + uint offset_idx = r * lsize.x * N_READS + lid.x * N_READS; + + // Read the values as pairs + for (int i = 0; i < N_READS; i++) { + if ((offset_idx + i) < axis_size) { + values[i] = make_pair( + static_cast(in[offset_idx + i]), offset_idx + i); + } else { + values[i] = op.get_init(); + } + } + + // Compute an inclusive scan per thread + for (int i = 1; i < N_READS; i++) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums + pair_t prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); + + // Write simdgroup_sums to SM + if (simd_lane_id == simd_size - 1) { + simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute exclusive scan of simdgroup_sums + if (simd_group_id == 0) { + pair_t prev_simdgroup = + op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); + simdgroup_sums[simd_lane_id] = prev_simdgroup; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute the output + for (int i = 0; i < N_READS; i++) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], simdgroup_sums[simd_group_id]); + values[i] = op(values[i], prev_thread); + } + + // Write the values + for (int i = 0; i < N_READS; i++) { + if ((offset_idx + i) < axis_size) { + out_values[offset_idx + i] = static_cast(values[i].value); + out_indices[offset_idx + i] = values[i].index; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Share the prefix + if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { + simdgroup_sums[0] = values[N_READS - 1]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + prefix = simdgroup_sums[0]; + } +} + +template > +kernel void scan_with_indices_outer_dim( + const device T* in [[buffer(0)]], + device T* out_values [[buffer(1)]], + device int64_t* out_indices [[buffer(2)]], + const constant size_t& axis_size [[buffer(3)]], + const constant size_t& stride [[buffer(4)]], + const constant size_t& stride_blocks [[buffer(5)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + constexpr int BM = 32; + constexpr int BN = 32; + constexpr int BN_pad = 32 + 16 / sizeof(T); + constexpr int n_simds = BN / N_READS; + constexpr int n_scans = BN / n_simds; + Op op; + using pair_t = typename Op::pair_t; + + threadgroup pair_t read_buffer[BM * BN_pad]; + pair_t values[n_scans]; + pair_t prefix[n_scans]; + for (int i = 0; i < n_scans; i++) { + prefix[i] = op.get_init(); + } + + // Compute offsets + size_t full_gid = gid.y + gsize.y * size_t(gid.z); + size_t offset = full_gid / stride_blocks * axis_size * stride; + size_t global_index_x = full_gid % stride_blocks * BN; + uint read_offset_y = (lid.x * N_READS) / BN; + uint read_offset_x = (lid.x * N_READS) % BN; + uint scan_offset_y = simd_lane_id; + uint scan_offset_x = simd_group_id * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out_values += offset + global_index_x + read_offset_x; + out_indices += offset + global_index_x + read_offset_x; + threadgroup pair_t* read_into = + read_buffer + read_offset_y * BN_pad + read_offset_x; + threadgroup pair_t* read_from = + read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread + uint index_y = j + read_offset_y; + uint check_index_y = index_y; + + // Read into shared memory as pairs + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + // For cummin/cummax, the index should represent the position along the + // scan axis + read_into[i] = make_pair( + static_cast(in[index_y * stride + i]), index_y); + } else { + read_into[i] = op.get_init(); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read strided into registers + for (int i = 0; i < n_scans; i++) { + values[i] = read_from[i]; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Perform the scan + for (int i = 0; i < n_scans; i++) { + values[i] = op.simd_scan(values[i]); + values[i] = op(values[i], prefix[i]); + prefix[i] = make_pair( + simd_shuffle(values[i].value, simd_size - 1), + simd_shuffle(values[i].index, simd_size - 1)); + } + + // Write to shared memory + for (int i = 0; i < n_scans; i++) { + read_from[i] = values[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write to device memory + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out_values[index_y * stride + i] = static_cast(read_into[i].value); + out_indices[index_y * stride + i] = read_into[i].index; + } + } + } +} + +#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE, NREADS) \ + template [[host_name(#OP_NAME "_innermost_" #DTYPE)]] [[kernel]] void \ + scan_innermost_dim, NREADS>( \ + const device DTYPE* in [[buffer(0)]], \ + device DTYPE* out [[buffer(1)]], \ + const constant size_t& axis_size [[buffer(2)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + \ + template [[host_name(#OP_NAME "_outer_" #DTYPE)]] [[kernel]] void \ + scan_outer_dim, NREADS>( \ + const device DTYPE* in [[buffer(0)]], \ + device DTYPE* out [[buffer(1)]], \ + const constant size_t& axis_size [[buffer(2)]], \ + const constant size_t& stride [[buffer(3)]], \ + const constant size_t& stride_blocks [[buffer(4)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]) + +#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE, NREADS) \ + template [[host_name(#OP_NAME "_innermost_" #DTYPE)]] [[kernel]] void \ + scan_with_indices_innermost_dim, NREADS>( \ + const device DTYPE* in [[buffer(0)]], \ + device DTYPE* out_values [[buffer(1)]], \ + device int64_t* out_indices [[buffer(2)]], \ + const constant size_t& axis_size [[buffer(3)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + \ + template [[host_name(#OP_NAME "_outer_" #DTYPE)]] [[kernel]] void \ + scan_with_indices_outer_dim, NREADS>( \ + const device DTYPE* in [[buffer(0)]], \ + device DTYPE* out_values [[buffer(1)]], \ + device int64_t* out_indices [[buffer(2)]], \ + const constant size_t& axis_size [[buffer(3)]], \ + const constant size_t& stride [[buffer(4)]], \ + const constant size_t& stride_blocks [[buffer(5)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]) + +// Simple scan operations +REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float, 4); +REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half, 4); +REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, bfloat, 4); + +// Scan with indices operations for cummin/cummax +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bfloat, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long, 2); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool, 4); + +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bfloat, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long, 2); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4); + +#endif diff --git a/aten/src/ATen/native/mps/kernels/TriangularOps.metal b/aten/src/ATen/native/mps/kernels/TriangularOps.metal index aa1093ec34d43b..27ad506028488c 100644 --- a/aten/src/ATen/native/mps/kernels/TriangularOps.metal +++ b/aten/src/ATen/native/mps/kernels/TriangularOps.metal @@ -1,5 +1,119 @@ #include + using namespace metal; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +template +inline bool triul_mask(int row, int col, int k); +template <> +inline bool triul_mask(int row, int col, int k) { + return col - row >= k; +} +template <> +inline bool triul_mask(int row, int col, int k) { + return col - row <= k; +} + +template +inline IndexType compute_offs( + constant IndexType* strides, + constant uint* sizes, + uint3 pos, + int ndim) { + auto offs = pos.x * strides[0] + pos.y * strides[1]; + if (ndim < 4) { + return ndim == 3 ? offs + pos.z * strides[2] : offs; + } + auto idx = pos.z; + for (int i = 2; i < ndim; ++i) { + offs += strides[i] * (idx % sizes[i]); + idx /= sizes[i]; + } + return offs; +} + +template +kernel void triul_inplace( + device T* self, + constant IndexType* strides, + constant uint* sizes, + constant int2& k_ndim, + uint3 pos [[thread_position_in_grid]]) { + if (triul_mask(pos.y, pos.x, k_ndim.x)) { + return; + } + auto offs = compute_offs(strides, sizes, pos, k_ndim.y); + self[offs] = 0; +} + +template +kernel void triul( + device T* out, + device T* inp, + constant IndexType* out_strides, + constant IndexType* inp_strides, + constant uint* sizes, + constant int2& k_ndim, + uint3 pos [[thread_position_in_grid]]) { + auto out_offs = compute_offs(out_strides, sizes, pos, k_ndim.y); + if (!triul_mask(pos.y, pos.x, k_ndim.x)) { + out[out_offs] = 0; + return; + } + auto inp_offs = compute_offs(inp_strides, sizes, pos, k_ndim.y); + out[out_offs] = inp[inp_offs]; +} + +#define INSTANTIATE_TRIUL_KERNELS(DTYPE, IDX_TYPE) \ + template [[host_name("triu_inplace_" #IDX_TYPE "_" #DTYPE)]] kernel void \ + triul_inplace( \ + device DTYPE * self, \ + constant IDX_TYPE * strides, \ + constant uint * sizes, \ + constant int2 & k_ndim, \ + uint3 pos [[thread_position_in_grid]]); \ + template [[host_name("tril_inplace_" #IDX_TYPE "_" #DTYPE)]] kernel void \ + triul_inplace( \ + device DTYPE * self, \ + constant IDX_TYPE * strides, \ + constant uint * sizes, \ + constant int2 & k_ndim, \ + uint3 pos [[thread_position_in_grid]]); \ + template [[host_name("triu_" #IDX_TYPE "_" #DTYPE)]] kernel void \ + triul( \ + device DTYPE * out, \ + device DTYPE * inp, \ + constant IDX_TYPE * out_strides, \ + constant IDX_TYPE * inp_strides, \ + constant uint * sizes, \ + constant int2 & k_ndim, \ + uint3 pos [[thread_position_in_grid]]); \ + template [[host_name("tril_" #IDX_TYPE "_" #DTYPE)]] kernel void \ + triul( \ + device DTYPE * out, \ + device DTYPE * inp, \ + constant IDX_TYPE * out_strides, \ + constant IDX_TYPE * inp_strides, \ + constant uint * sizes, \ + constant int2 & k_ndim, \ + uint3 pos [[thread_position_in_grid]]) + +INSTANTIATE_TRIUL_KERNELS(float, int); +INSTANTIATE_TRIUL_KERNELS(half, int); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_TRIUL_KERNELS(bfloat, int); +#endif + +INSTANTIATE_TRIUL_KERNELS(float2, int); +INSTANTIATE_TRIUL_KERNELS(half2, int); + +INSTANTIATE_TRIUL_KERNELS(long, int); +INSTANTIATE_TRIUL_KERNELS(int, int); +INSTANTIATE_TRIUL_KERNELS(short, int); +INSTANTIATE_TRIUL_KERNELS(char, int); +INSTANTIATE_TRIUL_KERNELS(uchar, int); +INSTANTIATE_TRIUL_KERNELS(bool, int); + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // To find the max integer that does not exceed the root of an int64_t variable, diff --git a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal index f7c83181a7eeec..37a61397467f10 100644 --- a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal @@ -1,23 +1,82 @@ #include #include +#include #include using namespace metal; using namespace c10::metal; +// Implement exp wrapper for both real and complex types +template , bool> = true> +inline T exp_(const T x) { + return T(precise::exp(x)); +} + +template , bool> = true> +inline T exp_(const T x) { + return T( + precise::exp(x.x) * precise::cos(x.y), + precise::exp(x.x) * precise::sin(x.y)); +} + struct exp_functor { - template - inline enable_if_t, T> operator()(const T x) { - return T(precise::exp(x)); + template , bool> = true> + inline T operator()(const T x) { + return exp_(x); } - template - inline enable_if_t, float> operator()(const T x) { - return precise::exp(static_cast(x)); + template , bool> = true> + inline float operator()(const T x) { + return exp_(static_cast(x)); } - template - inline enable_if_t, T> operator()(const T x) { - return T( - precise::exp(x.x) * precise::cos(x.y), - precise::exp(x.x) * precise::sin(x.y)); +}; + +struct expm1_functor { + template , bool> = true> + inline T operator()(const T x) { + if (::metal::fabs(x) < 1e-5f) { + return static_cast(c10::metal::expm1f(static_cast(x))); + } else { + return static_cast(exp_(static_cast(x)) - 1.0f); + } + } + template , bool> = true> + inline float operator()(const T x) { + return exp_(static_cast(x)) - 1; + } + template , bool> = true> + inline T operator()(const T x) { + if (::precise::sqrt(dot(x, x)) < 1e-2) { + return T( + c10::metal::expm1f(x.x + ::precise::log(precise::cos(x.y))), + exp_(x.x) * precise::sin(x.y)); + } else { + return exp_(x) - T(1.0f, 0.0f); + } + } +}; + +struct sigmoid_functor { + template , bool> = true> + inline T operator()(const T x) { + return T(1.0f / (1.0f + exp_(-static_cast(x)))); + } + template , bool> = true> + inline T operator()(const T x) { + return c10::metal::div(T(1, 0), (T(1, 0) + exp_(-x))); + } + template , bool> = true> + inline float operator()(const T x) { + return 1.0f / (1.0f + exp_(-static_cast(x))); + } +}; + +struct abs_functor { + template , bool> = true> + inline T operator()(const T x) { + return static_cast(precise::abs(x)); + } + template , bool> = true> + inline T operator()(const T x) { + return T(::precise::sqrt(dot(x, x)), 0); } }; @@ -79,6 +138,50 @@ struct tan_functor { } }; +struct sinh_functor { + template + inline enable_if_t, T> operator()(const T x) { + return T(precise::sinh(x)); + } + template + inline enable_if_t, float> operator()(const T x) { + return precise::sinh(static_cast(x)); + } + template + inline enable_if_t, T> operator()(const T x) { + // sinh(x) = (e^x - e^(-x)) / 2 + auto exp_1 = + T(precise::exp(x.x) * precise::cos(x.y), + precise::exp(x.x) * precise::sin(x.y)); + auto exp_2 = + T(precise::exp(-x.x) * precise::cos(-x.y), + precise::exp(-x.x) * precise::sin(-x.y)); + return div(exp_1 - exp_2, T(2, 0)); + } +}; + +struct cosh_functor { + template + inline enable_if_t, T> operator()(const T x) { + return T(precise::cosh(x)); + } + template + inline enable_if_t, float> operator()(const T x) { + return precise::cosh(static_cast(x)); + } + template + inline enable_if_t, T> operator()(const T x) { + // cosh(x+iy)=(e^x + e^(-x)) / 2 + auto exp_1 = + T(precise::exp(x.x) * precise::cos(x.y), + precise::exp(x.x) * precise::sin(x.y)); + auto exp_2 = + T(precise::exp(-x.x) * precise::cos(-x.y), + precise::exp(-x.x) * precise::sin(-x.y)); + return div(exp_1 + exp_2, T(2, 0)); + } +}; + struct tanh_functor { template inline enable_if_t, T> operator()(const T x) { @@ -97,6 +200,119 @@ struct tanh_functor { } }; +struct asin_functor { + template + inline enable_if_t, T> operator()(const T x) { + return T(precise::asin(x)); + } + template + inline enable_if_t, float> operator()(const T x) { + return precise::asin(static_cast(x)); + } + template + inline enable_if_t, T> operator()(const T x) { + // asin(z) = atan(z/sqrt(1-z^2)) if z != ±1 + if (x.x == 1 && x.y == 0) + return T(M_PI_F / 2, 0); + else if (x.x == -1 && x.y == 0) + return T(M_PI_F / -2, 0); + auto sqrt_val = T(1, 0) - c10::metal::mul(x, x); + // calculate sqrt + // modulus + auto m = precise::sqrt(sqrt_val.x * sqrt_val.x + sqrt_val.y * sqrt_val.y); + // real part: sqrt((m + a)/2) + auto real_part = precise::sqrt((m + sqrt_val.x) * .5); + // imaginary part: sign(b) * sqrt((m - a)/2) + auto imag_part = copysign( + static_cast(precise::sqrt((m - sqrt_val.x) * .5)), + sqrt_val.y); + auto atan_val = div(x, T(real_part, imag_part)); + // calculate atan (see atan_functor) + auto coef = div(T(1, 0), T(0, 2)); + auto log_arg = + div(T(-1 * atan_val.x, 1 - atan_val.y), T(atan_val.x, 1 + atan_val.y)); + // Calculate log using method from log_functor + auto magnitude = + ::precise::sqrt(log_arg.x * log_arg.x + log_arg.y * log_arg.y); + auto real = ::precise::log(magnitude); + auto imag = (log_arg.x == 0 && log_arg.y == 0) + ? 0 + : ::precise::atan2(log_arg.y, log_arg.x); + // return coefficient * log value + return c10::metal::mul(coef, T(real, imag)); + } +}; + +struct acos_functor { + template + inline enable_if_t, T> operator()(const T x) { + return T(precise::acos(x)); + } + template + inline enable_if_t, float> operator()(const T x) { + return precise::acos(static_cast(x)); + } + template + inline enable_if_t, T> operator()(const T x) { + // acos(z) = pi/2 - asin(z) if z != ±1 + // calculate asin + if (x.x == 1 && x.y == 0) + return T(M_PI_F, 0); + else if (x.x == -1 && x.y == 0) + return T(-M_PI_F, 0); + auto sqrt_val = T(1, 0) - c10::metal::mul(x, x); + // calculate sqrt + // modulus + auto m = precise::sqrt(sqrt_val.x * sqrt_val.x + sqrt_val.y * sqrt_val.y); + // real part: sqrt((m + a)/2) + auto real_part = precise::sqrt((m + sqrt_val.x) * .5); + // imaginary part: sign(b) * sqrt((m - a)/2) + auto imag_part = copysign( + static_cast(precise::sqrt((m - sqrt_val.x) * .5)), + sqrt_val.y); + auto atan_val = div(x, T(real_part, imag_part)); + // calculate atan (see atan_functor) + auto coef = div(T(1, 0), T(0, 2)); + auto log_arg = + div(T(-1 * atan_val.x, 1 - atan_val.y), T(atan_val.x, 1 + atan_val.y)); + // Calculate log using method from log_functor + auto magnitude = + ::precise::sqrt(log_arg.x * log_arg.x + log_arg.y * log_arg.y); + auto real = ::precise::log(magnitude); + auto imag = (log_arg.x == 0 && log_arg.y == 0) + ? 0 + : ::precise::atan2(log_arg.y, log_arg.x); + // return coefficient * log value + return T(M_PI_F / 2, 0) - c10::metal::mul(coef, T(real, imag)); + } +}; + +struct atan_functor { + template + inline enable_if_t, T> operator()(const T x) { + return T(precise::atan(x)); + } + template + inline enable_if_t, float> operator()(const T x) { + return precise::atan(static_cast(x)); + } + template + inline enable_if_t, T> operator()(const T x) { + // atan(z) = (1/2i)ln((i-z)/(i+z)) + auto coef = div(T(1, 0), T(0, 2)); + auto log_arg = div(T(-1 * x.x, 1 - x.y), T(x.x, 1 + x.y)); + // Calculate log using method from log_functor + auto magnitude = + ::precise::sqrt(log_arg.x * log_arg.x + log_arg.y * log_arg.y); + auto real = ::precise::log(magnitude); + auto imag = (log_arg.x == 0 && log_arg.y == 0) + ? 0 + : ::precise::atan2(log_arg.y, log_arg.x); + // return coefficient * log value + return c10::metal::mul(coef, T(real, imag)); + } +}; + // Bool specialization is need to workaround compiler crashes on MacOS-13 // Otherwise attempts to invoke will fail to create state object with error // Error Domain=AGXMetal13_3 Code=3 "Compiler encountered an internal error" @@ -145,6 +361,28 @@ struct log10_functor { } }; +struct log1p_functor { + template + inline enable_if_t, T> operator()(const T x) { + return T(::c10::metal::log1p(float(x))); + } + template + inline enable_if_t, float> operator()(const T x) { + return ::precise::log(1.0f + static_cast(x)); + } + template + inline enable_if_t, T> operator()(const T x) { + // TODO: Implement proper log1p algoirthm + auto magnitude = ::precise::sqrt((1.0f + x.x) * (1.0f + x.x) + x.y * x.y); + auto real = ::precise::log(magnitude); + auto imag = (x.x == -1 && x.y == 0) ? 0 : ::precise::atan2(x.y, 1.0 + x.x); + return T(real, imag); + } + inline float operator()(const bool x) { + return x ? ::precise::log(2.0) : 0; + } +}; + struct log2_functor { template inline enable_if_t, T> operator()(const T x) { @@ -252,6 +490,21 @@ struct bitwise_not_functor { } }; +template +float erfc(T x) { + return 1.0 - erf(x); +} + +struct round_decimals_functor { + template + inline T operator()(const T x, const long ndigits) { + return static_cast( + rint(exp10(float(ndigits)) * x) * exp10(float(-ndigits))); + } +}; + +DEFINE_UNARY_FLOATING_FUNCTOR(erf); +DEFINE_UNARY_FLOATING_FUNCTOR(erfc); DEFINE_UNARY_FLOATING_FUNCTOR(erfinv); DEFINE_UNARY_FLOATING_FUNCTOR(sinc); @@ -270,24 +523,43 @@ REGISTER_UNARY_OP(bitwise_not, char, char); REGISTER_UNARY_OP(bitwise_not, uchar, uchar); REGISTER_UNARY_OP(bitwise_not, bool, bool); +REGISTER_UNARY_OP(abs, int, int); +REGISTER_UNARY_OP(abs, long, long); +REGISTER_UNARY_OP(abs, short, short); +REGISTER_UNARY_OP(abs, char, char); +REGISTER_UNARY_OP(abs, uchar, uchar); +REGISTER_UNARY_OP(abs, float, float); +REGISTER_UNARY_OP(abs, half, half); + #define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \ + REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(exp, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(expm1, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(sigmoid, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(exp2, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(log, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(log10, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(log1p, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(log2, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(sinc, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(sqrt, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(rsqrt, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(sinh, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(cosh, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(tanh, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(sin, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(cos, DTYPE1, DTYPE0); \ - REGISTER_UNARY_OP(tan, DTYPE1, DTYPE0) + REGISTER_UNARY_OP(tan, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(asin, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \ + REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0) #if __METAL_VERSION__ >= 310 INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat); REGISTER_UNARY_OP(neg, bfloat, bfloat); +REGISTER_UNARY_OP(abs, bfloat, bfloat); #endif INSTANTIATE_UNARY_KERNELS2(half, half); INSTANTIATE_UNARY_KERNELS2(float, float); @@ -298,75 +570,36 @@ INSTANTIATE_UNARY_KERNELS2(float, short); INSTANTIATE_UNARY_KERNELS2(float, int); INSTANTIATE_UNARY_KERNELS2(float, long); -#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE) \ - REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(exp2, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(log, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(log10, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(log2, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(tanh, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(sqrt, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(rsqrt, DTYPE##2, DTYPE##2); \ - \ - REGISTER_UNARY_OP(sinc, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(sin, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(cos, DTYPE##2, DTYPE##2); \ - REGISTER_UNARY_OP(tan, DTYPE##2, DTYPE##2) +#define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE) \ + REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(sigmoid, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(abs, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(exp2, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(log, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(log10, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(log1p, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(log2, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(sinh, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(cosh, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(tanh, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(sqrt, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(rsqrt, DTYPE##2, DTYPE##2); \ + \ + REGISTER_UNARY_OP(sinc, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(sin, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(cos, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(tan, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(asin, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(acos, DTYPE##2, DTYPE##2); \ + REGISTER_UNARY_OP(atan, DTYPE##2, DTYPE##2) INSTANTIATE_UNARY_KERNELS_VEC2(half); INSTANTIATE_UNARY_KERNELS_VEC2(float); -template -kernel void round_decimals_dense( - device T* output [[buffer(0)]], - constant T* input [[buffer(1)]], - constant long& ndigits [[buffer(2)]], - uint index [[thread_position_in_grid]]) { - output[index] = static_cast( - rint(exp10(float(ndigits)) * input[index]) * exp10(float(-ndigits))); -} - -template -kernel void round_decimals_strided( - device T* output [[buffer(0)]], - constant T* input [[buffer(1)]], - constant long* sizes [[buffer(2)]], - constant long* input_strides [[buffer(3)]], - constant long* output_strides [[buffer(4)]], - constant uint& ndim [[buffer(5)]], - constant long& ndigits [[buffer(6)]], - uint index [[thread_position_in_grid]]) { - int pos[max_ndim]; - pos_from_thread_index(int(index), pos, sizes, ndim); - const auto input_offs = offset_from_coord(pos, input_strides, ndim); - const auto output_offs = offset_from_coord(pos, output_strides, ndim); - output[output_offs] = static_cast( - rint(exp10(float(ndigits)) * input[input_offs]) * exp10(float(-ndigits))); -} - -#define INSTANTIATE_ROUND_DECIMALS(DTYPE) \ - template \ - [[host_name("round_decimals_dense_" #DTYPE "_" #DTYPE)]] kernel void \ - round_decimals_dense( \ - device DTYPE* output [[buffer(0)]], \ - constant DTYPE* input [[buffer(1)]], \ - constant long& ndigits [[buffer(2)]], \ - uint index [[thread_position_in_grid]]); \ - template \ - [[host_name("round_decimals_strided_" #DTYPE "_" #DTYPE)]] kernel void \ - round_decimals_strided( \ - device DTYPE* output [[buffer(0)]], \ - constant DTYPE* input [[buffer(1)]], \ - constant long* sizes, \ - constant long* input_strides, \ - constant long* output_strides, \ - constant uint& ndim, \ - constant long& ndigits [[buffer(6)]], \ - uint index) - -INSTANTIATE_ROUND_DECIMALS(float); -INSTANTIATE_ROUND_DECIMALS(half); +REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float); +REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half); #if __METAL_VERSION__ >= 310 -INSTANTIATE_ROUND_DECIMALS(bfloat); +REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat); #endif diff --git a/aten/src/ATen/native/mps/kernels/UpSample.h b/aten/src/ATen/native/mps/kernels/UpSample.h new file mode 100644 index 00000000000000..c2c3c1d5d4587a --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/UpSample.h @@ -0,0 +1,20 @@ +#pragma once + +#ifndef __METAL__ +#include +using ulong = unsigned long; +#define _ARRAY_NS std +#else +#include +#define _ARRAY_NS metal +#endif + +template +struct UpsampleParams { + _ARRAY_NS::array input_strides; + _ARRAY_NS::array input_sizes; + _ARRAY_NS::array output_strides; + _ARRAY_NS::array output_sizes; + _ARRAY_NS::array scales; + bool align_corners; +}; diff --git a/aten/src/ATen/native/mps/kernels/UpSample.metal b/aten/src/ATen/native/mps/kernels/UpSample.metal index b32faf59a51923..7181f8b13e2b08 100644 --- a/aten/src/ATen/native/mps/kernels/UpSample.metal +++ b/aten/src/ATen/native/mps/kernels/UpSample.metal @@ -1,3 +1,4 @@ +#include #include #include @@ -61,6 +62,24 @@ accscalar_t area_pixel_compute_source_index( } } +template +scalar_t upsample_get_value_bounded( + constant scalar_t* data, + uint3 dim, + array strides, + uint n, + uint c, + uint z, + uint y, + uint x) { + auto access_z = max(min(z, dim.z - 1), 0U); + auto access_y = max(min(y, dim.y - 1), 0U); + auto access_x = max(min(x, dim.x - 1), 0U); + return data + [n * strides[0] + c * strides[1] + access_z * strides[2] + + access_y * strides[3] + access_x * strides[4]]; +} + template scalar_t upsample_get_value_bounded( constant scalar_t* data, @@ -108,6 +127,27 @@ void upsample_increment_value_bounded( static_cast(value)); } +template +void upsample_increment_value_bounded( + device AtomicType_t* data, + uint3 dim, + array strides, + uint n, + uint c, + uint z, + uint y, + uint x, + float value) { + auto access_z = max(min(z, dim.z - 1), 0U); + auto access_y = max(min(y, dim.y - 1), 0U); + auto access_x = max(min(x, dim.x - 1), 0U); + AtomicType::atomic_add( + data, + n * strides[0] + c * strides[1] + access_z * strides[2] + + access_y * strides[3] + access_x * strides[4], + static_cast(value)); +} + template struct linear_return_type { typedef float type; @@ -124,6 +164,288 @@ inline linear_return_t linear_interp(T v0, T v1, float x) { return x * v1 + (1 - x) * v0; } +/* 3D interpolation kernels and helper functions */ +inline uint3 coords_from_threadidx( + constant UpsampleParams<5>& params, + uint thread_index) { + const auto size_x = static_cast(params.output_sizes[4]); + const auto size_xy = static_cast(params.output_sizes[3]) * size_x; + auto output_xy = thread_index % size_xy; + return uint3(output_xy % size_x, output_xy / size_x, thread_index / size_xy); +} + +inline float3 coords_to_real_coords( + constant UpsampleParams<5>& params, + uint3 output, + bool align_corners) { + auto real_x = area_pixel_compute_source_index( + params.scales[0], output.x, align_corners, /*cubic=*/false); + auto real_y = area_pixel_compute_source_index( + params.scales[1], output.y, align_corners, /*cubic=*/false); + auto real_z = area_pixel_compute_source_index( + params.scales[2], output.z, align_corners, /*cubic=*/false); + return float3(real_x, real_y, real_z); +} + +template +kernel void upsample_nearest_exact_3d( + constant T* inputData [[buffer(0)]], + device T* outputData [[buffer(1)]], + constant UpsampleParams<5>& params [[buffer(2)]], + uint thread_index [[thread_position_in_grid]]) { + const auto input_sizes = uint3( + params.input_sizes[4], params.input_sizes[3], params.input_sizes[2]); + const auto output = coords_from_threadidx(params, thread_index); + const auto real = coords_to_real_coords(params, output, false); + for (uint n = 0; n < params.output_sizes[0]; n++) { + for (uint c = 0; c < params.output_sizes[1]; c++) { + auto res = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z + .5, + real.y + .5, + real.x + .5); + outputData + [n * params.output_strides[0] + c * params.output_strides[1] + + output.z * params.output_strides[2] + + output.y * params.output_strides[3] + + output.x * params.output_strides[4]] = static_cast(res); + } + } +} + +template +kernel void upsample_nearest_exact_3d_backward( + device AtomicType_t* gradInputData [[buffer(0)]], + constant T* gradOutputData [[buffer(1)]], + constant UpsampleParams<5>& params [[buffer(2)]], + uint thread_index [[thread_position_in_grid]]) { + const auto input_sizes = uint3( + params.input_sizes[4], params.input_sizes[3], params.input_sizes[2]); + const auto output = coords_from_threadidx(params, thread_index); + const auto real = coords_to_real_coords(params, output, false); + for (uint n = 0; n < params.output_sizes[0]; n++) { + for (uint c = 0; c < params.output_sizes[1]; c++) { + auto res = gradOutputData + [n * params.output_strides[0] + c * params.output_strides[1] + + output.z * params.output_strides[2] + + output.y * params.output_strides[3] + + output.x * params.output_strides[4]]; + upsample_increment_value_bounded( + gradInputData, + input_sizes, + params.input_strides, + n, + c, + real.z + .5, + real.y + .5, + real.x + .5, + res); + } + } +} + +template +kernel void upsample_nearest_3d( + constant T* inputData [[buffer(0)]], + device T* outputData [[buffer(1)]], + constant UpsampleParams<5>& params [[buffer(2)]], + uint thread_index [[thread_position_in_grid]]) { + const auto input_sizes = uint3( + params.input_sizes[4], params.input_sizes[3], params.input_sizes[2]); + const auto output = coords_from_threadidx(params, thread_index); + const auto real = coords_to_real_coords(params, output, true); + for (uint n = 0; n < params.output_sizes[0]; n++) { + for (uint c = 0; c < params.output_sizes[1]; c++) { + auto res = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z, + real.y, + real.x); + outputData + [n * params.output_strides[0] + c * params.output_strides[1] + + output.z * params.output_strides[2] + + output.y * params.output_strides[3] + + output.x * params.output_strides[4]] = static_cast(res); + } + } +} + +template +kernel void upsample_nearest_3d_backward( + device AtomicType_t* gradInputData [[buffer(0)]], + constant T* gradOutputData [[buffer(1)]], + constant UpsampleParams<5>& params [[buffer(2)]], + uint thread_index [[thread_position_in_grid]]) { + const auto input_sizes = uint3( + params.input_sizes[4], params.input_sizes[3], params.input_sizes[2]); + const auto output = coords_from_threadidx(params, thread_index); + const auto real = coords_to_real_coords(params, output, true); + for (uint n = 0; n < params.output_sizes[0]; n++) { + for (uint c = 0; c < params.output_sizes[1]; c++) { + auto res = gradOutputData + [n * params.output_strides[0] + c * params.output_strides[1] + + output.z * params.output_strides[2] + + output.y * params.output_strides[3] + + output.x * params.output_strides[4]]; + upsample_increment_value_bounded( + gradInputData, + input_sizes, + params.input_strides, + n, + c, + real.z, + real.y, + real.x, + res); + } + } +} + +template +kernel void upsample_trilinear( + constant T* inputData [[buffer(0)]], + device T* outputData [[buffer(1)]], + constant UpsampleParams<5>& params [[buffer(2)]], + uint thread_index [[thread_position_in_grid]]) { + const auto input_sizes = uint3( + params.input_sizes[4], params.input_sizes[3], params.input_sizes[2]); + const auto output = coords_from_threadidx(params, thread_index); + const auto real = coords_to_real_coords(params, output, params.align_corners); + auto t = fract(real); + for (uint n = 0; n < params.output_sizes[0]; n++) { + for (uint c = 0; c < params.output_sizes[1]; c++) { + auto i000 = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z, + real.y, + real.x); + auto i001 = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z, + real.y, + real.x + 1); + auto i010 = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z, + real.y + 1, + real.x); + auto i011 = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z, + real.y + 1, + real.x + 1); + auto i100 = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z + 1, + real.y, + real.x); + auto i101 = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z + 1, + real.y, + real.x + 1); + auto i110 = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z + 1, + real.y + 1, + real.x); + auto i111 = upsample_get_value_bounded( + inputData, + input_sizes, + params.input_strides, + n, + c, + real.z + 1, + real.y + 1, + real.x + 1); + auto i00_l = linear_interp(i000, i001, t.x); + auto i01_l = linear_interp(i010, i011, t.x); + auto i10_l = linear_interp(i100, i101, t.x); + auto i11_l = linear_interp(i110, i111, t.x); + auto i0_l = linear_interp(i00_l, i01_l, t.y); + auto i1_l = linear_interp(i10_l, i11_l, t.y); + auto res = linear_interp(i0_l, i1_l, t.z); + outputData + [n * params.output_strides[0] + c * params.output_strides[1] + + output.z * params.output_strides[2] + + output.y * params.output_strides[3] + + output.x * params.output_strides[4]] = static_cast(res); + } + } +} + +template +kernel void upsample_trilinear_backward( + device AtomicType_t* gradInputData [[buffer(0)]], + constant T* gradOutputData [[buffer(1)]], + constant UpsampleParams<5>& params [[buffer(2)]], + uint thread_index [[thread_position_in_grid]]) { + const auto input_sizes = uint3( + params.input_sizes[4], params.input_sizes[3], params.input_sizes[2]); + const auto output = coords_from_threadidx(params, thread_index); + const auto real = coords_to_real_coords(params, output, params.align_corners); + auto t = fract(real); + for (uint n = 0; n < params.output_sizes[0]; n++) { + for (uint c = 0; c < params.output_sizes[1]; c++) { + auto res = gradOutputData + [n * params.output_strides[0] + c * params.output_strides[1] + + output.z * params.output_strides[2] + + output.y * params.output_strides[3] + + output.x * params.output_strides[4]]; + for (int d = 0; d < 8; d++) { + const auto w = (d & 1 ? t.x : 1.0 - t.x) * (d & 2 ? t.y : 1.0 - t.y) * + (d & 4 ? t.z : 1.0 - t.z); + upsample_increment_value_bounded( + gradInputData, + input_sizes, + params.input_strides, + n, + c, + real.z + ((d & 4) >> 2), + real.y + ((d & 2) >> 1), + real.x + (d & 1), + res * w); + } + } + } +} + // See Note [ Weights computation for uint8_t and multiplication trick ] // Essentially fall back to fixed floating point arithmetic during uint8 // interpolation, which is not necesserily more accurate (see example below), @@ -475,15 +797,59 @@ kernel void upsample_bicubic2d_backward( constant bool& align_corners [[buffer(7)]], \ uint thread_index [[thread_position_in_grid]]) +#define INSTANTIATE_UPSAMPLE_3D(DTYPE) \ + template [[host_name("upsample_nearest_3d_" #DTYPE)]] kernel void \ + upsample_nearest_3d( \ + constant DTYPE * inputData [[buffer(0)]], \ + device DTYPE * outputData [[buffer(1)]], \ + constant UpsampleParams<5> & params [[buffer(2)]], \ + uint thread_index [[thread_position_in_grid]]); \ + template [[host_name("upsample_nearest_exact_3d_" #DTYPE)]] kernel void \ + upsample_nearest_exact_3d( \ + constant DTYPE * inputData [[buffer(0)]], \ + device DTYPE * outputData [[buffer(1)]], \ + constant UpsampleParams<5> & params [[buffer(2)]], \ + uint thread_index [[thread_position_in_grid]]); \ + template [[host_name("upsample_trilinear_" #DTYPE)]] kernel void \ + upsample_trilinear( \ + constant DTYPE * inputData [[buffer(0)]], \ + device DTYPE * outputData [[buffer(1)]], \ + constant UpsampleParams<5> & params [[buffer(2)]], \ + uint thread_index [[thread_position_in_grid]]) + +#define INSTANTIATE_UPSAMPLE_3D_BACKWARD(DTYPE) \ + template [[host_name("upsample_nearest_3d_backward_" #DTYPE)]] kernel void \ + upsample_nearest_3d_backward( \ + device AtomicType_t * gradInputData [[buffer(0)]], \ + constant DTYPE * gradOutputData [[buffer(1)]], \ + constant UpsampleParams<5> & params [[buffer(2)]], \ + uint thread_index [[thread_position_in_grid]]); \ + template \ + [[host_name("upsample_nearest_exact_3d_backward_" #DTYPE)]] kernel void \ + upsample_nearest_exact_3d_backward( \ + device AtomicType_t * gradInputData [[buffer(0)]], \ + constant DTYPE * gradOutputData [[buffer(1)]], \ + constant UpsampleParams<5> & params [[buffer(2)]], \ + uint thread_index [[thread_position_in_grid]]); \ + template [[host_name("upsample_trilinear_backward_" #DTYPE)]] kernel void \ + upsample_trilinear_backward( \ + device AtomicType_t * gradInputData [[buffer(0)]], \ + constant DTYPE * gradOutputData [[buffer(1)]], \ + constant UpsampleParams<5> & params [[buffer(2)]], \ + uint thread_index [[thread_position_in_grid]]); + #define INSTANTIATE_UPSAMPLE_ALL(DTYPE) \ INSTANTIATE_UPSAMPLE_2D(bicubic2d, DTYPE); \ INSTANTIATE_UPSAMPLE_2D_AA(bicubic2d_aa, BicubicFunctor, DTYPE); \ INSTANTIATE_UPSAMPLE_2D_BACKWARD(bicubic2d, DTYPE); \ INSTANTIATE_UPSAMPLE_2D(bilinear2d, DTYPE); \ INSTANTIATE_UPSAMPLE_2D_AA(bilinear2d_aa, BilinearFunctor, DTYPE); \ - INSTANTIATE_UPSAMPLE_LINEAR(DTYPE); + INSTANTIATE_UPSAMPLE_LINEAR(DTYPE); \ + INSTANTIATE_UPSAMPLE_3D_BACKWARD(DTYPE); \ + INSTANTIATE_UPSAMPLE_3D(DTYPE) INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar); +INSTANTIATE_UPSAMPLE_3D(uchar); INSTANTIATE_UPSAMPLE_ALL(float); INSTANTIATE_UPSAMPLE_ALL(half); #if __METAL_VERSION__ >= 310 diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index c526366ae99c89..dec200d7e5bc91 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -17,13 +17,7 @@ #include #include #include -#include -#include -#include -#include #include -#include -#include #include #include #include @@ -34,8 +28,6 @@ #include #include #include -#include -#include #include #include #include @@ -125,110 +117,6 @@ Tensor relu_mps(const Tensor& self) { return output; } -TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_slope, const Tensor& output) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - TORCH_CHECK(output.is_mps()); - - if (self.numel() == 0) { - return; - } - - MPSStream* stream = getCurrentMPSStream(); - - bool executeGatherOp = - !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || - self.is_contiguous(MemoryFormat::ChannelsLast3d)); - Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); - - @autoreleasepool { - std::string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to()); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - MPSGraphTensor* negSlopeTensor = [mpsGraph constantWithScalar:negative_slope.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - MPSGraphTensor* negSlopeMulXTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:negSlopeTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph maximumWithPrimaryTensor:negSlopeMulXTensor - secondaryTensor:inputTensor - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - Placeholder outputPlaceholder = - Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false); - - // Create dictionary of inputs and outputs - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } - if (executeGatherOp) { - output.copy_(output_); - } -} - -TORCH_IMPL_FUNC(leaky_relu_backward_out_mps) -(const Tensor& grad_output, - const Tensor& self, - const Scalar& negative_slope, - bool self_is_result, - const Tensor& output) { - using namespace mps; - using CachedGraph = MPSUnaryGradCachedGraph; - TORCH_CHECK(output.is_mps()); - - if (self.numel() == 0) { - return; - } - - MPSStream* stream = getCurrentMPSStream(); - - Tensor output_ = at::empty_like(self, self.suggest_memory_format()); - - @autoreleasepool { - std::string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + - std::to_string(negative_slope.to()); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - - MPSGraphTensor* negSlopeTensor = [mpsGraph constantWithScalar:negative_slope.to() - shape:@[ @1 ] - dataType:getMPSScalarType(self)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSScalarType(self)]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* gradientsMulNegSlopeTensor = [mpsGraph multiplicationWithPrimaryTensor:gradOutputTensor - secondaryTensor:negSlopeTensor - name:nil]; - MPSGraphTensor* gradInputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:gradOutputTensor - falsePredicateTensor:gradientsMulNegSlopeTensor - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output_); - - // Create dictionary of inputs and outputs - auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } - output.copy_(output_); -} - TORCH_IMPL_FUNC(log_softmax_mps_out) (const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) { using namespace mps; @@ -1381,150 +1269,6 @@ Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) { } } -TORCH_IMPL_FUNC(softshrink_out_mps) -(const Tensor& self, const Scalar& lambd, const Tensor& result) { - using namespace mps; - TORCH_CHECK(self.is_mps()); - - if (result.numel() == 0) - return; - - MPSScalar lambd_scalar = getMPSScalar(lambd, self.scalar_type()); - - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - MPSGraphTensor* lambdTensor_ = nil; - }; - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "softshrink_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to()); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* lambdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType); - - MPSGraphTensor* negativeLambdTensor = [mpsGraph negativeWithTensor:lambdTensor name:nil]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; - MPSGraphTensor* positiveLambdPredicateTensor = [mpsGraph greaterThanOrEqualToWithPrimaryTensor:inputTensor - secondaryTensor:lambdTensor - name:nil]; - MPSGraphTensor* negativeLambdPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:negativeLambdTensor - name:nil]; - MPSGraphTensor* outputTensor = - [mpsGraph selectWithPredicateTensor:positiveLambdPredicateTensor - truePredicateTensor:[mpsGraph subtractionWithPrimaryTensor:inputTensor - secondaryTensor:lambdTensor - name:nil] - falsePredicateTensor:zeroTensor - name:nil]; - outputTensor = [mpsGraph selectWithPredicateTensor:negativeLambdPredicateTensor - truePredicateTensor:[mpsGraph additionWithPrimaryTensor:inputTensor - secondaryTensor:lambdTensor - name:nil] - falsePredicateTensor:outputTensor - name:nil]; - MPSGraphTensor* isNanTensor = [mpsGraph isNaNWithTensor:inputTensor name:nil]; - - outputTensor = [mpsGraph selectWithPredicateTensor:isNanTensor - truePredicateTensor:inputTensor - falsePredicateTensor:outputTensor - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - newCachedGraph->lambdTensor_ = lambdTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->lambdTensor_ : getMPSGraphTensorFromScalar(stream, lambd_scalar), - }; - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } -} - -static void shrink_backward_out_mps(const Tensor& grad_output, - const Tensor& self, - const Scalar& lambd, - const Tensor& grad_input, - std::string op_name) { - using namespace mps; - TORCH_CHECK(self.is_mps()); - - if (grad_input.numel() == 0) - return; - - MPSScalar lambd_scalar = getMPSScalar(lambd, self.scalar_type()); - - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* gradOutputTensor_ = nil; - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* gradInputTensor_ = nil; - MPSGraphTensor* lambdTensor_ = nil; - }; - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = op_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to()); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* lambdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType); - - MPSGraphTensor* negativeLambdTensor = [mpsGraph negativeWithTensor:lambdTensor name:nil]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; - MPSGraphTensor* positiveLambdPredicateTensor = [mpsGraph greaterThanOrEqualToWithPrimaryTensor:inputTensor - secondaryTensor:lambdTensor - name:nil]; - MPSGraphTensor* negativeLambdPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:negativeLambdTensor - name:nil]; - MPSGraphTensor* gradInputTensor = [mpsGraph selectWithPredicateTensor:positiveLambdPredicateTensor - truePredicateTensor:gradOutputTensor - falsePredicateTensor:zeroTensor - name:nil]; - gradInputTensor = [mpsGraph selectWithPredicateTensor:negativeLambdPredicateTensor - truePredicateTensor:gradOutputTensor - falsePredicateTensor:gradInputTensor - name:nil]; - - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - newCachedGraph->lambdTensor_ = lambdTensor; - }); - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->lambdTensor_ : getMPSGraphTensorFromScalar(stream, lambd_scalar), - }; - runMPSGraph(stream, cachedGraph->graph(), feeds, gradInputPlaceholder); - return; - } -} - -TORCH_IMPL_FUNC(softshrink_backward_out_mps) -(const Tensor& grad_output, const Tensor& self, const Scalar& lambd, const Tensor& grad_input) { - return shrink_backward_out_mps(grad_output, self, lambd, grad_input, "softshrink_backward_out_mps"); -} - Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { using namespace mps; @@ -1752,105 +1496,6 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { } } -TORCH_IMPL_FUNC(hardsigmoid_out_mps)(const Tensor& self, const Tensor& result) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - TORCH_CHECK(self.is_mps()); - - // Empty output - if (result.numel() == 0) - return; - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "hardsigmoid_out_mps:" + getTensorsStringKey({self}); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* sixTensor = [mpsGraph constantWithScalar:6.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* inputPlusThreeTensor = [mpsGraph additionWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* outputTensor = [mpsGraph clampWithTensor:inputPlusThreeTensor - minValueTensor:zeroTensor - maxValueTensor:sixTensor - name:nil]; - outputTensor = [mpsGraph divisionWithPrimaryTensor:outputTensor secondaryTensor:sixTensor name:nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } -} - -TORCH_IMPL_FUNC(hardsigmoid_backward_out_mps) -(const Tensor& grad_output, const Tensor& self, const Tensor& grad_input) { - using namespace mps; - using CachedGraph = MPSUnaryGradCachedGraph; - TORCH_CHECK(self.is_mps()); - - // Empty output - if (grad_input.numel() == 0) - return; - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "hardsigmoid_backward_out_mps:" + getTensorsStringKey({self}); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* highTensor = [mpsGraph constantWithScalar:3.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:-3.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* oneSixTensor = [mpsGraph constantWithScalar:1.0 / 6.0 - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - MPSGraphTensor* inputLessThanHighPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:highTensor - name:nil]; - MPSGraphTensor* inputGreaterThanLowPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor - secondaryTensor:lowTensor - name:nil]; - MPSGraphTensor* inIntervalTensor = [mpsGraph logicalANDWithPrimaryTensor:inputLessThanHighPredicateTensor - secondaryTensor:inputGreaterThanLowPredicateTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradOutputTensor - secondaryTensor:oneSixTensor - name:nil]; - - outputTensor = [mpsGraph selectWithPredicateTensor:inIntervalTensor - truePredicateTensor:outputTensor - falsePredicateTensor:zeroTensor - name:nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->gradInputTensor_ = outputTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); - - // Create dictionary of inputs and outputs - auto feeds = dictionaryFromPlaceholders(selfPlaceholder, gradOutputPlaceholder); - auto results = dictionaryFromPlaceholders(gradInputPlaceholder); - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } -} - // ------------------------------------------------- // Hardtanh backward @@ -1937,176 +1582,4 @@ Tensor hardtanh_backward_mps(const Tensor& grad_output, const Tensor& self, cons return grad_input; } -Tensor& hardswish_out_mps(const Tensor& self, Tensor& output) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - if (output.numel() == 0) { - return output; - } - - auto resultMemFormat = output.suggest_memory_format(); - bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && output.is_contiguous(resultMemFormat)); - Tensor out; - if (executeGatherOp && !output.is_contiguous(MemoryFormat::Contiguous)) { - out = at::empty_like(output, MemoryFormat::Contiguous); - } - - MPSStream* stream = at::mps::getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "hardswish_out_mps" + getTensorsStringKey({self}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - - MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - - MPSGraphTensor* negativeThreeTensor = [mpsGraph constantWithScalar:-3.0f - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* sixTensor = [mpsGraph constantWithScalar:6.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - - MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph lessThanOrEqualToWithPrimaryTensor:inputTensor - secondaryTensor:negativeThreeTensor - name:nil]; - - MPSGraphTensor* lessThanMaxPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* inputPlusThreeTensor = [mpsGraph additionWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* inputDivSixTensor = [mpsGraph divisionWithPrimaryTensor:inputPlusThreeTensor - secondaryTensor:sixTensor - name:nil]; - - MPSGraphTensor* weightedTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:inputDivSixTensor - name:nil]; - - MPSGraphTensor* tempTensor = [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor - truePredicateTensor:weightedTensor - falsePredicateTensor:inputTensor - name:nil]; - - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor - truePredicateTensor:zeroTensor - falsePredicateTensor:tempTensor - name:nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - Placeholder outputPlaceholder = - Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : output, nil, false); - - // Create dictionary of inputs and outputs - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - auto results = dictionaryFromPlaceholders(outputPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - if (out.has_storage()) { - output.copy_(out); - } - } - return output; -} - -Tensor hardswish_mps(const Tensor& self) { - using namespace mps; - Tensor output = at::empty_like(self, self.suggest_memory_format()); - - return hardswish_out_mps(self, output); -} - -Tensor& hardswish_mps_(Tensor& self) { - using namespace mps; - Tensor& output = self; - - return hardswish_out_mps(self, output); -} - -Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) { - using namespace mps; - using CachedGraph = MPSUnaryGradCachedGraph; - - Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); - if (grad_input.numel() == 0) { - return grad_input; - } - - @autoreleasepool { - std::string key = "hardswish_backward_mps" + getTensorsStringKey({self}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* negativeThreeTensor = [mpsGraph constantWithScalar:-3.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* halfTensor = [mpsGraph constantWithScalar:0.5f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* tempTensor = [mpsGraph divisionWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* weightedTensor = [mpsGraph additionWithPrimaryTensor:tempTensor - secondaryTensor:halfTensor - name:nil]; - - MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph lessThanOrEqualToWithPrimaryTensor:inputTensor - secondaryTensor:negativeThreeTensor - name:nil]; - - MPSGraphTensor* lessThanMaxPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* lessThanMaxGradTensor = [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor - truePredicateTensor:weightedTensor - falsePredicateTensor:unitTensor - name:nil]; - - MPSGraphTensor* gradTensor = [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor - truePredicateTensor:zeroTensor - falsePredicateTensor:lessThanMaxGradTensor - name:nil]; - MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor - secondaryTensor:gradOutputTensor - name:nil]; - - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - }); - - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); - - // Create dictionary of inputs and outputs - auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfPlaceholder); - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, gradInputPlaceholder); - } - return grad_input; -} } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/ActivationKernel.mm b/aten/src/ATen/native/mps/operations/ActivationKernel.mm new file mode 100644 index 00000000000000..cec8bfa2312e47 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/ActivationKernel.mm @@ -0,0 +1,62 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include + +namespace at::native { + +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif + +static void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) { + lib.exec_unary_kernel(iter, "hardshrink", lambda); +} + +static void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) { + lib.exec_unary_kernel(iter, "softshrink", lambda); +} + +static void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) { + lib.exec_binary_kernel(iter, "shrink_backward", lambda); +} + +static void hardsigmoid_kernel(TensorIteratorBase& iter) { + lib.exec_unary_kernel(iter, "hardsigmoid"); +} + +static void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "hardsigmoid_backward"); +} + +static void hardswish_kernel(at::TensorIterator& iter) { + lib.exec_unary_kernel(iter, "hardswish"); +} + +static void hardswish_backward_kernel(at::TensorIterator& iter) { + lib.exec_binary_kernel(iter, "hardswish_backward"); +} + +static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negative_slope) { + lib.exec_unary_kernel(iter, "leaky_relu", negative_slope); +} + +static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negative_slope) { + lib.exec_binary_kernel(iter, "leaky_relu_backward", negative_slope); +} + +REGISTER_DISPATCH(hardshrink_stub, hardshrink_kernel); +REGISTER_DISPATCH(softshrink_stub, softshrink_kernel); +REGISTER_DISPATCH(shrink_backward_stub, shrink_backward_kernel); +REGISTER_DISPATCH(hardsigmoid_stub, hardsigmoid_kernel); +REGISTER_DISPATCH(hardsigmoid_backward_stub, hardsigmoid_backward_kernel); +REGISTER_DISPATCH(hardswish_stub, hardswish_kernel); +REGISTER_DISPATCH(hardswish_backward_stub, hardswish_backward_kernel); +REGISTER_DISPATCH(leaky_relu_stub, leaky_relu_kernel); +REGISTER_DISPATCH(leaky_relu_backward_stub, leaky_relu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 0cbdf7132c70c7..806eeb82e1d17d 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -119,6 +119,30 @@ static void chebyshev_polynomial_w_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "chebyshev_polynomial_w"); } +static void shifted_chebyshev_polynomial_t_mps_kernel(TensorIteratorBase& iter) { + TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), + "shifted_chebyshev_polynomial_t_mps not implemented for non-floating types"); + lib.exec_binary_kernel(iter, "shifted_chebyshev_polynomial_t"); +} + +static void shifted_chebyshev_polynomial_u_mps_kernel(TensorIteratorBase& iter) { + TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), + "shifted_chebyshev_polynomial_u_mps not implemented for non-floating types"); + lib.exec_binary_kernel(iter, "shifted_chebyshev_polynomial_u"); +} + +static void shifted_chebyshev_polynomial_v_mps_kernel(TensorIteratorBase& iter) { + TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), + "shifted_chebyshev_polynomial_v_mps not implemented for non-floating types"); + lib.exec_binary_kernel(iter, "shifted_chebyshev_polynomial_v"); +} + +static void shifted_chebyshev_polynomial_w_mps_kernel(TensorIteratorBase& iter) { + TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), + "shifted_chebyshev_polynomial_w_mps not implemented for non-floating types"); + lib.exec_binary_kernel(iter, "shifted_chebyshev_polynomial_w"); +} + static void hermite_polynomial_h_mps_kernel(TensorIteratorBase& iter) { TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "hermite_polynomial_h_mps not implemented for non-floating types"); @@ -177,6 +201,10 @@ static void fmod_mps_kernel(TensorIteratorBase& iter) { REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_mps_kernel) +REGISTER_DISPATCH(shifted_chebyshev_polynomial_t_stub, &shifted_chebyshev_polynomial_t_mps_kernel) +REGISTER_DISPATCH(shifted_chebyshev_polynomial_u_stub, &shifted_chebyshev_polynomial_u_mps_kernel) +REGISTER_DISPATCH(shifted_chebyshev_polynomial_v_stub, &shifted_chebyshev_polynomial_v_mps_kernel) +REGISTER_DISPATCH(shifted_chebyshev_polynomial_w_stub, &shifted_chebyshev_polynomial_w_mps_kernel) REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_mps_kernel) REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_mps_kernel) REGISTER_DISPATCH(polar_stub, &polar_mps_kernel); diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 72826c12873021..97d562730dd8a2 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -700,7 +700,9 @@ Tensor _mps_convolution_transpose(const Tensor& input_t, IntArrayRef stride, IntArrayRef dilation, int64_t groups) { - TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS"); + bool is_unsupported_3d_dtype = + (input_t.dim() == 5 && (input_t.scalar_type() == kHalf || input_t.scalar_type() == kBFloat16)); + TORCH_CHECK(!is_unsupported_3d_dtype, "ConvTranspose 3D with BF16 or FP16 types is not supported on MPS"); auto output_t = mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups); diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index ec9fd53ca08143..a226a7327b8420 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -67,7 +67,9 @@ std::vector iterShapeData(iterShape.size()); std::vector> strides(nDim); TORCH_INTERNAL_ASSERT(iter.ntensors() >= nOffsets); - TORCH_CHECK(use_64bit_index || iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator"); + TORCH_CHECK(use_64bit_index || iter.can_use_32bit_indexing(), + "kernel data offsets can't be computed using 32-bit iterator of shape ", + iterShape); for (const auto i : c10::irange(iterShape.size())) { iterShapeData[i] = static_cast(iterShape[i]); diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 00d622c7cdf821..42769c13f1e1b6 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -17,7 +17,7 @@ static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const MPSStream* mpsStream = getCurrentMPSStream(); id device = MPSDevice::getInstance()->device(); - const string key = "mps_linear" + getTensorsStringKey({input, weight, bias}, true, true); + const std::string key = "mps_linear" + getTensorsStringKey({input, weight, bias}, true, true); dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { mpsStream->endKernelCoalescing(); @@ -35,14 +35,15 @@ static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const shape:getMPSShape(weight.sizes())]; weightDesc.preferPackedRows = YES; [weightDesc transposeDimension:0 withDimension:1]; - MPSNDArray* weightNDArray = [[MPSNDArray alloc] initWithBuffer:weightBuf - offset:weight.storage_offset() * weight.element_size() - descriptor:weightDesc]; + MPSNDArray* weightNDArray = [[[MPSNDArray alloc] initWithBuffer:weightBuf + offset:weight.storage_offset() * weight.element_size() + descriptor:weightDesc] autorelease]; if (is_bias_defined) { auto biasNDArray = getMPSNDArray(bias, bias.sizes(), bias.strides()); - auto cachedKernel = LookUpOrCreateCachedKernel( - key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3]; }); + auto cachedKernel = LookUpOrCreateCachedKernel(key, [&]() { + return [[[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3] autorelease]; + }); auto kernel = cachedKernel->kernel(); getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias}); @@ -52,8 +53,9 @@ static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const destinationArray:outNDArray]; getMPSProfiler().endProfileKernel(kernel); } else { - auto cachedKernel = LookUpOrCreateCachedKernel( - key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2]; }); + auto cachedKernel = LookUpOrCreateCachedKernel(key, [&]() { + return [[[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2] autorelease]; + }); auto kernel = cachedKernel->kernel(); getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias}); [kernel encodeToCommandEncoder:computeEncoder diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 9be8ca1cc6513f..3cdf0021e987fc 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -2,6 +2,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include #include @@ -22,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -1097,25 +1097,8 @@ static void lu_unpack_mps_impl(const Tensor& LU_data, } } -static void linalg_cholesky_mps_impl(const Tensor& input, - bool upper, - bool check_errors, - const Tensor& out, - const Tensor& info) { - using namespace mps; - - TORCH_CHECK(out.is_mps()); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "linalg.cholesky: Input tensor must be float32"); - TORCH_CHECK(input.dim() >= 2, "linalg.cholesky: Input tensor must be at least 2D"); - TORCH_CHECK(input.size(-2) == input.size(-1), "linalg.cholesky: Input tensor must be square"); - auto input_sizes = input.sizes(); - resize_output(out, input_sizes); - resize_output(info, {input_sizes.begin(), input_sizes.end() - 2}); - if (input.numel() == 0) { - info.zero_(); - return; - } - out.copy_(input); +static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper) { + auto input_sizes = out.sizes(); int64_t ndim = out.dim(); int64_t N = out.size(-1); @@ -1124,9 +1107,9 @@ static void linalg_cholesky_mps_impl(const Tensor& input, auto stream = getCurrentMPSStream(); auto device = MPSDevice::getInstance()->device(); - auto factorDiagonalPSO = lib.getPipelineStateForFunc("factorDiagonalBlock"); - auto applyTRSMPSO = lib.getPipelineStateForFunc("applyTRSM"); - auto applySYRKPSO = lib.getPipelineStateForFunc("applySYRK"); + auto factorDiagonalPSO = lib.getPipelineStateForFunc(upper ? "factorDiagonalBlockU" : "factorDiagonalBlockL"); + auto applyTRSMPSO = lib.getPipelineStateForFunc(upper ? "applyTRSMU" : "applyTRSML"); + auto applySYRKPSO = lib.getPipelineStateForFunc(upper ? "applySYRKU" : "applySYRKL"); int64_t NB = std::min(32, N); int64_t numBlocks = (N + NB - 1) / NB; @@ -1168,33 +1151,8 @@ static void linalg_cholesky_mps_impl(const Tensor& input, } }); } - int status; - if (check_errors) { - if (info_.dim() > 0) { - // batch case - for (const auto i : c10::irange(B)) { - status = info_[i].item(); - TORCH_CHECK( - status == 0, - "linalg.cholesky(): (Batch element ", - i, - "): The factorization could not be completed because the input is not positive-definite (the leading minor of order ", - status, - " is not positive-definite)."); - } - } else { - // single matrix case(no batch size) - status = info.item(); - TORCH_CHECK( - status == 0, - "linalg.cholesky(): The factorization could not be completed because the input is not positive-definite (the leading minor of order ", - status, - " is not positive-definite)."); - } - } - out.tril_(); - upper ? out.transpose_(ndim - 2, ndim - 1) : out; } + } // namespace mps Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) { @@ -1355,23 +1313,6 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons return result; } -Tensor cholesky_mps(const Tensor& self, bool upper) { - auto out = at::empty_like(self, MemoryFormat::Contiguous); - cholesky_mps_out(self, upper, out); - return out; -} - -Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) { - auto info = at::empty({}, self.options().dtype(kInt)); - mps::linalg_cholesky_mps_impl(self, upper, true, out, info); - return out; -} - -TORCH_IMPL_FUNC(linalg_cholesky_ex_out_mps) -(const Tensor& self, bool upper, bool check_errors, const Tensor& L, const Tensor& info) { - mps::linalg_cholesky_mps_impl(self, upper, check_errors, L, info); -} - Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, @@ -1460,4 +1401,6 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info); } + +REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl) } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index 57656075b68000..f5264cf32d9f24 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -597,7 +597,10 @@ Check if running mean exists (maybe do this check before making graph) const bool has_weight = (weight_opt.has_value() && weight_opt->defined()); - if (grad_input.numel() == 0) { + bool any_grad_needed = (grad_input_mask[0] && grad_input.numel() > 0) || + (grad_input_mask[1] && grad_weight.numel() > 0) || (grad_input_mask[2] && grad_bias.numel() > 0); + + if (!any_grad_needed) { return std::make_tuple(grad_input, grad_weight, grad_bias); } diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index ce467efd7b5173..55a48c76627960 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -1,7 +1,10 @@ // Copyright © 2022 Apple Inc. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -15,9 +18,17 @@ #include #include #include +#include #endif namespace at::native { + +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif + namespace mps { struct PoolingCachedGraph : public MPSCachedGraph { @@ -240,6 +251,191 @@ static void pool2d_template(const Tensor& input, } } +static Tensor intarrayref_to_tensor(IntArrayRef arrayref) { + at::Tensor tensor = + at::empty({static_cast(arrayref.size())}, + TensorOptions().device(c10::kCPU).dtype(at::kLong).memory_format(at::MemoryFormat::Contiguous)); + std::memcpy(tensor.data_ptr(), arrayref.data(), arrayref.size() * sizeof(int64_t)); + return tensor; +} + +// NOTE: output is only valid as long as the tensor stays alive and its shape +// doesn't change. +static IntArrayRef tensor_to_intarrayref(const Tensor& tensor) { + TORCH_INTERNAL_ASSERT(tensor.dim() == 1); + TORCH_INTERNAL_ASSERT(tensor.scalar_type() == at::kLong); + TORCH_INTERNAL_ASSERT(tensor.device().type() == at::kCPU); + auto data_ptr = tensor.data_ptr(); + auto length = tensor.size(0); + return IntArrayRef(data_ptr, length); +} + +static void max_pool_with_indices_out_mps_template(const Tensor& output, + const Tensor& indices, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + const int32_t pooling_dims, + const std::string& op_name) { + TORCH_INTERNAL_ASSERT(pooling_dims == 1 || pooling_dims == 2 || pooling_dims == 3); + + const int32_t dims = input.dim(); + + TORCH_CHECK(dims == pooling_dims + 1 || dims == pooling_dims + 2, + op_name, + ": non-empty ", + pooling_dims + 1, + "D or ", + pooling_dims + 2, + "D (batch mode) tensor expected for input"); + + TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == pooling_dims, + op_name, + ": kernel_size must either be a single int, or a tuple of ", + pooling_dims, + " ints"); + + TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3, + op_name, + ": stride must either be omitted, a single int, or a tuple of ", + pooling_dims, + " ints"); + + TORCH_CHECK(padding.size() == 1 || padding.size() == 3, + op_name, + ": padding must either be a single int, or a tuple of ", + pooling_dims, + " ints"); + + TORCH_CHECK(dilation.size() == 1 || dilation.size() == pooling_dims, + op_name, + ": dilation must be either a single int, or a tuple of ", + pooling_dims, + " ints"); + + int32_t leading_dims = input.dim() - pooling_dims; + + at::Tensor t_input_size = intarrayref_to_tensor(input.sizes()); + at::Tensor t_input_pooling_size = t_input_size.slice(/*dim=*/0, /*start=*/leading_dims); + + at::Tensor t_kernel_size = intarrayref_to_tensor(kernel_size); + if (kernel_size.size() == 1) { + t_kernel_size.repeat(pooling_dims); + } + + at::Tensor t_stride = stride.empty() ? t_kernel_size.clone() : intarrayref_to_tensor(stride); + if (!stride.empty() && stride.size() == 1) { + t_stride.repeat(pooling_dims); + } + + at::Tensor t_padding = intarrayref_to_tensor(padding); + if (padding.size() == 1) { + t_padding.repeat(pooling_dims); + } + + TORCH_CHECK((t_padding.ge(0)).all().item(), op_name, ": pad must be non-negative"); + + TORCH_CHECK((t_padding.mul(2).le(t_kernel_size).all().item()), + op_name, + ": pad should be at most half of effective kernel size"); + + TORCH_CHECK(t_input_size.slice(0, leading_dims - 1).gt(0).all().item(), + op_name, + ": Expected input's non-batch dimensions to have positive length"); + + at::Tensor t_dilation = intarrayref_to_tensor(dilation); + if (dilation.size() == 1) { + t_dilation.repeat(pooling_dims); + } + + at::Tensor t_output_size = t_input_size.clone(); + + auto divide = [](const Tensor& a, const Tensor& b, bool ceil_mode) { + Tensor res = a.div(b); + + if (ceil_mode) { + Tensor res_ceil = res.ceil(); + return res_ceil.to(a.scalar_type()); + } else { + Tensor res_floor = res.floor(); + return res_floor.to(a.scalar_type()); + } + }; + + // According to the documentation, the output size of each pooling dimension + // follows this basic formula: + // (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + + at::Tensor t_output_pooling_size = + t_input_pooling_size.add(t_padding.mul(2)).sub(t_dilation.mul(t_kernel_size.sub(1))).sub(1); + + if (ceil_mode) { + t_output_pooling_size = t_output_pooling_size.add(t_stride).sub(1); + } + + t_output_pooling_size = t_output_pooling_size.floor_divide(t_stride).add(1); + + if (ceil_mode) { + t_output_pooling_size = t_output_pooling_size.sub(t_output_pooling_size.sub(1) + .mul(t_stride) + .ge(t_input_pooling_size.add(t_padding)) + .to(t_output_pooling_size.scalar_type())); + } + + t_output_size.slice(0, leading_dims) = t_output_pooling_size; + + IntArrayRef output_size = tensor_to_intarrayref(t_output_size); + output.resize_(output_size); + indices.resize_(output_size); + + auto iter = TensorIteratorConfig().add_output(output).resize_outputs(false).check_all_same_dtype(false).build(); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + const auto numThreads = iter.numel(); + TORCH_INTERNAL_ASSERT(numThreads == output.numel()); + + PoolingParams<5> params; + + params.dims = dims; + params.pooling_dims = pooling_dims; + memcpy(params.input_sizes.data(), input.sizes().data(), dims * sizeof(int64_t)); + memcpy(params.input_strides.data(), input.strides().data(), dims * sizeof(int64_t)); + memcpy(params.output_strides.data(), output.strides().data(), dims * sizeof(int64_t)); + memcpy(params.output_sizes.data(), output.sizes().data(), dims * sizeof(int64_t)); + memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t)); + memcpy(params.indices_sizes.data(), indices.sizes().data(), dims * sizeof(int64_t)); + memcpy(params.kernel_size.data(), t_kernel_size.data_ptr(), pooling_dims * sizeof(int64_t)); + memcpy(params.stride.data(), t_stride.data_ptr(), pooling_dims * sizeof(int64_t)); + memcpy(params.padding.data(), t_padding.data_ptr(), pooling_dims * sizeof(int64_t)); + memcpy(params.dilation.data(), t_dilation.data_ptr(), pooling_dims * sizeof(int64_t)); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto maxPoolPSO = lib.getPipelineStateForFunc("max_pool_" + scalarToMetalTypeString(input)); + + // Each thread needs to keep track of the indices into the pooling + // dimensions for the element of the output that it calculates. In other + // words, if the thread calculates `output[N, C, d, h, w]` for a 3D pool, + // the kernel needs to keep track of the indices `[d, h, w]`. So we create + // a device-side buffer for the threads to store these indices. + id work_pooling_dim_indices = [[device newBufferWithLength:numThreads * pooling_dims * sizeof(int64_t) + options:0] autorelease]; + + getMPSProfiler().beginProfileKernel(maxPoolPSO, op_name, {input}); + [computeEncoder setComputePipelineState:maxPoolPSO]; + mtl_setArgs(computeEncoder, input, output, indices, work_pooling_dim_indices, params); + + mtl_dispatch1DJob(computeEncoder, maxPoolPSO, numThreads); + getMPSProfiler().endProfileKernel(maxPoolPSO); + } + }); +} + static void avg_pool2d_template(const Tensor& input, const Tensor& output, const std::optional& grad_output_opt, @@ -493,6 +689,55 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, "max_pool2d_indices_backward"); } +std::tuple max_pool3d_with_indices_out_mps(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + Tensor& output, + Tensor& indices) { + mps::max_pool_with_indices_out_mps_template(output, + indices, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + /*pooling_dims=*/3, + "max_pool3d"); + return std::tuple(output, indices); +} + +std::tuple max_pool3d_with_indices_mps(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + NoNamesGuard guard; + + Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous); + Tensor indices = at::empty({0}, input.options().dtype(kLong), MemoryFormat::Contiguous); + mps::max_pool_with_indices_out_mps_template(output, + indices, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + /*pooling_dims=*/3, + "max_pool3d"); + + guard.reset(); + namedinference::propagate_names(output, input); + namedinference::propagate_names(indices, input); + + return std::tuple(output, indices); +} + TORCH_IMPL_FUNC(avg_pool2d_out_mps) (const Tensor& input, int64_t kH, diff --git a/aten/src/ATen/native/mps/operations/RangeFactories.mm b/aten/src/ATen/native/mps/operations/RangeFactories.mm index 613db5c5f489a4..4c1631f0f11072 100644 --- a/aten/src/ATen/native/mps/operations/RangeFactories.mm +++ b/aten/src/ATen/native/mps/operations/RangeFactories.mm @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -65,14 +66,7 @@ size_d = std::ceil(static_cast(end.to() - start.to()) / step.to()); } - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && std::isfinite(static_cast(xend)), - "unsupported range: ", - xstart, - " -> ", - xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), "invalid size, possible overflow?"); @@ -147,14 +141,7 @@ size_d = static_cast(end.to() - start.to()) / step.to() + 1; } - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && std::isfinite(static_cast(xend)), - "unsupported range: ", - xstart, - " -> ", - xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), "invalid size, possible overflow?"); diff --git a/aten/src/ATen/native/mps/operations/ScanKernel.mm b/aten/src/ATen/native/mps/operations/ScanKernel.mm new file mode 100644 index 00000000000000..9e3269d970143c --- /dev/null +++ b/aten/src/ATen/native/mps/operations/ScanKernel.mm @@ -0,0 +1,418 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif +#include + +namespace at::native { +namespace mps { + +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif + +// Generic scan implementation that handles both simple scans and scans with indices +static void scan_mps_impl(const Tensor& self, + const std::vector& outputs, + int64_t dim, + const std::string& op_name) { + if (outputs[0].numel() == 0) { + return; + } + + const int64_t ndim = self.dim(); + const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim); + + // Calculate dimensions for scan operation + int64_t row_size = self.size(wrapped_dim); + auto sizes = self.sizes(); + + bool is_innermost = (wrapped_dim == ndim - 1); + + // Check if all tensors are contiguous + bool is_contiguous = self.is_contiguous(); + for (const auto& output : outputs) { + is_contiguous = is_contiguous && output.is_contiguous(); + } + + uint32_t num_rows, num_orows, num_irows, num_threads; + + if (is_innermost) { + // Treat all outer dimensions as a single dimension + num_rows = self.numel() / row_size; + num_threads = num_rows; + } else { + // Treat all outer dimensions (i.e. dim_ < dim) as one + num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + wrapped_dim); + // Treat all inner dimensions (i.e. dim > dimension) as one + num_irows = c10::multiply_integers(sizes.begin() + wrapped_dim + 1, sizes.end()); + num_threads = num_orows * num_irows; + } + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + + // Choose kernel based on contiguity and dimension + std::string kernel_name; + if (is_contiguous) { + kernel_name = + op_name + "_contiguous_" + (is_innermost ? "innermost_" : "outer_") + scalarToMetalTypeString(self); + } else { + kernel_name = op_name + "_strided_" + scalarToMetalTypeString(self); + } + + id scanPSO = lib.getPipelineStateForFunc(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() { + std::vector all_tensors = {self}; + all_tensors.insert(all_tensors.end(), outputs.begin(), outputs.end()); + return all_tensors; + }()); + + [computeEncoder setComputePipelineState:scanPSO]; + + // Set input tensor + mtl_setBuffer(computeEncoder, self, 0); + + // Set output tensors + for (size_t i = 0; i < outputs.size(); ++i) { + mtl_setBuffer(computeEncoder, outputs[i], i + 1); + } + + if (is_contiguous) { + // Contiguous kernels + if (is_innermost) { + if (outputs.size() == 1) { + // Simple scan + mtl_setArgs<2>(computeEncoder, num_rows, static_cast(row_size)); + } else { + // Scan with indices + mtl_setArgs<3>(computeEncoder, num_rows, static_cast(row_size)); + } + } else { + if (outputs.size() == 1) { + // Simple scan + mtl_setArgs<2>(computeEncoder, num_orows, num_irows, static_cast(row_size)); + } else { + // Scan with indices + mtl_setArgs<3>(computeEncoder, num_orows, num_irows, static_cast(row_size)); + } + } + } else { + // Strided kernels - pass full tensor information + if (outputs.size() == 1) { + // Simple scan + mtl_setArgs<2>(computeEncoder, + self.sizes(), + self.strides(), + outputs[0].strides(), + static_cast(self.ndimension()), + static_cast(wrapped_dim)); + } else { + // Scan with indices + mtl_setArgs<3>(computeEncoder, + self.sizes(), + self.strides(), + outputs[0].strides(), + outputs[1].strides(), + static_cast(self.ndimension()), + static_cast(wrapped_dim)); + } + } + + mtl_dispatch1DJob(computeEncoder, scanPSO, num_threads); + + getMPSProfiler().endProfileKernel(scanPSO); + } + }); +} + +// Utility function to get 2D grid dimensions for dispatch +static std::pair get_2d_grid_dims(const IntArrayRef& shape, const int64_t dim) { + size_t grid_x = 1; + size_t grid_y = 1; + + for (const auto i : c10::irange(dim)) { + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + } + + TORCH_CHECK(grid_y <= UINT32_MAX && grid_x <= UINT32_MAX, "Unable to safely factor shape for grid dimensions."); + + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + + return {static_cast(grid_x), static_cast(grid_y)}; +} + +static void scan_simple_mps_impl(const Tensor& self, const Tensor& output, int64_t dim, const std::string& op_name) { + if (output.numel() == 0) { + return; + } + + const int64_t ndim = self.dim(); + const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim); + const int64_t axis_size = self.size(wrapped_dim); + + // Preprocess input tensor - ensure it's contiguous for Metal shaders + Tensor input_tensor = self.contiguous(); + + // Preprocess output tensor - ensure it's contiguous for Metal shaders + Tensor output_tensor = output; + bool output_needs_copy = !output.is_contiguous(); + Tensor temp_output; + + if (output_needs_copy) { + // Create a temporary contiguous tensor with the same shape and type + temp_output = at::empty_like(output, output.options()).contiguous(); + output_tensor = temp_output; + } + + // Determine which kernel to use based on scan dimension position + bool is_innermost_scan = (wrapped_dim == ndim - 1); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + + // Build kernel name based on scan dimension position + const auto type_str = scalarToMetalTypeString(input_tensor); + const auto kernel_name = fmt::format("{}_{}_{}", op_name, is_innermost_scan ? "innermost" : "outer", type_str); + + id scanPSO = lib.getPipelineStateForFunc(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() { + std::vector all_tensors = {input_tensor, output_tensor}; + return all_tensors; + }()); + + [computeEncoder setComputePipelineState:scanPSO]; + + // Set input and output buffers (both guaranteed contiguous) + mtl_setBuffer(computeEncoder, input_tensor, 0); + mtl_setBuffer(computeEncoder, output_tensor, 1); + + if (is_innermost_scan) { + // Contiguous scan dispatch (scanning innermost dimension) + mtl_setBytes(computeEncoder, axis_size, 2); + + int n_reads = (input_tensor.element_size() <= 4) ? 4 : 2; + constexpr int simd_size = 32; + int elements_per_simd = n_reads * simd_size; + int thread_group_size = static_cast(scanPSO.maxTotalThreadsPerThreadgroup); + + if (axis_size <= n_reads * 1024) { + thread_group_size = ((axis_size + elements_per_simd - 1) / elements_per_simd) * simd_size; + } else if (axis_size <= n_reads * 2048) { + thread_group_size = ((axis_size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size; + } + thread_group_size = std::min(thread_group_size, static_cast(scanPSO.maxTotalThreadsPerThreadgroup)); + + auto tmp_grid_dims = get_2d_grid_dims(input_tensor.sizes(), wrapped_dim); + + [computeEncoder dispatchThreads:MTLSizeMake(thread_group_size, tmp_grid_dims.first, tmp_grid_dims.second) + threadsPerThreadgroup:MTLSizeMake(thread_group_size, 1, 1)]; + } else { + // Strided scan dispatch (scanning non-innermost dimension) + size_t stride = input_tensor.strides()[wrapped_dim]; + constexpr int bn = 32; + size_t stride_blocks = (stride + bn - 1) / bn; + + mtl_setBytes(computeEncoder, axis_size, 2); + mtl_setBytes(computeEncoder, stride, 3); + mtl_setBytes(computeEncoder, stride_blocks, 4); + + int n_reads = (input_tensor.element_size() <= 4) ? 4 : 2; + int n_simdgroups = bn / n_reads; + constexpr int simd_size = 32; + int thread_group_size = n_simdgroups * simd_size; + + auto tmp_grid_dims = get_2d_grid_dims(input_tensor.sizes(), wrapped_dim); + if (tmp_grid_dims.first * stride_blocks <= UINT_MAX) { + tmp_grid_dims.first *= stride_blocks; + } else { + tmp_grid_dims.second *= stride_blocks; + } + + [computeEncoder dispatchThreads:MTLSizeMake(thread_group_size, tmp_grid_dims.first, tmp_grid_dims.second) + threadsPerThreadgroup:MTLSizeMake(thread_group_size, 1, 1)]; + } + + getMPSProfiler().endProfileKernel(scanPSO); + } + }); + + // Post-process: copy result back to original output tensor if needed + if (output_needs_copy) { + output.copy_(output_tensor); + } +} + +// Specialized implementation for cummin/cummax that returns both values and indices +static void scan_with_indices_mps_impl(const Tensor& self, + const Tensor& values_output, + const Tensor& indices_output, + int64_t dim, + const std::string& op_name) { + if (values_output.numel() == 0) { + return; + } + + const int64_t ndim = self.dim(); + const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim); + const int64_t axis_size = self.size(wrapped_dim); + + // Preprocess input tensor - ensure it's contiguous for Metal shaders + auto input_tensor = self.contiguous(); + + // Preprocess output tensors - ensure they're contiguous for Metal shaders + auto values_tensor = values_output.contiguous(); + auto indices_tensor = indices_output.contiguous(); + const bool values_needs_copy = !values_output.is_contiguous(); + const bool indices_needs_copy = !indices_output.is_contiguous(); + + // Determine which kernel to use based on scan dimension position + bool is_innermost_scan = (wrapped_dim == ndim - 1); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + + // Build kernel name based on scan type + const auto type_str = scalarToMetalTypeString(input_tensor); + const auto kernel_name = fmt::format("{}_{}_{}", op_name, is_innermost_scan ? "innermost" : "outer", type_str); + + id scanPSO = lib.getPipelineStateForFunc(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(scanPSO, op_name, {input_tensor, values_tensor, indices_tensor}); + + [computeEncoder setComputePipelineState:scanPSO]; + + // Set input and output buffers (all guaranteed contiguous) + mtl_setArgs(computeEncoder, input_tensor, values_tensor, indices_tensor); + + constexpr int simd_size = 32; + + if (is_innermost_scan) { + // Contiguous scan dispatch (scanning innermost dimension) + mtl_setArgs<3>(computeEncoder, axis_size); + + int n_reads = (input_tensor.element_size() <= 4) ? 4 : 2; + + int elements_per_simd = n_reads * simd_size; + int thread_group_size = static_cast(scanPSO.maxTotalThreadsPerThreadgroup); + + if (axis_size <= n_reads * 1024) { + thread_group_size = ((axis_size + elements_per_simd - 1) / elements_per_simd) * simd_size; + } else if (axis_size <= n_reads * 2048) { + thread_group_size = ((axis_size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size; + } + thread_group_size = std::min(thread_group_size, static_cast(scanPSO.maxTotalThreadsPerThreadgroup)); + + auto tmp_grid_dims = get_2d_grid_dims(input_tensor.sizes(), wrapped_dim); + + [computeEncoder dispatchThreads:MTLSizeMake(thread_group_size, tmp_grid_dims.first, tmp_grid_dims.second) + threadsPerThreadgroup:MTLSizeMake(thread_group_size, 1, 1)]; + } else { + // Strided scan dispatch (scanning non-innermost dimension) + size_t stride = input_tensor.strides()[wrapped_dim]; + constexpr int bn = 32; + size_t stride_blocks = (stride + bn - 1) / bn; + + mtl_setArgs<3>(computeEncoder, axis_size, stride, stride_blocks); + + int n_reads = (input_tensor.element_size() <= 4) ? 4 : 2; + int n_simdgroups = bn / n_reads; + int thread_group_size = n_simdgroups * simd_size; + + auto tmp_grid_dims = get_2d_grid_dims(input_tensor.sizes(), wrapped_dim); + if (tmp_grid_dims.first * stride_blocks <= UINT_MAX) { + tmp_grid_dims.first *= stride_blocks; + } else { + tmp_grid_dims.second *= stride_blocks; + } + + [computeEncoder dispatchThreads:MTLSizeMake(thread_group_size, tmp_grid_dims.first, tmp_grid_dims.second) + threadsPerThreadgroup:MTLSizeMake(thread_group_size, 1, 1)]; + } + + getMPSProfiler().endProfileKernel(scanPSO); + } + }); + + // Post-process: copy results back to original output tensors if needed + if (values_needs_copy) { + values_output.copy_(values_tensor); + } + if (indices_needs_copy) { + indices_output.copy_(indices_tensor); + } +} + +} // namespace mps + +void cummax_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { + mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax"); + } else { + mps::scan_mps_impl(self, {values, indices}, dim, "cummax"); + } +} + +void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { + mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin"); + } else { + mps::scan_mps_impl(self, {values, indices}, dim, "cummin"); + } +} + +Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) { + const auto wrap_dim = maybe_wrap_dim(dim, self.dim()); + result.resize_(self.sizes()); + if (self.dim() == 0) { + result.fill_(self); + return result; + } + if (self.numel() == 0) { + result.zero_(); + return result; + } + + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { + mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp"); + } else { + mps::scan_mps_impl(self, {result}, wrap_dim, "logcumsumexp"); + } + return result; +} + +Tensor _logcumsumexp_mps(const Tensor& self, int64_t dim) { + Tensor result = at::empty_like(self, MemoryFormat::Contiguous); + return _logcumsumexp_out_mps(self, dim, result); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index e209464e6371b6..19f26023b31793 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -87,6 +87,10 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in return; } + // issue #154890, raising error to prevent crash within MPSGraph until + // workaround is implemented. + TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890"); + MPSStream* stream = getCurrentMPSStream(); struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm index 6867bafc562eb7..647bac958ecae2 100644 --- a/aten/src/ATen/native/mps/operations/TriangularOps.mm +++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm @@ -5,6 +5,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -26,101 +27,53 @@ #include #endif -TORCH_IMPL_FUNC(triu_mps_out) -(const Tensor& self, int64_t k, const Tensor& output) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - if (self.numel() == 0) { - return; - } - auto stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* outputTensor = nil; - auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - auto minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32]; - - if (k > 0) { - auto diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32]; - auto onesTensor = [mpsGraph constantWithScalar:1 shape:inputTensor.shape dataType:MPSDataTypeInt32]; - auto maskTensor = [mpsGraph bandPartWithTensor:onesTensor - numLowerTensor:minusOneTensor - numUpperTensor:diagMinusOneTensor - name:nil]; - outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor - truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:inputTensor.dataType] - falsePredicateTensor:inputTensor - name:nil]; - } else { - auto minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32]; - outputTensor = [mpsGraph bandPartWithTensor:inputTensor - numLowerTensor:minusDiagTensor - numUpperTensor:minusOneTensor - name:nil]; - } - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - runMPSGraph(stream, cachedGraph->graph(), dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder); +template +static std::vector reverse_array(const IntArrayRef& arr) { + std::vector rc(arr.size()); + for (const auto& i : c10::irange(arr.size())) { + rc[i] = arr[arr.size() - 1 - i]; } + return rc; } -TORCH_IMPL_FUNC(tril_mps_out) -(const Tensor& self, int64_t k, const Tensor& output) { +static void triu_tril_impl(const Tensor& self, int64_t k, const Tensor& out, const std::string& name) { using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - if (self.numel() == 0) { return; } - + auto sizes = reverse_array(self.sizes()); + auto inp_strides = reverse_array(self.strides()); + auto out_strides = reverse_array(out.strides()); + std::array k_ndim = {int(k), int(self.ndimension())}; + const bool inplace = self.is_same(out); + const auto kernel_name = + fmt::format("{}{}_{}_{}", name, inplace ? "_inplace" : "", "int", scalarToMetalTypeString(self)); + auto triuPSO = lib.getPipelineStateForFunc(kernel_name); + uint32_t max_threads_per_group = [triuPSO maxTotalThreadsPerThreadgroup]; auto stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* outputTensor = nil; - - auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - auto minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32]; - - if (k >= 0) { - auto diagTensor = [mpsGraph constantWithScalar:k dataType:MPSDataTypeInt32]; - outputTensor = [mpsGraph bandPartWithTensor:inputTensor - numLowerTensor:minusOneTensor - numUpperTensor:diagTensor - name:nil]; + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:triuPSO]; + if (inplace) { + mtl_setArgs(computeEncoder, self, inp_strides, sizes, k_ndim); } else { - auto negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k - 1) dataType:MPSDataTypeInt32]; - auto complementTensor = [mpsGraph bandPartWithTensor:inputTensor - numLowerTensor:negDiagMinusOneTensor - numUpperTensor:minusOneTensor - name:nil]; - auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:getMPSDataType(self)]; - auto mask = [mpsGraph equalWithPrimaryTensor:complementTensor secondaryTensor:zeroTensor name:nil]; - outputTensor = [mpsGraph selectWithPredicateTensor:mask - truePredicateTensor:inputTensor - falsePredicateTensor:zeroTensor - name:nil]; + mtl_setArgs(computeEncoder, out, self, out_strides, inp_strides, sizes, k_ndim); } + [computeEncoder dispatchThreads:MTLSizeMake(sizes[0], sizes[1], self.numel() / (sizes[0] * sizes[1])) + threadsPerThreadgroup:MTLSizeMake(std::min(max_threads_per_group, sizes[0]), 1, 1)]; + } + }); +} - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); +TORCH_IMPL_FUNC(triu_mps_out) +(const Tensor& self, int64_t k, const Tensor& output) { + triu_tril_impl(self, k, output, "triu"); +} - runMPSGraph(stream, cachedGraph->graph(), dictionaryFromPlaceholders(selfPlaceholder), outputPlaceholder); - } +TORCH_IMPL_FUNC(tril_mps_out) +(const Tensor& self, int64_t k, const Tensor& output) { + triu_tril_impl(self, k, output, "tril"); } Tensor tril_indices_mps(int64_t row, diff --git a/aten/src/ATen/native/mps/operations/UnaryKernel.mm b/aten/src/ATen/native/mps/operations/UnaryKernel.mm index 48965c6d75ff0f..b560739ed40c30 100644 --- a/aten/src/ATen/native/mps/operations/UnaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/UnaryKernel.mm @@ -1,6 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include +// #include #include #include #include @@ -13,84 +14,42 @@ #include #endif -static void erfinv_kernel(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "erfinv"); -} - -static void exp_kernel(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "exp"); -} - -static void sinc_kernel(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "sinc"); -} - -static void tanh_kernel(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "tanh"); -} - -static void sin_kernel(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "sin"); -} - -static void cos_kernel(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "cos"); -} - -static void tan_kernel(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "tan"); -} +// KURT: call site of `exec_unary_kernel` +#define REGISTER_UNARY_TI_DISPATCH(NAME) \ + static void NAME##_kernel_mps(TensorIteratorBase& iter) { \ + lib.exec_unary_kernel(iter, #NAME); \ + } \ + REGISTER_DISPATCH(NAME##_stub, NAME##_kernel_mps) static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) { - lib.exec_unary_kernel(iter, "round_decimals", decimals); -} - -static void exp2_kernel_mps(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "exp2"); -} - -static void sqrt_kernel_mps(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "sqrt"); -} - -static void rsqrt_kernel_mps(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "rsqrt"); -} - -static void neg_kernel_mps(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "neg"); -} - -static void bitwise_not_kernel_mps(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "bitwise_not"); -} - -static void log10_kernel_mps(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "log10"); -} - -static void log2_kernel_mps(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "log2"); -} - -static void log_kernel_mps(TensorIteratorBase& iter) { - lib.exec_unary_kernel(iter, "log"); -} - -REGISTER_DISPATCH(exp_stub, exp_kernel); -REGISTER_DISPATCH(erfinv_stub, erfinv_kernel); -REGISTER_DISPATCH(sinc_stub, sinc_kernel); -REGISTER_DISPATCH(tanh_stub, tanh_kernel); -REGISTER_DISPATCH(sin_stub, sin_kernel); -REGISTER_DISPATCH(cos_stub, cos_kernel); -REGISTER_DISPATCH(tan_stub, tan_kernel); + lib.exec_unary_kernel(iter, "round_decimals", Scalar(decimals), ScalarType::Long); +} + +REGISTER_UNARY_TI_DISPATCH(exp); +REGISTER_UNARY_TI_DISPATCH(expm1); +REGISTER_UNARY_TI_DISPATCH(erf); +REGISTER_UNARY_TI_DISPATCH(erfc); +REGISTER_UNARY_TI_DISPATCH(erfinv); +REGISTER_UNARY_TI_DISPATCH(sinc); +REGISTER_UNARY_TI_DISPATCH(sinh); +REGISTER_UNARY_TI_DISPATCH(cosh); +REGISTER_UNARY_TI_DISPATCH(tanh); +REGISTER_UNARY_TI_DISPATCH(abs); +REGISTER_UNARY_TI_DISPATCH(sin); +REGISTER_UNARY_TI_DISPATCH(cos); +REGISTER_UNARY_TI_DISPATCH(tan); +REGISTER_UNARY_TI_DISPATCH(asin); +REGISTER_UNARY_TI_DISPATCH(acos); +REGISTER_UNARY_TI_DISPATCH(atan); +REGISTER_UNARY_TI_DISPATCH(sqrt); +REGISTER_UNARY_TI_DISPATCH(rsqrt); +REGISTER_UNARY_TI_DISPATCH(neg); +REGISTER_UNARY_TI_DISPATCH(exp2); +REGISTER_UNARY_TI_DISPATCH(log10); +REGISTER_UNARY_TI_DISPATCH(log2); +REGISTER_UNARY_TI_DISPATCH(log); +REGISTER_UNARY_TI_DISPATCH(log1p); +REGISTER_UNARY_TI_DISPATCH(bitwise_not); +REGISTER_UNARY_TI_DISPATCH(sigmoid); REGISTER_DISPATCH(round_decimals_stub, round_decimals_kernel); -REGISTER_DISPATCH(sqrt_stub, sqrt_kernel_mps); -REGISTER_DISPATCH(rsqrt_stub, rsqrt_kernel_mps); -REGISTER_DISPATCH(exp2_stub, exp2_kernel_mps); -REGISTER_DISPATCH(neg_stub, neg_kernel_mps); -REGISTER_DISPATCH(bitwise_not_stub, bitwise_not_kernel_mps); -REGISTER_DISPATCH(log10_stub, log10_kernel_mps); -REGISTER_DISPATCH(log2_stub, log2_kernel_mps); -REGISTER_DISPATCH(log_stub, log_kernel_mps); } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index 7a90b96b535faf..edf45a5ff80d03 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -12,7 +12,6 @@ #include #else #include -#include #include #include #include @@ -27,10 +26,8 @@ #include #include #include -#include #include #include -#include #include #include #include @@ -41,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -199,46 +195,10 @@ static void unary_op(const Tensor& self, } CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal) -CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(erf_out_mps, erf) -CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(asin_out_mps, asin) -CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(acos_out_mps, acos) -CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atan_out_mps, atan) -CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(sinh_out_mps, sinh) -CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(cosh_out_mps, cosh) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(asinh_out_mps, asinh) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(acosh_out_mps, acosh) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh) -Tensor& abs_out_mps(const Tensor& self, Tensor& output) { - using namespace mps; - - if (!output.is_same_size(self)) { - output.resize_(self.sizes()); - } - - if (self.numel() == 0) { - return output; - } - - if (supportsComplex() || !self.is_complex()) { - unary_op_noresize(self, output, "abs_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - auto rc = [mpsGraph absoluteWithTensor:inputTensor name:nil]; - if (self.is_complex()) { - rc = [mpsGraph realPartOfTensor:rc name:nil]; - } - return rc; - }); - } else { - Tensor realInput = at::view_as_real(self); - unary_op_noresize( - realInput, output, "abs_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - auto rc = lengthOfComplexAsReal(mpsGraph, inputTensor); - return [mpsGraph reshapeTensor:rc withShape:getMPSShape(output) name:nil]; - }); - } - return output; -} - Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) { auto bool_self = self.to(ScalarType::Bool); mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { @@ -280,18 +240,6 @@ Tensor angle_mps(const Tensor& self) { return angle_out_mps(self, result); } -TORCH_IMPL_FUNC(sigmoid_out_mps)(const Tensor& self, const Tensor& output) { - mps::unary_op(self, output, "sigmoid_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - return [mpsGraph sigmoidWithTensor:inputTensor name:nil]; - }); -} - -TORCH_IMPL_FUNC(log1p_out_mps)(const Tensor& self, const Tensor& output) { - mps::unary_op(self, output, "log1p_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - return mps::log1p(mpsGraph, inputTensor); - }); -} - TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) { TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types"); mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { @@ -305,14 +253,6 @@ Tensor angle_mps(const Tensor& self) { }); } -TORCH_IMPL_FUNC(expm1_out_mps)(const Tensor& self, const Tensor& output) { - mps::unary_op(self, output, "expm1_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType]; - MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:ePowTensor secondaryTensor:oneTensor name:nil]; - }); -} - static void logit_mps_impl(const Tensor& self, std::optional eps, Tensor& output, const std::string& op_name) { std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]"; diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index 18945ceaf24a67..addc70cf4334d1 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include #include #include @@ -38,7 +40,14 @@ #include #include #include +#include +#include +#include +#include #endif + +#include + namespace at::native { namespace mps { @@ -290,6 +299,76 @@ static void upsample_kernel_out_template(const Tensor& input, }); } +static void upsample_kernel_out_template(const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scale_d_opt, + std::optional scale_h_opt, + std::optional scale_w_opt, + const Tensor& output, + const std::string& name) { + if (output.numel() == 0) { + return; + } + UpsampleParams<5> params; + memcpy(params.input_sizes.data(), input.sizes().data(), 5 * sizeof(long)); + memcpy(params.input_strides.data(), input.strides().data(), 5 * sizeof(long)); + memcpy(params.output_strides.data(), output.strides().data(), 5 * sizeof(long)); + memcpy(params.output_sizes.data(), output.sizes().data(), 5 * sizeof(long)); + params.scales[0] = area_pixel_compute_scale(input.size(4), output.size(4), align_corners, scale_w_opt); + params.scales[1] = area_pixel_compute_scale(input.size(3), output.size(3), align_corners, scale_h_opt); + params.scales[2] = area_pixel_compute_scale(input.size(2), output.size(2), align_corners, scale_d_opt); + params.align_corners = align_corners; + auto upsamplePSO = lib.getPipelineStateForFunc(fmt::format("upsample_{}_{}", name, scalarToMetalTypeString(input))); + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:upsamplePSO]; + mtl_setArgs(computeEncoder, input, output, params); + mtl_dispatch1DJob(computeEncoder, upsamplePSO, output_size[0] * output_size[1] * output_size[2]); + } + }); +} + +static void upsample_kernel_backward_out_template(const Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scale_d_opt, + std::optional scale_h_opt, + std::optional scale_w_opt, + const std::string& name) { + grad_input.zero_(); + if (grad_output.numel() == 0) { + return; + } + auto upsamplePSO = + lib.getPipelineStateForFunc(fmt::format("upsample_{}_backward_{}", name, scalarToMetalTypeString(grad_input))); + UpsampleParams<5> params; + memcpy(params.input_sizes.data(), grad_input.sizes().data(), 5 * sizeof(long)); + memcpy(params.input_strides.data(), grad_input.strides().data(), 5 * sizeof(long)); + memcpy(params.output_strides.data(), grad_output.strides().data(), 5 * sizeof(long)); + memcpy(params.output_sizes.data(), grad_output.sizes().data(), 5 * sizeof(long)); + params.scales[0] = + area_pixel_compute_scale(grad_input.size(4), grad_output.size(4), align_corners, scale_w_opt); + params.scales[1] = + area_pixel_compute_scale(grad_input.size(3), grad_output.size(3), align_corners, scale_h_opt); + params.scales[2] = + area_pixel_compute_scale(grad_input.size(2), grad_output.size(2), align_corners, scale_d_opt); + params.align_corners = align_corners; + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:upsamplePSO]; + mtl_setArgs(computeEncoder, grad_input, grad_output, params); + mtl_dispatch1DJob(computeEncoder, upsamplePSO, output_size[0] * output_size[1] * output_size[2]); + } + }); +} + static void upsample_kernel_backward_out_template(const Tensor& grad_input, const Tensor& grad_output, IntArrayRef output_size, @@ -480,4 +559,67 @@ static void upsample_kernel_backward_out_template(const Tensor& grad_input, mps::upsample_kernel_out_template(input, output_size, align_corners, scales_h, scales_w, output, "bicubic2d_aa"); } +TORCH_IMPL_FUNC(upsample_nearest3d_out_mps)(const Tensor& input, + IntArrayRef output_size, + std::optional scales_d, + std::optional scales_h, + std::optional scales_w, + const Tensor& output) { + mps::upsample_kernel_out_template(input, output_size, false, scales_d, scales_h, scales_w, output, "nearest_3d"); +} + +TORCH_IMPL_FUNC(_upsample_nearest_exact3d_out_mps)(const Tensor& input, + IntArrayRef output_size, + std::optional scales_d, + std::optional scales_h, + std::optional scales_w, + const Tensor& output) { + mps::upsample_kernel_out_template( + input, output_size, false, scales_d, scales_h, scales_w, output, "nearest_exact_3d"); +} + +TORCH_IMPL_FUNC(upsample_nearest3d_backward_out_mps)(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + std::optional scales_d, + std::optional scales_h, + std::optional scales_w, + const Tensor& grad_input) { + mps::upsample_kernel_backward_out_template( + grad_input, grad_output, output_size, input_size, false, scales_d, scales_h, scales_w, "nearest_3d"); +} + +TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_mps)(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + std::optional scales_d, + std::optional scales_h, + std::optional scales_w, + const Tensor& grad_input) { + mps::upsample_kernel_backward_out_template( + grad_input, grad_output, output_size, input_size, false, scales_d, scales_h, scales_w, "nearest_exact_3d"); +} + +TORCH_IMPL_FUNC(upsample_trilinear3d_out_mps)(const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales_d, + std::optional scales_h, + std::optional scales_w, + const Tensor& output) { + mps::upsample_kernel_out_template( + input, output_size, align_corners, scales_d, scales_h, scales_w, output, "trilinear"); +} +TORCH_IMPL_FUNC(upsample_trilinear3d_backward_out_mps)(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales_d, + std::optional scales_h, + std::optional scales_w, + const Tensor& grad_input) { + mps::upsample_kernel_backward_out_template( + grad_input, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w, "trilinear"); +} + } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d421f9c5b5b065..c65cdd8ec081ef 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -357,8 +357,7 @@ - func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: abs_out - MPS: abs_out_mps + CPU, CUDA, MPS, MTIA: abs_out SparseCPU, SparseCUDA: abs_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_out tags: pointwise @@ -527,8 +526,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: acos_out - MPS: acos_out_mps + CPU, CUDA, MPS: acos_out tags: pointwise # arccos, alias of acos @@ -588,6 +586,7 @@ SparseCsrCUDA: add_out_sparse_compressed_cuda MkldnnCPU: mkldnn_add_out MPS: add_out_mps + MTIA: add_out_mtia tags: pointwise - func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor @@ -982,8 +981,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: asin_out - MPS: asin_out_mps + CPU, CUDA, MPS: asin_out SparseCPU, SparseCUDA: asin_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr_out tags: pointwise @@ -1020,8 +1018,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: atan_out - MPS: atan_out_mps + CPU, CUDA, MPS: atan_out SparseCPU, SparseCUDA: atan_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr_out tags: pointwise @@ -1221,7 +1218,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: bitwise_not_out + CPU, CUDA, MPS, MTIA: bitwise_not_out tags: pointwise - func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -1285,7 +1282,7 @@ - func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: logical_not_out + CPU, CUDA, MTIA: logical_not_out MPS: logical_not_out_mps tags: pointwise @@ -1327,7 +1324,7 @@ - func: logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: logical_and_out + CPU, CUDA, MTIA: logical_and_out MPS: logical_and_out_mps tags: pointwise @@ -1348,7 +1345,7 @@ - func: logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: logical_or_out + CPU, CUDA, MTIA: logical_or_out MPS: logical_or_out_mps tags: pointwise @@ -1832,7 +1829,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: cos_out + CPU, CUDA, MPS, MTIA: cos_out tags: pointwise - func: cosh(Tensor self) -> Tensor @@ -1852,8 +1849,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: cosh_out - MPS: cosh_out_mps + CPU, CUDA, MPS: cosh_out tags: pointwise - func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor @@ -1967,6 +1963,7 @@ dispatch: CPU: cummax_helper_cpu CUDA: cummax_helper_cuda + MPS: cummax_helper_mps - func: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) device_check: NoCheck # TensorIterator @@ -1991,6 +1988,7 @@ dispatch: CPU: cummin_helper_cpu CUDA: cummin_helper_cuda + MPS: cummin_helper_mps - func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor variants: function @@ -2171,7 +2169,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: div_out + CPU, CUDA, MPS, MTIA: div_out SparseCPU, SparseCUDA: div_out_sparse_zerodim tags: pointwise @@ -2548,8 +2546,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: erf_out - MPS: erf_out_mps + CPU, CUDA, MPS, MTIA: erf_out SparseCPU, SparseCUDA: erf_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr_out tags: pointwise @@ -2571,7 +2568,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: erfc_out + CPU, CUDA, MPS: erfc_out tags: pointwise - func: exp(Tensor self) -> Tensor @@ -2591,7 +2588,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: exp_out + CPU, CUDA, MPS, MTIA: exp_out tags: pointwise - func: exp2(Tensor self) -> Tensor @@ -2634,8 +2631,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: expm1_out - MPS: expm1_out_mps + CPU, CUDA, MPS: expm1_out SparseCPU, SparseCUDA: expm1_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr_out tags: pointwise @@ -3203,7 +3199,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, MPS: isnan + CPU, CUDA, MPS, MTIA: isnan NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isnan SparseCPU, SparseCUDA: isnan_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isnan_sparse_csr @@ -3339,7 +3335,7 @@ - func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: nan_to_num_out + CPU, CUDA, MTIA: nan_to_num_out MPS: nan_to_num_out_mps SparseCPU, SparseCUDA: nan_to_num_sparse_out tags: pointwise @@ -3512,7 +3508,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: log_out + CPU, CUDA, MPS, MTIA: log_out tags: pointwise - func: log10(Tensor self) -> Tensor @@ -3558,8 +3554,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: log1p_out - MPS: log1p_out_mps + CPU, CUDA, MPS: log1p_out SparseCPU, SparseCUDA: log1p_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr_out tags: pointwise @@ -3581,7 +3576,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: log2_out + CPU, CUDA, MPS, MTIA: log2_out tags: pointwise - func: logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -3728,6 +3723,7 @@ dispatch: CPU: log_softmax_cpu_out CUDA: log_softmax_cuda_out + MTIA: log_softmax_mtia_out MPS: log_softmax_mps_out - func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor @@ -3738,17 +3734,20 @@ dispatch: CPU: log_softmax_backward_cpu_out CUDA: log_softmax_backward_cuda_out + MTIA: log_softmax_backward_mtia_out MPS: log_softmax_backward_mps_out - func: _logcumsumexp(Tensor self, int dim) -> Tensor dispatch: CPU: _logcumsumexp_cpu CUDA: _logcumsumexp_cuda + MPS: _logcumsumexp_mps - func: _logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _logcumsumexp_out_cpu CUDA: _logcumsumexp_out_cuda + MPS: _logcumsumexp_out_mps - func: logcumsumexp(Tensor self, int dim) -> Tensor variants: function, method @@ -3861,7 +3860,7 @@ precomputed: - dim -> int dim dispatch: - CPU, CUDA: max_out + CPU, CUDA, MTIA: max_out MPS: max_out_mps - func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -4050,7 +4049,7 @@ precomputed: - dim -> int dim dispatch: - CPU, CUDA: min_out + CPU, CUDA, MTIA: min_out MPS: min_out_mps - func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -4158,6 +4157,7 @@ dispatch: CPU: mm_out_cpu CUDA: mm_out_cuda + MTIA: mm_out_mtia MPS: mm_out_mps XPU: mm_out_xpu SparseCPU, SparseCUDA: _sparse_mm_out @@ -4272,7 +4272,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: mul_out + CPU, CUDA, MPS, MTIA: mul_out SparseCPU: mul_out_sparse_cpu SparseCUDA: mul_out_sparse_cuda SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_out_sparse_csr @@ -4782,6 +4782,14 @@ CompositeExplicitAutograd: randint_like autogen: randint_like.out +- func: randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like + autogen: randint_like.Tensor_out + - func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor tags: nondeterministic_seeded dispatch: @@ -4891,7 +4899,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: reciprocal_out + CPU, CUDA, MTIA: reciprocal_out MPS: reciprocal_out_mps tags: pointwise @@ -4920,7 +4928,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: neg_out + CPU, CUDA, MPS, MTIA: neg_out SparseCPU, SparseCUDA: neg_out_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr_out tags: pointwise @@ -5060,7 +5068,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: relu + CPU, CUDA, MTIA: relu MPS: relu_mps MkldnnCPU: mkldnn_relu QuantizedCPU: relu_quantized_cpu @@ -5074,7 +5082,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: relu_ + CPU, CUDA, MTIA: relu_ MPS: relu_mps_ MkldnnCPU: mkldnn_relu_ QuantizedCPU: relu_quantized_cpu_ @@ -5166,7 +5174,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: hardshrink_out + CPU, CUDA, MPS: hardshrink_out - func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor structured_delegate: hardshrink.out @@ -5178,7 +5186,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: hardshrink_backward_out + CPU, CUDA, MPS: hardshrink_backward_out - func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor structured_delegate: hardshrink_backward.grad_input @@ -5201,7 +5209,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: rsqrt_out + CPU, CUDA, MPS, MTIA: rsqrt_out tags: pointwise - func: select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) @@ -5272,7 +5280,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: silu_out + CPU, CUDA, MTIA: silu_out MPS: silu_out_mps tags: pointwise @@ -5339,14 +5347,13 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sigmoid_out - MPS: sigmoid_out_mps + CPU, CUDA, MPS: sigmoid_out tags: pointwise - func: logit(Tensor self, float? eps=None) -> Tensor variants: function, method dispatch: - CPU, CUDA: logit + CPU, CUDA, MTIA: logit MPS: logit_mps tags: pointwise @@ -5386,7 +5393,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: sin_out + CPU, CUDA, MPS, MTIA: sin_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr_out SparseCPU, SparseCUDA: sin_sparse_out tags: pointwise @@ -5431,8 +5438,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sinh_out - MPS: sinh_out_mps + CPU, CUDA, MPS: sinh_out SparseCPU, SparseCUDA: sinh_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr_out @@ -5885,7 +5891,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: sqrt_out + CPU, CUDA, MPS, MTIA: sqrt_out SparseCPU, SparseCUDA: sqrt_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr_out tags: pointwise @@ -6075,7 +6081,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: tanh_out + CPU, CUDA, MPS, MTIA: tanh_out SparseCPU, SparseCUDA: tanh_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr_out tags: pointwise @@ -6539,14 +6545,14 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA, MPS: where + CPU, CUDA, MPS, MTIA: where NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_where tags: [core, pointwise] - func: where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MPS: where_self_out + CPU, CUDA, MPS, MTIA: where_self_out NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_where_out - func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor @@ -6926,6 +6932,7 @@ dispatch: CPU, CUDA: sub_out MPS: sub_out_mps + MTIA: sub_out_mtia SparseCPU, SparseCUDA: sub_out_sparse tags: pointwise @@ -6983,7 +6990,7 @@ device_check: NoCheck # TensorIterator variants: function dispatch: - CPU, CUDA, MPS: rsub + CPU, CUDA, MPS, MTIA: rsub autogen: rsub.Tensor_out - func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) @@ -7270,36 +7277,36 @@ dispatch: CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint -- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None) -> () +- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None, bool? check_pinning=None) -> () -- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout) -> () -- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> () -- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> () -- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> () -- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> () +- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout, bool? check_pinning=None) -> () +- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> () +- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> () +- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> () +- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> () - func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor dispatch: - SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse + SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_sparse autogen: _sparse_coo_tensor_with_dims.out - func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor dispatch: - SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint + SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_and_tensor_sparse_symint autogen: _sparse_coo_tensor_with_dims_and_tensors.out - func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: sparse_resize_ + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_ autogen: sparse_resize, sparse_resize.out - func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_ + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_and_clear_ autogen: sparse_resize_and_clear, sparse_resize_and_clear.out - func: sparse_mask(Tensor self, Tensor mask) -> Tensor @@ -7335,8 +7342,8 @@ - func: sparse_dim(Tensor self) -> int variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse - SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_dim_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: sparse_dim_sparse_csr CompositeExplicitAutograd: sparse_dim_default device_check: NoCheck device_guard: False @@ -7369,8 +7376,8 @@ - func: _nnz(Tensor self) -> int variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse - SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _nnz_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: _nnz_sparse_csr device_check: NoCheck device_guard: False @@ -7391,7 +7398,7 @@ - func: is_coalesced(Tensor self) -> bool variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: is_coalesced_sparse CompositeExplicitAutograd: is_coalesced_default device_check: NoCheck device_guard: False @@ -7399,14 +7406,14 @@ - func: _indices(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: _indices_sparse + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _indices_sparse device_check: NoCheck device_guard: False - func: _values(Tensor(a) self) -> Tensor(a) variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: _values_sparse + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _values_sparse device_check: NoCheck device_guard: False @@ -7416,7 +7423,7 @@ - func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) variants: method dispatch: - SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_ + SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _coalesced_sparse_ device_check: NoCheck device_guard: False autogen: _coalesced, _coalesced.out @@ -7505,9 +7512,9 @@ - func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor variants: method dispatch: - CPU, CUDA: dense_to_sparse - SparseCPU, SparseCUDA: sparse_coo_to_sparse - SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse + CPU, CUDA, MPS: dense_to_sparse + SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta, SparseCsrMPS: sparse_compressed_to_sparse autogen: _to_sparse.sparse_dim_out - func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor @@ -7517,8 +7524,8 @@ - func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor variants: method dispatch: - CPU, CUDA: dense_to_sparse - SparseCPU, SparseCUDA: sparse_coo_to_sparse + CPU, CUDA, MPS: dense_to_sparse + SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse autogen: _to_sparse.out @@ -8352,7 +8359,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: bitwise_and_out + CPU, CUDA, MTIA: bitwise_and_out MPS: bitwise_and_out_mps tags: pointwise @@ -8419,7 +8426,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: bitwise_or_out + CPU, CUDA, MTIA: bitwise_or_out MPS: bitwise_or_out_mps tags: pointwise @@ -8891,7 +8898,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ne_Scalar_out + CPU, CUDA, MTIA: ne_Scalar_out MPS: ne_scalar_out_mps QuantizedCPU: ne_out_quantized_cpu tags: pointwise @@ -8909,7 +8916,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ne_Tensor_out + CPU, CUDA, MTIA: ne_Tensor_out MPS: ne_tensor_out_mps QuantizedCPU: ne_out_quantized_cpu tags: pointwise @@ -8992,7 +8999,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ge_Scalar_out + CPU, CUDA, MTIA: ge_Scalar_out MPS: ge_scalar_out_mps QuantizedCPU: ge_out_quantized_cpu tags: pointwise @@ -9011,7 +9018,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ge_Tensor_out + CPU, CUDA, MTIA: ge_Tensor_out MPS: ge_tensor_out_mps QuantizedCPU: ge_out_quantized_cpu tags: pointwise @@ -9056,7 +9063,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: le_Scalar_out + CPU, CUDA, MTIA: le_Scalar_out MPS: le_scalar_out_mps QuantizedCPU: le_out_quantized_cpu tags: pointwise @@ -9074,7 +9081,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: le_Tensor_out + CPU, CUDA, MTIA: le_Tensor_out MPS: le_tensor_out_mps QuantizedCPU: le_out_quantized_cpu tags: pointwise @@ -9119,7 +9126,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: gt_Scalar_out + CPU, CUDA,MTIA: gt_Scalar_out MPS: gt_scalar_out_mps QuantizedCPU: gt_out_quantized_cpu tags: pointwise @@ -9138,7 +9145,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: gt_Tensor_out + CPU, CUDA, MTIA: gt_Tensor_out MPS: gt_tensor_out_mps QuantizedCPU: gt_out_quantized_cpu tags: pointwise @@ -9183,7 +9190,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: lt_Scalar_out + CPU, CUDA, MTIA: lt_Scalar_out MPS: lt_scalar_out_mps QuantizedCPU: lt_out_quantized_cpu tags: pointwise @@ -9201,7 +9208,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: lt_Tensor_out + CPU, CUDA, MTIA: lt_Tensor_out MPS: lt_tensor_out_mps QuantizedCPU: lt_out_quantized_cpu tags: pointwise @@ -9473,14 +9480,12 @@ - func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: cholesky_out - MPS: cholesky_mps_out + CPU, CUDA, MPS: cholesky_out - func: cholesky(Tensor self, bool upper=False) -> Tensor variants: method, function dispatch: - CPU, CUDA: cholesky - MPS: cholesky_mps + CPU, CUDA, MPS: cholesky - func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -9863,7 +9868,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: fmod_out + CPU, CUDA, MPS, MTIA: fmod_out tags: pointwise - func: fmod.Tensor(Tensor self, Tensor other) -> Tensor @@ -9969,7 +9974,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: remainder_out + CPU, CUDA, MPS, MTIA: remainder_out tags: pointwise - func: remainder.Tensor(Tensor self, Tensor other) -> Tensor @@ -10053,7 +10058,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: maximum_out + CPU, CUDA, MTIA: maximum_out MPS: maximum_out_mps tags: pointwise @@ -10085,7 +10090,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: minimum_out + CPU, CUDA, MTIA: minimum_out MPS: minimum_out_mps tags: pointwise @@ -11901,8 +11906,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardsigmoid_out - MPS: hardsigmoid_out_mps + CPU, CUDA, MPS: hardsigmoid_out QuantizedCPU: hardsigmoid_out_quantized_cpu - func: hardsigmoid(Tensor self) -> Tensor @@ -11923,8 +11927,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: hardsigmoid_backward_out - MPS: hardsigmoid_backward_out_mps + CPU, CUDA, MPS: hardsigmoid_backward_out - func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor structured_delegate: hardsigmoid_backward.grad_input @@ -11968,28 +11971,24 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish_out - MPS: hardswish_out_mps + CPU, CUDA, MPS: hardswish_out - func: hardswish(Tensor self) -> Tensor device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish - MPS: hardswish_mps + CPU, CUDA, MPS: hardswish - func: hardswish_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish_ - MPS: hardswish_mps_ + CPU, CUDA, MPS: hardswish_ - func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor python_module: nn dispatch: - CPU, CUDA: hardswish_backward - MPS: hardswish_backward_mps + CPU, CUDA, MPS: hardswish_backward autogen: hardswish_backward.out - func: leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) @@ -11998,8 +11997,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: leaky_relu_out - MPS: leaky_relu_out_mps + CPU, CUDA, MPS: leaky_relu_out QuantizedCPU: leaky_relu_out_quantized_cpu - func: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor @@ -12015,8 +12013,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: leaky_relu_backward_out - MPS: leaky_relu_backward_out_mps + CPU, CUDA, MPS: leaky_relu_backward_out - func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor structured_delegate: leaky_relu_backward.grad_input @@ -12128,8 +12125,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: softshrink_out - MPS: softshrink_out_mps + CPU, CUDA, MPS: softshrink_out - func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor structured_delegate: softshrink.out @@ -12142,8 +12138,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: softshrink_backward_out - MPS: softshrink_backward_out_mps + CPU, CUDA, MPS: softshrink_backward_out - func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor structured_delegate: softshrink_backward.grad_input @@ -12431,6 +12426,7 @@ dispatch: CPU: max_pool3d_with_indices_out_cpu CUDA: max_pool3d_with_indices_out_cuda + MPS: max_pool3d_with_indices_out_mps # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) @@ -12438,6 +12434,7 @@ dispatch: CPU: max_pool3d_with_indices_cpu CUDA: max_pool3d_with_indices_cuda + MPS: max_pool3d_with_indices_mps tags: core - func: max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -12828,6 +12825,7 @@ dispatch: CPU: upsample_trilinear3d_out_cpu CUDA: upsample_trilinear3d_out_cuda + MPS: upsample_trilinear3d_out_mps - func: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12839,6 +12837,7 @@ dispatch: CPU: upsample_trilinear3d_backward_out_cpu CUDA: upsample_trilinear3d_backward_out_cuda + MPS: upsample_trilinear3d_backward_out_mps - func: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12950,6 +12949,7 @@ dispatch: CPU: upsample_nearest3d_out_cpu CUDA: upsample_nearest3d_out_cuda + MPS: upsample_nearest3d_out_mps - func: _upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -12957,6 +12957,7 @@ dispatch: CPU: _upsample_nearest_exact3d_out_cpu CUDA: _upsample_nearest_exact3d_out_cuda + MPS: _upsample_nearest_exact3d_out_mps - func: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12976,6 +12977,7 @@ dispatch: CPU: upsample_nearest3d_backward_out_cpu CUDA: upsample_nearest3d_backward_out_cuda + MPS: upsample_nearest3d_backward_out_mps - func: _upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -12983,6 +12985,7 @@ dispatch: CPU: _upsample_nearest_exact3d_backward_out_cpu CUDA: _upsample_nearest_exact3d_backward_out_cuda + MPS: _upsample_nearest_exact3d_backward_out_mps - func: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -13025,7 +13028,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: tanh_backward_out + CPU, CUDA, MTIA: tanh_backward_out MPS: tanh_backward_out_mps tags: pointwise @@ -13936,8 +13939,7 @@ python_module: linalg structured: True dispatch: - CPU, CUDA: linalg_cholesky_ex_out - MPS: linalg_cholesky_ex_out_mps + CPU, CUDA, MPS: linalg_cholesky_ex_out - func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor python_module: linalg @@ -15589,7 +15591,7 @@ - func: special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_shifted_chebyshev_polynomial_t_out + CPU, CUDA, MPS: special_shifted_chebyshev_polynomial_t_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15638,7 +15640,7 @@ - func: special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_shifted_chebyshev_polynomial_u_out + CPU, CUDA, MPS: special_shifted_chebyshev_polynomial_u_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15687,7 +15689,7 @@ - func: special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_shifted_chebyshev_polynomial_v_out + CPU, CUDA, MPS: special_shifted_chebyshev_polynomial_v_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15736,7 +15738,7 @@ - func: special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_shifted_chebyshev_polynomial_w_out + CPU, CUDA, MPS: special_shifted_chebyshev_polynomial_w_out python_module: special structured_inherits: TensorIteratorBase structured: True diff --git a/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp b/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp index 15d4eeeddcccb5..1997ea1648352b 100644 --- a/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp +++ b/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp @@ -104,21 +104,6 @@ inline float dequantize_val(double scale, int64_t zero_point, T value) { } #else // USE_FBGEMM -#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) -template -inline float Round(const float x) { - return ::nearbyintf(x); -} -inline double Round(const double x) { - return ::nearbyint(x); -} -#else -template -inline T Round(const T x) { - return std::nearbyint(x); -} -#endif - template T quantize_val(double scale, int64_t zero_point, float value) { // std::nearbyint results in nearest integer value according to the current @@ -132,7 +117,7 @@ T quantize_val(double scale, int64_t zero_point, float value) { constexpr int64_t qmin = std::numeric_limits::min(); constexpr int64_t qmax = std::numeric_limits::max(); float inv_scale = 1.0f / static_cast(scale); - qvalue = static_cast(zero_point + Round(value * inv_scale)); + qvalue = static_cast(zero_point + std::nearbyint(value * inv_scale)); qvalue = std::max(qvalue, qmin); qvalue = std::min(qvalue, qmax); return static_cast(qvalue); @@ -147,7 +132,7 @@ T quantize_val_arm( constexpr int32_t qmax = std::numeric_limits::max(); float inv_scale = 1.0f / scale; #ifndef _MSC_VER - auto r = static_cast(Round(value * inv_scale)); + auto r = static_cast(std::nearbyint(value * inv_scale)); // builtin_add_overflow() returns true in case of overflow if (__builtin_add_overflow(zero_point, r, &r)) { // zero_point must be a non-negative value between qmin and qmax, @@ -155,7 +140,7 @@ T quantize_val_arm( r = qmax; } #else - auto r = zero_point + static_cast(Round(value * inv_scale)); + auto r = zero_point + static_cast(std::nearbyint(value * inv_scale)); #endif r = std::max(r, qmin); r = std::min(r, qmax); @@ -191,7 +176,7 @@ TORCH_API float dequantize_val(double scale, int64_t zero_point, T value) { /* * Quantize value based on the following equation -* Xq = Round(Xf * inv_scale + zero_point) +* Xq = std::nearbyint(Xf * inv_scale + zero_point) * where zero_point is in float. * * Note: For the case of embedding quantization we will set zero_point diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp index ab9f3f46330f1b..4ca777be9cd44e 100644 --- a/aten/src/ATen/native/quantized/QTensor.cpp +++ b/aten/src/ATen/native/quantized/QTensor.cpp @@ -81,8 +81,8 @@ std::vector quantize_per_tensor_list_cpu( for (const auto i : c10::irange(tensors.size())) { quantized_tensors.push_back(at::quantize_per_tensor( tensors[i], - scales[i].item(), - zero_points[i].item(), + scales[static_cast(i)].item(), + zero_points[static_cast(i)].item(), dtype)); } return quantized_tensors; @@ -293,18 +293,16 @@ std::tuple _choose_qparams_per_tensor( static float calculate_quant_loss( const float* input, - int numel, + int64_t numel, float xmin, float xmax, float* q_input, - int bit_width) { + int64_t bit_width) { xmin = static_cast(xmin); float data_range = xmax - xmin; - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - float qmax = (1 << bit_width) - 1; + float qmax = static_cast((1 << bit_width) - 1); float scale = data_range == 0 - ? 1.0 - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? 1.0f : static_cast(static_cast(data_range / qmax)); float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale; @@ -347,10 +345,10 @@ std::tuple choose_qparams_optimized( const float* input_row = input_tensor.const_data_ptr(); float xmin = *std::min_element(input_row, input_row + numel); float xmax = *std::max_element(input_row, input_row + numel); + float n_bins_float = static_cast(n_bins); - float stepsize = (xmax - xmin) / n_bins; - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - int min_bins = n_bins * (1.0 - (float) ratio); + float stepsize = (xmax - xmin) / n_bins_float; + float min_bins = static_cast(n_bins_float* (1.0 - ratio)); Tensor input_tensor_contig = input_tensor.contiguous(); const float* input = input_tensor_contig.const_data_ptr(); std::vector q_input(numel); @@ -363,7 +361,6 @@ std::tuple choose_qparams_optimized( float cur_max = xmax; float cur_loss = loss; - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) float thr = min_bins * stepsize; while (cur_min + thr < cur_max) { // move left diff --git a/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp b/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp index f62db511b367c6..9233ea7003d45a 100644 --- a/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cpu/LinearUnpackImpl.cpp @@ -84,7 +84,7 @@ std::tuple> PackedLinearWeightsQnnp:: at::device(c10::kCPU).dtype(c10::kFloat)); at::Tensor zero_points = at::empty( - w_zero_points.size() - kPaddingChannels, at::device(c10::kCPU).dtype(c10::kLong)); + static_cast(w_zero_points.size() - kPaddingChannels), at::device(c10::kCPU).dtype(c10::kLong)); for (const auto i : c10::irange(zero_points.numel())) { zero_points[i] = ((int64_t)w_zero_points[i] - 128); } diff --git a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h index 2fb60fd88b3c00..36f6140953f6a9 100644 --- a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h +++ b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h @@ -382,26 +382,11 @@ struct PackedConvWeightsQnnp : public ConvPackedParamsBase { enum class Activation : uint8_t { NONE = 0, RELU = 1 }; -#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) -template -inline float Round(const float x) { - return ::nearbyintf(x); -} -inline double Round(const double x) { - return ::nearbyint(x); -} -#else -template -inline T Round(const T x) { - return std::nearbyint(x); -} -#endif - template inline T QuantizeValue(float scale, int32_t zero_point, float value) { const int32_t qmin = std::numeric_limits::min(); const int32_t qmax = std::numeric_limits::max(); - auto r = zero_point + static_cast(Round(value / scale)); + auto r = zero_point + static_cast(std::nearbyint(value / scale)); r = std::max(r, qmin); r = std::min(r, qmax); return static_cast(r); diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 1ab272e86c1533..b5b887b98bb08e 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -108,8 +108,7 @@ Tensor qcat_nhwc_kernel( const int64_t N = qx0.size(0); const int64_t H = qx0.size(2); const int64_t W = qx0.size(3); - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float inv_scale = 1.0 / scale; + float inv_scale = static_cast(1.0 / scale); auto output = at::_empty_affine_quantized( {N, C_out, H, W}, @@ -1282,12 +1281,10 @@ void qelu_kernel( template void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) { int64_t zero_point = out.q_zero_point(); - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float scale = out.q_scale(); - float inv_scale = 1.0f / scale; + float scale = static_cast(out.q_scale()); + float inv_scale = static_cast(1.0f / scale); int64_t self_zero_point = self.q_zero_point(); - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float self_scale = self.q_scale(); + float self_scale = static_cast(self.q_scale()); float multiplier = self_scale * inv_scale; @@ -2699,10 +2696,11 @@ void _fake_quantize_tensor_helper( bool* mask_val = (bool*)(data[1] + i * strides[1]); scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]); - const auto qval = static_cast(z_point + std::nearbyint(*input_val * inv_scale)); if (fake_quant_on) { - *output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc; - *mask_val = ((quant_min <= qval) && (qval <= quant_max)); + auto qval_f = z_point + std::nearbyint(*input_val * inv_scale); + const auto qval = static_cast(std::fmin(std::fmax(qval_f, quant_min), quant_max)); + *output_val = (qval - z_point) * sc; + *mask_val = ((quant_min <= qval_f) && (qval_f <= quant_max)); } else { *output_val = *input_val; *mask_val = 1; @@ -2718,10 +2716,11 @@ void _fake_quantize_tensor_helper( bool* mask_val = (bool*)(data[1] + i * strides[1]); scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]); - const auto qval = static_cast(z_point + std::nearbyint(*input_val * inv_scale)); if (fake_quant_on) { - *output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc; - *mask_val = ((quant_min <= qval) && (qval <= quant_max)); + auto qval_f = z_point + std::nearbyint(*input_val * inv_scale); + const auto qval = static_cast(std::fmin(std::fmax(qval_f, quant_min), quant_max)); + *output_val = (qval - z_point) * sc; + *mask_val = ((quant_min <= qval_f) && (qval_f <= quant_max)); } else { *output_val = *input_val; *mask_val = 1; diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 10f34a685f3d6b..d2049a93672fe9 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,14 @@ #include // for quantize_per_te... #include #include +#include +#include +#include +#include +#include +#include +#include +#include #endif #include @@ -918,6 +927,118 @@ at::Tensor PackedLinearWeightsOnednn:: apply_tanh( std::move(input), output_scale, output_zero_point); } +static at::Tensor fp8_qlinear_onednn_ref( + at::Tensor input, + double input_scale, + at::Tensor weight, // expect plain weight + at::Tensor weight_scales, + std::optional bias, // plain tensor + double output_scale, + std::optional output_dtype, + std::optional other, // extra input for binary post-op + double other_scale, + const std::string_view& binary_post_op, // e.g. "none", "sum", "add" + double binary_alpha, + const std::string_view& unary_post_op, // e.g. "none", "relu" + torch::List>& unary_post_op_args, + std::string_view& unary_post_op_algorithm) { + TORCH_CHECK( + input.scalar_type() == at::ScalarType::Float8_e4m3fn && weight.scalar_type() == at::ScalarType::Float8_e4m3fn, + "FP8 qlinear: Unexpected dtype of input and weight:", input.scalar_type(), ", ", weight.scalar_type()); + const int64_t dim = input.dim(); + auto input_contig = + dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous(); + auto N = weight.size(0); + auto output_size = input.sizes().vec(); + output_size[dim - 1] = N; + auto dqx = input_contig.to(at::kFloat) * input_scale; + std::vector w_scales_new_shape(weight.dim(), 1); + w_scales_new_shape[0] = -1; + auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape); + auto y_f32 = at::linear(dqx, dqw, bias); + if (binary_post_op == "none") { + if (unary_post_op == "relu") { + at::relu_(y_f32); + } else if (unary_post_op == "leaky_relu") { + TORCH_CHECK( + unary_post_op_args.size() == 1, + "onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args"); + auto element = unary_post_op_args.get(0); + auto alpha = element.value().to(); + at::leaky_relu_(y_f32, alpha); + } else if (unary_post_op == "tanh") { + at::tanh_(y_f32); + } else if (unary_post_op == "gelu") { + TORCH_CHECK( + unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh", + "onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm); + at::gelu_(y_f32, unary_post_op_algorithm); + } else if (unary_post_op == "hardtanh") { + TORCH_CHECK( + unary_post_op_args.size() == 2 && + unary_post_op_args.get(0).has_value() && + unary_post_op_args.get(1).has_value(), + "hardtanh is expected to have two scalar input: min_val and max_val"); + auto lower_bound_value = + unary_post_op_args.get(0).value().to(); + auto upper_bound_value = + unary_post_op_args.get(1).value().to(); + at::hardtanh_(y_f32, lower_bound_value, upper_bound_value); + } else if (unary_post_op == "hardswish") { + at::hardswish_(y_f32); + } else if (unary_post_op == "swish") { + // return ideep::attr_t::fuse_swish(); + y_f32 = y_f32 * at::sigmoid(y_f32); + } else { + TORCH_CHECK( + unary_post_op == "none", + "onednn qlinear: unsupported unary post op ", unary_post_op); + } + } else if (binary_post_op == "sum") { + TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op sum"); + auto x1 = other.value(); + TORCH_CHECK(x1.sizes().vec() == output_size); + auto x1_f32 = x1.to(at::kFloat) * other_scale; + x1_f32 = x1_f32.view(y_f32.sizes()); + if (unary_post_op == "none") { + y_f32.add_(x1_f32); + } else if (unary_post_op == "relu") { + y_f32.add_(x1_f32).relu_(); + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum"); + } + y_f32.div_(output_scale); + x1.copy_(y_f32.to(x1.scalar_type()).view(x1.sizes())); + return x1; + } else if (binary_post_op == "add") { + TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op sum"); + auto x1 = other.value(); + TORCH_CHECK(x1.sizes().vec() == output_size); + auto x1_f32 = x1.to(at::kFloat) * other_scale; + x1_f32 = x1_f32.view(y_f32.sizes()); + if (unary_post_op == "none") { + y_f32.add_(x1_f32); + } else if (unary_post_op == "relu") { + y_f32.add_(x1_f32).relu_(); + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add"); + } + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported binary post op ", binary_post_op); + } + + y_f32.div_(output_scale); + y_f32 = y_f32.view(output_size); + auto out_dtype = output_dtype.has_value() ? output_dtype.value() : at::kFloat8_e4m3fn; + return y_f32.to(out_dtype); +} + static at::Tensor linear_int8_with_onednn_weight( at::Tensor input, // int8 CPU Tensor, not QTensor double input_scale, @@ -939,10 +1060,18 @@ static at::Tensor linear_int8_with_onednn_weight( std::string_view& unary_post_op_algorithm) { using ideep::tensor; const int64_t dim = input.dim(); - TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char, - "qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char)."); - TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char, - "qlinear with mkldnn tensor: data type of weight should be int8 (char)."); + TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char || input.scalar_type() == c10::ScalarType::Float8_e4m3fn, + "qlinear with mkldnn tensor: data type of input should be uint8, int8 or float8_e4m3fn."); + TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn, + "qlinear with mkldnn tensor: data type of weight should be int8 or float8_e4m3fn."); + bool is_fp8 = false; + if (input.scalar_type() == c10::ScalarType::Float8_e4m3fn || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn) { + TORCH_CHECK( + input.scalar_type() == c10::ScalarType::Float8_e4m3fn && onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn, + "qlinear with mkldnn tensor: data type of input and weight should be the same for fp8, but got ", + input.scalar_type(), " and ", onednn_weight.scalar_type()); + is_fp8 = true; + } TORCH_CHECK( weight_scales.scalar_type() == c10::ScalarType::Float, "weight scales should be dtype c10::ScalarType::Float."); TORCH_CHECK( @@ -976,7 +1105,7 @@ static at::Tensor linear_int8_with_onednn_weight( ); } if (binary_post_op == "sum") { - auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : c10::kByte; + auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : input.scalar_type(); TORCH_CHECK( other.value().scalar_type() == expected_dtype, "onednn qlinear: the dtype of extra input for binary post op should be ", expected_dtype, @@ -984,6 +1113,14 @@ static at::Tensor linear_int8_with_onednn_weight( ); } } + if (is_fp8 && !cpuinfo_has_x86_amx_int8()) { + // Fall back to ref impl on old platforms because not supported + return fp8_qlinear_onednn_ref( + input, input_scale, onednn_weight, weight_scales, bias, + output_scale, output_dtype, other, other_scale, + binary_post_op, binary_alpha, unary_post_op, + unary_post_op_args, unary_post_op_algorithm); + } // If the input has more than two dimensions, we will reshape it to a 2-dimensional form // for calculation and subsequently reshape the output back. @@ -1016,7 +1153,7 @@ static at::Tensor linear_int8_with_onednn_weight( at::empty( dst_dims, at::device(c10::kCPU) - .dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : c10::kByte)) + .dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : input.scalar_type())) ); if (output.numel() == 0) { return output; @@ -1029,7 +1166,7 @@ static at::Tensor linear_int8_with_onednn_weight( empty_tensor; // Create onednn primitive - auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8; + auto src_dtype = at::native::get_mkldnn_dtype(input.scalar_type()); auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any); auto weights_desc = packed_weight.get_desc(); auto dst_dtype = dst.get_data_type(); @@ -1463,5 +1600,16 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor)); } +TORCH_LIBRARY_IMPL(onednn, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"), + TORCH_FN(QLinearOnednn::run_pointwise)); + m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"), + TORCH_FN(at::native::QLinearOnednn::run_pointwise_tensor)); + m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"), + TORCH_FN(QLinearOnednn::run_pointwise_binary)); + m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"), + TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor)); +} + } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index 50af0862aef92e..d99a336bf3731f 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -297,14 +297,32 @@ c10::intrusive_ptr PackedLinearWeightsOnednn::prepack( static inline at::Tensor pack_weight_to_onednn_tensor( const at::Tensor& weight, std::optional>& input_shape) { + at::ScalarType weigh_dtype = weight.scalar_type(); + TORCH_CHECK( + weigh_dtype == at::kChar || weigh_dtype == at::kFloat8_e4m3fn, + "Weight should be of type int8 or float8_e4m3fn"); + bool is_fp8 = weigh_dtype == at::kFloat8_e4m3fn; + if (is_fp8 && !cpuinfo_has_x86_amx_int8()) { + // oneDNN's fp8 requires AMX support + // If AMX is not available, fall back to reference implementation + return weight; + } std::vector w_dims = weight.sizes().vec(); - ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::s8}, weight.data_ptr()); + auto w_data_type = is_fp8 + ? dnnl::memory::data_type::f8_e4m3 + : dnnl::memory::data_type::s8; + ideep::tensor wei = ideep::tensor({w_dims, w_data_type}, weight.data_ptr()); wei.transpose_(0, 1); // oneDNN requires transposed weight ideep::dims input_dims = input_shape.has_value() ? input_shape.value().vec() : ideep::dims(); ideep::attr_t op_attr; - op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + if (!is_fp8) { + op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + } + auto x_data_type = is_fp8 + ? dnnl::memory::data_type::f8_e4m3 + : dnnl::memory::data_type::u8; auto w_desc = ideep::matmul_forward::expected_weights_desc( - wei.get_dims(), input_dims, dnnl::memory::data_type::s8, dnnl::memory::data_type::u8, op_attr); + wei.get_dims(), input_dims, w_data_type, x_data_type, op_attr); ideep::tensor expected_weight(w_desc); expected_weight.feed_from(wei); auto packed_weight = at::native::new_with_itensor_mkldnn( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/CMakeLists.txt b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/CMakeLists.txt index af31056c443015..979af29f723dbb 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/CMakeLists.txt +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/CMakeLists.txt @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +cmake_minimum_required(VERSION 3.27 FATAL_ERROR) include(GNUInstallDirs) diff --git a/aten/src/ATen/native/quantized/cuda/FusedObsFakeQuant.cu b/aten/src/ATen/native/quantized/cuda/FusedObsFakeQuant.cu index 5e5a2458cfca8b..293fd600eaee33 100644 --- a/aten/src/ATen/native/quantized/cuda/FusedObsFakeQuant.cu +++ b/aten/src/ATen/native/quantized/cuda/FusedObsFakeQuant.cu @@ -273,7 +273,7 @@ std::tuple fused_moving_avg_obs_fake_quant_cuda( } _calculate_moving_average( y, - observer_on, + observer_on.to(at::kLong), running_min, running_max, averaging_const, @@ -282,7 +282,7 @@ std::tuple fused_moving_avg_obs_fake_quant_cuda( } else { _calculate_moving_average( x_contig, - observer_on, + observer_on.to(at::kLong), running_min, running_max, averaging_const, @@ -295,7 +295,7 @@ std::tuple fused_moving_avg_obs_fake_quant_cuda( _calc_moving_avg_qparams_helper( x_contig, - fake_quant_on, + fake_quant_on.to(at::kLong), running_min, running_max, scale_ptr, @@ -316,7 +316,7 @@ std::tuple fused_moving_avg_obs_fake_quant_cuda( } } else { return at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams( - x, scale, zero_point, fake_quant_on, qmin, qmax); + x, scale, zero_point, fake_quant_on.to(at::kLong), qmin, qmax); } } } // namespace at::native diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index eda5d7c2696b4b..e359da0af3eef5 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -125,7 +125,7 @@ bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& st formats with support to batched and dense dimensions. */ -static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) { +static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout, std::optional check_pinning_) { // Layout must be Sparse Compressed, 2.4 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{}); @@ -134,6 +134,7 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres const std::string plain_indices_name = plainIndicesName(layout); const std::string compressed_dim_name = compressedDimName(layout); const std::string plain_dim_name = plainDimName(layout); + const bool check_pinning = check_pinning_.value_or(true); // Layout Invariants @@ -295,20 +296,22 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres ") must match device of ", plain_indices_name, " (=", plain_indices.device(), ")"); - TORCH_CHECK( + if (check_pinning) { + TORCH_CHECK( compressed_indices.is_pinned() == values.is_pinned(), "memory pinning of ", compressed_indices_name, " (=", compressed_indices.is_pinned(), ") must match memory pinning of values (=", values.is_pinned(), ")"); - TORCH_CHECK( + TORCH_CHECK( compressed_indices.is_pinned() == plain_indices.is_pinned(), "memory pinning of ", compressed_indices_name, " (=", compressed_indices.is_pinned(), ") must match memory pinning of ", plain_indices_name, " (=", plain_indices.is_pinned(), ")"); + } // Autograd Invariants // @@ -319,24 +322,24 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres TORCH_INTERNAL_ASSERT(!plain_indices.requires_grad()); } -void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) { - _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout); +void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout, std::optional check_pinning) { + _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout, check_pinning); } -void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) { - _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr); +void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, std::optional check_pinning) { + _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr, check_pinning); } -void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) { - _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc); +void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size, std::optional check_pinning) { + _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc, check_pinning); } -void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) { - _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr); +void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, std::optional check_pinning) { + _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr, check_pinning); } -void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) { - _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc); +void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size, std::optional check_pinning) { + _validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc, check_pinning); } // Construction of CSR, CSC, BSR, and BSC tensors. @@ -467,7 +470,7 @@ Tensor _sparse_compressed_tensor_unsafe_symint( Layout layout_ = layout.value(); AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{}); if (at::globalContext().checkSparseTensorInvariants()) { - _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, C10_AS_INTARRAYREF_SLOW(size), layout_); + _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, C10_AS_INTARRAYREF_SLOW(size), layout_, true); } TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); SparseCsrTensor self = new_compressed_tensor(options); @@ -491,7 +494,7 @@ static Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed Layout layout_ = layout.value_or(required_layout); TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_); if (at::globalContext().checkSparseTensorInvariants()) { - _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); + _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_, true); } TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); SparseCsrTensor self = new_compressed_tensor(options); diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index db553314f5122b..b63d8ae80e50b8 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -371,9 +371,11 @@ void _validate_sparse_coo_tensor_args( const Tensor& indices, const Tensor& values_, ArrayRef size, - std::optional is_coalesced_) { + std::optional is_coalesced_, + std::optional check_pinning_) { Tensor values = expand_values_if_needed(values_); bool is_coalesced = is_coalesced_.value_or(false); + const bool check_pinning = check_pinning_.value_or(true); // the following checks are redundant because they are also checked in // SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them @@ -397,13 +399,15 @@ void _validate_sparse_coo_tensor_args( "), but got ", size.size()); - TORCH_CHECK( - indices.is_pinned() == values.is_pinned(), - "memory pinning of indices (=", - indices.is_pinned(), - ") must match memory pinning of values (=", - values.is_pinned(), - ")"); + if (check_pinning) { + TORCH_CHECK( + indices.is_pinned() == values.is_pinned(), + "memory pinning of indices (=", + indices.is_pinned(), + ") must match memory pinning of values (=", + values.is_pinned(), + ")"); + } // Check to make sure all indices are within the boundaries of `size` if (indices.numel() > 0) { diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index 133a73505dcf06..582778fdc299d5 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -7,67 +7,7 @@ #include -// LIMITATION (cusparseSpMM): -// The generic APIs are available on all platforms on CUDA 11.0 -// For CUDA 10.1+ it is available for all platforms except Windows. -// Using these APIs in any other systems will result in compile-time or run-time failures. -// Their support will be extended in the next releases. - -#if defined(CUDART_VERSION) && (CUSPARSE_VERSION >= 11000 || (!defined(_MSC_VER) && CUSPARSE_VERSION >= 10301)) -#define IS_SPMM_AVAILABLE() 1 -#else -#define IS_SPMM_AVAILABLE() 0 -#endif - -#if defined(USE_ROCM) -#define IS_SPMM_HIP_AVAILABLE() 1 -#else -#define IS_SPMM_HIP_AVAILABLE() 0 -#endif - -#if IS_SPMM_AVAILABLE() || IS_SPMM_HIP_AVAILABLE() #include -#endif - -#if !defined(CUSPARSE_VERSION) || (CUSPARSE_VERSION < 10100) -const char* cusparseGetErrorString(cusparseStatus_t status) { - switch(status) - { - case CUSPARSE_STATUS_SUCCESS: - return "success"; - - case CUSPARSE_STATUS_NOT_INITIALIZED: - return "library not initialized"; - - case CUSPARSE_STATUS_ALLOC_FAILED: - return "resource allocation failed"; - - case CUSPARSE_STATUS_INVALID_VALUE: - return "an invalid numeric value was used as an argument"; - - case CUSPARSE_STATUS_ARCH_MISMATCH: - return "an absent device architectural feature is required"; - - case CUSPARSE_STATUS_MAPPING_ERROR: - return "an access to GPU memory space failed"; - - case CUSPARSE_STATUS_EXECUTION_FAILED: - return "the GPU program failed to execute"; - - case CUSPARSE_STATUS_INTERNAL_ERROR: - return "an internal operation failed"; - - case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED: - return "the matrix type is not supported by this function"; - - case CUSPARSE_STATUS_ZERO_PIVOT: - return "an entry of the matrix is either structural zero or numerical zero (singular block)"; - - default: - return "unknown error"; - } -} -#endif namespace at::native::sparse::cuda { @@ -92,8 +32,6 @@ cusparseOperation_t convertTransToCusparseOperation(char trans) { } } -#if IS_SPMM_AVAILABLE() || IS_SPMM_HIP_AVAILABLE() - namespace { template void _csrmm2( @@ -259,211 +197,6 @@ template<> void csrmm2>( reinterpret_cast(c), ldc, CUDA_C_64F); } -#else - -void adjustLd(char transb, int64_t m, int64_t n, int64_t k, int64_t *ldb, int64_t *ldc) -{ - int transb_ = ((transb == 't') || (transb == 'T')); - - if(n == 1) - *ldc = m; - - if(transb_) - { - if(k == 1) - *ldb = n; - } - else - { - if(n == 1) - *ldb = k; - } -} - -void Scsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const float *alpha, const float *csrvala, int *csrrowptra, int *csrcolinda, const float *b, int64_t ldb, const float *beta, float *c, int64_t ldc) -{ - adjustLd(transb, m, n, k, &ldb, &ldc); - cusparseOperation_t opa = convertTransToCusparseOperation(transa); - cusparseOperation_t opb = convertTransToCusparseOperation(transb); - - TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX), - "cusparseScsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX); - int i_m = (int)m; - int i_n = (int)n; - int i_k = (int)k; - int i_nnz = (int)nnz; - int i_ldb = (int)ldb; - int i_ldc = (int)ldc; - - auto handle = at::cuda::getCurrentCUDASparseHandle(); - cusparseMatDescr_t desc; - cusparseCreateMatDescr(&desc); - TORCH_CUDASPARSE_CHECK(cusparseScsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc)); - TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc)); -} - -void Dcsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const double *alpha, const double *csrvala, int *csrrowptra, int *csrcolinda, const double *b, int64_t ldb, const double *beta, double *c, int64_t ldc) -{ - adjustLd(transb, m, n, k, &ldb, &ldc); - cusparseOperation_t opa = convertTransToCusparseOperation(transa); - cusparseOperation_t opb = convertTransToCusparseOperation(transb); - - TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX), - "cusparseDcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX); - int i_m = (int)m; - int i_n = (int)n; - int i_k = (int)k; - int i_nnz = (int)nnz; - int i_ldb = (int)ldb; - int i_ldc = (int)ldc; - - - auto handle = at::cuda::getCurrentCUDASparseHandle(); - cusparseMatDescr_t desc; - cusparseCreateMatDescr(&desc); - TORCH_CUDASPARSE_CHECK(cusparseDcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc)); - TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc)); - // TODO: Proper fix is to create real descriptor classes -} - -template -void Ccsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const complex_target_t *alpha, const complex_target_t *csrvala, int *csrrowptra, int *csrcolinda, const complex_target_t *b, int64_t ldb, const complex_target_t *beta, complex_target_t *c, int64_t ldc) -{ - adjustLd(transb, m, n, k, &ldb, &ldc); - cusparseOperation_t opa = convertTransToCusparseOperation(transa); - cusparseOperation_t opb = convertTransToCusparseOperation(transb); - - TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX), - "cusparseCcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX); - int i_m = (int)m; - int i_n = (int)n; - int i_k = (int)k; - int i_nnz = (int)nnz; - int i_ldb = (int)ldb; - int i_ldc = (int)ldc; - - auto handle = at::cuda::getCurrentCUDASparseHandle(); - cusparseMatDescr_t desc; - cusparseCreateMatDescr(&desc); - TORCH_CUDASPARSE_CHECK(cusparseCcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc)); - TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc)); -} - -template -void Zcsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const complex_target_t *alpha, const complex_target_t *csrvala, int *csrrowptra, int *csrcolinda, const complex_target_t *b, int64_t ldb, const complex_target_t *beta, complex_target_t *c, int64_t ldc) -{ - adjustLd(transb, m, n, k, &ldb, &ldc); - cusparseOperation_t opa = convertTransToCusparseOperation(transa); - cusparseOperation_t opb = convertTransToCusparseOperation(transb); - - TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX), - "cusparseZcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX); - int i_m = (int)m; - int i_n = (int)n; - int i_k = (int)k; - int i_nnz = (int)nnz; - int i_ldb = (int)ldb; - int i_ldc = (int)ldc; - - - auto handle = at::cuda::getCurrentCUDASparseHandle(); - cusparseMatDescr_t desc; - cusparseCreateMatDescr(&desc); - TORCH_CUDASPARSE_CHECK(cusparseZcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc)); - TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc)); -} - -// T can only be float or double -template -void csrmm2( - char transa, char transb, - int64_t m, int64_t n, int64_t k, int64_t nnz, - T alpha, T *csrvala, int *csrrowptra, int *csrcolinda, - T *b, int64_t ldb, T beta, T *c, int64_t ldc) -{ - static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); -} - -template<> void csrmm2( - char transa, char transb, - int64_t m, int64_t n, int64_t k, int64_t nnz, - float alpha, float *csrvala, int *csrrowptra, int *csrcolinda, - float *b, int64_t ldb, float beta, float *c, int64_t ldc) -{ - Scsrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc); -} - -template<> void csrmm2( - char transa, char transb, - int64_t m, int64_t n, int64_t k, int64_t nnz, - double alpha, double *csrvala, int *csrrowptra, int *csrcolinda, - double *b, int64_t ldb, double beta, double *c, int64_t ldc) -{ - Dcsrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc); -} - -template<> void csrmm2>( - char transa, char transb, - int64_t m, int64_t n, int64_t k, int64_t nnz, - c10::complex alpha, c10::complex *csrvala, int *csrrowptra, int *csrcolinda, - c10::complex *b, int64_t ldb, c10::complex beta, c10::complex *c, int64_t ldc) -{ - - #ifdef USE_ROCM - Ccsrmm2(transa, transb, m, n, k, nnz, - reinterpret_cast(&alpha), - reinterpret_cast(csrvala), - csrrowptra, - csrcolinda, - reinterpret_cast(b), - ldb, - reinterpret_cast(&beta), - reinterpret_cast(c), ldc); - #else - Ccsrmm2(transa, transb, m, n, k, nnz, - reinterpret_cast(&alpha), - reinterpret_cast(csrvala), - csrrowptra, - csrcolinda, - reinterpret_cast(b), - ldb, - reinterpret_cast(&beta), - reinterpret_cast(c), ldc); - #endif -} - -template<> void csrmm2>( - char transa, char transb, - int64_t m, int64_t n, int64_t k, int64_t nnz, - c10::complex alpha, c10::complex *csrvala, int *csrrowptra, int *csrcolinda, - c10::complex *b, int64_t ldb, c10::complex beta, c10::complex *c, int64_t ldc) -{ - #ifdef USE_ROCM - Zcsrmm2(transa, transb, m, n, k, nnz, - reinterpret_cast(&alpha), - reinterpret_cast(csrvala), - csrrowptra, - csrcolinda, - reinterpret_cast(b), - ldb, - reinterpret_cast(&beta), - reinterpret_cast(c), ldc); - #else - Zcsrmm2(transa, transb, m, n, k, nnz, - reinterpret_cast(&alpha), - reinterpret_cast(csrvala), - csrrowptra, - csrcolinda, - reinterpret_cast(b), - ldb, - reinterpret_cast(&beta), - reinterpret_cast(c), ldc); - #endif -} - - -#endif - /* format conversion */ void CreateIdentityPermutation(int64_t nnz, int *P) { TORCH_CHECK((nnz <= INT_MAX), diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp index 5f6633593ed71a..de73ce612f10a7 100644 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp +++ b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp @@ -1,5 +1,7 @@ #include - +#include +#include +#include #if AT_CUSPARSELT_ENABLED() namespace at::native { @@ -15,6 +17,45 @@ namespace at::native { thread_local cusparseLtHandle_t handle; thread_local bool handle_initialized = false; +#ifdef USE_ROCM +// Single global flag for platform-wide hipSparseLt support +c10::once_flag g_hipSparseLtSupportInitFlag; +static bool g_hipSparseLtSupported = false; + +// Initialize the hipSparseLt support status once for the platform +static void initHipSparseLtSupport() { + // Default to not supported + g_hipSparseLtSupported = false; + + // Check only the first available device + try { + if (at::cuda::device_count() > 0) { + g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0); + } + } catch (const std::exception&) { + // If an exception occurs during device property check, we assume hipSparseLt is not supported + // This could happen due to driver issues, device access problems, or other runtime errors + g_hipSparseLtSupported = false; + TORCH_WARN("Exception occurred while checking hipSparseLt support. Assuming not supported."); + } +} + +static bool isHipSparseLtSupported() { + // Initialize support check only once + c10::call_once(g_hipSparseLtSupportInitFlag, initHipSparseLtSupport); + + // Return cached result (platform-wide) + if (!g_hipSparseLtSupported) { + TORCH_CHECK( + false, + "hipSparseLt not supported on this device, supported architectures: " + "gfx950, gfx942. " + "required ROCM version: 6.4.0 or later."); + } + return g_hipSparseLtSupported; +} +#endif + at::Tensor _cslt_compress(const Tensor& sparse_input) { if (!handle_initialized) { TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); @@ -25,6 +66,10 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) { cudaDataType type; auto compression_factor = 9; + #ifdef USE_ROCM + TORCH_CHECK(isHipSparseLtSupported()); + #endif + switch (sparse_input.scalar_type()) { case at::ScalarType::Char: type = CUDA_R_8I; @@ -36,17 +81,19 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) { case at::ScalarType::BFloat16: type = CUDA_R_16BF; break; +#ifndef USE_ROCM case at::ScalarType::Float: type = CUDA_R_32F; break; -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 +#endif +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) case at::ScalarType::Float8_e4m3fn: type = CUDA_R_8F_E4M3; compression_factor = 10; break; #endif default: - TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix"); + TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt/hipSparseLt compressed matrix"); break; } @@ -120,6 +167,10 @@ std::tuple _cslt_sparse_mm_impl( cusparseComputeType compute_type; auto compression_factor = 9; + #ifdef USE_ROCM + TORCH_CHECK(isHipSparseLtSupported()); + #endif + switch (compressed_A.scalar_type()) { case at::ScalarType::Char: input_type = CUDA_R_8I; @@ -131,7 +182,7 @@ std::tuple _cslt_sparse_mm_impl( // cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F // to CUSPARSE_COMPUTE_32F -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 || defined(USE_ROCM) case at::ScalarType::Half: input_type = CUDA_R_16F; output_type = CUDA_R_16F; @@ -144,14 +195,16 @@ std::tuple _cslt_sparse_mm_impl( C_type = CUDA_R_16BF; compute_type = CUSPARSE_COMPUTE_32F; break; +#ifndef USE_ROCM case at::ScalarType::Float: input_type = CUDA_R_32F; output_type = CUDA_R_32F; C_type = CUDA_R_32F; compute_type = CUSPARSE_COMPUTE_32F; break; +#endif // if cuSPARSELt >= 6.2.3, we can add Float8 support -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) case at::ScalarType::Float8_e4m3fn: input_type = CUDA_R_8F_E4M3; output_type = CUDA_R_8F_E4M3; @@ -214,7 +267,7 @@ std::tuple _cslt_sparse_mm_impl( } } // cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) else if (input_type == CUDA_R_8F_E4M3) { switch (out_dtype) { case at::ScalarType::Float8_e4m3fn: diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 332d4a2ebfe97d..8647a199ad8e93 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -776,7 +776,7 @@ Tensor scaled_dot_product_attention( #ifdef USE_MPS const auto any_nested = query_.is_nested() || key.is_nested() || value.is_nested(); const bool any_inputs_require_grad = query_.requires_grad() || key.requires_grad() || value.requires_grad(); - const auto all_contiguous = query_.is_contiguous() && key.is_contiguous() && value.is_contiguous(); + const auto all_contiguous = query_.is_contiguous_or_false() && key.is_contiguous_or_false() && value.is_contiguous_or_false(); if (query_device_type == DeviceType::MPS && dropout_p == 0.0 && !(GradMode::is_enabled() && any_inputs_require_grad) && (all_contiguous || mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index affca278ad1a00..80049aa9a832f3 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -968,8 +968,8 @@ std::tuple _scaled_dot_product_efficient_attenti int64_t batch_size = query.size(0); if (batch_size > MAX_BATCH_SIZE) { - TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0), - "Efficient attention cannot produce valid seed, logsumexp and offset outputs when " + TORCH_CHECK(dropout_p == 0.0, + "Efficient attention cannot produce valid seed and offset outputs when " "the batch size exceeds (", MAX_BATCH_SIZE, ")."); } auto process_chunk = [&](const Tensor& q_chunk, @@ -1030,6 +1030,17 @@ std::tuple _scaled_dot_product_efficient_attenti } Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options()); final_attention.slice(0, start, end).copy_(attn); + Tensor final_log_sumexp; + if (compute_log_sumexp && log_sumexp.numel() > 0) { + std::vector lse_sizes; + lse_sizes.reserve(log_sumexp.dim()); + lse_sizes.push_back(batch_size); + for (int i = 1; i < log_sumexp.dim(); i++) { + lse_sizes.push_back(log_sumexp.size(i)); + } + final_log_sumexp = at::empty(std::move(lse_sizes), log_sumexp.options()); + final_log_sumexp.slice(0, start, end).copy_(log_sumexp); + } for (start = end; start < batch_size; start += MAX_BATCH_SIZE) { end = std::min(start + MAX_BATCH_SIZE, batch_size); @@ -1045,10 +1056,13 @@ std::tuple _scaled_dot_product_efficient_attenti auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] = process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk); final_attention.slice(0, start, end).copy_(chunk_attn); + if (compute_log_sumexp && chunk_log_sumexp.numel() > 0) { + final_log_sumexp.slice(0, start, end).copy_(chunk_log_sumexp); + } } return std::make_tuple(std::move(final_attention), - std::move(log_sumexp), + std::move(final_log_sumexp), std::move(seed), std::move(offset)); } @@ -1099,8 +1113,10 @@ _flash_attention_forward( std::optional alibi_slopes = _alibi_slopes; const float softcap = 0.0; - const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1; - const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1; +#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly. + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); +#endif // We are going to have two paths: // 1. The standard MHA path for dense tensors @@ -1137,8 +1153,13 @@ _flash_attention_forward( softmax_scale, false /*zero_tensors*/, is_causal, +#ifdef USE_ROCM + window_size_left, + window_size_right, +#else non_null_window_left, non_null_window_right, +#endif softcap, return_debug_mask, std::nullopt /*gen_*/); @@ -1161,8 +1182,13 @@ _flash_attention_forward( dropout_p, softmax_scale, is_causal, +#ifdef USE_ROCM + window_size_left, + window_size_right, +#else non_null_window_left, non_null_window_right, +#endif softcap, return_debug_mask, /*return_softmax (this is used for testing)*/ std::nullopt); diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index a5d976447cae5e..8940bea9a27fcc 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -24,6 +24,8 @@ #include #include #else +#include +#include #include #include #include @@ -85,8 +87,10 @@ std::tuple _flash_attention_backward( auto contiguous_grad_out = grad_out.contiguous(); auto contiguous_out = out.contiguous(); +#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly. const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1; const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1; +#endif std::optional dq{std::nullopt}; std::optional dk{std::nullopt}; @@ -134,8 +138,13 @@ std::tuple _flash_attention_backward( softmax_scale, false /*zero_tensors*/, is_causal, +#ifdef USE_ROCM + window_size_left, + window_size_right, +#else non_null_window_left, non_null_window_right, +#endif softcap, determinisitic, philox_seed, @@ -157,8 +166,13 @@ std::tuple _flash_attention_backward( dropout_p, softmax_scale, is_causal, +#ifdef USE_ROCM + window_size_left, + window_size_right, +#else non_null_window_left, non_null_window_right, +#endif softcap, determinisitic, philox_seed, @@ -748,7 +762,7 @@ _efficient_attention_backward( // when we need a staging area for gK/gV. let's avoid that if (Kernel::kNeedsAccumGradK || Kernel::kNeedsAccumGradV) { p.num_splits_key = std::min( - int(p.num_splits_key), 200 / (p.num_batches * p.num_heads)); + int32_t(p.num_splits_key), 200 / ((int32_t)(p.num_batches * p.num_heads))); } } if (!Kernel::kEnableSplitKeys || p.num_splits_key < 1) { @@ -905,40 +919,56 @@ std::tuple _scaled_dot_product_e if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); } - auto grad_out = grad_out_.transpose(1, 2); + constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1; + int64_t batch_size = query.size(0); + + if (batch_size > MAX_BATCH_SIZE) { + TORCH_CHECK(dropout_p == 0.0, + "Efficient attention backward cannot handle dropout when " + "the batch size exceeds (", MAX_BATCH_SIZE, ")."); + } + auto grad_out_t = grad_out_.transpose(1, 2); + auto query_t = query.transpose(1, 2); + auto key_t = key.transpose(1, 2); + auto value_t = value.transpose(1, 2); auto out_t = out.transpose(1, 2); - auto q_t = query.transpose(1, 2); - auto k_t = key.transpose(1, 2); - auto v_t = value.transpose(1, 2); + auto process_chunk = [&](const Tensor& grad_out_chunk, + const Tensor& query_chunk, + const Tensor& key_chunk, + const Tensor& value_chunk, + const std::optional& attn_bias_chunk, + const Tensor& out_chunk, + const Tensor& logsumexp_chunk) + -> std::tuple { // This is needed because SaveVariable automatically converts // std::optional to undefined tensor std::optional kernel_bias; - if (attn_bias.defined()) { - kernel_bias = attn_bias; + if (attn_bias_chunk.has_value() && attn_bias_chunk.value().defined()) { + kernel_bias = attn_bias_chunk.value(); } // Will add with signauter changes for dropout and bias // We are only handling Dense inputs, but this should be passed // from forward to backward - int64_t max_seqlen_q = q_t.size(1); - int64_t max_seqlen_k = k_t.size(1); + int64_t max_seqlen_q = query_chunk.size(2); + int64_t max_seqlen_k = key_chunk.size(2); sdp::CustomMaskType custom_mask_type = causal ? sdp::CustomMaskType::CausalFromTopLeft : sdp::CustomMaskType::NoCustomMask; auto [grad_q, grad_k, grad_v, grad_bias] = at::_efficient_attention_backward( - grad_out, - q_t, - k_t, - v_t, + grad_out_chunk, + query_chunk, + key_chunk, + value_chunk, kernel_bias, - out_t, + out_chunk, std::nullopt, std::nullopt, max_seqlen_q, max_seqlen_k, - logsumexp, + logsumexp_chunk, dropout_p, philox_seed, philox_offset, @@ -947,7 +977,90 @@ std::tuple _scaled_dot_product_e scale, std::nullopt); // num_split_keys return std::make_tuple( - grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias); + grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), std::move(grad_bias)); + }; + + // process in chunks if batch size exceeds maximum + if (batch_size > MAX_BATCH_SIZE) { + Tensor final_grad_q, final_grad_k, final_grad_v, final_grad_bias; + + auto create_strided_output = [batch_size](const Tensor& tensor) -> Tensor { + if (!tensor.defined()) { + return Tensor{}; + } + int dim = tensor.dim(); + std::vector sizes; + sizes.reserve(dim); + sizes.push_back(batch_size); + for (int i = 1; i < dim; i++) { + sizes.push_back(tensor.size(i)); + } + return at::empty_strided(std::move(sizes), tensor.strides(), tensor.options()); + }; + + if (grad_input_mask[0]) { + final_grad_q = create_strided_output(query); + } + + if (grad_input_mask[1]) { + final_grad_k = create_strided_output(key); + } + + if (grad_input_mask[2]) { + final_grad_v = create_strided_output(value); + } + if (grad_input_mask[3] && attn_bias.defined()) { + final_grad_bias = at::zeros_like(attn_bias); + } + + for (int64_t start = 0; start < batch_size; start += MAX_BATCH_SIZE) { + int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size); + + Tensor grad_out_chunk = grad_out_t.slice(0, start, end); + Tensor query_chunk = query_t.slice(0, start, end); + Tensor key_chunk = key_t.slice(0, start, end); + Tensor value_chunk = value_t.slice(0, start, end); + Tensor attn_bias_chunk; + if (attn_bias.defined()) { + attn_bias_chunk = attn_bias.slice(0, start, end); + } else { + attn_bias_chunk.reset(); + } + Tensor out_chunk = out_t.slice(0, start, end); + Tensor logsumexp_chunk = logsumexp.numel() > 0 ? logsumexp.slice(0, start, end) : logsumexp; + + auto [chunk_grad_q, chunk_grad_k, chunk_grad_v, chunk_grad_bias] = + process_chunk(grad_out_chunk, query_chunk, key_chunk, value_chunk, + attn_bias_chunk, out_chunk, logsumexp_chunk); + + if (grad_input_mask[0] && chunk_grad_q.defined()) { + final_grad_q.slice(0, start, end).copy_(chunk_grad_q); + } + if (grad_input_mask[1] && chunk_grad_k.defined()) { + final_grad_k.slice(0, start, end).copy_(chunk_grad_k); + } + if (grad_input_mask[2] && chunk_grad_v.defined()) { + final_grad_v.slice(0, start, end).copy_(chunk_grad_v); + } + if (grad_input_mask[3] && chunk_grad_bias.defined()) { + final_grad_bias.add_(chunk_grad_bias); + } + } + + return std::make_tuple( + std::move(final_grad_q), + std::move(final_grad_k), + std::move(final_grad_v), + std::move(final_grad_bias)); + } + // when batch size is within allowed size, no chunking needed + else { + std::optional attn_bias_opt; + if (attn_bias.defined()) { + attn_bias_opt = attn_bias; + } + return process_chunk(grad_out_t, query_t, key_t, value_t, attn_bias_opt, out_t, logsumexp); + } } } // namespace at::native diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index 4f6267f3f3c723..ae649e99c4cd80 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -666,12 +666,12 @@ struct AttentionBackwardKernel { int32_t num_heads = -1; uint8_t custom_mask_type = NoCustomMask; - int32_t q_strideM = -1; - int32_t k_strideM = -1; - int32_t v_strideM = -1; - int32_t bias_strideM = 0; - int32_t gO_strideM = -1; - int32_t gB_strideM = -1; + int64_t q_strideM = -1; + int64_t k_strideM = -1; + int64_t v_strideM = -1; + int64_t bias_strideM = 0; + int64_t gO_strideM = -1; + int64_t gB_strideM = -1; int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise at::PhiloxCudaState rng_engine_inputs = {0, 0}; @@ -680,16 +680,16 @@ struct AttentionBackwardKernel { unsigned long long dropout_batch_head_rng_offset = 0; float dropout_prob = 0.0f; - CUTLASS_HOST_DEVICE int32_t o_strideM() const { + CUTLASS_HOST_DEVICE int64_t o_strideM() const { return head_dim_value * num_heads; } - CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + CUTLASS_HOST_DEVICE int64_t gQ_strideM() const { return gQKV_strideM_multiplier * num_heads * head_dim; } - CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + CUTLASS_HOST_DEVICE int64_t gK_strideM() const { return gQKV_strideM_multiplier * num_heads * head_dim; } - CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + CUTLASS_HOST_DEVICE int64_t gV_strideM() const { return gQKV_strideM_multiplier * num_heads * head_dim_value; } @@ -858,14 +858,14 @@ struct AttentionBackwardKernel { return 0; } return num_splits_key * kBlockSizeJ * - align_up(head_dim, (int32_t)kBlockSizeI); + align_up(head_dim, kBlockSizeI); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { if (!kNeedsAccumGradV) { return 0; } return num_splits_key * kBlockSizeJ * - align_up(head_dim_value, (int32_t)kBlockSizeI); + align_up(head_dim_value, kBlockSizeI); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { if (!kNeedsAccumGradQ) { @@ -1307,12 +1307,12 @@ struct AttentionBackwardKernel { uint8_t warp_id = warp_uniform(thread_id / 32); uint8_t lane_id = thread_id % 32; - int32_t key_start = p.split_key_device() * kBlockSizeJ; + int64_t key_start = p.split_key_device() * kBlockSizeJ; if (key_start >= p.num_keys) { return; } if (kPrologueQK) { - int32_t query_start = getQueryStart(p, key_start); + int64_t query_start = getQueryStart(p, key_start); prologueQkNextIteration( shared_storage, p, query_start, key_start, warp_id, lane_id); } @@ -1362,8 +1362,8 @@ struct AttentionBackwardKernel { key_start += p.num_splits_key_device() * kBlockSizeJ) { output_frags.clear(); - int32_t next_key = key_start; - int32_t query_start = getQueryStart(p, key_start); + int64_t next_key = key_start; + int64_t query_start = getQueryStart(p, key_start); while (next_key == key_start && query_start < p.num_queries) { // This line here // vvvvvvvvvvvvvv @@ -1385,7 +1385,7 @@ struct AttentionBackwardKernel { warp_id, lane_id); - int32_t next_query; + int64_t next_query; incrIteration(p, query_start, key_start, next_query, next_key); query_start = next_query; } @@ -1439,8 +1439,8 @@ struct AttentionBackwardKernel { SharedStorage& shared_storage, OutputFragments& output_frags, Params& p, - int32_t query_start, - int32_t key_start, + int64_t query_start, + int64_t key_start, const curandStatePhilox4_32_10_t& curand_state_init, uint8_t warp_id, uint8_t lane_id) { @@ -1463,7 +1463,7 @@ struct AttentionBackwardKernel { }; bool isFirstQuery = (query_start == getQueryStart(p, key_start)); - int32_t next_query, next_key; + int64_t next_query, next_key; incrIteration(p, query_start, key_start, next_query, next_key); bool isLastQuery = next_key != key_start; @@ -1478,17 +1478,17 @@ struct AttentionBackwardKernel { int32_t num_queries_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kN : warp_uniform(cutlass::fast_min( - (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start)); + MatmulQK::Mma::Shape::kN, (int32_t)(p.num_queries - query_start))); int32_t num_keys_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kM : warp_uniform(cutlass::fast_min( - (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); + MatmulQK::Mma::Shape::kM, (int32_t)(p.num_keys - key_start))); - auto prologueGradV = [&](int col) { + auto prologueGradV = [&](int64_t col) { typename MatmulGradV::Mma::IteratorB iterator_dO( {int32_t(p.gO_strideM)}, const_cast(p.grad_output_ptr + query_start * p.gO_strideM + col), - {num_queries_in_block, p.head_dim_value - col}, + {num_queries_in_block, (int32_t)(p.head_dim_value - col)}, thread_id, no_offset); MatmulGradV::Mma::prologue( @@ -1709,7 +1709,7 @@ struct AttentionBackwardKernel { // on the K-dimension) otherwise we can get NaNs during the GEMM const int kQueriesPerBlock = kBlockSizeI; const int threads_per_row = cutlass::fast_min( - int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block); + kNumThreads / kQueriesPerBlock, (int64_t)num_keys_in_block); const int elts_per_thread = cutlass::round_nearest( cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); @@ -1779,7 +1779,7 @@ struct AttentionBackwardKernel { ///////////////////////////////////////////////////////////////////////////////////////////////// constexpr bool kSingleIterationGradV = kMaxK <= MatmulGradV::ThreadblockShape::kN; - for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); + for (int32_t col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); col += MatmulGradV::ThreadblockShape::kN) { using Mma = typename MatmulGradV::Mma; using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; @@ -2232,7 +2232,7 @@ struct AttentionBackwardKernel { } if (kPrologueQK && isLastColumn) { - int32_t next_query, next_key; + int64_t next_query, next_key; incrIteration(p, query_start, key_start, next_query, next_key); DISPATCH_BOOL( next_key != key_start, kForceReloadK, ([&]() { @@ -2260,7 +2260,7 @@ struct AttentionBackwardKernel { } } - static CUTLASS_HOST_DEVICE int32_t getQueryStartShift(Params const& p) { + static CUTLASS_HOST_DEVICE int64_t getQueryStartShift(Params const& p) { if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); } @@ -2268,27 +2268,27 @@ struct AttentionBackwardKernel { } // Iteration order logic - static CUTLASS_HOST_DEVICE int32_t - getQueryStart(Params const& p, int32_t key_start) { + static CUTLASS_HOST_DEVICE int64_t + getQueryStart(Params const& p, int64_t key_start) { return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); }; - static CUTLASS_HOST_DEVICE int32_t getQueryEnd(Params const& p) { + static CUTLASS_HOST_DEVICE int64_t getQueryEnd(Params const& p) { return align_up(p.num_queries, kBlockSizeI); }; - static CUTLASS_HOST_DEVICE int32_t - getSmallestQueryForKey(Params const& p, int32_t key_start) { + static CUTLASS_HOST_DEVICE int64_t + getSmallestQueryForKey(Params const& p, int64_t key_start) { if (p.custom_mask_type == NoCustomMask) { return 0; } - int32_t shift = p.custom_mask_type == CausalFromBottomRight + int64_t shift = p.custom_mask_type == CausalFromBottomRight ? p.num_keys - p.num_queries : 0; - int32_t window_size = + int64_t window_size = p.window_size == 0 ? p.num_queries + p.num_keys : p.window_size; auto last_key_for_block = - cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1; + cutlass::fast_min(key_start + kBlockSizeJ, (int64_t)p.num_keys) - 1; int first_query = key_start - shift; int last_query = last_key_for_block - shift + window_size - 1; if (last_query < 0 || first_query >= p.num_queries) { @@ -2333,10 +2333,10 @@ struct AttentionBackwardKernel { // Returns the next block to process static CUTLASS_HOST_DEVICE void incrIteration( Params const& p, - int32_t query_start, - int32_t key_start, - int32_t& next_query, - int32_t& next_key) { + int64_t query_start, + int64_t key_start, + int64_t& next_query, + int64_t& next_key) { next_query = query_start + kBlockSizeI; next_key = key_start; auto query_shift = getQueryStartShift(p); @@ -2357,7 +2357,7 @@ struct AttentionBackwardKernel { : 0; // last key that is not masked out int last_key_for_block = - cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1; + cutlass::fast_min(key_start + kBlockSizeJ, (int64_t)p.num_keys) - 1; int last_query = last_key_for_block - shift + p.window_size - 1; if (next_query <= last_query && next_query < p.num_queries) { return; @@ -2368,7 +2368,7 @@ struct AttentionBackwardKernel { // jump to next key } // Next key - next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + next_key = key_start + p.num_splits_key_device() * (int64_t)kBlockSizeJ; next_query = getQueryStart(p, next_key); } @@ -2422,7 +2422,7 @@ struct AttentionBackwardKernel { int32_t num_keys_in_block = skipBoundsChecks ? MatmulQK::Mma::Shape::kM : cutlass::fast_min( - (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + MatmulQK::Mma::Shape::kM, p.num_keys - key_start); typename MatmulGradV::OutputTileIterator outputV_it( typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, p.grad_value_ptr + key_start * p.gV_strideM(), diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 92a3dabbd7d5bd..1908096e2f6fad 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -64,8 +64,14 @@ #include #include -#if AOTRITON_VERSION_MINOR != 9 -#error "This adaptor code is only tested with AOTriton 0.9.x" +#if AOTRITON_VERSION_MINOR < 9 +#error "This adaptor code is only tested with AOTriton >= 0.9" +#endif + +#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10 +#define V3_API 1 +#else +#define V3_API 0 #endif namespace pytorch_flash { @@ -81,6 +87,38 @@ void check_gpu_arch(hipStream_t stream) { } } +std::tuple +calculate_swa(std::optional window_size_left, + std::optional window_size_right, + int max_seqlen_q, + int max_seqlen_k, + bool is_causal) { +#if V3_API // SWA is exposed through V3 API + bool needs_swa = false; + using aotriton::v3::flash::WindowValue; + // Default values when std::optional window_size_left/right have no value + int window_left = max_seqlen_q; + int window_right = max_seqlen_k; + if (is_causal) { + window_left = WindowValue::TopLeftAligned; + window_right = WindowValue::TopLeftAligned; + } + if (window_size_left.has_value() || window_size_right.has_value()) { + needs_swa = true; + window_left = window_size_left.value_or(window_left); + window_right = window_size_right.value_or(window_right); + } + return std::make_tuple(needs_swa, window_left, window_right); +#else + if (window_size_left.has_value() || window_size_right.has_value()) { + TORCH_WARN_ONCE("Current AOTriton does not support sliding window attention (SWA)." + " Both window_size_left and window_size_right will be ignored." + " Re-compile PyTorch with AOTriton >= 0.10b to enable SWA support."); + } + return std::make_tuple(false, 0, 0); +#endif +} + // We want to checkpoint and save the RNG state for backward if dropout // We get the default generator and return the seed and offset which will // be used in the backward function @@ -127,8 +165,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x const float p_dropout, const float softmax_scale, bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const bool return_softmax, const std::optional& gen_) { auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); @@ -161,7 +199,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - if (is_causal) { window_size_right = 0; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); @@ -212,6 +249,19 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); } + auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, + window_size_right, + seqlen_q, + seqlen_k, + is_causal); +#if V3_API + const bool uses_swa = needs_swa; +#else + // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // optimized out (hopefully). + constexpr bool uses_swa = false; +#endif + hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; @@ -226,23 +276,54 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(output_t, "Out"), - p_dropout, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - persistent_counter, - stream); + if (uses_swa) { +#if V3_API + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + aotriton::v3::flash::attn_fwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.Sm_scale = softmax_scale; + params.L = mk_aotensor<2>(M, "M"); + params.Out = mk_aotensor(output_t, "Out"); + params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = p_dropout; + params.philox_seed_ptr = seed; + params.philox_offset1 = offset1; + params.philox_offset2 = offset2; + params.philox_seed_output = seed_output; + params.philox_offset_output = offset_output; + params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); + params.persistent_atomic_counter = persistent_counter; + params.causal_type = CausalType::WindowedAttention; + params.varlen_type = VarlenType::None; + params.window_left = window_left; + params.window_right = window_right; + err = aotriton::v3::flash::attn_fwd(params, + aotriton::v3::flash::attn_fwd_params::kVersion, + stream); +#endif + } else { + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + empty_bias, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(output_t, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); + } return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t}; } @@ -263,8 +344,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot const float softmax_scale, const bool zero_tensors, bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const bool return_softmax, const std::optional& gen_) { TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); @@ -312,13 +393,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot TORCH_CHECK(head_size_og <= 512, "FlashAttention on ROCm forward only supports head dimension at most 512"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (window_size_left >= max_seqlen_k) { - window_size_left = -1; - } - if (window_size_right >= max_seqlen_k) { - window_size_right = -1; - } - CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og); const int total_k = k.size(0); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); @@ -368,6 +442,19 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot } } + auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, + window_size_right, + max_seqlen_q, + max_seqlen_k, + is_causal); +#if V3_API + const bool uses_swa = needs_swa; +#else + // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // optimized out (hopefully). + constexpr bool uses_swa = false; +#endif + auto [seed_t, offset_t, philox_state, use_philox_state] = prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); @@ -390,27 +477,60 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; - err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), - mk_aotensor(k_padded, "k"), - mk_aotensor(v_padded, "v"), - empty_bias, - mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), - mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), - max_seqlen_q, - max_seqlen_k, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(out_padded, "Out"), - p_dropout, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - persistent_counter, - stream); + if (uses_swa) { +#if V3_API + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + aotriton::v3::flash::attn_fwd_params params; + params.Q = mk_aotensor(q_padded, "q"); + params.K = mk_aotensor(k_padded, "k"); + params.V = mk_aotensor(v_padded, "v"); + params.Sm_scale = softmax_scale; + params.L = mk_aotensor<2>(M, "M"); + params.Out = mk_aotensor(out_padded, "Out"); + params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = p_dropout; + params.philox_seed_ptr = seed; + params.philox_offset1 = offset1; + params.philox_offset2 = offset2; + params.philox_seed_output = seed_output; + params.philox_offset_output = offset_output; + params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); + params.persistent_atomic_counter = persistent_counter; + params.causal_type = CausalType::WindowedAttention; + params.varlen_type = VarlenType::CompactVarlen; + params.window_left = window_left; + params.window_right = window_right; + err = aotriton::v3::flash::attn_fwd(params, + aotriton::v3::flash::attn_fwd_params::kVersion, + stream); +#endif + } else { + err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + empty_bias, + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(out_padded, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); + } } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); @@ -434,8 +554,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const bool deterministic, const at::Tensor& philox_seed, const at::Tensor& philox_offset) { @@ -524,6 +644,19 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea dv = at::empty_like(k); } + auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, + window_size_right, + seqlen_q, + seqlen_k, + is_causal); +#if V3_API + const bool uses_swa = needs_swa; +#else + // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // optimized out (hopefully). + constexpr bool uses_swa = false; +#endif + auto opts = q.options(); auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); @@ -541,10 +674,42 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea int d_head = head_size_og; bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512; hipError_t err; // TODO: Error handling - if (use_fused_bwd) { + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + if (uses_swa) { +#if V3_API + // Fused BWD does not support SWA + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + aotriton::v3::flash::attn_bwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.Sm_scale = softmax_scale; + params.Out = mk_aotensor(out_t, "out"); + params.DO = mk_aotensor(dout_t, "dout"); + params.DK = mk_aotensor(dq_t, "dq"); + params.DV = mk_aotensor(dk_t, "dk"); + params.DQ = mk_aotensor(dv_t, "dv"); + params.L = mk_aotensor<2>(softmax_lse_cont, "L"); + params.D = mk_aotensor<2>(delta, "delta"); + params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = p_dropout; + params.philox_seed_ptr = mk_aoscalartensor(philox_seed); + params.philox_offset1 = mk_aoscalartensor(philox_offset); + params.philox_offset2 = 0; + params.causal_type = CausalType::WindowedAttention; + params.varlen_type = VarlenType::None; + params.window_left = window_left; + params.window_right = window_right; + err = aotriton::v3::flash::attn_bwd(params, + aotriton::v3::flash::attn_bwd_params::kVersion, + stream); +#endif + } else if (use_fused_bwd) { using aotriton::v2::flash::attn_bwd_fused; - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd_fused(mk_aotensor(q_t, "q"), @@ -568,8 +733,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea } else { at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); using aotriton::v2::flash::attn_bwd; - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd(mk_aotensor(q_t, "q"), @@ -615,17 +778,14 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size const float softmax_scale, const bool zero_tensors, const bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const bool deterministic, const at::Tensor& philox_seed, const at::Tensor& philox_offset) { TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); - if (is_causal) { - window_size_right = 0; - } // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; @@ -669,9 +829,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size <= 512, "FlashAttention on ROCm backward only supports head dimension at most 512"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (window_size_left >= max_seqlen_k) { window_size_left = -1; } - if (window_size_right >= max_seqlen_k) { window_size_right = -1; } - CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); @@ -734,6 +891,19 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_d.zero_(); } + auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, + window_size_right, + max_seqlen_q, + max_seqlen_k, + is_causal); +#if V3_API + const bool uses_swa = needs_swa; +#else + // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // optimized out (hopefully). + constexpr bool uses_swa = false; +#endif + at::PhiloxCudaState philox_args; if (is_dropout) { if (at::cuda::currentStreamCaptureStatus() == @@ -747,34 +917,68 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size } if (max_seqlen_q > 0) { hipError_t err; // TODO: Error handling - using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), - mk_aotensor(k_padded, "k"), - mk_aotensor(v_padded, "v"), - mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), - mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), - max_seqlen_q, - max_seqlen_k, - empty_bias, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_padded, "dq"), - mk_aotensor(dk_padded, "dk"), - mk_aotensor(dv_padded, "dv"), - empty_bias, - mk_aotensor<2>(softmax_lse_cont, "L"), - mk_aotensor<2>(delta, "delta"), - p_dropout, - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); + if (uses_swa) { +#if V3_API + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + aotriton::v3::flash::attn_bwd_params params; + params.Q = mk_aotensor(q_padded, "q"); + params.K = mk_aotensor(k_padded, "k"); + params.V = mk_aotensor(v_padded, "v"); + params.Sm_scale = softmax_scale; + params.Out = mk_aotensor(out_t, "out"); + params.DO = mk_aotensor(dout_t, "dout"); + params.DK = mk_aotensor(dq_padded, "dq"); + params.DV = mk_aotensor(dk_padded, "dk"); + params.DQ = mk_aotensor(dv_padded, "dv"); + params.L = mk_aotensor<2>(softmax_lse_cont, "L"); + params.D = mk_aotensor<2>(delta, "delta"); + params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = p_dropout; + params.philox_seed_ptr = mk_aoscalartensor(philox_seed); + params.philox_offset1 = mk_aoscalartensor(philox_offset); + params.philox_offset2 = 0; + params.causal_type = CausalType::WindowedAttention; + params.varlen_type = VarlenType::CompactVarlen; + params.window_left = window_left; + params.window_right = window_right; + err = aotriton::v3::flash::attn_bwd(params, + aotriton::v3::flash::attn_bwd_params::kVersion, + stream); +#endif + } else { + using aotriton::v2::flash::attn_bwd_compact_varlen; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_padded, "dq"), + mk_aotensor(dk_padded, "dk"), + mk_aotensor(dv_padded, "dv"), + empty_bias, + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dq.zero_(); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt index a72911cd510eb6..b30c3934003606 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt @@ -1,7 +1,7 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( - COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + --api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt RESULT_VARIABLE ret ) @@ -10,8 +10,8 @@ if(ret AND NOT ret EQUAL 0) endif() execute_process( - COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + --api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt RESULT_VARIABLE ret ) @@ -20,14 +20,14 @@ if(ret AND NOT ret EQUAL 0) endif() # Generate the files for both fwd and bwd -execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR} ) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.") endif() -execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR} RESULT_VARIABLE ret ) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt new file mode 100644 index 00000000000000..cccf026690dc0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt @@ -0,0 +1,20 @@ +include(CMakePrintHelpers) + +# Generate AITER/CK Asm code +execute_process( + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/py_itfs_cu/fmha_v3_bwd_kernel_generate.py --receipt 1 --output_dir ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Failed to generate FAv3 CK Kernels") +endif() + +execute_process( + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/cpp_itfs/mha_bwd_generate.py --receipt 3 --output_dir ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret +) + + +# Change file extensions to .hip +execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done") diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp index 38ec2ef20c5cca..affa40619b598f 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp @@ -453,4 +453,5 @@ struct fmha_bwd_traits bool is_deterministic; // TODO: padding check is inside this api }; +template float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp index 4fb6e95bd33648..400da17426f1d4 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp @@ -14,7 +14,7 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) #endif __global__ void kentry_pt(Args... args) { -#if (defined(__gfx90a__) || defined(__gfx942__)) +#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) Kernel{}(args...); #else CUDA_KERNEL_ASSERT(false && "Fatal! Attempting to call a CK SDPA kernel on unsupported hardware"); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip index 0a99d5a81568dc..854ac950a867d1 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip @@ -3,6 +3,7 @@ ******************************************************************************/ #include +#include #include #include @@ -28,6 +29,26 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, deterministic}; } + + +aiter::mha_bwd_traits get_mha_bwd_traits(fmha_bwd_traits t, mask_info mask) +{ + return aiter::mha_bwd_traits(t.hdim_q, + t.hdim_v, + t.data_type, + t.is_group_mode, + mask.type, + t.bias_type, + t.has_dbias, + t.has_dropout, + t.is_store_randval, + t.is_deterministic, + true, // use_ext_asm + true, // is_v3_atomic_fp32, + 1); // how_v3_bf16_cvt + +} + fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, // sizes const int b, @@ -101,11 +122,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t nhead_stride_dv = dv.stride(2); - // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) + // dq_acc: (split, batch_size, nheads, seqlen_q, hdim) ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); - ck_tile::index_t stride_dq_acc = dq_acc.stride(2); - ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); + ck_tile::index_t stride_dq_acc = dq_acc.stride(3); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); // bias: (batch_size, nheads, seqlen_q, seqlen_k) void *attn_bias_ptr = nullptr; @@ -351,11 +372,11 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x at::Tensor dq_accum; if (!deterministic) { - dq_accum = at::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); } else { const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); - dq_accum = at::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; @@ -376,14 +397,6 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x if (seqlen_q > 0) { ck_tile::stream_config stream_config{stream}; dq.zero_(); // ck use atomic operation on dq - auto traits = - get_ck_fmha_bwd_traits(mask, - q_dtype_str, - head_size_8x, - is_dropout, - attn_bias_.has_value(), - deterministic, - bias_requires_grad); auto args = get_ck_fmha_bwd_args( @@ -411,7 +424,23 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x softmax_scale, p_dropout, drop_seed_offset); - float t = fmha_bwd(traits, args, stream_config); + + float t = aiter::mha_bwd(args, + stream_config, + q_dtype_str, + false, // is_group_mode + mask.type, + attn_bias_.has_value() ? bias_enum::elementwise_bias : bias_enum::no_bias, + bias_requires_grad, + false, // is_store_randval + deterministic, + true, // use_ext_asm + true, // is_v3_atomic_fp32 + 1); // how_v3_bf16_cvt + + + + TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index ead742a1efd67d..17298aae9485da 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -51,8 +51,8 @@ mha_fwd_aot( const float p_dropout, const float softmax_scale, bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const bool return_softmax, const std::optional& gen_); @@ -87,8 +87,8 @@ mha_varlen_fwd_aot( const float softmax_scale, const bool zero_tensors, bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const bool return_softmax, const std::optional& gen_); @@ -110,8 +110,8 @@ std::tuple mha_bwd_aot( const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const bool deterministic, const at::Tensor& philox_seed, const at::Tensor& philox_offset); @@ -141,8 +141,8 @@ std::tuple mha_varlen_bwd_aot( const float softmax_scale, const bool zero_tensors, const bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const bool deterministic, const at::Tensor& philox_seed, const at::Tensor& philox_offset); @@ -290,14 +290,16 @@ mha_fwd( const float p_dropout, const float softmax_scale, bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const float softcap, const bool return_softmax, std::optional gen_) { #if defined(USE_CK_FLASH_ATTENTION) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); std::optional dummy_attn_bias = std::nullopt; return mha_fwd_ck( q, @@ -307,27 +309,13 @@ mha_fwd( p_dropout, softmax_scale, is_causal, - window_size_left, - window_size_right, + non_null_window_left, + non_null_window_right, return_softmax, gen_, dummy_attn_bias); // Not used in flash attention - } else { - return mha_fwd_aot( - q, - k, - v, - out_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); } -#else +#endif return mha_fwd_aot( q, k, @@ -341,7 +329,6 @@ mha_fwd( window_size_right, return_softmax, gen_); -#endif } inline std::tuple< @@ -376,8 +363,8 @@ mha_varlen_fwd( const float softmax_scale, const bool zero_tensors, bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const float softcap, const bool return_softmax, std::optional gen_) { @@ -385,6 +372,8 @@ mha_varlen_fwd( if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional dummy_attn_bias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); return mha_varlen_fwd_ck( q, k, @@ -399,34 +388,13 @@ mha_varlen_fwd( softmax_scale, zero_tensors, is_causal, - window_size_left, - window_size_right, + non_null_window_left, + non_null_window_right, return_softmax, gen_, dummy_attn_bias); // Not used in flash attention - } else { - return mha_varlen_fwd_aot( - q, - k, - v, - out_, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - block_table_, - alibi_slopes_, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); } -#else +#endif return mha_varlen_fwd_aot( q, k, @@ -447,7 +415,6 @@ mha_varlen_fwd( window_size_right, return_softmax, gen_); -#endif } inline std::tuple mha_bwd( @@ -468,16 +435,18 @@ inline std::tuple mha_bwd( const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const float softcap, const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { -#if defined(USE_CK_FLASH_ATTENTION) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { +#if defined(USE_CK_FLASH_ATTENTION) std::optional non_null_dbias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); auto[dQuery, dKey, dValue, @@ -498,38 +467,16 @@ inline std::tuple mha_bwd( p_dropout, softmax_scale, is_causal, - window_size_left, - window_size_right, + non_null_window_left, + non_null_window_right, deterministic, philox_seed, philox_offset); // for FA return [dQ, dV, dK, dSoftmax] return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); - } else { - return mha_bwd_aot( - dout, - q, - k, - v, - out, - softmax_lse, - dq_, - dk_, - dv_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - deterministic, - philox_seed, - philox_offset); - } #else - if(at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); +#endif } return mha_bwd_aot( dout, @@ -550,7 +497,6 @@ inline std::tuple mha_bwd( deterministic, philox_seed, philox_offset); -#endif } inline std::tuple mha_varlen_bwd( @@ -578,8 +524,8 @@ inline std::tuple mha_varlen_bwd const float softmax_scale, const bool zero_tensors, const bool is_causal, - int window_size_left, - int window_size_right, + std::optional window_size_left, + std::optional window_size_right, const float softcap, const bool deterministic, const at::Tensor philox_seed, @@ -588,6 +534,8 @@ inline std::tuple mha_varlen_bwd if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional non_null_dbias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); auto[dQuery, dKey, dValue, @@ -613,40 +561,15 @@ inline std::tuple mha_varlen_bwd softmax_scale, zero_tensors, is_causal, - window_size_left, - window_size_right, + non_null_window_left, + non_null_window_right, deterministic, philox_seed, philox_offset); // for FA return [dQ, dV, dK, dSoftmax] return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); - } else { - return mha_varlen_bwd_aot( - dout, - q, - k, - v, - out, - softmax_lse, - dq_, - dk_, - dv_, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes_, - max_seqlen_q, - max_seqlen_k, - p_dropout, - softmax_scale, - zero_tensors, - is_causal, - window_size_left, - window_size_right, - deterministic, - philox_seed, - philox_offset); } -#else +#endif return mha_varlen_bwd_aot( dout, q, @@ -671,7 +594,6 @@ inline std::tuple mha_varlen_bwd deterministic, philox_seed, philox_offset); -#endif } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/sdp_utils.h b/aten/src/ATen/native/transformers/sdp_utils.h new file mode 100644 index 00000000000000..809abe50178ec7 --- /dev/null +++ b/aten/src/ATen/native/transformers/sdp_utils.h @@ -0,0 +1,88 @@ +#pragma once +#include +#include + +namespace at::native { + +void alloc_with_matching_layout( + const Tensor& q, + Tensor& output, + const std::vector& shape) { + TORCH_INTERNAL_ASSERT( + shape.size() == q.sizes().size(), + "SDPA alloc_with_matching_layout got requested shape ndim != q ndim"); + + if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) { + output = at::empty_like(q); + return; + } + + // get the "fill order," which is just an argsort on the strides + std::vector fill_order(shape.size()); + std::iota(fill_order.begin(), fill_order.end(), 0); + const auto q_strides = q.strides(); + std::stable_sort( + fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) { + return q_strides[idx1] ? q_strides[idx1] : 1 < q_strides[idx2] ? q_strides[idx2] : 1; + }); + std::vector ordered_strides(shape.size()); + int64_t current_stride = 1; + for (const int dim_idx : fill_order) { + ordered_strides[dim_idx] = current_stride; + current_stride *= shape[dim_idx]; + } + output = at::empty(at::IntArrayRef(shape), q.options()) + .as_strided( + at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0); +} + +void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) { + const int dims = output.sizes().size(); + std::vector outer_to_inner(dims); + std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0); + const auto o_strides = output.strides(); + std::stable_sort( + outer_to_inner.begin(), + outer_to_inner.end(), + [&o_strides](int idx1, int idx2) { + return o_strides[idx1] > o_strides[idx2]; + }); + std::vector inverse(dims); + for (int d = 0; d < dims; d++) { + inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) - + outer_to_inner.begin(); + } + grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner)) + .contiguous() + .permute(at::IntArrayRef(inverse)); +} + +bool same_strides(const Tensor& t1, const Tensor& t2) { + std::vector t1_strides_no_ones; + std::vector t2_strides_no_ones; + const auto t1strides = t1.strides(); + const auto t2strides = t2.strides(); + const int dim = t1strides.size(); + if (dim != (int)t2strides.size()) { + return false; + } + const auto t1sizes = t1.sizes(); + const auto t2sizes = t2.sizes(); + + // we are going through strides backward here, but if both are backward it's + // comparable + for (int i = 0; i < dim; i++) { + if (t1sizes[i] > 1) { + t1_strides_no_ones.push_back(t1strides[i]); + } + if (t2sizes[i] > 1) { + t2_strides_no_ones.push_back(t2strides[i]); + } + } + return std::equal( + t1_strides_no_ones.begin(), + t1_strides_no_ones.end(), + t2_strides_no_ones.begin(), + t2_strides_no_ones.end()); +} +} // namespace at::native diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index 4591fa253824d0..aa5c2b6cdd6416 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -503,17 +503,8 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool if (ignore_singleton_dim){ qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1; } - bool mask_stride_equal_1 = params.attn_mask.has_value() - ? params.attn_mask.value().sym_stride(-1) == 1 - : true; - if (!(qkv_strides_equal_1 && mask_stride_equal_1)) { + if (!qkv_strides_equal_1) { if (debug) { - std::ostringstream epilogue_message; - if (params.attn_mask.has_value()) { - epilogue_message << ", Attn_mask.stride(-1): " - << params.attn_mask.value().sym_stride(-1); - } - epilogue_message << " instead."; TORCH_WARN( "All fused kernels require the last dimension of the input to have stride 1. ", "Got Query.stride(-1): ", @@ -522,7 +513,7 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool params.key.sym_stride(-1), ", Value.stride(-1): ", params.value.sym_stride(-1), - epilogue_message.str()); + " instead."); } return false; diff --git a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h index 04823c592ccb75..532caa62687a8e 100644 --- a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h +++ b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h @@ -33,7 +33,8 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl { return c10::fromIntArrayRefKnownNonNegative(strides_); } - bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { + c10::SymBool sym_is_contiguous_custom( + c10::MemoryFormat memory_format) const override { (void)memory_format; return true; } diff --git a/aten/src/ATen/native/vulkan/api/Runtime.cpp b/aten/src/ATen/native/vulkan/api/Runtime.cpp index c84d8d4e4c8e93..cf8402e40a0b8d 100644 --- a/aten/src/ATen/native/vulkan/api/Runtime.cpp +++ b/aten/src/ATen/native/vulkan/api/Runtime.cpp @@ -78,6 +78,9 @@ VkInstance create_instance(const RuntimeConfiguration& config) { #ifdef VK_EXT_debug_report VK_EXT_DEBUG_REPORT_EXTENSION_NAME, #endif /* VK_EXT_debug_report */ +#ifdef __APPLE__ + VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME, +#endif // __APPLE__ }; find_requested_layers_and_extensions( @@ -90,7 +93,11 @@ VkInstance create_instance(const RuntimeConfiguration& config) { const VkInstanceCreateInfo instance_create_info{ VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, // sType nullptr, // pNext +#ifdef __APPLE__ + VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR, // flags +#else // __APPLE__ 0u, // flags +#endif // __APPLE__ &application_info, // pApplicationInfo static_cast(enabled_layers.size()), // enabledLayerCount enabled_layers.data(), // ppEnabledLayerNames diff --git a/aten/src/ATen/nnapi/nnapi_bind.cpp b/aten/src/ATen/nnapi/nnapi_bind.cpp index 38c7925a1d54f0..120c62cd4ab93c 100644 --- a/aten/src/ATen/nnapi/nnapi_bind.cpp +++ b/aten/src/ATen/nnapi/nnapi_bind.cpp @@ -133,7 +133,7 @@ void NnapiCompilation::run( t.nbytes()); } - for (const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(static_cast(outputs.size()))) { auto& t = outputs[i]; // TODO: Check contiguous and dtype. check_nnapi->Execution_setOutput( @@ -147,7 +147,7 @@ void NnapiCompilation::run( check_nnapi->Execution_compute(execution); // TODO: Maybe skip this for fixed-size outputs? - for (const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(static_cast(outputs.size()))) { auto& t = outputs[i]; uint32_t rank = 0; check_nnapi->Execution_getOutputOperandRank(execution, i, &rank); @@ -177,9 +177,8 @@ void NnapiCompilation::get_operand_type(const at::Tensor& t, ANeuralNetworksOper if (t.scalar_type() == c10::kQUInt8) { TORCH_CHECK(t.is_quantized()); operand->type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - operand->scale = t.q_scale(); - operand->zeroPoint = t.q_zero_point(); + operand->scale = static_cast(t.q_scale()); + operand->zeroPoint = static_cast(t.q_zero_point()); return; } if (t.scalar_type() == c10::kInt) { @@ -194,7 +193,6 @@ void NnapiCompilation::get_operand_type(const at::Tensor& t, ANeuralNetworksOper "testing with fixed scale, zero_point. Please change your ", "inputs if you see this in production"); operand->type = ANEURALNETWORKS_TENSOR_QUANT16_ASYMM; - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) operand->scale = 0.125; operand->zeroPoint = 0; return; diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 8826e81bd0c3f9..94a2bf56f8d7bf 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -724,36 +724,17 @@ uint64_t RecordFunction::currentThreadId() { return current_thread_id_; } -void RecordFunction::before(const char* name, int64_t sequence_nr) { - fn_ = name; - sequence_nr_ = sequence_nr; - is_nccl_meta_ = (std::strcmp(name, kParamCommsCallName.c_str()) == 0); - -#ifndef NDEBUG - inputs_valid_ = true; -#endif - runStartCallbacks(); - invalidateInputs(); -} - -void RecordFunction::before(std::string name, int64_t sequence_nr) { - is_nccl_meta_ = (name == kParamCommsCallName); - fn_ = std::move(name); - sequence_nr_ = sequence_nr; - -#ifndef NDEBUG - inputs_valid_ = true; -#endif - runStartCallbacks(); - invalidateInputs(); -} - -void RecordFunction::before( - RecordFunction::schema_ref_t schema, - int64_t sequence_nr) { +void RecordFunction::before(RecordFunction::FunctionDescriptor fn, int64_t sequence_nr) { + std::visit([this](auto&& fn) { + if constexpr (std::is_same_v, std::string_view>) { + is_nccl_meta_ = (fn == kParamCommsCallName); + fn_ = std::string(fn); + } else { + is_nccl_meta_ = (fn.get().name() == kParamCommsCallName); + fn_ = fn; + } + }, fn); sequence_nr_ = sequence_nr; - fn_ = schema; - is_nccl_meta_ = (schema.get().name() == kParamCommsCallName); #ifndef NDEBUG inputs_valid_ = true; diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 7c82879cde5935..29fbc8270a451f 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace c10 { @@ -287,9 +288,11 @@ struct TORCH_API RecordFunction { explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION); explicit RecordFunction(StepCallbacks&& step_callbacks); - template + using schema_ref_t = std::reference_wrapper; + using FunctionDescriptor = std::variant; + void before( - F fn, + FunctionDescriptor fn, c10::ArrayRef args, int64_t current_sequence_nr = -1) { if (!isActive()) { @@ -299,9 +302,8 @@ struct TORCH_API RecordFunction { before(fn, current_sequence_nr); } - template void before( - F fn, + FunctionDescriptor fn, c10::ArrayRef args, const std::unordered_map* kwargs, int64_t current_sequence_nr = -1) { @@ -309,12 +311,11 @@ struct TORCH_API RecordFunction { return; } kwinputs_ = *kwargs; - before(std::move(fn), args, current_sequence_nr); + before(fn, args, current_sequence_nr); } - template void before( - F fn, + FunctionDescriptor fn, const std::unordered_map* kwargs, int64_t current_sequence_nr = -1) { if (!isActive()) { @@ -324,20 +325,18 @@ struct TORCH_API RecordFunction { before(fn, current_sequence_nr); } - template void before( - F fn, + FunctionDescriptor fn, const std::vector* args, int64_t current_sequence_nr = -1) { before( - std::move(fn), + fn, c10::ArrayRef(args->data(), args->size()), current_sequence_nr); } - template void before( - F fn, + FunctionDescriptor fn, const std::vector* args, const std::unordered_map* kwargs, int64_t current_sequence_nr = -1) { @@ -426,10 +425,7 @@ struct TORCH_API RecordFunction { // before functions initialize RecordFunction members and call // start callbacks - using schema_ref_t = std::reference_wrapper; - void before(const char* name, int64_t sequence_nr = -1); - void before(std::string name, int64_t sequence_nr = -1); - void before(schema_ref_t schema, int64_t sequence_nr = -1); + void before(FunctionDescriptor schema, int64_t sequence_nr = -1); // Sets node ID for distributed profiling static void setDefaultNodeId(int64_t defaultNodeId); @@ -553,10 +549,10 @@ TORCH_API std::optional getStepCallbacksUnlessEmpty( RecordScope scope); namespace detail { -template +template void record_function_with_scope( RecordFunction& guard, - F fn, + RecordFunction::FunctionDescriptor fn, const Inputs& inputs, Args&&... args) { if (guard.needsInputs()) { @@ -569,10 +565,10 @@ void record_function_with_scope( } } -template +template void record_function_with_scope_and_debug_handle( RecordFunction& guard, - F fn, + RecordFunction::FunctionDescriptor fn, int64_t debug_handle, const Inputs& inputs, Args&&... args) { @@ -587,30 +583,26 @@ void record_function_with_scope_and_debug_handle( } } -template +template void record_function_with_scope( RecordFunction& guard, - F fn, + RecordFunction::FunctionDescriptor fn, c10::ArrayRef inputs, Args&&... args) { - return record_function_with_scope< - c10::ArrayRef, - F, - Args...>(guard, std::move(fn), inputs, std::forward(args)...); + return record_function_with_scope, Args...>( + guard, fn, inputs, std::forward(args)...); } -template +template void record_function_with_scope_and_debug_handle( RecordFunction& guard, - F fn, + RecordFunction::FunctionDescriptor fn, int64_t debug_handle, c10::ArrayRef inputs, Args&&... args) { return record_function_with_scope_and_debug_handle< c10::ArrayRef, - F, - Args...>( - guard, std::move(fn), debug_handle, inputs, std::forward(args)...); + Args...>(guard, fn, debug_handle, inputs, std::forward(args)...); } } // namespace detail diff --git a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp index 73f69d7f8529f0..c390305e2051cf 100644 --- a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp +++ b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp @@ -63,8 +63,7 @@ TEST(CPUAllocationPlanTest, with_control_flow) { } bool success{true}; for (uint64_t i = 0; i < 10; ++i) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool validation_success; + bool validation_success = false; { c10::WithValidateAllocationPlanGuard validation_guard(&plan, &validation_success); diff --git a/aten/src/ATen/test/cuda_complex_test.cu b/aten/src/ATen/test/cuda_complex_test.cu index ba3a23ce3e2f89..5736f73330760a 100644 --- a/aten/src/ATen/test/cuda_complex_test.cu +++ b/aten/src/ATen/test/cuda_complex_test.cu @@ -5,14 +5,14 @@ __global__ void test_thrust_kernel() { // thrust conversion { - constexpr float num1 = float(1.23); - constexpr float num2 = float(4.56); + [[maybe_unused]] constexpr float num1 = float(1.23); + [[maybe_unused]] constexpr float num2 = float(4.56); assert(c10::complex(thrust::complex(num1, num2)).real() == num1); assert(c10::complex(thrust::complex(num1, num2)).imag() == num2); } { - constexpr double num1 = double(1.23); - constexpr double num2 = double(4.56); + [[maybe_unused]] constexpr double num1 = double(1.23); + [[maybe_unused]] constexpr double num2 = double(4.56); assert(c10::complex(thrust::complex(num1, num2)).real() == num1); assert(c10::complex(thrust::complex(num1, num2)).imag() == num2); } @@ -46,11 +46,11 @@ __global__ void test_reinterpret_cast() { assert(zzzz.real() == double(1)); assert(zzzz.imag() == double(2)); - cuComplex cuComplex_zz = *reinterpret_cast(&zz); + [[maybe_unused]] cuComplex cuComplex_zz = *reinterpret_cast(&zz); assert(cuComplex_zz.x == float(1)); assert(cuComplex_zz.y == float(2)); - cuDoubleComplex cuDoubleComplex_zzzz = *reinterpret_cast(&zzzz); + [[maybe_unused]] cuDoubleComplex cuDoubleComplex_zzzz = *reinterpret_cast(&zzzz); assert(cuDoubleComplex_zzzz.x == double(1)); assert(cuDoubleComplex_zzzz.y == double(2)); } diff --git a/aten/src/ATen/test/cuda_cub_test.cu b/aten/src/ATen/test/cuda_cub_test.cu index 5e5e25d2a8c90a..6865984102b4bd 100644 --- a/aten/src/ATen/test/cuda_cub_test.cu +++ b/aten/src/ATen/test/cuda_cub_test.cu @@ -146,8 +146,8 @@ TEST(InclusiveScanSplit, CubTest) { cudaMallocManaged(&output1, sizeof(int) * 10); cudaDeviceSynchronize(); - at::cuda::cub::inclusive_scan( - input, output1, ::at_cuda_detail::cub::Sum(), 10); + at::cuda::cub::inclusive_scan, /*max_cub_size=*/2>( + input, output1, NO_ROCM(::cuda)::std::plus<>(), 10); cudaDeviceSynchronize(); ASSERT_EQ(output1[0], 1); @@ -172,8 +172,8 @@ TEST(ExclusiveScanSplit, CubTest) { cudaMallocManaged(&output2, sizeof(int) * 10); cudaDeviceSynchronize(); - at::cuda::cub::exclusive_scan( - input, output2, ::at_cuda_detail::cub::Sum(), 0, 10); + at::cuda::cub::exclusive_scan, int, /*max_cub_size=*/2>( + input, output2, NO_ROCM(::cuda)::std::plus<>(), 0, 10); cudaDeviceSynchronize(); ASSERT_EQ(output2[0], 0); diff --git a/aten/src/ATen/test/cuda_dlconvertor_test.cpp b/aten/src/ATen/test/cuda_dlconvertor_test.cpp index 697a6c8b7112f6..34f8589391d5e6 100644 --- a/aten/src/ATen/test/cuda_dlconvertor_test.cpp +++ b/aten/src/ATen/test/cuda_dlconvertor_test.cpp @@ -9,6 +9,7 @@ #include using namespace at; + TEST(TestDlconvertor, TestDlconvertorCUDA) { manual_seed(123); @@ -50,3 +51,45 @@ TEST(TestDlconvertor, TestDlconvertorCUDAHIP) { ASSERT_TRUE(a.equal(b)); } + +TEST(TestDlconvertorVersioned, TestDlconvertorCUDA) { + manual_seed(123); + + Tensor a = rand({3, 4}, at::kCUDA); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); + + Tensor b = fromDLPackVersioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} + +TEST(TestDlconvertorVersioned, TestDlconvertorNoStridesCUDA) { + manual_seed(123); + + Tensor a = rand({3, 4}, at::kCUDA); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); + dlMTensor->dl_tensor.strides = nullptr; + + Tensor b = fromDLPackVersioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} + +TEST(TestDlconvertorVersioned, TestDlconvertorCUDAHIP) { + if (!at::cuda::is_available()) + return; + manual_seed(123); + + Tensor a = rand({3, 4}, at::kCUDA); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); + +#if AT_ROCM_ENABLED() + ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLROCM); +#else + ASSERT_TRUE(dlMTensor->dl_tensor.device.device_type == DLDeviceType::kDLCUDA); +#endif + + Tensor b = fromDLPackVersioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} diff --git a/aten/src/ATen/test/cuda_half_test.cu b/aten/src/ATen/test/cuda_half_test.cu index e5013951a70698..6f45acc30f9ea2 100644 --- a/aten/src/ATen/test/cuda_half_test.cu +++ b/aten/src/ATen/test/cuda_half_test.cu @@ -33,7 +33,7 @@ __device__ void test(){ // use the std namespace, but just "::" so that the function // gets resolved from nvcc math_functions.hpp - float threshold = 0.00001; + [[maybe_unused]] float threshold = 0.00001; assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold); assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold); assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold); diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp index 2bf9e8dc232960..dca9126c7cde39 100644 --- a/aten/src/ATen/test/dlconvertor_test.cpp +++ b/aten/src/ATen/test/dlconvertor_test.cpp @@ -3,12 +3,8 @@ #include #include -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -#include - using namespace at; + TEST(TestDlconvertor, TestDlconvertor) { manual_seed(123); @@ -31,3 +27,26 @@ TEST(TestDlconvertor, TestDlconvertorNoStrides) { ASSERT_TRUE(a.equal(b)); } + +TEST(TestDlconvertorUnversioned, TestDlconvertor) { + manual_seed(123); + + Tensor a = rand({3, 4}); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); + + Tensor b = fromDLPackVersioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} + +TEST(TestDlconvertorUnversioned, TestDlconvertorNoStrides) { + manual_seed(123); + + Tensor a = rand({3, 4}); + DLManagedTensorVersioned* dlMTensor = toDLPackVersioned(a); + dlMTensor->dl_tensor.strides = nullptr; + + Tensor b = fromDLPackVersioned(dlMTensor); + + ASSERT_TRUE(a.equal(b)); +} diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index ee8a923668fe4d..a9b5a70f1de917 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -61,6 +61,8 @@ namespace { template class QuantizationTests : public ::testing::Test {}; template + class Quantization8BitTests : public ::testing::Test {}; + template class Quantization8BitWithTailTests : public ::testing::Test {}; template class FunctionalTests : public ::testing::Test {}; @@ -79,6 +81,7 @@ namespace { using FloatTestedTypes = ::testing::Types; using ALLTestedTypes = ::testing::Types; using QuantTestedTypes = ::testing::Types; + using Quantization8BitTestedTypes = ::testing::Types; #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) using Quantization8BitWithTailTestedTypes = ::testing::Types; @@ -116,6 +119,7 @@ namespace { TYPED_TEST_SUITE(BitwiseFloatsAdditional, RealFloatReducedFloatTestedTypes); TYPED_TEST_SUITE(BitwiseFloatsAdditional2, FloatTestedTypes); TYPED_TEST_SUITE(QuantizationTests, QuantTestedTypes); + TYPED_TEST_SUITE(Quantization8BitTests, Quantization8BitTestedTypes); TYPED_TEST_SUITE(InfiniteTests, RealFloatTestedTypes); #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) TYPED_TEST_SUITE( @@ -1496,6 +1500,68 @@ namespace { }, test_case); } +#ifndef _WIN32 + TYPED_TEST(Quantization8BitTests, Transpose) { + using VT = ValueType; + constexpr auto M = 4; + constexpr auto N = 64; + constexpr auto L = M * N; + constexpr auto ld_src = N; + constexpr auto ld_dst = M; + CACHE_ALIGN VT x[L]; + CACHE_ALIGN VT y[L]; + CACHE_ALIGN VT ref[L]; + auto seed = TestSeed(); + ValueGen generator(VT(-100), VT(100), seed); + for (const auto i : c10::irange(L)) { + x[i] = generator.get(); + } + at::native::utils::transpose( + M, N, + reinterpret_cast(x), ld_src, + reinterpret_cast(y), ld_dst); + for (int64_t j = 0; j < N; j++) { + for (int64_t i = 0; i < M; i++) { + ref[j * ld_dst + i] = c10::load(&(x[i * ld_src + j])); + } + } + for (const auto i : c10::irange(L)) { + ASSERT_EQ(y[i], ref[i]) + << "Failure Details:\nTest Seed to reproduce: " << seed; + } + } +#endif +#if defined(CPU_CAPABILITY_AVX512) + TYPED_TEST(Quantization8BitTests, PackVNNI4) { + using VT = ValueType; + constexpr auto K = 8; + constexpr auto N = 128; + constexpr auto L = K * N; + constexpr auto ld_src = N; + CACHE_ALIGN VT x[L]; + CACHE_ALIGN VT y[L]; + CACHE_ALIGN VT ref[L]; + auto seed = TestSeed(); + ValueGen generator(VT(-100), VT(100), seed); + for (const auto i : c10::irange(L)) { + x[i] = generator.get(); + } + at::vec::pack_vnni4(x, y, ld_src, K, N); + int64_t _K = K / 4; + for (int64_t k = 0; k < _K; k++) { + for(int64_t n = 0; n < N; n++) { + for(int64_t l = 0; l < 4; l++) { + ref[k * N * 4 + n * 4 + l] = + c10::load(&(x[k * ld_src * 4 + l * ld_src + n])); + } + } + } + for (const auto i : c10::irange(L)) { + ASSERT_EQ(y[i], ref[i]) + << "Failure Details:\nTest Seed to reproduce: " << seed; + } + } +#endif TYPED_TEST(FunctionalTests, Map) { using vec = TypeParam; using VT = ValueType; diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 462f90b4e9272a..f7062a3048dfc1 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -1,6 +1,7 @@ #pragma once -#include #include +#include +#include #include #include #include @@ -21,7 +22,9 @@ #else #define CACHE_LINE 32 #endif - +#ifndef _WIN32 +#include +#endif #if defined(__GNUC__) #define CACHE_ALIGN __attribute__((aligned(CACHE_LINE))) #define not_inline __attribute__((noinline)) diff --git a/aten/src/ATen/test/vitals.cpp b/aten/src/ATen/test/vitals.cpp index 9bf22d81e45f73..cc93775bb5383e 100644 --- a/aten/src/ATen/test/vitals.cpp +++ b/aten/src/ATen/test/vitals.cpp @@ -80,8 +80,7 @@ TEST(Vitals, OnAndOff) { TEST(Vitals, APIVitals) { std::stringstream buffer; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool rvalue; + bool rvalue = false; std::streambuf* sbuf = std::cout.rdbuf(); std::cout.rdbuf(buffer.rdbuf()); { diff --git a/aten/src/ATen/xpu/CachingHostAllocator.h b/aten/src/ATen/xpu/CachingHostAllocator.h index 8e0bd3f953a6d1..ac99e6eef4f44b 100644 --- a/aten/src/ATen/xpu/CachingHostAllocator.h +++ b/aten/src/ATen/xpu/CachingHostAllocator.h @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace at::xpu { diff --git a/benchmarks/README.md b/benchmarks/README.md index a0292bbc633f01..6f0a5efc615ab4 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -31,3 +31,4 @@ Please refer to each subfolder to discover each benchmark suite. Links are provi * [Overrides](overrides_benchmark/README.md) * [Sparse](sparse/README.md) * [Tensor expression](tensorexpr/HowToRun.md) +* [Data](data/README.md) diff --git a/benchmarks/data/README.md b/benchmarks/data/README.md new file mode 100644 index 00000000000000..259a51741cf760 --- /dev/null +++ b/benchmarks/data/README.md @@ -0,0 +1,62 @@ +# PyTorch Data Benchmarks + +This directory contains benchmarks for the `torch.utils.data` module components, focusing on the performance of samplers. + +## Dependencies + +The benchmarks require the following dependencies: +``` +numpy +tabulate +``` + +You can install them using pip: +```bash +pip install numpy tabulate +``` + +## Running the benchmarks + +To run the BatchSampler benchmark: +```bash +python samplers_benchmark.py +``` + +## Sampler Benchmark + +The `samplers_benchmark.py` script benchmarks the performance of PyTorch's BatchSampler against an alternative implementation as an example. It tests with the following parameters: + +- Batch sizes: 4, 8, 64, 640, 6400, 64000 +- Drop last options: True, False +- Each configuration is run 10 times and averaged +- Results include speedup percentage calculations + +### Output + +The benchmark outputs a table with the following columns: +- Batch Size +- Drop Last +- Original (s): Time taken by the original implementation +- New (s): Time taken by the alternative implementation +- Speedup: Percentage improvement of the new implementation over the original + +Example output: +``` ++------------+-----------+---------------+----------+---------+ +| Batch Size | Drop Last | Original (s) | New (s) | Speedup | ++============+===========+===============+==========+=========+ +| 4 | True | 0.1234 | 0.1000 | 18.96% | ++------------+-----------+---------------+----------+---------+ +| 4 | False | 0.1345 | 0.1100 | 18.22% | ++------------+-----------+---------------+----------+---------+ +... +``` + +### Extending the Benchmark + +To benchmark a different implementation: + +On local: +1. Modify the `NewBatchSampler` class in `samplers_benchmark.py` with your implementation. Similarly replace `BatchSampler` with the corresponding PyTorch implementation. + * Ensure to include all inputs like `replacement` for `RandomSampler` and its variations +2. Run the benchmark to compare its performance against the original diff --git a/benchmarks/data/samplers_benchmark.py b/benchmarks/data/samplers_benchmark.py new file mode 100644 index 00000000000000..6cdd0c77c65194 --- /dev/null +++ b/benchmarks/data/samplers_benchmark.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +import time +from collections.abc import Iterable, Iterator +from typing import Union + +import numpy as np +from tabulate import tabulate + +from torch.utils.data import BatchSampler, Sampler, SequentialSampler + + +class NewBatchSampler(Sampler[list[int]]): + """Alternative implementation of BatchSampler for benchmarking purposes.""" + + def __init__( + self, + sampler: Union[Sampler[int], Iterable[int]], + batch_size: int, + drop_last: bool, + ) -> None: + if ( + not isinstance(batch_size, int) + or isinstance(batch_size, bool) + or batch_size <= 0 + ): + raise ValueError( + f"batch_size should be a positive integer value, but got batch_size={batch_size}" + ) + if not isinstance(drop_last, bool): + raise ValueError( + f"drop_last should be a boolean value, but got drop_last={drop_last}" + ) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self) -> Iterator[list[int]]: + if self.drop_last: + sampler_iter = iter(self.sampler) + while True: + try: + batch = [next(sampler_iter) for _ in range(self.batch_size)] + yield batch + except StopIteration: + break + else: + batch = [0] * self.batch_size + idx_in_batch = 0 + for idx in self.sampler: + batch[idx_in_batch] = idx + idx_in_batch += 1 + if idx_in_batch == self.batch_size: + yield batch + idx_in_batch = 0 + batch = [0] * self.batch_size + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self) -> int: + # Can only be called if self.sampler has __len__ implemented + if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore[arg-type] + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] + + +def main(): + """Run benchmark with specified parameters.""" + DATA_SIZE = 99999 + AVG_TIMES = 10 + BATCH_SIZES = [4, 8, 64, 640, 6400, 64000] + DROP_LAST_OPTIONS = [True, False] + + results = [] + + # Set up samplers here, ensure right args are passed in + baselineSampler = BatchSampler + testSampler = NewBatchSampler + + for batch_size in BATCH_SIZES: + for drop_last in DROP_LAST_OPTIONS: + print(f"Benchmarking with batch_size={batch_size}, drop_last={drop_last}") + + # Benchmark baselineSampler + original_times = [] + for _ in range(AVG_TIMES): + start = time.perf_counter() + for _ in baselineSampler( + sampler=SequentialSampler(range(DATA_SIZE)), + batch_size=batch_size, + drop_last=drop_last, + ): + pass + end = time.perf_counter() + original_times.append(end - start) + time.sleep(0.1) + + original_avg = float(np.mean(original_times)) + + # Benchmark testSampler + new_times = [] + for _ in range(AVG_TIMES): + start = time.perf_counter() + for _ in testSampler( + sampler=SequentialSampler(range(DATA_SIZE)), + batch_size=batch_size, + drop_last=drop_last, + ): + pass + end = time.perf_counter() + new_times.append(end - start) + time.sleep(0.1) # Small delay to reduce system load + + new_avg = float(np.mean(new_times)) + + # Calculate speedup + if original_avg > 0 and new_avg > 0: + speedup = (original_avg - new_avg) / original_avg * 100 + speedup_str = f"{speedup:.2f}%" + else: + speedup_str = "N/A" + + print(f"Speedup: {speedup_str}\n") + + results.append( + [ + batch_size, + drop_last, + f"{original_avg:.4f}", + f"{new_avg:.4f}", + speedup_str, + ] + ) + + # Print results in a table + headers = ["Batch Size", "Drop Last", "Original (s)", "New (s)", "Speedup"] + print("\nBenchmark Results:") + print(tabulate(results, headers=headers, tablefmt="grid")) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/Makefile b/benchmarks/dynamo/Makefile index d27e64e82fd77c..c62773280bc282 100644 --- a/benchmarks/dynamo/Makefile +++ b/benchmarks/dynamo/Makefile @@ -1,3 +1,9 @@ +# Usage: +# make build-deps TORCHBENCH_MODELS= +# Support install a single torchbench model (e.g., "alexnet"), +# or multiple torchbench model names (e.g., "alexnet basic_gnn_gcn BERT_pytorch"), +# or empty (i.e., "") for installing all torchbench models. + clone-deps: (cd ../../.. \ && (test -e torchvision || git clone --recursive https://github.com/pytorch/vision torchvision) \ @@ -21,17 +27,13 @@ pull-deps: clone-deps (cd ../../../torchbenchmark && git fetch && git checkout "$$(cat ../pytorch/.github/ci_commit_pins/torchbench.txt)" && git submodule update --init --recursive) build-deps: clone-deps - # Install uv with - curl -LsSf https://astral.sh/uv/install.sh | sh - # and create the virtual env - uv venv pt-benchmark-py3.12 --python 3.12 && source pt-benchmark-py3.12/bin/activate uv pip install astunparse numpy scipy ninja pyyaml mkl mkl-include setuptools cmake \ - typing-extensions requests protobuf numba cython scikit-learn torch librosa - (cd ../../../torchvision && uv pip install -e .) + typing-extensions requests protobuf numba cython scikit-learn librosa + (cd ../../../torchvision && uv pip install -e . --no-build-isolation) (cd ../../../torchdata && uv pip install -e .) (cd ../../../torchaudio && uv pip install -e . --no-build-isolation) (cd ../../../FBGEMM/fbgemm_gpu && uv pip install -r requirements.txt && uv pip install -e . --no-build-isolation) (cd ../../../torchrec && uv pip install -e .) (cd ../../../detectron2 && uv pip install -e . --no-build-isolation) - (cd ../../../torchbenchmark && python install.py --continue_on_fail) + (cd ../../../torchbenchmark && python install.py --continue_on_fail $(if $(TORCHBENCH_MODELS),models $(TORCHBENCH_MODELS))) uv pip uninstall torchrec-nightly fbgemm-gpu-nightly diff --git a/benchmarks/dynamo/benchmarks.py b/benchmarks/dynamo/benchmarks.py index 981cfffe512902..55b03429bcc5b0 100755 --- a/benchmarks/dynamo/benchmarks.py +++ b/benchmarks/dynamo/benchmarks.py @@ -6,7 +6,7 @@ # Note - hf and timm have their own version of this, torchbench does not -# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this... +# TODO(voz): Someday, consolidate all the files into one runner instead of a shim like this... def model_names(filename: str) -> set[str]: names = set() with open(filename) as fh: diff --git a/benchmarks/dynamo/cachebench.py b/benchmarks/dynamo/cachebench.py index e32939add372eb..c5cbb1eef4d0a2 100644 --- a/benchmarks/dynamo/cachebench.py +++ b/benchmarks/dynamo/cachebench.py @@ -8,7 +8,7 @@ import tempfile from typing import Callable -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache logger: logging.Logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ def _run_torchbench_from_args( warm_compile_time: list[float] = [] for _ in range(cmd_args.repeat): - with fresh_inductor_cache(): + with fresh_cache(): env = os.environ.copy() with tempfile.NamedTemporaryFile(suffix=".csv") as file: args.append("--output=" + file.name) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_timm_inference.csv index 1d1ee09d9e571a..c889ba0e8d2f75 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_timm_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_timm_inference.csv @@ -42,7 +42,7 @@ cspdarknet53,pass,0 -deit_base_distilled_patch16_224,fail_to_run,0 +deit_base_distilled_patch16_224,pass,0 @@ -114,7 +114,7 @@ lcnet_050,pass,0 -levit_128,fail_to_run,0 +levit_128,pass,0 @@ -238,7 +238,7 @@ vit_base_patch16_224,pass,0 -volo_d1_224,fail_to_run,0 +volo_d1_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv index dd6f6264f90cf0..7e100f9787cfce 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv @@ -34,7 +34,7 @@ basic_gnn_gin,pass,0 -basic_gnn_sage,fail_to_run,0 +basic_gnn_sage,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index 5c90665ce32d09..9dfc96c49cfa94 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -138,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,18 +hf_BigBird,pass,24 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index ae70016cd9a79a..5f21dd6aba6848 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -122,7 +122,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,18 +hf_BigBird,pass,24 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv index 00802842d5d047..d2300bdac05b81 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv @@ -154,6 +154,10 @@ maml_omniglot,pass,7 +microbench_unbacked_tolist_sum,pass,8 + + + mnasnet1_0,pass,7 @@ -266,11 +270,11 @@ timm_nfnet,pass,0 -timm_regnet,pass,6 +timm_regnet,pass,7 -timm_resnest,pass,7 +timm_resnest,pass,6 @@ -294,7 +298,7 @@ tts_angular,pass,9 -vgg16,pass,6 +vgg16,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv index b3be1dd48c9540..1605a26b7ce5f5 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv @@ -154,6 +154,10 @@ maml_omniglot,pass,7 +microbench_unbacked_tolist_sum,pass,8 + + + mnasnet1_0,pass,7 @@ -262,11 +266,11 @@ timm_nfnet,pass,0 -timm_regnet,pass,6 +timm_regnet,pass,7 -timm_resnest,pass,7 +timm_resnest,pass,6 @@ -290,7 +294,7 @@ tts_angular,pass,9 -vgg16,pass,6 +vgg16,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv index 1ab1fd98f213db..b43e38b7d822a8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv @@ -142,6 +142,10 @@ maml_omniglot,pass,7 +microbench_unbacked_tolist_sum,pass,8 + + + mnasnet1_0,pass,7 @@ -242,11 +246,11 @@ timm_efficientnet,pass,7 -timm_regnet,pass,6 +timm_regnet,pass,7 -timm_resnest,pass,7 +timm_resnest,pass,6 @@ -270,7 +274,7 @@ tts_angular,pass,9 -vgg16,pass,6 +vgg16,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv index 6e55cc616fa685..754f5f718e436c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv @@ -154,6 +154,10 @@ maml_omniglot,pass,7 +microbench_unbacked_tolist_sum,pass,8 + + + mnasnet1_0,pass,7 @@ -266,11 +270,11 @@ timm_nfnet,pass,0 -timm_regnet,pass,6 +timm_regnet,pass,7 -timm_resnest,pass,7 +timm_resnest,pass,6 @@ -294,7 +298,7 @@ tts_angular,pass,9 -vgg16,pass,6 +vgg16,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv index 3991407f1b529c..86ad955b5a2cb0 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv @@ -142,6 +142,10 @@ maml_omniglot,pass,7 +microbench_unbacked_tolist_sum,pass,8 + + + mnasnet1_0,pass,7 @@ -246,11 +250,11 @@ timm_efficientnet,pass,7 -timm_regnet,pass,6 +timm_regnet,pass,7 -timm_resnest,pass,7 +timm_resnest,pass,6 @@ -274,7 +278,7 @@ tts_angular,pass,9 -vgg16,pass,6 +vgg16,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py index b5606b89c5b744..366ede56ed7aca 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py +++ b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py @@ -1,5 +1,5 @@ """ -Update commited CSV files used as reference points by dynamo/inductor CI. +Update committed CSV files used as reference points by dynamo/inductor CI. Currently only cares about graph breaks, so only saves those columns. diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index c30df0cd3f2783..1088634ce911e8 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -14,6 +14,7 @@ import json import logging import os +import random import shutil import signal import subprocess @@ -50,7 +51,7 @@ try: from torch._dynamo.utils import clone_inputs, graph_break_reasons - from torch._inductor.utils import fresh_inductor_cache + from torch._inductor.utils import fresh_cache except ImportError: from _dynamo.utils import clone_inputs, graph_break_reasons @@ -66,7 +67,7 @@ import torch_xla import torch_xla.core.xla_model as xm - # This is to woraround the backward issue https://github.com/pytorch/xla/issues/4174 + # This is to workaround the backward issue https://github.com/pytorch/xla/issues/4174 torch_xla._XLAC._init_computation_client() except ImportError: # ignore the error if torch_xla is not installed @@ -269,7 +270,7 @@ class CI(NamedTuple): # Maps a benchmark model name to a list of status codes. For any listed entry, we'll -# capture TORCH_COMPILE_DEBUG logs in CI runs and preseve them (i.e., for upload) if +# capture TORCH_COMPILE_DEBUG logs in CI runs and preserve them (i.e., for upload) if # the result status matches one listed. CI_PRESERVE_COMPILE_DEBUG = { # For example: @@ -559,7 +560,7 @@ def nothing(f): return f -@functools.lru_cache(None) +@functools.cache def patch_torch_manual_seed(): """Make torch manual seed deterministic. Helps with accuracy testing.""" @@ -690,17 +691,52 @@ def timed( times=1, return_result=False, collect_outputs=False, + batch_size=None, ): use_xla = tensor_is_on_xla(example_inputs) synchronize() + if batch_size: + patch_torch_manual_seed() + if use_xla: xm.mark_step() xm.wait_device_ops() + def vary_batch(t: torch.Tensor, new_batch_size) -> torch.Tensor: + for i, s in enumerate(t.size()): + if s == batch_size: + # If new batch is smaller, we truncate + if new_batch_size < batch_size: + indexer = [slice(None)] * t.ndim + indexer[i] = slice(0, new_batch_size) + t = t[tuple(indexer)] + # If new batch is greater, we just duplicate the last row + # over and over until we hit the desired batch size + elif new_batch_size > batch_size: + indexer = [slice(None)] * t.ndim + indexer[i] = -1 + last_slice = t[tuple(indexer)].unsqueeze(i) + repeat_shape = list(t.shape) + repeat_shape[i] = new_batch_size - batch_size + padding = last_slice.expand(*repeat_shape) + t = torch.cat([t, padding], dim=i) + break + return t + time_total = 0 # Dont collect outputs to correctly measure timing for _ in range(times): + # If batch_size is 1, it too often collides with other non batch size + # dimensions resulting in errors. + if batch_size and batch_size > 1: + # Calculate new batch size by varying the original batch size by up to 20% + # Ensure it's at least greater than 1 + variation = random.uniform(0.8, 1.2) + new_batch_size = max(2, int(batch_size * variation)) + example_inputs = tree_map_only( + torch.Tensor, lambda x: vary_batch(x, new_batch_size), example_inputs + ) # Put this call inside the loop to reset the seed for each iteration. # Don't include reset_rng_state() to correctly measure timing reset_rng_state(use_xla) @@ -1018,9 +1054,6 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): Writes to ./speedups.csv """ - # if args.dynamic_shapes: - # return speedup_experiment_ds(args, model_iter_fn, model, example_inputs) - timings = np.zeros((args.repeat, 2), np.float64) # if we randomize the input, we should also check the result is correct should_randomize_input = args.randomize_input @@ -1041,7 +1074,7 @@ def maybe_mark_profile(*args, **kwargs): times = args.iterations_per_run - # Use higher tolerance for XLA since XLA cause numerical unstability when + # Use higher tolerance for XLA since XLA cause numerical instability when # graph size changes tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4 torch._dynamo.config.repro_tolerance = tolerance @@ -1074,6 +1107,7 @@ def maybe_mark_profile(*args, **kwargs): return_result=True, times=times, collect_outputs=args.collect_outputs, + batch_size=kwargs.get("batch_size"), ) # call mark_step between the 2 calls to make the comparison fair. @@ -1179,82 +1213,6 @@ def maybe_mark_profile(*args, **kwargs): return msg -# WARNING: This code is currently dead -def speedup_experiment_ds(args, model_iter_fn, model, example_inputs): - """ - Run dynamic shapes benchmarks. - - Requires dynamic shape compatible models, which provide a list of example inputs. - - Warms up using the first input example and then iterates the inputs, - measuring (and expecting minimal) variance between the runtime for different examples. - - """ - timings = np.zeros((args.repeat, len(example_inputs), 2), np.float64) - - if args.repeat > 5: - print( - f"\ndynamic shapes experiments are slow, consider setting --repeat less than {args.repeat}\n" - ) - - nwarmup = 4 - for rep in range(args.repeat): - # Start each rep fresh, e.g. only warmup on example 0 - torch._dynamo.reset() - optimized_model_iter_fn = optimize_ctx(model_iter_fn) - for _ in range(nwarmup): - optimized_model_iter_fn(model, example_inputs[0]) - - for input_idx, inputs in enumerate(example_inputs): - # interleave the runs to handle frequency scaling and load changes - timings[rep, input_idx, 0] = timed( - model, model_iter_fn, inputs, return_result=False - ) - # different from regular speedup_experiment, we _DO_ want to allow recompilation - timings[rep, input_idx, 1] = timed( - model, optimized_model_iter_fn, inputs, return_result=False - ) - medians = np.median(timings, axis=0) - speedups = list(medians[:, 0] / medians[:, 1]) - speedups_mean = np.mean(speedups) - speedups_median = np.median(speedups) - speedups_var = np.var(speedups) - - # TODO this x[0] is not going to work in general but bert only has 1 input - shapes = [x[0].shape for x in example_inputs] - shape_keys = sorted(set(shapes)) - shape_speedups = { - shape: [ - it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups)) - ] - for shape in shape_keys - } - output_str = ( - f"mean: {speedups_mean:.3f}, median: {speedups_median:.3f}, var: {speedups_var:.3f}" - + "\nSpeedups by shape: " - + "\n".join( - [ - f"{shape}: " - + ", ".join([f"{speedup: .3g}" for speedup in shape_speedups[shape]]) - for shape in shape_keys - ] - ) - ) - write_outputs( - output_filename, - ("dev", "name", "batch_size", "speedup mean", "speedup median", "speedup var"), - [ - current_device, - current_name, - current_batch_size, - speedups_mean, - speedups_median, - speedups_var, - ], - ) - return output_str - - def overhead_experiment(*args, model_iter_fn): """ Measure overheads of TorchDynamo by running with no backend (only @@ -1722,7 +1680,7 @@ def setup_amp(self, current_device=None): devices = [current_device] if current_device else self.args.devices if self.args.amp: - # AMP training can lead to small loss values which can undeflow + # AMP training can lead to small loss values which can underflow # gradient values returning in zero gradients. To solve this # problem, PyTorch introduces GradScaler. GradScaler is a stateful # structure, that scales the loss values to prevent underflow. Loss @@ -1760,7 +1718,7 @@ def init_optimizer(self, name, device, params): self.optimizer = torch.optim.SGD(params, lr=0.01, foreach=True) # Disable multi_tensor_sgd for benchmarking, there isn't a large performance benefit (~1%) to compiling # this optimizer because it is a single foreach add, and increases compile time. - # After autotuning and fake tensor caching lands, we can enable, becuase the compile time impact will be lower. + # After autotuning and fake tensor caching lands, we can enable, because the compile time impact will be lower. # Fake Tensor caching: https://github.com/pytorch/pytorch/pull/113873 # Autotuning: https://github.com/pytorch/pytorch/issues/117447 self.optimizer.step = torch._dynamo.disable(self.optimizer.step) @@ -2557,7 +2515,14 @@ def warmup(fn, model, example_inputs, mode, niters=10): return " ".join(map(str, results)) def run_performance_test( - self, name, model, example_inputs, optimize_ctx, experiment, tag=None + self, + name, + model, + example_inputs, + optimize_ctx, + experiment, + tag=None, + batch_size=None, ): if self.args.xla: with self.pick_grad(name, self.args.training): @@ -2615,6 +2580,7 @@ def warmup(fn, model, example_inputs, mode, niters=5): with self.pick_grad(name, self.args.training), ctx: ok, total = Stats.reset_counters() experiment_kwargs = {} + experiment_kwargs["batch_size"] = batch_size if tag is not None: experiment_kwargs["tag"] = tag results = [] @@ -2778,6 +2744,7 @@ def run_one_model( experiment, explain=False, tag=None, + batch_size=None, ): mode = "train" if self.args.training else "eval" msg = f"{current_device:4} {mode:5} {current_name:34} " @@ -2806,7 +2773,13 @@ def run_one_model( ) else: status = self.run_performance_test( - name, model, example_inputs, optimize_ctx, experiment, tag + name, + model, + example_inputs, + optimize_ctx, + experiment, + tag, + batch_size=batch_size, ) print(status) empty_gpu_cache(current_device) @@ -2850,7 +2823,7 @@ def add_double_quotes(x): ) # NB: Don't upload them to the benchmark database as they are debugging - # infomation. There are also around a million records a day which is + # information. There are also around a million records a day which is # wasteful to store write_outputs( filename, @@ -2908,7 +2881,7 @@ def parse_args(args=None): iterations_per_run_help = """ Run this may iterations for each time measurement. This is mainly used for XLA training. We want to run multiple iterations per measurement so the - tracing and computation for different iteartions can overlap with each + tracing and computation for different iterations can overlap with each other. This makes sure we have an accurate xla baseline. """ parser.add_argument( @@ -3067,7 +3040,7 @@ def get_example_inputs(self): parser.add_argument( "--generate-aot-autograd-stats", action="store_true", - help="Generates AOT Autograd stats like how mnay graphs are sent to AOT", + help="Generates AOT Autograd stats like how many graphs are sent to AOT", ) parser.add_argument( "--inductor-settings", @@ -3288,7 +3261,7 @@ def get_example_inputs(self): "--warm-start-latency", "--warm_start_latency", action="store_true", - help="Run model(s) twice and preseve caches in between to enable a 'warm start' on the 2nd run", + help="Run model(s) twice and preserve caches in between to enable a 'warm start' on the 2nd run", ) group_fuser = parser.add_mutually_exclusive_group() @@ -3443,7 +3416,7 @@ def maybe_fresh_cache(args): if not cache_dir_assigned and ( args.cold_start_latency or args.warm_start_latency or args.ci ): - return fresh_inductor_cache() + return fresh_cache() else: return contextlib.nullcontext() @@ -3637,7 +3610,7 @@ def run(runner, args, original_dir=None): torch.backends.mkldnn.deterministic = True - # Remove randomeness when torch manual seed is called + # Remove randomness when torch manual seed is called patch_torch_manual_seed() # Some models e.g. yolov3 assert batch size on n_gpus @@ -4143,6 +4116,7 @@ def detect_and_mark_batch(t): experiment, explain=args.explain, tag=args.tag, + batch_size=batch_size if args.dynamic_batch_only else None, ) if args.generate_aot_autograd_stats: stats_file = output_filename.split(".csv")[0] + "_stats.csv" diff --git a/benchmarks/dynamo/microbenchmarks/cache_hit_microbenchmarks.py b/benchmarks/dynamo/microbenchmarks/cache_hit_microbenchmarks.py index 53879f5e8c0eee..e575102bd1a401 100644 --- a/benchmarks/dynamo/microbenchmarks/cache_hit_microbenchmarks.py +++ b/benchmarks/dynamo/microbenchmarks/cache_hit_microbenchmarks.py @@ -3,7 +3,7 @@ import torch.fx from torch._dynamo.utils import counters -from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache +from torch._inductor.utils import clear_caches, fresh_cache N = 10000 @@ -20,7 +20,7 @@ def main(): torch._inductor.config.fx_graph_cache = True torch._inductor.config.fx_graph_remote_cache = False - with fresh_inductor_cache(): + with fresh_cache(): a = torch.randn(4).cuda() compiled_fn = torch.compile(huge_graph, backend="inductor") @@ -30,7 +30,7 @@ def main(): def setup(): torch._dynamo.reset() - clear_inductor_caches() + clear_caches() for m in torch._inductor.codecache.PyCodeCache.cache.values(): os.remove(m.__file__) counters.clear() diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py index 36a212625f177d..f1f9ea9b30bab2 100644 --- a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py +++ b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py @@ -135,7 +135,7 @@ def contains_tensor_types(type): ) -@functools.lru_cache(None) +@functools.cache def non_compute_operator(op): schema = op._schema @@ -274,7 +274,7 @@ def get_inputs_for_operator( yield return - # line[1] represents number of times these inputs occured, ignored for now + # line[1] represents number of times these inputs occurred, ignored for now for line in self.operator_db[str(operator)].items(): inps = line[0] diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py index e805b7ff6b380e..67c39fc615d156 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py @@ -3,7 +3,7 @@ from benchmark_base import BenchmarkBase import torch -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache class Benchmark(BenchmarkBase): @@ -50,7 +50,7 @@ def f(a, b): result = result.sin() return result - with fresh_inductor_cache(): + with fresh_cache(): f(self.a, self.b) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py index b03d62924c7724..7d76597810a62a 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache class ListOfLinears(nn.Module): @@ -55,7 +55,7 @@ def _prepare(self): def _work(self): with ( - fresh_inductor_cache(), + fresh_cache(), torch._inductor.config.patch(force_shape_pad=self._force_shape_pad), ): opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/dynamo_inline.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/dynamo_inline.py index 8d2ce45cc02ce7..13efd583efeee4 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/dynamo_inline.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/dynamo_inline.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache # Create a chain of artificial nesting @@ -94,7 +94,7 @@ def _work(self): # enable_cpp_symbolic_shape_guards has impact on this benchmark # Keep using False value for consistency. with ( - fresh_inductor_cache(), + fresh_cache(), ): opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( self.m.cuda() if self._is_gpu else self.m diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/float_args.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/float_args.py index 640557e6f11d58..73d96ff34a6d57 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/float_args.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/float_args.py @@ -3,7 +3,7 @@ from benchmark_base import BenchmarkBase import torch -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache class Benchmark(BenchmarkBase): @@ -31,7 +31,7 @@ def _work(self): def f(x, y): return x + y - with fresh_inductor_cache(): + with fresh_cache(): for i in range(8): f(torch.arange(3), i * 2.5) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/mm_loop.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/mm_loop.py index 1939a2cd3f177c..51c2ecf034706b 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/mm_loop.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/mm_loop.py @@ -3,7 +3,7 @@ from benchmark_base import BenchmarkBase import torch -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache class Benchmark(BenchmarkBase): @@ -45,7 +45,7 @@ def f(a, b): z = torch.mm(z, b) return z - with fresh_inductor_cache(), torch._inductor.config.patch(max_autotune=True): + with fresh_cache(), torch._inductor.config.patch(max_autotune=True): f(self.a, self.b) diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/nested_module.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/nested_module.py index e28f5cd256c801..bf60b418266b76 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/nested_module.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/nested_module.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache class NestedModule(nn.Module): @@ -67,7 +67,7 @@ def _work(self): # enable_cpp_symbolic_shape_guards has impact on this benchmark # Keep using False value for consistency. with ( - fresh_inductor_cache(), + fresh_cache(), ): opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( self.m.cuda() if self._is_gpu else self.m diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 2088dcf6d50f25..edc9d0f73d1618 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,77 +1,89 @@ -add_loop_eager,compile_time_instruction_count,2953000000,0.015 +add_loop_eager,compile_time_instruction_count,3017000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 -add_loop_inductor,compile_time_instruction_count,29370000000,0.015 +add_loop_inductor,compile_time_instruction_count,29490000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18270000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16310000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015 -basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000000,0.2 +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2 -update_hint_regression,compile_time_instruction_count,1700000000,0.02 +update_hint_regression,compile_time_instruction_count,1673000000,0.02 -float_args,compile_time_instruction_count,452500000,0.015 +sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 -sum_floordiv_regression,compile_time_instruction_count,998600000,0.015 +symint_sum,compile_time_instruction_count,3166000000,0.015 -symint_sum,compile_time_instruction_count,3252000000,0.015 +symint_sum_loop,compile_time_instruction_count,4202000000,0.015 -symint_sum_loop,compile_time_instruction_count,4262000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2112000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8672000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3859000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10420000000,0.015 +mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015 + + + +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015 + + + +basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 + + + +basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index b27425b1283fc8..714cf9901b61e2 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -269,7 +269,7 @@ def parse_args(): "--no-graphs", action="store_true", default=False, - help="Do not genenerate and upload metric graphs", + help="Do not generate and upload metric graphs", ) parser.add_argument( "--no-update-archive", @@ -368,7 +368,7 @@ def get_mode(args): def get_skip_tests(suite, device, is_training: bool): """ - Generate -x seperated string to skip the unusual setup training tests + Generate -x separated string to skip the unusual setup training tests """ skip_tests = set() original_dir = abspath(os.getcwd()) @@ -550,7 +550,7 @@ def env_var(name): gh_fh.write(comment) -@functools.lru_cache(None) +@functools.cache def archive_data(archive_name): if archive_name is not None: prefix_match = re.search(r"\w+(?=_performance)", archive_name) @@ -570,7 +570,7 @@ def archive_data(archive_name): return day, prefix -@functools.lru_cache(None) +@functools.cache def default_archive_name(dtype): _, prefix = archive_data(None) return f"{prefix}_performance_{dtype}_{randint(100, 999)}" @@ -1359,7 +1359,7 @@ def update_lookup_file(self): dtype = self.args.dtypes[0] day, _ = archive_data(self.args.archive_name) target_dir = get_archive_name(self.args, dtype) - # Update lookup csv the folder to arhived logs + # Update lookup csv the folder to archived logs subprocess.check_call( f'echo "{day},performance,{dtype},{target_dir}" >> {self.lookup_file}', shell=True, @@ -1418,7 +1418,7 @@ def gen_comment(self): def comment_on_gh(self, comment): """ - Send a commment to dashboard + Send a comment to dashboard """ with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write(comment) diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index 51d5800cdadf2e..f0d5bc103f47b7 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -81,6 +81,7 @@ def pip_install(package): "sebotnet33ts_256", "selecsls42b", "convnext_base", + "cait_m36_384", } REQUIRE_HIGHER_TOLERANCE_AMP = { @@ -129,6 +130,7 @@ def pip_install(package): "mobilenetv3_large_100", "cspdarknet53", "gluon_inception_v3", + "cait_m36_384", } diff --git a/benchmarks/dynamo/torchbench.yaml b/benchmarks/dynamo/torchbench.yaml index 796bb8529f1b2c..304bcdec4f4814 100644 --- a/benchmarks/dynamo/torchbench.yaml +++ b/benchmarks/dynamo/torchbench.yaml @@ -229,7 +229,7 @@ skip: - doctr_det_predictor - doctr_reco_predictor - moondream - # doesnt fit in memory + # doesn't fit in memory - phi_1_5 - detectron2_fcos_r_50_fpn diff --git a/benchmarks/fastrnns/factory.py b/benchmarks/fastrnns/factory.py index b17a475b631ba7..fa19c10fd8a5a6 100644 --- a/benchmarks/fastrnns/factory.py +++ b/benchmarks/fastrnns/factory.py @@ -225,7 +225,7 @@ def varlen_lstm_inputs( return x, lengths, (hx, cx), lstm.all_weights, lstm else: # NB: lstm.all_weights format: - # wih, whh, bih, bhh = lstm.all_weights[layer] + # w_ih, w_hh, b_ih, b_hh = lstm.all_weights[layer] return x, lengths, (hx, cx), lstm.all_weights, None @@ -266,10 +266,10 @@ def varlen_lstm_factory(cell, script): def dynamic_rnn( sequences: list[Tensor], hiddens: tuple[Tensor, Tensor], - wih: Tensor, - whh: Tensor, - bih: Tensor, - bhh: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, ) -> tuple[list[Tensor], tuple[list[Tensor], list[Tensor]]]: hx, cx = hiddens hxs = hx.unbind(1) @@ -286,7 +286,7 @@ def dynamic_rnn( for seq_idx in range(len(inputs)): hy, cy = cell( - inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh + inputs[seq_idx].unsqueeze(0), (hy, cy), w_ih, w_hh, b_ih, b_hh ) output += [hy] outputs += [torch.stack(output)] @@ -315,7 +315,7 @@ def varlen_lstm_creator(script=False, **kwargs): # cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark -# the lowerbound directly. Instead, we only benchmark the forward pass by mimicing the +# the lowerbound directly. Instead, we only benchmark the forward pass by mimicking the # computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve # as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself # is invariant), the lowerbound of backward pass is hard to get since we lose the @@ -352,12 +352,12 @@ def forward(input, hidden): ) -# input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer]) +# input: lstm.all_weights format (w_ih, w_hh, b_ih, b_hh = lstm.all_weights[layer]) # output: packed_weights with format -# packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize) -# packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize) -# packed_weights[2] is bih with size (layer, 4*hiddenSize) -# packed_weights[3] is bhh with size (layer, 4*hiddenSize) +# packed_weights[0] is w_ih with size (layer, 4*hiddenSize, inputSize) +# packed_weights[1] is w_hh with size (layer, 4*hiddenSize, hiddenSize) +# packed_weights[2] is b_ih with size (layer, 4*hiddenSize) +# packed_weights[3] is b_hh with size (layer, 4*hiddenSize) def stack_weights(weights): def unzip_columns(mat): assert isinstance(mat, list) @@ -398,7 +398,7 @@ def lstm_inputs( return x, (hx, cx), lstm.all_weights, lstm else: # NB: lstm.all_weights format: - # wih, whh, bih, bhh = lstm.all_weights[layer] + # w_ih, w_hh, b_ih, b_hh = lstm.all_weights[layer] return x, (hx, cx), lstm.all_weights, None @@ -406,17 +406,17 @@ def lstm_factory(cell, script): def dynamic_rnn( input: Tensor, hidden: tuple[Tensor, Tensor], - wih: Tensor, - whh: Tensor, - bih: Tensor, - bhh: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, ) -> tuple[Tensor, tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = input.unbind(0) hy, cy = hx[0], cx[0] for seq_idx in range(len(inputs)): - hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) + hy, cy = cell(inputs[seq_idx], (hy, cy), w_ih, w_hh, b_ih, b_hh) outputs += [hy] return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) @@ -432,17 +432,17 @@ def lstm_factory_premul(premul_cell, script): def dynamic_rnn( input: Tensor, hidden: tuple[Tensor, Tensor], - wih: Tensor, - whh: Tensor, - bih: Tensor, - bhh: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, ) -> tuple[Tensor, tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] - inputs = torch.matmul(input, wih.t()).unbind(0) + inputs = torch.matmul(input, w_ih.t()).unbind(0) hy, cy = hx[0], cx[0] for seq_idx in range(len(inputs)): - hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh) + hy, cy = premul_cell(inputs[seq_idx], (hy, cy), w_hh, b_ih, b_hh) outputs += [hy] return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) @@ -458,10 +458,10 @@ def lstm_factory_premul_bias(premul_cell, script): def dynamic_rnn( input: Tensor, hidden: tuple[Tensor, Tensor], - wih: Tensor, - whh: Tensor, - bih: Tensor, - bhh: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, ) -> tuple[Tensor, tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] @@ -470,11 +470,11 @@ def dynamic_rnn( # FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this # case. Workaround with mm and views. inpSize = input.size() - inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih + inputs = torch.mm(input.view(-1, inpSize[2]), w_ih.t()) + b_ih inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0) hy, cy = hx[0], cx[0] for seq_idx in range(len(inputs)): - hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh) + hy, cy = premul_cell(inputs[seq_idx], (hy, cy), w_hh, b_hh) outputs += [hy] return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) @@ -488,12 +488,12 @@ def dynamic_rnn( # simple: flat inputs (no tuples), no list to accumulate outputs # useful mostly for benchmarking older JIT versions def lstm_factory_simple(cell, script): - def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh): + def dynamic_rnn(input, hx, cx, w_ih, w_hh, b_ih, b_hh): hy = hx # for scoping cy = cx # for scoping inputs = input.unbind(0) for seq_idx in range(len(inputs)): - hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh) + hy, cy = cell(inputs[seq_idx], hy, cy, w_ih, w_hh, b_ih, b_hh) return hy, cy if script: @@ -515,12 +515,12 @@ def dynamic_rnn( hy = hx[layer] cy = cx[layer] base_idx = layer * params_stride - wih = params[base_idx] - whh = params[base_idx + 1] - bih = params[base_idx + 2] - bhh = params[base_idx + 3] + w_ih = params[base_idx] + w_hh = params[base_idx + 1] + b_ih = params[base_idx + 2] + b_hh = params[base_idx + 3] for seq_idx in range(len(inputs)): - hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) + hy, cy = cell(inputs[seq_idx], (hy, cy), w_ih, w_hh, b_ih, b_hh) outputs += [hy] inputs, outputs = outputs, [] return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0)) diff --git a/benchmarks/fastrnns/test.py b/benchmarks/fastrnns/test.py index 36a5db23c1b4af..bf76b934ad1570 100644 --- a/benchmarks/fastrnns/test.py +++ b/benchmarks/fastrnns/test.py @@ -51,34 +51,34 @@ def test_rnns( print("Setting up...") control = control_creator(**creator_args) - experim = experim_creator(**creator_args) + experiment = experim_creator(**creator_args) # Precondition - assertEqual(experim.inputs, control.inputs) - assertEqual(experim.params, control.params) + assertEqual(experiment.inputs, control.inputs) + assertEqual(experiment.params, control.params) print("Checking outputs...") control_outputs = control.forward(*control.inputs) - experim_outputs = experim.forward(*experim.inputs) + experim_outputs = experiment.forward(*experiment.inputs) assertEqual(experim_outputs, control_outputs) print("Checking grads...") assert control.backward_setup is not None - assert experim.backward_setup is not None + assert experiment.backward_setup is not None assert control.backward is not None - assert experim.backward is not None + assert experiment.backward is not None control_backward_inputs = control.backward_setup(control_outputs, seed) - experim_backward_inputs = experim.backward_setup(experim_outputs, seed) + experim_backward_inputs = experiment.backward_setup(experim_outputs, seed) control.backward(*control_backward_inputs) - experim.backward(*experim_backward_inputs) + experiment.backward(*experim_backward_inputs) control_grads = [p.grad for p in control.params] - experim_grads = [p.grad for p in experim.params] + experim_grads = [p.grad for p in experiment.params] assertEqual(experim_grads, control_grads) if verbose: - print(experim.forward.graph_for(*experim.inputs)) + print(experiment.forward.graph_for(*experiment.inputs)) print() @@ -103,16 +103,16 @@ def test_vl_py(**test_args): print("Setting up...") control = control_creator(**creator_args) - experim = experim_creator(**creator_args) + experiment = experim_creator(**creator_args) # Precondition - assertEqual(experim.inputs, control.inputs[:2]) - assertEqual(experim.params, control.params) + assertEqual(experiment.inputs, control.inputs[:2]) + assertEqual(experiment.params, control.params) print("Checking outputs...") control_out, control_hiddens = control.forward(*control.inputs) control_hx, control_cx = control_hiddens - experim_out, experim_hiddens = experim.forward(*experim.inputs) + experim_out, experim_hiddens = experiment.forward(*experiment.inputs) experim_hx, experim_cx = experim_hiddens experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2) @@ -122,25 +122,25 @@ def test_vl_py(**test_args): print("Checking grads...") assert control.backward_setup is not None - assert experim.backward_setup is not None + assert experiment.backward_setup is not None assert control.backward is not None - assert experim.backward is not None + assert experiment.backward is not None control_backward_inputs = control.backward_setup( (control_out, control_hiddens), test_args["seed"] ) - experim_backward_inputs = experim.backward_setup( + experim_backward_inputs = experiment.backward_setup( (experim_out, experim_hiddens), test_args["seed"] ) control.backward(*control_backward_inputs) - experim.backward(*experim_backward_inputs) + experiment.backward(*experim_backward_inputs) control_grads = [p.grad for p in control.params] - experim_grads = [p.grad for p in experim.params] + experim_grads = [p.grad for p in experiment.params] assertEqual(experim_grads, control_grads) if test_args["verbose"]: - print(experim.forward.graph_for(*experim.inputs)) + print(experiment.forward.graph_for(*experiment.inputs)) print() diff --git a/benchmarks/functional_autograd_benchmark/torchvision_models.py b/benchmarks/functional_autograd_benchmark/torchvision_models.py index 25dd91c02d6ae7..ed236078830e7a 100644 --- a/benchmarks/functional_autograd_benchmark/torchvision_models.py +++ b/benchmarks/functional_autograd_benchmark/torchvision_models.py @@ -885,7 +885,7 @@ def __init__( self.cost_bbox = cost_bbox self.cost_giou = cost_giou assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, ( - "all costs cant be 0" + "all costs can't be 0" ) @torch.no_grad() @@ -920,13 +920,13 @@ def forward(self, outputs, targets): # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. - # The 1 is a constant that doesn't change the matching, it can be ommitted. + # The 1 is a constant that doesn't change the matching, it can be omitted. cost_class = -out_prob[:, tgt_ids] # Compute the L1 cost between boxes cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) - # Compute the giou cost betwen boxes + # Compute the giou cost between boxes cost_giou = -generalized_box_iou( box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) ) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index 8b4e4a550b9926..9b4fa22452c05b 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -44,7 +44,7 @@ def device_sync(device): elif "cpu" in device: pass else: - print(f"device={device} is not yet suppported") + print(f"device={device} is not yet supported") def get_arch_name() -> str: diff --git a/benchmarks/inductor_backends/cutlass.py b/benchmarks/inductor_backends/cutlass.py index b90620373adc2e..7141872ec3c43f 100644 --- a/benchmarks/inductor_backends/cutlass.py +++ b/benchmarks/inductor_backends/cutlass.py @@ -1,4 +1,5 @@ import os +import sys os.environ["TORCH_LOGS"] = "inductor" @@ -17,6 +18,7 @@ import torch from torch._inductor import config as inductor_config +from torch.testing._internal.inductor_utils import _quantize_rowwise log: logging.Logger = logging.getLogger(__name__) @@ -28,18 +30,21 @@ # uncomment for better debugging # inductor_config.force_disable_caches = True +USE_FAST_ACCUM = True UNITS = { "name": "", "forward_time": " (us)", + "teraflops": " (TFLOPS)", "compilation_time": " (s)", } PERF_OVER_ATEN_STR: str = "perf_over_aten (%)" OP_NAMES = [ "mm", - "addmm", - "bmm", + # "addmm", + # "bmm", + # "_scaled_mm", ] SHAPES = [ @@ -57,6 +62,7 @@ DTYPES = [ torch.float16, torch.bfloat16, + # torch.float8_e4m3fn, ] # triton knobs @@ -70,12 +76,13 @@ "0", # "1111", # "2222", - "3333", + "3332", + # "9992", ] def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: - return do_bench(lambda: func(*args, **kwargs)) * 1e3 + return do_bench(lambda: func(*args, **kwargs), warmup=100, rep=10000) * 1e3 @dataclass(frozen=True, kw_only=True) @@ -162,6 +169,7 @@ def name(self) -> str: class ExperimentResults: name: str forward_time: float + teraflops: float compilation_time: float def asdict(self): @@ -196,6 +204,34 @@ def get_inputs( A = torch.randn(batch_size, M, K, dtype=dtype, device=device) B = torch.randn(batch_size, N, K, dtype=dtype, device=device).permute(0, 2, 1) return A, B + elif op_name == "_scaled_mm": + # For _scaled_mm, we only support fp8e4m3 with rowwise scaling + if dtype != torch.float8_e4m3fn: + raise ValueError(f"_scaled_mm only supports fp8e4m3, got {dtype}") + + # Create input tensors in bfloat16 first, then quantize to fp8 + input_dtype = torch.bfloat16 + x = torch.randn(M, K, dtype=input_dtype, device=device) + w = torch.randn(N, K, dtype=input_dtype, device=device) + + # Quantize using rowwise scaling + w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype) + w_t_fp8 = w_fp8.t() + w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N) + + x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype) + + # Return inputs for _scaled_mm: (input, weight_t, scale_a, scale_b, bias, out, out_dtype, use_fast_accum) + return ( + x_fp8, + w_t_fp8, + x_inverse_scale, + w_inverse_scale, + None, + None, + torch.bfloat16, + USE_FAST_ACCUM, + ) else: raise ValueError(f"Unknown op {op_name}") @@ -210,8 +246,11 @@ def run_single_experiment_group( for config in group_config.experiments: torch._dynamo.reset() - torch._inductor.utils.clear_inductor_caches() - compiled_op = torch.compile(op, fullgraph=True, options=config.to_options()) + torch._inductor.utils.clear_caches() + compiled_op = torch.compile( + op, + options=config.to_options(), + ) start_time = time.perf_counter() try: @@ -227,6 +266,7 @@ def run_single_experiment_group( ExperimentResults( name=config.name(), forward_time=float("inf"), + teraflops=0.0, compilation_time=float("inf"), ) ) @@ -238,10 +278,18 @@ def run_single_experiment_group( *inputs, ) + flops = calculate_flops( + group_config.op_name, + group_config.shape, + group_config.batch_size, + ) + teraflops = flops / (forward_time * 1e-6) / 1e12 + results.append( ExperimentResults( name=config.name(), forward_time=forward_time, + teraflops=teraflops, compilation_time=compilation_time, ) ) @@ -336,6 +384,22 @@ def calculate_table_data(results: list[ExperimentResults]) -> dict: return table_data +def calculate_flops(op_name: str, shape: tuple[int, int, int], batch_size: int) -> int: + """ + Calculate the number of floating point operations based on operation type and shape. + """ + M, N, K = shape + + if op_name == "bmm": + return 2 * batch_size * M * N * K + elif op_name == "addmm": + return 2 * M * N * K + M * N + elif op_name == "_scaled_mm": + return 2 * M * N * K + else: + return 2 * M * N * K + + def get_printable_results(experiment_groups: list[ExperimentGroup]) -> list[str]: edge_over_aten = defaultdict(list) output = [] @@ -390,8 +454,10 @@ def main(): results.append( ExperimentGroup(config=group_config, results=group_results), ) - log.info(f"\nINTERMEDIATE results: {i}/{len(configs)}") # noqa: G004 - log.info(get_printable_results(results)) + sys.stderr.write( + f"\nINTERMEDIATE results: {i + 1}/{len(configs)} \n" + + get_printable_results(results) + ) print("\nFINAL results...") print(get_printable_results(results)) diff --git a/benchmarks/inference/README.md b/benchmarks/inference/README.md index fe707799c098e4..0517313cbb770a 100644 --- a/benchmarks/inference/README.md +++ b/benchmarks/inference/README.md @@ -20,7 +20,7 @@ For now we omit data preprocessing as well as result post-processing. ### Running a single benchmark -The togglable commmand line arguments to the script are as follows: +The togglable command line arguments to the script are as follows: - `num_iters` (default: 100): how many requests to send to the backend excluding the first warmup request - `batch_size` (default: 32): the batch size of the requests. diff --git a/benchmarks/inference/server.py b/benchmarks/inference/server.py index 706811b134e458..7cf7b940687fcc 100644 --- a/benchmarks/inference/server.py +++ b/benchmarks/inference/server.py @@ -45,7 +45,7 @@ def _run_metrics(self, metrics_lock): """ This function will poll the response queue until it has received all responses. It records the startup latency, the average, max, min latency - as well as througput of requests. + as well as throughput of requests. """ warmup_response_time = None response_times = [] diff --git a/benchmarks/instruction_counts/applications/ci.py b/benchmarks/instruction_counts/applications/ci.py index 4c9517b0f89727..7a4f39b093d98b 100644 --- a/benchmarks/instruction_counts/applications/ci.py +++ b/benchmarks/instruction_counts/applications/ci.py @@ -55,7 +55,7 @@ def main(argv: list[str]) -> None: results = Runner(work_orders, cadence=30.0).run() - # TODO: Annotate with TypedDict when 3.8 is the minimum supported verson. + # TODO: Annotate with TypedDict when 3.8 is the minimum supported version. grouped_results: dict[str, dict[str, list[Union[float, int]]]] = { key: {"times": [], "counts": []} for key in keys } diff --git a/benchmarks/instruction_counts/main.py b/benchmarks/instruction_counts/main.py index 09869bf6710d57..16b0ba4397cb08 100644 --- a/benchmarks/instruction_counts/main.py +++ b/benchmarks/instruction_counts/main.py @@ -2,7 +2,7 @@ The contents of this file are placeholders, and will be replaced by more expressive and robust components (e.g. better runner and result display -components) in future iterations. However this allows us to excercise the +components) in future iterations. However this allows us to exercise the underlying benchmark generation infrastructure in the mean time. """ diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index 47f346f2933cd3..cb836bb5eaa4bf 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -6,6 +6,8 @@ import os import timeit from collections import namedtuple +from dataclasses import asdict, dataclass +from typing import Any, Optional import benchmark_utils @@ -191,9 +193,8 @@ def __init__(self, args): self.use_jit = args.use_jit self.num_runs = args.num_runs self.print_per_iter = False - self.output_dir = args.output_dir + self.output_csv = args.output_csv self.operator_range = benchmark_utils.get_operator_range(args.operator_range) - self.disable_output = args.disable_output # 100 is the default warmup iterations if self.args.warmup_iterations == -1: self.args.warmup_iterations = 100 @@ -457,8 +458,6 @@ def _print_test_case_info(self, test_case): return False def _output_csv(self, filename, headers, row): - if self.args.disable_output is True: - return if os.path.exists(filename): with open(filename) as fd: lines = list(csv.reader(fd)) or [[]] @@ -475,19 +474,101 @@ def _output_csv(self, filename, headers, row): for line in lines: writer.writerow(list(line) + ["0"] * (len(headers) - len(line))) + def _output_json( + self, + perf_list, + output_file, + ): + """ + Write the result into JSON format, so that it can be uploaded to the benchmark database + to be displayed on OSS dashboard. The JSON format is defined at + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + if not perf_list: + return + + # Prepare headers and records for JSON output + records = [] + for perf_item in perf_list: + # Extract data from perf_item + test_name = perf_item.get("test_name", "unknown") + input_config = perf_item.get("input_config", "") + run_type = perf_item.get("run") + latency = perf_item.get("latency", 0) + + dtype = "float32" # default + + # Extract mode based on run_type + mode = None + if run_type == "Forward": + mode = "inference" + elif run_type == "Backward": + mode = "training" + + # Create the record + @dataclass + class BenchmarkInfo: + name: str + mode: Optional[str] + dtype: str + extra_info: dict[str, Any] + + @dataclass + class ModelInfo: + name: str + type: str + origins: list[str] + + @dataclass + class MetricInfo: + name: str + unit: str + benchmark_values: list[float] + target_value: Optional[float] + + @dataclass + class BenchmarkRecord: + benchmark: BenchmarkInfo + model: ModelInfo + metric: MetricInfo + + record = BenchmarkRecord( + benchmark=BenchmarkInfo( + name="PyTorch operator benchmark", + mode=mode, + dtype=dtype, + extra_info={"input_config": input_config}, + ), + model=ModelInfo( + name=test_name, type="micro-benchmark", origins=["pytorch"] + ), + metric=MetricInfo( + name="latency", + unit="us", + benchmark_values=[latency], + target_value=None, + ), + ) + + records.append(asdict(record)) + + # Write all records to the output file + with open(output_file, "w", encoding="utf-8") as f: + json.dump(records, f, indent=2) + def run(self): self._print_header() - output_filename = self.args.output_dir + output_csv_filename = self.args.output_csv headers = [ "Benchmarking Framework", - "Benchamrking Module Name", + "Benchmarking Module Name", "Case Name", "tag", "run_backward", "Execution Time", ] - if self.args.output_json: + if self.args.output_json or self.args.output_json_for_dashboard: perf_list = [] for test_metainfo in BENCHMARK_TESTER: @@ -532,7 +613,7 @@ def run(self): # output results to csv self._output_csv( - output_filename, + output_csv_filename, headers, [ test_case.framework, @@ -547,11 +628,14 @@ def run(self): reported_time[0], ], ) - if self.args.output_json: + if self.args.output_json or self.args.output_json_for_dashboard: perf_list.append( self._perf_result_to_dict(reported_time, test_case) ) + if self.args.output_json_for_dashboard: + self._output_json(perf_list, self.args.output_json_for_dashboard) + if self.args.output_json: with open(self.args.output_json, "w") as f: json.dump(perf_list, f) diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py index 6d6e0c5cbf8f64..9dfab781498eaf 100644 --- a/benchmarks/operator_benchmark/benchmark_runner.py +++ b/benchmarks/operator_benchmark/benchmark_runner.py @@ -152,13 +152,16 @@ def parse_args(): ) parser.add_argument( - "--output-dir", - help="Choose the output directory to save the logs", + "--output-csv", + "--output_csv", + help="CSV file path to store the results", default="benchmark_logs", ) + parser.add_argument( - "--disable-output", - help="Disable log output to csv file", + "--output-json-for-dashboard", + "--output_json_for_dashboard", + help="Save results in JSON format for display on the OSS dashboard", default="False", ) diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py index a4ff524c9863b1..62ebb76ac58d18 100644 --- a/benchmarks/operator_benchmark/benchmark_utils.py +++ b/benchmarks/operator_benchmark/benchmark_utils.py @@ -134,14 +134,14 @@ def _validate(configs): def config_list(**configs): """Generate configs based on the list of input shapes. This function will take input shapes specified in a list from user. Besides - that, all other parameters will be cross producted first and each of the + that, all other parameters will be cross produced first and each of the generated list will be merged with the input shapes list. Reserved Args: attr_names(reserved): a list of names for input shapes. attrs(reserved): a list of values for each input shape. corss_product: a dictionary of attributes which will be - cross producted with the input shapes. + cross produced with the input shapes. tags(reserved): a tag used to filter inputs. Here is an example: diff --git a/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv b/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv index 6289dbda597ebb..873f14d20127ca 100644 --- a/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv +++ b/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv @@ -1,4 +1,4 @@ -Benchmarking Framework,Benchamrking Module Name,Case Name,tag,run_backward,Execution Time +Benchmarking Framework,Benchmarking Module Name,Case Name,tag,run_backward,Execution Time PyTorch,add,add_M1_N1_K1_cpu,short,FALSE,3.9497 PyTorch,add,add_M64_N64_K64_cpu,short,FALSE,14.3181 PyTorch,add,add_M64_N64_K128_cpu,short,FALSE,14.6826 diff --git a/benchmarks/sparse/test_csr.sh b/benchmarks/sparse/test_csr.sh index 7bd43d08d90178..e22c2df6ee54bb 100644 --- a/benchmarks/sparse/test_csr.sh +++ b/benchmarks/sparse/test_csr.sh @@ -8,7 +8,7 @@ echo "----- USE_MKL=1 -----" >> $OUTFILE rm -rf build export USE_MKL=1 -python setup.py build --cmake-only +CMAKE_ONLY=1 python setup.py build ccmake build # or cmake-gui build python setup.py install diff --git a/benchmarks/tensorexpr/benchmark.py b/benchmarks/tensorexpr/benchmark.py index 38951ac5091f26..63b03ea33ffd41 100644 --- a/benchmarks/tensorexpr/benchmark.py +++ b/benchmarks/tensorexpr/benchmark.py @@ -34,7 +34,7 @@ def __init__(self, mode, device, dtype): for method in dir(self.engine): if not callable(getattr(self.engine, method)): continue - # don't forward if this function is overriden here + # don't forward if this function is overridden here if hasattr(self, method): continue # don't forward if it is a internal function @@ -89,7 +89,7 @@ def dtype_to_bytes(self): @staticmethod def default_configs(): - """return a list of defualt configs for this benchmark""" + """return a list of default configs for this benchmark""" raise ValueError("this method should be reimplemented by subclass") def is_supported(self): diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index a2de7538898692..4be4a1e7c46c0a 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -271,9 +271,9 @@ def run_single_backend_sdpa( if config.calculate_bwd_time: # TODO: debug backward pass for njt if eager_sdpa and not config.attn_type == "document_mask": - dOut = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2) + d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2) backward_eager_time = benchmark_torch_function_in_microseconds( - out_eager.backward, dOut, retain_graph=True + out_eager.backward, d_out, retain_graph=True ) else: backward_eager_time = float("nan") @@ -340,9 +340,9 @@ def run_single_backend_FA( if config.calculate_bwd_time: if FA: - dOut = torch.randn_like(out_FA) + d_out = torch.randn_like(out_FA) backward_FA_time = benchmark_torch_function_in_microseconds( - out_FA.backward, dOut, retain_graph=True + out_FA.backward, d_out, retain_graph=True ) else: backward_FA_time = float("nan") @@ -432,9 +432,9 @@ def run_single_experiment( ) if config.calculate_bwd_time: - dOut = torch.randn_like(out_compile) + d_out = torch.randn_like(out_compile) backward_compile_time = benchmark_torch_function_in_microseconds( - out_compile.backward, dOut, retain_graph=True + out_compile.backward, d_out, retain_graph=True ) sparsity = block_mask.sparsity() / 100.0 if block_mask is not None else 0.0 sparsity = sparsity if config.attn_type != "document_mask" else 0.5 diff --git a/benchmarks/transformer/sdpa.py b/benchmarks/transformer/sdpa.py index 8d286561ae0e75..2eca4bf06b444e 100644 --- a/benchmarks/transformer/sdpa.py +++ b/benchmarks/transformer/sdpa.py @@ -172,9 +172,9 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults: out_torch = scaled_dot_product_attention( q, k, v, is_causal=is_causal, attn_mask=None ) - dOut = torch.randn_like(out_torch) + d_out = torch.randn_like(out_torch) backward_time = benchmark_cuda_function_in_microseconds( - out_torch.backward, dOut, retain_graph=True + out_torch.backward, d_out, retain_graph=True ) # Calculate TFLOPS for forward and backward passes diff --git a/buckbuild.bzl b/buckbuild.bzl index 135325dc39ae9d..4eb92674ceec60 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -402,7 +402,7 @@ def get_aten_generated_files(enabled_backends): # This is tiresome. A better strategy would be to unconditionally # generate these files, and then only actually COMPILE them depended - # on the generated set. C'est la vie... + # on the generated set. C'est la vie... # codespell:ignore vie if "CPU" in enabled_backends: src_files.extend(aten_ufunc_generated_cpu_sources()) src_files.extend(aten_ufunc_generated_cpu_kernel_sources()) @@ -525,7 +525,7 @@ def copy_template_registration_files(name, apple_sdks = None): # Ideally, we would run one copy command for a single source directory along # with all its child directories, but it's somewhat hard to know if a directory - # is a child of another just bu looking at the metadata (directory relative + # is a child of another just by looking at the metadata (directory relative # path) that we currently have since 1 directory could look like a parent of # another and yet come from a different filegroup() rule. # @@ -738,7 +738,6 @@ def vulkan_spv_shader_library(name, spv_filegroup): }, cmd = " ".join(genrule_cmd), default_outs = ["."], - labels = ["uses_dotslash"], ) fb_xplat_cxx_library( @@ -777,7 +776,7 @@ def copy_metal(name, apple_sdks = None): # Metal custom ops currently have to be brought into selective build because they directly reference metal ops instead of # going through the dispatcher. There is some weird issues with the genrule and these files locations on windows though, so - # for now we simply skip building them for windows where they very likely arent needed anyway. + # for now we simply skip building them for windows where they very likely aren't needed anyway. # Metal MaskRCNN custom op for full_path in METAL_MASKRCNN_SOURCE_LIST: path_prefix = paths.dirname(full_path) @@ -793,7 +792,7 @@ def copy_metal(name, apple_sdks = None): name = name, cmd = " && ".join(cmd), cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)), - # due to an obscure bug certain custom ops werent being copied correctly on windows. ARVR also sometimes builds android targets on windows, + # due to an obscure bug certain custom ops weren't being copied correctly on windows. ARVR also sometimes builds android targets on windows, # so we just exclude those targets from being copied for those platforms (They end up uncompiled anyway). outs = select({ "DEFAULT": get_metal_registration_files_outs(), @@ -945,6 +944,7 @@ def define_buck_targets( [ ("torch/csrc/api/include", "torch/**/*.h"), ("", "torch/csrc/**/*.h"), + ("", "torch/headeronly/**/*.h"), ("", "torch/script.h"), ("", "torch/library.h"), ("", "torch/custom_class.h"), @@ -1244,6 +1244,7 @@ def define_buck_targets( "torch/csrc/jit/mobile/parse_operators.cpp", "torch/csrc/jit/mobile/upgrader_mobile.cpp", "torch/csrc/jit/serialization/import_read.cpp", + "torch/csrc/jit/serialization/pickler_helper.cpp", "torch/csrc/jit/serialization/unpickler.cpp", ], header_namespace = "", @@ -1256,11 +1257,11 @@ def define_buck_targets( extra_flags = { "fbandroid_compiler_flags": ["-frtti"], }, - # torch_mobile_deserialize brings in sources neccessary to read a module + # torch_mobile_deserialize brings in sources necessary to read a module # which depends on mobile module definition - # link_whole is enable so that all symbols neccessary for mobile module are compiled + # link_whole is enable so that all symbols necessary for mobile module are compiled # instead of only symbols used while loading; this prevents symbol - # found definied in runtime + # found defined in runtime # @lint-ignore BUCKLINT link_whole link_whole = True, linker_flags = get_no_as_needed_linker_flag(), @@ -1376,11 +1377,11 @@ def define_buck_targets( "torch/csrc/jit/mobile/import.h", "torch/csrc/jit/mobile/flatbuffer_loader.h", ], - # torch_mobile_deserialize brings in sources neccessary to read a module + # torch_mobile_deserialize brings in sources necessary to read a module # which depends on mobile module definition - # link_whole is enable so that all symbols neccessary for mobile module are compiled + # link_whole is enable so that all symbols necessary for mobile module are compiled # instead of only symbols used while loading; this prevents symbol - # found definied in runtime + # found defined in runtime # @lint-ignore BUCKLINT link_whole link_whole = True, linker_flags = get_no_as_needed_linker_flag(), @@ -1407,9 +1408,9 @@ def define_buck_targets( exported_headers = [], compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []), - # torch_mobile_core brings in sources neccessary to read and run a module + # torch_mobile_core brings in sources necessary to read and run a module # link_whole is enabled so that all symbols linked - # operators, registerations and other few symbols are need in runtime + # operators, registrations and other few symbols are need in runtime # @lint-ignore BUCKLINT link_whole link_whole = True, linker_flags = get_no_as_needed_linker_flag(), @@ -1523,10 +1524,10 @@ def define_buck_targets( ], compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"], - # torch_mobile_train brings in sources neccessary to read and run a mobile + # torch_mobile_train brings in sources necessary to read and run a mobile # and save and load mobile params along with autograd # link_whole is enabled so that all symbols linked - # operators, registerations and autograd related symbols are need in runtime + # operators, registrations and autograd related symbols are need in runtime # @lint-ignore BUCKLINT link_whole link_whole = True, visibility = ["PUBLIC"], @@ -1548,9 +1549,9 @@ def define_buck_targets( ], compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags(), - # torch brings in all sources neccessary to read and run a mobile module/jit module + # torch brings in all sources necessary to read and run a mobile module/jit module # link_whole is enabled so that all symbols linked - # operators, registerations and other few symbols are need in runtime + # operators, registrations and other few symbols are need in runtime # @lint-ignore BUCKLINT link_whole link_whole = True, visibility = ["PUBLIC"], @@ -1575,7 +1576,7 @@ def define_buck_targets( ], compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"], - # torch_mobile_train_import_data brings in sources neccessary to read a mobile module + # torch_mobile_train_import_data brings in sources necessary to read a mobile module # link_whole is enabled so that all symbols linked # operators other few symbols are need in runtime # @lint-ignore BUCKLINT link_whole @@ -1654,10 +1655,10 @@ def define_buck_targets( ], compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []), - # torch_mobile_model_tracer brings in sources neccessary to read and run a jit module + # torch_mobile_model_tracer brings in sources necessary to read and run a jit module # and trace the ops # link_whole is enabled so that all symbols linked - # operators, registerations and other few symbols are need in runtime + # operators, registrations and other few symbols are need in runtime # @lint-ignore BUCKLINT link_whole link_whole = True, linker_flags = get_no_as_needed_linker_flag(), @@ -1842,11 +1843,11 @@ def define_buck_targets( extra_flags = { "fbandroid_compiler_flags": ["-frtti"], }, - # torch_mobile_deserialize brings in sources neccessary to read a module + # torch_mobile_deserialize brings in sources necessary to read a module # which depends on mobile module definition - # link_whole is enable so that all symbols neccessary for mobile module are compiled + # link_whole is enable so that all symbols necessary for mobile module are compiled # instead of only symbols used while loading; this prevents symbol - # found definied in runtime + # found defined in runtime # @lint-ignore BUCKLINT link_whole link_whole = True, linker_flags = get_no_as_needed_linker_flag(), diff --git a/build_variables.bzl b/build_variables.bzl index 2b03902dda6198..51854e7c900029 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -89,6 +89,7 @@ core_sources_common = [ torch_unpickler_common = [ "torch/csrc/jit/serialization/import_read.cpp", + "torch/csrc/jit/serialization/pickler_helper.cpp", "torch/csrc/jit/serialization/unpickler.cpp", ] @@ -493,11 +494,10 @@ libtorch_core_sources = sorted( # These files are the only ones that are supported on Windows. libtorch_distributed_base_sources = [ - "torch/csrc/distributed/c10d/Backend.cpp", "torch/csrc/distributed/c10d/Backoff.cpp", - "torch/csrc/distributed/c10d/DMAConnectivity.cpp", - "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", + "torch/csrc/distributed/c10d/Backend.cpp", "torch/csrc/distributed/c10d/FileStore.cpp", + "torch/csrc/distributed/c10d/FlightRecorder.cpp", "torch/csrc/distributed/c10d/Functional.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", "torch/csrc/distributed/c10d/GroupRegistry.cpp", @@ -509,12 +509,15 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", - "torch/csrc/distributed/c10d/SymmetricMemory.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", "torch/csrc/distributed/c10d/Utils.cpp", + "torch/csrc/distributed/c10d/Work.cpp", "torch/csrc/distributed/c10d/comm.cpp", + "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", + "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", + "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", "torch/csrc/distributed/c10d/debug.cpp", "torch/csrc/distributed/c10d/default_comm_hooks.cpp", "torch/csrc/distributed/c10d/logger.cpp", @@ -523,9 +526,8 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/reducer.cpp", "torch/csrc/distributed/c10d/sequence_num.cpp", "torch/csrc/distributed/c10d/socket.cpp", - "torch/csrc/distributed/c10d/Work.cpp", - "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", - "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", + "torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.cpp", + "torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp", ] # These files are only supported on Linux (and others) but not on Windows. @@ -590,11 +592,38 @@ libtorch_core_jit_sources = sorted(jit_sources_full) libtorch_nativert_sources = [ + "torch/nativert/graph/Graph.cpp", + "torch/nativert/graph/GraphPasses.cpp", "torch/nativert/graph/GraphSignature.cpp", + "torch/nativert/graph/Serialization.cpp", "torch/nativert/graph/TensorMeta.cpp", + "torch/nativert/executor/DelegateExecutor.cpp", "torch/nativert/executor/Placement.cpp", + "torch/nativert/executor/ExecutionPlanner.cpp", + "torch/nativert/executor/ExecutionFrame.cpp", + "torch/nativert/executor/Executor.cpp", + "torch/nativert/executor/GraphExecutorBase.cpp", + "torch/nativert/executor/ConstantFolder.cpp", + "torch/nativert/executor/OpKernel.cpp", "torch/nativert/executor/PlacementUtils.cpp", + "torch/nativert/executor/SerialGraphExecutor.cpp", + "torch/nativert/executor/Weights.cpp", + "torch/nativert/executor/memory/FunctionSchema.cpp", "torch/nativert/common/FileUtil.cpp", + "torch/nativert/detail/ITree.cpp", + "torch/nativert/kernels/C10Kernel.cpp", + "torch/nativert/kernels/AutoFunctionalizeKernel.cpp", + "torch/nativert/kernels/HigherOrderKernel.cpp", + "torch/nativert/executor/memory/GreedyBySize.cpp", + "torch/nativert/executor/memory/Bump.cpp", + "torch/nativert/executor/ParallelGraphExecutor.cpp", + "torch/nativert/kernels/CallTorchBindKernel.cpp", + "torch/nativert/kernels/KernelFactory.cpp", + "torch/nativert/kernels/PrimKernelRegistry.cpp", + "torch/nativert/executor/memory/DisjointStorageGroups.cpp", + "torch/nativert/executor/memory/AliasAnalyzer.cpp", + "torch/nativert/executor/memory/LayoutPlanner.cpp", + "torch/nativert/executor/memory/LayoutManager.cpp", ] torch_mobile_tracer_sources = [ @@ -617,6 +646,7 @@ libtorch_lite_eager_symbolication = [ # Later we can split serialization and deserialization logic # to have better separation within build and only build relevant parts. "torch/csrc/jit/serialization/pickle.cpp", + "torch/csrc/jit/serialization/pickler_helper.cpp", "torch/csrc/jit/serialization/pickler.cpp", "torch/csrc/jit/serialization/unpickler.cpp", ] @@ -694,24 +724,25 @@ libtorch_cuda_distributed_base_sources = [ # These files are only supported on Linux (and others) but not on Windows. libtorch_cuda_distributed_extra_sources = [ - "torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp", + "torch/csrc/distributed/c10d/FlightRecorderCuda.cpp", "torch/csrc/distributed/c10d/NCCLUtils.cpp", - "torch/csrc/distributed/c10d/FlightRecorder.cpp", - "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", + "torch/csrc/distributed/c10d/NanCheck.cu", "torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp", + "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", "torch/csrc/distributed/c10d/ProcessGroupUCC.cpp", "torch/csrc/distributed/c10d/UCCTracing.cpp", "torch/csrc/distributed/c10d/UCCUtils.cpp", - "torch/csrc/distributed/c10d/intra_node_comm.cpp", - "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp", "torch/csrc/distributed/c10d/cuda/AsyncMM.cu", "torch/csrc/distributed/c10d/cuda/utils.cpp", - "torch/csrc/distributed/c10d/NanCheck.cu", - "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", + "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu", + "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu", + "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp", + "torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp", + "torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu", + "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp", + "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", + "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", ] libtorch_cuda_distributed_sources = libtorch_cuda_distributed_base_sources + libtorch_cuda_distributed_extra_sources @@ -865,6 +896,8 @@ libtorch_python_core_sources = [ "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", "torch/csrc/export/pybind.cpp", + "torch/csrc/export/upgrader.cpp", + "torch/csrc/export/example_upgraders.cpp", "torch/csrc/inductor/aoti_package/pybind.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/inductor/aoti_eager/kernel_holder.cpp", @@ -1510,7 +1543,7 @@ aten_cuda_cu_with_sort_by_key_source_list = [ "aten/src/ATen/native/cuda/Unique.cu", ] -# Followings are source code for xnnpack delegate +# Following are source code for xnnpack delegate xnnpack_delegate_serializer_header = [ "torch/csrc/jit/backends/xnnpack/serialization/serializer.h", diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index f00c662e70e04e..8e9d267352dd21 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +cmake_minimum_required(VERSION 3.27 FATAL_ERROR) project(c10 CXX) set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") diff --git a/c10/build.bzl b/c10/build.bzl index 6ecae511223879..42a831deb66970 100644 --- a/c10/build.bzl +++ b/c10/build.bzl @@ -30,7 +30,6 @@ def define_targets(rules): "//c10/macros", "//c10/util:base_headers", "//c10/util:bit_cast", - "//c10/util:ssize", ], visibility = ["//visibility:public"], ) diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 409c837c5908a0..67c9276313bba2 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -38,6 +38,8 @@ enum class Backend { SparseCUDA, SparseCsrCPU, SparseCsrCUDA, + SparseCsrMPS, + SparseMPS, SparseHIP, SparseVE, SparseXPU, @@ -94,6 +96,10 @@ inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::SparseCPU; } else if (t == DispatchKey::SparseCUDA) { return Backend::SparseCUDA; + } else if (t == DispatchKey::SparseMPS) { + return Backend::SparseMPS; + } else if (t == DispatchKey::SparseCsrMPS) { + return Backend::SparseCsrMPS; } else if (t == DispatchKey::SparseHIP) { return Backend::SparseHIP; } else if (t == DispatchKey::SparseVE) { @@ -172,6 +178,10 @@ inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::SparseCPU; case Backend::SparseCUDA: return DispatchKey::SparseCUDA; + case Backend::SparseMPS: + return DispatchKey::SparseMPS; + case Backend::SparseCsrMPS: + return DispatchKey::SparseCsrMPS; case Backend::SparseHIP: return DispatchKey::SparseHIP; case Backend::SparseVE: @@ -227,6 +237,8 @@ inline DeviceType backendToDeviceType(Backend b) { return DeviceType::CPU; case Backend::CUDA: case Backend::SparseCUDA: + case Backend::SparseMPS: + case Backend::SparseCsrMPS: case Backend::QuantizedCUDA: case Backend::SparseCsrCUDA: return DeviceType::CUDA; @@ -309,6 +321,10 @@ inline const char* toString(Backend b) { return "SparseCPU"; case Backend::SparseCUDA: return "SparseCUDA"; + case Backend::SparseMPS: + return "SparseMPS"; + case Backend::SparseCsrMPS: + return "SparseCsrMPS"; case Backend::SparseHIP: return "SparseHIP"; case Backend::SparseVE: @@ -361,6 +377,7 @@ inline bool isSparse(Backend b) { case Backend::SparseXPU: case Backend::SparseCPU: case Backend::SparseCUDA: + case Backend::SparseMPS: case Backend::SparseHIP: case Backend::SparseVE: case Backend::SparsePrivateUse1: diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index 7653a8b7253afe..b23490de693a8f 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -18,7 +18,7 @@ struct DeviceStats { // released via device memory deallocation) StatArray inactive_split; - // SUM: bytes allocated by this memory alocator + // SUM: bytes allocated by this memory allocator StatArray allocated_bytes; // SUM: bytes reserved by this memory allocator (both free and used) StatArray reserved_bytes; diff --git a/c10/core/Contiguity.h b/c10/core/Contiguity.h index 276d2ce07b5777..279a795583b12c 100644 --- a/c10/core/Contiguity.h +++ b/c10/core/Contiguity.h @@ -12,7 +12,7 @@ namespace c10 { template bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) { + if (numel == 0) { return true; } @@ -20,11 +20,11 @@ bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { // NB: make sure we do signed arithmetic for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { const auto& size_d = sizes[d]; - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) { + if (size_d == 1) { continue; } - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) { + if (strides[d] != expected_stride) { return false; } expected_stride *= size_d; @@ -32,29 +32,66 @@ bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { return true; } -// This function will return True if the tensor is contiguous, and False if the -// its not or if we can't determine if it is contiguous due to unbacked symbols -// (it could be either in that case based on the actual runtime data). -template -bool definitely_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { - if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { +// Return a SymBool with underlying symbolic expression that represents +// contiguity. Guaranteed not to add guards. +inline static c10::SymBool _compute_contiguous_sym( + ArrayRef sizes, + ArrayRef strides, + const c10::SymInt& numel) { + // If this return true, the tensor is contiguous indeed. Otherwise it could be + // either. + auto is_contiguous_or_false = [&]() { + if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { + return true; + } + + // When calculating the expected stride, we can choose to multiply + // with max(1, size[d]) or size[d]. Regardless, this is ok for this + // function. Why? + // (1) If size[d] == 0, then the tensor is contiguous and if + // we return true or false it won't break this function. + // (2) If size[d] is not 0, then max(1,size[d]) and size[d] are equal. + // Therefore, if we choose to use max(1, size[d]) or size[d] to + // calculate the expected stride, the result is the same. + // + // We symbolically check both paths to maximize the cases where this + // function returns true. This is because make_contiguous_strides_for adds + // the max symbolically, and in some other situations the max might not be + // there. And we want to ensure we return true in both cases. + c10::SymInt expected_stride = 1; + c10::SymInt expected_stride_max = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) { + continue; + } + + if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride)) && + TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride_max))) { + return false; + } + expected_stride_max *= sizes[d].max(1); + expected_stride *= sizes[d]; + } return true; + }; + + if (is_contiguous_or_false()) { + return c10::SymBool(true); } - T expected_stride = 1; - // NB: make sure we do signed arithmetic + // Build a single expression that represents contiguity and return it. + c10::SymBool is_empty = sym_eq(numel, 0); + c10::SymBool is_contiguous_cond = true; + + c10::SymInt expected_stride = 1; for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { const auto& size_d = sizes[d]; - if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) { - continue; - } - - if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) { - return false; - } - expected_stride *= size_d; + is_contiguous_cond = is_contiguous_cond.sym_and( + size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride))); + expected_stride = expected_stride * size_d; } - return true; + return is_contiguous_cond.sym_or(is_empty); } template diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index daf94245c1e7eb..32fcbc17717900 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -119,7 +119,7 @@ std::ostream& operator<<(std::ostream& stream, DeviceType type) { // Whenever a user prints a privateuse1 device name, they need to read this // variable. Although unlikely, we'll data race if someone else is trying to // set this variable at the same time that another thread is print the -// device name. We could re-use the same mutex, but reading the atomic will +// device name. We could reuse the same mutex, but reading the atomic will // be much faster. static std::atomic privateuse1_backend_name_set; static std::string privateuse1_backend_name; diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 4d803e639989f1..7c239ecddede26 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -354,6 +354,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"SparseCPU", c10::DispatchKey::SparseCPU}, {"SparseCUDA", c10::DispatchKey::SparseCUDA}, + {"SparseMPS", c10::DispatchKey::SparseMPS}, + {"SparseCsrMPS", c10::DispatchKey::SparseCsrMPS}, {"SparseHIP", c10::DispatchKey::SparseHIP}, {"SparseXPU", c10::DispatchKey::SparseXPU}, {"SparseVE", c10::DispatchKey::SparseVE}, diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 4cbd0cea8571eb..96ef6b3522ba70 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -53,7 +53,7 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | // explicit kernels therefore we manually add the key to the // math_dispatch_keyset DispatchKeySet{DispatchKey::NestedTensor} | - // Functionalize should always re-use CompositeImplicit decomps. + // Functionalize should always reuse CompositeImplicit decomps. DispatchKeySet{DispatchKey::Functionalize}; constexpr DispatchKeySet nested_dispatch_keyset = diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 49dafe1e3cb0e4..4de19c9ce5bfcc 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -115,7 +115,7 @@ C10_ALWAYS_INLINE static const std:: // Not every backend and not every functionality counts as a "building block // key". This is mostly to give us more levers to pull in the design space. // Backend keys and functionality keys that count as "building blocks" will -// contribute to a full cross product of functionality that can be overriden. +// contribute to a full cross product of functionality that can be overridden. // // For example, right now we have at least 12 "backend" building // blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality" diff --git a/c10/core/Layout.h b/c10/core/Layout.h index 82a9129501d9d2..0daa129bb5a4ff 100644 --- a/c10/core/Layout.h +++ b/c10/core/Layout.h @@ -32,6 +32,8 @@ inline Layout layout_from_backend(Backend backend) { switch (backend) { case Backend::SparseCPU: case Backend::SparseCUDA: + case Backend::SparseMPS: + case Backend::SparseCsrMPS: case Backend::SparseHIP: case Backend::SparseVE: case Backend::SparseXPU: @@ -46,7 +48,7 @@ inline Layout layout_from_backend(Backend backend) { case Backend::SparseCsrXPU: TORCH_CHECK( false, - "Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout."); + "Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU|MPS) to a unique layout."); default: return Layout::Strided; } diff --git a/c10/core/RefcountedDeleter.h b/c10/core/RefcountedDeleter.h index ce988864720a19..060a44187d7d90 100644 --- a/c10/core/RefcountedDeleter.h +++ b/c10/core/RefcountedDeleter.h @@ -17,12 +17,12 @@ namespace c10 { // data when the refcount reaches 0. // // This shared DataPtr feature is only used when storages are shared between -// multiple Python interpreters in MultiPy. Before storages had PyObject -// preservation, interpreters could just share the same StorageImpl instance. -// But now a StorageImpl can only be associated with one interpreter in order -// to properly manage a zombie PyObject. So we share storages across Python -// interpreters by creating a different StorageImpl instance for each one, but -// they all point to the same data. +// multiple Python interpreters in MultiPy. // codespell:ignore multipy +// Before storages had PyObject preservation, interpreters could just share the +// same StorageImpl instance. But now a StorageImpl can only be associated with +// one interpreter in order to properly manage a zombie PyObject. So we share +// storages across Python interpreters by creating a different StorageImpl +// instance for each one, but they all point to the same data. struct C10_API RefcountedDeleterContext { RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter) : other_ctx(other_ctx, other_deleter), refcount(1) {} diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 2a40114573cc98..3b483c86bc88ff 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -13,7 +13,6 @@ #include #include #include -#include #include #include #include @@ -187,9 +186,9 @@ class C10_API Scalar { return Tag::HAS_d == tag || Tag::HAS_sd == tag; } - C10_DEPRECATED_MESSAGE( - "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.") - bool isIntegral() const { + [[deprecated( + "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")]] bool + isIntegral() const { return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag; } bool isIntegral(bool includeBool) const { diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 05b7249f7336ce..3d8a2b0074e9ea 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -376,9 +375,9 @@ inline bool isIntegralType(ScalarType t, bool includeBool) { return isIntegral || (includeBool && t == ScalarType::Bool); } -C10_DEPRECATED_MESSAGE( - "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.") -inline bool isIntegralType(ScalarType t) { +[[deprecated( + "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")]] inline bool +isIntegralType(ScalarType t) { return isIntegralType(t, /*includeBool=*/false); } diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp index 70dfc0b74af89e..c6c2743d8358a3 100644 --- a/c10/core/SymInt.cpp +++ b/c10/core/SymInt.cpp @@ -136,7 +136,7 @@ SymInt operator-(const SymInt& s) { const auto val = *ma; // Note: Result of `-std::numeric_limits::min()` is undefined // But on many platforms it equals to self + setting Carry/Overflow flags - // Which in opimized code affects results of `check_range` condition + // Which in optimized code affects results of `check_range` condition // Workaround by using ternary that avoids alterning the flags #if C10_HAS_BUILTIN_OVERFLOW() std::decay_t out = 0; diff --git a/c10/core/SymbolicShapeMeta.cpp b/c10/core/SymbolicShapeMeta.cpp index 3becf927cd5200..6fa2ab0ed4f1db 100644 --- a/c10/core/SymbolicShapeMeta.cpp +++ b/c10/core/SymbolicShapeMeta.cpp @@ -79,18 +79,51 @@ SymBool SymbolicShapeMeta::compute_contiguous() const { } c10::SymIntArrayRef sizes(sizes_); c10::SymIntArrayRef strides(strides_); - return _compute_contiguous(sizes, strides, numel()); + + auto result = _compute_contiguous_sym(sizes, strides, numel()); + + // If the result is already determined without guarding, just return it. + auto maybe_as_bool = result.maybe_as_bool(); + if (maybe_as_bool.has_value()) { + return maybe_as_bool.value(); + } + + auto all_hinted = true; + for (const auto& s : sizes) { + if (!s.has_hint()) { + all_hinted = false; + break; + } + } + + if (all_hinted) { + for (const auto& s : strides) { + if (!s.has_hint()) { + all_hinted = false; + break; + } + } + } + + if (all_hinted) { + // We avoid going through the slow path if everything is hinted, + // because evaluating a large SymPy expression can be expensive. + // TODO exclude backed_size_oblivious from this path. + return _compute_contiguous(sizes_, strides_, numel()); + } + + return result; } // The rest of them -#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ - SymBool SymbolicShapeMeta::name() const { \ - if (!strides_valid_) { \ - return false; \ - } \ - c10::SymIntArrayRef sizes(sizes_); \ - c10::SymIntArrayRef strides(strides_); \ - return fallback(sizes, strides); \ +#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, fallback) \ + SymBool SymbolicShapeMeta::name() const { \ + if (!strides_valid_) { \ + return false; \ + } \ + c10::SymIntArrayRef sizes(sizes_); \ + c10::SymIntArrayRef strides(strides_); \ + return fallback(sizes, strides); \ } #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ @@ -110,11 +143,13 @@ SymBool SymbolicShapeMeta::compute_contiguous() const { } // clang-format off -DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) -DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) -DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d) -DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d) + DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) + // clang-format on #undef DEFINE_SYMBOOL_COMPUTE @@ -192,6 +227,7 @@ void SymbolicShapeMeta::set_numel(SymInt val) const { numel_ = std::move(val); available_.fetch_or(numel_avail); } + void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_contiguous()) { @@ -200,6 +236,7 @@ void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { is_contiguous_ = std::move(val); available_.fetch_or(is_contiguous_avail); } + void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_contiguous()) { @@ -208,6 +245,7 @@ void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { is_channels_last_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_contiguous_avail); } + void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_3d_contiguous()) { @@ -216,6 +254,7 @@ void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { is_channels_last_3d_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_3d_contiguous_avail); } + void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last()) { @@ -224,6 +263,7 @@ void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { is_channels_last_ = std::move(val); available_.fetch_or(is_channels_last_avail); } + void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_3d()) { diff --git a/c10/core/SymbolicShapeMeta.h b/c10/core/SymbolicShapeMeta.h index ce0769a8074f7d..0820038968a8e7 100644 --- a/c10/core/SymbolicShapeMeta.h +++ b/c10/core/SymbolicShapeMeta.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -82,6 +83,15 @@ class C10_API SymbolicShapeMeta { return numel_; } + const SymBool& is_contiguous(at::MemoryFormat memory_format) const { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return this->is_channels_last_contiguous(); + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return this->is_channels_last_3d_contiguous(); + } + return this->is_contiguous(); + } + const SymBool& is_contiguous() const { if (C10_UNLIKELY(!has_is_contiguous())) { init_is_contiguous(); @@ -194,6 +204,7 @@ class C10_API SymbolicShapeMeta { // Lazily initialized variables, with the corresponding available_ flag // indicating whether the value has been initialized mutable std::atomic available_{0}; + enum avail { numel_avail = 1 << 0, is_contiguous_avail = 1 << 1, diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index c3f110b35c081c..f3ec2f2d46ea21 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -160,7 +160,7 @@ TensorImpl::TensorImpl( if (inference_mode) { // See Note [Expected TLS state in InferenceMode] for why we exclude // Autograd & ADInplaceOrView keys. Normally key_set only contains backend - // keys but we do the substraction here to make sure. + // keys but we do the subtraction here to make sure. key_set_ = key_set - c10::autograd_dispatch_keyset_with_ADInplaceOrView; } else { // TODO: Ideally we only add AutogradBackend key when the tensor requires @@ -218,7 +218,7 @@ void TensorImpl::HandleResize() { } } -bool TensorImpl::compute_contiguous(identity) const { +bool TensorImpl::compute_contiguous() const { if (is_sparse()) { return false; } @@ -228,7 +228,7 @@ bool TensorImpl::compute_contiguous(identity) const { numel_); } -bool TensorImpl::compute_channels_last_contiguous_2d(identity) const { +bool TensorImpl::compute_channels_last_contiguous_2d() const { if (is_sparse()) { return false; } @@ -237,7 +237,7 @@ bool TensorImpl::compute_channels_last_contiguous_2d(identity) const { sizes_and_strides_.strides_arrayref()); } -bool TensorImpl::compute_channels_last_contiguous_3d(identity) const { +bool TensorImpl::compute_channels_last_contiguous_3d() const { if (is_sparse()) { return false; } @@ -246,7 +246,7 @@ bool TensorImpl::compute_channels_last_contiguous_3d(identity) const { sizes_and_strides_.strides_arrayref()); } -bool TensorImpl::compute_strides_like_channels_last_2d(identity) const { +bool TensorImpl::compute_strides_like_channels_last_2d() const { if (is_sparse()) { return false; } @@ -255,7 +255,7 @@ bool TensorImpl::compute_strides_like_channels_last_2d(identity) const { sizes_and_strides_.strides_arrayref()); } -bool TensorImpl::compute_strides_like_channels_last_3d(identity) const { +bool TensorImpl::compute_strides_like_channels_last_3d() const { if (is_sparse()) { return false; } @@ -264,7 +264,7 @@ bool TensorImpl::compute_strides_like_channels_last_3d(identity) const { sizes_and_strides_.strides_arrayref()); } -bool TensorImpl::compute_non_overlapping_and_dense(identity) const { +bool TensorImpl::compute_non_overlapping_and_dense() const { if (is_sparse()) { return false; } @@ -310,12 +310,14 @@ void TensorImpl::throw_data_ptr_access_error() const { false, "Cannot access data pointer of Tensor that doesn't have storage"); } -bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { +c10::SymBool TensorImpl::sym_is_contiguous_custom( + at::MemoryFormat memory_format) const { if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { return pyobj_slot_.load_pyobj_interpreter()->is_contiguous( this, memory_format); } - return is_contiguous_default(memory_format); + + return sym_is_contiguous_default(memory_format); } bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { @@ -326,12 +328,12 @@ bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { return is_strides_like_default(memory_format); } -bool TensorImpl::is_non_overlapping_and_dense_custom() const { +c10::SymBool TensorImpl::sym_is_non_overlapping_and_dense_custom() const { if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { return pyobj_slot_.load_pyobj_interpreter()->is_non_overlapping_and_dense( this); } - return is_non_overlapping_and_dense_default(); + return sym_is_non_overlapping_and_dense_default(); } IntArrayRef TensorImpl::sizes_custom() const { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index eb47e08d84f960..381bc65b27fbdb 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -812,6 +812,43 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } + c10::SymBool sym_is_contiguous( + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return sym_is_contiguous_custom(memory_format); + } + return sym_is_contiguous_default(memory_format); + } + + template + T is_contiguous_default_impl(at::MemoryFormat memory_format) const { + if (!has_symbolic_sizes_strides_) { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return is_channels_last_contiguous_; + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return is_channels_last_3d_contiguous_; + } + return is_contiguous_; + } + + // Handle dynamic shapes. + const auto& symbolic = symbolic_shape_meta().is_contiguous(memory_format); + + if constexpr (std::is_same_v) { + return symbolic.guard_bool(__FILE__, __LINE__); + } else { + return symbolic; + } + } + + bool is_contiguous_default(at::MemoryFormat memory_format) const { + return is_contiguous_default_impl(memory_format); + } + + c10::SymBool sym_is_contiguous_default(at::MemoryFormat memory_format) const { + return is_contiguous_default_impl(memory_format); + } + /** * Whether or not a tensor is laid out in contiguous memory. * @@ -827,30 +864,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_contiguous_default(memory_format); } - // These are factored into separate functions in case subclasses - // want to use them - bool is_contiguous_default(at::MemoryFormat memory_format) const { - if (has_symbolic_sizes_strides_) { - if (memory_format == at::MemoryFormat::ChannelsLast) { - return symbolic_shape_meta().is_channels_last_contiguous().guard_bool( - __FILE__, __LINE__); - } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { - return symbolic_shape_meta() - .is_channels_last_3d_contiguous() - .guard_bool(__FILE__, __LINE__); - } - return symbolic_shape_meta().is_contiguous().guard_bool( - __FILE__, __LINE__); - } - - if (memory_format == at::MemoryFormat::ChannelsLast) { - return is_channels_last_contiguous_; - } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { - return is_channels_last_3d_contiguous_; - } - return is_contiguous_; - } - bool is_strides_like_default(at::MemoryFormat memory_format) const { if (has_symbolic_sizes_strides_) { if (memory_format == at::MemoryFormat::ChannelsLast) { @@ -873,9 +886,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } + SymBool sym_is_non_overlapping_and_dense_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().is_non_overlapping_and_dense(); + } else { + return is_non_overlapping_and_dense_; + } + } + bool is_non_overlapping_and_dense_default() const { if (has_symbolic_sizes_strides_) { - return symbolic_shape_meta().is_non_overlapping_and_dense().guard_bool( + return sym_is_non_overlapping_and_dense_default().guard_bool( __FILE__, __LINE__); } else { return is_non_overlapping_and_dense_; @@ -964,13 +985,28 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Customization points for the functions above. sizes_strides_policy_ * must be set to enable these. * - * NB: dim is overrideable separately from sizes because it is possible + * NB: dim is overridable separately from sizes because it is possible * for a tensor to have rank, but not well defined sizes. */ // sizes_strides_policy_ >= CustomStrides - virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const; + virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const; - virtual bool is_non_overlapping_and_dense_custom() const; + + virtual c10::SymBool sym_is_non_overlapping_and_dense_custom() const; + + bool is_non_overlapping_and_dense_custom() const { + return sym_is_non_overlapping_and_dense_custom().guard_bool( + __FILE__, __LINE__); + } + + virtual c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const; + + bool is_contiguous_custom(at::MemoryFormat memory_format) const { + return sym_is_contiguous_custom(memory_format) + .guard_bool(__FILE__, __LINE__); + } + // sizes_strides_policy_ >= CustomSizes // Currently this method only exists to be overwritten by subclasses such as // NestedTensorImpl. @@ -1004,9 +1040,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { virtual c10::SymInt sym_storage_offset_custom() const; public: - /** - * True if this tensor has storage. See storage() for details. - */ +/** + * True if this tensor has storage. See storage() for details. + */ #ifdef DEBUG // Allow subclasses to check that their storage_ is never getting set in debug // builds. @@ -1016,11 +1052,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { #endif bool has_storage() const - // NOTE: we devirtualize this because it arguably shouldn't be an - // error just to ask subclasses if they have storage. - // This used to throw for most subclasses, but OpaqueTensorImpl - // wanted it to successfully return false, so we went ahead and made - // it a non-error. +// NOTE: we devirtualize this because it arguably shouldn't be an +// error just to ask subclasses if they have storage. +// This used to throw for most subclasses, but OpaqueTensorImpl +// wanted it to successfully return false, so we went ahead and made +// it a non-error. #ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY { return storage_; @@ -2447,6 +2483,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_strides_like(at::MemoryFormat::ChannelsLast3d); } + bool is_non_overlapping_and_dense_or_false() const { + return sym_is_non_overlapping_and_dense().guard_or_false( + __FILE__, __LINE__); + } + bool is_non_overlapping_and_dense() const { if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { return is_non_overlapping_and_dense_custom(); @@ -2454,6 +2495,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_non_overlapping_and_dense_default(); } + SymBool sym_is_non_overlapping_and_dense() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return sym_is_non_overlapping_and_dense_custom(); + } + return sym_is_non_overlapping_and_dense_default(); + } + // if this returns true, then it is guaranteed that this tensor has symbolic // sizes/strides bool has_symbolic_sizes_strides() const { @@ -2567,17 +2615,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Compute whether or not a tensor is contiguous based on the sizes and * strides of a tensor. */ - bool compute_contiguous(identity) const; + bool compute_contiguous() const; - bool compute_channels_last_contiguous_2d(identity) const; + bool compute_channels_last_contiguous_2d() const; - bool compute_channels_last_contiguous_3d(identity) const; + bool compute_channels_last_contiguous_3d() const; - bool compute_strides_like_channels_last_2d(identity) const; + bool compute_strides_like_channels_last_2d() const; - bool compute_strides_like_channels_last_3d(identity) const; + bool compute_strides_like_channels_last_3d() const; - bool compute_non_overlapping_and_dense(identity) const; + bool compute_non_overlapping_and_dense() const; protected: /** @@ -2620,68 +2668,62 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } private: - // NB: the TypeId argument prevents confusion where you pass a true/false - // literal and pick the wrong overload - - void _set_is_contiguous(identity, bool b) { + void _set_is_contiguous(bool b) { is_contiguous_ = b; } - void _set_is_channels_last_contiguous(identity, bool b) { + void _set_is_channels_last_contiguous(bool b) { is_channels_last_contiguous_ = b; } - void _set_is_channels_last_3d_contiguous(identity, bool b) { + void _set_is_channels_last_3d_contiguous(bool b) { is_channels_last_3d_contiguous_ = b; } - void _set_is_channels_last(identity, bool b) { + void _set_is_channels_last(bool b) { is_channels_last_ = b; } - void _set_is_channels_last_3d(identity, bool b) { + void _set_is_channels_last_3d(bool b) { is_channels_last_3d_ = b; } - void _set_is_non_overlapping_and_dense(identity, bool b) { + void _set_is_non_overlapping_and_dense(bool b) { is_non_overlapping_and_dense_ = b; } // These are little wrappers over the real compute_ functions that // can make use of other contiguity fields to short circuit. - bool compute_is_non_overlapping_and_dense_dim4(identity type_id) { + bool compute_is_non_overlapping_and_dense_dim4() { return is_contiguous_ || is_channels_last_contiguous_ || - compute_non_overlapping_and_dense(type_id); + compute_non_overlapping_and_dense(); } - bool compute_channels_last_contiguous_3d_dim5(identity type_id) { + bool compute_channels_last_contiguous_3d_dim5() { return !is_channels_last_contiguous_ && - compute_channels_last_contiguous_3d(type_id); + compute_channels_last_contiguous_3d(); } - bool compute_channels_last_2d_dim5(identity type_id) { + bool compute_channels_last_2d_dim5() { return !is_channels_last_3d_contiguous_ && - compute_strides_like_channels_last_2d(type_id); + compute_strides_like_channels_last_2d(); } - bool compute_channels_last_3d_dim5(identity type_id) { - return !is_channels_last_ && compute_strides_like_channels_last_3d(type_id); + bool compute_channels_last_3d_dim5() { + return !is_channels_last_ && compute_strides_like_channels_last_3d(); } - bool compute_is_non_overlapping_and_dense_dim5(identity type_id) { + bool compute_is_non_overlapping_and_dense_dim5() { return is_contiguous_ || is_channels_last_contiguous_ || - is_channels_last_3d_contiguous_ || - compute_non_overlapping_and_dense(type_id); + is_channels_last_3d_contiguous_ || compute_non_overlapping_and_dense(); } - bool compute_is_non_overlapping_and_dense_anydim(identity type_id) { - return is_contiguous_ || compute_non_overlapping_and_dense(type_id); + bool compute_is_non_overlapping_and_dense_anydim() { + return is_contiguous_ || compute_non_overlapping_and_dense(); } - template void _refresh_contiguous() { - auto type_id = identity(); // Note: // Dim 0, 1, 2 will never be a channels last 2d/3d format // Dim 3+ is possibly be a channels last 2d format (Dim 4 only at this @@ -2689,28 +2731,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // this point) switch (dim()) { case 4: { - _set_is_contiguous(type_id, compute_contiguous(type_id)); - _set_is_channels_last_contiguous( - type_id, compute_channels_last_contiguous_2d(type_id)); - _set_is_channels_last_3d_contiguous(type_id, false); - _set_is_channels_last( - type_id, compute_strides_like_channels_last_2d(type_id)); - _set_is_channels_last_3d(type_id, false); + _set_is_contiguous(compute_contiguous()); + _set_is_channels_last_contiguous(compute_channels_last_contiguous_2d()); + _set_is_channels_last_3d_contiguous(false); + _set_is_channels_last(compute_strides_like_channels_last_2d()); + _set_is_channels_last_3d(false); _set_is_non_overlapping_and_dense( - type_id, compute_is_non_overlapping_and_dense_dim4(type_id)); + compute_is_non_overlapping_and_dense_dim4()); break; } case 5: { - _set_is_contiguous(type_id, compute_contiguous(type_id)); - _set_is_channels_last_contiguous( - type_id, compute_channels_last_contiguous_2d(type_id)); + _set_is_contiguous(compute_contiguous()); + _set_is_channels_last_contiguous(compute_channels_last_contiguous_2d()); _set_is_channels_last_3d_contiguous( - type_id, compute_channels_last_contiguous_3d_dim5(type_id)); - _set_is_channels_last(type_id, compute_channels_last_2d_dim5(type_id)); - _set_is_channels_last_3d( - type_id, compute_channels_last_3d_dim5(type_id)); + compute_channels_last_contiguous_3d_dim5()); + _set_is_channels_last(compute_channels_last_2d_dim5()); + _set_is_channels_last_3d(compute_channels_last_3d_dim5()); _set_is_non_overlapping_and_dense( - type_id, compute_is_non_overlapping_and_dense_dim5(type_id)); + compute_is_non_overlapping_and_dense_dim5()); break; } default: @@ -2719,13 +2757,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // mean the tensor is strided like channels_last: for strides on channel // dimension could suggest desired memory_layout, but it doesn't affect // memory storage - _set_is_contiguous(type_id, compute_contiguous(type_id)); - _set_is_channels_last_contiguous(type_id, false); - _set_is_channels_last_3d_contiguous(type_id, false); - _set_is_channels_last(type_id, false); - _set_is_channels_last_3d(type_id, false); + _set_is_contiguous(compute_contiguous()); + _set_is_channels_last_contiguous(false); + _set_is_channels_last_3d_contiguous(false); + _set_is_channels_last(false); + _set_is_channels_last_3d(false); _set_is_non_overlapping_and_dense( - type_id, compute_is_non_overlapping_and_dense_anydim(type_id)); + compute_is_non_overlapping_and_dense_anydim()); break; } } @@ -2739,7 +2777,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { if (has_symbolic_sizes_strides_) { symbolic_shape_meta().refresh_contiguous(); } else { - _refresh_contiguous(); + _refresh_contiguous(); } } diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index d781ddf9e971a0..b42d3a92545f04 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -12,7 +12,8 @@ UndefinedTensorImpl::UndefinedTensorImpl() set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); } -bool UndefinedTensorImpl::is_contiguous_custom(MemoryFormat format) const { +c10::SymBool UndefinedTensorImpl::sym_is_contiguous_custom( + MemoryFormat format) const { return is_contiguous_default(format); } IntArrayRef UndefinedTensorImpl::strides_custom() const { diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index 33ac4e7f868a45..6b7573a69388aa 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -32,7 +32,7 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { void set_storage_offset(int64_t offset) override; protected: - bool is_contiguous_custom(MemoryFormat format) const override; + c10::SymBool sym_is_contiguous_custom(MemoryFormat format) const override; IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; diff --git a/c10/core/impl/HermeticPyObjectTLS.h b/c10/core/impl/HermeticPyObjectTLS.h index 741132b9f967c1..a973a5d2cef8f0 100644 --- a/c10/core/impl/HermeticPyObjectTLS.h +++ b/c10/core/impl/HermeticPyObjectTLS.h @@ -13,7 +13,8 @@ namespace c10::impl { struct C10_API HermeticPyObjectTLS { static void set_state(bool state); static bool get_state() { - // Hypothetical fastpath if torchdeploy/multipy isn't used. Per + // Hypothetical fastpath if torchdeploy/multipy // codespell:ignore multipy + // isn't used. Per // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf // this qualifies relaxed access because it is a single-location data // structure (only the boolean here). @@ -46,12 +47,14 @@ struct C10_API HermeticPyObjectTLS { return false; return get_tls_state(); } - // Call this from the multipy/torchdeploy top level + // Call this from the multipy/torchdeploy // codespell:ignore multipy + // top level static void init_state(); private: - // This only flipped once from false to true during torchdeploy/multipy - // initialization, and never again. + // This only flipped once from false to true during + // torchdeploy/multipy initialization, // codespell:ignore multipy + // and never again. static std::atomic haveState_; static bool get_tls_state(); }; diff --git a/c10/core/impl/InlineEvent.h b/c10/core/impl/InlineEvent.h index a731621a5bfdee..f09dbe63632719 100644 --- a/c10/core/impl/InlineEvent.h +++ b/c10/core/impl/InlineEvent.h @@ -121,7 +121,7 @@ struct InlineEvent final { was_marked_for_recording() && other.was_marked_for_recording(), "Both events must be recorded before calculating elapsed time."); // elapsedTime in MPS can wait event to be completed if event is not ready, - // which is a little differenct from CUDA + // which is a little different from CUDA TORCH_CHECK( (query() && other.query()) || device_type_ == DeviceType::MPS, "Both events must be completed before calculating elapsed time."); diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index 568de4491cfbb1..43492443c530c6 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -143,9 +143,9 @@ struct C10_API PyInterpreterVTable { virtual void reportErrorCallback(PyObject* callback, DispatchKey key) const = 0; - // This is only invoked in the multipy/torchdeploy situation from - // pythonOpRegistrationTrampoline; this lets us get to the Python - // interpreter to actually find the appropriate Python op registration + // This is only invoked in the multipy/torchdeploy // codespell:ignore multipy + // situation from pythonOpRegistrationTrampoline; this lets us get to the + // Python interpreter to actually find the appropriate Python op registration // entry to call. virtual void python_op_registration_trampoline( const c10::OperatorHandle& op, diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index ecf0c976686c0b..d2efb8c593e44e 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -16,6 +16,13 @@ CUDAAllocatorConfig::CUDAAllocatorConfig() m_garbage_collection_threshold(0), m_pinned_num_register_threads(1), m_expandable_segments(false), +#if CUDA_VERSION >= 12030 + m_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::UNSPECIFIED), +#else + m_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::POSIX_FD), +#endif m_release_lock_on_cudamalloc(false), m_pinned_use_cuda_host_register(false), m_pinned_use_background_threads(false) { @@ -156,7 +163,7 @@ size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions( } TORCH_CHECK( val2 == 0 || llvm::isPowerOf2_64(val2), - "For roundups, the divisons has to be power of 2 or 0 to disable roundup ", + "For roundups, the divisions has to be power of 2 or 0 to disable roundup ", ""); if (std::string_view(val1) == ">") { @@ -202,7 +209,7 @@ size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions( size_t val1 = stoi(config[i]); TORCH_CHECK( llvm::isPowerOf2_64(val1), - "For roundups, the divisons has to be power of 2 ", + "For roundups, the divisions has to be power of 2 ", ""); std::fill( m_roundup_power2_divisions.begin(), diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 5fac52cd42bea1..fda3cc02e5d0ae 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -13,6 +13,12 @@ namespace c10::cuda::CUDACachingAllocator { +enum class Expandable_Segments_Handle_Type : int { + UNSPECIFIED = 0, + POSIX_FD = 1, + FABRIC_HANDLE = 2, +}; + // Environment config parser class C10_CUDA_API CUDAAllocatorConfig { public: @@ -34,6 +40,15 @@ class C10_CUDA_API CUDAAllocatorConfig { #endif } + static Expandable_Segments_Handle_Type expandable_segments_handle_type() { + return instance().m_expandable_segments_handle_type; + } + + static void set_expandable_segments_handle_type( + Expandable_Segments_Handle_Type handle_type) { + instance().m_expandable_segments_handle_type = handle_type; + } + static bool release_lock_on_cudamalloc() { return instance().m_release_lock_on_cudamalloc; } @@ -60,7 +75,7 @@ class C10_CUDA_API CUDAAllocatorConfig { // This is used to round-up allocation size to nearest power of 2 divisions. // More description below in function roundup_power2_next_division - // As ane example, if we want 4 divisions between 2's power, this can be done + // As an example, if we want 4 divisions between 2's power, this can be done // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 static size_t roundup_power2_divisions(size_t size); @@ -134,6 +149,8 @@ class C10_CUDA_API CUDAAllocatorConfig { std::atomic m_garbage_collection_threshold; std::atomic m_pinned_num_register_threads; std::atomic m_expandable_segments; + std::atomic + m_expandable_segments_handle_type; std::atomic m_release_lock_on_cudamalloc; std::atomic m_pinned_use_cuda_host_register; std::atomic m_pinned_use_background_threads; diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index b945efea785fcb..e152feba9ccc45 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,18 @@ TORCH_SDT_DEFINE_SEMAPHORE(malloc) TORCH_SDT_DEFINE_SEMAPHORE(free) +// add these definitions so that we can compile with CUDA < 12.3 +// borrowed from +// https://github.com/NVIDIA/nccl/blob/3ea7eedf3b9b94f1d9f99f4e55536dfcbd23c1ca/src/include/p2p.h#L20 +#if CUDA_VERSION < 12030 +#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL) +#define CU_IPC_HANDLE_SIZE 64 +typedef struct CUmemFabricHandle_st { + unsigned char data[CU_IPC_HANDLE_SIZE]; +} CUmemFabricHandle_v1; +typedef CUmemFabricHandle_v1 CUmemFabricHandle; +#endif + namespace c10 { // NOLINTNEXTLINE(misc-use-internal-linkage) @@ -131,7 +144,7 @@ constexpr size_t kMinLargeAlloc = 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB -static char SHAREABLE_HANDLE_VERSION = 1; +static char SHAREABLE_HANDLE_VERSION = 2; enum ShareableHandleType : char { SHAREABLE_CUDA_MALLOC = 'c', SHAREABLE_CUDA_EXPANDABLE_SEGMENT = 'e' @@ -387,6 +400,7 @@ struct ExpandableSegment { // returns the actual range mapped, which may be // greater than requested if size is not aligned to segment_size_. // return size of 0 indicates OOM + // return nullptr indicates the handle type is not supported. SegmentRange map(SegmentRange range) { auto begin = segmentLeft(range.ptr); auto end = segmentRight(range.ptr + range.size); @@ -394,6 +408,23 @@ struct ExpandableSegment { if (begin == end) { return rangeFromHandles(begin, end); } + + // if the handle type is not specified, try to use fabric handle first. + // if it fails, use posix file handle + if (CUDAAllocatorConfig::expandable_segments_handle_type() == + Expandable_Segments_Handle_Type::UNSPECIFIED) { + CUDAAllocatorConfig::set_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::FABRIC_HANDLE); + auto output = map(range); + if (output.ptr != nullptr) { + return output; + } + // if fabric handle is not supported, use posix file handle. + CUDAAllocatorConfig::set_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::POSIX_FD); + return map(range); + } + while (end > handles_.size()) { handles_.emplace_back(std::nullopt); } @@ -403,7 +434,12 @@ struct ExpandableSegment { CUmemAllocationProp prop = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; #ifndef FBCODE_CAFFE2 - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + if (CUDAAllocatorConfig::expandable_segments_handle_type() != + Expandable_Segments_Handle_Type::FABRIC_HANDLE) { + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + } else { + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + } #endif int flag = 0; C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuDeviceGetAttribute_( @@ -417,17 +453,32 @@ struct ExpandableSegment { prop.location.id = static_cast(device_); auto status = DriverAPI::get()->cuMemCreate_(&handle, segment_size_, &prop, 0); - if (status == CUDA_ERROR_OUT_OF_MEMORY) { - for (auto j : c10::irange(begin, i)) { - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - auto h = handles_.at(j).value(); - handles_.at(j) = std::nullopt; - C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle)); + if (status != CUDA_SUCCESS) { + if (status == CUDA_ERROR_OUT_OF_MEMORY) { + for (auto j : c10::irange(begin, i)) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + auto h = handles_.at(j).value(); + handles_.at(j) = std::nullopt; + C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle)); + } + trimHandles(); + return rangeFromHandles(begin, begin); + } else if ( + CUDAAllocatorConfig::expandable_segments_handle_type() == + Expandable_Segments_Handle_Type::FABRIC_HANDLE) { + // we are testing if we can use fabric handle. + // if we can, we will use it. + // if we can't, we will use posix file handle. + // so we should not return an error here. + // in practice, we can get CUDA_ERROR_NOT_SUPPORTED or + // CUDA_ERROR_NOT_PERMITTED to be safe, any non out-of-memory error is + // considered as the handle type is not supported. if the handle type + // is not supported, return a null range to indicate it. + return SegmentRange(nullptr, 0); + } else { + C10_CUDA_DRIVER_CHECK(status); } - trimHandles(); - return rangeFromHandles(begin, begin); } - C10_CUDA_DRIVER_CHECK(status); handles_.at(i) = Handle{handle, std::nullopt}; } mapAndSetAccess(begin, end); @@ -460,14 +511,33 @@ struct ExpandableSegment { for (auto i : c10::irange(begin, end)) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) auto& handle = handles_.at(i).value(); - if (!handle.fd) { - int fd = 0; - C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemExportToShareableHandle_( - &fd, handle.handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); - handle.fd = fd; + if (CUDAAllocatorConfig::expandable_segments_handle_type() != + Expandable_Segments_Handle_Type::FABRIC_HANDLE) { + if (!handle.shareable_handle) { + int fd = 0; + C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemExportToShareableHandle_( + &fd, handle.handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); + handle.shareable_handle = fd; + LOG(INFO) << "use posix fd to share expandable segments."; + } + TORCH_CHECK( + handle.shareable_handle != std::nullopt, + "shareable_handle is null"); + buf.write((const char*)&*handle.shareable_handle, sizeof(int)); + } else { + if (!handle.shareable_handle) { + CUmemFabricHandle fabric_handle; + C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemExportToShareableHandle_( + &fabric_handle, handle.handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + handle.shareable_handle = fabric_handle; + LOG(INFO) << "use fabric handle to share expandable segments."; + } + TORCH_CHECK( + handle.shareable_handle != std::nullopt, + "shareable_handle is null"); + buf.write( + (const char*)&*handle.shareable_handle, sizeof(CUmemFabricHandle)); } - int fd = *handle.fd; - buf.write((const char*)&fd, sizeof(int)); } return rangeFromHandles(begin, end); } @@ -492,42 +562,60 @@ struct ExpandableSegment { #ifndef SYS_pidfd_getfd #define SYS_pidfd_getfd 438 #endif - auto pidfd = syscall(SYS_pidfd_open, header.pid, 0); - TORCH_CHECK( - pidfd != -1 || errno != ENOSYS, - "The kernel on this machine does not support the pidfd_open syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. " - "Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation."); - TORCH_CHECK(pidfd != -1, "pidfd_open:", c10::utils::str_error(errno)); - for (auto i : c10::irange(header.num_handles)) { - (void)i; - int fd = 0; - buf.read((char*)&fd, sizeof(int)); - auto myfd = syscall(SYS_pidfd_getfd, pidfd, fd, 0); - if (myfd == -1) { - auto err = errno; - close((int)pidfd); - for (auto& h : segment->handles_) { - C10_CUDA_DRIVER_CHECK( - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - DriverAPI::get()->cuMemRelease_(h.value().handle)); - h = std::nullopt; + if (CUDAAllocatorConfig::expandable_segments_handle_type() != + Expandable_Segments_Handle_Type::FABRIC_HANDLE) { + auto pidfd = syscall(SYS_pidfd_open, header.pid, 0); + TORCH_CHECK( + pidfd != -1 || errno != ENOSYS, + "The kernel on this machine does not support the pidfd_open syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. " + "Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation."); + TORCH_CHECK(pidfd != -1, "pidfd_open:", c10::utils::str_error(errno)); + for (auto i : c10::irange(header.num_handles)) { + (void)i; + int fd = 0; + buf.read((char*)&fd, sizeof(int)); + auto myfd = syscall(SYS_pidfd_getfd, pidfd, fd, 0); + if (myfd == -1) { + auto err = errno; + close((int)pidfd); + for (auto& h : segment->handles_) { + C10_CUDA_DRIVER_CHECK( + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + DriverAPI::get()->cuMemRelease_(h.value().handle)); + h = std::nullopt; + } + TORCH_CHECK( + err != ENOSYS, + "The kernel on this machine does not support the pidfd_getfd syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. " + "Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation."); + TORCH_CHECK(false, "pidfd_getfd: ", c10::utils::str_error(err)); } - TORCH_CHECK( - err != ENOSYS, - "The kernel on this machine does not support the pidfd_getfd syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. " - "Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation."); - TORCH_CHECK(false, "pidfd_getfd: ", c10::utils::str_error(err)); + CUmemGenericAllocationHandle handle = 0; + C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_( + &handle, + // NOLINTNEXTLINE(performance-no-int-to-ptr) + (void*)(uintptr_t)myfd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + LOG(INFO) << "use posix fd to import expandable segments."; + close((int)myfd); + segment->handles_.emplace_back(Handle{handle, std::nullopt}); } - CUmemGenericAllocationHandle handle = 0; - C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_( - &handle, - // NOLINTNEXTLINE(performance-no-int-to-ptr) - (void*)(uintptr_t)myfd, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - close((int)myfd); - segment->handles_.emplace_back(Handle{handle, std::nullopt}); - } - close((int)pidfd); + close((int)pidfd); + } else { + for (auto i : c10::irange(header.num_handles)) { + (void)i; + CUmemFabricHandle fabric_handle; + buf.read((char*)&fabric_handle, sizeof(CUmemFabricHandle)); + CUmemGenericAllocationHandle handle = 0; + C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_( + &handle, + // NOLINTNEXTLINE(performance-no-int-to-ptr) + (void*)&fabric_handle, + CU_MEM_HANDLE_TYPE_FABRIC)); + LOG(INFO) << "use fabric handle to import expandable segments."; + segment->handles_.emplace_back(Handle{handle, std::nullopt}); + } + } segment->mapAndSetAccess(0, header.num_handles); return segment; } @@ -601,8 +689,8 @@ struct ExpandableSegment { handles_.at(i) = std::nullopt; C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemUnmap_( ptr_ + segment_size_ * i, segment_size_)); - if (h.fd) { - close(*h.fd); + if (h.shareable_handle) { + close(std::get(*h.shareable_handle)); } C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle)); } @@ -646,7 +734,7 @@ struct ExpandableSegment { size_t max_handles_; struct Handle { CUmemGenericAllocationHandle handle; - std::optional fd; + std::optional> shareable_handle; }; struct ShareHeader { pid_t pid; @@ -1546,7 +1634,7 @@ class DeviceCachingAllocator { block->allocated = false; - // following logic might modifying underlaying Block, causing the size + // following logic might modifying underlying Block, causing the size // changed. We store ahead for reporting auto orig_block_ptr = block->ptr; auto orig_block_size = block->size; @@ -2096,7 +2184,7 @@ class DeviceCachingAllocator { // For example, if we need to round-up 1200 and number of divisions is 4, // the size 1200 lies between 1024 and 2048 and if we do 4 divisions between // them, the values are 1024, 1280, 1536, and 1792. So the function will - // return 1280 as the nearest ceiling of power-2 divison. + // return 1280 as the nearest ceiling of power-2 division. static size_t roundup_power2_next_division(size_t size, size_t divisions) { if (llvm::isPowerOf2_64(size)) { return size; @@ -2754,7 +2842,7 @@ class DeviceCachingAllocator { } } - // This function assumes that global lock has been taken whle calling into + // This function assumes that global lock has been taken while calling into // this function. We do cudaMalloc sync call in this function which // can be expensive while holding the lock. Hence, we pass-in the lock to the // function to temporarily release the lock before cudaMalloc call and acquire @@ -3485,16 +3573,16 @@ class NativeCachingAllocator : public CUDAAllocator { } void setMemoryFraction(double fraction, c10::DeviceIndex device) override { - TORCH_INTERNAL_ASSERT( + TORCH_CHECK( 0 <= device && static_cast(device) < device_allocator.size(), "Allocator not initialized for device ", device, ": did you call init?"); - TORCH_INTERNAL_ASSERT( + TORCH_CHECK( 0 <= fraction && fraction <= 1, "invalid fraction:", fraction, - ". Please set within (0, 1)."); + ". Please set within [0, 1]."); C10_CUDA_CHECK(c10::cuda::SetDevice(device)); device_allocator[device]->setMemoryFraction(fraction); } @@ -4117,8 +4205,11 @@ std::atomic MemPool::uuid_{1}; MemPool::MemPool( CUDACachingAllocator::CUDAAllocator* allocator, bool is_user_created, - bool use_on_oom) - : allocator_(allocator), is_user_created_(is_user_created) { + bool use_on_oom, + bool symmetric) + : allocator_(allocator), + is_user_created_(is_user_created), + symmetric_(symmetric) { if (is_user_created_) { id_ = {0, uid_++}; } else { @@ -4141,6 +4232,10 @@ MempoolId_t MemPool::id() { return id_; } +bool MemPool::is_symmetric() { + return symmetric_; +} + CUDACachingAllocator::CUDAAllocator* MemPool::allocator() { return allocator_; } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 6d86a2178d58d8..a6fa61110d6750 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -366,7 +366,7 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) { } inline void emptyCache(MempoolId_t mempool_id = {0, 0}) { - return get()->emptyCache(); + return get()->emptyCache(mempool_id); } inline void enable(bool value) { @@ -535,7 +535,8 @@ struct C10_CUDA_API MemPool { MemPool( CUDACachingAllocator::CUDAAllocator* allocator = nullptr, bool is_user_created = true, - bool use_on_oom = false); + bool use_on_oom = false, + bool symmetric = false); MemPool(const MemPool&) = delete; MemPool(MemPool&&) = default; MemPool& operator=(const MemPool&) = delete; @@ -543,6 +544,7 @@ struct C10_CUDA_API MemPool { ~MemPool(); MempoolId_t id(); + bool is_symmetric(); CUDACachingAllocator::CUDAAllocator* allocator(); int use_count(); c10::DeviceIndex device(); @@ -554,6 +556,7 @@ struct C10_CUDA_API MemPool { CUDACachingAllocator::CUDAAllocator* allocator_; bool is_user_created_; MempoolId_t id_; + bool symmetric_; c10::DeviceIndex device_; }; diff --git a/c10/cuda/CUDADeviceAssertionHost.h b/c10/cuda/CUDADeviceAssertionHost.h index 370bb7f291f2ef..8d1711a2b48c58 100644 --- a/c10/cuda/CUDADeviceAssertionHost.h +++ b/c10/cuda/CUDADeviceAssertionHost.h @@ -47,7 +47,7 @@ struct DeviceAssertionData { /// Used to hold assertions generated by the device /// Held in managed memory and access by both the CPU and the GPU. struct DeviceAssertionsData { - /// Total number of assertions found; a subset of thse will be recorded + /// Total number of assertions found; a subset of these will be recorded /// in `assertions` int32_t assertion_count{}; /// An array of assertions that will be written to in a race-free manner diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 7d54f098d4ff17..05f00e43a2a7c9 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -38,8 +38,8 @@ void c10_cuda_check_implementation( "Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers."); } #endif - - TORCH_CHECK(false, check_message); + throw c10::AcceleratorError( + {__func__, __FILE__, int32_t(__LINE__)}, err, check_message); } } // namespace c10::cuda diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 43f825dc9b93e5..0e8cabf6185934 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -130,8 +130,8 @@ DeviceIndex current_device() { return cur_device; } -void set_device(DeviceIndex device) { - C10_CUDA_CHECK(c10::cuda::SetDevice(device)); +void set_device(DeviceIndex device, const bool force) { + C10_CUDA_CHECK(c10::cuda::SetDevice(device, force)); } void device_synchronize() { @@ -231,9 +231,12 @@ cudaError_t GetDevice(DeviceIndex* device) { return err; } -cudaError_t SetDevice(DeviceIndex device) { - TORCH_CHECK(device >= 0, "device id must be positive!", device); +cudaError_t SetDevice(DeviceIndex device, const bool force) { + TORCH_CHECK(device >= 0, "device id must be non-negative!", device); targetDeviceIndex = -1; + if (force) { + return cudaSetDevice(device); + } int cur_device = -1; C10_CUDA_CHECK(cudaGetDevice(&cur_device)); if (device == cur_device) { @@ -309,8 +312,11 @@ cudaError_t GetDevice(DeviceIndex* device) { return err; } -cudaError_t SetDevice(DeviceIndex device) { - TORCH_CHECK(device >= 0, "device id must be positive!", device); +cudaError_t SetDevice(DeviceIndex device, const bool force) { + TORCH_CHECK(device >= 0, "device id must be non-negative!", device); + if (force) { + return cudaSetDevice(device); + } int cur_device = -1; C10_CUDA_CHECK(cudaGetDevice(&cur_device)); if (device == cur_device) { diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 192fafbad10f42..2c7aa99feeb3a0 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -27,7 +27,7 @@ C10_CUDA_API DeviceIndex device_count_ensure_non_zero(); C10_CUDA_API DeviceIndex current_device(); -C10_CUDA_API void set_device(DeviceIndex device); +C10_CUDA_API void set_device(DeviceIndex device, const bool force = false); C10_CUDA_API void device_synchronize(); @@ -38,7 +38,8 @@ C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count); C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device); -C10_CUDA_API cudaError_t SetDevice(DeviceIndex device); +C10_CUDA_API cudaError_t +SetDevice(DeviceIndex device, const bool force = false); C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device); diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 161487148652b9..0cde2d9de01cf2 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -245,7 +245,13 @@ static void initCUDAStreamsOnce() { // Helper to verify the GPU index is valid static inline void check_gpu(DeviceIndex device_index) { - TORCH_INTERNAL_ASSERT(device_index >= 0 && device_index < num_gpus); + TORCH_CHECK( + device_index >= 0 && device_index < num_gpus, + "Device index value ", + static_cast(device_index), + " is out of index range [0, ", + static_cast(num_gpus), + ")"); } // Helper to determine the index of the stream to return diff --git a/c10/cuda/test/CMakeLists.txt b/c10/cuda/test/CMakeLists.txt index 9c0606c2d54f49..1d912cfedb9038 100644 --- a/c10/cuda/test/CMakeLists.txt +++ b/c10/cuda/test/CMakeLists.txt @@ -17,7 +17,8 @@ if(BUILD_TEST) if(WIN32 AND test_src MATCHES "^.*\.hip$") set_source_files_properties(${test_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) hip_add_executable(${test_name} "${test_src}") - set_target_properties(${test_name} PROPERTIES HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH}) + list(JOIN PYTORCH_ROCM_ARCH " " ROCM_PROPERTY_ARCH_LIST) + set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX HIP_ARCHITECTURES ${ROCM_PROPERTY_ARCH_LIST}) else() add_executable(${test_name} "${test_src}") endif() diff --git a/c10/macros/Export.h b/c10/macros/Export.h index 21808de77a31e0..b013910902b262 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -1,95 +1,11 @@ #ifndef C10_MACROS_EXPORT_H_ #define C10_MACROS_EXPORT_H_ -/* Header file to define the common scaffolding for exported symbols. - * - * Export is by itself a quite tricky situation to deal with, and if you are - * hitting this file, make sure you start with the background here: - * - Linux: https://gcc.gnu.org/wiki/Visibility - * - Windows: - * https://docs.microsoft.com/en-us/cpp/cpp/dllexport-dllimport?view=vs-2017 - * - * Do NOT include this file directly. Instead, use c10/macros/Macros.h - */ - -// You do not need to edit this part of file unless you are changing the core -// pytorch export abstractions. -// -// This part defines the C10 core export and import macros. This is controlled -// by whether we are building shared libraries or not, which is determined -// during build time and codified in c10/core/cmake_macros.h. -// When the library is built as a shared lib, EXPORT and IMPORT will contain -// visibility attributes. If it is being built as a static lib, then EXPORT -// and IMPORT basically have no effect. - -// As a rule of thumb, you should almost NEVER mix static and shared builds for -// libraries that depend on c10. AKA, if c10 is built as a static library, we -// recommend everything dependent on c10 to be built statically. If c10 is built -// as a shared library, everything dependent on it should be built as shared. In -// the PyTorch project, all native libraries shall use the macro -// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static -// libraries. - -// For build systems that do not directly depend on CMake and directly build -// from the source directory (such as Buck), one may not have a cmake_macros.h -// file at all. In this case, the build system is responsible for providing -// correct macro definitions corresponding to the cmake_macros.h.in file. -// -// In such scenarios, one should define the macro -// C10_USING_CUSTOM_GENERATED_MACROS -// to inform this header that it does not need to include the cmake_macros.h -// file. - #ifndef C10_USING_CUSTOM_GENERATED_MACROS #include #endif // C10_USING_CUSTOM_GENERATED_MACROS -#ifdef _WIN32 -#define C10_HIDDEN -#if defined(C10_BUILD_SHARED_LIBS) -#define C10_EXPORT __declspec(dllexport) -#define C10_IMPORT __declspec(dllimport) -#else -#define C10_EXPORT -#define C10_IMPORT -#endif -#else // _WIN32 -#if defined(__GNUC__) -#define C10_EXPORT __attribute__((__visibility__("default"))) -#define C10_HIDDEN __attribute__((__visibility__("hidden"))) -#else // defined(__GNUC__) -#define C10_EXPORT -#define C10_HIDDEN -#endif // defined(__GNUC__) -#define C10_IMPORT C10_EXPORT -#endif // _WIN32 - -#ifdef NO_EXPORT -#undef C10_EXPORT -#define C10_EXPORT -#endif - -// Definition of an adaptive XX_API macro, that depends on whether you are -// building the library itself or not, routes to XX_EXPORT and XX_IMPORT. -// Basically, you will need to do this for each shared library that you are -// building, and the instruction is as follows: assuming that you are building -// a library called libawesome.so. You should: -// (1) for your cmake target (usually done by "add_library(awesome, ...)"), -// define a macro called AWESOME_BUILD_MAIN_LIB using -// target_compile_options. -// (2) define the AWESOME_API macro similar to the one below. -// And in the source file of your awesome library, use AWESOME_API to -// annotate public symbols. - -// Here, for the C10 library, we will define the macro C10_API for both import -// and export. - -// This one is being used by libc10.so -#ifdef C10_BUILD_MAIN_LIB -#define C10_API C10_EXPORT -#else -#define C10_API C10_IMPORT -#endif +#include // This one is being used by libtorch.so #ifdef CAFFE2_BUILD_MAIN_LIB @@ -159,4 +75,4 @@ #define C10_API_ENUM #endif -#endif // C10_MACROS_MACROS_H_ +#endif // C10_MACROS_EXPORT_H_ diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 0947be6c0d0fe6..6b51a39f2a9438 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -312,7 +312,21 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; #endif #if defined(USE_ROCM) -#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h) +// C10_WARP_SIZE is only allowed for device code. +// Host code _must_ use at::cuda::warp_size() +// HIP header used to define warpSize as a constexpr that was either 32 or 64 +// depending on the target device, and then always set it to 64 for host code. +// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we +// set it to something unreasonable to trigger obvious host code errors. +#if defined(__HIP_DEVICE_COMPILE__) +#if defined(__GFX9__) +static constexpr int C10_WARP_SIZE = 64; +#else // __GFX9__ +static constexpr int C10_WARP_SIZE = 32; +#endif // __GFX9__ +#else +static constexpr int C10_WARP_SIZE = 1; +#endif // __HIP_DEVICE_COMPILE__ #else #define C10_WARP_SIZE 32 #endif @@ -403,11 +417,24 @@ __host__ __device__ #endif // __SYCL_DEVICE_ONLY__ } #endif // NDEBUG -// ROCm disable kernel assert by default +// ROCm disables kernel assert by default for performance considerations. +// Though ROCm supports __assert_fail, it uses kernel printf which has +// a non-negligible performance impact even if the assert condition is +// never triggered. We choose to use abort() instead which will still +// terminate the application but without a more useful error message. #if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) -#define CUDA_KERNEL_ASSERT(cond) -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) -#define SYCL_KERNEL_ASSERT(cond) +#define CUDA_KERNEL_ASSERT(cond) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } #else #define CUDA_KERNEL_ASSERT(cond) \ if (C10_UNLIKELY(!(cond))) { \ diff --git a/c10/macros/build.bzl b/c10/macros/build.bzl index 73646c7cbe2f53..129b2b1e057026 100644 --- a/c10/macros/build.bzl +++ b/c10/macros/build.bzl @@ -12,6 +12,9 @@ def define_targets(rules): linkstatic = True, local_defines = ["C10_BUILD_MAIN_LIB"], visibility = ["//visibility:public"], + deps = [ + "//torch/headeronly:torch_headeronly", + ], ) rules.cmake_configure_file( diff --git a/c10/metal/expm1f.h b/c10/metal/expm1f.h new file mode 100644 index 00000000000000..3bc1517a2db383 --- /dev/null +++ b/c10/metal/expm1f.h @@ -0,0 +1,97 @@ +// Copy-and-pasted from: +// https://github.com/ml-explore/mlx/blob/99c33d011d63174f50cea37c3eede002958be6d3/mlx/backend/metal/kernels/expm1f.h + +#pragma once + +#include + +// Original license copied below: +// Copyright (c) 2015-2023 Norbert Juffa +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +namespace c10 { +namespace metal { + +/* Compute exponential base e minus 1. Maximum ulp error = 0.997458 + + i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. + Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). + With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, + when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. + + NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) +*/ +inline float expm1f_scaled_unchecked(float a, float b) { + float f, j, r, s, t, u, v, x, y; + int i; + + // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) + j = ::metal::fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 + j = j - 12582912.0f; // 0x1.8p23 + i = (int)j; + f = ::metal::fma(j, -6.93145752e-1f, a); + + // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] + s = f * f; + if (a == 0.0f) + s = a; // ensure -0 is passed through + // err = 0.997458 ulp1 = 11081805 + r = 1.97350979e-4f; // 0x1.9de000p-13 + r = ::metal::fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 + r = ::metal::fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 + r = ::metal::fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 + r = ::metal::fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 + r = ::metal::fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 + u = (j == 1) ? (f + 0.5f) : f; + v = ::metal::fma(r, s, u); + s = 0.5f * b; + t = ::metal::ldexp(s, i); + y = t - s; + x = (t - y) - s; // double-float canonicalization of difference + r = ::metal::fma(v, t, x) + y; + r = r + r; + if (j == 0) + r = v; + if (j == 1) + r = v + v; + return r; +} + +/* Compute exponential base e minus 1. max ulp err = 0.99746 */ +inline float expm1f(float a) { + float r; + + r = expm1f_scaled_unchecked(a, 1.0f); + /* handle severe overflow and underflow */ + if (::metal::abs(a - 1.0f) > 88.0f) { + r = ::metal::pow(2, a); + r = ::metal::fma(r, r, -1.0f); + } + return r; +} + +} // namespace metal +} // namespace c10 diff --git a/c10/metal/indexing.h b/c10/metal/indexing.h index 7f144a764a0460..cd7de5b54766b4 100644 --- a/c10/metal/indexing.h +++ b/c10/metal/indexing.h @@ -104,6 +104,61 @@ kernel void unary_strided( } \ } +template +kernel void unary_alpha_dense( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T2& alpha [[buffer(2)]], + uint index [[thread_position_in_grid]]) { + F f; + output[index] = f(input[index], alpha); +} + +template +kernel void unary_alpha_strided( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant long* sizes [[buffer(2)]], + constant long* input_strides [[buffer(3)]], + constant long* output_strides [[buffer(4)]], + constant uint& ndim [[buffer(5)]], + constant T2& alpha [[buffer(6)]], + uint index [[thread_position_in_grid]]) { + F f; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + output[output_offs] = f(input[input_offs], alpha); +} + +#define REGISTER_UNARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \ + static_assert( \ + ::metal::is_same_v< \ + DTYPEO, \ + ::c10::metal::result_of>, \ + "Output dtype mismatch for unary op " #NAME " and input " #DTYPEI); \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + unary_alpha_dense( \ + device ::c10::metal::result_of * \ + output, \ + constant DTYPEI * input, \ + constant DTYPEA & alpha, \ + uint index); \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + unary_alpha_strided( \ + device ::c10::metal::result_of * \ + output, \ + constant DTYPEI * input, \ + constant long* sizes, \ + constant long* input_strides, \ + constant long* output_strides, \ + constant uint& ndim, \ + constant DTYPEA& alpha, \ + uint index) + template inline T val_at_offs(constant void* ptr, long offs) { return *reinterpret_cast( @@ -156,10 +211,10 @@ inline device T& ref_at_offs(device void* ptr, long offs) { // strided // - binary_dense_cast - inputs are dense, but of different dtypes // - binary_strided_cast - inputs or output are strided and of different dtypes -// TODO: Looke like binary_dense_scalar are frequently used specialization that -// should be added Pluse 4 variants of the same, but that accept optional +// TODO: Look like binary_dense_scalar are frequently used specialization that +// should be added Pulse 4 variants of the same, but that accept optional // `alpha` parameter -// (currnetly only used add/sub/lerp.Scalar) +// (currently only used add/sub/lerp.Scalar) // Note about accuracy (for more info see // https://github.com/pytorch/pytorch/issues/152736) Sometimes when kernel is // invoked to produce `half` output, but one of the arguments is float arguments @@ -191,12 +246,12 @@ kernel void binary_strided( static_cast(f(om_t(a), om_t(b))); } -template -kernel void alpha_binary_strided( +template +kernel void binary_alpha_strided( device void* output [[buffer(0)]], constant void* input [[buffer(1)]], constant void* other [[buffer(2)]], - constant T& alpha [[buffer(3)]], + constant T2& alpha [[buffer(3)]], constant long* sizes [[buffer(4)]], constant long* output_strides [[buffer(5)]], constant long* input_strides [[buffer(6)]], @@ -211,7 +266,7 @@ kernel void alpha_binary_strided( const auto output_offs = offset_from_coord(pos, output_strides, ndim.x); const auto a = val_at_offs(input, input_offs); const auto b = val_at_offs(other, other_offs); - ref_at_offs>(output, output_offs) = f(a, b, alpha); + ref_at_offs>(output, output_offs) = f(a, b, alpha); } template > @@ -223,7 +278,7 @@ kernel void binary_strided_cast( constant long* output_strides [[buffer(4)]], constant long* input_strides [[buffer(5)]], constant long* other_strides [[buffer(6)]], - constant uint3& ndim_types [[buffer(7)]], + constant uint4& ndim_types [[buffer(7)]], uint index [[thread_position_in_grid]]) { F f; using res_t = result_of; @@ -239,17 +294,17 @@ kernel void binary_strided_cast( ref_at_offs(output, output_offs) = static_cast(f(a, b)); } -template -kernel void alpha_binary_strided_cast( +template +kernel void binary_alpha_strided_cast( device void* output [[buffer(0)]], constant void* input [[buffer(1)]], constant void* other [[buffer(2)]], - constant T& alpha [[buffer(3)]], + constant T2& alpha [[buffer(3)]], constant long* sizes [[buffer(4)]], constant long* output_strides [[buffer(5)]], constant long* input_strides [[buffer(6)]], constant long* other_strides [[buffer(7)]], - constant uint3& ndim_types [[buffer(8)]], + constant uint4& ndim_types [[buffer(8)]], uint index [[thread_position_in_grid]]) { F f; int pos[max_ndim]; @@ -261,7 +316,7 @@ kernel void alpha_binary_strided_cast( val_at_offs(input, input_offs, static_cast(ndim_types.y)); const auto b = val_at_offs(other, other_offs, static_cast(ndim_types.z)); - ref_at_offs>(output, output_offs) = f(a, b, alpha); + ref_at_offs>(output, output_offs) = f(a, b, alpha); } template > @@ -275,12 +330,12 @@ kernel void binary_dense( out[tid] = static_cast(f(om_t(input[tid]), om_t(other[tid]))); } -template -kernel void alpha_binary_dense( - device result_of* out [[buffer(0)]], +template +kernel void binary_alpha_dense( + device result_of* out [[buffer(0)]], constant T* input [[buffer(1)]], constant T* other [[buffer(2)]], - constant T& alpha [[buffer(3)]], + constant T2& alpha [[buffer(3)]], uint tid [[thread_position_in_grid]]) { F f; out[tid] = f(input[tid], other[tid], alpha); @@ -302,12 +357,12 @@ kernel void binary_dense_cast( out[tid] = static_cast(f(a, b)); } -template -kernel void alpha_binary_dense_cast( - device result_of* out [[buffer(0)]], +template +kernel void binary_alpha_dense_cast( + device result_of* out [[buffer(0)]], constant void* input [[buffer(1)]], constant void* other [[buffer(2)]], - constant T& alpha [[buffer(3)]], + constant T2& alpha [[buffer(3)]], constant uint4& sizes_types [[buffer(4)]], uint tid [[thread_position_in_grid]]) { F f; @@ -344,7 +399,7 @@ kernel void alpha_binary_dense_cast( constant long* output_strides, \ constant long* input_strides, \ constant long* other_strides, \ - constant uint3& ndim_types, \ + constant uint4& ndim_types, \ uint tid); \ template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ c10::metal::binary_dense( \ @@ -369,54 +424,58 @@ kernel void alpha_binary_dense_cast( #define REGISTER_BINARY_OP(NAME, DTYPEI, DTYPEO) \ REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI) -#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEO) \ +#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \ static_assert( \ ::metal::is_same_v< \ DTYPEO, \ - ::c10::metal::result_of>, \ + ::c10::metal::result_of>, \ "Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \ - template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ - c10::metal::alpha_binary_strided( \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_strided( \ device void* out, \ constant void* input, \ constant void* other, \ - constant DTYPEI& alpha, \ + constant DTYPEA& alpha, \ constant long* sizes, \ constant long* output_strides, \ constant long* input_strides, \ constant long* other_strides, \ constant uint3& ndim, \ uint tid); \ - template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \ - metal::alpha_binary_strided_cast( \ + template [[host_name(#NAME "_strided_cast_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_strided_cast( \ device void* out, \ constant void* input, \ constant void* other, \ - constant DTYPEI& alpha, \ + constant DTYPEA& alpha, \ constant long* sizes, \ constant long* output_strides, \ constant long* input_strides, \ constant long* other_strides, \ - constant uint3& ndim_types, \ + constant uint4& ndim_types, \ uint tid); \ - template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ - c10::metal::alpha_binary_dense( \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense( \ device ::c10::metal:: \ - result_of * \ + result_of * \ out_, \ constant DTYPEI * input_, \ constant DTYPEI * other_, \ - constant DTYPEI & alpha, \ + constant DTYPEA & alpha, \ uint tid); \ - template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \ - metal::alpha_binary_dense_cast( \ - device ::c10::metal:: \ - result_of * \ - out_, \ - constant void* input, \ - constant void* other, \ - constant DTYPEI& alpha, \ - constant uint4& sizes_types, \ - uint tid) + template \ + [[host_name(#NAME "_dense_cast_" #DTYPEI "_" #DTYPEA)]] kernel void :: \ + c10::metal::binary_alpha_dense_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* input, \ + constant void* other, \ + constant DTYPEA& alpha, \ + constant uint4& sizes_types, \ + uint tid) } // namespace metal } // namespace c10 diff --git a/c10/metal/random.h b/c10/metal/random.h index 29c9c58368057b..c03d9b8a3149c0 100644 --- a/c10/metal/random.h +++ b/c10/metal/random.h @@ -1,4 +1,4 @@ -// Philox Counter based RNG implemntation for Metal +// Philox Counter based RNG implementation for Metal // Borrowed from aten/src/ATen/core/PhiloxRNGEngine.h // Which in turn borrowed from // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 2e1350d391a0f6..34f6ab6d1d09e8 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -1,35 +1,50 @@ // Implementation of specal math functions for Metal #pragma once +#include #include #include namespace c10 { namespace metal { -// Translated to metal from https://www.johndcook.com/cpp_erf.html - +/* + * Approximation to the error function. + * Based on code from: + * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 + * Copy-n-pasted from + * https://github.com/ml-explore/mlx/blob/2e8cf0b4506c200a5c2d199ecbbf655fdf4c2ce2/mlx/backend/metal/kernels/erf.h#L11 + */ template -inline T erf(T x) { - T a1 = 0.254829592; - T a2 = -0.284496736; - T a3 = 1.421413741; - T a4 = -1.453152027; - T a5 = 1.061405429; - T p = 0.3275911; - - // Save the sign of x - int sign = 1; - if (x < 0) - sign = -1; - x = ::metal::fabs(x); - - // A&S formula 7.1.26 - T t = 1.0 / (1.0 + p * x); - T y = 1.0 - - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * - ::metal::exp(-x * x); - - return sign * y; +inline float erf(T x) { + const auto a = static_cast(x); + const auto t = ::metal::abs(a); + const auto s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + auto r = ::metal::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + const auto u = ::metal::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = ::metal::fma(r, s, u); + r = ::metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = ::metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = ::metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = ::metal::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - ::metal::exp(r); + r = ::metal::copysign(r, a); + return r; + } + + // maximum error 0.98929 ulp + auto r = -5.96761703e-4f; // -0x1.38e000p-11 + r = ::metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = ::metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = ::metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = ::metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = ::metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = ::metal::fma(r, a, a); + return r; } template @@ -603,20 +618,6 @@ inline T spherical_bessel_j0(T x) { return static_cast(::metal::sin(x) / x); } -// Compute log(1+x) without losing precision for small values of x -// Adapted from https://www.johndcook.com/blog/cpp_log_one_plus_x/ -template -inline float log1p(T x) { - // x is large enough that the obvious evaluation is OK - if (::metal::fabs(x) > 1E-4) { - return ::metal::log(1. + x); - } - - // Use Taylor approx. log(1 + x) = x - x^2/2 with error roughly x^3/3 - // Since |x| < 10^-4, |x|^3 < 10^-12, relative error less than 10^-8 - return (-0.5 * x + 1.0) * x; -} - template inline float xlog1py(T x, T y) { if (::metal::isnan(y)) { @@ -627,7 +628,7 @@ inline float xlog1py(T x, T y) { return x; } - return x * log1p(y); + return x * ::c10::metal::log1p(y); } template @@ -1558,7 +1559,7 @@ float chebyshev_polynomial_t_forward(T x, int64_t n) { float q = x; float r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { r = (x + x) * q - p; p = q; q = r; @@ -1602,7 +1603,7 @@ float chebyshev_polynomial_u_forward(T x, int64_t n) { auto p = 1.0; float r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { r = 2 * x * q - p; p = q; q = r; @@ -1655,7 +1656,7 @@ float chebyshev_polynomial_v_forward(T x, int64_t n) { auto p = 1.0; float r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { r = 2 * x * q - p; p = q; q = r; @@ -1712,7 +1713,7 @@ float chebyshev_polynomial_w_forward(T x, int64_t n) { auto p = 1.0; float r; - for (int64_t k = 2; k <= n; k++) { + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { r = 2.0 * x * q - p; p = q; q = r; @@ -1721,6 +1722,207 @@ float chebyshev_polynomial_w_forward(T x, int64_t n) { return r; } // chebyshev_polynomial_w_forward(T x, int64_t n) +template +float shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (x == T(1.0)) { + return 1.0; + } + + if (x == 0.0) { + if (n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + const float xpxm1 = x + x - 1.0; + if ((n > 6) && (::metal::abs(xpxm1) < 1.0)) { + return ::metal::precise::cos(n * ::metal::precise::acos(xpxm1)); + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return xpxm1; + } + + float p = 1.0; + float q = xpxm1; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (xpxm1 + xpxm1) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) + +template +float shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (x == 1.0) { + return n + 1; + } + + if (x == 0.0) { + if (n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + const float xpxm1 = x + x - 1.0; + if ((n > 6) && (::metal::abs(xpxm1) < 1.0)) { + const float acos_2xm1 = ::metal::precise::acos(xpxm1); + const float divisor = ::metal::precise::sin(acos_2xm1); + if (divisor != 0.0) { + return ::metal::precise::sin((n + 1) * acos_2xm1) / divisor; + } + + return (n + 1) * ::metal::precise::cos((n + 1) * acos_2xm1) / xpxm1; + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return xpxm1 + xpxm1; + } + + float p = 1.0; + float q = xpxm1 + xpxm1; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (xpxm1 + xpxm1) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) + +template +float shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (x == 1.0) { + return 1.0; + } + + if (x == 0.0) { + if (n % 2 == 0) { + return (n + n + 1); + } + + return -(n + n + 1); + } + + const float xpxm1 = x + x - 1.0; + if ((n > 6) && (::metal::abs(xpxm1) < 1.0)) { + const float acos_2xm1 = ::metal::precise::acos(xpxm1); + if (::metal::precise::sin(acos_2xm1 / 2.0) != 1.0) { + return ::metal::precise::cos((n + 0.5) * acos_2xm1) / + ::metal::precise::cos(acos_2xm1 / 2.0); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return xpxm1 + xpxm1 - 1.0; + } + + float p = 1.0; + float q = xpxm1 + xpxm1 - 1.0; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (xpxm1 + xpxm1) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) + +template +float shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (x == 1.0) { + return n + n + 1; + } + + if (x == 0.0) { + if (n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + const float xpxm1 = x + x - 1.0; + if ((n > 4) && (::metal::abs(xpxm1) < 1.0)) { + const float acos_2xm1 = ::metal::precise::acos(xpxm1); + if (::metal::precise::cos(acos_2xm1 / 2.0) != 1.0) { + return ::metal::precise::sin((n + 0.5) * acos_2xm1) / + ::metal::precise::sin(acos_2xm1 / 2.0); + } + + if (n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return xpxm1 + xpxm1 + 1.0; + } + + float p = 1.0; + float q = xpxm1 + xpxm1 + 1.0; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (xpxm1 + xpxm1) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) + template // TODO: Add 512 if/when double will be supported in Metal inline constexpr int getHermitianLimit() { diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 92fc87c4240a2b..d3a0c8ba96ad1f 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -85,6 +85,27 @@ struct OpMathType { using type = float; }; #endif + +// Type promotion structure for higher precision accumulation +template +struct AccumulationType { + using type = T; +}; + +// Specialization for half - promote to float for accumulation +template <> +struct AccumulationType { + using type = float; +}; + +#if __METAL_VERSION__ >= 310 +// Specialization for bfloat - promote to float for accumulation +template <> +struct AccumulationType { + using type = float; +}; +#endif + } // namespace detail template @@ -132,6 +153,9 @@ using vec4type_t = typename detail::vectypes::type4; template using opmath_t = typename detail::OpMathType::type; +template +using accum_t = typename detail::AccumulationType::type; + // TODO: Move it to type_traits header may be template using result_of = decltype(::metal::declval()(::metal::declval()...)); @@ -265,5 +289,46 @@ inline common_dtype div(const T x, const U y) { return T(::metal::dot(x, y), x.y * y.x - x.x * y.y) / ::metal::dot(y, y); } +// Remainder operator +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_floating_point_v || is_scalar_floating_point_v, + bool> = true> +inline float remainder(const T x, const U y) { + const auto x_f = static_cast(x); + const auto y_f = static_cast(y); + return x_f - y_f * floor_divide(x_f, y_f); +} + +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_integral_v && is_scalar_integral_v, + bool> = true> +inline common_dtype remainder(const T x, const U y) { + auto rc = x % y; + return rc == 0 || (x ^ y) > 0 ? rc : rc + y; +} + +// Based on algorithm described in +// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 +inline float log1p(float x) { + const auto xp1 = 1.0f + x; + // First two elements of Taylor series for log(1+x) in Horner's form are: + // log(1+x) = x * (1 - x * (.5 ...)), but if 1 + x == x, then it's just x + if (xp1 == 1.0f) { + return x; + } + auto rc = ::metal::precise::log(xp1); + if (x > -.5 && x < .5) { + // Order of operations is important here for higher precision + rc *= x / (xp1 - 1.0f); + } + return rc; +} + } // namespace metal } // namespace c10 diff --git a/c10/ovrsource_defs.bzl b/c10/ovrsource_defs.bzl index 804e7bceb96852..3f682fb9c2030d 100644 --- a/c10/ovrsource_defs.bzl +++ b/c10/ovrsource_defs.bzl @@ -74,6 +74,7 @@ def define_c10_ovrsource(name, is_mobile): ], }), exported_deps = [ + "//xplat/caffe2/torch/headeronly:torch_headeronly", ":ovrsource_c10_cmake_macros.h", "//arvr/third-party/gflags:gflags", "//third-party/cpuinfo:cpuinfo", diff --git a/c10/test/util/Half_test.cpp b/c10/test/util/Half_test.cpp index fc2a002f3a94a0..a76814615101b3 100644 --- a/c10/test/util/Half_test.cpp +++ b/c10/test/util/Half_test.cpp @@ -19,8 +19,7 @@ float halfbits2float(unsigned short h) { exponent = 0xff; } else if (!exponent) { /* Denorm or Zero */ if (mantissa) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - unsigned int msb; + unsigned int msb = 0; exponent = 0x71; do { msb = (mantissa & 0x400000); diff --git a/c10/test/util/generic_math_test.cpp b/c10/test/util/generic_math_test.cpp index 8ccdf8171dc769..461d55819c65c2 100644 --- a/c10/test/util/generic_math_test.cpp +++ b/c10/test/util/generic_math_test.cpp @@ -14,4 +14,6 @@ TEST(GenericMathTest, div_floor_test) { EXPECT_DOUBLE_EQ(c10::div_floor_floating(5., -2.), -3.); EXPECT_EQ(c10::div_floor_integer(5, 2), 2); EXPECT_EQ(c10::div_floor_integer(5, -2), -3); + EXPECT_EQ(c10::div_mod(-9, -3), 0); + EXPECT_EQ(c10::div_mod(-9., -3.), 0.); } diff --git a/c10/test/util/lazy_test.cpp b/c10/test/util/lazy_test.cpp index d3e208ecae5e03..09545c7a3a0ecf 100644 --- a/c10/test/util/lazy_test.cpp +++ b/c10/test/util/lazy_test.cpp @@ -53,8 +53,8 @@ TEST(LazyTest, OptimisticLazy) { EXPECT_EQ(sCopy.ensure(factory), kLongString); EXPECT_EQ(invocations.load(), 0); - auto sMove = std::move(s); - EXPECT_EQ(sMove.ensure(factory), kLongString); + auto sMove = std::move(s); // codespell:ignore smove + EXPECT_EQ(sMove.ensure(factory), kLongString); // codespell:ignore smove EXPECT_EQ(invocations.load(), 0); // NOLINTNEXTLINE(bugprone-use-after-move) EXPECT_EQ(s.ensure(factory), kLongString); diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index 10c83998c42022..64605f51535959 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -16,7 +16,6 @@ #pragma once #include -#include #include #include @@ -377,8 +376,8 @@ bool operator!=(c10::ArrayRef a1, const std::vector& a2) { using IntArrayRef = ArrayRef; -// This alias is deprecated because it doesn't make ownership -// semantics obvious. Use IntArrayRef instead! -C10_DEFINE_DEPRECATED_USING(IntList, ArrayRef) +using IntList [[deprecated( + "This alias is deprecated because it doesn't make ownership semantics obvious. Use IntArrayRef instead!")]] = + ArrayRef; } // namespace c10 diff --git a/c10/util/Exception.h b/c10/util/Exception.h index ff7130a01db01a..3ff3396f5f1b84 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -295,6 +295,19 @@ class C10_API SyntaxError : public Error { using Error::Error; }; +// Raised when accelerator API call hits an error. +// These turn into AcceleratorError when the cross into Python +class C10_API AcceleratorError : public Error { + int32_t error_code; + + public: + AcceleratorError(SourceLocation loc, int32_t code, const std::string& msg) + : Error(loc, msg), error_code(code) {} + int32_t get_error_code() const { + return error_code; + } +}; + // Base error type for all distributed errors. // These turn into DistError when they cross into Python. class C10_API DistError : public Error { @@ -705,28 +718,28 @@ namespace c10::detail { /* // Deprecation disabled until we fix sites in our codebase -C10_DEPRECATED_MESSAGE("AT_ERROR(msg) is deprecated, use TORCH_CHECK(false, msg) -instead.") +[[deprecated("AT_ERROR(msg) is deprecated, use TORCH_CHECK(false, msg) +instead.")]] */ inline void deprecated_AT_ERROR() {} /* // Deprecation disabled until we fix sites in our codebase -C10_DEPRECATED_MESSAGE("AT_ASSERT is deprecated, if you mean to indicate an +[[deprecated("AT_ASSERT is deprecated, if you mean to indicate an internal invariant failure, use " \ "TORCH_INTERNAL_ASSERT instead; if you mean to do user error checking, use " \ "TORCH_CHECK. See -https://github.com/pytorch/pytorch/issues/20287 for more details.") +https://github.com/pytorch/pytorch/issues/20287 for more details.")]] */ inline void deprecated_AT_ASSERT() {} /* // Deprecation disabled until we fix sites in our codebase -C10_DEPRECATED_MESSAGE("AT_ASSERTM is deprecated, if you mean to indicate an +[[deprecated("AT_ASSERTM is deprecated, if you mean to indicate an internal invariant failure, use " \ "TORCH_INTERNAL_ASSERT instead; if you mean to do user error checking, use " \ "TORCH_CHECK. See -https://github.com/pytorch/pytorch/issues/20287 for more details.") +https://github.com/pytorch/pytorch/issues/20287 for more details.")]] */ inline void deprecated_AT_ASSERTM() {} diff --git a/c10/util/Logging.h b/c10/util/Logging.h index fac615d836fca8..0d6a1b32632230 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -57,7 +57,9 @@ C10_DECLARE_bool(caffe2_use_fatal_for_enforce); namespace c10 { +#if !defined(C10_NODEPRECATED) using std::string; +#endif // Functions that we use for initialization. C10_API bool InitCaffeLogging(int* argc, char** argv); diff --git a/c10/util/NetworkFlow.cpp b/c10/util/NetworkFlow.cpp index d64599a547a378..29e0ae54e74c5c 100644 --- a/c10/util/NetworkFlow.cpp +++ b/c10/util/NetworkFlow.cpp @@ -90,7 +90,7 @@ struct DinicFlowGraph { // The residual level graph is constructed by: // 1. doing a BFS on the residual graph, assigning levels // to each vertex. - // 2. only include edges u->v where level[v] == leve[u] + 1 + // 2. only include edges u->v where level[v] == level[u] + 1 std::queue q; // let level[u] = 0 if it has not been visited yet. std::vector level(graph_size, 0); diff --git a/c10/util/build.bzl b/c10/util/build.bzl index e56864f4eb6bb2..d1881d2bae6000 100644 --- a/c10/util/build.bzl +++ b/c10/util/build.bzl @@ -34,6 +34,7 @@ def define_targets(rules): visibility = ["//visibility:public"], deps = [ ":bit_cast", + "//torch/headeronly:torch_headeronly", "//c10/macros", "@fmt", "@moodycamel//:moodycamel", @@ -90,6 +91,9 @@ def define_targets(rules): "ssize.h", ], ), + deps = [ + "//torch/headeronly:torch_headeronly", + ], visibility = ["//visibility:public"], ) diff --git a/c10/util/generic_math.h b/c10/util/generic_math.h index a3b318725fe3cb..493c03cb42e648 100644 --- a/c10/util/generic_math.h +++ b/c10/util/generic_math.h @@ -93,7 +93,7 @@ template < std::enable_if_t, int> = 0> inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) { auto mod = a % b; - if ((b < 0) != (mod < 0)) { + if (mod != 0 && (b < 0) != (mod < 0)) { mod += b; } return mod; diff --git a/c10/util/irange.h b/c10/util/irange.h index f5310510099962..cc52d443ee5f3a 100644 --- a/c10/util/irange.h +++ b/c10/util/irange.h @@ -24,7 +24,7 @@ struct integer_iterator { using pointer = I*; using reference = I&; - explicit constexpr integer_iterator(I value) : value(value) {} + explicit constexpr integer_iterator(I val) : value(val) {} constexpr I operator*() const { return value; diff --git a/c10/util/llvmMathExtras.h b/c10/util/llvmMathExtras.h index 82651f7858c773..556699be04b1ab 100644 --- a/c10/util/llvmMathExtras.h +++ b/c10/util/llvmMathExtras.h @@ -610,8 +610,7 @@ inline uint64_t GreatestCommonDivisor64(uint64_t A, uint64_t B) { /// This function takes a 64-bit integer and returns the bit equivalent double. inline double BitsToDouble(uint64_t Bits) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double D; + double D = 0; static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes"); memcpy(&D, &Bits, sizeof(Bits)); return D; diff --git a/c10/util/safe_numerics.h b/c10/util/safe_numerics.h index e242cfaf8449a3..32ffca52e48641 100644 --- a/c10/util/safe_numerics.h +++ b/c10/util/safe_numerics.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include // GCC has __builtin_mul_overflow from before it supported __has_builtin @@ -31,30 +32,38 @@ C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { #endif } -C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { +template +C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) { #if C10_HAS_BUILTIN_OVERFLOW() return __builtin_mul_overflow(a, b, out); #else - *out = a * b; - // This test isnt exact, but avoids doing integer division - return ( - (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64); -#endif -} + static_assert( + std::is_integral_v, "mul_overflows only supports integral types"); -C10_ALWAYS_INLINE bool mul_overflows(int64_t a, int64_t b, int64_t* out) { -#if C10_HAS_BUILTIN_OVERFLOW() - return __builtin_mul_overflow(a, b, out); -#else - volatile int64_t tmp = a * b; - *out = tmp; - if (a == 0 || b == 0) { - return false; + if constexpr (std::is_signed_v) { + // For signed types, use the division-based check + volatile T tmp = a * b; + *out = tmp; + if (a == 0 || b == 0) { + return false; + } + return !(a == tmp / b); + } else { + // For unsigned types, use leading zeros approach + // This test isn't exact, but avoids doing integer division + *out = a * b; + constexpr int bits = sizeof(T) * 8; + return ( + (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < + bits); } - return !(a == tmp / b); #endif } +C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { + return mul_overflows(a, b, out); +} + template bool safe_multiplies_u64(It first, It last, uint64_t* out) { #if C10_HAS_BUILTIN_OVERFLOW() @@ -77,7 +86,7 @@ bool safe_multiplies_u64(It first, It last, uint64_t* out) { prod_log2 += c10::llvm::Log2_64_Ceil(x); } *out = prod; - // This test isnt exact, but avoids doing integer division + // This test isn't exact, but avoids doing integer division return !is_zero && (prod_log2 >= 64); #endif } diff --git a/c10/util/string_view.h b/c10/util/string_view.h index 4858716a75b9d3..0cc5da4309f6b1 100644 --- a/c10/util/string_view.h +++ b/c10/util/string_view.h @@ -632,7 +632,7 @@ struct hash<::c10::basic_string_view> { size_t operator()(::c10::basic_string_view x) const { // The standard says that std::string_view hashing must do the same as // std::string hashing but leaves the details of std::string hashing - // up to the implementer. So, to be conformant, we need to re-use and + // up to the implementer. So, to be conformant, we need to reuse and // existing STL type's hash function. The std::string fallback is probably // slow but the only way to be conformant. diff --git a/c10/util/typeid.h b/c10/util/typeid.h index a52dbd62520a58..76087842635f5c 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -393,7 +393,7 @@ class C10_API TypeMeta final { return data().placementNew_; } /** - * Returns the typed copy function pointer for individual iterms. + * Returns the typed copy function pointer for individual items. */ Copy* copy() const noexcept { return data().copy_; diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 5a5e953bef4a5c..543b48f081135d 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -251,9 +251,12 @@ class DeviceCachingAllocator { return true; } - bool alloc_block(AllocParams& p) { + bool alloc_block(AllocParams& p, bool isRetry) { auto size = p.alloc_size; auto device = p.device(); + if (isRetry) { + stats.num_alloc_retries += 1; + } void* ptr = sycl::aligned_alloc_device( kDeviceAlignment, size, @@ -425,8 +428,8 @@ class DeviceCachingAllocator { bool block_found = get_free_block(params); // Can't reuse an existing block, try to get a new one. if (!block_found) { - block_found = alloc_block(params) || - (release_cached_blocks() && alloc_block(params)); + block_found = alloc_block(params, false) || + (release_cached_blocks() && alloc_block(params, true)); } if (!block_found) { c10::xpu::DeviceProp device_prop; @@ -519,6 +522,7 @@ class DeviceCachingAllocator { stats.active_bytes[statType].reset_accumulated(); stats.requested_bytes[statType].reset_accumulated(); } + stats.num_alloc_retries = 0; } void resetPeakStats() { diff --git a/c10/xpu/XPUDeviceProp.h b/c10/xpu/XPUDeviceProp.h index 00b7969a73d496..591a14f4ad91af 100644 --- a/c10/xpu/XPUDeviceProp.h +++ b/c10/xpu/XPUDeviceProp.h @@ -113,18 +113,21 @@ namespace c10::xpu { _(native_vector_width_double) \ _(native_vector_width_half) -#define AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(_) \ - /* the number of EUs associated with the Intel GPU. */ \ - _(gpu_eu_count, 512) \ - \ - /* the number of EUs in a subslice. */ \ - _(gpu_eu_count_per_subslice, 8) \ - \ - /* the simd width of EU of GPU. */ \ - _(gpu_eu_simd_width, 8) \ - \ - /* the number of hardware threads per EU of GPU. */ \ - _(gpu_hw_threads_per_eu, 8) +#define AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(_) \ + /* the number of EUs associated with the Intel GPU. */ \ + _(gpu_eu_count, 512) \ + \ + /* the number of EUs in a subslice. */ \ + _(gpu_eu_count_per_subslice, 8) \ + \ + /* the simd width of EU of GPU. */ \ + _(gpu_eu_simd_width, 8) \ + \ + /* the number of hardware threads per EU of GPU. */ \ + _(gpu_hw_threads_per_eu, 8) \ + \ + /* the device identifier of the Intel GPU, also known as the product ID. */ \ + _(device_id, 0) #define AT_FORALL_XPU_DEVICE_ASPECT(_) \ /* sycl::half is supported on device. */ \ diff --git a/c10/xpu/test/impl/XPUCachingAllocatorTest.cpp b/c10/xpu/test/impl/XPUCachingAllocatorTest.cpp index b4ffe2c141f3a4..d5d3d850c1056c 100644 --- a/c10/xpu/test/impl/XPUCachingAllocatorTest.cpp +++ b/c10/xpu/test/impl/XPUCachingAllocatorTest.cpp @@ -98,7 +98,7 @@ TEST(XPUCachingAllocatorTest, DeviceCachingAllocateByExternalStream) { void* tmp = sycl::aligned_alloc_device( 512, _10mb, c10::xpu::get_raw_device(0), c10::xpu::get_device_context()); void* ptr1 = c10::xpu::XPUCachingAllocator::raw_alloc(_10mb); - // We have reserved 500M of memory for resue. When allocating `ptr0` and + // We have reserved 500M of memory for reuse. When allocating `ptr0` and // `ptr1` through the device caching allocator, they should be allocated from // the same block. Specifically, `ptr1` should follow immediately after `ptr0` // in the block, forming a sequence like [ptr0, ptr1]. This behavior occurs diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 2cf0dac3f93526..60a701af06c4d8 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -572,13 +572,14 @@ if(USE_CUDA) if(NOT WIN32) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( - ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/CudaDMAConnectivity.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu - ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() @@ -709,10 +710,11 @@ list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS}) if(USE_MPS) list(APPEND Caffe2_CPU_SRCS ${Caffe2_MPS_SRCS}) list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.cpp) + list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_mps.mm) list(APPEND Caffe2_CPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_runner/model_container_runner_mps.cpp) if(CAN_COMPILE_METAL) - file(TOUCH ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp) - list(APPEND Caffe2_CPU_SRCS ${CMAKE_BINARY_DIR}/aten/src/ATen/metallib_dummy.cpp) + file(TOUCH ${CMAKE_BINARY_DIR}/caffe2/aten/src/ATen/metallib_dummy.cpp) + list(APPEND Caffe2_CPU_SRCS ${CMAKE_BINARY_DIR}/caffe2/aten/src/ATen/metallib_dummy.cpp) endif() endif() @@ -985,35 +987,42 @@ elseif(USE_CUDA) target_compile_definitions(torch_cuda PRIVATE USE_NCCL) endif() - # Use env var for these for now for prototyping purposes - set(USE_NVSHMEM $ENV{USE_NVSHMEM} CACHE BOOL "Whether to build with NVSHMEM support") - # If user has specified NVSHMEM_HOME, we use it; - # Otherwise, NVSHMEM_HOME is auto detected in tools/setup_helpers/cmake.py - if($ENV{NVSHMEM_HOME}) - set(NVSHMEM_HOME $ENV{NVSHMEM_HOME} CACHE PATH "Path to NVSHMEM build dir") - endif() - - if(USE_NVSHMEM AND NOT DEFINED NVSHMEM_HOME) - message(WARNING "USE_NVSHMEM set to 1 but NVSHMEM_HOME not found. Please run `pip install nvidia-nvshmem-`, or set NVSHMEM_HOME to the NVSHMEM build dir") - # Disable nvshmem if NVSHMEM_HOME is not found - set(USE_NVSHMEM FALSE CACHE BOOL "Whether to build with NVSHMEM support") - endif() - + # Compile with NVSHMEM + # Default value of `USE_NVSHMEM` is set in CMakeLists.txt under root, to ON. if(USE_NVSHMEM) - message("Building with NVSHMEM support: '${NVSHMEM_HOME}'") - set(NVSHMEM_INCLUDE_DIR "${NVSHMEM_HOME}/include") - set(NVSHMEM_LIB_DIR "${NVSHMEM_HOME}/lib") - + message(STATUS "NVSHMEM_HOME set to: '$ENV{NVSHMEM_HOME}'") + message(STATUS "NVSHMEM wheel installed at: '${NVSHMEM_PY_DIR}'") + # Search order: + # 1. If user has specified `NVSHMEM_HOME`, we use it; + # 2. If NVSHMEM wheel has been installed, we use it, see + # tools/setup_helpers/cmake.py, where we set `NVSHMEM_PY_DIR` to the wheel + # location, e.g. + # `/path/to/conda/lib/python3.10/site-packages/nvidia/nvshmem`, + # 3. Let CMake find it in the default system paths, e.g. /usr/local. + find_path(NVSHMEM_LIB_DIR + # In pip install case, the lib suffix is `.so.3` instead of `.so` + NAMES libnvshmem_host.so libnvshmem_host.so.3 + PATHS $ENV{NVSHMEM_HOME}/lib ${NVSHMEM_PY_DIR}/lib + DOC "The location of NVSHMEM library.") + find_path(NVSHMEM_INCLUDE_DIR + NAMES nvshmem.h + PATHS $ENV{NVSHMEM_HOME}/include ${NVSHMEM_PY_DIR}/include + DOC "The location of NVSHMEM headers.") + endif() + + # If NVSHMEM_LIBRARY is found, we build torch_cuda with NVSHMEM support. + if(NVSHMEM_LIB_DIR AND NVSHMEM_INCLUDE_DIR) + message(STATUS "Building with NVSHMEM support: '${NVSHMEM_LIB_DIR}'") include_directories(${NVSHMEM_INCLUDE_DIR}) # Linking with nvshmem requires the source binary to be built with -rdc # which is not viable for libtorch_cuda. So we isolate the linking of # nvshmem in nvshmem_extension. add_library(nvshmem_extension SHARED - "${TORCH_SRC_DIR}/csrc/distributed/c10d/nvshmem_extension.cu" - "${TORCH_SRC_DIR}/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu" - "${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp" "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp" + "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu" + "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu" + "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp" ) set_target_properties(nvshmem_extension PROPERTIES CUDA_SEPARABLE_COMPILATION ON) target_compile_options(nvshmem_extension PRIVATE $<$:-rdc=true>) @@ -1025,8 +1034,12 @@ elseif(USE_CUDA) ${NVSHMEM_LIB_DIR}/libnvshmem_host.so.3 nvshmem_device ) + target_compile_definitions(torch_cuda PUBLIC USE_NVSHMEM) + target_compile_definitions(nvshmem_extension PUBLIC USE_NVSHMEM) target_link_libraries(torch_cuda PRIVATE nvshmem_extension) install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib) + else() + message(STATUS "NVSHMEM not found, not building with NVSHMEM support.") endif() if(USE_UCC) @@ -1041,8 +1054,13 @@ elseif(USE_CUDA) FLASH_NAMESPACE=pytorch_flash UNFUSE_FMA # Addressing issue #121558 ) - target_include_directories(torch_cuda PRIVATE - ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/ + target_sources(torch_cuda PRIVATE $) + target_include_directories(torch_cuda PUBLIC + $ + $ + $ + $ + $ ) endif() if(USE_MEM_EFF_ATTENTION) @@ -1289,7 +1307,8 @@ endif() target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE}) target_include_directories(torch_cpu PRIVATE - ${TORCH_SRC_DIR}/csrc) + ${TORCH_SRC_DIR}/csrc + ${TORCH_SRC_DIR}/headeronly) target_include_directories(torch_cpu PRIVATE ${TORCH_ROOT}/third_party/miniz-3.0.2) @@ -1308,9 +1327,12 @@ target_include_directories(torch_cpu PRIVATE target_include_directories(torch_cpu PRIVATE ${TORCH_ROOT}/third_party/nlohmann/include) -install(DIRECTORY "${TORCH_SRC_DIR}/csrc" +install(DIRECTORY + "${TORCH_SRC_DIR}/csrc" + "${TORCH_SRC_DIR}/headeronly" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp") + install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" @@ -1319,12 +1341,6 @@ install(FILES "${TORCH_SRC_DIR}/custom_class_detail.h" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch) if(BUILD_TEST) - if(BUILD_EXECUTORCH) - add_subdirectory( - ${TORCH_ROOT}/test/edge - ${CMAKE_BINARY_DIR}/test_edge_op_registration - ) - endif() if(BUILD_LITE_INTERPRETER) add_subdirectory( ${TORCH_ROOT}/test/cpp/lite_interpreter_runtime @@ -1950,7 +1966,8 @@ if(BUILD_TEST) set(HIP_HIPCC_FLAGS ${BASE_HIPCC_FLAGS}) set_source_files_properties(${test_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) hip_add_executable(${test_name} "${test_src}") - set_target_properties(${test_name} PROPERTIES HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH}) + list(JOIN PYTORCH_ROCM_ARCH " " ROCM_PROPERTY_ARCH_LIST) + set_target_properties(${test_name} PROPERTIES HIP_ARCHITECTURES ${ROCM_PROPERTY_ARCH_LIST}) else() add_executable(${test_name} "${test_src}") endif() diff --git a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc index 3ed48a1c523226..936b686e816ea2 100644 --- a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc @@ -3359,12 +3359,10 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -3539,12 +3537,10 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -3650,12 +3646,10 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -3727,12 +3721,10 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -3794,12 +3786,10 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -3946,12 +3936,10 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -4126,12 +4114,10 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -4237,12 +4223,10 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -4314,12 +4298,10 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); @@ -4381,12 +4363,10 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; + float bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index ca6da8ef4ffced..91f6ac238c0f33 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -113,8 +113,6 @@ def compute(regid, InType, use_weights, isa, prefetch): if InType == "uint8_t": code.append(" " + OutType + " wgt = 1.f;") - code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)") - code.append(" " + OutType + " bio;") code.append(" if (weights) {") code.append( " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" @@ -125,7 +123,7 @@ def compute(regid, InType, use_weights, isa, prefetch): " const float* scale_bias = reinterpret_cast(\n" " &input[idx * fused_block_size + block_size]);" ) - code.append(" bio = wgt * scale_bias[1];") + code.append(" " + OutType + " bio = wgt * scale_bias[1];") code.append(" wgt = wgt * scale_bias[0];") else: code.append(" bio = wgt * scale_bias[2 * idx + 1];") @@ -316,8 +314,6 @@ def compute(InType, use_weights, isa): if InType == "uint8_t": code.append(" " + OutType + " wgt = 1.f;") - code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)") - code.append(" " + OutType + " bio;") code.append(" if (weights) {") code.append( " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" @@ -328,10 +324,10 @@ def compute(InType, use_weights, isa): " const float* scale_bias = reinterpret_cast(\n" " &input[idx * fused_block_size + block_size]);" ) - code.append(" bio = wgt * scale_bias[1];") + code.append(" " + OutType + " bio = wgt * scale_bias[1];") code.append(" wgt = wgt * scale_bias[0];") else: - code.append(" bio = wgt * scale_bias[2 * idx + 1];") + code.append(" " + OutType + " bio = wgt * scale_bias[2 * idx + 1];") code.append(" wgt = wgt * scale_bias[2 * idx];") code.append(" __m256 vbio = _mm256_set1_ps(bio);") else: diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 05bba51f4da9f1..e39a78c62dd540 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -196,8 +196,7 @@ void PyTorchStreamReader::init() { // version check at::DataPtr version_ptr; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t version_size; + size_t version_size = 0; if (hasRecord(".data/version")) { std::tie(version_ptr, version_size) = getRecord(".data/version"); } else { diff --git a/cmake/BLAS_ABI.cmake b/cmake/BLAS_ABI.cmake new file mode 100644 index 00000000000000..bb0b5949d73d20 --- /dev/null +++ b/cmake/BLAS_ABI.cmake @@ -0,0 +1,76 @@ +# Push host architecture when cross-compiling otherwise check would fail +# when cross-compiling for arm64 on x86_64 +cmake_push_check_state(RESET) +if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64)$") + list(APPEND CMAKE_REQUIRED_FLAGS "-arch ${CMAKE_HOST_SYSTEM_PROCESSOR}") +endif() + +# Set values through env variables if cross compiling +if(CMAKE_CROSSCOMPILING) + if("$ENV{PYTORCH_BLAS_F2C}" STREQUAL "ON") + SET(BLAS_F2C TRUE) + else() + SET(BLAS_F2C FALSE) + endif() + + if("$ENV{PYTORCH_BLAS_USE_CBLAS_DOT}" STREQUAL "ON") + SET(BLAS_USE_CBLAS_DOT TRUE) + else() + SET(BLAS_USE_CBLAS_DOT FALSE) + endif() +else() + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + CHECK_C_SOURCE_RUNS(" +#include +#include +float x[4] = { 1, 2, 3, 4 }; +float y[4] = { .1, .01, .001, .0001 }; +int four = 4; +int one = 1; +extern double sdot_(); +int main() { + int i; + double r = sdot_(&four, x, &one, y, &one); + exit((float)r != (float).1234); +}" BLAS_F2C_DOUBLE_WORKS ) + CHECK_C_SOURCE_RUNS(" +#include +#include +float x[4] = { 1, 2, 3, 4 }; +float y[4] = { .1, .01, .001, .0001 }; +int four = 4; +int one = 1; +extern float sdot_(); +int main() { + int i; + double r = sdot_(&four, x, &one, y, &one); + exit((float)r != (float).1234); +}" BLAS_F2C_FLOAT_WORKS ) + + if(BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + MESSAGE(STATUS "This BLAS uses the F2C return conventions") + SET(BLAS_F2C TRUE) + else(BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + SET(BLAS_F2C FALSE) + endif(BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + CHECK_C_SOURCE_RUNS(" +#include +#include +float x[4] = { 1, 2, 3, 4 }; +float y[4] = { .1, .01, .001, .0001 }; +extern float cblas_sdot(); +int main() { + int i; + double r = cblas_sdot(4, x, 1, y, 1); + exit((float)r != (float).1234); +}" BLAS_USE_CBLAS_DOT ) + if(BLAS_USE_CBLAS_DOT) + SET(BLAS_USE_CBLAS_DOT TRUE) + else(BLAS_USE_CBLAS_DOT) + SET(BLAS_USE_CBLAS_DOT FALSE) + endif(BLAS_USE_CBLAS_DOT) + SET(CMAKE_REQUIRED_LIBRARIES) +endif(CMAKE_CROSSCOMPILING) +MESSAGE(STATUS "BLAS_USE_CBLAS_DOT: ${BLAS_USE_CBLAS_DOT}") +MESSAGE(STATUS "BLAS_F2C: ${BLAS_F2C}") +cmake_pop_check_state() diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 90803d830163e8..16ee19a91d487d 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -35,7 +35,7 @@ endfunction() ################################################################################ -# -- [ Deterine commit hash +# -- [ Determine commit hash execute_process( COMMAND "${Python_EXECUTABLE}" -c "from tools.generate_torch_version import get_sha;print(get_sha('.'), end='')" OUTPUT_VARIABLE COMMIT_SHA @@ -81,7 +81,7 @@ if(INTERN_BUILD_ATEN_OPS) if(USE_CUDA) # The stable/nightly builds do not enable some SM architectures, # like 89/90a/100a. Still, some files need to be built for these - # architecturs specifically. This function makes it possible to + # architectures specifically. This function makes it possible to # enable building given file for a specific such architecture, in # case if PyTorch is built for corresponding other architecture; # for example, it will enable building for SM 90a in case PyTorch @@ -108,6 +108,11 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a") endif() endif() + if("${_arch}" STREQUAL "120a") + if(_existing_arch_flags MATCHES ".*compute_120.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") + endif() + endif() endforeach() endif() list(JOIN _file_compile_flags " " _file_compile_flags) @@ -117,13 +122,13 @@ if(INTERN_BUILD_ATEN_OPS) _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu" - "89;90a;100a") + "89;90a;100a;120a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu" "90a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu" - "90a") + "90a;100a") endif() @@ -403,7 +408,7 @@ if(INTERN_BUILD_ATEN_OPS) list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES) math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1") - # The sources list might get reordered later based on the capabilites. + # The sources list might get reordered later based on the capabilities. # See NOTE [ Linking AVX and non-AVX files ] foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES}) function(process_vec NAME) diff --git a/cmake/DebugHelper.cmake b/cmake/DebugHelper.cmake index 3797a472336fc4..f69603fb72002c 100644 --- a/cmake/DebugHelper.cmake +++ b/cmake/DebugHelper.cmake @@ -1,5 +1,5 @@ function(print_target_properties tgt) - # Get all propreties that cmake supports + # Get all properties that cmake supports execute_process(COMMAND cmake --help-property-list OUTPUT_VARIABLE CMAKE_PROPERTY_LIST) # Convert command output into a CMake list diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 15bbfaa82ddd17..6208ab77286b81 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -163,6 +163,7 @@ else() endif() set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib;APL") message(STATUS "Trying to find preferred BLAS backend of choice: " ${BLAS}) +set(BLAS_CHECK_F2C 0) if(BLAS STREQUAL "Eigen") # Eigen is header-only and we do not have any dependent libraries @@ -175,6 +176,7 @@ elseif(BLAS STREQUAL "ATLAS") set(BLAS_INFO "atlas") set(BLAS_FOUND 1) set(BLAS_LIBRARIES ${ATLAS_LIBRARIES} cblas) + set(BLAS_CHECK_F2C 1) elseif(BLAS STREQUAL "OpenBLAS") find_package(OpenBLAS REQUIRED) include_directories(SYSTEM ${OpenBLAS_INCLUDE_DIR}) @@ -182,10 +184,12 @@ elseif(BLAS STREQUAL "OpenBLAS") set(BLAS_INFO "open") set(BLAS_FOUND 1) set(BLAS_LIBRARIES ${OpenBLAS_LIB}) + set(BLAS_CHECK_F2C 1) elseif(BLAS STREQUAL "BLIS") find_package(BLIS REQUIRED) include_directories(SYSTEM ${BLIS_INCLUDE_DIR}) list(APPEND Caffe2_DEPENDENCY_LIBS ${BLIS_LIB}) + set(BLAS_CHECK_F2C 1) elseif(BLAS STREQUAL "MKL") if(BLAS_SET_BY_USER) find_package(MKL REQUIRED) @@ -215,6 +219,7 @@ elseif(BLAS STREQUAL "NVPL") set(BLAS_INFO "nvpl") set(BLAS_FOUND 1) set(BLAS_USE_CBLAS_DOT TRUE) + set(BLAS_CHECK_F2C 1) elseif(BLAS STREQUAL "vecLib") find_package(vecLib REQUIRED) include_directories(SYSTEM ${vecLib_INCLUDE_DIR}) @@ -226,12 +231,14 @@ elseif(BLAS STREQUAL "FlexiBLAS") find_package(FlexiBLAS REQUIRED) include_directories(SYSTEM ${FlexiBLAS_INCLUDE_DIR}) list(APPEND Caffe2_DEPENDENCY_LIBS ${FlexiBLAS_LIB}) + set(BLAS_CHECK_F2C 1) elseif(BLAS STREQUAL "APL") find_package(APL REQUIRED) include_directories(SYSTEM ${APL_INCLUDE_DIR}) set(BLAS_INFO "apl") set(BLAS_FOUND 1) set(BLAS_LIBRARIES ${APL_LIBRARIES}) + set(BLAS_CHECK_F2C 1) elseif(BLAS STREQUAL "Generic") # On Debian family, the CBLAS ABIs have been merged into libblas.so if(ENV{GENERIC_BLAS_LIBRARIES} STREQUAL "") @@ -245,10 +252,16 @@ elseif(BLAS STREQUAL "Generic") set(GENERIC_BLAS_FOUND TRUE) set(BLAS_INFO "generic") set(BLAS_FOUND 1) + set(BLAS_CHECK_F2C 1) else() message(FATAL_ERROR "Unrecognized BLAS option: " ${BLAS}) endif() +# Determine if blas was compiled with the f2c conventions +if(BLAS_LIBRARIES AND BLAS_CHECK_F2C) + include(cmake/BLAS_ABI.cmake) +endif(BLAS_LIBRARIES) + if(NOT INTERN_BUILD_MOBILE) set(AT_MKL_SEQUENTIAL 0) set(USE_BLAS 1) @@ -778,6 +791,20 @@ elseif(NOT TARGET fp16 AND USE_SYSTEM_FP16) endif() list(APPEND Caffe2_DEPENDENCY_LIBS fp16) +# ---[ Python Interpreter +# If not given a Python installation, then use the current active Python +if(NOT Python_EXECUTABLE) + execute_process( + COMMAND "which" "python3" RESULT_VARIABLE _exitcode OUTPUT_VARIABLE _py_exe) + if(${_exitcode} EQUAL 0) + if(NOT MSVC) + string(STRIP ${_py_exe} Python_EXECUTABLE) + endif() + message(STATUS "Setting Python to ${Python_EXECUTABLE}") + endif() +endif() + + # ---[ EIGEN # Due to license considerations, we will only use the MPL2 parts of Eigen. set(EIGEN_MPL2_ONLY 1) @@ -787,6 +814,9 @@ if(USE_SYSTEM_EIGEN_INSTALL) message(STATUS "Found system Eigen at " ${EIGEN3_INCLUDE_DIR}) else() message(STATUS "Did not find system Eigen. Using third party subdirectory.") + execute_process(COMMAND ${Python_EXECUTABLE} ../tools/optional_modules.py checkout_eigen + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) + set(EIGEN3_INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/../third_party/eigen) caffe2_update_option(USE_SYSTEM_EIGEN_INSTALL OFF) endif() @@ -797,19 +827,6 @@ endif() include_directories(SYSTEM ${EIGEN3_INCLUDE_DIR}) -# ---[ Python Interpreter -# If not given a Python installation, then use the current active Python -if(NOT Python_EXECUTABLE) - execute_process( - COMMAND "which" "python3" RESULT_VARIABLE _exitcode OUTPUT_VARIABLE _py_exe) - if(${_exitcode} EQUAL 0) - if(NOT MSVC) - string(STRIP ${_py_exe} Python_EXECUTABLE) - endif() - message(STATUS "Setting Python to ${Python_EXECUTABLE}") - endif() -endif() - if(BUILD_PYTHON) set(PYTHON_COMPONENTS Development.Module) if(USE_NUMPY) @@ -1022,6 +1039,9 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -std=c++17) list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2) list(APPEND HIP_CXX_FLAGS -DHIP_ENABLE_WARP_SYNC_BUILTINS) + if(HIPBLASLT_OUTER_VEC) + list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_OUTER_VEC) + endif() if(HIPBLASLT_VEC_EXT) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT) endif() @@ -1063,7 +1083,13 @@ if(USE_ROCM) # Math libraries list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS - roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt) + roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt roc::rocsolver) + # hipsparselt is an optional component that will eventually be enabled by default. + if(hipsparselt_FOUND) + list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS + roc::hipsparselt + ) + endif() # ---[ Kernel asserts # Kernel asserts is disabled for ROCm by default. @@ -1145,7 +1171,7 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE) set(CMAKE_POLICY_VERSION_MINIMUM 3.5) endif() add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/tensorpipe) - # Suppress warning to unblock libnop comiplation by clang-17 + # Suppress warning to unblock libnop compilation by clang-17 # See https://github.com/pytorch/pytorch/issues/151316 target_compile_options_if_supported(tensorpipe -Wno-missing-template-arg-list-after-template-kw) if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") @@ -1197,7 +1223,7 @@ if(USE_GLOO) endif() set(GLOO_USE_CUDA_TOOLKIT ON CACHE BOOL "" FORCE) - # Disable NCCL/RCCL since we don't use Gloo+NCCL, make sure to reenable it! + # Disable NCCL/RCCL since we don't use Gloo+NCCL, make sure to re-enable it! set(USE_NCCL_SAVED ${USE_NCCL}) set(USE_RCCL_SAVED ${USE_RCCL}) set(USE_NCCL OFF) @@ -1208,7 +1234,7 @@ if(USE_GLOO) # Here is a little bit hacky. We have to put PROJECT_BINARY_DIR in front # of PROJECT_SOURCE_DIR with/without conda system. The reason is that - # gloo generates a new config.h in the binary diretory. + # gloo generates a new config.h in the binary directory. include_directories(BEFORE SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../third_party/gloo) include_directories(BEFORE SYSTEM ${PROJECT_BINARY_DIR}/third_party/gloo) else() @@ -1664,7 +1690,10 @@ if(USE_KINETO) }" EXCEPTIONS_WORK) set(CMAKE_REQUIRED_LINK_OPTIONS "") if(NOT EXCEPTIONS_WORK) - message(FATAL_ERROR "Detected that statically linking against CUPTI causes exceptions to stop working. See https://github.com/pytorch/pytorch/issues/57744 for more details. Perhaps try: USE_CUPTI_SO=1 python setup.py develop --cmake") + message(FATAL_ERROR + "Detected that statically linking against CUPTI causes exceptions to stop working. " + "See https://github.com/pytorch/pytorch/issues/57744 for more details. " + "Perhaps try: USE_CUPTI_SO=1 CMAKE_FRESH=1 python setup.py develop") endif() endif() diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 9c1862f6b4446d..8004b0f400a8d5 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -1,16 +1,3 @@ -macro(get_target_gpus_from_pytorch target_gpus) - set(gfx90a_key MI200) - set(gfx942_key MI300X) - set(gfx1100_key Navi31) - - foreach(X IN LISTS PYTORCH_ROCM_ARCH) - set(key ${X}) - string(APPEND key "_key") - string(APPEND target_gpus ${${key}}) - string(APPEND target_gpus "|") - endforeach() -endmacro() - if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_INCLUDED TRUE) @@ -22,22 +9,22 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.9.2b") + set(__AOTRITON_VER "0.10b") set(__AOTRITON_MANYLINUX_LIST - "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.4 + "manylinux_2_28" # rocm7.0 ) set(__AOTRITON_ROCM_LIST - "rocm6.2" "rocm6.3" "rocm6.4" + "rocm7.0" ) - set(__AOTRITON_CI_COMMIT "b388d223d8c7213545603e00f6f3148c54d1f525") + set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477") set(__AOTRITON_SHA256_LIST - "08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2 - "9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3 - "41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4 + "861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 + "acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 + "7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm7.0 ) set(__AOTRITON_Z "gz") @@ -50,17 +37,13 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE}) - set(target_gpus "") - get_target_gpus_from_pytorch(target_gpus) ExternalProject_Add(aotriton_external GIT_REPOSITORY https://github.com/ROCm/aotriton.git GIT_TAG ${__AOTRITON_CI_COMMIT} PREFIX ${__AOTRITON_EXTERN_PREFIX} INSTALL_DIR ${__AOTRITON_INSTALL_DIR} - LIST_SEPARATOR | CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} - -DTARGET_GPUS:STRING=${target_gpus} - -DAOTRITON_COMPRESS_KERNEL=ON + -DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NO_SHARED=OFF diff --git a/cmake/Metal.cmake b/cmake/Metal.cmake index cc8e1932f1a159..c9565e2fc0e9e5 100644 --- a/cmake/Metal.cmake +++ b/cmake/Metal.cmake @@ -8,7 +8,7 @@ if(WERROR) endif() function(metal_to_air SRC TARGET FLAGS) - add_custom_command(COMMAND xcrun metal -c ${SRC} -I ${CMAKE_SOURCE_DIR} -o ${TARGET} ${FLAGS} ${METAL_CFLAGS} + add_custom_command(COMMAND xcrun metal -c ${SRC} -I ${CMAKE_SOURCE_DIR} -I ${CMAKE_SOURCE_DIR}/aten/src -o ${TARGET} ${FLAGS} ${METAL_CFLAGS} DEPENDS ${SRC} OUTPUT ${TARGET} COMMENT "Compiling ${SRC} to ${TARGET}" diff --git a/cmake/Modules/FindARM.cmake b/cmake/Modules/FindARM.cmake index 79fc772eed11a9..903025c5c2cfc5 100644 --- a/cmake/Modules/FindARM.cmake +++ b/cmake/Modules/FindARM.cmake @@ -89,7 +89,7 @@ if(NOT CORTEXA9_FOUND) endif(NOT CORTEXA9_FOUND) mark_as_advanced(NEON_FOUND) -#SVE support is availale is only for Linux OS. +#SVE support is available is only for Linux OS. IF(CMAKE_SYSTEM_NAME MATCHES "Linux") # Include necessary modules for checking C and C++ source compilations INCLUDE(CheckCSourceCompiles) diff --git a/cmake/Modules/FindBLAS.cmake b/cmake/Modules/FindBLAS.cmake index 8e54eedb2aa8f5..b4b158fc4965c0 100644 --- a/cmake/Modules/FindBLAS.cmake +++ b/cmake/Modules/FindBLAS.cmake @@ -311,80 +311,8 @@ endif() # Determine if blas was compiled with the f2c conventions IF (BLAS_LIBRARIES) - # Push host architecture when cross-compiling otherwise check would fail - # when cross-compiling for arm64 on x86_64 - cmake_push_check_state(RESET) - if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64)$") - list(APPEND CMAKE_REQUIRED_FLAGS "-arch ${CMAKE_HOST_SYSTEM_PROCESSOR}") - endif() - -# Set values through env variables if cross compiling - IF (CMAKE_CROSSCOMPILING) - IF("$ENV{PYTORCH_BLAS_F2C}" STREQUAL "ON") - SET(BLAS_F2C TRUE) - ELSE() - SET(BLAS_F2C FALSE) - ENDIF() - - IF("$ENV{PYTORCH_BLAS_USE_CBLAS_DOT}" STREQUAL "ON") - SET(BLAS_USE_CBLAS_DOT TRUE) - ELSE() - SET(BLAS_USE_CBLAS_DOT FALSE) - ENDIF() - ELSE () - SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) - CHECK_C_SOURCE_RUNS(" - #include - #include - float x[4] = { 1, 2, 3, 4 }; - float y[4] = { .1, .01, .001, .0001 }; - int four = 4; - int one = 1; - extern double sdot_(); - int main() { - int i; - double r = sdot_(&four, x, &one, y, &one); - exit((float)r != (float).1234); - }" BLAS_F2C_DOUBLE_WORKS ) - CHECK_C_SOURCE_RUNS(" - #include - #include - float x[4] = { 1, 2, 3, 4 }; - float y[4] = { .1, .01, .001, .0001 }; - int four = 4; - int one = 1; - extern float sdot_(); - int main() { - int i; - double r = sdot_(&four, x, &one, y, &one); - exit((float)r != (float).1234); - }" BLAS_F2C_FLOAT_WORKS ) - IF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) - MESSAGE(STATUS "This BLAS uses the F2C return conventions") - SET(BLAS_F2C TRUE) - ELSE (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) - SET(BLAS_F2C FALSE) - ENDIF(BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) - CHECK_C_SOURCE_RUNS(" - #include - #include - float x[4] = { 1, 2, 3, 4 }; - float y[4] = { .1, .01, .001, .0001 }; - extern float cblas_sdot(); - int main() { - int i; - double r = cblas_sdot(4, x, 1, y, 1); - exit((float)r != (float).1234); - }" BLAS_USE_CBLAS_DOT ) - IF (BLAS_USE_CBLAS_DOT) - SET(BLAS_USE_CBLAS_DOT TRUE) - ELSE (BLAS_USE_CBLAS_DOT) - SET(BLAS_USE_CBLAS_DOT FALSE) - ENDIF(BLAS_USE_CBLAS_DOT) - SET(CMAKE_REQUIRED_LIBRARIES) - ENDIF(CMAKE_CROSSCOMPILING) - cmake_pop_check_state() -ENDIF(BLAS_LIBRARIES) + include(cmake/BLAS_ABI.cmake) +endif(BLAS_LIBRARIES) # epilogue diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 71a07d06f0b0f1..00fd0130d83442 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -45,19 +45,13 @@ IF(NOT MKLDNN_FOUND) list(APPEND DNNL_MAKE_COMMAND "--" "-l" "$ENV{MAX_JOBS}") endif() endif() - if(XPU_DEVICE_CXX_FLAGS) - set(DNNL_CXX_FLAGS "-DCMAKE_CXX_FLAGS=${XPU_DEVICE_CXX_FLAGS}") - else() - set(DNNL_CXX_FLAGS "") - endif() ExternalProject_Add(xpu_mkldnn_proj GIT_REPOSITORY https://github.com/oneapi-src/oneDNN - GIT_TAG v3.7.1 + GIT_TAG v3.8.1 PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=${SYCL_CXX_DRIVER} - ${DNNL_CXX_FLAGS} -DDNNL_GPU_RUNTIME=SYCL -DDNNL_CPU_RUNTIME=THREADPOOL -DDNNL_BUILD_TESTS=OFF @@ -91,8 +85,12 @@ IF(NOT MKLDNN_FOUND) ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER) IF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") - MESSAGE("-- Will build oneDNN UKERNEL") - SET(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE) + IF(CPU_POWER) + SET(DNNL_EXPERIMENTAL_UKERNEL OFF CACHE BOOL "" FORCE) + ELSE() + MESSAGE("-- Will build oneDNN UKERNEL") + SET(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE) + ENDIF() ENDIF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") FIND_PACKAGE(BLAS) diff --git a/cmake/Modules/FindNCCL.cmake b/cmake/Modules/FindNCCL.cmake index b386517900d5ac..cef8002f817060 100644 --- a/cmake/Modules/FindNCCL.cmake +++ b/cmake/Modules/FindNCCL.cmake @@ -57,7 +57,8 @@ if(NCCL_FOUND) # obtaining NCCL version and some sanity checks include(CheckCXXSymbolExists) check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) - if (NCCL_VERSION_DEFINED) + # this condition check only works for non static NCCL linking + if (NCCL_VERSION_DEFINED AND NOT USE_STATIC_NCCL) set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") file(WRITE ${file} " #include @@ -65,7 +66,6 @@ if(NCCL_FOUND) # obtaining NCCL version and some sanity checks int main() { std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; - int x; ncclGetVersion(&x); return x == NCCL_VERSION_CODE; @@ -80,11 +80,9 @@ if(NCCL_FOUND) # obtaining NCCL version and some sanity checks (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") endif() message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") - else() - message(STATUS "NCCL version < 2.3.5-5") endif () - set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) endif() diff --git a/cmake/Modules/FindSYCLToolkit.cmake b/cmake/Modules/FindSYCLToolkit.cmake index dce4f4b0313bea..1dac15bb676aff 100644 --- a/cmake/Modules/FindSYCLToolkit.cmake +++ b/cmake/Modules/FindSYCLToolkit.cmake @@ -99,12 +99,10 @@ set(PYTORCH_2_5_SYCL_TOOLKIT_VERSION 20249999) # By default, we use libsycl.so on Linux and sycl.lib on Windows as the SYCL library name. if (SYCL_COMPILER_VERSION VERSION_LESS_EQUAL PYTORCH_2_5_SYCL_TOOLKIT_VERSION) - # Don't use if(LINUX) here since this requires cmake>=3.25 and file is installed + # Don't use if(WIN32) here since this requires cmake>=3.25 and file is installed # and used by other projects. # See: https://cmake.org/cmake/help/v3.25/variable/LINUX.html - if(CMAKE_SYSTEM_NAME MATCHES "Linux") - set(sycl_lib_suffix "-preview") - elseif(CMAKE_SYSTEM_NAME MATCHES "Windows") + if(CMAKE_SYSTEM_NAME MATCHES "Windows") # On Windows, the SYCL library is named sycl7.lib until PYTORCH_2_5_SYCL_TOOLKIT_VERSION. # sycl.lib is supported in the later version. set(sycl_lib_suffix "7") diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake index f642072bdc51c8..9daa571955d4a3 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake @@ -909,7 +909,7 @@ mark_as_advanced(CUDA_cupti_LIBRARY) # Set the CUDA_LIBRARIES variable. This is the set of stuff to link against if you are # using the CUDA runtime. For the dynamic version of the runtime, most of the -# dependencies are brough in, but for the static version there are additional libraries +# dependencies are brought in, but for the static version there are additional libraries # and linker commands needed. # Initialize to empty set(CUDA_LIBRARIES) @@ -1202,7 +1202,7 @@ function(CUDA_COMPUTE_BUILD_PATH path build_path) # Only deal with CMake style paths from here on out file(TO_CMAKE_PATH "${path}" bpath) if (IS_ABSOLUTE "${bpath}") - # Absolute paths are generally unnessary, especially if something like + # Absolute paths are generally unnecessary, especially if something like # file(GLOB_RECURSE) is used to pick up the files. string(FIND "${bpath}" "${CMAKE_CURRENT_BINARY_DIR}" _binary_dir_pos) @@ -1225,7 +1225,7 @@ function(CUDA_COMPUTE_BUILD_PATH path build_path) # Avoid spaces string(REPLACE " " "_" bpath "${bpath}") - # Strip off the filename. I wait until here to do it, since removin the + # Strip off the filename. I wait until here to do it, since removing the # basename can make a path that looked like path/../basename turn into # path/.. (notice the trailing slash). get_filename_component(bpath "${bpath}" PATH) @@ -1725,7 +1725,7 @@ function(CUDA_LINK_SEPARABLE_COMPILATION_OBJECTS output_file cuda_target options list(APPEND flags -Xcompiler ${f}) endforeach() - # Add our general CUDA_NVCC_FLAGS with the configuration specifig flags + # Add our general CUDA_NVCC_FLAGS with the configuration specific flags set(nvcc_flags ${CUDA_NVCC_FLAGS} ${config_specific_flags} ${nvcc_flags}) file(RELATIVE_PATH output_file_relative_path "${CMAKE_BINARY_DIR}" "${output_file}") diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake index 9293df3aafbdef..59c5c11a1091f3 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake @@ -156,7 +156,7 @@ macro(cuda_execute_process status command) # copy and paste a runnable command line. set(cuda_execute_process_string) foreach(arg ${ARGN}) - # If there are quotes, excape them, so they come through. + # If there are quotes, escape them, so they come through. string(REPLACE "\"" "\\\"" arg ${arg}) # Args with spaces need quotes around them to get them to be parsed as a single argument. if(arg MATCHES " ") diff --git a/cmake/ProtoBuf.cmake b/cmake/ProtoBuf.cmake index 4c436dcd6451d1..fe3113205df127 100644 --- a/cmake/ProtoBuf.cmake +++ b/cmake/ProtoBuf.cmake @@ -33,25 +33,6 @@ macro(custom_protobuf_find) set(__caffe2_CMAKE_POSITION_INDEPENDENT_CODE ${CMAKE_POSITION_INDEPENDENT_CODE}) set(CMAKE_POSITION_INDEPENDENT_CODE ON) - if(MSVC) - foreach(flag_var - CMAKE_C_FLAGS CMAKE_C_FLAGS_RELEASE CMAKE_C_FLAGS_MINSIZEREL - CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL) - if(${flag_var} MATCHES "/Z[iI7]") - string(REGEX REPLACE "/Z[iI7]" "" ${flag_var} "${${flag_var}}") - endif() - endforeach(flag_var) - if(MSVC_Z7_OVERRIDE) - foreach(flag_var - CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELWITHDEBINFO - CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELWITHDEBINFO) - if(${flag_var} MATCHES "/Z[iI]") - string(REGEX REPLACE "/Z[iI]" "/Z7" ${flag_var} "${${flag_var}}") - endif() - endforeach(flag_var) - endif(MSVC_Z7_OVERRIDE) - endif(MSVC) - if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") message(WARNING "Ancient protobuf forces CMake compatibility") set(CMAKE_POLICY_VERSION_MINIMUM 3.5) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index c33269c0a28cb7..efb4de47b6df59 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -172,6 +172,7 @@ function(caffe2_print_configuration_summary) if(${USE_NCCL}) message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") endif() + message(STATUS " NVSHMEM_LIB_DIR : ${NVSHMEM_LIB_DIR}") message(STATUS " USE_NNPACK : ${USE_NNPACK}") message(STATUS " USE_NUMPY : ${USE_NUMPY}") message(STATUS " USE_OBSERVERS : ${USE_OBSERVERS}") diff --git a/cmake/iOS.cmake b/cmake/iOS.cmake index 9ca91304d5fdb2..ebf30504feab58 100644 --- a/cmake/iOS.cmake +++ b/cmake/iOS.cmake @@ -100,7 +100,7 @@ if(IOS_DEPLOYMENT_TARGET) set(XCODE_IOS_PLATFORM_VERSION_FLAGS "-m${XCODE_IOS_PLATFORM}-version-min=${IOS_DEPLOYMENT_TARGET}") endif() -# Hidden visibilty is required for cxx on iOS +# Hidden visibility is required for cxx on iOS set(CMAKE_C_FLAGS_INIT "${XCODE_IOS_PLATFORM_VERSION_FLAGS}") set(CMAKE_CXX_FLAGS_INIT "${XCODE_IOS_PLATFORM_VERSION_FLAGS} -fvisibility-inlines-hidden") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 1080b7bc25251e..cae0ca62f2361e 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -65,8 +65,14 @@ list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) macro(find_package_and_print_version PACKAGE_NAME) find_package("${PACKAGE_NAME}" ${ARGN}) - message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") - list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR}) + if(NOT ${PACKAGE_NAME}_FOUND) + message("Optional package ${PACKAGE_NAME} not found") + else() + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + if(${PACKAGE_NAME}_INCLUDE_DIR) + list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR}) + endif() + endif() endmacro() # Find the HIP Package @@ -76,16 +82,32 @@ find_package_and_print_version(HIP 1.0 MODULE) if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) - find_package_and_print_version(hip REQUIRED CONFIG) - # Find ROCM version for checks. UNIX filename is rocm_version.h, Windows is hip_version.h - if(UNIX) - find_package_and_print_version(rocm-core REQUIRED CONFIG) - find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h - HINTS ${rocm_core_INCLUDE_DIR}/rocm-core /usr/include) - else() # Win32 - find_file(ROCM_VERSION_HEADER_PATH NAMES hip_version.h - HINTS ${hip_INCLUDE_DIR}/hip) + + # The rocm-core package was only introduced in ROCm 6.4, so we make it optional. + find_package(rocm-core CONFIG) + + # Some old consumer HIP SDKs do not distribute rocm_version.h, so we allow + # falling back to the hip version, which everyone should have. + # rocm_version.h lives in the rocm-core package and hip_version.h lives in the + # hip (lower-case) package. Both are probed above and will be in + # ROCM_INCLUDE_DIRS if available. + find_file(ROCM_VERSION_HEADER_PATH + NAMES rocm-core/rocm_version.h + NO_DEFAULT_PATH + PATHS ${ROCM_INCLUDE_DIRS} + ) + set(ROCM_LIB_NAME "ROCM") + if(NOT ROCM_VERSION_HEADER_PATH) + find_file(ROCM_VERSION_HEADER_PATH + NAMES hip/hip_version.h + NO_DEFAULT_PATH + PATHS ${ROCM_INCLUDE_DIRS} + ) + set(ROCM_LIB_NAME "HIP") + endif() + if(NOT ROCM_VERSION_HEADER_PATH) + message(FATAL_ERROR "Could not find hip/hip_version.h or rocm-core/rocm_version.h in ${ROCM_INCLUDE_DIRS}") endif() get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME) @@ -96,15 +118,10 @@ if(HIP_FOUND) endif() # Read the ROCM headerfile into a variable - file(READ ${ROCM_HEADER_FILE} ROCM_HEADER_CONTENT) + message(STATUS "Reading ROCM version from: ${ROCM_HEADER_FILE}") + message(STATUS "Content: ${ROCM_HEADER_CONTENT}") + file(READ "${ROCM_HEADER_FILE}" ROCM_HEADER_CONTENT) - # Since Windows currently supports only a part of ROCm and names it HIP-SDK, - # we need to refer to the HIP-SDK equivalents of entities existing in ROCm lib. - if(UNIX) - set(ROCM_LIB_NAME "ROCM") - else() # Win32 - set(ROCM_LIB_NAME "HIP") - endif() # Below we use a RegEx to find ROCM version numbers. # Note that CMake does not support \s for blank space. That is # why in the regular expressions below we have a blank space in @@ -155,6 +172,7 @@ if(HIP_FOUND) find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(rocthrust REQUIRED) find_package_and_print_version(hipsolver REQUIRED) + find_package_and_print_version(rocsolver REQUIRED) # workaround cmake 4 build issue if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") message(WARNING "Work around hiprtc cmake failure for cmake >= 4") @@ -171,6 +189,9 @@ if(HIP_FOUND) find_package_and_print_version(hsa-runtime64 REQUIRED) endif() + # Optional components. + find_package_and_print_version(hipsparselt) # Will be required when ready. + list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS) if(UNIX) @@ -180,6 +201,21 @@ if(HIP_FOUND) set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0") + # check whether hipblaslt provides HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F + set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_outer_vec.cc") + file(WRITE ${file} "" + "#define LEGACY_HIPBLAS_DIRECT\n" + "#include \n" + "int main() {\n" + " hipblasLtMatmulMatrixScale_t attr = HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;\n" + " return 0;\n" + "}\n" + ) + try_compile(hipblaslt_compile_result_outer_vec ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hipblaslt_compile_output_outer_vec) + # check whether hipblaslt provides HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_vec_ext.cc") file(WRITE ${file} "" @@ -193,15 +229,21 @@ if(HIP_FOUND) try_compile(hipblaslt_compile_result_vec_ext ${PROJECT_RANDOM_BINARY_DIR} ${file} CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ - OUTPUT_VARIABLE hipblaslt_compile_output) - if(hipblaslt_compile_result_vec_ext) + OUTPUT_VARIABLE hipblaslt_compile_output_vec_ext) + + if(hipblaslt_compile_result_outer_vec) + set(HIPBLASLT_OUTER_VEC ON) + set(HIPBLASLT_VEC_EXT OFF) + message("hipblaslt is using scale pointer outer vec") + elseif(hipblaslt_compile_result_vec_ext) + set(HIPBLASLT_OUTER_VEC OFF) set(HIPBLASLT_VEC_EXT ON) - #message("hipblaslt is using scale pointer vec ext: ${hipblaslt_compile_output}") message("hipblaslt is using scale pointer vec ext") else() + set(HIPBLASLT_OUTER_VEC OFF) set(HIPBLASLT_VEC_EXT OFF) - message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output}") - #message("hipblaslt is NOT using scale pointer vec ext") + message("hipblaslt is NOT using scale pointer outer vec: ${hipblaslt_compile_output_outer_vec}") + message("hipblaslt is NOT using scale pointer vec ext: ${hipblaslt_compile_output_vec_ext}") endif() endif() endif() diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index 95a12605ffcafb..4fc9e814097f0b 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -297,7 +297,7 @@ set_property( TARGET caffe2::nvrtc PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvrtc caffe2::cuda) -# Add onnx namepsace definition to nvcc +# Add onnx namespace definition to nvcc if(ONNX_NAMESPACE) list(APPEND CUDA_NVCC_FLAGS "-DONNX_NAMESPACE=${ONNX_NAMESPACE}") else() @@ -313,7 +313,13 @@ endif() # setting nvcc arch flags torch_cuda_get_nvcc_gencode_flag(NVCC_FLAGS_EXTRA) # CMake 3.18 adds integrated support for architecture selection, but we can't rely on it -set(CMAKE_CUDA_ARCHITECTURES OFF) +if(DEFINED CMAKE_CUDA_ARCHITECTURES) + message(WARNING + "pytorch is not compatible with `CMAKE_CUDA_ARCHITECTURES` and will ignore its value. " + "Please configure `TORCH_CUDA_ARCH_LIST` instead.") + set(CMAKE_CUDA_ARCHITECTURES OFF) +endif() + list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA}") diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 9cda24d8e62dfd..d56dd74d6c028d 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -324,12 +324,20 @@ endmacro() # macro(torch_cuda_get_nvcc_gencode_flag store_var) # setting nvcc arch flags + # We need to support the explicitly and conveniently defined TORCH_CUDA_ARCH_LIST + if((NOT DEFINED TORCH_CUDA_ARCH_LIST) AND (DEFINED ENV{TORCH_CUDA_ARCH_LIST})) + set(TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST}) + endif() if(DEFINED CUDA_ARCH_NAME) message(WARNING "CUDA_ARCH_NAME is no longer used. Use TORCH_CUDA_ARCH_LIST instead. " "Right now, CUDA_ARCH_NAME is ${CUDA_ARCH_NAME} and " "TORCH_CUDA_ARCH_LIST is ${TORCH_CUDA_ARCH_LIST}.") - set(TORCH_CUDA_ARCH_LIST TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME}) + if(NOT TORCH_CUDA_ARCH_LIST) + set(TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME}) + else() + list(APPEND TORCH_CUDA_ARCH_LIST ${CUDA_ARCH_NAME}) + endif() endif() # Invoke cuda_select_nvcc_arch_flags from proper cmake FindCUDA. diff --git a/cmake/public/xpu.cmake b/cmake/public/xpu.cmake index 4400497f1f1bd7..be083cb93af102 100644 --- a/cmake/public/xpu.cmake +++ b/cmake/public/xpu.cmake @@ -6,7 +6,6 @@ if(TARGET torch::xpurt) endif() set(XPU_HOST_CXX_FLAGS) -set(XPU_DEVICE_CXX_FLAGS) # Find SYCL library. find_package(SYCLToolkit REQUIRED) @@ -37,12 +36,6 @@ torch_xpu_get_arch_list(XPU_ARCH_FLAGS) # propagate to torch-xpu-ops set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS}) -if(CMAKE_SYSTEM_NAME MATCHES "Linux" AND SYCL_COMPILER_VERSION VERSION_LESS_EQUAL PYTORCH_2_5_SYCL_TOOLKIT_VERSION) - # for ABI compatibility on Linux - string(APPEND XPU_HOST_CXX_FLAGS " -D__INTEL_PREVIEW_BREAKING_CHANGES") - string(APPEND XPU_DEVICE_CXX_FLAGS " -fpreview-breaking-changes") -endif() - string(APPEND XPU_HOST_CXX_FLAGS " -DSYCL_COMPILER_VERSION=${SYCL_COMPILER_VERSION}") if(DEFINED ENV{XPU_ENABLE_KINETO}) diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md new file mode 100644 index 00000000000000..c6f2fb10804001 --- /dev/null +++ b/docs/source/accelerator.md @@ -0,0 +1,27 @@ +# torch.accelerator + +```{eval-rst} +.. automodule:: torch.accelerator +``` + +```{eval-rst} +.. currentmodule:: torch.accelerator +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + device_count + is_available + current_accelerator + set_device_index + set_device_idx + current_device_index + current_device_idx + set_stream + current_stream + synchronize + device_index +``` diff --git a/docs/source/accelerator.rst b/docs/source/accelerator.rst deleted file mode 100644 index fd5b95716a47c8..00000000000000 --- a/docs/source/accelerator.rst +++ /dev/null @@ -1,20 +0,0 @@ -torch.accelerator -=================================== -.. automodule:: torch.accelerator -.. currentmodule:: torch.accelerator - -.. autosummary:: - :toctree: generated - :nosignatures: - - device_count - is_available - current_accelerator - set_device_index - set_device_idx - current_device_index - current_device_idx - set_stream - current_stream - synchronize - device_index diff --git a/docs/source/amp.md b/docs/source/amp.md new file mode 100644 index 00000000000000..023f927c6f63c2 --- /dev/null +++ b/docs/source/amp.md @@ -0,0 +1,582 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Automatic Mixed Precision package - torch.amp + +% Both modules below are missing doc entry. Adding them here for now. + +% This does not add anything to the rendered page + +```{eval-rst} +.. py:module:: torch.cpu.amp +``` + +```{eval-rst} +.. py:module:: torch.cuda.amp +``` + +```{eval-rst} +.. automodule:: torch.amp +``` + +```{eval-rst} +.. currentmodule:: torch.amp +``` + +{class}`torch.amp` provides convenience methods for mixed precision, +where some operations use the `torch.float32` (`float`) datatype and other operations +use lower precision floating point datatype (`lower_precision_fp`): `torch.float16` (`half`) or `torch.bfloat16`. Some ops, like linear layers and convolutions, +are much faster in `lower_precision_fp`. Other ops, like reductions, often require the dynamic +range of `float32`. Mixed precision tries to match each op to its appropriate datatype. + +Ordinarily, "automatic mixed precision training" with datatype of `torch.float16` uses {class}`torch.autocast` and +{class}`torch.amp.GradScaler` together, as shown in the {ref}`Automatic Mixed Precision examples` +and [Automatic Mixed Precision recipe](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html). +However, {class}`torch.autocast` and {class}`torch.GradScaler` are modular, and may be used separately if desired. +As shown in the CPU example section of {class}`torch.autocast`, "automatic mixed precision training/inference" on CPU with +datatype of `torch.bfloat16` only uses {class}`torch.autocast`. + +:::{warning} +`torch.cuda.amp.autocast(args...)` and `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast("cuda", args...)` or `torch.amp.autocast("cpu", args...)` instead. +`torch.cuda.amp.GradScaler(args...)` and `torch.cpu.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler("cuda", args...)` or `torch.amp.GradScaler("cpu", args...)` instead. +::: + +{class}`torch.autocast` and {class}`torch.cpu.amp.autocast` are new in version `1.10`. + +```{contents} +:local: true +``` + +(autocasting)= + +## Autocasting + +```{eval-rst} +.. currentmodule:: torch.amp.autocast_mode +``` + +```{eval-rst} +.. autofunction:: is_autocast_available +``` + +```{eval-rst} +.. currentmodule:: torch +``` + +```{eval-rst} +.. autoclass:: autocast + :members: +``` + +```{eval-rst} +.. currentmodule:: torch.amp +``` + +```{eval-rst} +.. autofunction:: custom_fwd +``` + +```{eval-rst} +.. autofunction:: custom_bwd +``` + +```{eval-rst} +.. currentmodule:: torch.cuda.amp +``` + +```{eval-rst} +.. autoclass:: autocast + :members: +``` + +```{eval-rst} +.. autofunction:: custom_fwd +``` + +```{eval-rst} +.. autofunction:: custom_bwd +``` + +```{eval-rst} +.. currentmodule:: torch.cpu.amp +``` + +```{eval-rst} +.. autoclass:: autocast + :members: +``` + +(gradient-scaling)= + +## Gradient Scaling + +If the forward pass for a particular op has `float16` inputs, the backward pass for +that op will produce `float16` gradients. +Gradient values with small magnitudes may not be representable in `float16`. +These values will flush to zero ("underflow"), so the update for the corresponding parameters will be lost. + +To prevent underflow, "gradient scaling" multiplies the network's loss(es) by a scale factor and +invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are +then scaled by the same factor. In other words, gradient values have a larger magnitude, +so they don't flush to zero. + +Each parameter's gradient (`.grad` attribute) should be unscaled before the optimizer +updates the parameters, so the scale factor does not interfere with the learning rate. + +:::{note} +AMP/fp16 may not work for every model! For example, most bf16-pretrained models cannot operate in +the fp16 numerical range of max 65504 and will cause gradients to overflow instead of underflow. In +this case, the scale factor may decrease under 1 as an attempt to bring gradients to a number +representable in the fp16 dynamic range. While one may expect the scale to always be above 1, our +GradScaler does NOT make this guarantee to maintain performance. If you encounter NaNs in your loss +or gradients when running with AMP/fp16, verify your model is compatible. +::: + +```{eval-rst} +.. currentmodule:: torch.cuda.amp +``` + +```{eval-rst} +.. autoclass:: GradScaler + :members: +``` + +```{eval-rst} +.. currentmodule:: torch.cpu.amp +``` + +```{eval-rst} +.. autoclass:: GradScaler + :members: +``` + +(autocast-op-reference)= + +## Autocast Op Reference + +(autocast-eligibility)= + +### Op Eligibility + +Ops that run in `float64` or non-floating-point dtypes are not eligible, and will +run in these types whether or not autocast is enabled. + +Only out-of-place ops and Tensor methods are eligible. +In-place variants and calls that explicitly supply an `out=...` Tensor +are allowed in autocast-enabled regions, but won't go through autocasting. +For example, in an autocast-enabled region `a.addmm(b, c)` can autocast, +but `a.addmm_(b, c)` and `a.addmm(b, c, out=d)` cannot. +For best performance and stability, prefer out-of-place ops in autocast-enabled +regions. + +Ops called with an explicit `dtype=...` argument are not eligible, +and will produce output that respects the `dtype` argument. + +(autocast-cuda-op-reference)= + +### CUDA Op-Specific Behavior + +The following lists describe the behavior of eligible ops in autocast-enabled regions. +These ops always go through autocasting whether they are invoked as part of a {class}`torch.nn.Module`, +as a function, or as a {class}`torch.Tensor` method. If functions are exposed in multiple namespaces, +they go through autocasting regardless of the namespace. + +Ops not listed below do not go through autocasting. They run in the type +defined by their inputs. However, autocasting may still change the type +in which unlisted ops run if they're downstream from autocasted ops. + +If an op is unlisted, we assume it's numerically stable in `float16`. +If you believe an unlisted op is numerically unstable in `float16`, +please file an issue. + +#### CUDA Ops that can autocast to `float16` + +`__matmul__`, +`addbmm`, +`addmm`, +`addmv`, +`addr`, +`baddbmm`, +`bmm`, +`chain_matmul`, +`multi_dot`, +`conv1d`, +`conv2d`, +`conv3d`, +`conv_transpose1d`, +`conv_transpose2d`, +`conv_transpose3d`, +`GRUCell`, +`linear`, +`LSTMCell`, +`matmul`, +`mm`, +`mv`, +`prelu`, +`RNNCell` + +#### CUDA Ops that can autocast to `float32` + +`__pow__`, +`__rdiv__`, +`__rpow__`, +`__rtruediv__`, +`acos`, +`asin`, +`binary_cross_entropy_with_logits`, +`cosh`, +`cosine_embedding_loss`, +`cdist`, +`cosine_similarity`, +`cross_entropy`, +`cumprod`, +`cumsum`, +`dist`, +`erfinv`, +`exp`, +`expm1`, +`group_norm`, +`hinge_embedding_loss`, +`kl_div`, +`l1_loss`, +`layer_norm`, +`log`, +`log_softmax`, +`log10`, +`log1p`, +`log2`, +`margin_ranking_loss`, +`mse_loss`, +`multilabel_margin_loss`, +`multi_margin_loss`, +`nll_loss`, +`norm`, +`normalize`, +`pdist`, +`poisson_nll_loss`, +`pow`, +`prod`, +`reciprocal`, +`rsqrt`, +`sinh`, +`smooth_l1_loss`, +`soft_margin_loss`, +`softmax`, +`softmin`, +`softplus`, +`sum`, +`renorm`, +`tan`, +`triplet_margin_loss` + +#### CUDA Ops that promote to the widest input type + +These ops don't require a particular dtype for stability, but take multiple inputs +and require that the inputs' dtypes match. If all of the inputs are +`float16`, the op runs in `float16`. If any of the inputs is `float32`, +autocast casts all inputs to `float32` and runs the op in `float32`. + +`addcdiv`, +`addcmul`, +`atan2`, +`bilinear`, +`cross`, +`dot`, +`grid_sample`, +`index_put`, +`scatter_add`, +`tensordot` + +Some ops not listed here (e.g., binary ops like `add`) natively promote +inputs without autocasting's intervention. If inputs are a mixture of `float16` +and `float32`, these ops run in `float32` and produce `float32` output, +regardless of whether autocast is enabled. + +#### Prefer `binary_cross_entropy_with_logits` over `binary_cross_entropy` + +The backward passes of {func}`torch.nn.functional.binary_cross_entropy` (and {mod}`torch.nn.BCELoss`, which wraps it) +can produce gradients that aren't representable in `float16`. In autocast-enabled regions, the forward input +may be `float16`, which means the backward gradient must be representable in `float16` (autocasting `float16` +forward inputs to `float32` doesn't help, because that cast must be reversed in backward). +Therefore, `binary_cross_entropy` and `BCELoss` raise an error in autocast-enabled regions. + +Many models use a sigmoid layer right before the binary cross entropy layer. +In this case, combine the two layers using {func}`torch.nn.functional.binary_cross_entropy_with_logits` +or {mod}`torch.nn.BCEWithLogitsLoss`. `binary_cross_entropy_with_logits` and `BCEWithLogits` +are safe to autocast. + +(autocast-xpu-op-reference)= + +### XPU Op-Specific Behavior (Experimental) + +The following lists describe the behavior of eligible ops in autocast-enabled regions. +These ops always go through autocasting whether they are invoked as part of a {class}`torch.nn.Module`, +as a function, or as a {class}`torch.Tensor` method. If functions are exposed in multiple namespaces, +they go through autocasting regardless of the namespace. + +Ops not listed below do not go through autocasting. They run in the type +defined by their inputs. However, autocasting may still change the type +in which unlisted ops run if they're downstream from autocasted ops. + +If an op is unlisted, we assume it's numerically stable in `float16`. +If you believe an unlisted op is numerically unstable in `float16`, +please file an issue. + +#### XPU Ops that can autocast to `float16` + +`addbmm`, +`addmm`, +`addmv`, +`addr`, +`baddbmm`, +`bmm`, +`chain_matmul`, +`multi_dot`, +`conv1d`, +`conv2d`, +`conv3d`, +`conv_transpose1d`, +`conv_transpose2d`, +`conv_transpose3d`, +`GRUCell`, +`linear`, +`LSTMCell`, +`matmul`, +`mm`, +`mv`, +`RNNCell` + +#### XPU Ops that can autocast to `float32` + +`__pow__`, +`__rdiv__`, +`__rpow__`, +`__rtruediv__`, +`binary_cross_entropy_with_logits`, +`cosine_embedding_loss`, +`cosine_similarity`, +`cumsum`, +`dist`, +`exp`, +`group_norm`, +`hinge_embedding_loss`, +`kl_div`, +`l1_loss`, +`layer_norm`, +`log`, +`log_softmax`, +`margin_ranking_loss`, +`nll_loss`, +`normalize`, +`poisson_nll_loss`, +`pow`, +`reciprocal`, +`rsqrt`, +`soft_margin_loss`, +`softmax`, +`softmin`, +`sum`, +`triplet_margin_loss` + +#### XPU Ops that promote to the widest input type + +These ops don't require a particular dtype for stability, but take multiple inputs +and require that the inputs' dtypes match. If all of the inputs are +`float16`, the op runs in `float16`. If any of the inputs is `float32`, +autocast casts all inputs to `float32` and runs the op in `float32`. + +`bilinear`, +`cross`, +`grid_sample`, +`index_put`, +`scatter_add`, +`tensordot` + +Some ops not listed here (e.g., binary ops like `add`) natively promote +inputs without autocasting's intervention. If inputs are a mixture of `float16` +and `float32`, these ops run in `float32` and produce `float32` output, +regardless of whether autocast is enabled. + +(autocast-cpu-op-reference)= + +### CPU Op-Specific Behavior + +The following lists describe the behavior of eligible ops in autocast-enabled regions. +These ops always go through autocasting whether they are invoked as part of a {class}`torch.nn.Module`, +as a function, or as a {class}`torch.Tensor` method. If functions are exposed in multiple namespaces, +they go through autocasting regardless of the namespace. + +Ops not listed below do not go through autocasting. They run in the type +defined by their inputs. However, autocasting may still change the type +in which unlisted ops run if they're downstream from autocasted ops. + +If an op is unlisted, we assume it's numerically stable in `bfloat16`. +If you believe an unlisted op is numerically unstable in `bfloat16`, +please file an issue. `float16` shares the lists of `bfloat16`. + +#### CPU Ops that can autocast to `bfloat16` + +`conv1d`, +`conv2d`, +`conv3d`, +`bmm`, +`mm`, +`linalg_vecdot`, +`baddbmm`, +`addmm`, +`addbmm`, +`linear`, +`matmul`, +`_convolution`, +`conv_tbc`, +`mkldnn_rnn_layer`, +`conv_transpose1d`, +`conv_transpose2d`, +`conv_transpose3d`, +`prelu`, +`scaled_dot_product_attention`, +`_native_multi_head_attention` + +#### CPU Ops that can autocast to `float32` + +`avg_pool3d`, +`binary_cross_entropy`, +`grid_sampler`, +`grid_sampler_2d`, +`_grid_sampler_2d_cpu_fallback`, +`grid_sampler_3d`, +`polar`, +`prod`, +`quantile`, +`nanquantile`, +`stft`, +`cdist`, +`trace`, +`view_as_complex`, +`cholesky`, +`cholesky_inverse`, +`cholesky_solve`, +`inverse`, +`lu_solve`, +`orgqr`, +`inverse`, +`ormqr`, +`pinverse`, +`max_pool3d`, +`max_unpool2d`, +`max_unpool3d`, +`adaptive_avg_pool3d`, +`reflection_pad1d`, +`reflection_pad2d`, +`replication_pad1d`, +`replication_pad2d`, +`replication_pad3d`, +`mse_loss`, +`cosine_embedding_loss`, +`nll_loss`, +`nll_loss2d`, +`hinge_embedding_loss`, +`poisson_nll_loss`, +`cross_entropy_loss`, +`l1_loss`, +`huber_loss`, +`margin_ranking_loss`, +`soft_margin_loss`, +`triplet_margin_loss`, +`multi_margin_loss`, +`ctc_loss`, +`kl_div`, +`multilabel_margin_loss`, +`binary_cross_entropy_with_logits`, +`fft_fft`, +`fft_ifft`, +`fft_fft2`, +`fft_ifft2`, +`fft_fftn`, +`fft_ifftn`, +`fft_rfft`, +`fft_irfft`, +`fft_rfft2`, +`fft_irfft2`, +`fft_rfftn`, +`fft_irfftn`, +`fft_hfft`, +`fft_ihfft`, +`linalg_cond`, +`linalg_matrix_rank`, +`linalg_solve`, +`linalg_cholesky`, +`linalg_svdvals`, +`linalg_eigvals`, +`linalg_eigvalsh`, +`linalg_inv`, +`linalg_householder_product`, +`linalg_tensorinv`, +`linalg_tensorsolve`, +`fake_quantize_per_tensor_affine`, +`geqrf`, +`_lu_with_info`, +`qr`, +`svd`, +`triangular_solve`, +`fractional_max_pool2d`, +`fractional_max_pool3d`, +`adaptive_max_pool3d`, +`multilabel_margin_loss_forward`, +`linalg_qr`, +`linalg_cholesky_ex`, +`linalg_svd`, +`linalg_eig`, +`linalg_eigh`, +`linalg_lstsq`, +`linalg_inv_ex` + +#### CPU Ops that promote to the widest input type + +These ops don't require a particular dtype for stability, but take multiple inputs +and require that the inputs' dtypes match. If all of the inputs are +`bfloat16`, the op runs in `bfloat16`. If any of the inputs is `float32`, +autocast casts all inputs to `float32` and runs the op in `float32`. + +`cat`, +`stack`, +`index_copy` + +Some ops not listed here (e.g., binary ops like `add`) natively promote +inputs without autocasting's intervention. If inputs are a mixture of `bfloat16` +and `float32`, these ops run in `float32` and produce `float32` output, +regardless of whether autocast is enabled. + +% This module needs to be documented. Adding here in the meantime + +% for tracking purposes + +```{eval-rst} +.. py:module:: torch.amp.autocast_mode +``` + +```{eval-rst} +.. py:module:: torch.cpu.amp.autocast_mode +``` + +```{eval-rst} +.. py:module:: torch.cuda.amp.autocast_mode +``` + +```{eval-rst} +.. py:module:: torch.cuda.amp.common +``` + +```{eval-rst} +.. py:module:: torch.amp.grad_scaler +``` + +```{eval-rst} +.. py:module:: torch.cpu.amp.grad_scaler +``` + +```{eval-rst} +.. py:module:: torch.cuda.amp.grad_scaler +``` diff --git a/docs/source/amp.rst b/docs/source/amp.rst deleted file mode 100644 index 214843c1c0647d..00000000000000 --- a/docs/source/amp.rst +++ /dev/null @@ -1,519 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Automatic Mixed Precision package - torch.amp -============================================= - -.. Both modules below are missing doc entry. Adding them here for now. -.. This does not add anything to the rendered page -.. py:module:: torch.cpu.amp -.. py:module:: torch.cuda.amp - -.. automodule:: torch.amp -.. currentmodule:: torch.amp - -:class:`torch.amp` provides convenience methods for mixed precision, -where some operations use the ``torch.float32`` (``float``) datatype and other operations -use lower precision floating point datatype (``lower_precision_fp``): ``torch.float16`` (``half``) or ``torch.bfloat16``. Some ops, like linear layers and convolutions, -are much faster in ``lower_precision_fp``. Other ops, like reductions, often require the dynamic -range of ``float32``. Mixed precision tries to match each op to its appropriate datatype. - -Ordinarily, "automatic mixed precision training" with datatype of ``torch.float16`` uses :class:`torch.autocast` and -:class:`torch.amp.GradScaler` together, as shown in the :ref:`Automatic Mixed Precision examples` -and `Automatic Mixed Precision recipe `_. -However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and may be used separately if desired. -As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with -datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`. - -.. warning:: - ``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` or ``torch.amp.autocast("cpu", args...)`` instead. - ``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` or ``torch.amp.GradScaler("cpu", args...)`` instead. - -:class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`. - -.. contents:: :local: - -.. _autocasting: - -Autocasting -^^^^^^^^^^^ -.. currentmodule:: torch.amp.autocast_mode - -.. autofunction:: is_autocast_available - -.. currentmodule:: torch - -.. autoclass:: autocast - :members: - -.. currentmodule:: torch.amp - -.. autofunction:: custom_fwd - -.. autofunction:: custom_bwd - -.. currentmodule:: torch.cuda.amp - -.. autoclass:: autocast - :members: - -.. autofunction:: custom_fwd - -.. autofunction:: custom_bwd - -.. currentmodule:: torch.cpu.amp - -.. autoclass:: autocast - :members: - -.. _gradient-scaling: - -Gradient Scaling -^^^^^^^^^^^^^^^^ - -If the forward pass for a particular op has ``float16`` inputs, the backward pass for -that op will produce ``float16`` gradients. -Gradient values with small magnitudes may not be representable in ``float16``. -These values will flush to zero ("underflow"), so the update for the corresponding parameters will be lost. - -To prevent underflow, "gradient scaling" multiplies the network's loss(es) by a scale factor and -invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are -then scaled by the same factor. In other words, gradient values have a larger magnitude, -so they don't flush to zero. - -Each parameter's gradient (``.grad`` attribute) should be unscaled before the optimizer -updates the parameters, so the scale factor does not interfere with the learning rate. - -.. note:: - - AMP/fp16 may not work for every model! For example, most bf16-pretrained models cannot operate in - the fp16 numerical range of max 65504 and will cause gradients to overflow instead of underflow. In - this case, the scale factor may decrease under 1 as an attempt to bring gradients to a number - representable in the fp16 dynamic range. While one may expect the scale to always be above 1, our - GradScaler does NOT make this guarantee to maintain performance. If you encounter NaNs in your loss - or gradients when running with AMP/fp16, verify your model is compatible. - -.. currentmodule:: torch.cuda.amp - -.. autoclass:: GradScaler - :members: - -.. currentmodule:: torch.cpu.amp - -.. autoclass:: GradScaler - :members: - -.. _autocast-op-reference: - -Autocast Op Reference -^^^^^^^^^^^^^^^^^^^^^ - -.. _autocast-eligibility: - -Op Eligibility --------------- -Ops that run in ``float64`` or non-floating-point dtypes are not eligible, and will -run in these types whether or not autocast is enabled. - -Only out-of-place ops and Tensor methods are eligible. -In-place variants and calls that explicitly supply an ``out=...`` Tensor -are allowed in autocast-enabled regions, but won't go through autocasting. -For example, in an autocast-enabled region ``a.addmm(b, c)`` can autocast, -but ``a.addmm_(b, c)`` and ``a.addmm(b, c, out=d)`` cannot. -For best performance and stability, prefer out-of-place ops in autocast-enabled -regions. - -Ops called with an explicit ``dtype=...`` argument are not eligible, -and will produce output that respects the ``dtype`` argument. - -.. _autocast-cuda-op-reference: - -CUDA Op-Specific Behavior -------------------------- -The following lists describe the behavior of eligible ops in autocast-enabled regions. -These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`, -as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces, -they go through autocasting regardless of the namespace. - -Ops not listed below do not go through autocasting. They run in the type -defined by their inputs. However, autocasting may still change the type -in which unlisted ops run if they're downstream from autocasted ops. - -If an op is unlisted, we assume it's numerically stable in ``float16``. -If you believe an unlisted op is numerically unstable in ``float16``, -please file an issue. - -CUDA Ops that can autocast to ``float16`` -""""""""""""""""""""""""""""""""""""""""" - -``__matmul__``, -``addbmm``, -``addmm``, -``addmv``, -``addr``, -``baddbmm``, -``bmm``, -``chain_matmul``, -``multi_dot``, -``conv1d``, -``conv2d``, -``conv3d``, -``conv_transpose1d``, -``conv_transpose2d``, -``conv_transpose3d``, -``GRUCell``, -``linear``, -``LSTMCell``, -``matmul``, -``mm``, -``mv``, -``prelu``, -``RNNCell`` - -CUDA Ops that can autocast to ``float32`` -""""""""""""""""""""""""""""""""""""""""" - -``__pow__``, -``__rdiv__``, -``__rpow__``, -``__rtruediv__``, -``acos``, -``asin``, -``binary_cross_entropy_with_logits``, -``cosh``, -``cosine_embedding_loss``, -``cdist``, -``cosine_similarity``, -``cross_entropy``, -``cumprod``, -``cumsum``, -``dist``, -``erfinv``, -``exp``, -``expm1``, -``group_norm``, -``hinge_embedding_loss``, -``kl_div``, -``l1_loss``, -``layer_norm``, -``log``, -``log_softmax``, -``log10``, -``log1p``, -``log2``, -``margin_ranking_loss``, -``mse_loss``, -``multilabel_margin_loss``, -``multi_margin_loss``, -``nll_loss``, -``norm``, -``normalize``, -``pdist``, -``poisson_nll_loss``, -``pow``, -``prod``, -``reciprocal``, -``rsqrt``, -``sinh``, -``smooth_l1_loss``, -``soft_margin_loss``, -``softmax``, -``softmin``, -``softplus``, -``sum``, -``renorm``, -``tan``, -``triplet_margin_loss`` - -CUDA Ops that promote to the widest input type -"""""""""""""""""""""""""""""""""""""""""""""" -These ops don't require a particular dtype for stability, but take multiple inputs -and require that the inputs' dtypes match. If all of the inputs are -``float16``, the op runs in ``float16``. If any of the inputs is ``float32``, -autocast casts all inputs to ``float32`` and runs the op in ``float32``. - -``addcdiv``, -``addcmul``, -``atan2``, -``bilinear``, -``cross``, -``dot``, -``grid_sample``, -``index_put``, -``scatter_add``, -``tensordot`` - -Some ops not listed here (e.g., binary ops like ``add``) natively promote -inputs without autocasting's intervention. If inputs are a mixture of ``float16`` -and ``float32``, these ops run in ``float32`` and produce ``float32`` output, -regardless of whether autocast is enabled. - -Prefer ``binary_cross_entropy_with_logits`` over ``binary_cross_entropy`` -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" -The backward passes of :func:`torch.nn.functional.binary_cross_entropy` (and :mod:`torch.nn.BCELoss`, which wraps it) -can produce gradients that aren't representable in ``float16``. In autocast-enabled regions, the forward input -may be ``float16``, which means the backward gradient must be representable in ``float16`` (autocasting ``float16`` -forward inputs to ``float32`` doesn't help, because that cast must be reversed in backward). -Therefore, ``binary_cross_entropy`` and ``BCELoss`` raise an error in autocast-enabled regions. - -Many models use a sigmoid layer right before the binary cross entropy layer. -In this case, combine the two layers using :func:`torch.nn.functional.binary_cross_entropy_with_logits` -or :mod:`torch.nn.BCEWithLogitsLoss`. ``binary_cross_entropy_with_logits`` and ``BCEWithLogits`` -are safe to autocast. - -.. _autocast-xpu-op-reference: - -XPU Op-Specific Behavior (Experimental) ---------------------------------------- -The following lists describe the behavior of eligible ops in autocast-enabled regions. -These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`, -as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces, -they go through autocasting regardless of the namespace. - -Ops not listed below do not go through autocasting. They run in the type -defined by their inputs. However, autocasting may still change the type -in which unlisted ops run if they're downstream from autocasted ops. - -If an op is unlisted, we assume it's numerically stable in ``float16``. -If you believe an unlisted op is numerically unstable in ``float16``, -please file an issue. - -XPU Ops that can autocast to ``float16`` -"""""""""""""""""""""""""""""""""""""""" - -``addbmm``, -``addmm``, -``addmv``, -``addr``, -``baddbmm``, -``bmm``, -``chain_matmul``, -``multi_dot``, -``conv1d``, -``conv2d``, -``conv3d``, -``conv_transpose1d``, -``conv_transpose2d``, -``conv_transpose3d``, -``GRUCell``, -``linear``, -``LSTMCell``, -``matmul``, -``mm``, -``mv``, -``RNNCell`` - -XPU Ops that can autocast to ``float32`` -"""""""""""""""""""""""""""""""""""""""" - -``__pow__``, -``__rdiv__``, -``__rpow__``, -``__rtruediv__``, -``binary_cross_entropy_with_logits``, -``cosine_embedding_loss``, -``cosine_similarity``, -``cumsum``, -``dist``, -``exp``, -``group_norm``, -``hinge_embedding_loss``, -``kl_div``, -``l1_loss``, -``layer_norm``, -``log``, -``log_softmax``, -``margin_ranking_loss``, -``nll_loss``, -``normalize``, -``poisson_nll_loss``, -``pow``, -``reciprocal``, -``rsqrt``, -``soft_margin_loss``, -``softmax``, -``softmin``, -``sum``, -``triplet_margin_loss`` - -XPU Ops that promote to the widest input type -""""""""""""""""""""""""""""""""""""""""""""" -These ops don't require a particular dtype for stability, but take multiple inputs -and require that the inputs' dtypes match. If all of the inputs are -``float16``, the op runs in ``float16``. If any of the inputs is ``float32``, -autocast casts all inputs to ``float32`` and runs the op in ``float32``. - -``bilinear``, -``cross``, -``grid_sample``, -``index_put``, -``scatter_add``, -``tensordot`` - -Some ops not listed here (e.g., binary ops like ``add``) natively promote -inputs without autocasting's intervention. If inputs are a mixture of ``float16`` -and ``float32``, these ops run in ``float32`` and produce ``float32`` output, -regardless of whether autocast is enabled. - -.. _autocast-cpu-op-reference: - -CPU Op-Specific Behavior ------------------------- -The following lists describe the behavior of eligible ops in autocast-enabled regions. -These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`, -as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces, -they go through autocasting regardless of the namespace. - -Ops not listed below do not go through autocasting. They run in the type -defined by their inputs. However, autocasting may still change the type -in which unlisted ops run if they're downstream from autocasted ops. - -If an op is unlisted, we assume it's numerically stable in ``bfloat16``. -If you believe an unlisted op is numerically unstable in ``bfloat16``, -please file an issue. ``float16`` shares the lists of ``bfloat16``. - -CPU Ops that can autocast to ``bfloat16`` -""""""""""""""""""""""""""""""""""""""""" - -``conv1d``, -``conv2d``, -``conv3d``, -``bmm``, -``mm``, -``linalg_vecdot``, -``baddbmm``, -``addmm``, -``addbmm``, -``linear``, -``matmul``, -``_convolution``, -``conv_tbc``, -``mkldnn_rnn_layer``, -``conv_transpose1d``, -``conv_transpose2d``, -``conv_transpose3d``, -``prelu``, -``scaled_dot_product_attention``, -``_native_multi_head_attention`` - -CPU Ops that can autocast to ``float32`` -"""""""""""""""""""""""""""""""""""""""" - -``avg_pool3d``, -``binary_cross_entropy``, -``grid_sampler``, -``grid_sampler_2d``, -``_grid_sampler_2d_cpu_fallback``, -``grid_sampler_3d``, -``polar``, -``prod``, -``quantile``, -``nanquantile``, -``stft``, -``cdist``, -``trace``, -``view_as_complex``, -``cholesky``, -``cholesky_inverse``, -``cholesky_solve``, -``inverse``, -``lu_solve``, -``orgqr``, -``inverse``, -``ormqr``, -``pinverse``, -``max_pool3d``, -``max_unpool2d``, -``max_unpool3d``, -``adaptive_avg_pool3d``, -``reflection_pad1d``, -``reflection_pad2d``, -``replication_pad1d``, -``replication_pad2d``, -``replication_pad3d``, -``mse_loss``, -``cosine_embedding_loss``, -``nll_loss``, -``nll_loss2d``, -``hinge_embedding_loss``, -``poisson_nll_loss``, -``cross_entropy_loss``, -``l1_loss``, -``huber_loss``, -``margin_ranking_loss``, -``soft_margin_loss``, -``triplet_margin_loss``, -``multi_margin_loss``, -``ctc_loss``, -``kl_div``, -``multilabel_margin_loss``, -``binary_cross_entropy_with_logits``, -``fft_fft``, -``fft_ifft``, -``fft_fft2``, -``fft_ifft2``, -``fft_fftn``, -``fft_ifftn``, -``fft_rfft``, -``fft_irfft``, -``fft_rfft2``, -``fft_irfft2``, -``fft_rfftn``, -``fft_irfftn``, -``fft_hfft``, -``fft_ihfft``, -``linalg_cond``, -``linalg_matrix_rank``, -``linalg_solve``, -``linalg_cholesky``, -``linalg_svdvals``, -``linalg_eigvals``, -``linalg_eigvalsh``, -``linalg_inv``, -``linalg_householder_product``, -``linalg_tensorinv``, -``linalg_tensorsolve``, -``fake_quantize_per_tensor_affine``, -``geqrf``, -``_lu_with_info``, -``qr``, -``svd``, -``triangular_solve``, -``fractional_max_pool2d``, -``fractional_max_pool3d``, -``adaptive_max_pool3d``, -``multilabel_margin_loss_forward``, -``linalg_qr``, -``linalg_cholesky_ex``, -``linalg_svd``, -``linalg_eig``, -``linalg_eigh``, -``linalg_lstsq``, -``linalg_inv_ex`` - -CPU Ops that promote to the widest input type -""""""""""""""""""""""""""""""""""""""""""""" -These ops don't require a particular dtype for stability, but take multiple inputs -and require that the inputs' dtypes match. If all of the inputs are -``bfloat16``, the op runs in ``bfloat16``. If any of the inputs is ``float32``, -autocast casts all inputs to ``float32`` and runs the op in ``float32``. - -``cat``, -``stack``, -``index_copy`` - -Some ops not listed here (e.g., binary ops like ``add``) natively promote -inputs without autocasting's intervention. If inputs are a mixture of ``bfloat16`` -and ``float32``, these ops run in ``float32`` and produce ``float32`` output, -regardless of whether autocast is enabled. - - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.amp.autocast_mode -.. py:module:: torch.cpu.amp.autocast_mode -.. py:module:: torch.cuda.amp.autocast_mode -.. py:module:: torch.cuda.amp.common -.. py:module:: torch.amp.grad_scaler -.. py:module:: torch.cpu.amp.grad_scaler -.. py:module:: torch.cuda.amp.grad_scaler diff --git a/docs/source/autograd.md b/docs/source/autograd.md new file mode 100644 index 00000000000000..4218eac05d79d3 --- /dev/null +++ b/docs/source/autograd.md @@ -0,0 +1,472 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Automatic differentiation package - torch.autograd + +```{eval-rst} +.. automodule:: torch.autograd +``` + +```{eval-rst} +.. currentmodule:: torch.autograd +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + backward + grad +``` + +(forward-mode-ad)= + +## Forward-mode Automatic Differentiation + +:::{warning} +This API is in beta. Even though the function signatures are very unlikely to change, improved +operator coverage is planned before we consider this stable. +::: + +Please see the [forward-mode AD tutorial](https://pytorch.org/tutorials/intermediate/forward_ad_usage.html) +for detailed steps on how to use this API. + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + forward_ad.dual_level + forward_ad.make_dual + forward_ad.unpack_dual + forward_ad.enter_dual_level + forward_ad.exit_dual_level + forward_ad.UnpackedDualTensor +``` + +(functional-api)= + +## Functional higher level API + +:::{warning} +This API is in beta. Even though the function signatures are very unlikely to change, major +improvements to performances are planned before we consider this stable. +::: + +This section contains the higher level API for the autograd that builds on the basic API above +and allows you to compute jacobians, hessians, etc. + +This API works with user-provided functions that take only Tensors as input and return +only Tensors. +If your function takes other arguments that are not Tensors or Tensors that don't have requires_grad set, +you can use a lambda to capture them. +For example, for a function `f` that takes three inputs, a Tensor for which we want the jacobian, another +tensor that should be considered constant and a boolean flag as `f(input, constant, flag=flag)` +you can use it as `functional.jacobian(lambda x: f(x, constant, flag=flag), input)`. + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + functional.jacobian + functional.hessian + functional.vjp + functional.jvp + functional.vhp + functional.hvp +``` + +(locally-disable-grad)= + +## Locally disabling gradient computation + +See {ref}`locally-disable-grad-doc` for more information on the differences +between no-grad and inference mode as well as other related mechanisms that +may be confused with the two. Also see {ref}`torch-rst-local-disable-grad` +for a list of functions that can be used to locally disable gradients. + +(default-grad-layouts)= + +## Default gradient layouts + +When a non-sparse `param` receives a non-sparse gradient during +{func}`torch.autograd.backward` or {func}`torch.Tensor.backward` +`param.grad` is accumulated as follows. + +If `param.grad` is initially `None`: + +1. If `param`'s memory is non-overlapping and dense, `.grad` is + created with strides matching `param` (thus matching `param`'s + layout). +2. Otherwise, `.grad` is created with rowmajor-contiguous strides. + +If `param` already has a non-sparse `.grad` attribute: + +3. If `create_graph=False`, `backward()` accumulates into `.grad` + in-place, which preserves its strides. +4. If `create_graph=True`, `backward()` replaces `.grad` with a + new tensor `.grad + new grad`, which attempts (but does not guarantee) + matching the preexisting `.grad`'s strides. + +The default behavior (letting `.grad`s be `None` before the first +`backward()`, such that their layout is created according to 1 or 2, +and retained over time according to 3 or 4) is recommended for best performance. +Calls to `model.zero_grad()` or `optimizer.zero_grad()` will not affect `.grad` +layouts. + +In fact, resetting all `.grad`s to `None` before each +accumulation phase, e.g.: + +``` +for iterations... + ... + for param in model.parameters(): + param.grad = None + loss.backward() +``` + +such that they're recreated according to 1 or 2 every time, +is a valid alternative to `model.zero_grad()` or `optimizer.zero_grad()` +that may improve performance for some networks. + +### Manual gradient layouts + +If you need manual control over `.grad`'s strides, +assign `param.grad =` a zeroed tensor with desired strides +before the first `backward()`, and never reset it to `None`. +3 guarantees your layout is preserved as long as `create_graph=False`. +4 indicates your layout is *likely* preserved even if `create_graph=True`. + +## In-place operations on Tensors + +Supporting in-place operations in autograd is a hard matter, and we discourage +their use in most cases. Autograd's aggressive buffer freeing and reuse makes +it very efficient and there are very few occasions when in-place operations +actually lower memory usage by any significant amount. Unless you're operating +under heavy memory pressure, you might never need to use them. + +### In-place correctness checks + +All {class}`Tensor` s keep track of in-place operations applied to them, and +if the implementation detects that a tensor was saved for backward in one of +the functions, but it was modified in-place afterwards, an error will be raised +once backward pass is started. This ensures that if you're using in-place +functions and not seeing any errors, you can be sure that the computed +gradients are correct. + +## Variable (deprecated) + +:::{warning} +The Variable API has been deprecated: Variables are no longer necessary to +use autograd with tensors. Autograd automatically supports Tensors with +`requires_grad` set to `True`. Below please find a quick guide on what +has changed: + +- `Variable(tensor)` and `Variable(tensor, requires_grad)` still work as expected, + but they return Tensors instead of Variables. +- `var.data` is the same thing as `tensor.data`. +- Methods such as `var.backward(), var.detach(), var.register_hook()` now work on tensors + with the same method names. + +In addition, one can now create tensors with `requires_grad=True` using factory +methods such as {func}`torch.randn`, {func}`torch.zeros`, {func}`torch.ones`, and others +like the following: + +`autograd_tensor = torch.randn((2, 3, 4), requires_grad=True)` +::: + +## Tensor autograd functions + +```{eval-rst} +.. autosummary:: + :nosignatures: + + torch.Tensor.grad + torch.Tensor.requires_grad + torch.Tensor.is_leaf + torch.Tensor.backward + torch.Tensor.detach + torch.Tensor.detach_ + torch.Tensor.register_hook + torch.Tensor.register_post_accumulate_grad_hook + torch.Tensor.retain_grad +``` + +## {hidden}`Function` + +```{eval-rst} +.. autoclass:: Function +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + Function.forward + Function.backward + Function.jvp + Function.vmap +``` + +(context-method-mixins)= + +## Context method mixins + +When creating a new {class}`Function`, the following methods are available to `ctx`. + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + function.FunctionCtx.mark_dirty + function.FunctionCtx.mark_non_differentiable + function.FunctionCtx.save_for_backward + function.FunctionCtx.set_materialize_grads +``` + +## Custom Function utilities + +Decorator for backward method. + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + function.once_differentiable +``` + +Base custom {class}`Function` used to build PyTorch utilities + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + function.BackwardCFunction + function.InplaceFunction + function.NestedIOFunction + +``` + +(grad-check)= + +## Numerical gradient checking + +```{eval-rst} +.. automodule:: torch.autograd.gradcheck +``` + +```{eval-rst} +.. currentmodule:: torch.autograd.gradcheck +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + gradcheck + gradgradcheck + GradcheckError +``` + +% Just to reset the base path for the rest of this file + +```{eval-rst} +.. currentmodule:: torch.autograd +``` + +## Profiler + +Autograd includes a profiler that lets you inspect the cost of different +operators inside your model - both on the CPU and GPU. There are three modes +implemented at the moment - CPU-only using {class}`~torch.autograd.profiler.profile`. +nvprof based (registers both CPU and GPU activity) using +{class}`~torch.autograd.profiler.emit_nvtx`. +and vtune profiler based using +{class}`~torch.autograd.profiler.emit_itt`. + +```{eval-rst} +.. autoclass:: torch.autograd.profiler.profile +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + profiler.profile.export_chrome_trace + profiler.profile.key_averages + profiler.profile.self_cpu_time_total + profiler.profile.total_average + profiler.parse_nvprof_trace + profiler.EnforceUnique + profiler.KinetoStepTracker + profiler.record_function + profiler_util.Interval + profiler_util.Kernel + profiler_util.MemRecordsAcc + profiler_util.StringTable +``` + +```{eval-rst} +.. autoclass:: torch.autograd.profiler.emit_nvtx +``` + +```{eval-rst} +.. autoclass:: torch.autograd.profiler.emit_itt + +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + profiler.load_nvprof +``` + +## Debugging and anomaly detection + +```{eval-rst} +.. autoclass:: detect_anomaly +``` + +```{eval-rst} +.. autoclass:: set_detect_anomaly +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + grad_mode.set_multithreading_enabled + + +``` + +## Autograd graph + +Autograd exposes methods that allow one to inspect the graph and interpose behavior during +the backward pass. + +The `grad_fn` attribute of a {class}`torch.Tensor` holds a {class}`torch.autograd.graph.Node` +if the tensor is the output of a operation that was recorded by autograd (i.e., grad_mode is +enabled and at least one of the inputs required gradients), or `None` otherwise. + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + graph.Node.name + graph.Node.metadata + graph.Node.next_functions + graph.Node.register_hook + graph.Node.register_prehook + graph.increment_version +``` + +Some operations need intermediary results to be saved during the forward pass +in order to execute the backward pass. +These intermediary results are saved as attributes on the `grad_fn` and can be accessed. +For example: + +``` +>>> a = torch.tensor([0., 0., 0.], requires_grad=True) +>>> b = a.exp() +>>> print(isinstance(b.grad_fn, torch.autograd.graph.Node)) +True +>>> print(dir(b.grad_fn)) +['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_raw_saved_result', '_register_hook_dict', '_saved_result', 'metadata', 'name', 'next_functions', 'register_hook', 'register_prehook', 'requires_grad'] +>>> print(torch.allclose(b.grad_fn._saved_result, b)) +True +``` + +You can also define how these saved tensors should be packed / unpacked using hooks. +A common application is to trade compute for memory by saving those intermediary results +to disk or to CPU instead of leaving them on the GPU. This is especially useful if you +notice your model fits on GPU during evaluation, but not training. +Also see {ref}`saved-tensors-hooks-doc`. + +```{eval-rst} +.. autoclass:: torch.autograd.graph.saved_tensors_hooks +``` + +```{eval-rst} +.. autoclass:: torch.autograd.graph.save_on_cpu +``` + +```{eval-rst} +.. autoclass:: torch.autograd.graph.disable_saved_tensors_hooks +``` + +```{eval-rst} +.. autoclass:: torch.autograd.graph.register_multi_grad_hook +``` + +```{eval-rst} +.. autoclass:: torch.autograd.graph.allow_mutation_on_saved_tensors +``` + +```{eval-rst} +.. autoclass:: torch.autograd.graph.GradientEdge +``` + +```{eval-rst} +.. autofunction:: torch.autograd.graph.get_gradient_edge + + +``` + +% This module needs to be documented. Adding here in the meantime + +% for tracking purposes + +```{eval-rst} +.. py:module:: torch.autograd.anomaly_mode +``` + +```{eval-rst} +.. py:module:: torch.autograd.forward_ad +``` + +```{eval-rst} +.. py:module:: torch.autograd.function +``` + +```{eval-rst} +.. py:module:: torch.autograd.functional +``` + +```{eval-rst} +.. py:module:: torch.autograd.grad_mode +``` + +```{eval-rst} +.. py:module:: torch.autograd.graph +``` + +```{eval-rst} +.. py:module:: torch.autograd.profiler +``` + +```{eval-rst} +.. py:module:: torch.autograd.profiler_legacy +``` + +```{eval-rst} +.. py:module:: torch.autograd.profiler_util +``` + +```{eval-rst} +.. py:module:: torch.autograd.variable +``` diff --git a/docs/source/autograd.rst b/docs/source/autograd.rst deleted file mode 100644 index 195e96cd390168..00000000000000 --- a/docs/source/autograd.rst +++ /dev/null @@ -1,380 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Automatic differentiation package - torch.autograd -================================================== - -.. automodule:: torch.autograd -.. currentmodule:: torch.autograd - -.. autosummary:: - :toctree: generated - :nosignatures: - - backward - grad - -.. _forward-mode-ad: - -Forward-mode Automatic Differentiation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. warning:: - This API is in beta. Even though the function signatures are very unlikely to change, improved - operator coverage is planned before we consider this stable. - -Please see the `forward-mode AD tutorial `__ -for detailed steps on how to use this API. - -.. autosummary:: - :toctree: generated - :nosignatures: - - forward_ad.dual_level - forward_ad.make_dual - forward_ad.unpack_dual - forward_ad.enter_dual_level - forward_ad.exit_dual_level - forward_ad.UnpackedDualTensor - -.. _functional-api: - -Functional higher level API -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. warning:: - This API is in beta. Even though the function signatures are very unlikely to change, major - improvements to performances are planned before we consider this stable. - -This section contains the higher level API for the autograd that builds on the basic API above -and allows you to compute jacobians, hessians, etc. - -This API works with user-provided functions that take only Tensors as input and return -only Tensors. -If your function takes other arguments that are not Tensors or Tensors that don't have requires_grad set, -you can use a lambda to capture them. -For example, for a function ``f`` that takes three inputs, a Tensor for which we want the jacobian, another -tensor that should be considered constant and a boolean flag as ``f(input, constant, flag=flag)`` -you can use it as ``functional.jacobian(lambda x: f(x, constant, flag=flag), input)``. - -.. autosummary:: - :toctree: generated - :nosignatures: - - functional.jacobian - functional.hessian - functional.vjp - functional.jvp - functional.vhp - functional.hvp - -.. _locally-disable-grad: - -Locally disabling gradient computation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -See :ref:`locally-disable-grad-doc` for more information on the differences -between no-grad and inference mode as well as other related mechanisms that -may be confused with the two. Also see :ref:`torch-rst-local-disable-grad` -for a list of functions that can be used to locally disable gradients. - -.. _default-grad-layouts: - -Default gradient layouts -^^^^^^^^^^^^^^^^^^^^^^^^ - -When a non-sparse ``param`` receives a non-sparse gradient during -:func:`torch.autograd.backward` or :func:`torch.Tensor.backward` -``param.grad`` is accumulated as follows. - -If ``param.grad`` is initially ``None``: - -1. If ``param``'s memory is non-overlapping and dense, ``.grad`` is - created with strides matching ``param`` (thus matching ``param``'s - layout). -2. Otherwise, ``.grad`` is created with rowmajor-contiguous strides. - -If ``param`` already has a non-sparse ``.grad`` attribute: - -3. If ``create_graph=False``, ``backward()`` accumulates into ``.grad`` - in-place, which preserves its strides. -4. If ``create_graph=True``, ``backward()`` replaces ``.grad`` with a - new tensor ``.grad + new grad``, which attempts (but does not guarantee) - matching the preexisting ``.grad``'s strides. - -The default behavior (letting ``.grad``\ s be ``None`` before the first -``backward()``, such that their layout is created according to 1 or 2, -and retained over time according to 3 or 4) is recommended for best performance. -Calls to ``model.zero_grad()`` or ``optimizer.zero_grad()`` will not affect ``.grad`` -layouts. - -In fact, resetting all ``.grad``\ s to ``None`` before each -accumulation phase, e.g.:: - - for iterations... - ... - for param in model.parameters(): - param.grad = None - loss.backward() - -such that they're recreated according to 1 or 2 every time, -is a valid alternative to ``model.zero_grad()`` or ``optimizer.zero_grad()`` -that may improve performance for some networks. - -Manual gradient layouts ------------------------ - -If you need manual control over ``.grad``'s strides, -assign ``param.grad =`` a zeroed tensor with desired strides -before the first ``backward()``, and never reset it to ``None``. -3 guarantees your layout is preserved as long as ``create_graph=False``. -4 indicates your layout is *likely* preserved even if ``create_graph=True``. - -In-place operations on Tensors -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Supporting in-place operations in autograd is a hard matter, and we discourage -their use in most cases. Autograd's aggressive buffer freeing and reuse makes -it very efficient and there are very few occasions when in-place operations -actually lower memory usage by any significant amount. Unless you're operating -under heavy memory pressure, you might never need to use them. - -In-place correctness checks ---------------------------- - -All :class:`Tensor` s keep track of in-place operations applied to them, and -if the implementation detects that a tensor was saved for backward in one of -the functions, but it was modified in-place afterwards, an error will be raised -once backward pass is started. This ensures that if you're using in-place -functions and not seeing any errors, you can be sure that the computed -gradients are correct. - -Variable (deprecated) -^^^^^^^^^^^^^^^^^^^^^ - -.. warning:: - The Variable API has been deprecated: Variables are no longer necessary to - use autograd with tensors. Autograd automatically supports Tensors with - ``requires_grad`` set to ``True``. Below please find a quick guide on what - has changed: - - - ``Variable(tensor)`` and ``Variable(tensor, requires_grad)`` still work as expected, - but they return Tensors instead of Variables. - - ``var.data`` is the same thing as ``tensor.data``. - - Methods such as ``var.backward(), var.detach(), var.register_hook()`` now work on tensors - with the same method names. - - In addition, one can now create tensors with ``requires_grad=True`` using factory - methods such as :func:`torch.randn`, :func:`torch.zeros`, :func:`torch.ones`, and others - like the following: - - ``autograd_tensor = torch.randn((2, 3, 4), requires_grad=True)`` - -Tensor autograd functions -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autosummary:: - :nosignatures: - - torch.Tensor.grad - torch.Tensor.requires_grad - torch.Tensor.is_leaf - torch.Tensor.backward - torch.Tensor.detach - torch.Tensor.detach_ - torch.Tensor.register_hook - torch.Tensor.register_post_accumulate_grad_hook - torch.Tensor.retain_grad - -:hidden:`Function` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: Function - -.. autosummary:: - :toctree: generated - :nosignatures: - - Function.forward - Function.backward - Function.jvp - Function.vmap - -.. _context_method_mixins: - -Context method mixins -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When creating a new :class:`Function`, the following methods are available to `ctx`. - -.. autosummary:: - :toctree: generated - :nosignatures: - - function.FunctionCtx.mark_dirty - function.FunctionCtx.mark_non_differentiable - function.FunctionCtx.save_for_backward - function.FunctionCtx.set_materialize_grads - -Custom Function utilities -^^^^^^^^^^^^^^^^^^^^^^^^^ -Decorator for backward method. - -.. autosummary:: - :toctree: generated - :nosignatures: - - function.once_differentiable - -Base custom :class:`Function` used to build PyTorch utilities - -.. autosummary:: - :toctree: generated - :nosignatures: - - function.BackwardCFunction - function.InplaceFunction - function.NestedIOFunction - - -.. _grad-check: - -Numerical gradient checking -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - - -.. automodule:: torch.autograd.gradcheck -.. currentmodule:: torch.autograd.gradcheck - -.. autosummary:: - :toctree: generated - :nosignatures: - - gradcheck - gradgradcheck - GradcheckError - -.. Just to reset the base path for the rest of this file -.. currentmodule:: torch.autograd - -Profiler -^^^^^^^^ - -Autograd includes a profiler that lets you inspect the cost of different -operators inside your model - both on the CPU and GPU. There are three modes -implemented at the moment - CPU-only using :class:`~torch.autograd.profiler.profile`. -nvprof based (registers both CPU and GPU activity) using -:class:`~torch.autograd.profiler.emit_nvtx`. -and vtune profiler based using -:class:`~torch.autograd.profiler.emit_itt`. - -.. autoclass:: torch.autograd.profiler.profile - -.. autosummary:: - :toctree: generated - :nosignatures: - - profiler.profile.export_chrome_trace - profiler.profile.key_averages - profiler.profile.self_cpu_time_total - profiler.profile.total_average - profiler.parse_nvprof_trace - profiler.EnforceUnique - profiler.KinetoStepTracker - profiler.record_function - profiler_util.Interval - profiler_util.Kernel - profiler_util.MemRecordsAcc - profiler_util.StringTable - -.. autoclass:: torch.autograd.profiler.emit_nvtx -.. autoclass:: torch.autograd.profiler.emit_itt - - -.. autosummary:: - :toctree: generated - :nosignatures: - - profiler.load_nvprof - -Debugging and anomaly detection -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: detect_anomaly - -.. autoclass:: set_detect_anomaly - -.. autosummary:: - :toctree: generated - :nosignatures: - - grad_mode.set_multithreading_enabled - - - -Autograd graph -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Autograd exposes methods that allow one to inspect the graph and interpose behavior during -the backward pass. - -The ``grad_fn`` attribute of a :class:`torch.Tensor` holds a :class:`torch.autograd.graph.Node` -if the tensor is the output of a operation that was recorded by autograd (i.e., grad_mode is -enabled and at least one of the inputs required gradients), or ``None`` otherwise. - -.. autosummary:: - :toctree: generated - :nosignatures: - - graph.Node.name - graph.Node.metadata - graph.Node.next_functions - graph.Node.register_hook - graph.Node.register_prehook - graph.increment_version - -Some operations need intermediary results to be saved during the forward pass -in order to execute the backward pass. -These intermediary results are saved as attributes on the ``grad_fn`` and can be accessed. -For example:: - - >>> a = torch.tensor([0., 0., 0.], requires_grad=True) - >>> b = a.exp() - >>> print(isinstance(b.grad_fn, torch.autograd.graph.Node)) - True - >>> print(dir(b.grad_fn)) - ['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_raw_saved_result', '_register_hook_dict', '_saved_result', 'metadata', 'name', 'next_functions', 'register_hook', 'register_prehook', 'requires_grad'] - >>> print(torch.allclose(b.grad_fn._saved_result, b)) - True - -You can also define how these saved tensors should be packed / unpacked using hooks. -A common application is to trade compute for memory by saving those intermediary results -to disk or to CPU instead of leaving them on the GPU. This is especially useful if you -notice your model fits on GPU during evaluation, but not training. -Also see :ref:`saved-tensors-hooks-doc`. - -.. autoclass:: torch.autograd.graph.saved_tensors_hooks - -.. autoclass:: torch.autograd.graph.save_on_cpu - -.. autoclass:: torch.autograd.graph.disable_saved_tensors_hooks - -.. autoclass:: torch.autograd.graph.register_multi_grad_hook - -.. autoclass:: torch.autograd.graph.allow_mutation_on_saved_tensors - -.. autoclass:: torch.autograd.graph.GradientEdge - -.. autofunction:: torch.autograd.graph.get_gradient_edge - - - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.autograd.anomaly_mode -.. py:module:: torch.autograd.forward_ad -.. py:module:: torch.autograd.function -.. py:module:: torch.autograd.functional -.. py:module:: torch.autograd.grad_mode -.. py:module:: torch.autograd.graph -.. py:module:: torch.autograd.profiler -.. py:module:: torch.autograd.profiler_legacy -.. py:module:: torch.autograd.profiler_util -.. py:module:: torch.autograd.variable diff --git a/docs/source/backends.md b/docs/source/backends.md new file mode 100644 index 00000000000000..41869ba9b77b5c --- /dev/null +++ b/docs/source/backends.md @@ -0,0 +1,389 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# torch.backends + +```{eval-rst} +.. automodule:: torch.backends +``` + +`torch.backends` controls the behavior of various backends that PyTorch supports. + +These backends include: + +- `torch.backends.cpu` +- `torch.backends.cuda` +- `torch.backends.cudnn` +- `torch.backends.cusparselt` +- `torch.backends.mha` +- `torch.backends.mps` +- `torch.backends.mkl` +- `torch.backends.mkldnn` +- `torch.backends.nnpack` +- `torch.backends.openmp` +- `torch.backends.opt_einsum` +- `torch.backends.xeon` + +## torch.backends.cpu + +```{eval-rst} +.. automodule:: torch.backends.cpu +``` + +```{eval-rst} +.. autofunction:: torch.backends.cpu.get_cpu_capability +``` + +## torch.backends.cuda + +```{eval-rst} +.. automodule:: torch.backends.cuda +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.is_built +``` + +```{eval-rst} +.. currentmodule:: torch.backends.cuda.matmul +``` + +```{eval-rst} +.. attribute:: allow_tf32 + + A :class:`bool` that controls whether TensorFloat-32 tensor cores may be used in matrix + multiplications on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. +``` + +```{eval-rst} +.. attribute:: allow_fp16_reduced_precision_reduction + + A :class:`bool` that controls whether reduced precision reductions (e.g., with fp16 accumulation type) are allowed with fp16 GEMMs. +``` + +```{eval-rst} +.. attribute:: allow_bf16_reduced_precision_reduction + + A :class:`bool` that controls whether reduced precision reductions are allowed with bf16 GEMMs. +``` + +```{eval-rst} +.. currentmodule:: torch.backends.cuda +``` + +```{eval-rst} +.. attribute:: cufft_plan_cache + + ``cufft_plan_cache`` contains the cuFFT plan caches for each CUDA device. + Query a specific device `i`'s cache via `torch.backends.cuda.cufft_plan_cache[i]`. + + .. currentmodule:: torch.backends.cuda.cufft_plan_cache + .. attribute:: size + + A readonly :class:`int` that shows the number of plans currently in a cuFFT plan cache. + + .. attribute:: max_size + + A :class:`int` that controls the capacity of a cuFFT plan cache. + + .. method:: clear() + + Clears a cuFFT plan cache. +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.preferred_blas_library +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.preferred_rocm_fa_library +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.preferred_linalg_library +``` + +```{eval-rst} +.. autoclass:: torch.backends.cuda.SDPAParams +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.flash_sdp_enabled +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.enable_mem_efficient_sdp +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.mem_efficient_sdp_enabled +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.enable_flash_sdp +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.math_sdp_enabled +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.enable_math_sdp +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.cudnn_sdp_enabled +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.enable_cudnn_sdp +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.is_flash_attention_available +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.can_use_flash_attention +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.can_use_efficient_attention +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.can_use_cudnn_attention +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.sdp_kernel +``` + +## torch.backends.cudnn + +```{eval-rst} +.. automodule:: torch.backends.cudnn +``` + +```{eval-rst} +.. autofunction:: torch.backends.cudnn.version +``` + +```{eval-rst} +.. autofunction:: torch.backends.cudnn.is_available +``` + +```{eval-rst} +.. attribute:: enabled + + A :class:`bool` that controls whether cuDNN is enabled. +``` + +```{eval-rst} +.. attribute:: allow_tf32 + + A :class:`bool` that controls where TensorFloat-32 tensor cores may be used in cuDNN + convolutions on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. +``` + +```{eval-rst} +.. attribute:: deterministic + + A :class:`bool` that, if True, causes cuDNN to only use deterministic convolution algorithms. + See also :func:`torch.are_deterministic_algorithms_enabled` and + :func:`torch.use_deterministic_algorithms`. +``` + +```{eval-rst} +.. attribute:: benchmark + + A :class:`bool` that, if True, causes cuDNN to benchmark multiple convolution algorithms + and select the fastest. +``` + +```{eval-rst} +.. attribute:: benchmark_limit + + A :class:`int` that specifies the maximum number of cuDNN convolution algorithms to try when + `torch.backends.cudnn.benchmark` is True. Set `benchmark_limit` to zero to try every + available algorithm. Note that this setting only affects convolutions dispatched via the + cuDNN v8 API. +``` + +```{eval-rst} +.. py:module:: torch.backends.cudnn.rnn +``` + +## torch.backends.cusparselt + +```{eval-rst} +.. automodule:: torch.backends.cusparselt +``` + +```{eval-rst} +.. autofunction:: torch.backends.cusparselt.version +``` + +```{eval-rst} +.. autofunction:: torch.backends.cusparselt.is_available +``` + +## torch.backends.mha + +```{eval-rst} +.. automodule:: torch.backends.mha +``` + +```{eval-rst} +.. autofunction:: torch.backends.mha.get_fastpath_enabled +``` + +```{eval-rst} +.. autofunction:: torch.backends.mha.set_fastpath_enabled + +``` + +## torch.backends.mps + +```{eval-rst} +.. automodule:: torch.backends.mps +``` + +```{eval-rst} +.. autofunction:: torch.backends.mps.is_available +``` + +```{eval-rst} +.. autofunction:: torch.backends.mps.is_built + +``` + +## torch.backends.mkl + +```{eval-rst} +.. automodule:: torch.backends.mkl +``` + +```{eval-rst} +.. autofunction:: torch.backends.mkl.is_available +``` + +```{eval-rst} +.. autoclass:: torch.backends.mkl.verbose + +``` + +## torch.backends.mkldnn + +```{eval-rst} +.. automodule:: torch.backends.mkldnn +``` + +```{eval-rst} +.. autofunction:: torch.backends.mkldnn.is_available +``` + +```{eval-rst} +.. autoclass:: torch.backends.mkldnn.verbose +``` + +## torch.backends.nnpack + +```{eval-rst} +.. automodule:: torch.backends.nnpack +``` + +```{eval-rst} +.. autofunction:: torch.backends.nnpack.is_available +``` + +```{eval-rst} +.. autofunction:: torch.backends.nnpack.flags +``` + +```{eval-rst} +.. autofunction:: torch.backends.nnpack.set_flags +``` + +## torch.backends.openmp + +```{eval-rst} +.. automodule:: torch.backends.openmp +``` + +```{eval-rst} +.. autofunction:: torch.backends.openmp.is_available +``` + +% Docs for other backends need to be added here. +% Automodules are just here to ensure checks run but they don't actually +% add anything to the rendered page for now. + +```{eval-rst} +.. py:module:: torch.backends.quantized +``` + +```{eval-rst} +.. py:module:: torch.backends.xnnpack +``` + +```{eval-rst} +.. py:module:: torch.backends.kleidiai + +``` + +## torch.backends.opt_einsum + +```{eval-rst} +.. automodule:: torch.backends.opt_einsum +``` + +```{eval-rst} +.. autofunction:: torch.backends.opt_einsum.is_available +``` + +```{eval-rst} +.. autofunction:: torch.backends.opt_einsum.get_opt_einsum +``` + +```{eval-rst} +.. attribute:: enabled + + A :class:`bool` that controls whether opt_einsum is enabled (``True`` by default). If so, + torch.einsum will use opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html) + if available to calculate an optimal path of contraction for faster performance. + + If opt_einsum is not available, torch.einsum will fall back to the default contraction path + of left to right. +``` + +```{eval-rst} +.. attribute:: strategy + + A :class:`str` that specifies which strategies to try when ``torch.backends.opt_einsum.enabled`` + is ``True``. By default, torch.einsum will try the "auto" strategy, but the "greedy" and "optimal" + strategies are also supported. Note that the "optimal" strategy is factorial on the number of + inputs as it tries all possible paths. See more details in opt_einsum's docs + (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html). + +``` + +## torch.backends.xeon + +```{eval-rst} +.. automodule:: torch.backends.xeon +``` + +```{eval-rst} +.. py:module:: torch.backends.xeon.run_cpu +``` diff --git a/docs/source/backends.rst b/docs/source/backends.rst deleted file mode 100644 index de11a3c957481d..00000000000000 --- a/docs/source/backends.rst +++ /dev/null @@ -1,240 +0,0 @@ -.. role:: hidden - :class: hidden-section - -torch.backends -============== -.. automodule:: torch.backends - -`torch.backends` controls the behavior of various backends that PyTorch supports. - -These backends include: - -- ``torch.backends.cpu`` -- ``torch.backends.cuda`` -- ``torch.backends.cudnn`` -- ``torch.backends.cusparselt`` -- ``torch.backends.mha`` -- ``torch.backends.mps`` -- ``torch.backends.mkl`` -- ``torch.backends.mkldnn`` -- ``torch.backends.nnpack`` -- ``torch.backends.openmp`` -- ``torch.backends.opt_einsum`` -- ``torch.backends.xeon`` - -torch.backends.cpu -^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.cpu - -.. autofunction:: torch.backends.cpu.get_cpu_capability - -torch.backends.cuda -^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.cuda - -.. autofunction:: torch.backends.cuda.is_built - -.. currentmodule:: torch.backends.cuda.matmul -.. attribute:: allow_tf32 - - A :class:`bool` that controls whether TensorFloat-32 tensor cores may be used in matrix - multiplications on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. - -.. attribute:: allow_fp16_reduced_precision_reduction - - A :class:`bool` that controls whether reduced precision reductions (e.g., with fp16 accumulation type) are allowed with fp16 GEMMs. - -.. attribute:: allow_bf16_reduced_precision_reduction - - A :class:`bool` that controls whether reduced precision reductions are allowed with bf16 GEMMs. - -.. currentmodule:: torch.backends.cuda -.. attribute:: cufft_plan_cache - - ``cufft_plan_cache`` contains the cuFFT plan caches for each CUDA device. - Query a specific device `i`'s cache via `torch.backends.cuda.cufft_plan_cache[i]`. - - .. currentmodule:: torch.backends.cuda.cufft_plan_cache - .. attribute:: size - - A readonly :class:`int` that shows the number of plans currently in a cuFFT plan cache. - - .. attribute:: max_size - - A :class:`int` that controls the capacity of a cuFFT plan cache. - - .. method:: clear() - - Clears a cuFFT plan cache. - -.. autofunction:: torch.backends.cuda.preferred_blas_library - -.. autofunction:: torch.backends.cuda.preferred_rocm_fa_library - -.. autofunction:: torch.backends.cuda.preferred_linalg_library - -.. autoclass:: torch.backends.cuda.SDPAParams - -.. autofunction:: torch.backends.cuda.flash_sdp_enabled - -.. autofunction:: torch.backends.cuda.enable_mem_efficient_sdp - -.. autofunction:: torch.backends.cuda.mem_efficient_sdp_enabled - -.. autofunction:: torch.backends.cuda.enable_flash_sdp - -.. autofunction:: torch.backends.cuda.math_sdp_enabled - -.. autofunction:: torch.backends.cuda.enable_math_sdp - -.. autofunction:: torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed - -.. autofunction:: torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp - -.. autofunction:: torch.backends.cuda.cudnn_sdp_enabled - -.. autofunction:: torch.backends.cuda.enable_cudnn_sdp - -.. autofunction:: torch.backends.cuda.is_flash_attention_available - -.. autofunction:: torch.backends.cuda.can_use_flash_attention - -.. autofunction:: torch.backends.cuda.can_use_efficient_attention - -.. autofunction:: torch.backends.cuda.can_use_cudnn_attention - -.. autofunction:: torch.backends.cuda.sdp_kernel - -torch.backends.cudnn -^^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.cudnn - -.. autofunction:: torch.backends.cudnn.version - -.. autofunction:: torch.backends.cudnn.is_available - -.. attribute:: enabled - - A :class:`bool` that controls whether cuDNN is enabled. - -.. attribute:: allow_tf32 - - A :class:`bool` that controls where TensorFloat-32 tensor cores may be used in cuDNN - convolutions on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. - -.. attribute:: deterministic - - A :class:`bool` that, if True, causes cuDNN to only use deterministic convolution algorithms. - See also :func:`torch.are_deterministic_algorithms_enabled` and - :func:`torch.use_deterministic_algorithms`. - -.. attribute:: benchmark - - A :class:`bool` that, if True, causes cuDNN to benchmark multiple convolution algorithms - and select the fastest. - -.. attribute:: benchmark_limit - - A :class:`int` that specifies the maximum number of cuDNN convolution algorithms to try when - `torch.backends.cudnn.benchmark` is True. Set `benchmark_limit` to zero to try every - available algorithm. Note that this setting only affects convolutions dispatched via the - cuDNN v8 API. - -.. py:module:: torch.backends.cudnn.rnn - -torch.backends.cusparselt -^^^^^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.cusparselt - -.. autofunction:: torch.backends.cusparselt.version - -.. autofunction:: torch.backends.cusparselt.is_available - -torch.backends.mha -^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.mha - -.. autofunction:: torch.backends.mha.get_fastpath_enabled -.. autofunction:: torch.backends.mha.set_fastpath_enabled - - -torch.backends.mps -^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.mps - -.. autofunction:: torch.backends.mps.is_available - -.. autofunction:: torch.backends.mps.is_built - - -torch.backends.mkl -^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.mkl - -.. autofunction:: torch.backends.mkl.is_available - -.. autoclass:: torch.backends.mkl.verbose - - -torch.backends.mkldnn -^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.mkldnn - -.. autofunction:: torch.backends.mkldnn.is_available - -.. autoclass:: torch.backends.mkldnn.verbose - -torch.backends.nnpack -^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.nnpack - -.. autofunction:: torch.backends.nnpack.is_available - -.. autofunction:: torch.backends.nnpack.flags - -.. autofunction:: torch.backends.nnpack.set_flags - -torch.backends.openmp -^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.openmp - -.. autofunction:: torch.backends.openmp.is_available - -.. Docs for other backends need to be added here. -.. Automodules are just here to ensure checks run but they don't actually -.. add anything to the rendered page for now. -.. py:module:: torch.backends.quantized -.. py:module:: torch.backends.xnnpack -.. py:module:: torch.backends.kleidiai - - -torch.backends.opt_einsum -^^^^^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.opt_einsum - -.. autofunction:: torch.backends.opt_einsum.is_available - -.. autofunction:: torch.backends.opt_einsum.get_opt_einsum - -.. attribute:: enabled - - A :class:`bool` that controls whether opt_einsum is enabled (``True`` by default). If so, - torch.einsum will use opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html) - if available to calculate an optimal path of contraction for faster performance. - - If opt_einsum is not available, torch.einsum will fall back to the default contraction path - of left to right. - -.. attribute:: strategy - - A :class:`str` that specifies which strategies to try when ``torch.backends.opt_einsum.enabled`` - is ``True``. By default, torch.einsum will try the "auto" strategy, but the "greedy" and "optimal" - strategies are also supported. Note that the "optimal" strategy is factorial on the number of - inputs as it tries all possible paths. See more details in opt_einsum's docs - (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html). - - -torch.backends.xeon -^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.backends.xeon -.. py:module:: torch.backends.xeon.run_cpu diff --git a/docs/source/benchmark_utils.md b/docs/source/benchmark_utils.md new file mode 100644 index 00000000000000..8f58b60b034263 --- /dev/null +++ b/docs/source/benchmark_utils.md @@ -0,0 +1,58 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Benchmark Utils - torch.utils.benchmark + +```{eval-rst} +.. automodule:: torch.utils.benchmark +``` + +```{eval-rst} +.. currentmodule:: torch.utils.benchmark +``` + +```{eval-rst} +.. autoclass:: Timer + :members: +``` + +```{eval-rst} +.. autoclass:: Measurement + :members: +``` + +```{eval-rst} +.. autoclass:: CallgrindStats + :members: +``` + +```{eval-rst} +.. autoclass:: FunctionCounts + :members: +``` + +```{eval-rst} +.. autoclass:: Compare + :members: +``` + +% These are missing documentation. Adding them here until a better place +% is made in this file. + +```{eval-rst} +.. py:module:: torch.utils.benchmark.examples +``` + +```{eval-rst} +.. py:module:: torch.utils.benchmark.op_fuzzers +``` + +```{eval-rst} +.. py:module:: torch.utils.benchmark.utils +``` + +```{eval-rst} +.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper +``` diff --git a/docs/source/benchmark_utils.rst b/docs/source/benchmark_utils.rst deleted file mode 100644 index 7546179c503fd3..00000000000000 --- a/docs/source/benchmark_utils.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Benchmark Utils - torch.utils.benchmark -================================================== - -.. automodule:: torch.utils.benchmark -.. currentmodule:: torch.utils.benchmark - -.. autoclass:: Timer - :members: - -.. autoclass:: Measurement - :members: - -.. autoclass:: CallgrindStats - :members: - -.. autoclass:: FunctionCounts - :members: - -.. autoclass:: Compare - :members: - -.. These are missing documentation. Adding them here until a better place -.. is made in this file. -.. py:module:: torch.utils.benchmark.examples -.. py:module:: torch.utils.benchmark.op_fuzzers -.. py:module:: torch.utils.benchmark.utils -.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper diff --git a/docs/source/checkpoint.md b/docs/source/checkpoint.md new file mode 100644 index 00000000000000..d27d0a44021f41 --- /dev/null +++ b/docs/source/checkpoint.md @@ -0,0 +1,42 @@ +# torch.utils.checkpoint + +```{note} +Checkpointing is implemented by rerunning a forward-pass segment for +each checkpointed segment during backward propagation. This can cause persistent +states like the RNG state to be more advanced than they would without +checkpointing. By default, checkpointing includes logic to juggle +the RNG state such that checkpointed passes making use of RNG +(through dropout for example) have deterministic output as +compared to non-checkpointed passes. The logic to stash and restore +RNG states can incur a moderate performance hit depending on the runtime +of checkpointed operations. If deterministic output compared to +non-checkpointed passes is not required, supply `preserve_rng_state=False` +to `checkpoint` or `checkpoint_sequential` to omit stashing and +restoring the RNG state during each checkpoint. + +The stashing logic saves and restores the RNG state for CPU and another +device type (infer the device type from Tensor arguments excluding CPU +tensors by `_infer_device_type`) to the `run_fn`. If there are multiple +device, device state will only be saved for devices of a single device type, +and the remaining devices will be ignored. Consequently, if any checkpointed +functions involve randomness, this may result in incorrect gradients. (Note +that if CUDA devices are among the devices detected, it will be prioritized; +otherwise, the first device encountered will be selected.) If there are no +CPU-tensors, the default device type state (default value is `cuda`, and it +could be set to other device by `DefaultDeviceType`) will be saved and restored. +However, the logic has no way to anticipate if the user will move +Tensors to a new device within the `run_fn` itself. Therefore, if you move +Tensors to a new device ("new" meaning not belonging to the set of +[current device + devices of Tensor arguments]) within `run_fn`, deterministic +output compared to non-checkpointed passes is never guaranteed. +``` + +```{eval-rst} +.. currentmodule:: torch.utils.checkpoint +.. autofunction:: checkpoint +.. autofunction:: checkpoint_sequential +.. autofunction:: set_checkpoint_debug_enabled +.. autoclass:: CheckpointPolicy +.. autoclass:: SelectiveCheckpointContext +.. autofunction:: create_selective_checkpoint_contexts +``` diff --git a/docs/source/checkpoint.rst b/docs/source/checkpoint.rst deleted file mode 100644 index 8559d8bd73663c..00000000000000 --- a/docs/source/checkpoint.rst +++ /dev/null @@ -1,40 +0,0 @@ -torch.utils.checkpoint -====================== - -.. note:: - Checkpointing is implemented by rerunning a forward-pass segment for - each checkpointed segment during backward propagation. This can cause persistent - states like the RNG state to be more advanced than they would without - checkpointing. By default, checkpointing includes logic to juggle - the RNG state such that checkpointed passes making use of RNG - (through dropout for example) have deterministic output as - compared to non-checkpointed passes. The logic to stash and restore - RNG states can incur a moderate performance hit depending on the runtime - of checkpointed operations. If deterministic output compared to - non-checkpointed passes is not required, supply ``preserve_rng_state=False`` - to ``checkpoint`` or ``checkpoint_sequential`` to omit stashing and - restoring the RNG state during each checkpoint. - - The stashing logic saves and restores the RNG state for CPU and another - device type (infer the device type from Tensor arguments excluding CPU - tensors by ``_infer_device_type``) to the ``run_fn``. If there are multiple - device, device state will only be saved for devices of a single device type, - and the remaining devices will be ignored. Consequently, if any checkpointed - functions involve randomness, this may result in incorrect gradients. (Note - that if CUDA devices are among the devices detected, it will be prioritized; - otherwise, the first device encountered will be selected.) If there are no - CPU-tensors, the default device type state (default value is `cuda`, and it - could be set to other device by ``DefaultDeviceType``) will be saved and restored. - However, the logic has no way to anticipate if the user will move - Tensors to a new device within the ``run_fn`` itself. Therefore, if you move - Tensors to a new device ("new" meaning not belonging to the set of - [current device + devices of Tensor arguments]) within ``run_fn``, deterministic - output compared to non-checkpointed passes is never guaranteed. - -.. currentmodule:: torch.utils.checkpoint -.. autofunction:: checkpoint -.. autofunction:: checkpoint_sequential -.. autofunction:: set_checkpoint_debug_enabled -.. autoclass:: CheckpointPolicy -.. autoclass:: SelectiveCheckpointContext -.. autofunction:: create_selective_checkpoint_contexts diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index f4301fe50322ee..5e5ef631ceaa74 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -349,9 +349,9 @@ XLA TorchServe ~~~~~~~~~~ -- Li Ning (`lxning `__) -- Ankith Gunapal (`agunapal `__) -- Hamid Shojanazeri (`HamidShojanazeri `__) +- (emeritus) Li Ning (`lxning `__) +- (emeritus) Ankith Gunapal (`agunapal `__) +- (emeritus) Hamid Shojanazeri (`HamidShojanazeri `__) - (emeritus) Mark Saroufim (`msaroufIm `__) - (emeritus) Manoj Rao (`mycpuorg `__) - (emeritus) Vamshi Dantu (`vdantu `__) diff --git a/docs/source/complex_numbers.md b/docs/source/complex_numbers.md new file mode 100644 index 00000000000000..610f9a06615a1c --- /dev/null +++ b/docs/source/complex_numbers.md @@ -0,0 +1,161 @@ +(complex_numbers-doc)= + +# Complex Numbers + +Complex numbers are numbers that can be expressed in the form {math}`a + bj`, where a and b are real numbers, +and *j* is called the imaginary unit, which satisfies the equation {math}`j^2 = -1`. Complex numbers frequently occur in mathematics and +engineering, especially in topics like signal processing. Traditionally many users and libraries (e.g., TorchAudio) have +handled complex numbers by representing the data in float tensors with shape {math}`(..., 2)` where the last +dimension contains the real and imaginary values. + +Tensors of complex dtypes provide a more natural user experience while working with complex numbers. Operations on +complex tensors (e.g., {func}`torch.mv`, {func}`torch.matmul`) are likely to be faster and more memory efficient +than operations on float tensors mimicking them. Operations involving complex numbers in PyTorch are optimized +to use vectorized assembly instructions and specialized kernels (e.g. LAPACK, cuBlas). + +```{note} +Spectral operations in the [torch.fft module](https://pytorch.org/docs/stable/fft.html#torch-fft) support +native complex tensors. +``` + +```{warning} +Complex tensors is a beta feature and subject to change. +``` + +## Creating Complex Tensors + +We support two complex dtypes: `torch.cfloat` and `torch.cdouble` + +```python +>>> x = torch.randn(2,2, dtype=torch.cfloat) +>>> x +tensor([[-0.4621-0.0303j, -0.2438-0.5874j], + [ 0.7706+0.1421j, 1.2110+0.1918j]]) +``` + +```{note} +The default dtype for complex tensors is determined by the default floating point dtype. +If the default floating point dtype is `torch.float64` then complex numbers are inferred to +have a dtype of `torch.complex128`, otherwise they are assumed to have a dtype of `torch.complex64`. +``` + +All factory functions apart from {func}`torch.linspace`, {func}`torch.logspace`, and {func}`torch.arange` are +supported for complex tensors. + +## Transition from the old representation + +Users who currently worked around the lack of complex tensors with real tensors of shape {math}`(..., 2)` +can easily to switch using the complex tensors in their code using {func}`torch.view_as_complex` +and {func}`torch.view_as_real`. Note that these functions don’t perform any copy and return a +view of the input tensor. + +```python +>>> x = torch.randn(3, 2) +>>> x +tensor([[ 0.6125, -0.1681], + [-0.3773, 1.3487], + [-0.0861, -0.7981]]) +>>> y = torch.view_as_complex(x) +>>> y +tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j]) +>>> torch.view_as_real(y) +tensor([[ 0.6125, -0.1681], + [-0.3773, 1.3487], + [-0.0861, -0.7981]]) +``` + +## Accessing real and imag + +The real and imaginary values of a complex tensor can be accessed using the {attr}`real` and +{attr}`imag`. + +```{note} +Accessing `real` and `imag` attributes doesn't allocate any memory, and in-place updates on the +`real` and `imag` tensors will update the original complex tensor. Also, the +returned `real` and `imag` tensors are not contiguous. +``` + +```python +>>> y.real +tensor([ 0.6125, -0.3773, -0.0861]) +>>> y.imag +tensor([-0.1681, 1.3487, -0.7981]) + +>>> y.real.mul_(2) +tensor([ 1.2250, -0.7546, -0.1722]) +>>> y +tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j]) +>>> y.real.stride() +(2,) +``` + +## Angle and abs + +The angle and absolute values of a complex tensor can be computed using {func}`torch.angle` and +{func}`torch.abs`. + +```python +>>> x1=torch.tensor([3j, 4+4j]) +>>> x1.abs() +tensor([3.0000, 5.6569]) +>>> x1.angle() +tensor([1.5708, 0.7854]) +``` + +## Linear Algebra + +Many linear algebra operations, like {func}`torch.matmul`, {func}`torch.linalg.svd`, {func}`torch.linalg.solve` etc., support complex numbers. +If you'd like to request an operation we don't currently support, please [search](https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+complex) +if an issue has already been filed and if not, [file one](https://github.com/pytorch/pytorch/issues/new/choose). + +## Serialization + +Complex tensors can be serialized, allowing data to be saved as complex values. + +```python +>>> torch.save(y, 'complex_tensor.pt') +>>> torch.load('complex_tensor.pt') +tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j]) +``` + +## Autograd + +PyTorch supports autograd for complex tensors. The gradient computed is the Conjugate Wirtinger derivative, +the negative of which is precisely the direction of steepest descent used in Gradient Descent algorithm. Thus, +all the existing optimizers can be implemented to work out of the box with complex parameters. For more details, +check out the note {ref}`complex_autograd-doc`. + +## Optimizers + +Semantically, we define stepping through a PyTorch optimizer with complex parameters as being equivalent to stepping +through the same optimizer on the {func}`torch.view_as_real` equivalent of the complex params. More concretely: + +```python +>>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)] +>>> real_params = [torch.view_as_real(p) for p in params] + +>>> complex_optim = torch.optim.AdamW(params) +>>> real_optim = torch.optim.AdamW(real_params) +``` + +`real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical +discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers +and capturable vs default optimizers. For more details, see [numbercial accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html). + +Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their +`p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the +{func}`torch.view_as_real` equivalent will convert a complex tensor to a real tensor with shape {math}`(..., 2)`, +whereas splitting a complex tensor into two tensors is 2 tensors of size {math}`(...)`. This distinction has no impact on +pointwise optimizers (like AdamW) but will cause slight discrepancy in optimizers that do global reductions (like LBFGS). +We currently do not have optimizers that do per-Tensor reductions and thus do not yet define this behavior. Open an issue +if you have a use case that requires precisely defining this behavior. + +We do not fully support the following subsystems: + +* Quantization +* JIT +* Sparse Tensors +* Distributed + +If any of these would help your use case, please [search](https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+complex) +if an issue has already been filed and if not, [file one](https://github.com/pytorch/pytorch/issues/new/choose). \ No newline at end of file diff --git a/docs/source/complex_numbers.rst b/docs/source/complex_numbers.rst deleted file mode 100644 index 87de9f8a4088ed..00000000000000 --- a/docs/source/complex_numbers.rst +++ /dev/null @@ -1,175 +0,0 @@ -.. _complex_numbers-doc: - -Complex Numbers -=============== - -Complex numbers are numbers that can be expressed in the form :math:`a + bj`, where a and b are real numbers, -and *j* is called the imaginary unit, which satisfies the equation :math:`j^2 = -1`. Complex numbers frequently occur in mathematics and -engineering, especially in topics like signal processing. Traditionally many users and libraries (e.g., TorchAudio) have -handled complex numbers by representing the data in float tensors with shape :math:`(..., 2)` where the last -dimension contains the real and imaginary values. - -Tensors of complex dtypes provide a more natural user experience while working with complex numbers. Operations on -complex tensors (e.g., :func:`torch.mv`, :func:`torch.matmul`) are likely to be faster and more memory efficient -than operations on float tensors mimicking them. Operations involving complex numbers in PyTorch are optimized -to use vectorized assembly instructions and specialized kernels (e.g. LAPACK, cuBlas). - -.. note:: - Spectral operations in the `torch.fft module `_ support - native complex tensors. - -.. warning :: - Complex tensors is a beta feature and subject to change. - -Creating Complex Tensors ------------------------- - -We support two complex dtypes: `torch.cfloat` and `torch.cdouble` - -:: - - >>> x = torch.randn(2,2, dtype=torch.cfloat) - >>> x - tensor([[-0.4621-0.0303j, -0.2438-0.5874j], - [ 0.7706+0.1421j, 1.2110+0.1918j]]) - -.. note:: - - The default dtype for complex tensors is determined by the default floating point dtype. - If the default floating point dtype is `torch.float64` then complex numbers are inferred to - have a dtype of `torch.complex128`, otherwise they are assumed to have a dtype of `torch.complex64`. - -All factory functions apart from :func:`torch.linspace`, :func:`torch.logspace`, and :func:`torch.arange` are -supported for complex tensors. - -Transition from the old representation --------------------------------------- - -Users who currently worked around the lack of complex tensors with real tensors of shape :math:`(..., 2)` -can easily to switch using the complex tensors in their code using :func:`torch.view_as_complex` -and :func:`torch.view_as_real`. Note that these functions don’t perform any copy and return a -view of the input tensor. - -:: - - >>> x = torch.randn(3, 2) - >>> x - tensor([[ 0.6125, -0.1681], - [-0.3773, 1.3487], - [-0.0861, -0.7981]]) - >>> y = torch.view_as_complex(x) - >>> y - tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j]) - >>> torch.view_as_real(y) - tensor([[ 0.6125, -0.1681], - [-0.3773, 1.3487], - [-0.0861, -0.7981]]) - -Accessing real and imag ------------------------ - -The real and imaginary values of a complex tensor can be accessed using the :attr:`real` and -:attr:`imag`. - -.. note:: - Accessing `real` and `imag` attributes doesn't allocate any memory, and in-place updates on the - `real` and `imag` tensors will update the original complex tensor. Also, the - returned `real` and `imag` tensors are not contiguous. - -:: - - >>> y.real - tensor([ 0.6125, -0.3773, -0.0861]) - >>> y.imag - tensor([-0.1681, 1.3487, -0.7981]) - - >>> y.real.mul_(2) - tensor([ 1.2250, -0.7546, -0.1722]) - >>> y - tensor([ 1.2250-0.1681j, -0.7546+1.3487j, -0.1722-0.7981j]) - >>> y.real.stride() - (2,) - -Angle and abs -------------- - -The angle and absolute values of a complex tensor can be computed using :func:`torch.angle` and -:func:`torch.abs`. - -:: - - >>> x1=torch.tensor([3j, 4+4j]) - >>> x1.abs() - tensor([3.0000, 5.6569]) - >>> x1.angle() - tensor([1.5708, 0.7854]) - -Linear Algebra --------------- - -Many linear algebra operations, like :func:`torch.matmul`, :func:`torch.linalg.svd`, :func:`torch.linalg.solve` etc., support complex numbers. -If you'd like to request an operation we don't currently support, please `search `_ -if an issue has already been filed and if not, `file one `_. - - -Serialization -------------- - -Complex tensors can be serialized, allowing data to be saved as complex values. - -:: - - >>> torch.save(y, 'complex_tensor.pt') - >>> torch.load('complex_tensor.pt') - tensor([ 0.6125-0.1681j, -0.3773+1.3487j, -0.0861-0.7981j]) - - -Autograd --------- - -PyTorch supports autograd for complex tensors. The gradient computed is the Conjugate Wirtinger derivative, -the negative of which is precisely the direction of steepest descent used in Gradient Descent algorithm. Thus, -all the existing optimizers can be implemented to work out of the box with complex parameters. For more details, -check out the note :ref:`complex_autograd-doc`. - - -Optimizers ----------- - -Semantically, we define stepping through a PyTorch optimizer with complex parameters as being equivalent to stepping -through the same optimizer on the :func:`torch.view_as_real` equivalent of the complex params. More concretely: - -:: - - >>> params = [torch.rand(2, 3, dtype=torch.complex64) for _ in range(5)] - >>> real_params = [torch.view_as_real(p) for p in params] - - >>> complex_optim = torch.optim.AdamW(params) - >>> real_optim = torch.optim.AdamW(real_params) - - -`real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical -discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers -and capturable vs default optimizers. For more details, see https://pytorch.org/docs/stable/notes/numerical_accuracy.html. - -Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their -`p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the -:func:`torch.view_as_real` equivalent will convert a complex tensor to a real tensor with shape :math:`(..., 2)`, -whereas splitting a complex tensor into two tensors is 2 tensors of size :math:`(...)`. This distinction has no impact on -pointwise optimizers (like AdamW) but will cause slight discrepancy in optimizers that do global reductions (like LBFGS). -We currently do not have optimizers that do per-Tensor reductions and thus do not yet define this behavior. Open an issue -if you have a use case that requires precisely defining this behavior. - - -We do not fully support the following subsystems: - -* Quantization - -* JIT - -* Sparse Tensors - -* Distributed - -If any of these would help your use case, please `search `_ -if an issue has already been filed and if not, `file one `_. diff --git a/docs/source/cond.md b/docs/source/cond.md new file mode 100644 index 00000000000000..0765d59dae7fdd --- /dev/null +++ b/docs/source/cond.md @@ -0,0 +1,174 @@ +(cond)= + +# Control Flow - Cond + +`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow +and can logically be seen as implemented as follows. + +```python +def cond( + pred: Union[bool, torch.Tensor], + true_fn: Callable, + false_fn: Callable, + operands: Tuple[torch.Tensor] +): + if pred: + return true_fn(*operands) + else: + return false_fn(*operands) +``` + +Its unique power lies in its ability of expressing **data-dependent control flow**: it lowers to a conditional +operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions. +This unlocks great flexibility in writing and deploying models that change model architecture based on +the **value** or **shape** of inputs or intermediate outputs of tensor operations. + +```{warning} +`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and +doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. +Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype +``` + +## Examples + +Below is an example that uses cond to branch based on input shape: + +```python + import torch + + def true_fn(x: torch.Tensor): + return x.cos() + x.sin() + + def false_fn(x: torch.Tensor): + return x.sin() + + class DynamicShapeCondPredicate(torch.nn.Module): + """ + A basic usage of cond based on dynamic shape predicate. + """ + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_fn(x: torch.Tensor): + return x.cos() + + def false_fn(x: torch.Tensor): + return x.sin() + + return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,)) + + dyn_shape_mod = DynamicShapeCondPredicate() +``` + +We can eagerly run the model and expect the results vary based on input shape: + +```python + inp = torch.randn(3) + inp2 = torch.randn(5) + assert torch.equal(dyn_shape_mod(inp), false_fn(inp)) + assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2)) +``` + +We can export the model for further transformations and deployment: + +```python + inp = torch.randn(4, 3) + dim_batch = torch.export.Dim("batch", min=2) + ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}}) + print(ep) +``` + +This gives us an exported program as shown below: + +``` + class GraphModule(torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0) + gt: Sym(s0 > 4) = sym_size > 4; sym_size = None + true_graph_0 = self.true_graph_0 + false_graph_0 = self.false_graph_0 + conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None + return (conditional,) + + class (torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) + sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None + return add + + class (torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + return sin +``` + +Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input, +and branch functions becomes two sub-graph attributes of the top level graph module. + +Here is another example that showcases how to express a data-dependent control flow: + +```python + class DataDependentCondPredicate(torch.nn.Module): + """ + A basic usage of cond based on data dependent predicate. + """ + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,)) +``` + +The exported program we get after export: + +``` + class GraphModule(torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + sum_1: f32[] = torch.ops.aten.sum.default(arg0_1) + gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None + + true_graph_0 = self.true_graph_0 + false_graph_0 = self.false_graph_0 + conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None + return (conditional,) + + class (torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) + sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None + return add + + class (torch.nn.Module): + def forward(self, arg0_1: f32[s0, 3]): + sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None + return sin +``` + +## Invariants of torch.ops.higher_order.cond + +There are several useful invariants for `torch.ops.higher_order.cond`: + +- For predicate: + - Dynamicness of predicate is preserved (e.g. `gt` shown in the above example) + - If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant. + +- For branches: + - The input and output signature will be a flattened tuple. + - They are `torch.fx.GraphModule`. + - Closures in original function becomes explicit inputs. No closures. + - No mutations on inputs or globals are allowed. + +- For operands: + - It will also be a flat tuple. + +- Nesting of `torch.cond` in user program becomes nested graph modules. + +## API Reference + +```{eval-rst} +.. autofunction:: torch._higher_order_ops.cond.cond +``` diff --git a/docs/source/cond.rst b/docs/source/cond.rst deleted file mode 100644 index c43ce4fd6d9a2c..00000000000000 --- a/docs/source/cond.rst +++ /dev/null @@ -1,176 +0,0 @@ -.. _cond: - -Control Flow - Cond -==================== - -`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow -and can logically be seen as implemented as follows. - -.. code-block:: python - - def cond( - pred: Union[bool, torch.Tensor], - true_fn: Callable, - false_fn: Callable, - operands: Tuple[torch.Tensor] - ): - if pred: - return true_fn(*operands) - else: - return false_fn(*operands) - -Its unique power lies in its ability of expressing **data-dependent control flow**: it lowers to a conditional -operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions. -This unlocks great flexibility in writing and deploying models that change model architecture based on -the **value** or **shape** of inputs or intermediate outputs of tensor operations. - -.. warning:: - `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and - doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. - Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype - -Examples -~~~~~~~~ - -Below is an example that uses cond to branch based on input shape: - -.. code-block:: python - - import torch - - def true_fn(x: torch.Tensor): - return x.cos() + x.sin() - - def false_fn(x: torch.Tensor): - return x.sin() - - class DynamicShapeCondPredicate(torch.nn.Module): - """ - A basic usage of cond based on dynamic shape predicate. - """ - - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - def true_fn(x: torch.Tensor): - return x.cos() - - def false_fn(x: torch.Tensor): - return x.sin() - - return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,)) - - dyn_shape_mod = DynamicShapeCondPredicate() - -We can eagerly run the model and expect the results vary based on input shape: - -.. code-block:: python - - inp = torch.randn(3) - inp2 = torch.randn(5) - assert torch.equal(dyn_shape_mod(inp), false_fn(inp)) - assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2)) - -We can export the model for further transformations and deployment: - -.. code-block:: python - - inp = torch.randn(4, 3) - dim_batch = torch.export.Dim("batch", min=2) - ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}}) - print(ep) - -This gives us an exported program as shown below: - -.. code-block:: - - class GraphModule(torch.nn.Module): - def forward(self, arg0_1: f32[s0, 3]): - sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0) - gt: Sym(s0 > 4) = sym_size > 4; sym_size = None - true_graph_0 = self.true_graph_0 - false_graph_0 = self.false_graph_0 - conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None - return (conditional,) - - class (torch.nn.Module): - def forward(self, arg0_1: f32[s0, 3]): - cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) - sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None - add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None - return add - - class (torch.nn.Module): - def forward(self, arg0_1: f32[s0, 3]): - sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None - return sin - -Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input, -and branch functions becomes two sub-graph attributes of the top level graph module. - -Here is another example that showcases how to express a data-dependent control flow: - -.. code-block:: python - - class DataDependentCondPredicate(torch.nn.Module): - """ - A basic usage of cond based on data dependent predicate. - """ - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,)) - -The exported program we get after export: - -.. code-block:: - - class GraphModule(torch.nn.Module): - def forward(self, arg0_1: f32[s0, 3]): - sum_1: f32[] = torch.ops.aten.sum.default(arg0_1) - gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None - - true_graph_0 = self.true_graph_0 - false_graph_0 = self.false_graph_0 - conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None - return (conditional,) - - class (torch.nn.Module): - def forward(self, arg0_1: f32[s0, 3]): - cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) - sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None - add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None - return add - - class (torch.nn.Module): - def forward(self, arg0_1: f32[s0, 3]): - sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None - return sin - - -Invariants of torch.ops.higher_order.cond -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -There are several useful invariants for `torch.ops.higher_order.cond`: - -- For predicate: - - Dynamicness of predicate is preserved (e.g. `gt` shown in the above example) - - If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant. - -- For branches: - - The input and output signature will be a flattened tuple. - - They are `torch.fx.GraphModule`. - - Closures in original function becomes explicit inputs. No closures. - - No mutations on inputs or globals are allowed. - -- For operands: - - It will also be a flat tuple. - -- Nesting of `torch.cond` in user program becomes nested graph modules. - - -API Reference -------------- -.. autofunction:: torch._higher_order_ops.cond.cond diff --git a/docs/source/conf.py b/docs/source/conf.py index ed943929cfa501..acb2b088af7278 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -134,6 +134,7 @@ "json_url": "https://docs.pytorch.org/docs/pytorch-versions.json", "version_match": switcher_version, }, + "show_toc_level": 2, "navigation_with_keys": False, "external_links": [ { @@ -515,34 +516,8 @@ "graph_pool_handle", "is_current_stream_capturing", "make_graphed_callables", - # torch.cuda.memory - "caching_allocator_alloc", - "caching_allocator_delete", - "change_current_allocator", - "empty_cache", - "get_allocator_backend", - "get_per_process_memory_fraction", - "list_gpu_processes", - "max_memory_allocated", - "max_memory_cached", - "max_memory_reserved", - "mem_get_info", - "memory_allocated", - "memory_cached", - "memory_reserved", - "memory_snapshot", - "memory_stats", - "memory_stats_as_nested_dict", - "host_memory_stats", - "host_memory_stats_as_nested_dict", - "memory_summary", - "reset_accumulated_memory_stats", - "reset_accumulated_host_memory_stats", - "reset_max_memory_allocated", - "reset_max_memory_cached", + # torch.mtia.memory "reset_peak_memory_stats", - "reset_peak_host_memory_stats", - "set_per_process_memory_fraction", # torch.cuda.nccl "all_gather", "all_reduce", @@ -1323,10 +1298,6 @@ "scatter_kwargs", # torch.nn.parameter "is_lazy", - # torch.nn.utils.clip_grad - "clip_grad_norm", - "clip_grad_norm_", - "clip_grad_value_", # torch.nn.utils.convert_parameters "parameters_to_vector", "vector_to_parameters", @@ -2615,6 +2586,9 @@ # torch.distributed.checkpoint.filesystem "FileSystemReader", "FileSystemWriter", + # torch.distributed.checkpoint.hf_storage + "HuggingFaceStorageReader", + "HuggingFaceStorageWriter", # torch.distributed.checkpoint.metadata "BytesStorageMetadata", "ChunkStorageMetadata", diff --git a/docs/source/config_mod.md b/docs/source/config_mod.md new file mode 100644 index 00000000000000..eab05f297ee80b --- /dev/null +++ b/docs/source/config_mod.md @@ -0,0 +1,11 @@ +# torch.__config__ + +```{eval-rst} +.. automodule:: torch.__config__ +.. currentmodule:: torch.__config__ +``` + +```{eval-rst} +.. autofunction:: show +.. autofunction:: parallel_info +``` diff --git a/docs/source/config_mod.rst b/docs/source/config_mod.rst deleted file mode 100644 index adbfebe560a72c..00000000000000 --- a/docs/source/config_mod.rst +++ /dev/null @@ -1,8 +0,0 @@ -torch.__config__ -=================================== - -.. automodule:: torch.__config__ -.. currentmodule:: torch.__config__ - -.. autofunction:: show -.. autofunction:: parallel_info diff --git a/docs/source/cuda.md b/docs/source/cuda.md new file mode 100644 index 00000000000000..e72610fa81e72c --- /dev/null +++ b/docs/source/cuda.md @@ -0,0 +1,309 @@ +# torch.cuda + +```{eval-rst} +.. automodule:: torch.cuda +``` + +```{eval-rst} +.. currentmodule:: torch.cuda +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + StreamContext + can_device_access_peer + current_blas_handle + current_device + current_stream + cudart + default_stream + device + device_count + device_memory_used + device_of + get_arch_list + get_device_capability + get_device_name + get_device_properties + get_gencode_flags + get_stream_from_external + get_sync_debug_mode + init + ipc_collect + is_available + is_initialized + is_tf32_supported + memory_usage + set_device + set_stream + set_sync_debug_mode + stream + synchronize + utilization + temperature + power_draw + clock_rate + AcceleratorError + OutOfMemoryError +``` + +## Random Number Generator + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + get_rng_state + get_rng_state_all + set_rng_state + set_rng_state_all + manual_seed + manual_seed_all + seed + seed_all + initial_seed + +``` + +## Communication collectives + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + comm.broadcast + comm.broadcast_coalesced + comm.reduce_add + comm.reduce_add_coalesced + comm.scatter + comm.gather +``` + +## Streams and events + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + Stream + ExternalStream + Event +``` + +## Graphs (beta) + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + is_current_stream_capturing + graph_pool_handle + CUDAGraph + graph + make_graphed_callables +``` + +(cuda-memory-management-api)= + +```{eval-rst} +.. automodule:: torch.cuda.memory +``` + +```{eval-rst} +.. currentmodule:: torch.cuda.memory +``` + +## Memory management + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + get_per_process_memory_fraction + list_gpu_processes + mem_get_info + memory_stats + memory_stats_as_nested_dict + reset_accumulated_memory_stats + host_memory_stats + host_memory_stats_as_nested_dict + reset_accumulated_host_memory_stats + memory_summary + memory_snapshot + memory_allocated + max_memory_allocated + reset_max_memory_allocated + memory_reserved + max_memory_reserved + set_per_process_memory_fraction + memory_cached + max_memory_cached + reset_max_memory_cached + reset_peak_memory_stats + reset_peak_host_memory_stats + caching_allocator_alloc + caching_allocator_delete + get_allocator_backend + CUDAPluggableAllocator + change_current_allocator + MemPool +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + caching_allocator_enable +``` + +```{eval-rst} +.. currentmodule:: torch.cuda +``` + +```{eval-rst} +.. autoclass:: torch.cuda.use_mem_pool +``` + +% FIXME The following doesn't seem to exist. Is it supposed to? +% https://github.com/pytorch/pytorch/issues/27785 +% .. autofunction:: reset_max_memory_reserved + +## NVIDIA Tools Extension (NVTX) + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + nvtx.mark + nvtx.range_push + nvtx.range_pop + nvtx.range +``` + +## Jiterator (beta) + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + jiterator._create_jit_fn + jiterator._create_multi_output_jit_fn +``` + +## TunableOp + +Some operations could be implemented using more than one library or more than +one technique. For example, a GEMM could be implemented for CUDA or ROCm using +either the cublas/cublasLt libraries or hipblas/hipblasLt libraries, +respectively. How does one know which implementation is the fastest and should +be chosen? That's what TunableOp provides. Certain operators have been +implemented using multiple strategies as Tunable Operators. At runtime, all +strategies are profiled and the fastest is selected for all subsequent +operations. + +See the {doc}`documentation ` for information on how to use it. + +```{toctree} +:hidden: true + +cuda.tunable +``` + +## Stream Sanitizer (prototype) + +CUDA Sanitizer is a prototype tool for detecting synchronization errors between streams in PyTorch. +See the {doc}`documentation ` for information on how to use it. + +```{toctree} +:hidden: true + +cuda._sanitizer +``` + +## GPUDirect Storage (prototype) + +The APIs in `torch.cuda.gds` provide thin wrappers around certain cuFile APIs that allow +direct memory access transfers between GPU memory and storage, avoiding a bounce buffer in the CPU. See the +[cufile api documentation](https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api) +for more details. + +These APIs can be used in versions greater than or equal to CUDA 12.6. In order to use these APIs, one must +ensure that their system is appropriately configured to use GPUDirect Storage per the +[GPUDirect Storage documentation](https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/contents.html). + +See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use these. + +```{eval-rst} +.. currentmodule:: torch.cuda.gds +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + gds_register_buffer + gds_deregister_buffer + GdsFile + +``` + +% This module needs to be documented. Adding here in the meantime + +% for tracking purposes + +```{eval-rst} +.. py:module:: torch.cuda.comm +``` + +```{eval-rst} +.. py:module:: torch.cuda.error +``` + +```{eval-rst} +.. py:module:: torch.cuda.gds +``` + +```{eval-rst} +.. py:module:: torch.cuda.graphs +``` + +```{eval-rst} +.. py:module:: torch.cuda.jiterator +``` + +```{eval-rst} +.. py:module:: torch.cuda.nccl +``` + +```{eval-rst} +.. py:module:: torch.cuda.nvtx +``` + +```{eval-rst} +.. py:module:: torch.cuda.profiler +``` + +```{eval-rst} +.. py:module:: torch.cuda.random +``` + +```{eval-rst} +.. py:module:: torch.cuda.sparse +``` + +```{eval-rst} +.. py:module:: torch.cuda.streams +``` diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst deleted file mode 100644 index 431758fc2ef36d..00000000000000 --- a/docs/source/cuda.rst +++ /dev/null @@ -1,237 +0,0 @@ -torch.cuda -=================================== -.. automodule:: torch.cuda -.. currentmodule:: torch.cuda - -.. autosummary:: - :toctree: generated - :nosignatures: - - StreamContext - can_device_access_peer - current_blas_handle - current_device - current_stream - cudart - default_stream - device - device_count - device_memory_used - device_of - get_arch_list - get_device_capability - get_device_name - get_device_properties - get_gencode_flags - get_stream_from_external - get_sync_debug_mode - init - ipc_collect - is_available - is_initialized - is_tf32_supported - memory_usage - set_device - set_stream - set_sync_debug_mode - stream - synchronize - utilization - temperature - power_draw - clock_rate - OutOfMemoryError - -Random Number Generator -------------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - get_rng_state - get_rng_state_all - set_rng_state - set_rng_state_all - manual_seed - manual_seed_all - seed - seed_all - initial_seed - - -Communication collectives -------------------------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - comm.broadcast - comm.broadcast_coalesced - comm.reduce_add - comm.reduce_add_coalesced - comm.scatter - comm.gather - -Streams and events ------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - Stream - ExternalStream - Event - -Graphs (beta) -------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - is_current_stream_capturing - graph_pool_handle - CUDAGraph - graph - make_graphed_callables - -.. _cuda-memory-management-api: - -Memory management ------------------ -.. autosummary:: - :toctree: generated - :nosignatures: - - empty_cache - get_per_process_memory_fraction - list_gpu_processes - mem_get_info - memory_stats - host_memory_stats - memory_summary - memory_snapshot - memory_allocated - max_memory_allocated - reset_max_memory_allocated - memory_reserved - max_memory_reserved - set_per_process_memory_fraction - memory_cached - max_memory_cached - reset_max_memory_cached - reset_peak_memory_stats - reset_peak_host_memory_stats - caching_allocator_alloc - caching_allocator_delete - get_allocator_backend - CUDAPluggableAllocator - change_current_allocator - MemPool - -.. currentmodule:: torch.cuda.memory - -.. autosummary:: - :toctree: generated - :nosignatures: - - caching_allocator_enable - -.. currentmodule:: torch.cuda -.. autoclass:: torch.cuda.use_mem_pool - -.. FIXME The following doesn't seem to exist. Is it supposed to? - https://github.com/pytorch/pytorch/issues/27785 - .. autofunction:: reset_max_memory_reserved - -NVIDIA Tools Extension (NVTX) ------------------------------ - -.. autosummary:: - :toctree: generated - :nosignatures: - - nvtx.mark - nvtx.range_push - nvtx.range_pop - nvtx.range - -Jiterator (beta) ------------------------------ -.. autosummary:: - :toctree: generated - :nosignatures: - - jiterator._create_jit_fn - jiterator._create_multi_output_jit_fn - -TunableOp ---------- - -Some operations could be implemented using more than one library or more than -one technique. For example, a GEMM could be implemented for CUDA or ROCm using -either the cublas/cublasLt libraries or hipblas/hipblasLt libraries, -respectively. How does one know which implementation is the fastest and should -be chosen? That's what TunableOp provides. Certain operators have been -implemented using multiple strategies as Tunable Operators. At runtime, all -strategies are profiled and the fastest is selected for all subsequent -operations. - -See the :doc:`documentation ` for information on how to use it. - -.. toctree:: - :hidden: - - cuda.tunable - - -Stream Sanitizer (prototype) ----------------------------- - -CUDA Sanitizer is a prototype tool for detecting synchronization errors between streams in PyTorch. -See the :doc:`documentation ` for information on how to use it. - -.. toctree:: - :hidden: - - cuda._sanitizer - - -GPUDirect Storage (prototype) ------------------------------ - -The APIs in ``torch.cuda.gds`` provide thin wrappers around certain cuFile APIs that allow -direct memory access transfers between GPU memory and storage, avoiding a bounce buffer in the CPU. See the -`cufile api documentation `_ -for more details. - -These APIs can be used in versions greater than or equal to CUDA 12.6. In order to use these APIs, one must -ensure that their system is appropriately configured to use GPUDirect Storage per the -`GPUDirect Storage documentation `_. - -See the docs for :class:`~torch.cuda.gds.GdsFile` for an example of how to use these. - -.. currentmodule:: torch.cuda.gds -.. autosummary:: - :toctree: generated - :nosignatures: - - gds_register_buffer - gds_deregister_buffer - GdsFile - - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.cuda.comm -.. py:module:: torch.cuda.error -.. py:module:: torch.cuda.gds -.. py:module:: torch.cuda.graphs -.. py:module:: torch.cuda.jiterator -.. py:module:: torch.cuda.memory -.. py:module:: torch.cuda.nccl -.. py:module:: torch.cuda.nvtx -.. py:module:: torch.cuda.profiler -.. py:module:: torch.cuda.random -.. py:module:: torch.cuda.sparse -.. py:module:: torch.cuda.streams diff --git a/docs/source/cuda.tunable.md b/docs/source/cuda.tunable.md new file mode 100644 index 00000000000000..565633fe18814d --- /dev/null +++ b/docs/source/cuda.tunable.md @@ -0,0 +1,97 @@ +```{eval-rst} +.. currentmodule:: torch.cuda.tunable +``` + +# TunableOp + +## Overview + +```{eval-rst} +.. automodule:: torch.cuda.tunable +``` + +## API Reference + +```{eval-rst} +.. autofunction:: enable +``` + +```{eval-rst} +.. autofunction:: is_enabled +``` + +```{eval-rst} +.. autofunction:: tuning_enable +``` + +```{eval-rst} +.. autofunction:: tuning_is_enabled +``` + +```{eval-rst} +.. autofunction:: record_untuned_enable +``` + +```{eval-rst} +.. autofunction:: record_untuned_is_enabled +``` + +```{eval-rst} +.. autofunction:: set_max_tuning_duration +``` + +```{eval-rst} +.. autofunction:: get_max_tuning_duration +``` + +```{eval-rst} +.. autofunction:: set_max_tuning_iterations +``` + +```{eval-rst} +.. autofunction:: get_max_tuning_iterations +``` + +```{eval-rst} +.. autofunction:: set_filename +``` + +```{eval-rst} +.. autofunction:: get_filename +``` + +```{eval-rst} +.. autofunction:: get_results +``` + +```{eval-rst} +.. autofunction:: get_validators +``` + +```{eval-rst} +.. autofunction:: write_file_on_exit +``` + +```{eval-rst} +.. autofunction:: write_file +``` + +```{eval-rst} +.. autofunction:: read_file +``` + +```{eval-rst} +.. autofunction:: tune_gemm_in_file +``` + +```{eval-rst} +.. autofunction:: mgpu_tune_gemm_in_file +``` + +```{eval-rst} +.. autofunction:: set_rotating_buffer_size +``` + +```{eval-rst} +.. autofunction:: get_rotating_buffer_size +``` diff --git a/docs/source/cuda.tunable.rst b/docs/source/cuda.tunable.rst deleted file mode 100644 index 406871e9b273cf..00000000000000 --- a/docs/source/cuda.tunable.rst +++ /dev/null @@ -1,35 +0,0 @@ -.. currentmodule:: torch.cuda.tunable - -TunableOp -========= - - -Overview --------- - -.. automodule:: torch.cuda.tunable - -API Reference -------------- - -.. autofunction:: enable -.. autofunction:: is_enabled -.. autofunction:: tuning_enable -.. autofunction:: tuning_is_enabled -.. autofunction:: record_untuned_enable -.. autofunction:: record_untuned_is_enabled -.. autofunction:: set_max_tuning_duration -.. autofunction:: get_max_tuning_duration -.. autofunction:: set_max_tuning_iterations -.. autofunction:: get_max_tuning_iterations -.. autofunction:: set_filename -.. autofunction:: get_filename -.. autofunction:: get_results -.. autofunction:: get_validators -.. autofunction:: write_file_on_exit -.. autofunction:: write_file -.. autofunction:: read_file -.. autofunction:: tune_gemm_in_file -.. autofunction:: mgpu_tune_gemm_in_file -.. autofunction:: set_rotating_buffer_size -.. autofunction:: get_rotating_buffer_size diff --git a/docs/source/data.md b/docs/source/data.md new file mode 100644 index 00000000000000..77c9869dd87e0b --- /dev/null +++ b/docs/source/data.md @@ -0,0 +1,532 @@ +# torch.utils.data + +```{eval-rst} +.. automodule:: torch.utils.data +``` + +At the heart of PyTorch data loading utility is the {class}`torch.utils.data.DataLoader` +class. It represents a Python iterable over a dataset, with support for + +- {ref}`map-style and iterable-style datasets `, +- {ref}`customizing data loading order `, +- {ref}`automatic batching `, +- {ref}`single- and multi-process data loading `, +- {ref}`automatic memory pinning `. + +These options are configured by the constructor arguments of a +{class}`~torch.utils.data.DataLoader`, which has signature: + +```python +DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, collate_fn=None, + pin_memory=False, drop_last=False, timeout=0, + worker_init_fn=None, *, prefetch_factor=2, + persistent_workers=False) +``` + +The sections below describe in details the effects and usages of these options. + +(dataset-types)= +## Dataset Types + +The most important argument of {class}`~torch.utils.data.DataLoader` +constructor is {attr}`dataset`, which indicates a dataset object to load data +from. PyTorch supports two different types of datasets: + +- {ref}`map-style-datasets`, +- {ref}`iterable-style-datasets`. + +(map-style-datasets)= +### Map-style datasets + +A map-style dataset is one that implements the {meth}`__getitem__` and +{meth}`__len__` protocols, and represents a map from (possibly non-integral) +indices/keys to data samples. + +For example, such a dataset, when accessed with `dataset[idx]`, could read +the `idx`-th image and its corresponding label from a folder on the disk. + +See {class}`~torch.utils.data.Dataset` for more details. + +(iterable-style-datasets)= +### Iterable-style datasets + +An iterable-style dataset is an instance of a subclass of {class}`~torch.utils.data.IterableDataset` +that implements the {meth}`__iter__` protocol, and represents an iterable over +data samples. This type of datasets is particularly suitable for cases where +random reads are expensive or even improbable, and where the batch size depends +on the fetched data. + +For example, such a dataset, when called `iter(dataset)`, could return a +stream of data reading from a database, a remote server, or even logs generated +in real time. + +See {class}`~torch.utils.data.IterableDataset` for more details. + +:::{note} +When using a {class}`~torch.utils.data.IterableDataset` with +{ref}`multi-process data loading `. The same +dataset object is replicated on each worker process, and thus the +replicas must be configured differently to avoid duplicated data. See +{class}`~torch.utils.data.IterableDataset` documentations for how to +achieve this. +::: + + +(data-loading-order-and-sampler)= +## Data Loading Order and {class}`~torch.utils.data.Sampler` + +For {ref}`iterable-style datasets `, data loading order +is entirely controlled by the user-defined iterable. This allows easier +implementations of chunk-reading and dynamic batch size (e.g., by yielding a +batched sample at each time). + +The rest of this section concerns the case with +{ref}`map-style datasets `. {class}`torch.utils.data.Sampler` +classes are used to specify the sequence of indices/keys used in data loading. +They represent iterable objects over the indices to datasets. E.g., in the +common case with stochastic gradient decent (SGD), a +{class}`~torch.utils.data.Sampler` could randomly permute a list of indices +and yield each one at a time, or yield a small number of them for mini-batch +SGD. + +A sequential or shuffled sampler will be automatically constructed based on the {attr}`shuffle` argument to a {class}`~torch.utils.data.DataLoader`. +Alternatively, users may use the {attr}`sampler` argument to specify a +custom {class}`~torch.utils.data.Sampler` object that at each time yields +the next index/key to fetch. + +A custom {class}`~torch.utils.data.Sampler` that yields a list of batch +indices at a time can be passed as the {attr}`batch_sampler` argument. +Automatic batching can also be enabled via {attr}`batch_size` and +{attr}`drop_last` arguments. See +{ref}`the next section ` for more details +on this. + +:::{note} +Neither {attr}`sampler` nor {attr}`batch_sampler` is compatible with +iterable-style datasets, since such datasets have no notion of a key or an +index. +::: + +(loading-batched-and-non-batched-data)= +## Loading Batched and Non-Batched Data + +{class}`~torch.utils.data.DataLoader` supports automatically collating +individual fetched data samples into batches via arguments +{attr}`batch_size`, {attr}`drop_last`, {attr}`batch_sampler`, and +{attr}`collate_fn` (which has a default function). + +(automatic-batching-default)= +### Automatic batching (default) + +This is the most common case, and corresponds to fetching a minibatch of +data and collating them into batched samples, i.e., containing Tensors with +one dimension being the batch dimension (usually the first). + +When {attr}`batch_size` (default `1`) is not `None`, the data loader yields +batched samples instead of individual samples. {attr}`batch_size` and +{attr}`drop_last` arguments are used to specify how the data loader obtains +batches of dataset keys. For map-style datasets, users can alternatively +specify {attr}`batch_sampler`, which yields a list of keys at a time. + +:::{note} +The {attr}`batch_size` and {attr}`drop_last` arguments essentially are used +to construct a {attr}`batch_sampler` from {attr}`sampler`. For map-style +datasets, the {attr}`sampler` is either provided by user or constructed +based on the {attr}`shuffle` argument. For iterable-style datasets, the +{attr}`sampler` is a dummy infinite one. See +{ref}`this section ` on more details on +samplers. +::: + +:::{note} +When fetching from +{ref}`iterable-style datasets ` with +{ref}`multi-processing ` the {attr}`drop_last` +argument drops the last non-full batch of each worker's dataset replica. +::: + +After fetching a list of samples using the indices from sampler, the function +passed as the {attr}`collate_fn` argument is used to collate lists of samples +into batches. + +In this case, loading from a map-style dataset is roughly equivalent with: + +```python +for indices in batch_sampler: + yield collate_fn([dataset[i] for i in indices]) +``` + +and loading from an iterable-style dataset is roughly equivalent with: + +```python +dataset_iter = iter(dataset) +for indices in batch_sampler: + yield collate_fn([next(dataset_iter) for _ in indices]) +``` + +A custom {attr}`collate_fn` can be used to customize collation, e.g., padding +sequential data to max length of a batch. See +{ref}`this section ` on more about {attr}`collate_fn`. + +(disable-automatic-batching)= +### Disable automatic batching + +In certain cases, users may want to handle batching manually in dataset code, +or simply load individual samples. For example, it could be cheaper to directly +load batched data (e.g., bulk reads from a database or reading continuous +chunks of memory), or the batch size is data dependent, or the program is +designed to work on individual samples. Under these scenarios, it's likely +better to not use automatic batching (where {attr}`collate_fn` is used to +collate the samples), but let the data loader directly return each member of +the {attr}`dataset` object. + +When both {attr}`batch_size` and {attr}`batch_sampler` are `None` (default +value for {attr}`batch_sampler` is already `None`), automatic batching is +disabled. Each sample obtained from the {attr}`dataset` is processed with the +function passed as the {attr}`collate_fn` argument. + +**When automatic batching is disabled**, the default {attr}`collate_fn` simply +converts NumPy arrays into PyTorch Tensors, and keeps everything else untouched. + +In this case, loading from a map-style dataset is roughly equivalent with: + +```python +for index in sampler: + yield collate_fn(dataset[index]) +``` + +and loading from an iterable-style dataset is roughly equivalent with: + +```python +for data in iter(dataset): + yield collate_fn(data) +``` + +See {ref}`this section ` on more about {attr}`collate_fn`. + +(dataloader-collate_fn)= +### Working with {attr}`collate_fn` + +The use of {attr}`collate_fn` is slightly different when automatic batching is +enabled or disabled. + +**When automatic batching is disabled**, {attr}`collate_fn` is called with +each individual data sample, and the output is yielded from the data loader +iterator. In this case, the default {attr}`collate_fn` simply converts NumPy +arrays in PyTorch tensors. + +**When automatic batching is enabled**, {attr}`collate_fn` is called with a list +of data samples at each time. It is expected to collate the input samples into +a batch for yielding from the data loader iterator. The rest of this section +describes the behavior of the default {attr}`collate_fn` +({func}`~torch.utils.data.default_collate`). + +For instance, if each data sample consists of a 3-channel image and an integral +class label, i.e., each element of the dataset returns a tuple +`(image, class_index)`, the default {attr}`collate_fn` collates a list of +such tuples into a single tuple of a batched image tensor and a batched class +label Tensor. In particular, the default {attr}`collate_fn` has the following +properties: + +- It always prepends a new dimension as the batch dimension. +- It automatically converts NumPy arrays and Python numerical values into + PyTorch Tensors. +- It preserves the data structure, e.g., if each sample is a dictionary, it + outputs a dictionary with the same set of keys but batched Tensors as values + (or lists if the values can not be converted into Tensors). Same + for `list` s, `tuple` s, `namedtuple` s, etc. + +Users may use customized {attr}`collate_fn` to achieve custom batching, e.g., +collating along a dimension other than the first, padding sequences of +various lengths, or adding support for custom data types. + +If you run into a situation where the outputs of {class}`~torch.utils.data.DataLoader` +have dimensions or type that is different from your expectation, you may +want to check your {attr}`collate_fn`. + +(single-and-multi-process-data-loading)= +## Single- and Multi-process Data Loading + +A {class}`~torch.utils.data.DataLoader` uses single-process data loading by +default. + +Within a Python process, the +[Global Interpreter Lock (GIL)](https://wiki.python.org/moin/GlobalInterpreterLock) +prevents true fully parallelizing Python code across threads. To avoid blocking +computation code with data loading, PyTorch provides an easy switch to perform +multi-process data loading by simply setting the argument {attr}`num_workers` +to a positive integer. + +(single-process-data-loading-default)= +### Single-process data loading (default) + +In this mode, data fetching is done in the same process a +{class}`~torch.utils.data.DataLoader` is initialized. Therefore, data loading +may block computing. However, this mode may be preferred when resource(s) used +for sharing data among processes (e.g., shared memory, file descriptors) is +limited, or when the entire dataset is small and can be loaded entirely in +memory. Additionally, single-process loading often shows more readable error +traces and thus is useful for debugging. + +(multi-process-data-loading)= +### Multi-process data loading + +Setting the argument {attr}`num_workers` as a positive integer will +turn on multi-process data loading with the specified number of loader worker +processes. + +:::{warning} +After several iterations, the loader worker processes will consume +the same amount of CPU memory as the parent process for all Python +objects in the parent process which are accessed from the worker +processes. This can be problematic if the Dataset contains a lot of +data (e.g., you are loading a very large list of filenames at Dataset +construction time) and/or you are using a lot of workers (overall +memory usage is `number of workers * size of parent process`). The +simplest workaround is to replace Python objects with non-refcounted +representations such as Pandas, Numpy or PyArrow objects. Check out +[issue #13246](https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662) +for more details on why this occurs and example code for how to +workaround these problems. +::: + +In this mode, each time an iterator of a {class}`~torch.utils.data.DataLoader` +is created (e.g., when you call `enumerate(dataloader)`), {attr}`num_workers` +worker processes are created. At this point, the {attr}`dataset`, +{attr}`collate_fn`, and {attr}`worker_init_fn` are passed to each +worker, where they are used to initialize, and fetch data. This means that +dataset access together with its internal IO, transforms +(including {attr}`collate_fn`) runs in the worker process. + +{func}`torch.utils.data.get_worker_info()` returns various useful information +in a worker process (including the worker id, dataset replica, initial seed, +etc.), and returns `None` in main process. Users may use this function in +dataset code and/or {attr}`worker_init_fn` to individually configure each +dataset replica, and to determine whether the code is running in a worker +process. For example, this can be particularly helpful in sharding the dataset. + +For map-style datasets, the main process generates the indices using +{attr}`sampler` and sends them to the workers. So any shuffle randomization is +done in the main process which guides loading by assigning indices to load. + +For iterable-style datasets, since each worker process gets a replica of the +{attr}`dataset` object, naive multi-process loading will often result in +duplicated data. Using {func}`torch.utils.data.get_worker_info()` and/or +{attr}`worker_init_fn`, users may configure each replica independently. (See +{class}`~torch.utils.data.IterableDataset` documentations for how to achieve +this. ) For similar reasons, in multi-process loading, the {attr}`drop_last` +argument drops the last non-full batch of each worker's iterable-style dataset +replica. + +Workers are shut down once the end of the iteration is reached, or when the +iterator becomes garbage collected. + +:::{warning} +It is generally not recommended to return CUDA tensors in multi-process +loading because of many subtleties in using CUDA and sharing CUDA tensors in +multiprocessing (see {ref}`multiprocessing-cuda-note`). Instead, we recommend +using {ref}`automatic memory pinning ` (i.e., setting +{attr}`pin_memory=True`), which enables fast data transfer to CUDA-enabled +GPUs. +::: + +(platform-specific-behaviors)= +#### Platform-specific behaviors + +Since workers rely on Python {py:mod}`multiprocessing`, worker launch behavior is +different on Windows compared to Unix. + +- On Unix, {func}`fork()` is the default {py:mod}`multiprocessing` start method. + Using {func}`fork`, child workers typically can access the {attr}`dataset` and + Python argument functions directly through the cloned address space. +- On Windows or MacOS, {func}`spawn()` is the default {py:mod}`multiprocessing` start method. + Using {func}`spawn()`, another interpreter is launched which runs your main script, + followed by the internal worker function that receives the {attr}`dataset`, + {attr}`collate_fn` and other arguments through {py:mod}`pickle` serialization. + +This separate serialization means that you should take two steps to ensure you +are compatible with Windows while using multi-process data loading: + +- Wrap most of you main script's code within `if __name__ == '__main__':` block, + to make sure it doesn't run again (most likely generating error) when each worker + process is launched. You can place your dataset and {class}`~torch.utils.data.DataLoader` + instance creation logic here, as it doesn't need to be re-executed in workers. +- Make sure that any custom {attr}`collate_fn`, {attr}`worker_init_fn` + or {attr}`dataset` code is declared as top level definitions, outside of the + `__main__` check. This ensures that they are available in worker processes. + (this is needed since functions are pickled as references only, not `bytecode`.) + + +(data-loading-randomness)= +#### Randomness in multi-process data loading + +By default, each worker will have its PyTorch seed set to `base_seed + worker_id`, +where `base_seed` is a long generated by main process using its RNG (thereby, +consuming a RNG state mandatorily) or a specified {attr}`generator`. However, seeds for other +libraries may be duplicated upon initializing workers, causing each worker to return +identical random numbers. (See {ref}`this section ` in FAQ.). + +In {attr}`worker_init_fn`, you may access the PyTorch seed set for each worker +with either {func}`torch.utils.data.get_worker_info().seed ` +or {func}`torch.initial_seed()`, and use it to seed other libraries before data +loading. + + +(memory-pinning)= +## Memory Pinning + +Host to GPU copies are much faster when they originate from pinned (page-locked) +memory. See {ref}`cuda-memory-pinning` for more details on when and how to use +pinned memory generally. + +For data loading, passing {attr}`pin_memory=True` to a +{class}`~torch.utils.data.DataLoader` will automatically put the fetched data +Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled +GPUs. + +The default memory pinning logic only recognizes Tensors and maps and iterables +containing Tensors. By default, if the pinning logic sees a batch that is a +custom type (which will occur if you have a {attr}`collate_fn` that returns a +custom batch type), or if each element of your batch is a custom type, the +pinning logic will not recognize them, and it will return that batch (or those +elements) without pinning the memory. To enable memory pinning for custom +batch or data type(s), define a {meth}`pin_memory` method on your custom +type(s). + +See the example below. + +Example: + +```python +class SimpleCustomBatch: + def __init__(self, data): + transposed_data = list(zip(*data)) + self.inp = torch.stack(transposed_data[0], 0) + self.tgt = torch.stack(transposed_data[1], 0) + + # custom memory pinning method on custom type + def pin_memory(self): + self.inp = self.inp.pin_memory() + self.tgt = self.tgt.pin_memory() + return self + +def collate_wrapper(batch): + return SimpleCustomBatch(batch) + +inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) +tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) +dataset = TensorDataset(inps, tgts) + +loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, + pin_memory=True) + +for batch_ndx, sample in enumerate(loader): + print(sample.inp.is_pinned()) + print(sample.tgt.is_pinned()) +``` + +```{eval-rst} +.. autoclass:: DataLoader +``` + +```{eval-rst} +.. autoclass:: Dataset +``` + +```{eval-rst} +.. autoclass:: IterableDataset +``` + +```{eval-rst} +.. autoclass:: TensorDataset +``` + +```{eval-rst} +.. autoclass:: StackDataset +``` + +```{eval-rst} +.. autoclass:: ConcatDataset +``` + +```{eval-rst} +.. autoclass:: ChainDataset +``` + +```{eval-rst} +.. autoclass:: Subset +``` + +```{eval-rst} +.. autofunction:: torch.utils.data._utils.collate.collate +``` + +```{eval-rst} +.. autofunction:: torch.utils.data.default_collate +``` + +```{eval-rst} +.. autofunction:: torch.utils.data.default_convert +``` + +```{eval-rst} +.. autofunction:: torch.utils.data.get_worker_info +``` + +```{eval-rst} +.. autofunction:: torch.utils.data.random_split +``` + +```{eval-rst} +.. autoclass:: torch.utils.data.Sampler +``` + +```{eval-rst} +.. autoclass:: torch.utils.data.SequentialSampler +``` + +```{eval-rst} +.. autoclass:: torch.utils.data.RandomSampler +``` + +```{eval-rst} +.. autoclass:: torch.utils.data.SubsetRandomSampler +``` + +```{eval-rst} +.. autoclass:: torch.utils.data.WeightedRandomSampler +``` + +```{eval-rst} +.. autoclass:: torch.utils.data.BatchSampler +``` + +```{eval-rst} +.. autoclass:: torch.utils.data.distributed.DistributedSampler + +``` + +% These modules are documented as part of torch/data listing them here for + +% now until we have a clearer fix + +```{eval-rst} +.. py:module:: torch.utils.data.datapipes +``` + +```{eval-rst} +.. py:module:: torch.utils.data.datapipes.dataframe +``` + +```{eval-rst} +.. py:module:: torch.utils.data.datapipes.iter +``` + +```{eval-rst} +.. py:module:: torch.utils.data.datapipes.map +``` + +```{eval-rst} +.. py:module:: torch.utils.data.datapipes.utils +``` diff --git a/docs/source/data.rst b/docs/source/data.rst deleted file mode 100644 index 148f02bfc7f986..00000000000000 --- a/docs/source/data.rst +++ /dev/null @@ -1,451 +0,0 @@ -torch.utils.data -=================================== - -.. automodule:: torch.utils.data - -At the heart of PyTorch data loading utility is the :class:`torch.utils.data.DataLoader` -class. It represents a Python iterable over a dataset, with support for - -* `map-style and iterable-style datasets `_, - -* `customizing data loading order `_, - -* `automatic batching `_, - -* `single- and multi-process data loading `_, - -* `automatic memory pinning `_. - -These options are configured by the constructor arguments of a -:class:`~torch.utils.data.DataLoader`, which has signature:: - - DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, - batch_sampler=None, num_workers=0, collate_fn=None, - pin_memory=False, drop_last=False, timeout=0, - worker_init_fn=None, *, prefetch_factor=2, - persistent_workers=False) - -The sections below describe in details the effects and usages of these options. - -Dataset Types -------------- - -The most important argument of :class:`~torch.utils.data.DataLoader` -constructor is :attr:`dataset`, which indicates a dataset object to load data -from. PyTorch supports two different types of datasets: - -* `map-style datasets `_, - -* `iterable-style datasets `_. - -Map-style datasets -^^^^^^^^^^^^^^^^^^ - -A map-style dataset is one that implements the :meth:`__getitem__` and -:meth:`__len__` protocols, and represents a map from (possibly non-integral) -indices/keys to data samples. - -For example, such a dataset, when accessed with ``dataset[idx]``, could read -the ``idx``-th image and its corresponding label from a folder on the disk. - -See :class:`~torch.utils.data.Dataset` for more details. - -Iterable-style datasets -^^^^^^^^^^^^^^^^^^^^^^^ - -An iterable-style dataset is an instance of a subclass of :class:`~torch.utils.data.IterableDataset` -that implements the :meth:`__iter__` protocol, and represents an iterable over -data samples. This type of datasets is particularly suitable for cases where -random reads are expensive or even improbable, and where the batch size depends -on the fetched data. - -For example, such a dataset, when called ``iter(dataset)``, could return a -stream of data reading from a database, a remote server, or even logs generated -in real time. - -See :class:`~torch.utils.data.IterableDataset` for more details. - -.. note:: When using a :class:`~torch.utils.data.IterableDataset` with - `multi-process data loading `_. The same - dataset object is replicated on each worker process, and thus the - replicas must be configured differently to avoid duplicated data. See - :class:`~torch.utils.data.IterableDataset` documentations for how to - achieve this. - -Data Loading Order and :class:`~torch.utils.data.Sampler` ---------------------------------------------------------- - -For `iterable-style datasets `_, data loading order -is entirely controlled by the user-defined iterable. This allows easier -implementations of chunk-reading and dynamic batch size (e.g., by yielding a -batched sample at each time). - -The rest of this section concerns the case with -`map-style datasets `_. :class:`torch.utils.data.Sampler` -classes are used to specify the sequence of indices/keys used in data loading. -They represent iterable objects over the indices to datasets. E.g., in the -common case with stochastic gradient decent (SGD), a -:class:`~torch.utils.data.Sampler` could randomly permute a list of indices -and yield each one at a time, or yield a small number of them for mini-batch -SGD. - -A sequential or shuffled sampler will be automatically constructed based on the :attr:`shuffle` argument to a :class:`~torch.utils.data.DataLoader`. -Alternatively, users may use the :attr:`sampler` argument to specify a -custom :class:`~torch.utils.data.Sampler` object that at each time yields -the next index/key to fetch. - -A custom :class:`~torch.utils.data.Sampler` that yields a list of batch -indices at a time can be passed as the :attr:`batch_sampler` argument. -Automatic batching can also be enabled via :attr:`batch_size` and -:attr:`drop_last` arguments. See -`the next section `_ for more details -on this. - -.. note:: - Neither :attr:`sampler` nor :attr:`batch_sampler` is compatible with - iterable-style datasets, since such datasets have no notion of a key or an - index. - -Loading Batched and Non-Batched Data ------------------------------------- - -:class:`~torch.utils.data.DataLoader` supports automatically collating -individual fetched data samples into batches via arguments -:attr:`batch_size`, :attr:`drop_last`, :attr:`batch_sampler`, and -:attr:`collate_fn` (which has a default function). - - -Automatic batching (default) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This is the most common case, and corresponds to fetching a minibatch of -data and collating them into batched samples, i.e., containing Tensors with -one dimension being the batch dimension (usually the first). - -When :attr:`batch_size` (default ``1``) is not ``None``, the data loader yields -batched samples instead of individual samples. :attr:`batch_size` and -:attr:`drop_last` arguments are used to specify how the data loader obtains -batches of dataset keys. For map-style datasets, users can alternatively -specify :attr:`batch_sampler`, which yields a list of keys at a time. - -.. note:: - The :attr:`batch_size` and :attr:`drop_last` arguments essentially are used - to construct a :attr:`batch_sampler` from :attr:`sampler`. For map-style - datasets, the :attr:`sampler` is either provided by user or constructed - based on the :attr:`shuffle` argument. For iterable-style datasets, the - :attr:`sampler` is a dummy infinite one. See - `this section `_ on more details on - samplers. - -.. note:: - When fetching from - `iterable-style datasets `_ with - `multi-processing `_, the :attr:`drop_last` - argument drops the last non-full batch of each worker's dataset replica. - -After fetching a list of samples using the indices from sampler, the function -passed as the :attr:`collate_fn` argument is used to collate lists of samples -into batches. - -In this case, loading from a map-style dataset is roughly equivalent with:: - - for indices in batch_sampler: - yield collate_fn([dataset[i] for i in indices]) - -and loading from an iterable-style dataset is roughly equivalent with:: - - dataset_iter = iter(dataset) - for indices in batch_sampler: - yield collate_fn([next(dataset_iter) for _ in indices]) - -A custom :attr:`collate_fn` can be used to customize collation, e.g., padding -sequential data to max length of a batch. See -`this section `_ on more about :attr:`collate_fn`. - -Disable automatic batching -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In certain cases, users may want to handle batching manually in dataset code, -or simply load individual samples. For example, it could be cheaper to directly -load batched data (e.g., bulk reads from a database or reading continuous -chunks of memory), or the batch size is data dependent, or the program is -designed to work on individual samples. Under these scenarios, it's likely -better to not use automatic batching (where :attr:`collate_fn` is used to -collate the samples), but let the data loader directly return each member of -the :attr:`dataset` object. - -When both :attr:`batch_size` and :attr:`batch_sampler` are ``None`` (default -value for :attr:`batch_sampler` is already ``None``), automatic batching is -disabled. Each sample obtained from the :attr:`dataset` is processed with the -function passed as the :attr:`collate_fn` argument. - -**When automatic batching is disabled**, the default :attr:`collate_fn` simply -converts NumPy arrays into PyTorch Tensors, and keeps everything else untouched. - -In this case, loading from a map-style dataset is roughly equivalent with:: - - for index in sampler: - yield collate_fn(dataset[index]) - -and loading from an iterable-style dataset is roughly equivalent with:: - - for data in iter(dataset): - yield collate_fn(data) - -See `this section `_ on more about :attr:`collate_fn`. - -.. _dataloader-collate_fn: - -Working with :attr:`collate_fn` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The use of :attr:`collate_fn` is slightly different when automatic batching is -enabled or disabled. - -**When automatic batching is disabled**, :attr:`collate_fn` is called with -each individual data sample, and the output is yielded from the data loader -iterator. In this case, the default :attr:`collate_fn` simply converts NumPy -arrays in PyTorch tensors. - -**When automatic batching is enabled**, :attr:`collate_fn` is called with a list -of data samples at each time. It is expected to collate the input samples into -a batch for yielding from the data loader iterator. The rest of this section -describes the behavior of the default :attr:`collate_fn` -(:func:`~torch.utils.data.default_collate`). - -For instance, if each data sample consists of a 3-channel image and an integral -class label, i.e., each element of the dataset returns a tuple -``(image, class_index)``, the default :attr:`collate_fn` collates a list of -such tuples into a single tuple of a batched image tensor and a batched class -label Tensor. In particular, the default :attr:`collate_fn` has the following -properties: - -* It always prepends a new dimension as the batch dimension. - -* It automatically converts NumPy arrays and Python numerical values into - PyTorch Tensors. - -* It preserves the data structure, e.g., if each sample is a dictionary, it - outputs a dictionary with the same set of keys but batched Tensors as values - (or lists if the values can not be converted into Tensors). Same - for ``list`` s, ``tuple`` s, ``namedtuple`` s, etc. - -Users may use customized :attr:`collate_fn` to achieve custom batching, e.g., -collating along a dimension other than the first, padding sequences of -various lengths, or adding support for custom data types. - -If you run into a situation where the outputs of :class:`~torch.utils.data.DataLoader` -have dimensions or type that is different from your expectation, you may -want to check your :attr:`collate_fn`. - -Single- and Multi-process Data Loading --------------------------------------- - -A :class:`~torch.utils.data.DataLoader` uses single-process data loading by -default. - -Within a Python process, the -`Global Interpreter Lock (GIL) `_ -prevents true fully parallelizing Python code across threads. To avoid blocking -computation code with data loading, PyTorch provides an easy switch to perform -multi-process data loading by simply setting the argument :attr:`num_workers` -to a positive integer. - -Single-process data loading (default) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In this mode, data fetching is done in the same process a -:class:`~torch.utils.data.DataLoader` is initialized. Therefore, data loading -may block computing. However, this mode may be preferred when resource(s) used -for sharing data among processes (e.g., shared memory, file descriptors) is -limited, or when the entire dataset is small and can be loaded entirely in -memory. Additionally, single-process loading often shows more readable error -traces and thus is useful for debugging. - - -Multi-process data loading -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Setting the argument :attr:`num_workers` as a positive integer will -turn on multi-process data loading with the specified number of loader worker -processes. - -.. warning:: - After several iterations, the loader worker processes will consume - the same amount of CPU memory as the parent process for all Python - objects in the parent process which are accessed from the worker - processes. This can be problematic if the Dataset contains a lot of - data (e.g., you are loading a very large list of filenames at Dataset - construction time) and/or you are using a lot of workers (overall - memory usage is ``number of workers * size of parent process``). The - simplest workaround is to replace Python objects with non-refcounted - representations such as Pandas, Numpy or PyArrow objects. Check out - `issue #13246 - `_ - for more details on why this occurs and example code for how to - workaround these problems. - -In this mode, each time an iterator of a :class:`~torch.utils.data.DataLoader` -is created (e.g., when you call ``enumerate(dataloader)``), :attr:`num_workers` -worker processes are created. At this point, the :attr:`dataset`, -:attr:`collate_fn`, and :attr:`worker_init_fn` are passed to each -worker, where they are used to initialize, and fetch data. This means that -dataset access together with its internal IO, transforms -(including :attr:`collate_fn`) runs in the worker process. - -:func:`torch.utils.data.get_worker_info()` returns various useful information -in a worker process (including the worker id, dataset replica, initial seed, -etc.), and returns ``None`` in main process. Users may use this function in -dataset code and/or :attr:`worker_init_fn` to individually configure each -dataset replica, and to determine whether the code is running in a worker -process. For example, this can be particularly helpful in sharding the dataset. - -For map-style datasets, the main process generates the indices using -:attr:`sampler` and sends them to the workers. So any shuffle randomization is -done in the main process which guides loading by assigning indices to load. - -For iterable-style datasets, since each worker process gets a replica of the -:attr:`dataset` object, naive multi-process loading will often result in -duplicated data. Using :func:`torch.utils.data.get_worker_info()` and/or -:attr:`worker_init_fn`, users may configure each replica independently. (See -:class:`~torch.utils.data.IterableDataset` documentations for how to achieve -this. ) For similar reasons, in multi-process loading, the :attr:`drop_last` -argument drops the last non-full batch of each worker's iterable-style dataset -replica. - -Workers are shut down once the end of the iteration is reached, or when the -iterator becomes garbage collected. - -.. warning:: - It is generally not recommended to return CUDA tensors in multi-process - loading because of many subtleties in using CUDA and sharing CUDA tensors in - multiprocessing (see :ref:`multiprocessing-cuda-note`). Instead, we recommend - using `automatic memory pinning `_ (i.e., setting - :attr:`pin_memory=True`), which enables fast data transfer to CUDA-enabled - GPUs. - -Platform-specific behaviors -""""""""""""""""""""""""""" - -Since workers rely on Python :py:mod:`multiprocessing`, worker launch behavior is -different on Windows compared to Unix. - -* On Unix, :func:`fork()` is the default :py:mod:`multiprocessing` start method. - Using :func:`fork`, child workers typically can access the :attr:`dataset` and - Python argument functions directly through the cloned address space. - -* On Windows or MacOS, :func:`spawn()` is the default :py:mod:`multiprocessing` start method. - Using :func:`spawn()`, another interpreter is launched which runs your main script, - followed by the internal worker function that receives the :attr:`dataset`, - :attr:`collate_fn` and other arguments through :py:mod:`pickle` serialization. - -This separate serialization means that you should take two steps to ensure you -are compatible with Windows while using multi-process data loading: - -- Wrap most of you main script's code within ``if __name__ == '__main__':`` block, - to make sure it doesn't run again (most likely generating error) when each worker - process is launched. You can place your dataset and :class:`~torch.utils.data.DataLoader` - instance creation logic here, as it doesn't need to be re-executed in workers. - -- Make sure that any custom :attr:`collate_fn`, :attr:`worker_init_fn` - or :attr:`dataset` code is declared as top level definitions, outside of the - ``__main__`` check. This ensures that they are available in worker processes. - (this is needed since functions are pickled as references only, not ``bytecode``.) - -.. _data-loading-randomness: - -Randomness in multi-process data loading -"""""""""""""""""""""""""""""""""""""""""" - -By default, each worker will have its PyTorch seed set to ``base_seed + worker_id``, -where ``base_seed`` is a long generated by main process using its RNG (thereby, -consuming a RNG state mandatorily) or a specified :attr:`generator`. However, seeds for other -libraries may be duplicated upon initializing workers, causing each worker to return -identical random numbers. (See :ref:`this section ` in FAQ.). - -In :attr:`worker_init_fn`, you may access the PyTorch seed set for each worker -with either :func:`torch.utils.data.get_worker_info().seed ` -or :func:`torch.initial_seed()`, and use it to seed other libraries before data -loading. - -Memory Pinning --------------- - -Host to GPU copies are much faster when they originate from pinned (page-locked) -memory. See :ref:`cuda-memory-pinning` for more details on when and how to use -pinned memory generally. - -For data loading, passing :attr:`pin_memory=True` to a -:class:`~torch.utils.data.DataLoader` will automatically put the fetched data -Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled -GPUs. - -The default memory pinning logic only recognizes Tensors and maps and iterables -containing Tensors. By default, if the pinning logic sees a batch that is a -custom type (which will occur if you have a :attr:`collate_fn` that returns a -custom batch type), or if each element of your batch is a custom type, the -pinning logic will not recognize them, and it will return that batch (or those -elements) without pinning the memory. To enable memory pinning for custom -batch or data type(s), define a :meth:`pin_memory` method on your custom -type(s). - -See the example below. - -Example:: - - class SimpleCustomBatch: - def __init__(self, data): - transposed_data = list(zip(*data)) - self.inp = torch.stack(transposed_data[0], 0) - self.tgt = torch.stack(transposed_data[1], 0) - - # custom memory pinning method on custom type - def pin_memory(self): - self.inp = self.inp.pin_memory() - self.tgt = self.tgt.pin_memory() - return self - - def collate_wrapper(batch): - return SimpleCustomBatch(batch) - - inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) - tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) - dataset = TensorDataset(inps, tgts) - - loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, - pin_memory=True) - - for batch_ndx, sample in enumerate(loader): - print(sample.inp.is_pinned()) - print(sample.tgt.is_pinned()) - - -.. autoclass:: DataLoader -.. autoclass:: Dataset -.. autoclass:: IterableDataset -.. autoclass:: TensorDataset -.. autoclass:: StackDataset -.. autoclass:: ConcatDataset -.. autoclass:: ChainDataset -.. autoclass:: Subset -.. autofunction:: torch.utils.data._utils.collate.collate -.. autofunction:: torch.utils.data.default_collate -.. autofunction:: torch.utils.data.default_convert -.. autofunction:: torch.utils.data.get_worker_info -.. autofunction:: torch.utils.data.random_split -.. autoclass:: torch.utils.data.Sampler -.. autoclass:: torch.utils.data.SequentialSampler -.. autoclass:: torch.utils.data.RandomSampler -.. autoclass:: torch.utils.data.SubsetRandomSampler -.. autoclass:: torch.utils.data.WeightedRandomSampler -.. autoclass:: torch.utils.data.BatchSampler -.. autoclass:: torch.utils.data.distributed.DistributedSampler - - -.. These modules are documented as part of torch/data listing them here for -.. now until we have a clearer fix -.. py:module:: torch.utils.data.datapipes -.. py:module:: torch.utils.data.datapipes.dataframe -.. py:module:: torch.utils.data.datapipes.iter -.. py:module:: torch.utils.data.datapipes.map -.. py:module:: torch.utils.data.datapipes.utils diff --git a/docs/source/ddp_comm_hooks.md b/docs/source/ddp_comm_hooks.md new file mode 100644 index 00000000000000..059c388cd003a0 --- /dev/null +++ b/docs/source/ddp_comm_hooks.md @@ -0,0 +1,218 @@ +# DDP Communication Hooks + +DDP communication hook is a generic interface to control how to communicate +gradients across workers by overriding the vanilla allreduce in +[DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.). +A few built-in communication hooks are provided, +and users can easily apply any of these hooks to optimize communication. +Besides, the hook interface can also support user-defined communication +strategies for more advanced use cases. + +## How to Use a Communication Hook? + +To use a communication hook, the user just needs to let the DDP model register +the hook before the training loop as below. + +{func}`torch.nn.parallel.DistributedDataParallel.register_comm_hook` + +## What Does a Communication Hook Operate On? + +A communication hook provides a flexible way to allreduce gradients. +Therefore, it mainly operates on the gradients on each replica before allreduce, +which are bucketized to increase the overlap between communication and computation. +Particularly, {class}`torch.distributed.GradBucket` represents a bucket of gradient tensors to be allreduced. + +```{eval-rst} +.. autoclass:: torch.distributed.GradBucket + +.. autofunction:: torch.distributed.GradBucket.index +.. autofunction:: torch.distributed.GradBucket.buffer +.. autofunction:: torch.distributed.GradBucket.gradients +.. autofunction:: torch.distributed.GradBucket.is_last +.. autofunction:: torch.distributed.GradBucket.set_buffer +.. autofunction:: torch.distributed.GradBucket.parameters +``` + +## Default Communication Hooks + +Default communication hooks are simple **stateless** hooks, so the input state +in `register_comm_hook` is either a process group or `None`. +The input `bucket` is a {class}`torch.distributed.GradBucket` object. + +```{eval-rst} +.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks +.. autofunction:: allreduce_hook +.. autofunction:: fp16_compress_hook +.. autofunction:: bf16_compress_hook +``` + +Additionally, a communication hook wrapper is provided to support {meth}`~fp16_compress_hook` or {meth}`~bf16_compress_hook` as a wrapper, +which can be combined with other communication hooks. + +```{eval-rst} +.. autofunction:: fp16_compress_wrapper +.. autofunction:: bf16_compress_wrapper +``` +## PowerSGD Communication Hook + +PowerSGD ([Vogels et al., NeurIPS 2019](https://arxiv.org/abs/1905.13727)) +is a gradient compression algorithm, which can provide very high compression +rates and accelerate bandwidth-bound distributed training. +This algorithm needs to maintain both some hyperparameters and the internal +state. Therefore, PowerSGD communication hook is a **stateful** hook, +and the user needs to provide a state object defined as below. + +### PowerSGD State + +```{eval-rst} +.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook +.. autoclass:: PowerSGDState +``` +### PowerSGD Hooks + +```{warning} +PowerSGD typically requires extra memory of the same size as the model's +gradients to enable error feedback, which can compensate for biased +compressed communication and improve accuracy. +``` + +```{warning} +PowerSGD hooks may conflict with [Apex automatic mixed precision package](https://github.com/NVIDIA/apex). +Please use PyTorch [native automatic mixed precision package](https://pytorch.org/docs/stable/amp.html) +instead. +``` + +```{eval-rst} +.. autofunction:: powerSGD_hook +.. autofunction:: batched_powerSGD_hook +``` +## Debugging Communication Hooks + +As the name implies, debugging communication hooks are **only** used for debugging and performance optimization purpose. +```{eval-rst} +.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks +``` +```{warning} +Debugging communication hooks do not necessarily output the correct results. +``` +```{eval-rst} +.. autofunction:: noop_hook +``` +## Checkpointing of Communication Hooks + +```{eval-rst} +.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook +``` +A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts. +To make a hook serializable, ``__setstate__`` and ``__getstate__`` should be defined. + +```{warning} +`__getstate__` should exclude non-serializable attributes from a returned dictionary. +``` +```{warning} +`__setstate__` should properly initialize non-serializable attributes, excluded from a provided `state`. +``` +{class}`PowerSGDState` has `__setstate__` and `__getstate__` implemented and can be used as a reference. + +```{eval-rst} +.. class:: PowerSGDState + :noindex: + + .. automethod:: PowerSGDState.__getstate__ + .. automethod:: PowerSGDState.__setstate__ +``` +Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook. + +```python + +import os +import sys +import tempfile +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +import torch.multiprocessing as mp + +from torch.nn.parallel import DistributedDataParallel +from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD + +class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(24,24) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(24,12) + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +def run_demo(demo_fn, world_size): + mp.spawn( + demo_fn, + args=(world_size,), + nprocs=world_size, + join=True) + +def demo_serialization(rank, world_size): + setup(rank, world_size) + + CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt" + + model = SimpleModel().to(rank) + ddp_model = DistributedDataParallel(model, device_ids=[rank]) + + powersgd_hook = powerSGD.powerSGD_hook + powersgd_state = powerSGD.PowerSGDState(process_group=None) + + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + ddp_model.register_comm_hook(powersgd_state, powersgd_hook) + + state = { + 'state_dict': ddp_model.state_dict(), + 'comm_hook': powersgd_hook, + 'comm_hook_state': powersgd_state} + + if rank == 0: + torch.save(state, CHECKPOINT) + + dist.barrier() + map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} + checkpoint = torch.load(CHECKPOINT, map_location=map_location) + + new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank]) + new_ddp_model.load_state_dict(checkpoint['state_dict']) + powersgd_hook = checkpoint['comm_hook'] + powersgd_state = checkpoint['comm_hook_state'] + + new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook) + + if rank == 0: + os.remove(CHECKPOINT) + + cleanup() + +if __name__ == "__main__": + n_gpus = torch.cuda.device_count() + assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + run_demo(demo_serialization, world_size) +``` + +## Acknowledgements + +Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on +PowerSGD communication hook, as well as the +[comparison experiments](https://observablehq.com/@tvogels/powersgd-benchmark), +which show that the performance of PowerSGD communication hook is on par with +the implementation in the original [paper](https://arxiv.org/abs/1905.13727). diff --git a/docs/source/ddp_comm_hooks.rst b/docs/source/ddp_comm_hooks.rst deleted file mode 100644 index 204918361dc657..00000000000000 --- a/docs/source/ddp_comm_hooks.rst +++ /dev/null @@ -1,215 +0,0 @@ -DDP Communication Hooks -======================= - -DDP communication hook is a generic interface to control how to communicate -gradients across workers by overriding the vanilla allreduce in -`DistributedDataParallel `_. -A few built-in communication hooks are provided, -and users can easily apply any of these hooks to optimize communication. -Besides, the hook interface can also support user-defined communication -strategies for more advanced use cases. - -How to Use a Communication Hook? --------------------------------- - -To use a communication hook, the user just needs to let the DDP model register -the hook before the training loop as below. - -:func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook` - -What Does a Communication Hook Operate On? ------------------------------------------- - -A communication hook provides a flexible way to allreduce gradients. -Therefore, it mainly operates on the gradients on each replica before allreduce, -which are bucketized to increase the overlap between communication and computation. -Particularly, :class:`torch.distributed.GradBucket` represents a bucket of gradient tensors to be allreduced. - -.. autoclass:: torch.distributed.GradBucket - -.. autofunction:: torch.distributed.GradBucket.index -.. autofunction:: torch.distributed.GradBucket.buffer -.. autofunction:: torch.distributed.GradBucket.gradients -.. autofunction:: torch.distributed.GradBucket.is_last -.. autofunction:: torch.distributed.GradBucket.set_buffer -.. autofunction:: torch.distributed.GradBucket.parameters - -Default Communication Hooks ---------------------------- - -Default communication hooks are simple **stateless** hooks, so the input state -in ``register_comm_hook`` is either a process group or ``None``. -The input ``bucket`` is a :class:`torch.distributed.GradBucket` object. - -.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks -.. autofunction:: allreduce_hook -.. autofunction:: fp16_compress_hook -.. autofunction:: bf16_compress_hook - -Additionally, a communication hook wrapper is provided to support :meth:`~fp16_compress_hook` or :meth:`~bf16_compress_hook` as a wrapper, -which can be combined with other communication hooks. - -.. autofunction:: fp16_compress_wrapper -.. autofunction:: bf16_compress_wrapper - -PowerSGD Communication Hook ---------------------------- - -PowerSGD (`Vogels et al., NeurIPS 2019 `_) -is a gradient compression algorithm, which can provide very high compression -rates and accelerate bandwidth-bound distributed training. -This algorithm needs to maintain both some hyperparameters and the internal -state. Therefore, PowerSGD communication hook is a **stateful** hook, -and the user needs to provide a state object defined as below. - -PowerSGD State -^^^^^^^^^^^^^^^^ - -.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook -.. autoclass:: PowerSGDState - -PowerSGD Hooks -^^^^^^^^^^^^^^^^ - -.. warning :: - PowerSGD typically requires extra memory of the same size as the model's - gradients to enable error feedback, which can compensate for biased - compressed communication and improve accuracy. - -.. warning :: - PowerSGD hooks may conflict with `Apex automatic mixed precision package `_. - Please use PyTorch `native automatic mixed precision package `_ - instead. - -.. autofunction:: powerSGD_hook -.. autofunction:: batched_powerSGD_hook - -Debugging Communication Hooks ------------------------------ - -As the name implies, debugging communication hooks are **only** used for debugging and performance optimization purpose. - -.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks - -.. warning :: - Debugging communication hooks do not necessarily output the correct results. - -.. autofunction:: noop_hook - -Checkpointing of Communication Hooks ------------------------------------- - -.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook - -A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts. -To make a hook serializable, ``__setstate__`` and ``__getstate__`` should be defined. - -.. warning :: - ``__getstate__`` should exclude non-serializable attributes from a returned dictionary. - -.. warning :: - ``__setstate__`` should properly initialize non-serializable attributes, excluded from a provided ``state``. - -:class:`PowerSGDState` has ``__setstate__`` and ``__getstate__`` implemented and can be used as a reference. - -.. class:: PowerSGDState - :noindex: - - .. automethod:: PowerSGDState.__getstate__ - .. automethod:: PowerSGDState.__setstate__ - -Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook. - -:: - - import os - import sys - import tempfile - import torch - import torch.distributed as dist - import torch.nn as nn - import torch.optim as optim - import torch.multiprocessing as mp - - from torch.nn.parallel import DistributedDataParallel - from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD - - class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(24,24) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(24,12) - - def forward(self, x): - return self.fc2(self.relu(self.fc1(x))) - - def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - - def cleanup(): - dist.destroy_process_group() - - def run_demo(demo_fn, world_size): - mp.spawn( - demo_fn, - args=(world_size,), - nprocs=world_size, - join=True) - - def demo_serialization(rank, world_size): - setup(rank, world_size) - - CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt" - - model = SimpleModel().to(rank) - ddp_model = DistributedDataParallel(model, device_ids=[rank]) - - powersgd_hook = powerSGD.powerSGD_hook - powersgd_state = powerSGD.PowerSGDState(process_group=None) - - optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) - ddp_model.register_comm_hook(powersgd_state, powersgd_hook) - - state = { - 'state_dict': ddp_model.state_dict(), - 'comm_hook': powersgd_hook, - 'comm_hook_state': powersgd_state} - - if rank == 0: - torch.save(state, CHECKPOINT) - - dist.barrier() - map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} - checkpoint = torch.load(CHECKPOINT, map_location=map_location) - - new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank]) - new_ddp_model.load_state_dict(checkpoint['state_dict']) - powersgd_hook = checkpoint['comm_hook'] - powersgd_state = checkpoint['comm_hook_state'] - - new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook) - - if rank == 0: - os.remove(CHECKPOINT) - - cleanup() - - if __name__ == "__main__": - n_gpus = torch.cuda.device_count() - assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" - world_size = n_gpus - run_demo(demo_serialization, world_size) - -Acknowledgements ----------------- - -Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on -PowerSGD communication hook, as well as the -`comparison experiments `_, -which show that the performance of PowerSGD communication hook is on par with -the implementation in the original `paper `_. diff --git a/docs/source/debugging_environment_variables.md b/docs/source/debugging_environment_variables.md new file mode 100644 index 00000000000000..5e9f8b55359b6a --- /dev/null +++ b/docs/source/debugging_environment_variables.md @@ -0,0 +1,14 @@ +(debugging_environment_variables)= +# Debugging Environment Variables + +:::{list-table} + :header-rows: 1 + + * - Variable + - Description + * - ``TORCH_SHOW_CPP_STACKTRACES`` + - If set to ``1``, makes PyTorch print out a stack trace when it detects a C++ error. + * - ``TORCH_CPP_LOG_LEVEL`` + - Set the log level of c10 logging facility (supports both GLOG and c10 loggers). Valid values are ``INFO``, ``WARNING``, ``ERROR``, and ``FATAL`` or their numerical equivalents ``0``, ``1``, ``2``, and ``3``. + * - ``TORCH_LOGS`` + - For a more in depth explanation of this environment variable, see {doc}`/logging`. \ No newline at end of file diff --git a/docs/source/debugging_environment_variables.rst b/docs/source/debugging_environment_variables.rst deleted file mode 100644 index ebc38d454ebd52..00000000000000 --- a/docs/source/debugging_environment_variables.rst +++ /dev/null @@ -1,15 +0,0 @@ -.. _debugging_environment_variables: - -Debugging Environment Variables -=============================== -.. list-table:: - :header-rows: 1 - - * - Variable - - Description - * - ``TORCH_SHOW_CPP_STACKTRACES`` - - If set to ``1``, makes PyTorch print out a stack trace when it detects a C++ error. - * - ``TORCH_CPP_LOG_LEVEL`` - - Set the log level of c10 logging facility (supports both GLOG and c10 loggers). Valid values are ``INFO``, ``WARNING``, ``ERROR``, and ``FATAL`` or their numerical equivalents ``0``, ``1``, ``2``, and ``3``. - * - ``TORCH_LOGS`` - - For a more in depth explanation of this environment variable, see :doc:`/logging`. \ No newline at end of file diff --git a/docs/source/deploy.md b/docs/source/deploy.md new file mode 100644 index 00000000000000..ef5131717bf7b9 --- /dev/null +++ b/docs/source/deploy.md @@ -0,0 +1,8 @@ +--- +orphan: true +--- + +# torch::deploy has been moved to pytorch/multipy + + +``torch::deploy`` has been moved to its new home at [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy). diff --git a/docs/source/deploy.rst b/docs/source/deploy.rst deleted file mode 100644 index 8f076d8c57ac10..00000000000000 --- a/docs/source/deploy.rst +++ /dev/null @@ -1,6 +0,0 @@ -:orphan: - -torch::deploy has been moved to pytorch/multipy -=============================================== - -``torch::deploy`` has been moved to its new home at `https://github.com/pytorch/multipy `_. diff --git a/docs/source/deterministic.md b/docs/source/deterministic.md new file mode 100644 index 00000000000000..c979cd3a1ca6d4 --- /dev/null +++ b/docs/source/deterministic.md @@ -0,0 +1,30 @@ +# torch.utils.deterministic + +```{eval-rst} +.. py:module:: torch.utils.deterministic +.. currentmodule:: torch.utils.deterministic + +.. attribute:: fill_uninitialized_memory + + A :class:`bool` that, if True, causes uninitialized memory to be filled with + a known value when :meth:`torch.use_deterministic_algorithms()` is set to + ``True``. Floating point and complex values are set to NaN, and integer + values are set to the maximum value. + + Default: ``True`` + + Filling uninitialized memory is detrimental to performance. So if your + program is valid and does not use uninitialized memory as the input to an + operation, then this setting can be turned off for better performance and + still be deterministic. + + The following operations will fill uninitialized memory when this setting is + turned on: + + * :func:`torch.Tensor.resize_` when called with a tensor that is not + quantized + * :func:`torch.empty` + * :func:`torch.empty_strided` + * :func:`torch.empty_permuted` + * :func:`torch.empty_like` +``` \ No newline at end of file diff --git a/docs/source/deterministic.rst b/docs/source/deterministic.rst deleted file mode 100644 index 50390ddbdaed22..00000000000000 --- a/docs/source/deterministic.rst +++ /dev/null @@ -1,28 +0,0 @@ -torch.utils.deterministic -========================= -.. py:module:: torch.utils.deterministic -.. currentmodule:: torch.utils.deterministic - -.. attribute:: fill_uninitialized_memory - - A :class:`bool` that, if True, causes uninitialized memory to be filled with - a known value when :meth:`torch.use_deterministic_algorithms()` is set to - ``True``. Floating point and complex values are set to NaN, and integer - values are set to the maximum value. - - Default: ``True`` - - Filling uninitialized memory is detrimental to performance. So if your - program is valid and does not use uninitialized memory as the input to an - operation, then this setting can be turned off for better performance and - still be deterministic. - - The following operations will fill uninitialized memory when this setting is - turned on: - - * :func:`torch.Tensor.resize_` when called with a tensor that is not - quantized - * :func:`torch.empty` - * :func:`torch.empty_strided` - * :func:`torch.empty_permuted` - * :func:`torch.empty_like` \ No newline at end of file diff --git a/docs/source/distributed.algorithms.join.md b/docs/source/distributed.algorithms.join.md new file mode 100644 index 00000000000000..c8c661557dcedd --- /dev/null +++ b/docs/source/distributed.algorithms.join.md @@ -0,0 +1,24 @@ +```{role} hidden +--- +class: hidden-section +--- + +``` +# Generic Join Context Manager + +The generic join context manager facilitates distributed training on uneven +inputs. This page outlines the API of the relevant classes: {class}`Join`, +{class}`Joinable`, and {class}`JoinHook`. For a tutorial, see +[Distributed Training with Uneven Inputs Using the Join Context Manager](https://pytorch.org/tutorials/advanced/generic_join.html). + +```{eval-rst} +.. autoclass:: torch.distributed.algorithms.Join + :members: + +.. autoclass:: torch.distributed.algorithms.Joinable + :members: + +.. autoclass:: torch.distributed.algorithms.JoinHook + :members: + +``` \ No newline at end of file diff --git a/docs/source/distributed.algorithms.join.rst b/docs/source/distributed.algorithms.join.rst deleted file mode 100644 index 9ef1f6cf58763f..00000000000000 --- a/docs/source/distributed.algorithms.join.rst +++ /dev/null @@ -1,20 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Generic Join Context Manager -============================ -The generic join context manager facilitates distributed training on uneven -inputs. This page outlines the API of the relevant classes: :class:`Join`, -:class:`Joinable`, and :class:`JoinHook`. For a tutorial, see -`Distributed Training with Uneven Inputs Using the Join Context Manager`_. - -.. autoclass:: torch.distributed.algorithms.Join - :members: - -.. autoclass:: torch.distributed.algorithms.Joinable - :members: - -.. autoclass:: torch.distributed.algorithms.JoinHook - :members: - -.. _Distributed Training with Uneven Inputs Using the Join Context Manager: https://pytorch.org/tutorials/advanced/generic_join.html diff --git a/docs/source/distributed.checkpoint.md b/docs/source/distributed.checkpoint.md new file mode 100644 index 00000000000000..694dfef1098a19 --- /dev/null +++ b/docs/source/distributed.checkpoint.md @@ -0,0 +1,269 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Distributed Checkpoint - torch.distributed.checkpoint + +Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel. +It handles load-time resharding which enables saving in one cluster topology and loading into another. + +DCP is different than `torch.save` and `torch.load` in a few significant ways: + +- It produces multiple files per checkpoint, with at least one per rank. +- It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead. + +The entrypoints to load and save a checkpoint are the following: + +## Additional resources: + +- [Getting Started with Distributed Checkpoint (DCP)](https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html) +- [Asynchronous Saving with Distributed Checkpoint (DCP)](https://pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html) +- [TorchTitan Checkpointing Docs](https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md) +- [TorchTitan DCP Implementation](https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py) + +```{eval-rst} +.. automodule:: torch.distributed.checkpoint +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.checkpoint.state_dict_saver +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.state_dict_saver.AsyncSaveResponse + :members: +``` + +```{eval-rst} +.. autofunction:: save +``` + +```{eval-rst} +.. autofunction:: async_save +``` + +```{eval-rst} +.. autofunction:: save_state_dict +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.checkpoint.state_dict_loader +``` + +```{eval-rst} +.. autofunction:: load +``` + +```{eval-rst} +.. autofunction:: load_state_dict +``` + +The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (`torch.distributed.checkpoint.async_save`): + +```{eval-rst} +.. automodule:: torch.distributed.checkpoint.staging +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.staging.AsyncStager + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.staging.DefaultStager + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.staging.StagingOptions + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.staging.BlockingAsyncStager + :members: +``` + +In addition to the above entrypoints, `Stateful` objects, as described below, provide additional customization during saving/loading + +```{eval-rst} +.. automodule:: torch.distributed.checkpoint.stateful + :noindex: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.stateful.Stateful + :members: +``` + +This [example](https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py) shows how to use Pytorch Distributed Checkpoint to save a FSDP model. + +The following types define the IO interface used during checkpoint: + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.StorageReader + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.StorageWriter + :members: +``` + +The following types define the planner interface used during checkpoint: + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.LoadPlanner + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.LoadPlan + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.ReadItem + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.SavePlanner + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.SavePlan + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.planner.WriteItem + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.planner.BytesIOWriteData + :members: +``` + +We provide a filesystem based storage layer: + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.FileSystemReader + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.FileSystemWriter + :members: +``` + +We also provide other storage layers, including ones to interact with HuggingFace safetensors: + +.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageReader + :members: + +.. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter + :members: + +We provide default implementations of `LoadPlanner` and `SavePlanner` that +can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor. + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.DefaultSavePlanner + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.DefaultLoadPlanner + :members: + +``` + +Due to legacy design decisions, the state dictionaries of `FSDP` and `DDP` may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover, `FSDP` offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism). + +To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts. `get_model_state_dict()` returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly, `get_optimizer_state_dict()` provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency, `get_optimizer_state_dict()` converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary. + +Note that results returned by these APIs can be used directly with the `torch.distributed.checkpoint.save()` and `torch.distributed.checkpoint.load()` methods without requiring any additional conversions. + +`set_model_state_dict()` and `set_optimizer_state_dict()` are provided to load the model and optimizer state_dict generated by by their respective getter APIs. + +Note that `set_optimizer_state_dict()` can only be called before `backward()` or after `step()` is called on optimizers. + +Note that this feature is experimental, and API signatures might change in the future. + +```{eval-rst} +.. autofunction:: torch.distributed.checkpoint.state_dict.get_state_dict +``` + +```{eval-rst} +.. autofunction:: torch.distributed.checkpoint.state_dict.get_model_state_dict +``` + +```{eval-rst} +.. autofunction:: torch.distributed.checkpoint.state_dict.get_optimizer_state_dict +``` + +```{eval-rst} +.. autofunction:: torch.distributed.checkpoint.state_dict.set_state_dict +``` + +```{eval-rst} +.. autofunction:: torch.distributed.checkpoint.state_dict.set_model_state_dict +``` + +```{eval-rst} +.. autofunction:: torch.distributed.checkpoint.state_dict.set_optimizer_state_dict +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.state_dict.StateDictOptions + :members: +``` + +For users which are used to using and sharing models in the `torch.save` format, the following methods are provided which provide offline utilities for converting betweeing formats. + +```{eval-rst} +.. automodule:: torch.distributed.checkpoint.format_utils +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.checkpoint.format_utils +``` + +```{eval-rst} +.. autofunction:: dcp_to_torch_save +``` + +```{eval-rst} +.. autofunction:: torch_save_to_dcp +``` + +The following classes can also be utilized for online loading and resharding of models from the torch.save format. + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner + :members: +``` + +The following experimental interfaces are provided for improved observability in production environments: + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.logger +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.logging_handlers +``` diff --git a/docs/source/distributed.checkpoint.rst b/docs/source/distributed.checkpoint.rst deleted file mode 100644 index 909e836e57478a..00000000000000 --- a/docs/source/distributed.checkpoint.rst +++ /dev/null @@ -1,157 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Distributed Checkpoint - torch.distributed.checkpoint -===================================================== - - -Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel. -It handles load-time resharding which enables saving in one cluster topology and loading into another. - -DCP is different than `torch.save` and `torch.load` in a few significant ways: - -* It produces multiple files per checkpoint, with at least one per rank. -* It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead. - -The entrypoints to load and save a checkpoint are the following: - -Additional resources: ---------------------- - -* `Getting Started with Distributed Checkpoint (DCP) `__ -* `Asynchronous Saving with Distributed Checkpoint (DCP) `__ -* `TorchTitan Checkpointing Docs `__ -* `TorchTitan DCP Implementation `__ - -.. automodule:: torch.distributed.checkpoint - -.. currentmodule:: torch.distributed.checkpoint.state_dict_saver - -.. autoclass:: torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType - :members: - -.. autofunction:: save -.. autofunction:: async_save -.. autofunction:: save_state_dict - -.. currentmodule:: torch.distributed.checkpoint.state_dict_loader - -.. autofunction:: load -.. autofunction:: load_state_dict - -The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (`torch.distributed.checkpoint.async_save`): - -.. automodule:: torch.distributed.checkpoint.staging - -.. autoclass:: torch.distributed.checkpoint.staging.AsyncStager - :members: - -.. autoclass:: torch.distributed.checkpoint.staging.BlockingAsyncStager - :members: - -In addition to the above entrypoints, `Stateful` objects, as described below, provide additional customization during saving/loading -.. automodule:: torch.distributed.checkpoint.stateful - -.. autoclass:: torch.distributed.checkpoint.stateful.Stateful - :members: - -This `example `_ shows how to use Pytorch Distributed Checkpoint to save a FSDP model. - -The following types define the IO interface used during checkpoint: - -.. autoclass:: torch.distributed.checkpoint.StorageReader - :members: - -.. autoclass:: torch.distributed.checkpoint.StorageWriter - :members: - -The following types define the planner interface used during checkpoint: - -.. autoclass:: torch.distributed.checkpoint.LoadPlanner - :members: - -.. autoclass:: torch.distributed.checkpoint.LoadPlan - :members: - -.. autoclass:: torch.distributed.checkpoint.ReadItem - :members: - -.. autoclass:: torch.distributed.checkpoint.SavePlanner - :members: - -.. autoclass:: torch.distributed.checkpoint.SavePlan - :members: - -.. autoclass:: torch.distributed.checkpoint.planner.WriteItem - :members: - -.. autoclass:: torch.distributed.checkpoint.planner.BytesIOWriteData - :members: - -We provide a filesystem based storage layer: - -.. autoclass:: torch.distributed.checkpoint.FileSystemReader - :members: - -.. autoclass:: torch.distributed.checkpoint.FileSystemWriter - :members: - -We provide default implementations of `LoadPlanner` and `SavePlanner` that -can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor. - -.. autoclass:: torch.distributed.checkpoint.DefaultSavePlanner - :members: - -.. autoclass:: torch.distributed.checkpoint.DefaultLoadPlanner - :members: - - -Due to legacy design decisions, the state dictionaries of `FSDP` and `DDP` may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover, `FSDP` offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism). - -To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts. `get_model_state_dict()` returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly, `get_optimizer_state_dict()` provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency, `get_optimizer_state_dict()` converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary. - -Note that results returned by these APIs can be used directly with the `torch.distributed.checkpoint.save()` and `torch.distributed.checkpoint.load()` methods without requiring any additional conversions. - -`set_model_state_dict()` and `set_optimizer_state_dict()` are provided to load the model and optimizer state_dict generated by by their respective getter APIs. - -Note that `set_optimizer_state_dict()` can only be called before `backward()` or after `step()` is called on optimizers. - -Note that this feature is experimental, and API signatures might change in the future. - - -.. autofunction:: torch.distributed.checkpoint.state_dict.get_state_dict - -.. autofunction:: torch.distributed.checkpoint.state_dict.get_model_state_dict - -.. autofunction:: torch.distributed.checkpoint.state_dict.get_optimizer_state_dict - -.. autofunction:: torch.distributed.checkpoint.state_dict.set_state_dict - -.. autofunction:: torch.distributed.checkpoint.state_dict.set_model_state_dict - -.. autofunction:: torch.distributed.checkpoint.state_dict.set_optimizer_state_dict - -.. autoclass:: torch.distributed.checkpoint.state_dict.StateDictOptions - :members: - -For users which are used to using and sharing models in the `torch.save` format, the following methods are provided which provide offline utilities for converting betweeing formats. - -.. automodule:: torch.distributed.checkpoint.format_utils - -.. currentmodule:: torch.distributed.checkpoint.format_utils - -.. autofunction:: dcp_to_torch_save -.. autofunction:: torch_save_to_dcp - -The following classes can also be utilized for online loading and resharding of models from the torch.save format. - -.. autoclass:: torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader - :members: - -.. autoclass:: torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner - :members: - -The following experimental interfaces are provided for improved observability in production environments: - -.. py:module:: torch.distributed.checkpoint.logger -.. py:module:: torch.distributed.checkpoint.logging_handlers diff --git a/docs/source/distributed.elastic.md b/docs/source/distributed.elastic.md new file mode 100644 index 00000000000000..f23ca39d724aee --- /dev/null +++ b/docs/source/distributed.elastic.md @@ -0,0 +1,46 @@ +# Torch Distributed Elastic + +Makes distributed PyTorch fault-tolerant and elastic. + +## Get Started + +```{toctree} +:caption: Usage +:maxdepth: 1 + +elastic/quickstart +elastic/train_script +elastic/examples +``` + +## Documentation + +```{toctree} +:caption: API +:maxdepth: 1 + +elastic/run +elastic/agent +elastic/multiprocessing +elastic/errors +elastic/rendezvous +elastic/timer +elastic/metrics +elastic/events +elastic/subprocess_handler +elastic/control_plane +``` + +```{toctree} +:caption: Advanced +:maxdepth: 1 + +elastic/customization +``` + +```{toctree} +:caption: Plugins +:maxdepth: 1 + +elastic/kubernetes +``` diff --git a/docs/source/distributed.elastic.rst b/docs/source/distributed.elastic.rst deleted file mode 100644 index 0aabb560c9c80b..00000000000000 --- a/docs/source/distributed.elastic.rst +++ /dev/null @@ -1,44 +0,0 @@ -Torch Distributed Elastic -============================ - -Makes distributed PyTorch fault-tolerant and elastic. - -Get Started ---------------- -.. toctree:: - :maxdepth: 1 - :caption: Usage - - elastic/quickstart - elastic/train_script - elastic/examples - -Documentation ---------------- - -.. toctree:: - :maxdepth: 1 - :caption: API - - elastic/run - elastic/agent - elastic/multiprocessing - elastic/errors - elastic/rendezvous - elastic/timer - elastic/metrics - elastic/events - elastic/subprocess_handler - elastic/control_plane - -.. toctree:: - :maxdepth: 1 - :caption: Advanced - - elastic/customization - -.. toctree:: - :maxdepth: 1 - :caption: Plugins - - elastic/kubernetes diff --git a/docs/source/distributed.fsdp.fully_shard.md b/docs/source/distributed.fsdp.fully_shard.md new file mode 100644 index 00000000000000..4a54a41cefdbad --- /dev/null +++ b/docs/source/distributed.fsdp.fully_shard.md @@ -0,0 +1,125 @@ +# torch.distributed.fsdp.fully_shard + +## PyTorch FSDP2 (`fully_shard`) + +PyTorch FSDP2 ([RFC]()) provides +a fully sharded data parallelism (FSDP) implementation targeting performant +eager-mode while using per-parameter sharding for improved usability + +- See the [Getting Started with FSDP2](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) + tutorial for more information. + +- If you are currently using FSDP1, consider migrating to FSDP2 using our + [migration guide](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html#fsdp1-to-fsdp2-migration-guide). + + +The user contract for ``fully_shard(model)`` is as follows + +- For model initialization, fully_shard converts model.parameters() from + plain torch.Tensor to DTensor in-place. The parameters are moved to the + appropriate device according to the device mesh. + +- Before forward and backward passes, pre-forward/backward hooks are + responsible for all-gathering the parameters and converting model.parameters() + from DTensor to plain torch.Tensor. + +- After forward and backward passes, post-forward/backward hooks free + the unsharded parameters (no communication needed) and convert + model.parameters() from plain torch.Tensor back to DTensor. + +- For the optimizer, it must be initialized with the DTensor model.parameters(), + and the optimizer step should be performed on DTensor parameters. + +- Call ``model(input)`` instead of ``model.forward(input)`` to trigger pre-forward + hooks to all-gather parameters. To make model.forward(input) work, users must + either call ``model.unshard()`` explicitly or use ``register_fsdp_forward_method(model, "forward")`` + to register the forward method for hooking. + +- fully_shard groups parameters together for a single all-gather. User should apply + fully_shard in a bottom-up manner. For example, in a Transformer model, fully_shard + should be applied to each layer before applying it to the root model. When applied + to the root model, fully_shard excludes model.parameters() from each layer and groups + the remaining parameters (e.g., embeddings, output projection) into a single + all-gather group. + +- ``type(model)`` is "unioned" with ``FSDPModule`` in-place. For example, if model + is originally of type nn.Linear, then fully_shard changes ``type(model)`` from + nn.Linear to ``FSDPLinear`` in-place. ``FSDPLinear`` is an instance of both + nn.Linear and ``FSDPModule``. It retains all methods of nn.Linear while also + exposing FSDP2-specific APIs under FSDPModule, such as ``reshard()`` and + ``unshard()``. + +- Fully Qualified Names (FQNs) for parameters remain unchanged. If we call + ``model.state_dict()``, the FQNs are the same before and after applying + fully_shard. This is because fully_shard does not wrap the module but only + registers hooks to the original module. + + +Compared to PyTorch FSDP1 (`FullyShardedDataParallel`): + +- FSDP2 uses `DTensor`-based dim-0 per-parameter sharding for a simpler + sharding representation compared to FSDP1's flat-parameter sharding, while + preserving similar throughput performance. More specifically, FSDP2 chunks + each parameter on dim-0 across the data parallel workers (using + `torch.chunk(dim=0)`), whereas FSDP1 flattens, concatenates, and chunks a + group of tensors together, making reasoning about what data is present on + each worker and resharding to different parallelisms complex. Per-parameter + sharding provides a more intuitive user experience, relaxes constraints + around frozen parameters, and allows for communication-free (sharded) state + dicts, which otherwise require all-gathers in FSDP1. +- FSDP2 implements a different memory management approach to handle the + multi-stream usages that avoids `torch.Tensor.record_stream`. This ensures + deterministic and expected memory usage and does not require blocking the CPU + like in FSDP1's `limit_all_gathers=True`. +- FSDP2 exposes APIs for manual control over prefetching and collective + scheduling, allowing power users more customization. See the methods on + `FSDPModule` below for details. +- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly + support full state dicts. Instead, users can reshard the sharded state dicts + containing `DTensor` s to full state dicts themselves using `DTensor` + APIs like `DTensor.full_tensor()` or by using higher-level APIs like + [PyTorch Distributed Checkpoint](https://pytorch.org/docs/stable/distributed.checkpoint.html) 's + distributed state dict APIs. Also, some other args have been removed; see + [here](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md) for + details. + + +```{eval-rst} +.. currentmodule:: torch.distributed.fsdp +``` + +The frontend API is `fully_shard` that can be called on a `module`: + +```{eval-rst} +.. autofunction:: fully_shard +``` + +```{eval-rst} +.. autoclass:: FSDPModule + :members: + :member-order: bysource +``` + +```{eval-rst} +.. autoclass:: UnshardHandle + :members: +``` + +```{eval-rst} +.. autofunction:: register_fsdp_forward_method +``` + +```{eval-rst} +.. autoclass:: MixedPrecisionPolicy + :members: +``` + +```{eval-rst} +.. autoclass:: OffloadPolicy + :members: +``` + +```{eval-rst} +.. autoclass:: CPUOffloadPolicy + :members: +``` diff --git a/docs/source/distributed.fsdp.fully_shard.rst b/docs/source/distributed.fsdp.fully_shard.rst deleted file mode 100644 index edad3e4e989a55..00000000000000 --- a/docs/source/distributed.fsdp.fully_shard.rst +++ /dev/null @@ -1,85 +0,0 @@ -torch.distributed.fsdp.fully_shard -================================== - -PyTorch FSDP2 (``fully_shard``) -------------------------------- - -PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation -targeting performant eager-mode while using per-parameter sharding for improved -usability. - -- If you are new to FSDP, we recommend that you start with FSDP2 due to improved - usability. See `TorchTitan `_ for code examples. -- If you are currently using FSDP1, consider evaluating the following - differences to see if you should switch to FSDP2: - -Compared to PyTorch FSDP1 (``FullyShardedDataParallel``): - -- FSDP2 uses ``DTensor``-based dim-0 per-parameter sharding for a simpler - sharding representation compared to FSDP1's flat-parameter sharding, while - preserving similar throughput performance. More specifically, FSDP2 chunks - each parameter on dim-0 across the data parallel workers (using - ``torch.chunk(dim=0)``), whereas FSDP1 flattens, concatenates, and chunks a - group of tensors together, making reasoning about what data is present on - each worker and resharding to different parallelisms complex. Per-parameter - sharding provides a more intuitive user experience, relaxes constraints - around frozen parameters, and allows for communication-free (sharded) state - dicts, which otherwise require all-gathers in FSDP1. -- FSDP2 implements a different memory management approach to handle the - multi-stream usages that avoids ``torch.Tensor.record_stream``. This ensures - deterministic and expected memory usage and does not require blocking the CPU - like in FSDP1's ``limit_all_gathers=True``. -- FSDP2 exposes APIs for manual control over prefetching and collective - scheduling, allowing power users more customization. See the methods on - ``FSDPModule`` below for details. -- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly - support full state dicts. Instead, users can reshard the sharded state dicts - containing ``DTensor`` s to full state dicts themselves using ``DTensor`` - APIs like ``DTensor.full_tensor()`` or by using higher-level APIs like - `PyTorch Distributed Checkpoint `_ 's - distributed state dict APIs. Also, some other args have been removed; see - `here `_ for - details. - -If you are onboarding FSDP for the first time or if any of the above appeals to -your use case, we recommend that you consider using FSDP2. - -See `this RFC `_ for details -on system design and implementation. - -.. note:: - ``torch.distributed.fsdp.fully_shard`` is currently in prototype state and - under development. The core API will likely not change, but we may make some - API changes if necessary. - -.. currentmodule:: torch.distributed.fsdp - -The frontend API is ``fully_shard`` that can be called on a ``module``: - -.. autofunction:: fully_shard - -Calling ``fully_shard(module)`` dynamically constructs a new class that -subclasses ``type(module)`` and an FSDP class ``FSDPModule``. For example, if -we call ``fully_shard(linear)`` on a module ``linear: nn.Linear``, then FSDP -constructs a new class ``FSDPLinear`` and changes ``linear`` 's type to this. -Otherwise, ``fully_shard`` does not change the module structure and parameter -fully-qualified names. The class ``FSDPModule`` allows providing some -FSDP-specific methods on the module. - -.. autoclass:: FSDPModule - :members: - :member-order: bysource - -.. autoclass:: UnshardHandle - :members: - -.. autofunction:: register_fsdp_forward_method - -.. autoclass:: MixedPrecisionPolicy - :members: - -.. autoclass:: OffloadPolicy - :members: - -.. autoclass:: CPUOffloadPolicy - :members: diff --git a/docs/source/distributed.md b/docs/source/distributed.md new file mode 100644 index 00000000000000..6100e32452d857 --- /dev/null +++ b/docs/source/distributed.md @@ -0,0 +1,1473 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Distributed communication package - torch.distributed + +:::{note} +Please refer to [PyTorch Distributed Overview](https://pytorch.org/tutorials/beginner/dist_overview.html) +for a brief introduction to all features related to distributed training. +::: + +```{eval-rst} +.. automodule:: torch.distributed +``` + +```{eval-rst} +.. currentmodule:: torch.distributed +``` + +## Backends + +`torch.distributed` supports three built-in backends, each with +different capabilities. The table below shows which functions are available +for use with CPU / CUDA tensors. +MPI supports CUDA only if the implementation used to build PyTorch supports it. + +```{eval-rst} ++----------------+-----------+-----------+-----------+ +| Backend | ``gloo`` | ``mpi`` | ``nccl`` | ++----------------+-----+-----+-----+-----+-----+-----+ +| Device | CPU | GPU | CPU | GPU | CPU | GPU | ++================+=====+=====+=====+=====+=====+=====+ +| send | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| recv | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| broadcast | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| all_reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| all_gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| scatter | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| all_to_all | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+ +``` + +### Backends that come with PyTorch + +PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype). +By default for Linux, the Gloo and NCCL backends are built and included in PyTorch +distributed (NCCL only when building with CUDA). MPI is an optional backend that can only be +included if you build PyTorch from source. (e.g. building PyTorch on a host that has MPI +installed.) + +:::{note} +As of PyTorch v1.8, Windows supports all collective communications backend but NCCL, +If the `init_method` argument of {func}`init_process_group` points to a file it must adhere +to the following schema: + +- Local file system, `init_method="file:///d:/tmp/some_file"` +- Shared file system, `init_method="file://////{machine_name}/{share_folder_name}/some_file"` + +Same as on Linux platform, you can enable TcpStore by setting environment variables, +MASTER_ADDR and MASTER_PORT. +::: + +### Which backend to use? + +In the past, we were often asked: "which backend should I use?". + +- Rule of thumb + + - Use the NCCL backend for distributed **GPU** training + - Use the Gloo backend for distributed **CPU** training. + +- GPU hosts with InfiniBand interconnect + + - Use NCCL, since it's the only backend that currently supports + InfiniBand and GPUDirect. + +- GPU hosts with Ethernet interconnect + + - Use NCCL, since it currently provides the best distributed GPU + training performance, especially for multiprocess single-node or + multi-node distributed training. If you encounter any problem with + NCCL, use Gloo as the fallback option. (Note that Gloo currently + runs slower than NCCL for GPUs.) + +- CPU hosts with InfiniBand interconnect + + - If your InfiniBand has enabled IP over IB, use Gloo, otherwise, + use MPI instead. We are planning on adding InfiniBand support for + Gloo in the upcoming releases. + +- CPU hosts with Ethernet interconnect + + - Use Gloo, unless you have specific reasons to use MPI. + +### Common environment variables + +#### Choosing the network interface to use + +By default, both the NCCL and Gloo backends will try to find the right network interface to use. +If the automatically detected interface is not correct, you can override it using the following +environment variables (applicable to the respective backend): + +- **NCCL_SOCKET_IFNAME**, for example `export NCCL_SOCKET_IFNAME=eth0` +- **GLOO_SOCKET_IFNAME**, for example `export GLOO_SOCKET_IFNAME=eth0` + +If you're using the Gloo backend, you can specify multiple interfaces by separating +them by a comma, like this: `export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3`. +The backend will dispatch operations in a round-robin fashion across these interfaces. +It is imperative that all processes specify the same number of interfaces in this variable. + +#### Other NCCL environment variables + +**Debugging** - in case of NCCL failure, you can set `NCCL_DEBUG=INFO` to print an explicit +warning message as well as basic NCCL initialization information. + +You may also use `NCCL_DEBUG_SUBSYS` to get more details about a specific +aspect of NCCL. For example, `NCCL_DEBUG_SUBSYS=COLL` would print logs of +collective calls, which may be helpful when debugging hangs, especially those +caused by collective type or message size mismatch. In case of topology +detection failure, it would be helpful to set `NCCL_DEBUG_SUBSYS=GRAPH` +to inspect the detailed detection result and save as reference if further help +from NCCL team is needed. + +**Performance tuning** - NCCL performs automatic tuning based on its topology detection to save users' +tuning effort. On some socket-based systems, users may still try tuning +`NCCL_SOCKET_NTHREADS` and `NCCL_NSOCKS_PERTHREAD` to increase socket +network bandwidth. These two environment variables have been pre-tuned by NCCL +for some cloud providers, such as AWS or GCP. + +For a full list of NCCL environment variables, please refer to +[NVIDIA NCCL's official documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html) + +You can tune NCCL communicators even further using `torch.distributed.ProcessGroupNCCL.NCCLConfig` +and `torch.distributed.ProcessGroupNCCL.Options`. Learn more about them using `help` +(e.g. `help(torch.distributed.ProcessGroupNCCL.NCCLConfig)`) in the interpreter. + +(distributed-basics)= + +## Basics + +The `torch.distributed` package provides PyTorch support and communication primitives +for multiprocess parallelism across several computation nodes running on one or more +machines. The class {func}`torch.nn.parallel.DistributedDataParallel` builds on this +functionality to provide synchronous distributed training as a wrapper around any +PyTorch model. This differs from the kinds of parallelism provided by +{doc}`multiprocessing` and {func}`torch.nn.DataParallel` in that it supports +multiple network-connected machines and in that the user must explicitly launch a separate +copy of the main training script for each process. + +In the single-machine synchronous case, `torch.distributed` or the +{func}`torch.nn.parallel.DistributedDataParallel` wrapper may still have advantages over other +approaches to data-parallelism, including {func}`torch.nn.DataParallel`: + +- Each process maintains its own optimizer and performs a complete optimization step with each + iteration. While this may appear redundant, since the gradients have already been gathered + together and averaged across processes and are thus the same for every process, this means + that no parameter broadcast step is needed, reducing time spent transferring tensors between + nodes. +- Each process contains an independent Python interpreter, eliminating the extra interpreter + overhead and "GIL-thrashing" that comes from driving several execution threads, model + replicas, or GPUs from a single Python process. This is especially important for models that + make heavy use of the Python runtime, including models with recurrent layers or many small + components. + +## Initialization + +The package needs to be initialized using the {func}`torch.distributed.init_process_group` +or {func}`torch.distributed.device_mesh.init_device_mesh` function before calling any other methods. +Both block until all processes have joined. + +:::{warning} +Initialization is not thread-safe. Process group creation should be performed from a single thread, to prevent +inconsistent 'UUID' assignment across ranks, and to prevent races during initialization that can lead to hangs. +::: + +```{eval-rst} +.. autofunction:: is_available +``` + +```{eval-rst} +.. autofunction:: init_process_group +``` + +```{eval-rst} +.. autofunction:: torch.distributed.device_mesh.init_device_mesh +``` + +```{eval-rst} +.. autofunction:: is_initialized +``` + +```{eval-rst} +.. autofunction:: is_mpi_available +``` + +```{eval-rst} +.. autofunction:: is_nccl_available +``` + +```{eval-rst} +.. autofunction:: is_gloo_available +``` + +```{eval-rst} +.. autofunction:: torch.distributed.distributed_c10d.is_xccl_available +``` + +```{eval-rst} +.. autofunction:: is_torchelastic_launched +``` + +______________________________________________________________________ + +Currently three initialization methods are supported: + +### TCP initialization + +There are two ways to initialize using TCP, both requiring a network address +reachable from all processes and a desired `world_size`. The first way +requires specifying an address that belongs to the rank 0 process. This +initialization method requires that all processes have manually specified ranks. + +Note that multicast address is not supported anymore in the latest distributed +package. `group_name` is deprecated as well. + +``` +import torch.distributed as dist + +# Use address of one of the machines +dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456', + rank=args.rank, world_size=4) +``` + +### Shared file-system initialization + +Another initialization method makes use of a file system that is shared and +visible from all machines in a group, along with a desired `world_size`. The URL should start +with `file://` and contain a path to a non-existent file (in an existing +directory) on a shared file system. File-system initialization will automatically +create that file if it doesn't exist, but will not delete the file. Therefore, it +is your responsibility to make sure that the file is cleaned up before the next +{func}`init_process_group` call on the same file path/name. + +Note that automatic rank assignment is not supported anymore in the latest +distributed package and `group_name` is deprecated as well. + +:::{warning} +This method assumes that the file system supports locking using `fcntl` - most +local systems and NFS support it. +::: + +:::{warning} +This method will always create the file and try its best to clean up and remove +the file at the end of the program. In other words, each initialization with +the file init method will need a brand new empty file in order for the initialization +to succeed. If the same file used by the previous initialization (which happens not +to get cleaned up) is used again, this is unexpected behavior and can often cause +deadlocks and failures. Therefore, even though this method will try its best to clean up +the file, if the auto-delete happens to be unsuccessful, it is your responsibility +to ensure that the file is removed at the end of the training to prevent the same +file to be reused again during the next time. This is especially important +if you plan to call {func}`init_process_group` multiple times on the same file name. +In other words, if the file is not removed/cleaned up and you call +{func}`init_process_group` again on that file, failures are expected. +The rule of thumb here is that, make sure that the file is non-existent or +empty every time {func}`init_process_group` is called. +::: + +``` +import torch.distributed as dist + +# rank should always be specified +dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile', + world_size=4, rank=args.rank) +``` + +### Environment variable initialization + +This method will read the configuration from environment variables, allowing +one to fully customize how the information is obtained. The variables to be set +are: + +- `MASTER_PORT` - required; has to be a free port on machine with rank 0 +- `MASTER_ADDR` - required (except for rank 0); address of rank 0 node +- `WORLD_SIZE` - required; can be set either here, or in a call to init function +- `RANK` - required; can be set either here, or in a call to init function + +The machine with rank 0 will be used to set up all connections. + +This is the default method, meaning that `init_method` does not have to be specified (or +can be `env://`). + +### Improving initialization time + +- `TORCH_GLOO_LAZY_INIT` - establishes connections on demand rather than + using a full mesh which can greatly improve initialization time for non all2all + operations. + +## Post-Initialization + +Once {func}`torch.distributed.init_process_group` was run, the following functions can be used. To +check whether the process group has already been initialized use {func}`torch.distributed.is_initialized`. + +```{eval-rst} +.. autoclass:: Backend + :members: +``` + +```{eval-rst} +.. autofunction:: get_backend +``` + +```{eval-rst} +.. autofunction:: get_rank +``` + +```{eval-rst} +.. autofunction:: get_world_size +``` + +## Shutdown + +It is important to clean up resources on exit by calling {func}`destroy_process_group`. + +The simplest pattern to follow is to destroy every process group and backend by calling +{func}`destroy_process_group()` with the default value of None for the `group` argument, at a +point in the training script where communications are no longer needed, usually near the +end of main(). The call should be made once per trainer-process, not at the outer +process-launcher level. + +if {func}`destroy_process_group` is not called by all ranks in a pg within the timeout duration, +especially when there are multiple process-groups in the application e.g. for N-D parallelism, +hangs on exit are possible. This is because the destructor for ProcessGroupNCCL calls ncclCommAbort, +which must be called collectively, but the order of calling ProcessGroupNCCL's destructor if called +by python's GC is not deterministic. Calling {func}`destroy_process_group` helps by ensuring +ncclCommAbort is called in a consistent order across ranks, and avoids calling ncclCommAbort +during ProcessGroupNCCL's destructor. + +### Reinitialization + +`destroy_process_group` can also be used to destroy individual process groups. One use +case could be fault tolerant training, where a process group may be destroyed and then +a new one initialized during runtime. In this case, it's critical to synchronize the trainer +processes using some means other than torch.distributed primitives \_after\_ calling destroy and +before subsequently initializing. This behavior is currently unsupported/untested, due to +the difficulty of achieving this synchronization, and is considered a known issue. Please file +a github issue or RFC if this is a use case that's blocking you. + +______________________________________________________________________ + +## Groups + +By default collectives operate on the default group (also called the world) and +require all processes to enter the distributed function call. However, some workloads can benefit +from more fine-grained communication. This is where distributed groups come +into play. {func}`~torch.distributed.new_group` function can be +used to create new groups, with arbitrary subsets of all processes. It returns +an opaque group handle that can be given as a `group` argument to all collectives +(collectives are distributed functions to exchange information in certain well-known programming patterns). + +```{eval-rst} +.. autofunction:: new_group +``` + +```{eval-rst} +.. autofunction:: get_group_rank +``` + +```{eval-rst} +.. autofunction:: get_global_rank +``` + +```{eval-rst} +.. autofunction:: get_process_group_ranks + +``` + +## DeviceMesh + +DeviceMesh is a higher level abstraction that manages process groups (or NCCL communicators). +It allows user to easily create inter node and intra node process groups without worrying about +how to set up the ranks correctly for different sub process groups, and it helps manage those +distributed process group easily. {func}`~torch.distributed.device_mesh.init_device_mesh` function can be +used to create new DeviceMesh, with a mesh shape describing the device topology. + +```{eval-rst} +.. autoclass:: torch.distributed.device_mesh.DeviceMesh + :members: +``` + +## Point-to-point communication + +```{eval-rst} +.. autofunction:: send +``` + +```{eval-rst} +.. autofunction:: recv +``` + +{func}`~torch.distributed.isend` and {func}`~torch.distributed.irecv` +return distributed request objects when used. In general, the type of this object is unspecified +as they should never be created manually, but they are guaranteed to support two methods: + +- `is_completed()` - returns True if the operation has finished +- `wait()` - will block the process until the operation is finished. + `is_completed()` is guaranteed to return True once it returns. + +```{eval-rst} +.. autofunction:: isend +``` + +```{eval-rst} +.. autofunction:: irecv +``` + +```{eval-rst} +.. autofunction:: send_object_list +``` + +```{eval-rst} +.. autofunction:: recv_object_list +``` + +```{eval-rst} +.. autofunction:: batch_isend_irecv +``` + +```{eval-rst} +.. autoclass:: P2POp +``` + +## Synchronous and asynchronous collective operations + +Every collective operation function supports the following two kinds of operations, +depending on the setting of the `async_op` flag passed into the collective: + +**Synchronous operation** - the default mode, when `async_op` is set to `False`. +When the function returns, it is guaranteed that +the collective operation is performed. In the case of CUDA operations, it is not guaranteed +that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any +further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, +function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of +synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream +synchronization, see [CUDA Semantics](https://pytorch.org/docs/stable/notes/cuda.html). +See the below script to see examples of differences in these semantics for CPU and CUDA operations. + +**Asynchronous operation** - when `async_op` is set to True. The collective operation function +returns a distributed request object. In general, you don't need to create it manually and it +is guaranteed to support two methods: + +- `is_completed()` - in the case of CPU collectives, returns `True` if completed. In the case of CUDA operations, + returns `True` if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the + default stream without further synchronization. +- `wait()` - in the case of CPU collectives, will block the process until the operation is completed. In the case + of CUDA collectives, will block the currently active CUDA stream until the operation is completed (but will not block the CPU). +- `get_future()` - returns `torch._C.Future` object. Supported for NCCL, also supported for most operations on GLOO + and MPI, except for peer to peer operations. + Note: as we continue adopting Futures and merging APIs, `get_future()` call might become redundant. + +**Example** + +The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. +It shows the explicit need to synchronize when using collective outputs on different CUDA streams: + +``` +# Code runs on each rank. +dist.init_process_group("nccl", rank=rank, world_size=2) +output = torch.tensor([rank]).cuda(rank) +s = torch.cuda.Stream() +handle = dist.all_reduce(output, async_op=True) +# Wait ensures the operation is enqueued, but not necessarily complete. +handle.wait() +# Using result on non-default stream. +with torch.cuda.stream(s): + s.wait_stream(torch.cuda.default_stream()) + output.add_(100) +if rank == 0: + # if the explicit call to wait_stream was omitted, the output below will be + # non-deterministically 1 or 101, depending on whether the allreduce overwrote + # the value after the add completed. + print(output) +``` + +## Collective functions + +```{eval-rst} +.. autofunction:: broadcast +``` + +```{eval-rst} +.. autofunction:: broadcast_object_list +``` + +```{eval-rst} +.. autofunction:: all_reduce +``` + +```{eval-rst} +.. autofunction:: reduce +``` + +```{eval-rst} +.. autofunction:: all_gather +``` + +```{eval-rst} +.. autofunction:: all_gather_into_tensor +``` + +```{eval-rst} +.. autofunction:: all_gather_object +``` + +```{eval-rst} +.. autofunction:: gather +``` + +```{eval-rst} +.. autofunction:: gather_object +``` + +```{eval-rst} +.. autofunction:: scatter +``` + +```{eval-rst} +.. autofunction:: scatter_object_list +``` + +```{eval-rst} +.. autofunction:: reduce_scatter +``` + +```{eval-rst} +.. autofunction:: reduce_scatter_tensor +``` + +```{eval-rst} +.. autofunction:: all_to_all_single +``` + +```{eval-rst} +.. autofunction:: all_to_all +``` + +```{eval-rst} +.. autofunction:: barrier +``` + +```{eval-rst} +.. autofunction:: monitored_barrier +``` + +```{eval-rst} +.. autoclass:: Work + :members: +``` + +```{eval-rst} +.. autoclass:: ReduceOp +``` + +```{eval-rst} +.. class:: reduce_op + + Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``, + ``MIN``, and ``MAX``. + + :class:`~torch.distributed.ReduceOp` is recommended to use instead. + +``` + +## Distributed Key-Value Store + +The distributed package comes with a distributed key-value store, which can be +used to share information between processes in the group as well as to +initialize the distributed package in +{func}`torch.distributed.init_process_group` (by explicitly creating the store +as an alternative to specifying `init_method`.) There are 3 choices for +Key-Value Stores: {class}`~torch.distributed.TCPStore`, +{class}`~torch.distributed.FileStore`, and {class}`~torch.distributed.HashStore`. + +```{eval-rst} +.. autoclass:: Store + :members: + :special-members: +``` + +```{eval-rst} +.. autoclass:: TCPStore + :members: + :special-members: __init__ +``` + +```{eval-rst} +.. autoclass:: HashStore + :members: + :special-members: __init__ +``` + +```{eval-rst} +.. autoclass:: FileStore + :members: + :special-members: __init__ +``` + +```{eval-rst} +.. autoclass:: PrefixStore + :members: + :special-members: __init__ + +``` + +## Profiling Collective Communication + +Note that you can use `torch.profiler` (recommended, only available after 1.8.1) or `torch.autograd.profiler` to profile collective communication and point-to-point communication APIs mentioned here. All out-of-the-box backends (`gloo`, +`nccl`, `mpi`) are supported and collective communication usage will be rendered as expected in profiling output/traces. Profiling your code is the same as any regular torch operator: + +``` +import torch +import torch.distributed as dist +with torch.profiler(): + tensor = torch.randn(20, 10) + dist.all_reduce(tensor) +``` + +Please refer to the [profiler documentation](https://pytorch.org/docs/main/profiler.html) for a full overview of profiler features. + +## Multi-GPU collective functions + +:::{warning} +The multi-GPU functions (which stand for multiple GPUs per CPU thread) are +deprecated. As of today, PyTorch Distributed's preferred programming model +is one device per thread, as exemplified by the APIs in this document. If +you are a backend developer and want to support multiple devices per thread, +please contact PyTorch Distributed's maintainers. +::: + +(object_collectives)= + +## Object collectives + +:::{warning} +Object collectives have a number of serious limitations. Read further to determine +if they are safe to use for your use case. +::: + +Object collectives are a set of collective-like operations that work on arbitrary +Python objects, as long as they can be pickled. There are various collective patterns +implemented (e.g. broadcast, all_gather, ...) but they each roughly follow this pattern: + +1. convert the input object into a pickle (raw bytes), then shove it into a byte tensor +2. communicate the size of this byte tensor to peers (first collective operation) +3. allocate appropriately sized tensor to perform the real collective +4. communicate the object data (second collective operation) +5. convert raw data back into Python (unpickle) + +Object collectives sometimes have surprising performance or memory characteristics that lead to +long runtimes or OOMs, and thus they should be used with caution. Here are some common issues. + +**Asymmetric pickle/unpickle time** - Pickling objects can be slow, depending on the number, type and size of the objects. +When the collective has a fan-in (e.g. gather_object), the receiving rank(s) must unpickle N times more objects than +the sending rank(s) had to pickle, which can cause other ranks to time out on their next collective. + +**Inefficient tensor communication** - Tensors should be sent via regular collective APIs, not object collective APIs. +It is possible to send Tensors via object collective APIs, but they will be serialized and deserialized (including a +CPU-sync and device-to-host copy in the case of non-CPU tensors), and in almost every case other than debugging or +troubleshooting code, it would be worth the trouble to refactor the code to use non-object collectives instead. + +**Unexpected tensor devices** - If you still want to send tensors via object collectives, there is another aspect +specific to cuda (and possibly other accelerators) tensors. If you pickle a tensor that is currently on `cuda:3`, and +then unpickle it, you will get another tensor on `cuda:3` *regardless of which process you are on, or which CUDA device +is the 'default' device for that process*. With regular tensor collective APIs, 'output tensors' will always be on the +same, local device, which is generally what you'd expect. + +Unpickling a tensor will implicitly activate a CUDA context if it is the first +time a GPU is used by the process, which can waste significant amounts of GPU memory. This issue can be avoided by +moving tensors to CPU before passing them as inputs to an object collective. + +## Third-party backends + +Besides the builtin GLOO/MPI/NCCL backends, PyTorch distributed supports +third-party backends through a run-time register mechanism. +For references on how to develop a third-party backend through C++ Extension, +please refer to [Tutorials - Custom C++ and CUDA Extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html) and +`test/cpp_extensions/cpp_c10d_extension.cpp`. The capability of third-party +backends are decided by their own implementations. + +The new backend derives from `c10d::ProcessGroup` and registers the backend +name and the instantiating interface through {func}`torch.distributed.Backend.register_backend` +when imported. + +When manually importing this backend and invoking {func}`torch.distributed.init_process_group` +with the corresponding backend name, the `torch.distributed` package runs on +the new backend. + +:::{warning} +The support of third-party backend is experimental and subject to change. +::: + +(distributed-launch)= + +## Launch utility + +The `torch.distributed` package also provides a launch utility in +`torch.distributed.launch`. This helper utility can be used to launch +multiple processes per node for distributed training. + +```{eval-rst} +.. automodule:: torch.distributed.launch + +``` + +## Spawn utility + +The {ref}`multiprocessing-doc` package also provides a `spawn` +function in {func}`torch.multiprocessing.spawn`. This helper function +can be used to spawn multiple processes. It works by passing in the +function that you want to run and spawns N processes to run it. This +can be used for multiprocess distributed training as well. + +For references on how to use it, please refer to [PyTorch example - ImageNet +implementation](https://github.com/pytorch/examples/tree/master/imagenet) + +Note that this function requires Python 3.4 or higher. + +## Debugging `torch.distributed` applications + +Debugging distributed applications can be challenging due to hard to understand hangs, crashes, or inconsistent behavior across ranks. `torch.distributed` provides +a suite of tools to help debug training applications in a self-serve fashion: + +### Python Breakpoint + +It is extremely convenient to use python's debugger in a distributed environment, but because it does not work out of the box many people do not use it at all. +PyTorch offers a customized wrapper around pdb that streamlines the process. + +`torch.distributed.breakpoint` makes this process easy. Internally, it customizes `pdb`'s breakpoint behavior in two ways but otherwise behaves as normal `pdb`. +1. Attaches the debugger only on one rank (specified by the user). +2. Ensures all other ranks stop, by using a `torch.distributed.barrier()` that will release once the debugged rank issues a `continue` +3. Reroutes stdin from the child process such that it connects to your terminal. + +To use it, simply issue `torch.distributed.breakpoint(rank)` on all ranks, using the same value for `rank` in each case. + +### Monitored Barrier + +As of v1.10, {func}`torch.distributed.monitored_barrier` exists as an alternative to {func}`torch.distributed.barrier` which fails with helpful information about which rank may be faulty +when crashing, i.e. not all ranks calling into {func}`torch.distributed.monitored_barrier` within the provided timeout. {func}`torch.distributed.monitored_barrier` implements a host-side +barrier using `send`/`recv` communication primitives in a process similar to acknowledgements, allowing rank 0 to report which rank(s) failed to acknowledge +the barrier in time. As an example, consider the following function where rank 1 fails to call into {func}`torch.distributed.monitored_barrier` (in practice this could be due +to an application bug or hang in a previous collective): + +``` +import os +from datetime import timedelta + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def worker(rank): + dist.init_process_group("nccl", rank=rank, world_size=2) + # monitored barrier requires gloo process group to perform host-side sync. + group_gloo = dist.new_group(backend="gloo") + if rank not in [1]: + dist.monitored_barrier(group=group_gloo, timeout=timedelta(seconds=2)) + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + mp.spawn(worker, nprocs=2, args=()) +``` + +The following error message is produced on rank 0, allowing the user to determine which rank(s) may be faulty and investigate further: + +``` +RuntimeError: Rank 1 failed to pass monitoredBarrier in 2000 ms + Original exception: +[gloo/transport/tcp/pair.cc:598] Connection closed by peer [2401:db00:eef0:1100:3560:0:1c05:25d]:8594 +``` + +### `TORCH_DISTRIBUTED_DEBUG` + +With `TORCH_CPP_LOG_LEVEL=INFO`, the environment variable `TORCH_DISTRIBUTED_DEBUG` can be used to trigger additional useful logging and collective synchronization checks to ensure all ranks +are synchronized appropriately. `TORCH_DISTRIBUTED_DEBUG` can be set to either `OFF` (default), `INFO`, or `DETAIL` depending on the debugging level +required. Please note that the most verbose option, `DETAIL` may impact the application performance and thus should only be used when debugging issues. + +Setting `TORCH_DISTRIBUTED_DEBUG=INFO` will result in additional debug logging when models trained with {func}`torch.nn.parallel.DistributedDataParallel` are initialized, and +`TORCH_DISTRIBUTED_DEBUG=DETAIL` will additionally log runtime performance statistics a select number of iterations. These runtime statistics +include data such as forward time, backward time, gradient communication time, etc. As an example, given the following application: + +``` +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +class TwoLinLayerNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Linear(10, 10, bias=False) + self.b = torch.nn.Linear(10, 1, bias=False) + + def forward(self, x): + a = self.a(x) + b = self.b(x) + return (a, b) + + +def worker(rank): + dist.init_process_group("nccl", rank=rank, world_size=2) + torch.cuda.set_device(rank) + print("init model") + model = TwoLinLayerNet().cuda() + print("init ddp") + ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + + inp = torch.randn(10, 10).cuda() + print("train") + + for _ in range(20): + output = ddp_model(inp) + loss = output[0] + output[1] + loss.sum().backward() + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" + os.environ[ + "TORCH_DISTRIBUTED_DEBUG" + ] = "DETAIL" # set to DETAIL for runtime logging. + mp.spawn(worker, nprocs=2, args=()) +``` + +The following logs are rendered at initialization time: + +``` +I0607 16:10:35.739390 515217 logger.cpp:173] [Rank 0]: DDP Initialized with: +broadcast_buffers: 1 +bucket_cap_bytes: 26214400 +find_unused_parameters: 0 +gradient_as_bucket_view: 0 +is_multi_device_module: 0 +iteration: 0 +num_parameter_tensors: 2 +output_device: 0 +rank: 0 +total_parameter_size_bytes: 440 +world_size: 2 +backend_name: nccl +bucket_sizes: 440 +cuda_visible_devices: N/A +device_ids: 0 +dtypes: float +master_addr: localhost +master_port: 29501 +module_name: TwoLinLayerNet +nccl_async_error_handling: N/A +nccl_blocking_wait: N/A +nccl_debug: WARN +nccl_ib_timeout: N/A +nccl_nthreads: N/A +nccl_socket_ifname: N/A +torch_distributed_debug: INFO +``` + +The following logs are rendered during runtime (when `TORCH_DISTRIBUTED_DEBUG=DETAIL` is set): + +``` +I0607 16:18:58.085681 544067 logger.cpp:344] [Rank 1 / 2] Training TwoLinLayerNet unused_parameter_size=0 + Avg forward compute time: 40838608 + Avg backward compute time: 5983335 +Avg backward comm. time: 4326421 + Avg backward comm/comp overlap time: 4207652 +I0607 16:18:58.085693 544066 logger.cpp:344] [Rank 0 / 2] Training TwoLinLayerNet unused_parameter_size=0 + Avg forward compute time: 42850427 + Avg backward compute time: 3885553 +Avg backward comm. time: 2357981 + Avg backward comm/comp overlap time: 2234674 +``` + +In addition, `TORCH_DISTRIBUTED_DEBUG=INFO` enhances crash logging in {func}`torch.nn.parallel.DistributedDataParallel` due to unused parameters in the model. Currently, `find_unused_parameters=True` +must be passed into {func}`torch.nn.parallel.DistributedDataParallel` initialization if there are parameters that may be unused in the forward pass, and as of v1.10, all model outputs are required +to be used in loss computation as {func}`torch.nn.parallel.DistributedDataParallel` does not support unused parameters in the backwards pass. These constraints are challenging especially for larger +models, thus when crashing with an error, {func}`torch.nn.parallel.DistributedDataParallel` will log the fully qualified name of all parameters that went unused. For example, in the above application, +if we modify `loss` to be instead computed as `loss = output[1]`, then `TwoLinLayerNet.a` does not receive a gradient in the backwards pass, and +thus results in `DDP` failing. On a crash, the user is passed information about parameters which went unused, which may be challenging to manually find for large models: + +``` +RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing + the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by +making sure all `forward` function outputs participate in calculating loss. +If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return va +lue of `forward` of your module when reporting this issue (e.g. list, dict, iterable). +Parameters which did not receive grad for rank 0: a.weight +Parameter indices which did not receive grad for rank 0: 0 +``` + +Setting `TORCH_DISTRIBUTED_DEBUG=DETAIL` will trigger additional consistency and synchronization checks on every collective call issued by the user +either directly or indirectly (such as DDP `allreduce`). This is done by creating a wrapper process group that wraps all process groups returned by +{func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs. As a result, these APIs will return a wrapper process group that can be used exactly like a regular process +group, but performs consistency checks before dispatching the collective to an underlying process group. Currently, these checks include a {func}`torch.distributed.monitored_barrier`, +which ensures all ranks complete their outstanding collective calls and reports ranks which are stuck. Next, the collective itself is checked for consistency by +ensuring all collective functions match and are called with consistent tensor shapes. If this is not the case, a detailed error report is included when the +application crashes, rather than a hang or uninformative error message. As an example, consider the following function which has mismatched input shapes into +{func}`torch.distributed.all_reduce`: + +``` +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def worker(rank): + dist.init_process_group("nccl", rank=rank, world_size=2) + torch.cuda.set_device(rank) + tensor = torch.randn(10 if rank == 0 else 20).cuda() + dist.all_reduce(tensor) + torch.cuda.synchronize(device=rank) + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + mp.spawn(worker, nprocs=2, args=()) +``` + +With the `NCCL` backend, such an application would likely result in a hang which can be challenging to root-cause in nontrivial scenarios. If the user enables +`TORCH_DISTRIBUTED_DEBUG=DETAIL` and reruns the application, the following error message reveals the root cause: + +``` +work = default_pg.allreduce([tensor], opts) +RuntimeError: Error when verifying shape tensors for collective ALLREDUCE on rank 0. This likely indicates that input shapes into the collective are mismatched across ranks. Got shapes: 10 +[ torch.LongTensor{1} ] +``` + +:::{note} +For fine-grained control of the debug level during runtime the functions {func}`torch.distributed.set_debug_level`, {func}`torch.distributed.set_debug_level_from_env`, and +{func}`torch.distributed.get_debug_level` can also be used. +::: + +In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `TORCH_SHOW_CPP_STACKTRACES=1` to log the entire callstack when a collective desynchronization is detected. These +collective desynchronization checks will work for all applications that use `c10d` collective calls backed by process groups created with the +{func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs. + +## Logging + +In addition to explicit debugging support via {func}`torch.distributed.monitored_barrier` and `TORCH_DISTRIBUTED_DEBUG`, the underlying C++ library of `torch.distributed` also outputs log +messages at various levels. These messages can be helpful to understand the execution state of a distributed training job and to troubleshoot problems such as network connection failures. The +following matrix shows how the log level can be adjusted via the combination of `TORCH_CPP_LOG_LEVEL` and `TORCH_DISTRIBUTED_DEBUG` environment variables. + +| `TORCH_CPP_LOG_LEVEL` | `TORCH_DISTRIBUTED_DEBUG` | Effective Log Level | +| --------------------- | ------------------------- | ------------------- | +| `ERROR` | ignored | Error | +| `WARNING` | ignored | Warning | +| `INFO` | ignored | Info | +| `INFO` | `INFO` | Debug | +| `INFO` | `DETAIL` | Trace (a.k.a. All) | + +Distributed components raise custom Exception types derived from `RuntimeError`: + +- `torch.distributed.DistError`: This is the base type of all distributed exceptions. +- `torch.distributed.DistBackendError`: This exception is thrown when a backend-specific error occurs. For example, if + the `NCCL` backend is used and the user attempts to use a GPU that is not available to the `NCCL` library. +- `torch.distributed.DistNetworkError`: This exception is thrown when networking + libraries encounter errors (ex: Connection reset by peer) +- `torch.distributed.DistStoreError`: This exception is thrown when the Store encounters + an error (ex: TCPStore timeout) + +```{eval-rst} +.. autoclass:: torch.distributed.DistError +``` + +```{eval-rst} +.. autoclass:: torch.distributed.DistBackendError +``` + +```{eval-rst} +.. autoclass:: torch.distributed.DistNetworkError +``` + +```{eval-rst} +.. autoclass:: torch.distributed.DistStoreError +``` + +If you are running single node training, it may be convenient to interactively breakpoint your script. We offer a way to conveniently breakpoint a single rank: + +```{eval-rst} +.. autofunction:: torch.distributed.breakpoint +``` + +% Distributed modules that are missing specific entries. + +% Adding them here for tracking purposes until they are more permanently fixed. + +```{eval-rst} +.. py:module:: torch.distributed.algorithms +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.model_averaging +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils.data +``` + +```{eval-rst} +.. py:module:: torch.distributed.launcher +``` + +```{eval-rst} +.. py:module:: torch.distributed.nn +``` + +```{eval-rst} +.. py:module:: torch.distributed.nn.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.nn.jit +``` + +```{eval-rst} +.. py:module:: torch.distributed.nn.jit.templates +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.join +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.model_averaging.averagers +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.model_averaging.hierarchical_model_averager +``` + +```{eval-rst} +.. py:module:: torch.distributed.algorithms.model_averaging.utils +``` + +```{eval-rst} +.. py:module:: torch.distributed.argparse_util +``` + +```{eval-rst} +.. py:module:: torch.distributed.c10d_logger +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.default_planner +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.filesystem +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.hf_storage +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.metadata +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.optimizer +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.planner +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.planner_helpers +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.resharding +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.state_dict_loader +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.state_dict_saver +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.stateful +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.storage +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.utils +``` + +```{eval-rst} +.. py:module:: torch.distributed.collective_utils +``` + +```{eval-rst} +.. py:module:: torch.distributed.constants +``` + +```{eval-rst} +.. py:module:: torch.distributed.device_mesh +``` + +```{eval-rst} +.. py:module:: torch.distributed.distributed_c10d +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.agent.server.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.agent.server.local_elastic_agent +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.events.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.events.handlers +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.metrics.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.multiprocessing.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.multiprocessing.errors.error_handler +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.multiprocessing.errors.handlers +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.multiprocessing.redirects +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.multiprocessing.tail_log +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.c10d_rendezvous_backend +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.dynamic_rendezvous +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.etcd_rendezvous +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.etcd_rendezvous_backend +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.etcd_server +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.etcd_store +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.static_tcp_rendezvous +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.rendezvous.utils +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.timer.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.timer.file_based_local_timer +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.timer.local_timer +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils.data.cycling_iterator +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils.data.elastic_distributed_sampler +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils.distributed +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils.log_level +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils.logging +``` + +```{eval-rst} +.. py:module:: torch.distributed.elastic.utils.store +``` + +```{eval-rst} +.. py:module:: torch.distributed.fsdp.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.fsdp.fully_sharded_data_parallel +``` + +```{eval-rst} +.. py:module:: torch.distributed.fsdp.sharded_grad_scaler +``` + +```{eval-rst} +.. py:module:: torch.distributed.fsdp.wrap +``` + +```{eval-rst} +.. py:module:: torch.distributed.launcher.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.logging_handlers +``` + +```{eval-rst} +.. py:module:: torch.distributed.nn.api.remote_module +``` + +```{eval-rst} +.. py:module:: torch.distributed.nn.functional +``` + +```{eval-rst} +.. py:module:: torch.distributed.nn.jit.instantiator +``` + +```{eval-rst} +.. py:module:: torch.distributed.nn.jit.templates.remote_module_template +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.apply_optimizer_in_backward +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.functional_adadelta +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.functional_adagrad +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.functional_adam +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.functional_adamax +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.functional_adamw +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.functional_rmsprop +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.functional_rprop +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.functional_sgd +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.named_optimizer +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.optimizer +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.post_localSGD_optimizer +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.utils +``` + +```{eval-rst} +.. py:module:: torch.distributed.optim.zero_redundancy_optimizer +``` + +```{eval-rst} +.. py:module:: torch.distributed.remote_device +``` + +```{eval-rst} +.. py:module:: torch.distributed.rendezvous +``` + +```{eval-rst} +.. py:module:: torch.distributed.rpc.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.rpc.backend_registry +``` + +```{eval-rst} +.. py:module:: torch.distributed.rpc.constants +``` + +```{eval-rst} +.. py:module:: torch.distributed.rpc.functions +``` + +```{eval-rst} +.. py:module:: torch.distributed.rpc.internal +``` + +```{eval-rst} +.. py:module:: torch.distributed.rpc.options +``` + +```{eval-rst} +.. py:module:: torch.distributed.rpc.rref_proxy +``` + +```{eval-rst} +.. py:module:: torch.distributed.rpc.server_process_global_profiler +``` + +```{eval-rst} +.. py:module:: torch.distributed.tensor.parallel.api +``` + +```{eval-rst} +.. py:module:: torch.distributed.tensor.parallel.ddp +``` + +```{eval-rst} +.. py:module:: torch.distributed.tensor.parallel.fsdp +``` + +```{eval-rst} +.. py:module:: torch.distributed.tensor.parallel.input_reshard +``` + +```{eval-rst} +.. py:module:: torch.distributed.tensor.parallel.loss +``` + +```{eval-rst} +.. py:module:: torch.distributed.tensor.parallel.style +``` + +```{eval-rst} +.. py:module:: torch.distributed.utils +``` + +```{eval-rst} +.. py:module:: torch.distributed.checkpoint.state_dict +``` diff --git a/docs/source/distributed.optim.md b/docs/source/distributed.optim.md new file mode 100644 index 00000000000000..57914e1241fd6c --- /dev/null +++ b/docs/source/distributed.optim.md @@ -0,0 +1,15 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Distributed Optimizers + +:::{warning} +Distributed optimizer is not currently supported when using CUDA tensors +::: + +```{eval-rst} +.. automodule:: torch.distributed.optim + :members: DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer +``` diff --git a/docs/source/distributed.optim.rst b/docs/source/distributed.optim.rst deleted file mode 100644 index 0ad989261c3b8d..00000000000000 --- a/docs/source/distributed.optim.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Distributed Optimizers -====================== - -.. warning :: - Distributed optimizer is not currently supported when using CUDA tensors - -.. automodule:: torch.distributed.optim - :members: DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer diff --git a/docs/source/distributed.pipelining.md b/docs/source/distributed.pipelining.md new file mode 100644 index 00000000000000..2b6dbf186ff489 --- /dev/null +++ b/docs/source/distributed.pipelining.md @@ -0,0 +1,515 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Pipeline Parallelism + +:::{note} +`torch.distributed.pipelining` is currently in alpha state and under +development. API changes may be possible. It was migrated from the [PiPPy](https://github.com/pytorch/PiPPy) project. +::: + +## Why Pipeline Parallel? + +Pipeline Parallelism is one of the **primitive** parallelism for deep learning. +It allows the **execution** of a model to be partitioned such that multiple +**micro-batches** can execute different parts of the model code concurrently. +Pipeline parallelism can be an effective technique for: + +- large-scale training +- bandwidth-limited clusters +- large model inference + +The above scenarios share a commonality that the computation per device cannot +hide the communication of conventional parallelism, for example, the weight +all-gather of FSDP. + +## What is `torch.distributed.pipelining`? + +While promising for scaling, pipelining is often difficult to implement because +it needs to **partition the execution** of a model in addition to model weights. +The partitioning of execution often requires intrusive code changes to your +model. Another aspect of complexity comes from **scheduling micro-batches in a +distributed environment**, with **data flow dependency** considered. + +The `pipelining` package provides a toolkit that does said things +**automatically** which allows easy implementation of pipeline parallelism +on **general** models. + +It consists of two parts: a +**splitting frontend** and a **distributed runtime**. +The splitting frontend takes your model code as-is, splits it up into "model +partitions", and captures the data-flow relationship. The distributed runtime +executes the pipeline stages on different devices in parallel, handling things +like micro-batch splitting, scheduling, communication, and gradient propagation, +etc. + +Overall, the `pipelining` package provides the following features: + +- Splitting of model code based on simple specification. +- Rich support for pipeline schedules, including GPipe, 1F1B, + Interleaved 1F1B and Looped BFS, and providing the infrastructure for writing + customized schedules. +- First-class support for cross-host pipeline parallelism, as this is where PP + is typically used (over slower interconnects). +- Composability with other PyTorch parallel techniques such as data parallel + (DDP, FSDP) or tensor parallel. The [TorchTitan](https://github.com/pytorch/torchtitan) project demonstrates a "3D parallel" + application on the Llama model. + +## Step 1: build `PipelineStage` + +Before we can use a `PipelineSchedule`, we need to create `PipelineStage` +objects that wrap the part of the model running in that stage. The +`PipelineStage` is responsible for allocating communication buffers and +creating send/recv ops to communicate with its peers. It manages intermediate +buffers e.g. for the outputs of forward that have not been consumed yet, and it +provides a utility for running the backwards for the stage model. + +A `PipelineStage` needs to know the input and output shapes for the stage +model, so that it can correctly allocate communication buffers. The shapes must +be static, e.g. at runtime the shapes can not change from step to step. A class +`PipeliningShapeError` will be raised if runtime shapes do not match the +expected shapes. When composing with other paralleisms or applying mixed +precision, these techniques must be taken into account so the `PipelineStage` +knows the correct shape (and dtype) for the output of the stage module at +runtime. + +Users may construct a `PipelineStage` instance directly, by passing in an +`nn.Module` representing the portion of the model that should run on the +stage. This may require changes to the original model code. See the example +in {ref}`option_1_manual`. + +Alternatively, the splitting frontend can use graph partitioning to split your +model into a series of `nn.Module` automatically. This technique requires the +model is traceable with `torch.Export`. Composability of the resulting +`nn.Module` with other parallelism techniques is experimental, and may require +some workarounds. Usage of this frontend may be more appealing if the user +cannot easily change the model code. See {ref}`option_2_tracer` for more +information. + +## Step 2: use `PipelineSchedule` for execution + +We can now attach the `PipelineStage` to a pipeline schedule, and run the +schedule with input data. Here is a GPipe example: + +```python +from torch.distributed.pipelining import ScheduleGPipe + +# Create a schedule +schedule = ScheduleGPipe(stage, n_microbatches) + +# Input data (whole batch) +x = torch.randn(batch_size, in_dim, device=device) + +# Run the pipeline with input `x` +# `x` will be divided into microbatches automatically +if rank == 0: + schedule.step(x) +else: + output = schedule.step() +``` + +Note that the above code needs to be launched for each worker, thus we use a +launcher service to launch multiple processes: + +```bash +torchrun --nproc_per_node=2 example.py +``` + +## Options for Splitting a Model + +(option_1_manual)= + +### Option 1: splitting a model manually + +To directly construct a `PipelineStage`, the user is responsible for providing +a single `nn.Module` instance that owns the relevant `nn.Parameters` and +`nn.Buffers`, and defines a `forward()` method that executes the operations +relevant for that stage. For example, a condensed version of the Transformer +class defined in Torchtitan shows a pattern of building an easily partitionable +model. + +```python +class Transformer(nn.Module): + def __init__(self, model_args: ModelArgs): + super().__init__() + + self.tok_embeddings = nn.Embedding(...) + + # Using a ModuleDict lets us delete layers without affecting names, + # ensuring checkpoints will correctly save and load. + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(...) + + self.output = nn.Linear(...) + + def forward(self, tokens: torch.Tensor): + # Handling layers being 'None' at runtime enables easy pipeline splitting + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h).float() if self.output else h + return output +``` + +A model defined in this manner can be easily configured per stage by first +initializing the whole model (using meta-device to avoid OOM errors), deleting +undesired layers for that stage, and then creating a PipelineStage that wraps +the model. For example: + +```python +with torch.device("meta"): + assert num_stages == 2, "This is a simple 2-stage example" + + # we construct the entire model, then delete the parts we do not need for this stage + # in practice, this can be done using a helper function that automatically divides up layers across stages. + model = Transformer() + + if stage_index == 0: + # prepare the first stage model + del model.layers["1"] + model.norm = None + model.output = None + + elif stage_index == 1: + # prepare the second stage model + model.tok_embeddings = None + del model.layers["0"] + + from torch.distributed.pipelining import PipelineStage + stage = PipelineStage( + model, + stage_index, + num_stages, + device, + ) +``` + +When composing with other Data or Model parallelism techniques, `output_args` +may also be required, if the output shape/dtype of the model chunk will be +affected. + +(option_2_tracer)= + +### Option 2: splitting a model automatically + +If you have a full model and do not want to spend time on modifying it into a +sequence of "model partitions", the `pipeline` API is here to help. +Here is a brief example: + +```python +class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(10, 3) + self.layers = torch.nn.ModuleList( + Layer() for _ in range(2) + ) + self.lm = LMHead() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.emb(x) + for layer in self.layers: + x = layer(x) + x = self.lm(x) + return x +``` + +If we print the model, we can see multiple hierarchies, which makes it hard to split by hand: + +```python +Model( + (emb): Embedding(10, 3) + (layers): ModuleList( + (0-1): 2 x Layer( + (lin): Linear(in_features=3, out_features=3, bias=True) + ) + ) + (lm): LMHead( + (proj): Linear(in_features=3, out_features=3, bias=True) + ) +) +``` + +Let us see how the `pipeline` API works: + +```python +from torch.distributed.pipelining import pipeline, SplitPoint + +# An example micro-batch input +x = torch.LongTensor([1, 2, 4, 5]) + +pipe = pipeline( + module=mod, + mb_args=(x,), + split_spec={ + "layers.1": SplitPoint.BEGINNING, + } +) +``` + +The `pipeline` API splits your model given a `split_spec`, where +`SplitPoint.BEGINNING` stands for adding a split point +*before* execution of certain submodule in the `forward` function, and +similarly, `SplitPoint.END` for split point *after* such. + +If we `print(pipe)`, we can see: + +```python +GraphModule( + (submod_0): GraphModule( + (emb): InterpreterModule() + (layers): Module( + (0): InterpreterModule( + (lin): InterpreterModule() + ) + ) + ) + (submod_1): GraphModule( + (layers): Module( + (1): InterpreterModule( + (lin): InterpreterModule() + ) + ) + (lm): InterpreterModule( + (proj): InterpreterModule() + ) + ) +) + +def forward(self, x): + submod_0 = self.submod_0(x); x = None + submod_1 = self.submod_1(submod_0); submod_0 = None + return (submod_1,) +``` + +The "model partitions" are represented by submodules (`submod_0`, +`submod_1`), each of which is reconstructed with original model operations, weights +and hierarchies. In addition, a "root-level" `forward` function is +reconstructed to capture the data flow between those partitions. Such data flow +will be replayed by the pipeline runtime later, in a distributed fashion. + +The `Pipe` object provides a method for retrieving the "model partitions": + +```python +stage_mod : nn.Module = pipe.get_stage_module(stage_idx) +``` + +The returned `stage_mod` is a `nn.Module`, with which you can create an +optimizer, save or load checkpoints, or apply other parallelisms. + +`Pipe` also allows you to create a distributed stage runtime on a device given +a `ProcessGroup`: + +```python +stage = pipe.build_stage(stage_idx, device, group) +``` + +Alternatively, if you would like to build the stage runtime later after some +modification to the `stage_mod`, you can use a functional version of the +`build_stage` API. For example: + +```python +from torch.distributed.pipelining import build_stage +from torch.nn.parallel import DistributedDataParallel + +dp_mod = DistributedDataParallel(stage_mod) +info = pipe.info() +stage = build_stage(dp_mod, stage_idx, info, device, group) +``` + +:::{note} +The `pipeline` frontend uses a tracer (`torch.export`) to capture your +model into a single graph. If your model is not full-graph'able, you can use +our manual frontend below. +::: + +## Hugging Face Examples + +In the [PiPPy](https://github.com/pytorch/PiPPy) repo where this package was +original created, we kept examples based on unmodified Hugging Face models. +See the [examples/huggingface](https://github.com/pytorch/PiPPy/tree/main/examples/huggingface) directory. + +Examples include: + +- [GPT2](https://github.com/pytorch/PiPPy/tree/main/examples/huggingface/pippy_gpt2.py) +- [Llama](https://github.com/pytorch/PiPPy/tree/main/examples/llama) + +## Technical Deep Dive + +### How does the `pipeline` API split a model? + +First, the `pipeline` API turns our model into a directed acyclic graph (DAG) +by tracing the model. It traces the model using `torch.export` -- a PyTorch 2 +full-graph capturing tool. + +Then, it groups together the **operations and parameters** needed by a stage +into a reconstructed submodule: `submod_0`, `submod_1`, ... + +Different from conventional submodule access methods like `Module.children()`, +the `pipeline` API does not only cut the module structure of your model, but +also the **forward** function of your model. + +This is necessary because model structure like `Module.children()` merely +captures information during `Module.__init__()`, and does not capture any +information about `Module.forward()`. Said differently, `Module.children()` +lacks information about the following aspects key to pipelininig: + +- Execution order of child modules in `forward` +- Activation flows between child modules +- Whether there are any functional operators between child modules (for example, + `relu` or `add` operations will not be captured by `Module.children()`). + +The `pipeline` API, on the contrary, makes sure that the `forward` behavior +is truly preserved. It also captures the activation flow between the partitions, +helping the distributed runtime to make correct send/receive calls without human +intervention. + +Another flexibility of the `pipeline` API is that split points can be at +arbitrary levels within your model hierarchy. In the split partitions, the original model +hierarchy related to that partition will be reconstructed at no cost to you. +At a result, fully-qualified names (FQNs) pointing to a submodule or parameter +would be still valid, and services that relies on FQNs (such as FSDP, TP or +checkpointing) can still run with your partitioned modules with almost zero code +change. + +## Implementing Your Own Schedule + +You can implement your own pipeline schedule by extending one of the following two class: + +- `PipelineScheduleSingle` +- `PipelineScheduleMulti` + +`PipelineScheduleSingle` is for schedules that assigns *only one* stage per rank. +`PipelineScheduleMulti` is for schedules that assigns multiple stages per rank. + +For example, `ScheduleGPipe` and `Schedule1F1B` are subclasses of `PipelineScheduleSingle`. +Whereas, `ScheduleInterleaved1F1B`, `ScheduleLoopedBFS`, `ScheduleInterleavedZeroBubble`, and `ScheduleZBVZeroBubble` +are subclasses of `PipelineScheduleMulti`. + +## Logging + +You can turn on additional logging using the `TORCH_LOGS` environment variable from [torch.\_logging](https://pytorch.org/docs/main/logging.html#module-torch._logging): + +- `TORCH_LOGS=+pp` will display `logging.DEBUG` messages and all levels above it. +- `TORCH_LOGS=pp` will display `logging.INFO` messages and above. +- `TORCH_LOGS=-pp` will display `logging.WARNING` messages and above. + +## API Reference + +```{eval-rst} +.. automodule:: torch.distributed.pipelining +``` + +### Model Split APIs + +The following set of APIs transform your model into a pipeline representation. + +```{eval-rst} +.. currentmodule:: torch.distributed.pipelining +``` + +```{eval-rst} +.. autoclass:: SplitPoint +``` + +```{eval-rst} +.. autofunction:: pipeline +``` + +```{eval-rst} +.. autoclass:: Pipe +``` + +```{eval-rst} +.. autofunction:: pipe_split +``` + +### Microbatch Utilities + +```{eval-rst} +.. automodule:: torch.distributed.pipelining.microbatch +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.pipelining.microbatch +``` + +```{eval-rst} +.. autoclass:: TensorChunkSpec +``` + +```{eval-rst} +.. autofunction:: split_args_kwargs_into_chunks +``` + +```{eval-rst} +.. autofunction:: merge_chunks +``` + +### Pipeline Stages + +```{eval-rst} +.. automodule:: torch.distributed.pipelining.stage +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.pipelining.stage +``` + +```{eval-rst} +.. autoclass:: PipelineStage +``` + +```{eval-rst} +.. autofunction:: build_stage +``` + +### Pipeline Schedules + +```{eval-rst} +.. automodule:: torch.distributed.pipelining.schedules +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.pipelining.schedules +``` + +```{eval-rst} +.. autoclass:: ScheduleGPipe +``` + +```{eval-rst} +.. autoclass:: Schedule1F1B +``` + +```{eval-rst} +.. autoclass:: ScheduleInterleaved1F1B +``` + +```{eval-rst} +.. autoclass:: ScheduleLoopedBFS +``` + +```{eval-rst} +.. autoclass:: ScheduleInterleavedZeroBubble +``` + +```{eval-rst} +.. autoclass:: ScheduleZBVZeroBubble +``` + +```{eval-rst} +.. autoclass:: PipelineScheduleSingle + :members: +``` + +```{eval-rst} +.. autoclass:: PipelineScheduleMulti + :members: +``` diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst deleted file mode 100644 index 77aa8da7784a4a..00000000000000 --- a/docs/source/distributed.pipelining.rst +++ /dev/null @@ -1,491 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Pipeline Parallelism -#################### - -.. note:: - ``torch.distributed.pipelining`` is currently in alpha state and under - development. API changes may be possible. It was migrated from the `PiPPy - `_ project. - - -Why Pipeline Parallel? -********************** - -Pipeline Parallelism is one of the **primitive** parallelism for deep learning. -It allows the **execution** of a model to be partitioned such that multiple -**micro-batches** can execute different parts of the model code concurrently. -Pipeline parallelism can be an effective technique for: - -* large-scale training -* bandwidth-limited clusters -* large model inference - -The above scenarios share a commonality that the computation per device cannot -hide the communication of conventional parallelism, for example, the weight -all-gather of FSDP. - - -What is ``torch.distributed.pipelining``? -***************************************** - -While promising for scaling, pipelining is often difficult to implement because -it needs to **partition the execution** of a model in addition to model weights. -The partitioning of execution often requires intrusive code changes to your -model. Another aspect of complexity comes from **scheduling micro-batches in a -distributed environment**, with **data flow dependency** considered. - -The ``pipelining`` package provides a toolkit that does said things -**automatically** which allows easy implementation of pipeline parallelism -on **general** models. - -It consists of two parts: a -**splitting frontend** and a **distributed runtime**. -The splitting frontend takes your model code as-is, splits it up into "model -partitions", and captures the data-flow relationship. The distributed runtime -executes the pipeline stages on different devices in parallel, handling things -like micro-batch splitting, scheduling, communication, and gradient propagation, -etc. - -Overall, the ``pipelining`` package provides the following features: - -* Splitting of model code based on simple specification. -* Rich support for pipeline schedules, including GPipe, 1F1B, - Interleaved 1F1B and Looped BFS, and providing the infrastructure for writing - customized schedules. -* First-class support for cross-host pipeline parallelism, as this is where PP - is typically used (over slower interconnects). -* Composability with other PyTorch parallel techniques such as data parallel - (DDP, FSDP) or tensor parallel. The `TorchTitan - `_ project demonstrates a "3D parallel" - application on the Llama model. - - -Step 1: build ``PipelineStage`` -******************************* - -Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage`` -objects that wrap the part of the model running in that stage. The -``PipelineStage`` is responsible for allocating communication buffers and -creating send/recv ops to communicate with its peers. It manages intermediate -buffers e.g. for the outputs of forward that have not been consumed yet, and it -provides a utility for running the backwards for the stage model. - -A ``PipelineStage`` needs to know the input and output shapes for the stage -model, so that it can correctly allocate communication buffers. The shapes must -be static, e.g. at runtime the shapes can not change from step to step. A class -``PipeliningShapeError`` will be raised if runtime shapes do not match the -expected shapes. When composing with other paralleisms or applying mixed -precision, these techniques must be taken into account so the ``PipelineStage`` -knows the correct shape (and dtype) for the output of the stage module at -runtime. - -Users may construct a ``PipelineStage`` instance directly, by passing in an -``nn.Module`` representing the portion of the model that should run on the -stage. This may require changes to the original model code. See the example -in :ref:`option_1_manual`. - -Alternatively, the splitting frontend can use graph partitioning to split your -model into a series of ``nn.Module`` automatically. This technique requires the -model is traceable with ``torch.Export``. Composability of the resulting -``nn.Module`` with other parallelism techniques is experimental, and may require -some workarounds. Usage of this frontend may be more appealing if the user -cannot easily change the model code. See :ref:`option_2_tracer` for more -information. - - -Step 2: use ``PipelineSchedule`` for execution -********************************************** - -We can now attach the ``PipelineStage`` to a pipeline schedule, and run the -schedule with input data. Here is a GPipe example: - -.. code-block:: python - - from torch.distributed.pipelining import ScheduleGPipe - - # Create a schedule - schedule = ScheduleGPipe(stage, n_microbatches) - - # Input data (whole batch) - x = torch.randn(batch_size, in_dim, device=device) - - # Run the pipeline with input `x` - # `x` will be divided into microbatches automatically - if rank == 0: - schedule.step(x) - else: - output = schedule.step() - -Note that the above code needs to be launched for each worker, thus we use a -launcher service to launch multiple processes: - -.. code-block:: bash - - torchrun --nproc_per_node=2 example.py - - -Options for Splitting a Model -***************************** - -.. _option_1_manual: - -Option 1: splitting a model manually -==================================== - -To directly construct a ``PipelineStage``, the user is responsible for providing -a single ``nn.Module`` instance that owns the relevant ``nn.Parameters`` and -``nn.Buffers``, and defines a ``forward()`` method that executes the operations -relevant for that stage. For example, a condensed version of the Transformer -class defined in Torchtitan shows a pattern of building an easily partitionable -model. - -.. code-block:: python - - class Transformer(nn.Module): - def __init__(self, model_args: ModelArgs): - super().__init__() - - self.tok_embeddings = nn.Embedding(...) - - # Using a ModuleDict lets us delete layers without affecting names, - # ensuring checkpoints will correctly save and load. - self.layers = torch.nn.ModuleDict() - for layer_id in range(model_args.n_layers): - self.layers[str(layer_id)] = TransformerBlock(...) - - self.output = nn.Linear(...) - - def forward(self, tokens: torch.Tensor): - # Handling layers being 'None' at runtime enables easy pipeline splitting - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - - for layer in self.layers.values(): - h = layer(h, self.freqs_cis) - - h = self.norm(h) if self.norm else h - output = self.output(h).float() if self.output else h - return output - -A model defined in this manner can be easily configured per stage by first -initializing the whole model (using meta-device to avoid OOM errors), deleting -undesired layers for that stage, and then creating a PipelineStage that wraps -the model. For example: - -.. code-block:: python - - with torch.device("meta"): - assert num_stages == 2, "This is a simple 2-stage example" - - # we construct the entire model, then delete the parts we do not need for this stage - # in practice, this can be done using a helper function that automatically divides up layers across stages. - model = Transformer() - - if stage_index == 0: - # prepare the first stage model - del model.layers["1"] - model.norm = None - model.output = None - - elif stage_index == 1: - # prepare the second stage model - model.tok_embeddings = None - del model.layers["0"] - - from torch.distributed.pipelining import PipelineStage - stage = PipelineStage( - model, - stage_index, - num_stages, - device, - ) - -When composing with other Data or Model parallelism techniques, ``output_args`` -may also be required, if the output shape/dtype of the model chunk will be -affected. - - -.. _option_2_tracer: - -Option 2: splitting a model automatically -========================================= - -If you have a full model and do not want to spend time on modifying it into a -sequence of "model partitions", the ``pipeline`` API is here to help. -Here is a brief example: - -.. code-block:: python - - class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.emb = torch.nn.Embedding(10, 3) - self.layers = torch.nn.ModuleList( - Layer() for _ in range(2) - ) - self.lm = LMHead() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.emb(x) - for layer in self.layers: - x = layer(x) - x = self.lm(x) - return x - - -If we print the model, we can see multiple hierarchies, which makes it hard to split by hand:: - - Model( - (emb): Embedding(10, 3) - (layers): ModuleList( - (0-1): 2 x Layer( - (lin): Linear(in_features=3, out_features=3, bias=True) - ) - ) - (lm): LMHead( - (proj): Linear(in_features=3, out_features=3, bias=True) - ) - ) - -Let us see how the ``pipeline`` API works: - -.. code-block:: python - - from torch.distributed.pipelining import pipeline, SplitPoint - - # An example micro-batch input - x = torch.LongTensor([1, 2, 4, 5]) - - pipe = pipeline( - module=mod, - mb_args=(x,), - split_spec={ - "layers.1": SplitPoint.BEGINNING, - } - ) - -The ``pipeline`` API splits your model given a ``split_spec``, where -``SplitPoint.BEGINNING`` stands for adding a split point -*before* execution of certain submodule in the ``forward`` function, and -similarly, ``SplitPoint.END`` for split point *after* such. - -If we ``print(pipe)``, we can see:: - - GraphModule( - (submod_0): GraphModule( - (emb): InterpreterModule() - (layers): Module( - (0): InterpreterModule( - (lin): InterpreterModule() - ) - ) - ) - (submod_1): GraphModule( - (layers): Module( - (1): InterpreterModule( - (lin): InterpreterModule() - ) - ) - (lm): InterpreterModule( - (proj): InterpreterModule() - ) - ) - ) - - def forward(self, x): - submod_0 = self.submod_0(x); x = None - submod_1 = self.submod_1(submod_0); submod_0 = None - return (submod_1,) - - -The "model partitions" are represented by submodules (``submod_0``, -``submod_1``), each of which is reconstructed with original model operations, weights -and hierarchies. In addition, a "root-level" ``forward`` function is -reconstructed to capture the data flow between those partitions. Such data flow -will be replayed by the pipeline runtime later, in a distributed fashion. - -The ``Pipe`` object provides a method for retrieving the "model partitions": - -.. code-block:: python - - stage_mod : nn.Module = pipe.get_stage_module(stage_idx) - -The returned ``stage_mod`` is a ``nn.Module``, with which you can create an -optimizer, save or load checkpoints, or apply other parallelisms. - -``Pipe`` also allows you to create a distributed stage runtime on a device given -a ``ProcessGroup``: - -.. code-block:: python - - stage = pipe.build_stage(stage_idx, device, group) - -Alternatively, if you would like to build the stage runtime later after some -modification to the ``stage_mod``, you can use a functional version of the -``build_stage`` API. For example: - -.. code-block:: python - - from torch.distributed.pipelining import build_stage - from torch.nn.parallel import DistributedDataParallel - - dp_mod = DistributedDataParallel(stage_mod) - info = pipe.info() - stage = build_stage(dp_mod, stage_idx, info, device, group) - -.. note:: - The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your - model into a single graph. If your model is not full-graph'able, you can use - our manual frontend below. - - -Hugging Face Examples -********************* - -In the `PiPPy `_ repo where this package was -original created, we kept examples based on unmodified Hugging Face models. -See the `examples/huggingface -`_ directory. - -Examples include: - -* `GPT2 `_ -* `Llama `_ - - -Technical Deep Dive -******************* - -How does the ``pipeline`` API split a model? -============================================ - -First, the ``pipeline`` API turns our model into a directed acyclic graph (DAG) -by tracing the model. It traces the model using ``torch.export`` -- a PyTorch 2 -full-graph capturing tool. - -Then, it groups together the **operations and parameters** needed by a stage -into a reconstructed submodule: ``submod_0``, ``submod_1``, ... - -Different from conventional submodule access methods like ``Module.children()``, -the ``pipeline`` API does not only cut the module structure of your model, but -also the **forward** function of your model. - -This is necessary because model structure like ``Module.children()`` merely -captures information during ``Module.__init__()``, and does not capture any -information about ``Module.forward()``. Said differently, ``Module.children()`` -lacks information about the following aspects key to pipelininig: - -* Execution order of child modules in ``forward`` -* Activation flows between child modules -* Whether there are any functional operators between child modules (for example, - ``relu`` or ``add`` operations will not be captured by ``Module.children()``). - -The ``pipeline`` API, on the contrary, makes sure that the ``forward`` behavior -is truly preserved. It also captures the activation flow between the partitions, -helping the distributed runtime to make correct send/receive calls without human -intervention. - -Another flexibility of the ``pipeline`` API is that split points can be at -arbitrary levels within your model hierarchy. In the split partitions, the original model -hierarchy related to that partition will be reconstructed at no cost to you. -At a result, fully-qualified names (FQNs) pointing to a submodule or parameter -would be still valid, and services that relies on FQNs (such as FSDP, TP or -checkpointing) can still run with your partitioned modules with almost zero code -change. - - -Implementing Your Own Schedule -****************************** - -You can implement your own pipeline schedule by extending one of the following two class: - -* ``PipelineScheduleSingle`` -* ``PipelineScheduleMulti`` - -``PipelineScheduleSingle`` is for schedules that assigns *only one* stage per rank. -``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank. - -For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. -Whereas, ``ScheduleInterleaved1F1B``, ``ScheduleLoopedBFS``, ``ScheduleInterleavedZeroBubble``, and ``ScheduleZBVZeroBubble`` -are subclasses of ``PipelineScheduleMulti``. - - -Logging -******* - -You can turn on additional logging using the `TORCH_LOGS` environment variable from `torch._logging `_: - -* `TORCH_LOGS=+pp` will display `logging.DEBUG` messages and all levels above it. -* `TORCH_LOGS=pp` will display `logging.INFO` messages and above. -* `TORCH_LOGS=-pp` will display `logging.WARNING` messages and above. - - -API Reference -************* - -.. automodule:: torch.distributed.pipelining - -Model Split APIs -============================ - -The following set of APIs transform your model into a pipeline representation. - -.. currentmodule:: torch.distributed.pipelining - -.. autoclass:: SplitPoint - -.. autofunction:: pipeline - -.. autoclass:: Pipe - -.. autofunction:: pipe_split - -Microbatch Utilities -==================== - -.. automodule:: torch.distributed.pipelining.microbatch - -.. currentmodule:: torch.distributed.pipelining.microbatch - -.. autoclass:: TensorChunkSpec - -.. autofunction:: split_args_kwargs_into_chunks - -.. autofunction:: merge_chunks - -Pipeline Stages -=============== - -.. automodule:: torch.distributed.pipelining.stage - -.. currentmodule:: torch.distributed.pipelining.stage - -.. autoclass:: PipelineStage - -.. autofunction:: build_stage - -Pipeline Schedules -================== - -.. automodule:: torch.distributed.pipelining.schedules - -.. currentmodule:: torch.distributed.pipelining.schedules - -.. autoclass:: ScheduleGPipe - -.. autoclass:: Schedule1F1B - -.. autoclass:: ScheduleInterleaved1F1B - -.. autoclass:: ScheduleLoopedBFS - -.. autoclass:: ScheduleInterleavedZeroBubble - -.. autoclass:: ScheduleZBVZeroBubble - -.. autoclass:: PipelineScheduleSingle - :members: - -.. autoclass:: PipelineScheduleMulti - :members: diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst deleted file mode 100644 index f36f6218dac0a1..00000000000000 --- a/docs/source/distributed.rst +++ /dev/null @@ -1,1045 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Distributed communication package - torch.distributed -===================================================== - -.. note :: - Please refer to `PyTorch Distributed Overview `__ - for a brief introduction to all features related to distributed training. - -.. automodule:: torch.distributed -.. currentmodule:: torch.distributed - -Backends --------- - -``torch.distributed`` supports three built-in backends, each with -different capabilities. The table below shows which functions are available -for use with CPU / CUDA tensors. -MPI supports CUDA only if the implementation used to build PyTorch supports it. - - -+----------------+-----------+-----------+-----------+ -| Backend | ``gloo`` | ``mpi`` | ``nccl`` | -+----------------+-----+-----+-----+-----+-----+-----+ -| Device | CPU | GPU | CPU | GPU | CPU | GPU | -+================+=====+=====+=====+=====+=====+=====+ -| send | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| recv | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| broadcast | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| scatter | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_to_all | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ - - -Backends that come with PyTorch -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype). -By default for Linux, the Gloo and NCCL backends are built and included in PyTorch -distributed (NCCL only when building with CUDA). MPI is an optional backend that can only be -included if you build PyTorch from source. (e.g. building PyTorch on a host that has MPI -installed.) - -.. note :: - As of PyTorch v1.8, Windows supports all collective communications backend but NCCL, - If the `init_method` argument of :func:`init_process_group` points to a file it must adhere - to the following schema: - - - Local file system, ``init_method="file:///d:/tmp/some_file"`` - - Shared file system, ``init_method="file://////{machine_name}/{share_folder_name}/some_file"`` - - Same as on Linux platform, you can enable TcpStore by setting environment variables, - MASTER_ADDR and MASTER_PORT. - -Which backend to use? -^^^^^^^^^^^^^^^^^^^^^ - -In the past, we were often asked: "which backend should I use?". - -- Rule of thumb - - - Use the NCCL backend for distributed **GPU** training - - Use the Gloo backend for distributed **CPU** training. - -- GPU hosts with InfiniBand interconnect - - - Use NCCL, since it's the only backend that currently supports - InfiniBand and GPUDirect. - -- GPU hosts with Ethernet interconnect - - - Use NCCL, since it currently provides the best distributed GPU - training performance, especially for multiprocess single-node or - multi-node distributed training. If you encounter any problem with - NCCL, use Gloo as the fallback option. (Note that Gloo currently - runs slower than NCCL for GPUs.) - -- CPU hosts with InfiniBand interconnect - - - If your InfiniBand has enabled IP over IB, use Gloo, otherwise, - use MPI instead. We are planning on adding InfiniBand support for - Gloo in the upcoming releases. - -- CPU hosts with Ethernet interconnect - - - Use Gloo, unless you have specific reasons to use MPI. - -Common environment variables -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Choosing the network interface to use -""""""""""""""""""""""""""""""""""""" - -By default, both the NCCL and Gloo backends will try to find the right network interface to use. -If the automatically detected interface is not correct, you can override it using the following -environment variables (applicable to the respective backend): - -* **NCCL_SOCKET_IFNAME**, for example ``export NCCL_SOCKET_IFNAME=eth0`` -* **GLOO_SOCKET_IFNAME**, for example ``export GLOO_SOCKET_IFNAME=eth0`` - -If you're using the Gloo backend, you can specify multiple interfaces by separating -them by a comma, like this: ``export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3``. -The backend will dispatch operations in a round-robin fashion across these interfaces. -It is imperative that all processes specify the same number of interfaces in this variable. - -Other NCCL environment variables -"""""""""""""""""""""""""""""""" - -**Debugging** - in case of NCCL failure, you can set ``NCCL_DEBUG=INFO`` to print an explicit -warning message as well as basic NCCL initialization information. - -You may also use ``NCCL_DEBUG_SUBSYS`` to get more details about a specific -aspect of NCCL. For example, ``NCCL_DEBUG_SUBSYS=COLL`` would print logs of -collective calls, which may be helpful when debugging hangs, especially those -caused by collective type or message size mismatch. In case of topology -detection failure, it would be helpful to set ``NCCL_DEBUG_SUBSYS=GRAPH`` -to inspect the detailed detection result and save as reference if further help -from NCCL team is needed. - -**Performance tuning** - NCCL performs automatic tuning based on its topology detection to save users' -tuning effort. On some socket-based systems, users may still try tuning -``NCCL_SOCKET_NTHREADS`` and ``NCCL_NSOCKS_PERTHREAD`` to increase socket -network bandwidth. These two environment variables have been pre-tuned by NCCL -for some cloud providers, such as AWS or GCP. - -For a full list of NCCL environment variables, please refer to -`NVIDIA NCCL's official documentation `_ - -You can tune NCCL communicators even further using `torch.distributed.ProcessGroupNCCL.NCCLConfig` -and `torch.distributed.ProcessGroupNCCL.Options`. Learn more about them using `help` -(e.g. `help(torch.distributed.ProcessGroupNCCL.NCCLConfig)`) in the interpreter. - -.. _distributed-basics: - -Basics ------- - -The `torch.distributed` package provides PyTorch support and communication primitives -for multiprocess parallelism across several computation nodes running on one or more -machines. The class :func:`torch.nn.parallel.DistributedDataParallel` builds on this -functionality to provide synchronous distributed training as a wrapper around any -PyTorch model. This differs from the kinds of parallelism provided by -:doc:`multiprocessing` and :func:`torch.nn.DataParallel` in that it supports -multiple network-connected machines and in that the user must explicitly launch a separate -copy of the main training script for each process. - -In the single-machine synchronous case, `torch.distributed` or the -:func:`torch.nn.parallel.DistributedDataParallel` wrapper may still have advantages over other -approaches to data-parallelism, including :func:`torch.nn.DataParallel`: - -* Each process maintains its own optimizer and performs a complete optimization step with each - iteration. While this may appear redundant, since the gradients have already been gathered - together and averaged across processes and are thus the same for every process, this means - that no parameter broadcast step is needed, reducing time spent transferring tensors between - nodes. -* Each process contains an independent Python interpreter, eliminating the extra interpreter - overhead and "GIL-thrashing" that comes from driving several execution threads, model - replicas, or GPUs from a single Python process. This is especially important for models that - make heavy use of the Python runtime, including models with recurrent layers or many small - components. - -Initialization --------------- - -The package needs to be initialized using the :func:`torch.distributed.init_process_group` -or :func:`torch.distributed.device_mesh.init_device_mesh` function before calling any other methods. -Both block until all processes have joined. - -.. warning:: - Initialization is not thread-safe. Process group creation should be performed from a single thread, to prevent - inconsistent 'UUID' assignment across ranks, and to prevent races during initialization that can lead to hangs. - -.. autofunction:: is_available - -.. autofunction:: init_process_group - -.. autofunction:: torch.distributed.device_mesh.init_device_mesh - -.. autofunction:: is_initialized - -.. autofunction:: is_mpi_available - -.. autofunction:: is_nccl_available - -.. autofunction:: is_gloo_available - -.. autofunction:: torch.distributed.distributed_c10d.is_xccl_available - -.. autofunction:: is_torchelastic_launched - --------------------------------------------------------------------------------- - -Currently three initialization methods are supported: - -TCP initialization -^^^^^^^^^^^^^^^^^^ - -There are two ways to initialize using TCP, both requiring a network address -reachable from all processes and a desired ``world_size``. The first way -requires specifying an address that belongs to the rank 0 process. This -initialization method requires that all processes have manually specified ranks. - -Note that multicast address is not supported anymore in the latest distributed -package. ``group_name`` is deprecated as well. - -:: - - import torch.distributed as dist - - # Use address of one of the machines - dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456', - rank=args.rank, world_size=4) - -Shared file-system initialization -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Another initialization method makes use of a file system that is shared and -visible from all machines in a group, along with a desired ``world_size``. The URL should start -with ``file://`` and contain a path to a non-existent file (in an existing -directory) on a shared file system. File-system initialization will automatically -create that file if it doesn't exist, but will not delete the file. Therefore, it -is your responsibility to make sure that the file is cleaned up before the next -:func:`init_process_group` call on the same file path/name. - -Note that automatic rank assignment is not supported anymore in the latest -distributed package and ``group_name`` is deprecated as well. - -.. warning:: - This method assumes that the file system supports locking using ``fcntl`` - most - local systems and NFS support it. - -.. warning:: - This method will always create the file and try its best to clean up and remove - the file at the end of the program. In other words, each initialization with - the file init method will need a brand new empty file in order for the initialization - to succeed. If the same file used by the previous initialization (which happens not - to get cleaned up) is used again, this is unexpected behavior and can often cause - deadlocks and failures. Therefore, even though this method will try its best to clean up - the file, if the auto-delete happens to be unsuccessful, it is your responsibility - to ensure that the file is removed at the end of the training to prevent the same - file to be reused again during the next time. This is especially important - if you plan to call :func:`init_process_group` multiple times on the same file name. - In other words, if the file is not removed/cleaned up and you call - :func:`init_process_group` again on that file, failures are expected. - The rule of thumb here is that, make sure that the file is non-existent or - empty every time :func:`init_process_group` is called. - -:: - - import torch.distributed as dist - - # rank should always be specified - dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile', - world_size=4, rank=args.rank) - -Environment variable initialization -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This method will read the configuration from environment variables, allowing -one to fully customize how the information is obtained. The variables to be set -are: - -* ``MASTER_PORT`` - required; has to be a free port on machine with rank 0 -* ``MASTER_ADDR`` - required (except for rank 0); address of rank 0 node -* ``WORLD_SIZE`` - required; can be set either here, or in a call to init function -* ``RANK`` - required; can be set either here, or in a call to init function - -The machine with rank 0 will be used to set up all connections. - -This is the default method, meaning that ``init_method`` does not have to be specified (or -can be ``env://``). - -Improving initialization time -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -* ``TORCH_GLOO_LAZY_INIT`` - establishes connections on demand rather than - using a full mesh which can greatly improve initialization time for non all2all - operations. - -Post-Initialization -------------------- - -Once :func:`torch.distributed.init_process_group` was run, the following functions can be used. To -check whether the process group has already been initialized use :func:`torch.distributed.is_initialized`. - -.. autoclass:: Backend - :members: - -.. autofunction:: get_backend - -.. autofunction:: get_rank - -.. autofunction:: get_world_size - -Shutdown --------- - -It is important to clean up resources on exit by calling :func:`destroy_process_group`. - -The simplest pattern to follow is to destroy every process group and backend by calling -:func:`destroy_process_group()` with the default value of None for the `group` argument, at a -point in the training script where communications are no longer needed, usually near the -end of main(). The call should be made once per trainer-process, not at the outer -process-launcher level. - -if :func:`destroy_process_group` is not called by all ranks in a pg within the timeout duration, -especially when there are multiple process-groups in the application e.g. for N-D parallelism, -hangs on exit are possible. This is because the destructor for ProcessGroupNCCL calls ncclCommAbort, -which must be called collectively, but the order of calling ProcessGroupNCCL's destructor if called -by python's GC is not deterministic. Calling :func:`destroy_process_group` helps by ensuring -ncclCommAbort is called in a consistent order across ranks, and avoids calling ncclCommAbort -during ProcessGroupNCCL's destructor. - -Reinitialization -^^^^^^^^^^^^^^^^ - -`destroy_process_group` can also be used to destroy individual process groups. One use -case could be fault tolerant training, where a process group may be destroyed and then -a new one initialized during runtime. In this case, it's critical to synchronize the trainer -processes using some means other than torch.distributed primitives _after_ calling destroy and -before subsequently initializing. This behavior is currently unsupported/untested, due to -the difficulty of achieving this synchronization, and is considered a known issue. Please file -a github issue or RFC if this is a use case that's blocking you. - --------------------------------------------------------------------------------- - -Groups ------- - -By default collectives operate on the default group (also called the world) and -require all processes to enter the distributed function call. However, some workloads can benefit -from more fine-grained communication. This is where distributed groups come -into play. :func:`~torch.distributed.new_group` function can be -used to create new groups, with arbitrary subsets of all processes. It returns -an opaque group handle that can be given as a ``group`` argument to all collectives -(collectives are distributed functions to exchange information in certain well-known programming patterns). - - -.. autofunction:: new_group - -.. autofunction:: get_group_rank - -.. autofunction:: get_global_rank - -.. autofunction:: get_process_group_ranks - - -DeviceMesh ----------- - -DeviceMesh is a higher level abstraction that manages process groups (or NCCL communicators). -It allows user to easily create inter node and intra node process groups without worrying about -how to set up the ranks correctly for different sub process groups, and it helps manage those -distributed process group easily. :func:`~torch.distributed.device_mesh.init_device_mesh` function can be -used to create new DeviceMesh, with a mesh shape describing the device topology. - -.. autoclass:: torch.distributed.device_mesh.DeviceMesh - :members: - -Point-to-point communication ----------------------------- - -.. autofunction:: send - -.. autofunction:: recv - -:func:`~torch.distributed.isend` and :func:`~torch.distributed.irecv` -return distributed request objects when used. In general, the type of this object is unspecified -as they should never be created manually, but they are guaranteed to support two methods: - -* ``is_completed()`` - returns True if the operation has finished -* ``wait()`` - will block the process until the operation is finished. - ``is_completed()`` is guaranteed to return True once it returns. - -.. autofunction:: isend - -.. autofunction:: irecv - -.. autofunction:: send_object_list - -.. autofunction:: recv_object_list - -.. autofunction:: batch_isend_irecv - -.. autoclass:: P2POp - -Synchronous and asynchronous collective operations --------------------------------------------------- -Every collective operation function supports the following two kinds of operations, -depending on the setting of the ``async_op`` flag passed into the collective: - -**Synchronous operation** - the default mode, when ``async_op`` is set to ``False``. -When the function returns, it is guaranteed that -the collective operation is performed. In the case of CUDA operations, it is not guaranteed -that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any -further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, -function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of -synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream -synchronization, see `CUDA Semantics `__. -See the below script to see examples of differences in these semantics for CPU and CUDA operations. - -**Asynchronous operation** - when ``async_op`` is set to True. The collective operation function -returns a distributed request object. In general, you don't need to create it manually and it -is guaranteed to support two methods: - -* ``is_completed()`` - in the case of CPU collectives, returns ``True`` if completed. In the case of CUDA operations, - returns ``True`` if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the - default stream without further synchronization. -* ``wait()`` - in the case of CPU collectives, will block the process until the operation is completed. In the case - of CUDA collectives, will block the currently active CUDA stream until the operation is completed (but will not block the CPU). -* ``get_future()`` - returns ``torch._C.Future`` object. Supported for NCCL, also supported for most operations on GLOO - and MPI, except for peer to peer operations. - Note: as we continue adopting Futures and merging APIs, ``get_future()`` call might become redundant. - -**Example** - -The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. -It shows the explicit need to synchronize when using collective outputs on different CUDA streams: - -:: - - # Code runs on each rank. - dist.init_process_group("nccl", rank=rank, world_size=2) - output = torch.tensor([rank]).cuda(rank) - s = torch.cuda.Stream() - handle = dist.all_reduce(output, async_op=True) - # Wait ensures the operation is enqueued, but not necessarily complete. - handle.wait() - # Using result on non-default stream. - with torch.cuda.stream(s): - s.wait_stream(torch.cuda.default_stream()) - output.add_(100) - if rank == 0: - # if the explicit call to wait_stream was omitted, the output below will be - # non-deterministically 1 or 101, depending on whether the allreduce overwrote - # the value after the add completed. - print(output) - - -Collective functions --------------------- - -.. autofunction:: broadcast - -.. autofunction:: broadcast_object_list - -.. autofunction:: all_reduce - -.. autofunction:: reduce - -.. autofunction:: all_gather - -.. autofunction:: all_gather_into_tensor - -.. autofunction:: all_gather_object - -.. autofunction:: gather - -.. autofunction:: gather_object - -.. autofunction:: scatter - -.. autofunction:: scatter_object_list - -.. autofunction:: reduce_scatter - -.. autofunction:: reduce_scatter_tensor - -.. autofunction:: all_to_all_single - -.. autofunction:: all_to_all - -.. autofunction:: barrier - -.. autofunction:: monitored_barrier - -.. autoclass:: Work - :members: - -.. autoclass:: ReduceOp - -.. class:: reduce_op - - Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``, - ``MIN``, and ``MAX``. - - :class:`~torch.distributed.ReduceOp` is recommended to use instead. - - -Distributed Key-Value Store ---------------------------- - -The distributed package comes with a distributed key-value store, which can be -used to share information between processes in the group as well as to -initialize the distributed package in -:func:`torch.distributed.init_process_group` (by explicitly creating the store -as an alternative to specifying ``init_method``.) There are 3 choices for -Key-Value Stores: :class:`~torch.distributed.TCPStore`, -:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`. - -.. autoclass:: Store - :members: - :special-members: - -.. autoclass:: TCPStore - :members: - :special-members: __init__ - -.. autoclass:: HashStore - :members: - :special-members: __init__ - -.. autoclass:: FileStore - :members: - :special-members: __init__ - -.. autoclass:: PrefixStore - :members: - :special-members: __init__ - - -Profiling Collective Communication ------------------------------------------ - -Note that you can use ``torch.profiler`` (recommended, only available after 1.8.1) or ``torch.autograd.profiler`` to profile collective communication and point-to-point communication APIs mentioned here. All out-of-the-box backends (``gloo``, -``nccl``, ``mpi``) are supported and collective communication usage will be rendered as expected in profiling output/traces. Profiling your code is the same as any regular torch operator: - -:: - - import torch - import torch.distributed as dist - with torch.profiler(): - tensor = torch.randn(20, 10) - dist.all_reduce(tensor) - -Please refer to the `profiler documentation `__ for a full overview of profiler features. - - -Multi-GPU collective functions ------------------------------- - -.. warning:: - The multi-GPU functions (which stand for multiple GPUs per CPU thread) are - deprecated. As of today, PyTorch Distributed's preferred programming model - is one device per thread, as exemplified by the APIs in this document. If - you are a backend developer and want to support multiple devices per thread, - please contact PyTorch Distributed's maintainers. - - -.. _object_collectives: - -Object collectives ------------------- - -.. warning:: - Object collectives have a number of serious limitations. Read further to determine - if they are safe to use for your use case. - -Object collectives are a set of collective-like operations that work on arbitrary -Python objects, as long as they can be pickled. There are various collective patterns -implemented (e.g. broadcast, all_gather, ...) but they each roughly follow this pattern: - -1. convert the input object into a pickle (raw bytes), then shove it into a byte tensor -2. communicate the size of this byte tensor to peers (first collective operation) -3. allocate appropriately sized tensor to perform the real collective -4. communicate the object data (second collective operation) -5. convert raw data back into Python (unpickle) - -Object collectives sometimes have surprising performance or memory characteristics that lead to -long runtimes or OOMs, and thus they should be used with caution. Here are some common issues. - -**Asymmetric pickle/unpickle time** - Pickling objects can be slow, depending on the number, type and size of the objects. -When the collective has a fan-in (e.g. gather_object), the receiving rank(s) must unpickle N times more objects than -the sending rank(s) had to pickle, which can cause other ranks to time out on their next collective. - -**Inefficient tensor communication** - Tensors should be sent via regular collective APIs, not object collective APIs. -It is possible to send Tensors via object collective APIs, but they will be serialized and deserialized (including a -CPU-sync and device-to-host copy in the case of non-CPU tensors), and in almost every case other than debugging or -troubleshooting code, it would be worth the trouble to refactor the code to use non-object collectives instead. - -**Unexpected tensor devices** - If you still want to send tensors via object collectives, there is another aspect -specific to cuda (and possibly other accelerators) tensors. If you pickle a tensor that is currently on `cuda:3`, and -then unpickle it, you will get another tensor on `cuda:3` *regardless of which process you are on, or which CUDA device -is the 'default' device for that process*. With regular tensor collective APIs, 'output tensors' will always be on the -same, local device, which is generally what you'd expect. - -Unpickling a tensor will implicitly activate a CUDA context if it is the first -time a GPU is used by the process, which can waste significant amounts of GPU memory. This issue can be avoided by -moving tensors to CPU before passing them as inputs to an object collective. - -Third-party backends --------------------- - -Besides the builtin GLOO/MPI/NCCL backends, PyTorch distributed supports -third-party backends through a run-time register mechanism. -For references on how to develop a third-party backend through C++ Extension, -please refer to `Tutorials - Custom C++ and CUDA Extensions `_ and -``test/cpp_extensions/cpp_c10d_extension.cpp``. The capability of third-party -backends are decided by their own implementations. - -The new backend derives from ``c10d::ProcessGroup`` and registers the backend -name and the instantiating interface through :func:`torch.distributed.Backend.register_backend` -when imported. - -When manually importing this backend and invoking :func:`torch.distributed.init_process_group` -with the corresponding backend name, the ``torch.distributed`` package runs on -the new backend. - -.. warning:: - The support of third-party backend is experimental and subject to change. - -.. _distributed-launch: - -Launch utility --------------- - -The `torch.distributed` package also provides a launch utility in -`torch.distributed.launch`. This helper utility can be used to launch -multiple processes per node for distributed training. - - -.. automodule:: torch.distributed.launch - - -Spawn utility -------------- - -The :ref:`multiprocessing-doc` package also provides a ``spawn`` -function in :func:`torch.multiprocessing.spawn`. This helper function -can be used to spawn multiple processes. It works by passing in the -function that you want to run and spawns N processes to run it. This -can be used for multiprocess distributed training as well. - -For references on how to use it, please refer to `PyTorch example - ImageNet -implementation `_ - -Note that this function requires Python 3.4 or higher. - -Debugging ``torch.distributed`` applications ------------------------------------------------------- - -Debugging distributed applications can be challenging due to hard to understand hangs, crashes, or inconsistent behavior across ranks. ``torch.distributed`` provides -a suite of tools to help debug training applications in a self-serve fashion: - -Python Breakpoint -^^^^^^^^^^^^^^^^^ - -It is extremely convenient to use python's debugger in a distributed environment, but because it does not work out of the box many people do not use it at all. -PyTorch offers a customized wrapper around pdb that streamlines the process. - -`torch.distributed.breakpoint` makes this process easy. Internally, it customizes `pdb`'s breakpoint behavior in two ways but otherwise behaves as normal `pdb`. -1. Attaches the debugger only on one rank (specified by the user). -2. Ensures all other ranks stop, by using a `torch.distributed.barrier()` that will release once the debugged rank issues a `continue` -3. Reroutes stdin from the child process such that it connects to your terminal. - -To use it, simply issue `torch.distributed.breakpoint(rank)` on all ranks, using the same value for `rank` in each case. - -Monitored Barrier -^^^^^^^^^^^^^^^^^ - -As of v1.10, :func:`torch.distributed.monitored_barrier` exists as an alternative to :func:`torch.distributed.barrier` which fails with helpful information about which rank may be faulty -when crashing, i.e. not all ranks calling into :func:`torch.distributed.monitored_barrier` within the provided timeout. :func:`torch.distributed.monitored_barrier` implements a host-side -barrier using ``send``/``recv`` communication primitives in a process similar to acknowledgements, allowing rank 0 to report which rank(s) failed to acknowledge -the barrier in time. As an example, consider the following function where rank 1 fails to call into :func:`torch.distributed.monitored_barrier` (in practice this could be due -to an application bug or hang in a previous collective): - -:: - - import os - from datetime import timedelta - - import torch - import torch.distributed as dist - import torch.multiprocessing as mp - - - def worker(rank): - dist.init_process_group("nccl", rank=rank, world_size=2) - # monitored barrier requires gloo process group to perform host-side sync. - group_gloo = dist.new_group(backend="gloo") - if rank not in [1]: - dist.monitored_barrier(group=group_gloo, timeout=timedelta(seconds=2)) - - - if __name__ == "__main__": - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" - mp.spawn(worker, nprocs=2, args=()) - -The following error message is produced on rank 0, allowing the user to determine which rank(s) may be faulty and investigate further: - -:: - - RuntimeError: Rank 1 failed to pass monitoredBarrier in 2000 ms - Original exception: - [gloo/transport/tcp/pair.cc:598] Connection closed by peer [2401:db00:eef0:1100:3560:0:1c05:25d]:8594 - - -``TORCH_DISTRIBUTED_DEBUG`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -With ``TORCH_CPP_LOG_LEVEL=INFO``, the environment variable ``TORCH_DISTRIBUTED_DEBUG`` can be used to trigger additional useful logging and collective synchronization checks to ensure all ranks -are synchronized appropriately. ``TORCH_DISTRIBUTED_DEBUG`` can be set to either ``OFF`` (default), ``INFO``, or ``DETAIL`` depending on the debugging level -required. Please note that the most verbose option, ``DETAIL`` may impact the application performance and thus should only be used when debugging issues. - -Setting ``TORCH_DISTRIBUTED_DEBUG=INFO`` will result in additional debug logging when models trained with :func:`torch.nn.parallel.DistributedDataParallel` are initialized, and -``TORCH_DISTRIBUTED_DEBUG=DETAIL`` will additionally log runtime performance statistics a select number of iterations. These runtime statistics -include data such as forward time, backward time, gradient communication time, etc. As an example, given the following application: - -:: - - import os - - import torch - import torch.distributed as dist - import torch.multiprocessing as mp - - - class TwoLinLayerNet(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = torch.nn.Linear(10, 10, bias=False) - self.b = torch.nn.Linear(10, 1, bias=False) - - def forward(self, x): - a = self.a(x) - b = self.b(x) - return (a, b) - - - def worker(rank): - dist.init_process_group("nccl", rank=rank, world_size=2) - torch.cuda.set_device(rank) - print("init model") - model = TwoLinLayerNet().cuda() - print("init ddp") - ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) - - inp = torch.randn(10, 10).cuda() - print("train") - - for _ in range(20): - output = ddp_model(inp) - loss = output[0] + output[1] - loss.sum().backward() - - - if __name__ == "__main__": - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" - os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" - os.environ[ - "TORCH_DISTRIBUTED_DEBUG" - ] = "DETAIL" # set to DETAIL for runtime logging. - mp.spawn(worker, nprocs=2, args=()) - -The following logs are rendered at initialization time: - -:: - - I0607 16:10:35.739390 515217 logger.cpp:173] [Rank 0]: DDP Initialized with: - broadcast_buffers: 1 - bucket_cap_bytes: 26214400 - find_unused_parameters: 0 - gradient_as_bucket_view: 0 - is_multi_device_module: 0 - iteration: 0 - num_parameter_tensors: 2 - output_device: 0 - rank: 0 - total_parameter_size_bytes: 440 - world_size: 2 - backend_name: nccl - bucket_sizes: 440 - cuda_visible_devices: N/A - device_ids: 0 - dtypes: float - master_addr: localhost - master_port: 29501 - module_name: TwoLinLayerNet - nccl_async_error_handling: N/A - nccl_blocking_wait: N/A - nccl_debug: WARN - nccl_ib_timeout: N/A - nccl_nthreads: N/A - nccl_socket_ifname: N/A - torch_distributed_debug: INFO - - -The following logs are rendered during runtime (when ``TORCH_DISTRIBUTED_DEBUG=DETAIL`` is set): - -:: - - I0607 16:18:58.085681 544067 logger.cpp:344] [Rank 1 / 2] Training TwoLinLayerNet unused_parameter_size=0 - Avg forward compute time: 40838608 - Avg backward compute time: 5983335 - Avg backward comm. time: 4326421 - Avg backward comm/comp overlap time: 4207652 - I0607 16:18:58.085693 544066 logger.cpp:344] [Rank 0 / 2] Training TwoLinLayerNet unused_parameter_size=0 - Avg forward compute time: 42850427 - Avg backward compute time: 3885553 - Avg backward comm. time: 2357981 - Avg backward comm/comp overlap time: 2234674 - - -In addition, ``TORCH_DISTRIBUTED_DEBUG=INFO`` enhances crash logging in :func:`torch.nn.parallel.DistributedDataParallel` due to unused parameters in the model. Currently, ``find_unused_parameters=True`` -must be passed into :func:`torch.nn.parallel.DistributedDataParallel` initialization if there are parameters that may be unused in the forward pass, and as of v1.10, all model outputs are required -to be used in loss computation as :func:`torch.nn.parallel.DistributedDataParallel` does not support unused parameters in the backwards pass. These constraints are challenging especially for larger -models, thus when crashing with an error, :func:`torch.nn.parallel.DistributedDataParallel` will log the fully qualified name of all parameters that went unused. For example, in the above application, -if we modify ``loss`` to be instead computed as ``loss = output[1]``, then ``TwoLinLayerNet.a`` does not receive a gradient in the backwards pass, and -thus results in ``DDP`` failing. On a crash, the user is passed information about parameters which went unused, which may be challenging to manually find for large models: - - -:: - - RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing - the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by - making sure all `forward` function outputs participate in calculating loss. - If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return va - lue of `forward` of your module when reporting this issue (e.g. list, dict, iterable). - Parameters which did not receive grad for rank 0: a.weight - Parameter indices which did not receive grad for rank 0: 0 - - -Setting ``TORCH_DISTRIBUTED_DEBUG=DETAIL`` will trigger additional consistency and synchronization checks on every collective call issued by the user -either directly or indirectly (such as DDP ``allreduce``). This is done by creating a wrapper process group that wraps all process groups returned by -:func:`torch.distributed.init_process_group` and :func:`torch.distributed.new_group` APIs. As a result, these APIs will return a wrapper process group that can be used exactly like a regular process -group, but performs consistency checks before dispatching the collective to an underlying process group. Currently, these checks include a :func:`torch.distributed.monitored_barrier`, -which ensures all ranks complete their outstanding collective calls and reports ranks which are stuck. Next, the collective itself is checked for consistency by -ensuring all collective functions match and are called with consistent tensor shapes. If this is not the case, a detailed error report is included when the -application crashes, rather than a hang or uninformative error message. As an example, consider the following function which has mismatched input shapes into -:func:`torch.distributed.all_reduce`: - -:: - - import torch - import torch.distributed as dist - import torch.multiprocessing as mp - - - def worker(rank): - dist.init_process_group("nccl", rank=rank, world_size=2) - torch.cuda.set_device(rank) - tensor = torch.randn(10 if rank == 0 else 20).cuda() - dist.all_reduce(tensor) - torch.cuda.synchronize(device=rank) - - - if __name__ == "__main__": - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29501" - os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" - os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" - mp.spawn(worker, nprocs=2, args=()) - -With the ``NCCL`` backend, such an application would likely result in a hang which can be challenging to root-cause in nontrivial scenarios. If the user enables -``TORCH_DISTRIBUTED_DEBUG=DETAIL`` and reruns the application, the following error message reveals the root cause: - -:: - - work = default_pg.allreduce([tensor], opts) - RuntimeError: Error when verifying shape tensors for collective ALLREDUCE on rank 0. This likely indicates that input shapes into the collective are mismatched across ranks. Got shapes: 10 - [ torch.LongTensor{1} ] - -.. note:: - For fine-grained control of the debug level during runtime the functions :func:`torch.distributed.set_debug_level`, :func:`torch.distributed.set_debug_level_from_env`, and - :func:`torch.distributed.get_debug_level` can also be used. - -In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `TORCH_SHOW_CPP_STACKTRACES=1` to log the entire callstack when a collective desynchronization is detected. These -collective desynchronization checks will work for all applications that use ``c10d`` collective calls backed by process groups created with the -:func:`torch.distributed.init_process_group` and :func:`torch.distributed.new_group` APIs. - -Logging -------- - -In addition to explicit debugging support via :func:`torch.distributed.monitored_barrier` and ``TORCH_DISTRIBUTED_DEBUG``, the underlying C++ library of ``torch.distributed`` also outputs log -messages at various levels. These messages can be helpful to understand the execution state of a distributed training job and to troubleshoot problems such as network connection failures. The -following matrix shows how the log level can be adjusted via the combination of ``TORCH_CPP_LOG_LEVEL`` and ``TORCH_DISTRIBUTED_DEBUG`` environment variables. - -+-------------------------+-----------------------------+------------------------+ -| ``TORCH_CPP_LOG_LEVEL`` | ``TORCH_DISTRIBUTED_DEBUG`` | Effective Log Level | -+=========================+=============================+========================+ -| ``ERROR`` | ignored | Error | -+-------------------------+-----------------------------+------------------------+ -| ``WARNING`` | ignored | Warning | -+-------------------------+-----------------------------+------------------------+ -| ``INFO`` | ignored | Info | -+-------------------------+-----------------------------+------------------------+ -| ``INFO`` | ``INFO`` | Debug | -+-------------------------+-----------------------------+------------------------+ -| ``INFO`` | ``DETAIL`` | Trace (a.k.a. All) | -+-------------------------+-----------------------------+------------------------+ - -Distributed components raise custom Exception types derived from `RuntimeError`: - -- `torch.distributed.DistError`: This is the base type of all distributed exceptions. -- `torch.distributed.DistBackendError`: This exception is thrown when a backend-specific error occurs. For example, if - the `NCCL` backend is used and the user attempts to use a GPU that is not available to the `NCCL` library. -- `torch.distributed.DistNetworkError`: This exception is thrown when networking - libraries encounter errors (ex: Connection reset by peer) -- `torch.distributed.DistStoreError`: This exception is thrown when the Store encounters - an error (ex: TCPStore timeout) - -.. autoclass:: torch.distributed.DistError -.. autoclass:: torch.distributed.DistBackendError -.. autoclass:: torch.distributed.DistNetworkError -.. autoclass:: torch.distributed.DistStoreError - -If you are running single node training, it may be convenient to interactively breakpoint your script. We offer a way to conveniently breakpoint a single rank: - -.. autofunction:: torch.distributed.breakpoint - -.. Distributed modules that are missing specific entries. -.. Adding them here for tracking purposes until they are more permanently fixed. -.. py:module:: torch.distributed.algorithms -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks -.. py:module:: torch.distributed.algorithms.model_averaging -.. py:module:: torch.distributed.elastic -.. py:module:: torch.distributed.elastic.utils -.. py:module:: torch.distributed.elastic.utils.data -.. py:module:: torch.distributed.launcher -.. py:module:: torch.distributed.nn -.. py:module:: torch.distributed.nn.api -.. py:module:: torch.distributed.nn.jit -.. py:module:: torch.distributed.nn.jit.templates -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook -.. py:module:: torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks -.. py:module:: torch.distributed.algorithms.join -.. py:module:: torch.distributed.algorithms.model_averaging.averagers -.. py:module:: torch.distributed.algorithms.model_averaging.hierarchical_model_averager -.. py:module:: torch.distributed.algorithms.model_averaging.utils -.. py:module:: torch.distributed.argparse_util -.. py:module:: torch.distributed.c10d_logger -.. py:module:: torch.distributed.checkpoint.api -.. py:module:: torch.distributed.checkpoint.default_planner -.. py:module:: torch.distributed.checkpoint.filesystem -.. py:module:: torch.distributed.checkpoint.metadata -.. py:module:: torch.distributed.checkpoint.optimizer -.. py:module:: torch.distributed.checkpoint.planner -.. py:module:: torch.distributed.checkpoint.planner_helpers -.. py:module:: torch.distributed.checkpoint.resharding -.. py:module:: torch.distributed.checkpoint.state_dict_loader -.. py:module:: torch.distributed.checkpoint.state_dict_saver -.. py:module:: torch.distributed.checkpoint.stateful -.. py:module:: torch.distributed.checkpoint.storage -.. py:module:: torch.distributed.checkpoint.utils -.. py:module:: torch.distributed.collective_utils -.. py:module:: torch.distributed.constants -.. py:module:: torch.distributed.device_mesh -.. py:module:: torch.distributed.distributed_c10d -.. py:module:: torch.distributed.elastic.agent.server.api -.. py:module:: torch.distributed.elastic.agent.server.local_elastic_agent -.. py:module:: torch.distributed.elastic.events.api -.. py:module:: torch.distributed.elastic.events.handlers -.. py:module:: torch.distributed.elastic.metrics.api -.. py:module:: torch.distributed.elastic.multiprocessing.api -.. py:module:: torch.distributed.elastic.multiprocessing.errors.error_handler -.. py:module:: torch.distributed.elastic.multiprocessing.errors.handlers -.. py:module:: torch.distributed.elastic.multiprocessing.redirects -.. py:module:: torch.distributed.elastic.multiprocessing.tail_log -.. py:module:: torch.distributed.elastic.rendezvous.api -.. py:module:: torch.distributed.elastic.rendezvous.c10d_rendezvous_backend -.. py:module:: torch.distributed.elastic.rendezvous.dynamic_rendezvous -.. py:module:: torch.distributed.elastic.rendezvous.etcd_rendezvous -.. py:module:: torch.distributed.elastic.rendezvous.etcd_rendezvous_backend -.. py:module:: torch.distributed.elastic.rendezvous.etcd_server -.. py:module:: torch.distributed.elastic.rendezvous.etcd_store -.. py:module:: torch.distributed.elastic.rendezvous.static_tcp_rendezvous -.. py:module:: torch.distributed.elastic.rendezvous.utils -.. py:module:: torch.distributed.elastic.timer.api -.. py:module:: torch.distributed.elastic.timer.file_based_local_timer -.. py:module:: torch.distributed.elastic.timer.local_timer -.. py:module:: torch.distributed.elastic.utils.api -.. py:module:: torch.distributed.elastic.utils.data.cycling_iterator -.. py:module:: torch.distributed.elastic.utils.data.elastic_distributed_sampler -.. py:module:: torch.distributed.elastic.utils.distributed -.. py:module:: torch.distributed.elastic.utils.log_level -.. py:module:: torch.distributed.elastic.utils.logging -.. py:module:: torch.distributed.elastic.utils.store -.. py:module:: torch.distributed.fsdp.api -.. py:module:: torch.distributed.fsdp.fully_sharded_data_parallel -.. py:module:: torch.distributed.fsdp.sharded_grad_scaler -.. py:module:: torch.distributed.fsdp.wrap -.. py:module:: torch.distributed.launcher.api -.. py:module:: torch.distributed.logging_handlers -.. py:module:: torch.distributed.nn.api.remote_module -.. py:module:: torch.distributed.nn.functional -.. py:module:: torch.distributed.nn.jit.instantiator -.. py:module:: torch.distributed.nn.jit.templates.remote_module_template -.. py:module:: torch.distributed.optim.apply_optimizer_in_backward -.. py:module:: torch.distributed.optim.functional_adadelta -.. py:module:: torch.distributed.optim.functional_adagrad -.. py:module:: torch.distributed.optim.functional_adam -.. py:module:: torch.distributed.optim.functional_adamax -.. py:module:: torch.distributed.optim.functional_adamw -.. py:module:: torch.distributed.optim.functional_rmsprop -.. py:module:: torch.distributed.optim.functional_rprop -.. py:module:: torch.distributed.optim.functional_sgd -.. py:module:: torch.distributed.optim.named_optimizer -.. py:module:: torch.distributed.optim.optimizer -.. py:module:: torch.distributed.optim.post_localSGD_optimizer -.. py:module:: torch.distributed.optim.utils -.. py:module:: torch.distributed.optim.zero_redundancy_optimizer -.. py:module:: torch.distributed.remote_device -.. py:module:: torch.distributed.rendezvous -.. py:module:: torch.distributed.rpc.api -.. py:module:: torch.distributed.rpc.backend_registry -.. py:module:: torch.distributed.rpc.constants -.. py:module:: torch.distributed.rpc.functions -.. py:module:: torch.distributed.rpc.internal -.. py:module:: torch.distributed.rpc.options -.. py:module:: torch.distributed.rpc.rref_proxy -.. py:module:: torch.distributed.rpc.server_process_global_profiler -.. py:module:: torch.distributed.tensor.parallel.api -.. py:module:: torch.distributed.tensor.parallel.ddp -.. py:module:: torch.distributed.tensor.parallel.fsdp -.. py:module:: torch.distributed.tensor.parallel.input_reshard -.. py:module:: torch.distributed.tensor.parallel.loss -.. py:module:: torch.distributed.tensor.parallel.style -.. py:module:: torch.distributed.utils -.. py:module:: torch.distributed.checkpoint.state_dict diff --git a/docs/source/distributed.tensor.md b/docs/source/distributed.tensor.md new file mode 100644 index 00000000000000..64f2f02c81077e --- /dev/null +++ b/docs/source/distributed.tensor.md @@ -0,0 +1,250 @@ + +:::{currentmodule} torch.distributed.tensor +::: + + +# torch.distributed.tensor + +:::{note} +`torch.distributed.tensor` is currently in alpha state and under +development, we are committing backward compatibility for the most APIs listed +in the doc, but there might be API changes if necessary. +::: + +## PyTorch DTensor (Distributed Tensor) + +PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed +logic, including sharded storage, operator computation and collective communications across devices/hosts. +`DTensor` could be used to build different parallelism solutions and support sharded state_dict representation +when working with multi-dimensional sharding. + +Please see examples from the PyTorch native parallelism solutions that are built on top of `DTensor`: + +- [Tensor Parallel](https://pytorch.org/docs/main/distributed.tensor.parallel.html) +- [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md) + +```{eval-rst} +.. automodule:: torch.distributed.tensor +``` + +{class}`DTensor` follows the SPMD (single program, multiple data) programming model to empower users to +write distributed program as if it's a **single-device program with the same convergence property**. It +provides a uniform tensor sharding layout (DTensor Layout) through specifying the {class}`DeviceMesh` +and {class}`Placement`: + +- {class}`DeviceMesh` represents the device topology and the communicators of the cluster using + an n-dimensional array. +- {class}`Placement` describes the sharding layout of the logical tensor on the {class}`DeviceMesh`. + DTensor supports three types of placements: {class}`Shard`, {class}`Replicate` and {class}`Partial`. + +### DTensor Class APIs + +```{eval-rst} +.. currentmodule:: torch.distributed.tensor +``` + +{class}`DTensor` is a `torch.Tensor` subclass. This means once a {class}`DTensor` is created, it could be +used in very similar way to `torch.Tensor`, including running different types of PyTorch operators as if +running them in a single device, allowing proper distributed computation for PyTorch operators. + +In addition to existing `torch.Tensor` methods, it also offers a set of additional methods to interact with +`torch.Tensor`, `redistribute` the DTensor Layout to a new DTensor, get the full tensor content +on all devices, etc. + +```{eval-rst} +.. autoclass:: DTensor + :members: from_local, to_local, full_tensor, redistribute, device_mesh, placements + :member-order: groupwise + :special-members: __create_chunk_list__ + +``` + +### DeviceMesh as the distributed communicator + +```{eval-rst} +.. currentmodule:: torch.distributed.device_mesh +``` + +{class}`DeviceMesh` was built from DTensor as the abstraction to describe cluster's device topology and represent +multi-dimensional communicators (on top of `ProcessGroup`). To see the details of how to create/use a DeviceMesh, +please refer to the [DeviceMesh recipe](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html). + +### DTensor Placement Types + +```{eval-rst} +.. automodule:: torch.distributed.tensor.placement_types +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.tensor.placement_types +``` + +DTensor supports the following types of {class}`Placement` on each {class}`DeviceMesh` dimension: + +```{eval-rst} +.. autoclass:: Shard + :members: + :undoc-members: +``` + +```{eval-rst} +.. autoclass:: Replicate + :members: + :undoc-members: +``` + +```{eval-rst} +.. autoclass:: Partial + :members: + :undoc-members: +``` + +```{eval-rst} +.. autoclass:: Placement + :members: + :undoc-members: +``` + +(create_dtensor)= + +## Different ways to create a DTensor + +```{eval-rst} +.. currentmodule:: torch.distributed.tensor +``` + +There're three ways to construct a {class}`DTensor`: +: - {meth}`distribute_tensor` creates a {class}`DTensor` from a logical or "global" `torch.Tensor` on + each rank. This could be used to shard the leaf `torch.Tensor` s (i.e. model parameters/buffers + and inputs). + - {meth}`DTensor.from_local` creates a {class}`DTensor` from a local `torch.Tensor` on each rank, which can + be used to create {class}`DTensor` from a non-leaf `torch.Tensor` s (i.e. intermediate activation + tensors during forward/backward). + - DTensor provides dedicated tensor factory functions (e.g. {meth}`empty`, {meth}`ones`, {meth}`randn`, etc.) + to allow different {class}`DTensor` creations by directly specifying the {class}`DeviceMesh` and + {class}`Placement`. Compare to {meth}`distribute_tensor`, this could directly materializing the sharded memory + on device, instead of performing sharding after initializing the logical Tensor memory. + +### Create DTensor from a logical torch.Tensor + +The SPMD (single program, multiple data) programming model in `torch.distributed` launches multiple processes +(i.e. via `torchrun`) to execute the same program, this means that the model inside the program would be +initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly +on GPU if enough memory). + +`DTensor` offers a {meth}`distribute_tensor` API that could shard the model weights or Tensors to `DTensor` s, +where it would create a DTensor from the "logical" Tensor on each process. This would empower the created +`DTensor` s to comply with the single device semantic, which is critical for **numerical correctness**. + +```{eval-rst} +.. autofunction:: distribute_tensor +``` + +Along with {meth}`distribute_tensor`, DTensor also offers a {meth}`distribute_module` API to allow easier +sharding on the {class}`nn.Module` level + +```{eval-rst} +.. autofunction:: distribute_module + +``` + +### DTensor Factory Functions + +DTensor also provides dedicated tensor factory functions to allow creating {class}`DTensor` directly +using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally +specifying the {class}`DeviceMesh` and {class}`Placement` for the {class}`DTensor` created: + +```{eval-rst} +.. autofunction:: zeros +``` + +```{eval-rst} +.. autofunction:: ones +``` + +```{eval-rst} +.. autofunction:: empty +``` + +```{eval-rst} +.. autofunction:: full +``` + +```{eval-rst} +.. autofunction:: rand +``` + +```{eval-rst} +.. autofunction:: randn + +``` + +## Debugging + +```{eval-rst} +.. automodule:: torch.distributed.tensor.debug +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.tensor.debug +``` + +### Logging + +When launching the program, you can turn on additional logging using the `TORCH_LOGS` environment variable from +[torch._logging](https://pytorch.org/docs/main/logging.html#module-torch._logging) : + +- `TORCH_LOGS=+dtensor` will display `logging.DEBUG` messages and all levels above it. +- `TORCH_LOGS=dtensor` will display `logging.INFO` messages and above. +- `TORCH_LOGS=-dtensor` will display `logging.WARNING` messages and above. + +### Debugging Tools + +To debug the program that applied DTensor, and understand more details about what collectives happened under the +hood, DTensor provides a {class}`CommDebugMode`: + +```{eval-rst} +.. autoclass:: CommDebugMode + :members: + :undoc-members: +``` + +To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides {meth}`visualize_sharding`: + +```{eval-rst} +.. autofunction:: visualize_sharding + +``` + +## Experimental Features + +`DTensor` also provides a set of experimental features. These features are either in prototyping stage, or the basic +functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to +these features. + +```{eval-rst} +.. automodule:: torch.distributed.tensor.experimental +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.tensor.experimental +``` + +```{eval-rst} +.. autofunction:: context_parallel +``` + +```{eval-rst} +.. autofunction:: local_map +``` + +```{eval-rst} +.. autofunction:: register_sharding + +``` + +% modules that are missing docs, add the doc later when necessary + +```{eval-rst} +.. py:module:: torch.distributed.tensor.device_mesh +``` diff --git a/docs/source/distributed.tensor.parallel.md b/docs/source/distributed.tensor.parallel.md new file mode 100644 index 00000000000000..6083699493ff00 --- /dev/null +++ b/docs/source/distributed.tensor.parallel.md @@ -0,0 +1,92 @@ +:::{role} hidden + :class: hidden-section +::: + +# Tensor Parallelism - torch.distributed.tensor.parallel + +Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor +(DTensor)[https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md] +and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism. + +:::{warning} +Tensor Parallelism APIs are experimental and subject to change. +::: + +The entrypoint to parallelize your `nn.Module` using Tensor Parallelism is: + +```{eval-rst} +.. automodule:: torch.distributed.tensor.parallel +``` + +```{eval-rst} +.. currentmodule:: torch.distributed.tensor.parallel +``` + +```{eval-rst} +.. autofunction:: parallelize_module +``` + +Tensor Parallelism supports the following parallel styles: + +```{eval-rst} +.. autoclass:: torch.distributed.tensor.parallel.ColwiseParallel + :members: + :undoc-members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.tensor.parallel.RowwiseParallel + :members: + :undoc-members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.tensor.parallel.SequenceParallel + :members: + :undoc-members: +``` + +To simply configure the nn.Module's inputs and outputs with DTensor layouts +and perform necessary layout redistributions, without distribute the module +parameters to DTensors, the following `ParallelStyle` s can be used in +the `parallelize_plan` when calling `parallelize_module`: + + +```{eval-rst} +.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput + :members: + :undoc-members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleOutput + :members: + :undoc-members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInputOutput + :members: + :undoc-members: +``` + +:::{note} +when using the `Shard(dim)` as the input/output layouts for the above +`ParallelStyle` s, we assume the input/output activation tensors are evenly sharded on +the tensor dimension `dim` on the `DeviceMesh` that TP operates on. For instance, +since `RowwiseParallel` accepts input that is sharded on the last dimension, it assumes +the input tensor has already been evenly sharded on the last dimension. For the case of uneven sharded activation tensors, one could pass in DTensor directly to the partitioned modules, and use `use_local_output=False` to return DTensor after each `ParallelStyle`, where DTensor could track the uneven sharding information. +::: + +For models like Transformer, we recommend users to use `ColwiseParallel` +and `RowwiseParallel` together in the parallelize_plan for achieve the desired +sharding for the entire model (i.e. Attention and MLP). + +Parallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager: + +```{eval-rst} +.. autofunction:: torch.distributed.tensor.parallel.loss_parallel +``` +:::{warning} + The loss_parallel API is experimental and subject to change. +::: \ No newline at end of file diff --git a/docs/source/distributed.tensor.parallel.rst b/docs/source/distributed.tensor.parallel.rst deleted file mode 100644 index 75cedd809fdc89..00000000000000 --- a/docs/source/distributed.tensor.parallel.rst +++ /dev/null @@ -1,71 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Tensor Parallelism - torch.distributed.tensor.parallel -====================================================== - -Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor -(`DTensor `__) -and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism. - -.. warning :: - Tensor Parallelism APIs are experimental and subject to change. - -The entrypoint to parallelize your ``nn.Module`` using Tensor Parallelism is: - -.. automodule:: torch.distributed.tensor.parallel - -.. currentmodule:: torch.distributed.tensor.parallel - -.. autofunction:: parallelize_module - -Tensor Parallelism supports the following parallel styles: - -.. autoclass:: torch.distributed.tensor.parallel.ColwiseParallel - :members: - :undoc-members: - -.. autoclass:: torch.distributed.tensor.parallel.RowwiseParallel - :members: - :undoc-members: - -.. autoclass:: torch.distributed.tensor.parallel.SequenceParallel - :members: - :undoc-members: - -To simply configure the nn.Module's inputs and outputs with DTensor layouts -and perform necessary layout redistributions, without distribute the module -parameters to DTensors, the following ``ParallelStyle`` s can be used in -the ``parallelize_plan`` when calling ``parallelize_module``: - -.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput - :members: - :undoc-members: - -.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleOutput - :members: - :undoc-members: - -.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInputOutput - :members: - :undoc-members: - -.. note:: when using the ``Shard(dim)`` as the input/output layouts for the above - ``ParallelStyle`` s, we assume the input/output activation tensors are evenly sharded on - the tensor dimension ``dim`` on the ``DeviceMesh`` that TP operates on. For instance, - since ``RowwiseParallel`` accepts input that is sharded on the last dimension, it assumes - the input tensor has already been evenly sharded on the last dimension. For the case of uneven - sharded activation tensors, one could pass in DTensor directly to the partitioned modules, - and use ``use_local_output=False`` to return DTensor after each ``ParallelStyle``, where - DTensor could track the uneven sharding information. - -For models like Transformer, we recommend users to use ``ColwiseParallel`` -and ``RowwiseParallel`` together in the parallelize_plan for achieve the desired -sharding for the entire model (i.e. Attention and MLP). - -Parallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager: - -.. autofunction:: torch.distributed.tensor.parallel.loss_parallel - -.. warning :: - The loss_parallel API is experimental and subject to change. diff --git a/docs/source/distributed.tensor.rst b/docs/source/distributed.tensor.rst deleted file mode 100644 index 559014674b1619..00000000000000 --- a/docs/source/distributed.tensor.rst +++ /dev/null @@ -1,195 +0,0 @@ -.. currentmodule:: torch.distributed.tensor - -torch.distributed.tensor -=========================== - -.. note:: - ``torch.distributed.tensor`` is currently in alpha state and under - development, we are committing backward compatibility for the most APIs listed - in the doc, but there might be API changes if necessary. - - -PyTorch DTensor (Distributed Tensor) ---------------------------------------- - -PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed -logic, including sharded storage, operator computation and collective communications across devices/hosts. -``DTensor`` could be used to build different paralleism solutions and support sharded state_dict representation -when working with multi-dimensional sharding. - -Please see examples from the PyTorch native parallelism solutions that are built on top of ``DTensor``: - -* `Tensor Parallel `__ -* `FSDP2 `__ - -.. automodule:: torch.distributed.tensor - -:class:`DTensor` follows the SPMD (single program, multiple data) programming model to empower users to -write distributed program as if it's a **single-device program with the same convergence property**. It -provides a uniform tensor sharding layout (DTensor Layout) through specifying the :class:`DeviceMesh` -and :class:`Placement`: - -- :class:`DeviceMesh` represents the device topology and the communicators of the cluster using - an n-dimensional array. - -- :class:`Placement` describes the sharding layout of the logical tensor on the :class:`DeviceMesh`. - DTensor supports three types of placements: :class:`Shard`, :class:`Replicate` and :class:`Partial`. - - -DTensor Class APIs -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. currentmodule:: torch.distributed.tensor - -:class:`DTensor` is a ``torch.Tensor`` subclass. This means once a :class:`DTensor` is created, it could be -used in very similar way to ``torch.Tensor``, including running different types of PyTorch operators as if -running them in a single device, allowing proper distributed computation for PyTorch operators. - -In addition to existing ``torch.Tensor`` methods, it also offers a set of additional methods to interact with -``torch.Tensor``, ``redistribute`` the DTensor Layout to a new DTensor, get the full tensor content -on all devices, etc. - -.. autoclass:: DTensor - :members: from_local, to_local, full_tensor, redistribute, device_mesh, placements - :member-order: groupwise - :special-members: __create_chunk_list__ - - -DeviceMesh as the distributed communicator -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. currentmodule:: torch.distributed.device_mesh - -:class:`DeviceMesh` was built from DTensor as the abstraction to describe cluster's device topology and represent -multi-dimensional communicators (on top of ``ProcessGroup``). To see the details of how to create/use a DeviceMesh, -please refer to the `DeviceMesh recipe `__. - - -DTensor Placement Types -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: torch.distributed.tensor.placement_types -.. currentmodule:: torch.distributed.tensor.placement_types - -DTensor supports the following types of :class:`Placement` on each :class:`DeviceMesh` dimension: - -.. autoclass:: Shard - :members: - :undoc-members: - -.. autoclass:: Replicate - :members: - :undoc-members: - -.. autoclass:: Partial - :members: - :undoc-members: - -.. autoclass:: Placement - :members: - :undoc-members: - - -.. _create_dtensor: - -Different ways to create a DTensor ---------------------------------------- - -.. currentmodule:: torch.distributed.tensor - -There're three ways to construct a :class:`DTensor`: - * :meth:`distribute_tensor` creates a :class:`DTensor` from a logical or "global" ``torch.Tensor`` on - each rank. This could be used to shard the leaf ``torch.Tensor`` s (i.e. model parameters/buffers - and inputs). - * :meth:`DTensor.from_local` creates a :class:`DTensor` from a local ``torch.Tensor`` on each rank, which can - be used to create :class:`DTensor` from a non-leaf ``torch.Tensor`` s (i.e. intermediate activation - tensors during forward/backward). - * DTensor provides dedicated tensor factory functions (e.g. :meth:`empty`, :meth:`ones`, :meth:`randn`, etc.) - to allow different :class:`DTensor` creations by directly specifying the :class:`DeviceMesh` and - :class:`Placement`. Compare to :meth:`distribute_tensor`, this could directly materializing the sharded memory - on device, instead of performing sharding after initializing the logical Tensor memory. - -Create DTensor from a logical torch.Tensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The SPMD (single program, multiple data) programming model in ``torch.distributed`` launches multiple processes -(i.e. via ``torchrun``) to execute the same program, this means that the model inside the program would be -initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly -on GPU if enough memory). - -``DTensor`` offers a :meth:`distribute_tensor` API that could shard the model weights or Tensors to ``DTensor`` s, -where it would create a DTensor from the "logical" Tensor on each process. This would empower the created -``DTensor`` s to comply with the single device semantic, which is critical for **numerical correctness**. - -.. autofunction:: distribute_tensor - -Along with :meth:`distribute_tensor`, DTensor also offers a :meth:`distribute_module` API to allow easier -sharding on the :class:`nn.Module` level - -.. autofunction:: distribute_module - - -DTensor Factory Functions -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -DTensor also provides dedicated tensor factory functions to allow creating :class:`DTensor` directly -using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally -specifying the :class:`DeviceMesh` and :class:`Placement` for the :class:`DTensor` created: - -.. autofunction:: zeros - -.. autofunction:: ones - -.. autofunction:: empty - -.. autofunction:: full - -.. autofunction:: rand - -.. autofunction:: randn - - -Debugging ---------------------------------------- - -.. automodule:: torch.distributed.tensor.debug -.. currentmodule:: torch.distributed.tensor.debug - -Logging -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When launching the program, you can turn on additional logging using the `TORCH_LOGS` environment variable from -`torch._logging `__ : - -* `TORCH_LOGS=+dtensor` will display `logging.DEBUG` messages and all levels above it. -* `TORCH_LOGS=dtensor` will display `logging.INFO` messages and above. -* `TORCH_LOGS=-dtensor` will display `logging.WARNING` messages and above. - -Debugging Tools -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To debug the program that applied DTensor, and understand more details about what collectives happened under the -hood, DTensor provides a :class:`CommDebugMode`: - -.. autoclass:: CommDebugMode - :members: - :undoc-members: - -To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides :meth:`visualize_sharding`: - -.. autofunction:: visualize_sharding - - -Experimental Features ---------------------------------------- - -``DTensor`` also provides a set of experimental features. These features are either in prototyping stage, or the basic -functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to -these features. - -.. automodule:: torch.distributed.tensor.experimental -.. currentmodule:: torch.distributed.tensor.experimental - -.. autofunction:: context_parallel -.. autofunction:: local_map -.. autofunction:: register_sharding - - -.. modules that are missing docs, add the doc later when necessary -.. py:module:: torch.distributed.tensor.device_mesh diff --git a/docs/source/distributions.md b/docs/source/distributions.md new file mode 100644 index 00000000000000..71c37c386cd128 --- /dev/null +++ b/docs/source/distributions.md @@ -0,0 +1,692 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# Probability distributions - torch.distributions + +```{eval-rst} +.. automodule:: torch.distributions +``` + +```{eval-rst} +.. currentmodule:: torch.distributions +``` + +## {hidden}`Distribution` + +```{eval-rst} +.. currentmodule:: torch.distributions.distribution +``` + +```{eval-rst} +.. autoclass:: Distribution + :members: + :show-inheritance: +``` + +## {hidden}`ExponentialFamily` + +```{eval-rst} +.. currentmodule:: torch.distributions.exp_family +``` + +```{eval-rst} +.. autoclass:: ExponentialFamily + :members: + :show-inheritance: +``` + +## {hidden}`Bernoulli` + +```{eval-rst} +.. currentmodule:: torch.distributions.bernoulli +``` + +```{eval-rst} +.. autoclass:: Bernoulli + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Beta` + +```{eval-rst} +.. currentmodule:: torch.distributions.beta +``` + +```{eval-rst} +.. autoclass:: Beta + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Binomial` + +```{eval-rst} +.. currentmodule:: torch.distributions.binomial +``` + +```{eval-rst} +.. autoclass:: Binomial + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Categorical` + +```{eval-rst} +.. currentmodule:: torch.distributions.categorical +``` + +```{eval-rst} +.. autoclass:: Categorical + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Cauchy` + +```{eval-rst} +.. currentmodule:: torch.distributions.cauchy +``` + +```{eval-rst} +.. autoclass:: Cauchy + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Chi2` + +```{eval-rst} +.. currentmodule:: torch.distributions.chi2 +``` + +```{eval-rst} +.. autoclass:: Chi2 + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`ContinuousBernoulli` + +```{eval-rst} +.. currentmodule:: torch.distributions.continuous_bernoulli +``` + +```{eval-rst} +.. autoclass:: ContinuousBernoulli + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Dirichlet` + +```{eval-rst} +.. currentmodule:: torch.distributions.dirichlet +``` + +```{eval-rst} +.. autoclass:: Dirichlet + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Exponential` + +```{eval-rst} +.. currentmodule:: torch.distributions.exponential +``` + +```{eval-rst} +.. autoclass:: Exponential + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`FisherSnedecor` + +```{eval-rst} +.. currentmodule:: torch.distributions.fishersnedecor +``` + +```{eval-rst} +.. autoclass:: FisherSnedecor + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Gamma` + +```{eval-rst} +.. currentmodule:: torch.distributions.gamma +``` + +```{eval-rst} +.. autoclass:: Gamma + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`GeneralizedPareto` + +```{eval-rst} +.. currentmodule:: torch.distributions.generalized_pareto +``` + +```{eval-rst} +.. autoclass:: GeneralizedPareto + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Geometric` + +```{eval-rst} +.. currentmodule:: torch.distributions.geometric +``` + +```{eval-rst} +.. autoclass:: Geometric + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Gumbel` + +```{eval-rst} +.. currentmodule:: torch.distributions.gumbel +``` + +```{eval-rst} +.. autoclass:: Gumbel + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`HalfCauchy` + +```{eval-rst} +.. currentmodule:: torch.distributions.half_cauchy +``` + +```{eval-rst} +.. autoclass:: HalfCauchy + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`HalfNormal` + +```{eval-rst} +.. currentmodule:: torch.distributions.half_normal +``` + +```{eval-rst} +.. autoclass:: HalfNormal + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Independent` + +```{eval-rst} +.. currentmodule:: torch.distributions.independent +``` + +```{eval-rst} +.. autoclass:: Independent + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`InverseGamma` + +```{eval-rst} +.. currentmodule:: torch.distributions.inverse_gamma +``` + +```{eval-rst} +.. autoclass:: InverseGamma + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Kumaraswamy` + +```{eval-rst} +.. currentmodule:: torch.distributions.kumaraswamy +``` + +```{eval-rst} +.. autoclass:: Kumaraswamy + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`LKJCholesky` + +```{eval-rst} +.. currentmodule:: torch.distributions.lkj_cholesky +``` + +```{eval-rst} +.. autoclass:: LKJCholesky + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Laplace` + +```{eval-rst} +.. currentmodule:: torch.distributions.laplace +``` + +```{eval-rst} +.. autoclass:: Laplace + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`LogNormal` + +```{eval-rst} +.. currentmodule:: torch.distributions.log_normal +``` + +```{eval-rst} +.. autoclass:: LogNormal + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`LowRankMultivariateNormal` + +```{eval-rst} +.. currentmodule:: torch.distributions.lowrank_multivariate_normal +``` + +```{eval-rst} +.. autoclass:: LowRankMultivariateNormal + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`MixtureSameFamily` + +```{eval-rst} +.. currentmodule:: torch.distributions.mixture_same_family +``` + +```{eval-rst} +.. autoclass:: MixtureSameFamily + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Multinomial` + +```{eval-rst} +.. currentmodule:: torch.distributions.multinomial +``` + +```{eval-rst} +.. autoclass:: Multinomial + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`MultivariateNormal` + +```{eval-rst} +.. currentmodule:: torch.distributions.multivariate_normal +``` + +```{eval-rst} +.. autoclass:: MultivariateNormal + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`NegativeBinomial` + +```{eval-rst} +.. currentmodule:: torch.distributions.negative_binomial +``` + +```{eval-rst} +.. autoclass:: NegativeBinomial + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Normal` + +```{eval-rst} +.. currentmodule:: torch.distributions.normal +``` + +```{eval-rst} +.. autoclass:: Normal + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`OneHotCategorical` + +```{eval-rst} +.. currentmodule:: torch.distributions.one_hot_categorical +``` + +```{eval-rst} +.. autoclass:: OneHotCategorical + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Pareto` + +```{eval-rst} +.. currentmodule:: torch.distributions.pareto +``` + +```{eval-rst} +.. autoclass:: Pareto + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Poisson` + +```{eval-rst} +.. currentmodule:: torch.distributions.poisson +``` + +```{eval-rst} +.. autoclass:: Poisson + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`RelaxedBernoulli` + +```{eval-rst} +.. currentmodule:: torch.distributions.relaxed_bernoulli +``` + +```{eval-rst} +.. autoclass:: RelaxedBernoulli + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`LogitRelaxedBernoulli` + +```{eval-rst} +.. currentmodule:: torch.distributions.relaxed_bernoulli +``` + +```{eval-rst} +.. autoclass:: LogitRelaxedBernoulli + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`RelaxedOneHotCategorical` + +```{eval-rst} +.. currentmodule:: torch.distributions.relaxed_categorical +``` + +```{eval-rst} +.. autoclass:: RelaxedOneHotCategorical + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`StudentT` + +```{eval-rst} +.. currentmodule:: torch.distributions.studentT +``` + +```{eval-rst} +.. autoclass:: StudentT + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`TransformedDistribution` + +```{eval-rst} +.. currentmodule:: torch.distributions.transformed_distribution +``` + +```{eval-rst} +.. autoclass:: TransformedDistribution + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Uniform` + +```{eval-rst} +.. currentmodule:: torch.distributions.uniform +``` + +```{eval-rst} +.. autoclass:: Uniform + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`VonMises` + +```{eval-rst} +.. currentmodule:: torch.distributions.von_mises +``` + +```{eval-rst} +.. autoclass:: VonMises + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Weibull` + +```{eval-rst} +.. currentmodule:: torch.distributions.weibull +``` + +```{eval-rst} +.. autoclass:: Weibull + :members: + :undoc-members: + :show-inheritance: +``` + +## {hidden}`Wishart` + +```{eval-rst} +.. currentmodule:: torch.distributions.wishart +``` + +```{eval-rst} +.. autoclass:: Wishart + :members: + :undoc-members: + :show-inheritance: +``` + +## `KL Divergence` + +```{eval-rst} +.. automodule:: torch.distributions.kl +``` + +```{eval-rst} +.. currentmodule:: torch.distributions.kl +``` + +```{eval-rst} +.. autofunction:: kl_divergence +``` + +```{eval-rst} +.. autofunction:: register_kl +``` + +## `Transforms` + +```{eval-rst} +.. automodule:: torch.distributions.transforms + :members: + :member-order: bysource +``` + +## `Constraints` + +```{eval-rst} +.. automodule:: torch.distributions.constraints + :members: + :member-order: bysource +``` + +## `Constraint Registry` + +```{eval-rst} +.. automodule:: torch.distributions.constraint_registry + :members: + :member-order: bysource +``` + +% This module needs to be documented. Adding here in the meantime + +% for tracking purposes + +```{eval-rst} +.. py:module:: torch.distributions.bernoulli + +.. py:module:: torch.distributions.beta + +.. py:module:: torch.distributions.binomial + +.. py:module:: torch.distributions.categorical + +.. py:module:: torch.distributions.cauchy + +.. py:module:: torch.distributions.chi2 + +.. py:module:: torch.distributions.continuous_bernoulli + +.. py:module:: torch.distributions.dirichlet + +.. py:module:: torch.distributions.distribution + +.. py:module:: torch.distributions.exp_family + +.. py:module:: torch.distributions.exponential + +.. py:module:: torch.distributions.fishersnedecor + +.. py:module:: torch.distributions.gamma + +.. py:module:: torch.distributions.generalized_pareto + +.. py:module:: torch.distributions.geometric + +.. py:module:: torch.distributions.gumbel + +.. py:module:: torch.distributions.half_cauchy + +.. py:module:: torch.distributions.half_normal + +.. py:module:: torch.distributions.independent + +.. py:module:: torch.distributions.inverse_gamma + +.. py:module:: torch.distributions.kumaraswamy + +.. py:module:: torch.distributions.laplace + +.. py:module:: torch.distributions.lkj_cholesky + +.. py:module:: torch.distributions.log_normal + +.. py:module:: torch.distributions.logistic_normal + +.. py:module:: torch.distributions.lowrank_multivariate_normal + +.. py:module:: torch.distributions.mixture_same_family + +.. py:module:: torch.distributions.multinomial + +.. py:module:: torch.distributions.multivariate_normal + +.. py:module:: torch.distributions.negative_binomial + +.. py:module:: torch.distributions.normal + +.. py:module:: torch.distributions.one_hot_categorical + +.. py:module:: torch.distributions.pareto + +.. py:module:: torch.distributions.poisson + +.. py:module:: torch.distributions.relaxed_bernoulli + +.. py:module:: torch.distributions.relaxed_categorical + +.. py:module:: torch.distributions.studentT + +.. py:module:: torch.distributions.transformed_distribution + +.. py:module:: torch.distributions.uniform + +.. py:module:: torch.distributions.utils + +.. py:module:: torch.distributions.von_mises + +.. py:module:: torch.distributions.weibull + +.. py:module:: torch.distributions.wishart +``` diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst deleted file mode 100644 index ac2afeafb53cf5..00000000000000 --- a/docs/source/distributions.rst +++ /dev/null @@ -1,460 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Probability distributions - torch.distributions -================================================== - -.. automodule:: torch.distributions -.. currentmodule:: torch.distributions - -:hidden:`Distribution` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.distribution -.. autoclass:: Distribution - :members: - :show-inheritance: - -:hidden:`ExponentialFamily` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.exp_family -.. autoclass:: ExponentialFamily - :members: - :show-inheritance: - -:hidden:`Bernoulli` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.bernoulli -.. autoclass:: Bernoulli - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Beta` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.beta -.. autoclass:: Beta - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Binomial` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.binomial -.. autoclass:: Binomial - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Categorical` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.categorical -.. autoclass:: Categorical - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Cauchy` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.cauchy -.. autoclass:: Cauchy - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Chi2` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.chi2 -.. autoclass:: Chi2 - :members: - :undoc-members: - :show-inheritance: - -:hidden:`ContinuousBernoulli` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.continuous_bernoulli -.. autoclass:: ContinuousBernoulli - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Dirichlet` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.dirichlet -.. autoclass:: Dirichlet - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Exponential` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.exponential -.. autoclass:: Exponential - :members: - :undoc-members: - :show-inheritance: - -:hidden:`FisherSnedecor` -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.fishersnedecor -.. autoclass:: FisherSnedecor - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Gamma` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.gamma -.. autoclass:: Gamma - :members: - :undoc-members: - :show-inheritance: - -:hidden:`GeneralizedPareto` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.generalized_pareto -.. autoclass:: GeneralizedPareto - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Geometric` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.geometric -.. autoclass:: Geometric - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Gumbel` -~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.gumbel -.. autoclass:: Gumbel - :members: - :undoc-members: - :show-inheritance: - -:hidden:`HalfCauchy` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.half_cauchy -.. autoclass:: HalfCauchy - :members: - :undoc-members: - :show-inheritance: - -:hidden:`HalfNormal` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.half_normal -.. autoclass:: HalfNormal - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Independent` -~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.independent -.. autoclass:: Independent - :members: - :undoc-members: - :show-inheritance: - -:hidden:`InverseGamma` -~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.inverse_gamma -.. autoclass:: InverseGamma - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Kumaraswamy` -~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.kumaraswamy -.. autoclass:: Kumaraswamy - :members: - :undoc-members: - :show-inheritance: - -:hidden:`LKJCholesky` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.lkj_cholesky -.. autoclass:: LKJCholesky - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Laplace` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.laplace -.. autoclass:: Laplace - :members: - :undoc-members: - :show-inheritance: - -:hidden:`LogNormal` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.log_normal -.. autoclass:: LogNormal - :members: - :undoc-members: - :show-inheritance: - -:hidden:`LowRankMultivariateNormal` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.lowrank_multivariate_normal -.. autoclass:: LowRankMultivariateNormal - :members: - :undoc-members: - :show-inheritance: - -:hidden:`MixtureSameFamily` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.mixture_same_family -.. autoclass:: MixtureSameFamily - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Multinomial` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.multinomial -.. autoclass:: Multinomial - :members: - :undoc-members: - :show-inheritance: - -:hidden:`MultivariateNormal` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.multivariate_normal -.. autoclass:: MultivariateNormal - :members: - :undoc-members: - :show-inheritance: - -:hidden:`NegativeBinomial` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.negative_binomial -.. autoclass:: NegativeBinomial - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Normal` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.normal -.. autoclass:: Normal - :members: - :undoc-members: - :show-inheritance: - -:hidden:`OneHotCategorical` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.one_hot_categorical -.. autoclass:: OneHotCategorical - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Pareto` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.pareto -.. autoclass:: Pareto - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Poisson` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.poisson -.. autoclass:: Poisson - :members: - :undoc-members: - :show-inheritance: - -:hidden:`RelaxedBernoulli` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.relaxed_bernoulli -.. autoclass:: RelaxedBernoulli - :members: - :undoc-members: - :show-inheritance: - -:hidden:`LogitRelaxedBernoulli` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.relaxed_bernoulli -.. autoclass:: LogitRelaxedBernoulli - :members: - :undoc-members: - :show-inheritance: - -:hidden:`RelaxedOneHotCategorical` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.relaxed_categorical -.. autoclass:: RelaxedOneHotCategorical - :members: - :undoc-members: - :show-inheritance: - -:hidden:`StudentT` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.studentT -.. autoclass:: StudentT - :members: - :undoc-members: - :show-inheritance: - -:hidden:`TransformedDistribution` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.transformed_distribution -.. autoclass:: TransformedDistribution - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Uniform` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.uniform -.. autoclass:: Uniform - :members: - :undoc-members: - :show-inheritance: - -:hidden:`VonMises` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.von_mises -.. autoclass:: VonMises - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Weibull` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.weibull -.. autoclass:: Weibull - :members: - :undoc-members: - :show-inheritance: - -:hidden:`Wishart` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.distributions.wishart -.. autoclass:: Wishart - :members: - :undoc-members: - :show-inheritance: - -`KL Divergence` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: torch.distributions.kl -.. currentmodule:: torch.distributions.kl - -.. autofunction:: kl_divergence -.. autofunction:: register_kl - -`Transforms` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: torch.distributions.transforms - :members: - :member-order: bysource - -`Constraints` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: torch.distributions.constraints - :members: - :member-order: bysource - -`Constraint Registry` -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: torch.distributions.constraint_registry - :members: - :member-order: bysource - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.distributions.bernoulli -.. py:module:: torch.distributions.beta -.. py:module:: torch.distributions.binomial -.. py:module:: torch.distributions.categorical -.. py:module:: torch.distributions.cauchy -.. py:module:: torch.distributions.chi2 -.. py:module:: torch.distributions.continuous_bernoulli -.. py:module:: torch.distributions.dirichlet -.. py:module:: torch.distributions.distribution -.. py:module:: torch.distributions.exp_family -.. py:module:: torch.distributions.exponential -.. py:module:: torch.distributions.fishersnedecor -.. py:module:: torch.distributions.gamma -.. py:module:: torch.distributions.generalized_pareto -.. py:module:: torch.distributions.geometric -.. py:module:: torch.distributions.gumbel -.. py:module:: torch.distributions.half_cauchy -.. py:module:: torch.distributions.half_normal -.. py:module:: torch.distributions.independent -.. py:module:: torch.distributions.inverse_gamma -.. py:module:: torch.distributions.kumaraswamy -.. py:module:: torch.distributions.laplace -.. py:module:: torch.distributions.lkj_cholesky -.. py:module:: torch.distributions.log_normal -.. py:module:: torch.distributions.logistic_normal -.. py:module:: torch.distributions.lowrank_multivariate_normal -.. py:module:: torch.distributions.mixture_same_family -.. py:module:: torch.distributions.multinomial -.. py:module:: torch.distributions.multivariate_normal -.. py:module:: torch.distributions.negative_binomial -.. py:module:: torch.distributions.normal -.. py:module:: torch.distributions.one_hot_categorical -.. py:module:: torch.distributions.pareto -.. py:module:: torch.distributions.poisson -.. py:module:: torch.distributions.relaxed_bernoulli -.. py:module:: torch.distributions.relaxed_categorical -.. py:module:: torch.distributions.studentT -.. py:module:: torch.distributions.transformed_distribution -.. py:module:: torch.distributions.uniform -.. py:module:: torch.distributions.utils -.. py:module:: torch.distributions.von_mises -.. py:module:: torch.distributions.weibull -.. py:module:: torch.distributions.wishart diff --git a/docs/source/dlpack.md b/docs/source/dlpack.md new file mode 100644 index 00000000000000..d66594bfa0895a --- /dev/null +++ b/docs/source/dlpack.md @@ -0,0 +1,13 @@ +# torch.utils.dlpack + +```{eval-rst} +.. currentmodule:: torch.utils.dlpack +``` + +```{eval-rst} +.. autofunction:: from_dlpack +``` + +```{eval-rst} +.. autofunction:: to_dlpack +``` \ No newline at end of file diff --git a/docs/source/dlpack.rst b/docs/source/dlpack.rst deleted file mode 100644 index 838c75c0f93997..00000000000000 --- a/docs/source/dlpack.rst +++ /dev/null @@ -1,7 +0,0 @@ -torch.utils.dlpack -================== - -.. currentmodule:: torch.utils.dlpack - -.. autofunction:: from_dlpack -.. autofunction:: to_dlpack diff --git a/docs/source/draft_export.md b/docs/source/draft_export.md new file mode 100644 index 00000000000000..cc7247d3b526d1 --- /dev/null +++ b/docs/source/draft_export.md @@ -0,0 +1,262 @@ +(draft-export)= + +# Draft Export + +:::{warning} +This feature is not meant to be used in production and is designed to be +used as a tool for debugging torch.export tracing errors. +::: + +Draft-export is a new version of export, which is designed to consistently +produce a graph, even if there are potential soundness issues, and to generate a +report listing out all of the issues export encountered during +tracing and providing additional debugging information. For custom operators that +don't have fake kernels, it will also generate a profile which you can register +to automatically generate a fake kernel. + +Have you ever tried to export a model using {func}`torch.export.export`, only to +encounter a data-dependent issue? You fix it, but then run into a missing fake +kernel problem. And after resolving that, you get hit with another +data-dependent issue. You wonder to yourself, I wish there was a way I could +just get a graph to play around with, and be able to view all the issues in one +place so that I can fix them later… + +`draft_export` to the rescue! + +`draft_export` is a version of export which will always successfully export a +graph, even if there are potential soundness issues. These issues will then be +compiled into a report for clearer visualization, which can be fixed later on. + +## What sort of errors does it catch? + +Draft-export helps to catch and debug the following errors: + +- Guard on data-dependent errors +- Constraint violation errors +- Missing fake kernels +- Incorrectly written fake kernels + +## How does it work? + +In normal export, we will convert the sample inputs into FakeTensors and use +them to record operations and trace the program into a graph. Input tensor +shapes that can change (which are marked through `dynamic_shapes`), or values +within tensors (typically from an `.item()` call) will be represented as a symbolic +shape (`SymInt`) instead of a concrete integer. However some issues may occur +while tracing - we may run into guards that we cannot evaluate, like if we want +to check if some item in a tensor is greater than 0 (`u0 >= 0`). Since the tracer +doesn't know anything about the value of `u0`, it will throw a data-dependent +error. If the model uses a custom operator but a fake kernel hasn't been +defined for it, then we will error with `fake_tensor.UnsupportedOperatorException` +because export doesn't know how to apply this on `FakeTensors`. If a custom +operator has a fake kernel implemented incorrectly, export will silently produce +an incorrect graph that doesn't match the eager behavior. + +To fix the above errors, draft-export uses *real tensor tracing* to guide us on +how to proceed when tracing. As we trace the model with fake tensors, for every +operation that happens on a fake tensor, draft-export will also run the operator +on stored real tensors which come from the example inputs passed to export. This +allows us to address the above errors: When we reach a guard that we cannot +evaluate, like `u0 >= 0`, we will use the stored real tensor values to +evaluate this guard. Runtime asserts will be added into the graph to ensure that +the graph asserts the same guard that we assumed while tracing. If we run into +a custom operator without a fake kernel, we will run the operator's normal +kernel with the stored real tensors, and return a fake tensor with the same rank +but unbacked shapes. Since we have the real tensor output for every operation, +we will compare this with the fake tensor output from the fake kernel. If the +fake kernel is implemented incorrectly, we will then catch this behavior and +generate a more correct fake kernel. + +## How can I use draft export? + +Let's say you're trying to export this piece of code: + +```python +class M(torch.nn.Module): + def forward(self, x, y, z): + res = torch.ops.mylib.foo2(x, y) + + a = res.item() + a = -a + a = a // 3 + a = a + 5 + + z = torch.cat([z, z]) + + torch._check_is_size(a) + torch._check(a < z.shape[0]) + + return z[:a] + +inp = (torch.tensor(3), torch.tensor(4), torch.ones(3, 3)) + +ep = torch.export.export(M(), inp) +``` + +This runs into a “missing fake kernel” error for `mylib.foo2` and then a +`GuardOnDataDependentExpression` because of the slicing of `z` with `a`, +an unbacked symint. + +To call `draft-export`, we can replace the `torch.export` line with the following: + +```python +ep = torch.export.draft_export(M(), inp) +``` + +`ep` is a valid ExportedProgram which can now be passed through further environments! + +## Debugging with draft-export + +In the terminal output from draft-export, you should see the following message: + +``` +######################################################################################### +WARNING: 2 issue(s) found during export, and it was not able to soundly produce a graph. +To view the report of failures in an html page, please run the command: + `tlparse /tmp/export_angelayi/dedicated_log_torch_trace_axpofwe2.log --export` +Or, you can view the errors in python by inspecting `print(ep._report)`. +######################################################################################## +``` + +Draft-export automatically dumps logs for `tlparse`. You can view the tracing +errors by using `print(ep._report)`, or you can pass the logs into `tlparse` +to generate an html report. + +Running the `tlparse` command in the terminal will generate a +[tlparse](https://github.com/pytorch/tlparse) +HTML report. Here is an example of the `tlparse` report: + +```{image} _static/img/export/draft_export_report.png +``` + +Clicking into the Data Dependent Error, we will see the following page which +contains information to help debug this error. Specifically, it contains: + +- The stacktrace at which this error occurs +- A list of local variables and their shapes +- Information for how this guard was created + +```{image} _static/img/export/draft_export_report_dde.png +``` + +## The returned Exported Program + +Because draft-export specializes on code paths based on the example inputs, the +exported program resulting from draft-export is guaranteed to be runnable and +return correct results for **at least** the given example inputs. Other inputs can +work, as long as they match the same guards that were taken when we were +draft-exporting. + +For example, if we have a graph branching on if a value is greater than 5, if in +draft-export our example inputs were greater than 5, then the returned +`ExportedProgram` will specialize on that branch, and will assert that the value +is greater than 5. This means that the program will succeed if you pass in +another value greater than 5, but will fail if you pass in a value less than 5. +This is more sound than `torch.jit.trace`, which will silently specialize on the +branch. The proper way for `torch.export` to support both branches would be to +rewrite the code using `torch.cond`, which will then capture both branches. + +Because of the runtime assertions in the graph, the returned exported-program is +also retraceable with `torch.export` or `torch.compile`, with a minor addition in +the case where a custom operator is missing a fake kernel. + +## Generating Fake Kernels + +If a custom operator does not contain a fake implementation, currently +draft-export will use the real-tensor propagation to get an output for the +operator and continue tracing. However, if we run the exported program with fake +tensors or retrace the exported model, we will still fail because there is still +no fake kernel implementation. + +To address this, after draft-export, we will generate an operator profile for +each custom operator call that we encounter, and store this on the report +attached to the exported program: `ep._report.op_profiles`. Users can then use the +context manager `torch._library.fake_profile.unsafe_generate_fake_kernels` to +generate and register a fake implementation based on these operator profiles. +This way future fake tensor retracing will work. + +The workflow would look something like: + +```python +class M(torch.nn.Module): + def forward(self, a, b): + res = torch.ops.mylib.foo(a, b) # no fake impl + return res + +ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4))) + +with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles): + decomp = ep.run_decompositions() + +new_inp = ( + torch.ones(2, 3, 4), + torch.ones(2, 3, 4), +) + +# Save the profile to a yaml and check it into a codebase +save_op_profiles(ep._report.op_profiles, "op_profile.yaml") +# Load the yaml +loaded_op_profile = load_op_profiles("op_profile.yaml") +``` + +The operator profile is a dictionary mapping operator name to a set of profiles +which describe the input and outputs of the operator, and could be manually +written, saved into a yaml file, and checked into a codebase. Here's an example +of a profile for `mylib.foo.default`: + +```python +"mylib.foo.default": { + OpProfile( + args_profile=( + TensorMetadata( + rank=2, + dtype=torch.float32, + device=torch.device("cpu"), + layout=torch.strided, + ), + TensorMetadata( + rank=2, + dtype=torch.float32, + device=torch.device("cpu"), + layout=torch.strided, + ), + ), + out_profile=TensorMetadata( + rank=2, + dtype=torch.float32, + device=torch.device("cpu"), + layout=torch.strided, + ), + ) +} +``` + +`mylib.foo.default`'s profile contains only one profile, which says that for 2 +input tensors of rank 2, dtype `torch.float32`, device `cpu`, we will return +one tensor of rank 2, dtype `torch.float32`, and device `cpu`. Using the +context manager, will then generate a fake kernel where given 2 input tensors of +rank 2 (and the other tensor metadata), we will output one tensor of rank 2 (and +the other tensor metadata). + +If the operator also supports other input ranks, then we can add the profile to +this list of profiles, either by manually adding it into the existing profile or +rerunning draft-export with new inputs to get new profiles, so that the +generated fake kernel will support more input types. Otherwise it will error. + +## Where to go from here? + +Now that we have successfully created an `ExportedProgram` using draft-export, +we can use further compilers such as `AOTInductor` to optimize its performance +and produce a runnable artifact. This optimized version can then be used for +deployment. In parallel, we can utilize the report generated by draft-export to +identify and fix `torch.export` errors that were encountered so that the +original model can be directly traceable with `torch.export`. + +```{toctree} +:caption: Additional Links +:maxdepth: 1 + +torch.compiler_fake_tensor +torch.compiler_dynamic_shapes +torch.compiler_aot_inductor +``` diff --git a/docs/source/draft_export.rst b/docs/source/draft_export.rst deleted file mode 100644 index e085bece727f22..00000000000000 --- a/docs/source/draft_export.rst +++ /dev/null @@ -1,269 +0,0 @@ -.. _draft_export: - -Draft Export -============ - -.. warning:: - - This feature is not meant to be used in production and is designed to be - used as a tool for debugging torch.export tracing errors. - -Draft-export is a new version of export, which is designed to consistently -produce a graph, even if there are potential soundness issues, and to generate a -report listing out all of the issues export encountered during -tracing and providing additional debugging information. For custom operators that -don't have fake kernels, it will also generate a profile which you can register -to automatically generate a fake kernel. - -Have you ever tried to export a model using :func:`torch.export.export`, only to -encounter a data-dependent issue? You fix it, but then run into a missing fake -kernel problem. And after resolving that, you get hit with another -data-dependent issue. You wonder to yourself, I wish there was a way I could -just get a graph to play around with, and be able to view all the issues in one -place so that I can fix them later… - -``draft_export`` to the rescue! - -``draft_export`` is a version of export which will always successfully export a -graph, even if there are potential soundness issues. These issues will then be -compiled into a report for clearer visualization, which can be fixed later on. - -What sort of errors does it catch? ----------------------------------- - -Draft-export helps to catch and debug the following errors: - -* Guard on data-dependent errors -* Constraint violation errors -* Missing fake kernels -* Incorrectly written fake kernels - -How does it work? ------------------ - -In normal export, we will convert the sample inputs into FakeTensors and use -them to record operations and trace the program into a graph. Input tensor -shapes that can change (which are marked through ``dynamic_shapes``), or values -within tensors (typically from an ``.item()`` call) will be represented as a symbolic -shape (``SymInt``) instead of a concrete integer. However some issues may occur -while tracing - we may run into guards that we cannot evaluate, like if we want -to check if some item in a tensor is greater than 0 (``u0 >= 0``). Since the tracer -doesn't know anything about the value of ``u0``, it will throw a data-dependent -error. If the model uses a custom operator but a fake kernel hasn't been -defined for it, then we will error with ``fake_tensor.UnsupportedOperatorException`` -because export doesn't know how to apply this on ``FakeTensors``. If a custom -operator has a fake kernel implemented incorrectly, export will silently produce -an incorrect graph that doesn't match the eager behavior. - -To fix the above errors, draft-export uses *real tensor tracing* to guide us on -how to proceed when tracing. As we trace the model with fake tensors, for every -operation that happens on a fake tensor, draft-export will also run the operator -on stored real tensors which come from the example inputs passed to export. This -allows us to address the above errors: When we reach a guard that we cannot -evaluate, like ``u0 >= 0``, we will use the stored real tensor values to -evaluate this guard. Runtime asserts will be added into the graph to ensure that -the graph asserts the same guard that we assumed while tracing. If we run into -a custom operator without a fake kernel, we will run the operator's normal -kernel with the stored real tensors, and return a fake tensor with the same rank -but unbacked shapes. Since we have the real tensor output for every operation, -we will compare this with the fake tensor output from the fake kernel. If the -fake kernel is implemented incorrectly, we will then catch this behavior and -generate a more correct fake kernel. - -How can I use draft export? ---------------------------- - -Let's say you're trying to export this piece of code: - -:: - - class M(torch.nn.Module): - def forward(self, x, y, z): - res = torch.ops.mylib.foo2(x, y) - - a = res.item() - a = -a - a = a // 3 - a = a + 5 - - z = torch.cat([z, z]) - - torch._check_is_size(a) - torch._check(a < z.shape[0]) - - return z[:a] - - inp = (torch.tensor(3), torch.tensor(4), torch.ones(3, 3)) - - ep = torch.export.export(M(), inp) - -This runs into a “missing fake kernel” error for ``mylib.foo2`` and then a -``GuardOnDataDependentExpression`` because of the slicing of ``z`` with ``a``, -an unbacked symint. - -To call ``draft-export``, we can replace the ``torch.export`` line with the following: - -:: - - ep = torch.export.draft_export(M(), inp) - -``ep`` is a valid ExportedProgram which can now be passed through further environments! - -Debugging with draft-export ---------------------------- - -In the terminal output from draft-export, you should see the following message: - -.. code-block:: - - ######################################################################################### - WARNING: 2 issue(s) found during export, and it was not able to soundly produce a graph. - To view the report of failures in an html page, please run the command: - `tlparse /tmp/export_angelayi/dedicated_log_torch_trace_axpofwe2.log --export` - Or, you can view the errors in python by inspecting `print(ep._report)`. - ######################################################################################## - -Draft-export automatically dumps logs for ``tlparse``. You can view the tracing -errors by using ``print(ep._report)``, or you can pass the logs into ``tlparse`` -to generate an html report. - -Running the ``tlparse`` command in the terminal will generate a -`tlparse `_ -HTML report. Here is an example of the ``tlparse`` report: - -.. image:: _static/img/export/draft_export_report.png - -Clicking into the Data Dependent Error, we will see the following page which -contains information to help debug this error. Specifically, it contains: - -* The stacktrace at which this error occurs -* A list of local variables and their shapes -* Information for how this guard was created - -.. image:: _static/img/export/draft_export_report_dde.png - - -The returned Exported Program ------------------------------ - -Because draft-export specializes on code paths based on the example inputs, the -exported program resulting from draft-export is guaranteed to be runnable and -return correct results for **at least** the given example inputs. Other inputs can -work, as long as they match the same guards that were taken when we were -draft-exporting. - -For example, if we have a graph branching on if a value is greater than 5, if in -draft-export our example inputs were greater than 5, then the returned -``ExportedProgram`` will specialize on that branch, and will assert that the value -is greater than 5. This means that the program will succeed if you pass in -another value greater than 5, but will fail if you pass in a value less than 5. -This is more sound than ``torch.jit.trace``, which will silently specialize on the -branch. The proper way for ``torch.export`` to support both branches would be to -rewrite the code using ``torch.cond``, which will then capture both branches. - -Because of the runtime assertions in the graph, the returned exported-program is -also retraceable with ``torch.export`` or ``torch.compile``, with a minor addition in -the case where a custom operator is missing a fake kernel. - -Generating Fake Kernels ------------------------ - -If a custom operator does not contain a fake implementation, currently -draft-export will use the real-tensor propagation to get an output for the -operator and continue tracing. However, if we run the exported program with fake -tensors or retrace the exported model, we will still fail because there is still -no fake kernel implementation. - -To address this, after draft-export, we will generate an operator profile for -each custom operator call that we encounter, and store this on the report -attached to the exported program: ``ep._report.op_profiles``. Users can then use the -context manager ``torch._library.fake_profile.unsafe_generate_fake_kernels`` to -generate and register a fake implementation based on these operator profiles. -This way future fake tensor retracing will work. - -The workflow would look something like: - -:: - - class M(torch.nn.Module): - def forward(self, a, b): - res = torch.ops.mylib.foo(a, b) # no fake impl - return res - - ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4))) - - with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles): - decomp = ep.run_decompositions() - - new_inp = ( - torch.ones(2, 3, 4), - torch.ones(2, 3, 4), - ) - - # Save the profile to a yaml and check it into a codebase - save_op_profiles(ep._report.op_profiles, "op_profile.yaml") - # Load the yaml - loaded_op_profile = load_op_profiles("op_profile.yaml") - -The operator profile is a dictionary mapping operator name to a set of profiles -which describe the input and outputs of the operator, and could be manually -written, saved into a yaml file, and checked into a codebase. Here's an example -of a profile for ``mylib.foo.default``: - -:: - - "mylib.foo.default": { - OpProfile( - args_profile=( - TensorMetadata( - rank=2, - dtype=torch.float32, - device=torch.device("cpu"), - layout=torch.strided, - ), - TensorMetadata( - rank=2, - dtype=torch.float32, - device=torch.device("cpu"), - layout=torch.strided, - ), - ), - out_profile=TensorMetadata( - rank=2, - dtype=torch.float32, - device=torch.device("cpu"), - layout=torch.strided, - ), - ) - } - -``mylib.foo.default``'s profile contains only one profile, which says that for 2 -input tensors of rank 2, dtype ``torch.float32``, device ``cpu``, we will return -one tensor of rank 2, dtype ``torch.float32``, and device ``cpu``. Using the -context manager, will then generate a fake kernel where given 2 input tensors of -rank 2 (and the other tensor metadata), we will output one tensor of rank 2 (and -the other tensor metadata). - -If the operator also supports other input ranks, then we can add the profile to -this list of profiles, either by manually adding it into the existing profile or -rerunning draft-export with new inputs to get new profiles, so that the -generated fake kernel will support more input types. Otherwise it will error. - - -Where to go from here? ----------------------- - -Now that we have successfully created an ``ExportedProgram`` using draft-export, -we can use further compilers such as ``AOTInductor`` to optimize its performance -and produce a runnable artifact. This optimized version can then be used for -deployment. In parallel, we can utilize the report generated by draft-export to -identify and fix ``torch.export`` errors that were encountered so that the -original model can be directly traceable with ``torch.export``. - -.. toctree:: - :caption: Additional Links - :maxdepth: 1 - - torch.compiler_fake_tensor - torch.compiler_dynamic_shapes - torch.compiler_aot_inductor diff --git a/docs/source/elastic/agent.rst b/docs/source/elastic/agent.rst index ac42403761f39a..38aed193a2cf2f 100644 --- a/docs/source/elastic/agent.rst +++ b/docs/source/elastic/agent.rst @@ -48,7 +48,7 @@ Below are the agent implementations provided by torchelastic. Extending the Agent --------------------- -To extend the agent you can implement ```ElasticAgent`` directly, however +To extend the agent you can implement ``ElasticAgent`` directly, however we recommend you extend ``SimpleElasticAgent`` instead, which provides most of the scaffolding and leaves you with a few specific abstract methods to implement. @@ -64,24 +64,24 @@ to implement. Watchdog in the Agent --------------------- -A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an +A named pipe based watchdog can be enabled in ``LocalElasticAgent`` if an environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has -been defined in the ```LocalElasticAgent``` process. -Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` +been defined in the ``LocalElasticAgent`` process. +Optionally, another environment variable ``TORCHELASTIC_TIMER_FILE`` can be set with a unique file name for the named pipe. If the environment -variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` +variable ``TORCHELASTIC_TIMER_FILE`` is not set, ``LocalElasticAgent`` will internally create a unique file name and set it to the environment -variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will +variable ``TORCHELASTIC_TIMER_FILE``, and this environment variable will be propagated to the worker processes to allow them to connect to the same -named pipe that ```LocalElasticAgent``` uses. +named pipe that ``LocalElasticAgent`` uses. Health Check Server ------------------- -A health check monitoring server can be enabled in ```LocalElasticAgent``` +A health check monitoring server can be enabled in ``LocalElasticAgent`` if an environment variable ``TORCHELASTIC_HEALTH_CHECK_PORT`` has been defined -in the ```LocalElasticAgent``` process. +in the ``LocalElasticAgent`` process. Adding interface for health check server which can be extended by starting tcp/http server on the specified port number. Additionally, health check server will have callback to check watchdog is alive. diff --git a/docs/source/export.ir_spec.md b/docs/source/export.ir_spec.md new file mode 100644 index 00000000000000..562cae1e337fa6 --- /dev/null +++ b/docs/source/export.ir_spec.md @@ -0,0 +1,487 @@ +(export.ir_spec)= + +# torch.export IR Specification + +Export IR is an intermediate representation (IR) for compilers, which bears +similarities to [MLIR](https://mlir.llvm.org/) and TorchScript. It is specifically designed to express the +semantics of PyTorch programs. Export IR primarily represents computation in a +streamlined list of operations, with limited support for dynamism such as +control flows. + +To create an Export IR graph, a frontend can be used that soundly captures a +PyTorch program via a trace-specializing mechanism. The resulting Export IR can +then be optimized and executed by a backend. This can be done today through +{func}`torch.export.export`. + +The key concepts that will be covered in this document include: + +- ExportedProgram: the data structure containing the Export IR program +- Graph: which consists of a list of nodes. +- Nodes: which represents operations, control flow, and metadata stored on this node. +- Values are produced and consumed by nodes. +- Types are associated with values and nodes. +- The size and memory layout of values are also defined. + +## Assumptions + +This doc assumes that the audience is sufficiently familiar with PyTorch, +specifically with {class}`torch.fx` and its related toolings. Thus it will stop +describing contents present in {class}`torch.fx` documentation and paper. + +## What is Export IR + +Export IR is a graph-based intermediate representation IR of PyTorch programs. +Export IR is realized on top of {class}`torch.fx.Graph`. In other words, **all +Export IR graphs are also valid FX graphs**, and if interpreted using standard +FX semantics, Export IR can be interpreted soundly. One implication is that an +exported graph can be converted to a valid Python program via standard FX +codegen. + +This documentation will primarily focus on highlighting areas where Export IR +differs from FX in terms of its strictness, while skipping parts where it shares +similarities with FX. + +## ExportedProgram + +The top-level Export IR construct is an {class}`torch.export.ExportedProgram` +class. It bundles the computational graph of a PyTorch model (which is usually a +{class}`torch.nn.Module`) with the parameters or weights that this model +consumes. + +Some notable attributes of the {class}`torch.export.ExportedProgram` class are: + +- `graph_module` ({class}`torch.fx.GraphModule`): Data structure containing + the flattened computational graph of the PyTorch model. The graph can be + directly accessed through `ExportedProgram.graph`. +- `graph_signature` ({class}`torch.export.ExportGraphSignature`): The graph + signature, which specifies the parameters and buffer names used and mutated + within the graph. Instead of storing parameters and buffers as attributes of + the graph, they are lifted as inputs to the graph. The graph_signature is + utilized to keep track of additional information on these parameters and + buffers. +- `state_dict` (`Dict[str, Union[torch.Tensor, torch.nn.Parameter]]`): Data + structure containing the parameters and buffers. +- `range_constraints` (`Dict[sympy.Symbol, RangeConstraint]`): For programs + that are exported with data dependent behavior, the metadata on each node will + contain symbolic shapes (which look like `s0`, `i0`). This attribute maps + the symbolic shapes to their lower/upper ranges. + +## Graph + +An Export IR Graph is a PyTorch program represented in the form of a DAG +(directed acyclic graph). Each node in this graph represents a particular +computation or operation, and edges of this graph consist of references between +nodes. + +We can view Graph having this schema: + +```python +class Graph: + nodes: List[Node] +``` + +In practice, Export IR's graph is realized as {class}`torch.fx.Graph` Python class. + +An Export IR graph contains the following nodes (Nodes will be described in more +details in the next section): + +- 0 or more nodes of op type `placeholder` +- 0 or more nodes of op type `call_function` +- exactly 1 node of op type `output` + +**Collorary:** The smallest valid Graph will be of one node. i.e. nodes is never empty. + +**Definition:** +The set of `placeholder` nodes of a Graph represents the **inputs** of the +Graph of GraphModule. The `output` node of a Graph represents the **outputs** +of the Graph of GraphModule. + +Example: + +```python +import torch +from torch import nn + +class MyModule(nn.Module): + + def forward(self, x, y): + return x + y + +example_args = (torch.randn(1), torch.randn(1)) +mod = torch.export.export(MyModule(), example_args) +print(mod.graph) +``` + +```python +graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=1] = placeholder[target=y] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {}) + return (add,) +``` + +The above is the textual representation of a Graph, with each line being a node. + +## Node + +A Node represents a particular computation or operation and is represented in +Python using the {class}`torch.fx.Node` class. Edges between nodes are +represented as direct references to other nodes via the `args` property of the +Node class. Using the same FX machinery, we can represent the following +operations that a computational graph typically needs, such as operator calls, +placeholders (aka inputs), conditionals, and loops. + +The Node has the following schema: + +```python +class Node: + name: str # name of node + op_name: str # type of operation + + # interpretation of the fields below depends on op_name + target: [str|Callable] + args: List[object] + kwargs: Dict[str, object] + meta: Dict[str, object] +``` + +**FX Text Format** + +As in the example above, notice that each line has this format: + +``` +%:[...] = [target=](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5}) +``` + +This format captures everything present in the Node class, with the exception of +`meta`, in a compact format. + +Concretely: + +- **** is the name of the node as it would appear in `node.name`. +- **** is the `node.op` field, which must be one of these: + ``, ``, + ``, or ``. +- **** is the target of the node as `node.target`. The meaning of this + field depends on `op_name`. +- **args1, … args 4…** are what is listed in the `node.args` tuple. If a + value in the list is an {class}`torch.fx.Node`, then it will be especially + indicated with a leading **%.** + +For example, a call to the add operator would appear as: + +``` +%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {}) +``` + +Where `%x`, `%y` are two other Nodes that have names x and y. Worth noting +that the string `torch.op.aten.add.Tensor` represents the callable object that +is actually stored in the target field, not merely its string name. + +The final line of this text format is: + +``` +return [add] +``` + +which is a Node with `op_name = output`, indicating that we are returning this +one element. + +### call_function + +A `call_function` node represents a call to an operator. + +**Definitions** + +- **Functional:** We say a callable is “functional” if it satisfies all the + following requirements: + + - Non-mutating: The operator does not mutate the value of its input (for + tensors, this includes both metadata and data). + - No side effects: The operator does not mutate states that are visible + from outside, like changing values of module parameters. + +- **Operator:** is a functional callable with a predefined schema. Examples of + such operators include functional ATen operators. + +**Representation in FX** + +``` +%name = call_function[target = operator](args = (%x, %y, …), kwargs = {}) +``` + +**Differences from vanilla FX call_function** + +1. In FX graph, a call_function can refer to any callable, in Export IR, we + restrict it to only a select subset of ATen operators, custom operators, and + control flow operators. +2. In Export IR, constant arguments will be embedded within the graph. +3. In FX graph, a get_attr node can represent reading any attribute stored in + the graph module. However, in Export IR this is restricted to reading only + submodules as all parameters/buffers will be passed in as inputs to the graph + module. + +#### Metadata + +`Node.meta` is a dict attached to every FX node. However, the FX spec does not +specify what metadata can or will be there. Export IR provides a stronger +contract, specifically all `call_function` nodes will guarantee having and +only having the following metadata fields: + +- `node.meta["stack_trace"]` is a string containing the Python stack trace + referencing the original Python source code. An example stack trace looks + like: + + ``` + File "my_module.py", line 19, in forward + return x + dummy_helper(y) + File "helper_utility.py", line 89, in dummy_helper + return y + 1 + ``` + +- `node.meta["val"]` describes the output of running the operation. It can be + of type ``, ``, a + `List[Union[FakeTensor, SymInt]]`, or `None`. + +- `node.meta["nn_module_stack"]` describes the "stacktrace" of the + {class}`torch.nn.Module` from which the node came, if it was from a + {class}`torch.nn.Module` call. For example, if a node containing the `addmm` + op called from a {class}`torch.nn.Linear` module inside of a + {class}`torch.nn.Sequential` module, the `nn_module_stack` would look + something like: + + ``` + {'self_linear': ('self.linear', ), 'self_sequential': ('self.sequential', )} + ``` + +- `node.meta["source_fn_stack"]` contains the torch function or the leaf + {class}`torch.nn.Module` class this node was called from before decomposition. + For example, a node containing the `addmm` op from a + {class}`torch.nn.Linear` module call would contain {class}`torch.nn.Linear` in + their `source_fn`, and a node containing the `addmm` op from a + {class}`torch.nn.functional.Linear` module call would contain + {class}`torch.nn.functional.Linear` in their `source_fn`. + +### placeholder + +Placeholder represents an input to a graph. Its semantics are exactly the same as in FX. +Placeholder nodes must be the first N nodes in the nodes list of a graph. N can be zero. + +**Representation in FX** + +```python +%name = placeholder[target = name](args = ()) +``` + +The target field is a string which is the name of input. + +`args`, if non-empty, should be of size 1 representing the default value of this input. + +**Metadata** + +Placeholder nodes also have `meta[‘val’]`, like `call_function` nodes. The +`val` field in this case represents the input shape/dtype that the graph is +expected to receive for this input parameter. + +### output + +An output call represents a return statement in a function; it thus terminates the +current graph. There is one and only one output node, and it will always be the +last node of the graph. + +**Representation in FX** + +``` +output[](args = (%something, …)) +``` + +This has the exact semantics as in {class}`torch.fx`. `args` represents the node +to be returned. + +**Metadata** + +Output node has the same metadata as `call_function` nodes. + +### get_attr + +`get_attr` nodes represent reading a submodule from the encapsulating +{class}`torch.fx.GraphModule`. Unlike a vanilla FX graph from +{func}`torch.fx.symbolic_trace` in which `get_attr` nodes are used to read +attributes such as parameters and buffers from the top-level +{class}`torch.fx.GraphModule`, parameters and buffers are passed in as +inputs to the graph module, and stored in the top-level +{class}`torch.export.ExportedProgram`. + +**Representation in FX** + +```python +%name = get_attr[target = name](args = ()) +``` + +**Example** + +Consider the following model: + +```python +from functorch.experimental.control_flow import cond + +def true_fn(x): + return x.sin() + +def false_fn(x): + return x.cos() + +def f(x, y): + return cond(y, true_fn, false_fn, [x]) +``` + +Graph: + +``` +graph(): + %x_1 : [num_users=1] = placeholder[target=x_1] + %y_1 : [num_users=1] = placeholder[target=y_1] + %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] + %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] + %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {}) + return conditional +``` + +The line, `%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]`, +reads the submodule `true_graph_0` which contains the `sin` operator. + +## References + +### SymInt + +A SymInt is an object that can either be a literal integer or a symbol that represents +an Integer (represented in Python by `sympy.Symbol` class). When SymInt is a +symbol, it describes a variable of type integer that is unknown to the graph at +compile time, that is, its value is only known at runtime. + +### FakeTensor + +A FakeTensor is an object that contains the metadata of a tensor. It can be +viewed as having the following metadata. + +```python +class FakeTensor: + size: List[SymInt] + dtype: torch.dtype + device: torch.device + dim_order: List[int] # This doesn't exist yet +``` + +The size field of FakeTensor is a list of integers or SymInts. If SymInts are +present, this means this tensor has a dynamic shape. If integers are present, it +is assumed that the tensor will have that exact static shape. The rank of the +TensorMeta is never dynamic. The dtype field represents the dtype of the +output of that node. There are no implicit type promotions in Edge IR. There +are no strides in FakeTensor. + +In other words: + +- If the operator in node.target returns a Tensor, then `node.meta['val']` is a + FakeTensor describing that tensor. +- If the operator in node.target returns an n-tuple of Tensors, then + `node.meta['val']` is an n-tuple of FakeTensors describing each tensor. +- If the operator in node.target returns an int/float/scalar that is known at + compile time, then `node.meta['val']` is None. +- If the operator in node.target returns an int/float/scalar that is not known + at compile time, then `node.meta['val']` is of type SymInt. + +For example: + +- `aten::add` returns a Tensor; so its spec will be a FakeTensor with dtype + and size of the tensor returned by this operator. +- `aten::sym_size` returns an integer; so its val will be a SymInt because its + value is only available at runtime. +- `max_pool2d_with_indexes` returns a tuple of (Tensor, Tensor); so the spec + will also be a 2-tuple of FakeTensor objects, the first TensorMeta describes + the first element of the return value etc. + +Python code: + +```python +def add_one(x): + return torch.ops.aten(x, 1) +``` + +Graph: + +``` +graph(): + %ph_0 : [#users=1] = placeholder[target=ph_0] + %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {}) + return [add_tensor] +``` + +FakeTensor: + +```python +FakeTensor(dtype=torch.int, size=[2,], device=CPU) +``` + +### Pytree-able Types + +We define a type “Pytree-able”, if it is either a leaf type or a container type +that contains other Pytree-able types. + +Note: + +> The concept of pytree is the same as the one documented +> [here](https://jax.readthedocs.io/en/latest/pytrees.html) for JAX: + +The following types are defined as **leaf type**: + +```{eval-rst} +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Type + - Definition + * - Tensor + - :class:`torch.Tensor` + * - Scalar + - Any numerical types from Python, including integral types, floating point types, and zero dimensional tensors. + * - int + - Python int (bound as int64_t in C++) + * - float + - Python float (bound as double in C++) + * - bool + - Python bool + * - str + - Python string + * - ScalarType + - :class:`torch.dtype` + * - Layout + - :class:`torch.layout` + * - MemoryFormat + - :class:`torch.memory_format` + * - Device + - :class:`torch.device` +``` + +The following types are defined as **container type**: + +```{eval-rst} +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Type + - Definition + * - Tuple + - Python tuple + * - List + - Python list + * - Dict + - Python dict with Scalar keys + * - NamedTuple + - Python namedtuple + * - Dataclass + - Must be registered through `register_dataclass `__ + * - Custom class + - Any custom class defined with `_register_pytree_node `__ +``` diff --git a/docs/source/export.ir_spec.rst b/docs/source/export.ir_spec.rst deleted file mode 100644 index dadbd8d0c6ea1b..00000000000000 --- a/docs/source/export.ir_spec.rst +++ /dev/null @@ -1,483 +0,0 @@ -.. _export.ir_spec: - -torch.export IR Specification -============================= - -Export IR is an intermediate representation (IR) for compilers, which bears -similarities to MLIR and TorchScript. It is specifically designed to express the -semantics of PyTorch programs. Export IR primarily represents computation in a -streamlined list of operations, with limited support for dynamism such as -control flows. - -To create an Export IR graph, a frontend can be used that soundly captures a -PyTorch program via a trace-specializing mechanism. The resulting Export IR can -then be optimized and executed by a backend. This can be done today through -:func:`torch.export.export`. - -The key concepts that will be covered in this document include: - -- ExportedProgram: the data structure containing the Export IR program -- Graph: which consists of a list of nodes. -- Nodes: which represents operations, control flow, and metadata stored on this node. -- Values are produced and consumed by nodes. -- Types are associated with values and nodes. -- The size and memory layout of values are also defined. - -Assumptions ------------- - -This doc assumes that the audience is sufficiently familiar with PyTorch, -specifically with :class:`torch.fx` and its related toolings. Thus it will stop -describing contents present in :class:`torch.fx` documentation and paper. - -What is Export IR ------------------ - -Export IR is a graph-based intermediate representation IR of PyTorch programs. -Export IR is realized on top of :class:`torch.fx.Graph`. In other words, **all -Export IR graphs are also valid FX graphs**, and if interpreted using standard -FX semantics, Export IR can be interpreted soundly. One implication is that an -exported graph can be converted to a valid Python program via standard FX -codegen. - -This documentation will primarily focus on highlighting areas where Export IR -differs from FX in terms of its strictness, while skipping parts where it shares -similarities with FX. - -ExportedProgram ---------------- - -The top-level Export IR construct is an :class:`torch.export.ExportedProgram` -class. It bundles the computational graph of a PyTorch model (which is usually a -:class:`torch.nn.Module`) with the parameters or weights that this model -consumes. - -Some notable attributes of the :class:`torch.export.ExportedProgram` class are: - -- ``graph_module`` (:class:`torch.fx.GraphModule`): Data structure containing - the flattened computational graph of the PyTorch model. The graph can be - directly accessed through `ExportedProgram.graph`. -- ``graph_signature`` (:class:`torch.export.ExportGraphSignature`): The graph - signature, which specifies the parameters and buffer names used and mutated - within the graph. Instead of storing parameters and buffers as attributes of - the graph, they are lifted as inputs to the graph. The graph_signature is - utilized to keep track of additional information on these parameters and - buffers. -- ``state_dict`` (``Dict[str, Union[torch.Tensor, torch.nn.Parameter]]``): Data - structure containing the parameters and buffers. -- ``range_constraints`` (``Dict[sympy.Symbol, RangeConstraint]``): For programs - that are exported with data dependent behavior, the metadata on each node will - contain symbolic shapes (which look like ``s0``, ``i0``). This attribute maps - the symbolic shapes to their lower/upper ranges. - -Graph ------ - -An Export IR Graph is a PyTorch program represented in the form of a DAG -(directed acyclic graph). Each node in this graph represents a particular -computation or operation, and edges of this graph consist of references between -nodes. - -We can view Graph having this schema: - -.. code-block:: python - - class Graph: - nodes: List[Node] - -In practice, Export IR's graph is realized as :class:`torch.fx.Graph` Python class. - -An Export IR graph contains the following nodes (Nodes will be described in more -details in the next section): - -- 0 or more nodes of op type ``placeholder`` -- 0 or more nodes of op type ``call_function`` -- exactly 1 node of op type ``output`` - -**Collorary:** The smallest valid Graph will be of one node. i.e. nodes is never empty. - -**Definition:** -The set of ``placeholder`` nodes of a Graph represents the **inputs** of the -Graph of GraphModule. The `output` node of a Graph represents the **outputs** -of the Graph of GraphModule. - -Example:: - - import torch - from torch import nn - - class MyModule(nn.Module): - - def forward(self, x, y): - return x + y - - example_args = (torch.randn(1), torch.randn(1)) - mod = torch.export.export(MyModule(), example_args) - print(mod.graph) - -.. code-block:: python - - graph(): - %x : [num_users=1] = placeholder[target=x] - %y : [num_users=1] = placeholder[target=y] - %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {}) - return (add,) - -The above is the textual representation of a Graph, with each line being a node. - -Node ----- - -A Node represents a particular computation or operation and is represented in -Python using the :class:`torch.fx.Node` class. Edges between nodes are -represented as direct references to other nodes via the ``args`` property of the -Node class. Using the same FX machinery, we can represent the following -operations that a computational graph typically needs, such as operator calls, -placeholders (aka inputs), conditionals, and loops. - -The Node has the following schema: - -.. code-block:: python - - class Node: - name: str # name of node - op_name: str # type of operation - - # interpretation of the fields below depends on op_name - target: [str|Callable] - args: List[object] - kwargs: Dict[str, object] - meta: Dict[str, object] - -**FX Text Format** - -As in the example above, notice that each line has this format:: - - %:[...] = [target=](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5}) - -This format captures everything present in the Node class, with the exception of -``meta``, in a compact format. - -Concretely: - -- **** is the name of the node as it would appear in ``node.name``. - -- **** is the ``node.op`` field, which must be one of these: - ``, ``, - ``, or ``. - -- **** is the target of the node as ``node.target``. The meaning of this - field depends on ``op_name``. - -- **args1, … args 4…** are what is listed in the ``node.args`` tuple. If a - value in the list is an :class:`torch.fx.Node`, then it will be especially - indicated with a leading **%.** - -For example, a call to the add operator would appear as:: - - %add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {}) - -Where ``%x``, ``%y`` are two other Nodes that have names x and y. Worth noting -that the string ``torch.op.aten.add.Tensor`` represents the callable object that -is actually stored in the target field, not merely its string name. - -The final line of this text format is:: - - return [add] - -which is a Node with ``op_name = output``, indicating that we are returning this -one element. - -call_function -^^^^^^^^^^^^^ - -A ``call_function`` node represents a call to an operator. - -**Definitions** - -- **Functional:** We say a callable is “functional” if it satisfies all the - following requirements: - - - Non-mutating: The operator does not mutate the value of its input (for - tensors, this includes both metadata and data). - - No side effects: The operator does not mutate states that are visible - from outside, like changing values of module parameters. - -- **Operator:** is a functional callable with a predefined schema. Examples of - such operators include functional ATen operators. - -**Representation in FX** - -.. code-block:: - - %name = call_function[target = operator](args = (%x, %y, …), kwargs = {}) - - -**Differences from vanilla FX call_function** - -1. In FX graph, a call_function can refer to any callable, in Export IR, we - restrict it to only a select subset of ATen operators, custom operators, and - control flow operators. - -2. In Export IR, constant arguments will be embedded within the graph. - -3. In FX graph, a get_attr node can represent reading any attribute stored in - the graph module. However, in Export IR this is restricted to reading only - submodules as all parameters/buffers will be passed in as inputs to the graph - module. - -Metadata -~~~~~~~~ - -``Node.meta`` is a dict attached to every FX node. However, the FX spec does not -specify what metadata can or will be there. Export IR provides a stronger -contract, specifically all ``call_function`` nodes will guarantee having and -only having the following metadata fields: - -- ``node.meta["stack_trace"]`` is a string containing the Python stack trace - referencing the original Python source code. An example stack trace looks - like:: - - File "my_module.py", line 19, in forward - return x + dummy_helper(y) - File "helper_utility.py", line 89, in dummy_helper - return y + 1 - -- ``node.meta["val"]`` describes the output of running the operation. It can be - of type ``, ``, a - ``List[Union[FakeTensor, SymInt]]``, or ``None``. - -- ``node.meta["nn_module_stack"]`` describes the "stacktrace" of the - :class:`torch.nn.Module` from which the node came, if it was from a - :class:`torch.nn.Module` call. For example, if a node containing the ``addmm`` - op called from a :class:`torch.nn.Linear` module inside of a - :class:`torch.nn.Sequential` module, the ``nn_module_stack`` would look - something like:: - - {'self_linear': ('self.linear', ), 'self_sequential': ('self.sequential', )} - -- ``node.meta["source_fn_stack"]`` contains the torch function or the leaf - :class:`torch.nn.Module` class this node was called from before decomposition. - For example, a node containing the ``addmm`` op from a - :class:`torch.nn.Linear` module call would contain :class:`torch.nn.Linear` in - their ``source_fn``, and a node containing the ``addmm`` op from a - :class:`torch.nn.functional.Linear` module call would contain - :class:`torch.nn.functional.Linear` in their ``source_fn``. - -placeholder -^^^^^^^^^^^ - -Placeholder represents an input to a graph. Its semantics are exactly the same as in FX. -Placeholder nodes must be the first N nodes in the nodes list of a graph. N can be zero. - -**Representation in FX** - -.. code-block:: python - - %name = placeholder[target = name](args = ()) - -The target field is a string which is the name of input. - -``args``, if non-empty, should be of size 1 representing the default value of this input. - -**Metadata** - -Placeholder nodes also have ``meta[‘val’]``, like ``call_function`` nodes. The -``val`` field in this case represents the input shape/dtype that the graph is -expected to receive for this input parameter. - -output -^^^^^^ - -An output call represents a return statement in a function; it thus terminates the -current graph. There is one and only one output node, and it will always be the -last node of the graph. - -**Representation in FX** - -.. code-block:: - - output[](args = (%something, …)) - -This has the exact semantics as in :class:`torch.fx`. ``args`` represents the node -to be returned. - -**Metadata** - -Output node has the same metadata as ``call_function`` nodes. - -get_attr -^^^^^^^^ - -``get_attr`` nodes represent reading a submodule from the encapsulating -:class:`torch.fx.GraphModule`. Unlike a vanilla FX graph from -:func:`torch.fx.symbolic_trace` in which ``get_attr`` nodes are used to read -attributes such as parameters and buffers from the top-level -:class:`torch.fx.GraphModule`, parameters and buffers are passed in as -inputs to the graph module, and stored in the top-level -:class:`torch.export.ExportedProgram`. - -**Representation in FX** - -.. code-block:: python - - %name = get_attr[target = name](args = ()) - -**Example** - -Consider the following model:: - - from functorch.experimental.control_flow import cond - - def true_fn(x): - return x.sin() - - def false_fn(x): - return x.cos() - - def f(x, y): - return cond(y, true_fn, false_fn, [x]) - -Graph:: - - graph(): - %x_1 : [num_users=1] = placeholder[target=x_1] - %y_1 : [num_users=1] = placeholder[target=y_1] - %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] - %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] - %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {}) - return conditional - -The line, ``%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]``, -reads the submodule ``true_graph_0`` which contains the ``sin`` operator. - -References ----------- - -SymInt -^^^^^^ - -A SymInt is an object that can either be a literal integer or a symbol that represents -an Integer (represented in Python by ``sympy.Symbol`` class). When SymInt is a -symbol, it describes a variable of type integer that is unknown to the graph at -compile time, that is, its value is only known at runtime. - -FakeTensor -^^^^^^^^^^ - -A FakeTensor is an object that contains the metadata of a tensor. It can be -viewed as having the following metadata. - -.. code-block:: python - - class FakeTensor: - size: List[SymInt] - dtype: torch.dtype - device: torch.device - dim_order: List[int] # This doesn't exist yet - -The size field of FakeTensor is a list of integers or SymInts. If SymInts are -present, this means this tensor has a dynamic shape. If integers are present, it -is assumed that the tensor will have that exact static shape. The rank of the -TensorMeta is never dynamic. The dtype field represents the dtype of the -output of that node. There are no implicit type promotions in Edge IR. There -are no strides in FakeTensor. - -In other words: - -- If the operator in node.target returns a Tensor, then ``node.meta['val']`` is a - FakeTensor describing that tensor. -- If the operator in node.target returns an n-tuple of Tensors, then - ``node.meta['val']`` is an n-tuple of FakeTensors describing each tensor. -- If the operator in node.target returns an int/float/scalar that is known at - compile time, then ``node.meta['val']`` is None. -- If the operator in node.target returns an int/float/scalar that is not known - at compile time, then ``node.meta['val']`` is of type SymInt. - -For example: - -- ``aten::add`` returns a Tensor; so its spec will be a FakeTensor with dtype - and size of the tensor returned by this operator. -- ``aten::sym_size`` returns an integer; so its val will be a SymInt because its - value is only available at runtime. -- ``max_pool2d_with_indexes`` returns a tuple of (Tensor, Tensor); so the spec - will also be a 2-tuple of FakeTensor objects, the first TensorMeta describes - the first element of the return value etc. - -Python code:: - - def add_one(x): - return torch.ops.aten(x, 1) - -Graph:: - - graph(): - %ph_0 : [#users=1] = placeholder[target=ph_0] - %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {}) - return [add_tensor] - -FakeTensor:: - - FakeTensor(dtype=torch.int, size=[2,], device=CPU) - -Pytree-able Types -^^^^^^^^^^^^^^^^^ - -We define a type “Pytree-able”, if it is either a leaf type or a container type -that contains other Pytree-able types. - -Note: - - The concept of pytree is the same as the one documented - `here `__ for JAX: - - -The following types are defined as **leaf type**: - -.. list-table:: - :widths: 50 50 - :header-rows: 1 - - * - Type - - Definition - * - Tensor - - :class:`torch.Tensor` - * - Scalar - - Any numerical types from Python, including integral types, floating point types, and zero dimensional tensors. - * - int - - Python int (bound as int64_t in C++) - * - float - - Python float (bound as double in C++) - * - bool - - Python bool - * - str - - Python string - * - ScalarType - - :class:`torch.dtype` - * - Layout - - :class:`torch.layout` - * - MemoryFormat - - :class:`torch.memory_format` - * - Device - - :class:`torch.device` - -The following types are defined as **container type**: - -.. list-table:: - :widths: 50 50 - :header-rows: 1 - - * - Type - - Definition - * - Tuple - - Python tuple - * - List - - Python list - * - Dict - - Python dict with Scalar keys - * - NamedTuple - - Python namedtuple - * - Dataclass - - Must be registered through `register_dataclass `__ - * - Custom class - - Any custom class defined with `_register_pytree_node `__ diff --git a/docs/source/export.md b/docs/source/export.md new file mode 100644 index 00000000000000..9d57614a14adcd --- /dev/null +++ b/docs/source/export.md @@ -0,0 +1,923 @@ +(torch.export)= + +# torch.export + +:::{warning} +This feature is a prototype under active development and there WILL BE +BREAKING CHANGES in the future. +::: + +## Overview + +{func}`torch.export.export` takes a {class}`torch.nn.Module` and produces a traced graph +representing only the Tensor computation of the function in an Ahead-of-Time +(AOT) fashion, which can subsequently be executed with different outputs or +serialized. + +```python +import torch +from torch.export import export + +class Mod(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + a = torch.sin(x) + b = torch.cos(y) + return a + b + +example_args = (torch.randn(10, 10), torch.randn(10, 10)) + +exported_program: torch.export.ExportedProgram = export( + Mod(), args=example_args +) +print(exported_program) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"): + # code: a = torch.sin(x) + sin: "f32[10, 10]" = torch.ops.aten.sin.default(x) + + # code: b = torch.cos(y) + cos: "f32[10, 10]" = torch.ops.aten.cos.default(y) + + # code: return a + b + add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos) + return (add,) + + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='y'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='add'), + target=None + ) + ] + ) + Range constraints: {} +``` + +`torch.export` produces a clean intermediate representation (IR) with the +following invariants. More specifications about the IR can be found +{ref}`here `. + +- **Soundness**: It is guaranteed to be a sound representation of the original + program, and maintains the same calling conventions of the original program. +- **Normalized**: There are no Python semantics within the graph. Submodules + from the original programs are inlined to form one fully flattened + computational graph. +- **Graph properties**: The graph is purely functional, meaning it does not + contain operations with side effects such as mutations or aliasing. It does + not mutate any intermediate values, parameters, or buffers. +- **Metadata**: The graph contains metadata captured during tracing, such as a + stacktrace from user's code. + +Under the hood, `torch.export` leverages the following latest technologies: + +- **TorchDynamo (torch._dynamo)** is an internal API that uses a CPython feature + called the Frame Evaluation API to safely trace PyTorch graphs. This + provides a massively improved graph capturing experience, with much fewer + rewrites needed in order to fully trace the PyTorch code. +- **AOT Autograd** provides a functionalized PyTorch graph and ensures the graph + is decomposed/lowered to the ATen operator set. +- **Torch FX (torch.fx)** is the underlying representation of the graph, + allowing flexible Python-based transformations. + +### Existing frameworks + +{func}`torch.compile` also utilizes the same PT2 stack as `torch.export`, but +is slightly different: + +- **JIT vs. AOT**: {func}`torch.compile` is a JIT compiler whereas + which is not intended to be used to produce compiled artifacts outside of + deployment. +- **Partial vs. Full Graph Capture**: When {func}`torch.compile` runs into an + untraceable part of a model, it will "graph break" and fall back to running + the program in the eager Python runtime. In comparison, `torch.export` aims + to get a full graph representation of a PyTorch model, so it will error out + when something untraceable is reached. Since `torch.export` produces a full + graph disjoint from any Python features or runtime, this graph can then be + saved, loaded, and run in different environments and languages. +- **Usability tradeoff**: Since {func}`torch.compile` is able to fallback to the + Python runtime whenever it reaches something untraceable, it is a lot more + flexible. `torch.export` will instead require users to provide more + information or rewrite their code to make it traceable. + +Compared to {func}`torch.fx.symbolic_trace`, `torch.export` traces using +TorchDynamo which operates at the Python bytecode level, giving it the ability +to trace arbitrary Python constructs not limited by what Python operator +overloading supports. Additionally, `torch.export` keeps fine-grained track of +tensor metadata, so that conditionals on things like tensor shapes do not +fail tracing. In general, `torch.export` is expected to work on more user +programs, and produce lower-level graphs (at the `torch.ops.aten` operator +level). Note that users can still use {func}`torch.fx.symbolic_trace` as a +preprocessing step before `torch.export`. + +Compared to {func}`torch.jit.script`, `torch.export` does not capture Python +control flow or data structures, but it supports more Python language +features due to its comprehensive coverage over Python bytecodes. +The resulting graphs are simpler and only have straight line control +flow, except for explicit control flow operators. + +Compared to {func}`torch.jit.trace`, `torch.export` is sound: +it can trace code that performs integer computation on sizes and records +all of the side-conditions necessary to ensure that a particular +trace is valid for other inputs. + +## Exporting a PyTorch Model + +### An Example + +The main entrypoint is through {func}`torch.export.export`, which takes a +callable ({class}`torch.nn.Module`, function, or method) and sample inputs, and +captures the computation graph into an {class}`torch.export.ExportedProgram`. An +example: + +```python +import torch +from torch.export import export + +# Simple module for demonstration +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, out_channels=16, kernel_size=3, padding=1 + ) + self.relu = torch.nn.ReLU() + self.maxpool = torch.nn.MaxPool2d(kernel_size=3) + + def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor: + a = self.conv(x) + a.add_(constant) + return self.maxpool(self.relu(a)) + +example_args = (torch.randn(1, 3, 256, 256),) +example_kwargs = {"constant": torch.ones(1, 16, 256, 256)} + +exported_program: torch.export.ExportedProgram = export( + M(), args=example_args, kwargs=example_kwargs +) +print(exported_program) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"): + # code: a = self.conv(x) + conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]) + + # code: a.add_(constant) + add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant) + + # code: return self.maxpool(self.relu(a)) + relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_) + max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]) + return (max_pool2d,) + +Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_weight'), + target='conv.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_bias'), + target='conv.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='constant'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='max_pool2d'), + target=None + ) + ] + ) +Range constraints: {} +``` + +Inspecting the `ExportedProgram`, we can note the following: + +- The {class}`torch.fx.Graph` contains the computation graph of the original + program, along with records of the original code for easy debugging. +- The graph contains only `torch.ops.aten` operators found [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) + and custom operators, and is fully functional, without any inplace operators + such as `torch.add_`. +- The parameters (weight and bias to conv) are lifted as inputs to the graph, + resulting in no `get_attr` nodes in the graph, which previously existed in + the result of {func}`torch.fx.symbolic_trace`. +- The {class}`torch.export.ExportGraphSignature` models the input and output + signature, along with specifying which inputs are parameters. +- The resulting shape and dtype of tensors produced by each node in the graph is + noted. For example, the `convolution` node will result in a tensor of dtype + `torch.float32` and shape (1, 16, 256, 256). + +(non-strict-export)= + +### Non-Strict Export + +In PyTorch 2.3, we introduced a new mode of tracing called **non-strict mode**. +It's still going through hardening, so if you run into any issues, please file +them to Github with the "oncall: export" tag. + +In *non-strict mode*, we trace through the program using the Python interpreter. +Your code will execute exactly as it would in eager mode; the only difference is +that all Tensor objects will be replaced by ProxyTensors, which will record all +their operations into a graph. + +In *strict* mode, which is currently the default, we first trace through the +program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not +actually execute your Python code. Instead, it symbolically analyzes it and +builds a graph based on the results. This analysis allows torch.export to +provide stronger guarantees about safety, but not all Python code is supported. + +An example of a case where one might want to use non-strict mode is if you run +into a unsupported TorchDynamo feature that might not be easily solved, and you +know the python code is not exactly needed for computation. For example: + +```python +import contextlib +import torch + +class ContextManager(): + def __init__(self): + self.count = 0 + def __enter__(self): + self.count += 1 + def __exit__(self, exc_type, exc_value, traceback): + self.count -= 1 + +class M(torch.nn.Module): + def forward(self, x): + with ContextManager(): + return x.sin() + x.cos() + +export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully +export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager +``` + +In this example, the first call using non-strict mode (through the +`strict=False` flag) traces successfully whereas the second call using strict +mode (default) results with a failure, where TorchDynamo is unable to support +context managers. One option is to rewrite the code (see {ref}`Limitations of torch.export `), +but seeing as the context manager does not affect the tensor +computations in the model, we can go with the non-strict mode's result. + +(training-export)= + +### Export for Training and Inference + +In PyTorch 2.5, we introduced a new API called {func}`export_for_training`. +It's still going through hardening, so if you run into any issues, please file +them to Github with the "oncall: export" tag. + +In this API, we produce the most generic IR that contains all ATen operators +(including both functional and non-functional) which can be used to train in +eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization +and will soon be the default IR of torch.export.export. To read further about +the motivation behind this change, please refer to + + +When this API is combined with {func}`run_decompositions()`, you should be able to get inference IR with +any desired decomposition behavior. + +To show some examples: + +```python +class ConvBatchnorm(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) + +mod = ConvBatchnorm() +inp = torch.randn(1, 1, 3, 3) + +ep_for_training = torch.export.export_for_training(mod, (inp,)) +print(ep_for_training) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) + add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1) + batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True) + return (batch_norm,) +``` + +From the above output, you can see that {func}`export_for_training` produces pretty much the same ExportedProgram +as {func}`export` except for the operators in the graph. You can see that we captured batch_norm in the most general +form. This op is non-functional and will be lowered to different ops when running inference. + +You can also go from this IR to an inference IR via {func}`run_decompositions` with arbitrary customizations. + +```python +# Lower to core aten inference IR, but keep conv2d +decomp_table = torch.export.default_decompositions() +del decomp_table[torch.ops.aten.conv2d.default] +ep_for_inference = ep_for_training.run_decompositions(decomp_table) + +print(ep_for_inference) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) + getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] + getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4] + return (getitem_3, getitem_4, add, getitem) +``` + +Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR +containing core aten operators except for `conv2d`. + +You can do even more customization by directly registering your chosen decomposition behaviors. + +You can do even more customizations by directly registering custom decomp behaviour + +```python +# Lower to core aten inference IR, but customize conv2d +decomp_table = torch.export.default_decompositions() + +def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): + return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) + +decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function +ep_for_inference = ep_for_training.run_decompositions(decomp_table) + +print(ep_for_inference) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) + mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2) + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) + getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] + getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; + return (getitem_3, getitem_4, add, getitem) +``` + +### Expressing Dynamism + +By default `torch.export` will trace the program assuming all input shapes are +**static**, and specializing the exported program to those dimensions. However, +some dimensions, such as a batch dimension, can be dynamic and vary from run to +run. Such dimensions must be specified by using the +{func}`torch.export.Dim` API to create them and by passing them into +{func}`torch.export.export` through the `dynamic_shapes` argument. An example: + +```python +import torch +from torch.export import Dim, export + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + + self.branch1 = torch.nn.Sequential( + torch.nn.Linear(64, 32), torch.nn.ReLU() + ) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer = torch.ones(32) + + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer, out2) + +example_args = (torch.randn(32, 64), torch.randn(32, 128)) + +# Create a dynamic batch size +batch = Dim("batch") +# Specify that the first dimension of each input is that batch size +dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + +exported_program: torch.export.ExportedProgram = export( + M(), args=example_args, dynamic_shapes=dynamic_shapes +) +print(exported_program) +``` + +```python +ExportedProgram: +class GraphModule(torch.nn.Module): + def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"): + + # code: out1 = self.branch1(x1) + linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias) + relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear) + + # code: out2 = self.branch2(x2) + linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias) + relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1) + + # code: return (out1 + self.buffer, out2) + add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer) + return (add, relu_1) + +Range constraints: {s0: VR[0, int_oo]} +``` + +Some additional things to note: + +- Through the {func}`torch.export.Dim` API and the `dynamic_shapes` argument, we specified the first + dimension of each input to be dynamic. Looking at the inputs `x1` and + `x2`, they have a symbolic shape of (s0, 64) and (s0, 128), instead of + the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. + `s0` is a symbol representing that this dimension can be a range + of values. +- `exported_program.range_constraints` describes the ranges of each symbol + appearing in the graph. In this case, we see that `s0` has the range + [0, int_oo]. For technical reasons that are difficult to explain here, they are + assumed to be not 0 or 1. This is not a bug, and does not necessarily mean + that the exported program will not work for dimensions 0 or 1. See + [The 0/1 Specialization Problem](https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk) + for an in-depth discussion of this topic. + +We can also specify more expressive relationships between input shapes, such as +where a pair of shapes might differ by one, a shape might be double of +another, or a shape is even. An example: + +```python +class M(torch.nn.Module): + def forward(self, x, y): + return x + y[1:] + +x, y = torch.randn(5), torch.randn(6) +dimx = torch.export.Dim("dimx", min=3, max=6) +dimy = dimx + 1 + +exported_program = torch.export.export( + M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), +) +print(exported_program) +``` + +```python +ExportedProgram: +class GraphModule(torch.nn.Module): + def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"): + # code: return x + y[1:] + slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807) + add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1) + return (add,) + +Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]} +``` + +Some things to note: + +- By specifying `{0: dimx}` for the first input, we see that the resulting + shape of the first input is now dynamic, being `[s0]`. And now by specifying + `{0: dimy}` for the second input, we see that the resulting shape of the + second input is also dynamic. However, because we expressed `dimy = dimx + 1`, + instead of `y`'s shape containing a new symbol, we see that it is + now being represented with the same symbol used in `x`, `s0`. We can + see that relationship of `dimy = dimx + 1` is being shown through `s0 + 1`. +- Looking at the range constraints, we see that `s0` has the range [3, 6], + which is specified initially, and we can see that `s0 + 1` has the solved + range of [4, 7]. + +### Serialization + +To save the `ExportedProgram`, users can use the {func}`torch.export.save` and +{func}`torch.export.load` APIs. A convention is to save the `ExportedProgram` +using a `.pt2` file extension. + +An example: + +```python +import torch +import io + +class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + +exported_program = torch.export.export(MyModule(), torch.randn(5)) + +torch.export.save(exported_program, 'exported_program.pt2') +saved_exported_program = torch.export.load('exported_program.pt2') +``` + +### Specializations + +A key concept in understanding the behavior of `torch.export` is the +difference between *static* and *dynamic* values. + +A *dynamic* value is one that can change from run to run. These behave like +normal arguments to a Python function—you can pass different values for an +argument and expect your function to do the right thing. Tensor *data* is +treated as dynamic. + +A *static* value is a value that is fixed at export time and cannot change +between executions of the exported program. When the value is encountered during +tracing, the exporter will treat it as a constant and hard-code it into the +graph. + +When an operation is performed (e.g. `x + y`) and all inputs are static, then +the output of the operation will be directly hard-coded into the graph, and the +operation won’t show up (i.e. it will get constant-folded). + +When a value has been hard-coded into the graph, we say that the graph has been +*specialized* to that value. + +The following values are static: + +#### Input Tensor Shapes + +By default, `torch.export` will trace the program specializing on the input +tensors' shapes, unless a dimension is specified as dynamic via the +`dynamic_shapes` argument to `torch.export`. This means that if there exists +shape-dependent control flow, `torch.export` will specialize on the branch +that is being taken with the given sample inputs. For example: + +```python +import torch +from torch.export import export + +class Mod(torch.nn.Module): + def forward(self, x): + if x.shape[0] > 5: + return x + 1 + else: + return x - 1 + +example_inputs = (torch.rand(10, 2),) +exported_program = export(Mod(), example_inputs) +print(exported_program) +``` + +```python +ExportedProgram: +class GraphModule(torch.nn.Module): + def forward(self, x: "f32[10, 2]"): + # code: return x + 1 + add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1) + return (add,) +``` + +The conditional of (`x.shape[0] > 5`) does not appear in the +`ExportedProgram` because the example inputs have the static +shape of (10, 2). Since `torch.export` specializes on the inputs' static +shapes, the else branch (`x - 1`) will never be reached. To preserve the dynamic +branching behavior based on the shape of a tensor in the traced graph, +{func}`torch.export.Dim` will need to be used to specify the dimension +of the input tensor (`x.shape[0]`) to be dynamic, and the source code will +need to be {ref}`rewritten `. + +Note that tensors that are part of the module state (e.g. parameters and +buffers) always have static shapes. + +#### Python Primitives + +`torch.export` also specializes on Python primitives, +such as `int`, `float`, `bool`, and `str`. However they do have dynamic +variants such as `SymInt`, `SymFloat`, and `SymBool`. + +For example: + +```python +import torch +from torch.export import export + +class Mod(torch.nn.Module): + def forward(self, x: torch.Tensor, const: int, times: int): + for i in range(times): + x = x + const + return x + +example_inputs = (torch.rand(2, 2), 1, 3) +exported_program = export(Mod(), example_inputs) +print(exported_program) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[2, 2]", const, times): + # code: x = x + const + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1) + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1) + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1) + return (add_2,) +``` + +Because integers are specialized, the `torch.ops.aten.add.Tensor` operations +are all computed with the hard-coded constant `1`, rather than `const`. If +a user passes a different value for `const` at runtime, like 2, than the one used +during export time, 1, this will result in an error. +Additionally, the `times` iterator used in the `for` loop is also "inlined" +in the graph through the 3 repeated `torch.ops.aten.add.Tensor` calls, and the +input `times` is never used. + +#### Python Containers + +Python containers (`List`, `Dict`, `NamedTuple`, etc.) are considered to +have static structure. + +(limitations-of-torch-export)= + +## Limitations of torch.export + +### Graph Breaks + +As `torch.export` is a one-shot process for capturing a computation graph from +a PyTorch program, it might ultimately run into untraceable parts of programs as +it is nearly impossible to support tracing all PyTorch and Python features. In +the case of `torch.compile`, an unsupported operation will cause a "graph +break" and the unsupported operation will be run with default Python evaluation. +In contrast, `torch.export` will require users to provide additional +information or rewrite parts of their code to make it traceable. As the +tracing is based on TorchDynamo, which evaluates at the Python +bytecode level, there will be significantly fewer rewrites required compared to +previous tracing frameworks. + +When a graph break is encountered, {ref}`ExportDB ` is a great +resource for learning about the kinds of programs that are supported and +unsupported, along with ways to rewrite programs to make them traceable. + +An option to get past dealing with this graph breaks is by using +{ref}`non-strict export ` + +(data-shape-dependent-control-flow)= + +### Data/Shape-Dependent Control Flow + +Graph breaks can also be encountered on data-dependent control flow (`if +x.shape[0] > 2`) when shapes are not being specialized, as a tracing compiler cannot +possibly deal with without generating code for a combinatorially exploding +number of paths. In such cases, users will need to rewrite their code using +special control flow operators. Currently, we support {ref}`torch.cond ` +to express if-else like control flow (more coming soon!). + +### Missing Fake/Meta/Abstract Kernels for Operators + +When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is +required for all operators. This is used to reason about the input/output shapes +for this operator. + +Please see {func}`torch.library.register_fake` for more details. + +In the unfortunate case where your model uses an ATen operator that is does not +have a FakeTensor kernel implementation yet, please file an issue. + +## Read More + +```{toctree} +:caption: Additional Links for Export Users +:maxdepth: 1 + +export.programming_model +export.ir_spec +draft_export +torch.compiler_transformations +torch.compiler_ir +generated/exportdb/index +cond +``` + +```{toctree} +:caption: Deep Dive for PyTorch Developers +:maxdepth: 1 + +torch.compiler_dynamo_overview +torch.compiler_dynamo_deepdive +torch.compiler_dynamic_shapes +torch.compiler_fake_tensor +``` + +## API Reference + +```{eval-rst} +.. automodule:: torch.export +``` + +```{eval-rst} +.. autofunction:: export +``` + +```{eval-rst} +.. autofunction:: save +``` + +```{eval-rst} +.. autofunction:: load +``` + +```{eval-rst} +.. autofunction:: draft_export +``` + +```{eval-rst} +.. autofunction:: register_dataclass +``` + +```{eval-rst} +.. autoclass:: torch.export.dynamic_shapes.Dim +``` + +```{eval-rst} +.. autoclass:: torch.export.dynamic_shapes.ShapesCollection + + .. automethod:: dynamic_shapes +``` + +```{eval-rst} +.. autoclass:: torch.export.dynamic_shapes.AdditionalInputs + + .. automethod:: add + .. automethod:: dynamic_shapes + .. automethod:: verify +``` + +```{eval-rst} +.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes +``` + +```{eval-rst} +.. autoclass:: ExportedProgram + + .. attribute:: graph + .. attribute:: graph_signature + .. attribute:: state_dict + .. attribute:: constants + .. attribute:: range_constraints + .. attribute:: module_call_graph + .. attribute:: example_inputs + .. automethod:: module + .. automethod:: run_decompositions +``` + +```{eval-rst} +.. autoclass:: ExportGraphSignature +``` + +```{eval-rst} +.. autoclass:: ModuleCallSignature +``` + +```{eval-rst} +.. autoclass:: ModuleCallEntry +``` + +```{eval-rst} +.. automodule:: torch.export.decomp_utils +``` + +```{eval-rst} +.. autoclass:: CustomDecompTable + + .. automethod:: copy + .. automethod:: items + .. automethod:: keys + .. automethod:: materialize + .. automethod:: pop + .. automethod:: update +``` + +```{eval-rst} +.. autofunction:: torch.export.exported_program.default_decompositions +``` + +```{eval-rst} +.. automodule:: torch.export.exported_program +``` + +```{eval-rst} +.. automodule:: torch.export.graph_signature +``` + +```{eval-rst} +.. autoclass:: ExportGraphSignature + + .. automethod:: replace_all_uses + .. automethod:: get_replace_hook +``` + +```{eval-rst} +.. autoclass:: ExportBackwardSignature +``` + +```{eval-rst} +.. autoclass:: InputKind +``` + +```{eval-rst} +.. autoclass:: InputSpec +``` + +```{eval-rst} +.. autoclass:: OutputKind +``` + +```{eval-rst} +.. autoclass:: OutputSpec +``` + +```{eval-rst} +.. autoclass:: SymIntArgument +``` + +```{eval-rst} +.. autoclass:: SymBoolArgument +``` + +```{eval-rst} +.. autoclass:: SymFloatArgument +``` + +```{eval-rst} +.. autoclass:: CustomObjArgument +``` + +```{eval-rst} +.. py:module:: torch.export.dynamic_shapes +``` + +```{eval-rst} +.. py:module:: torch.export.custom_ops +``` + +```{eval-rst} +.. automodule:: torch.export.unflatten + :members: +``` + +```{eval-rst} +.. automodule:: torch.export.custom_obj +``` + +```{eval-rst} +.. automodule:: torch.export.experimental +``` + +```{eval-rst} +.. automodule:: torch.export.passes +``` + +```{eval-rst} +.. autofunction:: torch.export.passes.move_to_device_pass +``` + +```{eval-rst} +.. automodule:: torch.export.pt2_archive +``` + +```{eval-rst} +.. automodule:: torch.export.pt2_archive.constants +``` diff --git a/docs/source/export.programming_model.md b/docs/source/export.programming_model.md new file mode 100644 index 00000000000000..9a21db78464aa8 --- /dev/null +++ b/docs/source/export.programming_model.md @@ -0,0 +1,523 @@ +(export-programming-model)= + +# torch.export Programming Model + +This document aims to explain the behaviors and capabilities of +{func}`torch.export.export`. It is intended to help build your intuition +for how {func}`torch.export.export` handles code. + +## Basics of Tracing + +{func}`torch.export.export` captures a graph representing your model by +tracing its execution on "example" inputs and recording the PyTorch operations +and conditions observed along the traced path. This graph can then be run +on different inputs as long as they satisfy the same conditions. + +The basic output of {func}`torch.export.export` is a single graph of PyTorch +operations, with associated metadata. The exact format of this output is +covered in the {ref}`export.ir_spec`. + +### Strict vs. Non-Strict Tracing + +{func}`torch.export.export` provides two modes of tracing. + +In *non-strict mode*, we trace through the program using the normal Python +interpreter. Your code executes exactly as it would in eager mode; the only +difference is that all Tensors are replaced by +[fake Tensors](https://pytorch.org/docs/main/torch.compiler_fake_tensor.html), +**which have shapes and other forms of metadata but no data**, wrapped in +[Proxy objects](https://pytorch.org/docs/main/fx.html) that record all +operations on them into a graph. We also capture +[conditions on Tensor shapes](https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#the-guard-model) +**that guard the correctness of the generated code**. + +In *strict mode*, we first trace through the program using +{ref}`TorchDynamo `, a Python bytecode +analysis engine. TorchDynamo does not actually execute your Python code. +Instead, it symbolically analyzes it and builds a graph based on the results. +On the one hand, this analysis allows {func}`torch.export.export` to provide +additional guarantees on Python-level safety (beyond capturing conditions on +Tensor shapes, as in non-strict mode). On the other hand, not all Python +features are supported by this analysis. + +Although currently the default mode of tracing is strict, **we strongly +recommend using non-strict**, which will soon become the default. +For most models, conditions on Tensor shapes are enough for soundness, and +the additional guarantees on Python-level safety have no impact; at the same +time, the possibility of hitting unsupported Python features in TorchDynamo +presents an unnecessary risk. + +In the rest of this document we assume we are tracing in +[non-strict mode](https://pytorch.org/docs/main/export.html#non-strict-export); +in particular, we assume that **all Python features are supported**. + +## Values: Static vs. Dynamic + +A key concept in understanding the behavior of {func}`torch.export.export` is +the difference between *static* and *dynamic* values. + +### Static Values + +A *static* value is a value that is **fixed at export time and cannot change +between executions of the exported program**. When the value is encountered +during tracing, we treat it as a constant and hard-code it into the graph. + +When an operation is performed (e.g. `x + y`) and all inputs are static, +the output of the operation is directly hard-coded into the graph and the +operation does not show up (i.e. it gets "constant-folded"). + +When a value has been hard-coded into the graph, we say that the graph has +been *specialized* to that value. For example: + +```python +import torch + +class MyMod(torch.nn.Module): + def forward(self, x, y): + z = y + 7 + return x + z + +m = torch.export.export(MyMod(), (torch.randn(1), 3)) +print(m.graph_module.code) + +""" +def forward(self, arg0_1, arg1_1): + add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None + return (add,) + +""" +``` + +Here, we provide `3` as the traced value for `y`; it is treated as a static +value and added to `7`, burning in the static value `10` in the graph. + +### Dynamic Values + +A *dynamic* value is one that **can change from run to run**. It behaves just +like a "normal" function argument: you can pass different inputs and expect +your function to do the right thing. + +### Which values are static vs. dynamic? + +Whether a value is static or dynamic depends on its type: + +- For Tensor: + + - Tensor *data* is treated as dynamic. + + - Tensor *shapes* can be treated by the system as static or dynamic. + + - By default, shapes of all input Tensors are considered static. + The user can override this behavior for any input Tensor by specifying + a [dynamic shape](https://pytorch.org/docs/main/export.html#expressing-dynamism) + for it. + - Tensors that are part of module state, i.e., parameters and buffers, + always have static shapes. + + - Other forms of Tensor *metadata* (e.g. `device`, `dtype`) are static. + +- Python *primitives* (`int`, `float`, `bool`, `str`, `None`) are static. + + - There are dynamic variants for some primitive types (`SymInt`, + `SymFloat`, `SymBool`). Typically users do not have to deal with them. + +- For Python *standard containers* (`list`, `tuple`, `dict`, `namedtuple`): + + - The structure (i.e., length for `list` and `tuple` values, and key + sequence for `dict` and `namedtuple` values) is static. + - The contained elements have these rules applied to them recursively + (basically the + [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html) scheme) + with leaves that are either Tensor or primitive types. + +- Other *classes* (including data classes) can be registered with PyTree + (see below), and follow the same rules as the standard containers. + +## Input types + +Inputs will be treated as either static or dynamic, based on their type +(as explained above). + +- A static input will get hard-coded into the graph, and passing a different + value at run time will result in an error. Recall that these are mostly + values of primitive types. +- A dynamic input behaves like a "normal" function input. Recall that these + are mostly values of Tensor types. + +By default, the types of inputs you can use for your program are: + +- Tensor +- Python primitives (`int`, `float`, `bool`, `str`, `None`) +- Python standard containers (`list`, `tuple`, `dict`, `namedtuple`) + +### Custom Input Types + +In addition, you can also define your own (custom) class and use it as an +input type, but you will need to register such a class as a PyTree. + +Here's an example of using an utility to register a dataclass that is used as +an input type. + +```python +@dataclass +class Input: + f: torch.Tensor + p: torch.Tensor + +torch.export.register_dataclass(Input) + +class M(torch.nn.Module): + def forward(self, x: Input): + return x.f + 1 + +torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),)) +``` + +### Optional input types + +For optional inputs to the program that are not passed in, +{func}`torch.export.export` will specialize to their default values. As a +result, the exported program will require users to explicitly pass in all +arguments, and will lose the defaulting behavior. For example: + +```python +class M(torch.nn.Module): + def forward(self, x, y=None): + if y is not None: + return y * x + return x + x + +# Optional input is passed in +ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3))) +print(ep) +""" +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"): + # File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x + mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None + return (mul,) +""" + +# Optional input is not passed in +ep = torch.export.export(M(), (torch.randn(3, 3),)) +print(ep) +""" +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[3, 3]", y): + # File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x + add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None + return (add,) +""" +``` + +## Control Flow: Static vs. Dynamic + +Control flow is supported by {func}`torch.export.export`. The behavior of +control flow depends on whether the value you are branching on is static or +dynamic. + +### Static Control Flow + +**Python control flow over static values is supported transparently**. (Recall +that static values include static shapes, so control flow over static shapes +is also covered by this case.) + +As mentioned above, we "burn in" static values, so the exported graph will +never see any control flow over static values. + +In the case of an `if` statement, we will continue tracing the branch taken +at export time. In the case of a `for` or `while` statement, we will continue +tracing by unrolling the loop. + +### Dynamic Control Flow: Shape-Dependent vs. Data-Dependent + +When the value involved in a control flow is dynamic, it could depend on +dynamic shapes or dynamic data. Given that the compiler traces with +information on shapes rather than data, the implications on the programming +model are different in these cases. + +#### Dynamic Shape-Dependent Control Flow + +When the value involved in a control flow is a +[dynamic shape](https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html), +in most cases **we will also know the concrete value of the dynamic shape +during tracing**: see the following section for more details on how the +compiler tracks this information. + +In these cases we say that the control flow is shape-dependent. **We use the +concrete value of the dynamic shape to evaluate the condition** to either +`True` or `False` and continue tracing (as discussed above), additionally +emitting a guard corresponding to the condition just evaluated. + +Otherwise the control flow is considered data-dependent. We cannot evaluate +the condition to either `True` or `False`, so cannot continue tracing and have to +raise an error at export time. See next section. + +#### Dynamic Data-Dependent Control Flow + +**Data-dependent control flow over dynamic values is supported, but you must +use one of PyTorch's explicit operators** to continue tracing. Using Python +control flow statements over dynamic values is not permitted, because the +compiler cannot evaluate the conditions necessary to continue tracing and +thus an error must be raised at export time. + +We provide **operators to express general conditionals and loops over dynamic +values**, e.g., `torch.cond`, `torch.map`. Note that you only need to use these +if you truly want *data-dependent control flow*. + +Here's an example of an `if` statement on a data-dependent condition, +`x.sum() > 0`, where `x` is an input Tensor, rewritten using `torch.cond`. +Instead of having to decide which branch to trace, now both branches are +traced. + +```python +class M_old(torch.nn.Module): + def forward(self, x): + if x.sum() > 0: + return x.sin() + else: + return x.cos() + +class M_new(torch.nn.Module): + def forward(self, x): + return torch.cond( + pred=x.sum() > 0, + true_fn=lambda x: x.sin(), + false_fn=lambda x: x.cos(), + operands=(x,), + ) +``` + +A special case of data-dependent control flow is where it involves a +[data-dependent dynamic shape](https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#unbacked-symints): +typically, the shape of some intermediate Tensor that depends on input data +rather than on input shapes (thus not shape-dependent). Instead of using a +control flow operator, in this case you can provide an assertion that decides +whether the condition is `True` or `False`. Given such an assertion, we can +continue tracing, emitting a guard as above. + +We provide **operators to express assertions on dynamic shapes**, e.g., +`torch._check`. Note that you only need to use this when there is control +flow on data-dependent dynamic shapes. + +Here's an example of an `if` statement on a condition involving a +data-dependent dynamic shape, `nz.shape[0] > 0`, where `nz` is the result of +calling {func}`torch.nonzero`, an operator whose output shape depends on input +data. Instead of rewriting it, you can add an assertion using `torch._check` +to effectively decide which branch to trace. + +```python +class M_old(torch.nn.Module): + def forward(self, x): + nz = x.nonzero() + if nz.shape[0] > 0: + return x.sin() + else: + return x.cos() + +class M_new(torch.nn.Module): + def forward(self, x): + nz = x.nonzero() + torch._check(nz.shape[0] > 0) + if nz.shape[0] > 0: + return x.sin() + else: + return x.cos() +``` + +## Basics of Symbolic Shapes + +During tracing, dynamic Tensor shapes and conditions over them are encoded as +"symbolic expressions." (In contrast, static Tensor shapes and conditions +over them are simply `int` and `bool` values.) + +A *symbol* is like a variable; it describes a dynamic Tensor shape. + +As tracing proceeds, shapes of intermediate Tensors may be described by more +general expressions, typically involving integer arithmetic operators. This +is because **for most PyTorch operators, shapes of output Tensors can be +described as functions of shapes of input Tensors**. For example, the shape of +the output of {func}`torch.cat` is the sum of the shapes of its inputs. + +Moreover, as we encounter control flow in the program, we create boolean +expressions, typically involving relational operators, describing conditions +along the traced path. These **expressions are evaluated to decide which path +to trace through the program**, and recorded in a +[shape environment](https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#overall-architecture) +to guard the correctness of the traced path and to evaluate subsequently +created expressions. + +We briefly introduce these subsystems next. + +### Fake Implementations of PyTorch Operators + +Recall that during tracing, we are executing the program with +[fake Tensors](https://pytorch.org/docs/main/torch.compiler_fake_tensor.html), +which have no data. In general we cannot call the actual implementations of +PyTorch operators with fake Tensors. Thus each operator needs to have an +additional fake (a.k.a. "meta") implementation, which inputs and outputs fake +Tensors, that matches the behavior of the actual implementation in terms of +shapes and other forms of metadata carried by fake Tensors. + +For example, note how the fake implementation of {func}`torch.index_select` +computes the shape of the output using the shape of the input (while ignoring +input data and returning empty output data). + +```python +def meta_index_select(self, dim, index): + result_size = list(self.size()) + if self.dim() > 0: + result_size[dim] = index.numel() + return self.new_empty(result_size) +``` + +#### Shape Propagation: Backed vs. Unbacked Dynamic Shapes + +Shapes are propagated using fake implementations of PyTorch operators. + +A key concept to understand the propagation of dynamic shapes in particular +is the difference between *backed* and *unbacked* dynamic shapes: we know the +concrete values of the former but not the latter. + +Propagation of shapes, including tracking backed and unbacked dynamic shapes, +proceeds as follows: + +- The shapes of Tensors representing inputs can be static or dynamic. When + dynamic, they are described by symbols; moreover, **such symbols are backed + since we also know their concrete values given the "real" example inputs + provided by the user at export time**. + +- The output shape of an operator is computed by its fake implementation, and + is either static or dynamic. When dynamic, in general it is described by a + symbolic expression. Moreover: + + - If the output shape depends only on input shapes, it is either static or + backed dynamic whenever the input shapes are all static or backed dynamic. + - On the other hand, **if the output shape depends on input data**, it is + necessarily dynamic, and moreover, **because we cannot know its concrete + value it is unbacked**. + +### Control Flow: Guards and Assertions + +When a condition on shapes is encountered, it either involves only static +shapes, in which case it is a `bool`, or it involves dynamic shapes, in which +case it is a symbolic boolean expression. For the latter: + +- When the condition involves only backed dynamic shapes, we can use the + concrete values of those dynamic shapes to evaluate the condition to `True` + or `False`. We can then add a guard to the shape environment that states + that the corresponding symbolic boolean expression is `True` or `False`, + and continue tracing. +- Otherwise the condition involves unbacked dynamic shapes. In general we + cannot evaluate such a condition without additional information; thus we + cannot continue tracing, and we must raise an error at export time. The + user is expected to use an explicit PyTorch operator for tracing to + continue. This information is added as a guard in the shape environment, + and can also possibly help evaluate other subsequently encountered + conditions to `True` or `False`. + +Once the model is exported, **any guards on backed dynamic shapes can be +understood as conditions on input dynamic shapes**. These are verified against +a dynamic shape specification that must have been provided to export, +describing conditions on dynamic shapes that not only example inputs but also +all future inputs are expected to satisfy for the generated code to be +correct. More precisely, the dynamic shape specification must logically imply +the generated guards, otherwise an error is raised at export time (along with +suggested fixes to the dynamic shape specification). On the other hand, when +there are no generated guards on backed dynamic shapes (in particular, when +all shapes are static) no dynamic shape specification needs to be provided to +export. In general, the dynamic shape specification is converted to runtime +assertions on the inputs of the generated code. + +Finally, **any guards on unbacked dynamic shapes are converted to "inline" +runtime assertions**. These are added in the generated code at the locations +where those unbacked dynamic shapes were created: typically, right after +data-dependent operator calls. + +## Allowed PyTorch operators + +All PyTorch operators are permitted. + +### Custom operators + +In addition, you can define and use +[custom operators](https://pytorch.org/tutorials/advanced/python_custom_ops#python-custom-ops-tutorial). +Defining a custom operator includes defining a fake implementation for it, +just like any other PyTorch operator (see previous section). + +Here's an example of a custom `sin` operator that wraps NumPy, and its +registered (trivial) fake implementation. + +```python +@torch.library.custom_op("mylib::sin", mutates_args=()) +def sin(x: Tensor) -> Tensor: + x_np = x.numpy() + y_np = np.sin(x_np) + return torch.from_numpy(y_np) + +@torch.library.register_fake("mylib::sin") +def _(x: Tensor) -> Tensor: + return torch.empty_like(x) +``` + +**Sometimes your custom operator's fake implementation will involve +data-dependent shapes**. Here's how a fake implementation for a custom +`nonzero` might look like. + +```python +... + +@torch.library.register_fake("mylib::custom_nonzero") +def _(x): + nnz = torch.library.get_ctx().new_dynamic_size() + shape = [nnz, x.dim()] + return x.new_empty(shape, dtype=torch.int64) +``` + +## Module State: Reads vs. Updates + +Module states include parameters, buffers, and regular attributes. + +- A regular attribute can be of any type. +- On the other hand, parameters and buffers are always Tensors. + +Module states can be dynamic or static, based on their types as outlined +above. For example, `self.training` is a `bool`, which means it is static; on +the other hand, any parameter or buffer is dynamic. + +The *shapes* of any Tensors contained in module states cannot be dynamic, i.e., +those shapes are fixed at export time, and cannot change between executions +of the exported program. + +### Access rules + +**All module states must be initialized**. Accessing a module state that is +not already initialized causes an error to be raised at export time. + +**Reading module states is always permitted**. + +Updating module states is possible, but must follow the rules below: + +- **A static regular attribute** (e.g., of primitive type) **can be updated**. + Reads and updates can be freely interleaved, and as expected, any reads + will always see the values of the latest updates. Because these attributes + are static, we will also burn the values in, so the generated code will not + have any instructions to actually "get" or "set" such attributes. +- **A dynamic regular attribute** (e.g., of Tensor type) **cannot be updated**. + To do so, it must be registered as a buffer during module initialization. +- **A buffer can be updated**, where the updating can be in-place (e.g., + `self.buffer[:] = ...`) or not (e.g., `self.buffer = ...`). +- **A parameter cannot be updated**. Typically parameters are updated only + during training, not during inference. We recommend exporting with + {func}`torch.no_grad` to avoid parameter updates at export time. + +### Effects of functionalization + +Any dynamic module state that is read and/or updated is "lifted" +(respectively) as an input and/or output of the generated code. + +The exported program stores, along with the generated code, the initial +values of parameters and buffers and the constant values of other Tensor +attributes. diff --git a/docs/source/export.programming_model.rst b/docs/source/export.programming_model.rst deleted file mode 100644 index b82309b423489e..00000000000000 --- a/docs/source/export.programming_model.rst +++ /dev/null @@ -1,562 +0,0 @@ -.. _export.programming_model: - -torch.export Programming Model -============================== - -This document aims to explain the behaviors and capabilities of -:func:`torch.export.export`. It is intended to help build your intuition -for how :func:`torch.export.export` handles code. - -Basics of Tracing ------------------ - -:func:`torch.export.export` captures a graph representing your model by -tracing its execution on "example" inputs and recording the PyTorch operations -and conditions observed along the traced path. This graph can then be run -on different inputs as long as they satisfy the same conditions. - -The basic output of :func:`torch.export.export` is a single graph of PyTorch -operations, with associated metadata. The exact format of this output is -covered in the :ref:`export.ir_spec`. - -Strict vs. Non-Strict Tracing -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:func:`torch.export.export` provides two modes of tracing. - -In *non-strict mode*, we trace through the program using the normal Python -interpreter. Your code executes exactly as it would in eager mode; the only -difference is that all Tensors are replaced by -`fake Tensors `__, -**which have shapes and other forms of metadata but no data**, wrapped in -`Proxy objects `__ that record all -operations on them into a graph. We also capture -`conditions on Tensor shapes `__ -**that guard the correctness of the generated code**. - -In *strict mode*, we first trace through the program using -:ref:`TorchDynamo `, a Python bytecode -analysis engine. TorchDynamo does not actually execute your Python code. -Instead, it symbolically analyzes it and builds a graph based on the results. -On the one hand, this analysis allows :func:`torch.export.export` to provide -additional guarantees on Python-level safety (beyond capturing conditions on -Tensor shapes, as in non-strict mode). On the other hand, not all Python -features are supported by this analysis. - -Although currently the default mode of tracing is strict, **we strongly -recommend using non-strict**, which will soon become the default. -For most models, conditions on Tensor shapes are enough for soundness, and -the additional guarantees on Python-level safety have no impact; at the same -time, the possibility of hitting unsupported Python features in TorchDynamo -presents an unnecessary risk. - -In the rest of this document we assume we are tracing in -`non-strict mode `__; -in particular, we assume that **all Python features are supported**. - -Values: Static vs. Dynamic --------------------------- - -A key concept in understanding the behavior of :func:`torch.export.export` is -the difference between *static* and *dynamic* values. - -Static Values -^^^^^^^^^^^^^ - -A *static* value is a value that is **fixed at export time and cannot change -between executions of the exported program**. When the value is encountered -during tracing, we treat it as a constant and hard-code it into the graph. - -When an operation is performed (e.g. ``x + y``) and all inputs are static, -the output of the operation is directly hard-coded into the graph and the -operation does not show up (i.e. it gets "constant-folded"). - -When a value has been hard-coded into the graph, we say that the graph has -been *specialized* to that value. For example: - -.. code-block:: python - - import torch - - class MyMod(torch.nn.Module): - def forward(self, x, y): - z = y + 7 - return x + z - - m = torch.export.export(MyMod(), (torch.randn(1), 3)) - print(m.graph_module.code) - - """ - def forward(self, arg0_1, arg1_1): - add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None - return (add,) - - """ - -Here, we provide ``3`` as the traced value for ``y``; it is treated as a static -value and added to ``7``, burning in the static value ``10`` in the graph. - -Dynamic Values -^^^^^^^^^^^^^^ - -A *dynamic* value is one that **can change from run to run**. It behaves just -like a "normal" function argument: you can pass different inputs and expect -your function to do the right thing. - -Which values are static vs. dynamic? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Whether a value is static or dynamic depends on its type: - -- For Tensor: - - - Tensor *data* is treated as dynamic. - - - Tensor *shapes* can be treated by the system as static or dynamic. - - - By default, shapes of all input Tensors are considered static. - The user can override this behavior for any input Tensor by specifying - a `dynamic shape `__ - for it. - - - Tensors that are part of module state, i.e., parameters and buffers, - always have static shapes. - - - Other forms of Tensor *metadata* (e.g. ``device``, ``dtype``) are static. - -- Python *primitives* (``int``, ``float``, ``bool``, ``str``, ``None``) are static. - - - There are dynamic variants for some primitive types (``SymInt``, - ``SymFloat``, ``SymBool``). Typically users do not have to deal with them. - -- For Python *standard containers* (``list``, ``tuple``, ``dict``, ``namedtuple``): - - - The structure (i.e., length for ``list`` and ``tuple`` values, and key - sequence for ``dict`` and ``namedtuple`` values) is static. - - - The contained elements have these rules applied to them recursively - (basically the - `PyTree `__ scheme) - with leaves that are either Tensor or primitive types. - -- Other *classes* (including data classes) can be registered with PyTree - (see below), and follow the same rules as the standard containers. - - -Input types ------------ - -Inputs will be treated as either static or dynamic, based on their type -(as explained above). - -- A static input will get hard-coded into the graph, and passing a different - value at run time will result in an error. Recall that these are mostly - values of primitive types. - -- A dynamic input behaves like a "normal" function input. Recall that these - are mostly values of Tensor types. - -By default, the types of inputs you can use for your program are: - -- Tensor - -- Python primitives (``int``, ``float``, ``bool``, ``str``, ``None``) - -- Python standard containers (``list``, ``tuple``, ``dict``, ``namedtuple``) - -Custom Input Types -^^^^^^^^^^^^^^^^^^ - -In addition, you can also define your own (custom) class and use it as an -input type, but you will need to register such a class as a PyTree. - -Here's an example of using an utility to register a dataclass that is used as -an input type. - -.. code-block:: python - - @dataclass - class Input: - f: torch.Tensor - p: torch.Tensor - - torch.export.register_dataclass(Input) - - class M(torch.nn.Module): - def forward(self, x: Input): - return x.f + 1 - - torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),)) - -Optional input types -^^^^^^^^^^^^^^^^^^^^ - -For optional inputs to the program that are not passed in, -:func:`torch.export.export` will specialize to their default values. As a -result, the exported program will require users to explicitly pass in all -arguments, and will lose the defaulting behavior. For example: - -.. code-block:: python - - class M(torch.nn.Module): - def forward(self, x, y=None): - if y is not None: - return y * x - return x + x - - # Optional input is passed in - ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3))) - print(ep) - """ - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"): - # File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x - mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None - return (mul,) - """ - - # Optional input is not passed in - ep = torch.export.export(M(), (torch.randn(3, 3),)) - print(ep) - """ - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[3, 3]", y): - # File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x - add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None - return (add,) - """ - -Control Flow: Static vs. Dynamic --------------------------------- - -Control flow is supported by :func:`torch.export.export`. The behavior of -control flow depends on whether the value you are branching on is static or -dynamic. - -Static Control Flow -^^^^^^^^^^^^^^^^^^^ - -**Python control flow over static values is supported transparently**. (Recall -that static values include static shapes, so control flow over static shapes -is also covered by this case.) - -As mentioned above, we "burn in" static values, so the exported graph will -never see any control flow over static values. - -In the case of an ``if`` statement, we will continue tracing the branch taken -at export time. In the case of a ``for`` or ``while`` statement, we will continue -tracing by unrolling the loop. - -Dynamic Control Flow: Shape-Dependent vs. Data-Dependent -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When the value involved in a control flow is dynamic, it could depend on -dynamic shapes or dynamic data. Given that the compiler traces with -information on shapes rather than data, the implications on the programming -model are different in these cases. - -Dynamic Shape-Dependent Control Flow -"""""""""""""""""""""""""""""""""""" - -When the value involved in a control flow is a -`dynamic shape `__, -in most cases **we will also know the concrete value of the dynamic shape -during tracing**: see the following section for more details on how the -compiler tracks this information. - -In these cases we say that the control flow is shape-dependent. **We use the -concrete value of the dynamic shape to evaluate the condition** to either -``True`` or ``False`` and continue tracing (as discussed above), additionally -emitting a guard corresponding to the condition just evaluated. - -Otherwise the control flow is considered data-dependent. We cannot evaluate -the condition to either ``True`` or ``False``, so cannot continue tracing and have to -raise an error at export time. See next section. - -Dynamic Data-Dependent Control Flow -""""""""""""""""""""""""""""""""""" - -**Data-dependent control flow over dynamic values is supported, but you must -use one of PyTorch's explicit operators** to continue tracing. Using Python -control flow statements over dynamic values is not permitted, because the -compiler cannot evaluate the conditions necessary to continue tracing and -thus an error must be raised at export time. - -We provide **operators to express general conditionals and loops over dynamic -values**, e.g., `torch.cond`, `torch.map`. Note that you only need to use these -if you truly want *data-dependent control flow*. - -Here's an example of an ``if`` statement on a data-dependent condition, -``x.sum() > 0``, where ``x`` is an input Tensor, rewritten using `torch.cond`. -Instead of having to decide which branch to trace, now both branches are -traced. - -.. code-block:: python - - class M_old(torch.nn.Module): - def forward(self, x): - if x.sum() > 0: - return x.sin() - else: - return x.cos() - - class M_new(torch.nn.Module): - def forward(self, x): - return torch.cond( - pred=x.sum() > 0, - true_fn=lambda x: x.sin(), - false_fn=lambda x: x.cos(), - operands=(x,), - ) - -A special case of data-dependent control flow is where it involves a -`data-dependent dynamic shape `__: -typically, the shape of some intermediate Tensor that depends on input data -rather than on input shapes (thus not shape-dependent). Instead of using a -control flow operator, in this case you can provide an assertion that decides -whether the condition is ``True`` or ``False``. Given such an assertion, we can -continue tracing, emitting a guard as above. - -We provide **operators to express assertions on dynamic shapes**, e.g., -`torch._check`. Note that you only need to use this when there is control -flow on data-dependent dynamic shapes. - -Here's an example of an ``if`` statement on a condition involving a -data-dependent dynamic shape, ``nz.shape[0] > 0``, where ``nz`` is the result of -calling :func:`torch.nonzero`, an operator whose output shape depends on input -data. Instead of rewriting it, you can add an assertion using `torch._check` -to effectively decide which branch to trace. - -.. code-block:: python - - class M_old(torch.nn.Module): - def forward(self, x): - nz = x.nonzero() - if nz.shape[0] > 0: - return x.sin() - else: - return x.cos() - - class M_new(torch.nn.Module): - def forward(self, x): - nz = x.nonzero() - torch._check(nz.shape[0] > 0) - if nz.shape[0] > 0: - return x.sin() - else: - return x.cos() - - -Basics of Symbolic Shapes -------------------------- - -During tracing, dynamic Tensor shapes and conditions over them are encoded as -"symbolic expressions." (In contrast, static Tensor shapes and conditions -over them are simply ``int`` and ``bool`` values.) - -A *symbol* is like a variable; it describes a dynamic Tensor shape. - -As tracing proceeds, shapes of intermediate Tensors may be described by more -general expressions, typically involving integer arithmetic operators. This -is because **for most PyTorch operators, shapes of output Tensors can be -described as functions of shapes of input Tensors**. For example, the shape of -the output of :func:`torch.cat` is the sum of the shapes of its inputs. - -Moreover, as we encounter control flow in the program, we create boolean -expressions, typically involving relational operators, describing conditions -along the traced path. These **expressions are evaluated to decide which path -to trace through the program**, and recorded in a -`shape environment `__ -to guard the correctness of the traced path and to evaluate subsequently -created expressions. - -We briefly introduce these subsystems next. - -Fake Implementations of PyTorch Operators -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Recall that during tracing, we are executing the program with -`fake Tensors `__, -which have no data. In general we cannot call the actual implementations of -PyTorch operators with fake Tensors. Thus each operator needs to have an -additional fake (a.k.a. "meta") implementation, which inputs and outputs fake -Tensors, that matches the behavior of the actual implementation in terms of -shapes and other forms of metadata carried by fake Tensors. - -For example, note how the fake implementation of :func:`torch.index_select` -computes the shape of the output using the shape of the input (while ignoring -input data and returning empty output data). - -.. code-block:: python - - def meta_index_select(self, dim, index): - result_size = list(self.size()) - if self.dim() > 0: - result_size[dim] = index.numel() - return self.new_empty(result_size) - -Shape Propagation: Backed vs. Unbacked Dynamic Shapes -""""""""""""""""""""""""""""""""""""""""""""""""""""" - -Shapes are propagated using fake implementations of PyTorch operators. - -A key concept to understand the propagation of dynamic shapes in particular -is the difference between *backed* and *unbacked* dynamic shapes: we know the -concrete values of the former but not the latter. - -Propagation of shapes, including tracking backed and unbacked dynamic shapes, -proceeds as follows: - -- The shapes of Tensors representing inputs can be static or dynamic. When - dynamic, they are described by symbols; moreover, **such symbols are backed - since we also know their concrete values given the "real" example inputs - provided by the user at export time**. - -- The output shape of an operator is computed by its fake implementation, and - is either static or dynamic. When dynamic, in general it is described by a - symbolic expression. Moreover: - - - If the output shape depends only on input shapes, it is either static or - backed dynamic whenever the input shapes are all static or backed dynamic. - - - On the other hand, **if the output shape depends on input data**, it is - necessarily dynamic, and moreover, **because we cannot know its concrete - value it is unbacked**. - -Control Flow: Guards and Assertions -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When a condition on shapes is encountered, it either involves only static -shapes, in which case it is a ``bool``, or it involves dynamic shapes, in which -case it is a symbolic boolean expression. For the latter: - -- When the condition involves only backed dynamic shapes, we can use the - concrete values of those dynamic shapes to evaluate the condition to ``True`` - or ``False``. We can then add a guard to the shape environment that states - that the corresponding symbolic boolean expression is ``True`` or ``False``, - and continue tracing. - -- Otherwise the condition involves unbacked dynamic shapes. In general we - cannot evaluate such a condition without additional information; thus we - cannot continue tracing, and we must raise an error at export time. The - user is expected to use an explicit PyTorch operator for tracing to - continue. This information is added as a guard in the shape environment, - and can also possibly help evaluate other subsequently encountered - conditions to ``True`` or ``False``. - -Once the model is exported, **any guards on backed dynamic shapes can be -understood as conditions on input dynamic shapes**. These are verified against -a dynamic shape specification that must have been provided to export, -describing conditions on dynamic shapes that not only example inputs but also -all future inputs are expected to satisfy for the generated code to be -correct. More precisely, the dynamic shape specification must logically imply -the generated guards, otherwise an error is raised at export time (along with -suggested fixes to the dynamic shape specification). On the other hand, when -there are no generated guards on backed dynamic shapes (in particular, when -all shapes are static) no dynamic shape specification needs to be provided to -export. In general, the dynamic shape specification is converted to runtime -assertions on the inputs of the generated code. - -Finally, **any guards on unbacked dynamic shapes are converted to "inline" -runtime assertions**. These are added in the generated code at the locations -where those unbacked dynamic shapes were created: typically, right after -data-dependent operator calls. - - -Allowed PyTorch operators -------------------------- - -All PyTorch operators are permitted. - -Custom operators -^^^^^^^^^^^^^^^^ - -In addition, you can define and use -`custom operators `__. -Defining a custom operator includes defining a fake implementation for it, -just like any other PyTorch operator (see previous section). - -Here's an example of a custom ``sin`` operator that wraps NumPy, and its -registered (trivial) fake implementation. - -.. code-block:: python - - @torch.library.custom_op("mylib::sin", mutates_args=()) - def sin(x: Tensor) -> Tensor: - x_np = x.numpy() - y_np = np.sin(x_np) - return torch.from_numpy(y_np) - - @torch.library.register_fake("mylib::sin") - def _(x: Tensor) -> Tensor: - return torch.empty_like(x) - -**Sometimes your custom operator's fake implementation will involve -data-dependent shapes**. Here's how a fake implementation for a custom -``nonzero`` might look like. - -.. code-block:: python - - ... - - @torch.library.register_fake("mylib::custom_nonzero") - def _(x): - nnz = torch.library.get_ctx().new_dynamic_size() - shape = [nnz, x.dim()] - return x.new_empty(shape, dtype=torch.int64) - - -Module State: Reads vs. Updates -------------------------------- - -Module states include parameters, buffers, and regular attributes. - -- A regular attribute can be of any type. - -- On the other hand, parameters and buffers are always Tensors. - -Module states can be dynamic or static, based on their types as outlined -above. For example, ``self.training`` is a ``bool``, which means it is static; on -the other hand, any parameter or buffer is dynamic. - -The *shapes* of any Tensors contained in module states cannot be dynamic, i.e., -those shapes are fixed at export time, and cannot change between executions -of the exported program. - -Access rules -^^^^^^^^^^^^ - -**All module states must be initialized**. Accessing a module state that is -not already initialized causes an error to be raised at export time. - -**Reading module states is always permitted**. - -Updating module states is possible, but must follow the rules below: - -- **A static regular attribute** (e.g., of primitive type) **can be updated**. - Reads and updates can be freely interleaved, and as expected, any reads - will always see the values of the latest updates. Because these attributes - are static, we will also burn the values in, so the generated code will not - have any instructions to actually "get" or "set" such attributes. - -- **A dynamic regular attribute** (e.g., of Tensor type) **cannot be updated**. - To do so, it must be registered as a buffer during module initialization. - -- **A buffer can be updated**, where the updating can be in-place (e.g., - ``self.buffer[:] = ...``) or not (e.g., ``self.buffer = ...``). - -- **A parameter cannot be updated**. Typically parameters are updated only - during training, not during inference. We recommend exporting with - :func:`torch.no_grad` to avoid parameter updates at export time. - -Effects of functionalization -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Any dynamic module state that is read and/or updated is "lifted" -(respectively) as an input and/or output of the generated code. - -The exported program stores, along with the generated code, the initial -values of parameters and buffers and the constant values of other Tensor -attributes. diff --git a/docs/source/export.rst b/docs/source/export.rst deleted file mode 100644 index 4ec5e16c360c84..00000000000000 --- a/docs/source/export.rst +++ /dev/null @@ -1,863 +0,0 @@ -.. _torch.export: - -torch.export -===================== - -.. warning:: - This feature is a prototype under active development and there WILL BE - BREAKING CHANGES in the future. - - -Overview --------- - -:func:`torch.export.export` takes a :class:`torch.nn.Module` and produces a traced graph -representing only the Tensor computation of the function in an Ahead-of-Time -(AOT) fashion, which can subsequently be executed with different outputs or -serialized. - -:: - - import torch - from torch.export import export - - class Mod(torch.nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - a = torch.sin(x) - b = torch.cos(y) - return a + b - - example_args = (torch.randn(10, 10), torch.randn(10, 10)) - - exported_program: torch.export.ExportedProgram = export( - Mod(), args=example_args - ) - print(exported_program) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"): - # code: a = torch.sin(x) - sin: "f32[10, 10]" = torch.ops.aten.sin.default(x) - - # code: b = torch.cos(y) - cos: "f32[10, 10]" = torch.ops.aten.cos.default(y) - - # code: return a + b - add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos) - return (add,) - - Graph signature: - ExportGraphSignature( - input_specs=[ - InputSpec( - kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='y'), - target=None, - persistent=None - ) - ], - output_specs=[ - OutputSpec( - kind=, - arg=TensorArgument(name='add'), - target=None - ) - ] - ) - Range constraints: {} - -``torch.export`` produces a clean intermediate representation (IR) with the -following invariants. More specifications about the IR can be found -:ref:`here `. - -* **Soundness**: It is guaranteed to be a sound representation of the original - program, and maintains the same calling conventions of the original program. - -* **Normalized**: There are no Python semantics within the graph. Submodules - from the original programs are inlined to form one fully flattened - computational graph. - -* **Graph properties**: The graph is purely functional, meaning it does not - contain operations with side effects such as mutations or aliasing. It does - not mutate any intermediate values, parameters, or buffers. - -* **Metadata**: The graph contains metadata captured during tracing, such as a - stacktrace from user's code. - -Under the hood, ``torch.export`` leverages the following latest technologies: - -* **TorchDynamo (torch._dynamo)** is an internal API that uses a CPython feature - called the Frame Evaluation API to safely trace PyTorch graphs. This - provides a massively improved graph capturing experience, with much fewer - rewrites needed in order to fully trace the PyTorch code. - -* **AOT Autograd** provides a functionalized PyTorch graph and ensures the graph - is decomposed/lowered to the ATen operator set. - -* **Torch FX (torch.fx)** is the underlying representation of the graph, - allowing flexible Python-based transformations. - - -Existing frameworks -^^^^^^^^^^^^^^^^^^^ - -:func:`torch.compile` also utilizes the same PT2 stack as ``torch.export``, but -is slightly different: - -* **JIT vs. AOT**: :func:`torch.compile` is a JIT compiler whereas - which is not intended to be used to produce compiled artifacts outside of - deployment. - -* **Partial vs. Full Graph Capture**: When :func:`torch.compile` runs into an - untraceable part of a model, it will "graph break" and fall back to running - the program in the eager Python runtime. In comparison, ``torch.export`` aims - to get a full graph representation of a PyTorch model, so it will error out - when something untraceable is reached. Since ``torch.export`` produces a full - graph disjoint from any Python features or runtime, this graph can then be - saved, loaded, and run in different environments and languages. - -* **Usability tradeoff**: Since :func:`torch.compile` is able to fallback to the - Python runtime whenever it reaches something untraceable, it is a lot more - flexible. ``torch.export`` will instead require users to provide more - information or rewrite their code to make it traceable. - -Compared to :func:`torch.fx.symbolic_trace`, ``torch.export`` traces using -TorchDynamo which operates at the Python bytecode level, giving it the ability -to trace arbitrary Python constructs not limited by what Python operator -overloading supports. Additionally, ``torch.export`` keeps fine-grained track of -tensor metadata, so that conditionals on things like tensor shapes do not -fail tracing. In general, ``torch.export`` is expected to work on more user -programs, and produce lower-level graphs (at the ``torch.ops.aten`` operator -level). Note that users can still use :func:`torch.fx.symbolic_trace` as a -preprocessing step before ``torch.export``. - -Compared to :func:`torch.jit.script`, ``torch.export`` does not capture Python -control flow or data structures, but it supports more Python language features -than TorchScript (as it is easier to have comprehensive coverage over Python -bytecodes). The resulting graphs are simpler and only have straight line control -flow (except for explicit control flow operators). - -Compared to :func:`torch.jit.trace`, ``torch.export`` is sound: it is able to -trace code that performs integer computation on sizes and records all of the -side-conditions necessary to show that a particular trace is valid for other -inputs. - - -Exporting a PyTorch Model -------------------------- - -An Example -^^^^^^^^^^ - -The main entrypoint is through :func:`torch.export.export`, which takes a -callable (:class:`torch.nn.Module`, function, or method) and sample inputs, and -captures the computation graph into an :class:`torch.export.ExportedProgram`. An -example: - -:: - - import torch - from torch.export import export - - # Simple module for demonstration - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv = torch.nn.Conv2d( - in_channels=3, out_channels=16, kernel_size=3, padding=1 - ) - self.relu = torch.nn.ReLU() - self.maxpool = torch.nn.MaxPool2d(kernel_size=3) - - def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor: - a = self.conv(x) - a.add_(constant) - return self.maxpool(self.relu(a)) - - example_args = (torch.randn(1, 3, 256, 256),) - example_kwargs = {"constant": torch.ones(1, 16, 256, 256)} - - exported_program: torch.export.ExportedProgram = export( - M(), args=example_args, kwargs=example_kwargs - ) - print(exported_program) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"): - # code: a = self.conv(x) - conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]) - - # code: a.add_(constant) - add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant) - - # code: return self.maxpool(self.relu(a)) - relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_) - max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]) - return (max_pool2d,) - - Graph signature: - ExportGraphSignature( - input_specs=[ - InputSpec( - kind=, - arg=TensorArgument(name='p_conv_weight'), - target='conv.weight', - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='p_conv_bias'), - target='conv.bias', - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='constant'), - target=None, - persistent=None - ) - ], - output_specs=[ - OutputSpec( - kind=, - arg=TensorArgument(name='max_pool2d'), - target=None - ) - ] - ) - Range constraints: {} - -Inspecting the ``ExportedProgram``, we can note the following: - -* The :class:`torch.fx.Graph` contains the computation graph of the original - program, along with records of the original code for easy debugging. - -* The graph contains only ``torch.ops.aten`` operators found `here `__ - and custom operators, and is fully functional, without any inplace operators - such as ``torch.add_``. - -* The parameters (weight and bias to conv) are lifted as inputs to the graph, - resulting in no ``get_attr`` nodes in the graph, which previously existed in - the result of :func:`torch.fx.symbolic_trace`. - -* The :class:`torch.export.ExportGraphSignature` models the input and output - signature, along with specifying which inputs are parameters. - -* The resulting shape and dtype of tensors produced by each node in the graph is - noted. For example, the ``convolution`` node will result in a tensor of dtype - ``torch.float32`` and shape (1, 16, 256, 256). - - -.. _Non-Strict Export: - -Non-Strict Export -^^^^^^^^^^^^^^^^^ - -In PyTorch 2.3, we introduced a new mode of tracing called **non-strict mode**. -It's still going through hardening, so if you run into any issues, please file -them to Github with the "oncall: export" tag. - -In *non-strict mode*, we trace through the program using the Python interpreter. -Your code will execute exactly as it would in eager mode; the only difference is -that all Tensor objects will be replaced by ProxyTensors, which will record all -their operations into a graph. - -In *strict* mode, which is currently the default, we first trace through the -program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not -actually execute your Python code. Instead, it symbolically analyzes it and -builds a graph based on the results. This analysis allows torch.export to -provide stronger guarantees about safety, but not all Python code is supported. - -An example of a case where one might want to use non-strict mode is if you run -into a unsupported TorchDynamo feature that might not be easily solved, and you -know the python code is not exactly needed for computation. For example: - -:: - - import contextlib - import torch - - class ContextManager(): - def __init__(self): - self.count = 0 - def __enter__(self): - self.count += 1 - def __exit__(self, exc_type, exc_value, traceback): - self.count -= 1 - - class M(torch.nn.Module): - def forward(self, x): - with ContextManager(): - return x.sin() + x.cos() - - export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully - export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager - -In this example, the first call using non-strict mode (through the -``strict=False`` flag) traces successfully whereas the second call using strict -mode (default) results with a failure, where TorchDynamo is unable to support -context managers. One option is to rewrite the code (see :ref:`Limitations of torch.export `), but seeing as the context manager does not affect the tensor -computations in the model, we can go with the non-strict mode's result. - - -.. _Training Export: - -Export for Training and Inference -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In PyTorch 2.5, we introduced a new API called :func:`export_for_training`. -It's still going through hardening, so if you run into any issues, please file -them to Github with the "oncall: export" tag. - -In this API, we produce the most generic IR that contains all ATen operators -(including both functional and non-functional) which can be used to train in -eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization -and will soon be the default IR of torch.export.export. To read further about -the motivation behind this change, please refer to -https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206 - -When this API is combined with :func:`run_decompositions()`, you should be able to get inference IR with -any desired decomposition behavior. - -To show some examples: - -:: - - class ConvBatchnorm(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv = torch.nn.Conv2d(1, 3, 1, 1) - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return (x,) - - mod = ConvBatchnorm() - inp = torch.randn(1, 1, 3, 3) - - ep_for_training = torch.export.export_for_training(mod, (inp,)) - print(ep_for_training) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) - add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1) - batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True) - return (batch_norm,) - - -From the above output, you can see that :func:`export_for_training` produces pretty much the same ExportedProgram -as :func:`export` except for the operators in the graph. You can see that we captured batch_norm in the most general -form. This op is non-functional and will be lowered to different ops when running inference. - -You can also go from this IR to an inference IR via :func:`run_decompositions` with arbitrary customizations. - -:: - - # Lower to core aten inference IR, but keep conv2d - decomp_table = torch.export.default_decompositions() - del decomp_table[torch.ops.aten.conv2d.default] - ep_for_inference = ep_for_training.run_decompositions(decomp_table) - - print(ep_for_inference) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) - add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) - getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] - getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] - getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4] - return (getitem_3, getitem_4, add, getitem) - -Here you can see that we kept ``conv2d`` op in the IR while decomposing the rest. Now the IR is a functional IR -containing core aten operators except for ``conv2d``. - -You can do even more customization by directly registering your chosen decomposition behaviors. - -You can do even more customizations by directly registering custom decomp behaviour - -:: - - # Lower to core aten inference IR, but customize conv2d - decomp_table = torch.export.default_decompositions() - - def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): - return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) - - decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function - ep_for_inference = ep_for_training.run_decompositions(decomp_table) - - print(ep_for_inference) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) - mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2) - add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) - getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] - getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] - getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; - return (getitem_3, getitem_4, add, getitem) - - -Expressing Dynamism -^^^^^^^^^^^^^^^^^^^ - -By default ``torch.export`` will trace the program assuming all input shapes are -**static**, and specializing the exported program to those dimensions. However, -some dimensions, such as a batch dimension, can be dynamic and vary from run to -run. Such dimensions must be specified by using the -:func:`torch.export.Dim` API to create them and by passing them into -:func:`torch.export.export` through the ``dynamic_shapes`` argument. An example: - -:: - - import torch - from torch.export import Dim, export - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - self.branch1 = torch.nn.Sequential( - torch.nn.Linear(64, 32), torch.nn.ReLU() - ) - self.branch2 = torch.nn.Sequential( - torch.nn.Linear(128, 64), torch.nn.ReLU() - ) - self.buffer = torch.ones(32) - - def forward(self, x1, x2): - out1 = self.branch1(x1) - out2 = self.branch2(x2) - return (out1 + self.buffer, out2) - - example_args = (torch.randn(32, 64), torch.randn(32, 128)) - - # Create a dynamic batch size - batch = Dim("batch") - # Specify that the first dimension of each input is that batch size - dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} - - exported_program: torch.export.ExportedProgram = export( - M(), args=example_args, dynamic_shapes=dynamic_shapes - ) - print(exported_program) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"): - - # code: out1 = self.branch1(x1) - linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias) - relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear) - - # code: out2 = self.branch2(x2) - linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias) - relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1) - - # code: return (out1 + self.buffer, out2) - add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer) - return (add, relu_1) - - Range constraints: {s0: VR[0, int_oo]} - -Some additional things to note: - -* Through the :func:`torch.export.Dim` API and the ``dynamic_shapes`` argument, we specified the first - dimension of each input to be dynamic. Looking at the inputs ``x1`` and - ``x2``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of - the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. - ``s0`` is a symbol representing that this dimension can be a range - of values. - -* ``exported_program.range_constraints`` describes the ranges of each symbol - appearing in the graph. In this case, we see that ``s0`` has the range - [0, int_oo]. For technical reasons that are difficult to explain here, they are - assumed to be not 0 or 1. This is not a bug, and does not necessarily mean - that the exported program will not work for dimensions 0 or 1. See - `The 0/1 Specialization Problem `_ - for an in-depth discussion of this topic. - - -We can also specify more expressive relationships between input shapes, such as -where a pair of shapes might differ by one, a shape might be double of -another, or a shape is even. An example: - -:: - - class M(torch.nn.Module): - def forward(self, x, y): - return x + y[1:] - - x, y = torch.randn(5), torch.randn(6) - dimx = torch.export.Dim("dimx", min=3, max=6) - dimy = dimx + 1 - - exported_program = torch.export.export( - M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), - ) - print(exported_program) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"): - # code: return x + y[1:] - slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807) - add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1) - return (add,) - - Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]} - -Some things to note: - -* By specifying ``{0: dimx}`` for the first input, we see that the resulting - shape of the first input is now dynamic, being ``[s0]``. And now by specifying - ``{0: dimy}`` for the second input, we see that the resulting shape of the - second input is also dynamic. However, because we expressed ``dimy = dimx + 1``, - instead of ``y``'s shape containing a new symbol, we see that it is - now being represented with the same symbol used in ``x``, ``s0``. We can - see that relationship of ``dimy = dimx + 1`` is being shown through ``s0 + 1``. - -* Looking at the range constraints, we see that ``s0`` has the range [3, 6], - which is specified initially, and we can see that ``s0 + 1`` has the solved - range of [4, 7]. - - -Serialization -^^^^^^^^^^^^^ - -To save the ``ExportedProgram``, users can use the :func:`torch.export.save` and -:func:`torch.export.load` APIs. A convention is to save the ``ExportedProgram`` -using a ``.pt2`` file extension. - -An example: - -:: - - import torch - import io - - class MyModule(torch.nn.Module): - def forward(self, x): - return x + 10 - - exported_program = torch.export.export(MyModule(), torch.randn(5)) - - torch.export.save(exported_program, 'exported_program.pt2') - saved_exported_program = torch.export.load('exported_program.pt2') - - -Specializations -^^^^^^^^^^^^^^^ - -A key concept in understanding the behavior of ``torch.export`` is the -difference between *static* and *dynamic* values. - -A *dynamic* value is one that can change from run to run. These behave like -normal arguments to a Python function—you can pass different values for an -argument and expect your function to do the right thing. Tensor *data* is -treated as dynamic. - - -A *static* value is a value that is fixed at export time and cannot change -between executions of the exported program. When the value is encountered during -tracing, the exporter will treat it as a constant and hard-code it into the -graph. - -When an operation is performed (e.g. ``x + y``) and all inputs are static, then -the output of the operation will be directly hard-coded into the graph, and the -operation won’t show up (i.e. it will get constant-folded). - -When a value has been hard-coded into the graph, we say that the graph has been -*specialized* to that value. - -The following values are static: - -Input Tensor Shapes -~~~~~~~~~~~~~~~~~~~ - -By default, ``torch.export`` will trace the program specializing on the input -tensors' shapes, unless a dimension is specified as dynamic via the -``dynamic_shapes`` argument to ``torch.export``. This means that if there exists -shape-dependent control flow, ``torch.export`` will specialize on the branch -that is being taken with the given sample inputs. For example: - -:: - - import torch - from torch.export import export - - class Mod(torch.nn.Module): - def forward(self, x): - if x.shape[0] > 5: - return x + 1 - else: - return x - 1 - - example_inputs = (torch.rand(10, 2),) - exported_program = export(Mod(), example_inputs) - print(exported_program) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[10, 2]"): - # code: return x + 1 - add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1) - return (add,) - -The conditional of (``x.shape[0] > 5``) does not appear in the -``ExportedProgram`` because the example inputs have the static -shape of (10, 2). Since ``torch.export`` specializes on the inputs' static -shapes, the else branch (``x - 1``) will never be reached. To preserve the dynamic -branching behavior based on the shape of a tensor in the traced graph, -:func:`torch.export.Dim` will need to be used to specify the dimension -of the input tensor (``x.shape[0]``) to be dynamic, and the source code will -need to be :ref:`rewritten `. - -Note that tensors that are part of the module state (e.g. parameters and -buffers) always have static shapes. - -Python Primitives -~~~~~~~~~~~~~~~~~ - -``torch.export`` also specializes on Python primtivies, -such as ``int``, ``float``, ``bool``, and ``str``. However they do have dynamic -variants such as ``SymInt``, ``SymFloat``, and ``SymBool``. - -For example: - -:: - - import torch - from torch.export import export - - class Mod(torch.nn.Module): - def forward(self, x: torch.Tensor, const: int, times: int): - for i in range(times): - x = x + const - return x - - example_inputs = (torch.rand(2, 2), 1, 3) - exported_program = export(Mod(), example_inputs) - print(exported_program) - -.. code-block:: - - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[2, 2]", const, times): - # code: x = x + const - add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1) - add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1) - add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1) - return (add_2,) - -Because integers are specialized, the ``torch.ops.aten.add.Tensor`` operations -are all computed with the hard-coded constant ``1``, rather than ``const``. If -a user passes a different value for ``const`` at runtime, like 2, than the one used -during export time, 1, this will result in an error. -Additionally, the ``times`` iterator used in the ``for`` loop is also "inlined" -in the graph through the 3 repeated ``torch.ops.aten.add.Tensor`` calls, and the -input ``times`` is never used. - -Python Containers -~~~~~~~~~~~~~~~~~ - -Python containers (``List``, ``Dict``, ``NamedTuple``, etc.) are considered to -have static structure. - - -.. _Limitations of torch.export: - -Limitations of torch.export ---------------------------- - -Graph Breaks -^^^^^^^^^^^^ - -As ``torch.export`` is a one-shot process for capturing a computation graph from -a PyTorch program, it might ultimately run into untraceable parts of programs as -it is nearly impossible to support tracing all PyTorch and Python features. In -the case of ``torch.compile``, an unsupported operation will cause a "graph -break" and the unsupported operation will be run with default Python evaluation. -In contrast, ``torch.export`` will require users to provide additional -information or rewrite parts of their code to make it traceable. As the -tracing is based on TorchDynamo, which evaluates at the Python -bytecode level, there will be significantly fewer rewrites required compared to -previous tracing frameworks. - -When a graph break is encountered, :ref:`ExportDB ` is a great -resource for learning about the kinds of programs that are supported and -unsupported, along with ways to rewrite programs to make them traceable. - -An option to get past dealing with this graph breaks is by using -:ref:`non-strict export ` - -.. _Data/Shape-Dependent Control Flow: - -Data/Shape-Dependent Control Flow -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Graph breaks can also be encountered on data-dependent control flow (``if -x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot -possibly deal with without generating code for a combinatorially exploding -number of paths. In such cases, users will need to rewrite their code using -special control flow operators. Currently, we support :ref:`torch.cond ` -to express if-else like control flow (more coming soon!). - -Missing Fake/Meta/Abstract Kernels for Operators -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is -required for all operators. This is used to reason about the input/output shapes -for this operator. - -Please see :func:`torch.library.register_fake` for more details. - -In the unfortunate case where your model uses an ATen operator that is does not -have a FakeTensor kernel implementation yet, please file an issue. - - -Read More ---------- - -.. toctree:: - :caption: Additional Links for Export Users - :maxdepth: 1 - - export.programming_model - export.ir_spec - draft_export - torch.compiler_transformations - torch.compiler_ir - generated/exportdb/index - cond - -.. toctree:: - :caption: Deep Dive for PyTorch Developers - :maxdepth: 1 - - torch.compiler_dynamo_overview - torch.compiler_dynamo_deepdive - torch.compiler_dynamic_shapes - torch.compiler_fake_tensor - - -API Reference -------------- - -.. automodule:: torch.export -.. autofunction:: export -.. autofunction:: save -.. autofunction:: load -.. autofunction:: draft_export -.. autofunction:: register_dataclass -.. autoclass:: torch.export.dynamic_shapes.Dim -.. autoclass:: torch.export.dynamic_shapes.ShapesCollection - - .. automethod:: dynamic_shapes - -.. autoclass:: torch.export.dynamic_shapes.AdditionalInputs - - .. automethod:: add - .. automethod:: dynamic_shapes - .. automethod:: verify - -.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes -.. autoclass:: ExportedProgram - - .. attribute:: graph - .. attribute:: graph_signature - .. attribute:: state_dict - .. attribute:: constants - .. attribute:: range_constraints - .. attribute:: module_call_graph - .. attribute:: example_inputs - .. automethod:: module - .. automethod:: run_decompositions - -.. autoclass:: ExportGraphSignature -.. autoclass:: ModuleCallSignature -.. autoclass:: ModuleCallEntry -.. automodule:: torch.export.decomp_utils -.. autoclass:: CustomDecompTable - - .. automethod:: copy - .. automethod:: items - .. automethod:: keys - .. automethod:: materialize - .. automethod:: pop - .. automethod:: update -.. autofunction:: torch.export.exported_program.default_decompositions - -.. automodule:: torch.export.exported_program -.. automodule:: torch.export.graph_signature -.. autoclass:: ExportGraphSignature - - .. automethod:: replace_all_uses - .. automethod:: get_replace_hook - -.. autoclass:: ExportBackwardSignature -.. autoclass:: InputKind -.. autoclass:: InputSpec -.. autoclass:: OutputKind -.. autoclass:: OutputSpec -.. autoclass:: SymIntArgument -.. autoclass:: SymBoolArgument -.. autoclass:: SymFloatArgument - -.. autoclass:: CustomObjArgument - -.. py:module:: torch.export.dynamic_shapes -.. py:module:: torch.export.custom_ops - -.. automodule:: torch.export.unflatten - :members: - -.. automodule:: torch.export.custom_obj - -.. automodule:: torch.export.experimental -.. automodule:: torch.export.passes -.. autofunction:: torch.export.passes.move_to_device_pass -.. automodule:: torch.export.pt2_archive -.. automodule:: torch.export.pt2_archive.constants diff --git a/docs/source/fft.md b/docs/source/fft.md new file mode 100644 index 00000000000000..b345dbbdda4633 --- /dev/null +++ b/docs/source/fft.md @@ -0,0 +1,56 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# torch.fft + +Discrete Fourier transforms and related functions. + +```{eval-rst} +.. automodule:: torch.fft +``` + +```{eval-rst} +.. currentmodule:: torch.fft +``` + +## Fast Fourier Transforms + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + fft + ifft + fft2 + ifft2 + fftn + ifftn + rfft + irfft + rfft2 + irfft2 + rfftn + irfftn + hfft + ihfft + hfft2 + ihfft2 + hfftn + ihfftn +``` + +## Helper Functions + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + fftfreq + rfftfreq + fftshift + ifftshift +``` diff --git a/docs/source/fft.rst b/docs/source/fft.rst deleted file mode 100644 index 5406b6610a602b..00000000000000 --- a/docs/source/fft.rst +++ /dev/null @@ -1,48 +0,0 @@ -.. role:: hidden - :class: hidden-section - -torch.fft -========= - -Discrete Fourier transforms and related functions. - -.. automodule:: torch.fft -.. currentmodule:: torch.fft - -Fast Fourier Transforms ------------------------ - -.. autosummary:: - :toctree: generated - :nosignatures: - - fft - ifft - fft2 - ifft2 - fftn - ifftn - rfft - irfft - rfft2 - irfft2 - rfftn - irfftn - hfft - ihfft - hfft2 - ihfft2 - hfftn - ihfftn - -Helper Functions ----------------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - fftfreq - rfftfreq - fftshift - ifftshift diff --git a/docs/source/fsdp.md b/docs/source/fsdp.md new file mode 100644 index 00000000000000..6163e56bbe640e --- /dev/null +++ b/docs/source/fsdp.md @@ -0,0 +1,75 @@ +# FullyShardedDataParallel + +```{eval-rst} +.. automodule:: torch.distributed.fsdp +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.BackwardPrefetch + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.ShardingStrategy + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.MixedPrecision + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.CPUOffload + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.StateDictConfig + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.FullStateDictConfig + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig + :members: +``` + +```{eval-rst} +.. autoclass:: torch.distributed.fsdp.StateDictSettings + :members: +``` diff --git a/docs/source/fsdp.rst b/docs/source/fsdp.rst deleted file mode 100644 index 41883e3c6ed230..00000000000000 --- a/docs/source/fsdp.rst +++ /dev/null @@ -1,46 +0,0 @@ -FullyShardedDataParallel -======================== - -.. automodule:: torch.distributed.fsdp - -.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel - :members: - -.. autoclass:: torch.distributed.fsdp.BackwardPrefetch - :members: - -.. autoclass:: torch.distributed.fsdp.ShardingStrategy - :members: - -.. autoclass:: torch.distributed.fsdp.MixedPrecision - :members: - -.. autoclass:: torch.distributed.fsdp.CPUOffload - :members: - -.. autoclass:: torch.distributed.fsdp.StateDictConfig - :members: - -.. autoclass:: torch.distributed.fsdp.FullStateDictConfig - :members: - -.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig - :members: - -.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig - :members: - -.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig - :members: - -.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig - :members: - -.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig - :members: - -.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig - :members: - -.. autoclass:: torch.distributed.fsdp.StateDictSettings - :members: diff --git a/docs/source/func.api.md b/docs/source/func.api.md new file mode 100644 index 00000000000000..111eb4dc743e4d --- /dev/null +++ b/docs/source/func.api.md @@ -0,0 +1,88 @@ +# torch.func API Reference + +```{eval-rst} +.. currentmodule:: torch.func +``` + +```{eval-rst} +.. automodule:: torch.func +``` + +## Function Transforms +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + vmap + grad + grad_and_value + vjp + jvp + linearize + jacrev + jacfwd + hessian + functionalize +``` + +## Utilities for working with torch.nn.Modules + +In general, you can transform over a function that calls a ``torch.nn.Module``. +For example, the following is an example of computing a jacobian of a function +that takes three values and returns three values: + +```python +model = torch.nn.Linear(3, 3) + +def f(x): + return model(x) + +x = torch.randn(3) +jacobian = jacrev(f)(x) +assert jacobian.shape == (3, 3) +``` + +However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That's what {func}`functional_call` is for: it accepts an nn.Module, the transformed `parameters`, and the inputs to the Module's forward pass. It returns the value of running the Module's forward pass with the replaced parameters. + +Here's how we would compute the Jacobian over the parameters + +```python +model = torch.nn.Linear(3, 3) + +def f(params, x): + return torch.func.functional_call(model, params, x) + +x = torch.randn(3) +jacobian = jacrev(f)(dict(model.named_parameters()), x) +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + functional_call + stack_module_state + replace_all_batch_norm_modules_ +``` + +If you're looking for information on fixing Batch Norm modules, please follow the +guidance here + +```{eval-rst} +.. toctree:: + :maxdepth: 1 + + func.batch_norm +``` + +## Debug utilities + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + debug_unwrap +``` diff --git a/docs/source/func.api.rst b/docs/source/func.api.rst deleted file mode 100644 index 362954f731af45..00000000000000 --- a/docs/source/func.api.rst +++ /dev/null @@ -1,87 +0,0 @@ -torch.func API Reference -======================== - -.. currentmodule:: torch.func - -.. automodule:: torch.func - -Function Transforms -------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - vmap - grad - grad_and_value - vjp - jvp - linearize - jacrev - jacfwd - hessian - functionalize - -Utilities for working with torch.nn.Modules -------------------------------------------- - -In general, you can transform over a function that calls a ``torch.nn.Module``. -For example, the following is an example of computing a jacobian of a function -that takes three values and returns three values: - -.. code-block:: python - - model = torch.nn.Linear(3, 3) - - def f(x): - return model(x) - - x = torch.randn(3) - jacobian = jacrev(f)(x) - assert jacobian.shape == (3, 3) - -However, if you want to do something like compute a jacobian over the parameters -of the model, then there needs to be a way to construct a function where the -parameters are the inputs to the function. -That's what :func:`functional_call` is for: -it accepts an nn.Module, the transformed ``parameters``, and the inputs to the -Module's forward pass. It returns the value of running the Module's forward pass -with the replaced parameters. - -Here's how we would compute the Jacobian over the parameters - -.. code-block:: python - - model = torch.nn.Linear(3, 3) - - def f(params, x): - return torch.func.functional_call(model, params, x) - - x = torch.randn(3) - jacobian = jacrev(f)(dict(model.named_parameters()), x) - - -.. autosummary:: - :toctree: generated - :nosignatures: - - functional_call - stack_module_state - replace_all_batch_norm_modules_ - -If you're looking for information on fixing Batch Norm modules, please follow the -guidance here - -.. toctree:: - :maxdepth: 1 - - func.batch_norm - -Debug utilities ---------------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - debug_unwrap diff --git a/docs/source/func.batch_norm.md b/docs/source/func.batch_norm.md new file mode 100644 index 00000000000000..86907f11a7ca02 --- /dev/null +++ b/docs/source/func.batch_norm.md @@ -0,0 +1,75 @@ +# Patching Batch Norm + +## What's happening? +Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. +Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. +`regular.add_(batched)` is not allowed). So when vmapping over a batch of inputs to a single module, +we end up with this error + +## How to fix +One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this + +All of these options assume that you don't need running stats. If you're using a module this means +that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves +running batch norm with vmap in evaluation mode, please file an issue + +### Option 1: Change the BatchNorm +If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with: + +```python +BatchNorm2d(C, G, track_running_stats=False) +``` + +Here `C` is the same `C` as in the original BatchNorm. `G` is the number of groups to +break `C` into. As such, `C % G == 0` and as a fallback, you can set `C == G`, meaning +each channel will be treated separately. + +If you must use BatchNorm and you've built the module yourself, you can change the module to +not use running stats. In other words, anywhere that there's a BatchNorm module, set the +`track_running_stats` flag to be False + +```python +BatchNorm2d(64, track_running_stats=False) +``` + +### Option 2: torchvision parameter +Some torchvision models, like resnet and regnet, can take in a `norm_layer` parameter. These are +often defaulted to be BatchNorm2d if they've been defaulted. + +Instead you can set it to be GroupNorm. + +```python +import torchvision +from functools import partial +torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c)) +``` + +Here, once again, `c % g == 0` so as a fallback, set `g = c`. + +If you are attached to BatchNorm, be sure to use a version that doesn't use running stats + +```python +import torchvision +from functools import partial +torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False)) +``` + +### Option 3: functorch's patching +functorch has added some functionality to allow for quick, in-place patching of the module to not +use running stats. Changing the norm layer is more fragile, so we have not offered that. If you +have a net where you want the BatchNorm to not use running stats, you can run +`replace_all_batch_norm_modules_` to update the module in-place to not use running stats + +```python +from torch.func import replace_all_batch_norm_modules_ +replace_all_batch_norm_modules_(net) +``` + +### Option 4: eval mode +When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode + +```python +model.eval() +vmap(model)(x) +model.train() +``` diff --git a/docs/source/func.batch_norm.rst b/docs/source/func.batch_norm.rst deleted file mode 100644 index 1843c4c6db566e..00000000000000 --- a/docs/source/func.batch_norm.rst +++ /dev/null @@ -1,83 +0,0 @@ -Patching Batch Norm -=================== - -What's happening? ------------------ -Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. -Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. -``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module, -we end up with this error - -How to fix ----------- -One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this - -All of these options assume that you don't need running stats. If you're using a module this means -that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves -running batch norm with vmap in evaluation mode, please file an issue - -Option 1: Change the BatchNorm -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with: - -.. code-block:: python - - BatchNorm2d(C, G, track_running_stats=False) - -Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to -break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning -each channel will be treated separately. - -If you must use BatchNorm and you've built the module yourself, you can change the module to -not use running stats. In other words, anywhere that there's a BatchNorm module, set the -``track_running_stats`` flag to be False - -.. code-block:: python - - BatchNorm2d(64, track_running_stats=False) - - -Option 2: torchvision parameter -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are -often defaulted to be BatchNorm2d if they've been defaulted. - -Instead you can set it to be GroupNorm. - -.. code-block:: python - - import torchvision - from functools import partial - torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c)) - -Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``. - -If you are attached to BatchNorm, be sure to use a version that doesn't use running stats - -.. code-block:: python - - import torchvision - from functools import partial - torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False)) - -Option 3: functorch's patching -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -functorch has added some functionality to allow for quick, in-place patching of the module to not -use running stats. Changing the norm layer is more fragile, so we have not offered that. If you -have a net where you want the BatchNorm to not use running stats, you can run -``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats - -.. code-block:: python - - from torch.func import replace_all_batch_norm_modules_ - replace_all_batch_norm_modules_(net) - -Option 4: eval mode -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode - -.. code-block:: python - - model.eval() - vmap(model)(x) - model.train() diff --git a/docs/source/func.md b/docs/source/func.md new file mode 100644 index 00000000000000..d1b81a00fa5332 --- /dev/null +++ b/docs/source/func.md @@ -0,0 +1,56 @@ +# torch.func + +```{eval-rst} +.. currentmodule:: torch.func +``` + +torch.func, previously known as "functorch", is +[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch. + +```{note} +This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta). +What this means is that the features generally work (unless otherwise documented) +and we (the PyTorch team) are committed to bringing this library forward. However, the APIs +may change under user feedback and we don't have full coverage over PyTorch operations. + +If you have suggestions on the API or use-cases you'd like to be covered, please +open a GitHub issue or reach out. We'd love to hear about how you're using the library. +``` + +## What are composable function transforms? + +- A "function transform" is a higher-order function that accepts a numerical function + and returns a new function that computes a different quantity. + +- {mod}`torch.func` has auto-differentiation transforms (`grad(f)` returns a function that + computes the gradient of `f`), a vectorization/batching transform (`vmap(f)` + returns a function that computes `f` over batches of inputs), and others. + +- These function transforms can compose with each other arbitrarily. For example, + composing `vmap(grad(f))` computes a quantity called per-sample-gradients that + stock PyTorch cannot efficiently compute today. + +## Why composable function transforms? + +There are a number of use cases that are tricky to do in PyTorch today: + +- computing per-sample-gradients (or other per-sample quantities) +- running ensembles of models on a single machine +- efficiently batching together tasks in the inner-loop of MAML +- efficiently computing Jacobians and Hessians +- efficiently computing batched Jacobians and Hessians + +Composing {func}`vmap`, {func}`grad`, and {func}`vjp` transforms allows us to express the above without designing a separate subsystem for each. +This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax). + +## Read More + +```{eval-rst} +.. toctree:: + :maxdepth: 2 + + func.whirlwind_tour + func.api + func.ux_limitations + func.migrating +``` diff --git a/docs/source/func.migrating.md b/docs/source/func.migrating.md new file mode 100644 index 00000000000000..b87a590f804a18 --- /dev/null +++ b/docs/source/func.migrating.md @@ -0,0 +1,201 @@ +# Migrating from functorch to torch.func + +torch.func, previously known as "functorch", is +[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch. + +functorch started as an out-of-tree library over at +the [pytorch/functorch](https://github.com/pytorch/functorch) repository. +Our goal has always been to upstream functorch directly into PyTorch and provide +it as a core PyTorch library. + +As the final step of the upstream, we've decided to migrate from being a top level package +(`functorch`) to being a part of PyTorch to reflect how the function transforms are +integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating +`import functorch` and ask that users migrate to the newest APIs, which we +will maintain going forward. `import functorch` will be kept around to maintain +backwards compatibility for a couple of releases. + +## function transforms + +The following APIs are a drop-in replacement for the following +[functorch APIs](https://pytorch.org/functorch/1.13/functorch.html). +They are fully backwards compatible. + +| functorch API | PyTorch API (as of PyTorch 2.0) | +| ----------------------------------- | ---------------------------------------------- | +| functorch.vmap | {func}`torch.vmap` or {func}`torch.func.vmap` | +| functorch.grad | {func}`torch.func.grad` | +| functorch.vjp | {func}`torch.func.vjp` | +| functorch.jvp | {func}`torch.func.jvp` | +| functorch.jacrev | {func}`torch.func.jacrev` | +| functorch.jacfwd | {func}`torch.func.jacfwd` | +| functorch.hessian | {func}`torch.func.hessian` | +| functorch.functionalize | {func}`torch.func.functionalize` | + +Furthermore, if you are using torch.autograd.functional APIs, please try out +the {mod}`torch.func` equivalents instead. {mod}`torch.func` function +transforms are more composable and more performant in many cases. + +| torch.autograd.functional API | torch.func API (as of PyTorch 2.0) | +| ------------------------------------------- | ---------------------------------------------- | +| {func}`torch.autograd.functional.vjp` | {func}`torch.func.grad` or {func}`torch.func.vjp` | +| {func}`torch.autograd.functional.jvp` | {func}`torch.func.jvp` | +| {func}`torch.autograd.functional.jacobian` | {func}`torch.func.jacrev` or {func}`torch.func.jacfwd` | +| {func}`torch.autograd.functional.hessian` | {func}`torch.func.hessian` | + +## NN module utilities + +We've changed the APIs to apply function transforms over NN modules to make them +fit better into the PyTorch design philosophy. The new API is different, so +please read this section carefully. + +### functorch.make_functional + +{func}`torch.func.functional_call` is the replacement for +[functorch.make_functional](https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional) +and +[functorch.make_functional_with_buffers](https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers). +However, it is not a drop-in replacement. + +If you're in a hurry, you can use +[helper functions in this gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf) +that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers. +We recommend using {func}`torch.func.functional_call` directly because it is a more explicit +and flexible API. + +Concretely, functorch.make_functional returns a functional module and parameters. +The functional module accepts parameters and inputs to the model as arguments. +{func}`torch.func.functional_call` allows one to call the forward pass of an existing +module using new parameters and buffers and inputs. + +Here's an example of how to compute gradients of parameters of a model using functorch +vs {mod}`torch.func`: + +```python +# --------------- +# using functorch +# --------------- +import torch +import functorch +inputs = torch.randn(64, 3) +targets = torch.randn(64, 3) +model = torch.nn.Linear(3, 3) + +fmodel, params = functorch.make_functional(model) + +def compute_loss(params, inputs, targets): + prediction = fmodel(params, inputs) + return torch.nn.functional.mse_loss(prediction, targets) + +grads = functorch.grad(compute_loss)(params, inputs, targets) + +# ------------------------------------ +# using torch.func (as of PyTorch 2.0) +# ------------------------------------ +import torch +inputs = torch.randn(64, 3) +targets = torch.randn(64, 3) +model = torch.nn.Linear(3, 3) + +params = dict(model.named_parameters()) + +def compute_loss(params, inputs, targets): + prediction = torch.func.functional_call(model, params, (inputs,)) + return torch.nn.functional.mse_loss(prediction, targets) + +grads = torch.func.grad(compute_loss)(params, inputs, targets) +``` + +And here's an example of how to compute jacobians of model parameters: + +```python +# --------------- +# using functorch +# --------------- +import torch +import functorch +inputs = torch.randn(64, 3) +model = torch.nn.Linear(3, 3) + +fmodel, params = functorch.make_functional(model) +jacobians = functorch.jacrev(fmodel)(params, inputs) + +# ------------------------------------ +# using torch.func (as of PyTorch 2.0) +# ------------------------------------ +import torch +from torch.func import jacrev, functional_call +inputs = torch.randn(64, 3) +model = torch.nn.Linear(3, 3) + +params = dict(model.named_parameters()) +# jacrev computes jacobians of argnums=0 by default. +# We set it to 1 to compute jacobians of params +jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,)) +``` + +Note that it is important for memory consumption that you should only carry +around a single copy of your parameters. `model.named_parameters()` does not copy +the parameters. If in your model training you update the parameters of the model +in-place, then the `nn.Module` that is your model has the single copy of the +parameters and everything is OK. + +However, if you want to carry your parameters around in a dictionary and update +them out-of-place, then there are two copies of parameters: the one in the +dictionary and the one in the `model`. In this case, you should change +`model` to not hold memory by converting it to the meta device via +`model.to('meta')`. + +### functorch.combine_state_for_ensemble + +Please use {func}`torch.func.stack_module_state` instead of +[functorch.combine_state_for_ensemble](https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html) +{func}`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and +one of stacked buffers, that can then be used with {func}`torch.vmap` and {func}`torch.func.functional_call` +for ensembling. + +For example, here is an example of how to ensemble over a very simple model: + +```python +import torch +num_models = 5 +batch_size = 64 +in_features, out_features = 3, 3 +models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] +data = torch.randn(batch_size, 3) + +# --------------- +# using functorch +# --------------- +import functorch +fmodel, params, buffers = functorch.combine_state_for_ensemble(models) +output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data) +assert output.shape == (num_models, batch_size, out_features) + +# ------------------------------------ +# using torch.func (as of PyTorch 2.0) +# ------------------------------------ +import copy + +# Construct a version of the model with no memory by putting the Tensors on +# the meta device. +base_model = copy.deepcopy(models[0]) +base_model.to('meta') + +params, buffers = torch.func.stack_module_state(models) + +# It is possible to vmap directly over torch.func.functional_call, +# but wrapping it in a function makes it clearer what is going on. +def call_single_model(params, buffers, data): + return torch.func.functional_call(base_model, (params, buffers), (data,)) + +output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data) +assert output.shape == (num_models, batch_size, out_features) +``` + +## functorch.compile + +We are no longer supporting functorch.compile (also known as AOTAutograd) +as a frontend for compilation in PyTorch; we have integrated AOTAutograd +into PyTorch's compilation story. If you are a user, please use +{func}`torch.compile` instead. diff --git a/docs/source/func.migrating.rst b/docs/source/func.migrating.rst deleted file mode 100644 index c12596e4607cf7..00000000000000 --- a/docs/source/func.migrating.rst +++ /dev/null @@ -1,207 +0,0 @@ -Migrating from functorch to torch.func -====================================== - -torch.func, previously known as "functorch", is -`JAX-like `_ composable function transforms for PyTorch. - -functorch started as an out-of-tree library over at -the `pytorch/functorch `_ repository. -Our goal has always been to upstream functorch directly into PyTorch and provide -it as a core PyTorch library. - -As the final step of the upstream, we've decided to migrate from being a top level package -(``functorch``) to being a part of PyTorch to reflect how the function transforms are -integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating -``import functorch`` and ask that users migrate to the newest APIs, which we -will maintain going forward. ``import functorch`` will be kept around to maintain -backwards compatibility for a couple of releases. - -function transforms -------------------- - -The following APIs are a drop-in replacement for the following -`functorch APIs `_. -They are fully backwards compatible. - - -============================== ======================================= -functorch API PyTorch API (as of PyTorch 2.0) -============================== ======================================= -functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap` -functorch.grad :func:`torch.func.grad` -functorch.vjp :func:`torch.func.vjp` -functorch.jvp :func:`torch.func.jvp` -functorch.jacrev :func:`torch.func.jacrev` -functorch.jacfwd :func:`torch.func.jacfwd` -functorch.hessian :func:`torch.func.hessian` -functorch.functionalize :func:`torch.func.functionalize` -============================== ======================================= - -Furthermore, if you are using torch.autograd.functional APIs, please try out -the :mod:`torch.func` equivalents instead. :mod:`torch.func` function -transforms are more composable and more performant in many cases. - -=========================================== ======================================= -torch.autograd.functional API torch.func API (as of PyTorch 2.0) -=========================================== ======================================= -:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp` -:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp` -:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd` -:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian` -=========================================== ======================================= - -NN module utilities -------------------- - -We've changed the APIs to apply function transforms over NN modules to make them -fit better into the PyTorch design philosophy. The new API is different, so -please read this section carefully. - -functorch.make_functional -^^^^^^^^^^^^^^^^^^^^^^^^^ - -:func:`torch.func.functional_call` is the replacement for -`functorch.make_functional `_ -and -`functorch.make_functional_with_buffers `_. -However, it is not a drop-in replacement. - -If you're in a hurry, you can use -`helper functions in this gist `_ -that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers. -We recommend using :func:`torch.func.functional_call` directly because it is a more explicit -and flexible API. - -Concretely, functorch.make_functional returns a functional module and parameters. -The functional module accepts parameters and inputs to the model as arguments. -:func:`torch.func.functional_call` allows one to call the forward pass of an existing -module using new parameters and buffers and inputs. - -Here's an example of how to compute gradients of parameters of a model using functorch -vs :mod:`torch.func`:: - - # --------------- - # using functorch - # --------------- - import torch - import functorch - inputs = torch.randn(64, 3) - targets = torch.randn(64, 3) - model = torch.nn.Linear(3, 3) - - fmodel, params = functorch.make_functional(model) - - def compute_loss(params, inputs, targets): - prediction = fmodel(params, inputs) - return torch.nn.functional.mse_loss(prediction, targets) - - grads = functorch.grad(compute_loss)(params, inputs, targets) - - # ------------------------------------ - # using torch.func (as of PyTorch 2.0) - # ------------------------------------ - import torch - inputs = torch.randn(64, 3) - targets = torch.randn(64, 3) - model = torch.nn.Linear(3, 3) - - params = dict(model.named_parameters()) - - def compute_loss(params, inputs, targets): - prediction = torch.func.functional_call(model, params, (inputs,)) - return torch.nn.functional.mse_loss(prediction, targets) - - grads = torch.func.grad(compute_loss)(params, inputs, targets) - -And here's an example of how to compute jacobians of model parameters:: - - # --------------- - # using functorch - # --------------- - import torch - import functorch - inputs = torch.randn(64, 3) - model = torch.nn.Linear(3, 3) - - fmodel, params = functorch.make_functional(model) - jacobians = functorch.jacrev(fmodel)(params, inputs) - - # ------------------------------------ - # using torch.func (as of PyTorch 2.0) - # ------------------------------------ - import torch - from torch.func import jacrev, functional_call - inputs = torch.randn(64, 3) - model = torch.nn.Linear(3, 3) - - params = dict(model.named_parameters()) - # jacrev computes jacobians of argnums=0 by default. - # We set it to 1 to compute jacobians of params - jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,)) - -Note that it is important for memory consumption that you should only carry -around a single copy of your parameters. ``model.named_parameters()`` does not copy -the parameters. If in your model training you update the parameters of the model -in-place, then the ``nn.Module`` that is your model has the single copy of the -parameters and everything is OK. - -However, if you want to carry your parameters around in a dictionary and update -them out-of-place, then there are two copies of parameters: the one in the -dictionary and the one in the ``model``. In this case, you should change -``model`` to not hold memory by converting it to the meta device via -``model.to('meta')``. - -functorch.combine_state_for_ensemble -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Please use :func:`torch.func.stack_module_state` instead of -`functorch.combine_state_for_ensemble `_ -:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and -one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call` -for ensembling. - -For example, here is an example of how to ensemble over a very simple model:: - - import torch - num_models = 5 - batch_size = 64 - in_features, out_features = 3, 3 - models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] - data = torch.randn(batch_size, 3) - - # --------------- - # using functorch - # --------------- - import functorch - fmodel, params, buffers = functorch.combine_state_for_ensemble(models) - output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data) - assert output.shape == (num_models, batch_size, out_features) - - # ------------------------------------ - # using torch.func (as of PyTorch 2.0) - # ------------------------------------ - import copy - - # Construct a version of the model with no memory by putting the Tensors on - # the meta device. - base_model = copy.deepcopy(models[0]) - base_model.to('meta') - - params, buffers = torch.func.stack_module_state(models) - - # It is possible to vmap directly over torch.func.functional_call, - # but wrapping it in a function makes it clearer what is going on. - def call_single_model(params, buffers, data): - return torch.func.functional_call(base_model, (params, buffers), (data,)) - - output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data) - assert output.shape == (num_models, batch_size, out_features) - - -functorch.compile ------------------ - -We are no longer supporting functorch.compile (also known as AOTAutograd) -as a frontend for compilation in PyTorch; we have integrated AOTAutograd -into PyTorch's compilation story. If you are a user, please use -:func:`torch.compile` instead. diff --git a/docs/source/func.rst b/docs/source/func.rst deleted file mode 100644 index 4a14019c3d23ae..00000000000000 --- a/docs/source/func.rst +++ /dev/null @@ -1,55 +0,0 @@ -torch.func -========== - -.. currentmodule:: torch.func - -torch.func, previously known as "functorch", is -`JAX-like `_ composable function transforms for PyTorch. - -.. note:: - This library is currently in `beta `_. - What this means is that the features generally work (unless otherwise documented) - and we (the PyTorch team) are committed to bringing this library forward. However, the APIs - may change under user feedback and we don't have full coverage over PyTorch operations. - - If you have suggestions on the API or use-cases you'd like to be covered, please - open an GitHub issue or reach out. We'd love to hear about how you're using the library. - -What are composable function transforms? ----------------------------------------- - -- A "function transform" is a higher-order function that accepts a numerical function - and returns a new function that computes a different quantity. - -- :mod:`torch.func` has auto-differentiation transforms (``grad(f)`` returns a function that - computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)`` - returns a function that computes ``f`` over batches of inputs), and others. - -- These function transforms can compose with each other arbitrarily. For example, - composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that - stock PyTorch cannot efficiently compute today. - -Why composable function transforms? ------------------------------------ - -There are a number of use cases that are tricky to do in PyTorch today: - -- computing per-sample-gradients (or other per-sample quantities) -- running ensembles of models on a single machine -- efficiently batching together tasks in the inner-loop of MAML -- efficiently computing Jacobians and Hessians -- efficiently computing batched Jacobians and Hessians - -Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each. -This idea of composable function transforms comes from the `JAX framework `_. - -Read More ---------- - -.. toctree:: - :maxdepth: 2 - - func.whirlwind_tour - func.api - func.ux_limitations - func.migrating diff --git a/docs/source/func.ux_limitations.md b/docs/source/func.ux_limitations.md new file mode 100644 index 00000000000000..3997f806a5290d --- /dev/null +++ b/docs/source/func.ux_limitations.md @@ -0,0 +1,334 @@ +```{eval-rst} +.. currentmodule:: torch.func +``` + +(ux-limitations)= + +# UX Limitations + +torch.func, like [JAX](https://github.com/google/jax), has restrictions around +what can be transformed. In general, JAX’s limitations are that transforms +only work with pure functions: that is, functions where the output is completely +determined by the input and that do not involve side effects (like mutation). + +We have a similar guarantee: our transforms work well with pure functions. +However, we do support certain in-place operations. On one hand, writing code +compatible with function transforms may involve changing how you write PyTorch +code, on the other hand, you may find that our transforms let you express things +that were previously difficult to express in PyTorch. + +## General limitations + +All torch.func transforms share a limitation in that a function should not +assign to global variables. Instead, all outputs to a function must be returned +from the function. This restriction comes from how torch.func is implemented: +each transform wraps Tensor inputs in special torch.func Tensor subclasses +that facilitate the transform. + +So, instead of the following: + +```python +import torch +from torch.func import grad + +# Don't do this +intermediate = None + +def f(x): + global intermediate + intermediate = x.sin() + z = intermediate.sin() + return z + +x = torch.randn([]) +grad_x = grad(f)(x) +``` + +Please rewrite `f` to return `intermediate`: + +```python +def f(x): + intermediate = x.sin() + z = intermediate.sin() + return z, intermediate + +grad_x, intermediate = grad(f, has_aux=True)(x) +``` + +## torch.autograd APIs + +If you are trying to use a `torch.autograd` API like `torch.autograd.grad` +or `torch.autograd.backward` inside of a function being transformed by +{func}`vmap` or one of torch.func's AD transforms ({func}`vjp`, {func}`jvp`, +{func}`jacrev`, {func}`jacfwd`), the transform may not be able to transform over it. +If it is unable to do so, you'll receive an error message. + +This is a fundamental design limitation in how PyTorch's AD support is implemented +and the reason why we designed the torch.func library. Please instead use the torch.func +equivalents of the `torch.autograd` APIs: +- `torch.autograd.grad`, `Tensor.backward` -> `torch.func.vjp` or `torch.func.grad` +- `torch.autograd.functional.jvp` -> `torch.func.jvp` +- `torch.autograd.functional.jacobian` -> `torch.func.jacrev` or `torch.func.jacfwd` +- `torch.autograd.functional.hessian` -> `torch.func.hessian` + +## vmap limitations + +:::{note} +{func}`vmap` is our most restrictive transform. +The grad-related transforms ({func}`grad`, {func}`vjp`, {func}`jvp`) do not +have these limitations. {func}`jacfwd` (and {func}`hessian`, which is +implemented with {func}`jacfwd`) is a composition of {func}`vmap` and +{func}`jvp` so it also has these limitations. +::: + +`vmap(func)` is a transform that returns a function that maps `func` over +some new dimension of each input Tensor. The mental model for vmap is that it is +like running a for-loop: for pure functions (i.e. in the absence of side +effects), `vmap(f)(x)` is equivalent to: + +```python +torch.stack([f(x_i) for x_i in x.unbind(0)]) +``` + +### Mutation: Arbitrary mutation of Python data structures + +In the presence of side effects, {func}`vmap` no longer acts like it is running +a for-loop. For example, the following function: + +```python +def f(x, list): + list.pop() + print("hello!") + return x.sum(0) + +x = torch.randn(3, 1) +lst = [0, 1, 2, 3] + +result = vmap(f, in_dims=(0, None))(x, lst) +``` + +will print "hello!" once and pop only one element from `lst`. + +{func}`vmap` executes `f` a single time, so all side effects only happen once. + +This is a consequence of how vmap is implemented. torch.func has a special, +internal BatchedTensor class. `vmap(f)(*inputs)` takes all Tensor inputs, +turns them into BatchedTensors, and calls `f(*batched_tensor_inputs)`. +BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized) +behavior for each PyTorch operator. + +### Mutation: in-place PyTorch Operations + +You might be here due to receiving an error about vmap-incompatible in-place +operations. {func}`vmap` will raise an error if it encounters an unsupported PyTorch +in-place operation and it will succeed otherwise. Unsupported operations +are those that would cause a Tensor with more elements to be written to a +Tensor with fewer elements. Here's an example of how this can occur: + +```python +def f(x, y): + x.add_(y) + return x + +x = torch.randn(1) +y = torch.randn(3, 1) # When vmapped over, looks like it has shape [1] + +# Raises an error because `x` has fewer elements than `y`. +vmap(f, in_dims=(None, 0))(x, y) +``` + +`x` is a Tensor with one element, `y` is a Tensor with three elements. +`x + y` has three elements (due to broadcasting), but attempting to write +three elements back into `x`, which only has one element, raises an error +due to attempting to write three elements into a Tensor with a single element. + +There is no problem if the Tensor being written to is batched under +{func}`~torch.vmap` (i.e. it is being vmapped over). + +```python +def f(x, y): + x.add_(y) + return x + +x = torch.randn(3, 1) +y = torch.randn(3, 1) +expected = x + y + +# Does not raise an error because x is being vmapped over. +vmap(f, in_dims=(0, 0))(x, y) +assert torch.allclose(x, expected) +``` + +One common fix for this is to replace calls to factory functions with +their "new\_\*" equivalent. For example: + +- Replace {func}`torch.zeros` with {meth}`Tensor.new_zeros` +- Replace {func}`torch.empty` with {meth}`Tensor.new_empty` + +To see why this helps, consider the following. + +```python +def diag_embed(vec): + assert vec.dim() == 1 + result = torch.zeros(vec.shape[0], vec.shape[0]) + result.diagonal().copy_(vec) + return result + +vecs = torch.tensor([[0., 1, 2], [3., 4, 5]]) + +# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ... +vmap(diag_embed)(vecs) +``` + +Inside of {func}`~torch.vmap`, `result` is a Tensor of shape [3, 3]. +However, although `vec` looks like it has shape [3], `vec` actually has +underlying shape [2, 3]. +It is not possible to copy `vec` into `result.diagonal()`, which has +shape [3], because it has too many elements. + +```python +def diag_embed(vec): + assert vec.dim() == 1 + result = vec.new_zeros(vec.shape[0], vec.shape[0]) + result.diagonal().copy_(vec) + return result + +vecs = torch.tensor([[0., 1, 2], [3., 4, 5]]) +vmap(diag_embed)(vecs) +``` + +Replacing {func}`torch.zeros` with {meth}`Tensor.new_zeros` makes it so that +`result` has an underlying Tensor of shape [2, 3, 3], so it is now possible +to copy `vec`, which has underlying shape [2, 3], into `result.diagonal()`. + +### Mutation: out= PyTorch Operations + +{func}`vmap` doesn't support the `out=` keyword argument in PyTorch operations. +It will error out gracefully if it encounters that in your code. + +This is not a fundamental limitation; we could theoretically support this in the +future but we have chosen not to for now. + +### Data-dependent Python control flow + +We don't yet support `vmap` over data-dependent control flow. Data-dependent +control flow is when the condition of an if-statement, while-loop, or +for-loop is a Tensor that is being `vmap`'ed over. For example, the +following will raise an error message: + +```python +def relu(x): + if x > 0: + return x + return 0 + +x = torch.randn(3) +vmap(relu)(x) +``` + +However, any control flow that is not dependent on the values in `vmap`'ed +tensors will work: + +```python +def custom_dot(x): + if x.dim() == 1: + return torch.dot(x, x) + return (x * x).sum() + +x = torch.randn(3) +vmap(custom_dot)(x) +``` + +JAX supports transforming over +[data-dependent control flow](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) +using special control flow operators (e.g. `jax.lax.cond`, `jax.lax.while_loop`). +We're investigating adding equivalents of those to PyTorch. + +### Data-dependent operations (.item()) + +We do not (and will not) support vmap over a user-defined function that calls +`.item()` on a Tensor. For example, the following will raise an error message: + +```python +def f(x): + return x.item() + +x = torch.randn(3) +vmap(f)(x) +``` + +Please try to rewrite your code to not use `.item()` calls. + +You may also encounter an error message about using `.item()` but you might +not have used it. In those cases, it is possible that PyTorch internally is +calling `.item()` -- please file an issue on GitHub and we'll fix +PyTorch internals. + +### Dynamic shape operations (nonzero and friends) + +`vmap(f)` requires that `f` applied to every "example" in your input +returns a Tensor with the same shape. Operations such as `torch.nonzero`, +`torch.is_nonzero` are not supported and will error as a result. + +To see why, consider the following example: + +```python +xs = torch.tensor([[0, 1, 2], [0, 0, 3]]) +vmap(torch.nonzero)(xs) +``` + +`torch.nonzero(xs[0])` returns a Tensor of shape 2; +but `torch.nonzero(xs[1])` returns a Tensor of shape 1. +We are unable to construct a single Tensor as an output; +the output would need to be a ragged Tensor (and PyTorch does not yet have +the concept of a ragged Tensor). + +## Randomness + +The user's intention when calling a random operation can be unclear. Specifically, some users may want +the random behavior to be the same across batches while others may want it to differ across batches. +To address this, `vmap` takes a randomness flag. + +The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting +to error. Under "error" mode, any call to a random function will produce an error asking the user to use +one of the other two flags based on their use case. + +Under "different" randomness, elements in a batch produce different random values. For instance, + +```python +def add_noise(x): + y = torch.randn(()) # y will be different across the batch + return x + y + +x = torch.ones(3) +result = vmap(add_noise, randomness="different")(x) # we get 3 different values +``` + +Under "same" randomness, elements in a batch produce same random values. For instance, + +```python +def add_noise(x): + y = torch.randn(()) # y will be the same across the batch + return x + y + +x = torch.ones(3) +result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times +``` + +:::{warning} +Our system only determine the randomness behavior of PyTorch operators and cannot control the +behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions +::: + +:::{note} +Multiple vmap calls using either type of supported randomness will not produce +the same results. Like with standard PyTorch, a user can get randomness reproducibility through +either using `torch.manual_seed()` outside of vmap or by using generators. +::: + +:::{note} +Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch +doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the +most common forms of randomness that we see. If your use case does not fit these forms of randomness, please +file an issue. +::: diff --git a/docs/source/func.ux_limitations.rst b/docs/source/func.ux_limitations.rst deleted file mode 100644 index 803196b70d15b7..00000000000000 --- a/docs/source/func.ux_limitations.rst +++ /dev/null @@ -1,339 +0,0 @@ -.. currentmodule:: torch.func - -.. _ux-limitations: - -UX Limitations -============== - -torch.func, like `JAX `_, has restrictions around -what can be transformed. In general, JAX’s limitations are that transforms -only work with pure functions: that is, functions where the output is completely -determined by the input and that do not involve side effects (like mutation). - -We have a similar guarantee: our transforms work well with pure functions. -However, we do support certain in-place operations. On one hand, writing code -compatible with function transforms may involve changing how you write PyTorch -code, on the other hand, you may find that our transforms let you express things -that were previously difficult to express in PyTorch. - -General limitations -------------------- - -All torch.func transforms share a limitation in that a function should not -assign to global variables. Instead, all outputs to a function must be returned -from the function. This restriction comes from how torch.func is implemented: -each transform wraps Tensor inputs in special torch.func Tensor subclasses -that facilitate the transform. - -So, instead of the following: - -:: - - import torch - from torch.func import grad - - # Don't do this - intermediate = None - - def f(x): - global intermediate - intermediate = x.sin() - z = intermediate.sin() - return z - - x = torch.randn([]) - grad_x = grad(f)(x) - -Please rewrite ``f`` to return ``intermediate``: - -:: - - def f(x): - intermediate = x.sin() - z = intermediate.sin() - return z, intermediate - - grad_x, intermediate = grad(f, has_aux=True)(x) - -torch.autograd APIs -------------------- - -If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad`` -or ``torch.autograd.backward`` inside of a function being transformed by -:func:`vmap` or one of torch.func's AD transforms (:func:`vjp`, :func:`jvp`, -:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it. -If it is unable to do so, you'll receive an error message. - -This is a fundamental design limitation in how PyTorch's AD support is implemented -and the reason why we designed the torch.func library. Please instead use the torch.func -equivalents of the ``torch.autograd`` APIs: -- ``torch.autograd.grad``, ``Tensor.backward`` -> ``torch.func.vjp`` or ``torch.func.grad`` -- ``torch.autograd.functional.jvp`` -> ``torch.func.jvp`` -- ``torch.autograd.functional.jacobian`` -> ``torch.func.jacrev`` or ``torch.func.jacfwd`` -- ``torch.autograd.functional.hessian`` -> ``torch.func.hessian`` - -vmap limitations ----------------- - -.. note:: - :func:`vmap` is our most restrictive transform. - The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not - have these limitations. :func:`jacfwd` (and :func:`hessian`, which is - implemented with :func:`jacfwd`) is a composition of :func:`vmap` and - :func:`jvp` so it also has these limitations. - -``vmap(func)`` is a transform that returns a function that maps ``func`` over -some new dimension of each input Tensor. The mental model for vmap is that it is -like running a for-loop: for pure functions (i.e. in the absence of side -effects), ``vmap(f)(x)`` is equivalent to: - -:: - - torch.stack([f(x_i) for x_i in x.unbind(0)]) - -Mutation: Arbitrary mutation of Python data structures -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the presence of side effects, :func:`vmap` no longer acts like it is running -a for-loop. For example, the following function: - -:: - - def f(x, list): - list.pop() - print("hello!") - return x.sum(0) - - x = torch.randn(3, 1) - lst = [0, 1, 2, 3] - - result = vmap(f, in_dims=(0, None))(x, lst) - -will print "hello!" once and pop only one element from ``lst``. - - -:func:`vmap` executes ``f`` a single time, so all side effects only happen once. - -This is a consequence of how vmap is implemented. torch.func has a special, -internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs, -turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``. -BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized) -behavior for each PyTorch operator. - - -Mutation: in-place PyTorch Operations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You might be here due to receiving an error about vmap-incompatible in-place -operations. :func:`vmap` will raise an error if it encounters an unsupported PyTorch -in-place operation and it will succeed otherwise. Unsupported operations -are those that would cause a Tensor with more elements to be written to a -Tensor with fewer elements. Here's an example of how this can occur: - -:: - - def f(x, y): - x.add_(y) - return x - - x = torch.randn(1) - y = torch.randn(3, 1) # When vmapped over, looks like it has shape [1] - - # Raises an error because `x` has fewer elements than `y`. - vmap(f, in_dims=(None, 0))(x, y) - -``x`` is a Tensor with one element, ``y`` is a Tensor with three elements. -``x + y`` has three elements (due to broadcasting), but attempting to write -three elements back into ``x``, which only has one element, raises an error -due to attempting to write three elements into a Tensor with a single element. - -There is no problem if the Tensor being written to is batched under -:func:`~torch.vmap` (i.e. it is being vmapped over). - -:: - - def f(x, y): - x.add_(y) - return x - - x = torch.randn(3, 1) - y = torch.randn(3, 1) - expected = x + y - - # Does not raise an error because x is being vmapped over. - vmap(f, in_dims=(0, 0))(x, y) - assert torch.allclose(x, expected) - -One common fix for this is to replace calls to factory functions with -their "new_*" equivalent. For example: - -- Replace :func:`torch.zeros` with :meth:`Tensor.new_zeros` -- Replace :func:`torch.empty` with :meth:`Tensor.new_empty` - -To see why this helps, consider the following. - -:: - - def diag_embed(vec): - assert vec.dim() == 1 - result = torch.zeros(vec.shape[0], vec.shape[0]) - result.diagonal().copy_(vec) - return result - - vecs = torch.tensor([[0., 1, 2], [3., 4, 5]]) - - # RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ... - vmap(diag_embed)(vecs) - -Inside of :func:`~torch.vmap`, ``result`` is a Tensor of shape [3, 3]. -However, although ``vec`` looks like it has shape [3], ``vec`` actually has -underlying shape [2, 3]. -It is not possible to copy ``vec`` into ``result.diagonal()``, which has -shape [3], because it has too many elements. - -:: - - def diag_embed(vec): - assert vec.dim() == 1 - result = vec.new_zeros(vec.shape[0], vec.shape[0]) - result.diagonal().copy_(vec) - return result - - vecs = torch.tensor([[0., 1, 2], [3., 4, 5]]) - vmap(diag_embed)(vecs) - -Replacing :func:`torch.zeros` with :meth:`Tensor.new_zeros` makes it so that -``result`` has an underlying Tensor of shape [2, 3, 3], so it is now possible -to copy ``vec``, which has underlying shape [2, 3], into ``result.diagonal()``. - - -Mutation: out= PyTorch Operations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations. -It will error out gracefully if it encounters that in your code. - -This is not a fundamental limitation; we could theoretically support this in the -future but we have chosen not to for now. - -Data-dependent Python control flow -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -We don't yet support ``vmap`` over data-dependent control flow. Data-dependent -control flow is when the condition of an if-statement, while-loop, or -for-loop is a Tensor that is being ``vmap``'ed over. For example, the -following will raise an error message: - -:: - - def relu(x): - if x > 0: - return x - return 0 - - x = torch.randn(3) - vmap(relu)(x) - -However, any control flow that is not dependent on the values in ``vmap``'ed -tensors will work: - -:: - - def custom_dot(x): - if x.dim() == 1: - return torch.dot(x, x) - return (x * x).sum() - - x = torch.randn(3) - vmap(custom_dot)(x) - -JAX supports transforming over -`data-dependent control flow `_ -using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``). -We're investigating adding equivalents of those to PyTorch. - -Data-dependent operations (.item()) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -We do not (and will not) support vmap over a user-defined function that calls -``.item()`` on a Tensor. For example, the following will raise an error message: - -:: - - def f(x): - return x.item() - - x = torch.randn(3) - vmap(f)(x) - -Please try to rewrite your code to not use ``.item()`` calls. - -You may also encounter an error message about using ``.item()`` but you might -not have used it. In those cases, it is possible that PyTorch internally is -calling ``.item()`` -- please file an issue on GitHub and we'll fix -PyTorch internals. - -Dynamic shape operations (nonzero and friends) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -``vmap(f)`` requires that ``f`` applied to every "example" in your input -returns a Tensor with the same shape. Operations such as ``torch.nonzero``, -``torch.is_nonzero`` are not supported and will error as a result. - -To see why, consider the following example: - -:: - - xs = torch.tensor([[0, 1, 2], [0, 0, 3]]) - vmap(torch.nonzero)(xs) - -``torch.nonzero(xs[0])`` returns a Tensor of shape 2; -but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1. -We are unable to construct a single Tensor as an output; -the output would need to be a ragged Tensor (and PyTorch does not yet have -the concept of a ragged Tensor). - - -Randomness ----------- -The user's intention when calling a random operation can be unclear. Specifically, some users may want -the random behavior to be the same across batches while others may want it to differ across batches. -To address this, ``vmap`` takes a randomness flag. - -The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting -to error. Under "error" mode, any call to a random function will produce an error asking the user to use -one of the other two flags based on their use case. - -Under "different" randomness, elements in a batch produce different random values. For instance, - -:: - - def add_noise(x): - y = torch.randn(()) # y will be different across the batch - return x + y - - x = torch.ones(3) - result = vmap(add_noise, randomness="different")(x) # we get 3 different values - -Under "same" randomness, elements in a batch produce same random values. For instance, - -:: - - def add_noise(x): - y = torch.randn(()) # y will be the same across the batch - return x + y - - x = torch.ones(3) - result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times - - -.. warning:: - Our system only determine the randomness behavior of PyTorch operators and cannot control the - behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions - -.. note:: - Multiple vmap calls using either type of supported randomness will not produce - the same results. Like with standard PyTorch, a user can get randomness reproducibility through - either using ``torch.manual_seed()`` outside of vmap or by using generators. - -.. note:: - Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch - doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the - most common forms of randomness that we see. If your use case does not fit these forms of randomness, please - file an issue. diff --git a/docs/source/func.whirlwind_tour.md b/docs/source/func.whirlwind_tour.md new file mode 100644 index 00000000000000..e17172281e84ea --- /dev/null +++ b/docs/source/func.whirlwind_tour.md @@ -0,0 +1,190 @@ +# torch.func Whirlwind Tour + +## What is torch.func? + +```{eval-rst} +.. currentmodule:: torch.func +``` + +torch.func, previously known as functorch, is a library for +[JAX](https://github.com/google/jax)-like composable function transforms in +PyTorch. + +- A "function transform" is a higher-order function that accepts a numerical + function and returns a new function that computes a different quantity. +- torch.func has auto-differentiation transforms (`grad(f)` returns a function + that computes the gradient of `f`), a vectorization/batching transform + (`vmap(f)` returns a function that computes `f` over batches of inputs), + and others. +- These function transforms can compose with each other arbitrarily. For + example, composing `vmap(grad(f))` computes a quantity called + per-sample-gradients that stock PyTorch cannot efficiently compute today. + +## Why composable function transforms? + +There are a number of use cases that are tricky to do in PyTorch today: + +- computing per-sample-gradients (or other per-sample quantities) +- running ensembles of models on a single machine +- efficiently batching together tasks in the inner-loop of MAML +- efficiently computing Jacobians and Hessians +- efficiently computing batched Jacobians and Hessians + +Composing {func}`vmap`, {func}`grad`, {func}`vjp`, and {func}`jvp` transforms +allows us to express the above without designing a separate subsystem for each. + +## What are the transforms? + +### {func}`grad` (gradient computation) + +`grad(func)` is our gradient computation transform. It returns a new function +that computes the gradients of `func`. It assumes `func` returns a single-element +Tensor and by default it computes the gradients of the output of `func` w.r.t. +to the first input. + +```python +import torch +from torch.func import grad +x = torch.randn([]) +cos_x = grad(lambda x: torch.sin(x))(x) +assert torch.allclose(cos_x, x.cos()) + +# Second-order gradients +neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) +assert torch.allclose(neg_sin_x, -x.sin()) +``` + +### {func}`vmap` (auto-vectorization) + +Note: {func}`vmap` imposes restrictions on the code that it can be used on. For more +details, please see {ref}`ux-limitations`. + +`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor +operations in `func`. `vmap(func)` returns a new function that maps `func` +over some dimension (default: 0) of each Tensor in inputs. + +vmap is useful for hiding batch dimensions: one can write a function func that +runs on examples and then lift it to a function that can take batches of +examples with `vmap(func)`, leading to a simpler modeling experience: + +```python +import torch +from torch.func import vmap +batch_size, feature_size = 3, 5 +weights = torch.randn(feature_size, requires_grad=True) + +def model(feature_vec): + # Very simple linear model with activation + assert feature_vec.dim() == 1 + return feature_vec.dot(weights).relu() + +examples = torch.randn(batch_size, feature_size) +result = vmap(model)(examples) +``` + +When composed with {func}`grad`, {func}`vmap` can be used to compute per-sample-gradients: + +```python +from torch.func import vmap +batch_size, feature_size = 3, 5 + +def model(weights,feature_vec): + # Very simple linear model with activation + assert feature_vec.dim() == 1 + return feature_vec.dot(weights).relu() + +def compute_loss(weights, example, target): + y = model(weights, example) + return ((y - target) ** 2).mean() # MSELoss + +weights = torch.randn(feature_size, requires_grad=True) +examples = torch.randn(batch_size, feature_size) +targets = torch.randn(batch_size) +inputs = (weights,examples, targets) +grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) +``` + +### {func}`vjp` (vector-Jacobian product) + +The {func}`vjp` transform applies `func` to `inputs` and returns a new function +that computes the vector-Jacobian product (vjp) given some `cotangents` Tensors. + +```python +from torch.func import vjp + +inputs = torch.randn(3) +func = torch.sin +cotangents = (torch.randn(3),) + +outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) +``` + +### {func}`jvp` (Jacobian-vector product) + +The {func}`jvp` transforms computes Jacobian-vector-products and is also known as +"forward-mode AD". It is not a higher-order function unlike most other transforms, +but it returns the outputs of `func(inputs)` as well as the jvps. + +```python +from torch.func import jvp +x = torch.randn(5) +y = torch.randn(5) +f = lambda x, y: (x * y) +_, out_tangent = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) +assert torch.allclose(out_tangent, x + y) +``` + +### {func}`jacrev`, {func}`jacfwd`, and {func}`hessian` + +The {func}`jacrev` transform returns a new function that takes in `x` and returns +the Jacobian of the function with respect to `x` using reverse-mode AD. + +```python +from torch.func import jacrev +x = torch.randn(5) +jacobian = jacrev(torch.sin)(x) +expected = torch.diag(torch.cos(x)) +assert torch.allclose(jacobian, expected) +``` + +{func}`jacrev` can be composed with {func}`vmap` to produce batched jacobians: + +```python +x = torch.randn(64, 5) +jacobian = vmap(jacrev(torch.sin))(x) +assert jacobian.shape == (64, 5, 5) +``` + +{func}`jacfwd` is a drop-in replacement for jacrev that computes Jacobians using +forward-mode AD: + +```python +from torch.func import jacfwd +x = torch.randn(5) +jacobian = jacfwd(torch.sin)(x) +expected = torch.diag(torch.cos(x)) +assert torch.allclose(jacobian, expected) +``` + +Composing {func}`jacrev` with itself or {func}`jacfwd` can produce hessians: + +```python +def f(x): + return x.sin().sum() + +x = torch.randn(5) +hessian0 = jacrev(jacrev(f))(x) +hessian1 = jacfwd(jacrev(f))(x) +``` + +{func}`hessian` is a convenience function that combines jacfwd and jacrev: + +```python +from torch.func import hessian + +def f(x): + return x.sin().sum() + +x = torch.randn(5) +hess = hessian(f)(x) +``` diff --git a/docs/source/func.whirlwind_tour.rst b/docs/source/func.whirlwind_tour.rst deleted file mode 100644 index ecb4197827c2b1..00000000000000 --- a/docs/source/func.whirlwind_tour.rst +++ /dev/null @@ -1,196 +0,0 @@ -torch.func Whirlwind Tour -========================= - -What is torch.func? -------------------- - -.. currentmodule:: torch.func - -torch.func, previously known as functorch, is a library for -`JAX `_-like composable function transforms in -PyTorch. - -- A "function transform" is a higher-order function that accepts a numerical - function and returns a new function that computes a different quantity. -- torch.func has auto-differentiation transforms (``grad(f)`` returns a function - that computes the gradient of ``f``), a vectorization/batching transform - (``vmap(f)`` returns a function that computes ``f`` over batches of inputs), - and others. -- These function transforms can compose with each other arbitrarily. For - example, composing ``vmap(grad(f))`` computes a quantity called - per-sample-gradients that stock PyTorch cannot efficiently compute today. - -Why composable function transforms? ------------------------------------ -There are a number of use cases that are tricky to do in PyTorch today: -- computing per-sample-gradients (or other per-sample quantities) - -- running ensembles of models on a single machine -- efficiently batching together tasks in the inner-loop of MAML -- efficiently computing Jacobians and Hessians -- efficiently computing batched Jacobians and Hessians - -Composing :func:`vmap`, :func:`grad`, :func:`vjp`, and :func:`jvp` transforms -allows us to express the above without designing a separate subsystem for each. - -What are the transforms? ------------------------- - -:func:`grad` (gradient computation) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -``grad(func)`` is our gradient computation transform. It returns a new function -that computes the gradients of ``func``. It assumes ``func`` returns a single-element -Tensor and by default it computes the gradients of the output of ``func`` w.r.t. -to the first input. - -.. code-block:: python - - import torch - from torch.func import grad - x = torch.randn([]) - cos_x = grad(lambda x: torch.sin(x))(x) - assert torch.allclose(cos_x, x.cos()) - - # Second-order gradients - neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) - assert torch.allclose(neg_sin_x, -x.sin()) - -:func:`vmap` (auto-vectorization) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Note: :func:`vmap` imposes restrictions on the code that it can be used on. For more -details, please see :ref:`ux-limitations`. - -``vmap(func)(*inputs)`` is a transform that adds a dimension to all Tensor -operations in ``func``. ``vmap(func)`` returns a new function that maps ``func`` -over some dimension (default: 0) of each Tensor in inputs. - -vmap is useful for hiding batch dimensions: one can write a function func that -runs on examples and then lift it to a function that can take batches of -examples with ``vmap(func)``, leading to a simpler modeling experience: - -.. code-block:: python - - import torch - from torch.func import vmap - batch_size, feature_size = 3, 5 - weights = torch.randn(feature_size, requires_grad=True) - - def model(feature_vec): - # Very simple linear model with activation - assert feature_vec.dim() == 1 - return feature_vec.dot(weights).relu() - - examples = torch.randn(batch_size, feature_size) - result = vmap(model)(examples) - -When composed with :func:`grad`, :func:`vmap` can be used to compute per-sample-gradients: - -.. code-block:: python - - from torch.func import vmap - batch_size, feature_size = 3, 5 - - def model(weights,feature_vec): - # Very simple linear model with activation - assert feature_vec.dim() == 1 - return feature_vec.dot(weights).relu() - - def compute_loss(weights, example, target): - y = model(weights, example) - return ((y - target) ** 2).mean() # MSELoss - - weights = torch.randn(feature_size, requires_grad=True) - examples = torch.randn(batch_size, feature_size) - targets = torch.randn(batch_size) - inputs = (weights,examples, targets) - grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) - -:func:`vjp` (vector-Jacobian product) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :func:`vjp` transform applies ``func`` to ``inputs`` and returns a new function -that computes the vector-Jacobian product (vjp) given some ``cotangents`` Tensors. - -.. code-block:: python - - from torch.func import vjp - - inputs = torch.randn(3) - func = torch.sin - cotangents = (torch.randn(3),) - - outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) - -:func:`jvp` (Jacobian-vector product) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :func:`jvp` transforms computes Jacobian-vector-products and is also known as -"forward-mode AD". It is not a higher-order function unlike most other transforms, -but it returns the outputs of ``func(inputs)`` as well as the jvps. - -.. code-block:: python - - from torch.func import jvp - x = torch.randn(5) - y = torch.randn(5) - f = lambda x, y: (x * y) - _, out_tangent = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) - assert torch.allclose(out_tangent, x + y) - -:func:`jacrev`, :func:`jacfwd`, and :func:`hessian` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :func:`jacrev` transform returns a new function that takes in ``x`` and returns -the Jacobian of the function with respect to ``x`` using reverse-mode AD. - -.. code-block:: python - - from torch.func import jacrev - x = torch.randn(5) - jacobian = jacrev(torch.sin)(x) - expected = torch.diag(torch.cos(x)) - assert torch.allclose(jacobian, expected) - -:func:`jacrev` can be composed with :func:`vmap` to produce batched jacobians: - -.. code-block:: python - - x = torch.randn(64, 5) - jacobian = vmap(jacrev(torch.sin))(x) - assert jacobian.shape == (64, 5, 5) - -:func:`jacfwd` is a drop-in replacement for jacrev that computes Jacobians using -forward-mode AD: - -.. code-block:: python - - from torch.func import jacfwd - x = torch.randn(5) - jacobian = jacfwd(torch.sin)(x) - expected = torch.diag(torch.cos(x)) - assert torch.allclose(jacobian, expected) - -Composing :func:`jacrev` with itself or :func:`jacfwd` can produce hessians: - -.. code-block:: python - - def f(x): - return x.sin().sum() - - x = torch.randn(5) - hessian0 = jacrev(jacrev(f))(x) - hessian1 = jacfwd(jacrev(f))(x) - -:func:`hessian` is a convenience function that combines jacfwd and jacrev: - -.. code-block:: python - - from torch.func import hessian - - def f(x): - return x.sin().sum() - - x = torch.randn(5) - hess = hessian(f)(x) diff --git a/docs/source/future_mod.md b/docs/source/future_mod.md new file mode 100644 index 00000000000000..3e72f21c3e776d --- /dev/null +++ b/docs/source/future_mod.md @@ -0,0 +1,25 @@ +# torch.\_\_future\_\_ + +```{eval-rst} +.. automodule:: torch.__future__ +``` + +```{eval-rst} +.. currentmodule:: torch.__future__ +``` + +```{eval-rst} +.. autofunction:: set_overwrite_module_params_on_conversion +``` + +```{eval-rst} +.. autofunction:: get_overwrite_module_params_on_conversion +``` + +```{eval-rst} +.. autofunction:: set_swap_module_params_on_conversion +``` + +```{eval-rst} +.. autofunction:: get_swap_module_params_on_conversion +``` diff --git a/docs/source/future_mod.rst b/docs/source/future_mod.rst deleted file mode 100644 index 1ef2a25330ea3d..00000000000000 --- a/docs/source/future_mod.rst +++ /dev/null @@ -1,10 +0,0 @@ -torch.__future__ -=================================== - -.. automodule:: torch.__future__ -.. currentmodule:: torch.__future__ - -.. autofunction:: set_overwrite_module_params_on_conversion -.. autofunction:: get_overwrite_module_params_on_conversion -.. autofunction:: set_swap_module_params_on_conversion -.. autofunction:: get_swap_module_params_on_conversion diff --git a/docs/source/futures.md b/docs/source/futures.md new file mode 100644 index 00000000000000..ec3ffd59067364 --- /dev/null +++ b/docs/source/futures.md @@ -0,0 +1,30 @@ +```{eval-rst} +.. currentmodule:: torch.futures +``` + +(futures-docs)= + +# torch.futures + +This package provides a {class}`~torch.futures.Future` type that encapsulates +an asynchronous execution and a set of utility functions to simplify operations +on {class}`~torch.futures.Future` objects. Currently, the +{class}`~torch.futures.Future` type is primarily used by the +{ref}`distributed-rpc-framework`. + +```{eval-rst} +.. automodule:: torch.futures +``` + +```{eval-rst} +.. autoclass:: Future + :inherited-members: +``` + +```{eval-rst} +.. autofunction:: collect_all +``` + +```{eval-rst} +.. autofunction:: wait_all +``` diff --git a/docs/source/futures.rst b/docs/source/futures.rst deleted file mode 100644 index 82925138934537..00000000000000 --- a/docs/source/futures.rst +++ /dev/null @@ -1,20 +0,0 @@ -.. currentmodule:: torch.futures - -.. _futures-docs: - -torch.futures -============= - -This package provides a :class:`~torch.futures.Future` type that encapsulates -an asynchronous execution and a set of utility functions to simplify operations -on :class:`~torch.futures.Future` objects. Currently, the -:class:`~torch.futures.Future` type is primarily used by the -:ref:`distributed-rpc-framework`. - -.. automodule:: torch.futures - -.. autoclass:: Future - :inherited-members: - -.. autofunction:: collect_all -.. autofunction:: wait_all diff --git a/docs/source/fx.experimental.md b/docs/source/fx.experimental.md new file mode 100644 index 00000000000000..24125cd310bc4c --- /dev/null +++ b/docs/source/fx.experimental.md @@ -0,0 +1,90 @@ +```{eval-rst} +.. currentmodule:: torch.fx.experimental +``` + +# torch.fx.experimental + +:::{warning} +These APIs are experimental and subject to change without notice. +::: + +## torch.fx.experimental.symbolic_shapes + +```{eval-rst} +.. currentmodule:: torch.fx.experimental.symbolic_shapes +``` + +```{eval-rst} +.. automodule:: torch.fx.experimental.symbolic_shapes +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + ShapeEnv + DimDynamic + StrictMinMaxConstraint + RelaxedUnspecConstraint + EqualityConstraint + SymbolicContext + StatelessSymbolicContext + StatefulSymbolicContext + SubclassSymbolicContext + DimConstraints + ShapeEnvSettings + ConvertIntKey + CallMethodKey + PropagateUnbackedSymInts + DivideByKey + InnerTensorKey + Specialization + + hint_int + is_concrete_int + is_concrete_bool + is_concrete_float + has_free_symbols + has_free_unbacked_symbols + guard_or_true + guard_or_false + guard_size_oblivious + sym_and + sym_eq + sym_or + constrain_range + constrain_unify + canonicalize_bool_expr + statically_known_true + statically_known_false + has_static_value + lru_cache + check_consistent + compute_unbacked_bindings + rebind_unbacked + resolve_unbacked_bindings + is_accessor_node +``` + +## torch.fx.experimental.proxy_tensor + +```{eval-rst} +.. currentmodule:: torch.fx.experimental.proxy_tensor +``` + +```{eval-rst} +.. automodule:: torch.fx.experimental.proxy_tensor +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + make_fx + handle_sym_dispatch + get_proxy_mode + maybe_enable_thunkify + maybe_disable_thunkify +``` diff --git a/docs/source/fx.experimental.rst b/docs/source/fx.experimental.rst deleted file mode 100644 index 6a7d8de3585aae..00000000000000 --- a/docs/source/fx.experimental.rst +++ /dev/null @@ -1,74 +0,0 @@ -.. currentmodule:: torch.fx.experimental - -torch.fx.experimental -===================== - -.. warning:: - These APIs are experimental and subject to change without notice. - -torch.fx.experimental.symbolic_shapes -------------------------------------- -.. currentmodule:: torch.fx.experimental.symbolic_shapes -.. automodule:: torch.fx.experimental.symbolic_shapes - -.. autosummary:: - :toctree: generated - :nosignatures: - - ShapeEnv - DimDynamic - StrictMinMaxConstraint - RelaxedUnspecConstraint - EqualityConstraint - SymbolicContext - StatelessSymbolicContext - StatefulSymbolicContext - SubclassSymbolicContext - DimConstraints - ShapeEnvSettings - ConvertIntKey - CallMethodKey - PropagateUnbackedSymInts - DivideByKey - InnerTensorKey - - hint_int - is_concrete_int - is_concrete_bool - is_concrete_float - has_free_symbols - has_free_unbacked_symbols - guard_or_true - guard_or_false - guard_size_oblivious - sym_and - sym_eq - sym_or - constrain_range - constrain_unify - canonicalize_bool_expr - statically_known_true - statically_known_false - has_static_value - lru_cache - check_consistent - compute_unbacked_bindings - rebind_unbacked - resolve_unbacked_bindings - is_accessor_node - -torch.fx.experimental.proxy_tensor -------------------------------------- - -.. currentmodule:: torch.fx.experimental.proxy_tensor -.. automodule:: torch.fx.experimental.proxy_tensor - -.. autosummary:: - :toctree: generated - :nosignatures: - - make_fx - handle_sym_dispatch - get_proxy_mode - maybe_enable_thunkify - maybe_disable_thunkify diff --git a/docs/source/fx.md b/docs/source/fx.md new file mode 100644 index 00000000000000..8b60c80649661a --- /dev/null +++ b/docs/source/fx.md @@ -0,0 +1,1187 @@ +```{eval-rst} +.. currentmodule:: torch.fx +``` + + +# torch.fx + +## Overview +```{eval-rst} +.. automodule:: torch.fx +``` + + +(Writing Transformations)= + + +## Writing Transformations + +What is an FX transform? Essentially, it's a function that looks like this. + +```python + +import torch +import torch.fx + +def transform(m: nn.Module, + tracer_class : type = torch.fx.Tracer) -> torch.nn.Module: + # Step 1: Acquire a Graph representing the code in `m` + + # NOTE: torch.fx.symbolic_trace is a wrapper around a call to + # fx.Tracer.trace and constructing a GraphModule. We'll + # split that out in our transform to allow the caller to + # customize tracing behavior. + graph : torch.fx.Graph = tracer_class().trace(m) + + # Step 2: Modify this Graph or create a new one + graph = ... + + # Step 3: Construct a Module to return + return torch.fx.GraphModule(m, graph) +``` + +Your transform will take in a {class}`torch.nn.Module`, acquire a {class}`Graph` +from it, do some modifications, and return a new +{class}`torch.nn.Module`. You should think of the {class}`torch.nn.Module` that your FX +transform returns as identical to a regular {class}`torch.nn.Module` -- you can pass it to another +FX transform, you can pass it to TorchScript, or you can +run it. Ensuring that the inputs and outputs of your FX transform are a +{class}`torch.nn.Module` will allow for composability. + +```{note} + +It is also possible to modify an existing {class}`GraphModule` instead of +creating a new one, like so: + +```python +import torch +import torch.fx + +def transform(m : nn.Module) -> nn.Module: + gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m) + + # Modify gm.graph + # <...> + + # Recompile the forward() method of `gm` from its Graph + gm.recompile() + + return gm +``` + +Note that you MUST call {meth}`GraphModule.recompile` to bring the generated +`forward()` method on the `GraphModule` in sync with the modified {class}`Graph`. + +Given that you’ve passed in a {class}`torch.nn.Module` that has been traced into a +{class}`Graph`, there are now two primary approaches you can take to building a new +{class}`Graph`. + +### A Quick Primer on Graphs + +Full treatment of the semantics of graphs can be found in the {class}`Graph` +documentation, but we are going to cover the basics here. A {class}`Graph` is +a data structure that represents a method on a {class}`GraphModule`. The +information that this requires is: + +- What are the inputs to the method? +- What are the operations that run inside the method? +- What is the output (i.e. return) value from the method? + +All three of these concepts are represented with {class}`Node` instances. +Let's see what we mean by that with a short example: + +```python + +import torch +import torch.fx + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk(torch.sum( + self.linear(x + self.linear.weight).relu(), dim=-1), 3) + +m = MyModule() +gm = torch.fx.symbolic_trace(m) + +gm.graph.print_tabular() +``` + +Here we define a module `MyModule` for demonstration purposes, instantiate it, +symbolically trace it, then call the {meth}`Graph.print_tabular` method to print +out a table showing the nodes of this {class}`Graph`: + +| opcode | name | target | args | kwargs | +|--------|------|--------|------|--------| +| placeholder | x | x | () | {} | +| get_attr | linear_weight | linear.weight | () | {} | +| call_function | add_1 | | (x, linear_weight) | {} | +| call_module | linear_1 | linear | (add_1,) | {} | +| call_method | relu_1 | relu | (linear_1,) | {} | +| call_function | sum_1 | | (relu_1,) | {'dim': -1} | +| call_function | topk_1 | | (sum_1, 3) | {} | +| output | output | output | (topk_1,) | {} | + +We can use this information to answer the questions we posed above. + +- What are the inputs to the method? In FX, method inputs are specified + via special `placeholder` nodes. In this case, we have a single + `placeholder` node with a `target` of `x`, meaning we have + a single (non-self) argument named x. +- What are the operations within the method? The `get_attr`, + `call_function`, `call_module`, and `call_method` nodes + represent the operations in the method. A full treatment of + the semantics of all of these can be found in the {class}`Node` + documentation. +- What is the return value of the method? The return value in a + {class}`Graph` is specified by a special `output` node. + +Given that we now know the basics of how code is represented in +FX, we can now explore how we would edit a {class}`Graph`. + +### Graph Manipulation + +#### Direct Graph Manipulation + +One approach to building this new {class}`Graph` is to directly manipulate your old +one. To aid in this, we can simply take the {class}`Graph` we obtain from symbolic +tracing and modify it. For example, let’s say we desire to replace +{func}`torch.add` calls with {func}`torch.mul` calls. + +```python + +import torch +import torch.fx + +# Sample module +class M(torch.nn.Module): + def forward(self, x, y): + return torch.add(x, y) + +def transform(m: torch.nn.Module, + tracer_class : type = fx.Tracer) -> torch.nn.Module: + graph : fx.Graph = tracer_class().trace(m) + # FX represents its Graph as an ordered list of + # nodes, so we can iterate through them. + for node in graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == 'call_function': + # The target attribute is the function + # that call_function calls. + if node.target == torch.add: + node.target = torch.mul + + graph.lint() # Does some checks to make sure the + # Graph is well-formed. + + return fx.GraphModule(m, graph) +``` + +We can also do more involved {class}`Graph` rewrites, such as +deleting or appending nodes. To aid in these transformations, +FX has utility functions for transforming the graph that can +be found in the {class}`Graph` documentation. An +example of using these APIs to append a {func}`torch.relu` call +can be found below. + +```python + +# Specifies the insertion point. Any nodes added to the +# Graph within this scope will be inserted after `node` +with traced.graph.inserting_after(node): + # Insert a new `call_function` node calling `torch.relu` + new_node = traced.graph.call_function( + torch.relu, args=(node,)) + + # We want all places that used the value of `node` to + # now use that value after the `relu` call we've added. + # We use the `replace_all_uses_with` API to do this. + node.replace_all_uses_with(new_node) +``` + +For simple transformations that only consist of substitutions, you can also +make use of the [subgraph rewriter.](https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py) + +#### Subgraph Rewriting With replace_pattern() + +FX also provides another level of automation on top of direct graph manipulation. +The {func}`replace_pattern` API is essentially a "find/replace" tool for editing +{class}`Graph`\s. It allows you to specify a `pattern` and `replacement` function +and it will trace through those functions, find instances of the group of operations +in the `pattern` graph, and replace those instances with copies of the `replacement` +graph. This can help to greatly automate tedious graph manipulation code, which can +get unwieldy as the transformations get more complex. + +#### Graph Manipulation Examples + +- [Replace one + op](https://github.com/pytorch/examples/blob/master/fx/replace_op.py) +- [Conv/Batch Norm + fusion](https://github.com/pytorch/pytorch/blob/40cbf342d3c000712da92cfafeaca651b3e0bd3e/torch/fx/experimental/optimization.py#L50) +- [replace_pattern: Basic usage](https://github.com/pytorch/examples/blob/master/fx/subgraph_rewriter_basic_use.py) +- [Quantization](https://pytorch.org/docs/main/quantization.html#prototype-fx-graph-mode-quantization) +- [Invert Transformation](https://github.com/pytorch/examples/blob/master/fx/invert.py) + +### Proxy/Retracing + +Another way of manipulating {class}`Graph`\s is by reusing the {class}`Proxy` +machinery used in symbolic tracing. For example, let’s +imagine that we wanted to write a transformation that decomposed +PyTorch functions into smaller operations. It would transform every +`F.relu(x)` call into `(x > 0) * x`. One possibility would be to +perform the requisite graph rewriting to insert the comparison and +multiplication after the `F.relu`, and then clean up the original +`F.relu`. However, we can automate this process by using {class}`Proxy` +objects to automatically record operations into the {class}`Graph`. + +To use this method, we write the operations that we want inserted as regular +PyTorch code and invoke that code with {class}`Proxy` objects as arguments. +These {class}`Proxy` objects will capture the operations that are performed +on them and append them to the {class}`Graph`. + +```python + +# Note that this decomposition rule can be read as regular Python +def relu_decomposition(x): + return (x > 0) * x + +decomposition_rules = {} +decomposition_rules[F.relu] = relu_decomposition + +def decompose(model: torch.nn.Module, + tracer_class : type = fx.Tracer) -> torch.nn.Module: + """ + Decompose `model` into smaller constituent operations. + Currently,this only supports decomposing ReLU into its + mathematical definition: (x > 0) * x + """ + graph : fx.Graph = tracer_class().trace(model) + new_graph = fx.Graph() + env = {} + tracer = torch.fx.proxy.GraphAppendingTracer(new_graph) + for node in graph.nodes: + if node.op == 'call_function' and node.target in decomposition_rules: + # By wrapping the arguments with proxies, + # we can dispatch to the appropriate + # decomposition rule and implicitly add it + # to the Graph by symbolically tracing it. + proxy_args = [ + fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args] + output_proxy = decomposition_rules[node.target](*proxy_args) + + # Operations on `Proxy` always yield new `Proxy`s, and the + # return value of our decomposition rule is no exception. + # We need to extract the underlying `Node` from the `Proxy` + # to use it in subsequent iterations of this transform. + new_node = output_proxy.node + env[node.name] = new_node + else: + # Default case: we don't have a decomposition rule for this + # node, so just copy the node over into the new graph. + new_node = new_graph.node_copy(node, lambda x: env[x.name]) + env[node.name] = new_node + return fx.GraphModule(model, new_graph) +``` + +In addition to avoiding explicit graph manipulation, using {class}`Proxy`\s +also allows you to specify your rewrite rules as native Python code. +For transformations that require a large amount of rewrite rules +(such as vmap or grad), this can often improve readability and +maintainability of the rules. Note that while calling {class}`Proxy` we also +passed a tracer pointing to the underlying variable `graph`. This is done so +if in case the operations in graph are n-ary (e.g. add is a binary operator) +the call to {class}`Proxy` does not create multiple instances of a graph +tracer which can lead to unexpected runtime errors. We recommend this method +of using {class}`Proxy` especially when the underlying operators can not be +safely assumed to be unary. + +A worked example of using {class}`Proxy`\s for {class}`Graph` manipulation +can be found +[here](https://github.com/pytorch/examples/blob/master/fx/proxy_based_graph_creation.py). + +### The Interpreter Pattern + +A useful code organizational pattern in FX is to loop over all the {class}`Node`\s +in a {class}`Graph` and execute them. This can be used for several things including +runtime analysis of values flowing through the graph or transformation of the code +via retracing with {class}`Proxy`\s. For example, suppose we want to run a +{class}`GraphModule` and record the {class}`torch.Tensor` shape and dtype +properties on the nodes as we see them at runtime. That might look like: + +```python + +import torch +import torch.fx +from torch.fx.node import Node + +from typing import Dict + +class ShapeProp: + """ + Shape propagation. This class takes a `GraphModule`. + Then, its `propagate` method executes the `GraphModule` + node-by-node with the given arguments. As each operation + executes, the ShapeProp class stores away the shape and + element type for the output values of each operation on + the `shape` and `dtype` attributes of the operation's + `Node`. + """ + def __init__(self, mod): + self.mod = mod + self.graph = mod.graph + self.modules = dict(self.mod.named_modules()) + + def propagate(self, *args): + args_iter = iter(args) + env : Dict[str, Node] = {} + + def load_arg(a): + return torch.fx.graph.map_arg(a, lambda n: env[n.name]) + + def fetch_attr(target : str): + target_atoms = target.split('.') + attr_itr = self.mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + for node in self.graph.nodes: + if node.op == 'placeholder': + result = next(args_iter) + elif node.op == 'get_attr': + + result = fetch_attr(node.target) + elif node.op == 'call_function': + result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) + elif node.op == 'call_method': + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + elif node.op == 'call_module': + result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) + + # This is the only code specific to shape propagation. + # you can delete this `if` branch and this becomes + # a generic GraphModule interpreter. + if isinstance(result, torch.Tensor): + node.shape = result.shape + node.dtype = result.dtype + + env[node.name] = result + + return load_arg(self.graph.result) +``` + +As you can see, a full interpreter for FX is not that complicated +but it can be very useful. To ease using this pattern, we provide +the {class}`Interpreter` class, which encompasses the above logic +in a way that certain aspects of the interpreter's execution can +be overridden via method overrides. + +In addition to executing operations, we can also generate a new +`Graph` by feeding {class}`Proxy` values through an interpreter. +Similarly, we provide the {class}`Transformer` class to encompass +this pattern. {class}`Transformer` behaves similarly to +{class}`Interpreter`, but instead of calling the `run` method to +get a concrete output value from the Module, you would call the +{meth}`Transformer.transform` method to return a new +{class}`GraphModule` which was subject to any transformation rules +you installed as overridden methods. + +#### Examples of the Interpreter Pattern + +- [ShapePropagation](https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py) +- [Performance Profiler](https://github.com/pytorch/tutorials/pull/1319) + + +## Debugging + +### Introduction + +Often in the course of authoring transformations, our code will not be quite right. +In this case, we may need to do some debugging. The key is to work +backwards: first, check the results of invoking the generated module to prove or +disprove correctness. Then, inspect and debug the generated code. Then, debug the +process of transformations that led to the generated code. + +If you’re not familiar with debuggers, please see the auxiliary section +{ref}`Available-Debuggers`. + + +### Common Pitfalls in Transform Authoring + +* Nondeterministic `set` iteration order. In Python, the `set` datatype is + unordered. Using `set` to contain collections of objects like `Node`\ s, + for example, can cause unexpected nondeterminism. An example is iterating + over a set of `Node` s to insert them into a `Graph`. Because the + `set` data type is unordered, the ordering of the operations in the output + program will be nondeterministic and can change across program invocations. + The recommended alternative is to use a `dict` data type, which is + [insertion ordered](https://mail.python.org/pipermail/python-dev/2017-December/151283.html) + as of Python 3.7 (and as of cPython 3.6). A `dict` can be used equivalently + to a set by storing values to be deduplicated in the keys of the `dict`. + +### Checking Correctness of Modules + +Because the output of most deep learning modules consists of floating +point {class}`torch.Tensor` instances, checking for equivalence between +the results of two {class}`torch.nn.Module` is not as straightforward +as doing a simple equality check. To motivate this, let's use an +example: + +```python + +import torch +import torch.fx +import torchvision.models as models + +def transform(m : torch.nn.Module) -> torch.nn.Module: + gm = torch.fx.symbolic_trace(m) + + # Imagine we're doing some transforms here + # <...> + + gm.recompile() + + return gm + +resnet18 = models.resnet18() +transformed_resnet18 = transform(resnet18) + +input_image = torch.randn(5, 3, 224, 224) + +assert resnet18(input_image) == transformed_resnet18(input_image) +""" +RuntimeError: Boolean value of Tensor with more than one value is ambiguous +""" +``` + +Here, we've tried to check equality of the values of two deep learning +models with the `==` equality operator. However, this is not well-\ +defined both due to the issue of that operator returning a tensor +and not a bool, but also because comparison of floating point values +should use a margin of error (or epsilon) to account for the +non-commutativity of floating point operations (see +[here](https://floating-point-gui.de/errors/comparison/) for more +details). We can use {func}`torch.allclose` instead, which will give +us an approximate comparison taking into account a relative and +absolute tolerance threshold: + +```python +assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image)) +``` +This is the first tool in our toolbox to check if transformed modules are +behaving as we expect compared to a reference implementation. + +### Debugging the Generated Code + +Because FX generates the `forward()` function on {class}`GraphModule`\s, using +traditional debugging techniques like `print` statements or `pdb` is +not as straightforward. Luckily, we have several techniques we can use +for debugging the generated code. + +#### Use `pdb` +Invoke `pdb` to step into the running program. Although the code that +represents the {class}`Graph` is not in any source file, we can still step +into it manually using `pdb` when the forward pass is invoked. + +```python + +import torch +import torch.fx +import torchvision.models as models + +def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: + graph = tracer_class().trace(inp) + # Transformation logic here + # <...> + + # Return new Module + return fx.GraphModule(inp, graph) + +my_module = models.resnet18() +my_module_transformed = my_pass(my_module) + +input_value = torch.randn(5, 3, 224, 224) + +# When this line is executed at runtime, we will be dropped into an +# interactive `pdb` prompt. We can use the `step` or `s` command to +# step into the execution of the next line +import pdb; pdb.set_trace() + +my_module_transformed(input_value) +``` +(Print the Generated Code)= + +#### Print the Generated Code +If you’d like to run the same code multiple times, then it can be +a bit tedious to step to the right code with `pdb`. In that case, one +approach is to simply copy-paste the generated `forward` pass into +your code and examine it from there. + +```python + +# Assume that `traced` is a GraphModule that has undergone some +# number of transforms + +# Copy this code for later +print(traced) +# Print the code generated from symbolic tracing. This outputs: +""" +def forward(self, y): + x = self.x + add_1 = x + y; x = y = None + return add_1 +""" + +# Subclass the original Module +class SubclassM(M): + def __init__(self): + super().__init__() + + # Paste the generated `forward` function (the one we printed and + # copied above) here + def forward(self, y): + x = self.x + add_1 = x + y; x = y = None + return add_1 + +# Create an instance of the original, untraced Module. Then, create an +# instance of the Module with the copied `forward` function. We can +# now compare the output of both the original and the traced version. +pre_trace = M() +post_trace = SubclassM() +``` +#### Use the `to_folder` Function From `GraphModule` +{meth}`GraphModule.to_folder` is a method in `GraphModule` that allows +you to dump out the generated FX code to a folder. Although copying the +forward pass into the code often suffices as in {ref}`Print the Generated Code`, +it may be easier to examine modules and parameters using `to_folder`. + +```python + +m = symbolic_trace(M()) +m.to_folder("foo", "Bar") +from foo import Bar +y = Bar() +``` +After running the above example, we can then look at the code within +`foo/module.py` and modify it as desired (e.g. adding `print` +statements or using `pdb`) to debug the generated code. + +### Debugging the Transformation + +Now that we've identified that a transformation is creating incorrect +code, it's time to debug the transformation itself. First, we'll check +the {ref}`Limitations of Symbolic Tracing` section in the documentation. +Once we verify that tracing is working as expected, the goal +becomes figuring out what went wrong during our `GraphModule` +transformation. There may be a quick answer in +{ref}`Writing Transformations`, but, if not, there are several ways to +examine our traced module: + +```python + +# Sample Module +class M(torch.nn.Module): + def forward(self, x, y): + return x + y + +# Create an instance of `M` +m = M() + +# Symbolically trace an instance of `M` (returns a GraphModule). In +# this example, we'll only be discussing how to inspect a +# GraphModule, so we aren't showing any sample transforms for the +# sake of brevity. +traced = symbolic_trace(m) + +# Print the code produced by tracing the module. +print(traced) +# The generated `forward` function is: +""" +def forward(self, x, y): + add = x + y; x = y = None + return add +""" + +# Print the internal Graph. +print(traced.graph) +# This print-out returns: +""" +graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=1] = placeholder[target=y] + %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {}) + return add +""" + +# Print a tabular representation of the internal Graph. +traced.graph.print_tabular() +# This gives us: +""" +opcode name target args kwargs +------------- ------ ----------------------- ------ -------- +placeholder x x () {} +placeholder y y () {} +call_function add (x, y) {} +output output output (add,) {} +""" +``` +Using the utility functions above, we can compare our traced Module +before and after we've applied our transformations. Sometimes, a +simple visual comparison is enough to trace down a bug. If it's still +not clear what's going wrong, a debugger like `pdb` can be a good +next step. + +Going off of the example above, consider the following code: + +```python + +# Sample user-defined function +def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: + # Get the Graph from our traced Module + g = tracer_class().trace(module) + + """ + Transformations on `g` go here + """ + + return fx.GraphModule(module, g) + +# Transform the Graph +transformed = transform_graph(traced) + +# Print the new code after our transforms. Check to see if it was +# what we expected +print(transformed) +``` +Using the above example, let’s say that the call to `print(traced)` +showed us that there was an error in our transforms. We want to find +what goes wrong using a debugger. We start a `pdb` session. We can see +what’s happening during the transform by breaking on +`transform_graph(traced)`, then pressing `s` to “step into” the call +to `transform_graph(traced)`. + +We may also have good luck by editing the `print_tabular` method to print +different attributes of the Nodes in the Graph. (For example, we might +want to see the Node’s `input_nodes` and `users`.) + +(Available-Debuggers)= + +### Available Debuggers + +The most common Python debugger is +[pdb](https://docs.python.org/3/library/pdb.html). You can start +your program in “debug mode” with `pdb` by typing +`python -m pdb FILENAME.py` into the command line, where `FILENAME` +is the name of the file you want to debug. After that, you can use the +`pdb` [debugger commands](https://docs.python.org/3/library/pdb.html#debugger-commands) +to move through your running program stepwise. It’s common to set a +breakpoint (`b LINE-NUMBER`) when you start `pdb`, then call `c` to +run the program until that point. This prevents you from having to step +through each line of execution (using `s` or `n`) to get to the part +of the code you want to examine. Alternatively, you can write +`import pdb; pdb.set_trace()` before the line you want to break at. +If you add `pdb.set_trace()`, your program will automatically start +in debug mode when you run it. (In other words, you can just type +`python FILENAME.py` into the command line instead of +`python -m pdb FILENAME.py`.) Once you're running your file in +debug mode, you can step through the code and examine your program's +internal state using certain commands. There are many excellent +tutorials on `pdb` online, including RealPython’s +[“Python Debugging With Pdb”](https://realpython.com/python-debugging-pdb/). + +IDEs like PyCharm or VSCode usually have a debugger built in. In your +IDE, you can choose to either a) use `pdb` by pulling up a terminal +window in your IDE (e.g. View → Terminal in VSCode), or b) use the +built-in debugger (usually a graphical wrapper around `pdb`). + +(Limitations of Symbolic Tracing)= + +## Limitations of Symbolic Tracing + +FX uses a system of **symbolic tracing** (a.k.a [symbolic +execution](https://en.wikipedia.org/wiki/Symbolic_execution)) +to capture the semantics of programs in a transformable/analyzable form. +The system is **tracing** in that it executes the program (really a +{class}`torch.nn.Module` or function) to record operations. It is +**symbolic** in that the data flowing through the program during this +execution is not real data, but rather symbols ({class}`Proxy` in FX parlance). + +Although symbolic tracing works for most neural net code, it has some +limitations. + +### Dynamic Control Flow + +The main limitation of symbolic tracing is it does not currently support +*dynamic control flow*. That is, loops or `if` statements where the +condition may depend on the input values of the program. + +For example, let’s examine the following program: + +```python + +def func_to_trace(x): + if x.sum() > 0: + return torch.relu(x) + else: + + return torch.neg(x) + +traced = torch.fx.symbolic_trace(func_to_trace) +""" + <...> + File "dyn.py", line 6, in func_to_trace + if x.sum() > 0: + File "pytorch/torch/fx/proxy.py", line 155, in __bool__ + return self.tracer.to_bool(self) + File "pytorch/torch/fx/proxy.py", line 85, in to_bool + raise TraceError('symbolically traced variables cannot be used as inputs to control flow') +torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow +""" +``` +The condition to the `if` statement relies on the value of `x.sum()`, +which relies on the value of `x`, a function input. Since +`x` can change (i.e. if you pass a new input tensor to the traced +function), this is *dynamic control flow*. The traceback walks back up +through your code to show you where this situation happens. + +### Static Control Flow + +On the other hand, so-called *static control flow* is supported. Static +control flow is loops or `if` statements whose value cannot change +across invocations. Typically, in PyTorch programs, this control flow +arises for code making decisions about a model’s architecture based on +hyper-parameters. As a concrete example: + +```python + +import torch +import torch.fx + +class MyModule(torch.nn.Module): + def __init__(self, do_activation : bool = False): + super().__init__() + self.do_activation = do_activation + self.linear = torch.nn.Linear(512, 512) + + def forward(self, x): + x = self.linear(x) + # This if-statement is so-called static control flow. + # Its condition does not depend on any input values + if self.do_activation: + x = torch.relu(x) + return x + +without_activation = MyModule(do_activation=False) +with_activation = MyModule(do_activation=True) + +traced_without_activation = torch.fx.symbolic_trace(without_activation) +print(traced_without_activation.code) +""" +def forward(self, x): + linear_1 = self.linear(x); x = None + return linear_1 +""" + +traced_with_activation = torch.fx.symbolic_trace(with_activation) +print(traced_with_activation.code) +""" +import torch +def forward(self, x): + linear_1 = self.linear(x); x = None + relu_1 = torch.relu(linear_1); linear_1 = None + return relu_1 +""" +``` +The if-statement `if self.do_activation` does not depend on any +function inputs, thus it is static. `do_activation` can be considered +to be a hyper-parameter, and the traces of different instances of +`MyModule` with different values for that parameter have different +code. This is a valid pattern that is supported by symbolic tracing. + +Many instances of dynamic control flow are semantically static control +flow. These instances can be made to support symbolic tracing by +removing the data dependencies on input values, for example by moving +values to `Module` attributes or by binding concrete values to arguments +during symbolic tracing: + +```python + +def f(x, flag): + if flag: return x + else: return x*2 + +fx.symbolic_trace(f) # Fails! + +fx.symbolic_trace(f, concrete_args={'flag': True}) +``` +In the case of truly dynamic control flow, the sections of the program +that contain this code can be traced as calls to the Method (see +{ref}`Customizing Tracing`) or function (see +{func}`wrap`) rather than tracing through them. + +### Non- `torch` Functions + +FX uses `__torch_function__` as the mechanism by which it intercepts +calls (see the [technical +overview](https://github.com/pytorch/pytorch/blob/main/torch/fx/README.md#technical-details) +for more information about this). Some functions, such as builtin Python +functions or those in the `math` module, are not covered by +`__torch_function__`, but we would still like to capture them in +symbolic tracing. For example: + +```python + +import torch +import torch.fx +from math import sqrt + +def normalize(x): + """ + Normalize `x` by the size of the batch dimension + """ + return x / sqrt(len(x)) + +# It's valid Python code +normalize(torch.rand(3, 4)) + +traced = torch.fx.symbolic_trace(normalize) +""" + <...> + File "sqrt.py", line 9, in normalize + return x / sqrt(len(x)) + File "pytorch/torch/fx/proxy.py", line 161, in __len__ + raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " +RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope +""" +``` +The error tells us that the built-in function `len` is not supported. +We can make it so that functions like this are recorded in the trace as +direct calls using the {func}`wrap` API: + +```python + +torch.fx.wrap('len') +torch.fx.wrap('sqrt') + +traced = torch.fx.symbolic_trace(normalize) + +print(traced.code) +""" +import math +def forward(self, x): + len_1 = len(x) + sqrt_1 = math.sqrt(len_1); len_1 = None + truediv = x / sqrt_1; x = sqrt_1 = None + return truediv +""" +``` +(Customizing Tracing)= + +### Customizing Tracing with the `Tracer` class + +The {class}`Tracer` class is the class that underlies the +implementation of `symbolic_trace`. The behavior of tracing can be +customized by subclassing Tracer, like so: + +```python + +class MyCustomTracer(torch.fx.Tracer): + # Inside here you can override various methods + # to customize tracing. See the `Tracer` API + # reference + pass + + +# Let's use this custom tracer to trace through this module +class MyModule(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + torch.ones(3, 4) + +mod = MyModule() + +traced_graph = MyCustomTracer().trace(mod) +# trace() returns a Graph. Let's wrap it up in a +# GraphModule to make it runnable +traced = torch.fx.GraphModule(mod, traced_graph) +``` +## Leaf Modules + +Leaf Modules are the modules that appear as calls in the symbolic trace +rather than being traced through. The default set of leaf modules is the +set of standard `torch.nn` module instances. For example: + +```python + +class MySpecialSubmodule(torch.nn.Module): + def forward(self, x): + return torch.neg(x) + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 4) + self.submod = MySpecialSubmodule() + + def forward(self, x): + return self.submod(self.linear(x)) + +traced = torch.fx.symbolic_trace(MyModule()) +print(traced.code) +# `linear` is preserved as a call, yet `submod` is traced though. +# This is because the default set of "Leaf Modules" includes all +# standard `torch.nn` modules. +""" +import torch +def forward(self, x): + linear_1 = self.linear(x); x = None + neg_1 = torch.neg(linear_1); linear_1 = None + return neg_1 +""" +``` +The set of leaf modules can be customized by overriding +{meth}`Tracer.is_leaf_module`. + +### Miscellanea + +- Tensor constructors (e.g. `torch.zeros`, `torch.ones`, + `torch.rand`, `torch.randn`, `torch.sparse_coo_tensor`) + are currently not traceable. + + - The deterministic constructors (`zeros`, `ones`) can be used + and the value they produce will be embedded in the trace as a + constant. This is only problematic if the arguments to these + constructors refers to dynamic input sizes. In this case, + `ones_like` or `zeros_like` may be a viable substitute. + - Nondeterministic constructors (`rand`, `randn`) will have a + single random value embedded in the trace. This is likely not the + intended behavior. One workaround is to wrap `torch.randn` in a `torch.fx.wrap` function and call that instead. + + ```python + + @torch.fx.wrap + def torch_randn(x, shape): + return torch.randn(shape) + + def f(x): + return x + torch_randn(x, 5) + fx.symbolic_trace(f) + ``` + - This behavior may be fixed in a future release. + +- Type annotations + + - Python 3-style type annotations (e.g. + `func(x : torch.Tensor, y : int) -> torch.Tensor`) are supported + and will be preserved by symbolic tracing. + - Python 2-style comment type annotations + `# type: (torch.Tensor, int) -> torch.Tensor` are not currently + supported. + - Annotations on local names within a function are not currently + supported. + + +- Gotcha around `training` flag and submodules + + - When using functionals like `torch.nn.functional.dropout`, it will be common for the training argument to be passed in as `self.training`. During FX tracing, this will likely be baked in as a constant value. + + ```python + + import torch + import torch.fx + + class DropoutRepro(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.dropout(x, training=self.training) + + + traced = torch.fx.symbolic_trace(DropoutRepro()) + print(traced.code) + """ + def forward(self, x): + dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None + return dropout + """ + + traced.eval() + + x = torch.randn(5, 3) + torch.testing.assert_close(traced(x), x) + """ + AssertionError: Tensor-likes are not close! + + Mismatched elements: 15 / 15 (100.0%) + Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) + Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) + """ + ``` + - However, when the standard `nn.Dropout()` submodule is used, the training flag is encapsulated and--because of the preservation of the `nn.Module` object model--can be changed. + + ```python + + class DropoutRepro2(torch.nn.Module): + def __init__(self): + super().__init__() + self.drop = torch.nn.Dropout() + + def forward(self, x): + return self.drop(x) + + traced = torch.fx.symbolic_trace(DropoutRepro2()) + print(traced.code) + """ + def forward(self, x): + drop = self.drop(x); x = None + return drop + """ + + traced.eval() + + x = torch.randn(5, 3) + torch.testing.assert_close(traced(x), x) + ``` + - Because of this difference, consider marking modules that interact with the `training` flag dynamically as leaf modules. + + +## API Reference +```{eval-rst} +.. autofunction:: torch.fx.symbolic_trace +``` +```{eval-rst} +.. autofunction:: torch.fx.wrap +``` +```{eval-rst} +.. autoclass:: torch.fx.GraphModule + :members: + + .. automethod:: __init__ +``` +```{eval-rst} +.. autoclass:: torch.fx.Graph + :members: + + .. automethod:: __init__ +``` +```{eval-rst} +.. autoclass:: torch.fx.Node + :members: +``` +```{eval-rst} +.. autoclass:: torch.fx.Tracer + :members: + :inherited-members: +``` +```{eval-rst} +.. autoclass:: torch.fx.Proxy +``` +```{eval-rst} +.. autoclass:: torch.fx.Interpreter + :members: +``` +```{eval-rst} +.. autoclass:: torch.fx.Transformer + :members: +``` +```{eval-rst} +.. autofunction:: torch.fx.replace_pattern +``` + + + + +```{eval-rst} +.. py:module:: torch.fx.passes +.. py:module:: torch.fx.passes.infra +.. py:module:: torch.fx.passes.backends +.. py:module:: torch.fx.passes.utils +.. py:module:: torch.fx.passes.tests +.. py:module:: torch.fx.experimental +.. py:module:: torch.fx.experimental.unification +.. py:module:: torch.fx.experimental.unification.multipledispatch +.. py:module:: torch.fx.experimental.migrate_gradual_types +.. py:module:: torch.fx.passes.dialect +.. py:module:: torch.fx.passes.dialect.common +.. py:module:: torch.fx.annotate +.. py:module:: torch.fx.config +.. py:module:: torch.fx.experimental.accelerator_partitioner +.. py:module:: torch.fx.experimental.const_fold +.. py:module:: torch.fx.experimental.debug +.. py:module:: torch.fx.experimental.graph_gradual_typechecker +.. py:module:: torch.fx.experimental.merge_matmul +.. py:module:: torch.fx.experimental.meta_tracer +.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint +.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint_generator +.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint_transformation +.. py:module:: torch.fx.experimental.migrate_gradual_types.operation +.. py:module:: torch.fx.experimental.migrate_gradual_types.transform_to_z3 +.. py:module:: torch.fx.experimental.migrate_gradual_types.util +.. py:module:: torch.fx.experimental.migrate_gradual_types.z3_types +.. py:module:: torch.fx.experimental.normalize +.. py:module:: torch.fx.experimental.optimization +.. py:module:: torch.fx.experimental.partitioner_utils +.. py:module:: torch.fx.experimental.recording +.. py:module:: torch.fx.experimental.refinement_types +.. py:module:: torch.fx.experimental.rewriter +.. py:module:: torch.fx.experimental.schema_type_annotation +.. py:module:: torch.fx.experimental.sym_node +.. py:module:: torch.fx.experimental.unification.core +.. py:module:: torch.fx.experimental.unification.dispatch +.. py:module:: torch.fx.experimental.unification.match +.. py:module:: torch.fx.experimental.unification.more +.. py:module:: torch.fx.experimental.unification.multipledispatch.conflict +.. py:module:: torch.fx.experimental.unification.multipledispatch.core +.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher +.. py:module:: torch.fx.experimental.unification.multipledispatch.utils +.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic +.. py:module:: torch.fx.experimental.unification.unification_tools +.. py:module:: torch.fx.experimental.unification.utils +.. py:module:: torch.fx.experimental.unification.variable +.. py:module:: torch.fx.experimental.unify_refinements +.. py:module:: torch.fx.experimental.validator +.. py:module:: torch.fx.graph +.. py:module:: torch.fx.graph_module +.. py:module:: torch.fx.immutable_collections +.. py:module:: torch.fx.interpreter +.. py:module:: torch.fx.node +.. py:module:: torch.fx.operator_schemas +.. py:module:: torch.fx.passes.annotate_getitem_nodes +.. py:module:: torch.fx.passes.backends.cudagraphs +.. py:module:: torch.fx.passes.dialect.common.cse_pass +.. py:module:: torch.fx.passes.fake_tensor_prop +.. py:module:: torch.fx.passes.graph_drawer +.. py:module:: torch.fx.passes.graph_manipulation +.. py:module:: torch.fx.passes.graph_transform_observer +.. py:module:: torch.fx.passes.infra.partitioner +.. py:module:: torch.fx.passes.infra.pass_base +.. py:module:: torch.fx.passes.infra.pass_manager +.. py:module:: torch.fx.passes.net_min_base +.. py:module:: torch.fx.passes.operator_support +.. py:module:: torch.fx.passes.param_fetch +.. py:module:: torch.fx.passes.pass_manager +.. py:module:: torch.fx.passes.reinplace +.. py:module:: torch.fx.passes.runtime_assert +.. py:module:: torch.fx.passes.shape_prop +.. py:module:: torch.fx.passes.split_module +.. py:module:: torch.fx.passes.split_utils +.. py:module:: torch.fx.passes.splitter_base +.. py:module:: torch.fx.passes.tests.test_pass_manager +.. py:module:: torch.fx.passes.tools_common +.. py:module:: torch.fx.passes.utils.common +.. py:module:: torch.fx.passes.utils.fuser_utils +.. py:module:: torch.fx.passes.utils.matcher_utils +.. py:module:: torch.fx.passes.utils.matcher_with_name_node_map_utils +.. py:module:: torch.fx.passes.utils.source_matcher_utils +.. py:module:: torch.fx.proxy +.. py:module:: torch.fx.subgraph_rewriter +.. py:module:: torch.fx.tensor_type +.. py:module:: torch.fx.traceback +``` diff --git a/docs/source/fx.rst b/docs/source/fx.rst deleted file mode 100644 index 442cf8864a64d3..00000000000000 --- a/docs/source/fx.rst +++ /dev/null @@ -1,1201 +0,0 @@ -.. currentmodule:: torch.fx - -torch.fx -============= - -Overview --------- -.. automodule:: torch.fx - -.. _Writing Transformations: - - -Writing Transformations ------------------------ - -What is an FX transform? Essentially, it's a function that looks like this. - -:: - - import torch - import torch.fx - - def transform(m: nn.Module, - tracer_class : type = torch.fx.Tracer) -> torch.nn.Module: - # Step 1: Acquire a Graph representing the code in `m` - - # NOTE: torch.fx.symbolic_trace is a wrapper around a call to - # fx.Tracer.trace and constructing a GraphModule. We'll - # split that out in our transform to allow the caller to - # customize tracing behavior. - graph : torch.fx.Graph = tracer_class().trace(m) - - # Step 2: Modify this Graph or create a new one - graph = ... - - # Step 3: Construct a Module to return - return torch.fx.GraphModule(m, graph) - -Your transform will take in a :class:`torch.nn.Module`, acquire a :class:`Graph` -from it, do some modifications, and return a new -:class:`torch.nn.Module`. You should think of the :class:`torch.nn.Module` that your FX -transform returns as identical to a regular :class:`torch.nn.Module` -- you can pass it to another -FX transform, you can pass it to TorchScript, or you can -run it. Ensuring that the inputs and outputs of your FX transform are a -:class:`torch.nn.Module` will allow for composability. - -.. note:: - - It is also possible to modify an existing :class:`GraphModule` instead of - creating a new one, like so:: - - import torch - import torch.fx - - def transform(m : nn.Module) -> nn.Module: - gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m) - - # Modify gm.graph - # <...> - - # Recompile the forward() method of `gm` from its Graph - gm.recompile() - - return gm - - Note that you MUST call :meth:`GraphModule.recompile` to bring the generated - ``forward()`` method on the ``GraphModule`` in sync with the modified :class:`Graph`. - -Given that you’ve passed in a :class:`torch.nn.Module` that has been traced into a -:class:`Graph`, there are now two primary approaches you can take to building a new -:class:`Graph`. - -A Quick Primer on Graphs -^^^^^^^^^^^^^^^^^^^^^^^^ - -Full treatment of the semantics of graphs can be found in the :class:`Graph` -documentation, but we are going to cover the basics here. A :class:`Graph` is -a data structure that represents a method on a :class:`GraphModule`. The -information that this requires is: - -- What are the inputs to the method? -- What are the operations that run inside the method? -- What is the output (i.e. return) value from the method? - -All three of these concepts are represented with :class:`Node` instances. -Let's see what we mean by that with a short example: - -:: - - import torch - import torch.fx - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return torch.topk(torch.sum( - self.linear(x + self.linear.weight).relu(), dim=-1), 3) - - m = MyModule() - gm = torch.fx.symbolic_trace(m) - - gm.graph.print_tabular() - -Here we define a module ``MyModule`` for demonstration purposes, instantiate it, -symbolically trace it, then call the :meth:`Graph.print_tabular` method to print -out a table showing the nodes of this :class:`Graph`: - - +---------------+---------------+----------------------------+--------------------+-------------+ - | opcode | name | target | args | kwargs | - +===============+===============+============================+====================+=============+ - | placeholder | x | x | () | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | get_attr | linear_weight | linear.weight | () | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_function | add_1 | | (x, linear_weight) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_module | linear_1 | linear | (add_1,) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_method | relu_1 | relu | (linear_1,) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_function | sum_1 | | (relu_1,) | {'dim': -1} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | call_function | topk_1 | | (sum_1, 3) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - | output | output | output | (topk_1,) | {} | - +---------------+---------------+----------------------------+--------------------+-------------+ - -We can use this information to answer the questions we posed above. - -- What are the inputs to the method? In FX, method inputs are specified - via special ``placeholder`` nodes. In this case, we have a single - ``placeholder`` node with a ``target`` of ``x``, meaning we have - a single (non-self) argument named x. -- What are the operations within the method? The ``get_attr``, - ``call_function``, ``call_module``, and ``call_method`` nodes - represent the operations in the method. A full treatment of - the semantics of all of these can be found in the :class:`Node` - documentation. -- What is the return value of the method? The return value in a - :class:`Graph` is specified by a special ``output`` node. - -Given that we now know the basics of how code is represented in -FX, we can now explore how we would edit a :class:`Graph`. - -Graph Manipulation -^^^^^^^^^^^^^^^^^^ - -Direct Graph Manipulation -~~~~~~~~~~~~~~~~~~~~~~~~~ - -One approach to building this new :class:`Graph` is to directly manipulate your old -one. To aid in this, we can simply take the :class:`Graph` we obtain from symbolic -tracing and modify it. For example, let’s say we desire to replace -:func:`torch.add` calls with :func:`torch.mul` calls. - -:: - - import torch - import torch.fx - - # Sample module - class M(torch.nn.Module): - def forward(self, x, y): - return torch.add(x, y) - - def transform(m: torch.nn.Module, - tracer_class : type = fx.Tracer) -> torch.nn.Module: - graph : fx.Graph = tracer_class().trace(m) - # FX represents its Graph as an ordered list of - # nodes, so we can iterate through them. - for node in graph.nodes: - # Checks if we're calling a function (i.e: - # torch.add) - if node.op == 'call_function': - # The target attribute is the function - # that call_function calls. - if node.target == torch.add: - node.target = torch.mul - - graph.lint() # Does some checks to make sure the - # Graph is well-formed. - - return fx.GraphModule(m, graph) - - -We can also do more involved :class:`Graph` rewrites, such as -deleting or appending nodes. To aid in these transformations, -FX has utility functions for transforming the graph that can -be found in the :class:`Graph` documentation. An -example of using these APIs to append a :func:`torch.relu` call -can be found below. - -:: - - # Specifies the insertion point. Any nodes added to the - # Graph within this scope will be inserted after `node` - with traced.graph.inserting_after(node): - # Insert a new `call_function` node calling `torch.relu` - new_node = traced.graph.call_function( - torch.relu, args=(node,)) - - # We want all places that used the value of `node` to - # now use that value after the `relu` call we've added. - # We use the `replace_all_uses_with` API to do this. - node.replace_all_uses_with(new_node) - -For simple transformations that only consist of substitutions, you can also -make use of the `subgraph rewriter. `__ - -Subgraph Rewriting With replace_pattern() -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -FX also provides another level of automation on top of direct graph manipulation. -The :func:`replace_pattern` API is essentially a "find/replace" tool for editing -:class:`Graph`\s. It allows you to specify a ``pattern`` and ``replacement`` function -and it will trace through those functions, find instances of the group of operations -in the ``pattern`` graph, and replace those instances with copies of the ``replacement`` -graph. This can help to greatly automate tedious graph manipulation code, which can -get unwieldy as the transformations get more complex. - -Graph Manipulation Examples -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -- `Replace one - op `__ -- `Conv/Batch Norm - fusion `__ -- `replace_pattern: Basic usage `__ -- `Quantization `__ -- `Invert Transformation `__ - -Proxy/Retracing -^^^^^^^^^^^^^^^ - -Another way of manipulating :class:`Graph`\s is by reusing the :class:`Proxy` -machinery used in symbolic tracing. For example, let’s -imagine that we wanted to write a transformation that decomposed -PyTorch functions into smaller operations. It would transform every -``F.relu(x)`` call into ``(x > 0) * x``. One possibility would be to -perform the requisite graph rewriting to insert the comparison and -multiplication after the ``F.relu``, and then clean up the original -``F.relu``. However, we can automate this process by using :class:`Proxy` -objects to automatically record operations into the :class:`Graph`. - -To use this method, we write the operations that we want inserted as regular -PyTorch code and invoke that code with :class:`Proxy` objects as arguments. -These :class:`Proxy` objects will capture the operations that are performed -on them and append them to the :class:`Graph`. - -:: - - # Note that this decomposition rule can be read as regular Python - def relu_decomposition(x): - return (x > 0) * x - - decomposition_rules = {} - decomposition_rules[F.relu] = relu_decomposition - - def decompose(model: torch.nn.Module, - tracer_class : type = fx.Tracer) -> torch.nn.Module: - """ - Decompose `model` into smaller constituent operations. - Currently,this only supports decomposing ReLU into its - mathematical definition: (x > 0) * x - """ - graph : fx.Graph = tracer_class().trace(model) - new_graph = fx.Graph() - env = {} - tracer = torch.fx.proxy.GraphAppendingTracer(new_graph) - for node in graph.nodes: - if node.op == 'call_function' and node.target in decomposition_rules: - # By wrapping the arguments with proxies, - # we can dispatch to the appropriate - # decomposition rule and implicitly add it - # to the Graph by symbolically tracing it. - proxy_args = [ - fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args] - output_proxy = decomposition_rules[node.target](*proxy_args) - - # Operations on `Proxy` always yield new `Proxy`s, and the - # return value of our decomposition rule is no exception. - # We need to extract the underlying `Node` from the `Proxy` - # to use it in subsequent iterations of this transform. - new_node = output_proxy.node - env[node.name] = new_node - else: - # Default case: we don't have a decomposition rule for this - # node, so just copy the node over into the new graph. - new_node = new_graph.node_copy(node, lambda x: env[x.name]) - env[node.name] = new_node - return fx.GraphModule(model, new_graph) - -In addition to avoiding explicit graph manipulation, using :class:`Proxy`\s -also allows you to specify your rewrite rules as native Python code. -For transformations that require a large amount of rewrite rules -(such as vmap or grad), this can often improve readability and -maintainability of the rules. Note that while calling :class:`Proxy` we also -passed a tracer pointing to the underlying variable `graph`. This is done so -if in case the operations in graph are n-ary (e.g. add is a binary operator) -the call to :class:`Proxy` does not create multiple instances of a graph -tracer which can lead to unexpected runtime errors. We recommend this method -of using :class:`Proxy` especially when the underlying operators can not be -safely assumed to be unary. - -A worked example of using :class:`Proxy`\s for :class:`Graph` manipulation -can be found -`here `__. - -The Interpreter Pattern -^^^^^^^^^^^^^^^^^^^^^^^ - -A useful code organizational pattern in FX is to loop over all the :class:`Node`\s -in a :class:`Graph` and execute them. This can be used for several things including -runtime analysis of values flowing through the graph or transformation of the code -via retracing with :class:`Proxy`\s. For example, suppose we want to run a -:class:`GraphModule` and record the :class:`torch.Tensor` shape and dtype -properties on the nodes as we see them at runtime. That might look like: - -:: - - import torch - import torch.fx - from torch.fx.node import Node - - from typing import Dict - - class ShapeProp: - """ - Shape propagation. This class takes a `GraphModule`. - Then, its `propagate` method executes the `GraphModule` - node-by-node with the given arguments. As each operation - executes, the ShapeProp class stores away the shape and - element type for the output values of each operation on - the `shape` and `dtype` attributes of the operation's - `Node`. - """ - def __init__(self, mod): - self.mod = mod - self.graph = mod.graph - self.modules = dict(self.mod.named_modules()) - - def propagate(self, *args): - args_iter = iter(args) - env : Dict[str, Node] = {} - - def load_arg(a): - return torch.fx.graph.map_arg(a, lambda n: env[n.name]) - - def fetch_attr(target : str): - target_atoms = target.split('.') - attr_itr = self.mod - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") - attr_itr = getattr(attr_itr, atom) - return attr_itr - - for node in self.graph.nodes: - if node.op == 'placeholder': - result = next(args_iter) - elif node.op == 'get_attr': - result = fetch_attr(node.target) - elif node.op == 'call_function': - result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) - elif node.op == 'call_method': - self_obj, *args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = getattr(self_obj, node.target)(*args, **kwargs) - elif node.op == 'call_module': - result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) - - # This is the only code specific to shape propagation. - # you can delete this `if` branch and this becomes - # a generic GraphModule interpreter. - if isinstance(result, torch.Tensor): - node.shape = result.shape - node.dtype = result.dtype - - env[node.name] = result - - return load_arg(self.graph.result) - -As you can see, a full interpreter for FX is not that complicated -but it can be very useful. To ease using this pattern, we provide -the :class:`Interpreter` class, which encompasses the above logic -in a way that certain aspects of the interpreter's execution can -be overridden via method overrides. - -In addition to executing operations, we can also generate a new -`Graph` by feeding :class:`Proxy` values through an interpreter. -Similarly, we provide the :class:`Transformer` class to encompass -this pattern. :class:`Transformer` behaves similarly to -:class:`Interpreter`, but instead of calling the ``run`` method to -get a concrete output value from the Module, you would call the -:meth:`Transformer.transform` method to return a new -:class:`GraphModule` which was subject to any transformation rules -you installed as overridden methods. - -Examples of the Interpreter Pattern -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -- `Shape - Propagation `__ -- `Performance Profiler `__ - - -Debugging ------------ - -Introduction -^^^^^^^^^^^^^^^^ - -Often in the course of authoring transformations, our code will not be quite right. -In this case, we may need to do some debugging. The key is to work -backwards: first, check the results of invoking the generated module to prove or -disprove correctness. Then, inspect and debug the generated code. Then, debug the -process of transformations that led to the generated code. - -If you’re not familiar with debuggers, please see the auxiliary section -:ref:`Available Debuggers`. - - -Common Pitfalls in Transform Authoring -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -* Nondeterministic ``set`` iteration order. In Python, the ``set`` datatype is - unordered. Using ``set`` to contain collections of objects like ``Node``\ s, - for example, can cause unexpected nondeterminism. An example is iterating - over a set of ``Node``\ s to insert them into a ``Graph``. Because the - ``set`` data type is unordered, the ordering of the operations in the output - program will be nondeterministic and can change across program invocations. - The recommended alternative is to use a ``dict`` data type, which is - `insertion ordered `_ - as of Python 3.7 (and as of cPython 3.6). A ``dict`` can be used equivalently - to a set by storing values to be deduplicated in the keys of the ``dict``. - -Checking Correctness of Modules -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Because the output of most deep learning modules consists of floating -point :class:`torch.Tensor` instances, checking for equivalence between -the results of two :class:`torch.nn.Module` is not as straightforward -as doing a simple equality check. To motivate this, let's use an -example: - -:: - - import torch - import torch.fx - import torchvision.models as models - - def transform(m : torch.nn.Module) -> torch.nn.Module: - gm = torch.fx.symbolic_trace(m) - - # Imagine we're doing some transforms here - # <...> - - gm.recompile() - - return gm - - resnet18 = models.resnet18() - transformed_resnet18 = transform(resnet18) - - input_image = torch.randn(5, 3, 224, 224) - - assert resnet18(input_image) == transformed_resnet18(input_image) - """ - RuntimeError: Boolean value of Tensor with more than one value is ambiguous - """ - -Here, we've tried to check equality of the values of two deep learning -models with the ``==`` equality operator. However, this is not well- -defined both due to the issue of that operator returning a tensor -and not a bool, but also because comparison of floating point values -should use a margin of error (or epsilon) to account for the -non-commutativity of floating point operations (see -`here `__ for more -details). We can use :func:`torch.allclose` instead, which will give -us an approximate comparison taking into account a relative and -absolute tolerance threshold: - -:: - - assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image)) - -This is the first tool in our toolbox to check if transformed modules are -behaving as we expect compared to a reference implementation. - -Debugging the Generated Code -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Because FX generates the ``forward()`` function on :class:`GraphModule`\s, using -traditional debugging techniques like ``print`` statements or ``pdb`` is -not as straightforward. Luckily, we have several techniques we can use -for debugging the generated code. - -Use ``pdb`` -~~~~~~~~~~~~~ -Invoke ``pdb`` to step into the running program. Although the code that -represents the :class:`Graph` is not in any source file, we can still step -into it manually using ``pdb`` when the forward pass is invoked. - -:: - - import torch - import torch.fx - import torchvision.models as models - - def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: - graph = tracer_class().trace(inp) - # Transformation logic here - # <...> - - # Return new Module - return fx.GraphModule(inp, graph) - - my_module = models.resnet18() - my_module_transformed = my_pass(my_module) - - input_value = torch.randn(5, 3, 224, 224) - - # When this line is executed at runtime, we will be dropped into an - # interactive `pdb` prompt. We can use the `step` or `s` command to - # step into the execution of the next line - import pdb; pdb.set_trace() - - my_module_transformed(input_value) - -.. _Print the Generated Code: - -Print the Generated Code -~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you’d like to run the same code multiple times, then it can be -a bit tedious to step to the right code with ``pdb``. In that case, one -approach is to simply copy-paste the generated ``forward`` pass into -your code and examine it from there. - -:: - - # Assume that `traced` is a GraphModule that has undergone some - # number of transforms - - # Copy this code for later - print(traced) - # Print the code generated from symbolic tracing. This outputs: - """ - def forward(self, y): - x = self.x - add_1 = x + y; x = y = None - return add_1 - """ - - # Subclass the original Module - class SubclassM(M): - def __init__(self): - super().__init__() - - # Paste the generated `forward` function (the one we printed and - # copied above) here - def forward(self, y): - x = self.x - add_1 = x + y; x = y = None - return add_1 - - # Create an instance of the original, untraced Module. Then, create an - # instance of the Module with the copied `forward` function. We can - # now compare the output of both the original and the traced version. - pre_trace = M() - post_trace = SubclassM() - -Use the ``to_folder`` Function From ``GraphModule`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -:meth:`GraphModule.to_folder` is a method in ``GraphModule`` that allows -you to dump out the generated FX code to a folder. Although copying the -forward pass into the code often suffices as in :ref:`Print the Generated Code`, -it may be easier to examine modules and parameters using ``to_folder``. - -:: - - m = symbolic_trace(M()) - m.to_folder("foo", "Bar") - from foo import Bar - y = Bar() - -After running the above example, we can then look at the code within -``foo/module.py`` and modify it as desired (e.g. adding ``print`` -statements or using ``pdb``) to debug the generated code. - -Debugging the Transformation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Now that we've identified that a transformation is creating incorrect -code, it's time to debug the transformation itself. First, we'll check -the :ref:`Limitations of Symbolic Tracing` section in the documentation. -Once we verify that tracing is working as expected, the goal -becomes figuring out what went wrong during our ``GraphModule`` -transformation. There may be a quick answer in -:ref:`Writing Transformations`, but, if not, there are several ways to -examine our traced module: - -:: - - # Sample Module - class M(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Create an instance of `M` - m = M() - - # Symbolically trace an instance of `M` (returns a GraphModule). In - # this example, we'll only be discussing how to inspect a - # GraphModule, so we aren't showing any sample transforms for the - # sake of brevity. - traced = symbolic_trace(m) - - # Print the code produced by tracing the module. - print(traced) - # The generated `forward` function is: - """ - def forward(self, x, y): - add = x + y; x = y = None - return add - """ - - # Print the internal Graph. - print(traced.graph) - # This print-out returns: - """ - graph(): - %x : [num_users=1] = placeholder[target=x] - %y : [num_users=1] = placeholder[target=y] - %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {}) - return add - """ - - # Print a tabular representation of the internal Graph. - traced.graph.print_tabular() - # This gives us: - """ - opcode name target args kwargs - ------------- ------ ----------------------- ------ -------- - placeholder x x () {} - placeholder y y () {} - call_function add (x, y) {} - output output output (add,) {} - """ - -Using the utility functions above, we can compare our traced Module -before and after we've applied our transformations. Sometimes, a -simple visual comparison is enough to trace down a bug. If it's still -not clear what's going wrong, a debugger like ``pdb`` can be a good -next step. - -Going off of the example above, consider the following code: - -:: - - # Sample user-defined function - def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: - # Get the Graph from our traced Module - g = tracer_class().trace(module) - - """ - Transformations on `g` go here - """ - - return fx.GraphModule(module, g) - - # Transform the Graph - transformed = transform_graph(traced) - - # Print the new code after our transforms. Check to see if it was - # what we expected - print(transformed) - -Using the above example, let’s say that the call to ``print(traced)`` -showed us that there was an error in our transforms. We want to find -what goes wrong using a debugger. We start a ``pdb`` session. We can see -what’s happening during the transform by breaking on -``transform_graph(traced)``, then pressing ``s`` to “step into” the call -to ``transform_graph(traced)``. - -We may also have good luck by editing the ``print_tabular`` method to print -different attributes of the Nodes in the Graph. (For example, we might -want to see the Node’s ``input_nodes`` and ``users``.) - -.. _Available Debuggers: - -Available Debuggers -^^^^^^^^^^^^^^^^^^^^^^ - -The most common Python debugger is -`pdb `__. You can start -your program in “debug mode” with ``pdb`` by typing -``python -m pdb FILENAME.py`` into the command line, where ``FILENAME`` -is the name of the file you want to debug. After that, you can use the -``pdb`` `debugger commands -`__ -to move through your running program stepwise. It’s common to set a -breakpoint (``b LINE-NUMBER``) when you start ``pdb``, then call ``c`` to -run the program until that point. This prevents you from having to step -through each line of execution (using ``s`` or ``n``) to get to the part -of the code you want to examine. Alternatively, you can write -``import pdb; pdb.set_trace()`` before the line you want to break at. -If you add ``pdb.set_trace()``, your program will automatically start -in debug mode when you run it. (In other words, you can just type -``python FILENAME.py`` into the command line instead of -``python -m pdb FILENAME.py``.) Once you're running your file in -debug mode, you can step through the code and examine your program's -internal state using certain commands. There are many excellent -tutorials on ``pdb`` online, including RealPython’s -`“Python Debugging With Pdb” `__. - -IDEs like PyCharm or VSCode usually have a debugger built in. In your -IDE, you can choose to either a) use ``pdb`` by pulling up a terminal -window in your IDE (e.g. View → Terminal in VSCode), or b) use the -built-in debugger (usually a graphical wrapper around ``pdb``). - -.. _Limitations of Symbolic Tracing: - -Limitations of Symbolic Tracing -------------------------------- - -FX uses a system of **symbolic tracing** (a.k.a `symbolic -execution `__) -to capture the semantics of programs in a transformable/analyzable form. -The system is **tracing** in that it executes the program (really a -:class:`torch.nn.Module` or function) to record operations. It is -**symbolic** in that the data flowing through the program during this -execution is not real data, but rather symbols (:class:`Proxy` in FX parlance). - -Although symbolic tracing works for most neural net code, it has some -limitations. - -Dynamic Control Flow -^^^^^^^^^^^^^^^^^^^^ - -The main limitation of symbolic tracing is it does not currently support -*dynamic control flow*. That is, loops or ``if`` statements where the -condition may depend on the input values of the program. - -For example, let’s examine the following program: - -:: - - def func_to_trace(x): - if x.sum() > 0: - return torch.relu(x) - else: - return torch.neg(x) - - traced = torch.fx.symbolic_trace(func_to_trace) - """ - <...> - File "dyn.py", line 6, in func_to_trace - if x.sum() > 0: - File "pytorch/torch/fx/proxy.py", line 155, in __bool__ - return self.tracer.to_bool(self) - File "pytorch/torch/fx/proxy.py", line 85, in to_bool - raise TraceError('symbolically traced variables cannot be used as inputs to control flow') - torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow - """ - -The condition to the ``if`` statement relies on the value of ``x.sum()``, -which relies on the value of ``x``, a function input. Since -``x`` can change (i.e. if you pass a new input tensor to the traced -function), this is *dynamic control flow*. The traceback walks back up -through your code to show you where this situation happens. - -Static Control Flow -~~~~~~~~~~~~~~~~~~~ - -On the other hand, so-called *static control flow* is supported. Static -control flow is loops or ``if`` statements whose value cannot change -across invocations. Typically, in PyTorch programs, this control flow -arises for code making decisions about a model’s architecture based on -hyper-parameters. As a concrete example: - -:: - - import torch - import torch.fx - - class MyModule(torch.nn.Module): - def __init__(self, do_activation : bool = False): - super().__init__() - self.do_activation = do_activation - self.linear = torch.nn.Linear(512, 512) - - def forward(self, x): - x = self.linear(x) - # This if-statement is so-called static control flow. - # Its condition does not depend on any input values - if self.do_activation: - x = torch.relu(x) - return x - - without_activation = MyModule(do_activation=False) - with_activation = MyModule(do_activation=True) - - traced_without_activation = torch.fx.symbolic_trace(without_activation) - print(traced_without_activation.code) - """ - def forward(self, x): - linear_1 = self.linear(x); x = None - return linear_1 - """ - - traced_with_activation = torch.fx.symbolic_trace(with_activation) - print(traced_with_activation.code) - """ - import torch - def forward(self, x): - linear_1 = self.linear(x); x = None - relu_1 = torch.relu(linear_1); linear_1 = None - return relu_1 - """ - -The if-statement ``if self.do_activation`` does not depend on any -function inputs, thus it is static. ``do_activation`` can be considered -to be a hyper-parameter, and the traces of different instances of -``MyModule`` with different values for that parameter have different -code. This is a valid pattern that is supported by symbolic tracing. - -Many instances of dynamic control flow are semantically static control -flow. These instances can be made to support symbolic tracing by -removing the data dependencies on input values, for example by moving -values to ``Module`` attributes or by binding concrete values to arguments -during symbolic tracing: - -:: - - def f(x, flag): - if flag: return x - else: return x*2 - - fx.symbolic_trace(f) # Fails! - - fx.symbolic_trace(f, concrete_args={'flag': True}) - -In the case of truly dynamic control flow, the sections of the program -that contain this code can be traced as calls to the Method (see -:ref:`Customizing Tracing`) or function (see -:func:`wrap`) rather than tracing through them. - -Non-\ ``torch`` Functions -^^^^^^^^^^^^^^^^^^^^^^^^^ - -FX uses ``__torch_function__`` as the mechanism by which it intercepts -calls (see the `technical -overview `__ -for more information about this). Some functions, such as builtin Python -functions or those in the ``math`` module, are not covered by -``__torch_function__``, but we would still like to capture them in -symbolic tracing. For example: - -:: - - import torch - import torch.fx - from math import sqrt - - def normalize(x): - """ - Normalize `x` by the size of the batch dimension - """ - return x / sqrt(len(x)) - - # It's valid Python code - normalize(torch.rand(3, 4)) - - traced = torch.fx.symbolic_trace(normalize) - """ - <...> - File "sqrt.py", line 9, in normalize - return x / sqrt(len(x)) - File "pytorch/torch/fx/proxy.py", line 161, in __len__ - raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " - RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope - """ - -The error tells us that the built-in function ``len`` is not supported. -We can make it so that functions like this are recorded in the trace as -direct calls using the :func:`wrap` API: - -:: - - torch.fx.wrap('len') - torch.fx.wrap('sqrt') - - traced = torch.fx.symbolic_trace(normalize) - - print(traced.code) - """ - import math - def forward(self, x): - len_1 = len(x) - sqrt_1 = math.sqrt(len_1); len_1 = None - truediv = x / sqrt_1; x = sqrt_1 = None - return truediv - """ - -.. _Customizing Tracing: - -Customizing Tracing with the ``Tracer`` class -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :class:`Tracer` class is the class that underlies the -implementation of ``symbolic_trace``. The behavior of tracing can be -customized by subclassing Tracer, like so: - -:: - - class MyCustomTracer(torch.fx.Tracer): - # Inside here you can override various methods - # to customize tracing. See the `Tracer` API - # reference - pass - - - # Let's use this custom tracer to trace through this module - class MyModule(torch.nn.Module): - def forward(self, x): - return torch.relu(x) + torch.ones(3, 4) - - mod = MyModule() - - traced_graph = MyCustomTracer().trace(mod) - # trace() returns a Graph. Let's wrap it up in a - # GraphModule to make it runnable - traced = torch.fx.GraphModule(mod, traced_graph) - -Leaf Modules -~~~~~~~~~~~~ - -Leaf Modules are the modules that appear as calls in the symbolic trace -rather than being traced through. The default set of leaf modules is the -set of standard ``torch.nn`` module instances. For example: - -:: - - class MySpecialSubmodule(torch.nn.Module): - def forward(self, x): - return torch.neg(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 4) - self.submod = MySpecialSubmodule() - - def forward(self, x): - return self.submod(self.linear(x)) - - traced = torch.fx.symbolic_trace(MyModule()) - print(traced.code) - # `linear` is preserved as a call, yet `submod` is traced though. - # This is because the default set of "Leaf Modules" includes all - # standard `torch.nn` modules. - """ - import torch - def forward(self, x): - linear_1 = self.linear(x); x = None - neg_1 = torch.neg(linear_1); linear_1 = None - return neg_1 - """ - -The set of leaf modules can be customized by overriding -:meth:`Tracer.is_leaf_module`. - -Miscellanea -^^^^^^^^^^^ - -- Tensor constructors (e.g. ``torch.zeros``, ``torch.ones``, - ``torch.rand``, ``torch.randn``, ``torch.sparse_coo_tensor``) - are currently not traceable. - - - The deterministic constructors (``zeros``, ``ones``) can be used - and the value they produce will be embedded in the trace as a - constant. This is only problematic if the arguments to these - constructors refers to dynamic input sizes. In this case, - ``ones_like`` or ``zeros_like`` may be a viable substitute. - - Nondeterministic constructors (``rand``, ``randn``) will have a - single random value embedded in the trace. This is likely not the - intended behavior. One workaround is to wrap ``torch.randn`` in a ``torch.fx.wrap`` function and call that instead. - - :: - - @torch.fx.wrap - def torch_randn(x, shape): - return torch.randn(shape) - - def f(x): - return x + torch_randn(x, 5) - fx.symbolic_trace(f) - - - This behavior may be fixed in a future release. - -- Type annotations - - - Python 3-style type annotations (e.g. - ``func(x : torch.Tensor, y : int) -> torch.Tensor``) are supported - and will be preserved by symbolic tracing. - - Python 2-style comment type annotations - ``# type: (torch.Tensor, int) -> torch.Tensor`` are not currently - supported. - - Annotations on local names within a function are not currently - supported. - - -- Gotcha around ``training`` flag and submodules - - - When using functionals like ``torch.nn.functional.dropout``, it will be common for the training argument to be passed in as ``self.training``. During FX tracing, this will likely be baked in as a constant value. - - :: - - import torch - import torch.fx - - class DropoutRepro(torch.nn.Module): - def forward(self, x): - return torch.nn.functional.dropout(x, training=self.training) - - - traced = torch.fx.symbolic_trace(DropoutRepro()) - print(traced.code) - """ - def forward(self, x): - dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None - return dropout - """ - - traced.eval() - - x = torch.randn(5, 3) - torch.testing.assert_close(traced(x), x) - """ - AssertionError: Tensor-likes are not close! - - Mismatched elements: 15 / 15 (100.0%) - Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) - Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) - """ - - - However, when the standard ``nn.Dropout()`` submodule is used, the training flag is encapsulated and--because of the preservation of the ``nn.Module`` object model--can be changed. - - :: - - class DropoutRepro2(torch.nn.Module): - def __init__(self): - super().__init__() - self.drop = torch.nn.Dropout() - - def forward(self, x): - return self.drop(x) - - traced = torch.fx.symbolic_trace(DropoutRepro2()) - print(traced.code) - """ - def forward(self, x): - drop = self.drop(x); x = None - return drop - """ - - traced.eval() - - x = torch.randn(5, 3) - torch.testing.assert_close(traced(x), x) - - - Because of this difference, consider marking modules that interact with the ``training`` flag dynamically as leaf modules. - - -API Reference -------------- - -.. autofunction:: torch.fx.symbolic_trace - -.. autofunction:: torch.fx.wrap - -.. autoclass:: torch.fx.GraphModule - :members: - - .. automethod:: __init__ - -.. autoclass:: torch.fx.Graph - :members: - - .. automethod:: __init__ - -.. autoclass:: torch.fx.Node - :members: - -.. autoclass:: torch.fx.Tracer - :members: - :inherited-members: - -.. autoclass:: torch.fx.Proxy - -.. autoclass:: torch.fx.Interpreter - :members: - -.. autoclass:: torch.fx.Transformer - :members: - -.. autofunction:: torch.fx.replace_pattern - - -.. The experimental and passes submodules are missing docs. -.. Adding it here for coverage but this doesn't add anything to the -.. rendered doc. -.. py:module:: torch.fx.passes -.. py:module:: torch.fx.passes.infra -.. py:module:: torch.fx.passes.backends -.. py:module:: torch.fx.passes.utils -.. py:module:: torch.fx.passes.tests -.. py:module:: torch.fx.experimental -.. py:module:: torch.fx.experimental.unification -.. py:module:: torch.fx.experimental.unification.multipledispatch -.. py:module:: torch.fx.experimental.migrate_gradual_types -.. py:module:: torch.fx.passes.dialect -.. py:module:: torch.fx.passes.dialect.common -.. py:module:: torch.fx.annotate -.. py:module:: torch.fx.config -.. py:module:: torch.fx.experimental.accelerator_partitioner -.. py:module:: torch.fx.experimental.const_fold -.. py:module:: torch.fx.experimental.debug -.. py:module:: torch.fx.experimental.graph_gradual_typechecker -.. py:module:: torch.fx.experimental.merge_matmul -.. py:module:: torch.fx.experimental.meta_tracer -.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint -.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint_generator -.. py:module:: torch.fx.experimental.migrate_gradual_types.constraint_transformation -.. py:module:: torch.fx.experimental.migrate_gradual_types.operation -.. py:module:: torch.fx.experimental.migrate_gradual_types.transform_to_z3 -.. py:module:: torch.fx.experimental.migrate_gradual_types.util -.. py:module:: torch.fx.experimental.migrate_gradual_types.z3_types -.. py:module:: torch.fx.experimental.normalize -.. py:module:: torch.fx.experimental.optimization -.. py:module:: torch.fx.experimental.partitioner_utils -.. py:module:: torch.fx.experimental.recording -.. py:module:: torch.fx.experimental.refinement_types -.. py:module:: torch.fx.experimental.rewriter -.. py:module:: torch.fx.experimental.schema_type_annotation -.. py:module:: torch.fx.experimental.sym_node -.. py:module:: torch.fx.experimental.unification.core -.. py:module:: torch.fx.experimental.unification.dispatch -.. py:module:: torch.fx.experimental.unification.match -.. py:module:: torch.fx.experimental.unification.more -.. py:module:: torch.fx.experimental.unification.multipledispatch.conflict -.. py:module:: torch.fx.experimental.unification.multipledispatch.core -.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher -.. py:module:: torch.fx.experimental.unification.multipledispatch.utils -.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic -.. py:module:: torch.fx.experimental.unification.unification_tools -.. py:module:: torch.fx.experimental.unification.utils -.. py:module:: torch.fx.experimental.unification.variable -.. py:module:: torch.fx.experimental.unify_refinements -.. py:module:: torch.fx.experimental.validator -.. py:module:: torch.fx.graph -.. py:module:: torch.fx.graph_module -.. py:module:: torch.fx.immutable_collections -.. py:module:: torch.fx.interpreter -.. py:module:: torch.fx.node -.. py:module:: torch.fx.operator_schemas -.. py:module:: torch.fx.passes.annotate_getitem_nodes -.. py:module:: torch.fx.passes.backends.cudagraphs -.. py:module:: torch.fx.passes.dialect.common.cse_pass -.. py:module:: torch.fx.passes.fake_tensor_prop -.. py:module:: torch.fx.passes.graph_drawer -.. py:module:: torch.fx.passes.graph_manipulation -.. py:module:: torch.fx.passes.graph_transform_observer -.. py:module:: torch.fx.passes.infra.partitioner -.. py:module:: torch.fx.passes.infra.pass_base -.. py:module:: torch.fx.passes.infra.pass_manager -.. py:module:: torch.fx.passes.net_min_base -.. py:module:: torch.fx.passes.operator_support -.. py:module:: torch.fx.passes.param_fetch -.. py:module:: torch.fx.passes.pass_manager -.. py:module:: torch.fx.passes.reinplace -.. py:module:: torch.fx.passes.runtime_assert -.. py:module:: torch.fx.passes.shape_prop -.. py:module:: torch.fx.passes.split_module -.. py:module:: torch.fx.passes.split_utils -.. py:module:: torch.fx.passes.splitter_base -.. py:module:: torch.fx.passes.tests.test_pass_manager -.. py:module:: torch.fx.passes.tools_common -.. py:module:: torch.fx.passes.utils.common -.. py:module:: torch.fx.passes.utils.fuser_utils -.. py:module:: torch.fx.passes.utils.matcher_utils -.. py:module:: torch.fx.passes.utils.matcher_with_name_node_map_utils -.. py:module:: torch.fx.passes.utils.source_matcher_utils -.. py:module:: torch.fx.proxy -.. py:module:: torch.fx.subgraph_rewriter -.. py:module:: torch.fx.tensor_type -.. py:module:: torch.fx.traceback diff --git a/docs/source/hub.md b/docs/source/hub.md new file mode 100644 index 00000000000000..afb312098866a5 --- /dev/null +++ b/docs/source/hub.md @@ -0,0 +1,157 @@ +# torch.hub + +Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility. + +## Publishing models + +Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) +to a GitHub repository by adding a simple `hubconf.py` file; + +`hubconf.py` can have multiple entrypoints. Each entrypoint is defined as a python function +(example: a pre-trained model you want to publish). + +```python + def entrypoint_name(*args, **kwargs): + # args & kwargs are optional, for models which take positional/keyword arguments. + ... +``` + +### How to implement an entrypoint? + +Here is a code snippet specifies an entrypoint for `resnet18` model if we expand +the implementation in `pytorch/vision/hubconf.py`. +In most case importing the right function in `hubconf.py` is sufficient. Here we +just want to use the expanded version as an example to show how it works. +You can see the full script in +[pytorch/vision repo](https://github.com/pytorch/vision/blob/master/hubconf.py) + +```python + dependencies = ['torch'] + from torchvision.models.resnet import resnet18 as _resnet18 + + # resnet18 is the name of entrypoint + def resnet18(pretrained=False, **kwargs): + """ # This docstring shows up in hub.help() + Resnet18 model + pretrained (bool): kwargs, load pretrained weights into the model + """ + # Call the model, load pretrained weights + model = _resnet18(pretrained=pretrained, **kwargs) + return model +``` + +- `dependencies` variable is a **list** of package names required to **load** the model. Note this might + be slightly different from dependencies required for training a model. +- `args` and `kwargs` are passed along to the real callable function. +- Docstring of the function works as a help message. It explains what does the model do and what + are the allowed positional/keyword arguments. It's highly recommended to add a few examples here. +- Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers. +- Callables prefixed with underscore are considered as helper functions which won't show up in {func}`torch.hub.list()`. +- Pretrained weights can either be stored locally in the GitHub repo, or loadable by + {func}`torch.hub.load_state_dict_from_url()`. If less than 2GB, it's recommended to attach it to a [project release](https://help.github.com/en/articles/distributing-large-binaries) + and use the url from the release. + In the example above `torchvision.models.resnet.resnet18` handles `pretrained`, alternatively you can put the following logic in the entrypoint definition. + +```python + if pretrained: + # For checkpoint saved in local GitHub repo, e.g. =weights/save.pth + dirname = os.path.dirname(__file__) + checkpoint = os.path.join(dirname, ) + state_dict = torch.load(checkpoint) + model.load_state_dict(state_dict) + + # For checkpoint saved elsewhere + checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False)) +``` + +### Important Notice + +- The published models should be at least in a branch/tag. It can't be a random commit. + +## Loading models from Hub + +Pytorch Hub provides convenient APIs to explore all available models in hub +through {func}`torch.hub.list()`, show docstring and examples through +{func}`torch.hub.help()` and load the pre-trained models using +{func}`torch.hub.load()`. + +```{eval-rst} +.. automodule:: torch.hub +``` + +```{eval-rst} +.. autofunction:: list +``` + +```{eval-rst} +.. autofunction:: help +``` + +```{eval-rst} +.. autofunction:: load +``` + +```{eval-rst} +.. autofunction:: download_url_to_file +``` + +```{eval-rst} +.. autofunction:: load_state_dict_from_url +``` + +### Running a loaded model: + +Note that `*args` and `**kwargs` in {func}`torch.hub.load()` are used to +**instantiate** a model. After you have loaded a model, how can you find out +what you can do with the model? +A suggested workflow is + +- `dir(model)` to see all available methods of the model. +- `help(model.foo)` to check what arguments `model.foo` takes to run + +To help users explore without referring to documentation back and forth, we strongly +recommend repo owners make function help messages clear and succinct. It's also helpful +to include a minimal working example. + +### Where are my downloaded models saved? + +The locations are used in the order of + +- Calling `hub.set_dir()` +- `$TORCH_HOME/hub`, if environment variable `TORCH_HOME` is set. +- `$XDG_CACHE_HOME/torch/hub`, if environment variable `XDG_CACHE_HOME` is set. +- `~/.cache/torch/hub` + +```{eval-rst} +.. autofunction:: get_dir +``` + +```{eval-rst} +.. autofunction:: set_dir +``` + +### Caching logic + +By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in the +directory returned by {func}`~torch.hub.get_dir()`. + +Users can force a reload by calling `hub.load(..., force_reload=True)`. This will delete +the existing GitHub folder and downloaded weights, reinitialize a fresh download. This is useful +when updates are published to the same branch, users can keep up with the latest release. + +### Known limitations: + +Torch hub works by importing the package as if it was installed. There are some side effects +introduced by importing in Python. For example, you can see new items in Python caches +`sys.modules` and `sys.path_importer_cache` which is normal Python behavior. +This also means that you may have import errors when importing different models +from different repos, if the repos have the same sub-package names (typically, a +`model` subpackage). A workaround for these kinds of import errors is to +remove the offending sub-package from the `sys.modules` dict; more details can +be found in [this GitHub issue](https://github.com/pytorch/hub/issues/243#issuecomment-942403391). + +A known limitation that is worth mentioning here: users **CANNOT** load two different branches of +the same repo in the **same python process**. It's just like installing two packages with the +same name in Python, which is not good. Cache might join the party and give you surprises if you +actually try that. Of course it's totally fine to load them in separate processes. diff --git a/docs/source/hub.rst b/docs/source/hub.rst deleted file mode 100644 index c48b9d6be2fadb..00000000000000 --- a/docs/source/hub.rst +++ /dev/null @@ -1,153 +0,0 @@ -torch.hub -=================================== -Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility. - -Publishing models ------------------ - -Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) -to a GitHub repository by adding a simple ``hubconf.py`` file; - -``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function -(example: a pre-trained model you want to publish). - -:: - - def entrypoint_name(*args, **kwargs): - # args & kwargs are optional, for models which take positional/keyword arguments. - ... - -How to implement an entrypoint? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Here is a code snippet specifies an entrypoint for ``resnet18`` model if we expand -the implementation in ``pytorch/vision/hubconf.py``. -In most case importing the right function in ``hubconf.py`` is sufficient. Here we -just want to use the expanded version as an example to show how it works. -You can see the full script in -`pytorch/vision repo `_ - -:: - - dependencies = ['torch'] - from torchvision.models.resnet import resnet18 as _resnet18 - - # resnet18 is the name of entrypoint - def resnet18(pretrained=False, **kwargs): - """ # This docstring shows up in hub.help() - Resnet18 model - pretrained (bool): kwargs, load pretrained weights into the model - """ - # Call the model, load pretrained weights - model = _resnet18(pretrained=pretrained, **kwargs) - return model - - -- ``dependencies`` variable is a **list** of package names required to **load** the model. Note this might - be slightly different from dependencies required for training a model. -- ``args`` and ``kwargs`` are passed along to the real callable function. -- Docstring of the function works as a help message. It explains what does the model do and what - are the allowed positional/keyword arguments. It's highly recommended to add a few examples here. -- Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers. -- Callables prefixed with underscore are considered as helper functions which won't show up in :func:`torch.hub.list()`. -- Pretrained weights can either be stored locally in the GitHub repo, or loadable by - :func:`torch.hub.load_state_dict_from_url()`. If less than 2GB, it's recommended to attach it to a `project release `_ - and use the url from the release. - In the example above ``torchvision.models.resnet.resnet18`` handles ``pretrained``, alternatively you can put the following logic in the entrypoint definition. - -:: - - if pretrained: - # For checkpoint saved in local GitHub repo, e.g. =weights/save.pth - dirname = os.path.dirname(__file__) - checkpoint = os.path.join(dirname, ) - state_dict = torch.load(checkpoint) - model.load_state_dict(state_dict) - - # For checkpoint saved elsewhere - checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' - model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False)) - - -Important Notice -^^^^^^^^^^^^^^^^ - -- The published models should be at least in a branch/tag. It can't be a random commit. - - -Loading models from Hub ------------------------ - -Pytorch Hub provides convenient APIs to explore all available models in hub -through :func:`torch.hub.list()`, show docstring and examples through -:func:`torch.hub.help()` and load the pre-trained models using -:func:`torch.hub.load()`. - - -.. automodule:: torch.hub - -.. autofunction:: list - -.. autofunction:: help - -.. autofunction:: load - -.. autofunction:: download_url_to_file - -.. autofunction:: load_state_dict_from_url - -Running a loaded model: -^^^^^^^^^^^^^^^^^^^^^^^ - -Note that ``*args`` and ``**kwargs`` in :func:`torch.hub.load()` are used to -**instantiate** a model. After you have loaded a model, how can you find out -what you can do with the model? -A suggested workflow is - -- ``dir(model)`` to see all available methods of the model. -- ``help(model.foo)`` to check what arguments ``model.foo`` takes to run - -To help users explore without referring to documentation back and forth, we strongly -recommend repo owners make function help messages clear and succinct. It's also helpful -to include a minimal working example. - -Where are my downloaded models saved? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The locations are used in the order of - -- Calling ``hub.set_dir()`` -- ``$TORCH_HOME/hub``, if environment variable ``TORCH_HOME`` is set. -- ``$XDG_CACHE_HOME/torch/hub``, if environment variable ``XDG_CACHE_HOME`` is set. -- ``~/.cache/torch/hub`` - -.. autofunction:: get_dir - -.. autofunction:: set_dir - -Caching logic -^^^^^^^^^^^^^ - -By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in the -directory returned by :func:`~torch.hub.get_dir()`. - -Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete -the existing GitHub folder and downloaded weights, reinitialize a fresh download. This is useful -when updates are published to the same branch, users can keep up with the latest release. - - -Known limitations: -^^^^^^^^^^^^^^^^^^ -Torch hub works by importing the package as if it was installed. There are some side effects -introduced by importing in Python. For example, you can see new items in Python caches -``sys.modules`` and ``sys.path_importer_cache`` which is normal Python behavior. -This also means that you may have import errors when importing different models -from different repos, if the repos have the same sub-package names (typically, a -``model`` subpackage). A workaround for these kinds of import errors is to -remove the offending sub-package from the ``sys.modules`` dict; more details can -be found in `this GitHub issue -`_. - -A known limitation that is worth mentioning here: users **CANNOT** load two different branches of -the same repo in the **same python process**. It's just like installing two packages with the -same name in Python, which is not good. Cache might join the party and give you surprises if you -actually try that. Of course it's totally fine to load them in separate processes. diff --git a/docs/source/index.md b/docs/source/index.md index f74d5d3940c702..e1e8ce5c0f2e59 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -12,21 +12,12 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. Features described in this documentation are classified by release status: - *Stable:* These features will be maintained long-term and there should generally - be no major performance limitations or gaps in documentation. - We also expect to maintain backwards compatibility (although - breaking changes can happen and notice will be given one release ahead - of time). - - *Beta:* These features are tagged as Beta because the API may change based on - user feedback, because the performance needs to improve, or because - coverage across operators is not yet complete. For Beta features, we are - committing to seeing the feature through to the Stable classification. - We are not, however, committing to backwards compatibility. - - *Prototype:* These features are typically not available as part of - binary distributions like PyPI or Conda, except sometimes behind run-time - flags, and are at an early stage for feedback and testing. +**Stable (API-Stable):** +These features will be maintained long-term and there should generally be no major performance limitations or gaps in documentation. We also expect to maintain backwards compatibility (although breaking changes can happen and notice will be given one release ahead of time). + +**Unstable (API-Unstable):** +Encompasses all features that are under active development where APIs may change based on user feedback, requisite performance improvements or because coverage across operators is not yet complete. +The APIs and performance characteristics of these features may change. ```{toctree} :glob: diff --git a/docs/source/jit.rst b/docs/source/jit.rst index 9d37c2a7d3305b..c5ba9063a50c8b 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -310,7 +310,7 @@ your model code correctly. Interpreting Graphs ~~~~~~~~~~~~~~~~~~~ -TorchScript also has a representation at a lower level than the code pretty- +TorchScript also has a representation at a lower level than the code pretty-\ printer, in the form of IR graphs. TorchScript uses a static single assignment (SSA) intermediate representation diff --git a/docs/source/jit_language_reference.md b/docs/source/jit_language_reference.md new file mode 100644 index 00000000000000..97373094820807 --- /dev/null +++ b/docs/source/jit_language_reference.md @@ -0,0 +1,952 @@ +```{contents} +:depth: 2 +:local: true +``` + +```{eval-rst} +.. testsetup:: + + # These are hidden from the docs, but these are necessary for `doctest` + # since the `inspect` module doesn't play nicely with the execution + # environment for `doctest` + import torch + + original_script = torch.jit.script + def script_wrapper(obj, *args, **kwargs): + obj.__module__ = 'FakeMod' + return original_script(obj, *args, **kwargs) + + torch.jit.script = script_wrapper + + original_trace = torch.jit.trace + def trace_wrapper(obj, *args, **kwargs): + obj.__module__ = 'FakeMod' + return original_trace(obj, *args, **kwargs) + + torch.jit.trace = trace_wrapper +``` + +(language-reference)= + +# TorchScript Language Reference + +TorchScript is a statically typed subset of Python that can either be written directly (using +the {func}`@torch.jit.script ` decorator) or generated automatically from Python code via +tracing. When using tracing, code is automatically converted into this subset of +Python by recording only the actual operators on tensors and simply executing and +discarding the other surrounding Python code. + +When writing TorchScript directly using `@torch.jit.script` decorator, the programmer must +only use the subset of Python supported in TorchScript. This section documents +what is supported in TorchScript as if it were a language reference for a stand +alone language. Any features of Python not mentioned in this reference are not +part of TorchScript. See `Builtin Functions` for a complete reference of available +PyTorch tensor methods, modules, and functions. + +As a subset of Python, any valid TorchScript function is also a valid Python +function. This makes it possible to `disable TorchScript` and debug the +function using standard Python tools like `pdb`. The reverse is not true: there +are many valid Python programs that are not valid TorchScript programs. +Instead, TorchScript focuses specifically on the features of Python that are +needed to represent neural network models in PyTorch. + +(types)= + +(supported-type)= + +## Types + +The largest difference between TorchScript and the full Python language is that +TorchScript only supports a small set of types that are needed to express neural +net models. In particular, TorchScript supports: + +```{eval-rst} +.. csv-table:: + :header: "Type", "Description" + + "``Tensor``", "A PyTorch tensor of any dtype, dimension, or backend" + "``Tuple[T0, T1, ..., TN]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)" + "``bool``", "A boolean value" + "``int``", "A scalar integer" + "``float``", "A scalar floating point number" + "``str``", "A string" + "``List[T]``", "A list of which all members are type ``T``" + "``Optional[T]``", "A value which is either None or type ``T``" + "``Dict[K, V]``", "A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and ``float`` are allowed as key types." + "``T``", "A {ref}`TorchScript Class`" + "``E``", "A {ref}`TorchScript Enum`" + "``NamedTuple[T0, T1, ...]``", "A :func:`collections.namedtuple ` tuple type" + "``Union[T0, T1, ...]``", "One of the subtypes ``T0``, ``T1``, etc." +``` + +Unlike Python, each variable in TorchScript function must have a single static type. +This makes it easier to optimize TorchScript functions. + +Example (a type mismatch) + +```{eval-rst} +.. testcode:: + + import torch + + @torch.jit.script + def an_error(x): + if x: + r = torch.rand(1) + else: + r = 4 + return r + +``` + +```{eval-rst} +.. testoutput:: + + Traceback (most recent call last): + ... + RuntimeError: ... + + Type mismatch: r is set to type Tensor in the true branch and type int in the false branch: + @torch.jit.script + def an_error(x): + if x: + ~~~~~ + r = torch.rand(1) + ~~~~~~~~~~~~~~~~~ + else: + ~~~~~ + r = 4 + ~~~~~ <--- HERE + return r + and was used here: + else: + r = 4 + return r + ~ <--- HERE... +``` + +### Unsupported Typing Constructs + +TorchScript does not support all features and types of the {mod}`typing` module. Some of these +are more fundamental things that are unlikely to be added in the future while others +may be added if there is enough user demand to make it a priority. + +These types and features from the {mod}`typing` module are unavailable in TorchScript. + +```{eval-rst} +.. csv-table:: + :header: "Item", "Description" + + ":any:`typing.Any`", ":any:`typing.Any` is currently in development but not yet released" + ":any:`typing.NoReturn`", "Not implemented" + ":any:`typing.Sequence`", "Not implemented" + ":any:`typing.Callable`", "Not implemented" + ":any:`typing.Literal`", "Not implemented" + ":any:`typing.ClassVar`", "Not implemented" + ":any:`typing.Final`", "This is supported for :any:`module attributes ` class attribute annotations but not for functions" + ":any:`typing.AnyStr`", "TorchScript does not support :any:`bytes` so this type is not used" + ":any:`typing.overload`", ":any:`typing.overload` is currently in development but not yet released" + "Type aliases", "Not implemented" + "Nominal vs structural subtyping", "Nominal typing is in development, but structural typing is not" + "NewType", "Unlikely to be implemented" + "Generics", "Unlikely to be implemented" +``` + +Any other functionality from the {any}`typing` module not explicitly listed in this documentation is unsupported. + +### Default Types + +By default, all parameters to a TorchScript function are assumed to be Tensor. +To specify that an argument to a TorchScript function is another type, it is possible to use +MyPy-style type annotations using the types listed above. + +```{eval-rst} +.. testcode:: + + import torch + + @torch.jit.script + def foo(x, tup): + # type: (int, Tuple[Tensor, Tensor]) -> Tensor + t0, t1 = tup + return t0 + t1 + x + + print(foo(3, (torch.rand(3), torch.rand(3)))) +``` + +```{eval-rst} +.. testoutput:: + :hide: + + ... +``` + +:::{note} +It is also possible to annotate types with Python 3 type hints from the +`typing` module. + +```{eval-rst} +.. testcode:: + + import torch + from typing import Tuple + + @torch.jit.script + def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + t0, t1 = tup + return t0 + t1 + x + + print(foo(3, (torch.rand(3), torch.rand(3)))) +``` + +```{eval-rst} +.. testoutput:: + :hide: + + ... +``` +::: + +An empty list is assumed to be `List[Tensor]` and empty dicts +`Dict[str, Tensor]`. To instantiate an empty list or dict of other types, +use `Python 3 type hints`. + +Example (type annotations for Python 3): + +```{eval-rst} +.. testcode:: + + import torch + import torch.nn as nn + from typing import Dict, List, Tuple + + class EmptyDataStructures(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]: + # This annotates the list to be a `List[Tuple[int, float]]` + my_list: List[Tuple[int, float]] = [] + for i in range(10): + my_list.append((i, x.item())) + + my_dict: Dict[str, int] = {} + return my_list, my_dict + + x = torch.jit.script(EmptyDataStructures()) + + + +``` + +### Optional Type Refinement + +TorchScript will refine the type of a variable of type `Optional[T]` when +a comparison to `None` is made inside the conditional of an if-statement or checked in an `assert`. +The compiler can reason about multiple `None` checks that are combined with +`and`, `or`, and `not`. Refinement will also occur for else blocks of if-statements +that are not explicitly written. + +The `None` check must be within the if-statement's condition; assigning +a `None` check to a variable and using it in the if-statement's condition will +not refine the types of variables in the check. +Only local variables will be refined, an attribute like `self.x` will not and must assigned to +a local variable to be refined. + +Example (refining types on parameters and locals): + +```{eval-rst} +.. testcode:: + + import torch + import torch.nn as nn + from typing import Optional + + class M(nn.Module): + z: Optional[int] + + def __init__(self, z): + super().__init__() + # If `z` is None, its type cannot be inferred, so it must + # be specified (above) + self.z = z + + def forward(self, x, y, z): + # type: (Optional[int], Optional[int], Optional[int]) -> int + if x is None: + x = 1 + x = x + 1 + + # Refinement for an attribute by assigning it to a local + z = self.z + if y is not None and z is not None: + x = y + z + + # Refinement via an `assert` + assert z is not None + x += z + return x + + module = torch.jit.script(M(2)) + module = torch.jit.script(M(None)) + +``` + +(TorchScript Class)= + +(TorchScript Classes)= + +(torchscript-classes)= + +### TorchScript Classes + +:::{warning} +TorchScript class support is experimental. Currently it is best suited +for simple record-like types (think a `NamedTuple` with methods +attached). +::: + +Python classes can be used in TorchScript if they are annotated with {func}`@torch.jit.script `, +similar to how you would declare a TorchScript function: + +```{eval-rst} +.. testcode:: + :skipif: True # TODO: fix the source file resolving so this can be tested + + @torch.jit.script + class Foo: + def __init__(self, x, y): + self.x = x + + def aug_add_x(self, inc): + self.x += inc + +``` + +This subset is restricted: + +- All functions must be valid TorchScript functions (including `__init__()`). + +- Classes must be new-style classes, as we use `__new__()` to construct them with pybind11. + +- TorchScript classes are statically typed. Members can only be declared by assigning to + self in the `__init__()` method. + + > For example, assigning to `self` outside of the `__init__()` method: + > + > ``` + > @torch.jit.script + > class Foo: + > def assign_x(self): + > self.x = torch.rand(2, 3) + > ``` + > + > Will result in: + > + > ``` + > RuntimeError: + > Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: + > def assign_x(self): + > self.x = torch.rand(2, 3) + > ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + > ``` + +- No expressions except method definitions are allowed in the body of the class. + +- No support for inheritance or any other polymorphism strategy, except for inheriting + from `object` to specify a new-style class. + +After a class is defined, it can be used in both TorchScript and Python interchangeably +like any other TorchScript type: + +``` +# Declare a TorchScript class +@torch.jit.script +class Pair: + def __init__(self, first, second): + self.first = first + self.second = second + +@torch.jit.script +def sum_pair(p): + # type: (Pair) -> Tensor + return p.first + p.second + +p = Pair(torch.rand(2, 3), torch.rand(2, 3)) +print(sum_pair(p)) +``` + +(TorchScript Enum)= + +(TorchScript Enums)= + +(torchscript-enums)= + +### TorchScript Enums + +Python enums can be used in TorchScript without any extra annotation or code: + +``` +from enum import Enum + + +class Color(Enum): + RED = 1 + GREEN = 2 + +@torch.jit.script +def enum_fn(x: Color, y: Color) -> bool: + if x == Color.RED: + return True + + return x == y +``` + +After an enum is defined, it can be used in both TorchScript and Python interchangeably +like any other TorchScript type. The type of the values of an enum must be `int`, +`float`, or `str`. All values must be of the same type; heterogeneous types for enum +values are not supported. + +### Named Tuples + +Types produced by {func}`collections.namedtuple ` can be used in TorchScript. + +```{eval-rst} +.. testcode:: + + import torch + import collections + + Point = collections.namedtuple('Point', ['x', 'y']) + + @torch.jit.script + def total(point): + # type: (Point) -> Tensor + return point.x + point.y + + p = Point(x=torch.rand(3), y=torch.rand(3)) + print(total(p)) +``` + +```{eval-rst} +.. testoutput:: + :hide: + + ... + +``` + +(jit_iterables)= + +### Iterables + +Some functions (for example, {any}`zip` and {any}`enumerate`) can only operate on iterable types. +Iterable types in TorchScript include `Tensor`s, lists, tuples, dictionaries, strings, +{any}`torch.nn.ModuleList` and {any}`torch.nn.ModuleDict`. + +## Expressions + +The following Python Expressions are supported. + +### Literals + +``` +True +False +None +'string literals' +"string literals" +3 # interpreted as int +3.4 # interpreted as a float +``` + +#### List Construction + +An empty list is assumed have type `List[Tensor]`. +The types of other list literals are derived from the type of the members. +See [Default Types] for more details. + +``` +[3, 4] +[] +[torch.rand(3), torch.rand(4)] +``` + +#### Tuple Construction + +``` +(3, 4) +(3,) +``` + +#### Dict Construction + +An empty dict is assumed have type `Dict[str, Tensor]`. +The types of other dict literals are derived from the type of the members. +See [Default Types] for more details. + +``` +{'hello': 3} +{} +{'a': torch.rand(3), 'b': torch.rand(4)} +``` + +### Variables + +See [Variable Resolution] for how variables are resolved. + +``` +my_variable_name +``` + +### Arithmetic Operators + +``` +a + b +a - b +a * b +a / b +a ^ b +a @ b +``` + +### Comparison Operators + +``` +a == b +a != b +a < b +a > b +a <= b +a >= b +``` + +### Logical Operators + +``` +a and b +a or b +not b +``` + +### Subscripts and Slicing + +``` +t[0] +t[-1] +t[0:2] +t[1:] +t[:1] +t[:] +t[0, 1] +t[0, 1:2] +t[0, :1] +t[-1, 1:, 0] +t[1:, -1, 0] +t[i:j, i] +``` + +### Function Calls + +Calls to `builtin functions` + +``` +torch.rand(3, dtype=torch.int) +``` + +Calls to other script functions: + +```{eval-rst} +.. testcode:: + + import torch + + @torch.jit.script + def foo(x): + return x + 1 + + @torch.jit.script + def bar(x): + return foo(x) +``` + +### Method Calls + +Calls to methods of builtin types like tensor: `x.mm(y)` + +On modules, methods must be compiled before they can be called. The TorchScript +compiler recursively compiles methods it sees when compiling other methods. By default, +compilation starts on the `forward` method. Any methods called by `forward` will +be compiled, and any methods called by those methods, and so on. To start compilation at +a method other than `forward`, use the {func}`@torch.jit.export ` decorator +(`forward` implicitly is marked `@torch.jit.export`). + +Calling a submodule directly (e.g. `self.resnet(input)`) is equivalent to +calling its `forward` method (e.g. `self.resnet.forward(input)`). + +```{eval-rst} +.. testcode:: + :skipif: torchvision is None + + import torch + import torch.nn as nn + import torchvision + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + means = torch.tensor([103.939, 116.779, 123.68]) + self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1)) + resnet = torchvision.models.resnet18() + self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224)) + + def helper(self, input): + return self.resnet(input - self.means) + + def forward(self, input): + return self.helper(input) + + # Since nothing in the model calls `top_level_method`, the compiler + # must be explicitly told to compile this method + @torch.jit.export + def top_level_method(self, input): + return self.other_helper(input) + + def other_helper(self, input): + return input + 10 + + # `my_script_module` will have the compiled methods `forward`, `helper`, + # `top_level_method`, and `other_helper` + my_script_module = torch.jit.script(MyModule()) + +``` + +### Ternary Expressions + +``` +x if x > y else y +``` + +### Casts + +``` +float(ten) +int(3.5) +bool(ten) +str(2)`` +``` + +### Accessing Module Parameters + +``` +self.my_parameter +self.my_submodule.my_parameter +``` + +## Statements + +TorchScript supports the following types of statements: + +### Simple Assignments + +``` +a = b +a += b # short-hand for a = a + b, does not operate in-place on a +a -= b +``` + +### Pattern Matching Assignments + +``` +a, b = tuple_or_list +a, b, *c = a_tuple +``` + +Multiple Assignments + +``` +a = b, c = tup +``` + +### Print Statements + +``` +print("the result of an add:", a + b) +``` + +### If Statements + +``` +if a < 4: + r = -a +elif a < 3: + r = a + a +else: + r = 3 * a +``` + +In addition to bools, floats, ints, and Tensors can be used in a conditional +and will be implicitly casted to a boolean. + +### While Loops + +``` +a = 0 +while a < 4: + print(a) + a += 1 +``` + +### For loops with range + +``` +x = 0 +for i in range(10): + x *= i +``` + +### For loops over tuples + +These unroll the loop, generating a body for +each member of the tuple. The body must type-check correctly for each member. + +``` +tup = (3, torch.rand(4)) +for x in tup: + print(x) +``` + +### For loops over constant nn.ModuleList + +To use a `nn.ModuleList` inside a compiled method, it must be marked +constant by adding the name of the attribute to the `__constants__` +list for the type. For loops over a `nn.ModuleList` will unroll the body of the +loop at compile time, with each member of the constant module list. + +```{eval-rst} +.. testcode:: + + class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(2)) + + def forward(self, input): + return self.weight + input + + class MyModule(torch.nn.Module): + __constants__ = ['mods'] + + def __init__(self): + super().__init__() + self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) + + def forward(self, v): + for module in self.mods: + v = module(v) + return v + + + m = torch.jit.script(MyModule()) + + +``` + +### Break and Continue + +``` +for i in range(5): + if i == 1: + continue + if i == 3: + break + print(i) +``` + +### Return + +``` +return a, b +``` + +## Variable Resolution + +TorchScript supports a subset of Python's variable resolution (i.e. scoping) +rules. Local variables behave the same as in Python, except for the restriction +that a variable must have the same type along all paths through a function. +If a variable has a different type on different branches of an if statement, it +is an error to use it after the end of the if statement. + +Similarly, a variable is not allowed to be used if it is only *defined* along some +paths through the function. + +Example: + +```{eval-rst} +.. testcode:: + + @torch.jit.script + def foo(x): + if x < 0: + y = 4 + print(y) +``` + +```{eval-rst} +.. testoutput:: + + Traceback (most recent call last): + ... + RuntimeError: ... + + y is not defined in the false branch... + @torch.jit.script... + def foo(x): + if x < 0: + ~~~~~~~~~ + y = 4 + ~~~~~ <--- HERE + print(y) + and was used here: + if x < 0: + y = 4 + print(y) + ~ <--- HERE... +``` + +Non-local variables are resolved to Python values at compile time when the +function is defined. These values are then converted into TorchScript values using +the rules described in [Use of Python Values]. + +## Use of Python Values + +To make writing TorchScript more convenient, we allow script code to refer +to Python values in the surrounding scope. For instance, any time there is a +reference to `torch`, the TorchScript compiler is actually resolving it to the +`torch` Python module when the function is declared. These Python values are +not a first class part of TorchScript. Instead they are de-sugared at compile-time +into the primitive types that TorchScript supports. This depends +on the dynamic type of the Python valued referenced when compilation occurs. +This section describes the rules that are used when accessing Python values in TorchScript. + +### Functions + +TorchScript can call Python functions. This functionality is very useful when +incrementally converting a model to TorchScript. The model can be moved function-by-function +to TorchScript, leaving calls to Python functions in place. This way you can incrementally +check the correctness of the model as you go. + +```{eval-rst} +.. autofunction:: torch.jit.is_scripting +``` + +```{eval-rst} +.. autofunction:: torch.jit.is_tracing + +``` + +### Attribute Lookup On Python Modules + +TorchScript can lookup attributes on modules. `Builtin functions` like `torch.add` +are accessed this way. This allows TorchScript to call functions defined in +other modules. + +(constant)= + +### Python-defined Constants + +TorchScript also provides a way to use constants that are defined in Python. +These can be used to hard-code hyper-parameters into the function, or to +define universal constants. There are two ways of specifying that a Python +value should be treated as a constant. + +1. Values looked up as attributes of a module are assumed to be constant: + +```{eval-rst} +.. testcode:: + + import math + import torch + + @torch.jit.script + def fn(): + return math.pi +``` + +2. Attributes of a ScriptModule can be marked constant by annotating them with `Final[T]` + +``` +import torch +import torch.nn as nn + +class Foo(nn.Module): + # `Final` from the `typing_extensions` module can also be used + a : torch.jit.Final[int] + + def __init__(self): + super().__init__() + self.a = 1 + 4 + + def forward(self, input): + return self.a + input + +f = torch.jit.script(Foo()) +``` + +Supported constant Python types are + +- `int` +- `float` +- `bool` +- `torch.device` +- `torch.layout` +- `torch.dtype` +- tuples containing supported types +- `torch.nn.ModuleList` which can be used in a TorchScript for loop + +(module-attributes)= +(Module Attributes)= + +### Module Attributes + +The `torch.nn.Parameter` wrapper and `register_buffer` can be used to assign +tensors to a module. Other values assigned to a module that is compiled +will be added to the compiled module if their types can be inferred. All [types] +available in TorchScript can be used as module attributes. Tensor attributes are +semantically the same as buffers. The type of empty lists and dictionaries and `None` +values cannot be inferred and must be specified via +[PEP 526-style](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations) class annotations. +If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute +to the resulting {class}`ScriptModule`. + +Example: + +```{eval-rst} +.. testcode:: + + from typing import List, Dict + + class Foo(nn.Module): + # `words` is initialized as an empty list, so its type must be specified + words: List[str] + + # The type could potentially be inferred if `a_dict` (below) was not + # empty, but this annotation ensures `some_dict` will be made into the + # proper type + some_dict: Dict[str, int] + + def __init__(self, a_dict): + super().__init__() + self.words = [] + self.some_dict = a_dict + + # `int`s can be inferred + self.my_int = 10 + + def forward(self, input): + # type: (str) -> int + self.words.append(input) + return self.some_dict[input] + self.my_int + + f = torch.jit.script(Foo({'hi': 2})) +``` diff --git a/docs/source/jit_language_reference.rst b/docs/source/jit_language_reference.rst deleted file mode 100644 index ccd3cf873d7203..00000000000000 --- a/docs/source/jit_language_reference.rst +++ /dev/null @@ -1,921 +0,0 @@ -.. contents:: - :local: - :depth: 2 - - -.. testsetup:: - - # These are hidden from the docs, but these are necessary for `doctest` - # since the `inspect` module doesn't play nicely with the execution - # environment for `doctest` - import torch - - original_script = torch.jit.script - def script_wrapper(obj, *args, **kwargs): - obj.__module__ = 'FakeMod' - return original_script(obj, *args, **kwargs) - - torch.jit.script = script_wrapper - - original_trace = torch.jit.trace - def trace_wrapper(obj, *args, **kwargs): - obj.__module__ = 'FakeMod' - return original_trace(obj, *args, **kwargs) - - torch.jit.trace = trace_wrapper - -.. _language-reference: - -TorchScript Language Reference -============================== - -TorchScript is a statically typed subset of Python that can either be written directly (using -the :func:`@torch.jit.script ` decorator) or generated automatically from Python code via -tracing. When using tracing, code is automatically converted into this subset of -Python by recording only the actual operators on tensors and simply executing and -discarding the other surrounding Python code. - -When writing TorchScript directly using ``@torch.jit.script`` decorator, the programmer must -only use the subset of Python supported in TorchScript. This section documents -what is supported in TorchScript as if it were a language reference for a stand -alone language. Any features of Python not mentioned in this reference are not -part of TorchScript. See `Builtin Functions` for a complete reference of available -PyTorch tensor methods, modules, and functions. - -As a subset of Python, any valid TorchScript function is also a valid Python -function. This makes it possible to `disable TorchScript` and debug the -function using standard Python tools like ``pdb``. The reverse is not true: there -are many valid Python programs that are not valid TorchScript programs. -Instead, TorchScript focuses specifically on the features of Python that are -needed to represent neural network models in PyTorch. - -.. _types: -.. _supported type: - -Types -~~~~~ - -The largest difference between TorchScript and the full Python language is that -TorchScript only supports a small set of types that are needed to express neural -net models. In particular, TorchScript supports: - -.. csv-table:: - :header: "Type", "Description" - - "``Tensor``", "A PyTorch tensor of any dtype, dimension, or backend" - "``Tuple[T0, T1, ..., TN]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)" - "``bool``", "A boolean value" - "``int``", "A scalar integer" - "``float``", "A scalar floating point number" - "``str``", "A string" - "``List[T]``", "A list of which all members are type ``T``" - "``Optional[T]``", "A value which is either None or type ``T``" - "``Dict[K, V]``", "A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and ``float`` are allowed as key types." - "``T``", "A `TorchScript Class`_" - "``E``", "A `TorchScript Enum`_" - "``NamedTuple[T0, T1, ...]``", "A :func:`collections.namedtuple ` tuple type" - "``Union[T0, T1, ...]``", "One of the subtypes ``T0``, ``T1``, etc." - -Unlike Python, each variable in TorchScript function must have a single static type. -This makes it easier to optimize TorchScript functions. - -Example (a type mismatch) - -.. testcode:: - - import torch - - @torch.jit.script - def an_error(x): - if x: - r = torch.rand(1) - else: - r = 4 - return r - - -.. testoutput:: - - Traceback (most recent call last): - ... - RuntimeError: ... - - Type mismatch: r is set to type Tensor in the true branch and type int in the false branch: - @torch.jit.script - def an_error(x): - if x: - ~~~~~ - r = torch.rand(1) - ~~~~~~~~~~~~~~~~~ - else: - ~~~~~ - r = 4 - ~~~~~ <--- HERE - return r - and was used here: - else: - r = 4 - return r - ~ <--- HERE... - -Unsupported Typing Constructs -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -TorchScript does not support all features and types of the :mod:`typing` module. Some of these -are more fundamental things that are unlikely to be added in the future while others -may be added if there is enough user demand to make it a priority. - -These types and features from the :mod:`typing` module are unavailable in TorchScript. - -.. csv-table:: - :header: "Item", "Description" - - ":any:`typing.Any`", ":any:`typing.Any` is currently in development but not yet released" - ":any:`typing.NoReturn`", "Not implemented" - ":any:`typing.Sequence`", "Not implemented" - ":any:`typing.Callable`", "Not implemented" - ":any:`typing.Literal`", "Not implemented" - ":any:`typing.ClassVar`", "Not implemented" - ":any:`typing.Final`", "This is supported for :any:`module attributes ` class attribute annotations but not for functions" - ":any:`typing.AnyStr`", "TorchScript does not support :any:`bytes` so this type is not used" - ":any:`typing.overload`", ":any:`typing.overload` is currently in development but not yet released" - "Type aliases", "Not implemented" - "Nominal vs structural subtyping", "Nominal typing is in development, but structural typing is not" - "NewType", "Unlikely to be implemented" - "Generics", "Unlikely to be implemented" - -Any other functionality from the :any:`typing` module not explicitly listed in this documentation is unsupported. - -Default Types -^^^^^^^^^^^^^ - -By default, all parameters to a TorchScript function are assumed to be Tensor. -To specify that an argument to a TorchScript function is another type, it is possible to use -MyPy-style type annotations using the types listed above. - -.. testcode:: - - import torch - - @torch.jit.script - def foo(x, tup): - # type: (int, Tuple[Tensor, Tensor]) -> Tensor - t0, t1 = tup - return t0 + t1 + x - - print(foo(3, (torch.rand(3), torch.rand(3)))) - -.. testoutput:: - :hide: - - ... - -.. note:: - It is also possible to annotate types with Python 3 type hints from the - ``typing`` module. - - .. testcode:: - - import torch - from typing import Tuple - - @torch.jit.script - def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - t0, t1 = tup - return t0 + t1 + x - - print(foo(3, (torch.rand(3), torch.rand(3)))) - - .. testoutput:: - :hide: - - ... - - -An empty list is assumed to be ``List[Tensor]`` and empty dicts -``Dict[str, Tensor]``. To instantiate an empty list or dict of other types, -use `Python 3 type hints`. - -Example (type annotations for Python 3): - -.. testcode:: - - import torch - import torch.nn as nn - from typing import Dict, List, Tuple - - class EmptyDataStructures(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]: - # This annotates the list to be a `List[Tuple[int, float]]` - my_list: List[Tuple[int, float]] = [] - for i in range(10): - my_list.append((i, x.item())) - - my_dict: Dict[str, int] = {} - return my_list, my_dict - - x = torch.jit.script(EmptyDataStructures()) - - - - -Optional Type Refinement -^^^^^^^^^^^^^^^^^^^^^^^^ - -TorchScript will refine the type of a variable of type ``Optional[T]`` when -a comparison to ``None`` is made inside the conditional of an if-statement or checked in an ``assert``. -The compiler can reason about multiple ``None`` checks that are combined with -``and``, ``or``, and ``not``. Refinement will also occur for else blocks of if-statements -that are not explicitly written. - -The ``None`` check must be within the if-statement's condition; assigning -a ``None`` check to a variable and using it in the if-statement's condition will -not refine the types of variables in the check. -Only local variables will be refined, an attribute like ``self.x`` will not and must assigned to -a local variable to be refined. - - -Example (refining types on parameters and locals): - -.. testcode:: - - import torch - import torch.nn as nn - from typing import Optional - - class M(nn.Module): - z: Optional[int] - - def __init__(self, z): - super().__init__() - # If `z` is None, its type cannot be inferred, so it must - # be specified (above) - self.z = z - - def forward(self, x, y, z): - # type: (Optional[int], Optional[int], Optional[int]) -> int - if x is None: - x = 1 - x = x + 1 - - # Refinement for an attribute by assigning it to a local - z = self.z - if y is not None and z is not None: - x = y + z - - # Refinement via an `assert` - assert z is not None - x += z - return x - - module = torch.jit.script(M(2)) - module = torch.jit.script(M(None)) - - -.. _TorchScript Class: -.. _TorchScript Classes: -.. _torchscript-classes: - -TorchScript Classes -^^^^^^^^^^^^^^^^^^^ - -.. warning:: - - TorchScript class support is experimental. Currently it is best suited - for simple record-like types (think a ``NamedTuple`` with methods - attached). - -Python classes can be used in TorchScript if they are annotated with :func:`@torch.jit.script `, -similar to how you would declare a TorchScript function: - -.. testcode:: - :skipif: True # TODO: fix the source file resolving so this can be tested - - @torch.jit.script - class Foo: - def __init__(self, x, y): - self.x = x - - def aug_add_x(self, inc): - self.x += inc - - -This subset is restricted: - -* All functions must be valid TorchScript functions (including ``__init__()``). -* Classes must be new-style classes, as we use ``__new__()`` to construct them with pybind11. -* TorchScript classes are statically typed. Members can only be declared by assigning to - self in the ``__init__()`` method. - - For example, assigning to ``self`` outside of the ``__init__()`` method: :: - - @torch.jit.script - class Foo: - def assign_x(self): - self.x = torch.rand(2, 3) - - Will result in: :: - - RuntimeError: - Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: - def assign_x(self): - self.x = torch.rand(2, 3) - ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE - -* No expressions except method definitions are allowed in the body of the class. -* No support for inheritance or any other polymorphism strategy, except for inheriting - from ``object`` to specify a new-style class. - -After a class is defined, it can be used in both TorchScript and Python interchangeably -like any other TorchScript type: - -:: - - # Declare a TorchScript class - @torch.jit.script - class Pair: - def __init__(self, first, second): - self.first = first - self.second = second - - @torch.jit.script - def sum_pair(p): - # type: (Pair) -> Tensor - return p.first + p.second - - p = Pair(torch.rand(2, 3), torch.rand(2, 3)) - print(sum_pair(p)) - - -.. _TorchScript Enum: -.. _TorchScript Enums: -.. _torchscript-enums: - -TorchScript Enums -^^^^^^^^^^^^^^^^^^^ - -Python enums can be used in TorchScript without any extra annotation or code: - -:: - - from enum import Enum - - - class Color(Enum): - RED = 1 - GREEN = 2 - - @torch.jit.script - def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - - return x == y - -After an enum is defined, it can be used in both TorchScript and Python interchangeably -like any other TorchScript type. The type of the values of an enum must be ``int``, -``float``, or ``str``. All values must be of the same type; heterogeneous types for enum -values are not supported. - - -Named Tuples -^^^^^^^^^^^^ -Types produced by :func:`collections.namedtuple ` can be used in TorchScript. - -.. testcode:: - - import torch - import collections - - Point = collections.namedtuple('Point', ['x', 'y']) - - @torch.jit.script - def total(point): - # type: (Point) -> Tensor - return point.x + point.y - - p = Point(x=torch.rand(3), y=torch.rand(3)) - print(total(p)) - -.. testoutput:: - :hide: - - ... - - -.. _jit_iterables: - -Iterables -^^^^^^^^^ - -Some functions (for example, :any:`zip` and :any:`enumerate`) can only operate on iterable types. -Iterable types in TorchScript include ``Tensor``\s, lists, tuples, dictionaries, strings, -:any:`torch.nn.ModuleList` and :any:`torch.nn.ModuleDict`. - - -Expressions -~~~~~~~~~~~ - -The following Python Expressions are supported. - -Literals -^^^^^^^^ -:: - - True - False - None - 'string literals' - "string literals" - 3 # interpreted as int - 3.4 # interpreted as a float - -List Construction -""""""""""""""""" -An empty list is assumed have type ``List[Tensor]``. -The types of other list literals are derived from the type of the members. -See `Default Types`_ for more details. - -:: - - [3, 4] - [] - [torch.rand(3), torch.rand(4)] - - - -Tuple Construction -"""""""""""""""""" -:: - - (3, 4) - (3,) - - -Dict Construction -""""""""""""""""" -An empty dict is assumed have type ``Dict[str, Tensor]``. -The types of other dict literals are derived from the type of the members. -See `Default Types`_ for more details. - -:: - - {'hello': 3} - {} - {'a': torch.rand(3), 'b': torch.rand(4)} - - -Variables -^^^^^^^^^ -See `Variable Resolution`_ for how variables are resolved. - -:: - - my_variable_name - -Arithmetic Operators -^^^^^^^^^^^^^^^^^^^^ -:: - - a + b - a - b - a * b - a / b - a ^ b - a @ b - -Comparison Operators -^^^^^^^^^^^^^^^^^^^^ -:: - - a == b - a != b - a < b - a > b - a <= b - a >= b - -Logical Operators -^^^^^^^^^^^^^^^^^ -:: - - a and b - a or b - not b - -Subscripts and Slicing -^^^^^^^^^^^^^^^^^^^^^^ -:: - - t[0] - t[-1] - t[0:2] - t[1:] - t[:1] - t[:] - t[0, 1] - t[0, 1:2] - t[0, :1] - t[-1, 1:, 0] - t[1:, -1, 0] - t[i:j, i] - -Function Calls -^^^^^^^^^^^^^^ -Calls to `builtin functions` - -:: - - torch.rand(3, dtype=torch.int) - -Calls to other script functions: - -.. testcode:: - - import torch - - @torch.jit.script - def foo(x): - return x + 1 - - @torch.jit.script - def bar(x): - return foo(x) - -Method Calls -^^^^^^^^^^^^ -Calls to methods of builtin types like tensor: ``x.mm(y)`` - -On modules, methods must be compiled before they can be called. The TorchScript -compiler recursively compiles methods it sees when compiling other methods. By default, -compilation starts on the ``forward`` method. Any methods called by ``forward`` will -be compiled, and any methods called by those methods, and so on. To start compilation at -a method other than ``forward``, use the :func:`@torch.jit.export ` decorator -(``forward`` implicitly is marked ``@torch.jit.export``). - -Calling a submodule directly (e.g. ``self.resnet(input)``) is equivalent to -calling its ``forward`` method (e.g. ``self.resnet.forward(input)``). - -.. testcode:: - :skipif: torchvision is None - - import torch - import torch.nn as nn - import torchvision - - class MyModule(nn.Module): - def __init__(self): - super().__init__() - means = torch.tensor([103.939, 116.779, 123.68]) - self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1)) - resnet = torchvision.models.resnet18() - self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224)) - - def helper(self, input): - return self.resnet(input - self.means) - - def forward(self, input): - return self.helper(input) - - # Since nothing in the model calls `top_level_method`, the compiler - # must be explicitly told to compile this method - @torch.jit.export - def top_level_method(self, input): - return self.other_helper(input) - - def other_helper(self, input): - return input + 10 - - # `my_script_module` will have the compiled methods `forward`, `helper`, - # `top_level_method`, and `other_helper` - my_script_module = torch.jit.script(MyModule()) - - -Ternary Expressions -^^^^^^^^^^^^^^^^^^^ -:: - - x if x > y else y - -Casts -^^^^^ -:: - - float(ten) - int(3.5) - bool(ten) - str(2)`` - -Accessing Module Parameters -^^^^^^^^^^^^^^^^^^^^^^^^^^^ -:: - - self.my_parameter - self.my_submodule.my_parameter - - -Statements -~~~~~~~~~~ - -TorchScript supports the following types of statements: - -Simple Assignments -^^^^^^^^^^^^^^^^^^ -:: - - a = b - a += b # short-hand for a = a + b, does not operate in-place on a - a -= b - -Pattern Matching Assignments -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -:: - - a, b = tuple_or_list - a, b, *c = a_tuple - -Multiple Assignments -:: - - a = b, c = tup - -Print Statements -^^^^^^^^^^^^^^^^ -:: - - print("the result of an add:", a + b) - -If Statements -^^^^^^^^^^^^^ -:: - - if a < 4: - r = -a - elif a < 3: - r = a + a - else: - r = 3 * a - -In addition to bools, floats, ints, and Tensors can be used in a conditional -and will be implicitly casted to a boolean. - -While Loops -^^^^^^^^^^^ -:: - - a = 0 - while a < 4: - print(a) - a += 1 - - -For loops with range -^^^^^^^^^^^^^^^^^^^^ -:: - - x = 0 - for i in range(10): - x *= i - -For loops over tuples -^^^^^^^^^^^^^^^^^^^^^ -These unroll the loop, generating a body for -each member of the tuple. The body must type-check correctly for each member. - -:: - - tup = (3, torch.rand(4)) - for x in tup: - print(x) - - -For loops over constant nn.ModuleList -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To use a ``nn.ModuleList`` inside a compiled method, it must be marked -constant by adding the name of the attribute to the ``__constants__`` -list for the type. For loops over a ``nn.ModuleList`` will unroll the body of the -loop at compile time, with each member of the constant module list. - -.. testcode:: - - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.randn(2)) - - def forward(self, input): - return self.weight + input - - class MyModule(torch.nn.Module): - __constants__ = ['mods'] - - def __init__(self): - super().__init__() - self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) - - def forward(self, v): - for module in self.mods: - v = module(v) - return v - - - m = torch.jit.script(MyModule()) - - - -Break and Continue -^^^^^^^^^^^^^^^^^^ -:: - - for i in range(5): - if i == 1: - continue - if i == 3: - break - print(i) - -Return -^^^^^^ -:: - - return a, b - -Variable Resolution -~~~~~~~~~~~~~~~~~~~ - -TorchScript supports a subset of Python's variable resolution (i.e. scoping) -rules. Local variables behave the same as in Python, except for the restriction -that a variable must have the same type along all paths through a function. -If a variable has a different type on different branches of an if statement, it -is an error to use it after the end of the if statement. - -Similarly, a variable is not allowed to be used if it is only *defined* along some -paths through the function. - -Example: - -.. testcode:: - - @torch.jit.script - def foo(x): - if x < 0: - y = 4 - print(y) - -.. testoutput:: - - Traceback (most recent call last): - ... - RuntimeError: ... - - y is not defined in the false branch... - @torch.jit.script... - def foo(x): - if x < 0: - ~~~~~~~~~ - y = 4 - ~~~~~ <--- HERE - print(y) - and was used here: - if x < 0: - y = 4 - print(y) - ~ <--- HERE... - -Non-local variables are resolved to Python values at compile time when the -function is defined. These values are then converted into TorchScript values using -the rules described in `Use of Python Values`_. - -Use of Python Values -~~~~~~~~~~~~~~~~~~~~ - -To make writing TorchScript more convenient, we allow script code to refer -to Python values in the surrounding scope. For instance, any time there is a -reference to ``torch``, the TorchScript compiler is actually resolving it to the -``torch`` Python module when the function is declared. These Python values are -not a first class part of TorchScript. Instead they are de-sugared at compile-time -into the primitive types that TorchScript supports. This depends -on the dynamic type of the Python valued referenced when compilation occurs. -This section describes the rules that are used when accessing Python values in TorchScript. - -Functions -^^^^^^^^^ - -TorchScript can call Python functions. This functionality is very useful when -incrementally converting a model to TorchScript. The model can be moved function-by-function -to TorchScript, leaving calls to Python functions in place. This way you can incrementally -check the correctness of the model as you go. - - -.. autofunction:: torch.jit.is_scripting - -.. autofunction:: torch.jit.is_tracing - - -Attribute Lookup On Python Modules -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -TorchScript can lookup attributes on modules. `Builtin functions` like ``torch.add`` -are accessed this way. This allows TorchScript to call functions defined in -other modules. - -.. _constant: - -Python-defined Constants -^^^^^^^^^^^^^^^^^^^^^^^^ -TorchScript also provides a way to use constants that are defined in Python. -These can be used to hard-code hyper-parameters into the function, or to -define universal constants. There are two ways of specifying that a Python -value should be treated as a constant. - -1. Values looked up as attributes of a module are assumed to be constant: - -.. testcode:: - - import math - import torch - - @torch.jit.script - def fn(): - return math.pi - -2. Attributes of a ScriptModule can be marked constant by annotating them with ``Final[T]`` - -:: - - import torch - import torch.nn as nn - - class Foo(nn.Module): - # `Final` from the `typing_extensions` module can also be used - a : torch.jit.Final[int] - - def __init__(self): - super().__init__() - self.a = 1 + 4 - - def forward(self, input): - return self.a + input - - f = torch.jit.script(Foo()) - -Supported constant Python types are - -* ``int`` -* ``float`` -* ``bool`` -* ``torch.device`` -* ``torch.layout`` -* ``torch.dtype`` -* tuples containing supported types -* ``torch.nn.ModuleList`` which can be used in a TorchScript for loop - - - - -.. _module attributes: - -Module Attributes -^^^^^^^^^^^^^^^^^ - -The ``torch.nn.Parameter`` wrapper and ``register_buffer`` can be used to assign -tensors to a module. Other values assigned to a module that is compiled -will be added to the compiled module if their types can be inferred. All `types`_ -available in TorchScript can be used as module attributes. Tensor attributes are -semantically the same as buffers. The type of empty lists and dictionaries and ``None`` -values cannot be inferred and must be specified via -`PEP 526-style `_ class annotations. -If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute -to the resulting :class:`ScriptModule`. - -Example: - -.. testcode:: - - from typing import List, Dict - - class Foo(nn.Module): - # `words` is initialized as an empty list, so its type must be specified - words: List[str] - - # The type could potentially be inferred if `a_dict` (below) was not - # empty, but this annotation ensures `some_dict` will be made into the - # proper type - some_dict: Dict[str, int] - - def __init__(self, a_dict): - super().__init__() - self.words = [] - self.some_dict = a_dict - - # `int`s can be inferred - self.my_int = 10 - - def forward(self, input): - # type: (str) -> int - self.words.append(input) - return self.some_dict[input] + self.my_int - - f = torch.jit.script(Foo({'hi': 2})) diff --git a/docs/source/jit_language_reference_v2.md b/docs/source/jit_language_reference_v2.md new file mode 100644 index 00000000000000..12bd2a18a201c3 --- /dev/null +++ b/docs/source/jit_language_reference_v2.md @@ -0,0 +1,1854 @@ +```{eval-rst} +.. testsetup:: + + # These are hidden from the docs, but these are necessary for `doctest` + # since the `inspect` module doesn't play nicely with the execution + # environment for `doctest` + import torch + + original_script = torch.jit.script + def script_wrapper(obj, *args, **kwargs): + obj.__module__ = 'FakeMod' + return original_script(obj, *args, **kwargs) + + torch.jit.script = script_wrapper + + original_trace = torch.jit.trace + def trace_wrapper(obj, *args, **kwargs): + obj.__module__ = 'FakeMod' + return original_trace(obj, *args, **kwargs) + + torch.jit.trace = trace_wrapper +``` + +(language-reference-v2)= + +# TorchScript Language Reference + +This reference manual describes the syntax and core semantics of the TorchScript language. +TorchScript is a statically typed subset of the Python language. This document explains the supported features of +Python in TorchScript and also how the language diverges from regular Python. Any features of Python that are not mentioned in +this reference manual are not part of TorchScript. TorchScript focuses specifically on the features of Python that are needed to +represent neural network models in PyTorch. + +```{contents} +:depth: 1 +:local: true +``` + +(type-system)= + +## Terminology + +This document uses the following terminologies: + +```{eval-rst} +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Pattern + - Notes + * - ``::=`` + - Indicates that the given symbol is defined as. + * - ``" "`` + - Represents real keywords and delimiters that are part of the syntax. + * - ``A | B`` + - Indicates either A or B. + * - ``( )`` + - Indicates grouping. + * - ``[]`` + - Indicates optional. + * - ``A+`` + - Indicates a regular expression where term A is repeated at least once. + * - ``A*`` + - Indicates a regular expression where term A is repeated zero or more times. +``` + +## Type System + +TorchScript is a statically typed subset of Python. The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express +neural net models. + +### TorchScript Types + +The TorchScript type system consists of `TSType` and `TSModuleType` as defined below. + +``` +TSAllType ::= TSType | TSModuleType +TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType +``` + +`TSType` represents the majority of TorchScript types that are composable and that can be used in TorchScript type annotations. +`TSType` refers to any of the following: + +- Meta Types, e.g., `Any` +- Primitive Types, e.g., `int`, `float`, and `str` +- Structural Types, e.g., `Optional[int]` or `List[MyClass]` +- Nominal Types (Python classes), e.g., `MyClass` (user-defined), `torch.tensor` (built-in) + +`TSModuleType` represents `torch.nn.Module` and its subclasses. It is treated differently from `TSType` because its type schema is inferred partly from the object instance and partly from the class definition. +As such, instances of a `TSModuleType` may not follow the same static type schema. `TSModuleType` cannot be used as a TorchScript type annotation or be composed with `TSType` for type safety considerations. + +### Meta Types + +Meta types are so abstract that they are more like type constraints than concrete types. +Currently TorchScript defines one meta-type, `Any`, that represents any TorchScript type. + +#### `Any` Type + +The `Any` type represents any TorchScript type. `Any` specifies no type constraints, thus there is no type-checking on `Any`. +As such it can be bound to any Python or TorchScript data types (e.g., `int`, TorchScript `tuple`, or an arbitrary Python class that is not scripted). + +``` +TSMetaType ::= "Any" +``` + +Where: + +- `Any` is the Python class name from the typing module. Therefore, to use the `Any` type, you must import it from `typing` (e.g., `from typing import Any`). +- Since `Any` can represent any TorchScript type, the set of operators that are allowed to operate on values of this type on `Any` is limited. + +#### Operators Supported for `Any` Type + +- Assignment to data of `Any` type. +- Binding to parameter or return of `Any` type. +- `x is`, `x is not` where `x` is of `Any` type. +- `isinstance(x, Type)` where `x` is of `Any` type. +- Data of `Any` type is printable. +- Data of `List[Any]` type may be sortable if the data is a list of values of the same type `T` and that `T` supports comparison operators. + +**Compared to Python** + +`Any` is the least constrained type in the TorchScript type system. In that sense, it is quite similar to the +`Object` class in Python. However, `Any` only supports a subset of the operators and methods that are supported by `Object`. + +#### Design Notes + +When we script a PyTorch module, we may encounter data that is not involved in the execution of the script. Nevertheless, it has to be described +by a type schema. It is not only cumbersome to describe static types for unused data (in the context of the script), but also may lead to unnecessary +scripting failures. `Any` is introduced to describe the type of the data where precise static types are not necessary for compilation. + +**Example 1** + +This example illustrates how `Any` can be used to allow the second element of the tuple parameter to be of any type. This is possible +because `x[1]` is not involved in any computation that requires knowing its precise type. + +```{eval-rst} +.. testcode:: + + import torch + + from typing import Tuple + from typing import Any + + @torch.jit.export + def inc_first_element(x: Tuple[int, Any]): + return (x[0]+1, x[1]) + + m = torch.jit.script(inc_first_element) + print(m((1,2.0))) + print(m((1,(100,200)))) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + (2, 2.0) + (2, (100, 200)) +``` + +The second element of the tuple is of `Any` type, thus can bind to multiple types. +For example, `(1, 2.0)` binds a float type to `Any` as in `Tuple[int, Any]`, +whereas `(1, (100, 200))` binds a tuple to `Any` in the second invocation. + +**Example 2** + +This example illustrates how we can use `isinstance` to dynamically check the type of the data that is annotated as `Any` type: + +```{eval-rst} +.. testcode:: + + import torch + from typing import Any + + def f(a:Any): + print(a) + return (isinstance(a, torch.Tensor)) + + ones = torch.ones([2]) + m = torch.jit.script(f) + print(m(ones)) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + 1 + 1 + [ CPUFloatType{2} ] + True +``` + +### Primitive Types + +Primitive TorchScript types are types that represent a single type of value and go with a single pre-defined +type name. + +``` +TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" +``` + +### Structural Types + +Structural types are types that are structurally defined without a user-defined name (unlike nominal types), +such as `Future[int]`. Structural types are composable with any `TSType`. + +``` +TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | + TSOptional | TSUnion | TSFuture | TSRRef | TSAwait + +TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" +TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" +TSList ::= "List" "[" TSType "]" +TSOptional ::= "Optional" "[" TSType "]" +TSUnion ::= "Union" "[" (TSType ",")* TSType "]" +TSFuture ::= "Future" "[" TSType "]" +TSRRef ::= "RRef" "[" TSType "]" +TSAwait ::= "Await" "[" TSType "]" +TSDict ::= "Dict" "[" KeyType "," TSType "]" +KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" +``` + +Where: + +- `Tuple`, `List`, `Optional`, `Union`, `Future`, `Dict` represent Python type class names that are defined in the module `typing`. To use these type names, you must import them from `typing` (e.g., `from typing import Tuple`). +- `namedtuple` represents the Python class `collections.namedtuple` or `typing.NamedTuple`. +- `Future` and `RRef` represent the Python classes `torch.futures` and `torch.distributed.rpc`. +- `Await` represent the Python class `torch._awaits._Await` + +**Compared to Python** + +Apart from being composable with TorchScript types, these TorchScript structural types often support a common subset of the operators and methods of their Python counterparts. + +**Example 1** + +This example uses `typing.NamedTuple` syntax to define a tuple: + +```{eval-rst} +.. testcode:: + + import torch + from typing import NamedTuple + from typing import Tuple + + class MyTuple(NamedTuple): + first: int + second: int + + def inc(x: MyTuple) -> Tuple[int, int]: + return (x.first+1, x.second+1) + + t = MyTuple(first=1, second=2) + scripted_inc = torch.jit.script(inc) + print("TorchScript:", scripted_inc(t)) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + TorchScript: (2, 3) +``` + +**Example 2** + +This example uses `collections.namedtuple` syntax to define a tuple: + +```{eval-rst} +.. testcode:: + + import torch + from typing import NamedTuple + from typing import Tuple + from collections import namedtuple + + _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)]) + _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second']) + + def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]: + return (x.first+1, x.second+1) + + m = torch.jit.script(inc) + print(inc(_UnannotatedNamedTuple(1,2))) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + (2, 3) +``` + +**Example 3** + +This example illustrates a common mistake of annotating structural types, i.e., not importing the composite type +classes from the `typing` module: + +```python +import torch + +# ERROR: Tuple not recognized because not imported from typing +@torch.jit.export +def inc(x: Tuple[int, int]): + return (x[0]+1, x[1]+1) + +m = torch.jit.script(inc) +print(m((1,2))) +``` + +Running the above code yields the following scripting error: + +```python +File "test-tuple.py", line 5, in + def inc(x: Tuple[int, int]): +NameError: name 'Tuple' is not defined +``` + +The remedy is to add the line `from typing import Tuple` to the beginning of the code. + +### Nominal Types + +Nominal TorchScript types are Python classes. These types are called nominal because they are declared with a custom +name and are compared using class names. Nominal classes are further classified into the following categories: + +``` +TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum +``` + +Among them, `TSCustomClass` and `TSEnum` must be compilable to TorchScript Intermediate Representation (IR). This is enforced by the type-checker. + +### Built-in Class + +Built-in nominal types are Python classes whose semantics are built into the TorchScript system (e.g., tensor types). +TorchScript defines the semantics of these built-in nominal types, and often supports only a subset of the methods or +attributes of its Python class definition. + +``` +TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" | + "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... +TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" | + "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor +``` + +#### Special Note on torch.nn.ModuleList and torch.nn.ModuleDict + +Although `torch.nn.ModuleList` and `torch.nn.ModuleDict` are defined as a list and dictionary in Python, +they behave more like tuples in TorchScript: + +- In TorchScript, instances of `torch.nn.ModuleList` or `torch.nn.ModuleDict` are immutable. +- Code that iterates over `torch.nn.ModuleList` or `torch.nn.ModuleDict` is completely unrolled so that elements of `torch.nn.ModuleList` or keys of `torch.nn.ModuleDict` can be of different subclasses of `torch.nn.Module`. + +**Example** + +The following example highlights the use of a few built-in Torchscript classes (`torch.*`): + +```python +import torch + +@torch.jit.script +class A: + def __init__(self): + self.x = torch.rand(3) + + def f(self, y: torch.device): + return self.x.to(device=y) + +def g(): + a = A() + return a.f(torch.device("cpu")) + +script_g = torch.jit.script(g) +print(script_g.graph) +``` + +### Custom Class + +Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules. + +``` +TSClassDef ::= [ "@torch.jit.script" ] + "class" ClassName [ "(object)" ] ":" + MethodDefinition | + [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ] + MethodDefinition +``` + +Where: + +- Classes must be new-style classes. Python 3 supports only new-style classes. In Python 2.x, a new-style class is specified by subclassing from the object. +- Instance data attributes are statically typed, and instance attributes must be declared by assignments inside the `__init__()` method. +- Method overloading is not supported (i.e., you cannot have multiple methods with the same method name). +- `MethodDefinition` must be compilable to TorchScript IR and adhere to TorchScript’s type-checking rules, (i.e., all methods must be valid TorchScript functions and class attribute definitions must be valid TorchScript statements). +- `torch.jit.ignore` and `torch.jit.unused` can be used to ignore the method or function that is not fully torchscriptable or should be ignored by the compiler. + +**Compared to Python** + +TorchScript custom classes are quite limited compared to their Python counterpart. Torchscript custom classes: + +- Do not support class attributes. +- Do not support subclassing except for subclassing an interface type or object. +- Do not support method overloading. +- Must initialize all its instance attributes in `__init__()`; this is because TorchScript constructs a static schema of the class by inferring attribute types in `__init__()`. +- Must contain only methods that satisfy TorchScript type-checking rules and are compilable to TorchScript IRs. + +**Example 1** + +Python classes can be used in TorchScript if they are annotated with `@torch.jit.script`, similar to how a TorchScript function would be declared: + +```python +@torch.jit.script +class MyClass: + def __init__(self, x: int): + self.x = x + + def inc(self, val: int): + self.x += val +``` + +**Example 2** + +A TorchScript custom class type must "declare" all its instance attributes by assignments in `__init__()`. If an instance attribute is not defined in `__init__()` but accessed in other methods of the class, the class cannot be compiled as a TorchScript class, as shown in the following example: + +```python +import torch + +@torch.jit.script +class foo: + def __init__(self): + self.y = 1 + +# ERROR: self.x is not defined in __init__ +def assign_x(self): + self.x = torch.rand(2, 3) +``` + +The class will fail to compile and issue the following error: + +``` +RuntimeError: +Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: +def assign_x(self): + self.x = torch.rand(2, 3) + ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE +``` + +**Example 3** + +In this example, a TorchScript custom class defines a class variable name, which is not allowed: + +```python +import torch + +@torch.jit.script +class MyClass(object): + name = "MyClass" + def __init__(self, x: int): + self.x = x + +def fn(a: MyClass): + return a.name +``` + +It leads to the following compile-time error: + +``` +RuntimeError: +'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?: + File "test-class2.py", line 10 +def fn(a: MyClass): + return a.name + ~~~~~~ <--- HERE +``` + +### Enum Type + +Like custom classes, semantics of the enum type are user-defined and the entire class definition must be compilable to TorchScript IR and adhere to TorchScript type-checking rules. + +``` +TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":" + ( MemberIdentifier "=" Value )+ + ( MethodDefinition )* +``` + +Where: + +- Value must be a TorchScript literal of type `int`, `float`, or `str`, and must be of the same TorchScript type. +- `TSEnumType` is the name of a TorchScript enumerated type. Similar to Python enum, TorchScript allows restricted `Enum` subclassing, that is, subclassing an enumerated is allowed only if it does not define any members. + +**Compared to Python** + +- TorchScript supports only `enum.Enum`. It does not support other variations such as `enum.IntEnum`, `enum.Flag`, `enum.IntFlag`, and `enum.auto`. +- Values of TorchScript enum members must be of the same type and can only be `int`, `float`, or `str` types, whereas Python enum members can be of any type. +- Enums containing methods are ignored in TorchScript. + +**Example 1** + +The following example defines the class `Color` as an `Enum` type: + +```python +import torch +from enum import Enum + +class Color(Enum): + RED = 1 + GREEN = 2 + +def enum_fn(x: Color, y: Color) -> bool: + if x == Color.RED: + return True + return x == y + +m = torch.jit.script(enum_fn) + +print("Eager: ", enum_fn(Color.RED, Color.GREEN)) +print("TorchScript: ", m(Color.RED, Color.GREEN)) +``` + +**Example 2** + +The following example shows the case of restricted enum subclassing, where `BaseColor` does not define any member, thus can be subclassed by `Color`: + +```python +import torch +from enum import Enum + +class BaseColor(Enum): + def foo(self): + pass + +class Color(BaseColor): + RED = 1 + GREEN = 2 + +def enum_fn(x: Color, y: Color) -> bool: + if x == Color.RED: + return True + return x == y + +m = torch.jit.script(enum_fn) + +print("TorchScript: ", m(Color.RED, Color.GREEN)) +print("Eager: ", enum_fn(Color.RED, Color.GREEN)) +``` + +### TorchScript Module Class + +`TSModuleType` is a special class type that is inferred from object instances that are created outside TorchScript. `TSModuleType` is named by the Python class of the object instance. The `__init__()` method of the Python class is not considered a TorchScript method, so it does not have to comply with TorchScript’s type-checking rules. + +The type schema of a module instance class is constructed directly from an instance object (created outside the scope of TorchScript) rather than inferred from `__init__()` like custom classes. It is possible that two objects of the same instance class type follow two different type schemas. + +In this sense, `TSModuleType` is not really a static type. Therefore, for type safety considerations, `TSModuleType` cannot be used in a TorchScript type annotation or be composed with `TSType`. + +### Module Instance Class + +TorchScript module type represents the type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to `forward`). The Python module class is treated as a module instance class, so the `__init__()` method of the Python module class is not subject to the type-checking rules of TorchScript. + +``` +TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":" + ClassBodyDefinition +``` + +Where: + +- `forward()` and other methods decorated with `@torch.jit.export` must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules. + +Unlike custom classes, only the forward method and other methods decorated with `@torch.jit.export` of the module type need to be compilable. Most notably, `__init__()` is not considered a TorchScript method. Consequently, module type constructors cannot be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into `torch.jit.script(ModuleObj)`. + +**Example 1** + +This example illustrates a few features of module types: + +- The `TestModule` instance is created outside the scope of TorchScript (i.e., before invoking `torch.jit.script`). +- `__init__()` is not considered a TorchScript method, therefore, it does not have to be annotated and can contain arbitrary Python code. In addition, the `__init__()` method of an instance class cannot be invoked in TorchScript code. Because `TestModule` instances are instantiated in Python, in this example, `TestModule(2.0)` and `TestModule(2)` create two instances with different types for its data attributes. `self.x` is of type `float` for `TestModule(2.0)`, whereas `self.y` is of type `int` for `TestModule(2.0)`. +- TorchScript automatically compiles other methods (e.g., `mul()`) invoked by methods annotated via `@torch.jit.export` or `forward()` methods. +- Entry-points to a TorchScript program are either `forward()` of a module type, functions annotated as `torch.jit.script`, or methods annotated as `torch.jit.export`. + +```{eval-rst} +.. testcode:: + + import torch + + class TestModule(torch.nn.Module): + def __init__(self, v): + super().__init__() + self.x = v + + def forward(self, inc: int): + return self.x + inc + + m = torch.jit.script(TestModule(1)) + print(f"First instance: {m(3)}") + + m = torch.jit.script(TestModule(torch.ones([5]))) + print(f"Second instance: {m(3)}") +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + First instance: 4 + Second instance: tensor([4., 4., 4., 4., 4.]) +``` + +**Example 2** + +The following example shows an incorrect usage of module type. Specifically, this example invokes the constructor of `TestModule` inside the scope of TorchScript: + +```{eval-rst} +.. testcode:: + + import torch + + class TestModule(torch.nn.Module): + def __init__(self, v): + super().__init__() + self.x = v + + def forward(self, x: int): + return self.x + x + + class MyModel: + def __init__(self, v: int): + self.val = v + + @torch.jit.export + def doSomething(self, val: int) -> int: + # error: should not invoke the constructor of module type + myModel = TestModule(self.val) + return myModel(val) + + # m = torch.jit.script(MyModel(2)) # Results in below RuntimeError + # RuntimeError: Could not get name of python class object +``` + +(type-annotation)= + +## Type Annotation + +Since TorchScript is statically typed, programmers need to annotate types at *strategic points* of TorchScript code so that every local variable or +instance data attribute has a static type, and every function and method has a statically typed signature. + +### When to Annotate Types + +In general, type annotations are only needed in places where static types cannot be automatically inferred (e.g., parameters or sometimes return types to +methods or functions). Types of local variables and data attributes are often automatically inferred from their assignment statements. Sometimes an inferred type +may be too restrictive, e.g., `x` being inferred as `NoneType` through assignment `x = None`, whereas `x` is actually used as an `Optional`. In such +cases, type annotations may be needed to overwrite auto inference, e.g., `x: Optional[int] = None`. Note that it is always safe to type annotate a local variable +or data attribute even if its type can be automatically inferred. The annotated type must be congruent with TorchScript’s type-checking. + +When a parameter, local variable, or data attribute is not type annotated and its type cannot be automatically inferred, TorchScript assumes it to be a +default type of `TensorType`, `List[TensorType]`, or `Dict[str, TensorType]`. + +### Annotate Function Signature + +Since a parameter may not be automatically inferred from the body of the function (including both functions and methods), they need to be type annotated. Otherwise, they assume the default type `TensorType`. + +TorchScript supports two styles for method and function signature type annotation: + +- **Python3-style** annotates types directly on the signature. As such, it allows individual parameters to be left unannotated (whose type will be the default type of `TensorType`), or allows the return type to be left unannotated (whose type will be automatically inferred). + +``` +Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":" + FuncOrMethodBody +ParamAnnot ::= Identifier [ ":" TSType ] "," +ReturnAnnot ::= "->" TSType +``` + +Note that when using Python3 style, the type `self` is automatically inferred and should not be annotated. + +- **Mypy style** annotates types as a comment right below the function/method declaration. In the Mypy style, since parameter names do not appear in the annotation, all parameters have to be annotated. + +``` +MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ] +ParamAnnot ::= TSType "," +ReturnAnnot ::= "->" TSType +``` + +**Example 1** + +In this example: + +- `a` is not annotated and assumes the default type of `TensorType`. +- `b` is annotated as type `int`. +- The return type is not annotated and is automatically inferred as type `TensorType` (based on the type of the value being returned). + +```python +import torch + +def f(a, b: int): + return a+b + +m = torch.jit.script(f) +print("TorchScript:", m(torch.ones([6]), 100)) +``` + +**Example 2** + +The following example uses Mypy style annotation. Note that parameters or return values must be annotated even if some of +them assume the default type. + +```python +import torch + +def f(a, b): + # type: (torch.Tensor, int) → torch.Tensor + return a+b + +m = torch.jit.script(f) +print("TorchScript:", m(torch.ones([6]), 100)) +``` + +### Annotate Variables and Data Attributes + +In general, types of data attributes (including class and instance data attributes) and local variables can be automatically inferred from assignment statements. +Sometimes, however, if a variable or attribute is associated with values of different types (e.g., as `None` or `TensorType`), then they may need to be explicitly +type annotated as a *wider* type such as `Optional[int]` or `Any`. + +#### Local Variables + +Local variables can be annotated according to Python3 typing module annotation rules, i.e., + +``` +LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr +``` + +In general, types of local variables can be automatically inferred. In some cases, however, you may need to annotate a multi-type for local variables +that may be associated with different concrete types. Typical multi-types include `Optional[T]` and `Any`. + +**Example** + +```python +import torch + +def f(a, setVal: bool): + value: Optional[torch.Tensor] = None + if setVal: + value = a + return value + +ones = torch.ones([6]) +m = torch.jit.script(f) +print("TorchScript:", m(ones, True), m(ones, False)) +``` + +#### Instance Data Attributes + +For `ModuleType` classes, instance data attributes can be annotated according to Python3 typing module annotation rules. Instance data attributes can be annotated (optionally) as final +via `Final`. + +``` +"class" ClassIdentifier "(torch.nn.Module):" +InstanceAttrIdentifier ":" ["Final("] TSType [")"] +... +``` + +Where: + +- `InstanceAttrIdentifier` is the name of an instance attribute. +- `Final` indicates that the attribute cannot be re-assigned outside of `__init__` or overridden in subclasses. + +**Example** + +```python +import torch + +class MyModule(torch.nn.Module): + offset_: int + +def __init__(self, offset): + self.offset_ = offset + +... +``` + +### Type Annotation APIs + +#### `torch.jit.annotate(T, expr)` + +This API annotates type `T` to an expression `expr`. This is often used when the default type of an expression is not the type intended by the programmer. +For instance, an empty list (dictionary) has the default type of `List[TensorType]` (`Dict[TensorType, TensorType]`), but sometimes it may be used to initialize +a list of some other types. Another common use case is for annotating the return type of `tensor.tolist()`. Note, however, that it cannot be used to annotate +the type of a module attribute in `__init__`; `torch.jit.Attribute` should be used for this instead. + +**Example** + +In this example, `[]` is declared as a list of integers via `torch.jit.annotate` (instead of assuming `[]` to be the default type of `List[TensorType]`): + +```python +import torch +from typing import List + +def g(l: List[int], val: int): + l.append(val) + return l + +def f(val: int): + l = g(torch.jit.annotate(List[int], []), val) + return l + +m = torch.jit.script(f) +print("Eager:", f(3)) +print("TorchScript:", m(3)) +``` + +See {meth}`torch.jit.annotate` for more information. + +### Type Annotation Appendix + +#### TorchScript Type System Definition + +``` +TSAllType ::= TSType | TSModuleType +TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType + +TSMetaType ::= "Any" +TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" + +TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | + TSUnion | TSFuture | TSRRef | TSAwait +TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" +TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" +TSList ::= "List" "[" TSType "]" +TSOptional ::= "Optional" "[" TSType "]" +TSUnion ::= "Union" "[" (TSType ",")* TSType "]" +TSFuture ::= "Future" "[" TSType "]" +TSRRef ::= "RRef" "[" TSType "]" +TSAwait ::= "Await" "[" TSType "]" +TSDict ::= "Dict" "[" KeyType "," TSType "]" +KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" + +TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum +TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"| + "torch.dtype" | "torch.nn.ModuleList" | + "torch.nn.ModuleDict" | ... +TSTensor ::= "torch.tensor" and subclasses +``` + +#### Unsupported Typing Constructs + +TorchScript does not support all features and types of the Python3 [typing](https://docs.python.org/3/library/typing.html#module-typing) module. +Any functionality from the [typing](https://docs.python.org/3/library/typing.html#module-typing) module that is not explicitly specified in this +documentation is unsupported. The following table summarizes `typing` constructs that are either unsupported or supported with restrictions in TorchScript. + +```{eval-rst} +============================= ================ + Item Description +----------------------------- ---------------- +``typing.Any`` In development +``typing.NoReturn`` Not supported +``typing.Callable`` Not supported +``typing.Literal`` Not supported +``typing.ClassVar`` Not supported +``typing.Final`` Supported for module attributes, class attribute, and annotations, but not for functions. +``typing.AnyStr`` Not supported +``typing.overload`` In development +Type aliases Not supported +Nominal typing In development +Structural typing Not supported +NewType Not supported +Generics Not supported +============================= ================ +``` + +(expressions)= + +## Expressions + +The following section describes the grammar of expressions that are supported in TorchScript. +It is modeled after [the expressions chapter of the Python language reference](https://docs.python.org/3/reference/expressions.html). + +### Arithmetic Conversions + +There are a number of implicit type conversions that are performed in TorchScript: + +- A `Tensor` with a `float` or `int` data type can be implicitly converted to an instance of `FloatType` or `IntType` provided that it has a size of 0, does not have `require_grad` set to `True`, and will not require narrowing. +- Instances of `StringType` can be implicitly converted to `DeviceType`. +- The implicit conversion rules from the two bullet points above can be applied to instances of `TupleType` to produce instances of `ListType` with the appropriate contained type. + +Explicit conversions can be invoked using the `float`, `int`, `bool`, and `str` built-in functions +that accept primitive data types as arguments and can accept user-defined types if they implement +`__bool__`, `__str__`, etc. + +### Atoms + +Atoms are the most basic elements of expressions. + +``` +atom ::= identifier | literal | enclosure +enclosure ::= parenth_form | list_display | dict_display +``` + +#### Identifiers + +The rules that dictate what is a legal identifier in TorchScript are the same as +their [Python counterparts](https://docs.python.org/3/reference/lexical_analysis.html#identifiers). + +#### Literals + +``` +literal ::= stringliteral | integer | floatnumber +``` + +Evaluation of a literal yields an object of the appropriate type with the specific value +(with approximations applied as necessary for floats). Literals are immutable, and multiple evaluations +of identical literals may obtain the same object or distinct objects with the same value. +[stringliteral](https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals), +[integer](https://docs.python.org/3/reference/lexical_analysis.html#integer-literals), and +[floatnumber](https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals) +are defined in the same way as their Python counterparts. + +#### Parenthesized Forms + +``` +parenth_form ::= '(' [expression_list] ')' +``` + +A parenthesized expression list yields whatever the expression list yields. If the list contains at least one +comma, it yields a `Tuple`; otherwise, it yields the single expression inside the expression list. An empty +pair of parentheses yields an empty `Tuple` object (`Tuple[]`). + +#### List and Dictionary Displays + +``` +list_comprehension ::= expression comp_for +comp_for ::= 'for' target_list 'in' or_expr +list_display ::= '[' [expression_list | list_comprehension] ']' +dict_display ::= '{' [key_datum_list | dict_comprehension] '}' +key_datum_list ::= key_datum (',' key_datum)* +key_datum ::= expression ':' expression +dict_comprehension ::= key_datum comp_for +``` + +Lists and dicts can be constructed by either listing the container contents explicitly or by providing +instructions on how to compute them via a set of looping instructions (i.e. a *comprehension*). A comprehension +is semantically equivalent to using a for loop and appending to an ongoing list. +Comprehensions implicitly create their own scope to make sure that the items of the target list do not leak into the +enclosing scope. In the case that container items are explicitly listed, the expressions in the expression list +are evaluated left-to-right. If a key is repeated in a `dict_display` that has a `key_datum_list`, the +resultant dictionary uses the value from the rightmost datum in the list that uses the repeated key. + +### Primaries + +``` +primary ::= atom | attributeref | subscription | slicing | call +``` + +#### Attribute References + +``` +attributeref ::= primary '.' identifier +``` + +The `primary` must evaluate to an object of a type that supports attribute references that have an attribute named +`identifier`. + +#### Subscriptions + +``` +subscription ::= primary '[' expression_list ']' +``` + +The `primary` must evaluate to an object that supports subscription. + +- If the primary is a `List`, `Tuple`, or `str`, the expression list must evaluate to an integer or slice. +- If the primary is a `Dict`, the expression list must evaluate to an object of the same type as the key type of the `Dict`. +- If the primary is a `ModuleList`, the expression list must be an `integer` literal. +- If the primary is a `ModuleDict`, the expression must be a `stringliteral`. + +#### Slicings + +A slicing selects a range of items in a `str`, `Tuple`, `List`, or `Tensor`. Slicings may be used as +expressions or targets in assignment or `del` statements. + +``` +slicing ::= primary '[' slice_list ']' +slice_list ::= slice_item (',' slice_item)* [','] +slice_item ::= expression | proper_slice +proper_slice ::= [expression] ':' [expression] [':' [expression] ] +``` + +Slicings with more than one slice item in their slice lists can only be used with primaries that evaluate to an +object of type `Tensor`. + +#### Calls + +``` +call ::= primary '(' argument_list ')' +argument_list ::= args [',' kwargs] | kwargs +args ::= [arg (',' arg)*] +kwargs ::= [kwarg (',' kwarg)*] +kwarg ::= arg '=' expression +arg ::= identifier +``` + +The `primary` must desugar or evaluate to a callable object. All argument expressions are evaluated +before the call is attempted. + +### Power Operator + +``` +power ::= primary ['**' u_expr] +``` + +The power operator has the same semantics as the built-in pow function (not supported); it computes its +left argument raised to the power of its right argument. It binds more tightly than unary operators on the +left, but less tightly than unary operators on the right; i.e. `-2 ** -3 == -(2 ** (-3))`. The left and right +operands can be `int`, `float` or `Tensor`. Scalars are broadcast in the case of scalar-tensor/tensor-scalar +exponentiation operations, and tensor-tensor exponentiation is done elementwise without any broadcasting. + +### Unary and Arithmetic Bitwise Operations + +``` +u_expr ::= power | '-' power | '~' power +``` + +The unary `-` operator yields the negation of its argument. The unary `~` operator yields the bitwise inversion +of its argument. `-` can be used with `int`, `float`, and `Tensor` of `int` and `float`. +`~` can only be used with `int` and `Tensor` of `int`. + +### Binary Arithmetic Operations + +``` +m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr +a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr +``` + +The binary arithmetic operators can operate on `Tensor`, `int`, and `float`. For tensor-tensor ops, both arguments must +have the same shape. For scalar-tensor or tensor-scalar ops, the scalar is usually broadcast to the size of the +tensor. Division ops can only accept scalars as their right-hand side argument, and do not support broadcasting. +The `@` operator is for matrix multiplication and only operates on `Tensor` arguments. The multiplication operator +(`*`) can be used with a list and integer in order to get a result that is the original list repeated a certain +number of times. + +### Shifting Operations + +``` +shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr +``` + +These operators accept two `int` arguments, two `Tensor` arguments, or a `Tensor` argument and an `int` or +`float` argument. In all cases, a right shift by `n` is defined as floor division by `pow(2, n)`, and a left shift +by `n` is defined as multiplication by `pow(2, n)`. When both arguments are `Tensors`, they must have the same +shape. When one is a scalar and the other is a `Tensor`, the scalar is logically broadcast to match the size of +the `Tensor`. + +### Binary Bitwise Operations + +``` +and_expr ::= shift_expr | and_expr '&' shift_expr +xor_expr ::= and_expr | xor_expr '^' and_expr +or_expr ::= xor_expr | or_expr '|' xor_expr +``` + +The `&` operator computes the bitwise AND of its arguments, the `^` the bitwise XOR, and the `|` the bitwise OR. +Both operands must be `int` or `Tensor`, or the left operand must be `Tensor` and the right operand must be +`int`. When both operands are `Tensor`, they must have the same shape. When the right operand is `int`, and +the left operand is `Tensor`, the right operand is logically broadcast to match the shape of the `Tensor`. + +### Comparisons + +``` +comparison ::= or_expr (comp_operator or_expr)* +comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in' +``` + +A comparison yields a boolean value (`True` or `False`), or if one of the operands is a `Tensor`, a boolean +`Tensor`. Comparisons can be chained arbitrarily as long as they do not yield boolean `Tensors` that have more +than one element. `a op1 b op2 c ...` is equivalent to `a op1 b and b op2 c and ...`. + +#### Value Comparisons + +The operators `<`, `>`, `==`, `>=`, `<=`, and `!=` compare the values of two objects. The two objects generally need to be of +the same type, unless there is an implicit type conversion available between the objects. User-defined types can +be compared if rich comparison methods (e.g., `__lt__`) are defined on them. Built-in type comparison works like +Python: + +- Numbers are compared mathematically. +- Strings are compared lexicographically. +- `lists`, `tuples`, and `dicts` can be compared only to other `lists`, `tuples`, and `dicts` of the same type and are compared using the comparison operator of corresponding elements. + +#### Membership Test Operations + +The operators `in` and `not in` test for membership. `x in s` evaluates to `True` if `x` is a member of `s` and `False` otherwise. +`x not in s` is equivalent to `not x in s`. This operator is supported for `lists`, `dicts`, and `tuples`, and can be used with +user-defined types if they implement the `__contains__` method. + +#### Identity Comparisons + +For all types except `int`, `double`, `bool`, and `torch.device`, operators `is` and `is not` test for the object’s identity; +`x is y` is `True` if and only if `x` and `y` are the same object. For all other types, `is` is equivalent to +comparing them using `==`. `x is not y` yields the inverse of `x is y`. + +### Boolean Operations + +``` +or_test ::= and_test | or_test 'or' and_test +and_test ::= not_test | and_test 'and' not_test +not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test +``` + +User-defined objects can customize their conversion to `bool` by implementing a `__bool__` method. The operator `not` +yields `True` if its operand is false, `False` otherwise. The expression `x` and `y` first evaluates `x`; if it is `False`, its +value (`False`) is returned; otherwise, `y` is evaluated and its value is returned (`False` or `True`). The expression `x` or `y` +first evaluates `x`; if it is `True`, its value (`True`) is returned; otherwise, `y` is evaluated and its value is returned +(`False` or `True`). + +### Conditional Expressions + +``` +conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression] +expression ::= conditional_expression +``` + +The expression `x if c else y` first evaluates the condition `c` rather than x. If `c` is `True`, `x` is +evaluated and its value is returned; otherwise, `y` is evaluated and its value is returned. As with if-statements, +`x` and `y` must evaluate to a value of the same type. + +### Expression Lists + +``` +expression_list ::= expression (',' expression)* [','] +starred_item ::= '*' primary +``` + +A starred item can only appear on the left-hand side of an assignment statement, e.g., `a, *b, c = ...`. + +% statements: + +## Simple Statements + +The following section describes the syntax of simple statements that are supported in TorchScript. +It is modeled after [the simple statements chapter of the Python language reference](https://docs.python.org/3/reference/simple_stmts.html). + +### Expression Statements + +``` +expression_stmt ::= starred_expression +starred_expression ::= expression | (starred_item ",")* [starred_item] +starred_item ::= assignment_expression | "*" or_expr +``` + +### Assignment Statements + +``` +assignment_stmt ::= (target_list "=")+ (starred_expression) +target_list ::= target ("," target)* [","] +target ::= identifier + | "(" [target_list] ")" + | "[" [target_list] "]" + | attributeref + | subscription + | slicing + | "*" target +``` + +### Augmented Assignment Statements + +``` +augmented_assignment_stmt ::= augtarget augop (expression_list) +augtarget ::= identifier | attributeref | subscription +augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | + "**="| ">>=" | "<<=" | "&=" | "^=" | "|=" +``` + +### Annotated Assignment Statements + +``` +annotated_assignment_stmt ::= augtarget ":" expression + ["=" (starred_expression)] +``` + +### The `raise` Statement + +``` +raise_stmt ::= "raise" [expression ["from" expression]] +``` + +Raise statements in TorchScript do not support `try\except\finally`. + +### The `assert` Statement + +``` +assert_stmt ::= "assert" expression ["," expression] +``` + +Assert statements in TorchScript do not support `try\except\finally`. + +### The `return` Statement + +``` +return_stmt ::= "return" [expression_list] +``` + +Return statements in TorchScript do not support `try\except\finally`. + +### The `del` Statement + +``` +del_stmt ::= "del" target_list +``` + +### The `pass` Statement + +``` +pass_stmt ::= "pass" +``` + +### The `print` Statement + +``` +print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")" +``` + +### The `break` Statement + +``` +break_stmt ::= "break" +``` + +### The `continue` Statement: + +``` +continue_stmt ::= "continue" +``` + +## Compound Statements + +The following section describes the syntax of compound statements that are supported in TorchScript. +The section also highlights how Torchscript differs from regular Python statements. +It is modeled after [the compound statements chapter of the Python language reference](https://docs.python.org/3/reference/compound_stmts.html). + +### The `if` Statement + +Torchscript supports both basic `if/else` and ternary `if/else`. + +#### Basic `if/else` Statement + +``` +if_stmt ::= "if" assignment_expression ":" suite + ("elif" assignment_expression ":" suite) + ["else" ":" suite] +``` + +`elif` statements can repeat for an arbitrary number of times, but it needs to be before `else` statement. + +#### Ternary `if/else` Statement + +``` +if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list] +``` + +**Example 1** + +A `tensor` with 1 dimension is promoted to `bool`: + +```{eval-rst} +.. testcode:: + + import torch + + @torch.jit.script + def fn(x: torch.Tensor): + if x: # The tensor gets promoted to bool + return True + return False + print(fn(torch.rand(1))) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + True +``` + +**Example 2** + +A `tensor` with multi dimensions are not promoted to `bool`: + +```python +import torch + +# Multi dimensional Tensors error out. + +@torch.jit.script +def fn(): + if torch.rand(2): + print("Tensor is available") + + if torch.rand(4,5,6): + print("Tensor is available") + +print(fn()) +``` + +Running the above code yields the following `RuntimeError`. + +``` +RuntimeError: The following operation failed in the TorchScript interpreter. +Traceback of TorchScript (most recent call last): +@torch.jit.script +def fn(): + if torch.rand(2): + ~~~~~~~~~~~~ <--- HERE + print("Tensor is available") +RuntimeError: Boolean value of Tensor with more than one value is ambiguous +``` + +If a conditional variable is annotated as `final`, either the true or false branch is evaluated depending on the evaluation of the conditional variable. + +**Example 3** + +In this example, only the True branch is evaluated, since `a` is annotated as `final` and set to `True`: + +```python +import torch + +a : torch.jit.final[Bool] = True + +if a: + return torch.empty(2,3) +else: + return [] +``` + +### The `while` Statement + +``` +while_stmt ::= "while" assignment_expression ":" suite +``` + +`while...else` statements are not supported in Torchscript. It results in a `RuntimeError`. + +### The `for-in` Statement + +``` +for_stmt ::= "for" target_list "in" expression_list ":" suite + ["else" ":" suite] +``` + +`for...else` statements are not supported in Torchscript. It results in a `RuntimeError`. + +**Example 1** + +For loops on tuples: these unroll the loop, generating a body for each member of the tuple. The body must type-check correctly for each member. + +```{eval-rst} +.. testcode:: + + import torch + from typing import Tuple + + @torch.jit.script + def fn(): + tup = (3, torch.ones(4)) + for x in tup: + print(x) + + fn() +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + 3 + 1 + 1 + 1 + 1 + [ CPUFloatType{4} ] + +``` + +**Example 2** + +For loops on lists: for loops over a `nn.ModuleList` will unroll the body of the loop at compile time, with each member of the module list. + +```python +class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(2)) + + def forward(self, input): + return self.weight + input + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) + + def forward(self, v): + for module in self.mods: + v = module(v) + return v + +model = torch.jit.script(MyModule()) +``` + +### The `with` Statement + +The `with` statement is used to wrap the execution of a block with methods defined by a context manager. + +``` +with_stmt ::= "with" with_item ("," with_item) ":" suite +with_item ::= expression ["as" target] +``` + +- If a target was included in the `with` statement, the return value from the context manager’s `__enter__()` is assigned to it. Unlike python, if an exception caused the suite to be exited, its type, value, and traceback are not passed as arguments to `__exit__()`. Three `None` arguments are supplied. +- `try`, `except`, and `finally` statements are not supported inside `with` blocks. +- Exceptions raised within `with` block cannot be suppressed. + +### The `tuple` Statement + +``` +tuple_stmt ::= tuple([iterables]) +``` + +- Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList`, and `torch.nn.ModuleDict`. +- You cannot convert a List to Tuple by using this built-in function. + +Unpacking all outputs into a tuple is covered by: + +``` +abc = func() # Function that returns a tuple +a,b = func() +``` + +### The `getattr` Statement + +``` +getattr_stmt ::= getattr(object, name[, default]) +``` + +- Attribute name must be a literal string. +- Module type object is not supported (e.g., torch.\_C). +- Custom class object is not supported (e.g., torch.classes.\*). + +### The `hasattr` Statement + +``` +hasattr_stmt ::= hasattr(object, name) +``` + +- Attribute name must be a literal string. +- Module type object is not supported (e.g., torch.\_C). +- Custom class object is not supported (e.g., torch.classes.\*). + +### The `zip` Statement + +``` +zip_stmt ::= zip(iterable1, iterable2) +``` + +- Arguments must be iterables. +- Two iterables of same outer container type but different length are supported. + +**Example 1** + +Both the iterables must be of the same container type: + +```{eval-rst} +.. testcode:: + + a = [1, 2] # List + b = [2, 3, 4] # List + zip(a, b) # works +``` + +**Example 2** + +This example fails because the iterables are of different container types: + +``` +a = (1, 2) # Tuple +b = [2, 3, 4] # List +zip(a, b) # Runtime error +``` + +Running the above code yields the following `RuntimeError`. + +``` +RuntimeError: Can not iterate over a module list or + tuple with a value that does not have a statically determinable length. +``` + +**Example 3** + +Two iterables of the same container Type but different data type is supported: + +```{eval-rst} +.. testcode:: + + a = [1.3, 2.4] + b = [2, 3, 4] + zip(a, b) # Works +``` + +Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList`, and `torch.nn.ModuleDict`. + +### The `enumerate` Statement + +``` +enumerate_stmt ::= enumerate([iterable]) +``` + +- Arguments must be iterables. +- Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList` and `torch.nn.ModuleDict`. + +(python-values-torch-script)= + +## Python Values + +(python-builtin-functions-values-resolution)= + +### Resolution Rules + +When given a Python value, TorchScript attempts to resolve it in the following five different ways: + +- Compilable Python Implementation: + : - When a Python value is backed by a Python implementation that can be compiled by TorchScript, TorchScript compiles and uses the underlying Python implementation. + - Example: `torch.jit.Attribute` +- Op Python Wrapper: + : - When a Python value is a wrapper of a native PyTorch op, TorchScript emits the corresponding operator. + - Example: `torch.jit._logging.add_stat_value` +- Python Object Identity Match: + : - For a limited set of `torch.*` API calls (in the form of Python values) that TorchScript supports, TorchScript attempts to match a Python value against each item in the set. + - When matched, TorchScript generates a corresponding `SugaredValue` instance that contains lowering logic for these values. + - Example: `torch.jit.isinstance()` +- Name Match: + : - For Python built-in functions and constants, TorchScript identifies them by name, and creates a corresponding `SugaredValue` instance that implements their functionality. + - Example: `all()` +- Value Snapshot: + : - For Python values from unrecognized modules, TorchScript attempts to take a snapshot of the value and converts it to a constant in the graph of the function(s) or method(s) that are being compiled. + - Example: `math.pi` + +(python-builtin-functions-support)= + +### Python Built-in Functions Support + +```{eval-rst} +.. list-table:: TorchScript Support for Python Built-in Functions + :widths: 25 25 50 + :header-rows: 1 + + * - Built-in Function + - Support Level + - Notes + * - ``abs()`` + - Partial + - Only supports ``Tensor``/``Int``/``Float`` type inputs. | Doesn't honor ``__abs__`` override. + * - ``all()`` + - Full + - + * - ``any()`` + - Full + - + * - ``ascii()`` + - None + - + * - ``bin()`` + - Partial + - Only supports ``Int`` type input. + * - ``bool()`` + - Partial + - Only supports ``Tensor``/``Int``/``Float`` type inputs. + * - ``breakpoint()`` + - None + - + * - ``bytearray()`` + - None + - + * - ``bytes()`` + - None + - + * - ``callable()`` + - None + - + * - ``chr()`` + - Partial + - Only ASCII character set is supported. + * - ``classmethod()`` + - Full + - + * - ``compile()`` + - None + - + * - ``complex()`` + - None + - + * - ``delattr()`` + - None + - + * - ``dict()`` + - Full + - + * - ``dir()`` + - None + - + * - ``divmod()`` + - Full + - + * - ``enumerate()`` + - Full + - + * - ``eval()`` + - None + - + * - ``exec()`` + - None + - + * - ``filter()`` + - None + - + * - ``float()`` + - Partial + - Doesn't honor ``__index__`` override. + * - ``format()`` + - Partial + - Manual index specification not supported. | Format type modifier not supported. + * - ``frozenset()`` + - None + - + * - ``getattr()`` + - Partial + - Attribute name must be string literal. + * - ``globals()`` + - None + - + * - ``hasattr()`` + - Partial + - Attribute name must be string literal. + * - ``hash()`` + - Full + - ``Tensor``'s hash is based on identity not numeric value. + * - ``hex()`` + - Partial + - Only supports ``Int`` type input. + * - ``id()`` + - Full + - Only supports ``Int`` type input. + * - ``input()`` + - None + - + * - ``int()`` + - Partial + - ``base`` argument not supported. | Doesn't honor ``__index__`` override. + * - ``isinstance()`` + - Full + - ``torch.jit.isintance`` provides better support when checking against container types like ``Dict[str, int]``. + * - ``issubclass()`` + - None + - + * - ``iter()`` + - None + - + * - ``len()`` + - Full + - + * - ``list()`` + - Full + - + * - ``ord()`` + - Partial + - Only ASCII character set is supported. + * - ``pow()`` + - Full + - + * - ``print()`` + - Partial + - ``separate``, ``end`` and ``file`` arguments are not supported. + * - ``property()`` + - None + - + * - ``range()`` + - Full + - + * - ``repr()`` + - None + - + * - ``reversed()`` + - None + - + * - ``round()`` + - Partial + - ``ndigits`` argument is not supported. + * - ``set()`` + - None + - + * - ``setattr()`` + - None + - + * - ``slice()`` + - Full + - + * - ``sorted()`` + - Partial + - ``key`` argument is not supported. + * - ``staticmethod()`` + - Full + - + * - ``str()`` + - Partial + - ``encoding`` and ``errors`` arguments are not supported. + * - ``sum()`` + - Full + - + * - ``super()`` + - Partial + - It can only be used in ``nn.Module``'s ``__init__`` method. + * - ``type()`` + - None + - + * - ``vars()`` + - None + - + * - ``zip()`` + - Full + - + * - ``__import__()`` + - None + - +``` + +(python-builtin-values-support)= + +### Python Built-in Values Support + +```{eval-rst} +.. list-table:: TorchScript Support for Python Built-in Values + :widths: 25 25 50 + :header-rows: 1 + + * - Built-in Value + - Support Level + - Notes + * - ``False`` + - Full + - + * - ``True`` + - Full + - + * - ``None`` + - Full + - + * - ``NotImplemented`` + - None + - + * - ``Ellipsis`` + - Full + - + +``` + +(torch-apis-in-torchscript)= + +## torch.\* APIs + +(torch-apis-in-torchscript-rpc)= + +### Remote Procedure Calls + +TorchScript supports a subset of RPC APIs that supports running a function on +a specified remote worker instead of locally. + +Specifically, following APIs are fully supported: + +- `torch.distributed.rpc.rpc_sync()` + : - `rpc_sync()` makes a blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. + - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.rpc_sync`. +- `torch.distributed.rpc.rpc_async()` + : - `rpc_async()` makes a non-blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. + - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.rpc_async`. +- `torch.distributed.rpc.remote()` + : - `remote.()` executes a remote call on a worker and gets a Remote Reference `RRef` as the return value. + - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.remote`. + +(torch-apis-in-torchscript-async)= + +### Asynchronous Execution + +TorchScript enables you to create asynchronous computation tasks to make better use +of computation resources. This is done via supporting a list of APIs that are +only usable within TorchScript: + +- `torch.jit.fork()` + : - Creates an asynchronous task executing func and a reference to the value of the result of this execution. Fork will return immediately. + - Synonymous to `torch.jit._fork()`, which is only kept for backward compatibility reasons. + - More details about its usage and examples can be found in {meth}`~torch.jit.fork`. +- `torch.jit.wait()` + : - Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task. + - Synonymous to `torch.jit._wait()`, which is only kept for backward compatibility reasons. + - More details about its usage and examples can be found in {meth}`~torch.jit.wait`. + +(torch-apis-in-torchscript-annotation)= + +### Type Annotations + +TorchScript is statically-typed. It provides and supports a set of utilities to help annotate variables and attributes: + +- `torch.jit.annotate()` + : - Provides a type hint to TorchScript where Python 3 style type hints do not work well. + - One common example is to annotate type for expressions like `[]`. `[]` is treated as `List[torch.Tensor]` by default. When a different type is needed, you can use this code to hint TorchScript: `torch.jit.annotate(List[int], [])`. + - More details can be found in {meth}`~torch.jit.annotate` +- `torch.jit.Attribute` + : - Common use cases include providing type hint for `torch.nn.Module` attributes. Because their `__init__` methods are not parsed by TorchScript, `torch.jit.Attribute` should be used instead of `torch.jit.annotate` in the module's `__init__` methods. + - More details can be found in {meth}`~torch.jit.Attribute` +- `torch.jit.Final` + : - An alias for Python's `typing.Final`. `torch.jit.Final` is kept only for backward compatibility reasons. + +(torch-apis-in-torchscript-meta-programming)= + +### Meta Programming + +TorchScript provides a set of utilities to facilitate meta programming: + +- `torch.jit.is_scripting()` + : - Returns a boolean value indicating whether the current program is compiled by `torch.jit.script` or not. + - When used in an `assert` or an `if` statement, the scope or branch where `torch.jit.is_scripting()` evaluates to `False` is not compiled. + - Its value can be evaluated statically at compile time, thus commonly used in `if` statements to stop TorchScript from compiling one of the branches. + - More details and examples can be found in {meth}`~torch.jit.is_scripting` +- `torch.jit.is_tracing()` + : - Returns a boolean value indicating whether the current program is traced by `torch.jit.trace` / `torch.jit.trace_module` or not. + - More details can be found in {meth}`~torch.jit.is_tracing` +- `@torch.jit.ignore` + : - This decorator indicates to the compiler that a function or method should be ignored and left as a Python function. + - This allows you to leave code in your model that is not yet TorchScript compatible. + - If a function decorated by `@torch.jit.ignore` is called from TorchScript, ignored functions will dispatch the call to the Python interpreter. + - Models with ignored functions cannot be exported. + - More details and examples can be found in {meth}`~torch.jit.ignore` +- `@torch.jit.unused` + : - This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception. + - This allows you to leave code in your model that is not yet TorchScript compatible and still export your model. + - If a function decorated by `@torch.jit.unused` is called from TorchScript, a runtime error will be raised. + - More details and examples can be found in {meth}`~torch.jit.unused` + +(torch-apis-in-torchscript-type-refinement)= + +### Type Refinement + +- `torch.jit.isinstance()` + : - Returns a boolean indicating whether a variable is of the specified type. + - More details about its usage and examples can be found in {meth}`~torch.jit.isinstance`. \ No newline at end of file diff --git a/docs/source/jit_language_reference_v2.rst b/docs/source/jit_language_reference_v2.rst deleted file mode 100644 index 0863f9bbd2b7d8..00000000000000 --- a/docs/source/jit_language_reference_v2.rst +++ /dev/null @@ -1,1911 +0,0 @@ -.. testsetup:: - - # These are hidden from the docs, but these are necessary for `doctest` - # since the `inspect` module doesn't play nicely with the execution - # environment for `doctest` - import torch - - original_script = torch.jit.script - def script_wrapper(obj, *args, **kwargs): - obj.__module__ = 'FakeMod' - return original_script(obj, *args, **kwargs) - - torch.jit.script = script_wrapper - - original_trace = torch.jit.trace - def trace_wrapper(obj, *args, **kwargs): - obj.__module__ = 'FakeMod' - return original_trace(obj, *args, **kwargs) - - torch.jit.trace = trace_wrapper - -.. _language-reference-v2: - -TorchScript Language Reference -============================== - -This reference manual describes the syntax and core semantics of the TorchScript language. -TorchScript is a statically typed subset of the Python language. This document explains the supported features of -Python in TorchScript and also how the language diverges from regular Python. Any features of Python that are not mentioned in -this reference manual are not part of TorchScript. TorchScript focuses specifically on the features of Python that are needed to -represent neural network models in PyTorch. - -.. contents:: - :local: - :depth: 1 - -.. _type_system: - -Terminology -~~~~~~~~~~~ - -This document uses the following terminologies: - -.. list-table:: - :widths: 25 25 - :header-rows: 1 - - * - Pattern - - Notes - * - ``::=`` - - Indicates that the given symbol is defined as. - * - ``" "`` - - Represents real keywords and delimiters that are part of the syntax. - * - ``A | B`` - - Indicates either A or B. - * - ``( )`` - - Indicates grouping. - * - ``[]`` - - Indicates optional. - * - ``A+`` - - Indicates a regular expression where term A is repeated at least once. - * - ``A*`` - - Indicates a regular expression where term A is repeated zero or more times. - -Type System -~~~~~~~~~~~ -TorchScript is a statically typed subset of Python. The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express -neural net models. - -TorchScript Types -^^^^^^^^^^^^^^^^^ - -The TorchScript type system consists of ``TSType`` and ``TSModuleType`` as defined below. - -:: - - TSAllType ::= TSType | TSModuleType - TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType - -``TSType`` represents the majority of TorchScript types that are composable and that can be used in TorchScript type annotations. -``TSType`` refers to any of the following: - -* Meta Types, e.g., ``Any`` -* Primitive Types, e.g., ``int``, ``float``, and ``str`` -* Structural Types, e.g., ``Optional[int]`` or ``List[MyClass]`` -* Nominal Types (Python classes), e.g., ``MyClass`` (user-defined), ``torch.tensor`` (built-in) - -``TSModuleType`` represents ``torch.nn.Module`` and its subclasses. It is treated differently from ``TSType`` because its type schema is inferred partly from the object instance and partly from the class definition. -As such, instances of a ``TSModuleType`` may not follow the same static type schema. ``TSModuleType`` cannot be used as a TorchScript type annotation or be composed with ``TSType`` for type safety considerations. - -Meta Types -^^^^^^^^^^ - -Meta types are so abstract that they are more like type constraints than concrete types. -Currently TorchScript defines one meta-type, ``Any``, that represents any TorchScript type. - -``Any`` Type -"""""""""""" - -The ``Any`` type represents any TorchScript type. ``Any`` specifies no type constraints, thus there is no type-checking on ``Any``. -As such it can be bound to any Python or TorchScript data types (e.g., ``int``, TorchScript ``tuple``, or an arbitrary Python class that is not scripted). - -:: - - TSMetaType ::= "Any" - -Where: - -* ``Any`` is the Python class name from the typing module. Therefore, to use the ``Any`` type, you must import it from ``typing`` (e.g., ``from typing import Any``). -* Since ``Any`` can represent any TorchScript type, the set of operators that are allowed to operate on values of this type on ``Any`` is limited. - -Operators Supported for ``Any`` Type -"""""""""""""""""""""""""""""""""""" - -* Assignment to data of ``Any`` type. -* Binding to parameter or return of ``Any`` type. -* ``x is``, ``x is not`` where ``x`` is of ``Any`` type. -* ``isinstance(x, Type)`` where ``x`` is of ``Any`` type. -* Data of ``Any`` type is printable. -* Data of ``List[Any]`` type may be sortable if the data is a list of values of the same type ``T`` and that ``T`` supports comparison operators. - -**Compared to Python** - - -``Any`` is the least constrained type in the TorchScript type system. In that sense, it is quite similar to the -``Object`` class in Python. However, ``Any`` only supports a subset of the operators and methods that are supported by ``Object``. - -Design Notes -"""""""""""" - -When we script a PyTorch module, we may encounter data that is not involved in the execution of the script. Nevertheless, it has to be described -by a type schema. It is not only cumbersome to describe static types for unused data (in the context of the script), but also may lead to unnecessary -scripting failures. ``Any`` is introduced to describe the type of the data where precise static types are not necessary for compilation. - -**Example 1** - -This example illustrates how ``Any`` can be used to allow the second element of the tuple parameter to be of any type. This is possible -because ``x[1]`` is not involved in any computation that requires knowing its precise type. - -.. testcode:: - - import torch - - from typing import Tuple - from typing import Any - - @torch.jit.export - def inc_first_element(x: Tuple[int, Any]): - return (x[0]+1, x[1]) - - m = torch.jit.script(inc_first_element) - print(m((1,2.0))) - print(m((1,(100,200)))) - -The example above produces the following output: - -.. testoutput:: - - (2, 2.0) - (2, (100, 200)) - -The second element of the tuple is of ``Any`` type, thus can bind to multiple types. -For example, ``(1, 2.0)`` binds a float type to ``Any`` as in ``Tuple[int, Any]``, -whereas ``(1, (100, 200))`` binds a tuple to ``Any`` in the second invocation. - - -**Example 2** - -This example illustrates how we can use ``isinstance`` to dynamically check the type of the data that is annotated as ``Any`` type: - -.. testcode:: - - import torch - from typing import Any - - def f(a:Any): - print(a) - return (isinstance(a, torch.Tensor)) - - ones = torch.ones([2]) - m = torch.jit.script(f) - print(m(ones)) - -The example above produces the following output: - -.. testoutput:: - - 1 - 1 - [ CPUFloatType{2} ] - True - -Primitive Types -^^^^^^^^^^^^^^^ - -Primitive TorchScript types are types that represent a single type of value and go with a single pre-defined -type name. - -:: - - TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" - -Structural Types -^^^^^^^^^^^^^^^^ - -Structural types are types that are structurally defined without a user-defined name (unlike nominal types), -such as ``Future[int]``. Structural types are composable with any ``TSType``. - -:: - - TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | - TSOptional | TSUnion | TSFuture | TSRRef | TSAwait - - TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" - TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" - TSList ::= "List" "[" TSType "]" - TSOptional ::= "Optional" "[" TSType "]" - TSUnion ::= "Union" "[" (TSType ",")* TSType "]" - TSFuture ::= "Future" "[" TSType "]" - TSRRef ::= "RRef" "[" TSType "]" - TSAwait ::= "Await" "[" TSType "]" - TSDict ::= "Dict" "[" KeyType "," TSType "]" - KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" - -Where: - -* ``Tuple``, ``List``, ``Optional``, ``Union``, ``Future``, ``Dict`` represent Python type class names that are defined in the module ``typing``. To use these type names, you must import them from ``typing`` (e.g., ``from typing import Tuple``). -* ``namedtuple`` represents the Python class ``collections.namedtuple`` or ``typing.NamedTuple``. -* ``Future`` and ``RRef`` represent the Python classes ``torch.futures`` and ``torch.distributed.rpc``. -* ``Await`` represent the Python class ``torch._awaits._Await`` - -**Compared to Python** - -Apart from being composable with TorchScript types, these TorchScript structural types often support a common subset of the operators and methods of their Python counterparts. - -**Example 1** - -This example uses ``typing.NamedTuple`` syntax to define a tuple: - -.. testcode:: - - import torch - from typing import NamedTuple - from typing import Tuple - - class MyTuple(NamedTuple): - first: int - second: int - - def inc(x: MyTuple) -> Tuple[int, int]: - return (x.first+1, x.second+1) - - t = MyTuple(first=1, second=2) - scripted_inc = torch.jit.script(inc) - print("TorchScript:", scripted_inc(t)) - -The example above produces the following output: - -.. testoutput:: - - TorchScript: (2, 3) - -**Example 2** - -This example uses ``collections.namedtuple`` syntax to define a tuple: - -.. testcode:: - - import torch - from typing import NamedTuple - from typing import Tuple - from collections import namedtuple - - _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)]) - _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second']) - - def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]: - return (x.first+1, x.second+1) - - m = torch.jit.script(inc) - print(inc(_UnannotatedNamedTuple(1,2))) - -The example above produces the following output: - -.. testoutput:: - - (2, 3) - -**Example 3** - -This example illustrates a common mistake of annotating structural types, i.e., not importing the composite type -classes from the ``typing`` module: - -:: - - import torch - - # ERROR: Tuple not recognized because not imported from typing - @torch.jit.export - def inc(x: Tuple[int, int]): - return (x[0]+1, x[1]+1) - - m = torch.jit.script(inc) - print(m((1,2))) - -Running the above code yields the following scripting error: - -:: - - File "test-tuple.py", line 5, in - def inc(x: Tuple[int, int]): - NameError: name 'Tuple' is not defined - -The remedy is to add the line ``from typing import Tuple`` to the beginning of the code. - -Nominal Types -^^^^^^^^^^^^^ - -Nominal TorchScript types are Python classes. These types are called nominal because they are declared with a custom -name and are compared using class names. Nominal classes are further classified into the following categories: - -:: - - TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum - -Among them, ``TSCustomClass`` and ``TSEnum`` must be compilable to TorchScript Intermediate Representation (IR). This is enforced by the type-checker. - -Built-in Class -^^^^^^^^^^^^^^ - -Built-in nominal types are Python classes whose semantics are built into the TorchScript system (e.g., tensor types). -TorchScript defines the semantics of these built-in nominal types, and often supports only a subset of the methods or -attributes of its Python class definition. - -:: - - TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" | - "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... - TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" | - "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor - - -Special Note on torch.nn.ModuleList and torch.nn.ModuleDict -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" - -Although ``torch.nn.ModuleList`` and ``torch.nn.ModuleDict`` are defined as a list and dictionary in Python, -they behave more like tuples in TorchScript: - -* In TorchScript, instances of ``torch.nn.ModuleList`` or ``torch.nn.ModuleDict`` are immutable. -* Code that iterates over ``torch.nn.ModuleList`` or ``torch.nn.ModuleDict`` is completely unrolled so that elements of ``torch.nn.ModuleList`` or keys of ``torch.nn.ModuleDict`` can be of different subclasses of ``torch.nn.Module``. - -**Example** - -The following example highlights the use of a few built-in Torchscript classes (``torch.*``): - -:: - - import torch - - @torch.jit.script - class A: - def __init__(self): - self.x = torch.rand(3) - - def f(self, y: torch.device): - return self.x.to(device=y) - - def g(): - a = A() - return a.f(torch.device("cpu")) - - script_g = torch.jit.script(g) - print(script_g.graph) - -Custom Class -^^^^^^^^^^^^ - -Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules. - -:: - - TSClassDef ::= [ "@torch.jit.script" ] - "class" ClassName [ "(object)" ] ":" - MethodDefinition | - [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ] - MethodDefinition - -Where: - -* Classes must be new-style classes. Python 3 supports only new-style classes. In Python 2.x, a new-style class is specified by subclassing from the object. -* Instance data attributes are statically typed, and instance attributes must be declared by assignments inside the ``__init__()`` method. -* Method overloading is not supported (i.e., you cannot have multiple methods with the same method name). -* ``MethodDefinition`` must be compilable to TorchScript IR and adhere to TorchScript’s type-checking rules, (i.e., all methods must be valid TorchScript functions and class attribute definitions must be valid TorchScript statements). -* ``torch.jit.ignore`` and ``torch.jit.unused`` can be used to ignore the method or function that is not fully torchscriptable or should be ignored by the compiler. - -**Compared to Python** - - -TorchScript custom classes are quite limited compared to their Python counterpart. Torchscript custom classes: - -* Do not support class attributes. -* Do not support subclassing except for subclassing an interface type or object. -* Do not support method overloading. -* Must initialize all its instance attributes in ``__init__()``; this is because TorchScript constructs a static schema of the class by inferring attribute types in ``__init__()``. -* Must contain only methods that satisfy TorchScript type-checking rules and are compilable to TorchScript IRs. - -**Example 1** - -Python classes can be used in TorchScript if they are annotated with ``@torch.jit.script``, similar to how a TorchScript function would be declared: - -:: - - @torch.jit.script - class MyClass: - def __init__(self, x: int): - self.x = x - - def inc(self, val: int): - self.x += val - - -**Example 2** - -A TorchScript custom class type must "declare" all its instance attributes by assignments in ``__init__()``. If an instance attribute is not defined in ``__init__()`` but accessed in other methods of the class, the class cannot be compiled as a TorchScript class, as shown in the following example: - -:: - - import torch - - @torch.jit.script - class foo: - def __init__(self): - self.y = 1 - - # ERROR: self.x is not defined in __init__ - def assign_x(self): - self.x = torch.rand(2, 3) - -The class will fail to compile and issue the following error: - -:: - - RuntimeError: - Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: - def assign_x(self): - self.x = torch.rand(2, 3) - ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE - -**Example 3** - -In this example, a TorchScript custom class defines a class variable name, which is not allowed: - -:: - - import torch - - @torch.jit.script - class MyClass(object): - name = "MyClass" - def __init__(self, x: int): - self.x = x - - def fn(a: MyClass): - return a.name - -It leads to the following compile-time error: - -:: - - RuntimeError: - '__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?: - File "test-class2.py", line 10 - def fn(a: MyClass): - return a.name - ~~~~~~ <--- HERE - -Enum Type -^^^^^^^^^ - -Like custom classes, semantics of the enum type are user-defined and the entire class definition must be compilable to TorchScript IR and adhere to TorchScript type-checking rules. - -:: - - TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":" - ( MemberIdentifier "=" Value )+ - ( MethodDefinition )* - -Where: - -* Value must be a TorchScript literal of type ``int``, ``float``, or ``str``, and must be of the same TorchScript type. -* ``TSEnumType`` is the name of a TorchScript enumerated type. Similar to Python enum, TorchScript allows restricted ``Enum`` subclassing, that is, subclassing an enumerated is allowed only if it does not define any members. - -**Compared to Python** - - -* TorchScript supports only ``enum.Enum``. It does not support other variations such as ``enum.IntEnum``, ``enum.Flag``, ``enum.IntFlag``, and ``enum.auto``. -* Values of TorchScript enum members must be of the same type and can only be ``int``, ``float``, or ``str`` types, whereas Python enum members can be of any type. -* Enums containing methods are ignored in TorchScript. - -**Example 1** - -The following example defines the class ``Color`` as an ``Enum`` type: - -:: - - import torch - from enum import Enum - - class Color(Enum): - RED = 1 - GREEN = 2 - - def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - return x == y - - m = torch.jit.script(enum_fn) - - print("Eager: ", enum_fn(Color.RED, Color.GREEN)) - print("TorchScript: ", m(Color.RED, Color.GREEN)) - -**Example 2** - -The following example shows the case of restricted enum subclassing, where ``BaseColor`` does not define any member, thus can be subclassed by ``Color``: - -:: - - import torch - from enum import Enum - - class BaseColor(Enum): - def foo(self): - pass - - class Color(BaseColor): - RED = 1 - GREEN = 2 - - def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - return x == y - - m = torch.jit.script(enum_fn) - - print("TorchScript: ", m(Color.RED, Color.GREEN)) - print("Eager: ", enum_fn(Color.RED, Color.GREEN)) - -TorchScript Module Class -^^^^^^^^^^^^^^^^^^^^^^^^ - -``TSModuleType`` is a special class type that is inferred from object instances that are created outside TorchScript. ``TSModuleType`` is named by the Python class of the object instance. The ``__init__()`` method of the Python class is not considered a TorchScript method, so it does not have to comply with TorchScript’s type-checking rules. - -The type schema of a module instance class is constructed directly from an instance object (created outside the scope of TorchScript) rather than inferred from ``__init__()`` like custom classes. It is possible that two objects of the same instance class type follow two different type schemas. - -In this sense, ``TSModuleType`` is not really a static type. Therefore, for type safety considerations, ``TSModuleType`` cannot be used in a TorchScript type annotation or be composed with ``TSType``. - -Module Instance Class -^^^^^^^^^^^^^^^^^^^^^ - -TorchScript module type represents the type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to ``forward``). The Python module class is treated as a module instance class, so the ``__init__()`` method of the Python module class is not subject to the type-checking rules of TorchScript. - -:: - - TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":" - ClassBodyDefinition - -Where: - -* ``forward()`` and other methods decorated with ``@torch.jit.export`` must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules. - -Unlike custom classes, only the forward method and other methods decorated with ``@torch.jit.export`` of the module type need to be compilable. Most notably, ``__init__()`` is not considered a TorchScript method. Consequently, module type constructors cannot be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into ``torch.jit.script(ModuleObj)``. - -**Example 1** - -This example illustrates a few features of module types: - -* The ``TestModule`` instance is created outside the scope of TorchScript (i.e., before invoking ``torch.jit.script``). -* ``__init__()`` is not considered a TorchScript method, therefore, it does not have to be annotated and can contain arbitrary Python code. In addition, the ``__init__()`` method of an instance class cannot be invoked in TorchScript code. Because ``TestModule`` instances are instantiated in Python, in this example, ``TestModule(2.0)`` and ``TestModule(2)`` create two instances with different types for its data attributes. ``self.x`` is of type ``float`` for ``TestModule(2.0)``, whereas ``self.y`` is of type ``int`` for ``TestModule(2.0)``. -* TorchScript automatically compiles other methods (e.g., ``mul()``) invoked by methods annotated via ``@torch.jit.export`` or ``forward()`` methods. -* Entry-points to a TorchScript program are either ``forward()`` of a module type, functions annotated as ``torch.jit.script``, or methods annotated as ``torch.jit.export``. - -.. testcode:: - - import torch - - class TestModule(torch.nn.Module): - def __init__(self, v): - super().__init__() - self.x = v - - def forward(self, inc: int): - return self.x + inc - - m = torch.jit.script(TestModule(1)) - print(f"First instance: {m(3)}") - - m = torch.jit.script(TestModule(torch.ones([5]))) - print(f"Second instance: {m(3)}") - -The example above produces the following output: - -.. testoutput:: - - First instance: 4 - Second instance: tensor([4., 4., 4., 4., 4.]) - -**Example 2** - -The following example shows an incorrect usage of module type. Specifically, this example invokes the constructor of ``TestModule`` inside the scope of TorchScript: - -.. testcode:: - - import torch - - class TestModule(torch.nn.Module): - def __init__(self, v): - super().__init__() - self.x = v - - def forward(self, x: int): - return self.x + x - - class MyModel: - def __init__(self, v: int): - self.val = v - - @torch.jit.export - def doSomething(self, val: int) -> int: - # error: should not invoke the constructor of module type - myModel = TestModule(self.val) - return myModel(val) - - # m = torch.jit.script(MyModel(2)) # Results in below RuntimeError - # RuntimeError: Could not get name of python class object - -.. _type_annotation: - - -Type Annotation -~~~~~~~~~~~~~~~ -Since TorchScript is statically typed, programmers need to annotate types at *strategic points* of TorchScript code so that every local variable or -instance data attribute has a static type, and every function and method has a statically typed signature. - -When to Annotate Types -^^^^^^^^^^^^^^^^^^^^^^ -In general, type annotations are only needed in places where static types cannot be automatically inferred (e.g., parameters or sometimes return types to -methods or functions). Types of local variables and data attributes are often automatically inferred from their assignment statements. Sometimes an inferred type -may be too restrictive, e.g., ``x`` being inferred as ``NoneType`` through assignment ``x = None``, whereas ``x`` is actually used as an ``Optional``. In such -cases, type annotations may be needed to overwrite auto inference, e.g., ``x: Optional[int] = None``. Note that it is always safe to type annotate a local variable -or data attribute even if its type can be automatically inferred. The annotated type must be congruent with TorchScript’s type-checking. - -When a parameter, local variable, or data attribute is not type annotated and its type cannot be automatically inferred, TorchScript assumes it to be a -default type of ``TensorType``, ``List[TensorType]``, or ``Dict[str, TensorType]``. - -Annotate Function Signature -^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Since a parameter may not be automatically inferred from the body of the function (including both functions and methods), they need to be type annotated. Otherwise, they assume the default type ``TensorType``. - -TorchScript supports two styles for method and function signature type annotation: - -* **Python3-style** annotates types directly on the signature. As such, it allows individual parameters to be left unannotated (whose type will be the default type of ``TensorType``), or allows the return type to be left unannotated (whose type will be automatically inferred). - - -:: - - Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":" - FuncOrMethodBody - ParamAnnot ::= Identifier [ ":" TSType ] "," - ReturnAnnot ::= "->" TSType - -Note that when using Python3 style, the type ``self`` is automatically inferred and should not be annotated. - -* **Mypy style** annotates types as a comment right below the function/method declaration. In the Mypy style, since parameter names do not appear in the annotation, all parameters have to be annotated. - - -:: - - MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ] - ParamAnnot ::= TSType "," - ReturnAnnot ::= "->" TSType - -**Example 1** - -In this example: - -* ``a`` is not annotated and assumes the default type of ``TensorType``. -* ``b`` is annotated as type ``int``. -* The return type is not annotated and is automatically inferred as type ``TensorType`` (based on the type of the value being returned). - -:: - - import torch - - def f(a, b: int): - return a+b - - m = torch.jit.script(f) - print("TorchScript:", m(torch.ones([6]), 100)) - -**Example 2** - -The following example uses Mypy style annotation. Note that parameters or return values must be annotated even if some of -them assume the default type. - -:: - - import torch - - def f(a, b): - # type: (torch.Tensor, int) → torch.Tensor - return a+b - - m = torch.jit.script(f) - print("TorchScript:", m(torch.ones([6]), 100)) - - -Annotate Variables and Data Attributes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In general, types of data attributes (including class and instance data attributes) and local variables can be automatically inferred from assignment statements. -Sometimes, however, if a variable or attribute is associated with values of different types (e.g., as ``None`` or ``TensorType``), then they may need to be explicitly -type annotated as a *wider* type such as ``Optional[int]`` or ``Any``. - -Local Variables -""""""""""""""" -Local variables can be annotated according to Python3 typing module annotation rules, i.e., - -:: - - LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr - -In general, types of local variables can be automatically inferred. In some cases, however, you may need to annotate a multi-type for local variables -that may be associated with different concrete types. Typical multi-types include ``Optional[T]`` and ``Any``. - -**Example** - -:: - - import torch - - def f(a, setVal: bool): - value: Optional[torch.Tensor] = None - if setVal: - value = a - return value - - ones = torch.ones([6]) - m = torch.jit.script(f) - print("TorchScript:", m(ones, True), m(ones, False)) - -Instance Data Attributes -"""""""""""""""""""""""" -For ``ModuleType`` classes, instance data attributes can be annotated according to Python3 typing module annotation rules. Instance data attributes can be annotated (optionally) as final -via ``Final``. - -:: - - "class" ClassIdentifier "(torch.nn.Module):" - InstanceAttrIdentifier ":" ["Final("] TSType [")"] - ... - -Where: - -* ``InstanceAttrIdentifier`` is the name of an instance attribute. -* ``Final`` indicates that the attribute cannot be re-assigned outside of ``__init__`` or overridden in subclasses. - -**Example** - -:: - - import torch - - class MyModule(torch.nn.Module): - offset_: int - - def __init__(self, offset): - self.offset_ = offset - - ... - - - -Type Annotation APIs -^^^^^^^^^^^^^^^^^^^^ - -``torch.jit.annotate(T, expr)`` -""""""""""""""""""""""""""""""" -This API annotates type ``T`` to an expression ``expr``. This is often used when the default type of an expression is not the type intended by the programmer. -For instance, an empty list (dictionary) has the default type of ``List[TensorType]`` (``Dict[TensorType, TensorType]``), but sometimes it may be used to initialize -a list of some other types. Another common use case is for annotating the return type of ``tensor.tolist()``. Note, however, that it cannot be used to annotate -the type of a module attribute in `__init__`; ``torch.jit.Attribute`` should be used for this instead. - -**Example** - -In this example, ``[]`` is declared as a list of integers via ``torch.jit.annotate`` (instead of assuming ``[]`` to be the default type of ``List[TensorType]``): - -:: - - import torch - from typing import List - - def g(l: List[int], val: int): - l.append(val) - return l - - def f(val: int): - l = g(torch.jit.annotate(List[int], []), val) - return l - - m = torch.jit.script(f) - print("Eager:", f(3)) - print("TorchScript:", m(3)) - - -See :meth:`torch.jit.annotate` for more information. - - -Type Annotation Appendix -^^^^^^^^^^^^^^^^^^^^^^^^ - -TorchScript Type System Definition -"""""""""""""""""""""""""""""""""" - -:: - - TSAllType ::= TSType | TSModuleType - TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType - - TSMetaType ::= "Any" - TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" - - TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | - TSUnion | TSFuture | TSRRef | TSAwait - TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" - TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" - TSList ::= "List" "[" TSType "]" - TSOptional ::= "Optional" "[" TSType "]" - TSUnion ::= "Union" "[" (TSType ",")* TSType "]" - TSFuture ::= "Future" "[" TSType "]" - TSRRef ::= "RRef" "[" TSType "]" - TSAwait ::= "Await" "[" TSType "]" - TSDict ::= "Dict" "[" KeyType "," TSType "]" - KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" - - TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum - TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"| - "torch.dtype" | "torch.nn.ModuleList" | - "torch.nn.ModuleDict" | ... - TSTensor ::= "torch.tensor" and subclasses - -Unsupported Typing Constructs -""""""""""""""""""""""""""""" -TorchScript does not support all features and types of the Python3 `typing `_ module. -Any functionality from the `typing `_ module that is not explicitly specified in this -documentation is unsupported. The following table summarizes ``typing`` constructs that are either unsupported or supported with restrictions in TorchScript. - -============================= ================ - Item Description ------------------------------ ---------------- -``typing.Any`` In development -``typing.NoReturn`` Not supported -``typing.Callable`` Not supported -``typing.Literal`` Not supported -``typing.ClassVar`` Not supported -``typing.Final`` Supported for module attributes, class attribute, and annotations, but not for functions. -``typing.AnyStr`` Not supported -``typing.overload`` In development -Type aliases Not supported -Nominal typing In development -Structural typing Not supported -NewType Not supported -Generics Not supported -============================= ================ - - -.. _expressions: - - -Expressions -~~~~~~~~~~~ - -The following section describes the grammar of expressions that are supported in TorchScript. -It is modeled after `the expressions chapter of the Python language reference `_. - -Arithmetic Conversions -^^^^^^^^^^^^^^^^^^^^^^ -There are a number of implicit type conversions that are performed in TorchScript: - - -* A ``Tensor`` with a ``float`` or ``int`` data type can be implicitly converted to an instance of ``FloatType`` or ``IntType`` provided that it has a size of 0, does not have ``require_grad`` set to ``True``, and will not require narrowing. -* Instances of ``StringType`` can be implicitly converted to ``DeviceType``. -* The implicit conversion rules from the two bullet points above can be applied to instances of ``TupleType`` to produce instances of ``ListType`` with the appropriate contained type. - - -Explicit conversions can be invoked using the ``float``, ``int``, ``bool``, and ``str`` built-in functions -that accept primitive data types as arguments and can accept user-defined types if they implement -``__bool__``, ``__str__``, etc. - - -Atoms -^^^^^ -Atoms are the most basic elements of expressions. - -:: - - atom ::= identifier | literal | enclosure - enclosure ::= parenth_form | list_display | dict_display - -Identifiers -""""""""""" -The rules that dictate what is a legal identifier in TorchScript are the same as -their `Python counterparts `_. - -Literals -"""""""" - -:: - - literal ::= stringliteral | integer | floatnumber - -Evaluation of a literal yields an object of the appropriate type with the specific value -(with approximations applied as necessary for floats). Literals are immutable, and multiple evaluations -of identical literals may obtain the same object or distinct objects with the same value. -`stringliteral `_, -`integer `_, and -`floatnumber `_ -are defined in the same way as their Python counterparts. - -Parenthesized Forms -""""""""""""""""""" - -:: - - parenth_form ::= '(' [expression_list] ')' - -A parenthesized expression list yields whatever the expression list yields. If the list contains at least one -comma, it yields a ``Tuple``; otherwise, it yields the single expression inside the expression list. An empty -pair of parentheses yields an empty ``Tuple`` object (``Tuple[]``). - -List and Dictionary Displays -"""""""""""""""""""""""""""" - -:: - - list_comprehension ::= expression comp_for - comp_for ::= 'for' target_list 'in' or_expr - list_display ::= '[' [expression_list | list_comprehension] ']' - dict_display ::= '{' [key_datum_list | dict_comprehension] '}' - key_datum_list ::= key_datum (',' key_datum)* - key_datum ::= expression ':' expression - dict_comprehension ::= key_datum comp_for - -Lists and dicts can be constructed by either listing the container contents explicitly or by providing -instructions on how to compute them via a set of looping instructions (i.e. a *comprehension*). A comprehension -is semantically equivalent to using a for loop and appending to an ongoing list. -Comprehensions implicitly create their own scope to make sure that the items of the target list do not leak into the -enclosing scope. In the case that container items are explicitly listed, the expressions in the expression list -are evaluated left-to-right. If a key is repeated in a ``dict_display`` that has a ``key_datum_list``, the -resultant dictionary uses the value from the rightmost datum in the list that uses the repeated key. - -Primaries -^^^^^^^^^ - -:: - - primary ::= atom | attributeref | subscription | slicing | call - - -Attribute References -"""""""""""""""""""" - -:: - - attributeref ::= primary '.' identifier - - -The ``primary`` must evaluate to an object of a type that supports attribute references that have an attribute named -``identifier``. - -Subscriptions -""""""""""""" - -:: - - subscription ::= primary '[' expression_list ']' - - -The ``primary`` must evaluate to an object that supports subscription. - -* If the primary is a ``List``, ``Tuple``, or ``str``, the expression list must evaluate to an integer or slice. -* If the primary is a ``Dict``, the expression list must evaluate to an object of the same type as the key type of the ``Dict``. -* If the primary is a ``ModuleList``, the expression list must be an ``integer`` literal. -* If the primary is a ``ModuleDict``, the expression must be a ``stringliteral``. - - -Slicings -"""""""" -A slicing selects a range of items in a ``str``, ``Tuple``, ``List``, or ``Tensor``. Slicings may be used as -expressions or targets in assignment or ``del`` statements. - -:: - - slicing ::= primary '[' slice_list ']' - slice_list ::= slice_item (',' slice_item)* [','] - slice_item ::= expression | proper_slice - proper_slice ::= [expression] ':' [expression] [':' [expression] ] - -Slicings with more than one slice item in their slice lists can only be used with primaries that evaluate to an -object of type ``Tensor``. - - -Calls -""""" - -:: - - call ::= primary '(' argument_list ')' - argument_list ::= args [',' kwargs] | kwargs - args ::= [arg (',' arg)*] - kwargs ::= [kwarg (',' kwarg)*] - kwarg ::= arg '=' expression - arg ::= identifier - - -The ``primary`` must desugar or evaluate to a callable object. All argument expressions are evaluated -before the call is attempted. - -Power Operator -^^^^^^^^^^^^^^ - -:: - - power ::= primary ['**' u_expr] - - -The power operator has the same semantics as the built-in pow function (not supported); it computes its -left argument raised to the power of its right argument. It binds more tightly than unary operators on the -left, but less tightly than unary operators on the right; i.e. ``-2 ** -3 == -(2 ** (-3))``. The left and right -operands can be ``int``, ``float`` or ``Tensor``. Scalars are broadcast in the case of scalar-tensor/tensor-scalar -exponentiation operations, and tensor-tensor exponentiation is done elementwise without any broadcasting. - -Unary and Arithmetic Bitwise Operations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - u_expr ::= power | '-' power | '~' power - -The unary ``-`` operator yields the negation of its argument. The unary ``~`` operator yields the bitwise inversion -of its argument. ``-`` can be used with ``int``, ``float``, and ``Tensor`` of ``int`` and ``float``. -``~`` can only be used with ``int`` and ``Tensor`` of ``int``. - -Binary Arithmetic Operations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr - a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr - -The binary arithmetic operators can operate on ``Tensor``, ``int``, and ``float``. For tensor-tensor ops, both arguments must -have the same shape. For scalar-tensor or tensor-scalar ops, the scalar is usually broadcast to the size of the -tensor. Division ops can only accept scalars as their right-hand side argument, and do not support broadcasting. -The ``@`` operator is for matrix multiplication and only operates on ``Tensor`` arguments. The multiplication operator -(``*``) can be used with a list and integer in order to get a result that is the original list repeated a certain -number of times. - -Shifting Operations -^^^^^^^^^^^^^^^^^^^ - -:: - - shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr - - -These operators accept two ``int`` arguments, two ``Tensor`` arguments, or a ``Tensor`` argument and an ``int`` or -``float`` argument. In all cases, a right shift by ``n`` is defined as floor division by ``pow(2, n)``, and a left shift -by ``n`` is defined as multiplication by ``pow(2, n)``. When both arguments are ``Tensors``, they must have the same -shape. When one is a scalar and the other is a ``Tensor``, the scalar is logically broadcast to match the size of -the ``Tensor``. - -Binary Bitwise Operations -^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - and_expr ::= shift_expr | and_expr '&' shift_expr - xor_expr ::= and_expr | xor_expr '^' and_expr - or_expr ::= xor_expr | or_expr '|' xor_expr - - -The ``&`` operator computes the bitwise AND of its arguments, the ``^`` the bitwise XOR, and the ``|`` the bitwise OR. -Both operands must be ``int`` or ``Tensor``, or the left operand must be ``Tensor`` and the right operand must be -``int``. When both operands are ``Tensor``, they must have the same shape. When the right operand is ``int``, and -the left operand is ``Tensor``, the right operand is logically broadcast to match the shape of the ``Tensor``. - -Comparisons -^^^^^^^^^^^ - -:: - - comparison ::= or_expr (comp_operator or_expr)* - comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in' - -A comparison yields a boolean value (``True`` or ``False``), or if one of the operands is a ``Tensor``, a boolean -``Tensor``. Comparisons can be chained arbitrarily as long as they do not yield boolean ``Tensors`` that have more -than one element. ``a op1 b op2 c ...`` is equivalent to ``a op1 b and b op2 c and ...``. - -Value Comparisons -""""""""""""""""" -The operators ``<``, ``>``, ``==``, ``>=``, ``<=``, and ``!=`` compare the values of two objects. The two objects generally need to be of -the same type, unless there is an implicit type conversion available between the objects. User-defined types can -be compared if rich comparison methods (e.g., ``__lt__``) are defined on them. Built-in type comparison works like -Python: - -* Numbers are compared mathematically. -* Strings are compared lexicographically. -* ``lists``, ``tuples``, and ``dicts`` can be compared only to other ``lists``, ``tuples``, and ``dicts`` of the same type and are compared using the comparison operator of corresponding elements. - -Membership Test Operations -"""""""""""""""""""""""""" -The operators ``in`` and ``not in`` test for membership. ``x in s`` evaluates to ``True`` if ``x`` is a member of ``s`` and ``False`` otherwise. -``x not in s`` is equivalent to ``not x in s``. This operator is supported for ``lists``, ``dicts``, and ``tuples``, and can be used with -user-defined types if they implement the ``__contains__`` method. - -Identity Comparisons -"""""""""""""""""""" -For all types except ``int``, ``double``, ``bool``, and ``torch.device``, operators ``is`` and ``is not`` test for the object’s identity; -``x is y`` is ``True`` if and only if ``x`` and ``y`` are the same object. For all other types, ``is`` is equivalent to -comparing them using ``==``. ``x is not y`` yields the inverse of ``x is y``. - -Boolean Operations -^^^^^^^^^^^^^^^^^^ - -:: - - or_test ::= and_test | or_test 'or' and_test - and_test ::= not_test | and_test 'and' not_test - not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test - -User-defined objects can customize their conversion to ``bool`` by implementing a ``__bool__`` method. The operator ``not`` -yields ``True`` if its operand is false, ``False`` otherwise. The expression ``x`` and ``y`` first evaluates ``x``; if it is ``False``, its -value (``False``) is returned; otherwise, ``y`` is evaluated and its value is returned (``False`` or ``True``). The expression ``x`` or ``y`` -first evaluates ``x``; if it is ``True``, its value (``True``) is returned; otherwise, ``y`` is evaluated and its value is returned -(``False`` or ``True``). - -Conditional Expressions -^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression] - expression ::= conditional_expression - -The expression ``x if c else y`` first evaluates the condition ``c`` rather than x. If ``c`` is ``True``, ``x`` is -evaluated and its value is returned; otherwise, ``y`` is evaluated and its value is returned. As with if-statements, -``x`` and ``y`` must evaluate to a value of the same type. - -Expression Lists -^^^^^^^^^^^^^^^^ - -:: - - expression_list ::= expression (',' expression)* [','] - starred_item ::= '*' primary - -A starred item can only appear on the left-hand side of an assignment statement, e.g., ``a, *b, c = ...``. - -.. statements: - -Simple Statements -~~~~~~~~~~~~~~~~~ - -The following section describes the syntax of simple statements that are supported in TorchScript. -It is modeled after `the simple statements chapter of the Python language reference `_. - -Expression Statements -^^^^^^^^^^^^^^^^^^^^^^ - -:: - - expression_stmt ::= starred_expression - starred_expression ::= expression | (starred_item ",")* [starred_item] - starred_item ::= assignment_expression | "*" or_expr - -Assignment Statements -^^^^^^^^^^^^^^^^^^^^^^ - -:: - - assignment_stmt ::= (target_list "=")+ (starred_expression) - target_list ::= target ("," target)* [","] - target ::= identifier - | "(" [target_list] ")" - | "[" [target_list] "]" - | attributeref - | subscription - | slicing - | "*" target - -Augmented Assignment Statements -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - augmented_assignment_stmt ::= augtarget augop (expression_list) - augtarget ::= identifier | attributeref | subscription - augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | - "**="| ">>=" | "<<=" | "&=" | "^=" | "|=" - - -Annotated Assignment Statements -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -:: - - annotated_assignment_stmt ::= augtarget ":" expression - ["=" (starred_expression)] - -The ``raise`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - raise_stmt ::= "raise" [expression ["from" expression]] - -Raise statements in TorchScript do not support ``try\except\finally``. - -The ``assert`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - assert_stmt ::= "assert" expression ["," expression] - -Assert statements in TorchScript do not support ``try\except\finally``. - -The ``return`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - return_stmt ::= "return" [expression_list] - -Return statements in TorchScript do not support ``try\except\finally``. - -The ``del`` Statement -^^^^^^^^^^^^^^^^^^^^^^ - -:: - - del_stmt ::= "del" target_list - -The ``pass`` Statement -^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - pass_stmt ::= "pass" - -The ``print`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")" - -The ``break`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - break_stmt ::= "break" - -The ``continue`` Statement: -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - continue_stmt ::= "continue" - -Compound Statements -~~~~~~~~~~~~~~~~~~~ - -The following section describes the syntax of compound statements that are supported in TorchScript. -The section also highlights how Torchscript differs from regular Python statements. -It is modeled after `the compound statements chapter of the Python language reference `_. - -The ``if`` Statement -^^^^^^^^^^^^^^^^^^^^^ - -Torchscript supports both basic ``if/else`` and ternary ``if/else``. - -Basic ``if/else`` Statement -"""""""""""""""""""""""""""" - -:: - - if_stmt ::= "if" assignment_expression ":" suite - ("elif" assignment_expression ":" suite) - ["else" ":" suite] - -``elif`` statements can repeat for an arbitrary number of times, but it needs to be before ``else`` statement. - -Ternary ``if/else`` Statement -"""""""""""""""""""""""""""""" - -:: - - if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list] - -**Example 1** - -A ``tensor`` with 1 dimension is promoted to ``bool``: - -.. testcode:: - - import torch - - @torch.jit.script - def fn(x: torch.Tensor): - if x: # The tensor gets promoted to bool - return True - return False - print(fn(torch.rand(1))) - -The example above produces the following output: - -.. testoutput:: - - True - -**Example 2** - -A ``tensor`` with multi dimensions are not promoted to ``bool``: - -:: - - import torch - - # Multi dimensional Tensors error out. - - @torch.jit.script - def fn(): - if torch.rand(2): - print("Tensor is available") - - if torch.rand(4,5,6): - print("Tensor is available") - - print(fn()) - -Running the above code yields the following ``RuntimeError``. - -:: - - RuntimeError: The following operation failed in the TorchScript interpreter. - Traceback of TorchScript (most recent call last): - @torch.jit.script - def fn(): - if torch.rand(2): - ~~~~~~~~~~~~ <--- HERE - print("Tensor is available") - RuntimeError: Boolean value of Tensor with more than one value is ambiguous - -If a conditional variable is annotated as ``final``, either the true or false branch is evaluated depending on the evaluation of the conditional variable. - -**Example 3** - -In this example, only the True branch is evaluated, since ``a`` is annotated as ``final`` and set to ``True``: - -:: - - import torch - - a : torch.jit.final[Bool] = True - - if a: - return torch.empty(2,3) - else: - return [] - - -The ``while`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - while_stmt ::= "while" assignment_expression ":" suite - -`while...else` statements are not supported in Torchscript. It results in a ``RuntimeError``. - -The ``for-in`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - for_stmt ::= "for" target_list "in" expression_list ":" suite - ["else" ":" suite] - -``for...else`` statements are not supported in Torchscript. It results in a ``RuntimeError``. - -**Example 1** - -For loops on tuples: these unroll the loop, generating a body for each member of the tuple. The body must type-check correctly for each member. - -.. testcode:: - - import torch - from typing import Tuple - - @torch.jit.script - def fn(): - tup = (3, torch.ones(4)) - for x in tup: - print(x) - - fn() - -The example above produces the following output: - -.. testoutput:: - - 3 - 1 - 1 - 1 - 1 - [ CPUFloatType{4} ] - - -**Example 2** - -For loops on lists: for loops over a ``nn.ModuleList`` will unroll the body of the loop at compile time, with each member of the module list. - -:: - - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.randn(2)) - - def forward(self, input): - return self.weight + input - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) - - def forward(self, v): - for module in self.mods: - v = module(v) - return v - - model = torch.jit.script(MyModule()) - -The ``with`` Statement -^^^^^^^^^^^^^^^^^^^^^^^ -The ``with`` statement is used to wrap the execution of a block with methods defined by a context manager. - -:: - - with_stmt ::= "with" with_item ("," with_item) ":" suite - with_item ::= expression ["as" target] - -* If a target was included in the ``with`` statement, the return value from the context manager’s ``__enter__()`` is assigned to it. Unlike python, if an exception caused the suite to be exited, its type, value, and traceback are not passed as arguments to ``__exit__()``. Three ``None`` arguments are supplied. -* ``try``, ``except``, and ``finally`` statements are not supported inside ``with`` blocks. -* Exceptions raised within ``with`` block cannot be suppressed. - -The ``tuple`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - tuple_stmt ::= tuple([iterables]) - -* Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList``, and ``torch.nn.ModuleDict``. -* You cannot convert a List to Tuple by using this built-in function. - -Unpacking all outputs into a tuple is covered by: - -:: - - abc = func() # Function that returns a tuple - a,b = func() - -The ``getattr`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - getattr_stmt ::= getattr(object, name[, default]) - -* Attribute name must be a literal string. -* Module type object is not supported (e.g., torch._C). -* Custom class object is not supported (e.g., torch.classes.*). - -The ``hasattr`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - hasattr_stmt ::= hasattr(object, name) - -* Attribute name must be a literal string. -* Module type object is not supported (e.g., torch._C). -* Custom class object is not supported (e.g., torch.classes.*). - -The ``zip`` Statement -^^^^^^^^^^^^^^^^^^^^^^ - -:: - - zip_stmt ::= zip(iterable1, iterable2) - -* Arguments must be iterables. -* Two iterables of same outer container type but different length are supported. - -**Example 1** - -Both the iterables must be of the same container type: - -.. testcode:: - - a = [1, 2] # List - b = [2, 3, 4] # List - zip(a, b) # works - -**Example 2** - -This example fails because the iterables are of different container types: - -:: - - a = (1, 2) # Tuple - b = [2, 3, 4] # List - zip(a, b) # Runtime error - -Running the above code yields the following ``RuntimeError``. - -:: - - RuntimeError: Can not iterate over a module list or - tuple with a value that does not have a statically determinable length. - -**Example 3** - -Two iterables of the same container Type but different data type is supported: - -.. testcode:: - - a = [1.3, 2.4] - b = [2, 3, 4] - zip(a, b) # Works - -Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList``, and ``torch.nn.ModuleDict``. - -The ``enumerate`` Statement -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - enumerate_stmt ::= enumerate([iterable]) - -* Arguments must be iterables. -* Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList`` and ``torch.nn.ModuleDict``. - - -.. _python-values-torch-script: - -Python Values -~~~~~~~~~~~~~ - -.. _python-builtin-functions-values-resolution: - -Resolution Rules -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When given a Python value, TorchScript attempts to resolve it in the following five different ways: - -* Compilable Python Implementation: - * When a Python value is backed by a Python implementation that can be compiled by TorchScript, TorchScript compiles and uses the underlying Python implementation. - * Example: ``torch.jit.Attribute`` -* Op Python Wrapper: - * When a Python value is a wrapper of a native PyTorch op, TorchScript emits the corresponding operator. - * Example: ``torch.jit._logging.add_stat_value`` -* Python Object Identity Match: - * For a limited set of ``torch.*`` API calls (in the form of Python values) that TorchScript supports, TorchScript attempts to match a Python value against each item in the set. - * When matched, TorchScript generates a corresponding ``SugaredValue`` instance that contains lowering logic for these values. - * Example: ``torch.jit.isinstance()`` -* Name Match: - * For Python built-in functions and constants, TorchScript identifies them by name, and creates a corresponding ``SugaredValue`` instance that implements their functionality. - * Example: ``all()`` -* Value Snapshot: - * For Python values from unrecognized modules, TorchScript attempts to take a snapshot of the value and converts it to a constant in the graph of the function(s) or method(s) that are being compiled. - * Example: ``math.pi`` - - - -.. _python-builtin-functions-support: - -Python Built-in Functions Support -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. list-table:: TorchScript Support for Python Built-in Functions - :widths: 25 25 50 - :header-rows: 1 - - * - Built-in Function - - Support Level - - Notes - * - ``abs()`` - - Partial - - Only supports ``Tensor``/``Int``/``Float`` type inputs. | Doesn't honor ``__abs__`` override. - * - ``all()`` - - Full - - - * - ``any()`` - - Full - - - * - ``ascii()`` - - None - - - * - ``bin()`` - - Partial - - Only supports ``Int`` type input. - * - ``bool()`` - - Partial - - Only supports ``Tensor``/``Int``/``Float`` type inputs. - * - ``breakpoint()`` - - None - - - * - ``bytearray()`` - - None - - - * - ``bytes()`` - - None - - - * - ``callable()`` - - None - - - * - ``chr()`` - - Partial - - Only ASCII character set is supported. - * - ``classmethod()`` - - Full - - - * - ``compile()`` - - None - - - * - ``complex()`` - - None - - - * - ``delattr()`` - - None - - - * - ``dict()`` - - Full - - - * - ``dir()`` - - None - - - * - ``divmod()`` - - Full - - - * - ``enumerate()`` - - Full - - - * - ``eval()`` - - None - - - * - ``exec()`` - - None - - - * - ``filter()`` - - None - - - * - ``float()`` - - Partial - - Doesn't honor ``__index__`` override. - * - ``format()`` - - Partial - - Manual index specification not supported. | Format type modifier not supported. - * - ``frozenset()`` - - None - - - * - ``getattr()`` - - Partial - - Attribute name must be string literal. - * - ``globals()`` - - None - - - * - ``hasattr()`` - - Partial - - Attribute name must be string literal. - * - ``hash()`` - - Full - - ``Tensor``'s hash is based on identity not numeric value. - * - ``hex()`` - - Partial - - Only supports ``Int`` type input. - * - ``id()`` - - Full - - Only supports ``Int`` type input. - * - ``input()`` - - None - - - * - ``int()`` - - Partial - - ``base`` argument not supported. | Doesn't honor ``__index__`` override. - * - ``isinstance()`` - - Full - - ``torch.jit.isintance`` provides better support when checking against container types like ``Dict[str, int]``. - * - ``issubclass()`` - - None - - - * - ``iter()`` - - None - - - * - ``len()`` - - Full - - - * - ``list()`` - - Full - - - * - ``ord()`` - - Partial - - Only ASCII character set is supported. - * - ``pow()`` - - Full - - - * - ``print()`` - - Partial - - ``separate``, ``end`` and ``file`` arguments are not supported. - * - ``property()`` - - None - - - * - ``range()`` - - Full - - - * - ``repr()`` - - None - - - * - ``reversed()`` - - None - - - * - ``round()`` - - Partial - - ``ndigits`` argument is not supported. - * - ``set()`` - - None - - - * - ``setattr()`` - - None - - - * - ``slice()`` - - Full - - - * - ``sorted()`` - - Partial - - ``key`` argument is not supported. - * - ``staticmethod()`` - - Full - - - * - ``str()`` - - Partial - - ``encoding`` and ``errors`` arguments are not supported. - * - ``sum()`` - - Full - - - * - ``super()`` - - Partial - - It can only be used in ``nn.Module``'s ``__init__`` method. - * - ``type()`` - - None - - - * - ``vars()`` - - None - - - * - ``zip()`` - - Full - - - * - ``__import__()`` - - None - - - -.. _python-builtin-values-support: - -Python Built-in Values Support -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. list-table:: TorchScript Support for Python Built-in Values - :widths: 25 25 50 - :header-rows: 1 - - * - Built-in Value - - Support Level - - Notes - * - ``False`` - - Full - - - * - ``True`` - - Full - - - * - ``None`` - - Full - - - * - ``NotImplemented`` - - None - - - * - ``Ellipsis`` - - Full - - - - -.. _torch_apis_in_torchscript: - -torch.* APIs -~~~~~~~~~~~~ - -.. _torch_apis_in_torchscript_rpc: - -Remote Procedure Calls -^^^^^^^^^^^^^^^^^^^^^^ - -TorchScript supports a subset of RPC APIs that supports running a function on -a specified remote worker instead of locally. - -Specifically, following APIs are fully supported: - -- ``torch.distributed.rpc.rpc_sync()`` - - ``rpc_sync()`` makes a blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. - - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.rpc_sync`. - -- ``torch.distributed.rpc.rpc_async()`` - - ``rpc_async()`` makes a non-blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. - - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.rpc_async`. -- ``torch.distributed.rpc.remote()`` - - ``remote.()`` executes a remote call on a worker and gets a Remote Reference ``RRef`` as the return value. - - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.remote`. - -.. _torch_apis_in_torchscript_async: - -Asynchronous Execution -^^^^^^^^^^^^^^^^^^^^^^ - -TorchScript enables you to create asynchronous computation tasks to make better use -of computation resources. This is done via supporting a list of APIs that are -only usable within TorchScript: - -- ``torch.jit.fork()`` - - Creates an asynchronous task executing func and a reference to the value of the result of this execution. Fork will return immediately. - - Synonymous to ``torch.jit._fork()``, which is only kept for backward compatibility reasons. - - More details about its usage and examples can be found in :meth:`~torch.jit.fork`. -- ``torch.jit.wait()`` - - Forces completion of a ``torch.jit.Future[T]`` asynchronous task, returning the result of the task. - - Synonymous to ``torch.jit._wait()``, which is only kept for backward compatibility reasons. - - More details about its usage and examples can be found in :meth:`~torch.jit.wait`. - - -.. _torch_apis_in_torchscript_annotation: - -Type Annotations -^^^^^^^^^^^^^^^^ - -TorchScript is statically-typed. It provides and supports a set of utilities to help annotate variables and attributes: - -- ``torch.jit.annotate()`` - - Provides a type hint to TorchScript where Python 3 style type hints do not work well. - - One common example is to annotate type for expressions like ``[]``. ``[]`` is treated as ``List[torch.Tensor]`` by default. When a different type is needed, you can use this code to hint TorchScript: ``torch.jit.annotate(List[int], [])``. - - More details can be found in :meth:`~torch.jit.annotate` -- ``torch.jit.Attribute`` - - Common use cases include providing type hint for ``torch.nn.Module`` attributes. Because their ``__init__`` methods are not parsed by TorchScript, ``torch.jit.Attribute`` should be used instead of ``torch.jit.annotate`` in the module's ``__init__`` methods. - - More details can be found in :meth:`~torch.jit.Attribute` -- ``torch.jit.Final`` - - An alias for Python's ``typing.Final``. ``torch.jit.Final`` is kept only for backward compatibility reasons. - - -.. _torch_apis_in_torchscript_meta_programming: - -Meta Programming -^^^^^^^^^^^^^^^^ - -TorchScript provides a set of utilities to facilitate meta programming: - -- ``torch.jit.is_scripting()`` - - Returns a boolean value indicating whether the current program is compiled by ``torch.jit.script`` or not. - - When used in an ``assert`` or an ``if`` statement, the scope or branch where ``torch.jit.is_scripting()`` evaluates to ``False`` is not compiled. - - Its value can be evaluated statically at compile time, thus commonly used in ``if`` statements to stop TorchScript from compiling one of the branches. - - More details and examples can be found in :meth:`~torch.jit.is_scripting` -- ``torch.jit.is_tracing()`` - - Returns a boolean value indicating whether the current program is traced by ``torch.jit.trace`` / ``torch.jit.trace_module`` or not. - - More details can be found in :meth:`~torch.jit.is_tracing` -- ``@torch.jit.ignore`` - - This decorator indicates to the compiler that a function or method should be ignored and left as a Python function. - - This allows you to leave code in your model that is not yet TorchScript compatible. - - If a function decorated by ``@torch.jit.ignore`` is called from TorchScript, ignored functions will dispatch the call to the Python interpreter. - - Models with ignored functions cannot be exported. - - More details and examples can be found in :meth:`~torch.jit.ignore` -- ``@torch.jit.unused`` - - This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception. - - This allows you to leave code in your model that is not yet TorchScript compatible and still export your model. - - If a function decorated by ``@torch.jit.unused`` is called from TorchScript, a runtime error will be raised. - - More details and examples can be found in :meth:`~torch.jit.unused` - -.. _torch_apis_in_torchscript_type_refinement: - -Type Refinement -^^^^^^^^^^^^^^^ - -- ``torch.jit.isinstance()`` - - Returns a boolean indicating whether a variable is of the specified type. - - More details about its usage and examples can be found in :meth:`~torch.jit.isinstance`. diff --git a/docs/source/jit_python_reference.md b/docs/source/jit_python_reference.md new file mode 100644 index 00000000000000..1d2b5c78a894f3 --- /dev/null +++ b/docs/source/jit_python_reference.md @@ -0,0 +1,432 @@ +(python-language-reference)= + +# Python Language Reference Coverage + +This is a 1:1 mapping of the features listed in https://docs.python.org/3/reference/ and their +support in TorchScript. The categorizations are as follows: + +```{list-table} +:widths: 40 40 20 +:header-rows: 1 + +* - Section + - Status + - Note +* - [1. Introduction](https://docs.python.org/3/reference/introduction.html) + - Not Relevant + - +* - [1.1. Alternate Implementations](https://docs.python.org/3/reference/introduction.html#alternate-implementations) + - Not Relevant + - +* - [1.2. Notation](https://docs.python.org/3/reference/introduction.html#notation) + - Not Relevant + - +* - [2. Lexical analysis](https://docs.python.org/3/reference/lexical_analysis.html#) + - Not Relevant + - +* - [2.1. Line structure](https://docs.python.org/3/reference/lexical_analysis.html#line-structure) + - Not Relevant + - +* - [2.1.1. Logical lines](https://docs.python.org/3/reference/lexical_analysis.html#logical-lines) + - Not Relevant + - +* - [2.1.2. Physical lines](https://docs.python.org/3/reference/lexical_analysis.html#physical-lines) + - Supported + - +* - [2.1.3. Comments](https://docs.python.org/3/reference/lexical_analysis.html#comments) + - Supported + - +* - [2.1.4. Encoding declarations](https://docs.python.org/3/reference/lexical_analysis.html#encoding-declarations) + - Not Supported + - TorchScript explicitly don't support unicode +* - [2.1.5. Explicit line joining](https://docs.python.org/3/reference/lexical_analysis.html#explicit-line-joining) + - Supported + - +* - [2.1.6. Implicit line joining](https://docs.python.org/3/reference/lexical_analysis.html#implicit-line-joining) + - Supported + - +* - [2.1.7. Blank lines](https://docs.python.org/3/reference/lexical_analysis.html#blank-lines) + - Supported + - +* - [2.1.8. Indentation](https://docs.python.org/3/reference/lexical_analysis.html#indentation) + - Supported + - +* - [2.1.9. Whitespace between tokens](https://docs.python.org/3/reference/lexical_analysis.html#whitespace-between-tokens) + - Not Relevant + - +* - [2.2. Other tokens](https://docs.python.org/3/reference/lexical_analysis.html#other-tokens) + - Not Relevant + - +* - [2.3. Identifiers and keywords](https://docs.python.org/3/reference/lexical_analysis.html#identifiers) + - Supported + - +* - [2.3.1. Keywords](https://docs.python.org/3/reference/lexical_analysis.html#keywords) + - Supported + - +* - [2.3.2. Reserved classes of identifiers](https://docs.python.org/3/reference/lexical_analysis.html#reserved-classes-of-identifiers) + - Supported + - +* - [2.4. Literals](https://docs.python.org/3/reference/lexical_analysis.html#literals) + - Not Relevant + - +* - [2.4.1. String and Bytes literals](https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals) + - Supported + - +* - [2.4.2. String literal concatenation](https://docs.python.org/3/reference/lexical_analysis.html#string-literal-concatenation) + - Supported + - +* - [2.4.3. Formatted string literals](https://docs.python.org/3/reference/lexical_analysis.html#formatted-string-literals) + - Partially Supported + - +* - [2.4.4. Numeric literals](https://docs.python.org/3/reference/lexical_analysis.html#numeric-literals) + - Supported + - +* - [2.4.5. Integer literals](https://docs.python.org/3/reference/lexical_analysis.html#integer-literals) + - Supported + - +* - [2.4.6. Floating point literals](https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals) + - Supported + - +* - [2.4.7. Imaginary literals](https://docs.python.org/3/reference/lexical_analysis.html#imaginary-literals) + - Not Supported + - +* - [2.5. Operators](https://docs.python.org/3/reference/lexical_analysis.html#operators) + - Partially Supported + - Not supported: ``<<``, ``>>``, ``:=`` +* - [2.6. Delimiters](https://docs.python.org/3/reference/lexical_analysis.html#delimiters) + - Partially Supported + - Not supported: ``**=``, ``<<=``, ``>>=``, ``%=``, ``^=``, ``@=``, ``&=``, ``//=``, ``%`` operator for some types (e.g. ``str``\ ) +* - [3. Data model](https://docs.python.org/3/reference/datamodel.html#) + - Not Relevant + - +* - [3.1. Objects, values and types](https://docs.python.org/3/reference/datamodel.html#objects-values-and-types) + - Not Relevant + - +* - [3.2. The standard type hierarchy](https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy) + - Partially Supported + - Not supported: NotImplemented, Ellipsis, numbers.Complex, bytes, byte arrays, sets, frozen sets, generators, coroutines, async generators, modules, I/O objects, internal objects, slice objects ( though slicing is supported), classmethod +* - [3.3. Special method names](https://docs.python.org/3/reference/datamodel.html#special-method-names) + - Supported + - +* - [3.3.1. Basic customization](https://docs.python.org/3/reference/datamodel.html#basic-customization) + - Partially Supported + - Not supported: ``__new__`` , ``__del__`` , ``__bytes__`` , ``__format__`` , ``__hash__`` , +* - [3.3.2. Customizing attribute access](https://docs.python.org/3/reference/datamodel.html#customizing-attribute-access) + - Not Supported + - +* - [3.3.2.1. Customizing module attribute access](https://docs.python.org/3/reference/datamodel.html#customizing-module-attribute-access) + - Not Supported + - +* - [3.3.2.2. Implementing Descriptors](https://docs.python.org/3/reference/datamodel.html#implementing-descriptors) + - Not Supported + - +* - [3.3.2.3. Invoking Descriptors](https://docs.python.org/3/reference/datamodel.html#invoking-descriptors) + - Not Supported + - +* - [3.3.2.4. __slots__](https://docs.python.org/3/reference/datamodel.html#slots) + - Not Supported + - +* - [3.3.2.4.1. Notes on using __slots__](https://docs.python.org/3/reference/datamodel.html#notes-on-using-slots) + - Not Supported + - +* - [3.3.3. Customizing class creation](https://docs.python.org/3/reference/datamodel.html#customizing-class-creation) + - Not Supported + - +* - [3.3.3.1. Metaclasses](https://docs.python.org/3/reference/datamodel.html#metaclasses) + - Not Supported + - +* - [3.3.3.2. Resolving MRO entries](https://docs.python.org/3/reference/datamodel.html#resolving-mro-entries) + - Not Supported + - [`super()`` is not supported +* - [3.3.3.3. Determining the appropriate metaclass](https://docs.python.org/3/reference/datamodel.html#determining-the-appropriate-metaclass) + - Not relevant + - +* - [3.3.3.4. Preparing the class namespace](https://docs.python.org/3/reference/datamodel.html#preparing-the-class-namespace) + - Not relevant + - +* - [3.3.3.5. Executing the class body](https://docs.python.org/3/reference/datamodel.html#executing-the-class-body) + - Not relevant + - +* - [3.3.3.6. Creating the class object](https://docs.python.org/3/reference/datamodel.html#creating-the-class-object) + - Not relevant + - +* - [3.3.3.7. Uses for metaclasses](https://docs.python.org/3/reference/datamodel.html#uses-for-metaclasses) + - Not relevant + - +* - [3.3.4. Customizing instance and subclass checks](https://docs.python.org/3/reference/datamodel.html#customizing-instance-and-subclass-checks) + - Not Supported + - +* - [3.3.5. Emulating generic types](https://docs.python.org/3/reference/datamodel.html#emulating-generic-types) + - Not Supported + - +* - [3.3.6. Emulating callable objects](https://docs.python.org/3/reference/datamodel.html#emulating-callable-objects) + - Supported + - +* - [3.3.7. Emulating container types](https://docs.python.org/3/reference/datamodel.html#emulating-container-types) + - Partially Supported + - Some magic methods not supported (e.g. ``__iter__`` ) +* - [3.3.8. Emulating numeric types](https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types) + - Partially Supported + - Magic methods with swapped operands not supported (``__r*__``) +* - [3.3.9. With Statement Context Managers](https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers) + - Not Supported + - +* - [3.3.10. Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-method-lookup) + - Not relevant + - +* - [3.4. Coroutines](https://docs.python.org/3/reference/datamodel.html#coroutines) + - Not Supported + - +* - [3.4.1. Awaitable Objects](https://docs.python.org/3/reference/datamodel.html#awaitable-objects) + - Not Supported + - +* - [3.4.2. Coroutine Objects](https://docs.python.org/3/reference/datamodel.html#coroutine-objects) + - Not Supported + - +* - [3.4.3. Asynchronous Iterators](https://docs.python.org/3/reference/datamodel.html#asynchronous-iterators) + - Not Supported + - +* - [3.4.4. Asynchronous Context Managers](https://docs.python.org/3/reference/datamodel.html#asynchronous-context-managers) + - Not Supported + - +* - [4. Execution model](https://docs.python.org/3/reference/executionmodel.html#) + - Not Relevant + - +* - [4.1. Structure of a program](https://docs.python.org/3/reference/executionmodel.html#structure-of-a-program) + - Not Relevant + - +* - [4.2. Naming and binding](https://docs.python.org/3/reference/executionmodel.html#naming-and-binding) + - Not Relevant + - Names are bound at compile time in TorchScript +* - [4.2.1. Binding of names](https://docs.python.org/3/reference/executionmodel.html#binding-of-names) + - Not Relevant + - See ``global`` and ``nonlocal`` statements section +* - [4.2.2. Resolution of names](https://docs.python.org/3/reference/executionmodel.html#resolution-of-names) + - Not Relevant + - See ``global`` and ``nonlocal`` statements section +* - [4.2.3. Builtins and restricted execution](https://docs.python.org/3/reference/executionmodel.html#builtins-and-restricted-execution) + - Not Relevant + - +* - [4.2.4. Interaction with dynamic features](https://docs.python.org/3/reference/executionmodel.html#interaction-with-dynamic-features) + - Not Supported + - Python values cannot be captured +* - [4.3. Exceptions](https://docs.python.org/3/reference/executionmodel.html#exceptions) + - Partially Supported + - See ``try`` and ``raise`` statement section +* - [5. The import system](https://docs.python.org/3/reference/import.html) + - Not Relevant + - +* - [6. Expressions](https://docs.python.org/3/reference/expressions.html#) + - Not Relevant + - See expressions section +* - [6.1. Arithmetic conversions](https://docs.python.org/3/reference/expressions.html#arithmetic-conversions) + - Supported + - +* - [6.2. Atoms](https://docs.python.org/3/reference/expressions.html#atoms) + - Not Relevant + - +* - [6.2.1. Identifiers (Names)](https://docs.python.org/3/reference/expressions.html#atom-identifiers) + - Supported + - +* - [6.2.2. Literals](https://docs.python.org/3/reference/expressions.html#literals) + - Partially Supported + - [`bytesliteral``\ , ``imagnumber`` not supported +* - [6.2.3. Parenthesized forms](https://docs.python.org/3/reference/expressions.html#parenthesized-forms) + - Supported + - +* - [6.2.4. Displays for lists, sets and dictionaries](https://docs.python.org/3/reference/expressions.html#displays-for-lists-sets-and-dictionaries) + - Partially Supported + - Not supported: comprehension ifs, async iterators +* - [6.2.5. List displays](https://docs.python.org/3/reference/expressions.html#list-displays) + - Supported + - +* - [6.2.6. Set displays](https://docs.python.org/3/reference/expressions.html#set-displays) + - Not Supported + - +* - [6.2.7. Dictionary displays](https://docs.python.org/3/reference/expressions.html#dictionary-displays) + - Supported + - dict() constructor with kwargs doesn't work, dict comprehensions, dictionary unpacking +* - [6.2.8. Generator expressions](https://docs.python.org/3/reference/expressions.html#generator-expressions) + - Not Supported + - +* - [6.2.9. Yield expressions](https://docs.python.org/3/reference/expressions.html#yield-expressions) + - Not Supported + - +* - [6.2.9.1. Generator-iterator methods](https://docs.python.org/3/reference/expressions.html#generator-iterator-methods) + - Not Supported + - +* - [6.2.9.2. Examples](https://docs.python.org/3/reference/expressions.html#examples) + - Not Supported + - +* - [6.2.9.3. Asynchronous generator functions](https://docs.python.org/3/reference/expressions.html#asynchronous-generator-functions) + - Not Supported + - +* - [6.2.9.4. Asynchronous generator-iterator methods](https://docs.python.org/3/reference/expressions.html#asynchronous-generator-iterator-methods) + - Not Supported + - +* - [6.3. Primaries](https://docs.python.org/3/reference/expressions.html#primaries) + - Supported + - +* - [6.3.1. Attribute references](https://docs.python.org/3/reference/expressions.html#attribute-references) + - Supported + - +* - [6.3.2. Subscriptions](https://docs.python.org/3/reference/expressions.html#subscriptions) + - Supported + - +* - [6.3.3. Slicings](https://docs.python.org/3/reference/expressions.html#slicings) + - Partially Supported + - Tuple slicing with stride is not supported +* - [6.3.4. Calls](https://docs.python.org/3/reference/expressions.html#calls) + - Partially Supported + - Args unpack / kwargs unpack is not supported +* - [6.4. Await expression](https://docs.python.org/3/reference/expressions.html#await-expression) + - Not Supported + - +* - [6.5. The power operator](https://docs.python.org/3/reference/expressions.html#the-power-operator) + - Supported + - +* - [6.6. Unary arithmetic and bitwise operations](https://docs.python.org/3/reference/expressions.html#unary-arithmetic-and-bitwise-operations) + - Partially Supported + - Some bitwise operators are not implemented for primitive types (e.g. ``~x`` where ``x`` is an ``int`` is not currently supported) +* - [6.7. Binary arithmetic operations](https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations) + - Partially Supported + - See delimiters section +* - [6.8. Shifting operations](https://docs.python.org/3/reference/expressions.html#shifting-operations) + - Not Supported + - +* - [6.9. Binary bitwise operations](https://docs.python.org/3/reference/expressions.html#binary-bitwise-operations) + - Supported + - +* - [6.10. Comparisons](https://docs.python.org/3/reference/expressions.html#comparisons) + - Supported + - +* - [6.10.1. Value comparisons](https://docs.python.org/3/reference/expressions.html#value-comparisons) + - Partially Supported + - Dictionary equality checks are not currently supported +* - [6.10.2. Membership test operations](https://docs.python.org/3/reference/expressions.html#membership-test-operations) + - Partially Supported + - Not supported for TorchScript classes +* - [6.10.3. Identity comparisons](https://docs.python.org/3/reference/expressions.html#is-not) + - Supported + - +* - [6.11. Boolean operations](https://docs.python.org/3/reference/expressions.html#boolean-operations) + - Supported + - +* - [6.12. Conditional expressions](https://docs.python.org/3/reference/expressions.html#conditional-expressions) + - Supported + - +* - [6.13. Lambdas](https://docs.python.org/3/reference/expressions.html#lambda) + - Not Supported + - +* - [6.14. Expression lists](https://docs.python.org/3/reference/expressions.html#expression-lists) + - Partially Supported + - Iterable unpacking not supported +* - [6.15. Evaluation order](https://docs.python.org/3/reference/expressions.html#evaluation-order) + - Supported + - +* - [6.16. Operator precedence](https://docs.python.org/3/reference/expressions.html#operator-precedence) + - Supported + - +* - [7. Simple statements](https://docs.python.org/3/reference/simple_stmts.html#) + - Supported + - +* - [7.1. Expression statements](https://docs.python.org/3/reference/simple_stmts.html#expression-statements) + - Supported + - +* - [7.2. Assignment statements](https://docs.python.org/3/reference/simple_stmts.html#assignment-statements) + - Supported + - +* - [7.2.1. Augmented assignment statements](https://docs.python.org/3/reference/simple_stmts.html#augmented-assignment-statements) + - Partially Supported + - See delimiters section +* - [7.2.2. Annotated assignment statements](https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements) + - Supported + - +* - [7.3. The assert statement](https://docs.python.org/3/reference/simple_stmts.html#the-assert-statement) + - Partially Supported + - Exception message is not customizable +* - [7.4. The pass statement](https://docs.python.org/3/reference/simple_stmts.html#the-pass-statement) + - Supported + - +* - [7.5. The del statement](https://docs.python.org/3/reference/simple_stmts.html#the-del-statement) + - Not Supported + - +* - [7.6. The return statement](https://docs.python.org/3/reference/simple_stmts.html#the-return-statement) + - Supported + - Some other features of returning (e.g. behavior with try..finally) are unsupported +* - [7.7. The yield statement](https://docs.python.org/3/reference/simple_stmts.html#the-yield-statement) + - Not Supported + - +* - [7.8. The raise statement](https://docs.python.org/3/reference/simple_stmts.html#the-raise-statement) + - Partially Supported + - Exception message is not customizable +* - [7.9. The break statement](https://docs.python.org/3/reference/simple_stmts.html#the-break-statement) + - Supported + - Some other features of returning (e.g. behavior with try..finally) are unsupported +* - [7.10. The continue statement](https://docs.python.org/3/reference/simple_stmts.html#the-continue-statement) + - Supported + - Some other features of returning (e.g. behavior with try..finally) are unsupported +* - [7.11. The import statement](https://docs.python.org/3/reference/simple_stmts.html#the-import-statement) + - Not Supported + - +* - [7.11.1. Future statements](https://docs.python.org/3/reference/simple_stmts.html#future-statements) + - Not Supported + - +* - [7.12. The global statement](https://docs.python.org/3/reference/simple_stmts.html#the-global-statement) + - Not Supported + - +* - [7.13. The nonlocal statement](https://docs.python.org/3/reference/simple_stmts.html#the-nonlocal-statement) + - Not Supported + - +* - [8. Compound statements](https://docs.python.org/3/reference/compound_stmts.html#) + - Irrelevant + - +* - [8.1. The if statement](https://docs.python.org/3/reference/compound_stmts.html#the-if-statement) + - Supported + - +* - [8.2. The while statement](https://docs.python.org/3/reference/compound_stmts.html#the-while-statement) + - Partially Supported + - while..else is not supported +* - [8.3. The for statement](https://docs.python.org/3/reference/compound_stmts.html#the-for-statement) + - Partially Supported + - for..else is not supported +* - [8.4. The try statement](https://docs.python.org/3/reference/compound_stmts.html#the-try-statement) + - Not Supported + - +* - [8.5. The with statement](https://docs.python.org/3/reference/compound_stmts.html#the-with-statement) + - Partially Supported + - [`__exit__`` is always called with ``exc_type``, ``exc_value``, and ``traceback`` set to None, even if an exception was raised, and ``__exit__``'s return value is ignored. +* - [8.6. Function definitions](https://docs.python.org/3/reference/compound_stmts.html#function-definitions) + - Not Supported + - +* - [8.7. Class definitions](https://docs.python.org/3/reference/compound_stmts.html#class-definitions) + - Not Supported + - +* - [8.8. Coroutines](https://docs.python.org/3/reference/compound_stmts.html#coroutines) + - Not Supported + - +* - [8.8.1. Coroutine function definition](https://docs.python.org/3/reference/compound_stmts.html#coroutine-function-definition) + - Not Supported + - +* - [8.8.2. The async for statement](https://docs.python.org/3/reference/compound_stmts.html#the-async-for-statement) + - Not Supported + - +* - [8.8.3. The async with statement](https://docs.python.org/3/reference/compound_stmts.html#the-async-with-statement) + - Not Supported + - +* - [9. Top-level components](https://docs.python.org/3/reference/toplevel_components.html#) + - Not Relevant + - +* - [9.1. Complete Python programs](https://docs.python.org/3/reference/toplevel_components.html#complete-python-programs) + - Not Relevant + - +* - [9.2. File input](https://docs.python.org/3/reference/toplevel_components.html#file-input) + - Not Relevant + - +* - [9.3. Interactive input](https://docs.python.org/3/reference/toplevel_components.html#interactive-input) + - Not Relevant + - +* - [9.4. Expression input](https://docs.python.org/3/reference/toplevel_components.html#expression-input) + - Not Relevant + - +``` diff --git a/docs/source/jit_python_reference.rst b/docs/source/jit_python_reference.rst deleted file mode 100644 index 96e0fe13037c17..00000000000000 --- a/docs/source/jit_python_reference.rst +++ /dev/null @@ -1,432 +0,0 @@ -.. _python-language-reference: - -Python Language Reference Coverage -================================== - -This is a 1:1 mapping of the features listed in https://docs.python.org/3/reference/ and their -support in TorchScript. The categorizations are as follows: - - -.. list-table:: - :header-rows: 1 - - * - Section - - Status - - Note - * - `1. Introduction `_ - - Not Relevant - - - * - `1.1. Alternate Implementations `_ - - Not Relevant - - - * - `1.2. Notation `_ - - Not Relevant - - - * - `2. Lexical analysis `_ - - Not Relevant - - - * - `2.1. Line structure `_ - - Not Relevant - - - * - `2.1.1. Logical lines `_ - - Not Relevant - - - * - `2.1.2. Physical lines `_ - - Supported - - - * - `2.1.3. Comments `_ - - Supported - - - * - `2.1.4. Encoding declarations `_ - - Not Supported - - TorchScript explicitly don't support unicode - * - `2.1.5. Explicit line joining `_ - - Supported - - - * - `2.1.6. Implicit line joining `_ - - Supported - - - * - `2.1.7. Blank lines `_ - - Supported - - - * - `2.1.8. Indentation `_ - - Supported - - - * - `2.1.9. Whitespace between tokens `_ - - Not Relevant - - - * - `2.2. Other tokens `_ - - Not Relevant - - - * - `2.3. Identifiers and keywords `_ - - Supported - - - * - `2.3.1. Keywords `_ - - Supported - - - * - `2.3.2. Reserved classes of identifiers `_ - - Supported - - - * - `2.4. Literals `_ - - Not Relevant - - - * - `2.4.1. String and Bytes literals `_ - - Supported - - - * - `2.4.2. String literal concatenation `_ - - Supported - - - * - `2.4.3. Formatted string literals `_ - - Partially Supported - - - * - `2.4.4. Numeric literals `_ - - Supported - - - * - `2.4.5. Integer literals `_ - - Supported - - - * - `2.4.6. Floating point literals `_ - - Supported - - - * - `2.4.7. Imaginary literals `_ - - Not Supported - - - * - `2.5. Operators `_ - - Partially Supported - - Not supported: ``<<``, ``>>``, ``:=`` - * - `2.6. Delimiters `_ - - Partially Supported - - Not supported: ``**=``, ``<<=``, ``>>=``, ``%=``, ``^=``, ``@=``, ``&=``, ``//=``, ``%`` operator for some types (e.g. ``str``\ ) - * - `3. Data model `_ - - Not Relevant - - - * - `3.1. Objects, values and types `_ - - Not Relevant - - - * - `3.2. The standard type hierarchy `_ - - Partially Supported - - Not supported: NotImplemented, Ellipsis, numbers.Complex, bytes, byte arrays, sets, frozen sets, generators, coroutines, async generators, modules, I/O objects, internal objects, slice objects ( though slicing is supported), classmethod - * - `3.3. Special method names `_ - - Supported - - - * - `3.3.1. Basic customization `_ - - Partially Supported - - Not supported: ``__new__`` , ``__del__`` , ``__bytes__`` , ``__format__`` , ``__hash__`` , - * - `3.3.2. Customizing attribute access `_ - - Not Supported - - - * - `3.3.2.1. Customizing module attribute access `_ - - Not Supported - - - * - `3.3.2.2. Implementing Descriptors `_ - - Not Supported - - - * - `3.3.2.3. Invoking Descriptors `_ - - Not Supported - - - * - `3.3.2.4. __slots__ `_ - - Not Supported - - - * - `3.3.2.4.1. Notes on using __slots__ `_ - - Not Supported - - - * - `3.3.3. Customizing class creation `_ - - Not Supported - - - * - `3.3.3.1. Metaclasses `_ - - Not Supported - - - * - `3.3.3.2. Resolving MRO entries `_ - - Not Supported - - ``super()`` is not supported - * - `3.3.3.3. Determining the appropriate metaclass `_ - - Not relevant - - - * - `3.3.3.4. Preparing the class namespace `_ - - Not relevant - - - * - `3.3.3.5. Executing the class body `_ - - Not relevant - - - * - `3.3.3.6. Creating the class object `_ - - Not relevant - - - * - `3.3.3.7. Uses for metaclasses `_ - - Not relevant - - - * - `3.3.4. Customizing instance and subclass checks `_ - - Not Supported - - - * - `3.3.5. Emulating generic types `_ - - Not Supported - - - * - `3.3.6. Emulating callable objects `_ - - Supported - - - * - `3.3.7. Emulating container types `_ - - Partially Supported - - Some magic methods not supported (e.g. ``__iter__`` ) - * - `3.3.8. Emulating numeric types `_ - - Partially Supported - - Magic methods with swapped operands not supported (``__r*__``) - * - `3.3.9. With Statement Context Managers `_ - - Not Supported - - - * - `3.3.10. Special method lookup `_ - - Not relevant - - - * - `3.4. Coroutines `_ - - Not Supported - - - * - `3.4.1. Awaitable Objects `_ - - Not Supported - - - * - `3.4.2. Coroutine Objects `_ - - Not Supported - - - * - `3.4.3. Asynchronous Iterators `_ - - Not Supported - - - * - `3.4.4. Asynchronous Context Managers `_ - - Not Supported - - - * - `4. Execution model `_ - - Not Relevant - - - * - `4.1. Structure of a program `_ - - Not Relevant - - - * - `4.2. Naming and binding `_ - - Not Relevant - - Names are bound at compile time in TorchScript - * - `4.2.1. Binding of names `_ - - Not Relevant - - See ``global`` and ``nonlocal`` statements section - * - `4.2.2. Resolution of names `_ - - Not Relevant - - See ``global`` and ``nonlocal`` statements section - * - `4.2.3. Builtins and restricted execution `_ - - Not Relevant - - - * - `4.2.4. Interaction with dynamic features `_ - - Not Supported - - Python values cannot be captured - * - `4.3. Exceptions `_ - - Partially Supported - - See ``try`` and ``raise`` statement section - * - `5. The import system `_ - - Not Relevant - - - * - `6. Expressions `_ - - Not Relevant - - See expressions section - * - `6.1. Arithmetic conversions `_ - - Supported - - - * - `6.2. Atoms `_ - - Not Relevant - - - * - `6.2.1. Identifiers (Names) `_ - - Supported - - - * - `6.2.2. Literals `_ - - Partially Supported - - ``bytesliteral``\ , ``imagnumber`` not supported - * - `6.2.3. Parenthesized forms `_ - - Supported - - - * - `6.2.4. Displays for lists, sets and dictionaries `_ - - Partially Supported - - Not supported: comprehension ifs, async iterators - * - `6.2.5. List displays `_ - - Supported - - - * - `6.2.6. Set displays `_ - - Not Supported - - - * - `6.2.7. Dictionary displays `_ - - Supported - - dict() constructor with kwargs doesn't work, dict comprehensions, dictionary unpacking - * - `6.2.8. Generator expressions `_ - - Not Supported - - - * - `6.2.9. Yield expressions `_ - - Not Supported - - - * - `6.2.9.1. Generator-iterator methods `_ - - Not Supported - - - * - `6.2.9.2. Examples `_ - - Not Supported - - - * - `6.2.9.3. Asynchronous generator functions `_ - - Not Supported - - - * - `6.2.9.4. Asynchronous generator-iterator methods `_ - - Not Supported - - - * - `6.3. Primaries `_ - - Supported - - - * - `6.3.1. Attribute references `_ - - Supported - - - * - `6.3.2. Subscriptions `_ - - Supported - - - * - `6.3.3. Slicings `_ - - Partially Supported - - Tuple slicing with stride is not supported - * - `6.3.4. Calls `_ - - Partially Supported - - Args unpack / kwargs unpack is not supported - * - `6.4. Await expression `_ - - Not Supported - - - * - `6.5. The power operator `_ - - Supported - - - * - `6.6. Unary arithmetic and bitwise operations `_ - - Partially Supported - - Some bitwise operators are not implemented for primitive types (e.g. ``~x`` where ``x`` is an ``int`` is not currently supported) - * - `6.7. Binary arithmetic operations `_ - - Partially Supported - - See delimiters section - * - `6.8. Shifting operations `_ - - Not Supported - - - * - `6.9. Binary bitwise operations `_ - - Supported - - - * - `6.10. Comparisons `_ - - Supported - - - * - `6.10.1. Value comparisons `_ - - Partially Supported - - Dictionary equality checks are not currently supported - * - `6.10.2. Membership test operations `_ - - Partially Supported - - Not supported for TorchScript classes - * - `6.10.3. Identity comparisons `_ - - Supported - - - * - `6.11. Boolean operations `_ - - Supported - - - * - `6.12. Conditional expressions `_ - - Supported - - - * - `6.13. Lambdas `_ - - Not Supported - - - * - `6.14. Expression lists `_ - - Partially Supported - - Iterable unpacking not supported - * - `6.15. Evaluation order `_ - - Supported - - - * - `6.16. Operator precedence `_ - - Supported - - - * - `7. Simple statements `_ - - Supported - - - * - `7.1. Expression statements `_ - - Supported - - - * - `7.2. Assignment statements `_ - - Supported - - - * - `7.2.1. Augmented assignment statements `_ - - Partially Supported - - See delimiters section - * - `7.2.2. Annotated assignment statements `_ - - Supported - - - * - `7.3. The assert statement `_ - - Partially Supported - - Exception message is not customizable - * - `7.4. The pass statement `_ - - Supported - - - * - `7.5. The del statement `_ - - Not Supported - - - * - `7.6. The return statement `_ - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported - * - `7.7. The yield statement `_ - - Not Supported - - - * - `7.8. The raise statement `_ - - Partially Supported - - Exception message is not customizable - * - `7.9. The break statement `_ - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported - * - `7.10. The continue statement `_ - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported - * - `7.11. The import statement `_ - - Not Supported - - - * - `7.11.1. Future statements `_ - - Not Supported - - - * - `7.12. The global statement `_ - - Not Supported - - - * - `7.13. The nonlocal statement `_ - - Not Supported - - - * - `8. Compound statements `_ - - Irrelevant - - - * - `8.1. The if statement `_ - - Supported - - - * - `8.2. The while statement `_ - - Partially Supported - - while..else is not supported - * - `8.3. The for statement `_ - - Partially Supported - - for..else is not supported - * - `8.4. The try statement `_ - - Not Supported - - - * - `8.5. The with statement `_ - - Partially Supported - - ``__exit__`` is always called with ``exc_type``, ``exc_value``, and ``traceback`` set to None, even if an exception was raised, and ``__exit__``'s return value is ignored. - * - `8.6. Function definitions `_ - - Not Supported - - - * - `8.7. Class definitions `_ - - Not Supported - - - * - `8.8. Coroutines `_ - - Not Supported - - - * - `8.8.1. Coroutine function definition `_ - - Not Supported - - - * - `8.8.2. The async for statement `_ - - Not Supported - - - * - `8.8.3. The async with statement `_ - - Not Supported - - - * - `9. Top-level components `_ - - Not Relevant - - - * - `9.1. Complete Python programs `_ - - Not Relevant - - - * - `9.2. File input `_ - - Not Relevant - - - * - `9.3. Interactive input `_ - - Not Relevant - - - * - `9.4. Expression input `_ - - Not Relevant - - diff --git a/docs/source/jit_unsupported.md b/docs/source/jit_unsupported.md new file mode 100644 index 00000000000000..79a51c1651f346 --- /dev/null +++ b/docs/source/jit_unsupported.md @@ -0,0 +1,81 @@ +(jit_unsupported)= + +# TorchScript Unsupported PyTorch Constructs + +## Torch and Tensor Unsupported Attributes + +TorchScript supports most methods defined on `torch` and `torch.Tensor`, but we do not have full coverage. +Here are specific known ops and categories of ops which have diverging behavior between +Python and TorchScript. If you encounter something else that is not supported please +file a GitHub issue. Deprecated ops are not listed below. + +```{eval-rst} +.. automodule:: torch.jit.unsupported_tensor_ops +``` + +### Functions Not Correctly Bound on Torch + +The following functions will fail if used in TorchScript, either because they +are not bound on `torch` or because Python expects a different schema than +TorchScript. + +- {func}`torch.tensordot` +- {func}`torch.nn.init.calculate_gain` +- {func}`torch.nn.init.eye_` +- {func}`torch.nn.init.dirac_` +- {func}`torch.nn.init.kaiming_normal_` +- {func}`torch.nn.init.orthogonal_` +- {func}`torch.nn.init.sparse` + +### Ops With Divergent Schemas Between Torch & Python + +The following categories of ops have divergent schemas: + +Functions which construct tensors from non-tensor inputs do not support the `requires_grad` +argument, except for `torch.tensor`. This covers the following ops: + +- {func}`torch.norm` +- {func}`torch.bartlett_window` +- {func}`torch.blackman_window` +- {func}`torch.empty` +- {func}`torch.empty_like` +- {func}`torch.empty_strided` +- {func}`torch.eye` +- {func}`torch.full` +- {func}`torch.full_like` +- {func}`torch.hamming_window` +- {func}`torch.hann_window` +- {func}`torch.linspace` +- {func}`torch.logspace` +- {func}`torch.normal` +- {func}`torch.ones` +- {func}`torch.rand` +- {func}`torch.rand_like` +- {func}`torch.randint_like` +- {func}`torch.randn` +- {func}`torch.randn_like` +- {func}`torch.randperm` +- {func}`torch.tril_indices` +- {func}`torch.triu_indices` +- {func}`torch.vander` +- {func}`torch.zeros` +- {func}`torch.zeros_like` + +The following functions require `dtype`, `layout`, `device` as parameters in TorchScript, +but these parameters are optional in Python. + +- {func}`torch.randint` +- {func}`torch.sparse_coo_tensor` +- {func}`torch.Tensor.to` + +## PyTorch Unsupported Modules and Classes + +TorchScript cannot currently compile a number of other commonly used PyTorch +constructs. Below are listed the modules that TorchScript does not support, and +an incomplete list of PyTorch classes that are not supported. For unsupported modules +we suggest using {meth}`torch.jit.trace`. + +- {class}`torch.nn.RNN` +- {class}`torch.nn.AdaptiveLogSoftmaxWithLoss` +- {class}`torch.autograd.Function` +- {class}`torch.autograd.enable_grad` diff --git a/docs/source/jit_unsupported.rst b/docs/source/jit_unsupported.rst deleted file mode 100644 index 60bca7d6d92c64..00000000000000 --- a/docs/source/jit_unsupported.rst +++ /dev/null @@ -1,90 +0,0 @@ -.. _jit_unsupported: - -TorchScript Unsupported PyTorch Constructs -============================================ - -Torch and Tensor Unsupported Attributes ------------------------------------------- - - -TorchScript supports most methods defined on ``torch`` and ``torch.Tensor``, but we do not have full coverage. -Here are specific known ops and categories of ops which have diverging behavior between -Python and TorchScript. If you encounter something else that is not supported please -file a GitHub issue. Deprecated ops are not listed below. - - - -.. automodule:: torch.jit.unsupported_tensor_ops - - -Functions Not Correctly Bound on Torch -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The following functions will fail if used in TorchScript, either because they -are not bound on `torch` or because Python expects a different schema than -TorchScript. - - * :func:`torch.tensordot` - * :func:`torch.nn.init.calculate_gain` - * :func:`torch.nn.init.eye_` - * :func:`torch.nn.init.dirac_` - * :func:`torch.nn.init.kaiming_normal_` - * :func:`torch.nn.init.orthogonal_` - * :func:`torch.nn.init.sparse` - - -Ops With Divergent Schemas Between Torch & Python -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The following categories of ops have divergent schemas: - -Functions which construct tensors from non-tensor inputs do not support the `requires_grad` -argument, except for `torch.tensor`. This covers the following ops: - - * :func:`torch.norm` - * :func:`torch.bartlett_window` - * :func:`torch.blackman_window` - * :func:`torch.empty` - * :func:`torch.empty_like` - * :func:`torch.empty_strided` - * :func:`torch.eye` - * :func:`torch.full` - * :func:`torch.full_like` - * :func:`torch.hamming_window` - * :func:`torch.hann_window` - * :func:`torch.linspace` - * :func:`torch.logspace` - * :func:`torch.normal` - * :func:`torch.ones` - * :func:`torch.rand` - * :func:`torch.rand_like` - * :func:`torch.randint_like` - * :func:`torch.randn` - * :func:`torch.randn_like` - * :func:`torch.randperm` - * :func:`torch.tril_indices` - * :func:`torch.triu_indices` - * :func:`torch.vander` - * :func:`torch.zeros` - * :func:`torch.zeros_like` - -The following functions require `dtype`, `layout`, `device` as parameters in TorchScript, -but these parameters are optional in Python. - - * :func:`torch.randint` - * :func:`torch.sparse_coo_tensor` - * :meth:`~torch.Tensor.to` - - -PyTorch Unsupported Modules and Classes ------------------------------------------- - -TorchScript cannot currently compile a number of other commonly used PyTorch -constructs. Below are listed the modules that TorchScript does not support, and -an incomplete list of PyTorch classes that are not supported. For unsupported modules -we suggest using :meth:`torch.jit.trace`. - - * :class:`torch.nn.RNN` - * :class:`torch.nn.AdaptiveLogSoftmaxWithLoss` - * :class:`torch.autograd.Function` - * :class:`torch.autograd.enable_grad` diff --git a/docs/source/jit_utils.md b/docs/source/jit_utils.md new file mode 100644 index 00000000000000..de05fbf35c513f --- /dev/null +++ b/docs/source/jit_utils.md @@ -0,0 +1,5 @@ +# JIT Utils - torch.utils.jit + +```{eval-rst} +.. automodule:: torch.utils.jit +``` diff --git a/docs/source/jit_utils.rst b/docs/source/jit_utils.rst deleted file mode 100644 index abc4235912321e..00000000000000 --- a/docs/source/jit_utils.rst +++ /dev/null @@ -1,4 +0,0 @@ -JIT Utils - torch.utils.jit -================================================== - -.. automodule:: torch.utils.jit diff --git a/docs/source/library.md b/docs/source/library.md new file mode 100644 index 00000000000000..9d706e2e1080ed --- /dev/null +++ b/docs/source/library.md @@ -0,0 +1,81 @@ +(torch-library-docs)= + +# torch.library + +```{eval-rst} +.. py:module:: torch.library +.. currentmodule:: torch.library +``` + +torch.library is a collection of APIs for extending PyTorch's core library +of operators. It contains utilities for testing custom operators, creating new +custom operators, and extending operators defined with PyTorch's C++ operator +registration APIs (e.g. aten operators). + +For a detailed guide on effectively using these APIs, please see +[PyTorch Custom Operators Landing Page](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) +for more details on how to effectively use these APIs. + +## Testing custom ops + +Use {func}`torch.library.opcheck` to test custom ops for incorrect usage of the +Python torch.library and/or C++ TORCH_LIBRARY APIs. Also, if your operator supports +training, use {func}`torch.autograd.gradcheck` to test that the gradients are +mathematically correct. + +```{eval-rst} +.. autofunction:: opcheck +``` + +## Creating new custom ops in Python + +Use {func}`torch.library.custom_op` to create new custom ops. + +```{eval-rst} +.. autofunction:: custom_op +.. autofunction:: triton_op +.. autofunction:: wrap_triton +``` + +## Extending custom ops (created from Python or C++) + +Use the `register.*` methods, such as {func}`torch.library.register_kernel` and +{func}`torch.library.register_fake`, to add implementations +for any operators (they may have been created using {func}`torch.library.custom_op` or +via PyTorch's C++ operator registration APIs). + +```{eval-rst} +.. autofunction:: register_kernel +.. autofunction:: register_autocast +.. autofunction:: register_autograd +.. autofunction:: register_fake +.. autofunction:: register_vmap +.. autofunction:: impl_abstract +.. autofunction:: get_ctx +.. autofunction:: register_torch_dispatch +.. autofunction:: infer_schema +.. autoclass:: torch._library.custom_ops.CustomOpDef + :members: set_kernel_enabled +``` + +## Low-level APIs + +The following APIs are direct bindings to PyTorch's C++ low-level +operator registration APIs. + +```{eval-rst} +.. warning:: The low-level operator registration APIs and the PyTorch Dispatcher are a complicated PyTorch concept. We recommend you use the higher level APIs above (that do not require a torch.library.Library object) when possible. `This blog post `_ is a good starting point to learn about the PyTorch Dispatcher. +``` + +A tutorial that walks you through some examples on how to use this API is available on [Google Colab](https://colab.research.google.com/drive/1RRhSfk7So3Cn02itzLWE9K4Fam-8U011?usp=sharing). + +```{eval-rst} +.. autoclass:: torch.library.Library + :members: + +.. autofunction:: fallthrough_kernel + +.. autofunction:: define + +.. autofunction:: impl +``` diff --git a/docs/source/library.rst b/docs/source/library.rst deleted file mode 100644 index e54211ccab6f91..00000000000000 --- a/docs/source/library.rst +++ /dev/null @@ -1,80 +0,0 @@ -.. _torch-library-docs: - -torch.library -=================================== -.. py:module:: torch.library -.. currentmodule:: torch.library - -torch.library is a collection of APIs for extending PyTorch's core library -of operators. It contains utilities for testing custom operators, creating new -custom operators, and extending operators defined with PyTorch's C++ operator -registration APIs (e.g. aten operators). - -For a detailed guide on effectively using these APIs, please see -`PyTorch Custom Operators Landing Page `_ -for more details on how to effectively use these APIs. - -Testing custom ops ------------------- - -Use :func:`torch.library.opcheck` to test custom ops for incorrect usage of the -Python torch.library and/or C++ TORCH_LIBRARY APIs. Also, if your operator supports -training, use :func:`torch.autograd.gradcheck` to test that the gradients are -mathematically correct. - -.. autofunction:: opcheck - -Creating new custom ops in Python ---------------------------------- - -Use :func:`torch.library.custom_op` to create new custom ops. - -.. autofunction:: custom_op -.. autofunction:: triton_op -.. autofunction:: wrap_triton - -Extending custom ops (created from Python or C++) -------------------------------------------------- - -Use the register.* methods, such as :func:`torch.library.register_kernel` and -:func:`torch.library.register_fake`, to add implementations -for any operators (they may have been created using :func:`torch.library.custom_op` or -via PyTorch's C++ operator registration APIs). - -.. autofunction:: register_kernel -.. autofunction:: register_autocast -.. autofunction:: register_autograd -.. autofunction:: register_fake -.. autofunction:: register_vmap -.. autofunction:: impl_abstract -.. autofunction:: get_ctx -.. autofunction:: register_torch_dispatch -.. autofunction:: infer_schema -.. autoclass:: torch._library.custom_ops.CustomOpDef - - .. automethod:: set_kernel_enabled - - -Low-level APIs --------------- - -The following APIs are direct bindings to PyTorch's C++ low-level -operator registration APIs. - -.. warning:: - The low-level operator registration APIs and the PyTorch Dispatcher are a - complicated PyTorch concept. We recommend you use the higher level APIs above - (that do not require a torch.library.Library object) when possible. - This blog post `_ - is a good starting point to learn about the PyTorch Dispatcher. - -A tutorial that walks you through some examples on how to use this API is available on `Google Colab `_. - -.. autoclass:: torch.library.Library - :members: - -.. autofunction:: fallthrough_kernel - -.. autofunction:: define - -.. autofunction:: impl diff --git a/docs/source/linalg.md b/docs/source/linalg.md new file mode 100644 index 00000000000000..0c3b46b90a5a21 --- /dev/null +++ b/docs/source/linalg.md @@ -0,0 +1,141 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# torch.linalg + +Common linear algebra operations. + +See {ref}`Linear Algebra Stability` for some common numerical edge-cases. + +```{eval-rst} +.. automodule:: torch.linalg +.. currentmodule:: torch.linalg +``` + +## Matrix Properties + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + norm + vector_norm + matrix_norm + diagonal + det + slogdet + cond + matrix_rank +``` + +## Decompositions + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + cholesky + qr + lu + lu_factor + eig + eigvals + eigh + eigvalsh + svd + svdvals +``` + +(linalg solvers)= + +## Solvers + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + solve + solve_triangular + lu_solve + lstsq +``` + +(linalg inverses)= + +## Inverses + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + inv + pinv +``` + +## Matrix Functions + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + matrix_exp + matrix_power +``` + +## Matrix Products + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + cross + matmul + vecdot + multi_dot + householder_product +``` + +## Tensor Operations + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + tensorinv + tensorsolve +``` + +## Misc + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + vander +``` + +## Experimental Functions + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + cholesky_ex + inv_ex + solve_ex + lu_factor_ex + ldl_factor + ldl_factor_ex + ldl_solve +``` diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst deleted file mode 100644 index aec7031e2248ef..00000000000000 --- a/docs/source/linalg.rst +++ /dev/null @@ -1,128 +0,0 @@ -.. role:: hidden - :class: hidden-section - -torch.linalg -============ - -Common linear algebra operations. - -See :ref:`Linear Algebra Stability` for some common numerical edge-cases. - -.. automodule:: torch.linalg -.. currentmodule:: torch.linalg - -Matrix Properties ------------------ - -.. autosummary:: - :toctree: generated - :nosignatures: - - norm - vector_norm - matrix_norm - diagonal - det - slogdet - cond - matrix_rank - -Decompositions --------------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - cholesky - qr - lu - lu_factor - eig - eigvals - eigh - eigvalsh - svd - svdvals - -.. _linalg solvers: - -Solvers -------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - solve - solve_triangular - lu_solve - lstsq - -.. _linalg inverses: - -Inverses --------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - inv - pinv - -Matrix Functions ----------------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - matrix_exp - matrix_power - -Matrix Products ---------------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - cross - matmul - vecdot - multi_dot - householder_product - -Tensor Operations ------------------ - -.. autosummary:: - :toctree: generated - :nosignatures: - - tensorinv - tensorsolve - -Misc ----- - -.. autosummary:: - :toctree: generated - :nosignatures: - - vander - -Experimental Functions ----------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - cholesky_ex - inv_ex - solve_ex - lu_factor_ex - ldl_factor - ldl_factor_ex - ldl_solve diff --git a/docs/source/logging.md b/docs/source/logging.md new file mode 100644 index 00000000000000..8a9b8b2b13068c --- /dev/null +++ b/docs/source/logging.md @@ -0,0 +1,124 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# torch._logging + +PyTorch has a configurable logging system, where different components can be +given different log level settings. For instance, one component's log messages +can be completely disabled, while another component's log messages can be +set to maximum verbosity. + +:::{warning} +This feature is in beta and may have compatibility breaking +changes in the future. +::: + +:::{warning} +This feature has not been expanded to control the log messages of +all components in PyTorch yet. +::: + +There are two ways to configure the logging system: through the environment variable `TORCH_LOGS` +or the python API torch._logging.set_logs. + +```{eval-rst} +.. automodule:: torch._logging +.. currentmodule:: torch._logging +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + set_logs +``` + +The environment variable `TORCH_LOGS` is a comma-separated list of +`[+-]` pairs, where `` is a component specified below. The `+` prefix +will decrease the log level of the component, displaying more log messages while the `-` prefix +will increase the log level of the component and display fewer log messages. The default setting +is the behavior when a component is not specified in `TORCH_LOGS`. In addition to components, there are +also artifacts. Artifacts are specific pieces of debug information associated with a component that are either displayed or not displayed, +so prefixing an artifact with `+` or `-` will be a no-op. Since they are associated with a component, enabling that component will typically also enable that artifact, +unless that artifact was specified to be `off_by_default`. This option is specified in _registrations.py for artifacts that are so spammy they should only be displayed when explicitly enabled. +The following components and artifacts are configurable through the `TORCH_LOGS` environment +variable (see torch._logging.set_logs for the python API): + +```{eval-rst} +Components: + ``all`` + Special component which configures the default log level of all components. Default: ``logging.WARN`` + + ``dynamo`` + The log level for the TorchDynamo component. Default: ``logging.WARN`` + + ``aot`` + The log level for the AOTAutograd component. Default: ``logging.WARN`` + + ``inductor`` + The log level for the TorchInductor component. Default: ``logging.WARN`` + + ``your.custom.module`` + The log level for an arbitrary unregistered module. Provide the fully qualified name and the module will be enabled. Default: ``logging.WARN`` +``` + +```{eval-rst} +Artifacts: + ``bytecode`` + Whether to emit the original and generated bytecode from TorchDynamo. + Default: ``False`` + + ``aot_graphs`` + Whether to emit the graphs generated by AOTAutograd. Default: ``False`` + + ``aot_joint_graph`` + Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` + + ``compiled_autograd`` + Whether to emit logs from compiled_autograd. Defaults: ``False`` + + ``ddp_graphs`` + Whether to emit graphs generated by DDPOptimizer. Default: ``False`` + + ``graph`` + Whether to emit the graph captured by TorchDynamo in tabular format. + Default: ``False`` + + ``graph_code`` + Whether to emit the python source of the graph captured by TorchDynamo. + Default: ``False`` + + ``graph_breaks`` + Whether to emit a message when a unique graph break is encountered during + TorchDynamo tracing. Default: ``False`` + + ``guards`` + Whether to emit the guards generated by TorchDynamo for each compiled + function. Default: ``False`` + + ``recompiles`` + Whether to emit a guard failure reason and message every time + TorchDynamo recompiles a function. Default: ``False`` + + ``output_code`` + Whether to emit the TorchInductor output code. Default: ``False`` + + ``schedule`` + Whether to emit the TorchInductor schedule. Default: ``False`` +``` + +```{eval-rst} +Examples: + ``TORCH_LOGS="+dynamo,aot"`` will set the log level of TorchDynamo to ``logging.DEBUG`` and AOT to ``logging.INFO`` + + ``TORCH_LOGS="-dynamo,+inductor"`` will set the log level of TorchDynamo to ``logging.ERROR`` and TorchInductor to ``logging.DEBUG`` + + ``TORCH_LOGS="aot_graphs"`` will enable the ``aot_graphs`` artifact + + ``TORCH_LOGS="+dynamo,schedule"`` will enable set the log level of TorchDynamo to ``logging.DEBUG`` and enable the ``schedule`` artifact + + ``TORCH_LOGS="+some.random.module,schedule"`` will set the log level of some.random.module to ``logging.DEBUG`` and enable the ``schedule`` artifact +``` diff --git a/docs/source/logging.rst b/docs/source/logging.rst deleted file mode 100644 index 457ebd9dbce41e..00000000000000 --- a/docs/source/logging.rst +++ /dev/null @@ -1,109 +0,0 @@ -.. role:: hidden - :class: hidden-section - -torch._logging -============== - -PyTorch has a configurable logging system, where different components can be -given different log level settings. For instance, one component's log messages -can be completely disabled, while another component's log messages can be -set to maximum verbosity. - -.. warning:: This feature is in beta and may have compatibility breaking - changes in the future. - -.. warning:: This feature has not been expanded to control the log messages of - all components in PyTorch yet. - -There are two ways to configure the logging system: through the environment variable ``TORCH_LOGS`` -or the python API torch._logging.set_logs. - -.. automodule:: torch._logging -.. currentmodule:: torch._logging - -.. autosummary:: - :toctree: generated - :nosignatures: - - set_logs - -The environment variable ``TORCH_LOGS`` is a comma-separated list of -``[+-]`` pairs, where ```` is a component specified below. The ``+`` prefix -will decrease the log level of the component, displaying more log messages while the ``-`` prefix -will increase the log level of the component and display fewer log messages. The default setting -is the behavior when a component is not specified in ``TORCH_LOGS``. In addition to components, there are -also artifacts. Artifacts are specific pieces of debug information associated with a component that are either displayed or not displayed, -so prefixing an artifact with ``+`` or ``-`` will be a no-op. Since they are associated with a component, enabling that component will typically also enable that artifact, -unless that artifact was specified to be `off_by_default`. This option is specified in _registrations.py for artifacts that are so spammy they should only be displayed when explicitly enabled. -The following components and artifacts are configurable through the ``TORCH_LOGS`` environment -variable (see torch._logging.set_logs for the python API): - -Components: - ``all`` - Special component which configures the default log level of all components. Default: ``logging.WARN`` - - ``dynamo`` - The log level for the TorchDynamo component. Default: ``logging.WARN`` - - ``aot`` - The log level for the AOTAutograd component. Default: ``logging.WARN`` - - ``inductor`` - The log level for the TorchInductor component. Default: ``logging.WARN`` - - ``your.custom.module`` - The log level for an arbitrary unregistered module. Provide the fully qualified name and the module will be enabled. Default: ``logging.WARN`` - -Artifacts: - ``bytecode`` - Whether to emit the original and generated bytecode from TorchDynamo. - Default: ``False`` - - ``aot_graphs`` - Whether to emit the graphs generated by AOTAutograd. Default: ``False`` - - ``aot_joint_graph`` - Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` - - ``compiled_autograd`` - Whether to emit logs from compiled_autograd. Defaults: ``False`` - - ``ddp_graphs`` - Whether to emit graphs generated by DDPOptimizer. Default: ``False`` - - ``graph`` - Whether to emit the graph captured by TorchDynamo in tabular format. - Default: ``False`` - - ``graph_code`` - Whether to emit the python source of the graph captured by TorchDynamo. - Default: ``False`` - - ``graph_breaks`` - Whether to emit a message when a unique graph break is encountered during - TorchDynamo tracing. Default: ``False`` - - ``guards`` - Whether to emit the guards generated by TorchDynamo for each compiled - function. Default: ``False`` - - ``recompiles`` - Whether to emit a guard failure reason and message every time - TorchDynamo recompiles a function. Default: ``False`` - - ``output_code`` - Whether to emit the TorchInductor output code. Default: ``False`` - - ``schedule`` - Whether to emit the TorchInductor schedule. Default: ``False`` - -Examples: - ``TORCH_LOGS="+dynamo,aot"`` will set the log level of TorchDynamo to ``logging.DEBUG`` and AOT to ``logging.INFO`` - - ``TORCH_LOGS="-dynamo,+inductor"`` will set the log level of TorchDynamo to ``logging.ERROR`` and TorchInductor to ``logging.DEBUG`` - - ``TORCH_LOGS="aot_graphs"`` will enable the ``aot_graphs`` artifact - - ``TORCH_LOGS="+dynamo,schedule"`` will enable set the log level of TorchDynamo to ``logging.DEBUG`` and enable the ``schedule`` artifact - - ``TORCH_LOGS="+some.random.module,schedule"`` will set the log level of some.random.module to ``logging.DEBUG`` and enable the ``schedule`` artifact diff --git a/docs/source/masked.md b/docs/source/masked.md new file mode 100644 index 00000000000000..d193578d6937d6 --- /dev/null +++ b/docs/source/masked.md @@ -0,0 +1,316 @@ +```{eval-rst} +.. automodule:: torch.masked +.. automodule:: torch.masked.maskedtensor +``` + +```{eval-rst} +.. currentmodule:: torch +``` + +(masked-docs)= + +# torch.masked + +## Introduction + +### Motivation + +:::{warning} +The PyTorch API of masked tensors is in the prototype stage and may or may not change in the future. +::: + +MaskedTensor serves as an extension to {class}`torch.Tensor` that provides the user with the ability to: + +* use any masked semantics (e.g. variable length tensors, nan* operators, etc.) +* differentiate between 0 and NaN gradients +* various sparse applications (see tutorial below) + +"Specified" and "unspecified" have a long history in PyTorch without formal semantics and certainly without +consistency; indeed, MaskedTensor was born out of a build up of issues that the vanilla {class}`torch.Tensor` +class could not properly address. Thus, a primary goal of MaskedTensor is to become the source of truth for +said "specified" and "unspecified" values in PyTorch where they are a first class citizen instead of an afterthought. +In turn, this should further unlock [sparsity's](https://pytorch.org/docs/stable/sparse.html) potential, +enable safer and more consistent operators, and provide a smoother and more intuitive experience +for users and developers alike. + +### What is a MaskedTensor? + +A MaskedTensor is a tensor subclass that consists of 1) an input (data), and 2) a mask. The mask tells us +which entries from the input should be included or ignored. + +By way of example, suppose that we wanted to mask out all values that are equal to 0 (represented by the gray) +and take the max: + +```{eval-rst} +.. image:: _static/img/masked/tensor_comparison.jpg + :scale: 50% +``` + +On top is the vanilla tensor example while the bottom is MaskedTensor where all the 0's are masked out. +This clearly yields a different result depending on whether we have the mask, but this flexible structure +allows the user to systematically ignore any elements they'd like during computation. + +There are already a number of existing tutorials that we've written to help users onboard, such as: + +- [Overview – the place to start for new users, discusses how to use MaskedTensors and why they're useful](https://pytorch.org/tutorials/prototype/maskedtensor_overview) +- [Sparsity – MaskedTensor supports sparse COO and CSR data and mask Tensors](https://pytorch.org/tutorials/prototype/maskedtensor_sparsity) +- [Adagrad sparse semantics – a practical example of how MaskedTensor can simplify sparse semantics and implementations](https://pytorch.org/tutorials/prototype/maskedtensor_adagrad) +- [Advanced semantics – discussion on why certain decisions were made (e.g. requiring masks to match for binary/reduction operations), differences with NumPy's MaskedArray, and reduction semantics](https://pytorch.org/tutorials/prototype/maskedtensor_advanced_semantics) + +## Supported Operators + +### Unary Operators + +Unary operators are operators that only contain only a single input. +Applying them to MaskedTensors is relatively straightforward: if the data is masked out at a given index, +we apply the operator, otherwise we'll continue to mask out the data. + +The available unary operators are: + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + abs + absolute + acos + arccos + acosh + arccosh + angle + asin + arcsin + asinh + arcsinh + atan + arctan + atanh + arctanh + bitwise_not + ceil + clamp + clip + conj_physical + cos + cosh + deg2rad + digamma + erf + erfc + erfinv + exp + exp2 + expm1 + fix + floor + frac + lgamma + log + log10 + log1p + log2 + logit + i0 + isnan + nan_to_num + neg + negative + positive + pow + rad2deg + reciprocal + round + rsqrt + sigmoid + sign + sgn + signbit + sin + sinc + sinh + sqrt + square + tan + tanh + trunc +``` + +The available inplace unary operators are all of the above **except**: + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + angle + positive + signbit + isnan +``` + +### Binary Operators + +As you may have seen in the tutorial, {class}`MaskedTensor` also has binary operations implemented with the caveat +that the masks in the two MaskedTensors must match or else an error will be raised. As noted in the error, if you +need support for a particular operator or have proposed semantics for how they should behave instead, please open +an issue on GitHub. For now, we have decided to go with the most conservative implementation to ensure that users +know exactly what is going on and are being intentional about their decisions with masked semantics. + +The available binary operators are: + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + add + atan2 + arctan2 + bitwise_and + bitwise_or + bitwise_xor + bitwise_left_shift + bitwise_right_shift + div + divide + floor_divide + fmod + logaddexp + logaddexp2 + mul + multiply + nextafter + remainder + sub + subtract + true_divide + eq + ne + le + ge + greater + greater_equal + gt + less_equal + lt + less + maximum + minimum + fmax + fmin + not_equal +``` + +The available inplace binary operators are all of the above **except**: + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + logaddexp + logaddexp2 + equal + fmin + minimum + fmax +``` + +### Reductions + +The following reductions are available (with autograd support). For more information, the +[Overview](https://pytorch.org/tutorials/prototype/maskedtensor_overview.html) tutorial +details some examples of reductions, while the +[Advanced semantics](https://pytorch.org/tutorials/prototype/maskedtensor_advanced_semantics.html) tutorial +has some further in-depth discussions about how we decided on certain reduction semantics. + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + sum + mean + amin + amax + argmin + argmax + prod + all + norm + var + std +``` + +### View and select functions + +We've included a number of view and select functions as well; intuitively, these operators will apply to +both the data and the mask and then wrap the result in a {class}`MaskedTensor`. For a quick example, +consider {func}`select`: + +```python + >>> data = torch.arange(12, dtype=torch.float).reshape(3, 4) + >>> data + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]) + >>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]]) + >>> mt = masked_tensor(data, mask) + >>> data.select(0, 1) + tensor([4., 5., 6., 7.]) + >>> mask.select(0, 1) + tensor([False, True, False, False]) + >>> mt.select(0, 1) + MaskedTensor( + [ --, 5.0000, --, --] + ) +``` + +The following ops are currently supported: + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + atleast_1d + broadcast_tensors + broadcast_to + cat + chunk + column_stack + dsplit + flatten + hsplit + hstack + kron + meshgrid + narrow + nn.functional.unfold + ravel + select + split + stack + t + transpose + vsplit + vstack + Tensor.expand + Tensor.expand_as + Tensor.reshape + Tensor.reshape_as + Tensor.unfold + Tensor.view +``` + +```{eval-rst} +.. This module needs to be documented. Adding here in the meantime +.. for tracking purposes +.. py:module:: torch.masked.maskedtensor.binary +.. py:module:: torch.masked.maskedtensor.core +.. py:module:: torch.masked.maskedtensor.creation +.. py:module:: torch.masked.maskedtensor.passthrough +.. py:module:: torch.masked.maskedtensor.reductions +.. py:module:: torch.masked.maskedtensor.unary +``` diff --git a/docs/source/masked.rst b/docs/source/masked.rst deleted file mode 100644 index 8177b91a9c15c6..00000000000000 --- a/docs/source/masked.rst +++ /dev/null @@ -1,309 +0,0 @@ -.. automodule:: torch.masked -.. automodule:: torch.masked.maskedtensor - -.. currentmodule:: torch - -.. _masked-docs: - -torch.masked -============ - -Introduction -++++++++++++ - -Motivation ----------- - -.. warning:: - - The PyTorch API of masked tensors is in the prototype stage and may or may not change in the future. - -MaskedTensor serves as an extension to :class:`torch.Tensor` that provides the user with the ability to: - -* use any masked semantics (e.g. variable length tensors, nan* operators, etc.) -* differentiate between 0 and NaN gradients -* various sparse applications (see tutorial below) - -"Specified" and "unspecified" have a long history in PyTorch without formal semantics and certainly without -consistency; indeed, MaskedTensor was born out of a build up of issues that the vanilla :class:`torch.Tensor` -class could not properly address. Thus, a primary goal of MaskedTensor is to become the source of truth for -said "specified" and "unspecified" values in PyTorch where they are a first class citizen instead of an afterthought. -In turn, this should further unlock `sparsity's `_ potential, -enable safer and more consistent operators, and provide a smoother and more intuitive experience -for users and developers alike. - -What is a MaskedTensor? ------------------------ - -A MaskedTensor is a tensor subclass that consists of 1) an input (data), and 2) a mask. The mask tells us -which entries from the input should be included or ignored. - -By way of example, suppose that we wanted to mask out all values that are equal to 0 (represented by the gray) -and take the max: - -.. image:: _static/img/masked/tensor_comparison.jpg - :scale: 50% - -On top is the vanilla tensor example while the bottom is MaskedTensor where all the 0's are masked out. -This clearly yields a different result depending on whether we have the mask, but this flexible structure -allows the user to systematically ignore any elements they'd like during computation. - -There are already a number of existing tutorials that we've written to help users onboard, such as: - -- `Overview - the place to start for new users, discusses how to use MaskedTensors and why they're useful`_ -- `Sparsity - MaskedTensor supports sparse COO and CSR data and mask Tensors`_ -- `Adagrad sparse semantics - a practical example of how MaskedTensor can simplify sparse semantics and implementations`_ -- `Advanced semantics - discussion on why certain decisions were made (e.g. requiring masks to match for binary/reduction operations), - differences with NumPy's MaskedArray, and reduction semantics`_ - -.. _Overview - the place to start for new users, discusses how to use MaskedTensors and why they're useful: https://pytorch.org/tutorials/prototype/maskedtensor_overview -.. _Sparsity - MaskedTensor supports sparse COO and CSR data and mask Tensors: https://pytorch.org/tutorials/prototype/maskedtensor_sparsity -.. _Adagrad sparse semantics - a practical example of how MaskedTensor can simplify sparse semantics and implementations: https://pytorch.org/tutorials/prototype/maskedtensor_adagrad -.. _Advanced semantics - discussion on why certain decisions were made (e.g. requiring masks to match for binary/reduction operations), differences with NumPy's MaskedArray, and reduction semantics: https://pytorch.org/tutorials/prototype/maskedtensor_advanced_semantics - -Supported Operators -+++++++++++++++++++ - -Unary Operators ---------------- - -Unary operators are operators that only contain only a single input. -Applying them to MaskedTensors is relatively straightforward: if the data is masked out at a given index, -we apply the operator, otherwise we'll continue to mask out the data. - -The available unary operators are: - -.. autosummary:: - :toctree: generated - :nosignatures: - - abs - absolute - acos - arccos - acosh - arccosh - angle - asin - arcsin - asinh - arcsinh - atan - arctan - atanh - arctanh - bitwise_not - ceil - clamp - clip - conj_physical - cos - cosh - deg2rad - digamma - erf - erfc - erfinv - exp - exp2 - expm1 - fix - floor - frac - lgamma - log - log10 - log1p - log2 - logit - i0 - isnan - nan_to_num - neg - negative - positive - pow - rad2deg - reciprocal - round - rsqrt - sigmoid - sign - sgn - signbit - sin - sinc - sinh - sqrt - square - tan - tanh - trunc - -The available inplace unary operators are all of the above **except**: - -.. autosummary:: - :toctree: generated - :nosignatures: - - angle - positive - signbit - isnan - -Binary Operators ----------------- - -As you may have seen in the tutorial, :class:`MaskedTensor` also has binary operations implemented with the caveat -that the masks in the two MaskedTensors must match or else an error will be raised. As noted in the error, if you -need support for a particular operator or have proposed semantics for how they should behave instead, please open -an issue on GitHub. For now, we have decided to go with the most conservative implementation to ensure that users -know exactly what is going on and are being intentional about their decisions with masked semantics. - -The available binary operators are: - -.. autosummary:: - :toctree: generated - :nosignatures: - - add - atan2 - arctan2 - bitwise_and - bitwise_or - bitwise_xor - bitwise_left_shift - bitwise_right_shift - div - divide - floor_divide - fmod - logaddexp - logaddexp2 - mul - multiply - nextafter - remainder - sub - subtract - true_divide - eq - ne - le - ge - greater - greater_equal - gt - less_equal - lt - less - maximum - minimum - fmax - fmin - not_equal - -The available inplace binary operators are all of the above **except**: - -.. autosummary:: - :toctree: generated - :nosignatures: - - logaddexp - logaddexp2 - equal - fmin - minimum - fmax - -Reductions ----------- - -The following reductions are available (with autograd support). For more information, the -`Overview `_ tutorial -details some examples of reductions, while the -`Advanced semantics `_ tutorial -has some further in-depth discussions about how we decided on certain reduction semantics. - -.. autosummary:: - :toctree: generated - :nosignatures: - - sum - mean - amin - amax - argmin - argmax - prod - all - norm - var - std - -View and select functions -------------------------- - -We've included a number of view and select functions as well; intuitively, these operators will apply to -both the data and the mask and then wrap the result in a :class:`MaskedTensor`. For a quick example, -consider :func:`select`: - - >>> data = torch.arange(12, dtype=torch.float).reshape(3, 4) - >>> data - tensor([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.]]) - >>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]]) - >>> mt = masked_tensor(data, mask) - >>> data.select(0, 1) - tensor([4., 5., 6., 7.]) - >>> mask.select(0, 1) - tensor([False, True, False, False]) - >>> mt.select(0, 1) - MaskedTensor( - [ --, 5.0000, --, --] - ) - -The following ops are currently supported: - -.. autosummary:: - :toctree: generated - :nosignatures: - - atleast_1d - broadcast_tensors - broadcast_to - cat - chunk - column_stack - dsplit - flatten - hsplit - hstack - kron - meshgrid - narrow - nn.functional.unfold - ravel - select - split - stack - t - transpose - vsplit - vstack - Tensor.expand - Tensor.expand_as - Tensor.reshape - Tensor.reshape_as - Tensor.unfold - Tensor.view - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.masked.maskedtensor.binary -.. py:module:: torch.masked.maskedtensor.core -.. py:module:: torch.masked.maskedtensor.creation -.. py:module:: torch.masked.maskedtensor.passthrough -.. py:module:: torch.masked.maskedtensor.reductions -.. py:module:: torch.masked.maskedtensor.unary \ No newline at end of file diff --git a/docs/source/meta.md b/docs/source/meta.md new file mode 100644 index 00000000000000..a5cca7c86cde87 --- /dev/null +++ b/docs/source/meta.md @@ -0,0 +1,92 @@ +# Meta device + +The "meta" device is an abstract device which denotes a tensor which records +only metadata, but no actual data. Meta tensors have two primary use cases: + +* Models can be loaded on the meta device, allowing you to load a + representation of the model without actually loading the actual parameters + into memory. This can be helpful if you need to make transformations on + the model before you load the actual data. + +* Most operations can be performed on meta tensors, producing new meta + tensors that describe what the result would have been if you performed + the operation on a real tensor. You can use this to perform abstract + analysis without needing to spend time on compute or space to represent + the actual tensors. Because meta tensors do not have real data, you cannot + perform data-dependent operations like {func}`torch.nonzero` or + {meth}`~torch.Tensor.item`. In some cases, not all device types (e.g., CPU + and CUDA) have exactly the same output metadata for an operation; we + typically prefer representing the CUDA behavior faithfully in this + situation. + +```{warning} +Although in principle meta tensor computation should always be faster than +an equivalent CPU/CUDA computation, many meta tensor implementations are +implemented in Python and have not been ported to C++ for speed, so you +may find that you get lower absolute framework latency with small CPU tensors. +``` + +## Idioms for working with meta tensors + +An object can be loaded with {func}`torch.load` onto meta device by specifying +`map_location='meta'`: + +```python +>>> torch.save(torch.randn(2), 'foo.pt') +>>> torch.load('foo.pt', map_location='meta') +tensor(..., device='meta', size=(2,)) +``` + +If you have some arbitrary code which performs some tensor construction without +explicitly specifying a device, you can override it to instead construct on meta device by using +the {func}`torch.device` context manager: + +```python +>>> with torch.device('meta'): +... print(torch.randn(30, 30)) +... +tensor(..., device='meta', size=(30, 30)) +``` + +This is especially helpful NN module construction, where you often are not +able to explicitly pass in a device for initialization: + +```python +>>> from torch.nn.modules import Linear +>>> with torch.device('meta'): +... print(Linear(20, 30)) +... +Linear(in_features=20, out_features=30, bias=True) +``` + +You cannot convert a meta tensor directly to a CPU/CUDA tensor, because the +meta tensor stores no data and we do not know what the correct data values for +your new tensor are: + +```python +>>> torch.ones(5, device='meta').to("cpu") +Traceback (most recent call last): + File "", line 1, in +NotImplementedError: Cannot copy out of meta tensor; no data! +``` + +Use a factory function like {func}`torch.empty_like` to explicitly specify how +you would like the missing data to be filled in. + +NN modules have a convenience method {meth}`torch.nn.Module.to_empty` that +allows you to move the module to another device, leaving all parameters +uninitialized. You are expected to explicitly reinitialize the parameters +manually: + +```python +>>> from torch.nn.modules import Linear +>>> with torch.device('meta'): +... m = Linear(20, 30) +>>> m.to_empty(device="cpu") +Linear(in_features=20, out_features=30, bias=True) +``` + +{mod}`torch._subclasses.meta_utils` contains undocumented utilities for taking +an arbitrary Tensor and constructing an equivalent meta Tensor with high +fidelity. These APIs are experimental and may be changed in a BC breaking way +at any time. diff --git a/docs/source/meta.rst b/docs/source/meta.rst deleted file mode 100644 index 47efc205fe8d73..00000000000000 --- a/docs/source/meta.rst +++ /dev/null @@ -1,84 +0,0 @@ -Meta device -============ - -The "meta" device is an abstract device which denotes a tensor which records -only metadata, but no actual data. Meta tensors have two primary use cases: - -* Models can be loaded on the meta device, allowing you to load a - representation of the model without actually loading the actual parameters - into memory. This can be helpful if you need to make transformations on - the model before you load the actual data. - -* Most operations can be performed on meta tensors, producing new meta - tensors that describe what the result would have been if you performed - the operation on a real tensor. You can use this to perform abstract - analysis without needing to spend time on compute or space to represent - the actual tensors. Because meta tensors do not have real data, you cannot - perform data-dependent operations like :func:`torch.nonzero` or - :meth:`~torch.Tensor.item`. In some cases, not all device types (e.g., CPU - and CUDA) have exactly the same output metadata for an operation; we - typically prefer representing the CUDA behavior faithfully in this - situation. - -.. warning:: - - Although in principle meta tensor computation should always be faster than - an equivalent CPU/CUDA computation, many meta tensor implementations are - implemented in Python and have not been ported to C++ for speed, so you - may find that you get lower absolute framework latency with small CPU tensors. - -Idioms for working with meta tensors ------------------------------------- - -An object can be loaded with :func:`torch.load` onto meta device by specifying -``map_location='meta'``:: - - >>> torch.save(torch.randn(2), 'foo.pt') - >>> torch.load('foo.pt', map_location='meta') - tensor(..., device='meta', size=(2,)) - -If you have some arbitrary code which performs some tensor construction without -explicitly specifying a device, you can override it to instead construct on meta device by using -the :func:`torch.device` context manager:: - - >>> with torch.device('meta'): - ... print(torch.randn(30, 30)) - ... - tensor(..., device='meta', size=(30, 30)) - -This is especially helpful NN module construction, where you often are not -able to explicitly pass in a device for initialization:: - - >>> from torch.nn.modules import Linear - >>> with torch.device('meta'): - ... print(Linear(20, 30)) - ... - Linear(in_features=20, out_features=30, bias=True) - -You cannot convert a meta tensor directly to a CPU/CUDA tensor, because the -meta tensor stores no data and we do not know what the correct data values for -your new tensor are:: - - >>> torch.ones(5, device='meta').to("cpu") - Traceback (most recent call last): - File "", line 1, in - NotImplementedError: Cannot copy out of meta tensor; no data! - -Use a factory function like :func:`torch.empty_like` to explicitly specify how -you would like the missing data to be filled in. - -NN modules have a convenience method :meth:`torch.nn.Module.to_empty` that -allows you to move the module to another device, leaving all parameters -uninitialized. You are expected to explicitly reinitialize the parameters -manually:: - - >>> from torch.nn.modules import Linear - >>> with torch.device('meta'): - ... m = Linear(20, 30) - >>> m.to_empty(device="cpu") - Linear(in_features=20, out_features=30, bias=True) - -:mod:`torch._subclasses.meta_utils` contains undocumented utilities for taking -an arbitrary Tensor and constructing an equivalent meta Tensor with high -fidelity. These APIs are experimental and may be changed in a BC breaking way -at any time. diff --git a/docs/source/miscellaneous_environment_variables.md b/docs/source/miscellaneous_environment_variables.md new file mode 100644 index 00000000000000..6046d787ef5330 --- /dev/null +++ b/docs/source/miscellaneous_environment_variables.md @@ -0,0 +1,10 @@ +(miscellaneous_environment_variables)= + +# Miscellaneous Environment Variables + +| Variable | Description | +|---------------------------------------|-------------| +| `TORCH_FORCE_WEIGHTS_ONLY_LOAD` | If set to [`1`, `y`, `yes`, `true`], the `torch.load` will use `weights_only=True`. This will happen even if `weights_only=False` was passed at the callsite. For more documentation on this, see [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html). | +| `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` | If set to [`1`, `y`, `yes`, `true`], the `torch.load` will use `weights_only=False` if the `weights_only` variable was not passed at the callsite. For more documentation on this, see [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html). | +| `TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT` | Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on a timeout that is by default set to `10` seconds. This environment variable can be used to set the timeout in seconds. | +| `TORCH_DEVICE_BACKEND_AUTOLOAD` | If set to `1`, out-of-tree backend extensions will be automatically imported when running `import torch`. | diff --git a/docs/source/miscellaneous_environment_variables.rst b/docs/source/miscellaneous_environment_variables.rst deleted file mode 100644 index 14494241af9de0..00000000000000 --- a/docs/source/miscellaneous_environment_variables.rst +++ /dev/null @@ -1,19 +0,0 @@ -.. _miscellaneous_environment_variables: - -Miscellaneous Environment Variables -=================================== -.. list-table:: - :header-rows: 1 - - * - Variable - - Description - * - ``TORCH_FORCE_WEIGHTS_ONLY_LOAD`` - - If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weights_only=True``. This will happen even if - ``weights_only=False`` was passed at the callsite. For more documentation on this, see :func:`torch.load`. - * - ``TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD`` - - If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weights_only=False`` if the ``weights_only`` variable was not - passed at the callsite. For more documentation on this, see :func:`torch.load`. - * - ``TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT`` - - Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on timeout that is default set to ``10`` seconds. This environment variable can be used to set the timeout in seconds. - * - ``TORCH_DEVICE_BACKEND_AUTOLOAD`` - - If set to ``1``, out-of-tree backend extensions will be automatically imported when running ``import torch``. diff --git a/docs/source/mobile_optimizer.md b/docs/source/mobile_optimizer.md new file mode 100644 index 00000000000000..55b3c32c9fff28 --- /dev/null +++ b/docs/source/mobile_optimizer.md @@ -0,0 +1,24 @@ +--- +robots: noindex +--- +# torch.utils.mobile_optimizer + +PyTorch Mobile is no longer actively supported. Redirecting to [ExecuTorch documentation](https://docs.pytorch.org/executorch). + +```{raw} html + +``` + +```{warning} +PyTorch Mobile is no longer actively supported. Please check out +[ExecuTorch](https://pytorch.org/executorch-overview), PyTorch's +all-new on-device inference library. You can also review +documentation on [XNNPACK](https://pytorch.org/executorch/stable/native-delegates-executorch-xnnpack-delegate.html) +and [Vulkan](https://pytorch.org/executorch/stable/native-delegates-executorch-vulkan-delegate.html) delegates. +``` +```{eval-rst} +.. currentmodule:: torch.utils.mobile_optimizer +``` +```{eval-rst} +.. autofunction:: optimize_for_mobile +``` diff --git a/docs/source/mobile_optimizer.rst b/docs/source/mobile_optimizer.rst deleted file mode 100644 index 2a4bab76ec69c1..00000000000000 --- a/docs/source/mobile_optimizer.rst +++ /dev/null @@ -1,21 +0,0 @@ -.. meta:: - :robots: noindex - -torch.utils.mobile_optimizer -=================================== - -PyTorch Mobile is no longer actively supported. Redirecting to `ExecuTorch documentation `_. - -.. raw:: html - - - -.. warning:: - PyTorch Mobile is no longer actively supported. Please check out - `ExecuTorch `__, PyTorch's - all-new on-device inference library. You can also review - documentation on `XNNPACK `__ - and `Vulkan `__ delegates. - -.. currentmodule:: torch.utils.mobile_optimizer -.. autofunction:: optimize_for_mobile diff --git a/docs/source/model_zoo.md b/docs/source/model_zoo.md new file mode 100644 index 00000000000000..5caf7ac40c898c --- /dev/null +++ b/docs/source/model_zoo.md @@ -0,0 +1,10 @@ +# torch.utils.model_zoo + +Moved to `torch.hub`. + +```{eval-rst} +.. automodule:: torch.utils.model_zoo +``` +```{eval-rst} +.. autofunction:: load_url +``` diff --git a/docs/source/model_zoo.rst b/docs/source/model_zoo.rst deleted file mode 100644 index a2a8dec4351985..00000000000000 --- a/docs/source/model_zoo.rst +++ /dev/null @@ -1,7 +0,0 @@ -torch.utils.model_zoo -=================================== - -Moved to `torch.hub`. - -.. automodule:: torch.utils.model_zoo -.. autofunction:: load_url diff --git a/docs/source/module_tracker.md b/docs/source/module_tracker.md new file mode 100644 index 00000000000000..d0d1d55e64ae9d --- /dev/null +++ b/docs/source/module_tracker.md @@ -0,0 +1,11 @@ +# torch.utils.module_tracker +```{eval-rst} +.. automodule:: torch.utils.module_tracker +``` + +This utility can be used to track the current position inside an {class}`torch.nn.Module` hierarchy. +It can be used within other tracking tools to be able to easily associate measured quantities to user-friendly names. This is used in particular in the FlopCounterMode today. + +```{eval-rst} +.. autoclass:: torch.utils.module_tracker.ModuleTracker +``` diff --git a/docs/source/module_tracker.rst b/docs/source/module_tracker.rst deleted file mode 100644 index ecb84b22f32e55..00000000000000 --- a/docs/source/module_tracker.rst +++ /dev/null @@ -1,8 +0,0 @@ -torch.utils.module_tracker -=================================== -.. automodule:: torch.utils.module_tracker - -This utility can be used to track the current position inside an :class:`torch.nn.Module` hierarchy. -It can be used within other tracking tools to be able to easily associate measured quantities to user-friendly names. This is used in particular in the FlopCounterMode today. - -.. autoclass:: torch.utils.module_tracker.ModuleTracker diff --git a/docs/source/monitor.md b/docs/source/monitor.md new file mode 100644 index 00000000000000..20d310a20cd73a --- /dev/null +++ b/docs/source/monitor.md @@ -0,0 +1,70 @@ +# torch.monitor + +```{warning} +This module is a prototype release, and its interfaces and functionality may +change without warning in future PyTorch releases. +``` + +``torch.monitor`` provides an interface for logging events and counters from +PyTorch. + +The stat interfaces are designed to be used for tracking high level metrics that +are periodically logged out to be used for monitoring system performance. Since +the stats aggregate with a specific window size you can log to them from +critical loops with minimal performance impact. + +For more infrequent events or values such as loss, accuracy, usage tracking the +event interface can be directly used. + +Event handlers can be registered to handle the events and pass them to an +external event sink. + +## API Reference +```{eval-rst} +.. automodule:: torch.monitor +``` + +```{eval-rst} +.. autoclass:: torch.monitor.Aggregation + :members: +``` + +```{eval-rst} +.. autoclass:: torch.monitor.Stat + :members: + :special-members: __init__ +``` + +```{eval-rst} +.. autoclass:: torch.monitor.data_value_t + :members: +``` + +```{eval-rst} +.. autoclass:: torch.monitor.Event + :members: + :special-members: __init__ +``` + +```{eval-rst} +.. autoclass:: torch.monitor.EventHandlerHandle + :members: +``` + +```{eval-rst} +.. autofunction:: torch.monitor.log_event +``` + +```{eval-rst} +.. autofunction:: torch.monitor.register_event_handler +``` + +```{eval-rst} +.. autofunction:: torch.monitor.unregister_event_handler +``` + +```{eval-rst} +.. autoclass:: torch.monitor.TensorboardEventHandler + :members: + :special-members: __init__ +``` diff --git a/docs/source/monitor.rst b/docs/source/monitor.rst deleted file mode 100644 index 7952586da9c12e..00000000000000 --- a/docs/source/monitor.rst +++ /dev/null @@ -1,53 +0,0 @@ -torch.monitor -============= - -.. warning:: - - This module is a prototype release, and its interfaces and functionality may - change without warning in future PyTorch releases. - -``torch.monitor`` provides an interface for logging events and counters from -PyTorch. - -The stat interfaces are designed to be used for tracking high level metrics that -are periodically logged out to be used for monitoring system performance. Since -the stats aggregate with a specific window size you can log to them from -critical loops with minimal performance impact. - -For more infrequent events or values such as loss, accuracy, usage tracking the -event interface can be directly used. - -Event handlers can be registered to handle the events and pass them to an -external event sink. - -API Reference -------------- - -.. automodule:: torch.monitor - -.. autoclass:: torch.monitor.Aggregation - :members: - -.. autoclass:: torch.monitor.Stat - :members: - :special-members: __init__ - -.. autoclass:: torch.monitor.data_value_t - :members: - -.. autoclass:: torch.monitor.Event - :members: - :special-members: __init__ - -.. autoclass:: torch.monitor.EventHandlerHandle - :members: - -.. autofunction:: torch.monitor.log_event - -.. autofunction:: torch.monitor.register_event_handler - -.. autofunction:: torch.monitor.unregister_event_handler - -.. autoclass:: torch.monitor.TensorboardEventHandler - :members: - :special-members: __init__ diff --git a/docs/source/mps.md b/docs/source/mps.md new file mode 100644 index 00000000000000..7336e71c03c058 --- /dev/null +++ b/docs/source/mps.md @@ -0,0 +1,67 @@ +# torch.mps + +```{eval-rst} +.. automodule:: torch.mps +``` + +```{eval-rst} +.. currentmodule:: torch.mps +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + device_count + synchronize + get_rng_state + set_rng_state + manual_seed + seed + empty_cache + set_per_process_memory_fraction + current_allocated_memory + driver_allocated_memory + recommended_max_memory + compile_shader +``` + +## MPS Profiler + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + profiler.start + profiler.stop + profiler.profile + + profiler.is_capturing_metal + profiler.is_metal_capture_enabled + profiler.metal_capture +``` + +## MPS Event + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + event.Event + +``` + +% This module needs to be documented. Adding here in the meantime + +% for tracking purposes + +```{eval-rst} +.. py:module:: torch.mps.event +``` + +```{eval-rst} +.. py:module:: torch.mps.profiler +``` diff --git a/docs/source/mps.rst b/docs/source/mps.rst deleted file mode 100644 index 623915d3af1d8d..00000000000000 --- a/docs/source/mps.rst +++ /dev/null @@ -1,49 +0,0 @@ -torch.mps -=================================== -.. automodule:: torch.mps -.. currentmodule:: torch.mps - -.. autosummary:: - :toctree: generated - :nosignatures: - - device_count - synchronize - get_rng_state - set_rng_state - manual_seed - seed - empty_cache - set_per_process_memory_fraction - current_allocated_memory - driver_allocated_memory - recommended_max_memory - compile_shader - -MPS Profiler ------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - profiler.start - profiler.stop - profiler.profile - - profiler.is_capturing_metal - profiler.is_metal_capture_enabled - profiler.metal_capture - -MPS Event ------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - event.Event - - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.mps.event -.. py:module:: torch.mps.profiler diff --git a/docs/source/mps_environment_variables.md b/docs/source/mps_environment_variables.md new file mode 100644 index 00000000000000..93c3b94879c036 --- /dev/null +++ b/docs/source/mps_environment_variables.md @@ -0,0 +1,33 @@ +(mps_environment_variables)= +# MPS Environment Variables + +**PyTorch Environment Variables** + + +| Variable | Description | +|----------------------------------|-------------| +| `PYTORCH_DEBUG_MPS_ALLOCATOR` | If set to `1`, set allocator logging level to verbose. | +| `PYTORCH_MPS_LOG_PROFILE_INFO` | Set log options bitmask to `MPSProfiler`. See `LogOptions` enum in `aten/src/ATen/mps/MPSProfiler.h`. | +| `PYTORCH_MPS_TRACE_SIGNPOSTS` | Set profile and signpost bitmasks to `MPSProfiler`. See `ProfileOptions` and `SignpostTypes`. | +| `PYTORCH_MPS_HIGH_WATERMARK_RATIO` | High watermark ratio for MPS allocator. Default is 1.7. | +| `PYTORCH_MPS_LOW_WATERMARK_RATIO` | Low watermark ratio for MPS allocator. Default is 1.4 (unified) or 1.0 (discrete). | +| `PYTORCH_MPS_FAST_MATH` | If `1`, enables fast math for MPS kernels. See section 1.6.3 in the [Metal Shading Language Spec](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf). | +| `PYTORCH_MPS_PREFER_METAL` | If `1`, uses metal kernels instead of MPS Graph APIs. Used for matmul. | +| `PYTORCH_ENABLE_MPS_FALLBACK` | If `1`, falls back to CPU when MPS ops aren't supported. | + +```{note} +**high watermark ratio** is a hard limit for the total allowed allocations + +- `0.0` : disables high watermark limit (may cause system failure if system-wide OOM occurs) +- `1.0` : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize) +- `>1.0`: allows limits beyond the device.recommendedMaxWorkingSetSize + +e.g., value 0.95 means we allocate up to 95% of recommended maximum +allocation size; beyond that, the allocations would fail with OOM error. + +**low watermark ratio** is a soft limit to attempt limiting memory allocations up to the lower watermark +level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit). +Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection) +e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum +allocation size. +``` diff --git a/docs/source/mps_environment_variables.rst b/docs/source/mps_environment_variables.rst deleted file mode 100644 index 34c3ab75ce8db3..00000000000000 --- a/docs/source/mps_environment_variables.rst +++ /dev/null @@ -1,45 +0,0 @@ -.. _mps_environment_variables: - -MPS Environment Variables -========================== - -**PyTorch Environment Variables** - -.. list-table:: - :header-rows: 1 - - * - Variable - - Description - * - ``PYTORCH_DEBUG_MPS_ALLOCATOR`` - - If set to ``1``, set allocator logging level to verbose. - * - ``PYTORCH_MPS_LOG_PROFILE_INFO`` - - Set log options bitmask to ``MPSProfiler``. See ``LogOptions`` enum in `aten/src/ATen/mps/MPSProfiler.h` for available options. - * - ``PYTORCH_MPS_TRACE_SIGNPOSTS`` - - Set profile and signpost bitmasks to ``MPSProfiler``. See ``ProfileOptions`` and ``SignpostTypes`` enums in `aten/src/ATen/mps/MPSProfiler.h` for available options. - * - ``PYTORCH_MPS_HIGH_WATERMARK_RATIO`` - - High watermark ratio for MPS allocator. By default, it is set to 1.7. - * - ``PYTORCH_MPS_LOW_WATERMARK_RATIO`` - - Low watermark ratio for MPS allocator. By default, it is set to 1.4 if the memory is unified and set to 1.0 if the memory is discrete. - * - ``PYTORCH_MPS_FAST_MATH`` - - If set to ``1``, enable fast math for MPS metal kernels. See section 1.6.3 in https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf for precision implications. - * - ``PYTORCH_MPS_PREFER_METAL`` - - If set to ``1``, force using metal kernels instead of using MPS Graph APIs. For now this is only used for matmul op. - * - ``PYTORCH_ENABLE_MPS_FALLBACK`` - - If set to ``1``, full back operations to CPU when MPS does not support them. - -.. note:: - - **high watermark ratio** is a hard limit for the total allowed allocations - - - `0.0` : disables high watermark limit (may cause system failure if system-wide OOM occurs) - - `1.0` : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize) - - `>1.0`: allows limits beyond the device.recommendedMaxWorkingSetSize - - e.g., value 0.95 means we allocate up to 95% of recommended maximum - allocation size; beyond that, the allocations would fail with OOM error. - - **low watermark ratio** is a soft limit to attempt limiting memory allocations up to the lower watermark - level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit). - Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection) - e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum - allocation size. \ No newline at end of file diff --git a/docs/source/mtia.md b/docs/source/mtia.md new file mode 100644 index 00000000000000..3229b80c3d91b2 --- /dev/null +++ b/docs/source/mtia.md @@ -0,0 +1,51 @@ +# torch.mtia + +The MTIA backend is implemented out of the tree, only interfaces are be defined here. + +```{eval-rst} +.. automodule:: torch.mtia +``` + +```{eval-rst} +.. currentmodule:: torch.mtia +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + StreamContext + current_device + current_stream + default_stream + device_count + init + is_available + is_initialized + memory_stats + get_device_capability + empty_cache + record_memory_history + snapshot + attach_out_of_memory_observer + set_device + set_stream + stream + synchronize + device + set_rng_state + get_rng_state + DeferredMtiaCallError +``` + +## Streams and events + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + Event + Stream +``` diff --git a/docs/source/mtia.memory.md b/docs/source/mtia.memory.md new file mode 100644 index 00000000000000..4da8c098840627 --- /dev/null +++ b/docs/source/mtia.memory.md @@ -0,0 +1,19 @@ +# torch.mtia.memory + +The MTIA backend is implemented out of the tree, only interfaces are be defined here. + +```{eval-rst} +.. automodule:: torch.mtia.memory +``` + +```{eval-rst} +.. currentmodule:: torch.mtia.memory +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + memory_stats +``` diff --git a/docs/source/mtia.memory.rst b/docs/source/mtia.memory.rst deleted file mode 100644 index 90856dc5a40b68..00000000000000 --- a/docs/source/mtia.memory.rst +++ /dev/null @@ -1,13 +0,0 @@ -torch.mtia.memory -=================================== - -The MTIA backend is implemented out of the tree, only interfaces are be defined here. - -.. automodule:: torch.mtia.memory -.. currentmodule:: torch.mtia.memory - -.. autosummary:: - :toctree: generated - :nosignatures: - - memory_stats diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst deleted file mode 100644 index 7572e6bf56b310..00000000000000 --- a/docs/source/mtia.rst +++ /dev/null @@ -1,43 +0,0 @@ -torch.mtia -=================================== - -The MTIA backend is implemented out of the tree, only interfaces are be defined here. - -.. automodule:: torch.mtia -.. currentmodule:: torch.mtia - -.. autosummary:: - :toctree: generated - :nosignatures: - - StreamContext - current_device - current_stream - default_stream - device_count - init - is_available - is_initialized - memory_stats - get_device_capability - empty_cache - record_memory_history - snapshot - attach_out_of_memory_observer - set_device - set_stream - stream - synchronize - device - set_rng_state - get_rng_state - DeferredMtiaCallError - -Streams and events ------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - Event - Stream diff --git a/docs/source/multiprocessing.md b/docs/source/multiprocessing.md new file mode 100644 index 00000000000000..6669fcaa24b30e --- /dev/null +++ b/docs/source/multiprocessing.md @@ -0,0 +1,219 @@ +--- +orphan: true +--- + +(multiprocessing-doc)= + +# Multiprocessing package - torch.multiprocessing + +```{eval-rst} +.. automodule:: torch.multiprocessing +``` + +```{eval-rst} +.. currentmodule:: torch.multiprocessing +``` + +:::{warning} +If the main process exits abruptly (e.g. because of an incoming signal), +Python's `multiprocessing` sometimes fails to clean up its children. +It's a known caveat, so if you're seeing any resource leaks after +interrupting the interpreter, it probably means that this has just happened +to you. +::: + +## Strategy management + +```{eval-rst} +.. autofunction:: get_all_sharing_strategies +``` + +```{eval-rst} +.. autofunction:: get_sharing_strategy +``` + +```{eval-rst} +.. autofunction:: set_sharing_strategy + +``` + +(multiprocessing-cuda-sharing-details)= + +## Sharing CUDA tensors + +Sharing CUDA tensors between processes is supported only in Python 3, using +a `spawn` or `forkserver` start methods. + +Unlike CPU tensors, the sending process is required to keep the original tensor +as long as the receiving process retains a copy of the tensor. The refcounting is +implemented under the hood but requires users to follow the next best practices. + +:::{warning} +If the consumer process dies abnormally to a fatal signal, the shared tensor +could be forever kept in memory as long as the sending process is running. +::: + +1. Release memory ASAP in the consumer. + +``` +## Good +x = queue.get() +# do somethings with x +del x +``` + +``` +## Bad +x = queue.get() +# do somethings with x +# do everything else (producer have to keep x in memory) +``` + +2. Keep producer process running until all consumers exits. This will prevent +the situation when the producer process releasing memory which is still in use +by the consumer. + +``` +## producer +# send tensors, do something +event.wait() +``` + +``` +## consumer +# receive tensors and use them +event.set() +``` + +3. Don't pass received tensors. + +``` +# not going to work +x = queue.get() +queue_2.put(x) +``` + +``` +# you need to create a process-local copy +x = queue.get() +x_clone = x.clone() +queue_2.put(x_clone) +``` + +``` +# putting and getting from the same queue in the same process will likely end up with segfault +queue.put(tensor) +x = queue.get() +``` + +## Sharing strategies + +This section provides a brief overview into how different sharing strategies +work. Note that it applies only to CPU tensor - CUDA tensors will always use +the CUDA API, as that's the only way they can be shared. + +### File descriptor - `file_descriptor` + +:::{note} +This is the default strategy (except for macOS and OS X where it's not +supported). +::: + +This strategy will use file descriptors as shared memory handles. Whenever a +storage is moved to shared memory, a file descriptor obtained from `shm_open` +is cached with the object, and when it's going to be sent to other processes, +the file descriptor will be transferred (e.g. via UNIX sockets) to it. The +receiver will also cache the file descriptor and `mmap` it, to obtain a shared +view onto the storage data. + +Note that if there will be a lot of tensors shared, this strategy will keep a +large number of file descriptors open most of the time. If your system has low +limits for the number of open file descriptors, and you can't raise them, you +should use the `file_system` strategy. + +### File system - `file_system` + +This strategy will use file names given to `shm_open` to identify the shared +memory regions. This has a benefit of not requiring the implementation to cache +the file descriptors obtained from it, but at the same time is prone to shared +memory leaks. The file can't be deleted right after its creation, because other +processes need to access it to open their views. If the processes fatally +crash, or are killed, and don't call the storage destructors, the files will +remain in the system. This is very serious, because they keep using up the +memory until the system is restarted, or they're freed manually. + +To counter the problem of shared memory file leaks, {mod}`torch.multiprocessing` +will spawn a daemon named `torch_shm_manager` that will isolate itself from +the current process group, and will keep track of all shared memory allocations. +Once all processes connected to it exit, it will wait a moment to ensure there +will be no new connections, and will iterate over all shared memory files +allocated by the group. If it finds that any of them still exist, they will be +deallocated. We've tested this method and it proved to be robust to various +failures. Still, if your system has high enough limits, and `file_descriptor` +is a supported strategy, we do not recommend switching to this one. + +## Spawning subprocesses + +:::{note} +Available for Python >= 3.4. + +This depends on the `spawn` start method in Python's +`multiprocessing` package. +::: + +Spawning a number of subprocesses to perform some function can be done +by creating `Process` instances and calling `join` to wait for +their completion. This approach works fine when dealing with a single +subprocess but presents potential issues when dealing with multiple +processes. + +Namely, joining processes sequentially implies they will terminate +sequentially. If they don't, and the first process does not terminate, +the process termination will go unnoticed. Also, there are no native +facilities for error propagation. + +The `spawn` function below addresses these concerns and takes care +of error propagation, out of order termination, and will actively +terminate processes upon detecting an error in one of them. + +```{eval-rst} +.. automodule:: torch.multiprocessing.spawn +``` + +```{eval-rst} +.. currentmodule:: torch.multiprocessing.spawn +``` + +```{eval-rst} +.. autofunction:: spawn +``` + +```{eval-rst} +.. currentmodule:: torch.multiprocessing + +``` + +```{eval-rst} +.. class:: SpawnContext + + Returned by :func:`~spawn` when called with ``join=False``. + + .. automethod:: join + +``` + +% This module needs to be documented. Adding here in the meantime + +% for tracking purposes + +```{eval-rst} +.. py:module:: torch.multiprocessing.pool +``` + +```{eval-rst} +.. py:module:: torch.multiprocessing.queue +``` + +```{eval-rst} +.. py:module:: torch.multiprocessing.reductions +``` diff --git a/docs/source/multiprocessing.rst b/docs/source/multiprocessing.rst deleted file mode 100644 index 78218c2f7e9a46..00000000000000 --- a/docs/source/multiprocessing.rst +++ /dev/null @@ -1,196 +0,0 @@ -:orphan: - -.. _multiprocessing-doc: - -Multiprocessing package - torch.multiprocessing -=============================================== - -.. automodule:: torch.multiprocessing -.. currentmodule:: torch.multiprocessing - -.. warning:: - - If the main process exits abruptly (e.g. because of an incoming signal), - Python's ``multiprocessing`` sometimes fails to clean up its children. - It's a known caveat, so if you're seeing any resource leaks after - interrupting the interpreter, it probably means that this has just happened - to you. - -Strategy management -------------------- - -.. autofunction:: get_all_sharing_strategies -.. autofunction:: get_sharing_strategy -.. autofunction:: set_sharing_strategy - - -.. _multiprocessing-cuda-sharing-details: - -Sharing CUDA tensors --------------------- - -Sharing CUDA tensors between processes is supported only in Python 3, using -a ``spawn`` or ``forkserver`` start methods. - - -Unlike CPU tensors, the sending process is required to keep the original tensor -as long as the receiving process retains a copy of the tensor. The refcounting is -implemented under the hood but requires users to follow the next best practices. - -.. warning:: - If the consumer process dies abnormally to a fatal signal, the shared tensor - could be forever kept in memory as long as the sending process is running. - - -1. Release memory ASAP in the consumer. - -:: - - ## Good - x = queue.get() - # do somethings with x - del x - -:: - - ## Bad - x = queue.get() - # do somethings with x - # do everything else (producer have to keep x in memory) - -2. Keep producer process running until all consumers exits. This will prevent -the situation when the producer process releasing memory which is still in use -by the consumer. - -:: - - ## producer - # send tensors, do something - event.wait() - - -:: - - ## consumer - # receive tensors and use them - event.set() - -3. Don't pass received tensors. - -:: - - # not going to work - x = queue.get() - queue_2.put(x) - - -:: - - # you need to create a process-local copy - x = queue.get() - x_clone = x.clone() - queue_2.put(x_clone) - - -:: - - # putting and getting from the same queue in the same process will likely end up with segfault - queue.put(tensor) - x = queue.get() - - -Sharing strategies ------------------- - -This section provides a brief overview into how different sharing strategies -work. Note that it applies only to CPU tensor - CUDA tensors will always use -the CUDA API, as that's the only way they can be shared. - -File descriptor - ``file_descriptor`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - - -.. note:: - - This is the default strategy (except for macOS and OS X where it's not - supported). - -This strategy will use file descriptors as shared memory handles. Whenever a -storage is moved to shared memory, a file descriptor obtained from ``shm_open`` -is cached with the object, and when it's going to be sent to other processes, -the file descriptor will be transferred (e.g. via UNIX sockets) to it. The -receiver will also cache the file descriptor and ``mmap`` it, to obtain a shared -view onto the storage data. - -Note that if there will be a lot of tensors shared, this strategy will keep a -large number of file descriptors open most of the time. If your system has low -limits for the number of open file descriptors, and you can't raise them, you -should use the ``file_system`` strategy. - -File system - ``file_system`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This strategy will use file names given to ``shm_open`` to identify the shared -memory regions. This has a benefit of not requiring the implementation to cache -the file descriptors obtained from it, but at the same time is prone to shared -memory leaks. The file can't be deleted right after its creation, because other -processes need to access it to open their views. If the processes fatally -crash, or are killed, and don't call the storage destructors, the files will -remain in the system. This is very serious, because they keep using up the -memory until the system is restarted, or they're freed manually. - -To counter the problem of shared memory file leaks, :mod:`torch.multiprocessing` -will spawn a daemon named ``torch_shm_manager`` that will isolate itself from -the current process group, and will keep track of all shared memory allocations. -Once all processes connected to it exit, it will wait a moment to ensure there -will be no new connections, and will iterate over all shared memory files -allocated by the group. If it finds that any of them still exist, they will be -deallocated. We've tested this method and it proved to be robust to various -failures. Still, if your system has high enough limits, and ``file_descriptor`` -is a supported strategy, we do not recommend switching to this one. - -Spawning subprocesses ---------------------- - -.. note:: - - Available for Python >= 3.4. - - This depends on the ``spawn`` start method in Python's - ``multiprocessing`` package. - -Spawning a number of subprocesses to perform some function can be done -by creating ``Process`` instances and calling ``join`` to wait for -their completion. This approach works fine when dealing with a single -subprocess but presents potential issues when dealing with multiple -processes. - -Namely, joining processes sequentially implies they will terminate -sequentially. If they don't, and the first process does not terminate, -the process termination will go unnoticed. Also, there are no native -facilities for error propagation. - -The ``spawn`` function below addresses these concerns and takes care -of error propagation, out of order termination, and will actively -terminate processes upon detecting an error in one of them. - -.. automodule:: torch.multiprocessing.spawn -.. currentmodule:: torch.multiprocessing.spawn - -.. autofunction:: spawn - -.. currentmodule:: torch.multiprocessing - - -.. class:: SpawnContext - - Returned by :func:`~spawn` when called with ``join=False``. - - .. automethod:: join - - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.multiprocessing.pool -.. py:module:: torch.multiprocessing.queue -.. py:module:: torch.multiprocessing.reductions diff --git a/docs/source/name_inference.md b/docs/source/name_inference.md new file mode 100644 index 00000000000000..ab705fa509b184 --- /dev/null +++ b/docs/source/name_inference.md @@ -0,0 +1,483 @@ +```{eval-rst} +.. currentmodule:: torch +``` + +(name_inference_reference-doc)= + +# Named Tensors operator coverage + +Please read {ref}`named_tensors-doc` first for an introduction to named tensors. + +This document is a reference for *name inference*, a process that defines how +named tensors: + +1. use names to provide additional automatic runtime correctness checks +2. propagate names from input tensors to output tensors + +Below is a list of all operations that are supported with named tensors +and their associated name inference rules. + +If you don't see an operation listed here, but it would help your use case, please +[search if an issue has already been filed](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22) and if not, [file one](https://github.com/pytorch/pytorch/issues/new/choose). + +:::{warning} +The named tensor API is experimental and subject to change. +::: + +```{eval-rst} +.. csv-table:: Supported Operations + :header: API, Name inference rule + :widths: 20, 20 + + ":meth:`Tensor.abs`, :func:`torch.abs`",:ref:`keeps_input_names-doc` + :meth:`Tensor.abs_`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.acos`, :func:`torch.acos`",:ref:`keeps_input_names-doc` + :meth:`Tensor.acos_`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.add`, :func:`torch.add`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.add_`,:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.addmm`, :func:`torch.addmm`",:ref:`contracts_away_dims-doc` + :meth:`Tensor.addmm_`,:ref:`contracts_away_dims-doc` + ":meth:`Tensor.addmv`, :func:`torch.addmv`",:ref:`contracts_away_dims-doc` + :meth:`Tensor.addmv_`,:ref:`contracts_away_dims-doc` + :meth:`Tensor.align_as`,See documentation + :meth:`Tensor.align_to`,See documentation + ":meth:`Tensor.all`, :func:`torch.all`",None + ":meth:`Tensor.any`, :func:`torch.any`",None + ":meth:`Tensor.asin`, :func:`torch.asin`",:ref:`keeps_input_names-doc` + :meth:`Tensor.asin_`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.atan`, :func:`torch.atan`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.atan2`, :func:`torch.atan2`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.atan2_`,:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.atan_`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.bernoulli`, :func:`torch.bernoulli`",:ref:`keeps_input_names-doc` + :meth:`Tensor.bernoulli_`,None + :meth:`Tensor.bfloat16`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.bitwise_not`, :func:`torch.bitwise_not`",:ref:`keeps_input_names-doc` + :meth:`Tensor.bitwise_not_`,None + ":meth:`Tensor.bmm`, :func:`torch.bmm`",:ref:`contracts_away_dims-doc` + :meth:`Tensor.bool`,:ref:`keeps_input_names-doc` + :meth:`Tensor.byte`,:ref:`keeps_input_names-doc` + :func:`torch.cat`,:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.cauchy_`,None + ":meth:`Tensor.ceil`, :func:`torch.ceil`",:ref:`keeps_input_names-doc` + :meth:`Tensor.ceil_`,None + :meth:`Tensor.char`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.chunk`, :func:`torch.chunk`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.clamp`, :func:`torch.clamp`",:ref:`keeps_input_names-doc` + :meth:`Tensor.clamp_`,None + :meth:`Tensor.copy_`,:ref:`out_function_semantics-doc` + ":meth:`Tensor.cos`, :func:`torch.cos`",:ref:`keeps_input_names-doc` + :meth:`Tensor.cos_`,None + ":meth:`Tensor.cosh`, :func:`torch.cosh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.cosh_`,None + ":meth:`Tensor.acosh`, :func:`torch.acosh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.acosh_`,None + :meth:`Tensor.cpu`,:ref:`keeps_input_names-doc` + :meth:`Tensor.cuda`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.cumprod`, :func:`torch.cumprod`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.cumsum`, :func:`torch.cumsum`",:ref:`keeps_input_names-doc` + :meth:`Tensor.data_ptr`,None + ":meth:`Tensor.deg2rad`, :func:`torch.deg2rad`",:ref:`keeps_input_names-doc` + :meth:`Tensor.deg2rad_`,None + ":meth:`Tensor.detach`, :func:`torch.detach`",:ref:`keeps_input_names-doc` + :meth:`Tensor.detach_`,None + ":attr:`Tensor.device`, :func:`torch.device`",None + ":meth:`Tensor.digamma`, :func:`torch.digamma`",:ref:`keeps_input_names-doc` + :meth:`Tensor.digamma_`,None + :meth:`Tensor.dim`,None + ":meth:`Tensor.div`, :func:`torch.div`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.div_`,:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.dot`, :func:`torch.dot`",None + :meth:`Tensor.double`,:ref:`keeps_input_names-doc` + :meth:`Tensor.element_size`,None + :func:`torch.empty`,:ref:`factory-doc` + :func:`torch.empty_like`,:ref:`factory-doc` + ":meth:`Tensor.eq`, :func:`torch.eq`",:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.erf`, :func:`torch.erf`",:ref:`keeps_input_names-doc` + :meth:`Tensor.erf_`,None + ":meth:`Tensor.erfc`, :func:`torch.erfc`",:ref:`keeps_input_names-doc` + :meth:`Tensor.erfc_`,None + ":meth:`Tensor.erfinv`, :func:`torch.erfinv`",:ref:`keeps_input_names-doc` + :meth:`Tensor.erfinv_`,None + ":meth:`Tensor.exp`, :func:`torch.exp`",:ref:`keeps_input_names-doc` + :meth:`Tensor.exp_`,None + :meth:`Tensor.expand`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.expm1`, :func:`torch.expm1`",:ref:`keeps_input_names-doc` + :meth:`Tensor.expm1_`,None + :meth:`Tensor.exponential_`,None + :meth:`Tensor.fill_`,None + ":meth:`Tensor.flatten`, :func:`torch.flatten`",See documentation + :meth:`Tensor.float`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.floor`, :func:`torch.floor`",:ref:`keeps_input_names-doc` + :meth:`Tensor.floor_`,None + ":meth:`Tensor.frac`, :func:`torch.frac`",:ref:`keeps_input_names-doc` + :meth:`Tensor.frac_`,None + ":meth:`Tensor.ge`, :func:`torch.ge`",:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.get_device`, :func:`torch.get_device`",None + :attr:`Tensor.grad`,None + ":meth:`Tensor.gt`, :func:`torch.gt`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.half`,:ref:`keeps_input_names-doc` + :meth:`Tensor.has_names`,See documentation + ":meth:`Tensor.index_fill`, :func:`torch.index_fill`",:ref:`keeps_input_names-doc` + :meth:`Tensor.index_fill_`,None + :meth:`Tensor.int`,:ref:`keeps_input_names-doc` + :meth:`Tensor.is_contiguous`,None + :attr:`Tensor.is_cuda`,None + ":meth:`Tensor.is_floating_point`, :func:`torch.is_floating_point`",None + :attr:`Tensor.is_leaf`,None + :meth:`Tensor.is_pinned`,None + :meth:`Tensor.is_shared`,None + ":meth:`Tensor.is_signed`, :func:`torch.is_signed`",None + :attr:`Tensor.is_sparse`,None + :attr:`Tensor.is_sparse_csr`,None + :func:`torch.is_tensor`,None + :meth:`Tensor.item`,None + :attr:`Tensor.itemsize`,None + ":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.le`, :func:`torch.le`",:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.log`, :func:`torch.log`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.log10`, :func:`torch.log10`",:ref:`keeps_input_names-doc` + :meth:`Tensor.log10_`,None + ":meth:`Tensor.log1p`, :func:`torch.log1p`",:ref:`keeps_input_names-doc` + :meth:`Tensor.log1p_`,None + ":meth:`Tensor.log2`, :func:`torch.log2`",:ref:`keeps_input_names-doc` + :meth:`Tensor.log2_`,None + :meth:`Tensor.log_`,None + :meth:`Tensor.log_normal_`,None + ":meth:`Tensor.logical_not`, :func:`torch.logical_not`",:ref:`keeps_input_names-doc` + :meth:`Tensor.logical_not_`,None + ":meth:`Tensor.logsumexp`, :func:`torch.logsumexp`",:ref:`removes_dimensions-doc` + :meth:`Tensor.long`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.lt`, :func:`torch.lt`",:ref:`unifies_names_from_inputs-doc` + :func:`torch.manual_seed`,None + ":meth:`Tensor.masked_fill`, :func:`torch.masked_fill`",:ref:`keeps_input_names-doc` + :meth:`Tensor.masked_fill_`,None + ":meth:`Tensor.masked_select`, :func:`torch.masked_select`",Aligns mask up to input and then unifies_names_from_input_tensors + ":meth:`Tensor.matmul`, :func:`torch.matmul`",:ref:`contracts_away_dims-doc` + ":meth:`Tensor.mean`, :func:`torch.mean`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.median`, :func:`torch.median`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.nanmedian`, :func:`torch.nanmedian`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.mm`, :func:`torch.mm`",:ref:`contracts_away_dims-doc` + ":meth:`Tensor.mode`, :func:`torch.mode`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.mul`, :func:`torch.mul`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.mul_`,:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.mv`, :func:`torch.mv`",:ref:`contracts_away_dims-doc` + :attr:`Tensor.names`,See documentation + ":meth:`Tensor.narrow`, :func:`torch.narrow`",:ref:`keeps_input_names-doc` + :attr:`Tensor.nbytes`,None + :attr:`Tensor.ndim`,None + :meth:`Tensor.ndimension`,None + ":meth:`Tensor.ne`, :func:`torch.ne`",:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.neg`, :func:`torch.neg`",:ref:`keeps_input_names-doc` + :meth:`Tensor.neg_`,None + :func:`torch.normal`,:ref:`keeps_input_names-doc` + :meth:`Tensor.normal_`,None + ":meth:`Tensor.numel`, :func:`torch.numel`",None + :func:`torch.ones`,:ref:`factory-doc` + ":meth:`Tensor.pow`, :func:`torch.pow`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.pow_`,None + ":meth:`Tensor.prod`, :func:`torch.prod`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.rad2deg`, :func:`torch.rad2deg`",:ref:`keeps_input_names-doc` + :meth:`Tensor.rad2deg_`,None + :func:`torch.rand`,:ref:`factory-doc` + :func:`torch.rand`,:ref:`factory-doc` + :func:`torch.randn`,:ref:`factory-doc` + :func:`torch.randn`,:ref:`factory-doc` + :meth:`Tensor.random_`,None + ":meth:`Tensor.reciprocal`, :func:`torch.reciprocal`",:ref:`keeps_input_names-doc` + :meth:`Tensor.reciprocal_`,None + :meth:`Tensor.refine_names`,See documentation + :meth:`Tensor.register_hook`,None + :meth:`Tensor.register_post_accumulate_grad_hook`,None + :meth:`Tensor.rename`,See documentation + :meth:`Tensor.rename_`,See documentation + :attr:`Tensor.requires_grad`,None + :meth:`Tensor.requires_grad_`,None + :meth:`Tensor.resize_`,Only allow resizes that do not change shape + :meth:`Tensor.resize_as_`,Only allow resizes that do not change shape + ":meth:`Tensor.round`, :func:`torch.round`",:ref:`keeps_input_names-doc` + :meth:`Tensor.round_`,None + ":meth:`Tensor.rsqrt`, :func:`torch.rsqrt`",:ref:`keeps_input_names-doc` + :meth:`Tensor.rsqrt_`,None + ":meth:`Tensor.select`, :func:`torch.select`",:ref:`removes_dimensions-doc` + :meth:`Tensor.short`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.sigmoid`, :func:`torch.sigmoid`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sigmoid_`,None + ":meth:`Tensor.sign`, :func:`torch.sign`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sign_`,None + ":meth:`Tensor.sgn`, :func:`torch.sgn`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sgn_`,None + ":meth:`Tensor.sin`, :func:`torch.sin`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sin_`,None + ":meth:`Tensor.sinh`, :func:`torch.sinh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sinh_`,None + ":meth:`Tensor.asinh`, :func:`torch.asinh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.asinh_`,None + :meth:`Tensor.size`,None + ":meth:`Tensor.softmax`, :func:`torch.softmax`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.split`, :func:`torch.split`",:ref:`keeps_input_names-doc` + ":meth:`Tensor.sqrt`, :func:`torch.sqrt`",:ref:`keeps_input_names-doc` + :meth:`Tensor.sqrt_`,None + ":meth:`Tensor.squeeze`, :func:`torch.squeeze`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.std`, :func:`torch.std`",:ref:`removes_dimensions-doc` + :func:`torch.std_mean`,:ref:`removes_dimensions-doc` + :meth:`Tensor.stride`,None + ":meth:`Tensor.sub`, :func:`torch.sub`",:ref:`unifies_names_from_inputs-doc` + :meth:`Tensor.sub_`,:ref:`unifies_names_from_inputs-doc` + ":meth:`Tensor.sum`, :func:`torch.sum`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.tan`, :func:`torch.tan`",:ref:`keeps_input_names-doc` + :meth:`Tensor.tan_`,None + ":meth:`Tensor.tanh`, :func:`torch.tanh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.tanh_`,None + ":meth:`Tensor.atanh`, :func:`torch.atanh`",:ref:`keeps_input_names-doc` + :meth:`Tensor.atanh_`,None + :func:`torch.tensor`,:ref:`factory-doc` + :meth:`Tensor.to`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.topk`, :func:`torch.topk`",:ref:`removes_dimensions-doc` + ":meth:`Tensor.transpose`, :func:`torch.transpose`",:ref:`permutes_dimensions-doc` + ":meth:`Tensor.trunc`, :func:`torch.trunc`",:ref:`keeps_input_names-doc` + :meth:`Tensor.trunc_`,None + :meth:`Tensor.type`,None + :meth:`Tensor.type_as`,:ref:`keeps_input_names-doc` + ":meth:`Tensor.unbind`, :func:`torch.unbind`",:ref:`removes_dimensions-doc` + :meth:`Tensor.unflatten`,See documentation + :meth:`Tensor.uniform_`,None + ":meth:`Tensor.var`, :func:`torch.var`",:ref:`removes_dimensions-doc` + :func:`torch.var_mean`,:ref:`removes_dimensions-doc` + :meth:`Tensor.zero_`,None + :func:`torch.zeros`,:ref:`factory-doc` + +``` + +(keeps_input_names-doc)= + +## Keeps input names + +All pointwise unary functions follow this rule as well as some other unary functions. + +- Check names: None +- Propagate names: input tensor's names are propagated to the output. + +``` +>>> x = torch.randn(3, 3, names=('N', 'C')) +>>> x.abs().names +('N', 'C') +``` + +(removes_dimensions-doc)= + +## Removes dimensions + +All reduction ops like {meth}`~Tensor.sum` remove dimensions by reducing +over the desired dimensions. Other operations like {meth}`~Tensor.select` and +{meth}`~Tensor.squeeze` remove dimensions. + +Wherever one can pass an integer dimension index to an operator, one can also pass +a dimension name. Functions that take lists of dimension indices can also take in a +list of dimension names. + +- Check names: If {attr}`dim` or {attr}`dims` is passed in as a list of names, + check that those names exist in {attr}`self`. +- Propagate names: If the dimensions of the input tensor specified by {attr}`dim` + or {attr}`dims` are not present in the output tensor, then the corresponding names + of those dimensions do not appear in `output.names`. + +``` +>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W')) +>>> x.squeeze('N').names +('C', 'H', 'W') + +>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) +>>> x.sum(['N', 'C']).names +('H', 'W') + +# Reduction ops with keepdim=True don't actually remove dimensions. +>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) +>>> x.sum(['N', 'C'], keepdim=True).names +('N', 'C', 'H', 'W') +``` + +(unifies_names_from_inputs-doc)= + +## Unifies names from inputs + +All binary arithmetic ops follow this rule. Operations that broadcast still +broadcast positionally from the right to preserve compatibility with unnamed +tensors. To perform explicit broadcasting by names, use {meth}`Tensor.align_as`. + +- Check names: All names must match positionally from the right. i.e., in + `tensor + other`, `match(tensor.names[i], other.names[i])` must be true for all + `i` in `(-min(tensor.dim(), other.dim()) + 1, -1]`. +- Check names: Furthermore, all named dimensions must be aligned from the right. + During matching, if we match a named dimension `A` with an unnamed dimension + `None`, then `A` must not appear in the tensor with the unnamed dimension. +- Propagate names: unify pairs of names from the right from both tensors to + produce output names. + +For example, + +``` +# tensor: Tensor[ N, None] +# other: Tensor[None, C] +>>> tensor = torch.randn(3, 3, names=('N', None)) +>>> other = torch.randn(3, 3, names=(None, 'C')) +>>> (tensor + other).names +('N', 'C') +``` + +Check names: + +- `match(tensor.names[-1], other.names[-1])` is `True` +- `match(tensor.names[-2], tensor.names[-2])` is `True` +- Because we matched `None` in {attr}`tensor` with `'C'`, + check to make sure `'C'` doesn't exist in {attr}`tensor` (it does not). +- Check to make sure `'N'` doesn't exists in {attr}`other` (it does not). + +Finally, the output names are computed with +`[unify('N', None), unify(None, 'C')] = ['N', 'C']` + +More examples: + +``` +# Dimensions don't match from the right: +# tensor: Tensor[N, C] +# other: Tensor[ N] +>>> tensor = torch.randn(3, 3, names=('N', 'C')) +>>> other = torch.randn(3, names=('N',)) +>>> (tensor + other).names +RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims +['N']: dim 'C' and dim 'N' are at the same position from the right but do +not match. + +# Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]: +# tensor: Tensor[N, None] +# other: Tensor[ N] +>>> tensor = torch.randn(3, 3, names=('N', None)) +>>> other = torch.randn(3, names=('N',)) +>>> (tensor + other).names +RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and +dims ['N', None]: dim 'N' appears in a different position from the right +across both lists. +``` + +:::{note} +In both of the last examples, it is possible to align the tensors by names +and then perform the addition. Use {meth}`Tensor.align_as` to align +tensors by name or {meth}`Tensor.align_to` to align tensors to a custom +dimension ordering. +::: + +(permutes_dimensions-doc)= + +## Permutes dimensions + +Some operations, like {meth}`Tensor.t()`, permute the order of dimensions. Dimension names +are attached to individual dimensions so they get permuted as well. + +If the operator takes in positional index {attr}`dim`, it is also able to take a dimension +name as {attr}`dim`. + +- Check names: If {attr}`dim` is passed as a name, check that it exists in the tensor. +- Propagate names: Permute dimension names in the same way as the dimensions that are + being permuted. + +``` +>>> x = torch.randn(3, 3, names=('N', 'C')) +>>> x.transpose('N', 'C').names +('C', 'N') +``` + +(contracts_away_dims-doc)= + +## Contracts away dims + +Matrix multiply functions follow some variant of this. Let's go through +{func}`torch.mm` first and then generalize the rule for batch matrix multiplication. + +For `torch.mm(tensor, other)`: + +- Check names: None +- Propagate names: result names are `(tensor.names[-2], other.names[-1])`. + +``` +>>> x = torch.randn(3, 3, names=('N', 'D')) +>>> y = torch.randn(3, 3, names=('in', 'out')) +>>> x.mm(y).names +('N', 'out') +``` + +Inherently, a matrix multiplication performs a dot product over two dimensions, +collapsing them. When two tensors are matrix-multiplied, the contracted dimensions +disappear and do not show up in the output tensor. + +{func}`torch.mv`, {func}`torch.dot` work in a similar way: name inference does not +check input names and removes the dimensions that are involved in the dot product: + +``` +>>> x = torch.randn(3, 3, names=('N', 'D')) +>>> y = torch.randn(3, names=('something',)) +>>> x.mv(y).names +('N',) +``` + +Now, let's take a look at `torch.matmul(tensor, other)`. Assume that `tensor.dim() >= 2` +and `other.dim() >= 2`. + +- Check names: Check that the batch dimensions of the inputs are aligned and broadcastable. + See {ref}`unifies_names_from_inputs-doc` for what it means for the inputs to be aligned. +- Propagate names: result names are obtained by unifying the batch dimensions and removing + the contracted dimensions: + `unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])`. + +Examples: + +``` +# Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F']. +# 'A', 'B' are batch dimensions. +>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D')) +>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F')) +>>> torch.matmul(x, y).names +('A', 'B', 'C', 'F') +``` + +Finally, there are fused `add` versions of many matmul functions. i.e., {func}`addmm` +and {func}`addmv`. These are treated as composing name inference for i.e. {func}`mm` and +name inference for {func}`add`. + +(factory-doc)= + +## Factory functions + +Factory functions now take a new {attr}`names` argument that associates a name +with each dimension. + +``` +>>> torch.zeros(2, 3, names=('N', 'C')) +tensor([[0., 0., 0.], + [0., 0., 0.]], names=('N', 'C')) +``` + +(out_function_semantics-doc)= + +## out function and in-place variants + +A tensor specified as an `out=` tensor has the following behavior: + +- If it has no named dimensions, then the names computed from the operation + get propagated to it. +- If it has any named dimensions, then the names computed from the operation + must be exactly equal to the existing names. Otherwise, the operation errors. + +All in-place methods modify inputs to have names equal to the computed names +from name inference. For example: + +``` +>>> x = torch.randn(3, 3) +>>> y = torch.randn(3, 3, names=('N', 'C')) +>>> x.names +(None, None) + +>>> x += y +>>> x.names +('N', 'C') +``` diff --git a/docs/source/name_inference.rst b/docs/source/name_inference.rst deleted file mode 100644 index db6189f06ac519..00000000000000 --- a/docs/source/name_inference.rst +++ /dev/null @@ -1,485 +0,0 @@ -.. currentmodule:: torch - -.. _name_inference_reference-doc: - -Named Tensors operator coverage -=============================== - -Please read :ref:`named_tensors-doc` first for an introduction to named tensors. - -This document is a reference for *name inference*, a process that defines how -named tensors: - -1. use names to provide additional automatic runtime correctness checks -2. propagate names from input tensors to output tensors - -Below is a list of all operations that are supported with named tensors -and their associated name inference rules. - -If you don't see an operation listed here, but it would help your use case, please -`search if an issue has already been filed `_ and if not, `file one `_. - -.. warning:: - The named tensor API is experimental and subject to change. - -.. csv-table:: Supported Operations - :header: API, Name inference rule - :widths: 20, 20 - - ":meth:`Tensor.abs`, :func:`torch.abs`",:ref:`keeps_input_names-doc` - :meth:`Tensor.abs_`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.acos`, :func:`torch.acos`",:ref:`keeps_input_names-doc` - :meth:`Tensor.acos_`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.add`, :func:`torch.add`",:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.add_`,:ref:`unifies_names_from_inputs-doc` - ":meth:`Tensor.addmm`, :func:`torch.addmm`",:ref:`contracts_away_dims-doc` - :meth:`Tensor.addmm_`,:ref:`contracts_away_dims-doc` - ":meth:`Tensor.addmv`, :func:`torch.addmv`",:ref:`contracts_away_dims-doc` - :meth:`Tensor.addmv_`,:ref:`contracts_away_dims-doc` - :meth:`Tensor.align_as`,See documentation - :meth:`Tensor.align_to`,See documentation - ":meth:`Tensor.all`, :func:`torch.all`",None - ":meth:`Tensor.any`, :func:`torch.any`",None - ":meth:`Tensor.asin`, :func:`torch.asin`",:ref:`keeps_input_names-doc` - :meth:`Tensor.asin_`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.atan`, :func:`torch.atan`",:ref:`keeps_input_names-doc` - ":meth:`Tensor.atan2`, :func:`torch.atan2`",:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.atan2_`,:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.atan_`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.bernoulli`, :func:`torch.bernoulli`",:ref:`keeps_input_names-doc` - :meth:`Tensor.bernoulli_`,None - :meth:`Tensor.bfloat16`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.bitwise_not`, :func:`torch.bitwise_not`",:ref:`keeps_input_names-doc` - :meth:`Tensor.bitwise_not_`,None - ":meth:`Tensor.bmm`, :func:`torch.bmm`",:ref:`contracts_away_dims-doc` - :meth:`Tensor.bool`,:ref:`keeps_input_names-doc` - :meth:`Tensor.byte`,:ref:`keeps_input_names-doc` - :func:`torch.cat`,:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.cauchy_`,None - ":meth:`Tensor.ceil`, :func:`torch.ceil`",:ref:`keeps_input_names-doc` - :meth:`Tensor.ceil_`,None - :meth:`Tensor.char`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.chunk`, :func:`torch.chunk`",:ref:`keeps_input_names-doc` - ":meth:`Tensor.clamp`, :func:`torch.clamp`",:ref:`keeps_input_names-doc` - :meth:`Tensor.clamp_`,None - :meth:`Tensor.copy_`,:ref:`out_function_semantics-doc` - ":meth:`Tensor.cos`, :func:`torch.cos`",:ref:`keeps_input_names-doc` - :meth:`Tensor.cos_`,None - ":meth:`Tensor.cosh`, :func:`torch.cosh`",:ref:`keeps_input_names-doc` - :meth:`Tensor.cosh_`,None - ":meth:`Tensor.acosh`, :func:`torch.acosh`",:ref:`keeps_input_names-doc` - :meth:`Tensor.acosh_`,None - :meth:`Tensor.cpu`,:ref:`keeps_input_names-doc` - :meth:`Tensor.cuda`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.cumprod`, :func:`torch.cumprod`",:ref:`keeps_input_names-doc` - ":meth:`Tensor.cumsum`, :func:`torch.cumsum`",:ref:`keeps_input_names-doc` - :meth:`Tensor.data_ptr`,None - ":meth:`Tensor.deg2rad`, :func:`torch.deg2rad`",:ref:`keeps_input_names-doc` - :meth:`Tensor.deg2rad_`,None - ":meth:`Tensor.detach`, :func:`torch.detach`",:ref:`keeps_input_names-doc` - :meth:`Tensor.detach_`,None - ":attr:`Tensor.device`, :func:`torch.device`",None - ":meth:`Tensor.digamma`, :func:`torch.digamma`",:ref:`keeps_input_names-doc` - :meth:`Tensor.digamma_`,None - :meth:`Tensor.dim`,None - ":meth:`Tensor.div`, :func:`torch.div`",:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.div_`,:ref:`unifies_names_from_inputs-doc` - ":meth:`Tensor.dot`, :func:`torch.dot`",None - :meth:`Tensor.double`,:ref:`keeps_input_names-doc` - :meth:`Tensor.element_size`,None - :func:`torch.empty`,:ref:`factory-doc` - :func:`torch.empty_like`,:ref:`factory-doc` - ":meth:`Tensor.eq`, :func:`torch.eq`",:ref:`unifies_names_from_inputs-doc` - ":meth:`Tensor.erf`, :func:`torch.erf`",:ref:`keeps_input_names-doc` - :meth:`Tensor.erf_`,None - ":meth:`Tensor.erfc`, :func:`torch.erfc`",:ref:`keeps_input_names-doc` - :meth:`Tensor.erfc_`,None - ":meth:`Tensor.erfinv`, :func:`torch.erfinv`",:ref:`keeps_input_names-doc` - :meth:`Tensor.erfinv_`,None - ":meth:`Tensor.exp`, :func:`torch.exp`",:ref:`keeps_input_names-doc` - :meth:`Tensor.exp_`,None - :meth:`Tensor.expand`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.expm1`, :func:`torch.expm1`",:ref:`keeps_input_names-doc` - :meth:`Tensor.expm1_`,None - :meth:`Tensor.exponential_`,None - :meth:`Tensor.fill_`,None - ":meth:`Tensor.flatten`, :func:`torch.flatten`",See documentation - :meth:`Tensor.float`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.floor`, :func:`torch.floor`",:ref:`keeps_input_names-doc` - :meth:`Tensor.floor_`,None - ":meth:`Tensor.frac`, :func:`torch.frac`",:ref:`keeps_input_names-doc` - :meth:`Tensor.frac_`,None - ":meth:`Tensor.ge`, :func:`torch.ge`",:ref:`unifies_names_from_inputs-doc` - ":meth:`Tensor.get_device`, :func:`torch.get_device`",None - :attr:`Tensor.grad`,None - ":meth:`Tensor.gt`, :func:`torch.gt`",:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.half`,:ref:`keeps_input_names-doc` - :meth:`Tensor.has_names`,See documentation - ":meth:`Tensor.index_fill`, :func:`torch.index_fill`",:ref:`keeps_input_names-doc` - :meth:`Tensor.index_fill_`,None - :meth:`Tensor.int`,:ref:`keeps_input_names-doc` - :meth:`Tensor.is_contiguous`,None - :attr:`Tensor.is_cuda`,None - ":meth:`Tensor.is_floating_point`, :func:`torch.is_floating_point`",None - :attr:`Tensor.is_leaf`,None - :meth:`Tensor.is_pinned`,None - :meth:`Tensor.is_shared`,None - ":meth:`Tensor.is_signed`, :func:`torch.is_signed`",None - :attr:`Tensor.is_sparse`,None - :attr:`Tensor.is_sparse_csr`,None - :func:`torch.is_tensor`,None - :meth:`Tensor.item`,None - :attr:`Tensor.itemsize`,None - ":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.le`, :func:`torch.le`",:ref:`unifies_names_from_inputs-doc` - ":meth:`Tensor.log`, :func:`torch.log`",:ref:`keeps_input_names-doc` - ":meth:`Tensor.log10`, :func:`torch.log10`",:ref:`keeps_input_names-doc` - :meth:`Tensor.log10_`,None - ":meth:`Tensor.log1p`, :func:`torch.log1p`",:ref:`keeps_input_names-doc` - :meth:`Tensor.log1p_`,None - ":meth:`Tensor.log2`, :func:`torch.log2`",:ref:`keeps_input_names-doc` - :meth:`Tensor.log2_`,None - :meth:`Tensor.log_`,None - :meth:`Tensor.log_normal_`,None - ":meth:`Tensor.logical_not`, :func:`torch.logical_not`",:ref:`keeps_input_names-doc` - :meth:`Tensor.logical_not_`,None - ":meth:`Tensor.logsumexp`, :func:`torch.logsumexp`",:ref:`removes_dimensions-doc` - :meth:`Tensor.long`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.lt`, :func:`torch.lt`",:ref:`unifies_names_from_inputs-doc` - :func:`torch.manual_seed`,None - ":meth:`Tensor.masked_fill`, :func:`torch.masked_fill`",:ref:`keeps_input_names-doc` - :meth:`Tensor.masked_fill_`,None - ":meth:`Tensor.masked_select`, :func:`torch.masked_select`",Aligns mask up to input and then unifies_names_from_input_tensors - ":meth:`Tensor.matmul`, :func:`torch.matmul`",:ref:`contracts_away_dims-doc` - ":meth:`Tensor.mean`, :func:`torch.mean`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.median`, :func:`torch.median`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.nanmedian`, :func:`torch.nanmedian`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.mm`, :func:`torch.mm`",:ref:`contracts_away_dims-doc` - ":meth:`Tensor.mode`, :func:`torch.mode`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.mul`, :func:`torch.mul`",:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.mul_`,:ref:`unifies_names_from_inputs-doc` - ":meth:`Tensor.mv`, :func:`torch.mv`",:ref:`contracts_away_dims-doc` - :attr:`Tensor.names`,See documentation - ":meth:`Tensor.narrow`, :func:`torch.narrow`",:ref:`keeps_input_names-doc` - :attr:`Tensor.nbytes`,None - :attr:`Tensor.ndim`,None - :meth:`Tensor.ndimension`,None - ":meth:`Tensor.ne`, :func:`torch.ne`",:ref:`unifies_names_from_inputs-doc` - ":meth:`Tensor.neg`, :func:`torch.neg`",:ref:`keeps_input_names-doc` - :meth:`Tensor.neg_`,None - :func:`torch.normal`,:ref:`keeps_input_names-doc` - :meth:`Tensor.normal_`,None - ":meth:`Tensor.numel`, :func:`torch.numel`",None - :func:`torch.ones`,:ref:`factory-doc` - ":meth:`Tensor.pow`, :func:`torch.pow`",:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.pow_`,None - ":meth:`Tensor.prod`, :func:`torch.prod`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.rad2deg`, :func:`torch.rad2deg`",:ref:`keeps_input_names-doc` - :meth:`Tensor.rad2deg_`,None - :func:`torch.rand`,:ref:`factory-doc` - :func:`torch.rand`,:ref:`factory-doc` - :func:`torch.randn`,:ref:`factory-doc` - :func:`torch.randn`,:ref:`factory-doc` - :meth:`Tensor.random_`,None - ":meth:`Tensor.reciprocal`, :func:`torch.reciprocal`",:ref:`keeps_input_names-doc` - :meth:`Tensor.reciprocal_`,None - :meth:`Tensor.refine_names`,See documentation - :meth:`Tensor.register_hook`,None - :meth:`Tensor.register_post_accumulate_grad_hook`,None - :meth:`Tensor.rename`,See documentation - :meth:`Tensor.rename_`,See documentation - :attr:`Tensor.requires_grad`,None - :meth:`Tensor.requires_grad_`,None - :meth:`Tensor.resize_`,Only allow resizes that do not change shape - :meth:`Tensor.resize_as_`,Only allow resizes that do not change shape - ":meth:`Tensor.round`, :func:`torch.round`",:ref:`keeps_input_names-doc` - :meth:`Tensor.round_`,None - ":meth:`Tensor.rsqrt`, :func:`torch.rsqrt`",:ref:`keeps_input_names-doc` - :meth:`Tensor.rsqrt_`,None - ":meth:`Tensor.select`, :func:`torch.select`",:ref:`removes_dimensions-doc` - :meth:`Tensor.short`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.sigmoid`, :func:`torch.sigmoid`",:ref:`keeps_input_names-doc` - :meth:`Tensor.sigmoid_`,None - ":meth:`Tensor.sign`, :func:`torch.sign`",:ref:`keeps_input_names-doc` - :meth:`Tensor.sign_`,None - ":meth:`Tensor.sgn`, :func:`torch.sgn`",:ref:`keeps_input_names-doc` - :meth:`Tensor.sgn_`,None - ":meth:`Tensor.sin`, :func:`torch.sin`",:ref:`keeps_input_names-doc` - :meth:`Tensor.sin_`,None - ":meth:`Tensor.sinh`, :func:`torch.sinh`",:ref:`keeps_input_names-doc` - :meth:`Tensor.sinh_`,None - ":meth:`Tensor.asinh`, :func:`torch.asinh`",:ref:`keeps_input_names-doc` - :meth:`Tensor.asinh_`,None - :meth:`Tensor.size`,None - ":meth:`Tensor.softmax`, :func:`torch.softmax`",:ref:`keeps_input_names-doc` - ":meth:`Tensor.split`, :func:`torch.split`",:ref:`keeps_input_names-doc` - ":meth:`Tensor.sqrt`, :func:`torch.sqrt`",:ref:`keeps_input_names-doc` - :meth:`Tensor.sqrt_`,None - ":meth:`Tensor.squeeze`, :func:`torch.squeeze`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.std`, :func:`torch.std`",:ref:`removes_dimensions-doc` - :func:`torch.std_mean`,:ref:`removes_dimensions-doc` - :meth:`Tensor.stride`,None - ":meth:`Tensor.sub`, :func:`torch.sub`",:ref:`unifies_names_from_inputs-doc` - :meth:`Tensor.sub_`,:ref:`unifies_names_from_inputs-doc` - ":meth:`Tensor.sum`, :func:`torch.sum`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.tan`, :func:`torch.tan`",:ref:`keeps_input_names-doc` - :meth:`Tensor.tan_`,None - ":meth:`Tensor.tanh`, :func:`torch.tanh`",:ref:`keeps_input_names-doc` - :meth:`Tensor.tanh_`,None - ":meth:`Tensor.atanh`, :func:`torch.atanh`",:ref:`keeps_input_names-doc` - :meth:`Tensor.atanh_`,None - :func:`torch.tensor`,:ref:`factory-doc` - :meth:`Tensor.to`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.topk`, :func:`torch.topk`",:ref:`removes_dimensions-doc` - ":meth:`Tensor.transpose`, :func:`torch.transpose`",:ref:`permutes_dimensions-doc` - ":meth:`Tensor.trunc`, :func:`torch.trunc`",:ref:`keeps_input_names-doc` - :meth:`Tensor.trunc_`,None - :meth:`Tensor.type`,None - :meth:`Tensor.type_as`,:ref:`keeps_input_names-doc` - ":meth:`Tensor.unbind`, :func:`torch.unbind`",:ref:`removes_dimensions-doc` - :meth:`Tensor.unflatten`,See documentation - :meth:`Tensor.uniform_`,None - ":meth:`Tensor.var`, :func:`torch.var`",:ref:`removes_dimensions-doc` - :func:`torch.var_mean`,:ref:`removes_dimensions-doc` - :meth:`Tensor.zero_`,None - :func:`torch.zeros`,:ref:`factory-doc` - - -.. _keeps_input_names-doc: - -Keeps input names -^^^^^^^^^^^^^^^^^ - -All pointwise unary functions follow this rule as well as some other unary functions. - -- Check names: None -- Propagate names: input tensor's names are propagated to the output. - -:: - - >>> x = torch.randn(3, 3, names=('N', 'C')) - >>> x.abs().names - ('N', 'C') - -.. _removes_dimensions-doc: - -Removes dimensions -^^^^^^^^^^^^^^^^^^ - -All reduction ops like :meth:`~Tensor.sum` remove dimensions by reducing -over the desired dimensions. Other operations like :meth:`~Tensor.select` and -:meth:`~Tensor.squeeze` remove dimensions. - -Wherever one can pass an integer dimension index to an operator, one can also pass -a dimension name. Functions that take lists of dimension indices can also take in a -list of dimension names. - -- Check names: If :attr:`dim` or :attr:`dims` is passed in as a list of names, - check that those names exist in :attr:`self`. -- Propagate names: If the dimensions of the input tensor specified by :attr:`dim` - or :attr:`dims` are not present in the output tensor, then the corresponding names - of those dimensions do not appear in ``output.names``. - -:: - - >>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W')) - >>> x.squeeze('N').names - ('C', 'H', 'W') - - >>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) - >>> x.sum(['N', 'C']).names - ('H', 'W') - - # Reduction ops with keepdim=True don't actually remove dimensions. - >>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) - >>> x.sum(['N', 'C'], keepdim=True).names - ('N', 'C', 'H', 'W') - - -.. _unifies_names_from_inputs-doc: - -Unifies names from inputs -^^^^^^^^^^^^^^^^^^^^^^^^^ - -All binary arithmetic ops follow this rule. Operations that broadcast still -broadcast positionally from the right to preserve compatibility with unnamed -tensors. To perform explicit broadcasting by names, use :meth:`Tensor.align_as`. - -- Check names: All names must match positionally from the right. i.e., in - ``tensor + other``, ``match(tensor.names[i], other.names[i])`` must be true for all - ``i`` in ``(-min(tensor.dim(), other.dim()) + 1, -1]``. -- Check names: Furthermore, all named dimensions must be aligned from the right. - During matching, if we match a named dimension ``A`` with an unnamed dimension - ``None``, then ``A`` must not appear in the tensor with the unnamed dimension. -- Propagate names: unify pairs of names from the right from both tensors to - produce output names. - -For example, - -:: - - # tensor: Tensor[ N, None] - # other: Tensor[None, C] - >>> tensor = torch.randn(3, 3, names=('N', None)) - >>> other = torch.randn(3, 3, names=(None, 'C')) - >>> (tensor + other).names - ('N', 'C') - -Check names: - -- ``match(tensor.names[-1], other.names[-1])`` is ``True`` -- ``match(tensor.names[-2], tensor.names[-2])`` is ``True`` -- Because we matched ``None`` in :attr:`tensor` with ``'C'``, - check to make sure ``'C'`` doesn't exist in :attr:`tensor` (it does not). -- Check to make sure ``'N'`` doesn't exists in :attr:`other` (it does not). - -Finally, the output names are computed with -``[unify('N', None), unify(None, 'C')] = ['N', 'C']`` - -More examples:: - - # Dimensions don't match from the right: - # tensor: Tensor[N, C] - # other: Tensor[ N] - >>> tensor = torch.randn(3, 3, names=('N', 'C')) - >>> other = torch.randn(3, names=('N',)) - >>> (tensor + other).names - RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims - ['N']: dim 'C' and dim 'N' are at the same position from the right but do - not match. - - # Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]: - # tensor: Tensor[N, None] - # other: Tensor[ N] - >>> tensor = torch.randn(3, 3, names=('N', None)) - >>> other = torch.randn(3, names=('N',)) - >>> (tensor + other).names - RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and - dims ['N', None]: dim 'N' appears in a different position from the right - across both lists. - -.. note:: - - In both of the last examples, it is possible to align the tensors by names - and then perform the addition. Use :meth:`Tensor.align_as` to align - tensors by name or :meth:`Tensor.align_to` to align tensors to a custom - dimension ordering. - -.. _permutes_dimensions-doc: - -Permutes dimensions -^^^^^^^^^^^^^^^^^^^ - -Some operations, like :meth:`Tensor.t()`, permute the order of dimensions. Dimension names -are attached to individual dimensions so they get permuted as well. - -If the operator takes in positional index :attr:`dim`, it is also able to take a dimension -name as :attr:`dim`. - -- Check names: If :attr:`dim` is passed as a name, check that it exists in the tensor. -- Propagate names: Permute dimension names in the same way as the dimensions that are - being permuted. - -:: - - >>> x = torch.randn(3, 3, names=('N', 'C')) - >>> x.transpose('N', 'C').names - ('C', 'N') - -.. _contracts_away_dims-doc: - -Contracts away dims -^^^^^^^^^^^^^^^^^^^ - -Matrix multiply functions follow some variant of this. Let's go through -:func:`torch.mm` first and then generalize the rule for batch matrix multiplication. - -For ``torch.mm(tensor, other)``: - -- Check names: None -- Propagate names: result names are ``(tensor.names[-2], other.names[-1])``. - -:: - - >>> x = torch.randn(3, 3, names=('N', 'D')) - >>> y = torch.randn(3, 3, names=('in', 'out')) - >>> x.mm(y).names - ('N', 'out') - -Inherently, a matrix multiplication performs a dot product over two dimensions, -collapsing them. When two tensors are matrix-multiplied, the contracted dimensions -disappear and do not show up in the output tensor. - -:func:`torch.mv`, :func:`torch.dot` work in a similar way: name inference does not -check input names and removes the dimensions that are involved in the dot product: - -:: - - >>> x = torch.randn(3, 3, names=('N', 'D')) - >>> y = torch.randn(3, names=('something',)) - >>> x.mv(y).names - ('N',) - -Now, let's take a look at ``torch.matmul(tensor, other)``. Assume that ``tensor.dim() >= 2`` -and ``other.dim() >= 2``. - -- Check names: Check that the batch dimensions of the inputs are aligned and broadcastable. - See :ref:`unifies_names_from_inputs-doc` for what it means for the inputs to be aligned. -- Propagate names: result names are obtained by unifying the batch dimensions and removing - the contracted dimensions: - ``unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])``. - -Examples:: - - # Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F']. - # 'A', 'B' are batch dimensions. - >>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D')) - >>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F')) - >>> torch.matmul(x, y).names - ('A', 'B', 'C', 'F') - - -Finally, there are fused ``add`` versions of many matmul functions. i.e., :func:`addmm` -and :func:`addmv`. These are treated as composing name inference for i.e. :func:`mm` and -name inference for :func:`add`. - -.. _factory-doc: - -Factory functions -^^^^^^^^^^^^^^^^^ - - -Factory functions now take a new :attr:`names` argument that associates a name -with each dimension. - -:: - - >>> torch.zeros(2, 3, names=('N', 'C')) - tensor([[0., 0., 0.], - [0., 0., 0.]], names=('N', 'C')) - -.. _out_function_semantics-doc: - -out function and in-place variants -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -A tensor specified as an ``out=`` tensor has the following behavior: - -- If it has no named dimensions, then the names computed from the operation - get propagated to it. -- If it has any named dimensions, then the names computed from the operation - must be exactly equal to the existing names. Otherwise, the operation errors. - -All in-place methods modify inputs to have names equal to the computed names -from name inference. For example: - -:: - - >>> x = torch.randn(3, 3) - >>> y = torch.randn(3, 3, names=('N', 'C')) - >>> x.names - (None, None) - - >>> x += y - >>> x.names - ('N', 'C') diff --git a/docs/source/named_tensor.md b/docs/source/named_tensor.md new file mode 100644 index 00000000000000..72e895882f17c2 --- /dev/null +++ b/docs/source/named_tensor.md @@ -0,0 +1,316 @@ +```{eval-rst} +.. currentmodule:: torch +``` + +(named_tensors-doc)= + +# Named Tensors + +Named Tensors allow users to give explicit names to tensor dimensions. +In most cases, operations that take dimension parameters will accept +dimension names, avoiding the need to track dimensions by position. +In addition, named tensors use names to automatically check that APIs +are being used correctly at runtime, providing extra safety. Names can +also be used to rearrange dimensions, for example, to support +"broadcasting by name" rather than "broadcasting by position". + + +```{warning} + The named tensor API is a prototype feature and subject to change. +``` + +## Creating named tensors + + +Factory functions now take a new {attr}`names` argument that associates a name +with each dimension. + +``` + >>> torch.zeros(2, 3, names=('N', 'C')) + tensor([[0., 0., 0.], + [0., 0., 0.]], names=('N', 'C')) +``` + +Named dimensions, like regular Tensor dimensions, are ordered. +``tensor.names[i]`` is the name of dimension ``i`` of ``tensor``. + +The following factory functions support named tensors: + +- {func}`torch.empty` +- {func}`torch.rand` +- {func}`torch.randn` +- {func}`torch.ones` +- {func}`torch.tensor` +- {func}`torch.zeros` + +## Named dimensions + +See {attr}`~Tensor.names` for restrictions on tensor names. + +Use {attr}`~Tensor.names` to access the dimension names of a tensor and +{meth}`~Tensor.rename` to rename named dimensions. + +``` + >>> imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W')) + >>> imgs.names + ('N', 'C', 'H', 'W') + + >>> renamed_imgs = imgs.rename(H='height', W='width') + >>> renamed_imgs.names + ('N', 'C', 'height', 'width) +``` + +Named tensors can coexist with unnamed tensors; named tensors are instances of +{class}`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named +tensors do not require all dimensions to be named. + +``` + >>> imgs = torch.randn(1, 2, 2, 3 , names=(None, 'C', 'H', 'W')) + >>> imgs.names + (None, 'C', 'H', 'W') +``` + +## Name propagation semantics + +Named tensors use names to automatically check that APIs are being called +correctly at runtime. This occurs in a process called *name inference*. +More formally, name inference consists of the following two steps: + +- **Check names**: an operator may perform automatic checks at runtime that + check that certain dimension names must match. +- **Propagate names**: name inference propagates names to output tensors. + +All operations that support named tensors propagate names. + +``` + >>> x = torch.randn(3, 3, names=('N', 'C')) + >>> x.abs().names + ('N', 'C') +``` + + +(match_semantics-doc)= +### match semantics + + +Two names *match* if they are equal (string equality) or if at least one is ``None``. +Nones are essentially a special "wildcard" name. + +``unify(A, B)`` determines which of the names ``A`` and ``B`` to propagate to the outputs. +It returns the more *specific* of the two names, if they match. If the names do not match, +then it errors. + +```{note} +In practice, when working with named tensors, one should avoid having unnamed +dimensions because their handling can be complicated. It is recommended to lift +all unnamed dimensions to be named dimensions by using {meth}`~Tensor.refine_names`. +``` + +### Basic name inference rules + +Let's see how ``match`` and ``unify`` are used in name inference in the case of +adding two one-dim tensors with no broadcasting. + +``` + x = torch.randn(3, names=('X',)) + y = torch.randn(3) + z = torch.randn(3, names=('Z',)) +``` + +**Check names**: check that the names of the two tensors *match*. + +For the following examples: + +``` + >>> # x + y # match('X', None) is True + >>> # x + z # match('X', 'Z') is False + >>> # x + x # match('X', 'X') is True + + >>> x + z + Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match. +``` + +**Propagate names**: *unify* the names to select which one to propagate. +In the case of ``x + y``, ``unify('X', None) = 'X'`` because ``'X'`` is more +specific than ``None``. + +``` + >>> (x + y).names + ('X',) + >>> (x + x).names + ('X',) +``` + +For a comprehensive list of name inference rules, see {ref}`name_inference_reference-doc`. +Here are two common operations that may be useful to go over: + +- Binary arithmetic ops: {ref}`unifies_names_from_inputs-doc` +- Matrix multiplication ops: {ref}`contracts_away_dims-doc` + +## Explicit alignment by names + +Use {meth}`~Tensor.align_as` or {meth}`~Tensor.align_to` to align tensor dimensions +by name to a specified ordering. This is useful for performing "broadcasting by names". + +``` + # This function is agnostic to the dimension ordering of `input`, + # as long as it has a `C` dimension somewhere. + def scale_channels(input, scale): + scale = scale.refine_names('C') + return input * scale.align_as(input) + + >>> num_channels = 3 + >>> scale = torch.randn(num_channels, names=('C',)) + >>> imgs = torch.rand(3, 3, 3, num_channels, names=('N', 'H', 'W', 'C')) + >>> more_imgs = torch.rand(3, num_channels, 3, 3, names=('N', 'C', 'H', 'W')) + >>> videos = torch.randn(3, num_channels, 3, 3, 3, names=('N', 'C', 'H', 'W', 'D') + + >>> scale_channels(imgs, scale) + >>> scale_channels(more_imgs, scale) + >>> scale_channels(videos, scale) +``` + +## Manipulating dimensions + +Use {meth}`~Tensor.align_to` to permute large amounts of dimensions without +mentioning all of them as in required by {meth}`~Tensor.permute`. + +``` + >>> tensor = torch.randn(2, 2, 2, 2, 2, 2) + >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') + + # Move the F (dim 5) and E dimension (dim 4) to the front while keeping + # the rest in the same order + >>> tensor.permute(5, 4, 0, 1, 2, 3) + >>> named_tensor.align_to('F', 'E', ...) +``` + +Use {meth}`~Tensor.flatten` and {meth}`~Tensor.unflatten` to flatten and unflatten +dimensions, respectively. These methods are more verbose than {meth}`~Tensor.view` +and {meth}`~Tensor.reshape`, but have more semantic meaning to someone reading the code. + + +``` + >>> imgs = torch.randn(32, 3, 128, 128) + >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') + + >>> flat_imgs = imgs.view(32, -1) + >>> named_flat_imgs = named_imgs.flatten(['C', 'H', 'W'], 'features') + >>> named_flat_imgs.names + ('N', 'features') + + >>> unflattened_named_imgs = named_flat_imgs.unflatten('features', [('C', 3), ('H', 128), ('W', 128)]) + >>> unflattened_named_imgs.names + ('N', 'C', 'H', 'W') +``` + +(named_tensors_autograd-doc)= +## Autograd support + +Autograd currently supports named tensors in a limited manner: autograd ignores +names on all tensors. Gradient computation is still correct but we lose the +safety that names give us. + +``` + >>> x = torch.randn(3, names=('D',)) + >>> weight = torch.randn(3, names=('D',), requires_grad=True) + >>> loss = (x - weight).abs() + >>> grad_loss = torch.randn(3) + >>> loss.backward(grad_loss) + >>> weight.grad # Unnamed for now. Will be named in the future + tensor([-1.8107, -0.6357, 0.0783]) + + >>> weight.grad.zero_() + >>> grad_loss = grad_loss.refine_names('C') + >>> loss = (x - weight).abs() + # Ideally we'd check that the names of loss and grad_loss match but we don't yet. + >>> loss.backward(grad_loss) + >>> weight.grad + tensor([-1.8107, -0.6357, 0.0783]) +``` + +## Currently supported operations and subsystems + +### Operators + +See {ref}`name_inference_reference-doc` for a full list of the supported torch and +tensor operations. We do not yet support the following that is not covered by the link: + +- indexing, advanced indexing. + +For ``torch.nn.functional`` operators, we support the following: + +- {func}`torch.nn.functional.relu` +- {func}`torch.nn.functional.softmax` +- {func}`torch.nn.functional.log_softmax` +- {func}`torch.nn.functional.tanh` +- {func}`torch.nn.functional.sigmoid` +- {func}`torch.nn.functional.dropout` + +### Subsystems + + +Autograd is supported, see {ref}`named_tensors_autograd-doc`. +Because gradients are currently unnamed, optimizers may work but are untested. + +NN modules are currently unsupported. This can lead to the following when calling +modules with named tensor inputs: + +- NN module parameters are unnamed, so outputs may be partially named. +- NN module forward passes have code that don't support named tensors and will + error out appropriately. + +We also do not support the following subsystems, though some may work out +of the box: + +- distributions +- serialization ({func}`torch.load`, {func}`torch.save`) +- multiprocessing +- JIT +- distributed +- ONNX + +If any of these would help your use case, please +[search if an issue has already been filed](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22) +and if not, [file one](https://github.com/pytorch/pytorch/issues/new/choose). + +## Named tensor API reference + +In this section please find the documentation for named tensor specific APIs. +For a comprehensive reference for how names are propagated through other PyTorch +operators, see {ref}`name_inference_reference-doc`. + +```{eval-rst} +.. class:: Tensor() + :noindex: + + .. autoattribute:: names + + .. automethod:: rename + + .. automethod:: rename_ + + .. automethod:: refine_names + + .. automethod:: align_as + + .. automethod:: align_to + + .. py:method:: flatten(dims, out_dim) -> Tensor + :noindex: + + Flattens :attr:`dims` into a single dimension with name :attr:`out_dim`. + + All of `dims` must be consecutive in order in the :attr:`self` tensor, + but not necessary contiguous in memory. + + Examples:: + + >>> imgs = torch.randn(32, 3, 128, 128, names=('N', 'C', 'H', 'W')) + >>> flat_imgs = imgs.flatten(['C', 'H', 'W'], 'features') + >>> flat_imgs.names, flat_imgs.shape + (('N', 'features'), torch.Size([32, 49152])) + + .. warning:: + The named tensor API is experimental and subject to change. +``` \ No newline at end of file diff --git a/docs/source/named_tensor.rst b/docs/source/named_tensor.rst deleted file mode 100644 index 112682e7b26a10..00000000000000 --- a/docs/source/named_tensor.rst +++ /dev/null @@ -1,319 +0,0 @@ -.. currentmodule:: torch - -.. _named_tensors-doc: - -Named Tensors -============= - -Named Tensors allow users to give explicit names to tensor dimensions. -In most cases, operations that take dimension parameters will accept -dimension names, avoiding the need to track dimensions by position. -In addition, named tensors use names to automatically check that APIs -are being used correctly at runtime, providing extra safety. Names can -also be used to rearrange dimensions, for example, to support -"broadcasting by name" rather than "broadcasting by position". - - -.. warning:: - The named tensor API is a prototype feature and subject to change. - -Creating named tensors ----------------------- - -Factory functions now take a new :attr:`names` argument that associates a name -with each dimension. - -:: - - >>> torch.zeros(2, 3, names=('N', 'C')) - tensor([[0., 0., 0.], - [0., 0., 0.]], names=('N', 'C')) - -Named dimensions, like regular Tensor dimensions, are ordered. -``tensor.names[i]`` is the name of dimension ``i`` of ``tensor``. - -The following factory functions support named tensors: - -- :func:`torch.empty` -- :func:`torch.rand` -- :func:`torch.randn` -- :func:`torch.ones` -- :func:`torch.tensor` -- :func:`torch.zeros` - -Named dimensions ----------------- - -See :attr:`~Tensor.names` for restrictions on tensor names. - -Use :attr:`~Tensor.names` to access the dimension names of a tensor and -:meth:`~Tensor.rename` to rename named dimensions. - -:: - - >>> imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W')) - >>> imgs.names - ('N', 'C', 'H', 'W') - - >>> renamed_imgs = imgs.rename(H='height', W='width') - >>> renamed_imgs.names - ('N', 'C', 'height', 'width) - - -Named tensors can coexist with unnamed tensors; named tensors are instances of -:class:`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named -tensors do not require all dimensions to be named. - -:: - - >>> imgs = torch.randn(1, 2, 2, 3 , names=(None, 'C', 'H', 'W')) - >>> imgs.names - (None, 'C', 'H', 'W') - -Name propagation semantics --------------------------- - -Named tensors use names to automatically check that APIs are being called -correctly at runtime. This occurs in a process called *name inference*. -More formally, name inference consists of the following two steps: - -- **Check names**: an operator may perform automatic checks at runtime that - check that certain dimension names must match. -- **Propagate names**: name inference propagates names to output tensors. - -All operations that support named tensors propagate names. - -:: - - >>> x = torch.randn(3, 3, names=('N', 'C')) - >>> x.abs().names - ('N', 'C') - - -.. _match_semantics-doc: - -match semantics -^^^^^^^^^^^^^^^ - -Two names *match* if they are equal (string equality) or if at least one is ``None``. -Nones are essentially a special "wildcard" name. - -``unify(A, B)`` determines which of the names ``A`` and ``B`` to propagate to the outputs. -It returns the more *specific* of the two names, if they match. If the names do not match, -then it errors. - -.. note:: - In practice, when working with named tensors, one should avoid having unnamed - dimensions because their handling can be complicated. It is recommended to lift - all unnamed dimensions to be named dimensions by using :meth:`~Tensor.refine_names`. - - -Basic name inference rules -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Let's see how ``match`` and ``unify`` are used in name inference in the case of -adding two one-dim tensors with no broadcasting. - -:: - - x = torch.randn(3, names=('X',)) - y = torch.randn(3) - z = torch.randn(3, names=('Z',)) - -**Check names**: check that the names of the two tensors *match*. - -For the following examples: - -:: - - >>> # x + y # match('X', None) is True - >>> # x + z # match('X', 'Z') is False - >>> # x + x # match('X', 'X') is True - - >>> x + z - Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match. - -**Propagate names**: *unify* the names to select which one to propagate. -In the case of ``x + y``, ``unify('X', None) = 'X'`` because ``'X'`` is more -specific than ``None``. - -:: - - >>> (x + y).names - ('X',) - >>> (x + x).names - ('X',) - -For a comprehensive list of name inference rules, see :ref:`name_inference_reference-doc`. -Here are two common operations that may be useful to go over: - -- Binary arithmetic ops: :ref:`unifies_names_from_inputs-doc` -- Matrix multiplication ops: :ref:`contracts_away_dims-doc` - -Explicit alignment by names ---------------------------- - -Use :meth:`~Tensor.align_as` or :meth:`~Tensor.align_to` to align tensor dimensions -by name to a specified ordering. This is useful for performing "broadcasting by names". - -:: - - # This function is agnostic to the dimension ordering of `input`, - # as long as it has a `C` dimension somewhere. - def scale_channels(input, scale): - scale = scale.refine_names('C') - return input * scale.align_as(input) - - >>> num_channels = 3 - >>> scale = torch.randn(num_channels, names=('C',)) - >>> imgs = torch.rand(3, 3, 3, num_channels, names=('N', 'H', 'W', 'C')) - >>> more_imgs = torch.rand(3, num_channels, 3, 3, names=('N', 'C', 'H', 'W')) - >>> videos = torch.randn(3, num_channels, 3, 3, 3, names=('N', 'C', 'H', 'W', 'D') - - >>> scale_channels(imgs, scale) - >>> scale_channels(more_imgs, scale) - >>> scale_channels(videos, scale) - -Manipulating dimensions ------------------------ - -Use :meth:`~Tensor.align_to` to permute large amounts of dimensions without -mentioning all of them as in required by :meth:`~Tensor.permute`. - -:: - - >>> tensor = torch.randn(2, 2, 2, 2, 2, 2) - >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') - - # Move the F (dim 5) and E dimension (dim 4) to the front while keeping - # the rest in the same order - >>> tensor.permute(5, 4, 0, 1, 2, 3) - >>> named_tensor.align_to('F', 'E', ...) - -Use :meth:`~Tensor.flatten` and :meth:`~Tensor.unflatten` to flatten and unflatten -dimensions, respectively. These methods are more verbose than :meth:`~Tensor.view` -and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading the code. - -:: - - >>> imgs = torch.randn(32, 3, 128, 128) - >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') - - >>> flat_imgs = imgs.view(32, -1) - >>> named_flat_imgs = named_imgs.flatten(['C', 'H', 'W'], 'features') - >>> named_flat_imgs.names - ('N', 'features') - - >>> unflattened_named_imgs = named_flat_imgs.unflatten('features', [('C', 3), ('H', 128), ('W', 128)]) - >>> unflattened_named_imgs.names - ('N', 'C', 'H', 'W') - -.. _named_tensors_autograd-doc: - -Autograd support ----------------- - -Autograd currently supports named tensors in a limited manner: autograd ignores -names on all tensors. Gradient computation is still correct but we lose the -safety that names give us. - -:: - - >>> x = torch.randn(3, names=('D',)) - >>> weight = torch.randn(3, names=('D',), requires_grad=True) - >>> loss = (x - weight).abs() - >>> grad_loss = torch.randn(3) - >>> loss.backward(grad_loss) - >>> weight.grad # Unnamed for now. Will be named in the future - tensor([-1.8107, -0.6357, 0.0783]) - - >>> weight.grad.zero_() - >>> grad_loss = grad_loss.refine_names('C') - >>> loss = (x - weight).abs() - # Ideally we'd check that the names of loss and grad_loss match but we don't yet. - >>> loss.backward(grad_loss) - >>> weight.grad - tensor([-1.8107, -0.6357, 0.0783]) - -Currently supported operations and subsystems ---------------------------------------------- - -Operators -^^^^^^^^^ - -See :ref:`name_inference_reference-doc` for a full list of the supported torch and -tensor operations. We do not yet support the following that is not covered by the link: - -- indexing, advanced indexing. - -For ``torch.nn.functional`` operators, we support the following: - -- :func:`torch.nn.functional.relu` -- :func:`torch.nn.functional.softmax` -- :func:`torch.nn.functional.log_softmax` -- :func:`torch.nn.functional.tanh` -- :func:`torch.nn.functional.sigmoid` -- :func:`torch.nn.functional.dropout` - -Subsystems -^^^^^^^^^^ - -Autograd is supported, see :ref:`named_tensors_autograd-doc`. -Because gradients are currently unnamed, optimizers may work but are untested. - -NN modules are currently unsupported. This can lead to the following when calling -modules with named tensor inputs: - -- NN module parameters are unnamed, so outputs may be partially named. -- NN module forward passes have code that don't support named tensors and will - error out appropriately. - -We also do not support the following subsystems, though some may work out -of the box: - -- distributions -- serialization (:func:`torch.load`, :func:`torch.save`) -- multiprocessing -- JIT -- distributed -- ONNX - -If any of these would help your use case, please -`search if an issue has already been filed `_ -and if not, `file one `_. - -Named tensor API reference --------------------------- - -In this section please find the documentation for named tensor specific APIs. -For a comprehensive reference for how names are propagated through other PyTorch -operators, see :ref:`name_inference_reference-doc`. - -.. class:: Tensor() - :noindex: - - .. autoattribute:: names - .. automethod:: rename - .. automethod:: rename_ - .. automethod:: refine_names - - .. automethod:: align_as - .. automethod:: align_to - - .. py:method:: flatten(dims, out_dim) -> Tensor - :noindex: - - Flattens :attr:`dims` into a single dimension with name :attr:`out_dim`. - - All of `dims` must be consecutive in order in the :attr:`self` tensor, - but not necessary contiguous in memory. - - Examples:: - - >>> imgs = torch.randn(32, 3, 128, 128, names=('N', 'C', 'H', 'W')) - >>> flat_imgs = imgs.flatten(['C', 'H', 'W'], 'features') - >>> flat_imgs.names, flat_imgs.shape - (('N', 'features'), torch.Size([32, 49152])) - - .. warning:: - The named tensor API is experimental and subject to change. diff --git a/docs/source/nested.md b/docs/source/nested.md new file mode 100644 index 00000000000000..99bb2ad67056ee --- /dev/null +++ b/docs/source/nested.md @@ -0,0 +1,523 @@ +# torch.nested + +```{eval-rst} +.. automodule:: torch.nested +``` + +## Introduction + + +```{warning} + The PyTorch API of nested tensors is in prototype stage and will change in the near future. +``` + +Nested tensors allow for ragged-shaped data to be contained within and operated upon as a +single tensor. Such data is stored underneath in an efficient packed representation, while exposing +a standard PyTorch tensor interface for applying operations. + +A common application of nested tensors is for expressing batches of variable-length sequential data +present in various domains, such as varying sentence lengths, image sizes, and audio / video clip +lengths. Traditionally, such data has been handled by padding sequences to that of the max length +within a batch, performing computation on the padded form, and subsequently masking to remove +padding. This is inefficient and error-prone, and nested tensors exist to address these problems. + +The API for calling operations on a nested tensor is no different from that of a regular +``torch.Tensor``, allowing for seamless integration with existing models, with the main +difference being {ref}`construction of the inputs `. + +As this is a prototype feature, the set of {ref}`operations supported ` is +limited, but growing. We welcome issues, feature requests, and contributions. +More information on contributing can be found +[in this Readme](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md). + +(construction)= +## Construction + +```{note} + + There are two forms of nested tensors present within PyTorch, distinguished by layout as + specified during construction. Layout can be one of ``torch.strided`` or ``torch.jagged``. + We recommend utilizing the ``torch.jagged`` layout whenever possible. While it currently only + supports a single ragged dimension, it has better op coverage, receives active development, and + integrates well with ``torch.compile``. These docs adhere to this recommendation and refer to + nested tensors with the ``torch.jagged`` layout as "NJTs" for brevity throughout. +``` + +Construction is straightforward and involves passing a list of tensors to the +``torch.nested.nested_tensor`` constructor. A nested tensor with the ``torch.jagged`` layout +(AKA an "NJT") supports a single ragged dimension. This constructor will copy the input tensors +into a packed, contiguous block of memory according to the layout described in the `data_layout`_ +section below. + +``` +>>> a, b = torch.arange(3), torch.arange(5) + 3 +>>> a +tensor([0, 1, 2]) +>>> b +tensor([3, 4, 5, 6, 7]) +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) +>>> print([component for component in nt]) +[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])] +``` + +Each tensor in the list must have the same number of dimensions, but the shapes can otherwise vary +along a single dimension. If the dimensionalities of the input components don't match, the +constructor throws an error. +``` +>>> a = torch.randn(50, 128) # 2D tensor +>>> b = torch.randn(2, 50, 128) # 3D tensor +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) +... +RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim +``` + +During construction, dtype, device, and whether gradients are required can be chosen via the +usual keyword arguments. + +``` +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True) +>>> print([component for component in nt]) +[tensor([0., 1., 2.], device='cuda:0', + grad_fn=), tensor([3., 4., 5., 6., 7.], device='cuda:0', + grad_fn=)] +``` + +``torch.nested.as_nested_tensor`` can be used to preserve autograd history from the tensors passed +to the constructor. When this constructor is utilized, gradients will flow through the nested tensor +back into the original components. Note that this constructor still copies the input components into +a packed, contiguous block of memory. + +``` +>>> a = torch.randn(12, 512, requires_grad=True) +>>> b = torch.randn(23, 512, requires_grad=True) +>>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) +>>> nt.sum().backward() +>>> a.grad +tensor([[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + ..., + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]]) +>>> b.grad +tensor([[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + ..., + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]]) +``` + +The above functions all create contiguous NJTs, where a chunk of memory is allocated to store +a packed form of the underlying components (see the `data_layout`_ section below for more +details). + +It is also possible to create a non-contiguous NJT view over a pre-existing dense tensor +with padding, avoiding the memory allocation and copying. ``torch.nested.narrow()`` is the tool +for accomplishing this. + +``` +>>> padded = torch.randn(3, 5, 4) +>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64) +>>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged) +>>> nt.shape +torch.Size([3, j1, 4]) +>>> nt.is_contiguous() +False +``` + +Note that the nested tensor acts as a view over the original padded dense tensor, referencing the +same memory without copying / allocation. Operation support for non-contiguous NJTs is somewhat more +limited, so if you run into support gaps, it's always possible to convert to a contiguous NJT +using ``contiguous()``. + +(data_layout)= +## Data Layout and Shape + +For efficiency, nested tensors generally pack their tensor components into a contiguous chunk of +memory and maintain additional metadata to specify batch item boundaries. For the ``torch.jagged`` +layout, the contiguous chunk of memory is stored in the ``values`` component, with the ``offsets`` +component delineating batch item boundaries for the ragged dimension. + +![image](_static/img/nested/njt_visual.png) + +It's possible to directly access the underlying NJT components when necessary. + +``` +>>> a = torch.randn(50, 128) # text 1 +>>> b = torch.randn(32, 128) # text 2 +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) +>>> nt.values().shape # note the "packing" of the ragged dimension; no padding needed +torch.Size([82, 128]) +>>> nt.offsets() +tensor([ 0, 50, 82]) +``` + +It can also be useful to construct an NJT from the jagged ``values`` and ``offsets`` +constituents directly; the ``torch.nested.nested_tensor_from_jagged()`` constructor serves +this purpose. + +``` +>>> values = torch.randn(82, 128) +>>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64) +>>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets) +``` + +An NJT has a well-defined shape with dimensionality 1 greater than that of its components. The +underlying structure of the ragged dimension is represented by a symbolic value (``j1`` in the +example below). + +``` +>>> a = torch.randn(50, 128) +>>> b = torch.randn(32, 128) +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) +>>> nt.dim() +3 +>>> nt.shape +torch.Size([2, j1, 128]) +``` + +NJTs must have the same ragged structure to be compatible with each other. For example, to run a +binary operation involving two NJTs, the ragged structures must match (i.e. they must have the +same ragged shape symbol in their shapes). In the details, each symbol corresponds with an exact +``offsets`` tensor, so both NJTs must have the same ``offsets`` tensor to be compatible with +each other. + +``` +>>> a = torch.randn(50, 128) +>>> b = torch.randn(32, 128) +>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) +>>> nt2 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) +>>> nt1.offsets() is nt2.offsets() +False +>>> nt3 = nt1 + nt2 +RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128) +``` + +In the above example, even though the conceptual shapes of the two NJTs are the same, they don't +share a reference to the same ``offsets`` tensor, so their shapes differ, and they are not +compatible. We recognize that this behavior is unintuitive and are working hard to relax this +restriction for the beta release of nested tensors. For a workaround, see the +{ref}`Troubleshooting ` section of this document. + +In addition to the ``offsets`` metadata, NJTs can also compute and cache the minimum and maximum +sequence lengths for its components, which can be useful for invoking particular kernels (e.g. SDPA). +There are currently no public APIs for accessing these, but this will change for the beta release. + +(supported operations)= +## Supported Operations + +This section contains a list of common operations over nested tensors that you may find useful. +It is not comprehensive, as there are on the order of a couple thousand ops within PyTorch. While +a sizeable subset of these are supported for nested tensors today, full support is a large task. +The ideal state for nested tensors is full support of all PyTorch operations that are available +for non-nested tensors. To help us accomplish this, please consider: + +* Requesting particular ops needed for your use case + [here](https://github.com/pytorch/pytorch/issues/118107) to help us prioritize. +* Contributing! It's not too hard to add nested tensor support for a given PyTorch op; see + the [Contributions](contributions) section below for details. + +### Viewing nested tensor constituents + +``unbind()`` allows you to retrieve a view of the nested tensor's constituents. + +``` +>>> import torch +>>> a = torch.randn(2, 3) +>>> b = torch.randn(3, 3) +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) +>>> nt.unbind() +(tensor([[-0.9916, -0.3363, -0.2799], + [-2.3520, -0.5896, -0.4374]]), tensor([[-2.0969, -1.0104, 1.4841], + [ 2.0952, 0.2973, 0.2516], + [ 0.9035, 1.3623, 0.2026]])) +>>> nt.unbind()[0] is not a +True +>>> nt.unbind()[0].mul_(3) +tensor([[ 3.6858, -3.7030, -4.4525], + [-2.3481, 2.0236, 0.1975]]) +>>> nt.unbind() +(tensor([[-2.9747, -1.0089, -0.8396], + [-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104, 1.4841], + [ 2.0952, 0.2973, 0.2516], + [ 0.9035, 1.3623, 0.2026]])) +``` + +Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which +represents the first entry or constituent of the nested tensor. + +#### Conversions to / from padded + +``torch.nested.to_padded_tensor()`` converts an NJT to a padded dense tensor with the specified +padding value. The ragged dimension will be padded out to the size of the maximum sequence length. + +``` +>>> import torch +>>> a = torch.randn(2, 3) +>>> b = torch.randn(6, 3) +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) +>>> padded = torch.nested.to_padded_tensor(nt, padding=4.2) +>>> padded +tensor([[[ 1.6107, 0.5723, 0.3913], + [ 0.0700, -0.4954, 1.8663], + [ 4.2000, 4.2000, 4.2000], + [ 4.2000, 4.2000, 4.2000], + [ 4.2000, 4.2000, 4.2000], + [ 4.2000, 4.2000, 4.2000]], + [[-0.0479, -0.7610, -0.3484], + [ 1.1345, 1.0556, 0.3634], + [-1.7122, -0.5921, 0.0540], + [-0.5506, 0.7608, 2.0606], + [ 1.5658, -1.1934, 0.3041], + [ 0.1483, -1.1284, 0.6957]]]) +``` + +This can be useful as an escape hatch to work around NJT support gaps, but ideally such +conversions should be avoided when possible for optimal memory usage and performance, as the +more efficient nested tensor layout does not materialize padding. + +The reverse conversion can be accomplished using ``torch.nested.narrow()``, which applies +ragged structure to a given dense tensor to produce an NJT. Note that by default, this operation +does not copy the underlying data, and thus the output NJT is generally non-contiguous. It may be +useful to explicitly call ``contiguous()`` here if a contiguous NJT is desired. + +``` +>>> padded = torch.randn(3, 5, 4) +>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64) +>>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged) +>>> nt.shape +torch.Size([3, j1, 4]) +>>> nt = nt.contiguous() +>>> nt.shape +torch.Size([3, j2, 4]) +``` + +### Shape manipulations + +Nested tensors support a wide array of operations for shape manipulation, including views. + +``` +>>> a = torch.randn(2, 6) +>>> b = torch.randn(4, 6) +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) +>>> nt.shape +torch.Size([2, j1, 6]) +>>> nt.unsqueeze(-1).shape +torch.Size([2, j1, 6, 1]) +>>> nt.unflatten(-1, [2, 3]).shape +torch.Size([2, j1, 2, 3]) +>>> torch.cat([nt, nt], dim=2).shape +torch.Size([2, j1, 12]) +>>> torch.stack([nt, nt], dim=2).shape +torch.Size([2, j1, 2, 6]) +>>> nt.transpose(-1, -2).shape +torch.Size([2, 6, j1]) +``` + +### Attention mechanisms + +As variable-length sequences are common inputs to attention mechanisms, nested tensors support +important attention operators +[Scaled Dot Product Attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and +[FlexAttention](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention). +See +[here](https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention) +for usage examples of NJT with SDPA and +[here](https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#flexattention-njt) +for usage examples of NJT with FlexAttention. + +(usage_with_torch_compile)= +## Usage with torch.compile + + +NJTs are designed to be used with ``torch.compile()`` for optimal performance, and we always +recommend utilizing ``torch.compile()`` with NJTs when possible. NJTs work out-of-the-box and +graph-break-free both when passed as inputs to a compiled function or module OR when +instantiated in-line within the function. + +```{note} + If you're not able to utilize ``torch.compile()`` for your use case, performance and memory + usage may still benefit from the use of NJTs, but it's not as clear-cut whether this will be + the case. It is important that the tensors being operated on are large enough so the + performance gains are not outweighed by the overhead of python tensor subclasses. +``` + +``` +>>> import torch +>>> a = torch.randn(2, 3) +>>> b = torch.randn(4, 3) +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) +>>> def f(x): return x.sin() + 1 +... +>>> compiled_f = torch.compile(f, fullgraph=True) +>>> output = compiled_f(nt) +>>> output.shape +torch.Size([2, j1, 3]) +>>> def g(values, offsets): return torch.nested.nested_tensor_from_jagged(values, offsets) * 2. +... +>>> compiled_g = torch.compile(g, fullgraph=True) +>>> output2 = compiled_g(nt.values(), nt.offsets()) +>>> output2.shape +torch.Size([2, j1, 3]) +``` + +Note that NJTs support +[Dynamic Shapes](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html) +to avoid unnecessary recompiles with changing ragged structure. + +``` +>>> a = torch.randn(2, 3) +>>> b = torch.randn(4, 3) +>>> c = torch.randn(5, 3) +>>> d = torch.randn(6, 3) +>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged) +>>> nt2 = torch.nested.nested_tensor([c, d], layout=torch.jagged) +>>> def f(x): return x.sin() + 1 +... +>>> compiled_f = torch.compile(f, fullgraph=True) +>>> output1 = compiled_f(nt1) +>>> output2 = compiled_f(nt2) # NB: No recompile needed even though ragged structure differs +``` + +If you run into problems or arcane errors when utilizing NJT + ``torch.compile``, please file a +PyTorch issue. Full subclass support within ``torch.compile`` is a long-term effort and there may +be some rough edges at this time. + +(troubleshooting)= +## Troubleshooting + +This section contains common errors that you may run into when utilizing nested tensors, alongside +the reason for these errors and suggestions for how to address them. + +(unimplemented_op)= +### Unimplemented ops + +This error is becoming rarer as nested tensor op support grows, but it's still possible to hit it +today given that there are a couple thousand ops within PyTorch. + +``` + NotImplementedError: aten.view_as_real.default +``` + +The error is straightforward; we haven't gotten around to adding op support for this particular op +yet. If you'd like, you can [contribute](contributions) an implementation yourself OR simply +[request](https://github.com/pytorch/pytorch/issues/118107) that we add support for this op +in a future PyTorch release. + +(ragged_structure_incompatibility)= +### Ragged structure incompatibility + +``` + RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128) +``` + +This error occurs when calling an op that operates over multiple NJTs with incompatible ragged +structures. Currently, it is required that input NJTs have the exact same ``offsets`` constituent +in order to have the same symbolic ragged structure symbol (e.g. ``j1``). + +As a workaround for this situation, it is possible to construct NJTs from the ``values`` and +``offsets`` components directly. With both NJTs referencing the same ``offsets`` components, they +are considered to have the same ragged structure and are thus compatible. + +``` +>>> a = torch.randn(50, 128) +>>> b = torch.randn(32, 128) +>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) +>>> nt2 = torch.nested.nested_tensor_from_jagged(values=torch.randn(82, 128), offsets=nt1.offsets()) +>>> nt3 = nt1 + nt2 +>>> nt3.shape +torch.Size([2, j1, 128]) +``` + +### Data dependent operation within torch.compile + +``` + torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True +``` + +This error occurs when calling an op that does data-dependent operation within torch.compile; this +commonly occurs for ops that need to examine the values of the NJT's ``offsets`` to determine the +output shape. For example: + +``` +>>> a = torch.randn(50, 128) +>>> b = torch.randn(32, 128) +>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) +>>> def f(nt): return nt.chunk(2, dim=0)[0] +... +>>> compiled_f = torch.compile(f, fullgraph=True) +>>> output = compiled_f(nt) +``` + +In this example, calling ``chunk()`` on the batch dimension of the NJT requires examination of the +NJT's ``offsets`` data to delineate batch item boundaries within the packed ragged dimension. As a +workaround, there are a couple torch.compile flags that can be set: + +``` +>>> torch._dynamo.config.capture_dynamic_output_shape_ops = True +>>> torch._dynamo.config.capture_scalar_outputs = True +``` + +If, after setting these, you still see data-dependent operator errors, please file an issue with +PyTorch. This area of ``torch.compile()`` is still in heavy development and certain aspects of +NJT support may be incomplete. + +(contributions)= +## Contributions + +If you'd like to contribute to nested tensor development, one of the most impactful ways to do +so is to add nested tensor support for a currently-unsupported PyTorch op. This process generally +consists of a couple simple steps: + +1. Determine the name of the op to add; this should be something like ``aten.view_as_real.default``. + The signature for this op can be found in ``aten/src/ATen/native/native_functions.yaml``. +2. Register an op implementation in ``torch/nested/_internal/ops.py``, following the pattern + established there for other ops. Use the signature from ``native_functions.yaml`` for schema + validation. + +The most common way to implement an op is to unwrap the NJT into its constituents, redispatch the +op on the underlying ``values`` buffer, and propagate the relevant NJT metadata (including +``offsets``) to a new output NJT. If the output of the op is expected to have a different shape +from the input, new ``offsets``, etc. metadata must be computed. + +When an op is applied over the batch or ragged dimension, these tricks can help quickly get a +working implementation: + +* For *non-batchwise* operation, an ``unbind()``-based fallback should work. +* For operation on the ragged dimension, consider converting to padded dense with a properly-selected + padding value that won't negatively bias the output, running the op, and converting back to NJT. + Within ``torch.compile``, these conversions can be fused to avoid materializing the padded + intermediate. + +(construction_and_conversion)= + +## Detailed Docs for Construction and Conversion Functions +```{eval-rst} +.. currentmodule:: torch.nested +``` +```{eval-rst} +.. autofunction:: nested_tensor +``` +```{eval-rst} +.. autofunction:: nested_tensor_from_jagged +``` +```{eval-rst} +.. autofunction:: as_nested_tensor +``` +```{eval-rst} +.. autofunction:: to_padded_tensor +``` +```{eval-rst} +.. autofunction:: masked_select +``` +```{eval-rst} +.. autofunction:: narrow +``` +```{eval-rst} +.. seealso:: + + `Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile `_ +``` \ No newline at end of file diff --git a/docs/source/nested.rst b/docs/source/nested.rst deleted file mode 100644 index da0ef129efd8a6..00000000000000 --- a/docs/source/nested.rst +++ /dev/null @@ -1,493 +0,0 @@ -torch.nested -============ - -.. automodule:: torch.nested - -Introduction -++++++++++++ - -.. warning:: - - The PyTorch API of nested tensors is in prototype stage and will change in the near future. - -Nested tensors allow for ragged-shaped data to be contained within and operated upon as a -single tensor. Such data is stored underneath in an efficient packed representation, while exposing -a standard PyTorch tensor interface for applying operations. - -A common application of nested tensors is for expressing batches of variable-length sequential data -present in various domains, such as varying sentence lengths, image sizes, and audio / video clip -lengths. Traditionally, such data has been handled by padding sequences to that of the max length -within a batch, performing computation on the padded form, and subsequently masking to remove -padding. This is inefficient and error-prone, and nested tensors exist to address these problems. - -The API for calling operations on a nested tensor is no different from that of a regular -``torch.Tensor``, allowing for seamless integration with existing models, with the main -difference being :ref:`construction of the inputs `. - -As this is a prototype feature, the set of :ref:`operations supported ` is -limited, but growing. We welcome issues, feature requests, and contributions. -More information on contributing can be found -`in this Readme `_. - -.. _construction: - -Construction -++++++++++++ - -.. note:: - - There are two forms of nested tensors present within PyTorch, distinguished by layout as - specified during construction. Layout can be one of ``torch.strided`` or ``torch.jagged``. - We recommend utilizing the ``torch.jagged`` layout whenever possible. While it currently only - supports a single ragged dimension, it has better op coverage, receives active development, and - integrates well with ``torch.compile``. These docs adhere to this recommendation and refer to - nested tensors with the ``torch.jagged`` layout as "NJTs" for brevity throughout. - -Construction is straightforward and involves passing a list of tensors to the -``torch.nested.nested_tensor`` constructor. A nested tensor with the ``torch.jagged`` layout -(AKA an "NJT") supports a single ragged dimension. This constructor will copy the input tensors -into a packed, contiguous block of memory according to the layout described in the `data_layout`_ -section below. - ->>> a, b = torch.arange(3), torch.arange(5) + 3 ->>> a -tensor([0, 1, 2]) ->>> b -tensor([3, 4, 5, 6, 7]) ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) ->>> print([component for component in nt]) -[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])] - -Each tensor in the list must have the same number of dimensions, but the shapes can otherwise vary -along a single dimension. If the dimensionalities of the input components don't match, the -constructor throws an error. - ->>> a = torch.randn(50, 128) # 2D tensor ->>> b = torch.randn(2, 50, 128) # 3D tensor ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) -... -RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim - -During construction, dtype, device, and whether gradients are required can be chosen via the -usual keyword arguments. - ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True) ->>> print([component for component in nt]) -[tensor([0., 1., 2.], device='cuda:0', - grad_fn=), tensor([3., 4., 5., 6., 7.], device='cuda:0', - grad_fn=)] - -``torch.nested.as_nested_tensor`` can be used to preserve autograd history from the tensors passed -to the constructor. When this constructor is utilized, gradients will flow through the nested tensor -back into the original components. Note that this constructor still copies the input components into -a packed, contiguous block of memory. - ->>> a = torch.randn(12, 512, requires_grad=True) ->>> b = torch.randn(23, 512, requires_grad=True) ->>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) ->>> nt.sum().backward() ->>> a.grad -tensor([[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - ..., - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]]) ->>> b.grad -tensor([[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - ..., - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]]) - -The above functions all create contiguous NJTs, where a chunk of memory is allocated to store -a packed form of the underlying components (see the `data_layout`_ section below for more -details). - -It is also possible to create a non-contiguous NJT view over a pre-existing dense tensor -with padding, avoiding the memory allocation and copying. ``torch.nested.narrow()`` is the tool -for accomplishing this. - ->>> padded = torch.randn(3, 5, 4) ->>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64) ->>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged) ->>> nt.shape -torch.Size([3, j1, 4]) ->>> nt.is_contiguous() -False - -Note that the nested tensor acts as a view over the original padded dense tensor, referencing the -same memory without copying / allocation. Operation support for non-contiguous NJTs is somewhat more -limited, so if you run into support gaps, it's always possible to convert to a contiguous NJT -using ``contiguous()``. - -.. _data_layout: - -Data Layout and Shape -+++++++++++++++++++++ - -For efficiency, nested tensors generally pack their tensor components into a contiguous chunk of -memory and maintain additional metadata to specify batch item boundaries. For the ``torch.jagged`` -layout, the contiguous chunk of memory is stored in the ``values`` component, with the ``offsets`` -component delineating batch item boundaries for the ragged dimension. - -.. image:: _static/img/nested/njt_visual.png - -It's possible to directly access the underlying NJT components when necessary. - ->>> a = torch.randn(50, 128) # text 1 ->>> b = torch.randn(32, 128) # text 2 ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) ->>> nt.values().shape # note the "packing" of the ragged dimension; no padding needed -torch.Size([82, 128]) ->>> nt.offsets() -tensor([ 0, 50, 82]) - -It can also be useful to construct an NJT from the jagged ``values`` and ``offsets`` -constituents directly; the ``torch.nested.nested_tensor_from_jagged()`` constructor serves -this purpose. - ->>> values = torch.randn(82, 128) ->>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64) ->>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets) - -An NJT has a well-defined shape with dimensionality 1 greater than that of its components. The -underlying structure of the ragged dimension is represented by a symbolic value (``j1`` in the -example below). - ->>> a = torch.randn(50, 128) ->>> b = torch.randn(32, 128) ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) ->>> nt.dim() -3 ->>> nt.shape -torch.Size([2, j1, 128]) - -NJTs must have the same ragged structure to be compatible with each other. For example, to run a -binary operation involving two NJTs, the ragged structures must match (i.e. they must have the -same ragged shape symbol in their shapes). In the details, each symbol corresponds with an exact -``offsets`` tensor, so both NJTs must have the same ``offsets`` tensor to be compatible with -each other. - ->>> a = torch.randn(50, 128) ->>> b = torch.randn(32, 128) ->>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) ->>> nt2 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) ->>> nt1.offsets() is nt2.offsets() -False ->>> nt3 = nt1 + nt2 -RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128) - -In the above example, even though the conceptual shapes of the two NJTs are the same, they don't -share a reference to the same ``offsets`` tensor, so their shapes differ, and they are not -compatible. We recognize that this behavior is unintuitive and are working hard to relax this -restriction for the beta release of nested tensors. For a workaround, see the -:ref:`Troubleshooting ` section of this document. - -In addition to the ``offsets`` metadata, NJTs can also compute and cache the minimum and maximum -sequence lengths for its components, which can be useful for invoking particular kernels (e.g. SDPA). -There are currently no public APIs for accessing these, but this will change for the beta release. - -.. _supported operations: - -Supported Operations -++++++++++++++++++++ - -This section contains a list of common operations over nested tensors that you may find useful. -It is not comprehensive, as there are on the order of a couple thousand ops within PyTorch. While -a sizeable subset of these are supported for nested tensors today, full support is a large task. -The ideal state for nested tensors is full support of all PyTorch operations that are available -for non-nested tensors. To help us accomplish this, please consider: - -* Requesting particular ops needed for your use case - `here `__ to help us prioritize. -* Contributing! It's not too hard to add nested tensor support for a given PyTorch op; see - the `Contributions `__ section below for details. - -Viewing nested tensor constituents -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -``unbind()`` allows you to retrieve a view of the nested tensor's constituents. - ->>> import torch ->>> a = torch.randn(2, 3) ->>> b = torch.randn(3, 3) ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) ->>> nt.unbind() -(tensor([[-0.9916, -0.3363, -0.2799], - [-2.3520, -0.5896, -0.4374]]), tensor([[-2.0969, -1.0104, 1.4841], - [ 2.0952, 0.2973, 0.2516], - [ 0.9035, 1.3623, 0.2026]])) ->>> nt.unbind()[0] is not a -True ->>> nt.unbind()[0].mul_(3) -tensor([[ 3.6858, -3.7030, -4.4525], - [-2.3481, 2.0236, 0.1975]]) ->>> nt.unbind() -(tensor([[-2.9747, -1.0089, -0.8396], - [-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104, 1.4841], - [ 2.0952, 0.2973, 0.2516], - [ 0.9035, 1.3623, 0.2026]])) - -Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which -represents the first entry or constituent of the nested tensor. - -Conversions to / from padded -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -``torch.nested.to_padded_tensor()`` converts an NJT to a padded dense tensor with the specified -padding value. The ragged dimension will be padded out to the size of the maximum sequence length. - ->>> import torch ->>> a = torch.randn(2, 3) ->>> b = torch.randn(6, 3) ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) ->>> padded = torch.nested.to_padded_tensor(nt, padding=4.2) ->>> padded -tensor([[[ 1.6107, 0.5723, 0.3913], - [ 0.0700, -0.4954, 1.8663], - [ 4.2000, 4.2000, 4.2000], - [ 4.2000, 4.2000, 4.2000], - [ 4.2000, 4.2000, 4.2000], - [ 4.2000, 4.2000, 4.2000]], - [[-0.0479, -0.7610, -0.3484], - [ 1.1345, 1.0556, 0.3634], - [-1.7122, -0.5921, 0.0540], - [-0.5506, 0.7608, 2.0606], - [ 1.5658, -1.1934, 0.3041], - [ 0.1483, -1.1284, 0.6957]]]) - -This can be useful as an escape hatch to work around NJT support gaps, but ideally such -conversions should be avoided when possible for optimal memory usage and performance, as the -more efficient nested tensor layout does not materialize padding. - -The reverse conversion can be accomplished using ``torch.nested.narrow()``, which applies -ragged structure to a given dense tensor to produce an NJT. Note that by default, this operation -does not copy the underlying data, and thus the output NJT is generally non-contiguous. It may be -useful to explicitly call ``contiguous()`` here if a contiguous NJT is desired. - ->>> padded = torch.randn(3, 5, 4) ->>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64) ->>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged) ->>> nt.shape -torch.Size([3, j1, 4]) ->>> nt = nt.contiguous() ->>> nt.shape -torch.Size([3, j2, 4]) - -Shape manipulations -^^^^^^^^^^^^^^^^^^^ - -Nested tensors support a wide array of operations for shape manipulation, including views. - ->>> a = torch.randn(2, 6) ->>> b = torch.randn(4, 6) ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) ->>> nt.shape -torch.Size([2, j1, 6]) ->>> nt.unsqueeze(-1).shape -torch.Size([2, j1, 6, 1]) ->>> nt.unflatten(-1, [2, 3]).shape -torch.Size([2, j1, 2, 3]) ->>> torch.cat([nt, nt], dim=2).shape -torch.Size([2, j1, 12]) ->>> torch.stack([nt, nt], dim=2).shape -torch.Size([2, j1, 2, 6]) ->>> nt.transpose(-1, -2).shape -torch.Size([2, 6, j1]) - -Attention mechanisms -^^^^^^^^^^^^^^^^^^^^ - -As variable-length sequences are common inputs to attention mechanisms, nested tensors support -important attention operators -`Scaled Dot Product Attention (SDPA) `_ and -`FlexAttention `_. -See -`here `__ -for usage examples of NJT with SDPA and -`here `__ -for usage examples of NJT with FlexAttention. - -.. _usage_with_torch_compile: - -Usage with torch.compile -++++++++++++++++++++++++ - -NJTs are designed to be used with ``torch.compile()`` for optimal performance, and we always -recommend utilizing ``torch.compile()`` with NJTs when possible. NJTs work out-of-the-box and -graph-break-free both when passed as inputs to a compiled function or module OR when -instantiated in-line within the function. - -.. note:: - If you're not able to utilize ``torch.compile()`` for your use case, performance and memory - usage may still benefit from the use of NJTs, but it's not as clear-cut whether this will be - the case. It is important that the tensors being operated on are large enough so the - performance gains are not outweighed by the overhead of python tensor subclasses. - ->>> import torch ->>> a = torch.randn(2, 3) ->>> b = torch.randn(4, 3) ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) ->>> def f(x): return x.sin() + 1 -... ->>> compiled_f = torch.compile(f, fullgraph=True) ->>> output = compiled_f(nt) ->>> output.shape -torch.Size([2, j1, 3]) ->>> def g(values, offsets): return torch.nested.nested_tensor_from_jagged(values, offsets) * 2. -... ->>> compiled_g = torch.compile(g, fullgraph=True) ->>> output2 = compiled_g(nt.values(), nt.offsets()) ->>> output2.shape -torch.Size([2, j1, 3]) - -Note that NJTs support -`Dynamic Shapes `_ -to avoid unnecessary recompiles with changing ragged structure. - ->>> a = torch.randn(2, 3) ->>> b = torch.randn(4, 3) ->>> c = torch.randn(5, 3) ->>> d = torch.randn(6, 3) ->>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged) ->>> nt2 = torch.nested.nested_tensor([c, d], layout=torch.jagged) ->>> def f(x): return x.sin() + 1 -... ->>> compiled_f = torch.compile(f, fullgraph=True) ->>> output1 = compiled_f(nt1) ->>> output2 = compiled_f(nt2) # NB: No recompile needed even though ragged structure differs - -If you run into problems or arcane errors when utilizing NJT + ``torch.compile``, please file a -PyTorch issue. Full subclass support within ``torch.compile`` is a long-term effort and there may -be some rough edges at this time. - -.. _troubleshooting: - -Troubleshooting -+++++++++++++++ - -This section contains common errors that you may run into when utilizing nested tensors, alongside -the reason for these errors and suggestions for how to address them. - -.. _unimplemented_op: - -Unimplemented ops -^^^^^^^^^^^^^^^^^ - -This error is becoming rarer as nested tensor op support grows, but it's still possible to hit it -today given that there are a couple thousand ops within PyTorch. - -:: - - NotImplementedError: aten.view_as_real.default - -The error is straightforward; we haven't gotten around to adding op support for this particular op -yet. If you'd like, you can `contribute `__ an implementation yourself OR simply -`request `_ that we add support for this op -in a future PyTorch release. - -.. _ragged_structure_incompatibility: - -Ragged structure incompatibility -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128) - -This error occurs when calling an op that operates over multiple NJTs with incompatible ragged -structures. Currently, it is required that input NJTs have the exact same ``offsets`` constituent -in order to have the same symbolic ragged structure symbol (e.g. ``j1``). - -As a workaround for this situation, it is possible to construct NJTs from the ``values`` and -``offsets`` components directly. With both NJTs referencing the same ``offsets`` components, they -are considered to have the same ragged structure and are thus compatible. - ->>> a = torch.randn(50, 128) ->>> b = torch.randn(32, 128) ->>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) ->>> nt2 = torch.nested.nested_tensor_from_jagged(values=torch.randn(82, 128), offsets=nt1.offsets()) ->>> nt3 = nt1 + nt2 ->>> nt3.shape -torch.Size([2, j1, 128]) - -Data dependent operation within torch.compile -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:: - - torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True - -This error occurs when calling an op that does data-dependent operation within torch.compile; this -commonly occurs for ops that need to examine the values of the NJT's ``offsets`` to determine the -output shape. For example: - ->>> a = torch.randn(50, 128) ->>> b = torch.randn(32, 128) ->>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32) ->>> def f(nt): return nt.chunk(2, dim=0)[0] -... ->>> compiled_f = torch.compile(f, fullgraph=True) ->>> output = compiled_f(nt) - -In this example, calling ``chunk()`` on the batch dimension of the NJT requires examination of the -NJT's ``offsets`` data to delineate batch item boundaries within the packed ragged dimension. As a -workaround, there are a couple torch.compile flags that can be set: - ->>> torch._dynamo.config.capture_dynamic_output_shape_ops = True ->>> torch._dynamo.config.capture_scalar_outputs = True - -If, after setting these, you still see data-dependent operator errors, please file an issue with -PyTorch. This area of ``torch.compile()`` is still in heavy development and certain aspects of -NJT support may be incomplete. - -.. _contributions: - -Contributions -+++++++++++++ - -If you'd like to contribute to nested tensor development, one of the most impactful ways to do -so is to add nested tensor support for a currently-unsupported PyTorch op. This process generally -consists of a couple simple steps: - -#. Determine the name of the op to add; this should be something like ``aten.view_as_real.default``. - The signature for this op can be found in ``aten/src/ATen/native/native_functions.yaml``. -#. Register an op implementation in ``torch/nested/_internal/ops.py``, following the pattern - established there for other ops. Use the signature from ``native_functions.yaml`` for schema - validation. - -The most common way to implement an op is to unwrap the NJT into its constituents, redispatch the -op on the underlying ``values`` buffer, and propagate the relevant NJT metadata (including -``offsets``) to a new output NJT. If the output of the op is expected to have a different shape -from the input, new ``offsets``, etc. metadata must be computed. - -When an op is applied over the batch or ragged dimension, these tricks can help quickly get a -working implementation: - -* For *non-batchwise* operation, an ``unbind()``-based fallback should work. -* For operation on the ragged dimension, consider converting to padded dense with a properly-selected - padding value that won't negatively bias the output, running the op, and converting back to NJT. - Within ``torch.compile``, these conversions can be fused to avoid materializing the padded - intermediate. - -.. _construction_and_conversion: - -Detailed Docs for Construction and Conversion Functions -+++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -.. currentmodule:: torch.nested - -.. autofunction:: nested_tensor -.. autofunction:: nested_tensor_from_jagged -.. autofunction:: as_nested_tensor -.. autofunction:: to_padded_tensor -.. autofunction:: masked_select -.. autofunction:: narrow - -.. seealso:: - - `Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile `_ diff --git a/docs/source/nn.attention.bias.md b/docs/source/nn.attention.bias.md new file mode 100644 index 00000000000000..2a373e429b9429 --- /dev/null +++ b/docs/source/nn.attention.bias.md @@ -0,0 +1,30 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` +# torch.nn.attention.bias + +```{eval-rst} +.. automodule:: torch.nn.attention.bias +.. currentmodule:: torch.nn.attention.bias +``` + +## CausalBias + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classnoinheritance.rst + + CausalBias +``` +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + causal_lower_right + causal_upper_left + CausalVariant +``` \ No newline at end of file diff --git a/docs/source/nn.attention.bias.rst b/docs/source/nn.attention.bias.rst deleted file mode 100644 index 200f0e09e43b5a..00000000000000 --- a/docs/source/nn.attention.bias.rst +++ /dev/null @@ -1,27 +0,0 @@ -.. role:: hidden - :class: hidden-section - -torch.nn.attention.bias -======================== - -.. automodule:: torch.nn.attention.bias -.. currentmodule:: torch.nn.attention.bias - -CausalBias -========== - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classnoinheritance.rst - - CausalBias - - -.. autosummary:: - :toctree: generated - :nosignatures: - - causal_lower_right - causal_upper_left - CausalVariant diff --git a/docs/source/nn.attention.experimental.md b/docs/source/nn.attention.experimental.md new file mode 100644 index 00000000000000..0ecf3b1097cd9f --- /dev/null +++ b/docs/source/nn.attention.experimental.md @@ -0,0 +1,12 @@ +# torch.nn.attention.experimental + +```{eval-rst} +.. currentmodule:: torch.nn.attention.experimental +``` +```{eval-rst} +.. py:module:: torch.nn.attention.experimental +``` + +```{warning} + These APIs are experimental and subject to change without notice. +``` \ No newline at end of file diff --git a/docs/source/nn.attention.experimental.rst b/docs/source/nn.attention.experimental.rst deleted file mode 100644 index d09f12a6735a85..00000000000000 --- a/docs/source/nn.attention.experimental.rst +++ /dev/null @@ -1,7 +0,0 @@ -torch.nn.attention.experimental -=============================== -.. currentmodule:: torch.nn.attention.experimental -.. py:module:: torch.nn.attention.experimental - -.. warning:: - These APIs are experimental and subject to change without notice. diff --git a/docs/source/nn.attention.flex_attention.md b/docs/source/nn.attention.flex_attention.md new file mode 100644 index 00000000000000..fdbaff6c7b3ef5 --- /dev/null +++ b/docs/source/nn.attention.flex_attention.md @@ -0,0 +1,45 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# torch.nn.attention.flex_attention + +```{eval-rst} +.. currentmodule:: torch.nn.attention.flex_attention +``` +```{eval-rst} +.. py:module:: torch.nn.attention.flex_attention +``` +```{eval-rst} +.. autofunction:: flex_attention +``` + +## BlockMask Utilities + +```{eval-rst} +.. autofunction:: create_block_mask +``` +```{eval-rst} +.. autofunction:: create_mask +``` +```{eval-rst} +.. autofunction:: create_nested_block_mask +``` +```{eval-rst} +.. autofunction:: and_masks +``` +```{eval-rst} +.. autofunction:: or_masks +``` +```{eval-rst} +.. autofunction:: noop_mask +``` + +## BlockMask + +```{eval-rst} +.. autoclass:: BlockMask + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/docs/source/nn.attention.flex_attention.rst b/docs/source/nn.attention.flex_attention.rst deleted file mode 100644 index 93220ec1f213e9..00000000000000 --- a/docs/source/nn.attention.flex_attention.rst +++ /dev/null @@ -1,27 +0,0 @@ -.. role:: hidden - :class: hidden-section - -====================================== -torch.nn.attention.flex_attention -====================================== - -.. currentmodule:: torch.nn.attention.flex_attention -.. py:module:: torch.nn.attention.flex_attention -.. autofunction:: flex_attention - -BlockMask Utilities -------------------- - -.. autofunction:: create_block_mask -.. autofunction:: create_mask -.. autofunction:: create_nested_block_mask -.. autofunction:: and_masks -.. autofunction:: or_masks -.. autofunction:: noop_mask - -BlockMask ---------- - -.. autoclass:: BlockMask - :members: - :undoc-members: diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index f5a17e21cf7485..a77cf6708bb649 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -105,6 +105,59 @@ To try and reduce the impact of functions that are non-differentiable, we define .. _mathematical function: https://en.wikipedia.org/wiki/Function_%28mathematics%29 +Division by Zero in Autograd +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When performing division by zero in PyTorch (e.g., ``x / 0``), the forward pass will produce ``inf`` values following IEEE-754 floating point arithmetic. While these ``inf`` values can be masked out before computing the final loss (e.g., via indexing or masking), the autograd system still tracks and differentiates through the full computation graph, including the division by zero operation. + +During backpropagation, this can lead to problematic gradient expressions. For example: + +.. code:: + + x = torch.tensor([1., 1.], requires_grad=True) + div = torch.tensor([0., 1.]) + + y = x / div # Results in [inf, 1] + mask = div != 0 # [False, True] + loss = y[mask].sum() + loss.backward() + print(x.grad) # [nan, 1], not [0, 1] + +In this example, even though we only use the masked output (which excludes the division by zero), autograd still computes gradients through the full computation graph, including the division by zero operation. This results in ``nan`` gradients for the masked elements, which can cause training instability. + +To avoid this issue, there are several recommended approaches: + +1. Mask before division: + +.. code:: + + x = torch.tensor([1., 1.], requires_grad=True) + div = torch.tensor([0., 1.]) + + mask = div != 0 + safe = torch.zeros_like(x) + safe[mask] = x[mask] / div[mask] + loss = safe.sum() + loss.backward() # Produces safe gradients [0, 1] + +2. Use MaskedTensor (experimental API): + +.. code:: + + from torch.masked import as_masked_tensor + + x = torch.tensor([1., 1.], requires_grad=True) + div = torch.tensor([0., 1.]) + + y = x / div + mask = div != 0 + loss = as_masked_tensor(y, mask).sum() + loss.backward() # Cleanly handles "undefined" vs "zero" gradients + +The key principle is to prevent the division by zero operation from being recorded in the computation graph, rather than masking its results after the fact. This ensures that autograd only computes gradients through valid operations. + +This behavior is important to keep in mind when working with operations that might produce ``inf`` or ``nan`` values, as masking the outputs does not prevent the problematic gradients from being computed. + .. _locally-disable-grad-doc: Locally disabling gradient computation @@ -214,14 +267,14 @@ In other words, computations in no-grad mode are never recorded in the backward even if there are inputs that have ``require_grad=True``. Enable no-grad mode when you need to perform operations that should not be -recorded by autograd, but you’d still like to use the outputs of these +recorded by autograd, but you'd still like to use the outputs of these computations in grad mode later. This context manager makes it convenient to disable gradients for a block of code or function without having to temporarily set tensors to have ``requires_grad=False``, and then back to ``True``. For example, no-grad mode might be useful when writing an optimizer: when -performing the training update you’d like to update parameters +performing the training update you'd like to update parameters in-place without the update being recorded by autograd. You also intend to use the updated parameters for computations in grad mode in the next forward pass. @@ -241,13 +294,13 @@ will not be able to be used in computations to be recorded by autograd after exiting inference mode. Enable inference mode when you are performing computations that do not have -interactions with autograd, AND you don’t plan on using the tensors created +interactions with autograd, AND you don't plan on using the tensors created in inference mode in any computation that is to be recorded by autograd later. It is recommended that you try out inference mode in the parts of your code that do not require autograd tracking (e.g., data processing and model evaluation). If it works out of the box -for your use case it’s a free performance win. If you run into errors after +for your use case it's a free performance win. If you run into errors after enabling inference mode, check that you are not using tensors created in inference mode in computations that are recorded by autograd after exiting inference mode. If you cannot avoid such use in your case, you can always switch back @@ -278,7 +331,7 @@ BatchNorm running statistics on validation data. It is recommended that you always use ``model.train()`` when training and ``model.eval()`` when evaluating your model (validation/testing) even -if you aren’t sure your model has training-mode specific behavior, because a +if you aren't sure your model has training-mode specific behavior, because a module you are using might be updated to behave differently in training and eval modes. @@ -470,10 +523,10 @@ Wirtinger Calculus comes into the picture ... ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ So, we have this great theory of complex differentiability and -holomorphic functions, and we can’t use any of it at all, because many -of the commonly used functions are not holomorphic. What’s a poor +holomorphic functions, and we can't use any of it at all, because many +of the commonly used functions are not holomorphic. What's a poor mathematician to do? Well, Wirtinger observed that even if :math:`f(z)` -isn’t holomorphic, one could rewrite it as a two variable function +isn't holomorphic, one could rewrite it as a two variable function :math:`f(z, z*)` which is always holomorphic. This is because real and imaginary of the components of :math:`z` can be expressed in terms of :math:`z` and :math:`z^*` as: @@ -516,7 +569,7 @@ There are a lot of beautiful consequences of this change. - For one, the Cauchy-Riemann equations translate into simply saying that :math:`\frac{\partial f}{\partial z^*} = 0` (that is to say, the function :math:`f` can be written entirely in terms of :math:`z`, without making reference to :math:`z^*`). -- Another important (and somewhat counterintuitive) result, as we’ll see later, is that when we do optimization on a real-valued loss, the step we should +- Another important (and somewhat counterintuitive) result, as we'll see later, is that when we do optimization on a real-valued loss, the step we should take while making variable update is given by :math:`\frac{\partial Loss}{\partial z^*}` (not :math:`\frac{\partial Loss}{\partial z}`). For more reading, check out: https://arxiv.org/pdf/0906.4835.pdf @@ -557,7 +610,7 @@ How does PyTorch compute the conjugate Wirtinger derivative? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Typically, our derivative formulas take in `grad_output` as an input, -representing the incoming Vector-Jacobian product that we’ve already +representing the incoming Vector-Jacobian product that we've already computed, aka, :math:`\frac{\partial L}{\partial s^*}`, where :math:`L` is the loss of the entire computation (producing a real loss) and :math:`s` is the output of our function. The goal here is to compute @@ -569,10 +622,10 @@ have access to :math:`\frac{\partial L}{\partial s}`. If you want to skip this derivation, look at the last equation in this section and then skip to the next section. -Let’s continue working with :math:`f: ℂ → ℂ` defined as +Let's continue working with :math:`f: ℂ → ℂ` defined as :math:`f(z) = f(x+yj) = u(x, y) + v(x, y)j`. As discussed above, -autograd’s gradient convention is centered around optimization for real -valued loss functions, so let’s assume :math:`f` is a part of larger +autograd's gradient convention is centered around optimization for real +valued loss functions, so let's assume :math:`f` is a part of larger real valued loss function :math:`g`. Using chain rule, we can write: .. math:: diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index e8197ab93c69ba..ba2c0f47f26c54 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -133,6 +133,44 @@ To toggle the TF32 flags off in C++, you can do at::globalContext().setAllowTF32CuBLAS(false); at::globalContext().setAllowTF32CuDNN(false); +After Pytorch 2.7, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way. +We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator. + +.. code:: python + + torch.backends.fp32_precision = "ieee" + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "tf32" + torch.backends.cudnn.rnn.fp32_precision = "tf32" + +The fp32_precision can be set to `ieee` or `tf32` for `cuda/cudnn`. +`ieee` fp32_precision indicate that we will use `FP32` as internal computation precision. +`tf32` fp32_precision indicate that we will allow to use `TF32` as internal computation precision. + +We can override a generic setting for a specific operator if the fp32_precision is set to `ieee`. + +.. code:: python + + torch.backends.cudnn.fp32_precision = "tf32" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + +We can also override a generic setting for a specific backend if the fp32_precision is set to `ieee`. + +.. code:: python + + torch.backends.fp32_precision = "tf32" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + +For above 2 cases, both `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision` +is overridden to `ieee`. + +Old settings are still supported. But we suggest to use the new settings for better control. And we do not support +to use mix of old and new settings. + For more information about TF32, see: - `TensorFloat-32`_ @@ -481,7 +519,7 @@ Available options: the native CUDACachingAllocator, the sizes are rounded up in multiple of blocks size of 512, so this works fine for smaller sizes. However, this can be inefficient for large near-by allocations as each will go to different - size of blocks and re-use of those blocks are minimized. This might create + size of blocks and reuse of those blocks are minimized. This might create lots of unused blocks and will waste GPU memory capacity. This option enables the rounding of allocation size to nearest power-2 division. For example, if we need to round-up size of 1200 and if number of divisions is 4, @@ -497,10 +535,10 @@ Available options: ``roundup_power2_divisions`` is only meaningful with ``backend:native``. With ``backend:cudaMallocAsync``, ``roundup_power2_divisions`` is ignored. * ``max_non_split_rounding_mb`` will allow non-split blocks for better reuse, eg, - a 1024MB cached block can be re-used for a 512MB allocation request. In the default + a 1024MB cached block can be reused for a 512MB allocation request. In the default case, we only allow up to 20MB of rounding of non-split blocks, so a 512MB block can only be served with between 512-532 MB size block. If we set the value of this - option to 1024, it will alow 512-1536 MB size blocks to be used for a 512MB block + option to 1024, it will allow 512-1536 MB size blocks to be used for a 512MB block which increases reuse of larger blocks. This will also help in reducing the stalls in avoiding expensive cudaMalloc calls. * ``garbage_collection_threshold`` helps actively reclaiming unused GPU memory to @@ -511,6 +549,8 @@ Available options: 80% of the total memory allocated to the GPU application). The algorithm prefers to free old & unused blocks first to avoid freeing blocks that are actively being reused. The threshold value should be between greater than 0.0 and less than 1.0. + The default value is set at 1.0. + ``garbage_collection_threshold`` is only meaningful with ``backend:native``. With ``backend:cudaMallocAsync``, ``garbage_collection_threshold`` is ignored. * ``expandable_segments`` (experimental, default: `False`) If set to `True`, this setting instructs @@ -546,20 +586,20 @@ Available options: appended to the end of the segment. This process does not create as many slivers of unusable memory, so it is more likely to succeed at finding this memory. - `pinned_use_cuda_host_register` option is a boolean flag that determines whether to +* `pinned_use_cuda_host_register` option is a boolean flag that determines whether to use the CUDA API's cudaHostRegister function for allocating pinned memory instead of the default cudaHostAlloc. When set to True, the memory is allocated using regular malloc and then pages are mapped to the memory before calling cudaHostRegister. This pre-mapping of pages helps reduce the lock time during the execution of cudaHostRegister. - `pinned_num_register_threads` option is only valid when pinned_use_cuda_host_register +* `pinned_num_register_threads` option is only valid when pinned_use_cuda_host_register is set to True. By default, one thread is used to map the pages. This option allows using more threads to parallelize the page mapping operations to reduce the overall allocation time of pinned memory. A good value for this option is 8 based on benchmarking results. - `pinned_use_background_threads` option is a boolean flag to enable background thread +* `pinned_use_background_threads` option is a boolean flag to enable background thread for processing events. This avoids any slow path associated with querying/processing of events in the fast allocation path. This feature is disabled by default. @@ -823,7 +863,7 @@ APIs can be used for debugging purposes: out_2 = torch.randn(nelem_1mb, device="cuda") # pool now should have 2 segments since the CUDACachingAllocator had - # to make a new 2 MB buffer to accomodate out_2 + # to make a new 2 MB buffer to accommodate out_2 assert len(pool.snapshot()) == 2 diff --git a/docs/source/notes/fsdp.rst b/docs/source/notes/fsdp.rst index 5a3c1ab377ab5e..ce713fc1697f77 100644 --- a/docs/source/notes/fsdp.rst +++ b/docs/source/notes/fsdp.rst @@ -96,7 +96,7 @@ First, let's cover the buffers allocated for communications: ``forward`` currently requires 2x all-gather buffer size. Here is why: As explained in :ref:`fsdp_prefetch` in the case of explicit ``forward`` prefetching -(``forward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1 +(``forward_prefetch=True``) case of layer 0 all-gather -> layer 0 forward compute -> layer 1 all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward`` while the other is used to do the prefetching. While the implicit ``forward`` prefetching (``forward_prefetch=False``, default) case of the same sequence in theory should need only 1 buffer, in reality it's still 2x all-gather-sized buffers. The reason is that in the flat-parameter FSDP design, we do not copy-out of the all-gather buffer. The parameters used for compute are directly viewed into the all-gather buffer (in fact, the main benefit of the "flat parameter" is exactly this reason). In that case, while 'layer 1 all-gather' is overlapping with 'layer 0 forward compute', the 'layer 0 forward compute' is using the parameters viewed into the 'layer 0 all-gather' buffer. diff --git a/docs/source/notes/get_start_xpu.rst b/docs/source/notes/get_start_xpu.rst index d5f140a3db0b34..5ca51833f02567 100644 --- a/docs/source/notes/get_start_xpu.rst +++ b/docs/source/notes/get_start_xpu.rst @@ -21,20 +21,20 @@ For Intel Data Center GPU For Intel Client GPU -+-------------------------------------+----------------------------------------------------------------------------------------------+ -| Supported OS | Validated Hardware | -+=====================================+==============================================================================================+ -|| Windows 10/11 & Ubuntu 24.10 || Intel® Arc A-Series Graphics (CodeName: Alchemist) | -|| || Intel® Arc B-Series Graphics (CodeName: Battlemage) | -|| || Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake) | -|| || Intel® Core™ Ultra 200V Series with Intel® Arc™ Graphics (CodeName: Lunar Lake) | -|| || Intel® Core™ Ultra Series 2 Processors with Intel® Arc™ Graphics (CodeName: Arrow Lake) | -+-------------------------------------+----------------------------------------------------------------------------------------------+ -|| Ubuntu 24.04 & WSL2 (Ubuntu 24.04) || Intel® Arc A-Series Graphics (CodeName: Alchemist) | -|| || Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake) | -|| || Intel® Core™ Ultra 200V Series with Intel® Arc™ Graphics (CodeName: Lunar Lake) | -|| || Intel® Core™ Ultra Series 2 Processors with Intel® Arc™ Graphics (CodeName: Arrow Lake) | -+-------------------------------------+----------------------------------------------------------------------------------------------+ ++-------------------------------------+----------------------------------------------------------------------------------------------------+ +| Supported OS | Validated Hardware | ++=====================================+====================================================================================================+ +|| Windows 11 & Ubuntu 24.10 || Intel® Arc A-Series Graphics (CodeName: Alchemist) | +|| || Intel® Arc B-Series Graphics (CodeName: Battlemage) | +|| || Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake-H) | +|| || Intel® Core™ Ultra Desktop Processors (Series 2) with Intel® Arc™ Graphics (CodeName: Lunar Lake) | +|| || Intel® Core™ Ultra Mobile Processors (Series 2) with Intel® Arc™ Graphics (CodeName: Arrow Lake-H)| ++-------------------------------------+----------------------------------------------------------------------------------------------------+ +|| Ubuntu 24.04 & WSL2 (Ubuntu 24.04) || Intel® Arc A-Series Graphics (CodeName: Alchemist) | +|| || Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake-H) | +|| || Intel® Core™ Ultra Desktop Processors (Series 2) with Intel® Arc™ Graphics (CodeName: Lunar Lake) | +|| || Intel® Core™ Ultra Mobile Processors (Series 2) with Intel® Arc™ Graphics (CodeName: Arrow Lake-H)| ++-------------------------------------+----------------------------------------------------------------------------------------------------+ Intel GPUs support (Prototype) is ready from PyTorch* 2.5 for Intel® Client GPUs and Intel® Data Center GPU Max Series on both Linux and Windows, which brings Intel GPUs and the SYCL* software stack into the official PyTorch stack with consistent user experience to embrace more AI application scenarios. @@ -52,7 +52,7 @@ Installation Binaries ^^^^^^^^ -Now that we have `Intel GPU Driver `_ installed, use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio`` on Linux. +Now that we have `Intel GPU Driver `_ installed, use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio``. For release wheels @@ -77,7 +77,7 @@ Build from source for ``torch`` refer to `PyTorch Installation Build from source Build from source for ``torchvision`` refer to `Torchvision Installation Build from source `_. -Build from source for ``torchaudio`` refert to `Torchaudio Installation Build from source `_. +Build from source for ``torchaudio`` refer to `Torchaudio Installation Build from source `_. Check availability for Intel GPU -------------------------------- @@ -87,7 +87,7 @@ To check if your Intel GPU is available, you would typically use the following c .. code-block:: import torch - torch.xpu.is_available() # torch.xpu is the API for Intel GPU support + print(torch.xpu.is_available()) # torch.xpu is the API for Intel GPU support If the output is ``False``, double check driver installation for Intel GPUs. @@ -107,7 +107,7 @@ If you are migrating code from ``cuda``, you would change references from ``cuda The following points outline the support and limitations for PyTorch with Intel GPU: #. Both training and inference workflows are supported. -#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to Use Inductor on Windows with CPU/XPU `_. +#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to Use Inductor on Windows with CPU/XPU `_. #. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported. Examples diff --git a/docs/source/notes/gradcheck.rst b/docs/source/notes/gradcheck.rst index 813d2be40998f9..b0c8f9155d9408 100644 --- a/docs/source/notes/gradcheck.rst +++ b/docs/source/notes/gradcheck.rst @@ -67,7 +67,7 @@ If we consider the elementary case of a one-dimensional function (:math:`N = M = \frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps} This formula easily generalizes for multiple outputs (:math:`M \gt 1`) by having :math:`\frac{\partial y}{\partial x}` be a column vector of size :math:`M \times 1` like :math:`f(x + eps)`. -In that case, the above formula can be re-used as-is and approximates the full Jacobian matrix with only two evaluations of the user function (namely :math:`f(x + eps)` and :math:`f(x - eps)`). +In that case, the above formula can be reused as-is and approximates the full Jacobian matrix with only two evaluations of the user function (namely :math:`f(x + eps)` and :math:`f(x - eps)`). It is more computationally expensive to handle the case with multiple inputs (:math:`N \gt 1`). In this scenario, we loop over all the inputs one after the other and apply the :math:`eps` perturbation for each element of :math:`x` one after the other. This allows us to reconstruct the :math:`J_f` matrix column by column. diff --git a/docs/source/notes/libtorch_stable_abi.md b/docs/source/notes/libtorch_stable_abi.md index 188cdf96dcbd65..73b83c72595974 100644 --- a/docs/source/notes/libtorch_stable_abi.md +++ b/docs/source/notes/libtorch_stable_abi.md @@ -30,4 +30,21 @@ This note will eventually contain more details on how to use the APIs in torch/c | ? | ? | c10::SymBool | SymBool | | ? | ? | at::QScheme | QScheme | -Our confidently supported types are the ones in the table that have completed rows. For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. You can work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with aoti_torch_call_dispatcher. +Our confidently supported types are the ones in the table that have completed rows. You can rely on this subset for proper ABI stability. + +For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only. + +You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`. + + +## How to use stack-based APIs + +`aoti_torch_call_dispatcher` is what we consider a stack-based API because it takes as input a stack of StableIValues, which correlates with a `torch::jit::stack` of IValues. Working with the dispatcher will likely bring you into proximity with stack-based APIs, so we are documenting some invariants: + +1. The stack is populated left to right. + a. For example, a stack representing arguments `arg0`, `arg1`, and `arg2` will have `arg0` at index 0, `arg1` at index 1, and `arg2` at index 2. + b. Returns are also populated left to right, e.g., `ret0` will be at index 0 and `ret1` will be at index 1, and so on. + +2. The stack always has ownership of the objects it holds. + a. When calling a stack-based API, you must give owning references to the calling stack and steal references from the returned stack. + b. When registering your function to be called with a stack, you must steal references from your argument stack and push onto the stack new references. diff --git a/docs/source/notes/mkldnn.rst b/docs/source/notes/mkldnn.rst new file mode 100644 index 00000000000000..48ee9ce84c35a3 --- /dev/null +++ b/docs/source/notes/mkldnn.rst @@ -0,0 +1,102 @@ +.. meta:: + :description: A guide to torch.backends.mkldnn, a PyTorch backend to run MKLDNN operations + :keywords: optimize PyTorch, MKLDNN + +.. _mkldnn_backend: + +MKLDNN backend +--------------------------------------------------- + +MKLDNN is an open-source cross-platform performance library of basic building blocks +for deep learning applications. + +.. code:: python + + # The flag below controls whether enable MKLDNN backend in Pytorch. + torch.backends.mkldnn.enabled = True + +Users can disable MKLDNN backend by: + +.. code:: python + + torch.backends.mkldnn.enabled = False + +.. _bf16_on_mkldnn: + +Bfloat16 (BF16) on MKLDNN backend +--------------------------------------------------- + +Starting in PyTorch 2.4, there is a set of APIs to control the internal computation precision +for `float32` operators. + +.. code:: python + + # The flag below controls the internal computation precision for mkldnn matmul. Default ieee is float32. + torch.backends.mkldnn.matmul.fp32_precision = "ieee" + + # The flag below controls the internal computation precision for mkldnn conv. Default ieee is float32. + torch.backends.mkldnn.conv.fp32_precision = "ieee" + + # The flag below controls the internal computation precision for mkldnn rnn. Default ieee is float32. + torch.backends.mkldnn.rnn.fp32_precision = "ieee" + +Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses +matmuls or convolutions are also affected. These include :class:`torch.nn.Linear`, :class:`torch.nn._ConvNd`, :func:`torch.cdist`, +:func:`torch.tensordot`, :func:`torch.nn.functional.affine_grid` and :func:`torch.nn.functional.grid_sample`, +:class:`torch.nn.AdaptiveLogSoftmaxWithLoss`, :class:`torch.nn.GRU` and :class:`torch.nn.LSTM`. + +To get an idea of the precision and speed, see the example code and benchmark data (on SPR) below: + +.. code:: python + + torch.manual_seed(0) + a_full = torch.randn(10240, 10240, dtype=torch.double) + b_full = torch.randn(10240, 10240, dtype=torch.double) + ab_full = a_full @ b_full + mean = ab_full.abs().mean() # 80.7451 + + a = a_full.float() + b = b_full.float() + + # Do matmul at BF16 mode. + torch.backends.mkldnn.matmul.fp32_precision = 'bf16' + ab_bf16 = a @ b # expected speedup with BF16 dot-product acceleration + error = (ab_bf16 - ab_full).abs().max() # 1.3704 + relative_error = error / mean # 0.0170 + print(error, relative_error) + + # Do matmul FP32 mode. + torch.backends.mkldnn.matmul.fp32_precision = 'ieee' + ab_fp32 = a @ b + error = (ab_fp32 - ab_full).abs().max() # 0.0003 + relative_error = error / mean # 0.00000317 + print(error, relative_error) + +From the above example, we can see that with BF16, the speed is ~7x faster on SPR, and that +relative error compared to double precision is approximately 2 orders of magnitude larger. +If full FP32 precision is needed, users can disable BF16 by: + +.. code:: python + + torch.backends.mkldnn.matmul.fp32_precision = 'ieee' + torch.backends.mkldnn.conv.fp32_precision = 'ieee' + torch.backends.mkldnn.rnn.fp32_precision = 'ieee' + +To toggle the BF16 flags off in C++, you can do + +.. code:: C++ + + at::globalContext().setFloat32Precision("ieee", "mkldnn", "matmul"); + at::globalContext().setFloat32Precision("ieee", "mkldnn", "conv"); + at::globalContext().setFloat32Precision("ieee", "mkldnn", "rnn"); + +We can override a generic setting for a specific operator or backend if the fp32_precision is set to `ieee`. + +.. code:: python + + torch.backends.fp32_precision = "bf16" + torch.backends.mkldnn.fp32_precision = "ieee" + torch.backends.mkldnn.matmul.fp32_precision = "ieee" + +For such case, both `torch.backends.mkldnn.fp32_precision` and `torch.backends.mkldnn.matmul.fp32_precision` +is overridden to bf16. diff --git a/docs/source/notes/multiprocessing.rst b/docs/source/notes/multiprocessing.rst index 41b7caa8f4fd47..bfd7a8c0879e80 100644 --- a/docs/source/notes/multiprocessing.rst +++ b/docs/source/notes/multiprocessing.rst @@ -32,7 +32,7 @@ This happens when the accelerator's runtime is not fork safe and is initialized runtime errors in child processes. To prevent such errors: - - Avoid initializing the accelerator in the main process beofre forking child processes. + - Avoid initializing the accelerator in the main process before forking child processes. - Use an alternative process start methods, such as ``spawn`` or ``forkserver``, which ensures a clean initialization of each process. .. _multiprocessing-cuda-note: diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 7834563f26e432..42997694f762be 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -339,172 +339,6 @@ if one does not have access to the ``torch.load`` callsites. if ``weights_only`` was not passed as an argument. -.. _serializing-python-modules: - -Serializing torch.nn.Modules and loading them in C++ ----------------------------------------------------- - -See also: `Tutorial: Loading a TorchScript Model in C++ `_ - -ScriptModules can be serialized as a TorchScript program and loaded -using :func:`torch.jit.load`. -This serialization encodes all the modules’ methods, submodules, parameters, -and attributes, and it allows the serialized program to be loaded in C++ -(i.e. without Python). - -The distinction between :func:`torch.jit.save` and :func:`torch.save` may not -be immediately clear. :func:`torch.save` saves Python objects with pickle. -This is especially useful for prototyping, researching, and training. -:func:`torch.jit.save`, on the other hand, serializes ScriptModules to a format -that can be loaded in Python or C++. This is useful when saving and loading C++ -modules or for running modules trained in Python with C++, a common practice -when deploying PyTorch models. - -To script, serialize and load a module in Python: - -:: - - >>> scripted_module = torch.jit.script(MyModule()) - >>> torch.jit.save(scripted_module, 'mymodule.pt') - >>> torch.jit.load('mymodule.pt') - RecursiveScriptModule( original_name=MyModule - (l0): RecursiveScriptModule(original_name=Linear) - (l1): RecursiveScriptModule(original_name=Linear) ) - - -Traced modules can also be saved with :func:`torch.jit.save`, with the caveat -that only the traced code path is serialized. The following example demonstrates -this: - -:: - - # A module with control flow - >>> class ControlFlowModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.l0 = torch.nn.Linear(4, 2) - self.l1 = torch.nn.Linear(2, 1) - - def forward(self, input): - if input.dim() > 1: - return torch.tensor(0) - - out0 = self.l0(input) - out0_relu = torch.nn.functional.relu(out0) - return self.l1(out0_relu) - - >>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4)) - >>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt') - >>> loaded = torch.jit.load('controlflowmodule_traced.pt') - >>> loaded(torch.randn(2, 4))) - tensor([[-0.1571], [-0.3793]], grad_fn=) - - >>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4)) - >>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt') - >>> loaded = torch.jit.load('controlflowmodule_scripted.pt') - >> loaded(torch.randn(2, 4)) - tensor(0) - -The above module has an if statement that is not triggered by the traced inputs, -and so is not part of the traced module and not serialized with it. -The scripted module, however, contains the if statement and is serialized with it. -See the `TorchScript documentation `_ -for more on scripting and tracing. - -Finally, to load the module in C++: - -:: - - >>> torch::jit::script::Module module; - >>> module = torch::jit::load('controlflowmodule_scripted.pt'); - -See the `PyTorch C++ API documentation `_ -for details about how to use PyTorch modules in C++. - -.. _saving-loading-across-versions: - -Saving and loading ScriptModules across PyTorch versions ------------------------------------------------------------ - -The PyTorch Team recommends saving and loading modules with the same version of -PyTorch. Older versions of PyTorch may not support newer modules, and newer -versions may have removed or modified older behavior. These changes are -explicitly described in -PyTorch’s `release notes `_, -and modules relying on functionality that has changed may need to be updated -to continue working properly. In limited cases, detailed below, PyTorch will -preserve the historic behavior of serialized ScriptModules so they do not require -an update. - -torch.div performing integer division -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In PyTorch 1.5 and earlier :func:`torch.div` would perform floor division when -given two integer inputs: - -:: - - # PyTorch 1.5 (and earlier) - >>> a = torch.tensor(5) - >>> b = torch.tensor(3) - >>> a / b - tensor(1) - -In PyTorch 1.7, however, :func:`torch.div` will always perform a true division -of its inputs, just like division in Python 3: - -:: - - # PyTorch 1.7 - >>> a = torch.tensor(5) - >>> b = torch.tensor(3) - >>> a / b - tensor(1.6667) - -The behavior of :func:`torch.div` is preserved in serialized ScriptModules. -That is, ScriptModules serialized with versions of PyTorch before 1.6 will continue -to see :func:`torch.div` perform floor division when given two integer inputs -even when loaded with newer versions of PyTorch. ScriptModules using :func:`torch.div` -and serialized on PyTorch 1.6 and later cannot be loaded in earlier versions of -PyTorch, however, since those earlier versions do not understand the new behavior. - -torch.full always inferring a float dtype -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In PyTorch 1.5 and earlier :func:`torch.full` always returned a float tensor, -regardless of the fill value it’s given: - -:: - - # PyTorch 1.5 and earlier - >>> torch.full((3,), 1) # Note the integer fill value... - tensor([1., 1., 1.]) # ...but float tensor! - -In PyTorch 1.7, however, :func:`torch.full` will infer the returned tensor’s -dtype from the fill value: - -:: - - # PyTorch 1.7 - >>> torch.full((3,), 1) - tensor([1, 1, 1]) - - >>> torch.full((3,), True) - tensor([True, True, True]) - - >>> torch.full((3,), 1.) - tensor([1., 1., 1.]) - - >>> torch.full((3,), 1 + 1j) - tensor([1.+1.j, 1.+1.j, 1.+1.j]) - -The behavior of :func:`torch.full` is preserved in serialized ScriptModules. That is, -ScriptModules serialized with versions of PyTorch before 1.6 will continue to see -torch.full return float tensors by default, even when given bool or -integer fill values. ScriptModules using :func:`torch.full` and serialized on PyTorch 1.6 -and later cannot be loaded in earlier versions of PyTorch, however, since those -earlier versions do not understand the new behavior. - .. _utility functions: Utility functions diff --git a/docs/source/onnx.md b/docs/source/onnx.md new file mode 100644 index 00000000000000..ad436748022be5 --- /dev/null +++ b/docs/source/onnx.md @@ -0,0 +1,115 @@ +# torch.onnx + + +## Overview + +[Open Neural Network eXchange (ONNX)](https://onnx.ai/) is an open standard +format for representing machine learning models. The `torch.onnx` module captures the computation graph from a +native PyTorch {class}`torch.nn.Module` model and converts it into an +[ONNX graph](https://github.com/onnx/onnx/blob/main/docs/IR.md). + +The exported model can be consumed by any of the many +[runtimes that support ONNX](https://onnx.ai/supported-tools.html#deployModel), including +Microsoft's [ONNX Runtime](https://www.onnxruntime.ai). + +**There are two flavors of ONNX exporter API that you can use, as listed below.** +Both can be called through function {func}`torch.onnx.export`. +Next example shows how to export a simple model. + +```python +import torch + +class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 128, 5) + + def forward(self, x): + return torch.relu(self.conv1(x)) + +input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32) + +model = MyModel() + +torch.onnx.export( + model, # model to export + (input_tensor,), # inputs of the model, + "my_model.onnx", # filename of the ONNX model + input_names=["input"], # Rename inputs for the ONNX model + dynamo=True # True or False to select the exporter to use +) +``` + +Next sections introduce the two versions of the exporter. + +## TorchDynamo-based ONNX Exporter + +*The TorchDynamo-based ONNX exporter is the newest (and Beta) exporter for PyTorch 2.1 and newer* + +TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its +bytecode into an FX Graph. The resulting FX Graph is then polished before it is finally translated into an +ONNX graph. + +The main advantage of this approach is that the [FX graph](https://pytorch.org/docs/stable/fx.html) is captured using +bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques. + +{doc}`Learn more about the TorchDynamo-based ONNX Exporter ` + +## TorchScript-based ONNX Exporter + +*The TorchScript-based ONNX exporter is available since PyTorch 1.2.0* + +[TorchScript](https://pytorch.org/docs/stable/jit.html) is leveraged to trace (through {func}`torch.jit.trace`) +the model and capture a static computation graph. + +As a consequence, the resulting graph has a couple limitations: + +* It does not record any control-flow, like if-statements or loops; +* Does not handle nuances between `training` and `eval` mode; +* Does not truly handle dynamic inputs + +As an attempt to support the static tracing limitations, the exporter also supports TorchScript scripting +(through {func}`torch.jit.script`), which adds support for data-dependent control-flow, for example. However, TorchScript +itself is a subset of the Python language, so not all features in Python are supported, such as in-place operations. + +{doc}`Learn more about the TorchScript-based ONNX Exporter ` + +## Contributing / Developing + +The ONNX exporter is a community project and we welcome contributions. We follow the +[PyTorch guidelines for contributions](https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md), but you might +also be interested in reading our [development wiki](https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter). + +```{eval-rst} +.. toctree:: + :hidden: + + onnx_dynamo + onnx_ops + onnx_verification + onnx_dynamo_onnxruntime_backend + onnx_torchscript +``` + + +```{eval-rst} +.. py:module:: torch.onnx.errors +.. py:module:: torch.onnx.operators +.. py:module:: torch.onnx.symbolic_helper +.. py:module:: torch.onnx.symbolic_opset10 +.. py:module:: torch.onnx.symbolic_opset11 +.. py:module:: torch.onnx.symbolic_opset12 +.. py:module:: torch.onnx.symbolic_opset13 +.. py:module:: torch.onnx.symbolic_opset14 +.. py:module:: torch.onnx.symbolic_opset15 +.. py:module:: torch.onnx.symbolic_opset16 +.. py:module:: torch.onnx.symbolic_opset17 +.. py:module:: torch.onnx.symbolic_opset18 +.. py:module:: torch.onnx.symbolic_opset19 +.. py:module:: torch.onnx.symbolic_opset20 +.. py:module:: torch.onnx.symbolic_opset7 +.. py:module:: torch.onnx.symbolic_opset8 +.. py:module:: torch.onnx.symbolic_opset9 +.. py:module:: torch.onnx.utils +``` diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst deleted file mode 100644 index db569fe7a1c78d..00000000000000 --- a/docs/source/onnx.rst +++ /dev/null @@ -1,116 +0,0 @@ -torch.onnx -========== - -Overview --------- - -`Open Neural Network eXchange (ONNX) `_ is an open standard -format for representing machine learning models. The ``torch.onnx`` module captures the computation graph from a -native PyTorch :class:`torch.nn.Module` model and converts it into an -`ONNX graph `_. - -The exported model can be consumed by any of the many -`runtimes that support ONNX `_, including -Microsoft's `ONNX Runtime `_. - -**There are two flavors of ONNX exporter API that you can use, as listed below.** -Both can be called through function :func:`torch.onnx.export`. -Next example shows how to export a simple model. - -.. code-block:: python - - import torch - - class MyModel(torch.nn.Module): - def __init__(self): - super(MyModel, self).__init__() - self.conv1 = torch.nn.Conv2d(1, 128, 5) - - def forward(self, x): - return torch.relu(self.conv1(x)) - - input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32) - - model = MyModel() - - torch.onnx.export( - model, # model to export - (input_tensor,), # inputs of the model, - "my_model.onnx", # filename of the ONNX model - input_names=["input"], # Rename inputs for the ONNX model - dynamo=True # True or False to select the exporter to use - ) - -Next sections introduces the two versions of the exporter. - -TorchDynamo-based ONNX Exporter -------------------------------- - -*The TorchDynamo-based ONNX exporter is the newest (and Beta) exporter for PyTorch 2.1 and newer* - -TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its -bytecode into an FX Graph. The resulting FX Graph is then polished before it is finally translated into an -ONNX graph. - -The main advantage of this approach is that the `FX graph `_ is captured using -bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques. - -:doc:`Learn more about the TorchDynamo-based ONNX Exporter ` - -TorchScript-based ONNX Exporter -------------------------------- - -*The TorchScript-based ONNX exporter is available since PyTorch 1.2.0* - -`TorchScript `_ is leveraged to trace (through :func:`torch.jit.trace`) -the model and capture a static computation graph. - -As a consequence, the resulting graph has a couple limitations: - -* It does not record any control-flow, like if-statements or loops; -* Does not handle nuances between ``training`` and ``eval`` mode; -* Does not truly handle dynamic inputs - -As an attempt to support the static tracing limitations, the exporter also supports TorchScript scripting -(through :func:`torch.jit.script`), which adds support for data-dependent control-flow, for example. However, TorchScript -itself is a subset of the Python language, so not all features in Python are supported, such as in-place operations. - -:doc:`Learn more about the TorchScript-based ONNX Exporter ` - -Contributing / Developing -------------------------- - -The ONNX exporter is a community project and we welcome contributions. We follow the -`PyTorch guidelines for contributions `_, but you might -also be interested in reading our `development wiki `_. - -.. toctree:: - :hidden: - - onnx_dynamo - onnx_ops - onnx_verification - onnx_dynamo_onnxruntime_backend - onnx_torchscript - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.onnx.errors -.. py:module:: torch.onnx.operators -.. py:module:: torch.onnx.symbolic_caffe2 -.. py:module:: torch.onnx.symbolic_helper -.. py:module:: torch.onnx.symbolic_opset10 -.. py:module:: torch.onnx.symbolic_opset11 -.. py:module:: torch.onnx.symbolic_opset12 -.. py:module:: torch.onnx.symbolic_opset13 -.. py:module:: torch.onnx.symbolic_opset14 -.. py:module:: torch.onnx.symbolic_opset15 -.. py:module:: torch.onnx.symbolic_opset16 -.. py:module:: torch.onnx.symbolic_opset17 -.. py:module:: torch.onnx.symbolic_opset18 -.. py:module:: torch.onnx.symbolic_opset19 -.. py:module:: torch.onnx.symbolic_opset20 -.. py:module:: torch.onnx.symbolic_opset7 -.. py:module:: torch.onnx.symbolic_opset8 -.. py:module:: torch.onnx.symbolic_opset9 -.. py:module:: torch.onnx.utils diff --git a/docs/source/onnx_dynamo.md b/docs/source/onnx_dynamo.md new file mode 100644 index 00000000000000..c5077ef360a5ee --- /dev/null +++ b/docs/source/onnx_dynamo.md @@ -0,0 +1,274 @@ +# TorchDynamo-based ONNX Exporter + +```{eval-rst} +.. automodule:: torch.onnx + :noindex: +``` + +```{contents} +:local: +:depth: 1 +``` + +## Overview + +The ONNX exporter leverages TorchDynamo engine to hook into Python's frame evaluation API +and dynamically rewrite its bytecode into an FX Graph. +The resulting FX Graph is then polished before it is finally translated into an ONNX graph. + +The main advantage of this approach is that the [FX graph](https://pytorch.org/docs/stable/fx.html) is captured using +bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques. + +In addition, during the export process, memory usage is significantly reduced compared to the TorchScript-enabled exporter. +See the {doc}`memory usage documentation ` for more information. + + +## Dependencies + +The ONNX exporter depends on extra Python packages: + + - [ONNX](https://onnx.ai) + - [ONNX Script](https://microsoft.github.io/onnxscript) + +They can be installed through [pip](https://pypi.org/project/pip/): + +```{code-block} bash + + pip install --upgrade onnx onnxscript +``` + +[onnxruntime](https://onnxruntime.ai) can then be used to execute the model +on a large variety of processors. + +## A simple example + +See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example: + +```{code-block} python +import torch +import torch.nn as nn + +class MLPModel(nn.Module): + def __init__(self): + super().__init__() + self.fc0 = nn.Linear(8, 8, bias=True) + self.fc1 = nn.Linear(8, 4, bias=True) + self.fc2 = nn.Linear(4, 2, bias=True) + self.fc3 = nn.Linear(2, 2, bias=True) + self.fc_combined = nn.Linear(8 + 8 + 8, 8, bias=True) # Combine all inputs + + def forward(self, tensor_x: torch.Tensor, input_dict: dict, input_list: list): + """ + Forward method that requires all inputs: + - tensor_x: A direct tensor input. + - input_dict: A dictionary containing the tensor under the key 'tensor_x'. + - input_list: A list where the first element is the tensor. + """ + # Extract tensors from inputs + dict_tensor = input_dict['tensor_x'] + list_tensor = input_list[0] + + # Combine all inputs into a single tensor + combined_tensor = torch.cat([tensor_x, dict_tensor, list_tensor], dim=1) + + # Process the combined tensor through the layers + combined_tensor = self.fc_combined(combined_tensor) + combined_tensor = torch.sigmoid(combined_tensor) + combined_tensor = self.fc0(combined_tensor) + combined_tensor = torch.sigmoid(combined_tensor) + combined_tensor = self.fc1(combined_tensor) + combined_tensor = torch.sigmoid(combined_tensor) + combined_tensor = self.fc2(combined_tensor) + combined_tensor = torch.sigmoid(combined_tensor) + output = self.fc3(combined_tensor) + return output + +model = MLPModel() + +# Example inputs +tensor_input = torch.rand((97, 8), dtype=torch.float32) +dict_input = {'tensor_x': torch.rand((97, 8), dtype=torch.float32)} +list_input = [torch.rand((97, 8), dtype=torch.float32)] + +# The input_names and output_names are used to identify the inputs and outputs of the ONNX model +input_names = ['tensor_input', 'tensor_x', 'list_input_index_0'] +output_names = ['output'] + +# Exporting the model with all required inputs +onnx_program = torch.onnx.export(model,(tensor_input, dict_input, list_input), dynamic_shapes=({0: "batch_size"},{"tensor_x": {0: "batch_size"}},[{0: "batch_size"}]), input_names=input_names, output_names=output_names, dynamo=True,) + +# Check the exported ONNX model is dynamic +assert onnx_program.model.graph.inputs[0].shape == ("batch_size", 8) +assert onnx_program.model.graph.inputs[1].shape == ("batch_size", 8) +assert onnx_program.model.graph.inputs[2].shape == ("batch_size", 8) +``` + +As the code above shows, all you need is to provide {func}`torch.onnx.export` with an instance of the model and its input. +The exporter will then return an instance of {class}`torch.onnx.ONNXProgram` that contains the exported ONNX graph along with extra information. + +The in-memory model available through ``onnx_program.model_proto`` is an ``onnx.ModelProto`` object in compliance with the [ONNX IR spec](https://github.com/onnx/onnx/blob/main/docs/IR.md). +The ONNX model may then be serialized into a [Protobuf file](https://protobuf.dev/) using the {meth}`torch.onnx.ONNXProgram.save` API. + +```{code-block} python + onnx_program.save("mlp.onnx") +``` + +## Use the same model to compare with the TorchScript-enabled exporter + +The biggest difference between the TorchScript-enabled exporter and the TorchDynamo-based exporter is that the latter +requires dynamic_shapes to be the same tree structure as the input, while the former +requires the dynamic_shapes to be a single and flatten dictionary. + +```{code-block} python + torch.onnx.export(model,(tensor_input, dict_input, list_input), "mlp.onnx", dynamic_axes={"tensor_input":{0: "batch_size"}, "tensor_x": {0: "batch_size"}, "list_input_index_0": {0: "batch_size"}}, input_names=input_names, output_names=output_names) +``` + +## Inspecting the ONNX model using GUI + +You can view the exported model using [Netron](https://netron.app/). + +```{image} _static/img/onnx/onnx_dynamo_mlp_model.png +:alt: MLP model as viewed using Netron +:width: 30% +:align: center +``` + +## When the conversion fails + +Function {func}`torch.onnx.export` should be called a second time with +parameter ``report=True``. A markdown report is generated to help the user +to resolve the issue. + +```{toctree} +:hidden: +onnx_dynamo_memory_usage +``` +## Metadata + +During ONNX export, each ONNX node is annotated with metadata that helps trace its origin and context from the original PyTorch model. This metadata is useful for debugging, model inspection, and understanding the mapping between PyTorch and ONNX graphs. + +The following metadata fields are added to each ONNX node: + +- **namespace** + + A string representing the hierarchical namespace of the node, consisting of a stack trace of modules/methods. + + *Example:* + `__main__.SimpleAddModel/add: aten.add.Tensor` + +- **pkg.torch.onnx.class_hierarchy** + + A list of class names representing the hierarchy of modules leading to this node. + + *Example:* + `['__main__.SimpleAddModel', 'aten.add.Tensor']` + +- **pkg.torch.onnx.fx_node** + + The string representation of the original FX node, including its name, number of consumers, the targeted torch op, arguments, and keyword arguments. + + *Example:* + `%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%tensor_x, %input_dict_tensor_x, %input_list_0], 1), kwargs = {})` + +- **pkg.torch.onnx.name_scopes** + + A list of name scopes (methods) representing the path to this node in the PyTorch model. + + *Example:* + `['', 'add']` + +- **pkg.torch.onnx.stack_trace** + + The stack trace from the original code where this node was created, if available. + + *Example:* + ``` + File "simpleadd.py", line 7, in forward + return torch.add(x, y) + ``` + +These metadata fields are stored in the metadata_props attribute of each ONNX node and can be inspected using Netron or programmatically. + +The overall ONNX graph has the following `metadata_props`: + +- **pkg.torch.export.ExportedProgram.graph_signature** + + This property contains a string representation of the graph_signature from the original PyTorch ExportedProgram. The graph signature describes the structure of the model's inputs and outputs and how they map to the ONNX graph. The inputs are defined as `InputSpec` objects, which include the kind of input (e.g., `InputKind.PARAMETER` for parameters, `InputKind.USER_INPUT` for user-defined inputs), the argument name, the target (which can be a specific node in the model), and whether the input is persistent. The outputs are defined as `OutputSpec` objects, which specify the kind of output (e.g., `OutputKind.USER_OUTPUT`) and the argument name. + + To read more about the graph signature, please see the {doc}`torch.export ` for more information. + +- **pkg.torch.export.ExportedProgram.range_constraints** + + This property contains a string representation of any range constraints that were present in the original PyTorch ExportedProgram. Range constraints specify valid ranges for symbolic shapes or values in the model, which can be important for models that use dynamic shapes or symbolic dimensions. + + *Example:* + `s0: VR[2, int_oo]`, which indicates that the size of the input tensor must be at least 2. + + To read more about range constraints, please see the {doc}`torch.export ` for more information. + +Each input value in the ONNX graph may have the following metadata property: + +- **pkg.torch.export.graph_signature.InputSpec.kind** + + The kind of input, as defined by PyTorch's InputKind enum. + + *Example values:* + - "USER_INPUT": A user-provided input to the model. + - "PARAMETER": A model parameter (e.g., weight). + - "BUFFER": A model buffer (e.g., running mean in BatchNorm). + - "CONSTANT_TENSOR": A constant tensor argument. + - "CUSTOM_OBJ": A custom object input. + - "TOKEN": A token input. + +- **pkg.torch.export.graph_signature.InputSpec.persistent** + + Indicates whether the input is persistent (i.e., should be saved as part of the model's state). + + *Example values:* + - "True" + - "False" + +Each output value in the ONNX graph may have the following metadata property: + +- **pkg.torch.export.graph_signature.OutputSpec.kind** + + The kind of input, as defined by PyTorch's OutputKind enum. + + *Example values:* + - "USER_OUTPUT": A user-visible output. + - "LOSS_OUTPUT": A loss value output. + - "BUFFER_MUTATION": Indicates a buffer was mutated. + - "GRADIENT_TO_PARAMETER": Gradient output for a parameter. + - "GRADIENT_TO_USER_INPUT": Gradient output for a user input. + - "USER_INPUT_MUTATION": Indicates a user input was mutated. + - "TOKEN": A token output. + +Each initialized value, input, output has the following metadata: + +- **pkg.torch.onnx.original_node_name** + + The original name of the node in the PyTorch FX graph that produced this value in the case where the value was renamed. This helps trace initializers back to their source in the original model. + + *Example:* + `fc1.weight` + +## API Reference + +```{eval-rst} +.. autofunction:: torch.onnx.export +.. autoclass:: torch.onnx.ONNXProgram + :members: +.. autofunction:: is_in_onnx_export +.. autoclass:: torch.onnx.OnnxExporterError + :members: +.. autofunction:: torch.onnx.enable_fake_mode +``` + +## Deprecated + +The following classes and functions are deprecated and will be removed. + +```{eval-rst} +.. autofunction:: torch.onnx.dynamo_export +.. autoclass:: torch.onnx.ExportOptions +``` diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst deleted file mode 100644 index 5bd57f8c1602aa..00000000000000 --- a/docs/source/onnx_dynamo.rst +++ /dev/null @@ -1,161 +0,0 @@ -TorchDynamo-based ONNX Exporter -=============================== - -.. automodule:: torch.onnx - :noindex: - -.. contents:: :local: - :depth: 1 - -Overview --------- - -The ONNX exporter leverages TorchDynamo engine to hook into Python's frame evaluation API -and dynamically rewrite its bytecode into an FX Graph. -The resulting FX Graph is then polished before it is finally translated into an ONNX graph. - -The main advantage of this approach is that the `FX graph `_ is captured using -bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques. - -In addition, during the export process, memory usage is significantly reduced compared to the TorchScript-enabled exporter. -See the :doc:`memory usage documentation ` for more information. - - -Dependencies ------------- - -The ONNX exporter depends on extra Python packages: - - - `ONNX `_ - - `ONNX Script `_ - -They can be installed through `pip `_: - -.. code-block:: bash - - pip install --upgrade onnx onnxscript - -`onnxruntime `_ can then be used to execute the model -on a large variety of processors. - -A simple example ----------------- - -See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example: - -.. code-block:: python - - class MLPModel(nn.Module): - def __init__(self): - super().__init__() - self.fc0 = nn.Linear(8, 8, bias=True) - self.fc1 = nn.Linear(8, 4, bias=True) - self.fc2 = nn.Linear(4, 2, bias=True) - self.fc3 = nn.Linear(2, 2, bias=True) - self.fc_combined = nn.Linear(8 + 8 + 8, 8, bias=True) # Combine all inputs - - def forward(self, tensor_x: torch.Tensor, input_dict: dict, input_list: list): - """ - Forward method that requires all inputs: - - tensor_x: A direct tensor input. - - input_dict: A dictionary containing the tensor under the key 'tensor_x'. - - input_list: A list where the first element is the tensor. - """ - # Extract tensors from inputs - dict_tensor = input_dict['tensor_x'] - list_tensor = input_list[0] - - # Combine all inputs into a single tensor - combined_tensor = torch.cat([tensor_x, dict_tensor, list_tensor], dim=1) - - # Process the combined tensor through the layers - combined_tensor = self.fc_combined(combined_tensor) - combined_tensor = torch.sigmoid(combined_tensor) - combined_tensor = self.fc0(combined_tensor) - combined_tensor = torch.sigmoid(combined_tensor) - combined_tensor = self.fc1(combined_tensor) - combined_tensor = torch.sigmoid(combined_tensor) - combined_tensor = self.fc2(combined_tensor) - combined_tensor = torch.sigmoid(combined_tensor) - output = self.fc3(combined_tensor) - return output - - model = MLPModel() - - # Example inputs - tensor_input = torch.rand((97, 8), dtype=torch.float32) - dict_input = {'tensor_x': torch.rand((97, 8), dtype=torch.float32)} - list_input = [torch.rand((97, 8), dtype=torch.float32)] - - # The input_names and output_names are used to identify the inputs and outputs of the ONNX model - input_names = ['tensor_input', 'tensor_x', 'list_input_index_0'] - output_names = ['output'] - - # Exporting the model with all required inputs - onnx_program = torch.onnx.export(model,(tensor_input, dict_input, list_input), dynamic_shapes=({0: "batch_size"},{"tensor_x": {0: "batch_size"}},[{0: "batch_size"}]), input_names=input_names, output_names=output_names, dynamo=True,) - - # Check the exported ONNX model is dynamic - assert onnx_program.model.graph.inputs[0].shape == ("batch_size", 8) - assert onnx_program.model.graph.inputs[1].shape == ("batch_size", 8) - assert onnx_program.model.graph.inputs[2].shape == ("batch_size", 8) - -As the code above shows, all you need is to provide :func:`torch.onnx.export` with an instance of the model and its input. -The exporter will then return an instance of :class:`torch.onnx.ONNXProgram` that contains the exported ONNX graph along with extra information. - -The in-memory model available through ``onnx_program.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec `_. -The ONNX model may then be serialized into a `Protobuf file `_ using the :meth:`torch.onnx.ONNXProgram.save` API. - -.. code-block:: python - - onnx_program.save("mlp.onnx") - -Use the same model to compare with the TorchScript-enabled exporter -------------------------------------------------------------------- - -The biggest difference between the TorchScript-enabled exporter and the TorchDynamo-based exporter is that the latter -requires dynamic_shapes to be the same tree structure as the input, while the former -requires the dynamic_shapes to be a single and flatten dictionary. - -.. code-block:: python - - torch.onnx.export(model,(tensor_input, dict_input, list_input), "mlp.onnx", dynamic_axes={"tensor_input":{0: "batch_size"}, "tensor_x": {0: "batch_size"}, "list_input_index_0": {0: "batch_size"}}, input_names=input_names, output_names=output_names) - -Inspecting the ONNX model using GUI ------------------------------------ - -You can view the exported model using `Netron `__. - -.. image:: _static/img/onnx/onnx_dynamo_mlp_model.png - :width: 40% - :alt: MLP model as viewed using Netron - -When the conversion fails -------------------------- - -Function :func:`torch.onnx.export` should called a second time with -parameter ``report=True``. A markdown report is generated to help the user -to resolve the issue. - -.. toctree:: - :hidden: - - onnx_dynamo_memory_usage - -API Reference -------------- - -.. autofunction:: torch.onnx.export -.. autoclass:: torch.onnx.ONNXProgram - :members: -.. autofunction:: is_in_onnx_export -.. autoclass:: torch.onnx.OnnxExporterError - :members: -.. autofunction:: torch.onnx.enable_fake_mode - -Deprecated ----------- - -The following classes and functions are deprecated and will be removed. - -.. autofunction:: torch.onnx.dynamo_export -.. autoclass:: torch.onnx.ExportOptions diff --git a/docs/source/onnx_dynamo_memory_usage.rst b/docs/source/onnx_dynamo_memory_usage.rst index b339d20f0ba6c4..ba1213c6ee085a 100644 --- a/docs/source/onnx_dynamo_memory_usage.rst +++ b/docs/source/onnx_dynamo_memory_usage.rst @@ -103,7 +103,7 @@ The code below could be run to generate a snapshot file which records the state print(f"Export is done.") Open `pytorch.org/memory_viz `_ and drag/drop the generated pickled snapshot file into the visualizer. -The memeory usage is described as below: +The memory usage is described as below: .. image:: _static/img/onnx/torch_dynamo_exporter_memory_usage.png diff --git a/docs/source/onnx_dynamo_onnxruntime_backend.md b/docs/source/onnx_dynamo_onnxruntime_backend.md new file mode 100644 index 00000000000000..a59cd4ab919cd0 --- /dev/null +++ b/docs/source/onnx_dynamo_onnxruntime_backend.md @@ -0,0 +1,11 @@ +# ONNX Backend for TorchDynamo + +For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`. + +```{warning} + The ONNX backend for torch.compile is a rapidly evolving beta technology. +``` + +```{eval-rst} +.. autofunction:: torch.onnx.is_onnxrt_backend_supported +``` \ No newline at end of file diff --git a/docs/source/onnx_dynamo_onnxruntime_backend.rst b/docs/source/onnx_dynamo_onnxruntime_backend.rst deleted file mode 100644 index a8d0c21746eea4..00000000000000 --- a/docs/source/onnx_dynamo_onnxruntime_backend.rst +++ /dev/null @@ -1,9 +0,0 @@ -ONNX Backend for TorchDynamo -============================ - -For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`. - -.. warning:: - The ONNX backend for torch.compile is a rapidly evolving beta technology. - -.. autofunction:: torch.onnx.is_onnxrt_backend_supported \ No newline at end of file diff --git a/docs/source/onnx_ops.md b/docs/source/onnx_ops.md new file mode 100644 index 00000000000000..51ea2afe5eff31 --- /dev/null +++ b/docs/source/onnx_ops.md @@ -0,0 +1,128 @@ +# torch.onnx.ops + +```{eval-rst} +.. automodule:: torch.onnx.ops +``` + +## Symbolic Operators + +Operators that can be used to create any ONNX ops in the FX graph symbolically. +These operators do not do actual computation. It's recommended that you used them +inside an ``if torch.onnx.is_in_onnx_export`` block. + +```{eval-rst} +.. autofunction:: torch.onnx.ops.symbolic +.. autofunction:: torch.onnx.ops.symbolic_multi_out +``` + +## ONNX Operators + +The following operators are implemented as native PyTorch ops and can be exported as +ONNX operators. They can be used natively in an ``nn.Module``. + +For example, you can define a module: + +```py +class Model(torch.nn.Module): + def forward( + self, input_data, cos_cache_data, sin_cache_data, position_ids_data + ): + return torch.onnx.ops.rotary_embedding( + input_data, + cos_cache_data, + sin_cache_data, + position_ids_data, + ) +``` + +and export it to ONNX using: + +```py +input_data = torch.rand(2, 3, 4, 8) +position_ids_data = torch.randint(0, 50, (2, 3)).long() +sin_cache_data = torch.rand(50, 4) +cos_cache_data = torch.rand(50, 4) +dynamic_shapes = { + "input_data": {0: torch.export.Dim.DYNAMIC}, + "cos_cache_data": None, + "sin_cache_data": None, + "position_ids_data": {0: torch.export.Dim.DYNAMIC}, +} +onnx_program = torch.onnx.export( + model, + (input_data, cos_cache_data, sin_cache_data, position_ids_data), + dynamic_shapes=dynamic_shapes, + dynamo=True, + opset_version=23, +) +``` + +Printing the ONNX program will show the ONNX operators used in the graph: + +``` +<...> + +graph( + name=main_graph, + inputs=( + %"input_data", + %"cos_cache_data", + %"sin_cache_data", + %"position_ids_data" + ), + outputs=( + %"rotary_embedding" + ), +) { + 0 | # rotary_embedding + %"rotary_embedding" ⬅️ ::RotaryEmbedding(%"input_data", %"cos_cache_data", %"sin_cache_data", %"position_ids_data") + return %"rotary_embedding" +} +``` + +with the corresponding ``ExportedProgram``: + +ExportedProgram: + +```py +class GraphModule(torch.nn.Module): + def forward(self, input_data: "f32[s0, 3, 4, 8]", cos_cache_data: "f32[50, 4]", sin_cache_data: "f32[50, 4]", position_ids_data: "i64[s0, 3]"): + rotary_embedding: "f32[s0, 3, 4, 8]" = torch.ops.onnx.RotaryEmbedding.opset23(input_data, cos_cache_data, sin_cache_data, position_ids_data); input_data = cos_cache_data = sin_cache_data = position_ids_data = None + return (rotary_embedding,) +``` + +```{eval-rst} +.. autofunction:: torch.onnx.ops.rotary_embedding +.. autofunction:: torch.onnx.ops.attention +``` + +## ONNX to ATen Decomposition Table + +You can use {func}`torch.onnx.ops.aten_decompositions` to obtain a decomposition table +to decompose ONNX operators defined above to ATen operators. + +```py +class Model(torch.nn.Module): + def forward( + self, input_data, cos_cache_data, sin_cache_data, position_ids_data + ): + return torch.onnx.ops.rotary_embedding( + input_data, + cos_cache_data, + sin_cache_data, + position_ids_data, + ) + +model = Model() + +ep = torch.export.export( + model, + (input_data, cos_cache_data, sin_cache_data, position_ids_data), +) +# The program can be decomposed into aten ops +ep_decomposed = ep.run_decompositions(torch.onnx.ops.aten_decompositions()) +``` + +```{eval-rst} +.. autofunction:: torch.onnx.ops.aten_decompositions +``` diff --git a/docs/source/onnx_ops.rst b/docs/source/onnx_ops.rst deleted file mode 100644 index 26628f797429dd..00000000000000 --- a/docs/source/onnx_ops.rst +++ /dev/null @@ -1,11 +0,0 @@ -torch.onnx.ops -============== - -.. automodule:: torch.onnx.ops - -Operators ---------- - -.. autofunction:: torch.onnx.ops.symbolic - -.. autofunction:: torch.onnx.ops.symbolic_multi_out diff --git a/docs/source/onnx_torchscript.rst b/docs/source/onnx_torchscript.rst index 2fa02cf78f0551..b8c43cc28495e5 100644 --- a/docs/source/onnx_torchscript.rst +++ b/docs/source/onnx_torchscript.rst @@ -102,8 +102,6 @@ load and run the model:: ) print(outputs[0]) -Here is a more involved `tutorial on exporting a model and running it with ONNX Runtime `_. - .. _tracing-vs-scripting: Tracing vs Scripting @@ -712,4 +710,4 @@ Classes :nosignatures: :template: classtemplate.rst - JitScalarType + JitScalarType \ No newline at end of file diff --git a/docs/source/onnx_verification.md b/docs/source/onnx_verification.md new file mode 100644 index 00000000000000..cbaad021e960ce --- /dev/null +++ b/docs/source/onnx_verification.md @@ -0,0 +1,33 @@ +# torch.onnx.verification +```{eval-rst} +.. automodule:: torch.onnx.verification +``` + +```{eval-rst} +.. autofunction:: verify_onnx_program +``` + +```{eval-rst} +.. autoclass:: VerificationInfo + :members: +``` + +```{eval-rst} +.. autofunction:: verify +``` + +## Deprecated + +The following classes and functions are deprecated. + + +```{eval-rst} +.. py:class:: check_export_model_diff +.. py:class:: GraphInfo +.. py:class:: GraphInfoPrettyPrinter +.. py:class:: OnnxBackend +.. py:class:: OnnxTestCaseRepro +.. py:class:: VerificationOptions +.. py:function:: find_mismatch +.. py:function:: verify_aten_graph +``` diff --git a/docs/source/onnx_verification.rst b/docs/source/onnx_verification.rst deleted file mode 100644 index 1e197427f8c823..00000000000000 --- a/docs/source/onnx_verification.rst +++ /dev/null @@ -1,26 +0,0 @@ -torch.onnx.verification -======================= - -.. automodule:: torch.onnx.verification - -.. autofunction:: verify_onnx_program - -.. autoclass:: VerificationInfo - :members: - -.. autofunction:: verify - -Deprecated ----------- - -The following classes and functions are deprecated. - -.. Some deprecated members are not publicly shown -.. py:class:: check_export_model_diff -.. py:class:: GraphInfo -.. py:class:: GraphInfoPrettyPrinter -.. py:class:: OnnxBackend -.. py:class:: OnnxTestCaseRepro -.. py:class:: VerificationOptions -.. py:function:: find_mismatch -.. py:function:: verify_aten_graph diff --git a/docs/source/optim.md b/docs/source/optim.md new file mode 100644 index 00000000000000..8a3f03468810d2 --- /dev/null +++ b/docs/source/optim.md @@ -0,0 +1,707 @@ +# torch.optim + +```{eval-rst} +.. automodule:: torch.optim +``` + +## How to use an optimizer + +To use {mod}`torch.optim` you have to construct an optimizer object that will hold +the current state and will update the parameters based on the computed gradients. + +### Constructing it + +To construct an {class}`Optimizer` you have to give it an iterable containing the +parameters (all should be {class}`~torch.nn.Parameter` s) or named parameters +(tuples of (str, {class}`~torch.nn.Parameter`)) to optimize. Then, +you can specify optimizer-specific options such as the learning rate, weight decay, etc. + +Example: +```python +optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) +optimizer = optim.Adam([var1, var2], lr=0.0001) +``` + +Named parameters example: + +```python +optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) +optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001) +``` + +### Per-parameter options + +{class}`Optimizer` s also support specifying per-parameter options. To do this, instead +of passing an iterable of {class}`~torch.autograd.Variable` s, pass in an iterable of +{class}`dict` s. Each of them will define a separate parameter group, and should contain +a `params` key, containing a list of parameters belonging to it. Other keys +should match the keyword arguments accepted by the optimizers, and will be used +as optimization options for this group. + +For example, this is very useful when one wants to specify per-layer learning rates: + +```python +optim.SGD([ + {'params': model.base.parameters(), 'lr': 1e-2}, + {'params': model.classifier.parameters()} +], lr=1e-3, momentum=0.9) + +optim.SGD([ + {'params': model.base.named_parameters(), 'lr': 1e-2}, + {'params': model.classifier.named_parameters()} +], lr=1e-3, momentum=0.9) +``` + +This means that `model.base`'s parameters will use a learning rate of `1e-2`, whereas +`model.classifier`'s parameters will stick to the default learning rate of `1e-3`. +Finally a momentum of `0.9` will be used for all parameters. + +```{note} +You can still pass options as keyword arguments. They will be used as +defaults, in the groups that didn't override them. This is useful when you +only want to vary a single option, while keeping all others consistent +between parameter groups. +``` + +Also consider the following example related to the distinct penalization of parameters. +Remember that {func}`~torch.nn.Module.parameters` returns an iterable that +contains all learnable parameters, including biases and other +parameters that may prefer distinct penalization. To address this, one can specify +individual penalization weights for each parameter group: + +```python +bias_params = [p for name, p in self.named_parameters() if 'bias' in name] +others = [p for name, p in self.named_parameters() if 'bias' not in name] + +optim.SGD([ + {'params': others}, + {'params': bias_params, 'weight_decay': 0} +], weight_decay=1e-2, lr=1e-2) +``` + +In this manner, bias terms are isolated from non-bias terms, and a `weight_decay` +of `0` is set specifically for the bias terms, as to avoid any penalization for +this group. + + +### Taking an optimization step + +All optimizers implement a {func}`~Optimizer.step` method, that updates the +parameters. It can be used in two ways: + +#### `optimizer.step()` + +This is a simplified version supported by most optimizers. The function can be +called once the gradients are computed using e.g. +{func}`~torch.autograd.Variable.backward`. + +Example: + +```python +for input, target in dataset: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() +``` + +#### `optimizer.step(closure)` + +Some optimization algorithms such as Conjugate Gradient and LBFGS need to +reevaluate the function multiple times, so you have to pass in a closure that +allows them to recompute your model. The closure should clear the gradients, +compute the loss, and return it. + +Example: +```python +for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + loss.backward() + return loss + optimizer.step(closure) +``` + +(optimizer-algorithms)= + +## Base class + +```{eval-rst} +.. autoclass:: Optimizer + +.. autosummary:: + :toctree: generated + :nosignatures: + + Optimizer.add_param_group + Optimizer.load_state_dict + Optimizer.register_load_state_dict_pre_hook + Optimizer.register_load_state_dict_post_hook + Optimizer.state_dict + Optimizer.register_state_dict_pre_hook + Optimizer.register_state_dict_post_hook + Optimizer.step + Optimizer.register_step_pre_hook + Optimizer.register_step_post_hook + Optimizer.zero_grad +``` + +## Algorithms + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + Adadelta + Adafactor + Adagrad + Adam + AdamW + SparseAdam + Adamax + ASGD + LBFGS + NAdam + RAdam + RMSprop + Rprop + SGD +``` +Many of our algorithms have various implementations optimized for performance, +readability and/or generality, so we attempt to default to the generally fastest +implementation for the current device if no particular implementation has been +specified by the user. + +We have 3 major categories of implementations: for-loop, foreach (multi-tensor), and +fused. The most straightforward implementations are for-loops over the parameters with +big chunks of computation. For-looping is usually slower than our foreach +implementations, which combine parameters into a multi-tensor and run the big chunks +of computation all at once, thereby saving many sequential kernel calls. A few of our +optimizers have even faster fused implementations, which fuse the big chunks of +computation into one kernel. We can think of foreach implementations as fusing +horizontally and fused implementations as fusing vertically on top of that. + +In general, the performance ordering of the 3 implementations is fused > foreach > for-loop. +So when applicable, we default to foreach over for-loop. Applicable means the foreach +implementation is available, the user has not specified any implementation-specific kwargs +(e.g., fused, foreach, differentiable), and all tensors are native. Note that while fused +should be even faster than foreach, the implementations are newer and we would like to give +them more bake-in time before flipping the switch everywhere. We summarize the stability status +for each implementation on the second table below, you are welcome to try them out though! + +Below is a table showing the available and default implementations of each algorithm: + +```{eval-rst} +.. csv-table:: + :header: "Algorithm", "Default", "Has foreach?", "Has fused?" + :widths: 25, 25, 25, 25 + :delim: ; + + :class:`Adadelta`;foreach;yes;no + :class:`Adafactor`;for-loop;no;no + :class:`Adagrad`;foreach;yes;yes (cpu only) + :class:`Adam`;foreach;yes;yes + :class:`AdamW`;foreach;yes;yes + :class:`SparseAdam`;for-loop;no;no + :class:`Adamax`;foreach;yes;no + :class:`ASGD`;foreach;yes;no + :class:`LBFGS`;for-loop;no;no + :class:`NAdam`;foreach;yes;no + :class:`RAdam`;foreach;yes;no + :class:`RMSprop`;foreach;yes;no + :class:`Rprop`;foreach;yes;no + :class:`SGD`;foreach;yes;yes +``` +Below table is showing the stability status for fused implementations: + +```{eval-rst} +.. csv-table:: + :header: "Algorithm", "CPU", "CUDA", "MPS" + :widths: 25, 25, 25, 25 + :delim: ; + + :class:`Adadelta`;unsupported;unsupported;unsupported + :class:`Adafactor`;unsupported;unsupported;unsupported + :class:`Adagrad`;beta;unsupported;unsupported + :class:`Adam`;beta;stable;beta + :class:`AdamW`;beta;stable;beta + :class:`SparseAdam`;unsupported;unsupported;unsupported + :class:`Adamax`;unsupported;unsupported;unsupported + :class:`ASGD`;unsupported;unsupported;unsupported + :class:`LBFGS`;unsupported;unsupported;unsupported + :class:`NAdam`;unsupported;unsupported;unsupported + :class:`RAdam`;unsupported;unsupported;unsupported + :class:`RMSprop`;unsupported;unsupported;unsupported + :class:`Rprop`;unsupported;unsupported;unsupported + :class:`SGD`;beta;beta;beta +``` + +## How to adjust learning rate + +{class}`torch.optim.lr_scheduler.LRScheduler` provides several methods to adjust the learning +rate based on the number of epochs. {class}`torch.optim.lr_scheduler.ReduceLROnPlateau` +allows dynamic learning rate reducing based on some validation measurements. + +Learning rate scheduling should be applied after optimizer's update; e.g., you +should write your code this way: + +Example: +```python +optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) +scheduler = ExponentialLR(optimizer, gamma=0.9) + +for epoch in range(20): + for input, target in dataset: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + scheduler.step() +``` + +Most learning rate schedulers can be called back-to-back (also referred to as +chaining schedulers). The result is that each scheduler is applied one after the +other on the learning rate obtained by the one preceding it. + +Example: +```python +optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) +scheduler1 = ExponentialLR(optimizer, gamma=0.9) +scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) + +for epoch in range(20): + for input, target in dataset: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + scheduler1.step() + scheduler2.step() +``` + +In many places in the documentation, we will use the following template to refer to schedulers +algorithms. + +```python +>>> scheduler = ... +>>> for epoch in range(100): +>>> train(...) +>>> validate(...) +>>> scheduler.step() +``` + +```{warning} +Prior to PyTorch 1.1.0, the learning rate scheduler was expected to be called before +the optimizer's update; 1.1.0 changed this behavior in a BC-breaking way. If you use +the learning rate scheduler (calling `scheduler.step()`) before the optimizer's update +(calling `optimizer.step()`), this will skip the first value of the learning rate schedule. +If you are unable to reproduce results after upgrading to PyTorch 1.1.0, please check +if you are calling `scheduler.step()` at the wrong time. +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + lr_scheduler.LRScheduler + lr_scheduler.LambdaLR + lr_scheduler.MultiplicativeLR + lr_scheduler.StepLR + lr_scheduler.MultiStepLR + lr_scheduler.ConstantLR + lr_scheduler.LinearLR + lr_scheduler.ExponentialLR + lr_scheduler.PolynomialLR + lr_scheduler.CosineAnnealingLR + lr_scheduler.ChainedScheduler + lr_scheduler.SequentialLR + lr_scheduler.ReduceLROnPlateau + lr_scheduler.CyclicLR + lr_scheduler.OneCycleLR + lr_scheduler.CosineAnnealingWarmRestarts +``` + +## How to utilize named parameters to load optimizer state dict + +The function {func}`~Optimizer.load_state_dict` stores the optional `param_names` content from the +loaded state dict if present. However, the process of loading the optimizer state is not affected, +as the order of the parameters matters to maintain compatibility (in case of different ordering). +To utilize the loaded parameters names from the loaded state dict, a custom `register_load_state_dict_pre_hook` +needs to be implemented according to the desired behavior. + +This can be useful, for instance, when the model architecture changes, but the weights and optimizer states need to +remain unchanged. The following example demonstrates how to implement this customization. + +Example: +```python +class OneLayerModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(3, 4) + + def forward(self, x): + return self.fc(x) + +model = OneLayerModel() +optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) +# training.. +torch.save(optimizer.state_dict(), PATH) +``` + +Let's say that `model` implements an expert (MoE), and we want to duplicate it and resume training +for two experts, both initialized the same way as the `fc` layer. For the following `model2` we create two layers identical to `fc` and resume training by loading the model weights and optimizer states from `model` into both `fc1` and `fc2` of `model2` (and adjust them accordingly): + +```python +class TwoLayerModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(3, 4) + self.fc2 = nn.Linear(3, 4) + + def forward(self, x): + return (self.fc1(x) + self.fc2(x)) / 2 + +model2 = TwoLayerModel() +# adapt and load model weights.. +optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9) +``` + +To load the state dict for `optimizer2` with the state dict of the previous optimizer such that both +`fc1` and `fc2` will be initialized with a copy of `fc` optimizer states +(to resume training for each layer from `fc`), we can use the following hook: + +```python +def adapt_state_dict_ids(optimizer, state_dict): + adapted_state_dict = deepcopy(optimizer.state_dict()) + # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict. + for k, v in state_dict['param_groups'][0].items(): + if k not in ['params', 'param_names']: + adapted_state_dict['param_groups'][0][k] = v + + lookup_dict = { + 'fc1.weight': 'fc.weight', + 'fc1.bias': 'fc.bias', + 'fc2.weight': 'fc.weight', + 'fc2.bias': 'fc.bias' + } + clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()} + for param_id, param_name in zip( + optimizer.state_dict()['param_groups'][0]['params'], + optimizer.state_dict()['param_groups'][0]['param_names']): + name_in_loaded = lookup_dict[param_name] + index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded) + id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list] + # Copy the state of the corresponding parameter + if id_in_loaded in state_dict['state']: + adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded]) + + return adapted_state_dict + +optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids) +optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict +``` + +This ensures that the adapted state_dict with the correct states for the layers of `model2` will be used +during model loading. +Note that this code is designed specifically for this example (e.g., assuming a single parameter group), +and other cases might require different adaptations. + +The following example shows how to handle missing parameters in a loaded +`state dict` when the model structure changes. +The `Model_bypass` adds a new `bypass` layer, which is not present in the original `Model1`. +To resume training, a custom `adapt_state_dict_missing_param` hook is used to adapt the optimizer's `state_dict`, +ensuring existing parameters are mapped correctly, while missing ones (like the bypass layer) remain unchanged +(as initialized in this example). +This approach enables smooth loading and resuming of the optimizer state despite model changes. +The new bypass layer will be trained from scratch: + +```python +class Model1(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 5) + + def forward(self, x): + return self.fc(x) + x + + +model = Model1() +optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) +# training.. +torch.save(optimizer.state_dict(), PATH) + +class Model_bypass(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 5) + self.bypass = nn.Linear(5, 5, bias=False) + torch.nn.init.eye_(self.bypass.weight) + + def forward(self, x): + return self.fc(x) + self.bypass(x) + +model2 = Model_bypass() +optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9) + +def adapt_state_dict_missing_param(optimizer, state_dict): + adapted_state_dict = deepcopy(optimizer.state_dict()) + # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict. + for k, v in state_dict['param_groups'][0].items(): + if k not in ['params', 'param_names']: + adapted_state_dict['param_groups'][0][k] = v + + lookup_dict = { + 'fc.weight': 'fc.weight', + 'fc.bias': 'fc.bias', + 'bypass.weight': None, + } + + clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()} + for param_id, param_name in zip( + optimizer.state_dict()['param_groups'][0]['params'], + optimizer.state_dict()['param_groups'][0]['param_names']): + name_in_loaded = lookup_dict[param_name] + if name_in_loaded in state_dict['param_groups'][0]['param_names']: + index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded) + id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list] + # Copy the state of the corresponding parameter + if id_in_loaded in state_dict['state']: + adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded]) + + return adapted_state_dict + +optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids) +optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict +``` + + +As a third example, instead of loading a state according to the order of parameters (the default approach), +this hook can be used to load according to the parameters' names: + +```python +def names_matching(optimizer, state_dict): + assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups']) + adapted_state_dict = deepcopy(optimizer.state_dict()) + for g_ind in range(len(state_dict['param_groups'])): + assert len(state_dict['param_groups'][g_ind]['params']) == len( + optimizer.state_dict()['param_groups'][g_ind]['params']) + + for k, v in state_dict['param_groups'][g_ind].items(): + if k not in ['params', 'param_names']: + adapted_state_dict['param_groups'][g_ind][k] = v + + for param_id, param_name in zip( + optimizer.state_dict()['param_groups'][g_ind]['params'], + optimizer.state_dict()['param_groups'][g_ind]['param_names']): + index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name) + id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list] + # Copy the state of the corresponding parameter + if id_in_loaded in state_dict['state']: + adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded]) + + return adapted_state_dict +``` + + +## Weight Averaging (SWA and EMA) + +{class}`torch.optim.swa_utils.AveragedModel` implements Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA), +{class}`torch.optim.swa_utils.SWALR` implements the SWA learning rate scheduler and +{func}`torch.optim.swa_utils.update_bn` is a utility function used to update SWA/EMA batch +normalization statistics at the end of training. + +SWA has been proposed in [Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407). + +EMA is a widely known technique to reduce the training time by reducing the number of weight updates needed. +It is a variation of [Polyak averaging](https://paperswithcode.com/method/polyak-averaging), but using exponential weights instead of equal weights across iterations. + +### Constructing averaged models + +The `AveragedModel` class serves to compute the weights of the SWA or EMA model. + +You can create an SWA averaged model by running: + +```python +>>> averaged_model = AveragedModel(model) +``` + +EMA models are constructed by specifying the `multi_avg_fn` argument as follows: + +```python +>>> decay = 0.999 +>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay)) +``` + +Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to {func}`torch.optim.swa_utils.get_ema_multi_avg_fn`, the default is 0.999. Decay value should be close to 1.0, as smaller values can cause optimization convergence issues. + +{func}`torch.optim.swa_utils.get_ema_multi_avg_fn` returns a function that applies the following EMA equation to the weights: + +```{math} +W^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t +``` + +where alpha is the EMA decay. + +Here the model `model` can be an arbitrary {class}`torch.nn.Module` object. `averaged_model` +will keep track of the running averages of the parameters of the `model`. To update these +averages, you should use the {func}`update_parameters` function after the `optimizer.step()`: + +```python +>>> averaged_model.update_parameters(model) +``` + +For SWA and EMA, this call is usually done right after the optimizer `step()`. In the case of SWA, this is usually skipped for some numbers of steps at the beginning of the training. + +### Custom averaging strategies + +By default, {class}`torch.optim.swa_utils.AveragedModel` computes a running equal average of +the parameters that you provide, but you can also use custom averaging functions with the +`avg_fn` or `multi_avg_fn` parameters: + +- `avg_fn` allows defining a function operating on each parameter tuple (averaged parameter, model parameter) and should return the new averaged parameter. +- `multi_avg_fn` allows defining more efficient operations acting on a tuple of parameter lists, (averaged parameter list, model parameter list), at the same time, for example using the `torch._foreach*` functions. This function must update the averaged parameters in-place. + +In the following example `ema_model` computes an exponential moving average using the `avg_fn` parameter: + +```python +>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\ +>>> 0.9 * averaged_model_parameter + 0.1 * model_parameter +>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg) +``` + +In the following example `ema_model` computes an exponential moving average using the more efficient `multi_avg_fn` parameter: + +```python +>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9)) +``` + +### SWA learning rate schedules + +Typically, in SWA the learning rate is set to a high constant value. {class}`SWALR` is a +learning rate scheduler that anneals the learning rate to a fixed value, and then keeps it +constant. For example, the following code creates a scheduler that linearly anneals the +learning rate from its initial value to 0.05 in 5 epochs within each parameter group: + +```python +>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \ +>>> anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05) +``` + +You can also use cosine annealing to a fixed value instead of linear annealing by setting +`anneal_strategy="cos"`. + + +### Taking care of batch normalization + +{func}`update_bn` is a utility function that allows to compute the batchnorm statistics for the SWA model +on a given dataloader `loader` at the end of training: + +```python +>>> torch.optim.swa_utils.update_bn(loader, swa_model) +``` + +{func}`update_bn` applies the `swa_model` to every element in the dataloader and computes the activation +statistics for each batch normalization layer in the model. + +```{warning} +{func}`update_bn` assumes that each batch in the dataloader `loader` is either a tensors or a list of +tensors where the first element is the tensor that the network `swa_model` should be applied to. +If your dataloader has a different structure, you can update the batch normalization statistics of the +`swa_model` by doing a forward pass with the `swa_model` on each element of the dataset. +``` + + + +### Putting it all together: SWA + +In the example below, `swa_model` is the SWA model that accumulates the averages of the weights. +We train the model for a total of 300 epochs and we switch to the SWA learning rate schedule +and start to collect SWA averages of the parameters at epoch 160: + +```python +>>> loader, optimizer, model, loss_fn = ... +>>> swa_model = torch.optim.swa_utils.AveragedModel(model) +>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) +>>> swa_start = 160 +>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) +>>> +>>> for epoch in range(300): +>>> for input, target in loader: +>>> optimizer.zero_grad() +>>> loss_fn(model(input), target).backward() +>>> optimizer.step() +>>> if epoch > swa_start: +>>> swa_model.update_parameters(model) +>>> swa_scheduler.step() +>>> else: +>>> scheduler.step() +>>> +>>> # Update bn statistics for the swa_model at the end +>>> torch.optim.swa_utils.update_bn(loader, swa_model) +>>> # Use swa_model to make predictions on test data +>>> preds = swa_model(test_input) +``` + +### Putting it all together: EMA + +In the example below, `ema_model` is the EMA model that accumulates the exponentially-decayed averages of the weights with a decay rate of 0.999. +We train the model for a total of 300 epochs and start to collect EMA averages immediately. + +```python +>>> loader, optimizer, model, loss_fn = ... +>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \ +>>> multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999)) +>>> +>>> for epoch in range(300): +>>> for input, target in loader: +>>> optimizer.zero_grad() +>>> loss_fn(model(input), target).backward() +>>> optimizer.step() +>>> ema_model.update_parameters(model) +>>> +>>> # Update bn statistics for the ema_model at the end +>>> torch.optim.swa_utils.update_bn(loader, ema_model) +>>> # Use ema_model to make predictions on test data +>>> preds = ema_model(test_input) +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + swa_utils.AveragedModel + swa_utils.SWALR + + +.. autofunction:: torch.optim.swa_utils.get_ema_multi_avg_fn +.. autofunction:: torch.optim.swa_utils.update_bn +``` + + +```{eval-rst} +.. py:module:: torch.optim.adadelta +.. py:module:: torch.optim.adagrad +.. py:module:: torch.optim.adam +.. py:module:: torch.optim.adamax +.. py:module:: torch.optim.adamw +.. py:module:: torch.optim.asgd +.. py:module:: torch.optim.lbfgs +.. py:module:: torch.optim.lr_scheduler +.. py:module:: torch.optim.nadam +.. py:module:: torch.optim.optimizer +.. py:module:: torch.optim.radam +.. py:module:: torch.optim.rmsprop +.. py:module:: torch.optim.rprop +.. py:module:: torch.optim.sgd +.. py:module:: torch.optim.sparse_adam +.. py:module:: torch.optim.swa_utils +``` diff --git a/docs/source/optim.rst b/docs/source/optim.rst deleted file mode 100644 index a5ae21b83580ce..00000000000000 --- a/docs/source/optim.rst +++ /dev/null @@ -1,677 +0,0 @@ -torch.optim -=================================== - -.. automodule:: torch.optim - -How to use an optimizer ------------------------ - -To use :mod:`torch.optim` you have to construct an optimizer object that will hold -the current state and will update the parameters based on the computed gradients. - -Constructing it -^^^^^^^^^^^^^^^ - -To construct an :class:`Optimizer` you have to give it an iterable containing the -parameters (all should be :class:`~torch.nn.Parameter` s) or named parameters -(tuples of (str, :class:`~torch.nn.Parameter`)) to optimize. Then, -you can specify optimizer-specific options such as the learning rate, weight decay, etc. - -Example:: - - optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - optimizer = optim.Adam([var1, var2], lr=0.0001) - -Named parameters example:: - - optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) - optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001) - -Per-parameter options -^^^^^^^^^^^^^^^^^^^^^ - -:class:`Optimizer` s also support specifying per-parameter options. To do this, instead -of passing an iterable of :class:`~torch.autograd.Variable` s, pass in an iterable of -:class:`dict` s. Each of them will define a separate parameter group, and should contain -a ``params`` key, containing a list of parameters belonging to it. Other keys -should match the keyword arguments accepted by the optimizers, and will be used -as optimization options for this group. - -For example, this is very useful when one wants to specify per-layer learning rates:: - - optim.SGD([ - {'params': model.base.parameters(), 'lr': 1e-2}, - {'params': model.classifier.parameters()} - ], lr=1e-3, momentum=0.9) - - optim.SGD([ - {'params': model.base.named_parameters(), 'lr': 1e-2}, - {'params': model.classifier.named_parameters()} - ], lr=1e-3, momentum=0.9) - -This means that ``model.base``'s parameters will use a learning rate of ``1e-2``, whereas -``model.classifier``'s parameters will stick to the default learning rate of ``1e-3``. -Finally a momentum of ``0.9`` will be used for all parameters. - -.. note:: - - You can still pass options as keyword arguments. They will be used as - defaults, in the groups that didn't override them. This is useful when you - only want to vary a single option, while keeping all others consistent - between parameter groups. - -Also consider the following example related to the distinct penalization of parameters. -Remember that :func:`~torch.nn.Module.parameters` returns an iterable that -contains all learnable parameters, including biases and other -parameters that may prefer distinct penalization. To address this, one can specify -individual penalization weights for each parameter group:: - - bias_params = [p for name, p in self.named_parameters() if 'bias' in name] - others = [p for name, p in self.named_parameters() if 'bias' not in name] - - optim.SGD([ - {'params': others}, - {'params': bias_params, 'weight_decay': 0} - ], weight_decay=1e-2, lr=1e-2) - -In this manner, bias terms are isolated from non-bias terms, and a ``weight_decay`` -of ``0`` is set specifically for the bias terms, as to avoid any penalization for -this group. - - -Taking an optimization step -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -All optimizers implement a :func:`~Optimizer.step` method, that updates the -parameters. It can be used in two ways: - -``optimizer.step()`` -~~~~~~~~~~~~~~~~~~~~ - -This is a simplified version supported by most optimizers. The function can be -called once the gradients are computed using e.g. -:func:`~torch.autograd.Variable.backward`. - -Example:: - - for input, target in dataset: - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - loss.backward() - optimizer.step() - -``optimizer.step(closure)`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Some optimization algorithms such as Conjugate Gradient and LBFGS need to -reevaluate the function multiple times, so you have to pass in a closure that -allows them to recompute your model. The closure should clear the gradients, -compute the loss, and return it. - -Example:: - - for input, target in dataset: - def closure(): - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - loss.backward() - return loss - optimizer.step(closure) - -.. _optimizer-algorithms: - -Base class ----------- - -.. autoclass:: Optimizer - -.. autosummary:: - :toctree: generated - :nosignatures: - - Optimizer.add_param_group - Optimizer.load_state_dict - Optimizer.register_load_state_dict_pre_hook - Optimizer.register_load_state_dict_post_hook - Optimizer.state_dict - Optimizer.register_state_dict_pre_hook - Optimizer.register_state_dict_post_hook - Optimizer.step - Optimizer.register_step_pre_hook - Optimizer.register_step_post_hook - Optimizer.zero_grad - -Algorithms ----------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - Adadelta - Adafactor - Adagrad - Adam - AdamW - SparseAdam - Adamax - ASGD - LBFGS - NAdam - RAdam - RMSprop - Rprop - SGD - -Many of our algorithms have various implementations optimized for performance, -readability and/or generality, so we attempt to default to the generally fastest -implementation for the current device if no particular implementation has been -specified by the user. - -We have 3 major categories of implementations: for-loop, foreach (multi-tensor), and -fused. The most straightforward implementations are for-loops over the parameters with -big chunks of computation. For-looping is usually slower than our foreach -implementations, which combine parameters into a multi-tensor and run the big chunks -of computation all at once, thereby saving many sequential kernel calls. A few of our -optimizers have even faster fused implementations, which fuse the big chunks of -computation into one kernel. We can think of foreach implementations as fusing -horizontally and fused implementations as fusing vertically on top of that. - -In general, the performance ordering of the 3 implementations is fused > foreach > for-loop. -So when applicable, we default to foreach over for-loop. Applicable means the foreach -implementation is available, the user has not specified any implementation-specific kwargs -(e.g., fused, foreach, differentiable), and all tensors are native. Note that while fused -should be even faster than foreach, the implementations are newer and we would like to give -them more bake-in time before flipping the switch everywhere. We summarize the stability status -for each implementation on the second table below, you are welcome to try them out though! - -Below is a table showing the available and default implementations of each algorithm: - -.. csv-table:: - :header: "Algorithm", "Default", "Has foreach?", "Has fused?" - :widths: 25, 25, 25, 25 - :delim: ; - - :class:`Adadelta`;foreach;yes;no - :class:`Adafactor`;for-loop;no;no - :class:`Adagrad`;foreach;yes;yes (cpu only) - :class:`Adam`;foreach;yes;yes - :class:`AdamW`;foreach;yes;yes - :class:`SparseAdam`;for-loop;no;no - :class:`Adamax`;foreach;yes;no - :class:`ASGD`;foreach;yes;no - :class:`LBFGS`;for-loop;no;no - :class:`NAdam`;foreach;yes;no - :class:`RAdam`;foreach;yes;no - :class:`RMSprop`;foreach;yes;no - :class:`Rprop`;foreach;yes;no - :class:`SGD`;foreach;yes;yes - -Below table is showing the stability status for fused implementations: - -.. csv-table:: - :header: "Algorithm", "CPU", "CUDA", "MPS" - :widths: 25, 25, 25, 25 - :delim: ; - - :class:`Adadelta`;unsupported;unsupported;unsupported - :class:`Adafactor`;unsupported;unsupported;unsupported - :class:`Adagrad`;beta;unsupported;unsupported - :class:`Adam`;beta;stable;beta - :class:`AdamW`;beta;stable;beta - :class:`SparseAdam`;unsupported;unsupported;unsupported - :class:`Adamax`;unsupported;unsupported;unsupported - :class:`ASGD`;unsupported;unsupported;unsupported - :class:`LBFGS`;unsupported;unsupported;unsupported - :class:`NAdam`;unsupported;unsupported;unsupported - :class:`RAdam`;unsupported;unsupported;unsupported - :class:`RMSprop`;unsupported;unsupported;unsupported - :class:`Rprop`;unsupported;unsupported;unsupported - :class:`SGD`;beta;beta;beta - -How to adjust learning rate ---------------------------- - -:class:`torch.optim.lr_scheduler.LRScheduler` provides several methods to adjust the learning -rate based on the number of epochs. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau` -allows dynamic learning rate reducing based on some validation measurements. - -Learning rate scheduling should be applied after optimizer's update; e.g., you -should write your code this way: - -Example:: - - optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - scheduler = ExponentialLR(optimizer, gamma=0.9) - - for epoch in range(20): - for input, target in dataset: - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - loss.backward() - optimizer.step() - scheduler.step() - -Most learning rate schedulers can be called back-to-back (also referred to as -chaining schedulers). The result is that each scheduler is applied one after the -other on the learning rate obtained by the one preceding it. - -Example:: - - optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - scheduler1 = ExponentialLR(optimizer, gamma=0.9) - scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) - - for epoch in range(20): - for input, target in dataset: - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - loss.backward() - optimizer.step() - scheduler1.step() - scheduler2.step() - -In many places in the documentation, we will use the following template to refer to schedulers -algorithms. - - >>> scheduler = ... - >>> for epoch in range(100): - >>> train(...) - >>> validate(...) - >>> scheduler.step() - -.. warning:: - Prior to PyTorch 1.1.0, the learning rate scheduler was expected to be called before - the optimizer's update; 1.1.0 changed this behavior in a BC-breaking way. If you use - the learning rate scheduler (calling ``scheduler.step()``) before the optimizer's update - (calling ``optimizer.step()``), this will skip the first value of the learning rate schedule. - If you are unable to reproduce results after upgrading to PyTorch 1.1.0, please check - if you are calling ``scheduler.step()`` at the wrong time. - - -.. autosummary:: - :toctree: generated - :nosignatures: - - lr_scheduler.LRScheduler - lr_scheduler.LambdaLR - lr_scheduler.MultiplicativeLR - lr_scheduler.StepLR - lr_scheduler.MultiStepLR - lr_scheduler.ConstantLR - lr_scheduler.LinearLR - lr_scheduler.ExponentialLR - lr_scheduler.PolynomialLR - lr_scheduler.CosineAnnealingLR - lr_scheduler.ChainedScheduler - lr_scheduler.SequentialLR - lr_scheduler.ReduceLROnPlateau - lr_scheduler.CyclicLR - lr_scheduler.OneCycleLR - lr_scheduler.CosineAnnealingWarmRestarts - -How to utilize named parameters to load optimizer state dict ------------------------------------------------------------- - -The function :func:`~Optimizer.load_state_dict` stores the optional ``param_names`` content from the -loaded state dict if present. However, the process of loading the optimizer state is not affected, -as the order of the parameters matters to maintain compatibility (in case of different ordering). -To utilize the loaded parameters names from the loaded state dict, a custom ``register_load_state_dict_pre_hook`` -needs to be implemented according to the desired behavior. - -This can be useful, for instance, when the model architecture changes, but the weights and optimizer states need to -remain unchanged. The following example demonstrates how to implement this customization. - -Example:: - - class OneLayerModel(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(3, 4) - - def forward(self, x): - return self.fc(x) - - model = OneLayerModel() - optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) - # training.. - torch.save(optimizer.state_dict(), PATH) - -Let's say that ``model`` implements an expert (MoE), and we want to duplicate it and resume training -for two experts, both initialized the same way as the ``fc`` layer. For the following ``model2`` we create two layers identical to ``fc`` and resume training by loading the model weights and optimizer states from ``model`` into both ``fc1`` and ``fc2`` of ``model2`` (and adjust them accordingly):: - - class TwoLayerModel(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(3, 4) - self.fc2 = nn.Linear(3, 4) - - def forward(self, x): - return (self.fc1(x) + self.fc2(x)) / 2 - - model2 = TwoLayerModel() - # adapt and load model weights.. - optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9) - -To load the state dict for ``optimizer2`` with the state dict of the previous optimizer such that both -``fc1`` and ``fc2`` will be initialized with a copy of ``fc`` optimizer states -(to resume training for each layer from ``fc``), we can use the following hook:: - - def adapt_state_dict_ids(optimizer, state_dict): - adapted_state_dict = deepcopy(optimizer.state_dict()) - # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict. - for k, v in state_dict['param_groups'][0].items(): - if k not in ['params', 'param_names']: - adapted_state_dict['param_groups'][0][k] = v - - lookup_dict = { - 'fc1.weight': 'fc.weight', - 'fc1.bias': 'fc.bias', - 'fc2.weight': 'fc.weight', - 'fc2.bias': 'fc.bias' - } - clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()} - for param_id, param_name in zip( - optimizer.state_dict()['param_groups'][0]['params'], - optimizer.state_dict()['param_groups'][0]['param_names']): - name_in_loaded = lookup_dict[param_name] - index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded) - id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list] - # Copy the state of the corresponding parameter - if id_in_loaded in state_dict['state']: - adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded]) - - return adapted_state_dict - - optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids) - optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict - -This ensures that the adapted state_dict with the correct states for the layers of ``model2`` will be used -during model loading. -Note that this code is designed specifically for this example (e.g., assuming a single parameter group), -and other cases might require different adaptations. - -The following example shows how to handle missing parameters in a loaded -``state dict`` when the model structure changes. -The ``Model_bypass`` adds a new ``bypass`` layer, which is not present in the original ``Model1``. -To resume training, a custom ``adapt_state_dict_missing_param`` hook is used to adapt the optimizer's ``state_dict``, -ensuring existing parameters are mapped correctly, while missing ones (like the bypass layer) remain unchanged -(as initialized in this example). -This approach enables smooth loading and resuming of the optimizer state despite model changes. -The new bypass layer will be trained from scratch:: - - class Model1(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(5, 5) - - def forward(self, x): - return self.fc(x) + x - - - model = Model1() - optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) - # training.. - torch.save(optimizer.state_dict(), PATH) - - class Model_bypass(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(5, 5) - self.bypass = nn.Linear(5, 5, bias=False) - torch.nn.init.eye_(self.bypass.weight) - - def forward(self, x): - return self.fc(x) + self.bypass(x) - - model2 = Model_bypass() - optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9) - - def adapt_state_dict_missing_param(optimizer, state_dict): - adapted_state_dict = deepcopy(optimizer.state_dict()) - # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict. - for k, v in state_dict['param_groups'][0].items(): - if k not in ['params', 'param_names']: - adapted_state_dict['param_groups'][0][k] = v - - lookup_dict = { - 'fc.weight': 'fc.weight', - 'fc.bias': 'fc.bias', - 'bypass.weight': None, - } - - clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()} - for param_id, param_name in zip( - optimizer.state_dict()['param_groups'][0]['params'], - optimizer.state_dict()['param_groups'][0]['param_names']): - name_in_loaded = lookup_dict[param_name] - if name_in_loaded in state_dict['param_groups'][0]['param_names']: - index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded) - id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list] - # Copy the state of the corresponding parameter - if id_in_loaded in state_dict['state']: - adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded]) - - return adapted_state_dict - - optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids) - optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict - - - -As a third example, instead of loading a state according to the order of parameters (the default approach), -this hook can be used to load according to the parameters' names:: - - def names_matching(optimizer, state_dict): - assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups']) - adapted_state_dict = deepcopy(optimizer.state_dict()) - for g_ind in range(len(state_dict['param_groups'])): - assert len(state_dict['param_groups'][g_ind]['params']) == len( - optimizer.state_dict()['param_groups'][g_ind]['params']) - - for k, v in state_dict['param_groups'][g_ind].items(): - if k not in ['params', 'param_names']: - adapted_state_dict['param_groups'][g_ind][k] = v - - for param_id, param_name in zip( - optimizer.state_dict()['param_groups'][g_ind]['params'], - optimizer.state_dict()['param_groups'][g_ind]['param_names']): - index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name) - id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list] - # Copy the state of the corresponding parameter - if id_in_loaded in state_dict['state']: - adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded]) - - return adapted_state_dict - - - -Weight Averaging (SWA and EMA) ------------------------------- - -:class:`torch.optim.swa_utils.AveragedModel` implements Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA), -:class:`torch.optim.swa_utils.SWALR` implements the SWA learning rate scheduler and -:func:`torch.optim.swa_utils.update_bn` is a utility function used to update SWA/EMA batch -normalization statistics at the end of training. - -SWA has been proposed in `Averaging Weights Leads to Wider Optima and Better Generalization`_. - -EMA is a widely known technique to reduce the training time by reducing the number of weight updates needed. It is a variation of `Polyak averaging`_, but using exponential weights instead of equal weights across iterations. - -.. _`Averaging Weights Leads to Wider Optima and Better Generalization`: https://arxiv.org/abs/1803.05407 - -.. _`Polyak averaging`: https://paperswithcode.com/method/polyak-averaging - -Constructing averaged models -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The `AveragedModel` class serves to compute the weights of the SWA or EMA model. - -You can create an SWA averaged model by running: - ->>> averaged_model = AveragedModel(model) - -EMA models are constructed by specifying the ``multi_avg_fn`` argument as follows: - ->>> decay = 0.999 ->>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay)) - -Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to :func:`torch.optim.swa_utils.get_ema_multi_avg_fn`, the default is 0.999. Decay value should be close to 1.0, as smaller values can cause optimization convergence issues. - -:func:`torch.optim.swa_utils.get_ema_multi_avg_fn` returns a function that applies the following EMA equation to the weights: - -.. math:: W^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t - -where alpha is the EMA decay. - -Here the model ``model`` can be an arbitrary :class:`torch.nn.Module` object. ``averaged_model`` -will keep track of the running averages of the parameters of the ``model``. To update these -averages, you should use the :func:`update_parameters` function after the `optimizer.step()`: - ->>> averaged_model.update_parameters(model) - -For SWA and EMA, this call is usually done right after the optimizer ``step()``. In the case of SWA, this is usually skipped for some numbers of steps at the beginning of the training. - -Custom averaging strategies -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -By default, :class:`torch.optim.swa_utils.AveragedModel` computes a running equal average of -the parameters that you provide, but you can also use custom averaging functions with the -``avg_fn`` or ``multi_avg_fn`` parameters: - -- ``avg_fn`` allows defining a function operating on each parameter tuple (averaged parameter, model parameter) and should return the new averaged parameter. -- ``multi_avg_fn`` allows defining more efficient operations acting on a tuple of parameter lists, (averaged parameter list, model parameter list), at the same time, for example using the ``torch._foreach*`` functions. This function must update the averaged parameters in-place. - -In the following example ``ema_model`` computes an exponential moving average using the ``avg_fn`` parameter: - ->>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\ ->>> 0.9 * averaged_model_parameter + 0.1 * model_parameter ->>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg) - - -In the following example ``ema_model`` computes an exponential moving average using the more efficient ``multi_avg_fn`` parameter: - ->>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9)) - - -SWA learning rate schedules -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Typically, in SWA the learning rate is set to a high constant value. :class:`SWALR` is a -learning rate scheduler that anneals the learning rate to a fixed value, and then keeps it -constant. For example, the following code creates a scheduler that linearly anneals the -learning rate from its initial value to 0.05 in 5 epochs within each parameter group: - ->>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \ ->>> anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05) - -You can also use cosine annealing to a fixed value instead of linear annealing by setting -``anneal_strategy="cos"``. - - -Taking care of batch normalization -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -:func:`update_bn` is a utility function that allows to compute the batchnorm statistics for the SWA model -on a given dataloader ``loader`` at the end of training: - ->>> torch.optim.swa_utils.update_bn(loader, swa_model) - -:func:`update_bn` applies the ``swa_model`` to every element in the dataloader and computes the activation -statistics for each batch normalization layer in the model. - -.. warning:: - :func:`update_bn` assumes that each batch in the dataloader ``loader`` is either a tensors or a list of - tensors where the first element is the tensor that the network ``swa_model`` should be applied to. - If your dataloader has a different structure, you can update the batch normalization statistics of the - ``swa_model`` by doing a forward pass with the ``swa_model`` on each element of the dataset. - - - - -Putting it all together: SWA -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the example below, ``swa_model`` is the SWA model that accumulates the averages of the weights. -We train the model for a total of 300 epochs and we switch to the SWA learning rate schedule -and start to collect SWA averages of the parameters at epoch 160: - ->>> loader, optimizer, model, loss_fn = ... ->>> swa_model = torch.optim.swa_utils.AveragedModel(model) ->>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) ->>> swa_start = 160 ->>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) ->>> ->>> for epoch in range(300): ->>> for input, target in loader: ->>> optimizer.zero_grad() ->>> loss_fn(model(input), target).backward() ->>> optimizer.step() ->>> if epoch > swa_start: ->>> swa_model.update_parameters(model) ->>> swa_scheduler.step() ->>> else: ->>> scheduler.step() ->>> ->>> # Update bn statistics for the swa_model at the end ->>> torch.optim.swa_utils.update_bn(loader, swa_model) ->>> # Use swa_model to make predictions on test data ->>> preds = swa_model(test_input) - - -Putting it all together: EMA -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the example below, ``ema_model`` is the EMA model that accumulates the exponentially-decayed averages of the weights with a decay rate of 0.999. -We train the model for a total of 300 epochs and start to collect EMA averages immediately. - ->>> loader, optimizer, model, loss_fn = ... ->>> ema_model = torch.optim.swa_utils.AveragedModel(model, \ ->>> multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999)) ->>> ->>> for epoch in range(300): ->>> for input, target in loader: ->>> optimizer.zero_grad() ->>> loss_fn(model(input), target).backward() ->>> optimizer.step() ->>> ema_model.update_parameters(model) ->>> ->>> # Update bn statistics for the ema_model at the end ->>> torch.optim.swa_utils.update_bn(loader, ema_model) ->>> # Use ema_model to make predictions on test data ->>> preds = ema_model(test_input) - -.. autosummary:: - :toctree: generated - :nosignatures: - - swa_utils.AveragedModel - swa_utils.SWALR - - -.. autofunction:: torch.optim.swa_utils.get_ema_multi_avg_fn -.. autofunction:: torch.optim.swa_utils.update_bn - - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.optim.adadelta -.. py:module:: torch.optim.adagrad -.. py:module:: torch.optim.adam -.. py:module:: torch.optim.adamax -.. py:module:: torch.optim.adamw -.. py:module:: torch.optim.asgd -.. py:module:: torch.optim.lbfgs -.. py:module:: torch.optim.lr_scheduler -.. py:module:: torch.optim.nadam -.. py:module:: torch.optim.optimizer -.. py:module:: torch.optim.radam -.. py:module:: torch.optim.rmsprop -.. py:module:: torch.optim.rprop -.. py:module:: torch.optim.sgd -.. py:module:: torch.optim.sparse_adam -.. py:module:: torch.optim.swa_utils diff --git a/docs/source/package.md b/docs/source/package.md new file mode 100644 index 00000000000000..e337fedde3e6bb --- /dev/null +++ b/docs/source/package.md @@ -0,0 +1,756 @@ +```{eval-rst} +.. automodule:: torch.package +.. py:module:: torch.package.analyze + +.. currentmodule:: torch.package +``` + +# torch.package +`torch.package` adds support for creating packages containing both artifacts and arbitrary +PyTorch code. These packages can be saved, shared, used to load and execute models +at a later date or on a different machine, and can even be deployed to production using +`torch::deploy`. + +This document contains tutorials, how-to guides, explanations, and an API reference that +will help you learn more about `torch.package` and how to use it. + +```{warning} +This module depends on the `pickle` module which is not secure. Only unpackage data you trust. + +It is possible to construct malicious pickle data which will **execute arbitrary code during unpickling**. +Never unpackage data that could have come from an untrusted source, or that could have been tampered with. + +For more information, review the [documentation](https://docs.python.org/3/library/pickle.html) for the `pickle` module. +``` + +```{contents} +:local: +:depth: 2 +``` + +## Tutorials +### Packaging your first model +A tutorial that guides you through packaging and unpackaging a simple model is available +[on Colab](https://colab.research.google.com/drive/1lFZkLyViGfXxB-m3jqlyTQuYToo3XLo-). +After completing this exercise, you will be familiar with the basic API for creating and using +Torch packages. + +## How do I... + +### See what is inside a package? + +#### Treat the package like a ZIP archive + +The container format for a `torch.package` is ZIP, so any tools that work with standard ZIP files should +work for exploring the contents. Some common ways to interact with ZIP files: + +* `unzip my_package.pt` will unzip the `torch.package` archive to disk, where you can freely inspect its contents. + +``` +$ unzip my_package.pt && tree my_package +my_package +├── .data +│ ├── 94304870911616.storage +│ ├── 94304900784016.storage +│ ├── extern_modules +│ └── version +├── models +│ └── model_1.pkl +└── torchvision + └── models + ├── resnet.py + └── utils.py +~ cd my_package && cat torchvision/models/resnet.py +... +``` + +* The Python `zipfile` module provides a standard way to read and write ZIP archive contents. + +```python +from zipfile import ZipFile +with ZipFile("my_package.pt") as myzip: + file_bytes = myzip.read("torchvision/models/resnet.py") + # edit file_bytes in some way + myzip.writestr("torchvision/models/resnet.py", new_file_bytes) +``` + +* vim has the ability to natively read ZIP archives. You can even edit files and :`write` them back into the archive! + +```vim +# add this to your .vimrc to treat `*.pt` files as zip files +au BufReadCmd *.pt call zip#Browse(expand("")) + +~ vi my_package.pt +``` + +#### Use the `file_structure()` API +{class}`PackageImporter` provides a `file_structure()` method, which will return a printable +and queryable {class}`Directory` object. The {class}`Directory` object is a simple directory structure that you can use to explore the +current contents of a `torch.package`. + +The {class}`Directory` object itself is directly printable and will print out a file tree representation. To filter what is returned, +use the glob-style `include` and `exclude` filtering arguments. + +```python +with PackageExporter('my_package.pt') as pe: + pe.save_pickle('models', 'model_1.pkl', mod) + +importer = PackageImporter('my_package.pt') +# can limit printed items with include/exclude args +print(importer.file_structure(include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage")) +print(importer.file_structure()) # will print out all files +``` + +Output: + +``` +# filtered with glob pattern: +# include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage" +─── my_package.pt + ├── models + │ └── model_1.pkl + └── torchvision + └── models + └── utils.py + +# all files +─── my_package.pt + ├── .data + │ ├── 94304870911616.storage + │ ├── 94304900784016.storage + │ ├── extern_modules + │ └── version + ├── models + │ └── model_1.pkl + └── torchvision + └── models + ├── resnet.py + └── utils.py +``` + +You can also query {class}`Directory` objects with the `has_file()` method. + +```python +importer_file_structure = importer.file_structure() +found: bool = importer_file_structure.has_file("package_a/subpackage.py") +``` + +### See why a given module was included as a dependency? + +Say there is a given module `foo`, and you want to know why your {class}`PackageExporter` is pulling in `foo` as a dependency. + +{meth}`PackageExporter.get_rdeps` will return all modules that directly depend on `foo`. + +If you would like to see how a given module `src` depends on `foo`, the {meth}`PackageExporter.all_paths` method will +return a DOT-formatted graph showing all the dependency paths between `src` and `foo`. + +If you would just like to see the whole dependency graph of your :class:`PackageExporter`, you can use {meth}`PackageExporter.dependency_graph_string`. + + +### Include arbitrary resources with my package and access them later? +{class}`PackageExporter` exposes three methods, `save_pickle`, `save_text` and `save_binary` that allow you to save +Python objects, text, and binary data to a package. + +```python +with torch.PackageExporter("package.pt") as exporter: + # Pickles the object and saves to `my_resources/tensor.pkl` in the archive. + exporter.save_pickle("my_resources", "tensor.pkl", torch.randn(4)) + exporter.save_text("config_stuff", "words.txt", "a sample string") + exporter.save_binary("raw_data", "binary", my_bytes) + +``` +{class}`PackageImporter` exposes complementary methods named `load_pickle`, `load_text` and `load_binary` that allow you to load +Python objects, text and binary data from a package. + +```python +importer = torch.PackageImporter("package.pt") +my_tensor = importer.load_pickle("my_resources", "tensor.pkl") +text = importer.load_text("config_stuff", "words.txt") +binary = importer.load_binary("raw_data", "binary") +``` + +### Customize how a class is packaged? +`torch.package` allows for the customization of how classes are packaged. This behavior is accessed through defining the method +`__reduce_package__` on a class and by defining a corresponding de-packaging function. This is similar to defining `__reduce__` for +Python’s normal pickling process. + +Steps: + +1. Define the method `__reduce_package__(self, exporter: PackageExporter)` on the target class. This method should do the work to save the class instance inside of the package, and should return a tuple of the corresponding de-packaging function with the arguments needed to invoke the de-packaging function. This method is called by the `PackageExporter` when it encounters an instance of the target class. +2. Define a de-packaging function for the class. This de-packaging function should do the work to reconstruct and return an instance of the class. The function signature’s first parameter should be a `PackageImporter` instance, and the rest of the parameters are user defined. + + +```python +# foo.py [Example of customizing how class Foo is packaged] +from torch.package import PackageExporter, PackageImporter +import time + + +class Foo: + def __init__(self, my_string: str): + super().__init__() + self.my_string = my_string + self.time_imported = 0 + self.time_exported = 0 + + def __reduce_package__(self, exporter: PackageExporter): + """ + Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when + saving an instance of this object. This method should do the work to save this + object inside of the ``torch.package`` archive. + + Returns function w/ arguments to load the object from a + ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function. + """ + + # use this pattern to ensure no naming conflicts with normal dependencies, + # anything saved under this module name shouldn't conflict with other + # items in the package + generated_module_name = f"foo-generated._{exporter.get_unique_id()}" + exporter.save_text( + generated_module_name, + "foo.txt", + self.my_string + ", with exporter modification!", + ) + time_exported = time.clock_gettime(1) + + # returns de-packaging function w/ arguments to invoke with + return (unpackage_foo, (generated_module_name, time_exported,)) + + +def unpackage_foo( + importer: PackageImporter, generated_module_name: str, time_exported: float +) -> Foo: + """ + Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function + when depickling a Foo object. + Performs work of loading and returning a Foo instance from a ``torch.package`` archive. + """ + time_imported = time.clock_gettime(1) + foo = Foo(importer.load_text(generated_module_name, "foo.txt")) + foo.time_imported = time_imported + foo.time_exported = time_exported + return foo + +``` + + +```python +# example of saving instances of class Foo + +import torch +from torch.package import PackageImporter, PackageExporter +import foo + +foo_1 = foo.Foo("foo_1 initial string") +foo_2 = foo.Foo("foo_2 initial string") +with PackageExporter('foo_package.pt') as pe: + # save as normal, no extra work necessary + pe.save_pickle('foo_collection', 'foo1.pkl', foo_1) + pe.save_pickle('foo_collection', 'foo2.pkl', foo_2) + +pi = PackageImporter('foo_package.pt') +print(pi.file_structure()) +imported_foo = pi.load_pickle('foo_collection', 'foo1.pkl') +print(f"foo_1 string: '{imported_foo.my_string}'") +print(f"foo_1 export time: {imported_foo.time_exported}") +print(f"foo_1 import time: {imported_foo.time_imported}") +``` + +``` +# output of running above script +─── foo_package + ├── foo-generated + │ ├── _0 + │ │ └── foo.txt + │ └── _1 + │ └── foo.txt + ├── foo_collection + │ ├── foo1.pkl + │ └── foo2.pkl + └── foo.py + +foo_1 string: 'foo_1 initial string, with reduction modification!' +foo_1 export time: 9857706.650140837 +foo_1 import time: 9857706.652698385 +``` + +### Test in my source code whether or not it is executing inside a package? + +A {class}`PackageImporter` will add the attribute `__torch_package__` to every module that it initializes. Your code can check for the +presence of this attribute to determine whether it is executing in a packaged context or not. + +```python +# In foo/bar.py: + +if "__torch_package__" in dir(): # true if the code is being loaded from a package + def is_in_package(): + return True + + UserException = Exception +else: + def is_in_package(): + return False + + UserException = UnpackageableException +``` + +Now, the code will behave differently depending on whether it’s imported normally through your Python environment or imported from a +`torch.package`. + +```python +from foo.bar import is_in_package + +print(is_in_package()) # False + +loaded_module = PackageImporter(my_package).import_module("foo.bar") +loaded_module.is_in_package() # True +``` + +**Warning**: in general, it’s bad practice to have code that behaves differently depending on whether it’s packaged or not. This can lead to +hard-to-debug issues that are sensitive to how you imported your code. If your package is intended to be heavily used, consider restructuring +your code so that it behaves the same way no matter how it was loaded. + + +### Patch code into a package? +{class}`PackageExporter` offers a `save_source_string()` method that allows one to save arbitrary Python source code to a module of your choosing. +```python +with PackageExporter(f) as exporter: + # Save the my_module.foo available in your current Python environment. + exporter.save_module("my_module.foo") + + # This saves the provided string to my_module/foo.py in the package archive. + # It will override the my_module.foo that was previously saved. + exporter.save_source_string("my_module.foo", textwrap.dedent( + """\ + def my_function(): + print('hello world') + """ + )) + + # If you want to treat my_module.bar as a package + # (e.g. save to `my_module/bar/__init__.py` instead of `my_module/bar.py) + # pass is_package=True, + exporter.save_source_string("my_module.bar", + "def foo(): print('hello')\n", + is_package=True) + +importer = PackageImporter(f) +importer.import_module("my_module.foo").my_function() # prints 'hello world' +``` + +### Access package contents from packaged code? +{class}`PackageImporter` implements the +[`importlib.resources`](https://docs.python.org/3/library/importlib.html#module-importlib.resources) +API for accessing resources from inside a package. + +```python +with PackageExporter(f) as exporter: + # saves text to my_resource/a.txt in the archive + exporter.save_text("my_resource", "a.txt", "hello world!") + # saves the tensor to my_pickle/obj.pkl + exporter.save_pickle("my_pickle", "obj.pkl", torch.ones(2, 2)) + + # see below for module contents + exporter.save_module("foo") + exporter.save_module("bar") +``` + +The `importlib.resources` API allows access to resources from within packaged code. + + +```python +# foo.py: +import importlib.resources +import my_resource + +# returns "hello world!" +def get_my_resource(): + return importlib.resources.read_text(my_resource, "a.txt") +``` + +Using `importlib.resources` is the recommended way to access package contents from within packaged code, since it complies +with the Python standard. However, it is also possible to access the parent :class:`PackageImporter` instance itself from within +packaged code. + +```python +# bar.py: +import torch_package_importer # this is the PackageImporter that imported this module. + +# Prints "hello world!", equivalent to importlib.resources.read_text +def get_my_resource(): + return torch_package_importer.load_text("my_resource", "a.txt") + +# You also do things that the importlib.resources API does not support, like loading +# a pickled object from the package. +def get_my_pickle(): + return torch_package_importer.load_pickle("my_pickle", "obj.pkl") +``` + +### Distinguish between packaged code and non-packaged code? +To tell if an object’s code is from a `torch.package`, use the `torch.package.is_from_package()` function. +Note: if an object is from a package but its definition is from a module marked `extern` or from `stdlib`, +this check will return `False`. + +```python +importer = PackageImporter(f) +mod = importer.import_module('foo') +obj = importer.load_pickle('model', 'model.pkl') +txt = importer.load_text('text', 'my_test.txt') + +assert is_from_package(mod) +assert is_from_package(obj) +assert not is_from_package(txt) # str is from stdlib, so this will return False +``` + +### Re-export an imported object? +To re-export an object that was previously imported by a {class}`PackageImporter`, you must make the new {class}`PackageExporter` +aware of the original {class}`PackageImporter` so that it can find source code for your object’s dependencies. + +```python +importer = PackageImporter(f) +obj = importer.load_pickle("model", "model.pkl") + +# re-export obj in a new package +with PackageExporter(f2, importer=(importer, sys_importer)) as exporter: + exporter.save_pickle("model", "model.pkl", obj) +``` + +### Package a TorchScript module? +To package a TorchScript model, use the same `save_pickle` and `load_pickle` APIs as you would with any other object. +Saving TorchScript objects that are attributes or submodules is supported as well with no extra work. + +```python +# save TorchScript just like any other object +with PackageExporter(file_name) as e: + e.save_pickle("res", "script_model.pkl", scripted_model) + e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule) +# load as normal +importer = PackageImporter(file_name) +loaded_script = importer.load_pickle("res", "script_model.pkl") +loaded_mixed = importer.load_pickle("res", "mixed_model.pkl" +``` + +## Explanation + +### `torch.package` Format Overview +A `torch.package` file is a ZIP archive which conventionally uses the `.pt` extension. Inside the ZIP archive, there are two kinds of files: + +* Framework files, which are placed in the `.data/`. +* User files, which is everything else. + +As an example, this is what a fully packaged ResNet model from `torchvision` looks like: + +``` +resnet +├── .data # All framework-specific data is stored here. +│ │ # It's named to avoid conflicts with user-serialized code. +│ ├── 94286146172688.storage # tensor data +│ ├── 94286146172784.storage +│ ├── extern_modules # text file with names of extern modules (e.g. 'torch') +│ ├── version # version metadata +│ ├── ... +├── model # the pickled model +│ └── model.pkl +└── torchvision # all code dependencies are captured as source files + └── models + ├── resnet.py + └── utils.py +``` + +#### Framework files +The `.data/` directory is owned by torch.package, and its contents are considered to be a private implementation detail. +The `torch.package` format makes no guarantees about the contents of `.data/`, but any changes made will be backward compatible +(that is, newer version of PyTorch will always be able to load older `torch.packages`). + +Currently, the `.data/` directory contains the following items: + +* `version`: a version number for the serialized format, so that the `torch.package` import infrastructures knows how to load this package. +* `extern_modules`: a list of modules that are considered `extern`. `extern` modules will be imported using the loading environment’s system importer. +* `*.storage`: serialized tensor data. + +``` +.data +├── 94286146172688.storage +├── 94286146172784.storage +├── extern_modules +├── version +├── ... +``` + +#### User files +All other files in the archive were put there by a user. The layout is identical to a Python +[regular package](https://docs.python.org/3/reference/import.html#regular-packages). For a deeper dive in how Python packaging works, +please consult [this essay](https://www.python.org/doc/essays/packages/) (it’s slightly out of date, so double-check implementation details +with the [Python reference documentation](https://docs.python.org/3/library/importlib.html). + +``` + +├── model # the pickled model +│ └── model.pkl +├── another_package +│ ├── __init__.py +│ ├── foo.txt # a resource file , see importlib.resources +│ └── ... +└── torchvision + └── models + ├── resnet.py # torchvision.models.resnet + └── utils.py # torchvision.models.utils +``` + +### How `torch.package` finds your code's dependencies +#### Analyzing an object's dependencies +When you issue a `save_pickle(obj, ...)` call, {class}`PackageExporter` will pickle the object normally. Then, it uses the +`pickletools` standard library module to parse the pickle bytecode. + +In a pickle, an object is saved along with a `GLOBAL` opcode that describes where to find the implementation of the object’s type, like: + +``` +GLOBAL 'torchvision.models.resnet Resnet` +``` + +The dependency resolver will gather up all `GLOBAL` ops and mark them as dependencies of your pickled object. +For more information about pickling and the pickle format, please consult [the Python docs](https://docs.python.org/3/library/pickle.html). + +#### Analyzing a module's dependencies +When a Python module is identified as a dependency, `torch.package` walks the module’s python AST representation and looks for import statements with +full support for the standard forms: `from x import y`, `import z`, `from w import v as u`, etc. When one of these import statements are +encountered, `torch.package` registers the imported modules as dependencies that are then themselves parsed in the same AST walking way. + +**Note**: AST parsing has limited support for the `__import__(...)` syntax and does not support `importlib.import_module` calls. In general, you should +not expect dynamic imports to be detected by `torch.package`. + + +### Dependency Management +`torch.package` automatically finds the Python modules that your code and objects depend on. This process is called dependency resolution. +For each module that the dependency resolver finds, you must specify an *action* to take. + +The allowed actions are: + +* `intern`: put this module into the package. +* `extern`: declare this module as an external dependency of the package. +* `mock`: stub out this module. +* `deny`: depending on this module will raise an error during package export. + +Finally, there is one more important action that is not technically part of `torch.package`: + +* Refactoring: remove or change the dependencies in your code. + +Note that actions are only defined on entire Python modules. There is no way to package “just” a function or class from a module and leave the rest out. +This is by design. Python does not offer clean boundaries between objects defined in a module. The only defined unit of dependency organization is a +module, so that’s what `torch.package` uses. + +Actions are applied to modules using patterns. Patterns can either be module names (`"foo.bar"`) or globs (like `"foo.**"`). You associate a pattern +with an action using methods on {class}`PackageExporter`, e.g. + +```python +my_exporter.intern("torchvision.**") +my_exporter.extern("numpy") +``` + +If a module matches a pattern, the corresponding action is applied to it. For a given module, patterns will be checked in the order that they were defined, +and the first action will be taken. + + +#### `intern` +If a module is `intern`-ed, it will be placed into the package. + +This action is your model code, or any related code you want to package. For example, if you are trying to package a ResNet from `torchvision`, +you will need to `intern` the module torchvision.models.resnet. + +On package import, when your packaged code tries to import an `intern`-ed module, PackageImporter will look inside your package for that module. +If it can’t find that module, an error will be raised. This ensures that each {class}`PackageImporter` is isolated from the loading environment—even +if you have `my_interned_module` available in both your package and the loading environment, {class}`PackageImporter` will only use the version in your +package. + +**Note**: Only Python source modules can be `intern`-ed. Other kinds of modules, like C extension modules and bytecode modules, will raise an error if +you attempt to `intern` them. These kinds of modules need to be `mock`-ed or `extern`-ed. + + +#### `extern` +If a module is `extern`-ed, it will not be packaged. Instead, it will be added to a list of external dependencies for this package. You can find this +list on `package_exporter.extern_modules`. + +On package import, when the packaged code tries to import an `extern`-ed module, {class}`PackageImporter` will use the default Python importer to find +that module, as if you did `importlib.import_module("my_externed_module")`. If it can’t find that module, an error will be raised. + +In this way, you can depend on third-party libraries like `numpy` and `scipy` from within your package without having to package them too. + +**Warning**: If any external library changes in a backwards-incompatible way, your package may fail to load. If you need long-term reproducibility +for your package, try to limit your use of `extern`. + + +#### `mock` +If a module is `mock`-ed, it will not be packaged. Instead a stub module will be packaged in its place. The stub module will allow you to retrieve +objects from it (so that `from my_mocked_module import foo` will not error), but any use of that object will raise a `NotImplementedError`. + +`mock` should be used for code that you “know” will not be needed in the loaded package, but you still want available for use in non-packaged contents. +For example, initialization/configuration code, or code only used for debugging/training. + +**Warning**: In general, `mock` should be used as a last resort. It introduces behavioral differences between packaged code and non-packaged code, +which may lead to later confusion. Prefer instead to refactor your code to remove unwanted dependencies. + + +#### Refactoring +The best way to manage dependencies is to not have dependencies at all! Often, code can be refactored to remove unnecessary dependencies. Here are some +guidelines for writing code with clean dependencies (which are also generally good practices!): + +**Include only what you use**. Do not leave unused imports in your code. The dependency resolver is not smart enough to tell that they are indeed unused, +and will try to process them. + +**Qualify your imports**. For example, instead of writing import foo and later using `foo.bar.baz`, prefer to write `from foo.bar import baz`. This more +precisely specifies your real dependency (`foo.bar`) and lets the dependency resolver know you don’t need all of `foo`. + +**Split up large files with unrelated functionality into smaller ones**. If your `utils` module contains a hodge-podge of unrelated functionality, any module +that depends on `utils` will need to pull in lots of unrelated dependencies, even if you only needed a small part of it. Prefer instead to define +single-purpose modules that can be packaged independently of one another. + + +#### Patterns +Patterns allow you to specify groups of modules with a convenient syntax. The syntax and behavior of patterns follows the Bazel/Buck +[glob()](https://docs.bazel.build/versions/master/be/functions.html#glob). + +A module that we are trying to match against a pattern is called a candidate. A candidate is composed of a list of segments separated by a +separator string, e.g. `foo.bar.baz`. + +A pattern contains one or more segments. Segments can be: + +* A literal string (e.g. `foo`), which matches exactly. +* A string containing a wildcard (e.g. `torch`, or `foo*baz*`). The wildcard matches any string, including the empty string. +* A double wildcard (`**`). This matches against zero or more complete segments. + +Examples: + +* `torch.**`: matches `torch` and all its submodules, e.g. `torch.nn` and `torch.nn.functional`. +* `torch.*`: matches `torch.nn` or `torch.functional`, but not `torch.nn.functional` or `torch` +* `torch*.**`: matches `torch`, `torchvision`, and all of their submodules + +When specifying actions, you can pass multiple patterns, e.g. + +```python +exporter.intern(["torchvision.models.**", "torchvision.utils.**"]) +``` + +A module will match against this action if it matches any of the patterns. + +You can also specify patterns to exclude, e.g. + +```python +exporter.mock("**", exclude=["torchvision.**"]) +``` + + +A module will not match against this action if it matches any of the exclude patterns. In this example, we are mocking all modules except +`torchvision` and its submodules. + +When a module could potentially match against multiple actions, the first action defined will be taken. + + +### `torch.package` sharp edges +#### Avoid global state in your modules +Python makes it really easy to bind objects and run code at module-level scope. This is generally fine—after all, functions and classes are bound to +names this way. However, things become more complicated when you define an object at module scope with the intention of mutating it, introducing mutable +global state. + +Mutable global state is quite useful—it can reduce boilerplate, allow for open registration into tables, etc. But unless employed very carefully, it can +cause complications when used with `torch.package`. + +Every {class}`PackageImporter` creates an independent environment for its contents. This is nice because it means we load multiple packages and ensure +they are isolated from each other, but when modules are written in a way that assumes shared mutable global state, this behavior can create hard-to-debug +errors. + +#### Types are not shared between packages and the loading environment +Any class that you import from a {class}`PackageImporter` will be a version of the class specific to that importer. For example: + + +```python +from foo import MyClass + +my_class_instance = MyClass() + +with PackageExporter(f) as exporter: + exporter.save_module("foo") + +importer = PackageImporter(f) +imported_MyClass = importer.import_module("foo").MyClass + +assert isinstance(my_class_instance, MyClass) # works +assert isinstance(my_class_instance, imported_MyClass) # ERROR! +``` + +In this example, `MyClass` and `imported_MyClass` are *not the same type*. In this specific example, `MyClass` and `imported_MyClass` have exactly the +same implementation, so you might think it’s okay to consider them the same class. But consider the situation where `imported_MyClass` is coming from an +older package with an entirely different implementation of `MyClass` — in that case, it’s unsafe to consider them the same class. + +Under the hood, each importer has a prefix that allows it to uniquely identify classes: + +```python +print(MyClass.__name__) # prints "foo.MyClass" +print(imported_MyClass.__name__) # prints .foo.MyClass +``` + +That means you should not expect `isinstance` checks to work when one of the arguments is from a package and the other is not. If you need this +functionality, consider the following options: + +* Doing duck typing (just using the class instead of explicitly checking that it is of a given type). +* Make the typing relationship an explicit part of the class contract. For example, you can add an attribute tag `self.handler = "handle_me_this_way"` and have client code check for the value of `handler` instead of checking the type directly. + + +### How `torch.package` keeps packages isolated from each other +Each {class}`PackageImporter` instance creates an independent, isolated environment for its modules and objects. Modules in a package can only import +other packaged modules, or modules marked `extern`. If you use multiple {class}`PackageImporter` instances to load a single package, you will get +multiple independent environments that do not interact. + +This is achieved by extending Python’s import infrastructure with a custom importer. {class}`PackageImporter` provides the same core API as the +`importlib` importer; namely, it implements the `import_module` and `__import__` methods. + +When you invoke {meth}`PackageImporter.import_module`, {class}`PackageImporter` will construct and return a new module, much as the system importer does. +However, {class}`PackageImporter` patches the returned module to use `self` (i.e. that {class}`PackageImporter` instance) to fulfill future import +requests by looking in the package rather than searching the user’s Python environment. + +#### Mangling +To avoid confusion (“is this `foo.bar` object the one from my package, or the one from my Python environment?”), {class}`PackageImporter` mangles the +`__name__` and `__file__` of all imported modules, by adding a *mangle prefix* to them. + +For `__name__`, a name like `torchvision.models.resnet18` becomes `.torchvision.models.resnet18`. + +For `__file__`, a name like `torchvision/models/resnet18.py` becomes `.torchvision/modules/resnet18.py`. + +Name mangling helps avoid inadvertent punning of module names between different packages, and helps you debug by making stack traces and print +statements more clearly show whether they are referring to packaged code or not. For developer-facing details about mangling, consult +`mangling.md` in `torch/package/`. + + +## API Reference +```{eval-rst} +.. autoclass:: torch.package.PackagingError + +.. autoclass:: torch.package.EmptyMatchError + +.. autoclass:: torch.package.PackageExporter + :members: + + .. automethod:: __init__ + +.. autoclass:: torch.package.PackageImporter + :members: + + .. automethod:: __init__ + +.. autoclass:: torch.package.Directory + :members: +``` + + +```{eval-rst} +.. py:module:: torch.package.analyze.find_first_use_of_broken_modules +.. py:module:: torch.package.analyze.is_from_package +.. py:module:: torch.package.analyze.trace_dependencies +.. py:module:: torch.package.file_structure_representation +.. py:module:: torch.package.find_file_dependencies +.. py:module:: torch.package.glob_group +.. py:module:: torch.package.importer +.. py:module:: torch.package.package_exporter +.. py:module:: torch.package.package_importer +``` diff --git a/docs/source/package.rst b/docs/source/package.rst deleted file mode 100644 index d8d6e3e28f1f48..00000000000000 --- a/docs/source/package.rst +++ /dev/null @@ -1,832 +0,0 @@ -.. automodule:: torch.package -.. py:module:: torch.package.analyze - -.. currentmodule:: torch.package - -torch.package -============= -``torch.package`` adds support for creating packages containing both artifacts and arbitrary -PyTorch code. These packages can be saved, shared, used to load and execute models -at a later date or on a different machine, and can even be deployed to production using -``torch::deploy``. - -This document contains tutorials, how-to guides, explanations, and an API reference that -will help you learn more about ``torch.package`` and how to use it. - - -.. warning:: - - This module depends on the ``pickle`` module which is not secure. Only unpackage data you trust. - - It is possible to construct malicious pickle data which will **execute arbitrary code during unpickling**. - Never unpackage data that could have come from an untrusted source, or that could have been tampered with. - - For more information, review the `documentation `_ for the ``pickle`` module. - - -.. contents:: :local: - :depth: 2 - - -Tutorials ---------- -Packaging your first model -^^^^^^^^^^^^^^^^^^^^^^^^^^ -A tutorial that guides you through packaging and unpackaging a simple model is available -`on Colab `_. -After completing this exercise, you will be familiar with the basic API for creating and using -Torch packages. - -How do I... ------------ -See what is inside a package? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Treat the package like a ZIP archive -"""""""""""""""""""""""""""""""""""" -The container format for a ``torch.package`` is ZIP, so any tools that work with standard ZIP files should -work for exploring the contents. Some common ways to interact with ZIP files: - -* ``unzip my_package.pt`` will unzip the ``torch.package`` archive to disk, where you can freely inspect its contents. - - -:: - - $ unzip my_package.pt && tree my_package - my_package - ├── .data - │ ├── 94304870911616.storage - │ ├── 94304900784016.storage - │ ├── extern_modules - │ └── version - ├── models - │ └── model_1.pkl - └── torchvision - └── models - ├── resnet.py - └── utils.py - ~ cd my_package && cat torchvision/models/resnet.py - ... - - -* The Python ``zipfile`` module provides a standard way to read and write ZIP archive contents. - - -:: - - from zipfile import ZipFile - with ZipFile("my_package.pt") as myzip: - file_bytes = myzip.read("torchvision/models/resnet.py") - # edit file_bytes in some way - myzip.writestr("torchvision/models/resnet.py", new_file_bytes) - - -* vim has the ability to natively read ZIP archives. You can even edit files and :``write`` them back into the archive! - - -:: - - # add this to your .vimrc to treat `*.pt` files as zip files - au BufReadCmd *.pt call zip#Browse(expand("")) - - ~ vi my_package.pt - - -Use the ``file_structure()`` API -"""""""""""""""""""""""""""""""" -:class:`PackageImporter` provides a ``file_structure()`` method, which will return a printable -and queryable :class:`Directory` object. The :class:`Directory` object is a simple directory structure that you can use to explore the -current contents of a ``torch.package``. - -The :class:`Directory` object itself is directly printable and will print out a file tree representation. To filter what is returned, -use the glob-style ``include`` and ``exclude`` filtering arguments. - - -:: - - with PackageExporter('my_package.pt') as pe: - pe.save_pickle('models', 'model_1.pkl', mod) - - importer = PackageImporter('my_package.pt') - # can limit printed items with include/exclude args - print(importer.file_structure(include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage")) - print(importer.file_structure()) # will print out all files - - -Output: - - -:: - - # filtered with glob pattern: - # include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage" - ─── my_package.pt - ├── models - │ └── model_1.pkl - └── torchvision - └── models - └── utils.py - - # all files - ─── my_package.pt - ├── .data - │ ├── 94304870911616.storage - │ ├── 94304900784016.storage - │ ├── extern_modules - │ └── version - ├── models - │ └── model_1.pkl - └── torchvision - └── models - ├── resnet.py - └── utils.py - - -You can also query :class:`Directory` objects with the ``has_file()`` method. - - -:: - - importer_file_structure = importer.file_structure() - found: bool = importer_file_structure.has_file("package_a/subpackage.py") - -See why a given module was included as a dependency? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Say there is a given module ``foo``, and you want to know why your :class:`PackageExporter` is pulling in ``foo`` as a dependency. - -:meth:`PackageExporter.get_rdeps` will return all modules that directly depend on ``foo``. - -If you would like to see how a given module ``src`` depends on ``foo``, the :meth:`PackageExporter.all_paths` method will -return a DOT-formatted graph showing all the dependency paths between ``src`` and ``foo``. - -If you would just like to see the whole dependency graph of your :class:`PackageExporter`, you can use :meth:`PackageExporter.dependency_graph_string`. - - -Include arbitrary resources with my package and access them later? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -:class:`PackageExporter` exposes three methods, ``save_pickle``, ``save_text`` and ``save_binary`` that allow you to save -Python objects, text, and binary data to a package. - - -:: - - with torch.PackageExporter("package.pt") as exporter: - # Pickles the object and saves to `my_resources/tensor.pkl` in the archive. - exporter.save_pickle("my_resources", "tensor.pkl", torch.randn(4)) - exporter.save_text("config_stuff", "words.txt", "a sample string") - exporter.save_binary("raw_data", "binary", my_bytes) - - -:class:`PackageImporter` exposes complementary methods named ``load_pickle``, ``load_text`` and ``load_binary`` that allow you to load -Python objects, text and binary data from a package. - - -:: - - importer = torch.PackageImporter("package.pt") - my_tensor = importer.load_pickle("my_resources", "tensor.pkl") - text = importer.load_text("config_stuff", "words.txt") - binary = importer.load_binary("raw_data", "binary") - - -Customize how a class is packaged? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -``torch.package`` allows for the customization of how classes are packaged. This behavior is accessed through defining the method -``__reduce_package__`` on a class and by defining a corresponding de-packaging function. This is similar to defining ``__reduce__`` for -Python’s normal pickling process. - -Steps: - -1. Define the method ``__reduce_package__(self, exporter: PackageExporter)`` on the target class. This method should do the work to save the class instance inside of the package, and should return a tuple of the corresponding de-packaging function with the arguments needed to invoke the de-packaging function. This method is called by the ``PackageExporter`` when it encounters an instance of the target class. -2. Define a de-packaging function for the class. This de-packaging function should do the work to reconstruct and return an instance of the class. The function signature’s first parameter should be a ``PackageImporter`` instance, and the rest of the parameters are user defined. - - -:: - - # foo.py [Example of customizing how class Foo is packaged] - from torch.package import PackageExporter, PackageImporter - import time - - - class Foo: - def __init__(self, my_string: str): - super().__init__() - self.my_string = my_string - self.time_imported = 0 - self.time_exported = 0 - - def __reduce_package__(self, exporter: PackageExporter): - """ - Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when - saving an instance of this object. This method should do the work to save this - object inside of the ``torch.package`` archive. - - Returns function w/ arguments to load the object from a - ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function. - """ - - # use this pattern to ensure no naming conflicts with normal dependencies, - # anything saved under this module name shouldn't conflict with other - # items in the package - generated_module_name = f"foo-generated._{exporter.get_unique_id()}" - exporter.save_text( - generated_module_name, - "foo.txt", - self.my_string + ", with exporter modification!", - ) - time_exported = time.clock_gettime(1) - - # returns de-packaging function w/ arguments to invoke with - return (unpackage_foo, (generated_module_name, time_exported,)) - - - def unpackage_foo( - importer: PackageImporter, generated_module_name: str, time_exported: float - ) -> Foo: - """ - Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function - when depickling a Foo object. - Performs work of loading and returning a Foo instance from a ``torch.package`` archive. - """ - time_imported = time.clock_gettime(1) - foo = Foo(importer.load_text(generated_module_name, "foo.txt")) - foo.time_imported = time_imported - foo.time_exported = time_exported - return foo - - -:: - - # example of saving instances of class Foo - - import torch - from torch.package import PackageImporter, PackageExporter - import foo - - foo_1 = foo.Foo("foo_1 initial string") - foo_2 = foo.Foo("foo_2 initial string") - with PackageExporter('foo_package.pt') as pe: - # save as normal, no extra work necessary - pe.save_pickle('foo_collection', 'foo1.pkl', foo_1) - pe.save_pickle('foo_collection', 'foo2.pkl', foo_2) - - pi = PackageImporter('foo_package.pt') - print(pi.file_structure()) - imported_foo = pi.load_pickle('foo_collection', 'foo1.pkl') - print(f"foo_1 string: '{imported_foo.my_string}'") - print(f"foo_1 export time: {imported_foo.time_exported}") - print(f"foo_1 import time: {imported_foo.time_imported}") - - -:: - - # output of running above script - ─── foo_package - ├── foo-generated - │ ├── _0 - │ │ └── foo.txt - │ └── _1 - │ └── foo.txt - ├── foo_collection - │ ├── foo1.pkl - │ └── foo2.pkl - └── foo.py - - foo_1 string: 'foo_1 initial string, with reduction modification!' - foo_1 export time: 9857706.650140837 - foo_1 import time: 9857706.652698385 - - -Test in my source code whether or not it is executing inside a package? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -A :class:`PackageImporter` will add the attribute ``__torch_package__`` to every module that it initializes. Your code can check for the -presence of this attribute to determine whether it is executing in a packaged context or not. - - -:: - - # In foo/bar.py: - - if "__torch_package__" in dir(): # true if the code is being loaded from a package - def is_in_package(): - return True - - UserException = Exception - else: - def is_in_package(): - return False - - UserException = UnpackageableException - - -Now, the code will behave differently depending on whether it’s imported normally through your Python environment or imported from a -``torch.package``. - - -:: - - from foo.bar import is_in_package - - print(is_in_package()) # False - - loaded_module = PackageImporter(my_package).import_module("foo.bar") - loaded_module.is_in_package() # True - - -**Warning**: in general, it’s bad practice to have code that behaves differently depending on whether it’s packaged or not. This can lead to -hard-to-debug issues that are sensitive to how you imported your code. If your package is intended to be heavily used, consider restructuring -your code so that it behaves the same way no matter how it was loaded. - - -Patch code into a package? -^^^^^^^^^^^^^^^^^^^^^^^^^^ -:class:`PackageExporter` offers a ``save_source_string()`` method that allows one to save arbitrary Python source code to a module of your choosing. - - -:: - - with PackageExporter(f) as exporter: - # Save the my_module.foo available in your current Python environment. - exporter.save_module("my_module.foo") - - # This saves the provided string to my_module/foo.py in the package archive. - # It will override the my_module.foo that was previously saved. - exporter.save_source_string("my_module.foo", textwrap.dedent( - """\ - def my_function(): - print('hello world') - """ - )) - - # If you want to treat my_module.bar as a package - # (e.g. save to `my_module/bar/__init__.py` instead of `my_module/bar.py) - # pass is_package=True, - exporter.save_source_string("my_module.bar", - "def foo(): print('hello')\n", - is_package=True) - - importer = PackageImporter(f) - importer.import_module("my_module.foo").my_function() # prints 'hello world' - - -Access package contents from packaged code? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -:class:`PackageImporter` implements the -`importlib.resources `_ -API for accessing resources from inside a package. - - -:: - - with PackageExporter(f) as exporter: - # saves text to my_resource/a.txt in the archive - exporter.save_text("my_resource", "a.txt", "hello world!") - # saves the tensor to my_pickle/obj.pkl - exporter.save_pickle("my_pickle", "obj.pkl", torch.ones(2, 2)) - - # see below for module contents - exporter.save_module("foo") - exporter.save_module("bar") - - -The ``importlib.resources`` API allows access to resources from within packaged code. - - -:: - - # foo.py: - import importlib.resources - import my_resource - - # returns "hello world!" - def get_my_resource(): - return importlib.resources.read_text(my_resource, "a.txt") - - -Using ``importlib.resources`` is the recommended way to access package contents from within packaged code, since it complies -with the Python standard. However, it is also possible to access the parent :class:`PackageImporter` instance itself from within -packaged code. - - -:: - - # bar.py: - import torch_package_importer # this is the PackageImporter that imported this module. - - # Prints "hello world!", equivalent to importlib.resources.read_text - def get_my_resource(): - return torch_package_importer.load_text("my_resource", "a.txt") - - # You also do things that the importlib.resources API does not support, like loading - # a pickled object from the package. - def get_my_pickle(): - return torch_package_importer.load_pickle("my_pickle", "obj.pkl") - - -Distinguish between packaged code and non-packaged code? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -To tell if an object’s code is from a ``torch.package``, use the ``torch.package.is_from_package()`` function. -Note: if an object is from a package but its definition is from a module marked ``extern`` or from ``stdlib``, -this check will return ``False``. - - -:: - - importer = PackageImporter(f) - mod = importer.import_module('foo') - obj = importer.load_pickle('model', 'model.pkl') - txt = importer.load_text('text', 'my_test.txt') - - assert is_from_package(mod) - assert is_from_package(obj) - assert not is_from_package(txt) # str is from stdlib, so this will return False - - -Re-export an imported object? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -To re-export an object that was previously imported by a :class:`PackageImporter`, you must make the new :class:`PackageExporter` -aware of the original :class:`PackageImporter` so that it can find source code for your object’s dependencies. - - -:: - - importer = PackageImporter(f) - obj = importer.load_pickle("model", "model.pkl") - - # re-export obj in a new package - with PackageExporter(f2, importer=(importer, sys_importer)) as exporter: - exporter.save_pickle("model", "model.pkl", obj) - - -Package a TorchScript module? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -To package a TorchScript model, use the same ``save_pickle`` and ``load_pickle`` APIs as you would with any other object. -Saving TorchScript objects that are attributes or submodules is supported as well with no extra work. - - -:: - - # save TorchScript just like any other object - with PackageExporter(file_name) as e: - e.save_pickle("res", "script_model.pkl", scripted_model) - e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule) - # load as normal - importer = PackageImporter(file_name) - loaded_script = importer.load_pickle("res", "script_model.pkl") - loaded_mixed = importer.load_pickle("res", "mixed_model.pkl" - - -Explanation ------------ -``torch.package`` Format Overview -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -A ``torch.package`` file is a ZIP archive which conventionally uses the ``.pt`` extension. Inside the ZIP archive, there are two kinds of files: - -* Framework files, which are placed in the ``.data/``. -* User files, which is everything else. - -As an example, this is what a fully packaged ResNet model from ``torchvision`` looks like: - - -:: - - resnet - ├── .data # All framework-specific data is stored here. - │ │ # It's named to avoid conflicts with user-serialized code. - │ ├── 94286146172688.storage # tensor data - │ ├── 94286146172784.storage - │ ├── extern_modules # text file with names of extern modules (e.g. 'torch') - │ ├── version # version metadata - │ ├── ... - ├── model # the pickled model - │ └── model.pkl - └── torchvision # all code dependencies are captured as source files - └── models - ├── resnet.py - └── utils.py - - -Framework files -""""""""""""""" -The ``.data/`` directory is owned by torch.package, and its contents are considered to be a private implementation detail. -The ``torch.package`` format makes no guarantees about the contents of ``.data/``, but any changes made will be backward compatible -(that is, newer version of PyTorch will always be able to load older ``torch.packages``). - -Currently, the ``.data/`` directory contains the following items: - -* ``version``: a version number for the serialized format, so that the ``torch.package`` import infrastructures knows how to load this package. -* ``extern_modules``: a list of modules that are considered ``extern``. ``extern`` modules will be imported using the loading environment’s system importer. -* ``*.storage``: serialized tensor data. - - -:: - - .data - ├── 94286146172688.storage - ├── 94286146172784.storage - ├── extern_modules - ├── version - ├── ... - - -User files -"""""""""" -All other files in the archive were put there by a user. The layout is identical to a Python -`regular package `_. For a deeper dive in how Python packaging works, -please consult `this essay `_ (it’s slightly out of date, so double-check implementation details -with the `Python reference documentation `_). - - -:: - - - ├── model # the pickled model - │ └── model.pkl - ├── another_package - │ ├── __init__.py - │ ├── foo.txt # a resource file , see importlib.resources - │ └── ... - └── torchvision - └── models - ├── resnet.py # torchvision.models.resnet - └── utils.py # torchvision.models.utils - - -How ``torch.package`` finds your code's dependencies -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Analyzing an object's dependencies -"""""""""""""""""""""""""""""""""" -When you issue a ``save_pickle(obj, ...)`` call, :class:`PackageExporter` will pickle the object normally. Then, it uses the -``pickletools`` standard library module to parse the pickle bytecode. - -In a pickle, an object is saved along with a ``GLOBAL`` opcode that describes where to find the implementation of the object’s type, like: - - -:: - - GLOBAL 'torchvision.models.resnet Resnet` - - -The dependency resolver will gather up all ``GLOBAL`` ops and mark them as dependencies of your pickled object. -For more information about pickling and the pickle format, please consult `the Python docs `_. - -Analyzing a module's dependencies -""""""""""""""""""""""""""""""""" -When a Python module is identified as a dependency, ``torch.package`` walks the module’s python AST representation and looks for import statements with -full support for the standard forms: ``from x import y``, ``import z``, ``from w import v as u``, etc. When one of these import statements are -encountered, ``torch.package`` registers the imported modules as dependencies that are then themselves parsed in the same AST walking way. - -**Note**: AST parsing has limited support for the ``__import__(...)`` syntax and does not support ``importlib.import_module`` calls. In general, you should -not expect dynamic imports to be detected by ``torch.package``. - - -Dependency Management -^^^^^^^^^^^^^^^^^^^^^ -``torch.package`` automatically finds the Python modules that your code and objects depend on. This process is called dependency resolution. -For each module that the dependency resolver finds, you must specify an *action* to take. - -The allowed actions are: - -* ``intern``: put this module into the package. -* ``extern``: declare this module as an external dependency of the package. -* ``mock``: stub out this module. -* ``deny``: depending on this module will raise an error during package export. - -Finally, there is one more important action that is not technically part of ``torch.package``: - -* Refactoring: remove or change the dependencies in your code. - -Note that actions are only defined on entire Python modules. There is no way to package “just” a function or class from a module and leave the rest out. -This is by design. Python does not offer clean boundaries between objects defined in a module. The only defined unit of dependency organization is a -module, so that’s what ``torch.package`` uses. - -Actions are applied to modules using patterns. Patterns can either be module names (``"foo.bar"``) or globs (like ``"foo.**"``). You associate a pattern -with an action using methods on :class:`PackageExporter`, e.g. - - -:: - - my_exporter.intern("torchvision.**") - my_exporter.extern("numpy") - - -If a module matches a pattern, the corresponding action is applied to it. For a given module, patterns will be checked in the order that they were defined, -and the first action will be taken. - - -``intern`` -"""""""""" -If a module is ``intern``-ed, it will be placed into the package. - -This action is your model code, or any related code you want to package. For example, if you are trying to package a ResNet from ``torchvision``, -you will need to ``intern`` the module torchvision.models.resnet. - -On package import, when your packaged code tries to import an ``intern``-ed module, PackageImporter will look inside your package for that module. -If it can’t find that module, an error will be raised. This ensures that each :class:`PackageImporter` is isolated from the loading environment—even -if you have ``my_interned_module`` available in both your package and the loading environment, :class:`PackageImporter` will only use the version in your -package. - -**Note**: Only Python source modules can be ``intern``-ed. Other kinds of modules, like C extension modules and bytecode modules, will raise an error if -you attempt to ``intern`` them. These kinds of modules need to be ``mock``-ed or ``extern``-ed. - - -``extern`` -"""""""""" -If a module is ``extern``-ed, it will not be packaged. Instead, it will be added to a list of external dependencies for this package. You can find this -list on ``package_exporter.extern_modules``. - -On package import, when the packaged code tries to import an ``extern``-ed module, :class:`PackageImporter` will use the default Python importer to find -that module, as if you did ``importlib.import_module("my_externed_module")``. If it can’t find that module, an error will be raised. - -In this way, you can depend on third-party libraries like ``numpy`` and ``scipy`` from within your package without having to package them too. - -**Warning**: If any external library changes in a backwards-incompatible way, your package may fail to load. If you need long-term reproducibility -for your package, try to limit your use of ``extern``. - - -``mock`` -"""""""" -If a module is ``mock``-ed, it will not be packaged. Instead a stub module will be packaged in its place. The stub module will allow you to retrieve -objects from it (so that ``from my_mocked_module import foo`` will not error), but any use of that object will raise a ``NotImplementedError``. - -``mock`` should be used for code that you “know” will not be needed in the loaded package, but you still want available for use in non-packaged contents. -For example, initialization/configuration code, or code only used for debugging/training. - -**Warning**: In general, ``mock`` should be used as a last resort. It introduces behavioral differences between packaged code and non-packaged code, -which may lead to later confusion. Prefer instead to refactor your code to remove unwanted dependencies. - - -Refactoring -""""""""""" -The best way to manage dependencies is to not have dependencies at all! Often, code can be refactored to remove unnecessary dependencies. Here are some -guidelines for writing code with clean dependencies (which are also generally good practices!): - -**Include only what you use**. Do not leave unused imports in your code. The dependency resolver is not smart enough to tell that they are indeed unused, -and will try to process them. - -**Qualify your imports**. For example, instead of writing import foo and later using ``foo.bar.baz``, prefer to write ``from foo.bar import baz``. This more -precisely specifies your real dependency (``foo.bar``) and lets the dependency resolver know you don’t need all of ``foo``. - -**Split up large files with unrelated functionality into smaller ones**. If your ``utils`` module contains a hodge-podge of unrelated functionality, any module -that depends on ``utils`` will need to pull in lots of unrelated dependencies, even if you only needed a small part of it. Prefer instead to define -single-purpose modules that can be packaged independently of one another. - - -Patterns -"""""""" -Patterns allow you to specify groups of modules with a convenient syntax. The syntax and behavior of patterns follows the Bazel/Buck -`glob() `_. - -A module that we are trying to match against a pattern is called a candidate. A candidate is composed of a list of segments separated by a -separator string, e.g. ``foo.bar.baz``. - -A pattern contains one or more segments. Segments can be: - -* A literal string (e.g. ``foo``), which matches exactly. -* A string containing a wildcard (e.g. ``torch``, or ``foo*baz*``). The wildcard matches any string, including the empty string. -* A double wildcard (``**``). This matches against zero or more complete segments. - -Examples: - -* ``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``. -* ``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional`` or ``torch`` -* ``torch*.**``: matches ``torch``, ``torchvision``, and all of their submodules - -When specifying actions, you can pass multiple patterns, e.g. - - -:: - - exporter.intern(["torchvision.models.**", "torchvision.utils.**"]) - - -A module will match against this action if it matches any of the patterns. - -You can also specify patterns to exclude, e.g. - - -:: - - exporter.mock("**", exclude=["torchvision.**"]) - - -A module will not match against this action if it matches any of the exclude patterns. In this example, we are mocking all modules except -``torchvision`` and its submodules. - -When a module could potentially match against multiple actions, the first action defined will be taken. - - -``torch.package`` sharp edges -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Avoid global state in your modules -"""""""""""""""""""""""""""""""""" -Python makes it really easy to bind objects and run code at module-level scope. This is generally fine—after all, functions and classes are bound to -names this way. However, things become more complicated when you define an object at module scope with the intention of mutating it, introducing mutable -global state. - -Mutable global state is quite useful—it can reduce boilerplate, allow for open registration into tables, etc. But unless employed very carefully, it can -cause complications when used with ``torch.package``. - -Every :class:`PackageImporter` creates an independent environment for its contents. This is nice because it means we load multiple packages and ensure -they are isolated from each other, but when modules are written in a way that assumes shared mutable global state, this behavior can create hard-to-debug -errors. - -Types are not shared between packages and the loading environment -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" -Any class that you import from a :class:`PackageImporter` will be a version of the class specific to that importer. For example: - - -:: - - from foo import MyClass - - my_class_instance = MyClass() - - with PackageExporter(f) as exporter: - exporter.save_module("foo") - - importer = PackageImporter(f) - imported_MyClass = importer.import_module("foo").MyClass - - assert isinstance(my_class_instance, MyClass) # works - assert isinstance(my_class_instance, imported_MyClass) # ERROR! - - -In this example, ``MyClass`` and ``imported_MyClass`` are *not the same type*. In this specific example, ``MyClass`` and ``imported_MyClass`` have exactly the -same implementation, so you might think it’s okay to consider them the same class. But consider the situation where ``imported_MyClass`` is coming from an -older package with an entirely different implementation of ``MyClass`` — in that case, it’s unsafe to consider them the same class. - -Under the hood, each importer has a prefix that allows it to uniquely identify classes: - - -:: - - print(MyClass.__name__) # prints "foo.MyClass" - print(imported_MyClass.__name__) # prints .foo.MyClass - - -That means you should not expect ``isinstance`` checks to work when one of the arguments is from a package and the other is not. If you need this -functionality, consider the following options: - -* Doing duck typing (just using the class instead of explicitly checking that it is of a given type). -* Make the typing relationship an explicit part of the class contract. For example, you can add an attribute tag ``self.handler = "handle_me_this_way"`` and have client code check for the value of ``handler`` instead of checking the type directly. - - -How ``torch.package`` keeps packages isolated from each other -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Each :class:`PackageImporter` instance creates an independent, isolated environment for its modules and objects. Modules in a package can only import -other packaged modules, or modules marked ``extern``. If you use multiple :class:`PackageImporter` instances to load a single package, you will get -multiple independent environments that do not interact. - -This is achieved by extending Python’s import infrastructure with a custom importer. :class:`PackageImporter` provides the same core API as the -``importlib`` importer; namely, it implements the ``import_module`` and ``__import__`` methods. - -When you invoke :meth:`PackageImporter.import_module`, :class:`PackageImporter` will construct and return a new module, much as the system importer does. -However, :class:`PackageImporter` patches the returned module to use ``self`` (i.e. that :class:`PackageImporter` instance) to fulfill future import -requests by looking in the package rather than searching the user’s Python environment. - -Mangling -"""""""" -To avoid confusion (“is this ``foo.bar`` object the one from my package, or the one from my Python environment?”), :class:`PackageImporter` mangles the -``__name__`` and ``__file__`` of all imported modules, by adding a *mangle prefix* to them. - -For ``__name__``, a name like ``torchvision.models.resnet18`` becomes ``.torchvision.models.resnet18``. - -For ``__file__``, a name like ``torchvision/models/resnet18.py`` becomes ``.torchvision/modules/resnet18.py``. - -Name mangling helps avoid inadvertent punning of module names between different packages, and helps you debug by making stack traces and print -statements more clearly show whether they are referring to packaged code or not. For developer-facing details about mangling, consult -``mangling.md`` in ``torch/package/``. - - -API Reference -------------- -.. autoclass:: torch.package.PackagingError - -.. autoclass:: torch.package.EmptyMatchError - -.. autoclass:: torch.package.PackageExporter - :members: - - .. automethod:: __init__ - -.. autoclass:: torch.package.PackageImporter - :members: - - .. automethod:: __init__ - -.. autoclass:: torch.package.Directory - :members: - - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.package.analyze.find_first_use_of_broken_modules -.. py:module:: torch.package.analyze.is_from_package -.. py:module:: torch.package.analyze.trace_dependencies -.. py:module:: torch.package.file_structure_representation -.. py:module:: torch.package.find_file_dependencies -.. py:module:: torch.package.glob_group -.. py:module:: torch.package.importer -.. py:module:: torch.package.package_exporter -.. py:module:: torch.package.package_importer diff --git a/docs/source/profiler.md b/docs/source/profiler.md new file mode 100644 index 00000000000000..1578b7334d849e --- /dev/null +++ b/docs/source/profiler.md @@ -0,0 +1,49 @@ +```{eval-rst} +.. currentmodule:: torch.profiler +``` + +# torch.profiler + +## Overview +```{eval-rst} +.. automodule:: torch.profiler +``` + +## API Reference +```{eval-rst} +.. autoclass:: torch.profiler._KinetoProfile + :members: + +.. autoclass:: torch.profiler.profile + :members: + +.. autoclass:: torch.profiler.ProfilerAction + :members: + +.. autoclass:: torch.profiler.ProfilerActivity + :members: + +.. autofunction:: torch.profiler.schedule + +.. autofunction:: torch.profiler.tensorboard_trace_handler +``` + +## Intel Instrumentation and Tracing Technology APIs + +```{eval-rst} +.. autofunction:: torch.profiler.itt.is_available + +.. autofunction:: torch.profiler.itt.mark + +.. autofunction:: torch.profiler.itt.range_push + +.. autofunction:: torch.profiler.itt.range_pop +``` + + +```{eval-rst} +.. py:module:: torch.profiler.itt +.. py:module:: torch.profiler.profiler +.. py:module:: torch.profiler.python_tracer +``` diff --git a/docs/source/profiler.rst b/docs/source/profiler.rst deleted file mode 100644 index 38871882fa2ad9..00000000000000 --- a/docs/source/profiler.rst +++ /dev/null @@ -1,45 +0,0 @@ -.. currentmodule:: torch.profiler - -torch.profiler -============== - -Overview --------- -.. automodule:: torch.profiler - - -API Reference -------------- - -.. autoclass:: torch.profiler._KinetoProfile - :members: - -.. autoclass:: torch.profiler.profile - :members: - -.. autoclass:: torch.profiler.ProfilerAction - :members: - -.. autoclass:: torch.profiler.ProfilerActivity - :members: - -.. autofunction:: torch.profiler.schedule - -.. autofunction:: torch.profiler.tensorboard_trace_handler - -Intel Instrumentation and Tracing Technology APIs -------------------------------------------------- - -.. autofunction:: torch.profiler.itt.is_available - -.. autofunction:: torch.profiler.itt.mark - -.. autofunction:: torch.profiler.itt.range_push - -.. autofunction:: torch.profiler.itt.range_pop - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.profiler.itt -.. py:module:: torch.profiler.profiler -.. py:module:: torch.profiler.python_tracer \ No newline at end of file diff --git a/docs/source/quantization-accuracy-debugging.md b/docs/source/quantization-accuracy-debugging.md new file mode 100644 index 00000000000000..d13d83129570a5 --- /dev/null +++ b/docs/source/quantization-accuracy-debugging.md @@ -0,0 +1,96 @@ +# Quantization Accuracy Debugging + +This document provides high level strategies for improving quantization +accuracy. If a quantized model has error compared to the original model, +we can categorize the error into: + +1. **data insensitive error** - caused by intrinsic model quantization error, + large portion of input data has large error +2. **data sensitive error** - caused by outlier input data, small + portion of input data has large error +3. **implementation error** - quantized kernel is not matching reference implementation + +## Data insensitive error + +### General tips + +1. For PTQ, ensure that the data you are calibrating with is representative + of your dataset. For example, for a classification problem a general + guideline is to have multiple samples in every category, and the overall + number of samples should be at least 100. There is no penalty for + calibrating with more data other than calibration time. +2. If your model has Conv-BN or Linear-BN patterns, consider fusing them. + If you are using FX graph mode quantization, this is done automatically + by the workflow. If you are using Eager mode quantization, you can do + this manually with the ``torch.ao.quantization.fuse_modules`` API. +3. Increase the precision of dtype of the problematic ops. Usually, fp32 + will have the highest accuracy, followed by fp16, followed by dynamically + quantized int8, followed by statically quantized int8. + + 1. Note: this is trading off performance for accuracy. + 2. Note: availability of kernels per dtype per op can vary by backend. + 3. Note: dtype conversions add an additional performance cost. For example, + ``fp32_op -> quant -> int8_op -> dequant -> fp32_op -> quant -> int8_op -> dequant`` + will have a performance penalty compared to + ``fp32_op -> fp32_op -> quant -> int8_op -> int8_op -> dequant`` + because of a higher number of required dtype conversions. + +4. If you are using PTQ, consider using QAT to recover some of the accuracy loss + from quantization. + +### Int8 quantization tips + +1. If you are using per-tensor weight quantization, consider using per-channel + weight quantization. +2. If you are doing inference on `fbgemm`, ensure that you set the `reduce_range` + argument to `False` if your CPU is Cooperlake or newer, and to `True` otherwise. +3. Audit the input activation distribution variation across different samples. + If this variation is high, the layer may be suitable for dynamic quantization + but not static quantization. + +## Data sensitive error + +If you are using static quantization and a small portion of your input data is +resulting in high quantization error, you can try: + +1. Adjust your calibration dataset to make it more representative of your + inference dataset. +2. Manually inspect (using Numeric Suite) which layers have high quantization + error. For these layers, consider leaving them in floating point or adjusting + the observer settings to choose a better scale and zero_point. + + +## Implementation error + +If you are using PyTorch quantization with your own backend +you may see differences between the reference implementation of an +operation (such as ``dequant -> op_fp32 -> quant``) and the quantized implementation +(such as `op_int8`) of the op on the target hardware. This could mean one of two things: + +1. the differences (usually small) are expected due to specific behavior of + the target kernel on the target hardware compared to fp32/cpu. An example of this + is accumulating in an integer dtype. Unless the kernel guarantees bitwise + equivalency with the reference implementation, this is expected. +2. the kernel on the target hardware has an accuracy issue. In this case, reach + out to the kernel developer. + +## Numerical Debugging Tooling (prototype) + +```{eval-rst} +.. toctree:: + :hidden: + + torch.ao.ns._numeric_suite + torch.ao.ns._numeric_suite_fx +``` + +```{warning} +Numerical debugging tooling is early prototype and subject to change. +``` + +```{eval-rst} +* :ref:`torch_ao_ns_numeric_suite` + Eager mode numeric suite +* :ref:`torch_ao_ns_numeric_suite_fx` + FX numeric suite +``` diff --git a/docs/source/quantization-accuracy-debugging.rst b/docs/source/quantization-accuracy-debugging.rst deleted file mode 100644 index 0fa590abd2f0cb..00000000000000 --- a/docs/source/quantization-accuracy-debugging.rst +++ /dev/null @@ -1,98 +0,0 @@ -Quantization Accuracy Debugging -------------------------------- - -This document provides high level strategies for improving quantization -accuracy. If a quantized model has error compared to the original model, -we can categorize the error into: - -1. **data insensitive error** - caused by intrinsic model quantization error, - large portion of input data has large error -2. **data sensitive error** - caused by outlier input data, small - portion of input data has large error -3. **implementation error** - quantized kernel is not matching reference implementation - -Data insensitive error -~~~~~~~~~~~~~~~~~~~~~~ - -General tips -^^^^^^^^^^^^ - -1. For PTQ, ensure that the data you are calibrating with is representative - of your dataset. For example, for a classification problem a general - guideline is to have multiple samples in every category, and the overall - number of samples should be at least 100. There is no penalty for - calibrating with more data other than calibration time. -2. If your model has Conv-BN or Linear-BN patterns, consider fusing them. - If you are using FX graph mode quantization, this is done automatically - by the workflow. If you are using Eager mode quantization, you can do - this manually with the ``torch.ao.quantization.fuse_modules`` API. -3. Increase the precision of dtype of the problematic ops. Usually, fp32 - will have the highest accuracy, followed by fp16, followed by dynamically - quantized int8, followed by statically quantized int8. - - 1. Note: this is trading off performance for accuracy. - 2. Note: availability of kernels per dtype per op can vary by backend. - 3. Note: dtype conversions add an additional performance cost. For example, - ``fp32_op -> quant -> int8_op -> dequant -> fp32_op -> quant -> int8_op -> dequant`` - will have a performance penalty compared to - ``fp32_op -> fp32_op -> quant -> int8_op -> int8_op -> dequant`` - because of a higher number of required dtype conversions. - -4. If you are using PTQ, consider using QAT to recover some of the accuracy loss - from quantization. - -Int8 quantization tips -^^^^^^^^^^^^^^^^^^^^^^ - -1. If you are using per-tensor weight quantization, consider using per-channel - weight quantization. -2. If you are doing inference on `fbgemm`, ensure that you set the `reduce_range` - argument to `False` if your CPU is Cooperlake or newer, and to `True` otherwise. -3. Audit the input activation distribution variation across different samples. - If this variation is high, the layer may be suitable for dynamic quantization - but not static quantization. - -Data sensitive error -~~~~~~~~~~~~~~~~~~~~ - -If you are using static quantization and a small portion of your input data is -resulting in high quantization error, you can try: - -1. Adjust your calibration dataset to make it more representative of your - inference dataset. -2. Manually inspect (using Numeric Suite) which layers have high quantization - error. For these layers, consider leaving them in floating point or adjusting - the observer settings to choose a better scale and zero_point. - - -Implementation error -~~~~~~~~~~~~~~~~~~~~ - -If you are using PyTorch quantization with your own backend -you may see differences between the reference implementation of an -operation (such as ``dequant -> op_fp32 -> quant``) and the quantized implementation -(such as `op_int8`) of the op on the target hardware. This could mean one of two things: - -1. the differences (usually small) are expected due to specific behavior of - the target kernel on the target hardware compared to fp32/cpu. An example of this - is accumulating in an integer dtype. Unless the kernel guarantees bitwise - equivalency with the reference implementation, this is expected. -2. the kernel on the target hardware has an accuracy issue. In this case, reach - out to the kernel developer. - -Numerical Debugging Tooling (prototype) ---------------------------------------- - -.. toctree:: - :hidden: - - torch.ao.ns._numeric_suite - torch.ao.ns._numeric_suite_fx - -.. warning :: - Numerical debugging tooling is early prototype and subject to change. - -* :ref:`torch_ao_ns_numeric_suite` - Eager mode numeric suite -* :ref:`torch_ao_ns_numeric_suite_fx` - FX numeric suite diff --git a/docs/source/quantization-backend-configuration.md b/docs/source/quantization-backend-configuration.md new file mode 100644 index 00000000000000..fb28fbef543872 --- /dev/null +++ b/docs/source/quantization-backend-configuration.md @@ -0,0 +1,19 @@ +# Quantization Backend Configuration + +FX Graph Mode Quantization allows the user to configure various +quantization behaviors of an op in order to match the expectation +of their backend. + +In the future, this document will contain a detailed spec of +these configurations. + +## Default values for native configurations + +Below is the output of the configuration for quantization of ops +in x86 and qnnpack (PyTorch's default quantized backends). + +Results: + +```{eval-rst} +.. literalinclude:: scripts/quantization_backend_configs/default_backend_config.txt +``` diff --git a/docs/source/quantization-backend-configuration.rst b/docs/source/quantization-backend-configuration.rst deleted file mode 100644 index bfe93ce701e623..00000000000000 --- a/docs/source/quantization-backend-configuration.rst +++ /dev/null @@ -1,20 +0,0 @@ -Quantization Backend Configuration ----------------------------------- - -FX Graph Mode Quantization allows the user to configure various -quantization behaviors of an op in order to match the expectation -of their backend. - -In the future, this document will contain a detailed spec of -these configurations. - - -Default values for native configurations -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Below is the output of the configuration for quantization of ops -in x86 and qnnpack (PyTorch's default quantized backends). - -Results: - -.. literalinclude:: scripts/quantization_backend_configs/default_backend_config.txt diff --git a/docs/source/quantization-support.md b/docs/source/quantization-support.md new file mode 100644 index 00000000000000..2f17a06265954c --- /dev/null +++ b/docs/source/quantization-support.md @@ -0,0 +1,785 @@ +# Quantization API Reference + +## torch.ao.quantization + +This module contains Eager mode quantization APIs. + +```{eval-rst} +.. currentmodule:: torch.ao.quantization +``` + +### Top level APIs + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + quantize + quantize_dynamic + quantize_qat + prepare + prepare_qat + convert +``` + +### Preparing model for quantization + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + fuse_modules.fuse_modules + QuantStub + DeQuantStub + QuantWrapper + add_quant_dequant +``` + +### Utility functions + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + swap_module + propagate_qconfig_ + default_eval_fn +``` + +## torch.ao.quantization.quantize_fx + +This module contains FX graph mode quantization APIs (prototype). + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.quantize_fx +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + prepare_fx + prepare_qat_fx + convert_fx + fuse_fx +``` + +## torch.ao.quantization.qconfig_mapping + +This module contains QConfigMapping for configuring FX graph mode quantization. + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.qconfig_mapping +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + QConfigMapping + get_default_qconfig_mapping + get_default_qat_qconfig_mapping +``` + +## torch.ao.quantization.backend_config + +This module contains BackendConfig, a config object that defines how quantization is supported +in a backend. Currently only used by FX Graph Mode Quantization, but we may extend Eager Mode +Quantization to work with this as well. + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.backend_config +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + BackendConfig + BackendPatternConfig + DTypeConfig + DTypeWithConstraints + ObservationType +``` + +## torch.ao.quantization.fx.custom_config + +This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.fx.custom_config +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + FuseCustomConfig + PrepareCustomConfig + ConvertCustomConfig + StandaloneModuleConfigEntry +``` + +## torch.ao.quantization.quantizer + +```{eval-rst} +.. automodule:: torch.ao.quantization.quantizer +``` + +## torch.ao.quantization.pt2e (quantization in pytorch 2.0 export implementation) + +```{eval-rst} +.. automodule:: torch.ao.quantization.pt2e +.. automodule:: torch.ao.quantization.pt2e.representation +``` + +## torch.ao.quantization.pt2e.export_utils + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.pt2e.export_utils +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + model_is_exported +``` + +```{eval-rst} +.. currentmodule:: torch.ao.quantization +``` + +## torch.ao.quantization.pt2e.lowering + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.pt2e.lowering +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + lower_pt2e_quantized_to_x86 +``` + +```{eval-rst} +.. currentmodule:: torch.ao.quantization +``` + +## PT2 Export (pt2e) Numeric Debugger + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + generate_numeric_debug_handle + CUSTOM_KEY + NUMERIC_DEBUG_HANDLE_KEY + prepare_for_propagation_comparison + extract_results_from_loggers + compare_results +``` + +## torch (quantization related functions) + +This describes the quantization related functions of the `torch` namespace. + +```{eval-rst} +.. currentmodule:: torch +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + quantize_per_tensor + quantize_per_channel + dequantize +``` + +## torch.Tensor (quantization related methods) + +Quantized Tensors support a limited subset of data manipulation methods of the +regular full-precision tensor. + +```{eval-rst} +.. currentmodule:: torch.Tensor +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + view + as_strided + expand + flatten + select + ne + eq + ge + le + gt + lt + copy_ + clone + dequantize + equal + int_repr + max + mean + min + q_scale + q_zero_point + q_per_channel_scales + q_per_channel_zero_points + q_per_channel_axis + resize_ + sort + topk +``` + +## torch.ao.quantization.observer + +This module contains observers which are used to collect statistics about +the values observed during calibration (PTQ) or training (QAT). + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.observer +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + ObserverBase + MinMaxObserver + MovingAverageMinMaxObserver + PerChannelMinMaxObserver + MovingAveragePerChannelMinMaxObserver + HistogramObserver + PlaceholderObserver + RecordingObserver + NoopObserver + get_observer_state_dict + load_observer_state_dict + default_observer + default_placeholder_observer + default_debug_observer + default_weight_observer + default_histogram_observer + default_per_channel_weight_observer + default_dynamic_quant_observer + default_float_qparams_observer + AffineQuantizedObserverBase + Granularity + MappingType + PerAxis + PerBlock + PerGroup + PerRow + PerTensor + PerToken + TorchAODType + ZeroPointDomain + get_block_size +``` + +## torch.ao.quantization.fake_quantize + +This module implements modules which are used to perform fake quantization +during QAT. + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.fake_quantize +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + FakeQuantizeBase + FakeQuantize + FixedQParamsFakeQuantize + FusedMovingAvgObsFakeQuantize + default_fake_quant + default_weight_fake_quant + default_per_channel_weight_fake_quant + default_histogram_fake_quant + default_fused_act_fake_quant + default_fused_wt_fake_quant + default_fused_per_channel_wt_fake_quant + disable_fake_quant + enable_fake_quant + disable_observer + enable_observer +``` + +## torch.ao.quantization.qconfig + +This module defines `QConfig` objects which are used +to configure quantization settings for individual ops. + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.qconfig +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + QConfig + default_qconfig + default_debug_qconfig + default_per_channel_qconfig + default_dynamic_qconfig + float16_dynamic_qconfig + float16_static_qconfig + per_channel_dynamic_qconfig + float_qparams_weight_only_qconfig + default_qat_qconfig + default_weight_only_qconfig + default_activation_only_qconfig + default_qat_qconfig_v2 +``` + +## torch.ao.nn.intrinsic + +```{eval-rst} +.. automodule:: torch.ao.nn.intrinsic +.. automodule:: torch.ao.nn.intrinsic.modules +``` + +This module implements the combined (fused) modules conv + relu which can +then be quantized. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.intrinsic +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + ConvReLU1d + ConvReLU2d + ConvReLU3d + LinearReLU + ConvBn1d + ConvBn2d + ConvBn3d + ConvBnReLU1d + ConvBnReLU2d + ConvBnReLU3d + BNReLU2d + BNReLU3d +``` + +## torch.ao.nn.intrinsic.qat + +```{eval-rst} +.. automodule:: torch.ao.nn.intrinsic.qat +.. automodule:: torch.ao.nn.intrinsic.qat.modules +``` + +This module implements the versions of those fused operations needed for +quantization aware training. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.intrinsic.qat +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + LinearReLU + ConvBn1d + ConvBnReLU1d + ConvBn2d + ConvBnReLU2d + ConvReLU2d + ConvBn3d + ConvBnReLU3d + ConvReLU3d + update_bn_stats + freeze_bn_stats +``` + +## torch.ao.nn.intrinsic.quantized + +```{eval-rst} +.. automodule:: torch.ao.nn.intrinsic.quantized +.. automodule:: torch.ao.nn.intrinsic.quantized.modules +``` + +This module implements the quantized implementations of fused operations +like conv + relu. No BatchNorm variants as it's usually folded into convolution +for inference. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.intrinsic.quantized +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + BNReLU2d + BNReLU3d + ConvReLU1d + ConvReLU2d + ConvReLU3d + LinearReLU +``` + +## torch.ao.nn.intrinsic.quantized.dynamic + +```{eval-rst} +.. automodule:: torch.ao.nn.intrinsic.quantized.dynamic +.. automodule:: torch.ao.nn.intrinsic.quantized.dynamic.modules +``` + +This module implements the quantized dynamic implementations of fused operations +like linear + relu. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.intrinsic.quantized.dynamic +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + LinearReLU +``` + +## torch.ao.nn.qat + +```{eval-rst} +.. automodule:: torch.ao.nn.qat +.. automodule:: torch.ao.nn.qat.modules +``` + +This module implements versions of the key nn modules **Conv2d()** and +**Linear()** which run in FP32 but with rounding applied to simulate the +effect of INT8 quantization. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.qat +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + Conv2d + Conv3d + Linear +``` + +## torch.ao.nn.qat.dynamic + +```{eval-rst} +.. automodule:: torch.ao.nn.qat.dynamic +.. automodule:: torch.ao.nn.qat.dynamic.modules +``` + +This module implements versions of the key nn modules such as **Linear()** +which run in FP32 but with rounding applied to simulate the effect of INT8 +quantization and will be dynamically quantized during inference. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.qat.dynamic +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + Linear +``` + +## torch.ao.nn.quantized + +```{eval-rst} +.. automodule:: torch.ao.nn.quantized + :noindex: +.. automodule:: torch.ao.nn.quantized.modules +``` + +This module implements the quantized versions of the nn layers such as +`~torch.nn.Conv2d` and `torch.nn.ReLU`. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.quantized +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + ReLU6 + Hardswish + ELU + LeakyReLU + Sigmoid + BatchNorm2d + BatchNorm3d + Conv1d + Conv2d + Conv3d + ConvTranspose1d + ConvTranspose2d + ConvTranspose3d + Embedding + EmbeddingBag + FloatFunctional + FXFloatFunctional + QFunctional + Linear + LayerNorm + GroupNorm + InstanceNorm1d + InstanceNorm2d + InstanceNorm3d +``` + +## torch.ao.nn.quantized.functional + +```{eval-rst} +.. automodule:: torch.ao.nn.quantized.functional +``` + +```{eval-rst} +This module implements the quantized versions of the functional layers such as +`~torch.nn.functional.conv2d` and `torch.nn.functional.relu`. Note: +:math:`~torch.nn.functional.relu` supports quantized inputs. +``` + +```{eval-rst} +.. currentmodule:: torch.ao.nn.quantized.functional +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + avg_pool2d + avg_pool3d + adaptive_avg_pool2d + adaptive_avg_pool3d + conv1d + conv2d + conv3d + interpolate + linear + max_pool1d + max_pool2d + celu + leaky_relu + hardtanh + hardswish + threshold + elu + hardsigmoid + clamp + upsample + upsample_bilinear + upsample_nearest +``` + +## torch.ao.nn.quantizable + +This module implements the quantizable versions of some of the nn layers. +These modules can be used in conjunction with the custom module mechanism, +by providing the ``custom_module_config`` argument to both prepare and convert. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.quantizable +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + LSTM + MultiheadAttention +``` + +## torch.ao.nn.quantized.dynamic + +```{eval-rst} +.. automodule:: torch.ao.nn.quantized.dynamic +.. automodule:: torch.ao.nn.quantized.dynamic.modules +``` + +Dynamically quantized {class}`~torch.nn.Linear`, {class}`~torch.nn.LSTM`, +{class}`~torch.nn.LSTMCell`, {class}`~torch.nn.GRUCell`, and +{class}`~torch.nn.RNNCell`. + +```{eval-rst} +.. currentmodule:: torch.ao.nn.quantized.dynamic +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + Linear + LSTM + GRU + RNNCell + LSTMCell + GRUCell +``` + +## Quantized dtypes and quantization schemes + +Note that operator implementations currently only +support per channel quantization for weights of the **conv** and **linear** +operators. Furthermore, the input data is +mapped linearly to the quantized data and vice versa +as follows: + +```{eval-rst} + .. math:: + + \begin{aligned} + \text{Quantization:}&\\ + &Q_\text{out} = \text{clamp}(x_\text{input}/s+z, Q_\text{min}, Q_\text{max})\\ + \text{Dequantization:}&\\ + &x_\text{out} = (Q_\text{input}-z)*s + \end{aligned} +``` + +```{eval-rst} +where :math:`\text{clamp}(.)` is the same as :func:`~torch.clamp` while the +scale :math:`s` and zero point :math:`z` are then computed +as described in :class:`~torch.ao.quantization.observer.MinMaxObserver`, specifically: +``` + +```{eval-rst} + .. math:: + + \begin{aligned} + \text{if Symmetric:}&\\ + &s = 2 \max(|x_\text{min}|, x_\text{max}) / + \left( Q_\text{max} - Q_\text{min} \right) \\ + &z = \begin{cases} + 0 & \text{if dtype is qint8} \\ + 128 & \text{otherwise} + \end{cases}\\ + \text{Otherwise:}&\\ + &s = \left( x_\text{max} - x_\text{min} \right ) / + \left( Q_\text{max} - Q_\text{min} \right ) \\ + &z = Q_\text{min} - \text{round}(x_\text{min} / s) + \end{aligned} +``` + +where :math:`[x_\text{min}, x_\text{max}]` denotes the range of the input data while +:math:`Q_\text{min}` and :math:`Q_\text{max}` are respectively the minimum and maximum values of the quantized dtype. + +Note that the choice of :math:`s` and :math:`z` implies that zero is represented with no quantization error whenever zero is within +the range of the input data or symmetric quantization is being used. + +Additional data types and quantization schemes can be implemented through +the `custom operator mechanism `_. + +```{eval-rst} +* :attr:`torch.qscheme` — Type to describe the quantization scheme of a tensor. + Supported types: + + * :attr:`torch.per_tensor_affine` — per tensor, asymmetric + * :attr:`torch.per_channel_affine` — per channel, asymmetric + * :attr:`torch.per_tensor_symmetric` — per tensor, symmetric + * :attr:`torch.per_channel_symmetric` — per channel, symmetric + +* ``torch.dtype`` — Type to describe the data. Supported types: + + * :attr:`torch.quint8` — 8-bit unsigned integer + * :attr:`torch.qint8` — 8-bit signed integer + * :attr:`torch.qint32` — 32-bit signed integer +``` + +```{eval-rst} +.. These modules are missing docs. Adding them here only for tracking +.. automodule:: torch.ao.nn.quantizable.modules + :noindex: +.. automodule:: torch.ao.nn.quantized.reference + :noindex: +.. automodule:: torch.ao.nn.quantized.reference.modules + :noindex: + +.. automodule:: torch.nn.quantizable +.. automodule:: torch.nn.qat.dynamic.modules +.. automodule:: torch.nn.qat.modules +.. automodule:: torch.nn.qat +.. automodule:: torch.nn.intrinsic.qat.modules +.. automodule:: torch.nn.quantized.dynamic +.. automodule:: torch.nn.intrinsic +.. automodule:: torch.nn.intrinsic.quantized.modules +.. automodule:: torch.quantization.fx +.. automodule:: torch.nn.intrinsic.quantized.dynamic +.. automodule:: torch.nn.qat.dynamic +.. automodule:: torch.nn.intrinsic.qat +.. automodule:: torch.nn.quantized.modules +.. automodule:: torch.nn.intrinsic.quantized +.. automodule:: torch.nn.quantizable.modules +.. automodule:: torch.nn.quantized +.. automodule:: torch.nn.intrinsic.quantized.dynamic.modules +.. automodule:: torch.nn.quantized.dynamic.modules +.. automodule:: torch.quantization +.. automodule:: torch.nn.intrinsic.modules +``` diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst deleted file mode 100644 index 83ad054514ef0a..00000000000000 --- a/docs/source/quantization-support.rst +++ /dev/null @@ -1,680 +0,0 @@ -Quantization API Reference -------------------------------- - -torch.ao.quantization -~~~~~~~~~~~~~~~~~~~~~ - -This module contains Eager mode quantization APIs. - -.. currentmodule:: torch.ao.quantization - -Top level APIs -^^^^^^^^^^^^^^ - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - quantize - quantize_dynamic - quantize_qat - prepare - prepare_qat - convert - -Preparing model for quantization -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - fuse_modules.fuse_modules - QuantStub - DeQuantStub - QuantWrapper - add_quant_dequant - -Utility functions -^^^^^^^^^^^^^^^^^ - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - swap_module - propagate_qconfig_ - default_eval_fn - - -torch.ao.quantization.quantize_fx -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This module contains FX graph mode quantization APIs (prototype). - -.. currentmodule:: torch.ao.quantization.quantize_fx - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - prepare_fx - prepare_qat_fx - convert_fx - fuse_fx - -torch.ao.quantization.qconfig_mapping -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This module contains QConfigMapping for configuring FX graph mode quantization. - -.. currentmodule:: torch.ao.quantization.qconfig_mapping - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - QConfigMapping - get_default_qconfig_mapping - get_default_qat_qconfig_mapping - -torch.ao.quantization.backend_config -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This module contains BackendConfig, a config object that defines how quantization is supported -in a backend. Currently only used by FX Graph Mode Quantization, but we may extend Eager Mode -Quantization to work with this as well. - -.. currentmodule:: torch.ao.quantization.backend_config - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - BackendConfig - BackendPatternConfig - DTypeConfig - DTypeWithConstraints - ObservationType - -torch.ao.quantization.fx.custom_config -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization - - -.. currentmodule:: torch.ao.quantization.fx.custom_config - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - FuseCustomConfig - PrepareCustomConfig - ConvertCustomConfig - StandaloneModuleConfigEntry - -torch.ao.quantization.quantizer -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: torch.ao.quantization.quantizer - -torch.ao.quantization.pt2e (quantization in pytorch 2.0 export implementation) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: torch.ao.quantization.pt2e -.. automodule:: torch.ao.quantization.pt2e.representation - -torch.ao.quantization.pt2e.export_utils -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.ao.quantization.pt2e.export_utils - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - model_is_exported - -.. currentmodule:: torch.ao.quantization - -torch.ao.quantization.pt2e.lowering -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. currentmodule:: torch.ao.quantization.pt2e.lowering - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - lower_pt2e_quantized_to_x86 - -.. currentmodule:: torch.ao.quantization - -PT2 Export (pt2e) Numeric Debugger -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - generate_numeric_debug_handle - CUSTOM_KEY - NUMERIC_DEBUG_HANDLE_KEY - prepare_for_propagation_comparison - extract_results_from_loggers - compare_results - -torch (quantization related functions) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This describes the quantization related functions of the `torch` namespace. - -.. currentmodule:: torch - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - quantize_per_tensor - quantize_per_channel - dequantize - -torch.Tensor (quantization related methods) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Quantized Tensors support a limited subset of data manipulation methods of the -regular full-precision tensor. - -.. currentmodule:: torch.Tensor - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - view - as_strided - expand - flatten - select - ne - eq - ge - le - gt - lt - copy_ - clone - dequantize - equal - int_repr - max - mean - min - q_scale - q_zero_point - q_per_channel_scales - q_per_channel_zero_points - q_per_channel_axis - resize_ - sort - topk - - -torch.ao.quantization.observer -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This module contains observers which are used to collect statistics about -the values observed during calibration (PTQ) or training (QAT). - -.. currentmodule:: torch.ao.quantization.observer - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - ObserverBase - MinMaxObserver - MovingAverageMinMaxObserver - PerChannelMinMaxObserver - MovingAveragePerChannelMinMaxObserver - HistogramObserver - PlaceholderObserver - RecordingObserver - NoopObserver - get_observer_state_dict - load_observer_state_dict - default_observer - default_placeholder_observer - default_debug_observer - default_weight_observer - default_histogram_observer - default_per_channel_weight_observer - default_dynamic_quant_observer - default_float_qparams_observer - AffineQuantizedObserverBase - Granularity - MappingType - PerAxis - PerBlock - PerGroup - PerRow - PerTensor - PerToken - TorchAODType - ZeroPointDomain - get_block_size - -torch.ao.quantization.fake_quantize -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This module implements modules which are used to perform fake quantization -during QAT. - -.. currentmodule:: torch.ao.quantization.fake_quantize - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - FakeQuantizeBase - FakeQuantize - FixedQParamsFakeQuantize - FusedMovingAvgObsFakeQuantize - default_fake_quant - default_weight_fake_quant - default_per_channel_weight_fake_quant - default_histogram_fake_quant - default_fused_act_fake_quant - default_fused_wt_fake_quant - default_fused_per_channel_wt_fake_quant - disable_fake_quant - enable_fake_quant - disable_observer - enable_observer - -torch.ao.quantization.qconfig -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This module defines `QConfig` objects which are used -to configure quantization settings for individual ops. - -.. currentmodule:: torch.ao.quantization.qconfig - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - QConfig - default_qconfig - default_debug_qconfig - default_per_channel_qconfig - default_dynamic_qconfig - float16_dynamic_qconfig - float16_static_qconfig - per_channel_dynamic_qconfig - float_qparams_weight_only_qconfig - default_qat_qconfig - default_weight_only_qconfig - default_activation_only_qconfig - default_qat_qconfig_v2 - -torch.ao.nn.intrinsic -~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.intrinsic -.. automodule:: torch.ao.nn.intrinsic.modules - -This module implements the combined (fused) modules conv + relu which can -then be quantized. - -.. currentmodule:: torch.ao.nn.intrinsic - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - ConvReLU1d - ConvReLU2d - ConvReLU3d - LinearReLU - ConvBn1d - ConvBn2d - ConvBn3d - ConvBnReLU1d - ConvBnReLU2d - ConvBnReLU3d - BNReLU2d - BNReLU3d - -torch.ao.nn.intrinsic.qat -~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.intrinsic.qat -.. automodule:: torch.ao.nn.intrinsic.qat.modules - - -This module implements the versions of those fused operations needed for -quantization aware training. - -.. currentmodule:: torch.ao.nn.intrinsic.qat - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - LinearReLU - ConvBn1d - ConvBnReLU1d - ConvBn2d - ConvBnReLU2d - ConvReLU2d - ConvBn3d - ConvBnReLU3d - ConvReLU3d - update_bn_stats - freeze_bn_stats - -torch.ao.nn.intrinsic.quantized -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.intrinsic.quantized -.. automodule:: torch.ao.nn.intrinsic.quantized.modules - - -This module implements the quantized implementations of fused operations -like conv + relu. No BatchNorm variants as it's usually folded into convolution -for inference. - -.. currentmodule:: torch.ao.nn.intrinsic.quantized - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - BNReLU2d - BNReLU3d - ConvReLU1d - ConvReLU2d - ConvReLU3d - LinearReLU - -torch.ao.nn.intrinsic.quantized.dynamic -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.intrinsic.quantized.dynamic -.. automodule:: torch.ao.nn.intrinsic.quantized.dynamic.modules - -This module implements the quantized dynamic implementations of fused operations -like linear + relu. - -.. currentmodule:: torch.ao.nn.intrinsic.quantized.dynamic - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - LinearReLU - -torch.ao.nn.qat -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.qat -.. automodule:: torch.ao.nn.qat.modules - -This module implements versions of the key nn modules **Conv2d()** and -**Linear()** which run in FP32 but with rounding applied to simulate the -effect of INT8 quantization. - -.. currentmodule:: torch.ao.nn.qat - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - Conv2d - Conv3d - Linear - -torch.ao.nn.qat.dynamic -~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.qat.dynamic -.. automodule:: torch.ao.nn.qat.dynamic.modules - -This module implements versions of the key nn modules such as **Linear()** -which run in FP32 but with rounding applied to simulate the effect of INT8 -quantization and will be dynamically quantized during inference. - -.. currentmodule:: torch.ao.nn.qat.dynamic - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - Linear - -torch.ao.nn.quantized -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.quantized - :noindex: -.. automodule:: torch.ao.nn.quantized.modules - -This module implements the quantized versions of the nn layers such as -~`torch.nn.Conv2d` and `torch.nn.ReLU`. - -.. currentmodule:: torch.ao.nn.quantized - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - ReLU6 - Hardswish - ELU - LeakyReLU - Sigmoid - BatchNorm2d - BatchNorm3d - Conv1d - Conv2d - Conv3d - ConvTranspose1d - ConvTranspose2d - ConvTranspose3d - Embedding - EmbeddingBag - FloatFunctional - FXFloatFunctional - QFunctional - Linear - LayerNorm - GroupNorm - InstanceNorm1d - InstanceNorm2d - InstanceNorm3d - -torch.ao.nn.quantized.functional -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.quantized.functional - -This module implements the quantized versions of the functional layers such as -~`torch.nn.functional.conv2d` and `torch.nn.functional.relu`. Note: -:meth:`~torch.nn.functional.relu` supports quantized inputs. - -.. currentmodule:: torch.ao.nn.quantized.functional - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - avg_pool2d - avg_pool3d - adaptive_avg_pool2d - adaptive_avg_pool3d - conv1d - conv2d - conv3d - interpolate - linear - max_pool1d - max_pool2d - celu - leaky_relu - hardtanh - hardswish - threshold - elu - hardsigmoid - clamp - upsample - upsample_bilinear - upsample_nearest - -torch.ao.nn.quantizable -~~~~~~~~~~~~~~~~~~~~~~~ - -This module implements the quantizable versions of some of the nn layers. -These modules can be used in conjunction with the custom module mechanism, -by providing the ``custom_module_config`` argument to both prepare and convert. - -.. currentmodule:: torch.ao.nn.quantizable - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - LSTM - MultiheadAttention - - -torch.ao.nn.quantized.dynamic -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: torch.ao.nn.quantized.dynamic -.. automodule:: torch.ao.nn.quantized.dynamic.modules - -Dynamically quantized :class:`~torch.nn.Linear`, :class:`~torch.nn.LSTM`, -:class:`~torch.nn.LSTMCell`, :class:`~torch.nn.GRUCell`, and -:class:`~torch.nn.RNNCell`. - -.. currentmodule:: torch.ao.nn.quantized.dynamic - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - Linear - LSTM - GRU - RNNCell - LSTMCell - GRUCell - -Quantized dtypes and quantization schemes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Note that operator implementations currently only -support per channel quantization for weights of the **conv** and **linear** -operators. Furthermore, the input data is -mapped linearly to the quantized data and vice versa -as follows: - - .. math:: - - \begin{aligned} - \text{Quantization:}&\\ - &Q_\text{out} = \text{clamp}(x_\text{input}/s+z, Q_\text{min}, Q_\text{max})\\ - \text{Dequantization:}&\\ - &x_\text{out} = (Q_\text{input}-z)*s - \end{aligned} - -where :math:`\text{clamp}(.)` is the same as :func:`~torch.clamp` while the -scale :math:`s` and zero point :math:`z` are then computed -as described in :class:`~torch.ao.quantization.observer.MinMaxObserver`, specifically: - - .. math:: - - \begin{aligned} - \text{if Symmetric:}&\\ - &s = 2 \max(|x_\text{min}|, x_\text{max}) / - \left( Q_\text{max} - Q_\text{min} \right) \\ - &z = \begin{cases} - 0 & \text{if dtype is qint8} \\ - 128 & \text{otherwise} - \end{cases}\\ - \text{Otherwise:}&\\ - &s = \left( x_\text{max} - x_\text{min} \right ) / - \left( Q_\text{max} - Q_\text{min} \right ) \\ - &z = Q_\text{min} - \text{round}(x_\text{min} / s) - \end{aligned} - -where :math:`[x_\text{min}, x_\text{max}]` denotes the range of the input data while -:math:`Q_\text{min}` and :math:`Q_\text{max}` are respectively the minimum and maximum values of the quantized dtype. - -Note that the choice of :math:`s` and :math:`z` implies that zero is represented with no quantization error whenever zero is within -the range of the input data or symmetric quantization is being used. - -Additional data types and quantization schemes can be implemented through -the `custom operator mechanism `_. - -* :attr:`torch.qscheme` — Type to describe the quantization scheme of a tensor. - Supported types: - - * :attr:`torch.per_tensor_affine` — per tensor, asymmetric - * :attr:`torch.per_channel_affine` — per channel, asymmetric - * :attr:`torch.per_tensor_symmetric` — per tensor, symmetric - * :attr:`torch.per_channel_symmetric` — per channel, symmetric - -* ``torch.dtype`` — Type to describe the data. Supported types: - - * :attr:`torch.quint8` — 8-bit unsigned integer - * :attr:`torch.qint8` — 8-bit signed integer - * :attr:`torch.qint32` — 32-bit signed integer - - -.. These modules are missing docs. Adding them here only for tracking -.. automodule:: torch.ao.nn.quantizable.modules - :noindex: -.. automodule:: torch.ao.nn.quantized.reference - :noindex: -.. automodule:: torch.ao.nn.quantized.reference.modules - :noindex: - -.. automodule:: torch.nn.quantizable -.. automodule:: torch.nn.qat.dynamic.modules -.. automodule:: torch.nn.qat.modules -.. automodule:: torch.nn.qat -.. automodule:: torch.nn.intrinsic.qat.modules -.. automodule:: torch.nn.quantized.dynamic -.. automodule:: torch.nn.intrinsic -.. automodule:: torch.nn.intrinsic.quantized.modules -.. automodule:: torch.quantization.fx -.. automodule:: torch.nn.intrinsic.quantized.dynamic -.. automodule:: torch.nn.qat.dynamic -.. automodule:: torch.nn.intrinsic.qat -.. automodule:: torch.nn.quantized.modules -.. automodule:: torch.nn.intrinsic.quantized -.. automodule:: torch.nn.quantizable.modules -.. automodule:: torch.nn.quantized -.. automodule:: torch.nn.intrinsic.quantized.dynamic.modules -.. automodule:: torch.nn.quantized.dynamic.modules -.. automodule:: torch.quantization -.. automodule:: torch.nn.intrinsic.modules diff --git a/docs/source/random.md b/docs/source/random.md new file mode 100644 index 00000000000000..432c0a83293bc5 --- /dev/null +++ b/docs/source/random.md @@ -0,0 +1,10 @@ +# torch.random + +```{eval-rst} +.. currentmodule:: torch.random +``` + +```{eval-rst} +.. automodule:: torch.random + :members: +``` diff --git a/docs/source/random.rst b/docs/source/random.rst deleted file mode 100644 index 45f98dff591d82..00000000000000 --- a/docs/source/random.rst +++ /dev/null @@ -1,7 +0,0 @@ -torch.random -=================================== - -.. currentmodule:: torch.random - -.. automodule:: torch.random - :members: diff --git a/docs/source/rpc.md b/docs/source/rpc.md new file mode 100644 index 00000000000000..77f8ec439aea67 --- /dev/null +++ b/docs/source/rpc.md @@ -0,0 +1,306 @@ +(distributed-rpc-framework)= + +# Distributed RPC Framework + +The distributed RPC framework provides mechanisms for multi-machine model +training through a set of primitives to allow for remote communication, and a +higher-level API to automatically differentiate models split across several +machines. + +```{warning} +APIs in the RPC package are stable. There are multiple ongoing work items +to improve performance and error handling, which will ship in future releases. +``` + +```{warning} +CUDA support was introduced in PyTorch 1.9 and is still a **beta** feature. +Not all features of the RPC package are yet compatible with CUDA support and +thus their use is discouraged. These unsupported features include: RRefs, +JIT compatibility, dist autograd and dist optimizer, and profiling. These +shortcomings will be addressed in future releases. +``` + +```{note} +Please refer to `PyTorch Distributed Overview `__ +for a brief introduction to all features related to distributed training. +``` + +## Basics + +The distributed RPC framework makes it easy to run functions remotely, supports +referencing remote objects without copying the real data around, and provides +autograd and optimizer APIs to transparently run backward and update parameters +across RPC boundaries. These features can be categorized into four sets of APIs. + +1) **Remote Procedure Call (RPC)** supports running a function on the specified + destination worker with the given arguments and getting the return value back + or creating a reference to the return value. There are three main RPC APIs: + {meth}`~torch.distributed.rpc.rpc_sync` (synchronous), + {meth}`~torch.distributed.rpc.rpc_async` (asynchronous), and + {meth}`~torch.distributed.rpc.remote` (asynchronous and returns a reference + to the remote return value). Use the synchronous API if the user code cannot + proceed without the return value. Otherwise, use the asynchronous API to get + a future, and wait on the future when the return value is needed on the + caller. The {meth}`~torch.distributed.rpc.remote` API is useful when the + requirement is to create something remotely but never need to fetch it to + the caller. Imagine the case that a driver process is setting up a parameter + server and a trainer. The driver can create an embedding table on the + parameter server and then share the reference to the embedding table with the + trainer, but itself will never use the embedding table locally. In this case, + {meth}`~torch.distributed.rpc.rpc_sync` and + {meth}`~torch.distributed.rpc.rpc_async` are no longer appropriate, as they + always imply that the return value will be returned to the caller + immediately or in the future. +2) **Remote Reference (RRef)** serves as a distributed shared pointer to a local + or remote object. It can be shared with other workers and reference counting + will be handled transparently. Each RRef only has one owner and the object + only lives on that owner. Non-owner workers holding RRefs can get copies of + the object from the owner by explicitly requesting it. This is useful when + a worker needs to access some data object, but itself is neither the creator + (the caller of {meth}`~torch.distributed.rpc.remote`) or the owner of the + object. The distributed optimizer, as we will discuss below, is one example + of such use cases. +3) **Distributed Autograd** stitches together local autograd engines on all the + workers involved in the forward pass, and automatically reach out to them + during the backward pass to compute gradients. This is especially helpful if + the forward pass needs to span multiple machines when conducting, e.g., + distributed model parallel training, parameter-server training, etc. With + this feature, user code no longer needs to worry about how to send gradients + across RPC boundaries and in which order should the local autograd engines + be launched, which can become quite complicated where there are nested and + inter-dependent RPC calls in the forward pass. +4) **Distributed Optimizer**'s constructor takes a + {meth}`~torch.optim.Optimizer` (e.g., {meth}`~torch.optim.SGD`, + {meth}`~torch.optim.Adagrad`, etc.) and a list of parameter RRefs, creates an + {meth}`~torch.optim.Optimizer` instance on each distinct RRef owner, and + updates parameters accordingly when running ``step()``. When you have + distributed forward and backward passes, parameters and gradients will be + scattered across multiple workers, and hence it requires an optimizer on each + of the involved workers. Distributed Optimizer wraps all those local + optimizers into one, and provides a concise constructor and ``step()`` API. + + +(rpc)= +## RPC + +Before using RPC and distributed autograd primitives, initialization must take +place. To initialize the RPC framework we need to use +{meth}`~torch.distributed.rpc.init_rpc` which would initialize the RPC +framework, RRef framework and distributed autograd. + +```{eval-rst} +.. automodule:: torch.distributed.rpc +.. autofunction:: init_rpc +``` + +The following APIs allow users to remotely execute functions as well as create +references (RRefs) to remote data objects. In these APIs, when passing a +``Tensor`` as an argument or a return value, the destination worker will try to +create a ``Tensor`` with the same meta (i.e., shape, stride, etc.). We +intentionally disallow transmitting CUDA tensors because it might crash if the +device lists on source and destination workers do not match. In such cases, +applications can always explicitly move the input tensors to CPU on the caller +and move it to the desired devices on the callee if necessary. + +```{warning} + TorchScript support in RPC is a prototype feature and subject to change. Since + v1.5.0, ``torch.distributed.rpc`` supports calling TorchScript functions as + RPC target functions, and this will help improve parallelism on the callee + side as executing TorchScript functions does not require GIL. +``` + +```{eval-rst} +.. autofunction:: rpc_sync +.. autofunction:: rpc_async +.. autofunction:: remote +.. autofunction:: get_worker_info +.. autofunction:: shutdown +.. autoclass:: WorkerInfo + :members: +``` + +The RPC package also provides decorators which allow applications to specify +how a given function should be treated on the callee side. + +```{eval-rst} +.. autofunction:: torch.distributed.rpc.functions.async_execution +``` + +(rpc-backends)= +### Backends + +The RPC module can leverage different backends to perform the communication +between the nodes. The backend to be used can be specified in the +{func}`~torch.distributed.rpc.init_rpc` function, by passing a certain value of +the {class}`~torch.distributed.rpc.BackendType` enum. Regardless of what backend +is used, the rest of the RPC API won't change. Each backend also defines its own +subclass of the {class}`~torch.distributed.rpc.RpcBackendOptions` class, an +instance of which can also be passed to {func}`~torch.distributed.rpc.init_rpc` +to configure the backend's behavior. + +```{eval-rst} +.. autoclass:: BackendType + +.. autoclass:: RpcBackendOptions + :members: +``` + +#### TensorPipe Backend + +The TensorPipe agent, which is the default, leverages [the TensorPipe library](https://github.com/pytorch/tensorpipe), which provides a natively +point-to-point communication primitive specifically suited for machine learning +that fundamentally addresses some of the limitations of Gloo. Compared to Gloo, +it has the advantage of being asynchronous, which allows a large number of +transfers to occur simultaneously, each at their own speed, without blocking +each other. It will only open pipes between pairs of nodes when needed, on +demand, and when one node fails only its incident pipes will be closed, while +all other ones will keep working as normal. In addition, it is able to support +multiple different transports (TCP, of course, but also shared memory, NVLink, +InfiniBand, ...) and can automatically detect their availability and negotiate +the best transport to use for each pipe. + +The TensorPipe backend has been introduced in PyTorch v1.6 and is being actively +developed. At the moment, it only supports CPU tensors, with GPU support coming +soon. It comes with a TCP-based transport, just like Gloo. It is also able to +automatically chunk and multiplex large tensors over multiple sockets and +threads in order to achieve very high bandwidths. The agent will be able to pick +the best transport on its own, with no intervention required. + +Example: + +```{code-block} python +import os +from torch.distributed import rpc +os.environ['MASTER_ADDR'] = 'localhost' +os.environ['MASTER_PORT'] = '29500' + +rpc.init_rpc( + "worker1", + rank=0, + world_size=2, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + num_worker_threads=8, + rpc_timeout=20 # 20 second timeout + ) +) + +# omitting init_rpc invocation on worker2 +``` + +```{eval-rst} +.. autoclass:: TensorPipeRpcBackendOptions + :members: + :inherited-members: +``` + +```{note} +The RPC framework does not automatically retry any +{meth}`~torch.distributed.rpc.rpc_sync`, +{meth}`~torch.distributed.rpc.rpc_async` and +{meth}`~torch.distributed.rpc.remote` calls. The reason being that there is +no way the RPC framework can determine whether an operation is idempotent or +not and whether it is safe to retry. As a result, it is the application's +responsibility to deal with failures and retry if necessary. RPC communication +is based on TCP and as a result failures could happen due to network failures +or intermittent network connectivity issues. In such scenarios, the application +needs to retry appropriately with reasonable backoffs to ensure the network +isn't overwhelmed by aggressive retries. +``` +(rref)= +## RRef + +```{warning} +RRefs are not currently supported when using CUDA tensors +``` + +An ``RRef`` (Remote REFerence) is a reference to a value of some type ``T`` +(e.g. ``Tensor``) on a remote worker. This handle keeps the referenced remote +value alive on the owner, but there is no implication that the value will be +transferred to the local worker in the future. RRefs can be used in +multi-machine training by holding references to [nn.Modules](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) that exist on +other workers, and calling the appropriate functions to retrieve or modify their +parameters during training. See {ref}`remote-reference-protocol` for more +details. + +```{eval-rst} +.. autoclass:: PyRRef(RRef) + :members: + :inherited-members: +``` + +```{toctree} +:caption: More Information about RRef + +rpc/rref +``` + +(remote_module)= + +## RemoteModule + +```{warning} +RemoteModule is not currently supported when using CUDA tensors +``` + +``RemoteModule`` is an easy way to create an nn.Module remotely on a different +process. The actual module resides on a remote host, but the local host has a +handle to this module and invoke this module similar to a regular nn.Module. +The invocation however incurs RPC calls to the remote end and can be performed +asynchronously if needed via additional APIs supported by RemoteModule. + +```{eval-rst} +.. autoclass:: torch.distributed.nn.api.remote_module.RemoteModule + :members: remote_parameters, get_module_rref +``` + +## Distributed Autograd Framework + +```{warning} +Distributed autograd is not currently supported when using CUDA tensors +``` + +This module provides an RPC-based distributed autograd framework that can be +used for applications such as model parallel training. In short, applications +may send and receive gradient recording tensors over RPC. In the forward pass, +we record when gradient recording tensors are sent over RPC and during the +backward pass we use this information to perform a distributed backward pass +using RPC. For more details see {ref}`distributed-autograd-design`. + +```{eval-rst} +.. automodule:: torch.distributed.autograd + :members: context, backward, get_gradients +``` + +```{toctree} +:caption: More Information about RPC Autograd + +rpc/distributed_autograd +``` + + +## Distributed Optimizer + +See the [torch.distributed.optim](https://pytorch.org/docs/main/distributed.optim.html) page for documentation on distributed optimizers. + +## Design Notes + +The distributed autograd design note covers the design of the RPC-based distributed autograd framework that is useful for applications such as model parallel training. + +- {ref}`distributed-autograd-design` + +The RRef design note covers the design of the {ref}`rref` (Remote REFerence) protocol used to refer to values on remote workers by the framework. + +- {ref}`remote-reference-protocol` + +## Tutorials + +The RPC tutorials introduce users to the RPC framework, provide several example applications +using {ref}`torch.distributed.rpc` APIs, and demonstrate how +to use [the profiler](https://pytorch.org/docs/stable/autograd.html#profiler) to profile RPC-based workloads. + +- [Getting started with Distributed RPC Framework](https://pytorch.org/tutorials/intermediate/rpc_tutorial.html) +- [Implementing a Parameter Server using Distributed RPC Framework](https://pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html) +- [Combining Distributed DataParallel with Distributed RPC Framework](https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html) (covers **RemoteModule** as well) +- [Profiling RPC-based Workloads](https://pytorch.org/tutorials/recipes/distributed_rpc_profiling.html) +- [Implementing batch RPC processing](https://pytorch.org/tutorials/intermediate/rpc_async_execution.html) +- [Distributed Pipeline Parallel](https://pytorch.org/tutorials/intermediate/dist_pipeline_parallel_tutorial.html) diff --git a/docs/source/rpc.rst b/docs/source/rpc.rst deleted file mode 100644 index 5c65a79aabf333..00000000000000 --- a/docs/source/rpc.rst +++ /dev/null @@ -1,300 +0,0 @@ -.. _distributed-rpc-framework: - -Distributed RPC Framework -========================= - -The distributed RPC framework provides mechanisms for multi-machine model -training through a set of primitives to allow for remote communication, and a -higher-level API to automatically differentiate models split across several -machines. - -.. warning :: - APIs in the RPC package are stable. There are multiple ongoing work items - to improve performance and error handling, which will ship in future releases. - -.. warning :: - CUDA support was introduced in PyTorch 1.9 and is still a **beta** feature. - Not all features of the RPC package are yet compatible with CUDA support and - thus their use is discouraged. These unsupported features include: RRefs, - JIT compatibility, dist autograd and dist optimizer, and profiling. These - shortcomings will be addressed in future releases. - -.. note :: - Please refer to `PyTorch Distributed Overview `__ - for a brief introduction to all features related to distributed training. - -Basics ------- - -The distributed RPC framework makes it easy to run functions remotely, supports -referencing remote objects without copying the real data around, and provides -autograd and optimizer APIs to transparently run backward and update parameters -across RPC boundaries. These features can be categorized into four sets of APIs. - -1) **Remote Procedure Call (RPC)** supports running a function on the specified - destination worker with the given arguments and getting the return value back - or creating a reference to the return value. There are three main RPC APIs: - :meth:`~torch.distributed.rpc.rpc_sync` (synchronous), - :meth:`~torch.distributed.rpc.rpc_async` (asynchronous), and - :meth:`~torch.distributed.rpc.remote` (asynchronous and returns a reference - to the remote return value). Use the synchronous API if the user code cannot - proceed without the return value. Otherwise, use the asynchronous API to get - a future, and wait on the future when the return value is needed on the - caller. The :meth:`~torch.distributed.rpc.remote` API is useful when the - requirement is to create something remotely but never need to fetch it to - the caller. Imagine the case that a driver process is setting up a parameter - server and a trainer. The driver can create an embedding table on the - parameter server and then share the reference to the embedding table with the - trainer, but itself will never use the embedding table locally. In this case, - :meth:`~torch.distributed.rpc.rpc_sync` and - :meth:`~torch.distributed.rpc.rpc_async` are no longer appropriate, as they - always imply that the return value will be returned to the caller - immediately or in the future. -2) **Remote Reference (RRef)** serves as a distributed shared pointer to a local - or remote object. It can be shared with other workers and reference counting - will be handled transparently. Each RRef only has one owner and the object - only lives on that owner. Non-owner workers holding RRefs can get copies of - the object from the owner by explicitly requesting it. This is useful when - a worker needs to access some data object, but itself is neither the creator - (the caller of :meth:`~torch.distributed.rpc.remote`) or the owner of the - object. The distributed optimizer, as we will discuss below, is one example - of such use cases. -3) **Distributed Autograd** stitches together local autograd engines on all the - workers involved in the forward pass, and automatically reach out to them - during the backward pass to compute gradients. This is especially helpful if - the forward pass needs to span multiple machines when conducting, e.g., - distributed model parallel training, parameter-server training, etc. With - this feature, user code no longer needs to worry about how to send gradients - across RPC boundaries and in which order should the local autograd engines - be launched, which can become quite complicated where there are nested and - inter-dependent RPC calls in the forward pass. -4) **Distributed Optimizer**'s constructor takes a - :meth:`~torch.optim.Optimizer` (e.g., :meth:`~torch.optim.SGD`, - :meth:`~torch.optim.Adagrad`, etc.) and a list of parameter RRefs, creates an - :meth:`~torch.optim.Optimizer` instance on each distinct RRef owner, and - updates parameters accordingly when running ``step()``. When you have - distributed forward and backward passes, parameters and gradients will be - scattered across multiple workers, and hence it requires an optimizer on each - of the involved workers. Distributed Optimizer wraps all those local - optimizers into one, and provides a concise constructor and ``step()`` API. - - -.. _rpc: - -RPC ---- - -Before using RPC and distributed autograd primitives, initialization must take -place. To initialize the RPC framework we need to use -:meth:`~torch.distributed.rpc.init_rpc` which would initialize the RPC -framework, RRef framework and distributed autograd. - -.. automodule:: torch.distributed.rpc -.. autofunction:: init_rpc - -The following APIs allow users to remotely execute functions as well as create -references (RRefs) to remote data objects. In these APIs, when passing a -``Tensor`` as an argument or a return value, the destination worker will try to -create a ``Tensor`` with the same meta (i.e., shape, stride, etc.). We -intentionally disallow transmitting CUDA tensors because it might crash if the -device lists on source and destination workers do not match. In such cases, -applications can always explicitly move the input tensors to CPU on the caller -and move it to the desired devices on the callee if necessary. - -.. warning:: - TorchScript support in RPC is a prototype feature and subject to change. Since - v1.5.0, ``torch.distributed.rpc`` supports calling TorchScript functions as - RPC target functions, and this will help improve parallelism on the callee - side as executing TorchScript functions does not require GIL. - - -.. autofunction:: rpc_sync -.. autofunction:: rpc_async -.. autofunction:: remote -.. autofunction:: get_worker_info -.. autofunction:: shutdown -.. autoclass:: WorkerInfo - :members: - - -The RPC package also provides decorators which allow applications to specify -how a given function should be treated on the callee side. - - -.. autofunction:: torch.distributed.rpc.functions.async_execution - - -.. _rpc-backends: - -Backends -^^^^^^^^ - -The RPC module can leverage different backends to perform the communication -between the nodes. The backend to be used can be specified in the -:func:`~torch.distributed.rpc.init_rpc` function, by passing a certain value of -the :class:`~torch.distributed.rpc.BackendType` enum. Regardless of what backend -is used, the rest of the RPC API won't change. Each backend also defines its own -subclass of the :class:`~torch.distributed.rpc.RpcBackendOptions` class, an -instance of which can also be passed to :func:`~torch.distributed.rpc.init_rpc` -to configure the backend's behavior. - -.. autoclass:: BackendType - -.. autoclass:: RpcBackendOptions - :members: - - -TensorPipe Backend -"""""""""""""""""" - -The TensorPipe agent, which is the default, leverages `the TensorPipe library -`_, which provides a natively -point-to-point communication primitive specifically suited for machine learning -that fundamentally addresses some of the limitations of Gloo. Compared to Gloo, -it has the advantage of being asynchronous, which allows a large number of -transfers to occur simultaneously, each at their own speed, without blocking -each other. It will only open pipes between pairs of nodes when needed, on -demand, and when one node fails only its incident pipes will be closed, while -all other ones will keep working as normal. In addition, it is able to support -multiple different transports (TCP, of course, but also shared memory, NVLink, -InfiniBand, ...) and can automatically detect their availability and negotiate -the best transport to use for each pipe. - -The TensorPipe backend has been introduced in PyTorch v1.6 and is being actively -developed. At the moment, it only supports CPU tensors, with GPU support coming -soon. It comes with a TCP-based transport, just like Gloo. It is also able to -automatically chunk and multiplex large tensors over multiple sockets and -threads in order to achieve very high bandwidths. The agent will be able to pick -the best transport on its own, with no intervention required. - -Example:: - - >>> import os - >>> from torch.distributed import rpc - >>> os.environ['MASTER_ADDR'] = 'localhost' - >>> os.environ['MASTER_PORT'] = '29500' - >>> - >>> rpc.init_rpc( - >>> "worker1", - >>> rank=0, - >>> world_size=2, - >>> rpc_backend_options=rpc.TensorPipeRpcBackendOptions( - >>> num_worker_threads=8, - >>> rpc_timeout=20 # 20 second timeout - >>> ) - >>> ) - >>> - >>> # omitting init_rpc invocation on worker2 - -.. autoclass:: TensorPipeRpcBackendOptions - :members: - :inherited-members: - -.. note :: - The RPC framework does not automatically retry any - :meth:`~torch.distributed.rpc.rpc_sync`, - :meth:`~torch.distributed.rpc.rpc_async` and - :meth:`~torch.distributed.rpc.remote` calls. The reason being that there is - no way the RPC framework can determine whether an operation is idempotent or - not and whether it is safe to retry. As a result, it is the application's - responsibility to deal with failures and retry if necessary. RPC communication - is based on TCP and as a result failures could happen due to network failures - or intermittent network connectivity issues. In such scenarios, the application - needs to retry appropriately with reasonable backoffs to ensure the network - isn't overwhelmed by aggressive retries. - -.. _rref: - -RRef ----- - -.. warning :: - RRefs are not currently supported when using CUDA tensors - -An ``RRef`` (Remote REFerence) is a reference to a value of some type ``T`` -(e.g. ``Tensor``) on a remote worker. This handle keeps the referenced remote -value alive on the owner, but there is no implication that the value will be -transferred to the local worker in the future. RRefs can be used in -multi-machine training by holding references to `nn.Modules -`_ that exist on -other workers, and calling the appropriate functions to retrieve or modify their -parameters during training. See :ref:`remote-reference-protocol` for more -details. - -.. autoclass:: PyRRef(RRef) - :members: - :inherited-members: - - -.. toctree:: - :caption: More Information about RRef - - rpc/rref - -.. _remote_module: - -RemoteModule ------------- - -.. warning :: - RemoteModule is not currently supported when using CUDA tensors - -``RemoteModule`` is an easy way to create an nn.Module remotely on a different -process. The actual module resides on a remote host, but the local host has a -handle to this module and invoke this module similar to a regular nn.Module. -The invocation however incurs RPC calls to the remote end and can be performed -asynchronously if needed via additional APIs supported by RemoteModule. - -.. autoclass:: torch.distributed.nn.api.remote_module.RemoteModule - :members: remote_parameters, get_module_rref - - -Distributed Autograd Framework ------------------------------- - -.. warning :: - Distributed autograd is not currently supported when using CUDA tensors - -This module provides an RPC-based distributed autograd framework that can be -used for applications such as model parallel training. In short, applications -may send and receive gradient recording tensors over RPC. In the forward pass, -we record when gradient recording tensors are sent over RPC and during the -backward pass we use this information to perform a distributed backward pass -using RPC. For more details see :ref:`distributed-autograd-design`. - -.. automodule:: torch.distributed.autograd - :members: context, backward, get_gradients - -.. toctree:: - :caption: More Information about RPC Autograd - - rpc/distributed_autograd - - -Distributed Optimizer ---------------------- - -See the `torch.distributed.optim `__ page for documentation on distributed optimizers. - -Design Notes ------------- -The distributed autograd design note covers the design of the RPC-based distributed autograd framework that is useful for applications such as model parallel training. - -- :ref:`distributed-autograd-design` - -The RRef design note covers the design of the :ref:`rref` (Remote REFerence) protocol used to refer to values on remote workers by the framework. - -- :ref:`remote-reference-protocol` - -Tutorials ---------- -The RPC tutorials introduce users to the RPC framework, provide several example applications -using :ref:`torch.distributed.rpc` APIs, and demonstrate how -to use `the profiler `__ to profile RPC-based workloads. - -- `Getting started with Distributed RPC Framework `__ -- `Implementing a Parameter Server using Distributed RPC Framework `__ -- `Combining Distributed DataParallel with Distributed RPC Framework `__ (covers **RemoteModule** as well) -- `Profiling RPC-based Workloads `__ -- `Implementing batch RPC processing `__ -- `Distributed Pipeline Parallel `__ diff --git a/docs/source/signal.md b/docs/source/signal.md new file mode 100644 index 00000000000000..b73609d11c3c25 --- /dev/null +++ b/docs/source/signal.md @@ -0,0 +1,37 @@ +```{role} hidden +:class: hidden-section +``` + +# torch.signal + +```{eval-rst} +.. automodule:: torch.signal +.. currentmodule:: torch.signal +``` + +The `torch.signal` module, modeled after SciPy's [signal](https://docs.scipy.org/doc/scipy/reference/signal.html)module. + +## torch.signal.windows + +```{eval-rst} +.. automodule:: torch.signal.windows +.. currentmodule:: torch.signal.windows +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + bartlett + blackman + cosine + exponential + gaussian + general_cosine + general_hamming + hamming + hann + kaiser + nuttall +``` \ No newline at end of file diff --git a/docs/source/signal.rst b/docs/source/signal.rst deleted file mode 100644 index ebb1cdb8a089ee..00000000000000 --- a/docs/source/signal.rst +++ /dev/null @@ -1,31 +0,0 @@ -.. role:: hidden - :class: hidden-section - -torch.signal -============ -.. automodule:: torch.signal -.. currentmodule:: torch.signal - -The `torch.signal` module, modeled after SciPy's `signal `_ module. - -torch.signal.windows --------------------- - -.. automodule:: torch.signal.windows -.. currentmodule:: torch.signal.windows - -.. autosummary:: - :toctree: generated - :nosignatures: - - bartlett - blackman - cosine - exponential - gaussian - general_cosine - general_hamming - hamming - hann - kaiser - nuttall diff --git a/docs/source/size.md b/docs/source/size.md new file mode 100644 index 00000000000000..5ebba9a2e401d3 --- /dev/null +++ b/docs/source/size.md @@ -0,0 +1,26 @@ +# torch.Size + +{class}`torch.Size` is the result type of a call to {func}`torch.Tensor.size`. It describes the size of all dimensions +of the original tensor. As a subclass of {class}`tuple`, it supports common sequence operations like indexing and +length. + + +Example: + +```{code-block} python + >>> x = torch.ones(10, 20, 30) + >>> s = x.size() + >>> s + torch.Size([10, 20, 30]) + >>> s[1] + 20 + >>> len(s) + 3 +``` + +```{eval-rst} +.. autoclass:: torch.Size + :members: + :undoc-members: + :inherited-members: +``` \ No newline at end of file diff --git a/docs/source/size.rst b/docs/source/size.rst deleted file mode 100644 index 340836e000e1f5..00000000000000 --- a/docs/source/size.rst +++ /dev/null @@ -1,25 +0,0 @@ -torch.Size -=================================== - -:class:`torch.Size` is the result type of a call to :func:`torch.Tensor.size`. It describes the size of all dimensions -of the original tensor. As a subclass of :class:`tuple`, it supports common sequence operations like indexing and -length. - - -Example:: - - >>> x = torch.ones(10, 20, 30) - >>> s = x.size() - >>> s - torch.Size([10, 20, 30]) - >>> s[1] - 20 - >>> len(s) - 3 - - - -.. autoclass:: torch.Size - :members: - :undoc-members: - :inherited-members: diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index 8f7b025f354ccb..652667b5ea7981 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -360,8 +360,7 @@ Suppose we want to define a sparse tensor with the entry 3 at location Unspecified elements are assumed to have the same value, fill value, which is zero by default. We would then write: - >>> i = [[0, 1, 1], - [2, 0, 2]] + >>> i = [[0, 1, 1], [2, 0, 2]] >>> v = [3, 4, 5] >>> s = torch.sparse_coo_tensor(i, v, (2, 3)) >>> s @@ -1070,7 +1069,7 @@ Tools for working with sparse compressed tensors ------------------------------------------------ All sparse compressed tensors --- CSR, CSC, BSR, and BSC tensors --- -are conceptionally very similar in that their indices data is split +are conceptually very similar in that their indices data is split into two parts: so-called compressed indices that use the CSR encoding, and so-called plain indices that are orthogonal to the compressed indices. This allows various tools on these tensors to diff --git a/docs/source/special.md b/docs/source/special.md new file mode 100644 index 00000000000000..24dc10756a1eb1 --- /dev/null +++ b/docs/source/special.md @@ -0,0 +1,73 @@ +```{role} hidden +:class: hidden-section +``` + +# torch.special + +The torch.special module, modeled after SciPy's [special](https://docs.scipy.org/doc/scipy/reference/special.html) module. + +```{eval-rst} +.. automodule:: torch.special +.. currentmodule:: torch.special +``` + +## Functions + +```{eval-rst} +.. autofunction:: airy_ai +.. autofunction:: bessel_j0 +.. autofunction:: bessel_j1 +.. autofunction:: bessel_y0 +.. autofunction:: bessel_y1 +.. autofunction:: chebyshev_polynomial_t +.. autofunction:: chebyshev_polynomial_u +.. autofunction:: chebyshev_polynomial_v +.. autofunction:: chebyshev_polynomial_w +.. autofunction:: digamma +.. autofunction:: entr +.. autofunction:: erf +.. autofunction:: erfc +.. autofunction:: erfcx +.. autofunction:: erfinv +.. autofunction:: exp2 +.. autofunction:: expit +.. autofunction:: expm1 +.. autofunction:: gammainc +.. autofunction:: gammaincc +.. autofunction:: gammaln +.. autofunction:: hermite_polynomial_h +.. autofunction:: hermite_polynomial_he +.. autofunction:: i0 +.. autofunction:: i0e +.. autofunction:: i1 +.. autofunction:: i1e +.. autofunction:: laguerre_polynomial_l +.. autofunction:: legendre_polynomial_p +.. autofunction:: log1p +.. autofunction:: log_ndtr +.. autofunction:: log_softmax +.. autofunction:: logit +.. autofunction:: logsumexp +.. autofunction:: modified_bessel_i0 +.. autofunction:: modified_bessel_i1 +.. autofunction:: modified_bessel_k0 +.. autofunction:: modified_bessel_k1 +.. autofunction:: multigammaln +.. autofunction:: ndtr +.. autofunction:: ndtri +.. autofunction:: polygamma +.. autofunction:: psi +.. autofunction:: round +.. autofunction:: scaled_modified_bessel_k0 +.. autofunction:: scaled_modified_bessel_k1 +.. autofunction:: shifted_chebyshev_polynomial_t +.. autofunction:: shifted_chebyshev_polynomial_u +.. autofunction:: shifted_chebyshev_polynomial_v +.. autofunction:: shifted_chebyshev_polynomial_w +.. autofunction:: sinc +.. autofunction:: softmax +.. autofunction:: spherical_bessel_j0 +.. autofunction:: xlog1py +.. autofunction:: xlogy +.. autofunction:: zeta +``` \ No newline at end of file diff --git a/docs/source/special.rst b/docs/source/special.rst deleted file mode 100644 index 96179475469aaf..00000000000000 --- a/docs/source/special.rst +++ /dev/null @@ -1,52 +0,0 @@ -.. role:: hidden - :class: hidden-section - -torch.special -============= - -The torch.special module, modeled after SciPy's `special `_ module. - -.. automodule:: torch.special -.. currentmodule:: torch.special - -Functions ------------------------ - -.. autofunction:: airy_ai -.. autofunction:: bessel_j0 -.. autofunction:: bessel_j1 -.. autofunction:: digamma -.. autofunction:: entr -.. autofunction:: erf -.. autofunction:: erfc -.. autofunction:: erfcx -.. autofunction:: erfinv -.. autofunction:: exp2 -.. autofunction:: expit -.. autofunction:: expm1 -.. autofunction:: gammainc -.. autofunction:: gammaincc -.. autofunction:: gammaln -.. autofunction:: i0 -.. autofunction:: i0e -.. autofunction:: i1 -.. autofunction:: i1e -.. autofunction:: log1p -.. autofunction:: log_ndtr -.. autofunction:: log_softmax -.. autofunction:: logit -.. autofunction:: logsumexp -.. autofunction:: multigammaln -.. autofunction:: ndtr -.. autofunction:: ndtri -.. autofunction:: polygamma -.. autofunction:: psi -.. autofunction:: round -.. autofunction:: scaled_modified_bessel_k0 -.. autofunction:: scaled_modified_bessel_k1 -.. autofunction:: sinc -.. autofunction:: softmax -.. autofunction:: spherical_bessel_j0 -.. autofunction:: xlog1py -.. autofunction:: xlogy -.. autofunction:: zeta diff --git a/docs/source/tensor_attributes.rst b/docs/source/tensor_attributes.rst index deb85a8773f8d1..eda8dbce234ce1 100644 --- a/docs/source/tensor_attributes.rst +++ b/docs/source/tensor_attributes.rst @@ -17,30 +17,65 @@ torch.dtype A :class:`torch.dtype` is an object that represents the data type of a :class:`torch.Tensor`. PyTorch has several different data types: -========================== =========================================== =========================== -Data type dtype Legacy Constructors -========================== =========================================== =========================== -32-bit floating point ``torch.float32`` or ``torch.float`` ``torch.*.FloatTensor`` -64-bit floating point ``torch.float64`` or ``torch.double`` ``torch.*.DoubleTensor`` -32-bit complex ``torch.complex32`` or ``torch.chalf`` -64-bit complex ``torch.complex64`` or ``torch.cfloat`` -128-bit complex ``torch.complex128`` or ``torch.cdouble`` -16-bit floating point [1]_ ``torch.float16`` or ``torch.half`` ``torch.*.HalfTensor`` -16-bit floating point [2]_ ``torch.bfloat16`` ``torch.*.BFloat16Tensor`` -8-bit integer (unsigned) ``torch.uint8`` ``torch.*.ByteTensor`` -8-bit integer (signed) ``torch.int8`` ``torch.*.CharTensor`` -16-bit integer (signed) ``torch.int16`` or ``torch.short`` ``torch.*.ShortTensor`` -32-bit integer (signed) ``torch.int32`` or ``torch.int`` ``torch.*.IntTensor`` -64-bit integer (signed) ``torch.int64`` or ``torch.long`` ``torch.*.LongTensor`` -Boolean ``torch.bool`` ``torch.*.BoolTensor`` -========================== =========================================== =========================== - -.. [1] Sometimes referred to as binary16: uses 1 sign, 5 exponent, and 10 - significand bits. Useful when precision is important. - -.. [2] Sometimes referred to as Brain Floating Point: use 1 sign, 8 exponent and 7 - significand bits. Useful when range is important, since it has the same - number of exponent bits as ``float32`` +**Floating point dtypes** + +========================================= =============================== +dtype description +========================================= =============================== +``torch.float32`` or ``torch.float`` 32-bit floating point, as defined in https://en.wikipedia.org/wiki/IEEE_754 +``torch.float64`` or ``torch.double`` 64-bit floating point, as defined in https://en.wikipedia.org/wiki/IEEE_754 +``torch.float16`` or ``torch.half`` 16-bit floating point, as defined in https://en.wikipedia.org/wiki/IEEE_754, S-E-M 1-5-10 +``torch.bfloat16`` 16-bit floating point, sometimes referred to as Brain floating point, S-E-M 1-8-7 +``torch.complex32`` or ``torch.chalf`` 32-bit complex with two `float16` components +``torch.complex64`` or ``torch.cfloat`` 64-bit complex with two `float32` components +``torch.complex128`` or ``torch.cdouble`` 128-bit complex with two `float64` components +``torch.float8_e4m3fn`` [shell]_, [1]_ 8-bit floating point, S-E-M 1-4-3, from https://arxiv.org/abs/2209.05433 +``torch.float8_e5m2`` [shell]_ 8-bit floating point, S-E-M 1-5-2, from https://arxiv.org/abs/2209.05433 +``torch.float8_e4m3fnuz`` [shell]_, [1]_ 8-bit floating point, S-E-M 1-4-3, from https://arxiv.org/pdf/2206.02915 +``torch.float8_e5m2fnuz`` [shell]_, [1]_ 8-bit floating point, S-E-M 1-5-2, from https://arxiv.org/pdf/2206.02915 +``torch.float8_e8m0fnu`` [shell]_, [1]_ 8-bit floating point, S-E-M 0-8-0, from https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +``torch.float4_e2m1fn_x2`` [shell]_, [1]_ packed 4-bit floating point, S-E-M 1-2-1, from https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +========================================= =============================== + +**Integer dtypes** + +========================================= =============================== +dtype description +========================================= =============================== +``torch.uint8`` 8-bit integer (unsigned) +``torch.int8`` 8-bit integer (signed) +``torch.uint16`` [shell]_, [2]_ 16-bit integer (unsigned) +``torch.int16`` or ``torch.short`` 16-bit integer (signed) +``torch.uint32`` [shell]_, [2]_ 32-bit integer (unsigned) +``torch.int32`` or ``torch.int`` 32-bit integer (signed) +``torch.uint64`` [shell]_, [2]_ 64-bit integer (unsigned) +``torch.int64`` or ``torch.long`` 64-bit integer (signed) +``torch.bool`` Boolean +========================================= =============================== + +.. [shell] a shell dtype a specialized dtype with limited op and backend support. + Specifically, ops that support tensor creation (``torch.empty``, ``torch.fill``, ``torch.zeros``) + and operations which do not peek inside the data elements (``torch.cat``, ``torch.view``, ``torch.reshape``) + are supported. Ops that peek inside the data elements such as casting, + matrix multiplication, nan/inf checks are supported only on a case by + case basis, depending on maturity and presence of hardware accelerated kernels + and established use cases. + +.. [1] The "fn", "fnu" and "fnuz" dtype suffixes mean: + "f" - finite value encodings only, no infinity; + "n" - nan value encodings differ from the IEEE spec; + "uz" - "unsigned zero" only, i.e. no negative zero encoding + +.. [2] + Unsigned types asides from ``uint8`` are currently planned to only have + limited support in eager mode (they primarily exist to assist usage with + torch.compile); if you need eager support and the extra range is not needed, + we recommend using their signed variants instead. See + https://github.com/pytorch/pytorch/issues/58734 for more details. + +**Note**: legacy constructors such as ``torch.*.FloatTensor``, ``torch.*.DoubleTensor``, ``torch.*.HalfTensor``, +``torch.*.BFloat16Tensor``, ``torch.*.ByteTensor``, ``torch.*.CharTensor``, ``torch.*.ShortTensor``, ``torch.*.IntTensor``, +``torch.*.LongTensor``, ``torch.*.BoolTensor`` only remain for backwards compatibility and should no longer be used. To find out if a :class:`torch.dtype` is a floating point data type, the property :attr:`is_floating_point` can be used, which returns ``True`` if the data type is a floating point data type. @@ -64,8 +99,8 @@ by finding the minimum dtype that satisfies the following rules: A floating point scalar operand has dtype `torch.get_default_dtype()` and an integral non-boolean scalar operand has dtype `torch.int64`. Unlike numpy, we do not inspect -values when determining the minimum `dtypes` of an operand. Quantized and complex types -are not yet supported. +values when determining the minimum `dtypes` of an operand. Complex types +are not yet supported. Promotion for shell dtypes is not defined. Promotion Examples:: @@ -149,9 +184,13 @@ the result of :func:`torch.cuda.current_device()`. A :class:`torch.Tensor`'s device can be accessed via the :attr:`Tensor.device` property. -A :class:`torch.device` can be constructed via a string or via a string and device ordinal +A :class:`torch.device` can be constructed using: -Via a string: + * A device string, which is a string representation of the device type and optionally the device ordinal. + * A device type and a device ordinal. + * A device ordinal, where the current :ref:`accelerator` type will be used. + +Via a device string: :: >>> torch.device('cuda:0') @@ -163,10 +202,10 @@ Via a string: >>> torch.device('mps') device(type='mps') - >>> torch.device('cuda') # current cuda device + >>> torch.device('cuda') # implicit index is the "current device index" device(type='cuda') -Via a string and device ordinal: +Via a device type and a device ordinal: :: @@ -179,6 +218,24 @@ Via a string and device ordinal: >>> torch.device('cpu', 0) device(type='cpu', index=0) +Via a device ordinal: + +.. note:: + This method will raise a RuntimeError if no accelerator is currently detected. + +:: + + >>> torch.device(0) # the current accelerator is cuda + device(type='cuda', index=0) + + >>> torch.device(1) # the current accelerator is xpu + device(type='xpu', index=1) + + >>> torch.device(0) # no current accelerator detected + Traceback (most recent call last): + File "", line 1, in + RuntimeError: Cannot access accelerator device when none is available. + The device object can also be used as a context manager to change the default device tensors are allocated on: @@ -211,22 +268,13 @@ non-None device argument. To globally change the default device, see also >>> # You can substitute the torch.device with a string >>> torch.randn((2,3), device='cuda:1') -.. note:: - For legacy reasons, a device can be constructed via a single device ordinal, which is treated - as the current :ref:`accelerator` type. - This matches :meth:`Tensor.get_device`, which returns an ordinal for device - tensors and is not supported for cpu tensors. - - >>> torch.device(1) - device(type='cuda', index=1) - .. note:: Methods which take a device will generally accept a (properly formatted) string - or (legacy) integer device ordinal, i.e. the following are all equivalent: + or an integer device ordinal, i.e. the following are all equivalent: >>> torch.randn((2,3), device=torch.device('cuda:1')) >>> torch.randn((2,3), device='cuda:1') - >>> torch.randn((2,3), device=1) # legacy + >>> torch.randn((2,3), device=1) # equivalent to 'cuda:1' if the current accelerator is cuda .. note:: Tensors are never moved automatically between devices and require an explicit call from the user. Scalar Tensors (with tensor.dim()==0) are the only exception to this rule and they are automatically transferred from CPU to GPU when needed as this operation can be done "for free". diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 3f9a96ac7da659..c2336dfd81ec0e 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -6,84 +6,7 @@ torch.Tensor =================================== A :class:`torch.Tensor` is a multi-dimensional matrix containing elements of -a single data type. - - -Data types ----------- - -Torch defines tensor types with the following data types: - -======================================= =========================================== -Data type dtype -======================================= =========================================== -32-bit floating point ``torch.float32`` or ``torch.float`` -64-bit floating point ``torch.float64`` or ``torch.double`` -16-bit floating point [1]_ ``torch.float16`` or ``torch.half`` -16-bit floating point [2]_ ``torch.bfloat16`` -32-bit complex ``torch.complex32`` or ``torch.chalf`` -64-bit complex ``torch.complex64`` or ``torch.cfloat`` -128-bit complex ``torch.complex128`` or ``torch.cdouble`` -8-bit integer (unsigned) ``torch.uint8`` -16-bit integer (unsigned) ``torch.uint16`` (limited support) [4]_ -32-bit integer (unsigned) ``torch.uint32`` (limited support) [4]_ -64-bit integer (unsigned) ``torch.uint64`` (limited support) [4]_ -8-bit integer (signed) ``torch.int8`` -16-bit integer (signed) ``torch.int16`` or ``torch.short`` -32-bit integer (signed) ``torch.int32`` or ``torch.int`` -64-bit integer (signed) ``torch.int64`` or ``torch.long`` -Boolean ``torch.bool`` -quantized 8-bit integer (unsigned) ``torch.quint8`` -quantized 8-bit integer (signed) ``torch.qint8`` -quantized 32-bit integer (signed) ``torch.qint32`` -quantized 4-bit integer (unsigned) [3]_ ``torch.quint4x2`` -8-bit floating point, e4m3 [5]_ ``torch.float8_e4m3fn`` (limited support) -8-bit floating point, e5m2 [5]_ ``torch.float8_e5m2`` (limited support) -======================================= =========================================== - -.. [1] - Sometimes referred to as binary16: uses 1 sign, 5 exponent, and 10 - significand bits. Useful when precision is important at the expense of range. -.. [2] - Sometimes referred to as Brain Floating Point: uses 1 sign, 8 exponent, and 7 - significand bits. Useful when range is important, since it has the same - number of exponent bits as ``float32`` -.. [3] - quantized 4-bit integer is stored as a 8-bit signed integer. Currently it's only supported in EmbeddingBag operator. -.. [4] - Unsigned types asides from ``uint8`` are currently planned to only have - limited support in eager mode (they primarily exist to assist usage with - torch.compile); if you need eager support and the extra range is not needed, - we recommend using their signed variants instead. See - https://github.com/pytorch/pytorch/issues/58734 for more details. -.. [5] - ``torch.float8_e4m3fn`` and ``torch.float8_e5m2`` implement the spec for 8-bit - floating point types from https://arxiv.org/abs/2209.05433. The op support - is very limited. - - -For backwards compatibility, we support the following alternate class names -for these data types: - -======================================= ============================= ================================ -Data type CPU tensor GPU tensor -======================================= ============================= ================================ -32-bit floating point :class:`torch.FloatTensor` :class:`torch.cuda.FloatTensor` -64-bit floating point :class:`torch.DoubleTensor` :class:`torch.cuda.DoubleTensor` -16-bit floating point :class:`torch.HalfTensor` :class:`torch.cuda.HalfTensor` -16-bit floating point :class:`torch.BFloat16Tensor` :class:`torch.cuda.BFloat16Tensor` -8-bit integer (unsigned) :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor` -8-bit integer (signed) :class:`torch.CharTensor` :class:`torch.cuda.CharTensor` -16-bit integer (signed) :class:`torch.ShortTensor` :class:`torch.cuda.ShortTensor` -32-bit integer (signed) :class:`torch.IntTensor` :class:`torch.cuda.IntTensor` -64-bit integer (signed) :class:`torch.LongTensor` :class:`torch.cuda.LongTensor` -Boolean :class:`torch.BoolTensor` :class:`torch.cuda.BoolTensor` -======================================= ============================= ================================ - -However, to construct tensors, we recommend using factory functions such as -:func:`torch.empty` with the ``dtype`` argument instead. The -:class:`torch.Tensor` constructor is an alias for the default tensor type -(:class:`torch.FloatTensor`). +a single data type. Please see :ref:`dtype-doc` for more details about dtype support. Initializing and basic operations --------------------------------- diff --git a/docs/source/testing.md b/docs/source/testing.md new file mode 100644 index 00000000000000..c18a4a1d9ed4d2 --- /dev/null +++ b/docs/source/testing.md @@ -0,0 +1,21 @@ +# torch.testing + +```{eval-rst} +.. automodule:: torch.testing +``` + +```{eval-rst} +.. currentmodule:: torch.testing +``` + +```{eval-rst} +.. autofunction:: assert_close +``` + +```{eval-rst} +.. autofunction:: make_tensor +``` + +```{eval-rst} +.. autofunction:: assert_allclose +``` diff --git a/docs/source/testing.rst b/docs/source/testing.rst deleted file mode 100644 index 8837c4a0ec1a76..00000000000000 --- a/docs/source/testing.rst +++ /dev/null @@ -1,9 +0,0 @@ -torch.testing -============= - -.. automodule:: torch.testing -.. currentmodule:: torch.testing - -.. autofunction:: assert_close -.. autofunction:: make_tensor -.. autofunction:: assert_allclose diff --git a/docs/source/threading_environment_variables.md b/docs/source/threading_environment_variables.md new file mode 100644 index 00000000000000..590ae6391ecb5b --- /dev/null +++ b/docs/source/threading_environment_variables.md @@ -0,0 +1,14 @@ +(threading_environment_variables)= + +# Threading Environment Variables + +```{list-table} +:header-rows: 1 + +* - Variable + - Description +* - ``OMP_NUM_THREADS`` + - Sets the maximum number of threads to use for OpenMP parallel regions. +* - ``MKL_NUM_THREADS`` + - Sets the maximum number of threads to use for the Intel MKL library. Note that MKL_NUM_THREADS takes precedence over ``OMP_NUM_THREADS``. +``` diff --git a/docs/source/threading_environment_variables.rst b/docs/source/threading_environment_variables.rst deleted file mode 100644 index aaca33392961cc..00000000000000 --- a/docs/source/threading_environment_variables.rst +++ /dev/null @@ -1,13 +0,0 @@ -.. _threading_environment_variables: - -Threading Environment Variables -=============================== -.. list-table:: - :header-rows: 1 - - * - Variable - - Description - * - ``OMP_NUM_THREADS`` - - Sets the maximum number of threads to use for OpenMP parallel regions. - * - ``MKL_NUM_THREADS`` - - Sets the maximum number of threads to use for the Intel MKL library. Note that MKL_NUM_THREADS takes precedence over ``OMP_NUM_THREADS``. \ No newline at end of file diff --git a/docs/source/torch.ao.ns._numeric_suite.md b/docs/source/torch.ao.ns._numeric_suite.md new file mode 100644 index 00000000000000..b1466470fe26c4 --- /dev/null +++ b/docs/source/torch.ao.ns._numeric_suite.md @@ -0,0 +1,16 @@ +(torch_ao_ns_numeric_suite)= + +# torch.ao.ns._numeric_suite + +```{warning} +This module is an early prototype and is subject to change. +``` + +```{eval-rst} +.. currentmodule:: torch.ao.ns._numeric_suite +``` +```{eval-rst} +.. automodule:: torch.ao.ns._numeric_suite + :members: + :member-order: bysource +``` diff --git a/docs/source/torch.ao.ns._numeric_suite.rst b/docs/source/torch.ao.ns._numeric_suite.rst deleted file mode 100644 index a3d6d4b8ff5b7d..00000000000000 --- a/docs/source/torch.ao.ns._numeric_suite.rst +++ /dev/null @@ -1,13 +0,0 @@ -.. _torch_ao_ns_numeric_suite: - -torch.ao.ns._numeric_suite --------------------------- - -.. warning :: - This module is an early prototype and is subject to change. - -.. currentmodule:: torch.ao.ns._numeric_suite - -.. automodule:: torch.ao.ns._numeric_suite - :members: - :member-order: bysource diff --git a/docs/source/torch.ao.ns._numeric_suite_fx.md b/docs/source/torch.ao.ns._numeric_suite_fx.md new file mode 100644 index 00000000000000..46a46d598f4f54 --- /dev/null +++ b/docs/source/torch.ao.ns._numeric_suite_fx.md @@ -0,0 +1,39 @@ +(torch_ao_ns_numeric_suite_fx)= + +# torch.ao.ns._numeric_suite_fx + + +```{warning} + This module is an early prototype and is subject to change. +``` + +```{eval-rst} +.. automodule:: torch.ao.ns._numeric_suite_fx + :members: + :member-order: bysource + +``` +--- + +# torch.ao.ns.fx.utils + + +```{warning} + This module is an early prototype and is subject to change. +``` + +```{eval-rst} +.. currentmodule:: torch.ao.ns.fx.utils +``` + +```{eval-rst} +.. function:: compute_sqnr(x, y) +``` + +```{eval-rst} +.. function:: compute_normalized_l2_error(x, y) +``` + +```{eval-rst} +.. function:: compute_cosine_similarity(x, y) +``` \ No newline at end of file diff --git a/docs/source/torch.ao.ns._numeric_suite_fx.rst b/docs/source/torch.ao.ns._numeric_suite_fx.rst deleted file mode 100644 index eb466b45cb601d..00000000000000 --- a/docs/source/torch.ao.ns._numeric_suite_fx.rst +++ /dev/null @@ -1,26 +0,0 @@ -.. _torch_ao_ns_numeric_suite_fx: - -torch.ao.ns._numeric_suite_fx ------------------------------ - -.. warning :: - This module is an early prototype and is subject to change. - -.. currentmodule:: torch.ao.ns._numeric_suite_fx - -.. automodule:: torch.ao.ns._numeric_suite_fx - :members: - :member-order: bysource - - -torch.ao.ns.fx.utils --------------------------------------- - -.. warning :: - This module is an early prototype and is subject to change. - -.. currentmodule:: torch.ao.ns.fx.utils - -.. autofunction:: torch.ao.ns.fx.utils.compute_sqnr(x, y) -.. autofunction:: torch.ao.ns.fx.utils.compute_normalized_l2_error(x, y) -.. autofunction:: torch.ao.ns.fx.utils.compute_cosine_similarity(x, y) diff --git a/docs/source/torch.compiler.config.md b/docs/source/torch.compiler.config.md new file mode 100644 index 00000000000000..66059f07ea5b0b --- /dev/null +++ b/docs/source/torch.compiler.config.md @@ -0,0 +1,14 @@ +```{eval-rst} +.. currentmodule:: torch.compiler.config + +``` + +# torch.compiler.config + +```{eval-rst} +.. automodule:: torch.compiler.config +``` + +```{eval-rst} +.. autodata:: torch.compiler.config.job_id +``` diff --git a/docs/source/torch.compiler.config.rst b/docs/source/torch.compiler.config.rst deleted file mode 100644 index c40b41fdb5d31c..00000000000000 --- a/docs/source/torch.compiler.config.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. currentmodule:: torch.compiler.config - - -torch.compiler.config -===================== - -.. automodule:: torch.compiler.config - -.. autodata:: torch.compiler.config.job_id diff --git a/docs/source/torch.compiler.md b/docs/source/torch.compiler.md new file mode 100644 index 00000000000000..5f12670f5e1def --- /dev/null +++ b/docs/source/torch.compiler.md @@ -0,0 +1,128 @@ +(torch.compiler_overview)= + +# torch.compiler + +`torch.compiler` is a namespace through which some of the internal compiler +methods are surfaced for user consumption. The main function and the feature in +this namespace is `torch.compile`. + +`torch.compile` is a PyTorch function introduced in PyTorch 2.x that aims to +solve the problem of accurate graph capturing in PyTorch and ultimately enable +software engineers to run their PyTorch programs faster. `torch.compile` is +written in Python and it marks the transition of PyTorch from C++ to Python. + +`torch.compile` leverages the following underlying technologies: + +- **TorchDynamo (torch._dynamo)** is an internal API that uses a CPython + feature called the Frame Evaluation API to safely capture PyTorch graphs. + Methods that are available externally for PyTorch users are surfaced + through the `torch.compiler` namespace. +- **TorchInductor** is the default `torch.compile` deep learning compiler + that generates fast code for multiple accelerators and backends. You + need to use a backend compiler to make speedups through `torch.compile` + possible. For NVIDIA, AMD and Intel GPUs, it leverages OpenAI Triton as the key + building block. +- **AOT Autograd** captures not only the user-level code, but also backpropagation, + which results in capturing the backwards pass "ahead-of-time". This enables + acceleration of both forwards and backwards pass using TorchInductor. + +:::{note} +In some cases, the terms `torch.compile`, TorchDynamo, `torch.compiler` +might be used interchangeably in this documentation. +::: + +As mentioned above, to run your workflows faster, `torch.compile` through +TorchDynamo requires a backend that converts the captured graphs into a fast +machine code. Different backends can result in various optimization gains. +The default backend is called TorchInductor, also known as *inductor*, +TorchDynamo has a list of supported backends developed by our partners, +which can be seen by running `torch.compiler.list_backends()` each of which +with its optional dependencies. + +Some of the most commonly used backends include: + +**Training & inference backends** + +```{eval-rst} +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Backend + - Description + * - ``torch.compile(m, backend="inductor")`` + - Uses the TorchInductor backend. `Read more `__ + * - ``torch.compile(m, backend="cudagraphs")`` + - CUDA graphs with AOT Autograd. `Read more `__ + * - ``torch.compile(m, backend="ipex")`` + - Uses IPEX on CPU. `Read more `__ + * - ``torch.compile(m, backend="onnxrt")`` + - Uses ONNX Runtime for training on CPU/GPU. :doc:`Read more ` +``` + +**Inference-only backends** + +```{eval-rst} +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Backend + - Description + * - ``torch.compile(m, backend="tensorrt")`` + - Uses Torch-TensorRT for inference optimizations. Requires ``import torch_tensorrt`` in the calling script to register backend. `Read more `__ + * - ``torch.compile(m, backend="ipex")`` + - Uses IPEX for inference on CPU. `Read more `__ + * - ``torch.compile(m, backend="tvm")`` + - Uses Apache TVM for inference optimizations. `Read more `__ + * - ``torch.compile(m, backend="openvino")`` + - Uses OpenVINO for inference optimizations. `Read more `__ +``` + +## Read More + +```{eval-rst} +.. toctree:: + :caption: Getting Started for PyTorch Users + :maxdepth: 1 + + torch.compiler_get_started + torch.compiler_api + torch.compiler.config + torch.compiler_fine_grain_apis + torch.compiler_backward + torch.compiler_aot_inductor + torch.compiler_inductor_profiling + torch.compiler_profiling_torch_compile + torch.compiler_faq + torch.compiler_troubleshooting + torch.compiler_performance_dashboard + torch.compiler_inductor_provenance +``` + +% _If you want to contribute a developer-level topic +% that provides in-depth overview of a torch._dynamo feature, +% add in the below toc. + +```{eval-rst} +.. toctree:: + :caption: Deep Dive for PyTorch Developers + :maxdepth: 1 + + torch.compiler_dynamo_overview + torch.compiler_dynamo_deepdive + torch.compiler_dynamic_shapes + torch.compiler_nn_module + torch.compiler_cudagraph_trees + torch.compiler_fake_tensor +``` + +```{eval-rst} +.. toctree:: + :caption: HowTo for PyTorch Backend Vendors + :maxdepth: 1 + + torch.compiler_custom_backends + torch.compiler_transformations + torch.compiler_ir +``` diff --git a/docs/source/torch.compiler.rst b/docs/source/torch.compiler.rst deleted file mode 100644 index 5a3a2977e4b82f..00000000000000 --- a/docs/source/torch.compiler.rst +++ /dev/null @@ -1,120 +0,0 @@ -.. _torch.compiler_overview: - -torch.compiler -============== - -``torch.compiler`` is a namespace through which some of the internal compiler -methods are surfaced for user consumption. The main function and the feature in -this namespace is ``torch.compile``. - -``torch.compile`` is a PyTorch function introduced in PyTorch 2.x that aims to -solve the problem of accurate graph capturing in PyTorch and ultimately enable -software engineers to run their PyTorch programs faster. ``torch.compile`` is -written in Python and it marks the transition of PyTorch from C++ to Python. - -``torch.compile`` leverages the following underlying technologies: - -* **TorchDynamo (torch._dynamo)** is an internal API that uses a CPython - feature called the Frame Evaluation API to safely capture PyTorch graphs. - Methods that are available externally for PyTorch users are surfaced - through the ``torch.compiler`` namespace. - -* **TorchInductor** is the default ``torch.compile`` deep learning compiler - that generates fast code for multiple accelerators and backends. You - need to use a backend compiler to make speedups through ``torch.compile`` - possible. For NVIDIA, AMD and Intel GPUs, it leverages OpenAI Triton as the key - building block. - -* **AOT Autograd** captures not only the user-level code, but also backpropagation, - which results in capturing the backwards pass "ahead-of-time". This enables - acceleration of both forwards and backwards pass using TorchInductor. - -.. note:: In some cases, the terms ``torch.compile``, TorchDynamo, ``torch.compiler`` - might be used interchangeably in this documentation. - -As mentioned above, to run your workflows faster, ``torch.compile`` through -TorchDynamo requires a backend that converts the captured graphs into a fast -machine code. Different backends can result in various optimization gains. -The default backend is called TorchInductor, also known as *inductor*, -TorchDynamo has a list of supported backends developed by our partners, -which can be see by running ``torch.compiler.list_backends()`` each of which -with its optional dependencies. - -Some of the most commonly used backends include: - -**Training & inference backends** - -.. list-table:: - :widths: 50 50 - :header-rows: 1 - - * - Backend - - Description - * - ``torch.compile(m, backend="inductor")`` - - Uses the TorchInductor backend. `Read more `__ - * - ``torch.compile(m, backend="cudagraphs")`` - - CUDA graphs with AOT Autograd. `Read more `__ - * - ``torch.compile(m, backend="ipex")`` - - Uses IPEX on CPU. `Read more `__ - * - ``torch.compile(m, backend="onnxrt")`` - - Uses ONNX Runtime for training on CPU/GPU. :doc:`Read more ` - -**Inference-only backends** - -.. list-table:: - :widths: 50 50 - :header-rows: 1 - - * - Backend - - Description - * - ``torch.compile(m, backend="tensorrt")`` - - Uses Torch-TensorRT for inference optimizations. Requires ``import torch_tensorrt`` in the calling script to register backend. `Read more `__ - * - ``torch.compile(m, backend="ipex")`` - - Uses IPEX for inference on CPU. `Read more `__ - * - ``torch.compile(m, backend="tvm")`` - - Uses Apache TVM for inference optimizations. `Read more `__ - * - ``torch.compile(m, backend="openvino")`` - - Uses OpenVINO for inference optimizations. `Read more `__ - -Read More -~~~~~~~~~ - -.. toctree:: - :caption: Getting Started for PyTorch Users - :maxdepth: 1 - - torch.compiler_get_started - torch.compiler_api - torch.compiler.config - torch.compiler_fine_grain_apis - torch.compiler_aot_inductor - torch.compiler_inductor_profiling - torch.compiler_profiling_torch_compile - torch.compiler_faq - torch.compiler_troubleshooting - torch.compiler_performance_dashboard - torch.compiler_inductor_provenance -.. - _If you want to contribute a developer-level topic - that provides in-depth overview of a torch._dynamo feature, - add in the below toc. - -.. toctree:: - :caption: Deep Dive for PyTorch Developers - :maxdepth: 1 - - torch.compiler_dynamo_overview - torch.compiler_dynamo_deepdive - torch.compiler_dynamic_shapes - torch.compiler_nn_module - torch.compiler_best_practices_for_backends - torch.compiler_cudagraph_trees - torch.compiler_fake_tensor - -.. toctree:: - :caption: HowTo for PyTorch Backend Vendors - :maxdepth: 1 - - torch.compiler_custom_backends - torch.compiler_transformations - torch.compiler_ir diff --git a/docs/source/torch.compiler_aot_inductor.md b/docs/source/torch.compiler_aot_inductor.md new file mode 100644 index 00000000000000..d2a7c93392647a --- /dev/null +++ b/docs/source/torch.compiler_aot_inductor.md @@ -0,0 +1,212 @@ +# AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models + +```{warning} +AOTInductor and its related features are in prototype status and are +subject to backwards compatibility breaking changes. +``` + +AOTInductor is a specialized version of +[TorchInductor](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747), +designed to process exported PyTorch models, optimize them, and produce shared libraries as well +as other relevant artifacts. +These compiled artifacts are specifically crafted for deployment in non-Python environments, +which are frequently employed for inference deployments on the server side. + +In this tutorial, you will gain insight into the process of taking a PyTorch model, exporting it, +compiling it into an artifact, and conducting model predictions using C++. + +## Model Compilation + +To compile a model using AOTInductor, we first need to use +{func}`torch.export.export` to capture a given PyTorch model into a +computational graph. {ref}`torch.export ` provides soundness +guarantees and a strict specification on the IR captured, which AOTInductor +relies on. + +We will then use {func}`torch._inductor.aoti_compile_and_package` to compile the +exported program using TorchInductor, and save the compiled artifacts into one +package. + +```{note} +If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support, +the following code will compile the model into a shared library for CUDA execution. +Otherwise, the compiled artifact will run on CPU. For better performance during CPU inference, +it is suggested to enable freezing by setting `export TORCHINDUCTOR_FREEZING=1` +before running the Python script below. The same behavior works in an environment with Intel® +GPU as well. +``` + +```python +import os +import torch + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(16, 1) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return x + +with torch.no_grad(): + device = "cuda" if torch.cuda.is_available() else "cpu" + model = Model().to(device=device) + example_inputs=(torch.randn(8, 10, device=device),) + batch_dim = torch.export.Dim("batch", min=1, max=1024) + # [Optional] Specify the first dimension of the input x as dynamic. + exported = torch.export.export(model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}) + # [Note] In this example we directly feed the exported module to aoti_compile_and_package. + # Depending on your use case, e.g. if your training platform and inference platform + # are different, you may choose to save the exported model using torch.export.save and + # then load it back using torch.export.load on your inference platform to run AOT compilation. + output_path = torch._inductor.aoti_compile_and_package( + exported, + # [Optional] Specify the generated shared library path. If not specified, + # the generated artifact is stored in your system temp directory. + package_path=os.path.join(os.getcwd(), "model.pt2"), + ) +``` + +In this illustrative example, the `Dim` parameter is employed to designate the first dimension of +the input variable "x" as dynamic. Notably, the path and name of the compiled library remain unspecified, +resulting in the shared library being stored in a temporary directory. +To access this path from the C++ side, we save it to a file for later retrieval within the C++ code. + +## Inference in Python + +There are multiple ways to deploy the compiled artifact for inference, and one of that is using Python. +We have provided a convenient utility API in Python {func}`torch._inductor.aoti_load_package` for loading +and running the artifact, as shown in the following example: + +```python +import os +import torch + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "model.pt2")) +print(model(torch.randn(8, 10, device=device))) +``` + +The input at inference time should have the same size, dtype, and stride as the input at export time. + +## Inference in C++ + +Next, we use the following example C++ file `inference.cpp` to load the compiled artifact, +enabling us to conduct model predictions directly within a C++ environment. + +```cpp +#include +#include + +#include +#include + +int main() { + c10::InferenceMode mode; + + torch::inductor::AOTIModelPackageLoader loader("model.pt2"); + // Assume running on CUDA + std::vector inputs = {torch::randn({8, 10}, at::kCUDA)}; + std::vector outputs = loader.run(inputs); + std::cout << "Result from the first inference:"<< std::endl; + std::cout << outputs[0] << std::endl; + + // The second inference uses a different batch size and it works because we + // specified that dimension as dynamic when compiling model.pt2. + std::cout << "Result from the second inference:"<< std::endl; + // Assume running on CUDA + std::cout << loader.run({torch::randn({1, 10}, at::kCUDA)})[0] << std::endl; + + return 0; +} +``` + +For building the C++ file, you can make use of the provided `CMakeLists.txt` file, which +automates the process of invoking `python model.py` for AOT compilation of the model and compiling +`inference.cpp` into an executable binary named `aoti_example`. + +```cmake +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +project(aoti_example) + +find_package(Torch REQUIRED) + +add_executable(aoti_example inference.cpp model.pt2) + +add_custom_command( + OUTPUT model.pt2 + COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/model.py + DEPENDS model.py +) + +target_link_libraries(aoti_example "${TORCH_LIBRARIES}") +set_property(TARGET aoti_example PROPERTY CXX_STANDARD 17) +``` + +Provided the directory structure resembles the following, you can execute the subsequent commands +to construct the binary. It is essential to note that the `CMAKE_PREFIX_PATH` variable +is crucial for CMake to locate the LibTorch library, and it should be set to an absolute path. +Please be mindful that your path may vary from the one illustrated in this example. + +``` +aoti_example/ + CMakeLists.txt + inference.cpp + model.py +``` + +```bash +$ mkdir build +$ cd build +$ CMAKE_PREFIX_PATH=/path/to/python/install/site-packages/torch/share/cmake cmake .. +$ cmake --build . --config Release +``` + +After the `aoti_example` binary has been generated in the `build` directory, executing it will +display results akin to the following: + +```bash +$ ./aoti_example +Result from the first inference: +0.4866 +0.5184 +0.4462 +0.4611 +0.4744 +0.4811 +0.4938 +0.4193 +[ CUDAFloatType{8,1} ] +Result from the second inference: +0.4883 +0.4703 +[ CUDAFloatType{2,1} ] +``` + +## Troubleshooting + +Below are some useful tools for debugging AOT Inductor. + +```{toctree} +:caption: Debugging Tools +:maxdepth: 1 + +logging +torch.compiler_aot_inductor_minifier +``` + +To enable runtime checks on inputs, set the environment variable `AOTI_RUNTIME_CHECK_INPUTS` to 1. This will raise a `RuntimeError` if the inputs to the compiled model differ in size, data type, or strides from those used during export. + +## API Reference + +```{eval-rst} +.. autofunction:: torch._inductor.aoti_compile_and_package +.. autofunction:: torch._inductor.aoti_load_package +``` diff --git a/docs/source/torch.compiler_aot_inductor.rst b/docs/source/torch.compiler_aot_inductor.rst deleted file mode 100644 index f30d38e78798ef..00000000000000 --- a/docs/source/torch.compiler_aot_inductor.rst +++ /dev/null @@ -1,221 +0,0 @@ - - -AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models -================================================================= - -.. warning:: - - AOTInductor and its related features are in prototype status and are - subject to backwards compatibility breaking changes. - -AOTInductor is a specialized version of -`TorchInductor `__ -, designed to process exported PyTorch models, optimize them, and produce shared libraries as well -as other relevant artifacts. -These compiled artifacts are specifically crafted for deployment in non-Python environments, -which are frequently employed for inference deployments on the server side. - -In this tutorial, you will gain insight into the process of taking a PyTorch model, exporting it, -compiling it into an artifact, and conducting model predictions using C++. - - -Model Compilation ---------------------------- - -To compile a model using AOTInductor, we first need to use -:func:`torch.export.export` to capture a given PyTorch model into a -computational graph. :ref:`torch.export ` provides soundness -guarantees and a strict specification on the IR captured, which AOTInductor -relies on. - -We will then use :func:`torch._inductor.aoti_compile_and_package` to compile the -exported program using TorchInductor, and save the compiled artifacts into one -package. - -.. note:: - - If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support, - the following code will compile the model into a shared library for CUDA execution. - Otherwise, the compiled artifact will run on CPU. For better performance during CPU inference, - it is suggested to enable freezing by setting ``export TORCHINDUCTOR_FREEZING=1`` - before running the Python script below. The same behavior works in an environment with Intel® - GPU as well. - -.. code-block:: python - - import os - import torch - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.fc1 = torch.nn.Linear(10, 16) - self.relu = torch.nn.ReLU() - self.fc2 = torch.nn.Linear(16, 1) - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x): - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.sigmoid(x) - return x - - with torch.no_grad(): - device = "cuda" if torch.cuda.is_available() else "cpu" - model = Model().to(device=device) - example_inputs=(torch.randn(8, 10, device=device),) - batch_dim = torch.export.Dim("batch", min=1, max=1024) - # [Optional] Specify the first dimension of the input x as dynamic. - exported = torch.export.export(model, example_inputs, dynamic_shapes={"x": {0: batch_dim}}) - # [Note] In this example we directly feed the exported module to aoti_compile_and_package. - # Depending on your use case, e.g. if your training platform and inference platform - # are different, you may choose to save the exported model using torch.export.save and - # then load it back using torch.export.load on your inference platform to run AOT compilation. - output_path = torch._inductor.aoti_compile_and_package( - exported, - # [Optional] Specify the generated shared library path. If not specified, - # the generated artifact is stored in your system temp directory. - package_path=os.path.join(os.getcwd(), "model.pt2"), - ) - - -In this illustrative example, the ``Dim`` parameter is employed to designate the first dimension of -the input variable "x" as dynamic. Notably, the path and name of the compiled library remain unspecified, -resulting in the shared library being stored in a temporary directory. -To access this path from the C++ side, we save it to a file for later retrieval within the C++ code. - - -Inference in Python ---------------------------- -There are multiple ways to deploy the compiled artifact for inference, and one of that is using Python. -We have provided a convinient utility API in Python :func:`torch._inductor.aoti_load_package` for loading -and running the artifact, as shown in the following example: - -.. code-block:: python - - import os - import torch - - device = "cuda" if torch.cuda.is_available() else "cpu" - model = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "model.pt2")) - print(model(torch.randn(8, 10, device=device))) - -The input at inference time should have the same size, dtype, and stride as the input at export time. - -Inference in C++ ---------------------------- - -Next, we use the following example C++ file ``inference.cpp`` to load the compiled artifact, -enabling us to conduct model predictions directly within a C++ environment. - -.. code-block:: cpp - - #include - #include - - #include - #include - - int main() { - c10::InferenceMode mode; - - torch::inductor::AOTIModelPackageLoader loader("model.pt2"); - // Assume running on CUDA - std::vector inputs = {torch::randn({8, 10}, at::kCUDA)}; - std::vector outputs = loader.run(inputs); - std::cout << "Result from the first inference:"<< std::endl; - std::cout << outputs[0] << std::endl; - - // The second inference uses a different batch size and it works because we - // specified that dimension as dynamic when compiling model.pt2. - std::cout << "Result from the second inference:"<< std::endl; - // Assume running on CUDA - std::cout << loader.run({torch::randn({1, 10}, at::kCUDA)})[0] << std::endl; - - return 0; - } - -For building the C++ file, you can make use of the provided ``CMakeLists.txt`` file, which -automates the process of invoking ``python model.py`` for AOT compilation of the model and compiling -``inference.cpp`` into an executable binary named ``aoti_example``. - -.. code-block:: cmake - - cmake_minimum_required(VERSION 3.18 FATAL_ERROR) - project(aoti_example) - - find_package(Torch REQUIRED) - - add_executable(aoti_example inference.cpp model.pt2) - - add_custom_command( - OUTPUT model.pt2 - COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/model.py - DEPENDS model.py - ) - - target_link_libraries(aoti_example "${TORCH_LIBRARIES}") - set_property(TARGET aoti_example PROPERTY CXX_STANDARD 17) - - -Provided the directory structure resembles the following, you can execute the subsequent commands -to construct the binary. It is essential to note that the ``CMAKE_PREFIX_PATH`` variable -is crucial for CMake to locate the LibTorch library, and it should be set to an absolute path. -Please be mindful that your path may vary from the one illustrated in this example. - -.. code-block:: shell - - aoti_example/ - CMakeLists.txt - inference.cpp - model.py - - -.. code-block:: shell - - $ mkdir build - $ cd build - $ CMAKE_PREFIX_PATH=/path/to/python/install/site-packages/torch/share/cmake cmake .. - $ cmake --build . --config Release - -After the ``aoti_example`` binary has been generated in the ``build`` directory, executing it will -display results akin to the following: - -.. code-block:: shell - - $ ./aoti_example - Result from the first inference: - 0.4866 - 0.5184 - 0.4462 - 0.4611 - 0.4744 - 0.4811 - 0.4938 - 0.4193 - [ CUDAFloatType{8,1} ] - Result from the second inference: - 0.4883 - 0.4703 - [ CUDAFloatType{2,1} ] - - -Troubleshooting ---------------------------- -Below are some useful tools for debugging AOT Inductor. - -.. toctree:: - :caption: Debugging Tools - :maxdepth: 1 - - logging - torch.compiler_aot_inductor_minifier - -To enable runtime checks on inputs, set the environment variable `AOTI_RUNTIME_CHECK_INPUTS` to 1. This will raise a `RuntimeError` if the inputs to the compiled model differ in size, data type, or strides from those used during export. - -API Reference -------------- - -.. autofunction:: torch._inductor.aoti_compile_and_package -.. autofunction:: torch._inductor.aoti_load_package diff --git a/docs/source/torch.compiler_aot_inductor_minifier.md b/docs/source/torch.compiler_aot_inductor_minifier.md new file mode 100644 index 00000000000000..75a06159ff08a5 --- /dev/null +++ b/docs/source/torch.compiler_aot_inductor_minifier.md @@ -0,0 +1,215 @@ +# AOTInductor Minifier + +If you encounter an error while using AOT Inductor APIs such as +`torch._inductor.aoti_compile_and_package`, `torch._indcutor.aoti_load_package`, +or running the loaded model of `aoti_load_package` on some inputs, you can use the AOTInductor Minifier +to create a minimal nn.Module that reproduce the error by setting `from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True`. + +One a high-level, there are two steps in using the minifier: + +- Set `from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True` or set the environment variable `DUMP_AOTI_MINIFIER=1`. Then running the script that errors would produce a `minifier_launcher.py` script. The output directory is configurable by setting `torch._dynamo.config.debug_dir_root` to a valid directory name. + +- Run the `minifier_launcher.py` script. If the minifier runs successfully, it generates runnable python code in `repro.py` which reproduces the exact error. + +## Example Code + +Here is sample code which will generate an error because we injected an error on relu with +`torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error"`. + + +``` +import torch +from torch._inductor import config as inductor_config + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + return x + + +inductor_config.aot_inductor.dump_aoti_minifier = True +torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error" + +with torch.no_grad(): + model = Model().to("cuda") + example_inputs = (torch.randn(8, 10).to("cuda"),) + ep = torch.export.export(model, example_inputs) + package_path = torch._inductor.aoti_compile_and_package(ep) + compiled_model = torch._inductor.aoti_load_package(package_path) + result = compiled_model(*example_inputs) +``` + +The code above generates the following error: + +```text +RuntimeError: Failed to import /tmp/torchinductor_shangdiy/fr/cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py +SyntaxError: invalid syntax (cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py, line 29) +``` + + +This is because we injected an error on relu, and so the generated triton kernel looks like below. Note that we have `compile error!` +instead if `relu`, so we get a `SyntaxError`. + +``` +@triton.jit +def triton_poi_fused_addmm_relu_sigmoid_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 128 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = xindex % 16 + tmp0 = tl.load(in_out_ptr0 + (x2), xmask) + tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last') + tmp2 = tmp0 + tmp1 + tmp3 = compile error! + tmp4 = tl.sigmoid(tmp3) + tl.store(in_out_ptr0 + (x2), tmp4, xmask) +``` + + +Since we have `torch._inductor.config.aot_inductor.dump_aoti_minifier=True`, we also see an additional line indicating where `minifier_launcher.py` has +been written to. The output directory is configurable by setting +`torch._dynamo.config.debug_dir_root` to a valid directory name. + +```text +W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] Writing minified repro to: +W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_21_08_602433-pid_2861654/minifier/minifier_launcher.py +``` + + +## Minifier Launcher + +The `minifier_launcher.py` file has the following code. The `exported_program` contains the inputs to `torch._inductor.aoti_compile_and_package`. +The `command='minify'` parameter means the script will run the minifier to create a minimal graph module that reproduce the error. Alternatively, you set +use `command='run'` to just compile, load, and run the loaded model (without running the minifier). + + +``` +import torch +import torch._inductor.inductor_prims + +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config + +torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error' +torch._inductor.config.aot_inductor.dump_aoti_minifier = True + + + + +isolate_fails_code_str = None + + + +# torch version: 2.6.0a0+gitcd9c6e9 +# torch cuda version: 12.0 +# torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84 + + +# CUDA Info: +# nvcc: NVIDIA (R) Cuda compiler driver +# Copyright (c) 2005-2023 NVIDIA Corporation +# Built on Fri_Jan__6_16:45:21_PST_2023 +# Cuda compilation tools, release 12.0, V12.0.140 +# Build cuda_12.0.r12.0/compiler.32267302_0 + +# GPU Hardware Info: +# NVIDIA PG509-210 : 8 + +exported_program = torch.export.load('/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints/exported_program.pt2') +# print(exported_program.graph) +config_patches={} +if __name__ == '__main__': + from torch._dynamo.repro.aoti import run_repro + with torch.no_grad(): + run_repro(exported_program, config_patches=config_patches, accuracy=False, command='minify', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints', check_str=None) +``` + + +Suppose we kept the `command='minify'` option, and run the script, we would get the following output: + +```text +... +W1031 16:48:08.938000 3598491 torch/_dynamo/repro/aoti.py:89] Writing checkpoint with 3 nodes to /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_48_02_720863-pid_3598491/minifier/checkpoints/3.py +W1031 16:48:08.975000 3598491 torch/_dynamo/repro/aoti.py:101] Copying repro file for convenience to /data/users/shangdiy/pytorch/repro.py +Wrote minimal repro out to repro.py +``` + + +If you get an `AOTIMinifierError` when running `minifier_launcher.py`, please report a bug [here](https://github.com/pytorch/pytorch/issues/new?assignees=&labels=&projects=&template=bug-report.yml). + +## Minified Result + +The `repro.py` looks like this. Notice that the exported program is printed at the top of the file, and it contains only the relu node. The minifier successfully reduced the graph to the op that raises the error. + + +``` +# from torch.nn import * +# class Repro(torch.nn.Module): +# def __init__(self) -> None: +# super().__init__() + + + +# def forward(self, linear): +# relu = torch.ops.aten.relu.default(linear); linear = None +# return (relu,) + +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims + +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config + +torch._inductor.config.generate_intermediate_hooks = True +torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error' +torch._inductor.config.aot_inductor.dump_aoti_minifier = True + + + + +isolate_fails_code_str = None + + + +# torch version: 2.6.0a0+gitcd9c6e9 +# torch cuda version: 12.0 +# torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84 + + +# CUDA Info: +# nvcc: NVIDIA (R) Cuda compiler driver +# Copyright (c) 2005-2023 NVIDIA Corporation +# Built on Fri_Jan__6_16:45:21_PST_2023 +# Cuda compilation tools, release 12.0, V12.0.140 +# Build cuda_12.0.r12.0/compiler.32267302_0 + +# GPU Hardware Info: +# NVIDIA PG509-210 : 8 + + +exported_program = torch.export.load('/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_25_13_59_33_102283-pid_3658904/minifier/checkpoints/exported_program.pt2') +# print(exported_program.graph) +config_patches={'aot_inductor.package': True} +if __name__ == '__main__': + from torch._dynamo.repro.aoti import run_repro + with torch.no_grad(): + run_repro(exported_program, config_patches=config_patches, accuracy=False, command='run', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_25_13_59_33_102283-pid_3658904/minifier/checkpoints', check_str=None) +``` \ No newline at end of file diff --git a/docs/source/torch.compiler_aot_inductor_minifier.rst b/docs/source/torch.compiler_aot_inductor_minifier.rst deleted file mode 100644 index c85291089e647f..00000000000000 --- a/docs/source/torch.compiler_aot_inductor_minifier.rst +++ /dev/null @@ -1,221 +0,0 @@ -AOTInductor Minifier -=========================== - -If you encounter an error while using AOT Inductor APIs such as -``torch._inductor.aoti_compile_and_package``, ``torch._indcutor.aoti_load_package``, -or running the loaded model of ``aoti_load_package`` on some inputs, you can use the AOTInductor Minifier -to create a minimal nn.Module that reproduce the error by setting ``from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True``. - - -One a high-level, there are two steps in using the minifier: - -- Set ``from torch._inductor import config; config.aot_inductor.dump_aoti_minifier = True`` or set the environment variable ``DUMP_AOTI_MINIFIER=1``. Then running the script that errors would produce a ``minifier_launcher.py`` script. The output directory is configurable by setting ``torch._dynamo.config.debug_dir_root`` to a valid directory name. - -- Run the ``minifier_launcher.py`` script. If the minifier runs successfully, it generates runnable python code in ``repro.py`` which reproduces the exact error. - - -Example Code ---------------------------- - -Here is sample code which will generate an error because we injected an error on relu with -``torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error"``. - - -.. code-block:: py - - import torch - from torch._inductor import config as inductor_config - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.fc1 = torch.nn.Linear(10, 16) - self.relu = torch.nn.ReLU() - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x): - x = self.fc1(x) - x = self.relu(x) - x = self.sigmoid(x) - return x - - - inductor_config.aot_inductor.dump_aoti_minifier = True - torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error" - - with torch.no_grad(): - model = Model().to("cuda") - example_inputs = (torch.randn(8, 10).to("cuda"),) - ep = torch.export.export(model, example_inputs) - package_path = torch._inductor.aoti_compile_and_package(ep) - compiled_model = torch._inductor.aoti_load_package(package_path) - result = compiled_model(*example_inputs) - - -The code above generates the following error: - -:: - - RuntimeError: Failed to import /tmp/torchinductor_shangdiy/fr/cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py - SyntaxError: invalid syntax (cfrlf4smkwe4lub4i4cahkrb3qiczhf7hliqqwpewbw3aplj5g3s.py, line 29) - -This is because we injected an error on relu, and so the generated triton kernel looks like below. Note that we have ``compile error!`` -instead if ``relu``, so we get a ``SyntaxError``. - -.. code-block:: - - @triton.jit - def triton_poi_fused_addmm_relu_sigmoid_0(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): - xnumel = 128 - xoffset = tl.program_id(0) * XBLOCK - xindex = xoffset + tl.arange(0, XBLOCK)[:] - xmask = xindex < xnumel - x2 = xindex - x0 = xindex % 16 - tmp0 = tl.load(in_out_ptr0 + (x2), xmask) - tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last') - tmp2 = tmp0 + tmp1 - tmp3 = compile error! - tmp4 = tl.sigmoid(tmp3) - tl.store(in_out_ptr0 + (x2), tmp4, xmask) - - -Since we have ``torch._inductor.config.aot_inductor.dump_aoti_minifier=True``, we also see an additional line indicating where ``minifier_launcher.py`` has -been written to. The output directory is configurable by setting -``torch._dynamo.config.debug_dir_root`` to a valid directory name. - -:: - - W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] Writing minified repro to: - W1031 16:21:08.612000 2861654 pytorch/torch/_dynamo/debug_utils.py:279] /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_21_08_602433-pid_2861654/minifier/minifier_launcher.py - - -Minifier Launcher ---------------------------- - - -The ``minifier_launcher.py`` file has the following code. The ``exported_program`` contains the inputs to ``torch._inductor.aoti_compile_and_package``. -The ``command='minify'`` parameter means the script will run the minifier to create a minimal graph module that reproduce the error. Alternatively, you set -use ``command='run'`` to just compile, load, and run the loaded model (without running the minifier). - -.. code-block:: py - - import torch - import torch._inductor.inductor_prims - - import torch._dynamo.config - import torch._inductor.config - import torch._functorch.config - import torch.fx.experimental._config - - torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error' - torch._inductor.config.aot_inductor.dump_aoti_minifier = True - - - - - isolate_fails_code_str = None - - - - # torch version: 2.6.0a0+gitcd9c6e9 - # torch cuda version: 12.0 - # torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84 - - - # CUDA Info: - # nvcc: NVIDIA (R) Cuda compiler driver - # Copyright (c) 2005-2023 NVIDIA Corporation - # Built on Fri_Jan__6_16:45:21_PST_2023 - # Cuda compilation tools, release 12.0, V12.0.140 - # Build cuda_12.0.r12.0/compiler.32267302_0 - - # GPU Hardware Info: - # NVIDIA PG509-210 : 8 - - exported_program = torch.export.load('/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints/exported_program.pt2') - # print(exported_program.graph) - config_patches={} - if __name__ == '__main__': - from torch._dynamo.repro.aoti import run_repro - with torch.no_grad(): - run_repro(exported_program, config_patches=config_patches, accuracy=False, command='minify', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_06_13_52_35_711642-pid_3567062/minifier/checkpoints', check_str=None) - - -Suppose we kept the ``command='minify'`` option, and run the script, we would get the following output: - -:: - - ... - W1031 16:48:08.938000 3598491 torch/_dynamo/repro/aoti.py:89] Writing checkpoint with 3 nodes to /data/users/shangdiy/pytorch/torch_compile_debug/run_2024_10_31_16_48_02_720863-pid_3598491/minifier/checkpoints/3.py - W1031 16:48:08.975000 3598491 torch/_dynamo/repro/aoti.py:101] Copying repro file for convenience to /data/users/shangdiy/pytorch/repro.py - Wrote minimal repro out to repro.py - -If you get an ``AOTIMinifierError`` when running ``minifier_launcher.py``, please report a bug `here `__. - - -Minified Result ---------------------------- - -The ``repro.py`` looks like this. Notice that the exported program is printed at the top of the file, and it contains only the relu node. The minifier successfully reduced the graph to the op that raises the -error. - -.. code-block:: py - - # from torch.nn import * - # class Repro(torch.nn.Module): - # def __init__(self) -> None: - # super().__init__() - - - - # def forward(self, linear): - # relu = torch.ops.aten.relu.default(linear); linear = None - # return (relu,) - - import torch - from torch import tensor, device - import torch.fx as fx - from torch._dynamo.testing import rand_strided - from math import inf - import torch._inductor.inductor_prims - - import torch._dynamo.config - import torch._inductor.config - import torch._functorch.config - import torch.fx.experimental._config - - torch._inductor.config.generate_intermediate_hooks = True - torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = 'compile_error' - torch._inductor.config.aot_inductor.dump_aoti_minifier = True - - - - - isolate_fails_code_str = None - - - - # torch version: 2.6.0a0+gitcd9c6e9 - # torch cuda version: 12.0 - # torch git version: cd9c6e9408dd79175712223895eed36dbdc84f84 - - - # CUDA Info: - # nvcc: NVIDIA (R) Cuda compiler driver - # Copyright (c) 2005-2023 NVIDIA Corporation - # Built on Fri_Jan__6_16:45:21_PST_2023 - # Cuda compilation tools, release 12.0, V12.0.140 - # Build cuda_12.0.r12.0/compiler.32267302_0 - - # GPU Hardware Info: - # NVIDIA PG509-210 : 8 - - - exported_program = torch.export.load('/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_25_13_59_33_102283-pid_3658904/minifier/checkpoints/exported_program.pt2') - # print(exported_program.graph) - config_patches={'aot_inductor.package': True} - if __name__ == '__main__': - from torch._dynamo.repro.aoti import run_repro - with torch.no_grad(): - run_repro(exported_program, config_patches=config_patches, accuracy=False, command='run', save_dir='/data/users/shangdiy/pytorch/torch_compile_debug/run_2024_11_25_13_59_33_102283-pid_3658904/minifier/checkpoints', check_str=None) diff --git a/docs/source/torch.compiler_api.md b/docs/source/torch.compiler_api.md new file mode 100644 index 00000000000000..2b79b0e670073a --- /dev/null +++ b/docs/source/torch.compiler_api.md @@ -0,0 +1,34 @@ +```{eval-rst} +.. currentmodule:: torch.compiler +.. automodule:: torch.compiler +``` + +(torch.compiler_api)= +# torch.compiler API reference + +For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`. + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + compile + reset + allow_in_graph + substitute_in_graph + assume_constant_result + list_backends + disable + set_stance + set_enable_guard_collectives + cudagraph_mark_step_begin + is_compiling + is_dynamo_compiling + is_exporting + skip_guard_on_inbuilt_nn_modules_unsafe + skip_guard_on_all_nn_modules_unsafe + keep_tensor_guards_unsafe + skip_guard_on_globals_unsafe + nested_compile_region +``` diff --git a/docs/source/torch.compiler_api.rst b/docs/source/torch.compiler_api.rst deleted file mode 100644 index 88a373067f1cb0..00000000000000 --- a/docs/source/torch.compiler_api.rst +++ /dev/null @@ -1,27 +0,0 @@ -.. currentmodule:: torch.compiler - -.. automodule:: torch.compiler - -.. _torch.compiler_api: - -torch.compiler API reference -============================ - -For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`. - -.. autosummary:: - :toctree: generated - :nosignatures: - - compile - reset - allow_in_graph - substitute_in_graph - assume_constant_result - list_backends - disable - set_stance - cudagraph_mark_step_begin - is_compiling - is_dynamo_compiling - is_exporting diff --git a/docs/source/torch.compiler_backward.md b/docs/source/torch.compiler_backward.md new file mode 100644 index 00000000000000..27cd66dc419c89 --- /dev/null +++ b/docs/source/torch.compiler_backward.md @@ -0,0 +1,75 @@ +``torch.compile`` has different autograd semantics +================================================== + +When you apply ``torch.compile`` to a function in your model's forward pass, +it will automatically generate a backward pass for the compiled function. +During compilation, it will trace out a graph for the backward pass that +is used whenever autograd is invoked. We refer to the component inside +``torch.compile`` that is responsible for this as ``AOTDispatcher`` +(sometimes known as ``AOTAutograd``). + +As so, ``torch.compile`` bakes in details of the computation into the +traced-out backward graph during compilation of the function +in the forward pass. +However, in eager-mode PyTorch, the backward computation is dynamic: +outside of the forward pass, you can wrap the call to +``tensor.backward()`` or ``torch.autograd.grad(...)`` +in a context manager that may change its behavior. + +This page documents how ``torch.compile``'s autograd semantics differ from +eager-mode PyTorch and how to work around it. + +``Autocast`` behavior +--------------------- + +``torch.compile`` bakes in an assumption on if the backward pass will be +run under an ambient autocast context manager. By default, +Use ``torch._functorch.config.backward_pass_autocast`` +to control that assumption; an incorrect assumption may lead to silent +incorrectness. + +The options are either: +- `"same_as_forward"` (default). + We assume that the backward of the ``torch.compile``'ed region + will be run under the same autocast context manager that the region was run + under (if any). Use this if your code looks like the following: + ```py + with torch.amp.autocast(...): + y = torch.compile(region)(x) + ... + # backward pass run under the same autocast context as the compiled region + z.backward() + ``` +- `"off"`. We assume that the backward of the torch.compile'd region will + not be run under any autocast context managers. + Use this if your code looks like the following: + ```py + with torch.amp.autocast(...): + y = torch.compile(region)(x) + ... + # Backward pass runs under no autocast. + z.backward() + ``` +- There is a third option. If you set ``torch._functorch.config.backward_pass_autocast`` + to a list of kwargs, we will assume the backward pass runs under an autocast context + constructed by those kwargs. + + For example, if your code looks like the following: + ```py + y = torch.compile(region)(x) + ... + # Backward pass runs under special context manager + with torch.amp.autocast(**kwargs): + z.backward() + ``` + then set ``torch._functorch.config.backward_pass_autocast = kwargs``. + +Use ``patch`` to apply the option to a specific ``torch.compile`` call: +```py +with torch.amp.autocast(...): + with torch._functorch.config.patch(backward_pass_autocast="same_as_forward") + y = torch.compile(region)(x) + ... + # backward pass run under the same autocast context as the compiled region + z.backward() +``` diff --git a/docs/source/torch.compiler_best_practices_for_backends.rst b/docs/source/torch.compiler_best_practices_for_backends.rst deleted file mode 100644 index 32052403511df6..00000000000000 --- a/docs/source/torch.compiler_best_practices_for_backends.rst +++ /dev/null @@ -1,17 +0,0 @@ -Best Practices for Backends -=========================== - -x86 CPU -------- - -Compiled workloads on modern x86 CPUs are usually optimized by Single Instruction Multiple Data (SIMD) instruction sets. SIMD is a typical parallel processing technique for high performance computing, such as deep learning model training and inference. With SIMD applied, each compute unit performs the same instruction with different allocated data at any given time slot. The most commonly deployed x86 instruction set architectures (ISAs) enabling SIMD include `AVX, AVX2, AVX-512 `_ and `AMX `_. - -You can check supported ISAs for your machine by using the `collect_env script `_. As the script provides complete environment information for PyTorch, we can use ``grep`` to extract the line containing ISA information: - -:: - - python collect_env.py | grep "a[(v|m)]x" - -Normally, if AVX-512 is supported, instructions start with "avx512" (like ``avx512f``, ``avx512bw``, ``avx512_vnni``) should be observed. If AMX is supported, instructions start with "amx" (like ``amx_tile``, ``amx_bf16``, ``amx_int8``) should be observed. - -Specifically, with a server having AMX instructions enabled, workloads performance can be further boosted by `leveraging AMX `_. diff --git a/docs/source/torch.compiler_cudagraph_trees.md b/docs/source/torch.compiler_cudagraph_trees.md new file mode 100644 index 00000000000000..6fd52edf97a4d3 --- /dev/null +++ b/docs/source/torch.compiler_cudagraph_trees.md @@ -0,0 +1,287 @@ +# CUDAGraph Trees + +## **Background** + +### CUDAGraph + +For a longer background on CUDAGraphs, read [accelerating pytorch with CUDAGraphs](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/). + +[CUDA Graphs](https://developer.nvidia.com/blog/cuda-10-features-revealed/), which made its debut in CUDA 10, let a series of CUDA kernels to be defined and encapsulated as a single unit, i.e., a graph of operations, rather than a sequence of individually-launched operations. It provides a mechanism to launch multiple GPU operations through a single CPU operation, and hence reduces the launching overheads. + +CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses. + +- Control Flow is not possible +- Kernels which trigger host to device syncs (such as .item()) errors +- All input arguments to kernels are fixed to what they were recorded +- CUDA Memory addresses are fixed, however the values of the memory at those addresses can change +- No Essential CPU ops or CPU side effects + +### PyTorch CUDAGraph Integration + +PyTorch provides a [convenience wrapper](https://pytorch.org/docs/stable/generated/torch.cuda.CUDAGraph.html) around CUDAGraphs that handles a couple of tricky interactions with PyTorch’s caching allocator. + +The CachingAllocator uses a separate memory pool for all the new allocations. During CUDAGraph recording, memory is accounted for, allocated, and freed exactly as during eager run. On replay, just the kernels are invoked, and there are no changes to the allocator. Subsequent to initial recording, the allocator does not know which memory is actively being used in user programs. + +Using a separate memory pool between eager allocations and cudagraph allocations may increase the memory of your program if there is substantial memory allocated to both. + +### Make Graphed Callables + +[Make Graphed Callables](https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html) is a PyTorch Abstraction to share a single memory pool over a series of callables. Graphed Callables takes advantage of the fact that on CUDA Graph recording, memory is exactly accounted for by the caching allocator to safely share memory between separate CUDA Graph recordings. In each invocation, outputs are preserved as live memory, preventing one callable from overwriting the live memory of another. Graphed Callables can only be invoked in a single order; memory addresses from the first run are burned into the second, and so forth. + +### TorchDynamo Previous CUDA Graphs Integration + +Running with `cudagraph_trees=False` does not reuse memory across separate graph captures, which can lead to large memory regressions. Even for a model that has no graph breaks, this has issues. The forward and backward are separate graph captures, so the memory pools for forward and backward are not shared. In particular, memory for activations that are saved in the forward cannot be reclaimed in the backward. + +## **CUDAGraph Trees Integration** + +Like Graph Callables, CUDA Graph Trees use a single memory pool across all graph captures. However, instead of requiring a single sequence of invocations, CUDA Graph Trees create separate trees of CUDA Graph captures. Let’s take a look at an illustrative example: + +```python +@torch.compile(mode="reduce-overhead") +def foo(x): + # GRAPH 1 + y = x * x * x + # graph break triggered here + if y.sum() > 0: + # GRAPH 2 + z = y ** y + else: + # GRAPH 3 + z = (y.abs() ** y.abs()) + torch._dynamo.graph_break() + # GRAPH 4 + return z * torch.rand_like(z) + +# the first run warms up each graph, which does things like CuBlas or Triton benchmarking +foo(torch.arange(0, 10, device="cuda")) +# The second run does a CUDA Graph recording, and replays it +foo(torch.arange(0, 10, device="cuda")) +# Finally we hit the optimized, CUDA Graph replay path +foo(torch.arange(0, 10, device="cuda")) +``` + +In this example, there are two separate paths that we make through the function: 1 -> 2 -> 4, or 1 -> 3 -> 4. + +We share all of the memory in a single memory pool between separate recordings by building up a tape of CUDA Graph recordings, in this instance, 1 -> 2 -> 4. We add invariants to ensure that memory is always in the same location as it were recorded, and no live tensors exist in user programs that might be overwritten. + +- Same constraints from CUDA Graphs apply: same kernels must be invoked with the same arguments (static sizes, addresses, etc) +- The same pattern of memory must be observed between recording and replay: if a tensor output of one graph dies subsequent to another graph during recording, it must also do so during replay. +- Live memory in the CUDA pool forces a dependence between two recordings +- These recordings can only be invoked in a single order 1 - > 2 -> 4 + +All of the memory is shared in a single memory pool, so there is no additional memory overhead compared to eager. Now, what happens if we were to hit a new path and run Graph 3? + +Graph 1 gets replayed, and then we hit Graph 3, which we have not yet recorded. On graph replays, the private memory pool is not updated, so y is not reflected in the allocator. Without care, we would overwrite it. To support reusing the same memory pool after replaying other graphs, we checkpoint the memory pool back to its state at the end of graph 1. Now that our live tensors are reflected in the caching allocator, we are safe to run a new graph. + +First, we would hit the optimized, CUDAGraph.replay() path that we have already recorded in graph 1. Then we would hit Graph 3. Just as before, we will need to warm up the graph once before recording. On the warmup run, the memory addresses are not fixed, so graph 4 will also fallback to the inductor, non-cudagraph invocation. + +The second time we hit graph 3 we are warmed up and ready to record. We record graph 3 and then record graph 4 again since the input memory addresses have changed. This creates a tree of CUDA Graph recordings. A CUDA Graph Tree! + +``` + 1 + / \\ +2 3 + \\ \\ + 4 4 +``` + +### Input Mutation Support + +Input mutation function refers to a function conducting in-place writes to an input tensor, +as illustrated below: + +```python +def foo(x, y): + # mutates input x + x.add_(1) + return x + y +``` + +Input mutation functions generally lead to challenges for CUDAGraph Trees. Due to the static +CUDA memory address requirement from CUDAGraph, for each input tensor x, CUDAGraph Trees may +allocate a static memory address x'. During execution, CUDAGraph Trees first copy the input +tensor x to the static memory address x', and then replay the recorded CUDAGraph. For input +mutation function, x' is in-place updated, which is not reflected on the input tensor x since +x and x' reside on different CUDA memory addresses. + +A closer look at input mutation functions reveals that there are three types of inputs: + +- **inputs from eager**: These tensors we assume will vary input tensor addresses from + execution to execution. Because cudagraphs freeze memory addresses, we need to copy these + inputs to a static address tensor prior to graph recording and execution. +- **Parameters and buffers**: These tensors we assume (and runtime-check) have the same tensor + addresses on every execution. We do not need to copy over their contents because the recorded + memory address will be the same as the executed memory address. +- **Tensors which are prior outputs from CUDAGraph Trees**: Because the output tensor addresses + of a cudagraph are fixed, if we run CUDAGraph1, then run CUDAGraph2, the inputs which came from + CUDAGraph1 into CUDAGraph2 will have a fixed memory address. These inputs, like parameters and + buffers, do not require copying over to a static address tensor. We check to make sure that + these inputs are stable at runtime, and if they're not we will re-record. + +CUDAGraph Trees support input mutation on parameters and buffers, and tensors which are prior +outputs from CUDAGraph Trees. For mutation on inputs from eager, CUDAGraph Trees will run the +function without CUDAGraph and emit *skipping due to mutated inputs* log. The following example +shows CUDAGraph Trees' support for tensors which are prior outputs from CUDAGraph Trees. + +```python +import torch + +@torch.compile(mode="reduce-overhead") +def foo(x): + return x + 1 + +@torch.compile(mode="reduce-overhead") +def mut(x): + return x.add_(2) + +# Enable input mutation support +torch._inductor.config.triton.cudagraph_support_input_mutation = True + +for i in range(3): + torch.compiler.cudagraph_mark_step_begin() + inp = torch.rand([4], device="cuda") + + # CUDAGraph is applied since `foo` does not mutate `inp` + tmp = foo(inp) + # Although `mut` mutates `tmp`, which is an output of a CUDAGraph + # managed function. So CUDAGraph is still applied. + mut(tmp) + + +torch.compiler.cudagraph_mark_step_begin() +inp = torch.rand([4], device="cuda") + +tmp = foo(inp) +# While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()` +# is not. So CUDAGraph is not applied to `mut` and there is a log +# `skipping cudagraphs due to mutated inputs` +mut(tmp.clone()) +``` + +To enable CUDAGraph Trees for a function mutating inputs from eager, please re-write +the function to avoid input mutation. + + + +> **Note**\ +> Enable input mutation support by setting +[torch.\_inductor.config.cudagraph_support_input_mutation = True](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L662) for "reduce-overhead" mode. + + +### Dynamic Shape Support + +[Dynamic shape](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html) +means that an input tensor has different shapes across function calls. Since CUDAGraph +requires fixed tensor addresses, CUDAGraph Trees re-record CUDAGraph for every unique +shape of an input tensor. This leads to multiple CUDAGraphs for a single inductor graph. +When there are limited shapes (e.g., batch sizes in inference), it is profitable to +re-record CUDAGraphs. However, if input tensor shapes change frequently or even on +every invocation, re-recording CUDAGraph may not be profitable. Nvidia uses 64 KB of +device memory per kernel launch in CUDAGraph, up until CUDA 12.4 and Driver Version 550+. +This memory cost can be significant with many CUDAGraph re-recordings. + +For functions with frequently changing input tensor shapes, we suggest padding input +tensors to a few fixed tensor shapes to still enjoy benefits from CUDAGraph. In addition, +setting [torch.\_inductor.config.triton.cudagraph_skip_dynamic_graphs=True](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L653) +allows to skip cudagraphing functions with dynamic shape inputs and only cudagraphing +functions with static input tensor shapes. + +### NCCL Support + +CUDAGraph Trees support functions with nccl operators. While CUDAGraph Trees perform per-device +record for CUDAGraph, NCCL support allows cross-device communication. + +```python +@torch.compile(mode="reduce-overhead") +def func(x): + y = x * x + y = torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM) + x = torch.nn.functional.silu(x) + return x * y +``` + +### Reasons for Skipping CUDAGraph + +Since CUDAGraph has requirements such as static input tensor addresses and not supporting +CPU operators, CUDAGraph Trees check whether a function satisfies these requirements and +may skip CUDAGraph when necessary. Here, we list common reasons for skipping CUDAGraph. + +- **Input mutation**: CUDAGraph Trees skip functions that in-place mutates eager input. + In-place mutating parameters and buffers, or output tensors from CUDAGraph Tree managed + functions are still supported. Please see *Input Mutation Support* section for more details. +- **CPU operators**: Functions containing CPU operator are skipped. Please split the + function into multiple functions and apply CUDAGraph Trees on functions with only GPU operators. +- **Multi-device operators**: A function is skipped if it contains operators on multiple + devices. Currently, CUDAGraph is applied on a per-device basis. Please use supported + libraries such as NCCL for cross-device communication. Please see *NCCL Support* + section for more details. +- **Free unbacked symbols**: Free unbacked symbols usually happen during + [dynamic shapes](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html). + CUDAGraph Trees currently record a CUDAGraph for every unique input tensor shapes. + Please see *Dynamic Shape Support* for more details. +- **Incompatible operators**: CUDAGraph Trees skip a function if it contain incompatible + operators. Please replace these operators in a function with supported operators. We + show an exhaustive list of incompatible operators: + +```python +aten._fused_moving_avg_obs_fq_helper.default +aten._fused_moving_avg_obs_fq_helper_functional.default +aten.multinomial.default +fbgemm.dense_to_jagged.default +fbgemm.jagged_to_padded_dense.default +run_and_save_rng_state +run_with_rng_state +aten._local_scalar_dense +aten._assert_scalar +``` + +The following operators are incompatible when [torch.are_deterministic_algorithms_enabled()](https://pytorch.org/docs/stable/generated/torch.are_deterministic_algorithms_enabled.html). + +```python +aten._fused_moving_avg_obs_fq_helper.default +aten._fused_moving_avg_obs_fq_helper_functional.default +aten.multinomial.default +fbgemm.dense_to_jagged.default +fbgemm.jagged_to_padded_dense.default +run_and_save_rng_state +run_with_rng_state +aten._local_scalar_dense +aten._assert_scalar +``` + +### Limitations + +Because CUDA Graph fixes memory addresses, CUDA Graphs do not have a great way of handling live tensors from a previous invocation. + +Let’s say we are benchmarking running inference with the following code: + +```python +import torch + +@torch.compile(mode="reduce-overhead") +def my_model(x): + y = torch.matmul(x, x) + return y + +x = torch.randn(10, 10, device="cuda") +y1 = my_model(x) +y2 = my_model(x) +print(y1) +# RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. +``` + +In the Separate CUDA Graph implementation, the output from the first invocation will be overwritten by the second invocation. In CUDAGraph +Trees, we don’t want to add unintended dependencies between iterations that would cause us to not hit the hot path, nor do we want we want +to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for +torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics +are wrong, you can mark the start of a new iteration with +[torch.compiler.mark_step_begin()](https://pytorch.org/docs/stable/generated/torch.compiler.cudagraph_mark_step_begin.html), or clone +tensors of a prior iteration (outside of torch.compile) before you begin the next run. + +### Comparisons + +| Footguns | Separate CudaGraph | CUDAGraph Trees | +|---------------|------------------------------------------------------------|------------------------------------------------------------------------| +| Memory Can Increase | On each graph compilation (new sizes, etc.) | If you are also running non-cudagraph memory | +| Recordings | On any new invocation of a graph | Will re-record on any new, unique path you take through your program | +| Footguns | Invocation of one graph will overwrite prior invocation | Cannot persist memory between separate runs through your model - one training loop training, or one run of inference | \ No newline at end of file diff --git a/docs/source/torch.compiler_cudagraph_trees.rst b/docs/source/torch.compiler_cudagraph_trees.rst deleted file mode 100644 index 4eef3482f1076c..00000000000000 --- a/docs/source/torch.compiler_cudagraph_trees.rst +++ /dev/null @@ -1,320 +0,0 @@ -CUDAGraph Trees -================ - -**Background** -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -CUDAGraph --------------------- - -For a longer background on CUDAGraphs, read `accelerating pytorch with CUDAGraphs `_. - -`CUDA Graphs `_, which made its debut in CUDA 10, let a series of CUDA kernels to be defined and encapsulated as a single unit, i.e., a graph of operations, rather than a sequence of individually-launched operations. It provides a mechanism to launch multiple GPU operations through a single CPU operation, and hence reduces the launching overheads. - -CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses. - -- Control Flow is not possible -- Kernels which trigger host to device syncs (such as .item()) errors -- All input arguments to kernels are fixed to what they were recorded -- CUDA Memory addresses are fixed, however the values of the memory at those addresses can change -- No Essential CPU ops or CPU side effects - -PyTorch CUDAGraph Integration ------------------------------ - -PyTorch provides a `convenience wrapper `_ around CUDAGraphs that handles a couple of tricky interactions with PyTorch’s caching allocator. - -The CachingAllocator uses a separate memory pool for all the new allocations. During CUDAGraph recording, memory is accounted for, allocated, and freed exactly as during eager run. On replay, just the kernels are invoked, and there are no changes to the allocator. Subsequent to initial recording, the allocator does not know which memory is actively being used in user programs. - -Using a separate memory pool between eager allocations and cudagraph allocations may increase the memory of your program if there is substantial memory allocated to both. - -Make Graphed Callables ----------------------- - -`Make Graphed Callables `_ is a PyTorch Abstraction to share a single memory pool over a series of callables. Graphed Callables takes advantage of the fact that on CUDA Graph recording, memory is exactly accounted for by the caching allocator to safely share memory between separate CUDA Graph recordings. In each invocation, outputs are preserved as live memory, preventing one callable from overwriting the live memory of another. Graphed Callables can only be invoked in a single order; memory addresses from the first run are burned into the second, and so forth. - -TorchDynamo Previous CUDA Graphs Integration --------------------------------------------- - -Running with ``cudagraph_trees=False`` does not reuse memory across separate graph captures, which can lead to large memory regressions. Even for a model that has no graph breaks, this has issues. The forward and backward are separate graph captures, so the memory pools for forward and backward are not shared. In particular, memory for activations that are saved in the forward cannot be reclaimed in the backward. - -**CUDAGraph Trees Integration** -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Like Graph Callables, CUDA Graph Trees use a single memory pool across all graph captures. However, instead of requiring a single sequence of invocations, CUDA Graph Trees create separate trees of CUDA Graph captures. Let’s take a look at an illustrative example: - -.. code-block:: python - - @torch.compile(mode="reduce-overhead") - def foo(x): - # GRAPH 1 - y = x * x * x - # graph break triggered here - if y.sum() > 0: - # GRAPH 2 - z = y ** y - else: - # GRAPH 3 - z = (y.abs() ** y.abs()) - torch._dynamo.graph_break() - # GRAPH 4 - return z * torch.rand_like(z) - - # the first run warms up each graph, which does things like CuBlas or Triton benchmarking - foo(torch.arange(0, 10, device="cuda")) - # The second run does a CUDA Graph recording, and replays it - foo(torch.arange(0, 10, device="cuda")) - # Finally we hit the optimized, CUDA Graph replay path - foo(torch.arange(0, 10, device="cuda")) - - -In this example, there are two separate paths that we make through the function: 1 -> 2 -> 4, or 1 -> 3 -> 4. - -We share all of the memory in a single memory pool between separate recordings by building up a tape of CUDA Graph recordings, in this instance, 1 -> 2 -> 4. We add invariants to ensure that memory is always in the same location as it were recorded, and no live tensors exist in user programs that might be overwritten. - -- Same constraints from CUDA Graphs apply: same kernels must be invoked with the same arguments (static sizes, addresses, etc) -- The same pattern of memory must be observed between recording and replay: if a tensor output of one graph dies subsequent to another graph during recording, it must also do so during replay. -- Live memory in the CUDA pool forces a dependence between two recordings -- These recordings can only be invoked in a single order 1 - > 2 -> 4 - -All of the memory is shared in a single memory pool, so there is no additional memory overhead compared to eager. Now, what happens if we were to hit a new path and run Graph 3? - -Graph 1 gets replayed, and then we hit Graph 3, which we have not yet recorded. On graph replays, the private memory pool is not updated, so y is not reflected in the allocator. Without care, we would overwrite it. To support reusing the same memory pool after replaying other graphs, we checkpoint the memory pool back to its state at the end of graph 1. Now that our live tensors are reflected in the caching allocator, we are safe to run a new graph. - -First, we would hit the optimized, CUDAGraph.replay() path that we have already recorded in graph 1. Then we would hit Graph 3. Just as before, we will need to warm up the graph once before recording. On the warmup run, the memory addresses are not fixed, so graph 4 will also fallback to the inductor, non-cudagraph invocation. - -The second time we hit graph 3 we are warmed up and ready to record. We record graph 3 and then record graph 4 again since the input memory addresses have changed. This creates a tree of CUDA Graph recordings. A CUDA Graph Tree! - -:: - - 1 - / \\ - 2 3 - \\ \\ - 4 4 - - -Input Mutation Support ----------------------- - -Input mutation function refers to a function conducting in-place writes to an input tensor, -as illustrated below: - -.. code-block:: python - - def foo(x, y): - # mutates input x - x.add_(1) - return x + y - -Input mutation functions generally lead to challenges for CUDAGraph Trees. Due to the static -CUDA memory address requirement from CUDAGraph, for each input tensor x, CUDAGraph Trees may -allocate a static memory address x'. During execution, CUDAGraph Trees first copy the input -tensor x to the static memory address x', and then replay the recorded CUDAGraph. For input -mutation function, x' is in-place updated, which is not reflected on the input tensor x since -x and x' reside on different CUDA memory addresses. - -A closer look at input mutation functions reveals that there are three types of inputs: - -* **inputs from eager**: These tensors we assume will vary input tensor addresses from - execution to execution. Because cudagraphs freeze memory addresses, we need to copy these - inputs to a static address tensor prior to graph recording and execution. -* **Parameters and buffers**: These tensors we assume (and runtime-check) have the same tensor - addresses on every execution. We do not need to copy over their contents because the recorded - memory address will be the same as the executed memory address. -* **Tensors which are prior outputs from CUDAGraph Trees**: Because the output tensor addresses - of a cudagraph are fixed, if we run CUDAGraph1, then run CUDAGraph2, the inputs which came from - CUDAGraph1 into CUDAGraph2 will have a fixed memory address. These inputs, like parameters and - buffers, do not require copying over to a static address tensor. We check to make sure that - these inputs are stable at runtime, and if they're not we will re-record. - -CUDAGraph Trees support input mutation on parameters and buffers, and tensors which are prior -outputs from CUDAGraph Trees. For mutation on inputs from eager, CUDAGraph Trees will run the -function without CUDAGraph and emit *skipping due to mutated inputs* log. The following example -shows CUDAGraph Trees' support for tensors which are prior outputs from CUDAGraph Trees. - - -.. code-block:: python - - import torch - - @torch.compile(mode="reduce-overhead") - def foo(x): - return x + 1 - - @torch.compile(mode="reduce-overhead") - def mut(x): - return x.add_(2) - - # Enable input mutation support - torch._inductor.config.triton.cudagraph_support_input_mutation = True - - for i in range(3): - torch.compiler.cudagraph_mark_step_begin() - inp = torch.rand([4], device="cuda") - - # CUDAGraph is applied since `foo` does not mutate `inp` - tmp = foo(inp) - # Although `mut` mutates `tmp`, which is an output of a CUDAGraph - # managed function. So CUDAGraph is still applied. - mut(tmp) - - - torch.compiler.cudagraph_mark_step_begin() - inp = torch.rand([4], device="cuda") - - tmp = foo(inp) - # While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()` - # is not. So CUDAGraph is not applied to `mut` and there is a log - # `skipping cudagraphs due to mutated inputs` - mut(tmp.clone()) - - -To enable CUDAGraph Trees for a function mutating inputs from eager, please re-write -the function to avoid input mutation. - -.. note:: Enable input mutation support by setting - `torch._inductor.config.cudagraph_support_input_mutation = True `_ - for "reduce-overhead" mode. - - -Dynamic Shape Support ---------------------- - -`Dynamic shape `_ -means that an input tensor has different shapes across function calls. Since CUDAGraph -requires fixed tensor addresses, CUDAGraph Trees re-record CUDAGraph for every unique -shape of an input tensor. This leads to multiple CUDAGraphs for a single inductor graph. -When there are limited shapes (e.g., batch sizes in inference), it is profitable to -re-record CUDAGraphs. However, if input tensor shapes change frequently or even on -every invocation, re-recording CUDAGraph may not be profitable. Nvidia uses 64 KB of -device memory per kernel launch in CUDAGraph, up until CUDA 12.4 and Driver Version 550+. -This memory cost can be significant with many CUDAGraph re-recordings. - -For functions with frequently changing input tensor shapes, we suggest padding input -tensors to a few fixed tensor shapes to still enjoy benefits from CUDAGraph. In addition, -setting `torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True `_ -allows to skip cudagraphing functions with dynamic shape inputs and only cudagraphing -functions with static input tensor shapes. - - -NCCL Support ------------- - -CUDAGraph Trees support functions with nccl operators. While CUDAGraph Trees perform per-device -record for CUDAGraph, NCCL support allows cross-device communication. - -.. code-block:: python - - @torch.compile(mode="reduce-overhead") - def func(x): - y = x * x - y = torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM) - x = torch.nn.functional.silu(x) - return x * y - - -Reasons for Skipping CUDAGraph ------------------------------- - -Since CUDAGraph has requirements such as static input tensor addresses and not supporting -CPU operators, CUDAGraph Trees check whether a function satisfies these requirements and -may skip CUDAGraph when necessary. Here, we list common reasons for skipping CUDAGraph. - -* **Input mutation**: CUDAGraph Trees skip functions that in-place mutates eager input. - In-place mutating parameters and buffers, or output tensors from CUDAGraph Tree managed - functions are still supported. Please see *Input Mutation Support* section for more details. -* **CPU operators**: Functions containing CPU operator are skipped. Please split the - function into multiple functions and apply CUDAGraph Trees on functions with only GPU operators. -* **Multi-device operators**: A function is skipped if it contains operators on multiple - devices. Currently, CUDAGraph is applied on a per-device basis. Please use supported - libraries such as NCCL for cross-device communication. Please see *NCCL Support* - section for more details. -* **Free unbacked symbols**: Free unbacked symbols usually happen during - `dynamic shapes `_. - CUDAGraph Trees currently record a CUDAGraph for every unique input tensor shapes. - Please see *Dynamic Shape Support* for more details. -* **Incompatible operators**: CUDAGraph Trees skip a function if it contain incompatible - operators. Please replace these operators in a function with supported operators. We - show an exhaustive list of incompatible operators: - - -.. code-block:: python - - aten._fused_moving_avg_obs_fq_helper.default - aten._fused_moving_avg_obs_fq_helper_functional.default - aten.multinomial.default - fbgemm.dense_to_jagged.default - fbgemm.jagged_to_padded_dense.default - run_and_save_rng_state - run_with_rng_state - aten._local_scalar_dense - aten._assert_scalar - - -The following operators are incompatible when `torch.are_deterministic_algorithms_enabled() `_. - - -.. code-block:: python - - aten._fused_moving_avg_obs_fq_helper.default - aten._fused_moving_avg_obs_fq_helper_functional.default - aten.multinomial.default - fbgemm.dense_to_jagged.default - fbgemm.jagged_to_padded_dense.default - run_and_save_rng_state - run_with_rng_state - aten._local_scalar_dense - aten._assert_scalar - - -Limitations ------------ - -Because CUDA Graph fixes memory addresses, CUDA Graphs do not have a great way of handling live tensors from a previous invocation. - -Let’s say we are benchmarking running inference with the following code: - -.. code-block:: python - - import torch - - @torch.compile(mode="reduce-overhead") - def my_model(x): - y = torch.matmul(x, x) - return y - - x = torch.randn(10, 10, device="cuda") - y1 = my_model(x) - y2 = my_model(x) - print(y1) - # RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. - -In the Separate CUDA Graph implementation, the output from the first invocation will be overwritten by the second invocation. In CUDAGraph -Trees, we don’t want to add unintended dependencies between iterations that would cause us to not hit the hot path, nor do we want we want -to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for -torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics -are wrong, you can mark the start of a new iteration with -`torch.compiler.mark_step_begin() `_, or clone -tensors of a prior iteration (outside of torch.compile) before you begin the next run. - - -Comparisons ------------ - -.. list-table:: - :widths: 20 40 40 - :header-rows: 1 - - * - Footguns - - Separate CudaGraph - - CUDAGraph Trees - * - Memory Can Increase - - On each graph compilation (new sizes, etc.) - - If you are also running non-cudagraph memory - * - Recordings - - On any new invocation of a graph - - Will re-record on any new, unique path you take through your program - * - Footguns - - Invocation of one graph will overwrite prior invocation - - Cannot persist memory between separate runs through your model - one training loop training, or one run of inference diff --git a/docs/source/torch.compiler_custom_backends.md b/docs/source/torch.compiler_custom_backends.md new file mode 100644 index 00000000000000..df43bb4c9e860f --- /dev/null +++ b/docs/source/torch.compiler_custom_backends.md @@ -0,0 +1,280 @@ +# Custom Backends + +## Overview + +`torch.compile` provides a straightforward method to enable users +to define custom backends. + +A backend function has the contract +`(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable`. + +Backend functions can be called by TorchDynamo, the graph tracing component of `torch.compile`, +after tracing an FX graph and are +expected to return a compiled function that is equivalent to the traced FX graph. +The returned callable should have the same contract as the `forward` function of the original `torch.fx.GraphModule` +passed into the backend: +`(*args: torch.Tensor) -> List[torch.Tensor]`. + +In order for TorchDynamo to call your backend, pass your backend function as the `backend` kwarg in +`torch.compile`. For example, + +```python +import torch + +def my_custom_backend(gm, example_inputs): + return gm.forward + +def f(...): + ... + +f_opt = torch.compile(f, backend=my_custom_backend) + +@torch.compile(backend=my_custom_backend) +def g(...): + ... +``` + +See below for more examples. + +## Registering Custom Backends + +You can register your backend using the `register_backend` decorator, for example, + +```python +from torch._dynamo import register_backend + +@register_backend +def my_compiler(gm, example_inputs): + ... +``` + +Besides the `register_backend` decorator, if your backend is in another python package, you could also register your +backend through entry points of python package, which provides a way for a package to register a plugin for another one. + +:::{hint} +You can learn more about `entry_points` in the +[python packaging documentation](https://setuptools.pypa.io/en/latest/userguide/entry_point.html). +::: + +To register your backend through `entry_points`, you could add your backend function to the `torch_dynamo_backends` entry point group in the +`setup.py` file of your package like: + +```python +... +setup( + ... + 'torch_dynamo_backends': [ + 'my_compiler = your_module.submodule:my_compiler', + ] + ... +) +``` + +Please replace the `my_compiler` before `=` to the name of your backend's name and replace the part after `=` to +the module and function name of your backend function. +The entry point will be added to your python environment after the installation of the package. +When you call `torch.compile(model, backend="my_compiler")`, PyTorch would first search the backend named `my_compiler` +that has been registered with `register_backend`. If not found, it will continue to search in all backends registered +via `entry_points`. + +Registration serves two purposes: + +- You can pass a string containing your backend function's name to `torch.compile` instead of the function itself, + for example, `torch.compile(model, backend="my_compiler")`. +- It is required for use with the [minifier](https://pytorch.org/docs/main/torch.compiler_troubleshooting_old.html#minifier). Any generated + code from the minifier must call your code that registers your backend function, typically through an `import` statement. + +## Custom Backends after AOTAutograd + +It is possible to define custom backends that are called by AOTAutograd rather than TorchDynamo. +This is useful for 2 main reasons: + +- Users can define backends that support model training, as AOTAutograd can generate the backward graph for compilation. +- AOTAutograd produces FX graphs consisting of [core Aten ops](https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir). As a result, + custom backends only need to support the core Aten opset, which is a significantly smaller opset than the entire torch/Aten opset. + +Wrap your backend with +`torch._dynamo.backends.common.aot_autograd` and use `torch.compile` with the `backend` kwarg as before. +Backend functions wrapped by `aot_autograd` should have the same contract as before. + +Backend functions are passed to `aot_autograd` through the `fw_compiler` (forward compiler) +or `bw_compiler` (backward compiler) kwargs. If `bw_compiler` is not specified, the backward compile function +defaults to the forward compile function. + +One caveat is that AOTAutograd requires compiled functions returned by backends to be "boxed". This can be done by wrapping +the compiled function with `functorch.compile.make_boxed_func`. + +For example, + +```python +from torch._dynamo.backends.common import aot_autograd +from functorch.compile import make_boxed_func + +def my_compiler(gm, example_inputs): + return make_boxed_func(gm.forward) + +my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler + +model_opt = torch.compile(model, backend=my_backend) +``` + +## Examples + +### Debugging Backend + +If you want to better understand what is going on during a +compilation, you can create a custom compiler, which is referred to as +backend in this section, that will print pretty print the fx +`GraphModule` extracted from Dynamo’s bytecode analysis +and return a `forward()` callable. + +For example: + +```python +from typing import List +import torch +def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable +@torch.compile(backend=my_compiler) +def fn(x, y): + a = torch.cos(x) + b = torch.sin(y) + return a + b +fn(torch.randn(10), torch.randn(10)) +``` + +Running the above example produces the following output: + +``` +my_compiler() called with FX graph: +opcode name target args kwargs +------------- ------ ------------------------------------------------------ ---------- -------- +placeholder x x () {} +placeholder y y () {} +call_function cos (x,) {} +call_function sin (y,) {} +call_function add (cos, sin) {} +output output output ((add,),) {} +``` + +This works for `torch.nn.Module` as well as shown below: + +```python +from typing import List +import torch +def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable +class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + def forward(self, x): + return self.relu(torch.cos(x)) +mod = MockModule() +optimized_mod = torch.compile(mod, backend=my_compiler) +optimized_mod(torch.randn(10)) +``` + +Let’s take a look at one more example with control flow: + +```python +from typing import List +import torch +def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable +@torch.compile(backend=my_compiler) +def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b +for _ in range(100): + toy_example(torch.randn(10), torch.randn(10)) +``` + +Running this example produces the following output: + +``` +my_compiler() called with FX graph: +opcode name target args kwargs +------------- ------- ------------------------------------------------------ ---------------- -------- +placeholder a a () {} +placeholder b b () {} +call_function abs_1 (a,) {} +call_function add (abs_1, 1) {} +call_function truediv (a, add) {} +call_method sum_1 sum (b,) {} +call_function lt (sum_1, 0) {} +output output output ((truediv, lt),) {} + +my_compiler() called with FX graph: +opcode name target args kwargs +------------- ------ ----------------------- ----------- -------- +placeholder b b () {} +placeholder x x () {} +call_function mul (b, -1) {} +call_function mul_1 (x, mul) {} +output output output ((mul_1,),) {} + +my_compiler() called with FX graph: +opcode name target args kwargs +------------- ------ ----------------------- --------- -------- +placeholder b b () {} +placeholder x x () {} +call_function mul (x, b) {} +output output output ((mul,),) {} + +The order of the last two graphs is nondeterministic depending +on which one is encountered first by the just-in-time compiler. +``` + +### Speedy Backend + +Integrating a custom backend that offers superior performance is also +easy and we’ll integrate a real one +with [optimize_for_inference](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html): + +```python +def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + scripted = torch.jit.script(gm) + return torch.jit.optimize_for_inference(scripted) +``` + +And then you should be able to optimize any existing code with: + +```python +@torch.compile(backend=optimize_for_inference_compiler) +def code_to_accelerate(): + ... +``` + +### Composable Backends + +TorchDynamo includes many backends, which can be listed with +`torch._dynamo.list_backends()`. You can combine these backends +together with the following code: + +```python +from torch._dynamo import lookup_backend +def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + try: + trt_compiled = lookup_backend("tensorrt")(gm, example_inputs) + if trt_compiled is not None: + return trt_compiled + except Exception: + pass + # first backend failed, try something else... + try: + inductor_compiled = lookup_backend("inductor")(gm, example_inputs) + if inductor_compiled is not None: + return inductor_compiled + except Exception: + pass + return gm.forward +``` diff --git a/docs/source/torch.compiler_custom_backends.rst b/docs/source/torch.compiler_custom_backends.rst deleted file mode 100644 index 611cc0bff7b084..00000000000000 --- a/docs/source/torch.compiler_custom_backends.rst +++ /dev/null @@ -1,288 +0,0 @@ -Custom Backends -=============== - -Overview --------- - -``torch.compile`` provides a straightforward method to enable users -to define custom backends. - -A backend function has the contract -``(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable``. - -Backend functions can be called by TorchDynamo, the graph tracing component of ``torch.compile``, -after tracing an FX graph and are -expected to return a compiled function that is equivalent to the traced FX graph. -The returned callable should have the same contract as the ``forward`` function of the original ``torch.fx.GraphModule`` -passed into the backend: -``(*args: torch.Tensor) -> List[torch.Tensor]``. - -In order for TorchDynamo to call your backend, pass your backend function as the ``backend`` kwarg in -``torch.compile``. For example, - -.. code-block:: python - - import torch - - def my_custom_backend(gm, example_inputs): - return gm.forward - - def f(...): - ... - - f_opt = torch.compile(f, backend=my_custom_backend) - - @torch.compile(backend=my_custom_backend) - def g(...): - ... - -See below for more examples. - -Registering Custom Backends ---------------------------- - -You can register your backend using the ``register_backend`` decorator, for example, - -.. code-block:: python - - from torch._dynamo import register_backend - - @register_backend - def my_compiler(gm, example_inputs): - ... - -Besides the ``register_backend`` decorator, if your backend is in another python package, you could also register your -backend through entry points of python package, which provides a way for a package to register a plugin for another one. - -.. hint:: - - You can learn more about ``entry_points`` in the - `python packaging documentation `__. - -To register your backend through ``entry_points``, you could add your backend function to the ``torch_dynamo_backends`` entry point group in the -``setup.py`` file of your package like: - -.. code-block:: python - - ... - setup( - ... - 'torch_dynamo_backends': [ - 'my_compiler = your_module.submodule:my_compiler', - ] - ... - ) - -Please replace the ``my_compiler`` before ``=`` to the name of your backend's name and replace the part after ``=`` to -the module and function name of your backend function. -The entry point will be added to your python environment after the installation of the package. -When you call ``torch.compile(model, backend="my_compiler")``, PyTorch would first search the backend named ``my_compiler`` -that has been registered with ``register_backend``. If not found, it will continue to search in all backends registered -via ``entry_points``. - -Registration serves two purposes: - -* You can pass a string containing your backend function's name to ``torch.compile`` instead of the function itself, - for example, ``torch.compile(model, backend="my_compiler")``. -* It is required for use with the :ref:`minifier `. Any generated - code from the minifier must call your code that registers your backend function, typically through an ``import`` statement. - -Custom Backends after AOTAutograd ---------------------------------- - -It is possible to define custom backends that are called by AOTAutograd rather than TorchDynamo. -This is useful for 2 main reasons: - -* Users can define backends that support model training, as AOTAutograd can generate the backward graph for compilation. -* AOTAutograd produces FX graphs consisting of `core Aten ops `__. As a result, - custom backends only need to support the core Aten opset, which is a significantly smaller opset than the entire torch/Aten opset. - -Wrap your backend with -``torch._dynamo.backends.common.aot_autograd`` and use ``torch.compile`` with the ``backend`` kwarg as before. -Backend functions wrapped by ``aot_autograd`` should have the same contract as before. - -Backend functions are passed to ``aot_autograd`` through the ``fw_compiler`` (forward compiler) -or ``bw_compiler`` (backward compiler) kwargs. If ``bw_compiler`` is not specified, the backward compile function -defaults to the forward compile function. - -One caveat is that AOTAutograd requires compiled functions returned by backends to be "boxed". This can be done by wrapping -the compiled function with ``functorch.compile.make_boxed_func``. - -For example, - -.. code-block:: python - - from torch._dynamo.backends.common import aot_autograd - from functorch.compile import make_boxed_func - - def my_compiler(gm, example_inputs): - return make_boxed_func(gm.forward) - - my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler - - model_opt = torch.compile(model, backend=my_backend) - -Examples --------- - -Debugging Backend -^^^^^^^^^^^^^^^^^ - -If you want to better understand what is going on during a -compilation, you can create a custom compiler, which is referred to as -backend in this section, that will print pretty print the fx -``GraphModule`` extracted from Dynamo’s bytecode analysis -and return a ``forward()`` callable. - -For example: - -.. code-block:: python - - from typing import List - import torch - def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - return gm.forward # return a python callable - @torch.compile(backend=my_compiler) - def fn(x, y): - a = torch.cos(x) - b = torch.sin(y) - return a + b - fn(torch.randn(10), torch.randn(10)) - -Running the above example produces the following output: - -:: - - my_compiler() called with FX graph: - opcode name target args kwargs - ------------- ------ ------------------------------------------------------ ---------- -------- - placeholder x x () {} - placeholder y y () {} - call_function cos (x,) {} - call_function sin (y,) {} - call_function add (cos, sin) {} - output output output ((add,),) {} - -This works for ``torch.nn.Module`` as well as shown below: - -.. code-block:: python - - from typing import List - import torch - def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - return gm.forward # return a python callable - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - def forward(self, x): - return self.relu(torch.cos(x)) - mod = MockModule() - optimized_mod = torch.compile(mod, backend=my_compiler) - optimized_mod(torch.randn(10)) - -Let’s take a look at one more example with control flow: - -.. code-block:: python - - from typing import List - import torch - def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - return gm.forward # return a python callable - @torch.compile(backend=my_compiler) - def toy_example(a, b): - x = a / (torch.abs(a) + 1) - if b.sum() < 0: - b = b * -1 - return x * b - for _ in range(100): - toy_example(torch.randn(10), torch.randn(10)) - -Running this example produces the following output: - -:: - - my_compiler() called with FX graph: - opcode name target args kwargs - ------------- ------- ------------------------------------------------------ ---------------- -------- - placeholder a a () {} - placeholder b b () {} - call_function abs_1 (a,) {} - call_function add (abs_1, 1) {} - call_function truediv (a, add) {} - call_method sum_1 sum (b,) {} - call_function lt (sum_1, 0) {} - output output output ((truediv, lt),) {} - - my_compiler() called with FX graph: - opcode name target args kwargs - ------------- ------ ----------------------- ----------- -------- - placeholder b b () {} - placeholder x x () {} - call_function mul (b, -1) {} - call_function mul_1 (x, mul) {} - output output output ((mul_1,),) {} - - my_compiler() called with FX graph: - opcode name target args kwargs - ------------- ------ ----------------------- --------- -------- - placeholder b b () {} - placeholder x x () {} - call_function mul (x, b) {} - output output output ((mul,),) {} - - The order of the last two graphs is nondeterministic depending - on which one is encountered first by the just-in-time compiler. - -Speedy Backend -^^^^^^^^^^^^^^ - -Integrating a custom backend that offers superior performance is also -easy and we’ll integrate a real one -with `optimize_for_inference `__: - -.. code-block:: python - - def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - scripted = torch.jit.script(gm) - return torch.jit.optimize_for_inference(scripted) - -And then you should be able to optimize any existing code with: - -.. code-block:: python - - @torch.compile(backend=optimize_for_inference_compiler) - def code_to_accelerate(): - ... - -Composable Backends -^^^^^^^^^^^^^^^^^^^ - -TorchDynamo includes many backends, which can be listed with -``torch._dynamo.list_backends()``. You can combine these backends -together with the following code: - -.. code-block:: python - - from torch._dynamo import lookup_backend - def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - try: - trt_compiled = lookup_backend("tensorrt")(gm, example_inputs) - if trt_compiled is not None: - return trt_compiled - except Exception: - pass - # first backend failed, try something else... - try: - inductor_compiled = lookup_backend("inductor")(gm, example_inputs) - if inductor_compiled is not None: - return inductor_compiled - except Exception: - pass - return gm.forward diff --git a/docs/source/torch.compiler_dynamic_shapes.md b/docs/source/torch.compiler_dynamic_shapes.md new file mode 100644 index 00000000000000..95998ffe8491c6 --- /dev/null +++ b/docs/source/torch.compiler_dynamic_shapes.md @@ -0,0 +1,129 @@ +# Dynamic Shapes + +Code: [symbolic_shapes.py](https://github.com/pytorch/pytorch/blob/db4572dbf18f1cf50cf662547e272d3117063747/torch/fx/experimental/symbolic_shapes.py) + +See also: [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng) + +## Motivation + +Deep learning compilers commonly only work for static shapes, that is to say, they produced compiled programs which only work for a single specific configuration of input shapes, and must recompile if any input shape changes. This assumption works great for the majority of commonly run deep learning models today, but there are a few situations where it is insufficient: + +- Some dimensions, such as batch size or sequence length, may vary. For example, an inference service performing adaptive batching will execute inference requests with varying batch sizes depending on how many requests it received within its batching window. We may also want to consider padding out variable size sequences only to the maximum sequence length within a batch, which may vary from batch-to-batch. +- Some models exhibit data-dependent output shapes, that is to say, the size of their outputs and intermediates may depend on the actual input data which may vary across runs. For example, detection models may first generate a variable number of potential bounding boxes before running a more expensive image recognition model to identify if the subject is in a bounding box. The number of bounding boxes is data dependent. +- One particularly important case of data-dependent shapes occurs when dealing with sparse representations, such as sparse tensors, jagged tensors, and graph neural networks. In all of these cases, the amount of data to be processed depends on the sparse structure of the problem, which will typically vary in a data-dependent way. + +In supporting dynamic shapes, we chose not to support dynamic rank programs, e.g., programs whose inputs tensors change in dimensionality, as this pattern rarely occurs in real-world deep learning programs, and it avoids the need to reason inductively over symbolic lists of shapes. + +## Abridged public API + +The default dynamic behavior in PyTorch 2.1 is: + +- PT2 assumes everything is static by default +- If we recompile because a size changed, we will instead attempt to recompile + that size as being dynamic (sizes that have changed are likely to change in + the future). This generalization may fail (e.g., because user code does a + conditional branch on the size in question or missing dynamic shapes support + in PT2). If you are trying to understand why PT2 has overspecialized some + code, run with `TORCH_LOGS=dynamic` and look for "eval" entries that say + when guards are added and why. +- If you know ahead of time something will be dynamic, you can skip the first + recompile with `torch._dynamo.mark_dynamic(tensor, dim)`. If you know ahead of time + the `min` and `max` value this dimension can take, you can specify `torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)` +- If you say `torch.compile(dynamic=False)`, we will turn off automatic + dynamic shapes on recompiles and always recompile for each distinct size. + Conversely, if you say `torch.compile(dynamic=True)`, we will try to make + everything as dynamic as possible. This is mostly useful for small + operators; if you try it on a big model it will (1) probably crash PT2 and (2) run slow for no good reason. +- You can whitelist specific sources to be marked as dynamic using the + `TORCH_COMPILE_DYNAMIC_SOURCES` environment variable or by setting + `torch.compiler.config.dynamic_sources`. This is particularly useful for large + models with graph breaks, as you can maintain dynamism across graph breaks since + source names stay consistent. You can also use this to mark integers as dynamic. + The format is a comma-delimited list of source names, e.g., `"L['x'], L['y']"`. + You can also use regexes, e.g., `"L\['x.*'\], L\['y.*'\]")`. + This whitelist takes precedence over other flags like `dynamic=False`, + `force_nn_module_property_static_shapes`, and `force_parameter_static_shapes`. +- Sometimes it can be cumbersome to find the right inputs to mark as dynamic. If + you're willing to take a performance hit for the first batch, one other affordable + option we have are the eager_then_compile stances which derive dynamism for you. + See [torch.compiler.set_stance](https://docs.pytorch.org/docs/stable/generated/torch.compiler.set_stance.html) for more details. + +## The Guard Model + +When considering how to add support for dynamic shapes to TorchDynamo and TorchInductor, we made a major design decision: in order to reuse decompositions and other preexisting code written in Python/C++ targeting the PyTorch API, we must be able to trace through dynamic shapes. Unlike a fully symbolic system which might capture both branches of a conditional, we always pick one branch and specialize our trace under the assumption that we only use this trace when we would have made the same choice for that branch in the future. To do this, we maintain a "hint" for every symbolic size saying what its concrete value is at compile time (as TorchDynamo is a just-in-time compiler, it always knows what the actual input sizes are.) When we perform a condition on a tensor, we simply consult the hint to find out which branch to take. + +This greatly simplifies the symbolic shape formulas we produce, but means we have a much more involved system for managing guards. Consider, for example, the following program: + +```python +def f(x, y): + z = torch.cat([x, y]) + if z.size(0) > 2: + return z.mul(2) + else: + return z.add(2) +``` + +The final IR we will compile with TorchInductor will either be `torch.cat([x, y]).add(2)` or `torch.cat([x, y]).mul(2)` (with the condition flattened away), but to determine which branch we are in, we would need to know the size of `z`, an intermediate. Because TorchDynamo must know upfront if a compiled trace is valid (we do not support bailouts, like some JIT compilers), we must be able to reduce `z.size(0)` as an expression in terms of the inputs, `x.size(0) + y.size(0)`. This is done by writing meta functions for all operators in PyTorch which can propagate size information to the output of a tensor without actually performing computation on the node. + +## Overall architecture + +Symbolic shapes workflow: + +1. When we start compiling a frame in Dynamo, we allocate a ShapeEnv (attached to FakeTensorMode) which keeps track of symbolic shapes state. +2. We allocate symbolic sizes for tensors on entry (what is static or dynamic is a policy decision, with some knobs). +3. We propagate the symbolic sizes through operators, maintaining both (1) FX IR so that we can faithfully export symbolic compute, and (2) Sympy expressions representing the size vars, so we can reason about them. +4. When we condition on symbolic sizes, either in Dynamo tracing or in Inductor optimization, we add guards based on the conditional. These can be induced from both Python and C++. +5. These guards can induce further simplifications on symbolic variables. For example, if you assert `s0 == 4`, we can now replace all occurrences of `s0` with `4`. +6. When we're done tracing and optimizing, we install all of these guards with the compiled code; the compiled code is only reusable if all the guards evaluate true. + +Important files: + +- C++ SymInt API: `c10/core/SymInt.h`, `SymFloat.h`, `SymBool.h` +- Python SymInt API: `torch/__init__.py` (look for `SymInt/SymFloat/SymBool`) +- C++ plumbing: `c10/core/SymNodeImpl.h`, `torch/csrc/utils/python_symnode.h`, `torch/csrc/jit/python/init.cpp` +- Python infrastructure: `torch/fx/experimental/symbolic_shapes.py` +- Other important files: `torch/_subclasses/fake_tensor.py`, `torch/_meta_registrations.py`, decomps, PrimTorch refs + +## Abridged internal API + +Understanding the Python class hierarchy: + +- SymInt/SymFloat/SymBool: these are user-visible classes that simulate their int/float/bool counterparts. If you add two SymInts, we give you a new SymInt that symbolically tracks that the integer addition had occurred. +- SymNode: this is the internal structure (accessible via e.g., `symint.node`) which holds the actual symbolic tracking info. SymNode is type erased; this makes it more convenient to represent mixed-type operations. Note that technically you don't have to call into Python SymNode from SymInt; for example, XLA's C++ `SymNodeImpl` would take the place of SymNode. +- ShapeEnv: per-compile context state which keeps track of all the free symbols and guards we have accumulated so far. Every SymNode records its ShapeEnv (but not vice versa; SymNodes only get used if they participate in a guard). + +C++ is fairly similar: + +- c10::SymInt/SymFloat/SymBool: user-visible classes that simulate int/float/bool. +- c10::SymNode/SymNodeImpl: analogous to SymNode +- There is no ShapeEnv in C++; for ease of debugging, the entire symbolic reasoning apparatus is in Python. + +When you write code that is traceable with `make_fx`, it must be able to deal with SymInt/SymFloat/SymBool flowing through it. [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng) gives some guidance for how to do this. + +## DimDynamic policy + +Symbolic reasoning: + +- Value ranges +- Sympy usage notes +- Constraints +- DimDynamic/Constraint + +## Unbacked SymInts + +To resolve control flow, we check the hint, aka actual value, of a symbolic integer to determine which branch to go. However, in some cases, we may not have a hint: so-called unbacked symbolic integers arise when a size variable emerges from a data-dependent operation like `.nonzero()` or `.item()`. It is illegal to perform control flow on these symbolic integers, so we must graph break on these operations. + +Naively implemented, this is too restrictive: most PyTorch programs will immediately fail if you try to do anything with unbacked symbolic integers. Here are the most important enhancements to make this actually work: + +- On tensor creation, PyTorch precomputes a lot of data about a tensor; for example, if you use `empty_strided` to create a tensor, we will eagerly sort the strides and determine if the tensor is non-overlapping and dense. Sorts produce a lot of guards. However, it is more common to produce a tensor directly with a higher-level API like `empty`, which is guaranteed to produce a non-overlapping and dense tensor. We modified PyTorch to avoid needlessly recomputing these properties. +- Even if nontrivial compute is needed, sometimes a property is never actually queried at all. Making these precomputed properties lazy allows us to avoid guarding on an unbacked symbolic integer unless it is actually needed. +- The data in an integer tensor is generally not known to be non-negative. However, we provide an API `constrain_range` whereby a user can specify that a size is bounded above and below by known limits. + +Similar to the dynamic APIs, there are corresponding unbacked APIs: namely you can use mark_unbacked instead of `mark_dynamic` and `TORCH_COMPILE_UNBACKED_SOURCES` instead of `TORCH_COMPILE_DYNAMIC_SOURCES` to tell the compiler to mark an input as unbacked. + +In future versions of PT2 (beyond PT2.1), we will extend our reasoning system +to infer that an unbacked symbolic integer is size-like based on usage. For +example, if you pass the result of an `.item()` call to a factory function +like `torch.empty`, we will automatically infer that the result is a size +(because if it was not, it would fail.) This assumption would get validated +at runtime, raising an error if it was not fulfilled. diff --git a/docs/source/torch.compiler_dynamic_shapes.rst b/docs/source/torch.compiler_dynamic_shapes.rst deleted file mode 100644 index 33256840cf6b55..00000000000000 --- a/docs/source/torch.compiler_dynamic_shapes.rst +++ /dev/null @@ -1,126 +0,0 @@ -Dynamic shapes -============== - -Code: `symbolic_shapes.py `_ - -See also: `The dynamic shapes manual `_ - -Motivation ----------- - -Deep learning compilers commonly only work for static shapes, that is to say, they produced compiled programs which only work for a single specific configuration of input shapes, and must recompile if any input shape changes. This assumption works great for the majority of commonly run deep learning models today, but there are a few situations where it is insufficient: - -- Some dimensions, such as batch size or sequence length, may vary. For example, an inference service performing adaptive batching will execute inference requests with varying batch sizes depending on how many requests it received within its batching window. We may also want to consider padding out variable size sequences only to the maximum sequence length within a batch, which may vary from batch-to-batch. -- Some models exhibit data-dependent output shapes, that is to say, the size of their outputs and intermediates may depend on the actual input data which may vary across runs. For example, detection models may first generate a variable number of potential bounding boxes before running a more expensive image recognition model to identify if the subject is in a bounding box. The number of bounding boxes is data dependent. -- One particularly important case of data-dependent shapes occurs when dealing with sparse representations, such as sparse tensors, jagged tensors, and graph neural networks. In all of these cases, the amount of data to be processed depends on the sparse structure of the problem, which will typically vary in a data-dependent way. - -In supporting dynamic shapes, we chose not to support dynamic rank programs, e.g., programs whose inputs tensors change in dimensionality, as this pattern rarely occurs in real-world deep learning programs, and it avoids the need to reason inductively over symbolic lists of shapes. - -Abridged public API -------------------- - -The default dynamic behavior in PyTorch 2.1 is: - -- PT2 assumes everything is static by default - -- If we recompile because a size changed, we will instead attempt to recompile - that size as being dynamic (sizes that have changed are likely to change in - the future). This generalization may fail (e.g., because user code does a - conditional branch on the size in question or missing dynamic shapes support - in PT2). If you are trying to understand why PT2 has overspecialized some - code, run with ``TORCH_LOGS=dynamic`` and look for "eval" entries that say - when guards are added and why. - -- If you know ahead of time something will be dynamic, you can skip the first - recompile with ``torch._dynamo.mark_dynamic(tensor, dim)``. If you know ahead of time - the ``min`` and ``max`` value this dimension can take, you can specify ``torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)`` - -- If you say ``torch.compile(dynamic=False)``, we will turn off automatic - dynamic shapes on recompiles and always recompile for each distinct size. - Conversely, if you say ``torch.compile(dynamic=True)``, we will try to make - everything as dynamic as possible. This is mostly useful for small - operators; if you try it on a big model it will (1) probably crash PT2 and - (2) run slow for no good reason. - -The Guard Model ---------------- - -When considering how to add support for dynamic shapes to TorchDynamo and TorchInductor, we made a major design decision: in order to reuse decompositions and other preexisting code written in Python/C++ targeting the PyTorch API, we must be able to trace through dynamic shapes. Unlike a fully symbolic system which might capture both branches of a conditional, we always pick one branch and specialize our trace under the assumption that we only use this trace when we would have made the same choice for that branch in the future. To do this, we maintain a "hint" for every symbolic size saying what its concrete value is at compile time (as TorchDynamo is a just-in-time compiler, it always knows what the actual input sizes are.) When we perform a condition on a tensor, we simply consult the hint to find out which branch to take. - -This greatly simplifies the symbolic shape formulas we produce, but means we have a much more involved system for managing guards. Consider, for example, the following program: - -.. code-block:: python - - def f(x, y): - z = torch.cat([x, y]) - if z.size(0) > 2: - return z.mul(2) - else: - return z.add(2) - -The final IR we will compile with TorchInductor will either be ``torch.cat([x, y]).add(2)`` or ``torch.cat([x, y]).mul(2)`` (with the condition flattened away), but to determine which branch we are in, we would need to know the size of ``z``, an intermediate. Because TorchDynamo must know upfront if a compiled trace is valid (we do not support bailouts, like some JIT compilers), we must be able to reduce ``z.size(0)`` as an expression in terms of the inputs, ``x.size(0) + y.size(0)``. This is done by writing meta functions for all operators in PyTorch which can propagate size information to the output of a tensor without actually performing computation on the node. - -Overall architecture --------------------- - -Symbolic shapes workflow: - -1. When we start compiling a frame in Dynamo, we allocate a ShapeEnv (attached to FakeTensorMode) which keeps track of symbolic shapes state. -2. We allocate symbolic sizes for tensors on entry (what is static or dynamic is a policy decision, with some knobs). -3. We propagate the symbolic sizes through operators, maintaining both (1) FX IR so that we can faithfully export symbolic compute, and (2) Sympy expressions representing the size vars, so we can reason about them. -4. When we condition on symbolic sizes, either in Dynamo tracing or in Inductor optimization, we add guards based on the conditional. These can be induced from both Python and C++. -5. These guards can induce further simplifications on symbolic variables. For example, if you assert ``s0 == 4``, we can now replace all occurrences of ``s0`` with ``4``. -6. When we're done tracing and optimizing, we install all of these guards with the compiled code; the compiled code is only reusable if all the guards evaluate true. - -Important files: - -- C++ SymInt API: ``c10/core/SymInt.h``, ``SymFloat.h``, ``SymBool.h`` -- Python SymInt API: ``torch/__init__.py`` (look for ``SymInt/SymFloat/SymBool``) -- C++ plumbing: ``c10/core/SymNodeImpl.h``, ``torch/csrc/utils/python_symnode.h``, ``torch/csrc/jit/python/init.cpp`` -- Python infrastructure: ``torch/fx/experimental/symbolic_shapes.py`` -- Other important files: ``torch/_subclasses/fake_tensor.py``, ``torch/_meta_registrations.py``, decomps, PrimTorch refs - -Abridged internal API ---------------------- - -Understanding the Python class hierarchy: - -- SymInt/SymFloat/SymBool: these are user-visible classes that simulate their int/float/bool counterparts. If you add two SymInts, we give you a new SymInt that symbolically tracks that the integer addition had occurred. -- SymNode: this is the internal structure (accessible via e.g., ``symint.node``) which holds the actual symbolic tracking info. SymNode is type erased; this makes it more convenient to represent mixed-type operations. Note that technically you don't have to call into Python SymNode from SymInt; for example, XLA's C++ ``SymNodeImpl`` would take the place of SymNode. -- ShapeEnv: per-compile context state which keeps track of all the free symbols and guards we have accumulated so far. Every SymNode records its ShapeEnv (but not vice versa; SymNodes only get used if they participate in a guard). - -C++ is fairly similar: - -- c10::SymInt/SymFloat/SymBool: user-visible classes that simulate int/float/bool. -- c10::SymNode/SymNodeImpl: analogous to SymNode -- There is no ShapeEnv in C++; for ease of debugging, the entire symbolic reasoning apparatus is in Python. - -When you write code that is traceable with ``make_fx``, it must be able to deal with SymInt/SymFloat/SymBool flowing through it. `The dynamic shapes manual `_ gives some guidance for how to do this. - -DimDynamic policy ------------------ - -Symbolic reasoning: - -- Value ranges -- Sympy usage notes -- Constraints -- DimDynamic/Constraint - -Unbacked SymInts ----------------- - -To resolve control flow, we check the hint, aka actual value, of a symbolic integer to determine which branch to go. However, in some cases, we may not have a hint: so-called unbacked symbolic integers arise when a size variable emerges from a data-dependent operation like ``.nonzero()`` or ``.item()``. It is illegal to perform control flow on these symbolic integers, so we must graph break on these operations. - -Naively implemented, this is too restrictive: most PyTorch programs will immediately fail if you try to do anything with unbacked symbolic integers. Here are the most important enhancements to make this actually work: - -- On tensor creation, PyTorch precomputes a lot of data about a tensor; for example, if you use ``empty_strided`` to create a tensor, we will eagerly sort the strides and determine if the tensor is non-overlapping and dense. Sorts produce a lot of guards. However, it is more common to produce a tensor directly with a higher-level API like ``empty``, which is guaranteed to produce a non-overlapping and dense tensor. We modified PyTorch to avoid needlessly recomputing these properties. -- Even if nontrivial compute is needed, sometimes a property is never actually queried at all. Making these precomputed properties lazy allows us to avoid guarding on an unbacked symbolic integer unless it is actually needed. -- The data in an integer tensor is generally not known to be non-negative. However, we provide an API ``constrain_range`` whereby a user can specify that a size is bounded above and below by known limits. - -In future versions of PT2 (beyond PT2.1), we will extend our reasoning system -to infer that an unbacked symbolic integer is size-like based on usage. For -example, if you pass the result of an ``.item()`` call to a factory function -like ``torch.empty``, we will automatically infer that the result is a size -(because if it was not, it would fail.) This assumption would get validated -at runtime, raising an error if it was not fulfilled. diff --git a/docs/source/torch.compiler_dynamo_deepdive.md b/docs/source/torch.compiler_dynamo_deepdive.md new file mode 100644 index 00000000000000..6bbb03170e549c --- /dev/null +++ b/docs/source/torch.compiler_dynamo_deepdive.md @@ -0,0 +1,856 @@ +(torch.compiler_dynamo_deepdive)= + +# Dynamo Deep-Dive + +TorchDynamo (or simply Dynamo) is the tracer within `torch.compile`, +and it is, more often than not, the one to blame for those insane +backtraces. However, we cannot blindly blame Dynamo for these errors. In +order to provide the user with the flexibility it does, Dynamo is given +the arduous task of understanding any Python program. In particular, +Dynamo has to implement a good part of the Python programming language +internally! + +In this post, we will go over the internal design of Dynamo from the +ground up. We will discuss the functionality it provides, and how it is +implemented. By the end of this post, you will have a better +understanding of what went wrong when you `torch.compiled` a PyTorch +program and the compilation errored out, or succeeded but the speed-up +was not what you expected. + +## A Gentle Introduction to Dynamo + +Before getting our hands dirty with all the implementation details, +let’s start by discussing what it is that Dynamo does. + +Dynamo is a tracer. This means, given and function and inputs to it, it +executes the function and records a linear sequence of instructions +(without control flow) into a graph. For example, consider the following +program: + +```python +import torch + +@torch.compile +def mse(x, y): + z = (x - y) ** 2 + return z.sum() + +x = torch.randn(200) +y = torch.randn(200) +mse(x, y) +``` + +If we save this program into the file `example.py` and we run + +```bash +TORCH_LOGS=graph_code python example.py +``` + +we see the output that Dynamo traced + +```python +def forward(l_x_: torch.Tensor, l_y_: torch.Tensor): + # File: example.py:5, code: z = (x - y) ** 2 + sub = l_x_ - l_y_ + z = sub ** 2 + # File: example.py:6, code: return z.sum() + sum_1 = z.sum() + return (sum_1,) +``` + +We call this a **graph (or trace) of the function for the given +inputs**. This is represented via an [FX +graph](https://pytorch.org/docs/main/fx.html). We will simply think +of an FX graph as a container that stores a list of function calls. + +The first thing we should notice is that the graph is a linear sequence +of PyTorch operations. [^1] Dynamo records all the PyTorch operations +and stores them sequentially. For example, it split `z = (x - y) ** 2` +into its two constituting operations, `sub = l_x_ - l_y_` and +`z = sub ** 2`. + +When we say that the trace is linear, we mean that there is no branching +or any control flow. To see this, consider + +```python +import torch + +@torch.compile +def fn(x, n): + y = x ** 2 + if n >= 0: + return (n + 1) * y + else: + return y / n + +x = torch.randn(200) +fn(x, 2) +``` + +which, when executed with `TORCH_LOGS=graph_code`, returns + +```python +def forward(l_x_: torch.Tensor): + # File: example.py:5, code: y = x ** 2 + y = l_x_ ** 2 + # File: example.py:7, code: return (n + 1) * y + mul = 3 * y + return (mul,) +``` + +We see that Dynamo completely removed the `if` statement from the +trace and just recorded the operations that were executed with the +inputs. + +As such, it should be clear that **the trace of a function depends on +the inputs**. In particular, this means that the trace is not generated +when we write `@torch.compile`, but when we execute the function +`fn(x, 2)` with the actual arguments. + +The other interesting thing to note here is that Dynamo removed the +second argument to the function. Instead, it treated it as a constant +and recorded the result of the operation `n + 1` in the graph. This is +another feature of Dynamo: Dynamo will treat as constant any non-tensor +value… other than ints. Let’s see now how are ints special. + +The last defining property of Dynamo is that it knows how to handle +dynamic shapes. Symbolic shapes refer to Dynamo’s ability of tracing +shapes, and more generally, integers, rather than leaving them as +constants. This allows for avoiding recompilations and deploying generic +models that work for any size in production. The main examples of places +where dynamic shapes appear are the batch size, where we might train a +model with a fixed batch size but then perform inference for an +arbitrary batch size, or the variable sequence length that one +encounters when processing text or audio. + +We can see this by executing a few more times the example above + +```python +import torch + +@torch.compile +def fn(x, n): + y = x ** 2 + if n >= 0: + return (n + 1) * y + else: + return y / n + +x = torch.randn(200) +fn(x, 2) +fn(x, 3) +fn(x, -2) +``` + +In this case, `TORCH_LOGS=graph_code` generates two more graphs + +```python +# Graph for n==2 omitted + +def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt): + # File: a.py:5, code: y = x ** 2 + y = l_x_ ** 2 + + # File: a.py:7, code: return (n + 1) * y + add = l_n_ + 1 + mul = add * y + return (mul,) +``` + +```python +def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt): + # File: a.py:5, code: y = x ** 2 + y = l_x_ ** 2 + + # File: a.py:9, code: return y / n + truediv = y / l_n_ + return (truediv,) +``` + +Dynamo detected that one integer changed its value after the first call +and started tracing it. We see that these graphs are generic, and trace +the variable `n` symbolically via an object of type `SymInt`. + +If after these calls we call `fn(x, 4)`, Dynamo would not recompile, +but rather reuse the graph that was already traced. + +To summarize: 1. Dynamo is a Python tracer 2. Given some inputs, it +returns an FX graph with the PyTorch functions that were executed 3. It +can also trace integers if it detects that they changed between calls 4. +It specializes any other value that is not a tensor or a scalar + +Of course, Dynamo does many more things, like figuring out when it needs +to retrace, rewriting the bytecode of the function, implementing graph +breaks… To keep the introduction short, we will incrementally discuss +all these in the sequel. + +## PEP 523: Adding a frame evaluation API to CPython + +Imagine now that we are given the task to implement Dynamo. Where would +we even start? Rather conveniently for us, [PEP +523](https://peps.python.org/pep-0523/) was released with Python 3.6. +This PEP [was +designed](https://peps.python.org/pep-0523/#a-jit-for-cpython) to +allow third parties to create JIT compilers for Python. Let’s see how. + +**A note on CPython**: CPython is internally implemented as a [stack +machine](https://en.wikipedia.org/wiki/Stack_machine). A Python +program is compiled into +[bytecodes](https://en.wikipedia.org/wiki/Bytecode) that then are +executed by this interpreter. To learn more about these bytecodes, see +the [dis module](https://docs.python.org/3/library/dis.html) from the +standard library. See also [the developer +docs](https://devguide.python.org/internals/interpreter/) for an +introduction to CPython’s interpreter. We will assume that the reader is +familiar with the notion of a stack machine. + +PEP 523 exposes an API where a user can add a custom per-function +interpreter. Then, CPython will use this interpreter rather than its own +to execute the function. In order to be able to execute the function, on +entry, CPython provides the custom interpreter with things like - The +bytecode of the function - The value of the arguments of the function +(i.e., the local variables) and their names - The value of the global +variables and their names - The builtin functions like `abs` or +`print` + +You can see all the fields +[here](https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/csrc/dynamo/eval_frame.c#L50-L59). [^2] + +In summary, CPython provides the user’s interpreter with all the +information necessary to execute the function. [^3] + +With this API, we can implement a tracer by implementing an interpreter +that runs the code and records in a graph all the PyTorch operations +that occur during this execution. This is exactly what Dynamo does. + +Dynamo uses this CPython API to parse all these objects and packs them +into [a Python +structure](https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/csrc/dynamo/eval_frame.c#L93-L108). +After it has done so… it goes back from C to python. Other than for this +piece of code that communicates with CPython, Dynamo is fully +implemented in Python. + +It should be clear that it is the decorator `@torch.compile`’s job +to install the necessary scaffolding that will pass the bytecode, the +args, global variables and so on to Dynamo when the function is called. +Again, `@torch.compile` does not actually compile anything. + +## Implementing CPython in Python + +So, we are back in the Python world. We have the bytecode of a function, +and all the context necessary to execute it. In particular, we have +landed at +[_convert_frame_assert](https://github.com/pytorch/pytorch/blob/b6df8414601e1e086e830ca9e919e7fdc8874e71/torch/_dynamo/convert_frame.py#L272-L274). +This is the function that the decorator `torch.compile` returns! We +get to this function from +[_dynamo.optimize](https://github.com/pytorch/pytorch/blob/b6df8414601e1e086e830ca9e919e7fdc8874e71/torch/_dynamo/eval_frame.py#L715-L727). +The decorator `torch.compile` is just a nice API around +`_dynamo.optimize`. + +Before getting into implementing a Python interpreter, we want to define +an [IR](https://en.wikipedia.org/wiki/Intermediate_representation). +In particular, we want to wrap all the local and global variables in our +own internal classes. This allows us to better track these objects and +group together objects that can be treated in the same way to the eyes +of Dynamo. + +The parent class of the internal class structure is `VariableTracker` +and represents the different objects that Dynamo understands. For +example, `ListVariable`, represents a `list` object, and keeps +internally a [list of VariableTrackers](https://github.com/pytorch/pytorch/blob/e38a3a6079a3861b4bc9f256120ec661f34e726d/torch/_dynamo/variables/lists.py#L48-L56). +Another example of `VariableTracker` is +[ConstantVariable](https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables/constant.py#L30). +ConstantVariable wraps all the [objects considered constant by +Dynamo](https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables/constant.py#L98-L107). +We also have special subclasses for objects that require special +attention, like +[TensorVariable](https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables/tensor.py#L68-L69). +All these internal classes are defined in the +[torch/_dynamo/variables](https://github.com/pytorch/pytorch/tree/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables) +folder. + +Python objects are wrapped into their corresponding `VariableTracker` +class in +[VariableBuilder._wrap](https://github.com/pytorch/pytorch/blob/83c0763dda1f93c6cf552ba88260a0dc7a3ecb70/torch/_dynamo/variables/builder.py#L365). +This function is just a very long chain of `elif`s that tries to +recursively pattern-match the Python inputs into the appropriate type of +`VariableTracker`. + +**Debugging tip**. When we get unexpected results from dynamo, it is +sometimes caused by the builder. If the logic of the builder is wrong, +sometimes Dynamo may wrap a variable in the incorrect +`VariableTracker` type, and this may cause issues later on. It is +rather useful to have a look at the `VariableTracker` types that +appear in the errors, and the `VariableTracker` method that throws the +exception when you encounter a Dynamo error. In particular, sometimes we +find that an object is tracked as a `UserDefinedObjectVariable` (this +is Dynamo’s catch-all class), when it should have been tracked as +something more specific. In these cases, the `SourceBuilder.__call__` +logic is often to blame. + +**Debugging tip**. When running a program with `TORCH_LOGS=dynamo`, +one of the artifacts that are printed out is lines of the form + +``` +TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(), TensorVariable()] +``` + +This is the bytecode for the original program and the state of the stack +at that point. This is very useful to find where an object was not +traced into the right `VariableTracker`. + +Ok, so we have an IR for our tracer, now we *just* need to reimplement +CPython’s stack machine. This is implemented by +[InstructorTranslatorBase](https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py#L576-L594) +in +[symbolic_convert.py](https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py). + +`InstructionTranslatorBase` has about 200 methods, implementing almost +all of Python bytecodes. As an example, we can see the implementation of +`BUILD_LIST` + +```python +def BUILD_LIST(self, inst): + items = self.popn(inst.argval) + self.push(ListVariable(items, mutation_type=ValueMutationNew())) +``` + +This is the bytecode generated by constructions like `l = [2, 3, 4]`. +In this case, since there are three elements, the generated bytecode is +`BUILD_LIST 3`. This means that we pop the top `3` elements of the +stack and push a new list object to the top of the stack formed by these +three elements. + +## Generating the Output Graph + +With a way to symbolically execute Python code, we are set to extract +the PyTorch operations that happen during the symbolic execution of a +program given some inputs. This is implemented in Dynamo via the +[OutputGraph](https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/output_graph.py#L221-L230) +object. The `OutputGraph` object is [bound to an +`InstructionTranslator object](https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/symbolic_convert.py#L2060-L2071) +and it tracks all the data necessary to create the FX graph which will +be returned by Dynamo. + +All the inputs and intermediary elements of the FX graph are +`fx.Node`s. In Dynamo, `fx.Node`s are wrapped in +`fx.Proxy`s. `fx.Proxy`s are used to build the FX graph. +In particular, they record every PyTorch operation performed on them +into the graph. You can create a new operation to be added to +the graph by calling [create_proxy](https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/_dynamo/output_graph.py#L430-L431). +Then, we can add it to the graph through the function +[wrap_fx_proxy](https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/_dynamo/variables/builder.py#L1311). + +A graph stores operations on tensors… and operations on symbolic +integers. We will discuss symbolic integers later on, but first we will +discuss how Dynamo addresses a rather important correctness issue. + +(making-dynamo-sound-guards)= +## Making Dynamo Sound: Guards + +At this point, we have a way to trace programs completely disregarding control flow. +And for that, we have reimplemented all of CPython… If this sounds like a bit of an +overkill, that is because it is. +[torch.jit.trace](https://pytorch.org/docs/main/generated/torch.jit.trace.html) +already implements this without all this machinery, so what gives? + +The issue with `torch.jit.trace`, as it is warned in its docs, is that +it just works if the traced program is not data dependent. In other +words, it will just work if the program itself is linear. This means +writing our program without using if-elses, for-while loops, exceptions. +Even more, none of the libraries that we use can use any control flow! +All in all, not using control flow in a language as dynamic as Python +is, in fact, a huge constraint. + +JAX solves this problem by always retracing and caching the graph after +retracing. Dynamo, on the other hand, uses guards to avoid retracing the +whole program every time. + +A **guard** is an assumption (a boolean expression on an input) made in +order to specialize a frame for one set of example inputs. Reusing the +graph is only valid if these assumptions hold on the new inputs. + +For example, any constant input to a function, like a string, installs a +guard stating that that input should be of type `str` and equal to the +string we passed. Running + +```python +import torch + +@torch.compile +def fn(a, b): + return a * len(b) + +fn(torch.arange(10), "Hello") +``` + +with `TORCH_LOGS=guards` prints (among other guards) + +```python +___check_type_id(L['b'], 94334122025024) +L['b'] == 'Hello' +``` + +This reads as “the local variable `b` should have a specific type +(`str` in this case, represented by the constant `9433...`) and +its value should be `'Hello'`”. If we then execute the function +again passing a different argument + +```python +import torch + +@torch.compile +def fn(a, b): + return a * len(b) + +fn(torch.arange(10), "Hello") +fn(torch.arange(10), "Hi") +``` + +we can see the guard that failed by running `TORCH_LOGS=recompiles` + +```python +Recompiling function fn in script.py:3 +triggered by the following guard failure(s): + - L['b'] == 'Hello' +``` + +Guards are accumulated while [the inputs to the function are wrapped in +the +builder](https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/variables/builder.py#L808-L810) +and [during the execution of the +program](https://github.com/pytorch/pytorch/blob/69f112d5867f785a3a090a0c6d6644ae047033ac/torch/_dynamo/variables/dicts.py#L763-L769). +We will show many more examples of guards in the next section, but first +let us discuss sources. + +A **source** tracks how to reconstruct a variable from the original +local or global variables present when entering the current frame. In +particular, it tracks the original local and global objects and any of +the objects they contain. In + +```python +def foo(x: Tensor, y: List[Tensor]): + a = x * y[0] + return a * x +``` + +`x` and `y` have +[LocalSource](https://github.com/pytorch/pytorch/blob/40dc0580a69565b06ec5263efe5d87cecc8200f7/torch/_dynamo/source.py#L80-L92) +as their source, and `y[0]` has +[GetItemSource](https://github.com/pytorch/pytorch/blob/40dc0580a69565b06ec5263efe5d87cecc8200f7/torch/_dynamo/source.py#L302), +which stores a `LocalSource` inside. On the other hand, `a` will not +have a source as it is an intermediate variable that only exists within +the fx graph. + +All these are defined in +[torch/_dynamo/source.py](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/source.py). +We can see the guard generated by `GetItemSource` in the following +example: + +```python +import torch + +@torch.compile +def fn(x, l): + return x * len(l[0]) + +fn(torch.randn(8), ["Hi", "Hello"]) +``` + +generates the following guards + +```python +___check_type_id(L['l'], 94439025877664) +len(L['l']) == 2 +___check_type_id(L['l'][0], 94439025840192) +L['l'][0] == 'Hi' +___check_type_id(L['l'][1], 94439025840192) +L['l'][1] == 'Hello' +``` + +Here, we see the code generated by `GetItemSource` (`[0]` and +`[1]`) wrapping a `LocalSource` (`L['l']`). + +At this point, with sources and guards, we are able to implement a +caching system to avoid recompilation without having to retrace every +time. We will discuss a bit more in detail this caching system in the +sequel. + +The attentive reader will have noticed that this does not explain yet +why we need to have such fine control over the Python interpreter as to +having to reimplement it. The examples of guards that we have shown +depend on the input objects, so we could still compute these before +executing the function. In other words, we could implement this guard +system on top of `torch.jit.trace` and get the same functionality with +much less effort… Enter symbolic shapes. + +## Symbolic Shapes + +Another point we discussed in the introduction is that Dynamo knows how +to trace integers. In order to implement this, we use a symbolic class +[torch.SymInt](https://github.com/pytorch/pytorch/blob/fb80f05ee2e1cba17892980701bfd5dbce58349f/torch/__init__.py#L244-L249) +that acts like an `int` but it records all the operations performed on +it in the output FX graph. [^4] We already saw this class in the introduction +when introducing symbolic integer tracing. + +Let us now discuss the three properties that define symbolic shape +tracing in Dynamo, and how to implement them. + +### Static by default + +Dynamo assumes that every integer, let that be an input or the shape of +a tensor, is static by default. In other words, no integers will be +traced on the first execution of a function. Then, only if it detects +that an integer or a shape changed value during the execution, it will +trace it and generate a graph generic on that variable. + +We already saw this behavior in the introduction using integers. Let us +now look at an example using shapes of tensors. + +```python +import torch + +@torch.compile +def fn(a, b): + return a.shape[0] * a * b + +fn(torch.randn(4, 3), torch.randn(4, 3)) +fn(torch.randn(8, 3), torch.randn(8, 3)) +``` + +Running this program with `TORCH_LOGS=graph_code` we see that these +two calls are traced as + +```python +def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor): + mul = 4 * l_a_ + mul_1 = mul * l_b_ + return (mul_1,) + +def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor): + size = l_a_.size() + getitem = size[0] + mul = getitem * l_a_ + mul_1 = mul * l_b_ + return (mul_1,) +``` + +In the first graph the shape is traced as a constant, but once it +changes, it traces it symbolically using a `SymInt`s. In general, a +simpler way to see the shapes of the intermediary values is by running +the program with `TORCH_LOGS=graph_sizes` + +``` +TRACED GRAPH TENSOR SIZES +===== __compiled_fn_1 ===== +l_a_: (s0, 3) +l_a_ (concrete): (8, 3) +l_b_: (s0, 3) +l_b_ (concrete): (8, 3) +mul: (s0, 3) +mul (concrete): (8, 3) +mul_1: (s0, 3) +mul_1 (concrete): (8, 3) +``` + +where we can see that the first dimension of the two tensor args is +dynamic, given that it is represented by the `s0` variable. + +We can find how Dynamo implements this by running `TORCH_LOGS=guards` + +```python +# Guards first call +check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1]) +check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1]) + +# Guards second call +check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1]) +check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1]) + +L['b'].size()[0] == L['a'].size()[0] +2 <= L['a'].size()[0] +``` + +We see that on the first call, the guards check that the tensors have +some fixed sizes and strides. These guards fail in the second execution, +so it retraces. Since it was an `int` guard that failed, in this +second iteration it traces this `int` symbolically and it installs +more general guards on this more generic kernel. + +**Compilation performance tip**. If you know that a dimension will vary +in size, you can mark it as dynamic by calling +[torch._dynamo.mark_dynamic](https://github.com/pytorch/pytorch/blob/66a76516bfc341b2b55bb2056d2faa9c2de46d69/torch/_dynamo/decorators.py#L176) +before calling `torch.compile`. This will avoid the first compilation +with a static shape. There are other useful utility functions like +`maybe_mark_dynamic` or `mark_static`. You can also have all +integers and shapes traced by calling `torch.compile(dynamic=True)`. +This is mostly useful for debugging purposes. + +### 0, 1 are always specialized + +Regardless of whether we mark a dimension as dynamic, if we pass an input +where that dimension is 0 or 1, Dynamo will trace it as non-dynamic and it +will generate a specific graph for it. This is the reason why in the example +above we find guards of the form `2 <= L['a'].size()[0]`. + +There are several reasons for this choice. There are two particularly +important - A tensor is empty if and only if any of its dimensions is +zero - A tensor can only be contiguous if one of the strides is one + +This policy decision does NOT apply to plain Python ints; if we think a Python +int should be compiled dynamically, we won't specialize them by default; +instead, whether or not it gets specialized depends on its usage. + +### Duck shaping + +Dynamo performs what we call “duck shaping”. If two dynamic integers +have the same value at trace time, we will assume that they are equal +and guard on it. Effectively, this means that rather than having two +symbols `s0`, `s1` in the example above, we just unified them to +`s0` and had the guard `L['b'].size()[0] == L['a'].size()[0]`. This +enables performing fusions within the compiler while being able to +generate kernels that are generic enough. + +### Guards on symbolic ints + +We now understand how symbolic shapes are implemented at a high level +and the properties they have. Now, why is that symbolic shapes forced us +through the tricky route of getting control of the CPython interpreter? +Consider the following example: + +```python +import torch + +@torch.compile(dynamic=True) +def fn(a): + if a.shape[0] * 2 < 16: + return a + else: + return a + 1 + +fn(torch.randn(8)) +``` + +This code has a guard of the form `2*L['a'].size()[0] >= 16`. This is +a non-trivial guard in terms of the inputs of the function, but it is +registered in the middle of the execution of the program. Even more so, +we cannot know this guard is needed until we see the `if` statement +conditional on a `SymNodeVariable` argument. Such conditions are +invisible to `torch.jit.trace` and require deep analysis of the python +code. + +**Debugging tip** Running this code with `TORCH_LOGS=dynamo` tells us +where this guard was added + +``` +eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr) +``` + +Placing a breakpoint there and looking at the backtrace is rather useful +to understand where a guard came from. + +## Making Dynamo Complete: Graph Breaks + +With all the tools we have discussed, we have a tracer that can trace +PyTorch operations on tensors and integers and has a caching system that +knows when it can reuse a previously traced graph and when it needs to +retrace. All this executing arbitrary Python code! + +There is just one small issue with this. The statement “executing +arbitrary Python code” is perhaps a bit too general. Dynamo implements a +good part of Python, but does it implement the more complex parts, like +coroutines or async? Does it implement the whole Python standard +library? NumPy also has a Python API. Does `torch.compile` also +understand NumPy? and Django? [^5] + +Python’s ecosystem is massive, and a good part of it is written in other +more performant languages like C++ or Rust, and it just exposes Python +bindings. There is no hope in Dynamo tracing through Python objects that +are implemented in C++. What can a tracer do when it finds an operation +that it does not understand? + +The usual way machine learning tracers handle this issue is by informing +the user that the operation they choked on and giving up tracing +altogether. This would pose a real usability issue in the case of +PyTorch, where its users are used to the flexibility it gives them. As a +real-world example the `doctr_det_predictor` model uses NumPy and the +`cv2` library to [postprocess the model’s +result](https://github.com/mindee/doctr/blob/f2114758d529ed8d3d0030581638f0520b6b98d8/doctr/models/detection/core.py#L86). + +Here is another place where having access to CPython is interesting. +Rather than erroring out, Dynamo can let CPython run that problematic +code! To do this, Dynamo generates at trace time one graph with all the +operations before the problematic code, and one with all the operations +after. [^6] Then, at runtime, it will delegate to CPython to execute the +first graph, then the problematic code, and then the second graph. This +process of stopping the tracing and generating multiple graphs is called +a **graph break**. + +A small confession: I lied all throughout the introduction and the first +sections. Dynamo does not generate one graph, but **multiple graphs**! +For all practical purposes, starting retracing after a second graph can +be thought of as starting tracing a new function. The new graph after +the graph break will have its own guards, its new set of local +variables, and so on. + +To discuss how to implement graph breaks, we need to first revisit how +Dynamo interacts with CPython. Using PEP 523, CPython allows a user to +use their own frame evaluation mechanism. What we had not discussed is +that CPython also exposes its own frame evaluation for others to use. +Dynamo leverages this to let the fast CPython interpreter run the +compiled code. For a function without graph breaks, the whole tracing / +execution process of a program that calls the function 2 times with the +same arguments looks like this: + +1. In the first call to the function + + 1. Dynamo traces the function into an FX graph + + 1. The FX graph is compiled by the compiler (Inductor) into + efficient low-level code… but that’s a story for another day + + 2. It rewrites the bytecode of the function so that it simply calls + the compiled function + + 3. It gives CPython this new bytecode and asks it to run it + [here](https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/csrc/dynamo/eval_frame.c#L1006) + +2. In the second call to the function + + 1. It checks the guards from the first call against the new arguments + [here](https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/csrc/dynamo/eval_frame.c#L658). + Since they are the same arguments as before, they pass + 2. It asks CPython to run the bytecode associated to those guards + [here](https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/csrc/dynamo/eval_frame.c#L972-L975) + +This process on its own looks overly complicated. Why generate new +bytecode and ask CPython to run it rather than simply creating a C++ +binding to the compiled function and executing it? Well, this pattern +allows us to implement graph breaks! The bytecode generated by a graph +break has the following structure: + +1. Bytecode that executes the first graph +2. Bytecode that leaves the stack as it would be if CPython would have + executed the first graph. It also replays any modifications to local + or global variables that would be visible at this point +3. The bytecode that made Dynamo graph break +4. Bytecode that executes the second graph + +Let us see this in a simple example + +```python +import torch + +@torch.compile +def fn(a): + b = a + 2 + print("Hi") + return b + a + +fn(torch.randn(4)) +``` + +Running this with `TORCH_LOGS=bytecode` shows us the initial bytecode +and the modified bytecode + +```python +MODIFIED BYTECODE fn script.py line 3 + 0 LOAD_GLOBAL 1 (__compiled_fn_0) + 2 LOAD_FAST 0 (a) + 4 CALL_FUNCTION 1 + 6 STORE_FAST 3 (graph_out_0) + 8 LOAD_GLOBAL 0 (print) +10 LOAD_CONST 2 ('Hi') +12 LOAD_FAST 3 (graph_out_0) +14 LOAD_CONST 3 (0) +16 BINARY_SUBSCR +18 STORE_FAST 1 (b) + +20 CALL_FUNCTION 1 +22 LOAD_GLOBAL 2 (__resume_at_14_1) +24 ROT_TWO +26 LOAD_FAST 0 (a) +28 LOAD_FAST 1 (b) +30 CALL_FUNCTION 3 +32 RETURN_VALUE + +MODIFIED BYTECODE resume_in_fn script.py line 6 + 0 LOAD_GLOBAL 1 (__compiled_fn_2) + 2 LOAD_FAST 2 (b) + 4 LOAD_FAST 1 (a) + 6 CALL_FUNCTION 2 + 8 UNPACK_SEQUENCE 1 +10 RETURN_VALUE +``` + +We can see that the modified bytecode is split into two functions, +`fn`, the original function, and a function called `resume_in_fn`. +This second function is a function created by Dynamo to implement the +execution of the program starting at the graph break. This is often +called a [continuation +function](https://en.wikipedia.org/wiki/Continuation). This +continuation function simply calls the second compiled function with the +right arguments. The code for the initial function is rewritten +implementing the strategy that we described before + +- L0-4. Call the compiled function (`a + 2`). +- L6. Store its result in a local variable called `graph_out_0`. + `graph_out_0` is a tuple +- L8-18. Leave the stack as it would be at the point of the graph break +- L20. Execute the code that caused the graph break +- L22-32. Call the compiled continuation function (`a + b`) + +The code generation of the stack in Dynamo is delegated to +`VariableTracker` subclasses. Every `VariableTracker` object in +Dynamo has a [reconstruct](https://github.com/pytorch/pytorch/blob/e891a3bba9f05697d72776f6e89347231a141f03/torch/_dynamo/variables/lists.py#L307-L309) +method that generates the necessary bytecode to create the python object +it represents on the stack. + +**Debugging tip**. Graph breaks hamper performance, and as such, it is +best to avoid them. Running a program with `TORCH_LOGS=graph_breaks` +is a great way to find how many graph breaks did our program hit. The +information it returns is in terms of `VariableTracker` objects, so +the debugging tips above are sometimes also helpful to figure out what +caused that graph break. + +## Conclusion + +Dynamo is a complex piece of software. Once you sign up to implement a +CPython interpreter you know you are in for a ride. That being said, we +hope that this post helps demystify it a bit. + +Dynamo is (mostly) implemented in Python. We left plenty of links to the +pieces of the code that we discussed. We hope that reading those pieces +of code and grepping for the places that call them, or putting +breakpoints on them and looking at the call stack helps understanding +the rest of the code base. + +Of course, the best way to learn how a piece of software works is by +extending it. In this case, the best way is to have a look at the [open +dynamo issues on +github](https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+label%3A%22module%3A+dynamo%22+). +Many of them require very minor changes in the code, once you find where +you need to make those changes. + +## Footnotes + +Below are additional details and references for concepts mentioned in this document. + +[^1]: In the literature, this is called a Directed Acyclical Graph (DAG). + +[^2]: All this binding code lives in `torch/csrc/dynamo/eval_frame.c`. + +[^3]: In CPython lingo, the set of all these objects are called [a + frame](https://github.com/python/cpython/blob/f26bfe4b25f7e5a4f68fcac26207b7175abad208/Include/internal/pycore_frame.h#L57-L71). + +[^4]: There are also `SymBool` and `SymFloat` classes. The latter one + is not used all that much at the time of this writing. + +[^5]: Interestingly enough, it does understand NumPy code! Have a look at + [this blogpost](https://pytorch.org/blog/compiling-numpy-code/) + and [the docs](https://pytorch.org/docs/main/torch.compiler_faq.html#does-numpy-work-with-torch-compile). + Now, this is just possible because we reimplemented NumPy using + PyTorch. Good luck implementing Django in PyTorch though… + +[^6]: Assuming there is just one piece of problematic code. If there are + more, Dynamo can split the code into as many graphs as it needs. diff --git a/docs/source/torch.compiler_dynamo_deepdive.rst b/docs/source/torch.compiler_dynamo_deepdive.rst deleted file mode 100644 index d63e8a4e7d3f2c..00000000000000 --- a/docs/source/torch.compiler_dynamo_deepdive.rst +++ /dev/null @@ -1,868 +0,0 @@ -.. _torch.compiler_dynamo_deepdive: - -Dynamo Deep-Dive -================ - -TorchDynamo (or simply Dynamo) is the tracer within ``torch.compile``, -and it is, more often than not, the one to blame for those insane -backtraces. However, we cannot blindly blame Dynamo for these errors. In -order to provide the user with the flexibility it does, Dynamo is given -the arduous task of understanding any Python program. In particular, -Dynamo has to implement a good part of the Python programming language -internally! - -In this post, we will go over the internal design of Dynamo from the -ground up. We will discuss the functionality it provides, and how it is -implemented. By the end of this post, you will have a better -understanding of what went wrong when you ``torch.compiled`` a PyTorch -program and the compilation errored out, or succeeded but the speed-up -was not what you expected. - -A Gentle Introduction to Dynamo -------------------------------- - -Before getting our hands dirty with all the implementation details, -let’s start by discussing what it is that Dynamo does. - -Dynamo is a tracer. This means, given and function and inputs to it, it -executes the function and records a linear sequence of instructions -(without control flow) into a graph. For example, consider the following -program: - -.. code:: python - - import torch - - @torch.compile - def mse(x, y): - z = (x - y) ** 2 - return z.sum() - - x = torch.randn(200) - y = torch.randn(200) - mse(x, y) - -If we save this program into the file ``example.py`` and we run - -.. code:: bash - - TORCH_LOGS=graph_code python example.py - -we see the output that Dynamo traced - -.. code:: python - - def forward(l_x_: torch.Tensor, l_y_: torch.Tensor): - # File: example.py:5, code: z = (x - y) ** 2 - sub = l_x_ - l_y_ - z = sub ** 2 - # File: example.py:6, code: return z.sum() - sum_1 = z.sum() - return (sum_1,) - -We call this a **graph (or trace) of the function for the given -inputs**. This is represented via an `FX -graph `__. We will simply think -of an FX graph as a container that stores a list of function calls. - -The first thing we should notice is that the graph is a linear sequence -of PyTorch operations. [1]_ Dynamo records all the PyTorch operations -and stores them sequentially. For example, it split ``z = (x - y) ** 2`` -into its two constituting operations, ``sub = l_x_ - l_y_`` and -``z = sub ** 2``. - -When we say that the trace is linear, we mean that there is no branching -or any control flow. To see this, consider - -.. code:: python - - import torch - - @torch.compile - def fn(x, n): - y = x ** 2 - if n >= 0: - return (n + 1) * y - else: - return y / n - - x = torch.randn(200) - fn(x, 2) - -which, when executed with ``TORCH_LOGS=graph_code``, returns - -.. code:: python - - def forward(l_x_: torch.Tensor): - # File: example.py:5, code: y = x ** 2 - y = l_x_ ** 2 - # File: example.py:7, code: return (n + 1) * y - mul = 3 * y - return (mul,) - -We see that Dynamo completely removed the ``if`` statement from the -trace and just recorded the operations that were executed with the -inputs. - -As such, it should be clear that **the trace of a function depends on -the inputs**. In particular, this means that the trace is not generated -when we write ``@torch.compile``, but when we execute the function -``fn(x, 2)`` with the actual arguments. - -The other interesting thing to note here is that Dynamo removed the -second argument to the function. Instead, it treated it as a constant -and recorded the result of the operation ``n + 1`` in the graph. This is -another feature of Dynamo: Dynamo will treat as constant any non-tensor -value… other than ints. Let’s see now how are ints special. - -The last defining property of Dynamo is that it knows how to handle -dynamic shapes. Symbolic shapes refer to Dynamo’s ability of tracing -shapes, and more generally, integers, rather than leaving them as -constants. This allows for avoiding recompilations and deploying generic -models that work for any size in production. The main examples of places -where dynamic shapes appear are the batch size, where we might train a -model with a fixed batch size but then perform inference for an -arbitrary batch size, or the variable sequence length that one -encounters when processing text or audio. - -We can see this by executing a few more times the example above - -.. code:: python - - import torch - - @torch.compile - def fn(x, n): - y = x ** 2 - if n >= 0: - return (n + 1) * y - else: - return y / n - - x = torch.randn(200) - fn(x, 2) - fn(x, 3) - fn(x, -2) - -In this case, ``TORCH_LOGS=graph_code`` generates two more graphs - -.. code:: python - - # Graph for n==2 omitted - - def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt): - # File: a.py:5, code: y = x ** 2 - y = l_x_ ** 2 - - # File: a.py:7, code: return (n + 1) * y - add = l_n_ + 1 - mul = add * y - return (mul,) - -.. code:: python - - def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt): - # File: a.py:5, code: y = x ** 2 - y = l_x_ ** 2 - - # File: a.py:9, code: return y / n - truediv = y / l_n_ - return (truediv,) - -Dynamo detected that one integer changed its value after the first call -and started tracing it. We see that these graphs are generic, and trace -the variable ``n`` symbolically via an object of type ``SymInt``. - -If after these calls we call ``fn(x, 4)``, Dynamo would not recompile, -but rather reuse the graph that was already traced. - -To summarize: 1. Dynamo is a Python tracer 2. Given some inputs, it -returns an FX graph with the PyTorch functions that were executed 3. It -can also trace integers if it detects that they changed between calls 4. -It specializes any other value that is not a tensor or a scalar - -Of course, Dynamo does many more things, like figuring out when it needs -to retrace, rewriting the bytecode of the function, implementing graph -breaks… To keep the introduction short, we will incrementally discuss -all these in the sequel. - -PEP 523: Adding a frame evaluation API to CPython -------------------------------------------------- - -Imagine now that we are given the task to implement Dynamo. Where would -we even start? Rather conveniently for us, `PEP -523 `__ was released with Python 3.6. -This PEP `was -designed `__ to -allow third parties to create JIT compilers for Python. Let’s see how. - -**A note on CPython**: CPython is internally implemented as a `stack -machine `__. A Python -program is compiled into -`bytecodes `__ that then are -executed by this interpreter. To learn more about these bytecodes, see -the `dis module `__ from the -standard library. See also `the developer -docs `__ for an -introduction to CPython’s interpreter. We will assume that the reader is -familiar with the notion of a stack machine. - -PEP 523 exposes an API where a user can add a custom per-function -interpreter. Then, CPython will use this interpreter rather than its own -to execute the function. In order to be able to execute the function, on -entry, CPython provides the custom interpreter with things like - The -bytecode of the function - The value of the arguments of the function -(i.e., the local variables) and their names - The value of the global -variables and their names - The builtin functions like ``abs`` or -``print`` - -You can see all the fields -`here `__. [2]_ - -In summary, CPython provides the user’s interpreter with all the -information necessary to execute the function. [3]_ - -With this API, we can implement a tracer by implementing an interpreter -that runs the code and records in a graph all the PyTorch operations -that occur during this execution. This is exactly what Dynamo does. - -Dynamo uses this CPython API to parse all these objects and packs them -into `a Python -structure `__. -After it has done so… it goes back from C to python. Other than for this -piece of code that communicates with CPython, Dynamo is fully -implemented in Python. - -It should be clear that it is the decorator ``@torch.compile``\ ’s job -to install the necessary scaffolding that will pass the bytecode, the -args, global variables and so on to Dynamo when the function is called. -Again, ``@torch.compile`` does not actually compile anything. - -Implementing CPython in Python ------------------------------- - -So, we are back in the Python world. We have the bytecode of a function, -and all the context necessary to execute it. In particular, we have -landed at -`_convert_frame_assert `__. -This is the function that the decorator ``torch.compile`` returns! We -get to this function from -`_dynamo.optimize `__. -The decorator ``torch.compile`` is just a nice API around -``_dynamo.optimize``. - -Before getting into implementing a Python interpreter, we want to define -an `IR `__. -In particular, we want to wrap all the local and global variables in our -own internal classes. This allows us to better track these objects and -group together objects that can be treated in the same way to the eyes -of Dynamo. - -The parent class of the internal class structure is ``VariableTracker`` -and represents the different objects that Dynamo understands. For -example, ``ListVariable``, represents a ``list`` object, and keeps -internally a `list of VariableTrackers `__. -Another example of ``VariableTracker`` is -`ConstantVariable `__. -ConstantVariable wraps all the `objects considered constant by -Dynamo `__. -We also have special subclasses for objects that require special -attention, like -`TensorVariable `__. -All these internal classes are defined in the -`torch/_dynamo/variables `__ -folder. - -Python objects are wrapped into their corresponding ``VariableTracker`` -class in -`VariableBuilder._wrap `__. -This function is just a very long chain of ``elif``\ s that tries to -recursively pattern-match the Python inputs into the appropriate type of -``VariableTracker``. - -**Debugging tip**. When we get unexpected results from dynamo, it is -sometimes caused by the builder. If the logic of the builder is wrong, -sometimes Dynamo may wrap a variable in the incorrect -``VariableTracker`` type, and this may cause issues later on. It is -rather useful to have a look at the ``VariableTracker`` types that -appear in the errors, and the ``VariableTracker`` method that throws the -exception when you encounter a Dynamo error. In particular, sometimes we -find that an object is tracked as a ``UserDefinedObjectVariable`` (this -is Dynamo’s catch-all class), when it should have been tracked as -something more specific. In these cases, the ``SourceBuilder.__call__`` -logic is often to blame. - -**Debugging tip**. When running a program with ``TORCH_LOGS=dynamo``, -one of the artifacts that are printed out is lines of the form - -:: - - TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(), TensorVariable()] - -This is the bytecode for the original program and the state of the stack -at that point. This is very useful to find where an object was not -traced into the right ``VariableTracker``. - -Ok, so we have an IR for our tracer, now we *just* need to reimplement -CPython’s stack machine. This is implemented by -`InstructorTranslatorBase `__ -in -`symbolic_convert.py `__. - -``InstructionTranslatorBase`` has about 200 methods, implementing almost -all of Python bytecodes. As an example, we can see the implementation of -``BUILD_LIST`` - -.. code:: python - - def BUILD_LIST(self, inst): - items = self.popn(inst.argval) - self.push(ListVariable(items, mutation_type=ValueMutationNew())) - -This is the bytecode generated by constructions like ``l = [2, 3, 4]``. -In this case, since there are three elements, the generated bytecode is -``BUILD_LIST 3``. This means that we pop the top ``3`` elements of the -stack and push a new list object to the top of the stack formed by these -three elements. - -Generating the Output Graph ---------------------------- - -With a way to symbolically execute Python code, we are set to extract -the PyTorch operations that happen during the symbolic execution of a -program given some inputs. This is implemented in Dynamo via the -`OutputGraph `__ -object. The ``OutputGraph`` object is `bound to an -`InstructionTranslator object `__ -and it tracks all the data necessary to create the FX graph which will -be returned by Dynamo. - -All the inputs and intermediary elements of the FX graph are -``fx.Node``\ s. In Dynamo, ``fx.Node``\ s are wrapped in -``fx.Proxy``\ s. ``fx.Proxy``\ s are used to build the FX graph. -In particular, they record every PyTorch operation performed on them -into the graph. You can create a new operation to be added to -the graph by calling `create_proxy `__. -Then, we can add it to the graph through the function -`wrap_fx_proxy `__. - -A graph stores operations on tensors… and operations on symbolic -integers. We will discuss symbolic integers later on, but first we will -discuss how Dynamo addresses a rather important correctness issue. - -.. _making-dynamo-sound-guards: - -Making Dynamo Sound: Guards ---------------------------- - -At this point, we have a way to trace programs completely disregarding control flow. -And for that, we have reimplemented all of CPython… If this sounds like a bit of an -overkill, that is because it is. -`torch.jit.trace `__ -already implements this without all this machinery, so what gives? - -The issue with ``torch.jit.trace``, as it is warned in its docs, is that -it just works if the traced program is not data dependent. In other -words, it will just work if the program itself is linear. This means -writing our program without using if-elses, for-while loops, exceptions. -Even more, none of the libraries that we use can use any control flow! -All in all, not using control flow in a language as dynamic as Python -is, in fact, a huge constraint. - -JAX solves this problem by always retracing and caching the graph after -retracing. Dynamo, on the other hand, uses guards to avoid retracing the -whole program every time. - -A **guard** is an assumption (a boolean expression on an input) made in -order to specialize a frame for one set of example inputs. Reusing the -graph is only valid if these assumptions hold on the new inputs. - -For example, any constant input to a function, like a string, installs a -guard stating that that input should be of type ``str`` and equal to the -string we passed. Running - -.. code:: python - - import torch - - @torch.compile - def fn(a, b): - return a * len(b) - - fn(torch.arange(10), "Hello") - -with ``TORCH_LOGS=guards`` prints (among other guards) - -.. code:: python - - ___check_type_id(L['b'], 94334122025024) - L['b'] == 'Hello' - -This reads as “the local variable ``b`` should have a specific type -(``str`` in this case, represented by the constant ``9433...``) and -its value should be ``'Hello'``”. If we then execute the function -again passing a different argument - -.. code:: python - - import torch - - @torch.compile - def fn(a, b): - return a * len(b) - - fn(torch.arange(10), "Hello") - fn(torch.arange(10), "Hi") - -we can see the guard that failed by running ``TORCH_LOGS=recompiles`` - -.. code:: python - - Recompiling function fn in script.py:3 - triggered by the following guard failure(s): - - L['b'] == 'Hello' - -Guards are accumulated while `the inputs to the function are wrapped in -the -builder `__ -and `during the execution of the -program `__. -We will show many more examples of guards in the next section, but first -let us discuss sources. - -A **source** tracks how to reconstruct a variable from the original -local or global variables present when entering the current frame. In -particular, it tracks the original local and global objects and any of -the objects they contain. In - -.. code:: python - - def foo(x: Tensor, y: List[Tensor]): - a = x * y[0] - return a * x - -``x`` and ``y`` have -`LocalSource `__ -as their source, and ``y[0]`` has -`GetItemSource `__, -which stores a ``LocalSource`` inside. On the other hand, ``a`` will not -have a source as it is an intermediate variable that only exists within -the fx graph. - -All these are defined in -`torch/_dynamo/source.py `__. -We can see the guard generated by ``GetItemSource`` in the following -example: - -.. code:: python - - import torch - - @torch.compile - def fn(x, l): - return x * len(l[0]) - - fn(torch.randn(8), ["Hi", "Hello"]) - -generates the following guards - -.. code:: python - - ___check_type_id(L['l'], 94439025877664) - len(L['l']) == 2 - ___check_type_id(L['l'][0], 94439025840192) - L['l'][0] == 'Hi' - ___check_type_id(L['l'][1], 94439025840192) - L['l'][1] == 'Hello' - -Here, we see the code generated by ``GetItemSource`` (``[0]`` and -``[1]``) wrapping a ``LocalSource`` (``L['l']``). - -At this point, with sources and guards, we are able to implement a -caching system to avoid recompilation without having to retrace every -time. We will discuss a bit more in detail this caching system in the -sequel. - -The attentive reader will have noticed that this does not explain yet -why we need to have such fine control over the Python interpreter as to -having to reimplement it. The examples of guards that we have shown -depend on the input objects, so we could still compute these before -executing the function. In other words, we could implement this guard -system on top of ``torch.jit.trace`` and get the same functionality with -much less effort… Enter symbolic shapes. - -Symbolic Shapes ---------------- - -Another point we discussed in the introduction is that Dynamo knows how -to trace integers. In order to implement this, we use a symbolic class -`torch.SymInt `__ -that acts like an ``int`` but it records all the operations performed on -it in the output FX graph. [4]_ We already saw this class in the introduction -when introducing symbolic integer tracing. - -Let us now discuss the three properties that define symbolic shape -tracing in Dynamo, and how to implement them. - -Static by default -^^^^^^^^^^^^^^^^^ - -Dynamo assumes that every integer, let that be an input or the shape of -a tensor, is static by default. In other words, no integers will be -traced on the first execution of a function. Then, only if it detects -that an integer or a shape changed value during the execution, it will -trace it and generate a graph generic on that variable. - -We already saw this behavior in the introduction using integers. Let us -now look at an example using shapes of tensors. - -.. code:: python - - import torch - - @torch.compile - def fn(a, b): - return a.shape[0] * a * b - - fn(torch.randn(4, 3), torch.randn(4, 3)) - fn(torch.randn(8, 3), torch.randn(8, 3)) - -Running this program with ``TORCH_LOGS=graph_code`` we see that these -two calls are traced as - -.. code:: python - - def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor): - mul = 4 * l_a_ - mul_1 = mul * l_b_ - return (mul_1,) - - def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor): - size = l_a_.size() - getitem = size[0] - mul = getitem * l_a_ - mul_1 = mul * l_b_ - return (mul_1,) - -In the first graph the shape is traced as a constant, but once it -changes, it traces it symbolically using a ``SymInt``\ s. In general, a -simpler way to see the shapes of the intermediary values is by running -the program with ``TORCH_LOGS=graph_sizes`` - -:: - - TRACED GRAPH TENSOR SIZES - ===== __compiled_fn_1 ===== - l_a_: (s0, 3) - l_a_ (concrete): (8, 3) - l_b_: (s0, 3) - l_b_ (concrete): (8, 3) - mul: (s0, 3) - mul (concrete): (8, 3) - mul_1: (s0, 3) - mul_1 (concrete): (8, 3) - -where we can see that the first dimension of the two tensor args is -dynamic, given that it is represented by the ``s0`` variable. - -We can find how Dynamo implements this by running ``TORCH_LOGS=guards`` - -.. code:: python - - # Guards first call - check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1]) - check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1]) - - # Guards second call - check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1]) - check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1]) - - L['b'].size()[0] == L['a'].size()[0] - 2 <= L['a'].size()[0] - -We see that on the first call, the guards check that the tensors have -some fixed sizes and strides. These guards fail in the second execution, -so it retraces. Since it was an ``int`` guard that failed, in this -second iteration it traces this ``int`` symbolically and it installs -more general guards on this more generic kernel. - -**Compilation performance tip**. If you know that a dimension will vary -in size, you can mark it as dynamic by calling -`torch._dynamo.mark_dynamic `__ -before calling ``torch.compile``. This will avoid the first compilation -with a static shape. There are other useful utility functions like -``maybe_mark_dynamic`` or ``mark_static``. You can also have all -integers and shapes traced by calling ``torch.compile(dynamic=True)``. -This is mostly useful for debugging purposes. - -0, 1 are always specialized -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Regardless of whether we mark a dimension as dynamic, if we pass an input -where that dimension is 0 or 1, Dynamo will trace it as non-dynamic and it -will generate a specific graph for it. This is the reason why in the example -above we find guards of the form ``2 <= L['a'].size()[0]``. - -There are several reasons for this choice. There are two particularly -important - A tensor is empty if and only if any of its dimensions is -zero - A tensor can only be contiguous if one of the strides is one - -This policy decision does NOT apply to plain Python ints; if we think a Python -int should be compiled dynamically, we won't specialize them by default; -instead, whether or not it gets specialized depends on its usage. - -Duck shaping -^^^^^^^^^^^^ - -Dynamo performs what we call “duck shaping”. If two dynamic integers -have the same value at trace time, we will assume that they are equal -and guard on it. Effectively, this means that rather than having two -symbols ``s0``, ``s1`` in the example above, we just unified them to -``s0`` and had the guard ``L['b'].size()[0] == L['a'].size()[0]``. This -enables performing fusions within the compiler while being able to -generate kernels that are generic enough. - -Guards on symbolic ints -^^^^^^^^^^^^^^^^^^^^^^^ - -We now understand how symbolic shapes are implemented at a high level -and the properties they have. Now, why is that symbolic shapes forced us -through the tricky route of getting control of the CPython interpreter? -Consider the following example: - -.. code:: python - - import torch - - @torch.compile(dynamic=True) - def fn(a): - if a.shape[0] * 2 < 16: - return a - else: - return a + 1 - - fn(torch.randn(8)) - -This code has a guard of the form ``2*L['a'].size()[0] >= 16``. This is -a non-trivial guard in terms of the inputs of the function, but it is -registered in the middle of the execution of the program. Even more so, -we cannot know this guard is needed until we see the ``if`` statement -conditional on a ``SymNodeVariable`` argument. Such conditions are -invisible to ``torch.jit.trace`` and require deep analysis of the python -code. - -**Debugging tip** Running this code with ``TORCH_LOGS=dynamo`` tells us -where this guard was added - -:: - - eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr) - -Placing a breakpoint there and looking at the backtrace is rather useful -to understand where a guard came from. - -Making Dynamo Complete: Graph Breaks ------------------------------------- - -With all the tools we have discussed, we have a tracer that can trace -PyTorch operations on tensors and integers and has a caching system that -knows when it can reuse a previously traced graph and when it needs to -retrace. All this executing arbitrary Python code! - -There is just one small issue with this. The statement “executing -arbitrary Python code” is perhaps a bit too general. Dynamo implements a -good part of Python, but does it implement the more complex parts, like -coroutines or async? Does it implement the whole Python standard -library? NumPy also has a Python API. Does ``torch.compile`` also -understand NumPy? and Django? [5]_ - -Python’s ecosystem is massive, and a good part of it is written in other -more performant languages like C++ or Rust, and it just exposes Python -bindings. There is no hope in Dynamo tracing through Python objects that -are implemented in C++. What can a tracer do when it finds an operation -that it does not understand? - -The usual way machine learning tracers handle this issue is by informing -the user that the operation they choked on and giving up tracing -altogether. This would pose a real usability issue in the case of -PyTorch, where its users are used to the flexibility it gives them. As a -real-world example the ``doctr_det_predictor`` model uses NumPy and the -``cv2`` library to `postprocess the model’s -result `__. - -Here is another place where having access to CPython is interesting. -Rather than erroring out, Dynamo can let CPython run that problematic -code! To do this, Dynamo generates at trace time one graph with all the -operations before the problematic code, and one with all the operations -after. [6]_ Then, at runtime, it will delegate to CPython to execute the -first graph, then the problematic code, and then the second graph. This -process of stopping the tracing and generating multiple graphs is called -a **graph break**. - -A small confession: I lied all throughout the introduction and the first -sections. Dynamo does not generate one graph, but **multiple graphs**! -For all practical purposes, starting retracing after a second graph can -be thought of as starting tracing a new function. The new graph after -the graph break will have its own guards, its new set of local -variables, and so on. - -To discuss how to implement graph breaks, we need to first revisit how -Dynamo interacts with CPython. Using PEP 523, CPython allows a user to -use their own frame evaluation mechanism. What we had not discussed is -that CPython also exposes its own frame evaluation for others to use. -Dynamo leverages this to let the fast CPython interpreter run the -compiled code. For a function without graph breaks, the whole tracing / -execution process of a program that calls the function 2 times with the -same arguments looks like this: - -1. In the first call to the function - - 1. Dynamo traces the function into an FX graph - - 1. The FX graph is compiled by the compiler (Inductor) into - efficient low-level code… but that’s a story for another day - - 2. It rewrites the bytecode of the function so that it simply calls - the compiled function - 3. It gives CPython this new bytecode and asks it to run it - [`here `__] - -2. In the second call to the function - - 1. It checks the guards from the first call against the new arguments - [`here `__]. - Since they are the same arguments as before, they pass - 2. It asks CPython to run the bytecode associated to those guards - [`here `__] - -This process on its own looks overly complicated. Why generate new -bytecode and ask CPython to run it rather than simply creating a C++ -binding to the compiled function and executing it? Well, this pattern -allows us to implement graph breaks! The bytecode generated by a graph -break has the following structure: - -1. Bytecode that executes the first graph -2. Bytecode that leaves the stack as it would be if CPython would have - executed the first graph. It also replays any modifications to local - or global variables that would be visible at this point -3. The bytecode that made Dynamo graph break -4. Bytecode that executes the second graph - -Let us see this in a simple example - -.. code:: python - - import torch - - @torch.compile - def fn(a): - b = a + 2 - print("Hi") - return b + a - - fn(torch.randn(4)) - -Running this with ``TORCH_LOGS=bytecode`` shows us the initial bytecode -and the modified bytecode - -.. code:: python - - MODIFIED BYTECODE fn script.py line 3 - 0 LOAD_GLOBAL 1 (__compiled_fn_0) - 2 LOAD_FAST 0 (a) - 4 CALL_FUNCTION 1 - 6 STORE_FAST 3 (graph_out_0) - 8 LOAD_GLOBAL 0 (print) - 10 LOAD_CONST 2 ('Hi') - 12 LOAD_FAST 3 (graph_out_0) - 14 LOAD_CONST 3 (0) - 16 BINARY_SUBSCR - 18 STORE_FAST 1 (b) - - 20 CALL_FUNCTION 1 - 22 LOAD_GLOBAL 2 (__resume_at_14_1) - 24 ROT_TWO - 26 LOAD_FAST 0 (a) - 28 LOAD_FAST 1 (b) - 30 CALL_FUNCTION 3 - 32 RETURN_VALUE - - MODIFIED BYTECODE resume_in_fn script.py line 6 - 0 LOAD_GLOBAL 1 (__compiled_fn_2) - 2 LOAD_FAST 2 (b) - 4 LOAD_FAST 1 (a) - 6 CALL_FUNCTION 2 - 8 UNPACK_SEQUENCE 1 - 10 RETURN_VALUE - -We can see that the modified bytecode is split into two functions, -``fn``, the original function, and a function called ``resume_in_fn``. -This second function is a function created by Dynamo to implement the -execution of the program starting at the graph break. This is often -called a `continuation -function `__. This -continuation function simply calls the second compiled function with the -right arguments. The code for the initial function is rewritten -implementing the strategy that we described before - -- L0-4. Call the compiled function (``a + 2``). -- L6. Store its result in a local variable called ``graph_out_0``. - ``graph_out_0`` is a tuple -- L8-18. Leave the stack as it would be at the point of the graph break -- L20. Execute the code that caused the graph break -- L22-32. Call the compiled continuation function (``a + b``) - -The code generation of the stack in Dynamo is delegated to -``VariableTracker`` subclasses. Every ``VariableTracker`` object in -Dynamo has a `reconstruct `__ -method that generates the necessary bytecode to create the python object -it represents on the stack. - -**Debugging tip**. Graph breaks hamper performance, and as such, it is -best to avoid them. Running a program with ``TORCH_LOGS=graph_breaks`` -is a great way to find how many graph breaks did our program hit. The -information it returns is in terms of ``VariableTracker`` objects, so -the debugging tips above are sometimes also helpful to figure out what -caused that graph break. - -Conclusion ----------- - -Dynamo is a complex piece of software. Once you sign up to implement a -CPython interpreter you know you are in for a ride. That being said, we -hope that this post helps demystify it a bit. - -Dynamo is (mostly) implemented in Python. We left plenty of links to the -pieces of the code that we discussed. We hope that reading those pieces -of code and grepping for the places that call them, or putting -breakpoints on them and looking at the call stack helps understanding -the rest of the code base. - -Of course, the best way to learn how a piece of software works is by -extending it. In this case, the best way is to have a look at the `open -dynamo issues on -github `__. -Many of them require very minor changes in the code, once you find where -you need to make those changes. - -Footnotes ---------- - -.. [1] In the literature, this is called a Directed Acyclical Graph (DAG). - -.. [2] All this binding code lives in ``torch/csrc/dynamo/eval_frame.c``. - -.. [3] In CPython lingo, the set of all these objects are called `a - frame `__. - -.. [4] There are also ``SymBool`` and ``SymFloat`` classes. The latter one - is not used all that much at the time of this writing. - -.. [5] Interestingly enough, it does understand NumPy code! Have a look at - `this blogpost `__ - and `the docs `__. - Now, this is just possible because we reimplemented NumPy using - PyTorch. Good luck implementing Django in PyTorch though… - -.. [6] Assuming there is just one piece of problematic code. If there are - more, Dynamo can split the code into as many graphs as it needs. diff --git a/docs/source/torch.compiler_dynamo_overview.md b/docs/source/torch.compiler_dynamo_overview.md new file mode 100644 index 00000000000000..6baf75058a8e43 --- /dev/null +++ b/docs/source/torch.compiler_dynamo_overview.md @@ -0,0 +1,333 @@ +# Dynamo Overview + +Before you read this section, read {ref}`torch.compiler_overview`. + +TorchDynamo (or simply Dynamo) is a Python-level Just-In-Time (JIT) compiler designed to make +unmodified PyTorch programs faster. Dynamo hooks into the frame evaluation +API in CPython ([PEP 523](https://peps.python.org/pep-0523/)) to +dynamically modify Python bytecode right before it is executed. It +rewrites Python bytecode to extract sequences of PyTorch +operations into an [FX Graph](https://pytorch.org/docs/stable/fx.html) +which is then compiled with a customizable backend. +It creates this FX Graph through bytecode analysis and is designed to +mix Python execution with compiled backends to get the best of both +worlds — usability and performance. + +Dynamo makes it easy to experiment with different compiler +backends to make PyTorch code faster with a single line decorator +`torch._dynamo.optimize()` which is wrapped for convenience by `torch.compile()` + +The following diagram demonstrates how PyTorch works with `torch.compile` +and without it: + +```{image} _static/img/dynamo/TorchDynamo.png +``` + +`TorchInductor` is one of the backends +supported by [Dynamo Graph](https://pytorch.org/docs/stable/fx.html) +into [Triton](https://github.com/openai/triton) for GPUs or +[C++/OpenMP](https://www.openmp.org/) for CPUs. We have a +[training performance dashboard](https://github.com/pytorch/torchdynamo/issues/681#issuecomment-1233828468) +that provides performance comparison for different training backends. You can read +more in the [TorchInductor post on PyTorch +dev-discuss](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747). + +For an in-depth overview, read the sections below, watch the deep-dive video, +and check out the dev-discuss topics. + +- [Dynamo deep-dive video](https://www.youtube.com/watch?v=egZB5Uxki0I) +- [dev-discuss topics](https://dev-discuss.pytorch.org/search?q=TorchDynamo%20order%3Alatest) +## Dynamo Internals + +**Author**: [Jason Ansel](https://github.com/jansel) and [Kaichao You](https://github.com/youkaichao) + +This section will go over some of the Dynamo internals and will +demonstrate how Dynamo works under the hood. + +### What is a guard? + +Dynamo operates just-in-time and specializes graphs based on +dynamic properties. Below is a basic example of how to use Dynamo. +One can decorate a function or a method using `torchdynamo.optimize` to enable +Dynamo optimization: + +```python +from typing import List +import torch +from torch import _dynamo as torchdynamo +def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + +@torchdynamo.optimize(my_compiler) +def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b +for _ in range(100): + toy_example(torch.randn(10), torch.randn(10)) +``` + +For example, the first graph above has the following +guards: + +``` +GUARDS: +hasattr(L['a'], '_dynamo_dynamic_indices') == False +hasattr(L['b'], '_dynamo_dynamic_indices') == False +utils_device.CURRENT_DEVICE == None +___skip_backend_check() or ___current_backend() == ___lookup_backend(140355900538256) +check_tensor(L['a'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1]) +check_tensor(L['b'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1]) +``` + +If any of those guards fail, the graph will be recaptured and +recompiled. The interesting guard there is `check_tensor`, which +checks the following `torch.Tensor` properties: + +- Python class of the tensor (tensor subclassing, etc) +- dtype +- device +- requires_grad +- dispatch_key (with thread-local includes/excludes applied) +- ndim +- sizes\* +- strides\* + +The full specialization mode allows the backend compiler to assume an +entirely static graph. Unfortunately, most backends require this. +Operators which return dynamic shapes will trigger a graph break when +not in dynamic shape mode. + +### What is Dynamo doing? + +If you want to understand better what Dynamo is doing, you can run your code with: + +``` +TORCH_LOGS="+dynamo,guards,bytecode" +``` + +If you are not familiar with Python bytecode, you can add a decompiler hook +to decompile the bytecode into human-readable source code. One available +tool is [depyf](https://github.com/youkaichao/depyf). If you don't have +`depyf` already installed, run `pip install depyf`. Then, add the +following code to install decompilation hooks before you run any code. + +```python +import depyf +depyf.install() +``` + +This code triggers useful (but spammy) printouts. + +For example, the printouts for the first graph in the `toy_example` +are: + +``` +__compiled_fn_0 .1 +opcode name target args kwargs +------------- ------- ------------------------------------------------------ ---------------- -------- +placeholder a a () {} +placeholder b b () {} +call_function abs_1 (a,) {} +call_function add (abs_1, 1) {} +call_function truediv (a, add) {} +call_method sum_1 sum (b,) {} +call_function lt (sum_1, 0) {} +output output output ((truediv, lt),) {} +ORIGINAL BYTECODE toy_example example.py line 12 + 14 0 LOAD_FAST 0 (a) + 2 LOAD_GLOBAL 0 (torch) + 4 LOAD_METHOD 1 (abs) + 6 LOAD_FAST 0 (a) + 8 CALL_METHOD 1 + 10 LOAD_CONST 1 (1) + 12 BINARY_ADD + 14 BINARY_TRUE_DIVIDE + 16 STORE_FAST 2 (x) + 15 18 LOAD_FAST 1 (b) + 20 LOAD_METHOD 2 (sum) + 22 CALL_METHOD 0 + 24 LOAD_CONST 2 (0) + 26 COMPARE_OP 0 (<) + 28 POP_JUMP_IF_FALSE 19 (to 38) + 16 30 LOAD_FAST 1 (b) + 32 LOAD_CONST 3 (-1) + 34 BINARY_MULTIPLY + 36 STORE_FAST 1 (b) + 17 >> 38 LOAD_FAST 2 (x) + 40 LOAD_FAST 1 (b) + 42 BINARY_MULTIPLY + 44 RETURN_VALUE +MODIFIED BYTECODE toy_example example.py line 12 + 12 0 LOAD_GLOBAL 3 (__compiled_fn_0) + 2 LOAD_FAST 0 (a) + 4 LOAD_FAST 1 (b) + 6 CALL_FUNCTION 2 + 8 UNPACK_SEQUENCE 2 + 10 STORE_FAST 2 (x) + 12 POP_JUMP_IF_FALSE 12 (to 24) + 14 LOAD_GLOBAL 4 (__resume_at_30_1) + 16 LOAD_FAST 1 (b) + 18 LOAD_FAST 2 (x) + 20 CALL_FUNCTION 2 + 22 RETURN_VALUE + >> 24 LOAD_GLOBAL 5 (__resume_at_38_2) + 26 LOAD_FAST 1 (b) + 28 LOAD_FAST 2 (x) + 30 CALL_FUNCTION 2 + 32 RETURN_VALUE +possible source code: +def toy_example(a, b): + __temp_1 = __compiled_fn_0(a, b) + x = __temp_1[0] + if __temp_1[1]: + return __resume_at_30_1(b, x) + return __resume_at_38_2(b, x) +If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues. +``` + +At the top you can see the FX graph. +Next, you see the original bytecode of the function, followed by the +modified bytecode generated by Dynamo, and the decompiled source +code for reference. Finally, you see the guards which we covered above. + +In the modified bytecode, `__compiled_fn_0` is the return value of +`my_compiler()` (the compiled graph). `__resume_at_30_1` and +`__resume_at_38_2` are both generated continuation functions that pick +up execution after a graph break (at bytecode offsets 30 and 38). Each +of these functions take the form: + +``` +__resume_at_: + ... restore stack state if needed ... + JUMP_ABSOLUTE into toy_example + ... original bytecode of toy_example ... +``` + +By generating this `resume_at` function, we force the remainder of the +function to be executed in a new Python frame which recursively +triggers Dynamo to restart its capture once execution reaches that +point for the first time. + +### How to inspect artifacts generated by Dynamo? + +To inspect the artifacts generated by Dynamo, there is an API `torch._dynamo.eval_frame._debug_get_cache_entry_list` that retrieves compiled code and guards out of a function's `__code__` object. A compiled function can have several cache entries, and each cache entry consists a generated function to check guards, and a `types.CodeType` object to keep the code to be executed if the guarding conditions are satisfied. + +```python +from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn +cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example)) +cache_entry = cache_entries[0] +guard, code = cache_entry.check_fn, cache_entry.code +# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered. +import dis +dis.dis(guard) +dis.dis(code) +``` + +If you know Python bytecode, you can understand the above output. + +For the guard function, there is no need to inspect the bytecode. We can directly access its guarding conditions: + +```python +for code_part in guard.code_parts: + print(code_part) +``` + +The output is: + +``` +___guarded_code.valid +___check_global_state() +hasattr(L['a'], '_dynamo_dynamic_indices') == False +hasattr(L['b'], '_dynamo_dynamic_indices') == False +utils_device.CURRENT_DEVICE == None +___skip_backend_check() or ___current_backend() == ___lookup_backend(140215810860528) +___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names) +``` + +Only when all the conditions are satisfied, the guard function returns true, and the compiled code is executed. + +For the compiled code, we cannot directly access its source but have to decompile it. + +```python +from depyf import decompile +print(decompile(code)) +``` + +The output is: + +``` +def toy_example(a, b): + __temp_1 = __compiled_fn_0(a, b) + x = __temp_1[0] + if __temp_1[1]: + return __resume_at_30_1(b, x) + return __resume_at_38_2(b, x) +``` + +Some names referenced in the code are: + +- Compiled functions, stored in the global namespace of the module containing the original function `toy_example`. These include names like `__compiled_fn_0` / `__resume_at_30_1` / `__resume_at_38_2`. +- Closure variables used for checking guards. The names can be accessed from `guard.__code__.co_freevars`, and the values are stored in `guard.__closure__`. These include names like `___guarded_code` / `___is_grad_enabled` / `___are_deterministic_algorithms_enabled` / `___is_torch_function_enabled` / `utils_device` / `___check_tensors` / `tensor_check_names`. +- Argument `L` of the `guard` function. This is a dict mapping the name of arguments of `toy_example` to its values. This is only available when the function is called, where the frame evaluation API comes into play. In short, `L` is a `dict` with structure of `{'a': value_a, 'b': value_b}`. Therefore, you can see the code uses `L['a']` to refer to the input variable `a`. + +The graph break is shown in the code of compiled `toy_example`, where we have to use Python interpreter to select the following graph to execute. + +Note that we pass a simple `my_compiler` function as the backend compiler, therefore the subgraph code `__resume_at_38_2`, `__resume_at_30_1`, and `__compiled_fn_0` remain Python code. This can also be inspected (please ignore the function name, and only use the function signature and function body code): + +```python +print("source code of __compiled_fn_0:") +print(innermost_fn(__compiled_fn_0).__self__.code) +print("=" * 60) +print("source code of __resume_at_30_1:") +print(decompile(__resume_at_30_1)) +print("=" * 60) +print("source code of __resume_at_38_2:") +print(decompile(__resume_at_38_2)) +``` + +``` +source code of __compiled_fn_0: +def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor): + l_a_ = L_a_ + l_b_ = L_b_ + abs_1 = torch.abs(l_a_) + add = abs_1 + 1; abs_1 = None + truediv = l_a_ / add; l_a_ = add = None + sum_1 = l_b_.sum(); l_b_ = None + lt = sum_1 < 0; sum_1 = None + return (truediv, lt) +# To see more debug info, please use ``graph_module.print_readable()`` +============================================================ +source code of __resume_at_30_1: +def (b, x): + b = b * -1 + return x * b +============================================================ +source code of __resume_at_38_2: +def (b, x): + return x * b +``` + +However, if we use other backends like the built-in `inductor`, the subgraph code will be compiled CUDA kernels for GPU or C++ code for CPU. + +To summarize, the compiled code is conceptually equivalent to the code below: + +```python +def compiled_example(a, b): + L = {'a': a, 'b': b} + for guard, code in get_cache_entries(): + if guard(L): + return code(a, b) + recompile_and_add_another_cache_entry() +``` + +The following diagram demonstrates how `torch.compile` transforms and optimizes user-written code: it first extracts computation graphs from the user-written function, and compiles these graphs into optimized functions, then assembles them into a new function, which is functionally equivalent to the user-written code but optimized to have a good computation speed. + +```{image} _static/img/dynamo/flowchart.jpg +``` + +To learn more about how all this is implemented internally, see {ref}`torch.compiler_dynamo_deepdive`. \ No newline at end of file diff --git a/docs/source/torch.compiler_dynamo_overview.rst b/docs/source/torch.compiler_dynamo_overview.rst deleted file mode 100644 index fc7a8c5c292e93..00000000000000 --- a/docs/source/torch.compiler_dynamo_overview.rst +++ /dev/null @@ -1,350 +0,0 @@ -Dynamo Overview -=============== - -Before you read this section, read :ref:`torch.compiler_overview`. - -TorchDynamo (or simply Dynamo) is a Python-level Just-In-Time (JIT) compiler designed to make -unmodified PyTorch programs faster. Dynamo hooks into the frame evaluation -API in CPython (`PEP 523 `__) to -dynamically modify Python bytecode right before it is executed. It -rewrites Python bytecode to extract sequences of PyTorch -operations into an `FX Graph `__ -which is then compiled with a customizable backend. -It creates this FX Graph through bytecode analysis and is designed to -mix Python execution with compiled backends to get the best of both -worlds — usability and performance. - -Dynamo makes it easy to experiment with different compiler -backends to make PyTorch code faster with a single line decorator -``torch._dynamo.optimize()`` which is wrapped for convenience by ``torch.compile()`` - -The following diagram demonstrates how PyTorch works with ``torch.compile`` -and without it: - -.. image:: _static/img/dynamo/TorchDynamo.png - -`TorchInductor` is one of the backends -supported by `Dynamo Graph `__ -into `Triton `__ for GPUs or -`C++/OpenMP `__ for CPUs. We have a -`training performance dashboard `__ -that provides performance comparison for different training backends. You can read -more in the `TorchInductor post on PyTorch -dev-discuss `__. - -For an in-depth overview, read the sections below, watch the deep-dive video, -and check out the dev-discuss topics. - - * `Dynamo deep-dive video `__ - * `dev-discuss topics `__ - -Dynamo Internals -~~~~~~~~~~~~~~~~ -**Author**: `Jason Ansel `_ and `Kaichao You `_ - -This section will go over some of the Dynamo internals and will -demonstrate how Dynamo works under the hood. - -What is a guard? ----------------- - -Dynamo operates just-in-time and specializes graphs based on -dynamic properties. Below is a basic example of how to use Dynamo. -One can decorate a function or a method using ``torchdynamo.optimize`` to enable -Dynamo optimization: - -.. code-block:: python - - from typing import List - import torch - from torch import _dynamo as torchdynamo - def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("my_compiler() called with FX graph:") - gm.graph.print_tabular() - return gm.forward # return a python callable - - @torchdynamo.optimize(my_compiler) - def toy_example(a, b): - x = a / (torch.abs(a) + 1) - if b.sum() < 0: - b = b * -1 - return x * b - for _ in range(100): - toy_example(torch.randn(10), torch.randn(10)) - -For example, the first graph above has the following -guards: - -:: - - GUARDS: - hasattr(L['a'], '_dynamo_dynamic_indices') == False - hasattr(L['b'], '_dynamo_dynamic_indices') == False - utils_device.CURRENT_DEVICE == None - ___skip_backend_check() or ___current_backend() == ___lookup_backend(140355900538256) - check_tensor(L['a'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1]) - check_tensor(L['b'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1]) - -If any of those guards fail, the graph will be recaptured and -recompiled. The interesting guard there is ``check_tensor``, which -checks the following ``torch.Tensor`` properties: - -- Python class of the tensor (tensor subclassing, etc) -- dtype -- device -- requires_grad -- dispatch_key (with thread-local includes/excludes applied) -- ndim -- sizes\* -- strides\* - -The full specialization mode allows the backend compiler to assume an -entirely static graph. Unfortunately, most backends require this. -Operators which return dynamic shapes will trigger a graph break when -not in dynamic shape mode. - -What is Dynamo doing? ---------------------- - -If you want to understand better what Dynamo is doing, you can run your code with: - -:: - - TORCH_LOGS="+dynamo,guards,bytecode" - -If you are not familiar with Python bytecode, you can add a decompiler hook -to decompile the bytecode into human-readable source code. One available -tool is `depyf `__. If you don't have -``depyf`` already installed, run ``pip install depyf``. Then, add the -following code to install decompilation hooks before you run any code. - -.. code-block:: python - - import depyf - depyf.install() - -This code triggers useful (but spammy) printouts. - -For example, the printouts for the first graph in the ``toy_example`` -are: - -:: - - __compiled_fn_0 .1 - opcode name target args kwargs - ------------- ------- ------------------------------------------------------ ---------------- -------- - placeholder a a () {} - placeholder b b () {} - call_function abs_1 (a,) {} - call_function add (abs_1, 1) {} - call_function truediv (a, add) {} - call_method sum_1 sum (b,) {} - call_function lt (sum_1, 0) {} - output output output ((truediv, lt),) {} - - ORIGINAL BYTECODE toy_example example.py line 12 - 14 0 LOAD_FAST 0 (a) - 2 LOAD_GLOBAL 0 (torch) - 4 LOAD_METHOD 1 (abs) - 6 LOAD_FAST 0 (a) - 8 CALL_METHOD 1 - 10 LOAD_CONST 1 (1) - 12 BINARY_ADD - 14 BINARY_TRUE_DIVIDE - 16 STORE_FAST 2 (x) - - 15 18 LOAD_FAST 1 (b) - 20 LOAD_METHOD 2 (sum) - 22 CALL_METHOD 0 - 24 LOAD_CONST 2 (0) - 26 COMPARE_OP 0 (<) - 28 POP_JUMP_IF_FALSE 19 (to 38) - - 16 30 LOAD_FAST 1 (b) - 32 LOAD_CONST 3 (-1) - 34 BINARY_MULTIPLY - 36 STORE_FAST 1 (b) - - 17 >> 38 LOAD_FAST 2 (x) - 40 LOAD_FAST 1 (b) - 42 BINARY_MULTIPLY - 44 RETURN_VALUE - - - MODIFIED BYTECODE toy_example example.py line 12 - 12 0 LOAD_GLOBAL 3 (__compiled_fn_0) - 2 LOAD_FAST 0 (a) - 4 LOAD_FAST 1 (b) - 6 CALL_FUNCTION 2 - 8 UNPACK_SEQUENCE 2 - 10 STORE_FAST 2 (x) - 12 POP_JUMP_IF_FALSE 12 (to 24) - 14 LOAD_GLOBAL 4 (__resume_at_30_1) - 16 LOAD_FAST 1 (b) - 18 LOAD_FAST 2 (x) - 20 CALL_FUNCTION 2 - 22 RETURN_VALUE - >> 24 LOAD_GLOBAL 5 (__resume_at_38_2) - 26 LOAD_FAST 1 (b) - 28 LOAD_FAST 2 (x) - 30 CALL_FUNCTION 2 - 32 RETURN_VALUE - - - possible source code: - def toy_example(a, b): - __temp_1 = __compiled_fn_0(a, b) - x = __temp_1[0] - if __temp_1[1]: - return __resume_at_30_1(b, x) - return __resume_at_38_2(b, x) - - If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues. - -At the top you can see the FX graph. -Next, you see the original bytecode of the function, followed by the -modified bytecode generated by Dynamo, and the decompiled source -code for reference. Finally, you see the guards which we covered above. - -In the modified bytecode, ``__compiled_fn_0`` is the return value of -``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and -``__resume_at_38_2`` are both generated continuation functions that pick -up execution after a graph break (at bytecode offsets 30 and 38). Each -of these functions take the form: - -:: - - __resume_at_: - ... restore stack state if needed ... - JUMP_ABSOLUTE into toy_example - ... original bytecode of toy_example ... - -By generating this ``resume_at`` function, we force the remainder of the -function to be executed in a new Python frame which recursively -triggers Dynamo to restart its capture once execution reaches that -point for the first time. - -How to inspect artifacts generated by Dynamo? ---------------------------------------------- - -To inspect the artifacts generated by Dynamo, there is an API ``torch._dynamo.eval_frame._debug_get_cache_entry_list`` that retrieves compiled code and guards out of a function's ``__code__`` object. A compiled function can have several cache entries, and each cache entry consists a generated function to check guards, and a ``types.CodeType`` object to keep the code to be executed if the guarding conditions are satisfied. - -.. code-block:: python - - from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn - cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example)) - cache_entry = cache_entries[0] - guard, code = cache_entry.check_fn, cache_entry.code - # the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered. - import dis - dis.dis(guard) - dis.dis(code) - -If you know Python bytecode, you can understand the above output. - -For the guard function, there is no need to inspect the bytecode. We can directly access its guarding conditions: - -.. code-block:: python - - for code_part in guard.code_parts: - print(code_part) - -The output is: - -:: - - ___guarded_code.valid - ___check_global_state() - hasattr(L['a'], '_dynamo_dynamic_indices') == False - hasattr(L['b'], '_dynamo_dynamic_indices') == False - utils_device.CURRENT_DEVICE == None - ___skip_backend_check() or ___current_backend() == ___lookup_backend(140215810860528) - ___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names) - -Only when all the conditions are satisfied, the guard function returns true, and the compiled code is executed. - -For the compiled code, we cannot directly access its source but have to decompile it. - -.. code-block:: python - - from depyf import decompile - print(decompile(code)) - -The output is: - -:: - - def toy_example(a, b): - __temp_1 = __compiled_fn_0(a, b) - x = __temp_1[0] - if __temp_1[1]: - return __resume_at_30_1(b, x) - return __resume_at_38_2(b, x) - -Some names referenced in the code are: - -- Compiled functions, stored in the global namespace of the module containing the original function ``toy_example``. These include names like ``__compiled_fn_0`` / ``__resume_at_30_1`` / ``__resume_at_38_2``. - -- Closure variables used for checking guards. The names can be accessed from ``guard.__code__.co_freevars``, and the values are stored in ``guard.__closure__``. These include names like ``___guarded_code`` / ``___is_grad_enabled`` / ``___are_deterministic_algorithms_enabled`` / ``___is_torch_function_enabled`` / ``utils_device`` / ``___check_tensors`` / ``tensor_check_names``. - -- Argument ``L`` of the ``guard`` function. This is a dict mapping the name of arguments of ``toy_example`` to its values. This is only available when the function is called, where the frame evaluation API comes into play. In short, ``L`` is a ``dict`` with structure of ``{'a': value_a, 'b': value_b}``. Therefore, you can see the code uses ``L['a']`` to refer to the input variable ``a``. - -The graph break is shown in the code of compiled ``toy_example``, where we have to use Python interpreter to select the following graph to execute. - -Note that we pass a simple ``my_compiler`` function as the backend compiler, therefore the subgraph code ``__resume_at_38_2``, ``__resume_at_30_1``, and ``__compiled_fn_0`` remain Python code. This can also be inspected (please ignore the function name, and only use the function signature and function body code): - -.. code-block:: python - - print("source code of __compiled_fn_0:") - print(innermost_fn(__compiled_fn_0).__self__.code) - print("=" * 60) - print("source code of __resume_at_30_1:") - print(decompile(__resume_at_30_1)) - print("=" * 60) - print("source code of __resume_at_38_2:") - print(decompile(__resume_at_38_2)) - -:: - - source code of __compiled_fn_0: - - def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor): - l_a_ = L_a_ - l_b_ = L_b_ - abs_1 = torch.abs(l_a_) - add = abs_1 + 1; abs_1 = None - truediv = l_a_ / add; l_a_ = add = None - sum_1 = l_b_.sum(); l_b_ = None - lt = sum_1 < 0; sum_1 = None - return (truediv, lt) - - # To see more debug info, please use ``graph_module.print_readable()`` - ============================================================ - source code of __resume_at_30_1: - def (b, x): - b = b * -1 - return x * b - - ============================================================ - source code of __resume_at_38_2: - def (b, x): - return x * b - -However, if we use other backends like the built-in ``inductor``, the subgraph code will be compiled CUDA kernels for GPU or C++ code for CPU. - -To summarize, the compiled code is conceptually equivalent to the code below: - -.. code-block:: python - - def compiled_example(a, b): - L = {'a': a, 'b': b} - for guard, code in get_cache_entries(): - if guard(L): - return code(a, b) - recompile_and_add_another_cache_entry() - -The following diagram demonstrates how ``torch.compile`` transforms and optimizes user-written code: it first extracts computation graphs from the user-written function, and compiles these graphs into optimized functions, then assembles them into a new function, which is functionally equivalent to the user-written code but optimized to have a good computation speed. - -.. image:: _static/img/dynamo/flowchart.jpg - -To learn more about how all this is implemented internally, see :ref:`torch.compiler_dynamo_deepdive`. diff --git a/docs/source/torch.compiler_fake_tensor.md b/docs/source/torch.compiler_fake_tensor.md new file mode 100644 index 00000000000000..b53b1d87ad19e5 --- /dev/null +++ b/docs/source/torch.compiler_fake_tensor.md @@ -0,0 +1,157 @@ +# Fake tensor + +Code: [fake_tensor.py](https://github.com/pytorch/pytorch/blob/db4572dbf18f1cf50cf662547e272d3117063747/torch/_subclasses/fake_tensor.py) + +## Motivation + +When doing Dynamo symbolic evaluation and compiler passes, we often want to be able to run tensor operations to understand what output sizes/dtypes/devices are, without actually running those operations (or trashing preexisting tensors), which would be slower (if you're doing a lot of compute) and take a lot of memory (it's bad if your compiler needs to use GPU memory while you are compiling the program). A fake tensor is like a real tensor in all respects, except that it doesn't actually have any data. For example, when we do Dynamo tracing, we need to trace through user Tensor code and answer questions about intermediates (e.g., if a user does a conditional on an intermediate tensor). Without fake tensor, we would not have accurate information for these queries. + +Similarly, suppose you want to store metadata for a tensor, e.g., on an FX IR node (meta['val']). You can instead store a fake tensor directly on the node, which will give you all the metadata you need for the tensor, including subtle stuff that you probably wouldn't have handled (e.g., aliasing relationships). + +## Related work + +- A meta tensor is a tensor with device='meta'. This is actually a lot of what you want for fake tensor, but meta tensors don't model devices, and sometimes stride behavior varies depending on your device, so fake tensors really can get a lot more accurate info this way. Also, meta tensors are "global" (they exist on their own, similar to how a CPU/CUDA tensor exist on their own), whereas fake tensors are scoped to a FakeTensorMode. +- A tensor subclass lets you subclass torch.Tensor and customize their behavior. Fake tensors are implemented as a tensor subclass; that means almost all of its implementation lives in Python! For more simple examples of tensor subclasses check out [subclass_zoo](https://github.com/albanD/subclass_zoo/). +- Dynamic shapes allow you to create tensors with symbolic sizes rather than only concrete sizes, and propagate these sizes symbolically through operations. Dynamic shapes maintain state in a ShapeEnv, which is always associated with a FakeTensorMode (so fake tensors also are responsible for managing symbolic sizes.) In general, whenever we compile a subgraph with PT2, there is a tracing context associated with this compilation, which contains, among other things, a FakeTensorMode and (possibly) a ShapeEnv. + +## Overall architecture + +All fake tensors are associated with a FakeTensorMode. Because fake tensor's primary use case is to do analysis on real tensors, the general workflow is you have a bunch of real tensors, you allocate a FakeTensorMode, and then you use from_real_tensor to convert all those real tensors into fake tensors, and then you do things to the fake tensors. In particular, the FakeTensorMode maintains a memo table persistently mapping tensors (and storages) to the same storages. If you fakeify the same tensor multiple times, you will get the same fake tensor; if you fakeify two tensors which alias each other, you will get two fake tensors which alias the same fake storage. FakeTensors are tensor subclasses, so if you do operations on them, you'll automatically get a fake tensor, but in general you will want to do operations on fake tensors (e.g., if you're running an FX pass) with the FakeTensorMode active; what a tensor operation will do is automatically turn on the fake tensor mode and try again. + +A fake tensor is represented as a \_\_torch_dispatch\_\_ tensor subclass of a meta tensor. This means under the hood, fake tensors are meta device tensors; they then use extra extensibility hooks, specifically dispatch_device, to lie about what the actual device of the tensor is. This was one of the more error-prone parts of fake tensors in the early days: sometimes, fake tensors were too good at lying about being CPU/CUDA whatever, and you'd end up with a CPU kernel getting called with a fake tensor trying to dereference the data pointer, which obviously won't work. If you are segfaulting in fake tensor code, this is the first thing you should check: is the C++ backtrace in a CPU kernel (unexpected!) or a meta kernel (expected!) A meta kernel is like a real kernel, but all it does is allocate the outputs, it doesn't do any data compute. + +A tensor subclass has to define how to implement various operations. Here is the general fake tensor recipe: + +- Run the meta kernel on the input fake tensors, reinterpreting them as meta tensors. This is done via a magic context manager in_kernel_invocation_manager which instructs all of PyTorch to view fake tensors as their underlying meta tensors, rather than "unwrapping" fake tensors into meta tensors (a fake tensor is a meta tensor). Fake tensors are represented this way to avoid having to keep two sets of metadata in sync (the meta tensor's metadata, and the fake tensor's metadata); the "is a" relationship ensures there is only one canonical copy of metadata. +- If you're a factory function, you'll instead call the underlying factory function with device='meta'. +- Convert the resulting meta tensor into a fake tensor, computing what the output device of the tensor should be (this is usually trivial, but sometimes it is not, e.g., cpu scalar promotion, or device-converting operations.) + +## API: the important bits + +Non-PT2 usage (check out test/test_fake_tensor.py for more examples): + +```python +# Create a fake mode +from torch._subclasses.fake_tensor import FakeTensorMode +fake_mode = FakeTensorMode() +converter = fake_mode.fake_tensor_converter +# Fakeify some real tensors +fake_x = converter.from_real_tensor(fake_mode, x) +with fake_mode: + # Do some operations on the fake tensors + fake_y = fake_x * 2 + # Factory operations automatically get fakeified in the context manager + fake_z = torch.empty(20) +``` + +Q: Why do you have real tensors as inputs? + +A: In a PT2 context, this is because you typically are compiling just-in-time, so for all the inputs to a graph you're compiling, you already have the "real" inputs, because you're compiling while you're executing the program. + +PT2 pre-AOTAutograd usage (this is unusual, you probably don't want to do this): + +```python +# Fake mode is not enabled! +from torch._guards import detect_fake_mode +fake_mode = detect_fake_mode(args) +# if fake_mode isn't None +converter = fake_mode.fake_tensor_converter +fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args] +with fake_mode: + ... # do stuff with the fake args, if needed ... +``` + +detect_fake_mode will search a number of locations to try to find "the" fake tensor mode associated with the lifecycle. Typically it will be pulled off of the tracing context. + +PT2 post-AOTAutograd usage: + +```python +# Fake mode is enabled! example_inputs is typically fake already +# TODO: we probably want to change this +# Still do this to access fake mode +fake_mode = detect_fake_mode(example_inputs) +# But in general you don't have to turn it on +``` + +Other useful stuff: + +```python +from torch._subclasses.fake_tensor import unset_fake_temporarily +with unset_fake_temporarily(): + ... # fake mode is disabled here, you can do real tensor compute +``` + +When might you want to disable fake tensor mode? Usually you don't want to do this. One niche case where we've found it useful is to implement constant propagation on fake tensors: in this case, we need to do some actual tensor computation even though we're in a fake tensor mode. + +```python +import FakeTensorProp from torch.fx.passes.fake_tensor_prop +gm: GraphModule +real_inputs: List[Tensor] +FakeTensorProp(gm).propagate(*real_inputs) +# This will populate meta['val'] on all the FX nodes with a fake tensor +# or if you have a preexisting fake mode, you should use it +FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs) +# There is also propagate_dont_convert_inputs if your inputs are already fake +fake_inputs: List[FakeTensor] +FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs) +``` + +## Details + +Auto-convert or not? +Originally, FakeTensorMode would not automatically fakeify real tensors if you tried to do compute on them inside a FakeTensorMode region. The motivation behind this was to prevent the following footgun: + +```python +with FakeTensorMode(): + real_tensor.t_() +``` + +What should this code do? It would be surprising if we actually modified the metadata on the real tensor. But at the same time, there isn't any obvious opportunity to create a FakeTensor. So we conservatively decided to make this raise an error: "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. Please convert all Tensors to FakeTensors first." + +This error is pretty annoying in practice. For example, suppose you have a real nn.Module and you want to feed fake tensors through it. You need to somehow fakeify the nn.Module. This motivated FakeCopyMode. + +Eventually, we gave up and added automatic fakeification. However, this is still not yet enabled by default in many uses of FakeTensorMode. + +Metadata mutation on fake tensor +If you have a fake tensor, and you t\_() it, the metadata on the fake tensor changes. This is reasonable on its face, but sometimes you want to also store fake tensors as metadata on FX nodes; mutating a fake tensor is bad because this will invalidate old metadata! + +In fact, there is a fundamental tension here, which is that fake tensors maintain extremely accurate metadata about tensors, up to and including object identity. If object metadata changes over time in an FX graph, there is not actually any way to represent this change over time. Most of the time, our serious FX analyses are done on functionalized graphs, which don't have this, but occasionally you need to do an analysis on a non-functionalized graph. Maybe it was a mistake to put fake tensor in meta['val'] + +## About the tensor subclass + +Fake tensor uses both a subclass and a mode tensor subclass pattern, where FakeTensor.\_\_torch_dispatch\_\_ enables the FakeTensorMode associated with the fake tensor, and then redispatches (relying on FakeTensorMode to do the heavy lifting). If fake tensor operations get a subclass argument it doesn't recognize, it will return NotImplemented, giving the other subclass a chance to run first (hopefully desugaring into plain tensor operations), before it tries again. This can cause infinite loops. + +## How is each individual operator implemented? + +Unfortunately, there is a pretty complicated set of places where any given operator may be implemented. Some important cases to know about: + +- Tensor subclasses support limited constant propagation if the number of elements is very small (this helps deal with some cases where we immediately call item() on such tensors.) +- We have some fastpath implementations for certain operators, which are done entirely in fake tensor, for performance reasons. +- If you use @custom_op to generate a custom tensor, these will register impl_abstract directly to fake tensor. +- Fake tensor itself has some hardcoded special cases for device-converting operations. +- If there is no meta implementation nor any decomposition, we will generate real zero-filled tensors and attempt to run the operator directly to find out what the results will be. This can cause segfaults if the operator attempts to do indexing with data, so we don't turn this on by default for custom ops. + +## How does the converter work? + +Because fake tensors are used in situations that are very sensitive to the exact properties of a tensor, fake tensors do conversion very carefully, preserving leaf-ness, requires_grad'ness, aliasing, and a whole host of other properties. The bulk of the heavy lifting is in MetaConverter. + +## Performance characteristics + +You would think fake tensors are fast because they don't do any tensor compute. But at small tensor sizes we are actually entirely overhead bound, and, well, fake tensor is in Python, and we often do a LOT of work to do a single tensor operation (because they are implemented as decompositions). So fake tensors are actually pretty slow in practice, especially when symbolic shapes are involved. There are two important fastpaths we currently have in fake tensor that make a big difference in practice: + +- Pointwise ops don't go through PrimTorch decomps, instead we've hand-coded their propagation rule. +- If possible, we should. + +## Fake tensor of fake tensor? + +There is interest in sending fake tensors as user inputs into the PT2 stack, which would imply we would need to be able to create a fake tensor of a fake tensor. This isn't really supported right now, but maybe it would not be too difficult to do. + +## Interaction with dynamic shapes + +Every FakeTensorMode contains a ShapeEnv, which tracks all symbolic shapes information. Their lifetimes are typically tied: they live and die together. + +Because FakeTensorMode has a ShapeEnv (but meta implementations do not), meta functions that are data-dependent and require allocating an unbacked SymInt live in fake tensor. Fake tensor also takes care of memoizing unbacked SymInts, so that, e.g., if you call nonzero() on the same fake tensor twice, you get the same symbolic size. + +## Other resources + +[Colab Tutorial On Using FakeTensor To Determine Max Batch Size](https://colab.research.google.com/drive/1zjAisRrc8R6uixKsrs1DRm3lwz5MWN68) diff --git a/docs/source/torch.compiler_fake_tensor.rst b/docs/source/torch.compiler_fake_tensor.rst deleted file mode 100644 index 41d9b25d662674..00000000000000 --- a/docs/source/torch.compiler_fake_tensor.rst +++ /dev/null @@ -1,176 +0,0 @@ -Fake tensor -=========== - -Code: `fake_tensor.py `_ - -Motivation ----------- - -When doing Dynamo symbolic evaluation and compiler passes, we often want to be able to run tensor operations to understand what output sizes/dtypes/devices are, without actually running those operations (or trashing preexisting tensors), which would be slower (if you're doing a lot of compute) and take a lot of memory (it's bad if your compiler needs to use GPU memory while you are compiling the program). A fake tensor is like a real tensor in all respects, except that it doesn't actually have any data. For example, when we do Dynamo tracing, we need to trace through user Tensor code and answer questions about intermediates (e.g., if a user does a conditional on an intermediate tensor). Without fake tensor, we would not have accurate information for these queries. - -Similarly, suppose you want to store metadata for a tensor, e.g., on an FX IR node (meta['val']). You can instead store a fake tensor directly on the node, which will give you all the metadata you need for the tensor, including subtle stuff that you probably wouldn't have handled (e.g., aliasing relationships). - -Related work ------------- - -- A meta tensor is a tensor with device='meta'. This is actually a lot of what you want for fake tensor, but meta tensors don't model devices, and sometimes stride behavior varies depending on your device, so fake tensors really can get a lot more accurate info this way. Also, meta tensors are "global" (they exist on their own, similar to how a CPU/CUDA tensor exist on their own), whereas fake tensors are scoped to a FakeTensorMode. - -- A tensor subclass lets you subclass torch.Tensor and customize their behavior. Fake tensors are implemented as a tensor subclass; that means almost all of its implementation lives in Python! For more simple examples of tensor subclasses check out `subclass_zoo `_. - -- Dynamic shapes allow you to create tensors with symbolic sizes rather than only concrete sizes, and propagate these sizes symbolically through operations. Dynamic shapes maintain state in a ShapeEnv, which is always associated with a FakeTensorMode (so fake tensors also are responsible for managing symbolic sizes.) In general, whenever we compile a subgraph with PT2, there is a tracing context associated with this compilation, which contains, among other things, a FakeTensorMode and (possibly) a ShapeEnv. - -Overall architecture --------------------- - -All fake tensors are associated with a FakeTensorMode. Because fake tensor's primary use case is to do analysis on real tensors, the general workflow is you have a bunch of real tensors, you allocate a FakeTensorMode, and then you use from_real_tensor to convert all those real tensors into fake tensors, and then you do things to the fake tensors. In particular, the FakeTensorMode maintains a memo table persistently mapping tensors (and storages) to the same storages. If you fakeify the same tensor multiple times, you will get the same fake tensor; if you fakeify two tensors which alias each other, you will get two fake tensors which alias the same fake storage. FakeTensors are tensor subclasses, so if you do operations on them, you'll automatically get a fake tensor, but in general you will want to do operations on fake tensors (e.g., if you're running an FX pass) with the FakeTensorMode active; what a tensor operation will do is automatically turn on the fake tensor mode and try again. - -A fake tensor is represented as a __torch_dispatch__ tensor subclass of a meta tensor. This means under the hood, fake tensors are meta device tensors; they then use extra extensibility hooks, specifically dispatch_device, to lie about what the actual device of the tensor is. This was one of the more error-prone parts of fake tensors in the early days: sometimes, fake tensors were too good at lying about being CPU/CUDA whatever, and you'd end up with a CPU kernel getting called with a fake tensor trying to dereference the data pointer, which obviously won't work. If you are segfaulting in fake tensor code, this is the first thing you should check: is the C++ backtrace in a CPU kernel (unexpected!) or a meta kernel (expected!) A meta kernel is like a real kernel, but all it does is allocate the outputs, it doesn't do any data compute. - -A tensor subclass has to define how to implement various operations. Here is the general fake tensor recipe: - -- Run the meta kernel on the input fake tensors, reinterpreting them as meta tensors. This is done via a magic context manager in_kernel_invocation_manager which instructs all of PyTorch to view fake tensors as their underlying meta tensors, rather than "unwrapping" fake tensors into meta tensors (a fake tensor is a meta tensor). Fake tensors are represented this way to avoid having to keep two sets of metadata in sync (the meta tensor's metadata, and the fake tensor's metadata); the "is a" relationship ensures there is only one canonical copy of metadata. - -- If you're a factory function, you'll instead call the underlying factory function with device='meta'. - -- Convert the resulting meta tensor into a fake tensor, computing what the output device of the tensor should be (this is usually trivial, but sometimes it is not, e.g., cpu scalar promotion, or device-converting operations.) - -API: the important bits ------------------------ - -Non-PT2 usage (check out test/test_fake_tensor.py for more examples): - -.. code:: python - - # Create a fake mode - from torch._subclasses.fake_tensor import FakeTensorMode - fake_mode = FakeTensorMode() - converter = fake_mode.fake_tensor_converter - # Fakeify some real tensors - fake_x = converter.from_real_tensor(fake_mode, x) - with fake_mode: - # Do some operations on the fake tensors - fake_y = fake_x * 2 - # Factory operations automatically get fakeified in the context manager - fake_z = torch.empty(20) - -Q: Why do you have real tensors as inputs? - -A: In a PT2 context, this is because you typically are compiling just-in-time, so for all the inputs to a graph you're compiling, you already have the "real" inputs, because you're compiling while you're executing the program. - -PT2 pre-AOTAutograd usage (this is unusual, you probably don't want to do this): - -.. code:: python - - - # Fake mode is not enabled! - from torch._guards import detect_fake_mode - fake_mode = detect_fake_mode(args) - # if fake_mode isn't None - converter = fake_mode.fake_tensor_converter - fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args] - with fake_mode: - ... # do stuff with the fake args, if needed ... - -detect_fake_mode will search a number of locations to try to find "the" fake tensor mode associated with the lifecycle. Typically it will be pulled off of the tracing context. - -PT2 post-AOTAutograd usage: - -.. code:: python - - - # Fake mode is enabled! example_inputs is typically fake already - # TODO: we probably want to change this - # Still do this to access fake mode - fake_mode = detect_fake_mode(example_inputs) - # But in general you don't have to turn it on - -Other useful stuff: - -.. code:: python - - from torch._subclasses.fake_tensor import unset_fake_temporarily - with unset_fake_temporarily(): - ... # fake mode is disabled here, you can do real tensor compute - -When might you want to disable fake tensor mode? Usually you don't want to do this. One niche case where we've found it useful is to implement constant propagation on fake tensors: in this case, we need to do some actual tensor computation even though we're in a fake tensor mode. - -.. code:: python - - import FakeTensorProp from torch.fx.passes.fake_tensor_prop - gm: GraphModule - real_inputs: List[Tensor] - FakeTensorProp(gm).propagate(*real_inputs) - # This will populate meta['val'] on all the FX nodes with a fake tensor - # or if you have a preexisting fake mode, you should use it - FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs) - # There is also propagate_dont_convert_inputs if your inputs are already fake - fake_inputs: List[FakeTensor] - FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs) - -Details -------- - -Auto-convert or not? -Originally, FakeTensorMode would not automatically fakeify real tensors if you tried to do compute on them inside a FakeTensorMode region. The motivation behind this was to prevent the following footgun: - -.. code:: python - - with FakeTensorMode(): - real_tensor.t_() - -What should this code do? It would be surprising if we actually modified the metadata on the real tensor. But at the same time, there isn't any obvious opportunity to create a FakeTensor. So we conservatively decided to make this raise an error: "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. Please convert all Tensors to FakeTensors first." - -This error is pretty annoying in practice. For example, suppose you have a real nn.Module and you want to feed fake tensors through it. You need to somehow fakeify the nn.Module. This motivated FakeCopyMode. - -Eventually, we gave up and added automatic fakeification. However, this is still not yet enabled by default in many uses of FakeTensorMode. - -Metadata mutation on fake tensor -If you have a fake tensor, and you t_() it, the metadata on the fake tensor changes. This is reasonable on its face, but sometimes you want to also store fake tensors as metadata on FX nodes; mutating a fake tensor is bad because this will invalidate old metadata! - -In fact, there is a fundamental tension here, which is that fake tensors maintain extremely accurate metadata about tensors, up to and including object identity. If object metadata changes over time in an FX graph, there is not actually any way to represent this change over time. Most of the time, our serious FX analyses are done on functionalized graphs, which don't have this, but occasionally you need to do an analysis on a non-functionalized graph. Maybe it was a mistake to put fake tensor in meta['val'] - -About the tensor subclass -------------------------- - -Fake tensor uses both a subclass and a mode tensor subclass pattern, where FakeTensor.__torch_dispatch__ enables the FakeTensorMode associated with the fake tensor, and then redispatches (relying on FakeTensorMode to do the heavy lifting). If fake tensor operations get a subclass argument it doesn't recognize, it will return NotImplemented, giving the other subclass a chance to run first (hopefully desugaring into plain tensor operations), before it tries again. This can cause infinite loops. - -How is each individual operator implemented? --------------------------------------------- - -Unfortunately, there is a pretty complicated set of places where any given operator may be implemented. Some important cases to know about: - -- Tensor subclasses support limited constant propagation if the number of elements is very small (this helps deal with some cases where we immediately call item() on such tensors.) -- We have some fastpath implementations for certain operators, which are done entirely in fake tensor, for performance reasons. -- If you use @custom_op to generate a custom tensor, these will register impl_abstract directly to fake tensor. -- Fake tensor itself has some hardcoded special cases for device-converting operations. -- If there is no meta implementation nor any decomposition, we will generate real zero-filled tensors and attempt to run the operator directly to find out what the results will be. This can cause segfaults if the operator attempts to do indexing with data, so we don't turn this on by default for custom ops. - -How does the converter work? ----------------------------- - -Because fake tensors are used in situations that are very sensitive to the exact properties of a tensor, fake tensors do conversion very carefully, preserving leaf-ness, requires_grad'ness, aliasing, and a whole host of other properties. The bulk of the heavy lifting is in MetaConverter. - -Performance characteristics ---------------------------- - -You would think fake tensors are fast because they don't do any tensor compute. But at small tensor sizes we are actually entirely overhead bound, and, well, fake tensor is in Python, and we often do a LOT of work to do a single tensor operation (because they are implemented as decompositions). So fake tensors are actually pretty slow in practice, especially when symbolic shapes are involved. There are two important fastpaths we currently have in fake tensor that make a big difference in practice: - -- Pointwise ops don't go through PrimTorch decomps, instead we've hand-coded their propagation rule. -- If possible, we should. - -Fake tensor of fake tensor? ----------------------------- - -There is interest in sending fake tensors as user inputs into the PT2 stack, which would imply we would need to be able to create a fake tensor of a fake tensor. This isn't really supported right now, but maybe it would not be too difficult to do. - -Interaction with dynamic shapes -------------------------------- - -Every FakeTensorMode contains a ShapeEnv, which tracks all symbolic shapes information. Their lifetimes are typically tied: they live and die together. - -Because FakeTensorMode has a ShapeEnv (but meta implementations do not), meta functions that are data-dependent and require allocating an unbacked SymInt live in fake tensor. Fake tensor also takes care of memoizing unbacked SymInts, so that, e.g., if you call nonzero() on the same fake tensor twice, you get the same symbolic size. - -Other resources ---------------- - -`Colab Tutorial On Using FakeTensor To Determine Max Batch Size `_ diff --git a/docs/source/torch.compiler_faq.md b/docs/source/torch.compiler_faq.md new file mode 100644 index 00000000000000..7a8eaaa5215fab --- /dev/null +++ b/docs/source/torch.compiler_faq.md @@ -0,0 +1,630 @@ +# Frequently Asked Questions + +**Author**: [Mark Saroufim](https://github.com/msaroufim) + +## Does `torch.compile` support training? + +`torch.compile` supports training, using AOTAutograd to capture backwards: + +1. The `.forward()` graph and `optimizer.step()` is captured by + TorchDynamo’s python `evalframe` frontend. +2. For each segment of `.forward()` that torchdynamo captures, it uses + AOTAutograd to generate a backward graph segment. +3. Each pair of forward and backward graph are (optionally) min-cut + partitioned to save the minimal state between forward and backward. +4. The forward and backward pairs are wrapped in `autograd.function` modules. +5. User code calling `.backward()` still triggers eager’s autograd engine, + which runs each *compiled backward* graph as if it were one op, also running + any non-compiled eager ops’ `.backward()` functions. + +## Do you support Distributed code? + +`torch.compile` supports `DistributedDataParallel` (DDP). +Support for other distributed training libraries is being considered. + +The main reason why Distributed code is challenging with dynamo is +because AOTAutograd unrolls both the forward and backward pass and +provides 2 graphs for backends to optimize. This is a problem for +distributed code because we’d like to ideally overlap communication +operations with computations. Eager pytorch accomplishes this in +different ways for DDP/FSDP- using autograd hooks, module hooks, and +modifications/mutations of module states. In a naive application of +dynamo, hooks that should run directly after an operation during +backwards may be delayed until after the entire compiled region of +backwards ops, due to how AOTAutograd compiled functions interact with +dispatcher hooks. + +The basic strategy for optimizing DDP with Dynamo is outlined in +[distributed.py](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/backends/distributed.py) +where the main idea will be to graph break on [DDP bucket +boundaries](https://pytorch.org/docs/stable/notes/ddp.html#internal-design). + +When each node in DDP needs to synchronize its weights with the other +nodes it organizes its gradients and parameters into buckets which +reduces communication times and allows a node to broadcast a fraction of +its gradients to other waiting nodes. + +Graph breaks in distributed code mean you can expect dynamo and its +backends to optimize the compute overhead of a distributed program but +not its communication overhead. Graph-breaks may interfere with +compilation speedups, if the reduced graph-size robs the compiler of +fusion opportunities. However, there are diminishing returns with +increasing graph size since most of the current compute optimizations +are local fusions. So in practice this approach may be sufficient. + +## Do I still need to export whole graphs? + +For the vast majority of models you probably don’t and you can use +`torch.compile()` as is but there are a few situations where +full graphs are necessary and you can can ensure a full graph by simply +running `torch.compile(..., fullgraph=True)`. These situations include: + +- Large scale training runs, such as $250K+ that require pipeline parallelism + and other advanced sharding strategies. +- Inference optimizers like [TensorRT](https://github.com/pytorch/TensorRT) + or [AITemplate](https://github.com/facebookincubator/AITemplate) that + rely on fusing much more aggressively than training optimizers. +- Mobile training or inference. + +Future work will include tracing communication operations into graphs, +coordinating these operations with compute optimizations, and optimizing +the communication operations. + +## Why is my code crashing? + +If your code ran just fine without `torch.compile` and started to +crash with it is enabled, then the most important first step is figuring +out which part of the stack your failure occurred. To troubleshoot that, +follow the steps below and only try the next step if the previous one +succeeded. + +1. `torch.compile(..., backend="eager")` which only runs TorchDynamo + forward graph capture and then runs the captured graph with PyTorch. + If this fails then there’s an issue with TorchDynamo. +2. `torch.compile(..., backend="aot_eager")` + which runs TorchDynamo to capture a forward graph, and then AOTAutograd + to trace the backward graph without any additional backend compiler + steps. PyTorch eager will then be used to run the forward and backward + graphs. If this fails then there’s an issue with AOTAutograd. +3. `torch.compile(..., backend="inductor")` which runs TorchDynamo to capture a + forward graph, and then AOTAutograd to trace the backward graph with the + TorchInductor compiler. If this fails then there’s an issue with TorchInductor + +## Why is compilation slow? + +- **Dynamo Compilation**– TorchDynamo has a builtin stats function for + collecting and displaying the time spent in each compilation phase. + These stats can be accessed by calling `torch._dynamo.utils.compile_times()` + after executing `torch._dynamo`. By default, this returns a string + representation of the compile times spent in each TorchDynamo function by name. +- **Inductor Compilation**– TorchInductor has a builtin stats and trace function + for displaying time spent in each compilation phase, output code, output + graph visualization and IR dump. `env TORCH_COMPILE_DEBUG=1 python repro.py`. + This is a debugging tool designed to make it easier to debug/understand the + internals of TorchInductor with an output that will look something like + [this](https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396) + Each file in that debug trace can be enabled/disabled via + `torch._inductor.config.trace.*`. The profile and the diagram are both + disabled by default since they are expensive to generate. See the + [example debug directory + output](https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396) + for more examples. +- **Excessive Recompilation** + When TorchDynamo compiles a function (or part of one), it makes certain + assumptions about locals and globals in order to allow compiler + optimizations, and expresses these assumptions as guards that check + particular values at runtime. If any of these guards fail, Dynamo will + recompile that function (or part) up to + `torch._dynamo.config.recompile_limit` times. If your program is + hitting the cache limit, you will first need to determine which guard is + failing and what part of your program is triggering it. The + Use `TORCH_TRACE/tlparse` or `TORCH_LOGS=recompiles` to trace the root of the issue, check {ref}`torch.compiler_troubleshooting` for more details. +## Why are you recompiling in production? + +In some cases, you may not want unexpected compiles after a program has +warmed up. For example, if you are serving production traffic in a +latency critical application. For this, TorchDynamo provides an +alternate mode where prior compiled graphs are used, but no new ones are +generated: + +```python +frozen_toy_example = dynamo.run(toy_example) +frozen_toy_example(torch.randn(10), torch.randn(10)) +``` + +## How are you speeding up my code? + +There are 3 major ways to accelerate PyTorch code: + +1. Kernel fusion via vertical fusions which fuse sequential operations to avoid + excessive read/writes. For example, fuse 2 subsequent cosines means you + can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion: + the simplest example being batching where a single matrix is multiplied + with a batch of examples but the more general scenario is a grouped GEMM + where a group of matrix multiplications are scheduled together +2. Out of order execution: A general optimization for compilers, by looking ahead + at the exact data dependencies within a graph we can decide on the most + opportune time to execute a node and which buffers can be reused +3. Automatic work placement: Similar of the out of order execution point, + but by matching nodes of a graph to resources like physical hardware or + memory we can design an appropriate schedule + +The above are general principles for accelerating PyTorch code but +different backends will each make different tradeoffs on what to +optimize. For example Inductor first takes care of fusing whatever it +can and only then generates [Triton](https://openai.com/blog/triton/) +kernels. + +Triton in addition offers speedups because of automatic memory +coalescing, memory management and scheduling within each Streaming +Multiprocessor and has been designed to handle tiled computations. + +However, regardless of the backend you use it’s best to use a benchmark +and see approach so try out the PyTorch profiler, visually inspect the +generated kernels and try to see what’s going on for yourself. + + +(torch.compiler_graph_breaks)= +## Why am I not seeing speedups? + +### Graph Breaks + +The main reason you won’t see the speedups you’d like to by using dynamo +is excessive graph breaks. So what’s a graph break? + +Given a program like: + +```python +def some_fun(x): + ... + +torch.compile(some_fun)(x) +... +``` + +Torchdynamo will attempt to compile all of the torch/tensor operations +within `some_fun()` into a single FX graph, but it may fail to capture +everything into one graph. + +Some graph break reasons are insurmountable to TorchDynamo like calling +into a C extension other than PyTorch is invisible to TorchDynamo, and +could do arbitrary things without TorchDynamo being able to introduce +necessary guards to ensure that the compiled program would be safe to reuse. + +> To maximize performance, it’s important to have as few graph breaks +> as possible. +### Identifying the cause of a graph break + +To identify all graph breaks in a program and the associated reasons for +the breaks, `torch._dynamo.explain` can be used. This tool runs +TorchDynamo on the supplied function and aggregates the graph breaks +that are encountered. Here is an example usage: + +```python +import torch +import torch._dynamo as dynamo +def toy_example(a, b): + x = a / (torch.abs(a) + 1) + print("woo") + if b.sum() < 0: + b = b * -1 + return x * b +explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10)) +print(explanation) +""" +Graph Count: 3 +Graph Break Count: 2 +Op Count: 5 +Break Reasons: + Break Reason 1: + Reason: builtin: print [] False + User Stack: + + Break Reason 2: + Reason: generic_jump TensorVariable() + User Stack: + +Ops per Graph: + ... +Out Guards: + ... +""" +``` + +To throw an error on the first graph break encountered you can +disable python fallbacks by using `fullgraph=True`, this should be +familiar if you’ve worked with export based compilers. + +```python +def toy_example(a, b): + ... + +torch.compile(toy_example, fullgraph=True, backend=)(a, b) +``` + +### Why didn’t my code recompile when I changed it? + +If you enabled dynamic shapes by setting +`env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py` then your code +won’t recompile on shape changes. We’ve added support for dynamic shapes +which avoids recompilations in the case when shapes vary by less than a +factor of 2. This is especially useful in scenarios like varying image +sizes in CV or variable sequence length in NLP. In inference scenarios +it’s often not possible to know what a batch size will be beforehand +because you take what you can get from different client apps. + +In general, TorchDynamo tries very hard not to recompile things +unnecessarily so if for example TorchDynamo finds 3 graphs and your +change only modified one graph then only that graph will recompile. So +another tip to avoid potentially slow compilation times is to warmup a +model by compiling it once after which subsequent compilations will be +much faster. Cold start compile times is still a metric we track +visibly. + +## Why am I getting incorrect results? + +Accuracy issues can also be minified if you set the environment variable +`TORCHDYNAMO_REPRO_LEVEL=4`, it operates with a similar git bisect +model and a full repro might be something like +`TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4` the reason +we need this is downstream compilers will codegen code whether it’s +Triton code or the C++ backend, the numerics from those downstream +compilers can be different in subtle ways yet have dramatic impact on +your training stability. So the accuracy debugger is very useful for us +to detect bugs in our codegen or with a backend compiler. + +If you'd like to ensure that random number generation is the same across both torch +and triton then you can enable `torch._inductor.config.fallback_random = True` + +## Why am I getting OOMs? + +Dynamo is still an alpha product so there’s a few sources of OOMs and if +you’re seeing an OOM try disabling the following configurations in this +order and then open an issue on GitHub so we can solve the root problem +1\. If you’re using dynamic shapes try disabling them, we’ve disabled +them by default: `env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py` 2. +CUDA graphs with Triton are enabled by default in inductor but removing +them may alleviate some OOM issues: `torch._inductor.config.triton.cudagraphs = False`. + +## Does `torch.func` work with `torch.compile` (for `grad` and `vmap` transforms)? + +Applying a `torch.func` transform to a function that uses `torch.compile` +does work: + +```python +import torch + +@torch.compile +def f(x): + return torch.sin(x) + +def g(x): + return torch.grad(f)(x) + +x = torch.randn(2, 3) +g(x) +``` + +### Calling `torch.func` transform inside of a function handled with `torch.compile` + +### Compiling `torch.func.grad` with `torch.compile` + +```python +import torch + +def wrapper_fn(x): + return torch.func.grad(lambda x: x.sin().sum())(x) + +x = torch.randn(3, 3, 3) +grad_x = torch.compile(wrapper_fn)(x) +``` + +### Compiling `torch.vmap` with `torch.compile` + +```python +import torch + +def my_fn(x): + return torch.vmap(lambda x: x.sum(1))(x) + +x = torch.randn(3, 3, 3) +output = torch.compile(my_fn)(x) +``` + +### Compiling functions besides the ones which are supported (escape hatch) + +For other transforms, as a workaround, use `torch._dynamo.allow_in_graph` + +`allow_in_graph` is an escape hatch. If your code does not work with +`torch.compile`, which introspects Python bytecode, but you believe it +will work via a symbolic tracing approach (like `jax.jit`), then use +`allow_in_graph`. + +By using `allow_in_graph` to annotate a function, you must make sure +your code meets the following requirements: + +- All outputs in your function only depend on the inputs and + do not depend on any captured Tensors. +- Your function is functional. That is, it does not mutate any state. This may + be relaxed; we actually support functions that appear to be functional from + the outside: they may have in-place PyTorch operations, but may not mutate + global state or inputs to the function. +- Your function does not raise data-dependent errors. + +```python +import torch + +@torch.compile +def f(x): + return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x) + +x = torch.randn(2, 3) +f(x) +``` + +A common pitfall is using `allow_in_graph` to annotate a function that +invokes an `nn.Module`. This is because the outputs now depend on the +parameters of the `nn.Module`. To get this to work, use +`torch.func.functional_call` to extract the module state. + +## Does NumPy work with `torch.compile`? + +Starting in 2.1, `torch.compile` understands native NumPy programs that +work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch +to NumPy and back via `x.numpy()`, `torch.from_numpy`, and related functions. + +(nonsupported-numpy-feats)= + +### Which NumPy features does `torch.compile` support? + +NumPy within `torch.compile` follows NumPy 2.0 pre-release. + +Generally, `torch.compile` is able to trace through most NumPy constructions, +and when it cannot, it falls back to eager and lets NumPy execute that piece of +code. Even then, there are a few features where `torch.compile` semantics +slightly deviate from those of NumPy: + +- NumPy scalars: We model them as 0-D arrays. That is, `np.float32(3)` returns + a 0-D array under `torch.compile`. To avoid a graph break, it is best to use this 0-D + array. If this breaks your code, you can workaround this by casting the NumPy scalar + to the relevant Python scalar type `bool/int/float`. +- Negative strides: `np.flip` and slicing with a negative step return a copy. +- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules + are described in [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html). + `torch.compile` implements NEP 50 rather than the current soon-to-be deprecated rules. +- `{tril,triu}_indices_from/{tril,triu}_indices` return arrays rather than a tuple of arrays. + +There are other features for which we do not support tracing and we gracefully +fallback to NumPy for their execution: + +- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays. +- Long dtypes `np.float128/np.complex256` and some unsigned dtypes `np.uint16/np.uint32/np.uint64`. +- `ndarray` subclasses. +- Masked arrays. +- Esoteric ufunc machinery like `axes=[(n,k),(k,m)->(n,m)]` and ufunc methods (e.g., `np.add.reduce`). +- Sorting / ordering `complex64/complex128` arrays. +- NumPy `np.poly1d` and `np.polynomial`. +- Positional `out1, out2` args in functions with 2 or more returns (`out=tuple` does work). +- `__array_function__`, `__array_interface__` and `__array_wrap__`. +- `ndarray.ctypes` attribute. + +### Can I compile NumPy code using `torch.compile`? + +Of course you do! `torch.compile` understands NumPy code natively, and treats it +as if it were PyTorch code. To do so, simply wrap NumPy code with the `torch.compile` +decorator. + +```python +import torch +import numpy as np + +@torch.compile +def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: + return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) + +X = np.random.randn(1024, 64) +Y = np.random.randn(1024, 64) +Z = numpy_fn(X, Y) +assert isinstance(Z, np.ndarray) +``` + +Executing this example with the environment variable `TORCH_LOGS=output_code`, we can see +that `torch.compile` was able to fuse the multiplication and the sum into one C++ kernel. +It was also able to execute them in parallel using OpenMP (native NumPy is single-threaded). +This can easily make your NumPy code `n` times faster, where `n` is the number of cores +in your processor! + +Tracing NumPy code this way also supports graph breaks within the compiled code. + +### Can I execute NumPy code on CUDA and compute gradients via `torch.compile`? + +Yes you can! To do so, you may simply execute your code within a `torch.device("cuda")` +context. Consider the example + +```python +import torch +import numpy as np + +@torch.compile +def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: + return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) + +X = np.random.randn(1024, 64) +Y = np.random.randn(1024, 64) +with torch.device("cuda"): + Z = numpy_fn(X, Y) +assert isinstance(Z, np.ndarray) +``` + +In this example, `numpy_fn` will be executed in CUDA. For this to be +possible, `torch.compile` automatically moves `X` and `Y` from CPU +to CUDA, and then it moves the result `Z` from CUDA to CPU. If we are +executing this function several times in the same program run, we may want +to avoid all these rather expensive memory copies. To do so, we just need +to tweak our `numpy_fn` so that it accepts cuda Tensors and returns tensors. +We can do so by using `torch.compiler.wrap_numpy`: + +```python +@torch.compile(fullgraph=True) +@torch.compiler.wrap_numpy +def numpy_fn(X, Y): + return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) + +X = torch.randn(1024, 64, device="cuda") +Y = torch.randn(1024, 64, device="cuda") +Z = numpy_fn(X, Y) +assert isinstance(Z, torch.Tensor) +assert Z.device.type == "cuda" +``` + +Here, we explicitly create the tensors in CUDA memory, and pass them to the +function, which performs all the computations on the CUDA device. +`wrap_numpy` is in charge of marking any `torch.Tensor` input as an input +with `np.ndarray` semantics at a `torch.compile` level. Marking tensors +inside the compiler is a very cheap operation, so no data copy or data movement +happens during runtime. + +Using this decorator, we can also differentiate through NumPy code! + +```python +@torch.compile(fullgraph=True) +@torch.compiler.wrap_numpy +def numpy_fn(X, Y): + return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))) + +X = torch.randn(1024, 64, device="cuda", requires_grad=True) +Y = torch.randn(1024, 64, device="cuda") +Z = numpy_fn(X, Y) +assert isinstance(Z, torch.Tensor) +Z.backward() +# X.grad now holds the gradient of the computation +print(X.grad) +``` + +We have been using `fullgraph=True` as graph break are problematic in this context. +When a graph break occurs, we need to materialize the NumPy arrays. Since NumPy arrays +do not have a notion of `device` or `requires_grad`, this information is lost during +a graph break. + +We cannot propagate gradients through a graph break, as the graph break code may execute +arbitrary code that don't know how to differentiate. On the other hand, in the case of +the CUDA execution, we can work around this problem as we did in the first example, by +using the `torch.device("cuda")` context manager: + +```python +@torch.compile +@torch.compiler.wrap_numpy +def numpy_fn(X, Y): + prod = X[:, :, None] * Y[:, None, :] + print("oops, a graph break!") + return np.sum(prod, axis=(-2, -1)) + +X = torch.randn(1024, 64, device="cuda") +Y = torch.randn(1024, 64, device="cuda") + +with torch.device("cuda"): + Z = numpy_fn(X, Y) +assert isinstance(Z, torch.Tensor) +assert Z.device.type == "cuda" +``` + +During the graph break, the intermediary tensors still need to be moved to CPU, but when the +tracing is resumed after the graph break, the rest of the graph is still traced on CUDA. +Given this CUDA <> CPU and CPU <> CUDA movement, graph breaks are fairly costly in the NumPy +context and should be avoided, but at least they allow tracing through complex pieces of code. + +### How do I debug NumPy code under `torch.compile`? + +Debugging JIT compiled code is challenging, given the complexity of modern +compilers and the daunting errors that they raise. +{ref}`The torch.compile troubleshooting doc ` +contains a few tips and tricks on how to tackle this task. + +If the above is not enough to pinpoint the origin of the issue, there are still +a few other NumPy-specific tools we can use. We can discern whether the bug +is entirely in the PyTorch code by disabling tracing through NumPy functions: + +```python +from torch._dynamo import config +config.trace_numpy = False +``` + +If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without `torch.compile`) +using PyTorch as a backend by importing `import torch._numpy as np`. +This should just be used for **debugging purposes** and is in no way a +replacement for the PyTorch API, as it is **much less performant** and, as a +private API, **may change without notice**. At any rate, `torch._numpy` is a +Python implementation of NumPy in terms of PyTorch and it is used internally by `torch.compile` to +transform NumPy code into Pytorch code. It is rather easy to read and modify, +so if you find any bug in it feel free to submit a PR fixing it or simply open +an issue. + +If the program does work when importing `torch._numpy as np`, chances are +that the bug is in TorchDynamo. If this is the case, please feel free to open an issue +with a {ref}`minimal reproducer `. + +### I `torch.compile` some NumPy code and I did not see any speed-up. + +The best place to start is the +[tutorial with general advice for how to debug these sort of torch.compile issues](https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups). + +Some graph breaks may happen because of the use of unsupported features. See +{ref}`nonsupported-numpy-feats`. More generally, it is useful to keep in mind +that some widely used NumPy features do not play well with compilers. For +example, in-place modifications make reasoning difficult within the compiler and +often yield worse performance than their out-of-place counterparts.As such, it is best to avoid +them. Same goes for the use of the `out=` parameter. Instead, prefer +out-of-place ops and let `torch.compile` optimize the memory use. Same goes +for data-dependent ops like masked indexing through boolean masks, or +data-dependent control flow like `if` or `while` constructions. + +## Which API to use for fine grain tracing? + +In some cases, you might need to exclude small parts of your code from the +torch.compile compilations. This section provides some of the answers and +you can find more information in {ref}`torchdynamo_fine_grain_tracing`. + +### How do I graph break on a function? + +Graph break on a function is not enough to sufficiently express what you want +PyTorch to do. You need to be more specific about your use case. Some of the +most common use cases you might want to consider: + +- If you want to disable compilation on this function frame and the recursively + invoked frames, use `torch._dynamo.disable`. +- If you want a particular operator, such as `fbgemm` to use the eager mode, + use `torch._dynamo.disallow_in_graph`. + +Some of the uncommon use cases include: + +- If you want to disable TorchDynamo on the function frame but enable it back + on the recursively invoked frames – use `torch._dynamo.disable(recursive=False)`. +- If you want to prevent inlining of a function frame – use `torch._dynamo.graph_break` + at the beginning of the function you want to prevent inlining. + +### What's the difference between `torch._dynamo.disable` and `torch._dynamo.disallow_in_graph` + +Disallow-in-graph works at the level of operators, or more specifically, +the operators that you see in the TorchDynamo extracted graphs. + +Disable works at the function frame level and decides if TorchDynamo +should look into the function frame or not. + +### What's the difference between `torch._dynamo.disable` and `torch._dynamo_skip` + +:::{note} +`torch._dynamo_skip` is deprecated. +::: + +You most likely need `torch._dynamo.disable`. But in an unlikely scenario, you +might need even finer control. Suppose you want to disable the tracing on just +the `a_fn` function, but want to continue the tracing back in `aa_fn` and +`ab_fn`. The image below demonstrates this use case: + +:::{figure} _static/img/fine_grained_apis/call_stack_diagram.png +:alt: diagram of torch.compile + disable(a_fn, recursive=False) +::: + +In this case, you can use `torch._dynamo.disable(recursive=False)`. +In previous versions, this functionality was provided by `torch._dynamo.skip`. +This is now supported by the `recursive` flag inside `torch._dynamo.disable`. \ No newline at end of file diff --git a/docs/source/torch.compiler_faq.rst b/docs/source/torch.compiler_faq.rst deleted file mode 100644 index 07bf7d681ac36d..00000000000000 --- a/docs/source/torch.compiler_faq.rst +++ /dev/null @@ -1,692 +0,0 @@ -Frequently Asked Questions -========================== -**Author**: `Mark Saroufim `_ - -Does ``torch.compile`` support training? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -``torch.compile`` supports training, using AOTAutograd to capture backwards: - -1. The ``.forward()`` graph and ``optimizer.step()`` is captured by - TorchDynamo’s python ``evalframe`` frontend. -2. For each segment of ``.forward()`` that torchdynamo captures, it uses - AOTAutograd to generate a backward graph segment. -3. Each pair of forward and backward graph are (optionally) min-cut - partitioned to save the minimal state between forward and backward. -4. The forward and backward pairs are wrapped in ``autograd.function`` modules. -5. Usercode calling\ ``.backward()`` still triggers eager’s autograd engine, - which runs each *compiled backward* graph as if it were one op, also running - any non-compiled eager ops’ ``.backward()`` functions. - -Do you support Distributed code? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -``torch.compile`` supports ``DistributedDataParallel`` (DDP). -Support for other distributed training libraries is being considered. - -The main reason why Distributed code is challenging with dynamo is -because AOTAutograd unrolls both the forward and backward pass and -provides 2 graphs for backends to optimize. This is a problem for -distributed code because we’d like to ideally overlap communication -operations with computations. Eager pytorch accomplishes this in -different ways for DDP/FSDP- using autograd hooks, module hooks, and -modifications/mutations of module states. In a naive application of -dynamo, hooks that should run directly after an operation during -backwards may be delayed until after the entire compiled region of -backwards ops, due to how AOTAutograd compiled functions interact with -dispatcher hooks. - -The basic strategy for optimizing DDP with Dynamo is outlined in -`distributed.py `__ -where the main idea will be to graph break on `DDP bucket -boundaries `__. - -When each node in DDP needs to synchronize its weights with the other -nodes it organizes its gradients and parameters into buckets which -reduces communication times and allows a node to broadcast a fraction of -its gradients to other waiting nodes. - -Graph breaks in distributed code mean you can expect dynamo and its -backends to optimize the compute overhead of a distributed program but -not its communication overhead. Graph-breaks may interfere with -compilation speedups, if the reduced graph-size robs the compiler of -fusion opportunities. However, there are diminishing returns with -increasing graph size since most of the current compute optimizations -are local fusions. So in practice this approach may be sufficient. - -Do I still need to export whole graphs? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -For the vast majority of models you probably don’t and you can use -``torch.compile()`` as is but there are a few situations where -full graphs are necessary and you can can ensure a full graph by simply -running ``torch.compile(..., fullgraph=True)``. These situations include: - -* Large scale training runs, such as $250K+ that require pipeline parallelism - and other advanced sharding strategies. - -* Inference optimizers like `TensorRT `__ - or `AITemplate `__ that - rely on fusing much more aggressively than training optimizers. - -* Mobile training or inference. - -Future work will include tracing communication operations into graphs, -coordinating these operations with compute optimizations, and optimizing -the communication operations. - -Why is my code crashing? -~~~~~~~~~~~~~~~~~~~~~~~~ - -If your code ran just fine without ``torch.compile`` and started to -crash with it is enabled, then the most important first step is figuring -out which part of the stack your failure occurred. To troubleshoot that, -follow the steps below and only try the next step if the previous one -succeeded. - -1. ``torch.compile(..., backend="eager")`` which only runs TorchDynamo - forward graph capture and then runs the captured graph with PyTorch. - If this fails then there’s an issue with TorchDynamo. - -2. ``torch.compile(..., backend="aot_eager")`` - which runs TorchDynamo to capture a forward graph, and then AOTAutograd - to trace the backward graph without any additional backend compiler - steps. PyTorch eager will then be used to run the forward and backward - graphs. If this fails then there’s an issue with AOTAutograd. - -3. ``torch.compile(..., backend="inductor")`` which runs TorchDynamo to capture a - forward graph, and then AOTAutograd to trace the backward graph with the - TorchInductor compiler. If this fails then there’s an issue with TorchInductor - -Why is compilation slow? -~~~~~~~~~~~~~~~~~~~~~~~~ - -* **Dynamo Compilation**– TorchDynamo has a builtin stats function for - collecting and displaying the time spent in each compilation phase. - These stats can be accessed by calling ``torch._dynamo.utils.compile_times()`` - after executing ``torch._dynamo``. By default, this returns a string - representation of the compile times spent in each TorchDynamo function by name. - -* **Inductor Compilation**– TorchInductor has a builtin stats and trace function - for displaying time spent in each compilation phase, output code, output - graph visualization and IR dump. ``env TORCH_COMPILE_DEBUG=1 python repro.py``. - This is a debugging tool designed to make it easier to debug/understand the - internals of TorchInductor with an output that will look something like - `this `__ - Each file in that debug trace can be enabled/disabled via - ``torch._inductor.config.trace.*``. The profile and the diagram are both - disabled by default since they are expensive to generate. See the - `example debug directory - output `__ - for more examples. - -* **Excessive Recompilation** - When TorchDynamo compiles a function (or part of one), it makes certain - assumptions about locals and globals in order to allow compiler - optimizations, and expresses these assumptions as guards that check - particular values at runtime. If any of these guards fail, Dynamo will - recompile that function (or part) up to - ``torch._dynamo.config.recompile_limit`` times. If your program is - hitting the cache limit, you will first need to determine which guard is - failing and what part of your program is triggering it. The - `recompilation profiler <#recompilation-profiler>`__ automates the - process of setting TorchDynamo’s cache limit to 1 and running your - program under an observation-only ‘compiler’ that records the causes of - any guard failures. You should be sure to run your program for at least - as long (as many iterations) as you were running when you ran into - trouble, and the profiler will accumulate statistics over this duration. - - -Why are you recompiling in production? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In some cases, you may not want unexpected compiles after a program has -warmed up. For example, if you are serving production traffic in a -latency critical application. For this, TorchDynamo provides an -alternate mode where prior compiled graphs are used, but no new ones are -generated: - -.. code-block:: python - - frozen_toy_example = dynamo.run(toy_example) - frozen_toy_example(torch.randn(10), torch.randn(10)) - -How are you speeding up my code? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -There are 3 major ways to accelerate PyTorch code: - -1. Kernel fusion via vertical fusions which fuse sequential operations to avoid - excessive read/writes. For example, fuse 2 subsequent cosines means you - can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion: - the simplest example being batching where a single matrix is multiplied - with a batch of examples but the more general scenario is a grouped GEMM - where a group of matrix multiplications are scheduled together - -2. Out of order execution: A general optimization for compilers, by looking ahead - at the exact data dependencies within a graph we can decide on the most - opportune time to execute a node and which buffers can be reused - -3. Automatic work placement: Similar of the out of order execution point, - but by matching nodes of a graph to resources like physical hardware or - memory we can design an appropriate schedule - -The above are general principles for accelerating PyTorch code but -different backends will each make different tradeoffs on what to -optimize. For example Inductor first takes care of fusing whatever it -can and only then generates `Triton `__ -kernels. - -Triton in addition offers speedups because of automatic memory -coalescing, memory management and scheduling within each Streaming -Multiprocessor and has been designed to handle tiled computations. - -However, regardless of the backend you use it’s best to use a benchmark -and see approach so try out the PyTorch profiler, visually inspect the -generated kernels and try to see what’s going on for yourself. - -Why am I not seeing speedups? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. _torch.compiler_graph_breaks: - -Graph Breaks ------------- - -The main reason you won’t see the speedups you’d like to by using dynamo -is excessive graph breaks. So what’s a graph break? - -Given a program like: - -.. code-block:: python - - def some_fun(x): - ... - - torch.compile(some_fun)(x) - ... - -Torchdynamo will attempt to compile all of the torch/tensor operations -within ``some_fun()`` into a single FX graph, but it may fail to capture -everything into one graph. - -Some graph break reasons are insurmountable to TorchDynamo like calling -into a C extension other than PyTorch is invisible to TorchDynamo, and -could do arbitrary things without TorchDynamo being able to introduce -necessary guards to ensure that the compiled program would be safe to reuse. - - To maximize performance, it’s important to have as few graph breaks - as possible. - -Identifying the cause of a graph break --------------------------------------- - -To identify all graph breaks in a program and the associated reasons for -the breaks, ``torch._dynamo.explain`` can be used. This tool runs -TorchDynamo on the supplied function and aggregates the graph breaks -that are encountered. Here is an example usage: - -.. code-block:: python - - import torch - import torch._dynamo as dynamo - def toy_example(a, b): - x = a / (torch.abs(a) + 1) - print("woo") - if b.sum() < 0: - b = b * -1 - return x * b - explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10)) - print(explanation) - """ - Graph Count: 3 - Graph Break Count: 2 - Op Count: 5 - Break Reasons: - Break Reason 1: - Reason: builtin: print [] False - User Stack: - - Break Reason 2: - Reason: generic_jump TensorVariable() - User Stack: - - Ops per Graph: - ... - Out Guards: - ... - """ - -To throw an error on the first graph break encountered you can -disable python fallbacks by using ``fullgraph=True``, this should be -familiar if you’ve worked with export based compilers. - -.. code-block:: python - - def toy_example(a, b): - ... - - torch.compile(toy_example, fullgraph=True, backend=)(a, b) - -Why didn’t my code recompile when I changed it? ------------------------------------------------ - -If you enabled dynamic shapes by setting -``env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py`` then your code -won’t recompile on shape changes. We’ve added support for dynamic shapes -which avoids recompilations in the case when shapes vary by less than a -factor of 2. This is especially useful in scenarios like varying image -sizes in CV or variable sequence length in NLP. In inference scenarios -it’s often not possible to know what a batch size will be beforehand -because you take what you can get from different client apps. - -In general, TorchDynamo tries very hard not to recompile things -unnecessarily so if for example TorchDynamo finds 3 graphs and your -change only modified one graph then only that graph will recompile. So -another tip to avoid potentially slow compilation times is to warmup a -model by compiling it once after which subsequent compilations will be -much faster. Cold start compile times is still a metric we track -visibly. - -Why am I getting incorrect results? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Accuracy issues can also be minified if you set the environment variable -``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect -model and a full repro might be something like -``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason -we need this is downstream compilers will codegen code whether it’s -Triton code or the C++ backend, the numerics from those downstream -compilers can be different in subtle ways yet have dramatic impact on -your training stability. So the accuracy debugger is very useful for us -to detect bugs in our codegen or with a backend compiler. - -If you'd like to ensure that random number generation is the same across both torch -and triton then you can enable ``torch._inductor.config.fallback_random = True`` - -Why am I getting OOMs? -~~~~~~~~~~~~~~~~~~~~~~ - -Dynamo is still an alpha product so there’s a few sources of OOMs and if -you’re seeing an OOM try disabling the following configurations in this -order and then open an issue on GitHub so we can solve the root problem -1. If you’re using dynamic shapes try disabling them, we’ve disabled -them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2. -CUDA graphs with Triton are enabled by default in inductor but removing -them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``. - -Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Applying a ``torch.func`` transform to a function that uses ``torch.compile`` -does work: - -.. code-block:: python - - import torch - - @torch.compile - def f(x): - return torch.sin(x) - - def g(x): - return torch.grad(f)(x) - - x = torch.randn(2, 3) - g(x) - -Calling ``torch.func`` transform inside of a function handled with ``torch.compile`` ------------------------------------------------------------------------------------- - - -Compiling ``torch.func.grad`` with ``torch.compile`` ----------------------------------------------------- - -.. code-block:: python - - import torch - - def wrapper_fn(x): - return torch.func.grad(lambda x: x.sin().sum())(x) - - x = torch.randn(3, 3, 3) - grad_x = torch.compile(wrapper_fn)(x) - -Compiling ``torch.vmap`` with ``torch.compile`` ------------------------------------------------ - -.. code-block:: python - - import torch - - def my_fn(x): - return torch.vmap(lambda x: x.sum(1))(x) - - x = torch.randn(3, 3, 3) - output = torch.compile(my_fn)(x) - - -Compiling functions besides the ones which are supported (escape hatch) ------------------------------------------------------------------------ - -For other transforms, as a workaround, use ``torch._dynamo.allow_in_graph`` - -``allow_in_graph`` is an escape hatch. If your code does not work with -``torch.compile``, which introspects Python bytecode, but you believe it -will work via a symbolic tracing approach (like ``jax.jit``), then use -``allow_in_graph``. - -By using ``allow_in_graph`` to annotate a function, you must make sure -your code meets the following requirements: - -- All outputs in your function only depend on the inputs and - do not depend on any captured Tensors. -- Your function is functional. That is, it does not mutate any state. This may - be relaxed; we actually support functions that appear to be functional from - the outside: they may have in-place PyTorch operations, but may not mutate - global state or inputs to the function. -- Your function does not raise data-dependent errors. - -.. code-block:: python - - import torch - - @torch.compile - def f(x): - return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x) - - x = torch.randn(2, 3) - f(x) - -A common pitfall is using ``allow_in_graph`` to annotate a function that -invokes an ``nn.Module``. This is because the outputs now depend on the -parameters of the ``nn.Module``. To get this to work, use -``torch.func.functional_call`` to extract the module state. - -Does NumPy work with ``torch.compile``? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Starting in 2.1, ``torch.compile`` understands native NumPy programs that -work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch -to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions. - -.. _nonsupported-numpy-feats: - -Which NumPy features does ``torch.compile`` support? ----------------------------------------------------- - -NumPy within ``torch.compile`` follows NumPy 2.0 pre-release. - -Generally, ``torch.compile`` is able to trace through most NumPy constructions, -and when it cannot, it falls back to eager and lets NumPy execute that piece of -code. Even then, there are a few features where ``torch.compile`` semantics -slightly deviate from those of NumPy: - -- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns - a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D - array. If this breaks your code, you can workaround this by casting the NumPy scalar - to the relevant Python scalar type ``bool/int/float``. - -- Negative strides: ``np.flip`` and slicing with a negative step return a copy. - -- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules - are described in `NEP 50 `__. - ``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules. - -- ``{tril,triu}_indices_from/{tril,triu}_indices`` return arrays rather than a tuple of arrays. - -There are other features for which we do not support tracing and we gracefully -fallback to NumPy for their execution: - -- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays. - -- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``. - -- ``ndarray`` subclasses. - -- Masked arrays. - -- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``). - -- Sorting / ordering ``complex64/complex128`` arrays. - -- NumPy ``np.poly1d`` and ``np.polynomial``. - -- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work). - -- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``. - -- ``ndarray.ctypes`` attribute. - -Can I compile NumPy code using ``torch.compile``? -------------------------------------------------- - -Of course you do! ``torch.compile`` understands NumPy code natively, and treats it -as if it were PyTorch code. To do so, simply wrap NumPy code with the ``torch.compile`` -decorator. - -.. code-block:: python - - import torch - import numpy as np - - @torch.compile - def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: - return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) - - X = np.random.randn(1024, 64) - Y = np.random.randn(1024, 64) - Z = numpy_fn(X, Y) - assert isinstance(Z, np.ndarray) - -Executing this example with the environment variable ``TORCH_LOGS=output_code``, we can see -that ``torch.compile`` was able to fuse the multiplication and the sum into one C++ kernel. -It was also able to execute them in parallel using OpenMP (native NumPy is single-threaded). -This can easily make your NumPy code ``n`` times faster, where ``n`` is the number of cores -in your processor! - -Tracing NumPy code this way also supports graph breaks within the compiled code. - -Can I execute NumPy code on CUDA and compute gradients via ``torch.compile``? ------------------------------------------------------------------------------ - -Yes you can! To do so, you may simply execute your code within a ``torch.device("cuda")`` -context. Consider the example - -.. code-block:: python - - import torch - import numpy as np - - @torch.compile - def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray: - return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) - - X = np.random.randn(1024, 64) - Y = np.random.randn(1024, 64) - with torch.device("cuda"): - Z = numpy_fn(X, Y) - assert isinstance(Z, np.ndarray) - -In this example, ``numpy_fn`` will be executed in CUDA. For this to be -possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU -to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are -executing this function several times in the same program run, we may want -to avoid all these rather expensive memory copies. To do so, we just need -to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors. -We can do so by using ``torch.compiler.wrap_numpy``: - -.. code-block:: python - - @torch.compile(fullgraph=True) - @torch.compiler.wrap_numpy - def numpy_fn(X, Y): - return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)) - - X = torch.randn(1024, 64, device="cuda") - Y = torch.randn(1024, 64, device="cuda") - Z = numpy_fn(X, Y) - assert isinstance(Z, torch.Tensor) - assert Z.device.type == "cuda" - -Here, we explicitly create the tensors in CUDA memory, and pass them to the -function, which performs all the computations on the CUDA device. -``wrap_numpy`` is in charge of marking any ``torch.Tensor`` input as an input -with ``np.ndarray`` semantics at a ``torch.compile`` level. Marking tensors -inside the compiler is a very cheap operation, so no data copy or data movement -happens during runtime. - -Using this decorator, we can also differentiate through NumPy code! - -.. code-block:: python - - @torch.compile(fullgraph=True) - @torch.compiler.wrap_numpy - def numpy_fn(X, Y): - return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))) - - X = torch.randn(1024, 64, device="cuda", requires_grad=True) - Y = torch.randn(1024, 64, device="cuda") - Z = numpy_fn(X, Y) - assert isinstance(Z, torch.Tensor) - Z.backward() - # X.grad now holds the gradient of the computation - print(X.grad) - -We have been using ``fullgraph=True`` as graph break are problematic in this context. -When a graph break occurs, we need to materialize the NumPy arrays. Since NumPy arrays -do not have a notion of ``device`` or ``requires_grad``, this information is lost during -a graph break. - -We cannot propagate gradients through a graph break, as the graph break code may execute -arbitrary code that don't know how to differentiate. On the other hand, in the case of -the CUDA execution, we can work around this problem as we did in the first example, by -using the ``torch.device("cuda")`` context manager: - -.. code-block:: python - - @torch.compile - @torch.compiler.wrap_numpy - def numpy_fn(X, Y): - prod = X[:, :, None] * Y[:, None, :] - print("oops, a graph break!") - return np.sum(prod, axis=(-2, -1)) - - X = torch.randn(1024, 64, device="cuda") - Y = torch.randn(1024, 64, device="cuda") - - with torch.device("cuda"): - Z = numpy_fn(X, Y) - assert isinstance(Z, torch.Tensor) - assert Z.device.type == "cuda" - -During the graph break, the intermediary tensors still need to be moved to CPU, but when the -tracing is resumed after the graph break, the rest of the graph is still traced on CUDA. -Given this CUDA <> CPU and CPU <> CUDA movement, graph breaks are fairly costly in the NumPy -context and should be avoided, but at least they allow tracing through complex pieces of code. - - -How do I debug NumPy code under ``torch.compile``? --------------------------------------------------- - -Debugging JIT compiled code is challenging, given the complexity of modern -compilers and the daunting errors that they raise. -:ref:`The torch.compile troubleshooting doc ` -contains a few tips and tricks on how to tackle this task. - -If the above is not enough to pinpoint the origin of the issue, there are still -a few other NumPy-specific tools we can use. We can discern whether the bug -is entirely in the PyTorch code by disabling tracing through NumPy functions: - - -.. code-block:: python - - from torch._dynamo import config - config.trace_numpy = False - -If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without ``torch.compile``) -using PyTorch as a backend by importing ``import torch._numpy as np``. -This should just be used for **debugging purposes** and is in no way a -replacement for the PyTorch API, as it is **much less performant** and, as a -private API, **may change without notice**. At any rate, ``torch._numpy`` is a -Python implementation of NumPy in terms of PyTorch and it is used internally by ``torch.compile`` to -transform NumPy code into Pytorch code. It is rather easy to read and modify, -so if you find any bug in it feel free to submit a PR fixing it or simply open -an issue. - -If the program does work when importing ``torch._numpy as np``, chances are -that the bug is in TorchDynamo. If this is the case, please feel open an issue -with a :ref:`minimal reproducer `. - -I ``torch.compile`` some NumPy code and I did not see any speed-up. -------------------------------------------------------------------- - -The best place to start is the -`tutorial with general advice for how to debug these sort of torch.compile issues `__. - -Some graph breaks may happen because of the use of unsupported features. See -:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind -that some widely used NumPy features do not play well with compilers. For -example, in-place modifications make reasoning difficult within the compiler and -often yield worse performance than their out-of-place counterparts.As such, it is best to avoid -them. Same goes for the use of the ``out=`` parameter. Instead, prefer -out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes -for data-dependent ops like masked indexing through boolean masks, or -data-dependent control flow like ``if`` or ``while`` constructions. - - -Which API to use for fine grain tracing? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In some cases, you might need to exclude small parts of your code from the -torch.compile compilations. This section provides some of the answers and -you can find more information in :ref:`torchdynamo_fine_grain_tracing`. - -How do I graph break on a function? ------------------------------------ - -Graph break on a function is not enough to sufficiently express what you want -PyTorch to do. You need to be more specific about your use case. Some of the -most common use cases you might want to consider: - -* If you want to disable compilation on this function frame and the recursively - invoked frames, use ``torch._dynamo.disable``. - -* If you want a particular operator, such as ``fbgemm`` to use the eager mode, - use ``torch._dynamo.disallow_in_graph``. - -Some of the uncommon use cases include: - -* If you want to disable TorchDynamo on the function frame but enable it back - on the recursively invoked frames – use ``torch._dynamo.disable(recursive=False)``. - -* If you want to prevent inlining of a function frame – use ``torch._dynamo.graph_break`` - at the beginning of the function you want to prevent inlining. - -What's the difference between ``torch._dynamo.disable`` and ``torch._dynamo.disallow_in_graph`` ------------------------------------------------------------------------------------------------ - -Disallow-in-graph works at the level of operators, or more specifically, -the operators that you see in the TorchDynamo extracted graphs. - -Disable works at the function frame level and decides if TorchDynamo -should look into the function frame or not. - -What's the difference between ``torch._dynamo.disable`` and ``torch._dynamo_skip`` ----------------------------------------------------------------------------------- - -.. note:: - ``torch._dynamo_skip`` is deprecated. - -You most likely need ``torch._dynamo.disable``. But in an unlikely scenario, you -might need even finer control. Suppose you want to disable the tracing on just -the ``a_fn`` function, but want to continue the tracing back in ``aa_fn`` and -``ab_fn``. The image below demonstrates this use case: - - -.. figure:: _static/img/fine_grained_apis/call_stack_diagram.png - :alt: diagram of torch.compile + disable(a_fn, recursive=False) - -In this case, you can use ``torch._dynamo.disable(recursive=False)``. -In previous versions, this functionality was provided by ``torch._dynamo.skip``. -This is now supported by the ``recursive`` flag inside ``torch._dynamo.disable``. diff --git a/docs/source/torch.compiler_fine_grain_apis.md b/docs/source/torch.compiler_fine_grain_apis.md new file mode 100644 index 00000000000000..fc4768ce2ebc07 --- /dev/null +++ b/docs/source/torch.compiler_fine_grain_apis.md @@ -0,0 +1,108 @@ +(torchdynamo_fine_grain_tracing)= + +# TorchDynamo APIs for fine-grained tracing + +:::{note} +In this document `torch.compiler.compile` and `torch.compile` are used interchangeably. +Both versions will work in your code. +::: + +`torch.compile` performs TorchDynamo tracing on the whole user model. +However, it is possible that a small part of the model code cannot be +handled by `torch.compiler`. In this case, you might want to disable +the compiler on that particular portion, while running compilation on +the rest of the model. This section describe the existing APIs that +use to define parts of your code in which you want to skip compilation +and the relevant use cases. + +The API that you can use to define portions of the code on which you can +disable compilation are listed in the following table: + +```{eval-rst} +.. csv-table:: TorchDynamo APIs to control fine-grained tracing + :header: "API", "Description", "When to use?" + :widths: auto + + "``torch.compiler.disable``", "Disables Dynamo on the decorated function as well as recursively invoked functions.", "Excellent for unblocking a user, if a small portion of the model cannot be handled with ``torch.compile``." + "``torch._dynamo.disallow_in_graph``", "Disallows the marked op in the TorchDynamo graph. TorchDynamo causes graph break, and runs the op in the eager (no compile) mode.\n\nThis is suitable for the ops, while ``torch.compiler.disable`` is suitable for decorating functions.", "This API is excellent for both debugging and unblocking if a custom op like ``torch.ops.fbgemm.*`` is causing issues with the ``torch.compile`` function." + "``torch.compile.allow_in_graph``", "The annotated callable goes as is in the TorchDynamo graph. For example, a black-box for TorchDynamo Dynamo.\n\nNote that AOT Autograd will trace through it, so the ``allow_in_graph`` is only a Dynamo-level concept.", "This API is useful for portions of the model which have known TorchDynamo hard-to-support features, like hooks or ``autograd.Function``. However, each usage of ``allow_in_graph`` **must be carefully screened** (no graph breaks, no closures)." + "``torch._dynamo.graph_break``", "Adds a graph break. The code before and after the graph break goes through TorchDynamo.", "**Rarely useful for deployment** - If you think you need this, most probably you need either ``disable`` or ``disallow_in_graph``." + "``torch.compiler.is_compiling``", "Indicates whether a graph is executed/traced as part of torch.compile() or torch.export()." + "``torch.compiler.is_dynamo_compiling``", "Indicates whether a graph is traced via TorchDynamo. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when TorchDynamo is used." + "``torch.compiler.is_exporting``", "Indicates whether a graph is traced via export. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when torch.export is used." +``` + +## `torch.compiler.disable` + +`torch.compiler.disable` disables compilation on the decorated function frame and all the function frames recursively invoked from the decorated function frame. + +TorchDynamo intercepts the execution of each Python function frame. So, suppose you have a code structure (image below) where the function `fn` calls functions `a_fn` and `b_fn`. And `a_fn` calls `aa_fn` and `ab_fn`. When you use the PyTorch eager mode rather than `torch.compile`, these function frames run as is. With `torch.compile`, TorchDynamo intercepts each of these function frames (indicated by the green color): + +:::{figure} _static/img/fine_grained_apis/api_diagram.png +:alt: Callstack diagram of different apis. +::: + +Let's imagine, that function `a_fn` is causing troubles with `torch.compile`. +And this is a non-critical portion of the model. You can use `compiler.disable` +on function `a_fn`. As shown above, TorchDynamo will stop looking at frames +originating from the `a_fn` call (white color indicates original Python behavior). + +To skip compilation, you can decorate the offending function with +`@torch.compiler.disable`. + +You can also use the non-decorator syntax if you don’t want to change the source +code +However, we recommend that you avoid this style if possible. Here, you have to +take care that all users of the original function are now using the patched +version. + +## `torch._dynamo.disallow_in_graph` + +`torch._dynamo.disallow_in_graph` disallows an operator but not the function +to be present in the TorchDynamo extracted graph. Note that this is suitable +for operators and not general functions as in the case of `_dynamo.disable`. + +Let's imagine you compile your model with PyTorch. TorchDynamo is able to +extract a graph, but then you see the downstream compiler failing. For example, +the meta kernel is missing, or some Autograd dispatch key is set incorrectly +for a particular operator. Then you can mark that operator as +`disallow_in_graph`, and TorchDynamo will cause a graph break and run that +operator by using the PyTorch eager mode. + +The catch is that you will have to find the corresponding Dynamo level operator, +and not the ATen level operator. See more in the Limitations section of the doc. + +:::{warning} +`torch._dynamo.disallow_in_graph` is a global flag. If you are comparing +different backend compilers, you might have to call `allow_in_graph` for +the disallowed operator when switching to the other compiler. +::: + +## `torch.compiler.allow_in_graph` + +`torch.compiler.allow_in_graph` is useful when the relevant function frame +has some known hard-to-support TorchDynamo feature, such as hooks and +`autograd.Function`, and you are confident that downstream PyTorch components +such as AOTAutograd can safely trace through the decorated function. When a +function is decorated with `allow_in_graph`, TorchDynamo treats it as a +black-box and puts it as is in the generated graph. + +:::{warning} +`allow_in_graph` skips TorchDynamo completely on the decorated function +omitting all TorchDynamo safety checks, including graph breaks, handling +closures, and others. Use `allow_in_graph` with caution. PyTorch downstream +components, such as AOTAutograd rely on TorchDynamo to handle complex Python +features, but `allow_in_graph` bypasses TorchDynamo. Using `allow_in_graph` +could lead to soundness and hard-to-debug issues. +::: + +## Limitations + +All the existing APIs are applied at the TorchDynamo level. Therefore, these +APIs have visibility to only what TorchDynamo sees. This can lead to confusing +scenarios. + +For example, `torch._dynamo.disallow_in_graph` will not work for ATen operators +because they are visible to AOT Autograd. For example, +`torch._dynamo.disallow_in_graph(torch.ops.aten.add)` will not work in the +above example. diff --git a/docs/source/torch.compiler_fine_grain_apis.rst b/docs/source/torch.compiler_fine_grain_apis.rst deleted file mode 100644 index 7f61d88a26967c..00000000000000 --- a/docs/source/torch.compiler_fine_grain_apis.rst +++ /dev/null @@ -1,107 +0,0 @@ -.. _torchdynamo_fine_grain_tracing: - -TorchDynamo APIs for fine-grained tracing -========================================= - -.. note:: In this document ``torch.compiler.compile`` and - ``torch.compile`` are used interchangeably. Both versions - will work in your code. - -``torch.compile`` performs TorchDynamo tracing on the whole user model. -However, it is possible that a small part of the model code cannot be -handled by ``torch.compiler``. In this case, you might want to disable -the compiler on that particular portion, while running compilation on -the rest of the model. This section describe the existing APIs that -use to define parts of your code in which you want to skip compilation -and the relevant use cases. - -The API that you can use to define portions of the code on which you can -disable compilation are listed in the following table: - -.. csv-table:: TorchDynamo APIs to control fine-grained tracing - :header: "API", "Description", "When to use?" - :widths: auto - - "``torch.compiler.disable``", "Disables Dynamo on the decorated function as well as recursively invoked functions.", "Excellent for unblocking a user, if a small portion of the model cannot be handled with ``torch.compile``." - "``torch._dynamo.disallow_in_graph``", "Disallows the marked op in the TorchDynamo graph. TorchDynamo causes graph break, and runs the op in the eager (no compile) mode.\n\nThis is suitable for the ops, while ``torch.compiler.disable`` is suitable for decorating functions.", "This API is excellent for both debugging and unblocking if a custom op like ``torch.ops.fbgemm.*`` is causing issues with the ``torch.compile`` function." - "``torch.compile.allow_in_graph``", "The annotated callable goes as is in the TorchDynamo graph. For example, a black-box for TorchDynamo Dynamo.\n\nNote that AOT Autograd will trace through it, so the ``allow_in_graph`` is only a Dynamo-level concept.", "This API is useful for portions of the model which have known TorchDynamo hard-to-support features, like hooks or ``autograd.Function``. However, each usage of ``allow_in_graph`` **must be carefully screened** (no graph breaks, no closures)." - "``torch._dynamo.graph_break``", "Adds a graph break. The code before and after the graph break goes through TorchDynamo.", "**Rarely useful for deployment** - If you think you need this, most probably you need either ``disable`` or ``disallow_in_graph``." - "``torch.compiler.is_compiling``", "Indicates whether a graph is executed/traced as part of torch.compile() or torch.export()." - "``torch.compiler.is_dynamo_compiling``", "Indicates whether a graph is traced via TorchDynamo. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when TorchDynamo is used." - "``torch.compiler.is_exporting``", "Indicates whether a graph is traced via export. It's stricter than torch.compiler.is_compiling() flag, as it would only be set to True when torch.export is used." - -``torch.compiler.disable`` -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -``torch.compiler.disable`` disables compilation on the decorated function frame and all the function frames recursively invoked from the decorated function frame. - -TorchDynamo intercepts the execution of each Python function frame. So, suppose you have a code structure (image below) where the function ``fn`` calls functions ``a_fn`` and ``b_fn``. And ``a_fn`` calls ``aa_fn`` and ``ab_fn``. When you use the PyTorch eager mode rather than ``torch.compile``, these function frames run as is. With ``torch.compile``, TorchDynamo intercepts each of these function frames (indicated by the green color): - -.. figure:: _static/img/fine_grained_apis/api_diagram.png - :alt: Callstack diagram of different apis. - -Let's imagine, that function ``a_fn`` is causing troubles with ``torch.compile``. -And this is a non-critical portion of the model. You can use ``compiler.disable`` -on function ``a_fn``. As shown above, TorchDynamo will stop looking at frames -originating from the ``a_fn`` call (white color indicates original Python behavior). - -To skip compilation, you can decorate the offending function with -``@torch.compiler.disable``. - -You can also use the non-decorator syntax if you don’t want to change the source -code -However, we recommend that you avoid this style if possible. Here, you have to -take care that all users of the original function are now using the patched -version. - -``torch._dynamo.disallow_in_graph`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -``torch._dynamo.disallow_in_graph`` disallows an operator but not the function -to be present in the TorchDynamo extracted graph. Note that this is suitable -for operators and not general functions as in the case of ``_dynamo.disable``. - -Let's imagine you compile your model with PyTorch. TorchDynamo is able to -extract a graph, but then you see the downstream compiler failing. For example, -the meta kernel is missing, or some Autograd dispatch key is set incorrectly -for a particular operator. Then you can mark that operator as -``disallow_in_graph``, and TorchDynamo will cause a graph break and run that -operator by using the PyTorch eager mode. - -The catch is that you will have to find the corresponding Dynamo level operator, -and not the ATen level operator. See more in the Limitations section of the doc. - -.. warning:: - ``torch._dynamo.disallow_in_graph`` is a global flag. If you are comparing - different backend compilers, you might have to call ``allow_in_graph`` for - the disallowed operator when switching to the other compiler. - -``torch.compiler.allow_in_graph`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -``torch.compiler.allow_in_graph`` is useful when the relevant function frame -has some known hard-to-support TorchDynamo feature, such as hooks and -``autograd.Function``, and you are confident that downstream PyTorch components -such as AOTAutograd can safely trace through the decorated function. When a -function is decorated with ``allow_in_graph``, TorchDynamo treats it as a -black-box and puts it as is in the generated graph. - -.. warning:: - ``allow_in_graph`` skips TorchDynamo completely on the decorated function - omitting all TorchDynamo safety checks, including graph breaks, handling - closures, and others. Use `allow_in_graph` with caution. PyTorch downstream - components, such as AOTAutograd rely on TorchDynamo to handle complex Python - features, but ``allow_in_graph`` bypasses TorchDynamo. Using ``allow_in_graph`` - could lead to soundness and hard-to-debug issues. - -Limitations -~~~~~~~~~~~ - -All the existing APIs are applied at the TorchDynamo level. Therefore, these -APIs have visibility to only what TorchDynamo sees. This can lead to confusing -scenarios. - -For example, ``torch._dynamo.disallow_in_graph`` will not work for ATen operators -because they are visible to AOT Autograd. For example, -``torch._dynamo.disallow_in_graph(torch.ops.aten.add)`` will not work in the -above example. diff --git a/docs/source/torch.compiler_get_started.md b/docs/source/torch.compiler_get_started.md new file mode 100644 index 00000000000000..adbc2184df250b --- /dev/null +++ b/docs/source/torch.compiler_get_started.md @@ -0,0 +1,148 @@ +(torch_compiler_get_started)= + +# Getting Started + +Before you read this section, make sure to read the {ref}`torch.compiler_overview` + +let's start by looking at a simple `torch.compile` example that demonstrates +how to use `torch.compile` for inference. This example demonstrates the +`torch.cos()` and `torch.sin()` features which are examples of pointwise +operators as they operate element by element on a vector. This example might +not show significant performance gains but should help you form an intuitive +understanding of how you can use `torch.compile` in your own programs. + +:::{note} +To run this script, you need to have at least one GPU on your machine. +If you do not have a GPU, you can remove the `.to(device="cuda:0")` code +in the snippet below and it will run on CPU. You can also set device to +`xpu:0` to run on Intel® GPUs. +::: + +```python +import torch +def fn(x): + a = torch.cos(x) + b = torch.sin(a) + return b +new_fn = torch.compile(fn, backend="inductor") +input_tensor = torch.randn(10000).to(device="cuda:0") +a = new_fn(input_tensor) +``` + +A more famous pointwise operator you might want to use would +be something like `torch.relu()`. Pointwise ops in eager mode are +suboptimal because each one would need to read a tensor from the +memory, make some changes, and then write back those changes. The single +most important optimization that inductor performs is fusion. In the +example above we can turn 2 reads (`x`, `a`) and +2 writes (`a`, `b`) into 1 read (`x`) and 1 write (`b`), which +is crucial especially for newer GPUs where the bottleneck is memory +bandwidth (how quickly you can send data to a GPU) rather than compute +(how quickly your GPU can crunch floating point operations). + +Another major optimization that inductor provides is automatic +support for CUDA graphs. +CUDA graphs help eliminate the overhead from launching individual +kernels from a Python program which is especially relevant for newer GPUs. + +TorchDynamo supports many different backends, but TorchInductor specifically works +by generating [Triton](https://github.com/openai/triton) kernels. Let's save +our example above into a file called `example.py`. We can inspect the code +generated Triton kernels by running `TORCH_COMPILE_DEBUG=1 python example.py`. +As the script executes, you should see `DEBUG` messages printed to the +terminal. Closer to the end of the log, you should see a path to a folder +that contains `torchinductor_`. In that folder, you can find +the `output_code.py` file that contains the generated kernel code similar to +the following: + +```python +@pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) +@triton.jit +def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 10000 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask, other=0.0) + tmp1 = tl.cos(tmp0) + tmp2 = tl.sin(tmp1) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask) +``` + +:::{note} +The above code snippet is an example. Depending on your hardware, +you might see different code generated. +::: + +And you can verify that fusing the `cos` and `sin` did actually occur +because the `cos` and `sin` operations occur within a single Triton kernel +and the temporary variables are held in registers with very fast access. + +Read more on Triton's performance +[here](https://openai.com/blog/triton/). Because the code is written +in Python, it's fairly easy to understand even if you have not written all that +many CUDA kernels. + +Next, let's try a real model like resnet50 from the PyTorch +hub. + +```python +import torch +model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) +opt_model = torch.compile(model, backend="inductor") +opt_model(torch.randn(1,3,64,64)) +``` + +And that is not the only available backend, you can run in a REPL +`torch.compiler.list_backends()` to see all the available backends. Try out the +`cudagraphs` next as inspiration. + +## Using a pretrained model + +PyTorch users frequently leverage pretrained models from +[transformers](https://github.com/huggingface/transformers) or +[TIMM](https://github.com/rwightman/pytorch-image-models) and one of +the design goals is TorchDynamo and TorchInductor is to work out of the box with +any model that people would like to author. + +Let's download a pretrained model directly from the HuggingFace hub and optimize +it: + +```python +import torch +from transformers import BertTokenizer, BertModel +# Copy pasted from here https://huggingface.co/bert-base-uncased +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0") +model = torch.compile(model, backend="inductor") # This is the only line of code that we changed +text = "Replace me by any text you'd like." +encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0") +output = model(**encoded_input) +``` + +If you remove the `to(device="cuda:0")` from the model and +`encoded_input`, then Triton will generate C++ kernels that will be +optimized for running on your CPU. You can inspect both Triton or C++ +kernels for BERT. They are more complex than the trigonometry +example we tried above but you can similarly skim through it and see if you +understand how PyTorch works. + +Similarly, let's try out a TIMM example: + +```python +import timm +import torch +model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2) +opt_model = torch.compile(model, backend="inductor") +opt_model(torch.randn(64,3,7,7)) +``` + +## Next Steps + +In this section, we have reviewed a few inference examples and developed a +basic understanding of how torch.compile works. Here is what you check out next: + +- [torch.compile tutorial on training](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) +- {ref}`torch.compiler_api` +- {ref}`torchdynamo_fine_grain_tracing` \ No newline at end of file diff --git a/docs/source/torch.compiler_get_started.rst b/docs/source/torch.compiler_get_started.rst deleted file mode 100644 index 7661c884177d67..00000000000000 --- a/docs/source/torch.compiler_get_started.rst +++ /dev/null @@ -1,148 +0,0 @@ -.. _torch.compiler_get_started: - -Getting Started -=============== - -Before you read this section, make sure to read the :ref:`torch.compiler_overview`. - -Let's start by looking at a simple ``torch.compile`` example that demonstrates -how to use ``torch.compile`` for inference. This example demonstrates the -``torch.cos()`` and ``torch.sin()`` features which are examples of pointwise -operators as they operate element by element on a vector. This example might -not show significant performance gains but should help you form an intuitive -understanding of how you can use ``torch.compile`` in your own programs. - -.. note:: - To run this script, you need to have at least one GPU on your machine. - If you do not have a GPU, you can remove the ``.to(device="cuda:0")`` code - in the snippet below and it will run on CPU. You can also set device to - ``xpu:0`` to run on Intel® GPUs. - -.. code:: python - - import torch - def fn(x): - a = torch.cos(x) - b = torch.sin(a) - return b - new_fn = torch.compile(fn, backend="inductor") - input_tensor = torch.randn(10000).to(device="cuda:0") - a = new_fn(input_tensor) - -A more famous pointwise operator you might want to use would -be something like ``torch.relu()``. Pointwise ops in eager mode are -suboptimal because each one would need to read a tensor from the -memory, make some changes, and then write back those changes. The single -most important optimization that inductor performs is fusion. In the -example above we can turn 2 reads (``x``, ``a``) and -2 writes (``a``, ``b``) into 1 read (``x``) and 1 write (``b``), which -is crucial especially for newer GPUs where the bottleneck is memory -bandwidth (how quickly you can send data to a GPU) rather than compute -(how quickly your GPU can crunch floating point operations). - -Another major optimization that inductor provides is automatic -support for CUDA graphs. -CUDA graphs help eliminate the overhead from launching individual -kernels from a Python program which is especially relevant for newer GPUs. - -TorchDynamo supports many different backends, but TorchInductor specifically works -by generating `Triton `__ kernels. Let's save -our example above into a file called ``example.py``. We can inspect the code -generated Triton kernels by running ``TORCH_COMPILE_DEBUG=1 python example.py``. -As the script executes, you should see ``DEBUG`` messages printed to the -terminal. Closer to the end of the log, you should see a path to a folder -that contains ``torchinductor_``. In that folder, you can find -the ``output_code.py`` file that contains the generated kernel code similar to -the following: - -.. code-block:: python - - @pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) - @triton.jit - def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): - xnumel = 10000 - xoffset = tl.program_id(0) * XBLOCK - xindex = xoffset + tl.arange(0, XBLOCK)[:] - xmask = xindex < xnumel - x0 = xindex - tmp0 = tl.load(in_ptr0 + (x0), xmask, other=0.0) - tmp1 = tl.cos(tmp0) - tmp2 = tl.sin(tmp1) - tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask) - -.. note:: The above code snippet is an example. Depending on your hardware, - you might see different code generated. - -And you can verify that fusing the ``cos`` and ``sin`` did actually occur -because the ``cos`` and ``sin`` operations occur within a single Triton kernel -and the temporary variables are held in registers with very fast access. - -Read more on Triton's performance -`here `__. Because the code is written -in Python, it's fairly easy to understand even if you have not written all that -many CUDA kernels. - -Next, let's try a real model like resnet50 from the PyTorch -hub. - -.. code-block:: python - - import torch - model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) - opt_model = torch.compile(model, backend="inductor") - opt_model(torch.randn(1,3,64,64)) - -And that is not the only available backend, you can run in a REPL -``torch.compiler.list_backends()`` to see all the available backends. Try out the -``cudagraphs`` next as inspiration. - -Using a pretrained model -~~~~~~~~~~~~~~~~~~~~~~~~ - -PyTorch users frequently leverage pretrained models from -`transformers `__ or -`TIMM `__ and one of -the design goals is TorchDynamo and TorchInductor is to work out of the box with -any model that people would like to author. - -Let's download a pretrained model directly from the HuggingFace hub and optimize -it: - -.. code-block:: python - - import torch - from transformers import BertTokenizer, BertModel - # Copy pasted from here https://huggingface.co/bert-base-uncased - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0") - model = torch.compile(model, backend="inductor") # This is the only line of code that we changed - text = "Replace me by any text you'd like." - encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0") - output = model(**encoded_input) - -If you remove the ``to(device="cuda:0")`` from the model and -``encoded_input``, then Triton will generate C++ kernels that will be -optimized for running on your CPU. You can inspect both Triton or C++ -kernels for BERT. They are more complex than the trigonometry -example we tried above but you can similarly skim through it and see if you -understand how PyTorch works. - -Similarly, let's try out a TIMM example: - -.. code-block:: python - - import timm - import torch - model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2) - opt_model = torch.compile(model, backend="inductor") - opt_model(torch.randn(64,3,7,7)) - -Next Steps -~~~~~~~~~~ - -In this section, we have reviewed a few inference examples and developed a -basic understanding of how torch.compile works. Here is what you check out next: - -- `torch.compile tutorial on training `_ -- :ref:`torch.compiler_api` -- :ref:`torchdynamo_fine_grain_tracing` diff --git a/docs/source/torch.compiler_inductor_profiling.md b/docs/source/torch.compiler_inductor_profiling.md new file mode 100644 index 00000000000000..c8e69e836b957b --- /dev/null +++ b/docs/source/torch.compiler_inductor_profiling.md @@ -0,0 +1,171 @@ +# TorchInductor GPU Profiling + +This section lists useful commands and workflows that can help +you dive into a model’s performance in TorchInductor. When a model is not +running as fast as expected, you may want to check individual kernels of the +model. Usually, those kernels taking the majority of the +GPU time are the most interesting ones. After that, you +may also want to run individual kernels directly and inspect its perf. +PyTorch provides tools to cover everything mentioned above. + +## Relevant Environment Variables + +You can use the following environment variables in your analysis: + +- ``TORCHINDUCTOR_UNIQUE_KERNEL_NAMES`` + + - By default, TorchInductor names a Triton kernel as ``‘triton\_’``. When + this environmental variable is enabled, inductor generates a more + meaningful kernel name in the trace, for example, + ``triton_poi_fused_cat_155`` which contains the kernel category + (``poi`` for pointwise) and original ATen + operator. This config is disabled by default to improve the chance of + compilation cache hit. + +- ``TORCHINDUCTOR_BENCHMARK_KERNEL`` + + - Enabling this will make inductor codegen harness to benchmark + individual triton kernels. + +- ``TORCHINDUCTOR_MAX_AUTOTUNE`` + + - Inductor autotuner will benchmark more ``triton.Configs`` and pick the + one with the best performance results. This will increase compilation + time with the hope to improve performance. + +## Breakdown Model GPU Time + +Below are the steps to breakdown execution time of a model into +individual kernels. We take ``mixnet_l`` as an example. + +1. Run the benchmark script for the model: + + ```bash + TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1 + python -u benchmarks/dynamo/timm_models.py –backend inductor –amp + –performance –dashboard –only mixnet_l –disable-cudagraphs –training + ``` + ```{note} + The tool relies on kernel name to decide its category. Enabling + ``TORCHINDUCTOR_UNIQUE_KERNEL_NAMES`` is crucial for that. + ``` +2. In the output log, look for lines: + + ```bash + **Compiled module path: + /tmp/torchinductor_shunting/qz/cqz7hvhood7y3psp7fy6msjxsxyli7qiwiybizdwtjw6ffyq5wwd.py** + ``` + +We have one line for each compiled module. If there are no extra graph +breaks, we would see 2 such lines in the log, one for the forward graph +and one for the backward graph. + +For our example command, we get the following compiled module for the +forward and backward graphs respectively: + +- [Forward graph compiled module](https://gist.github.com/shunting314/c2a4d8a28b00fcb5586d0e9d9bf77f9f) +- [Backward graph compiled module](https://gist.github.com/shunting314/48efc83b12ec3ead950052e4a0220b10) + +3. Now we can dive into the perf for each individual compiled module. + Let’s pick the one for the forward graph for illustration purposes. + I’ll name it ``fwd.py`` for convenience. Run it directly with the + ``-p`` argument: + + ```bash + **> python fwd.py -p** + ``` + +See the full output log in this [example gist](https://gist.github.com/shunting314/8243734a38b5733ea78479209c0ae893) + +In the output, you can notice the following: + +* We write a chrome trace file for the profile so we can load the trace and interact with it. In the log, look for lines as follows to find the path of the trace file. + + + **Chrome trace for the profile is written to /tmp/compiled_module_profile.json** + + Loading the trace into Chrome (visit chrome://tracing in the chrome browser and load the file as the UI suggested) will show UI as follows: + + ```{image} _static/img/inductor_profiling/trace.png + ``` + + You can zoom in and out to check the profile. + +* We report the percent of GPU time regarding to the wall time by log line like: + + **Percent of time when GPU is busy: 102.88%** + + Sometimes you may see a value larger than 100%. The reason is because PyTorch + uses the kernel execution time with profiling enabled while using wall time + with profiling disabled. Profiling may distort the kernel execution time a + bit. But overall it should not be a big deal. + + If we run the model like ``densenet121`` with a small batch size, we would see + low percent of time when GPU is busy: + + ```bash + (Forward graph) Percent of time when GPU is busy: 32.69% + ``` + + This means the model has a lot of CPU overhead. This is consistent with + the fact that enabling cudagraphs improve densenet121’s perf a lot. + +* We can break down the GPU time to different categories of kernels. + In the ``mixnet_l`` example, we see + + - pointwise kernel takes 28.58% + - reduction kernel takes 13.85% + - persistent reduction kernel takes 3.89% + - the rest are cutlass/cudnn kernels for mm/conv which takes 56.57% + + This information can be found in the summary line (last line) + of the report for each kernel category. + +* We also call zoom into a certain category of kernels. For example, + let’s check reduction kernels: + + ```{image} _static/img/inductor_profiling/kernel_breakdown.png + ``` + + We can see an ordered table of execution time for each individual + reduction kernel. We also see how many times a kernel is executed. This + is helpful for a few reasons: + + - If a kernel only takes a tiny amount of time, for example, 0.1%, + improving it will at most bring 0.1% overall gain. It is not + worth spending a lot of effort on it. + - Ff a kernel takes 2% of time, improving it by 2x will bring in 1% + overall gain which justifies the effort. + +## Benchmark Individual Triton Kernel + +Let’s say we want to take a closer look at +``triton_red_fused\__native_batch_norm_legit_functional_16`` which is the +most expensive reduction kernel and takes 2.19% of overall wall time for +the forward graph. + +We can lookup the kernel name in the ``fwd.py``, and find comment like: + +**# kernel path: +/tmp/torchinductor_shunting/jk/cjk2vm3446xrk7rth7hr6pun7xxo3dnzubwcn6ydrpifal4eykrz.py** + +```{image} _static/img/inductor_profiling/inductor_code.png +``` + +I’ll rename it k.py for convenience. Here is a paste for this [file](https://gist.github.com/shunting314/96a0afef9dce53d6357bf1633094f358). + +``k.py`` is a standalone Python module containing the kernel code and its +benchmark. + +Run ``k.py`` directly will report its execution time and bandwidth: + + ```{image} _static/img/inductor_profiling/terminal_printout.png + ``` + +We can check if max-autotune helps this kernel, by running: + +```bash + **TORCHINDUCTOR_MAX_AUTOTUNE=1 python /tmp/k.py** +``` +We may also temporarily add more reduction heuristics and run the script +again to check how that helps with the kernel. diff --git a/docs/source/torch.compiler_inductor_profiling.rst b/docs/source/torch.compiler_inductor_profiling.rst deleted file mode 100644 index fb060137b507d5..00000000000000 --- a/docs/source/torch.compiler_inductor_profiling.rst +++ /dev/null @@ -1,177 +0,0 @@ -.. _torchinductor-gpu-profiling: - -TorchInductor GPU Profiling -=========================== - -This section lists useful commands and workflows that can help -you dive into a model’s performance in TorchInductor. When a model is not -running as fast as expected, you may want to check individual kernels of the -model. Usually, those kernels taking the majority of the -GPU time are the most interesting ones. After that, you -may also want to run individual kernels directly and inspect its perf. -PyTorch provides tools to cover everything mentioned above. - -Relevant Environment Variables -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -You can use the following environment variables in your analysis: - -- ``TORCHINDUCTOR_UNIQUE_KERNEL_NAMES`` - - - By default, TorchInductor names a Triton kernel as ``‘triton\_’``. When - this environmental variable is enabled, inductor generates a more - meaningful kernel name in the trace, for example, - ``triton_poi_fused_cat_155`` which contains the kernel category - (``poi`` for pointwise) and original ATen - operator. This config is disabled by default to improve the chance of - compilation cache hit. - -- ``TORCHINDUCTOR_BENCHMARK_KERNEL`` - - - Enabling this will make inductor codegen harness to benchmark - individual triton kernels. - -- ``TORCHINDUCTOR_MAX_AUTOTUNE`` - - - Inductor autotuner will benchmark more ``triton.Configs`` and pick the - one with the best performance results. This will increase compilation - time with the hope to improve performance. - -Breakdown Model GPU Time -~~~~~~~~~~~~~~~~~~~~~~~~ - -Below are the steps to breakdown execution time of a model into -individual kernels. We take ``mixnet_l`` as an example. - -1. Run the benchmark script for the model: - - .. code-block:: bash - - TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1 - python -u benchmarks/dynamo/timm_models.py –backend inductor –amp - –performance –dashboard –only mixnet_l –disable-cudagraphs –training - - .. note:: The tool relies on kernel name to decide its category. Enabling - ``TORCHINDUCTOR_UNIQUE_KERNEL_NAMES`` is crucial for that. - -2. In the output log, look for lines: - - .. code-block:: bash - - **Compiled module path: - /tmp/torchinductor_shunting/qz/cqz7hvhood7y3psp7fy6msjxsxyli7qiwiybizdwtjw6ffyq5wwd.py** - -We have one line for each compiled module. If there are no extra graph -breaks, we would see 2 such lines in the log, one for the forward graph -and one for the backward graph. - -For our example command, we get the following compiled module for the -forward and backward graphs respectively: - -- https://gist.github.com/shunting314/c2a4d8a28b00fcb5586d0e9d9bf77f9f -- https://gist.github.com/shunting314/48efc83b12ec3ead950052e4a0220b10 - -3. Now we can dive into the perf for each individual compiled module. - Let’s pick the one for the forward graph for illustration purposes. - I’ll name it ``fwd.py`` for convenience. Run it directly with the - ``-p`` argument: - - .. code-block:: bash - - **> python fwd.py -p** - -See the full output log in this -`example gist `__. - -In the output, you can notice the following: - -* We write a chrome trace file for the profile so we can load the trace and interact with it. In the log, look for lines as follows to find the path of the trace file. - - **Chrome trace for the profile is written to - /tmp/compiled_module_profile.json** - - Loading the trace into Chrome (visit chrome://tracing in the chrome - browser and load the file as the UI suggested) will show UI as follows: - - .. image:: _static/img/inductor_profiling/trace.png - - You can zoom in and out to check the profile. - -* We report the percent of GPU time regarding to the wall time by log line like: - - **Percent of time when GPU is busy: 102.88%** - - Sometimes you may see a value larger than 100%. The reason is because PyTorch - uses the kernel execution time with profiling enabled while using wall time - with profiling disabled. Profiling may distort the kernel execution time a - bit. But overall it should not be a big deal. - - If we run the model like ``densenet121`` with a small batch size, we would see - low percent of time when GPU is busy: - - :: - - (Forward graph) Percent of time when GPU is busy: 32.69% - - This means the model has a lot of CPU overhead. This is consistent with - the fact that enabling cudagraphs improve densenet121’s perf a lot. - -* We can break down the GPU time to different categories of kernels. - In the ``mixnet_l`` example, we see - - - pointwise kernel takes 28.58% - - reduction kernel takes 13.85% - - persistent reduction kernel takes 3.89% - - the rest are cutlass/cudnn kernels for mm/conv which takes 56.57% - - This information can be found in the summary line (last line) - of the report for each kernel category. - -* We also call zoom into a certain category of kernels. For example, - let’s check reduction kernels: - - .. image:: _static/img/inductor_profiling/kernel_breakdown.png - - We can see an ordered table of execution time for each individual - reduction kernel. We also see how many times a kernel is executed. This - is helpful for a few reasons: - - - If a kernel only takes a tiny amount of time, for example, 0.1%, - improving it will at most bring 0.1% overall gain. It is not - worth spending a lot of effort on it. - - Ff a kernel takes 2% of time, improving it by 2x will bring in 1% - overall gain which justifies the effort. - -Benchmark Individual Triton Kernel -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Let’s say we want to take a closer look at -``triton_red_fused\__native_batch_norm_legit_functional_16`` which is the -most expensive reduction kernel and takes 2.19% of overall wall time for -the forward graph. - -We can lookup the kernel name in the ``fwd.py``, and find comment like: - -**# kernel path: -/tmp/torchinductor_shunting/jk/cjk2vm3446xrk7rth7hr6pun7xxo3dnzubwcn6ydrpifal4eykrz.py** - -.. image:: _static/img/inductor_profiling/inductor_code.png - -I’ll rename it k.py for convenience. Here is a paste for this -`file `__. - -``k.py`` is a standalone Python module containing the kernel code and its -benchmark. - -Run ``k.py`` directly will report its execution time and bandwidth: - -.. image:: _static/img/inductor_profiling/terminal_printout.png - -We can check if max-autotune helps this kernel, by running: - -.. code-block:: bash - - **TORCHINDUCTOR_MAX_AUTOTUNE=1 python /tmp/k.py** - -We may also temporarily add more reduction heuristics and run the script -again to check how that helps with the kernel. diff --git a/docs/source/torch.compiler_ir.md b/docs/source/torch.compiler_ir.md new file mode 100644 index 00000000000000..ed920a064a68db --- /dev/null +++ b/docs/source/torch.compiler_ir.md @@ -0,0 +1,38 @@ +# IRs + +PyTorch 2.0 offers two set of IRs for backends to interface with: Core Aten IR and Prims IR. + +## Core Aten IR + +Core aten ops is the core subset of aten operators that can be used to compose other operators. +Core aten IR is fully functional, and there is no `inplace` or `_out` variants in this opset. +In contrast to Prims IR, core aten ops reuses the existing aten ops in "native_functions.yaml", +and it doesn't further decompose ops into explicit type promotion and broadcasting ops. +This opset is designed to serve as the functional IR to interface with backends. + +```{warning} + This opset is still under active development, more ops will be added in the future. +``` + +```{csv-table} + :file: ../build/ir/aten_ops.csv + :widths: auto + :header-rows: 1 +``` + +## Prims IR + +Prims IR is a set of primitive operators that can be used to compose other operators. +Prims IR is a lower level opset than core aten IR, and it further decomposes ops into explicit +type promotion and broadcasting ops: prims.convert_element_type and prims.broadcast_in_dim. +This opset is designed to interface with compiler backends. + +```{warning} + This opset is still under active development, more ops will be added in the future. +``` + +```{csv-table} + :file: ../build/ir/prims_ops.csv + :widths: auto + :header-rows: 1 +``` diff --git a/docs/source/torch.compiler_ir.rst b/docs/source/torch.compiler_ir.rst deleted file mode 100644 index 7da1e102756350..00000000000000 --- a/docs/source/torch.compiler_ir.rst +++ /dev/null @@ -1,39 +0,0 @@ -.. _torch.compiler_ir: - -IRs -=============== - -PyTorch 2.0 offers two set of IRs for backends to interface with: Core Aten IR and Prims IR. - -Core Aten IR --------------------- - -Core aten ops is the core subset of aten operators that can be used to compose other operators. -Core aten IR is fully functional, and there is no `inplace` or `_out` variants in this opset. -In contrast to Prims IR, core aten ops reuses the existing aten ops in "native_functions.yaml", -and it doesn't further decompose ops into explicit type promotion and broadcasting ops. -This opset is designed to serve as the functional IR to interface with backends. - -.. warning:: - This opset is still under active development, more ops will be added in the future. - -.. csv-table:: - :file: ../build/ir/aten_ops.csv - :widths: auto - :header-rows: 1 - -Prims IR ------------ - -Prims IR is a set of primitive operators that can be used to compose other operators. -Prims IR is a lower level opset than core aten IR, and it further decomposes ops into explicit -type promotion and broadcasting ops: prims.convert_element_type and prims.broadcast_in_dim. -This opset is designed to interface with compiler backends. - -.. warning:: - This opset is still under active development, more ops will be added in the future. - -.. csv-table:: - :file: ../build/ir/prims_ops.csv - :widths: auto - :header-rows: 1 diff --git a/docs/source/torch.compiler_nn_module.md b/docs/source/torch.compiler_nn_module.md new file mode 100644 index 00000000000000..a694e2c88dbd6c --- /dev/null +++ b/docs/source/torch.compiler_nn_module.md @@ -0,0 +1,59 @@ +# PyTorch 2.0 NNModule Support + +**Author**: [Will Constable](https://github.com/wconstab) + +`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces +arbitrary python classes, with the intent of producing faster code by making assumptions about the structure. + +This doc describes some of the tradeoffs or edge cases that come up due to this specialization. + +## NNModule Hooks Support + +Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered +they would simply be ignored in the compiled program. Indeed many users do not +use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases +for composing nn.Module hooks with `torch.compile`. + +Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`, +`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'. +These hooks are partially supported by `torch.compile` with limitations described below. + +Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still +unsupported by `torch.compile`. + +## `nn.Module.__call__` Hooks Usage and limitations + +By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter +and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove +or alter the hooks later, your use case should be supported by default. + +Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamo +occur when accessing backward_hooks dicts, which is probably avoiable with some work. Graph-breaks also impact the +timing of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads at +the same time. Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we would +still expect the backward hooks for a series of modules to all fire together after the whole compiled graph's backward +ran. + +**hooks on 'allowed modules'** +`torch.compile` treats common modules such as torch.conv, as well as modules that are difficult to trace, specially +by allowing them to be called opaquely in the dynamo graph instead of traced into by dynamo. For such modules, hooks +currently trigger a graph-break so that the affected modules run outside of dynamo. Depending on the model, this could +introduce a significant performance regression, and additional work is required to improve this support. + +**skip_nnmodule_hook_guards** +By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed +on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing +if any hook dict is changed after compilation. + +If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately +(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added +guards. + +TODO: confirm if backward/pre_backward hooks are working or not and document accordingly + +## state_dict Hooks + +State dict hooks have not yet been supported in `torch.compile`. + + +TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present. \ No newline at end of file diff --git a/docs/source/torch.compiler_nn_module.rst b/docs/source/torch.compiler_nn_module.rst deleted file mode 100644 index 21a8e624a247a1..00000000000000 --- a/docs/source/torch.compiler_nn_module.rst +++ /dev/null @@ -1,60 +0,0 @@ -PyTorch 2.0 NNModule Support -============================ - -**Author**: `Will Constable `_ - -`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces -arbitrary python classes, with the intent of producing faster code by making assumptions about the structure. - -This doc describes some of the tradeoffs or edge cases that come up due to this specialization. - -NNModule Hooks Support ----------------------- -Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered -they would simply be ignored in the compiled program. Indeed many users do not -use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases -for composing nn.Module hooks with `torch.compile`. - -Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`, -`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'. -These hooks are partially supported by `torch.compile` with limitations described below. - -Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still -unsupported by `torch.compile`. - -`nn.Module.__call__` Hooks Usage and limitations -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter -and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove -or alter the hooks later, your use case should be supported by default. - -Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamo -occur when accessing backward_hooks dicts, which is probably avoiable with some work. Graph-breaks also impact the -timing of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads at -the same time. Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we would -still expect the backward hooks for a series of modules to all fire together after the whole compiled graph's backward -ran. - -**hooks on 'allowed modules'** -`torch.compile` treats common modules such as torch.conv, as well as modules that are difficult to trace, specially -by allowing them to be called opaquely in the dynamo graph instead of traced into by dynamo. For such modules, hooks -currently trigger a graph-break so that the affected modules run outside of dynamo. Depending on the model, this could -introduce a significant performance regression, and additional work is required to improve this support. - -**skip_nnmodule_hook_guards** -By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed -on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing -if any hook dict is changed after compilation. - -If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately -(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added -guards. - -TODO: confirm if backward/pre_backward hooks are working or not and document accordingly - -state_dict Hooks -~~~~~~~~~~~~~~~~ -State dict hooks have not yet been supported in `torch.compile`. - - -TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present. \ No newline at end of file diff --git a/docs/source/torch.compiler_performance_dashboard.md b/docs/source/torch.compiler_performance_dashboard.md new file mode 100644 index 00000000000000..468d70f0dd5ccd --- /dev/null +++ b/docs/source/torch.compiler_performance_dashboard.md @@ -0,0 +1,47 @@ +# PyTorch 2.0 Performance Dashboard + +**Author:** [Bin Bao](https://github.com/desertfire) and [Huy Do](https://github.com/huydhn) + +PyTorch 2.0's performance is tracked nightly on this [dashboard](https://hud.pytorch.org/benchmark/compilers). +The performance collection runs on 12 GCP A100 nodes every night. Each node contains a 40GB A100 Nvidia GPU and +a 6-core 2.2GHz Intel Xeon CPU. The corresponding CI workflow file can be found +[here](https://github.com/pytorch/pytorch/blob/main/.github/workflows/inductor-perf-test-nightly.yml). + +## How to read the dashboard? + +The landing page shows tables for all three benchmark suites we measure, ``TorchBench``, ``Huggingface``, and ``TIMM``, +and graphs for one benchmark suite with the default setting. For example, the default graphs currently show the AMP +training performance trend in the past 7 days for ``TorchBench``. Droplists on the top of that page can be +selected to view tables and graphs with different options. In addition to the pass rate, there are 3 key +performance metrics reported there: ``Geometric mean speedup``, ``Mean compilation time``, and +``Peak memory footprint compression ratio``. +Both ``Geometric mean speedup`` and ``Peak memory footprint compression ratio`` are compared against +the PyTorch eager performance, and the larger the better. Each individual performance number on those tables can be clicked, +which will bring you to a view with detailed numbers for all the tests in that specific benchmark suite. + +## What is measured on the dashboard? + +All the dashboard tests are defined in this +[function](https://github.com/pytorch/pytorch/blob/3e18d3958be3dfcc36d3ef3c481f064f98ebeaf6/.ci/pytorch/test.sh#L305). +The exact test configurations are subject to change, but at the moment, we measure both inference and training +performance with AMP precision on the three benchmark suites. We also measure different settings of TorchInductor, +including ``default``, ``with_cudagraphs (default + cudagraphs)``, and ``dynamic (default + dynamic_shapes)``. + +## Can I check if my PR affects TorchInductor's performance on the dashboard before merging? + +Individual dashboard runs can be triggered manually by clicking the ``Run workflow`` button +[here](https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml) +and submitting with your PR's branch selected. This will kick off a whole dashboard run with your PR's changes. +Once it is done, you can check the results by selecting the corresponding branch name and commit ID +on the performance dashboard UI. Be aware that this is an expensive CI run. With the limited +resources, please use this functionality wisely. + +## How can I run any performance test locally? + +The exact command lines used during a complete dashboard run can be found in any recent CI run logs. +The [workflow page](https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml) +is a good place to look for logs from some of the recent runs. +In those logs, you can search for lines like +`python benchmarks/dynamo/huggingface.py --performance --cold-start-latency --inference --amp --backend inductor --disable-cudagraphs --device cuda` +and run them locally if you have a GPU working with PyTorch 2.0. +``python benchmarks/dynamo/huggingface.py -h`` will give you a detailed explanation on options of the benchmarking script. diff --git a/docs/source/torch.compiler_performance_dashboard.rst b/docs/source/torch.compiler_performance_dashboard.rst deleted file mode 100644 index 8704365e00741f..00000000000000 --- a/docs/source/torch.compiler_performance_dashboard.rst +++ /dev/null @@ -1,52 +0,0 @@ -PyTorch 2.0 Performance Dashboard -================================= - -**Author:** `Bin Bao `__ and `Huy Do `__ - -PyTorch 2.0's performance is tracked nightly on this `dashboard `__. -The performance collection runs on 12 GCP A100 nodes every night. Each node contains a 40GB A100 Nvidia GPU and -a 6-core 2.2GHz Intel Xeon CPU. The corresponding CI workflow file can be found -`here `__. - -How to read the dashboard? ---------------------------- - -The landing page shows tables for all three benchmark suites we measure, ``TorchBench``, ``Huggingface``, and ``TIMM``, -and graphs for one benchmark suite with the default setting. For example, the default graphs currently show the AMP -training performance trend in the past 7 days for ``TorchBench``. Droplists on the top of that page can be -selected to view tables and graphs with different options. In addition to the pass rate, there are 3 key -performance metrics reported there: ``Geometric mean speedup``, ``Mean compilation time``, and -``Peak memory footprint compression ratio``. -Both ``Geometric mean speedup`` and ``Peak memory footprint compression ratio`` are compared against -the PyTorch eager performance, and the larger the better. Each individual performance number on those tables can be clicked, -which will bring you to a view with detailed numbers for all the tests in that specific benchmark suite. - -What is measured on the dashboard? ------------------------------------ - -All the dashboard tests are defined in this -`function `__. -The exact test configurations are subject to change, but at the moment, we measure both inference and training -performance with AMP precision on the three benchmark suites. We also measure different settings of TorchInductor, -including ``default``, ``with_cudagraphs (default + cudagraphs)``, and ``dynamic (default + dynamic_shapes)``. - -Can I check if my PR affects TorchInductor's performance on the dashboard before merging? ------------------------------------------------------------------------------------------ - -Individual dashboard runs can be triggered manually by clicking the ``Run workflow`` button -`here `__ -and submitting with your PR's branch selected. This will kick off a whole dashboard run with your PR's changes. -Once it is done, you can check the results by selecting the corresponding branch name and commit ID -on the performance dashboard UI. Be aware that this is an expensive CI run. With the limited -resources, please use this functionality wisely. - -How can I run any performance test locally? --------------------------------------------- - -The exact command lines used during a complete dashboard run can be found in any recent CI run logs. -The `workflow page `__ -is a good place to look for logs from some of the recent runs. -In those logs, you can search for lines like -``python benchmarks/dynamo/huggingface.py --performance --cold-start-latency --inference --amp --backend inductor --disable-cudagraphs --device cuda`` -and run them locally if you have a GPU working with PyTorch 2.0. -``python benchmarks/dynamo/huggingface.py -h`` will give you a detailed explanation on options of the benchmarking script. diff --git a/docs/source/torch.compiler_profiling_torch_compile.md b/docs/source/torch.compiler_profiling_torch_compile.md new file mode 100644 index 00000000000000..885b43dc2eeffd --- /dev/null +++ b/docs/source/torch.compiler_profiling_torch_compile.md @@ -0,0 +1,252 @@ +# Profiling to understand torch.compile performance + +## What to use torch.profiler for: + +torch.profiler is helpful for understanding the performance of your program at a kernel-level granularity - for example, it can show graph breaks and resources utilization at the level of the program. The data provided by the profiler can often help users understand where to investigate further to understand model performance. + +To understand kernel-level performance, other tools exist, such as [Nvidia Nsight compute tool](https://developer.nvidia.com/nsight-compute), [AMD Omnitrace](https://rocm.docs.amd.com/projects/omnitrace/en/latest/), Intel® VTune™ Profiler or [inductor's profiling tools](https://docs.pytorch.org/docs/stable/torch.compiler_inductor_profiling.html#torchinductor-gpu-profiling) can be used. + +See also the [general pytorch profiler guide](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html). + +## Basics of using torch.profiler and viewing traces + +**Example program**: We'll use this example of profiling resnet18. Notice the following parts of this example program: + +* Include a warm-up run to wait for compilation to complete (this will warm up systems like the CUDA caching allocator) +* Use `torch.profiler.profile()` context for profiling the section we are interested in +* Use `prof.export_chrome_trace("trace.json")` to export the profiling artifact. + +```python + + import torch + from torchvision.models import resnet18 + + device = 'cuda' # or 'cpu', 'xpu', etc. + model = resnet18().to(device) + + inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)] + + model_c = torch.compile(model) + + def fwd_bwd(inp): + out = model_c(inp) + out.sum().backward() + + # warm up + fwd_bwd(inputs[0]) + + with torch.profiler.profile() as prof: + for i in range(1, 4): + fwd_bwd(inputs[i]) + prof.step() + + prof.export_chrome_trace("trace.json") +``` + +**Viewing chrome traces**: In the Chrome browser, open chrome://tracing and load the json file. Use the “w” and “s” keys to zoom in and out, and use “a” and “d” to scroll left and right. “?” will show a “help” screen with a list of shortcuts. + +```{figure} _static/img/profiling_torch_compile/basic_chrome_trace.png +:alt: Example of a basic chrome trace, visualized in the chrome://tracing viewer +``` + +Here, we observe: +* CompiledFunction and CompiledFunctionBackward events, which correspond to the dynamo-compiled regions. +* CPU events at the top, and GPU events at the bottom. + +**Flows between CPU and accelerator events** + +Every kernel on the accelerator occurs after being launched by code running on the CPU. The profiler can draw connections (i.e. “flows”) between the accelerator and CPU events to show which CPU event launched a accelerator kernel. This is particularly helpful because, with a few exceptions, accelerator kernels are launched asynchronously. + +To view a flow connection, click on a GPU kernel and click “ac2g”: + +```{figure} _static/img/profiling_torch_compile/ac2g.png +:alt: Visualization in the chrome://trace viewer, showing an async flow between a kernel and its launching location. +``` + +Alternatively, turn on *all* flows with the “Flow events” dropdown at the top. + +## Working around CUDA Graph profiling issues + +When CUDA graphs are enabled, some CUDA configurations (driver version under 525.85.12 or CUDA < 12) can encounter issues between the profiling tools and CUDA graphs. To fix these issues, add an empty profiling context at the top of your program: + +```python + + import torch + + torch.profiler._utils._init_for_cuda_graphs() + + # ... rest of program +``` + +## Understanding compilation time + +To understand why compilation is taking a long time, you can profile the first invocation of a torch.compile-ed program. Keep in mind that profile traces of compilations can be distorted more than typical profiling, because compilation workloads can be quite different from typical PyTorch workloads. In some cases, trace files may also be quite large. Traces > 1GB can be difficult to open with the chrome tracing tool. + +Note: roughly the same information can also be obtained in non-graphical format with :code:`torch._dynamo.utils.compile_times()`. This utility won’t show when the compilation steps occur, but it will show the amount of time spent on each step - and times will not be affected by any profiling overhead. + +See an example below: + +```python + + import torch + from torchvision.models import resnet18 + + # user can switch between cuda and xpu + device = 'cuda' + model = resnet18().to(device) + inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)] + + model_c = torch.compile(model) + + def fwd_bwd(inp): + out = model_c(inp) + out.sum().backward() + + def warmup_compile(): + def fn(x): + return x.sin().relu() + + x = torch.rand((2, 2), device=device, requires_grad=True) + fn_c = torch.compile(fn) + out = fn_c(x) + out.sum().backward() + + with torch.profiler.profile() as prof: + with torch.profiler.record_function("warmup compile"): + warmup_compile() + + with torch.profiler.record_function("resnet18 compile"): + fwd_bwd(inputs[0]) + + prof.export_chrome_trace("trace_compile.json") +``` + +```{figure} _static/img/profiling_torch_compile/compilation_profiling.png +:alt: A visualization in the chrome://trace viewer, showing dynamo and inductor compilation steps +``` + +Note a few things: + +* The first invocation should occur *during* profiling in order to capture compilation +* Add a warm-up compilation in order to initialize any systems that need to be lazily initialized. + +# Finding graph breaks: "Torch-Compiled Region" and "CompiledFunction" + +Although there are logging tools for identifying graph breaks, the profiler provides a quick visual method of identifying :ref:`graph breaks `. There are two profiler events to look for: **Torch-Compiled Region** and **CompiledFunction**. + +**Torch-Compiled Region** - which was introduced in PyTorch 2.2 - is a profiler event that covers the entire compiled region. Graph breaks almost always look the same: nested “Torch-Compiled Region” events. + +If you run two separate functions with torch.compile() applied independently on each of them, you should generally expect to see two adjacent (i.e NOT stacked/nested) Torch-Compiled regions. Meanwhile, if you encounter graph breaks (or disable()'ed/skipped regions), expect nested “Torch-Compiled Region” events. + +**CompiledFunction** - introduced in PyTorch 2.0 - is a profiler event that appears when gradients are required for any inputs. Each graph break will interrupt a CompiledFunction block, splitting it in two. CompiledFunction events only appear when Autograd is involved, i.e. some of the input tensors to the graph have requires_grad=True. + +When a CompiledFunction appears in a trace, it is typically paired with a CompiledFunctionBackward event in the backward pass. A “fwd-bwd link” should appear in the trace connecting the two, if the backward function is called. + +If your use case includes a graph that doesn't require grad and doesn't include "Torch-Compiled Region" events, it can be more difficult to identify whether torch.compile is being applied correctly. One clue can be the existence of Inductor-generated Triton kernels. + +See the synthetic example below for a demonstration: + +```python + + import torch + import torch._dynamo + # user can switch between cuda and xpu + device = 'cuda' + + class ModelWithBreaks(torch.nn.Module): + def __init__(self): + super().__init__() + def create_sequential(): + return torch.nn.Sequential( + torch.nn.Linear(128, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 128), + torch.nn.ReLU(), + ) + self.mod1 = create_sequential() + self.mod2 = create_sequential() + self.mod3 = create_sequential() + self.mod4 = create_sequential() + + def forward(self, inp): + mod1 = self.mod1(inp) + torch._dynamo.graph_break() + mod2 = self.mod2(mod1) + torch._dynamo.graph_break() + mod3 = self.mod3(mod2) + torch._dynamo.graph_break() + mod4 = self.mod4(mod3) + return mod4 + + model = ModelWithBreaks().to(device) + inputs = [torch.randn((128, 128), device=device) for _ in range(10)] + + model_c = torch.compile(model) + + def fwd_bwd(inp): + out = model_c(inp) + out.sum().backward() + + # warm up + fwd_bwd(inputs[0]) + + with torch.profiler.profile() as prof: + for i in range(1, 4): + fwd_bwd(inputs[i]) + prof.step() + + prof.export_chrome_trace("trace_break.json") +``` + +```{figure} _static/img/profiling_torch_compile/graph_breaks_with_torch_compiled_region.png +:alt: Visualization in the chrome://trace viewer, showing nested Torch-Compiled Region events and multiple CompiledFunction events - indicating graph breaks. +``` + +## Operator Kernels + +When an operator is launched, we expect to see a few events: + +1. CPU-side event +2. Kernel launch (if dealing with a GPU kernel) +3. GPU-side event + +```{figure} _static/img/profiling_torch_compile/kernel_launch_labeled.png +:alt: Visualization in the chrome://trace viewer, showing the three types of events - CPU-side event, kernel launch, and GPU-side event +``` + +**Inductor-generated Triton kernels:** +1. The **CPU-side event** should appear as an event prefixed with "triton\_". The events currently have minimal information - the kernel name and a launch, but less information than typical aten kernel launches (which contain input shapes, types, etc.). +2. The **kernel launch** should appear as cuLaunchKernel instead of cudaLaunchKernel (cudaLaunchKernel is typical for aten ops) +3. The **GPU-side event** should appear, and how descriptive the name will be depends on the inductor config for unique_kernel_names + +```{figure} _static/img/profiling_torch_compile/triton_kernel_launch.png +``` + +**Non-Inductor generated Triton kernels:** + +1. The **CPU-side** event may not appear in traces; the machinery for automatically inserting a profiler event is currently implemented at the Inductor level, so Triton kernels that bypass Inductor may not appear in traces, unless users have annotated them manually +2. The **kernel launch** should appear s cuLaunchKernel instead of cudaLaunchKernel (cudaLaunchKernel is typical for aten ops) +3. The **GPU-side** event should appear, named similarly to the triton kernel that was authored. + +```{figure} _static/img/profiling_torch_compile/noninductor_triton_kernel.png +``` + +**Inductor-generated CPU kernels:** + +1. The **CPU-side event** will not appear in traces; we haven't added profiling for this yet. +2. The **kernel launch** and **GPU-side events** don't exist + +**Non-Triton kernels** (i.e. aten kernels or custom ops) should also be expected to sometimes appear in traces. Sometimes, Inductor will fall back to the original op implementation, in which case you will see a call to the aten op. + + +## Launch overhead + +One common issue is bad GPU utilization. A quick way to identify this is if there are large gaps between kernels on the GPU: + +```{figure} _static/img/profiling_torch_compile/cpu_bound.png +:alt: Visualization in the chrome://trace viewer, showing large gaps between GPU kernels. This indicates that the model is CPU bound, likely due to overhead during kernel launches. +``` + +This is often the result of CPU overhead, e.g. if the amount of time spent on the CPU between kernel launches is larger than the amount of time spent by the GPU to process the kernels. The issue is more common for small batch sizes. + +When using inductor, enabling CUDA graphs can often help improve performance when launch overhead is a concern. \ No newline at end of file diff --git a/docs/source/torch.compiler_profiling_torch_compile.rst b/docs/source/torch.compiler_profiling_torch_compile.rst deleted file mode 100644 index 4462e921848f85..00000000000000 --- a/docs/source/torch.compiler_profiling_torch_compile.rst +++ /dev/null @@ -1,249 +0,0 @@ -Profiling to understand torch.compile performance -================================================= - -What to use torch.profiler for: -------------------------------- - -torch.profiler is helpful for understanding the performance of your program at a kernel-level granularity - for example, it can show graph breaks and resources utilization at the level of the program. The data provided by the profiler can often help users understand where to investigate further to understand model performance. - -To understand kernel-level performance, other tools exist, such as `Nvidia Nsight compute tool `_, `AMD Omnitrace `_, Intel® VTune™ Profiler or :ref:`inductor's profiling tools ` can be used. - -See also the `general pytorch profiler guide `_. - -Basics of using torch.profiler and viewing traces -------------------------------------------------- - -**Example program**: We'll use this example of profiling resnet18. Notice the following parts of this example program: - -* Include a warm-up run to wait for compilation to complete (this will warm up systems like the CUDA caching allocator) -* Use :code:`torch.profiler.profile()` context for profiling the section we are interested in -* Use :code:`prof.export_chrome_trace("trace.json")` to export the profiling artifact. - -.. code-block:: python - - import torch - from torchvision.models import resnet18 - - device = 'cuda' # or 'cpu', 'xpu', etc. - model = resnet18().to(device) - - inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)] - - model_c = torch.compile(model) - - def fwd_bwd(inp): - out = model_c(inp) - out.sum().backward() - - # warm up - fwd_bwd(inputs[0]) - - with torch.profiler.profile() as prof: - for i in range(1, 4): - fwd_bwd(inputs[i]) - prof.step() - - prof.export_chrome_trace("trace.json") - -**Viewing chrome traces**: In the Chrome browser, open chrome://tracing and load the json file. Use the “w” and “s” keys to zoom in and out, and use “a” and “d” to scroll left and right. “?” will show a “help” screen with a list of shortcuts. - -.. figure:: _static/img/profiling_torch_compile/basic_chrome_trace.png - :alt: Example of a basic chrome trace, visualized in the chrome://tracing viewer - -Here, we observe: -* CompiledFunction and CompiledFunctionBackward events, which correspond to the dynamo-compiled regions. -* CPU events at the top, and GPU events at the bottom. - -**Flows between CPU and accelerator events** - -Every kernel on the accelerator occurs after being launched by code running on the CPU. The profiler can draw connections (i.e. “flows”) between the accelerator and CPU events to show which CPU event launched a accelerator kernel. This is particularly helpful because, with a few exceptions, accelerator kernels are launched asynchronously. - -To view a flow connection, click on a GPU kernel and click “ac2g”: - -.. figure:: _static/img/profiling_torch_compile/ac2g.png - :alt: Visualization in the chrome://trace viewer, showing an async flow between a kernel and its launching location. - -Alternatively, turn on *all* flows with the “Flow events” dropdown at the top. - -Working around CUDA Graph profiling issues ------------------------------------------- - -When CUDA graphs are enabled, some CUDA configurations (driver version under 525.85.12 or CUDA < 12) can encounter issues between the profiling tools and CUDA graphs. To fix these issues, add an empty profiling context at the top of your program: - -.. code-block:: python - - import torch - - torch.profiler._utils._init_for_cuda_graphs() - - # ... rest of program - -Understanding compilation time ------------------------------- - -To understand why compilation is taking a long time, you can profile the first invocation of a torch.compile-ed program. Keep in mind that profile traces of compilations can be distorted more than typical profiling, because compilation workloads can be quite different from typical PyTorch workloads. In some cases, trace files may also be quite large. Traces > 1GB can be difficult to open with the chrome tracing tool. - -Note: roughly the same information can also be obtained in non-graphical format with :code:`torch._dynamo.utils.compile_times()`. This utility won’t show when the compilation steps occur, but it will show the amount of time spent on each step - and times will not be affected by any profiling overhead. - -See an example below: - -.. code-block:: python - - import torch - from torchvision.models import resnet18 - - # user can switch between cuda and xpu - device = 'cuda' - model = resnet18().to(device) - inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)] - - model_c = torch.compile(model) - - def fwd_bwd(inp): - out = model_c(inp) - out.sum().backward() - - def warmup_compile(): - def fn(x): - return x.sin().relu() - - x = torch.rand((2, 2), device=device, requires_grad=True) - fn_c = torch.compile(fn) - out = fn_c(x) - out.sum().backward() - - with torch.profiler.profile() as prof: - with torch.profiler.record_function("warmup compile"): - warmup_compile() - - with torch.profiler.record_function("resnet18 compile"): - fwd_bwd(inputs[0]) - - prof.export_chrome_trace("trace_compile.json") - -.. figure:: _static/img/profiling_torch_compile/compilation_profiling.png - :alt: A visualization in the chrome://trace viewer, showing dynamo and inductor compilation steps - - -Note a few things: - -* The first invocation should occur *during* profiling in order to capture compilation -* Add a warm-up compilation in order to initialize any systems that need to be lazily initialized. - -Finding graph breaks: "Torch-Compiled Region" and "CompiledFunction" --------------------------------------------------------------------- - -Although there are logging tools for identifying graph breaks, the profiler provides a quick visual method of identifying :ref:`graph breaks `. There are two profiler events to look for: **Torch-Compiled Region** and **CompiledFunction**. - -**Torch-Compiled Region** - which was introduced in PyTorch 2.2 - is a profiler event that covers the entire compiled region. Graph breaks almost always look the same: nested “Torch-Compiled Region” events. - -If you run two separate functions with torch.compile() applied independently on each of them, you should generally expect to see two adjacent (i.e NOT stacked/nested) Torch-Compiled regions. Meanwhile, if you encounter graph breaks (or disable()'ed/skipped regions), expect nested “Torch-Compiled Region” events. - -**CompiledFunction** - introduced in PyTorch 2.0 - is a profiler event that appears when gradients are required for any inputs. Each graph break will interrupt a CompiledFunction block, splitting it in two. CompiledFunction events only appear when Autograd is involved, i.e. some of the input tensors to the graph have requires_grad=True. - -When a CompiledFunction appears in a trace, it is typically paired with a CompiledFunctionBackward event in the backward pass. A “fwd-bwd link” should appear in the trace connecting the two, if the backward function is called. - -If your use case includes a graph that doesn't require grad and doesn't include "Torch-Compiled Region" events, it can be more difficult to identify whether torch.compile is being applied correctly. One clue can be the existence of Inductor-generated Triton kernels. - -See the synthetic example below for a demonstration: - -.. code-block:: python - - import torch - import torch._dynamo - # user can switch between cuda and xpu - device = 'cuda' - - class ModelWithBreaks(torch.nn.Module): - def __init__(self): - super().__init__() - def create_sequential(): - return torch.nn.Sequential( - torch.nn.Linear(128, 128), - torch.nn.ReLU(), - torch.nn.Linear(128, 128), - torch.nn.ReLU(), - ) - self.mod1 = create_sequential() - self.mod2 = create_sequential() - self.mod3 = create_sequential() - self.mod4 = create_sequential() - - def forward(self, inp): - mod1 = self.mod1(inp) - torch._dynamo.graph_break() - mod2 = self.mod2(mod1) - torch._dynamo.graph_break() - mod3 = self.mod3(mod2) - torch._dynamo.graph_break() - mod4 = self.mod4(mod3) - return mod4 - - model = ModelWithBreaks().to(device) - inputs = [torch.randn((128, 128), device=device) for _ in range(10)] - - model_c = torch.compile(model) - - def fwd_bwd(inp): - out = model_c(inp) - out.sum().backward() - - # warm up - fwd_bwd(inputs[0]) - - with torch.profiler.profile() as prof: - for i in range(1, 4): - fwd_bwd(inputs[i]) - prof.step() - - prof.export_chrome_trace("trace_break.json") - -.. figure:: _static/img/profiling_torch_compile/graph_breaks_with_torch_compiled_region.png - :alt: Visualization in the chrome://trace viewer, showing nested Torch-Compiled Region events and multiple CompiledFunction events - indicating graph breaks. - -Operator Kernels ----------------- - -When an operator is launched, we expect to see a few events: - -1. CPU-side event -2. Kernel launch (if dealing with a GPU kernel) -3. GPU-side event - -.. figure:: _static/img/profiling_torch_compile/kernel_launch_labeled.png - :alt: Visualization in the chrome://trace viewer, showing the three types of events: CPU-side event, kernel launch, and GPU-side event - -**Inductor-generated Triton kernels:** -1. The **CPU-side event** should appear as an event prefixed with "triton\_". The events currently have minimal information - the kernel name and a launch, but less information than typical aten kernel launches (which contain input shapes, types, etc.). -2. The **kernel launch** should appear as cuLaunchKernel instead of cudaLaunchKernel (cudaLaunchKernel is typical for aten ops) -3. The **GPU-side event** should appear, and how descriptive the name will be depends on the inductor config for unique_kernel_names - -.. figure:: _static/img/profiling_torch_compile/triton_kernel_launch.png - -**Non-Inductor generated Triton kernels:** - -1. The **CPU-side** event may not appear in traces; the machinery for automatically inserting a profiler event is currently implemented at the Inductor level, so Triton kernels that bypass Inductor may not appear in traces, unless users have annotated them manually -2. The **kernel launch** should appear s cuLaunchKernel instead of cudaLaunchKernel (cudaLaunchKernel is typical for aten ops) -3. The **GPU-side** event should appear, named similarly to the triton kernel that was authored. - -.. figure:: _static/img/profiling_torch_compile/noninductor_triton_kernel.png - -**Inductor-generated CPU kernels:** - -1. The **CPU-side event** will not appear in traces; we haven't added profiling for this yet. -2. The **kernel launch** and **GPU-side events** don't exist - -**Non-Triton kernels** (i.e. aten kernels or custom ops) should also be expected to sometimes appear in traces. Sometimes, Inductor will fall back to the original op implementation, in which case you will see a call to the aten op. - - -Launch overhead ---------------- - -One common issue is bad GPU utilization. A quick way to identify this is if there are large gaps between kernels on the GPU: - -.. figure:: _static/img/profiling_torch_compile/cpu_bound.png - :alt: Visualization in the chrome://trace viewer, showing large gaps between GPU kernels. This indicates that the model is CPU bound, likely due to overhead during kernel launches. - -This is often the result of CPU overhead, e.g. if the amount of time spent on the CPU between kernel launches is larger than the amount of time spent by the GPU to process the kernels. The issue is more common for small batch sizes. - -When using inductor, enabling CUDA graphs can often help improve performance when launch overhead is a concern. diff --git a/docs/source/torch.compiler_transformations.md b/docs/source/torch.compiler_transformations.md new file mode 100644 index 00000000000000..7291df298f3748 --- /dev/null +++ b/docs/source/torch.compiler_transformations.md @@ -0,0 +1,424 @@ +# Writing Graph Transformations on ATen IR + +## Passes + +Since the ATen IR sits at the FX Graph/GraphModule level, any +transformations written for FX Graphs can be easily applied onto the +ATen IR. If you’re familiar with writing FX graph transformations, then +this will be the same. + +The most direct way of writing transformations is by looping through the +given graph and directly manipulating the nodes within the graph. + +For example, let’s say we want to replace +`torch.ops.aten.add.Tensor()` calls with +`torch.ops.aten.mul.Tensor()` calls: + +```python +import torch + +def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: + node.target = torch.ops.aten.mul.Tensor +``` + +We can also delete and append new nodes through FX utility functions +that can be found in the +[Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) +documentation. For example, if we want to insert a +`torch.ops.aten.relu.default()` after the `add` call: + +```python +import torch + +def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: + + # Specifies the insertion point. Any nodes added to the graph within + # this scope will be inserted after `node` + with gm.graph.inserting_after(node): + # Insert a new `call_function` node with op `torch.ops.aten.relu.default` + new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,)) + # Replace all the places that use `node` to now use the `new_relu_node` + node.replace_all_uses_with(new_relu_node) +``` + +In general, transformations can be roughly categorized into a couple of +axis: + +Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating +many-to-one mapping (eg. fusion) + +Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing +backwards iteration (eg. dead code elimination) + +Axis C: 1. Dependent on local node information (eg. out-variant +conversion) 2. Dependent on global graph information (eg. memory +planning) + +Our projection on the frequency of these use cases are: 1. A.1, B.1, C.1 +2\. A.2 3. B.2, C.2 + +Although we can make all graph transformations through directly +manipulating the graph, we also provide some helper utilities for some +ease of use for the level 1 and 2 use-cases. + +### Transformer + +For level 1 uses cases (creating one-to-X mappings, doing forwards +iterations, and looking at local node information), we can utilize the +[Transformer](https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer) +class to execute each node and recreate a graph, except with the +transformations specified. + +#### One-to-One Pass + +An example for one-to-one mappings, if we wanted to replace an op A with +another op B, we can run the GraphModule, and very time we see op A, +return op B. + +An example is: + +```python +class ReplaceAddWithMul(torch.fx.Transformer): + def call_function(self, target, args, kwargs): + if target != torch.ops.aten.add.Tensor: + return super().call_function(target, args, kwargs) + return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs) + +transformed_graph_module = ReplaceAddWithMul(graph_module).transform() +``` + +The `super().call_function(target, args, kwargs, meta)` call creates a +`call_function` FX node, and returns the result of running the +operator with the given arguments. + +#### One-to-X Pass + +If we wanted to do one-to-X mappings, like replacing op A with 2 other +ops B and C, we would then make 2 calls to `super().call_function` to +create 2 FX nodes, one with op B and another with op C, and return the +result of running op C. + +For example: + +```python +class ReplaceAddWithMulSub(torch.fx.Transformer): + """ + Original: + def f(x, y): + return x + y + + After pass: + def f(x, y): + z = x * y + return z - y + """ + def call_function(self, target, args, kwargs): + if target != torch.ops.aten.add.Tensor: + return super().call_function(target, args, kwargs) + + x, y = args + + mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {}) + return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {}) + +transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform() +``` + +#### One-to-None Pass + +If we wanted to remove an op, we can just return the value passed into +the function: + +```python +class RemoveDetachPass(torch.fx.Transformer): + def call_function(self, target, args, kwargs): + if target not in ( + torch.ops.aten.detach.default, + torch.ops.aten.detach_copy.default, + ): + return super().call_function(target, args, kwargs, meta) + + assert len(args) == 1 + return args[0] + +transformed_graph_module = RemoveDetachPass(graph_module).transform() +``` + +#### Utilizing Local Information + +An example of utilizing local node information is, if we wanted to +convert all the scalars within the graph to tensors, we can run the +given `fx.GraphModule`, and for every argument that contains a scalar, +we convert it to a tensor. It might look something like: + +```python +def args_map(target, fn, args, kwargs): + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + args = list(args) + kwargs = kwargs.copy() + + # Update the argument based on the function passed + def update(key, args, schema): + args[key] = fn(args[key], schema) + + # Update each argument in the schema + for i, schema in enumerate(target._schema.arguments): + if schema.name in kwargs: + update(schema.name, kwargs, schema) + elif not schema.kwarg_only and i < len(args): + update(i, args, schema) + return tuple(args), kwargs + +class ScalarToTensorPass(torch.fx.Transformer): + def call_function(self, target, args, kwargs): + breakpoint() + def try_coerce(value, arg): + return ( + torch.tensor(value) + if isinstance(value, (float, int, bool)) + and type(arg.type) == torch.TensorType + else value + ) + + args, kwargs = args_map(target, try_coerce, args, kwargs) + return super().call_function(target, args, kwargs) + +transformed_graph_module = ScalarToTensorPass(graph_module).transform() +``` + +### Subgraph Rewriter + +For creating many-to-one mappings, we can utilize FX’s [subgraph +rewriter](https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py). +Given a `pattern`, it creates a subgraph of operators matching to the +pattern, and then replaces each matched subgraph with the +`replacement`. + +Note: + +``` +This is an inplace operation. +``` + +The `pattern` and `replacement` inputs must be callable functions or +GraphModules containing the same operators that are used within the +graph (ATen ops) so that the subgraph rewriter can find the correct +pattern in the graph. Inputs to the pattern/replacement callables will +be treated as wildcards when matching. + +An example: + +```python +from torch.fx import subgraph_rewriter + +def replace_patterns(graph_module): + def pattern(x, y): + x = torch.ops.aten.add.Tensor(x, y) + x = torch.ops.aten.mul.Tensor(x, y) + return x + + def replacement(x, y): + return torch.ops.aten.sub.Tensor(x, y) + +replaced_patterns = subgraph_rewriter.replace_pattern_with_filters( + traced_module, pattern, replacement +) +``` + +The subgraph rewriter returns a list of `ReplacedPatterns`: + +```python +@dataclass +class ReplacedPatterns: + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + # List of nodes that were added into the graph + replacements: List[Node] +``` + +Note: + +``` +The nodes created by the subgraph rewriter will not have the metadata that +is populated in the matched nodes, but you can use +`ReplacedPatterns.nodes_map` to find the nodes in the original graph that +were matched, and `ReplacedPatterns.replacements` to find the nodes that +were replaced in the transformed graph. +``` + +## Pass Manager + +The +[PassManager](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py) +is a class used to run multiple passes on a given graph module. When +initializing a `PassManager` instance, we pass in a list of passes +that we want to run and set a couple of flags. To run the collection of +passes on a graph module, we can pass the graph module directly to the +`PassManager` instance. + +An example: + +```python +from torch.fx.passes.infra.pass_manager import PassManager + +pm = PassManager( + passes=[replace_add_with_div, replace_div_with_mul], + run_checks_after_each_pass=True, + suppress_check_failures=False, +) +graph_module_out = pm(graph_module) +``` + +To add a common set of checks that are run after each pass, we can call +the function `set_checks(check: Callable)` which takes in a callable +function as input. If the `run_checks_after_each_pass` flag is set, +the `check` will be called after each pass is run on the graph module. + +An example: + +```python +pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul]) + +def check_div_target(graph_module): + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target != torch.div: + raise ValueError("Target should be div!") + +pm.add_checks(check_div_target) + +pm(graph_module) # raises ValueError after replace_div_with_mul pass +``` + +## Partitioner + +There are a couple of common FX graph based partitioners we can use to +partition the graph. + +### Subgraph Matcher + +For finding subgraphs within a graph that match a specific pattern, we +can utilize FX’s +[`SubgraphMatcher`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py). + +Class Attributes: + +- `pattern (Graph)`: The targeted matching pattern. Placeholder nodes + in the graph will be treated as wildcards when matching. +- `match_output (bool)`: If True, output node in the pattern graph + will be treated as a part of the targeted pattern. If False, output + node is ignored during match. +- `match_placeholder (bool)`: If True, placeholder node in the + pattern graph will be treated as a part of the targeted pattern. If + False, placeholder nodes will be used a wildcard. +- `remove_overlapping_matches (bool)`: If True, in the case of + overlapping matches, only the first match will be returned. +- `ignore_literals (bool)`: If True, will not check if literals are + equal and will instead treat them as wildcards. + +An example: + +```python +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher + +class LargeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self._weight = torch.nn.Parameter(torch.ones(3, 3)) + self._bias = torch.nn.Parameter(torch.ones(3, 3)) + + def forward(self, x): + return torch.ops.aten.addmm.default(self._bias, x, self._weight) + +large_model_graph = torch.export(LargeModel(), inputs).graph + +class PatternModel(torch.nn.Module): + def __init__(self): + super().__init__() + self._weight_1 = torch.nn.Parameter(torch.ones(5, 5)) + self._bias_1 = torch.nn.Parameter(torch.ones(5, 5)) + + def forward(self, x): + return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1) + +pattern_graph = torch.export(PatternModel(), inputs).graph + +subgraph_matcher = SubgraphMatcher(pattern_graph) +match_result = subgraph_matcher.match(large_model_graph) +``` + +The `match` function returns a list of `InternalMatch`: + +```python +@dataclass +class InternalMatch(): + # Nodes from which the match was found + anchors: List[Node] + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] = field(default_factory=dict) + # Nodes in target graph that are matched placeholder in pattern + placeholder_nodes: List[Node] = field(default_factory=list) + # Nodes in matched subgraph returned by output + returning_nodes: List[Node] = field(default_factory=list) +``` + +### Capability Based Partitioner + +To find the largest subgraphs of nodes that support a specific +invariant, we can utilize FX’s +[`CapabilityBasedPartitioner`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34). + +Class Attributes + +- `graph_module (torch.fx.GraphModule)`: The graph module we are + partitioning on. +- `operator_support (OperatorSupportBase)`: The object used to + determine if a node in the graph is supported in the partition. +- `allows_single_node_partition (bool)`: If True, allows single node + partitions to be formed. +- `non_compute_ops (Optional[Sequence[str]])`: A set of ops that are + considered to be “non-compute” (ex `torch.ops.aten.view` and + `_operator.getitem`, so that the partitioner will not create graphs + that only contain these non-compute ops +- `allowed_single_node_partition_ops (Optional[Sequence[str]])`: A + set of ops that are allowed to be in a single node partition. + +The +[`OperatorSupportBase`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1) +class is used by the partitioner to determine if a specific node in the +graph belongs in the partition. This is done by overriding the +`is_node_supported` function. You can chain multiple +`OperatorSupportBase` by using +[`chain`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150) (which +returns False if any of the OperatorSupportBase return False) and +[`any_chain`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164) +(which returns True if any of the OperatorSupportBase returns True). + +An example: + +```python +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import any_chain, OperatorSupportBase + +class AddMulOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor, + ] + +capability_partitioner = CapabilityBasedPartitioner( + graph_module, + op_support, +) + +# Returns a list of partitions (list of nodes that belong in each partition) +partition_list = capability_partitioner.propose_partitions() +# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph +fused_graph_module = capability_partitioner.fuse_partitions(partition_list) +``` diff --git a/docs/source/torch.compiler_transformations.rst b/docs/source/torch.compiler_transformations.rst deleted file mode 100644 index 4cc15b7b2339f3..00000000000000 --- a/docs/source/torch.compiler_transformations.rst +++ /dev/null @@ -1,436 +0,0 @@ -Writing Graph Transformations on ATen IR -======================================== - -Passes ------- - -Since the ATen IR sits at the FX Graph/GraphModule level, any -transformations written for FX Graphs can be easily applied onto the -ATen IR. If you’re familiar with writing FX graph transformations, then -this will be the same. - -The most direct way of writing transformations is by looping through the -given graph and directly manipulating the nodes within the graph. - -For example, let’s say we want to replace -``torch.ops.aten.add.Tensor()`` calls with -``torch.ops.aten.mul.Tensor()`` calls: - -.. code:: python - - import torch - - def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: - node.target = torch.ops.aten.mul.Tensor - -We can also delete and append new nodes through FX utility functions -that can be found in the -`Graph `__ -documentation. For example, if we want to insert a -``torch.ops.aten.relu.default()`` after the ``add`` call: - -.. code:: python - - import torch - - def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: - - # Specifies the insertion point. Any nodes added to the graph within - # this scope will be inserted after `node` - with gm.graph.inserting_after(node): - # Insert a new `call_function` node with op `torch.ops.aten.relu.default` - new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,)) - # Replace all the places that use `node` to now use the `new_relu_node` - node.replace_all_uses_with(new_relu_node) - -In general, transformations can be roughly categorized into a couple of -axis: - -Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating -many-to-one mapping (eg. fusion) - -Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing -backwards iteration (eg. dead code elimination) - -Axis C: 1. Dependent on local node information (eg. out-variant -conversion) 2. Dependent on global graph information (eg. memory -planning) - -Our projection on the frequency of these use cases are: 1. A.1, B.1, C.1 -2. A.2 3. B.2, C.2 - -Although we can make all graph transformations through directly -manipulating the graph, we also provide some helper utilities for some -ease of use for the level 1 and 2 use-cases. - -Transformer -~~~~~~~~~~~ - -For level 1 uses cases (creating one-to-X mappings, doing forwards -iterations, and looking at local node information), we can utilize the -`Transformer `__ -class to execute each node and recreate a graph, except with the -transformations specified. - -One-to-One Pass -^^^^^^^^^^^^^^^ - -An example for one-to-one mappings, if we wanted to replace an op A with -another op B, we can run the GraphModule, and very time we see op A, -return op B. - -An example is: - -.. code:: python - - class ReplaceAddWithMul(torch.fx.Transformer): - def call_function(self, target, args, kwargs): - if target != torch.ops.aten.add.Tensor: - return super().call_function(target, args, kwargs) - return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs) - - transformed_graph_module = ReplaceAddWithMul(graph_module).transform() - -The ``super().call_function(target, args, kwargs, meta)`` call creates a -``call_function`` FX node, and returns the result of running the -operator with the given arguments. - -One-to-X Pass -^^^^^^^^^^^^^ - -If we wanted to do one-to-X mappings, like replacing op A with 2 other -ops B and C, we would then make 2 calls to ``super().call_function`` to -create 2 FX nodes, one with op B and another with op C, and return the -result of running op C. - -For example: - -.. code:: python - - class ReplaceAddWithMulSub(torch.fx.Transformer): - """ - Original: - def f(x, y): - return x + y - - After pass: - def f(x, y): - z = x * y - return z - y - """ - def call_function(self, target, args, kwargs): - if target != torch.ops.aten.add.Tensor: - return super().call_function(target, args, kwargs) - - x, y = args - - mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {}) - return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {}) - - transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform() - -One-to-None Pass -^^^^^^^^^^^^^^^^ - -If we wanted to remove an op, we can just return the value passed into -the function: - -.. code:: python - - class RemoveDetachPass(torch.fx.Transformer): - def call_function(self, target, args, kwargs): - if target not in ( - torch.ops.aten.detach.default, - torch.ops.aten.detach_copy.default, - ): - return super().call_function(target, args, kwargs, meta) - - assert len(args) == 1 - return args[0] - - transformed_graph_module = RemoveDetachPass(graph_module).transform() - -Utilizing Local Information -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -An example of utilizing local node information is, if we wanted to -convert all the scalars within the graph to tensors, we can run the -given ``fx.GraphModule``, and for every argument that contains a scalar, -we convert it to a tensor. It might look something like: - -.. code:: python - - def args_map(target, fn, args, kwargs): - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - args = list(args) - kwargs = kwargs.copy() - - # Update the argument based on the function passed - def update(key, args, schema): - args[key] = fn(args[key], schema) - - # Update each argument in the schema - for i, schema in enumerate(target._schema.arguments): - if schema.name in kwargs: - update(schema.name, kwargs, schema) - elif not schema.kwarg_only and i < len(args): - update(i, args, schema) - return tuple(args), kwargs - - class ScalarToTensorPass(torch.fx.Transformer): - def call_function(self, target, args, kwargs): - breakpoint() - def try_coerce(value, arg): - return ( - torch.tensor(value) - if isinstance(value, (float, int, bool)) - and type(arg.type) == torch.TensorType - else value - ) - - args, kwargs = args_map(target, try_coerce, args, kwargs) - return super().call_function(target, args, kwargs) - - transformed_graph_module = ScalarToTensorPass(graph_module).transform() - -Subgraph Rewriter -~~~~~~~~~~~~~~~~~ - -For creating many-to-one mappings, we can utilize FX’s `subgraph -rewriter `__. -Given a ``pattern``, it creates a subgraph of operators matching to the -pattern, and then replaces each matched subgraph with the -``replacement``. - -Note: - -:: - - This is an inplace operation. - -The ``pattern`` and ``replacement`` inputs must be callable functions or -GraphModules containing the same operators that are used within the -graph (ATen ops) so that the subgraph rewriter can find the correct -pattern in the graph. Inputs to the pattern/replacement callables will -be treated as wildcards when matching. - -An example: - -.. code:: python - - from torch.fx import subgraph_rewriter - - def replace_patterns(graph_module): - def pattern(x, y): - x = torch.ops.aten.add.Tensor(x, y) - x = torch.ops.aten.mul.Tensor(x, y) - return x - - def replacement(x, y): - return torch.ops.aten.sub.Tensor(x, y) - - replaced_patterns = subgraph_rewriter.replace_pattern_with_filters( - traced_module, pattern, replacement - ) - -The subgraph rewriter returns a list of ``ReplacedPatterns``: - -.. code:: python - - @dataclass - class ReplacedPatterns: - # Node from which the match was found - anchor: Node - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] - # List of nodes that were added into the graph - replacements: List[Node] - -Note: - -:: - - The nodes created by the subgraph rewriter will not have the metadata that - is populated in the matched nodes, but you can use - `ReplacedPatterns.nodes_map` to find the nodes in the original graph that - were matched, and `ReplacedPatterns.replacements` to find the nodes that - were replaced in the transformed graph. - -Pass Manager ------------- - -The -```PassManager`` `__ -is a class used to run multiple passes on a given graph module. When -initializing a ``PassManager`` instance, we pass in a list of passes -that we want to run and set a couple of flags. To run the collection of -passes on a graph module, we can pass the graph module directly to the -``PassManager`` instance. - -An example: - -.. code:: python - - from torch.fx.passes.infra.pass_manager import PassManager - - pm = PassManager( - passes=[replace_add_with_div, replace_div_with_mul], - run_checks_after_each_pass=True, - suppress_check_failures=False, - ) - graph_module_out = pm(graph_module) - -To add a common set of checks that are run after each pass, we can call -the function ``set_checks(check: Callable)`` which takes in a callable -function as input. If the ``run_checks_after_each_pass`` flag is set, -the ``check`` will be called after each pass is run on the graph module. - -An example: - -.. code:: python - - pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul]) - - def check_div_target(graph_module): - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target != torch.div: - raise ValueError("Target should be div!") - - pm.add_checks(check_div_target) - - pm(graph_module) # raises ValueError after replace_div_with_mul pass - -Partitioner ------------ - -There are a couple of common FX graph based partitioners we can use to -partition the graph. - -Subgraph Matcher -~~~~~~~~~~~~~~~~ - -For finding subgraphs within a graph that match a specific pattern, we -can utilize FX’s -```SubgraphMatcher`` `__. - -Class Attributes: - -- ``pattern (Graph)``: The targeted matching pattern. Placeholder nodes - in the graph will be treated as wildcards when matching. -- ``match_output (bool)``: If True, output node in the pattern graph - will be treated as a part of the targeted pattern. If False, output - node is ignored during match. -- ``match_placeholder (bool)``: If True, placeholder node in the - pattern graph will be treated as a part of the targeted pattern. If - False, placeholder nodes will be used a wildcard. -- ``remove_overlapping_matches (bool)``: If True, in the case of - overlapping matches, only the first match will be returned. -- ``ignore_literals (bool)``: If True, will not check if literals are - equal and will instead treat them as wildcards. - -An example: - -.. code:: python - - from torch.fx.passes.utils.matcher_utils import SubgraphMatcher - - class LargeModel(torch.nn.Module): - def __init__(self): - super().__init__() - self._weight = torch.nn.Parameter(torch.ones(3, 3)) - self._bias = torch.nn.Parameter(torch.ones(3, 3)) - - def forward(self, x): - return torch.ops.aten.addmm.default(self._bias, x, self._weight) - - large_model_graph = torch.export(LargeModel(), inputs).graph - - class PatternModel(torch.nn.Module): - def __init__(self): - super().__init__() - self._weight_1 = torch.nn.Parameter(torch.ones(5, 5)) - self._bias_1 = torch.nn.Parameter(torch.ones(5, 5)) - - def forward(self, x): - return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1) - - pattern_graph = torch.export(PatternModel(), inputs).graph - - subgraph_matcher = SubgraphMatcher(pattern_graph) - match_result = subgraph_matcher.match(large_model_graph) - -The ``match`` function returns a list of ``InternalMatch``: - -.. code:: python - - @dataclass - class InternalMatch(): - # Nodes from which the match was found - anchors: List[Node] - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] = field(default_factory=dict) - # Nodes in target graph that are matched placeholder in pattern - placeholder_nodes: List[Node] = field(default_factory=list) - # Nodes in matched subgraph returned by output - returning_nodes: List[Node] = field(default_factory=list) - -Capability Based Partitioner -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To find the largest subgraphs of nodes that support a specific -invariant, we can utilize FX’s -```CapabilityBasedPartitioner`` `__. - -Class Attributes - -- ``graph_module (torch.fx.GraphModule)``: The graph module we are - partitioning on. -- ``operator_support (OperatorSupportBase)``: The object used to - determine if a node in the graph is supported in the partition. -- ``allows_single_node_partition (bool)``: If True, allows single node - partitions to be formed. -- ``non_compute_ops (Optional[Sequence[str]])``: A set of ops that are - considered to be “non-compute” (ex ``torch.ops.aten.view`` and - ``_operator.getitem``, so that the partitioner will not create graphs - that only contain these non-compute ops -- ``allowed_single_node_partition_ops (Optional[Sequence[str]])``: A - set of ops that are allowed to be in a single node partition. - -The -```OperatorSupportBase`` `__ -class is used by the partitioner to determine if a specific node in the -graph belongs in the partition. This is done by overriding the -``is_node_supported`` function. You can chain multiple -``OperatorSupportBase`` by using -```chain`` `__\ (which -returns False if any of the OperatorSupportBase return False) and -```any_chain`` `__ -(which returns True if any of the OperatorSupportBase returns True). - -An example: - -.. code:: python - - from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner - from torch.fx.passes.operator_support import any_chain, OperatorSupportBase - - class AddMulOperatorSupport(OperatorSupportBase): - def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return node.op == "call_function" and node.target in [ - torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor, - ] - - capability_partitioner = CapabilityBasedPartitioner( - graph_module, - op_support, - ) - - # Returns a list of partitions (list of nodes that belong in each partition) - partition_list = capability_partitioner.propose_partitions() - # Fuses the partitions into graph modules and inserts `call_module` nodes in the graph - fused_graph_module = capability_partitioner.fuse_partitions(partition_list) diff --git a/docs/source/torch.compiler_troubleshooting.md b/docs/source/torch.compiler_troubleshooting.md new file mode 100644 index 00000000000000..041d61cf9b9019 --- /dev/null +++ b/docs/source/torch.compiler_troubleshooting.md @@ -0,0 +1,1083 @@ +(torch.compiler_troubleshooting)= + +# torch.compile Troubleshooting + +You're trying to use `torch.compile` on your PyTorch model to enhance its performance +but it's not working as expected. Perhaps performance isn't improving, crashes are happening, or compilation time is too long. This article provides tips, workarounds, and debugging tools to help you overcome these challenges. + +**Contents** + +```{contents} +:local: true +``` + +## Setting Expectations + +`torch.compile` is designed as a general-purpose PyTorch compiler. +Unlike the previous compiler solution, TorchScript, `torch.compile` +requires fewer code changes, meaning models typically don't need to be rewritten from scratch. +It also manages unsupported code more gracefully - unsupported code results in a lost optimization opportunity rather than a crash. + +In the ideal world, one can simply apply `torch.compile` to any PyTorch model and enjoy automatic speedups. +However, in reality, code complexities can lead to one of three scenarios: + +1. `torch.compile` works seamlessly, providing speedups. +2. Some code modifications are necessary. `torch.compile` doesn't crash or take too long, + but you might not be seeing significant performance gains. +3. Extensive changes to your code are required. + +We anticipate most code will fall under scenarios (1) and (2). +This document provides tips, arranged by level of involvement, to help address code issues in scenario (2). + +### Compile times + +`torch.compile` functions as a just-in-time compiler, so the initial one or two runs +of the compiled function are expected to be significantly slower. Recompilations, which can occur under certain conditions (detailed below), +will also make runs slower. Various `torch.compile` components cache results to +reduce compilation time for future invocations, even in different processes. +Cold-start (uncached) compilation time typically ranges from seconds to minutes for common or benchmarked models. +Larger models may take upwards of 30 minutes to a few hours. + +## Terminology + +The following terms are relevant to troubleshooting `torch.compile` problems. + +### Graph break + +`torch.compile` traces your code and attempts to capture your PyTorch code into a +single computation graph of PyTorch operators (FX graph). However, this is not always possible. +When encountering code that can't be traced, a "graph break" occurs. +A graph break involves compiling the FX graph has been determined so far, running the unsupported code, +then resuming tracing after the unsupported code with a new FX graph. +Because the computation graph is broken up, we lose optimization opportunities, +so model code should avoid graph breaks whenever possible. +Graph breaks occur on things like: + +- Data-dependent if-statements +- Many Python built-in functions +- C functions + +Below is an example of a graph break due to the function `copy.deepcopy` from a Python builtin library +(exact output may differ). + +```py +import torch + +@torch.compile +def fn(x): + x = x + 1 + with open("test.txt", "r") as f: + return x + len(f.read()) + +fn(torch.ones(3, 3)) +``` + +``` +$TORCH_LOGS="graph_breaks" python playground.py +Graph break in user code at /data/users/williamwen/pytorch/playground.py:7 +Reason: Unsupported: builtin: open [, ] False +User code traceback: +File "/data/users/williamwen/pytorch/playground.py", line 7, in fn + with open("test.txt", "r") as f: +Traceback (most recent call last): +File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 635, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ +File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2414, in CALL + self._call(inst) +File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _call + self.call_function(fn, args, kwargs) +File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in call_function + self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 997, in call_function + return handler(tx, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^ +File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 831, in + return lambda *args: unimplemented(error_msg) + ^^^^^^^^^^^^^^^^^^^^^^^^ +File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented + raise Unsupported(msg, case_name=case_name) +torch._dynamo.exc.Unsupported: builtin: open [, ] False +``` + +### Guards + +`torch.compile` makes some assumptions about runtime values as we trace through code. +During tracing, we generate "guards", which are runtime checks for these assumptions. +Guards are run in future calls to the compiled function to determine if we can reuse previously compiled code. +Examples of runtime checks are constant values, types, and object IDs. + +Below is an example of generated guards. The `TENSOR_MATCH` guard checks for the input's type, device, dtype, shape, etc. + +```py +import torch + +@torch.compile +def fn(x): + return x + 1 + +fn(torch.ones(3, 3)) +``` + +``` +$ TORCH_LOGS="guards" python playground.py +GUARDS: + +TREE_GUARD_MANAGER: ++- RootGuardManager +| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:471 in init_ambient_guards +| +- GLOBAL_STATE: ___check_global_state() +| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack() +| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x) +| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1]) # return x + 1 # playground.py:6 in fn +| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # return x + 1 # playground.py:6 in fn +``` + +### Recompilation + +If the guards fail for every instance of previously compiled code, +then `torch.compile` must "recompile" the function, requiring the original code to be traced again. + +In the example below, recompilation is necessary because the guard checking the tensor argument's shape failed. + +```py +import torch + +@torch.compile +def fn(x): + return x + 1 + +fn(torch.ones(3, 3)) +fn(torch.ones(4, 4)) +``` + +``` +$ TORCH_LOGS="recompiles" python playground.py +Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3 + triggered by the following guard failure(s): + - 0/0: tensor 'L['x']' size mismatch at index 0. expected 3, actual 4 +``` + +### Dynamic Shapes + +`torch.compile` initially assumes tensor shapes are static/constant and guards based on these assumptions. +By using "dynamic shapes," we can get `torch.compile` to produce compiled code that can accept +tensor inputs with different shapes - we avoid recompiling every time shapes differ. +By default, automatic dynamic shapes are enabled `torch.compile(dynamic=None)` - +if compilation fails due to shape mismatch, recompilation is attempted with dynamic shapes. +Dynamic shapes can also be fully enabled `dynamic=True` or disabled `dynamic=False`. + +Below, we enable dynamic shapes and note that we no longer need to recompile. + +```py +import torch + +@torch.compile(dynamic=True) +def fn(x): + return x + 1 + +fn(torch.ones(3, 3)) +fn(torch.ones(4, 4)) +``` + +``` +$ TORCH_LOGS="dynamic,recompiles" python playground.py +create_symbol s0 = 3 for L['x'].size()[0] [2, int_oo] at playground.py:5 in fn (_dynamo/variables/builder.py:2718 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" +produce_guards +produce_guards +``` + +For more information on dynamic shapes, see [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng). + +## Logging Tools + +### tlparse / TORCH_TRACE + +`tlparse` / `TORCH_TRACE` are a pair of tools that produce compilation reports that look like this: +. + +Traces are very easy to collect. To collect a trace, run your reproduction command with + +``` +TORCH_TRACE="/tmp/tracedir" python foo.py +pip install tlparse +tlparse /tmp/tracedir +``` + +This approach works even if you are running a distributed job, providing a trace for each rank. +It will open your browser with HTML similar to what's generated above. +If you are making a bug report for a complicated problem that you don't have a standalone reproduction for, +you can still greatly assist PyTorch developers by attaching the trace log generated in `/tmp/tracedir`. + +```{warning} +The trace log contains all of your model code. +Do not share the trace log if the model you are working on is sensitive. The trace log does NOT contain weights. +``` + +```{raw} html + +``` + +```{eval-rst} +.. role:: red + +.. role:: green + +.. role:: dark-green +``` + +The output of `tlparse` is primarily aimed for PyTorch developers, +and the log format is easy to upload and share on GitHub. +However, as a non-PyTorch developer, you can still extract useful information from it. +We recommend starting with the inline help text in the report, which explains its contents. +Here are some insights you can gain from a `tlparse`: + +- What model code was compiled by looking at the stack trie? + This is especially useful if you're not familiar with the codebase being compiled! +- How many graph breaks / distinct compilation regions are there? + (Each distinct compile is its own color coded block like {dark-green}`[0/0]`). + Frames that are potentially graph-broken are light green {green}`[2/4]`. + If there are a lot of frames, that is suspicious, and suggests that you had some catastrophic graph breaks, + or maybe your code isn't a good match for `torch.compile`. +- How many times did I recompile a particular frame? Something that recompiled a lot will look like: + {dark-green}`[10/0]` {dark-green}`[10/1]` {dark-green}`[10/2]` + \- if something is being recompiled a lot, that is very suspicious and worth looking into, even if it isn't the root cause of your problem. +- Was there a compilation error? Frames that errored will look like {red}`[0/1]`. +- What intermediate compiler products did I generate for a given frame? + For example, you can look at the high-level generated FX graph or the generated Triton code. +- Is there relevant information for a particular frame? You can find these in `compilation_metrics`. + +### TORCH_LOGS + +You can use the `TORCH_LOGS` environment variable to selectively enable parts of the `torch.compile` stack to log. +`TORCH_LOGS` is in fact the source of logs for `tlparse`. The format of the `TORCH_LOGS` environment variable looks like this: + +``` +TORCH_LOGS=",,..." python foo.py +``` + +Useful high-level options include: + +- `graph_breaks`: logs locations of graph breaks in user code and the reason for the graph break +- `guards`: logs guards that are generated +- `recompiles`: logs which function recompiled and the guards that failed, leading to the recompilation +- `dynamic`: logs related to dynamic shapes + +Also, you can programmatically set logging options using `torch._logging.set_logs`: + +```py +import logging +torch._logging.set_logs(graph_breaks=True) +... +``` + +More `TORCH_LOGS` options are {ref}`troubleshooting-torch-logs-options`. +For the full list of options, see [torch.\_logging](https://pytorch.org/docs/stable/logging.html) +and [torch.\_logging.set_logs](https://pytorch.org/docs/stable/generated/torch._logging.set_logs.html#torch._logging.set_logs). + +### tlparse vs. TORCH_LOGS + +Generally, we suggest first using `tlparse` when encountering issues. +`tlparse` is ideal for debugging large models and gaining a high-level overview of how your model was compiled. +On the other hand, `TORCH_LOGS` is preferred for small examples and fine-grained debugging detail, +when we already have an idea of which `torch.compile` component is causing the problem. + +## Simple Workarounds + +Here, we describe some workarounds to `torch.compile` issues involving small code modifications +or changing some `torch.compile` settings. + +### Where to apply torch.compile? + +We recommend applying `torch.compile` to the highest-level function that doesn't cause excessive problems. +Typically, it is your train or eval step with the optimizer but without the loop, your top-level `nn.Module`, +or some sub-``` nn.Module``s. ``torch.compile ``` specifically doesn't handle distributed wrapper modules like +DDP or FSDP very well, so consider applying `torch.compile` to the inner module passed to the wrapper. + +```py +# inference +model = ... +opt_model = torch.compile(model) + +for _ in range(N_ITERS): + inp = ... + out = opt_model(inp) +``` + +```py +# training +model = ... +opt = torch.optim.Adam(model.parameters()) + +@torch.compile +def train(mod, data): + opt.zero_grad(True) + pred = mod(data[0]) + loss = torch.nn.CrossEntropyLoss()(pred, data[1]) + loss.backward() + opt.step() + +for _ in range(N_ITERS): + inp = ... + train(model, inp) +``` + +```py +# DistributedDataParallel +model = ... +opt_model = torch.compile(model) +model_ddp = DistributedDataParallel(opt_model, ...) + +for _ in range(N_ITERS): + inp = ... + out = model_ddp(inp) +``` + +### Disabling and Suppressing Errors + +For some model architectures, there are portions of the model which are particularly difficult to compile +\- either there are many graph breaks, or there are crashes. You may want to explicitly disable these +portions of the model which are problematic so that you can apply `torch.compile` to the parts that work. +You can do this by using the `@torch.compiler.disable` decorator. When `torch.compile` attempts to call a +disabled function, it breaks the graph and skips tracing the disabled function, resuming tracing after the call. +By default, all recursive calls made from a disabled function are also disabled. Use the `recursive=False` +option to allow compilation for recursive calls. + +```py +def bad1_inner(...): + # skipped + +@torch.compiler.disable +def bad1_outer(...): + # skipped + bad1_inner(...) + +def bad2_inner(...) + # traced + +@torch.compiler.disable(recursive=False) +def bad2_outer(...): + # skipped + bad2_inner(...) + +@torch.compile +def fn(...): + # graph break + bad1_outer(...) + ... + # graph break + bad2_outer(...) +``` + +For example, we use `torch.compiler.disable` to disable `torch.compile` on sparse architecture in +recommendation models, as the sparse arch is difficult to compile. Preprocessing and logging functions +are other examples of functions that typically cause a lot of graph breaks and do not get value from being compiled. + +If you are experiencing compiler crashes and you want to continue regardless, you can set +`torch._dynamo.config.suppress_errors = True`. When the compiler crashes, we will just skip tracing +the function and try again later. This is not best practice - it is better to eventually manually add +disable annotations as necessary. + +### Resolving graph breaks + +To maximize optimization opportunities, it's important to reduce the number of graph breaks. +Recall that you can see what graph breaks are happening using `tlparse` or `TORCH_LOGS="graph_breaks"`. +In general, graph breaks are caused by one of the following: + +1. You're trying to do something that fundamentally cannot be traced, such as data-dependent control flow. +2. You're trying to do something not yet supported. . + For example, we currently have limited support for tracing code that uses the built-in Python `inspect` module. +3. Your code has an error in it. For example, you may have tried calling a function with an incorrect number of arguments. + +Graph break logs will tell you the user code location and reason for the graph break. +Unfortunately, many graph breaks are not actionable without a deeper understanding of Dynamo. +It can even be challenging to determine which of the three causes was the true cause of your graph break. +We are working on making graph break messages more actionable. + +Additionally, the impact of lost optimization opportunities differs between graph breaks. +For example, graph breaks that happen in the middle of your model's `forward` are likely to have a more negatie impact than +graph breaks in a preprocessing part at the beginning of the `forward`. So it is not crucial to prevent *every single* +break, but rather to prevent the ones that cause significant performance hits. + +If a graph break message doesn't suggest any action, you suspect that the cause of your graph break is (2), +and you believe that the graph break is causing performance hits, +then please report the graph break as an issue. If a function has many graph breaks, +consider disabling compilation on that function, as the overhead cost for the graph breaks may become prohibitive. + +Below are some common graph breaks and some workarounds. + +#### Data-dependent operations + +`torch.compile` graph breaks on data-dependent operations such as data-dependent control flow +(if-statements, loops with tensors) and direct tensor data accesses (`.item`, `.data_ptr`). + +```py +import torch + +@torch.compile +def fn(x): + y = x.sum() + if y > 0: + return x + y.item() + return x - y.item() + +fn(torch.ones(3, 3)) +``` + +``` +$ TORCH_LOGS="graph_breaks" python playground.py +Graph break in user code at /data/users/williamwen/pytorch/playground.py:6 +Reason: Data-dependent jump +User code traceback: +File "/data/users/williamwen/pytorch/playground.py", line 6, in fn + if y > 0: + +Graph break in user code at /data/users/williamwen/pytorch/playground.py:7 +Reason: Unsupported: Tensor.item +User code traceback: +File "/data/users/williamwen/pytorch/playground.py", line 7, in torch_dynamo_resume_in_fn_at_6 + return x + y.item() +Traceback (most recent call last): +File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 616, in wrapper + return inner_fn(self, inst) + ^^^^^^^^^^^^^^^^^^^^ +File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2288, in CALL + self._call(inst) +File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2282, in _call + self.call_function(fn, args, kwargs) +File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 838, in call_function + self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +File "/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py", line 1038, in call_function + return self.obj.call_method(tx, self.name, args, kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method + result = handler_method(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 773, in method_item + unimplemented("Tensor.item") +File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 304, in unimplemented + raise Unsupported(msg, case_name=case_name) +torch._dynamo.exc.Unsupported: Tensor.item +``` + +The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are: + +- If your control flow doesn't actually depend on data values, consider modifying your code to perform control flow on constants. + +```py +# old +x = torch.randn(3, 3) +@torch.compile +def fn(y): + if x.sum() > 0: + return y + x + else: + return y - x + +# new +x = torch.randn(3, 3) +cond = (x.sum() > 0).item() +@torch.compile +def fn(y): + if cond: + return y + x + else: + return y - x +``` + +- Use higher-order ops like `torch.cond` () in place of data-dependent control flow + +```py +# old +@torch.compile +def fn(x): + if x.sum() > 0: + return x + 1 + return x - 1 + +# new +@torch.compile +def fn(x): + return torch.cond( + x.sum() > 0, + lambda x: x + 1, + lambda x: x - 1, + (x,), + ) +``` + +- If you have a `.item()` call, try `torch._dynamo.config.capture_scalar_outputs = True` or `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` +- Wrap problematic parts of the function in a custom op + +#### Custom ops + +If you have code that `torch.compile` has trouble tracing through, either due to missing support or fundamental incompatibility, +you can consider wrapping the problematic code in a custom op. + +Custom ops require a little bit of additional work to get them to be compatible with `torch.compile`. +See for more details. + +#### Printing + +Printing/logging/issuing warnings will result in a graph break. If you have a function that makes many logging calls, +for example, a function that logs data about a training iteration, consider applying `torch.compiler.disable` on it. + +Alternatively, you can try using `torch._dynamo.config.reorderable_logging_functions`. +This config is used to reorder logging functions so that they are called at the end of the traced function, +thus avoiding a graph break. However, the logged contents may differ if, for example, a mutation occurs. + +```py +import torch + +torch._dynamo.config.reorderable_logging_functions.add(print) + +@torch.compile +def fn(x): + x += 1 + print("log!") + return torch.sin(x) + +fn(torch.ones(3, 3)) +``` + +``` +$ TORCH_LOGS="graph_breaks" python playground.py +log! +``` + +#### Incorrect code + +Your code may be wrong, or is otherwise encountering an error from outside `torch.compile`. +In the code below, we made a typo in the `torch.sin` call by providing an extra argument. + +```py +import torch + +@torch.compile +def fn(x): + y = torch.sin(x, x) + return y + +fn(torch.ones(3, 3)) +``` + +``` +$ TORCH_LOGS="graph_breaks" python playground.py +Graph break in user code at /data/users/williamwen/pytorch/playground.py:5 +Reason: Unsupported: TypeError : sin() takes 1 positional argument but 2 were given +User code traceback: +File "/data/users/williamwen/pytorch/playground.py", line 5, in fn + y = torch.sin(x, x) +... +``` + +It can be difficult to tell from the logs if the error is caused by your code or because of a `torch.compile` bug. +In order to differentiate, we recommend trying to run your code without `torch.compile` to see if you still get the error. + +### Dealing with recompilations + +You can view recompilations and their reasons using `tlparse` or `TORCH_LOGS=recompiles`. + +#### Is dynamic shapes enabled? + +Recompilations due to mismatched shapes are in the form: + +``` +tensor 'L['x']' size mismatch at index 0. expected 3, actual 4 +``` + +Make sure that the `dynamic` option of `torch.compile` is not set to `False`. +The default option, `dynamic=None`, will only attempt dynamic shapes after the first compilation. +You can set `dynamic=True` to upfront compile as dynamic as possible. + +For more information on dynamic shapes, see [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng). + +#### Changing the cache size limit + +There is a limit to how many times a function can be recompiled, determined by `torch._dynamo.config.recompile_limit` +and `torch._dynamo.config.accumulated_recompile_limit`. +If either limit is exceeded, then we will not attempt to compile the function again and instead will run the function eagerly. +`torch.compile` will also issue a warning containing the affected function and which limit was hit. +In the example below, each function call results in a recompile attempt. +When we hit the cache size limit (8), we stop attempting to recompile. + +```py +import torch + +@torch.compile(dynamic=False) +def fn(x): + return x + 1 + +for i in range(1, 10): + fn(torch.ones(i)) +``` + +``` +$ python playground.py +torch._dynamo hit config.recompile_limit (8) + function: 'fn' (/data/users/williamwen/pytorch/playground.py:5) + last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9 +``` + +If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit. +If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit. + +#### Wrapping constants with tensors + +By default, `int` / `float` variables are treated as constants and are guarded as such. +In the below example, we have a recompilation for each function call. + +```py +import torch + +@torch.compile +def fn(x, c): + return x + c + +for i in range(1, 10): + fn(torch.ones(i), 0.5 + i) +``` + +``` +$ TORCH_LOGS="recompiles" python playground.py +Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3 + triggered by the following guard failure(s): + - 0/7: L['c'] == 8.5 + - 0/6: L['c'] == 7.5 + - 0/5: L['c'] == 6.5 + - 0/4: L['c'] == 5.5 + - 0/3: L['c'] == 4.5 + - 0/2: L['c'] == 3.5 + - 0/1: L['c'] == 2.5 + - 0/0: L['c'] == 1.5 +torch._dynamo hit config.recompile_limit (8) + function: 'fn' (/data/users/williamwen/pytorch/playground.py:3) + last reason: 0/0: L['c'] == 1.5 +``` + +In particular, for LR schedulers, initializing with a constant can lead to recompilations: + +```py +import torch + +mod = torch.nn.Linear(3, 3) +opt = torch.optim.Adam(mod.parameters(), lr=0.01) +sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9) + +@torch.compile +def fn(inp): + opt.zero_grad(True) + out = mod(inp).sum() + out.backward() + opt.step() + sched.step() + +for i in range(1, 10): + fn(torch.ones(3, 3)) +``` + +``` +$ TORCH_LOGS="recompiles" python playground.py +Recompiling function step in /data/users/williamwen/pytorch/torch/optim/adam.py:189 + triggered by the following guard failure(s): + - 3/7: L['self'].param_groups[0]['lr'] == 0.004782969000000002 + - 3/6: L['self'].param_groups[0]['lr'] == 0.005314410000000002 + - 3/5: L['self'].param_groups[0]['lr'] == 0.005904900000000002 + - 3/4: L['self'].param_groups[0]['lr'] == 0.006561000000000002 + - 3/3: L['self'].param_groups[0]['lr'] == 0.007290000000000001 + - 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001 + - 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001 + - 3/0: L['self'].param_groups[0]['lr'] == 0.01 +torch._dynamo hit config.recompile_limit (8) + function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189) + last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01 +``` + +In both examples, we can wrap float variables in tensors in order to prevent recompilations. + +```py +# first example +for i in range(1, 10): + fn(torch.ones(i), torch.tensor(0.5 + i)) + +# second example +opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01)) +sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9)) +``` + +## Reporting Issues + +If the workarounds provided above were not enough to get `torch.compile` working, +then you should consider reporting the issue to PyTorch. +But there are a few things that you can do to make our lives significantly easier. + +### Ablation + +Check which component of the `torch.compile` stack is the one causing the issue using the `backend=` option for `torch.compile`. +In particular, try: + +- `torch.compile(fn, backend="eager")`, which only runs TorchDynamo, the graph capture component of `torch.compile`. +- `torch.compile(fn, backend="aot_eager")`, which runs TorchDynamo and AOTAutograd, which additionally generates the backward graph during compilation. +- `torch.compile(fn, backend="aot_eager_decomp_partition")`, which runs TorchDynamo and AOTAutograd with operator decompositions/partitions. +- `torch.compile(fn, backend="inductor")`, which runs TorchDynamo, AOTAutograd, and TorchInductor, the backend ML compiler that generates compiled kernels. + +If you only fail with the Inductor backend, you can additionally test various Inductor modes: + +- `torch.compile(fn, backend="inductor", mode="default")` +- `torch.compile(fn, backend="inductor", mode="reduce-overhead")` +- `torch.compile(fn, backend="inductor", mode="max-autotune")` + +You can also check if dynamic shapes is causing issues with any backend: + +- `torch.compile(fn, dynamic=True)` (always use dynamic shapes) +- `torch.compile(fn, dynamic=False)` (never use dynamic shapes) +- `torch.compile(fn, dynamic=None)` (automatic dynamic shapes) + +### Bisecting + +Did you try on the latest nightly? Did something work in the past but now no longer works? +Can you bisect to determine the first nightly where your issue occurs? +Bisecting is especially helpful for performance, accuracy, or compile time regressions, +where it is not immediately obvious where the problem originates from. + +### Creating a reproducer + +Creating reproducers is a lot of work, and it is perfectly fine if you do not have the time to do it. +However, if you are a motivated user unfamiliar with the internals of `torch.compile`, +creating a standalone reproducer can have a huge impact on our ability to fix the bug. +Without a reproducer, your bug report must contain enough information for us to identify the root cause of the problem and write a reproducer from scratch. + +Here's a list of useful reproducers, ranked from most to least preferred: + +1. **Self-contained, small reproducer:** A script with no external dependencies, under 100 lines of code, that reproduces the problem when run. +2. **Self-contained, large reproducer:** Even if it's large, being self-contained is a huge advantage! +3. **Non-self-contained reproducer with manageable dependencies:** + For example, if you can reproduce the problem by running a script after `pip install transformers`, + that's manageable. We can likely run it and investigate. +4. **Non-self-contained reproducer requiring substantial setup:** This might involve downloading datasets, + multiple environment setup steps, or specific system library versions requiring a Docker image. + The more complex the setup, the harder it is for us to recreate the environment. + + :::{note} + Docker simplifies setup but complicates changes to the environment, so it's not a perfect solution, though we'll use it if necessary. + ::: + +Somewhat orthogonally, a reproducer that can be run in a single process is better than a reproducer +that requires multiprocess training (but once again, if you only have a multiprocess reproducer, we'll take it!). + +Additionally, below is a non-exhaustive list of aspects to check in your +issue that you can attempt to replicate in your reproducer: + +- **Autograd**. Did you have tensor inputs with `requires_grad=True`? Did you call `backward()` on the output? +- **Dynamic shapes**. Did you set `dynamic=True`? Or did you run the test code multiple times with varying shapes? +- **Custom operators**. Is there a custom operator involved in the real workflow? + Can you replicate some of its important characteristics using the Python custom operator API? +- **Configuration**. Did you set all the same configuration? + This includes `torch._dynamo.config` and `torch._inductor.config` settings, + as well as arguments to `torch.compile` like `backend` / `mode`. +- **Context managers**. Did you replicate any active context managers? + This could be `torch.no_grad`, automatic mixed precision, `TorchFunctionMode` / `TorchDispatchMode`, + activation checkpointing, compiled autograd etc. +- **Tensor subclasses**. Is there a tensor subclass involved? + +### Minifier + +The minifier is an early `torch.compile` tool that, given an FX graph that crashes when we attempt to run or compile it, +finds a subgraph that also crashes and outputs the code that performs that subgraph's operations. +Essentially, the minifier finds a minimal repro for a certain class of `torch.compile`-related crashes. +This assumes that we were able to successfully trace through code. + +Unfortunately, most of the time nowadays, the minifier doesn't work as expected, and alternative methods may be necessary. +This is likely because bugs that can be automatically reproduced in this manner are generally easier to fix +and have already been addressed, leaving more complex issues that do not reproduce easily. +However, it is straightforward to attempt using the minifier, so it is worth trying even if it may not succeed. + +Instructions for operating the minifier can be found [here](https://pytorch.org/docs/stable/torch.compiler_troubleshooting_old.html). +If the compiler is crashing, you can set `TORCHDYNAMO_REPRO_AFTER="dynamo"` or `TORCHDYNAMO_REPRO_AFTER="aot"` +The `aot` option is more likely to succeed, although it may not identify the `AOTAutograd` issues. This will generate the `repro.py` file which may help to diagnose the problem. +For accuracy-related issues, consider setting `TORCHDYNAMO_REPRO_LEVEL=4`. Please note that this may not always successfully identify the problematic subgraph. + +## Debugging Deeper + +This section provides tools and techniques for independently debugging `torch.compile` issues +or for gaining a deeper understanding of the `torch.compile` stack. +These methods are more involved than those presented above and are used by PyTorch developers regularly +to debug real `torch.compile` issues. + +Below is a high-level overview of the stack: + +![Torch Dynamo Stack](_static/img/dynamo/td_stack.png) + +The stack comprises three main components: TorchDynamo, AOTAutograd, and Inductor. +Our debugging strategy involves first identifying the component in which the error occurs +and then individually debugging the component. To determine the component responsible for the issue, +see the `Ablation` section under `Reporting Issues` above. For guidance on debugging a specific component, consult the sections below. + +### TorchDynamo + +#### Logging what Dynamo is tracing + +The `TORCH_LOGS=trace_bytecode` option enables you to view the precise bytecode instructions that Dynamo is tracing, +as well as a symbolic representation of the Python interpreter stack. When encountering a graph break or crash, +it is advisable to inspect the last few bytecode instructions traced. + +You can also use `TORCH_LOGS=trace_source` to see which lines of source code Dynamo is tracing through. +This is useful in combination with `trace_bytecode` to see the line of source code each traced bytecode instruction corresponds to. + +Finally, you can use `TORCH_LOGS=graph_code` to see the Python code representing the FX graph that Dynamo traced. +You can view this code to double check that the correct ops are being traced. + +```py +import torch + +def g(x, y): + return x + y + +@torch.compile(backend="eager") +def f(x): + x = torch.sin(x) + x = g(x, x) + return x + +f(torch.ones(3, 3)) +``` + +``` +$ TORCH_LOGS="trace_bytecode,trace_source,graph_code" python playground.py +TRACE starts_line /data/users/williamwen/pytorch/playground.py:6 in f () + @torch.compile(backend="eager") +TRACE RESUME 0 [] +TRACE starts_line /data/users/williamwen/pytorch/playground.py:8 in f (f) + x = torch.sin(x) +TRACE LOAD_GLOBAL torch [] +TRACE LOAD_ATTR sin [NullVariable(), PythonModuleVariable()] +TRACE LOAD_FAST x [NullVariable(), TorchInGraphFunctionVariable()] +TRACE CALL 1 [NullVariable(), TorchInGraphFunctionVariable(), LazyVariableTracker()] +TRACE STORE_FAST x [TensorVariable()] +TRACE starts_line /data/users/williamwen/pytorch/playground.py:9 in f (f) + x = g(x, x) +TRACE LOAD_GLOBAL g [] +TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable()] +TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable(), TensorVariable()] +TRACE CALL 2 [NullVariable(), UserFunctionVariable(), TensorVariable(), TensorVariable()] +TRACE starts_line /data/users/williamwen/pytorch/playground.py:3 in g (g) (inline depth: 1) + def g(x, y): +TRACE RESUME 0 [] +TRACE starts_line /data/users/williamwen/pytorch/playground.py:4 in g (g) (inline depth: 1) + return x + y +TRACE LOAD_FAST x [] +TRACE LOAD_FAST y [TensorVariable()] +TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()] +TRACE RETURN_VALUE None [TensorVariable()] +TRACE STORE_FAST x [TensorVariable()] +TRACE starts_line /data/users/williamwen/pytorch/playground.py:10 in f (f) + return x +TRACE LOAD_FAST x [] +TRACE RETURN_VALUE None [TensorVariable()] +TRACED GRAPH +===== __compiled_fn_1 ===== +/data/users/williamwen/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3][3, 1]cpu"): + l_x_ = L_x_ + + # File: /data/users/williamwen/pytorch/playground.py:8 in f, code: x = torch.sin(x) + x: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None + + # File: /data/users/williamwen/pytorch/playground.py:4 in g, code: return x + y + x_1: "f32[3, 3][3, 1]cpu" = x + x; x = None + return (x_1,) +``` + +#### Breakpointing Dynamo tracing + +Inserting a breakpoint in Dynamo/user code is helpful at times to see what the state of Dynamo is when tracing through user code. +Unfortunately, inserting a breakpoint in the normal Python fashion will result in a graph break in TorchDynamo, +so we will not be able to view the state of Dynamo at the point where we intended to breakpoint. + +The first method for setting a breakpoint is to insert it within the Dynamo source code. Three recommended locations to place a breakpoint are: + +- In `torch/_dynamo/symbolic_convert.py`, breakpoint at functions that are named after the problematic bytecode instruction, + such as `def CALL_FUNCTION` and `def STORE_ATTR`. You can conditionally breakpoint depending on inputs, + for example, the `argval` of the instruction, or the name of the object at the top of the stack since some bytecode opcodes are frequently used. +- Breakpoint where the graph break or error originates from. Typically, graph breaks are emitted from a call to `unimplemented(...)`. +- Breakpoint in `torch/_dynamo/variables/builder.py, function:_wrap`. You will likely have to conditionally breakpoint on the input. + This function determines how to symbolically represent a given value. Consider breakpointing here if you suspect that a value is represented incorrectly. + +The second way to insert a breakpoint is to use `torch._dynamo.comptime.comptime.breakpoint`: + +```py +from torch._dynamo.comptime import comptime + +@torch.compile +def f(...): + ... + comptime.breakpoint() + ... +``` + +A comptime breakpoint is convenient as it enables you to inspect the Dynamo state at a specific location within the user code being traced. +It does not require you to insert a breakpoint in the Dynamo source or to conditionally breakpoint based on variables. + +When a comptime breakpoint is triggered, you can do the following: + +- `ctx.print_bt()` to print the user stack trace +- `ctx.print_locals()` to print all current locals +- `ctx.print_graph()` to print the currently traced graph +- `ctx.disas()` to print the currently traced function's bytecode +- Use standard `pdb` commands, such as `bt/u/d/n/s/r`, - you can go up the `pdb` stack to inspect more Dynamo internals + +```py +import torch +from torch._dynamo.comptime import comptime + +@torch.compile(backend="eager") +def f(x): + y = x + 1 + comptime.breakpoint() + y = y + 1 + return y + +f(torch.ones(3, 3)) +``` + +``` +$ python playground.py +--Return-- +> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None +-> builtins.breakpoint() +(Pdb) ctx.print_bt() +File "/data/users/williamwen/pytorch/playground.py", line 7, in f + comptime.breakpoint() + +(Pdb) ctx.print_locals() +x = FakeTensor(..., size=(3, 3)) +y = FakeTensor(..., size=(3, 3)) +(Pdb) bt +... +/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py(826)call_function() +-> self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] +/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py(331)call_function() +-> func(ComptimeContext(tx)) +> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None +-> builtins.breakpoint() +(Pdb) ctx.print_graph() + + + +def forward(self, L_x_: "f32[3, 3]"): + l_x_ = L_x_ + + # File: /data/users/williamwen/pytorch/playground.py:6 in f, code: y = x + 1 + y: "f32[3, 3]" = l_x_ + 1; l_x_ = y = None +``` + +% TODO(uncomment/update once we improve this API) +% Debugging large models +% ^^^^^^^^^^^^^^^^^^^^^^ +% +% Debugging TorchDynamo on large models can be tricky, mainly because Dynamo traces through large amounts of code. +% It can be difficult to find the problematic function, or to determine where to place a breakpoint. +% Even if we've found the problematic function, we don't want to deal with logging spam. +% Fortunately, you can use ``TORCHDYNAMO_DEBUG_FUNCTION=``, which limits dynamo tracing to only functions with a specific name +% (exact match). This will allow you to filter all of the functions in the model to the function(s) of interest. +% Use this in combination with the above debugging strategies. + +#### Bytecode generation errors + +Although uncommon, Dynamo may generate incorrect bytecode. This may occur if you determine the following: + +- Ablation reveals the error is happening at the TorchDynamo level +- The error is not being emitted from TorchDynamo stack frames +- The error looks more like a user error rather than a Dynamo error, or is a segmentation fault +- The error does not occur without `torch.compile` + +Bytecode generation bugs are generally tricky to fix and we recommend submitting an issue instead of trying to fix those yourself. +If you are interested in seeing the bytecode that Dynamo generates, you can use `TORCH_LOGS=bytecode`. +You can see a high-level overview on what bytecode Dynamo generates [here](https://docs.google.com/presentation/d/1tMZOoAoNKF32CAm1C-WfzdVVgoEvJ3lp/edit?usp=sharing&ouid=114922067987692817315&rtpof=true&sd=true). + +### AOTAutograd + +AOTAutograd errors are typically difficult to debug - we recommend just submitting an issue. +AOTAutograd logging output is primarily helpful to see what the input to Inductor is. + +% TODO +% TorchInductor +% ------------- + +% TODO + +(troubleshooting-torch-logs-options)= + +### Summary of TORCH_LOGS options + +A summary of helpful `TORCH_LOGS` options is: + +```{eval-rst} +.. list-table:: + :widths: 25 50 + :header-rows: 1 + + * - Option + - Description + * - +all + - Output debug logs from all ``torch.compile`` components + * - +dynamo + - Output debug logs from TorchDynamo + * - +aot + - Output debug logs from AOTAutograd + * - +inductor + - Output debug logs from TorchInductor + * - dynamic + - Output logs from dynamic shapes + * - graph_code + - Output the Python code for the FX graph that Dynamo generated + * - graph_sizes + - Output the tensor sizes of the FX graph that Dynamo generated + * - trace_bytecode + - Output the bytecode instructions that Dynamo is tracing through and the symbolic interpreter stack Dynamo is keeping track of + * - trace_source + - Output the line of code in the original source that Dynamo is currently tracing through + * - bytecode + - Output Dynamo-generated bytecode + * - guards + - Output generated guards + * - recompiles + - Output recompilation reasons (only the first guard check that fails) + * - recompiles_verbose + - Output all guard checks that fail when a recompilation occurs + * - aot_graphs + - Output graph generated by AOTAutograd + * - aot_joint_graphs + - Output the joint forward-backward graph generated by AOTAutograd + * - output_code + - Output code generated by Inductor + * - kernel_code + - Output code generated by Inductor on a per-kernel basis + * - schedule + - Output Inductor scheduling logs + * - perf_hints + - Output Inductor perf hint logs + * - fusion + - Output Inductor fusion logs +``` + +For the full list of options, see [torch.\_logging](https://pytorch.org/docs/stable/logging.html) +and [torch.\_logging.set_logs](https://pytorch.org/docs/stable/generated/torch._logging.set_logs.html#torch._logging.set_logs). + +## Related Articles + +- [torch.compile tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) +- [torch.compile fine-grained APIs](https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html) +- [torch.compile FAQ](https://pytorch.org/docs/stable/torch.compiler_faq.html) +- [torch.compiler namespace overview](https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler-overview) +- [torch.compiler API reference](https://pytorch.org/docs/stable/torch.compiler_api.html) +- [Profiling torch.compile](https://pytorch.org/docs/stable/torch.compiler_profiling_torch_compile.html) +- [torch.compile missing manual](https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit?usp=sharing) +- [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng) +- [TorchInductor caching tutorial](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) diff --git a/docs/source/torch.compiler_troubleshooting.rst b/docs/source/torch.compiler_troubleshooting.rst deleted file mode 100644 index 89731fac13ad04..00000000000000 --- a/docs/source/torch.compiler_troubleshooting.rst +++ /dev/null @@ -1,1113 +0,0 @@ -.. _torch.compiler_troubleshooting: - -torch.compile Troubleshooting -================================= - -You're trying to use ``torch.compile`` on your PyTorch model to enhance its performance -but it's not working as expected. Perhaps performance isn't improving, crashes are happening, or compilation time is too long. This article provides tips, workarounds, and debugging tools to help you overcome these challenges. - -**Contents** - -.. contents:: - :local: - -Setting Expectations -~~~~~~~~~~~~~~~~~~~~ - -``torch.compile`` is designed as a general-purpose PyTorch compiler. -Unlike the previous compiler solution, TorchScript, ``torch.compile`` -requires fewer code changes, meaning models typically don't need to be rewritten from scratch. -It also manages unsupported code more gracefully - unsupported code results in a lost optimization opportunity rather than a crash. - -In the ideal world, one can simply apply ``torch.compile`` to any PyTorch model and enjoy automatic speedups. -However, in reality, code complexities can lead to one of three scenarios: - -1. ``torch.compile`` works seamlessly, providing speedups. -2. Some code modifications are necessary. ``torch.compile`` doesn't crash or take too long, - but you might not be seeing significant performance gains. -3. Extensive changes to your code are required. - -We anticipate most code will fall under scenarios (1) and (2). -This document provides tips, arranged by level of involvement, to help address code issues in scenario (2). - -Compile times -------------- - -``torch.compile`` functions as a just-in-time compiler, so the initial one or two runs -of the compiled function are expected to be significantly slower. Recompilations, which can occur under certain conditions (detailed below), -will also make runs slower. Various ``torch.compile`` components cache results to -reduce compilation time for future invocations, even in different processes. -Cold-start (uncached) compilation time typically ranges from seconds to minutes for common or benchmarked models. -Larger models may take upwards of 30 minutes to a few hours. - -Terminology -~~~~~~~~~~~ - -The following terms are relevant to troubleshooting ``torch.compile`` problems. - -Graph break ------------ - -``torch.compile`` traces your code and attempts to capture your PyTorch code into a -single computation graph of PyTorch operators (FX graph). However, this is not always possible. -When encountering code that can't be traced, a "graph break" occurs. -A graph break involves compiling the FX graph has been determined so far, running the unsupported code, -then resuming tracing after the unsupported code with a new FX graph. -Because the computation graph is broken up, we lose optimization opportunities, -so model code should avoid graph breaks whenever possible. -Graph breaks occur on things like: - -- Data-dependent if-statements -- Many Python built-in functions -- C functions - -Below is an example of a graph break due to the function ``copy.deepcopy`` from a Python builtin library -(exact output may differ). - -.. code-block:: py - - import torch - - @torch.compile - def fn(x): - x = x + 1 - with open("test.txt", "r") as f: - return x + len(f.read()) - - fn(torch.ones(3, 3)) - -:: - - $TORCH_LOGS="graph_breaks" python playground.py - Graph break in user code at /data/users/williamwen/pytorch/playground.py:7 - Reason: Unsupported: builtin: open [, ] False - User code traceback: - File "/data/users/williamwen/pytorch/playground.py", line 7, in fn - with open("test.txt", "r") as f: - Traceback (most recent call last): - File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 635, in wrapper - return inner_fn(self, inst) - ^^^^^^^^^^^^^^^^^^^^ - File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2414, in CALL - self._call(inst) - File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _call - self.call_function(fn, args, kwargs) - File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in call_function - self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 997, in call_function - return handler(tx, args, kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 831, in - return lambda *args: unimplemented(error_msg) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented - raise Unsupported(msg, case_name=case_name) - torch._dynamo.exc.Unsupported: builtin: open [, ] False - -Guards ------- - -``torch.compile`` makes some assumptions about runtime values as we trace through code. -During tracing, we generate "guards", which are runtime checks for these assumptions. -Guards are run in future calls to the compiled function to determine if we can reuse previously compiled code. -Examples of runtime checks are constant values, types, and object IDs. - -Below is an example of generated guards. The ``TENSOR_MATCH`` guard checks for the input's type, device, dtype, shape, etc. - -.. code-block:: py - - import torch - - @torch.compile - def fn(x): - return x + 1 - - fn(torch.ones(3, 3)) - -:: - - $ TORCH_LOGS="guards" python playground.py - GUARDS: - - TREE_GUARD_MANAGER: - +- RootGuardManager - | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:471 in init_ambient_guards - | +- GLOBAL_STATE: ___check_global_state() - | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack() - | +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x) - | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1]) # return x + 1 # playground.py:6 in fn - | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # return x + 1 # playground.py:6 in fn - -Recompilation -------------- - -If the guards fail for every instance of previously compiled code, -then ``torch.compile`` must "recompile" the function, requiring the original code to be traced again. - -In the example below, recompilation is necessary because the guard checking the tensor argument's shape failed. - -.. code-block:: py - - import torch - - @torch.compile - def fn(x): - return x + 1 - - fn(torch.ones(3, 3)) - fn(torch.ones(4, 4)) - -:: - - $ TORCH_LOGS="recompiles" python playground.py - Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3 - triggered by the following guard failure(s): - - 0/0: tensor 'L['x']' size mismatch at index 0. expected 3, actual 4 - -Dynamic Shapes -------------------------- -``torch.compile`` initially assumes tensor shapes are static/constant and guards based on these assumptions. -By using "dynamic shapes," we can get ``torch.compile`` to produce compiled code that can accept -tensor inputs with different shapes - we avoid recompiling every time shapes differ. -By default, automatic dynamic shapes are enabled ``torch.compile(dynamic=None)`` - -if compilation fails due to shape mismatch, recompilation is attempted with dynamic shapes. -Dynamic shapes can also be fully enabled ``dynamic=True`` or disabled ``dynamic=False``. - -Below, we enable dynamic shapes and note that we no longer need to recompile. - -.. code-block:: py - - import torch - - @torch.compile(dynamic=True) - def fn(x): - return x + 1 - - fn(torch.ones(3, 3)) - fn(torch.ones(4, 4)) - -:: - - $ TORCH_LOGS="dynamic,recompiles" python playground.py - create_symbol s0 = 3 for L['x'].size()[0] [2, int_oo] at playground.py:5 in fn (_dynamo/variables/builder.py:2718 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" - produce_guards - produce_guards - -For more information on dynamic shapes, see `The dynamic shapes manual `__. - -Logging Tools -~~~~~~~~~~~~~ - -tlparse / TORCH_TRACE ------------------------------ - -``tlparse`` / ``TORCH_TRACE`` are a pair of tools that produce compilation reports that look like this: -https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html. - -Traces are very easy to collect. To collect a trace, run your reproduction command with - -:: - - TORCH_TRACE="/tmp/tracedir" python foo.py - pip install tlparse - tlparse /tmp/tracedir - -This approach works even if you are running a distributed job, providing a trace for each rank. -It will open your browser with HTML similar to what's generated above. -If you are making a bug report for a complicated problem that you don't have a standalone reproduction for, -you can still greatly assist PyTorch developers by attaching the trace log generated in ``/tmp/tracedir``. - -.. warning:: The trace log contains all of your model code. - Do not share the trace log if the model you are working on is sensitive. The trace log does NOT contain weights. - -.. raw:: html - - - -.. role:: red - -.. role:: green - -.. role:: dark-green - -The output of ``tlparse`` is primarily aimed for PyTorch developers, -and the log format is easy to upload and share on GitHub. -However, as a non-PyTorch developer, you can still extract useful information from it. -We recommend starting with the inline help text in the report, which explains its contents. -Here are some insights you can gain from a ``tlparse``: - -- What model code was compiled by looking at the stack trie? - This is especially useful if you're not familiar with the codebase being compiled! -- How many graph breaks / distinct compilation regions are there? - (Each distinct compile is its own color coded block like :dark-green:`[0/0]`). - Frames that are potentially graph-broken are light green :green:`[2/4]`. - If there are a lot of frames, that is suspicious, and suggests that you had some catastrophic graph breaks, - or maybe your code isn't a good match for ``torch.compile``. -- How many times did I recompile a particular frame? Something that recompiled a lot will look like: - :dark-green:`[10/0]` :dark-green:`[10/1]` :dark-green:`[10/2]` - - if something is being recompiled a lot, that is very suspicious and worth looking into, even if it isn't the root cause of your problem. -- Was there a compilation error? Frames that errored will look like :red:`[0/1]`. -- What intermediate compiler products did I generate for a given frame? - For example, you can look at the high-level generated FX graph or the generated Triton code. -- Is there relevant information for a particular frame? You can find these in ``compilation_metrics``. - -TORCH_LOGS --------------- - -You can use the ``TORCH_LOGS`` environment variable to selectively enable parts of the ``torch.compile`` stack to log. -``TORCH_LOGS`` is in fact the source of logs for ``tlparse``. The format of the ``TORCH_LOGS`` environment variable looks like this: - -:: - - TORCH_LOGS=",,..." python foo.py - - -Useful high-level options include: - -- ``graph_breaks``: logs locations of graph breaks in user code and the reason for the graph break -- ``guards``: logs guards that are generated -- ``recompiles``: logs which function recompiled and the guards that failed, leading to the recompilation -- ``dynamic``: logs related to dynamic shapes - -Also, you can programmatically set logging options using ``torch._logging.set_logs``: - -.. code-block:: py - - import logging - torch._logging.set_logs(graph_breaks=True) - ... - -More ``TORCH_LOGS`` options are :ref:`detailed below `. -For the full list of options, see `torch._logging `__ -and `torch._logging.set_logs `__. - -tlparse vs. TORCH_LOGS ----------------------- - -Generally, we suggest first using ``tlparse`` when encountering issues. -``tlparse`` is ideal for debugging large models and gaining a high-level overview of how your model was compiled. -On the other hand, ``TORCH_LOGS`` is preferred for small examples and fine-grained debugging detail, -when we already have an idea of which ``torch.compile`` component is causing the problem. - -Simple Workarounds -~~~~~~~~~~~~~~~~~~ - -Here, we describe some workarounds to ``torch.compile`` issues involving small code modifications -or changing some ``torch.compile`` settings. - -Where to apply torch.compile? ---------------------------------- - -We recommend applying ``torch.compile`` to the highest-level function that doesn't cause excessive problems. -Typically, it is your train or eval step with the optimizer but without the loop, your top-level ``nn.Module``, -or some sub-``nn.Module``s. ``torch.compile`` specifically doesn't handle distributed wrapper modules like -DDP or FSDP very well, so consider applying ``torch.compile`` to the inner module passed to the wrapper. - -.. code-block:: py - - # inference - model = ... - opt_model = torch.compile(model) - - for _ in range(N_ITERS): - inp = ... - out = opt_model(inp) - -.. code-block:: py - - # training - model = ... - opt = torch.optim.Adam(model.parameters()) - - @torch.compile - def train(mod, data): - opt.zero_grad(True) - pred = mod(data[0]) - loss = torch.nn.CrossEntropyLoss()(pred, data[1]) - loss.backward() - opt.step() - - for _ in range(N_ITERS): - inp = ... - train(model, inp) - -.. code-block:: py - - # DistributedDataParallel - model = ... - opt_model = torch.compile(model) - model_ddp = DistributedDataParallel(opt_model, ...) - - for _ in range(N_ITERS): - inp = ... - out = model_ddp(inp) - -Disabling and Suppressing Errors ---------------------------------- - -For some model architectures, there are portions of the model which are particularly difficult to compile -- either there are many graph breaks, or there are crashes. You may want to explicitly disable these -portions of the model which are problematic so that you can apply ``torch.compile`` to the parts that work. -You can do this by using the ``@torch.compiler.disable`` decorator. When ``torch.compile`` attempts to call a -disabled function, it breaks the graph and skips tracing the disabled function, resuming tracing after the call. -By default, all recursive calls made from a disabled function are also disabled. Use the ``recursive=False`` -option to allow compilation for recursive calls. - -.. code-block:: py - - def bad1_inner(...): - # skipped - - @torch.compiler.disable - def bad1_outer(...): - # skipped - bad1_inner(...) - - def bad2_inner(...) - # traced - - @torch.compiler.disable(recursive=False) - def bad2_outer(...): - # skipped - bad2_inner(...) - - @torch.compile - def fn(...): - # graph break - bad1_outer(...) - ... - # graph break - bad2_outer(...) - -For example, we use ``torch.compiler.disable`` to disable ``torch.compile`` on sparse architecture in -recommendation models, as the sparse arch is difficult to compile. Preprocessing and logging functions -are other examples of functions that typically cause a lot of graph breaks and do not get value from being compiled. - -If you are experiencing compiler crashes and you want to continue regardless, you can set -``torch._dynamo.config.suppress_errors = True``. When the compiler crashes, we will just skip tracing -the function and try again later. This is not best practice - it is better to eventually manually add -disable annotations as necessary. - -Resolving graph breaks ----------------------- - -To maximize optimization opportunities, it's important to reduce the number of graph breaks. -Recall that you can see what graph breaks are happening using ``tlparse`` or ``TORCH_LOGS="graph_breaks"``. -In general, graph breaks are caused by one of the following: - -1. You're trying to do something that fundamentally cannot be traced, such as data-dependent control flow. -2. You're trying to do something not yet supported. . - For example, we currently have limited support for tracing code that uses the built-in Python ``inspect`` module. -3. Your code has an error in it. For example, you may have tried calling a function with an incorrect number of arguments. - -Graph break logs will tell you the user code location and reason for the graph break. -Unfortunately, many graph breaks are not actionable without a deeper understanding of Dynamo. -It can even be challenging to determine which of the three causes was the true cause of your graph break. -We are working on making graph break messages more actionable. - -Additionally, the impact of lost optimization opportunities differs between graph breaks. -For example, graph breaks that happen in the middle of your model's ``forward`` are likely to have a more negatie impact than -graph breaks in a preprocessing part at the beginning of the ``forward``. So it is not crucial to prevent *every single* -break, but rather to prevent the ones that cause significant performance hits. - -If a graph break message doesn't suggest any action, you suspect that the cause of your graph break is (2), -and you believe that the graph break is causing performance hits, -then please report the graph break as an issue. If a function has many graph breaks, -consider disabling compilation on that function, as the overhead cost for the graph breaks may become prohibitive. - -Below are some common graph breaks and some workarounds. - -Data-dependent operations -^^^^^^^^^^^^^^^^^^^^^^^^^ - -``torch.compile`` graph breaks on data-dependent operations such as data-dependent control flow -(if-statements, loops with tensors) and direct tensor data accesses (``.item``, ``.data_ptr``). - -.. code-block:: py - - import torch - - @torch.compile - def fn(x): - y = x.sum() - if y > 0: - return x + y.item() - return x - y.item() - - fn(torch.ones(3, 3)) - -:: - - $ TORCH_LOGS="graph_breaks" python playground.py - Graph break in user code at /data/users/williamwen/pytorch/playground.py:6 - Reason: Data-dependent jump - User code traceback: - File "/data/users/williamwen/pytorch/playground.py", line 6, in fn - if y > 0: - - Graph break in user code at /data/users/williamwen/pytorch/playground.py:7 - Reason: Unsupported: Tensor.item - User code traceback: - File "/data/users/williamwen/pytorch/playground.py", line 7, in torch_dynamo_resume_in_fn_at_6 - return x + y.item() - Traceback (most recent call last): - File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 616, in wrapper - return inner_fn(self, inst) - ^^^^^^^^^^^^^^^^^^^^ - File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2288, in CALL - self._call(inst) - File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2282, in _call - self.call_function(fn, args, kwargs) - File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 838, in call_function - self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py", line 1038, in call_function - return self.obj.call_method(tx, self.name, args, kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method - result = handler_method(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 773, in method_item - unimplemented("Tensor.item") - File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 304, in unimplemented - raise Unsupported(msg, case_name=case_name) - torch._dynamo.exc.Unsupported: Tensor.item - -The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are: - -- If your control flow doesn't actually depend on data values, consider modifying your code to perform control flow on constants. - -.. code-block:: py - - # old - x = torch.randn(3, 3) - @torch.compile - def fn(y): - if x.sum() > 0: - return y + x - else: - return y - x - - # new - x = torch.randn(3, 3) - cond = (x.sum() > 0).item() - @torch.compile - def fn(y): - if cond: - return y + x - else: - return y - x - -- Use higher-order ops like ``torch.cond`` (https://pytorch.org/docs/main/cond.html) in place of data-dependent control flow - -.. code-block:: py - - # old - @torch.compile - def fn(x): - if x.sum() > 0: - return x + 1 - return x - 1 - - # new - @torch.compile - def fn(x): - return torch.cond( - x.sum() > 0, - lambda x: x + 1, - lambda x: x - 1, - (x,), - ) - -- If you have a ``.item()`` call, try ``torch._dynamo.config.capture_scalar_outputs = True`` or ``TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`` -- Wrap problematic parts of the function in a custom op - -Custom ops -^^^^^^^^^^ - -If you have code that ``torch.compile`` has trouble tracing through, either due to missing support or fundamental incompatibility, -you can consider wrapping the problematic code in a custom op. - -Custom ops require a little bit of additional work to get them to be compatible with ``torch.compile``. -See https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details. - -Printing -^^^^^^^^ - -Printing/logging/issuing warnings will result in a graph break. If you have a function that makes many logging calls, -for example, a function that logs data about a training iteration, consider applying ``torch.compiler.disable`` on it. - -Alternatively, you can try using ``torch._dynamo.config.reorderable_logging_functions``. -This config is used to reorder logging functions so that they are called at the end of the traced function, -thus avoiding a graph break. However, the logged contents may differ if, for example, a mutation occurs. - -.. code-block:: py - - import torch - - torch._dynamo.config.reorderable_logging_functions.add(print) - - @torch.compile - def fn(x): - x += 1 - print("log!") - return torch.sin(x) - - fn(torch.ones(3, 3)) - -:: - - $ TORCH_LOGS="graph_breaks" python playground.py - log! - -Incorrect code -^^^^^^^^^^^^^^ - -Your code may be wrong, or is otherwise encountering an error from outside ``torch.compile``. -In the code below, we made a typo in the ``torch.sin`` call by providing an extra argument. - -.. code-block:: py - - import torch - - @torch.compile - def fn(x): - y = torch.sin(x, x) - return y - - fn(torch.ones(3, 3)) - -:: - - $ TORCH_LOGS="graph_breaks" python playground.py - Graph break in user code at /data/users/williamwen/pytorch/playground.py:5 - Reason: Unsupported: TypeError : sin() takes 1 positional argument but 2 were given - User code traceback: - File "/data/users/williamwen/pytorch/playground.py", line 5, in fn - y = torch.sin(x, x) - ... - -It can be difficult to tell from the logs if the error is caused by your code or because of a ``torch.compile`` bug. -In order to differentiate, we recommend trying to run your code without ``torch.compile`` to see if you still get the error. - -Dealing with recompilations ---------------------------- - -You can view recompilations and their reasons using ``tlparse`` or ``TORCH_LOGS=recompiles``. - -Is dynamic shapes enabled? -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Recompilations due to mismatched shapes are in the form: - -:: - - tensor 'L['x']' size mismatch at index 0. expected 3, actual 4 - -Make sure that the ``dynamic`` option of ``torch.compile`` is not set to ``False``. -The default option, ``dynamic=None``, will only attempt dynamic shapes after the first compilation. -You can set ``dynamic=True`` to upfront compile as dynamic as possible. - -For more information on dynamic shapes, see `The dynamic shapes manual `__. - -Changing the cache size limit -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -There is a limit to how many times a function can be recompiled, determined by ``torch._dynamo.config.recompile_limit`` -and ``torch._dynamo.config.accumulated_recompile_limit``. -If either limit is exceeded, then we will not attempt to compile the function again and instead will run the function eagerly. -``torch.compile`` will also issue a warning containing the affected function and which limit was hit. -In the example below, each function call results in a recompile attempt. -When we hit the cache size limit (8), we stop attempting to recompile. - -.. code-block:: py - - import torch - - @torch.compile(dynamic=False) - def fn(x): - return x + 1 - - for i in range(1, 10): - fn(torch.ones(i)) - -:: - - $ python playground.py - torch._dynamo hit config.recompile_limit (8) - function: 'fn' (/data/users/williamwen/pytorch/playground.py:5) - last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9 - -If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit. -If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit. - -Wrapping constants with tensors -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -By default, ``int`` / ``float`` variables are treated as constants and are guarded as such. -In the below example, we have a recompilation for each function call. - -.. code-block:: py - - import torch - - @torch.compile - def fn(x, c): - return x + c - - for i in range(1, 10): - fn(torch.ones(i), 0.5 + i) - -:: - - $ TORCH_LOGS="recompiles" python playground.py - Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3 - triggered by the following guard failure(s): - - 0/7: L['c'] == 8.5 - - 0/6: L['c'] == 7.5 - - 0/5: L['c'] == 6.5 - - 0/4: L['c'] == 5.5 - - 0/3: L['c'] == 4.5 - - 0/2: L['c'] == 3.5 - - 0/1: L['c'] == 2.5 - - 0/0: L['c'] == 1.5 - torch._dynamo hit config.recompile_limit (8) - function: 'fn' (/data/users/williamwen/pytorch/playground.py:3) - last reason: 0/0: L['c'] == 1.5 - -In particular, for LR schedulers, initializing with a constant can lead to recompilations: - -.. code-block:: py - - import torch - - mod = torch.nn.Linear(3, 3) - opt = torch.optim.Adam(mod.parameters(), lr=0.01) - sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9) - - @torch.compile - def fn(inp): - opt.zero_grad(True) - out = mod(inp).sum() - out.backward() - opt.step() - sched.step() - - for i in range(1, 10): - fn(torch.ones(3, 3)) - -:: - - $ TORCH_LOGS="recompiles" python playground.py - Recompiling function step in /data/users/williamwen/pytorch/torch/optim/adam.py:189 - triggered by the following guard failure(s): - - 3/7: L['self'].param_groups[0]['lr'] == 0.004782969000000002 - - 3/6: L['self'].param_groups[0]['lr'] == 0.005314410000000002 - - 3/5: L['self'].param_groups[0]['lr'] == 0.005904900000000002 - - 3/4: L['self'].param_groups[0]['lr'] == 0.006561000000000002 - - 3/3: L['self'].param_groups[0]['lr'] == 0.007290000000000001 - - 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001 - - 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001 - - 3/0: L['self'].param_groups[0]['lr'] == 0.01 - torch._dynamo hit config.recompile_limit (8) - function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189) - last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01 - -In both examples, we can wrap float variables in tensors in order to prevent recompilations. - -.. code-block:: py - - # first example - for i in range(1, 10): - fn(torch.ones(i), torch.tensor(0.5 + i)) - - # second example - opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01)) - sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9)) - -Reporting Issues -~~~~~~~~~~~~~~~~ - -If the workarounds provided above were not enough to get ``torch.compile`` working, -then you should consider reporting the issue to PyTorch. -But there are a few things that you can do to make our lives significantly easier. - -Ablation --------- - -Check which component of the ``torch.compile`` stack is the one causing the issue using the ``backend=`` option for ``torch.compile``. -In particular, try: - -- ``torch.compile(fn, backend="eager")``, which only runs TorchDynamo, the graph capture component of ``torch.compile``. -- ``torch.compile(fn, backend="aot_eager")``, which runs TorchDynamo and AOTAutograd, which additionally generates the backward graph during compilation. -- ``torch.compile(fn, backend="aot_eager_decomp_partition")``, which runs TorchDynamo and AOTAutograd with operator decompositions/partitions. -- ``torch.compile(fn, backend="inductor")``, which runs TorchDynamo, AOTAutograd, and TorchInductor, the backend ML compiler that generates compiled kernels. - -If you only fail with the Inductor backend, you can additionally test various Inductor modes: - -- ``torch.compile(fn, backend="inductor", mode="default")`` -- ``torch.compile(fn, backend="inductor", mode="reduce-overhead")`` -- ``torch.compile(fn, backend="inductor", mode="max-autotune")`` - -You can also check if dynamic shapes is causing issues with any backend: - -- ``torch.compile(fn, dynamic=True)`` (always use dynamic shapes) -- ``torch.compile(fn, dynamic=False)`` (never use dynamic shapes) -- ``torch.compile(fn, dynamic=None)`` (automatic dynamic shapes) - -Bisecting ---------- -Did you try on the latest nightly? Did something work in the past but now no longer works? -Can you bisect to determine the first nightly where your issue occurs? -Bisecting is especially helpful for performance, accuracy, or compile time regressions, -where it is not immediately obvious where the problem originates from. - -Creating a reproducer ---------------------- - -Creating reproducers is a lot of work, and it is perfectly fine if you do not have the time to do it. -However, if you are a motivated user unfamiliar with the internals of ``torch.compile``, -creating a standalone reproducer can have a huge impact on our ability to fix the bug. -Without a reproducer, your bug report must contain enough information for us to identify the root cause of the problem and write a reproducer from scratch. - -Here's a list of useful reproducers, ranked from most to least preferred: - -1. **Self-contained, small reproducer:** A script with no external dependencies, under 100 lines of code, that reproduces the problem when run. -2. **Self-contained, large reproducer:** Even if it's large, being self-contained is a huge advantage! -3. **Non-self-contained reproducer with manageable dependencies:** - For example, if you can reproduce the problem by running a script after ``pip install transformers``, - that's manageable. We can likely run it and investigate. -4. **Non-self-contained reproducer requiring substantial setup:** This might involve downloading datasets, - multiple environment setup steps, or specific system library versions requiring a Docker image. - The more complex the setup, the harder it is for us to recreate the environment. - - .. note:: - Docker simplifies setup but complicates changes to the environment, so it's not a perfect solution, though we'll use it if necessary. - -Somewhat orthogonally, a reproducer that can be run in a single process is better than a reproducer -that requires multiprocess training (but once again, if you only have a multiprocess reproducer, we'll take it!). - -Additionally, below is a non-exhaustive list of aspects to check in your -issue that you can attempt to replicate in your reproducer: - -- **Autograd**. Did you have tensor inputs with ``requires_grad=True``? Did you call ``backward()`` on the output? -- **Dynamic shapes**. Did you set ``dynamic=True``? Or did you run the test code multiple times with varying shapes? -- **Custom operators**. Is there a custom operator involved in the real workflow? - Can you replicate some of its important characteristics using the Python custom operator API? -- **Configuration**. Did you set all the same configuration? - This includes ``torch._dynamo.config`` and ``torch._inductor.config`` settings, - as well as arguments to ``torch.compile`` like ``backend`` / ``mode``. -- **Context managers**. Did you replicate any active context managers? - This could be ``torch.no_grad``, automatic mixed precision, ``TorchFunctionMode`` / ``TorchDispatchMode``, - activation checkpointing, compiled autograd etc. -- **Tensor subclasses**. Is there a tensor subclass involved? - -Minifier --------- - -The minifier is an early ``torch.compile`` tool that, given an FX graph that crashes when we attempt to run or compile it, -finds a subgraph that also crashes and outputs the code that performs that subgraph's operations. -Essentially, the minifier finds a minimal repro for a certain class of ``torch.compile``-related crashes. -This assumes that we were able to successfully trace through code. - -Unfortunately, most of the time nowadays, the minifier doesn't work as expected, and alternative methods may be necessary. -This is likely because bugs that can be automatically reproduced in this manner are generally easier to fix -and have already been addressed, leaving more complex issues that do not reproduce easily. -However, it is straightforward to attempt using the minifier, so it is worth trying even if it may not succeed. - -Instructions for operating the minifier can be found `here `__. -If the compiler is crashing, you can set ``TORCHDYNAMO_REPRO_AFTER="dynamo"`` or ``TORCHDYNAMO_REPRO_AFTER="aot"`` -The ``aot`` option is more likely to succeed, although it may not identify the ``AOTAutograd`` issues. This will generate the ``repro.py`` file which may help to diagnose the problem. -For accuracy-related issues, consider setting ``TORCHDYNAMO_REPRO_LEVEL=4``. Please note that this may not always successfully identify the problematic subgraph. - -Debugging Deeper -~~~~~~~~~~~~~~~~ - -This section provides tools and techniques for independently debugging ``torch.compile`` issues -or for gaining a deeper understanding of the ``torch.compile`` stack. -These methods are more involved than those presented above and are used by PyTorch developers regularly -to debug real ``torch.compile`` issues. - -Below is a high-level overview of the stack: - -.. image:: _static/img/dynamo/td_stack.png - -The stack comprises three main components: TorchDynamo, AOTAutograd, and Inductor. -Our debugging strategy involves first identifying the component in which the error occurs -and then individually debugging the component. To determine the component responsible for the issue, -see the `Ablation` section under `Reporting Issues` above. For guidance on debugging a specific component, consult the sections below. - -TorchDynamo ------------ - -Logging what Dynamo is tracing -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The ``TORCH_LOGS=trace_bytecode`` option enables you to view the precise bytecode instructions that Dynamo is tracing, -as well as a symbolic representation of the Python interpreter stack. When encountering a graph break or crash, -it is advisable to inspect the last few bytecode instructions traced. - -You can also use ``TORCH_LOGS=trace_source`` to see which lines of source code Dynamo is tracing through. -This is useful in combination with ``trace_bytecode`` to see the line of source code each traced bytecode instruction corresponds to. - -Finally, you can use ``TORCH_LOGS=graph_code`` to see the Python code representing the FX graph that Dynamo traced. -You can view this code to double check that the correct ops are being traced. - -.. code-block:: py - - import torch - - def g(x, y): - return x + y - - @torch.compile(backend="eager") - def f(x): - x = torch.sin(x) - x = g(x, x) - return x - - f(torch.ones(3, 3)) - -:: - - $ TORCH_LOGS="trace_bytecode,trace_source,graph_code" python playground.py - TRACE starts_line /data/users/williamwen/pytorch/playground.py:6 in f () - @torch.compile(backend="eager") - TRACE RESUME 0 [] - TRACE starts_line /data/users/williamwen/pytorch/playground.py:8 in f (f) - x = torch.sin(x) - TRACE LOAD_GLOBAL torch [] - TRACE LOAD_ATTR sin [NullVariable(), PythonModuleVariable()] - TRACE LOAD_FAST x [NullVariable(), TorchInGraphFunctionVariable()] - TRACE CALL 1 [NullVariable(), TorchInGraphFunctionVariable(), LazyVariableTracker()] - TRACE STORE_FAST x [TensorVariable()] - TRACE starts_line /data/users/williamwen/pytorch/playground.py:9 in f (f) - x = g(x, x) - TRACE LOAD_GLOBAL g [] - TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable()] - TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable(), TensorVariable()] - TRACE CALL 2 [NullVariable(), UserFunctionVariable(), TensorVariable(), TensorVariable()] - TRACE starts_line /data/users/williamwen/pytorch/playground.py:3 in g (g) (inline depth: 1) - def g(x, y): - TRACE RESUME 0 [] - TRACE starts_line /data/users/williamwen/pytorch/playground.py:4 in g (g) (inline depth: 1) - return x + y - TRACE LOAD_FAST x [] - TRACE LOAD_FAST y [TensorVariable()] - TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()] - TRACE RETURN_VALUE None [TensorVariable()] - TRACE STORE_FAST x [TensorVariable()] - TRACE starts_line /data/users/williamwen/pytorch/playground.py:10 in f (f) - return x - TRACE LOAD_FAST x [] - TRACE RETURN_VALUE None [TensorVariable()] - TRACED GRAPH - ===== __compiled_fn_1 ===== - /data/users/williamwen/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[3, 3][3, 1]cpu"): - l_x_ = L_x_ - - # File: /data/users/williamwen/pytorch/playground.py:8 in f, code: x = torch.sin(x) - x: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None - - # File: /data/users/williamwen/pytorch/playground.py:4 in g, code: return x + y - x_1: "f32[3, 3][3, 1]cpu" = x + x; x = None - return (x_1,) - -Breakpointing Dynamo tracing -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Inserting a breakpoint in Dynamo/user code is helpful at times to see what the state of Dynamo is when tracing through user code. -Unfortunately, inserting a breakpoint in the normal Python fashion will result in a graph break in TorchDynamo, -so we will not be able to view the state of Dynamo at the point where we intended to breakpoint. - -The first method for setting a breakpoint is to insert it within the Dynamo source code. Three recommended locations to place a breakpoint are: - -- In ``torch/_dynamo/symbolic_convert.py``, breakpoint at functions that are named after the problematic bytecode instruction, - such as ``def CALL_FUNCTION`` and ``def STORE_ATTR``. You can conditionally breakpoint depending on inputs, - for example, the ``argval`` of the instruction, or the name of the object at the top of the stack since some bytecode opcodes are frequently used. -- Breakpoint where the graph break or error originates from. Typically, graph breaks are emitted from a call to ``unimplemented(...)``. -- Breakpoint in ``torch/_dynamo/variables/builder.py, function:_wrap``. You will likely have to conditionally breakpoint on the input. - This function determines how to symbolically represent a given value. Consider breakpointing here if you suspect that a value is represented incorrectly. - -The second way to insert a breakpoint is to use ``torch._dynamo.comptime.comptime.breakpoint``: - -.. code-block:: py - - from torch._dynamo.comptime import comptime - - @torch.compile - def f(...): - ... - comptime.breakpoint() - ... - -A comptime breakpoint is convenient as it enables you to inspect the Dynamo state at a specific location within the user code being traced. -It does not require you to insert a breakpoint in the Dynamo source or to conditionally breakpoint based on variables. - -When a comptime breakpoint is triggered, you can do the following: - -- ``ctx.print_bt()`` to print the user stack trace -- ``ctx.print_locals()`` to print all current locals -- ``ctx.print_graph()`` to print the currently traced graph -- ``ctx.disas()`` to print the currently traced function's bytecode -- Use standard ``pdb`` commands, such as ``bt/u/d/n/s/r``, - you can go up the ``pdb`` stack to inspect more Dynamo internals - -.. code-block:: py - - import torch - from torch._dynamo.comptime import comptime - - @torch.compile(backend="eager") - def f(x): - y = x + 1 - comptime.breakpoint() - y = y + 1 - return y - - f(torch.ones(3, 3)) - -:: - - $ python playground.py - --Return-- - > /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None - -> builtins.breakpoint() - (Pdb) ctx.print_bt() - File "/data/users/williamwen/pytorch/playground.py", line 7, in f - comptime.breakpoint() - - (Pdb) ctx.print_locals() - x = FakeTensor(..., size=(3, 3)) - y = FakeTensor(..., size=(3, 3)) - (Pdb) bt - ... - /data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py(826)call_function() - -> self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] - /data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py(331)call_function() - -> func(ComptimeContext(tx)) - > /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None - -> builtins.breakpoint() - (Pdb) ctx.print_graph() - - - - def forward(self, L_x_: "f32[3, 3]"): - l_x_ = L_x_ - - # File: /data/users/williamwen/pytorch/playground.py:6 in f, code: y = x + 1 - y: "f32[3, 3]" = l_x_ + 1; l_x_ = y = None - -.. - TODO(uncomment/update once we improve this API) - Debugging large models - ^^^^^^^^^^^^^^^^^^^^^^ - - Debugging TorchDynamo on large models can be tricky, mainly because Dynamo traces through large amounts of code. - It can be difficult to find the problematic function, or to determine where to place a breakpoint. - Even if we've found the problematic function, we don't want to deal with logging spam. - Fortunately, you can use ``TORCHDYNAMO_DEBUG_FUNCTION=``, which limits dynamo tracing to only functions with a specific name - (exact match). This will allow you to filter all of the functions in the model to the function(s) of interest. - Use this in combination with the above debugging strategies. - -Bytecode generation errors -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Although uncommon, Dynamo may generate incorrect bytecode. This may occur if you determine the following: - -- Ablation reveals the error is happening at the TorchDynamo level -- The error is not being emitted from TorchDynamo stack frames -- The error looks more like a user error rather than a Dynamo error, or is a segmentation fault -- The error does not occur without ``torch.compile`` - -Bytecode generation bugs are generally tricky to fix and we recommend submitting an issue instead of trying to fix those yourself. -If you are interested in seeing the bytecode that Dynamo generates, you can use ``TORCH_LOGS=bytecode``. -You can see a high-level overview on what bytecode Dynamo generates `here `__. - -AOTAutograd ------------ - -AOTAutograd errors are typically difficult to debug - we recommend just submitting an issue. -AOTAutograd logging output is primarily helpful to see what the input to Inductor is. - -.. - TODO - TorchInductor - ------------- - -.. TODO - -.. _troubleshooting_torch_logs_options: - -Summary of TORCH_LOGS options ---------------------------------- - -A summary of helpful ``TORCH_LOGS`` options is: - -.. list-table:: - :widths: 25 50 - :header-rows: 1 - - * - Option - - Description - * - +all - - Output debug logs from all ``torch.compile`` components - * - +dynamo - - Output debug logs from TorchDynamo - * - +aot - - Output debug logs from AOTAutograd - * - +inductor - - Output debug logs from TorchInductor - * - dynamic - - Output logs from dynamic shapes - * - graph_code - - Output the Python code for the FX graph that Dynamo generated - * - graph_sizes - - Output the tensor sizes of the FX graph that Dynamo generated - * - trace_bytecode - - Output the bytecode instructions that Dynamo is tracing through and the symbolic interpreter stack Dynamo is keeping track of - * - trace_source - - Output the line of code in the original source that Dynamo is currently tracing through - * - bytecode - - Output Dynamo-generated bytecode - * - guards - - Output generated guards - * - recompiles - - Output recompilation reasons (only the first guard check that fails) - * - recompiles_verbose - - Output all guard checks that fail when a recompilation occurs - * - aot_graphs - - Output graph generated by AOTAutograd - * - aot_joint_graphs - - Output the joint forward-backward graph generated by AOTAutograd - * - output_code - - Output code generated by Inductor - * - kernel_code - - Output code generated by Inductor on a per-kernel basis - * - schedule - - Output Inductor scheduling logs - * - perf_hints - - Output Inductor perf hint logs - * - fusion - - Output Inductor fusion logs - -For the full list of options, see `torch._logging `__ -and `torch._logging.set_logs `__. - -Related Articles -~~~~~~~~~~~~~~~~ - -- `torch.compile tutorial `__ -- `torch.compile fine-grained APIs `__ -- `torch.compile FAQ `__ -- `torch.compiler namespace overview `__ -- `torch.compiler API reference `__ -- `Profiling torch.compile `__ -- `torch.compile missing manual `__ -- `The dynamic shapes manual `__ -- `TorchInductor caching tutorial `__ diff --git a/docs/source/torch.compiler_troubleshooting_old.md b/docs/source/torch.compiler_troubleshooting_old.md new file mode 100644 index 00000000000000..03555d74e817ce --- /dev/null +++ b/docs/source/torch.compiler_troubleshooting_old.md @@ -0,0 +1,721 @@ +--- +orphan: true +--- + +(torch.compiler_troubleshooting_old)= + +# PyTorch 2.0 Troubleshooting (old) + +**Author**: [Michael Lazos](https://github.com/mlazos) + +:::{note} +This document is outdated and is now mainly a primary resource on how to run the `torch.compile` minifier. +Please see the [updated troubleshooting document](https://pytorch.org/docs/main/torch.compiler_troubleshooting.html). +There is also a more [comprehensive manual for torch.compile](https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.ivdr7fmrbeab) +available. +::: + +We are actively developing debug tools, profilers, and improving our +error and warning messages. Below is a table of the available +tools and their typical usage. For additional help see +{ref}`diagnosing-runtime-errors`. + +```{eval-rst} +.. list-table:: Title + :widths: 25 25 50 + :header-rows: 1 + + * - Tool + - Purpose + - Usage + * - Info logging + - View summarized steps of compilation + - ``torch._logging.set_logs(dynamo = logging.INFO)`` or ``TORCH_LOGS="dynamo"`` + * - Debug logging + - View detailed steps of compilation (print every instruction traced) + - ``torch._logging.set_logs(dynamo = logging.DEBUG)`` and + ``torch._dynamo.config.verbose = True``, or ``TORCH_LOGS="+dynamo" TORCHDYNAMO_VERBOSE=1`` + * - Minifier for any backend + - Find smallest subgraph which reproduces errors for any backend + - set environment variable ``TORCHDYNAMO_REPRO_AFTER="dynamo"`` + * - Minifier for ``TorchInductor`` + - If the error is known to occur after ``AOTAutograd`` find + smallest subgraph which reproduces errors during ``TorchInductor`` lowering + - set environment variable ``TORCHDYNAMO_REPRO_AFTER="aot"`` + * - Dynamo accuracy minifier + - Finds the smallest subgraph which reproduces an accuracy issue + between an eager mode model and optimized model, when you + suspect the problem is in ``AOTAutograd`` + - ``TORCHDYNAMO_REPRO_AFTER="dynamo" TORCHDYNAMO_REPRO_LEVEL=4`` + * - Inductor accuracy minifier + - Finds the smallest subgraph which reproduces an accuracy issue + between an eager mode model and optimized model, when you + suspect the problem is in the backend (e.g., inductor). + If this doesn't work, try the Dynamo accuracy minifier + instead. + - ``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` + * - ``torch._dynamo.explain`` + - Find graph breaks and display reasoning for them + - ``torch._dynamo.explain(fn)(*inputs)`` + * - Record/Replay + - Record and replay frames which to reproduce errors during graph capture + - ``torch._dynamo.config.replay_record_enabled = True`` + * - TorchDynamo function name filtering + - Only compile functions with the given name to reduce noise when + debugging an issue + - set environment variable ``TORCHDYNAMO_DEBUG_FUNCTION=`` + * - TorchInductor Debug logging + - Print general TorchInductor debug info and generated Triton/C++ code + - ``torch._inductor.config.debug = True`` + * - TorchInductor Tracing + - Show time taken in each TorchInductor stage + output code and graph + visualization + - set the environment variable TORCH_COMPILE_DEBUG=1 or + ``torch._inductor.config.trace.enabled = True`` +``` + +In addition to info and debug logging, +you can use [torch.\_logging](https://pytorch.org/docs/main/logging.html) +for more fine-grained logging. + +(diagnosing-runtime-errors)= +## Diagnosing Runtime Errors + +At a high level, the TorchDynamo stack consists of a graph capture from +Python code (TorchDynamo) and a backend compiler. For example, a +backend compiler may consist of backward graph tracing (AOTAutograd) and +graph lowering (TorchInductor)\*. Errors can occur in any component of +the stack and will provide full stack traces. + +To determine in which component an error occurred, +you may use info-level logging +`torch._logging.set_logs(dynamo = logging.INFO)` or `TORCH_LOGS="dynamo"` +and look for `Step #: ...` outputs. Logs are made at the beginning and end of +each step, so the step that an error should correspond to is the most recently +logged step whose end has not yet been logged. The steps correspond to the +following parts of the stack: + +| Step | Component | +| ---- | ---------------- | +| 1 | TorchDynamo | +| 2 | Compiler Backend | +| 3 | TorchInductor | + +If info logging is insufficient, you can use available backend +options. These options include: + +- `"eager"`: only runs TorchDynamo forward graph capture and then + runs the captured graph with PyTorch. This provides an indication as + to whether TorchDynamo is raising the error. +- `"aot_eager"`: runs TorchDynamo to capture a forward graph, and + then AOTAutograd to trace the backward graph without any additional + backend compiler steps. PyTorch eager will then be used to run the + forward and backward graphs. This is useful to narrow down the issue + to AOTAutograd. + +The general procedure to narrow down an issue is the following: + +1. Run your program with the `"eager"` backend. If the error no longer + occurs, the issue is in the backend compiler that is being used (if + using TorchInductor, proceed to step 2. If not, see + {ref}`minifying-backend-compiler-errors`). If the error still + occurs with the `"eager"` backend, it is due to + {ref}`torchdynamo-errors`. +2. This step is only necessary if `TorchInductor` is used as the backend + compiler. Run the model with the `"aot_eager"` backend. If this + backend raises an error then the error is occurring during + AOTAutograd tracing. If the error no longer occurs with this backend, + then {ref}`minifying-torchinductor-errors`. + +Each of these cases are analyzed in the following sections. + +:::{note} +The TorchInductor backend consists of +both AOTAutograd tracing and the TorchInductor compiler itself. We will +disambiguate by referring to `TorchInductor` as the backend, and +TorchInductor lowering as the phase which lowers the graph traced by +AOTAutograd. +::: + +(torchdynamo-errors)= + +### Torchdynamo Errors + +If the error that is generated occurs with the `"eager"` backend, then +TorchDynamo is most likely the source of the error. Here is a sample code +which will generate an error. + +```py +import torch + +import torch._dynamo as dynamo + + +def test_assertion_error(): + y = torch.ones(200, 200) + z = {y: 5} + return z + +compiled_test_assertion_error = torch.compile(test_assertion_error, backend="eager") + +compiled_test_assertion_error() +``` + +The code above generates the following error: + +``` +torch._dynamo.convert_frame: [ERROR] WON'T CONVERT test_assertion_error /scratch/mlazos/torchdynamo/../test/errors.py line 26 +due to: +Traceback (most recent call last): + File "/scratch/mlazos/torchdynamo/torchdynamo/symbolic_convert.py", line 837, in BUILD_MAP + assert isinstance(k, ConstantVariable) or ( +AssertionError + +from user code: + File "/scratch/mlazos/torchdynamo/../test/errors.py", line 34, in test_assertion_error + z = {y: 5} + +Set torch._dynamo.config.verbose=True for more information +========== +``` + +As the message suggests you can set +`torch._dynamo.config.verbose=True` to get a full stack trace to both +the error in TorchDynamo and the user code. In addition to this flag, +you can also set the `log_level` of TorchDynamo through +`torch._logging.set_logs(dynamo = logging.INFO)` or `TORCH_LOGS="dynamo"`. These levels include: + +- `logging.DEBUG` or `TORCH_LOGS="+dynamo"`: Print every instruction that is + encountered in addition to all the log levels listed below. +- `logging.INFO`: + Print each function that is compiled (original and modified bytecode) + and the graph that is captured in addition to all the log levels listed below. +- `logging.WARNING` (default): Print graph breaks in addition to all + the log levels listed below. +- `logging.ERROR`: Print errors only. + +If a model is very large, the logs can become overwhelming. If +an error occurs deep within a model's Python code, it can be useful to +execute only the frame in which the error occurs to enable easier +debugging. There are two tools available to enable this: + +- Setting the environment variable `TORCHDYNAMO_DEBUG_FUNCTION` + to the desired function name will only run torchdynamo on functions with that + name. +- Enabling the record/replay tool (set `torch._dynamo.config.replay_record_enabled = True`) + which dumps an execution record when an error is encountered. This record can + then be replayed to run only the frame where an error occurred. + +### Diagnosing TorchInductor Errors + +If the error does not occur with the `"eager"` backend, then the +backend compiler is the source of the error ([example +error](https://gist.github.com/mlazos/2f13681e3cc6c43b3911f336327032de)). +There are [different choices](./torch.compiler.md) +for backend compilers for TorchDynamo, with TorchInductor +fitting the needs of most users. This section focuses on TorchInductor +as the motivating example, but some tools can also be used with other +backend compilers. + +Below is the portion of the stack which we are focusing on: + +With TorchInductor as the chosen backend, AOTAutograd is used to +generate the backward graph from the forward graph captured by +torchdynamo. It is important to note that errors can occur during this +tracing and also while TorchInductor lowers the forward and backward +graphs to GPU code or C++. A model can often consist of hundreds or +thousands of FX nodes, so narrowing the exact nodes where this problem +occurred can be very difficult. Fortunately, there are tools available to +automatically minify these input graphs to the nodes which are causing +the issue. The first step is to determine whether the error occurs +during tracing of the backward graph with AOTAutograd or during +TorchInductor lowering. As mentioned above in step 2, the +`"aot_eager"` backend can be used to run only AOTAutograd in isolation +without lowering. If the error still occurs with this backend, this +indicates that the error is occurring during AOTAutograd tracing. + +Here is an example: + +```py +import torch + +import torch._dynamo as dynamo + +model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)]) + +def test_backend_error(): + + y = torch.ones(200, 200) + x = torch.ones(200, 200) + z = x + y + a = torch.ops.aten._foobar(z) # dummy function which errors + return model(a) + + +compiled_test_backend_error = torch.compile(test_backend_error, backend="inductor") +compiled_test_backend_error() +``` + +Running this should give you this error with a longer stack trace below +it: + +``` +Traceback (most recent call last): + File "/scratch/mlazos/torchdynamo/torchinductor/graph.py", line 246, in call_function + return lowerings[target](*args, **kwargs) + File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 185, in wrapped + return decomp_fn(*args, **kwargs) + File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 810, in _foobar + assert False +AssertionError +... +``` + +[error with full stack +trace](https://gist.github.com/mlazos/d6947854aa56d686800259a164c62100) + +If you then change `torch.compile(backend="inductor")` to +`torch.compile(backend="aot_eager")`, it will run without error, because +[the +issue](https://github.com/pytorch/torchdynamo/blob/d09e50fbee388d466b5252a63045643166006f77/torchinductor/lowering.py#:~:text=%23%20This%20shouldn%27t%20be,assert%20False) +is in the TorchInductor lowering process, not in AOTAutograd. + +(minifying-torchinductor-errors)= + +### Minifying TorchInductor Errors + +From here, let’s run the minifier to get a minimal repro. Setting the +environment variable `TORCHDYNAMO_REPRO_AFTER="aot"` (or setting +`torch._dynamo.config.repro_after="aot"` directly) will generate a +Python program which reduces the graph produced by AOTAutograd to the +smallest subgraph which reproduces the error. (See below for an example +where we minify the graph produced by TorchDynamo) Running the program +with this environment variable should show nearly [identical +output](https://gist.github.com/mlazos/0458ab828aa403c779fe73c012aa5982), +with an additional line indicating where `minifier_launcher.py` has +been written to. The output directory is configurable by setting +`torch._dynamo.config.base_dir` to a valid directory name. The final +step is to run the minifier and check that it runs successfully. A +successful run looks like +[this](https://gist.github.com/mlazos/e6ea41ccce68a7b1b8a7a09acb1b206a). +If the minifier runs successfully, it generates runnable python code +which reproduces the exact error. For our example this is the following +code: + +```python +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +from torch.fx.experimental.proxy_tensor import make_fx + +# torch version: 1.13.0a0+gitfddfc44 +# torch cuda version: 11.6 +# torch git version: fddfc4488afb207971c54ad4bf58130fdc8a4dc5 + + +# CUDA Info: +# nvcc: NVIDIA (R) Cuda compiler driver +# Copyright (c) 2005-2022 NVIDIA Corporation +# Built on Thu_Feb_10_18:23:41_PST_2022 +# Cuda compilation tools, release 11.6, V11.6.112 +# Build cuda_11.6.r11.6/compiler.30978841_0 + +# GPU Hardware Info: +# NVIDIA A100-SXM4-40GB : 8 + +from torch.nn import * + +class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, add): + _foobar = torch.ops.aten._foobar.default(add); add = None + return (_foobar,) + +args = [((200, 200), (200, 1), torch.float32, 'cpu')] +args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args] +mod = make_fx(Repro())(*args) +from torch._inductor.compile_fx import compile_fx_inner + +compiled = compile_fx_inner(mod, args) +compiled(*args) +``` + +The `forward` method of the `Repro` module contains the exact op +which causes the issue. When filing an issue, please include any +minified repros to aid in debugging. + +(minifying-backend-compiler-errors)= + +### Minifying Backend Compiler Errors + +With backend compilers other than TorchInductor the process for finding +the subgraph causing the error is nearly identical to the procedure in +{ref}`minifying-torchinductor-errors` with one important +caveat. Namely, that the minifier will now be run on the graph that is +traced by TorchDynamo, not the output graph of AOTAutograd. Let’s walk +through an example. + +```py +import torch + +import torch._dynamo as dynamo + +model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)]) +# toy compiler which fails if graph contains relu +def toy_compiler(gm: torch.fx.GraphModule, _): + for node in gm.graph.nodes: + if node.target == torch.relu: + assert False + + return gm + + +def test_backend_error(): + y = torch.ones(200, 200) + x = torch.ones(200, 200) + z = x + y + a = torch.relu(z) + return model(a) + + +compiled_test_backend_error = torch.compile(test_backend_error, backend=toy_compiler) +compiled_test_backend_error() +``` + +In order to run the code after TorchDynamo has traced the forward graph, +you can use the `TORCHDYNAMO_REPRO_AFTER` environment variable. Running +this program with `TORCHDYNAMO_REPRO_AFTER="dynamo"` (or +`torch._dynamo.config.repro_after="dynamo"`) should produce [this +output](https://gist.github.com/mlazos/244e3d5b53667e44078e194762c0c92b)and +the following code in `{torch._dynamo.config.base_dir}/repro.py`. + +:::{note} +The other option for TORCHDYNAMO_REPRO_AFTER is `"aot"`, which +will run the minifier after the backward graph has been generated. +::: + +```python +import torch +import torch._dynamo as dynamo +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +from torch._dynamo.debug_utils import run_fwd_maybe_bwd + +from torch.nn import * + +class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, add): + relu = torch.relu(add); add = None + return (relu,) + + +mod = Repro().cuda() +opt_mod = torch.compile(mod, backend="None") + + +args = [((200, 200), (200, 1), torch.float32, 'cpu', False)] +args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] + + +with torch.cuda.amp.autocast(enabled=False): + ref = run_fwd_maybe_bwd(mod, args) + res = run_fwd_maybe_bwd(opt_mod, args) +``` + +The minifier successfully reduced the graph to the op that raises the +error in `toy_compiler`. The other difference from the procedure in +{ref}`minifying-torchinductor-errors` is that the minifier is +automatically run after encountering a backend compiler error. After a +successful run, the minifier writes `repro.py` to +`torch._dynamo.config.base_dir`. + +## Performance Profiling + +### Accessing TorchDynamo Profiler + +TorchDynamo has a built-in stats function for collecting and displaying +the time spent in each compilation phase. These stats can be accessed by +calling `torch._dynamo.utils.compile_times()` after executing +Torch.\_Dynamo. By default, this returns a string representation of the +compile times spent in each TorchDynamo function by name. + +### TorchInductor Debugging using TORCH_COMPILE_DEBUG + +TorchInductor has a builtin stats and trace function for displaying time +spent in each compilation phase, output code, output graph visualization +and IR dump. This is a debugging tool designed to make it easier to +understand and troubleshoot the internals of TorchInductor. + +Let's run an example with the following test program (`repro.py`): + +``` +import torch + +@torch.compile() +def test_model(x): + model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.LayerNorm(10), + torch.nn.ReLU(), + ) + return model(x) + + +y = test_model(torch.ones(10, 10)) +``` + +Setting the environment variable `TORCH_COMPILE_DEBUG=1` will cause a +debug trace directory to be created, by default this directory will be in the +current directory and named torch_compile_debug (this can be overridden in +the torchdynamo configuration field `debug_dir_root` and also the +`env var TORCH_COMPILE_DEBUG_DIR`). Inside this directory, each run will +have a separate folder named with the timestamp and process id of the run: + +``` +$ env TORCH_COMPILE_DEBUG=1 python repro.py +$ cd torch_compile_debug +$ ls +run_2023_03_01_08_20_52_143510-pid_180167 +``` + +In the run folder there will be a `torchdynamo` directory which contains +debug logs, and an `torchinductor` folder which contains a subfolder for each +compiled kernel with inductor debug artifacts. + +``` +$ cd +run_2023_03_01_08_20_52_143510-pid_180167 +$ ls +torchinductor torchdynamo +``` + +Moving further into the `torchinductor` directory, the `\*.log` files are +logs from the AOT Autograd phase of compilation, `model__0_forward_1.0` contains +the inductor debug artifacts. + +``` +$ cd torchinductor +$ ls +aot_model___0_debug.log model__0_forward_1.0 +$ cd model__0_forward_1.0 +$ ls +debug.log fx_graph_readable.py fx_graph_runnable.py fx_graph_transformed.py ir_post_fusion.txt ir_pre_fusion.txt output_code.py +``` + +Here is a summary of the contents: + +- `fx_graph_readable.py` and `fx_graph_runnable.py` are the readable and + runnable versions of the `fx_graph` received by inductor. +- `fx_graph_transformed.py` is the fx graph after inductor has run all fx passes. +- `ir\*.txt` is the inductor ir pre and post fusion. +- `output_code.py` is the compiled triton kernel for the subgraph. + +Here are [example debug directory contents](https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396) +for the test program: + +``` +import torch + +@torch.compile() +def test_model(x): + model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.LayerNorm(10), + torch.nn.ReLU(), + ) + return model(x) + + +y = test_model(torch.ones(10, 10)) +``` + +Each file in that debug trace can be enabled and disabled through +`torch._inductor.config.trace.*`. The profile and the diagram are both +disabled by default since they are expensive to generate. + +A single node in this new debug format looks like: + +``` +buf1: SchedulerNode(ComputedBuffer) +buf1.writes = + { MemoryDep(name='buf1', index=0, size=()), + MemoryDep(name='buf1', index=0, size=(s0,))} +buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))} +buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))} +buf1.group.device = cuda:0 +buf1.group.iteration = (1, s0) +buf1.sizes = ([], [s0]) +class buf1_loop_body: + var_ranges = {z0: s0} + index0 = z0 + index1 = 0 + def body(self, ops): + get_index = self.get_index('index0') + load = ops.load('buf0', get_index, False) + get_index_1 = self.get_index('index0') + load_1 = ops.load('primals_2', get_index_1, False) + add = ops.add(load, load_1) + get_index_2 = self.get_index('index1') + reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add) + return reduction +``` + +See the [example debug directory +output](https://gist.github.com/jansel/f4af078791ad681a0d4094adeb844396) +for more examples. + +% _Memory Profiling +% ---------------- +% +% TBD + +### Graph Breaks + +Given a program like this: + +```python +def some_fun(x): + ... + +compiled_fun = torch.compile(some_fun, ...) +... +``` + +TorchDynamo will attempt to compile all of the torch/tensor operations +within some_fun into a single FX graph, but it may fail to capture +everything into one graph. + +Some graph break reasons are insurmountable to TorchDynamo, and can't be +easily fixed. - calling into a C extension other than torch is invisible +to torchdynamo, and could do arbitrary things without TorchDynamo being +able to introduce necessary guards (see {ref}`making-dynamo-sound-guards`) +to ensure that the compiled program would be safe to reuse. Graph breaks +can hinder performance if the resulting fragments are small. To maximize +performance, it's important to have as few graph breaks as possible. + +## Identifying the Cause of a Graph Break + +To identify all graph breaks in a program and the associated reasons for +the breaks, `torch._dynamo.explain` can be used. This tool runs +TorchDynamo on the supplied function and aggregates the graph breaks +that are encountered. Here is an example usage: + +```python +import torch +import torch._dynamo as dynamo +def toy_example(a, b): + x = a / (torch.abs(a) + 1) + print("woo") + if b.sum() < 0: + b = b * -1 + return x * b +explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10)) +print(explanation_verbose) +""" +Graph Count: 3 +Graph Break Count: 2 +Op Count: 5 +Break Reasons: + Break Reason 1: + Reason: builtin: print [] False + User Stack: + + Break Reason 2: + Reason: generic_jump TensorVariable() + User Stack: + +Ops per Graph: + ... +Out Guards: + ... +""" +``` + +Outputs include: + +- `out_guards` - a list of lists where each sublist contains the guards that must pass to ensure the traced graphs are valid. +- `graphs` - a list of graph modules which were successfully traced. +- `ops_per_graph` - a list of lists where each sublist contains the ops that are run in the graph. + +To throw an error on the first graph break encountered, use the `fullgraph` +mode. This mode disables TorchDynamo’s Python fallback, and only +succeeds if the entire program is convertible into a single graph. Example +usage: + +```python +def toy_example(a, b): + ... + +compiled_toy = torch.compile(toy_example, fullgraph=True, backend=)(a, b) +``` + +### Excessive Recompilation + +When TorchDynamo compiles a function (or part of one), it makes certain +assumptions about locals and globals in order to allow compiler +optimizations, and expresses these assumptions as guards that check +particular values at runtime. If any of these guards fail, Dynamo will +recompile that function (or part) up to +`torch._dynamo.config.recompile_limit` times. If your program is +hitting the cache limit, you will first need to determine which guard is +failing and what part of your program is triggering it. + +If your program exhibits a bounded amount of dynamism, you may be able +to tune the TorchDynamo cache limit to allow for each variation to be +compiled and cached, but if the cache limit is too high you may find the +cost of recompilation outweighs any optimization benefits. + +``` +torch._dynamo.config.recompile_limit = +``` + +TorchDynamo plans to support many common cases of dynamic tensor shapes, +such as varying batch size or sequence length. It does not plan to +support rank-dynamism. In the meantime, setting a specific cache limit +can be used in coordination with bucketing techniques to achieve an +acceptable number of recompilations for some dynamic models. + +## Accuracy Debugging + +Accuracy issues can also be minified if you set the environment variable +`TORCHDYNAMO_REPRO_LEVEL=4`, it operates with a similar git bisect +model and a full repro might be something like +`TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4` the reason +we need this is downstream compilers will codegen code whether it’s +Triton code or the C++ backend, the numerics from those downstream +compilers can be different in subtle ways yet have dramatic impact on +your training stability. So the accuracy debugger is very useful for us +to detect bugs in our codegen or with a backend compiler. + +If you'd like to ensure that random number generation is the same across both torch +and triton then you can enable `torch._inductor.config.fallback_random = True` + +## Extended Debugging + +Extended debugging can be enabled by using the following experimental flags. + +`TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED` - provides extended debug information if the +string representation of a guard matches this flag value. For example, set it to +"Ne(s0, 10)" to generate full Python and C++ backtrace whenever guard was issued. +`TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL` - provides extended debug information when +a particular symbol is allocated. For example, set this to "u2" to generate full Python +and C++ backtrace whenever this symbol was created. +`TORCHDYNAMO_EXTENDED_DEBUG_CPP` - provides extended debug information (C++ backtrace) +for all extended debug settings as well as errors. For example, set this to "1". The C++ +backtrace is slow and very spammy so it is not included by default with extended debugging. + +## Cold Start Timing and Cache Corruption Debugging + +In order to measure the cold start compilation time or debug a cache corruption, +it is possible pass `TORCHINDUCTOR_FORCE_DISABLE_CACHES=1` or set +`torch._inductor.config.force_disable_caches = True` which will override any +other caching config option and disable all compile time caching. diff --git a/docs/source/torch.compiler_troubleshooting_old.rst b/docs/source/torch.compiler_troubleshooting_old.rst deleted file mode 100644 index 5f693741e94c77..00000000000000 --- a/docs/source/torch.compiler_troubleshooting_old.rst +++ /dev/null @@ -1,727 +0,0 @@ -:orphan: - -.. _torch.compiler_troubleshooting_old: - -PyTorch 2.0 Troubleshooting (old) -================================= - -**Author**: `Michael Lazos `_ - -.. note:: This document is outdated and is now mainly a primary resource on how to run the ``torch.compile`` minifier. - Please see the `updated troubleshooting document `__. - There is also a more `comprehensive manual for torch.compile `__ - available. - -We are actively developing debug tools, profilers, and improving our -error and warning messages. Below is a table of the available -tools and their typical usage. For additional help see -`Diagnosing Runtime Errors <#diagnosing-runtime-errors>`__. - -.. list-table:: Title - :widths: 25 25 50 - :header-rows: 1 - - * - Tool - - Purpose - - Usage - * - Info logging - - View summarized steps of compilation - - ``torch._logging.set_logs(dynamo = logging.INFO)`` or ``TORCH_LOGS="dynamo"`` - * - Debug logging - - View detailed steps of compilation (print every instruction traced) - - ``torch._logging.set_logs(dynamo = logging.DEBUG)`` and - ``torch._dynamo.config.verbose = True``, or ``TORCH_LOGS="+dynamo" TORCHDYNAMO_VERBOSE=1`` - * - Minifier for any backend - - Find smallest subgraph which reproduces errors for any backend - - set environment variable ``TORCHDYNAMO_REPRO_AFTER="dynamo"`` - * - Minifier for ``TorchInductor`` - - If the error is known to occur after ``AOTAutograd`` find - smallest subgraph which reproduces errors during ``TorchInductor`` lowering - - set environment variable ``TORCHDYNAMO_REPRO_AFTER="aot"`` - * - Dynamo accuracy minifier - - Finds the smallest subgraph which reproduces an accuracy issue - between an eager mode model and optimized model, when you - suspect the problem is in ``AOTAutograd`` - - ``TORCHDYNAMO_REPRO_AFTER="dynamo" TORCHDYNAMO_REPRO_LEVEL=4`` - * - Inductor accuracy minifier - - Finds the smallest subgraph which reproduces an accuracy issue - between an eager mode model and optimized model, when you - suspect the problem is in the backend (e.g., inductor). - If this doesn't work, try the Dynamo accuracy minifier - instead. - - ``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` - * - ``torch._dynamo.explain`` - - Find graph breaks and display reasoning for them - - ``torch._dynamo.explain(fn)(*inputs)`` - * - Record/Replay - - Record and replay frames which to reproduce errors during graph capture - - ``torch._dynamo.config.replay_record_enabled = True`` - * - TorchDynamo function name filtering - - Only compile functions with the given name to reduce noise when - debugging an issue - - set environment variable ``TORCHDYNAMO_DEBUG_FUNCTION=`` - * - TorchInductor Debug logging - - Print general TorchInductor debug info and generated Triton/C++ code - - ``torch._inductor.config.debug = True`` - * - TorchInductor Tracing - - Show time taken in each TorchInductor stage + output code and graph - visualization - - set the environment variable TORCH_COMPILE_DEBUG=1 or - ``torch._inductor.config.trace.enabled = True`` - -In addition to info and debug logging, -you can use `torch._logging `__ -for more fine-grained logging. - -Diagnosing Runtime Errors -~~~~~~~~~~~~~~~~~~~~~~~~~ - -At a high level, the TorchDynamo stack consists of a graph capture from -Python code (TorchDynamo) and a backend compiler. For example, a -backend compiler may consist of backward graph tracing (AOTAutograd) and -graph lowering (TorchInductor)*. Errors can occur in any component of -the stack and will provide full stack traces. - -To determine in which component an error occurred, -you may use info-level logging -``torch._logging.set_logs(dynamo = logging.INFO)`` or ``TORCH_LOGS="dynamo"`` -and look for ``Step #: ...`` outputs. Logs are made at the beginning and end of -each step, so the step that an error should correspond to is the most recently -logged step whose end has not yet been logged. The steps correspond to the -following parts of the stack: - -==== ================ -Step Component -==== ================ -1 TorchDynamo -2 Compiler Backend -3 TorchInductor -==== ================ - -If info logging is insufficient, you can use available backend -options. These options include: - -- ``"eager"``: only runs TorchDynamo forward graph capture and then - runs the captured graph with PyTorch. This provides an indication as - to whether TorchDynamo is raising the error. - -- ``"aot_eager"``: runs TorchDynamo to capture a forward graph, and - then AOTAutograd to trace the backward graph without any additional - backend compiler steps. PyTorch eager will then be used to run the - forward and backward graphs. This is useful to narrow down the issue - to AOTAutograd. - -The general procedure to narrow down an issue is the following: - -1. Run your program with the ``"eager"`` backend. If the error no longer - occurs, the issue is in the backend compiler that is being used (if - using TorchInductor, proceed to step 2. If not, see `this - section <#minifying-backend-compiler-errors>`__). If the error still - occurs with the ``"eager"`` backend, it is an `error while running - torchdynamo <#torchdynamo-errors>`__. - -2. This step is only necessary if ``TorchInductor`` is used as the backend - compiler. Run the model with the ``"aot_eager"`` backend. If this - backend raises an error then the error is occurring during - AOTAutograd tracing. If the error no longer occurs with this backend, - then `the error is in - TorchInductor\* <#minifying-torchinductor-errors>`__. - -Each of these cases are analyzed in the following sections. - -.. note:: The TorchInductor backend consists of - both AOTAutograd tracing and the TorchInductor compiler itself. We will - disambiguate by referring to ``TorchInductor`` as the backend, and - TorchInductor lowering as the phase which lowers the graph traced by - AOTAutograd. - -Torchdynamo Errors ------------------- - -If the error that is generated occurs with the ``"eager"`` backend, then -TorchDynamo is most likely the source of the error. Here is a sample code -which will generate an error. - -.. code-block:: py - - import torch - - import torch._dynamo as dynamo - - - def test_assertion_error(): - y = torch.ones(200, 200) - z = {y: 5} - return z - - compiled_test_assertion_error = torch.compile(test_assertion_error, backend="eager") - - compiled_test_assertion_error() - -The code above generates the following error: - -:: - - torch._dynamo.convert_frame: [ERROR] WON'T CONVERT test_assertion_error /scratch/mlazos/torchdynamo/../test/errors.py line 26 - due to: - Traceback (most recent call last): - File "/scratch/mlazos/torchdynamo/torchdynamo/symbolic_convert.py", line 837, in BUILD_MAP - assert isinstance(k, ConstantVariable) or ( - AssertionError - - from user code: - File "/scratch/mlazos/torchdynamo/../test/errors.py", line 34, in test_assertion_error - z = {y: 5} - - Set torch._dynamo.config.verbose=True for more information - ========== - -As the message suggests you can set -``torch._dynamo.config.verbose=True`` to get a full stack trace to both -the error in TorchDynamo and the user code. In addition to this flag, -you can also set the ``log_level`` of TorchDynamo through -``torch._logging.set_logs(dynamo = logging.INFO)`` or ``TORCH_LOGS="dynamo"``. These levels include: - -- ``logging.DEBUG`` or ``TORCH_LOGS="+dynamo"``: Print every instruction that is - encountered in addition to all the log levels listed below. -- ``logging.INFO``: - Print each function that is compiled (original and modified bytecode) - and the graph that is captured in addition to all the log levels listed below. -- ``logging.WARNING`` (default): Print graph breaks in addition to all - the log levels listed below. -- ``logging.ERROR``: Print errors only. - -If a model is very large, the logs can become overwhelming. If -an error occurs deep within a model's Python code, it can be useful to -execute only the frame in which the error occurs to enable easier -debugging. There are two tools available to enable this: - -- Setting the environment variable ``TORCHDYNAMO_DEBUG_FUNCTION`` - to the desired function name will only run torchdynamo on functions with that - name. - -- Enabling the record/replay tool (set ``torch._dynamo.config.replay_record_enabled = True``) - which dumps an execution record when an error is encountered. This record can - then be replayed to run only the frame where an error occurred. - -Diagnosing TorchInductor Errors -------------------------------- - -If the error does not occur with the ``"eager"`` backend, then the -backend compiler is the source of the error (`example -error `__). -There are `different choices <./torch.compiler.rst>`__ -for backend compilers for TorchDynamo, with TorchInductor -fitting the needs of most users. This section focuses on TorchInductor -as the motivating example, but some tools can also be used with other -backend compilers. - -Below is the portion of the stack which we are focusing on: - -With TorchInductor as the chosen backend, AOTAutograd is used to -generate the backward graph from the forward graph captured by -torchdynamo. It is important to note that errors can occur during this -tracing and also while TorchInductor lowers the forward and backward -graphs to GPU code or C++. A model can often consist of hundreds or -thousands of FX nodes, so narrowing the exact nodes where this problem -occurred can be very difficult. Fortunately, there are tools available to -automatically minify these input graphs to the nodes which are causing -the issue. The first step is to determine whether the error occurs -during tracing of the backward graph with AOTAutograd or during -TorchInductor lowering. As mentioned above in step 2, the -``"aot_eager"`` backend can be used to run only AOTAutograd in isolation -without lowering. If the error still occurs with this backend, this -indicates that the error is occurring during AOTAutograd tracing. - -Here is an example: - -.. code-block:: py - - import torch - - import torch._dynamo as dynamo - - model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)]) - - def test_backend_error(): - - y = torch.ones(200, 200) - x = torch.ones(200, 200) - z = x + y - a = torch.ops.aten._foobar(z) # dummy function which errors - return model(a) - - - compiled_test_backend_error = torch.compile(test_backend_error, backend="inductor") - compiled_test_backend_error() - -Running this should give you this error with a longer stack trace below -it: - -:: - - Traceback (most recent call last): - File "/scratch/mlazos/torchdynamo/torchinductor/graph.py", line 246, in call_function - return lowerings[target](*args, **kwargs) - File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 185, in wrapped - return decomp_fn(*args, **kwargs) - File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 810, in _foobar - assert False - AssertionError - ... - -`error with full stack -trace `__ - -If you then change ``torch.compile(backend="inductor")`` to -``torch.compile(backend="aot_eager")``, it will run without error, because -`the -issue `__ -is in the TorchInductor lowering process, not in AOTAutograd. - -Minifying TorchInductor Errors ------------------------------- - -From here, let’s run the minifier to get a minimal repro. Setting the -environment variable ``TORCHDYNAMO_REPRO_AFTER="aot"`` (or setting -``torch._dynamo.config.repro_after="aot"`` directly) will generate a -Python program which reduces the graph produced by AOTAutograd to the -smallest subgraph which reproduces the error. (See below for an example -where we minify the graph produced by TorchDynamo) Running the program -with this environment variable should show nearly `identical -output `__, -with an additional line indicating where ``minifier_launcher.py`` has -been written to. The output directory is configurable by setting -``torch._dynamo.config.base_dir`` to a valid directory name. The final -step is to run the minifier and check that it runs successfully. A -successful run looks like -`this `__. -If the minifier runs successfully, it generates runnable python code -which reproduces the exact error. For our example this is the following -code: - -.. code-block:: python - - import torch - from torch import tensor, device - import torch.fx as fx - from torch._dynamo.testing import rand_strided - from math import inf - from torch.fx.experimental.proxy_tensor import make_fx - - # torch version: 1.13.0a0+gitfddfc44 - # torch cuda version: 11.6 - # torch git version: fddfc4488afb207971c54ad4bf58130fdc8a4dc5 - - - # CUDA Info: - # nvcc: NVIDIA (R) Cuda compiler driver - # Copyright (c) 2005-2022 NVIDIA Corporation - # Built on Thu_Feb_10_18:23:41_PST_2022 - # Cuda compilation tools, release 11.6, V11.6.112 - # Build cuda_11.6.r11.6/compiler.30978841_0 - - # GPU Hardware Info: - # NVIDIA A100-SXM4-40GB : 8 - - from torch.nn import * - - class Repro(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, add): - _foobar = torch.ops.aten._foobar.default(add); add = None - return (_foobar,) - - args = [((200, 200), (200, 1), torch.float32, 'cpu')] - args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args] - mod = make_fx(Repro())(*args) - from torch._inductor.compile_fx import compile_fx_inner - - compiled = compile_fx_inner(mod, args) - compiled(*args) - -The ``forward`` method of the ``Repro`` module contains the exact op -which causes the issue. When filing an issue, please include any -minified repros to aid in debugging. - -Minifying Backend Compiler Errors ---------------------------------- - -With backend compilers other than TorchInductor the process for finding -the subgraph causing the error is nearly identical to the procedure in -`errors in TorchInductor <#torchinductor-errors>`__ with one important -caveat. Namely, that the minifier will now be run on the graph that is -traced by TorchDynamo, not the output graph of AOTAutograd. Let’s walk -through an example. - -.. code-block:: py - - import torch - - import torch._dynamo as dynamo - - model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)]) - # toy compiler which fails if graph contains relu - def toy_compiler(gm: torch.fx.GraphModule, _): - for node in gm.graph.nodes: - if node.target == torch.relu: - assert False - - return gm - - - def test_backend_error(): - y = torch.ones(200, 200) - x = torch.ones(200, 200) - z = x + y - a = torch.relu(z) - return model(a) - - - compiled_test_backend_error = torch.compile(test_backend_error, backend=toy_compiler) - compiled_test_backend_error() - -In order to run the code after TorchDynamo has traced the forward graph, -you can use the ``TORCHDYNAMO_REPRO_AFTER`` environment variable. Running -this program with ``TORCHDYNAMO_REPRO_AFTER="dynamo"`` (or -``torch._dynamo.config.repro_after="dynamo"``) should produce `this -output `__\ and -the following code in ``{torch._dynamo.config.base_dir}/repro.py``. - -.. note:: The other option for TORCHDYNAMO_REPRO_AFTER is ``"aot"``, which - will run the minifier after the backward graph has been generated. - -.. code-block:: python - - import torch - import torch._dynamo as dynamo - from torch import tensor, device - import torch.fx as fx - from torch._dynamo.testing import rand_strided - from math import inf - from torch._dynamo.debug_utils import run_fwd_maybe_bwd - - from torch.nn import * - - class Repro(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, add): - relu = torch.relu(add); add = None - return (relu,) - - - mod = Repro().cuda() - opt_mod = torch.compile(mod, backend="None") - - - args = [((200, 200), (200, 1), torch.float32, 'cpu', False)] - args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] - - - with torch.cuda.amp.autocast(enabled=False): - ref = run_fwd_maybe_bwd(mod, args) - res = run_fwd_maybe_bwd(opt_mod, args) - -The minifier successfully reduced the graph to the op that raises the -error in ``toy_compiler``. The other difference from the procedure in -`TorchInductor Errors <#torchinductor-errors>`__ is that the minifier is -automatically run after encountering a backend compiler error. After a -successful run, the minifier writes ``repro.py`` to -``torch._dynamo.config.base_dir``. - -Performance Profiling -~~~~~~~~~~~~~~~~~~~~~ - -Accessing TorchDynamo Profiler ------------------------------- - -TorchDynamo has a built-in stats function for collecting and displaying -the time spent in each compilation phase. These stats can be accessed by -calling ``torch._dynamo.utils.compile_times()`` after executing -Torch._Dynamo. By default, this returns a string representation of the -compile times spent in each TorchDynamo function by name. - -TorchInductor Debugging using TORCH_COMPILE_DEBUG -------------------------------------------------- - -TorchInductor has a builtin stats and trace function for displaying time -spent in each compilation phase, output code, output graph visualization -and IR dump. This is a debugging tool designed to make it easier to -understand and troubleshoot the internals of TorchInductor. - -Let's run an example with the following test program (``repro.py``): - -:: - - import torch - - @torch.compile() - def test_model(x): - model = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.LayerNorm(10), - torch.nn.ReLU(), - ) - return model(x) - - - y = test_model(torch.ones(10, 10)) - -Setting the environment variable ``TORCH_COMPILE_DEBUG=1`` will cause a -debug trace directory to be created, by default this directory will be in the -current directory and named torch_compile_debug (this can be overridden in -the torchdynamo configuration field ``debug_dir_root`` and also the -``env var TORCH_COMPILE_DEBUG_DIR``). Inside this directory, each run will -have a separate folder named with the timestamp and process id of the run: - -:: - - $ env TORCH_COMPILE_DEBUG=1 python repro.py - $ cd torch_compile_debug - $ ls - run_2023_03_01_08_20_52_143510-pid_180167 - -In the run folder there will be a ``torchdynamo`` directory which contains -debug logs, and an ``torchinductor`` folder which contains a subfolder for each -compiled kernel with inductor debug artifacts. - -:: - - $ cd - run_2023_03_01_08_20_52_143510-pid_180167 - $ ls - torchinductor torchdynamo - -Moving further into the ``torchinductor`` directory, the ``\*.log`` files are -logs from the AOT Autograd phase of compilation, ``model__0_forward_1.0`` contains -the inductor debug artifacts. - -:: - - $ cd torchinductor - $ ls - aot_model___0_debug.log model__0_forward_1.0 - $ cd model__0_forward_1.0 - $ ls - debug.log fx_graph_readable.py fx_graph_runnable.py fx_graph_transformed.py ir_post_fusion.txt ir_pre_fusion.txt output_code.py - -Here is a summary of the contents: - -- ``fx_graph_readable.py`` and ``fx_graph_runnable.py`` are the readable and - runnable versions of the ``fx_graph`` received by inductor. -- ``fx_graph_transformed.py`` is the fx graph after inductor has run all fx passes. -- ``ir\*.txt`` is the inductor ir pre and post fusion. -- ``output_code.py`` is the compiled triton kernel for the subgraph. - -Here are `example debug directory contents -`__ -for the test program: - -:: - - import torch - - @torch.compile() - def test_model(x): - model = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.LayerNorm(10), - torch.nn.ReLU(), - ) - return model(x) - - - y = test_model(torch.ones(10, 10)) - -Each file in that debug trace can be enabled and disabled through -``torch._inductor.config.trace.*``. The profile and the diagram are both -disabled by default since they are expensive to generate. - -A single node in this new debug format looks like: - -:: - - buf1: SchedulerNode(ComputedBuffer) - buf1.writes = - { MemoryDep(name='buf1', index=0, size=()), - MemoryDep(name='buf1', index=0, size=(s0,))} - buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))} - buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))} - buf1.group.device = cuda:0 - buf1.group.iteration = (1, s0) - buf1.sizes = ([], [s0]) - class buf1_loop_body: - var_ranges = {z0: s0} - index0 = z0 - index1 = 0 - def body(self, ops): - get_index = self.get_index('index0') - load = ops.load('buf0', get_index, False) - get_index_1 = self.get_index('index0') - load_1 = ops.load('primals_2', get_index_1, False) - add = ops.add(load, load_1) - get_index_2 = self.get_index('index1') - reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add) - return reduction - -See the `example debug directory -output `__ -for more examples. - -.. - _Memory Profiling - ---------------- - - TBD - -Graph Breaks ------------- - -Given a program like this: - -.. code-block:: python - - def some_fun(x): - ... - - compiled_fun = torch.compile(some_fun, ...) - ... - -TorchDynamo will attempt to compile all of the torch/tensor operations -within some_fun into a single FX graph, but it may fail to capture -everything into one graph. - -Some graph break reasons are insurmountable to TorchDynamo, and can't be -easily fixed. - calling into a C extension other than torch is invisible -to torchdynamo, and could do arbitrary things without TorchDynamo being -able to introduce necessary guards (see :ref:`making-dynamo-sound-guards`) -to ensure that the compiled program would be safe to reuse. Graph breaks -can hinder performance if the resulting fragments are small. To maximize -performance, it's important to have as few graph breaks as possible. - -Identifying the Cause of a Graph Break -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To identify all graph breaks in a program and the associated reasons for -the breaks, ``torch._dynamo.explain`` can be used. This tool runs -TorchDynamo on the supplied function and aggregates the graph breaks -that are encountered. Here is an example usage: - -.. code-block:: python - - import torch - import torch._dynamo as dynamo - def toy_example(a, b): - x = a / (torch.abs(a) + 1) - print("woo") - if b.sum() < 0: - b = b * -1 - return x * b - explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10)) - print(explanation_verbose) - """ - Graph Count: 3 - Graph Break Count: 2 - Op Count: 5 - Break Reasons: - Break Reason 1: - Reason: builtin: print [] False - User Stack: - - Break Reason 2: - Reason: generic_jump TensorVariable() - User Stack: - - Ops per Graph: - ... - Out Guards: - ... - """ - -Outputs include: - -- ``out_guards`` - a list of lists where each sublist contains the guards that must pass to ensure the traced graphs are valid. -- ``graphs`` - a list of graph modules which were successfully traced. -- ``ops_per_graph`` - a list of lists where each sublist contains the ops that are run in the graph. - -To throw an error on the first graph break encountered, use the ``fullgraph`` -mode. This mode disables TorchDynamo’s Python fallback, and only -succeeds if the entire program is convertible into a single graph. Example -usage: - -.. code-block:: python - - def toy_example(a, b): - ... - - compiled_toy = torch.compile(toy_example, fullgraph=True, backend=)(a, b) - -Excessive Recompilation ------------------------ - -When TorchDynamo compiles a function (or part of one), it makes certain -assumptions about locals and globals in order to allow compiler -optimizations, and expresses these assumptions as guards that check -particular values at runtime. If any of these guards fail, Dynamo will -recompile that function (or part) up to -``torch._dynamo.config.recompile_limit`` times. If your program is -hitting the cache limit, you will first need to determine which guard is -failing and what part of your program is triggering it. - -If your program exhibits a bounded amount of dynamism, you may be able -to tune the TorchDynamo cache limit to allow for each variation to be -compiled and cached, but if the cache limit is too high you may find the -cost of recompilation outweighs any optimization benefits. - -:: - - torch._dynamo.config.recompile_limit = - -TorchDynamo plans to support many common cases of dynamic tensor shapes, -such as varying batch size or sequence length. It does not plan to -support rank-dynamism. In the meantime, setting a specific cache limit -can be used in coordination with bucketing techniques to achieve an -acceptable number of recompilations for some dynamic models. - -Accuracy Debugging -~~~~~~~~~~~~~~~~~~ - -Accuracy issues can also be minified if you set the environment variable -``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect -model and a full repro might be something like -``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason -we need this is downstream compilers will codegen code whether it’s -Triton code or the C++ backend, the numerics from those downstream -compilers can be different in subtle ways yet have dramatic impact on -your training stability. So the accuracy debugger is very useful for us -to detect bugs in our codegen or with a backend compiler. - -If you'd like to ensure that random number generation is the same across both torch -and triton then you can enable ``torch._inductor.config.fallback_random = True`` - -Extended Debugging -~~~~~~~~~~~~~~~~~~ - -Extended debugging can be enabled by using the following experimental flags. - -``TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED`` - provides extended debug information if the -string representation of a guard matches this flag value. For example, set it to -"Ne(s0, 10)" to generate full Python and C++ backtrace whenever guard was issued. -``TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL`` - provides extended debug information when -a particular symbol is allocated. For example, set this to "u2" to generate full Python -and C++ backtrace whenever this symbol was created. -``TORCHDYNAMO_EXTENDED_DEBUG_CPP`` - provides extended debug information (C++ backtrace) -for all extended debug settings as well as errors. For example, set this to "1". The C++ -backtrace is slow and very spammy so it is not included by default with extended debugging. - -Cold Start Timing and Cache Corruption Debugging -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In order to measure the cold start compilation time or debug a cache corruption, -it is possible pass ``TORCHINDUCTOR_FORCE_DISABLE_CACHES=1`` or set -``torch._inductor.config.force_disable_caches = True`` which will override any -other caching config option and disable all compile time caching. diff --git a/docs/source/torch.overrides.md b/docs/source/torch.overrides.md new file mode 100644 index 00000000000000..42e75bab950893 --- /dev/null +++ b/docs/source/torch.overrides.md @@ -0,0 +1,49 @@ +```{eval-rst} +.. currentmodule:: torch.overrides +``` + +# torch.overrides +```{eval-rst} +.. py:module:: torch.overrides +``` + +This module exposes various helper functions for the ``__torch_function__`` +protocol. See {ref}`extending-torch-python` for more details on the +``__torch_function__`` protocol. + +## Functions +```{eval-rst} +.. autofunction:: get_ignored_functions +``` + +```{eval-rst} +.. autofunction:: get_overridable_functions +``` + +```{eval-rst} +.. autofunction:: resolve_name +``` + +```{eval-rst} +.. autofunction:: get_testing_overrides +``` + +```{eval-rst} +.. autofunction:: handle_torch_function +``` + +```{eval-rst} +.. autofunction:: has_torch_function +``` + +```{eval-rst} +.. autofunction:: is_tensor_like +``` + +```{eval-rst} +.. autofunction:: is_tensor_method_or_property +``` + +```{eval-rst} +.. autofunction:: wrap_torch_function +``` diff --git a/docs/source/torch.overrides.rst b/docs/source/torch.overrides.rst deleted file mode 100644 index 5695372240fea8..00000000000000 --- a/docs/source/torch.overrides.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. currentmodule:: torch.overrides - -torch.overrides ---------------- -.. py:module:: torch.overrides - -This module exposes various helper functions for the ``__torch_function__`` -protocol. See :ref:`extending-torch-python` for more details on the -``__torch_function__`` protocol. - -Functions -~~~~~~~~~ - -.. autofunction:: get_ignored_functions - -.. autofunction:: get_overridable_functions - -.. autofunction:: resolve_name - -.. autofunction:: get_testing_overrides - -.. autofunction:: handle_torch_function - -.. autofunction:: has_torch_function - -.. autofunction:: is_tensor_like - -.. autofunction:: is_tensor_method_or_property - -.. autofunction:: wrap_torch_function diff --git a/docs/source/torch_cuda_memory.md b/docs/source/torch_cuda_memory.md new file mode 100644 index 00000000000000..bb50e5fd575137 --- /dev/null +++ b/docs/source/torch_cuda_memory.md @@ -0,0 +1,97 @@ +(torch_cuda_memory)= + +# Understanding CUDA Memory Usage + +To debug CUDA memory use, PyTorch provides a way to generate memory snapshots that record the state of allocated CUDA memory +at any point in time, and optionally record the history of allocation events that led up to that snapshot. + +The generated snapshots can then be drag and dropped onto the interactiver viewer hosted at [pytorch.org/memory_viz](https://pytorch.org/memory_viz) which +can be used to explore the snapshot. + +```{note} +The memory profiler and visualizer described in this document only have visibility into the CUDA memory that is +allocated and managed through the PyTorch allocator. Any memory allocated directly from CUDA APIs will not be +visible in the PyTorch memory profiler. + +NCCL (used for distributed communication on CUDA devices) is a common example of a library that allocates some +GPU memory that is invisible to the PyTorch memory profiler. See {ref}`non_pytorch_alloc` for more info. +``` + +## Generating a Snapshot + +The common pattern for recording a snapshot is to enable memory history, run the code to be observed, and then save a file with a pickled snapshot: + +```python +# enable memory history, which will +# add tracebacks and event history to snapshots +torch.cuda.memory._record_memory_history() + +run_your_code() +torch.cuda.memory._dump_snapshot("my_snapshot.pickle") +``` + +## Using the visualizer + +Open [pytorch.org/memory_viz](https://pytorch.org/memory_viz>) and drag/drop the pickled snapshot file into the visualizer. +The visualizer is a javascript application that runs locally on your computer. It does not upload any snapshot data. + + +## Active Memory Timeline + +The Active Memory Timeline shows all the live tensors over time in the snapshot on a particular GPU. Pan/Zoom over the plot to look at smaller allocations. +Mouse over allocated blocks to see a stack trace for when that block was allocated, and details like its address. The detail slider can be adjusted to +render fewer allocations and improve performance when there is a lot of data. + +```{image} _static/img/torch_cuda_memory/active_memory_timeline.png +``` + + +## Allocator State History + +The Allocator State History shows individual allocator events in a timeline on the left. Select an event in the timeline to see a visual summary of the +allocator state at that event. This summary shows each individual segment returned from cudaMalloc and how it is split up into blocks of individual allocations +or free space. Mouse over segments and blocks to see the stack trace when the memory was allocated. Mouse over events to see the stack trace when the event occurred, +such as when a tensor was freed. Out of memory errors are reported as OOM events. Looking at the state of memory during an OOM may provide insight into why +an allocation failed even though reserved memory still exists. + +```{image} _static/img/torch_cuda_memory/allocator_state_history.png +``` + +The stack trace information also reports the address at which an allocation occurred. +The address b7f064c000000_0 refers to the (b)lock at address 7f064c000000 which is the "_0"th time this address was allocated. +This unique string can be looked up in the Active Memory Timeline and searched +in the Active State History to examine the memory state when a tensor was allocated or freed. + +(non_pytorch_alloc)= +## Identifying Non-PyTorch allocations + +If you suspect CUDA memory is being allocated outside of PyTorch, you can collect the raw CUDA allocation info using +the pynvml package, and compare that to the allocation reported by pytorch. + + +To collect raw memory usage outside pytorch, use {func}`device_memory_used` + +```python +import torch +device_idx = ... +print(torch.cuda.device_memory_used(device_idx)) +``` + +## Snapshot API Reference + +```{eval-rst} +.. currentmodule:: torch.cuda.memory +``` + +```{eval-rst} +.. autofunction:: _record_memory_history +``` + +```{eval-rst} +.. autofunction:: _snapshot +``` + + +```{eval-rst} +.. autofunction:: _dump_snapshot +``` diff --git a/docs/source/torch_cuda_memory.rst b/docs/source/torch_cuda_memory.rst deleted file mode 100644 index 697ba26714860d..00000000000000 --- a/docs/source/torch_cuda_memory.rst +++ /dev/null @@ -1,90 +0,0 @@ -.. _torch_cuda_memory: - -Understanding CUDA Memory Usage -=============================== -To debug CUDA memory use, PyTorch provides a way to generate memory snapshots that record the state of allocated CUDA memory -at any point in time, and optionally record the history of allocation events that led up to that snapshot. - -The generated snapshots can then be drag and dropped onto the interactiver viewer hosted at `pytorch.org/memory_viz `_ which -can be used to explore the snapshot. - -.. note:: - - The memory profiler and visualizer described in this document only have visibility into the CUDA memory that is - allocated and managed through the PyTorch allocator. Any memory allocated directly from CUDA APIs will not be - visible in the PyTorch memory profiler. - - NCCL (used for distributed communication on CUDA devices) is a common example of a library that allocates some - GPU memory that is invisible to the PyTorch memory profiler. See :ref:`non_pytorch_alloc` for more info. - -Generating a Snapshot -===================== -The common pattern for recording a snapshot is to enable memory history, run the code to be observed, and then save a file with a pickled snapshot: - -.. code-block:: python - - # enable memory history, which will - # add tracebacks and event history to snapshots - torch.cuda.memory._record_memory_history() - - run_your_code() - torch.cuda.memory._dump_snapshot("my_snapshot.pickle") - -Using the visualizer -==================== - -Open `pytorch.org/memory_viz `_ and drag/drop the pickled snapshot file into the visualizer. -The visualizer is a javascript application that runs locally on your computer. It does not upload any snapshot data. - - -Active Memory Timeline ----------------------- - -The Active Memory Timeline shows all the live tensors over time in the snapshot on a particular GPU. Pan/Zoom over the plot to look at smaller allocations. -Mouse over allocated blocks to see a stack trace for when that block was allocated, and details like its address. The detail slider can be adjusted to -render fewer allocations and improve performance when there is a lot of data. - -.. image:: _static/img/torch_cuda_memory/active_memory_timeline.png - - -Allocator State History ------------------------ - -The Allocator State History shows individual allocator events in a timeline on the left. Select an event in the timeline to see a visual summary of the -allocator state at that event. This summary shows each individual segment returned from cudaMalloc and how it is split up into blocks of individual allocations -or free space. Mouse over segments and blocks to see the stack trace when the memory was allocated. Mouse over events to see the stack trace when the event occurred, -such as when a tensor was freed. Out of memory errors are reported as OOM events. Looking at the state of memory during an OOM may provide insight into why -an allocation failed even though reserved memory still exists. - -.. image:: _static/img/torch_cuda_memory/allocator_state_history.png - -The stack trace information also reports the address at which an allocation occurred. -The address b7f064c000000_0 refers to the (b)lock at address 7f064c000000 which is the "_0"th time this address was allocated. -This unique string can be looked up in the Active Memory Timeline and searched -in the Active State History to examine the memory state when a tensor was allocated or freed. - -.. _non_pytorch_alloc: - -Identifying Non-PyTorch allocations ------------------------------------ - -If you suspect CUDA memory is being allocated outside of PyTorch, you can collect the raw CUDA allocation info using -the pynvml package, and compare that to the allocation reported by pytorch. - - -To collect raw memory usage outside pytorch, use :func:`device_memory_used`: - -.. code:: - - import torch - device_idx = ... - print(torch.cuda.device_memory_used(device_idx)) - - -Snapshot API Reference -====================== - -.. currentmodule:: torch.cuda.memory -.. autofunction:: _record_memory_history -.. autofunction:: _snapshot -.. autofunction:: _dump_snapshot diff --git a/docs/source/torch_environment_variables.md b/docs/source/torch_environment_variables.md new file mode 100644 index 00000000000000..7bf429db0033d3 --- /dev/null +++ b/docs/source/torch_environment_variables.md @@ -0,0 +1,29 @@ +(torch_environment_variables)= +# Torch Environment Variables + +PyTorch leverages environment variables for adjusting various settings that influence its runtime behavior. +These variables offer control over key functionalities, such as displaying the C++ stack trace upon encountering errors, synchronizing the execution of CUDA kernels, +specifying the number of threads for parallel processing tasks and many more. + +Moreover, PyTorch leverages several high-performance libraries, such as MKL and cuDNN, +which also utilize environment variables to modify their functionality. +This interplay of settings allows for a highly customizable development environment that can be +optimized for efficiency, debugging, and computational resource management. + +Please note that while this documentation covers a broad spectrum of environment variables relevant to PyTorch and its associated libraries, it is not exhaustive. +If you find anything in this documentation that is missing, incorrect, or could be improved, please let us know by filing an issue or opening a pull request. + + +```{eval-rst} +.. toctree:: + :maxdepth: 1 + + threading_environment_variables + cuda_environment_variables + mps_environment_variables + debugging_environment_variables + miscellaneous_environment_variables + logging + torch_nccl_environment_variables + +``` diff --git a/docs/source/torch_environment_variables.rst b/docs/source/torch_environment_variables.rst deleted file mode 100644 index fddb090690a1a8..00000000000000 --- a/docs/source/torch_environment_variables.rst +++ /dev/null @@ -1,28 +0,0 @@ -.. _torch_environment_variables: - -Torch Environment Variables -=============================== - -PyTorch leverages environment variables for adjusting various settings that influence its runtime behavior. -These variables offer control over key functionalities, such as displaying the C++ stack trace upon encountering errors, synchronizing the execution of CUDA kernels, -specifying the number of threads for parallel processing tasks and many more. - -Moreover, PyTorch leverages several high-performance libraries, such as MKL and cuDNN, -which also utilize environment variables to modify their functionality. -This interplay of settings allows for a highly customizable development environment that can be -optimized for efficiency, debugging, and computational resource management. - -Please note that while this documentation covers a broad spectrum of environment variables relevant to PyTorch and its associated libraries, it is not exhaustive. -If you find anything in this documentation that is missing, incorrect, or could be improved, please let us know by filing an issue or opening a pull request. - - -.. toctree:: - :maxdepth: 1 - - threading_environment_variables - cuda_environment_variables - mps_environment_variables - debugging_environment_variables - miscellaneous_environment_variables - logging - torch_nccl_environment_variables diff --git a/docs/source/torch_nccl_environment_variables.md b/docs/source/torch_nccl_environment_variables.md new file mode 100644 index 00000000000000..8293cdbbfc1767 --- /dev/null +++ b/docs/source/torch_nccl_environment_variables.md @@ -0,0 +1,41 @@ +(_torch_nccl_environment_variables)= +# PYTORCH ProcessGroupNCCL Environment Variables + +For more information on the environment variables, see [ProcessGroupNCCL Environment Variables](https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp). + +```{list-table} +:header-rows: 1 + +* - **Variable** + - **Description** +* - ``TORCH_NCCL_ASYNC_ERROR_HANDLING`` + - Control how we perform Async Error Handling with NCCL when an exception is observed in watchdog. If set to 0, no handling of asynchronous NCCL errors. If set to 1, aborting NCCL communicator and tearing down process upon error. If set to 2, only abort NCCL communicator and if set to 3, tearing down process without aborting NCCL communicator. By default, it is set to 3. +* - ``TORCH_NCCL_HIGH_PRIORITY`` + - Control whether to use high priority stream for the NCCL communicator. +* - ``TORCH_NCCL_BLOCKING_WAIT`` + - Control whether or not wait() is blocking or non-blocking. +* - ``TORCH_NCCL_DUMP_ON_TIMEOUT`` + - Control whether dumping debug info on watchdog timeout or exception is detected. This variable must be set together with TORCH_NCCL_TRACE_BUFFER_SIZE larger than 0. +* - ``TORCH_NCCL_DESYNC_DEBUG`` + - Control whether Desync Debug is enabled. This is helpful in figuring out the culprit rank of collective desync. +* - ``TORCH_NCCL_ENABLE_TIMING`` + - If set to ``1``, enable recording start-events for all ProcessGroupNCCL collectives, and compute accurate collective timing per-collective. +* - ``TORCH_NCCL_ENABLE_MONITORING`` + - If set to ``1``,enable monitoring thread which aborts the process when the ProcessGroupNCCL Watchdog thread gets stuck and no heartbeat is detected after TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged time than necessary tying up cluster resources. +* - ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC`` + - Control the watchdog heartbeat timeout period after which the monitoring thread will abort the process. +* - ``TORCH_NCCL_TRACE_BUFFER_SIZE`` + - The maximum number of events we store in the flight recorder's ring buffer. One event could be the start or end of a collective, for example. Set to 0 to disable the tracebuffer and debugging info dump. +* - ``TORCH_NCCL_TRACE_CPP_STACK`` + - Whether to collect cpp stack traces for flight recorder. Default value is False. +* - ``TORCH_NCCL_COORD_CHECK_MILSEC`` + - Control the interval inside the monitoring thread to check the coordinated signal from other ranks, e.g. to dump the debugging information. Default value is 1000 ms. +* - ``TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC`` + - Control how much extra time we will wait for dumping the debugging info before we exit and throws timeout exception. +* - ``TORCH_NCCL_DEBUG_INFO_TEMP_FILE`` + - The file into which the debugging info would be dumped. +* - ``TORCH_NCCL_DEBUG_INFO_PIPE_FILE`` + - The pipe file to trigger debugging dump manually, write anything into the pipe would trigger the dump. +* - ``TORCH_NCCL_NAN_CHECK`` + - Control whether to enable NAN check for the input, Error would be thrown if NAN is detected. +``` diff --git a/docs/source/torch_nccl_environment_variables.rst b/docs/source/torch_nccl_environment_variables.rst deleted file mode 100644 index 0d9070e34c72c1..00000000000000 --- a/docs/source/torch_nccl_environment_variables.rst +++ /dev/null @@ -1,41 +0,0 @@ -.. _torch_nccl_environment_variables: - -PYTORCH ProcessGroupNCCL Environment Variables -============================================== -For more information on the environment variables, see `ProcessGroupNCCL Environment Variables `_. - -.. list-table:: - :header-rows: 1 - - * - Variable - - Description - * - ``TORCH_NCCL_ASYNC_ERROR_HANDLING`` - - Control how we perform Async Error Handling with NCCL when an exception is observed in watchdog. If set to 0, no handling of asynchronous NCCL errors. If set to 1, aborting NCCL communicator and tearing down process upon error. If set to 2, only abort NCCL communicator and if set to 3, tearing down process without aborting NCCL communicator. By default, it is set to 3. - * - ``TORCH_NCCL_HIGH_PRIORITY`` - - Control whether to use high priority stream for the NCCL communicator. - * - ``TORCH_NCCL_BLOCKING_WAIT`` - - Control whether or not wait() is blocking or non-blocking. - * - ``TORCH_NCCL_DUMP_ON_TIMEOUT`` - - Control whether dumping debug info on watchdog timeout or exception is detected. This variable must be set together with TORCH_NCCL_TRACE_BUFFER_SIZE larger than 0. - * - ``TORCH_NCCL_DESYNC_DEBUG`` - - Control whether Desync Debug is enabled. This is helpful in figuring out the culprit rank of collective desync. - * - ``TORCH_NCCL_ENABLE_TIMING`` - - If set to ``1``, enable recording start-events for all ProcessGroupNCCL collectives, and compute accurate collective timing per-collective. - * - ``TORCH_NCCL_ENABLE_MONITORING`` - - If set to ``1``,enable monitoring thread which aborts the process when the ProcessGroupNCCL Watchdog thread gets stuck and no heartbeat is detected after TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged time than necessary tying up cluster resources. - * - ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC`` - - Control the watchdog heartbeat timeout period after which the monitoring thread will abort the process. - * - ``TORCH_NCCL_TRACE_BUFFER_SIZE`` - - The maximum number of events we store in the flight recorder's ring buffer. One event could be the start or end of a collective, for example. Set to 0 to disable the tracebuffer and debugging info dump. - * - ``TORCH_NCCL_TRACE_CPP_STACK`` - - Whether to collect cpp stack traces for flight recorder. Default value is False. - * - ``TORCH_NCCL_COORD_CHECK_MILSEC`` - - Control the interval inside the monitoring thread to check the coordinated signal from other ranks, e.g. to dump the debugging information. Default value is 1000 ms. - * - ``TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC`` - - Control how much extra time we will wait for dumping the debugging info before we exit and throws timeout exception. - * - ``TORCH_NCCL_DEBUG_INFO_TEMP_FILE`` - - The file into which the debugging info would be dumped. - * - ``TORCH_NCCL_DEBUG_INFO_PIPE_FILE`` - - The pipe file to trigger debugging dump manually, write anything into the pipe would trigger the dump. - * - ``TORCH_NCCL_NAN_CHECK`` - - Control whether to enable NAN check for the input, Error would be thrown if NAN is detected. diff --git a/docs/source/type_info.md b/docs/source/type_info.md new file mode 100644 index 00000000000000..9fc2ce56c4bea8 --- /dev/null +++ b/docs/source/type_info.md @@ -0,0 +1,61 @@ +```{eval-rst} +.. currentmodule:: torch +``` + +(type-info-doc)= +# Type Info + +The numerical properties of a {class}`torch.dtype` can be accessed through either the {class}`torch.finfo` or the {class}`torch.iinfo`. + +(finfo-doc)= +## torch.finfo + +```{eval-rst} +.. class:: torch.finfo +``` + +A {class}`torch.finfo` is an object that represents the numerical properties of a floating point +{class}`torch.dtype`, (i.e. ``torch.float32``, ``torch.float64``, ``torch.float16``, and ``torch.bfloat16``). +This is similar to [numpy.finfo](https://numpy.org/doc/stable/reference/generated/numpy.finfo.html). + +A {class}`torch.finfo` provides the following attributes: + +| Name | Type | Description | +| :-------------- | :---- | :------------------------------------------------------------------------- | +| bits | int | The number of bits occupied by the type. | +| eps | float | The smallest representable number such that ``1.0 + eps != 1.0``. | +| max | float | The largest representable number. | +| min | float | The smallest representable number (typically ``-max``). | +| tiny | float | The smallest positive normal number. Equivalent to ``smallest_normal``. | +| smallest_normal | float | The smallest positive normal number. See notes. | +| resolution | float | The approximate decimal resolution of this type, i.e., ``10**-precision``. | + +```{note} + The constructor of {class}`torch.finfo` can be called without argument, + in which case the class is created for the pytorch default dtype (as returned by {func}`torch.get_default_dtype`). +``` + +```{note} + `smallest_normal` returns the smallest *normal* number, but there are smaller + subnormal numbers. See https://en.wikipedia.org/wiki/Denormal_number + for more information. +``` + +(iinfo-doc)= +## torch.iinfo + +```{eval-rst} +.. class:: torch.iinfo +``` + +A {class}`torch.iinfo` is an object that represents the numerical properties of a integer +{class}`torch.dtype` (i.e. ``torch.uint8``, ``torch.int8``, ``torch.int16``, ``torch.int32``, and ``torch.int64``). +This is similar to [numpy.iinfo](https://numpy.org/doc/stable/reference/generated/numpy.iinfo.html). + +A {class}`torch.iinfo` provides the following attributes: + +| Name | Type | Description | +| :--- | :--- | :--------------------------------------- | +| bits | int | The number of bits occupied by the type. | +| max | int | The largest representable number. | +| min | int | The smallest representable number. | diff --git a/docs/source/type_info.rst b/docs/source/type_info.rst deleted file mode 100644 index 29a5ca28269735..00000000000000 --- a/docs/source/type_info.rst +++ /dev/null @@ -1,62 +0,0 @@ -.. currentmodule:: torch - -.. _type-info-doc: - -Type Info -========= - -The numerical properties of a :class:`torch.dtype` can be accessed through either the :class:`torch.finfo` or the :class:`torch.iinfo`. - -.. _finfo-doc: - -torch.finfo ------------ - -.. class:: torch.finfo - -A :class:`torch.finfo` is an object that represents the numerical properties of a floating point -:class:`torch.dtype`, (i.e. ``torch.float32``, ``torch.float64``, ``torch.float16``, and ``torch.bfloat16``). This is similar to `numpy.finfo `_. - -A :class:`torch.finfo` provides the following attributes: - -=============== ===== ========================================================================== -Name Type Description -=============== ===== ========================================================================== -bits int The number of bits occupied by the type. -eps float The smallest representable number such that ``1.0 + eps != 1.0``. -max float The largest representable number. -min float The smallest representable number (typically ``-max``). -tiny float The smallest positive normal number. Equivalent to ``smallest_normal``. -smallest_normal float The smallest positive normal number. See notes. -resolution float The approximate decimal resolution of this type, i.e., ``10**-precision``. -=============== ===== ========================================================================== - -.. note:: - The constructor of :class:`torch.finfo` can be called without argument, in which case the class is created for the pytorch default dtype (as returned by :func:`torch.get_default_dtype`). - -.. note:: - `smallest_normal` returns the smallest *normal* number, but there are smaller - subnormal numbers. See https://en.wikipedia.org/wiki/Denormal_number - for more information. - - -.. _iinfo-doc: - -torch.iinfo ------------- - -.. class:: torch.iinfo - - -A :class:`torch.iinfo` is an object that represents the numerical properties of a integer -:class:`torch.dtype` (i.e. ``torch.uint8``, ``torch.int8``, ``torch.int16``, ``torch.int32``, and ``torch.int64``). This is similar to `numpy.iinfo `_. - -A :class:`torch.iinfo` provides the following attributes: - -========= ===== ======================================== -Name Type Description -========= ===== ======================================== -bits int The number of bits occupied by the type. -max int The largest representable number. -min int The smallest representable number. -========= ===== ======================================== diff --git a/docs/source/utils.md b/docs/source/utils.md new file mode 100644 index 00000000000000..6742866a8b25bd --- /dev/null +++ b/docs/source/utils.md @@ -0,0 +1,97 @@ +# torch.utils +```{eval-rst} +.. automodule:: torch.utils +``` + +```{eval-rst} +.. currentmodule:: torch.utils +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + rename_privateuse1_backend + generate_methods_for_privateuse1_backend + get_cpp_backtrace + set_module + swap_tensors +``` + + +```{eval-rst} +.. py:module:: torch.utils.backend_registration +.. py:module:: torch.utils.benchmark.examples.compare +.. py:module:: torch.utils.benchmark.examples.fuzzer +.. py:module:: torch.utils.benchmark.examples.op_benchmark +.. py:module:: torch.utils.benchmark.examples.simple_timeit +.. py:module:: torch.utils.benchmark.examples.spectral_ops_fuzz_test +.. py:module:: torch.utils.benchmark.op_fuzzers.binary +.. py:module:: torch.utils.benchmark.op_fuzzers.sparse_binary +.. py:module:: torch.utils.benchmark.op_fuzzers.sparse_unary +.. py:module:: torch.utils.benchmark.op_fuzzers.spectral +.. py:module:: torch.utils.benchmark.op_fuzzers.unary +.. py:module:: torch.utils.benchmark.utils.common +.. py:module:: torch.utils.benchmark.utils.compare +.. py:module:: torch.utils.benchmark.utils.compile +.. py:module:: torch.utils.benchmark.utils.cpp_jit +.. py:module:: torch.utils.benchmark.utils.fuzzer +.. py:module:: torch.utils.benchmark.utils.sparse_fuzzer +.. py:module:: torch.utils.benchmark.utils.timer +.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface +.. py:module:: torch.utils.bundled_inputs +.. py:module:: torch.utils.checkpoint +.. py:module:: torch.utils.collect_env +.. py:module:: torch.utils.cpp_backtrace +.. py:module:: torch.utils.cpp_extension +.. py:module:: torch.utils.data.backward_compatibility +.. py:module:: torch.utils.data.dataloader +.. py:module:: torch.utils.data.datapipes.dataframe.dataframe_wrapper +.. py:module:: torch.utils.data.datapipes.dataframe.dataframes +.. py:module:: torch.utils.data.datapipes.dataframe.datapipes +.. py:module:: torch.utils.data.datapipes.dataframe.structures +.. py:module:: torch.utils.data.datapipes.datapipe +.. py:module:: torch.utils.data.datapipes.gen_pyi +.. py:module:: torch.utils.data.datapipes.iter.callable +.. py:module:: torch.utils.data.datapipes.iter.combinatorics +.. py:module:: torch.utils.data.datapipes.iter.combining +.. py:module:: torch.utils.data.datapipes.iter.filelister +.. py:module:: torch.utils.data.datapipes.iter.fileopener +.. py:module:: torch.utils.data.datapipes.iter.grouping +.. py:module:: torch.utils.data.datapipes.iter.routeddecoder +.. py:module:: torch.utils.data.datapipes.iter.selecting +.. py:module:: torch.utils.data.datapipes.iter.sharding +.. py:module:: torch.utils.data.datapipes.iter.streamreader +.. py:module:: torch.utils.data.datapipes.iter.utils +.. py:module:: torch.utils.data.datapipes.map.callable +.. py:module:: torch.utils.data.datapipes.map.combinatorics +.. py:module:: torch.utils.data.datapipes.map.combining +.. py:module:: torch.utils.data.datapipes.map.grouping +.. py:module:: torch.utils.data.datapipes.map.utils +.. py:module:: torch.utils.data.datapipes.utils.common +.. py:module:: torch.utils.data.datapipes.utils.decoder +.. py:module:: torch.utils.data.datapipes.utils.snapshot +.. py:module:: torch.utils.data.dataset +.. py:module:: torch.utils.data.distributed +.. py:module:: torch.utils.data.graph +.. py:module:: torch.utils.data.graph_settings +.. py:module:: torch.utils.data.sampler +.. py:module:: torch.utils.dlpack +.. py:module:: torch.utils.file_baton +.. py:module:: torch.utils.flop_counter +.. py:module:: torch.utils.hipify.constants +.. py:module:: torch.utils.hipify.cuda_to_hip_mappings +.. py:module:: torch.utils.hipify.hipify_python +.. py:module:: torch.utils.hipify.version +.. py:module:: torch.utils.hooks +.. py:module:: torch.utils.jit.log_extract +.. py:module:: torch.utils.mkldnn +.. py:module:: torch.utils.mobile_optimizer +.. py:module:: torch.utils.show_pickle +.. py:module:: torch.utils.tensorboard.summary +.. py:module:: torch.utils.tensorboard.writer +.. py:module:: torch.utils.throughput_benchmark +.. py:module:: torch.utils.weak +``` diff --git a/docs/source/utils.rst b/docs/source/utils.rst deleted file mode 100644 index 307872f359d06b..00000000000000 --- a/docs/source/utils.rst +++ /dev/null @@ -1,89 +0,0 @@ -torch.utils -=================================== -.. automodule:: torch.utils -.. currentmodule:: torch.utils - -.. autosummary:: - :toctree: generated - :nosignatures: - - rename_privateuse1_backend - generate_methods_for_privateuse1_backend - get_cpp_backtrace - set_module - swap_tensors - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.utils.backend_registration -.. py:module:: torch.utils.benchmark.examples.compare -.. py:module:: torch.utils.benchmark.examples.fuzzer -.. py:module:: torch.utils.benchmark.examples.op_benchmark -.. py:module:: torch.utils.benchmark.examples.simple_timeit -.. py:module:: torch.utils.benchmark.examples.spectral_ops_fuzz_test -.. py:module:: torch.utils.benchmark.op_fuzzers.binary -.. py:module:: torch.utils.benchmark.op_fuzzers.sparse_binary -.. py:module:: torch.utils.benchmark.op_fuzzers.sparse_unary -.. py:module:: torch.utils.benchmark.op_fuzzers.spectral -.. py:module:: torch.utils.benchmark.op_fuzzers.unary -.. py:module:: torch.utils.benchmark.utils.common -.. py:module:: torch.utils.benchmark.utils.compare -.. py:module:: torch.utils.benchmark.utils.compile -.. py:module:: torch.utils.benchmark.utils.cpp_jit -.. py:module:: torch.utils.benchmark.utils.fuzzer -.. py:module:: torch.utils.benchmark.utils.sparse_fuzzer -.. py:module:: torch.utils.benchmark.utils.timer -.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface -.. py:module:: torch.utils.bundled_inputs -.. py:module:: torch.utils.checkpoint -.. py:module:: torch.utils.collect_env -.. py:module:: torch.utils.cpp_backtrace -.. py:module:: torch.utils.cpp_extension -.. py:module:: torch.utils.data.backward_compatibility -.. py:module:: torch.utils.data.dataloader -.. py:module:: torch.utils.data.datapipes.dataframe.dataframe_wrapper -.. py:module:: torch.utils.data.datapipes.dataframe.dataframes -.. py:module:: torch.utils.data.datapipes.dataframe.datapipes -.. py:module:: torch.utils.data.datapipes.dataframe.structures -.. py:module:: torch.utils.data.datapipes.datapipe -.. py:module:: torch.utils.data.datapipes.gen_pyi -.. py:module:: torch.utils.data.datapipes.iter.callable -.. py:module:: torch.utils.data.datapipes.iter.combinatorics -.. py:module:: torch.utils.data.datapipes.iter.combining -.. py:module:: torch.utils.data.datapipes.iter.filelister -.. py:module:: torch.utils.data.datapipes.iter.fileopener -.. py:module:: torch.utils.data.datapipes.iter.grouping -.. py:module:: torch.utils.data.datapipes.iter.routeddecoder -.. py:module:: torch.utils.data.datapipes.iter.selecting -.. py:module:: torch.utils.data.datapipes.iter.sharding -.. py:module:: torch.utils.data.datapipes.iter.streamreader -.. py:module:: torch.utils.data.datapipes.iter.utils -.. py:module:: torch.utils.data.datapipes.map.callable -.. py:module:: torch.utils.data.datapipes.map.combinatorics -.. py:module:: torch.utils.data.datapipes.map.combining -.. py:module:: torch.utils.data.datapipes.map.grouping -.. py:module:: torch.utils.data.datapipes.map.utils -.. py:module:: torch.utils.data.datapipes.utils.common -.. py:module:: torch.utils.data.datapipes.utils.decoder -.. py:module:: torch.utils.data.datapipes.utils.snapshot -.. py:module:: torch.utils.data.dataset -.. py:module:: torch.utils.data.distributed -.. py:module:: torch.utils.data.graph -.. py:module:: torch.utils.data.graph_settings -.. py:module:: torch.utils.data.sampler -.. py:module:: torch.utils.dlpack -.. py:module:: torch.utils.file_baton -.. py:module:: torch.utils.flop_counter -.. py:module:: torch.utils.hipify.constants -.. py:module:: torch.utils.hipify.cuda_to_hip_mappings -.. py:module:: torch.utils.hipify.hipify_python -.. py:module:: torch.utils.hipify.version -.. py:module:: torch.utils.hooks -.. py:module:: torch.utils.jit.log_extract -.. py:module:: torch.utils.mkldnn -.. py:module:: torch.utils.mobile_optimizer -.. py:module:: torch.utils.show_pickle -.. py:module:: torch.utils.tensorboard.summary -.. py:module:: torch.utils.tensorboard.writer -.. py:module:: torch.utils.throughput_benchmark -.. py:module:: torch.utils.weak diff --git a/docs/source/xpu.md b/docs/source/xpu.md new file mode 100644 index 00000000000000..46d36451d4b8a7 --- /dev/null +++ b/docs/source/xpu.md @@ -0,0 +1,92 @@ +# torch.xpu +```{eval-rst} +.. automodule:: torch.xpu +``` +```{eval-rst} +.. currentmodule:: torch.xpu +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + StreamContext + current_device + current_stream + device + device_count + device_of + get_arch_list + get_device_capability + get_device_name + get_device_properties + get_gencode_flags + get_stream_from_external + init + is_available + is_initialized + set_device + set_stream + stream + synchronize +``` + +## Random Number Generator +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + get_rng_state + get_rng_state_all + initial_seed + manual_seed + manual_seed_all + seed + seed_all + set_rng_state + set_rng_state_all +``` + +## Streams and events +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + Event + Stream +``` + +```{eval-rst} +.. automodule:: torch.xpu.memory +``` +```{eval-rst} +.. currentmodule:: torch.xpu.memory +``` + +## Memory management +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + max_memory_allocated + max_memory_reserved + mem_get_info + memory_allocated + memory_reserved + memory_stats + memory_stats_as_nested_dict + reset_accumulated_memory_stats + reset_peak_memory_stats +``` + + +```{eval-rst} +.. py:module:: torch.xpu.random +.. py:module:: torch.xpu.streams +``` diff --git a/docs/source/xpu.rst b/docs/source/xpu.rst deleted file mode 100644 index 2b1010fb1c03ce..00000000000000 --- a/docs/source/xpu.rst +++ /dev/null @@ -1,78 +0,0 @@ -torch.xpu -=================================== -.. automodule:: torch.xpu -.. currentmodule:: torch.xpu - -.. autosummary:: - :toctree: generated - :nosignatures: - - StreamContext - current_device - current_stream - device - device_count - device_of - get_arch_list - get_device_capability - get_device_name - get_device_properties - get_gencode_flags - get_stream_from_external - init - is_available - is_initialized - set_device - set_stream - stream - synchronize - -Random Number Generator -------------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - get_rng_state - get_rng_state_all - initial_seed - manual_seed - manual_seed_all - seed - seed_all - set_rng_state - set_rng_state_all - -Streams and events ------------------- -.. autosummary:: - :toctree: generated - :nosignatures: - - Event - Stream - - -Memory management ------------------ -.. autosummary:: - :toctree: generated - :nosignatures: - - empty_cache - max_memory_allocated - max_memory_reserved - mem_get_info - memory_allocated - memory_reserved - memory_stats - memory_stats_as_nested_dict - reset_accumulated_memory_stats - reset_peak_memory_stats - - -.. This module needs to be documented. Adding here in the meantime -.. for tracking purposes -.. py:module:: torch.xpu.memory -.. py:module:: torch.xpu.random -.. py:module:: torch.xpu.streams diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 33e1c080dabda7..19270d2f9225de 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -710,7 +710,7 @@ struct Tensor : public mpy::base { auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true)); tensor_ = t->tensor(A); delayed_.reset(); - // don't force creation of batch tensor if it wasn't alreay provided. + // don't force creation of batch tensor if it wasn't already provided. batchtensor_ = t->batchtensor_; AT_ASSERT(levels() == t->levels()); } @@ -1739,7 +1739,7 @@ static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice if (lr_dims.dims.size() != sum.size()) { for (auto & d : sum) { if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { - mpy::raise_error(DimensionBindError(), "summing over non-existant dimension %S", d.dim().ptr()); + mpy::raise_error(DimensionBindError(), "summing over non-existent dimension %S", d.dim().ptr()); } } } @@ -2206,7 +2206,7 @@ mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indi self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op - // we need to be careful not to rely the dimensions size because it doesnt match the size of the whole group + // we need to be careful not to rely the dimensions size because it doesn't match the size of the whole group } bool has_dimpacks = false; for (auto idx : indices_list) { @@ -2219,7 +2219,7 @@ mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indi return invoke_getitem(A, info); } -// true -- the indices were flattend out of a tuple, list or sequence... +// true -- the indices were flattened out of a tuple, list or sequence... Slice slice_from_sequence(Arena& A, mpy::handle value) { if (mpy::tuple_view::check(value)) { @@ -2539,7 +2539,7 @@ IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice } } else if (Dim::check_exact(inp)) { auto d = Dim::unchecked_wrap(inp); - // dimesions used once are just binding operations + // dimensions used once are just binding operations if (1 == seen_dims_nuses[*seen_dims.index(d)]) { flat_inputs[i] = no_slice; result_levels.append(A, d); @@ -2798,7 +2798,7 @@ PyObject* py_split(PyObject *_, if (!dim.ptr()) { dim = A.autorelease(mpy::from_int(0)); } - mpy::raise_error(PyExc_TypeError, "tensor does not comtain dimension %R", dim.ptr()); + mpy::raise_error(PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); } Slice indices; diff --git a/functorch/csrc/dim/python_variable_simple.h b/functorch/csrc/dim/python_variable_simple.h index fbd5cfd8281579..d8c22ca312e35a 100644 --- a/functorch/csrc/dim/python_variable_simple.h +++ b/functorch/csrc/dim/python_variable_simple.h @@ -6,7 +6,7 @@ #pragma once // note: pytorch's python variable simple includes pybind which conflicts with minpybind -// so this file just reproduces the minimial API needed to extract Tensors from python objects. +// so this file just reproduces the minimal API needed to extract Tensors from python objects. #include #include diff --git a/functorch/dim/README.md b/functorch/dim/README.md index 74c25d949c0ba3..517930cb844b57 100644 --- a/functorch/dim/README.md +++ b/functorch/dim/README.md @@ -5,7 +5,7 @@ Named Tensors using First-class Dimensions in PyTorch _An implementation of [named tensors](https://namedtensor.github.io) with the functionality of [einsum](http://einops.rocks]http://einops.rocks) , batching ([vmap](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap), [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html)), and tensor indexing by adding dimension objects to PyTorch_. -The tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Eventhough 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension. +The tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Even though 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension. Named tensors gives these dimensions names. [PyTorch's current implementation](https://pytorch.org/docs/stable/named_tensor.html) uses strings to name dimensions. Instead, this library introduces a Python object, a `Dim`, to represent the concept. By expanding the semantics of tensors with dim objects, in addition to naming dimensions, we can get behavior equivalent to batching transforms (xmap, vmap), einops-style rearrangement, and loop-style tensor indexing. @@ -751,7 +751,7 @@ In this way, first-class dims are a way of adapting the nicer syntax of these ar Performance Expectations ======================== -First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can encorporate more fusion optimization to further improve performance of this style of code. +First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can incorporate more fusion optimization to further improve performance of this style of code. ## License diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index 691b1b984f8d55..f52d417d2ba275 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -58,7 +58,7 @@ def __repr__(self): class Dim(_C.Dim, _Tensor): - # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence. + # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precedence. # Tensor defines format, but we want to print Dims with special formatting __format__ = object.__format__ diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py index 5c6178c0981c75..fd934011d82388 100644 --- a/functorch/dim/reference.py +++ b/functorch/dim/reference.py @@ -507,7 +507,7 @@ def add_dims(t): for i in reversed(dim_packs): input[i : i + 1] = input[i] - # currenty: + # currently: # input is flat, containing either Dim, or Tensor, or something valid for standard indexing # self may have first-class dims as well. @@ -515,7 +515,7 @@ def add_dims(t): # drop the first class dims from self, they just become direct indices of their positions # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. - # these dimensions will appear and need to be bound at the first place tensor occures + # these dimensions will appear and need to be bound at the first place tensor occurs if isinstance(self, _Tensor): ptensor_self, levels = self._tensor, list(self._levels) diff --git a/functorch/examples/ensembling/parallel_train.py b/functorch/examples/ensembling/parallel_train.py index a674a24c738dc3..0a9abddc9cb5e3 100644 --- a/functorch/examples/ensembling/parallel_train.py +++ b/functorch/examples/ensembling/parallel_train.py @@ -138,7 +138,7 @@ def step6(): # Step 7: Now, the flaw with step 6 is that we were training on the same exact # data. This can lead to all of the models in the ensemble overfitting in the # same way. The solution that http://willwhitney.com/parallel-training-jax.html -# applies is to randomly subset the data in a way that the models do not recieve +# applies is to randomly subset the data in a way that the models do not receive # exactly the same data in each training step! # Because the goal of this doc is to show that we can use eager-mode vmap to # achieve similar things as JAX, the rest of this is left as an exercise to the reader. diff --git a/functorch/examples/lennard_jones/lennard_jones.py b/functorch/examples/lennard_jones/lennard_jones.py index 30a50c14a7f794..7d8a6be445ab83 100644 --- a/functorch/examples/lennard_jones/lennard_jones.py +++ b/functorch/examples/lennard_jones/lennard_jones.py @@ -1,4 +1,4 @@ -# This example was adapated from https://github.com/muhrin/milad +# This example was adapted from https://github.com/muhrin/milad # It is licensed under the GLPv3 license. You can find a copy of it # here: https://www.gnu.org/licenses/gpl-3.0.en.html . diff --git a/functorch/notebooks/_src/plot_jacobians_and_hessians.py b/functorch/notebooks/_src/plot_jacobians_and_hessians.py index 3faeaa9a167521..cab6a0d989edbb 100644 --- a/functorch/notebooks/_src/plot_jacobians_and_hessians.py +++ b/functorch/notebooks/_src/plot_jacobians_and_hessians.py @@ -100,7 +100,7 @@ def compute_jac(xp): # vjp and vmap transforms. # - jacfwd uses forward-mode AD. It is implemented as a composition of our # jvp and vmap transforms. -# jacfwd and jacrev can be subsituted for each other and have different +# jacfwd and jacrev can be substituted for each other and have different # performance characteristics. # # As a general rule of thumb, if you're computing the jacobian of an R^N -> R^M diff --git a/functorch/notebooks/jacobians_hessians.ipynb b/functorch/notebooks/jacobians_hessians.ipynb index 5b986a592b722c..4acf2ec609ff34 100644 --- a/functorch/notebooks/jacobians_hessians.ipynb +++ b/functorch/notebooks/jacobians_hessians.ipynb @@ -350,7 +350,7 @@ { "cell_type": "markdown", "source": [ - "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input." + "Furthermore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input." ], "metadata": { "id": "EQAB99EQflUJ" diff --git a/functorch/notebooks/per_sample_grads.ipynb b/functorch/notebooks/per_sample_grads.ipynb index a34c80d07ac4ac..e2317351f7eb1c 100644 --- a/functorch/notebooks/per_sample_grads.ipynb +++ b/functorch/notebooks/per_sample_grads.ipynb @@ -123,7 +123,7 @@ "predictions = model(data) # move the entire mini-batch through the model\n", "\n", "loss = loss_fn(predictions, targets)\n", - "loss.backward() # back propogate the 'average' gradient of this mini-batch" + "loss.backward() # back propagate the 'average' gradient of this mini-batch" ], "metadata": { "id": "WYjMx8QTUvRu" diff --git a/mypy-strict.ini b/mypy-strict.ini index 2feea92cb8c056..dddbb623047f71 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -20,7 +20,7 @@ disallow_any_unimported = True strict = True implicit_reexport = False -# do not reenable this: +# do not re-enable this: # https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657 warn_unused_ignores = False diff --git a/mypy.ini b/mypy.ini index 65f9ee43a6b8d1..e6a8af4c88c20c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ # test_run_mypy in test/test_type_hints.py uses this string) [mypy] -plugins = mypy_plugins/check_mypy_version.py, mypy_plugins/sympy_mypy_plugin.py, numpy.typing.mypy_plugin +plugins = mypy_plugins/check_mypy_version.py, mypy_plugins/sympy_mypy_plugin.py cache_dir = .mypy_cache/normal allow_redefinition = True @@ -17,7 +17,7 @@ follow_imports = normal local_partial_types = True enable_error_code = possibly-undefined -# do not reenable this: +# do not re-enable this: # https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657 warn_unused_ignores = False @@ -55,9 +55,6 @@ python_version = 3.11 # Extension modules without stubs. # -[mypy-torch._C._jit_tree_views] -ignore_missing_imports = True - [mypy-torch.for_onnx.onnx] ignore_missing_imports = True @@ -311,4 +308,4 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-redis] -ignore_missing_imports = True \ No newline at end of file +ignore_missing_imports = True diff --git a/pt_template_srcs.bzl b/pt_template_srcs.bzl index 6d42026ba6ca98..d3a8dcabaa7ed3 100644 --- a/pt_template_srcs.bzl +++ b/pt_template_srcs.bzl @@ -210,7 +210,7 @@ def get_metal_registration_files_outs(): # There is a really weird issue with the arvr windows builds where # the custom op files are breaking them. See https://fburl.com/za87443c -# The hack is just to not build them for that platform and pray they arent needed. +# The hack is just to not build them for that platform and pray they aren't needed. def get_metal_registration_files_outs_windows(): outs = {} for file_path in METAL_SOURCE_LIST: diff --git a/pyproject.toml b/pyproject.toml index 054eb4d6ecb760..a2939483fc355c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,46 +1,76 @@ -[project] -name = "torch" -requires-python = ">=3.9" -license = {text = "BSD-3-Clause"} -dynamic = [ - "authors", - "classifiers", - "entry-points", - "dependencies", - "description", - "keywords", - "optional-dependencies", - "readme", - "scripts", - "version", -] - -[project.urls] -Homepage = "https://pytorch.org/" -Documentation = "https://pytorch.org/docs/" -Source = "https://github.com/pytorch/pytorch" -Forum = "https://discuss.pytorch.org/" - +# Package ###################################################################### [build-system] requires = [ # After 75.8.2 dropped dep disttools API. Please fix # API temporarily restored and shim used. Please fix # Setuptools will drop support for setup.py past 80 - # min version for recursive glob package data support + # 62.3.0: min version for recursive glob package data support + # 77.0.0: min version for SPDX expression support for project.license "setuptools>=62.3.0,<80.0", "wheel", "astunparse", - "numpy", + "cmake>=3.27", "ninja", + "numpy", + "packaging", "pyyaml", - "cmake", - "typing-extensions>=4.10.0", "requests", + "typing-extensions>=4.10.0", +] +build-backend = "setuptools.build_meta" + +[project] +name = "torch" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +readme = "README.md" +requires-python = ">=3.9,<3.14" +# TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 +# FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. +# TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed +# to an error on 2026.02.18. See also: https://github.com/pypa/setuptools/issues/4903 +license = { text = "BSD-3-Clause" } +authors = [{ name = "PyTorch Team", email = "packages@pytorch.org" }] +keywords = ["pytorch", "machine learning"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Programming Language :: C++", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dynamic = [ + "entry-points", + "dependencies", + "scripts", + "version", ] -# Use legacy backend to import local packages in setup.py -build-backend = "setuptools.build_meta:__legacy__" +[project.urls] +Homepage = "https://pytorch.org" +Repository = "https://github.com/pytorch/pytorch" +Documentation = "https://pytorch.org/docs" +"Issue Tracker" = "https://github.com/pytorch/pytorch/issues" +Forum = "https://discuss.pytorch.org" + +[project.optional-dependencies] +optree = ["optree>=0.13.0"] +opt-einsum = ["opt-einsum>=3.3"] +pyyaml = ["pyyaml"] + +# Linter tools ################################################################# [tool.black] line-length = 88 @@ -59,12 +89,10 @@ multi_line_output = 3 include_trailing_comma = true combine_as_imports = true - [tool.usort.known] first_party = ["caffe2", "torch", "torchgen", "functorch", "test"] standard_library = ["typing_extensions"] - [tool.ruff] line-length = 88 src = ["caffe2", "torch", "torchgen", "functorch", "test"] @@ -272,10 +300,6 @@ select = [ "F401", "F403", ] -"torchgen/executorch/api/types/__init__.py" = [ - "F401", - "F403", -] "torch/utils/collect_env.py" = [ "UP", # collect_env.py needs to work with older versions of Python ] @@ -285,3 +309,6 @@ select = [ "tools/linter/**" = [ "LOG015" # please fix ] + +[tool.codespell] +ignore-words = "tools/linter/dictionary.txt" diff --git a/requirements.txt b/requirements.txt index 18f7810de95121..9bd9b54258146c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # Python dependencies required for development astunparse -cmake +build[uv] # for building sdist and wheel +cmake>=3.27 expecttest>=0.3.0 filelock fsspec @@ -19,4 +20,4 @@ requests setuptools>=62.3.0,<80.0 sympy>=1.13.3 types-dataclasses -typing-extensions>=4.10.0 +typing-extensions>=4.13.2 diff --git a/scripts/build_android.sh b/scripts/build_android.sh index de0bed7c26d417..43f11b86828d4c 100755 --- a/scripts/build_android.sh +++ b/scripts/build_android.sh @@ -157,7 +157,7 @@ if [ -n "${USE_VULKAN}" ]; then fi fi -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) # Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh index 06cae0dd41a3c3..7b1995a61ebc75 100755 --- a/scripts/build_mobile.sh +++ b/scripts/build_mobile.sh @@ -80,7 +80,7 @@ if [ "${VERBOSE:-}" == '1' ]; then CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") fi -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=("$@") # Now, actually build the Android target. diff --git a/scripts/compile_tests/download_reports.py b/scripts/compile_tests/download_reports.py index 03804b11f7eb60..1f48522a5e2d46 100644 --- a/scripts/compile_tests/download_reports.py +++ b/scripts/compile_tests/download_reports.py @@ -9,24 +9,26 @@ CONFIGS = { "dynamo39": { - "linux-focal-py3.9-clang10 / test (dynamo_wrapped, 1, 3, linux.2xlarge)", - "linux-focal-py3.9-clang10 / test (dynamo_wrapped, 2, 3, linux.2xlarge)", - "linux-focal-py3.9-clang10 / test (dynamo_wrapped, 3, 3, linux.2xlarge)", + "linux-jammy-py3.9-clang12 / test (dynamo_wrapped, 1, 3, linux.2xlarge)", + "linux-jammy-py3.9-clang12 / test (dynamo_wrapped, 2, 3, linux.2xlarge)", + "linux-jammy-py3.9-clang12 / test (dynamo_wrapped, 3, 3, linux.2xlarge)", }, - "dynamo311": { - "linux-focal-py3.11-clang10 / test (dynamo_wrapped, 1, 3, linux.2xlarge)", - "linux-focal-py3.11-clang10 / test (dynamo_wrapped, 2, 3, linux.2xlarge)", - "linux-focal-py3.11-clang10 / test (dynamo_wrapped, 3, 3, linux.2xlarge)", + "dynamo313": { + "linux-jammy-py3.13-clang12 / test (dynamo_wrapped, 1, 3, linux.2xlarge)", + "linux-jammy-py3.13-clang12 / test (dynamo_wrapped, 2, 3, linux.2xlarge)", + "linux-jammy-py3.13-clang12 / test (dynamo_wrapped, 3, 3, linux.2xlarge)", }, - "eager311": { - "linux-focal-py3.11-clang10 / test (default, 1, 3, linux.2xlarge)", - "linux-focal-py3.11-clang10 / test (default, 2, 3, linux.2xlarge)", - "linux-focal-py3.11-clang10 / test (default, 3, 3, linux.2xlarge)", + "eager313": { + "linux-jammy-py3.13-clang12 / test (default, 1, 5, linux.4xlarge)", + "linux-jammy-py3.13-clang12 / test (default, 2, 5, linux.4xlarge)", + "linux-jammy-py3.13-clang12 / test (default, 3, 5, linux.4xlarge)", + "linux-jammy-py3.13-clang12 / test (default, 4, 5, linux.4xlarge)", + "linux-jammy-py3.13-clang12 / test (default, 5, 5, linux.4xlarge)", }, } -def download_reports(commit_sha, configs=("dynamo39", "dynamo311", "eager311")): +def download_reports(commit_sha, configs=("dynamo39", "dynamo313", "eager313")): log_dir = "tmp_test_reports_" + commit_sha def subdir_path(config): diff --git a/scripts/compile_tests/update_failures.py b/scripts/compile_tests/update_failures.py index 73fb354a8d1533..2e38738059a018 100755 --- a/scripts/compile_tests/update_failures.py +++ b/scripts/compile_tests/update_failures.py @@ -221,5 +221,5 @@ def read_test_results(directory): args = parser.parse_args() assert Path(args.filename).exists(), args.filename assert Path(args.test_dir).exists(), args.test_dir - dynamo39, dynamo311 = download_reports(args.commit, ("dynamo39", "dynamo311")) - update(args.filename, args.test_dir, dynamo39, dynamo311, args.also_remove_skips) + dynamo39, dynamo313 = download_reports(args.commit, ("dynamo39", "dynamo313")) + update(args.filename, args.test_dir, dynamo39, dynamo313, args.also_remove_skips) diff --git a/scripts/install_triton_wheel.sh b/scripts/install_triton_wheel.sh index a3e1736362a5b0..5c4e72b8d3e217 100755 --- a/scripts/install_triton_wheel.sh +++ b/scripts/install_triton_wheel.sh @@ -1,5 +1,7 @@ #!/bin/bash # Updates Triton to the pinned version for this copy of PyTorch +PYTHON="python3" +PIP="$PYTHON -m pip" BRANCH=$(git rev-parse --abbrev-ref HEAD) DOWNLOAD_PYTORCH_ORG="https://download.pytorch.org/whl" @@ -7,25 +9,14 @@ if [[ -z "${USE_XPU}" ]]; then # Default install from PyTorch source TRITON_VERSION="pytorch-triton==$(cat .ci/docker/triton_version.txt)" - if [[ "$BRANCH" =~ .*release.* ]]; then - pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION - else - pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+git$(head -c 8 .ci/docker/ci_commit_pins/triton.txt) - fi + TRITON_COMMIT_ID="$(head -c 8 .ci/docker/ci_commit_pins/triton.txt)" else - # The Triton xpu logic is as follows: - # 1. By default, install pre-built whls. - # 2. [Not exposed to user] If the user set `TRITON_XPU_BUILD_FROM_SOURCE=1` flag, - # it will install Triton from the source. - - TRITON_VERSION="pytorch-triton-xpu==$(cat .ci/docker/triton_version.txt)" - TRITON_XPU_COMMIT_ID="$(head -c 8 .ci/docker/ci_commit_pins/triton-xpu.txt)" - if [[ -z "${TRITON_XPU_BUILD_FROM_SOURCE}" ]]; then - pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ ${TRITON_VERSION}+git${TRITON_XPU_COMMIT_ID} - else - TRITON_XPU_REPO="https://github.com/intel/intel-xpu-backend-for-triton" + TRITON_VERSION="pytorch-triton-xpu==$(cat .ci/docker/triton_xpu_version.txt)" + TRITON_COMMIT_ID="$(head -c 8 .ci/docker/ci_commit_pins/triton-xpu.txt)" +fi - # force-reinstall to ensure the pinned version is installed - pip install --force-reinstall "git+${TRITON_XPU_REPO}@${TRITON_XPU_COMMIT_ID}#subdirectory=python" - fi +if [[ "$BRANCH" =~ .*release.* ]]; then + ${PIP} install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION +else + ${PIP} install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+git${TRITON_COMMIT_ID} fi diff --git a/scripts/jit/log_extract.py b/scripts/jit/log_extract.py index 95d882b461d487..60aeaab92fc8ca 100644 --- a/scripts/jit/log_extract.py +++ b/scripts/jit/log_extract.py @@ -95,7 +95,7 @@ def run(): "--no-nnc-dynamic", dest="nnc_dynamic", action="store_false", - help="DONT't benchmark nnc with dynamic shapes", + help="don't benchmark nnc with dynamic shapes", ) parser.set_defaults(nnc_dynamic=False) diff --git a/scripts/lint_urls.sh b/scripts/lint_urls.sh index c8f4e183e17704..0a8fddce00b364 100755 --- a/scripts/lint_urls.sh +++ b/scripts/lint_urls.sh @@ -19,7 +19,7 @@ while IFS=: read -r filepath url; do code=$(curl -k -gsLm30 --retry 3 --retry-delay 3 --retry-connrefused -o /dev/null -w "%{http_code}" -I "$url") || code=000 if [ "$code" -lt 200 ] || [ "$code" -ge 400 ]; then sleep 1 - code=$(curl -k -gsLm30 --retry 3 --retry-delay 3 --retry-connrefused -o /dev/null -w "%{http_code}" -r 0-0 -A "$user_agent" "$url") || code=000 + code=$(curl -k -gsLm30 --retry 3 --retry-delay 3 --retry-connrefused -o /dev/null -w "%{http_code}" -r 0-0 -A "$user_agent" -H "Accept-Language: en-US,en" -H "Connection: keep-alive" "$url") || code=000 fi if [ "$code" -lt 200 ] || [ "$code" -ge 400 ]; then sleep 1 @@ -62,7 +62,7 @@ while IFS=: read -r filepath url; do sleep 1 done done < <( - pattern='(?!.*@lint-ignore)(?\")]*[<>\{\}\$])[^[:space:]<>")\[\]\\|]+' + pattern='(?!.*@lint-ignore)(?\")]*[<>\{\}\$])[[:alnum:]][^[:space:]<>")\[\]\\|]*' excludes=( ':(exclude,glob)**/.*' ':(exclude,glob)**/*.lock' diff --git a/scripts/release/README.md b/scripts/release/README.md index 5f35d4e4d771c0..bc32bd0cb656c6 100644 --- a/scripts/release/README.md +++ b/scripts/release/README.md @@ -10,8 +10,7 @@ These are a collection of scripts that are to be used for release activities. ### Order of Execution 1. Run cut-release-branch.sh to cut the release branch -2. Run tag-docker-images.sh to tag current docker images with release tag and push them to docker.io. These images will be used to build the release. -3. Run apply-release-changes.sh to apply release only changes to create a PR with release only changes similar to this [PR](https://github.com/pytorch/pytorch/pull/149056) +2. Run apply-release-changes.sh to apply release only changes to create a PR with release only changes similar to this [PR](https://github.com/pytorch/pytorch/pull/149056) #### Promoting packages diff --git a/scripts/release/tag-docker-images.sh b/scripts/release/tag-docker-images.sh deleted file mode 100644 index f2299d6c463ee2..00000000000000 --- a/scripts/release/tag-docker-images.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env bash -# -# Step 1 after branch cut is complete. -# -# Tags latest docker images for release branch. -# In case of failure. The script can be rerun. -# -# Before executing this script do: -# 1. Create and Check out to Release Branch -# git checkout -b "${RELEASE_BRANCH}" -# 2. Update submodules -# git submodule update --init --recursive -# -# Usage (run from root of project): -# DRY_RUN=disabled ./scripts/release/tag-docker-images.sh -# - -set -eou pipefail - -GIT_TOP_DIR=$(git rev-parse --show-toplevel) -RELEASE_VERSION=${RELEASE_VERSION:-$(cut -d'.' -f1-2 "${GIT_TOP_DIR}/version.txt")} -DRY_RUN=${DRY_RUN:-enabled} - -python3 .github/scripts/tag_docker_images_for_release.py --version ${RELEASE_VERSION} --dry-run ${DRY_RUN} diff --git a/scripts/release_notes/apply_categories.py b/scripts/release_notes/apply_categories.py index 786b1a95908b1c..9711737fc6537e 100644 --- a/scripts/release_notes/apply_categories.py +++ b/scripts/release_notes/apply_categories.py @@ -1,4 +1,4 @@ -# Quick scipt to apply categorized items to the +# Quick script to apply categorized items to the # base commitlist . Useful if you are refactoring any code # but want to keep the previous data on categories diff --git a/scripts/release_notes/classifier.py b/scripts/release_notes/classifier.py index c64bad818e4efe..a517ea7e77da56 100644 --- a/scripts/release_notes/classifier.py +++ b/scripts/release_notes/classifier.py @@ -156,9 +156,9 @@ def convert_index_to_category_name(self, most_likely_index): elif isinstance(most_likely_index, torch.Tensor): return [self.categories[i] for i in most_likely_index] - def get_most_likely_category_name(self, inpt): + def get_most_likely_category_name(self, input): # Input will be a dict with title and author keys - logits = self.forward(inpt) + logits = self.forward(input) most_likely_index = torch.argmax(logits, dim=1) return self.convert_index_to_category_name(most_likely_index) @@ -264,9 +264,9 @@ def generate_batch(batch): def train_step(batch, model, optimizer, loss): - inpt, targets = batch + input, targets = batch optimizer.zero_grad() - output = model(inpt) + output = model(input) l = loss(output, targets) l.backward() optimizer.step() @@ -275,8 +275,8 @@ def train_step(batch, model, optimizer, loss): @torch.no_grad() def eval_step(batch, model, loss): - inpt, targets = batch - output = model(inpt) + input, targets = batch + output = model(input) l = loss(output, targets) return l diff --git a/setup.py b/setup.py index 9ee5600be12d7d..ab42a9aa562e54 100644 --- a/setup.py +++ b/setup.py @@ -30,13 +30,19 @@ # CC # the C/C++ compiler to use # +# CMAKE_FRESH=1 +# force a fresh cmake configuration run, ignoring the existing cmake cache +# +# CMAKE_ONLY=1 +# run cmake and stop; do not build the project +# # Environment variables for feature toggles: # # DEBUG_CUDA=1 # if used in conjunction with DEBUG or REL_WITH_DEB_INFO, will also # build CUDA kernels with -lineinfo --source-in-ptx. Note that # on CUDA 12 this may cause nvcc to OOM, so this is disabled by default. - +# # USE_CUDNN=0 # disables the cuDNN build # @@ -221,66 +227,146 @@ # BUILD_PYTHON_ONLY # Builds pytorch as a wheel using libtorch.so from a separate wheel +from __future__ import annotations + import os import sys if sys.platform == "win32" and sys.maxsize.bit_length() == 31: print( - "32-bit Windows Python runtime is not supported. Please switch to 64-bit Python." + "32-bit Windows Python runtime is not supported. " + "Please switch to 64-bit Python.", + file=sys.stderr, ) sys.exit(-1) import platform -BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1" -BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1" - +# Also update `project.requires-python` in pyproject.toml when changing this python_min_version = (3, 9, 0) python_min_version_str = ".".join(map(str, python_min_version)) if sys.version_info < python_min_version: print( - f"You are using Python {platform.python_version()}. Python >={python_min_version_str} is required." + f"You are using Python {platform.python_version()}. " + f"Python >={python_min_version_str} is required.", + file=sys.stderr, ) sys.exit(-1) import filecmp import glob import importlib -import importlib.util +import itertools import json import shutil import subprocess import sysconfig import time from collections import defaultdict +from pathlib import Path +from typing import Any, ClassVar, IO import setuptools.command.build_ext -import setuptools.command.install import setuptools.command.sdist -from setuptools import Extension, find_packages, setup +import setuptools.errors +from setuptools import Command, Extension, find_packages, setup from setuptools.dist import Distribution + + +CWD = Path(__file__).absolute().parent + +# Add the current directory to the Python path so that we can import `tools`. +# This is required when running this script with a PEP-517-enabled build backend. +# +# From the PEP-517 documentation: https://peps.python.org/pep-0517 +# +# > When importing the module path, we do *not* look in the directory containing +# > the source tree, unless that would be on `sys.path` anyway (e.g. because it +# > is specified in `PYTHONPATH`). +# +sys.path.insert(0, str(CWD)) # this only affects the current process +# Add the current directory to PYTHONPATH so that we can import `tools` in subprocesses +os.environ["PYTHONPATH"] = os.pathsep.join( + [ + str(CWD), + os.getenv("PYTHONPATH", ""), + ] +).rstrip(os.pathsep) + from tools.build_pytorch_libs import build_pytorch from tools.generate_torch_version import get_torch_version -from tools.setup_helpers.cmake import CMake -from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS +from tools.setup_helpers.cmake import CMake, CMakeValue +from tools.setup_helpers.env import ( + BUILD_DIR, + build_type, + IS_DARWIN, + IS_LINUX, + IS_WINDOWS, +) from tools.setup_helpers.generate_linker_script import gen_linker_script -def _get_package_path(package_name): - spec = importlib.util.find_spec(package_name) +def str2bool(value: str | None) -> bool: + """Convert environment variables to boolean values.""" + if not value: + return False + if not isinstance(value, str): + raise ValueError( + f"Expected a string value for boolean conversion, got {type(value)}" + ) + value = value.strip().lower() + if value in ( + "1", + "true", + "t", + "yes", + "y", + "on", + "enable", + "enabled", + "found", + ): + return True + if value in ( + "0", + "false", + "f", + "no", + "n", + "off", + "disable", + "disabled", + "notfound", + "none", + "null", + "nil", + "undefined", + "n/a", + ): + return False + raise ValueError(f"Invalid string value for boolean conversion: {value}") + + +def _get_package_path(package_name: str) -> Path: + from importlib.util import find_spec + + spec = find_spec(package_name) if spec: # The package might be a namespace package, so get_data may fail try: loader = spec.loader if loader is not None: file_path = loader.get_filename() # type: ignore[attr-defined] - return os.path.dirname(file_path) + return Path(file_path).parent except AttributeError: pass - return None + return CWD / package_name + +BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL")) +BUILD_PYTHON_ONLY = str2bool(os.getenv("BUILD_PYTHON_ONLY")) # set up appropriate env variables if BUILD_LIBTORCH_WHL: @@ -288,22 +374,21 @@ def _get_package_path(package_name): # functorch is not supported without python os.environ["BUILD_FUNCTORCH"] = "OFF" - if BUILD_PYTHON_ONLY: os.environ["BUILD_LIBTORCHLESS"] = "ON" - os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib" + os.environ["LIBTORCH_LIB_PATH"] = (_get_package_path("torch") / "lib").as_posix() ################################################################################ # Parameters parsed from environment ################################################################################ -VERBOSE_SCRIPT = True +VERBOSE_SCRIPT = str2bool(os.getenv("VERBOSE", "1")) RUN_BUILD_DEPS = True # see if the user passed a quiet flag to setup.py arguments and respect # that in our parts of the build EMIT_BUILD_WARNING = False -RERUN_CMAKE = False -CMAKE_ONLY = False +RERUN_CMAKE = str2bool(os.environ.pop("CMAKE_FRESH", None)) +CMAKE_ONLY = str2bool(os.environ.pop("CMAKE_ONLY", None)) filtered_args = [] for i, arg in enumerate(sys.argv): if arg == "--cmake": @@ -322,67 +407,72 @@ def _get_package_path(package_name): break if arg == "-q" or arg == "--quiet": VERBOSE_SCRIPT = False - if arg in ["clean", "egg_info", "sdist"]: + if arg in ["clean", "dist_info", "egg_info", "sdist"]: RUN_BUILD_DEPS = False filtered_args.append(arg) sys.argv = filtered_args if VERBOSE_SCRIPT: - def report(*args): - print(*args) + def report( + *args: Any, file: IO[str] = sys.stderr, flush: bool = True, **kwargs: Any + ) -> None: + print(*args, file=file, flush=flush, **kwargs) else: - def report(*args): + def report( + *args: Any, file: IO[str] = sys.stderr, flush: bool = True, **kwargs: Any + ) -> None: pass # Make distutils respect --quiet too - setuptools.distutils.log.warn = report + setuptools.distutils.log.warn = report # type: ignore[attr-defined] # Constant known variables used throughout this file -cwd = os.path.dirname(os.path.abspath(__file__)) -lib_path = os.path.join(cwd, "torch", "lib") -third_party_path = os.path.join(cwd, "third_party") +TORCH_DIR = CWD / "torch" +TORCH_LIB_DIR = TORCH_DIR / "lib" +THIRD_PARTY_DIR = CWD / "third_party" # CMAKE: full path to python library if IS_WINDOWS: - cmake_python_library = "{}/libs/python{}.lib".format( - sysconfig.get_config_var("prefix"), sysconfig.get_config_var("VERSION") + CMAKE_PYTHON_LIBRARY = ( + Path(sysconfig.get_config_var("prefix")) + / "libs" + / f"python{sysconfig.get_config_var('VERSION')}.lib" ) # Fix virtualenv builds - if not os.path.exists(cmake_python_library): - cmake_python_library = "{}/libs/python{}.lib".format( - sys.base_prefix, sysconfig.get_config_var("VERSION") + if not CMAKE_PYTHON_LIBRARY.exists(): + CMAKE_PYTHON_LIBRARY = ( + Path(sys.base_prefix) + / "libs" + / f"python{sysconfig.get_config_var('VERSION')}.lib" ) else: - cmake_python_library = "{}/{}".format( - sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("INSTSONAME") - ) -cmake_python_include_dir = sysconfig.get_path("include") + CMAKE_PYTHON_LIBRARY = Path( + sysconfig.get_config_var("LIBDIR") + ) / sysconfig.get_config_var("INSTSONAME") ################################################################################ # Version, create_version_file, and package_name ################################################################################ -package_name = os.getenv("TORCH_PACKAGE_NAME", "torch") +TORCH_PACKAGE_NAME = os.getenv("TORCH_PACKAGE_NAME", "torch") LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python") if BUILD_LIBTORCH_WHL: - package_name = LIBTORCH_PKG_NAME - + TORCH_PACKAGE_NAME = LIBTORCH_PKG_NAME -package_type = os.getenv("PACKAGE_TYPE", "wheel") -version = get_torch_version() -report(f"Building wheel {package_name}-{version}") +TORCH_VERSION = get_torch_version() +report(f"Building wheel {TORCH_PACKAGE_NAME}-{TORCH_VERSION}") cmake = CMake() -def get_submodule_folders(): - git_modules_path = os.path.join(cwd, ".gitmodules") +def get_submodule_folders() -> list[Path]: + git_modules_file = CWD / ".gitmodules" default_modules_path = [ - os.path.join(third_party_path, name) + THIRD_PARTY_DIR / name for name in [ "gloo", "cpuinfo", @@ -391,29 +481,29 @@ def get_submodule_folders(): "cutlass", ] ] - if not os.path.exists(git_modules_path): + if not git_modules_file.exists(): return default_modules_path - with open(git_modules_path) as f: + with git_modules_file.open(encoding="utf-8") as f: return [ - os.path.join(cwd, line.split("=", 1)[1].strip()) + CWD / line.partition("=")[-1].strip() for line in f if line.strip().startswith("path") ] -def check_submodules(): - def check_for_files(folder, files): - if not any(os.path.exists(os.path.join(folder, f)) for f in files): +def check_submodules() -> None: + def check_for_files(folder: Path, files: list[str]) -> None: + if not any((folder / f).exists() for f in files): report("Could not find any of {} in {}".format(", ".join(files), folder)) report("Did you run 'git submodule update --init --recursive'?") sys.exit(1) - def not_exists_or_empty(folder): - return not os.path.exists(folder) or ( - os.path.isdir(folder) and len(os.listdir(folder)) == 0 + def not_exists_or_empty(folder: Path) -> bool: + return not folder.exists() or ( + folder.is_dir() and next(folder.iterdir(), None) is None ) - if bool(os.getenv("USE_SYSTEM_LIBS", False)): + if str2bool(os.getenv("USE_SYSTEM_LIBS")): return folders = get_submodule_folders() # If none of the submodule folders exists, try to initialize them @@ -422,12 +512,12 @@ def not_exists_or_empty(folder): report(" --- Trying to initialize submodules") start = time.time() subprocess.check_call( - ["git", "submodule", "update", "--init", "--recursive"], cwd=cwd + ["git", "submodule", "update", "--init", "--recursive"], cwd=CWD ) end = time.time() report(f" --- Submodule initialization took {end - start:.2f} sec") except Exception: - report(" --- Submodule initalization failed") + report(" --- Submodule initialization failed") report("Please run:\n\tgit submodule update --init --recursive") sys.exit(1) for folder in folders: @@ -443,37 +533,49 @@ def not_exists_or_empty(folder): ], ) check_for_files( - os.path.join(third_party_path, "fbgemm", "external", "asmjit"), + THIRD_PARTY_DIR / "fbgemm" / "external" / "asmjit", ["CMakeLists.txt"], ) # Windows has very bad support for symbolic links. # Instead of using symlinks, we're going to copy files over -def mirror_files_into_torchgen(): +def mirror_files_into_torchgen() -> None: # (new_path, orig_path) # Directories are OK and are recursively mirrored. paths = [ ( - "torchgen/packaged/ATen/native/native_functions.yaml", - "aten/src/ATen/native/native_functions.yaml", + CWD / "torchgen/packaged/ATen/native/native_functions.yaml", + CWD / "aten/src/ATen/native/native_functions.yaml", + ), + ( + CWD / "torchgen/packaged/ATen/native/tags.yaml", + CWD / "aten/src/ATen/native/tags.yaml", + ), + ( + CWD / "torchgen/packaged/ATen/templates", + CWD / "aten/src/ATen/templates", + ), + ( + CWD / "torchgen/packaged/autograd", + CWD / "tools/autograd", + ), + ( + CWD / "torchgen/packaged/autograd/templates", + CWD / "tools/autograd/templates", ), - ("torchgen/packaged/ATen/native/tags.yaml", "aten/src/ATen/native/tags.yaml"), - ("torchgen/packaged/ATen/templates", "aten/src/ATen/templates"), - ("torchgen/packaged/autograd", "tools/autograd"), - ("torchgen/packaged/autograd/templates", "tools/autograd/templates"), ] for new_path, orig_path in paths: # Create the dirs involved in new_path if they don't exist - if not os.path.exists(new_path): - os.makedirs(os.path.dirname(new_path), exist_ok=True) + if not new_path.exists(): + new_path.parent.mkdir(parents=True, exist_ok=True) # Copy the files from the orig location to the new location - if os.path.isfile(orig_path): + if orig_path.is_file(): shutil.copyfile(orig_path, new_path) continue - if os.path.isdir(orig_path): - if os.path.exists(new_path): + if orig_path.is_dir(): + if new_path.exists(): # copytree fails if the tree exists already, so remove it. shutil.rmtree(new_path) shutil.copytree(orig_path, new_path) @@ -482,15 +584,14 @@ def mirror_files_into_torchgen(): # all the work we need to do _before_ setup runs -def build_deps(): - report("-- Building version " + version) +def build_deps() -> None: + report(f"-- Building version {TORCH_VERSION}") check_submodules() check_pydep("yaml", "pyyaml") - build_python = not BUILD_LIBTORCH_WHL build_pytorch( - version=version, - cmake_python_library=cmake_python_library, - build_python=build_python, + version=TORCH_VERSION, + cmake_python_library=CMAKE_PYTHON_LIBRARY.as_posix(), + build_python=not BUILD_LIBTORCH_WHL, rerun_cmake=RERUN_CMAKE, cmake_only=CMAKE_ONLY, cmake=cmake, @@ -507,22 +608,22 @@ def build_deps(): # Use copies instead of symbolic files. # Windows has very poor support for them. sym_files = [ - "tools/shared/_utils_internal.py", - "torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h", - "torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h", + CWD / "tools/shared/_utils_internal.py", + CWD / "torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h", + CWD / "torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h", ] orig_files = [ - "torch/_utils_internal.py", - "third_party/valgrind-headers/callgrind.h", - "third_party/valgrind-headers/valgrind.h", + CWD / "torch/_utils_internal.py", + CWD / "third_party/valgrind-headers/callgrind.h", + CWD / "third_party/valgrind-headers/valgrind.h", ] for sym_file, orig_file in zip(sym_files, orig_files): same = False - if os.path.exists(sym_file): + if sym_file.exists(): if filecmp.cmp(sym_file, orig_file): same = True else: - os.remove(sym_file) + sym_file.unlink() if not same: shutil.copyfile(orig_file, sym_file) @@ -537,7 +638,7 @@ def build_deps(): """.strip() -def check_pydep(importname, module): +def check_pydep(importname: str, module: str) -> None: try: importlib.import_module(importname) except ImportError as e: @@ -547,19 +648,22 @@ def check_pydep(importname, module): class build_ext(setuptools.command.build_ext.build_ext): - def _embed_libomp(self): + def _embed_libomp(self) -> None: # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS - lib_dir = os.path.join(self.build_lib, "torch", "lib") - libtorch_cpu_path = os.path.join(lib_dir, "libtorch_cpu.dylib") - if not os.path.exists(libtorch_cpu_path): + build_lib = Path(self.build_lib) + build_torch_lib_dir = build_lib / "torch" / "lib" + build_torch_include_dir = build_lib / "torch" / "include" + libtorch_cpu_path = build_torch_lib_dir / "libtorch_cpu.dylib" + if not libtorch_cpu_path.exists(): return # Parse libtorch_cpu load commands otool_cmds = ( - subprocess.check_output(["otool", "-l", libtorch_cpu_path]) + subprocess.check_output(["otool", "-l", str(libtorch_cpu_path)]) .decode("utf-8") .split("\n") ) - rpaths, libs = [], [] + rpaths: list[str] = [] + libs: list[str] = [] for idx, line in enumerate(otool_cmds): if line.strip() == "cmd LC_LOAD_DYLIB": lib_name = otool_cmds[idx + 2].strip() @@ -571,8 +675,9 @@ def _embed_libomp(self): assert rpath.startswith("path ") rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1]) - omplib_path = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] - omplib_name = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] + ".dylib" + omplib_path: str = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] # type: ignore[assignment] + omplib_name: str = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] # type: ignore[assignment] + omplib_name += ".dylib" omplib_rpath_path = os.path.join("@rpath", omplib_name) # This logic is fragile and checks only two cases: @@ -582,8 +687,9 @@ def _embed_libomp(self): return # Copy libomp/libiomp5 from rpath locations - target_lib = os.path.join(self.build_lib, "torch", "lib", omplib_name) + target_lib = build_torch_lib_dir / omplib_name libomp_relocated = False + install_name_tool_args: list[str] = [] for rpath in rpaths: source_lib = os.path.join(rpath, omplib_name) if not os.path.exists(source_lib): @@ -614,24 +720,31 @@ def _embed_libomp(self): ] libomp_relocated = True if libomp_relocated: - install_name_tool_args.insert(0, "install_name_tool") - install_name_tool_args.append(libtorch_cpu_path) + install_name_tool_args = [ + "install_name_tool", + *install_name_tool_args, + str(libtorch_cpu_path), + ] subprocess.check_call(install_name_tool_args) # Copy omp.h from OpenMP_C_FLAGS and copy it into include folder - omp_cflags = get_cmake_cache_vars()["OpenMP_C_FLAGS"] + omp_cflags: str = get_cmake_cache_vars()["OpenMP_C_FLAGS"] # type: ignore[assignment] if not omp_cflags: return - for include_dir in [f[2:] for f in omp_cflags.split(" ") if f.startswith("-I")]: - omp_h = os.path.join(include_dir, "omp.h") - if not os.path.exists(omp_h): + for include_dir in [ + Path(f.removeprefix("-I")) + for f in omp_cflags.split(" ") + if f.startswith("-I") + ]: + omp_h = include_dir / "omp.h" + if not omp_h.exists(): continue - target_omp_h = os.path.join(self.build_lib, "torch", "include", "omp.h") + target_omp_h = build_torch_include_dir / "omp.h" self.copy_file(omp_h, target_omp_h) break - def run(self): - # Report build options. This is run after the build completes so # `CMakeCache.txt` exists and we can get an - # accurate report on what is used and what is not. + def run(self) -> None: + # Report build options. This is run after the build completes so # `CMakeCache.txt` exists + # and we can get an accurate report on what is used and what is not. cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) if cmake_cache_vars["USE_NUMPY"]: report("-- Building with NumPy bindings") @@ -640,18 +753,17 @@ def run(self): if cmake_cache_vars["USE_CUDNN"]: report( "-- Detected cuDNN at " - + cmake_cache_vars["CUDNN_LIBRARY"] - + ", " - + cmake_cache_vars["CUDNN_INCLUDE_DIR"] + f"{cmake_cache_vars['CUDNN_LIBRARY']}, " + f"{cmake_cache_vars['CUDNN_INCLUDE_DIR']}" ) else: report("-- Not using cuDNN") if cmake_cache_vars["USE_CUDA"]: - report("-- Detected CUDA at " + cmake_cache_vars["CUDA_TOOLKIT_ROOT_DIR"]) + report(f"-- Detected CUDA at {cmake_cache_vars['CUDA_TOOLKIT_ROOT_DIR']}") else: report("-- Not using CUDA") if cmake_cache_vars["USE_XPU"]: - report("-- Detected XPU runtime at " + cmake_cache_vars["SYCL_LIBRARY_DIR"]) + report(f"-- Detected XPU runtime at {cmake_cache_vars['SYCL_LIBRARY_DIR']}") else: report("-- Not using XPU") if cmake_cache_vars["USE_MKLDNN"]: @@ -670,10 +782,9 @@ def run(self): report("-- Not using MKLDNN") if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]: report( - "-- Using system provided NCCL library at {}, {}".format( - cmake_cache_vars["NCCL_LIBRARIES"], - cmake_cache_vars["NCCL_INCLUDE_DIRS"], - ) + "-- Using system provided NCCL library at " + f"{cmake_cache_vars['NCCL_LIBRARIES']}, " + f"{cmake_cache_vars['NCCL_INCLUDE_DIRS']}" ) elif cmake_cache_vars["USE_NCCL"]: report("-- Building NCCL library") @@ -684,23 +795,18 @@ def run(self): report("-- Building without distributed package") else: report("-- Building with distributed package: ") - report( - " -- USE_TENSORPIPE={}".format(cmake_cache_vars["USE_TENSORPIPE"]) - ) - report(" -- USE_GLOO={}".format(cmake_cache_vars["USE_GLOO"])) - report(" -- USE_MPI={}".format(cmake_cache_vars["USE_OPENMPI"])) + report(f" -- USE_TENSORPIPE={cmake_cache_vars['USE_TENSORPIPE']}") + report(f" -- USE_GLOO={cmake_cache_vars['USE_GLOO']}") + report(f" -- USE_MPI={cmake_cache_vars['USE_OPENMPI']}") else: report("-- Building without distributed package") if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]: report( - "-- Using static dispatch with backend {}".format( - cmake_cache_vars["STATIC_DISPATCH_BACKEND"] - ) + "-- Using static dispatch with " + f"backend {cmake_cache_vars['STATIC_DISPATCH_BACKEND']}" ) if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]: report("-- Using lightweight dispatch") - if cmake_cache_vars["BUILD_EXECUTORCH"]: - report("-- Building Executorch") if cmake_cache_vars["USE_ITT"]: report("-- Using ITT") @@ -709,99 +815,90 @@ def run(self): # Do not use clang to compile extensions if `-fstack-clash-protection` is defined # in system CFLAGS - c_flags = str(os.getenv("CFLAGS", "")) + c_flags = os.getenv("CFLAGS", "") if ( IS_LINUX and "-fstack-clash-protection" in c_flags - and "clang" in os.environ.get("CC", "") + and "clang" in os.getenv("CC", "") ): os.environ["CC"] = str(os.environ["CC"]) - # It's an old-style class in Python 2.7... - setuptools.command.build_ext.build_ext.run(self) + super().run() if IS_DARWIN: self._embed_libomp() # Copy the essential export library to compile C++ extensions. if IS_WINDOWS: - build_temp = self.build_temp + build_temp = Path(self.build_temp) + build_lib = Path(self.build_lib) ext_filename = self.get_ext_filename("_C") lib_filename = ".".join(ext_filename.split(".")[:-1]) + ".lib" - export_lib = os.path.join( - build_temp, "torch", "csrc", lib_filename - ).replace("\\", "/") - - build_lib = self.build_lib - - target_lib = os.path.join(build_lib, "torch", "lib", "_C.lib").replace( - "\\", "/" - ) + export_lib = build_temp / "torch" / "csrc" / lib_filename + target_lib = build_lib / "torch" / "lib" / "_C.lib" # Create "torch/lib" directory if not exists. # (It is not created yet in "develop" mode.) - target_dir = os.path.dirname(target_lib) - if not os.path.exists(target_dir): - os.makedirs(target_dir) - + target_dir = target_lib.parent + target_dir.mkdir(parents=True, exist_ok=True) self.copy_file(export_lib, target_lib) # In ROCm on Windows case copy rocblas and hipblaslt files into # torch/lib/rocblas/library and torch/lib/hipblaslt/library - use_rocm = os.environ.get("USE_ROCM") - if use_rocm: - rocm_dir_path = os.environ.get("ROCM_DIR") - rocm_bin_path = os.path.join(rocm_dir_path, "bin") - - rocblas_dir = os.path.join(rocm_bin_path, "rocblas") - target_rocblas_dir = os.path.join(target_dir, "rocblas") - os.makedirs(target_rocblas_dir, exist_ok=True) - self.copy_tree(rocblas_dir, target_rocblas_dir) - - hipblaslt_dir = os.path.join(rocm_bin_path, "hipblaslt") - target_hipblaslt_dir = os.path.join(target_dir, "hipblaslt") - os.makedirs(target_hipblaslt_dir, exist_ok=True) - self.copy_tree(hipblaslt_dir, target_hipblaslt_dir) + if str2bool(os.getenv("USE_ROCM")): + rocm_dir_path = Path(os.environ["ROCM_DIR"]) + rocm_bin_path = rocm_dir_path / "bin" + rocblas_dir = rocm_bin_path / "rocblas" + target_rocblas_dir = target_dir / "rocblas" + target_rocblas_dir.mkdir(parents=True, exist_ok=True) + self.copy_tree(rocblas_dir, str(target_rocblas_dir)) + + hipblaslt_dir = rocm_bin_path / "hipblaslt" + target_hipblaslt_dir = target_dir / "hipblaslt" + target_hipblaslt_dir.mkdir(parents=True, exist_ok=True) + self.copy_tree(hipblaslt_dir, str(target_hipblaslt_dir)) else: report("The specified environment variable does not exist.") - def build_extensions(self): + def build_extensions(self) -> None: self.create_compile_commands() + build_lib = Path(self.build_lib).resolve() + # Copy functorch extension - for i, ext in enumerate(self.extensions): + for ext in self.extensions: if ext.name != "functorch._C": continue fullname = self.get_ext_fullname(ext.name) - filename = self.get_ext_filename(fullname) - fileext = os.path.splitext(filename)[1] - src = os.path.join(os.path.dirname(filename), "functorch" + fileext) - dst = os.path.join(os.path.realpath(self.build_lib), filename) - if os.path.exists(src): + filename = Path(self.get_ext_filename(fullname)) + src = filename.with_stem("functorch") + dst = build_lib / filename + if src.exists(): report(f"Copying {ext.name} from {src} to {dst}") - dst_dir = os.path.dirname(dst) - if not os.path.exists(dst_dir): - os.makedirs(dst_dir) + dst.parent.mkdir(parents=True, exist_ok=True) self.copy_file(src, dst) - setuptools.command.build_ext.build_ext.build_extensions(self) + super().build_extensions() - def get_outputs(self): - outputs = setuptools.command.build_ext.build_ext.get_outputs(self) + def get_outputs(self) -> list[str]: + outputs = super().get_outputs() outputs.append(os.path.join(self.build_lib, "caffe2")) report(f"setup.py::get_outputs returning {outputs}") return outputs - def create_compile_commands(self): - def load(filename): - with open(filename) as f: - return json.load(f) + def create_compile_commands(self) -> None: + def load(file: Path) -> list[dict[str, Any]]: + return json.loads(file.read_text(encoding="utf-8")) - ninja_files = glob.glob("build/*compile_commands.json") - cmake_files = glob.glob("torch/lib/build/*/compile_commands.json") - all_commands = [entry for f in ninja_files + cmake_files for entry in load(f)] + ninja_files = (CWD / BUILD_DIR).glob("*compile_commands.json") + cmake_files = (CWD / "torch" / "lib" / "build").glob("*/compile_commands.json") + all_commands = [ + entry + for f in itertools.chain(ninja_files, cmake_files) + for entry in load(f) + ] # cquery does not like c++ compiles that start with gcc. # It forgets to include the c++ header directories. @@ -813,12 +910,11 @@ def load(filename): new_contents = json.dumps(all_commands, indent=2) contents = "" - if os.path.exists("compile_commands.json"): - with open("compile_commands.json") as f: - contents = f.read() + compile_commands_json = CWD / "compile_commands.json" + if compile_commands_json.exists(): + contents = compile_commands_json.read_text(encoding="utf-8") if contents != new_contents: - with open("compile_commands.json", "w") as f: - f.write(new_contents) + compile_commands_json.write_text(new_contents, encoding="utf-8") class concat_license_files: @@ -830,123 +926,120 @@ class concat_license_files: licensing info. """ - def __init__(self, include_files=False): - self.f1 = "LICENSE" - self.f2 = "third_party/LICENSES_BUNDLED.txt" + def __init__(self, include_files: bool = False) -> None: + self.f1 = CWD / "LICENSE" + self.f2 = THIRD_PARTY_DIR / "LICENSES_BUNDLED.txt" self.include_files = include_files + self.bsd_text = "" - def __enter__(self): + def __enter__(self) -> None: """Concatenate files""" old_path = sys.path - sys.path.append(third_party_path) + sys.path.append(str(THIRD_PARTY_DIR)) try: - from build_bundled import create_bundled + from build_bundled import create_bundled # type: ignore[import-not-found] finally: sys.path = old_path - with open(self.f1) as f1: - self.bsd_text = f1.read() + self.bsd_text = self.f1.read_text(encoding="utf-8") - with open(self.f1, "a") as f1: + with self.f1.open(mode="a", encoding="utf-8") as f1: f1.write("\n\n") create_bundled( - os.path.relpath(third_party_path), f1, include_files=self.include_files + str(THIRD_PARTY_DIR.resolve()), + f1, + include_files=self.include_files, ) - def __exit__(self, exception_type, exception_value, traceback): + def __exit__(self, *exc_info: object) -> None: """Restore content of f1""" - with open(self.f1, "w") as f: - f.write(self.bsd_text) + self.f1.write_text(self.bsd_text, encoding="utf-8") try: - from wheel.bdist_wheel import bdist_wheel + from wheel.bdist_wheel import bdist_wheel # type: ignore[import-untyped] except ImportError: # This is useful when wheel is not installed and bdist_wheel is not # specified on the command line. If it _is_ specified, parsing the command # line will fail before wheel_concatenate is needed - wheel_concatenate = None + wheel_concatenate: type[Command] | None = None else: # Need to create the proper LICENSE.txt for the wheel - class wheel_concatenate(bdist_wheel): + class wheel_concatenate(bdist_wheel): # type: ignore[no-redef] """check submodules on sdist to prevent incomplete tarballs""" - def run(self): + def run(self) -> None: with concat_license_files(include_files=True): super().run() - def write_wheelfile(self, *args, **kwargs): + def write_wheelfile(self, *args: Any, **kwargs: Any) -> None: super().write_wheelfile(*args, **kwargs) if BUILD_LIBTORCH_WHL: + bdist_dir = Path(self.bdist_dir) # Remove extraneneous files in the libtorch wheel - for root, dirs, files in os.walk(self.bdist_dir): - for file in files: - if file.endswith((".a", ".so")) and os.path.isfile( - os.path.join(self.bdist_dir, file) - ): - os.remove(os.path.join(root, file)) - elif file.endswith(".py"): - os.remove(os.path.join(root, file)) + for file in itertools.chain( + bdist_dir.rglob("*.a"), + bdist_dir.rglob("*.so"), + ): + if (bdist_dir / file.name).is_file(): + file.unlink() + for file in bdist_dir.rglob("*.py"): + file.unlink() # need an __init__.py file otherwise we wouldn't have a package - open(os.path.join(self.bdist_dir, "torch", "__init__.py"), "w").close() - - -class install(setuptools.command.install.install): - def run(self): - super().run() + (bdist_dir / "torch" / "__init__.py").touch() -class clean(setuptools.Command): - user_options = [] +class clean(Command): + user_options: ClassVar[list[tuple[str, str | None, str]]] = [] - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): - import glob - import re - - with open(".gitignore") as f: - ignores = f.read() - pat = re.compile(r"^#( BEGIN NOT-CLEAN-FILES )?") - for wildcard in filter(None, ignores.split("\n")): - match = pat.match(wildcard) - if match: - if match.group(1): - # Marker is found and stop reading .gitignore. - break - # Ignore lines which begin with '#'. - else: - # Don't remove absolute paths from the system - wildcard = wildcard.lstrip("./") - - for filename in glob.glob(wildcard): - try: - os.remove(filename) - except OSError: - shutil.rmtree(filename, ignore_errors=True) + def run(self) -> None: + ignores = (CWD / ".gitignore").read_text(encoding="utf-8") + for wildcard in filter(None, ignores.splitlines()): + if wildcard.strip().startswith("#"): + if "BEGIN NOT-CLEAN-FILES" in wildcard: + # Marker is found and stop reading .gitignore. + break + # Ignore lines which begin with '#'. + else: + # Don't remove absolute paths from the system + wildcard = wildcard.lstrip("./") + for filename in glob.iglob(wildcard): + try: + os.remove(filename) + except OSError: + shutil.rmtree(filename, ignore_errors=True) class sdist(setuptools.command.sdist.sdist): - def run(self): + def run(self) -> None: with concat_license_files(): super().run() -def get_cmake_cache_vars(): +def get_cmake_cache_vars() -> defaultdict[str, CMakeValue]: try: return defaultdict(lambda: False, cmake.get_cmake_cache_variables()) except FileNotFoundError: - # CMakeCache.txt does not exist. Probably running "python setup.py clean" over a clean directory. + # CMakeCache.txt does not exist. + # Probably running "python setup.py clean" over a clean directory. return defaultdict(lambda: False) -def configure_extension_build(): +def configure_extension_build() -> tuple[ + list[Extension], # ext_modules + dict[str, type[Command]], # cmdclass + list[str], # packages + dict[str, list[str]], # entry_points + list[str], # extra_install_requires +]: r"""Configures extension build options according to system environment and user's choice. Returns: @@ -959,17 +1052,17 @@ def configure_extension_build(): # Configure compile flags ################################################################################ - library_dirs = [] - extra_install_requires = [] + library_dirs: list[str] = [str(TORCH_LIB_DIR)] + extra_install_requires: list[str] = [] if IS_WINDOWS: # /NODEFAULTLIB makes sure we only link to DLL runtime # and matches the flags set for protobuf and ONNX - extra_link_args = ["/NODEFAULTLIB:LIBCMT.LIB"] + extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] # /MD links against DLL runtime # and matches the flags set for protobuf and ONNX # /EHsc is about standard C++ exception handling - extra_compile_args = ["/MD", "/FS", "/EHsc"] + extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"] else: extra_link_args = [] extra_compile_args = [ @@ -985,13 +1078,11 @@ def configure_extension_build(): "-fno-strict-aliasing", ] - library_dirs.append(lib_path) - - main_compile_args = [] - main_libraries = ["torch_python"] + main_compile_args: list[str] = [] + main_libraries: list[str] = ["torch_python"] - main_link_args = [] - main_sources = ["torch/csrc/stub.c"] + main_link_args: list[str] = [] + main_sources: list[str] = ["torch/csrc/stub.c"] if BUILD_LIBTORCH_WHL: main_libraries = ["torch"] @@ -999,30 +1090,28 @@ def configure_extension_build(): if build_type.is_debug(): if IS_WINDOWS: - extra_compile_args.append("/Z7") - extra_link_args.append("/DEBUG:FULL") + extra_compile_args += ["/Z7"] + extra_link_args += ["/DEBUG:FULL"] else: extra_compile_args += ["-O0", "-g"] extra_link_args += ["-O0", "-g"] if build_type.is_rel_with_deb_info(): if IS_WINDOWS: - extra_compile_args.append("/Z7") - extra_link_args.append("/DEBUG:FULL") + extra_compile_args += ["/Z7"] + extra_link_args += ["/DEBUG:FULL"] else: extra_compile_args += ["-g"] extra_link_args += ["-g"] # pypi cuda package that requires installation of cuda runtime, cudnn and cublas # should be included in all wheels uploaded to pypi - pytorch_extra_install_requirements = os.getenv( - "PYTORCH_EXTRA_INSTALL_REQUIREMENTS", "" - ) - if pytorch_extra_install_requirements: - report( - f"pytorch_extra_install_requirements: {pytorch_extra_install_requirements}" + pytorch_extra_install_requires = os.getenv("PYTORCH_EXTRA_INSTALL_REQUIREMENTS") + if pytorch_extra_install_requires: + report(f"pytorch_extra_install_requirements: {pytorch_extra_install_requires}") + extra_install_requires.extend( + map(str.strip, pytorch_extra_install_requires.split("|")) ) - extra_install_requires += pytorch_extra_install_requirements.split("|") # Cross-compile for M1 if IS_DARWIN: @@ -1045,7 +1134,7 @@ def configure_extension_build(): ] extra_link_args += ["-arch", macos_target_arch] - def make_relative_rpath_args(path): + def make_relative_rpath_args(path: str) -> list[str]: if IS_DARWIN: return ["-Wl,-rpath,@loader_path/" + path] elif IS_WINDOWS: @@ -1057,39 +1146,47 @@ def make_relative_rpath_args(path): # Declare extensions and package ################################################################################ - extensions = [] + ext_modules: list[Extension] = [] + # packages that we want to install into site-packages and include them in wheels + includes = ["torch", "torch.*", "torchgen", "torchgen.*"] + # exclude folders that they look like Python packages but are not wanted in wheels excludes = ["tools", "tools.*", "caffe2", "caffe2.*"] - if not cmake_cache_vars["BUILD_FUNCTORCH"]: + if cmake_cache_vars["BUILD_FUNCTORCH"]: + includes.extend(["functorch", "functorch.*"]) + else: excludes.extend(["functorch", "functorch.*"]) - packages = find_packages(exclude=excludes) + packages = find_packages(include=includes, exclude=excludes) C = Extension( "torch._C", libraries=main_libraries, sources=main_sources, language="c", - extra_compile_args=main_compile_args + extra_compile_args, + extra_compile_args=[ + *main_compile_args, + *extra_compile_args, + ], include_dirs=[], library_dirs=library_dirs, - extra_link_args=extra_link_args - + main_link_args - + make_relative_rpath_args("lib"), + extra_link_args=[ + *extra_link_args, + *main_link_args, + *make_relative_rpath_args("lib"), + ], ) - extensions.append(C) + ext_modules.append(C) # These extensions are built by cmake and copied manually in build_extensions() # inside the build_ext implementation if cmake_cache_vars["BUILD_FUNCTORCH"]: - extensions.append( - Extension(name="functorch._C", sources=[]), - ) + ext_modules.append(Extension(name="functorch._C", sources=[])) cmdclass = { - "bdist_wheel": wheel_concatenate, "build_ext": build_ext, "clean": clean, - "install": install, "sdist": sdist, } + if wheel_concatenate is not None: + cmdclass["bdist_wheel"] = wheel_concatenate entry_points = { "console_scripts": [ @@ -1105,7 +1202,7 @@ def make_relative_rpath_args(path): entry_points["console_scripts"].append( "torchfrtrace = tools.flight_recorder.fr_trace:main", ) - return extensions, cmdclass, packages, entry_points, extra_install_requires + return ext_modules, cmdclass, packages, entry_points, extra_install_requires # post run, warnings, printed at the end to make them more visible @@ -1117,11 +1214,11 @@ def make_relative_rpath_args(path): To develop locally: $ python setup.py develop To force cmake to re-generate native build files (off by default): - $ python setup.py develop --cmake + $ CMAKE_FRESH=1 python setup.py develop """ -def print_box(msg): +def print_box(msg: str) -> None: lines = msg.split("\n") size = max(len(l) + 1 for l in lines) print("-" * (size + 2)) @@ -1130,11 +1227,13 @@ def print_box(msg): print("-" * (size + 2)) -def main(): +def main() -> None: if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY: raise RuntimeError( - "Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. Set one to 0 and rerun." + "Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. " + "Set one to 0 and rerun." ) + install_requires = [ "filelock", "typing-extensions>=4.10.0", @@ -1144,23 +1243,10 @@ def main(): "jinja2", "fsspec", ] - if BUILD_PYTHON_ONLY: - install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}") + install_requires += [f"{LIBTORCH_PKG_NAME}=={TORCH_VERSION}"] - use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", "")) - if ( - use_prioritized_text == "" - and platform.system() == "Linux" - and platform.processor() == "aarch64" - ): - print_box( - """ - WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA. - To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1 - """ - ) - if use_prioritized_text == "1" or use_prioritized_text == "True": + if str2bool(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD")): gen_linker_script( filein="cmake/prioritized_text.txt", fout="cmake/linker_script.ld" ) @@ -1172,6 +1258,13 @@ def main(): os.environ["CXXFLAGS"] = ( os.getenv("CXXFLAGS", "") + " -ffunction-sections -fdata-sections" ) + elif platform.system() == "Linux" and platform.processor() == "aarch64": + print_box( + """ + WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA. + To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1 + """ + ) # Parse the command line and check the arguments before we proceed with # building deps and setup. We need to set values so `--help` works. @@ -1180,8 +1273,8 @@ def main(): dist.script_args = sys.argv[1:] try: dist.parse_command_line() - except setuptools.distutils.errors.DistutilsArgError as e: - print(e) + except setuptools.errors.BaseError as e: + print(e, file=sys.stderr) sys.exit(1) mirror_files_into_torchgen() @@ -1189,7 +1282,7 @@ def main(): build_deps() ( - extensions, + ext_modules, cmdclass, packages, entry_points, @@ -1197,17 +1290,6 @@ def main(): ) = configure_extension_build() install_requires += extra_install_requires - extras_require = { - "optree": ["optree>=0.13.0"], - "opt-einsum": ["opt-einsum>=3.3"], - "pyyaml": ["pyyaml"], - } - - # Read in README.md for our long_description - with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - version_range_max = max(sys.version_info[1], 13) + 1 torch_package_data = [ "py.typed", "bin/*", @@ -1246,46 +1328,40 @@ def main(): "utils/model_dump/skeleton.html", "utils/model_dump/code.js", "utils/model_dump/*.mjs", + "_dynamo/graph_break_registry.json", ] if not BUILD_LIBTORCH_WHL: - torch_package_data.extend( - [ - "lib/libtorch_python.so", - "lib/libtorch_python.dylib", - "lib/libtorch_python.dll", - ] - ) + torch_package_data += [ + "lib/libtorch_python.so", + "lib/libtorch_python.dylib", + "lib/libtorch_python.dll", + ] if not BUILD_PYTHON_ONLY: - torch_package_data.extend( - [ - "lib/*.so*", - "lib/*.dylib*", - "lib/*.dll", - "lib/*.lib", - ] - ) - aotriton_image_path = os.path.join(lib_path, "aotriton.images") - aks2_files = [] - for root, dirs, files in os.walk(aotriton_image_path): - subpath = os.path.relpath(root, start=aotriton_image_path) - for fn in files: - aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn)) + torch_package_data += [ + "lib/*.so*", + "lib/*.dylib*", + "lib/*.dll", + "lib/*.lib", + ] + # XXX: Why not use wildcards ["lib/aotriton.images/*", "lib/aotriton.images/**/*"] here? + aotriton_image_path = TORCH_DIR / "lib" / "aotriton.images" + aks2_files = [ + file.relative_to(TORCH_DIR).as_posix() + for file in aotriton_image_path.rglob("*") + if file.is_file() + ] torch_package_data += aks2_files if get_cmake_cache_vars()["USE_TENSORPIPE"]: - torch_package_data.extend( - [ - "include/tensorpipe/*.h", - "include/tensorpipe/**/*.h", - ] - ) + torch_package_data += [ + "include/tensorpipe/*.h", + "include/tensorpipe/**/*.h", + ] if get_cmake_cache_vars()["USE_KINETO"]: - torch_package_data.extend( - [ - "include/kineto/*.h", - "include/kineto/**/*.h", - ] - ) + torch_package_data += [ + "include/kineto/*.h", + "include/kineto/**/*.h", + ] torchgen_package_data = [ "packaged/*", "packaged/**/*", @@ -1293,57 +1369,28 @@ def main(): package_data = { "torch": torch_package_data, } + exclude_package_data = {} if not BUILD_LIBTORCH_WHL: package_data["torchgen"] = torchgen_package_data + exclude_package_data["torchgen"] = ["*.py[co]"] else: # no extensions in BUILD_LIBTORCH_WHL mode - extensions = [] + ext_modules = [] setup( - name=package_name, - version=version, - description=( - "Tensors and Dynamic neural networks in Python with strong GPU acceleration" - ), - long_description=long_description, - long_description_content_type="text/markdown", - ext_modules=extensions, + name=TORCH_PACKAGE_NAME, + version=TORCH_VERSION, + ext_modules=ext_modules, cmdclass=cmdclass, packages=packages, entry_points=entry_points, install_requires=install_requires, - extras_require=extras_require, package_data=package_data, - # TODO fix later Manifest.IN file was previously ignored - include_package_data=False, # defaults to True with pyproject.toml file - url="https://pytorch.org/", - download_url="https://github.com/pytorch/pytorch/tags", - author="PyTorch Team", - author_email="packages@pytorch.org", - python_requires=f">={python_min_version_str}", - # PyPI package information. - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: BSD License", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Mathematics", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - "Programming Language :: C++", - "Programming Language :: Python :: 3", - ] - + [ - f"Programming Language :: Python :: 3.{i}" - for i in range(python_min_version[1], version_range_max) - ], - license="BSD-3-Clause", - keywords="pytorch, machine learning", + exclude_package_data=exclude_package_data, + # Disable automatic inclusion of data files because we want to + # explicitly control with `package_data` above. + include_package_data=False, ) if EMIT_BUILD_WARNING: print_box(build_update_message) diff --git a/test/ao/sparsity/test_activation_sparsifier.py b/test/ao/sparsity/test_activation_sparsifier.py index 56a88596f99478..9c2a10d3355a1b 100644 --- a/test/ao/sparsity/test_activation_sparsifier.py +++ b/test/ao/sparsity/test_activation_sparsifier.py @@ -1,7 +1,6 @@ # Owner(s): ["module: unknown"] import copy -import logging import torch import torch.nn as nn @@ -10,11 +9,10 @@ ActivationSparsifier, ) from torch.ao.pruning.sparsifier.utils import module_to_fqn -from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase - - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, + TestCase, ) @@ -405,3 +403,7 @@ def _vanilla_norm_sparsifier(data, sparsity_level): # check state_dict() after squash_mask() self._check_state_dict(activation_sparsifier) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_composability.py b/test/ao/sparsity/test_composability.py index 8b4586f9979cc2..b3aaf1c6dfbea5 100644 --- a/test/ao/sparsity/test_composability.py +++ b/test/ao/sparsity/test_composability.py @@ -1,8 +1,6 @@ # Owner(s): ["module: unknown"] -import logging - import torch import torch.ao.quantization as tq from torch import nn @@ -15,13 +13,13 @@ prepare_qat_fx, ) from torch.testing._internal.common_quantization import skipIfNoFBGEMM -from torch.testing._internal.common_utils import TestCase, xfailIfS390X - - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + TestCase, + xfailIfS390X, ) + sparse_defaults = { "sparsity_level": 0.8, "sparse_block_shape": (1, 4), @@ -642,3 +640,7 @@ def test_s_prep_q_prep_fx_ref(self): sparsity_level, sparse_config[0]["sparsity_level"] ) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_data_scheduler.py b/test/ao/sparsity/test_data_scheduler.py index 6481867292e479..cc4d8ddae63f14 100644 --- a/test/ao/sparsity/test_data_scheduler.py +++ b/test/ao/sparsity/test_data_scheduler.py @@ -1,19 +1,13 @@ # Owner(s): ["module: unknown"] import copy -import logging import warnings import torch from torch import nn from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier -from torch.testing._internal.common_utils import TestCase - - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO -) +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class ImplementedDataScheduler(BaseDataScheduler): @@ -180,3 +174,7 @@ def test_state_dict(self): name, _, _ = self._get_name_data_config(some_data, defaults) assert scheduler1.base_param[name] == scheduler2.base_param[name] assert scheduler1._last_param[name] == scheduler2._last_param[name] + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index 4f987b994ae86b..99943821574352 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -2,7 +2,6 @@ import copy import itertools -import logging import math import torch @@ -15,12 +14,7 @@ post_training_sparse_quantize, ) from torch.nn.utils.parametrize import is_parametrized -from torch.testing._internal.common_utils import TestCase - - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO -) +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class ImplementedSparsifier(BaseDataSparsifier): @@ -500,9 +494,7 @@ def test_nn_embeddings(self): ( emb1, emb2, - ) = nn.Embedding( - 10, 3 - ), nn.Embedding(20, 3) + ) = nn.Embedding(10, 3), nn.Embedding(20, 3) emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) @@ -633,9 +625,7 @@ def test_nn_embeddings(self): ( emb1, emb2, - ) = nn.Embedding( - 10, 3 - ), nn.Embedding(20, 3) + ) = nn.Embedding(10, 3), nn.Embedding(20, 3) emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) @@ -792,3 +782,7 @@ def test_ptq_quantize_first(self): assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_kernels.py b/test/ao/sparsity/test_kernels.py index 132fa66edf0393..1ffdca5fd343a1 100644 --- a/test/ao/sparsity/test_kernels.py +++ b/test/ao/sparsity/test_kernels.py @@ -19,20 +19,16 @@ qengine_is_qnnpack, qengine_is_x86, ) -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, + TestCase, +) # TODO: Once more test files are created, move the contents to a ao folder. logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -handler = logging.StreamHandler() -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -handler.setFormatter(formatter) - -logger.addHandler(handler) -logger.propagate = False # Prevent duplicate logs if root logger also has handlers class TestQuantizedSparseKernels(TestCase): @@ -222,12 +218,12 @@ def _sparse_layer_test_helper( qmodule_to_check = fqn_to_module(qmodel, fqn_to_check) # check that the modules were converted as expected - assert isinstance( - sqmodule_to_check, sqmodule_expected_converted_class - ), "Convert failed" - assert isinstance( - qmodule_to_check, qmodule_expected_converted_class - ), "Mapping failed" + assert isinstance(sqmodule_to_check, sqmodule_expected_converted_class), ( + "Convert failed" + ) + assert isinstance(qmodule_to_check, qmodule_expected_converted_class), ( + "Mapping failed" + ) row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[ 2: @@ -331,4 +327,4 @@ def test_sparse_qlinear_serdes(self): if __name__ == "__main__": - run_tests() + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_parametrization.py b/test/ao/sparsity/test_parametrization.py index 820c133d894013..ac79b6309cf996 100644 --- a/test/ao/sparsity/test_parametrization.py +++ b/test/ao/sparsity/test_parametrization.py @@ -1,18 +1,11 @@ # Owner(s): ["module: unknown"] -import logging - import torch from torch import nn from torch.ao.pruning.sparsifier import utils from torch.nn.utils import parametrize -from torch.testing._internal.common_utils import TestCase - - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO -) +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class ModelUnderTest(nn.Module): @@ -173,3 +166,7 @@ def test_jit_trace(self): y = model(x) y_hat = model_trace(x) self.assertEqual(y_hat, y) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_scheduler.py b/test/ao/sparsity/test_scheduler.py index 7ac2f42193ca64..38e8fca4cdd845 100644 --- a/test/ao/sparsity/test_scheduler.py +++ b/test/ao/sparsity/test_scheduler.py @@ -4,7 +4,7 @@ from torch import nn from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class ImplementedScheduler(BaseScheduler): @@ -190,3 +190,7 @@ def test_step(self): self.sorted_sparse_levels, msg="Sparsity level is not reaching the target level afer delta_t * n steps ", ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_sparsifier.py b/test/ao/sparsity/test_sparsifier.py index 097d4890dc8f21..ca80fa7dde7fe1 100644 --- a/test/ao/sparsity/test_sparsifier.py +++ b/test/ao/sparsity/test_sparsifier.py @@ -1,7 +1,6 @@ # Owner(s): ["module: unknown"] import itertools -import logging import re import torch @@ -18,12 +17,7 @@ MockSparseLinear, SimpleLinear, ) -from torch.testing._internal.common_utils import TestCase - - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO -) +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class TestBaseSparsifier(TestCase): @@ -484,3 +478,7 @@ def _verify_nearliness(self, mask: torch.Tensor, nearliness: int): assert mask[row, col] == 1 else: assert mask[row, col] == 0 + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_sparsity_utils.py b/test/ao/sparsity/test_sparsity_utils.py index b29be49d571d78..45385bca6f6dd0 100644 --- a/test/ao/sparsity/test_sparsity_utils.py +++ b/test/ao/sparsity/test_sparsity_utils.py @@ -18,7 +18,7 @@ SingleLayerLinearModel, TwoLayerLinearModel, ) -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase logging.basicConfig( @@ -147,3 +147,7 @@ def test_get_arg_info_from_tensor_fqn_fail(self): self.assertEqual(arg_info["module_fqn"], "foo.bar") self.assertEqual(arg_info["tensor_name"], "baz") self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz") + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/ao/sparsity/test_structured_sparsifier.py b/test/ao/sparsity/test_structured_sparsifier.py index 00fdbed68afacf..c62cc3d3053946 100644 --- a/test/ao/sparsity/test_structured_sparsifier.py +++ b/test/ao/sparsity/test_structured_sparsifier.py @@ -1,6 +1,5 @@ # Owner(s): ["module: unknown"] import copy -import logging import random import torch @@ -29,13 +28,13 @@ SimpleConv2d, SimpleLinear, ) -from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase - - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, + TestCase, ) + DEVICES = { torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), @@ -1056,9 +1055,9 @@ def _test_update_mask_on_multiple_layer( mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1] mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2] # Check if either of the least-norm filters is not pruned - assert ( - mask1.item() is not False or mask2.item() is not False - ), "Do not prune all least-norm filters" + assert mask1.item() is not False or mask2.item() is not False, ( + "Do not prune all least-norm filters" + ) # fusion step pruned_model = pruner.prune() @@ -1089,3 +1088,7 @@ def test_update_mask(self): self._test_update_mask_on_multiple_layer( expected_conv1, expected_conv2, device ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_ao_sparsity.py") diff --git a/test/bench_mps_ops.py b/test/bench_mps_ops.py index 319c8eb9ef4023..a7d647c120f42a 100644 --- a/test/bench_mps_ops.py +++ b/test/bench_mps_ops.py @@ -71,6 +71,15 @@ def bench_binary( return rc +def check_eager_vs_compile(rc_c, rc_e, func, dtype): + if not torch.allclose(rc_c, rc_e): + mdiff = (rc_c - rc_e).abs().max() + warnings.warn( + f"Eager and compile reduction do not match for {func.__name__} and {dtype} max_diff={mdiff}", + stacklevel=2, + ) + + def bench_reduction( reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32 ) -> list[Measurement]: @@ -87,17 +96,63 @@ def f(t): x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) rc_c, rc_e = f(x), f_c(x) rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e) - if not torch.allclose(rc_c, rc_e): - mdiff = (rc_c - rc_e).abs().max() - warnings.warn( - f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}", - stacklevel=2, - ) + check_eager_vs_compile(rc_c, rc_e, reduction_func, dtype) rc.append(bench_unary_op(f, x, f"eager-{size}x{size}")) rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}")) return rc +def bench_scan( + scan_func, + device: str = "mps", + dtype: torch.dtype = torch.float32, + with_indices: bool = False, +) -> list[Measurement]: + rc = [] + + # Bench cumsum along different dimensions + for dim in [0, 1]: + + def f(t): + return scan_func(t, dim=dim) + + f_c = torch.compile(f, dynamic=False) + + for size in (32, 128, 512, 1024): + f.__name__ = f"{scan_func.__name__}-dim{dim}-{size}x{size}" + f_c.__name__ = f.__name__ + x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) + rc_c, rc_e = f(x), f_c(x) + if with_indices: + check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype) + check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype) + else: + check_eager_vs_compile(rc_c, rc_e, scan_func, dtype) + rc.append(bench_unary_op(f, x, "eager")) + rc.append(bench_unary_op(f_c, x, "compile")) + + # Bench 1D cumsum for different sizes + def f_1d(t): + return scan_func(t, dim=0) + + f_1d_c = torch.compile(f_1d, dynamic=False) + + for size in (100, 10000, 1000000): + f_1d.__name__ = f"{scan_func.__name__}-1d-{size}" + f_1d_c.__name__ = f_1d.__name__ + x = torch.testing.make_tensor(size, device=device, dtype=dtype) + rc_c, rc_e = f_1d(x), f_1d_c(x) + if with_indices: + check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype) + check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype) + else: + check_eager_vs_compile(rc_c, rc_e, scan_func, dtype) + rc.append(bench_unary_op(f_1d, x, "eager")) + rc.append(bench_unary_op(f_1d_c, x, "compile")) + + return rc + + def main() -> None: dtypes = [torch.float16, torch.float32] if torch.backends.mps.is_macos_or_newer(14, 0): @@ -115,6 +170,18 @@ def main() -> None: rc.extend(bench_reduction(op)) Compare(rc).print() + # Profile scan ops (cumsum) + rc = [] + for dtype in dtypes: + rc.extend(bench_scan(torch.cumsum, dtype=dtype)) + Compare(rc).print() + + # Profile scan with indices ops (cummin) + rc = [] + for dtype in dtypes: + rc.extend(bench_scan(torch.cummin, dtype=dtype, with_indices=True)) + Compare(rc).print() + # Profile binary ops rc = [] ops = [torch.fmax, torch.add] diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py index 969a1584b68c2d..1d8d8e7e35948e 100644 --- a/test/benchmark_utils/test_benchmark_utils.py +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -66,9 +66,10 @@ def to_entry(fn_counts): json.dump(artifacts, f, indent=4) -def load_callgrind_artifacts() -> ( - tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats] -): +def load_callgrind_artifacts() -> tuple[ + benchmark_utils.CallgrindStats, + benchmark_utils.CallgrindStats, +]: """Hermetic artifact to unit test Callgrind wrapper. In addition to collecting counts, this wrapper provides some facilities for diff --git a/test/delete.py b/test/compiled_autograd_skips/FakeTensorDispatchCache.test__upsample_bilinear2d_aa_backward_dynamic_shapes similarity index 100% rename from test/delete.py rename to test/compiled_autograd_skips/FakeTensorDispatchCache.test__upsample_bilinear2d_aa_backward_dynamic_shapes diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_truediv_zero_division b/test/compiled_autograd_skips/MiniOpTest.test_aot_dispatch_dynamic__test_mm similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_truediv_zero_division rename to test/compiled_autograd_skips/MiniOpTest.test_aot_dispatch_dynamic__test_mm diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_independence b/test/compiled_autograd_skips/MiniOpTest.test_aot_dispatch_static__test_mm similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_independence rename to test/compiled_autograd_skips/MiniOpTest.test_aot_dispatch_static__test_mm diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_nested_comprehensions_for b/test/compiled_autograd_skips/PackedSequenceTest.test_pack_padded_sequence similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_nested_comprehensions_for rename to test/compiled_autograd_skips/PackedSequenceTest.test_pack_padded_sequence diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_nested_comprehensions_iter b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_backward_out_of_context similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_nested_comprehensions_iter rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_backward_out_of_context diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_enumerate b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_basic similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_enumerate rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_basic diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_yield b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_double_backward similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_yield rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_double_backward diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_copy b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_inplace_foreach similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_copy rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_inplace_foreach diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_save_base_and_modify_view similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_save_base_and_modify_view diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference_rev b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_save_view_modify_base similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference_rev rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_save_view_modify_base diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_intersection b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_saved_but_not_anymore similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_intersection rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_saved_but_not_anymore diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_isdisjoint b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_views similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_isdisjoint rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_views diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_symmetric_difference b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_with_math_views similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_symmetric_difference rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_with_math_views diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_union b/test/compiled_autograd_skips/TestAllowMutationOnSaved.test_with_out_variant similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_union rename to test/compiled_autograd_skips/TestAllowMutationOnSaved.test_with_out_variant diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_equivalent_equality b/test/compiled_autograd_skips/TestAutodiffSubgraphSlicing.test_bias_as_arg similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_equivalent_equality rename to test/compiled_autograd_skips/TestAutodiffSubgraphSlicing.test_bias_as_arg diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_intersection_empty b/test/compiled_autograd_skips/TestAutodiffSubgraphSlicing.test_bias_as_module_attr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_intersection_empty rename to test/compiled_autograd_skips/TestAutodiffSubgraphSlicing.test_bias_as_module_attr diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_length b/test/compiled_autograd_skips/TestAutodiffSubgraphSlicing.test_constructed_bias similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_length rename to test/compiled_autograd_skips/TestAutodiffSubgraphSlicing.test_constructed_bias diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_difference b/test/compiled_autograd_skips/TestAutograd.test_accumulate_grad_tensor_reference similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_difference rename to test/compiled_autograd_skips/TestAutograd.test_accumulate_grad_tensor_reference diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_equality b/test/compiled_autograd_skips/TestAutograd.test_anomaly_assign_parent_cleanup similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_equality rename to test/compiled_autograd_skips/TestAutograd.test_anomaly_assign_parent_cleanup diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_intersection b/test/compiled_autograd_skips/TestAutograd.test_anomaly_detect_nan similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_intersection rename to test/compiled_autograd_skips/TestAutograd.test_anomaly_detect_nan diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_isdisjoint b/test/compiled_autograd_skips/TestAutograd.test_anomaly_grad_warnings similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_isdisjoint rename to test/compiled_autograd_skips/TestAutograd.test_anomaly_grad_warnings diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_symmetric_difference b/test/compiled_autograd_skips/TestAutograd.test_autograd_inplace_views_cross_dtype similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_symmetric_difference rename to test/compiled_autograd_skips/TestAutograd.test_autograd_inplace_views_cross_dtype diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_union b/test/compiled_autograd_skips/TestAutograd.test_autograd_node_isinstance similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_union rename to test/compiled_autograd_skips/TestAutograd.test_autograd_node_isinstance diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_union_empty b/test/compiled_autograd_skips/TestAutograd.test_callback_adds_callback similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_union_empty rename to test/compiled_autograd_skips/TestAutograd.test_callback_adds_callback diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_and b/test/compiled_autograd_skips/TestAutograd.test_callback_propagates_errors_from_device_thread similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_and rename to test/compiled_autograd_skips/TestAutograd.test_callback_propagates_errors_from_device_thread diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_contains b/test/compiled_autograd_skips/TestAutograd.test_create_graph_and_full_backward_hook_cycle similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_contains rename to test/compiled_autograd_skips/TestAutograd.test_create_graph_and_full_backward_hook_cycle diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_difference b/test/compiled_autograd_skips/TestAutograd.test_current_graph_task_execution_order similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_difference rename to test/compiled_autograd_skips/TestAutograd.test_current_graph_task_execution_order diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_equality b/test/compiled_autograd_skips/TestAutograd.test_current_node similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_equality rename to test/compiled_autograd_skips/TestAutograd.test_current_node diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_or b/test/compiled_autograd_skips/TestAutograd.test_custom_autograd_no_early_free similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_or rename to test/compiled_autograd_skips/TestAutograd.test_custom_autograd_no_early_free diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_setOfFrozensets b/test/compiled_autograd_skips/TestAutograd.test_custom_function_cycle similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_setOfFrozensets rename to test/compiled_autograd_skips/TestAutograd.test_custom_function_cycle diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_sub b/test/compiled_autograd_skips/TestAutograd.test_custom_function_error similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_sub rename to test/compiled_autograd_skips/TestAutograd.test_custom_function_error diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_symmetric_difference b/test/compiled_autograd_skips/TestAutograd.test_custom_function_forward_mode_forward_is_no_op similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_symmetric_difference rename to test/compiled_autograd_skips/TestAutograd.test_custom_function_forward_mode_forward_is_no_op diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_union b/test/compiled_autograd_skips/TestAutograd.test_custom_function_forward_mode_wrong_formula similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_union rename to test/compiled_autograd_skips/TestAutograd.test_custom_function_forward_mode_wrong_formula diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_uniquification b/test/compiled_autograd_skips/TestAutograd.test_custom_function_save_for_forward similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_uniquification rename to test/compiled_autograd_skips/TestAutograd.test_custom_function_save_for_forward diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_xor b/test/compiled_autograd_skips/TestAutograd.test_default_saved_tensors_hooks_double_backward similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_xor rename to test/compiled_autograd_skips/TestAutograd.test_default_saved_tensors_hooks_double_backward diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_and b/test/compiled_autograd_skips/TestAutograd.test_grad_batched_grad similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_and rename to test/compiled_autograd_skips/TestAutograd.test_grad_batched_grad diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_constructor_identity b/test/compiled_autograd_skips/TestAutograd.test_grad_nonleaf_register_hook similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_constructor_identity rename to test/compiled_autograd_skips/TestAutograd.test_grad_nonleaf_register_hook diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_difference b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_backward_mul_by_grad_output similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_difference rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_backward_mul_by_grad_output diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_equality b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_check_batched_grad similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_equality rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_check_batched_grad diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_init b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_check_forward_or_backward_only similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_init rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_check_forward_or_backward_only diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_isdisjoint b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_check_no_differentiable_outputs similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_isdisjoint rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_check_no_differentiable_outputs diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_len b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_complex_non_complex_outputs similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_len rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_complex_non_complex_outputs diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_nested_empty_constructor b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_dense_and_sparse_inputs similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_nested_empty_constructor rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_dense_and_sparse_inputs diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_or b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_forward_ad similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_or rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_forward_ad diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_setOfFrozensets b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_forward_ad_batched_grad similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_setOfFrozensets rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_forward_ad_batched_grad diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_sub b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_multiple_mkldnn_inputs similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_sub rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_multiple_mkldnn_inputs diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_sub_and_super b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_undefined_grad similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_sub_and_super rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_undefined_grad diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_symmetric_difference b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_validates_input_mkldnn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_symmetric_difference rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_validates_input_mkldnn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_union b/test/compiled_autograd_skips/TestAutograd.test_gradcheck_validates_inputs similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_union rename to test/compiled_autograd_skips/TestAutograd.test_gradcheck_validates_inputs diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_uniquification b/test/compiled_autograd_skips/TestAutograd.test_graph_save_on_cpu similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_uniquification rename to test/compiled_autograd_skips/TestAutograd.test_graph_save_on_cpu diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_xor b/test/compiled_autograd_skips/TestAutograd.test_graph_save_on_cpu_cuda similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_xor rename to test/compiled_autograd_skips/TestAutograd.test_graph_save_on_cpu_cuda diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cube b/test/compiled_autograd_skips/TestAutograd.test_hooks similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cube rename to test/compiled_autograd_skips/TestAutograd.test_hooks diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_union b/test/compiled_autograd_skips/TestAutograd.test_input_buffer_accum similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_union rename to test/compiled_autograd_skips/TestAutograd.test_input_buffer_accum diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_update_operator b/test/compiled_autograd_skips/TestAutograd.test_mark_non_differentiable_none similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_mark_non_differentiable_none diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_difference_update_operator b/test/compiled_autograd_skips/TestAutograd.test_naughty_autograd_function_stashing_ctx similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_difference_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_naughty_autograd_function_stashing_ctx diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_intersection_update_operator b/test/compiled_autograd_skips/TestAutograd.test_no_grad_copy similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_intersection_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_no_grad_copy diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_sym_difference_update_operator b/test/compiled_autograd_skips/TestAutograd.test_no_grad_copy_sparse similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_sym_difference_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_no_grad_copy_sparse diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_update_operator b/test/compiled_autograd_skips/TestAutograd.test_node_post_hook_registered_during_unpack_hook similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_node_post_hook_registered_during_unpack_hook diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_union b/test/compiled_autograd_skips/TestAutograd.test_post_accumulate_grad_hook_ordering similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_union rename to test/compiled_autograd_skips/TestAutograd.test_post_accumulate_grad_hook_ordering diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_update_operator b/test/compiled_autograd_skips/TestAutograd.test_profiler_seq_nr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_profiler_seq_nr diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_intersection b/test/compiled_autograd_skips/TestAutograd.test_reentrant_with_callbacks_both_depths similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_intersection rename to test/compiled_autograd_skips/TestAutograd.test_reentrant_with_callbacks_both_depths diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_intersection_update b/test/compiled_autograd_skips/TestAutograd.test_reentrant_with_callbacks_depth_0 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_intersection_update rename to test/compiled_autograd_skips/TestAutograd.test_reentrant_with_callbacks_depth_0 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_union b/test/compiled_autograd_skips/TestAutograd.test_reentrant_with_callbacks_depth_1 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_union rename to test/compiled_autograd_skips/TestAutograd.test_reentrant_with_callbacks_depth_1 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_update_operator b/test/compiled_autograd_skips/TestAutograd.test_reentrant_with_non_leaf_variable_hook similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_reentrant_with_non_leaf_variable_hook diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_intersection b/test/compiled_autograd_skips/TestAutograd.test_save_output_nr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_intersection rename to test/compiled_autograd_skips/TestAutograd.test_save_output_nr diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_intersection_update b/test/compiled_autograd_skips/TestAutograd.test_saved_variable_packing_unpacking_saved_original_with_hooks similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_intersection_update rename to test/compiled_autograd_skips/TestAutograd.test_saved_variable_packing_unpacking_saved_original_with_hooks diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_union b/test/compiled_autograd_skips/TestAutograd.test_setup_context_when_forward_has_default_args similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_union rename to test/compiled_autograd_skips/TestAutograd.test_setup_context_when_forward_has_default_args diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_update_operator b/test/compiled_autograd_skips/TestAutograd.test_sparse_gather_both_scalar similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_sparse_gather_both_scalar diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_union b/test/compiled_autograd_skips/TestAutograd.test_sparse_gather_dim0 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_union rename to test/compiled_autograd_skips/TestAutograd.test_sparse_gather_dim0 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_update_operator b/test/compiled_autograd_skips/TestAutograd.test_sparse_gather_dim1 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_update_operator rename to test/compiled_autograd_skips/TestAutograd.test_sparse_gather_dim1 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_difference b/test/compiled_autograd_skips/TestAutograd.test_sparse_gather_dim_neg similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_difference rename to test/compiled_autograd_skips/TestAutograd.test_sparse_gather_dim_neg diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_difference_update b/test/compiled_autograd_skips/TestAutograd.test_sparse_gather_ind_scalar similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_difference_update rename to test/compiled_autograd_skips/TestAutograd.test_sparse_gather_ind_scalar diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_intersection_update b/test/compiled_autograd_skips/TestAutograd.test_sparse_gather_x_scalar similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_intersection_update rename to test/compiled_autograd_skips/TestAutograd.test_sparse_gather_x_scalar diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_or b/test/compiled_autograd_skips/TestAutograd.test_sparse_mm_backward similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_or rename to test/compiled_autograd_skips/TestAutograd.test_sparse_mm_backward diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_remove_keyerror_unpacking b/test/compiled_autograd_skips/TestAutograd.test_to_sparse_backward similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_remove_keyerror_unpacking rename to test/compiled_autograd_skips/TestAutograd.test_to_sparse_backward diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_setOfFrozensets b/test/compiled_autograd_skips/TestAutograd.test_unrelated_inputs similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_setOfFrozensets rename to test/compiled_autograd_skips/TestAutograd.test_unrelated_inputs diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_set_literal_evaluation_order b/test/compiled_autograd_skips/TestAutogradComplex.test_view_func_for_complex_views similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_set_literal_evaluation_order rename to test/compiled_autograd_skips/TestAutogradComplex.test_view_func_for_complex_views diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_symmetric_difference b/test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_inplace_on_view_gradcheck_cpu similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_symmetric_difference rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_inplace_on_view_gradcheck_cpu diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_symmetric_difference_update b/test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_backward_cpu_complex128 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_symmetric_difference_update rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_backward_cpu_complex128 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_union b/test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_backward_cpu_float64 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_union rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_backward_cpu_float64 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_uniquification b/test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_uniquification rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_update b/test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_float64 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSet.test_update rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_float64 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetOfSets.test_constructor b/test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_mask_autograd_cpu similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetOfSets.test_constructor rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_sparse_mask_autograd_cpu diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_add b/test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_where_functional_cpu similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_add rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCPU.test_where_functional_cpu diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_and b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_inplace_on_view_python_cuda similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_and rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_inplace_on_view_python_cuda diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_clear b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_pin_memory_cuda similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_clear rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_pin_memory_cuda diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_constructor_identity b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_reentrant_parent_error_on_cpu_cuda similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_constructor_identity rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_reentrant_parent_error_on_cpu_cuda diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_difference b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_backward_cuda_complex128 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_difference rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_backward_cuda_complex128 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_difference_update b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_backward_cuda_float64 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_difference_update rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_backward_cuda_float64 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_equality b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_ctor_getter_backward_cuda_complex128 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_equality rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_ctor_getter_backward_cuda_complex128 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_iand b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_ctor_getter_backward_cuda_float64 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_iand rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_ctor_getter_backward_cuda_float64 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_init b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_mask_autograd_cuda similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_init rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_sparse_mask_autograd_cuda diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_inplace_on_self b/test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_where_functional_cuda similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_inplace_on_self rename to test/compiled_autograd_skips/TestAutogradDeviceTypeCUDA.test_where_functional_cuda diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_intersection_update b/test/compiled_autograd_skips/TestAutogradFallback.test_autograd_function_registered_to_cpu_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_intersection_update rename to test/compiled_autograd_skips/TestAutogradFallback.test_autograd_function_registered_to_cpu_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_ior b/test/compiled_autograd_skips/TestAutogradFallback.test_base_does_not_require_grad_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_ior rename to test/compiled_autograd_skips/TestAutogradFallback.test_base_does_not_require_grad_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_isdisjoint b/test/compiled_autograd_skips/TestAutogradFallback.test_composite_registered_to_cpu_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_isdisjoint rename to test/compiled_autograd_skips/TestAutogradFallback.test_composite_registered_to_cpu_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_isub b/test/compiled_autograd_skips/TestAutogradFallback.test_cpu_return_self_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_isub rename to test/compiled_autograd_skips/TestAutogradFallback.test_cpu_return_self_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_ixor b/test/compiled_autograd_skips/TestAutogradFallback.test_inplace_autograd_function_registered_to_cpu_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_ixor rename to test/compiled_autograd_skips/TestAutogradFallback.test_inplace_autograd_function_registered_to_cpu_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_len b/test/compiled_autograd_skips/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_len rename to test/compiled_autograd_skips/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_or b/test/compiled_autograd_skips/TestAutogradFallback.test_no_autograd_kernel_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_or rename to test/compiled_autograd_skips/TestAutogradFallback.test_no_autograd_kernel_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_pop b/test/compiled_autograd_skips/TestAutogradFallback.test_post_autograd_returns_leaf_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_pop rename to test/compiled_autograd_skips/TestAutogradFallback.test_post_autograd_returns_leaf_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_remove_keyerror_set b/test/compiled_autograd_skips/TestAutogradFallback.test_post_autograd_returns_mix_of_requires_grad_tensors_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_remove_keyerror_set rename to test/compiled_autograd_skips/TestAutogradFallback.test_post_autograd_returns_mix_of_requires_grad_tensors_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_remove_keyerror_unpacking b/test/compiled_autograd_skips/TestAutogradFallback.test_supports_tensor_lists_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_remove_keyerror_unpacking rename to test/compiled_autograd_skips/TestAutogradFallback.test_supports_tensor_lists_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_setOfFrozensets b/test/compiled_autograd_skips/TestAutogradFallback.test_undefined_grads_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_setOfFrozensets rename to test/compiled_autograd_skips/TestAutogradFallback.test_undefined_grads_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_set_literal_evaluation_order b/test/compiled_autograd_skips/TestAutogradFallback.test_undefined_inputs_outputs_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_set_literal_evaluation_order rename to test/compiled_autograd_skips/TestAutogradFallback.test_undefined_inputs_outputs_mode_warn diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_sub b/test/compiled_autograd_skips/TestAutogradForwardMode.test_advanced_packing_unpacking similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_sub rename to test/compiled_autograd_skips/TestAutogradForwardMode.test_advanced_packing_unpacking diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_sub_and_super b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_create_graph_vectorize_False_base_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_sub_and_super rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_create_graph_vectorize_False_base_tensor diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_symmetric_difference b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_create_graph_vectorize_False_logging_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_symmetric_difference rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_create_graph_vectorize_False_logging_tensor diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_symmetric_difference_update b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_create_graph_vectorize_True_base_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_symmetric_difference_update rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_create_graph_vectorize_True_base_tensor diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_union b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_create_graph_vectorize_True_logging_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_union rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_create_graph_vectorize_True_logging_tensor diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_uniquification b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_err_check_vectorize_True_base_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_uniquification rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_err_check_vectorize_True_base_tensor diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_update b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_err_check_vectorize_True_logging_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_update rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_err_check_vectorize_True_logging_tensor diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_xor b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_output_vectorized_base_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_xor rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_output_vectorized_base_tensor diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestDecorateSortUndecorate.test_key_with_exception b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_output_vectorized_logging_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_sort-TestDecorateSortUndecorate.test_key_with_exception rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_output_vectorized_logging_tensor diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestOptimizedCompares.test_none_in_tuples b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_multi_input_base_tensor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_sort-TestOptimizedCompares.test_none_in_tuples rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_multi_input_base_tensor diff --git a/test/dynamo_expected_failures/TestBaseSparsifier.test_state_dict b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_multi_input_logging_tensor similarity index 100% rename from test/dynamo_expected_failures/TestBaseSparsifier.test_state_dict rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_multi_input_logging_tensor diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_warn_invalid_degrees_of_freedom_cpu_complex128 b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_simple_base_tensor similarity index 100% rename from test/dynamo_expected_failures/TestReductionsCPU.test_warn_invalid_degrees_of_freedom_cpu_complex128 rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_simple_base_tensor diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_warn_invalid_degrees_of_freedom_cpu_complex64 b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_simple_logging_tensor similarity index 100% rename from test/dynamo_expected_failures/TestReductionsCPU.test_warn_invalid_degrees_of_freedom_cpu_complex64 rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_simple_logging_tensor diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_warn_invalid_degrees_of_freedom_cpu_float32 b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_unrelated_outputs_base_tensor similarity index 100% rename from test/dynamo_expected_failures/TestReductionsCPU.test_warn_invalid_degrees_of_freedom_cpu_float32 rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_unrelated_outputs_base_tensor diff --git a/test/dynamo_expected_failures/TestReductionsCPU.test_warn_invalid_degrees_of_freedom_cpu_float64 b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_unrelated_outputs_logging_tensor similarity index 100% rename from test/dynamo_expected_failures/TestReductionsCPU.test_warn_invalid_degrees_of_freedom_cpu_float64 rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_correctness_unrelated_outputs_logging_tensor diff --git a/test/dynamo_expected_failures/TestVmapAPI.test_fallback_warns_when_warnings_are_enabled b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_raises_no_warnings_base_tensor similarity index 100% rename from test/dynamo_expected_failures/TestVmapAPI.test_fallback_warns_when_warnings_are_enabled rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_raises_no_warnings_base_tensor diff --git a/test/hi.py b/test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_raises_no_warnings_logging_tensor similarity index 100% rename from test/hi.py rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hessian_vectorize_raises_no_warnings_logging_tensor diff --git a/torchgen/executorch/__init__.py b/test/compiled_autograd_skips/TestAutogradFunctional.test_hvp_create_graph_base_tensor similarity index 100% rename from torchgen/executorch/__init__.py rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hvp_create_graph_base_tensor diff --git a/torchgen/executorch/api/__init__.py b/test/compiled_autograd_skips/TestAutogradFunctional.test_hvp_create_graph_logging_tensor similarity index 100% rename from torchgen/executorch/api/__init__.py rename to test/compiled_autograd_skips/TestAutogradFunctional.test_hvp_create_graph_logging_tensor diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_create_graph_vectorize_False_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_create_graph_vectorize_False_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_create_graph_vectorize_False_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_create_graph_vectorize_False_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_create_graph_vectorize_True_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_create_graph_vectorize_True_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_create_graph_vectorize_True_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_create_graph_vectorize_True_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_err_check_vectorize_True_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_err_check_vectorize_True_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_err_check_vectorize_True_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_err_check_vectorize_True_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_output_vectorized_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_output_vectorized_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_output_vectorized_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_output_vectorized_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_scalar_vectorized_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_scalar_vectorized_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_scalar_vectorized_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_scalar_vectorized_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_different_devices_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_different_devices_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_different_devices_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_different_devices_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_different_dtype_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_different_dtype_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_different_dtype_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_different_dtype_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_multi_input_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_multi_input_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_multi_input_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_multi_input_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_multi_input_multi_output_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_multi_input_multi_output_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_multi_input_multi_output_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_multi_input_multi_output_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_simple_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_simple_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_simple_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_simple_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_unrelated_outputs_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_unrelated_outputs_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_unrelated_outputs_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_unrelated_outputs_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_zero_dim_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_zero_dim_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_zero_dim_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_correctness_zero_dim_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_raises_no_warnings_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_raises_no_warnings_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_raises_no_warnings_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jacobian_vectorize_raises_no_warnings_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jvp_create_graph_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jvp_create_graph_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_jvp_create_graph_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_jvp_create_graph_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_vhp_create_graph_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_vhp_create_graph_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_vhp_create_graph_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_vhp_create_graph_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_vjp_create_graph_base_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_vjp_create_graph_base_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradFunctional.test_vjp_create_graph_logging_tensor b/test/compiled_autograd_skips/TestAutogradFunctional.test_vjp_create_graph_logging_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradMultipleDispatchCPU.test_autograd_composite_implicit_and_dispatch_registration_cpu b/test/compiled_autograd_skips/TestAutogradMultipleDispatchCPU.test_autograd_composite_implicit_and_dispatch_registration_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradMultipleDispatchCPU.test_autograd_multiple_dispatch_registrations_cpu b/test/compiled_autograd_skips/TestAutogradMultipleDispatchCPU.test_autograd_multiple_dispatch_registrations_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradMultipleDispatchCPU.test_per_dispatch_key_input_saving_cpu b/test/compiled_autograd_skips/TestAutogradMultipleDispatchCPU.test_per_dispatch_key_input_saving_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradMultipleDispatchCUDA.test_autograd_composite_implicit_and_dispatch_registration_cuda b/test/compiled_autograd_skips/TestAutogradMultipleDispatchCUDA.test_autograd_composite_implicit_and_dispatch_registration_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradMultipleDispatchCUDA.test_autograd_multiple_dispatch_registrations_cuda b/test/compiled_autograd_skips/TestAutogradMultipleDispatchCUDA.test_autograd_multiple_dispatch_registrations_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradMultipleDispatchCUDA.test_backward_single_threaded_cuda b/test/compiled_autograd_skips/TestAutogradMultipleDispatchCUDA.test_backward_single_threaded_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestAutogradStreamSynchronization.test_side_stream_backward_overlap b/test/compiled_autograd_skips/TestAutogradStreamSynchronization.test_side_stream_backward_overlap new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_nn_unfold_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_nn_unfold_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_softmax_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_softmax_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_stack_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_stack_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_to_dense_and_sparse_coo_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_to_dense_and_sparse_coo_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_to_dense_and_sparse_csr_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_to_dense_and_sparse_csr_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_to_dense_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_to_dense_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_to_sparse_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_to_sparse_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_unfold_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_unfold_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestBasicsCPU.test_where_cpu b/test/compiled_autograd_skips/TestBasicsCPU.test_where_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestControlFlowTraced.test_tracing_map_autograd_aot_functionalized b/test/compiled_autograd_skips/TestControlFlowTraced.test_tracing_map_autograd_aot_functionalized new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestControlFlowTraced.test_tracing_map_autograd_symbolic_dict b/test/compiled_autograd_skips/TestControlFlowTraced.test_tracing_map_autograd_symbolic_dict new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestControlFlowTraced.test_tracing_map_autograd_symbolic_list b/test/compiled_autograd_skips/TestControlFlowTraced.test_tracing_map_autograd_symbolic_list new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestControlFlowTraced.test_tracing_map_autograd_symbolic_simple b/test/compiled_autograd_skips/TestControlFlowTraced.test_tracing_map_autograd_symbolic_simple new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_Conv2d_backward_depthwise_cpu_complex128 b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_Conv2d_backward_depthwise_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv3d_same_padding_backward_cpu_complex128 b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv3d_same_padding_backward_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv3d_valid_padding_backward_cpu_complex128 b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv3d_valid_padding_backward_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn1d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn2d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn3d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch1d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch2d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch3d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel1d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel2d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_batch_channel3d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel1d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel2d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_False_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_False_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_False_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_False_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_False_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_False_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_False_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_False_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_True_strided_False_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_True_strided_False_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_True_strided_False_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_True_strided_False_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_True_strided_True_contiguous_False_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_True_strided_True_contiguous_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_True_strided_True_contiguous_True_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_backend_mkldnn_empty_channel3d_has_bias_True_strided_True_contiguous_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_double_backward_stride_cpu b/test/compiled_autograd_skips/TestConvolutionNNDeviceTypeCPU.test_conv_double_backward_stride_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestCppApiParity.test_torch_nn_EmbeddingBag_sparse b/test/compiled_autograd_skips/TestCppApiParity.test_torch_nn_EmbeddingBag_sparse new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestCppApiParity.test_torch_nn_Embedding_sparse b/test/compiled_autograd_skips/TestCppApiParity.test_torch_nn_Embedding_sparse new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_aot_autograd_check_degenerate_cases_check_gradients_False_dynamic_False_cpu b/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_aot_autograd_check_degenerate_cases_check_gradients_False_dynamic_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_aot_autograd_check_degenerate_cases_check_gradients_False_dynamic_True_cpu b/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_aot_autograd_check_degenerate_cases_check_gradients_False_dynamic_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_aot_autograd_check_degenerate_cases_check_gradients_auto_dynamic_False_cpu b/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_aot_autograd_check_degenerate_cases_check_gradients_auto_dynamic_False_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_aot_autograd_check_degenerate_cases_check_gradients_auto_dynamic_True_cpu b/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_aot_autograd_check_degenerate_cases_check_gradients_auto_dynamic_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_global_state_mutation_cpu b/test/compiled_autograd_skips/TestCustomOpTestingCPU.test_global_state_mutation_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNN.test_embedding_sparse_basic b/test/compiled_autograd_skips/TestEmbeddingNN.test_embedding_sparse_basic new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNN.test_embedding_sparse_empty_tensor b/test/compiled_autograd_skips/TestEmbeddingNN.test_embedding_sparse_empty_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu_int32_float32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu_int32_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu_int32_float64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu_int32_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu_int64_float32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu_int64_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu_int64_float64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_EmbeddingBag_per_sample_weights_and_no_offsets_cpu_int64_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_backward_cpu_float64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_backward_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_2D_padding_idx_cpu_float32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_2D_padding_idx_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_2D_padding_idx_cpu_float64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_2D_padding_idx_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_bfloat16_cpu_int32_int32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_bfloat16_cpu_int32_int32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_bfloat16_cpu_int32_int64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_bfloat16_cpu_int32_int64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_bfloat16_cpu_int64_int32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_bfloat16_cpu_int64_int32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_bfloat16_cpu_int64_int64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_bfloat16_cpu_int64_int64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int32_bfloat16 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int32_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int32_float16 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int32_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int32_float32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int32_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int32_float64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int32_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int64_bfloat16 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int64_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int64_float16 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int64_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int64_float32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int64_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int64_float64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int32_int64_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int32_bfloat16 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int32_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int32_float16 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int32_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int32_float32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int32_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int32_float64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int32_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int64_bfloat16 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int64_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int64_float16 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int64_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int64_float32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int64_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int64_float64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_device_cpu_int64_int64_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_half_cpu_int32_int32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_half_cpu_int32_int32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_half_cpu_int32_int64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_half_cpu_int32_int64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_half_cpu_int64_int32 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_half_cpu_int64_int32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_half_cpu_int64_int64 b/test/compiled_autograd_skips/TestEmbeddingNNDeviceTypeCPU.test_embedding_bag_half_cpu_int64_int64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExamplesCorrectnessCPU.test_maml_omniglot_mechanism_functional_call_cpu b/test/compiled_autograd_skips/TestExamplesCorrectnessCPU.test_maml_omniglot_mechanism_functional_call_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExamplesCorrectnessCPU.test_maml_omniglot_mechanism_make_functional_cpu b/test/compiled_autograd_skips/TestExamplesCorrectnessCPU.test_maml_omniglot_mechanism_make_functional_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_cnn_model_mean_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_cnn_model_mean_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_cnn_model_sum_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_cnn_model_sum_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_embedding_model_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_embedding_model_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_group_norm_model_num_dim_1_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_group_norm_model_num_dim_1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_group_norm_model_num_dim_2_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_group_norm_model_num_dim_2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_group_norm_model_num_dim_3_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_group_norm_model_num_dim_3_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_instance_norm_model_num_dim_1_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_instance_norm_model_num_dim_1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_instance_norm_model_num_dim_2_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_instance_norm_model_num_dim_2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_instance_norm_model_num_dim_3_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_instance_norm_model_num_dim_3_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_layer_norm_model_num_dim_1_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_layer_norm_model_num_dim_1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_layer_norm_model_num_dim_2_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_layer_norm_model_num_dim_2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_layer_norm_model_num_dim_3_cpu b/test/compiled_autograd_skips/TestExpandedWeightFunctionalCPU.test_layer_norm_model_num_dim_3_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_circular_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_circular_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_circular_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_circular_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad1_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad1_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad1_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad1size1_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad1size1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad1size1_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad1size1_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad2size1_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad2size1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad2size1_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_pad2size1_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_reflect_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_reflect_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_reflect_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_reflect_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_replicate_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_replicate_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_replicate_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_replicate_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_stride_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_stride_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_stride_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_stride_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_zeros_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_zeros_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_zeros_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv1d_zeros_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_circular_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_circular_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_circular_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_circular_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_dilated_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_dilated_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_dilated_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_dilated_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_no_bias_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_no_bias_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_no_bias_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_no_bias_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_padding_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_padding_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_padding_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_padding_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_reflect_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_reflect_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_reflect_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_reflect_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_replicate_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_replicate_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_replicate_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_replicate_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_strided_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_strided_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_strided_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_strided_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_zeros_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_zeros_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_zeros_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv2d_zeros_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_1x1x1_no_bias_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_1x1x1_no_bias_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_1x1x1_no_bias_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_1x1x1_no_bias_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_circular_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_circular_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_circular_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_circular_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_no_bias_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_no_bias_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_no_bias_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_no_bias_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_replicate_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_replicate_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_replicate_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_replicate_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_stride_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_stride_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_stride_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_stride_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_stride_padding_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_stride_padding_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_stride_padding_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_stride_padding_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_zeros_stride2_pad2_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_zeros_stride2_pad2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_zeros_stride2_pad2_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Conv3d_zeros_stride2_pad2_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Linear_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Linear_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Linear_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Linear_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Linear_no_bias_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Linear_no_bias_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Linear_no_bias_multiple_inputs_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_Linear_no_bias_multiple_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_per_sample_api_failing_cpu b/test/compiled_autograd_skips/TestExpandedWeightModuleCPU.test_per_sample_api_failing_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestFakeQuantize.test_fq_module_per_channel b/test/compiled_autograd_skips/TestFakeQuantize.test_fq_module_per_channel new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenerateOpcheckTests.test_opcheck b/test/compiled_autograd_skips/TestGenerateOpcheckTests.test_opcheck new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenerateOpcheckTests.test_opcheck_customopdef b/test/compiled_autograd_skips/TestGenerateOpcheckTests.test_opcheck_customopdef new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorFake.test_make_fx_model_fwd_bwd b/test/compiled_autograd_skips/TestGenericProxyTensorFake.test_make_fx_model_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorFake.test_make_fx_model_fwd_bwd_wgtupdate b/test/compiled_autograd_skips/TestGenericProxyTensorFake.test_make_fx_model_fwd_bwd_wgtupdate new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorFake.test_proxy_tensor b/test/compiled_autograd_skips/TestGenericProxyTensorFake.test_proxy_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorFake.test_resnet18_backward_trace b/test/compiled_autograd_skips/TestGenericProxyTensorFake.test_resnet18_backward_trace new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorReal.test_make_fx_model_fwd_bwd b/test/compiled_autograd_skips/TestGenericProxyTensorReal.test_make_fx_model_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorReal.test_make_fx_model_fwd_bwd_wgtupdate b/test/compiled_autograd_skips/TestGenericProxyTensorReal.test_make_fx_model_fwd_bwd_wgtupdate new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorReal.test_proxy_tensor b/test/compiled_autograd_skips/TestGenericProxyTensorReal.test_proxy_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorReal.test_resnet18_backward_trace b/test/compiled_autograd_skips/TestGenericProxyTensorReal.test_resnet18_backward_trace new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorSymbolic.test_make_fx_model_fwd_bwd b/test/compiled_autograd_skips/TestGenericProxyTensorSymbolic.test_make_fx_model_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorSymbolic.test_make_fx_model_fwd_bwd_wgtupdate b/test/compiled_autograd_skips/TestGenericProxyTensorSymbolic.test_make_fx_model_fwd_bwd_wgtupdate new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorSymbolic.test_proxy_tensor b/test/compiled_autograd_skips/TestGenericProxyTensorSymbolic.test_proxy_tensor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestGenericProxyTensorSymbolic.test_resnet18_backward_trace b/test/compiled_autograd_skips/TestGenericProxyTensorSymbolic.test_resnet18_backward_trace new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJit.test_script_autograd_grad b/test/compiled_autograd_skips/TestJit.test_script_autograd_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJit.test_script_backward_twice b/test/compiled_autograd_skips/TestJit.test_script_backward_twice new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Conv2d_no_bias b/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Conv2d_no_bias new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Conv3d_1x1x1_no_bias b/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Conv3d_1x1x1_no_bias new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Conv3d_no_bias b/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Conv3d_no_bias new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_no_bias b/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_no_bias new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d_dilated b/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d_dilated new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d_no_bias b/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d_no_bias new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Softsign_no_batch_dim b/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Softsign_no_batch_dim new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Tanhshrink_no_batch_dim b/test/compiled_autograd_skips/TestJitGeneratedModule.test_nn_Tanhshrink_no_batch_dim new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestLinalgCPU.test_invariance_error_spectral_decompositions_cpu_complex128 b/test/compiled_autograd_skips/TestLinalgCPU.test_invariance_error_spectral_decompositions_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_adaptive_avg_pool2d_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_adaptive_avg_pool2d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_autograd_from_mkldnn_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_autograd_from_mkldnn_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_autograd_to_mkldnn_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_autograd_to_mkldnn_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_avg_pool2d_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_avg_pool2d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_avg_pool3d_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_avg_pool3d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_batch_norm_2d_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_batch_norm_2d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_conv2d_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_conv2d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_conv3d_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_conv3d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_gelu_bf16_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_gelu_bf16_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_gelu_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_gelu_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_linear_backward_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_linear_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_linear_non_contiguous_weight_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_linear_non_contiguous_weight_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_max_pool2d_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_max_pool2d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_max_pool3d_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_max_pool3d_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_prelu_bf16_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_prelu_bf16_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_prelu_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_prelu_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_relu__cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_relu__cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_relu_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_relu_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMkldnnCPU.test_reshape_backward_cpu b/test/compiled_autograd_skips/TestMkldnnCPU.test_reshape_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestModels.test_dcgan_models b/test/compiled_autograd_skips/TestModels.test_dcgan_models new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestModels.test_mnist b/test/compiled_autograd_skips/TestModels.test_mnist new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestModels.test_reinforcement_learning b/test/compiled_autograd_skips/TestModels.test_reinforcement_learning new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestModels.test_vae b/test/compiled_autograd_skips/TestModels.test_vae new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestModuleGlobalHooks.test_module_backward_global_hook_writeable b/test/compiled_autograd_skips/TestModuleGlobalHooks.test_module_backward_global_hook_writeable new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestModuleGlobalHooks.test_module_global_hook_invalid_outputs b/test/compiled_autograd_skips/TestModuleGlobalHooks.test_module_global_hook_invalid_outputs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestModuleHookNN.test_hook_invalid_outputs b/test/compiled_autograd_skips/TestModuleHookNN.test_hook_invalid_outputs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestModuleHookNN.test_hooks b/test/compiled_autograd_skips/TestModuleHookNN.test_hooks new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMultithreadAutograd.test_fork_join_in_middle b/test/compiled_autograd_skips/TestMultithreadAutograd.test_fork_join_in_middle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMultithreadAutograd.test_multi_grad_all_hooks b/test/compiled_autograd_skips/TestMultithreadAutograd.test_multi_grad_all_hooks new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMultithreadAutograd.test_multi_grad_any_hooks b/test/compiled_autograd_skips/TestMultithreadAutograd.test_multi_grad_any_hooks new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMultithreadAutograd.test_python_thread_in_middle b/test/compiled_autograd_skips/TestMultithreadAutograd.test_python_thread_in_middle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMultithreadAutograd.test_set_multithreading_enabled_as_context_manager_and_function b/test/compiled_autograd_skips/TestMultithreadAutograd.test_set_multithreading_enabled_as_context_manager_and_function new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMultithreadAutograd.test_simple_backward b/test/compiled_autograd_skips/TestMultithreadAutograd.test_simple_backward new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestMultithreadAutograd.test_simple_backward_same_input b/test/compiled_autograd_skips/TestMultithreadAutograd.test_simple_backward_same_input new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_AdaptiveLogSoftmax b/test/compiled_autograd_skips/TestNN.test_AdaptiveLogSoftmax new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Conv1d_circular_stride2_pad2 b/test/compiled_autograd_skips/TestNN.test_Conv1d_circular_stride2_pad2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Conv1d_reflect_stride2_pad2 b/test/compiled_autograd_skips/TestNN.test_Conv1d_reflect_stride2_pad2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Conv1d_replicate_stride2_pad2 b/test/compiled_autograd_skips/TestNN.test_Conv1d_replicate_stride2_pad2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Conv2d_circular_stride2_pad2 b/test/compiled_autograd_skips/TestNN.test_Conv2d_circular_stride2_pad2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Conv2d_reflect_stride2_pad2 b/test/compiled_autograd_skips/TestNN.test_Conv2d_reflect_stride2_pad2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Conv2d_replicate_stride2_pad2 b/test/compiled_autograd_skips/TestNN.test_Conv2d_replicate_stride2_pad2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Conv3d_circular_stride2_pad2 b/test/compiled_autograd_skips/TestNN.test_Conv3d_circular_stride2_pad2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Conv3d_replicate_stride2_pad2 b/test/compiled_autograd_skips/TestNN.test_Conv3d_replicate_stride2_pad2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_CosineEmbeddingLoss_no_batch_dim_mean b/test/compiled_autograd_skips/TestNN.test_CosineEmbeddingLoss_no_batch_dim_mean new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_CosineEmbeddingLoss_no_batch_dim_none b/test/compiled_autograd_skips/TestNN.test_CosineEmbeddingLoss_no_batch_dim_none new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_CosineEmbeddingLoss_no_batch_dim_sum b/test/compiled_autograd_skips/TestNN.test_CosineEmbeddingLoss_no_batch_dim_sum new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_EmbeddingBag_sparse b/test/compiled_autograd_skips/TestNN.test_EmbeddingBag_sparse new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Embedding_sparse b/test/compiled_autograd_skips/TestNN.test_Embedding_sparse new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_margin_no_reduce b/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_margin_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_no_batch_dim_mean b/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_no_batch_dim_mean new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_no_batch_dim_none b/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_no_batch_dim_none new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_no_batch_dim_sum b/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_no_batch_dim_sum new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_no_reduce b/test/compiled_autograd_skips/TestNN.test_HingeEmbeddingLoss_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Linear b/test/compiled_autograd_skips/TestNN.test_Linear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Linear_no_batch_dim b/test/compiled_autograd_skips/TestNN.test_Linear_no_batch_dim new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Linear_no_bias b/test/compiled_autograd_skips/TestNN.test_Linear_no_bias new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_MarginRankingLoss_no_batch_dim_mean b/test/compiled_autograd_skips/TestNN.test_MarginRankingLoss_no_batch_dim_mean new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_MarginRankingLoss_no_batch_dim_none b/test/compiled_autograd_skips/TestNN.test_MarginRankingLoss_no_batch_dim_none new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_MarginRankingLoss_no_batch_dim_sum b/test/compiled_autograd_skips/TestNN.test_MarginRankingLoss_no_batch_dim_sum new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Mish_no_batch_dim b/test/compiled_autograd_skips/TestNN.test_Mish_no_batch_dim new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_no_batch_dim_mean b/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_no_batch_dim_mean new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_no_batch_dim_none b/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_no_batch_dim_none new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_no_batch_dim_sum b/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_no_batch_dim_sum new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_no_reduce b/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_weights_no_reduce b/test/compiled_autograd_skips/TestNN.test_MultiLabelSoftMarginLoss_weights_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PairwiseDistance b/test/compiled_autograd_skips/TestNN.test_PairwiseDistance new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PairwiseDistance_broadcast_lhs b/test/compiled_autograd_skips/TestNN.test_PairwiseDistance_broadcast_lhs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PairwiseDistance_broadcast_rhs b/test/compiled_autograd_skips/TestNN.test_PairwiseDistance_broadcast_rhs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PairwiseDistance_no_batch_dim b/test/compiled_autograd_skips/TestNN.test_PairwiseDistance_no_batch_dim new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PairwiseDistance_with_non_default_args b/test/compiled_autograd_skips/TestNN.test_PairwiseDistance_with_non_default_args new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PoissonNLLLoss_no_batch_dim_mean b/test/compiled_autograd_skips/TestNN.test_PoissonNLLLoss_no_batch_dim_mean new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PoissonNLLLoss_no_batch_dim_none b/test/compiled_autograd_skips/TestNN.test_PoissonNLLLoss_no_batch_dim_none new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PoissonNLLLoss_no_batch_dim_sum b/test/compiled_autograd_skips/TestNN.test_PoissonNLLLoss_no_batch_dim_sum new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_PoissonNLLLoss_no_reduce b/test/compiled_autograd_skips/TestNN.test_PoissonNLLLoss_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_SiLU_no_batch_dim b/test/compiled_autograd_skips/TestNN.test_SiLU_no_batch_dim new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Softsign_no_batch_dim b/test/compiled_autograd_skips/TestNN.test_Softsign_no_batch_dim new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Tanhshrink_no_batch_dim b/test/compiled_autograd_skips/TestNN.test_Tanhshrink_no_batch_dim new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_TransformerDecoderLayer_gelu_activation b/test/compiled_autograd_skips/TestNN.test_TransformerDecoderLayer_gelu_activation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_TransformerDecoderLayer_relu_activation b/test/compiled_autograd_skips/TestNN.test_TransformerDecoderLayer_relu_activation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_TransformerEncoderLayer_gelu_activation b/test/compiled_autograd_skips/TestNN.test_TransformerEncoderLayer_gelu_activation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_TransformerEncoderLayer_relu_activation b/test/compiled_autograd_skips/TestNN.test_TransformerEncoderLayer_relu_activation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_Transformer_multilayer_coder b/test/compiled_autograd_skips/TestNN.test_Transformer_multilayer_coder new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_TripletMarginLoss_no_batch_dim_mean b/test/compiled_autograd_skips/TestNN.test_TripletMarginLoss_no_batch_dim_mean new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_TripletMarginLoss_no_batch_dim_none b/test/compiled_autograd_skips/TestNN.test_TripletMarginLoss_no_batch_dim_none new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_TripletMarginLoss_no_batch_dim_sum b/test/compiled_autograd_skips/TestNN.test_TripletMarginLoss_no_batch_dim_sum new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_affine_grid b/test/compiled_autograd_skips/TestNN.test_affine_grid new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_affine_grid_3d b/test/compiled_autograd_skips/TestNN.test_affine_grid_3d new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_bilinear_no_bias b/test/compiled_autograd_skips/TestNN.test_bilinear_no_bias new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_cosine_embedding_loss_margin_no_reduce b/test/compiled_autograd_skips/TestNN.test_cosine_embedding_loss_margin_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_cosine_embedding_loss_no_reduce b/test/compiled_autograd_skips/TestNN.test_cosine_embedding_loss_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_elu_inplace_on_view b/test/compiled_autograd_skips/TestNN.test_elu_inplace_on_view new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_grid_sample b/test/compiled_autograd_skips/TestNN.test_grid_sample new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_interpolate b/test/compiled_autograd_skips/TestNN.test_interpolate new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_bias_weightCOO b/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_bias_weightCOO new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_bias_weightCSC b/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_bias_weightCSC new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_bias_weightCSR b/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_bias_weightCSR new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_nobias_weightCOO b/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_nobias_weightCOO new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_nobias_weightCSC b/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_nobias_weightCSC new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_nobias_weightCSR b/test/compiled_autograd_skips/TestNN.test_linear_autograd_device_cpu_nobias_weightCSR new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_normalize b/test/compiled_autograd_skips/TestNN.test_normalize new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_relu_inplace_on_view b/test/compiled_autograd_skips/TestNN.test_relu_inplace_on_view new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_swap_module_params_poisons_acc_grad b/test/compiled_autograd_skips/TestNN.test_swap_module_params_poisons_acc_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_triplet_margin_loss b/test/compiled_autograd_skips/TestNN.test_triplet_margin_loss new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_triplet_margin_loss_no_reduce b/test/compiled_autograd_skips/TestNN.test_triplet_margin_loss_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_triplet_margin_loss_swap b/test/compiled_autograd_skips/TestNN.test_triplet_margin_loss_swap new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_triplet_margin_loss_swap_no_reduce b/test/compiled_autograd_skips/TestNN.test_triplet_margin_loss_swap_no_reduce new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_upsamplingLinear1d b/test/compiled_autograd_skips/TestNN.test_upsamplingLinear1d new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_upsampling_bfloat16 b/test/compiled_autograd_skips/TestNN.test_upsampling_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNN.test_upsampling_not_recompute_scale_factor b/test/compiled_autograd_skips/TestNN.test_upsampling_not_recompute_scale_factor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_GRU_grad_and_gradgrad_cpu_float64 b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_GRU_grad_and_gradgrad_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_GroupNorm_general_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_GroupNorm_general_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_LSTM_differentiable_backward_using_oneDNN_cpu_bfloat16 b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_LSTM_differentiable_backward_using_oneDNN_cpu_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_LSTM_differentiable_backward_using_oneDNN_cpu_float32 b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_LSTM_differentiable_backward_using_oneDNN_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_LSTM_grad_and_gradgrad_cpu_float64 b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_LSTM_grad_and_gradgrad_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_elu_inplace_with_neg_alpha_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_elu_inplace_with_neg_alpha_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_groupnorm_nhwc_cpu_bfloat16 b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_groupnorm_nhwc_cpu_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_groupnorm_nhwc_cpu_float16 b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_groupnorm_nhwc_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_groupnorm_nhwc_cpu_float32 b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_groupnorm_nhwc_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_groupnorm_nhwc_cpu_float64 b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_groupnorm_nhwc_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_leaky_relu_inplace_with_neg_slope_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_leaky_relu_inplace_with_neg_slope_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_triplet_margin_with_distance_loss_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_triplet_margin_with_distance_loss_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_triplet_margin_with_distance_loss_default_parity_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_triplet_margin_with_distance_loss_default_parity_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_False_mode_bicubic_memory_format0_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_False_mode_bicubic_memory_format0_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_False_mode_bicubic_memory_format1_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_False_mode_bicubic_memory_format1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_False_mode_bilinear_memory_format0_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_False_mode_bilinear_memory_format0_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_False_mode_bilinear_memory_format1_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_False_mode_bilinear_memory_format1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_True_mode_bicubic_memory_format0_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_True_mode_bicubic_memory_format0_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_True_mode_bicubic_memory_format1_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_True_mode_bicubic_memory_format1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_True_mode_bilinear_memory_format0_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_True_mode_bilinear_memory_format0_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_True_mode_bilinear_memory_format1_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_False_align_corners_True_mode_bilinear_memory_format1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_True_align_corners_False_mode_bicubic_memory_format0_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_True_align_corners_False_mode_bicubic_memory_format0_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_True_align_corners_False_mode_bicubic_memory_format1_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_True_align_corners_False_mode_bicubic_memory_format1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_True_align_corners_True_mode_bicubic_memory_format0_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_True_align_corners_True_mode_bicubic_memory_format0_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_True_align_corners_True_mode_bicubic_memory_format1_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingBiMode2d_antialias_True_align_corners_True_mode_bicubic_memory_format1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingNearest1d_mode_nearest-exact_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingNearest1d_mode_nearest-exact_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingNearest1d_mode_nearest_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingNearest1d_mode_nearest_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_False_memory_format0_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_False_memory_format0_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_False_memory_format1_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_False_memory_format1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_True_memory_format0_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_True_memory_format0_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_True_memory_format1_cpu b/test/compiled_autograd_skips/TestNNDeviceTypeCPU.test_upsamplingTrilinear3d_align_corners_True_memory_format1_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNNParametrization.test_new_spectral_norm_swap_False b/test/compiled_autograd_skips/TestNNParametrization.test_new_spectral_norm_swap_False new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNamedTensor.test_autograd_ignores_names b/test/compiled_autograd_skips/TestNamedTensor.test_autograd_ignores_names new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNamedTensor.test_autograd_smoke b/test/compiled_autograd_skips/TestNamedTensor.test_autograd_smoke new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNamedTensor.test_autograd_warns_named_grad b/test/compiled_autograd_skips/TestNamedTensor.test_autograd_warns_named_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNamedTensor.test_tensor_grad_is_unnamed b/test/compiled_autograd_skips/TestNamedTensor.test_tensor_grad_is_unnamed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_early_stop_False b/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_early_stop_False new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_early_stop_True b/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_early_stop_True new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_set_early_stop_no_recompution_needed b/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_set_early_stop_no_recompution_needed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_False b/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_False new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_True b/test/compiled_autograd_skips/TestNestedCheckpoint.test_nested_checkpoint_two_children_early_stop_True new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_abs_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_abs_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_accumulate_grad_different_strides_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_accumulate_grad_different_strides_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_as_nested_tensor_propagates_gradients_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_as_nested_tensor_propagates_gradients_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_backward_add_strided_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_backward_add_strided_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_backward_for_add_op_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_backward_for_add_op_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_backward_for_sub_op_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_backward_for_sub_op_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_backward_sub_strided_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_backward_sub_strided_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_dropout_backward_jagged_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_dropout_backward_jagged_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_dropout_backward_strided_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_dropout_backward_strided_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_gelu_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_gelu_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_indexing_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_indexing_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_5d_size_128_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_5d_size_128_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_5d_size_2_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_5d_size_2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_5d_size_32_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_5d_size_32_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_5d_size_4_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_5d_size_4_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_edge_case_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_edge_case_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_1023_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_1023_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_1024_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_1024_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_128_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_128_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_256_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_256_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_2_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_32_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_32_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_4_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_4_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_512_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_512_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_513_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_layer_norm_backward_size_513_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_masked_fill_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_masked_fill_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_bmm_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_bmm_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_bmm_gradcheck_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_bmm_gradcheck_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_from_list_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_from_list_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_from_mask_and_to_padded_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_from_mask_and_to_padded_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_from_padded_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_from_padded_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_from_padded_fused_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_from_padded_fused_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_generates_leaf_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_generates_leaf_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_linear_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_linear_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_linear_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_linear_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_linear_plus_transpose_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_linear_plus_transpose_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_matmul_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_matmul_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_matmul_gradcheck_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_matmul_gradcheck_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_reshape_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_reshape_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_reshape_gradcheck_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_reshape_gradcheck_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_softmax_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_softmax_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_squeeze_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_squeeze_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_squeeze_gradcheck_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_squeeze_gradcheck_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_to_padded_tensor_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_to_padded_tensor_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_transpose_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_transpose_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_transpose_gradcheck_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_transpose_gradcheck_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_unsqueeze_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_unsqueeze_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_unsqueeze_gradcheck_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_nested_tensor_unsqueeze_gradcheck_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_relu_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_relu_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_selu_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_selu_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_split_with_sizes_flow_through_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_split_with_sizes_flow_through_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_to_buffer_series_ops_grad_with_broadcast_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_to_buffer_series_ops_grad_with_broadcast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_unbind_flow_through_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_unbind_flow_through_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_values_grad_with_broadcast_cpu b/test/compiled_autograd_skips/TestNestedTensorAutogradCPU.test_values_grad_with_broadcast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_abs_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_abs_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_accumulate_grad_different_strides_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_accumulate_grad_different_strides_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_as_nested_tensor_propagates_gradients_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_as_nested_tensor_propagates_gradients_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_backward_add_strided_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_backward_add_strided_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_backward_for_add_op_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_backward_for_add_op_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_backward_for_sub_op_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_backward_for_sub_op_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_backward_sub_strided_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_backward_sub_strided_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_dropout_backward_jagged_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_dropout_backward_jagged_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_dropout_backward_strided_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_dropout_backward_strided_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_gelu_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_gelu_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_indexing_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_indexing_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_5d_size_128_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_5d_size_128_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_5d_size_2_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_5d_size_2_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_5d_size_32_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_5d_size_32_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_5d_size_4_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_5d_size_4_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_edge_case_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_edge_case_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_1023_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_1023_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_1024_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_1024_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_128_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_128_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_256_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_256_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_2_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_2_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_32_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_32_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_4_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_4_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_512_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_512_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_513_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_layer_norm_backward_size_513_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_masked_fill_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_masked_fill_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_bmm_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_bmm_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_bmm_gradcheck_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_bmm_gradcheck_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_from_list_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_from_list_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_from_mask_and_to_padded_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_from_mask_and_to_padded_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_from_padded_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_from_padded_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_from_padded_fused_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_from_padded_fused_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_generates_leaf_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_generates_leaf_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_linear_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_linear_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_linear_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_linear_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_linear_plus_transpose_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_linear_plus_transpose_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_matmul_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_matmul_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_matmul_gradcheck_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_matmul_gradcheck_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_reshape_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_reshape_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_reshape_gradcheck_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_reshape_gradcheck_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_softmax_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_softmax_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_squeeze_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_squeeze_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_squeeze_gradcheck_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_squeeze_gradcheck_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_to_padded_tensor_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_to_padded_tensor_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_transpose_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_transpose_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_transpose_gradcheck_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_transpose_gradcheck_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_unsqueeze_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_unsqueeze_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_unsqueeze_gradcheck_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_nested_tensor_unsqueeze_gradcheck_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_relu_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_relu_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_selu_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_selu_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_split_with_sizes_flow_through_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_split_with_sizes_flow_through_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_to_buffer_series_ops_grad_with_broadcast_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_to_buffer_series_ops_grad_with_broadcast_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_unbind_flow_through_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_unbind_flow_through_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_values_grad_with_broadcast_cuda b/test/compiled_autograd_skips/TestNestedTensorAutogradCUDA.test_values_grad_with_broadcast_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_detach_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_detach_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_detach_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_detach_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_detach_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_detach_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_chunk_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_chunk_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_chunk_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_chunk_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_chunk_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_chunk_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_indexing_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_indexing_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_indexing_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_indexing_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_indexing_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_indexing_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_sum_dim_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCPU.test_nested_tensor_sum_dim_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_detach_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_detach_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_detach_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_detach_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_detach_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_detach_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_embedding_jagged_cuda b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_embedding_jagged_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_chunk_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_chunk_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_chunk_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_chunk_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_chunk_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_chunk_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_indexing_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_indexing_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_indexing_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_indexing_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_indexing_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorDeviceTypeCUDA.test_nested_tensor_indexing_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorOpInfoCPU.test_nested_tensor_input_mutation_backward_cpu b/test/compiled_autograd_skips/TestNestedTensorOpInfoCPU.test_nested_tensor_input_mutation_backward_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorOpInfoCUDA.test_nested_tensor_input_mutation_backward_cuda b/test/compiled_autograd_skips/TestNestedTensorOpInfoCUDA.test_nested_tensor_input_mutation_backward_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_apply__cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_apply__cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_autograd_function_with_None_grad_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_autograd_function_with_None_grad_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_binary_pointwise_broadcasting_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_binary_pointwise_broadcasting_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_binary_pointwise_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_binary_pointwise_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_binary_pointwise_transposed_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_binary_pointwise_transposed_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_composite_op_in_inference_mode_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_composite_op_in_inference_mode_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_construction_from_list_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_construction_from_list_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_is_contiguous_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_is_contiguous_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_False_components_require_grad_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_op_different_output_shape_dim_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_2d_input_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_2d_input_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_2d_input_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_2d_input_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_2d_input_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_2d_input_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_2d_input_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_2d_input_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_operate_on_batch_dim_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_operate_on_batch_dim_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_operate_on_batch_dim_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_operate_on_batch_dim_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_operate_on_batch_dim_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_operate_on_batch_dim_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_operate_on_batch_dim_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_operate_on_batch_dim_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_reduce_ragged_idx_1_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_reduce_ragged_idx_1_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_reduce_ragged_idx_1_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_layer_norm_reduce_ragged_idx_1_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_linear_nt_dim_3_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_linear_nt_dim_3_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_linear_nt_dim_4_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_linear_nt_dim_4_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_linear_nt_dim_5_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_linear_nt_dim_5_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_njt_cat_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_njt_cat_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_batch_only_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_1_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_1_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_1_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_1_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_1_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_1_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_1_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_1_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_1_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_2_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_2_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_2_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_2_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_2_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_2_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_2_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_mean_transpose_offset_2_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_1_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape_sum_transpose_offset_2_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_transpose_non_ragged_dim_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_mean_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_mean_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_op_dim_with_lengths_different_output_shape_sum_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_reshape_decomp_requires_grad_True_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_reshape_decomp_requires_grad_True_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_cpu_bfloat16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_cpu_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_with_constant_sequence_length_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_with_constant_sequence_length_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_with_constant_sequence_length_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_with_constant_sequence_length_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_with_constant_sequence_length_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sdpa_with_constant_sequence_length_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_1_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_1_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_1_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_1_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_1_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_1_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_1_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_1_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_1_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_1_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_1_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_1_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_1_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_1_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_1_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_1_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_2_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_2_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_2_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_2_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_2_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_2_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_2_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape_transpose_offset_2_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_transpose_non_ragged_dim_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_transpose_non_ragged_dim_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_transpose_non_ragged_dim_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_transpose_non_ragged_dim_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_transpose_non_ragged_dim_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_transpose_non_ragged_dim_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_transpose_non_ragged_dim_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_transpose_non_ragged_dim_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_with_lengths_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_with_lengths_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_with_lengths_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_with_lengths_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_with_lengths_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_with_lengths_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_with_lengths_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_dim_with_lengths_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_reduce_batch_dim_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_reduce_batch_dim_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_reduce_batch_dim_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_reduce_batch_dim_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_reduce_batch_dim_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_reduce_batch_dim_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_reduce_batch_dim_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_softmax_reduce_batch_dim_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_squeeze_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_squeeze_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_batch_and_non_batch_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_False_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_False_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_False_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_False_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_True_requires_grad_False_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_True_requires_grad_False_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_True_requires_grad_True_components_require_grad_False_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_sum_dim_reduce_ragged_and_non_batch_keepdim_True_requires_grad_True_components_require_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_tensor_attributes_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_tensor_attributes_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_copy_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_copy_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_2_requires_grad_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_2_requires_grad_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_2_requires_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_2_requires_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_2_requires_grad_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_2_requires_grad_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_3_requires_grad_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_3_requires_grad_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_3_requires_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_3_requires_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_3_requires_grad_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_3_requires_grad_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_4_requires_grad_True_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_4_requires_grad_True_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_4_requires_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_4_requires_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_4_requires_grad_True_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_to_padded_tensor_nt_dim_4_requires_grad_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unary_pointwise_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unary_pointwise_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unary_pointwise_transposed_inputs_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unary_pointwise_transposed_inputs_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_backward_cpu_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_backward_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_backward_cpu_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_backward_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_backward_cpu_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_backward_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_transpose_ragged_idx_2_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_transpose_ragged_idx_2_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_transpose_ragged_idx_3_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_transpose_ragged_idx_3_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_transpose_ragged_idx_last_dim_cpu b/test/compiled_autograd_skips/TestNestedTensorSubclassCPU.test_unbind_transpose_ragged_idx_last_dim_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_False_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_jagged_requires_grad_True_contiguous_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_False_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_2_layout_strided_requires_grad_True_contiguous_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_False_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_jagged_requires_grad_True_contiguous_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_False_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_3_layout_strided_requires_grad_True_contiguous_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_False_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_jagged_requires_grad_True_contiguous_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_False_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_as_nested_tensor_from_tensor_dim_4_layout_strided_requires_grad_True_contiguous_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_autograd_function_with_None_grad_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_autograd_function_with_None_grad_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_binary_pointwise_broadcasting_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_binary_pointwise_broadcasting_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_binary_pointwise_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_binary_pointwise_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_binary_pointwise_transposed_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_binary_pointwise_transposed_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_chunk_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_chunk_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_flex_attention_converts_stacked_seq_indices_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_flex_attention_converts_stacked_seq_indices_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_flex_attention_noncontig_with_holes_False_cross_attention_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_flex_attention_noncontig_with_holes_False_cross_attention_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_flex_attention_noncontig_with_holes_False_cross_attention_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_flex_attention_noncontig_with_holes_False_cross_attention_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_as_nested_tensor_components_require_grad_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_False_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_layout_construction_nested_tensor_requires_grad_True_components_require_grad_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_False_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_jagged_view_from_values_offsets_requires_grad_True_values_is_view_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_linear_backward_memory_usage_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_linear_backward_memory_usage_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_linear_nt_dim_3_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_linear_nt_dim_3_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_linear_nt_dim_4_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_linear_nt_dim_4_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_linear_nt_dim_5_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_linear_nt_dim_5_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_njt_cat_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_njt_cat_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_reshape_decomp_requires_grad_True_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_reshape_decomp_requires_grad_True_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_cuda_bfloat16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_cuda_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_with_constant_sequence_length_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_with_constant_sequence_length_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_with_constant_sequence_length_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_with_constant_sequence_length_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_with_constant_sequence_length_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_sdpa_with_constant_sequence_length_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_softmax_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_softmax_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_2_requires_grad_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_2_requires_grad_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_2_requires_grad_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_2_requires_grad_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_2_requires_grad_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_2_requires_grad_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_3_requires_grad_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_3_requires_grad_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_3_requires_grad_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_3_requires_grad_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_3_requires_grad_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_3_requires_grad_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_4_requires_grad_True_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_4_requires_grad_True_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_4_requires_grad_True_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_4_requires_grad_True_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_4_requires_grad_True_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_to_padded_tensor_nt_dim_4_requires_grad_True_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unary_pointwise_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unary_pointwise_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unary_pointwise_transposed_inputs_cuda b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unary_pointwise_transposed_inputs_cuda new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unbind_backward_cuda_float16 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unbind_backward_cuda_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unbind_backward_cuda_float32 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unbind_backward_cuda_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unbind_backward_cuda_float64 b/test/compiled_autograd_skips/TestNestedTensorSubclassCUDA.test_unbind_backward_cuda_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestOpenReg.test_autograd_init b/test/compiled_autograd_skips/TestOpenReg.test_autograd_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_adaptive_pooling_empty_output_size_cpu_float32 b/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_adaptive_pooling_empty_output_size_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_adaptive_pooling_empty_output_size_cpu_float64 b/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_adaptive_pooling_empty_output_size_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_max_pool3d_ndhwc_cpu_bfloat16 b/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_max_pool3d_ndhwc_cpu_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_max_pool3d_ndhwc_cpu_float16 b/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_max_pool3d_ndhwc_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_max_pool3d_ndhwc_cpu_float32 b/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_max_pool3d_ndhwc_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_max_pool3d_ndhwc_cpu_float64 b/test/compiled_autograd_skips/TestPoolingNNDeviceTypeCPU.test_max_pool3d_ndhwc_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestProfiler.test_profiler_fwd_bwd_link b/test/compiled_autograd_skips/TestProfiler.test_profiler_fwd_bwd_link new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestProfiler.test_source b/test/compiled_autograd_skips/TestProfiler.test_source new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPythonDispatch.test_custom_autograd b/test/compiled_autograd_skips/TestPythonDispatch.test_custom_autograd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPythonDispatch.test_shallow_copy_and_detach b/test/compiled_autograd_skips/TestPythonDispatch.test_shallow_copy_and_detach new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestPythonDispatch.test_subclass_autograd_device_check b/test/compiled_autograd_skips/TestPythonDispatch.test_subclass_autograd_device_check new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_amax_grad b/test/compiled_autograd_skips/TestReductions.test_amax_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_amin_grad b/test/compiled_autograd_skips/TestReductions.test_amin_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_backward b/test/compiled_autograd_skips/TestReductions.test_backward new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_mean_dim_grad b/test/compiled_autograd_skips/TestReductions.test_mean_dim_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_mean_grad_case_1a b/test/compiled_autograd_skips/TestReductions.test_mean_grad_case_1a new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_mean_grad_case_1b b/test/compiled_autograd_skips/TestReductions.test_mean_grad_case_1b new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_mean_grad_case_1e b/test/compiled_autograd_skips/TestReductions.test_mean_grad_case_1e new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_prod_grad b/test/compiled_autograd_skips/TestReductions.test_prod_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestReductions.test_sum_grad b/test/compiled_autograd_skips/TestReductions.test_sum_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestScatterGatherCPU.test_gather_backward_with_empty_index_tensor_sparse_grad_True_cpu_float32 b/test/compiled_autograd_skips/TestScatterGatherCPU.test_gather_backward_with_empty_index_tensor_sparse_grad_True_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestScatterGatherCPU.test_gather_backward_with_empty_index_tensor_sparse_grad_True_cpu_float64 b/test/compiled_autograd_skips/TestScatterGatherCPU.test_gather_backward_with_empty_index_tensor_sparse_grad_True_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestScript.test_cat b/test/compiled_autograd_skips/TestScript.test_cat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestScript.test_linear_grad b/test/compiled_autograd_skips/TestScript.test_linear_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestScript.test_mm_batching b/test/compiled_autograd_skips/TestScript.test_mm_batching new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestScript.test_stack b/test/compiled_autograd_skips/TestScript.test_stack new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSelectiveActivationCheckpoint.test_flops_and_mem b/test/compiled_autograd_skips/TestSelectiveActivationCheckpoint.test_flops_and_mem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSC_masked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSC_masked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSC_masked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSC_masked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSC_nonmasked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSC_nonmasked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSC_nonmasked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSC_nonmasked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSR_masked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSR_masked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSR_masked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSR_masked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSR_nonmasked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSR_nonmasked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSR_nonmasked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseBSR_nonmasked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCOO_masked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCOO_masked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCOO_masked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCOO_masked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCOO_nonmasked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCOO_nonmasked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCOO_nonmasked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCOO_nonmasked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSC_masked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSC_masked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSC_masked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSC_masked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSC_nonmasked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSC_nonmasked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSC_nonmasked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSC_nonmasked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSR_masked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSR_masked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSR_masked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSR_masked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSR_nonmasked_fast_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSR_nonmasked_fast_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSR_nonmasked_slow_cpu b/test/compiled_autograd_skips/TestSparseAnyCPU.test_as_sparse_gradcheck_SparseCSR_nonmasked_slow_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_masked_fast_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_masked_fast_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_masked_fast_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_masked_fast_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_masked_slow_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_masked_slow_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_masked_slow_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_masked_slow_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_sparse_fast_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_sparse_fast_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_sparse_fast_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_sparse_fast_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_sparse_slow_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_sparse_slow_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_sparse_slow_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseBSR_sparse_slow_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_masked_fast_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_masked_fast_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_masked_fast_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_masked_fast_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_masked_slow_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_masked_slow_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_masked_slow_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_masked_slow_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_sparse_fast_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_sparse_fast_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_sparse_fast_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_sparse_fast_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_sparse_slow_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_sparse_slow_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_sparse_slow_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSC_sparse_slow_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_masked_fast_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_masked_fast_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_masked_fast_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_masked_fast_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_masked_slow_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_masked_slow_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_masked_slow_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_masked_slow_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_sparse_fast_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_sparse_fast_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_sparse_fast_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_sparse_fast_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_sparse_slow_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_sparse_slow_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_sparse_slow_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_mm_SparseCSR_sparse_slow_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSC_int64_masked_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSC_int64_masked_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSC_int64_masked_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSC_int64_masked_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSC_int64_sparse_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSC_int64_sparse_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSC_int64_sparse_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSC_int64_sparse_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSR_int64_masked_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSR_int64_masked_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSR_int64_masked_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSR_int64_masked_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSR_int64_sparse_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSR_int64_sparse_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSR_int64_sparse_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseBSR_int64_sparse_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCOO_int64_masked_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCOO_int64_masked_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCOO_int64_masked_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCOO_int64_masked_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCOO_int64_sparse_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCOO_int64_sparse_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCOO_int64_sparse_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCOO_int64_sparse_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSC_int64_masked_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSC_int64_masked_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSC_int64_masked_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSC_int64_masked_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSC_int64_sparse_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSC_int64_sparse_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSC_int64_sparse_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSC_int64_sparse_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSR_int64_masked_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSR_int64_masked_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSR_int64_masked_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSR_int64_masked_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSR_int64_sparse_cpu_complex128 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSR_int64_sparse_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSR_int64_sparse_cpu_float64 b/test/compiled_autograd_skips/TestSparseAnyCPU.test_gradcheck_to_dense_SparseCSR_int64_sparse_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_Sparse_to_Sparse_copy__cpu_bfloat16 b/test/compiled_autograd_skips/TestSparseCPU.test_Sparse_to_Sparse_copy__cpu_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_Sparse_to_Sparse_copy__cpu_complex128 b/test/compiled_autograd_skips/TestSparseCPU.test_Sparse_to_Sparse_copy__cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_Sparse_to_Sparse_copy__cpu_float64 b/test/compiled_autograd_skips/TestSparseCPU.test_Sparse_to_Sparse_copy__cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_log_softmax_zero_nnz_cpu_float32 b/test/compiled_autograd_skips/TestSparseCPU.test_log_softmax_zero_nnz_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_log_softmax_zero_nnz_cpu_float64 b/test/compiled_autograd_skips/TestSparseCPU.test_log_softmax_zero_nnz_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_softmax_zero_nnz_cpu_float32 b/test/compiled_autograd_skips/TestSparseCPU.test_softmax_zero_nnz_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_softmax_zero_nnz_cpu_float64 b/test/compiled_autograd_skips/TestSparseCPU.test_softmax_zero_nnz_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_sparse_mask_backward_cpu_complex128 b/test/compiled_autograd_skips/TestSparseCPU.test_sparse_mask_backward_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_sparse_mask_backward_cpu_float64 b/test/compiled_autograd_skips/TestSparseCPU.test_sparse_mask_backward_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_sparse_matmul_cpu_complex128 b/test/compiled_autograd_skips/TestSparseCPU.test_sparse_matmul_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_sparse_matmul_cpu_complex64 b/test/compiled_autograd_skips/TestSparseCPU.test_sparse_matmul_cpu_complex64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_sparse_matmul_cpu_float32 b/test/compiled_autograd_skips/TestSparseCPU.test_sparse_matmul_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_sparse_matmul_cpu_float64 b/test/compiled_autograd_skips/TestSparseCPU.test_sparse_matmul_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_sparse_mul_masked_cpu_float64 b/test/compiled_autograd_skips/TestSparseCPU.test_sparse_mul_masked_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCPU.test_sparse_mul_sparse_cpu_float64 b/test/compiled_autograd_skips/TestSparseCPU.test_sparse_mul_sparse_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_autograd_dense_output_addmm_cpu_float64 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_autograd_dense_output_addmm_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_autograd_dense_output_addmv_cpu_float64 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_autograd_dense_output_addmv_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_mul_cpu_float32 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_mul_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_mul_cpu_float64 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_mul_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sampled_addmm_autograd_cpu_complex128 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sampled_addmm_autograd_cpu_complex128 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sampled_addmm_autograd_cpu_complex64 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sampled_addmm_autograd_cpu_complex64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sampled_addmm_autograd_cpu_float32 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sampled_addmm_autograd_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sampled_addmm_autograd_cpu_float64 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sampled_addmm_autograd_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sparse_mm_reduce_sum_cpu_bfloat16 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sparse_mm_reduce_sum_cpu_bfloat16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sparse_mm_reduce_sum_cpu_float16 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sparse_mm_reduce_sum_cpu_float16 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sparse_mm_reduce_sum_cpu_float32 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sparse_mm_reduce_sum_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sparse_mm_reduce_sum_cpu_float64 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sparse_mm_reduce_sum_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sum_cpu_float32 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sum_cpu_float32 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestSparseCSRCPU.test_sum_cpu_float64 b/test/compiled_autograd_skips/TestSparseCSRCPU.test_sum_cpu_float64 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_addcmul b/test/compiled_autograd_skips/TestTEFuserDynamic.test_addcmul new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_clamp b/test/compiled_autograd_skips/TestTEFuserDynamic.test_clamp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_comparison_eq_ne b/test/compiled_autograd_skips/TestTEFuserDynamic.test_comparison_eq_ne new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_comparison_ge_le b/test/compiled_autograd_skips/TestTEFuserDynamic.test_comparison_ge_le new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_comparison_gt_lt b/test/compiled_autograd_skips/TestTEFuserDynamic.test_comparison_gt_lt new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_concat_invariant b/test/compiled_autograd_skips/TestTEFuserDynamic.test_concat_invariant new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_exp b/test/compiled_autograd_skips/TestTEFuserDynamic.test_exp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_hardsigmoid_fwd_bwd b/test/compiled_autograd_skips/TestTEFuserDynamic.test_hardsigmoid_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_hardswish_fwd_bwd b/test/compiled_autograd_skips/TestTEFuserDynamic.test_hardswish_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_inlined_optimized_graph b/test/compiled_autograd_skips/TestTEFuserDynamic.test_inlined_optimized_graph new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_lerp b/test/compiled_autograd_skips/TestTEFuserDynamic.test_lerp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_lstm_concat b/test/compiled_autograd_skips/TestTEFuserDynamic.test_lstm_concat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_lstm_traced b/test/compiled_autograd_skips/TestTEFuserDynamic.test_lstm_traced new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_relu b/test/compiled_autograd_skips/TestTEFuserDynamic.test_relu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_relu_fwd_bwd b/test/compiled_autograd_skips/TestTEFuserDynamic.test_relu_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserDynamic.test_small_constant b/test/compiled_autograd_skips/TestTEFuserDynamic.test_small_constant new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_addcmul b/test/compiled_autograd_skips/TestTEFuserStatic.test_addcmul new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_chunk_distributes b/test/compiled_autograd_skips/TestTEFuserStatic.test_chunk_distributes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_clamp b/test/compiled_autograd_skips/TestTEFuserStatic.test_clamp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_comparison_eq_ne b/test/compiled_autograd_skips/TestTEFuserStatic.test_comparison_eq_ne new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_comparison_ge_le b/test/compiled_autograd_skips/TestTEFuserStatic.test_comparison_ge_le new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_comparison_gt_lt b/test/compiled_autograd_skips/TestTEFuserStatic.test_comparison_gt_lt new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_concat_invariant b/test/compiled_autograd_skips/TestTEFuserStatic.test_concat_invariant new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_constant_chunk_shapes b/test/compiled_autograd_skips/TestTEFuserStatic.test_constant_chunk_shapes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_exp b/test/compiled_autograd_skips/TestTEFuserStatic.test_exp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_hardsigmoid_fwd_bwd b/test/compiled_autograd_skips/TestTEFuserStatic.test_hardsigmoid_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_hardswish_fwd_bwd b/test/compiled_autograd_skips/TestTEFuserStatic.test_hardswish_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_lerp b/test/compiled_autograd_skips/TestTEFuserStatic.test_lerp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_lstm_concat b/test/compiled_autograd_skips/TestTEFuserStatic.test_lstm_concat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_lstm_traced b/test/compiled_autograd_skips/TestTEFuserStatic.test_lstm_traced new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_milstm b/test/compiled_autograd_skips/TestTEFuserStatic.test_milstm new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_relu b/test/compiled_autograd_skips/TestTEFuserStatic.test_relu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_relu_fwd_bwd b/test/compiled_autograd_skips/TestTEFuserStatic.test_relu_fwd_bwd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestTEFuserStatic.test_small_constant b/test/compiled_autograd_skips/TestTEFuserStatic.test_small_constant new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestViewOpsCPU.test_as_strided_gradients_cpu b/test/compiled_autograd_skips/TestViewOpsCPU.test_as_strided_gradients_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestViewOpsLAZY.test_as_strided_gradients_lazy b/test/compiled_autograd_skips/TestViewOpsLAZY.test_as_strided_gradients_lazy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestViewOpsLAZY.test_view_copy_lazy b/test/compiled_autograd_skips/TestViewOpsLAZY.test_view_copy_lazy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapAPILegacy.test_batched_gradient_basic b/test/compiled_autograd_skips/TestVmapAPILegacy.test_batched_gradient_basic new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapAPILegacy.test_fallback_with_undefined_grad b/test/compiled_autograd_skips/TestVmapAPILegacy.test_fallback_with_undefined_grad new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_add_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_add_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_binary_cross_entropy_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_binary_cross_entropy_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_diagonal_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_diagonal_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_div_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_div_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_expand_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_expand_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_index_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_index_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_inplace_manyview_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_inplace_manyview_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_inplace_on_view_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_inplace_on_view_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_lgamma_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_lgamma_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_log1p_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_log1p_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_log_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_log_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_logsumexp_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_logsumexp_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_max_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_max_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_median_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_median_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_min_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_min_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_mul_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_mul_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_permute_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_permute_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_reshape_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_reshape_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_select_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_select_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_sigmoid_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_sigmoid_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_slice_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_slice_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_stack_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_stack_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_sub_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_sub_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_threshold_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_threshold_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_trace_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_trace_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_unrelated_output_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_unrelated_output_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_unrelated_output_multiple_grad_cpu b/test/compiled_autograd_skips/TestVmapBatchedGradientLegacyCPU.test_unrelated_output_multiple_grad_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestWithEffects.test_effectful_custom_op_with_subclasses b/test/compiled_autograd_skips/TestWithEffects.test_effectful_custom_op_with_subclasses new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestWithEffects.test_effects_and_aliased_outputs b/test/compiled_autograd_skips/TestWithEffects.test_effects_and_aliased_outputs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestWithEffects.test_regular_effectful_op_in_forward_and_backward b/test/compiled_autograd_skips/TestWithEffects.test_regular_effectful_op_in_forward_and_backward new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/compiled_autograd_skips/TestWithEffects.test_regular_effectful_op_only_in_backward b/test/compiled_autograd_skips/TestWithEffects.test_regular_effectful_op_only_in_backward new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/conftest.py b/test/conftest.py index e02f24ad9cbbcb..d742430f886d72 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -341,5 +341,5 @@ def pytest_runtest_protocol(self, item, nextitem) -> None: self.cache.set(self.directory, self.lastrun) def pytest_sessionfinish(self, session, exitstatus): - if exitstatus == 0 and not self.run_single: + if exitstatus == 0: self.cache.set(self.directory, self.initial_val) diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 9a4624dfd69b86..b317e0400155c9 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -5,6 +5,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp diff --git a/test/cpp/aoti_abi_check/test_macros.cpp b/test/cpp/aoti_abi_check/test_macros.cpp new file mode 100644 index 00000000000000..a42e89d524ac4b --- /dev/null +++ b/test/cpp/aoti_abi_check/test_macros.cpp @@ -0,0 +1,18 @@ +#include + +#include + +namespace torch { +namespace aot_inductor { + +C10_API bool equal(int a, int b) { + return a == b; +} + +TEST(TestMacros, TestC10API) { + EXPECT_TRUE(equal(1, 1)); + EXPECT_FALSE(equal(1, 2)); +} + +} // namespace aot_inductor +} // namespace torch diff --git a/test/cpp/aoti_inference/CMakeLists.txt b/test/cpp/aoti_inference/CMakeLists.txt index 5ac32ef3b91c57..cd87ba6c5053d7 100644 --- a/test/cpp/aoti_inference/CMakeLists.txt +++ b/test/cpp/aoti_inference/CMakeLists.txt @@ -18,8 +18,9 @@ add_custom_command( OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt ${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt ${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt + # This script requires the torch package to be installed. COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py - DEPENDS compile_model.py + DEPENDS torch torch_python aoti_custom_class ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py ) add_custom_target(aoti_script_model ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 8541c53d402a28..bea27e8d1ec804 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -2890,7 +2890,6 @@ TEST_F(ModulesTest, TanhGELU) { ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); } -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST_F(ModulesTest, Mish) { Mish model; auto x = torch::randn(100) * 10; diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 6af7e9230d7755..b038db2eaabaca 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -164,7 +164,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { // so we have this hack to manually set the desync debug flag after PG // creation. void forceSetDesyncDebugFlag() { - desyncDebug_ = true; + watchdog_->setDesyncDebug(true); } private: diff --git a/test/cpp/jit/test_subgraph_utils.cpp b/test/cpp/jit/test_subgraph_utils.cpp index bc05b8e82991db..90e7218f35bfec 100644 --- a/test/cpp/jit/test_subgraph_utils.cpp +++ b/test/cpp/jit/test_subgraph_utils.cpp @@ -19,8 +19,7 @@ TEST(SubgraphUtilsTest, Basic) { for (bool reverse_iterate : {true, false}) { // Merge everything into a single subgraph bool first = true; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Node* subgraph; + Node* subgraph = nullptr; auto it = reverse_iterate ? graph->nodes().rbegin() : graph->nodes().begin(); auto end = reverse_iterate ? graph->nodes().rend() : graph->nodes().end(); @@ -84,8 +83,7 @@ graph(%a : Tensor, %b : Tensor, %c : Tensor): while (graph2->next() != *graph->nodes().end()) { SubgraphUtils::mergeNodeIntoSubgraph(graph2->next(), graph2); } - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Node* subgraph; + Node* subgraph = nullptr; if (reverse_merge) { SubgraphUtils::mergeNodeIntoSubgraph(graph2, graph1); subgraph = graph1; diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 620636330a4fc8..10b750d8b39ad7 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -6,9 +6,24 @@ file(GLOB_RECURSE NATIVERT_ALL_TEST_FILES "${NATIVERT_TEST_ROOT}/test_*.cpp") set(NATIVERT_TEST_SRCS ${NATIVERT_ALL_TEST_FILES} ${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp + ${TORCH_ROOT}/torch/nativert/graph/Graph.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp + ${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp + ${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp ${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp + ${TORCH_ROOT}/torch/nativert/executor/Weights.cpp ${TORCH_ROOT}/torch/nativert/common/FileUtil.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/FunctionSchema.cpp + ${TORCH_ROOT}/torch/nativert/executor/ExecutionPlanner.cpp + ${TORCH_ROOT}/torch/nativert/detail/ITree.cpp + ${TORCH_ROOT}/torch/nativert/executor/ExecutionFrame.cpp + ${TORCH_ROOT}/torch/nativert/kernels/C10Kernel.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/GreedyBySize.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/Bump.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/DisjointStorageGroups.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp + ${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp ) add_executable(test_nativert diff --git a/test/cpp/nativert/test_c10_kernel.cpp b/test/cpp/nativert/test_c10_kernel.cpp new file mode 100644 index 00000000000000..f731128b1c8106 --- /dev/null +++ b/test/cpp/nativert/test_c10_kernel.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include +#include + +namespace torch::nativert { + +at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) { + return a + b; +} + +TEST(C10KernelTest, computeInternal) { + auto registrar = c10::RegisterOperators().op( + "test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel); + + static constexpr std::string_view source = + R"(graph(%a, %b): +%x = test.foo.default(a=%a, b=%b) +return (%x) +)"; + + auto graph = stringToGraph(source); + const auto& nodes = graph->nodes(); + auto it = nodes.begin(); + std::advance(it, 1); + const Node& node = *it; + + c10::Device device = torch::Device(torch::kCPU, 0); + + auto a = at::randn({6, 6, 6}); + auto b = at::randn({6, 6, 6}); + + auto frame = ExecutionFrame(*graph); + frame.setIValue(graph->getValue("a")->id(), a); + frame.setIValue(graph->getValue("b")->id(), b); + + auto kernel = C10Kernel(&node, device); + + kernel.computeInternal(frame); + + at::Tensor expected = a + b; + EXPECT_TRUE( + torch::equal(frame.getTensor(graph->getValue("x")->id()), expected)); +} + +TEST(ScalarBinaryOpKernelTest, computeInternal) { + static constexpr std::string_view source = + R"(graph(%a, %b): +%x = _operator.add(a=%a, b=%b) +return (%x) +)"; + + auto graph = stringToGraph(source); + const auto& nodes = graph->nodes(); + auto it = nodes.begin(); + std::advance(it, 1); + const Node& node = *it; + + auto a = 1; + auto b = 2; + + auto frame = ExecutionFrame(*graph); + frame.setIValue(graph->getValue("a")->id(), a); + frame.setIValue(graph->getValue("b")->id(), b); + + auto kernel = ScalarBinaryOpKernel(&node); + + kernel.computeInternal(frame); + + auto expected = a + b; + EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected); +} + +} // namespace torch::nativert diff --git a/test/cpp/nativert/test_execution_frame.cpp b/test/cpp/nativert/test_execution_frame.cpp new file mode 100644 index 00000000000000..1f4a6975ad9376 --- /dev/null +++ b/test/cpp/nativert/test_execution_frame.cpp @@ -0,0 +1,98 @@ +#include +#include + +namespace torch::nativert { + +TEST(ExecutionFrameTest, CreateFrame) { + auto graph = stringToGraph(R"( + graph(%x, %y): + %a = foo(a=%x, b=%y) + %b = foo1(a=%x, b=%y) + %c = foo2(c=%a, d=%b) + return(%c) + )"); + + auto frame = ExecutionFrame(*graph); + + for (auto* v : graph->values()) { + frame.setIValue(v->id(), c10::IValue(at::tensor({v->id()}, at::kInt))); + auto& frame_v = frame.getIValue(v->id()); + EXPECT_EQ(frame_v.tagKind(), "Tensor"); + } + + auto outputs = frame.tryMoveUserOutputs(); + + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(outputs[0].tagKind(), "Tensor"); + EXPECT_EQ(outputs[0].toTensor().item().toInt(), graph->getValue("c")->id()); +} + +TEST(ExecutionFrameTest, TestSetBorrowedValue) { + auto graph = stringToGraph(R"( + graph(%x, %y): + %a = foo(a=%x, b=%y) + %b = foo1(a=%x, b=%y) + %c = foo2(c=%a, d=%b) + return(%c) + )"); + + auto x = c10::IValue(at::tensor({1}, at::kInt)); + auto y = c10::IValue(at::tensor({2}, at::kInt)); + + { + auto frame = ExecutionFrame(*graph); + + frame.setBorrowedIValue( + graph->getValue("x")->id(), + c10::MaybeOwnedTraits::createBorrow(x)); + frame.setBorrowedIValue( + graph->getValue("y")->id(), + c10::MaybeOwnedTraits::createBorrow(y)); + + [[maybe_unused]] auto& w = frame.getIValue(graph->getValue("x")->id()); + [[maybe_unused]] auto& z = frame.getIValue(graph->getValue("y")->id()); + + EXPECT_EQ(x.use_count(), 1); + EXPECT_EQ(y.use_count(), 1); + + EXPECT_TRUE(c10::MaybeOwnedTraits{}.debugBorrowIsValid( + frame.getIValue(graph->getValue("x")->id()))); + EXPECT_TRUE(c10::MaybeOwnedTraits{}.debugBorrowIsValid( + frame.getIValue(graph->getValue("y")->id()))); + } + + EXPECT_EQ(x.use_count(), 1); + EXPECT_EQ(y.use_count(), 1); +} + +TEST(ExecutionFrameTest, TestPersistentValue) { + auto graph = stringToGraph(R"( + graph(%x, %y, %my_weight): + %a = foo(a=%x, b=%y) + %b = foo1(a=%x, b=%y) + %c = foo2(c=%a, d=%b) + return(%c) + )"); + + Weights weights(graph.get()); + weights.setValue("my_weight", at::tensor({1}, at::kInt)); + + auto new_sig = graph->signature(); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast>&>( + new_sig.inputsToWeights()) + .emplace_back("my_weight", "my_weight"); + graph->setSignature(new_sig); + + auto frame = ExecutionFrame(*graph, weights); + + EXPECT_EQ(frame.weightVersion(), 0); + auto wid = graph->getValue("my_weight")->id(); + + EXPECT_NO_THROW(frame.getTensor(wid)); + // can't release persistent value + frame.releaseValueIfNeeded(wid); + EXPECT_FALSE(frame.getIValue(wid).isNone()); +} + +} // namespace torch::nativert diff --git a/test/cpp/nativert/test_execution_planner.cpp b/test/cpp/nativert/test_execution_planner.cpp new file mode 100644 index 00000000000000..8162d0fdbf2f0b --- /dev/null +++ b/test/cpp/nativert/test_execution_planner.cpp @@ -0,0 +1,47 @@ +#include +#include + +namespace torch::nativert { + +TEST(ExecutionPlannerTest, CreatePlan) { + auto graph = stringToGraph(R"( + graph(%x, %y): + %a = foo(a=%x, b=%y) + %b = foo1(a=%x, b=%y) + %c = foo2(c=%a, d=%b) + return(%c) + )"); + + { + auto plan = ExecutionPlanner{*graph}.createPlan(); + + auto& values_to_free = plan->valuesToFree; + EXPECT_EQ(values_to_free.size(), 5); + + for (const auto i : c10::irange(3)) { + EXPECT_TRUE(values_to_free[i].empty()); + } + + EXPECT_EQ(values_to_free[3].size(), 2); + std::set ids{values_to_free[3].begin(), values_to_free[3].end()}; + EXPECT_EQ( + ids, + std::set( + {graph->tryGetValue("a")->id(), graph->tryGetValue("b")->id()})); + + EXPECT_EQ(values_to_free[4].size(), 0); + } + + { + auto static_values = ExecutionPlanner::staticValues(*graph); + std::set static_ids{static_values.begin(), static_values.end()}; + EXPECT_EQ( + static_ids, + std::set( + {graph->tryGetValue("x")->id(), + graph->tryGetValue("y")->id(), + graph->tryGetValue("c")->id()})); + } +} + +} // namespace torch::nativert diff --git a/test/cpp/nativert/test_function_schema.cpp b/test/cpp/nativert/test_function_schema.cpp new file mode 100644 index 00000000000000..c80e6273f12aa5 --- /dev/null +++ b/test/cpp/nativert/test_function_schema.cpp @@ -0,0 +1,70 @@ +#include + +#include +#include +#include + +using namespace ::testing; + +int64_t increment_kernel(const at::Tensor& tensor, int64_t input) { + return input + 1; +} + +at::Tensor slice_kernel(const at::Tensor& tensor, int64_t dim) { + return tensor.slice(dim); +} + +TEST(TestFunctionSchema, testNoAlias) { + auto registrar = c10::RegisterOperators().op( + "_test::my_op(Tensor dummy, int input) -> int", &increment_kernel); + auto handle = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); + + EXPECT_TRUE(handle.has_value()); + EXPECT_TRUE(handle->hasSchema()); + + auto nativert_schema = torch::nativert::FunctionSchema(handle->schema()); + + EXPECT_FALSE(nativert_schema.alias(0, 0)); + EXPECT_FALSE(nativert_schema.alias(1, 0)); + + // bounds check + EXPECT_THROW(nativert_schema.alias(2, 0), c10::Error); + EXPECT_THROW(nativert_schema.alias(1, 1), c10::Error); +} + +TEST(TestFunctionSchema, testAliasOverride) { + auto registrar = c10::RegisterOperators().op( + "_test::my_op(Tensor dummy, int input) -> int", &increment_kernel); + auto handle = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); + + EXPECT_TRUE(handle.has_value()); + EXPECT_TRUE(handle->hasSchema()); + + auto nativert_schema = + torch::nativert::FunctionSchema(handle->schema(), {{0, 0}}); + + EXPECT_TRUE(nativert_schema.alias(0, 0)); + EXPECT_FALSE(nativert_schema.alias(1, 0)); + + // bounds check + EXPECT_THROW(nativert_schema.alias(2, 0), c10::Error); + EXPECT_THROW(nativert_schema.alias(1, 1), c10::Error); +} + +TEST(TestFunctionSchema, testAlias) { + auto registrar = c10::RegisterOperators().op( + "_test::my_op(Tensor(a) dummy, int input) -> Tensor(a)", &slice_kernel); + auto handle = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); + + EXPECT_TRUE(handle.has_value()); + EXPECT_TRUE(handle->hasSchema()); + + auto nativert_schema = torch::nativert::FunctionSchema(handle->schema()); + + EXPECT_TRUE(nativert_schema.alias(0, 0)); + EXPECT_FALSE(nativert_schema.alias(1, 0)); + + // bounds check + EXPECT_THROW(nativert_schema.alias(2, 0), c10::Error); + EXPECT_THROW(nativert_schema.alias(1, 1), c10::Error); +} diff --git a/test/cpp/nativert/test_graph.cpp b/test/cpp/nativert/test_graph.cpp new file mode 100644 index 00000000000000..945d2d51252f14 --- /dev/null +++ b/test/cpp/nativert/test_graph.cpp @@ -0,0 +1,647 @@ +#include +#include +#include + +#include + +using namespace ::testing; + +namespace torch::nativert { +TEST(GraphTest, Basic) { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o2, %baz) +)"; + auto graph = stringToGraph(source); + EXPECT_EQ(graph->inputs().size(), 3); + EXPECT_EQ(graph->inputs()[0]->name(), "foo"); + EXPECT_EQ(graph->inputs()[1]->name(), "bar"); + EXPECT_EQ(graph->inputs()[2]->name(), "baz"); + + const auto& nodes = graph->nodes(); + EXPECT_EQ(nodes.size(), 3); + // First node is the input node + auto it = nodes.begin(); + { + const auto& node = *it; + EXPECT_EQ(node.target(), "prim.Input"); + EXPECT_EQ(node.inputs().size(), 0); + EXPECT_EQ(node.outputs().size(), 3); + EXPECT_EQ(node.outputs()[0]->name(), "foo"); + EXPECT_EQ(node.outputs()[1]->name(), "bar"); + EXPECT_EQ(node.outputs()[2]->name(), "baz"); + } + { + std::advance(it, 1); + const auto& node = *it; + EXPECT_EQ(node.target(), "aten.foo"); + EXPECT_EQ(node.inputs().size(), 2); + EXPECT_EQ(node.inputs()[0].name, "self"); + EXPECT_EQ(node.inputs()[1].name, "target"); + + EXPECT_EQ(node.attributes().size(), 1); + EXPECT_EQ(node.attributes()[0].name, "alpha"); + } + { + std::advance(it, 1); + const auto& node = *it; + EXPECT_EQ(node.target(), "prim.Output"); + EXPECT_EQ(node.inputs().size(), 2); + EXPECT_EQ(node.inputs()[0].name, "o2"); + EXPECT_EQ(node.inputs()[1].name, "baz"); + } + EXPECT_EQ(graph->outputs().size(), 2); + EXPECT_EQ(graph->outputs()[0]->name(), "o2"); + EXPECT_EQ(graph->outputs()[1]->name(), "baz"); + + const auto& values = graph->values(); + EXPECT_EQ(values.size(), 5); + std::vector valueNames; + valueNames.reserve(values.size()); + for (const auto& v : values) { + valueNames.emplace_back(v->name()); + } + std::sort(valueNames.begin(), valueNames.end()); + + EXPECT_THAT( + valueNames, + ContainerEq(std::vector({"bar", "baz", "foo", "o1", "o2"}))); +} + +TEST(GraphTest, ValueProducer) { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o2, %baz) +)"; + auto graph = stringToGraph(source); + auto foo = graph->getValue("foo"); + EXPECT_EQ(foo->producer()->target(), "prim.Input"); + auto o1 = graph->getValue("o1"); + EXPECT_EQ(o1->producer()->target(), "aten.foo"); +} + +TEST(GraphTest, InsertBeforeAfter) { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o2, %baz) +)"; + auto graph = stringToGraph(source); + auto it = graph->nodes().begin(); + ++it; + auto& node = *it; + EXPECT_EQ(node.target(), "aten.foo"); + auto before = graph->createNode("before", {}); + auto after = graph->createNode("after", {}); + auto atEnd = graph->createNode("atEnd", {}); + + graph->insertBefore(before, &node); + graph->insertAfter(after, &node); + graph->insert(atEnd); + + static constexpr std::string_view expected = + R"(graph(%foo, %bar, %baz): + = before() +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) + = after() + = atEnd() +return(%o2, %baz) +)"; + EXPECT_EQ(graphToString(*graph), expected); +} + +TEST(GraphTest, ValueUses) { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o2, %baz) +)"; + auto graph = stringToGraph(source); + auto o2 = graph->getValue("o2"); + EXPECT_EQ(o2->users().size(), 1); + EXPECT_EQ(o2->users()[0]->target(), "prim.Output"); +} + +TEST(GraphTest, ApplyDevicePlacement) { + auto graph = Graph::createGraph(); + auto node1 = graph->insertNode("node1"); + auto node2 = graph->insertNode("node2"); + + node1->addAttribute({"a", c10::Device(c10::DeviceType::CPU)}); + node1->addAttribute({"b", c10::Device(c10::DeviceType::CUDA, 0)}); + node1->addAttribute({"c", c10::Device(c10::DeviceType::CUDA, 1)}); + + node2->addAttribute({"d", c10::Device(c10::DeviceType::CUDA, 0)}); + + graph->applyDevicePlacement( + Placement(std::unordered_map{ + {c10::Device(c10::DeviceType::CUDA, 0), + c10::Device(c10::DeviceType::CUDA, 1)}})); + + EXPECT_EQ( + std::get(node1->getAttribute("a").value), + c10::Device(c10::DeviceType::CPU)); + EXPECT_EQ( + std::get(node1->getAttribute("b").value), + c10::Device(c10::DeviceType::CUDA, 1)); + EXPECT_EQ( + std::get(node1->getAttribute("c").value), + c10::Device(c10::DeviceType::CUDA, 1)); + EXPECT_EQ( + std::get(node2->getAttribute("d").value), + c10::Device(c10::DeviceType::CUDA, 1)); +} + +TEST(GraphTest, ReplaceAllUses) { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o2, %baz) +)"; + auto graph = stringToGraph(source); + auto o2 = graph->getValue("o2"); + auto bar = graph->getValue("bar"); + auto foo = graph->getValue("foo"); + + EXPECT_EQ(o2->users().size(), 1); + EXPECT_EQ(bar->users().size(), 1); + EXPECT_EQ(foo->users().size(), 1); + + graph->replaceAllUses(o2, bar); + EXPECT_EQ(o2->users().size(), 0); + EXPECT_EQ(bar->users().size(), 2); + + graph->replaceAllUses(bar, foo); + EXPECT_EQ(bar->users().size(), 0); + EXPECT_EQ(foo->users().size(), 2); + static constexpr std::string_view expected = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%foo, alpha=0.1) +return(%foo, %baz) +)"; + EXPECT_EQ(graphToString(*graph), expected); +} + +TEST(GraphTest, GetUniqueValueName) { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o2, %bar) +)"; + auto graph = stringToGraph(source); + auto o2 = graph->getValue("o2"); + auto fooNode = o2->producer(); + auto v0 = graph->getUniqueValueName(); + graph->addValue(v0, Type::Kind::None, fooNode); + auto v1 = graph->getUniqueValueName(); + graph->addValue(v1, Type::Kind::None, fooNode); + auto v2 = graph->getUniqueValueName(); + EXPECT_EQ(v0, "v0"); + EXPECT_EQ(v1, "v1"); + EXPECT_EQ(v2, "v2"); +} + +TEST(GraphTest, ReplaceAllUsesMultiUse) { + static constexpr std::string_view source = + R"(graph(%foo, %bar): +%o1 = aten.foo(a=%foo, b=%foo, c=%bar) +return(%o1) +)"; + auto graph = stringToGraph(source); + auto foo = graph->getValue("foo"); + auto bar = graph->getValue("bar"); + graph->replaceAllUses(foo, bar); + + static constexpr std::string_view expected = + R"(graph(%foo, %bar): +%o1 = aten.foo(a=%bar, b=%bar, c=%bar) +return(%o1) +)"; + EXPECT_EQ(graphToString(*graph), expected); +} + +TEST(GraphTest, ReplaceAllUsesAfter) { + static constexpr std::string_view source = + R"(graph(%foo): +%o1 = aten.foo1(a=%foo) +%o2 = aten.foo2(a=%o1, b=%foo) +%o3 = aten.foo3(a=%o2, b=%o2, c=%foo) +return(%foo, %o1, %o2, %o3) +)"; + auto graph = stringToGraph(source); + auto foo = graph->getValue("foo"); + auto o1 = graph->getValue("o1"); + auto foo3Node = graph->getValue("o3")->producer(); + graph->replaceAllUsesAfterNode(foo, o1, foo3Node); + + static constexpr std::string_view expected = + R"(graph(%foo): +%o1 = aten.foo1(a=%foo) +%o2 = aten.foo2(a=%o1, b=%foo) +%o3 = aten.foo3(a=%o2, b=%o2, c=%foo) +return(%o1, %o1, %o2, %o3) +)"; + EXPECT_EQ(graphToString(*graph), expected); + EXPECT_EQ(foo->users().size(), 3); + EXPECT_EQ(o1->users().size(), 2); +} + +TEST(GraphTest, InsertingAfter) { + static constexpr std::string_view source = + R"(graph(%foo, %bar): +%o1 = aten.first(a=%foo) +%o2 = aten.foo(c=%bar) +return(%o1, %o2) +)"; + auto graph = stringToGraph(source); + auto origNode = graph->getValue("o1")->producer(); + { + InsertingAfter guard(origNode); + graph->insertNode("one"); + graph->insertNode("two"); + graph->insertNode("three"); + } + graph->insertNode("four"); + static constexpr std::string_view expected = + R"(graph(%foo, %bar): +%o1 = aten.first(a=%foo) + = one() + = two() + = three() +%o2 = aten.foo(c=%bar) + = four() +return(%o1, %o2) +)"; + EXPECT_EQ(graphToString(*graph), expected); +} + +TEST(NodeTest, GetInputAndAttribute) { + auto graph = Graph::createGraph(); + auto input1 = graph->addInput("input1", Type::Kind::Tensor); + auto input2 = graph->addInput("input2", Type::Kind::Tensor); + auto input3 = graph->addInput("input3", Type::Kind::Tensor); + auto node = graph->createNode("foo.bar"); + + node->addInput({"out_of_order", input1}); + node->addInput({"arg1", input2}); + node->addInput({"arg2", input3}); + + node->addAttribute({"b", static_cast(0)}); + node->addAttribute({"a", static_cast(2)}); + node->addAttribute({"c", static_cast(1)}); + { + const auto& input = node->getInput("out_of_order"); + EXPECT_EQ(input.name, "out_of_order"); + EXPECT_EQ(input.value, input1); + } + { + const auto& input = node->getInput("arg1"); + EXPECT_EQ(input.name, "arg1"); + EXPECT_EQ(input.value, input2); + } + { + const auto& input = node->getInput("arg2"); + EXPECT_EQ(input.name, "arg2"); + EXPECT_EQ(input.value, input3); + } + { + const auto& attr = node->getAttribute("a"); + EXPECT_EQ(attr.name, "a"); + EXPECT_EQ(attr.value, Constant(static_cast(2))); + } + { + const auto& attr = node->getAttribute("b"); + EXPECT_EQ(attr.name, "b"); + EXPECT_EQ(attr.value, Constant(static_cast(0))); + } + { + const auto& attr = node->getAttribute("c"); + EXPECT_EQ(attr.name, "c"); + EXPECT_EQ(attr.value, Constant(static_cast(1))); + } + + EXPECT_EQ(node->tryGetInput("doesnotexist"), nullptr); + EXPECT_EQ(node->tryGetAttribute("doesnotexist"), nullptr); +} + +TEST(NodeTest, NextPrev) { + static constexpr std::string_view source = + R"(graph(%foo): +%o1 = aten.foo1(a=%foo) +%o2 = aten.foo2(a=%o1, b=%foo) +%o3 = aten.foo3(a=%o2, b=%o2, c=%foo) +return(%foo, %o1, %o2, %o3) +)"; + auto graph = stringToGraph(source); + auto foo1 = graph->getValue("o1")->producer(); + auto foo2 = graph->getValue("o2")->producer(); + auto foo3 = graph->getValue("o3")->producer(); + EXPECT_EQ(foo1->next(), foo2); + EXPECT_EQ(foo2->next(), foo3); + EXPECT_EQ(foo3->prev(), foo2); + EXPECT_EQ(foo3->next(), graph->outputNode()); + EXPECT_EQ(foo2->prev(), foo1); + EXPECT_EQ(foo1->prev(), graph->inputNode()); + EXPECT_EQ(graph->inputNode()->prev(), nullptr); + EXPECT_EQ(graph->outputNode()->next(), nullptr); +} + +TEST(GraphTest, IsBefore) { + auto source = R"IR( + graph(%foo): + %o1 = aten.foo1(a=%foo) + %o2 = aten.foo2(a=%o1) + %o3 = aten.foo3(a=%o2) + return (%o3) + )IR"; + + auto graph = stringToGraph(source); + ASSERT_NE(graph, nullptr); + + auto* o1 = graph->tryGetValue("o1"); + auto* o2 = graph->tryGetValue("o2"); + auto* o3 = graph->tryGetValue("o3"); + + auto* foo1 = o1->producer(); + auto* foo2 = o2->producer(); + auto* foo3 = o3->producer(); + + EXPECT_TRUE(foo1->isBefore(foo2)) << "foo1 should appear before foo2"; + EXPECT_TRUE(foo2->isBefore(foo3)) << "foo2 should appear before foo3"; + EXPECT_TRUE(foo1->isBefore(foo3)) << "foo1 should appear before foo3"; + + EXPECT_FALSE(foo2->isBefore(foo1)) << "foo2 should not appear before foo1"; + EXPECT_FALSE(foo3->isBefore(foo2)) << "foo3 should not appear before foo2"; +} + +TEST(GraphTest, RemoveNodeWithUsers) { + // Check we shouldn't be able to remove a node that still has users + auto source = R"IR( + graph(%foo): + %o1 = aten.foo1(a=%foo) + %o2 = aten.foo2(a=%o1, b=%foo) + %o3 = aten.foo3(a=%o2, b=%o2, c=%foo) + return (%foo, %o1, %o3) + )IR"; + + auto graph = stringToGraph(source); + ASSERT_NE(graph, nullptr); + + auto* o2 = graph->tryGetValue("o2"); + auto* foo2 = o2->producer(); + + EXPECT_THROW(graph->removeNode(foo2), c10::Error); +} + +TEST(GraphTest, RemoveNodeUnused) { + // Check node removal works as expected + auto source = R"IR( + graph(%foo): + %o1 = aten.foo1(a=%foo) + %o2 = aten.foo2(a=%o1, b=%foo) + %unused = aten.fooUnused(a=%o2) + return(%foo, %o1, %o2) + )IR"; + auto graph = stringToGraph(source); + + auto* valUnused = graph->tryGetValue("unused"); + Node* nodeUnused = valUnused->producer(); + EXPECT_EQ(nodeUnused->target(), "aten.fooUnused"); + + graph->removeNode(nodeUnused); + graph->lint(); + + // %unused should now be gone + EXPECT_EQ(graph->tryGetValue("unused"), nullptr) + << "Value %unused should no longer exist in the graph"; + + for (const auto& node : graph->nodes()) { + EXPECT_NE(node.target(), "aten.fooUnused"); + for (const auto* output : node.outputs()) { + EXPECT_NE(output->name(), "unused") + << "Should not find %unused in any remaining node's outputs"; + } + } +} + +TEST(GraphTest, RemoveValue) { + auto source = R"IR( + graph(%foo): + %o1 = aten.foo1(a=%foo) + %o2 = aten.foo2(a=%o1, b=%foo) + %o3 = aten.foo3(a=%o2, b=%o2, c=%foo) + return (%foo, %o1, %o3) + )IR"; + + auto graph = stringToGraph(source); + auto* val_o1 = graph->tryGetValue("o1"); + + { + // Check we shouldn't be able to remove a value that still has users + EXPECT_THROW(graph->removeValue(val_o1), c10::Error); + } + + { + // Check value removal works as expected + graph->replaceAllUses(val_o1, graph->tryGetValue("foo")); + graph->removeValue(val_o1); + EXPECT_EQ(graph->tryGetValue("%o1"), nullptr); + } +} + +TEST(GraphTest, InsertGraph) { + auto source = R"IR( + graph(%foo): + %o1 = aten.foo1(a=%foo) + return (%o1) + )IR"; + + // Subgraph to be inserted + auto subgraphSource = R"IR( + graph(%x): + %s1 = aten.subFoo1(a=%x) + %s2 = aten.subFoo2(a=%s1) + return (%s2) + )IR"; + + auto mainGraph = stringToGraph(source); + auto subGraph = stringToGraph(subgraphSource); + + // Insert subGraph into mainGraph. Use %o1 as the subGraph's %x + auto val_o1 = mainGraph->tryGetValue("o1"); + std::unordered_map valueMap; + std::vector insertedOutputs = + mainGraph->insertGraph(*subGraph, {val_o1}, valueMap); + + EXPECT_EQ(insertedOutputs.size(), 1); + + // Check all new nodes are inserted correctly from the copied %s2 + auto* newS2 = insertedOutputs.front(); + + auto* newSubFoo2 = newS2->producer(); + EXPECT_EQ(newSubFoo2->target(), "aten.subFoo2"); + + auto* newS1 = newSubFoo2->inputs().front().value; + auto* newSubFoo1 = newS1->producer(); + EXPECT_EQ(newSubFoo1->target(), "aten.subFoo1"); + + EXPECT_EQ(newSubFoo1->inputs().front().value, val_o1); + + auto* subInputVal = subGraph->inputs().front(); + EXPECT_EQ(valueMap[subInputVal], val_o1); + for (const auto& [val1, val2] : valueMap) { + if (val1->name() == "s1") { + EXPECT_EQ(val2->name(), newS1->name()); + } + if (val1->name() == "s2") { + EXPECT_EQ(val2->name(), newS2->name()); + } + if (val1->name() == "x") { + EXPECT_EQ(val2->name(), val_o1->name()); + } + } + + mainGraph->lint(); +} + +TEST(GraphTest, CleanupDeadNodes) { + // %c is unused + const std::string source = R"( + graph(%x, %y): +%a = foo(a=%x, b=%y) +%b = foo1(c=%a) +%c = foo2(a=%b, b=%y) +return(%b) +)"; + auto graph = stringToGraph(source); + + // Verify that %c exists initially + auto* cVal = graph->tryGetValue("c"); + ASSERT_NE(nullptr, cVal); + size_t nodeCountBefore = graph->nodes().size(); + + graph->cleanupDeadNodes(); + + // %c should now be gone + EXPECT_EQ(nullptr, graph->tryGetValue("c")); + // %b should still be there + EXPECT_NE(nullptr, graph->tryGetValue("b")); + EXPECT_EQ(nodeCountBefore - 1, graph->nodes().size()); +} + +TEST(GraphTest, RenumberValues) { + const std::string source = R"( + graph(%x): +%a = foo(a=%x) +%b = foo1(a=%a) +return (%a) +)"; + auto graph = stringToGraph(source); + graph->cleanupDeadNodes(); + + // %b should now be gone + EXPECT_EQ(nullptr, graph->tryGetValue("b")); + + // %a should now be the last value + EXPECT_EQ(graph->tryGetValue("a")->id(), graph->numValues() - 1); + + // All values should be renumbered + size_t numVals = graph->numValues(); + std::unordered_set ids; + ids.reserve(numVals); + for (const auto* val : graph->values()) { + ASSERT_LT(val->id(), numVals); + ids.insert(val->id()); + } + + // Check ids are contiguous and unique b/w 0 and numVals + EXPECT_EQ(numVals, ids.size()); + for (size_t i = 0; i < numVals; ++i) { + EXPECT_NE(ids.end(), ids.find(i)); + } +} + +TEST(SerializationTest, RoundTrip) { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o1, %baz) +)"; + const auto graph = stringToGraph(source); + const auto serialized = graphToString(*graph); + EXPECT_EQ(source, serialized); +} + +TEST(SerializationTest, EscapedStringConstant) { + const auto parsed = + std::get(convertAtomicConstant(R"("string_\"escape")")); + std::string expected = "string_\\\"escape"; + EXPECT_EQ(parsed, expected); +} + +TEST(SerializationTest, DeviceConstant) { + const auto device = + std::get(convertAtomicConstant("Device{cuda:1}")); + EXPECT_EQ(device.index(), 1); + EXPECT_EQ(device.type(), c10::DeviceType::CUDA); +} + +TEST(SerializationTest, TrueConstant) { + const auto parsedTrue = std::get(convertAtomicConstant("true")); + EXPECT_EQ(parsedTrue, true); + const auto parsedFalse = std::get(convertAtomicConstant("false")); + EXPECT_EQ(parsedFalse, false); +} + +TEST(SerializationTest, MemoryFormatConstant) { + const auto parsed = std::get( + convertAtomicConstant("MemoryFormat::ContiguousFormat")); + EXPECT_EQ(parsed, c10::MemoryFormat::Contiguous); +} + +TEST(SerializationTest, FloatConstant) { + const auto parsed = std::get(convertAtomicConstant("5.0")); + EXPECT_EQ(parsed, 5.0); +} + +TEST(SerializationTest, IntConstant) { + const auto parsed = std::get(convertAtomicConstant("5")); + EXPECT_EQ(parsed, 5); +} + +TEST(SerializationTest, FloatExponentConstant) { + const auto parsed = std::get(convertAtomicConstant("1e-05")); + EXPECT_EQ(parsed, 0.00001); +} + +TEST(SerializationTest, SingleElementListConstant) { + const auto parsed = + std::get>(convertListConstant("[1]")); + const auto expected = std::vector{1}; + EXPECT_EQ(parsed, expected); +} + +TEST(SerializationTest, IntListConstant) { + const auto parsed = + std::get>(convertListConstant("[1, 2, 3, 4]")); + const auto expected = std::vector{1, 2, 3, 4}; + EXPECT_EQ(parsed, expected); +} + +TEST(SerializationTest, FloatListConstant) { + const auto parsed = std::get>( + convertListConstant("[1.0, 2.0, 3.0, 4.0]")); + const auto expected = std::vector{1.0, 2.0, 3.0, 4.0}; + EXPECT_EQ(parsed, expected); +} + +TEST(SerializationTest, BoolListConstant) { + const auto parsed = + std::get>(convertListConstant("[false, true, false]")); + const auto expected = std::vector{false, true, false}; + EXPECT_EQ(parsed, expected); +} + +} // namespace torch::nativert diff --git a/test/cpp/nativert/test_itree.cpp b/test/cpp/nativert/test_itree.cpp new file mode 100644 index 00000000000000..e0004f7db77e4d --- /dev/null +++ b/test/cpp/nativert/test_itree.cpp @@ -0,0 +1,1150 @@ +#include +#include + +#include + +#include +#include + +namespace torch::nativert::detail { + +using torch::nativert::Graph; +using torch::nativert::stringToGraph; +using torch::nativert::Type; +using torch::nativert::Value; + +std::pair, std::vector> makeValues( + int count) { + auto graph = Graph::createGraph(); + std::vector values; + + for (int i = 0; i < count; i++) { + std::string name = fmt::format("v{}", i); + Value* value = graph->addValue(name, Type::Kind::None, nullptr); + values.push_back(value); + } + + return std::make_pair(std::move(graph), values); +} + +TEST(ITreeTest, Unflatten) { + // Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}, (10,), {"11": 12}] + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": "torch.fx.immutable_collections.immutable_list", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": "torch.fx.immutable_collections.immutable_dict", + "context": "[\"11\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + + auto [graph, valuePtrs] = makeValues(8); + + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + std::vector flats = { + c10::IValue(0), + c10::IValue(1), + c10::IValue(2), + c10::IValue(7), + c10::IValue(8), + c10::IValue(9), + c10::IValue(10), + c10::IValue(12), + }; + auto itree = itreeUnflatten(flats, spec); + EXPECT_TRUE(itree.isList()); + EXPECT_EQ(itree.toListRef().size(), 5); + + EXPECT_TRUE(itree.toListRef().at(0).isTuple()); + EXPECT_EQ(itree.toListRef().at(0).toTupleRef().elements()[0], c10::IValue(0)); + EXPECT_EQ(itree.toListRef().at(0).toTupleRef().elements()[1], c10::IValue(1)); + + EXPECT_TRUE(itree.toListRef().at(1).isInt()); + EXPECT_EQ(itree.toListRef().at(1), c10::IValue(2)); + + EXPECT_TRUE(itree.toListRef().at(2).isGenericDict()); + EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("4"), c10::IValue(7)); + EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("5"), c10::IValue(8)); + EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("6"), c10::IValue(9)); + + EXPECT_TRUE(itree.toListRef().at(3).isList()); + EXPECT_EQ(itree.toListRef().at(3).toListRef().at(0), c10::IValue(10)); + + EXPECT_TRUE(itree.toListRef().at(4).isGenericDict()); + EXPECT_EQ(itree.toListRef().at(4).toGenericDict().at("11"), c10::IValue(12)); + + const auto flattened = itreeFlatten(itree, spec); + EXPECT_EQ(flattened.size(), flats.size()); + for (size_t i = 0; i < flattened.size(); i++) { + EXPECT_EQ(flattened[i], flats[i]); + } +} + +TEST(ITreeTest, NoVersion) { + auto jsonSpec = R"( + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } + )"; + + auto [graph, valuePtrs] = makeValues(2); + EXPECT_THROW({ itreeSpecLoads(jsonSpec, valuePtrs); }, std::exception); +} + +TEST(ITreeTest, NoField) { + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + + auto [graph, valuePtrs] = makeValues(3); + EXPECT_THROW(itreeSpecLoads(jsonSpec, valuePtrs), std::exception); +} + +TEST(ITreeTest, NoContext) { + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.dict", + "context": "[]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(3); + auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + + std::vector flats = { + c10::IValue(7), + c10::IValue(8), + c10::IValue(9), + }; + ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); +} + +TEST(ITreeTest, TooManyContext) { + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\", \"10\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + + auto [graph, valuePtrs] = makeValues(3); + auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + + std::vector flats = { + c10::IValue(7), + c10::IValue(8), + c10::IValue(9), + }; + ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); +} + +TEST(ITreeTest, DoubleRegister) { + EXPECT_THROW( + { registerPytreeNode("builtins.dict", NodeDef{}); }, std::exception); +} + +TEST(ITreeTest, NotEnoughUnflatten) { + // Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}] + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(6); + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + std::vector flats = { + c10::IValue(0), + c10::IValue(1), + c10::IValue(2), + c10::IValue(7), + }; + ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); +} + +TEST(ITreeTest, TooManyUnflatten) { + // Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}] + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(6); + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + std::vector flats = { + c10::IValue(0), + c10::IValue(1), + c10::IValue(2), + c10::IValue(7), + c10::IValue(0), + c10::IValue(1), + c10::IValue(2), + c10::IValue(7), + c10::IValue(0), + c10::IValue(1), + c10::IValue(2), + c10::IValue(7), + }; + ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); +} + +TEST(ITreeTest, Flatten) { + // Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}, (10,), {"11": 12}] + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": "torch.fx.immutable_collections.immutable_list", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": "torch.fx.immutable_collections.immutable_dict", + "context": "[\"11\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(8); + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)}); + c10::Dict dict( + c10::StringType::get(), c10::AnyType::get()); + dict.insert("4", c10::IValue(7)); + dict.insert("5", c10::IValue(8)); + dict.insert("6", c10::IValue(9)); + c10::List ilist(c10::AnyType::get()); + ilist.push_back(c10::IValue(10)); + c10::Dict idict( + c10::StringType::get(), c10::AnyType::get()); + idict.insert("11", c10::IValue(12)); + c10::List list(c10::AnyType::get()); + list.push_back(std::move(tup)); + list.push_back(c10::IValue(2)); + list.push_back(std::move(dict)); + list.push_back(std::move(ilist)); + list.push_back(std::move(idict)); + auto flats = itreeFlatten(c10::IValue{list}, spec); + std::vector expected = { + c10::IValue(0), + c10::IValue(1), + c10::IValue(2), + c10::IValue(7), + c10::IValue(8), + c10::IValue(9), + c10::IValue(10), + c10::IValue(12), + }; + for (const auto& [i, flat] : c10::enumerate(flats)) { + EXPECT_EQ(flat, expected.at(i)); + } +} + +TEST(ITreeTest, IValueApplyFromArgs) { + // inputSpec for testing is generated from E2ETestModelWithNestedDictInput + /* + args = ( + { + "a": ( + torch.rand(4, 4), + { + 123: (torch.rand(4, 4), torch.rand(4, 4)), + 234: (torch.rand(4, 4), torch.rand(4, 4)), + }, + ), + "b": ( + torch.rand(4, 4), + { + 345: (torch.rand(4, 4), torch.rand(4, 4)), + 456: (torch.rand(4, 4), torch.rand(4, 4)), + }, + ), + }, + )*/ + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": "builtins.dict", + "context": "[\"a\", \"b\"]", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[123, 234]", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } + ] + }, + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[345, 456]", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } + ] + } + ] + } + ] + }, + { + "type": "builtins.dict", + "context": "[]", + "children_spec": [] + } + ] + } +] + )"; + + auto tup_a1_123 = + c10::ivalue::Tuple::create({c10::IValue(1), c10::IValue(2)}); + auto tup_a1_234 = + c10::ivalue::Tuple::create({c10::IValue(3), c10::IValue(4)}); + c10::Dict dict_a1( + c10::StringType::get(), c10::AnyType::get()); + dict_a1.insert(123, tup_a1_123); + dict_a1.insert(234, tup_a1_234); + auto tup_a = + c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(dict_a1)}); + + auto tup_b1_345 = + c10::ivalue::Tuple::create({c10::IValue(6), c10::IValue(7)}); + auto tup_b1_456 = + c10::ivalue::Tuple::create({c10::IValue(8), c10::IValue(9)}); + c10::Dict dict_b1( + c10::StringType::get(), c10::AnyType::get()); + dict_b1.insert(345, tup_b1_345); + dict_b1.insert(456, tup_b1_456); + auto tup_b = + c10::ivalue::Tuple::create({c10::IValue(5), c10::IValue(dict_b1)}); + + c10::Dict dict( + c10::StringType::get(), c10::AnyType::get()); + dict.insert("a", tup_a); + dict.insert("b", tup_b); + std::vector args = {c10::IValue(dict)}; + + for (int usedIdx = 0; usedIdx < 10; usedIdx++) { + std::vector isUsed(10, false); + isUsed[usedIdx] = true; + std::stringstream ss; + for (int i = 0; i < 10; ++i) { + if (isUsed[i]) { + ss << fmt::format("%o1 = aten.foo(a=%a{})\n", i); + } + } + std::string source = fmt::format( + R"(graph(%a0, %a1, %a2, %a3, %a4, %a5, %a6, %a7, %a8, %a9): +{} +return(%o1) +)", + ss.str()); + + auto graph = stringToGraph(source); + std::vector userInputs( + graph->userInputs().begin(), graph->userInputs().end()); + + const auto spec = itreeSpecLoads(jsonSpec, userInputs); + + std::vector visited; + auto fn = [&](const c10::IValue& leaf, const Value* value) { + visited.push_back(value->id()); + }; + ivalueApplyFromArgs(fn, args, {}, spec); + + EXPECT_EQ(visited.size(), 1); + EXPECT_EQ(visited[0], usedIdx); + } +} + +TEST(ITreeTest, UnmatchedFlattenType) { + // Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}] + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(6); + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)}); + c10::Dict dict( + c10::StringType::get(), c10::AnyType::get()); + dict.insert("4", c10::IValue(7)); + dict.insert("5", c10::IValue(8)); + dict.insert("6", c10::IValue(9)); + EXPECT_THROW( + { itreeFlatten(c10::IValue{std::move(dict)}, spec); }, std::exception); +} + +TEST(ITreeTest, UnmatchedDictFlatten) { + // Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}] + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(6); + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)}); + c10::Dict dict( + c10::StringType::get(), c10::AnyType::get()); + dict.insert("4", c10::IValue(7)); + dict.insert("5", c10::IValue(8)); + dict.insert("100", c10::IValue(8)); + dict.insert("101", c10::IValue(8)); + c10::List list(c10::AnyType::get()); + list.push_back(std::move(tup)); + list.push_back(c10::IValue(2)); + list.push_back(std::move(dict)); + ASSERT_DEATH( + { itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed"); +} + +TEST(ITreeTest, DictFlattenTest) { + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(3); + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + c10::Dict dict( + c10::StringType::get(), c10::AnyType::get()); + // allow dict.size < context + // test dict.size=2 , context,size=3, + dict.insert("4", c10::IValue(7)); + dict.insert("5", c10::IValue(8)); + c10::List list(c10::AnyType::get()); + list.push_back(std::move(dict)); + itreeFlatten(c10::IValue{std::move(list)}, spec); +} + +TEST(ITreeTest, UnmatchedTupleFlatten) { + // Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}] + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\", \"6\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(6); + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + auto tup = c10::ivalue::Tuple::create({c10::IValue(0)}); + c10::Dict dict( + c10::StringType::get(), c10::AnyType::get()); + dict.insert("4", c10::IValue(7)); + dict.insert("5", c10::IValue(8)); + dict.insert("6", c10::IValue(8)); + c10::List list(c10::AnyType::get()); + list.push_back(std::move(tup)); + list.push_back(c10::IValue(2)); + list.push_back(std::move(dict)); + ASSERT_DEATH( + { itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed"); +} + +TEST(ITreeTest, ToAtenType) { + // Original data: ((0, 1), 2, {"4": 7, "5": 8}, [10], {6: 9}) + auto jsonSpec = R"( +[ + 1, + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": "builtins.tuple", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": "builtins.dict", + "context": "[\"4\", \"5\"]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + }, + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": "builtins.list", + "context": "null", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + } + ] + }, + { + "type": "builtins.dict", + "context": "[6]", + "children_spec": [ + { + "type": null, + "context": null, + "children_spec": [] + } + ] + } + ] + } +] + )"; + auto [graph, valuePtrs] = makeValues(7); + const auto spec = itreeSpecLoads(jsonSpec, valuePtrs); + auto atenType = spec.toAtenType(); + + // Root level is tuple. + EXPECT_EQ(atenType->kind(), c10::TypeKind::TupleType); + const c10::TupleType& rootType = atenType->expectRef(); + EXPECT_EQ(rootType.elements().size(), 5); + + at::TypePtr elementType = rootType.elements()[0]; + EXPECT_EQ(elementType->kind(), c10::TypeKind::TupleType); + EXPECT_EQ( + elementType->expectRef().elements()[0]->kind(), + c10::TypeKind::AnyType); + EXPECT_EQ( + elementType->expectRef().elements()[1]->kind(), + c10::TypeKind::AnyType); + + elementType = rootType.elements()[1]; + EXPECT_EQ(elementType->kind(), c10::TypeKind::AnyType); + + elementType = rootType.elements()[2]; + EXPECT_EQ(elementType->kind(), c10::TypeKind::DictType); + EXPECT_EQ( + elementType->expectRef().getKeyType()->kind(), + c10::TypeKind::StringType); + EXPECT_EQ( + elementType->expectRef().getValueType()->kind(), + c10::TypeKind::AnyType); + + elementType = rootType.elements()[3]; + EXPECT_EQ(elementType->kind(), c10::TypeKind::ListType); + EXPECT_EQ( + elementType->expectRef().getElementType()->kind(), + c10::TypeKind::AnyType); + + elementType = rootType.elements()[4]; + EXPECT_EQ(elementType->kind(), c10::TypeKind::DictType); + EXPECT_EQ( + elementType->expectRef().getKeyType()->kind(), + c10::TypeKind::IntType); + EXPECT_EQ( + elementType->expectRef().getValueType()->kind(), + c10::TypeKind::AnyType); +} + +} // namespace torch::nativert::detail diff --git a/test/cpp/nativert/test_layout_planner_algorithm.cpp b/test/cpp/nativert/test_layout_planner_algorithm.cpp new file mode 100644 index 00000000000000..0d4f8fb0d27378 --- /dev/null +++ b/test/cpp/nativert/test_layout_planner_algorithm.cpp @@ -0,0 +1,86 @@ +#include +#include + +#include +#include +#include + +using namespace ::testing; +using namespace torch::nativert; + +std::vector create_test_allocation_specs() { + std::vector specs; + + const std::vector> test_cases = { + {0, 1, 32}, + {1, 4, 28}, + {2, 5, 36}, + {3, 5, 16}, + {4, 5, 8}, + {5, 7, 64}, + {6, 8, 10}, + {7, 8, 40}, + }; + + specs.reserve(test_cases.size()); + for (const auto& [l_start, l_end, size] : test_cases) { + specs.push_back(AllocationSpec{AllocationLifetime(l_start, l_end), size}); + }; + + return specs; +} + +// figure 6 -- https://arxiv.org/pdf/2001.03288 +TEST(LayoutPlannerAlgorithmTests, TestGreedyBySize) { + auto result = GreedyBySizeAllocationPlanner(create_test_allocation_specs()); + + EXPECT_EQ(result.total_size, 124); + + auto& allocations = result.allocations; + + EXPECT_EQ(allocations[0].offset, 0); + EXPECT_EQ(allocations[1].offset, 32); + EXPECT_EQ(allocations[2].offset, 64); + EXPECT_EQ(allocations[3].offset, 100); + EXPECT_EQ(allocations[4].offset, 116); + EXPECT_EQ(allocations[5].offset, 0); + EXPECT_EQ(allocations[6].offset, 104); + EXPECT_EQ(allocations[7].offset, 64); +} + +TEST(LayoutPlannerAlgorithmTests, TestBump) { + auto specs = create_test_allocation_specs(); + auto result = BumpAllocationPlanner(create_test_allocation_specs()); + + auto& allocations = result.allocations; + + size_t offset = 0; + for (auto&& [i, spec] : c10::enumerate(specs)) { + EXPECT_EQ(allocations[i].offset, offset); + offset += spec.size; + } + + EXPECT_EQ(result.total_size, offset); +} + +TEST(LayoutPlannerAlgorithmTests, TestStorageGroup) { + auto specs = create_test_allocation_specs(); + auto result = DisjointStorageGroupsPlanner(create_test_allocation_specs()); + + auto& allocations = result.allocations; + + EXPECT_EQ(allocations[0].offset, 0); + EXPECT_EQ(allocations[1].offset, 36); + EXPECT_EQ(allocations[2].offset, 0); + EXPECT_EQ(allocations[3].offset, 100); + EXPECT_EQ(allocations[4].offset, 140); + EXPECT_EQ(allocations[5].offset, 36); + EXPECT_EQ(allocations[6].offset, 140); + EXPECT_EQ(allocations[7].offset, 100); + + for (auto&& [i, spec] : c10::enumerate(specs)) { + EXPECT_EQ(allocations[i].size, spec.size); + } + + EXPECT_EQ(result.total_size, 150); +} diff --git a/test/cpp/nativert/test_op_kernel.cpp b/test/cpp/nativert/test_op_kernel.cpp new file mode 100644 index 00000000000000..312e355f9ca21a --- /dev/null +++ b/test/cpp/nativert/test_op_kernel.cpp @@ -0,0 +1,55 @@ +#include +#include +#include +#include + +namespace torch::nativert { + +int64_t increment_kernel(const at::Tensor& tensor, int64_t input) { + return input + 1; +} + +TEST(OpKernelTest, GetOperatorForTargetValid) { + auto registrar = c10::RegisterOperators().op( + "test::foo(Tensor dummy, int input) -> int", &increment_kernel); + std::string target = "test.foo.default"; + EXPECT_NO_THROW({ + c10::OperatorHandle handle = getOperatorForTarget(target); + EXPECT_TRUE(handle.hasSchema()); + EXPECT_EQ(handle.operator_name().name, "test::foo"); + EXPECT_EQ(handle.operator_name().overload_name, ""); + }); +} + +TEST(OpKernelTest, GetOperatorForTargetInvalid) { + std::string target = "invalid.target"; + EXPECT_THROW(getOperatorForTarget(target), c10::Error); +} + +TEST(OpKernelTest, GetReadableArgs) { + c10::FunctionSchema schema = c10::FunctionSchema( + "test_op", + "", + {c10::Argument("tensor_arg"), + c10::Argument("tensor_list_arg"), + c10::Argument("int_arg"), + c10::Argument("none_arg")}, + {}); + std::vector stack = { + at::tensor({1, 2, 3}), + c10::IValue( + std::vector{at::tensor({1, 2}), at::tensor({3, 4})}), + c10::IValue(1), + c10::IValue(), + }; + std::string expected = + "arg0 tensor_arg: Tensor int[3]cpu\n" + "arg1 tensor_list_arg: GenericList [int[2]cpu, int[2]cpu, ]\n" + "arg2 int_arg: Int 1\n" + "arg3 none_arg: None \n"; + + std::string result = readableArgs(schema, stack); + EXPECT_EQ(result, expected); +} + +} // namespace torch::nativert diff --git a/test/cpp/nativert/test_serialization.cpp b/test/cpp/nativert/test_serialization.cpp new file mode 100644 index 00000000000000..3504f02b53a9ad --- /dev/null +++ b/test/cpp/nativert/test_serialization.cpp @@ -0,0 +1,51 @@ +#include +#include + +namespace torch::nativert { +TEST(SerializationTest, CheckIsSymbolic) { + torch::_export::TensorArgument tensor_arg; + torch::_export::Argument as_tensor_arg; + as_tensor_arg.set_as_tensor(tensor_arg); + EXPECT_TRUE(isSymbolic(as_tensor_arg)); + + std::vector tensor_args; + torch::_export::Argument as_tensors_arg; + as_tensors_arg.set_as_tensors(tensor_args); + EXPECT_TRUE(isSymbolic(as_tensors_arg)); + + torch::_export::SymIntArgument sym_int_arg; + torch::_export::Argument as_sym_int_arg; + as_sym_int_arg.set_as_sym_int(sym_int_arg); + EXPECT_TRUE(isSymbolic(as_sym_int_arg)); + + torch::_export::Argument as_int_arg; + as_int_arg.set_as_int(static_cast(1)); + EXPECT_FALSE(isSymbolic(as_int_arg)); + + torch::_export::Argument as_bool_arg; + as_bool_arg.set_as_bool(true); + EXPECT_FALSE(isSymbolic(as_bool_arg)); + + torch::_export::Argument as_string_arg; + as_string_arg.set_as_string("test_string"); + EXPECT_FALSE(isSymbolic(as_string_arg)); +} + +TEST(SerializationTest, ConstantToValue) { + torch::_export::Argument as_int_arg; + as_int_arg.set_as_int(static_cast(42)); + auto value = constantToValue(as_int_arg, false); + EXPECT_EQ(value, Constant(static_cast(42))); + + torch::_export::Argument as_bool_arg; + as_bool_arg.set_as_bool(true); + value = constantToValue(as_bool_arg, false); + EXPECT_EQ(value, Constant(true)); + + torch::_export::Argument as_string_arg; + as_string_arg.set_as_string("test_string"); + value = constantToValue(as_string_arg, false); + EXPECT_EQ(value, Constant("test_string")); +} + +} // namespace torch::nativert diff --git a/test/cpp/nativert/test_weights.cpp b/test/cpp/nativert/test_weights.cpp new file mode 100644 index 00000000000000..43d05d5ad887ab --- /dev/null +++ b/test/cpp/nativert/test_weights.cpp @@ -0,0 +1,92 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::nativert { +class WeightsTest : public ::testing::Test { + protected: + void SetUp() override { + static constexpr std::string_view source = + R"(graph(%foo, %bar, %baz): +%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) +return(%o2, %baz) +)"; + graph = stringToGraph(source); + placement = std::make_unique(c10::Device(c10::DeviceType::CPU)); + } + std::shared_ptr graph; + std::unique_ptr placement; +}; +TEST_F(WeightsTest, ConstructEmptyStateDict) { + std::unordered_map stateDict; + Weights weights(graph.get(), stateDict, *placement); + // Check that weights are initialized correctly + EXPECT_TRUE(weights.parameters().empty()); + EXPECT_TRUE(weights.buffers().empty()); + EXPECT_FALSE(weights.contains("non_existent_weight")); +} +TEST_F(WeightsTest, SetAndGetValue) { + std::unordered_map stateDict; + Weights weights(graph.get(), stateDict, *placement); + at::Tensor tensor = at::ones({2, 2}); + weights.setValue("added_weight", tensor); + EXPECT_TRUE(weights.contains("added_weight")); + EXPECT_EQ(weights.at("added_weight").sizes(), tensor.sizes()); +} + +} // namespace torch::nativert + +using namespace ::testing; +struct ContainsTensorDict : torch::CustomClassHolder { + explicit ContainsTensorDict(at::Tensor t) : t_(t) {} + + explicit ContainsTensorDict(c10::Dict dict) { + t_ = dict.at(std::string("init_tensor")); + } + + c10::Dict serialize() const { + c10::Dict dict; + dict.insert(std::string("init_tensor"), t_); + return dict; + } + + at::Tensor t_; +}; + +static auto reg = + torch::class_("testing", "ContainsTensorDict") + .def(torch::init()) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& self) + -> c10::Dict { + return self->serialize(); + }, + // __setstate__ + [](c10::Dict data) + -> c10::intrusive_ptr { + return c10::make_intrusive(std::move(data)); + }); + +TEST(CustomWeightsTest, TestCustomObjWithContainedTensor) { + // Save + auto customObj = + c10::make_intrusive(torch::tensor({1, 2, 3})); + const auto bytes = torch::jit::pickle_save(c10::IValue(std::move(customObj))); + + // Load + const auto loadedCustomObj = + torch::jit::pickle_load_obj(std::string{bytes.begin(), bytes.end()}); + EXPECT_TRUE(loadedCustomObj.isObject()); + EXPECT_EQ( + loadedCustomObj.to>() + ->t_[0] + .item(), + 1); +} diff --git a/test/cpp/tensorexpr/README.md b/test/cpp/tensorexpr/README.md index 055d2201b009d9..f86a50a65e8047 100644 --- a/test/cpp/tensorexpr/README.md +++ b/test/cpp/tensorexpr/README.md @@ -40,8 +40,8 @@ We glob all the test files together in `CMakeLists.txt` so that you don't have to edit it every time you add a test. Unfortunately, this means that in order to get the build to pick up your new test file, you need to re-run cmake: -``` -python setup.py build --cmake +```bash +CMAKE_FRESH=1 python setup.py build ``` ## How do I run the tests? diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp index a7df88b8ab990e..2605842d6e74de 100644 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ b/test/cpp/tensorexpr/test_boundsinference.cpp @@ -202,9 +202,7 @@ TEST(BoundsInference, _5) { Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getLoopStmtsFor(b); LoopNest::splitWithTail(loops[0], 16, &inner, &tail); @@ -680,7 +678,6 @@ TEST(BoundsInference, GetPotentialHazardsLoopSplit) { }); LoopNest l({A}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner, tail; // Splitting with tail by something offset creates a tail which also writes to diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index ab387458fb121e..a8bda8814dbaeb 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -164,9 +164,7 @@ TEST(LoopNest, ExprSliceHeadWithLoopOptions) { }; Tensor tensor = Compute("f", {10}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); @@ -187,16 +185,12 @@ TEST(LoopNest, ExprSliceTailWithLoopOptions) { }; Tensor tensor = Compute("f", {10}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceTail(loops[0], 4, &head, &tail); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail_head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail_tail; tail->set_gpu_block_index(LoopOptions::IDX_Y); LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); @@ -219,9 +213,7 @@ TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { }; Tensor tensor = Compute("f", {10}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceHead(loops[0], 10, &head, &tail); @@ -239,9 +231,7 @@ TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { }; Tensor tensor = Compute("f", {10}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceHead(loops[0], 100, &head, &tail); @@ -259,9 +249,7 @@ TEST(LoopNest, ExprSliceHead) { }; Tensor tensor = Compute("f", {10}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceHead(loops[0], 4, &head, &tail); @@ -283,9 +271,7 @@ TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { LoopNest l({tensor}); std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; LoopNest::sliceTail(loops[0], 4, &head, &tail); // head: [0, 6) @@ -307,9 +293,7 @@ TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { }; Tensor tensor = Compute("f", {10}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceTail(loops[0], 10, &head, &tail); @@ -329,9 +313,7 @@ TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { }; Tensor tensor = Compute("f", {10}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceTail(loops[0], 100, &head, &tail); @@ -349,9 +331,7 @@ TEST(LoopNest, ExprSliceTail) { }; Tensor tensor = Compute("f", {10}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceTail(loops[0], 4, &head, &tail); @@ -375,9 +355,7 @@ TEST(LoopNest, ExprSplitAndSlice) { Tensor tensor = Compute("f", {100}, func); LoopNest l({tensor}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // outer: [0, 4) @@ -428,9 +406,7 @@ TEST(LoopNest, ExprSliceAndNormalize) { LoopNest l({tensor}); std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; LoopNest::sliceHead(loops[0], 2, &head, &tail); // head: [0, 2) @@ -460,9 +436,7 @@ TEST(LoopNest, ExprSliceWithVariableDimension) { std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; LoopNest::sliceHead(loops[0], 2, &head, &tail); @@ -850,7 +824,6 @@ TEST(LoopNest, SplitWithTailWithLoopOptions) { Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner, tail; LoopNest l({tensor}); @@ -880,7 +853,6 @@ TEST(LoopNest, SplitWithMaskWithLoopOptions) { Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner; LoopNest l({tensor}); @@ -1433,7 +1405,6 @@ TEST(LoopNest, ScheduleSplitTwiceThenInline) { Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); Tensor b = Compute( "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr i_inner; LoopNest l({b}, {a, b}); @@ -3410,9 +3381,7 @@ TEST(LoopNest, NormalizeAndSplitWithTail) { LoopNest::normalize(for_stmt); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr x_inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr x_tail; LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); @@ -3454,9 +3423,7 @@ TEST(LoopNest, NotNormalizeAndSplitWithTail) { auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2)); auto parent_block = Block::make({for_stmt}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr x_inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr x_tail; LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail); @@ -5349,7 +5316,6 @@ TEST(LoopNest, fuseLoopsSimple) { auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); @@ -5389,7 +5355,6 @@ TEST(LoopNest, fuseLoopsMultiple) { auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); auto par = Block::make({forI, forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); @@ -5446,7 +5411,6 @@ TEST(LoopNest, fuseLoopsNested) { auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); auto forN = For::make(n, 0, 20, Block::make({initB, forK})); auto par = Block::make({forM, forN}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); @@ -5506,7 +5470,6 @@ TEST(LoopNest, fuseLoopsNested2D) { 50, Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); auto par = Block::make({forI, forM}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); @@ -5547,7 +5510,6 @@ TEST(LoopNest, fuseLoopsNested2DInner) { auto forN = For::make( n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); @@ -5583,7 +5545,6 @@ TEST(LoopNest, fuseLoopsDifferentStopBounds) { auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -5604,7 +5565,6 @@ TEST(LoopNest, fuseLoopsDifferentStartBounds) { auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -5627,7 +5587,6 @@ TEST(LoopNest, fuseLoopsNotContiguous) { auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, initB, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -5654,7 +5613,6 @@ TEST(LoopNest, fuseLoopsWithDifferentParents) { auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forI, initB, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -5676,7 +5634,6 @@ TEST(LoopNest, fuseLoopsWithVariableBounds) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); @@ -5712,7 +5669,6 @@ TEST(LoopNest, fuseLoopsWithExprBounds) { auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j))); auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); @@ -5749,7 +5705,6 @@ TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); @@ -5784,7 +5739,6 @@ TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k))); auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); @@ -5830,7 +5784,6 @@ TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { auto forM = For::make(m, 0, 20, forN); auto par = Block::make({forI, forM}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); @@ -5876,7 +5829,6 @@ TEST(LoopNest, fuseLoopsWithReductions) { auto forM = For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); auto par = Block::make({forI, forM}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); @@ -5932,7 +5884,6 @@ TEST(LoopNest, fuseLoopsWith2DReductions) { auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC)); auto par = Block::make({forI, forM}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); @@ -5980,7 +5931,6 @@ TEST(LoopNest, fuseLoopsWithComplexIndices) { auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); auto par = Block::make({forI, forM}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); @@ -6025,7 +5975,6 @@ TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); auto par = Block::make({forI, forM}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); } @@ -6054,7 +6003,6 @@ TEST(LoopNest, fuseLoopsWithTranspose) { auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); auto par = Block::make({forI, forM}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); } @@ -6075,7 +6023,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies1) { For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -6096,7 +6043,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies2) { For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -6139,7 +6085,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies3) { auto forN = For::make(n, 0, 20, Block::make({initB, forK})); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forM, forN}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); } @@ -6181,7 +6126,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies4) { Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forI, forM}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); } @@ -6209,7 +6153,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies5) { Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100)))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); } @@ -6235,7 +6178,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies6) { b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forJ, forK}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); } @@ -6261,7 +6203,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies7) { auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) auto par = Block::make({forK, forJ}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr fused_loop; ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); } diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index ddb63431fe3f6e..bdc744ae4e0339 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -1066,7 +1066,6 @@ TEST(Reductions, ReduceOverSplitRfactor) { Tensor c = Reduce("sum", {}, Sum(), b, {N, K}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr i, t; LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); LoopNest::reorderAxis(loops[0], i); @@ -1573,7 +1572,6 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) { LoopNest l({e}, {c, d, e}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner; // Split outer reduction axis. @@ -1623,7 +1621,6 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { LoopNest l({e}, {c, d, e}); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner; // reorder outer reduction axes. @@ -1678,7 +1675,6 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) { LoopNest::reorderAxis(loops.at(0), loops.at(1)); loops = loop.getLoopStmtsFor(c); auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) BufPtr rfac_buf; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); loop.distributeLoop(loops.at(0)); @@ -1744,7 +1740,6 @@ TEST(Reductions, ReductionRfactorCacheTempInner) { LoopNest::reorderAxis(loops.at(0), loops.at(1)); loops = loop.getLoopStmtsFor(c); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) BufPtr rfac_buf; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); loop.distributeLoop(loops.at(0)); diff --git a/test/cpp_api_parity/functional_impl_check.py b/test/cpp_api_parity/functional_impl_check.py index b4272a2df1bd8e..34b9ac1581272f 100644 --- a/test/cpp_api_parity/functional_impl_check.py +++ b/test/cpp_api_parity/functional_impl_check.py @@ -158,7 +158,8 @@ def camel_case_to_snake_case(camel_case_str): return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "") else: raise RuntimeError( - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}" ) @@ -179,7 +180,8 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name): ) else: raise RuntimeError( - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}" ) @@ -217,7 +219,8 @@ def write_test_to_test_class( or "cpp_function_call" in test_params_dict ), ( "To enable C++ API parity test, " - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}. \n" "If you are interested in adding the C++ API parity test, please see:\n" "NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n" "If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this." @@ -233,14 +236,16 @@ def write_test_to_test_class( functional_name = compute_functional_name(test_params_dict) - assert hasattr( - torch.nn.functional, functional_name - ), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)" # noqa: B950 + assert hasattr(torch.nn.functional, functional_name), ( + f"`torch.nn.functional` doesn't have function `{functional_name}`. " + f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)" + ) functional_full_name = "F::" + functional_name assert functional_full_name in parity_table["torch::nn::functional"], ( - f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. " + f"Please add `{functional_full_name}` entry to `torch::nn::functional` " + "section of `test/cpp_api_parity/parity-tracker.md`. " f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)" ) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 9ba24df283777b..554203752479bc 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,11 +1,9 @@ #include -#include #include +#include #include -using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; - void inline sgd_math( float* param_ptr, float* grad_ptr, @@ -26,17 +24,14 @@ void inline sgd_math( } } +using torch::stable::Tensor; -RAIIATH sgd_out_of_place( - const RAIIATH param, - const RAIIATH grad, +Tensor sgd_out_of_place( + const Tensor param, + const Tensor grad, const float weight_decay, const double lr, const bool maximize) { - - int64_t param_dim; - aoti_torch_get_dim(param.get(), ¶m_dim); - int64_t *param_sizes; int64_t *param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); @@ -46,56 +41,34 @@ RAIIATH sgd_out_of_place( aoti_torch_get_dtype(param.get(), ¶m_dtype); int32_t param_device_type; - int32_t param_device_index; aoti_torch_get_device_type(param.get(), ¶m_device_type); - aoti_torch_get_device_index(param.get(), ¶m_device_index); - - AtenTensorHandle out; - aoti_torch_empty_strided(param_dim, param_sizes, param_strides, param_dtype, param_device_type, param_device_index, &out); - - void* param_ptr; - aoti_torch_get_data_ptr(param.get(), ¶m_ptr); - void* grad_ptr; - aoti_torch_get_data_ptr(grad.get(), &grad_ptr); - void* out_ptr; - aoti_torch_get_data_ptr(out, &out_ptr); - - auto param_fp_ptr = reinterpret_cast(param_ptr); - auto grad_fp_ptr = reinterpret_cast(grad_ptr); - auto out_fp_ptr = reinterpret_cast(out_ptr); - int64_t param_numel; - aoti_torch_get_numel(param.get(), ¶m_numel); + AtenTensorHandle out_ath; + aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); + auto out = Tensor(out_ath); sgd_math( - param_fp_ptr, - grad_fp_ptr, - out_fp_ptr, + reinterpret_cast(param.data_ptr()), + reinterpret_cast(grad.data_ptr()), + reinterpret_cast(out.data_ptr()), weight_decay, lr, maximize, - param_numel + param.numel() ); - return RAIIATH(out); + return out; } - void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH param(to(stack[0])); - RAIIATH grad(to(stack[1])); - auto weight_decay = to(stack[2]); - auto lr = to(stack[3]); - auto maximize = to(stack[4]); - - RAIIATH raiiath_res = sgd_out_of_place( - std::move(param), - std::move(grad), - float(weight_decay), - lr, - maximize); - - stack[0] = from(raiiath_res.release()); + Tensor res = sgd_out_of_place( + to(stack[0]), + to(stack[1]), + float(to(stack[2])), + to(stack[3]), + to(stack[4])); + + stack[0] = from(res); } STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { @@ -106,14 +79,13 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { m.impl("sgd_out_of_place", &boxed_sgd_out_of_place); } -RAIIATH identity(RAIIATH t) { - return std::move(t); +Tensor identity(Tensor t) { + return t; } void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t(to(stack[0])); - RAIIATH raiiath_res = identity(std::move(t)); - stack[0] = from(raiiath_res.release()); + Tensor res = identity(to(stack[0])); + stack[0] = from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -128,18 +100,17 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { m.impl("identity", &boxed_identity); } -RAIIATH my_abs(RAIIATH t) { +Tensor my_abs(Tensor t) { const auto num_args = 1; StableIValue stack[num_args]; - stack[0] = from(t.release()); + stack[0] = from(t); aoti_torch_call_dispatcher("aten::abs", "", stack); - return RAIIATH(to(stack[0])); + return to(stack[0]); } void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t(to(stack[0])); - RAIIATH raiiath_res = my_abs(std::move(t)); - stack[0] = from(raiiath_res.release()); + Tensor tensor_res = my_abs(to(stack[0])); + stack[0] = from(tensor_res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -150,7 +121,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_abs", &boxed_my_abs); } -RAIIATH my_ones_like(RAIIATH t, StableIValue device) { +Tensor my_ones_like(Tensor t, StableIValue device) { const auto num_args = 6; StableIValue stack[num_args]; @@ -158,7 +129,7 @@ RAIIATH my_ones_like(RAIIATH t, StableIValue device) { aoti_torch_get_dtype(t.get(), &t_dtype); auto mf = aoti_torch_memory_format_contiguous_format(); - stack[0] = from(t.release()); + stack[0] = from(t); stack[1] = from(std::optional(t_dtype)); // dtype stack[2] = from(std::nullopt); // layout stack[3] = from(std::optional(device)); // device @@ -167,15 +138,12 @@ RAIIATH my_ones_like(RAIIATH t, StableIValue device) { aoti_torch_call_dispatcher("aten::ones_like", "", stack); - return RAIIATH(to(stack[0])); + return to(stack[0]); } void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t(to(stack[0])); - StableIValue device = stack[1]; - - RAIIATH raiiath_res = my_ones_like(std::move(t), device); - stack[0] = from(raiiath_res.release()); + Tensor res = my_ones_like(to(stack[0]), stack[1]); + stack[0] = from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -186,32 +154,29 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_ones_like", &boxed_my_ones_like); } -std::tuple exp_neg_is_leaf(RAIIATH t1, RAIIATH t2, RAIIATH t3) { - StableIValue stack1[1]; - stack1[0] = from(t1.release()); - aoti_torch_call_dispatcher("aten::exp", "", stack1); +std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) { + StableIValue stack_exp[1]; + stack_exp[0] = from(t1); + aoti_torch_call_dispatcher("aten::exp", "", stack_exp); - StableIValue stack2[1]; - stack2[0] = from(t2.release()); - aoti_torch_call_dispatcher("aten::neg", "", stack2); + StableIValue stack_neg[1]; + stack_neg[0] = from(t2); + aoti_torch_call_dispatcher("aten::neg", "", stack_neg); - StableIValue stack3[1]; - stack3[0] = from(t3.release()); - aoti_torch_call_dispatcher("aten::is_leaf", "", stack3); + StableIValue stack_is_leaf[1]; + stack_is_leaf[0] = from(t3); + aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf); return std::make_tuple( - RAIIATH(to(stack1[0])), - RAIIATH(to(stack2[0])), - to(stack3[0])); + to(stack_exp[0]), + to(stack_neg[0]), + to(stack_is_leaf[0])); } void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t1(to(stack[0])); - RAIIATH t2(to(stack[1])); - RAIIATH t3(to(stack[2])); - auto tuple = exp_neg_is_leaf(std::move(t1), std::move(t2), std::move(t3)); - stack[0] = from(std::get<0>(tuple).release()); - stack[1] = from(std::get<1>(tuple).release()); + auto tuple = exp_neg_is_leaf(to(stack[0]), to(stack[1]), to(stack[2])); + stack[0] = from(std::get<0>(tuple)); + stack[1] = from(std::get<1>(tuple)); stack[2] = from(std::get<2>(tuple)); } @@ -222,3 +187,70 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("exp_neg_is_leaf", &boxed_exp_neg_is_leaf); } + +Tensor neg_exp(Tensor t) { + StableIValue stack[1]; + stack[0] = from(t); + aoti_torch_call_dispatcher("aten::exp", "", stack); + aoti_torch_call_dispatcher("aten::neg", "", stack); + return to(stack[0]); +} + +void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor res = neg_exp(to(stack[0])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("neg_exp(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("neg_exp", &boxed_neg_exp); +} + +Tensor divide_neg_exp(Tensor t) { + StableIValue stack_neg[1]; + stack_neg[0] = from(t); + + StableIValue stack_exp[1]; + stack_exp[0] = from(t); + aoti_torch_call_dispatcher("aten::exp", "", stack_exp); + aoti_torch_call_dispatcher("aten::neg", "", stack_neg); + + StableIValue stack_div[2]; + stack_div[0] = stack_neg[0]; + stack_div[1] = stack_exp[0]; + aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div); + return to(stack_div[0]); +} + +void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor res = divide_neg_exp(to(stack[0])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("divide_neg_exp(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("divide_neg_exp", &boxed_divide_neg_exp); +} + +bool is_contiguous(Tensor t) { + return t.is_contiguous(); +} + +void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + bool res = is_contiguous(to(stack[0])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("is_contiguous(Tensor t) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("is_contiguous", &boxed_is_contiguous); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 2b6e5e0e1dd56b..2b4fbd40eb1a21 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -80,3 +80,39 @@ def exp_neg_is_leaf(t1, t2, t3) -> tuple[Tensor, Tensor, bool]: (exp(t1), neg(t2), is_leaf(t3)) """ return torch.ops.libtorch_agnostic.exp_neg_is_leaf.default(t1, t2, t3) + + +def neg_exp(t) -> Tensor: + """ + Returns a Tensor composing neg of exp + + Args: + t: Tensor + + Returns: neg(exp(t)) + """ + return torch.ops.libtorch_agnostic.neg_exp.default(t) + + +def divide_neg_exp(t) -> Tensor: + """ + Returns a Tensor division of neg and exp + + Args: + t: Tensor + + Returns: divide(neg(t), exp(t)) + """ + return torch.ops.libtorch_agnostic.divide_neg_exp.default(t) + + +def is_contiguous(t) -> bool: + """ + Returns a bool indicating if the input tensor is contiguous + + Args: + t: Tensor + + Returns: is_contiguous(t) + """ + return torch.ops.libtorch_agnostic.is_contiguous.default(t) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index 5324f5c67df89a..ba1d6411b0984a 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -24,7 +24,10 @@ class TestLibtorchAgnostic(TestCase): @classmethod def setUpClass(cls): - install_cpp_extension(extension_root=Path(__file__).parent.parent) + try: + import libtorch_agnostic # noqa: F401 + except Exception: + install_cpp_extension(extension_root=Path(__file__).parent.parent) @onlyCPU def test_slow_sgd(self, device): @@ -86,8 +89,8 @@ def test_my_abs(self, device): import libtorch_agnostic t = torch.rand(32, 16, device=device) - 0.5 - cpu_t = libtorch_agnostic.ops.my_abs(t) - self.assertEqual(cpu_t, torch.abs(t)) + res = libtorch_agnostic.ops.my_abs(t) + self.assertEqual(res, torch.abs(t)) def _make_cuda_tensors(prior_mem): cuda_t = libtorch_agnostic.ops.my_abs(t) @@ -101,6 +104,51 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) + def test_neg_exp(self, device): + import libtorch_agnostic + + t = torch.rand(32, 16, device=device) - 0.5 + res = libtorch_agnostic.ops.neg_exp(t) + self.assertEqual(res, torch.neg(torch.exp(t))) + + def _make_cuda_tensors(prior_mem): + cuda_res = libtorch_agnostic.ops.neg_exp(t) + self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) + self.assertEqual(cuda_res, torch.neg(torch.exp(t))) + + if t.is_cuda: + init_mem = torch.cuda.memory_allocated(device) + for _ in range(3): + _make_cuda_tensors(init_mem) + curr_mem = torch.cuda.memory_allocated(device) + self.assertEqual(curr_mem, init_mem) + + def test_divide_neg_exp(self, device): + import libtorch_agnostic + + t = torch.zeros(2, 3, device=device) - 0.5 + res = libtorch_agnostic.ops.divide_neg_exp(t) + self.assertEqual(res, torch.neg(t) / torch.exp(t)) + + def _make_cuda_tensors(prior_mem): + cuda_res = libtorch_agnostic.ops.divide_neg_exp(t) + self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) + self.assertEqual(cuda_res, torch.neg(t) / torch.exp(t)) + + if t.is_cuda: + init_mem = torch.cuda.memory_allocated(device) + for _ in range(3): + _make_cuda_tensors(init_mem) + curr_mem = torch.cuda.memory_allocated(device) + self.assertEqual(curr_mem, init_mem) + + def test_is_contiguous(self, device): + import libtorch_agnostic + + t = torch.rand(2, 7, device=device) + self.assertTrue(libtorch_agnostic.ops.is_contiguous(t)) + self.assertFalse(libtorch_agnostic.ops.is_contiguous(t.transpose(0, 1))) + # TODO: Debug this: # torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: # call_function libtorch_agnostic.my_ones_like.default(*(FakeTensor(..., size=(3, 1)), 'cpu'), diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index 47b8eed2fb97b4..fbd53b96234b22 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -16,186 +16,18 @@ #include #include #include -#include #include #include -#include #include #include static uint64_t add_counter = 0; static uint64_t last_saved_value = 0; -static c10::DeviceIndex custom_device_index = 0; - -static uint64_t abs_counter = 0; -static uint64_t last_abs_saved_value = 0; static uint64_t storageImpl_counter = 0; static uint64_t last_storageImpl_saved_value = 0; -namespace { - -// Using the simplest way to obtain continuous Tensor data and process it. -// This is a demo for using operand API, and you can add more complex logic -// for input and output tensor based on your custom device kernel. -void abs_kernel(at::TensorIteratorBase& iter) { - // Abs only have a input tensor and a output tensor. - auto& output_operand = iter.operand(0); - auto& input_operand = iter.operand(1); - auto& output_tensor_base = output_operand.tensor_base(); - auto& input_tensor_base = input_operand.tensor_base(); - TORCH_CHECK(!input_operand.original_tensor_base().defined(), - "input original tensor is defined."); - TORCH_CHECK(!output_operand.original_tensor_base().defined(), - "output original tensor is defined."); - // For easy test, only accept contiguous input tensor for calculate. - auto memory_format = input_tensor_base.suggest_memory_format(); - TORCH_CHECK(input_tensor_base.is_contiguous(memory_format), - "Input tensor need be contiguous."); - // Add necessary restrictions to ensure the security of the demo. - TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(), - "Intput and output tensor size are not equal."); - // Common dtype is calculate in TensorIteratorBase. - TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float, - "Only support float type.") - // Using for loop for abs calculate. - auto abs_function = [](float* output_ptr, const float* input_ptr, - const int64_t NUM) { - for (int64_t i = 0; i < NUM; ++i) { - *(output_ptr + i) = std::abs(*(input_ptr + i)); - } - }; - // To simplify the logic of the test demo code, - // we only use contiguous tensor to calculate on device side. - // And using input tensor memory format. - if (iter.is_contiguous()) { - // Add for will_resize flag check. You can convert to differernt - // tensor memory format when will_resize is True. - // If TensorIteratorConfig resize_outputs_ flag is true, and there are two - // situations: - // 1) Out tensor is undefined, and TensorIterator set will_resize to true; - // 2) Out tensor is defined and tensor size is not equal to input tensor size; - // TensorIterator set will_resize to true, and call set_output_raw_strided - // to resize output tensor. - // When output operand will_resize flag is ture, dummy - // device can convert tensor to dummy device preferred memory format. - // Here we don't convert tensor memory format, because it will become complex - // when dummy device want keep same memory format for training network. - TORCH_CHECK(output_operand.will_resize, - "output operand will_resize flag need be True."); - abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel()); - } else { - // Stride copy is not support for foo device, using cpu device instead. - // For abs op, the last situation is: output tensor is not contiguous with - // operand will_resize is False. - TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True."); - // Get a contiguous tensor with input memory format. - at::Tensor output = at::empty(output_tensor_base.sizes(), - input_tensor_base.options() - .memory_format(memory_format)); - // For structured op which inheried from TensorIteratorBase, maybe you need to - // call set_output_raw_strided function to update output stored in op sturctured. - // abs op is no need to do this. - output_operand.exchange_tensor(c10::MaybeOwned::owned(std::in_place, output)); - abs_function((float*)output_operand.tensor_base().mutable_data_ptr(), - (float*)iter.data_ptr(1), iter.numel()); - // Copy tensor base to original tensor base, and keep same scalar type and - // stride with cpu and gpu. - if (output_operand.original_tensor_base().defined() && - !output_operand.original_tensor_base().is_same(output_operand.tensor_base())) { - output_operand.original_tensor().copy_(output_operand.tensor()); - output_operand.restore_original_tensor(); - } - } -} - -void quantize_tensor_per_tensor_affine_privateuse1( - const at::Tensor& rtensor, - at::Tensor& qtensor, - double scale, - int64_t zero_point) { - // do nothing -} - -} // namespace - -namespace at::native { - -REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel); -REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1); - -} // namespace at::native -struct CustomBackendMetadata : public c10::BackendMeta { - // for testing this field will mutate when clone() is called by shallow_copy_from. - int backend_version_format_{-1}; - int format_number_{-1}; - mutable bool cloned_{false}; - // define the constructor - CustomBackendMetadata(int backend_version_format, int format_number) : - backend_version_format_(backend_version_format), format_number_(format_number) {} - c10::intrusive_ptr clone( - const c10::intrusive_ptr& ptr) const override { - cloned_ = true; - return c10::BackendMeta::clone(ptr); - } -}; - -// we need to register two functions for serialization -void for_serialization(const at::Tensor& t, std::unordered_map& m) { - if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr() == nullptr) { - return; - } - auto tmeta = dynamic_cast(t.unsafeGetTensorImpl()->get_backend_meta()); - if (tmeta->backend_version_format_ == 1) { - m["backend_version_format"] = true; - } - if (tmeta->format_number_ == 29) { - m["format_number"] = true; - } -} - -void for_deserialization(const at::Tensor& t, std::unordered_map& m) { - int backend_version_format{-1}; - int format_number{-1}; - if (m.find("backend_version_format") != m.end()) { - backend_version_format = 1; - } - if (m.find("format_number") != m.end()) { - format_number = 29; - } - c10::intrusive_ptr new_tmeta{std::unique_ptr( - new CustomBackendMetadata(backend_version_format, format_number))}; - t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta); -} - -void custom_serialization_registry() { - torch::jit::TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1, - &for_serialization, - &for_deserialization); -} - -//check if BackendMeta serialization correctly -bool check_backend_meta(const at::Tensor& t) { - if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr()) { - CustomBackendMetadata* tmeta = dynamic_cast( - t.unsafeGetTensorImpl()->get_backend_meta()); - if (tmeta->backend_version_format_==1 && tmeta->format_number_==29) { - return true; - } - } - return false; -} - -// a fake set function is exposed to the Python side -void custom_set_backend_meta(const at::Tensor& t) { - int backend_version_format{1}; - int format_number{29}; - c10::intrusive_ptr new_tmeta{std::unique_ptr( - new CustomBackendMetadata(backend_version_format, format_number))}; - t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta); -} - // A dummy storageImpl for our custom device, that secretly uses the CPU c10::intrusive_ptr make_custom_storage_impl(c10::StorageImpl::use_byte_size_t, c10::SymInt size_bytes, @@ -263,7 +95,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("add.Tensor", &custom_add_Tensor); m.impl("_copy_from_and_resize", &custom__copy_from_and_resize); m.impl("set_.source_Storage", &custom_set_source_Storage); - m.impl("quantize_per_tensor", at::native::quantize_per_tensor); } void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { @@ -294,16 +125,8 @@ bool custom_add_called() { return called; } -void set_custom_device_index(c10::DeviceIndex device_index) { - custom_device_index = device_index; -} - -const at::Generator& default_generator(c10::DeviceIndex device_index) { - return at::globalContext().defaultGenerator(at::Device(c10::DeviceType::PrivateUse1, device_index));; -} - void fallback_with_undefined_tensor() { - at::Tensor first = at::empty((2,3)).to(at::DeviceType::PrivateUse1); + at::Tensor first = at::empty({2, 3}).to(at::DeviceType::PrivateUse1); at::Tensor second = at::Tensor(); at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1); at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1); @@ -316,36 +139,6 @@ void fallback_with_undefined_tensor() { grad_scale, found_inf); } -struct CustomAutogradFnReturnsSelf : public torch::autograd::Function { - - static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) { - return self; - } - - static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) { - return {grad_output[0] * 0.5}; - } -}; - -struct CustomAutogradFnAliasing : public torch::autograd::Function { - - static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) { - return self.view_symint(self.sym_sizes()); - } - - static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) { - return {grad_output[0] * 0.5}; - } -}; - -at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { - return CustomAutogradFnReturnsSelf::apply(x); -} - -at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { - return CustomAutogradFnAliasing::apply(x); -} - // Here, we're exposing a custom device object that corresponds to our custom backend. // We do this using pybind: exposing an "extension_name.custom_device()" function in python, // that's implemented in C++. @@ -353,22 +146,7 @@ at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("custom_device", &get_custom_device, "get custom device object"); m.def("custom_add_called", &custom_add_called, "check if our custom add function was called"); - m.def("set_custom_device_index", &set_custom_device_index, "set custom device index"); m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method"); m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called"); - m.def("custom_set_backend_meta", &custom_set_backend_meta, "a fake set tensor BackendMeta function"); - m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly"); - m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function"); - m.def("default_generator", &default_generator, "default_generator for privateuse1"); m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1"); - - // Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++ - m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self); -} - -TORCH_LIBRARY(_test_funcs, m) { - m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); -} -TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) { - m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing); } diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py index 1e07dd49ff9517..05b8955b6557b8 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py @@ -43,7 +43,7 @@ def is_available(): def current_device(): return torch.accelerator.current_device_index() - def get_rng_state(device): + def get_rng_state(device="openreg"): if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): @@ -54,7 +54,7 @@ def get_rng_state(device): default_generator = pytorch_openreg._C._get_default_generator(idx) return default_generator.get_state() - def set_rng_state(new_state, device): + def set_rng_state(new_state, device="openreg"): if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): @@ -65,9 +65,32 @@ def set_rng_state(new_state, device): default_generator = pytorch_openreg._C._get_default_generator(idx) default_generator.set_state(new_state) + def initial_seed() -> int: + _lazy_init() + idx = current_device() + default_generator = pytorch_openreg._C._get_default_generator(idx) + return default_generator.initial_seed() + + def manual_seed(seed: int) -> None: + seed = int(seed) + + idx = current_device() + default_generator = pytorch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + + def manual_seed_all(seed: int) -> None: + seed = int(seed) + + for idx in range(device_count()): + default_generator = pytorch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + def is_initialized(): return module._initialized + def _is_in_bad_fork(): + return False + def _lazy_init(): if is_initialized(): return @@ -85,6 +108,10 @@ def _lazy_init(): module.current_device = current_device # type: ignore[assignment] module.get_rng_state = get_rng_state # type: ignore[assignment] module.set_rng_state = set_rng_state # type: ignore[assignment] + module._is_in_bad_fork = _is_in_bad_fork # type: ignore[assignment] + module.initial_seed = initial_seed # type: ignore[assignment] + module.manual_seed = manual_seed # type: ignore[assignment] + module.manual_seed_all = manual_seed_all # type: ignore[assignment] return module @@ -92,3 +119,4 @@ def _lazy_init(): # Set all the appropriate state on PyTorch torch.utils.rename_privateuse1_backend("openreg") torch._register_device_module("openreg", _create_module()) +torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py index 8d98387cf5f28c..d4c49bd28d458c 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -78,9 +78,9 @@ def _post_process(): elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: # Only handle inplace ops returning their first arg assert len(args) >= 1, f"Inplace {op} needs at least one arg" - assert ( - len(op._schema.returns) == 1 - ), f"NYI Inplace {op} with more than one return" + assert len(op._schema.returns) == 1, ( + f"NYI Inplace {op} with more than one return" + ) op_name = op.overloadpacket._qualified_op_name real_res = args[0] elif any(r.alias_info is not None for r in op._schema.returns): diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py index 8489a7bcd9a8e8..d339869635001b 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py @@ -104,11 +104,6 @@ def __init__(self, num_devices): super().__init__() self.num_devices = num_devices self.is_initialized = False - self.rlock = threading.RLock() - - def _lazy_init(self): - if self.is_initialized: - return # State of our driver self.curr_device_idx = 0 @@ -119,6 +114,11 @@ def _lazy_init(self): self.host_allocator = HostAllocator() self.event_belong = {} + self.rlock = threading.RLock() + + def _lazy_init(self): + if self.is_initialized: + return self.devices = [] for i in range(self.num_devices): @@ -136,7 +136,6 @@ def _lazy_init(self): def exec(self, cmd, *args): with self.rlock: - self._lazy_init() log.info("Main process launched: %s(*%s)", cmd, safe_str(args)) if cmd in Driver.registry: @@ -151,6 +150,7 @@ def exec(self, cmd, *args): return res def run_on_executor(self, device_idx, cmd, *args): + self._lazy_init() req_queue, ans_queue, _ = self.devices[device_idx] stream = self.getStream(device_idx) validate_send_queue_args(cmd, args) @@ -161,7 +161,7 @@ def run_on_executor(self, device_idx, cmd, *args): @register(registry) def hasPrimaryContext(self, device_idx): - return device_idx >= 0 and device_idx < len(self.devices) + return device_idx >= 0 and device_idx < self.num_devices @register(registry) def deviceCount(self, *args): diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py index 80194b38aaebf0..0f54f2ec4df000 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py @@ -67,7 +67,7 @@ def prepare_for_sending(args, kwargs): def convert(obj): if type(obj) not in VALID_QUEUE_TYPES_IN: raise RuntimeError( - f"Cannot send object of type {type(obj)} " "over openreg device pipe." + f"Cannot send object of type {type(obj)} over openreg device pipe." ) if isinstance(obj, torch.Tensor): @@ -82,8 +82,7 @@ def receive_after_sending(allocator, args, kwargs): def convert(obj): if type(obj) not in VALID_QUEUE_TYPES_OUT: raise RuntimeError( - f"Received invalid object of type {type(obj)} " - "over openreg device pipe." + f"Received invalid object of type {type(obj)} over openreg device pipe." ) if isinstance(obj, OpenRegTensorMeta): diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h index ca6d7903055cf6..a04248f2e50294 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h @@ -38,4 +38,13 @@ static void ReportAndDelete(void* ptr) { PyErr_Restore(type, value, traceback); } +#define REGISTER_PRIVATEUSE1_SERIALIZATION( \ + FOR_SERIALIZATION, FOR_DESERIALIZATION) \ + static int register_serialization() { \ + torch::jit::TensorBackendMetaRegistry( \ + c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \ + return 0; \ + } \ + static const int _temp = register_serialization(); + } // namespace openreg diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp index d673b763424a5e..4d9bde06011833 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp @@ -1,14 +1,25 @@ #include "OpenReg.h" #include -#include -#include +#include +#include #include +#include +#include #include #include +#include +#include +#include +#include +#include #include +#include +#include +#include + #include namespace openreg { @@ -102,7 +113,15 @@ at::Tensor as_strided_openreg( return at::cpu::as_strided(self, size, stride, storage_offset_); } -at::Tensor& set_openreg( +const at::Tensor& resize__openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::resize_( + self, C10_AS_INTARRAYREF_SLOW(size), memory_format); +} + +at::Tensor& set_source_Storage_storage_offsetset_openreg( at::Tensor& result, at::Storage storage, int64_t storage_offset, @@ -165,6 +184,80 @@ custom_scaled_dot_product_fused_attention_overrideable_backward( } } +// Using the simplest way to obtain continuous Tensor data and process it. +// This is a demo for using operand API, and you can add more complex logic +// for input and output tensor based on your custom device kernel. +void abs_kernel(at::TensorIteratorBase& iter) { + // Abs only have a input tensor and a output tensor. + auto& output_operand = iter.operand(0); + auto& input_operand = iter.operand(1); + auto& output_tensor_base = output_operand.tensor_base(); + auto& input_tensor_base = input_operand.tensor_base(); + TORCH_CHECK(!input_operand.original_tensor_base().defined(), + "input original tensor is defined."); + TORCH_CHECK(!output_operand.original_tensor_base().defined(), + "output original tensor is defined."); + // For easy test, only accept contiguous input tensor for calculate. + auto memory_format = input_tensor_base.suggest_memory_format(); + TORCH_CHECK(input_tensor_base.is_contiguous(memory_format), + "Input tensor need be contiguous."); + // Add necessary restrictions to ensure the security of the demo. + TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(), + "Intput and output tensor size are not equal."); + // Common dtype is calculate in TensorIteratorBase. + TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float, + "Only support float type.") + // Using for loop for abs calculate. + auto abs_function = [](float* output_ptr, const float* input_ptr, + const int64_t NUM) { + for (int64_t i = 0; i < NUM; ++i) { + *(output_ptr + i) = std::abs(*(input_ptr + i)); + } + }; + // To simplify the logic of the test demo code, + // we only use contiguous tensor to calculate on device side. + // And using input tensor memory format. + if (iter.is_contiguous()) { + // Add for will_resize flag check. You can convert to differernt + // tensor memory format when will_resize is True. + // If TensorIteratorConfig resize_outputs_ flag is true, and there are two + // situations: + // 1) Out tensor is undefined, and TensorIterator set will_resize to true; + // 2) Out tensor is defined and tensor size is not equal to input tensor size; + // TensorIterator set will_resize to true, and call set_output_raw_strided + // to resize output tensor. + // When output operand will_resize flag is ture, dummy + // device can convert tensor to dummy device preferred memory format. + // Here we don't convert tensor memory format, because it will become complex + // when dummy device want keep same memory format for training network. + TORCH_CHECK(output_operand.will_resize, + "output operand will_resize flag need be True."); + abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel()); + } else { + // Stride copy is not support for foo device, using cpu device instead. + // For abs op, the last situation is: output tensor is not contiguous with + // operand will_resize is False. + TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True."); + // Get a contiguous tensor with input memory format. + at::Tensor output = at::empty(output_tensor_base.sizes(), + input_tensor_base.options() + .memory_format(memory_format)); + // For structured op which inheried from TensorIteratorBase, maybe you need to + // call set_output_raw_strided function to update output stored in op sturctured. + // abs op is no need to do this. + output_operand.exchange_tensor(c10::MaybeOwned::owned(std::in_place, output)); + abs_function((float*)output_operand.tensor_base().mutable_data_ptr(), + (float*)iter.data_ptr(1), iter.numel()); + // Copy tensor base to original tensor base, and keep same scalar type and + // stride with cpu and gpu. + if (output_operand.original_tensor_base().defined() && + !output_operand.original_tensor_base().is_same(output_operand.tensor_base())) { + output_operand.original_tensor().copy_(output_operand.tensor()); + output_operand.restore_original_tensor(); + } + } +} + int64_t _fused_sdp_choice_privateuse1( const at::Tensor& query, const at::Tensor& key, @@ -178,19 +271,148 @@ int64_t _fused_sdp_choice_privateuse1( return static_cast(backend); } +void quantize_tensor_per_tensor_affine_privateuse1( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point) { + // Just test the process, so do nothing +} + +struct CustomAutogradFnReturnsSelf + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; + +struct CustomAutogradFnAliasing + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self.view_symint(self.sym_sizes()); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; + +at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { + return CustomAutogradFnReturnsSelf::apply(x); +} + +at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { + return CustomAutogradFnAliasing::apply(x); +} + +/* Notes: + * + * OpenReg is currently designed to simulate device memory through multiple + * subprocesses on purpose to ensure we don't mistakenly poke at the "device's + * memory" from the main process. And be able to simulate the same thing that + * happens with other accelerators: any metadata-only change is cpu-only + * (main process), any data change must go through to the device (other process) + * and any data transfer between the two is expensive (serializing the whole + * Tensor). + * + * Currently, for the efficiency of IPC, most operations are to pass the Tensor + * metadata, and only a small number of operations involving copy will serialize + * and pass the Tensor body by custom pickler provided by torch.multiprocess. + * + * Therefore, in principle, only operations related to Metadata modification can + * be directly implemented at the C++ level and registered in PrivateUse1; but + * if memory access is involved, the relevant operations must be implemented at + * the Python level, otherwise invalid memory access will result. + */ + TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("empty.memory_format", empty_openreg); m.impl("empty_strided", empty_strided_openreg); m.impl("as_strided", as_strided_openreg); - m.impl("set_.source_Storage_storage_offset", set_openreg); + m.impl("resize_", resize__openreg); + m.impl("set_.source_Storage", at::native::set_); + m.impl("set_.source_Storage_storage_offset", set_source_Storage_storage_offsetset_openreg); + m.impl("quantize_per_tensor", at::native::quantize_per_tensor); m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1); m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable); m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward); } + +struct OpenRegBackendMeta : public c10::BackendMeta { + OpenRegBackendMeta(int version_number, int format_number) + : version_number_(version_number), format_number_(format_number) {} + + int version_number_{-1}; + int format_number_{-1}; +}; + +void for_serialization( + const at::Tensor& t, + std::unordered_map& m) { + auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta(); + + if (meta_ptr != nullptr) { + auto o_meta_ptr = dynamic_cast(meta_ptr); + if (o_meta_ptr->version_number_ == 1) { + m["version_number"] = true; + } + if (o_meta_ptr->format_number_ == 29) { + m["format_number"] = true; + } + } +} + +void for_deserialization( + const at::Tensor& t, + std::unordered_map& m) { + int version_number{-1}; + int format_number{-1}; + + if (m.find("version_number") != m.end()) { + version_number = 1; + } + if (m.find("format_number") != m.end()) { + format_number = 29; + } + + c10::intrusive_ptr meta{std::unique_ptr( + new OpenRegBackendMeta(version_number, format_number))}; + t.unsafeGetTensorImpl()->set_backend_meta(meta); +} + +REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization) } // namespace openreg namespace at::native { +REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel); +REGISTER_PRIVATEUSE1_DISPATCH( + quantize_tensor_per_tensor_affine_stub, + &openreg::quantize_tensor_per_tensor_affine_privateuse1); REGISTER_PRIVATEUSE1_DISPATCH( _fused_sdp_choice_stub, &openreg::_fused_sdp_choice_privateuse1); } // namespace at::native + +TORCH_LIBRARY(openreg, m) { + m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor"); + m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); +} + +TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) { + m.impl("custom_autograd_fn_aliasing", &openreg::custom_autograd_fn_aliasing); + m.impl( + "custom_autograd_fn_returns_self", + &openreg::custom_autograd_fn_returns_self); +} diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 2744591307f28b..6067a25883993f 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -3,7 +3,10 @@ import copy import functools import itertools +import os +import tempfile from typing import Callable, Optional, Union +from unittest.mock import MagicMock import torch import torch.distributed as dist @@ -17,9 +20,12 @@ MixedPrecisionPolicy, OffloadPolicy, ) +from torch.distributed.fsdp._fully_shard._fsdp_api import AllGather from torch.distributed.fsdp._fully_shard._fsdp_collectives import ( _div_if_needed, _get_gradient_divide_factors, + DefaultAllGather, + DefaultReduceScatter, foreach_all_gather, foreach_all_gather_copy_out, foreach_reduce, @@ -34,7 +40,10 @@ from torch.distributed.tensor import DTensor from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.experimental import implicit_replication -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_multicast_support, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_fsdp import ( check_sharded_parity, DoubleLinear, @@ -157,6 +166,7 @@ def _test_all_gather( all_gather_stream, ): def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup): + all_gather_comm = DefaultAllGather() all_gather_result = foreach_all_gather( fsdp_param_group.fsdp_params, group, @@ -164,6 +174,7 @@ def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup): all_gather_copy_in_stream=all_gather_copy_in_stream, all_gather_stream=all_gather_stream, device=self.device, + all_gather_comm=all_gather_comm, ) foreach_all_gather_copy_out(all_gather_result, fsdp_params, group) # Transition to unsharded state to register unsharded parameters @@ -256,6 +267,7 @@ def _test_reduce_scatter( group = fsdp_param_group.mesh_info.shard_process_group self.assertEqual(group.size(), self.world_size) all_reduce_stream = device_module.Stream() + comm = DefaultReduceScatter() ( _, _, @@ -268,10 +280,11 @@ def _test_reduce_scatter( unsharded_grads, group, reduce_scatter_stream, + comm, orig_dtype=orig_params[0].dtype, reduce_dtype=reduce_scatter_dtype, device=self.device, - reduce_scatter_reduce_op=None, + gradient_divide_factor=None, all_reduce_group=None, all_reduce_stream=all_reduce_stream, all_reduce_hook=None, @@ -283,16 +296,19 @@ def _test_reduce_scatter( ) # Check reduce-scatter correctness - predivide_factor, postdivide_factor = _get_gradient_divide_factors( - group, None, reduce_scatter_dtype - ) + ( + predivide_factor, + postdivide_factor, + _, + all_reduce_op, + ) = _get_gradient_divide_factors(group, None, reduce_scatter_dtype) reduced_grads = [grad.detach().clone() for grad in unsharded_grads] for grad in reduced_grads: _div_if_needed(grad, predivide_factor) dist.all_reduce( grad, group=group, - op=dist.ReduceOp.AVG if predivide_factor is None else dist.ReduceOp.SUM, + op=all_reduce_op, ) _div_if_needed(grad, postdivide_factor) for fsdp_param, reduced_grad in zip(fsdp_params, reduced_grads): @@ -313,13 +329,13 @@ def test_fully_shard_communication_count(self): reduce-scatters during forward and backward. """ self.run_subtests( - {"reshard_after_forward": [True, False, 2]}, + {"reshard_after_forward": [True, False, 2, None]}, self._test_communication_count, ) def _test_communication_count( self, - reshard_after_forward: Union[bool, int], + reshard_after_forward: Union[bool, int, None], ): torch.manual_seed(42) model_args = ModelArgs() @@ -345,12 +361,16 @@ def _test_communication_count( with CommDebugMode() as bwd_comm_mode: loss.sum().backward() bwd_comm_counts = bwd_comm_mode.get_comm_counts() - if reshard_after_forward is False: - self.assertEqual(len(bwd_comm_counts), 1) - else: - # The root always does not reshard after forward + if reshard_after_forward is None: + # 2 means two types of collectives (all-gather, reduce-scatter) self.assertEqual(len(bwd_comm_counts), 2) + # do not reshard root model self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_blocks) + elif reshard_after_forward: + self.assertEqual(len(bwd_comm_counts), 2) + self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_blocks + 1) + else: + self.assertEqual(len(bwd_comm_counts), 1) self.assertEqual( bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_blocks + 1 ) @@ -439,21 +459,28 @@ def test_set_reshard_after_forward(self): comm_count should perform same as test_fully_shard_communication_count. """ self.run_subtests( - {"set_reshard_after_forward": [True, False], "recurse": [True, False]}, + { + "set_reshard_after_forward": [True, False, None], + "recurse": [True, False], + }, self._test_set_reshard_after_forward_by_communication_count, ) def _test_set_reshard_after_forward_by_communication_count( self, - set_reshard_after_forward: bool, + set_reshard_after_forward: Union[bool, None], recurse: bool, ): torch.manual_seed(42) model_args = ModelArgs() model = Transformer(model_args).to(device_type) - fully_shard_fn = functools.partial( - fully_shard, reshard_after_forward=not set_reshard_after_forward - ) + if set_reshard_after_forward is None: + fully_shard_fn = fully_shard + else: + fully_shard_fn = functools.partial( + fully_shard, reshard_after_forward=not set_reshard_after_forward + ) + num_blocks = 0 for module in model.modules(): if isinstance(module, TransformerBlock): @@ -463,9 +490,10 @@ def _test_set_reshard_after_forward_by_communication_count( num_fsdp_modules = sum( isinstance(module, FSDPModule) for module in model.modules() ) - model.set_reshard_after_forward( - reshard_after_forward=set_reshard_after_forward, recurse=recurse - ) + if set_reshard_after_forward is not None: + model.set_reshard_after_forward( + reshard_after_forward=set_reshard_after_forward, recurse=recurse + ) torch.manual_seed(42 + self.rank) inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type) @@ -478,13 +506,23 @@ def _test_set_reshard_after_forward_by_communication_count( with CommDebugMode() as bwd_comm_mode: loss.sum().backward() bwd_comm_counts = bwd_comm_mode.get_comm_counts() - # If recurse is False, set_reshard_after_forward only affects the root module, - # resulting in comm_counts identical to those without set_reshard_after_forward. - if recurse == set_reshard_after_forward: + # If recurse is False, set_reshard_after_forward only affects the root module + if set_reshard_after_forward is None: self.assertEqual(len(bwd_comm_counts), 2) self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_blocks) + elif set_reshard_after_forward: + self.assertEqual(len(bwd_comm_counts), 2) + self.assertEqual( + bwd_comm_counts[c10d_ops._allgather_base_], + num_blocks + 1 if recurse else 1, + ) else: - self.assertEqual(len(bwd_comm_counts), 1) + if recurse: + self.assertEqual(len(bwd_comm_counts), 1) + else: + self.assertEqual(len(bwd_comm_counts), 2) + self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_blocks) + self.assertEqual( bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_blocks + 1 ) @@ -500,14 +538,14 @@ def test_fully_shard_backward_prefetch(self): # Activation checkpointing should not affect the expected FSDP events self.run_subtests( { - "reshard_after_forward": [True, False, 2], + "reshard_after_forward": [True, False, 2, None], "checkpoint_impl": [None, "utils", "composable"], }, self._test_backward_prefetch_forward_backward, ) self.run_subtests( { - "reshard_after_forward": [True, False, 2], + "reshard_after_forward": [True, False, 2, None], "checkpoint_impl": [None, "utils", "composable"], }, self._test_backward_prefetch_multi_forward, @@ -515,7 +553,9 @@ def test_fully_shard_backward_prefetch(self): self._test_backward_prefetch_unused_in_backward(True) def _test_backward_prefetch_forward_backward( - self, reshard_after_forward: Union[bool, int], checkpoint_impl: Optional[str] + self, + reshard_after_forward: Union[bool, int, None], + checkpoint_impl: Optional[str], ): n_layers = 3 model, optim, inp = self._init_transformer( @@ -529,8 +569,9 @@ def _test_backward_prefetch_forward_backward( FSDPParamGroup.post_backward, events ) # Check the order for normal 1 forward, 1 backward, 1 optimizer step - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for iter_idx in range(3): loss = model(inp) @@ -543,20 +584,25 @@ def _test_backward_prefetch_forward_backward( self.assertEqual(events, expected_events) events.clear() loss.sum().backward() - expected_events = [ - # Root does not reshard after forward so there is no - # unshard event for it in backward - ("unshard", "layers.2", TrainingState.PRE_BACKWARD), - # Explicit backward prefetching moves the unshards early - # by one module (note how swapping each unshard down one - # event would give the natural event order) - ("unshard", "layers.1", TrainingState.PRE_BACKWARD), - ("post_backward", "layers.2", TrainingState.POST_BACKWARD), - ("unshard", "layers.0", TrainingState.PRE_BACKWARD), - ("post_backward", "layers.1", TrainingState.POST_BACKWARD), - ("post_backward", "layers.0", TrainingState.POST_BACKWARD), - ("post_backward", "", TrainingState.POST_BACKWARD), - ] + expected_events = [] + # Root does not reshard after forward so there is no + # unshard event for it in backward + if reshard_after_forward is not None: + expected_events.append(("unshard", "", TrainingState.PRE_BACKWARD)) + expected_events.extend( + [ + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + # Explicit backward prefetching moves the unshards early + # by one module (note how swapping each unshard down one + # event would give the natural event order) + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + ) if reshard_after_forward is False: # No reshard after forward means no backward unshards expected_events = [e for e in expected_events if e[0] != "unshard"] @@ -580,8 +626,9 @@ def _test_backward_prefetch_multi_forward( FSDPParamGroup.post_backward, events ) # Check the order for multiple forwards before 1 backward - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): loss1 = model(inp) loss2 = model(inp) @@ -590,31 +637,40 @@ def _test_backward_prefetch_multi_forward( ("unshard", "layers.0", TrainingState.FORWARD), ("unshard", "layers.1", TrainingState.FORWARD), ("unshard", "layers.2", TrainingState.FORWARD), - # Root does not reshard after forward so there is not another - # unshard event for it - ("unshard", "layers.0", TrainingState.FORWARD), - ("unshard", "layers.1", TrainingState.FORWARD), - ("unshard", "layers.2", TrainingState.FORWARD), ] + if reshard_after_forward is not None: + expected_events.append(("unshard", "", TrainingState.FORWARD)) + expected_events.extend( + [ + ("unshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ] + ) if reshard_after_forward is False: # No reshard after forward means no second set of unshards - expected_events = expected_events[:-3] + expected_events = expected_events[:-4] self.assertEqual(events, expected_events) events.clear() (loss1 + loss2).sum().backward() - expected_events = [ - # Same as the single forward/backward case except the root's - # post-backward does not run until the end of backward in the - # final callback (since the input not requiring gradient means - # that we do not have a tensor on which to hook for - # post-backward) - ("unshard", "layers.2", TrainingState.PRE_BACKWARD), - ("unshard", "layers.1", TrainingState.PRE_BACKWARD), - ("post_backward", "layers.2", TrainingState.POST_BACKWARD), - ("unshard", "layers.0", TrainingState.PRE_BACKWARD), - ("post_backward", "layers.1", TrainingState.POST_BACKWARD), - ("post_backward", "layers.0", TrainingState.POST_BACKWARD), - ] + expected_events = [] + if reshard_after_forward is not None: + expected_events.append(("unshard", "", TrainingState.PRE_BACKWARD)) + expected_events.extend( + [ + # Same as the single forward/backward case except the root's + # post-backward does not run until the end of backward in the + # final callback (since the input not requiring gradient means + # that we do not have a tensor on which to hook for + # post-backward) + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ] + ) if reshard_after_forward is False: # No reshard after forward means no backward unshards expected_events = [e for e in expected_events if e[0] != "unshard"] @@ -635,7 +691,7 @@ def _test_backward_prefetch_multi_forward( events.clear() def _test_backward_prefetch_unused_in_backward( - self, reshard_after_forward: Union[bool, int] + self, reshard_after_forward: Union[bool, int, None] ): """ Test a model with a linear module then a split into two linear modules, @@ -657,8 +713,9 @@ def _test_backward_prefetch_unused_in_backward( post_backward_with_record = self._get_post_backward_with_record( FSDPParamGroup.post_backward, events ) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): loss1, loss2 = model(inp) expected_events = [ @@ -732,6 +789,7 @@ def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None: ) expected_backward_events = [ # Default backward prefetching + ("unshard", "", TrainingState.PRE_BACKWARD), ("unshard", "layers.3", TrainingState.PRE_BACKWARD), ("unshard", "layers.2", TrainingState.PRE_BACKWARD), ("reshard", "layers.3", TrainingState.POST_BACKWARD), @@ -747,9 +805,11 @@ def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None: ("reshard", "", TrainingState.POST_BACKWARD), ("post_backward", "", TrainingState.POST_BACKWARD), ] - with patch_unshard(unshard_with_record), patch_reshard( - reshard_with_record - ), patch_post_backward(post_backward_with_record): + with ( + patch_unshard(unshard_with_record), + patch_reshard(reshard_with_record), + patch_post_backward(post_backward_with_record), + ): set_forward_prefetch(model, num_to_prefetch=1) loss = model(inp) expected_forward_events = [ @@ -763,6 +823,7 @@ def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None: ("unshard", "layers.3", TrainingState.FORWARD), ("reshard", "layers.2", TrainingState.FORWARD), ("reshard", "layers.3", TrainingState.FORWARD), + ("reshard", "", TrainingState.FORWARD), ] self.assertEqual(events, expected_forward_events) events.clear() @@ -783,6 +844,7 @@ def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None: ("reshard", "layers.1", TrainingState.FORWARD), ("reshard", "layers.2", TrainingState.FORWARD), ("reshard", "layers.3", TrainingState.FORWARD), + ("reshard", "", TrainingState.FORWARD), ] self.assertEqual(events, expected_forward_events) events.clear() @@ -831,16 +893,20 @@ def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None: ("reshard", "layers.2", TrainingState.FORWARD), ("unshard", "layers.3", TrainingState.FORWARD), ("reshard", "layers.3", TrainingState.FORWARD), + ("reshard", "", TrainingState.FORWARD), ] - with patch_unshard(unshard_with_record), patch_reshard( - reshard_with_record - ), patch_post_backward(post_backward_with_record): + with ( + patch_unshard(unshard_with_record), + patch_reshard(reshard_with_record), + patch_post_backward(post_backward_with_record), + ): set_backward_prefetch(model, num_to_prefetch=1) loss = model(inp) self.assertEqual(events, expected_forward_events) events.clear() loss.sum().backward() expected_backward_events = [ + ("unshard", "", TrainingState.PRE_BACKWARD), # Root prefetches `layers.3` per default ("unshard", "layers.3", TrainingState.PRE_BACKWARD), # `layers.i` prefetches for `layers.i-1` (same as default) @@ -867,6 +933,7 @@ def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None: events.clear() loss.sum().backward() expected_backward_events = [ + ("unshard", "", TrainingState.PRE_BACKWARD), # Root prefetches `layers.3` per default ("unshard", "layers.3", TrainingState.PRE_BACKWARD), # `layers.i` prefetches for `layers.i-1` and `layers.i-2` @@ -915,8 +982,9 @@ def test_fully_shard_multi_module_backward_prefetch(self): (2, model_args.max_seq_len), device=device_type.type, ) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for _ in range(3): loss = model(inp) @@ -994,8 +1062,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: FSDPParamGroup.post_backward, events ) inp = torch.randn((2, 16), device=device_type.type) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for _ in range(3): loss = model(inp) @@ -1062,7 +1131,7 @@ def test_backward_misprefetch(self): def _init_transformer( self, n_layers: int, - reshard_after_forward: Union[bool, int], + reshard_after_forward: Union[bool, int, None], checkpoint_impl: Optional[str], ): model_args = ModelArgs( @@ -1239,5 +1308,214 @@ def test_unshard_without_lazy_init(self): self.assertEqual(ref_param, param) +class TestFullyShardAllocFromPG(FSDPTest): + # The messages might change when we move to a different NCCL version. + # Please update this test if it starts failing. + MEMORY_REGISTER_RE = ( + "NCCL INFO register comm 0x[0-9a-f]+ buffer 0x[0-9a-f]+ size [0-9]+" + ) + + @classmethod + def _run(cls, *args, **kwargs): + cls.nccl_log_dir = tempfile.TemporaryDirectory() + os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_DEBUG_SUBSYS"] = "INIT,ENV,REG" + os.environ["NCCL_DEBUG_FILE"] = cls.nccl_log_dir.name + "/nccl_log" + super()._run(*args, **kwargs) + + @skip_if_lt_x_gpu(2) + # The NCCL PG refuses to allocate tensors if multicast is unavailable, see + # https://github.com/pytorch/pytorch/blob/503362d019b3782581492af7767945dbd75ca1c9/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L5634 + @requires_multicast_support() + def test_fully_shard_alloc_from_pg(self): + torch.manual_seed(42) + model_args = ModelArgs() + model = Transformer(model_args) + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard(module) + fully_shard(model) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") + + loss = model(inp) + loss.sum().backward() + + torch.distributed.barrier() + torch.cuda.synchronize() + + with open(self.nccl_log_dir.name + "/nccl_log") as f: + self.assertNotRegex(f.read(), self.MEMORY_REGISTER_RE) + + for module in model.modules(): + if isinstance(module, TransformerBlock): + module.set_allocate_memory_from_process_group_for_comm(True) + model.set_allocate_memory_from_process_group_for_comm(True) + + loss = model(inp) + loss.sum().backward() + + torch.distributed.barrier() + torch.cuda.synchronize() + + with open(self.nccl_log_dir.name + "/nccl_log") as f: + self.assertRegex(f.read(), self.MEMORY_REGISTER_RE) + + @skip_if_lt_x_gpu(2) + def test_exception_when_used_together_with_comm_hooks(self): + model = nn.Linear(16, 16) + model = fully_shard(model) + # ok + model.set_allocate_memory_from_process_group_for_comm(True) + + # setting custom hook after is also ok + # (overrides set_allocate_memory_from_process_group_for_comm) + mock_all_gather = MagicMock(spec=AllGather) + model.set_custom_all_gather(mock_all_gather) + + # setting this after custom comm is used is ko + with self.assertRaises(AssertionError): + model.set_allocate_memory_from_process_group_for_comm(True) + + +class TestFullyShardForceSumReduction(FSDPTest): + # The messages might change when we move to a different NCCL version. + # Please update this test if it starts failing. + COLLECTIVE_RE = ( + "NCCL INFO {coll}: opCount [0-9a-f]+ sendbuff 0x[0-9a-f]+ recvbuff 0x[0-9a-f]+ " + "count {count} datatype [0-9]+ op {reduce_op} root [0-9]+ comm 0x[0-9a-f]+" + ) + # See here for the numerical values for each reduction op: + # https://github.com/NVIDIA/nccl/blob/72d2432094d6ae36abd6e511c3a16a2d052dbf94/src/nccl.h.in#L260-L275 + SUM_REDUCTION = 0 + AVG_REDUCTION = 4 + + @classmethod + def _run(cls, *args, **kwargs): + cls.nccl_log_dir = tempfile.TemporaryDirectory() + os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_DEBUG_SUBSYS"] = "COLL" + os.environ["NCCL_DEBUG_FILE"] = cls.nccl_log_dir.name + "/nccl_log" + super()._run(*args, **kwargs) + + # Test reduce-scatter only on plain FSDP on 2 GPUs + @skip_if_lt_x_gpu(2) + def test_fully_shard_force_sum_reduce_scatter(self): + torch.manual_seed(42) + model_args = ModelArgs() + model = Transformer(model_args) + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard(module) + fully_shard(model) + + # We target a specific count so that we don't pick up the barrier ops + layer_numel = sum(w.numel() for w in model.layers[0].parameters()) + comms_size = layer_numel // self.world_size + reduce_scatter_avg_re = self.COLLECTIVE_RE.format( + coll="ReduceScatter", count=comms_size, reduce_op=self.AVG_REDUCTION + ) + reduce_scatter_sum_re = self.COLLECTIVE_RE.format( + coll="ReduceScatter", count=comms_size, reduce_op=self.SUM_REDUCTION + ) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") + + loss = model(inp) + loss.sum().backward() + + torch.distributed.barrier() + torch.cuda.synchronize() + + with open(self.nccl_log_dir.name + "/nccl_log") as f: + logs = f.read() + # At this stage we should have only AVG, no SUM + self.assertRegex(logs, reduce_scatter_avg_re) + self.assertNotRegex(logs, reduce_scatter_sum_re) + + for module in model.modules(): + if isinstance(module, TransformerBlock): + module.set_force_sum_reduction_for_comms(True) + model.set_force_sum_reduction_for_comms(True) + + loss = model(inp) + loss.sum().backward() + + torch.distributed.barrier() + torch.cuda.synchronize() + + with open(self.nccl_log_dir.name + "/nccl_log") as f: + logs = f.read() + # Now we should also have SUM + self.assertRegex(logs, reduce_scatter_sum_re) + + # Test both reduce-scatter and all-reduce on HSDP (DDP+FSDP) on 4 GPUs + @skip_if_lt_x_gpu(4) + def test_fully_shard_force_sum_both_reductions(self): + mesh = init_device_mesh( + device_type.type, (2, self.world_size // 2), mesh_dim_names=("ddp", "fsdp") + ) + + torch.manual_seed(42) + model_args = ModelArgs() + model = Transformer(model_args) + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard(module, mesh=mesh) + fully_shard(model, mesh=mesh) + + # We target a specific count so that we don't pick up the barrier ops + layer_numel = sum(w.numel() for w in model.layers[0].parameters()) + comms_size = layer_numel // (self.world_size // 2) + reduce_scatter_avg_re = self.COLLECTIVE_RE.format( + coll="ReduceScatter", count=comms_size, reduce_op=self.AVG_REDUCTION + ) + reduce_scatter_sum_re = self.COLLECTIVE_RE.format( + coll="ReduceScatter", count=comms_size, reduce_op=self.SUM_REDUCTION + ) + all_reduce_avg_re = self.COLLECTIVE_RE.format( + coll="AllReduce", count=comms_size, reduce_op=self.AVG_REDUCTION + ) + all_reduce_sum_re = self.COLLECTIVE_RE.format( + coll="AllReduce", count=comms_size, reduce_op=self.SUM_REDUCTION + ) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") + + loss = model(inp) + loss.sum().backward() + + torch.distributed.barrier() + torch.cuda.synchronize() + + with open(self.nccl_log_dir.name + "/nccl_log") as f: + logs = f.read() + # At this stage we should have only AVG, no SUM + self.assertRegex(logs, reduce_scatter_avg_re) + self.assertRegex(logs, all_reduce_avg_re) + self.assertNotRegex(logs, reduce_scatter_sum_re) + self.assertNotRegex(logs, all_reduce_sum_re) + + for module in model.modules(): + if isinstance(module, TransformerBlock): + module.set_force_sum_reduction_for_comms(True) + model.set_force_sum_reduction_for_comms(True) + + loss = model(inp) + loss.sum().backward() + + torch.distributed.barrier() + torch.cuda.synchronize() + + with open(self.nccl_log_dir.name + "/nccl_log") as f: + logs = f.read() + # Now we should also have SUM + self.assertRegex(logs, reduce_scatter_sum_re) + self.assertRegex(logs, all_reduce_sum_re) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 4599c26fe13b3b..29200e5884e665 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -222,9 +222,7 @@ def _assert_no_aliased_unsharded_params_in_graph_inputs( ): unsharded_param_graph_inputs.add(node.args[0]) assert len(unsharded_param_graph_inputs) > 0 - assert len(unsharded_param_graph_inputs) == len( - list(model.parameters()) - ), """\ + assert len(unsharded_param_graph_inputs) == len(list(model.parameters())), """\ Expected all model parameters to be wrapped by FSDP2 and have their unsharded version as graph input, but it's not true! """ @@ -237,7 +235,7 @@ def _assert_no_aliased_unsharded_params_in_graph_inputs( no_aliased_unsharded_params_in_graph_inputs = False err_msg += f"""\n Found aliased unsharded param in graph inputs: {aliased_graph_inputs}, -val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, +val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]}, """ self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg) @@ -466,10 +464,9 @@ def inductor_code_check_fsdp_reduce_scatter( @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_compiled_autograd_ctx(self): self.skipTestForOldSm() - with torch._dynamo.config.patch( - skip_fsdp_hooks=False, - ), torch._functorch.config.patch( - recompute_views=True, + with ( + torch._dynamo.config.patch(skip_fsdp_hooks=False), + torch._functorch.config.patch(recompute_views=True), ): inputs = torch.randn(8, 8) model = torch.nn.Linear(8, 8) @@ -543,7 +540,16 @@ def test_compiled(): ) if fwd_fullgraph: self.assertEqual(len(counters["graph_break"]), 1) - self.assertIn("Tensor.backward", counters["graph_break"]) + self.assertExpectedInline( + next(iter(counters["graph_break"].keys())), + """\ +Unsupported Tensor.backward() call + Explanation: Dynamo currently does not support tracing `Tensor.backward()`. + Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. + + Developer debug context: call_method TensorVariable() backward () {} +""", # noqa: B950 + ) else: self.assertGreater(len(counters["graph_break"]), 1) return res @@ -557,24 +563,28 @@ def test_eager(): torch._dynamo.reset() torch._dynamo.compiled_autograd.reset() - with torch._dynamo.config.patch( - compiled_autograd=True, - compiled_autograd_kwargs_override={ - "fullgraph": True, - }, - inline_inbuilt_nn_modules=True, - skip_fsdp_hooks=False, - ), torch._functorch.config.patch( - enable_autograd_cache=False, - recompute_views=True, - ), torch._inductor.config.patch( - force_disable_caches=True, - reorder_for_compute_comm_overlap=True, - reorder_for_compute_comm_overlap_passes=[ - "sink_waits", - "raise_comms", - "reorder_compute_for_overlap", - ], + with ( + torch._dynamo.config.patch( + compiled_autograd=True, + compiled_autograd_kwargs_override={ + "fullgraph": True, + }, + inline_inbuilt_nn_modules=True, + skip_fsdp_hooks=False, + ), + torch._functorch.config.patch( + enable_autograd_cache=False, + recompute_views=True, + ), + torch._inductor.config.patch( + force_disable_caches=True, + reorder_for_compute_comm_overlap=True, + reorder_for_compute_comm_overlap_passes=[ + "sink_waits", + "raise_comms", + "reorder_compute_for_overlap", + ], + ), ): losses_compiled = test_compiled() losses_eager = test_eager() @@ -731,20 +741,21 @@ def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): def _test_nested_fully_shard_backend_inductor_fullgraph_True(self): self.skipTestForOldSm() for fwd_fullgraph in [True]: - with self._reinplace_all_gather_with_optional_checks( - fwd_fullgraph - ), torch._inductor.config.patch( - post_grad_custom_post_pass=( - functools.partial( - self._check_fsdp_copy_and_resize_ops_count_in_graph, - fwd_copy_count=0, - fwd_resize_count=0, - bwd_copy_count=0, - bwd_resize_count=0, + with ( + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + torch._inductor.config.patch( + post_grad_custom_post_pass=( + functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + fwd_copy_count=0, + fwd_resize_count=0, + bwd_copy_count=0, + bwd_resize_count=0, + ) + if fwd_fullgraph + else None ) - if fwd_fullgraph - else None - ) + ), ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( @@ -898,7 +909,7 @@ def model_init_fn(): for _, mod in enumerate(model.layers): fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config) model = fully_shard( - model, mesh=mesh, reshard_after_forward=True, **fsdp_config + model, mesh=mesh, reshard_after_forward=False, **fsdp_config ) optim = torch.optim.SGD(model.parameters(), lr=1e-4) return model, optim @@ -933,9 +944,10 @@ def test_transformer_backend_aot_eager(self): for fwd_fullgraph, all_requires_grad in itertools.product( [True], [True, False] ): - with self._maybe_add_graph_break_to_sdpa( - fwd_fullgraph - ), self._reinplace_all_gather_with_optional_checks(fwd_fullgraph): + with ( + self._maybe_add_graph_break_to_sdpa(fwd_fullgraph), + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + ): self._test_traceable_fsdp( *self._create_transformer_factory_fns( all_requires_grad=all_requires_grad @@ -972,23 +984,24 @@ def _test_transformer_backend_inductor_fullgraph_True(self): log.warning( f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950 ) - with self._reinplace_all_gather_with_optional_checks( - fwd_fullgraph - ), torch._inductor.config.patch( - post_grad_custom_post_pass=( - functools.partial( - self._check_fsdp_copy_and_resize_ops_count_in_graph, - # NOTE: For the root unsharded params, we don't reshard after forward since for training, - # the parameters would be freed and all-gathered immediately. Hence we still have - # their resize and copy ops in the graph. - fwd_copy_count=4, - fwd_resize_count=4, - bwd_copy_count=0, - bwd_resize_count=4, + with ( + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + torch._inductor.config.patch( + post_grad_custom_post_pass=( + functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + # NOTE: For the root unsharded params, we don't reshard after forward since for training, + # the parameters would be freed and all-gathered immediately. Hence we still have + # their resize and copy ops in the graph. + fwd_copy_count=4, + fwd_resize_count=4, + bwd_copy_count=0, + bwd_resize_count=4, + ) + if fwd_fullgraph + else None ) - if fwd_fullgraph - else None - ) + ), ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( @@ -1078,6 +1091,7 @@ def _test_transformer_backend_inductor_fullgraph_True(self): pass file_check.run(bwd_code) + @unittest.skip('"Traceable FSDP2" is not being maintained anymore.') @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why @@ -1085,6 +1099,7 @@ def _test_transformer_backend_inductor_fullgraph_True(self): def test_transformer_backend_inductor_fullgraph_True(self): self._test_transformer_backend_inductor_fullgraph_True() + @unittest.skip('"Traceable FSDP2" is not being maintained anymore.') @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why diff --git a/test/distributed/_composable/fsdp/test_fully_shard_extensions.py b/test/distributed/_composable/fsdp/test_fully_shard_extensions.py index d653f6a0bcb8bc..0b25e09b3defc5 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_extensions.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_extensions.py @@ -385,16 +385,18 @@ def test_all_gather_extension_outer_size_stride(self): only some ranks may require padding, in which case only those ranks will error out and the all-gather will timeout. """ - assert ( - self.world_size >= 2 - ), f"Assumes world size of at least 2 but got {self.world_size=}" + assert self.world_size >= 2, ( + f"Assumes world size of at least 2 but got {self.world_size=}" + ) model = MLP(dim=3, dim_multiplier=3) for module in model.modules(): for param_name, param in module.named_parameters(recurse=False): if "weight" in param_name: param = nn.Parameter(BFloat16AllGatherTensor(param)) setattr(module, param_name, param) - fully_shard(model) + # need to fix reshard_after_forward=True + # https://github.com/pytorch/pytorch/issues/154836 + fully_shard(model, reshard_after_forward=False) optim = torch.optim.AdamW(model.parameters(), lr=1e-2, fused=True) torch.manual_seed(42 + self.rank + 1) inp = torch.randn((2, 3), device=device_type) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_frozen.py b/test/distributed/_composable/fsdp/test_fully_shard_frozen.py index 467b63563b82b6..f56c5e76c12247 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_frozen.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_frozen.py @@ -115,9 +115,10 @@ def backward_with_count(*args, **kwargs): torch.manual_seed(42 + self.rank + 1) device = device_type - with patch_reduce_scatter( - reduce_scatter - ), patch_register_post_backward_hook_backward(backward_with_count): + with ( + patch_reduce_scatter(reduce_scatter), + patch_register_post_backward_hook_backward(backward_with_count), + ): for iter_idx in range(10): inp = torch.randn((8, lin_dim), device=device) losses: list[torch.Tensor] = [] diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py index 41a083a9e8d480..714145f8b976cc 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_init.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py @@ -910,9 +910,9 @@ def test_1d_process_group_init(self): @skip_if_lt_x_gpu(1) def test_2d_process_group_init(self): shard_mesh_dim_size = 2 - assert ( - self.world_size % shard_mesh_dim_size == 0 - ), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" + assert self.world_size % shard_mesh_dim_size == 0, ( + f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" + ) replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size mesh_dim_names = ("replicate", "shard") ref_mesh = init_device_mesh( @@ -1322,5 +1322,33 @@ def test_old_import_training(self): model(inp).sum().backward() +class TestFullyShardMixedDtypeParam(FSDPTestMultiThread): + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(2) + def test_mixed_dtypes_no_grad_param(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + # no grad params with different dtypes + self.w_fp8 = torch.nn.Parameter( + torch.empty((256, 256), dtype=torch.float8_e4m3fn), + requires_grad=False, + ) + self.w_fp32 = torch.nn.Parameter( + torch.empty((256, 256), dtype=torch.float32) + ) + + def forward(self, input): + return + + mesh = init_device_mesh(device_type.type, (self.world_size,)) + model = Model() + fully_shard(model, mesh=mesh) + model(0) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_memory.py b/test/distributed/_composable/fsdp/test_fully_shard_memory.py index c3b8f04688ef41..44d05ade98f75a 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_memory.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_memory.py @@ -54,9 +54,9 @@ def _test_fully_shard_training_memory( ) ): return # skip since not a common use case - assert ( - self.world_size == 2 - ), f"Requires world size of 2 since some values are hard coded: {self.world_size}" + assert self.world_size == 2, ( + f"Requires world size of 2 since some values are hard coded: {self.world_size}" + ) torch.manual_seed(42) # Pre-run a linear forward (gemm and bias) and backward (gemm) to # allocate the cuBLAS workspaces before measuring the memory usage diff --git a/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py b/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py index 87e5e1900ce344..06881442b748e8 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import copy +import dataclasses import functools from typing import Optional, Union @@ -122,7 +123,7 @@ def assert_fn(output: torch.Tensor): reduce_scatter = functools.partial( reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn ) - predivide_factor, postdivide_factor = _get_gradient_divide_factors( + predivide_factor, postdivide_factor, _, _ = _get_gradient_divide_factors( self.process_group, all_reduce_group=None, reduce_dtype=param_dtype ) @@ -283,9 +284,7 @@ def assert_fn(output: torch.Tensor): ) # bf16 reduction param.grad = funcol.all_gather_tensor( sharded_grad, gather_dim=0, group=group - ).to( - param.dtype - ) # upcast to fp32 + ).to(param.dtype) # upcast to fp32 ref_optim.step() # fp32 optimizer step self.assertEqual(fsdp_loss, ref_loss) @@ -593,6 +592,30 @@ def assert_fn(output: torch.Tensor): loss = model(inp).sum() loss.backward() + @skip_if_lt_x_gpu(1) + def test_dataclass_input(self): + @dataclasses.dataclass + class Input: + x: torch.Tensor + + class Model(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._layer = nn.Linear(10, 10) + + def forward(self, input: Input): + return self._layer(input.x) + + mp_policy = MixedPrecisionPolicy( + torch.bfloat16, torch.bfloat16, torch.bfloat16, True + ) + model = Model() + inp = Input(torch.randn(2, 10).cuda()) + + fully_shard(model, mp_policy=mp_policy) + loss = model(inp).sum() + loss.backward() + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py index c9653d06adeade..e8d52f70e0f40a 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py @@ -139,8 +139,9 @@ def ref_fwd_bwd(): dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input) def fwd_bwd(): - with patch_all_gather(delayed_all_gather), patch_reduce_scatter( - delayed_reduce_scatter + with ( + patch_all_gather(delayed_all_gather), + patch_reduce_scatter(delayed_reduce_scatter), ): loss = model(inp).sum() loss.backward() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index b6b6dfaa017bc5..cf8b86cc8e06d0 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -74,12 +74,12 @@ def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]): # Check that FSDP moved the inputs to GPU, including recursing # into the tuple data structure assert x.device == device, f"Expects {device} but got {x.device}" - assert ( - ys[0].device == device - ), f"Expects {device} but got {ys[0].device}" - assert ( - ys[1].device == device - ), f"Expects {device} but got {ys[1].device}" + assert ys[0].device == device, ( + f"Expects {device} but got {ys[0].device}" + ) + assert ys[1].device == device, ( + f"Expects {device} but got {ys[1].device}" + ) y = ys[0] + ys[1] return x + y + 1 @@ -103,7 +103,7 @@ def test_param_registration_after_forward(self): """Tests the parameter registration after forward.""" device = torch.device(device_type.type, 0) # Single FSDP group - for reshard_after_forward in (True, False, 2): + for reshard_after_forward in (True, False, 2, None): torch.manual_seed(42) model = MLP(3, device) # Since seed is per process, not per thread, we broadcast to ensure @@ -115,15 +115,18 @@ def test_param_registration_after_forward(self): inp = torch.randn((2, 3), device=device_type.type) self._assert_dtensor_params(model.parameters()) self._assert_same_params(model.parameters(), ref_model.parameters()) - model(inp) # root does not reshard after forward - self._assert_tensor_params(model.parameters()) + model(inp) + if reshard_after_forward: + self._assert_dtensor_params(model.parameters()) + else: + self._assert_tensor_params(model.parameters()) self._assert_same_params(model.parameters(), ref_model.parameters()) model.reshard() # however, we can manually reshard self._assert_dtensor_params(model.parameters()) self._assert_same_params(model.parameters(), ref_model.parameters()) # Multiple FSDP groups - for reshard_after_forward in (True, False, 2): + for reshard_after_forward in (True, False, 2, None): torch.manual_seed(42) model = nn.Sequential(MLP(3, device), MLP(3, device)) for param in model.parameters(): @@ -140,11 +143,15 @@ def test_param_registration_after_forward(self): model[0].out_proj.parameters() ) root_params = list(set(model.parameters()) - set(non_root_params)) - if reshard_after_forward is False: - self._assert_tensor_params(non_root_params) - else: + if reshard_after_forward is None: + self._assert_dtensor_params(non_root_params) + self._assert_tensor_params(root_params) + elif reshard_after_forward: self._assert_dtensor_params(non_root_params) - self._assert_tensor_params(root_params) + self._assert_dtensor_params(root_params) + else: + self._assert_tensor_params(non_root_params) + self._assert_tensor_params(root_params) self._assert_same_params(model.parameters(), ref_model.parameters()) for module in model.modules(): if isinstance(module, FSDPModule): @@ -176,13 +183,16 @@ def test_param_registration_after_backward(self): self._assert_dtensor_params(model.parameters()) def _assert_tensor_params(self, params: Iterable[nn.Parameter]): - self.assertGreater(len(list(params)), 0) + # need to iterate over the list multiple times + params = list(params) + self.assertGreater(len(params), 0) for param in params: self.assertNotIsInstance(param, DTensor) self.assertIsInstance(param, torch.Tensor) def _assert_dtensor_params(self, params: Iterable[nn.Parameter]): - self.assertGreater(len(list(params)), 0) + params = list(params) + self.assertGreater(len(params), 0) for param in params: self.assertIsInstance(param, DTensor) @@ -1074,9 +1084,7 @@ def set_backward_flags(_model: nn.Module, is_last_microbatch: bool): # the first microbatch's forward expected_all_gather_count = num_mlps + 1 if reshard_after_forward is not False: # `True` or `2` - # Add the number of MLPs without the +1 for the backward - # all-gathers since the root does not reshard after forward - expected_all_gather_count += num_mlps + expected_all_gather_count += num_mlps + 1 # Multiply by the number of microbatches since these # all-gathers run every microbatch expected_all_gather_count *= num_microbatches @@ -1303,7 +1311,7 @@ def _test_3d_mlp_with_nd_mesh( use_activation_checkpointing, reshard_after_forward=reshard_after_forward, ) - # Checking paramters match orig model is critical to validate .full_tensor correctly replicates the + # Checking parameters match orig model is critical to validate .full_tensor correctly replicates the # strided-sharded layers. for ref_p, p in zip(ref_model.parameters(), model.parameters()): self.assertIsInstance(p, DTensor) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index edb24ef01ab9da..5ad96979717c4b 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -47,7 +47,6 @@ instantiate_parametrized_tests, parametrize, run_tests, - skipIfRocm, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -116,7 +115,6 @@ def init_global_mesh(self) -> DeviceMesh: ) @skip_if_lt_x_gpu(2) - @skipIfRocm def test_train_parity_2d_mlp(self): global_mesh = self.init_global_mesh() self.run_subtests( @@ -164,7 +162,6 @@ def _test_train_parity_2d_mlp( self.assertEqual(losses[0], losses[1]) @skip_if_lt_x_gpu(2) - @skipIfRocm def test_train_parity_2d_transformer(self): self.run_subtests( {"use_shard_placement_fn": [False, True]}, @@ -245,7 +242,6 @@ def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: self.assertEqual(full_param, ref_param) @skip_if_lt_x_gpu(2) - @skipIfRocm def test_tp_with_fsdp_offloading(self): global_mesh = init_device_mesh( "cuda", (1, self.world_size), mesh_dim_names=("dp", "tp") diff --git a/test/distributed/_composable/test_composability/test_pp_composability.py b/test/distributed/_composable/test_composability/test_pp_composability.py index d30d0424a04009..8f0b938da41b03 100644 --- a/test/distributed/_composable/test_composability/test_pp_composability.py +++ b/test/distributed/_composable/test_composability/test_pp_composability.py @@ -38,7 +38,6 @@ parametrize, run_tests, skip_but_pass_in_sandcastle_if, - skipIfRocm, ) from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir @@ -101,7 +100,6 @@ def world_size(self): def device(self): return self.rank - @skipIfRocm() @requires_nccl() @skip_if_lt_x_gpu(4) @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs") diff --git a/test/distributed/_tools/test_memory_tracker.py b/test/distributed/_tools/test_memory_tracker.py index eaa5d24ce369ba..ccf7f0beefd072 100644 --- a/test/distributed/_tools/test_memory_tracker.py +++ b/test/distributed/_tools/test_memory_tracker.py @@ -6,17 +6,18 @@ import torch.nn as nn from torch.distributed._tools import MemoryTracker from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase class TestMemoryTracker(TestCase): - @unittest.skipIf(not TEST_CUDA, "no cuda") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu") def test_local_model(self): """ Minimal test case to check the memory tracker can collect the expected memory stats at operator level, as well as can print the summary result without crash. """ + device = "cuda" if TEST_CUDA else "xpu" # Create a model with a hierarchy of modules torch.manual_seed(0) model = nn.Sequential( @@ -28,16 +29,16 @@ def test_local_model(self): ), nn.Flatten(start_dim=1), nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)), - ).cuda() + ).to(device) # Run one iteration of forward and backward pass tracker = MemoryTracker() tracker.start_monitor(model) - x = torch.randn(size=(2, 3, 224, 224), device=torch.device("cuda")) - # torch.LongTensor expects cpu device type, not cuda device type in - # constructor, so calling .cuda() outside constructor here. - target = torch.LongTensor([0, 1]).cuda() + x = torch.randn(size=(2, 3, 224, 224), device=torch.device(device)) + # torch.LongTensor expects cpu device type, not device type in + # constructor, so calling .to(device) outside constructor here. + target = torch.LongTensor([0, 1]).to(device) criterion = nn.CrossEntropyLoss() criterion(model(x), target).backward() @@ -61,7 +62,7 @@ def test_local_model(self): self.assertEqual(len(tracker.memories_reserved), tracker._op_index) self.assertTrue(len(tracker._markers) == 2) self.assertTrue(tracker._cur_module_name != "") - self.assertTrue(hasattr(tracker, "_num_cuda_retries")) + self.assertTrue(hasattr(tracker, "_num_alloc_retries")) if __name__ == "__main__": diff --git a/test/distributed/_tools/test_sac_ilp.py b/test/distributed/_tools/test_sac_ilp.py index 05c7dbb1a63eb4..10d4d7a030f10f 100644 --- a/test/distributed/_tools/test_sac_ilp.py +++ b/test/distributed/_tools/test_sac_ilp.py @@ -157,7 +157,7 @@ def test_sac_ilp_case1(self): # Due to symmetry, the layer that has 0.7964 can be any of the first three layers. On CI, # due to machine variance and difference in flops, the results can be different -- e.g., # the ratios are 0.672, 0.5646, 0.5646, 0.5646 for the four transformer layers for test - # linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, lf.linux.8xlarge.nvidia.gpu). + # linux-jammy-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, lf.linux.8xlarge.nvidia.gpu). # and recomputation_time = 58.14; compute_time = 902.26 modules_to_ac = set(ac_decisions.keys()) sorted_discard_ratio = sorted(ac_decisions.values()) diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py index ec85a668d74ff4..89a893037c3b58 100644 --- a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -234,8 +234,8 @@ def hook(flags, bucket): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/algorithms/test_join.py b/test/distributed/algorithms/test_join.py index 60982d29cc6252..8fd613a47d7754 100644 --- a/test/distributed/algorithms/test_join.py +++ b/test/distributed/algorithms/test_join.py @@ -250,9 +250,11 @@ def _test_join_base( else "Detected at least one rank that exhausted inputs. " "Throwing across all ranks." ) - with self.assertRaisesRegex( - RuntimeError, expected_msg - ) if throw_on_early_termination else contextlib.nullcontext(): + with ( + self.assertRaisesRegex(RuntimeError, expected_msg) + if throw_on_early_termination + else contextlib.nullcontext() + ): with Join( allreducers, enable=enable, diff --git a/test/distributed/checkpoint/_experimental/test_barriers.py b/test/distributed/checkpoint/_experimental/test_barriers.py new file mode 100644 index 00000000000000..b483659ba00532 --- /dev/null +++ b/test/distributed/checkpoint/_experimental/test_barriers.py @@ -0,0 +1,110 @@ +# Owner(s): ["oncall: distributed checkpointing"] + +import unittest.mock as mock + +from torch.distributed.checkpoint._experimental.barriers import TCPStoreBarrier +from torch.distributed.checkpoint._experimental.types import RankInfo +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestBarriers(TestCase): + @mock.patch("torch.distributed.TCPStore") + @mock.patch("torch.distributed.elastic.utils.store.barrier") + def test_tcpstore_barrier_initialization(self, _, mock_tcpstore): + """Test that TCPStoreBarrier initializes correctly.""" + # Setup + timeout_barrier_init_secs = 60 + barrier_prefix = "test_barrier" + world_size = 4 + use_checkpoint_barrier_tcpstore_libuv = True + tcpstore_port = 12345 + master_address = "localhost" + rank = 0 + timeout_secs = 30 + + # Create rank_info + rank_info = RankInfo(global_rank=rank, global_world_size=world_size) + + # Create the barrier (used for verification) + _ = TCPStoreBarrier( + global_rank=rank_info.global_rank, + global_world_size=rank_info.global_world_size, + barrier_prefix=barrier_prefix, + timeout_barrier_init_secs=timeout_barrier_init_secs, + use_checkpoint_barrier_tcpstore_libuv=use_checkpoint_barrier_tcpstore_libuv, + tcpstore_port=tcpstore_port, + master_address=master_address, + timeout_secs=timeout_secs, + ) + + # Verify that TCPStore was initialized with the correct parameters + mock_tcpstore.assert_called_once_with( + master_address, + tcpstore_port, + world_size=rank_info.global_world_size, + timeout=mock.ANY, # timedelta is hard to compare directly + is_master=(rank_info.global_rank == 0), + ) + + @mock.patch("torch.distributed.TCPStore") + @mock.patch("torch.distributed.elastic.utils.store.barrier") + def test_execute_barrier(self, mock_barrier, mock_tcpstore): + """Test that execute_barrier calls the barrier function correctly.""" + # Setup + barrier_prefix = "test_barrier" + timeout_barrier_init_secs = 60 + world_size = 4 + use_checkpoint_barrier_tcpstore_libuv = True + tcpstore_port = 12345 + master_address = "localhost" + rank = 0 + timeout_secs = 30 + + # Create rank_info + rank_info = RankInfo(global_rank=rank, global_world_size=world_size) + + # Mock the TCPStore instance + mock_tcpstore_instance = mock.MagicMock() + mock_tcpstore.return_value = mock_tcpstore_instance + + # Create the barrier + barrier = TCPStoreBarrier( + global_rank=rank_info.global_rank, + global_world_size=rank_info.global_world_size, + barrier_prefix=barrier_prefix, + timeout_barrier_init_secs=timeout_barrier_init_secs, + use_checkpoint_barrier_tcpstore_libuv=use_checkpoint_barrier_tcpstore_libuv, + tcpstore_port=tcpstore_port, + master_address=master_address, + timeout_secs=timeout_secs, + ) + + # Execute the barrier + barrier.execute_barrier() + + # Verify that the TCPStore's set method was called with the correct parameters + mock_tcpstore_instance.set.assert_called_once_with("rank0", "0") + + # Verify that the barrier function was called with the correct parameters + mock_barrier.assert_called_once_with( + store=mock_tcpstore_instance, + world_size=rank_info.global_world_size, + key_prefix=barrier_prefix + "0", + ) + + # Execute the barrier again to test sequence number increment + barrier.execute_barrier() + + # Verify that the TCPStore's set method was called with the incremented sequence number + mock_tcpstore_instance.set.assert_called_with("rank0", "1") + + # Verify that the barrier function was called with the incremented sequence number + mock_barrier.assert_called_with( + store=mock_tcpstore_instance, + world_size=rank_info.global_world_size, + key_prefix=barrier_prefix + "1", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/_experimental/test_builder.py b/test/distributed/checkpoint/_experimental/test_builder.py new file mode 100644 index 00000000000000..7eed02755610b0 --- /dev/null +++ b/test/distributed/checkpoint/_experimental/test_builder.py @@ -0,0 +1,165 @@ +# Owner(s): ["oncall: distributed checkpointing"] + +import os +import shutil +import tempfile + +import torch +from torch.distributed.checkpoint._experimental.barriers import BarrierConfig +from torch.distributed.checkpoint._experimental.builder import ( + make_async_checkpointer, + make_sync_checkpointer, +) +from torch.distributed.checkpoint._experimental.checkpointer import ( + AsyncCheckpointer, + SyncCheckpointer, +) +from torch.distributed.checkpoint._experimental.config import CheckpointerConfig +from torch.distributed.checkpoint._experimental.staging import CheckpointStagerConfig +from torch.distributed.checkpoint._experimental.types import RankInfo +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestMakeCheckpointer(TestCase): + def setUp(self) -> None: + # Create a temporary directory for checkpoints + self.temp_dir = tempfile.mkdtemp() + + # Create real objects for testing + self.rank_info = RankInfo( + global_world_size=1, + global_rank=0, + ) + + # Create a test state dictionary + self.state_dict = { + "model": torch.nn.Linear(10, 5).state_dict(), + "optimizer": {"param_groups": [{"lr": 0.01}]}, + "epoch": 5, + "step": 1000, + } + + def tearDown(self) -> None: + # Clean up the temporary directory + shutil.rmtree(self.temp_dir) + + def test_make_sync_checkpointer(self) -> None: + """Test creating a synchronous checkpointer using make_sync_checkpointer.""" + + # Create sync checkpointer using factory function with no barrier + config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None)) + checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info) + + # Verify it's a SyncCheckpointer instance + self.assertIsInstance(checkpointer, SyncCheckpointer) + + # Test that it works for sync operations + checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_sync") + result = checkpointer.save(self.state_dict, checkpoint_path) + self.assertIsNone(result) # Sync mode returns None + + # Verify checkpoint was created + checkpoint_file = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(checkpoint_file)) + + # Test loading + loaded_state_dict = checkpointer.load(checkpoint_path) + self.assertEqual(loaded_state_dict["epoch"], 5) + + def test_make_sync_checkpointer_with_config_first(self) -> None: + """Test creating a synchronous checkpointer with config as first parameter.""" + # Create sync checkpointer with config as first parameter + config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None)) + checkpointer = make_sync_checkpointer(config=config, rank_info=self.rank_info) + + # Verify it's a SyncCheckpointer instance + self.assertIsInstance(checkpointer, SyncCheckpointer) + + # Test that it works for sync operations + checkpoint_path = os.path.join( + self.temp_dir, "checkpoint_factory_sync_config_first" + ) + result = checkpointer.save(self.state_dict, checkpoint_path) + self.assertIsNone(result) # Sync mode returns None + + # Verify checkpoint was created + checkpoint_file = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(checkpoint_file)) + + def test_make_sync_checkpointer_with_custom_config(self) -> None: + """Test creating a synchronous checkpointer with a custom config.""" + # Create a custom config with no barrier + config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None)) + + # Create sync checkpointer with the custom config + checkpointer = make_sync_checkpointer(rank_info=self.rank_info, config=config) + + # Verify it's a SyncCheckpointer instance + self.assertIsInstance(checkpointer, SyncCheckpointer) + + # Test that it works for sync operations + checkpoint_path = os.path.join( + self.temp_dir, "checkpoint_factory_sync_custom_config" + ) + result = checkpointer.save(self.state_dict, checkpoint_path) + self.assertIsNone(result) # Sync mode returns None + + # Verify checkpoint was created + checkpoint_file = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(checkpoint_file)) + + # Test loading + loaded_state_dict = checkpointer.load(checkpoint_path) + self.assertEqual(loaded_state_dict["epoch"], 5) + + def test_make_async_checkpointer(self) -> None: + """Test creating an asynchronous checkpointer using make_async_checkpointer.""" + # Create async checkpointer using factory function with default parameters + config: CheckpointerConfig = CheckpointerConfig() + config.staging_config = CheckpointStagerConfig( + use_cuda_non_blocking_copy=torch.cuda.is_available(), + use_pinned_memory=torch.cuda.is_available(), + ) + checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info) + + try: + # Verify it's an AsyncCheckpointer instance + self.assertIsInstance(checkpointer, AsyncCheckpointer) + + # Test that it works for async operations + checkpoint_path = os.path.join(self.temp_dir, "checkpoint_factory_async") + stage_future, write_future = checkpointer.save( + self.state_dict, checkpoint_path + ) + + # Verify futures are returned + self.assertIsNotNone(stage_future) + self.assertIsNotNone(write_future) + + # Wait for completion + stage_future.result() + write_future.result() + + # Verify checkpoint was created + checkpoint_file = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(checkpoint_file)) + + # Test loading + loaded_state_dict = checkpointer.load(checkpoint_path) + self.assertEqual(loaded_state_dict["epoch"], 5) + + finally: + # Clean up + checkpointer.close() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/_experimental/test_checkpoint_process.py b/test/distributed/checkpoint/_experimental/test_checkpoint_process.py new file mode 100644 index 00000000000000..1220d5f07235b4 --- /dev/null +++ b/test/distributed/checkpoint/_experimental/test_checkpoint_process.py @@ -0,0 +1,465 @@ +# Owner(s): ["oncall: distributed checkpointing"] + + +import os +import tempfile +import time +from concurrent.futures import Future +from typing import Any + +import torch +from torch.distributed.checkpoint._experimental.checkpoint_process import ( + CheckpointProcess, + CheckpointProcessConfig, + RequestType, + WorkerRequest, + WorkerResponse, +) +from torch.distributed.checkpoint._experimental.checkpoint_writer import ( + CheckpointWriter, + CheckpointWriterConfig, +) +from torch.distributed.checkpoint._experimental.types import RankInfo +from torch.testing._internal.common_utils import run_tests, TestCase + + +def subprocess_init_fn(name: str, parent_pid: int) -> None: + """Initialize the subprocess with some basic checks. + + This is similar to the subprocess_init_routine in checkpointing_test.py. + """ + assert name == "test-checkpointer", f"Unexpected subprocess name: {name}" + assert os.getpid() != parent_pid, "This was supposed to run in a different process" + assert os.getppid() == parent_pid, ( + "This was supposed to run as a child to main process" + ) + + +def failing_subprocess_init_fn(name: str, parent_pid: int) -> None: + """Initialize function that raises an exception.""" + # Acknowledge parameters to avoid unused variable warnings + _ = name + _ = parent_pid + raise RuntimeError("Subprocess initialization failed") + + +def timedout_subprocess_init_fn(**kwargs: Any) -> None: + # Acknowledge parameters to avoid unused variable warnings + _ = kwargs + time.sleep(3) # Simulate a long initialization + + +def ckpt_writer_init_fn(**kwargs: Any) -> CheckpointWriter: + """Initialize a CheckpointWriter in the subprocess. + + This function is called in the subprocess to create a CheckpointWriter instance. + It's important that this function is defined at the module level so it can be pickled. + """ + return CheckpointWriter( + config=kwargs.get("config"), + rank_info=kwargs.get("rank_info"), + ) + + +def failing_ckpt_writer_init_fn(**kwargs: Any) -> CheckpointWriter: + """Initialize function that raises an exception.""" + # Acknowledge parameters to avoid unused variable warnings + _ = kwargs + raise RuntimeError("CheckpointWriter initialization failed") + + +def shared_tensor_verifier_init_fn(**kwargs: Any) -> CheckpointWriter: + """Initialize a CheckpointWriter that verifies shared memory tensors.""" + + class SharedTensorVerifier(CheckpointWriter): + def __init__(self, config=None, rank_info=None, **init_kwargs): + # Acknowledge unused kwargs to avoid linting warnings + _ = init_kwargs + super().__init__( + config=config or CheckpointWriterConfig(), + rank_info=rank_info, + barrier=None, + commit_hook=None, + ) + + def write(self, state_dict, path, **__): + # Acknowledge parameters to avoid unused variable warnings + _ = path + + # Verify shared memory tensor behavior directly with assertions + if "shared_tensor" in state_dict: + shared_tensor = state_dict["shared_tensor"] + # Critical assertion: shared tensor should remain in shared memory in subprocess + assert shared_tensor.is_shared(), ( + "Shared tensor should be in shared memory in subprocess" + ) + + shared_tensor[0] = 42.0 + + if "regular_tensor" in state_dict: + # Note: ForkingPickler moves regular tensors to shared memory during IPC - this is acceptable + assert state_dict["regular_tensor"].is_shared(), ( + "Regular tensor should also be in shared memory in subprocess" + ) + + return None + + verifier = SharedTensorVerifier( + config=kwargs.get("config"), + rank_info=kwargs.get("rank_info"), + ) + return verifier + + +class TestRequestTypes(TestCase): + """Test the request/response data structures.""" + + def test_request_type_enum(self) -> None: + """Test RequestType enum values.""" + self.assertEqual(RequestType.PING.value, "ping") + self.assertEqual(RequestType.WRITE_CHECKPOINT.value, "write_checkpoint") + self.assertEqual(RequestType.TERMINATE_PROCESS.value, "exit") + + def test_worker_request(self) -> None: + """Test WorkerRequest dataclass.""" + request = WorkerRequest(request_type=RequestType.PING, payload={"test": "data"}) + self.assertEqual(request.request_type, RequestType.PING) + self.assertEqual(request.payload["test"], "data") + + def test_worker_response(self) -> None: + """Test WorkerResponse dataclass.""" + response = WorkerResponse( + request_type=RequestType.PING, + success=True, + error_msg=None, + payload={"result": "success"}, + ) + self.assertEqual(response.request_type, RequestType.PING) + self.assertTrue(response.success) + self.assertIsNone(response.error_msg) + self.assertEqual(response.payload["result"], "success") + + +class TestCheckpointProcessConfig(TestCase): + """Test CheckpointProcessConfig configuration.""" + + def test_default_options(self) -> None: + """Test default CheckpointProcessConfig.""" + options = CheckpointProcessConfig() + # Test default values + self.assertEqual(options.subprocess_init_timeout_secs, 30) + self.assertEqual(options.subprocess_shutdown_timeout_secs, 60) + + def test_custom_options(self) -> None: + """Test custom CheckpointProcessConfig.""" + options = CheckpointProcessConfig( + subprocess_init_timeout_secs=10, subprocess_shutdown_timeout_secs=30 + ) + self.assertEqual(options.subprocess_init_timeout_secs, 10) + self.assertEqual(options.subprocess_shutdown_timeout_secs, 30) + + +class TestCheckpointProcess(TestCase): + def setUp(self) -> None: + """Set up common test fixtures.""" + self.rank_info = RankInfo( + global_world_size=1, + global_rank=0, + ) + self.writer_config = CheckpointWriterConfig() + self.test_state_dict = { + "model": torch.nn.Linear(10, 5).state_dict(), + "optimizer": {"param_groups": [{"lr": 0.01}]}, + "epoch": 5, + "step": 1000, + } + + def _create_checkpoint_process( + self, + subprocess_init_fn_override=None, + subprocess_init_args_override=None, + writer_init_fn_override=None, + subprocess_init_timeout_secs=30, + ): + """Helper to create CheckpointProcess.""" + config = CheckpointProcessConfig( + subprocess_init_timeout_secs=subprocess_init_timeout_secs, + ) + + return CheckpointProcess( + rank_info=self.rank_info, + config=config, + subprocess_init_fn=subprocess_init_fn_override or subprocess_init_fn, + subprocess_init_args=subprocess_init_args_override + or ( + "test-checkpointer", + os.getpid(), + ), + checkpoint_writer_init_fn=writer_init_fn_override or ckpt_writer_init_fn, + checkpoint_writer_init_args={ + "config": self.writer_config, + "rank_info": self.rank_info, + }, + ) + + def test_checkpoint_process_initialization(self) -> None: + """Test that CheckpointProcess initializes and closes correctly.""" + checkpoint_process = self._create_checkpoint_process() + + # Wait for the process creation future to complete + checkpoint_process.process_creation_future.result() + + # Verify process is alive + self.assertTrue(checkpoint_process.process.processes[0].is_alive()) + + checkpoint_process.close() + + # Verify process is terminated + self.assertFalse(checkpoint_process.process.processes[0].is_alive()) + + def test_checkpoint_write_sync_state_dict(self) -> None: + """Test writing a checkpoint with synchronous state dict.""" + checkpoint_process = self._create_checkpoint_process() + + # Wait for initialization + checkpoint_process.process_creation_future.result() + + # Create a temporary directory for the checkpoint + with tempfile.TemporaryDirectory() as temp_dir: + checkpoint_path = os.path.join(temp_dir, "test_checkpoint") + + # Write checkpoint + future = checkpoint_process.write(self.test_state_dict, checkpoint_path) + + # Verify future is returned + self.assertIsInstance(future, Future) + + # Wait for completion + future.result() + + # Verify checkpoint file was created + expected_file = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(expected_file)) + + # Verify checkpoint content + loaded_state_dict = torch.load(expected_file) + self.assertIn("model", loaded_state_dict) + self.assertIn("optimizer", loaded_state_dict) + self.assertEqual(loaded_state_dict["epoch"], 5) + self.assertEqual(loaded_state_dict["step"], 1000) + + checkpoint_process.close() + + def test_checkpoint_write_future_state_dict(self) -> None: + """Test writing a checkpoint with Future state dict.""" + checkpoint_process = self._create_checkpoint_process() + + # Wait for initialization + checkpoint_process.process_creation_future.result() + + # Create a Future that resolves to the state dict + from concurrent.futures import ThreadPoolExecutor + + executor = ThreadPoolExecutor(max_workers=1) + + def get_state_dict(): + time.sleep(0.1) # Simulate some processing time + return self.test_state_dict + + future_state_dict = executor.submit(get_state_dict) + + # Create a temporary directory for the checkpoint + with tempfile.TemporaryDirectory() as temp_dir: + checkpoint_path = os.path.join(temp_dir, "test_checkpoint") + + # Write checkpoint with Future state dict + write_future = checkpoint_process.write(future_state_dict, checkpoint_path) + + # Wait for completion + write_future.result() + + # Verify checkpoint file was created + expected_file = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(expected_file)) + + executor.shutdown(wait=True) + checkpoint_process.close() + + def test_checkpoint_write_with_kwargs(self) -> None: + """Test checkpoint writing with additional kwargs.""" + checkpoint_process = self._create_checkpoint_process() + + # Wait for initialization + checkpoint_process.process_creation_future.result() + + with tempfile.TemporaryDirectory() as temp_dir: + checkpoint_path = os.path.join(temp_dir, "test_checkpoint") + + # Write checkpoint with kwargs + future = checkpoint_process.write( + self.test_state_dict, + checkpoint_path, + custom_arg="test_value", + another_arg=42, + ) + + # Wait for completion + future.result() + + # Verify checkpoint was created + expected_file = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(expected_file)) + + checkpoint_process.close() + + def test_subprocess_initialization_timeout(self) -> None: + """Test subprocess initialization timeout.""" + + # Create checkpoint process with a very short timeout by mocking the initialization + checkpoint_process = self._create_checkpoint_process( + subprocess_init_fn_override=timedout_subprocess_init_fn, + subprocess_init_timeout_secs=1, + ) + + # This should timeout + with self.assertRaises(TimeoutError) as cm: + checkpoint_process.process_creation_future.result() + + self.assertIn("Timed out", str(cm.exception)) + + def test_subprocess_initialization_failure(self) -> None: + """Test subprocess initialization failure.""" + checkpoint_process = self._create_checkpoint_process( + subprocess_init_fn_override=failing_subprocess_init_fn + ) + + # The subprocess should fail to initialize + # We expect this to raise an exception when we try to use it + with self.assertRaises(RuntimeError): + checkpoint_process.process_creation_future.result() + + def test_graceful_termination(self) -> None: + """Test graceful termination of subprocess.""" + checkpoint_process = self._create_checkpoint_process() + + checkpoint_process.process_creation_future.result() + self.assertTrue(checkpoint_process.process.processes[0].is_alive()) + checkpoint_process.close() + self.assertFalse(checkpoint_process.process.processes[0].is_alive()) + + def test_forced_termination(self) -> None: + """Test forced termination when graceful termination fails.""" + checkpoint_process = self._create_checkpoint_process() + + # Wait for initialization + checkpoint_process.process_creation_future.result() + + # Mock the join method to simulate timeout + def mock_join(timeout=None): + # Acknowledge timeout parameter to avoid unused variable warning + _ = timeout + return False # Simulate timeout + + checkpoint_process.process.join = mock_join + + # This should trigger forced termination + checkpoint_process.close() + + # Process should still be terminated (killed) + # Note: This test might be flaky depending on timing + + def test_communication_error_handling(self): + """Test handling of communication errors.""" + checkpoint_process = self._create_checkpoint_process() + + # Wait for initialization + checkpoint_process.process_creation_future.result() + + # Close the pipe to simulate communication failure + checkpoint_process._parent_end.close() + + # Attempting to write should raise an error + with self.assertRaises(RuntimeError) as cm: + future = checkpoint_process.write(self.test_state_dict, "/tmp/test") + future.result() + + self.assertIn("Child process terminated unexpectedly", str(cm.exception)) + + def test_shared_memory_tensor_ipc(self): + """Test that shared memory tensors are backed by the same memory across processes.""" + + checkpoint_process = self._create_checkpoint_process( + writer_init_fn_override=shared_tensor_verifier_init_fn, + ) + + checkpoint_process.process_creation_future.result() + + # Create tensors and put them in shared memory + shared_tensor = torch.randn(100, 100) + shared_tensor.share_memory_() + + shared_tensor_data_ptr = shared_tensor.data_ptr() + + regular_tensor = torch.randn(50, 50) + # Don't put regular tensor in shared memory for comparison + + # Verify initial shared memory status + self.assertTrue( + shared_tensor.is_shared(), "Shared tensor should be in shared memory" + ) + self.assertFalse( + regular_tensor.is_shared(), "Regular tensor should not be in shared memory" + ) + + # Create state dict with mixed tensor types + test_state_dict = { + "shared_tensor": shared_tensor, + "regular_tensor": regular_tensor, + } + + # Write to subprocess - the SharedTensorVerifier will: + # 1. Verify the tensor is still in shared memory + # 2. Check the marker value (42.0) to confirm same memory + # 3. Modify specific positions to prove same memory access + future = checkpoint_process.write(test_state_dict, "") + + try: + result = ( + future.result() + ) # This will raise an exception if the subprocess assertions fail + self.assertIsNone(result) # SharedTensorVerifier returns None on success + except Exception as e: + self.fail(f"Subprocess assertions failed: {e}") + + # assert shared tensor is still in same shared memory + self.assertEqual( + shared_tensor_data_ptr, + shared_tensor.data_ptr(), + "Shared tensor should still be in same shared memory", + ) + self.assertTrue( + shared_tensor.is_shared(), "Shared tensor should still be in shared memory" + ) + + # CRITICAL TEST: Verify that modifications made by subprocess are visible in main process + # This definitively proves that both processes access the same memory + + self.assertAlmostEqual( + shared_tensor[0][0], + 42.0, + places=6, + msg=f"Expected subprocess signature 42.0, got {shared_tensor[0]}. " + f"Shared memory not working - subprocess modifications not visible!", + ) + + checkpoint_process.close() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/_experimental/test_checkpoint_reader.py b/test/distributed/checkpoint/_experimental/test_checkpoint_reader.py new file mode 100644 index 00000000000000..88feb0bffee5db --- /dev/null +++ b/test/distributed/checkpoint/_experimental/test_checkpoint_reader.py @@ -0,0 +1,250 @@ +# Owner(s): ["oncall: distributed checkpointing"] +import os +import shutil +import tempfile +from typing import Any + +import torch +from torch.distributed.checkpoint._experimental.checkpoint_reader import ( + CheckpointReader, +) +from torch.distributed.checkpoint._experimental.types import RankInfo +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestCheckpointReader(TestCase): + def setUp(self): + # Create a temporary directory for test checkpoints + self.temp_dir = tempfile.mkdtemp() + + # Create test objects + self.rank_info = RankInfo( + global_rank=0, + global_world_size=1, + ) + + # Create the checkpoint reader + self.reader = CheckpointReader( + rank_info=self.rank_info, + ) + + # Create a test state dictionary + self.state_dict = { + "model": { + "weight": torch.randn(10, 5), + "bias": torch.randn(5), + "test_list": [torch.randn(2), torch.randn(2)], + }, + "optimizer": { + "param_groups": [ + {"lr": 0.01, "test_list": [torch.randn(2), torch.randn(2)]} + ] + }, + "epoch": 5, + "step": 1000, + } + + # Create a test checkpoint file + self.checkpoint_path = os.path.join(self.temp_dir, "checkpoint") + os.makedirs(self.checkpoint_path, exist_ok=True) + checkpoint_file = os.path.join( + self.checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + torch.save(self.state_dict, checkpoint_file) + + def move_tensors_to_device(self, state_dict: Any, device: str) -> Any: + """ + Recursively move all tensors in a nested dictionary to CUDA. + + Args: + state_dict (dict): A dictionary potentially containing nested dictionaries and tensors. + + Returns: + dict: A new dictionary with all tensors moved to CUDA. + """ + if isinstance(state_dict, dict): + return { + key: self.move_tensors_to_device(value, device) + for key, value in state_dict.items() + } + elif isinstance(state_dict, list): + return [self.move_tensors_to_device(item, device) for item in state_dict] + elif isinstance(state_dict, torch.Tensor): + return state_dict.cuda() if device == "cpu" else state_dict.cpu() + else: + return state_dict + + def deep_compare(self, obj1: Any, obj2: Any) -> bool: + if isinstance(obj1, dict) and isinstance(obj2, dict): + if obj1.keys() != obj2.keys(): + return False + return all(self.deep_compare(obj1[key], obj2[key]) for key in obj1) + elif isinstance(obj1, (list, tuple)) and isinstance(obj2, (list, tuple)): + if len(obj1) != len(obj2): + return False + return all( + self.deep_compare(item1, item2) for item1, item2 in zip(obj1, obj2) + ) + elif isinstance(obj1, torch.Tensor) and isinstance(obj2, torch.Tensor): + return torch.equal(obj1, obj2) + else: + return obj1 == obj2 + + def tearDown(self): + # Clean up the temporary directory + shutil.rmtree(self.temp_dir) + + def test_read_checkpoint(self): + """Test that read correctly reads a checkpoint file.""" + + # Call read + read_state_dict, missing_keys = self.reader.read(self.checkpoint_path) + self.assertEqual(missing_keys, []) + + # Verify that the read state dictionary contains the expected values + self.assertIn("model", read_state_dict) + self.assertIn("optimizer", read_state_dict) + self.assertTrue(self.deep_compare(read_state_dict, self.state_dict)) + + # No hooks to verify since we removed them + + def test_read_with_map_location(self): + """Test that read correctly uses the map_location parameter.""" + # Call read with map_location='cpu' + map_location = "cuda" if torch.cuda.is_available() else "cpu" + read_state_dict, _ = self.reader.read( + self.checkpoint_path, map_location=map_location + ) + + # Verify that the read state dictionary contains the expected values + self.assertIn("model", read_state_dict) + self.assertIn("optimizer", read_state_dict) + self.assertEqual(read_state_dict["epoch"], 5) + self.assertEqual(read_state_dict["step"], 1000) + self.assertEqual(read_state_dict["model"]["weight"].device.type, map_location) + read_state_dict, _ = self.reader.read( + self.checkpoint_path, map_location=map_location + ) + + # Verify that the read state dictionary contains the expected values + self.assertIn("model", read_state_dict) + self.assertIn("optimizer", read_state_dict) + self.assertEqual(read_state_dict["epoch"], 5) + self.assertEqual(read_state_dict["step"], 1000) + self.assertEqual(read_state_dict["model"]["weight"].device.type, map_location) + + def test_read_nonexistent_checkpoint(self): + """Test that read raises FileNotFoundError for a nonexistent checkpoint.""" + # Set up a path to a nonexistent checkpoint + nonexistent_path = os.path.join(self.temp_dir, "nonexistent_checkpoint") + + # Call read and expect a FileNotFoundError + with self.assertRaises(FileNotFoundError): + self.reader.read(nonexistent_path) + + def test_read_with_kwargs(self): + """Test that read correctly passes kwargs.""" + # Call read with additional kwargs + kwargs = {"extra": "value"} + self.reader.read(self.checkpoint_path, **kwargs) + + def test_partial_read(self): + """Test that read with state_dict correctly loads only the requested keys.""" + # Create a partial state dictionary with only some keys + partial_state_dict = {} + partial_state_dict["optimizer"] = None + partial_state_dict["model"] = {"weight": torch.randn(10, 5)} + partial_state_dict["epoch"] = None + # Call read with state_dict + updated_state_dict, _ = self.reader.read( + self.checkpoint_path, + partial_state_dict, + ) + + # Verify that the updated state dictionary contains values from both dictionaries + self.assertIn("model", updated_state_dict) + self.assertIn("epoch", updated_state_dict) + self.assertTrue( + torch.equal( + updated_state_dict["model"]["weight"], + self.state_dict["model"]["weight"], + ) + ) + + self.assertTrue( + self.deep_compare( + updated_state_dict["optimizer"], self.state_dict["optimizer"] + ) + ) + self.assertEqual(updated_state_dict["epoch"], 5) # From checkpoint + + self.assertNotIn("bias", updated_state_dict["model"]) + self.assertNotIn("step", updated_state_dict) + + def test_partial_read_missing_keys(self): + """Test that partial_read correctly reports missing keys.""" + # Create a partial state dictionary with keys that don't exist in the checkpoint + partial_state_dict = { + "model": None, + "nonexistent_key": None, # This key doesn't exist in the checkpoint + "another_missing_key": {"nested": None}, # This key also doesn't exist + } + + # Call read with state_dict + _, missing_keys = self.reader.read( + self.checkpoint_path, + partial_state_dict, + ) + + # Verify that missing keys are correctly reported + self.assertIn("nonexistent_key", missing_keys) + self.assertIn("another_missing_key", missing_keys) + + # Verify that keys that exist in the checkpoint are not in missing_keys + self.assertNotIn("model", missing_keys) + + def test_partial_read_different_dtypes(self): + """Test that partial_read correctly handles different tensor dtypes.""" + # Create a state dictionary with tensors of different dtypes + dtype_state_dict = { + "float32": torch.randn(10, 10, dtype=torch.float32), + "float64": torch.randn(10, 10, dtype=torch.float64), + "int32": torch.randint(-100, 100, (10, 10), dtype=torch.int32), + "int64": torch.randint(-100, 100, (10, 10), dtype=torch.int64), + "bool": torch.randint(0, 2, (10, 10), dtype=torch.bool), + } + + # Save the state dictionary + dtype_checkpoint_path = os.path.join(self.temp_dir, "dtype_checkpoint") + os.makedirs(dtype_checkpoint_path, exist_ok=True) + checkpoint_file = os.path.join( + dtype_checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + torch.save(dtype_state_dict, checkpoint_file) + + # Create a partial state dictionary requesting tensors of each dtype + partial_state_dict = { + "float32": torch.randn(10, 10, dtype=torch.float32), + "float64": None, + "int32": None, + "int64": None, + "bool": None, + } + + # Load the partial state dictionary + updated_state_dict, _ = self.reader.read( + os.path.dirname(checkpoint_file), + partial_state_dict, + ) + + # Verify that tensors of each dtype were loaded correctly + for key in dtype_state_dict: + self.assertIn(key, updated_state_dict) + self.assertEqual(updated_state_dict[key].dtype, dtype_state_dict[key].dtype) + self.assertTrue( + torch.allclose(updated_state_dict[key], dtype_state_dict[key]) + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/_experimental/test_checkpoint_writer.py b/test/distributed/checkpoint/_experimental/test_checkpoint_writer.py new file mode 100644 index 00000000000000..ce3945c455abdd --- /dev/null +++ b/test/distributed/checkpoint/_experimental/test_checkpoint_writer.py @@ -0,0 +1,200 @@ +# Owner(s): ["oncall: distributed checkpointing"] + +import os +import shutil +import tempfile +from typing import Any, Optional +from unittest.mock import MagicMock + +import torch +from torch.distributed.checkpoint._experimental.checkpoint_writer import ( + CheckpointWriter, + CheckpointWriterConfig, + WriterHook, +) +from torch.distributed.checkpoint._experimental.types import RankInfo +from torch.testing._internal.common_utils import run_tests, TestCase + + +class MockWriterHook(WriterHook): + """Mock implementation of WriterHook for testing.""" + + def __init__(self): + self.pre_commit_called = False + self.commit_called = False + self.pre_commit_path: Optional[str] = None + self.commit_path: Optional[str] = None + self.pre_commit_kwargs: Optional[dict[str, Any]] = None + self.commit_kwargs: Optional[dict[str, Any]] = None + + def pre_commit(self, path: str, **kwargs: Any): + self.pre_commit_called = True + self.pre_commit_path = path + self.pre_commit_kwargs = kwargs + + def post_commit(self, path: str, **kwargs: Any): + self.commit_called = True + self.commit_path = path + self.commit_kwargs = kwargs + + +class TestCheckpointWriterConfig(TestCase): + def test_default_values(self): + """Test that CheckpointWriterConfig has the correct default values.""" + options = CheckpointWriterConfig() + self.assertEqual(options.write_barrier_timeout_secs, 600) + + def test_custom_values(self): + """Test that CheckpointWriterConfig can be initialized with custom values.""" + options = CheckpointWriterConfig(write_barrier_timeout_secs=300) + self.assertEqual(options.write_barrier_timeout_secs, 300) + + +class TestCheckpointWriter(TestCase): + def setUp(self): + # Create a temporary directory for test checkpoints + self.temp_dir = tempfile.mkdtemp() + + # Create test objects + self.rank_info = RankInfo( + global_rank=0, + global_world_size=1, + ) + self.options = CheckpointWriterConfig() + self.mock_barrier = MagicMock() + self.mock_hook = MockWriterHook() + + # Create the checkpoint writer + self.writer = CheckpointWriter( + config=self.options, + rank_info=self.rank_info, + barrier=self.mock_barrier, + commit_hook=self.mock_hook, + ) + + # Create a test state dictionary + self.state_dict = { + "model": torch.nn.Linear(10, 5).state_dict(), + "optimizer": {"param_groups": [{"lr": 0.01}]}, + "epoch": 5, + "step": 1000, + } + + def tearDown(self): + # Clean up the temporary directory + shutil.rmtree(self.temp_dir) + + def test_write_creates_checkpoint_file(self): + """Test that write creates a checkpoint file with the correct content.""" + # Set up the checkpoint path + checkpoint_path = os.path.join(self.temp_dir, "checkpoint") + + # Call write + self.writer.write(self.state_dict, checkpoint_path) + + # Verify that the checkpoint file exists + expected_file_path = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(expected_file_path)) + + # Load the checkpoint and verify its contents + loaded_state_dict = torch.load(expected_file_path) + self.assertIn("model", loaded_state_dict) + self.assertIn("optimizer", loaded_state_dict) + self.assertEqual(loaded_state_dict["epoch"], 5) + self.assertEqual(loaded_state_dict["step"], 1000) + + def test_write_calls_barrier(self): + """Test that write calls the barrier with the correct parameters.""" + # Set up the checkpoint path + checkpoint_path = os.path.join(self.temp_dir, "checkpoint") + + # Call write + self.writer.write(self.state_dict, checkpoint_path) + + # Verify that the barrier was called + self.mock_barrier.execute_barrier.assert_called_once() + + def test_write_calls_commit_hooks(self): + """Test that write calls the commit hooks with the correct parameters.""" + # Set up the checkpoint path + checkpoint_path = os.path.join(self.temp_dir, "checkpoint") + + # Call write with additional kwargs + kwargs = {"extra": "value"} + self.writer.write(self.state_dict, checkpoint_path, **kwargs) + + # Verify that the pre_commit hook was called with the correct parameters + self.assertTrue(self.mock_hook.pre_commit_called) + self.assertEqual(self.mock_hook.pre_commit_path, checkpoint_path) + self.assertEqual( + self.mock_hook.pre_commit_kwargs is not None + and self.mock_hook.pre_commit_kwargs["extra"], + "value", + ) + + # Verify that the commit hook was called with the correct parameters + self.assertTrue(self.mock_hook.commit_called) + self.assertEqual(self.mock_hook.commit_path, checkpoint_path) + self.assertEqual( + self.mock_hook.commit_kwargs is not None + and self.mock_hook.commit_kwargs["extra"], + "value", + ) + + def test_write_without_barrier(self): + """Test that write works correctly without a barrier.""" + # Create a writer without a barrier + writer = CheckpointWriter( + config=self.options, + rank_info=self.rank_info, + barrier=None, + commit_hook=self.mock_hook, + ) + + # Set up the checkpoint path + checkpoint_path = os.path.join(self.temp_dir, "checkpoint_no_barrier") + + # Call write + writer.write(self.state_dict, checkpoint_path) + + # Verify that the checkpoint file exists + expected_file_path = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(expected_file_path)) + + def test_write_without_commit_hook(self): + """Test that write works correctly without a commit hook.""" + # Create a writer without a commit hook + writer = CheckpointWriter( + config=self.options, + rank_info=self.rank_info, + barrier=self.mock_barrier, + commit_hook=None, + ) + + # Set up the checkpoint path + checkpoint_path = os.path.join(self.temp_dir, "checkpoint_no_hook") + + # Call write + writer.write(self.state_dict, checkpoint_path) + + # Verify that the checkpoint file exists + expected_file_path = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(expected_file_path)) + + # Verify that the barrier was still called + self.mock_barrier.execute_barrier.assert_called_once() + + def test_close(self): + """Test that close doesn't raise any exceptions.""" + # This is a no-op in the base class, so just verify it doesn't raise + self.writer.close() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/_experimental/test_checkpointer.py b/test/distributed/checkpoint/_experimental/test_checkpointer.py new file mode 100644 index 00000000000000..e2c030385c89df --- /dev/null +++ b/test/distributed/checkpoint/_experimental/test_checkpointer.py @@ -0,0 +1,182 @@ +# Owner(s): ["oncall: distributed checkpointing"] + +import os +import shutil +import tempfile + +import torch +from torch.distributed.checkpoint._experimental.checkpoint_reader import ( + CheckpointReader, +) +from torch.distributed.checkpoint._experimental.checkpoint_writer import ( + CheckpointWriter, + CheckpointWriterConfig, +) +from torch.distributed.checkpoint._experimental.checkpointer import SyncCheckpointer +from torch.distributed.checkpoint._experimental.types import RankInfo +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestSyncCheckpointer(TestCase): + def setUp(self): + # Create a temporary directory for checkpoints + self.temp_dir = tempfile.mkdtemp() + + # Create real objects for testing + self.rank_info = RankInfo( + global_world_size=1, + global_rank=0, + ) + self.writer_config = CheckpointWriterConfig() + self.writer = CheckpointWriter( + config=self.writer_config, + rank_info=self.rank_info, + ) + + # Create reader for testing + self.reader = CheckpointReader( + rank_info=self.rank_info, + ) + + # Create sync checkpointer + self.checkpointer = SyncCheckpointer(self.writer, self.reader) + + # Create a test state dictionary + self.state_dict = { + "model": torch.nn.Linear(10, 5).state_dict(), + "optimizer": {"param_groups": [{"lr": 0.01}]}, + "epoch": 5, + "step": 1000, + } + + def tearDown(self): + # Clean up the temporary directory + shutil.rmtree(self.temp_dir) + + def test_sync_save_and_read(self): + """Test saving and reading a checkpoint synchronously.""" + checkpoint_path = os.path.join(self.temp_dir, "checkpoint_sync") + + # Save the checkpoint synchronously + result = self.checkpointer.save(self.state_dict, checkpoint_path) + self.assertIsNone(result) # Sync mode returns None + + # Verify that the checkpoint file exists + checkpoint_file = os.path.join( + checkpoint_path, f"checkpoint_{self.rank_info.global_rank}.pt" + ) + self.assertTrue(os.path.exists(checkpoint_file)) + + # Load the checkpoint using the checkpointer + loaded_state_dict = self.checkpointer.load(checkpoint_path) + + # Verify the loaded state dictionary + self.assertIn("model", loaded_state_dict) + self.assertIn("optimizer", loaded_state_dict) + self.assertEqual(loaded_state_dict["epoch"], 5) + self.assertEqual(loaded_state_dict["step"], 1000) + + def test_read_with_map_location(self): + """Test reading a checkpoint with a specific map_location.""" + checkpoint_path = os.path.join(self.temp_dir, "checkpoint_map_location") + + # Save the checkpoint + self.checkpointer.save(self.state_dict, checkpoint_path) + + # Load the checkpoint with map_location='cpu' + loaded_state_dict = self.checkpointer.load( + checkpoint_path, default_map_location="cpu" + ) + + # Verify the loaded state dictionary + self.assertIn("model", loaded_state_dict) + self.assertIn("optimizer", loaded_state_dict) + self.assertEqual(loaded_state_dict["epoch"], 5) + self.assertEqual(loaded_state_dict["step"], 1000) + + def test_partial_load(self): + """Test loading only specific keys from a checkpoint.""" + checkpoint_path = os.path.join(self.temp_dir, "checkpoint_partial") + + # Save the full checkpoint + self.checkpointer.save(self.state_dict, checkpoint_path) + + # Create a partial state dictionary with only some keys + partial_state_dict = { + "model": torch.nn.Linear(10, 5).state_dict(), + "epoch": None, # Will be loaded from checkpoint + } + + # Load only the keys in partial_state_dict + loaded_state_dict = self.checkpointer.load( + checkpoint_path, state_dict=partial_state_dict, default_map_location="cpu" + ) + + # Verify that the loaded state dictionary contains values from the checkpoint + self.assertIn("model", loaded_state_dict) + self.assertIn("epoch", loaded_state_dict) + self.assertEqual(loaded_state_dict["epoch"], 5) # From checkpoint + + # Verify that keys not in the partial_state_dict are not loaded + self.assertNotIn("step", loaded_state_dict) + self.assertNotIn("optimizer", loaded_state_dict) + + # Verify that the loaded state dictionary is the same object as the input + self.assertIs(loaded_state_dict, partial_state_dict) + + def test_partial_load_with_nested_dict(self): + """Test loading only specific nested keys from a checkpoint.""" + # Create a checkpoint with nested dictionaries + nested_state_dict = { + "model": { + "layer1": {"weight": torch.randn(5, 10), "bias": torch.randn(5)}, + "layer2": {"weight": torch.randn(2, 5), "bias": torch.randn(2)}, + }, + "metadata": {"epoch": 10, "step": 2000}, + } + + checkpoint_path = os.path.join(self.temp_dir, "checkpoint_nested") + + # Create a writer and save the nested state dict + writer = CheckpointWriter( + config=self.writer_config, + rank_info=self.rank_info, + ) + writer.write(nested_state_dict, checkpoint_path) + + # Create a partial state dictionary with nested structure + partial_state_dict = { + "model": { + "layer1": {"weight": None}, # Only request layer1.weight + }, + "metadata": {"epoch": None}, # Only request metadata.epoch + } + + # Load only the keys in partial_state_dict + loaded_state_dict = self.checkpointer.load( + checkpoint_path, state_dict=partial_state_dict, default_map_location="cpu" + ) + + # Verify that the nested keys were correctly loaded + self.assertIn("model", loaded_state_dict) + self.assertIn("layer1", loaded_state_dict["model"]) + self.assertIn("weight", loaded_state_dict["model"]["layer1"]) + self.assertIn("metadata", loaded_state_dict) + self.assertIn("epoch", loaded_state_dict["metadata"]) + + # Verify values were loaded correctly + self.assertTrue( + torch.allclose( + loaded_state_dict["model"]["layer1"]["weight"], + nested_state_dict["model"]["layer1"]["weight"], + ) + ) + self.assertEqual(loaded_state_dict["metadata"]["epoch"], 10) + + # Verify that keys not in the partial_state_dict are not loaded + self.assertNotIn("layer2", loaded_state_dict["model"]) + self.assertNotIn("step", loaded_state_dict["metadata"]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/_experimental/test_staging.py b/test/distributed/checkpoint/_experimental/test_staging.py new file mode 100644 index 00000000000000..0eeba5d63524d5 --- /dev/null +++ b/test/distributed/checkpoint/_experimental/test_staging.py @@ -0,0 +1,216 @@ +# Owner(s): ["oncall: distributed checkpointing"] + +from concurrent.futures import Future + +import torch +from torch.distributed.checkpoint._experimental.staging import ( + CheckpointStagerConfig, + DefaultStager, +) +from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase + + +class TestDefaultStager(TestCase): + def setUp(self) -> None: + # Create a test state dictionary with various data types + self.state_dict = { + "model": torch.nn.Linear(10, 5).state_dict(), + "optimizer": {"param_groups": [{"lr": 0.01}]}, + "epoch": 5, + "step": 1000, + "tensor": torch.randn(3, 4), + "nested": {"inner_tensor": torch.ones(2, 2), "inner_value": 42}, + } + + @requires_cuda + def test_sync_staging(self) -> None: + """Test synchronous staging.""" + options = CheckpointStagerConfig(use_async_staging=False) + stager = DefaultStager(options) + + # Stage the state dict + staged_dict = stager.stage(self.state_dict) + + # Verify that a state dict is returned (not a Future) + self.assertIsInstance(staged_dict, dict) + + # Verify the staged state dictionary + self.assertIn("model", staged_dict) + self.assertIn("optimizer", staged_dict) + self.assertEqual(staged_dict["epoch"], 5) + self.assertEqual(staged_dict["step"], 1000) + self.assertIn("tensor", staged_dict) + self.assertIn("nested", staged_dict) + + # Clean up + stager.close() + + @requires_cuda + def test_async_staging(self) -> None: + """Test asynchronous staging.""" + options = CheckpointStagerConfig(use_async_staging=True) + stager = DefaultStager(options) + + # Stage the state dict + result = stager.stage(self.state_dict) + + # Verify that a Future is returned + self.assertIsInstance(result, Future) + + # Wait for the Future to complete + staged_dict = result.result() + + # Verify the staged state dictionary + self.assertIn("model", staged_dict) + self.assertIn("optimizer", staged_dict) + self.assertEqual(staged_dict["epoch"], 5) + self.assertEqual(staged_dict["step"], 1000) + + # Clean up + stager.close() + + def test_cuda_non_blocking_without_cuda(self) -> None: + """Test that non-blocking copy fails when CUDA is not available.""" + if torch.cuda.is_available(): + self.skipTest("CUDA is available, cannot test CUDA unavailable scenario") + + options = CheckpointStagerConfig(use_cuda_non_blocking_copy=True) + with self.assertRaises(AssertionError): + DefaultStager(options) + + def test_different_option_combinations(self) -> None: + """Test various combinations of staging options.""" + test_cases = [ + # All disabled + CheckpointStagerConfig( + use_pinned_memory=False, + use_shared_memory=False, + use_async_staging=False, + use_cuda_non_blocking_copy=False, + ), + # Only pinned memory + CheckpointStagerConfig( + use_pinned_memory=True, + use_shared_memory=False, + use_async_staging=False, + use_cuda_non_blocking_copy=False, + ), + # Only shared memory + CheckpointStagerConfig( + use_pinned_memory=False, + use_shared_memory=True, + use_async_staging=False, + use_cuda_non_blocking_copy=False, + ), + ] + + if torch.cuda.is_available(): + # Only async staging + test_cases.append( + CheckpointStagerConfig( + use_pinned_memory=torch.cuda.is_available(), + use_shared_memory=False, + use_async_staging=True, + use_cuda_non_blocking_copy=False, + ) + ) + # Only CUDA non-blocking copy + test_cases.append( + CheckpointStagerConfig( + use_pinned_memory=torch.cuda.is_available(), + use_shared_memory=False, + use_async_staging=False, + use_cuda_non_blocking_copy=torch.cuda.is_available(), + ) + ) + + for options in test_cases: + with self.subTest(options=options): + stager = DefaultStager(options) + + # Test staging works with these options + if options.use_async_staging and torch.cuda.is_available(): + result = stager.stage(self.state_dict) + self.assertIsInstance(result, Future) + staged_dict = result.result() + else: + staged_dict = stager.stage(self.state_dict) + + self.assertIsInstance(staged_dict, dict) + self.assertIn("model", staged_dict) + + stager.close() + + @requires_cuda + def test_cuda_tensors_staging(self) -> None: + """Test staging with CUDA tensors.""" + # Create state dict with CUDA tensors + cuda_state_dict = { + "cuda_tensor": torch.randn(3, 4).cuda(), + "cpu_tensor": torch.randn(2, 3), + "mixed_model": { + "weight": torch.randn(5, 5).cuda(), + "bias": torch.randn(5).cuda(), + }, + } + + options = CheckpointStagerConfig(use_async_staging=False) + stager = DefaultStager(options) + + staged_dict = stager.stage(cuda_state_dict) + assert isinstance(staged_dict, dict) + + # Verify tensors are staged (should be moved to CPU) + self.assertIn("cuda_tensor", staged_dict) + self.assertIn("cpu_tensor", staged_dict) + self.assertIn("mixed_model", staged_dict) + + stager.close() + + @requires_cuda + def test_resource_cleanup(self) -> None: + """Test that resources are properly cleaned up.""" + options = CheckpointStagerConfig(use_async_staging=False) + stager = DefaultStager(options) + + # Verify initial state + self.assertIsNotNone(stager._state_dict_stager) + + # Close and verify cleanup + stager.close() + + def test_multiple_staging_operations(self) -> None: + """Test multiple staging operations with the same stager.""" + options = CheckpointStagerConfig( + use_async_staging=False, + use_pinned_memory=torch.cuda.is_available(), + use_shared_memory=False, + use_cuda_non_blocking_copy=torch.cuda.is_available(), + ) + stager = DefaultStager(options) + + # Stage multiple different state dicts + state_dicts = [ + {"model1": torch.nn.Linear(5, 3).state_dict()}, + {"model2": torch.nn.Conv2d(3, 16, 3).state_dict()}, + {"optimizer": {"lr": 0.001, "momentum": 0.9}}, + ] + + staged_results = [] + for state_dict in state_dicts: + staged_dict = stager.stage(state_dict) + staged_results.append(staged_dict) + + # Verify all staging operations succeeded + self.assertEqual(len(staged_results), 3) + for i, result in enumerate(staged_results): + self.assertIsInstance(result, dict) + # Verify the result contains the expected keys + for key in state_dicts[i].keys(): + self.assertIn(key, result) + + stager.close() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/_experimental/test_types.py b/test/distributed/checkpoint/_experimental/test_types.py new file mode 100644 index 00000000000000..6f67f619b768c0 --- /dev/null +++ b/test/distributed/checkpoint/_experimental/test_types.py @@ -0,0 +1,44 @@ +# Owner(s): ["oncall: distributed checkpointing"] + + +from torch.distributed.checkpoint._experimental.types import RankInfo, STATE_DICT +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestRankInfo(TestCase): + def test_rank_info_initialization(self): + """Test that RankInfo initializes correctly with all parameters.""" + # Create a RankInfo instance with all parameters + rank_info = RankInfo( + global_rank=0, + global_world_size=4, + ) + + # Verify that all attributes are set correctly + self.assertEqual(rank_info.global_rank, 0) + self.assertEqual(rank_info.global_world_size, 4) + + def test_rank_info_default_initialization(self): + """Test that RankInfo initializes correctly with default parameters.""" + # Create a RankInfo instance with only required parameters + rank_info = RankInfo( + global_rank=0, + global_world_size=1, + ) + + # Verify that all attributes are set correctly + self.assertEqual(rank_info.global_rank, 0) + self.assertEqual(rank_info.global_world_size, 1) + + def test_state_dict_type_alias(self): + """Test that STATE_DICT type alias works correctly.""" + # Create a state dictionary + state_dict = {"model": {"weight": [1, 2, 3]}, "optimizer": {"lr": 0.01}} + + # Verify that it can be assigned to a variable of type STATE_DICT + state_dict_var: STATE_DICT = state_dict + self.assertEqual(state_dict_var, state_dict) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py index 2e59c1e4fdd285..c2e37850d9d708 100644 --- a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py +++ b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import time +from concurrent.futures import Future from dataclasses import dataclass, field from enum import auto, Enum from functools import partial @@ -13,6 +14,7 @@ import torch.distributed.checkpoint.state_dict_saver as saver import torch.nn as nn import torch.nn.functional as F +from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions from torch.distributed.checkpoint.state_dict import ( _patch_model_state_dict, _patch_optimizer_state_dict, @@ -22,7 +24,10 @@ set_state_dict, ) from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys -from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType +from torch.distributed.checkpoint.state_dict_saver import ( + AsyncCheckpointerType, + AsyncSaveResponse, +) from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.device_mesh import init_device_mesh @@ -216,21 +221,26 @@ def test_e2e(self, compile, model_type): @skip_if_lt_x_gpu(4) @with_temp_dir @parametrize( - "cache_staged_state_dict, async_checkpointer_type", + "cache_staged_state_dict, async_checkpointer_type, zoc", [ - (False, AsyncCheckpointerType.THREAD), - (True, AsyncCheckpointerType.THREAD), - (False, AsyncCheckpointerType.PROCESS), - (True, AsyncCheckpointerType.PROCESS), + (False, AsyncCheckpointerType.THREAD, False), + (True, AsyncCheckpointerType.THREAD, False), + (False, AsyncCheckpointerType.PROCESS, False), + (True, AsyncCheckpointerType.PROCESS, False), + (False, AsyncCheckpointerType.PROCESS, True), + (False, AsyncCheckpointerType.THREAD, True), ], ) - def test_e2e_async_cached(self, cache_staged_state_dict, async_checkpointer_type): + def test_e2e_async_cached( + self, cache_staged_state_dict, async_checkpointer_type, zoc + ): self._run_e2e_test( compile=False, model_type=ModelType.FSDP, async_op=True, cache_staged_state_dict=cache_staged_state_dict, async_checkpointer_type=async_checkpointer_type, + zoc=zoc, ) def _run_e2e_test( @@ -240,6 +250,7 @@ def _run_e2e_test( async_op=False, cache_staged_state_dict=False, async_checkpointer_type=None, + zoc=False, ): model, optim = self._create_model(compile, ModelType.NONE) _train(model, optim, train_steps=2) @@ -259,7 +270,19 @@ def _run_e2e_test( writer = DCP.FileSystemWriter( self.temp_dir, cache_staged_state_dict=cache_staged_state_dict ) - f = saver.async_save( + stager = None + if not cache_staged_state_dict: + use_shared_memory = ( + async_checkpointer_type == AsyncCheckpointerType.PROCESS + ) + staging_options = StagingOptions( + use_async_staging=zoc, + use_shared_memory=use_shared_memory, + use_pinned_memory=zoc, + use_cuda_non_blocking_copy=zoc, + ) + stager = DefaultStager(staging_options) + async_save_response_or_future = saver.async_save( sd, storage_writer=writer, async_checkpointer_type=( @@ -267,13 +290,20 @@ def _run_e2e_test( if async_checkpointer_type else AsyncCheckpointerType.THREAD ), + async_stager=stager, ) + if isinstance(async_save_response_or_future, Future): + save_future = async_save_response_or_future + else: + assert isinstance(async_save_response_or_future, AsyncSaveResponse) + save_future = async_save_response_or_future.upload_completion + # wait for the future to complete t = time.monotonic() - while not f.done(): + while not save_future.done(): time.sleep(1) print(f"still waiting... {time.monotonic() - t}") - f.result() + save_future.result() else: DCP.save(sd, checkpoint_id=self.temp_dir) diff --git a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py new file mode 100644 index 00000000000000..ba07c62728d71f --- /dev/null +++ b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py @@ -0,0 +1,158 @@ +# Owner(s): ["oncall: distributed checkpointing"] + +import importlib +import json +import os + +import torch +import torch.distributed.checkpoint as dist_cp +from torch import distributed as dist +from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + consolidate_safetensors_files, +) +from torch.distributed.checkpoint._hf_utils import _metadata_fn +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import DTensor, Shard +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + skip_if_lt_x_gpu, + with_comms, +) +from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir + + +class TestConsolidateHFSafeTensors(DTensorTestBase): + def _create_d_tensors(self) -> None: + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + mesh_shape = (self.world_size,) + mesh_1d = init_device_mesh(self.device_type, mesh_shape) + + # Create local tensor with row-wise sharding + rows_per_rank = global_tensor.shape[0] // self.world_size + start_row = self.rank * rows_per_rank + end_row = start_row + rows_per_rank + local_tensor = global_tensor[start_row:end_row].clone() + + # Create DTensor with row-wise sharding + dtensor = DTensor.from_local( + local_tensor, + device_mesh=mesh_1d, + placements=[Shard(0)], + shape=global_tensor.shape, + stride=(4, 1), + ) + + # Create local tensor with column-wise sharding + cols_per_rank = global_tensor.shape[1] // self.world_size + start_col = self.rank * cols_per_rank + end_col = start_col + cols_per_rank + local_tensor_col = global_tensor[:, start_col:end_col].clone() + + # Create DTensor with column-wise sharding + dtensor_col = DTensor.from_local( + local_tensor_col, + device_mesh=mesh_1d, + placements=[Shard(1)], # Column-wise sharding + shape=global_tensor.shape, + stride=(4, 1), + ) + + state_dict_to_save = {"dtensor": dtensor, "dtensor_col": dtensor_col} + dist_cp.save( + state_dict=state_dict_to_save, + storage_writer=dist_cp.HuggingFaceStorageWriter( + path=self.temp_dir, save_distributed=True + ), + ) + dist.barrier() + os.sync() + + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(2) + def test_consolidate_to_one_file(self) -> None: + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + import safetensors + + checkpoint_dir = self.temp_dir + output_dir = os.path.join(checkpoint_dir, "consolidated") + os.makedirs(output_dir, exist_ok=True) + + self._create_d_tensors() + + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + + if self.rank == 0: + consolidate_safetensors_files(checkpoint_dir, output_dir) + + file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors") + loaded_dict = safetensors.torch.load_file(file_path) + self.assertEqual(loaded_dict.keys(), {"dtensor", "dtensor_col"}) + self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) + self.assertTrue(torch.equal(loaded_dict["dtensor_col"], global_tensor)) + + with open(os.path.join(output_dir, _metadata_fn)) as f: + metadata = json.load(f) + self.assertEqual(metadata["metadata"]["total_size"], 16 * 4 * 2) + self.assertEqual( + metadata["weight_map"], + { + "dtensor": "model-00001-of-00001.safetensors", + "dtensor_col": "model-00001-of-00001.safetensors", + }, + ) + + dist.barrier() + + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(2) + def test_consolidate_to_two_files(self): + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + import safetensors + + checkpoint_dir = self.temp_dir + output_dir = os.path.join(checkpoint_dir, "consolidated") + os.makedirs(output_dir, exist_ok=True) + + self._create_d_tensors() + + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + + if self.rank == 0: + fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2} + consolidate_safetensors_files( + checkpoint_dir, output_dir, fqn_to_index_mapping=fqn_to_index_mapping + ) + + file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors") + file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors") + + loaded_dict = safetensors.torch.load_file(file1_path) + self.assertEqual(loaded_dict.keys(), {"dtensor"}) + self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) + + loaded_dict_col = safetensors.torch.load_file(file2_path) + self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"}) + self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor)) + + with open(os.path.join(output_dir, _metadata_fn)) as f: + metadata = json.load(f) + self.assertEqual(metadata["metadata"]["total_size"], 16 * 4 * 2) + self.assertEqual( + metadata["weight_map"], + { + "dtensor": "model-00001-of-00002.safetensors", + "dtensor_col": "model-00002-of-00002.safetensors", + }, + ) + dist.barrier() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/test_hf_safetensor_e2e.py b/test/distributed/checkpoint/test_hf_safetensor_e2e.py new file mode 100644 index 00000000000000..0220ae5138fc1b --- /dev/null +++ b/test/distributed/checkpoint/test_hf_safetensor_e2e.py @@ -0,0 +1,461 @@ +# Owner(s): ["oncall: distributed checkpointing"] + +import importlib +import os + +import torch +import torch.distributed.checkpoint as dist_cp +from torch import distributed as dist +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard, zeros +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, + TestCase, +) +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + skip_if_lt_x_gpu, + with_comms, +) +from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir + + +CHECKPOINT_DIR = "checkpoint" + + +class MyTestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_1 = torch.nn.Linear(5, 5) + self.linear_2 = torch.nn.Linear(5, 1) + self.emb = torch.nn.EmbeddingBag(5, 10) + + +class TestSingleRankSaveLoad(TestCase): + @with_temp_dir + def test_save(self) -> None: + try: + from safetensors.torch import load_file + except ImportError: + print("safetensors not installed") + return + + CHECKPOINT_DIR = self.temp_dir + + state_dict_to_save = MyTestModule().state_dict() + dist_cp.save( + state_dict=state_dict_to_save, + storage_writer=dist_cp.HuggingFaceStorageWriter(path=CHECKPOINT_DIR), + ) + + state_dict_loaded = load_file( + CHECKPOINT_DIR + "/model-00001-of-00001.safetensors" + ) + self.assertEqual( + sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()) + ) + for key in state_dict_to_save.keys(): + self.assertTrue( + torch.equal(state_dict_to_save[key], state_dict_loaded[key]) + ) + + @with_temp_dir + def test_load(self) -> None: + try: + from safetensors.torch import save_file + except ImportError: + print("safetensors not installed") + return + + CHECKPOINT_DIR = self.temp_dir + + state_dict_to_save = MyTestModule().state_dict() + state_dict_to_load = MyTestModule().state_dict() + save_file( + state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors" + ) + + dist_cp.load( + state_dict=state_dict_to_load, + storage_reader=dist_cp.HuggingFaceStorageReader(path=CHECKPOINT_DIR), + ) + + self.assertEqual( + sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()) + ) + for key in state_dict_to_save.keys(): + self.assertTrue( + torch.equal(state_dict_to_save[key], state_dict_to_load[key]) + ) + + @with_temp_dir + def test_load_into_empty_dict(self) -> None: + try: + from safetensors.torch import save_file + except ImportError: + print("safetensors not installed") + return + + CHECKPOINT_DIR = self.temp_dir + + state_dict_to_save = MyTestModule().state_dict() + save_file( + state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors" + ) + + state_dict_loaded = _load_state_dict_from_keys( + storage_reader=dist_cp.HuggingFaceStorageReader(path=CHECKPOINT_DIR), + ) + + self.assertEqual( + sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()) + ) + for key in state_dict_to_save.keys(): + self.assertTrue( + torch.equal(state_dict_to_save[key], state_dict_loaded[key]) + ) + + +class TestDistributedHFSafetensorsConsolidation(DTensorTestBase): + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(2) + def test_consolidate_to_one_file(self) -> None: + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + + import safetensors + + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + mesh_shape = (self.world_size,) + mesh_1d = init_device_mesh(self.device_type, mesh_shape) + + # Create local tensor with row-wise sharding + rows_per_rank = global_tensor.shape[0] // self.world_size + start_row = self.rank * rows_per_rank + end_row = start_row + rows_per_rank + local_tensor = global_tensor[start_row:end_row].clone() + + # Create DTensor with row-wise sharding + dtensor = DTensor.from_local( + local_tensor, + device_mesh=mesh_1d, + placements=[Shard(0)], + shape=global_tensor.shape, + stride=(4, 1), + ) + + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + + checkpoint_dir = self.temp_dir + consolidated_output_dir = os.path.join(checkpoint_dir, "consolidated") + os.makedirs(consolidated_output_dir, exist_ok=True) + + state_dict_to_save = {"dtensor": dtensor} + dist_cp.save( + state_dict=state_dict_to_save, + storage_writer=dist_cp.HuggingFaceStorageWriter( + path=checkpoint_dir, + save_distributed=True, + consolidated_output_path=consolidated_output_dir, + ), + ) + dist.barrier() + + if self.rank == 0: + file_path = os.path.join( + consolidated_output_dir, "model-00001-of-00001.safetensors" + ) + loaded_dict = safetensors.torch.load_file(file_path) + self.assertEqual(loaded_dict.keys(), {"dtensor"}) + self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) + + dist.barrier() + + +ONE_D_PLACEMENTS = [ + [Shard(0)], + [Replicate()], +] +ONE_D_TO_ONE_D_PLACEMENTS = [ + ([Replicate()], [Shard(0)]), + ([Shard(0)], [Replicate()]), +] + +TWO_D_PLACEMENTS = [ + [Replicate(), Replicate()], + [Replicate(), Shard(0)], + [Shard(0), Replicate()], + [Shard(0), Shard(0)], +] +TWO_D_TO_TWO_D_PLACEMENTS = [] +for p1 in TWO_D_PLACEMENTS: + for p2 in TWO_D_PLACEMENTS: + if p1 != p2: + TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2)) + + +@instantiate_parametrized_tests +class TestDTensorReshardPlacementChange(DTensorTestBase): + """ + Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change. + """ + + @with_comms + @skip_if_lt_x_gpu(2) + @with_temp_dir + def test_1d_to_1d_reshard_placement_change(self) -> None: + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + + CHECKPOINT_DIR = self.temp_dir + + for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS: + original_placement, new_placement = one_d_to_one_d_placements + + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + mesh_shape = (self.world_size,) + device_mesh = init_device_mesh(self.device_type, mesh_shape) + dtensor = distribute_tensor( + global_tensor, device_mesh, placements=original_placement + ) + state_dict_to_save = {"dtensor": dtensor} + + dist_cp.save( + state_dict=state_dict_to_save, + storage_writer=dist_cp.HuggingFaceStorageWriter( + path=CHECKPOINT_DIR, + save_distributed=True, + ), + ) + + zero_dtensor = zeros( + [4, 4], device_mesh=device_mesh, placements=new_placement + ) + state_dict_to_load = {"dtensor": zero_dtensor} + + dist_cp.load( + state_dict=state_dict_to_load, + storage_reader=dist_cp.HuggingFaceStorageReader( + CHECKPOINT_DIR, + ), + ) + + # materialize the whole tensor to compare with the original global_tensor + state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( + device_mesh, + placements=[Replicate()], + ) + self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local()) + + # redistribute the tensor back to its original placement for comparison. + state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( + device_mesh, + placements=original_placement, + ) + self.assertEqual( + state_dict_to_save["dtensor"].to_local(), + state_dict_to_load["dtensor"].to_local(), + ) + + @with_comms + @skip_if_lt_x_gpu(4) + @with_temp_dir + def test_2d_to_2d_reshard_placement_change(self) -> None: + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + + CHECKPOINT_DIR = self.temp_dir + for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS: + original_placement, new_placement = two_d_to_two_d_placements + + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + mesh_shape = (2, self.world_size // 2) + mesh_2d = init_device_mesh(self.device_type, mesh_shape) + dtensor = distribute_tensor( + global_tensor, + mesh_2d, + placements=original_placement, + ) + state_dict_to_save = {"dtensor": dtensor} + + dist_cp.save( + state_dict=state_dict_to_save, + storage_writer=dist_cp.HuggingFaceStorageWriter( + path=CHECKPOINT_DIR, save_distributed=True + ), + planner=dist_cp.DefaultSavePlanner(), + ) + + zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement) + state_dict_to_load = {"dtensor": zero_dtensor} + + dist_cp.load( + state_dict=state_dict_to_load, + storage_reader=dist_cp.HuggingFaceStorageReader(CHECKPOINT_DIR), + ) + + state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( + mesh_2d, + placements=[Replicate(), Replicate()], + ) + self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local()) + + state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( + mesh_2d, + placements=original_placement, + ) + self.assertEqual( + state_dict_to_save["dtensor"].to_local(), + state_dict_to_load["dtensor"].to_local(), + ) + + +class TestDTensorReshardMeshChange(DTensorTestBase): + """ + Test DCP reshard for DTensor with placements changes and mesh_tensor change. + """ + + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(2) + def test_1d_to_2d_reshard_mesh_change(self) -> None: + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + + CHECKPOINT_DIR = self.temp_dir + for placements_1d in ONE_D_PLACEMENTS: + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + mesh_shape = (self.world_size,) + mesh_1d = init_device_mesh(self.device_type, mesh_shape) + dtensor = distribute_tensor( + global_tensor, mesh_1d, placements=placements_1d + ) + state_dict_to_save = {"dtensor": dtensor} + + dist_cp.save( + state_dict=state_dict_to_save, + storage_writer=dist_cp.HuggingFaceStorageWriter( + path=CHECKPOINT_DIR, save_distributed=True + ), + ) + + for placements_2d in TWO_D_PLACEMENTS: + mesh_shape = (2, self.world_size // 2) + mesh_2d = init_device_mesh(self.device_type, mesh_shape) + + zero_dtensor = zeros( + [4, 4], device_mesh=mesh_2d, placements=placements_2d + ) + state_dict_to_load = {"dtensor": zero_dtensor} + + dist_cp.load( + state_dict=state_dict_to_load, + storage_reader=dist_cp.HuggingFaceStorageReader(CHECKPOINT_DIR), + planner=dist_cp.DefaultLoadPlanner(), + ) + + # materialzie the whole tensor to compare with the original global_tensor + state_dict_to_load["dtensor"] = state_dict_to_load[ + "dtensor" + ].redistribute( + mesh_2d, + placements=[Replicate(), Replicate()], + ) + self.assertEqual( + global_tensor, state_dict_to_load["dtensor"].to_local() + ) + + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(4) + def test_2d_to_1d_reshard_mesh_change(self) -> None: + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + + CHECKPOINT_DIR = self.temp_dir + for placements_2d in TWO_D_PLACEMENTS: + global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) + mesh_shape = (2, self.world_size // 2) + mesh_2d = init_device_mesh(self.device_type, mesh_shape) + dtensor = distribute_tensor( + global_tensor, mesh_2d, placements=placements_2d + ) + state_dict_to_save = {"dtensor": dtensor} + + dist_cp.save( + state_dict=state_dict_to_save, + storage_writer=dist_cp.HuggingFaceStorageWriter( + path=CHECKPOINT_DIR, save_distributed=True + ), + planner=dist_cp.DefaultSavePlanner(), + ) + + for placements_1d in ONE_D_PLACEMENTS: + mesh_shape = (self.world_size,) + mesh_1d = init_device_mesh(self.device_type, mesh_shape) + + zero_dtensor = zeros( + [4, 4], device_mesh=mesh_1d, placements=placements_1d + ) + state_dict_to_load = {"dtensor": zero_dtensor} + + dist_cp.load( + state_dict=state_dict_to_load, + storage_reader=dist_cp.HuggingFaceStorageReader(CHECKPOINT_DIR), + planner=dist_cp.DefaultLoadPlanner(), + ) + + # materialzie the whole tensor to compare with the original global_tensor + state_dict_to_load["dtensor"] = state_dict_to_load[ + "dtensor" + ].redistribute( + mesh_1d, + placements=[Replicate()], + ) + self.assertEqual( + global_tensor, state_dict_to_load["dtensor"].to_local() + ) + + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(2) + def test_dtensor_checkpoint_resharding_with_empty_shard(self): + """ + Test dtensor checkpoint resharding with dtensor containing empty shards. + """ + if importlib.util.find_spec("safetensors") is None: + print("safetensors not installed") + return + + tensor = torch.rand(1).cuda() + mesh = init_device_mesh(self.device_type, (self.world_size,)) + dtensor = distribute_tensor(tensor, mesh, [Shard(0)]) + ref_state_dict = {"dtensor": dtensor} + + dist_cp.save( + state_dict=ref_state_dict, + storage_writer=dist_cp.HuggingFaceStorageWriter( + path=self.temp_dir, save_distributed=True + ), + ) + + tensor = torch.rand(1).cuda() + mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2)) + dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)]) + state_dict = {"dtensor": dtensor} + dist_cp.load( + state_dict=state_dict, + storage_reader=dist_cp.HuggingFaceStorageReader(self.temp_dir), + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/test_hf_storage.py b/test/distributed/checkpoint/test_hf_storage.py index 6d5e04a0ca7a44..a07df2535bed33 100644 --- a/test/distributed/checkpoint/test_hf_storage.py +++ b/test/distributed/checkpoint/test_hf_storage.py @@ -8,27 +8,30 @@ from unittest.mock import MagicMock import torch -from torch.distributed.checkpoint._hf_planner import ( - _FqnToFileMapping, - _HuggingFaceLoadPlanner, -) -from torch.distributed.checkpoint._hf_storage import ( - _HuggingFaceStorageReader, - _HuggingFaceStorageWriter, - _metadata_fn, -) +from torch.distributed.checkpoint import DefaultLoadPlanner +from torch.distributed.checkpoint._hf_utils import _HFStorageInfo from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem +from torch.distributed.checkpoint.hf_storage import ( + _metadata_fn, + HuggingFaceStorageReader, + HuggingFaceStorageWriter, +) from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, + ChunkStorageMetadata, Metadata, MetadataIndex, + TensorProperties, + TensorStorageMetadata, ) -from torch.distributed.checkpoint.planner import LoadPlan, SavePlan -from torch.distributed.checkpoint.planner_helpers import ( - _create_read_items, - _create_write_item_for_tensor, +from torch.distributed.checkpoint.planner import ( + LoadItemType, + LoadPlan, + ReadItem, + SavePlan, ) +from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor from torch.distributed.checkpoint.storage import WriteResult from torch.testing._internal.common_utils import run_tests, TestCase @@ -36,17 +39,74 @@ class TestHfStorage(TestCase): def test_write_data_hf(self) -> None: mock_module = MagicMock() - sys.modules["safetensors"] = mock_module - sys.modules["huggingface_hub"] = mock_module + mock_module.save.return_value = b"" + sys.modules["safetensors.torch"] = mock_module + + with tempfile.TemporaryDirectory() as path: + writer = HuggingFaceStorageWriter( + path=path, + fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 2}, + ) + writer.fs = FileSystem() + + tensor0 = torch.rand(4) + tensor1 = torch.rand(10) + write_item_1 = _create_write_item_for_tensor("tensor_0", tensor0) + write_item_2 = _create_write_item_for_tensor("tensor_1", tensor1) + + state_dict = {"tensor_0": tensor0, "tensor_1": tensor1} + + save_plan = SavePlan( + [write_item_1, write_item_2], + storage_data={"fqn_to_index_mapping": {"tensor_0": 1, "tensor_1": 2}}, + ) + save_planner = DefaultSavePlanner() + save_planner.set_up_planner(state_dict=state_dict) + + write_results = writer.write_data(save_plan, save_planner) + + write_results.wait() + actual_write_results = write_results.value() + + expected_write_results = [ + WriteResult( + index=MetadataIndex( + fqn="tensor_0", offset=torch.Size([0]), index=None + ), + size_in_bytes=tensor0.numel() * tensor0.element_size(), + storage_data=_StorageInfo( + relative_path="model-00001-of-00002.safetensors", + offset=0, + length=tensor0.numel() * tensor0.element_size(), + ), + ), + WriteResult( + index=MetadataIndex( + fqn="tensor_1", offset=torch.Size([0]), index=None + ), + size_in_bytes=tensor1.numel() * tensor1.element_size(), + storage_data=_StorageInfo( + relative_path="model-00002-of-00002.safetensors", + offset=0, + length=tensor1.numel() * tensor1.element_size(), + ), + ), + ] + self.assertEqual( + actual_write_results, + expected_write_results, + ) + + def test_write_data_with_sharding(self) -> None: mock_module = MagicMock() mock_module.save.return_value = b"" sys.modules["safetensors.torch"] = mock_module with tempfile.TemporaryDirectory() as path: - writer = _HuggingFaceStorageWriter( + writer = HuggingFaceStorageWriter( path=path, - fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 1}, + save_distributed=True, ) writer.fs = FileSystem() @@ -59,7 +119,7 @@ def test_write_data_hf(self) -> None: save_plan = SavePlan( [write_item_1, write_item_2], - storage_data=_FqnToFileMapping({"tensor_0": 1, "tensor_1": 1}), + storage_data={"shard_index": 1}, ) save_planner = DefaultSavePlanner() save_planner.set_up_planner(state_dict=state_dict) @@ -76,7 +136,7 @@ def test_write_data_hf(self) -> None: ), size_in_bytes=tensor0.numel() * tensor0.element_size(), storage_data=_StorageInfo( - relative_path="model-00001-of-00001.safetensors", + relative_path="shard-00001-model-00001-of-00001.safetensors", offset=0, length=tensor0.numel() * tensor0.element_size(), ), @@ -87,7 +147,7 @@ def test_write_data_hf(self) -> None: ), size_in_bytes=tensor1.numel() * tensor1.element_size(), storage_data=_StorageInfo( - relative_path="model-00001-of-00001.safetensors", + relative_path="shard-00001-model-00001-of-00001.safetensors", offset=0, length=tensor1.numel() * tensor1.element_size(), ), @@ -100,43 +160,99 @@ def test_write_data_hf(self) -> None: ) def test_read_data_hf(self) -> None: - mock_module = MagicMock() - sys.modules["safetensors"] = mock_module - sys.modules["huggingface_hub"] = mock_module + mock_safetensors = MagicMock() + sys.modules["safetensors"] = mock_safetensors - name = "tensor_0" - tensor_0 = torch.rand(4) - mock_module = MagicMock() - mock_module.load.return_value = {name: tensor_0} - sys.modules["safetensors.torch"] = mock_module + # Create test tensors + tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0]) + + # Mock the deserialize function to return our test tensors + # The format matches what's expected in the read_data method + mock_safetensors.deserialize.return_value = [ + ( + "tensor_0", + {"data": tensor_0.numpy().tobytes(), "dtype": "F32", "shape": [4]}, + ), + ] with tempfile.TemporaryDirectory() as path: - reader = _HuggingFaceStorageReader(path=path) + # Create the reader + reader = HuggingFaceStorageReader(path=path) reader.fs = FileSystem() - file_name = "model-00001-of-00001" - - pathlib.Path(os.path.join(path, file_name)).touch() - reader.set_up_storage_reader( - Metadata( - state_dict_metadata={name: BytesStorageMetadata()}, - storage_data={name: file_name}, + # Create test file + file_name = "model-00001-of-00001.safetensors" + file_path = os.path.join(path, file_name) + pathlib.Path(file_path).touch() + + # Set up storage data with _StorageInfo objects + storage_data = { + MetadataIndex( + fqn="tensor_0", offset=torch.Size([0]), index=None + ): _HFStorageInfo( + file_path, + 0, + tensor_0.numel() * tensor_0.element_size(), + tensor_0.shape, + tensor_0.dtype, ), - is_coordinator=True, - ) + } - read_items = _create_read_items(name, BytesStorageMetadata(), file_name) + reader.storage_data = storage_data + + # Create target tensors that will be updated by read_data + target_tensor_0 = torch.zeros(4) + state_dict = { + "tensor_0": target_tensor_0, + } + + # Create read items for the load plan + read_items = [] + for name, tensor in state_dict.items(): + storage_index = MetadataIndex( + fqn=name, offset=torch.Size([0]), index=None + ) + dest_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None) + read_items.append( + ReadItem( + type=LoadItemType.TENSOR, + storage_index=storage_index, + dest_index=dest_index, + storage_offsets=[0, 0], + dest_offsets=[0, 0], + lengths=tensor.size(), + ) + ) + + # Create load plan and planner load_plan = LoadPlan(read_items) - load_planner = _HuggingFaceLoadPlanner() - load_planner.set_up_planner(state_dict={name: torch.rand(4)}) + load_planner = DefaultLoadPlanner() + load_planner.set_up_planner( + state_dict=state_dict, + metadata=Metadata( + state_dict_metadata={ + "tensor_0": TensorStorageMetadata( + properties=TensorProperties(dtype=torch.float32), + size=torch.Size([4]), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size([0]), sizes=torch.Size([4]) + ) + ], + ) + }, + storage_data=storage_data, + ), + ) - read_data = reader.read_data(load_plan, load_planner) - read_data.wait() + # Call read_data + future = reader.read_data(load_plan, load_planner) + future.wait() - loaded_tensor = load_planner.original_state_dict[name] - self.assertEqual(loaded_tensor, tensor_0) + # Verify results - the target tensors should now contain the values from our test tensor + self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0)) - def test_metadata_hf(self) -> None: + def test_write_metadata_hf(self) -> None: mock_module = MagicMock() sys.modules["huggingface_hub"] = mock_module with tempfile.TemporaryDirectory() as path: @@ -146,21 +262,24 @@ def test_metadata_hf(self) -> None: index=MetadataIndex(fqn="tensor_0", offset=None, index=None), size_in_bytes=100, storage_data=_StorageInfo( - relative_path=file_name, offset=0, length=100 + relative_path=file_name, + offset=0, + length=100, ), ), WriteResult( index=MetadataIndex(fqn="tensor_1", offset=None, index=None), size_in_bytes=100, storage_data=_StorageInfo( - relative_path=file_name, offset=0, length=100 + relative_path=file_name, + offset=0, + length=100, ), ), ] - writer = _HuggingFaceStorageWriter( + writer = HuggingFaceStorageWriter( path=path, - fqn_to_index_mapping=_FqnToFileMapping({}), ) writer.fs = FileSystem() writer.finish( @@ -185,26 +304,22 @@ def test_metadata_hf(self) -> None: metadata = json.load(f) self.assertEqual(metadata, expected_metadata) - reader = _HuggingFaceStorageReader(path=path) - reader.fs = FileSystem() - metadata = reader.read_metadata() - self.assertEqual(metadata.storage_data, expected_metadata["weight_map"]) - - def test_read_metadata_when_metadata_file_does_not_exist(self) -> None: - mock_module = MagicMock() - sys.modules["huggingface_hub"] = mock_module - + def test_read_metadata_hf(self): with tempfile.TemporaryDirectory() as path: - reader = _HuggingFaceStorageReader(path=path) - reader.fs = FileSystem() - # there is one safetensor file, but no metadata file, - # so we create metadata from the safetensor file - keys = ["tensor_0", "tensor_1"] + reader = HuggingFaceStorageReader(path=path) + + key = "tensor_0" file_name = "test.safetensors" with open(os.path.join(path, file_name), "wb") as f: # write metadata the same way it would be in safetensors file metadata_contents = json.dumps( - {"tensor_0": "value_0", "tensor_1": "value_1"} + { + "tensor_0": { + "dtype": "F32", + "shape": [5, 10], + "data_offsets": [0, 200], + } + } ) metadata_bytes = metadata_contents.encode("utf-8") @@ -216,13 +331,30 @@ def test_read_metadata_when_metadata_file_does_not_exist(self) -> None: self.assertEqual( metadata.state_dict_metadata, { - keys[0]: BytesStorageMetadata(), - keys[1]: BytesStorageMetadata(), + key: TensorStorageMetadata( + properties=TensorProperties(dtype=torch.float32), + size=torch.Size([5, 10]), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size([0, 0]), sizes=torch.Size([5, 10]) + ) + ], + ), }, ) self.assertEqual( metadata.storage_data, - {keys[0]: file_name, keys[1]: file_name}, + { + MetadataIndex( + fqn=key, offset=torch.Size([0, 0]), index=None + ): _HFStorageInfo( + os.path.join(path, file_name), + 0, + 200, + torch.Size([5, 10]), + torch.float32, + ) + }, ) diff --git a/test/distributed/checkpoint/test_pg_transport.py b/test/distributed/checkpoint/test_pg_transport.py new file mode 100644 index 00000000000000..df64e9451b4677 --- /dev/null +++ b/test/distributed/checkpoint/test_pg_transport.py @@ -0,0 +1,516 @@ +# Owner(s): ["oncall: distributed"] + +import logging +import os +from datetime import timedelta +from typing import Optional +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn +from torch.distributed.checkpoint._pg_transport import ( + _cast_tensor, + _prepare_state_dict, + _prepare_tensor, + _StateDictMeta, + _TensorMeta, + PGTransport, +) +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.tensor import DTensor +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_distributed import ( + MultiProcContinousTest, + requires_nccl, +) +from torch.testing._internal.common_utils import ( + run_tests, + skip_but_pass_in_sandcastle_if, + TestCase, +) + + +logger = logging.getLogger(__name__) + + +class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 10) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def ring_send_recv_checkpoint( + transport: PGTransport, state_dict, rank, world_size, step=0 +): + """ + Use the transport to send to rank + 1 and receive from rank - 1. + """ + next_rank = (rank + 1) % world_size + prev_rank = (rank - 1) % world_size + if rank == 0: + transport.send_checkpoint([next_rank], state_dict) + received_checkpoint = transport.recv_checkpoint(prev_rank) + else: + received_checkpoint = transport.recv_checkpoint(prev_rank) + transport.send_checkpoint([next_rank], received_checkpoint) + return received_checkpoint + + +def _test_pg_transport(self, device) -> None: + # python test/distributed/checkpoint/test_pg_transport.py -k test_pg_transport + print(f"{self.rank=} pid: {os.getpid()} {device=}") + print("in test") + + model = SimpleModel().to(device) + transport = PGTransport(_get_default_group(), timedelta(seconds=10), device) + original_state_dict = model.state_dict() + received_checkpoint = ring_send_recv_checkpoint( + transport=transport, + state_dict=original_state_dict, + rank=self.rank, + world_size=self.world_size, + ) + self.assertEqual(original_state_dict, received_checkpoint) + + +def _test_pg_transport_with_mixed_content(self, device) -> None: + # Create a device mesh for DTensor + device_mesh = init_device_mesh(device.type, (self.world_size,)) + + # Create a DTensor + local_tensor = torch.randn(10, 10, device=device) + dtensor = DTensor.from_local(local_tensor, device_mesh) + + # Include mixed content in the state dict + # Dtensor, Tensor, and non-tensor + model = SimpleModel().to(device) + state_dict = { + "net1.weight": model.net1.weight.data, + "net1.bias": model.net1.bias.data, + "net2.weight": model.net2.weight.data, + "net2.bias": model.net2.bias.data, + "dtensor": dtensor, + "non-tensor": "some string", + "nested": {"tensor": torch.randn(1, 2), "value": 42}, + "list": [1, 2, 3], + } + + transport = PGTransport(_get_default_group(), timedelta(seconds=10), device) + received_checkpoint = ring_send_recv_checkpoint( + transport=transport, + state_dict=state_dict, + rank=self.rank, + world_size=self.world_size, + ) + self.assertEqual(state_dict, received_checkpoint) + + +class PgTransportCPU(MultiProcContinousTest): + world_size = 8 + timeout: timedelta = timedelta(seconds=20) + + @classmethod + def backend_str(cls) -> Optional[str]: + return "gloo" + + @classmethod + def device_type(cls) -> str: + return "cpu" + + @property + def device(self) -> torch.device: + return torch.device(self.device_type()) + + def test_pg_transport(self) -> None: + _test_pg_transport(self, self.device) + + def test_pg_transport_with_mixed_content(self) -> None: + _test_pg_transport_with_mixed_content(self, self.device) + + +class PgTransportCUDA(MultiProcContinousTest): + world_size = 2 + timeout: timedelta = timedelta(seconds=20) + + @classmethod + def backend_str(cls) -> Optional[str]: + return "nccl" + + @classmethod + def device_type(cls) -> str: + return "cuda" + + @property + def device(self) -> torch.device: + return torch.device(f"{self.device_type()}:{self.rank}") + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_pg_transport(self) -> None: + _test_pg_transport(self, self.device) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_pg_transport_with_mixed_content(self) -> None: + _test_pg_transport_with_mixed_content(self, self.device) + + +class TestCastTensor(TestCase): + def test_cast_tensor_different_dtypes(self): + """Test casting tensors of different dtypes.""" + dtypes = [torch.float32, torch.float64, torch.int32, torch.int64, torch.bool] + + for dtype in dtypes: + original = torch.tensor([1, 2, 3], dtype=dtype) + casted = _cast_tensor(original, torch.uint8) + + # Check that the storage is the same + self.assertIs(original.untyped_storage(), casted.untyped_storage()) + + # Check that the size is correct + self.assertEqual(casted.numel(), original.untyped_storage().nbytes()) + + def test_cast_tensor_with_stride(self): + """Test casting tensors with non-standard strides.""" + # Create a tensor with non-standard stride + original = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + transposed = original.t() # Transpose to get non-standard stride + + casted = _cast_tensor(transposed, torch.uint8) + + # Check that the storage is the same + self.assertIs(transposed.untyped_storage(), casted.untyped_storage()) + + # Check that the size is correct + self.assertEqual(casted.numel(), transposed.untyped_storage().nbytes()) + + def test_cast_tensor_with_offset(self): + """Test casting tensors with storage offset.""" + # Create a tensor with storage offset + original = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + sliced = original[2:] # This creates a tensor with storage offset + + casted = _cast_tensor(sliced, torch.uint8) + + # Check that the storage is the same + self.assertIs(sliced.untyped_storage(), casted.untyped_storage()) + + # Check that the size is correct + self.assertEqual(casted.numel(), sliced.untyped_storage().nbytes()) + + +class TestPrepareTensor(TestCase): + def test_prepare_tensor_basic(self): + """Test basic tensor preparation.""" + tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + prepared_tensor, meta = _prepare_tensor(tensor) + + # Check metadata + self.assertEqual(meta.shape, tensor.shape) + self.assertEqual(meta.dtype, tensor.dtype) + self.assertEqual(meta.storage_offset, tensor.storage_offset()) + self.assertEqual(meta.stride, tensor.stride()) + self.assertEqual(meta.nbytes, tensor.untyped_storage().nbytes()) + + # Check prepared tensor + self.assertEqual(prepared_tensor.dtype, torch.uint8) + self.assertEqual(prepared_tensor.numel(), tensor.untyped_storage().nbytes()) + + def test_prepare_tensor_different_shapes(self): + """Test preparing tensors with different shapes.""" + shapes = [(3,), (2, 3), (2, 3, 4)] + + for shape in shapes: + tensor = torch.randn(shape) + prepared_tensor, meta = _prepare_tensor(tensor) + + # Check metadata + self.assertEqual(meta.shape, tensor.shape) + self.assertEqual(meta.dtype, tensor.dtype) + self.assertEqual(meta.storage_offset, tensor.storage_offset()) + self.assertEqual(meta.stride, tensor.stride()) + self.assertEqual(meta.nbytes, tensor.untyped_storage().nbytes()) + + def test_prepare_tensor_with_stride(self): + """Test preparing tensors with non-standard strides.""" + tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + transposed = tensor.t() # Transpose to get non-standard stride + + prepared_tensor, meta = _prepare_tensor(transposed) + + # Check metadata + self.assertEqual(meta.shape, transposed.shape) + self.assertEqual(meta.dtype, transposed.dtype) + self.assertEqual(meta.storage_offset, transposed.storage_offset()) + self.assertEqual(meta.stride, transposed.stride()) + self.assertEqual(meta.nbytes, transposed.untyped_storage().nbytes()) + + +class TestPrepareStateDict(TestCase): + def test_prepare_state_dict_basic(self): + """Test basic state dict preparation.""" + state_dict = {"weight": torch.randn(3, 4), "bias": torch.randn(4)} + device = torch.device("cpu") + + meta, tensors = _prepare_state_dict(state_dict, device) + + # Check metadata + self.assertEqual(len(meta.paths), 2) + self.assertEqual(len(meta.non_tensor_leaves), 2) + self.assertEqual(len(tensors), 2) + + # Check that all non_tensor_leaves are _TensorMeta instances + for leaf in meta.non_tensor_leaves: + self.assertIsInstance(leaf, _TensorMeta) + + def test_prepare_state_dict_nested(self): + """Test preparing nested state dict.""" + state_dict = { + "layer1": {"weight": torch.randn(3, 4), "bias": torch.randn(4)}, + "layer2": {"weight": torch.randn(4, 5), "bias": torch.randn(5)}, + } + device = torch.device("cpu") + + meta, tensors = _prepare_state_dict(state_dict, device) + + # Check metadata + self.assertEqual(len(meta.paths), 4) + self.assertEqual(len(meta.non_tensor_leaves), 4) + self.assertEqual(len(tensors), 4) + + def test_prepare_state_dict_with_non_tensor_values(self): + """Test preparing state dict with non-tensor values.""" + state_dict = { + "weight": torch.randn(3, 4), + "bias": torch.randn(4), + "config": {"lr": 0.01, "momentum": 0.9}, + "step": 42, + } + device = torch.device("cpu") + + meta, tensors = _prepare_state_dict(state_dict, device) + + # Check metadata - the actual number of paths depends on how the pytree flattens the dict + # The nested config dict might be flattened differently + self.assertEqual(len(meta.non_tensor_leaves), len(meta.paths)) + self.assertEqual(len(tensors), 2) + + # Check that non-tensor values are preserved + non_tensor_values = [ + leaf for leaf in meta.non_tensor_leaves if not isinstance(leaf, _TensorMeta) + ] + self.assertEqual(len(non_tensor_values), 3) # config (2) and step + + +class TestPGTransportMocked(TestCase): + def setUp(self): + self.device = torch.device("cpu") + self.pg = MagicMock() + self.timeout = timedelta(seconds=10) + + # Mock Work object + self.mock_work = MagicMock() + self.mock_work.wait = MagicMock() + + # Setup process group mock to return mock_work + self.pg.send = MagicMock(return_value=self.mock_work) + self.pg.recv = MagicMock(return_value=self.mock_work) + + def test_send_checkpoint_basic(self): + """Test basic send_checkpoint functionality with mocked process group.""" + transport = PGTransport(self.pg, self.timeout, self.device) + state_dict = {"weight": torch.randn(3, 4), "bias": torch.randn(4)} + dst_ranks = [1, 2] + + transport.send_checkpoint(dst_ranks, state_dict) + + # Check that send was called with correct parameters + # First for metadata length, then for metadata, then for each tensor + expected_calls = len(dst_ranks) * (2 + len(state_dict)) + self.assertEqual(self.pg.send.call_count, expected_calls) + + # Check that wait was called on all work objects + self.assertEqual(self.mock_work.wait.call_count, expected_calls) + + def test_recv_checkpoint_basic(self): + """Test basic recv_checkpoint functionality with mocked process group.""" + # Setup mock for pickle.loads to return a valid _StateDictMeta + with patch("pickle.loads") as mock_loads: + # Create a mock state dict metadata + from torch.utils._pytree import tree_flatten_with_path + + state_dict = {"weight": torch.randn(3, 4), "bias": torch.randn(4)} + leaves, treespec = tree_flatten_with_path(state_dict) + paths = [path for path, _ in leaves] + + # Create mock tensor metadata + tensor_metas = [] + for _, v in leaves: + tensor_metas.append( + _TensorMeta( + shape=v.shape, + dtype=v.dtype, + storage_offset=v.storage_offset(), + stride=v.stride(), + nbytes=v.untyped_storage().nbytes(), + ) + ) + + mock_meta = _StateDictMeta( + treespec=treespec, paths=paths, non_tensor_leaves=tensor_metas + ) + mock_loads.return_value = mock_meta + + # Setup len_t and buf tensors for the mock recv + def side_effect(tensor_list, *args, **kwargs): + if tensor_list[0].numel() == 1: # This is len_t + tensor_list[0].fill_(100) # Some arbitrary length + return self.mock_work + + self.pg.recv.side_effect = side_effect + + # Create transport and call recv_checkpoint + transport = PGTransport(self.pg, self.timeout, self.device) + transport.recv_checkpoint(src_rank=0) + + # Check that recv was called + self.assertGreaterEqual( + self.pg.recv.call_count, 2 + ) # At least for len_t and buf + + # Check that wait was called + self.assertGreaterEqual(self.mock_work.wait.call_count, 2) + + def test_send_checkpoint_empty_state_dict(self): + """Test send_checkpoint with empty state dict.""" + transport = PGTransport(self.pg, self.timeout, self.device) + state_dict = {} + dst_ranks = [1] + + transport.send_checkpoint(dst_ranks, state_dict) + + # Check that send was called only for metadata + self.assertEqual(self.pg.send.call_count, 2) # len_t and buf_t + + # Check that wait was called + self.assertEqual(self.mock_work.wait.call_count, 2) + + def test_send_checkpoint_with_non_tensor_values(self): + """Test send_checkpoint with non-tensor values in state dict.""" + transport = PGTransport(self.pg, self.timeout, self.device) + state_dict = {"weight": torch.randn(3, 4), "config": {"lr": 0.01}} + dst_ranks = [1] + + transport.send_checkpoint(dst_ranks, state_dict) + + # Check that send was called for metadata and one tensor + self.assertEqual(self.pg.send.call_count, 3) # len_t, buf_t, and one tensor + + # Check that wait was called + self.assertEqual(self.mock_work.wait.call_count, 3) + + def test_recv_checkpoint_with_state_dict_callback(self): + """Test recv_checkpoint with state_dict callback.""" + # Setup mock for pickle.loads to return a valid _StateDictMeta + with patch("pickle.loads") as mock_loads: + # Create a mock state dict metadata + from torch.utils._pytree import tree_flatten_with_path + + state_dict = {"weight": torch.randn(3, 4), "bias": torch.randn(4)} + leaves, treespec = tree_flatten_with_path(state_dict) + paths = [path for path, _ in leaves] + + # Create mock tensor metadata + tensor_metas = [] + for _, v in leaves: + tensor_metas.append( + _TensorMeta( + shape=v.shape, + dtype=v.dtype, + storage_offset=v.storage_offset(), + stride=v.stride(), + nbytes=v.untyped_storage().nbytes(), + ) + ) + + mock_meta = _StateDictMeta( + treespec=treespec, paths=paths, non_tensor_leaves=tensor_metas + ) + mock_loads.return_value = mock_meta + + # Setup len_t and buf tensors for the mock recv + def side_effect(tensor_list, *args, **kwargs): + if tensor_list[0].numel() == 1: # This is len_t + tensor_list[0].fill_(100) # Some arbitrary length + return self.mock_work + + self.pg.recv.side_effect = side_effect + + # Create a state_dict callback + callback_state_dict = {"weight": torch.randn(3, 4), "bias": torch.randn(4)} + state_dict_callback = MagicMock(return_value=callback_state_dict) + + # Create transport with state_dict callback and call recv_checkpoint + transport = PGTransport( + self.pg, self.timeout, self.device, state_dict=state_dict_callback + ) + transport.recv_checkpoint(src_rank=0) + + # Check that state_dict callback was called + state_dict_callback.assert_called_once() + + +class TestPGTransportEdgeCases(TestCase): + def setUp(self): + self.device = torch.device("cpu") + self.pg = MagicMock() + self.timeout = timedelta(seconds=10) + + # Mock Work object + self.mock_work = MagicMock() + self.mock_work.wait = MagicMock() + + # Setup process group mock to return mock_work + self.pg.send = MagicMock(return_value=self.mock_work) + self.pg.recv = MagicMock(return_value=self.mock_work) + + def test_send_checkpoint_with_cpu_tensors(self): + """Test send_checkpoint with CPU tensors when device is CUDA.""" + # Skip if CUDA is not available + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + device = torch.device("cuda:0") + + # Create a state dict with CPU tensors + state_dict = { + "cpu_tensor1": torch.randn(2, 3), + "cpu_tensor2": torch.randn(3, 4), + } + + # Create transport with CUDA device + transport = PGTransport(self.pg, self.timeout, device) + + # Call send_checkpoint + transport.send_checkpoint([1], state_dict) + + # Check that send was called + self.assertGreaterEqual( + self.pg.send.call_count, 4 + ) # len_t, buf_t, and 2 tensors + + # Check that wait was called + self.assertGreaterEqual(self.mock_work.wait.call_count, 4) + + +# import fbvscode +# fbvscode.attach_debugger() + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index b0a8ae3f58c94f..9c4f6fb005a309 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -677,9 +677,7 @@ def test_optim_state_dict_param_matching(self) -> None: fully_shard(layer) fully_shard(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2) - torch.optim.lr_scheduler.LambdaLR( - optim, lr_lambda=[lambda epoch: 0.95**epoch] - ) + torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[lambda epoch: 0.95**epoch]) opt_state_dict = ptd_state_dict.get_optimizer_state_dict( model, optim, diff --git a/test/distributed/checkpoint/test_state_dict_stager.py b/test/distributed/checkpoint/test_state_dict_stager.py new file mode 100644 index 00000000000000..86a952e0701d22 --- /dev/null +++ b/test/distributed/checkpoint/test_state_dict_stager.py @@ -0,0 +1,822 @@ +# Owner(s): ["oncall: distributed"] + +import dataclasses + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor +from torch.distributed._tensor.placement_types import Shard +from torch.distributed.checkpoint._state_dict_stager import StateDictStager +from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +def create_cpu_state_dict(state_dict): + cpu_state_dict = {} + for key, value in state_dict.items(): + cpu_state_dict[key] = value.cpu() + return cpu_state_dict + + +def compare_state_dicts(cuda_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8): + """ + Compare if two state dictionaries (one on CUDA, one on CPU) are otherwise the same. + + This function checks if the tensors in both state dictionaries have the same values, + shapes, dtypes, etc., ignoring the device difference. It also checks if tensors that + share storage in one state dict also share storage in the other. + + Args: + cuda_state_dict: The state dictionary with tensors on CUDA + cpu_state_dict: The state dictionary with tensors on CPU + rtol: Relative tolerance for comparing tensor values + atol: Absolute tolerance for comparing tensor values + + Returns: + bool: True if the state dictionaries are equivalent, False otherwise + str: Error message if the state dictionaries are not equivalent, empty string otherwise + """ + # Track storage data pointers to check storage sharing + cuda_storage_ptrs = {} + cpu_storage_ptrs = {} + + def compare_objects(cuda_obj, cpu_obj, path=""): + # If objects are tensors, compare them + if isinstance(cuda_obj, torch.Tensor) and isinstance(cpu_obj, torch.Tensor): + # Check if devices are as expected + if cuda_obj.device.type != "cuda": + return ( + False, + f"Expected CUDA tensor, got {cuda_obj.device.type} tensor at {path}", + ) + if cpu_obj.device.type != "cpu": + return ( + False, + f"Expected CPU tensor, got {cpu_obj.device.type} tensor at {path}", + ) + if cuda_obj.storage_offset() != cpu_obj.storage_offset(): + return ( + False, + f"Storage offset mismatch at {path}: {cuda_obj.storage_offset()} vs {cpu_obj.storage_offset()}", + ) + + if not torch.equal(cuda_obj.cpu(), cpu_obj): + return ( + False, + f"Tensors are not same at {path}", + ) + + # Track storage sharing + cuda_storage_ptr = cuda_obj.storage().data_ptr() + cpu_storage_ptr = cpu_obj.storage().data_ptr() + + if cuda_storage_ptr in cuda_storage_ptrs: + # This CUDA tensor shares storage with another tensor + # Check if the corresponding CPU tensors also share storage + if cpu_storage_ptr != cuda_storage_ptrs[cuda_storage_ptr]: + return ( + False, + f"Storage sharing mismatch: CUDA tensors share storage but CPU tensors don't at {path}", + ) + else: + # First time seeing this storage + cuda_storage_ptrs[cuda_storage_ptr] = cpu_storage_ptr + cpu_storage_ptrs[cpu_storage_ptr] = cuda_storage_ptr + + return True, "" + + # If objects are dictionaries, compare them recursively + elif isinstance(cuda_obj, dict) and isinstance(cpu_obj, dict): + if cuda_obj.keys() != cpu_obj.keys(): + return ( + False, + f"Dictionary keys mismatch at {path}: {cuda_obj.keys()} vs {cpu_obj.keys()}", + ) + + for key in cuda_obj: + result, error = compare_objects( + cuda_obj[key], cpu_obj[key], f"{path}.{key}" if path else key + ) + if not result: + return False, error + + return True, "" + + # If objects are lists, tuples, or sets, compare them recursively + elif isinstance(cuda_obj, (list, tuple, set)) and isinstance( + cpu_obj, (list, tuple, set) + ): + if len(cuda_obj) != len(cpu_obj): + return ( + False, + f"Collection length mismatch at {path}: {len(cuda_obj)} vs {len(cpu_obj)}", + ) + if type(cuda_obj) != type(cpu_obj): + return ( + False, + f"Collection type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}", + ) + + for i, (cuda_item, cpu_item) in enumerate(zip(cuda_obj, cpu_obj)): + result, error = compare_objects(cuda_item, cpu_item, f"{path}[{i}]") + if not result: + return False, error + + return True, "" + + # If objects are custom classes, compare their attributes + elif hasattr(cuda_obj, "__dict__") and hasattr(cpu_obj, "__dict__"): + if type(cuda_obj) != type(cpu_obj): + return ( + False, + f"Object type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}", + ) + + result, error = compare_objects( + cuda_obj.__dict__, cpu_obj.__dict__, f"{path}.__dict__" + ) + if not result: + return False, error + + return True, "" + + # For other types, use direct equality comparison + else: + if type(cuda_obj) != type(cpu_obj): + return ( + False, + f"Type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}", + ) + if cuda_obj != cpu_obj: + return False, f"Value mismatch at {path}: {cuda_obj} vs {cpu_obj}" + + return True, "" + + # Start the recursive comparison + result, error = compare_objects(cuda_state_dict, cpu_state_dict) + return result, error + + +@dataclasses.dataclass +class TestStruct: + tensor1: torch.Tensor + + +@dataclasses.dataclass +class NestedTensorStruct: + tensor: torch.Tensor + value: int = 42 + + +@dataclasses.dataclass +class ComplexDataClass: + tensor: torch.Tensor + name: str + values: list[float] + nested: NestedTensorStruct + + +@dataclasses.dataclass(frozen=True) +class FrozenDataClass: + tensor: torch.Tensor + value: int = 100 + + +class TestStateDictStager(TestCase): + @requires_cuda + def test_views(self): + test_configs = [ + (False, False), # pin_memory=False, share_memory=False, + (True, False), # pin_memory=True, share_memory=False + (False, True), # pin_memory=False, share_memory=True + (True, True), # pin_memory=True, share_memory=True + ] + for pin_memory, share_memory in test_configs: + with self.subTest(pin_memory=pin_memory, share_memory=share_memory): + tensor1 = torch.randn(4, 4).cuda() + tensor2 = tensor1.view(16) + tensor3 = torch.randn(4, 4).cuda() + state_dict = { + "tensor1": tensor1, + "tensor2": tensor2, + "recursive": { + "tensor3": tensor3, + "type": TestStruct(tensor1=tensor3.narrow(0, 0, 2)), + }, + } + assert ( + state_dict["tensor1"].storage().data_ptr() + == state_dict["tensor2"].storage().data_ptr() + ) + + stager = StateDictStager( + pin_memory=pin_memory, share_memory=share_memory + ) + + cpu_state_dict = stager.stage(state_dict) + + # Calculate stats + num_storages = len(stager._cached_storage_mapping) + num_bytes = sum( + storage.nbytes() + for storage in stager._cached_storage_mapping.values() + ) + + # Validate tensor count and bytes + expected_storage_cnt = 2 + assert num_storages == expected_storage_cnt, ( + f"Expected {expected_storage_cnt} storages, got {num_storages}" + ) + + # Calculate expected bytes + # Note: Only unique storages are counted in the byte count + expected_bytes = ( + tensor1.numel() * tensor1.element_size() + + tensor3.numel() # tensor1 and tensor2 share storage + * tensor3.element_size() # tensor3 and its narrow view share storage + ) + assert num_bytes == expected_bytes, ( + f"Expected {expected_bytes} bytes, got {num_bytes}" + ) + # Verify that the CPU state dict is equivalent to the original CUDA state dict + result, error = compare_state_dicts(state_dict, cpu_state_dict) + assert result, f"State dicts are not equivalent: {error}" + + # Additional checks for storage sharing + assert cpu_state_dict["tensor1"].device == torch.device("cpu") + assert cpu_state_dict["tensor2"].device == torch.device("cpu") + assert ( + cpu_state_dict["tensor1"].storage().data_ptr() + == cpu_state_dict["tensor2"].storage().data_ptr() + ) + + recursive = cpu_state_dict["recursive"] + assert recursive["tensor3"].device == torch.device("cpu") + assert recursive["type"].tensor1.device == torch.device("cpu") + assert ( + recursive["tensor3"].storage().data_ptr() + == recursive["type"].tensor1.storage().data_ptr() + ) + + @requires_cuda + def test_caching(self): + """ + Test that the StateDictStager correctly caches and reuses storages. + """ + test_configs = [ + (False, False), # pin_memory=False, share_memory=False, + (True, False), # pin_memory=True, share_memory=False + (False, True), # pin_memory=False, share_memory=True + (True, True), # pin_memory=True, share_memory=True + ] + for pin_memory, share_memory in test_configs: + with self.subTest(pin_memory=pin_memory, share_memory=share_memory): + # Create test tensors and state dict + tensor1 = torch.randn(4, 4).cuda() + tensor2 = tensor1.view(16) + tensor3 = torch.randn(4, 4).cuda() + state_dict = { + "tensor1": tensor1, + "tensor2": tensor2, + "recursive": { + "tensor3": tensor3, + "type": TestStruct(tensor1=tensor3.narrow(0, 0, 2)), + }, + } + + # Create a StateDictStager instance + stager = StateDictStager( + pin_memory=pin_memory, share_memory=share_memory + ) + + # First call to stage with staging context + cpu_state_dict1 = stager.stage(state_dict) + + # Get the number of cached storages after first stage + num_storages1 = len(stager._cached_storage_mapping) + + # Verify the first result is correct + result, error = compare_state_dicts(state_dict, cpu_state_dict1) + assert result, ( + f"First state dict is not equivalent to original: {error}" + ) + + # Modify the original tensors + tensor1.fill_(0) + tensor3.fill_(0) + + # Second call to stage with staging context + cpu_state_dict2 = stager.stage(state_dict) + + # Get the number of cached storages after second stage + num_storages2 = len(stager._cached_storage_mapping) + + # Verify that the second CPU state dict is equivalent to the modified original state dict + result, error = compare_state_dicts(state_dict, cpu_state_dict2) + assert result, ( + f"Second state dict is not equivalent to modified original: {error}" + ) + + # Verify that the number of cached storages hasn't changed + assert num_storages1 == num_storages2, ( + f"Storage count changed: {num_storages1} vs {num_storages2}" + ) + + # Verify that the tensors in the second state dict have the same storage pointers as the first + assert ( + cpu_state_dict1["tensor1"].storage().data_ptr() + == cpu_state_dict2["tensor1"].storage().data_ptr() + ), "Storage pointers should match for tensor1" + assert ( + cpu_state_dict1["tensor2"].storage().data_ptr() + == cpu_state_dict2["tensor2"].storage().data_ptr() + ), "Storage pointers should match for tensor2" + assert ( + cpu_state_dict1["recursive"]["tensor3"].storage().data_ptr() + == cpu_state_dict2["recursive"]["tensor3"].storage().data_ptr() + ), "Storage pointers should match for tensor3" + + # Modify the original tensors again with different values + tensor1.fill_(42.0) + + # Third call to stage with staging context + cpu_state_dict3 = stager.stage(state_dict) + + # Verify that the third CPU state dict reflects the updated values + assert torch.all(cpu_state_dict3["tensor1"] == 42.0), ( + "Updated values should be reflected in the cached state dict" + ) + assert torch.all(cpu_state_dict3["tensor2"] == 42.0), ( + "Updated values should be reflected in the cached state dict" + ) + + @requires_cuda + def test_tensor_attrs(self): + """ + Test that tensor attributes are preserved during stage with StateDictStager. + """ + tensor1 = torch.randn(4, 4).cuda() + tensor2 = tensor1.view(16) + tensor3 = torch.randn(4, 4).cuda() + + # Add custom attributes to tensors + tensor1.a = 42 + tensor1.b = 43 + tensor3.c = 44 + + state_dict = { + "tensor1": tensor1, + "tensor2": tensor2, + "recursive": { + "tensor3": tensor3, + "type": TestStruct(tensor1=tensor3.narrow(0, 0, 2)), + }, + } + + stager = StateDictStager(pin_memory=True, share_memory=True) + cpu_state_dict = stager.stage(state_dict) + + # Verify that tensor attributes are preserved + assert hasattr(cpu_state_dict["tensor1"], "a"), ( + "Tensor attribute 'a' was not preserved" + ) + assert cpu_state_dict["tensor1"].a == 42, ( + "Tensor attribute 'a' has incorrect value" + ) + assert hasattr(cpu_state_dict["tensor1"], "b"), ( + "Tensor attribute 'b' was not preserved" + ) + assert cpu_state_dict["tensor1"].b == 43, ( + "Tensor attribute 'b' has incorrect value" + ) + assert hasattr(cpu_state_dict["recursive"]["tensor3"], "c"), ( + "Tensor attribute 'c' was not preserved" + ) + assert cpu_state_dict["recursive"]["tensor3"].c == 44, ( + "Tensor attribute 'c' has incorrect value" + ) + + @requires_cuda + def test_different_dtypes(self): + """ + Test that StateDictStager works correctly with tensors of different data types. + """ + # Create tensors with different dtypes + tensors = { + "float32": torch.randn(4, 4, dtype=torch.float32).cuda(), + "float64": torch.randn(4, 4, dtype=torch.float64).cuda(), + "int32": torch.randint(-100, 100, (4, 4), dtype=torch.int32).cuda(), + "int64": torch.randint(-100, 100, (4, 4), dtype=torch.int64).cuda(), + "bool": torch.randint(0, 2, (4, 4), dtype=torch.bool).cuda(), + } + + # Create a state dict with these tensors + state_dict = tensors.copy() + + stager = StateDictStager() + cpu_state_dict = stager.stage(state_dict) + + # Verify that all tensors have been correctly copied to CPU with the right dtypes + for dtype_name, original_tensor in tensors.items(): + cpu_tensor = cpu_state_dict[dtype_name] + self.assertEqual( + cpu_tensor.device.type, "cpu", f"Tensor {dtype_name} should be on CPU" + ) + self.assertEqual( + cpu_tensor.dtype, + original_tensor.dtype, + f"Tensor {dtype_name} has incorrect dtype", + ) + self.assertTrue( + torch.allclose(cpu_tensor, original_tensor.cpu()), + f"Tensor {dtype_name} has incorrect values", + ) + + @requires_cuda + def test_empty_tensors(self): + """ + Test that StateDictStager works correctly with empty tensors. + """ + test_configs = [ + (False, False), # pin_memory=False, share_memory=False, + (True, False), # pin_memory=True, share_memory=False + (False, True), # pin_memory=False, share_memory=True + (True, True), # pin_memory=True, share_memory=True + ] + for pin_memory, share_memory in test_configs: + with self.subTest(pin_memory=pin_memory, share_memory=share_memory): + # Create empty tensors with different shapes + tensors = { + "empty_0d": torch.tensor([], dtype=torch.float32).cuda(), + "empty_1d": torch.tensor([], dtype=torch.float32).reshape(0).cuda(), + "empty_2d": torch.tensor([], dtype=torch.float32) + .reshape(0, 0) + .cuda(), + "empty_3d": torch.tensor([], dtype=torch.float32) + .reshape(0, 0, 0) + .cuda(), + "zero_dim": torch.tensor(0.0).cuda(), # scalar tensor + } + + # Create a state dict with these tensors + state_dict = tensors.copy() + + cpu_state_dict = StateDictStager(pin_memory, share_memory).stage( + state_dict + ) + + # Verify that all tensors have been correctly copied to CPU + for tensor_name, original_tensor in tensors.items(): + cpu_tensor = cpu_state_dict[tensor_name] + + self.assertEqual( + cpu_tensor.device.type, + "cpu", + f"Tensor {tensor_name} should be on CPU", + ) + self.assertEqual( + cpu_tensor.shape, + original_tensor.shape, + f"Tensor {tensor_name} has incorrect shape", + ) + self.assertEqual( + cpu_tensor.dtype, + original_tensor.dtype, + f"Tensor {tensor_name} has incorrect dtype", + ) + + @requires_cuda + def test_complex_storage_sharing(self): + """ + Test that StateDictStager correctly handles complex storage sharing scenarios. + """ + # Create a base tensor + base_tensor = torch.randn(10, 10).cuda() + + # Create various views and slices that share storage + view1 = base_tensor.view(100) + view2 = base_tensor.view(10, 10) + slice1 = base_tensor[2:8, 2:8] + slice2 = base_tensor[:, :5] + slice3 = view1[10:60] + + # Create a state dict with these tensors + state_dict = { + "base": base_tensor, + "view1": view1, + "view2": view2, + "slice1": slice1, + "slice2": slice2, + "slice3": slice3, + } + cpu_state_dict = StateDictStager().stage(state_dict) + + # Verify that all tensors have been correctly copied to CPU + result, error = compare_state_dicts(state_dict, cpu_state_dict) + self.assertTrue(result, f"State dicts are not equivalent: {error}") + + # Verify storage sharing is preserved + # All these tensors should share the same storage + storage_ptr = cpu_state_dict["base"].storage().data_ptr() + self.assertEqual( + cpu_state_dict["view1"].storage().data_ptr(), + storage_ptr, + "view1 should share storage with base", + ) + self.assertEqual( + cpu_state_dict["view2"].storage().data_ptr(), + storage_ptr, + "view2 should share storage with base", + ) + self.assertEqual( + cpu_state_dict["slice1"].storage().data_ptr(), + storage_ptr, + "slice1 should share storage with base", + ) + self.assertEqual( + cpu_state_dict["slice2"].storage().data_ptr(), + storage_ptr, + "slice2 should share storage with base", + ) + self.assertEqual( + cpu_state_dict["slice3"].storage().data_ptr(), + storage_ptr, + "slice3 should share storage with base", + ) + + # Verify that modifying the base tensor affects all views and slices + cpu_state_dict["base"].fill_(42.0) + self.assertTrue( + torch.all(cpu_state_dict["view1"] == 42.0), + "view1 should reflect changes to base", + ) + self.assertTrue( + torch.all(cpu_state_dict["view2"] == 42.0), + "view2 should reflect changes to base", + ) + self.assertTrue( + torch.all(cpu_state_dict["slice1"] == 42.0), + "slice1 should reflect changes to base", + ) + self.assertTrue( + torch.all(cpu_state_dict["slice2"] == 42.0), + "slice2 should reflect changes to base", + ) + self.assertTrue( + torch.all(cpu_state_dict["slice3"] == 42.0), + "slice3 should reflect changes to base", + ) + + @requires_cuda + def test_dataclasses(self): + # Create tensors + tensor1 = torch.randn(4, 4).cuda() + tensor2 = torch.randn(8, 8).cuda() + tensor3 = torch.randn(2, 6).cuda() + tensor4 = torch.randn(3, 5).cuda() + + # Create dataclass instances + nested = NestedTensorStruct(tensor=tensor3) + complex_dc = ComplexDataClass( + tensor=tensor1, name="test", values=[1.0, 2.0, 3.0], nested=nested + ) + frozen_dc = FrozenDataClass(tensor=tensor4) + + # Create a state dict with these dataclasses + state_dict = { + "regular_tensor": tensor2, + "complex_dataclass": complex_dc, + "frozen_dataclass": frozen_dc, + } + + # Stage the state dict + stager = StateDictStager(pin_memory=False, share_memory=False) + cpu_state_dict = stager.stage(state_dict) + + # Verify regular tensor + self.assertEqual(cpu_state_dict["regular_tensor"].device.type, "cpu") + self.assertTrue(torch.allclose(cpu_state_dict["regular_tensor"], tensor2.cpu())) + + # Verify complex dataclass + complex_cpu = cpu_state_dict["complex_dataclass"] + self.assertEqual(complex_cpu.name, "test") + self.assertEqual(complex_cpu.values, [1.0, 2.0, 3.0]) + self.assertEqual(complex_cpu.tensor.device.type, "cpu") + self.assertTrue(torch.allclose(complex_cpu.tensor, tensor1.cpu())) + + # Verify nested dataclass inside complex dataclass + nested_cpu = complex_cpu.nested + self.assertEqual(nested_cpu.value, 42) + self.assertEqual(nested_cpu.tensor.device.type, "cpu") + self.assertTrue(torch.allclose(nested_cpu.tensor, tensor3.cpu())) + + # Verify frozen dataclass + frozen_cpu = cpu_state_dict["frozen_dataclass"] + self.assertEqual(frozen_cpu.value, 100) + self.assertEqual(frozen_cpu.tensor.device.type, "cpu") + self.assertTrue(torch.allclose(frozen_cpu.tensor, tensor4.cpu())) + + # Verify that modifying the original tensors doesn't affect the staged ones + tensor1.fill_(99.0) + tensor3.fill_(88.0) + tensor4.fill_(77.0) + + self.assertFalse(torch.allclose(complex_cpu.tensor, tensor1.cpu())) + self.assertFalse(torch.allclose(nested_cpu.tensor, tensor3.cpu())) + self.assertFalse(torch.allclose(frozen_cpu.tensor, tensor4.cpu())) + + def test_cpu_storage_independence(self): + """ + Test ensures CPU tensors passed to StateDictStager are actually cloned + """ + # Create test tensors + tensor1 = torch.randn(4, 4) + tensor2 = torch.randn(8, 8) + + # Create a state dict with these tensors + state_dict = { + "tensor1": tensor1, + "tensor2": tensor2, + } + + cpu_state_dict = StateDictStager().stage(state_dict) + cpu_tensor1 = cpu_state_dict["tensor1"] + cpu_tensor2 = cpu_state_dict["tensor2"] + + # Verify that the CPU tensors have different storage pointers than the original tensors + self.assertNotEqual( + tensor1.storage().data_ptr(), + cpu_tensor1.storage().data_ptr(), + "CPU tensor should have a different storage pointer than the original tensor", + ) + self.assertNotEqual( + tensor2.storage().data_ptr(), + cpu_tensor2.storage().data_ptr(), + "CPU tensor should have a different storage pointer than the original tensor", + ) + + self.assertTrue( + torch.allclose(tensor1, cpu_tensor1), + "CPU tensor should have the same values as the original tensor", + ) + self.assertTrue( + torch.allclose(tensor2, cpu_tensor2), + "CPU tensor should have the same values as the original tensor", + ) + + # Modify the original CPU tensors and validate staged tensors are not modified + cloned_orginial1 = tensor1.clone() + cloned_orginia2 = tensor2.clone() + tensor1.fill_(99.0) + tensor2.fill_(88.0) + + self.assertFalse(torch.allclose(cloned_orginial1, tensor1)) + self.assertTrue( + torch.allclose(cloned_orginial1, cpu_tensor1), + "CPU tensor should have the same values as the original tensor", + ) + self.assertTrue( + torch.allclose(cloned_orginia2, cpu_tensor2), + "CPU tensor should have the same values as the original tensor", + ) + + @requires_cuda + def test_tensor_pinned_and_shared(self): + """ + Test that verifies tensors are actually pinned and shared using tensor.is_pinned() and tensor.is_shared() methods. + """ + # Create test tensors + tensor1 = torch.randn(4, 4).cuda() + tensor2 = torch.randn(8, 8).cuda() + + # Create a state dict with these tensors + state_dict = { + "tensor1": tensor1, + "tensor2": tensor2, + } + + # Test all combinations of pin_memory and share_memory + test_configs = [ + (False, False), # pin_memory=False, share_memory=False + (True, False), # pin_memory=True, share_memory=False + (False, True), # pin_memory=False, share_memory=True + (True, True), # pin_memory=True, share_memory=True + ] + + for pin_memory, share_memory in test_configs: + with self.subTest(pin_memory=pin_memory, share_memory=share_memory): + # Create stager with specific configuration + stager = StateDictStager( + pin_memory=pin_memory, share_memory=share_memory + ) + cpu_state_dict = stager.stage(state_dict) + + # Get the staged tensors + cpu_tensor1 = cpu_state_dict["tensor1"] + cpu_tensor2 = cpu_state_dict["tensor2"] + + # Verify tensor device + self.assertEqual( + cpu_tensor1.device.type, "cpu", "Staged tensor should be on CPU" + ) + self.assertEqual( + cpu_tensor2.device.type, "cpu", "Staged tensor should be on CPU" + ) + + # Verify tensor values + self.assertTrue( + torch.allclose(cpu_tensor1, tensor1.cpu()), + "CPU tensor should have the same values as the original tensor", + ) + self.assertTrue( + torch.allclose(cpu_tensor2, tensor2.cpu()), + "CPU tensor should have the same values as the original tensor", + ) + + # Verify pinned memory status + self.assertEqual( + cpu_tensor1.is_pinned(), + pin_memory, + f"Tensor pinned status should be {pin_memory}", + ) + self.assertEqual( + cpu_tensor2.is_pinned(), + pin_memory, + f"Tensor pinned status should be {pin_memory}", + ) + + # Verify shared memory status + self.assertEqual( + cpu_tensor1.is_shared(), + share_memory, + f"Tensor shared status should be {share_memory}", + ) + self.assertEqual( + cpu_tensor2.is_shared(), + share_memory, + f"Tensor shared status should be {share_memory}", + ) + + # Verify storage sharing is consistent with tensor sharing + if share_memory: + # When share_memory is True, the storage should also be shared + self.assertTrue( + cpu_tensor1.storage().is_shared(), + "When share_memory=True, tensor storage should be shared", + ) + self.assertTrue( + cpu_tensor2.storage().is_shared(), + "When share_memory=True, tensor storage should be shared", + ) + else: + # When share_memory is False, the storage should not be shared + self.assertFalse( + cpu_tensor1.storage().is_shared(), + "When share_memory=False, tensor storage should not be shared", + ) + self.assertFalse( + cpu_tensor2.storage().is_shared(), + "When share_memory=False, tensor storage should not be shared", + ) + + +class TestDTensorStateDictStager(DTensorTestBase): + @with_comms + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_dtensor(self): + """ + Test that StateDictStager works correctly with DTensors. + """ + # Create a DTensor + device_mesh = dist.DeviceMesh("cuda", list(range(dist.get_world_size()))) + tensor = torch.randn(3, 3, device="cuda") + dtensor = DTensor.from_local(tensor, device_mesh, [Shard(0)]) + + dtensor = dtensor + 1 + dtensor = dtensor * 2 + + state_dict = { + "dtensor": dtensor, + } + + stager = StateDictStager(pin_memory=True, share_memory=True) + cpu_state_dict = stager.stage(state_dict) + + # Verify the original DTensor has the expected values + self.assertTrue(torch.allclose(dtensor.to_local(), (tensor + 1) * 2)) + self.assertTrue( + torch.allclose( + cpu_state_dict["dtensor"].to_local(), dtensor.to_local().cpu() + ) + ) + self.assertEqual(cpu_state_dict["dtensor"]._spec, dtensor._spec) + self.assertEqual(cpu_state_dict["dtensor"].size(), dtensor.size()) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/test_state_dict_utils.py b/test/distributed/checkpoint/test_state_dict_utils.py index 590bf2c2e2c964..010ebf02ecd609 100644 --- a/test/distributed/checkpoint/test_state_dict_utils.py +++ b/test/distributed/checkpoint/test_state_dict_utils.py @@ -5,6 +5,12 @@ import torch import torch.distributed as dist import torch.distributed._functional_collectives as funcol +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard as ShardedTensorShard, + ShardedTensor, + ShardMetadata, +) from torch.distributed._state_dict_utils import ( _check_state_dict_similarity, _copy_state_dict, @@ -120,15 +126,42 @@ def create_dtensor(): } self.assertEqual(state_dict, _gather_state_dict(dist_state_dict)) + @with_comms @skip_if_lt_x_gpu(2) def test_create_cpu_state_dict(self): device = torch.device("cuda") + rank = dist.get_rank() + # Scale tensors based on world size + # to fit in the tensor shards accurately. + scale_factor = self.world_size buffer = io.BytesIO() torch.save(torch.ones(10), buffer) buffer.seek(0) state_dict = { "tensor1": torch.arange(10, device=device), "tensor2": torch.ones(10, device=device), + "sharded_tensor": init_from_local_shards( + [ + ShardedTensorShard( + tensor=torch.arange( + 50 * rank, 50 + 50 * rank, device=device + ).reshape(5, 10), + metadata=ShardMetadata( + shard_offsets=[5 * rank, 0], + shard_sizes=[5, 10], + placement=f"rank:{rank}/cuda:{rank}", + ), + ) + ], + torch.Size([5 * scale_factor, 10]), + ), + "dtensor": distribute_tensor( + torch.arange(50 * scale_factor, device=device).reshape( + 5 * scale_factor, 10 + ), + init_device_mesh("cuda", mesh_shape=(self.world_size,)), + [Shard(0)], + ), "non_tensor_bytes_io": copy.deepcopy(buffer), "non_tensor_bytes": buffer.read(), "step": torch.tensor(7, dtype=torch.float), @@ -148,10 +181,34 @@ def _verify(cpu_state_dict): # Verify if _copy_state_dict works for v in cpu_state_dict.values(): - if isinstance(v, torch.Tensor): - self.assertFalse(v.is_cuda) + if isinstance(v, (torch.Tensor, DTensor, ShardedTensor)): + self.assertTrue(v.device == torch.device("cpu")) self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10)) self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10)) + self.assertEqual( + cpu_state_dict["sharded_tensor"].local_tensor(), + torch.arange(50 * rank, 50 + 50 * rank).reshape(5, 10), + ) + self.assertEqual( + cpu_state_dict["dtensor"].to_local(), + torch.arange(50 * rank, 50 + 50 * rank).reshape(5, 10), + ) + self.assertNotEqual( + cpu_state_dict["tensor1"].storage().data_ptr(), + state_dict["tensor1"].storage().data_ptr(), + ) + self.assertNotEqual( + cpu_state_dict["tensor2"].storage().data_ptr(), + state_dict["tensor2"].storage().data_ptr(), + ) + self.assertNotEqual( + cpu_state_dict["sharded_tensor"].local_tensor().storage().data_ptr(), + state_dict["sharded_tensor"].local_tensor().storage().data_ptr(), + ) + self.assertNotEqual( + cpu_state_dict["dtensor"].to_local().storage().data_ptr(), + state_dict["dtensor"].to_local().storage().data_ptr(), + ) buffer.seek(0) cpu_state_dict["non_tensor_bytes_io"].seek(0) self.assertEqual( @@ -163,6 +220,8 @@ def _verify(cpu_state_dict): self.assertEqual(cpu_state_dict["step"], 7) self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]}) + cpu_state_dict = _create_cpu_state_dict(state_dict) + _verify(cpu_state_dict) cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True) _verify(cpu_state_dict) cpu_state_dict = _create_cpu_state_dict(state_dict, share_memory=True) diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py index 9dc730379ecfc3..1074d11a77b08a 100644 --- a/test/distributed/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -2,6 +2,7 @@ import io import sys +from typing import Optional import torch import torch.distributed as dist @@ -242,6 +243,48 @@ def test_scatter_object(self): expected_objects = rank assert scattered_objects == expected_objects + @with_comms + @skip_if_lt_x_gpu(2) + def test_broadcast_object_with_nonzero_coordinator(self): + # Everybody uses WORLD, but src is coordinator_rank=1 + dist_wrapper = _DistWrapper( + group=dist.group.WORLD, + use_dist=True, + coordinator_rank=1, + ) + + rank = dist.get_rank() + # only local rank 1 supplies the payload + payload: Optional[int] = rank if rank == 1 else None + + result = dist_wrapper.broadcast_object(payload) + # every rank should receive the value from global rank 1 + assert result == 1 + + @with_comms + @skip_if_lt_x_gpu(4) + def test_broadcast_object_global_local_mismatch(self): + # reproduces issue 152310 + + mesh_2d = dist.init_device_mesh(self.device_type, (2, self.world_size // 2)) + dist_wrapper = _DistWrapper( + group=mesh_2d.get_group(1), + use_dist=True, + coordinator_rank=1, # local coordinator index within the subgroup + ) + + rank = mesh_2d.get_rank() + + # only the local coordinator in each subgroup provides payload + payload: Optional[int] = rank if dist_wrapper.is_coordinator else None + got = dist_wrapper.broadcast_object(payload) + + # ensure we broadcast from the *global* coordinator rank, + # not the local index. For rows [0,1] this is global rank 1; + # for rows [2,3] this is global rank 3. + expected = dist_wrapper.global_coordinator_rank + assert got == expected + @with_comms @skip_if_lt_x_gpu(2) def test_barrier(self): diff --git a/test/distributed/elastic/agent/server/test/api_test.py b/test/distributed/elastic/agent/server/test/api_test.py index 4fa8ba8fb5ae94..517eb8eca367e1 100644 --- a/test/distributed/elastic/agent/server/test/api_test.py +++ b/test/distributed/elastic/agent/server/test/api_test.py @@ -34,7 +34,6 @@ from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters from torch.distributed.elastic.rendezvous.api import RendezvousGracefulExitError from torch.distributed.elastic.utils.distributed import get_free_port -from torch.testing._internal.common_utils import run_tests def do_nothing(): @@ -166,6 +165,7 @@ def _get_worker_spec( role="test_trainer", local_world_size=8, local_addr=None, + event_log_handler="null", ): run_id = str(uuid.uuid4().int) port = get_free_port() @@ -192,6 +192,7 @@ def _get_worker_spec( max_restarts=max_restarts, monitor_interval=monitor_interval, local_addr=local_addr, + event_log_handler=event_log_handler, ) return spec @@ -349,7 +350,8 @@ def test_rendezvous_master_addr_with_local_addr(self): self.assertGreater(worker_group.master_port, 0) @patch.object(TestAgent, "_construct_event") - def test_initialize_workers(self, mock_construct_event): + @patch("torch.distributed.elastic.agent.server.api.record") + def test_initialize_workers(self, mock_record, mock_construct_event): spec = self._get_worker_spec(max_restarts=1) agent = TestAgent(spec) worker_group = agent.get_worker_group() @@ -362,6 +364,30 @@ def test_initialize_workers(self, mock_construct_event): mock_construct_event.assert_called() self.assertEqual(mock_construct_event.call_count, 10) + mock_record.assert_called() + second_arg = mock_record.call_args_list[0][0][1] + self.assertEqual(second_arg, "null") + + @patch.object(TestAgent, "_construct_event") + @patch("torch.distributed.elastic.agent.server.api.record") + def test_initialize_workers_with_new_spec(self, mock_record, mock_construct_event): + spec = self._get_worker_spec( + max_restarts=1, event_log_handler="framework_logger" + ) + agent = TestAgent(spec) + worker_group = agent.get_worker_group() + agent._initialize_workers(worker_group) + + self.assertEqual(WorkerState.HEALTHY, worker_group.state) + for i in range(spec.local_world_size): + worker = worker_group.workers[i] + self.assertEqual(worker.id, worker.global_rank) + + mock_construct_event.assert_called() + self.assertEqual(mock_construct_event.call_count, 10) + mock_record.assert_called() + second_arg = mock_record.call_args_list[0][0][1] + self.assertEqual(second_arg, "framework_logger") def test_restart_workers(self): spec = self._get_worker_spec() @@ -650,4 +676,7 @@ def test_agent_process_handler_graceful_exception(self, invoke_run, _): if __name__ == "__main__": - run_tests() + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py b/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py index f689f8f41f546c..ff89b5c51f01a6 100644 --- a/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py +++ b/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py @@ -1466,3 +1466,10 @@ def fail_rank_one_once(self): ) def test_rank_restart_after_failure(self): self.run_test_with_backend(backend="c10d", test_to_run=self.fail_rank_one_once) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index a6acc177ec81c8..6e0f273a7c8ebe 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -500,11 +500,13 @@ def test_wait_for_all_child_procs_to_exit(self): logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), ) - with mock.patch.object( - mpc, "_is_done", return_value=True - ), mock.patch.object(mpc, "_pc"), mock.patch.object( - mpc._pc, "join", side_effect=[True, False, False, True] - ) as mock_join: + with ( + mock.patch.object(mpc, "_is_done", return_value=True), + mock.patch.object(mpc, "_pc"), + mock.patch.object( + mpc._pc, "join", side_effect=[True, False, False, True] + ) as mock_join, + ): mpc._poll() self.assertEqual(4, mock_join.call_count) diff --git a/test/distributed/elastic/multiprocessing/redirects_test.py b/test/distributed/elastic/multiprocessing/redirects_test.py index 2fa507a15a36bb..8a34d3ba81d346 100644 --- a/test/distributed/elastic/multiprocessing/redirects_test.py +++ b/test/distributed/elastic/multiprocessing/redirects_test.py @@ -141,4 +141,7 @@ def c_print(i): if __name__ == "__main__": - unittest.main() + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/distributed/flight_recorder/test_fr_analysis.py b/test/distributed/flight_recorder/test_fr_analysis.py index cc0f1ad0375486..2822ed5dcb3a20 100644 --- a/test/distributed/flight_recorder/test_fr_analysis.py +++ b/test/distributed/flight_recorder/test_fr_analysis.py @@ -187,6 +187,7 @@ def test_match_info(self): "entries": [], "pg_config": { "0": {"name": "0", "desc": "default_pg", "ranks": "[0, 1]"}, + "1": {"name": "1", "desc": "sub_pg", "ranks": "[0]"}, }, "rank": 0, }, @@ -194,6 +195,7 @@ def test_match_info(self): "entries": [], "pg_config": { "0": {"name": "0", "desc": "default_pg", "ranks": "[0, 1]"}, + "1": {"name": "1", "desc": "sub_pg", "ranks": "[1]"}, }, "rank": 1, }, @@ -209,10 +211,11 @@ def create_one_entry( collective_seq_id=0, p2p_seq_id=0, output_dtypes="float32", + pg_info=("0", "default"), ): event = create_one_event( collective_name, - ("0", "default"), + pg_info, input_sizes, output_sizes, state, @@ -229,7 +232,7 @@ class FlightRecorderE2ETest(TestCase): def testBuildDB(self): config = JobConfig() args = config.parse_args([]) - version = "2.7" # Same as the version in FlightRecorder.hpp + version = "2.8" # Same as the version in FlightRecorder.hpp LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_0"]["version"] = version LOADED_FR_DETAIL_TEMPLATE["dump_file_rank_1"]["version"] = version # Test case 1: matched all_reduce case. @@ -240,11 +243,25 @@ def testBuildDB(self): details1["dump_file_rank_1"]["entries"].append( create_one_entry(0, "all_reduce", [[4, 4]], [[4, 4]]) ) + details1["dump_file_rank_0"]["entries"].append( + create_one_entry( + 1, "all_reduce", [[5, 5]], [[5, 5]], pg_info=("1", "sub_pg") + ) + ) + details1["dump_file_rank_1"]["entries"].append( + create_one_entry( + 1, "all_reduce", [[5, 5]], [[5, 5]], pg_info=("1", "sub_pg") + ) + ) db = build_db(details1, args, version) - self.assertEqual(len(db.collectives), 1) + self.assertEqual(len(db.collectives), 3) self.assertEqual(db.collectives[0].record_id, 0) self.assertEqual(db.collectives[0].collective_name, "nccl:all_reduce") self.assertEqual(db.collectives[0].pass_check, True) + self.assertEqual(db.collectives[1].record_id, 1) + self.assertEqual(db.collectives[1].collective_name, "nccl:all_reduce") + self.assertEqual(db.collectives[1].pass_check, True) + self.assertEqual(db.collectives[2].pass_check, True) # Test case 2: matched allreduce_coalesced case. details2 = copy.deepcopy(LOADED_FR_DETAIL_TEMPLATE) details2["dump_file_rank_0"]["entries"].append( diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index 42111efc8922dc..ac34246ee64322 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -56,32 +56,36 @@ def test_distributed_checkpoint(self, state_dict_type) -> None: torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) - with FullyShardedDataParallel.summon_full_params( - model - ), FullyShardedDataParallel.summon_full_params(new_model): + with ( + FullyShardedDataParallel.summon_full_params(model), + FullyShardedDataParallel.summon_full_params(new_model), + ): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) writer = FileSystemWriter(self.temp_dir) reader = FileSystemReader(self.temp_dir) - with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( - new_model, state_dict_type + with ( + FSDP.state_dict_type(model, state_dict_type), + FSDP.state_dict_type(new_model, state_dict_type), ): state_dict = model.state_dict() save(state_dict, writer) - with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( - new_model, state_dict_type + with ( + FSDP.state_dict_type(model, state_dict_type), + FSDP.state_dict_type(new_model, state_dict_type), ): state_dict = new_model.state_dict() load(state_dict, reader) new_model.load_state_dict(state_dict) - with FullyShardedDataParallel.summon_full_params( - model - ), FullyShardedDataParallel.summon_full_params(new_model): + with ( + FullyShardedDataParallel.summon_full_params(model), + FullyShardedDataParallel.summon_full_params(new_model), + ): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params) diff --git a/test/distributed/fsdp/test_fsdp_comm.py b/test/distributed/fsdp/test_fsdp_comm.py index aedeb688977d26..42fa0316230144 100644 --- a/test/distributed/fsdp/test_fsdp_comm.py +++ b/test/distributed/fsdp/test_fsdp_comm.py @@ -242,11 +242,10 @@ def test_communication( # and if `use_no_sync=False`, we only run `num_iters` iterations # outside `no_sync()` num_iters = 3 - with patch( - "torch.distributed.all_gather_into_tensor" - ) as mock_all_gather, patch( - "torch.distributed.reduce_scatter_tensor" - ) as mock_reduce_scatter: + with ( + patch("torch.distributed.all_gather_into_tensor") as mock_all_gather, + patch("torch.distributed.reduce_scatter_tensor") as mock_reduce_scatter, + ): def reset_mocks(): mock_all_gather.reset_mock() diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py index 5f8b88bb6e5925..d6ee32c1f2e35f 100644 --- a/test/distributed/fsdp/test_fsdp_core.py +++ b/test/distributed/fsdp/test_fsdp_core.py @@ -379,12 +379,15 @@ def _register_pre_backward_hooks_with_count(*args, **kwargs): register_pre_backward_hooks_call_count += 1 return orig_register_pre_backward_hooks(*args, **kwargs) - with mock.patch( - "torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks", - _register_pre_backward_hooks_with_count, - ), mock.patch( - "torch.distributed.fsdp._runtime_utils._register_post_backward_hook" - ) as register_post_bwd_mock: + with ( + mock.patch( + "torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks", + _register_pre_backward_hooks_with_count, + ), + mock.patch( + "torch.distributed.fsdp._runtime_utils._register_post_backward_hook" + ) as register_post_bwd_mock, + ): self.assertEqual(register_pre_backward_hooks_call_count, 0) self.assertFalse(register_post_bwd_mock.called) fsdp_model(*input) diff --git a/test/distributed/fsdp/test_fsdp_flatten_params.py b/test/distributed/fsdp/test_fsdp_flatten_params.py index 5581318b1c386f..1e4a408b872924 100644 --- a/test/distributed/fsdp/test_fsdp_flatten_params.py +++ b/test/distributed/fsdp/test_fsdp_flatten_params.py @@ -647,6 +647,36 @@ def test_flat_param_shard_metadata_with_memory_format(self, memory_format): ), ) + @skip_if_lt_x_gpu(1) + def test_writeback_orig_params_no_shard(self): + class EmbeddingModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = nn.Embedding(5, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.emb(x).sum() + + model = EmbeddingModel().half().to(self.rank) + fsdp_model = FSDP( + model, + sharding_strategy=HandleShardingStrategy.NO_SHARD, + use_orig_params=True, + ) + + # Copied from https://github.com/huggingface/accelerate/blob/main/src/accelerate/accelerator.py#L1679-1719 + for fsdp_module in FSDP.fsdp_modules(fsdp_model): + if not fsdp_module._has_params: + continue + param = fsdp_module._flat_param + param.data = param.data.float() + fsdp_module._handle._orig_param_dtype = torch.float32 + + x = torch.randint(0, 5, (20,), device=self.rank) + with torch.no_grad(): + out = fsdp_model(x) + self.assertEqual(out.shape, torch.Size([])) + instantiate_parametrized_tests(TestFlattenParams) diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py index 1e51938a033fc5..b674b408462cd1 100644 --- a/test/distributed/fsdp/test_fsdp_grad_acc.py +++ b/test/distributed/fsdp/test_fsdp_grad_acc.py @@ -152,9 +152,9 @@ def permute_tensor(x: torch.Tensor): batches.append(tuple(permute_tensor(t) for t in batch)) for batch1, batch2 in itertools.combinations(batches, r=2): for t1, t2 in zip(batch1, batch2): - assert not torch.all( - t1 == t2 - ), "Check the test to make sure that batches are distinct" + assert not torch.all(t1 == t2), ( + "Check the test to make sure that batches are distinct" + ) # Concatenate the batches along the given batch dimension concat_batch: tuple[torch.Tensor, ...] = tuple( diff --git a/test/distributed/fsdp/test_fsdp_hybrid_shard.py b/test/distributed/fsdp/test_fsdp_hybrid_shard.py index dc9b54be2dd7c7..70c415ae1fe7fd 100644 --- a/test/distributed/fsdp/test_fsdp_hybrid_shard.py +++ b/test/distributed/fsdp/test_fsdp_hybrid_shard.py @@ -121,8 +121,9 @@ def test_raises_manual_wrap_hybrid_shard_when_none_policy(self): def test_hsdp_save_load_state_dict(self): model = MyModel().cuda() num_node_devices = torch.cuda.device_count() - shard_rank_lists = list(range(0, num_node_devices // 2)), list( - range(num_node_devices // 2, num_node_devices) + shard_rank_lists = ( + list(range(0, num_node_devices // 2)), + list(range(num_node_devices // 2, num_node_devices)), ) shard_groups = ( dist.new_group(shard_rank_lists[0]), @@ -171,8 +172,9 @@ def test_hsdp_save_load_state_dict(self): def test_hsdp_sync_module_state(self): model = MyModel().cuda() num_node_devices = torch.cuda.device_count() - shard_rank_lists = list(range(0, num_node_devices // 2)), list( - range(num_node_devices // 2, num_node_devices) + shard_rank_lists = ( + list(range(0, num_node_devices // 2)), + list(range(num_node_devices // 2, num_node_devices)), ) shard_groups = ( dist.new_group(shard_rank_lists[0]), @@ -310,8 +312,9 @@ def patched_collective(orig_collective, counter, *args, **kwargs): cntr = Counter() patched_allreduce = partial(patched_collective, orig_ar, cntr) patched_reduce_scatter = partial(patched_collective, orig_rs, cntr) - with patch_allreduce(patched_allreduce), patch_reduce_scatter( - patched_reduce_scatter + with ( + patch_allreduce(patched_allreduce), + patch_reduce_scatter(patched_reduce_scatter), ): inp = hsdp_model.get_input(device=torch.cuda.current_device()) out = hsdp_model(inp[0], inp[1]) @@ -355,9 +358,9 @@ def _test_fsdp_hybrid_shard_parity( use_orig_params, hsdp_process_groups=hsdp_pgs, ) - assert ( - hsdp_model._inter_node_pg.size() > 1 - ), "HSDP model initialized without replication" + assert hsdp_model._inter_node_pg.size() > 1, ( + "HSDP model initialized without replication" + ) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2) torch.manual_seed(global_pg.rank() + 1) diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index bb54f1c2d2c99d..dee38d04034677 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -766,9 +766,9 @@ def forward(self, x, expect_use_full_prec_in_eval): if expect_use_full_prec_in_eval: assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}" else: - assert ( - x.dtype == low_prec_dtype - ), f"Expected {low_prec_dtype}, got {x.dtype}" + assert x.dtype == low_prec_dtype, ( + f"Expected {low_prec_dtype}, got {x.dtype}" + ) return self.a(x) mp_config = MixedPrecision( diff --git a/test/distributed/fsdp/test_fsdp_tp_integration.py b/test/distributed/fsdp/test_fsdp_tp_integration.py index 326157ec9e4148..2cc3858e12696e 100644 --- a/test/distributed/fsdp/test_fsdp_tp_integration.py +++ b/test/distributed/fsdp/test_fsdp_tp_integration.py @@ -91,9 +91,9 @@ def _get_params_and_sharding_info( tensor_parallel_size: int, ) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]: """ """ - assert ( - type(model) is SimpleModel - ), "Expects a `SimpleModel` since the sharding cases on the model definition" + assert type(model) is SimpleModel, ( + "Expects a `SimpleModel` since the sharding cases on the model definition" + ) param_name_to_numel = OrderedDict() param_name_to_sharding_info = OrderedDict() for param_name, param in model.named_parameters(): diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index a0e1d0a50cc078..7efe6ec6661ca9 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -654,9 +654,12 @@ def _test_multiple_forward( losses1 = [] losses2 = [] losses = [] - for _model, _optim in (fsdp_model, optim), ( - fsdp_model_orig_params, - optim_orig_params, + for _model, _optim in ( + (fsdp_model, optim), + ( + fsdp_model_orig_params, + optim_orig_params, + ), ): _optim.zero_grad() loss1 = _model(*inp1) @@ -1166,9 +1169,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: clean_tensor_name(tup[0]) for tup in self.named_parameters() ] params = [tup[1] for tup in self.named_parameters()] - assert ( - param_shapes[0] is not None and param_shapes[1] is not None - ), "`param_sizes` should be set" + assert param_shapes[0] is not None and param_shapes[1] is not None, ( + "`param_sizes` should be set" + ) assert_equal_fn( param_names, [ diff --git a/test/distributed/launcher/api_test.py b/test/distributed/launcher/api_test.py index c167a1e03c2058..48465516a913b3 100644 --- a/test/distributed/launcher/api_test.py +++ b/test/distributed/launcher/api_test.py @@ -411,3 +411,10 @@ def test_rdzv_handler_shutdown_on_agent_error(self, mock_get_rdzv, mock_agent_ru launch_agent(config, simple_rank_scale, []) rdzv_handler_mock.shutdown.assert_called_once() record_event_mock.assert_called_once() + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py index 691c43ddb54290..f3ab4090e8dc94 100755 --- a/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py +++ b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py @@ -19,6 +19,7 @@ see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched() - test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched() """ + import argparse import torch.distributed as dist diff --git a/test/distributed/launcher/launch_test.py b/test/distributed/launcher/launch_test.py index 1ef7fa7e284bb4..a3b17b93d18e73 100644 --- a/test/distributed/launcher/launch_test.py +++ b/test/distributed/launcher/launch_test.py @@ -84,3 +84,10 @@ def test_launch_with_env(self): self.assertSetEqual( {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 5cf402a3b6e79f..603f671546a53b 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -6,7 +6,6 @@ # LICENSE file in the root directory of this source tree. import copy -import os import sys from contextlib import nullcontext from typing import Any, cast @@ -30,12 +29,23 @@ from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW, SGD -from torch.testing._internal import common_distributed +from torch.testing._internal.common_distributed import ( + DistributedTestBase, + logger, + requires_accelerator_dist_backend, + requires_ddp_rank, + requires_gloo, + skip_if_lt_x_gpu, + skip_if_no_gpu, + skip_if_rocm_multiprocess, + skip_if_win32, +) +from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - IS_WINDOWS, parametrize, run_tests, + skipIfHpu, ) @@ -47,63 +57,24 @@ HAS_TORCHVISION = False -# Use GLOO on GPU when running CUDA + Windows -def _get_backend_for_tests(): - return ( - dist.Backend.NCCL - if not IS_WINDOWS and torch.cuda.is_available() - # Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when - # no GPUs are available. - else dist.Backend.GLOO - ) - - -BACKEND = _get_backend_for_tests() - +device_type = str(get_devtype()) -class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase): - def setUp(self): - super().setUp() - os.environ["WORLD_SIZE"] = str(self.world_size) - self._spawn_processes() +class TestZeroRedundancyOptimizer(DistributedTestBase): @property def device(self): - return ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - ) + return device_type @property def world_size(self): return 1 - def tearDown(self): - try: - torch.distributed.destroy_process_group() - except AssertionError: - pass - try: - os.remove(self.file_name) - except OSError: - pass - - def dist_init(self, rank, world_size=-1, backend=BACKEND): - if world_size < 1: - world_size = self.world_size - store = dist.FileStore(self.file_name, world_size) - return dist.init_process_group( - backend=backend, - store=store, - rank=rank, - world_size=world_size, - ) - class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer): def test_state_dict(self): """Check that ZeroRedundancyOptimizer exposes the expected state dict interface, irrespective of the sharding.""" - self.dist_init(self.rank) + self.create_pg(self.device) LR1 = 0.1 LR2 = 0.01 MOMENTUM = 0.9 @@ -171,7 +142,7 @@ def test_state_dict(self): def test_lr_scheduler(self): """Check that a normal PyTorch ``lr_scheduler`` is usable with ZeroRedundancyOptimizer.""" - self.dist_init(self.rank) + self.create_pg(self.device) NUM_ITERS = 5 LR = 0.01 x = torch.tensor([1.0], device=self.device, requires_grad=True) @@ -193,7 +164,7 @@ def test_lr_scheduler(self): def test_step_with_kwargs(self): """Check that the ``step(**kwargs)`` interface is properly exposed.""" - self.dist_init(self.rank) + self.create_pg(self.device) LR = 0.1 class SGDWithStepKWArg(torch.optim.SGD): @@ -217,7 +188,7 @@ def test_step_with_extra_inner_key(self): """Check that ZeroRedundancyOptimizer wrapping an optimizer that adds extra keys to ``param_groups`` exposes those keys through ZeRO's own ``param_groups``.""" - self.dist_init(self.rank) + self.create_pg(self.device) LR = 0.1 class SGDWithNewKey(torch.optim.SGD): @@ -236,7 +207,7 @@ def step(self, closure=None): def test_step_without_closure(self): """Check that the ``step()`` method (without closure) is handled as expected.""" - self.dist_init(self.rank) + self.create_pg(self.device) LR = 0.1 class SGDWithoutClosure(torch.optim.SGD): @@ -255,7 +226,7 @@ def step(self): def test_zero_grad(self): """Check that the ``zero_grad`` method is properly handled.""" - self.dist_init(self.rank) + self.create_pg(self.device) LR = 0.01 x = torch.rand(1) m = torch.nn.Linear(1, 1) @@ -271,7 +242,7 @@ def test_zero_grad(self): def test_constructor(self): """Check the robustness of the ZeroRedundancyOptimizer constructor by passing different values for the ``params`` argument.""" - self.dist_init(self.rank) + self.create_pg(self.device) LR = 0.01 m = torch.nn.Sequential( torch.nn.Linear(5, 10), @@ -321,9 +292,9 @@ def test_constructor(self): betas=BETAS, eps=EPS, ) - assert ( - len(o.param_groups) == 2 - ), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" + assert len(o.param_groups) == 2, ( + f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" + ) assert len(o.optim.param_groups) == 2, ( "Expected 2 local optimizer param groups, but got " f"{len(o.optim.param_groups)}" @@ -336,7 +307,7 @@ def test_same_dense_param_type(self): NOTE: This test should be removed once support for sparse parameters and varying parameter types is added. """ - self.dist_init(self.rank) + self.create_pg(self.device) LR = 0.01 inputs = [ [torch.sparse_coo_tensor(size=(2, 3))], @@ -353,25 +324,16 @@ def test_same_dense_param_type(self): class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): - @property - def device(self): - return ( - torch.device(self.rank) - if torch.cuda.is_available() - else torch.device("cpu") - ) - @property def world_size(self): - return min(4, max(2, torch.cuda.device_count())) + return min(4, max(2, torch.get_device_module(self.device).device_count())) @property def context(self): - return ( - nullcontext() - if not torch.cuda.is_available() - else torch.cuda.device(self.rank) - ) + if requires_ddp_rank(self.device): + return torch.get_device_module(self.device).device(self.rank) + else: + return nullcontext() def _check_same_model_params( self, @@ -396,12 +358,12 @@ def _check_same_model_params( msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message, ) - @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm_multiprocess + @skip_if_no_gpu + @skip_if_rocm_multiprocess def test_step(self): """Check that ZeroRedundancyOptimizer properly exposes the ``step()`` interface.""" - self.dist_init(self.rank, world_size=self.world_size) + self.create_pg(self.device) LR = 0.01 with self.context: @@ -436,13 +398,12 @@ def test_step(self): self.assertEqual(m.weight, m_zero.weight) self.assertEqual(m.bias, m_zero.bias) - @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm_multiprocess + @skip_if_no_gpu + @skip_if_rocm_multiprocess def test_step_with_closure(self): """Check that ZeroRedundancyOptimizer properly exposes the ``step(closure)`` interface.""" - self.dist_init(self.rank, world_size=self.world_size) - + self.create_pg(self.device) with self.context: for bucket_view in [False, True]: x_val = self.rank + 1 @@ -487,11 +448,11 @@ def closure(): self.assertEqual(m.weight, torch.tensor([[1.1]])) self.assertEqual(m.bias, torch.tensor([2.1])) - @common_distributed.skip_if_no_gpu + @skip_if_no_gpu def test_lr_scheduler(self): """Check that a normal PyTorch ``lr_scheduler`` is usable with ZeroRedundancyOptimizer.""" - self.dist_init(self.rank) + self.create_pg(self.device) x = torch.tensor([1.0], device=self.device, requires_grad=True) x2 = torch.tensor([1.0], device=self.device, requires_grad=True) o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01) @@ -519,7 +480,7 @@ def test_sharding(self): ``ZeroRedundancyOptimizer._partition_parameters()`` in zero_redundancy_optimizer.py. """ - self.dist_init(self.rank) + self.create_pg(self.device) LR = 0.01 sizes = [9, 7, 5, 3] params = [] @@ -541,7 +502,7 @@ def test_add_param_group(self): ``ZeroRedundancyOptimizer._partition_parameters()`` in zero_redundancy_optimizer.py. """ - self.dist_init(self.rank) + self.create_pg(self.device) LR = 0.01 # Test with all parameters trainable to begin with @@ -589,14 +550,14 @@ def some_trainable(): all_trainable() some_trainable() - @common_distributed.skip_if_no_gpu + @skip_if_no_gpu def test_multiple_param_groups(self): """ Check parity between constructing ZeRO with multiple parameter groups upfront versus adding parameter groups to ZeRO after construction versus a non-sharded optimizer. """ - self.dist_init(self.rank) + self.create_pg(self.device) BATCH_SIZE, NUM_ITERS = 8, 3 INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5 WD, LR = 0.01, 0.01 @@ -656,12 +617,12 @@ def test_multiple_param_groups(self): torch.testing.assert_close(layer1.bias, layer2.bias) torch.testing.assert_close(layer1.bias, layer3.bias) - @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm_multiprocess + @skip_if_no_gpu + @skip_if_rocm_multiprocess def test_collect_shards(self): """Check the state consolidation mechanism and the state dict exposed by ZeroRedundancyOptimizer.""" - self.dist_init(self.rank) + self.create_pg(self.device) LR = 1e-3 MOMENTUM = 0.99 BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5 @@ -719,27 +680,25 @@ def test_nondefault_process_group(self): # trivial MIN_WORLD_SIZE = 4 if self.world_size < MIN_WORLD_SIZE: - common_distributed.logger.info( + logger.info( "Skipping `test_nondefault_process_group()` since world size " "of %s is less than %s", self.world_size, MIN_WORLD_SIZE, ) return - BACKEND = dist.Backend.GLOO - self.dist_init(self.rank, self.world_size, BACKEND) - # Use GPU if enough are available, or fall back to CPU otherwise, which - # is fine since Gloo backend supports both - if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size: - device = torch.device(self.rank) - else: + # Use GPU if enough are available, or fall back to CPU otherwise + if torch.get_device_module(self.device).device_count() < self.world_size: device = torch.device("cpu") + else: + device = torch.device(self.device) + self.create_pg(device.type) # Create a new process group consisting of the even ranks to exercise # the case where the global and local ranks do not necessarily match subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0] process_group = dist.new_group( ranks=subgroup_ranks, - backend=BACKEND, + backend=self.backend(device.type), ) # Ranks not participating in the new process group are no longer needed if self.rank not in subgroup_ranks: @@ -754,9 +713,9 @@ def test_nondefault_process_group(self): LR = 1e-3 MOMENTUM = 0.99 REFERENCE_RANK = 0 - assert ( - REFERENCE_RANK in subgroup_ranks - ), "Reference rank must be in the new process group" + assert REFERENCE_RANK in subgroup_ranks, ( + "Reference rank must be in the new process group" + ) loss_fn = torch.nn.L1Loss().to(device) def check(optimizer): @@ -811,7 +770,7 @@ def closure(): ) check(optimizer) - @common_distributed.skip_if_no_gpu + @skip_if_no_gpu @parametrize( "optimizer_class_str", ["Adam", "AdamW", "SGD"], @@ -828,7 +787,7 @@ def test_local_optimizer_parity( ): """When combined with DDP, check that a local optimizer gives the same results as wrapping that optimizer with ZeroRedundancyOptimizer.""" - self.dist_init(self.rank) + self.create_pg(self.device) BATCHES = 20 BATCH_SIZE = 64 LR = 1e-3 @@ -867,7 +826,7 @@ def test_local_optimizer_parity( ) sharded_ddp_model = DDP( module=model, - device_ids=[self.rank], + device_ids=[self.rank] if requires_ddp_rank(self.device) else None, broadcast_buffers=True, find_unused_parameters=True, ) @@ -879,7 +838,7 @@ def test_local_optimizer_parity( ) ddp_model = DDP( local_model, - device_ids=[self.rank], + device_ids=[self.rank] if requires_ddp_rank(self.device) else None, broadcast_buffers=True, find_unused_parameters=True, ) @@ -892,7 +851,7 @@ def test_local_optimizer_parity( ) def check_step(): - input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM)) + input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM)).to(self.device) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() @@ -970,13 +929,12 @@ def _test_zero_join(self, device): NUM_EPOCHS = 2 LR = 0.01 torch.manual_seed(0) - torch.cuda.manual_seed(0) + if "cpu" not in device: + torch.get_device_module(device).manual_seed(0) rank = self.rank world_size = self.world_size - is_gpu = device.type == "cuda" - backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO - self.dist_init(rank, world_size, backend) + self.create_pg(device) model = torch.nn.Sequential( torch.nn.Linear(2, 3), @@ -988,7 +946,9 @@ def _test_zero_join(self, device): # DDP ensures correct gradients in data parallel training, so DDP with # local optimizers on uneven inputs should be equivalent to ZeRO on # uneven inputs with gradients being manually set - ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model) + ddp_model = ( + DDP(model, device_ids=[rank]) if requires_ddp_rank(device) else DDP(model) + ) local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR) zero_model = copy.deepcopy(model) zero_model.to(device) @@ -1111,27 +1071,28 @@ def join_process_group(self): ) iter += 1 - @common_distributed.requires_nccl() - @common_distributed.skip_if_no_gpu + @requires_accelerator_dist_backend() + @skip_if_no_gpu def test_zero_join_gpu(self): """Check that the ZeRO join hook allows training with uneven inputs on GPU.""" self._test_zero_join(self.device) - @common_distributed.requires_gloo() + @requires_gloo() def test_zero_join_cpu(self): """Check that the ZeRO join hook allows training with uneven inputs on CPU.""" - self._test_zero_join(torch.device("cpu")) + self._test_zero_join("cpu") - def _test_zero_model_parallel(self, parameters_as_bucket_view: bool): + def _test_zero_model_parallel(self, parameters_as_bucket_view: bool, device: str): # Use two processes each with two GPUs assert self.rank < 2 NUM_EPOCHS = 2 NUM_INPUTS = 4 LR = 0.01 torch.manual_seed(0) - torch.cuda.manual_seed(0) + if "cpu" not in device: + torch.get_device_module(device).manual_seed(0) class ModelParallelModel(torch.nn.Module): def __init__(self, dev0, dev1): @@ -1204,24 +1165,31 @@ def closure_ddp(): # Increased tolerances are needed to pass when using TF32 # See: https://github.com/pytorch/pytorch/issues/67764 - torch.testing.assert_close( - local_loss.cpu(), - ddp_loss.cpu(), - rtol=1e-03, - atol=1e-08, - ), "Losses differ between local optimizer and ZeRO" + ( + torch.testing.assert_close( + local_loss.cpu(), + ddp_loss.cpu(), + rtol=1e-03, + atol=1e-08, + ), + "Losses differ between local optimizer and ZeRO", + ) for local_p, ddp_p in zip( local_model.parameters(), ddp_model.parameters() ): - torch.testing.assert_close( - local_p.cpu(), - ddp_p.cpu(), - rtol=1e-03, - atol=1e-04, - ), "Models differ after a step" - - @common_distributed.skip_if_lt_x_gpu(4) + ( + torch.testing.assert_close( + local_p.cpu(), + ddp_p.cpu(), + rtol=1e-03, + atol=1e-04, + ), + "Models differ after a step", + ) + + @skipIfHpu + @skip_if_lt_x_gpu(4) @parametrize( "parameters_as_bucket_view", [False, True], @@ -1234,8 +1202,8 @@ def test_zero_model_parallel( layers are assigned to different devices.""" if self.rank >= 2: return - self.dist_init(self.rank, world_size=2) - self._test_zero_model_parallel(parameters_as_bucket_view) + self.create_pg(self.device, world_size=2) + self._test_zero_model_parallel(parameters_as_bucket_view, self.device) def _test_ddp_zero_overlap( self, @@ -1250,12 +1218,10 @@ def _test_ddp_zero_overlap( SGD_WEIGHT_DECAY = 0.001 NUM_INPUTS = 5 torch.manual_seed(0) - torch.cuda.manual_seed(0) + if "cpu" not in device: + torch.get_device_module(device).manual_seed(0) rank = self.rank - is_gpu = device.type == "cuda" - if is_gpu: - torch.cuda.set_device(device) models_to_test = [ ( torch.nn.Sequential( @@ -1273,11 +1239,16 @@ def _test_ddp_zero_overlap( ) ) for model, inputs in models_to_test: - # Enable determinism in cudnn operators - with torch.backends.cudnn.flags( - enabled=True, deterministic=True, benchmark=False - ): - device_ids = [rank] if is_gpu else None + # Select deterministic context based on device + det_ctx = ( + torch.backends.cudnn.flags( + enabled=True, deterministic=True, benchmark=False + ) + if "cuda" in device + else torch.use_deterministic_algorithms(True) + ) + with det_ctx: + device_ids = [rank] if requires_ddp_rank(device) else None # Set up the DDP model overlapping with ZeRO ddp_model_overlap = DDP( copy.deepcopy(model).to(device), @@ -1374,10 +1345,10 @@ def _test_ddp_zero_overlap( # NOTE: The test is skipped if using Windows since functional optimizers # are not currently supported. - @common_distributed.skip_if_win32() - @common_distributed.requires_nccl() - @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm_multiprocess + @skip_if_win32() + @requires_accelerator_dist_backend() + @skip_if_no_gpu + @skip_if_rocm_multiprocess @parametrize( "use_gpu", [True], @@ -1413,9 +1384,7 @@ def test_ddp_zero_overlap( by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO and DDP arguments achieves parity with DDP using a local optimizer. """ - device = torch.device(self.rank) if use_gpu else torch.device("cpu") - backend = _get_backend_for_tests() - self.dist_init(self.rank, self.world_size, backend) + self.create_pg(self.device) hook_constructor = ( hook_with_zero_step if not use_interleaved_hook @@ -1423,7 +1392,7 @@ def test_ddp_zero_overlap( ) self._test_ddp_zero_overlap( - device, + self.device if use_gpu else "cpu", hook_constructor, gradient_as_bucket_view, static_graph, diff --git a/test/distributed/pipelining/test_pipe.py b/test/distributed/pipelining/test_pipe.py index 3e02c4de3c93b6..8ddb5634811cb9 100644 --- a/test/distributed/pipelining/test_pipe.py +++ b/test/distributed/pipelining/test_pipe.py @@ -89,9 +89,9 @@ def test_model_split(self, ModelClass): mb_args=(x, y), ) - assert ( - pipe.num_stages == EXPECTED_N_STAGES[ModelClass] - ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" + assert pipe.num_stages == EXPECTED_N_STAGES[ModelClass], ( + f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" + ) ref_out = mod(x, y) out = pipe(x, y)[0] @@ -109,9 +109,7 @@ def test_model_split(self, ModelClass): new_names.update(stage_fqns) if CHECK_FQN_SET_EQUALITY: - assert ( - old_names == new_names - ), f""" + assert old_names == new_names, f""" old names {old_names} new names {new_names} """ diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index e0db341254399c..50aa9ff21ba08c 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -44,13 +44,13 @@ d_hid = 512 batch_size = 256 - torch.manual_seed(0) - device_type = "cuda" class ScheduleTest(MultiProcContinousTest): + world_size = 2 + @classmethod def backend_str(cls) -> str: # Testing with NCCL backend @@ -513,15 +513,13 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): for name, p in stage_module.named_parameters(): ref_p = ref_submod.get_parameter(name) try: - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=1e-3) except AssertionError: print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") raise @requires_nccl() - @skip_but_pass_in_sandcastle_if( - not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs" - ) + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble]) def test_schedule_with_native_zero_bubble(self, ScheduleClass): print(ScheduleClass) @@ -615,9 +613,7 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): raise @requires_nccl() - @skip_but_pass_in_sandcastle_if( - not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs" - ) + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize( "ScheduleClass", [ @@ -722,9 +718,7 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): raise @requires_nccl() - @skip_but_pass_in_sandcastle_if( - not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs" - ) + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize( "schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble] ) @@ -829,9 +823,7 @@ def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime): raise @requires_nccl() - @skip_but_pass_in_sandcastle_if( - not torch.cuda.device_count() == 2, "This test requires exactly 2 GPUs" - ) + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble]) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 5fb30b5e1d17f2..ae1e684d7c222f 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -60,9 +60,9 @@ def test_unflatten(self, device): for stage_idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(stage_idx) for param_name, _ in stage_mod.named_parameters(): - assert ( - param_name in orig_state_dict - ), f"{param_name} not in original state dict" + assert param_name in orig_state_dict, ( + f"{param_name} not in original state dict" + ) print("Param qualname test passed") # Check equivalence diff --git a/test/distributed/rpc/test_share_memory.py b/test/distributed/rpc/test_share_memory.py index bda98b1df94952..97273981d08297 100644 --- a/test/distributed/rpc/test_share_memory.py +++ b/test/distributed/rpc/test_share_memory.py @@ -45,9 +45,9 @@ def __init__(self) -> None: for t in torch._tensor_classes: self._dispatch_table[t] = TorchMpReductions.reduce_tensor self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor - self._dispatch_table[ - torch.nn.parameter.Parameter - ] = TorchMpReductions.reduce_tensor + self._dispatch_table[torch.nn.parameter.Parameter] = ( + TorchMpReductions.reduce_tensor + ) def worker_loop(a): diff --git a/test/distributed/tensor/debug/test_comm_mode_features.py b/test/distributed/tensor/debug/test_comm_mode_features.py index 4242cd20a71af9..6c07431291508e 100644 --- a/test/distributed/tensor/debug/test_comm_mode_features.py +++ b/test/distributed/tensor/debug/test_comm_mode_features.py @@ -11,7 +11,7 @@ parallelize_module, RowwiseParallel, ) -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, skipIfHpu from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, MLPModule, @@ -111,6 +111,7 @@ def test_MLP_distributed_sharding_display(self): ) self.check_same_set_of_keys(module_sharding_dict, comm_mode.get_sharding_info()) + @skipIfHpu @with_comms def test_MLPStacked_distributed_sharding_display(self): """ @@ -143,10 +144,10 @@ def test_MLPStacked_distributed_sharding_display(self): model2 = MLPStacked(self.device_type) parallelize_plan = { - "MLPStacked.layers.0.net1": ColwiseParallel(), - "MLPStacked.layers.0.net2": RowwiseParallel(), - "MLPStacked.layers.1.net1": ColwiseParallel(), - "MLPStacked.layers.1.net2": RowwiseParallel(), + "layers.0.net1": ColwiseParallel(), + "layers.0.net2": RowwiseParallel(), + "layers.1.net1": ColwiseParallel(), + "layers.1.net2": RowwiseParallel(), } model2 = parallelize_module(model2, device_mesh, parallelize_plan) @@ -218,6 +219,7 @@ def test_MLP_module_tracing(self): 1, ) + @skipIfHpu @skip_unless_torch_gpu @with_comms def test_transformer_module_tracing(self, is_seq_parallel=False): diff --git a/test/distributed/tensor/experimental/test_local_map.py b/test/distributed/tensor/experimental/test_local_map.py index 3b8e3b4c78a4c9..fbbec59293ba0c 100644 --- a/test/distributed/tensor/experimental/test_local_map.py +++ b/test/distributed/tensor/experimental/test_local_map.py @@ -5,9 +5,16 @@ import torch import torch.distributed._functional_collectives as funcol from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard +from torch.distributed.tensor import ( + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.experimental import local_map +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -318,6 +325,109 @@ def test_local_map_redistribute(self): with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"): Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) + # check for `in_grad_placements` handling + @with_comms() + def test_local_map_with_grad_placement(self): + """ + Test the gradient result is correct when we specify the right + `in_grad_placements`. + """ + device_mesh = init_device_mesh( + device_type=self.device_type, mesh_shape=(self.world_size,) + ) + torch.manual_seed(12) + + # ground truth output, consider X as a batch of 2 on dim 0. + X = torch.randn(4, 2, device=self.device_type, requires_grad=True) + X1, X2 = torch.chunk(X, 2, dim=0) + X1 = X1.detach().requires_grad_() + X2 = X2.detach().requires_grad_() + W = torch.randn(2, 4, device=self.device_type, requires_grad=True) + Y1 = torch.mm(X1, W) + Y2 = torch.mm(X2, W) + loss = Y1.sum() + Y2.sum() + loss.backward() + + in_placement_mismatch_choice = (False, True) + for is_in_placement_mismatch in in_placement_mismatch_choice: + if is_in_placement_mismatch: + # in_placements for local_map() will take effect + X_dt = distribute_tensor(X, device_mesh, replicate) + else: + # in_placements for local_map() will not take effect + X_dt = distribute_tensor(X, device_mesh, row_wise) + W_dt = distribute_tensor(W, device_mesh, replicate) + in_grad_placements = ([Shard(0)], [Partial()]) + + local_mm_forward = local_map( + mm_forward, + out_placements=[Shard(0)], + in_placements=(row_wise, replicate), + in_grad_placements=in_grad_placements, + device_mesh=device_mesh, + redistribute_inputs=True, + ) + Y_dt = local_mm_forward(X_dt, W_dt) + self.assertEqual(Y_dt.full_tensor(), torch.cat([Y1, Y2], dim=0)) + + # Note: this is a way to simulate how DPP works. We don't need to + # all_gather the loss. Instead, we do all_reduce to each distributed + # weight. + loss = Y_dt.to_local().sum() + loss.backward() + + if not is_in_placement_mismatch: + self.assertEqual(X_dt.grad.placements, in_grad_placements[0]) + self.assertEqual(W_dt.grad.placements, in_grad_placements[1]) + # regardless of is_in_placement_mismatch, grad output should always + # match + self.assertEqual( + X_dt.grad.full_tensor(), torch.cat([X1.grad, X2.grad], dim=0) + ) + self.assertEqual(W_dt.grad.full_tensor(), W.grad) + + @skip_if_lt_x_gpu(4) + @with_comms + def test_multi_mesh_inputs(self): + """ + Test the function can be applied to accept DTensors that lives + on different device meshes. + """ + mesh_full = init_device_mesh( + device_type=self.device_type, mesh_shape=(self.world_size,) + ) + mesh_2d = init_device_mesh( + device_type=self.device_type, mesh_shape=(self.world_size // 2, 2) + ) + comm_mode = CommDebugMode() + + X = torch.randn(8, 32, device=self.device_type, requires_grad=False) + x_placements = [Shard(1)] + W = torch.randn(16, 8, device=self.device_type, requires_grad=False) + w_placements = [Shard(0), Shard(1)] + + X_dt = distribute_tensor(X, mesh_full, x_placements) + W_dt = distribute_tensor(W, mesh_2d, w_placements) + + # local output shape should be (8, 4) + output_placements = [Replicate(), Shard(1)] + + local_mm_forward = local_map( + mm_forward, + out_placements=output_placements, + in_placements=(x_placements, w_placements), + device_mesh=mesh_2d, + ) + + with comm_mode: + Y_dt = local_mm_forward(X_dt, W_dt) + + self.assertEqual(comm_mode.get_total_counts(), 0) + # output local shape should be (8, 4) + self.assertEqual(Y_dt.to_local().shape, (8, 4)) + # output lives in mesh_2d + self.assertEqual(Y_dt.device_mesh, mesh_2d) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index 7a64e1fce19747..906b7d1a4a52b4 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -13,7 +13,7 @@ micro_pipeline_tp_pass, ) from torch._inductor.fx_passes.post_grad import remove_noop_ops, view_to_reshape -from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code +from torch._inductor.utils import fresh_cache, run_and_get_triton_code from torch.distributed._functional_collectives import ( all_gather_tensor, reduce_scatter_tensor, @@ -81,7 +81,7 @@ def tearDown(self): dist.destroy_process_group() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_find_all_gather_patterns(self): group = dist.group.WORLD @@ -134,7 +134,7 @@ def func( ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_find_reduce_scatter_patterns(self): group = dist.group.WORLD @@ -173,7 +173,7 @@ def func(inp: torch.Tensor) -> torch.Tensor: self.assertEqual(reduce_scatters[1].scatter_dim, 1) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_get_unexposed_collectives(self): group = dist.group.WORLD @@ -201,7 +201,7 @@ def func(inp: torch.Tensor) -> torch.Tensor: @parametrize("A_dims", [2, 3]) @parametrize("gather_dim", [0, 1, 2]) @parametrize("return_A", [True, False]) - @fresh_inductor_cache() + @fresh_cache() def test_fuse_all_gather_matmul(self, A_dims, gather_dim, return_A): if gather_dim >= A_dims: return @@ -248,7 +248,7 @@ def func(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor: @parametrize("A_dims", [2, 3]) @parametrize("gather_dim", [0, 1, 2]) @parametrize("return_A", [True, False]) - @fresh_inductor_cache() + @fresh_cache() def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim, return_A): if gather_dim >= A_dims: return @@ -321,7 +321,7 @@ def func( @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("scatter_dim", [0, 1, 2]) - @fresh_inductor_cache() + @fresh_cache() def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim): if scatter_dim >= A_dims: return @@ -350,7 +350,7 @@ def func(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("scatter_dim", [0, 1, 2]) - @fresh_inductor_cache() + @fresh_cache() def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim): if scatter_dim >= A_dims: return @@ -403,7 +403,7 @@ def func( @runOnRocmArch(MI300_ARCH) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("scatter_dim", [0, 1, 2]) - @fresh_inductor_cache() + @fresh_cache() def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape( self, scatter_dim ): @@ -465,7 +465,7 @@ def reshape_mm_reshape( @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @parametrize("shard_dim", [0, 1]) - @fresh_inductor_cache() + @fresh_cache() def test_dtensor_seq_par(self, shard_dim: int): model: torch.nn.Module = MLPModule(device="cuda", bias=False) device_mesh = DeviceMesh( diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py index 335d75522ed702..cc41b250e34aa2 100644 --- a/test/distributed/tensor/parallel/test_parallelize_api.py +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -332,6 +332,49 @@ def test_parallelize_module_multi_wildcard(self): ) self._compare_module(model, model_tp, inp_size, rank0_only=False) + @with_comms + def test_parallelize_module_with_root_module(self): + inp_size = [16, 10] + model = MLPModule(self.device_type) + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + model_tp = deepcopy(model) + model_tp = parallelize_module( + model_tp, + device_mesh, + { + "": PrepareModuleInputOutput( + input_layouts=Replicate(), + desired_input_layouts=Shard(0), + output_layouts=Shard(0), + desired_output_layouts=Replicate(), + ), + "net1": ColwiseParallel(input_layouts=Shard(0)), + "net2": RowwiseParallel(output_layouts=Shard(0)), + }, + ) + self._compare_module(model, model_tp, inp_size, rank0_only=False) + + @with_comms + def test_parallelize_module_with_no_match(self): + inp_size = [16, 10] + model = MLPModule(self.device_type) + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + model_tp = deepcopy(model) + with self.assertWarns(UserWarning): + model_tp = parallelize_module( + model_tp, + device_mesh, + { + "net0.hello.world": ColwiseParallel(), + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + "net3": ColwiseParallel(), + }, + ) + self._compare_module(model, model_tp, inp_size, rank0_only=False) + @with_comms def test_under_devicemesh_context(self): # test ColwiseParallel @@ -357,7 +400,8 @@ def test_empty_plan(self): # Call parallelize_module with empty plan. # Goal is not to crash. device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - parallelize_module(model, device_mesh) + with self.assertWarns(UserWarning): + parallelize_module(model, device_mesh) if __name__ == "__main__": diff --git a/test/distributed/tensor/parallel/test_tp_random_state.py b/test/distributed/tensor/parallel/test_tp_random_state.py index ae47198a295636..a12bf017932f2a 100644 --- a/test/distributed/tensor/parallel/test_tp_random_state.py +++ b/test/distributed/tensor/parallel/test_tp_random_state.py @@ -21,10 +21,10 @@ def get_tensor_slice(self, idx, n, large_tensor): assert shape[0] % n == 0 local_shape = [shape[0] // n, shape[1]] - slice_idx = [ + slice_idx = ( slice(idx * local_shape[0], (idx + 1) * local_shape[0]), slice(local_shape[1]), - ] + ) return large_tensor[slice_idx] def check_gathered_tensors(self, self_rank, size, gathered_tensors, assertFunc): diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index 5d40a18f06742a..b6588c2ad95eb1 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -187,7 +187,7 @@ def test_depthwise_convolution(self): @skip_if_lt_x_gpu(2) def test_conv_backward_none_grad_inp(self): device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(self.world_size,) + device_type=self.device_type, mesh_shape=(self.world_size,) ) conv = nn.Conv2d(64, 64, 3, padding=1).train() x = torch.randn(1, 64, 32, 32) diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index eecdb08ecec69c..b41d87c7e44b4c 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -29,7 +29,7 @@ parallelize_module, RowwiseParallel, ) -from torch.testing._internal.common_utils import IS_FBCODE, run_tests +from torch.testing._internal.common_utils import IS_FBCODE, run_tests, skipIfHpu from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -540,6 +540,7 @@ def test_dtensor_save_load(self): reloaded_st = torch.load(buffer, weights_only=True) self.assertEqual(sharded_tensor, reloaded_st) + @skipIfHpu @with_comms @unittest.skipIf( IS_FBCODE, diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 83c4de81a5d978..23114f87f46a83 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -36,6 +36,7 @@ instantiate_parametrized_tests, parametrize, run_tests, + skipIfHpu, skipIfTorchDynamo, TEST_CUDA, TEST_HPU, @@ -255,6 +256,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + @skipIfHpu def test_dtensor_dynamic(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -481,6 +483,7 @@ def fn(x): self.assertEqual(fn(x3), opt_fn(x3)) self.assertEqual(cnt.frame_count, 2) + @skipIfHpu def test_dtensor_partial_placement_redistribute_unbalanced_correct_strides(self): # Partial -> Shard on an unbalanced tensor results in: # - A contiguous DTensor @@ -650,6 +653,7 @@ def redistribute_kwargs_fn(x): res = opt_kwargs_fn(x) self.assertEqual(res, ref) + @skipIfHpu def test_dynamo_dtensor_from_local_redistribute_async(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) from torch.distributed._functional_collectives import AsyncCollectiveTensor @@ -721,6 +725,7 @@ def fn(x_dt): res = opt_fn(x_dt) self.assertEqual(ref, res) + @skipIfHpu def test_graph_input_is_async(self): from torch.distributed._functional_collectives import AsyncCollectiveTensor @@ -801,12 +806,7 @@ def fn(x): out_dt = torch.matmul(tmp_dt, y_dt) out_dt.sum().backward() - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_lt_x_gpu(1) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) - @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) - def test_tp_compile_comm_reordering(self): + def _test_tp_compile_comm_reordering(self): class FakeAttention(nn.Module): def __init__(self) -> None: super().__init__() @@ -872,9 +872,24 @@ def forward(self, input): "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal" ).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check( "extern_kernels.mm(buf0," - ).run( - code - ) + ).run(code) + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @skip_if_lt_x_gpu(1) + # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor + @patch.object(torch._inductor.config, "compile_threads", 1) + @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) + def test_tp_compile_comm_reordering(self): + self._test_tp_compile_comm_reordering() + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @skip_if_lt_x_gpu(1) + # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor + @patch.object(torch._inductor.config, "compile_threads", 1) + @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) + @torch._inductor.config.patch("graph_partition", True) + def test_tp_compile_comm_reordering_graph_partition(self): + self._test_tp_compile_comm_reordering() @instantiate_parametrized_tests @@ -962,7 +977,7 @@ def test_2d_fsdp_tp_compile(self, use_ca): # 2-D mesh is [dp, tp] twod_mesh = init_device_mesh( - "cuda", + self.device_type, (data_parallel_size, self.world_size // data_parallel_size), mesh_dim_names=["dp", "tp"], ) @@ -1015,7 +1030,9 @@ def test_2d_fsdp_tp_ac_compile(self, use_ca): # 2-D mesh is [dp, tp] mesh_2d = init_device_mesh( - "cuda", mesh_shape=(dp_degree, tp_degree), mesh_dim_names=("dp", "tp") + self.device_type, + mesh_shape=(dp_degree, tp_degree), + mesh_dim_names=("dp", "tp"), ) inp = torch.rand(20, 10, device=self.device_type) @@ -1061,7 +1078,9 @@ def test_2d_fsdp_tp_ac_compile(self, use_ca): @skip_if_lt_x_gpu(4) @parametrize("use_ca", [True, False]) def test_compile_dtensor_redistribute_backward(self, use_ca): - mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size)) + mesh = DeviceMesh( + device_type=self.device_type, mesh=torch.arange(self.world_size) + ) def fn(x, y): dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False) diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index cddc352eda8388..ba43335d1ddcb2 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -385,6 +385,8 @@ def wrapped(fn): xfail("special.bessel_y1"), xfail("special.chebyshev_polynomial_t"), xfail("special.chebyshev_polynomial_u"), + xfail("special.chebyshev_polynomial_v"), + xfail("special.chebyshev_polynomial_w"), xfail("special.entr"), xfail("special.erfcx"), xfail("special.hermite_polynomial_h"), @@ -393,6 +395,7 @@ def wrapped(fn): xfail("special.i1"), xfail("special.i1e"), xfail("special.laguerre_polynomial_l"), + xfail("special.legendre_polynomial_p"), xfail("special.log_ndtr"), xfail("special.modified_bessel_i0"), xfail("special.modified_bessel_i1"), @@ -401,6 +404,10 @@ def wrapped(fn): xfail("special.ndtri"), xfail("special.scaled_modified_bessel_k0"), xfail("special.scaled_modified_bessel_k1"), + xfail("special.shifted_chebyshev_polynomial_t"), + xfail("special.shifted_chebyshev_polynomial_u"), + xfail("special.shifted_chebyshev_polynomial_v"), + xfail("special.shifted_chebyshev_polynomial_w"), xfail("special.spherical_bessel_j0"), xfail("special.xlog1py"), xfail("special.zeta"), diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 0ce1206ae1bd4f..48f92c4ecd7424 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -441,9 +441,11 @@ class SubTest(NamedTuple): out_req_grad: bool subtest_fails = {} - valid_filter = lambda cfg: not ( # noqa: E731 - cfg.ln_req_grad and not cfg.elementwise_affine - ) and any(cfg[2:]) + valid_filter = ( # noqa: E731 + lambda cfg: ( + not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:]) + ) + ) subtest_cfgs = list( filter( valid_filter, @@ -566,9 +568,9 @@ def forward(self, tokens): except Exception as e: subtest_fails[subtest_cfg] = e # if any subtest fails, provide the failed subtests and report the overall failure - assert ( - not subtest_fails - ), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" + assert not subtest_fails, ( + f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" + ) @with_comms def test_topk(self): diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index b1ca23cfc05fbe..886fb7f1468a28 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -6,7 +6,7 @@ from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard from torch.distributed.tensor._collective_utils import redistribute_cost from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta -from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy +from torch.distributed.tensor._op_schema import OpSchema, OpSpec, OpStrategy from torch.distributed.tensor._ops._einsum_strategy import ( EinsumDims, gen_einsum_strategies, @@ -184,9 +184,9 @@ def test_redistribute_cost_latency(self): op_schema = OpSchema( torch.ops.aten.addmm.default, ( - OpStrategy([PlacementStrategy(shard0_spec)]), - OpStrategy([PlacementStrategy(partial_spec)]), - OpStrategy([PlacementStrategy(shard1_spec)]), + OpStrategy([OpSpec(shard0_spec)]), + OpStrategy([OpSpec(partial_spec)]), + OpStrategy([OpSpec(shard1_spec)]), ), {}, ) @@ -261,8 +261,8 @@ def test_mm_strategies(self): op_schema = OpSchema( torch.ops.aten.mm.default, ( - OpStrategy([PlacementStrategy(lhs_spec)]), - OpStrategy([PlacementStrategy(rhs_spec)]), + OpStrategy([OpSpec(lhs_spec)]), + OpStrategy([OpSpec(rhs_spec)]), ), {}, ) @@ -308,8 +308,8 @@ def test_bmm_strategies(self): op_schema = OpSchema( torch.ops.aten.bmm.default, ( - OpStrategy([PlacementStrategy(lhs_spec)]), - OpStrategy([PlacementStrategy(rhs_spec)]), + OpStrategy([OpSpec(lhs_spec)]), + OpStrategy([OpSpec(rhs_spec)]), ), {}, ) diff --git a/test/distributed/tensor/test_pointwise_ops.py b/test/distributed/tensor/test_pointwise_ops.py index 2fce3efd7fcdcb..28dd1ac9def51a 100644 --- a/test/distributed/tensor/test_pointwise_ops.py +++ b/test/distributed/tensor/test_pointwise_ops.py @@ -17,6 +17,7 @@ Replicate, Shard, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorOpTestBase, @@ -147,14 +148,6 @@ def test_partial_add(self): d_3 = d_1 + d_2 self.assertTrue(d_3._spec.placements[0].is_partial()) - def test_partial_mul(self): - device_mesh = self.build_device_mesh() - d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) - d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) - d_3 = d_1 * d_2 - self.assertTrue(d_3._spec.placements[0].is_replicate()) - self.assertEqual(d_3.to_local(), torch.ones(2, 2) * (self.world_size**2)) - def test_activations(self): device_mesh = self.build_device_mesh() self._run_sharded_elementwise_ops( @@ -282,6 +275,62 @@ def test_mul_out(self): self.assertEqual(input_tensor, dtensor.to_local()) self.assertEqual(expected, dt.to_local()) + def test_mul_partial(self): + # we only test the partial behavior for mul op as other placement + # behaviors should be well tested in test_dtensor_ops.py + device_mesh = self.build_device_mesh() + comm_mode = CommDebugMode() + # 1. simple test for partial * partial + d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) + d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) + with comm_mode: + d_3 = d_1 * d_2 + comm_counts = comm_mode.get_total_counts() + self.assertEqual(comm_counts, 1) + self.assertTrue(isinstance(d_3, DTensor)) + self.assertEqual(d_3.placements, (Partial(),)) + self.assertEqual(d_3.to_local(), torch.ones(2, 2) * (self.world_size)) + + # 2. test the partial input DTensor * scalar/replicate input + input = torch.full((8, 8), 1.0, device=self.device_type) + + # test for different types of other inputs + other_inps = ( + 2.0, # scalar + torch.tensor(2.0, device=self.device_type), # scalar tensor + torch.full((8, 8), 2.0, device=self.device_type), # tensor + ) + + for partial_op in ["sum", "avg"]: + expected_p_out = ( + input * self.world_size * 2.0 if partial_op == "sum" else input * 2.0 + ) + + d_input = DTensor.from_local(input, device_mesh, [Partial(partial_op)]) + + for other_inp in other_inps: + if isinstance(other_inp, Tensor) and other_inp.numel() > 1: + d_other = distribute_tensor(other_inp, device_mesh, [Replicate()]) + else: + d_other = other_inp + + with comm_mode: + z = d_input * d_other + + comm_counts = comm_mode.get_total_counts() + self.assertEqual(comm_counts, 0) + self.assertTrue(isinstance(z, DTensor)) + self.assertEqual(z.placements, (Partial(partial_op),)) + self.assertEqual(z.full_tensor(), expected_p_out) + + # test other partial to assert the partial not getting propagated + d_input = DTensor.from_local(input, device_mesh, [Partial("max")]) + d_other = distribute_tensor(torch.ones(8, 8), device_mesh, [Replicate()]) + + z = d_input * d_other + self.assertEqual(z.placements, (Replicate(),)) + self.assertEqual(z.to_local(), input) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 3b37f338f47e87..5e98934249e97c 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -24,7 +24,7 @@ from torch.distributed.tensor._utils import compute_local_shape_and_global_offset from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module -from torch.testing._internal.common_utils import run_tests, TEST_HPU +from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_if_lt_x_gpu, @@ -33,9 +33,6 @@ ) -TYPE_DEVICE = "hpu" if TEST_HPU else "cuda" - - class DistTensorRandomInitTest(DTensorTestBase): def _run_init_op(self, init_op, *args, **kwargs): device_mesh = self.build_device_mesh() @@ -55,7 +52,7 @@ def _run_init_op(self, init_op, *args, **kwargs): self.assertEqual(local_tensor_clone, dtensor.to_local()) else: # create DTensor from Tensor - _tensor = torch.empty(*input_size, device=TYPE_DEVICE) + _tensor = torch.empty(*input_size, device=self.device_type) dtensor = distribute_tensor(_tensor, device_mesh, [Shard(1)]) # DTensor random init @@ -65,12 +62,12 @@ def _run_init_op(self, init_op, *args, **kwargs): # compare with local tensors from other ranks for other_rank in range(self.world_size): if self.rank != other_rank: - slice_idx = [ + slice_idx = ( slice(input_size[0]), slice( other_rank * input_size[1], (other_rank + 1) * input_size[1] ), - ] + ) # other rank should have a different local tensor self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor) @@ -173,7 +170,9 @@ def test_tp_model_meta_init(self): self.assertEqual(model.weight.device, torch.device("meta")) # actual initialization - device = torch.device("cuda", torch.cuda.current_device()) + device = torch.device( + self.device_type, torch.get_device_module(self.device_type).current_device() + ) model.to_empty(device=device) model.reset_parameters() self.assertTrue( @@ -224,7 +223,9 @@ def test_fsdp_tp_model_meta_init(self): self.assertEqual(model.weight.device, torch.device("meta")) # actual initialization - device = torch.device("cuda", torch.cuda.current_device()) + device = torch.device( + self.device_type, torch.get_device_module(self.device_type).current_device() + ) model.to_empty(device=device) model.reset_parameters() self.assertTrue( @@ -266,7 +267,9 @@ def test_rng_tracker_init(self): # seed synchronization now does NOT happen after the first `distribute_tensor` # call dt = distribute_tensor( - torch.empty([self.world_size], device=TYPE_DEVICE), device_mesh, [Shard(0)] + torch.empty([self.world_size], device=self.device_type), + device_mesh, + [Shard(0)], ) self.assertTrue(random._rng_tracker is None) # seed synchronization only happens after `manual_seed` or the first DTensor @@ -366,7 +369,7 @@ def test_deterministic_dropout_1d(self): size = [4, 4] dtensor = distribute_tensor( - torch.empty(*size, device=TYPE_DEVICE), device_mesh, [Shard(1)] + torch.empty(*size, device=self.device_type), device_mesh, [Shard(1)] ) # a random op call shifts the offset @@ -537,9 +540,9 @@ def test_deterministic_uniform_2d(self): slice(offset, offset + size) for offset, size in other_local_shard ] if local_shard_offset == other_local_shard_offset: - self.assertEqual(full_tensor[slice_idx], local_tensor) + self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor) else: - self.assertNotEqual(full_tensor[slice_idx], local_tensor) + self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor) class DistTensorRandomOpsTest3D(DTensorTestBase): @@ -571,7 +574,9 @@ def test_hsdp_tp_model_meta_init(self): self.assertEqual(model.weight.device, torch.device("meta")) # actual initialization - device = torch.device("cuda", torch.cuda.current_device()) + device = torch.device( + self.device_type, torch.get_device_module(self.device_type).current_device() + ) model.to_empty(device=device) model.reset_parameters() self.assertTrue( diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index 341cc9a8b9ad7e..8087b0144f367a 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -15,7 +15,13 @@ ) from torch.distributed.tensor._collective_utils import shard_dim_alltoall from torch.distributed.tensor.debug import CommDebugMode -from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TEST_CUDA, + TEST_HPU, +) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -31,7 +37,8 @@ def world_size(self): return 4 @with_comms - def test_shard_to_replicate_forward_backward(self): + @parametrize("dtype", [torch.float32, torch.cfloat]) + def test_shard_to_replicate_forward_backward(self, dtype): # 1) test shard -> replicate forward device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) replica_spec = [Replicate()] @@ -49,7 +56,7 @@ def test_shard_to_replicate_forward_backward(self): for input_size, shard_dim in input_sizes_and_shard_dim: shard_spec = [Shard(shard_dim)] expected_tensor = torch.randn( - input_size, device=self.device_type, requires_grad=True + input_size, device=self.device_type, requires_grad=True, dtype=dtype ) dtensor = distribute_tensor(expected_tensor, device_mesh, shard_spec) with comm_mode: @@ -68,7 +75,8 @@ def test_shard_to_replicate_forward_backward(self): grad_input = dtensor.grad self.assertEqual(grad_input.placements, shard_spec) self.assertEqual( - grad_input.to_local(), torch.ones(dtensor.to_local().size()) + grad_input.to_local(), + torch.ones(dtensor.to_local().size(), dtype=dtype), ) self.assertEqual(comm_mode.get_total_counts(), 0) @@ -101,10 +109,13 @@ def test_replicate_to_replicate_forward_backward(self): self.assertEqual(comm_mode.get_total_counts(), 0) @with_comms - def test_replicate_to_local_partial_grad(self): + @parametrize("dtype", [torch.float32, torch.cfloat]) + def test_replicate_to_local_partial_grad(self, dtype): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) replica_spec = [Replicate()] - local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) + local_tensor = torch.randn( + 12, 3, device=self.device_type, requires_grad=True, dtype=dtype + ) replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec) @@ -168,13 +179,16 @@ def test_replicate_to_shard_forward_backward(self): ) @with_comms - def test_partial_to_replicate_forward_backward(self): + @parametrize("dtype", [torch.float32, torch.cfloat]) + def test_partial_to_replicate_forward_backward(self, dtype): # Although we don't allow user to reshard to produce a partial # placement (i.e. user can't reshard to partial), we do allow # replicate to partial internally, and also partial to replicate # backward should work as expected device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - partial_local = torch.ones(12, 3, device=self.device_type, requires_grad=True) + partial_local = torch.ones( + 12, 3, device=self.device_type, requires_grad=True, dtype=dtype + ) partial_spec = [Partial()] replica_spec = [Replicate()] @@ -199,7 +213,9 @@ def test_partial_to_replicate_forward_backward(self): global_partial_tensor.backward(torch.ones_like(global_partial_tensor)) self.assertIsNotNone(partial_local.grad) self.assertEqual(partial_local.grad.size(), partial_local.size()) - self.assertEqual(partial_local.grad, torch.ones_like(partial_local)) + self.assertEqual( + partial_local.grad, torch.ones_like(partial_local, dtype=dtype) + ) self.assertEqual(comm_mode.get_total_counts(), 0) @with_comms @@ -312,7 +328,9 @@ def test_shard_to_replicate_forward_backward_datatype_conversion(self): backward_dtype=backward_dtype, ) self.assertEqual(reshard_dtensor.size(), torch.Size(input_size)) - self.assertEqual(expected_tensor, reshard_dtensor.to_local()) + self.assertEqual( + expected_tensor.to(forward_dtype), reshard_dtensor.to_local() + ) self.assertEqual( comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1 ) @@ -378,7 +396,8 @@ def test_replicate_to_partial(self): self.assertEqual(comm_mode.get_total_counts(), 0) @with_comms - def test_partial_to_shard(self): + @parametrize("dtype", [torch.float32, torch.cfloat]) + def test_partial_to_shard(self, dtype): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) partial_spec = [Partial()] my_rank = device_mesh.get_rank() @@ -397,7 +416,7 @@ def test_partial_to_shard(self): for input_size, shard_dim in input_sizes_and_shard_dim: shard_spec = [Shard(shard_dim)] - partial_local = torch.ones(input_size, device=self.device_type) + partial_local = torch.ones(input_size, device=self.device_type, dtype=dtype) partial_tensor = DTensor.from_local( partial_local, device_mesh, partial_spec, run_check=False ) @@ -426,7 +445,7 @@ def test_partial_to_shard(self): self.assertEqual(scatter_shard_tensor.placements, shard_spec) self.assertEqual( scatter_shard_tensor.to_local(), - torch.ones(local_shape) * self.world_size, + torch.ones(local_shape, dtype=dtype) * self.world_size, ) self.assertEqual( comm_mode.get_comm_counts()[funcol.reduce_scatter_tensor], 1 @@ -469,20 +488,21 @@ def test_redistribute_uneven_sharding(self): self.assertEqual(dt_full_tensor, input_tensor) @with_comms - def test_redistribute_shard_dim_change(self): + @parametrize("dtype", [torch.float32, torch.cfloat]) + def test_redistribute_shard_dim_change(self, dtype): # test 1d device mesh mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size)) data_to_test = [ # evenly sharded case - torch.randn((8, 8), device=self.device_type), + torch.randn((8, 8), device=self.device_type, dtype=dtype), # 3d or more dims - torch.randn((8, 8, 8), device=self.device_type), + torch.randn((8, 8, 8), device=self.device_type, dtype=dtype), # uneven case 1 - torch.randn((8, 5), device=self.device_type), + torch.randn((8, 5), device=self.device_type, dtype=dtype), # uneven case 2 - torch.randn((5, 8), device=self.device_type), + torch.randn((5, 8), device=self.device_type, dtype=dtype), # uneven case 3 - torch.randn((5, 5), device=self.device_type), + torch.randn((5, 5), device=self.device_type, dtype=dtype), ] sharding_src_dst_pairs = [([Shard(0)], [Shard(1)]), ([Shard(1)], [Shard(0)])] @@ -518,15 +538,15 @@ def test_redistribute_shard_dim_change(self): ) data_to_test_2d = [ # evenly sharded case - torch.randn((8, 8), device=self.device_type), + torch.randn((8, 8), device=self.device_type, dtype=dtype), # 3d or more dims - torch.randn((8, 8, 8), device=self.device_type), + torch.randn((8, 8, 8), device=self.device_type, dtype=dtype), # uneven case 1 - torch.randn((8, 5), device=self.device_type), + torch.randn((8, 5), device=self.device_type, dtype=dtype), # uneven case 2 - torch.randn((5, 8), device=self.device_type), + torch.randn((5, 8), device=self.device_type, dtype=dtype), # uneven case 3 - torch.randn((5, 5), device=self.device_type), + torch.randn((5, 5), device=self.device_type, dtype=dtype), ] sharding_src_dst_pairs_2d = [ ([Shard(0), Shard(1)], [Shard(0), Shard(0)]), @@ -566,10 +586,11 @@ def test_redistribute_shard_dim_change(self): self.assertEqual(local_out_dt, local_expected_dt) @with_comms - def test_shard_dim_alltoall(self): + @parametrize("dtype", [torch.float32, torch.cfloat]) + def test_shard_dim_alltoall(self, dtype): # init 2d mesh here so we can test when group_rank != global_rank mesh = init_device_mesh(self.device_type, (2, 2)) - tensor = torch.randn(12, self.world_size, device=self.device_type) + tensor = torch.randn(12, self.world_size, device=self.device_type, dtype=dtype) new_tensor = shard_dim_alltoall(tensor, 0, 1, mesh, 0) meta_tensor = torch.randn(12, self.world_size, device="meta") @@ -579,6 +600,9 @@ def test_shard_dim_alltoall(self): self.assertEqual(new_tensor.stride(), new_meta_tensor.stride()) +instantiate_parametrized_tests(RedistributeTest) + + class MultiDimRedistributeTest(DTensorTestBase): @property def world_size(self) -> int: diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 97e3317835e173..c77c97993d3338 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -6,6 +6,7 @@ DeviceMesh, distribute_tensor, DTensor, + init_device_mesh, Partial, Replicate, Shard, @@ -592,6 +593,47 @@ def test_index(self): torch.randint(5, (12, 8, 12)), ) + @with_comms + def test_index_put_scalar(self): + device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) + global_input = torch.randn(2, 4, 8, device=self.device_type) + global_index = [ + torch.randint(global_input.shape[i], size=(), device=self.device_type) + for i in range(3) + ] + global_value = torch.randn(size=(), device=self.device_type) + value_dt = distribute_tensor( + global_value, device_mesh, [Replicate(), Replicate()] + ) + placement_choice_pool = [Shard(0), Shard(1), Replicate()] + for i in placement_choice_pool: + for j in placement_choice_pool: + input_dt = distribute_tensor(global_input, device_mesh, [i, j]) + ref = torch.index_put(global_input, global_index, global_value) + output_dt = torch.index_put(input_dt, global_index, value_dt) + assert isinstance(output_dt, DTensor) + # for value is a scalar case, output placement must be replicate + self.assertEqual(output_dt.placements, (Replicate(), Replicate())) + self.assertEqual(output_dt.full_tensor(), ref) + + @with_comms + def test_index_put_tensor(self): + device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) + global_input = torch.randn(2, 4, 8, device=self.device_type) + global_index = [ + torch.randint(global_input.shape[0], size=(), device=self.device_type) + ] + global_value = torch.zeros([4, 8], device=self.device_type) + value_dt = distribute_tensor(global_value, device_mesh, [Shard(1), Replicate()]) + input_dt = distribute_tensor(global_input, device_mesh, [Shard(0), Replicate()]) + ref = torch.index_put(global_input, global_index, global_value) + output_dt = torch.index_put(input_dt, global_index, value_dt) + assert isinstance(output_dt, DTensor) + # `input_dt` follows `value_dt`'s Shard(1) plus a offset value of + # global_value.ndim-global_input.ndim, which results in Shard(2) + self.assertEqual(output_dt.placements, (Shard(2), Replicate())) + self.assertEqual(output_dt.full_tensor(), ref) + @with_comms def test_where_type_promotion(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # 1D mesh diff --git a/test/distributed/tensor/test_xla_integration.py b/test/distributed/tensor/test_xla_integration.py index 179b5bc796c81b..3fbfcffbd76c6f 100644 --- a/test/distributed/tensor/test_xla_integration.py +++ b/test/distributed/tensor/test_xla_integration.py @@ -26,7 +26,9 @@ def with_xla(func: Callable) -> Callable: @wraps(func) # pyre-ignore[6] def wrapper( - self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] + self, + *args: tuple[object], + **kwargs: dict[str, Any], # type: ignore[misc] ) -> None: # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. os.environ["XLA_USE_SPMD"] = "1" diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 3a096b2781fa21..efac131e6c380d 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -15,7 +15,6 @@ from itertools import product from sys import platform from typing import Optional -from unittest.mock import patch import torch import torch.distributed as dist @@ -1900,66 +1899,6 @@ def test_send_recv(self): # intentionally not calling into `destroy_process_group` as not all # user applications would explicitly that. - @patch.object(dist.ProcessGroup, "group_name", "custom") - def test_comm_split_group(self): - dist.Backend.register_backend( - "dummy", - PythonProcessGroupExtensionTest.create_dummy_ext, - extended_api=True, - devices=["cuda"], - ) - - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "6789" - - dist.init_process_group( - "dummy", - rank=self.rank, - world_size=self.world_size, - device_id=torch.device(f"cuda:{self.rank}"), - ) - - split_group_size = self.world_size // 2 - split_group_rank = self.rank % split_group_size - - all_group_ranks = [ - list(range(i * split_group_size, (i + 1) * split_group_size)) - for i in range(2) - ] - pg_opts = PythonProcessGroupExtensionTest.Options() - - registered_backend = None - - def _register_backend_side_effect(*args, **kwargs): - nonlocal registered_backend - registered_backend = args[2] - - def _get_backend_side_effect(*args, **kwargs): - return registered_backend - - # Need to patch these methods in absence of a true c10d::Backend wrapper - with patch.object( - dist.ProcessGroup, - "_register_backend", - side_effect=_register_backend_side_effect, - ), patch.object( - dist.ProcessGroup, "_get_backend", side_effect=_get_backend_side_effect - ): - split_pg = dist.split_group( - split_ranks=all_group_ranks, - group_desc="split_pg", - pg_options=pg_opts, - ) - - if split_pg is not None: - self.assertEqual( - dist.get_group_rank(split_pg, self.rank), split_group_rank - ) - self.assertEqual(dist.get_world_size(split_pg), split_group_size) - - dist.destroy_process_group(split_pg) - dist.destroy_process_group() - def test_shutdown(self) -> None: dist.Backend.register_backend( "dummy", PythonProcessGroupExtensionTest.create_dummy @@ -2295,8 +2234,8 @@ def testNodeLocalRank(self): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index ff51d18e415f56..17a7966a15848c 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -9,7 +9,7 @@ import torch.distributed as dist import torch.distributed._functional_collectives as funcol from torch._C import FileCheck -from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code +from torch._inductor.utils import fresh_cache, run_and_get_code, run_and_get_triton_code from torch.distributed._functional_collectives import ( all_gather_into_tensor_coalesced, all_gather_tensor, @@ -464,7 +464,7 @@ def test_unwaited(self) -> None: @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - @fresh_inductor_cache() + @fresh_cache() def test_threading(self): self._init_process_group() device = torch.device(f"cuda:{self.rank}") @@ -510,7 +510,7 @@ def join(self): "_scaled_mm currently only supports sm>=90", ) @skip_if_lt_x_gpu(2) - @fresh_inductor_cache() + @fresh_cache() def test_fixed_striding(self): self._init_process_group() @@ -713,6 +713,61 @@ def test_collectives(self) -> None: self.assertEqual(pg.dels, 4) +class CompileTestCPU(TestCase): + def setUp(self): + super().setUp() + + if not dist.is_initialized(): + self.rank = 0 + self.world_size = 2 + + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + + def tearDown(self): + dist.destroy_process_group() + + @fresh_cache() + def _test_inductor_all_reduce_cpu(self, cpp_wrapper=False): + def func(arg: torch.Tensor) -> torch.Tensor: + buf0 = arg + 42 + ar0 = funcol.all_reduce(buf0, "avg", "0") + ar0 = funcol.wait_tensor(ar0) + return ar0 + + arg = torch.rand(4, 4, device="cpu") + torch._inductor.config.cpp_wrapper = cpp_wrapper + compiled = torch.compile(func) + + _, (code,) = run_and_get_code(compiled, arg) + include_ops = ( + [ + "aoti_torch_cpu__c10d_functional_all_reduce_", + "aoti_torch_cpu__c10d_functional_wait_tensor", + ] + if cpp_wrapper + else [ + "torch.ops._c10d_functional.all_reduce_.default", + "torch.ops._c10d_functional.wait_tensor.default", + ] + ) + for op in include_ops: + self.assertIn(op, code) + + # Test aoti + AOTIRunnerUtil.run(func, (arg,)) + torch.cpu.synchronize() + + def test_inductor_all_reduce_cpu(self): + self._test_inductor_all_reduce_cpu(cpp_wrapper=False) + self._test_inductor_all_reduce_cpu(cpp_wrapper=True) + + class CompileTest(TestCase): def setUp(self): super().setUp() @@ -736,7 +791,7 @@ def tearDown(self): dist.destroy_process_group() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_all_reduce_single(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 @@ -773,7 +828,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: torch.cuda.synchronize() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_all_reduce_coalesced(self): def func(args: list[torch.Tensor]) -> torch.Tensor: bufs = [arg + 42 for arg in args] @@ -796,13 +851,11 @@ def func(args: list[torch.Tensor]) -> torch.Tensor: .check("buf6 = empty") # Expect in-place with inductor allocated buf .check( - "torch.ops._c10d_functional.all_reduce_coalesced_" - ".default([buf0, buf1]" + "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]" ) # Expect no in-place with graph input (buf5, buf6 are clones) .check( - "torch.ops._c10d_functional.all_reduce_coalesced_" - ".default([buf5, buf6]" + "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]" ) .check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("torch.ops._c10d_functional.wait_tensor.default(buf1") @@ -819,7 +872,7 @@ def func(args: list[torch.Tensor]) -> torch.Tensor: torch.cuda.synchronize() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_inplace_op_on_view(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = (arg + 10)[:2] @@ -843,7 +896,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_all_reduce_non_contig_input(self): def func(arg: torch.Tensor) -> torch.Tensor: ar0 = funcol.all_reduce(arg, "avg", "0") @@ -869,7 +922,7 @@ def func2(arg: torch.Tensor) -> torch.Tensor: assert "torch.ops._c10d_functional.wait_tensor.default" in code @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_reuse_buffer_after_inplace_collective(self): def func(arg: torch.Tensor) -> torch.Tensor: # Expect allocation @@ -904,7 +957,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: assert "= torch.ops._c10d_functional.wait_tensor.default" not in code @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_all_gather_into_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: ag0 = funcol.all_gather_tensor(arg, 0, "0") @@ -931,7 +984,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: torch.cuda.synchronize() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_all_gather_into_tensor_coalesced(self): def func(args: list[torch.Tensor]) -> torch.Tensor: ag0 = funcol.all_gather_into_tensor_coalesced(args, "0") @@ -965,7 +1018,7 @@ def func(args: list[torch.Tensor]) -> torch.Tensor: torch.cuda.synchronize() @unittest.skipIf(not HAS_GPU, "This is a GPU test!") - @fresh_inductor_cache() + @fresh_cache() def test_wait_tensor(self): def func(arg: torch.Tensor) -> torch.Tensor: t = torch.ops._c10d_functional.all_reduce(arg, "avg", "0") @@ -987,7 +1040,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: torch.cuda.synchronize() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_reduce_scatter_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0") @@ -1013,7 +1066,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: torch.cuda.synchronize() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_reduce_scatter_tensor_coalesced(self): def func(args: list[torch.Tensor]) -> torch.Tensor: rs0 = funcol.reduce_scatter_tensor_coalesced( @@ -1049,7 +1102,7 @@ def func(args: list[torch.Tensor]) -> torch.Tensor: torch.cuda.synchronize() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_all_to_all_single(self): def _tolist_with_constrain_as_size(tensor): lst = tensor.tolist() @@ -1097,7 +1150,7 @@ def func( ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_inductor_broadcast(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 @@ -1134,7 +1187,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: torch.cuda.synchronize() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @fresh_inductor_cache() + @fresh_cache() def test_ranks_and_tag(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 04805a2e1ddf4a..96ad01b95b18c2 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -1,13 +1,16 @@ # Owner(s): ["oncall: distributed"] import copy +import json import logging import math import operator import os +import pickle import random import sys import tempfile +import time from datetime import timedelta from functools import reduce from itertools import groupby @@ -54,6 +57,7 @@ retry_on_connect_failures, run_tests, skip_but_pass_in_sandcastle, + skipIfRocm, skipIfRocmArch, TestCase, ) @@ -242,11 +246,12 @@ def setUp(self): super().setUp() self._spawn_processes() - def opts(self, threads=2): + def opts(self, threads=2, group_name="0"): opts = c10d.ProcessGroupGloo._Options() opts._timeout = 50.0 opts._devices = [create_device(interface=LOOPBACK, lazy_init=self.lazy_init)] opts._threads = threads + opts.group_name = group_name return opts @requires_gloo() @@ -408,6 +413,7 @@ def test_broadcast_stress(self): @skip_if_lt_x_gpu(2) @requires_gloo() + @skipIfRocm def test_broadcast_stress_cuda(self): inputs = [ torch.tensor([i * self.world_size + self.rank]).cuda() for i in range(1000) @@ -476,6 +482,30 @@ def _test_allreduce_basics(self, fn): result[0], ) + # Test fp16 numerical correctness for all-reduce SUM. + torch.manual_seed(self.rank) + # TODO: when create larger sizes of tensors, numerical instability will be observed. + # We need to investigate the root cause and ensure it is fixed. + tensor = ( + (torch.rand(200, 1, dtype=torch.float32) * 2 - 1) * 65504 / self.world_size + ) + opts = c10d.AllreduceOptions() + tensor = tensor.to(torch.float16) + output = [[torch.zeros_like(tensor) for _ in range(self.world_size)]] + # allgather all local tensors first and then sum up. + fut = pg.allgather(output, [tensor]).get_future() + fut.wait() + ag_result = fut.value() + total = torch.stack(ag_result, dim=0).sum(dim=0) + + # result from fp16 all-reduce. + fut = pg.allreduce([tensor], opts).get_future() + fut.wait() + result_fp16 = fut.value() + # float16 has only ~11 bits of mantissa, and is sensitive to accumulation + # order and rounding errors so we use a larger tolerance. + self.assertEqual(total, result_fp16[0], rtol=1e-2, atol=1e-3) + @requires_gloo() def test_allreduce_basics(self): self._test_allreduce_basics(lambda t: t.clone()) @@ -513,6 +543,7 @@ def test_allreduce_stress(self): @skip_if_lt_x_gpu(2) @requires_gloo() + @skipIfRocm def test_allreduce_stress_cuda(self): inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)] self._test_allreduce_stress(inputs) @@ -966,6 +997,7 @@ def test_scatter_stress(self): ) @skip_if_lt_x_gpu(2) @requires_gloo() + @skipIfRocm def test_scatter_stress_cuda(self): inputs = [ [torch.tensor([i + self.rank]) for _ in range(self.world_size)] @@ -1276,6 +1308,7 @@ def test_allgather_stress(self): @skip_if_lt_x_gpu(2) @requires_gloo() + @skipIfRocm def test_allgather_stress_cuda(self): inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)] self._test_allgather_stress(inputs, lambda t: t.clone().cuda()) @@ -1462,6 +1495,7 @@ def test_reduce_stress(self): @skip_if_lt_x_gpu(2) @requires_gloo() + @skipIfRocm def test_reduce_stress_cuda(self): inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)] self._test_reduce_stress(inputs) @@ -2361,6 +2395,122 @@ def tearDown(self) -> None: return super().tearDown() +class ProcessGroupGlooFRTest(ProcessGroupGlooTest): + def setUp(self): + os.environ["TORCH_FR_BUFFER_SIZE"] = "10" + super().setUp() + + def tearDown(self) -> None: + del os.environ["TORCH_FR_BUFFER_SIZE"] + return super().tearDown() + + def _verify_trace(self, t, is_json): + ver = t["version"] + self.assertEqual(ver, "2.9") + pg_config = t["pg_config"] + self.assertEqual(len(pg_config), 1) + default_pg_info = pg_config["0"] + self.assertIn("name", default_pg_info) + self.assertIn("desc", default_pg_info) + self.assertIn("ranks", default_pg_info) + pg_status = t["pg_status"] + self.assertEqual(len(pg_status), 1) + self.assertEqual(str(pg_status["0"]["last_enqueued_collective"]), "3") + self.assertEqual(str(pg_status["0"]["last_completed_collective"]), "3") + self.assertEqual( + str(pg_status["0"]["last_started_collective"]), + "-1", + ) + global_ranks = pg_config["0"]["ranks"] + self.assertEqual(len(json.loads(global_ranks)), self.world_size) + self.assertEqual(len(t["entries"]), 3) + t = t["entries"] + last = t[-1] + self.assertEqual(last["process_group"], ("0", "")) + # No event recorded for Gloo. + self.assertEqual(last["state"], "scheduled") + # we don't collect stack traces in JSON at the moment + if not is_json: + self.assertIn("test_c10d_gloo.py", str(last["frames"])) + self.assertEqual(last["input_sizes"], ((3, 4),)) + self.assertEqual(last["input_dtypes"], ["Float"]) + self.assertEqual(last["output_sizes"], ((3, 4),)) + self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["collective_seq_id"], 3) + # TODO: Needs verification + self.assertEqual(last["timeout_ms"], 50000) + self.assertTrue("duration_ms" not in last) + + @requires_gloo() + def test_short_json(self): + store = c10d.FileStore(self.file_name, self.world_size) + pg = self._create_process_group_gloo( + store, self.rank, self.world_size, self.opts(group_name="0") + ) + a = torch.full((3, 4), float(self.rank)) + for _ in range(2): + f = pg.allreduce(a) + f.wait() + time.sleep(1) + t = json.loads( + torch._C._distributed_c10d._dump_fr_trace_json(includeCollectives=True) + ) + self._verify_trace(t, True) + + @requires_gloo() + def test_short_pickle(self): + store = c10d.FileStore(self.file_name, self.world_size) + pg = self._create_process_group_gloo( + store, self.rank, self.world_size, self.opts(group_name="0") + ) + a = torch.full((3, 4), float(self.rank)) + for _ in range(2): + f = pg.allreduce(a) + f.wait() + time.sleep(1) + t = pickle.loads( + torch._C._distributed_c10d._dump_fr_trace(includeCollectives=True) + ) + self._verify_trace( + t, + is_json=False, + ) + + @requires_gloo() + def test_long(self): + store = c10d.FileStore(self.file_name, self.world_size) + pg = self._create_process_group_gloo( + store, self.rank, self.world_size, self.opts(group_name="0") + ) + a = torch.full((3, 4), float(self.rank)) + for _ in range(2): + # test some other primitives to make sure + # their strings are valid + xs = [torch.ones(3, 4)] + pg.broadcast(xs).wait() + pg.allreduce(xs).wait() + pg.reduce(xs).wait() + ys = [[torch.empty(3, 4) for _ in range(self.world_size)]] + pg.allgather(ys, xs).wait() + pg.reduce_scatter(xs, ys).wait() + f = pg.allreduce(a) + f.wait() + t = pickle.loads(torch._C._distributed_c10d._dump_fr_trace()) + t = t["entries"] + self.assertEqual(len(t), 10) + first = t[0] + last = t[-1] + self.assertEqual(last["profiling_name"], "gloo:all_reduce") + self.assertEqual(last["state"], "scheduled") + self.assertIn("test_c10d_gloo.py", str(last["frames"])) + self.assertEqual(last["input_sizes"], ((3, 4),)) + self.assertEqual(last["input_dtypes"], ["Float"]) + self.assertEqual(last["output_sizes"], ((3, 4),)) + self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["timeout_ms"], 50000) + self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9) + + class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): @property def device(self): @@ -2555,8 +2705,8 @@ def test_new_group_local_sync_duplicate_pg(self): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_logger.py b/test/distributed/test_c10d_logger.py index de72646405af58..bbbcd2c751a6ae 100644 --- a/test/distributed/test_c10d_logger.py +++ b/test/distributed/test_c10d_logger.py @@ -2,7 +2,6 @@ import json import logging -import os import re import sys from functools import partial, wraps @@ -16,10 +15,13 @@ print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) -from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS +from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS +from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN +device_type = str(get_devtype()) + if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", @@ -27,8 +29,7 @@ ) sys.exit(0) -BACKEND = dist.Backend.NCCL -WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) +WORLD_SIZE = min(4, max(2, torch.get_device_module(device_type).device_count())) def with_comms(func=None): @@ -39,30 +40,16 @@ def with_comms(func=None): @wraps(func) def wrapper(self, *args, **kwargs): - if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: + if torch.get_device_module(device_type).device_count() < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) - self.dist_init() + self.create_pg(device_type) func(self) self.destroy_comms() return wrapper -class C10dErrorLoggerTest(MultiProcessTestCase): - def setUp(self): - super().setUp() - os.environ["WORLD_SIZE"] = str(self.world_size) - os.environ["BACKEND"] = BACKEND - self._spawn_processes() - - @property - def device(self): - return ( - torch.device(self.rank) - if BACKEND == dist.Backend.NCCL - else torch.device("cpu") - ) - +class C10dErrorLoggerTest(DistributedTestBase): @property def world_size(self): return WORLD_SIZE @@ -76,18 +63,6 @@ def destroy_comms(self): dist.barrier() dist.destroy_process_group() - def dist_init(self): - dist.init_process_group( - backend=BACKEND, - world_size=self.world_size, - rank=self.rank, - init_method=f"file://{self.file_name}", - ) - - # set device for nccl pg for collectives - if BACKEND == "nccl": - torch.cuda.set_device(self.rank) - def test_get_or_create_logger(self): self.assertIsNotNone(_c10d_logger) self.assertEqual(1, len(_c10d_logger.handlers)) @@ -117,7 +92,11 @@ def test_exception_logger(self) -> None: re.search("({.+})", captured.output[0]).group(0).replace("'", '"') ) - self.assertEqual(len(error_msg_dict), 9) + # NCCL adds additional nccl_version data to the error_msg_dict + if self.backend(device_type) == dist.Backend.NCCL: + self.assertEqual(len(error_msg_dict), 9) + else: + self.assertEqual(len(error_msg_dict), 8) self.assertIn("pg_name", error_msg_dict.keys()) self.assertEqual("None", error_msg_dict["pg_name"]) @@ -126,13 +105,14 @@ def test_exception_logger(self) -> None: self.assertEqual("broadcast", error_msg_dict["func_name"]) self.assertIn("backend", error_msg_dict.keys()) - self.assertEqual("nccl", error_msg_dict["backend"]) - - self.assertIn("nccl_version", error_msg_dict.keys()) - nccl_ver = torch.cuda.nccl.version() - self.assertEqual( - ".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"] - ) + self.assertEqual(self.backend(device_type), error_msg_dict["backend"]) + + if self.backend(device_type) == dist.Backend.NCCL: + self.assertIn("nccl_version", error_msg_dict.keys()) + nccl_ver = torch.cuda.nccl.version() + self.assertEqual( + ".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"] + ) # In this test case, group_size = world_size, since we don't have multiple processes on one node. self.assertIn("group_size", error_msg_dict.keys()) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index f1691eeb877b80..c02e968e23fb62 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -39,7 +39,7 @@ from torch import nn from torch._C._distributed_c10d import ErrorType, OpType, WorkResult from torch.nn.parallel import DistributedDataParallel -from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU from torch.testing._internal.common_distributed import ( get_timeout, init_multigpu_helper, @@ -285,10 +285,9 @@ def opts(self, high_priority_stream=False): def setUp(self): super().setUp() - # These tests are expected to throw SIGABRT(6); adding the negative sign - # bc the test return code is actually -6 + # These tests are expected to throw SIGABRT(6); # But if we are in Sandcastle, `skip_but_pass_in_sandcastle` would return 0. - TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else -signal.SIGABRT + TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else signal.SIGABRT self.special_return_code_checks = { self.test_nan_assert_float16.__wrapped__: TEST_NAN_ASSERT_RETURN, self.test_nan_assert_float32.__wrapped__: TEST_NAN_ASSERT_RETURN, @@ -455,7 +454,6 @@ def test_cuda_event_cache_mthd_race(self): # This unit test is to test the case when the collective is launched in # a side thread and the thread dies before the cache has been fully recycled. # More details can be found in this issue: https://github.com/pytorch/pytorch/issues/143470. - import threading # initiate collectives here def init_collective_task(t): @@ -478,6 +476,9 @@ def init_collective_task(t): side_thread.join() torch.cuda.synchronize() + # reset ENV + os.environ["TORCH_NCCL_CUDA_EVENT_CACHE"] = "0" + CUDA_12_AND_ABOVE = torch.cuda.is_available() and ( torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12 ) @@ -485,7 +486,7 @@ def init_collective_task(t): @requires_nccl() @skip_but_pass_in_sandcastle_if( # skip for cu126 as well due to https://github.com/pytorch/pytorch/issues/153479 - not (TEST_MULTIGPU and CUDA_12_AND_ABOVE and False), + not (TEST_MULTIGPU and CUDA_12_AND_ABOVE), "NCCL test requires 2+ GPUs and Device side assert could cause unexpected errors in lower versions of CUDA", ) @parametrize( @@ -539,10 +540,15 @@ def test_nan_assert(self, type): backend._set_enable_nan_check(False) # Note: using all-gather here bc some NCCL/SM version does not support # FP8 reduction - pg._allgather_base(output, nan_tensor) + # temporarily skip due to https://github.com/pytorch/pytorch/issues/153479 + # pg._allgather_base(output, nan_tensor) backend._set_enable_nan_check(True) - pg._allgather_base(output, nan_tensor) + try: + pg._allgather_base(output, nan_tensor) + except Exception: + sys.exit(signal.SIGABRT) + dist.destroy_process_group() # reset env @@ -1102,7 +1108,6 @@ def test_non_blocking_init(self): def test_non_blocking_with_eager_init(self): # Test creating a pg eagerly with nonblocking mode when # we've passed a specific device_id to init_process_group. - raise SkipTest("Skip due to https://github.com/pytorch/pytorch/issues/153517") os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1" os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100" store = c10d.FileStore(self.file_name, self.world_size) @@ -1189,6 +1194,29 @@ def test_set_process_group_desc(self): pg_2 = c10d.new_group([0, 1]) self.assertEqual(pg_2.group_desc, "undefined") + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_deterministic_mode_no_break(self): + torch.use_deterministic_algorithms(True) + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device(f"cuda:{self.rank}") + self._create_process_group_nccl(store, self.opts(), device_id=device) + tensor = torch.empty(10, 10, device=device) + dist.all_reduce(tensor) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_init_with_idx(self): + store = c10d.FileStore(self.file_name, self.world_size) + device_idx = self.rank + dist.init_process_group( + world_size=self.world_size, + rank=self.rank, + store=store, + device_id=device_idx, + ) + dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx))) + class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase @@ -2841,9 +2869,9 @@ def test_nccl_errors_nonblocking(self): self.assertTrue(t.is_alive()) if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @skip_if_lt_x_gpu(3) @@ -2903,9 +2931,9 @@ def test_nccl_non_blocking_wait_with_barrier(self): os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" self._test_barrier_error() if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @@ -2956,9 +2984,9 @@ def assert_fut_success(fut): process_group.abort() if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @@ -3037,9 +3065,9 @@ def test_restart_pg_after_error(self): os.remove(new_file_name) if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) def _run_invalid_nccl_blocking_wait_env(self, val): os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val @@ -3056,7 +3084,7 @@ def test_invalid_nccl_blocking_wait_env(self): self._run_invalid_nccl_blocking_wait_env("4294967295") -class NcclUserBufferRegistrationTest(MultiProcessTestCase): +class NcclRegistrationTest(MultiProcessTestCase): def setUp(self): super().setUp() # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests @@ -3067,7 +3095,7 @@ def setUp(self): os.environ["NCCL_DEBUG"] = "INFO" os.environ["NCCL_DEBUG_SUBSYS"] = "NVLS" if torch.cuda.nccl.version() >= (2, 24, 3): - os.environ["NCCL_DEBUG_SUBSYS"] = "REG" + os.environ["NCCL_DEBUG_SUBSYS"] = "REG,TUNING" os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name self._spawn_processes() @@ -3123,6 +3151,48 @@ def test_nccl_user_buffer_registration(self): else: self.assertRegex(nccl_debug_file_content, "local-registered") + @requires_nccl() + @requires_nccl_version((2, 27), "Need NCCL 2.27 for window registration") + @skip_if_lt_x_gpu(4) + @requires_multicast_support() + def test_nccl_window_registration(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="nccl", rank=self.rank, world_size=self.world_size, store=store + ) + device = torch.device(f"cuda:{self.rank}") + torch.cuda.set_device(self.rank) + pg = c10d.distributed_c10d._get_default_group() + backend = pg._get_backend(torch.device(device)) + + # Use NCCL memory allocator + # enable symmetric memory usage in NCCL + pool = torch.cuda.MemPool(backend.mem_allocator, symm_mem=True) + + # allocate memory with ncclMemAlloc + # note: symmetric kernels are not available for dtypes like torch.int64 + with torch.cuda.use_mem_pool(pool): + tensor = torch.arange(1024 * 1024 * 2, device=device, dtype=torch.float32) + + # register buffers to NCCL + backend.register_mem_pool(pool) + + # allreduce now should use NVIDIA Switches + pg.allreduce(tensor).wait() + torch.cuda.synchronize(device=device) + + # de-register buffers from NCCL + backend.deregister_mem_pool(pool) + + # clean up memory + del tensor, pool + + with open(os.environ["NCCL_DEBUG_FILE"]) as f: + nccl_debug_file_content = f.read() + # if buffers were registered and symmetric kernels ran, NCCL_DEBUG + # should show successful registration in debug output + self.assertRegex(nccl_debug_file_content, "[Symmetric]") + class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): @property @@ -3290,9 +3360,7 @@ def test_intra_node_comm_all_reduce(self): self.assertEqual(_get_intra_node_comm_usage_counter(), 3) # Verify that IntraNodeComm is not used beyond 10MB - t = torch.full( - (10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16 - ).cuda() + t = torch.full((10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda() c10d.all_reduce(t, c10d.ReduceOp.SUM) self.assertTrue(t.eq(expect).all()) self.assertEqual(_get_intra_node_comm_usage_counter(), 3) @@ -3409,6 +3477,21 @@ def test_pass_nccl_options_config(self): self.assertEqual(pg_opts.config.net_name, net_name.decode()) self.assertEqual(pg_opts.config.split_share, int(split_share)) + # Tests that config is inited correctly + pg_opts = c10d.ProcessGroupNCCL.Options() + nccl_cfg = c10d.ProcessGroupNCCL.NCCLConfig() + self.assertEqual(pg_opts.config.min_ctas, -2147483648) + self.assertEqual(nccl_cfg.min_ctas, -2147483648) + + # Tests that opts and config can be copied + pg_opts_2 = copy.deepcopy(pg_opts) + nccl_cfg_2 = copy.copy(pg_opts_2.config) + pg_opts_2.config.min_ctas = 2 + nccl_cfg_2.min_ctas = 4 + self.assertEqual(pg_opts.config.min_ctas, -2147483648) + self.assertEqual(pg_opts_2.config.min_ctas, 2) + self.assertEqual(nccl_cfg_2.min_ctas, 4) + @requires_nccl() @skip_if_lt_x_gpu(4) def test_nccl_barrier(self): @@ -4164,9 +4247,9 @@ def test_ddp_set_sparse_metadata(self): class NCCLTraceTestBase(MultiProcessTestCase): def setUp(self): super().setUp() - os.environ[ - "TORCH_NCCL_ENABLE_TIMING" - ] = "0" # see 'timing_enabled' parametrized tests + os.environ["TORCH_NCCL_ENABLE_TIMING"] = ( + "0" # see 'timing_enabled' parametrized tests + ) os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000" os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1" self.tempdir = tempfile.TemporaryDirectory() @@ -4252,7 +4335,7 @@ def started_or_scheduled(self, timing_enabled): class NCCLTraceTest(NCCLTraceTestBase): def _verify_trace(self, t, include_collectives, timing_enabled, is_json): ver = t["version"] - self.assertEqual(ver, "2.7") + self.assertEqual(ver, "2.9") nccl_version = t["nccl_version"] torch_nccl_version = torch.cuda.nccl.version() self.assertEqual(nccl_version, ".".join(str(v) for v in torch_nccl_version)) @@ -4276,6 +4359,8 @@ def _verify_trace(self, t, include_collectives, timing_enabled, is_json): self.assertEqual(len(t["entries"]), 2) t = t["entries"] last = t[-1] + self.assertEqual(last["thread_id"], str(threading.current_thread().ident)) + self.assertEqual(last["thread_name"], "fr_test_thread") self.assertEqual(last["process_group"], ("0", "default_pg")) self.assertEqual(last["state"], "completed") s = last["time_discovered_started_ns"] @@ -4308,6 +4393,35 @@ def _verify_trace(self, t, include_collectives, timing_enabled, is_json): else: self.assertTrue("entries" not in t) + def load_libpthread_or_libc(self): + import ctypes.util + + for base in ("pthread", "c"): + path = ctypes.util.find_library(base) + if path: + try: + return ctypes.CDLL(path) + except OSError: + continue + raise RuntimeError("Could not load pthread or libc") + + # Directly set thread name using threading.current_thread().name does not work + # because we use pthread_getname_np to get the thread’s OS-level name in C++ + def set_thread_name(self, name): + import ctypes + + lib = self.load_libpthread_or_libc() + pthread_self = lib.pthread_self + pthread_self.restype = ctypes.c_void_p + pthread_setname_np = lib.pthread_setname_np + pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + + # Get current pthread handle + tid = pthread_self() + + # Set name + pthread_setname_np(tid, name.encode()) + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("timing_enabled", [True, False]) @@ -4319,6 +4433,7 @@ def test_short_json(self, timing_enabled, include_collectives): if timing_enabled: pg._enable_collectives_timing() device = self.local_device + self.set_thread_name("fr_test_thread") a = torch.full((3, 4), float(self.rank), device=device) for _ in range(2): f = pg.allreduce(a) @@ -4345,6 +4460,7 @@ def test_short_pickle(self, timing_enabled, include_collectives): if timing_enabled: pg._enable_collectives_timing() device = self.local_device + self.set_thread_name("fr_test_thread") a = torch.full((3, 4), float(self.rank), device=device) for _ in range(2): f = pg.allreduce(a) @@ -4526,9 +4642,20 @@ def test_trace_while_active(self, timing_enabled, only_active): else: self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") self.assertEqual(t[-1]["collective_seq_id"], 2) - self.assertEqual( - t[-1]["state"], self.started_or_scheduled(timing_enabled) - ) + + # ROCm runtime used to call uSleep(20 µs)inside the default‑signal busy-wait loop. + # Now, this sleep is removed which lets the host thread spin continuously + # Therefore, the state can either be scheduled or started before test dumps the trace. + if ( + torch.version.hip + and _get_torch_rocm_version() >= (6, 4) + and timing_enabled + ): + assert t[-1]["state"] in ("scheduled", "started") + else: + self.assertEqual( + t[-1]["state"], self.started_or_scheduled(timing_enabled) + ) self.parent.send("next") self.assertEqual("next", self.parent.recv()) @@ -5202,8 +5329,8 @@ def test_comm_recursive_split_group(self): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_spawn.py b/test/distributed/test_c10d_spawn.py index 01f7d8adb55190..adddc7f71afda3 100644 --- a/test/distributed/test_c10d_spawn.py +++ b/test/distributed/test_c10d_spawn.py @@ -8,7 +8,7 @@ import torch.distributed as c10d import torch.multiprocessing as mp from torch.testing._internal.common_distributed import MultiProcessTestCase -from torch.testing._internal.common_utils import load_tests +from torch.testing._internal.common_utils import load_tests, run_tests # Torch distributed.nn is not available in windows @@ -246,3 +246,7 @@ def _test_all_to_all_single(self, backend): z.backward() x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos() self.assertEqual(x.grad, x_s) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_c10d_ucc.py b/test/distributed/test_c10d_ucc.py index e63c5f81924e84..e3a4764d594f2a 100644 --- a/test/distributed/test_c10d_ucc.py +++ b/test/distributed/test_c10d_ucc.py @@ -42,7 +42,6 @@ run_tests, skip_but_pass_in_sandcastle, TestCase, - xfailIfLinux, ) @@ -674,7 +673,6 @@ def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model): vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce() ) - @xfailIfLinux @requires_ucc() @skip_if_lt_x_gpu(2) def test_save_load_checkpoint(self): @@ -1092,8 +1090,8 @@ def test_allgather_base(self): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_collective_utils.py b/test/distributed/test_collective_utils.py index ee93d56efb8fee..a150a55f77be60 100644 --- a/test/distributed/test_collective_utils.py +++ b/test/distributed/test_collective_utils.py @@ -5,6 +5,7 @@ import torch.distributed as c10d from torch.distributed.collective_utils import all_gather, broadcast from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_utils import run_tests class TestCollectiveUtils(MultiProcessTestCase): @@ -89,9 +90,9 @@ def test_all_gather_result(self) -> None: res = all_gather(data_or_fn=func, pg=pg) func.assert_called_once() - assert res == list( - range(self.world_size) - ), f"Expect res to be list of 0 through {self.world_size} (got {res})" + assert res == list(range(self.world_size)), ( + f"Expect res to be list of 0 through {self.world_size} (got {res})" + ) def test_all_gather_result_no_pg(self) -> None: """ @@ -114,3 +115,7 @@ def test_all_gather_result_raises_exceptions_from_func( expected_exception = "test exception" with self.assertRaisesRegex(Exception, expected_exception): all_gather(data_or_fn=func) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index d7a84876dc9730..56611c6891e53e 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -26,12 +26,16 @@ _dynamo_dist_per_rank_init, at_least_x_gpu, DynamoDistributedMultiProcTestCase, - requires_nccl, + requires_accelerator_dist_backend, ) +from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.inductor_utils import HAS_GPU +device_type = str(get_devtype()) + + def get_snode_runtime_for_reorder_compute_test(snode): # NOTE: custom cost model to show that the compute reordering algorithm is working # Collective kernels @@ -74,7 +78,7 @@ def create_grouped_node_for_allreduce_and_its_deps(snodes): return new_snode_order -@requires_nccl() +@requires_accelerator_dist_backend() class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): """ Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under @@ -113,9 +117,12 @@ def func(a): return torch.matmul(ar, b) with _dynamo_dist_per_rank_init( - self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), ): - inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank + inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs) # Verify that the wait_tensor is sinked below the 1st matmul but @@ -154,9 +161,12 @@ def func(a): return torch.matmul(d, e) with _dynamo_dist_per_rank_init( - self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), ): - inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank + inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs) # Verify that the all_reduce_ has been raised above the 2nd matmul @@ -202,9 +212,12 @@ def func(a, *, tag, ranks, group_size): return torch.mm(e, g) with _dynamo_dist_per_rank_init( - self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), ): - inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank + inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # Things to verify: @@ -255,9 +268,12 @@ def func(a, *, tag, ranks, group_size): return (e,) with _dynamo_dist_per_rank_init( - self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), ): - inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank + inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # NOTE: after scheduling the first all_reduce: @@ -312,9 +328,12 @@ def func(a, *, tag, ranks, group_size): return (e,) with _dynamo_dist_per_rank_init( - self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), ): - inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank + inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # NOTE: after scheduling the first all_reduce: @@ -362,9 +381,12 @@ def func(a, *, tag, ranks, group_size): return (mm,) with _dynamo_dist_per_rank_init( - self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), ): - inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank + inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # Expectations: @@ -387,9 +409,9 @@ def test_inductor_default_comms_ordering(self): ranks = pg_info["ranks"] group_size = pg_info["group_size"] - g1 = torch.ones(10, 10, device="cuda") - g2 = torch.ones(11, 11, device="cuda") - g3 = torch.ones(12, 12, device="cuda") + g1 = torch.ones(10, 10, device=device_type) + g2 = torch.ones(11, 11, device=device_type) + g3 = torch.ones(12, 12, device=device_type) def assert_pass(graph): # all_reduces need to remain in order! @@ -429,7 +451,9 @@ def fn(g1, g2, g3): grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1) return grad3, grad2, grad1 - with _dynamo_dist_per_rank_init(self.rank, self.world_size, fake_pg=True): + with _dynamo_dist_per_rank_init( + self.rank, self.world_size, self.backend(device_type), fake_pg=True + ): fn(g1, g2, g3) def test_nccl_heuristics(self): diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py index 594c028ae9d47c..8e48735c777994 100644 --- a/test/distributed/test_control_collectives.py +++ b/test/distributed/test_control_collectives.py @@ -207,8 +207,8 @@ def f(rank: int) -> None: if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 7ad4c33de431e3..06502943934f0e 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -490,8 +490,9 @@ def test_from_group_with_mesh_shape_2d(self): # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) # and assign the correct shard group to each rank - shard_rank_lists = list(range(0, self.world_size // 2)), list( - range(self.world_size // 2, self.world_size) + shard_rank_lists = ( + list(range(0, self.world_size // 2)), + list(range(self.world_size // 2, self.world_size)), ) shard_groups = ( new_group(shard_rank_lists[0]), diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 7d43bf730b9eb2..73ac6eb0da7bd9 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -25,6 +25,7 @@ from torch._dynamo.testing import collect_results from torch._dynamo.utils import same from torch._higher_order_ops.wrap import tag_activation_checkpoint +from torch.compiler import set_enable_guard_collectives from torch.distributed._functional_collectives import _maybe_wrap_tensor from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import ( @@ -61,6 +62,15 @@ def init_weights(m): m.bias.data.fill_(0.01) +@contextmanager +def enable_guard_collectives(): + old = set_enable_guard_collectives(True) + try: + yield + finally: + set_enable_guard_collectives(old) + + class ToyModel(nn.Module): def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None): super().__init__() @@ -668,6 +678,50 @@ def test_fsdp_aot_eager(self): outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) + @skip_if_lt_x_gpu(2) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + def test_ddp_optimizer_cudagraph(self): + class Net(nn.Module): + def __init__(self): + super().__init__() + # need a large channel to trigger ddp optimizer split module + self.CHANNELS = 640 + self.convi = nn.Conv2d(46, self.CHANNELS, 3, padding=1, bias=False) + self.convp = nn.Conv2d( + self.CHANNELS, self.CHANNELS, 1, padding=0, bias=False + ) + self.bni = nn.BatchNorm2d(self.CHANNELS) + + def forward(self, bitmap_channels): + x = self.convi(bitmap_channels) + x = self.bni(x) + x = self.convp(x) + return x + + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + net = Net().to(self.rank) + optimizer = torch.optim.SGD( + net.parameters(), + lr=5e-2, + ) + + net = DDP(net, device_ids=[self.rank]) + opt_net = torch.compile(net, mode="reduce-overhead") + opt_net.train() + + for _ in range(10): + optimizer.zero_grad() + data = torch.randn((16, 46, 8, 8), dtype=torch.float32, device="cuda") + opt_net(data).sum().backward() + + # 2 fwd and 2 bwd graph such that 4 graphs in total + graph_id = ( + torch._inductor.cudagraph_trees.get_container(self.rank) + .tree_manager.new_graph_id() + .id + ) + self.assertTrue(graph_id == 4) + @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) def test_fsdp_setattr(self): @@ -1097,6 +1151,31 @@ def f(x): for r in res[1:]: self.assertEqual(res[0], r) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @enable_guard_collectives() + def test_guard_collective(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(x): + return x.sum() + + x = torch.randn(10, device=self.rank) + f(x) + + if self.rank == 0: + x = torch.randn(10, device=self.rank) + else: + x = torch.randn(12, device=self.rank) # recompile on one rank + f(x) + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_get_pg_attr(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): @@ -1174,11 +1253,9 @@ def f(x): @patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10) def test_asymmetric_compilation_with_fx_cache(self): from torch._dynamo.utils import counters - from torch._inductor.utils import fresh_inductor_cache + from torch._inductor.utils import fresh_cache - with fresh_inductor_cache(), _dynamo_dist_per_rank_init( - self.rank, self.world_size - ): + with fresh_cache(), _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() device = f"cuda:{self.rank}" @@ -1208,7 +1285,7 @@ def f(x): torch._dynamo.reset() if self.rank == 0: - with fresh_inductor_cache(): + with fresh_cache(): f(x) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) @@ -1788,9 +1865,7 @@ def _(ctx): f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" ).check( f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" - ).run( - GUARDS_FILE.getvalue() - ) + ).run(GUARDS_FILE.getvalue()) self.assertTrue(same(correct_outputs, outputs)) diff --git a/test/distributed/test_fake_pg.py b/test/distributed/test_fake_pg.py index 8f2786bdc8b1c8..bc65fab2c67f55 100644 --- a/test/distributed/test_fake_pg.py +++ b/test/distributed/test_fake_pg.py @@ -17,7 +17,9 @@ ) from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_distributed import HAS_ACCELERATOR +from torch.testing._internal.common_fsdp import get_devtype +from torch.testing._internal.common_utils import run_tests, skipIfHpu, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule from torch.testing._internal.distributed.fake_pg import FakeStore @@ -26,13 +28,16 @@ print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) -HAS_CUDA = torch.cuda.is_available() +device_type = get_devtype().type class TestFakePG(TestCase): def tearDown(self): super().tearDown() - dist.destroy_process_group() + try: + dist.destroy_process_group() + except AssertionError: + pass def test_all_reduce(self): store = FakeStore() @@ -62,20 +67,21 @@ def test_reduce_scatter(self): dist.reduce_scatter(output_tensor, to_reduce_scatter) self.assertEqual(tuple(output_tensor.shape), (3, 3)) - @unittest.skipIf(not HAS_CUDA, "No CUDA") + @unittest.skipIf(not HAS_ACCELERATOR, "No accelerator") def test_construct_fsdp(self): store = FakeStore() dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) - FSDP(nn.Linear(2, 3, device="cuda")) + FSDP(nn.Linear(2, 3, device=device_type)) - @unittest.skipIf(not HAS_CUDA, "No CUDA") + @skipIfHpu + @unittest.skipIf(not HAS_ACCELERATOR, "No accelerator") def test_fsdp_fake_e2e(self): store = dist.HashStore() dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) my_module = nn.Sequential( - nn.Linear(2, 3, device="cuda"), + nn.Linear(2, 3, device=device_type), nn.ReLU(), - nn.Linear(3, 2, device="cuda"), + nn.Linear(3, 2, device=device_type), ) sharded_module = FSDP(my_module, use_orig_params=True) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) @@ -85,7 +91,8 @@ def test_fsdp_fake_e2e(self): loss.backward() optim.step() - @unittest.skipIf(not HAS_CUDA, "No CUDA") + @skipIfHpu + @unittest.skipIf(not HAS_ACCELERATOR, "No accelerator") def test_fake_pg_tracing(self): store = dist.HashStore() dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) @@ -95,7 +102,7 @@ def test_fake_pg_tracing(self): def allgather_fn(tensor): return funcol.all_gather_tensor(tensor, 0, default_pg) - gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda")) + gm = make_fx(allgather_fn)(torch.randn(2, 2, device=device_type)) FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph)) def test_broadcast(self): @@ -165,7 +172,8 @@ def test_recv(self): dist.recv(output, 1) self.assertEqual(tuple(output.shape), (3, 3)) - @unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP") + @skipIfHpu + @unittest.skipIf(not HAS_ACCELERATOR, "No accelerator") def test_fsdp_tp_fake_e2e(self): world_size = 4 tp_size = 2 @@ -175,9 +183,11 @@ def test_fsdp_tp_fake_e2e(self): backend="fake", rank=0, world_size=world_size, store=store ) - device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size)) + device_mesh = DeviceMesh( + device_type, torch.arange(0, world_size).view(-1, tp_size) + ) device_mesh = init_device_mesh( - "cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"] + device_type, (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"] ) sequence_parallelize_plan = { @@ -190,7 +200,7 @@ def test_fsdp_tp_fake_e2e(self): } for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]: my_module = parallelize_module( - MLPModule(device="cuda"), + MLPModule(device=device_type), device_mesh["tp"], parallel_plan, ) @@ -203,7 +213,7 @@ def test_fsdp_tp_fake_e2e(self): for i in range(10): dp_rank = dist.get_rank() torch.manual_seed(i + dp_rank) - input = torch.randn(20, 10).cuda(dist.get_rank()) + input = torch.randn(20, 10, device=f"{device_type}:{dp_rank}") x = sharded_module(input) loss = x.sum() loss.backward() diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index 5e6510f2c22f9e..3b93e4d2b19ad6 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -715,6 +715,13 @@ def run_with_backward(): _, codes = run_and_get_code(run_with_backward) for code in codes: + assert_keywords = ["assert_size_stride", "assert_alignment"] + filtered_lines = [ + line + for line in code.splitlines() + if not any(assert_key in line for assert_key in assert_keywords) + ] + code = "\n".join(filtered_lines) FileCheck().check_count( "_c10d_functional.all_to_all_single.default", 1, exactly=True ).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run( diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 6c33a6031d2870..47550a43b722bb 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -25,10 +25,12 @@ from torch._inductor.utils import run_and_get_triton_code from torch.distributed.distributed_c10d import GroupMember from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_distributed import ( _dynamo_dist_per_rank_init, DynamoDistributedMultiProcTestCase, DynamoDistributedSingleProcTestCase, + MultiProcessTestCase, requires_nccl, skip_if_lt_x_gpu, ) @@ -519,12 +521,13 @@ def example( out = a2a / a2a.sum(dim=0) return out - with _dynamo_dist_per_rank_init( - self.rank, self.world_size - ), torch._dynamo.config.patch( - dynamic_shapes=True, - capture_dynamic_output_shape_ops=True, - capture_scalar_outputs=True, + with ( + _dynamo_dist_per_rank_init(self.rank, self.world_size), + torch._dynamo.config.patch( + dynamic_shapes=True, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + ), ): row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2 input_split_sizes_tensor = torch.tensor( @@ -680,15 +683,15 @@ def example( return torch.ops.custom_ns.foo(a2a) - with _dynamo_dist_per_rank_init( - self.rank, self.world_size - ), torch._dynamo.config.patch( - dynamic_shapes=True, - capture_dynamic_output_shape_ops=True, - capture_scalar_outputs=True, - ), torch.library._scoped_library( - "custom_ns", "FRAGMENT" - ) as lib: + with ( + _dynamo_dist_per_rank_init(self.rank, self.world_size), + torch._dynamo.config.patch( + dynamic_shapes=True, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + ), + torch.library._scoped_library("custom_ns", "FRAGMENT") as lib, + ): lib.define( "alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950 ) @@ -1502,6 +1505,319 @@ def _reorder_communication_preserving_peak_memory( self.assertEqual(stats.limiting_factor, "data dependency") self.assertEqual(stats.moves, 0) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not SM80OrLater, "bfloat16") + def test_all_gather_bucket(self): + def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + # do some unrelated matmuls + y = torch.mm(x, w) + + # cast the inputs + ag_0_cast = ag_0.to(torch.bfloat16) + ag_1_cast = ag_1.to(torch.bfloat16) + + # allgather + group_name = ( + torch.distributed.distributed_c10d._get_default_group().group_name + ) + ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_0_cast, group_size, group_name + ) + ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_1_cast, group_size, group_name + ) + + # wait op + ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out) + ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out) + + return y, ag_0_out, ag_1_out + + x = torch.ones(4, 384, device="cuda", dtype=torch.float32) + w = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1] + + with torch._inductor.config.patch( + { + "bucket_all_gathers_fx": "fsdp", + "reorder_for_compute_comm_overlap": False, + } + ): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + # NOTE: The first return value should be the output of the first wait_tensor. + # We want to make sure no unneccessary copy is made. + (FileCheck().check("all_gather_into_tensor_out").run(code)) + out = compiled(*inputs, **self.get_world_trs()) + correct = func(*inputs, **self.get_world_trs()) + assert same(out, correct), f"{out} va {correct}" + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not SM80OrLater, "bfloat16") + def test_reorder_peak_memory_bucketed(self): + """ + Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather. + Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the + comm from moving due to data dependency. + """ + + def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + # do some unrelated matmuls + y = torch.mm(x, w) + + # cast the inputs + ag_0_cast = ag_0.to(torch.bfloat16) + ag_1_cast = ag_1.to(torch.bfloat16) + + # allgather + group_name = ( + torch.distributed.distributed_c10d._get_default_group().group_name + ) + ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_0_cast, group_size, group_name + ) + ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_1_cast, group_size, group_name + ) + + # wait op + ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out) + ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out) + + return y, ag_0_out, ag_1_out + + x = torch.ones(4, 384, device="cuda", dtype=torch.float32) + w = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_1 = torch.ones(512, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1] + + # get stats directly from the internal helper without affecting the real pass's signature + node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None + + def _reorder_communication_preserving_peak_memory( + snodes: list[BaseSchedulerNode], + ) -> list[BaseSchedulerNode]: + nonlocal node_stats + ( + reordered_snodes, + node_stats, + ) = _reorder_communication_preserving_peak_memory_internal(snodes) + return reordered_snodes + + with torch._inductor.config.patch( + { + "bucket_all_gathers_fx": "all", + "reorder_for_compute_comm_overlap": True, + "reorder_for_compute_comm_overlap_passes": [ + "sink_waits", + # same as reorder_communication_preserving_peak_memory but returns debug info structures directly + _reorder_communication_preserving_peak_memory, + ], + } + ): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + # NOTE: The first return value should be the output of the first wait_tensor. + # We want to make sure no unneccessary copy is made. + (FileCheck().check("all_gather_into_tensor_out").run(code)) + out = compiled(*inputs, **self.get_world_trs()) + correct = func(*inputs, **self.get_world_trs()) + assert same(out, correct), f"{out} va {correct}" + assert node_stats is not None + self.assertTrue(isinstance(node_stats, dict)) + self.assertEqual(len(node_stats), 1) + + # TODO: Debug why reordering does not move collective after bucketing + # for stats in node_stats.values(): + # self.assertEqual(stats.initial_exposed, 0) + def _reorder_communication_preserving_peak_memory( + snodes: list[BaseSchedulerNode], + ) -> list[BaseSchedulerNode]: + nonlocal node_stats + ( + reordered_snodes, + node_stats, + ) = _reorder_communication_preserving_peak_memory_internal(snodes) + return reordered_snodes + + with torch._inductor.config.patch( + { + "reorder_for_compute_comm_overlap": True, + "reorder_for_compute_comm_overlap_passes": [ + "sink_waits", + # same as reorder_communication_preserving_peak_memory but returns debug info structures directly + _reorder_communication_preserving_peak_memory, + ], + } + ): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + # NOTE: The first return value should be the output of the first wait_tensor. + # We want to make sure no unneccessary copy is made. + ( + FileCheck() + .check("all_gather") + .check("wait") + .check("all_gather") + .check("wait") + .run(code) + ) + out = compiled(*inputs, **self.get_world_trs()) + correct = func(*inputs, **self.get_world_trs()) + assert same(out, correct), f"{out} va {correct}" + + # TODO make the test case more interesting and validate the actual desired behavior + assert node_stats is not None + self.assertTrue(isinstance(node_stats, dict)) + self.assertEqual(len(node_stats), 2) + # for stats in node_stats.values(): + # self.assertEqual(stats.moves, 0) + # self.assertEqual(stats.limiting_factor, "data dependency") + # self.assertEqual(stats.moves, 3) + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + def test_reorder_respects_wait_dep(self): + """ + Covers the case where the output of one collective feeds the input of another collective. + e.g. TP + FSDP - all_gather(tp+dp sharded param on TP dim) -> allgather dp_sharded buffer on DP dim + """ + + def func(inp, *, tag, ranks, group_size): + group_name = ( + torch.distributed.distributed_c10d._get_default_group().group_name + ) + ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor( + inp, group_size, group_name + ) + ag_0_wait = torch.ops.c10d_functional.wait_tensor(ag_0_out) + ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_0_wait, group_size, group_name + ) + ag_1_wait = torch.ops.c10d_functional.wait_tensor(ag_1_out) + # ensure other is not incorrectly aliasing ar's buffer + return ag_1_wait + + inputs = torch.ones(4, 4, device="cuda") + + # get stats directly from the internal helper without affecting the real pass's signature + node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None + + def _reorder_communication_preserving_peak_memory( + snodes: list[BaseSchedulerNode], + ) -> list[BaseSchedulerNode]: + nonlocal node_stats + ( + reordered_snodes, + node_stats, + ) = _reorder_communication_preserving_peak_memory_internal(snodes) + return reordered_snodes + + with torch._inductor.config.patch( + { + "reorder_for_compute_comm_overlap": True, + "reorder_for_compute_comm_overlap_passes": [ + "sink_waits", + # same as reorder_communication_preserving_peak_memory but returns debug info structures directly + _reorder_communication_preserving_peak_memory, + ], + } + ): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) + # NOTE: The first return value should be the output of the first wait_tensor. + # We want to make sure no unneccessary copy is made. + ( + FileCheck() + .check("all_gather") + .check("wait") + .check("all_gather") + .check("wait") + .run(code) + ) + out = compiled(inputs, **self.get_world_trs()) + correct = func(inputs, **self.get_world_trs()) + assert same(out, correct), f"{out} va {correct}" + + # TODO make the test case more interesting and validate the actual desired behavior + assert node_stats is not None + self.assertTrue(isinstance(node_stats, dict)) + self.assertEqual(len(node_stats), 2) + for stats in node_stats.values(): + self.assertEqual(stats.moves, 0) + + +@requires_nccl() +class TestSyncDecisionCrossRanks(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 2 + + @property + def ranks(self) -> list[int]: + return list(range(self.world_size)) + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process_group(self) -> None: + torch._inductor.config.triton.store_cubin = True + torch._inductor.config.debug = True + + torch.cuda.set_device(self.device) + store = torch.distributed.FileStore(self.file_name, self.world_size) + torch.distributed.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch._C._distributed_c10d._register_process_group( + "default", torch.distributed.group.WORLD + ) + + @skip_if_lt_x_gpu(2) + def test_sync_decision_cross_ranks(self): + from torch._functorch.partitioners import _sync_decision_cross_ranks + + test_graph = torch.fx.Graph() + node1 = test_graph.placeholder("x") + + ag1 = test_graph.create_node( + "call_function", + torch.ops._c10d_functional.all_gather_into_tensor.default, + (node1,), + ) + wt1 = test_graph.create_node( + "call_function", torch.ops._c10d_functional.wait_tensor.default, (ag1,) + ) + wt1.meta["val"] = torch.randn(10, 10) + + ag2 = test_graph.create_node( + "call_function", + torch.ops._c10d_functional.all_gather_into_tensor.default, + (node1,), + ) + wt2 = test_graph.create_node( + "call_function", torch.ops._c10d_functional.wait_tensor.default, (ag2,) + ) + wt2.meta["val"] = torch.randn(10, 20) + if self.rank == 0: + saved_values = [wt1] + else: + saved_values = [wt2] + + self._init_process_group() + saved_values = _sync_decision_cross_ranks(test_graph, saved_values) + self.assertEqual(saved_values, [wt1]) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/distributed/test_nccl.py b/test/distributed/test_nccl.py index f9bb4f6543ee5f..8c7f0b3073b003 100644 --- a/test/distributed/test_nccl.py +++ b/test/distributed/test_nccl.py @@ -7,15 +7,21 @@ import torch.cuda import torch.cuda.nccl as nccl import torch.distributed as c10d +import torch.distributed._symmetric_memory as symm_mem from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, ) +from torch.testing._internal.common_distributed import ( + MultiProcContinousTest, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import ( IS_WINDOWS, load_tests, NoTest, + requires_cuda_p2p_access, run_tests, skip_but_pass_in_sandcastle_if, TEST_WITH_ROCM, @@ -239,6 +245,38 @@ def test_reduce_scatter(self, device, dtype): self.assertEqual(outputs[i], expected[i]) +@requires_cuda_p2p_access() +class NCCLSymmetricMemoryTest(MultiProcContinousTest): + @property + def device(self) -> torch.device: + return torch.device("cuda", self.rank) + + @skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm") + @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") + @skip_if_lt_x_gpu(2) + def test_nccl_symmem_alloc(self): + symm_mem.set_backend("NCCL") + torch.cuda.set_device(self.rank) + # Need this all_reduce to initialize NCCL communicator. Otherwise, the + # test will hang. TODO: investigate how NCCLSymmetricMemory can + # initialize NCCL communicator. + c10d.all_reduce(torch.ones(1, device=self.device)) + group_name = c10d.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + numel = 1024 + + def foo(): + inp = symm_mem.empty(numel, dtype=dtype, device=self.device) + symm_mem.rendezvous(inp, group=group_name) + + foo() + + out = symm_mem.empty(numel, dtype=dtype, device=self.device) + symm_mem.rendezvous(out, group=group_name) + + instantiate_device_type_tests(TestNCCL, globals(), only_for="cuda") if __name__ == "__main__": diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index e56f55efa32ad1..88876c70e5a7a2 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -1,39 +1,27 @@ # Owner(s): ["oncall: distributed"] # To run: -# TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -# OR -# TORCH_SYMMMEM=NVSHMEM torchrun --nproc-per-node 4 test/distributed/test_nvshmem.py +# python test/distributed/test_nvshmem.py -import os -import sys import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem from torch.testing._internal.common_distributed import MultiProcContinousTest from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, run_tests, skip_but_pass_in_sandcastle_if, skipIfRocm, ) -symm_mem_backend = os.getenv("TORCH_SYMMMEM") - -if symm_mem_backend != "NVSHMEM": - print( - "test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`, skipping tests", - file=sys.stderr, - ) - sys.exit(0) - - # Decorator def requires_nvshmem(): return skip_but_pass_in_sandcastle_if( - symm_mem_backend != "NVSHMEM", - "test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`", + not symm_mem.is_nvshmem_available(), + "test_nvshmem requires NVSHMEM, skipping tests", ) @@ -42,6 +30,7 @@ def requires_nvshmem(): device_module = torch.get_device_module(device_type) +@instantiate_parametrized_tests @requires_nvshmem() class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest): def _init_device(self) -> None: @@ -49,11 +38,80 @@ def _init_device(self) -> None: device_module.set_device(self.device) # NOTE: required for nvshmem allocation torch.empty(1, device=self.device) + # Set NVSHMEM as SymmMem backend + symm_mem.set_backend("NVSHMEM") @property def device(self) -> torch.device: return torch.device(device_type, self.rank) + @skipIfRocm + def test_alloc(self) -> None: + self._init_device() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + numel = 1024 + + def foo(): + inp = symm_mem.empty(numel, dtype=dtype, device=self.device) + symm_mem.rendezvous(inp, group=group_name) + + foo() + + out = symm_mem.empty(numel, dtype=dtype, device=self.device) + symm_mem.rendezvous(out, group=group_name) + + @skipIfRocm + def test_nvshmem_put(self) -> None: + self._init_device() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + numel = 1024 + tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank) + symm_mem.rendezvous(tensor, group=group_name) + + if self.rank == 0: + torch.ops.symm_mem.nvshmem_put(tensor, 1) + # TODO: remove after we have wait_signal + dist.barrier() + elif self.rank == 1: + # handle.wait_signal(src_rank=0) + # TODO: remove after we have wait_signal + dist.barrier() + torch.testing.assert_close( + tensor, torch.zeros(numel, dtype=dtype, device=self.device) + ) + else: + dist.barrier() + + @skipIfRocm + def test_nvshmem_get(self) -> None: + self._init_device() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + numel = 1024 + tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank) + symm_mem.rendezvous(tensor, group=group_name) + + if self.rank == 0: + torch.ops.symm_mem.nvshmem_get(tensor, 1) + # TODO: remove after we have wait_signal + dist.barrier() + torch.testing.assert_close( + tensor, torch.ones(numel, dtype=dtype, device=self.device) + ) + else: + # handle.wait_signal(src_rank=0) + # TODO: remove after we have wait_signal + dist.barrier() + @skipIfRocm def test_nvshmem_all_to_all(self) -> None: self._init_device() @@ -80,7 +138,7 @@ def test_nvshmem_all_to_all(self) -> None: torch.testing.assert_close(out, expected) @skipIfRocm - def test_nvshmem_all_to_all_vdev(self) -> None: + def test_all_to_all_vdev(self) -> None: self._init_device() group_name = dist.group.WORLD.group_name @@ -95,21 +153,24 @@ def test_nvshmem_all_to_all_vdev(self) -> None: out_splits = torch.zeros_like(inp_splits) dist.all_to_all_single(out_splits, inp_splits) out_numel = out_splits.sum().item() - # Align up to make it bigger - align = 16 - out_numel_max = (out_numel + align - 1) // align * align - inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_( + # Max number of input elements (must be a constant across ranks for symmetric memory allocation) + max_inp_numel = k * self.world_size + # Max number of output elements (must be a constant across ranks for symmetric memory allocation) + overflow_factor = self.world_size # worst case: one rank receives all data + max_out_numel = max_inp_numel * overflow_factor + + inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_( self.rank ) - out = symm_mem.empty(out_numel_max, dtype=dtype, device=self.device).fill_(-1) + out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1) in_out_splits = symm_mem.empty( (3, self.world_size), dtype=torch.int64, device=self.device ) # Row 0 is input splits in_out_splits[0].copy_(inp_splits) - torch.ops.symm_mem.nvshmem_all_to_all_vdev(inp, out, in_out_splits, group_name) + torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name) # Check input splits (row 0) -- should not change torch.testing.assert_close(in_out_splits[0], inp_splits) @@ -119,15 +180,123 @@ def test_nvshmem_all_to_all_vdev(self) -> None: # Check output offsets (row 2) out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan - # output offsets from `nvshmem_all_to_all_vdev` is exclusive scan + # output offsets from `all_to_all_vdev` is exclusive scan self.assertEqual(in_out_splits[2][0], 0) torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1]) # Check data expected = torch.empty(out_numel, dtype=dtype, device=self.device) - dist.all_to_all_single(expected, inp, out_splits.tolist(), inp_splits.tolist()) + dist.all_to_all_single( + expected, inp[:inp_numel], out_splits.tolist(), inp_splits.tolist() + ) torch.testing.assert_close(out[:out_numel], expected) + @skipIfRocm + @parametrize("align", [1, 8, 16]) # `major_align` of output + def test_all_to_all_vdev_2d(self, align: int) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + # Number of experts per rank + ne = 8 + nsplits = ne * self.world_size + + # Number of elements for an expert is random between [0, k) + k = 10 + inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=self.device) + + # Exchange input splits to get output splits + out_splits = torch.zeros_like(inp_splits) + dist.all_to_all_single(out_splits, inp_splits) + # We do a .t() here because there is a rank-major to expert-major shuffle + out_splits_t = out_splits.reshape(self.world_size, ne).t() + + # Actual number of input elements + inp_numel = inp_splits.sum().item() + # Actual number of output elements + out_numel = out_splits.sum().item() + # Max number of input elements (must be a constant across ranks for symmetric memory allocation) + max_inp_numel = k * nsplits + # Max number of output elements (must be a constant across ranks for symmetric memory allocation) + overflow_factor = self.world_size # worst case: one rank receives all data + max_out_numel = max_inp_numel * overflow_factor + + inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_( + self.rank + ) + out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1) + # 3 rows: input splits, output splits, output offsets + # Initiallizing all values to -1 to check if they are updated + in_out_splits = symm_mem.empty( + (3, nsplits), dtype=torch.int64, device=self.device + ).fill_(-1) + # Row 0 is input splits + in_out_splits[0].copy_(inp_splits) + + torch.ops.symm_mem.all_to_all_vdev_2d( + inp, out, in_out_splits, group_name, major_align=align + ) + received_out_splits = in_out_splits[1] + received_out_offsets = in_out_splits[2] + + # Check input splits (row 0) -- should not change + torch.testing.assert_close(in_out_splits[0], inp_splits) + + # Check output splits (row 1) + torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1)) + + # Check output offsets (row 2) + out_split_list = out_splits_t.tolist() + for i in range(ne): + expert_sum = 0 + for j in range(self.world_size): + expert_sum += out_split_list[i][j] + # Align up expert_sum + expert_sum_aligned = (expert_sum + align - 1) // align * align + # If 0, make it at least `align` (bc cutlass currently does not support empty bins) + expert_sum_aligned = max(expert_sum_aligned, align) + # last element absorbs the padding + out_split_list[i][-1] += expert_sum_aligned - expert_sum + + out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1) + out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan + # Make it exclusive scan because that's what `all_to_all_vdev_2d` returns + out_offsets = torch.cat( + [torch.zeros(1, device=self.device), out_offsets[:-1]] + ).to(torch.int64) + torch.testing.assert_close(received_out_offsets, out_offsets) + + # Check data + expected = torch.empty(out_numel, dtype=dtype, device=self.device) + inp_splits_rank = inp_splits.reshape(self.world_size, ne).sum(1) + out_splits_rank = out_splits.reshape(self.world_size, ne).sum(1) + dist.all_to_all_single( + expected, + inp[:inp_numel], + out_splits_rank.tolist(), + inp_splits_rank.tolist(), + ) + # We still need to shuffle `expected` + out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan + result_list = [] + for j in range(ne): + for i in range(self.world_size): + chunk_id = i * ne + j + offset = out_offsets[chunk_id] + chunk = expected[offset - out_splits[chunk_id] : offset] + result_list.append(chunk) + + # Do a chunk-wise comparison + for c, chunk in enumerate(result_list): + start = received_out_offsets[c].item() + split = received_out_splits[c].item() + received_chunk = out[start : start + split] + torch.testing.assert_close(received_chunk, chunk) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py new file mode 100644 index 00000000000000..2aabf92427841f --- /dev/null +++ b/test/distributed/test_nvshmem_triton.py @@ -0,0 +1,687 @@ +# Owner(s): ["oncall: distributed"] + +# To run: +# python test/distributed/test_nvshmem_triton.py + + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem +from torch._inductor.runtime.triton_compat import triton +from torch.testing._internal.common_distributed import MultiProcContinousTest +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, + skip_but_pass_in_sandcastle, + skip_but_pass_in_sandcastle_if, + skipIfRocm, +) +from torch.testing._internal.inductor_utils import requires_triton + + +# Decorator +def requires_nvshmem(): + return skip_but_pass_in_sandcastle_if( + not symm_mem.is_nvshmem_available(), + "test_nvshmem requires NVSHMEM, skipping tests", + ) + + +# So that tests are written in device-agnostic way +device_type = "cuda" +device_module = torch.get_device_module(device_type) + + +# Shared Triton JIT kernels +@triton.jit +def put_kernel( + dst_ptr, + src_ptr, + numel, + peer, +): + nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + + +@triton.jit +def get_kernel( + dst_ptr, + src_ptr, + numel, + peer, +): + nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer) + + +@triton.jit +def put_signal_kernel( + dst_ptr, + src_ptr, + numel, + sig_ptr, + signal_val, + sig_op, + peer, +): + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer + ) + + +@triton.jit +def signal_wait_until_kernel(sig_ptr, cmp_op, cmp_val): + nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val) + + +@triton.jit +def signal_op_kernel( + sig_addr, + signal, + sig_op, + peer, +): + nvshmem.signal_op(sig_addr, signal, sig_op, peer) + + +@triton.jit +def wait_until_kernel( + ivar_ptr, + cmp_op, + cmp_val, +): + nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val) + + +@triton.jit +def put_and_signal_kernel( + dst_ptr, + src_ptr, + numel, + sig_ptr, + signal_val, + sig_op, + peer, +): + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer + ) + + +@triton.jit +def put_with_fence_kernel( + dst_ptr1, + dst_ptr2, + src_ptr1, + src_ptr2, + flag_ptr, + flag_src_ptr, + numel, + peer, +): + # First put + nvshmem.putmem_block(dst_ptr1, src_ptr1, numel, peer) + # Ensure the first put is ordered before the next. + nvshmem.fence() + # Second put + nvshmem.putmem_block(dst_ptr2, src_ptr2, numel, peer) + # Order the second put before flag update. + nvshmem.fence() + # Write the flag (single int64) to signal completion. + nvshmem.putmem_block(flag_ptr, flag_src_ptr, 1, peer) + + +@triton.jit +def put_with_quiet_kernel( + dst_ptr, + src_ptr, + flag_dst_ptr, + flag_src_ptr, + numel, + peer, +): + # Put data + nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + # Call quiet to ensure put is complete + nvshmem.quiet() + # Only after quiet, set the completion flag + # This ensures the data put is complete before flag is set + nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer) + + +@instantiate_parametrized_tests +@requires_nvshmem() +class NVSHMEMTritonTest(MultiProcContinousTest): + def _init_device(self) -> None: + # TODO: relieve this (seems to hang if without) + device_module.set_device(self.device) + # NOTE: required for nvshmem allocation + torch.empty(1, device=self.device) + # Set NVSHMEM as SymmMem backend + symm_mem.set_backend("NVSHMEM") + + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) + + @skipIfRocm + @requires_triton() + def test_triton_put(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + + # Enable NVSHMEM for Triton + nvshmem_lib = nvshmem.enable_triton() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + val = 5 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + + peer = 1 - rank + if rank == 0: + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + put_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + + dist.barrier() + if rank == 1: + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_get(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val = 7 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_( + val if rank == 0 else -1 + ) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + dist.barrier() + peer = 1 - rank + if rank == 1: + # Rank 1 gets data from rank 0 + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + get_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + if rank == 1: + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_get_ring(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + world_size = dist.get_world_size() + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + # Each rank fills its input buffer with its own rank value + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + dist.barrier() + + # Ring topology: each rank gets data from the rank to its left + # rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc. + peer = (rank - 1) % world_size + + # All ranks execute the get operation + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + get_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + + expected_value = peer + torch.testing.assert_close( + out, expected_value * torch.ones(numel, dtype=dtype, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_put_signal_set(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + # Data buffers + val = 11 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + + # Use the signal pad attached to the output symmetric memory handle + # as the flag buffer for signaling completion. + flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) + + peer = 1 - rank + NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set + SIGNAL_VAL = 1 # Signal completion value + NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until + + if rank == 0: + # Rank 0 puts into Rank 1 + dst_ptr = out_hdl.buffer_ptrs[peer] + src_ptr = inp_hdl.buffer_ptrs[rank] + sig_ptr = out_hdl.signal_pad_ptrs[peer] + put_signal_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + sig_ptr=sig_ptr, + signal_val=SIGNAL_VAL, + sig_op=NVSHMEM_SIGNAL_SET, + peer=peer, + extern_libs=nvshmem_lib, + ) + + if rank == 1: + # Wait until signal flag is set by Rank 0 + sig_ptr_local = out_hdl.signal_pad_ptrs[rank] + signal_wait_until_kernel[(1,)]( + sig_ptr_local, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=SIGNAL_VAL, + extern_libs=nvshmem_lib, + ) + # After wait completes, verify data and flag contents + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_put_signal_add(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + # Data buffers + val = 11 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + + # Use the signal pad attached to the output symmetric memory handle + # as the flag buffer for signaling completion. + flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) + + peer = 1 - rank + NVSHMEM_SIGNAL_ADD = 5 # atomic add operation + SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD + NVSHMEM_CMP_EQ = 0 + + if rank == 0: + # Rank 0 puts into Rank 1 + dst_ptr = out_hdl.buffer_ptrs[peer] + src_ptr = inp_hdl.buffer_ptrs[rank] + sig_ptr = out_hdl.signal_pad_ptrs[peer] + put_signal_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + sig_ptr=sig_ptr, + signal_val=SIGNAL_VAL, + sig_op=NVSHMEM_SIGNAL_ADD, + peer=peer, + extern_libs=nvshmem_lib, + ) + + if rank == 1: + sig_ptr_local = out_hdl.signal_pad_ptrs[rank] + signal_wait_until_kernel[(1, 1, 1)]( + sig_ptr_local, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=SIGNAL_VAL, + extern_libs=nvshmem_lib, + ) + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device) + ) + + # This test hangs. TODO: investigate why. + @skip_but_pass_in_sandcastle("Hangs") + @skipIfRocm + @requires_triton() + def test_triton_wait_until(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + + # Data buffers + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val = 13 + flag_val = 21 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + + peer = 1 - rank + NVSHMEM_CMP_EQ = 0 # from nvshmem.h + NVSHMEM_SIGNAL_SET = 0 # atomic set operation + + if rank == 0: + # Rank 0 waits for the flag to be set by Rank 1, then checks the data + ivar_ptr = out_hdl.signal_pad_ptrs[rank] + wait_until_kernel[(1, 1, 1)]( + ivar_ptr, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=flag_val, + extern_libs=nvshmem_lib, + ) + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + + if rank == 1: + # Rank 1 puts data into Rank 0's output buffer + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + put_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + + # Rank 1 sets the flag on Rank 0 using nvshmemx_signal_op + sig_addr = out_hdl.signal_pad_ptrs[rank] + signal_op_kernel[(1, 1, 1)]( + sig_addr, + signal=flag_val, + sig_op=NVSHMEM_SIGNAL_SET, + peer=peer, + extern_libs=nvshmem_lib, + ) + + @skipIfRocm + @requires_triton() + def test_triton_signal_wait_until(self) -> None: + self._init_device() + # Enable NVSHMEM for Triton + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + peer = 1 - rank + + # NVSHMEM constants from documentation + NVSHMEM_CMP_EQ = 0 # equal comparison + NVSHMEM_SIGNAL_SET = 0 # atomic set operation + + # Message configuration + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val_to_put = 123 # arbitrary test value + COMPLETION_FLAG_VAL = 1 + + # Producer (rank 0) prepares the data to send + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val_to_put) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + # Consumer (rank 1) prepares the destination buffer + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + out_hdl = symm_mem.rendezvous(out, group=group_name) + # Use the signal pad for synchronization, as in previous tests + flag_dtype = torch.int64 + flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0) + + if rank == 0: + # Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag + dst_ptr = out_hdl.buffer_ptrs[peer] + src_ptr = inp_hdl.buffer_ptrs[rank] + sig_ptr = out_hdl.signal_pad_ptrs[peer] + put_and_signal_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel, + sig_ptr, + signal_val=COMPLETION_FLAG_VAL, + sig_op=NVSHMEM_SIGNAL_SET, + peer=peer, + extern_libs=nvshmem_lib, + ) + elif rank == 1: + # Consumer (rank 1): Waits on the signal variable using `signal_wait_until`. + sig_ptr = out_hdl.signal_pad_ptrs[rank] + signal_wait_until_kernel[(1, 1, 1)]( + sig_ptr, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=COMPLETION_FLAG_VAL, + extern_libs=nvshmem_lib, + ) + # After the wait returns, verify data and flag + torch.testing.assert_close( + out, val_to_put * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + flag, + torch.tensor( + [COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device + ), + ) + + @skipIfRocm + @requires_triton() + def test_triton_fence(self) -> None: + """ + Rank 0 performs two put operations into Rank 1's buffers with a fence + between them, followed by another fence and a flag update. Rank 1 waits + for the flag, then verifies that both destination buffers contain the + expected values. The flag is transferred after the final fence, so + its arrival implies that both preceding puts have been delivered in + order. + """ + + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + peer = 1 - rank + # Message configuration + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val1 = 10 + val2 = 20 + flag_val = 1 + # Symmetric buffers + inp1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val1) + inp2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val2) + out1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + out2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp1_hdl = symm_mem.rendezvous(inp1, group=group_name) + inp2_hdl = symm_mem.rendezvous(inp2, group=group_name) + out1_hdl = symm_mem.rendezvous(out1, group=group_name) + out2_hdl = symm_mem.rendezvous(out2, group=group_name) + + # Flag buffer resides in the signal pad of out2. + flag = out2_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) + flag_update_val = torch.tensor( + [flag_val], dtype=torch.int64, device=self.device + ) + NVSHMEM_CMP_EQ = 0 # compare equal + + if rank == 0: + dst_ptr1 = out1_hdl.buffer_ptrs[rank] + dst_ptr2 = out2_hdl.buffer_ptrs[rank] + src_ptr1 = inp1_hdl.buffer_ptrs[rank] + src_ptr2 = inp2_hdl.buffer_ptrs[rank] + flag_ptr = out2_hdl.signal_pad_ptrs[rank] + flag_src_ptr = flag_update_val.data_ptr() + + put_with_fence_kernel[(1, 1, 1)]( + dst_ptr1, + dst_ptr2, + src_ptr1, + src_ptr2, + flag_ptr, + flag_src_ptr, + numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + elif rank == 1: + # Wait until flag is set by Rank 0. + ivar_ptr = out2_hdl.signal_pad_ptrs[rank] + wait_until_kernel[(1, 1, 1)]( + ivar_ptr, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=flag_val, + extern_libs=nvshmem_lib, + ) + + # Verify ordered data arrival. + torch.testing.assert_close( + out1, val1 * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + out2, val2 * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_quiet(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + # Enable NVSHMEM for Triton + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + # Data buffers + val = 15 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + # Use signal pad as completion flag + flag_val = 42 + peer = 1 - rank + NVSHMEM_CMP_EQ = 0 + + if rank == 0: + # Rank 0 waits for flag from Rank 1 + ivar_ptr = out_hdl.signal_pad_ptrs[rank] + wait_until_kernel[(1, 1, 1)]( + ivar_ptr, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=flag_val, + extern_libs=nvshmem_lib, + ) + # After flag is set, data should be complete due to quiet + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + if rank == 1: + # Rank 1 puts data and flag with quiet in between + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + flag_dst_ptr = out_hdl.signal_pad_ptrs[rank] + # Create a tensor for the flag value + flag_update_val = torch.tensor( + [flag_val], dtype=torch.int64, device=self.device + ) + flag_src_ptr = flag_update_val.data_ptr() + put_with_quiet_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + flag_dst_ptr, + flag_src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py index d7e59f1c90a76e..4c96d4b564d62e 100644 --- a/test/distributed/test_pg_wrapper.py +++ b/test/distributed/test_pg_wrapper.py @@ -464,8 +464,8 @@ def test_collective_shape_mismatch_cuda(self): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_pg_wrapper must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_pg_wrapper must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 25a554942c8251..e9abb1d9071785 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -372,12 +372,8 @@ def test_address_already_in_use(self): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = dist.TCPStore( - addr, port, 1, True, use_libuv=self._use_libuv - ) # noqa: F841 - store2 = dist.TCPStore( - addr, port, 1, True, use_libuv=self._use_libuv - ) # noqa: F841 + store1 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841 + store2 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841 self.assertEqual(store1.libuvBackend, self._use_libuv) self.assertEqual(store2.libuvBackend, self._use_libuv) @@ -767,7 +763,7 @@ def test_common_errors(self): def test_nominal(self): with tempfile.NamedTemporaryFile(delete=False) as file: - url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' + url = f"file:///{file.name.replace(os.path.sep, '/')}?world_size=2" gen0 = dist.rendezvous(url + "&rank=0") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) @@ -1178,8 +1174,8 @@ def listen() -> None: if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index ce16ec737ff399..55db0d3b7283b2 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -10,7 +10,7 @@ import torch.distributed._symmetric_memory as symm_mem from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory -from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code +from torch._inductor.utils import fresh_cache, run_and_get_triton_code from torch.distributed._functional_collectives import all_gather_tensor from torch.distributed._symmetric_memory import ( _fused_all_gather_matmul_fallback, @@ -34,9 +34,9 @@ MI300_ARCH, parametrize, requires_cuda, + requires_cuda_p2p_access, run_tests, runOnRocmArch, - skip_but_pass_in_sandcastle_if, skipIfRocm, TEST_WITH_ROCM, TestCase, @@ -50,27 +50,6 @@ device_module = torch.get_device_module(device_type) -def requires_cuda_p2p_access(): - cuda_p2p_access_available = ( - torch.cuda.is_available() - and torch.cuda.get_device_capability() >= (8, 0) - and torch.cuda.device_count() >= 2 - ) - num_devices = torch.cuda.device_count() - for i in range(num_devices - 1): - for j in range(i + 1, num_devices): - if not torch.cuda.can_device_access_peer(i, j): - cuda_p2p_access_available = False - break - if not cuda_p2p_access_available: - break - - return skip_but_pass_in_sandcastle_if( - not cuda_p2p_access_available, - "cuda p2p access is not available", - ) - - @instantiate_parametrized_tests @requires_cuda_p2p_access() class SymmetricMemoryTest(MultiProcContinousTest): @@ -88,6 +67,14 @@ def test_has_multicast_support(self) -> None: self.assertFalse(_SymmetricMemory.has_multicast_support(DeviceType.CPU, 0)) # NOTE: DeviceType.CUDA is implicitly tested through @requires_multicast_support + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_get_backend(self) -> None: + backend = symm_mem.get_backend(torch.device("cuda")) + self.assertIsNotNone(backend) + backend = symm_mem.get_backend("cuda") + self.assertIsNotNone(backend) + @skipIfRocm @skip_if_lt_x_gpu(2) def test_cuda_nvlink_connectivity_detection(self) -> None: @@ -974,7 +961,7 @@ def _verify_reduce_scatter_result(self, inp, res): gathered_res[i], sum_inps[..., i * slice_width : (i + 1) * slice_width], rtol=1e-01, - atol=1e-01, + atol=1.1e-01, ) @skip_if_lt_x_gpu(4) @@ -1020,7 +1007,7 @@ def device(self) -> torch.device: @skip("Fails with 'one_shot_all_reduce' not found in AOT graph, TODO: fix") @skipIfRocm # requires registered-buffer support @skip_if_lt_x_gpu(2) - @fresh_inductor_cache() + @fresh_cache() def test_lowering_one_shot_all_reduce(self): self._init_process() arg = torch.rand(4, 4, device=self.device) @@ -1079,7 +1066,7 @@ class SymmMemSingleProcTest(TestCase): "stream_write_value32 currently only supports cuda version>=12.0", ) @skipIf( - _get_torch_cuda_version() == (12, 6), + _get_torch_cuda_version() >= (12, 6), "https://github.com/pytorch/pytorch/issues/154073", ) @runOnRocmArch(MI300_ARCH) diff --git a/test/dynamo/cpython/3_13/data/README b/test/dynamo/cpython/3_13/data/README new file mode 100644 index 00000000000000..bd05984e439078 --- /dev/null +++ b/test/dynamo/cpython/3_13/data/README @@ -0,0 +1,2 @@ +This empty directory serves as destination for temporary files +created by some tests, in particular, the test_codecmaps_* tests. diff --git a/test/dynamo/cpython/3_13/exception_hierarchy.txt b/test/dynamo/cpython/3_13/exception_hierarchy.txt new file mode 100644 index 00000000000000..5e83faab9a6158 --- /dev/null +++ b/test/dynamo/cpython/3_13/exception_hierarchy.txt @@ -0,0 +1,69 @@ +BaseException + ├── BaseExceptionGroup + ├── GeneratorExit + ├── KeyboardInterrupt + ├── SystemExit + └── Exception + ├── ArithmeticError + │ ├── FloatingPointError + │ ├── OverflowError + │ └── ZeroDivisionError + ├── AssertionError + ├── AttributeError + ├── BufferError + ├── EOFError + ├── ExceptionGroup [BaseExceptionGroup] + ├── ImportError + │ └── ModuleNotFoundError + ├── LookupError + │ ├── IndexError + │ └── KeyError + ├── MemoryError + ├── NameError + │ └── UnboundLocalError + ├── OSError + │ ├── BlockingIOError + │ ├── ChildProcessError + │ ├── ConnectionError + │ │ ├── BrokenPipeError + │ │ ├── ConnectionAbortedError + │ │ ├── ConnectionRefusedError + │ │ └── ConnectionResetError + │ ├── FileExistsError + │ ├── FileNotFoundError + │ ├── InterruptedError + │ ├── IsADirectoryError + │ ├── NotADirectoryError + │ ├── PermissionError + │ ├── ProcessLookupError + │ └── TimeoutError + ├── ReferenceError + ├── RuntimeError + │ ├── NotImplementedError + │ ├── PythonFinalizationError + │ └── RecursionError + ├── StopAsyncIteration + ├── StopIteration + ├── SyntaxError + │ └── _IncompleteInputError + │ └── IndentationError + │ └── TabError + ├── SystemError + ├── TypeError + ├── ValueError + │ └── UnicodeError + │ ├── UnicodeDecodeError + │ ├── UnicodeEncodeError + │ └── UnicodeTranslateError + └── Warning + ├── BytesWarning + ├── DeprecationWarning + ├── EncodingWarning + ├── FutureWarning + ├── ImportWarning + ├── PendingDeprecationWarning + ├── ResourceWarning + ├── RuntimeWarning + ├── SyntaxWarning + ├── UnicodeWarning + └── UserWarning diff --git a/test/dynamo/cpython/3_13/list_tests.diff b/test/dynamo/cpython/3_13/list_tests.diff new file mode 100644 index 00000000000000..ab26ef23f44e4d --- /dev/null +++ b/test/dynamo/cpython/3_13/list_tests.diff @@ -0,0 +1,67 @@ +diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py +index dbc5ef4f9f2..2b9f3b9311f 100644 +--- a/test/dynamo/cpython/3_13/list_tests.py ++++ b/test/dynamo/cpython/3_13/list_tests.py +@@ -1,3 +1,53 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + """ + Tests common to list and UserList.UserList + """ +@@ -5,7 +55,7 @@ Tests common to list and UserList.UserList + import sys + from functools import cmp_to_key + +-from test import seq_tests ++import seq_tests + from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit + + diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py new file mode 100644 index 00000000000000..2b9f3b9311f424 --- /dev/null +++ b/test/dynamo/cpython/3_13/list_tests.py @@ -0,0 +1,627 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +""" +Tests common to list and UserList.UserList +""" + +import sys +from functools import cmp_to_key + +import seq_tests +from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit + + +class CommonTest(seq_tests.CommonTest): + + def test_init(self): + # Iterable arg is optional + self.assertEqual(self.type2test([]), self.type2test()) + + # Init clears previous values + a = self.type2test([1, 2, 3]) + a.__init__() + self.assertEqual(a, self.type2test([])) + + # Init overwrites previous values + a = self.type2test([1, 2, 3]) + a.__init__([4, 5, 6]) + self.assertEqual(a, self.type2test([4, 5, 6])) + + # Mutables always return a new object + b = self.type2test(a) + self.assertNotEqual(id(a), id(b)) + self.assertEqual(a, b) + + def test_getitem_error(self): + a = [] + msg = "list indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + a['a'] + + def test_setitem_error(self): + a = [] + msg = "list indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + a['a'] = "python" + + def test_repr(self): + l0 = [] + l2 = [0, 1, 2] + a0 = self.type2test(l0) + a2 = self.type2test(l2) + + self.assertEqual(str(a0), str(l0)) + self.assertEqual(repr(a0), repr(l0)) + self.assertEqual(repr(a2), repr(l2)) + self.assertEqual(str(a2), "[0, 1, 2]") + self.assertEqual(repr(a2), "[0, 1, 2]") + + a2.append(a2) + a2.append(3) + self.assertEqual(str(a2), "[0, 1, 2, [...], 3]") + self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]") + + def test_repr_deep(self): + a = self.type2test([]) + for i in range(get_c_recursion_limit() + 1): + a = self.type2test([a]) + self.assertRaises(RecursionError, repr, a) + + def test_set_subscript(self): + a = self.type2test(range(20)) + self.assertRaises(ValueError, a.__setitem__, slice(0, 10, 0), [1,2,3]) + self.assertRaises(TypeError, a.__setitem__, slice(0, 10), 1) + self.assertRaises(ValueError, a.__setitem__, slice(0, 10, 2), [1,2]) + self.assertRaises(TypeError, a.__getitem__, 'x', 1) + a[slice(2,10,3)] = [1,2,3] + self.assertEqual(a, self.type2test([0, 1, 1, 3, 4, 2, 6, 7, 3, + 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19])) + + def test_reversed(self): + a = self.type2test(range(20)) + r = reversed(a) + self.assertEqual(list(r), self.type2test(range(19, -1, -1))) + self.assertRaises(StopIteration, next, r) + self.assertEqual(list(reversed(self.type2test())), + self.type2test()) + # Bug 3689: make sure list-reversed-iterator doesn't have __len__ + self.assertRaises(TypeError, len, reversed([1,2,3])) + + def test_setitem(self): + a = self.type2test([0, 1]) + a[0] = 0 + a[1] = 100 + self.assertEqual(a, self.type2test([0, 100])) + a[-1] = 200 + self.assertEqual(a, self.type2test([0, 200])) + a[-2] = 100 + self.assertEqual(a, self.type2test([100, 200])) + self.assertRaises(IndexError, a.__setitem__, -3, 200) + self.assertRaises(IndexError, a.__setitem__, 2, 200) + + a = self.type2test([]) + self.assertRaises(IndexError, a.__setitem__, 0, 200) + self.assertRaises(IndexError, a.__setitem__, -1, 200) + self.assertRaises(TypeError, a.__setitem__) + + a = self.type2test([0,1,2,3,4]) + a[0] = 1 + a[1] = 2 + a[2] = 3 + self.assertEqual(a, self.type2test([1,2,3,3,4])) + a[0] = 5 + a[1] = 6 + a[2] = 7 + self.assertEqual(a, self.type2test([5,6,7,3,4])) + a[-2] = 88 + a[-1] = 99 + self.assertEqual(a, self.type2test([5,6,7,88,99])) + a[-2] = 8 + a[-1] = 9 + self.assertEqual(a, self.type2test([5,6,7,8,9])) + + msg = "list indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + a['a'] = "python" + + def test_delitem(self): + a = self.type2test([0, 1]) + del a[1] + self.assertEqual(a, [0]) + del a[0] + self.assertEqual(a, []) + + a = self.type2test([0, 1]) + del a[-2] + self.assertEqual(a, [1]) + del a[-1] + self.assertEqual(a, []) + + a = self.type2test([0, 1]) + self.assertRaises(IndexError, a.__delitem__, -3) + self.assertRaises(IndexError, a.__delitem__, 2) + + a = self.type2test([]) + self.assertRaises(IndexError, a.__delitem__, 0) + + self.assertRaises(TypeError, a.__delitem__) + + def test_setslice(self): + l = [0, 1] + a = self.type2test(l) + + for i in range(-3, 4): + a[:i] = l[:i] + self.assertEqual(a, l) + a2 = a[:] + a2[:i] = a[:i] + self.assertEqual(a2, a) + a[i:] = l[i:] + self.assertEqual(a, l) + a2 = a[:] + a2[i:] = a[i:] + self.assertEqual(a2, a) + for j in range(-3, 4): + a[i:j] = l[i:j] + self.assertEqual(a, l) + a2 = a[:] + a2[i:j] = a[i:j] + self.assertEqual(a2, a) + + aa2 = a2[:] + aa2[:0] = [-2, -1] + self.assertEqual(aa2, [-2, -1, 0, 1]) + aa2[0:] = [] + self.assertEqual(aa2, []) + + a = self.type2test([1, 2, 3, 4, 5]) + a[:-1] = a + self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 5])) + a = self.type2test([1, 2, 3, 4, 5]) + a[1:] = a + self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5])) + a = self.type2test([1, 2, 3, 4, 5]) + a[1:-1] = a + self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5, 5])) + + a = self.type2test([]) + a[:] = tuple(range(10)) + self.assertEqual(a, self.type2test(range(10))) + + self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5)) + + self.assertRaises(TypeError, a.__setitem__) + + def test_slice_assign_iterator(self): + x = self.type2test(range(5)) + x[0:3] = reversed(range(3)) + self.assertEqual(x, self.type2test([2, 1, 0, 3, 4])) + + x[:] = reversed(range(3)) + self.assertEqual(x, self.type2test([2, 1, 0])) + + def test_delslice(self): + a = self.type2test([0, 1]) + del a[1:2] + del a[0:1] + self.assertEqual(a, self.type2test([])) + + a = self.type2test([0, 1]) + del a[1:2] + del a[0:1] + self.assertEqual(a, self.type2test([])) + + a = self.type2test([0, 1]) + del a[-2:-1] + self.assertEqual(a, self.type2test([1])) + + a = self.type2test([0, 1]) + del a[-2:-1] + self.assertEqual(a, self.type2test([1])) + + a = self.type2test([0, 1]) + del a[1:] + del a[:1] + self.assertEqual(a, self.type2test([])) + + a = self.type2test([0, 1]) + del a[1:] + del a[:1] + self.assertEqual(a, self.type2test([])) + + a = self.type2test([0, 1]) + del a[-1:] + self.assertEqual(a, self.type2test([0])) + + a = self.type2test([0, 1]) + del a[-1:] + self.assertEqual(a, self.type2test([0])) + + a = self.type2test([0, 1]) + del a[:] + self.assertEqual(a, self.type2test([])) + + def test_append(self): + a = self.type2test([]) + a.append(0) + a.append(1) + a.append(2) + self.assertEqual(a, self.type2test([0, 1, 2])) + + self.assertRaises(TypeError, a.append) + + def test_extend(self): + a1 = self.type2test([0]) + a2 = self.type2test((0, 1)) + a = a1[:] + a.extend(a2) + self.assertEqual(a, a1 + a2) + + a.extend(self.type2test([])) + self.assertEqual(a, a1 + a2) + + a.extend(a) + self.assertEqual(a, self.type2test([0, 0, 1, 0, 0, 1])) + + a = self.type2test("spam") + a.extend("eggs") + self.assertEqual(a, list("spameggs")) + + self.assertRaises(TypeError, a.extend, None) + self.assertRaises(TypeError, a.extend) + + # overflow test. issue1621 + class CustomIter: + def __iter__(self): + return self + def __next__(self): + raise StopIteration + def __length_hint__(self): + return sys.maxsize + a = self.type2test([1,2,3,4]) + a.extend(CustomIter()) + self.assertEqual(a, [1,2,3,4]) + + + def test_insert(self): + a = self.type2test([0, 1, 2]) + a.insert(0, -2) + a.insert(1, -1) + a.insert(2, 0) + self.assertEqual(a, [-2, -1, 0, 0, 1, 2]) + + b = a[:] + b.insert(-2, "foo") + b.insert(-200, "left") + b.insert(200, "right") + self.assertEqual(b, self.type2test(["left",-2,-1,0,0,"foo",1,2,"right"])) + + self.assertRaises(TypeError, a.insert) + + def test_pop(self): + a = self.type2test([-1, 0, 1]) + a.pop() + self.assertEqual(a, [-1, 0]) + a.pop(0) + self.assertEqual(a, [0]) + self.assertRaises(IndexError, a.pop, 5) + a.pop(0) + self.assertEqual(a, []) + self.assertRaises(IndexError, a.pop) + self.assertRaises(TypeError, a.pop, 42, 42) + a = self.type2test([0, 10, 20, 30, 40]) + + def test_remove(self): + a = self.type2test([0, 0, 1]) + a.remove(1) + self.assertEqual(a, [0, 0]) + a.remove(0) + self.assertEqual(a, [0]) + a.remove(0) + self.assertEqual(a, []) + + self.assertRaises(ValueError, a.remove, 0) + + self.assertRaises(TypeError, a.remove) + + a = self.type2test([1, 2]) + self.assertRaises(ValueError, a.remove, NEVER_EQ) + self.assertEqual(a, [1, 2]) + a.remove(ALWAYS_EQ) + self.assertEqual(a, [2]) + a = self.type2test([ALWAYS_EQ]) + a.remove(1) + self.assertEqual(a, []) + a = self.type2test([ALWAYS_EQ]) + a.remove(NEVER_EQ) + self.assertEqual(a, []) + a = self.type2test([NEVER_EQ]) + self.assertRaises(ValueError, a.remove, ALWAYS_EQ) + + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False + + a = self.type2test([0, 1, 2, 3]) + self.assertRaises(BadExc, a.remove, BadCmp()) + + class BadCmp2: + def __eq__(self, other): + raise BadExc() + + d = self.type2test('abcdefghcij') + d.remove('c') + self.assertEqual(d, self.type2test('abdefghcij')) + d.remove('c') + self.assertEqual(d, self.type2test('abdefghij')) + self.assertRaises(ValueError, d.remove, 'c') + self.assertEqual(d, self.type2test('abdefghij')) + + # Handle comparison errors + d = self.type2test(['a', 'b', BadCmp2(), 'c']) + e = self.type2test(d) + self.assertRaises(BadExc, d.remove, 'c') + for x, y in zip(d, e): + # verify that original order and values are retained. + self.assertIs(x, y) + + def test_index(self): + super().test_index() + a = self.type2test([-2, -1, 0, 0, 1, 2]) + a.remove(0) + self.assertRaises(ValueError, a.index, 2, 0, 4) + self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2])) + + # Test modifying the list during index's iteration + class EvilCmp: + def __init__(self, victim): + self.victim = victim + def __eq__(self, other): + del self.victim[:] + return False + a = self.type2test() + a[:] = [EvilCmp(a) for _ in range(100)] + # This used to seg fault before patch #1005778 + self.assertRaises(ValueError, a.index, None) + + def test_reverse(self): + u = self.type2test([-2, -1, 0, 1, 2]) + u2 = u[:] + u.reverse() + self.assertEqual(u, [2, 1, 0, -1, -2]) + u.reverse() + self.assertEqual(u, u2) + + self.assertRaises(TypeError, u.reverse, 42) + + def test_clear(self): + u = self.type2test([2, 3, 4]) + u.clear() + self.assertEqual(u, []) + + u = self.type2test([]) + u.clear() + self.assertEqual(u, []) + + u = self.type2test([]) + u.append(1) + u.clear() + u.append(2) + self.assertEqual(u, [2]) + + self.assertRaises(TypeError, u.clear, None) + + def test_copy(self): + u = self.type2test([1, 2, 3]) + v = u.copy() + self.assertEqual(v, [1, 2, 3]) + + u = self.type2test([]) + v = u.copy() + self.assertEqual(v, []) + + # test that it's indeed a copy and not a reference + u = self.type2test(['a', 'b']) + v = u.copy() + v.append('i') + self.assertEqual(u, ['a', 'b']) + self.assertEqual(v, u + ['i']) + + # test that it's a shallow, not a deep copy + u = self.type2test([1, 2, [3, 4], 5]) + v = u.copy() + self.assertEqual(u, v) + self.assertIs(v[3], u[3]) + + self.assertRaises(TypeError, u.copy, None) + + def test_sort(self): + u = self.type2test([1, 0]) + u.sort() + self.assertEqual(u, [0, 1]) + + u = self.type2test([2,1,0,-1,-2]) + u.sort() + self.assertEqual(u, self.type2test([-2,-1,0,1,2])) + + self.assertRaises(TypeError, u.sort, 42, 42) + + def revcmp(a, b): + if a == b: + return 0 + elif a < b: + return 1 + else: # a > b + return -1 + u.sort(key=cmp_to_key(revcmp)) + self.assertEqual(u, self.type2test([2,1,0,-1,-2])) + + # The following dumps core in unpatched Python 1.5: + def myComparison(x,y): + xmod, ymod = x%3, y%7 + if xmod == ymod: + return 0 + elif xmod < ymod: + return -1 + else: # xmod > ymod + return 1 + z = self.type2test(range(12)) + z.sort(key=cmp_to_key(myComparison)) + + self.assertRaises(TypeError, z.sort, 2) + + def selfmodifyingComparison(x,y): + z.append(1) + if x == y: + return 0 + elif x < y: + return -1 + else: # x > y + return 1 + self.assertRaises(ValueError, z.sort, + key=cmp_to_key(selfmodifyingComparison)) + + self.assertRaises(TypeError, z.sort, 42, 42, 42, 42) + + def test_slice(self): + u = self.type2test("spam") + u[:2] = "h" + self.assertEqual(u, list("ham")) + + def test_iadd(self): + super().test_iadd() + u = self.type2test([0, 1]) + u2 = u + u += [2, 3] + self.assertIs(u, u2) + + u = self.type2test("spam") + u += "eggs" + self.assertEqual(u, self.type2test("spameggs")) + + self.assertRaises(TypeError, u.__iadd__, None) + + def test_imul(self): + super().test_imul() + s = self.type2test([]) + oldid = id(s) + s *= 10 + self.assertEqual(id(s), oldid) + + def test_extendedslicing(self): + # subscript + a = self.type2test([0,1,2,3,4]) + + # deletion + del a[::2] + self.assertEqual(a, self.type2test([1,3])) + a = self.type2test(range(5)) + del a[1::2] + self.assertEqual(a, self.type2test([0,2,4])) + a = self.type2test(range(5)) + del a[1::-2] + self.assertEqual(a, self.type2test([0,2,3,4])) + a = self.type2test(range(10)) + del a[::1000] + self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 6, 7, 8, 9])) + # assignment + a = self.type2test(range(10)) + a[::2] = [-1]*5 + self.assertEqual(a, self.type2test([-1, 1, -1, 3, -1, 5, -1, 7, -1, 9])) + a = self.type2test(range(10)) + a[::-4] = [10]*3 + self.assertEqual(a, self.type2test([0, 10, 2, 3, 4, 10, 6, 7, 8 ,10])) + a = self.type2test(range(4)) + a[::-1] = a + self.assertEqual(a, self.type2test([3, 2, 1, 0])) + a = self.type2test(range(10)) + b = a[:] + c = a[:] + a[2:3] = self.type2test(["two", "elements"]) + b[slice(2,3)] = self.type2test(["two", "elements"]) + c[2:3:] = self.type2test(["two", "elements"]) + self.assertEqual(a, b) + self.assertEqual(a, c) + a = self.type2test(range(10)) + a[::2] = tuple(range(5)) + self.assertEqual(a, self.type2test([0, 1, 1, 3, 2, 5, 3, 7, 4, 9])) + # test issue7788 + a = self.type2test(range(10)) + del a[9::1<<333] + + def test_constructor_exception_handling(self): + # Bug #1242657 + class F(object): + def __iter__(self): + raise KeyboardInterrupt + self.assertRaises(KeyboardInterrupt, list, F()) + + def test_exhausted_iterator(self): + a = self.type2test([1, 2, 3]) + exhit = iter(a) + empit = iter(a) + for x in exhit: # exhaust the iterator + next(empit) # not exhausted + a.append(9) + self.assertEqual(list(exhit), []) + self.assertEqual(list(empit), [9]) + self.assertEqual(a, self.type2test([1, 2, 3, 9])) + + # gh-115733: Crash when iterating over exhausted iterator + exhit = iter(self.type2test([1, 2, 3])) + for _ in exhit: + next(exhit, 1) diff --git a/test/dynamo/cpython/3_13/mapping_tests.diff b/test/dynamo/cpython/3_13/mapping_tests.diff new file mode 100644 index 00000000000000..03ae75513d6646 --- /dev/null +++ b/test/dynamo/cpython/3_13/mapping_tests.diff @@ -0,0 +1,67 @@ +diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py +index ed89a81a6ea..eed59a68e94 100644 +--- a/test/dynamo/cpython/3_13/mapping_tests.py ++++ b/test/dynamo/cpython/3_13/mapping_tests.py +@@ -1,10 +1,61 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + # tests common to dict and UserDict + import unittest + import collections + from test.support import get_c_recursion_limit + + +-class BasicTestMappingProtocol(unittest.TestCase): ++class BasicTestMappingProtocol(__TestCase): + # This base class can be used to check that an object conforms to the + # mapping protocol + diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py new file mode 100644 index 00000000000000..eed59a68e9443a --- /dev/null +++ b/test/dynamo/cpython/3_13/mapping_tests.py @@ -0,0 +1,719 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +# tests common to dict and UserDict +import unittest +import collections +from test.support import get_c_recursion_limit + + +class BasicTestMappingProtocol(__TestCase): + # This base class can be used to check that an object conforms to the + # mapping protocol + + # Functions that can be useful to override to adapt to dictionary + # semantics + type2test = None # which class is being tested (overwrite in subclasses) + + def _reference(self): + """Return a dictionary of values which are invariant by storage + in the object under test.""" + return {"1": "2", "key1":"value1", "key2":(1,2,3)} + def _empty_mapping(self): + """Return an empty mapping object""" + return self.type2test() + def _full_mapping(self, data): + """Return a mapping object with the value contained in data + dictionary""" + x = self._empty_mapping() + for key, value in data.items(): + x[key] = value + return x + + def __init__(self, *args, **kw): + unittest.TestCase.__init__(self, *args, **kw) + self.reference = self._reference().copy() + + # A (key, value) pair not in the mapping + key, value = self.reference.popitem() + self.other = {key:value} + + # A (key, value) pair in the mapping + key, value = self.reference.popitem() + self.inmapping = {key:value} + self.reference[key] = value + + def test_read(self): + # Test for read only operations on mapping + p = self._empty_mapping() + p1 = dict(p) #workaround for singleton objects + d = self._full_mapping(self.reference) + if d is p: + p = p1 + #Indexing + for key, value in self.reference.items(): + self.assertEqual(d[key], value) + knownkey = list(self.other.keys())[0] + self.assertRaises(KeyError, lambda:d[knownkey]) + #len + self.assertEqual(len(p), 0) + self.assertEqual(len(d), len(self.reference)) + #__contains__ + for k in self.reference: + self.assertIn(k, d) + for k in self.other: + self.assertNotIn(k, d) + #cmp + self.assertEqual(p, p) + self.assertEqual(d, d) + self.assertNotEqual(p, d) + self.assertNotEqual(d, p) + #bool + if p: self.fail("Empty mapping must compare to False") + if not d: self.fail("Full mapping must compare to True") + # keys(), items(), iterkeys() ... + def check_iterandlist(iter, lst, ref): + self.assertTrue(hasattr(iter, '__next__')) + self.assertTrue(hasattr(iter, '__iter__')) + x = list(iter) + self.assertTrue(set(x)==set(lst)==set(ref)) + check_iterandlist(iter(d.keys()), list(d.keys()), + self.reference.keys()) + check_iterandlist(iter(d), list(d.keys()), self.reference.keys()) + check_iterandlist(iter(d.values()), list(d.values()), + self.reference.values()) + check_iterandlist(iter(d.items()), list(d.items()), + self.reference.items()) + #get + key, value = next(iter(d.items())) + knownkey, knownvalue = next(iter(self.other.items())) + self.assertEqual(d.get(key, knownvalue), value) + self.assertEqual(d.get(knownkey, knownvalue), knownvalue) + self.assertNotIn(knownkey, d) + + def test_write(self): + # Test for write operations on mapping + p = self._empty_mapping() + #Indexing + for key, value in self.reference.items(): + p[key] = value + self.assertEqual(p[key], value) + for key in self.reference.keys(): + del p[key] + self.assertRaises(KeyError, lambda:p[key]) + p = self._empty_mapping() + #update + p.update(self.reference) + self.assertEqual(dict(p), self.reference) + items = list(p.items()) + p = self._empty_mapping() + p.update(items) + self.assertEqual(dict(p), self.reference) + d = self._full_mapping(self.reference) + #setdefault + key, value = next(iter(d.items())) + knownkey, knownvalue = next(iter(self.other.items())) + self.assertEqual(d.setdefault(key, knownvalue), value) + self.assertEqual(d[key], value) + self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue) + self.assertEqual(d[knownkey], knownvalue) + #pop + self.assertEqual(d.pop(knownkey), knownvalue) + self.assertNotIn(knownkey, d) + self.assertRaises(KeyError, d.pop, knownkey) + default = 909 + d[knownkey] = knownvalue + self.assertEqual(d.pop(knownkey, default), knownvalue) + self.assertNotIn(knownkey, d) + self.assertEqual(d.pop(knownkey, default), default) + #popitem + key, value = d.popitem() + self.assertNotIn(key, d) + self.assertEqual(value, self.reference[key]) + p=self._empty_mapping() + self.assertRaises(KeyError, p.popitem) + + def test_constructor(self): + self.assertEqual(self._empty_mapping(), self._empty_mapping()) + + def test_bool(self): + self.assertTrue(not self._empty_mapping()) + self.assertTrue(self.reference) + self.assertTrue(bool(self._empty_mapping()) is False) + self.assertTrue(bool(self.reference) is True) + + def test_keys(self): + d = self._empty_mapping() + self.assertEqual(list(d.keys()), []) + d = self.reference + self.assertIn(list(self.inmapping.keys())[0], d.keys()) + self.assertNotIn(list(self.other.keys())[0], d.keys()) + self.assertRaises(TypeError, d.keys, None) + + def test_values(self): + d = self._empty_mapping() + self.assertEqual(list(d.values()), []) + + self.assertRaises(TypeError, d.values, None) + + def test_items(self): + d = self._empty_mapping() + self.assertEqual(list(d.items()), []) + + self.assertRaises(TypeError, d.items, None) + + def test_len(self): + d = self._empty_mapping() + self.assertEqual(len(d), 0) + + def test_getitem(self): + d = self.reference + self.assertEqual(d[list(self.inmapping.keys())[0]], + list(self.inmapping.values())[0]) + + self.assertRaises(TypeError, d.__getitem__) + + def test_update(self): + # mapping argument + d = self._empty_mapping() + d.update(self.other) + self.assertEqual(list(d.items()), list(self.other.items())) + + # No argument + d = self._empty_mapping() + d.update() + self.assertEqual(d, self._empty_mapping()) + + # item sequence + d = self._empty_mapping() + d.update(self.other.items()) + self.assertEqual(list(d.items()), list(self.other.items())) + + # Iterator + d = self._empty_mapping() + d.update(self.other.items()) + self.assertEqual(list(d.items()), list(self.other.items())) + + # FIXME: Doesn't work with UserDict + # self.assertRaises((TypeError, AttributeError), d.update, None) + self.assertRaises((TypeError, AttributeError), d.update, 42) + + outerself = self + class SimpleUserDict: + def __init__(self): + self.d = outerself.reference + def keys(self): + return self.d.keys() + def __getitem__(self, i): + return self.d[i] + d.clear() + d.update(SimpleUserDict()) + i1 = sorted(d.items()) + i2 = sorted(self.reference.items()) + self.assertEqual(i1, i2) + + class Exc(Exception): pass + + d = self._empty_mapping() + class FailingUserDict: + def keys(self): + raise Exc + self.assertRaises(Exc, d.update, FailingUserDict()) + + d.clear() + + class FailingUserDict: + def keys(self): + class BogonIter: + def __init__(self): + self.i = 1 + def __iter__(self): + return self + def __next__(self): + if self.i: + self.i = 0 + return 'a' + raise Exc + return BogonIter() + def __getitem__(self, key): + return key + self.assertRaises(Exc, d.update, FailingUserDict()) + + class FailingUserDict: + def keys(self): + class BogonIter: + def __init__(self): + self.i = ord('a') + def __iter__(self): + return self + def __next__(self): + if self.i <= ord('z'): + rtn = chr(self.i) + self.i += 1 + return rtn + raise StopIteration + return BogonIter() + def __getitem__(self, key): + raise Exc + self.assertRaises(Exc, d.update, FailingUserDict()) + + d = self._empty_mapping() + class badseq(object): + def __iter__(self): + return self + def __next__(self): + raise Exc() + + self.assertRaises(Exc, d.update, badseq()) + + self.assertRaises(ValueError, d.update, [(1, 2, 3)]) + + # no test_fromkeys or test_copy as both os.environ and selves don't support it + + def test_get(self): + d = self._empty_mapping() + self.assertTrue(d.get(list(self.other.keys())[0]) is None) + self.assertEqual(d.get(list(self.other.keys())[0], 3), 3) + d = self.reference + self.assertTrue(d.get(list(self.other.keys())[0]) is None) + self.assertEqual(d.get(list(self.other.keys())[0], 3), 3) + self.assertEqual(d.get(list(self.inmapping.keys())[0]), + list(self.inmapping.values())[0]) + self.assertEqual(d.get(list(self.inmapping.keys())[0], 3), + list(self.inmapping.values())[0]) + self.assertRaises(TypeError, d.get) + self.assertRaises(TypeError, d.get, None, None, None) + + def test_setdefault(self): + d = self._empty_mapping() + self.assertRaises(TypeError, d.setdefault) + + def test_popitem(self): + d = self._empty_mapping() + self.assertRaises(KeyError, d.popitem) + self.assertRaises(TypeError, d.popitem, 42) + + def test_pop(self): + d = self._empty_mapping() + k, v = list(self.inmapping.items())[0] + d[k] = v + self.assertRaises(KeyError, d.pop, list(self.other.keys())[0]) + + self.assertEqual(d.pop(k), v) + self.assertEqual(len(d), 0) + + self.assertRaises(KeyError, d.pop, k) + + +class TestMappingProtocol(BasicTestMappingProtocol): + def test_constructor(self): + BasicTestMappingProtocol.test_constructor(self) + self.assertTrue(self._empty_mapping() is not self._empty_mapping()) + self.assertEqual(self.type2test(x=1, y=2), {"x": 1, "y": 2}) + + def test_bool(self): + BasicTestMappingProtocol.test_bool(self) + self.assertTrue(not self._empty_mapping()) + self.assertTrue(self._full_mapping({"x": "y"})) + self.assertTrue(bool(self._empty_mapping()) is False) + self.assertTrue(bool(self._full_mapping({"x": "y"})) is True) + + def test_keys(self): + BasicTestMappingProtocol.test_keys(self) + d = self._empty_mapping() + self.assertEqual(list(d.keys()), []) + d = self._full_mapping({'a': 1, 'b': 2}) + k = d.keys() + self.assertIn('a', k) + self.assertIn('b', k) + self.assertNotIn('c', k) + + def test_values(self): + BasicTestMappingProtocol.test_values(self) + d = self._full_mapping({1:2}) + self.assertEqual(list(d.values()), [2]) + + def test_items(self): + BasicTestMappingProtocol.test_items(self) + + d = self._full_mapping({1:2}) + self.assertEqual(list(d.items()), [(1, 2)]) + + def test_contains(self): + d = self._empty_mapping() + self.assertNotIn('a', d) + self.assertTrue(not ('a' in d)) + self.assertTrue('a' not in d) + d = self._full_mapping({'a': 1, 'b': 2}) + self.assertIn('a', d) + self.assertIn('b', d) + self.assertNotIn('c', d) + + self.assertRaises(TypeError, d.__contains__) + + def test_len(self): + BasicTestMappingProtocol.test_len(self) + d = self._full_mapping({'a': 1, 'b': 2}) + self.assertEqual(len(d), 2) + + def test_getitem(self): + BasicTestMappingProtocol.test_getitem(self) + d = self._full_mapping({'a': 1, 'b': 2}) + self.assertEqual(d['a'], 1) + self.assertEqual(d['b'], 2) + d['c'] = 3 + d['a'] = 4 + self.assertEqual(d['c'], 3) + self.assertEqual(d['a'], 4) + del d['b'] + self.assertEqual(d, {'a': 4, 'c': 3}) + + self.assertRaises(TypeError, d.__getitem__) + + def test_clear(self): + d = self._full_mapping({1:1, 2:2, 3:3}) + d.clear() + self.assertEqual(d, {}) + + self.assertRaises(TypeError, d.clear, None) + + def test_update(self): + BasicTestMappingProtocol.test_update(self) + # mapping argument + d = self._empty_mapping() + d.update({1:100}) + d.update({2:20}) + d.update({1:1, 2:2, 3:3}) + self.assertEqual(d, {1:1, 2:2, 3:3}) + + # no argument + d.update() + self.assertEqual(d, {1:1, 2:2, 3:3}) + + # keyword arguments + d = self._empty_mapping() + d.update(x=100) + d.update(y=20) + d.update(x=1, y=2, z=3) + self.assertEqual(d, {"x":1, "y":2, "z":3}) + + # item sequence + d = self._empty_mapping() + d.update([("x", 100), ("y", 20)]) + self.assertEqual(d, {"x":100, "y":20}) + + # Both item sequence and keyword arguments + d = self._empty_mapping() + d.update([("x", 100), ("y", 20)], x=1, y=2) + self.assertEqual(d, {"x":1, "y":2}) + + # iterator + d = self._full_mapping({1:3, 2:4}) + d.update(self._full_mapping({1:2, 3:4, 5:6}).items()) + self.assertEqual(d, {1:2, 2:4, 3:4, 5:6}) + + class SimpleUserDict: + def __init__(self): + self.d = {1:1, 2:2, 3:3} + def keys(self): + return self.d.keys() + def __getitem__(self, i): + return self.d[i] + d.clear() + d.update(SimpleUserDict()) + self.assertEqual(d, {1:1, 2:2, 3:3}) + + def test_fromkeys(self): + self.assertEqual(self.type2test.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) + d = self._empty_mapping() + self.assertTrue(not(d.fromkeys('abc') is d)) + self.assertEqual(d.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) + self.assertEqual(d.fromkeys((4,5),0), {4:0, 5:0}) + self.assertEqual(d.fromkeys([]), {}) + def g(): + yield 1 + self.assertEqual(d.fromkeys(g()), {1:None}) + self.assertRaises(TypeError, {}.fromkeys, 3) + class dictlike(self.type2test): pass + self.assertEqual(dictlike.fromkeys('a'), {'a':None}) + self.assertEqual(dictlike().fromkeys('a'), {'a':None}) + self.assertTrue(dictlike.fromkeys('a').__class__ is dictlike) + self.assertTrue(dictlike().fromkeys('a').__class__ is dictlike) + self.assertTrue(type(dictlike.fromkeys('a')) is dictlike) + class mydict(self.type2test): + def __new__(cls): + return collections.UserDict() + ud = mydict.fromkeys('ab') + self.assertEqual(ud, {'a':None, 'b':None}) + self.assertIsInstance(ud, collections.UserDict) + self.assertRaises(TypeError, dict.fromkeys) + + class Exc(Exception): pass + + class baddict1(self.type2test): + def __init__(self, *args, **kwargs): + raise Exc() + + self.assertRaises(Exc, baddict1.fromkeys, [1]) + + class BadSeq(object): + def __iter__(self): + return self + def __next__(self): + raise Exc() + + self.assertRaises(Exc, self.type2test.fromkeys, BadSeq()) + + class baddict2(self.type2test): + def __setitem__(self, key, value): + raise Exc() + + self.assertRaises(Exc, baddict2.fromkeys, [1]) + + def test_copy(self): + d = self._full_mapping({1:1, 2:2, 3:3}) + self.assertEqual(d.copy(), {1:1, 2:2, 3:3}) + d = self._empty_mapping() + self.assertEqual(d.copy(), d) + self.assertIsInstance(d.copy(), d.__class__) + self.assertRaises(TypeError, d.copy, None) + + def test_get(self): + BasicTestMappingProtocol.test_get(self) + d = self._empty_mapping() + self.assertTrue(d.get('c') is None) + self.assertEqual(d.get('c', 3), 3) + d = self._full_mapping({'a' : 1, 'b' : 2}) + self.assertTrue(d.get('c') is None) + self.assertEqual(d.get('c', 3), 3) + self.assertEqual(d.get('a'), 1) + self.assertEqual(d.get('a', 3), 1) + + def test_setdefault(self): + BasicTestMappingProtocol.test_setdefault(self) + d = self._empty_mapping() + self.assertTrue(d.setdefault('key0') is None) + d.setdefault('key0', []) + self.assertTrue(d.setdefault('key0') is None) + d.setdefault('key', []).append(3) + self.assertEqual(d['key'][0], 3) + d.setdefault('key', []).append(4) + self.assertEqual(len(d['key']), 2) + + def test_popitem(self): + BasicTestMappingProtocol.test_popitem(self) + for copymode in -1, +1: + # -1: b has same structure as a + # +1: b is a.copy() + for log2size in range(12): + size = 2**log2size + a = self._empty_mapping() + b = self._empty_mapping() + for i in range(size): + a[repr(i)] = i + if copymode < 0: + b[repr(i)] = i + if copymode > 0: + b = a.copy() + for i in range(size): + ka, va = ta = a.popitem() + self.assertEqual(va, int(ka)) + kb, vb = tb = b.popitem() + self.assertEqual(vb, int(kb)) + self.assertTrue(not(copymode < 0 and ta != tb)) + self.assertTrue(not a) + self.assertTrue(not b) + + def test_pop(self): + BasicTestMappingProtocol.test_pop(self) + + # Tests for pop with specified key + d = self._empty_mapping() + k, v = 'abc', 'def' + + self.assertEqual(d.pop(k, v), v) + d[k] = v + self.assertEqual(d.pop(k, 1), v) + + +class TestHashMappingProtocol(TestMappingProtocol): + + def test_getitem(self): + TestMappingProtocol.test_getitem(self) + class Exc(Exception): pass + + class BadEq(object): + def __eq__(self, other): + raise Exc() + def __hash__(self): + return 24 + + d = self._empty_mapping() + d[BadEq()] = 42 + self.assertRaises(KeyError, d.__getitem__, 23) + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 + + d = self._empty_mapping() + x = BadHash() + d[x] = 42 + x.fail = True + self.assertRaises(Exc, d.__getitem__, x) + + def test_fromkeys(self): + TestMappingProtocol.test_fromkeys(self) + class mydict(self.type2test): + def __new__(cls): + return collections.UserDict() + ud = mydict.fromkeys('ab') + self.assertEqual(ud, {'a':None, 'b':None}) + self.assertIsInstance(ud, collections.UserDict) + + def test_pop(self): + TestMappingProtocol.test_pop(self) + + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 + + d = self._empty_mapping() + x = BadHash() + d[x] = 42 + x.fail = True + self.assertRaises(Exc, d.pop, x) + + def test_mutatingiteration(self): + d = self._empty_mapping() + d[1] = 1 + try: + count = 0 + for i in d: + d[i+1] = 1 + if count >= 1: + self.fail("changing dict size during iteration doesn't raise Error") + count += 1 + except RuntimeError: + pass + + def test_repr(self): + d = self._empty_mapping() + self.assertEqual(repr(d), '{}') + d[1] = 2 + self.assertEqual(repr(d), '{1: 2}') + d = self._empty_mapping() + d[1] = d + self.assertEqual(repr(d), '{1: {...}}') + + class Exc(Exception): pass + + class BadRepr(object): + def __repr__(self): + raise Exc() + + d = self._full_mapping({1: BadRepr()}) + self.assertRaises(Exc, repr, d) + + def test_repr_deep(self): + d = self._empty_mapping() + for i in range(get_c_recursion_limit() + 1): + d0 = d + d = self._empty_mapping() + d[1] = d0 + self.assertRaises(RecursionError, repr, d) + + def test_eq(self): + self.assertEqual(self._empty_mapping(), self._empty_mapping()) + self.assertEqual(self._full_mapping({1: 2}), + self._full_mapping({1: 2})) + + class Exc(Exception): pass + + class BadCmp(object): + def __eq__(self, other): + raise Exc() + def __hash__(self): + return 1 + + d1 = self._full_mapping({BadCmp(): 1}) + d2 = self._full_mapping({1: 1}) + self.assertRaises(Exc, lambda: BadCmp()==1) + self.assertRaises(Exc, lambda: d1==d2) + + def test_setdefault(self): + TestMappingProtocol.test_setdefault(self) + + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 + + d = self._empty_mapping() + x = BadHash() + d[x] = 42 + x.fail = True + self.assertRaises(Exc, d.setdefault, x, []) diff --git a/test/dynamo/cpython/3_13/mathdata/cmath_testcases.txt b/test/dynamo/cpython/3_13/mathdata/cmath_testcases.txt new file mode 100644 index 00000000000000..0165e17634f41c --- /dev/null +++ b/test/dynamo/cpython/3_13/mathdata/cmath_testcases.txt @@ -0,0 +1,2514 @@ +-- Testcases for functions in cmath. +-- +-- Each line takes the form: +-- +-- -> +-- +-- where: +-- +-- is a short name identifying the test, +-- +-- is the function to be tested (exp, cos, asinh, ...), +-- +-- is a pair of floats separated by whitespace +-- representing real and imaginary parts of a complex number, and +-- +-- is the expected (ideal) output value, again +-- represented as a pair of floats. +-- +-- is a list of the floating-point flags required by C99 +-- +-- The possible flags are: +-- +-- divide-by-zero : raised when a finite input gives a +-- mathematically infinite result. +-- +-- overflow : raised when a finite input gives a finite result whose +-- real or imaginary part is too large to fit in the usual range +-- of an IEEE 754 double. +-- +-- invalid : raised for invalid inputs. +-- +-- ignore-real-sign : indicates that the sign of the real part of +-- the result is unspecified; if the real part of the result is +-- given as inf, then both -inf and inf should be accepted as +-- correct. +-- +-- ignore-imag-sign : indicates that the sign of the imaginary part +-- of the result is unspecified. +-- +-- Flags may appear in any order. +-- +-- Lines beginning with '--' (like this one) start a comment, and are +-- ignored. Blank lines, or lines containing only whitespace, are also +-- ignored. + +-- The majority of the values below were computed with the help of +-- version 2.3 of the MPFR library for multiple-precision +-- floating-point computations with correct rounding. All output +-- values in this file are (modulo yet-to-be-discovered bugs) +-- correctly rounded, provided that each input and output decimal +-- floating-point value below is interpreted as a representation of +-- the corresponding nearest IEEE 754 double-precision value. See the +-- MPFR homepage at http://www.mpfr.org for more information about the +-- MPFR project. + +-- A minority of the test cases were generated with the help of +-- mpmath 0.19 at 100 bit accuracy (http://mpmath.org) to improve +-- coverage of real functions with real-valued arguments. These are +-- used in test.test_math.MathTests.test_testfile, as well as in +-- test_cmath. + + +-------------------------- +-- acos: Inverse cosine -- +-------------------------- + +-- zeros +acos0000 acos 0.0 0.0 -> 1.5707963267948966 -0.0 +acos0001 acos 0.0 -0.0 -> 1.5707963267948966 0.0 +acos0002 acos -0.0 0.0 -> 1.5707963267948966 -0.0 +acos0003 acos -0.0 -0.0 -> 1.5707963267948966 0.0 + +-- branch points: +/-1 +acos0010 acos 1.0 0.0 -> 0.0 -0.0 +acos0011 acos 1.0 -0.0 -> 0.0 0.0 +acos0012 acos -1.0 0.0 -> 3.1415926535897931 -0.0 +acos0013 acos -1.0 -0.0 -> 3.1415926535897931 0.0 + +-- values along both sides of real axis +acos0020 acos -9.8813129168249309e-324 0.0 -> 1.5707963267948966 -0.0 +acos0021 acos -9.8813129168249309e-324 -0.0 -> 1.5707963267948966 0.0 +acos0022 acos -1e-305 0.0 -> 1.5707963267948966 -0.0 +acos0023 acos -1e-305 -0.0 -> 1.5707963267948966 0.0 +acos0024 acos -1e-150 0.0 -> 1.5707963267948966 -0.0 +acos0025 acos -1e-150 -0.0 -> 1.5707963267948966 0.0 +acos0026 acos -9.9999999999999998e-17 0.0 -> 1.5707963267948968 -0.0 +acos0027 acos -9.9999999999999998e-17 -0.0 -> 1.5707963267948968 0.0 +acos0028 acos -0.001 0.0 -> 1.5717963269615634 -0.0 +acos0029 acos -0.001 -0.0 -> 1.5717963269615634 0.0 +acos0030 acos -0.57899999999999996 0.0 -> 2.1882979816120667 -0.0 +acos0031 acos -0.57899999999999996 -0.0 -> 2.1882979816120667 0.0 +acos0032 acos -0.99999999999999989 0.0 -> 3.1415926386886319 -0.0 +acos0033 acos -0.99999999999999989 -0.0 -> 3.1415926386886319 0.0 +acos0034 acos -1.0000000000000002 0.0 -> 3.1415926535897931 -2.1073424255447014e-08 +acos0035 acos -1.0000000000000002 -0.0 -> 3.1415926535897931 2.1073424255447014e-08 +acos0036 acos -1.0009999999999999 0.0 -> 3.1415926535897931 -0.044717633608306849 +acos0037 acos -1.0009999999999999 -0.0 -> 3.1415926535897931 0.044717633608306849 +acos0038 acos -2.0 0.0 -> 3.1415926535897931 -1.3169578969248168 +acos0039 acos -2.0 -0.0 -> 3.1415926535897931 1.3169578969248168 +acos0040 acos -23.0 0.0 -> 3.1415926535897931 -3.8281684713331012 +acos0041 acos -23.0 -0.0 -> 3.1415926535897931 3.8281684713331012 +acos0042 acos -10000000000000000.0 0.0 -> 3.1415926535897931 -37.534508668464674 +acos0043 acos -10000000000000000.0 -0.0 -> 3.1415926535897931 37.534508668464674 +acos0044 acos -9.9999999999999998e+149 0.0 -> 3.1415926535897931 -346.08091112966679 +acos0045 acos -9.9999999999999998e+149 -0.0 -> 3.1415926535897931 346.08091112966679 +acos0046 acos -1.0000000000000001e+299 0.0 -> 3.1415926535897931 -689.16608998577965 +acos0047 acos -1.0000000000000001e+299 -0.0 -> 3.1415926535897931 689.16608998577965 +acos0048 acos 9.8813129168249309e-324 0.0 -> 1.5707963267948966 -0.0 +acos0049 acos 9.8813129168249309e-324 -0.0 -> 1.5707963267948966 0.0 +acos0050 acos 1e-305 0.0 -> 1.5707963267948966 -0.0 +acos0051 acos 1e-305 -0.0 -> 1.5707963267948966 0.0 +acos0052 acos 1e-150 0.0 -> 1.5707963267948966 -0.0 +acos0053 acos 1e-150 -0.0 -> 1.5707963267948966 0.0 +acos0054 acos 9.9999999999999998e-17 0.0 -> 1.5707963267948966 -0.0 +acos0055 acos 9.9999999999999998e-17 -0.0 -> 1.5707963267948966 0.0 +acos0056 acos 0.001 0.0 -> 1.56979632662823 -0.0 +acos0057 acos 0.001 -0.0 -> 1.56979632662823 0.0 +acos0058 acos 0.57899999999999996 0.0 -> 0.95329467197772655 -0.0 +acos0059 acos 0.57899999999999996 -0.0 -> 0.95329467197772655 0.0 +acos0060 acos 0.99999999999999989 0.0 -> 1.4901161193847656e-08 -0.0 +acos0061 acos 0.99999999999999989 -0.0 -> 1.4901161193847656e-08 0.0 +acos0062 acos 1.0000000000000002 0.0 -> 0.0 -2.1073424255447014e-08 +acos0063 acos 1.0000000000000002 -0.0 -> 0.0 2.1073424255447014e-08 +acos0064 acos 1.0009999999999999 0.0 -> 0.0 -0.044717633608306849 +acos0065 acos 1.0009999999999999 -0.0 -> 0.0 0.044717633608306849 +acos0066 acos 2.0 0.0 -> 0.0 -1.3169578969248168 +acos0067 acos 2.0 -0.0 -> 0.0 1.3169578969248168 +acos0068 acos 23.0 0.0 -> 0.0 -3.8281684713331012 +acos0069 acos 23.0 -0.0 -> 0.0 3.8281684713331012 +acos0070 acos 10000000000000000.0 0.0 -> 0.0 -37.534508668464674 +acos0071 acos 10000000000000000.0 -0.0 -> 0.0 37.534508668464674 +acos0072 acos 9.9999999999999998e+149 0.0 -> 0.0 -346.08091112966679 +acos0073 acos 9.9999999999999998e+149 -0.0 -> 0.0 346.08091112966679 +acos0074 acos 1.0000000000000001e+299 0.0 -> 0.0 -689.16608998577965 +acos0075 acos 1.0000000000000001e+299 -0.0 -> 0.0 689.16608998577965 + +-- random inputs +acos0100 acos -3.3307113324596682 -10.732007530863266 -> 1.8706085694482339 3.113986806554613 +acos0101 acos -2863.952991743291 -2681013315.2571239 -> 1.5707973950301699 22.402607843274758 +acos0102 acos -0.33072639793220088 -0.85055464658253055 -> 1.8219426895922601 0.79250166729311966 +acos0103 acos -2.5722325842097802 -12.703940809821574 -> 1.7699942413107408 3.2565170156527325 +acos0104 acos -42.495233785459583 -0.54039320751337161 -> 3.1288732573153304 4.4424815519735601 +acos0105 acos -1.1363818625856401 9641.1325498630376 -> 1.5709141948820049 -9.8669410553254284 +acos0106 acos -2.4398426824157866e-11 0.33002051890266165 -> 1.570796326818066 -0.32430578041578667 +acos0107 acos -1.3521340428186552 2.9369737912076772 -> 1.9849059192339338 -1.8822893674117942 +acos0108 acos -1.827364706477915 1.0355459232147557 -> 2.5732246307960032 -1.4090688267854969 +acos0109 acos -0.25978373706403546 10.09712669185833 -> 1.5963940386378306 -3.0081673050196063 +acos0110 acos 0.33561778471072551 -4587350.6823999118 -> 1.5707962536333251 16.031960402579539 +acos0111 acos 0.49133444610998445 -0.8071422362990015 -> 1.1908761712801788 0.78573345813187867 +acos0112 acos 0.42196734507823974 -2.4812965431745115 -> 1.414091186100692 1.651707260988172 +acos0113 acos 2.961426210100655 -219.03295695248664 -> 1.5572768319822778 6.0824659885827304 +acos0114 acos 2.886209063652641 -20.38011207220606 -> 1.4302765252297889 3.718201853147642 +acos0115 acos 0.4180568075276509 1.4833433990823484 -> 1.3393834558303042 -1.2079847758301576 +acos0116 acos 52.376111405924718 0.013930429001941001 -> 0.00026601761804024188 -4.6515066691204714 +acos0117 acos 41637948387.625969 1.563418292894041 -> 3.7547918507883548e-11 -25.145424989809381 +acos0118 acos 0.061226659122249526 0.8447234394615154 -> 1.5240280306367315 -0.76791798971140812 +acos0119 acos 2.4480466420442959e+26 0.18002339201384662 -> 7.353756620564798e-28 -61.455650015996376 + +-- values near infinity +acos0200 acos 1.6206860518683021e+308 1.0308426226285283e+308 -> 0.56650826093826223 -710.54206874241561 +acos0201 acos 1.2067735875070062e+308 -1.3429173724390276e+308 -> 0.83874369390864889 710.48017794027498 +acos0202 acos -7.4130145132549047e+307 1.1759130543927645e+308 -> 2.1332729346478536 -710.21871115698752 +acos0203 acos -8.6329426442257249e+307 -1.2316282952184133e+308 -> 2.1821511032444838 710.29752145697148 +acos0204 acos 0.0 1.4289713855849746e+308 -> 1.5707963267948966 -710.24631069738996 +acos0205 acos -0.0 1.3153524545987432e+308 -> 1.5707963267948966 -710.1634604787539 +acos0206 acos 0.0 -9.6229037669269321e+307 -> 1.5707963267948966 709.85091679573691 +acos0207 acos -0.0 -4.9783616421107088e+307 -> 1.5707963267948966 709.19187157911233 +acos0208 acos 1.3937541925739389e+308 0.0 -> 0.0 -710.22135678707264 +acos0209 acos 9.1362388967371536e+307 -0.0 -> 0.0 709.79901953124613 +acos0210 acos -1.3457361220697436e+308 0.0 -> 3.1415926535897931 -710.18629698871848 +acos0211 acos -5.4699090056144284e+307 -0.0 -> 3.1415926535897931 709.28603271085649 +acos0212 acos 1.5880716932358901e+308 5.5638401252339929 -> 3.503519487773873e-308 -710.35187633140583 +acos0213 acos 1.2497211663463164e+308 -3.0456477717911024 -> 2.4370618453197486e-308 710.11227628223412 +acos0214 acos -9.9016224006029528e+307 4.9570427340789056 -> 3.1415926535897931 -709.87946935229468 +acos0215 acos -1.5854071066874139e+308 -4.4233577741497783 -> 3.1415926535897931 710.35019704672004 +acos0216 acos 9.3674623083647628 1.5209559051877979e+308 -> 1.5707963267948966 -710.30869484491086 +acos0217 acos 8.1773832021784383 -6.6093445795000056e+307 -> 1.5707963267948966 709.4752552227792 +acos0218 acos -3.1845935000665104 1.5768856396650893e+308 -> 1.5707963267948966 -710.34480761042687 +acos0219 acos -1.0577303880953903 -6.4574626815735613e+307 -> 1.5707963267948966 709.45200719662046 + +-- values near 0 +acos0220 acos 1.8566986970714045e-320 3.1867234156760402e-321 -> 1.5707963267948966 -3.1867234156760402e-321 +acos0221 acos 7.9050503334599447e-323 -8.8931816251424378e-323 -> 1.5707963267948966 8.8931816251424378e-323 +acos0222 acos -4.4465908125712189e-323 2.4654065097222727e-311 -> 1.5707963267948966 -2.4654065097222727e-311 +acos0223 acos -6.1016916408192619e-311 -2.4703282292062327e-323 -> 1.5707963267948966 2.4703282292062327e-323 +acos0224 acos 0.0 3.4305783621842729e-311 -> 1.5707963267948966 -3.4305783621842729e-311 +acos0225 acos -0.0 1.6117409498633145e-319 -> 1.5707963267948966 -1.6117409498633145e-319 +acos0226 acos 0.0 -4.9900630229965901e-322 -> 1.5707963267948966 4.9900630229965901e-322 +acos0227 acos -0.0 -4.4889279210592818e-311 -> 1.5707963267948966 4.4889279210592818e-311 +acos0228 acos 5.3297678681477214e-312 0.0 -> 1.5707963267948966 -0.0 +acos0229 acos 6.2073425897211614e-313 -0.0 -> 1.5707963267948966 0.0 +acos0230 acos -4.9406564584124654e-324 0.0 -> 1.5707963267948966 -0.0 +acos0231 acos -1.7107517052899003e-318 -0.0 -> 1.5707963267948966 0.0 + +-- special values +acos1000 acos 0.0 0.0 -> 1.5707963267948966 -0.0 +acos1001 acos 0.0 -0.0 -> 1.5707963267948966 0.0 +acos1002 acos -0.0 0.0 -> 1.5707963267948966 -0.0 +acos1003 acos -0.0 -0.0 -> 1.5707963267948966 0.0 +acos1004 acos 0.0 nan -> 1.5707963267948966 nan +acos1005 acos -0.0 nan -> 1.5707963267948966 nan +acos1006 acos -2.3 inf -> 1.5707963267948966 -inf +acos1007 acos -0.0 inf -> 1.5707963267948966 -inf +acos1008 acos 0.0 inf -> 1.5707963267948966 -inf +acos1009 acos 2.3 inf -> 1.5707963267948966 -inf +acos1010 acos -2.3 nan -> nan nan +acos1011 acos 2.3 nan -> nan nan +acos1012 acos -inf 2.3 -> 3.1415926535897931 -inf +acos1013 acos -inf 0.0 -> 3.1415926535897931 -inf +acos1014 acos inf 2.3 -> 0.0 -inf +acos1015 acos inf 0.0 -> 0.0 -inf +acos1016 acos -inf inf -> 2.3561944901923448 -inf +acos1017 acos inf inf -> 0.78539816339744828 -inf +acos1018 acos inf nan -> nan inf ignore-imag-sign +acos1019 acos -inf nan -> nan inf ignore-imag-sign +acos1020 acos nan 0.0 -> nan nan +acos1021 acos nan 2.3 -> nan nan +acos1022 acos nan inf -> nan -inf +acos1023 acos nan nan -> nan nan +acos1024 acos -2.3 -inf -> 1.5707963267948966 inf +acos1025 acos -0.0 -inf -> 1.5707963267948966 inf +acos1026 acos 0.0 -inf -> 1.5707963267948966 inf +acos1027 acos 2.3 -inf -> 1.5707963267948966 inf +acos1028 acos -inf -2.3 -> 3.1415926535897931 inf +acos1029 acos -inf -0.0 -> 3.1415926535897931 inf +acos1030 acos inf -2.3 -> 0.0 inf +acos1031 acos inf -0.0 -> 0.0 inf +acos1032 acos -inf -inf -> 2.3561944901923448 inf +acos1033 acos inf -inf -> 0.78539816339744828 inf +acos1034 acos nan -0.0 -> nan nan +acos1035 acos nan -2.3 -> nan nan +acos1036 acos nan -inf -> nan inf + + +-------------------------------------- +-- acosh: Inverse hyperbolic cosine -- +-------------------------------------- + +-- zeros +acosh0000 acosh 0.0 0.0 -> 0.0 1.5707963267948966 +acosh0001 acosh 0.0 -0.0 -> 0.0 -1.5707963267948966 +acosh0002 acosh -0.0 0.0 -> 0.0 1.5707963267948966 +acosh0003 acosh -0.0 -0.0 -> 0.0 -1.5707963267948966 + +-- branch points: +/-1 +acosh0010 acosh 1.0 0.0 -> 0.0 0.0 +acosh0011 acosh 1.0 -0.0 -> 0.0 -0.0 +acosh0012 acosh -1.0 0.0 -> 0.0 3.1415926535897931 +acosh0013 acosh -1.0 -0.0 -> 0.0 -3.1415926535897931 + +-- values along both sides of real axis +acosh0020 acosh -9.8813129168249309e-324 0.0 -> 0.0 1.5707963267948966 +acosh0021 acosh -9.8813129168249309e-324 -0.0 -> 0.0 -1.5707963267948966 +acosh0022 acosh -1e-305 0.0 -> 0.0 1.5707963267948966 +acosh0023 acosh -1e-305 -0.0 -> 0.0 -1.5707963267948966 +acosh0024 acosh -1e-150 0.0 -> 0.0 1.5707963267948966 +acosh0025 acosh -1e-150 -0.0 -> 0.0 -1.5707963267948966 +acosh0026 acosh -9.9999999999999998e-17 0.0 -> 0.0 1.5707963267948968 +acosh0027 acosh -9.9999999999999998e-17 -0.0 -> 0.0 -1.5707963267948968 +acosh0028 acosh -0.001 0.0 -> 0.0 1.5717963269615634 +acosh0029 acosh -0.001 -0.0 -> 0.0 -1.5717963269615634 +acosh0030 acosh -0.57899999999999996 0.0 -> 0.0 2.1882979816120667 +acosh0031 acosh -0.57899999999999996 -0.0 -> 0.0 -2.1882979816120667 +acosh0032 acosh -0.99999999999999989 0.0 -> 0.0 3.1415926386886319 +acosh0033 acosh -0.99999999999999989 -0.0 -> 0.0 -3.1415926386886319 +acosh0034 acosh -1.0000000000000002 0.0 -> 2.1073424255447014e-08 3.1415926535897931 +acosh0035 acosh -1.0000000000000002 -0.0 -> 2.1073424255447014e-08 -3.1415926535897931 +acosh0036 acosh -1.0009999999999999 0.0 -> 0.044717633608306849 3.1415926535897931 +acosh0037 acosh -1.0009999999999999 -0.0 -> 0.044717633608306849 -3.1415926535897931 +acosh0038 acosh -2.0 0.0 -> 1.3169578969248168 3.1415926535897931 +acosh0039 acosh -2.0 -0.0 -> 1.3169578969248168 -3.1415926535897931 +acosh0040 acosh -23.0 0.0 -> 3.8281684713331012 3.1415926535897931 +acosh0041 acosh -23.0 -0.0 -> 3.8281684713331012 -3.1415926535897931 +acosh0042 acosh -10000000000000000.0 0.0 -> 37.534508668464674 3.1415926535897931 +acosh0043 acosh -10000000000000000.0 -0.0 -> 37.534508668464674 -3.1415926535897931 +acosh0044 acosh -9.9999999999999998e+149 0.0 -> 346.08091112966679 3.1415926535897931 +acosh0045 acosh -9.9999999999999998e+149 -0.0 -> 346.08091112966679 -3.1415926535897931 +acosh0046 acosh -1.0000000000000001e+299 0.0 -> 689.16608998577965 3.1415926535897931 +acosh0047 acosh -1.0000000000000001e+299 -0.0 -> 689.16608998577965 -3.1415926535897931 +acosh0048 acosh 9.8813129168249309e-324 0.0 -> 0.0 1.5707963267948966 +acosh0049 acosh 9.8813129168249309e-324 -0.0 -> 0.0 -1.5707963267948966 +acosh0050 acosh 1e-305 0.0 -> 0.0 1.5707963267948966 +acosh0051 acosh 1e-305 -0.0 -> 0.0 -1.5707963267948966 +acosh0052 acosh 1e-150 0.0 -> 0.0 1.5707963267948966 +acosh0053 acosh 1e-150 -0.0 -> 0.0 -1.5707963267948966 +acosh0054 acosh 9.9999999999999998e-17 0.0 -> 0.0 1.5707963267948966 +acosh0055 acosh 9.9999999999999998e-17 -0.0 -> 0.0 -1.5707963267948966 +acosh0056 acosh 0.001 0.0 -> 0.0 1.56979632662823 +acosh0057 acosh 0.001 -0.0 -> 0.0 -1.56979632662823 +acosh0058 acosh 0.57899999999999996 0.0 -> 0.0 0.95329467197772655 +acosh0059 acosh 0.57899999999999996 -0.0 -> 0.0 -0.95329467197772655 +acosh0060 acosh 0.99999999999999989 0.0 -> 0.0 1.4901161193847656e-08 +acosh0061 acosh 0.99999999999999989 -0.0 -> 0.0 -1.4901161193847656e-08 +acosh0062 acosh 1.0000000000000002 0.0 -> 2.1073424255447014e-08 0.0 +acosh0063 acosh 1.0000000000000002 -0.0 -> 2.1073424255447014e-08 -0.0 +acosh0064 acosh 1.0009999999999999 0.0 -> 0.044717633608306849 0.0 +acosh0065 acosh 1.0009999999999999 -0.0 -> 0.044717633608306849 -0.0 +acosh0066 acosh 2.0 0.0 -> 1.3169578969248168 0.0 +acosh0067 acosh 2.0 -0.0 -> 1.3169578969248168 -0.0 +acosh0068 acosh 23.0 0.0 -> 3.8281684713331012 0.0 +acosh0069 acosh 23.0 -0.0 -> 3.8281684713331012 -0.0 +acosh0070 acosh 10000000000000000.0 0.0 -> 37.534508668464674 0.0 +acosh0071 acosh 10000000000000000.0 -0.0 -> 37.534508668464674 -0.0 +acosh0072 acosh 9.9999999999999998e+149 0.0 -> 346.08091112966679 0.0 +acosh0073 acosh 9.9999999999999998e+149 -0.0 -> 346.08091112966679 -0.0 +acosh0074 acosh 1.0000000000000001e+299 0.0 -> 689.16608998577965 0.0 +acosh0075 acosh 1.0000000000000001e+299 -0.0 -> 689.16608998577965 -0.0 + +-- random inputs +acosh0100 acosh -1.4328589581250843 -1.8370347775558309 -> 1.5526962646549587 -2.190250168435786 +acosh0101 acosh -0.31075819156220957 -1.0772555786839297 -> 0.95139168286193709 -1.7812228089636479 +acosh0102 acosh -1.9044776578070453 -20.485370158932124 -> 3.7177411088932359 -1.6633888745861227 +acosh0103 acosh -0.075642506000858742 -21965976320.873051 -> 24.505907742881991 -1.5707963267983402 +acosh0104 acosh -1.6162271181056307 -3.0369343458696099 -> 1.9407057262861227 -2.0429549461750209 +acosh0105 acosh -0.3103780280298063 0.00018054880018078987 -> 0.00018992877058761416 1.886386995096728 +acosh0106 acosh -9159468751.5897655 5.8014747664273649 -> 23.631201197959193 3.1415926529564078 +acosh0107 acosh -0.037739157550933884 0.21841357493510705 -> 0.21685844960602488 1.6076735133449402 +acosh0108 acosh -8225991.0508394297 0.28318543008913644 -> 16.615956520420287 3.1415926191641019 +acosh0109 acosh -35.620070502302639 0.31303237005015 -> 4.2658980006943965 3.1328013255541873 +acosh0110 acosh 96.729939906820917 -0.029345228372365334 -> 5.2650434775863548 -0.00030338895866972843 +acosh0111 acosh 0.59656024007966491 -2.0412294654163978 -> 1.4923002024287835 -1.312568421900338 +acosh0112 acosh 109.29384112677828 -0.00015454863061533812 -> 5.3871662961545477 -1.4141245154061214e-06 +acosh0113 acosh 8.6705651969361597 -3.6723631649787465 -> 2.9336180958363545 -0.40267362031872861 +acosh0114 acosh 1.8101646445052686 -0.012345132721855478 -> 1.1997148566285769 -0.0081813912760150265 +acosh0115 acosh 52.56897195025288 0.001113916065985443 -> 4.6551827622264135 2.1193445872040307e-05 +acosh0116 acosh 0.28336786164214739 355643992457.40485 -> 27.290343226816528 1.5707963267940999 +acosh0117 acosh 0.73876621291911437 2.8828594541104322e-20 -> 4.2774820978159067e-20 0.73955845836827927 +acosh0118 acosh 0.025865471781718878 37125746064318.492 -> 31.938478989418012 1.5707963267948959 +acosh0119 acosh 2.2047353511780132 0.074712248143489271 -> 1.4286403248698021 0.037997904971626598 + +-- values near infinity +acosh0200 acosh 8.1548592876467785e+307 9.0943779335951128e+307 -> 710.08944620800605 0.83981165425478954 +acosh0201 acosh 1.4237229680972531e+308 -1.0336966617874858e+308 -> 710.4543331094759 -0.6279972876348755 +acosh0202 acosh -1.5014526899738939e+308 1.5670700378448792e+308 -> 710.66420706795464 2.3348137299106697 +acosh0203 acosh -1.0939040375213928e+308 -1.0416960351127978e+308 -> 710.30182863115886 -2.380636147787027 +acosh0204 acosh 0.0 1.476062433559588e+308 -> 710.27873384716929 1.5707963267948966 +acosh0205 acosh -0.0 6.2077210326221094e+307 -> 709.41256457484769 1.5707963267948966 +acosh0206 acosh 0.0 -1.5621899909968308e+308 -> 710.33544449990734 -1.5707963267948966 +acosh0207 acosh -0.0 -8.3556624833839122e+307 -> 709.70971018048317 -1.5707963267948966 +acosh0208 acosh 1.3067079752499342e+308 0.0 -> 710.15686680107228 0.0 +acosh0209 acosh 1.5653640340214026e+308 -0.0 -> 710.33747422926706 -0.0 +acosh0210 acosh -6.9011375992290636e+307 0.0 -> 709.51845699719922 3.1415926535897931 +acosh0211 acosh -9.9539576809926973e+307 -0.0 -> 709.88474095870185 -3.1415926535897931 +acosh0212 acosh 7.6449598518914925e+307 9.5706540768268358 -> 709.62081731754802 1.2518906916769345e-307 +acosh0213 acosh 5.4325410972602197e+307 -7.8064807816522706 -> 709.279177727925 -1.4369851312471974e-307 +acosh0214 acosh -1.1523626112360465e+308 7.0617510038869336 -> 710.03117010216909 3.1415926535897931 +acosh0215 acosh -1.1685027786862599e+308 -5.1568558357925625 -> 710.04507907571417 -3.1415926535897931 +acosh0216 acosh 3.0236370339788721 1.7503248720096417e+308 -> 710.44915723458064 1.5707963267948966 +acosh0217 acosh 6.6108007926031149 -9.1469968225806149e+307 -> 709.80019633903328 -1.5707963267948966 +acosh0218 acosh -5.1096262905623959 6.4484926785412395e+307 -> 709.45061713997973 1.5707963267948966 +acosh0219 acosh -2.8080920608735846 -1.7716118836519368e+308 -> 710.46124562363445 -1.5707963267948966 + +-- values near 0 +acosh0220 acosh 4.5560530326699304e-317 7.3048989121436657e-318 -> 7.3048989121436657e-318 1.5707963267948966 +acosh0221 acosh 4.8754274133585331e-314 -9.8469794897684199e-315 -> 9.8469794897684199e-315 -1.5707963267948966 +acosh0222 acosh -4.6748876009960097e-312 9.7900342887557606e-318 -> 9.7900342887557606e-318 1.5707963267948966 +acosh0223 acosh -4.3136871538399236e-320 -4.9406564584124654e-323 -> 4.9406564584124654e-323 -1.5707963267948966 +acosh0224 acosh 0.0 4.3431013866496774e-314 -> 4.3431013866496774e-314 1.5707963267948966 +acosh0225 acosh -0.0 6.0147334335829184e-317 -> 6.0147334335829184e-317 1.5707963267948966 +acosh0226 acosh 0.0 -1.2880291387081297e-320 -> 1.2880291387081297e-320 -1.5707963267948966 +acosh0227 acosh -0.0 -1.4401563976534621e-317 -> 1.4401563976534621e-317 -1.5707963267948966 +acosh0228 acosh 1.3689680570863091e-313 0.0 -> 0.0 1.5707963267948966 +acosh0229 acosh 1.5304346893494371e-312 -0.0 -> 0.0 -1.5707963267948966 +acosh0230 acosh -3.7450175954766488e-320 0.0 -> 0.0 1.5707963267948966 +acosh0231 acosh -8.4250563080885801e-311 -0.0 -> 0.0 -1.5707963267948966 + +-- special values +acosh1000 acosh 0.0 0.0 -> 0.0 1.5707963267948966 +acosh1001 acosh -0.0 0.0 -> 0.0 1.5707963267948966 +acosh1002 acosh 0.0 inf -> inf 1.5707963267948966 +acosh1003 acosh 2.3 inf -> inf 1.5707963267948966 +acosh1004 acosh -0.0 inf -> inf 1.5707963267948966 +acosh1005 acosh -2.3 inf -> inf 1.5707963267948966 +acosh1006 acosh 0.0 nan -> nan nan +acosh1007 acosh 2.3 nan -> nan nan +acosh1008 acosh -0.0 nan -> nan nan +acosh1009 acosh -2.3 nan -> nan nan +acosh1010 acosh -inf 0.0 -> inf 3.1415926535897931 +acosh1011 acosh -inf 2.3 -> inf 3.1415926535897931 +acosh1012 acosh inf 0.0 -> inf 0.0 +acosh1013 acosh inf 2.3 -> inf 0.0 +acosh1014 acosh -inf inf -> inf 2.3561944901923448 +acosh1015 acosh inf inf -> inf 0.78539816339744828 +acosh1016 acosh inf nan -> inf nan +acosh1017 acosh -inf nan -> inf nan +acosh1018 acosh nan 0.0 -> nan nan +acosh1019 acosh nan 2.3 -> nan nan +acosh1020 acosh nan inf -> inf nan +acosh1021 acosh nan nan -> nan nan +acosh1022 acosh 0.0 -0.0 -> 0.0 -1.5707963267948966 +acosh1023 acosh -0.0 -0.0 -> 0.0 -1.5707963267948966 +acosh1024 acosh 0.0 -inf -> inf -1.5707963267948966 +acosh1025 acosh 2.3 -inf -> inf -1.5707963267948966 +acosh1026 acosh -0.0 -inf -> inf -1.5707963267948966 +acosh1027 acosh -2.3 -inf -> inf -1.5707963267948966 +acosh1028 acosh -inf -0.0 -> inf -3.1415926535897931 +acosh1029 acosh -inf -2.3 -> inf -3.1415926535897931 +acosh1030 acosh inf -0.0 -> inf -0.0 +acosh1031 acosh inf -2.3 -> inf -0.0 +acosh1032 acosh -inf -inf -> inf -2.3561944901923448 +acosh1033 acosh inf -inf -> inf -0.78539816339744828 +acosh1034 acosh nan -0.0 -> nan nan +acosh1035 acosh nan -2.3 -> nan nan +acosh1036 acosh nan -inf -> inf nan + + +------------------------ +-- asin: Inverse sine -- +------------------------ + +-- zeros +asin0000 asin 0.0 0.0 -> 0.0 0.0 +asin0001 asin 0.0 -0.0 -> 0.0 -0.0 +asin0002 asin -0.0 0.0 -> -0.0 0.0 +asin0003 asin -0.0 -0.0 -> -0.0 -0.0 + +-- branch points: +/-1 +asin0010 asin 1.0 0.0 -> 1.5707963267948966 0.0 +asin0011 asin 1.0 -0.0 -> 1.5707963267948966 -0.0 +asin0012 asin -1.0 0.0 -> -1.5707963267948966 0.0 +asin0013 asin -1.0 -0.0 -> -1.5707963267948966 -0.0 + +-- values along both sides of real axis +asin0020 asin -9.8813129168249309e-324 0.0 -> -9.8813129168249309e-324 0.0 +asin0021 asin -9.8813129168249309e-324 -0.0 -> -9.8813129168249309e-324 -0.0 +asin0022 asin -1e-305 0.0 -> -1e-305 0.0 +asin0023 asin -1e-305 -0.0 -> -1e-305 -0.0 +asin0024 asin -1e-150 0.0 -> -1e-150 0.0 +asin0025 asin -1e-150 -0.0 -> -1e-150 -0.0 +asin0026 asin -9.9999999999999998e-17 0.0 -> -9.9999999999999998e-17 0.0 +asin0027 asin -9.9999999999999998e-17 -0.0 -> -9.9999999999999998e-17 -0.0 +asin0028 asin -0.001 0.0 -> -0.0010000001666667416 0.0 +asin0029 asin -0.001 -0.0 -> -0.0010000001666667416 -0.0 +asin0030 asin -0.57899999999999996 0.0 -> -0.61750165481717001 0.0 +asin0031 asin -0.57899999999999996 -0.0 -> -0.61750165481717001 -0.0 +asin0032 asin -0.99999999999999989 0.0 -> -1.5707963118937354 0.0 +asin0033 asin -0.99999999999999989 -0.0 -> -1.5707963118937354 -0.0 +asin0034 asin -1.0000000000000002 0.0 -> -1.5707963267948966 2.1073424255447014e-08 +asin0035 asin -1.0000000000000002 -0.0 -> -1.5707963267948966 -2.1073424255447014e-08 +asin0036 asin -1.0009999999999999 0.0 -> -1.5707963267948966 0.044717633608306849 +asin0037 asin -1.0009999999999999 -0.0 -> -1.5707963267948966 -0.044717633608306849 +asin0038 asin -2.0 0.0 -> -1.5707963267948966 1.3169578969248168 +asin0039 asin -2.0 -0.0 -> -1.5707963267948966 -1.3169578969248168 +asin0040 asin -23.0 0.0 -> -1.5707963267948966 3.8281684713331012 +asin0041 asin -23.0 -0.0 -> -1.5707963267948966 -3.8281684713331012 +asin0042 asin -10000000000000000.0 0.0 -> -1.5707963267948966 37.534508668464674 +asin0043 asin -10000000000000000.0 -0.0 -> -1.5707963267948966 -37.534508668464674 +asin0044 asin -9.9999999999999998e+149 0.0 -> -1.5707963267948966 346.08091112966679 +asin0045 asin -9.9999999999999998e+149 -0.0 -> -1.5707963267948966 -346.08091112966679 +asin0046 asin -1.0000000000000001e+299 0.0 -> -1.5707963267948966 689.16608998577965 +asin0047 asin -1.0000000000000001e+299 -0.0 -> -1.5707963267948966 -689.16608998577965 +asin0048 asin 9.8813129168249309e-324 0.0 -> 9.8813129168249309e-324 0.0 +asin0049 asin 9.8813129168249309e-324 -0.0 -> 9.8813129168249309e-324 -0.0 +asin0050 asin 1e-305 0.0 -> 1e-305 0.0 +asin0051 asin 1e-305 -0.0 -> 1e-305 -0.0 +asin0052 asin 1e-150 0.0 -> 1e-150 0.0 +asin0053 asin 1e-150 -0.0 -> 1e-150 -0.0 +asin0054 asin 9.9999999999999998e-17 0.0 -> 9.9999999999999998e-17 0.0 +asin0055 asin 9.9999999999999998e-17 -0.0 -> 9.9999999999999998e-17 -0.0 +asin0056 asin 0.001 0.0 -> 0.0010000001666667416 0.0 +asin0057 asin 0.001 -0.0 -> 0.0010000001666667416 -0.0 +asin0058 asin 0.57899999999999996 0.0 -> 0.61750165481717001 0.0 +asin0059 asin 0.57899999999999996 -0.0 -> 0.61750165481717001 -0.0 +asin0060 asin 0.99999999999999989 0.0 -> 1.5707963118937354 0.0 +asin0061 asin 0.99999999999999989 -0.0 -> 1.5707963118937354 -0.0 +asin0062 asin 1.0000000000000002 0.0 -> 1.5707963267948966 2.1073424255447014e-08 +asin0063 asin 1.0000000000000002 -0.0 -> 1.5707963267948966 -2.1073424255447014e-08 +asin0064 asin 1.0009999999999999 0.0 -> 1.5707963267948966 0.044717633608306849 +asin0065 asin 1.0009999999999999 -0.0 -> 1.5707963267948966 -0.044717633608306849 +asin0066 asin 2.0 0.0 -> 1.5707963267948966 1.3169578969248168 +asin0067 asin 2.0 -0.0 -> 1.5707963267948966 -1.3169578969248168 +asin0068 asin 23.0 0.0 -> 1.5707963267948966 3.8281684713331012 +asin0069 asin 23.0 -0.0 -> 1.5707963267948966 -3.8281684713331012 +asin0070 asin 10000000000000000.0 0.0 -> 1.5707963267948966 37.534508668464674 +asin0071 asin 10000000000000000.0 -0.0 -> 1.5707963267948966 -37.534508668464674 +asin0072 asin 9.9999999999999998e+149 0.0 -> 1.5707963267948966 346.08091112966679 +asin0073 asin 9.9999999999999998e+149 -0.0 -> 1.5707963267948966 -346.08091112966679 +asin0074 asin 1.0000000000000001e+299 0.0 -> 1.5707963267948966 689.16608998577965 +asin0075 asin 1.0000000000000001e+299 -0.0 -> 1.5707963267948966 -689.16608998577965 + +-- random inputs +asin0100 asin -1.5979555835086083 -0.15003009814595247 -> -1.4515369557405788 -1.0544476399790823 +asin0101 asin -0.57488225895317679 -9.6080397838952743e-13 -> -0.61246024460412851 -1.174238005400403e-12 +asin0102 asin -3.6508087930516249 -0.36027527093220152 -> -1.4685890605305874 -1.9742273007152038 +asin0103 asin -1.5238659792326819 -1.1360813516996364 -> -0.86080051691147275 -1.3223742205689195 +asin0104 asin -1592.0639045555306 -0.72362427935018236 -> -1.5703418071175179 -8.0659336918729228 +asin0105 asin -0.19835471371312019 4.2131508416697709 -> -0.045777831019935149 2.1461732751933171 +asin0106 asin -1.918471054430213 0.40603305079779234 -> -1.3301396585791556 1.30263642314981 +asin0107 asin -254495.01623373642 0.71084414434470822 -> -1.5707935336394359 13.140183712762321 +asin0108 asin -0.31315882715691157 3.9647994288429866 -> -0.076450403840916004 2.0889762138713457 +asin0109 asin -0.90017064284720816 1.2530659485907105 -> -0.53466509741943447 1.1702811557577 +asin0110 asin 2.1615181696571075 -0.14058647488229523 -> 1.4976166323896871 -1.4085811039334604 +asin0111 asin 1.2104749210707795 -0.85732484485298999 -> 0.83913071588343924 -1.0681719250525901 +asin0112 asin 1.7059733185128891 -0.84032966373156581 -> 1.0510900815816229 -1.2967979791361652 +asin0113 asin 9.9137085017290687 -1.4608383970250893 -> 1.4237704820128891 -2.995414677560686 +asin0114 asin 117.12344751041495 -5453908091.5334015 -> 2.1475141411392012e-08 -23.112745450217066 +asin0115 asin 0.081041187798029227 0.067054349860173196 -> 0.080946786856771813 0.067223991060639698 +asin0116 asin 46.635472322049949 2.3835190718056678 -> 1.5197194940010779 4.5366989600972083 +asin0117 asin 3907.0687961127105 19.144021886390181 -> 1.5658965233083235 8.9637018715924217 +asin0118 asin 1.0889312322308273 509.01577883554768 -> 0.0021392803817829316 6.9256294494524706 +asin0119 asin 0.10851518277509224 1.5612510908217476 -> 0.058491014243902621 1.2297075725621327 + +-- values near infinity +asin0200 asin 1.5230241998821499e+308 5.5707228994084525e+307 -> 1.2201446370892068 710.37283486535966 +asin0201 asin 8.1334317698672204e+307 -9.2249425197872451e+307 -> 0.72259991284020042 -710.0962453049026 +asin0202 asin -9.9138506659241768e+307 6.701544526434995e+307 -> -0.97637511742194594 710.06887486671371 +asin0203 asin -1.4141298868173842e+308 -5.401505134514191e+307 -> -1.2059319055160587 -710.30396478954628 +asin0204 asin 0.0 9.1618092977897431e+307 -> 0.0 709.80181441050593 +asin0205 asin -0.0 6.8064342551939755e+307 -> -0.0 709.50463910853489 +asin0206 asin 0.0 -6.4997516454798215e+307 -> 0.0 -709.45853469751592 +asin0207 asin -0.0 -1.6767449053345242e+308 -> -0.0 -710.4062101803022 +asin0208 asin 5.4242749957378916e+307 0.0 -> 1.5707963267948966 709.27765497888902 +asin0209 asin 9.5342145121164749e+307 -0.0 -> 1.5707963267948966 -709.84165758595907 +asin0210 asin -7.0445698006201847e+307 0.0 -> -1.5707963267948966 709.53902780872136 +asin0211 asin -1.0016025569769706e+308 -0.0 -> -1.5707963267948966 -709.89095709697881 +asin0212 asin 1.6552203778877204e+308 0.48761543336249491 -> 1.5707963267948966 710.39328998153474 +asin0213 asin 1.2485712830384869e+308 -4.3489311161278899 -> 1.5707963267948966 -710.1113557467786 +asin0214 asin -1.5117842813353125e+308 5.123452666102434 -> -1.5707963267948966 710.30264641923031 +asin0215 asin -1.3167634313008016e+308 -0.52939679793528982 -> -1.5707963267948966 -710.16453260239768 +asin0216 asin 0.80843929176985907 1.0150851827767876e+308 -> 7.9642507396113875e-309 709.90432835561637 +asin0217 asin 8.2544809829680901 -1.7423548140539474e+308 -> 4.7375430746865733e-308 -710.44459336242164 +asin0218 asin -5.2499000118824295 4.6655578977512214e+307 -> -1.1252459249113292e-307 709.1269781491103 +asin0219 asin -5.9904782760833433 -4.7315689314781163e+307 -> -1.2660659419394637e-307 -709.14102757522312 + +-- special values +asin1000 asin -0.0 0.0 -> -0.0 0.0 +asin1001 asin 0.0 0.0 -> 0.0 0.0 +asin1002 asin -0.0 -0.0 -> -0.0 -0.0 +asin1003 asin 0.0 -0.0 -> 0.0 -0.0 +asin1004 asin -inf 0.0 -> -1.5707963267948966 inf +asin1005 asin -inf 2.2999999999999998 -> -1.5707963267948966 inf +asin1006 asin nan 0.0 -> nan nan +asin1007 asin nan 2.2999999999999998 -> nan nan +asin1008 asin -0.0 inf -> -0.0 inf +asin1009 asin -2.2999999999999998 inf -> -0.0 inf +asin1010 asin -inf inf -> -0.78539816339744828 inf +asin1011 asin nan inf -> nan inf +asin1012 asin -0.0 nan -> -0.0 nan +asin1013 asin -2.2999999999999998 nan -> nan nan +asin1014 asin -inf nan -> nan inf ignore-imag-sign +asin1015 asin nan nan -> nan nan +asin1016 asin inf 0.0 -> 1.5707963267948966 inf +asin1017 asin inf 2.2999999999999998 -> 1.5707963267948966 inf +asin1018 asin 0.0 inf -> 0.0 inf +asin1019 asin 2.2999999999999998 inf -> 0.0 inf +asin1020 asin inf inf -> 0.78539816339744828 inf +asin1021 asin 0.0 nan -> 0.0 nan +asin1022 asin 2.2999999999999998 nan -> nan nan +asin1023 asin inf nan -> nan inf ignore-imag-sign +asin1024 asin inf -0.0 -> 1.5707963267948966 -inf +asin1025 asin inf -2.2999999999999998 -> 1.5707963267948966 -inf +asin1026 asin nan -0.0 -> nan nan +asin1027 asin nan -2.2999999999999998 -> nan nan +asin1028 asin 0.0 -inf -> 0.0 -inf +asin1029 asin 2.2999999999999998 -inf -> 0.0 -inf +asin1030 asin inf -inf -> 0.78539816339744828 -inf +asin1031 asin nan -inf -> nan -inf +asin1032 asin -inf -0.0 -> -1.5707963267948966 -inf +asin1033 asin -inf -2.2999999999999998 -> -1.5707963267948966 -inf +asin1034 asin -0.0 -inf -> -0.0 -inf +asin1035 asin -2.2999999999999998 -inf -> -0.0 -inf +asin1036 asin -inf -inf -> -0.78539816339744828 -inf + + +------------------------------------ +-- asinh: Inverse hyperbolic sine -- +------------------------------------ + +-- zeros +asinh0000 asinh 0.0 0.0 -> 0.0 0.0 +asinh0001 asinh 0.0 -0.0 -> 0.0 -0.0 +asinh0002 asinh -0.0 0.0 -> -0.0 0.0 +asinh0003 asinh -0.0 -0.0 -> -0.0 -0.0 + +-- branch points: +/-i +asinh0010 asinh 0.0 1.0 -> 0.0 1.5707963267948966 +asinh0011 asinh 0.0 -1.0 -> 0.0 -1.5707963267948966 +asinh0012 asinh -0.0 1.0 -> -0.0 1.5707963267948966 +asinh0013 asinh -0.0 -1.0 -> -0.0 -1.5707963267948966 + +-- values along both sides of imaginary axis +asinh0020 asinh 0.0 -9.8813129168249309e-324 -> 0.0 -9.8813129168249309e-324 +asinh0021 asinh -0.0 -9.8813129168249309e-324 -> -0.0 -9.8813129168249309e-324 +asinh0022 asinh 0.0 -1e-305 -> 0.0 -1e-305 +asinh0023 asinh -0.0 -1e-305 -> -0.0 -1e-305 +asinh0024 asinh 0.0 -1e-150 -> 0.0 -1e-150 +asinh0025 asinh -0.0 -1e-150 -> -0.0 -1e-150 +asinh0026 asinh 0.0 -9.9999999999999998e-17 -> 0.0 -9.9999999999999998e-17 +asinh0027 asinh -0.0 -9.9999999999999998e-17 -> -0.0 -9.9999999999999998e-17 +asinh0028 asinh 0.0 -0.001 -> 0.0 -0.0010000001666667416 +asinh0029 asinh -0.0 -0.001 -> -0.0 -0.0010000001666667416 +asinh0030 asinh 0.0 -0.57899999999999996 -> 0.0 -0.61750165481717001 +asinh0031 asinh -0.0 -0.57899999999999996 -> -0.0 -0.61750165481717001 +asinh0032 asinh 0.0 -0.99999999999999989 -> 0.0 -1.5707963118937354 +asinh0033 asinh -0.0 -0.99999999999999989 -> -0.0 -1.5707963118937354 +asinh0034 asinh 0.0 -1.0000000000000002 -> 2.1073424255447014e-08 -1.5707963267948966 +asinh0035 asinh -0.0 -1.0000000000000002 -> -2.1073424255447014e-08 -1.5707963267948966 +asinh0036 asinh 0.0 -1.0009999999999999 -> 0.044717633608306849 -1.5707963267948966 +asinh0037 asinh -0.0 -1.0009999999999999 -> -0.044717633608306849 -1.5707963267948966 +asinh0038 asinh 0.0 -2.0 -> 1.3169578969248168 -1.5707963267948966 +asinh0039 asinh -0.0 -2.0 -> -1.3169578969248168 -1.5707963267948966 +asinh0040 asinh 0.0 -20.0 -> 3.6882538673612966 -1.5707963267948966 +asinh0041 asinh -0.0 -20.0 -> -3.6882538673612966 -1.5707963267948966 +asinh0042 asinh 0.0 -10000000000000000.0 -> 37.534508668464674 -1.5707963267948966 +asinh0043 asinh -0.0 -10000000000000000.0 -> -37.534508668464674 -1.5707963267948966 +asinh0044 asinh 0.0 -9.9999999999999998e+149 -> 346.08091112966679 -1.5707963267948966 +asinh0045 asinh -0.0 -9.9999999999999998e+149 -> -346.08091112966679 -1.5707963267948966 +asinh0046 asinh 0.0 -1.0000000000000001e+299 -> 689.16608998577965 -1.5707963267948966 +asinh0047 asinh -0.0 -1.0000000000000001e+299 -> -689.16608998577965 -1.5707963267948966 +asinh0048 asinh 0.0 9.8813129168249309e-324 -> 0.0 9.8813129168249309e-324 +asinh0049 asinh -0.0 9.8813129168249309e-324 -> -0.0 9.8813129168249309e-324 +asinh0050 asinh 0.0 1e-305 -> 0.0 1e-305 +asinh0051 asinh -0.0 1e-305 -> -0.0 1e-305 +asinh0052 asinh 0.0 1e-150 -> 0.0 1e-150 +asinh0053 asinh -0.0 1e-150 -> -0.0 1e-150 +asinh0054 asinh 0.0 9.9999999999999998e-17 -> 0.0 9.9999999999999998e-17 +asinh0055 asinh -0.0 9.9999999999999998e-17 -> -0.0 9.9999999999999998e-17 +asinh0056 asinh 0.0 0.001 -> 0.0 0.0010000001666667416 +asinh0057 asinh -0.0 0.001 -> -0.0 0.0010000001666667416 +asinh0058 asinh 0.0 0.57899999999999996 -> 0.0 0.61750165481717001 +asinh0059 asinh -0.0 0.57899999999999996 -> -0.0 0.61750165481717001 +asinh0060 asinh 0.0 0.99999999999999989 -> 0.0 1.5707963118937354 +asinh0061 asinh -0.0 0.99999999999999989 -> -0.0 1.5707963118937354 +asinh0062 asinh 0.0 1.0000000000000002 -> 2.1073424255447014e-08 1.5707963267948966 +asinh0063 asinh -0.0 1.0000000000000002 -> -2.1073424255447014e-08 1.5707963267948966 +asinh0064 asinh 0.0 1.0009999999999999 -> 0.044717633608306849 1.5707963267948966 +asinh0065 asinh -0.0 1.0009999999999999 -> -0.044717633608306849 1.5707963267948966 +asinh0066 asinh 0.0 2.0 -> 1.3169578969248168 1.5707963267948966 +asinh0067 asinh -0.0 2.0 -> -1.3169578969248168 1.5707963267948966 +asinh0068 asinh 0.0 20.0 -> 3.6882538673612966 1.5707963267948966 +asinh0069 asinh -0.0 20.0 -> -3.6882538673612966 1.5707963267948966 +asinh0070 asinh 0.0 10000000000000000.0 -> 37.534508668464674 1.5707963267948966 +asinh0071 asinh -0.0 10000000000000000.0 -> -37.534508668464674 1.5707963267948966 +asinh0072 asinh 0.0 9.9999999999999998e+149 -> 346.08091112966679 1.5707963267948966 +asinh0073 asinh -0.0 9.9999999999999998e+149 -> -346.08091112966679 1.5707963267948966 +asinh0074 asinh 0.0 1.0000000000000001e+299 -> 689.16608998577965 1.5707963267948966 +asinh0075 asinh -0.0 1.0000000000000001e+299 -> -689.16608998577965 1.5707963267948966 + +-- random inputs +asinh0100 asinh -0.5946402853710423 -0.044506548910000145 -> -0.56459775392653022 -0.038256221441536356 +asinh0101 asinh -0.19353958046180916 -0.017489624793193454 -> -0.19237926804196651 -0.017171741895336792 +asinh0102 asinh -0.033117585138955893 -8.5256414015933757 -> -2.8327758348650969 -1.5668848791092411 +asinh0103 asinh -1.5184043184035716 -0.73491245339073275 -> -1.2715891419764005 -0.39204624408542355 +asinh0104 asinh -0.60716120271208818 -0.28900743958436542 -> -0.59119299421187232 -0.24745931678118135 +asinh0105 asinh -0.0237177865112429 2.8832601052166313 -> -1.7205820772413236 1.5620261702963094 +asinh0106 asinh -2.3906812342743979 2.6349216848574013 -> -1.9609636249445124 0.8142142660574706 +asinh0107 asinh -0.0027605019787620517 183.85588476550555 -> -5.9072920005445066 1.5707813120847871 +asinh0108 asinh -0.99083661164404713 0.028006797051617648 -> -0.8750185251283995 0.019894099615994653 +asinh0109 asinh -3.0362951937986393 0.86377266758504867 -> -1.8636030714685221 0.26475058859950168 +asinh0110 asinh 0.34438464536152769 -0.71603790174885029 -> 0.43985415690734164 -0.71015037409294324 +asinh0111 asinh 4.4925124413876256 -60604595352.871613 -> 25.520783738612078 -1.5707963267207683 +asinh0112 asinh 2.3213991428170337 -7.5459667007307258 -> 2.7560464993451643 -1.270073210856117 +asinh0113 asinh 0.21291939741682028 -1.2720428814784408 -> 0.77275088137338266 -1.3182099250896895 +asinh0114 asinh 6.6447359379455957 -0.97196191666946996 -> 2.602830695139672 -0.14368247412319965 +asinh0115 asinh 7.1326256655083746 2.1516360452706857 -> 2.7051146374367212 0.29051701669727581 +asinh0116 asinh 0.18846550905063442 3.4705348585339832 -> 1.917697875799296 1.514155593347924 +asinh0117 asinh 0.19065075303281598 0.26216814548222012 -> 0.19603050785932474 0.26013422809614117 +asinh0118 asinh 2.0242004665739719 0.70510281647495787 -> 1.4970366212896002 0.30526007200481453 +asinh0119 asinh 37.336596461576057 717.29157391678234 -> 7.269981997945294 1.5187910219576033 + +-- values near infinity +asinh0200 asinh 1.0760517500874541e+308 1.1497786241240167e+308 -> 710.34346055651815 0.81850936961793475 +asinh0201 asinh 1.1784839328845529e+308 -1.6478429586716638e+308 -> 710.59536255783678 -0.94996311735607697 +asinh0202 asinh -4.8777682248909193e+307 1.4103736217538474e+308 -> -710.28970147376992 1.2378239519096443 +asinh0203 asinh -1.2832478903233108e+308 -1.5732392613155698e+308 -> -710.59750164290745 -0.88657181439322452 +asinh0204 asinh 0.0 6.8431383856345372e+307 -> 709.51001718444604 1.5707963267948966 +asinh0205 asinh -0.0 8.601822432238051e+307 -> -709.73874482126689 1.5707963267948966 +asinh0206 asinh 0.0 -5.5698396067303782e+307 -> 709.30413698733742 -1.5707963267948966 +asinh0207 asinh -0.0 -7.1507777734621804e+307 -> -709.55399186002705 -1.5707963267948966 +asinh0208 asinh 1.6025136110019349e+308 0.0 -> 710.3609292261076 0.0 +asinh0209 asinh 1.3927819858239114e+308 -0.0 -> 710.22065899832899 -0.0 +asinh0210 asinh -6.0442994056210995e+307 0.0 -> -709.38588631057621 0.0 +asinh0211 asinh -1.2775271979042634e+308 -0.0 -> -710.13428215553972 -0.0 +asinh0212 asinh 1.0687496260268489e+308 1.0255615699476961 -> 709.95584521407841 9.5959010882679093e-309 +asinh0213 asinh 1.0050967333370962e+308 -0.87668970117333433 -> 709.89443961168183 -8.7224410556242882e-309 +asinh0214 asinh -5.7161452814862392e+307 8.2377808413450122 -> -709.33006540611166 1.4411426644501116e-307 +asinh0215 asinh -8.2009040727653315e+307 -6.407409526654976 -> -709.69101513070109 -7.8130526461510088e-308 +asinh0216 asinh 6.4239368496483982 1.6365990821551427e+308 -> 710.38197618101287 1.5707963267948966 +asinh0217 asinh 5.4729111423315882 -1.1227237438144211e+308 -> 710.00511346983546 -1.5707963267948966 +asinh0218 asinh -8.3455818297412723 1.443172020182019e+308 -> -710.25619930551818 1.5707963267948966 +asinh0219 asinh -2.6049726230372441 -1.7952291144022702e+308 -> -710.47448847685644 -1.5707963267948966 + +-- values near 0 +asinh0220 asinh 1.2940113339664088e-314 6.9169190417774516e-323 -> 1.2940113339664088e-314 6.9169190417774516e-323 +asinh0221 asinh 2.3848478863874649e-315 -3.1907655025717717e-310 -> 2.3848478863874649e-315 -3.1907655025717717e-310 +asinh0222 asinh -3.0097643679641622e-316 4.6936236354918422e-322 -> -3.0097643679641622e-316 4.6936236354918422e-322 +asinh0223 asinh -1.787997087755751e-308 -8.5619622834902341e-310 -> -1.787997087755751e-308 -8.5619622834902341e-310 +asinh0224 asinh 0.0 1.2491433448427325e-314 -> 0.0 1.2491433448427325e-314 +asinh0225 asinh -0.0 2.5024072154538062e-308 -> -0.0 2.5024072154538062e-308 +asinh0226 asinh 0.0 -2.9643938750474793e-323 -> 0.0 -2.9643938750474793e-323 +asinh0227 asinh -0.0 -2.9396905927554169e-320 -> -0.0 -2.9396905927554169e-320 +asinh0228 asinh 5.64042930029359e-317 0.0 -> 5.64042930029359e-317 0.0 +asinh0229 asinh 3.3833911866596068e-318 -0.0 -> 3.3833911866596068e-318 -0.0 +asinh0230 asinh -4.9406564584124654e-324 0.0 -> -4.9406564584124654e-324 0.0 +asinh0231 asinh -2.2211379227994845e-308 -0.0 -> -2.2211379227994845e-308 -0.0 + +-- special values +asinh1000 asinh 0.0 0.0 -> 0.0 0.0 +asinh1001 asinh 0.0 -0.0 -> 0.0 -0.0 +asinh1002 asinh -0.0 0.0 -> -0.0 0.0 +asinh1003 asinh -0.0 -0.0 -> -0.0 -0.0 +asinh1004 asinh 0.0 inf -> inf 1.5707963267948966 +asinh1005 asinh 2.3 inf -> inf 1.5707963267948966 +asinh1006 asinh 0.0 nan -> nan nan +asinh1007 asinh 2.3 nan -> nan nan +asinh1008 asinh inf 0.0 -> inf 0.0 +asinh1009 asinh inf 2.3 -> inf 0.0 +asinh1010 asinh inf inf -> inf 0.78539816339744828 +asinh1011 asinh inf nan -> inf nan +asinh1012 asinh nan 0.0 -> nan 0.0 +asinh1013 asinh nan 2.3 -> nan nan +asinh1014 asinh nan inf -> inf nan ignore-real-sign +asinh1015 asinh nan nan -> nan nan +asinh1016 asinh 0.0 -inf -> inf -1.5707963267948966 +asinh1017 asinh 2.3 -inf -> inf -1.5707963267948966 +asinh1018 asinh inf -0.0 -> inf -0.0 +asinh1019 asinh inf -2.3 -> inf -0.0 +asinh1020 asinh inf -inf -> inf -0.78539816339744828 +asinh1021 asinh nan -0.0 -> nan -0.0 +asinh1022 asinh nan -2.3 -> nan nan +asinh1023 asinh nan -inf -> inf nan ignore-real-sign +asinh1024 asinh -0.0 -inf -> -inf -1.5707963267948966 +asinh1025 asinh -2.3 -inf -> -inf -1.5707963267948966 +asinh1026 asinh -0.0 nan -> nan nan +asinh1027 asinh -2.3 nan -> nan nan +asinh1028 asinh -inf -0.0 -> -inf -0.0 +asinh1029 asinh -inf -2.3 -> -inf -0.0 +asinh1030 asinh -inf -inf -> -inf -0.78539816339744828 +asinh1031 asinh -inf nan -> -inf nan +asinh1032 asinh -0.0 inf -> -inf 1.5707963267948966 +asinh1033 asinh -2.3 inf -> -inf 1.5707963267948966 +asinh1034 asinh -inf 0.0 -> -inf 0.0 +asinh1035 asinh -inf 2.3 -> -inf 0.0 +asinh1036 asinh -inf inf -> -inf 0.78539816339744828 + + +--------------------------- +-- atan: Inverse tangent -- +--------------------------- + +-- zeros +-- These are tested in testAtanSign in test_cmath.py +-- atan0000 atan 0.0 0.0 -> 0.0 0.0 +-- atan0001 atan 0.0 -0.0 -> 0.0 -0.0 +-- atan0002 atan -0.0 0.0 -> -0.0 0.0 +-- atan0003 atan -0.0 -0.0 -> -0.0 -0.0 + +-- values along both sides of imaginary axis +atan0010 atan 0.0 -9.8813129168249309e-324 -> 0.0 -9.8813129168249309e-324 +atan0011 atan -0.0 -9.8813129168249309e-324 -> -0.0 -9.8813129168249309e-324 +atan0012 atan 0.0 -1e-305 -> 0.0 -1e-305 +atan0013 atan -0.0 -1e-305 -> -0.0 -1e-305 +atan0014 atan 0.0 -1e-150 -> 0.0 -1e-150 +atan0015 atan -0.0 -1e-150 -> -0.0 -1e-150 +atan0016 atan 0.0 -9.9999999999999998e-17 -> 0.0 -9.9999999999999998e-17 +atan0017 atan -0.0 -9.9999999999999998e-17 -> -0.0 -9.9999999999999998e-17 +atan0018 atan 0.0 -0.001 -> 0.0 -0.0010000003333335333 +atan0019 atan -0.0 -0.001 -> -0.0 -0.0010000003333335333 +atan0020 atan 0.0 -0.57899999999999996 -> 0.0 -0.6609570902866303 +atan0021 atan -0.0 -0.57899999999999996 -> -0.0 -0.6609570902866303 +atan0022 atan 0.0 -0.99999999999999989 -> 0.0 -18.714973875118524 +atan0023 atan -0.0 -0.99999999999999989 -> -0.0 -18.714973875118524 +atan0024 atan 0.0 -1.0000000000000002 -> 1.5707963267948966 -18.36840028483855 +atan0025 atan -0.0 -1.0000000000000002 -> -1.5707963267948966 -18.36840028483855 +atan0026 atan 0.0 -1.0009999999999999 -> 1.5707963267948966 -3.8007011672919218 +atan0027 atan -0.0 -1.0009999999999999 -> -1.5707963267948966 -3.8007011672919218 +atan0028 atan 0.0 -2.0 -> 1.5707963267948966 -0.54930614433405489 +atan0029 atan -0.0 -2.0 -> -1.5707963267948966 -0.54930614433405489 +atan0030 atan 0.0 -20.0 -> 1.5707963267948966 -0.050041729278491265 +atan0031 atan -0.0 -20.0 -> -1.5707963267948966 -0.050041729278491265 +atan0032 atan 0.0 -10000000000000000.0 -> 1.5707963267948966 -9.9999999999999998e-17 +atan0033 atan -0.0 -10000000000000000.0 -> -1.5707963267948966 -9.9999999999999998e-17 +atan0034 atan 0.0 -9.9999999999999998e+149 -> 1.5707963267948966 -1e-150 +atan0035 atan -0.0 -9.9999999999999998e+149 -> -1.5707963267948966 -1e-150 +atan0036 atan 0.0 -1.0000000000000001e+299 -> 1.5707963267948966 -9.9999999999999999e-300 +atan0037 atan -0.0 -1.0000000000000001e+299 -> -1.5707963267948966 -9.9999999999999999e-300 +atan0038 atan 0.0 9.8813129168249309e-324 -> 0.0 9.8813129168249309e-324 +atan0039 atan -0.0 9.8813129168249309e-324 -> -0.0 9.8813129168249309e-324 +atan0040 atan 0.0 1e-305 -> 0.0 1e-305 +atan0041 atan -0.0 1e-305 -> -0.0 1e-305 +atan0042 atan 0.0 1e-150 -> 0.0 1e-150 +atan0043 atan -0.0 1e-150 -> -0.0 1e-150 +atan0044 atan 0.0 9.9999999999999998e-17 -> 0.0 9.9999999999999998e-17 +atan0045 atan -0.0 9.9999999999999998e-17 -> -0.0 9.9999999999999998e-17 +atan0046 atan 0.0 0.001 -> 0.0 0.0010000003333335333 +atan0047 atan -0.0 0.001 -> -0.0 0.0010000003333335333 +atan0048 atan 0.0 0.57899999999999996 -> 0.0 0.6609570902866303 +atan0049 atan -0.0 0.57899999999999996 -> -0.0 0.6609570902866303 +atan0050 atan 0.0 0.99999999999999989 -> 0.0 18.714973875118524 +atan0051 atan -0.0 0.99999999999999989 -> -0.0 18.714973875118524 +atan0052 atan 0.0 1.0000000000000002 -> 1.5707963267948966 18.36840028483855 +atan0053 atan -0.0 1.0000000000000002 -> -1.5707963267948966 18.36840028483855 +atan0054 atan 0.0 1.0009999999999999 -> 1.5707963267948966 3.8007011672919218 +atan0055 atan -0.0 1.0009999999999999 -> -1.5707963267948966 3.8007011672919218 +atan0056 atan 0.0 2.0 -> 1.5707963267948966 0.54930614433405489 +atan0057 atan -0.0 2.0 -> -1.5707963267948966 0.54930614433405489 +atan0058 atan 0.0 20.0 -> 1.5707963267948966 0.050041729278491265 +atan0059 atan -0.0 20.0 -> -1.5707963267948966 0.050041729278491265 +atan0060 atan 0.0 10000000000000000.0 -> 1.5707963267948966 9.9999999999999998e-17 +atan0061 atan -0.0 10000000000000000.0 -> -1.5707963267948966 9.9999999999999998e-17 +atan0062 atan 0.0 9.9999999999999998e+149 -> 1.5707963267948966 1e-150 +atan0063 atan -0.0 9.9999999999999998e+149 -> -1.5707963267948966 1e-150 +atan0064 atan 0.0 1.0000000000000001e+299 -> 1.5707963267948966 9.9999999999999999e-300 +atan0065 atan -0.0 1.0000000000000001e+299 -> -1.5707963267948966 9.9999999999999999e-300 + +-- random inputs +atan0100 atan -0.32538873661060214 -1.5530461550412578 -> -1.3682728427554227 -0.69451401598762041 +atan0101 atan -0.45863393495197929 -4799.1747094903594 -> -1.5707963068820623 -0.00020836916050636145 +atan0102 atan -8.3006999685976162 -2.6788890251790938 -> -1.4619862771810199 -0.034811669653327826 +atan0103 atan -1.8836307682985314 -1.1441976638861771 -> -1.1839984370871612 -0.20630956157312796 +atan0104 atan -0.00063230482407491669 -4.9312520961829485 -> -1.5707692093223147 -0.20563867743008304 +atan0105 atan -0.84278137150065946 179012.37493146997 -> -1.5707963267685969 5.5862059836425272e-06 +atan0106 atan -0.95487853984049287 14.311334539886177 -> -1.5661322859434561 0.069676024526232005 +atan0107 atan -1.3513252539663239 6.0500727021632198e-08 -> -0.93371676315220975 2.140800269742656e-08 +atan0108 atan -0.20566254458595795 0.11933771944159823 -> -0.20556463711174916 0.11493405387141732 +atan0109 atan -0.58563718795408559 0.64438965423212868 -> -0.68361089300233124 0.46759762751800249 +atan0110 atan 48.479267751948292 -78.386382460112543 -> 1.5650888770910523 -0.0092276811373297584 +atan0111 atan 1.0575373914056061 -0.75988012377296987 -> 0.94430886722043594 -0.31915698126703118 +atan0112 atan 4444810.4314677203 -0.56553404593942558 -> 1.5707961018134231 -2.8625446437701909e-14 +atan0113 atan 0.010101405082520009 -0.032932668550282478 -> 0.01011202676646334 -0.032941214776834996 +atan0114 atan 1.5353585300154911 -2.1947099346796519 -> 1.3400310739206394 -0.29996003607449045 +atan0115 atan 0.21869457055670882 9.9915684254007093 -> 1.5685846078876444 0.1003716881759439 +atan0116 atan 0.17783290150246836 0.064334689863650957 -> 0.17668728064286277 0.062435808728873846 +atan0117 atan 15.757474087615918 383.57262142534 -> 1.5706894060369621 0.0026026817278826603 +atan0118 atan 10.587017408533317 0.21720238081843438 -> 1.4766594681336236 0.0019199097383010061 +atan0119 atan 0.86026078678781204 0.1230148609359502 -> 0.7147259322534929 0.070551221954286605 + +-- values near infinity +atan0200 atan 7.8764397011195798e+307 8.1647921137746308e+307 -> 1.5707963267948966 6.3439446939604493e-309 +atan0201 atan 1.5873698696131487e+308 -1.0780367422960641e+308 -> 1.5707963267948966 -2.9279309368530781e-309 +atan0202 atan -1.5844551864825834e+308 1.0290657809098675e+308 -> -1.5707963267948966 2.8829614736961417e-309 +atan0203 atan -1.3168792562524032e+308 -9.088432341614825e+307 -> -1.5707963267948966 -3.5499373057390056e-309 +atan0204 atan 0.0 1.0360465742258337e+308 -> 1.5707963267948966 9.6520757355646018e-309 +atan0205 atan -0.0 1.0045063210373196e+308 -> -1.5707963267948966 9.955138947929503e-309 +atan0206 atan 0.0 -9.5155296715763696e+307 -> 1.5707963267948966 -1.050913648020118e-308 +atan0207 atan -0.0 -1.5565700490496501e+308 -> -1.5707963267948966 -6.4243816114189071e-309 +atan0208 atan 1.2956339389525244e+308 0.0 -> 1.5707963267948966 0.0 +atan0209 atan 1.4408126243772151e+308 -0.0 -> 1.5707963267948966 -0.0 +atan0210 atan -1.0631786461936417e+308 0.0 -> -1.5707963267948966 0.0 +atan0211 atan -1.0516056964171069e+308 -0.0 -> -1.5707963267948966 -0.0 +atan0212 atan 1.236162319603838e+308 4.6827953496242936 -> 1.5707963267948966 0.0 +atan0213 atan 7.000516472897218e+307 -5.8631608017844163 -> 1.5707963267948966 -0.0 +atan0214 atan -1.5053444003338508e+308 5.1199197268420313 -> -1.5707963267948966 0.0 +atan0215 atan -1.399172518147259e+308 -3.5687766472913673 -> -1.5707963267948966 -0.0 +atan0216 atan 8.1252833070803021 6.2782953917343822e+307 -> 1.5707963267948966 1.5927890256908564e-308 +atan0217 atan 2.8034285947515167 -1.3378049775753878e+308 -> 1.5707963267948966 -7.4749310756219562e-309 +atan0218 atan -1.4073509988974953 1.6776381785968355e+308 -> -1.5707963267948966 5.9607608646364569e-309 +atan0219 atan -2.7135551527592119 -1.281567445525738e+308 -> -1.5707963267948966 -7.8029447727565326e-309 + +-- imaginary part = +/-1, real part tiny +atan0300 atan -1e-150 -1.0 -> -0.78539816339744828 -173.04045556483339 +atan0301 atan 1e-155 1.0 -> 0.78539816339744828 178.79691829731851 +atan0302 atan 9.9999999999999999e-161 -1.0 -> 0.78539816339744828 -184.55338102980363 +atan0303 atan -1e-165 1.0 -> -0.78539816339744828 190.30984376228875 +atan0304 atan -9.9998886718268301e-321 -1.0 -> -0.78539816339744828 -368.76019403576692 + +-- Additional real values (mpmath) +atan0400 atan 1.7976931348623157e+308 0.0 -> 1.5707963267948966192 0.0 +atan0401 atan -1.7976931348623157e+308 0.0 -> -1.5707963267948966192 0.0 +atan0402 atan 1e-17 0.0 -> 1.0000000000000000715e-17 0.0 +atan0403 atan -1e-17 0.0 -> -1.0000000000000000715e-17 0.0 +atan0404 atan 0.0001 0.0 -> 0.000099999999666666673459 0.0 +atan0405 atan -0.0001 0.0 -> -0.000099999999666666673459 0.0 +atan0406 atan 0.999999999999999 0.0 -> 0.78539816339744781002 0.0 +atan0407 atan 1.000000000000001 0.0 -> 0.78539816339744886473 0.0 +atan0408 atan 14.101419947171719 0.0 -> 1.4999999999999999969 0.0 +atan0409 atan 1255.7655915007897 0.0 -> 1.5700000000000000622 0.0 + +-- special values +atan1000 atan -0.0 0.0 -> -0.0 0.0 +atan1001 atan nan 0.0 -> nan 0.0 +atan1002 atan -0.0 1.0 -> -0.0 inf divide-by-zero +atan1003 atan -inf 0.0 -> -1.5707963267948966 0.0 +atan1004 atan -inf 2.2999999999999998 -> -1.5707963267948966 0.0 +atan1005 atan nan 2.2999999999999998 -> nan nan +atan1006 atan -0.0 inf -> -1.5707963267948966 0.0 +atan1007 atan -2.2999999999999998 inf -> -1.5707963267948966 0.0 +atan1008 atan -inf inf -> -1.5707963267948966 0.0 +atan1009 atan nan inf -> nan 0.0 +atan1010 atan -0.0 nan -> nan nan +atan1011 atan -2.2999999999999998 nan -> nan nan +atan1012 atan -inf nan -> -1.5707963267948966 0.0 ignore-imag-sign +atan1013 atan nan nan -> nan nan +atan1014 atan 0.0 0.0 -> 0.0 0.0 +atan1015 atan 0.0 1.0 -> 0.0 inf divide-by-zero +atan1016 atan inf 0.0 -> 1.5707963267948966 0.0 +atan1017 atan inf 2.2999999999999998 -> 1.5707963267948966 0.0 +atan1018 atan 0.0 inf -> 1.5707963267948966 0.0 +atan1019 atan 2.2999999999999998 inf -> 1.5707963267948966 0.0 +atan1020 atan inf inf -> 1.5707963267948966 0.0 +atan1021 atan 0.0 nan -> nan nan +atan1022 atan 2.2999999999999998 nan -> nan nan +atan1023 atan inf nan -> 1.5707963267948966 0.0 ignore-imag-sign +atan1024 atan 0.0 -0.0 -> 0.0 -0.0 +atan1025 atan nan -0.0 -> nan -0.0 +atan1026 atan 0.0 -1.0 -> 0.0 -inf divide-by-zero +atan1027 atan inf -0.0 -> 1.5707963267948966 -0.0 +atan1028 atan inf -2.2999999999999998 -> 1.5707963267948966 -0.0 +atan1029 atan nan -2.2999999999999998 -> nan nan +atan1030 atan 0.0 -inf -> 1.5707963267948966 -0.0 +atan1031 atan 2.2999999999999998 -inf -> 1.5707963267948966 -0.0 +atan1032 atan inf -inf -> 1.5707963267948966 -0.0 +atan1033 atan nan -inf -> nan -0.0 +atan1034 atan -0.0 -0.0 -> -0.0 -0.0 +atan1035 atan -0.0 -1.0 -> -0.0 -inf divide-by-zero +atan1036 atan -inf -0.0 -> -1.5707963267948966 -0.0 +atan1037 atan -inf -2.2999999999999998 -> -1.5707963267948966 -0.0 +atan1038 atan -0.0 -inf -> -1.5707963267948966 -0.0 +atan1039 atan -2.2999999999999998 -inf -> -1.5707963267948966 -0.0 +atan1040 atan -inf -inf -> -1.5707963267948966 -0.0 + + +--------------------------------------- +-- atanh: Inverse hyperbolic tangent -- +--------------------------------------- + +-- zeros +-- These are tested in testAtanhSign in test_cmath.py +-- atanh0000 atanh 0.0 0.0 -> 0.0 0.0 +-- atanh0001 atanh 0.0 -0.0 -> 0.0 -0.0 +-- atanh0002 atanh -0.0 0.0 -> -0.0 0.0 +-- atanh0003 atanh -0.0 -0.0 -> -0.0 -0.0 + +-- values along both sides of real axis +atanh0010 atanh -9.8813129168249309e-324 0.0 -> -9.8813129168249309e-324 0.0 +atanh0011 atanh -9.8813129168249309e-324 -0.0 -> -9.8813129168249309e-324 -0.0 +atanh0012 atanh -1e-305 0.0 -> -1e-305 0.0 +atanh0013 atanh -1e-305 -0.0 -> -1e-305 -0.0 +atanh0014 atanh -1e-150 0.0 -> -1e-150 0.0 +atanh0015 atanh -1e-150 -0.0 -> -1e-150 -0.0 +atanh0016 atanh -9.9999999999999998e-17 0.0 -> -9.9999999999999998e-17 0.0 +atanh0017 atanh -9.9999999999999998e-17 -0.0 -> -9.9999999999999998e-17 -0.0 +atanh0018 atanh -0.001 0.0 -> -0.0010000003333335333 0.0 +atanh0019 atanh -0.001 -0.0 -> -0.0010000003333335333 -0.0 +atanh0020 atanh -0.57899999999999996 0.0 -> -0.6609570902866303 0.0 +atanh0021 atanh -0.57899999999999996 -0.0 -> -0.6609570902866303 -0.0 +atanh0022 atanh -0.99999999999999989 0.0 -> -18.714973875118524 0.0 +atanh0023 atanh -0.99999999999999989 -0.0 -> -18.714973875118524 -0.0 +atanh0024 atanh -1.0000000000000002 0.0 -> -18.36840028483855 1.5707963267948966 +atanh0025 atanh -1.0000000000000002 -0.0 -> -18.36840028483855 -1.5707963267948966 +atanh0026 atanh -1.0009999999999999 0.0 -> -3.8007011672919218 1.5707963267948966 +atanh0027 atanh -1.0009999999999999 -0.0 -> -3.8007011672919218 -1.5707963267948966 +atanh0028 atanh -2.0 0.0 -> -0.54930614433405489 1.5707963267948966 +atanh0029 atanh -2.0 -0.0 -> -0.54930614433405489 -1.5707963267948966 +atanh0030 atanh -23.0 0.0 -> -0.043505688494814884 1.5707963267948966 +atanh0031 atanh -23.0 -0.0 -> -0.043505688494814884 -1.5707963267948966 +atanh0032 atanh -10000000000000000.0 0.0 -> -9.9999999999999998e-17 1.5707963267948966 +atanh0033 atanh -10000000000000000.0 -0.0 -> -9.9999999999999998e-17 -1.5707963267948966 +atanh0034 atanh -9.9999999999999998e+149 0.0 -> -1e-150 1.5707963267948966 +atanh0035 atanh -9.9999999999999998e+149 -0.0 -> -1e-150 -1.5707963267948966 +atanh0036 atanh -1.0000000000000001e+299 0.0 -> -9.9999999999999999e-300 1.5707963267948966 +atanh0037 atanh -1.0000000000000001e+299 -0.0 -> -9.9999999999999999e-300 -1.5707963267948966 +atanh0038 atanh 9.8813129168249309e-324 0.0 -> 9.8813129168249309e-324 0.0 +atanh0039 atanh 9.8813129168249309e-324 -0.0 -> 9.8813129168249309e-324 -0.0 +atanh0040 atanh 1e-305 0.0 -> 1e-305 0.0 +atanh0041 atanh 1e-305 -0.0 -> 1e-305 -0.0 +atanh0042 atanh 1e-150 0.0 -> 1e-150 0.0 +atanh0043 atanh 1e-150 -0.0 -> 1e-150 -0.0 +atanh0044 atanh 9.9999999999999998e-17 0.0 -> 9.9999999999999998e-17 0.0 +atanh0045 atanh 9.9999999999999998e-17 -0.0 -> 9.9999999999999998e-17 -0.0 +atanh0046 atanh 0.001 0.0 -> 0.0010000003333335333 0.0 +atanh0047 atanh 0.001 -0.0 -> 0.0010000003333335333 -0.0 +atanh0048 atanh 0.57899999999999996 0.0 -> 0.6609570902866303 0.0 +atanh0049 atanh 0.57899999999999996 -0.0 -> 0.6609570902866303 -0.0 +atanh0050 atanh 0.99999999999999989 0.0 -> 18.714973875118524 0.0 +atanh0051 atanh 0.99999999999999989 -0.0 -> 18.714973875118524 -0.0 +atanh0052 atanh 1.0000000000000002 0.0 -> 18.36840028483855 1.5707963267948966 +atanh0053 atanh 1.0000000000000002 -0.0 -> 18.36840028483855 -1.5707963267948966 +atanh0054 atanh 1.0009999999999999 0.0 -> 3.8007011672919218 1.5707963267948966 +atanh0055 atanh 1.0009999999999999 -0.0 -> 3.8007011672919218 -1.5707963267948966 +atanh0056 atanh 2.0 0.0 -> 0.54930614433405489 1.5707963267948966 +atanh0057 atanh 2.0 -0.0 -> 0.54930614433405489 -1.5707963267948966 +atanh0058 atanh 23.0 0.0 -> 0.043505688494814884 1.5707963267948966 +atanh0059 atanh 23.0 -0.0 -> 0.043505688494814884 -1.5707963267948966 +atanh0060 atanh 10000000000000000.0 0.0 -> 9.9999999999999998e-17 1.5707963267948966 +atanh0061 atanh 10000000000000000.0 -0.0 -> 9.9999999999999998e-17 -1.5707963267948966 +atanh0062 atanh 9.9999999999999998e+149 0.0 -> 1e-150 1.5707963267948966 +atanh0063 atanh 9.9999999999999998e+149 -0.0 -> 1e-150 -1.5707963267948966 +atanh0064 atanh 1.0000000000000001e+299 0.0 -> 9.9999999999999999e-300 1.5707963267948966 +atanh0065 atanh 1.0000000000000001e+299 -0.0 -> 9.9999999999999999e-300 -1.5707963267948966 + +-- random inputs +atanh0100 atanh -0.54460925980633501 -0.54038050126721027 -> -0.41984265808446974 -0.60354153938352828 +atanh0101 atanh -1.6934614269829051 -0.48807386108113621 -> -0.58592769102243281 -1.3537837470975898 +atanh0102 atanh -1.3467293985501207 -0.47868354895395876 -> -0.69961624370709985 -1.1994450156570076 +atanh0103 atanh -5.6142232418984888 -544551613.39307702 -> -1.8932657550925744e-17 -1.5707963249585235 +atanh0104 atanh -0.011841460381263651 -3.259978899823385 -> -0.0010183936547405188 -1.2731614020743838 +atanh0105 atanh -0.0073345736950029532 0.35821949670922248 -> -0.0065004869024682466 0.34399359971920895 +atanh0106 atanh -13.866782244320014 0.9541129545860273 -> -0.071896852055058899 1.5658322704631409 +atanh0107 atanh -708.59964982780775 21.984802159266675 -> -0.0014098779074189741 1.5707525842838959 +atanh0108 atanh -30.916832076030602 1.3691897138829843 -> -0.032292682045743676 1.5693652094847115 +atanh0109 atanh -0.57461806339861754 0.29534797443913063 -> -0.56467464472482765 0.39615612824172625 +atanh0110 atanh 0.40089246737415685 -1.632285984300659 -> 0.1063832707890608 -1.0402821335326482 +atanh0111 atanh 2119.6167688262176 -1.5383653437377242e+17 -> 8.9565008518382049e-32 -1.5707963267948966 +atanh0112 atanh 756.86017850941641 -6.6064087133223817 -> 0.0013211481136820046 -1.5707847948702234 +atanh0113 atanh 4.0490617718041602 -2.5784456791040652e-12 -> 0.25218425538553618 -1.5707963267947291 +atanh0114 atanh 10.589254957173523 -0.13956391149624509 -> 0.094700890282197664 -1.5695407140217623 +atanh0115 atanh 1.0171187553160499 0.70766113465354019 -> 0.55260251975367791 0.96619711116641682 +atanh0116 atanh 0.031645502527750849 0.067319983726544394 -> 0.031513018344086742 0.067285437670549036 +atanh0117 atanh 0.13670177624994517 0.43240089361857947 -> 0.11538933151017253 0.41392008145336212 +atanh0118 atanh 0.64173899243596688 2.9008577686695256 -> 0.065680142424134405 1.2518535724053921 +atanh0119 atanh 0.19313813528025942 38.799619150741869 -> 0.00012820765917366644 1.5450292202823612 + +-- values near infinity +atanh0200 atanh 5.3242646831347954e+307 1.3740396080084153e+308 -> 2.4519253616695576e-309 1.5707963267948966 +atanh0201 atanh 1.158701641241358e+308 -6.5579268873375853e+307 -> 6.5365375267795098e-309 -1.5707963267948966 +atanh0202 atanh -1.3435325735762247e+308 9.8947369259601547e+307 -> -4.8256680906589956e-309 1.5707963267948966 +atanh0203 atanh -1.4359857522598942e+308 -9.4701204702391004e+307 -> -4.8531282262872645e-309 -1.5707963267948966 +atanh0204 atanh 0.0 5.6614181068098497e+307 -> 0.0 1.5707963267948966 +atanh0205 atanh -0.0 6.9813212721450139e+307 -> -0.0 1.5707963267948966 +atanh0206 atanh 0.0 -7.4970613060311453e+307 -> 0.0 -1.5707963267948966 +atanh0207 atanh -0.0 -1.5280601880314068e+308 -> -0.0 -1.5707963267948966 +atanh0208 atanh 8.2219472336000745e+307 0.0 -> 1.2162568933954813e-308 1.5707963267948966 +atanh0209 atanh 1.4811519617280899e+308 -0.0 -> 6.7515017083951325e-309 -1.5707963267948966 +atanh0210 atanh -1.2282016263598785e+308 0.0 -> -8.1419856360537615e-309 1.5707963267948966 +atanh0211 atanh -1.0616427760154426e+308 -0.0 -> -9.4193642399489563e-309 -1.5707963267948966 +atanh0212 atanh 1.2971536510180682e+308 5.2847948452333293 -> 7.7091869510998328e-309 1.5707963267948966 +atanh0213 atanh 1.1849860977411851e+308 -7.9781906447459949 -> 8.4389175696339014e-309 -1.5707963267948966 +atanh0214 atanh -1.4029969422586635e+308 0.93891986543663375 -> -7.127599283218073e-309 1.5707963267948966 +atanh0215 atanh -4.7508098912248211e+307 -8.2702421247039908 -> -2.1049042645278043e-308 -1.5707963267948966 +atanh0216 atanh 8.2680742115769998 8.1153898410918065e+307 -> 0.0 1.5707963267948966 +atanh0217 atanh 1.2575325146218885 -1.4746679147661649e+308 -> 0.0 -1.5707963267948966 +atanh0218 atanh -2.4618803682310899 1.3781522717005568e+308 -> -0.0 1.5707963267948966 +atanh0219 atanh -4.0952386694788112 -1.231083376353703e+308 -> -0.0 -1.5707963267948966 + +-- values near 0 +atanh0220 atanh 3.8017563659811628e-314 2.6635484239074319e-312 -> 3.8017563659811628e-314 2.6635484239074319e-312 +atanh0221 atanh 1.7391110733611878e-321 -4.3547800672541419e-313 -> 1.7391110733611878e-321 -4.3547800672541419e-313 +atanh0222 atanh -5.9656816081325078e-317 9.9692253555416263e-313 -> -5.9656816081325078e-317 9.9692253555416263e-313 +atanh0223 atanh -6.5606671178400239e-313 -2.1680936406357335e-309 -> -6.5606671178400239e-313 -2.1680936406357335e-309 +atanh0224 atanh 0.0 2.5230944401820779e-319 -> 0.0 2.5230944401820779e-319 +atanh0225 atanh -0.0 5.6066569490064658e-320 -> -0.0 5.6066569490064658e-320 +atanh0226 atanh 0.0 -2.4222487249468377e-317 -> 0.0 -2.4222487249468377e-317 +atanh0227 atanh -0.0 -3.0861101089206037e-316 -> -0.0 -3.0861101089206037e-316 +atanh0228 atanh 3.1219222884393986e-310 0.0 -> 3.1219222884393986e-310 0.0 +atanh0229 atanh 9.8926337564976196e-309 -0.0 -> 9.8926337564976196e-309 -0.0 +atanh0230 atanh -1.5462535092918154e-312 0.0 -> -1.5462535092918154e-312 0.0 +atanh0231 atanh -9.8813129168249309e-324 -0.0 -> -9.8813129168249309e-324 -0.0 + +-- real part = +/-1, imaginary part tiny +atanh0300 atanh 1.0 1e-153 -> 176.49433320432448 0.78539816339744828 +atanh0301 atanh 1.0 9.9999999999999997e-155 -> 177.64562575082149 0.78539816339744828 +atanh0302 atanh -1.0 1e-161 -> -185.70467357630065 0.78539816339744828 +atanh0303 atanh 1.0 -1e-165 -> 190.30984376228875 -0.78539816339744828 +atanh0304 atanh -1.0 -9.8813129168249309e-324 -> -372.22003596069061 -0.78539816339744828 + +-- special values +atanh1000 atanh 0.0 0.0 -> 0.0 0.0 +atanh1001 atanh 0.0 nan -> 0.0 nan +atanh1002 atanh 1.0 0.0 -> inf 0.0 divide-by-zero +atanh1003 atanh 0.0 inf -> 0.0 1.5707963267948966 +atanh1004 atanh 2.3 inf -> 0.0 1.5707963267948966 +atanh1005 atanh 2.3 nan -> nan nan +atanh1006 atanh inf 0.0 -> 0.0 1.5707963267948966 +atanh1007 atanh inf 2.3 -> 0.0 1.5707963267948966 +atanh1008 atanh inf inf -> 0.0 1.5707963267948966 +atanh1009 atanh inf nan -> 0.0 nan +atanh1010 atanh nan 0.0 -> nan nan +atanh1011 atanh nan 2.3 -> nan nan +atanh1012 atanh nan inf -> 0.0 1.5707963267948966 ignore-real-sign +atanh1013 atanh nan nan -> nan nan +atanh1014 atanh 0.0 -0.0 -> 0.0 -0.0 +atanh1015 atanh 1.0 -0.0 -> inf -0.0 divide-by-zero +atanh1016 atanh 0.0 -inf -> 0.0 -1.5707963267948966 +atanh1017 atanh 2.3 -inf -> 0.0 -1.5707963267948966 +atanh1018 atanh inf -0.0 -> 0.0 -1.5707963267948966 +atanh1019 atanh inf -2.3 -> 0.0 -1.5707963267948966 +atanh1020 atanh inf -inf -> 0.0 -1.5707963267948966 +atanh1021 atanh nan -0.0 -> nan nan +atanh1022 atanh nan -2.3 -> nan nan +atanh1023 atanh nan -inf -> 0.0 -1.5707963267948966 ignore-real-sign +atanh1024 atanh -0.0 -0.0 -> -0.0 -0.0 +atanh1025 atanh -0.0 nan -> -0.0 nan +atanh1026 atanh -1.0 -0.0 -> -inf -0.0 divide-by-zero +atanh1027 atanh -0.0 -inf -> -0.0 -1.5707963267948966 +atanh1028 atanh -2.3 -inf -> -0.0 -1.5707963267948966 +atanh1029 atanh -2.3 nan -> nan nan +atanh1030 atanh -inf -0.0 -> -0.0 -1.5707963267948966 +atanh1031 atanh -inf -2.3 -> -0.0 -1.5707963267948966 +atanh1032 atanh -inf -inf -> -0.0 -1.5707963267948966 +atanh1033 atanh -inf nan -> -0.0 nan +atanh1034 atanh -0.0 0.0 -> -0.0 0.0 +atanh1035 atanh -1.0 0.0 -> -inf 0.0 divide-by-zero +atanh1036 atanh -0.0 inf -> -0.0 1.5707963267948966 +atanh1037 atanh -2.3 inf -> -0.0 1.5707963267948966 +atanh1038 atanh -inf 0.0 -> -0.0 1.5707963267948966 +atanh1039 atanh -inf 2.3 -> -0.0 1.5707963267948966 +atanh1040 atanh -inf inf -> -0.0 1.5707963267948966 + + +---------------------------- +-- log: Natural logarithm -- +---------------------------- + +log0000 log 1.0 0.0 -> 0.0 0.0 +log0001 log 1.0 -0.0 -> 0.0 -0.0 +log0002 log -1.0 0.0 -> 0.0 3.1415926535897931 +log0003 log -1.0 -0.0 -> 0.0 -3.1415926535897931 +-- values along both sides of real axis +log0010 log -9.8813129168249309e-324 0.0 -> -743.74692474082133 3.1415926535897931 +log0011 log -9.8813129168249309e-324 -0.0 -> -743.74692474082133 -3.1415926535897931 +log0012 log -1e-305 0.0 -> -702.28845336318398 3.1415926535897931 +log0013 log -1e-305 -0.0 -> -702.28845336318398 -3.1415926535897931 +log0014 log -1e-150 0.0 -> -345.38776394910684 3.1415926535897931 +log0015 log -1e-150 -0.0 -> -345.38776394910684 -3.1415926535897931 +log0016 log -9.9999999999999998e-17 0.0 -> -36.841361487904734 3.1415926535897931 +log0017 log -9.9999999999999998e-17 -0.0 -> -36.841361487904734 -3.1415926535897931 +log0018 log -0.001 0.0 -> -6.9077552789821368 3.1415926535897931 +log0019 log -0.001 -0.0 -> -6.9077552789821368 -3.1415926535897931 +log0020 log -0.57899999999999996 0.0 -> -0.54645280140914188 3.1415926535897931 +log0021 log -0.57899999999999996 -0.0 -> -0.54645280140914188 -3.1415926535897931 +log0022 log -0.99999999999999989 0.0 -> -1.1102230246251565e-16 3.1415926535897931 +log0023 log -0.99999999999999989 -0.0 -> -1.1102230246251565e-16 -3.1415926535897931 +log0024 log -1.0000000000000002 0.0 -> 2.2204460492503128e-16 3.1415926535897931 +log0025 log -1.0000000000000002 -0.0 -> 2.2204460492503128e-16 -3.1415926535897931 +log0026 log -1.0009999999999999 0.0 -> 0.00099950033308342321 3.1415926535897931 +log0027 log -1.0009999999999999 -0.0 -> 0.00099950033308342321 -3.1415926535897931 +log0028 log -2.0 0.0 -> 0.69314718055994529 3.1415926535897931 +log0029 log -2.0 -0.0 -> 0.69314718055994529 -3.1415926535897931 +log0030 log -23.0 0.0 -> 3.1354942159291497 3.1415926535897931 +log0031 log -23.0 -0.0 -> 3.1354942159291497 -3.1415926535897931 +log0032 log -10000000000000000.0 0.0 -> 36.841361487904734 3.1415926535897931 +log0033 log -10000000000000000.0 -0.0 -> 36.841361487904734 -3.1415926535897931 +log0034 log -9.9999999999999998e+149 0.0 -> 345.38776394910684 3.1415926535897931 +log0035 log -9.9999999999999998e+149 -0.0 -> 345.38776394910684 -3.1415926535897931 +log0036 log -1.0000000000000001e+299 0.0 -> 688.47294280521965 3.1415926535897931 +log0037 log -1.0000000000000001e+299 -0.0 -> 688.47294280521965 -3.1415926535897931 +log0038 log 9.8813129168249309e-324 0.0 -> -743.74692474082133 0.0 +log0039 log 9.8813129168249309e-324 -0.0 -> -743.74692474082133 -0.0 +log0040 log 1e-305 0.0 -> -702.28845336318398 0.0 +log0041 log 1e-305 -0.0 -> -702.28845336318398 -0.0 +log0042 log 1e-150 0.0 -> -345.38776394910684 0.0 +log0043 log 1e-150 -0.0 -> -345.38776394910684 -0.0 +log0044 log 9.9999999999999998e-17 0.0 -> -36.841361487904734 0.0 +log0045 log 9.9999999999999998e-17 -0.0 -> -36.841361487904734 -0.0 +log0046 log 0.001 0.0 -> -6.9077552789821368 0.0 +log0047 log 0.001 -0.0 -> -6.9077552789821368 -0.0 +log0048 log 0.57899999999999996 0.0 -> -0.54645280140914188 0.0 +log0049 log 0.57899999999999996 -0.0 -> -0.54645280140914188 -0.0 +log0050 log 0.99999999999999989 0.0 -> -1.1102230246251565e-16 0.0 +log0051 log 0.99999999999999989 -0.0 -> -1.1102230246251565e-16 -0.0 +log0052 log 1.0000000000000002 0.0 -> 2.2204460492503128e-16 0.0 +log0053 log 1.0000000000000002 -0.0 -> 2.2204460492503128e-16 -0.0 +log0054 log 1.0009999999999999 0.0 -> 0.00099950033308342321 0.0 +log0055 log 1.0009999999999999 -0.0 -> 0.00099950033308342321 -0.0 +log0056 log 2.0 0.0 -> 0.69314718055994529 0.0 +log0057 log 2.0 -0.0 -> 0.69314718055994529 -0.0 +log0058 log 23.0 0.0 -> 3.1354942159291497 0.0 +log0059 log 23.0 -0.0 -> 3.1354942159291497 -0.0 +log0060 log 10000000000000000.0 0.0 -> 36.841361487904734 0.0 +log0061 log 10000000000000000.0 -0.0 -> 36.841361487904734 -0.0 +log0062 log 9.9999999999999998e+149 0.0 -> 345.38776394910684 0.0 +log0063 log 9.9999999999999998e+149 -0.0 -> 345.38776394910684 -0.0 +log0064 log 1.0000000000000001e+299 0.0 -> 688.47294280521965 0.0 +log0065 log 1.0000000000000001e+299 -0.0 -> 688.47294280521965 -0.0 + +-- random inputs +log0066 log -1.9830454945186191e-16 -2.0334448025673346 -> 0.70973130194329803 -1.5707963267948968 +log0067 log -0.96745853024741857 -0.84995816228299692 -> 0.25292811398722387 -2.4207570438536905 +log0068 log -0.1603644313948418 -0.2929942111041835 -> -1.0965857872427374 -2.0715870859971419 +log0069 log -0.15917913168438699 -0.25238799251132177 -> -1.2093477313249901 -2.1334784232033863 +log0070 log -0.68907818535078802 -3.0693105853476346 -> 1.1460398629184565 -1.7916403813913211 +log0071 log -17.268133447565589 6.8165120014604756 -> 2.9212694465974836 2.7656245081603164 +log0072 log -1.7153894479690328 26.434055372802636 -> 3.2767542953718003 1.6355986276341734 +log0073 log -8.0456794648936578e-06 0.19722758057570208 -> -1.6233969848296075 1.5708371206810101 +log0074 log -2.4306442691323173 0.6846919750700996 -> 0.92633592001969589 2.8670160576718331 +log0075 log -3.5488049250888194 0.45324040643185254 -> 1.2747008374256426 3.0145640007885111 +log0076 log 0.18418516851510189 -0.26062518836212617 -> -1.1421287121940344 -0.95558440841183434 +log0077 log 2.7124837795638399 -13.148769067133387 -> 2.5971659975706802 -1.3673583045209439 +log0078 log 3.6521275476169149e-13 -3.7820543023170673e-05 -> -10.182658136741569 -1.5707963171384316 +log0079 log 5.0877545813862239 -1.2834978326786852 -> 1.6576856213076328 -0.24711583497738485 +log0080 log 0.26477986808461512 -0.67659001194187429 -> -0.31944085207999973 -1.197773671987121 +log0081 log 0.0014754261398071962 5.3514691608205442 -> 1.6773711707153829 1.5705206219261802 +log0082 log 0.29667334462157885 0.00020056045042584795 -> -1.2151233667079588 0.00067603114168689204 +log0083 log 0.82104233671099425 3.9005387130133102 -> 1.3827918965299593 1.3633304701848363 +log0084 log 0.27268135358180667 124.42088110945804 -> 4.8236724223559229 1.5686047258789015 +log0085 log 0.0026286959168267485 0.47795808180573013 -> -0.73821712137809126 1.5652965360960087 + +-- values near infinity +log0100 log 1.0512025744003172e+308 7.2621669750664611e+307 -> 709.44123967814494 0.60455434048332968 +log0101 log 5.5344249034372126e+307 -1.2155859158431275e+308 -> 709.48562300345679 -1.143553056717973 +log0102 log -1.3155575403469408e+308 1.1610793541663864e+308 -> 709.75847809546428 2.41848796504974 +log0103 log -1.632366720973235e+308 -1.54299446211448e+308 -> 710.00545236515586 -2.3843326028455087 +log0104 log 0.0 5.9449276692327712e+307 -> 708.67616191258526 1.5707963267948966 +log0105 log -0.0 1.1201850459025692e+308 -> 709.30970253338171 1.5707963267948966 +log0106 log 0.0 -1.6214225933466528e+308 -> 709.6795125501086 -1.5707963267948966 +log0107 log -0.0 -1.7453269791591058e+308 -> 709.75315056087379 -1.5707963267948966 +log0108 log 1.440860577601428e+308 0.0 -> 709.56144920058262 0.0 +log0109 log 1.391515176148282e+308 -0.0 -> 709.52660185041327 -0.0 +log0110 log -1.201354401295296e+308 0.0 -> 709.37965823023956 3.1415926535897931 +log0111 log -1.6704337825976804e+308 -0.0 -> 709.70929198492399 -3.1415926535897931 +log0112 log 7.2276974655190223e+307 7.94879711369164 -> 708.87154406512104 1.0997689307850458e-307 +log0113 log 1.1207859593716076e+308 -6.1956200868221147 -> 709.31023883080104 -5.5279244310803286e-308 +log0114 log -4.6678933874471045e+307 9.947107893220382 -> 708.43433142431388 3.1415926535897931 +log0115 log -1.5108012453950142e+308 -5.3117197179375619 -> 709.60884877835008 -3.1415926535897931 +log0116 log 7.4903750871504435 1.5320703776626352e+308 -> 709.62282865085137 1.5707963267948966 +log0117 log 5.9760325525654778 -8.0149473997349123e+307 -> 708.97493177248396 -1.5707963267948966 +log0118 log -7.880194206386629 1.7861845814767441e+308 -> 709.77629046837137 1.5707963267948966 +log0119 log -9.886438993852865 -6.19235781080747e+307 -> 708.71693946977302 -1.5707963267948966 + +-- values near 0 +log0120 log 2.2996867579227779e-308 6.7861840770939125e-312 -> -708.36343567717392 0.00029509166223339815 +log0121 log 6.9169190417774516e-323 -9.0414013188948118e-322 -> -739.22766796468386 -1.4944423210001669 +log0122 log -1.5378064962914011e-316 1.8243628389354635e-310 -> -713.20014803142965 1.5707971697228842 +log0123 log -2.3319898483706837e-321 -2.2358763941866371e-313 -> -719.9045008332522 -1.570796337224766 +log0124 log 0.0 3.872770101081121e-315 -> -723.96033425374401 1.5707963267948966 +log0125 log -0.0 9.6342800939043076e-322 -> -739.16707236281752 1.5707963267948966 +log0126 log 0.0 -2.266099393427834e-308 -> -708.37814861757965 -1.5707963267948966 +log0127 log -0.0 -2.1184695673766626e-315 -> -724.56361036731812 -1.5707963267948966 +log0128 log 1.1363509854348671e-322 0.0 -> -741.30457770545206 0.0 +log0129 log 3.5572726500569751e-322 -0.0 -> -740.16340580236522 -0.0 +log0130 log -2.3696071074040593e-310 0.0 -> -712.93865466421641 3.1415926535897931 +log0131 log -2.813283897266934e-317 -0.0 -> -728.88512203138862 -3.1415926535897931 + +-- values near the unit circle +log0200 log -0.59999999999999998 0.80000000000000004 -> 2.2204460492503132e-17 2.2142974355881808 +log0201 log 0.79999999999999993 0.60000000000000009 -> 6.1629758220391547e-33 0.64350110879328448 + +-- special values +log1000 log -0.0 0.0 -> -inf 3.1415926535897931 divide-by-zero +log1001 log 0.0 0.0 -> -inf 0.0 divide-by-zero +log1002 log 0.0 inf -> inf 1.5707963267948966 +log1003 log 2.3 inf -> inf 1.5707963267948966 +log1004 log -0.0 inf -> inf 1.5707963267948966 +log1005 log -2.3 inf -> inf 1.5707963267948966 +log1006 log 0.0 nan -> nan nan +log1007 log 2.3 nan -> nan nan +log1008 log -0.0 nan -> nan nan +log1009 log -2.3 nan -> nan nan +log1010 log -inf 0.0 -> inf 3.1415926535897931 +log1011 log -inf 2.3 -> inf 3.1415926535897931 +log1012 log inf 0.0 -> inf 0.0 +log1013 log inf 2.3 -> inf 0.0 +log1014 log -inf inf -> inf 2.3561944901923448 +log1015 log inf inf -> inf 0.78539816339744828 +log1016 log inf nan -> inf nan +log1017 log -inf nan -> inf nan +log1018 log nan 0.0 -> nan nan +log1019 log nan 2.3 -> nan nan +log1020 log nan inf -> inf nan +log1021 log nan nan -> nan nan +log1022 log -0.0 -0.0 -> -inf -3.1415926535897931 divide-by-zero +log1023 log 0.0 -0.0 -> -inf -0.0 divide-by-zero +log1024 log 0.0 -inf -> inf -1.5707963267948966 +log1025 log 2.3 -inf -> inf -1.5707963267948966 +log1026 log -0.0 -inf -> inf -1.5707963267948966 +log1027 log -2.3 -inf -> inf -1.5707963267948966 +log1028 log -inf -0.0 -> inf -3.1415926535897931 +log1029 log -inf -2.3 -> inf -3.1415926535897931 +log1030 log inf -0.0 -> inf -0.0 +log1031 log inf -2.3 -> inf -0.0 +log1032 log -inf -inf -> inf -2.3561944901923448 +log1033 log inf -inf -> inf -0.78539816339744828 +log1034 log nan -0.0 -> nan nan +log1035 log nan -2.3 -> nan nan +log1036 log nan -inf -> inf nan + + +------------------------------ +-- log10: Logarithm base 10 -- +------------------------------ + +logt0000 log10 1.0 0.0 -> 0.0 0.0 +logt0001 log10 1.0 -0.0 -> 0.0 -0.0 +logt0002 log10 -1.0 0.0 -> 0.0 1.3643763538418414 +logt0003 log10 -1.0 -0.0 -> 0.0 -1.3643763538418414 +-- values along both sides of real axis +logt0010 log10 -9.8813129168249309e-324 0.0 -> -323.0051853474518 1.3643763538418414 +logt0011 log10 -9.8813129168249309e-324 -0.0 -> -323.0051853474518 -1.3643763538418414 +logt0012 log10 -1e-305 0.0 -> -305.0 1.3643763538418414 +logt0013 log10 -1e-305 -0.0 -> -305.0 -1.3643763538418414 +logt0014 log10 -1e-150 0.0 -> -150.0 1.3643763538418414 +logt0015 log10 -1e-150 -0.0 -> -150.0 -1.3643763538418414 +logt0016 log10 -9.9999999999999998e-17 0.0 -> -16.0 1.3643763538418414 +logt0017 log10 -9.9999999999999998e-17 -0.0 -> -16.0 -1.3643763538418414 +logt0018 log10 -0.001 0.0 -> -3.0 1.3643763538418414 +logt0019 log10 -0.001 -0.0 -> -3.0 -1.3643763538418414 +logt0020 log10 -0.57899999999999996 0.0 -> -0.23732143627256383 1.3643763538418414 +logt0021 log10 -0.57899999999999996 -0.0 -> -0.23732143627256383 -1.3643763538418414 +logt0022 log10 -0.99999999999999989 0.0 -> -4.821637332766436e-17 1.3643763538418414 +logt0023 log10 -0.99999999999999989 -0.0 -> -4.821637332766436e-17 -1.3643763538418414 +logt0024 log10 -1.0000000000000002 0.0 -> 9.6432746655328696e-17 1.3643763538418414 +logt0025 log10 -1.0000000000000002 -0.0 -> 9.6432746655328696e-17 -1.3643763538418414 +logt0026 log10 -1.0009999999999999 0.0 -> 0.0004340774793185929 1.3643763538418414 +logt0027 log10 -1.0009999999999999 -0.0 -> 0.0004340774793185929 -1.3643763538418414 +logt0028 log10 -2.0 0.0 -> 0.3010299956639812 1.3643763538418414 +logt0029 log10 -2.0 -0.0 -> 0.3010299956639812 -1.3643763538418414 +logt0030 log10 -23.0 0.0 -> 1.3617278360175928 1.3643763538418414 +logt0031 log10 -23.0 -0.0 -> 1.3617278360175928 -1.3643763538418414 +logt0032 log10 -10000000000000000.0 0.0 -> 16.0 1.3643763538418414 +logt0033 log10 -10000000000000000.0 -0.0 -> 16.0 -1.3643763538418414 +logt0034 log10 -9.9999999999999998e+149 0.0 -> 150.0 1.3643763538418414 +logt0035 log10 -9.9999999999999998e+149 -0.0 -> 150.0 -1.3643763538418414 +logt0036 log10 -1.0000000000000001e+299 0.0 -> 299.0 1.3643763538418414 +logt0037 log10 -1.0000000000000001e+299 -0.0 -> 299.0 -1.3643763538418414 +logt0038 log10 9.8813129168249309e-324 0.0 -> -323.0051853474518 0.0 +logt0039 log10 9.8813129168249309e-324 -0.0 -> -323.0051853474518 -0.0 +logt0040 log10 1e-305 0.0 -> -305.0 0.0 +logt0041 log10 1e-305 -0.0 -> -305.0 -0.0 +logt0042 log10 1e-150 0.0 -> -150.0 0.0 +logt0043 log10 1e-150 -0.0 -> -150.0 -0.0 +logt0044 log10 9.9999999999999998e-17 0.0 -> -16.0 0.0 +logt0045 log10 9.9999999999999998e-17 -0.0 -> -16.0 -0.0 +logt0046 log10 0.001 0.0 -> -3.0 0.0 +logt0047 log10 0.001 -0.0 -> -3.0 -0.0 +logt0048 log10 0.57899999999999996 0.0 -> -0.23732143627256383 0.0 +logt0049 log10 0.57899999999999996 -0.0 -> -0.23732143627256383 -0.0 +logt0050 log10 0.99999999999999989 0.0 -> -4.821637332766436e-17 0.0 +logt0051 log10 0.99999999999999989 -0.0 -> -4.821637332766436e-17 -0.0 +logt0052 log10 1.0000000000000002 0.0 -> 9.6432746655328696e-17 0.0 +logt0053 log10 1.0000000000000002 -0.0 -> 9.6432746655328696e-17 -0.0 +logt0054 log10 1.0009999999999999 0.0 -> 0.0004340774793185929 0.0 +logt0055 log10 1.0009999999999999 -0.0 -> 0.0004340774793185929 -0.0 +logt0056 log10 2.0 0.0 -> 0.3010299956639812 0.0 +logt0057 log10 2.0 -0.0 -> 0.3010299956639812 -0.0 +logt0058 log10 23.0 0.0 -> 1.3617278360175928 0.0 +logt0059 log10 23.0 -0.0 -> 1.3617278360175928 -0.0 +logt0060 log10 10000000000000000.0 0.0 -> 16.0 0.0 +logt0061 log10 10000000000000000.0 -0.0 -> 16.0 -0.0 +logt0062 log10 9.9999999999999998e+149 0.0 -> 150.0 0.0 +logt0063 log10 9.9999999999999998e+149 -0.0 -> 150.0 -0.0 +logt0064 log10 1.0000000000000001e+299 0.0 -> 299.0 0.0 +logt0065 log10 1.0000000000000001e+299 -0.0 -> 299.0 -0.0 + +-- random inputs +logt0066 log10 -1.9830454945186191e-16 -2.0334448025673346 -> 0.30823238806798503 -0.68218817692092071 +logt0067 log10 -0.96745853024741857 -0.84995816228299692 -> 0.10984528422284802 -1.051321426174086 +logt0068 log10 -0.1603644313948418 -0.2929942111041835 -> -0.47624115633305419 -0.89967884023059597 +logt0069 log10 -0.15917913168438699 -0.25238799251132177 -> -0.52521304641665956 -0.92655790645688119 +logt0070 log10 -0.68907818535078802 -3.0693105853476346 -> 0.4977187885066448 -0.77809953119328823 +logt0071 log10 -17.268133447565589 6.8165120014604756 -> 1.2686912008098534 1.2010954629104202 +logt0072 log10 -1.7153894479690328 26.434055372802636 -> 1.423076309032751 0.71033145859005309 +logt0073 log10 -8.0456794648936578e-06 0.19722758057570208 -> -0.70503235244987561 0.68220589348055516 +logt0074 log10 -2.4306442691323173 0.6846919750700996 -> 0.40230257845332595 1.2451292533748923 +logt0075 log10 -3.5488049250888194 0.45324040643185254 -> 0.55359553977141063 1.3092085108866405 +logt0076 log10 0.18418516851510189 -0.26062518836212617 -> -0.49602019732913638 -0.41500503556604301 +logt0077 log10 2.7124837795638399 -13.148769067133387 -> 1.1279348613317008 -0.59383616643803216 +logt0078 log10 3.6521275476169149e-13 -3.7820543023170673e-05 -> -4.4222722398941112 -0.68218817272717114 +logt0079 log10 5.0877545813862239 -1.2834978326786852 -> 0.71992371806426847 -0.10732104352159283 +logt0080 log10 0.26477986808461512 -0.67659001194187429 -> -0.13873139935281681 -0.52018649631300229 +logt0081 log10 0.0014754261398071962 5.3514691608205442 -> 0.72847304354528819 0.6820684398178033 +logt0082 log10 0.29667334462157885 0.00020056045042584795 -> -0.52772137299296806 0.00029359659442937261 +logt0083 log10 0.82104233671099425 3.9005387130133102 -> 0.60053889028349361 0.59208690021184018 +logt0084 log10 0.27268135358180667 124.42088110945804 -> 2.094894315538069 0.68123637673656989 +logt0085 log10 0.0026286959168267485 0.47795808180573013 -> -0.32060362226100814 0.67979964816877081 + +-- values near infinity +logt0100 log10 1.0512025744003172e+308 7.2621669750664611e+307 -> 308.10641562682065 0.26255461408256975 +logt0101 log10 5.5344249034372126e+307 -1.2155859158431275e+308 -> 308.12569106009209 -0.496638782296212 +logt0102 log10 -1.3155575403469408e+308 1.1610793541663864e+308 -> 308.24419052091019 1.0503359777705266 +logt0103 log10 -1.632366720973235e+308 -1.54299446211448e+308 -> 308.3514500834093 -1.0355024924378222 +logt0104 log10 0.0 5.9449276692327712e+307 -> 307.77414657501117 0.68218817692092071 +logt0105 log10 -0.0 1.1201850459025692e+308 -> 308.04928977068465 0.68218817692092071 +logt0106 log10 0.0 -1.6214225933466528e+308 -> 308.20989622030174 -0.68218817692092071 +logt0107 log10 -0.0 -1.7453269791591058e+308 -> 308.24187680203539 -0.68218817692092071 +logt0108 log10 1.440860577601428e+308 0.0 -> 308.15862195908755 0.0 +logt0109 log10 1.391515176148282e+308 -0.0 -> 308.14348794720007 -0.0 +logt0110 log10 -1.201354401295296e+308 0.0 -> 308.07967114380773 1.3643763538418414 +logt0111 log10 -1.6704337825976804e+308 -0.0 -> 308.22282926451624 -1.3643763538418414 +logt0112 log10 7.2276974655190223e+307 7.94879711369164 -> 307.85899996571993 4.7762357800858463e-308 +logt0113 log10 1.1207859593716076e+308 -6.1956200868221147 -> 308.04952268169455 -2.4007470767963597e-308 +logt0114 log10 -4.6678933874471045e+307 9.947107893220382 -> 307.66912092839902 1.3643763538418414 +logt0115 log10 -1.5108012453950142e+308 -5.3117197179375619 -> 308.1792073341565 -1.3643763538418414 +logt0116 log10 7.4903750871504435 1.5320703776626352e+308 -> 308.18527871564157 0.68218817692092071 +logt0117 log10 5.9760325525654778 -8.0149473997349123e+307 -> 307.90390067652424 -0.68218817692092071 +logt0118 log10 -7.880194206386629 1.7861845814767441e+308 -> 308.25192633617331 0.68218817692092071 +logt0119 log10 -9.886438993852865 -6.19235781080747e+307 -> 307.79185604308338 -0.68218817692092071 + +-- values near 0 +logt0120 log10 2.2996867579227779e-308 6.7861840770939125e-312 -> -307.63833129662572 0.00012815668056362305 +logt0121 log10 6.9169190417774516e-323 -9.0414013188948118e-322 -> -321.04249706727148 -0.64902805353306059 +logt0122 log10 -1.5378064962914011e-316 1.8243628389354635e-310 -> -309.73888878263222 0.68218854299989429 +logt0123 log10 -2.3319898483706837e-321 -2.2358763941866371e-313 -> -312.65055220919641 -0.68218818145055538 +logt0124 log10 0.0 3.872770101081121e-315 -> -314.41197828323476 0.68218817692092071 +logt0125 log10 -0.0 9.6342800939043076e-322 -> -321.01618073175331 0.68218817692092071 +logt0126 log10 0.0 -2.266099393427834e-308 -> -307.64472104545649 -0.68218817692092071 +logt0127 log10 -0.0 -2.1184695673766626e-315 -> -314.67397777042407 -0.68218817692092071 +logt0128 log10 1.1363509854348671e-322 0.0 -> -321.94448750709819 0.0 +logt0129 log10 3.5572726500569751e-322 -0.0 -> -321.44888284668451 -0.0 +logt0130 log10 -2.3696071074040593e-310 0.0 -> -309.62532365619722 1.3643763538418414 +logt0131 log10 -2.813283897266934e-317 -0.0 -> -316.55078643961042 -1.3643763538418414 + +-- values near the unit circle +logt0200 log10 -0.59999999999999998 0.80000000000000004 -> 9.6432746655328709e-18 0.96165715756846815 +logt0201 log10 0.79999999999999993 0.60000000000000009 -> 2.6765463916147622e-33 0.2794689806475476 + +-- special values +logt1000 log10 -0.0 0.0 -> -inf 1.3643763538418414 divide-by-zero +logt1001 log10 0.0 0.0 -> -inf 0.0 divide-by-zero +logt1002 log10 0.0 inf -> inf 0.68218817692092071 +logt1003 log10 2.3 inf -> inf 0.68218817692092071 +logt1004 log10 -0.0 inf -> inf 0.68218817692092071 +logt1005 log10 -2.3 inf -> inf 0.68218817692092071 +logt1006 log10 0.0 nan -> nan nan +logt1007 log10 2.3 nan -> nan nan +logt1008 log10 -0.0 nan -> nan nan +logt1009 log10 -2.3 nan -> nan nan +logt1010 log10 -inf 0.0 -> inf 1.3643763538418414 +logt1011 log10 -inf 2.3 -> inf 1.3643763538418414 +logt1012 log10 inf 0.0 -> inf 0.0 +logt1013 log10 inf 2.3 -> inf 0.0 +logt1014 log10 -inf inf -> inf 1.0232822653813811 +logt1015 log10 inf inf -> inf 0.34109408846046035 +logt1016 log10 inf nan -> inf nan +logt1017 log10 -inf nan -> inf nan +logt1018 log10 nan 0.0 -> nan nan +logt1019 log10 nan 2.3 -> nan nan +logt1020 log10 nan inf -> inf nan +logt1021 log10 nan nan -> nan nan +logt1022 log10 -0.0 -0.0 -> -inf -1.3643763538418414 divide-by-zero +logt1023 log10 0.0 -0.0 -> -inf -0.0 divide-by-zero +logt1024 log10 0.0 -inf -> inf -0.68218817692092071 +logt1025 log10 2.3 -inf -> inf -0.68218817692092071 +logt1026 log10 -0.0 -inf -> inf -0.68218817692092071 +logt1027 log10 -2.3 -inf -> inf -0.68218817692092071 +logt1028 log10 -inf -0.0 -> inf -1.3643763538418414 +logt1029 log10 -inf -2.3 -> inf -1.3643763538418414 +logt1030 log10 inf -0.0 -> inf -0.0 +logt1031 log10 inf -2.3 -> inf -0.0 +logt1032 log10 -inf -inf -> inf -1.0232822653813811 +logt1033 log10 inf -inf -> inf -0.34109408846046035 +logt1034 log10 nan -0.0 -> nan nan +logt1035 log10 nan -2.3 -> nan nan +logt1036 log10 nan -inf -> inf nan + + +----------------------- +-- sqrt: Square root -- +----------------------- + +-- zeros +sqrt0000 sqrt 0.0 0.0 -> 0.0 0.0 +sqrt0001 sqrt 0.0 -0.0 -> 0.0 -0.0 +sqrt0002 sqrt -0.0 0.0 -> 0.0 0.0 +sqrt0003 sqrt -0.0 -0.0 -> 0.0 -0.0 + +-- values along both sides of real axis +sqrt0010 sqrt -9.8813129168249309e-324 0.0 -> 0.0 3.1434555694052576e-162 +sqrt0011 sqrt -9.8813129168249309e-324 -0.0 -> 0.0 -3.1434555694052576e-162 +sqrt0012 sqrt -1e-305 0.0 -> 0.0 3.1622776601683791e-153 +sqrt0013 sqrt -1e-305 -0.0 -> 0.0 -3.1622776601683791e-153 +sqrt0014 sqrt -1e-150 0.0 -> 0.0 9.9999999999999996e-76 +sqrt0015 sqrt -1e-150 -0.0 -> 0.0 -9.9999999999999996e-76 +sqrt0016 sqrt -9.9999999999999998e-17 0.0 -> 0.0 1e-08 +sqrt0017 sqrt -9.9999999999999998e-17 -0.0 -> 0.0 -1e-08 +sqrt0018 sqrt -0.001 0.0 -> 0.0 0.031622776601683791 +sqrt0019 sqrt -0.001 -0.0 -> 0.0 -0.031622776601683791 +sqrt0020 sqrt -0.57899999999999996 0.0 -> 0.0 0.76092049518987193 +sqrt0021 sqrt -0.57899999999999996 -0.0 -> 0.0 -0.76092049518987193 +sqrt0022 sqrt -0.99999999999999989 0.0 -> 0.0 0.99999999999999989 +sqrt0023 sqrt -0.99999999999999989 -0.0 -> 0.0 -0.99999999999999989 +sqrt0024 sqrt -1.0000000000000002 0.0 -> 0.0 1.0 +sqrt0025 sqrt -1.0000000000000002 -0.0 -> 0.0 -1.0 +sqrt0026 sqrt -1.0009999999999999 0.0 -> 0.0 1.000499875062461 +sqrt0027 sqrt -1.0009999999999999 -0.0 -> 0.0 -1.000499875062461 +sqrt0028 sqrt -2.0 0.0 -> 0.0 1.4142135623730951 +sqrt0029 sqrt -2.0 -0.0 -> 0.0 -1.4142135623730951 +sqrt0030 sqrt -23.0 0.0 -> 0.0 4.7958315233127191 +sqrt0031 sqrt -23.0 -0.0 -> 0.0 -4.7958315233127191 +sqrt0032 sqrt -10000000000000000.0 0.0 -> 0.0 100000000.0 +sqrt0033 sqrt -10000000000000000.0 -0.0 -> 0.0 -100000000.0 +sqrt0034 sqrt -9.9999999999999998e+149 0.0 -> 0.0 9.9999999999999993e+74 +sqrt0035 sqrt -9.9999999999999998e+149 -0.0 -> 0.0 -9.9999999999999993e+74 +sqrt0036 sqrt -1.0000000000000001e+299 0.0 -> 0.0 3.1622776601683796e+149 +sqrt0037 sqrt -1.0000000000000001e+299 -0.0 -> 0.0 -3.1622776601683796e+149 +sqrt0038 sqrt 9.8813129168249309e-324 0.0 -> 3.1434555694052576e-162 0.0 +sqrt0039 sqrt 9.8813129168249309e-324 -0.0 -> 3.1434555694052576e-162 -0.0 +sqrt0040 sqrt 1e-305 0.0 -> 3.1622776601683791e-153 0.0 +sqrt0041 sqrt 1e-305 -0.0 -> 3.1622776601683791e-153 -0.0 +sqrt0042 sqrt 1e-150 0.0 -> 9.9999999999999996e-76 0.0 +sqrt0043 sqrt 1e-150 -0.0 -> 9.9999999999999996e-76 -0.0 +sqrt0044 sqrt 9.9999999999999998e-17 0.0 -> 1e-08 0.0 +sqrt0045 sqrt 9.9999999999999998e-17 -0.0 -> 1e-08 -0.0 +sqrt0046 sqrt 0.001 0.0 -> 0.031622776601683791 0.0 +sqrt0047 sqrt 0.001 -0.0 -> 0.031622776601683791 -0.0 +sqrt0048 sqrt 0.57899999999999996 0.0 -> 0.76092049518987193 0.0 +sqrt0049 sqrt 0.57899999999999996 -0.0 -> 0.76092049518987193 -0.0 +sqrt0050 sqrt 0.99999999999999989 0.0 -> 0.99999999999999989 0.0 +sqrt0051 sqrt 0.99999999999999989 -0.0 -> 0.99999999999999989 -0.0 +sqrt0052 sqrt 1.0000000000000002 0.0 -> 1.0 0.0 +sqrt0053 sqrt 1.0000000000000002 -0.0 -> 1.0 -0.0 +sqrt0054 sqrt 1.0009999999999999 0.0 -> 1.000499875062461 0.0 +sqrt0055 sqrt 1.0009999999999999 -0.0 -> 1.000499875062461 -0.0 +sqrt0056 sqrt 2.0 0.0 -> 1.4142135623730951 0.0 +sqrt0057 sqrt 2.0 -0.0 -> 1.4142135623730951 -0.0 +sqrt0058 sqrt 23.0 0.0 -> 4.7958315233127191 0.0 +sqrt0059 sqrt 23.0 -0.0 -> 4.7958315233127191 -0.0 +sqrt0060 sqrt 10000000000000000.0 0.0 -> 100000000.0 0.0 +sqrt0061 sqrt 10000000000000000.0 -0.0 -> 100000000.0 -0.0 +sqrt0062 sqrt 9.9999999999999998e+149 0.0 -> 9.9999999999999993e+74 0.0 +sqrt0063 sqrt 9.9999999999999998e+149 -0.0 -> 9.9999999999999993e+74 -0.0 +sqrt0064 sqrt 1.0000000000000001e+299 0.0 -> 3.1622776601683796e+149 0.0 +sqrt0065 sqrt 1.0000000000000001e+299 -0.0 -> 3.1622776601683796e+149 -0.0 + +-- random inputs +sqrt0100 sqrt -0.34252542541549913 -223039880.15076211 -> 10560.300180587592 -10560.300196805192 +sqrt0101 sqrt -0.88790791393018909 -5.3307751730827402 -> 1.5027154613689004 -1.7737140896343291 +sqrt0102 sqrt -113916.89291310767 -0.018143374626153858 -> 2.6877817875351178e-05 -337.51576691038952 +sqrt0103 sqrt -0.63187172386197121 -0.26293913366617694 -> 0.16205707495266153 -0.81125471918761971 +sqrt0104 sqrt -0.058185169308906215 -2.3548312990430991 -> 1.0717660342420072 -1.0985752598086966 +sqrt0105 sqrt -1.0580584765935896 0.14400319259151736 -> 0.069837489270111242 1.030987755262468 +sqrt0106 sqrt -1.1667595947504932 0.11159711473953678 -> 0.051598531319315251 1.0813981705111229 +sqrt0107 sqrt -0.5123728411449906 0.026175433648339085 -> 0.018278026262418718 0.71603556293597614 +sqrt0108 sqrt -3.7453400060067228 1.0946500314809635 -> 0.27990088541692498 1.9554243814742367 +sqrt0109 sqrt -0.0027736121575097673 1.0367943000839817 -> 0.71903560338719175 0.72096172651250545 +sqrt0110 sqrt 1501.2559699453188 -1.1997325207283589 -> 38.746047664730959 -0.015481998720355024 +sqrt0111 sqrt 1.4830075326850578 -0.64100878436755349 -> 1.244712815741096 -0.25749264258434584 +sqrt0112 sqrt 0.095395618499734602 -0.48226565701639595 -> 0.54175904053472879 -0.44509239434231551 +sqrt0113 sqrt 0.50109185681863277 -0.54054037379892561 -> 0.7868179858332387 -0.34349772344520979 +sqrt0114 sqrt 0.98779807595367897 -0.00019848758437225191 -> 0.99388031770665153 -9.9854872279921968e-05 +sqrt0115 sqrt 11.845472380792259 0.0010051104581506761 -> 3.4417252072345397 0.00014601840612346451 +sqrt0116 sqrt 2.3558249686735975 0.25605157371744403 -> 1.5371278477386647 0.083288964575761404 +sqrt0117 sqrt 0.77584894123159098 1.0496420627016076 -> 1.0200744386390885 0.51449287568756552 +sqrt0118 sqrt 1.8961715669604893 0.34940793467158854 -> 1.3827991781411615 0.12634080935066902 +sqrt0119 sqrt 0.96025378316565801 0.69573224860140515 -> 1.0358710342209998 0.33581991658093457 + +-- values near 0 +sqrt0120 sqrt 7.3577938365086866e-313 8.1181408465112743e-319 -> 8.5777583531543516e-157 4.732087634251168e-163 +sqrt0121 sqrt 1.2406883874892108e-310 -5.1210133324269776e-312 -> 1.1140990057468052e-155 -2.2982756945349973e-157 +sqrt0122 sqrt -7.1145453001139502e-322 2.9561379244703735e-314 -> 1.2157585807480286e-157 1.2157586100077242e-157 +sqrt0123 sqrt -4.9963244206801218e-314 -8.4718424423690227e-319 -> 1.8950582312540437e-162 -2.2352459419578971e-157 +sqrt0124 sqrt 0.0 7.699553609385195e-318 -> 1.9620848107797476e-159 1.9620848107797476e-159 +sqrt0125 sqrt -0.0 3.3900826606499415e-309 -> 4.1170879639922327e-155 4.1170879639922327e-155 +sqrt0126 sqrt 0.0 -9.8907989772250828e-319 -> 7.032353438652342e-160 -7.032353438652342e-160 +sqrt0127 sqrt -0.0 -1.3722939367590908e-315 -> 2.6194407196566702e-158 -2.6194407196566702e-158 +sqrt0128 sqrt 7.9050503334599447e-323 0.0 -> 8.8910349979403099e-162 0.0 +sqrt0129 sqrt 1.8623241768349486e-309 -0.0 -> 4.3154654173506579e-155 -0.0 +sqrt0130 sqrt -2.665971134499887e-308 0.0 -> 0.0 1.6327801856036491e-154 +sqrt0131 sqrt -1.5477066694467245e-310 -0.0 -> 0.0 -1.2440685951533077e-155 + +-- inputs whose absolute value overflows +sqrt0140 sqrt 1.6999999999999999e+308 -1.6999999999999999e+308 -> 1.4325088230154573e+154 -5.9336458271212207e+153 +sqrt0141 sqrt -1.797e+308 -9.9999999999999999e+306 -> 3.7284476432057307e+152 -1.3410406899802901e+154 + +-- Additional real values (mpmath) +sqrt0150 sqrt 1.7976931348623157e+308 0.0 -> 1.3407807929942596355e+154 0.0 +sqrt0151 sqrt 2.2250738585072014e-308 0.0 -> 1.4916681462400413487e-154 0.0 +sqrt0152 sqrt 5e-324 0.0 -> 2.2227587494850774834e-162 0.0 +sqrt0153 sqrt 5e-324 1.0 -> 0.7071067811865476 0.7071067811865476 + +-- special values +sqrt1000 sqrt 0.0 0.0 -> 0.0 0.0 +sqrt1001 sqrt -0.0 0.0 -> 0.0 0.0 +sqrt1002 sqrt 0.0 inf -> inf inf +sqrt1003 sqrt 2.3 inf -> inf inf +sqrt1004 sqrt inf inf -> inf inf +sqrt1005 sqrt -0.0 inf -> inf inf +sqrt1006 sqrt -2.3 inf -> inf inf +sqrt1007 sqrt -inf inf -> inf inf +sqrt1008 sqrt nan inf -> inf inf +sqrt1009 sqrt 0.0 nan -> nan nan +sqrt1010 sqrt 2.3 nan -> nan nan +sqrt1011 sqrt -0.0 nan -> nan nan +sqrt1012 sqrt -2.3 nan -> nan nan +sqrt1013 sqrt -inf 0.0 -> 0.0 inf +sqrt1014 sqrt -inf 2.3 -> 0.0 inf +sqrt1015 sqrt inf 0.0 -> inf 0.0 +sqrt1016 sqrt inf 2.3 -> inf 0.0 +sqrt1017 sqrt -inf nan -> nan inf ignore-imag-sign +sqrt1018 sqrt inf nan -> inf nan +sqrt1019 sqrt nan 0.0 -> nan nan +sqrt1020 sqrt nan 2.3 -> nan nan +sqrt1021 sqrt nan nan -> nan nan +sqrt1022 sqrt 0.0 -0.0 -> 0.0 -0.0 +sqrt1023 sqrt -0.0 -0.0 -> 0.0 -0.0 +sqrt1024 sqrt 0.0 -inf -> inf -inf +sqrt1025 sqrt 2.3 -inf -> inf -inf +sqrt1026 sqrt inf -inf -> inf -inf +sqrt1027 sqrt -0.0 -inf -> inf -inf +sqrt1028 sqrt -2.3 -inf -> inf -inf +sqrt1029 sqrt -inf -inf -> inf -inf +sqrt1030 sqrt nan -inf -> inf -inf +sqrt1031 sqrt -inf -0.0 -> 0.0 -inf +sqrt1032 sqrt -inf -2.3 -> 0.0 -inf +sqrt1033 sqrt inf -0.0 -> inf -0.0 +sqrt1034 sqrt inf -2.3 -> inf -0.0 +sqrt1035 sqrt nan -0.0 -> nan nan +sqrt1036 sqrt nan -2.3 -> nan nan + + +-- For exp, cosh, sinh, tanh we limit tests to arguments whose +-- imaginary part is less than 10 in absolute value: most math +-- libraries have poor accuracy for (real) sine and cosine for +-- large arguments, and the accuracy of these complex functions +-- suffer correspondingly. +-- +-- Similarly, for cos, sin and tan we limit tests to arguments +-- with relatively small real part. + + +------------------------------- +-- exp: Exponential function -- +------------------------------- + +-- zeros +exp0000 exp 0.0 0.0 -> 1.0 0.0 +exp0001 exp 0.0 -0.0 -> 1.0 -0.0 +exp0002 exp -0.0 0.0 -> 1.0 0.0 +exp0003 exp -0.0 -0.0 -> 1.0 -0.0 + +-- random inputs +exp0004 exp -17.957359009564684 -1.108613895795274 -> 7.0869292576226611e-09 -1.4225929202377833e-08 +exp0005 exp -1.4456149663368642e-15 -0.75359817331772239 -> 0.72923148323917997 -0.68426708517419033 +exp0006 exp -0.76008654883512661 -0.46657235480105019 -> 0.41764393109928666 -0.21035108396792854 +exp0007 exp -5.7071614697735731 -2.3744161818115816e-11 -> 0.0033220890242068356 -7.8880219364953578e-14 +exp0008 exp -0.4653981327927097 -5.2236706667445587e-21 -> 0.62788507378216663 -3.2798648420026468e-21 +exp0009 exp -3.2444565242295518 1.1535625304243959 -> 0.015799936931457641 0.035644950380024749 +exp0010 exp -3.0651456337977727 0.87765086532391878 -> 0.029805595629855953 0.035882775180855669 +exp0011 exp -0.11080823753233926 0.96486386300873106 -> 0.50979112534376314 0.73575512419561562 +exp0012 exp -2.5629722598928648 0.019636235754708079 -> 0.077060452853917397 0.0015133717341137684 +exp0013 exp -3.3201709957983357e-10 1.2684017344487268 -> 0.29780699855434889 0.95462610007689186 +exp0014 exp 0.88767276057993272 -0.18953422986895557 -> 2.3859624049858095 -0.45771559132044426 +exp0015 exp 1.5738333486794742 -2.2576803075544328e-11 -> 4.8251091132458654 -1.0893553826776623e-10 +exp0016 exp 1.6408702341813795 -1.438879484380837 -> 0.6786733590689048 -5.1148284173168825 +exp0017 exp 1.820279424202033 -0.020812040370785722 -> 6.1722462896420902 -0.1284755888435051 +exp0018 exp 1.7273965735945873 -0.61140621328954947 -> 4.6067931898799976 -3.2294267694441308 +exp0019 exp 2.5606034306862995 0.098153136008435504 -> 12.881325889966629 1.2684184812864494 +exp0020 exp 10.280368619483029 3.4564622559748535 -> -27721.283321551502 -9028.9663215568835 +exp0021 exp 1.104007405129741e-155 0.21258803067317278 -> 0.97748813933531764 0.21099037290544478 +exp0022 exp 0.027364777809295172 0.00059226603500623363 -> 1.0277424518451876 0.0006086970181346579 +exp0023 exp 0.94356313429255245 3.418530463518592 -> -2.4712285695346194 -0.70242654900218349 + +-- cases where exp(z) representable, exp(z.real) not +exp0030 exp 710.0 0.78500000000000003 -> 1.5803016909637158e+308 1.5790437551806911e+308 +exp0031 exp 710.0 -0.78500000000000003 -> 1.5803016909637158e+308 -1.5790437551806911e+308 + +-- values for which exp(x) is subnormal, or underflows to 0 +exp0040 exp -735.0 0.78500000000000003 -> 4.3976783136329355e-320 4.3942198541120468e-320 +exp0041 exp -735.0 -2.3559999999999999 -> -4.3952079854037293e-320 -4.396690182341253e-320 +exp0042 exp -745.0 0.0 -> 4.9406564584124654e-324 0.0 +exp0043 exp -745.0 0.7 -> 0.0 0.0 +exp0044 exp -745.0 2.1 -> -0.0 0.0 +exp0045 exp -745.0 3.7 -> -0.0 -0.0 +exp0046 exp -745.0 5.3 -> 0.0 -0.0 + +-- values for which exp(z) overflows +exp0050 exp 710.0 0.0 -> inf 0.0 overflow +exp0051 exp 711.0 0.7 -> inf inf overflow +exp0052 exp 710.0 1.5 -> 1.5802653829857376e+307 inf overflow +exp0053 exp 710.0 1.6 -> -6.5231579995501372e+306 inf overflow +exp0054 exp 710.0 2.8 -> -inf 7.4836177417448528e+307 overflow + +-- Additional real values (mpmath) +exp0070 exp 1e-08 0.0 -> 1.00000001000000005 0.0 +exp0071 exp 0.0003 0.0 -> 1.0003000450045003375 0.0 +exp0072 exp 0.2 0.0 -> 1.2214027581601698475 0.0 +exp0073 exp 1.0 0.0 -> 2.7182818284590452354 0.0 +exp0074 exp -1e-08 0.0 -> 0.99999999000000005 0.0 +exp0075 exp -0.0003 0.0 -> 0.99970004499550033751 0.0 +exp0076 exp -1.0 0.0 -> 0.3678794411714423216 0.0 +exp0077 exp 2.220446049250313e-16 0.0 -> 1.000000000000000222 0.0 +exp0078 exp -1.1102230246251565e-16 0.0 -> 0.99999999999999988898 0.0 +exp0079 exp 2.302585092994046 0.0 -> 10.000000000000002171 0.0 +exp0080 exp -2.302585092994046 0.0 -> 0.099999999999999978292 0.0 +exp0081 exp 709.7827 0.0 -> 1.7976699566638014654e+308 0.0 + +-- special values +exp1000 exp 0.0 0.0 -> 1.0 0.0 +exp1001 exp -0.0 0.0 -> 1.0 0.0 +exp1002 exp 0.0 inf -> nan nan invalid +exp1003 exp 2.3 inf -> nan nan invalid +exp1004 exp -0.0 inf -> nan nan invalid +exp1005 exp -2.3 inf -> nan nan invalid +exp1006 exp 0.0 nan -> nan nan +exp1007 exp 2.3 nan -> nan nan +exp1008 exp -0.0 nan -> nan nan +exp1009 exp -2.3 nan -> nan nan +exp1010 exp -inf 0.0 -> 0.0 0.0 +exp1011 exp -inf 1.4 -> 0.0 0.0 +exp1012 exp -inf 2.8 -> -0.0 0.0 +exp1013 exp -inf 4.2 -> -0.0 -0.0 +exp1014 exp -inf 5.6 -> 0.0 -0.0 +exp1015 exp -inf 7.0 -> 0.0 0.0 +exp1016 exp inf 0.0 -> inf 0.0 +exp1017 exp inf 1.4 -> inf inf +exp1018 exp inf 2.8 -> -inf inf +exp1019 exp inf 4.2 -> -inf -inf +exp1020 exp inf 5.6 -> inf -inf +exp1021 exp inf 7.0 -> inf inf +exp1022 exp -inf inf -> 0.0 0.0 ignore-real-sign ignore-imag-sign +exp1023 exp inf inf -> inf nan invalid ignore-real-sign +exp1024 exp -inf nan -> 0.0 0.0 ignore-real-sign ignore-imag-sign +exp1025 exp inf nan -> inf nan ignore-real-sign +exp1026 exp nan 0.0 -> nan 0.0 +exp1027 exp nan 2.3 -> nan nan +exp1028 exp nan inf -> nan nan +exp1029 exp nan nan -> nan nan +exp1030 exp 0.0 -0.0 -> 1.0 -0.0 +exp1031 exp -0.0 -0.0 -> 1.0 -0.0 +exp1032 exp 0.0 -inf -> nan nan invalid +exp1033 exp 2.3 -inf -> nan nan invalid +exp1034 exp -0.0 -inf -> nan nan invalid +exp1035 exp -2.3 -inf -> nan nan invalid +exp1036 exp -inf -0.0 -> 0.0 -0.0 +exp1037 exp -inf -1.4 -> 0.0 -0.0 +exp1038 exp -inf -2.8 -> -0.0 -0.0 +exp1039 exp -inf -4.2 -> -0.0 0.0 +exp1040 exp -inf -5.6 -> 0.0 0.0 +exp1041 exp -inf -7.0 -> 0.0 -0.0 +exp1042 exp inf -0.0 -> inf -0.0 +exp1043 exp inf -1.4 -> inf -inf +exp1044 exp inf -2.8 -> -inf -inf +exp1045 exp inf -4.2 -> -inf inf +exp1046 exp inf -5.6 -> inf inf +exp1047 exp inf -7.0 -> inf -inf +exp1048 exp -inf -inf -> 0.0 0.0 ignore-real-sign ignore-imag-sign +exp1049 exp inf -inf -> inf nan invalid ignore-real-sign +exp1050 exp nan -0.0 -> nan -0.0 +exp1051 exp nan -2.3 -> nan nan +exp1052 exp nan -inf -> nan nan + + +----------------------------- +-- cosh: Hyperbolic Cosine -- +----------------------------- + +-- zeros +cosh0000 cosh 0.0 0.0 -> 1.0 0.0 +cosh0001 cosh 0.0 -0.0 -> 1.0 -0.0 +cosh0002 cosh -0.0 0.0 -> 1.0 -0.0 +cosh0003 cosh -0.0 -0.0 -> 1.0 0.0 + +-- random inputs +cosh0004 cosh -0.85395264297414253 -8.8553756148671958 -> -1.1684340348021185 0.51842195359787435 +cosh0005 cosh -19.584904237211223 -0.066582627994906177 -> 159816812.23336992 10656776.050406246 +cosh0006 cosh -0.11072618401130772 -1.484820215073247 -> 0.086397164744949503 0.11054275637717284 +cosh0007 cosh -3.4764840250681752 -0.48440348288275276 -> 14.325931955190844 7.5242053548737955 +cosh0008 cosh -0.52047063604524602 -0.3603805382775585 -> 1.0653940354683802 0.19193293606252473 +cosh0009 cosh -1.39518962975995 0.0074738604700702906 -> 2.1417031027235969 -0.01415518712296308 +cosh0010 cosh -0.37107064757653541 0.14728085307856609 -> 1.0580601496776991 -0.055712531964568587 +cosh0011 cosh -5.8470200958739653 4.0021722388336292 -> -112.86220667618285 131.24734033545013 +cosh0012 cosh -0.1700261444851883 0.97167540135354513 -> 0.57208748253577946 -0.1410904820240203 +cosh0013 cosh -0.44042397902648783 1.0904791964139742 -> 0.50760322393058133 -0.40333966652010816 +cosh0014 cosh 0.052267552491867299 -3.8889011430644174 -> -0.73452303414639297 0.035540704833537134 +cosh0015 cosh 0.98000764177127453 -1.2548829247784097 -> 0.47220747341416142 -1.0879421432180316 +cosh0016 cosh 0.083594701222644008 -0.88847899930181284 -> 0.63279782419312613 -0.064954566816002285 +cosh0017 cosh 1.38173531783776 -0.43185040816732229 -> 1.9221663374671647 -0.78073830858849347 +cosh0018 cosh 0.57315681120148465 -0.22255760951027942 -> 1.1399733125173004 -0.1335512343605956 +cosh0019 cosh 1.8882512333062347 4.5024932182383797 -> -0.7041602065362691 -3.1573822131964615 +cosh0020 cosh 0.5618219206858317 0.92620452129575348 -> 0.69822380405378381 0.47309067471054522 +cosh0021 cosh 0.54361442847062591 0.64176483583018462 -> 0.92234462074193491 0.34167906495845501 +cosh0022 cosh 0.0014777403107920331 1.3682028122677661 -> 0.2012106963899549 0.001447518137863219 +cosh0023 cosh 2.218885944363501 2.0015727395883687 -> -1.94294321081968 4.1290269176083196 + +-- large real part +cosh0030 cosh 710.5 2.3519999999999999 -> -1.2967465239355998e+308 1.3076707908857333e+308 +cosh0031 cosh -710.5 0.69999999999999996 -> 1.4085466381392499e+308 -1.1864024666450239e+308 +cosh0032 cosh 720.0 0.0 -> inf 0.0 overflow + +-- Additional real values (mpmath) +cosh0050 cosh 1e-150 0.0 -> 1.0 0.0 +cosh0051 cosh 1e-18 0.0 -> 1.0 0.0 +cosh0052 cosh 1e-09 0.0 -> 1.0000000000000000005 0.0 +cosh0053 cosh 0.0003 0.0 -> 1.0000000450000003375 0.0 +cosh0054 cosh 0.2 0.0 -> 1.0200667556190758485 0.0 +cosh0055 cosh 1.0 0.0 -> 1.5430806348152437785 0.0 +cosh0056 cosh -1e-18 0.0 -> 1.0 -0.0 +cosh0057 cosh -0.0003 0.0 -> 1.0000000450000003375 -0.0 +cosh0058 cosh -1.0 0.0 -> 1.5430806348152437785 -0.0 +cosh0059 cosh 1.3169578969248168 0.0 -> 2.0000000000000001504 0.0 +cosh0060 cosh -1.3169578969248168 0.0 -> 2.0000000000000001504 -0.0 +cosh0061 cosh 17.328679513998633 0.0 -> 16777216.000000021938 0.0 +cosh0062 cosh 18.714973875118524 0.0 -> 67108864.000000043662 0.0 +cosh0063 cosh 709.7827 0.0 -> 8.9883497833190073272e+307 0.0 +cosh0064 cosh -709.7827 0.0 -> 8.9883497833190073272e+307 -0.0 + +-- special values +cosh1000 cosh 0.0 0.0 -> 1.0 0.0 +cosh1001 cosh 0.0 inf -> nan 0.0 invalid ignore-imag-sign +cosh1002 cosh 0.0 nan -> nan 0.0 ignore-imag-sign +cosh1003 cosh 2.3 inf -> nan nan invalid +cosh1004 cosh 2.3 nan -> nan nan +cosh1005 cosh inf 0.0 -> inf 0.0 +cosh1006 cosh inf 1.4 -> inf inf +cosh1007 cosh inf 2.8 -> -inf inf +cosh1008 cosh inf 4.2 -> -inf -inf +cosh1009 cosh inf 5.6 -> inf -inf +cosh1010 cosh inf 7.0 -> inf inf +cosh1011 cosh inf inf -> inf nan invalid ignore-real-sign +cosh1012 cosh inf nan -> inf nan +cosh1013 cosh nan 0.0 -> nan 0.0 ignore-imag-sign +cosh1014 cosh nan 2.3 -> nan nan +cosh1015 cosh nan inf -> nan nan +cosh1016 cosh nan nan -> nan nan +cosh1017 cosh 0.0 -0.0 -> 1.0 -0.0 +cosh1018 cosh 0.0 -inf -> nan 0.0 invalid ignore-imag-sign +cosh1019 cosh 2.3 -inf -> nan nan invalid +cosh1020 cosh inf -0.0 -> inf -0.0 +cosh1021 cosh inf -1.4 -> inf -inf +cosh1022 cosh inf -2.8 -> -inf -inf +cosh1023 cosh inf -4.2 -> -inf inf +cosh1024 cosh inf -5.6 -> inf inf +cosh1025 cosh inf -7.0 -> inf -inf +cosh1026 cosh inf -inf -> inf nan invalid ignore-real-sign +cosh1027 cosh nan -0.0 -> nan 0.0 ignore-imag-sign +cosh1028 cosh nan -2.3 -> nan nan +cosh1029 cosh nan -inf -> nan nan +cosh1030 cosh -0.0 -0.0 -> 1.0 0.0 +cosh1031 cosh -0.0 -inf -> nan 0.0 invalid ignore-imag-sign +cosh1032 cosh -0.0 nan -> nan 0.0 ignore-imag-sign +cosh1033 cosh -2.3 -inf -> nan nan invalid +cosh1034 cosh -2.3 nan -> nan nan +cosh1035 cosh -inf -0.0 -> inf 0.0 +cosh1036 cosh -inf -1.4 -> inf inf +cosh1037 cosh -inf -2.8 -> -inf inf +cosh1038 cosh -inf -4.2 -> -inf -inf +cosh1039 cosh -inf -5.6 -> inf -inf +cosh1040 cosh -inf -7.0 -> inf inf +cosh1041 cosh -inf -inf -> inf nan invalid ignore-real-sign +cosh1042 cosh -inf nan -> inf nan +cosh1043 cosh -0.0 0.0 -> 1.0 -0.0 +cosh1044 cosh -0.0 inf -> nan 0.0 invalid ignore-imag-sign +cosh1045 cosh -2.3 inf -> nan nan invalid +cosh1046 cosh -inf 0.0 -> inf -0.0 +cosh1047 cosh -inf 1.4 -> inf -inf +cosh1048 cosh -inf 2.8 -> -inf -inf +cosh1049 cosh -inf 4.2 -> -inf inf +cosh1050 cosh -inf 5.6 -> inf inf +cosh1051 cosh -inf 7.0 -> inf -inf +cosh1052 cosh -inf inf -> inf nan invalid ignore-real-sign + + +--------------------------- +-- sinh: Hyperbolic Sine -- +--------------------------- + +-- zeros +sinh0000 sinh 0.0 0.0 -> 0.0 0.0 +sinh0001 sinh 0.0 -0.0 -> 0.0 -0.0 +sinh0002 sinh -0.0 0.0 -> -0.0 0.0 +sinh0003 sinh -0.0 -0.0 -> -0.0 -0.0 + +-- random inputs +sinh0004 sinh -17.282588091462742 -0.38187948694103546 -> -14867386.857248396 -5970648.6553516639 +sinh0005 sinh -343.91971203143208 -5.0172868877771525e-22 -> -1.1518691776521735e+149 -5.7792581214689021e+127 +sinh0006 sinh -14.178122253300922 -1.9387157579351293 -> 258440.37909034826 -670452.58500946441 +sinh0007 sinh -1.0343810581686239 -1.0970235266369905 -> -0.56070858278092739 -1.4098883258046697 +sinh0008 sinh -0.066126561416368204 -0.070461584169961872 -> -0.066010558700938124 -0.070557276738637542 +sinh0009 sinh -0.37630149150308484 3.3621734692162173 -> 0.37591118119332617 -0.23447115926369383 +sinh0010 sinh -0.049941960978670055 0.40323767020414625 -> -0.045955482136329009 0.3928878494430646 +sinh0011 sinh -16.647852603903715 0.0026852219129082098 -> -8492566.5739382561 22804.480671133562 +sinh0012 sinh -1.476625314303694 0.89473773116683386 -> -1.2982943334382224 1.7966593367791204 +sinh0013 sinh -422.36429577556913 0.10366634502307912 -> -1.3400321008920044e+183 1.3941600948045599e+182 +sinh0014 sinh 0.09108340745641981 -0.40408227416070353 -> 0.083863724802237902 -0.39480716553935602 +sinh0015 sinh 2.036064132067386 -2.6831729961386239 -> -3.37621124363175 -1.723868330002817 +sinh0016 sinh 2.5616717223063317 -0.0078978498622717767 -> 6.4399415853815869 -0.051472264400722133 +sinh0017 sinh 0.336804011985188 -6.5654622971649337 -> 0.32962499307574578 -0.29449170159995197 +sinh0018 sinh 0.23774603755649693 -0.92467195799232049 -> 0.14449839490603389 -0.82109449053556793 +sinh0019 sinh 0.0011388273541465494 1.9676196882949855 -> -0.00044014605389634999 0.92229398407098806 +sinh0020 sinh 3.2443870105663759 0.8054287559616895 -> 8.8702890778527426 9.2610748597042196 +sinh0021 sinh 0.040628908857054738 0.098206391190944958 -> 0.04044426841671233 0.098129544739707392 +sinh0022 sinh 4.7252283918217696e-30 9.1198155642656697 -> -4.5071980561644404e-30 0.30025730701661713 +sinh0023 sinh 0.043713693678420068 0.22512549887532657 -> 0.042624198673416713 0.22344201231217961 + +-- large real part +sinh0030 sinh 710.5 -2.3999999999999999 -> -1.3579970564885919e+308 -1.24394470907798e+308 +sinh0031 sinh -710.5 0.80000000000000004 -> -1.2830671601735164e+308 1.3210954193997678e+308 +sinh0032 sinh 720.0 0.0 -> inf 0.0 overflow + +-- Additional real values (mpmath) +sinh0050 sinh 1e-100 0.0 -> 1.00000000000000002e-100 0.0 +sinh0051 sinh 5e-17 0.0 -> 4.9999999999999998955e-17 0.0 +sinh0052 sinh 1e-16 0.0 -> 9.999999999999999791e-17 0.0 +sinh0053 sinh 3.7e-08 0.0 -> 3.7000000000000008885e-8 0.0 +sinh0054 sinh 0.001 0.0 -> 0.0010000001666666750208 0.0 +sinh0055 sinh 0.2 0.0 -> 0.20133600254109399895 0.0 +sinh0056 sinh 1.0 0.0 -> 1.1752011936438014569 0.0 +sinh0057 sinh -3.7e-08 0.0 -> -3.7000000000000008885e-8 0.0 +sinh0058 sinh -0.001 0.0 -> -0.0010000001666666750208 0.0 +sinh0059 sinh -1.0 0.0 -> -1.1752011936438014569 0.0 +sinh0060 sinh 1.4436354751788103 0.0 -> 1.9999999999999999078 0.0 +sinh0061 sinh -1.4436354751788103 0.0 -> -1.9999999999999999078 0.0 +sinh0062 sinh 17.328679513998633 0.0 -> 16777215.999999992136 0.0 +sinh0063 sinh 18.714973875118524 0.0 -> 67108864.000000036211 0.0 +sinh0064 sinh 709.7827 0.0 -> 8.9883497833190073272e+307 0.0 +sinh0065 sinh -709.7827 0.0 -> -8.9883497833190073272e+307 0.0 + +-- special values +sinh1000 sinh 0.0 0.0 -> 0.0 0.0 +sinh1001 sinh 0.0 inf -> 0.0 nan invalid ignore-real-sign +sinh1002 sinh 0.0 nan -> 0.0 nan ignore-real-sign +sinh1003 sinh 2.3 inf -> nan nan invalid +sinh1004 sinh 2.3 nan -> nan nan +sinh1005 sinh inf 0.0 -> inf 0.0 +sinh1006 sinh inf 1.4 -> inf inf +sinh1007 sinh inf 2.8 -> -inf inf +sinh1008 sinh inf 4.2 -> -inf -inf +sinh1009 sinh inf 5.6 -> inf -inf +sinh1010 sinh inf 7.0 -> inf inf +sinh1011 sinh inf inf -> inf nan invalid ignore-real-sign +sinh1012 sinh inf nan -> inf nan ignore-real-sign +sinh1013 sinh nan 0.0 -> nan 0.0 +sinh1014 sinh nan 2.3 -> nan nan +sinh1015 sinh nan inf -> nan nan +sinh1016 sinh nan nan -> nan nan +sinh1017 sinh 0.0 -0.0 -> 0.0 -0.0 +sinh1018 sinh 0.0 -inf -> 0.0 nan invalid ignore-real-sign +sinh1019 sinh 2.3 -inf -> nan nan invalid +sinh1020 sinh inf -0.0 -> inf -0.0 +sinh1021 sinh inf -1.4 -> inf -inf +sinh1022 sinh inf -2.8 -> -inf -inf +sinh1023 sinh inf -4.2 -> -inf inf +sinh1024 sinh inf -5.6 -> inf inf +sinh1025 sinh inf -7.0 -> inf -inf +sinh1026 sinh inf -inf -> inf nan invalid ignore-real-sign +sinh1027 sinh nan -0.0 -> nan -0.0 +sinh1028 sinh nan -2.3 -> nan nan +sinh1029 sinh nan -inf -> nan nan +sinh1030 sinh -0.0 -0.0 -> -0.0 -0.0 +sinh1031 sinh -0.0 -inf -> 0.0 nan invalid ignore-real-sign +sinh1032 sinh -0.0 nan -> 0.0 nan ignore-real-sign +sinh1033 sinh -2.3 -inf -> nan nan invalid +sinh1034 sinh -2.3 nan -> nan nan +sinh1035 sinh -inf -0.0 -> -inf -0.0 +sinh1036 sinh -inf -1.4 -> -inf -inf +sinh1037 sinh -inf -2.8 -> inf -inf +sinh1038 sinh -inf -4.2 -> inf inf +sinh1039 sinh -inf -5.6 -> -inf inf +sinh1040 sinh -inf -7.0 -> -inf -inf +sinh1041 sinh -inf -inf -> inf nan invalid ignore-real-sign +sinh1042 sinh -inf nan -> inf nan ignore-real-sign +sinh1043 sinh -0.0 0.0 -> -0.0 0.0 +sinh1044 sinh -0.0 inf -> 0.0 nan invalid ignore-real-sign +sinh1045 sinh -2.3 inf -> nan nan invalid +sinh1046 sinh -inf 0.0 -> -inf 0.0 +sinh1047 sinh -inf 1.4 -> -inf inf +sinh1048 sinh -inf 2.8 -> inf inf +sinh1049 sinh -inf 4.2 -> inf -inf +sinh1050 sinh -inf 5.6 -> -inf -inf +sinh1051 sinh -inf 7.0 -> -inf inf +sinh1052 sinh -inf inf -> inf nan invalid ignore-real-sign + + +------------------------------ +-- tanh: Hyperbolic Tangent -- +------------------------------ + +-- Disabled test: replaced by test_math.testTanhSign() +-- and test_cmath.testTanhSign() + +-- -- zeros +-- tanh0000 tanh 0.0 0.0 -> 0.0 0.0 +-- tanh0001 tanh 0.0 -0.0 -> 0.0 -0.0 +-- tanh0002 tanh -0.0 0.0 -> -0.0 0.0 +-- tanh0003 tanh -0.0 -0.0 -> -0.0 -0.0 + +-- random inputs +tanh0004 tanh -21.200500450664993 -1.6970729480342996 -> -1.0 1.9241352344849399e-19 +tanh0005 tanh -0.34158771504251928 -8.0848504951747131 -> -2.123711225855613 1.2827526782026006 +tanh0006 tanh -15.454144725193689 -0.23619582288265617 -> -0.99999999999993283 -3.4336684248260036e-14 +tanh0007 tanh -7.6103163119661952 -0.7802748320307008 -> -0.99999999497219438 -4.9064845343755437e-07 +tanh0008 tanh -0.15374717235792129 -0.6351086327306138 -> -0.23246081703561869 -0.71083467433910219 +tanh0009 tanh -0.49101115474392465 0.09723001264886301 -> -0.45844445715492133 0.077191158541805888 +tanh0010 tanh -0.10690612157664491 2.861612800856395 -> -0.11519761626257358 -0.28400488355647507 +tanh0011 tanh -0.91505774192066702 1.5431174597727007 -> -1.381109893068114 0.025160819663709356 +tanh0012 tanh -0.057433367093792223 0.35491159541246459 -> -0.065220499046696953 0.36921788332369498 +tanh0013 tanh -1.3540418621233514 0.18969415642242535 -> -0.88235642861151387 0.043764069984411721 +tanh0014 tanh 0.94864783961003529 -0.11333689578867717 -> 0.74348401861861368 -0.051271042543855221 +tanh0015 tanh 1.9591698133845488 -0.0029654444904578339 -> 0.9610270776968135 -0.00022664240049212933 +tanh0016 tanh 1.0949715796669197 -0.24706642853984456 -> 0.81636574501369386 -0.087767436914149954 +tanh0017 tanh 5770428.2113731047 -3.7160580339833165 -> 1.0 -0.0 +tanh0018 tanh 1.5576782321399629 -1.0357943787966468 -> 1.0403002384895388 -0.081126347894671463 +tanh0019 tanh 0.62378536230552961 2.3471393579560216 -> 0.85582499238960363 -0.53569473646842869 +tanh0020 tanh 17.400628602508025 9.3987059533841979 -> 0.99999999999999845 -8.0175867720530832e-17 +tanh0021 tanh 0.15026177509871896 0.50630349159505472 -> 0.19367536571827768 0.53849847858853661 +tanh0022 tanh 0.57433977530711167 1.0071604546265627 -> 1.0857848159262844 0.69139213955872214 +tanh0023 tanh 0.16291181500449456 0.006972810241567544 -> 0.16149335907551157 0.0067910772903467817 + +-- large real part +tanh0030 tanh 710 0.13 -> 1.0 0.0 +tanh0031 tanh -711 7.4000000000000004 -> -1.0 0.0 +tanh0032 tanh 1000 -2.3199999999999998 -> 1.0 0.0 +tanh0033 tanh -1.0000000000000001e+300 -9.6699999999999999 -> -1.0 -0.0 + +-- Additional real values (mpmath) +tanh0050 tanh 1e-100 0.0 -> 1.00000000000000002e-100 0.0 +tanh0051 tanh 5e-17 0.0 -> 4.9999999999999998955e-17 0.0 +tanh0052 tanh 1e-16 0.0 -> 9.999999999999999791e-17 0.0 +tanh0053 tanh 3.7e-08 0.0 -> 3.6999999999999983559e-8 0.0 +tanh0054 tanh 0.001 0.0 -> 0.00099999966666680002076 0.0 +tanh0055 tanh 0.2 0.0 -> 0.19737532022490401141 0.0 +tanh0056 tanh 1.0 0.0 -> 0.76159415595576488812 0.0 +tanh0057 tanh -3.7e-08 0.0 -> -3.6999999999999983559e-8 0.0 +tanh0058 tanh -0.001 0.0 -> -0.00099999966666680002076 0.0 +tanh0059 tanh -1.0 0.0 -> -0.76159415595576488812 0.0 +tanh0060 tanh 0.5493061443340549 0.0 -> 0.50000000000000003402 0.0 +tanh0061 tanh -0.5493061443340549 0.0 -> -0.50000000000000003402 0.0 +tanh0062 tanh 17.328679513998633 0.0 -> 0.99999999999999822364 0.0 +tanh0063 tanh 18.714973875118524 0.0 -> 0.99999999999999988898 0.0 +tanh0064 tanh 711 0.0 -> 1.0 0.0 +tanh0065 tanh 1.797e+308 0.0 -> 1.0 0.0 + +--special values +tanh1000 tanh 0.0 0.0 -> 0.0 0.0 +tanh1001 tanh 0.0 inf -> nan nan invalid +tanh1002 tanh 2.3 inf -> nan nan invalid +tanh1003 tanh 0.0 nan -> nan nan +tanh1004 tanh 2.3 nan -> nan nan +tanh1005 tanh inf 0.0 -> 1.0 0.0 +tanh1006 tanh inf 0.7 -> 1.0 0.0 +tanh1007 tanh inf 1.4 -> 1.0 0.0 +tanh1008 tanh inf 2.1 -> 1.0 -0.0 +tanh1009 tanh inf 2.8 -> 1.0 -0.0 +tanh1010 tanh inf 3.5 -> 1.0 0.0 +tanh1011 tanh inf inf -> 1.0 0.0 ignore-imag-sign +tanh1012 tanh inf nan -> 1.0 0.0 ignore-imag-sign +tanh1013 tanh nan 0.0 -> nan 0.0 +tanh1014 tanh nan 2.3 -> nan nan +tanh1015 tanh nan inf -> nan nan +tanh1016 tanh nan nan -> nan nan +tanh1017 tanh 0.0 -0.0 -> 0.0 -0.0 +tanh1018 tanh 0.0 -inf -> nan nan invalid +tanh1019 tanh 2.3 -inf -> nan nan invalid +tanh1020 tanh inf -0.0 -> 1.0 -0.0 +tanh1021 tanh inf -0.7 -> 1.0 -0.0 +tanh1022 tanh inf -1.4 -> 1.0 -0.0 +tanh1023 tanh inf -2.1 -> 1.0 0.0 +tanh1024 tanh inf -2.8 -> 1.0 0.0 +tanh1025 tanh inf -3.5 -> 1.0 -0.0 +tanh1026 tanh inf -inf -> 1.0 0.0 ignore-imag-sign +tanh1027 tanh nan -0.0 -> nan -0.0 +tanh1028 tanh nan -2.3 -> nan nan +tanh1029 tanh nan -inf -> nan nan +tanh1030 tanh -0.0 -0.0 -> -0.0 -0.0 +tanh1031 tanh -0.0 -inf -> nan nan invalid +tanh1032 tanh -2.3 -inf -> nan nan invalid +tanh1033 tanh -0.0 nan -> nan nan +tanh1034 tanh -2.3 nan -> nan nan +tanh1035 tanh -inf -0.0 -> -1.0 -0.0 +tanh1036 tanh -inf -0.7 -> -1.0 -0.0 +tanh1037 tanh -inf -1.4 -> -1.0 -0.0 +tanh1038 tanh -inf -2.1 -> -1.0 0.0 +tanh1039 tanh -inf -2.8 -> -1.0 0.0 +tanh1040 tanh -inf -3.5 -> -1.0 -0.0 +tanh1041 tanh -inf -inf -> -1.0 0.0 ignore-imag-sign +tanh1042 tanh -inf nan -> -1.0 0.0 ignore-imag-sign +tanh1043 tanh -0.0 0.0 -> -0.0 0.0 +tanh1044 tanh -0.0 inf -> nan nan invalid +tanh1045 tanh -2.3 inf -> nan nan invalid +tanh1046 tanh -inf 0.0 -> -1.0 0.0 +tanh1047 tanh -inf 0.7 -> -1.0 0.0 +tanh1048 tanh -inf 1.4 -> -1.0 0.0 +tanh1049 tanh -inf 2.1 -> -1.0 -0.0 +tanh1050 tanh -inf 2.8 -> -1.0 -0.0 +tanh1051 tanh -inf 3.5 -> -1.0 0.0 +tanh1052 tanh -inf inf -> -1.0 0.0 ignore-imag-sign + + +----------------- +-- cos: Cosine -- +----------------- + +-- zeros +cos0000 cos 0.0 0.0 -> 1.0 -0.0 +cos0001 cos 0.0 -0.0 -> 1.0 0.0 +cos0002 cos -0.0 0.0 -> 1.0 0.0 +cos0003 cos -0.0 -0.0 -> 1.0 -0.0 + +-- random inputs +cos0004 cos -2.0689194692073034 -0.0016802181751734313 -> -0.47777827208561469 -0.0014760401501695971 +cos0005 cos -0.4209627318177977 -1.8238516774258027 -> 2.9010402201444108 -1.2329207042329617 +cos0006 cos -1.9402181630694557 -2.9751857392891217 -> -3.5465459297970985 -9.1119163586282248 +cos0007 cos -3.3118320290191616 -0.87871302909286142 -> -1.3911528636565498 0.16878141517391701 +cos0008 cos -4.9540404623376872 -0.57949232239026827 -> 0.28062445586552065 0.59467861308508008 +cos0009 cos -0.45374584316245026 1.3950283448373935 -> 1.9247665574290578 0.83004572204761107 +cos0010 cos -0.42578172040176843 1.2715881615413049 -> 1.7517161459489148 0.67863902697363332 +cos0011 cos -0.13862985354300136 0.43587635877670328 -> 1.0859880290361912 0.062157548146672272 +cos0012 cos -0.11073221308966584 9.9384082307326475e-15 -> 0.99387545040722947 1.0982543264065479e-15 +cos0013 cos -1.5027633662054623e-07 0.0069668060249955498 -> 1.0000242682912412 1.0469545565660995e-09 +cos0014 cos 4.9728645490503052 -0.00027479808860952822 -> 0.25754011731975501 -0.00026552849549083186 +cos0015 cos 7.81969303486719 -0.79621523445878783 -> 0.045734882501585063 0.88253139933082991 +cos0016 cos 0.13272421880766716 -0.74668445308718201 -> 1.2806012244432847 0.10825373267437005 +cos0017 cos 4.2396521985973274 -2.2178848380884881 -> -2.1165117057056855 -4.0416492444641401 +cos0018 cos 1.1622206624927296 -0.50400115461197081 -> 0.44884072613370379 0.4823469915034318 +cos0019 cos 1.628772864620884e-08 0.58205705428979282 -> 1.1742319995791435 -1.0024839481956604e-08 +cos0020 cos 2.6385212606111241 2.9886107100937296 -> -8.7209475927161417 -4.7748352107199796 +cos0021 cos 4.8048375263775256 0.0062248852898515658 -> 0.092318702015846243 0.0061983430422306142 +cos0022 cos 7.9914515433858515 0.71659966615501436 -> -0.17375439906936566 -0.77217043527294582 +cos0023 cos 0.45124351152540226 1.6992693993812158 -> 2.543477948972237 -1.1528193694875477 + +-- Additional real values (mpmath) +cos0050 cos 1e-150 0.0 -> 1.0 -0.0 +cos0051 cos 1e-18 0.0 -> 1.0 -0.0 +cos0052 cos 1e-09 0.0 -> 0.9999999999999999995 -0.0 +cos0053 cos 0.0003 0.0 -> 0.9999999550000003375 -0.0 +cos0054 cos 0.2 0.0 -> 0.98006657784124162892 -0.0 +cos0055 cos 1.0 0.0 -> 0.5403023058681397174 -0.0 +cos0056 cos -1e-18 0.0 -> 1.0 0.0 +cos0057 cos -0.0003 0.0 -> 0.9999999550000003375 0.0 +cos0058 cos -1.0 0.0 -> 0.5403023058681397174 0.0 +cos0059 cos 1.0471975511965976 0.0 -> 0.50000000000000009945 -0.0 +cos0060 cos 2.5707963267948966 0.0 -> -0.84147098480789647357 -0.0 +cos0061 cos -2.5707963267948966 0.0 -> -0.84147098480789647357 0.0 +cos0062 cos 7.225663103256523 0.0 -> 0.58778525229247407559 -0.0 +cos0063 cos -8.79645943005142 0.0 -> -0.80901699437494722255 0.0 + +-- special values +cos1000 cos -0.0 0.0 -> 1.0 0.0 +cos1001 cos -inf 0.0 -> nan 0.0 invalid ignore-imag-sign +cos1002 cos nan 0.0 -> nan 0.0 ignore-imag-sign +cos1003 cos -inf 2.2999999999999998 -> nan nan invalid +cos1004 cos nan 2.2999999999999998 -> nan nan +cos1005 cos -0.0 inf -> inf 0.0 +cos1006 cos -1.3999999999999999 inf -> inf inf +cos1007 cos -2.7999999999999998 inf -> -inf inf +cos1008 cos -4.2000000000000002 inf -> -inf -inf +cos1009 cos -5.5999999999999996 inf -> inf -inf +cos1010 cos -7.0 inf -> inf inf +cos1011 cos -inf inf -> inf nan invalid ignore-real-sign +cos1012 cos nan inf -> inf nan +cos1013 cos -0.0 nan -> nan 0.0 ignore-imag-sign +cos1014 cos -2.2999999999999998 nan -> nan nan +cos1015 cos -inf nan -> nan nan +cos1016 cos nan nan -> nan nan +cos1017 cos 0.0 0.0 -> 1.0 -0.0 +cos1018 cos inf 0.0 -> nan 0.0 invalid ignore-imag-sign +cos1019 cos inf 2.2999999999999998 -> nan nan invalid +cos1020 cos 0.0 inf -> inf -0.0 +cos1021 cos 1.3999999999999999 inf -> inf -inf +cos1022 cos 2.7999999999999998 inf -> -inf -inf +cos1023 cos 4.2000000000000002 inf -> -inf inf +cos1024 cos 5.5999999999999996 inf -> inf inf +cos1025 cos 7.0 inf -> inf -inf +cos1026 cos inf inf -> inf nan invalid ignore-real-sign +cos1027 cos 0.0 nan -> nan 0.0 ignore-imag-sign +cos1028 cos 2.2999999999999998 nan -> nan nan +cos1029 cos inf nan -> nan nan +cos1030 cos 0.0 -0.0 -> 1.0 0.0 +cos1031 cos inf -0.0 -> nan 0.0 invalid ignore-imag-sign +cos1032 cos nan -0.0 -> nan 0.0 ignore-imag-sign +cos1033 cos inf -2.2999999999999998 -> nan nan invalid +cos1034 cos nan -2.2999999999999998 -> nan nan +cos1035 cos 0.0 -inf -> inf 0.0 +cos1036 cos 1.3999999999999999 -inf -> inf inf +cos1037 cos 2.7999999999999998 -inf -> -inf inf +cos1038 cos 4.2000000000000002 -inf -> -inf -inf +cos1039 cos 5.5999999999999996 -inf -> inf -inf +cos1040 cos 7.0 -inf -> inf inf +cos1041 cos inf -inf -> inf nan invalid ignore-real-sign +cos1042 cos nan -inf -> inf nan +cos1043 cos -0.0 -0.0 -> 1.0 -0.0 +cos1044 cos -inf -0.0 -> nan 0.0 invalid ignore-imag-sign +cos1045 cos -inf -2.2999999999999998 -> nan nan invalid +cos1046 cos -0.0 -inf -> inf -0.0 +cos1047 cos -1.3999999999999999 -inf -> inf -inf +cos1048 cos -2.7999999999999998 -inf -> -inf -inf +cos1049 cos -4.2000000000000002 -inf -> -inf inf +cos1050 cos -5.5999999999999996 -inf -> inf inf +cos1051 cos -7.0 -inf -> inf -inf +cos1052 cos -inf -inf -> inf nan invalid ignore-real-sign + + +--------------- +-- sin: Sine -- +--------------- + +-- zeros +sin0000 sin 0.0 0.0 -> 0.0 0.0 +sin0001 sin 0.0 -0.0 -> 0.0 -0.0 +sin0002 sin -0.0 0.0 -> -0.0 0.0 +sin0003 sin -0.0 -0.0 -> -0.0 -0.0 + +-- random inputs +sin0004 sin -0.18691829163163759 -0.74388741985507034 -> -0.2396636733773444 -0.80023231101856751 +sin0005 sin -0.45127453702459158 -461.81339920716164 -> -7.9722299331077877e+199 -1.6450205811004628e+200 +sin0006 sin -0.47669228345768921 -2.7369936564987514 -> -3.557238022267124 -6.8308030771226615 +sin0007 sin -0.31024285525950857 -1.4869219939188296 -> -0.70972676047175209 -1.9985029635426839 +sin0008 sin -4.4194573407025608 -1.405999210989288 -> 2.0702480800802685 0.55362250792180601 +sin0009 sin -1.7810832046434898e-05 0.0016439555384379083 -> -1.7810856113185261e-05 0.0016439562786668375 +sin0010 sin -0.8200017874897666 0.61724876887771929 -> -0.8749078195948865 0.44835295550987758 +sin0011 sin -1.4536502806107114 0.63998575534150415 -> -1.2035709929437679 0.080012187489163708 +sin0012 sin -2.2653412155506079 0.13172760685583729 -> -0.77502093809190431 -0.084554426868229532 +sin0013 sin -0.02613983069491858 0.18404766597776073 -> -0.026580778863127943 0.18502525396735642 +sin0014 sin 1.5743065001054617 -0.53125574272642029 -> 1.1444596332092725 0.0019537598099352077 +sin0015 sin 7.3833101791283289e-20 -0.16453221324236217 -> 7.4834720674379429e-20 -0.16527555646466915 +sin0016 sin 0.34763834641254038 -2.8377416421089565 -> 2.918883541504663 -8.0002718053250224 +sin0017 sin 0.077105785180421563 -0.090056027316200674 -> 0.077341973814471304 -0.089909869380524587 +sin0018 sin 3.9063227798142329e-17 -0.05954098654295524 -> 3.9132490348956512e-17 -0.059576172859837351 +sin0019 sin 0.57333917932544598 8.7785221430594696e-06 -> 0.54244029338302935 7.3747869125301368e-06 +sin0020 sin 0.024861722816513169 0.33044620756118515 -> 0.026228801369651 0.3363889671570689 +sin0021 sin 1.4342727387492671 0.81361889790284347 -> 1.3370960060947923 0.12336137961387163 +sin0022 sin 1.1518087354403725 4.8597235966150558 -> 58.919141989603041 26.237003403758852 +sin0023 sin 0.00087773078406649192 34.792379211312095 -> 565548145569.38245 644329685822700.62 + +-- Additional real values (mpmath) +sin0050 sin 1e-100 0.0 -> 1.00000000000000002e-100 0.0 +sin0051 sin 3.7e-08 0.0 -> 3.6999999999999992001e-8 0.0 +sin0052 sin 0.001 0.0 -> 0.00099999983333334168748 0.0 +sin0053 sin 0.2 0.0 -> 0.19866933079506122634 0.0 +sin0054 sin 1.0 0.0 -> 0.84147098480789650665 0.0 +sin0055 sin -3.7e-08 0.0 -> -3.6999999999999992001e-8 0.0 +sin0056 sin -0.001 0.0 -> -0.00099999983333334168748 0.0 +sin0057 sin -1.0 0.0 -> -0.84147098480789650665 0.0 +sin0058 sin 0.5235987755982989 0.0 -> 0.50000000000000004642 0.0 +sin0059 sin -0.5235987755982989 0.0 -> -0.50000000000000004642 0.0 +sin0060 sin 2.6179938779914944 0.0 -> 0.49999999999999996018 -0.0 +sin0061 sin -2.6179938779914944 0.0 -> -0.49999999999999996018 -0.0 +sin0062 sin 7.225663103256523 0.0 -> 0.80901699437494673648 0.0 +sin0063 sin -8.79645943005142 0.0 -> -0.58778525229247340658 -0.0 + +-- special values +sin1000 sin -0.0 0.0 -> -0.0 0.0 +sin1001 sin -inf 0.0 -> nan 0.0 invalid ignore-imag-sign +sin1002 sin nan 0.0 -> nan 0.0 ignore-imag-sign +sin1003 sin -inf 2.2999999999999998 -> nan nan invalid +sin1004 sin nan 2.2999999999999998 -> nan nan +sin1005 sin -0.0 inf -> -0.0 inf +sin1006 sin -1.3999999999999999 inf -> -inf inf +sin1007 sin -2.7999999999999998 inf -> -inf -inf +sin1008 sin -4.2000000000000002 inf -> inf -inf +sin1009 sin -5.5999999999999996 inf -> inf inf +sin1010 sin -7.0 inf -> -inf inf +sin1011 sin -inf inf -> nan inf invalid ignore-imag-sign +sin1012 sin nan inf -> nan inf ignore-imag-sign +sin1013 sin -0.0 nan -> -0.0 nan +sin1014 sin -2.2999999999999998 nan -> nan nan +sin1015 sin -inf nan -> nan nan +sin1016 sin nan nan -> nan nan +sin1017 sin 0.0 0.0 -> 0.0 0.0 +sin1018 sin inf 0.0 -> nan 0.0 invalid ignore-imag-sign +sin1019 sin inf 2.2999999999999998 -> nan nan invalid +sin1020 sin 0.0 inf -> 0.0 inf +sin1021 sin 1.3999999999999999 inf -> inf inf +sin1022 sin 2.7999999999999998 inf -> inf -inf +sin1023 sin 4.2000000000000002 inf -> -inf -inf +sin1024 sin 5.5999999999999996 inf -> -inf inf +sin1025 sin 7.0 inf -> inf inf +sin1026 sin inf inf -> nan inf invalid ignore-imag-sign +sin1027 sin 0.0 nan -> 0.0 nan +sin1028 sin 2.2999999999999998 nan -> nan nan +sin1029 sin inf nan -> nan nan +sin1030 sin 0.0 -0.0 -> 0.0 -0.0 +sin1031 sin inf -0.0 -> nan 0.0 invalid ignore-imag-sign +sin1032 sin nan -0.0 -> nan 0.0 ignore-imag-sign +sin1033 sin inf -2.2999999999999998 -> nan nan invalid +sin1034 sin nan -2.2999999999999998 -> nan nan +sin1035 sin 0.0 -inf -> 0.0 -inf +sin1036 sin 1.3999999999999999 -inf -> inf -inf +sin1037 sin 2.7999999999999998 -inf -> inf inf +sin1038 sin 4.2000000000000002 -inf -> -inf inf +sin1039 sin 5.5999999999999996 -inf -> -inf -inf +sin1040 sin 7.0 -inf -> inf -inf +sin1041 sin inf -inf -> nan inf invalid ignore-imag-sign +sin1042 sin nan -inf -> nan inf ignore-imag-sign +sin1043 sin -0.0 -0.0 -> -0.0 -0.0 +sin1044 sin -inf -0.0 -> nan 0.0 invalid ignore-imag-sign +sin1045 sin -inf -2.2999999999999998 -> nan nan invalid +sin1046 sin -0.0 -inf -> -0.0 -inf +sin1047 sin -1.3999999999999999 -inf -> -inf -inf +sin1048 sin -2.7999999999999998 -inf -> -inf inf +sin1049 sin -4.2000000000000002 -inf -> inf inf +sin1050 sin -5.5999999999999996 -inf -> inf -inf +sin1051 sin -7.0 -inf -> -inf -inf +sin1052 sin -inf -inf -> nan inf invalid ignore-imag-sign + + +------------------ +-- tan: Tangent -- +------------------ + +-- zeros +tan0000 tan 0.0 0.0 -> 0.0 0.0 +tan0001 tan 0.0 -0.0 -> 0.0 -0.0 +tan0002 tan -0.0 0.0 -> -0.0 0.0 +tan0003 tan -0.0 -0.0 -> -0.0 -0.0 + +-- random inputs +tan0004 tan -0.56378561833861074 -1.7110276237187664e+73 -> -0.0 -1.0 +tan0005 tan -3.5451633993471915e-12 -2.855471863564059 -> -4.6622441304889575e-14 -0.99340273843093951 +tan0006 tan -2.502442719638696 -0.26742234390504221 -> 0.66735215252994995 -0.39078997935420956 +tan0007 tan -0.87639597720371365 -55.586225523280206 -> -1.0285264565948176e-48 -1.0 +tan0008 tan -0.015783869596427243 -520.05944436039272 -> -0.0 -1.0 +tan0009 tan -0.84643549990725164 2.0749097935396343 -> -0.031412661676959573 1.0033548479526764 +tan0010 tan -0.43613792248559646 8.1082741629458059 -> -1.3879848444644593e-07 0.99999988344224011 +tan0011 tan -1.0820906367833114 0.28571868992480248 -> -1.3622485737936536 0.99089269377971245 +tan0012 tan -1.1477859580220084 1.9021637002708041 -> -0.034348450042071196 1.0293954097901687 +tan0013 tan -0.12465543176953409 3.0606851016344815e-05 -> -0.12530514290387343 3.1087420769945479e-05 +tan0014 tan 3.7582848717525343 -692787020.44038939 -> 0.0 -1.0 +tan0015 tan 2.2321967655142176e-06 -10.090069423008169 -> 1.5369846120622643e-14 -0.99999999655723759 +tan0016 tan 0.88371172390245012 -1.1635053630132823 -> 0.19705017118625889 -1.0196452280843129 +tan0017 tan 2.1347414231849267 -1.9311339960416831 -> -0.038663576915982524 -1.0174399993980778 +tan0018 tan 5.9027945255899974 -2.1574195684607135e-183 -> -0.39986591539281496 -2.5023753167976915e-183 +tan0019 tan 0.44811489490805362 683216075670.07556 -> 0.0 1.0 +tan0020 tan 4.1459766396068325 12.523017205605756 -> 2.4022514758988068e-11 1.0000000000112499 +tan0021 tan 1.7809617968443272 1.5052381702853379 -> -0.044066222118946903 1.0932684517702778 +tan0022 tan 1.1615313900880577 1.7956298728647107 -> 0.041793186826390362 1.0375339546034792 +tan0023 tan 0.067014779477908945 5.8517361577457097 -> 2.2088639754800034e-06 0.9999836182420061 + +-- Additional real values (mpmath) +tan0050 tan 1e-100 0.0 -> 1.00000000000000002e-100 0.0 +tan0051 tan 3.7e-08 0.0 -> 3.7000000000000017328e-8 0.0 +tan0052 tan 0.001 0.0 -> 0.0010000003333334666875 0.0 +tan0053 tan 0.2 0.0 -> 0.20271003550867249488 0.0 +tan0054 tan 1.0 0.0 -> 1.5574077246549022305 0.0 +tan0055 tan -3.7e-08 0.0 -> -3.7000000000000017328e-8 0.0 +tan0056 tan -0.001 0.0 -> -0.0010000003333334666875 0.0 +tan0057 tan -1.0 0.0 -> -1.5574077246549022305 0.0 +tan0058 tan 0.4636476090008061 0.0 -> 0.49999999999999997163 0.0 +tan0059 tan -0.4636476090008061 0.0 -> -0.49999999999999997163 0.0 +tan0060 tan 1.1071487177940904 0.0 -> 1.9999999999999995298 0.0 +tan0061 tan -1.1071487177940904 0.0 -> -1.9999999999999995298 0.0 +tan0062 tan 1.5 0.0 -> 14.101419947171719388 0.0 +tan0063 tan 1.57 0.0 -> 1255.7655915007896475 0.0 +tan0064 tan 1.5707963267948961 0.0 -> 1978937966095219.0538 0.0 +tan0065 tan 7.225663103256523 0.0 -> 1.3763819204711701522 0.0 +tan0066 tan -8.79645943005142 0.0 -> 0.7265425280053614098 0.0 + +-- special values +tan1000 tan -0.0 0.0 -> -0.0 0.0 +tan1001 tan -inf 0.0 -> nan nan invalid +tan1002 tan -inf 2.2999999999999998 -> nan nan invalid +tan1003 tan nan 0.0 -> nan nan +tan1004 tan nan 2.2999999999999998 -> nan nan +tan1005 tan -0.0 inf -> -0.0 1.0 +tan1006 tan -0.69999999999999996 inf -> -0.0 1.0 +tan1007 tan -1.3999999999999999 inf -> -0.0 1.0 +tan1008 tan -2.1000000000000001 inf -> 0.0 1.0 +tan1009 tan -2.7999999999999998 inf -> 0.0 1.0 +tan1010 tan -3.5 inf -> -0.0 1.0 +tan1011 tan -inf inf -> -0.0 1.0 ignore-real-sign +tan1012 tan nan inf -> -0.0 1.0 ignore-real-sign +tan1013 tan -0.0 nan -> -0.0 nan +tan1014 tan -2.2999999999999998 nan -> nan nan +tan1015 tan -inf nan -> nan nan +tan1016 tan nan nan -> nan nan +tan1017 tan 0.0 0.0 -> 0.0 0.0 +tan1018 tan inf 0.0 -> nan nan invalid +tan1019 tan inf 2.2999999999999998 -> nan nan invalid +tan1020 tan 0.0 inf -> 0.0 1.0 +tan1021 tan 0.69999999999999996 inf -> 0.0 1.0 +tan1022 tan 1.3999999999999999 inf -> 0.0 1.0 +tan1023 tan 2.1000000000000001 inf -> -0.0 1.0 +tan1024 tan 2.7999999999999998 inf -> -0.0 1.0 +tan1025 tan 3.5 inf -> 0.0 1.0 +tan1026 tan inf inf -> -0.0 1.0 ignore-real-sign +tan1027 tan 0.0 nan -> 0.0 nan +tan1028 tan 2.2999999999999998 nan -> nan nan +tan1029 tan inf nan -> nan nan +tan1030 tan 0.0 -0.0 -> 0.0 -0.0 +tan1031 tan inf -0.0 -> nan nan invalid +tan1032 tan inf -2.2999999999999998 -> nan nan invalid +tan1033 tan nan -0.0 -> nan nan +tan1034 tan nan -2.2999999999999998 -> nan nan +tan1035 tan 0.0 -inf -> 0.0 -1.0 +tan1036 tan 0.69999999999999996 -inf -> 0.0 -1.0 +tan1037 tan 1.3999999999999999 -inf -> 0.0 -1.0 +tan1038 tan 2.1000000000000001 -inf -> -0.0 -1.0 +tan1039 tan 2.7999999999999998 -inf -> -0.0 -1.0 +tan1040 tan 3.5 -inf -> 0.0 -1.0 +tan1041 tan inf -inf -> -0.0 -1.0 ignore-real-sign +tan1042 tan nan -inf -> -0.0 -1.0 ignore-real-sign +tan1043 tan -0.0 -0.0 -> -0.0 -0.0 +tan1044 tan -inf -0.0 -> nan nan invalid +tan1045 tan -inf -2.2999999999999998 -> nan nan invalid +tan1046 tan -0.0 -inf -> -0.0 -1.0 +tan1047 tan -0.69999999999999996 -inf -> -0.0 -1.0 +tan1048 tan -1.3999999999999999 -inf -> -0.0 -1.0 +tan1049 tan -2.1000000000000001 -inf -> 0.0 -1.0 +tan1050 tan -2.7999999999999998 -inf -> 0.0 -1.0 +tan1051 tan -3.5 -inf -> -0.0 -1.0 +tan1052 tan -inf -inf -> -0.0 -1.0 ignore-real-sign + + +------------------------------------------------------------------------ +-- rect: Conversion from polar coordinates to rectangular coordinates -- +------------------------------------------------------------------------ +-- +-- For cmath.rect, we can use the same testcase syntax as for the +-- complex -> complex functions above, but here the input arguments +-- should be interpreted as a pair of floating-point numbers rather +-- than the real and imaginary parts of a complex number. +-- +-- Here are the 'spirit of C99' rules for rect. First, the short +-- version: +-- +-- rect(x, t) = exp(log(x)+it) for positive-signed x +-- rect(x, t) = -exp(log(-x)+it) for negative-signed x +-- rect(nan, t) = exp(nan + it), except that in rect(nan, +-0) the +-- sign of the imaginary part is unspecified. +-- +-- and now the long version: +-- +-- rect(x, -t) = conj(rect(x, t)) for all x and t +-- rect(-x, t) = -rect(x, t) for all x and t +-- rect(+0, +0) returns +0 + i0 +-- rect(+0, inf) returns +- 0 +- i0, where the signs of the real and +-- imaginary parts are unspecified. +-- rect(x, inf) returns NaN + i NaN and raises the "invalid" +-- floating-point exception, for finite nonzero x. +-- rect(inf, inf) returns +-inf + i NaN and raises the "invalid" +-- floating-point exception (where the sign of the real part of the +-- result is unspecified). +-- rect(inf, +0) returns inf+i0 +-- rect(inf, x) returns inf*cis(x), for finite nonzero x +-- rect(inf, NaN) returns +-inf+i NaN, where the sign of the real part +-- of the result is unspecified. +-- rect(NaN, x) returns NaN + i NaN for all nonzero numbers (including +-- infinities) x +-- rect(NaN, 0) returns NaN +- i0, where the sign of the imaginary +-- part is unspecified +-- rect(NaN, NaN) returns NaN + i NaN +-- rect(x, NaN) returns NaN + i NaN for finite nonzero x +-- rect(+0, NaN) return +-0 +- i0, where the signs of the real and +-- imaginary parts are unspecified. + +-- special values +rect1000 rect 0.0 0.0 -> 0.0 0.0 +rect1001 rect 0.0 inf -> 0.0 0.0 ignore-real-sign ignore-imag-sign +rect1002 rect 2.3 inf -> nan nan invalid +rect1003 rect inf inf -> inf nan invalid ignore-real-sign +rect1004 rect inf 0.0 -> inf 0.0 +rect1005 rect inf 1.4 -> inf inf +rect1006 rect inf 2.8 -> -inf inf +rect1007 rect inf 4.2 -> -inf -inf +rect1008 rect inf 5.6 -> inf -inf +rect1009 rect inf 7.0 -> inf inf +rect1010 rect nan 0.0 -> nan 0.0 ignore-imag-sign +rect1011 rect nan 2.3 -> nan nan +rect1012 rect nan inf -> nan nan +rect1013 rect nan nan -> nan nan +rect1014 rect inf nan -> inf nan ignore-real-sign +rect1015 rect 2.3 nan -> nan nan +rect1016 rect 0.0 nan -> 0.0 0.0 ignore-real-sign ignore-imag-sign +rect1017 rect 0.0 -0.0 -> 0.0 -0.0 +rect1018 rect 0.0 -inf -> 0.0 0.0 ignore-real-sign ignore-imag-sign +rect1019 rect 2.3 -inf -> nan nan invalid +rect1020 rect inf -inf -> inf nan invalid ignore-real-sign +rect1021 rect inf -0.0 -> inf -0.0 +rect1022 rect inf -1.4 -> inf -inf +rect1023 rect inf -2.8 -> -inf -inf +rect1024 rect inf -4.2 -> -inf inf +rect1025 rect inf -5.6 -> inf inf +rect1026 rect inf -7.0 -> inf -inf +rect1027 rect nan -0.0 -> nan 0.0 ignore-imag-sign +rect1028 rect nan -2.3 -> nan nan +rect1029 rect nan -inf -> nan nan +rect1030 rect -0.0 0.0 -> -0.0 -0.0 +rect1031 rect -0.0 inf -> 0.0 0.0 ignore-real-sign ignore-imag-sign +rect1032 rect -2.3 inf -> nan nan invalid +rect1033 rect -inf inf -> -inf nan invalid ignore-real-sign +rect1034 rect -inf 0.0 -> -inf -0.0 +rect1035 rect -inf 1.4 -> -inf -inf +rect1036 rect -inf 2.8 -> inf -inf +rect1037 rect -inf 4.2 -> inf inf +rect1038 rect -inf 5.6 -> -inf inf +rect1039 rect -inf 7.0 -> -inf -inf +rect1040 rect -inf nan -> inf nan ignore-real-sign +rect1041 rect -2.3 nan -> nan nan +rect1042 rect -0.0 nan -> 0.0 0.0 ignore-real-sign ignore-imag-sign +rect1043 rect -0.0 -0.0 -> -0.0 0.0 +rect1044 rect -0.0 -inf -> 0.0 0.0 ignore-real-sign ignore-imag-sign +rect1045 rect -2.3 -inf -> nan nan invalid +rect1046 rect -inf -inf -> -inf nan invalid ignore-real-sign +rect1047 rect -inf -0.0 -> -inf 0.0 +rect1048 rect -inf -1.4 -> -inf inf +rect1049 rect -inf -2.8 -> inf inf +rect1050 rect -inf -4.2 -> inf -inf +rect1051 rect -inf -5.6 -> -inf -inf +rect1052 rect -inf -7.0 -> -inf inf + +------------------------------------------------------------------------- +-- polar: Conversion from rectangular coordinates to polar coordinates -- +------------------------------------------------------------------------- +-- +-- For cmath.polar, we can use the same testcase syntax as for the +-- complex -> complex functions above, but here the output arguments +-- should be interpreted as a pair of floating-point numbers rather +-- than the real and imaginary parts of a complex number. +-- +-- Annex G of the C99 standard describes fully both the real and +-- imaginary parts of polar (as cabs and carg, respectively, which in turn +-- are defined in terms of the functions hypot and atan2). + +-- overflow +polar0100 polar 1.4e308 1.4e308 -> inf 0.78539816339744828 overflow + +-- special values +polar1000 polar 0.0 0.0 -> 0.0 0.0 +polar1001 polar 0.0 -0.0 -> 0.0 -0.0 +polar1002 polar -0.0 0.0 -> 0.0 3.1415926535897931 +polar1003 polar -0.0 -0.0 -> 0.0 -3.1415926535897931 +polar1004 polar inf 0.0 -> inf 0.0 +polar1005 polar inf 2.3 -> inf 0.0 +polar1006 polar inf inf -> inf 0.78539816339744828 +polar1007 polar 2.3 inf -> inf 1.5707963267948966 +polar1008 polar 0.0 inf -> inf 1.5707963267948966 +polar1009 polar -0.0 inf -> inf 1.5707963267948966 +polar1010 polar -2.3 inf -> inf 1.5707963267948966 +polar1011 polar -inf inf -> inf 2.3561944901923448 +polar1012 polar -inf 2.3 -> inf 3.1415926535897931 +polar1013 polar -inf 0.0 -> inf 3.1415926535897931 +polar1014 polar -inf -0.0 -> inf -3.1415926535897931 +polar1015 polar -inf -2.3 -> inf -3.1415926535897931 +polar1016 polar -inf -inf -> inf -2.3561944901923448 +polar1017 polar -2.3 -inf -> inf -1.5707963267948966 +polar1018 polar -0.0 -inf -> inf -1.5707963267948966 +polar1019 polar 0.0 -inf -> inf -1.5707963267948966 +polar1020 polar 2.3 -inf -> inf -1.5707963267948966 +polar1021 polar inf -inf -> inf -0.78539816339744828 +polar1022 polar inf -2.3 -> inf -0.0 +polar1023 polar inf -0.0 -> inf -0.0 +polar1024 polar nan -inf -> inf nan +polar1025 polar nan -2.3 -> nan nan +polar1026 polar nan -0.0 -> nan nan +polar1027 polar nan 0.0 -> nan nan +polar1028 polar nan 2.3 -> nan nan +polar1029 polar nan inf -> inf nan +polar1030 polar nan nan -> nan nan +polar1031 polar inf nan -> inf nan +polar1032 polar 2.3 nan -> nan nan +polar1033 polar 0.0 nan -> nan nan +polar1034 polar -0.0 nan -> nan nan +polar1035 polar -2.3 nan -> nan nan +polar1036 polar -inf nan -> inf nan diff --git a/test/dynamo/cpython/3_13/mathdata/floating_points.txt b/test/dynamo/cpython/3_13/mathdata/floating_points.txt new file mode 100644 index 00000000000000..539073d19d8577 --- /dev/null +++ b/test/dynamo/cpython/3_13/mathdata/floating_points.txt @@ -0,0 +1,1028 @@ +# These numbers are used to test floating point binary-to-decimal conversion. +# They are based on the TCL test suite (tests/expr.test), which is based on +# test data from: +# Brigitte Verdonk, Annie Cuyt, Dennis Verschaeren, A precision and range +# independent tool for testing floating-point arithmetic II: Conversions, +# ACM Transactions on Mathematical Software 27:2 (March 2001), pp. 119-140. + +0E0 +-0E0 +1E0 +15E-1 +125E-2 +1125E-3 +10625E-4 +103125E-5 +1015625E-6 +10078125E-7 +100390625E-8 +1001953125E-9 +10009765625E-10 +100048828125E-11 +1000244140625E-12 +10001220703125E-13 +100006103515625E-14 +1000030517578125E-15 +10000152587890625E-16 ++8E153 +-1E153 ++9E306 +-2E153 ++7E-304 +-3E-49 ++7E-303 +-6E-49 ++9E43 +-9E44 ++8E303 +-1E303 ++7E-287 +-2E-204 ++2E-205 +-9E-47 ++34E195 +-68E195 ++85E194 +-67E97 ++93E-234 +-19E-87 ++38E-87 +-38E-88 +-69E220 ++18E43 +-36E43 ++61E-99 +-43E-92 ++86E-92 +-51E-74 ++283E85 +-566E85 ++589E187 +-839E143 +-744E-234 ++930E-235 +-186E-234 ++604E175 +-302E175 ++755E174 +-151E175 ++662E-213 +-408E-74 ++510E-75 ++6782E55 +-2309E92 ++7963E34 +-3391E55 ++7903E-96 +-7611E-226 ++4907E-196 +-5547E-311 ++5311E241 +-5311E243 ++5311E242 ++9269E-45 +-8559E-289 ++8699E-276 +-8085E-64 ++74819E201 +-82081E41 ++51881E37 +-55061E157 ++77402E-215 +-33891E-92 ++38701E-215 +-82139E-76 ++75859E25 ++89509E140 +-57533E287 ++46073E-32 +-92146E-32 ++83771E-74 +-34796E-276 ++584169E229 ++164162E41 +-328324E41 ++209901E-11 +-419802E-11 ++940189E-112 +-892771E-213 ++757803E120 +-252601E120 ++252601E121 +-505202E120 ++970811E-264 +-654839E-60 ++289767E-178 +-579534E-178 +-8823691E130 ++9346704E229 +-1168338E229 +-6063369E-136 ++3865421E-225 +-5783893E-127 ++2572231E223 +-5144462E223 ++1817623E109 ++6431543E-97 +-5444097E-21 ++8076999E-121 +-9997649E-270 ++50609263E157 ++70589528E130 +-88236910E129 ++87575437E-310 +-23135572E-127 ++85900881E177 +-84863171E113 ++68761586E232 +-50464069E286 ++27869147E-248 +-55738294E-248 ++70176353E-53 +-80555086E-32 +-491080654E121 ++526250918E287 +-245540327E121 +-175150874E-310 ++350301748E-310 +-437877185E-311 ++458117166E52 +-916234332E52 ++229058583E52 +-525789935E98 ++282926897E-227 +-565853794E-227 ++667284113E-240 +-971212611E-126 ++9981396317E-182 +-5035231965E-156 ++8336960483E-153 +-8056371144E-155 ++6418488827E79 +-3981006983E252 ++7962013966E252 +-4713898551E261 ++8715380633E-58 +-9078555839E-109 ++9712126110E-127 ++42333842451E201 +-84667684902E201 ++23792120709E-315 +-78564021519E-227 ++71812054883E-188 +-30311163631E-116 ++71803914657E292 ++36314223356E-109 ++18157111678E-109 +-45392779195E-110 ++778380362293E218 +-685763015669E280 ++952918668151E70 +-548357443505E32 ++384865004907E-285 +-769730009814E-285 ++697015418417E-93 +-915654049301E-28 ++178548656339E169 +-742522891517E259 ++742522891517E258 +-357097312678E169 +-3113521449172E218 ++3891901811465E217 +-1556760724586E218 ++9997878507563E-195 +-7247563029154E-319 ++3623781514577E-319 +-3092446298323E-200 ++6363857920591E145 +-8233559360849E94 ++2689845954547E49 +-5379691909094E49 ++5560322501926E-301 +-7812878489261E-179 ++8439398533053E-256 +-2780161250963E-301 +-87605699161665E155 +-17521139832333E156 +-88218101363513E-170 ++38639244311627E-115 ++35593959807306E261 +-53390939710959E260 ++71187919614612E261 +-88984899518265E260 ++77003665618895E-73 +-15400733123779E-72 ++61602932495116E-72 +-30801466247558E-72 ++834735494917063E-300 +-589795149206434E-151 ++475603213226859E-42 +-294897574603217E-151 ++850813008001913E93 +-203449172043339E185 ++406898344086678E185 +-813796688173356E185 ++6045338514609393E244 +-5145963778954906E142 ++2572981889477453E142 +-6965949469487146E74 ++6182410494241627E-119 +-8510309498186985E-277 ++6647704637273331E-212 +-2215901545757777E-212 ++3771476185376383E276 +-3729901848043846E212 ++3771476185376383E277 +-9977830465649166E119 ++8439928496349319E-142 +-8204230082070882E-59 ++8853686434843997E-244 +-5553274272288559E-104 ++36149023611096162E144 +-36149023611096162E147 ++18074511805548081E146 +-18074511805548081E147 ++97338774138954421E-290 +-88133809804950961E-308 ++94080055902682397E-243 +-24691002732654881E-115 ++52306490527514614E49 +-26153245263757307E49 ++55188692254193604E165 +-68985865317742005E164 ++27176258005319167E-261 +-73169230107256116E-248 ++91461537634070145E-249 +-54352516010638334E-261 ++586144289638535878E280 +-601117006785295431E245 ++293072144819267939E280 +-953184713238516652E272 ++902042358290366539E-281 +-557035730189854663E-294 ++902042358290366539E-280 +-354944100507554393E-238 ++272104041512242479E199 +-816312124536727437E199 ++544208083024484958E199 +-792644927852378159E78 +-679406450132979175E-263 ++543525160106383340E-262 ++7400253695682920196E215 +-1850063423920730049E215 ++3700126847841460098E215 +-9250317119603650245E214 ++8396094300569779681E-252 +-3507665085003296281E-75 ++7015330170006592562E-75 +-7015330170006592562E-74 ++7185620434951919351E205 +-1360520207561212395E198 ++2178999185345151731E-184 +-8691089486201567102E-218 ++4345544743100783551E-218 +-4357998370690303462E-184 ++59825267349106892461E177 +-62259110684423957791E47 ++58380168477038565599E265 +-62259110684423957791E48 +-33584377202279118724E-252 +-57484963479615354808E205 ++71856204349519193510E204 +-14371240869903838702E205 ++36992084760177624177E-318 +-73984169520355248354E-318 ++99257763227713890244E-115 +-87336362425182547697E-280 ++7E289 +-3E153 ++6E153 +-5E243 ++7E-161 +-7E-172 ++8E-63 +-7E-113 ++8E126 +-4E126 ++5E125 +-1E126 ++8E-163 +-1E-163 ++2E-163 +-4E-163 ++51E195 +-37E46 ++74E46 +-56E289 ++69E-145 +-70E-162 ++56E-161 +-21E-303 ++34E-276 +-68E-276 ++85E-277 +-87E-274 ++829E102 +-623E100 ++723E-162 +-457E-102 ++914E-102 +-323E-135 ++151E176 +-302E176 ++921E90 +-604E176 ++823E-206 +-463E-114 ++348E-274 ++9968E100 +-6230E99 ++1246E100 ++6676E-296 +-8345E-297 ++1669E-296 +-3338E-296 ++3257E58 +-6514E58 ++2416E176 ++8085E-63 +-3234E-62 ++1617E-62 +-6468E-62 ++53418E111 +-60513E160 ++26709E111 +-99447E166 ++12549E48 +-25098E48 ++50196E48 +-62745E47 ++83771E-73 +-97451E-167 ++86637E-203 +-75569E-254 ++473806E83 +-947612E83 ++292369E76 +-584738E76 ++933587E-140 +-720919E-14 ++535001E-149 +-890521E-235 ++548057E81 +-706181E88 ++820997E106 +-320681E63 ++928609E-261 +-302276E-254 ++151138E-254 ++4691773E45 +-9383546E45 ++3059949E-243 +-6119898E-243 ++5356626E-213 +-4877378E-199 ++7716693E223 +-5452869E109 ++4590831E156 +-9181662E156 +-3714436E-261 ++4643045E-262 +-7428872E-261 ++52942146E130 +-27966061E145 ++26471073E130 +-55932122E145 ++95412548E-99 +-47706274E-99 ++23853137E-99 +-78493654E-301 ++65346417E29 +-51083099E167 ++89396333E264 +-84863171E114 ++59540836E-251 +-74426045E-252 ++14885209E-251 +-29770418E-251 ++982161308E122 +-245540327E122 ++491080654E122 ++525452622E-310 +-771837113E-134 ++820858081E-150 +-262726311E-310 ++923091487E209 +-653777767E273 ++842116236E-53 +-741111169E-202 ++839507247E-284 +-951487269E-264 +-9821613080E121 ++6677856011E-31 +-3573796826E-266 ++7147593652E-266 +-9981396317E-181 ++3268888835E272 +-2615111068E273 ++1307555534E273 ++2990671154E-190 +-1495335577E-190 ++5981342308E-190 +-7476677885E-191 ++82259684194E-202 +-93227267727E-49 ++41129842097E-202 +-47584241418E-314 +-79360293406E92 ++57332259349E225 +-57202326162E111 ++86860597053E-206 +-53827010643E-200 ++53587107423E-61 ++635007636765E200 ++508006109412E201 +-254003054706E201 ++561029718715E-72 +-897647549944E-71 ++112205943743E-71 +-873947086081E-236 ++809184709177E116 +-573112917422E81 ++286556458711E81 ++952805821491E-259 +-132189992873E-44 +-173696038493E-144 ++1831132757599E-107 +-9155663787995E-108 ++7324531030396E-107 +-9277338894969E-200 ++8188292423973E287 +-5672557437938E59 ++2836278718969E59 +-9995153153494E54 ++9224786422069E-291 +-3142213164987E-294 ++6284426329974E-294 +-8340483752889E-301 ++67039371486466E89 +-62150786615239E197 ++33519685743233E89 +-52563419496999E156 ++32599460466991E-65 +-41010988798007E-133 ++65198920933982E-65 +-82021977596014E-133 ++80527976643809E61 +-74712611505209E158 ++53390939710959E261 +-69277302659155E225 ++46202199371337E-72 +-23438635467783E-179 ++41921560615349E-67 +-92404398742674E-72 ++738545606647197E124 +-972708181182949E117 +-837992143580825E87 ++609610927149051E-255 +-475603213226859E-41 ++563002800671023E-177 +-951206426453718E-41 ++805416432656519E202 +-530658674694337E159 ++946574173863918E208 +-318329953318553E113 +-462021993713370E-73 ++369617594970696E-72 ++3666156212014994E233 +-1833078106007497E233 ++8301790508624232E174 +-1037723813578029E174 ++7297662880581139E-286 +-5106185698912191E-276 ++7487252720986826E-165 +-3743626360493413E-165 ++3773057430100257E230 +-7546114860200514E230 ++4321222892463822E58 +-7793560217139653E51 ++26525993941010681E112 +-53051987882021362E112 ++72844871414247907E77 +-88839359596763261E105 ++18718131802467065E-166 +-14974505441973652E-165 ++73429396004640239E106 +-58483921078398283E57 ++41391519190645203E165 +-82783038381290406E165 ++58767043776702677E-163 +-90506231831231999E-129 ++64409240769861689E-159 +-77305427432277771E-190 ++476592356619258326E273 +-953184713238516652E273 ++899810892172646163E283 +-929167076892018333E187 ++647761278967534239E-312 +-644290479820542942E-180 ++926145344610700019E-225 +-958507931896511964E-246 ++272104041512242479E200 +-792644927852378159E79 ++544208083024484958E200 +-929963218616126365E290 ++305574339166810102E-219 +-152787169583405051E-219 ++611148678333620204E-219 +-763935847917025255E-220 ++7439550220920798612E158 +-3719775110460399306E158 ++9299437776150998265E157 +-7120190517612959703E120 ++3507665085003296281E-73 +-7015330170006592562E-73 +-6684428762278255956E-294 +-1088416166048969916E200 +-8707329328391759328E200 ++4439021781608558002E-65 +-8878043563217116004E-65 ++2219510890804279001E-65 ++33051223951904955802E55 +-56961524140903677624E120 ++71201905176129597030E119 ++14030660340013185124E-73 +-17538325425016481405E-74 ++67536228609141569109E-133 +-35620497849450218807E-306 ++66550376797582521751E-126 +-71240995698900437614E-306 ++3E24 +-6E24 ++6E26 +-7E25 ++1E-14 +-2E-14 ++4E-14 +-8E-14 ++5E26 +-8E27 ++1E27 +-4E27 ++9E-13 +-7E-20 ++56E25 +-70E24 ++51E26 ++71E-17 +-31E-5 ++62E-5 +-94E-8 ++67E27 +-81E24 ++54E23 +-54E25 ++63E-22 +-63E-23 ++43E-4 +-86E-4 ++942E26 +-471E25 ++803E24 +-471E26 +-409E-21 ++818E-21 +-867E-8 ++538E27 +-857E24 ++269E27 +-403E26 ++959E-7 +-959E-6 ++373E-27 +-746E-27 ++4069E24 +-4069E23 +-8138E24 ++8294E-15 +-4147E-14 ++4147E-15 +-8294E-14 ++538E27 +-2690E26 ++269E27 +-2152E27 ++1721E-17 +-7979E-27 ++6884E-17 +-8605E-18 ++82854E27 +-55684E24 ++27842E24 +-48959E25 ++81921E-17 +-76207E-8 ++4147E-15 +-41470E-16 ++89309E24 ++75859E26 +-75859E25 ++14257E-23 +-28514E-23 ++57028E-23 +-71285E-24 ++344863E27 +-951735E27 ++200677E23 +-401354E24 ++839604E-11 +-209901E-11 ++419802E-11 +-537734E-24 ++910308E26 +-227577E26 ++455154E26 +-531013E25 ++963019E-21 +-519827E-13 ++623402E-27 +-311701E-27 ++9613651E26 +-9191316E23 ++4595658E23 +-2297829E23 +-1679208E-11 ++3379223E27 +-6758446E27 ++5444097E-21 +-8399969E-27 ++8366487E-16 +-8366487E-15 ++65060671E25 ++65212389E23 ++55544957E-13 +-51040905E-20 ++99585767E-22 +-99585767E-23 ++40978393E26 +-67488159E24 ++69005339E23 +-81956786E26 +-87105552E-21 ++10888194E-21 +-21776388E-21 ++635806667E27 +-670026614E25 ++335013307E26 +-335013307E25 ++371790617E-24 +-371790617E-25 ++743581234E-24 +-743581234E-25 ++202464477E24 +-404928954E24 ++997853758E27 +-997853758E26 ++405498418E-17 +-582579084E-14 ++608247627E-18 +-291289542E-14 +-9537100005E26 ++6358066670E27 +-1271613334E27 ++5229646999E-16 ++5229646999E-17 ++4429943614E24 +-8859887228E24 ++2214971807E24 +-4176887093E26 ++4003495257E-20 +-4361901637E-23 ++8723803274E-23 +-8006990514E-20 ++72835110098E27 +-36417555049E27 ++84279630104E25 +-84279630104E24 ++21206176437E-27 +-66461566917E-22 ++64808355539E-16 +-84932679673E-19 ++65205430094E26 +-68384463429E25 ++32602715047E26 +-62662203426E27 ++58784444678E-18 +-50980203373E-21 ++29392222339E-18 +-75529940323E-27 +-937495906299E26 ++842642485799E-20 +-387824150699E-23 ++924948814726E-27 +-775648301398E-23 ++547075707432E25 ++683844634290E24 +-136768926858E25 ++509802033730E-22 ++101960406746E-21 +-815683253968E-21 ++7344124123524E24 +-9180155154405E23 ++6479463327323E27 +-1836031030881E24 ++4337269293039E-19 +-4599163554373E-23 ++9198327108746E-23 ++4812803938347E27 +-8412030890011E23 ++9625607876694E27 +-4739968828249E24 ++9697183891673E-23 +-7368108517543E-20 ++51461358161422E25 +-77192037242133E26 ++77192037242133E25 +-51461358161422E27 ++43999661561541E-21 +-87999323123082E-21 ++48374886826137E-26 +-57684246567111E-23 ++87192805957686E23 +-75108713005913E24 ++64233110587487E27 +-77577471133384E-23 ++48485919458365E-24 +-56908598265713E-26 ++589722294620133E23 ++652835804449289E-22 +-656415363936202E-23 ++579336749585745E-25 +-381292764980839E-26 ++965265859649698E23 +-848925235434882E27 ++536177612222491E23 +-424462617717441E27 ++276009279888989E-27 +-608927158043691E-26 ++552018559777978E-27 +-425678377667758E-22 ++8013702726927119E26 ++8862627962362001E27 +-5068007907757162E26 +-7379714799828406E-23 ++4114538064016107E-27 +-3689857399914203E-23 ++5575954851815478E23 ++3395700941739528E27 ++4115535777581961E-23 +-8231071555163922E-23 ++6550246696190871E-26 +-68083046403986701E27 ++43566388595783643E27 +-87132777191567286E27 ++59644881059342141E25 +-83852770718576667E23 ++99482967418206961E-25 +-99482967418206961E-26 ++87446669969994614E-27 +-43723334984997307E-27 ++5E24 +-8E25 ++1E25 +-4E25 ++2E-5 +-5E-6 ++4E-5 +-3E-20 ++3E27 +-9E26 ++7E25 +-6E27 ++2E-21 +-5E-22 +-4E-21 ++87E25 +-97E24 ++82E-24 +-41E-24 ++76E-23 ++83E25 +-50E27 ++25E27 +-99E27 ++97E-10 +-57E-20 ++997E23 ++776E24 +-388E24 ++521E-10 +-506E-26 ++739E-10 +-867E-7 +-415E24 ++332E25 +-664E25 ++291E-13 +-982E-8 ++582E-13 +-491E-8 ++4574E26 +-8609E26 ++2287E26 +-4818E24 ++6529E-8 +-8151E-21 ++1557E-12 +-2573E-18 ++4929E-16 +-3053E-22 ++9858E-16 +-7767E-11 ++54339E26 +-62409E25 ++32819E27 +-89849E27 ++63876E-20 +-15969E-20 ++31938E-20 +-79845E-21 ++89306E27 +-25487E24 ++79889E24 +-97379E26 ++81002E-8 +-43149E-25 ++40501E-8 +-60318E-10 +-648299E27 ++780649E24 ++720919E-14 +-629703E-11 ++557913E24 +-847899E23 ++565445E27 +-736531E24 ++680013E-19 +-529981E-10 ++382923E-23 +-633614E-18 ++2165479E27 +-8661916E27 ++4330958E27 +-9391993E22 +-5767352E-14 ++7209190E-15 +-1441838E-14 ++8478990E22 ++1473062E24 ++8366487E-14 +-8399969E-25 ++9366737E-12 +-9406141E-13 ++65970979E24 +-65060671E26 ++54923002E27 +-63846927E25 ++99585767E-21 ++67488159E25 +-69005339E24 ++81956786E27 +-40978393E27 ++77505754E-12 +-38752877E-12 ++82772981E-15 +-95593517E-25 ++200036989E25 +-772686455E27 ++859139907E23 +-400073978E25 ++569014327E-14 +-794263862E-15 ++397131931E-15 +-380398957E-16 ++567366773E27 +-337440795E24 ++134976318E25 +-269952636E25 ++932080597E-20 +-331091924E-15 +-413864905E-16 ++8539246247E26 +-5859139791E26 ++6105010149E24 +-3090745820E27 ++3470877773E-20 +-6136309089E-27 ++8917758713E-19 +-6941755546E-20 ++9194900535E25 +-1838980107E26 ++7355920428E26 +-3677960214E26 ++8473634343E-17 +-8870766274E-16 ++4435383137E-16 +-9598990129E-15 ++71563496764E26 +-89454370955E25 ++17890874191E26 +-35781748382E26 ++57973447842E-19 +-28986723921E-19 ++76822711313E-19 +-97699466874E-20 ++67748656762E27 +-19394840991E24 ++38789681982E24 +-33874328381E27 ++54323763886E-27 +-58987193887E-20 ++27161881943E-27 +-93042648033E-19 ++520831059055E27 +-768124264394E25 ++384062132197E25 ++765337749889E-25 ++794368912771E25 +-994162090146E23 ++781652779431E26 ++910077190046E-26 +-455038595023E-26 ++471897551096E-20 +-906698409911E-21 ++8854128003935E25 +-8146122716299E27 ++7083302403148E26 +-3541651201574E26 ++8394920649291E-25 +-7657975756753E-22 ++5473834002228E-20 +-6842292502785E-21 +-2109568884597E25 ++8438275538388E25 +-4219137769194E25 ++3200141789841E-25 +-8655689322607E-22 ++6400283579682E-25 +-8837719634493E-21 ++19428217075297E24 +-38856434150594E24 ++77712868301188E24 +-77192037242133E27 ++76579757567530E-23 ++15315951513506E-22 +-38289878783765E-23 ++49378033925202E25 +-50940527102367E24 ++98756067850404E25 +-99589397544892E26 +-56908598265713E-25 ++97470695699657E-22 +-35851901247343E-25 ++154384074484266E27 +-308768148968532E27 ++910990389005985E23 ++271742424169201E-27 +-543484848338402E-27 ++162192083357563E-26 +-869254552770081E-23 ++664831007626046E24 +-332415503813023E24 ++943701829041427E24 +-101881054204734E24 ++828027839666967E-27 +-280276135608777E-27 ++212839188833879E-21 +-113817196531426E-25 ++9711553197796883E27 +-2739849386524269E26 ++5479698773048538E26 ++6124568318523113E-25 +-1139777988171071E-24 ++6322612303128019E-27 +-2955864564844617E-25 +-9994029144998961E25 +-2971238324022087E27 +-1656055679333934E-27 +-1445488709150234E-26 ++55824717499885172E27 +-69780896874856465E26 ++84161538867545199E25 +-27912358749942586E27 ++24711112462926331E-25 +-12645224606256038E-27 +-12249136637046226E-25 ++74874448287465757E27 +-35642836832753303E24 +-71285673665506606E24 ++43723334984997307E-26 ++10182419849537963E-24 +-93501703572661982E-26 + +# A value that caused a crash in debug builds for Python >= 2.7, 3.1 +# See http://bugs.python.org/issue7632 +2183167012312112312312.23538020374420446192e-370 + +# Another value designed to test a corner case of Python's strtod code. +0.99999999999999999999999999999999999999999e+23 diff --git a/test/dynamo/cpython/3_13/mathdata/formatfloat_testcases.txt b/test/dynamo/cpython/3_13/mathdata/formatfloat_testcases.txt new file mode 100644 index 00000000000000..25c07ba2939b01 --- /dev/null +++ b/test/dynamo/cpython/3_13/mathdata/formatfloat_testcases.txt @@ -0,0 +1,355 @@ +-- 'f' code formatting, with explicit precision (>= 0). Output always +-- has the given number of places after the point; zeros are added if +-- necessary to make this true. + +-- zeros +%.0f 0 -> 0 +%.1f 0 -> 0.0 +%.2f 0 -> 0.00 +%.3f 0 -> 0.000 +%.50f 0 -> 0.00000000000000000000000000000000000000000000000000 + +-- precision 0; result should never include a . +%.0f 1.5 -> 2 +%.0f 2.5 -> 2 +%.0f 3.5 -> 4 +%.0f 0.0 -> 0 +%.0f 0.1 -> 0 +%.0f 0.001 -> 0 +%.0f 10.0 -> 10 +%.0f 10.1 -> 10 +%.0f 10.01 -> 10 +%.0f 123.456 -> 123 +%.0f 1234.56 -> 1235 +%.0f 1e49 -> 9999999999999999464902769475481793196872414789632 +%.0f 9.9999999999999987e+49 -> 99999999999999986860582406952576489172979654066176 +%.0f 1e50 -> 100000000000000007629769841091887003294964970946560 + +-- precision 1 +%.1f 0.0001 -> 0.0 +%.1f 0.001 -> 0.0 +%.1f 0.01 -> 0.0 +%.1f 0.04 -> 0.0 +%.1f 0.06 -> 0.1 +%.1f 0.25 -> 0.2 +%.1f 0.75 -> 0.8 +%.1f 1.4 -> 1.4 +%.1f 1.5 -> 1.5 +%.1f 10.0 -> 10.0 +%.1f 1000.03 -> 1000.0 +%.1f 1234.5678 -> 1234.6 +%.1f 1234.7499 -> 1234.7 +%.1f 1234.75 -> 1234.8 + +-- precision 2 +%.2f 0.0001 -> 0.00 +%.2f 0.001 -> 0.00 +%.2f 0.004999 -> 0.00 +%.2f 0.005001 -> 0.01 +%.2f 0.01 -> 0.01 +%.2f 0.125 -> 0.12 +%.2f 0.375 -> 0.38 +%.2f 1234500 -> 1234500.00 +%.2f 1234560 -> 1234560.00 +%.2f 1234567 -> 1234567.00 +%.2f 1234567.8 -> 1234567.80 +%.2f 1234567.89 -> 1234567.89 +%.2f 1234567.891 -> 1234567.89 +%.2f 1234567.8912 -> 1234567.89 + +-- alternate form always includes a decimal point. This only +-- makes a difference when the precision is 0. +%#.0f 0 -> 0. +%#.1f 0 -> 0.0 +%#.0f 1.5 -> 2. +%#.0f 2.5 -> 2. +%#.0f 10.1 -> 10. +%#.0f 1234.56 -> 1235. +%#.1f 1.4 -> 1.4 +%#.2f 0.375 -> 0.38 + +-- if precision is omitted it defaults to 6 +%f 0 -> 0.000000 +%f 1230000 -> 1230000.000000 +%f 1234567 -> 1234567.000000 +%f 123.4567 -> 123.456700 +%f 1.23456789 -> 1.234568 +%f 0.00012 -> 0.000120 +%f 0.000123 -> 0.000123 +%f 0.00012345 -> 0.000123 +%f 0.000001 -> 0.000001 +%f 0.0000005001 -> 0.000001 +%f 0.0000004999 -> 0.000000 + +-- 'e' code formatting with explicit precision (>= 0). Output should +-- always have exactly the number of places after the point that were +-- requested. + +-- zeros +%.0e 0 -> 0e+00 +%.1e 0 -> 0.0e+00 +%.2e 0 -> 0.00e+00 +%.10e 0 -> 0.0000000000e+00 +%.50e 0 -> 0.00000000000000000000000000000000000000000000000000e+00 + +-- precision 0. no decimal point in the output +%.0e 0.01 -> 1e-02 +%.0e 0.1 -> 1e-01 +%.0e 1 -> 1e+00 +%.0e 10 -> 1e+01 +%.0e 100 -> 1e+02 +%.0e 0.012 -> 1e-02 +%.0e 0.12 -> 1e-01 +%.0e 1.2 -> 1e+00 +%.0e 12 -> 1e+01 +%.0e 120 -> 1e+02 +%.0e 123.456 -> 1e+02 +%.0e 0.000123456 -> 1e-04 +%.0e 123456000 -> 1e+08 +%.0e 0.5 -> 5e-01 +%.0e 1.4 -> 1e+00 +%.0e 1.5 -> 2e+00 +%.0e 1.6 -> 2e+00 +%.0e 2.4999999 -> 2e+00 +%.0e 2.5 -> 2e+00 +%.0e 2.5000001 -> 3e+00 +%.0e 3.499999999999 -> 3e+00 +%.0e 3.5 -> 4e+00 +%.0e 4.5 -> 4e+00 +%.0e 5.5 -> 6e+00 +%.0e 6.5 -> 6e+00 +%.0e 7.5 -> 8e+00 +%.0e 8.5 -> 8e+00 +%.0e 9.4999 -> 9e+00 +%.0e 9.5 -> 1e+01 +%.0e 10.5 -> 1e+01 +%.0e 14.999 -> 1e+01 +%.0e 15 -> 2e+01 + +-- precision 1 +%.1e 0.0001 -> 1.0e-04 +%.1e 0.001 -> 1.0e-03 +%.1e 0.01 -> 1.0e-02 +%.1e 0.1 -> 1.0e-01 +%.1e 1 -> 1.0e+00 +%.1e 10 -> 1.0e+01 +%.1e 100 -> 1.0e+02 +%.1e 120 -> 1.2e+02 +%.1e 123 -> 1.2e+02 +%.1e 123.4 -> 1.2e+02 + +-- precision 2 +%.2e 0.00013 -> 1.30e-04 +%.2e 0.000135 -> 1.35e-04 +%.2e 0.0001357 -> 1.36e-04 +%.2e 0.0001 -> 1.00e-04 +%.2e 0.001 -> 1.00e-03 +%.2e 0.01 -> 1.00e-02 +%.2e 0.1 -> 1.00e-01 +%.2e 1 -> 1.00e+00 +%.2e 10 -> 1.00e+01 +%.2e 100 -> 1.00e+02 +%.2e 1000 -> 1.00e+03 +%.2e 1500 -> 1.50e+03 +%.2e 1590 -> 1.59e+03 +%.2e 1598 -> 1.60e+03 +%.2e 1598.7 -> 1.60e+03 +%.2e 1598.76 -> 1.60e+03 +%.2e 9999 -> 1.00e+04 + +-- omitted precision defaults to 6 +%e 0 -> 0.000000e+00 +%e 165 -> 1.650000e+02 +%e 1234567 -> 1.234567e+06 +%e 12345678 -> 1.234568e+07 +%e 1.1 -> 1.100000e+00 + +-- alternate form always contains a decimal point. This only makes +-- a difference when precision is 0. + +%#.0e 0.01 -> 1.e-02 +%#.0e 0.1 -> 1.e-01 +%#.0e 1 -> 1.e+00 +%#.0e 10 -> 1.e+01 +%#.0e 100 -> 1.e+02 +%#.0e 0.012 -> 1.e-02 +%#.0e 0.12 -> 1.e-01 +%#.0e 1.2 -> 1.e+00 +%#.0e 12 -> 1.e+01 +%#.0e 120 -> 1.e+02 +%#.0e 123.456 -> 1.e+02 +%#.0e 0.000123456 -> 1.e-04 +%#.0e 123456000 -> 1.e+08 +%#.0e 0.5 -> 5.e-01 +%#.0e 1.4 -> 1.e+00 +%#.0e 1.5 -> 2.e+00 +%#.0e 1.6 -> 2.e+00 +%#.0e 2.4999999 -> 2.e+00 +%#.0e 2.5 -> 2.e+00 +%#.0e 2.5000001 -> 3.e+00 +%#.0e 3.499999999999 -> 3.e+00 +%#.0e 3.5 -> 4.e+00 +%#.0e 4.5 -> 4.e+00 +%#.0e 5.5 -> 6.e+00 +%#.0e 6.5 -> 6.e+00 +%#.0e 7.5 -> 8.e+00 +%#.0e 8.5 -> 8.e+00 +%#.0e 9.4999 -> 9.e+00 +%#.0e 9.5 -> 1.e+01 +%#.0e 10.5 -> 1.e+01 +%#.0e 14.999 -> 1.e+01 +%#.0e 15 -> 2.e+01 +%#.1e 123.4 -> 1.2e+02 +%#.2e 0.0001357 -> 1.36e-04 + +-- 'g' code formatting. + +-- zeros +%.0g 0 -> 0 +%.1g 0 -> 0 +%.2g 0 -> 0 +%.3g 0 -> 0 +%.4g 0 -> 0 +%.10g 0 -> 0 +%.50g 0 -> 0 +%.100g 0 -> 0 + +-- precision 0 doesn't make a lot of sense for the 'g' code (what does +-- it mean to have no significant digits?); in practice, it's interpreted +-- as identical to precision 1 +%.0g 1000 -> 1e+03 +%.0g 100 -> 1e+02 +%.0g 10 -> 1e+01 +%.0g 1 -> 1 +%.0g 0.1 -> 0.1 +%.0g 0.01 -> 0.01 +%.0g 1e-3 -> 0.001 +%.0g 1e-4 -> 0.0001 +%.0g 1e-5 -> 1e-05 +%.0g 1e-6 -> 1e-06 +%.0g 12 -> 1e+01 +%.0g 120 -> 1e+02 +%.0g 1.2 -> 1 +%.0g 0.12 -> 0.1 +%.0g 0.012 -> 0.01 +%.0g 0.0012 -> 0.001 +%.0g 0.00012 -> 0.0001 +%.0g 0.000012 -> 1e-05 +%.0g 0.0000012 -> 1e-06 + +-- precision 1 identical to precision 0 +%.1g 1000 -> 1e+03 +%.1g 100 -> 1e+02 +%.1g 10 -> 1e+01 +%.1g 1 -> 1 +%.1g 0.1 -> 0.1 +%.1g 0.01 -> 0.01 +%.1g 1e-3 -> 0.001 +%.1g 1e-4 -> 0.0001 +%.1g 1e-5 -> 1e-05 +%.1g 1e-6 -> 1e-06 +%.1g 12 -> 1e+01 +%.1g 120 -> 1e+02 +%.1g 1.2 -> 1 +%.1g 0.12 -> 0.1 +%.1g 0.012 -> 0.01 +%.1g 0.0012 -> 0.001 +%.1g 0.00012 -> 0.0001 +%.1g 0.000012 -> 1e-05 +%.1g 0.0000012 -> 1e-06 + +-- precision 2 +%.2g 1000 -> 1e+03 +%.2g 100 -> 1e+02 +%.2g 10 -> 10 +%.2g 1 -> 1 +%.2g 0.1 -> 0.1 +%.2g 0.01 -> 0.01 +%.2g 0.001 -> 0.001 +%.2g 1e-4 -> 0.0001 +%.2g 1e-5 -> 1e-05 +%.2g 1e-6 -> 1e-06 +%.2g 1234 -> 1.2e+03 +%.2g 123 -> 1.2e+02 +%.2g 12.3 -> 12 +%.2g 1.23 -> 1.2 +%.2g 0.123 -> 0.12 +%.2g 0.0123 -> 0.012 +%.2g 0.00123 -> 0.0012 +%.2g 0.000123 -> 0.00012 +%.2g 0.0000123 -> 1.2e-05 + +-- bad cases from http://bugs.python.org/issue9980 +%.12g 38210.0 -> 38210 +%.12g 37210.0 -> 37210 +%.12g 36210.0 -> 36210 + +-- alternate g formatting: always include decimal point and +-- exactly significant digits. +%#.0g 0 -> 0. +%#.1g 0 -> 0. +%#.2g 0 -> 0.0 +%#.3g 0 -> 0.00 +%#.4g 0 -> 0.000 + +%#.0g 0.2 -> 0.2 +%#.1g 0.2 -> 0.2 +%#.2g 0.2 -> 0.20 +%#.3g 0.2 -> 0.200 +%#.4g 0.2 -> 0.2000 +%#.10g 0.2 -> 0.2000000000 + +%#.0g 2 -> 2. +%#.1g 2 -> 2. +%#.2g 2 -> 2.0 +%#.3g 2 -> 2.00 +%#.4g 2 -> 2.000 + +%#.0g 20 -> 2.e+01 +%#.1g 20 -> 2.e+01 +%#.2g 20 -> 20. +%#.3g 20 -> 20.0 +%#.4g 20 -> 20.00 + +%#.0g 234.56 -> 2.e+02 +%#.1g 234.56 -> 2.e+02 +%#.2g 234.56 -> 2.3e+02 +%#.3g 234.56 -> 235. +%#.4g 234.56 -> 234.6 +%#.5g 234.56 -> 234.56 +%#.6g 234.56 -> 234.560 + +-- repr formatting. Result always includes decimal point and at +-- least one digit after the point, or an exponent. +%r 0 -> 0.0 +%r 1 -> 1.0 + +%r 0.01 -> 0.01 +%r 0.02 -> 0.02 +%r 0.03 -> 0.03 +%r 0.04 -> 0.04 +%r 0.05 -> 0.05 + +-- values >= 1e16 get an exponent +%r 10 -> 10.0 +%r 100 -> 100.0 +%r 1e15 -> 1000000000000000.0 +%r 9.999e15 -> 9999000000000000.0 +%r 9999999999999998 -> 9999999999999998.0 +%r 9999999999999999 -> 1e+16 +%r 1e16 -> 1e+16 +%r 1e17 -> 1e+17 + +-- as do values < 1e-4 +%r 1e-3 -> 0.001 +%r 1.001e-4 -> 0.0001001 +%r 1.0000000000000001e-4 -> 0.0001 +%r 1.000000000000001e-4 -> 0.0001000000000000001 +%r 1.00000000001e-4 -> 0.000100000000001 +%r 1.0000000001e-4 -> 0.00010000000001 +%r 1e-4 -> 0.0001 +%r 0.99999999999999999e-4 -> 0.0001 +%r 0.9999999999999999e-4 -> 9.999999999999999e-05 +%r 0.999999999999e-4 -> 9.99999999999e-05 +%r 0.999e-4 -> 9.99e-05 +%r 1e-5 -> 1e-05 diff --git a/test/dynamo/cpython/3_13/mathdata/ieee754.txt b/test/dynamo/cpython/3_13/mathdata/ieee754.txt new file mode 100644 index 00000000000000..a8b8a0a2148f00 --- /dev/null +++ b/test/dynamo/cpython/3_13/mathdata/ieee754.txt @@ -0,0 +1,183 @@ +====================================== +Python IEEE 754 floating point support +====================================== + +>>> from sys import float_info as FI +>>> from math import * +>>> PI = pi +>>> E = e + +You must never compare two floats with == because you are not going to get +what you expect. We treat two floats as equal if the difference between them +is small than epsilon. +>>> EPS = 1E-15 +>>> def equal(x, y): +... """Almost equal helper for floats""" +... return abs(x - y) < EPS + + +NaNs and INFs +============= + +In Python 2.6 and newer NaNs (not a number) and infinity can be constructed +from the strings 'inf' and 'nan'. + +>>> INF = float('inf') +>>> NINF = float('-inf') +>>> NAN = float('nan') + +>>> INF +inf +>>> NINF +-inf +>>> NAN +nan + +The math module's ``isnan`` and ``isinf`` functions can be used to detect INF +and NAN: +>>> isinf(INF), isinf(NINF), isnan(NAN) +(True, True, True) +>>> INF == -NINF +True + +Infinity +-------- + +Ambiguous operations like ``0 * inf`` or ``inf - inf`` result in NaN. +>>> INF * 0 +nan +>>> INF - INF +nan +>>> INF / INF +nan + +However unambigous operations with inf return inf: +>>> INF * INF +inf +>>> 1.5 * INF +inf +>>> 0.5 * INF +inf +>>> INF / 1000 +inf + +Not a Number +------------ + +NaNs are never equal to another number, even itself +>>> NAN == NAN +False +>>> NAN < 0 +False +>>> NAN >= 0 +False + +All operations involving a NaN return a NaN except for nan**0 and 1**nan. +>>> 1 + NAN +nan +>>> 1 * NAN +nan +>>> 0 * NAN +nan +>>> 1 ** NAN +1.0 +>>> NAN ** 0 +1.0 +>>> 0 ** NAN +nan +>>> (1.0 + FI.epsilon) * NAN +nan + +Misc Functions +============== + +The power of 1 raised to x is always 1.0, even for special values like 0, +infinity and NaN. + +>>> pow(1, 0) +1.0 +>>> pow(1, INF) +1.0 +>>> pow(1, -INF) +1.0 +>>> pow(1, NAN) +1.0 + +The power of 0 raised to x is defined as 0, if x is positive. Negative +finite values are a domain error or zero division error and NaN result in a +silent NaN. + +>>> pow(0, 0) +1.0 +>>> pow(0, INF) +0.0 +>>> pow(0, -INF) +inf +>>> 0 ** -1 +Traceback (most recent call last): +... +ZeroDivisionError: 0.0 cannot be raised to a negative power +>>> pow(0, NAN) +nan + + +Trigonometric Functions +======================= + +>>> sin(INF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> sin(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> sin(NAN) +nan +>>> cos(INF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> cos(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> cos(NAN) +nan +>>> tan(INF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> tan(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> tan(NAN) +nan + +Neither pi nor tan are exact, but you can assume that tan(pi/2) is a large value +and tan(pi) is a very small value: +>>> tan(PI/2) > 1E10 +True +>>> -tan(-PI/2) > 1E10 +True +>>> tan(PI) < 1E-15 +True + +>>> asin(NAN), acos(NAN), atan(NAN) +(nan, nan, nan) +>>> asin(INF), asin(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> acos(INF), acos(NINF) +Traceback (most recent call last): +... +ValueError: math domain error +>>> equal(atan(INF), PI/2), equal(atan(NINF), -PI/2) +(True, True) + + +Hyberbolic Functions +==================== + diff --git a/test/dynamo/cpython/3_13/mathdata/math_testcases.txt b/test/dynamo/cpython/3_13/mathdata/math_testcases.txt new file mode 100644 index 00000000000000..958518824376f8 --- /dev/null +++ b/test/dynamo/cpython/3_13/mathdata/math_testcases.txt @@ -0,0 +1,633 @@ +-- Testcases for functions in math. +-- +-- Each line takes the form: +-- +-- -> +-- +-- where: +-- +-- is a short name identifying the test, +-- +-- is the function to be tested (exp, cos, asinh, ...), +-- +-- is a string representing a floating-point value +-- +-- is the expected (ideal) output value, again +-- represented as a string. +-- +-- is a list of the floating-point flags required by C99 +-- +-- The possible flags are: +-- +-- divide-by-zero : raised when a finite input gives a +-- mathematically infinite result. +-- +-- overflow : raised when a finite input gives a finite result that +-- is too large to fit in the usual range of an IEEE 754 double. +-- +-- invalid : raised for invalid inputs (e.g., sqrt(-1)) +-- +-- ignore-sign : indicates that the sign of the result is +-- unspecified; e.g., if the result is given as inf, +-- then both -inf and inf should be accepted as correct. +-- +-- Flags may appear in any order. +-- +-- Lines beginning with '--' (like this one) start a comment, and are +-- ignored. Blank lines, or lines containing only whitespace, are also +-- ignored. + +-- Many of the values below were computed with the help of +-- version 2.4 of the MPFR library for multiple-precision +-- floating-point computations with correct rounding. All output +-- values in this file are (modulo yet-to-be-discovered bugs) +-- correctly rounded, provided that each input and output decimal +-- floating-point value below is interpreted as a representation of +-- the corresponding nearest IEEE 754 double-precision value. See the +-- MPFR homepage at http://www.mpfr.org for more information about the +-- MPFR project. + + +------------------------- +-- erf: error function -- +------------------------- + +erf0000 erf 0.0 -> 0.0 +erf0001 erf -0.0 -> -0.0 +erf0002 erf inf -> 1.0 +erf0003 erf -inf -> -1.0 +erf0004 erf nan -> nan + +-- tiny values +erf0010 erf 1e-308 -> 1.1283791670955125e-308 +erf0011 erf 5e-324 -> 4.9406564584124654e-324 +erf0012 erf 1e-10 -> 1.1283791670955126e-10 + +-- small integers +erf0020 erf 1 -> 0.84270079294971489 +erf0021 erf 2 -> 0.99532226501895271 +erf0022 erf 3 -> 0.99997790950300136 +erf0023 erf 4 -> 0.99999998458274209 +erf0024 erf 5 -> 0.99999999999846256 +erf0025 erf 6 -> 1.0 + +erf0030 erf -1 -> -0.84270079294971489 +erf0031 erf -2 -> -0.99532226501895271 +erf0032 erf -3 -> -0.99997790950300136 +erf0033 erf -4 -> -0.99999998458274209 +erf0034 erf -5 -> -0.99999999999846256 +erf0035 erf -6 -> -1.0 + +-- huge values should all go to +/-1, depending on sign +erf0040 erf -40 -> -1.0 +erf0041 erf 1e16 -> 1.0 +erf0042 erf -1e150 -> -1.0 +erf0043 erf 1.7e308 -> 1.0 + +-- Issue 8986: inputs x with exp(-x*x) near the underflow threshold +-- incorrectly signalled overflow on some platforms. +erf0100 erf 26.2 -> 1.0 +erf0101 erf 26.4 -> 1.0 +erf0102 erf 26.6 -> 1.0 +erf0103 erf 26.8 -> 1.0 +erf0104 erf 27.0 -> 1.0 +erf0105 erf 27.2 -> 1.0 +erf0106 erf 27.4 -> 1.0 +erf0107 erf 27.6 -> 1.0 + +erf0110 erf -26.2 -> -1.0 +erf0111 erf -26.4 -> -1.0 +erf0112 erf -26.6 -> -1.0 +erf0113 erf -26.8 -> -1.0 +erf0114 erf -27.0 -> -1.0 +erf0115 erf -27.2 -> -1.0 +erf0116 erf -27.4 -> -1.0 +erf0117 erf -27.6 -> -1.0 + +---------------------------------------- +-- erfc: complementary error function -- +---------------------------------------- + +erfc0000 erfc 0.0 -> 1.0 +erfc0001 erfc -0.0 -> 1.0 +erfc0002 erfc inf -> 0.0 +erfc0003 erfc -inf -> 2.0 +erfc0004 erfc nan -> nan + +-- tiny values +erfc0010 erfc 1e-308 -> 1.0 +erfc0011 erfc 5e-324 -> 1.0 +erfc0012 erfc 1e-10 -> 0.99999999988716204 + +-- small integers +erfc0020 erfc 1 -> 0.15729920705028513 +erfc0021 erfc 2 -> 0.0046777349810472662 +erfc0022 erfc 3 -> 2.2090496998585441e-05 +erfc0023 erfc 4 -> 1.541725790028002e-08 +erfc0024 erfc 5 -> 1.5374597944280349e-12 +erfc0025 erfc 6 -> 2.1519736712498913e-17 + +erfc0030 erfc -1 -> 1.8427007929497148 +erfc0031 erfc -2 -> 1.9953222650189528 +erfc0032 erfc -3 -> 1.9999779095030015 +erfc0033 erfc -4 -> 1.9999999845827421 +erfc0034 erfc -5 -> 1.9999999999984626 +erfc0035 erfc -6 -> 2.0 + +-- as x -> infinity, erfc(x) behaves like exp(-x*x)/x/sqrt(pi) +erfc0040 erfc 20 -> 5.3958656116079012e-176 +erfc0041 erfc 25 -> 8.3001725711965228e-274 +erfc0042 erfc 27 -> 5.2370464393526292e-319 +erfc0043 erfc 28 -> 0.0 + +-- huge values +erfc0050 erfc -40 -> 2.0 +erfc0051 erfc 1e16 -> 0.0 +erfc0052 erfc -1e150 -> 2.0 +erfc0053 erfc 1.7e308 -> 0.0 + +-- Issue 8986: inputs x with exp(-x*x) near the underflow threshold +-- incorrectly signalled overflow on some platforms. +erfc0100 erfc 26.2 -> 1.6432507924389461e-300 +erfc0101 erfc 26.4 -> 4.4017768588035426e-305 +erfc0102 erfc 26.6 -> 1.0885125885442269e-309 +erfc0103 erfc 26.8 -> 2.4849621571966629e-314 +erfc0104 erfc 27.0 -> 5.2370464393526292e-319 +erfc0105 erfc 27.2 -> 9.8813129168249309e-324 +erfc0106 erfc 27.4 -> 0.0 +erfc0107 erfc 27.6 -> 0.0 + +erfc0110 erfc -26.2 -> 2.0 +erfc0111 erfc -26.4 -> 2.0 +erfc0112 erfc -26.6 -> 2.0 +erfc0113 erfc -26.8 -> 2.0 +erfc0114 erfc -27.0 -> 2.0 +erfc0115 erfc -27.2 -> 2.0 +erfc0116 erfc -27.4 -> 2.0 +erfc0117 erfc -27.6 -> 2.0 + +--------------------------------------------------------- +-- lgamma: log of absolute value of the gamma function -- +--------------------------------------------------------- + +-- special values +lgam0000 lgamma 0.0 -> inf divide-by-zero +lgam0001 lgamma -0.0 -> inf divide-by-zero +lgam0002 lgamma inf -> inf +lgam0003 lgamma -inf -> inf +lgam0004 lgamma nan -> nan + +-- negative integers +lgam0010 lgamma -1 -> inf divide-by-zero +lgam0011 lgamma -2 -> inf divide-by-zero +lgam0012 lgamma -1e16 -> inf divide-by-zero +lgam0013 lgamma -1e300 -> inf divide-by-zero +lgam0014 lgamma -1.79e308 -> inf divide-by-zero + +-- small positive integers give factorials +lgam0020 lgamma 1 -> 0.0 +lgam0021 lgamma 2 -> 0.0 +lgam0022 lgamma 3 -> 0.69314718055994529 +lgam0023 lgamma 4 -> 1.791759469228055 +lgam0024 lgamma 5 -> 3.1780538303479458 +lgam0025 lgamma 6 -> 4.7874917427820458 + +-- half integers +lgam0030 lgamma 0.5 -> 0.57236494292470008 +lgam0031 lgamma 1.5 -> -0.12078223763524522 +lgam0032 lgamma 2.5 -> 0.28468287047291918 +lgam0033 lgamma 3.5 -> 1.2009736023470743 +lgam0034 lgamma -0.5 -> 1.2655121234846454 +lgam0035 lgamma -1.5 -> 0.86004701537648098 +lgam0036 lgamma -2.5 -> -0.056243716497674054 +lgam0037 lgamma -3.5 -> -1.309006684993042 + +-- values near 0 +lgam0040 lgamma 0.1 -> 2.252712651734206 +lgam0041 lgamma 0.01 -> 4.5994798780420219 +lgam0042 lgamma 1e-8 -> 18.420680738180209 +lgam0043 lgamma 1e-16 -> 36.841361487904734 +lgam0044 lgamma 1e-30 -> 69.077552789821368 +lgam0045 lgamma 1e-160 -> 368.41361487904732 +lgam0046 lgamma 1e-308 -> 709.19620864216608 +lgam0047 lgamma 5.6e-309 -> 709.77602713741896 +lgam0048 lgamma 5.5e-309 -> 709.79404564292167 +lgam0049 lgamma 1e-309 -> 711.49879373516012 +lgam0050 lgamma 1e-323 -> 743.74692474082133 +lgam0051 lgamma 5e-324 -> 744.44007192138122 +lgam0060 lgamma -0.1 -> 2.3689613327287886 +lgam0061 lgamma -0.01 -> 4.6110249927528013 +lgam0062 lgamma -1e-8 -> 18.420680749724522 +lgam0063 lgamma -1e-16 -> 36.841361487904734 +lgam0064 lgamma -1e-30 -> 69.077552789821368 +lgam0065 lgamma -1e-160 -> 368.41361487904732 +lgam0066 lgamma -1e-308 -> 709.19620864216608 +lgam0067 lgamma -5.6e-309 -> 709.77602713741896 +lgam0068 lgamma -5.5e-309 -> 709.79404564292167 +lgam0069 lgamma -1e-309 -> 711.49879373516012 +lgam0070 lgamma -1e-323 -> 743.74692474082133 +lgam0071 lgamma -5e-324 -> 744.44007192138122 + +-- values near negative integers +lgam0080 lgamma -0.99999999999999989 -> 36.736800569677101 +lgam0081 lgamma -1.0000000000000002 -> 36.043653389117154 +lgam0082 lgamma -1.9999999999999998 -> 35.350506208557213 +lgam0083 lgamma -2.0000000000000004 -> 34.657359027997266 +lgam0084 lgamma -100.00000000000001 -> -331.85460524980607 +lgam0085 lgamma -99.999999999999986 -> -331.85460524980596 + +-- large inputs +lgam0100 lgamma 170 -> 701.43726380873704 +lgam0101 lgamma 171 -> 706.57306224578736 +lgam0102 lgamma 171.624 -> 709.78077443669895 +lgam0103 lgamma 171.625 -> 709.78591682948365 +lgam0104 lgamma 172 -> 711.71472580228999 +lgam0105 lgamma 2000 -> 13198.923448054265 +lgam0106 lgamma 2.55998332785163e305 -> 1.7976931348623099e+308 +lgam0107 lgamma 2.55998332785164e305 -> inf overflow +lgam0108 lgamma 1.7e308 -> inf overflow + +-- inputs for which gamma(x) is tiny +lgam0120 lgamma -100.5 -> -364.90096830942736 +lgam0121 lgamma -160.5 -> -656.88005261126432 +lgam0122 lgamma -170.5 -> -707.99843314507882 +lgam0123 lgamma -171.5 -> -713.14301641168481 +lgam0124 lgamma -176.5 -> -738.95247590846486 +lgam0125 lgamma -177.5 -> -744.13144651738037 +lgam0126 lgamma -178.5 -> -749.3160351186001 + +lgam0130 lgamma -1000.5 -> -5914.4377011168517 +lgam0131 lgamma -30000.5 -> -279278.6629959144 +lgam0132 lgamma -4503599627370495.5 -> -1.5782258434492883e+17 + +-- results close to 0: positive argument ... +lgam0150 lgamma 0.99999999999999989 -> 6.4083812134800075e-17 +lgam0151 lgamma 1.0000000000000002 -> -1.2816762426960008e-16 +lgam0152 lgamma 1.9999999999999998 -> -9.3876980655431170e-17 +lgam0153 lgamma 2.0000000000000004 -> 1.8775396131086244e-16 + +-- ... and negative argument +lgam0160 lgamma -2.7476826467 -> -5.2477408147689136e-11 +lgam0161 lgamma -2.457024738 -> 3.3464637541912932e-10 + + +--------------------------- +-- gamma: Gamma function -- +--------------------------- + +-- special values +gam0000 gamma 0.0 -> inf divide-by-zero +gam0001 gamma -0.0 -> -inf divide-by-zero +gam0002 gamma inf -> inf +gam0003 gamma -inf -> nan invalid +gam0004 gamma nan -> nan + +-- negative integers inputs are invalid +gam0010 gamma -1 -> nan invalid +gam0011 gamma -2 -> nan invalid +gam0012 gamma -1e16 -> nan invalid +gam0013 gamma -1e300 -> nan invalid + +-- small positive integers give factorials +gam0020 gamma 1 -> 1 +gam0021 gamma 2 -> 1 +gam0022 gamma 3 -> 2 +gam0023 gamma 4 -> 6 +gam0024 gamma 5 -> 24 +gam0025 gamma 6 -> 120 + +-- half integers +gam0030 gamma 0.5 -> 1.7724538509055161 +gam0031 gamma 1.5 -> 0.88622692545275805 +gam0032 gamma 2.5 -> 1.3293403881791370 +gam0033 gamma 3.5 -> 3.3233509704478426 +gam0034 gamma -0.5 -> -3.5449077018110322 +gam0035 gamma -1.5 -> 2.3632718012073548 +gam0036 gamma -2.5 -> -0.94530872048294190 +gam0037 gamma -3.5 -> 0.27008820585226911 + +-- values near 0 +gam0040 gamma 0.1 -> 9.5135076986687306 +gam0041 gamma 0.01 -> 99.432585119150602 +gam0042 gamma 1e-8 -> 99999999.422784343 +gam0043 gamma 1e-16 -> 10000000000000000 +gam0044 gamma 1e-30 -> 9.9999999999999988e+29 +gam0045 gamma 1e-160 -> 1.0000000000000000e+160 +gam0046 gamma 1e-308 -> 1.0000000000000000e+308 +gam0047 gamma 5.6e-309 -> 1.7857142857142848e+308 +gam0048 gamma 5.5e-309 -> inf overflow +gam0049 gamma 1e-309 -> inf overflow +gam0050 gamma 1e-323 -> inf overflow +gam0051 gamma 5e-324 -> inf overflow +gam0060 gamma -0.1 -> -10.686287021193193 +gam0061 gamma -0.01 -> -100.58719796441078 +gam0062 gamma -1e-8 -> -100000000.57721567 +gam0063 gamma -1e-16 -> -10000000000000000 +gam0064 gamma -1e-30 -> -9.9999999999999988e+29 +gam0065 gamma -1e-160 -> -1.0000000000000000e+160 +gam0066 gamma -1e-308 -> -1.0000000000000000e+308 +gam0067 gamma -5.6e-309 -> -1.7857142857142848e+308 +gam0068 gamma -5.5e-309 -> -inf overflow +gam0069 gamma -1e-309 -> -inf overflow +gam0070 gamma -1e-323 -> -inf overflow +gam0071 gamma -5e-324 -> -inf overflow + +-- values near negative integers +gam0080 gamma -0.99999999999999989 -> -9007199254740992.0 +gam0081 gamma -1.0000000000000002 -> 4503599627370495.5 +gam0082 gamma -1.9999999999999998 -> 2251799813685248.5 +gam0083 gamma -2.0000000000000004 -> -1125899906842623.5 +gam0084 gamma -100.00000000000001 -> -7.5400833348831090e-145 +gam0085 gamma -99.999999999999986 -> 7.5400833348840962e-145 + +-- large inputs +gam0100 gamma 170 -> 4.2690680090047051e+304 +gam0101 gamma 171 -> 7.2574156153079990e+306 +gam0102 gamma 171.624 -> 1.7942117599248104e+308 +gam0103 gamma 171.625 -> inf overflow +gam0104 gamma 172 -> inf overflow +gam0105 gamma 2000 -> inf overflow +gam0106 gamma 1.7e308 -> inf overflow + +-- inputs for which gamma(x) is tiny +gam0120 gamma -100.5 -> -3.3536908198076787e-159 +gam0121 gamma -160.5 -> -5.2555464470078293e-286 +gam0122 gamma -170.5 -> -3.3127395215386074e-308 +gam0123 gamma -171.5 -> 1.9316265431711902e-310 +gam0124 gamma -176.5 -> -1.1956388629358166e-321 +gam0125 gamma -177.5 -> 4.9406564584124654e-324 +gam0126 gamma -178.5 -> -0.0 +gam0127 gamma -179.5 -> 0.0 +gam0128 gamma -201.0001 -> 0.0 +gam0129 gamma -202.9999 -> -0.0 +gam0130 gamma -1000.5 -> -0.0 +gam0131 gamma -1000000000.3 -> -0.0 +gam0132 gamma -4503599627370495.5 -> 0.0 + +-- inputs that cause problems for the standard reflection formula, +-- thanks to loss of accuracy in 1-x +gam0140 gamma -63.349078729022985 -> 4.1777971677761880e-88 +gam0141 gamma -127.45117632943295 -> 1.1831110896236810e-214 + + +----------------------------------------------------------- +-- log1p: log(1 + x), without precision loss for small x -- +----------------------------------------------------------- + +-- special values +log1p0000 log1p 0.0 -> 0.0 +log1p0001 log1p -0.0 -> -0.0 +log1p0002 log1p inf -> inf +log1p0003 log1p -inf -> nan invalid +log1p0004 log1p nan -> nan + +-- singularity at -1.0 +log1p0010 log1p -1.0 -> -inf divide-by-zero +log1p0011 log1p -0.9999999999999999 -> -36.736800569677101 + +-- finite values < 1.0 are invalid +log1p0020 log1p -1.0000000000000002 -> nan invalid +log1p0021 log1p -1.1 -> nan invalid +log1p0022 log1p -2.0 -> nan invalid +log1p0023 log1p -1e300 -> nan invalid + +-- tiny x: log1p(x) ~ x +log1p0110 log1p 5e-324 -> 5e-324 +log1p0111 log1p 1e-320 -> 1e-320 +log1p0112 log1p 1e-300 -> 1e-300 +log1p0113 log1p 1e-150 -> 1e-150 +log1p0114 log1p 1e-20 -> 1e-20 + +log1p0120 log1p -5e-324 -> -5e-324 +log1p0121 log1p -1e-320 -> -1e-320 +log1p0122 log1p -1e-300 -> -1e-300 +log1p0123 log1p -1e-150 -> -1e-150 +log1p0124 log1p -1e-20 -> -1e-20 + +-- some (mostly) random small and moderate-sized values +log1p0200 log1p -0.89156889782277482 -> -2.2216403106762863 +log1p0201 log1p -0.23858496047770464 -> -0.27257668276980057 +log1p0202 log1p -0.011641726191307515 -> -0.011710021654495657 +log1p0203 log1p -0.0090126398571693817 -> -0.0090534993825007650 +log1p0204 log1p -0.00023442805985712781 -> -0.00023445554240995693 +log1p0205 log1p -1.5672870980936349e-5 -> -1.5672993801662046e-5 +log1p0206 log1p -7.9650013274825295e-6 -> -7.9650330482740401e-6 +log1p0207 log1p -2.5202948343227410e-7 -> -2.5202951519170971e-7 +log1p0208 log1p -8.2446372820745855e-11 -> -8.2446372824144559e-11 +log1p0209 log1p -8.1663670046490789e-12 -> -8.1663670046824230e-12 +log1p0210 log1p 7.0351735084656292e-18 -> 7.0351735084656292e-18 +log1p0211 log1p 5.2732161907375226e-12 -> 5.2732161907236188e-12 +log1p0212 log1p 1.0000000000000000e-10 -> 9.9999999995000007e-11 +log1p0213 log1p 2.1401273266000197e-9 -> 2.1401273243099470e-9 +log1p0214 log1p 1.2668914653979560e-8 -> 1.2668914573728861e-8 +log1p0215 log1p 1.6250007816299069e-6 -> 1.6249994613175672e-6 +log1p0216 log1p 8.3740495645839399e-6 -> 8.3740145024266269e-6 +log1p0217 log1p 3.0000000000000001e-5 -> 2.9999550008999799e-5 +log1p0218 log1p 0.0070000000000000001 -> 0.0069756137364252423 +log1p0219 log1p 0.013026235315053002 -> 0.012942123564008787 +log1p0220 log1p 0.013497160797236184 -> 0.013406885521915038 +log1p0221 log1p 0.027625599078135284 -> 0.027250897463483054 +log1p0222 log1p 0.14179687245544870 -> 0.13260322540908789 + +-- large values +log1p0300 log1p 1.7976931348623157e+308 -> 709.78271289338397 +log1p0301 log1p 1.0000000000000001e+300 -> 690.77552789821368 +log1p0302 log1p 1.0000000000000001e+70 -> 161.18095650958321 +log1p0303 log1p 10000000000.000000 -> 23.025850930040455 + +-- other values transferred from testLog1p in test_math +log1p0400 log1p -0.63212055882855767 -> -1.0000000000000000 +log1p0401 log1p 1.7182818284590451 -> 1.0000000000000000 +log1p0402 log1p 1.0000000000000000 -> 0.69314718055994529 +log1p0403 log1p 1.2379400392853803e+27 -> 62.383246250395075 + + +----------------------------------------------------------- +-- expm1: exp(x) - 1, without precision loss for small x -- +----------------------------------------------------------- + +-- special values +expm10000 expm1 0.0 -> 0.0 +expm10001 expm1 -0.0 -> -0.0 +expm10002 expm1 inf -> inf +expm10003 expm1 -inf -> -1.0 +expm10004 expm1 nan -> nan + +-- expm1(x) ~ x for tiny x +expm10010 expm1 5e-324 -> 5e-324 +expm10011 expm1 1e-320 -> 1e-320 +expm10012 expm1 1e-300 -> 1e-300 +expm10013 expm1 1e-150 -> 1e-150 +expm10014 expm1 1e-20 -> 1e-20 + +expm10020 expm1 -5e-324 -> -5e-324 +expm10021 expm1 -1e-320 -> -1e-320 +expm10022 expm1 -1e-300 -> -1e-300 +expm10023 expm1 -1e-150 -> -1e-150 +expm10024 expm1 -1e-20 -> -1e-20 + +-- moderate sized values, where direct evaluation runs into trouble +expm10100 expm1 1e-10 -> 1.0000000000500000e-10 +expm10101 expm1 -9.9999999999999995e-08 -> -9.9999995000000163e-8 +expm10102 expm1 3.0000000000000001e-05 -> 3.0000450004500034e-5 +expm10103 expm1 -0.0070000000000000001 -> -0.0069755570667648951 +expm10104 expm1 -0.071499208740094633 -> -0.069002985744820250 +expm10105 expm1 -0.063296004180116799 -> -0.061334416373633009 +expm10106 expm1 0.02390954035597756 -> 0.024197665143819942 +expm10107 expm1 0.085637352649044901 -> 0.089411184580357767 +expm10108 expm1 0.5966174947411006 -> 0.81596588596501485 +expm10109 expm1 0.30247206212075139 -> 0.35319987035848677 +expm10110 expm1 0.74574727375889516 -> 1.1080161116737459 +expm10111 expm1 0.97767512926555711 -> 1.6582689207372185 +expm10112 expm1 0.8450154566787712 -> 1.3280137976535897 +expm10113 expm1 -0.13979260323125264 -> -0.13046144381396060 +expm10114 expm1 -0.52899322039643271 -> -0.41080213643695923 +expm10115 expm1 -0.74083261478900631 -> -0.52328317124797097 +expm10116 expm1 -0.93847766984546055 -> -0.60877704724085946 +expm10117 expm1 10.0 -> 22025.465794806718 +expm10118 expm1 27.0 -> 532048240600.79865 +expm10119 expm1 123 -> 2.6195173187490626e+53 +expm10120 expm1 -12.0 -> -0.99999385578764666 +expm10121 expm1 -35.100000000000001 -> -0.99999999999999944 + +-- extreme negative values +expm10201 expm1 -37.0 -> -0.99999999999999989 +expm10200 expm1 -38.0 -> -1.0 +expm10210 expm1 -710.0 -> -1.0 +-- the formula expm1(x) = 2 * sinh(x/2) * exp(x/2) doesn't work so +-- well when exp(x/2) is subnormal or underflows to zero; check we're +-- not using it! +expm10211 expm1 -1420.0 -> -1.0 +expm10212 expm1 -1450.0 -> -1.0 +expm10213 expm1 -1500.0 -> -1.0 +expm10214 expm1 -1e50 -> -1.0 +expm10215 expm1 -1.79e308 -> -1.0 + +-- extreme positive values +expm10300 expm1 300 -> 1.9424263952412558e+130 +expm10301 expm1 700 -> 1.0142320547350045e+304 +-- the next test (expm10302) is disabled because it causes failure on +-- OS X 10.4/Intel: apparently all values over 709.78 produce an +-- overflow on that platform. See issue #7575. +-- expm10302 expm1 709.78271289328393 -> 1.7976931346824240e+308 +expm10303 expm1 709.78271289348402 -> inf overflow +expm10304 expm1 1000 -> inf overflow +expm10305 expm1 1e50 -> inf overflow +expm10306 expm1 1.79e308 -> inf overflow + +-- weaker version of expm10302 +expm10307 expm1 709.5 -> 1.3549863193146328e+308 + +------------------------- +-- log2: log to base 2 -- +------------------------- + +-- special values +log20000 log2 0.0 -> -inf divide-by-zero +log20001 log2 -0.0 -> -inf divide-by-zero +log20002 log2 inf -> inf +log20003 log2 -inf -> nan invalid +log20004 log2 nan -> nan + +-- exact value at 1.0 +log20010 log2 1.0 -> 0.0 + +-- negatives +log20020 log2 -5e-324 -> nan invalid +log20021 log2 -1.0 -> nan invalid +log20022 log2 -1.7e-308 -> nan invalid + +-- exact values at powers of 2 +log20100 log2 2.0 -> 1.0 +log20101 log2 4.0 -> 2.0 +log20102 log2 8.0 -> 3.0 +log20103 log2 16.0 -> 4.0 +log20104 log2 32.0 -> 5.0 +log20105 log2 64.0 -> 6.0 +log20106 log2 128.0 -> 7.0 +log20107 log2 256.0 -> 8.0 +log20108 log2 512.0 -> 9.0 +log20109 log2 1024.0 -> 10.0 +log20110 log2 2048.0 -> 11.0 + +log20200 log2 0.5 -> -1.0 +log20201 log2 0.25 -> -2.0 +log20202 log2 0.125 -> -3.0 +log20203 log2 0.0625 -> -4.0 + +-- values close to 1.0 +log20300 log2 1.0000000000000002 -> 3.2034265038149171e-16 +log20301 log2 1.0000000001 -> 1.4426951601859516e-10 +log20302 log2 1.00001 -> 1.4426878274712997e-5 + +log20310 log2 0.9999999999999999 -> -1.6017132519074588e-16 +log20311 log2 0.9999999999 -> -1.4426951603302210e-10 +log20312 log2 0.99999 -> -1.4427022544056922e-5 + +-- tiny values +log20400 log2 5e-324 -> -1074.0 +log20401 log2 1e-323 -> -1073.0 +log20402 log2 1.5e-323 -> -1072.4150374992789 +log20403 log2 2e-323 -> -1072.0 + +log20410 log2 1e-308 -> -1023.1538532253076 +log20411 log2 2.2250738585072014e-308 -> -1022.0 +log20412 log2 4.4501477170144028e-308 -> -1021.0 +log20413 log2 1e-307 -> -1019.8319251304202 + +-- huge values +log20500 log2 1.7976931348623157e+308 -> 1024.0 +log20501 log2 1.7e+308 -> 1023.9193879716706 +log20502 log2 8.9884656743115795e+307 -> 1023.0 + +-- selection of random values +log20600 log2 -7.2174324841039838e+289 -> nan invalid +log20601 log2 -2.861319734089617e+265 -> nan invalid +log20602 log2 -4.3507646894008962e+257 -> nan invalid +log20603 log2 -6.6717265307520224e+234 -> nan invalid +log20604 log2 -3.9118023786619294e+229 -> nan invalid +log20605 log2 -1.5478221302505161e+206 -> nan invalid +log20606 log2 -1.4380485131364602e+200 -> nan invalid +log20607 log2 -3.7235198730382645e+185 -> nan invalid +log20608 log2 -1.0472242235095724e+184 -> nan invalid +log20609 log2 -5.0141781956163884e+160 -> nan invalid +log20610 log2 -2.1157958031160324e+124 -> nan invalid +log20611 log2 -7.9677558612567718e+90 -> nan invalid +log20612 log2 -5.5553906194063732e+45 -> nan invalid +log20613 log2 -16573900952607.953 -> nan invalid +log20614 log2 -37198371019.888618 -> nan invalid +log20615 log2 -6.0727115121422674e-32 -> nan invalid +log20616 log2 -2.5406841656526057e-38 -> nan invalid +log20617 log2 -4.9056766703267657e-43 -> nan invalid +log20618 log2 -2.1646786075228305e-71 -> nan invalid +log20619 log2 -2.470826790488573e-78 -> nan invalid +log20620 log2 -3.8661709303489064e-165 -> nan invalid +log20621 log2 -1.0516496976649986e-182 -> nan invalid +log20622 log2 -1.5935458614317996e-255 -> nan invalid +log20623 log2 -2.8750977267336654e-293 -> nan invalid +log20624 log2 -7.6079466794732585e-296 -> nan invalid +log20625 log2 3.2073253539988545e-307 -> -1018.1505544209213 +log20626 log2 1.674937885472249e-244 -> -809.80634755783126 +log20627 log2 1.0911259044931283e-214 -> -710.76679472274213 +log20628 log2 2.0275372624809709e-154 -> -510.55719818383272 +log20629 log2 7.3926087369631841e-115 -> -379.13564735312292 +log20630 log2 1.3480198206342423e-86 -> -285.25497445094436 +log20631 log2 8.9927384655719947e-83 -> -272.55127136401637 +log20632 log2 3.1452398713597487e-60 -> -197.66251564496875 +log20633 log2 7.0706573215457351e-55 -> -179.88420087782217 +log20634 log2 3.1258285390731669e-49 -> -161.13023800505653 +log20635 log2 8.2253046627829942e-41 -> -133.15898277355879 +log20636 log2 7.8691367397519897e+49 -> 165.75068202732419 +log20637 log2 2.9920561983925013e+64 -> 214.18453534573757 +log20638 log2 4.7827254553946841e+77 -> 258.04629628445673 +log20639 log2 3.1903566496481868e+105 -> 350.47616767491166 +log20640 log2 5.6195082449502419e+113 -> 377.86831861008250 +log20641 log2 9.9625658250651047e+125 -> 418.55752921228753 +log20642 log2 2.7358945220961532e+145 -> 483.13158636923413 +log20643 log2 2.785842387926931e+174 -> 579.49360214860280 +log20644 log2 2.4169172507252751e+193 -> 642.40529039289652 +log20645 log2 3.1689091206395632e+205 -> 682.65924573798395 +log20646 log2 2.535995592365391e+208 -> 692.30359597460460 +log20647 log2 6.2011236566089916e+233 -> 776.64177576730913 +log20648 log2 2.1843274820677632e+253 -> 841.57499717289647 +log20649 log2 8.7493931063474791e+297 -> 989.74182713073981 diff --git a/test/dynamo/cpython/3_13/seq_tests.diff b/test/dynamo/cpython/3_13/seq_tests.diff new file mode 100644 index 00000000000000..03c7021e4f96ae --- /dev/null +++ b/test/dynamo/cpython/3_13/seq_tests.diff @@ -0,0 +1,68 @@ +diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py +index 719c9434a16..4325892276d 100644 +--- a/test/dynamo/cpython/3_13/seq_tests.py ++++ b/test/dynamo/cpython/3_13/seq_tests.py +@@ -1,3 +1,54 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + """ + Tests common to tuple, list and UserList.UserList + """ +@@ -95,7 +146,7 @@ class LyingList(list): + def __iter__(self): + yield 1 + +-class CommonTest(unittest.TestCase): ++class CommonTest(__TestCase): + # The type to be tested + type2test = None + diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py new file mode 100644 index 00000000000000..4325892276d4cb --- /dev/null +++ b/test/dynamo/cpython/3_13/seq_tests.py @@ -0,0 +1,483 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +""" +Tests common to tuple, list and UserList.UserList +""" + +import unittest +import sys +import pickle +from test import support +from test.support import ALWAYS_EQ, NEVER_EQ + +# Various iterables +# This is used for checking the constructor (here and in test_deque.py) +def iterfunc(seqn): + 'Regular generator' + for i in seqn: + yield i + +class Sequence: + 'Sequence using __getitem__' + def __init__(self, seqn): + self.seqn = seqn + def __getitem__(self, i): + return self.seqn[i] + +class IterFunc: + 'Sequence using iterator protocol' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def __next__(self): + if self.i >= len(self.seqn): raise StopIteration + v = self.seqn[self.i] + self.i += 1 + return v + +class IterGen: + 'Sequence using iterator protocol defined with a generator' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + for val in self.seqn: + yield val + +class IterNextOnly: + 'Missing __getitem__ and __iter__' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __next__(self): + if self.i >= len(self.seqn): raise StopIteration + v = self.seqn[self.i] + self.i += 1 + return v + +class IterNoNext: + 'Iterator missing __next__()' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + +class IterGenExc: + 'Test propagation of exceptions' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def __next__(self): + 3 // 0 + +class IterFuncStop: + 'Test immediate stop' + def __init__(self, seqn): + pass + def __iter__(self): + return self + def __next__(self): + raise StopIteration + +from itertools import chain +def itermulti(seqn): + 'Test multiple tiers of iterators' + return chain(map(lambda x:x, iterfunc(IterGen(Sequence(seqn))))) + +class LyingTuple(tuple): + def __iter__(self): + yield 1 + +class LyingList(list): + def __iter__(self): + yield 1 + +class CommonTest(__TestCase): + # The type to be tested + type2test = None + + def test_constructors(self): + l0 = [] + l1 = [0] + l2 = [0, 1] + + u = self.type2test() + u0 = self.type2test(l0) + u1 = self.type2test(l1) + u2 = self.type2test(l2) + + uu = self.type2test(u) + uu0 = self.type2test(u0) + uu1 = self.type2test(u1) + uu2 = self.type2test(u2) + + v = self.type2test(tuple(u)) + class OtherSeq: + def __init__(self, initseq): + self.__data = initseq + def __len__(self): + return len(self.__data) + def __getitem__(self, i): + return self.__data[i] + s = OtherSeq(u0) + v0 = self.type2test(s) + self.assertEqual(len(v0), len(s)) + + s = "this is also a sequence" + vv = self.type2test(s) + self.assertEqual(len(vv), len(s)) + + # Create from various iteratables + for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): + for g in (Sequence, IterFunc, IterGen, + itermulti, iterfunc): + self.assertEqual(self.type2test(g(s)), self.type2test(s)) + self.assertEqual(self.type2test(IterFuncStop(s)), self.type2test()) + self.assertEqual(self.type2test(c for c in "123"), self.type2test("123")) + self.assertRaises(TypeError, self.type2test, IterNextOnly(s)) + self.assertRaises(TypeError, self.type2test, IterNoNext(s)) + self.assertRaises(ZeroDivisionError, self.type2test, IterGenExc(s)) + + # Issue #23757 + self.assertEqual(self.type2test(LyingTuple((2,))), self.type2test((1,))) + self.assertEqual(self.type2test(LyingList([2])), self.type2test([1])) + + with self.assertRaises(TypeError): + self.type2test(unsupported_arg=[]) + + def test_truth(self): + self.assertFalse(self.type2test()) + self.assertTrue(self.type2test([42])) + + def test_getitem(self): + u = self.type2test([0, 1, 2, 3, 4]) + for i in range(len(u)): + self.assertEqual(u[i], i) + self.assertEqual(u[int(i)], i) + for i in range(-len(u), -1): + self.assertEqual(u[i], len(u)+i) + self.assertEqual(u[int(i)], len(u)+i) + self.assertRaises(IndexError, u.__getitem__, -len(u)-1) + self.assertRaises(IndexError, u.__getitem__, len(u)) + self.assertRaises(ValueError, u.__getitem__, slice(0,10,0)) + + u = self.type2test() + self.assertRaises(IndexError, u.__getitem__, 0) + self.assertRaises(IndexError, u.__getitem__, -1) + + self.assertRaises(TypeError, u.__getitem__) + + a = self.type2test([10, 11]) + self.assertEqual(a[0], 10) + self.assertEqual(a[1], 11) + self.assertEqual(a[-2], 10) + self.assertEqual(a[-1], 11) + self.assertRaises(IndexError, a.__getitem__, -3) + self.assertRaises(IndexError, a.__getitem__, 3) + + def test_getslice(self): + l = [0, 1, 2, 3, 4] + u = self.type2test(l) + + self.assertEqual(u[0:0], self.type2test()) + self.assertEqual(u[1:2], self.type2test([1])) + self.assertEqual(u[-2:-1], self.type2test([3])) + self.assertEqual(u[-1000:1000], u) + self.assertEqual(u[1000:-1000], self.type2test([])) + self.assertEqual(u[:], u) + self.assertEqual(u[1:None], self.type2test([1, 2, 3, 4])) + self.assertEqual(u[None:3], self.type2test([0, 1, 2])) + + # Extended slices + self.assertEqual(u[::], u) + self.assertEqual(u[::2], self.type2test([0, 2, 4])) + self.assertEqual(u[1::2], self.type2test([1, 3])) + self.assertEqual(u[::-1], self.type2test([4, 3, 2, 1, 0])) + self.assertEqual(u[::-2], self.type2test([4, 2, 0])) + self.assertEqual(u[3::-2], self.type2test([3, 1])) + self.assertEqual(u[3:3:-2], self.type2test([])) + self.assertEqual(u[3:2:-2], self.type2test([3])) + self.assertEqual(u[3:1:-2], self.type2test([3])) + self.assertEqual(u[3:0:-2], self.type2test([3, 1])) + self.assertEqual(u[::-100], self.type2test([4])) + self.assertEqual(u[100:-100:], self.type2test([])) + self.assertEqual(u[-100:100:], u) + self.assertEqual(u[100:-100:-1], u[::-1]) + self.assertEqual(u[-100:100:-1], self.type2test([])) + self.assertEqual(u[-100:100:2], self.type2test([0, 2, 4])) + + # Test extreme cases with long ints + a = self.type2test([0,1,2,3,4]) + self.assertEqual(a[ -pow(2,128): 3 ], self.type2test([0,1,2])) + self.assertEqual(a[ 3: pow(2,145) ], self.type2test([3,4])) + self.assertEqual(a[3::sys.maxsize], self.type2test([3])) + + def test_contains(self): + u = self.type2test([0, 1, 2]) + for i in u: + self.assertIn(i, u) + for i in min(u)-1, max(u)+1: + self.assertNotIn(i, u) + + self.assertRaises(TypeError, u.__contains__) + + def test_contains_fake(self): + # Sequences must use rich comparison against each item + # (unless "is" is true, or an earlier item answered) + # So ALWAYS_EQ must be found in all non-empty sequences. + self.assertNotIn(ALWAYS_EQ, self.type2test([])) + self.assertIn(ALWAYS_EQ, self.type2test([1])) + self.assertIn(1, self.type2test([ALWAYS_EQ])) + self.assertNotIn(NEVER_EQ, self.type2test([])) + self.assertNotIn(ALWAYS_EQ, self.type2test([NEVER_EQ])) + self.assertIn(NEVER_EQ, self.type2test([ALWAYS_EQ])) + + def test_contains_order(self): + # Sequences must test in-order. If a rich comparison has side + # effects, these will be visible to tests against later members. + # In this test, the "side effect" is a short-circuiting raise. + class DoNotTestEq(Exception): + pass + class StopCompares: + def __eq__(self, other): + raise DoNotTestEq + + checkfirst = self.type2test([1, StopCompares()]) + self.assertIn(1, checkfirst) + checklast = self.type2test([StopCompares(), 1]) + self.assertRaises(DoNotTestEq, checklast.__contains__, 1) + + def test_len(self): + self.assertEqual(len(self.type2test()), 0) + self.assertEqual(len(self.type2test([])), 0) + self.assertEqual(len(self.type2test([0])), 1) + self.assertEqual(len(self.type2test([0, 1, 2])), 3) + + def test_minmax(self): + u = self.type2test([0, 1, 2]) + self.assertEqual(min(u), 0) + self.assertEqual(max(u), 2) + + def test_addmul(self): + u1 = self.type2test([0]) + u2 = self.type2test([0, 1]) + self.assertEqual(u1, u1 + self.type2test()) + self.assertEqual(u1, self.type2test() + u1) + self.assertEqual(u1 + self.type2test([1]), u2) + self.assertEqual(self.type2test([-1]) + u1, self.type2test([-1, 0])) + self.assertEqual(self.type2test(), u2*0) + self.assertEqual(self.type2test(), 0*u2) + self.assertEqual(self.type2test(), u2*0) + self.assertEqual(self.type2test(), 0*u2) + self.assertEqual(u2, u2*1) + self.assertEqual(u2, 1*u2) + self.assertEqual(u2, u2*1) + self.assertEqual(u2, 1*u2) + self.assertEqual(u2+u2, u2*2) + self.assertEqual(u2+u2, 2*u2) + self.assertEqual(u2+u2, u2*2) + self.assertEqual(u2+u2, 2*u2) + self.assertEqual(u2+u2+u2, u2*3) + self.assertEqual(u2+u2+u2, 3*u2) + + class subclass(self.type2test): + pass + u3 = subclass([0, 1]) + self.assertEqual(u3, u3*1) + self.assertIsNot(u3, u3*1) + + def test_iadd(self): + u = self.type2test([0, 1]) + u += self.type2test() + self.assertEqual(u, self.type2test([0, 1])) + u += self.type2test([2, 3]) + self.assertEqual(u, self.type2test([0, 1, 2, 3])) + u += self.type2test([4, 5]) + self.assertEqual(u, self.type2test([0, 1, 2, 3, 4, 5])) + + u = self.type2test("spam") + u += self.type2test("eggs") + self.assertEqual(u, self.type2test("spameggs")) + + def test_imul(self): + u = self.type2test([0, 1]) + u *= 3 + self.assertEqual(u, self.type2test([0, 1, 0, 1, 0, 1])) + u *= 0 + self.assertEqual(u, self.type2test([])) + + def test_getitemoverwriteiter(self): + # Verify that __getitem__ overrides are not recognized by __iter__ + class T(self.type2test): + def __getitem__(self, key): + return str(key) + '!!!' + self.assertEqual(next(iter(T((1,2)))), 1) + + def test_repeat(self): + for m in range(4): + s = tuple(range(m)) + for n in range(-3, 5): + self.assertEqual(self.type2test(s*n), self.type2test(s)*n) + self.assertEqual(self.type2test(s)*(-4), self.type2test([])) + self.assertEqual(id(s), id(s*1)) + + def test_bigrepeat(self): + if sys.maxsize <= 2147483647: + x = self.type2test([0]) + x *= 2**16 + self.assertRaises(MemoryError, x.__mul__, 2**16) + if hasattr(x, '__imul__'): + self.assertRaises(MemoryError, x.__imul__, 2**16) + + def test_subscript(self): + a = self.type2test([10, 11]) + self.assertEqual(a.__getitem__(0), 10) + self.assertEqual(a.__getitem__(1), 11) + self.assertEqual(a.__getitem__(-2), 10) + self.assertEqual(a.__getitem__(-1), 11) + self.assertRaises(IndexError, a.__getitem__, -3) + self.assertRaises(IndexError, a.__getitem__, 3) + self.assertEqual(a.__getitem__(slice(0,1)), self.type2test([10])) + self.assertEqual(a.__getitem__(slice(1,2)), self.type2test([11])) + self.assertEqual(a.__getitem__(slice(0,2)), self.type2test([10, 11])) + self.assertEqual(a.__getitem__(slice(0,3)), self.type2test([10, 11])) + self.assertEqual(a.__getitem__(slice(3,5)), self.type2test([])) + self.assertRaises(ValueError, a.__getitem__, slice(0, 10, 0)) + self.assertRaises(TypeError, a.__getitem__, 'x') + + def test_count(self): + a = self.type2test([0, 1, 2])*3 + self.assertEqual(a.count(0), 3) + self.assertEqual(a.count(1), 3) + self.assertEqual(a.count(3), 0) + + self.assertEqual(a.count(ALWAYS_EQ), 9) + self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).count(1), 2) + self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).count(NEVER_EQ), 2) + self.assertEqual(self.type2test([NEVER_EQ, NEVER_EQ]).count(ALWAYS_EQ), 0) + + self.assertRaises(TypeError, a.count) + + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False + + self.assertRaises(BadExc, a.count, BadCmp()) + + def test_index(self): + u = self.type2test([0, 1]) + self.assertEqual(u.index(0), 0) + self.assertEqual(u.index(1), 1) + self.assertRaises(ValueError, u.index, 2) + + u = self.type2test([-2, -1, 0, 0, 1, 2]) + self.assertEqual(u.count(0), 2) + self.assertEqual(u.index(0), 2) + self.assertEqual(u.index(0, 2), 2) + self.assertEqual(u.index(-2, -10), 0) + self.assertEqual(u.index(0, 3), 3) + self.assertEqual(u.index(0, 3, 4), 3) + self.assertRaises(ValueError, u.index, 2, 0, -10) + + self.assertEqual(u.index(ALWAYS_EQ), 0) + self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).index(1), 0) + self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).index(NEVER_EQ), 0) + self.assertRaises(ValueError, self.type2test([NEVER_EQ, NEVER_EQ]).index, ALWAYS_EQ) + + self.assertRaises(TypeError, u.index) + + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False + + a = self.type2test([0, 1, 2, 3]) + self.assertRaises(BadExc, a.index, BadCmp()) + + a = self.type2test([-2, -1, 0, 0, 1, 2]) + self.assertEqual(a.index(0), 2) + self.assertEqual(a.index(0, 2), 2) + self.assertEqual(a.index(0, -4), 2) + self.assertEqual(a.index(-2, -10), 0) + self.assertEqual(a.index(0, 3), 3) + self.assertEqual(a.index(0, -3), 3) + self.assertEqual(a.index(0, 3, 4), 3) + self.assertEqual(a.index(0, -3, -2), 3) + self.assertEqual(a.index(0, -4*sys.maxsize, 4*sys.maxsize), 2) + self.assertRaises(ValueError, a.index, 0, 4*sys.maxsize,-4*sys.maxsize) + self.assertRaises(ValueError, a.index, 2, 0, -10) + + def test_pickle(self): + lst = self.type2test([4, 5, 6, 7]) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + lst2 = pickle.loads(pickle.dumps(lst, proto)) + self.assertEqual(lst2, lst) + self.assertNotEqual(id(lst2), id(lst)) + + @support.suppress_immortalization() + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, self.type2test) + support.check_free_after_iterating(self, reversed, self.type2test) diff --git a/test/dynamo/cpython/3_13/test_baseexception.diff b/test/dynamo/cpython/3_13/test_baseexception.diff new file mode 100644 index 00000000000000..b25d72d0f65dd2 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_baseexception.diff @@ -0,0 +1,94 @@ +diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py +index e599b02c17d..3dc102e3b8a 100644 +--- a/test/dynamo/cpython/3_13/test_baseexception.py ++++ b/test/dynamo/cpython/3_13/test_baseexception.py +@@ -1,10 +1,61 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import unittest + import builtins + import os + from platform import system as platform_system + + +-class ExceptionClassTests(unittest.TestCase): ++class ExceptionClassTests(__TestCase): + + """Tests for anything relating to exception objects themselves (e.g., + inheritance hierarchy)""" +@@ -78,9 +129,6 @@ class ExceptionClassTests(unittest.TestCase): + last_depth = depth + finally: + inheritance_tree.close() +- +- # Underscore-prefixed (private) exceptions don't need to be documented +- exc_set = set(e for e in exc_set if not e.startswith('_')) + self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set) + + interface_tests = ("length", "args", "str", "repr") +@@ -142,7 +190,7 @@ class ExceptionClassTests(unittest.TestCase): + gc.collect() + + +-class UsageTests(unittest.TestCase): ++class UsageTests(__TestCase): + + """Test usage of exceptions""" + +@@ -208,5 +256,5 @@ class UsageTests(unittest.TestCase): + self.catch_fails("spam") + + +-if __name__ == '__main__': +- unittest.main() ++if __name__ == "__main__": ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py new file mode 100644 index 00000000000000..3dc102e3b8a2e0 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_baseexception.py @@ -0,0 +1,260 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import unittest +import builtins +import os +from platform import system as platform_system + + +class ExceptionClassTests(__TestCase): + + """Tests for anything relating to exception objects themselves (e.g., + inheritance hierarchy)""" + + def test_builtins_new_style(self): + self.assertTrue(issubclass(Exception, object)) + + def verify_instance_interface(self, ins): + for attr in ("args", "__str__", "__repr__"): + self.assertTrue(hasattr(ins, attr), + "%s missing %s attribute" % + (ins.__class__.__name__, attr)) + + def test_inheritance(self): + # Make sure the inheritance hierarchy matches the documentation + exc_set = set() + for object_ in builtins.__dict__.values(): + try: + if issubclass(object_, BaseException): + exc_set.add(object_.__name__) + except TypeError: + pass + + inheritance_tree = open( + os.path.join(os.path.split(__file__)[0], 'exception_hierarchy.txt'), + encoding="utf-8") + try: + superclass_name = inheritance_tree.readline().rstrip() + try: + last_exc = getattr(builtins, superclass_name) + except AttributeError: + self.fail("base class %s not a built-in" % superclass_name) + self.assertIn(superclass_name, exc_set, + '%s not found' % superclass_name) + exc_set.discard(superclass_name) + superclasses = [] # Loop will insert base exception + last_depth = 0 + for exc_line in inheritance_tree: + exc_line = exc_line.rstrip() + depth = exc_line.rindex('─') + exc_name = exc_line[depth+2:] # Slice past space + if '(' in exc_name: + paren_index = exc_name.index('(') + platform_name = exc_name[paren_index+1:-1] + exc_name = exc_name[:paren_index-1] # Slice off space + if platform_system() != platform_name: + exc_set.discard(exc_name) + continue + if '[' in exc_name: + left_bracket = exc_name.index('[') + exc_name = exc_name[:left_bracket-1] # cover space + try: + exc = getattr(builtins, exc_name) + except AttributeError: + self.fail("%s not a built-in exception" % exc_name) + if last_depth < depth: + superclasses.append((last_depth, last_exc)) + elif last_depth > depth: + while superclasses[-1][0] >= depth: + superclasses.pop() + self.assertTrue(issubclass(exc, superclasses[-1][1]), + "%s is not a subclass of %s" % (exc.__name__, + superclasses[-1][1].__name__)) + try: # Some exceptions require arguments; just skip them + self.verify_instance_interface(exc()) + except TypeError: + pass + self.assertIn(exc_name, exc_set) + exc_set.discard(exc_name) + last_exc = exc + last_depth = depth + finally: + inheritance_tree.close() + self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set) + + interface_tests = ("length", "args", "str", "repr") + + def interface_test_driver(self, results): + for test_name, (given, expected) in zip(self.interface_tests, results): + self.assertEqual(given, expected, "%s: %s != %s" % (test_name, + given, expected)) + + def test_interface_single_arg(self): + # Make sure interface works properly when given a single argument + arg = "spam" + exc = Exception(arg) + results = ([len(exc.args), 1], [exc.args[0], arg], + [str(exc), str(arg)], + [repr(exc), '%s(%r)' % (exc.__class__.__name__, arg)]) + self.interface_test_driver(results) + + def test_interface_multi_arg(self): + # Make sure interface correct when multiple arguments given + arg_count = 3 + args = tuple(range(arg_count)) + exc = Exception(*args) + results = ([len(exc.args), arg_count], [exc.args, args], + [str(exc), str(args)], + [repr(exc), exc.__class__.__name__ + repr(exc.args)]) + self.interface_test_driver(results) + + def test_interface_no_arg(self): + # Make sure that with no args that interface is correct + exc = Exception() + results = ([len(exc.args), 0], [exc.args, tuple()], + [str(exc), ''], + [repr(exc), exc.__class__.__name__ + '()']) + self.interface_test_driver(results) + + def test_setstate_refcount_no_crash(self): + # gh-97591: Acquire strong reference before calling tp_hash slot + # in PyObject_SetAttr. + import gc + d = {} + class HashThisKeyWillClearTheDict(str): + def __hash__(self) -> int: + d.clear() + return super().__hash__() + class Value(str): + pass + exc = Exception() + + d[HashThisKeyWillClearTheDict()] = Value() # refcount of Value() is 1 now + + # Exception.__setstate__ should acquire a strong reference of key and + # value in the dict. Otherwise, Value()'s refcount would go below + # zero in the tp_hash call in PyObject_SetAttr(), and it would cause + # crash in GC. + exc.__setstate__(d) # __hash__() is called again here, clearing the dict. + + # This GC would crash if the refcount of Value() goes below zero. + gc.collect() + + +class UsageTests(__TestCase): + + """Test usage of exceptions""" + + def raise_fails(self, object_): + """Make sure that raising 'object_' triggers a TypeError.""" + try: + raise object_ + except TypeError: + return # What is expected. + self.fail("TypeError expected for raising %s" % type(object_)) + + def catch_fails(self, object_): + """Catching 'object_' should raise a TypeError.""" + try: + try: + raise Exception + except object_: + pass + except TypeError: + pass + except Exception: + self.fail("TypeError expected when catching %s" % type(object_)) + + try: + try: + raise Exception + except (object_,): + pass + except TypeError: + return + except Exception: + self.fail("TypeError expected when catching %s as specified in a " + "tuple" % type(object_)) + + def test_raise_new_style_non_exception(self): + # You cannot raise a new-style class that does not inherit from + # BaseException; the ability was not possible until BaseException's + # introduction so no need to support new-style objects that do not + # inherit from it. + class NewStyleClass(object): + pass + self.raise_fails(NewStyleClass) + self.raise_fails(NewStyleClass()) + + def test_raise_string(self): + # Raising a string raises TypeError. + self.raise_fails("spam") + + def test_catch_non_BaseException(self): + # Trying to catch an object that does not inherit from BaseException + # is not allowed. + class NonBaseException(object): + pass + self.catch_fails(NonBaseException) + self.catch_fails(NonBaseException()) + + def test_catch_BaseException_instance(self): + # Catching an instance of a BaseException subclass won't work. + self.catch_fails(BaseException()) + + def test_catch_string(self): + # Catching a string is bad. + self.catch_fails("spam") + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_cmath.diff b/test/dynamo/cpython/3_13/test_cmath.diff new file mode 100644 index 00000000000000..7157e8c0498f68 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_cmath.diff @@ -0,0 +1,116 @@ +diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py +index a96a5780b31..883e87a0733 100644 +--- a/test/dynamo/cpython/3_13/test_cmath.py ++++ b/test/dynamo/cpython/3_13/test_cmath.py +@@ -1,5 +1,55 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + from test.support import requires_IEEE_754, cpython_only, import_helper +-from test.support.testcase import ComplexesAreIdenticalMixin + from test.test_math import parse_testfile, test_file + import test.test_math as test_math + import unittest +@@ -50,7 +100,7 @@ complex_nans = [complex(x, y) for x, y in [ + (INF, NAN) + ]] + +-class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): ++class CMathTests(__TestCase): + # list of all functions in cmath + test_functions = [getattr(cmath, fname) for fname in [ + 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', +@@ -66,6 +116,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): + def tearDown(self): + self.test_values.close() + ++ def assertFloatIdentical(self, x, y): ++ """Fail unless floats x and y are identical, in the sense that: ++ (1) both x and y are nans, or ++ (2) both x and y are infinities, with the same sign, or ++ (3) both x and y are zeros, with the same sign, or ++ (4) x and y are both finite and nonzero, and x == y ++ ++ """ ++ msg = 'floats {!r} and {!r} are not identical' ++ ++ if math.isnan(x) or math.isnan(y): ++ if math.isnan(x) and math.isnan(y): ++ return ++ elif x == y: ++ if x != 0.0: ++ return ++ # both zero; check that signs match ++ elif math.copysign(1.0, x) == math.copysign(1.0, y): ++ return ++ else: ++ msg += ': zeros have different signs' ++ self.fail(msg.format(x, y)) ++ ++ def assertComplexesAreIdentical(self, x, y): ++ """Fail unless complex numbers x and y have equal values and signs. ++ ++ In particular, if x and y both have real (or imaginary) part ++ zero, but the zeros have different signs, this test will fail. ++ ++ """ ++ self.assertFloatIdentical(x.real, y.real) ++ self.assertFloatIdentical(x.imag, y.imag) ++ + def rAssertAlmostEqual(self, a, b, rel_err = 2e-15, abs_err = 5e-323, + msg=None): + """Fail if the two floating-point numbers are not almost equal. +@@ -590,4 +673,4 @@ class IsCloseTests(test_math.IsCloseTests): + + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py new file mode 100644 index 00000000000000..883e87a07337aa --- /dev/null +++ b/test/dynamo/cpython/3_13/test_cmath.py @@ -0,0 +1,676 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +from test.support import requires_IEEE_754, cpython_only, import_helper +from test.test_math import parse_testfile, test_file +import test.test_math as test_math +import unittest +import cmath, math +from cmath import phase, polar, rect, pi +import platform +import sys + + +INF = float('inf') +NAN = float('nan') + +complex_zeros = [complex(x, y) for x in [0.0, -0.0] for y in [0.0, -0.0]] +complex_infinities = [complex(x, y) for x, y in [ + (INF, 0.0), # 1st quadrant + (INF, 2.3), + (INF, INF), + (2.3, INF), + (0.0, INF), + (-0.0, INF), # 2nd quadrant + (-2.3, INF), + (-INF, INF), + (-INF, 2.3), + (-INF, 0.0), + (-INF, -0.0), # 3rd quadrant + (-INF, -2.3), + (-INF, -INF), + (-2.3, -INF), + (-0.0, -INF), + (0.0, -INF), # 4th quadrant + (2.3, -INF), + (INF, -INF), + (INF, -2.3), + (INF, -0.0) + ]] +complex_nans = [complex(x, y) for x, y in [ + (NAN, -INF), + (NAN, -2.3), + (NAN, -0.0), + (NAN, 0.0), + (NAN, 2.3), + (NAN, INF), + (-INF, NAN), + (-2.3, NAN), + (-0.0, NAN), + (0.0, NAN), + (2.3, NAN), + (INF, NAN) + ]] + +class CMathTests(__TestCase): + # list of all functions in cmath + test_functions = [getattr(cmath, fname) for fname in [ + 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', + 'cos', 'cosh', 'exp', 'log', 'log10', 'sin', 'sinh', + 'sqrt', 'tan', 'tanh']] + # test first and second arguments independently for 2-argument log + test_functions.append(lambda x : cmath.log(x, 1729. + 0j)) + test_functions.append(lambda x : cmath.log(14.-27j, x)) + + def setUp(self): + self.test_values = open(test_file, encoding="utf-8") + + def tearDown(self): + self.test_values.close() + + def assertFloatIdentical(self, x, y): + """Fail unless floats x and y are identical, in the sense that: + (1) both x and y are nans, or + (2) both x and y are infinities, with the same sign, or + (3) both x and y are zeros, with the same sign, or + (4) x and y are both finite and nonzero, and x == y + + """ + msg = 'floats {!r} and {!r} are not identical' + + if math.isnan(x) or math.isnan(y): + if math.isnan(x) and math.isnan(y): + return + elif x == y: + if x != 0.0: + return + # both zero; check that signs match + elif math.copysign(1.0, x) == math.copysign(1.0, y): + return + else: + msg += ': zeros have different signs' + self.fail(msg.format(x, y)) + + def assertComplexesAreIdentical(self, x, y): + """Fail unless complex numbers x and y have equal values and signs. + + In particular, if x and y both have real (or imaginary) part + zero, but the zeros have different signs, this test will fail. + + """ + self.assertFloatIdentical(x.real, y.real) + self.assertFloatIdentical(x.imag, y.imag) + + def rAssertAlmostEqual(self, a, b, rel_err = 2e-15, abs_err = 5e-323, + msg=None): + """Fail if the two floating-point numbers are not almost equal. + + Determine whether floating-point values a and b are equal to within + a (small) rounding error. The default values for rel_err and + abs_err are chosen to be suitable for platforms where a float is + represented by an IEEE 754 double. They allow an error of between + 9 and 19 ulps. + """ + + # special values testing + if math.isnan(a): + if math.isnan(b): + return + self.fail(msg or '{!r} should be nan'.format(b)) + + if math.isinf(a): + if a == b: + return + self.fail(msg or 'finite result where infinity expected: ' + 'expected {!r}, got {!r}'.format(a, b)) + + # if both a and b are zero, check whether they have the same sign + # (in theory there are examples where it would be legitimate for a + # and b to have opposite signs; in practice these hardly ever + # occur). + if not a and not b: + if math.copysign(1., a) != math.copysign(1., b): + self.fail(msg or 'zero has wrong sign: expected {!r}, ' + 'got {!r}'.format(a, b)) + + # if a-b overflows, or b is infinite, return False. Again, in + # theory there are examples where a is within a few ulps of the + # max representable float, and then b could legitimately be + # infinite. In practice these examples are rare. + try: + absolute_error = abs(b-a) + except OverflowError: + pass + else: + # test passes if either the absolute error or the relative + # error is sufficiently small. The defaults amount to an + # error of between 9 ulps and 19 ulps on an IEEE-754 compliant + # machine. + if absolute_error <= max(abs_err, rel_err * abs(a)): + return + self.fail(msg or + '{!r} and {!r} are not sufficiently close'.format(a, b)) + + def test_constants(self): + e_expected = 2.71828182845904523536 + pi_expected = 3.14159265358979323846 + self.assertAlmostEqual(cmath.pi, pi_expected, places=9, + msg="cmath.pi is {}; should be {}".format(cmath.pi, pi_expected)) + self.assertAlmostEqual(cmath.e, e_expected, places=9, + msg="cmath.e is {}; should be {}".format(cmath.e, e_expected)) + + def test_infinity_and_nan_constants(self): + self.assertEqual(cmath.inf.real, math.inf) + self.assertEqual(cmath.inf.imag, 0.0) + self.assertEqual(cmath.infj.real, 0.0) + self.assertEqual(cmath.infj.imag, math.inf) + + self.assertTrue(math.isnan(cmath.nan.real)) + self.assertEqual(cmath.nan.imag, 0.0) + self.assertEqual(cmath.nanj.real, 0.0) + self.assertTrue(math.isnan(cmath.nanj.imag)) + # Also check that the sign of all of these is positive: + self.assertEqual(math.copysign(1., cmath.nan.real), 1.) + self.assertEqual(math.copysign(1., cmath.nan.imag), 1.) + self.assertEqual(math.copysign(1., cmath.nanj.real), 1.) + self.assertEqual(math.copysign(1., cmath.nanj.imag), 1.) + + # Check consistency with reprs. + self.assertEqual(repr(cmath.inf), "inf") + self.assertEqual(repr(cmath.infj), "infj") + self.assertEqual(repr(cmath.nan), "nan") + self.assertEqual(repr(cmath.nanj), "nanj") + + def test_user_object(self): + # Test automatic calling of __complex__ and __float__ by cmath + # functions + + # some random values to use as test values; we avoid values + # for which any of the functions in cmath is undefined + # (i.e. 0., 1., -1., 1j, -1j) or would cause overflow + cx_arg = 4.419414439 + 1.497100113j + flt_arg = -6.131677725 + + # a variety of non-complex numbers, used to check that + # non-complex return values from __complex__ give an error + non_complexes = ["not complex", 1, 5, 2., None, + object(), NotImplemented] + + # Now we introduce a variety of classes whose instances might + # end up being passed to the cmath functions + + # usual case: new-style class implementing __complex__ + class MyComplex: + def __init__(self, value): + self.value = value + def __complex__(self): + return self.value + + # classes for which __complex__ raises an exception + class SomeException(Exception): + pass + class MyComplexException: + def __complex__(self): + raise SomeException + + # some classes not providing __float__ or __complex__ + class NeitherComplexNorFloat(object): + pass + class Index: + def __int__(self): return 2 + def __index__(self): return 2 + class MyInt: + def __int__(self): return 2 + + # other possible combinations of __float__ and __complex__ + # that should work + class FloatAndComplex: + def __float__(self): + return flt_arg + def __complex__(self): + return cx_arg + class JustFloat: + def __float__(self): + return flt_arg + + for f in self.test_functions: + # usual usage + self.assertEqual(f(MyComplex(cx_arg)), f(cx_arg)) + # other combinations of __float__ and __complex__ + self.assertEqual(f(FloatAndComplex()), f(cx_arg)) + self.assertEqual(f(JustFloat()), f(flt_arg)) + self.assertEqual(f(Index()), f(int(Index()))) + # TypeError should be raised for classes not providing + # either __complex__ or __float__, even if they provide + # __int__ or __index__: + self.assertRaises(TypeError, f, NeitherComplexNorFloat()) + self.assertRaises(TypeError, f, MyInt()) + # non-complex return value from __complex__ -> TypeError + for bad_complex in non_complexes: + self.assertRaises(TypeError, f, MyComplex(bad_complex)) + # exceptions in __complex__ should be propagated correctly + self.assertRaises(SomeException, f, MyComplexException()) + + def test_input_type(self): + # ints should be acceptable inputs to all cmath + # functions, by virtue of providing a __float__ method + for f in self.test_functions: + for arg in [2, 2.]: + self.assertEqual(f(arg), f(arg.__float__())) + + # but strings should give a TypeError + for f in self.test_functions: + for arg in ["a", "long_string", "0", "1j", ""]: + self.assertRaises(TypeError, f, arg) + + def test_cmath_matches_math(self): + # check that corresponding cmath and math functions are equal + # for floats in the appropriate range + + # test_values in (0, 1) + test_values = [0.01, 0.1, 0.2, 0.5, 0.9, 0.99] + + # test_values for functions defined on [-1., 1.] + unit_interval = test_values + [-x for x in test_values] + \ + [0., 1., -1.] + + # test_values for log, log10, sqrt + positive = test_values + [1.] + [1./x for x in test_values] + nonnegative = [0.] + positive + + # test_values for functions defined on the whole real line + real_line = [0.] + positive + [-x for x in positive] + + test_functions = { + 'acos' : unit_interval, + 'asin' : unit_interval, + 'atan' : real_line, + 'cos' : real_line, + 'cosh' : real_line, + 'exp' : real_line, + 'log' : positive, + 'log10' : positive, + 'sin' : real_line, + 'sinh' : real_line, + 'sqrt' : nonnegative, + 'tan' : real_line, + 'tanh' : real_line} + + for fn, values in test_functions.items(): + float_fn = getattr(math, fn) + complex_fn = getattr(cmath, fn) + for v in values: + z = complex_fn(v) + self.rAssertAlmostEqual(float_fn(v), z.real) + self.assertEqual(0., z.imag) + + # test two-argument version of log with various bases + for base in [0.5, 2., 10.]: + for v in positive: + z = cmath.log(v, base) + self.rAssertAlmostEqual(math.log(v, base), z.real) + self.assertEqual(0., z.imag) + + @requires_IEEE_754 + def test_specific_values(self): + # Some tests need to be skipped on ancient OS X versions. + # See issue #27953. + SKIP_ON_TIGER = {'tan0064'} + + osx_version = None + if sys.platform == 'darwin': + version_txt = platform.mac_ver()[0] + try: + osx_version = tuple(map(int, version_txt.split('.'))) + except ValueError: + pass + + def rect_complex(z): + """Wrapped version of rect that accepts a complex number instead of + two float arguments.""" + return cmath.rect(z.real, z.imag) + + def polar_complex(z): + """Wrapped version of polar that returns a complex number instead of + two floats.""" + return complex(*polar(z)) + + for id, fn, ar, ai, er, ei, flags in parse_testfile(test_file): + arg = complex(ar, ai) + expected = complex(er, ei) + + # Skip certain tests on OS X 10.4. + if osx_version is not None and osx_version < (10, 5): + if id in SKIP_ON_TIGER: + continue + + if fn == 'rect': + function = rect_complex + elif fn == 'polar': + function = polar_complex + else: + function = getattr(cmath, fn) + if 'divide-by-zero' in flags or 'invalid' in flags: + try: + actual = function(arg) + except ValueError: + continue + else: + self.fail('ValueError not raised in test ' + '{}: {}(complex({!r}, {!r}))'.format(id, fn, ar, ai)) + + if 'overflow' in flags: + try: + actual = function(arg) + except OverflowError: + continue + else: + self.fail('OverflowError not raised in test ' + '{}: {}(complex({!r}, {!r}))'.format(id, fn, ar, ai)) + + actual = function(arg) + + if 'ignore-real-sign' in flags: + actual = complex(abs(actual.real), actual.imag) + expected = complex(abs(expected.real), expected.imag) + if 'ignore-imag-sign' in flags: + actual = complex(actual.real, abs(actual.imag)) + expected = complex(expected.real, abs(expected.imag)) + + # for the real part of the log function, we allow an + # absolute error of up to 2e-15. + if fn in ('log', 'log10'): + real_abs_err = 2e-15 + else: + real_abs_err = 5e-323 + + error_message = ( + '{}: {}(complex({!r}, {!r}))\n' + 'Expected: complex({!r}, {!r})\n' + 'Received: complex({!r}, {!r})\n' + 'Received value insufficiently close to expected value.' + ).format(id, fn, ar, ai, + expected.real, expected.imag, + actual.real, actual.imag) + self.rAssertAlmostEqual(expected.real, actual.real, + abs_err=real_abs_err, + msg=error_message) + self.rAssertAlmostEqual(expected.imag, actual.imag, + msg=error_message) + + def check_polar(self, func): + def check(arg, expected): + got = func(arg) + for e, g in zip(expected, got): + self.rAssertAlmostEqual(e, g) + check(0, (0., 0.)) + check(1, (1., 0.)) + check(-1, (1., pi)) + check(1j, (1., pi / 2)) + check(-3j, (3., -pi / 2)) + inf = float('inf') + check(complex(inf, 0), (inf, 0.)) + check(complex(-inf, 0), (inf, pi)) + check(complex(3, inf), (inf, pi / 2)) + check(complex(5, -inf), (inf, -pi / 2)) + check(complex(inf, inf), (inf, pi / 4)) + check(complex(inf, -inf), (inf, -pi / 4)) + check(complex(-inf, inf), (inf, 3 * pi / 4)) + check(complex(-inf, -inf), (inf, -3 * pi / 4)) + nan = float('nan') + check(complex(nan, 0), (nan, nan)) + check(complex(0, nan), (nan, nan)) + check(complex(nan, nan), (nan, nan)) + check(complex(inf, nan), (inf, nan)) + check(complex(-inf, nan), (inf, nan)) + check(complex(nan, inf), (inf, nan)) + check(complex(nan, -inf), (inf, nan)) + + def test_polar(self): + self.check_polar(polar) + + @cpython_only + def test_polar_errno(self): + # Issue #24489: check a previously set C errno doesn't disturb polar() + _testcapi = import_helper.import_module('_testcapi') + def polar_with_errno_set(z): + _testcapi.set_errno(11) + try: + return polar(z) + finally: + _testcapi.set_errno(0) + self.check_polar(polar_with_errno_set) + + def test_phase(self): + self.assertAlmostEqual(phase(0), 0.) + self.assertAlmostEqual(phase(1.), 0.) + self.assertAlmostEqual(phase(-1.), pi) + self.assertAlmostEqual(phase(-1.+1E-300j), pi) + self.assertAlmostEqual(phase(-1.-1E-300j), -pi) + self.assertAlmostEqual(phase(1j), pi/2) + self.assertAlmostEqual(phase(-1j), -pi/2) + + # zeros + self.assertEqual(phase(complex(0.0, 0.0)), 0.0) + self.assertEqual(phase(complex(0.0, -0.0)), -0.0) + self.assertEqual(phase(complex(-0.0, 0.0)), pi) + self.assertEqual(phase(complex(-0.0, -0.0)), -pi) + + # infinities + self.assertAlmostEqual(phase(complex(-INF, -0.0)), -pi) + self.assertAlmostEqual(phase(complex(-INF, -2.3)), -pi) + self.assertAlmostEqual(phase(complex(-INF, -INF)), -0.75*pi) + self.assertAlmostEqual(phase(complex(-2.3, -INF)), -pi/2) + self.assertAlmostEqual(phase(complex(-0.0, -INF)), -pi/2) + self.assertAlmostEqual(phase(complex(0.0, -INF)), -pi/2) + self.assertAlmostEqual(phase(complex(2.3, -INF)), -pi/2) + self.assertAlmostEqual(phase(complex(INF, -INF)), -pi/4) + self.assertEqual(phase(complex(INF, -2.3)), -0.0) + self.assertEqual(phase(complex(INF, -0.0)), -0.0) + self.assertEqual(phase(complex(INF, 0.0)), 0.0) + self.assertEqual(phase(complex(INF, 2.3)), 0.0) + self.assertAlmostEqual(phase(complex(INF, INF)), pi/4) + self.assertAlmostEqual(phase(complex(2.3, INF)), pi/2) + self.assertAlmostEqual(phase(complex(0.0, INF)), pi/2) + self.assertAlmostEqual(phase(complex(-0.0, INF)), pi/2) + self.assertAlmostEqual(phase(complex(-2.3, INF)), pi/2) + self.assertAlmostEqual(phase(complex(-INF, INF)), 0.75*pi) + self.assertAlmostEqual(phase(complex(-INF, 2.3)), pi) + self.assertAlmostEqual(phase(complex(-INF, 0.0)), pi) + + # real or imaginary part NaN + for z in complex_nans: + self.assertTrue(math.isnan(phase(z))) + + def test_abs(self): + # zeros + for z in complex_zeros: + self.assertEqual(abs(z), 0.0) + + # infinities + for z in complex_infinities: + self.assertEqual(abs(z), INF) + + # real or imaginary part NaN + self.assertEqual(abs(complex(NAN, -INF)), INF) + self.assertTrue(math.isnan(abs(complex(NAN, -2.3)))) + self.assertTrue(math.isnan(abs(complex(NAN, -0.0)))) + self.assertTrue(math.isnan(abs(complex(NAN, 0.0)))) + self.assertTrue(math.isnan(abs(complex(NAN, 2.3)))) + self.assertEqual(abs(complex(NAN, INF)), INF) + self.assertEqual(abs(complex(-INF, NAN)), INF) + self.assertTrue(math.isnan(abs(complex(-2.3, NAN)))) + self.assertTrue(math.isnan(abs(complex(-0.0, NAN)))) + self.assertTrue(math.isnan(abs(complex(0.0, NAN)))) + self.assertTrue(math.isnan(abs(complex(2.3, NAN)))) + self.assertEqual(abs(complex(INF, NAN)), INF) + self.assertTrue(math.isnan(abs(complex(NAN, NAN)))) + + + @requires_IEEE_754 + def test_abs_overflows(self): + # result overflows + self.assertRaises(OverflowError, abs, complex(1.4e308, 1.4e308)) + + def assertCEqual(self, a, b): + eps = 1E-7 + if abs(a.real - b[0]) > eps or abs(a.imag - b[1]) > eps: + self.fail((a ,b)) + + def test_rect(self): + self.assertCEqual(rect(0, 0), (0, 0)) + self.assertCEqual(rect(1, 0), (1., 0)) + self.assertCEqual(rect(1, -pi), (-1., 0)) + self.assertCEqual(rect(1, pi/2), (0, 1.)) + self.assertCEqual(rect(1, -pi/2), (0, -1.)) + + def test_isfinite(self): + real_vals = [float('-inf'), -2.3, -0.0, + 0.0, 2.3, float('inf'), float('nan')] + for x in real_vals: + for y in real_vals: + z = complex(x, y) + self.assertEqual(cmath.isfinite(z), + math.isfinite(x) and math.isfinite(y)) + + def test_isnan(self): + self.assertFalse(cmath.isnan(1)) + self.assertFalse(cmath.isnan(1j)) + self.assertFalse(cmath.isnan(INF)) + self.assertTrue(cmath.isnan(NAN)) + self.assertTrue(cmath.isnan(complex(NAN, 0))) + self.assertTrue(cmath.isnan(complex(0, NAN))) + self.assertTrue(cmath.isnan(complex(NAN, NAN))) + self.assertTrue(cmath.isnan(complex(NAN, INF))) + self.assertTrue(cmath.isnan(complex(INF, NAN))) + + def test_isinf(self): + self.assertFalse(cmath.isinf(1)) + self.assertFalse(cmath.isinf(1j)) + self.assertFalse(cmath.isinf(NAN)) + self.assertTrue(cmath.isinf(INF)) + self.assertTrue(cmath.isinf(complex(INF, 0))) + self.assertTrue(cmath.isinf(complex(0, INF))) + self.assertTrue(cmath.isinf(complex(INF, INF))) + self.assertTrue(cmath.isinf(complex(NAN, INF))) + self.assertTrue(cmath.isinf(complex(INF, NAN))) + + @requires_IEEE_754 + def testTanhSign(self): + for z in complex_zeros: + self.assertComplexesAreIdentical(cmath.tanh(z), z) + + # The algorithm used for atan and atanh makes use of the system + # log1p function; If that system function doesn't respect the sign + # of zero, then atan and atanh will also have difficulties with + # the sign of complex zeros. + @requires_IEEE_754 + def testAtanSign(self): + for z in complex_zeros: + self.assertComplexesAreIdentical(cmath.atan(z), z) + + @requires_IEEE_754 + def testAtanhSign(self): + for z in complex_zeros: + self.assertComplexesAreIdentical(cmath.atanh(z), z) + + +class IsCloseTests(test_math.IsCloseTests): + isclose = cmath.isclose + + def test_reject_complex_tolerances(self): + with self.assertRaises(TypeError): + self.isclose(1j, 1j, rel_tol=1j) + + with self.assertRaises(TypeError): + self.isclose(1j, 1j, abs_tol=1j) + + with self.assertRaises(TypeError): + self.isclose(1j, 1j, rel_tol=1j, abs_tol=1j) + + def test_complex_values(self): + # test complex values that are close to within 12 decimal places + complex_examples = [(1.0+1.0j, 1.000000000001+1.0j), + (1.0+1.0j, 1.0+1.000000000001j), + (-1.0+1.0j, -1.000000000001+1.0j), + (1.0-1.0j, 1.0-0.999999999999j), + ] + + self.assertAllClose(complex_examples, rel_tol=1e-12) + self.assertAllNotClose(complex_examples, rel_tol=1e-13) + + def test_complex_near_zero(self): + # test values near zero that are near to within three decimal places + near_zero_examples = [(0.001j, 0), + (0.001, 0), + (0.001+0.001j, 0), + (-0.001+0.001j, 0), + (0.001-0.001j, 0), + (-0.001-0.001j, 0), + ] + + self.assertAllClose(near_zero_examples, abs_tol=1.5e-03) + self.assertAllNotClose(near_zero_examples, abs_tol=0.5e-03) + + self.assertIsClose(0.001-0.001j, 0.001+0.001j, abs_tol=2e-03) + self.assertIsNotClose(0.001-0.001j, 0.001+0.001j, abs_tol=1e-03) + + def test_complex_special(self): + self.assertIsNotClose(INF, INF*1j) + self.assertIsNotClose(INF*1j, INF) + self.assertIsNotClose(INF, -INF) + self.assertIsNotClose(-INF, INF) + self.assertIsNotClose(0, INF) + self.assertIsNotClose(0, INF*1j) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_contextlib.diff b/test/dynamo/cpython/3_13/test_contextlib.diff new file mode 100644 index 00000000000000..f3314f590c105a --- /dev/null +++ b/test/dynamo/cpython/3_13/test_contextlib.diff @@ -0,0 +1,195 @@ +diff --git a/test/dynamo/cpython/3_13/test_contextlib.py b/test/dynamo/cpython/3_13/test_contextlib.py +index cf651959803..6a17bc719eb 100644 +--- a/test/dynamo/cpython/3_13/test_contextlib.py ++++ b/test/dynamo/cpython/3_13/test_contextlib.py +@@ -1,3 +1,54 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + """Unit tests for contextlib.py, and other context managers.""" + + import io +@@ -14,7 +65,7 @@ from test.support.testcase import ExceptionIsLikeMixin + import weakref + + +-class TestAbstractContextManager(unittest.TestCase): ++class TestAbstractContextManager(__TestCase): + + def test_enter(self): + class DefaultEnter(AbstractContextManager): +@@ -67,7 +118,7 @@ class TestAbstractContextManager(unittest.TestCase): + self.assertFalse(issubclass(NoExit, AbstractContextManager)) + + +-class ContextManagerTestCase(unittest.TestCase): ++class ContextManagerTestCase(__TestCase): + + def test_contextmanager_plain(self): + state = [] +@@ -396,7 +447,7 @@ def woohoo(): + self.assertEqual(depth, 0) + + +-class ClosingTestCase(unittest.TestCase): ++class ClosingTestCase(__TestCase): + + @support.requires_docstrings + def test_instance_docs(self): +@@ -430,7 +481,7 @@ class ClosingTestCase(unittest.TestCase): + self.assertEqual(state, [1]) + + +-class NullcontextTestCase(unittest.TestCase): ++class NullcontextTestCase(__TestCase): + def test_nullcontext(self): + class C: + pass +@@ -439,7 +490,7 @@ class NullcontextTestCase(unittest.TestCase): + self.assertIs(c_in, c) + + +-class FileContextTestCase(unittest.TestCase): ++class FileContextTestCase(__TestCase): + + def testWithOpen(self): + tfn = tempfile.mktemp() +@@ -457,7 +508,7 @@ class FileContextTestCase(unittest.TestCase): + finally: + os_helper.unlink(tfn) + +-class LockContextTestCase(unittest.TestCase): ++class LockContextTestCase(__TestCase): + + def boilerPlate(self, lock, locked): + self.assertFalse(locked()) +@@ -520,7 +571,7 @@ class mycontext(ContextDecorator): + return self.catch + + +-class TestContextDecorator(unittest.TestCase): ++class TestContextDecorator(__TestCase): + + @support.requires_docstrings + def test_instance_docs(self): +@@ -680,7 +731,7 @@ class TestContextDecorator(unittest.TestCase): + self.assertEqual(state, [1, 'something else', 999]) + + +-class TestBaseExitStack: ++class _TestBaseExitStack: + exit_stack = None + + @support.requires_docstrings +@@ -1141,7 +1192,7 @@ class TestBaseExitStack: + self.assertIs(exc.__cause__, exc.__context__) + + +-class TestExitStack(TestBaseExitStack, unittest.TestCase): ++class TestExitStack(_TestBaseExitStack, __TestCase): + exit_stack = ExitStack + callback_error_internal_frames = [ + ('__exit__', 'raise exc'), +@@ -1149,7 +1200,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase): + ] + + +-class TestRedirectStream: ++class _TestRedirectStream: + + redirect_stream = None + orig_stream = None +@@ -1206,19 +1257,19 @@ class TestRedirectStream: + self.assertEqual(s, "Hello World!\n") + + +-class TestRedirectStdout(TestRedirectStream, unittest.TestCase): ++class TestRedirectStdout(_TestRedirectStream, __TestCase): + + redirect_stream = redirect_stdout + orig_stream = "stdout" + + +-class TestRedirectStderr(TestRedirectStream, unittest.TestCase): ++class TestRedirectStderr(_TestRedirectStream, __TestCase): + + redirect_stream = redirect_stderr + orig_stream = "stderr" + + +-class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase): ++class TestSuppress(ExceptionIsLikeMixin, __TestCase): + + @support.requires_docstrings + def test_instance_docs(self): +@@ -1315,7 +1366,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase): + ) + + +-class TestChdir(unittest.TestCase): ++class TestChdir(__TestCase): + def make_relative_path(self, *parts): + return os.path.join( + os.path.dirname(os.path.realpath(__file__)), +@@ -1331,6 +1382,7 @@ class TestChdir(unittest.TestCase): + self.assertEqual(os.getcwd(), target) + self.assertEqual(os.getcwd(), old_cwd) + ++ @unittest.skip("Missing archivetestdata") + def test_reentrant(self): + old_cwd = os.getcwd() + target1 = self.make_relative_path('data') +@@ -1363,4 +1415,4 @@ class TestChdir(unittest.TestCase): + + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_contextlib.py b/test/dynamo/cpython/3_13/test_contextlib.py new file mode 100644 index 00000000000000..6a17bc719eb947 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_contextlib.py @@ -0,0 +1,1418 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +"""Unit tests for contextlib.py, and other context managers.""" + +import io +import os +import sys +import tempfile +import threading +import traceback +import unittest +from contextlib import * # Tests __all__ +from test import support +from test.support import os_helper +from test.support.testcase import ExceptionIsLikeMixin +import weakref + + +class TestAbstractContextManager(__TestCase): + + def test_enter(self): + class DefaultEnter(AbstractContextManager): + def __exit__(self, *args): + super().__exit__(*args) + + manager = DefaultEnter() + self.assertIs(manager.__enter__(), manager) + + def test_slots(self): + class DefaultContextManager(AbstractContextManager): + __slots__ = () + + def __exit__(self, *args): + super().__exit__(*args) + + with self.assertRaises(AttributeError): + DefaultContextManager().var = 42 + + def test_exit_is_abstract(self): + class MissingExit(AbstractContextManager): + pass + + with self.assertRaises(TypeError): + MissingExit() + + def test_structural_subclassing(self): + class ManagerFromScratch: + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + return None + + self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager)) + + class DefaultEnter(AbstractContextManager): + def __exit__(self, *args): + super().__exit__(*args) + + self.assertTrue(issubclass(DefaultEnter, AbstractContextManager)) + + class NoEnter(ManagerFromScratch): + __enter__ = None + + self.assertFalse(issubclass(NoEnter, AbstractContextManager)) + + class NoExit(ManagerFromScratch): + __exit__ = None + + self.assertFalse(issubclass(NoExit, AbstractContextManager)) + + +class ContextManagerTestCase(__TestCase): + + def test_contextmanager_plain(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + yield 42 + state.append(999) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_finally(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + try: + yield 42 + finally: + state.append(999) + with self.assertRaises(ZeroDivisionError): + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError() + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_traceback(self): + @contextmanager + def f(): + yield + + try: + with f(): + 1/0 + except ZeroDivisionError as e: + frames = traceback.extract_tb(e.__traceback__) + + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].name, 'test_contextmanager_traceback') + self.assertEqual(frames[0].line, '1/0') + + # Repeat with RuntimeError (which goes through a different code path) + class RuntimeErrorSubclass(RuntimeError): + pass + + try: + with f(): + raise RuntimeErrorSubclass(42) + except RuntimeErrorSubclass as e: + frames = traceback.extract_tb(e.__traceback__) + + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].name, 'test_contextmanager_traceback') + self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)') + + class StopIterationSubclass(StopIteration): + pass + + for stop_exc in ( + StopIteration('spam'), + StopIterationSubclass('spam'), + ): + with self.subTest(type=type(stop_exc)): + try: + with f(): + raise stop_exc + except type(stop_exc) as e: + self.assertIs(e, stop_exc) + frames = traceback.extract_tb(e.__traceback__) + else: + self.fail(f'{stop_exc} was suppressed') + + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].name, 'test_contextmanager_traceback') + self.assertEqual(frames[0].line, 'raise stop_exc') + + def test_contextmanager_no_reraise(self): + @contextmanager + def whee(): + yield + ctx = whee() + ctx.__enter__() + # Calling __exit__ should not result in an exception + self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) + + def test_contextmanager_trap_yield_after_throw(self): + @contextmanager + def whoo(): + try: + yield + except: + yield + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(TypeError, TypeError("foo"), None) + if support.check_impl_detail(cpython=True): + # The "gen" attribute is an implementation detail. + self.assertFalse(ctx.gen.gi_suspended) + + def test_contextmanager_trap_no_yield(self): + @contextmanager + def whoo(): + if False: + yield + ctx = whoo() + with self.assertRaises(RuntimeError): + ctx.__enter__() + + def test_contextmanager_trap_second_yield(self): + @contextmanager + def whoo(): + yield + yield + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(None, None, None) + if support.check_impl_detail(cpython=True): + # The "gen" attribute is an implementation detail. + self.assertFalse(ctx.gen.gi_suspended) + + def test_contextmanager_non_normalised(self): + @contextmanager + def whoo(): + try: + yield + except RuntimeError: + raise SyntaxError + + ctx = whoo() + ctx.__enter__() + with self.assertRaises(SyntaxError): + ctx.__exit__(RuntimeError, None, None) + + def test_contextmanager_except(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + try: + yield 42 + except ZeroDivisionError as e: + state.append(e.args[0]) + self.assertEqual(state, [1, 42, 999]) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_except_stopiter(self): + @contextmanager + def woohoo(): + yield + + class StopIterationSubclass(StopIteration): + pass + + for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')): + with self.subTest(type=type(stop_exc)): + try: + with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail(f'{stop_exc} was suppressed') + + def test_contextmanager_except_pep479(self): + code = """\ +from __future__ import generator_stop +from contextlib import contextmanager +@contextmanager +def woohoo(): + yield +""" + locals = {} + exec(code, locals, locals) + woohoo = locals['woohoo'] + + stop_exc = StopIteration('spam') + try: + with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail('StopIteration was suppressed') + + def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self): + @contextmanager + def test_issue29692(): + try: + yield + except Exception as exc: + raise RuntimeError('issue29692:Chained') from exc + try: + with test_issue29692(): + raise ZeroDivisionError + except Exception as ex: + self.assertIs(type(ex), RuntimeError) + self.assertEqual(ex.args[0], 'issue29692:Chained') + self.assertIsInstance(ex.__cause__, ZeroDivisionError) + + try: + with test_issue29692(): + raise StopIteration('issue29692:Unchained') + except Exception as ex: + self.assertIs(type(ex), StopIteration) + self.assertEqual(ex.args[0], 'issue29692:Unchained') + self.assertIsNone(ex.__cause__) + + def test_contextmanager_wrap_runtimeerror(self): + @contextmanager + def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f'caught {exc}') from exc + + with self.assertRaises(RuntimeError): + with woohoo(): + 1 / 0 + + # If the context manager wrapped StopIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopIteration): + with woohoo(): + raise StopIteration + + def _create_contextmanager_attribs(self): + def attribs(**kw): + def decorate(func): + for k,v in kw.items(): + setattr(func,k,v) + return func + return decorate + @contextmanager + @attribs(foo='bar') + def baz(spam): + """Whee!""" + yield + return baz + + def test_contextmanager_attribs(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__name__,'baz') + self.assertEqual(baz.foo, 'bar') + + @support.requires_docstrings + def test_contextmanager_doc_attrib(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__doc__, "Whee!") + + @support.requires_docstrings + def test_instance_docstring_given_cm_docstring(self): + baz = self._create_contextmanager_attribs()(None) + self.assertEqual(baz.__doc__, "Whee!") + + def test_keywords(self): + # Ensure no keyword arguments are inhibited + @contextmanager + def woohoo(self, func, args, kwds): + yield (self, func, args, kwds) + with woohoo(self=11, func=22, args=33, kwds=44) as target: + self.assertEqual(target, (11, 22, 33, 44)) + + def test_nokeepref(self): + class A: + pass + + @contextmanager + def woohoo(a, b): + a = weakref.ref(a) + b = weakref.ref(b) + # Allow test to work with a non-refcounted GC + support.gc_collect() + self.assertIsNone(a()) + self.assertIsNone(b()) + yield + + with woohoo(A(), b=A()): + pass + + def test_param_errors(self): + @contextmanager + def woohoo(a, *, b): + yield + + with self.assertRaises(TypeError): + woohoo() + with self.assertRaises(TypeError): + woohoo(3, 5) + with self.assertRaises(TypeError): + woohoo(b=3) + + def test_recursive(self): + depth = 0 + ncols = 0 + @contextmanager + def woohoo(): + nonlocal ncols + ncols += 1 + nonlocal depth + before = depth + depth += 1 + yield + depth -= 1 + self.assertEqual(depth, before) + + @woohoo() + def recursive(): + if depth < 10: + recursive() + + recursive() + self.assertEqual(ncols, 10) + self.assertEqual(depth, 0) + + +class ClosingTestCase(__TestCase): + + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = closing.__doc__ + obj = closing(None) + self.assertEqual(obj.__doc__, cm_docstring) + + def test_closing(self): + state = [] + class C: + def close(self): + state.append(1) + x = C() + self.assertEqual(state, []) + with closing(x) as y: + self.assertEqual(x, y) + self.assertEqual(state, [1]) + + def test_closing_error(self): + state = [] + class C: + def close(self): + state.append(1) + x = C() + self.assertEqual(state, []) + with self.assertRaises(ZeroDivisionError): + with closing(x) as y: + self.assertEqual(x, y) + 1 / 0 + self.assertEqual(state, [1]) + + +class NullcontextTestCase(__TestCase): + def test_nullcontext(self): + class C: + pass + c = C() + with nullcontext(c) as c_in: + self.assertIs(c_in, c) + + +class FileContextTestCase(__TestCase): + + def testWithOpen(self): + tfn = tempfile.mktemp() + try: + with open(tfn, "w", encoding="utf-8") as f: + self.assertFalse(f.closed) + f.write("Booh\n") + self.assertTrue(f.closed) + with self.assertRaises(ZeroDivisionError): + with open(tfn, "r", encoding="utf-8") as f: + self.assertFalse(f.closed) + self.assertEqual(f.read(), "Booh\n") + 1 / 0 + self.assertTrue(f.closed) + finally: + os_helper.unlink(tfn) + +class LockContextTestCase(__TestCase): + + def boilerPlate(self, lock, locked): + self.assertFalse(locked()) + with lock: + self.assertTrue(locked()) + self.assertFalse(locked()) + with self.assertRaises(ZeroDivisionError): + with lock: + self.assertTrue(locked()) + 1 / 0 + self.assertFalse(locked()) + + def testWithLock(self): + lock = threading.Lock() + self.boilerPlate(lock, lock.locked) + + def testWithRLock(self): + lock = threading.RLock() + self.boilerPlate(lock, lock._is_owned) + + def testWithCondition(self): + lock = threading.Condition() + def locked(): + return lock._is_owned() + self.boilerPlate(lock, locked) + + def testWithSemaphore(self): + lock = threading.Semaphore() + def locked(): + if lock.acquire(False): + lock.release() + return False + else: + return True + self.boilerPlate(lock, locked) + + def testWithBoundedSemaphore(self): + lock = threading.BoundedSemaphore() + def locked(): + if lock.acquire(False): + lock.release() + return False + else: + return True + self.boilerPlate(lock, locked) + + +class mycontext(ContextDecorator): + """Example decoration-compatible context manager for testing""" + started = False + exc = None + catch = False + + def __enter__(self): + self.started = True + return self + + def __exit__(self, *exc): + self.exc = exc + return self.catch + + +class TestContextDecorator(__TestCase): + + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = mycontext.__doc__ + obj = mycontext() + self.assertEqual(obj.__doc__, cm_docstring) + + def test_contextdecorator(self): + context = mycontext() + with context as result: + self.assertIs(result, context) + self.assertTrue(context.started) + + self.assertEqual(context.exc, (None, None, None)) + + + def test_contextdecorator_with_exception(self): + context = mycontext() + + with self.assertRaisesRegex(NameError, 'foo'): + with context: + raise NameError('foo') + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + context = mycontext() + context.catch = True + with context: + raise NameError('foo') + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + + def test_decorator(self): + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + test() + self.assertEqual(context.exc, (None, None, None)) + + + def test_decorator_with_exception(self): + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + raise NameError('foo') + + with self.assertRaisesRegex(NameError, 'foo'): + test() + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + + def test_decorating_method(self): + context = mycontext() + + class Test(object): + + @context + def method(self, a, b, c=None): + self.a = a + self.b = b + self.c = c + + # these tests are for argument passing when used as a decorator + test = Test() + test.method(1, 2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + self.assertEqual(test.c, None) + + test = Test() + test.method('a', 'b', 'c') + self.assertEqual(test.a, 'a') + self.assertEqual(test.b, 'b') + self.assertEqual(test.c, 'c') + + test = Test() + test.method(a=1, b=2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + + + def test_typo_enter(self): + class mycontext(ContextDecorator): + def __unter__(self): + pass + def __exit__(self, *exc): + pass + + with self.assertRaisesRegex(TypeError, 'the context manager'): + with mycontext(): + pass + + + def test_typo_exit(self): + class mycontext(ContextDecorator): + def __enter__(self): + pass + def __uxit__(self, *exc): + pass + + with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'): + with mycontext(): + pass + + + def test_contextdecorator_as_mixin(self): + class somecontext(object): + started = False + exc = None + + def __enter__(self): + self.started = True + return self + + def __exit__(self, *exc): + self.exc = exc + + class mycontext(somecontext, ContextDecorator): + pass + + context = mycontext() + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + test() + self.assertEqual(context.exc, (None, None, None)) + + + def test_contextmanager_as_decorator(self): + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + state = [] + @woohoo(1) + def test(x): + self.assertEqual(state, [1]) + state.append(x) + test('something') + self.assertEqual(state, [1, 'something', 999]) + + # Issue #11647: Ensure the decorated function is 'reusable' + state = [] + test('something else') + self.assertEqual(state, [1, 'something else', 999]) + + +class _TestBaseExitStack: + exit_stack = None + + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = self.exit_stack.__doc__ + obj = self.exit_stack() + self.assertEqual(obj.__doc__, cm_docstring) + + def test_no_resources(self): + with self.exit_stack(): + pass + + def test_callback(self): + expected = [ + ((), {}), + ((1,), {}), + ((1,2), {}), + ((), dict(example=1)), + ((1,), dict(example=1)), + ((1,2), dict(example=1)), + ((1,2), dict(self=3, callback=4)), + ] + result = [] + def _exit(*args, **kwds): + """Test metadata propagation""" + result.append((args, kwds)) + with self.exit_stack() as stack: + for args, kwds in reversed(expected): + if args and kwds: + f = stack.callback(_exit, *args, **kwds) + elif args: + f = stack.callback(_exit, *args) + elif kwds: + f = stack.callback(_exit, **kwds) + else: + f = stack.callback(_exit) + self.assertIs(f, _exit) + for wrapper in stack._exit_callbacks: + self.assertIs(wrapper[1].__wrapped__, _exit) + self.assertNotEqual(wrapper[1].__name__, _exit.__name__) + self.assertIsNone(wrapper[1].__doc__, _exit.__doc__) + self.assertEqual(result, expected) + + result = [] + with self.exit_stack() as stack: + with self.assertRaises(TypeError): + stack.callback(arg=1) + with self.assertRaises(TypeError): + self.exit_stack.callback(arg=2) + with self.assertRaises(TypeError): + stack.callback(callback=_exit, arg=3) + self.assertEqual(result, []) + + def test_push(self): + exc_raised = ZeroDivisionError + def _expect_exc(exc_type, exc, exc_tb): + self.assertIs(exc_type, exc_raised) + def _suppress_exc(*exc_details): + return True + def _expect_ok(exc_type, exc, exc_tb): + self.assertIsNone(exc_type) + self.assertIsNone(exc) + self.assertIsNone(exc_tb) + class ExitCM(object): + def __init__(self, check_exc): + self.check_exc = check_exc + def __enter__(self): + self.fail("Should not be called!") + def __exit__(self, *exc_details): + self.check_exc(*exc_details) + with self.exit_stack() as stack: + stack.push(_expect_ok) + self.assertIs(stack._exit_callbacks[-1][1], _expect_ok) + cm = ExitCM(_expect_ok) + stack.push(cm) + self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) + stack.push(_suppress_exc) + self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc) + cm = ExitCM(_expect_exc) + stack.push(cm) + self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) + stack.push(_expect_exc) + self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) + stack.push(_expect_exc) + self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) + 1/0 + + def test_enter_context(self): + class TestCM(object): + def __enter__(self): + result.append(1) + def __exit__(self, *exc_details): + result.append(3) + + result = [] + cm = TestCM() + with self.exit_stack() as stack: + @stack.callback # Registered first => cleaned up last + def _exit(): + result.append(4) + self.assertIsNotNone(_exit) + stack.enter_context(cm) + self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) + result.append(2) + self.assertEqual(result, [1, 2, 3, 4]) + + def test_enter_context_errors(self): + class LacksEnterAndExit: + pass + class LacksEnter: + def __exit__(self, *exc_info): + pass + class LacksExit: + def __enter__(self): + pass + + with self.exit_stack() as stack: + with self.assertRaisesRegex(TypeError, 'the context manager'): + stack.enter_context(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + stack.enter_context(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + stack.enter_context(LacksExit()) + self.assertFalse(stack._exit_callbacks) + + def test_close(self): + result = [] + with self.exit_stack() as stack: + @stack.callback + def _exit(): + result.append(1) + self.assertIsNotNone(_exit) + stack.close() + result.append(2) + self.assertEqual(result, [1, 2]) + + def test_pop_all(self): + result = [] + with self.exit_stack() as stack: + @stack.callback + def _exit(): + result.append(3) + self.assertIsNotNone(_exit) + new_stack = stack.pop_all() + result.append(1) + result.append(2) + new_stack.close() + self.assertEqual(result, [1, 2, 3]) + + def test_exit_raise(self): + with self.assertRaises(ZeroDivisionError): + with self.exit_stack() as stack: + stack.push(lambda *exc: False) + 1/0 + + def test_exit_suppress(self): + with self.exit_stack() as stack: + stack.push(lambda *exc: True) + 1/0 + + def test_exit_exception_traceback(self): + # This test captures the current behavior of ExitStack so that we know + # if we ever unintendedly change it. It is not a statement of what the + # desired behavior is (for instance, we may want to remove some of the + # internal contextlib frames). + + def raise_exc(exc): + raise exc + + try: + with self.exit_stack() as stack: + stack.callback(raise_exc, ValueError) + 1/0 + except ValueError as e: + exc = e + + self.assertIsInstance(exc, ValueError) + ve_frames = traceback.extract_tb(exc.__traceback__) + expected = \ + [('test_exit_exception_traceback', 'with self.exit_stack() as stack:')] + \ + self.callback_error_internal_frames + \ + [('_exit_wrapper', 'callback(*args, **kwds)'), + ('raise_exc', 'raise exc')] + + self.assertEqual( + [(f.name, f.line) for f in ve_frames], expected) + + self.assertIsInstance(exc.__context__, ZeroDivisionError) + zde_frames = traceback.extract_tb(exc.__context__.__traceback__) + self.assertEqual([(f.name, f.line) for f in zde_frames], + [('test_exit_exception_traceback', '1/0')]) + + def test_exit_exception_chaining_reference(self): + # Sanity check to make sure that ExitStack chaining matches + # actual nested with statements + class RaiseExc: + def __init__(self, exc): + self.exc = exc + def __enter__(self): + return self + def __exit__(self, *exc_details): + raise self.exc + + class RaiseExcWithContext: + def __init__(self, outer, inner): + self.outer = outer + self.inner = inner + def __enter__(self): + return self + def __exit__(self, *exc_details): + try: + raise self.inner + except: + raise self.outer + + class SuppressExc: + def __enter__(self): + return self + def __exit__(self, *exc_details): + type(self).saved_details = exc_details + return True + + try: + with RaiseExc(IndexError): + with RaiseExcWithContext(KeyError, AttributeError): + with SuppressExc(): + with RaiseExc(ValueError): + 1 / 0 + except IndexError as exc: + self.assertIsInstance(exc.__context__, KeyError) + self.assertIsInstance(exc.__context__.__context__, AttributeError) + # Inner exceptions were suppressed + self.assertIsNone(exc.__context__.__context__.__context__) + else: + self.fail("Expected IndexError, but no exception was raised") + # Check the inner exceptions + inner_exc = SuppressExc.saved_details[1] + self.assertIsInstance(inner_exc, ValueError) + self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) + + def test_exit_exception_chaining(self): + # Ensure exception chaining matches the reference behaviour + def raise_exc(exc): + raise exc + + saved_details = None + def suppress_exc(*exc_details): + nonlocal saved_details + saved_details = exc_details + return True + + try: + with self.exit_stack() as stack: + stack.callback(raise_exc, IndexError) + stack.callback(raise_exc, KeyError) + stack.callback(raise_exc, AttributeError) + stack.push(suppress_exc) + stack.callback(raise_exc, ValueError) + 1 / 0 + except IndexError as exc: + self.assertIsInstance(exc.__context__, KeyError) + self.assertIsInstance(exc.__context__.__context__, AttributeError) + # Inner exceptions were suppressed + self.assertIsNone(exc.__context__.__context__.__context__) + else: + self.fail("Expected IndexError, but no exception was raised") + # Check the inner exceptions + inner_exc = saved_details[1] + self.assertIsInstance(inner_exc, ValueError) + self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) + + def test_exit_exception_explicit_none_context(self): + # Ensure ExitStack chaining matches actual nested `with` statements + # regarding explicit __context__ = None. + + class MyException(Exception): + pass + + @contextmanager + def my_cm(): + try: + yield + except BaseException: + exc = MyException() + try: + raise exc + finally: + exc.__context__ = None + + @contextmanager + def my_cm_with_exit_stack(): + with self.exit_stack() as stack: + stack.enter_context(my_cm()) + yield stack + + for cm in (my_cm, my_cm_with_exit_stack): + with self.subTest(): + try: + with cm(): + raise IndexError() + except MyException as exc: + self.assertIsNone(exc.__context__) + else: + self.fail("Expected IndexError, but no exception was raised") + + def test_exit_exception_non_suppressing(self): + # http://bugs.python.org/issue19092 + def raise_exc(exc): + raise exc + + def suppress_exc(*exc_details): + return True + + try: + with self.exit_stack() as stack: + stack.callback(lambda: None) + stack.callback(raise_exc, IndexError) + except Exception as exc: + self.assertIsInstance(exc, IndexError) + else: + self.fail("Expected IndexError, but no exception was raised") + + try: + with self.exit_stack() as stack: + stack.callback(raise_exc, KeyError) + stack.push(suppress_exc) + stack.callback(raise_exc, IndexError) + except Exception as exc: + self.assertIsInstance(exc, KeyError) + else: + self.fail("Expected KeyError, but no exception was raised") + + def test_exit_exception_with_correct_context(self): + # http://bugs.python.org/issue20317 + @contextmanager + def gets_the_context_right(exc): + try: + yield + finally: + raise exc + + exc1 = Exception(1) + exc2 = Exception(2) + exc3 = Exception(3) + exc4 = Exception(4) + + # The contextmanager already fixes the context, so prior to the + # fix, ExitStack would try to fix it *again* and get into an + # infinite self-referential loop + try: + with self.exit_stack() as stack: + stack.enter_context(gets_the_context_right(exc4)) + stack.enter_context(gets_the_context_right(exc3)) + stack.enter_context(gets_the_context_right(exc2)) + raise exc1 + except Exception as exc: + self.assertIs(exc, exc4) + self.assertIs(exc.__context__, exc3) + self.assertIs(exc.__context__.__context__, exc2) + self.assertIs(exc.__context__.__context__.__context__, exc1) + self.assertIsNone( + exc.__context__.__context__.__context__.__context__) + + def test_exit_exception_with_existing_context(self): + # Addresses a lack of test coverage discovered after checking in a + # fix for issue 20317 that still contained debugging code. + def raise_nested(inner_exc, outer_exc): + try: + raise inner_exc + finally: + raise outer_exc + exc1 = Exception(1) + exc2 = Exception(2) + exc3 = Exception(3) + exc4 = Exception(4) + exc5 = Exception(5) + try: + with self.exit_stack() as stack: + stack.callback(raise_nested, exc4, exc5) + stack.callback(raise_nested, exc2, exc3) + raise exc1 + except Exception as exc: + self.assertIs(exc, exc5) + self.assertIs(exc.__context__, exc4) + self.assertIs(exc.__context__.__context__, exc3) + self.assertIs(exc.__context__.__context__.__context__, exc2) + self.assertIs( + exc.__context__.__context__.__context__.__context__, exc1) + self.assertIsNone( + exc.__context__.__context__.__context__.__context__.__context__) + + def test_body_exception_suppress(self): + def suppress_exc(*exc_details): + return True + try: + with self.exit_stack() as stack: + stack.push(suppress_exc) + 1/0 + except IndexError as exc: + self.fail("Expected no exception, got IndexError") + + def test_exit_exception_chaining_suppress(self): + with self.exit_stack() as stack: + stack.push(lambda *exc: True) + stack.push(lambda *exc: 1/0) + stack.push(lambda *exc: {}[1]) + + def test_excessive_nesting(self): + # The original implementation would die with RecursionError here + with self.exit_stack() as stack: + for i in range(10000): + stack.callback(int) + + def test_instance_bypass(self): + class Example(object): pass + cm = Example() + cm.__enter__ = object() + cm.__exit__ = object() + stack = self.exit_stack() + with self.assertRaisesRegex(TypeError, 'the context manager'): + stack.enter_context(cm) + stack.push(cm) + self.assertIs(stack._exit_callbacks[-1][1], cm) + + def test_dont_reraise_RuntimeError(self): + # https://bugs.python.org/issue27122 + class UniqueException(Exception): pass + class UniqueRuntimeError(RuntimeError): pass + + @contextmanager + def second(): + try: + yield 1 + except Exception as exc: + raise UniqueException("new exception") from exc + + @contextmanager + def first(): + try: + yield 1 + except Exception as exc: + raise exc + + # The UniqueRuntimeError should be caught by second()'s exception + # handler which chain raised a new UniqueException. + with self.assertRaises(UniqueException) as err_ctx: + with self.exit_stack() as es_ctx: + es_ctx.enter_context(second()) + es_ctx.enter_context(first()) + raise UniqueRuntimeError("please no infinite loop.") + + exc = err_ctx.exception + self.assertIsInstance(exc, UniqueException) + self.assertIsInstance(exc.__context__, UniqueRuntimeError) + self.assertIsNone(exc.__context__.__context__) + self.assertIsNone(exc.__context__.__cause__) + self.assertIs(exc.__cause__, exc.__context__) + + +class TestExitStack(_TestBaseExitStack, __TestCase): + exit_stack = ExitStack + callback_error_internal_frames = [ + ('__exit__', 'raise exc'), + ('__exit__', 'if cb(*exc_details):'), + ] + + +class _TestRedirectStream: + + redirect_stream = None + orig_stream = None + + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = self.redirect_stream.__doc__ + obj = self.redirect_stream(None) + self.assertEqual(obj.__doc__, cm_docstring) + + def test_no_redirect_in_init(self): + orig_stdout = getattr(sys, self.orig_stream) + self.redirect_stream(None) + self.assertIs(getattr(sys, self.orig_stream), orig_stdout) + + def test_redirect_to_string_io(self): + f = io.StringIO() + msg = "Consider an API like help(), which prints directly to stdout" + orig_stdout = getattr(sys, self.orig_stream) + with self.redirect_stream(f): + print(msg, file=getattr(sys, self.orig_stream)) + self.assertIs(getattr(sys, self.orig_stream), orig_stdout) + s = f.getvalue().strip() + self.assertEqual(s, msg) + + def test_enter_result_is_target(self): + f = io.StringIO() + with self.redirect_stream(f) as enter_result: + self.assertIs(enter_result, f) + + def test_cm_is_reusable(self): + f = io.StringIO() + write_to_f = self.redirect_stream(f) + orig_stdout = getattr(sys, self.orig_stream) + with write_to_f: + print("Hello", end=" ", file=getattr(sys, self.orig_stream)) + with write_to_f: + print("World!", file=getattr(sys, self.orig_stream)) + self.assertIs(getattr(sys, self.orig_stream), orig_stdout) + s = f.getvalue() + self.assertEqual(s, "Hello World!\n") + + def test_cm_is_reentrant(self): + f = io.StringIO() + write_to_f = self.redirect_stream(f) + orig_stdout = getattr(sys, self.orig_stream) + with write_to_f: + print("Hello", end=" ", file=getattr(sys, self.orig_stream)) + with write_to_f: + print("World!", file=getattr(sys, self.orig_stream)) + self.assertIs(getattr(sys, self.orig_stream), orig_stdout) + s = f.getvalue() + self.assertEqual(s, "Hello World!\n") + + +class TestRedirectStdout(_TestRedirectStream, __TestCase): + + redirect_stream = redirect_stdout + orig_stream = "stdout" + + +class TestRedirectStderr(_TestRedirectStream, __TestCase): + + redirect_stream = redirect_stderr + orig_stream = "stderr" + + +class TestSuppress(ExceptionIsLikeMixin, __TestCase): + + @support.requires_docstrings + def test_instance_docs(self): + # Issue 19330: ensure context manager instances have good docstrings + cm_docstring = suppress.__doc__ + obj = suppress() + self.assertEqual(obj.__doc__, cm_docstring) + + def test_no_result_from_enter(self): + with suppress(ValueError) as enter_result: + self.assertIsNone(enter_result) + + def test_no_exception(self): + with suppress(ValueError): + self.assertEqual(pow(2, 5), 32) + + def test_exact_exception(self): + with suppress(TypeError): + len(5) + + def test_exception_hierarchy(self): + with suppress(LookupError): + 'Hello'[50] + + def test_other_exception(self): + with self.assertRaises(ZeroDivisionError): + with suppress(TypeError): + 1/0 + + def test_no_args(self): + with self.assertRaises(ZeroDivisionError): + with suppress(): + 1/0 + + def test_multiple_exception_args(self): + with suppress(ZeroDivisionError, TypeError): + 1/0 + with suppress(ZeroDivisionError, TypeError): + len(5) + + def test_cm_is_reentrant(self): + ignore_exceptions = suppress(Exception) + with ignore_exceptions: + pass + with ignore_exceptions: + len(5) + with ignore_exceptions: + with ignore_exceptions: # Check nested usage + len(5) + outer_continued = True + 1/0 + self.assertTrue(outer_continued) + + def test_exception_groups(self): + eg_ve = lambda: ExceptionGroup( + "EG with ValueErrors only", + [ValueError("ve1"), ValueError("ve2"), ValueError("ve3")], + ) + eg_all = lambda: ExceptionGroup( + "EG with many types of exceptions", + [ValueError("ve1"), KeyError("ke1"), ValueError("ve2"), KeyError("ke2")], + ) + with suppress(ValueError): + raise eg_ve() + with suppress(ValueError, KeyError): + raise eg_all() + with self.assertRaises(ExceptionGroup) as eg1: + with suppress(ValueError): + raise eg_all() + self.assertExceptionIsLike( + eg1.exception, + ExceptionGroup( + "EG with many types of exceptions", + [KeyError("ke1"), KeyError("ke2")], + ), + ) + # Check handling of BaseExceptionGroup, using GeneratorExit so that + # we don't accidentally discard a ctrl-c with KeyboardInterrupt. + with suppress(GeneratorExit): + raise BaseExceptionGroup("message", [GeneratorExit()]) + # If we raise a BaseException group, we can still suppress parts + with self.assertRaises(BaseExceptionGroup) as eg1: + with suppress(KeyError): + raise BaseExceptionGroup("message", [GeneratorExit("g"), KeyError("k")]) + self.assertExceptionIsLike( + eg1.exception, BaseExceptionGroup("message", [GeneratorExit("g")]), + ) + # If we suppress all the leaf BaseExceptions, we get a non-base ExceptionGroup + with self.assertRaises(ExceptionGroup) as eg1: + with suppress(GeneratorExit): + raise BaseExceptionGroup("message", [GeneratorExit("g"), KeyError("k")]) + self.assertExceptionIsLike( + eg1.exception, ExceptionGroup("message", [KeyError("k")]), + ) + + +class TestChdir(__TestCase): + def make_relative_path(self, *parts): + return os.path.join( + os.path.dirname(os.path.realpath(__file__)), + *parts, + ) + + def test_simple(self): + old_cwd = os.getcwd() + target = self.make_relative_path('data') + self.assertNotEqual(old_cwd, target) + + with chdir(target): + self.assertEqual(os.getcwd(), target) + self.assertEqual(os.getcwd(), old_cwd) + + @unittest.skip("Missing archivetestdata") + def test_reentrant(self): + old_cwd = os.getcwd() + target1 = self.make_relative_path('data') + target2 = self.make_relative_path('archivetestdata') + self.assertNotIn(old_cwd, (target1, target2)) + chdir1, chdir2 = chdir(target1), chdir(target2) + + with chdir1: + self.assertEqual(os.getcwd(), target1) + with chdir2: + self.assertEqual(os.getcwd(), target2) + with chdir1: + self.assertEqual(os.getcwd(), target1) + self.assertEqual(os.getcwd(), target2) + self.assertEqual(os.getcwd(), target1) + self.assertEqual(os.getcwd(), old_cwd) + + def test_exception(self): + old_cwd = os.getcwd() + target = self.make_relative_path('data') + self.assertNotEqual(old_cwd, target) + + try: + with chdir(target): + self.assertEqual(os.getcwd(), target) + raise RuntimeError("boom") + except RuntimeError as re: + self.assertEqual(str(re), "boom") + self.assertEqual(os.getcwd(), old_cwd) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_dict.diff b/test/dynamo/cpython/3_13/test_dict.diff new file mode 100644 index 00000000000000..9589bcf797bd92 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_dict.diff @@ -0,0 +1,122 @@ +diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py +index 4729132c5a5..14f829c1715 100644 +--- a/test/dynamo/cpython/3_13/test_dict.py ++++ b/test/dynamo/cpython/3_13/test_dict.py +@@ -1,3 +1,57 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import ( ++ run_tests, ++ xfailIfTorchDynamo, ++) ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import collections + import collections.abc + import gc +@@ -11,7 +65,7 @@ from test import support + from test.support import import_helper, get_c_recursion_limit + + +-class DictTest(unittest.TestCase): ++class DictTest(__TestCase): + + def test_invalid_keyword_arguments(self): + class Custom(dict): +@@ -265,6 +319,7 @@ class DictTest(unittest.TestCase): + + self.assertRaises(ValueError, {}.update, [(1, 2, 3)]) + ++ @unittest.skip("test hangs") + def test_fromkeys(self): + self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) + d = {} +@@ -477,7 +532,7 @@ class DictTest(unittest.TestCase): + for copymode in -1, +1: + # -1: b has same structure as a + # +1: b is a.copy() +- for log2size in range(12): ++ for log2size in range(4): + size = 2**log2size + a = {} + b = {} +@@ -1006,18 +1061,6 @@ class DictTest(unittest.TestCase): + pass + self._tracked(MyDict()) + +- @support.cpython_only +- def test_track_lazy_instance_dicts(self): +- class C: +- pass +- o = C() +- d = o.__dict__ +- self._not_tracked(d) +- o.untracked = 42 +- self._not_tracked(d) +- o.tracked = [] +- self._tracked(d) +- + def make_shared_key_dict(self, n): + class C: + pass +@@ -1622,7 +1665,7 @@ class DictTest(unittest.TestCase): + self.assertGreaterEqual(eq_count, 1) + + +-class CAPITest(unittest.TestCase): ++class CAPITest(__TestCase): + + # Test _PyDict_GetItem_KnownHash() + @support.cpython_only +@@ -1666,4 +1709,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py new file mode 100644 index 00000000000000..14f829c1715c1c --- /dev/null +++ b/test/dynamo/cpython/3_13/test_dict.py @@ -0,0 +1,1712 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, + xfailIfTorchDynamo, +) + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import collections +import collections.abc +import gc +import pickle +import random +import string +import sys +import unittest +import weakref +from test import support +from test.support import import_helper, get_c_recursion_limit + + +class DictTest(__TestCase): + + def test_invalid_keyword_arguments(self): + class Custom(dict): + pass + for invalid in {1 : 2}, Custom({1 : 2}): + with self.assertRaises(TypeError): + dict(**invalid) + with self.assertRaises(TypeError): + {}.update(**invalid) + + def test_constructor(self): + # calling built-in types without argument must return empty + self.assertEqual(dict(), {}) + self.assertIsNot(dict(), {}) + + def test_literal_constructor(self): + # check literal constructor for different sized dicts + # (to exercise the BUILD_MAP oparg). + for n in (0, 1, 6, 256, 400): + items = [(''.join(random.sample(string.ascii_letters, 8)), i) + for i in range(n)] + random.shuffle(items) + formatted_items = ('{!r}: {:d}'.format(k, v) for k, v in items) + dictliteral = '{' + ', '.join(formatted_items) + '}' + self.assertEqual(eval(dictliteral), dict(items)) + + def test_merge_operator(self): + + a = {0: 0, 1: 1, 2: 1} + b = {1: 1, 2: 2, 3: 3} + + c = a.copy() + c |= b + + self.assertEqual(a | b, {0: 0, 1: 1, 2: 2, 3: 3}) + self.assertEqual(c, {0: 0, 1: 1, 2: 2, 3: 3}) + + c = b.copy() + c |= a + + self.assertEqual(b | a, {1: 1, 2: 1, 3: 3, 0: 0}) + self.assertEqual(c, {1: 1, 2: 1, 3: 3, 0: 0}) + + c = a.copy() + c |= [(1, 1), (2, 2), (3, 3)] + + self.assertEqual(c, {0: 0, 1: 1, 2: 2, 3: 3}) + + self.assertIs(a.__or__(None), NotImplemented) + self.assertIs(a.__or__(()), NotImplemented) + self.assertIs(a.__or__("BAD"), NotImplemented) + self.assertIs(a.__or__(""), NotImplemented) + + self.assertRaises(TypeError, a.__ior__, None) + self.assertEqual(a.__ior__(()), {0: 0, 1: 1, 2: 1}) + self.assertRaises(ValueError, a.__ior__, "BAD") + self.assertEqual(a.__ior__(""), {0: 0, 1: 1, 2: 1}) + + def test_bool(self): + self.assertIs(not {}, True) + self.assertTrue({1: 2}) + self.assertIs(bool({}), False) + self.assertIs(bool({1: 2}), True) + + def test_keys(self): + d = {} + self.assertEqual(set(d.keys()), set()) + d = {'a': 1, 'b': 2} + k = d.keys() + self.assertEqual(set(k), {'a', 'b'}) + self.assertIn('a', k) + self.assertIn('b', k) + self.assertIn('a', d) + self.assertIn('b', d) + self.assertRaises(TypeError, d.keys, None) + self.assertEqual(repr(dict(a=1).keys()), "dict_keys(['a'])") + + def test_values(self): + d = {} + self.assertEqual(set(d.values()), set()) + d = {1:2} + self.assertEqual(set(d.values()), {2}) + self.assertRaises(TypeError, d.values, None) + self.assertEqual(repr(dict(a=1).values()), "dict_values([1])") + + def test_items(self): + d = {} + self.assertEqual(set(d.items()), set()) + + d = {1:2} + self.assertEqual(set(d.items()), {(1, 2)}) + self.assertRaises(TypeError, d.items, None) + self.assertEqual(repr(dict(a=1).items()), "dict_items([('a', 1)])") + + def test_views_mapping(self): + mappingproxy = type(type.__dict__) + class Dict(dict): + pass + for cls in [dict, Dict]: + d = cls() + m1 = d.keys().mapping + m2 = d.values().mapping + m3 = d.items().mapping + + for m in [m1, m2, m3]: + self.assertIsInstance(m, mappingproxy) + self.assertEqual(m, d) + + d["foo"] = "bar" + + for m in [m1, m2, m3]: + self.assertIsInstance(m, mappingproxy) + self.assertEqual(m, d) + + def test_contains(self): + d = {} + self.assertNotIn('a', d) + self.assertFalse('a' in d) + self.assertTrue('a' not in d) + d = {'a': 1, 'b': 2} + self.assertIn('a', d) + self.assertIn('b', d) + self.assertNotIn('c', d) + + self.assertRaises(TypeError, d.__contains__) + + def test_len(self): + d = {} + self.assertEqual(len(d), 0) + d = {'a': 1, 'b': 2} + self.assertEqual(len(d), 2) + + def test_getitem(self): + d = {'a': 1, 'b': 2} + self.assertEqual(d['a'], 1) + self.assertEqual(d['b'], 2) + d['c'] = 3 + d['a'] = 4 + self.assertEqual(d['c'], 3) + self.assertEqual(d['a'], 4) + del d['b'] + self.assertEqual(d, {'a': 4, 'c': 3}) + + self.assertRaises(TypeError, d.__getitem__) + + class BadEq(object): + def __eq__(self, other): + raise Exc() + def __hash__(self): + return 24 + + d = {} + d[BadEq()] = 42 + self.assertRaises(KeyError, d.__getitem__, 23) + + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 + + x = BadHash() + d[x] = 42 + x.fail = True + self.assertRaises(Exc, d.__getitem__, x) + + def test_clear(self): + d = {1:1, 2:2, 3:3} + d.clear() + self.assertEqual(d, {}) + + self.assertRaises(TypeError, d.clear, None) + + def test_update(self): + d = {} + d.update({1:100}) + d.update({2:20}) + d.update({1:1, 2:2, 3:3}) + self.assertEqual(d, {1:1, 2:2, 3:3}) + + d.update() + self.assertEqual(d, {1:1, 2:2, 3:3}) + + self.assertRaises((TypeError, AttributeError), d.update, None) + + class SimpleUserDict: + def __init__(self): + self.d = {1:1, 2:2, 3:3} + def keys(self): + return self.d.keys() + def __getitem__(self, i): + return self.d[i] + d.clear() + d.update(SimpleUserDict()) + self.assertEqual(d, {1:1, 2:2, 3:3}) + + class Exc(Exception): pass + + d.clear() + class FailingUserDict: + def keys(self): + raise Exc + self.assertRaises(Exc, d.update, FailingUserDict()) + + class FailingUserDict: + def keys(self): + class BogonIter: + def __init__(self): + self.i = 1 + def __iter__(self): + return self + def __next__(self): + if self.i: + self.i = 0 + return 'a' + raise Exc + return BogonIter() + def __getitem__(self, key): + return key + self.assertRaises(Exc, d.update, FailingUserDict()) + + class FailingUserDict: + def keys(self): + class BogonIter: + def __init__(self): + self.i = ord('a') + def __iter__(self): + return self + def __next__(self): + if self.i <= ord('z'): + rtn = chr(self.i) + self.i += 1 + return rtn + raise StopIteration + return BogonIter() + def __getitem__(self, key): + raise Exc + self.assertRaises(Exc, d.update, FailingUserDict()) + + class badseq(object): + def __iter__(self): + return self + def __next__(self): + raise Exc() + + self.assertRaises(Exc, {}.update, badseq()) + + self.assertRaises(ValueError, {}.update, [(1, 2, 3)]) + + @unittest.skip("test hangs") + def test_fromkeys(self): + self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) + d = {} + self.assertIsNot(d.fromkeys('abc'), d) + self.assertEqual(d.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) + self.assertEqual(d.fromkeys((4,5),0), {4:0, 5:0}) + self.assertEqual(d.fromkeys([]), {}) + def g(): + yield 1 + self.assertEqual(d.fromkeys(g()), {1:None}) + self.assertRaises(TypeError, {}.fromkeys, 3) + class dictlike(dict): pass + self.assertEqual(dictlike.fromkeys('a'), {'a':None}) + self.assertEqual(dictlike().fromkeys('a'), {'a':None}) + self.assertIsInstance(dictlike.fromkeys('a'), dictlike) + self.assertIsInstance(dictlike().fromkeys('a'), dictlike) + class mydict(dict): + def __new__(cls): + return collections.UserDict() + ud = mydict.fromkeys('ab') + self.assertEqual(ud, {'a':None, 'b':None}) + self.assertIsInstance(ud, collections.UserDict) + self.assertRaises(TypeError, dict.fromkeys) + + class Exc(Exception): pass + + class baddict1(dict): + def __init__(self): + raise Exc() + + self.assertRaises(Exc, baddict1.fromkeys, [1]) + + class BadSeq(object): + def __iter__(self): + return self + def __next__(self): + raise Exc() + + self.assertRaises(Exc, dict.fromkeys, BadSeq()) + + class baddict2(dict): + def __setitem__(self, key, value): + raise Exc() + + self.assertRaises(Exc, baddict2.fromkeys, [1]) + + # test fast path for dictionary inputs + res = dict(zip(range(6), [0]*6)) + d = dict(zip(range(6), range(6))) + self.assertEqual(dict.fromkeys(d, 0), res) + # test fast path for set inputs + d = set(range(6)) + self.assertEqual(dict.fromkeys(d, 0), res) + # test slow path for other iterable inputs + d = list(range(6)) + self.assertEqual(dict.fromkeys(d, 0), res) + + # test fast path when object's constructor returns large non-empty dict + class baddict3(dict): + def __new__(cls): + return d + d = {i : i for i in range(1000)} + res = d.copy() + res.update(a=None, b=None, c=None) + self.assertEqual(baddict3.fromkeys({"a", "b", "c"}), res) + + # test slow path when object is a proper subclass of dict + class baddict4(dict): + def __init__(self): + dict.__init__(self, d) + d = {i : i for i in range(1000)} + res = d.copy() + res.update(a=None, b=None, c=None) + self.assertEqual(baddict4.fromkeys({"a", "b", "c"}), res) + + def test_copy(self): + d = {1: 1, 2: 2, 3: 3} + self.assertIsNot(d.copy(), d) + self.assertEqual(d.copy(), d) + self.assertEqual(d.copy(), {1: 1, 2: 2, 3: 3}) + + copy = d.copy() + d[4] = 4 + self.assertNotEqual(copy, d) + + self.assertEqual({}.copy(), {}) + self.assertRaises(TypeError, d.copy, None) + + def test_copy_fuzz(self): + for dict_size in [10, 100, 1000, 10000, 100000]: + dict_size = random.randrange( + dict_size // 2, dict_size + dict_size // 2) + with self.subTest(dict_size=dict_size): + d = {} + for i in range(dict_size): + d[i] = i + + d2 = d.copy() + self.assertIsNot(d2, d) + self.assertEqual(d, d2) + d2['key'] = 'value' + self.assertNotEqual(d, d2) + self.assertEqual(len(d2), len(d) + 1) + + def test_copy_maintains_tracking(self): + class A: + pass + + key = A() + + for d in ({}, {'a': 1}, {key: 'val'}): + d2 = d.copy() + self.assertEqual(gc.is_tracked(d), gc.is_tracked(d2)) + + def test_copy_noncompact(self): + # Dicts don't compact themselves on del/pop operations. + # Copy will use a slow merging strategy that produces + # a compacted copy when roughly 33% of dict is a non-used + # keys-space (to optimize memory footprint). + # In this test we want to hit the slow/compacting + # branch of dict.copy() and make sure it works OK. + d = {k: k for k in range(1000)} + for k in range(950): + del d[k] + d2 = d.copy() + self.assertEqual(d2, d) + + def test_get(self): + d = {} + self.assertIs(d.get('c'), None) + self.assertEqual(d.get('c', 3), 3) + d = {'a': 1, 'b': 2} + self.assertIs(d.get('c'), None) + self.assertEqual(d.get('c', 3), 3) + self.assertEqual(d.get('a'), 1) + self.assertEqual(d.get('a', 3), 1) + self.assertRaises(TypeError, d.get) + self.assertRaises(TypeError, d.get, None, None, None) + + def test_setdefault(self): + # dict.setdefault() + d = {} + self.assertIs(d.setdefault('key0'), None) + d.setdefault('key0', []) + self.assertIs(d.setdefault('key0'), None) + d.setdefault('key', []).append(3) + self.assertEqual(d['key'][0], 3) + d.setdefault('key', []).append(4) + self.assertEqual(len(d['key']), 2) + self.assertRaises(TypeError, d.setdefault) + + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 + + x = BadHash() + d[x] = 42 + x.fail = True + self.assertRaises(Exc, d.setdefault, x, []) + + def test_setdefault_atomic(self): + # Issue #13521: setdefault() calls __hash__ and __eq__ only once. + class Hashed(object): + def __init__(self): + self.hash_count = 0 + self.eq_count = 0 + def __hash__(self): + self.hash_count += 1 + return 42 + def __eq__(self, other): + self.eq_count += 1 + return id(self) == id(other) + hashed1 = Hashed() + y = {hashed1: 5} + hashed2 = Hashed() + y.setdefault(hashed2, []) + self.assertEqual(hashed1.hash_count, 1) + self.assertEqual(hashed2.hash_count, 1) + self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1) + + def test_setitem_atomic_at_resize(self): + class Hashed(object): + def __init__(self): + self.hash_count = 0 + self.eq_count = 0 + def __hash__(self): + self.hash_count += 1 + return 42 + def __eq__(self, other): + self.eq_count += 1 + return id(self) == id(other) + hashed1 = Hashed() + # 5 items + y = {hashed1: 5, 0: 0, 1: 1, 2: 2, 3: 3} + hashed2 = Hashed() + # 6th item forces a resize + y[hashed2] = [] + self.assertEqual(hashed1.hash_count, 1) + self.assertEqual(hashed2.hash_count, 1) + self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1) + + def test_popitem(self): + # dict.popitem() + for copymode in -1, +1: + # -1: b has same structure as a + # +1: b is a.copy() + for log2size in range(4): + size = 2**log2size + a = {} + b = {} + for i in range(size): + a[repr(i)] = i + if copymode < 0: + b[repr(i)] = i + if copymode > 0: + b = a.copy() + for i in range(size): + ka, va = ta = a.popitem() + self.assertEqual(va, int(ka)) + kb, vb = tb = b.popitem() + self.assertEqual(vb, int(kb)) + self.assertFalse(copymode < 0 and ta != tb) + self.assertFalse(a) + self.assertFalse(b) + + d = {} + self.assertRaises(KeyError, d.popitem) + + def test_pop(self): + # Tests for pop with specified key + d = {} + k, v = 'abc', 'def' + d[k] = v + self.assertRaises(KeyError, d.pop, 'ghi') + + self.assertEqual(d.pop(k), v) + self.assertEqual(len(d), 0) + + self.assertRaises(KeyError, d.pop, k) + + self.assertEqual(d.pop(k, v), v) + d[k] = v + self.assertEqual(d.pop(k, 1), v) + + self.assertRaises(TypeError, d.pop) + + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 + + x = BadHash() + d[x] = 42 + x.fail = True + self.assertRaises(Exc, d.pop, x) + + def test_mutating_iteration(self): + # changing dict size during iteration + d = {} + d[1] = 1 + with self.assertRaises(RuntimeError): + for i in d: + d[i+1] = 1 + + def test_mutating_iteration_delete(self): + # change dict content during iteration + d = {} + d[0] = 0 + with self.assertRaises(RuntimeError): + for i in d: + del d[0] + d[0] = 0 + + def test_mutating_iteration_delete_over_values(self): + # change dict content during iteration + d = {} + d[0] = 0 + with self.assertRaises(RuntimeError): + for i in d.values(): + del d[0] + d[0] = 0 + + def test_mutating_iteration_delete_over_items(self): + # change dict content during iteration + d = {} + d[0] = 0 + with self.assertRaises(RuntimeError): + for i in d.items(): + del d[0] + d[0] = 0 + + def test_mutating_lookup(self): + # changing dict during a lookup (issue #14417) + class NastyKey: + mutate_dict = None + + def __init__(self, value): + self.value = value + + def __hash__(self): + # hash collision! + return 1 + + def __eq__(self, other): + if NastyKey.mutate_dict: + mydict, key = NastyKey.mutate_dict + NastyKey.mutate_dict = None + del mydict[key] + return self.value == other.value + + key1 = NastyKey(1) + key2 = NastyKey(2) + d = {key1: 1} + NastyKey.mutate_dict = (d, key1) + d[key2] = 2 + self.assertEqual(d, {key2: 2}) + + def test_repr(self): + d = {} + self.assertEqual(repr(d), '{}') + d[1] = 2 + self.assertEqual(repr(d), '{1: 2}') + d = {} + d[1] = d + self.assertEqual(repr(d), '{1: {...}}') + + class Exc(Exception): pass + + class BadRepr(object): + def __repr__(self): + raise Exc() + + d = {1: BadRepr()} + self.assertRaises(Exc, repr, d) + + def test_repr_deep(self): + d = {} + for i in range(get_c_recursion_limit() + 1): + d = {1: d} + self.assertRaises(RecursionError, repr, d) + + def test_eq(self): + self.assertEqual({}, {}) + self.assertEqual({1: 2}, {1: 2}) + + class Exc(Exception): pass + + class BadCmp(object): + def __eq__(self, other): + raise Exc() + def __hash__(self): + return 1 + + d1 = {BadCmp(): 1} + d2 = {1: 1} + + with self.assertRaises(Exc): + d1 == d2 + + def test_keys_contained(self): + self.helper_keys_contained(lambda x: x.keys()) + self.helper_keys_contained(lambda x: x.items()) + + def helper_keys_contained(self, fn): + # Test rich comparisons against dict key views, which should behave the + # same as sets. + empty = fn(dict()) + empty2 = fn(dict()) + smaller = fn({1:1, 2:2}) + larger = fn({1:1, 2:2, 3:3}) + larger2 = fn({1:1, 2:2, 3:3}) + larger3 = fn({4:1, 2:2, 3:3}) + + self.assertTrue(smaller < larger) + self.assertTrue(smaller <= larger) + self.assertTrue(larger > smaller) + self.assertTrue(larger >= smaller) + + self.assertFalse(smaller >= larger) + self.assertFalse(smaller > larger) + self.assertFalse(larger <= smaller) + self.assertFalse(larger < smaller) + + self.assertFalse(smaller < larger3) + self.assertFalse(smaller <= larger3) + self.assertFalse(larger3 > smaller) + self.assertFalse(larger3 >= smaller) + + # Inequality strictness + self.assertTrue(larger2 >= larger) + self.assertTrue(larger2 <= larger) + self.assertFalse(larger2 > larger) + self.assertFalse(larger2 < larger) + + self.assertTrue(larger == larger2) + self.assertTrue(smaller != larger) + + # There is an optimization on the zero-element case. + self.assertTrue(empty == empty2) + self.assertFalse(empty != empty2) + self.assertFalse(empty == smaller) + self.assertTrue(empty != smaller) + + # With the same size, an elementwise compare happens + self.assertTrue(larger != larger3) + self.assertFalse(larger == larger3) + + def test_errors_in_view_containment_check(self): + class C: + def __eq__(self, other): + raise RuntimeError + + d1 = {1: C()} + d2 = {1: C()} + with self.assertRaises(RuntimeError): + d1.items() == d2.items() + with self.assertRaises(RuntimeError): + d1.items() != d2.items() + with self.assertRaises(RuntimeError): + d1.items() <= d2.items() + with self.assertRaises(RuntimeError): + d1.items() >= d2.items() + + d3 = {1: C(), 2: C()} + with self.assertRaises(RuntimeError): + d2.items() < d3.items() + with self.assertRaises(RuntimeError): + d3.items() > d2.items() + + def test_dictview_set_operations_on_keys(self): + k1 = {1:1, 2:2}.keys() + k2 = {1:1, 2:2, 3:3}.keys() + k3 = {4:4}.keys() + + self.assertEqual(k1 - k2, set()) + self.assertEqual(k1 - k3, {1,2}) + self.assertEqual(k2 - k1, {3}) + self.assertEqual(k3 - k1, {4}) + self.assertEqual(k1 & k2, {1,2}) + self.assertEqual(k1 & k3, set()) + self.assertEqual(k1 | k2, {1,2,3}) + self.assertEqual(k1 ^ k2, {3}) + self.assertEqual(k1 ^ k3, {1,2,4}) + + def test_dictview_set_operations_on_items(self): + k1 = {1:1, 2:2}.items() + k2 = {1:1, 2:2, 3:3}.items() + k3 = {4:4}.items() + + self.assertEqual(k1 - k2, set()) + self.assertEqual(k1 - k3, {(1,1), (2,2)}) + self.assertEqual(k2 - k1, {(3,3)}) + self.assertEqual(k3 - k1, {(4,4)}) + self.assertEqual(k1 & k2, {(1,1), (2,2)}) + self.assertEqual(k1 & k3, set()) + self.assertEqual(k1 | k2, {(1,1), (2,2), (3,3)}) + self.assertEqual(k1 ^ k2, {(3,3)}) + self.assertEqual(k1 ^ k3, {(1,1), (2,2), (4,4)}) + + def test_items_symmetric_difference(self): + rr = random.randrange + for _ in range(100): + left = {x:rr(3) for x in range(20) if rr(2)} + right = {x:rr(3) for x in range(20) if rr(2)} + with self.subTest(left=left, right=right): + expected = set(left.items()) ^ set(right.items()) + actual = left.items() ^ right.items() + self.assertEqual(actual, expected) + + def test_dictview_mixed_set_operations(self): + # Just a few for .keys() + self.assertTrue({1:1}.keys() == {1}) + self.assertTrue({1} == {1:1}.keys()) + self.assertEqual({1:1}.keys() | {2}, {1, 2}) + self.assertEqual({2} | {1:1}.keys(), {1, 2}) + # And a few for .items() + self.assertTrue({1:1}.items() == {(1,1)}) + self.assertTrue({(1,1)} == {1:1}.items()) + self.assertEqual({1:1}.items() | {2}, {(1,1), 2}) + self.assertEqual({2} | {1:1}.items(), {(1,1), 2}) + + def test_missing(self): + # Make sure dict doesn't have a __missing__ method + self.assertFalse(hasattr(dict, "__missing__")) + self.assertFalse(hasattr({}, "__missing__")) + # Test several cases: + # (D) subclass defines __missing__ method returning a value + # (E) subclass defines __missing__ method raising RuntimeError + # (F) subclass sets __missing__ instance variable (no effect) + # (G) subclass doesn't define __missing__ at all + class D(dict): + def __missing__(self, key): + return 42 + d = D({1: 2, 3: 4}) + self.assertEqual(d[1], 2) + self.assertEqual(d[3], 4) + self.assertNotIn(2, d) + self.assertNotIn(2, d.keys()) + self.assertEqual(d[2], 42) + + class E(dict): + def __missing__(self, key): + raise RuntimeError(key) + e = E() + with self.assertRaises(RuntimeError) as c: + e[42] + self.assertEqual(c.exception.args, (42,)) + + class F(dict): + def __init__(self): + # An instance variable __missing__ should have no effect + self.__missing__ = lambda key: None + f = F() + with self.assertRaises(KeyError) as c: + f[42] + self.assertEqual(c.exception.args, (42,)) + + class G(dict): + pass + g = G() + with self.assertRaises(KeyError) as c: + g[42] + self.assertEqual(c.exception.args, (42,)) + + def test_tuple_keyerror(self): + # SF #1576657 + d = {} + with self.assertRaises(KeyError) as c: + d[(1,)] + self.assertEqual(c.exception.args, ((1,),)) + + def test_bad_key(self): + # Dictionary lookups should fail if __eq__() raises an exception. + class CustomException(Exception): + pass + + class BadDictKey: + def __hash__(self): + return hash(self.__class__) + + def __eq__(self, other): + if isinstance(other, self.__class__): + raise CustomException + return other + + d = {} + x1 = BadDictKey() + x2 = BadDictKey() + d[x1] = 1 + for stmt in ['d[x2] = 2', + 'z = d[x2]', + 'x2 in d', + 'd.get(x2)', + 'd.setdefault(x2, 42)', + 'd.pop(x2)', + 'd.update({x2: 2})']: + with self.assertRaises(CustomException): + exec(stmt, locals()) + + def test_resize1(self): + # Dict resizing bug, found by Jack Jansen in 2.2 CVS development. + # This version got an assert failure in debug build, infinite loop in + # release build. Unfortunately, provoking this kind of stuff requires + # a mix of inserts and deletes hitting exactly the right hash codes in + # exactly the right order, and I can't think of a randomized approach + # that would be *likely* to hit a failing case in reasonable time. + + d = {} + for i in range(5): + d[i] = i + for i in range(5): + del d[i] + for i in range(5, 9): # i==8 was the problem + d[i] = i + + def test_resize2(self): + # Another dict resizing bug (SF bug #1456209). + # This caused Segmentation faults or Illegal instructions. + + class X(object): + def __hash__(self): + return 5 + def __eq__(self, other): + if resizing: + d.clear() + return False + d = {} + resizing = False + d[X()] = 1 + d[X()] = 2 + d[X()] = 3 + d[X()] = 4 + d[X()] = 5 + # now trigger a resize + resizing = True + d[9] = 6 + + def test_empty_presized_dict_in_freelist(self): + # Bug #3537: if an empty but presized dict with a size larger + # than 7 was in the freelist, it triggered an assertion failure + with self.assertRaises(ZeroDivisionError): + d = {'a': 1 // 0, 'b': None, 'c': None, 'd': None, 'e': None, + 'f': None, 'g': None, 'h': None} + d = {} + + def test_container_iterator(self): + # Bug #3680: tp_traverse was not implemented for dictiter and + # dictview objects. + class C(object): + pass + views = (dict.items, dict.values, dict.keys) + for v in views: + obj = C() + ref = weakref.ref(obj) + container = {obj: 1} + obj.v = v(container) + obj.x = iter(obj.v) + del obj, container + gc.collect() + self.assertIs(ref(), None, "Cycle was not collected") + + def _not_tracked(self, t): + # Nested containers can take several collections to untrack + gc.collect() + gc.collect() + self.assertFalse(gc.is_tracked(t), t) + + def _tracked(self, t): + self.assertTrue(gc.is_tracked(t), t) + gc.collect() + gc.collect() + self.assertTrue(gc.is_tracked(t), t) + + def test_string_keys_can_track_values(self): + # Test that this doesn't leak. + for i in range(10): + d = {} + for j in range(10): + d[str(j)] = j + d["foo"] = d + + @support.cpython_only + def test_track_literals(self): + # Test GC-optimization of dict literals + x, y, z, w = 1.5, "a", (1, None), [] + + self._not_tracked({}) + self._not_tracked({x:(), y:x, z:1}) + self._not_tracked({1: "a", "b": 2}) + self._not_tracked({1: 2, (None, True, False, ()): int}) + self._not_tracked({1: object()}) + + # Dicts with mutable elements are always tracked, even if those + # elements are not tracked right now. + self._tracked({1: []}) + self._tracked({1: ([],)}) + self._tracked({1: {}}) + self._tracked({1: set()}) + + @support.cpython_only + def test_track_dynamic(self): + # Test GC-optimization of dynamically-created dicts + class MyObject(object): + pass + x, y, z, w, o = 1.5, "a", (1, object()), [], MyObject() + + d = dict() + self._not_tracked(d) + d[1] = "a" + self._not_tracked(d) + d[y] = 2 + self._not_tracked(d) + d[z] = 3 + self._not_tracked(d) + self._not_tracked(d.copy()) + d[4] = w + self._tracked(d) + self._tracked(d.copy()) + d[4] = None + self._not_tracked(d) + self._not_tracked(d.copy()) + + # dd isn't tracked right now, but it may mutate and therefore d + # which contains it must be tracked. + d = dict() + dd = dict() + d[1] = dd + self._not_tracked(dd) + self._tracked(d) + dd[1] = d + self._tracked(dd) + + d = dict.fromkeys([x, y, z]) + self._not_tracked(d) + dd = dict() + dd.update(d) + self._not_tracked(dd) + d = dict.fromkeys([x, y, z, o]) + self._tracked(d) + dd = dict() + dd.update(d) + self._tracked(dd) + + d = dict(x=x, y=y, z=z) + self._not_tracked(d) + d = dict(x=x, y=y, z=z, w=w) + self._tracked(d) + d = dict() + d.update(x=x, y=y, z=z) + self._not_tracked(d) + d.update(w=w) + self._tracked(d) + + d = dict([(x, y), (z, 1)]) + self._not_tracked(d) + d = dict([(x, y), (z, w)]) + self._tracked(d) + d = dict() + d.update([(x, y), (z, 1)]) + self._not_tracked(d) + d.update([(x, y), (z, w)]) + self._tracked(d) + + @support.cpython_only + def test_track_subtypes(self): + # Dict subtypes are always tracked + class MyDict(dict): + pass + self._tracked(MyDict()) + + def make_shared_key_dict(self, n): + class C: + pass + + dicts = [] + for i in range(n): + a = C() + a.x, a.y, a.z = 1, 2, 3 + dicts.append(a.__dict__) + + return dicts + + @support.cpython_only + def test_splittable_setdefault(self): + """split table must keep correct insertion + order when attributes are adding using setdefault()""" + a, b = self.make_shared_key_dict(2) + + a['a'] = 1 + size_a = sys.getsizeof(a) + a['b'] = 2 + b.setdefault('b', 2) + size_b = sys.getsizeof(b) + b['a'] = 1 + + self.assertEqual(list(a), ['x', 'y', 'z', 'a', 'b']) + self.assertEqual(list(b), ['x', 'y', 'z', 'b', 'a']) + + @support.cpython_only + def test_splittable_del(self): + """split table must be combined when del d[k]""" + a, b = self.make_shared_key_dict(2) + + orig_size = sys.getsizeof(a) + + del a['y'] # split table is combined + with self.assertRaises(KeyError): + del a['y'] + + self.assertEqual(list(a), ['x', 'z']) + self.assertEqual(list(b), ['x', 'y', 'z']) + + # Two dicts have different insertion order. + a['y'] = 42 + self.assertEqual(list(a), ['x', 'z', 'y']) + self.assertEqual(list(b), ['x', 'y', 'z']) + + @support.cpython_only + def test_splittable_pop(self): + a, b = self.make_shared_key_dict(2) + + a.pop('y') + with self.assertRaises(KeyError): + a.pop('y') + + self.assertEqual(list(a), ['x', 'z']) + self.assertEqual(list(b), ['x', 'y', 'z']) + + # Two dicts have different insertion order. + a['y'] = 42 + self.assertEqual(list(a), ['x', 'z', 'y']) + self.assertEqual(list(b), ['x', 'y', 'z']) + + @support.cpython_only + def test_splittable_pop_pending(self): + """pop a pending key in a split table should not crash""" + a, b = self.make_shared_key_dict(2) + + a['a'] = 4 + with self.assertRaises(KeyError): + b.pop('a') + + @support.cpython_only + def test_splittable_popitem(self): + """split table must be combined when d.popitem()""" + a, b = self.make_shared_key_dict(2) + + orig_size = sys.getsizeof(a) + + item = a.popitem() # split table is combined + self.assertEqual(item, ('z', 3)) + with self.assertRaises(KeyError): + del a['z'] + + self.assertGreater(sys.getsizeof(a), orig_size) + self.assertEqual(list(a), ['x', 'y']) + self.assertEqual(list(b), ['x', 'y', 'z']) + + @support.cpython_only + def test_splittable_update(self): + """dict.update(other) must preserve order in other.""" + class C: + def __init__(self, order): + if order: + self.a, self.b, self.c = 1, 2, 3 + else: + self.c, self.b, self.a = 1, 2, 3 + o = C(True) + o = C(False) # o.__dict__ has reversed order. + self.assertEqual(list(o.__dict__), ["c", "b", "a"]) + + d = {} + d.update(o.__dict__) + self.assertEqual(list(d), ["c", "b", "a"]) + + @support.cpython_only + def test_splittable_to_generic_combinedtable(self): + """split table must be correctly resized and converted to generic combined table""" + class C: + pass + + a = C() + a.x = 1 + d = a.__dict__ + d[2] = 2 # split table is resized to a generic combined table + + self.assertEqual(list(d), ['x', 2]) + + def test_iterator_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + data = {1:"a", 2:"b", 3:"c"} + it = iter(data) + d = pickle.dumps(it, proto) + it = pickle.loads(d) + self.assertEqual(list(it), list(data)) + + it = pickle.loads(d) + try: + drop = next(it) + except StopIteration: + continue + d = pickle.dumps(it, proto) + it = pickle.loads(d) + del data[drop] + self.assertEqual(list(it), list(data)) + + def test_itemiterator_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + data = {1:"a", 2:"b", 3:"c"} + # dictviews aren't picklable, only their iterators + itorg = iter(data.items()) + d = pickle.dumps(itorg, proto) + it = pickle.loads(d) + # note that the type of the unpickled iterator + # is not necessarily the same as the original. It is + # merely an object supporting the iterator protocol, yielding + # the same objects as the original one. + # self.assertEqual(type(itorg), type(it)) + self.assertIsInstance(it, collections.abc.Iterator) + self.assertEqual(dict(it), data) + + it = pickle.loads(d) + drop = next(it) + d = pickle.dumps(it, proto) + it = pickle.loads(d) + del data[drop[0]] + self.assertEqual(dict(it), data) + + def test_valuesiterator_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + data = {1:"a", 2:"b", 3:"c"} + # data.values() isn't picklable, only its iterator + it = iter(data.values()) + d = pickle.dumps(it, proto) + it = pickle.loads(d) + self.assertEqual(list(it), list(data.values())) + + it = pickle.loads(d) + drop = next(it) + d = pickle.dumps(it, proto) + it = pickle.loads(d) + values = list(it) + [drop] + self.assertEqual(sorted(values), sorted(list(data.values()))) + + def test_reverseiterator_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + data = {1:"a", 2:"b", 3:"c"} + it = reversed(data) + d = pickle.dumps(it, proto) + it = pickle.loads(d) + self.assertEqual(list(it), list(reversed(data))) + + it = pickle.loads(d) + try: + drop = next(it) + except StopIteration: + continue + d = pickle.dumps(it, proto) + it = pickle.loads(d) + del data[drop] + self.assertEqual(list(it), list(reversed(data))) + + def test_reverseitemiterator_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + data = {1:"a", 2:"b", 3:"c"} + # dictviews aren't picklable, only their iterators + itorg = reversed(data.items()) + d = pickle.dumps(itorg, proto) + it = pickle.loads(d) + # note that the type of the unpickled iterator + # is not necessarily the same as the original. It is + # merely an object supporting the iterator protocol, yielding + # the same objects as the original one. + # self.assertEqual(type(itorg), type(it)) + self.assertIsInstance(it, collections.abc.Iterator) + self.assertEqual(dict(it), data) + + it = pickle.loads(d) + drop = next(it) + d = pickle.dumps(it, proto) + it = pickle.loads(d) + del data[drop[0]] + self.assertEqual(dict(it), data) + + def test_reversevaluesiterator_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + data = {1:"a", 2:"b", 3:"c"} + # data.values() isn't picklable, only its iterator + it = reversed(data.values()) + d = pickle.dumps(it, proto) + it = pickle.loads(d) + self.assertEqual(list(it), list(reversed(data.values()))) + + it = pickle.loads(d) + drop = next(it) + d = pickle.dumps(it, proto) + it = pickle.loads(d) + values = list(it) + [drop] + self.assertEqual(sorted(values), sorted(data.values())) + + def test_instance_dict_getattr_str_subclass(self): + class Foo: + def __init__(self, msg): + self.msg = msg + f = Foo('123') + class _str(str): + pass + self.assertEqual(f.msg, getattr(f, _str('msg'))) + self.assertEqual(f.msg, f.__dict__[_str('msg')]) + + def test_object_set_item_single_instance_non_str_key(self): + class Foo: pass + f = Foo() + f.__dict__[1] = 1 + f.a = 'a' + self.assertEqual(f.__dict__, {1:1, 'a':'a'}) + + def check_reentrant_insertion(self, mutate): + # This object will trigger mutation of the dict when replaced + # by another value. Note this relies on refcounting: the test + # won't achieve its purpose on fully-GCed Python implementations. + class Mutating: + def __del__(self): + mutate(d) + + d = {k: Mutating() for k in 'abcdefghijklmnopqr'} + for k in list(d): + d[k] = k + + def test_reentrant_insertion(self): + # Reentrant insertion shouldn't crash (see issue #22653) + def mutate(d): + d['b'] = 5 + self.check_reentrant_insertion(mutate) + + def mutate(d): + d.update(self.__dict__) + d.clear() + self.check_reentrant_insertion(mutate) + + def mutate(d): + while d: + d.popitem() + self.check_reentrant_insertion(mutate) + + def test_merge_and_mutate(self): + class X: + def __hash__(self): + return 0 + + def __eq__(self, o): + other.clear() + return False + + l = [(i,0) for i in range(1, 1337)] + other = dict(l) + other[X()] = 0 + d = {X(): 0, 1: 1} + self.assertRaises(RuntimeError, d.update, other) + + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, dict) + support.check_free_after_iterating(self, lambda d: iter(d.keys()), dict) + support.check_free_after_iterating(self, lambda d: iter(d.values()), dict) + support.check_free_after_iterating(self, lambda d: iter(d.items()), dict) + + def test_equal_operator_modifying_operand(self): + # test fix for seg fault reported in bpo-27945 part 3. + class X(): + def __del__(self): + dict_b.clear() + + def __eq__(self, other): + dict_a.clear() + return True + + def __hash__(self): + return 13 + + dict_a = {X(): 0} + dict_b = {X(): X()} + self.assertTrue(dict_a == dict_b) + + # test fix for seg fault reported in bpo-38588 part 1. + class Y: + def __eq__(self, other): + dict_d.clear() + return True + + dict_c = {0: Y()} + dict_d = {0: set()} + self.assertTrue(dict_c == dict_d) + + def test_fromkeys_operator_modifying_dict_operand(self): + # test fix for seg fault reported in issue 27945 part 4a. + class X(int): + def __hash__(self): + return 13 + + def __eq__(self, other): + if len(d) > 1: + d.clear() + return False + + d = {} # this is required to exist so that d can be constructed! + d = {X(1): 1, X(2): 2} + try: + dict.fromkeys(d) # shouldn't crash + except RuntimeError: # implementation defined + pass + + def test_fromkeys_operator_modifying_set_operand(self): + # test fix for seg fault reported in issue 27945 part 4b. + class X(int): + def __hash__(self): + return 13 + + def __eq__(self, other): + if len(d) > 1: + d.clear() + return False + + d = {} # this is required to exist so that d can be constructed! + d = {X(1), X(2)} + try: + dict.fromkeys(d) # shouldn't crash + except RuntimeError: # implementation defined + pass + + def test_dictitems_contains_use_after_free(self): + class X: + def __eq__(self, other): + d.clear() + return NotImplemented + + d = {0: set()} + (0, X()) in d.items() + + def test_dict_contain_use_after_free(self): + # bpo-40489 + class S(str): + def __eq__(self, other): + d.clear() + return NotImplemented + + def __hash__(self): + return hash('test') + + d = {S(): 'value'} + self.assertFalse('test' in d) + + def test_init_use_after_free(self): + class X: + def __hash__(self): + pair[:] = [] + return 13 + + pair = [X(), 123] + dict([pair]) + + def test_oob_indexing_dictiter_iternextitem(self): + class X(int): + def __del__(self): + d.clear() + + d = {i: X(i) for i in range(8)} + + def iter_and_mutate(): + for result in d.items(): + if result[0] == 2: + d[2] = None # free d[2] --> X(2).__del__ was called + + self.assertRaises(RuntimeError, iter_and_mutate) + + def test_reversed(self): + d = {"a": 1, "b": 2, "foo": 0, "c": 3, "d": 4} + del d["foo"] + r = reversed(d) + self.assertEqual(list(r), list('dcba')) + self.assertRaises(StopIteration, next, r) + + def test_reverse_iterator_for_empty_dict(self): + # bpo-38525: reversed iterator should work properly + + # empty dict is directly used for reference count test + self.assertEqual(list(reversed({})), []) + self.assertEqual(list(reversed({}.items())), []) + self.assertEqual(list(reversed({}.values())), []) + self.assertEqual(list(reversed({}.keys())), []) + + # dict() and {} don't trigger the same code path + self.assertEqual(list(reversed(dict())), []) + self.assertEqual(list(reversed(dict().items())), []) + self.assertEqual(list(reversed(dict().values())), []) + self.assertEqual(list(reversed(dict().keys())), []) + + def test_reverse_iterator_for_shared_shared_dicts(self): + class A: + def __init__(self, x, y): + if x: self.x = x + if y: self.y = y + + self.assertEqual(list(reversed(A(1, 2).__dict__)), ['y', 'x']) + self.assertEqual(list(reversed(A(1, 0).__dict__)), ['x']) + self.assertEqual(list(reversed(A(0, 1).__dict__)), ['y']) + + def test_dict_copy_order(self): + # bpo-34320 + od = collections.OrderedDict([('a', 1), ('b', 2)]) + od.move_to_end('a') + expected = list(od.items()) + + copy = dict(od) + self.assertEqual(list(copy.items()), expected) + + # dict subclass doesn't override __iter__ + class CustomDict(dict): + pass + + pairs = [('a', 1), ('b', 2), ('c', 3)] + + d = CustomDict(pairs) + self.assertEqual(pairs, list(dict(d).items())) + + class CustomReversedDict(dict): + def keys(self): + return reversed(list(dict.keys(self))) + + __iter__ = keys + + def items(self): + return reversed(dict.items(self)) + + d = CustomReversedDict(pairs) + self.assertEqual(pairs[::-1], list(dict(d).items())) + + @support.cpython_only + def test_dict_items_result_gc(self): + # bpo-42536: dict.items's tuple-reuse speed trick breaks the GC's + # assumptions about what can be untracked. Make sure we re-track result + # tuples whenever we reuse them. + it = iter({None: []}.items()) + gc.collect() + # That GC collection probably untracked the recycled internal result + # tuple, which is initialized to (None, None). Make sure it's re-tracked + # when it's mutated and returned from __next__: + self.assertTrue(gc.is_tracked(next(it))) + + @support.cpython_only + def test_dict_items_result_gc_reversed(self): + # Same as test_dict_items_result_gc above, but reversed. + it = reversed({None: []}.items()) + gc.collect() + self.assertTrue(gc.is_tracked(next(it))) + + def test_store_evilattr(self): + class EvilAttr: + def __init__(self, d): + self.d = d + + def __del__(self): + if 'attr' in self.d: + del self.d['attr'] + gc.collect() + + class Obj: + pass + + obj = Obj() + obj.__dict__ = {} + for _ in range(10): + obj.attr = EvilAttr(obj.__dict__) + + def test_str_nonstr(self): + # cpython uses a different lookup function if the dict only contains + # `str` keys. Make sure the unoptimized path is used when a non-`str` + # key appears. + + class StrSub(str): + pass + + eq_count = 0 + # This class compares equal to the string 'key3' + class Key3: + def __hash__(self): + return hash('key3') + + def __eq__(self, other): + nonlocal eq_count + if isinstance(other, Key3) or isinstance(other, str) and other == 'key3': + eq_count += 1 + return True + return False + + key3_1 = StrSub('key3') + key3_2 = Key3() + key3_3 = Key3() + + dicts = [] + + # Create dicts of the form `{'key1': 42, 'key2': 43, key3: 44}` in a + # bunch of different ways. In all cases, `key3` is not of type `str`. + # `key3_1` is a `str` subclass and `key3_2` is a completely unrelated + # type. + for key3 in (key3_1, key3_2): + # A literal + dicts.append({'key1': 42, 'key2': 43, key3: 44}) + + # key3 inserted via `dict.__setitem__` + d = {'key1': 42, 'key2': 43} + d[key3] = 44 + dicts.append(d) + + # key3 inserted via `dict.setdefault` + d = {'key1': 42, 'key2': 43} + self.assertEqual(d.setdefault(key3, 44), 44) + dicts.append(d) + + # key3 inserted via `dict.update` + d = {'key1': 42, 'key2': 43} + d.update({key3: 44}) + dicts.append(d) + + # key3 inserted via `dict.__ior__` + d = {'key1': 42, 'key2': 43} + d |= {key3: 44} + dicts.append(d) + + # `dict(iterable)` + def make_pairs(): + yield ('key1', 42) + yield ('key2', 43) + yield (key3, 44) + d = dict(make_pairs()) + dicts.append(d) + + # `dict.copy` + d = d.copy() + dicts.append(d) + + # dict comprehension + d = {key: 42 + i for i,key in enumerate(['key1', 'key2', key3])} + dicts.append(d) + + for d in dicts: + with self.subTest(d=d): + self.assertEqual(d.get('key1'), 42) + + # Try to make an object that is of type `str` and is equal to + # `'key1'`, but (at least on cpython) is a different object. + noninterned_key1 = 'ke' + noninterned_key1 += 'y1' + if support.check_impl_detail(cpython=True): + # suppress a SyntaxWarning + interned_key1 = 'key1' + self.assertFalse(noninterned_key1 is interned_key1) + self.assertEqual(d.get(noninterned_key1), 42) + + self.assertEqual(d.get('key3'), 44) + self.assertEqual(d.get(key3_1), 44) + self.assertEqual(d.get(key3_2), 44) + + # `key3_3` itself is definitely not a dict key, so make sure + # that `__eq__` gets called. + # + # Note that this might not hold for `key3_1` and `key3_2` + # because they might be the same object as one of the dict keys, + # in which case implementations are allowed to skip the call to + # `__eq__`. + eq_count = 0 + self.assertEqual(d.get(key3_3), 44) + self.assertGreaterEqual(eq_count, 1) + + +class CAPITest(__TestCase): + + # Test _PyDict_GetItem_KnownHash() + @support.cpython_only + def test_getitem_knownhash(self): + _testinternalcapi = import_helper.import_module('_testinternalcapi') + dict_getitem_knownhash = _testinternalcapi.dict_getitem_knownhash + + d = {'x': 1, 'y': 2, 'z': 3} + self.assertEqual(dict_getitem_knownhash(d, 'x', hash('x')), 1) + self.assertEqual(dict_getitem_knownhash(d, 'y', hash('y')), 2) + self.assertEqual(dict_getitem_knownhash(d, 'z', hash('z')), 3) + + # not a dict + self.assertRaises(SystemError, dict_getitem_knownhash, [], 1, hash(1)) + # key does not exist + self.assertRaises(KeyError, dict_getitem_knownhash, {}, 1, hash(1)) + + class Exc(Exception): pass + class BadEq: + def __eq__(self, other): + raise Exc + def __hash__(self): + return 7 + + k1, k2 = BadEq(), BadEq() + d = {k1: 1} + self.assertEqual(dict_getitem_knownhash(d, k1, hash(k1)), 1) + self.assertRaises(Exc, dict_getitem_knownhash, d, k2, hash(k2)) + + +from test import mapping_tests + +class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): + type2test = dict + +class Dict(dict): + pass + +class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + type2test = Dict + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_exception_variations.diff b/test/dynamo/cpython/3_13/test_exception_variations.diff new file mode 100644 index 00000000000000..45424e087b5a1f --- /dev/null +++ b/test/dynamo/cpython/3_13/test_exception_variations.diff @@ -0,0 +1,349 @@ +diff --git a/test/dynamo/cpython/3_13/test_exception_variations.py b/test/dynamo/cpython/3_13/test_exception_variations.py +index a83a41d2975..be432089e3a 100644 +--- a/test/dynamo/cpython/3_13/test_exception_variations.py ++++ b/test/dynamo/cpython/3_13/test_exception_variations.py +@@ -1,7 +1,59 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] + ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case + import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import ( ++ run_tests, ++ xfailIfTorchDynamo, ++) ++ ++__TestCase = CPythonTestCase ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= + +-class ExceptTestCases(unittest.TestCase): ++import unittest ++ ++class ExceptTestCases(__TestCase): + def test_try_except_else_finally(self): + hit_except = False + hit_else = False +@@ -294,282 +346,5 @@ class ExceptTestCases(unittest.TestCase): + self.assertTrue(hit_except) + + +-class ExceptStarTestCases(unittest.TestCase): +- def test_try_except_else_finally(self): +- hit_except = False +- hit_else = False +- hit_finally = False +- +- try: +- raise Exception('nyaa!') +- except* BaseException: +- hit_except = True +- else: +- hit_else = True +- finally: +- hit_finally = True +- +- self.assertTrue(hit_except) +- self.assertTrue(hit_finally) +- self.assertFalse(hit_else) +- +- def test_try_except_else_finally_no_exception(self): +- hit_except = False +- hit_else = False +- hit_finally = False +- +- try: +- pass +- except* BaseException: +- hit_except = True +- else: +- hit_else = True +- finally: +- hit_finally = True +- +- self.assertFalse(hit_except) +- self.assertTrue(hit_finally) +- self.assertTrue(hit_else) +- +- def test_try_except_finally(self): +- hit_except = False +- hit_finally = False +- +- try: +- raise Exception('yarr!') +- except* BaseException: +- hit_except = True +- finally: +- hit_finally = True +- +- self.assertTrue(hit_except) +- self.assertTrue(hit_finally) +- +- def test_try_except_finally_no_exception(self): +- hit_except = False +- hit_finally = False +- +- try: +- pass +- except* BaseException: +- hit_except = True +- finally: +- hit_finally = True +- +- self.assertFalse(hit_except) +- self.assertTrue(hit_finally) +- +- def test_try_except(self): +- hit_except = False +- +- try: +- raise Exception('ahoy!') +- except* BaseException: +- hit_except = True +- +- self.assertTrue(hit_except) +- +- def test_try_except_no_exception(self): +- hit_except = False +- +- try: +- pass +- except* BaseException: +- hit_except = True +- +- self.assertFalse(hit_except) +- +- def test_try_except_else(self): +- hit_except = False +- hit_else = False +- +- try: +- raise Exception('foo!') +- except* BaseException: +- hit_except = True +- else: +- hit_else = True +- +- self.assertFalse(hit_else) +- self.assertTrue(hit_except) +- +- def test_try_except_else_no_exception(self): +- hit_except = False +- hit_else = False +- +- try: +- pass +- except* BaseException: +- hit_except = True +- else: +- hit_else = True +- +- self.assertFalse(hit_except) +- self.assertTrue(hit_else) +- +- def test_try_finally_no_exception(self): +- hit_finally = False +- +- try: +- pass +- finally: +- hit_finally = True +- +- self.assertTrue(hit_finally) +- +- def test_nested(self): +- hit_finally = False +- hit_inner_except = False +- hit_inner_finally = False +- +- try: +- try: +- raise Exception('inner exception') +- except* BaseException: +- hit_inner_except = True +- finally: +- hit_inner_finally = True +- finally: +- hit_finally = True +- +- self.assertTrue(hit_inner_except) +- self.assertTrue(hit_inner_finally) +- self.assertTrue(hit_finally) +- +- def test_nested_else(self): +- hit_else = False +- hit_finally = False +- hit_except = False +- hit_inner_except = False +- hit_inner_else = False +- +- try: +- try: +- pass +- except* BaseException: +- hit_inner_except = True +- else: +- hit_inner_else = True +- +- raise Exception('outer exception') +- except* BaseException: +- hit_except = True +- else: +- hit_else = True +- finally: +- hit_finally = True +- +- self.assertFalse(hit_inner_except) +- self.assertTrue(hit_inner_else) +- self.assertFalse(hit_else) +- self.assertTrue(hit_finally) +- self.assertTrue(hit_except) +- +- def test_nested_mixed1(self): +- hit_except = False +- hit_finally = False +- hit_inner_except = False +- hit_inner_finally = False +- +- try: +- try: +- raise Exception('inner exception') +- except* BaseException: +- hit_inner_except = True +- finally: +- hit_inner_finally = True +- except: +- hit_except = True +- finally: +- hit_finally = True +- +- self.assertTrue(hit_inner_except) +- self.assertTrue(hit_inner_finally) +- self.assertFalse(hit_except) +- self.assertTrue(hit_finally) +- +- def test_nested_mixed2(self): +- hit_except = False +- hit_finally = False +- hit_inner_except = False +- hit_inner_finally = False +- +- try: +- try: +- raise Exception('inner exception') +- except: +- hit_inner_except = True +- finally: +- hit_inner_finally = True +- except* BaseException: +- hit_except = True +- finally: +- hit_finally = True +- +- self.assertTrue(hit_inner_except) +- self.assertTrue(hit_inner_finally) +- self.assertFalse(hit_except) +- self.assertTrue(hit_finally) +- +- +- def test_nested_else_mixed1(self): +- hit_else = False +- hit_finally = False +- hit_except = False +- hit_inner_except = False +- hit_inner_else = False +- +- try: +- try: +- pass +- except* BaseException: +- hit_inner_except = True +- else: +- hit_inner_else = True +- +- raise Exception('outer exception') +- except: +- hit_except = True +- else: +- hit_else = True +- finally: +- hit_finally = True +- +- self.assertFalse(hit_inner_except) +- self.assertTrue(hit_inner_else) +- self.assertFalse(hit_else) +- self.assertTrue(hit_finally) +- self.assertTrue(hit_except) +- +- def test_nested_else_mixed2(self): +- hit_else = False +- hit_finally = False +- hit_except = False +- hit_inner_except = False +- hit_inner_else = False +- +- try: +- try: +- pass +- except: +- hit_inner_except = True +- else: +- hit_inner_else = True +- +- raise Exception('outer exception') +- except* BaseException: +- hit_except = True +- else: +- hit_else = True +- finally: +- hit_finally = True +- +- self.assertFalse(hit_inner_except) +- self.assertTrue(hit_inner_else) +- self.assertFalse(hit_else) +- self.assertTrue(hit_finally) +- self.assertTrue(hit_except) +- +- + if __name__ == '__main__': +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_exception_variations.py b/test/dynamo/cpython/3_13/test_exception_variations.py new file mode 100644 index 00000000000000..be432089e3a33d --- /dev/null +++ b/test/dynamo/cpython/3_13/test_exception_variations.py @@ -0,0 +1,350 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, + xfailIfTorchDynamo, +) + +__TestCase = CPythonTestCase + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import unittest + +class ExceptTestCases(__TestCase): + def test_try_except_else_finally(self): + hit_except = False + hit_else = False + hit_finally = False + + try: + raise Exception('nyaa!') + except: + hit_except = True + else: + hit_else = True + finally: + hit_finally = True + + self.assertTrue(hit_except) + self.assertTrue(hit_finally) + self.assertFalse(hit_else) + + def test_try_except_else_finally_no_exception(self): + hit_except = False + hit_else = False + hit_finally = False + + try: + pass + except: + hit_except = True + else: + hit_else = True + finally: + hit_finally = True + + self.assertFalse(hit_except) + self.assertTrue(hit_finally) + self.assertTrue(hit_else) + + def test_try_except_finally(self): + hit_except = False + hit_finally = False + + try: + raise Exception('yarr!') + except: + hit_except = True + finally: + hit_finally = True + + self.assertTrue(hit_except) + self.assertTrue(hit_finally) + + def test_try_except_finally_no_exception(self): + hit_except = False + hit_finally = False + + try: + pass + except: + hit_except = True + finally: + hit_finally = True + + self.assertFalse(hit_except) + self.assertTrue(hit_finally) + + def test_try_except(self): + hit_except = False + + try: + raise Exception('ahoy!') + except: + hit_except = True + + self.assertTrue(hit_except) + + def test_try_except_no_exception(self): + hit_except = False + + try: + pass + except: + hit_except = True + + self.assertFalse(hit_except) + + def test_try_except_else(self): + hit_except = False + hit_else = False + + try: + raise Exception('foo!') + except: + hit_except = True + else: + hit_else = True + + self.assertFalse(hit_else) + self.assertTrue(hit_except) + + def test_try_except_else_no_exception(self): + hit_except = False + hit_else = False + + try: + pass + except: + hit_except = True + else: + hit_else = True + + self.assertFalse(hit_except) + self.assertTrue(hit_else) + + def test_try_finally_no_exception(self): + hit_finally = False + + try: + pass + finally: + hit_finally = True + + self.assertTrue(hit_finally) + + def test_nested(self): + hit_finally = False + hit_inner_except = False + hit_inner_finally = False + + try: + try: + raise Exception('inner exception') + except: + hit_inner_except = True + finally: + hit_inner_finally = True + finally: + hit_finally = True + + self.assertTrue(hit_inner_except) + self.assertTrue(hit_inner_finally) + self.assertTrue(hit_finally) + + def test_nested_else(self): + hit_else = False + hit_finally = False + hit_except = False + hit_inner_except = False + hit_inner_else = False + + try: + try: + pass + except: + hit_inner_except = True + else: + hit_inner_else = True + + raise Exception('outer exception') + except: + hit_except = True + else: + hit_else = True + finally: + hit_finally = True + + self.assertFalse(hit_inner_except) + self.assertTrue(hit_inner_else) + self.assertFalse(hit_else) + self.assertTrue(hit_finally) + self.assertTrue(hit_except) + + def test_nested_exception_in_except(self): + hit_else = False + hit_finally = False + hit_except = False + hit_inner_except = False + hit_inner_else = False + + try: + try: + raise Exception('inner exception') + except: + hit_inner_except = True + raise Exception('outer exception') + else: + hit_inner_else = True + except: + hit_except = True + else: + hit_else = True + finally: + hit_finally = True + + self.assertTrue(hit_inner_except) + self.assertFalse(hit_inner_else) + self.assertFalse(hit_else) + self.assertTrue(hit_finally) + self.assertTrue(hit_except) + + def test_nested_exception_in_else(self): + hit_else = False + hit_finally = False + hit_except = False + hit_inner_except = False + hit_inner_else = False + + try: + try: + pass + except: + hit_inner_except = True + else: + hit_inner_else = True + raise Exception('outer exception') + except: + hit_except = True + else: + hit_else = True + finally: + hit_finally = True + + self.assertFalse(hit_inner_except) + self.assertTrue(hit_inner_else) + self.assertFalse(hit_else) + self.assertTrue(hit_finally) + self.assertTrue(hit_except) + + def test_nested_exception_in_finally_no_exception(self): + hit_else = False + hit_finally = False + hit_except = False + hit_inner_except = False + hit_inner_else = False + hit_inner_finally = False + + try: + try: + pass + except: + hit_inner_except = True + else: + hit_inner_else = True + finally: + hit_inner_finally = True + raise Exception('outer exception') + except: + hit_except = True + else: + hit_else = True + finally: + hit_finally = True + + self.assertFalse(hit_inner_except) + self.assertTrue(hit_inner_else) + self.assertTrue(hit_inner_finally) + self.assertFalse(hit_else) + self.assertTrue(hit_finally) + self.assertTrue(hit_except) + + def test_nested_exception_in_finally_with_exception(self): + hit_else = False + hit_finally = False + hit_except = False + hit_inner_except = False + hit_inner_else = False + hit_inner_finally = False + + try: + try: + raise Exception('inner exception') + except: + hit_inner_except = True + else: + hit_inner_else = True + finally: + hit_inner_finally = True + raise Exception('outer exception') + except: + hit_except = True + else: + hit_else = True + finally: + hit_finally = True + + + self.assertTrue(hit_inner_except) + self.assertFalse(hit_inner_else) + self.assertTrue(hit_inner_finally) + self.assertFalse(hit_else) + self.assertTrue(hit_finally) + self.assertTrue(hit_except) + + +if __name__ == '__main__': + run_tests() diff --git a/test/dynamo/cpython/3_13/test_exceptions.diff b/test/dynamo/cpython/3_13/test_exceptions.diff new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py new file mode 100644 index 00000000000000..e6a9a2676bc000 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_exceptions.py @@ -0,0 +1,2587 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, + xfailIfTorchDynamo, +) + +__TestCase = CPythonTestCase + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +# Python test set -- part 5, built-in exceptions + +import copy +import os +import sys +import unittest +import pickle +import weakref +import errno +from codecs import BOM_UTF8 +from itertools import product +from textwrap import dedent + +from test.support import (captured_stderr, check_impl_detail, + cpython_only, gc_collect, + no_tracing, script_helper, + SuppressCrashReport, + force_not_colorized) +from test.support.import_helper import import_module +from test.support.os_helper import TESTFN, unlink +from test.support.warnings_helper import check_warnings +from test import support + +try: + import _testcapi + from _testcapi import INT_MAX +except ImportError: + _testcapi = None + INT_MAX = 2**31 - 1 + + +class NaiveException(Exception): + def __init__(self, x): + self.x = x + +class SlottedNaiveException(Exception): + __slots__ = ('x',) + def __init__(self, x): + self.x = x + +class BrokenStrException(Exception): + def __str__(self): + raise Exception("str() is broken") + +# XXX This is not really enough, each *operation* should be tested! + + +class ExceptionTests(__TestCase): + + def raise_catch(self, exc, excname): + with self.subTest(exc=exc, excname=excname): + try: + raise exc("spam") + except exc as err: + buf1 = str(err) + try: + raise exc("spam") + except exc as err: + buf2 = str(err) + self.assertEqual(buf1, buf2) + self.assertEqual(exc.__name__, excname) + + def testRaising(self): + self.raise_catch(AttributeError, "AttributeError") + self.assertRaises(AttributeError, getattr, sys, "undefined_attribute") + + self.raise_catch(EOFError, "EOFError") + fp = open(TESTFN, 'w', encoding="utf-8") + fp.close() + fp = open(TESTFN, 'r', encoding="utf-8") + savestdin = sys.stdin + try: + try: + import marshal + marshal.loads(b'') + except EOFError: + pass + finally: + sys.stdin = savestdin + fp.close() + unlink(TESTFN) + + self.raise_catch(OSError, "OSError") + self.assertRaises(OSError, open, 'this file does not exist', 'r') + + self.raise_catch(ImportError, "ImportError") + self.assertRaises(ImportError, __import__, "undefined_module") + + self.raise_catch(IndexError, "IndexError") + x = [] + self.assertRaises(IndexError, x.__getitem__, 10) + + self.raise_catch(KeyError, "KeyError") + x = {} + self.assertRaises(KeyError, x.__getitem__, 'key') + + self.raise_catch(KeyboardInterrupt, "KeyboardInterrupt") + + self.raise_catch(MemoryError, "MemoryError") + + self.raise_catch(NameError, "NameError") + try: x = undefined_variable + except NameError: pass + + self.raise_catch(OverflowError, "OverflowError") + x = 1 + for dummy in range(128): + x += x # this simply shouldn't blow up + + self.raise_catch(RuntimeError, "RuntimeError") + self.raise_catch(RecursionError, "RecursionError") + + self.raise_catch(SyntaxError, "SyntaxError") + try: exec('/\n') + except SyntaxError: pass + + self.raise_catch(IndentationError, "IndentationError") + + self.raise_catch(TabError, "TabError") + try: compile("try:\n\t1/0\n \t1/0\nfinally:\n pass\n", + '', 'exec') + except TabError: pass + else: self.fail("TabError not raised") + + self.raise_catch(SystemError, "SystemError") + + self.raise_catch(SystemExit, "SystemExit") + self.assertRaises(SystemExit, sys.exit, 0) + + self.raise_catch(TypeError, "TypeError") + try: [] + () + except TypeError: pass + + self.raise_catch(ValueError, "ValueError") + self.assertRaises(ValueError, chr, 17<<16) + + self.raise_catch(ZeroDivisionError, "ZeroDivisionError") + try: x = 1/0 + except ZeroDivisionError: pass + + self.raise_catch(Exception, "Exception") + try: x = 1/0 + except Exception as e: pass + + self.raise_catch(StopAsyncIteration, "StopAsyncIteration") + + def testSyntaxErrorMessage(self): + # make sure the right exception message is raised for each of + # these code fragments + + def ckmsg(src, msg): + with self.subTest(src=src, msg=msg): + try: + compile(src, '', 'exec') + except SyntaxError as e: + if e.msg != msg: + self.fail("expected %s, got %s" % (msg, e.msg)) + else: + self.fail("failed to get expected SyntaxError") + + s = '''if 1: + try: + continue + except: + pass''' + + ckmsg(s, "'continue' not properly in loop") + ckmsg("continue\n", "'continue' not properly in loop") + ckmsg("f'{6 0}'", "invalid syntax. Perhaps you forgot a comma?") + + def testSyntaxErrorMissingParens(self): + def ckmsg(src, msg, exception=SyntaxError): + try: + compile(src, '', 'exec') + except exception as e: + if e.msg != msg: + self.fail("expected %s, got %s" % (msg, e.msg)) + else: + self.fail("failed to get expected SyntaxError") + + s = '''print "old style"''' + ckmsg(s, "Missing parentheses in call to 'print'. Did you mean print(...)?") + + s = '''print "old style",''' + ckmsg(s, "Missing parentheses in call to 'print'. Did you mean print(...)?") + + s = 'print f(a+b,c)' + ckmsg(s, "Missing parentheses in call to 'print'. Did you mean print(...)?") + + s = '''exec "old style"''' + ckmsg(s, "Missing parentheses in call to 'exec'. Did you mean exec(...)?") + + s = 'exec f(a+b,c)' + ckmsg(s, "Missing parentheses in call to 'exec'. Did you mean exec(...)?") + + # Check that we don't incorrectly identify '(...)' as an expression to the right + # of 'print' + + s = 'print (a+b,c) $ 42' + ckmsg(s, "invalid syntax") + + s = 'exec (a+b,c) $ 42' + ckmsg(s, "invalid syntax") + + # should not apply to subclasses, see issue #31161 + s = '''if True:\nprint "No indent"''' + ckmsg(s, "expected an indented block after 'if' statement on line 1", IndentationError) + + s = '''if True:\n print()\n\texec "mixed tabs and spaces"''' + ckmsg(s, "inconsistent use of tabs and spaces in indentation", TabError) + + def check(self, src, lineno, offset, end_lineno=None, end_offset=None, encoding='utf-8'): + with self.subTest(source=src, lineno=lineno, offset=offset): + with self.assertRaises(SyntaxError) as cm: + compile(src, '', 'exec') + self.assertEqual(cm.exception.lineno, lineno) + self.assertEqual(cm.exception.offset, offset) + if end_lineno is not None: + self.assertEqual(cm.exception.end_lineno, end_lineno) + if end_offset is not None: + self.assertEqual(cm.exception.end_offset, end_offset) + + if cm.exception.text is not None: + if not isinstance(src, str): + src = src.decode(encoding, 'replace') + line = src.split('\n')[lineno-1] + self.assertIn(line, cm.exception.text) + + def test_error_offset_continuation_characters(self): + check = self.check + check('"\\\n"(1 for c in I,\\\n\\', 2, 2) + + def testSyntaxErrorOffset(self): + check = self.check + check('def fact(x):\n\treturn x!\n', 2, 10) + check('1 +\n', 1, 4) + check('def spam():\n print(1)\n print(2)', 3, 10) + check('Python = "Python" +', 1, 20) + check('Python = "\u1e54\xfd\u0163\u0125\xf2\xf1" +', 1, 20) + check(b'# -*- coding: cp1251 -*-\nPython = "\xcf\xb3\xf2\xee\xed" +', + 2, 19, encoding='cp1251') + check(b'Python = "\xcf\xb3\xf2\xee\xed" +', 1, 10) + check('x = "a', 1, 5) + check('lambda x: x = 2', 1, 1) + check('f{a + b + c}', 1, 2) + check('[file for str(file) in []\n]', 1, 11) + check('a = « hello » « world »', 1, 5) + check('[\nfile\nfor str(file)\nin\n[]\n]', 3, 5) + check('[file for\n str(file) in []]', 2, 2) + check("ages = {'Alice'=22, 'Bob'=23}", 1, 9) + check('match ...:\n case {**rest, "key": value}:\n ...', 2, 19) + check("[a b c d e f]", 1, 2) + check("for x yfff:", 1, 7) + check("f(a for a in b, c)", 1, 3, 1, 15) + check("f(a for a in b if a, c)", 1, 3, 1, 20) + check("f(a, b for b in c)", 1, 6, 1, 18) + check("f(a, b for b in c, d)", 1, 6, 1, 18) + + # Errors thrown by compile.c + check('class foo:return 1', 1, 11) + check('def f():\n continue', 2, 3) + check('def f():\n break', 2, 3) + check('try:\n pass\nexcept:\n pass\nexcept ValueError:\n pass', 3, 1) + check('try:\n pass\nexcept*:\n pass', 3, 8) + check('try:\n pass\nexcept*:\n pass\nexcept* ValueError:\n pass', 3, 8) + + # Errors thrown by the tokenizer + check('(0x+1)', 1, 3) + check('x = 0xI', 1, 6) + check('0010 + 2', 1, 1) + check('x = 32e-+4', 1, 8) + check('x = 0o9', 1, 7) + check('\u03b1 = 0xI', 1, 6) + check(b'\xce\xb1 = 0xI', 1, 6) + check(b'# -*- coding: iso8859-7 -*-\n\xe1 = 0xI', 2, 6, + encoding='iso8859-7') + check(b"""if 1: + def foo(): + ''' + + def bar(): + pass + + def baz(): + '''quux''' + """, 9, 24) + check("pass\npass\npass\n(1+)\npass\npass\npass", 4, 4) + check("(1+)", 1, 4) + check("[interesting\nfoo()\n", 1, 1) + check(b"\xef\xbb\xbf#coding: utf8\nprint('\xe6\x88\x91')\n", 0, -1) + check("""f''' + { + (123_a) + }'''""", 3, 17) + check("""f''' + { + f\"\"\" + { + (123_a) + } + \"\"\" + }'''""", 5, 17) + check('''f""" + + + { + 6 + 0="""''', 5, 13) + check('b"fooжжж"'.encode(), 1, 1, 1, 10) + + # Errors thrown by symtable.c + check('x = [(yield i) for i in range(3)]', 1, 7) + check('def f():\n from _ import *', 2, 17) + check('def f(x, x):\n pass', 1, 10) + check('{i for i in range(5) if (j := 0) for j in range(5)}', 1, 38) + check('def f(x):\n nonlocal x', 2, 3) + check('def f(x):\n x = 1\n global x', 3, 3) + check('nonlocal x', 1, 1) + check('def f():\n global x\n nonlocal x', 2, 3) + + # Errors thrown by future.c + check('from __future__ import doesnt_exist', 1, 24) + check('from __future__ import braces', 1, 24) + check('x=1\nfrom __future__ import division', 2, 1) + check('foo(1=2)', 1, 5) + check('def f():\n x, y: int', 2, 3) + check('[*x for x in xs]', 1, 2) + check('foo(x for x in range(10), 100)', 1, 5) + check('for 1 in []: pass', 1, 5) + check('(yield i) = 2', 1, 2) + check('def f(*):\n pass', 1, 7) + + @unittest.skipIf(INT_MAX >= sys.maxsize, "Downcasting to int is safe for col_offset") + @support.requires_resource('cpu') + @support.bigmemtest(INT_MAX, memuse=2, dry_run=False) + def testMemoryErrorBigSource(self, size): + src = b"if True:\n%*s" % (size, b"pass") + with self.assertRaisesRegex(OverflowError, "Parser column offset overflow"): + compile(src, '', 'exec') + + @cpython_only + def testSettingException(self): + # test that setting an exception at the C level works even if the + # exception object can't be constructed. + + class BadException(Exception): + def __init__(self_): + raise RuntimeError("can't instantiate BadException") + + class InvalidException: + pass + + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_capi1(): + try: + _testcapi.raise_exception(BadException, 1) + except TypeError as err: + co = err.__traceback__.tb_frame.f_code + self.assertEqual(co.co_name, "test_capi1") + self.assertTrue(co.co_filename.endswith('test_exceptions.py')) + else: + self.fail("Expected exception") + + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_capi2(): + try: + _testcapi.raise_exception(BadException, 0) + except RuntimeError as err: + tb = err.__traceback__.tb_next + co = tb.tb_frame.f_code + self.assertEqual(co.co_name, "__init__") + self.assertTrue(co.co_filename.endswith('test_exceptions.py')) + co2 = tb.tb_frame.f_back.f_code + self.assertEqual(co2.co_name, "test_capi2") + else: + self.fail("Expected exception") + + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_capi3(): + self.assertRaises(SystemError, _testcapi.raise_exception, + InvalidException, 1) + + test_capi1() + test_capi2() + test_capi3() + + def test_WindowsError(self): + try: + WindowsError + except NameError: + pass + else: + self.assertIs(WindowsError, OSError) + self.assertEqual(str(OSError(1001)), "1001") + self.assertEqual(str(OSError(1001, "message")), + "[Errno 1001] message") + # POSIX errno (9 aka EBADF) is untranslated + w = OSError(9, 'foo', 'bar') + self.assertEqual(w.errno, 9) + self.assertEqual(w.winerror, None) + self.assertEqual(str(w), "[Errno 9] foo: 'bar'") + # ERROR_PATH_NOT_FOUND (win error 3) becomes ENOENT (2) + w = OSError(0, 'foo', 'bar', 3) + self.assertEqual(w.errno, 2) + self.assertEqual(w.winerror, 3) + self.assertEqual(w.strerror, 'foo') + self.assertEqual(w.filename, 'bar') + self.assertEqual(w.filename2, None) + self.assertEqual(str(w), "[WinError 3] foo: 'bar'") + # Unknown win error becomes EINVAL (22) + w = OSError(0, 'foo', None, 1001) + self.assertEqual(w.errno, 22) + self.assertEqual(w.winerror, 1001) + self.assertEqual(w.strerror, 'foo') + self.assertEqual(w.filename, None) + self.assertEqual(w.filename2, None) + self.assertEqual(str(w), "[WinError 1001] foo") + # Non-numeric "errno" + w = OSError('bar', 'foo') + self.assertEqual(w.errno, 'bar') + self.assertEqual(w.winerror, None) + self.assertEqual(w.strerror, 'foo') + self.assertEqual(w.filename, None) + self.assertEqual(w.filename2, None) + + @unittest.skipUnless(sys.platform == 'win32', + 'test specific to Windows') + def test_windows_message(self): + """Should fill in unknown error code in Windows error message""" + ctypes = import_module('ctypes') + # this error code has no message, Python formats it as hexadecimal + code = 3765269347 + with self.assertRaisesRegex(OSError, 'Windows Error 0x%x' % code): + ctypes.pythonapi.PyErr_SetFromWindowsErr(code) + + def testAttributes(self): + # test that exception attributes are happy + + exceptionList = [ + (BaseException, (), {}, {'args' : ()}), + (BaseException, (1, ), {}, {'args' : (1,)}), + (BaseException, ('foo',), {}, + {'args' : ('foo',)}), + (BaseException, ('foo', 1), {}, + {'args' : ('foo', 1)}), + (SystemExit, ('foo',), {}, + {'args' : ('foo',), 'code' : 'foo'}), + (OSError, ('foo',), {}, + {'args' : ('foo',), 'filename' : None, 'filename2' : None, + 'errno' : None, 'strerror' : None}), + (OSError, ('foo', 'bar'), {}, + {'args' : ('foo', 'bar'), + 'filename' : None, 'filename2' : None, + 'errno' : 'foo', 'strerror' : 'bar'}), + (OSError, ('foo', 'bar', 'baz'), {}, + {'args' : ('foo', 'bar'), + 'filename' : 'baz', 'filename2' : None, + 'errno' : 'foo', 'strerror' : 'bar'}), + (OSError, ('foo', 'bar', 'baz', None, 'quux'), {}, + {'args' : ('foo', 'bar'), 'filename' : 'baz', 'filename2': 'quux'}), + (OSError, ('errnoStr', 'strErrorStr', 'filenameStr'), {}, + {'args' : ('errnoStr', 'strErrorStr'), + 'strerror' : 'strErrorStr', 'errno' : 'errnoStr', + 'filename' : 'filenameStr'}), + (OSError, (1, 'strErrorStr', 'filenameStr'), {}, + {'args' : (1, 'strErrorStr'), 'errno' : 1, + 'strerror' : 'strErrorStr', + 'filename' : 'filenameStr', 'filename2' : None}), + (SyntaxError, (), {}, {'msg' : None, 'text' : None, + 'filename' : None, 'lineno' : None, 'offset' : None, + 'end_offset': None, 'print_file_and_line' : None}), + (SyntaxError, ('msgStr',), {}, + {'args' : ('msgStr',), 'text' : None, + 'print_file_and_line' : None, 'msg' : 'msgStr', + 'filename' : None, 'lineno' : None, 'offset' : None, + 'end_offset': None}), + (SyntaxError, ('msgStr', ('filenameStr', 'linenoStr', 'offsetStr', + 'textStr', 'endLinenoStr', 'endOffsetStr')), {}, + {'offset' : 'offsetStr', 'text' : 'textStr', + 'args' : ('msgStr', ('filenameStr', 'linenoStr', + 'offsetStr', 'textStr', + 'endLinenoStr', 'endOffsetStr')), + 'print_file_and_line' : None, 'msg' : 'msgStr', + 'filename' : 'filenameStr', 'lineno' : 'linenoStr', + 'end_lineno': 'endLinenoStr', 'end_offset': 'endOffsetStr'}), + (SyntaxError, ('msgStr', 'filenameStr', 'linenoStr', 'offsetStr', + 'textStr', 'endLinenoStr', 'endOffsetStr', + 'print_file_and_lineStr'), {}, + {'text' : None, + 'args' : ('msgStr', 'filenameStr', 'linenoStr', 'offsetStr', + 'textStr', 'endLinenoStr', 'endOffsetStr', + 'print_file_and_lineStr'), + 'print_file_and_line' : None, 'msg' : 'msgStr', + 'filename' : None, 'lineno' : None, 'offset' : None, + 'end_lineno': None, 'end_offset': None}), + (UnicodeError, (), {}, {'args' : (),}), + (UnicodeEncodeError, ('ascii', 'a', 0, 1, + 'ordinal not in range'), {}, + {'args' : ('ascii', 'a', 0, 1, + 'ordinal not in range'), + 'encoding' : 'ascii', 'object' : 'a', + 'start' : 0, 'reason' : 'ordinal not in range'}), + (UnicodeDecodeError, ('ascii', bytearray(b'\xff'), 0, 1, + 'ordinal not in range'), {}, + {'args' : ('ascii', bytearray(b'\xff'), 0, 1, + 'ordinal not in range'), + 'encoding' : 'ascii', 'object' : b'\xff', + 'start' : 0, 'reason' : 'ordinal not in range'}), + (UnicodeDecodeError, ('ascii', b'\xff', 0, 1, + 'ordinal not in range'), {}, + {'args' : ('ascii', b'\xff', 0, 1, + 'ordinal not in range'), + 'encoding' : 'ascii', 'object' : b'\xff', + 'start' : 0, 'reason' : 'ordinal not in range'}), + (UnicodeTranslateError, ("\u3042", 0, 1, "ouch"), {}, + {'args' : ('\u3042', 0, 1, 'ouch'), + 'object' : '\u3042', 'reason' : 'ouch', + 'start' : 0, 'end' : 1}), + (NaiveException, ('foo',), {}, + {'args': ('foo',), 'x': 'foo'}), + (SlottedNaiveException, ('foo',), {}, + {'args': ('foo',), 'x': 'foo'}), + (AttributeError, ('foo',), dict(name='name', obj='obj'), + dict(args=('foo',), name='name', obj='obj')), + ] + try: + # More tests are in test_WindowsError + exceptionList.append( + (WindowsError, (1, 'strErrorStr', 'filenameStr'), {}, + {'args' : (1, 'strErrorStr'), + 'strerror' : 'strErrorStr', 'winerror' : None, + 'errno' : 1, + 'filename' : 'filenameStr', 'filename2' : None}) + ) + except NameError: + pass + + for exc, args, kwargs, expected in exceptionList: + try: + e = exc(*args, **kwargs) + except: + print(f"\nexc={exc!r}, args={args!r}", file=sys.stderr) + # raise + else: + # Verify module name + if not type(e).__name__.endswith('NaiveException'): + self.assertEqual(type(e).__module__, 'builtins') + # Verify no ref leaks in Exc_str() + s = str(e) + for checkArgName in expected: + value = getattr(e, checkArgName) + self.assertEqual(repr(value), + repr(expected[checkArgName]), + '%r.%s == %r, expected %r' % ( + e, checkArgName, + value, expected[checkArgName])) + + # test for pickling support + for p in [pickle]: + for protocol in range(p.HIGHEST_PROTOCOL + 1): + s = p.dumps(e, protocol) + new = p.loads(s) + for checkArgName in expected: + got = repr(getattr(new, checkArgName)) + if exc == AttributeError and checkArgName == 'obj': + # See GH-103352, we're not pickling + # obj at this point. So verify it's None. + want = repr(None) + else: + want = repr(expected[checkArgName]) + self.assertEqual(got, want, + 'pickled "%r", attribute "%s' % + (e, checkArgName)) + + def test_setstate(self): + e = Exception(42) + e.blah = 53 + self.assertEqual(e.args, (42,)) + self.assertEqual(e.blah, 53) + self.assertRaises(AttributeError, getattr, e, 'a') + self.assertRaises(AttributeError, getattr, e, 'b') + e.__setstate__({'a': 1 , 'b': 2}) + self.assertEqual(e.args, (42,)) + self.assertEqual(e.blah, 53) + self.assertEqual(e.a, 1) + self.assertEqual(e.b, 2) + e.__setstate__({'a': 11, 'args': (1,2,3), 'blah': 35}) + self.assertEqual(e.args, (1,2,3)) + self.assertEqual(e.blah, 35) + self.assertEqual(e.a, 11) + self.assertEqual(e.b, 2) + + def test_invalid_setstate(self): + e = Exception(42) + with self.assertRaisesRegex(TypeError, "state is not a dictionary"): + e.__setstate__(42) + + def test_notes(self): + for e in [BaseException(1), Exception(2), ValueError(3)]: + with self.subTest(e=e): + self.assertFalse(hasattr(e, '__notes__')) + e.add_note("My Note") + self.assertEqual(e.__notes__, ["My Note"]) + + with self.assertRaises(TypeError): + e.add_note(42) + self.assertEqual(e.__notes__, ["My Note"]) + + e.add_note("Your Note") + self.assertEqual(e.__notes__, ["My Note", "Your Note"]) + + del e.__notes__ + self.assertFalse(hasattr(e, '__notes__')) + + e.add_note("Our Note") + self.assertEqual(e.__notes__, ["Our Note"]) + + e.__notes__ = 42 + self.assertEqual(e.__notes__, 42) + + with self.assertRaises(TypeError): + e.add_note("will not work") + self.assertEqual(e.__notes__, 42) + + def testWithTraceback(self): + try: + raise IndexError(4) + except Exception as e: + tb = e.__traceback__ + + e = BaseException().with_traceback(tb) + self.assertIsInstance(e, BaseException) + self.assertEqual(e.__traceback__, tb) + + e = IndexError(5).with_traceback(tb) + self.assertIsInstance(e, IndexError) + self.assertEqual(e.__traceback__, tb) + + class MyException(Exception): + pass + + e = MyException().with_traceback(tb) + self.assertIsInstance(e, MyException) + self.assertEqual(e.__traceback__, tb) + + def testInvalidTraceback(self): + try: + Exception().__traceback__ = 5 + except TypeError as e: + self.assertIn("__traceback__ must be a traceback", str(e)) + else: + self.fail("No exception raised") + + def test_invalid_setattr(self): + TE = TypeError + exc = Exception() + msg = "'int' object is not iterable" + self.assertRaisesRegex(TE, msg, setattr, exc, 'args', 1) + msg = "__traceback__ must be a traceback or None" + self.assertRaisesRegex(TE, msg, setattr, exc, '__traceback__', 1) + msg = "exception cause must be None or derive from BaseException" + self.assertRaisesRegex(TE, msg, setattr, exc, '__cause__', 1) + msg = "exception context must be None or derive from BaseException" + self.assertRaisesRegex(TE, msg, setattr, exc, '__context__', 1) + + def test_invalid_delattr(self): + TE = TypeError + try: + raise IndexError(4) + except Exception as e: + exc = e + + msg = "may not be deleted" + self.assertRaisesRegex(TE, msg, delattr, exc, 'args') + self.assertRaisesRegex(TE, msg, delattr, exc, '__traceback__') + self.assertRaisesRegex(TE, msg, delattr, exc, '__cause__') + self.assertRaisesRegex(TE, msg, delattr, exc, '__context__') + + def testNoneClearsTracebackAttr(self): + try: + raise IndexError(4) + except Exception as e: + tb = e.__traceback__ + + e = Exception() + e.__traceback__ = tb + e.__traceback__ = None + self.assertEqual(e.__traceback__, None) + + def testChainingAttrs(self): + e = Exception() + self.assertIsNone(e.__context__) + self.assertIsNone(e.__cause__) + + e = TypeError() + self.assertIsNone(e.__context__) + self.assertIsNone(e.__cause__) + + class MyException(OSError): + pass + + e = MyException() + self.assertIsNone(e.__context__) + self.assertIsNone(e.__cause__) + + def testChainingDescriptors(self): + try: + raise Exception() + except Exception as exc: + e = exc + + self.assertIsNone(e.__context__) + self.assertIsNone(e.__cause__) + self.assertFalse(e.__suppress_context__) + + e.__context__ = NameError() + e.__cause__ = None + self.assertIsInstance(e.__context__, NameError) + self.assertIsNone(e.__cause__) + self.assertTrue(e.__suppress_context__) + e.__suppress_context__ = False + self.assertFalse(e.__suppress_context__) + + def testKeywordArgs(self): + # test that builtin exception don't take keyword args, + # but user-defined subclasses can if they want + self.assertRaises(TypeError, BaseException, a=1) + + class DerivedException(BaseException): + def __init__(self, fancy_arg): + BaseException.__init__(self) + self.fancy_arg = fancy_arg + + x = DerivedException(fancy_arg=42) + self.assertEqual(x.fancy_arg, 42) + + @no_tracing + def testInfiniteRecursion(self): + def f(): + return f() + self.assertRaises(RecursionError, f) + + def g(): + try: + return g() + except ValueError: + return -1 + self.assertRaises(RecursionError, g) + + def test_str(self): + # Make sure both instances and classes have a str representation. + self.assertTrue(str(Exception)) + self.assertTrue(str(Exception('a'))) + self.assertTrue(str(Exception('a', 'b'))) + + def test_exception_cleanup_names(self): + # Make sure the local variable bound to the exception instance by + # an "except" statement is only visible inside the except block. + try: + raise Exception() + except Exception as e: + self.assertIsInstance(e, Exception) + self.assertNotIn('e', locals()) + with self.assertRaises(UnboundLocalError): + e + + def test_exception_cleanup_names2(self): + # Make sure the cleanup doesn't break if the variable is explicitly deleted. + try: + raise Exception() + except Exception as e: + self.assertIsInstance(e, Exception) + del e + self.assertNotIn('e', locals()) + with self.assertRaises(UnboundLocalError): + e + + def testExceptionCleanupState(self): + # Make sure exception state is cleaned up as soon as the except + # block is left. See #2507 + + class MyException(Exception): + def __init__(self, obj): + self.obj = obj + class MyObj: + pass + + def inner_raising_func(): + # Create some references in exception value and traceback + local_ref = obj + raise MyException(obj) + + # Qualified "except" with "as" + obj = MyObj() + wr = weakref.ref(obj) + try: + inner_raising_func() + except MyException as e: + pass + obj = None + gc_collect() # For PyPy or other GCs. + obj = wr() + self.assertIsNone(obj) + + # Qualified "except" without "as" + obj = MyObj() + wr = weakref.ref(obj) + try: + inner_raising_func() + except MyException: + pass + obj = None + gc_collect() # For PyPy or other GCs. + obj = wr() + self.assertIsNone(obj) + + # Bare "except" + obj = MyObj() + wr = weakref.ref(obj) + try: + inner_raising_func() + except: + pass + obj = None + gc_collect() # For PyPy or other GCs. + obj = wr() + self.assertIsNone(obj) + + # "except" with premature block leave + obj = MyObj() + wr = weakref.ref(obj) + for i in [0]: + try: + inner_raising_func() + except: + break + obj = None + gc_collect() # For PyPy or other GCs. + obj = wr() + self.assertIsNone(obj) + + # "except" block raising another exception + obj = MyObj() + wr = weakref.ref(obj) + try: + try: + inner_raising_func() + except: + raise KeyError + except KeyError as e: + # We want to test that the except block above got rid of + # the exception raised in inner_raising_func(), but it + # also ends up in the __context__ of the KeyError, so we + # must clear the latter manually for our test to succeed. + e.__context__ = None + obj = None + gc_collect() # For PyPy or other GCs. + obj = wr() + # guarantee no ref cycles on CPython (don't gc_collect) + if check_impl_detail(cpython=False): + gc_collect() + self.assertIsNone(obj) + + # Some complicated construct + obj = MyObj() + wr = weakref.ref(obj) + try: + inner_raising_func() + except MyException: + try: + try: + raise + finally: + raise + except MyException: + pass + obj = None + if check_impl_detail(cpython=False): + gc_collect() + obj = wr() + self.assertIsNone(obj) + + # Inside an exception-silencing "with" block + class Context: + def __enter__(self): + return self + def __exit__ (self, exc_type, exc_value, exc_tb): + return True + obj = MyObj() + wr = weakref.ref(obj) + with Context(): + inner_raising_func() + obj = None + if check_impl_detail(cpython=False): + gc_collect() + obj = wr() + self.assertIsNone(obj) + + def test_exception_target_in_nested_scope(self): + # issue 4617: This used to raise a SyntaxError + # "can not delete variable 'e' referenced in nested scope" + def print_error(): + e + try: + something + except Exception as e: + print_error() + # implicit "del e" here + + def test_generator_leaking(self): + # Test that generator exception state doesn't leak into the calling + # frame + def yield_raise(): + try: + raise KeyError("caught") + except KeyError: + yield sys.exception() + yield sys.exception() + yield sys.exception() + g = yield_raise() + self.assertIsInstance(next(g), KeyError) + self.assertIsNone(sys.exception()) + self.assertIsInstance(next(g), KeyError) + self.assertIsNone(sys.exception()) + self.assertIsNone(next(g)) + + # Same test, but inside an exception handler + try: + raise TypeError("foo") + except TypeError: + g = yield_raise() + self.assertIsInstance(next(g), KeyError) + self.assertIsInstance(sys.exception(), TypeError) + self.assertIsInstance(next(g), KeyError) + self.assertIsInstance(sys.exception(), TypeError) + self.assertIsInstance(next(g), TypeError) + del g + self.assertIsInstance(sys.exception(), TypeError) + + def test_generator_leaking2(self): + # See issue 12475. + def g(): + yield + try: + raise RuntimeError + except RuntimeError: + it = g() + next(it) + try: + next(it) + except StopIteration: + pass + self.assertIsNone(sys.exception()) + + def test_generator_leaking3(self): + # See issue #23353. When gen.throw() is called, the caller's + # exception state should be save and restored. + def g(): + try: + yield + except ZeroDivisionError: + yield sys.exception() + it = g() + next(it) + try: + 1/0 + except ZeroDivisionError as e: + self.assertIs(sys.exception(), e) + gen_exc = it.throw(e) + self.assertIs(sys.exception(), e) + self.assertIs(gen_exc, e) + self.assertIsNone(sys.exception()) + + def test_generator_leaking4(self): + # See issue #23353. When an exception is raised by a generator, + # the caller's exception state should still be restored. + def g(): + try: + 1/0 + except ZeroDivisionError: + yield sys.exception() + raise + it = g() + try: + raise TypeError + except TypeError: + # The caller's exception state (TypeError) is temporarily + # saved in the generator. + tp = type(next(it)) + self.assertIs(tp, ZeroDivisionError) + try: + next(it) + # We can't check it immediately, but while next() returns + # with an exception, it shouldn't have restored the old + # exception state (TypeError). + except ZeroDivisionError as e: + self.assertIs(sys.exception(), e) + # We used to find TypeError here. + self.assertIsNone(sys.exception()) + + def test_generator_doesnt_retain_old_exc(self): + def g(): + self.assertIsInstance(sys.exception(), RuntimeError) + yield + self.assertIsNone(sys.exception()) + it = g() + try: + raise RuntimeError + except RuntimeError: + next(it) + self.assertRaises(StopIteration, next, it) + + def test_generator_finalizing_and_sys_exception(self): + # See #7173 + def simple_gen(): + yield 1 + def run_gen(): + gen = simple_gen() + try: + raise RuntimeError + except RuntimeError: + return next(gen) + run_gen() + gc_collect() + self.assertIsNone(sys.exception()) + + def _check_generator_cleanup_exc_state(self, testfunc): + # Issue #12791: exception state is cleaned up as soon as a generator + # is closed (reference cycles are broken). + class MyException(Exception): + def __init__(self, obj): + self.obj = obj + class MyObj: + pass + + def raising_gen(): + try: + raise MyException(obj) + except MyException: + yield + + obj = MyObj() + wr = weakref.ref(obj) + g = raising_gen() + next(g) + testfunc(g) + g = obj = None + gc_collect() # For PyPy or other GCs. + obj = wr() + self.assertIsNone(obj) + + def test_generator_throw_cleanup_exc_state(self): + def do_throw(g): + try: + g.throw(RuntimeError()) + except RuntimeError: + pass + self._check_generator_cleanup_exc_state(do_throw) + + def test_generator_close_cleanup_exc_state(self): + def do_close(g): + g.close() + self._check_generator_cleanup_exc_state(do_close) + + def test_generator_del_cleanup_exc_state(self): + def do_del(g): + g = None + self._check_generator_cleanup_exc_state(do_del) + + def test_generator_next_cleanup_exc_state(self): + def do_next(g): + try: + next(g) + except StopIteration: + pass + else: + self.fail("should have raised StopIteration") + self._check_generator_cleanup_exc_state(do_next) + + def test_generator_send_cleanup_exc_state(self): + def do_send(g): + try: + g.send(None) + except StopIteration: + pass + else: + self.fail("should have raised StopIteration") + self._check_generator_cleanup_exc_state(do_send) + + def test_3114(self): + # Bug #3114: in its destructor, MyObject retrieves a pointer to + # obsolete and/or deallocated objects. + class MyObject: + def __del__(self): + nonlocal e + e = sys.exception() + e = () + try: + raise Exception(MyObject()) + except: + pass + gc_collect() # For PyPy or other GCs. + self.assertIsNone(e) + + def test_raise_does_not_create_context_chain_cycle(self): + class A(Exception): + pass + class B(Exception): + pass + class C(Exception): + pass + + # Create a context chain: + # C -> B -> A + # Then raise A in context of C. + try: + try: + raise A + except A as a_: + a = a_ + try: + raise B + except B as b_: + b = b_ + try: + raise C + except C as c_: + c = c_ + self.assertIsInstance(a, A) + self.assertIsInstance(b, B) + self.assertIsInstance(c, C) + self.assertIsNone(a.__context__) + self.assertIs(b.__context__, a) + self.assertIs(c.__context__, b) + raise a + except A as e: + exc = e + + # Expect A -> C -> B, without cycle + self.assertIs(exc, a) + self.assertIs(a.__context__, c) + self.assertIs(c.__context__, b) + self.assertIsNone(b.__context__) + + def test_no_hang_on_context_chain_cycle1(self): + # See issue 25782. Cycle in context chain. + + def cycle(): + try: + raise ValueError(1) + except ValueError as ex: + ex.__context__ = ex + raise TypeError(2) + + try: + cycle() + except Exception as e: + exc = e + + self.assertIsInstance(exc, TypeError) + self.assertIsInstance(exc.__context__, ValueError) + self.assertIs(exc.__context__.__context__, exc.__context__) + + def test_no_hang_on_context_chain_cycle2(self): + # See issue 25782. Cycle at head of context chain. + + class A(Exception): + pass + class B(Exception): + pass + class C(Exception): + pass + + # Context cycle: + # +-----------+ + # V | + # C --> B --> A + with self.assertRaises(C) as cm: + try: + raise A() + except A as _a: + a = _a + try: + raise B() + except B as _b: + b = _b + try: + raise C() + except C as _c: + c = _c + a.__context__ = c + raise c + + self.assertIs(cm.exception, c) + # Verify the expected context chain cycle + self.assertIs(c.__context__, b) + self.assertIs(b.__context__, a) + self.assertIs(a.__context__, c) + + def test_no_hang_on_context_chain_cycle3(self): + # See issue 25782. Longer context chain with cycle. + + class A(Exception): + pass + class B(Exception): + pass + class C(Exception): + pass + class D(Exception): + pass + class E(Exception): + pass + + # Context cycle: + # +-----------+ + # V | + # E --> D --> C --> B --> A + with self.assertRaises(E) as cm: + try: + raise A() + except A as _a: + a = _a + try: + raise B() + except B as _b: + b = _b + try: + raise C() + except C as _c: + c = _c + a.__context__ = c + try: + raise D() + except D as _d: + d = _d + e = E() + raise e + + self.assertIs(cm.exception, e) + # Verify the expected context chain cycle + self.assertIs(e.__context__, d) + self.assertIs(d.__context__, c) + self.assertIs(c.__context__, b) + self.assertIs(b.__context__, a) + self.assertIs(a.__context__, c) + + def test_context_of_exception_in_try_and_finally(self): + try: + try: + te = TypeError(1) + raise te + finally: + ve = ValueError(2) + raise ve + except Exception as e: + exc = e + + self.assertIs(exc, ve) + self.assertIs(exc.__context__, te) + + def test_context_of_exception_in_except_and_finally(self): + try: + try: + te = TypeError(1) + raise te + except: + ve = ValueError(2) + raise ve + finally: + oe = OSError(3) + raise oe + except Exception as e: + exc = e + + self.assertIs(exc, oe) + self.assertIs(exc.__context__, ve) + self.assertIs(exc.__context__.__context__, te) + + def test_context_of_exception_in_else_and_finally(self): + try: + try: + pass + except: + pass + else: + ve = ValueError(1) + raise ve + finally: + oe = OSError(2) + raise oe + except Exception as e: + exc = e + + self.assertIs(exc, oe) + self.assertIs(exc.__context__, ve) + + def test_unicode_change_attributes(self): + # See issue 7309. This was a crasher. + + u = UnicodeEncodeError('baz', 'xxxxx', 1, 5, 'foo') + self.assertEqual(str(u), "'baz' codec can't encode characters in position 1-4: foo") + u.end = 2 + self.assertEqual(str(u), "'baz' codec can't encode character '\\x78' in position 1: foo") + u.end = 5 + u.reason = 0x345345345345345345 + self.assertEqual(str(u), "'baz' codec can't encode characters in position 1-4: 965230951443685724997") + u.encoding = 4000 + self.assertEqual(str(u), "'4000' codec can't encode characters in position 1-4: 965230951443685724997") + u.start = 1000 + self.assertEqual(str(u), "'4000' codec can't encode characters in position 1000-4: 965230951443685724997") + + u = UnicodeDecodeError('baz', b'xxxxx', 1, 5, 'foo') + self.assertEqual(str(u), "'baz' codec can't decode bytes in position 1-4: foo") + u.end = 2 + self.assertEqual(str(u), "'baz' codec can't decode byte 0x78 in position 1: foo") + u.end = 5 + u.reason = 0x345345345345345345 + self.assertEqual(str(u), "'baz' codec can't decode bytes in position 1-4: 965230951443685724997") + u.encoding = 4000 + self.assertEqual(str(u), "'4000' codec can't decode bytes in position 1-4: 965230951443685724997") + u.start = 1000 + self.assertEqual(str(u), "'4000' codec can't decode bytes in position 1000-4: 965230951443685724997") + + u = UnicodeTranslateError('xxxx', 1, 5, 'foo') + self.assertEqual(str(u), "can't translate characters in position 1-4: foo") + u.end = 2 + self.assertEqual(str(u), "can't translate character '\\x78' in position 1: foo") + u.end = 5 + u.reason = 0x345345345345345345 + self.assertEqual(str(u), "can't translate characters in position 1-4: 965230951443685724997") + u.start = 1000 + self.assertEqual(str(u), "can't translate characters in position 1000-4: 965230951443685724997") + + def test_unicode_errors_no_object(self): + # See issue #21134. + klasses = UnicodeEncodeError, UnicodeDecodeError, UnicodeTranslateError + for klass in klasses: + self.assertEqual(str(klass.__new__(klass)), "") + + def test_unicode_error_str_does_not_crash(self): + # Test that str(UnicodeError(...)) does not crash. + # See https://github.com/python/cpython/issues/123378. + + for start, end, objlen in product( + range(-5, 5), + range(-5, 5), + range(7), + ): + obj = 'a' * objlen + with self.subTest('encode', objlen=objlen, start=start, end=end): + exc = UnicodeEncodeError('utf-8', obj, start, end, '') + self.assertIsInstance(str(exc), str) + + with self.subTest('translate', objlen=objlen, start=start, end=end): + exc = UnicodeTranslateError(obj, start, end, '') + self.assertIsInstance(str(exc), str) + + encoded = obj.encode() + with self.subTest('decode', objlen=objlen, start=start, end=end): + exc = UnicodeDecodeError('utf-8', encoded, start, end, '') + self.assertIsInstance(str(exc), str) + + @no_tracing + def test_badisinstance(self): + # Bug #2542: if issubclass(e, MyException) raises an exception, + # it should be ignored + class Meta(type): + def __subclasscheck__(cls, subclass): + raise ValueError() + class MyException(Exception, metaclass=Meta): + pass + + with captured_stderr() as stderr: + try: + raise KeyError() + except MyException as e: + self.fail("exception should not be a MyException") + except KeyError: + pass + except: + self.fail("Should have raised KeyError") + else: + self.fail("Should have raised KeyError") + + def g(): + try: + return g() + except RecursionError as e: + return e + exc = g() + self.assertIsInstance(exc, RecursionError, type(exc)) + self.assertIn("maximum recursion depth exceeded", str(exc)) + + + @cpython_only + @support.requires_resource('cpu') + def test_trashcan_recursion(self): + # See bpo-33930 + + def foo(): + o = object() + for x in range(1_000_000): + # Create a big chain of method objects that will trigger + # a deep chain of calls when they need to be destructed. + o = o.__dir__ + + foo() + support.gc_collect() + + @cpython_only + def test_recursion_normalizing_exception(self): + import_module("_testinternalcapi") + # Issue #22898. + # Test that a RecursionError is raised when tstate->recursion_depth is + # equal to recursion_limit in PyErr_NormalizeException() and check + # that a ResourceWarning is printed. + # Prior to #22898, the recursivity of PyErr_NormalizeException() was + # controlled by tstate->recursion_depth and a PyExc_RecursionErrorInst + # singleton was being used in that case, that held traceback data and + # locals indefinitely and would cause a segfault in _PyExc_Fini() upon + # finalization of these locals. + code = """if 1: + import sys + from _testinternalcapi import get_recursion_depth + from test import support + + class MyException(Exception): pass + + def setrecursionlimit(depth): + while 1: + try: + sys.setrecursionlimit(depth) + return depth + except RecursionError: + # sys.setrecursionlimit() raises a RecursionError if + # the new recursion limit is too low (issue #25274). + depth += 1 + + def recurse(cnt): + cnt -= 1 + if cnt: + recurse(cnt) + else: + generator.throw(MyException) + + def gen(): + f = open(%a, mode='rb', buffering=0) + yield + + generator = gen() + next(generator) + recursionlimit = sys.getrecursionlimit() + try: + recurse(support.exceeds_recursion_limit()) + finally: + sys.setrecursionlimit(recursionlimit) + print('Done.') + """ % __file__ + rc, out, err = script_helper.assert_python_failure("-Wd", "-c", code) + # Check that the program does not fail with SIGABRT. + self.assertEqual(rc, 1) + self.assertIn(b'RecursionError', err) + self.assertIn(b'ResourceWarning', err) + self.assertIn(b'Done.', out) + + @cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") + @force_not_colorized + def test_recursion_normalizing_infinite_exception(self): + # Issue #30697. Test that a RecursionError is raised when + # maximum recursion depth has been exceeded when creating + # an exception + code = """if 1: + import _testcapi + try: + raise _testcapi.RecursingInfinitelyError + finally: + print('Done.') + """ + rc, out, err = script_helper.assert_python_failure("-c", code) + self.assertEqual(rc, 1) + expected = b'RecursionError: maximum recursion depth exceeded' + self.assertTrue(expected in err, msg=f"{expected!r} not found in {err[:3_000]!r}... (truncated)") + self.assertIn(b'Done.', out) + + + def test_recursion_in_except_handler(self): + + def set_relative_recursion_limit(n): + depth = 1 + while True: + try: + sys.setrecursionlimit(depth) + except RecursionError: + depth += 1 + else: + break + sys.setrecursionlimit(depth+n) + + def recurse_in_except(): + try: + 1/0 + except: + recurse_in_except() + + def recurse_after_except(): + try: + 1/0 + except: + pass + recurse_after_except() + + def recurse_in_body_and_except(): + try: + recurse_in_body_and_except() + except: + recurse_in_body_and_except() + + recursionlimit = sys.getrecursionlimit() + try: + set_relative_recursion_limit(10) + for func in (recurse_in_except, recurse_after_except, recurse_in_body_and_except): + with self.subTest(func=func): + try: + func() + except RecursionError: + pass + else: + self.fail("Should have raised a RecursionError") + finally: + sys.setrecursionlimit(recursionlimit) + + + @cpython_only + # Python built with Py_TRACE_REFS fail with a fatal error in + # _PyRefchain_Trace() on memory allocation error. + @unittest.skipIf(support.Py_TRACE_REFS, 'cannot test Py_TRACE_REFS build') + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_recursion_normalizing_with_no_memory(self): + # Issue #30697. Test that in the abort that occurs when there is no + # memory left and the size of the Python frames stack is greater than + # the size of the list of preallocated MemoryError instances, the + # Fatal Python error message mentions MemoryError. + code = """if 1: + import _testcapi + class C(): pass + def recurse(cnt): + cnt -= 1 + if cnt: + recurse(cnt) + else: + _testcapi.set_nomemory(0) + C() + recurse(16) + """ + with SuppressCrashReport(): + rc, out, err = script_helper.assert_python_failure("-c", code) + self.assertIn(b'MemoryError', err) + + @cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_MemoryError(self): + # PyErr_NoMemory always raises the same exception instance. + # Check that the traceback is not doubled. + import traceback + from _testcapi import raise_memoryerror + def raiseMemError(): + try: + raise_memoryerror() + except MemoryError as e: + tb = e.__traceback__ + else: + self.fail("Should have raised a MemoryError") + return traceback.format_tb(tb) + + tb1 = raiseMemError() + tb2 = raiseMemError() + self.assertEqual(tb1, tb2) + + @cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_exception_with_doc(self): + doc2 = "This is a test docstring." + doc4 = "This is another test docstring." + + self.assertRaises(SystemError, _testcapi.make_exception_with_doc, + "error1") + + # test basic usage of PyErr_NewException + error1 = _testcapi.make_exception_with_doc("_testcapi.error1") + self.assertIs(type(error1), type) + self.assertTrue(issubclass(error1, Exception)) + self.assertIsNone(error1.__doc__) + + # test with given docstring + error2 = _testcapi.make_exception_with_doc("_testcapi.error2", doc2) + self.assertEqual(error2.__doc__, doc2) + + # test with explicit base (without docstring) + error3 = _testcapi.make_exception_with_doc("_testcapi.error3", + base=error2) + self.assertTrue(issubclass(error3, error2)) + + # test with explicit base tuple + class C(object): + pass + error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4, + (error3, C)) + self.assertTrue(issubclass(error4, error3)) + self.assertTrue(issubclass(error4, C)) + self.assertEqual(error4.__doc__, doc4) + + # test with explicit dictionary + error5 = _testcapi.make_exception_with_doc("_testcapi.error5", "", + error4, {'a': 1}) + self.assertTrue(issubclass(error5, error4)) + self.assertEqual(error5.a, 1) + self.assertEqual(error5.__doc__, "") + + @cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_memory_error_cleanup(self): + # Issue #5437: preallocated MemoryError instances should not keep + # traceback objects alive. + from _testcapi import raise_memoryerror + class C: + pass + wr = None + def inner(): + nonlocal wr + c = C() + wr = weakref.ref(c) + raise_memoryerror() + # We cannot use assertRaises since it manually deletes the traceback + try: + inner() + except MemoryError as e: + self.assertNotEqual(wr(), None) + else: + self.fail("MemoryError not raised") + gc_collect() # For PyPy or other GCs. + self.assertEqual(wr(), None) + + @no_tracing + def test_recursion_error_cleanup(self): + # Same test as above, but with "recursion exceeded" errors + class C: + pass + wr = None + def inner(): + nonlocal wr + c = C() + wr = weakref.ref(c) + inner() + # We cannot use assertRaises since it manually deletes the traceback + try: + inner() + except RecursionError as e: + self.assertNotEqual(wr(), None) + else: + self.fail("RecursionError not raised") + gc_collect() # For PyPy or other GCs. + self.assertEqual(wr(), None) + + def test_errno_ENOTDIR(self): + # Issue #12802: "not a directory" errors are ENOTDIR even on Windows + with self.assertRaises(OSError) as cm: + os.listdir(__file__) + self.assertEqual(cm.exception.errno, errno.ENOTDIR, cm.exception) + + def test_unraisable(self): + # Issue #22836: PyErr_WriteUnraisable() should give sensible reports + class BrokenDel: + def __del__(self): + exc = ValueError("del is broken") + # The following line is included in the traceback report: + raise exc + + obj = BrokenDel() + with support.catch_unraisable_exception() as cm: + del obj + + gc_collect() # For PyPy or other GCs. + self.assertEqual(cm.unraisable.object, BrokenDel.__del__) + self.assertIsNotNone(cm.unraisable.exc_traceback) + + def test_unhandled(self): + # Check for sensible reporting of unhandled exceptions + for exc_type in (ValueError, BrokenStrException): + with self.subTest(exc_type): + try: + exc = exc_type("test message") + # The following line is included in the traceback report: + raise exc + except exc_type: + with captured_stderr() as stderr: + sys.__excepthook__(*sys.exc_info()) + report = stderr.getvalue() + self.assertIn("test_exceptions.py", report) + self.assertIn("raise exc", report) + self.assertIn(exc_type.__name__, report) + if exc_type is BrokenStrException: + self.assertIn("", report) + else: + self.assertIn("test message", report) + self.assertTrue(report.endswith("\n")) + + @cpython_only + # Python built with Py_TRACE_REFS fail with a fatal error in + # _PyRefchain_Trace() on memory allocation error. + @unittest.skipIf(support.Py_TRACE_REFS, 'cannot test Py_TRACE_REFS build') + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_memory_error_in_PyErr_PrintEx(self): + code = """if 1: + import _testcapi + class C(): pass + _testcapi.set_nomemory(0, %d) + C() + """ + + # Issue #30817: Abort in PyErr_PrintEx() when no memory. + # Span a large range of tests as the CPython code always evolves with + # changes that add or remove memory allocations. + for i in range(1, 20): + rc, out, err = script_helper.assert_python_failure("-c", code % i) + self.assertIn(rc, (1, 120)) + self.assertIn(b'MemoryError', err) + + def test_yield_in_nested_try_excepts(self): + #Issue #25612 + class MainError(Exception): + pass + + class SubError(Exception): + pass + + def main(): + try: + raise MainError() + except MainError: + try: + yield + except SubError: + pass + raise + + coro = main() + coro.send(None) + with self.assertRaises(MainError): + coro.throw(SubError()) + + def test_generator_doesnt_retain_old_exc2(self): + #Issue 28884#msg282532 + def g(): + try: + raise ValueError + except ValueError: + yield 1 + self.assertIsNone(sys.exception()) + yield 2 + + gen = g() + + try: + raise IndexError + except IndexError: + self.assertEqual(next(gen), 1) + self.assertEqual(next(gen), 2) + + def test_raise_in_generator(self): + #Issue 25612#msg304117 + def g(): + yield 1 + raise + yield 2 + + with self.assertRaises(ZeroDivisionError): + i = g() + try: + 1/0 + except: + next(i) + next(i) + + @unittest.skipUnless(__debug__, "Won't work if __debug__ is False") + def test_assert_shadowing(self): + # Shadowing AssertionError would cause the assert statement to + # misbehave. + global AssertionError + AssertionError = TypeError + try: + assert False, 'hello' + except BaseException as e: + del AssertionError + self.assertIsInstance(e, AssertionError) + self.assertEqual(str(e), 'hello') + else: + del AssertionError + self.fail('Expected exception') + + def test_memory_error_subclasses(self): + # bpo-41654: MemoryError instances use a freelist of objects that are + # linked using the 'dict' attribute when they are inactive/dead. + # Subclasses of MemoryError should not participate in the freelist + # schema. This test creates a MemoryError object and keeps it alive + # (therefore advancing the freelist) and then it creates and destroys a + # subclass object. Finally, it checks that creating a new MemoryError + # succeeds, proving that the freelist is not corrupted. + + class TestException(MemoryError): + pass + + try: + raise MemoryError + except MemoryError as exc: + inst = exc + + try: + raise TestException + except Exception: + pass + + for _ in range(10): + try: + raise MemoryError + except MemoryError as exc: + pass + + gc_collect() + + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def test_memory_error_in_subinterp(self): + # gh-109894: subinterpreters shouldn't count on last resort memory error + # when MemoryError is raised through PyErr_NoMemory() call, + # and should preallocate memory errors as does the main interpreter. + # interp.static_objects.last_resort_memory_error.args + # should be initialized to empty tuple to avoid crash on attempt to print it. + code = f"""if 1: + import _testcapi + _testcapi.run_in_subinterp(\"[0]*{sys.maxsize}\") + exit(0) + """ + rc, _, err = script_helper.assert_python_ok("-c", code) + self.assertIn(b'MemoryError', err) + + +class NameErrorTests(__TestCase): + def test_name_error_has_name(self): + try: + bluch + except NameError as exc: + self.assertEqual("bluch", exc.name) + + def test_issue45826(self): + # regression test for bpo-45826 + def f(): + with self.assertRaisesRegex(NameError, 'aaa'): + aab + + try: + f() + except self.failureException: + with support.captured_stderr() as err: + sys.__excepthook__(*sys.exc_info()) + else: + self.fail("assertRaisesRegex should have failed.") + + self.assertIn("aab", err.getvalue()) + + def test_issue45826_focused(self): + def f(): + try: + nonsense + except BaseException as E: + E.with_traceback(None) + raise ZeroDivisionError() + + try: + f() + except ZeroDivisionError: + with support.captured_stderr() as err: + sys.__excepthook__(*sys.exc_info()) + + self.assertIn("nonsense", err.getvalue()) + self.assertIn("ZeroDivisionError", err.getvalue()) + + def test_gh_111654(self): + def f(): + class TestClass: + TestClass + + self.assertRaises(NameError, f) + + # Note: name suggestion tests live in `test_traceback`. + + +class AttributeErrorTests(__TestCase): + def test_attributes(self): + # Setting 'attr' should not be a problem. + exc = AttributeError('Ouch!') + self.assertIsNone(exc.name) + self.assertIsNone(exc.obj) + + sentinel = object() + exc = AttributeError('Ouch', name='carry', obj=sentinel) + self.assertEqual(exc.name, 'carry') + self.assertIs(exc.obj, sentinel) + + def test_getattr_has_name_and_obj(self): + class A: + blech = None + + obj = A() + try: + obj.bluch + except AttributeError as exc: + self.assertEqual("bluch", exc.name) + self.assertEqual(obj, exc.obj) + try: + object.__getattribute__(obj, "bluch") + except AttributeError as exc: + self.assertEqual("bluch", exc.name) + self.assertEqual(obj, exc.obj) + + def test_getattr_has_name_and_obj_for_method(self): + class A: + def blech(self): + return + + obj = A() + try: + obj.bluch() + except AttributeError as exc: + self.assertEqual("bluch", exc.name) + self.assertEqual(obj, exc.obj) + + # Note: name suggestion tests live in `test_traceback`. + + +class ImportErrorTests(__TestCase): + + def test_attributes(self): + # Setting 'name' and 'path' should not be a problem. + exc = ImportError('test') + self.assertIsNone(exc.name) + self.assertIsNone(exc.path) + + exc = ImportError('test', name='somemodule') + self.assertEqual(exc.name, 'somemodule') + self.assertIsNone(exc.path) + + exc = ImportError('test', path='somepath') + self.assertEqual(exc.path, 'somepath') + self.assertIsNone(exc.name) + + exc = ImportError('test', path='somepath', name='somename') + self.assertEqual(exc.name, 'somename') + self.assertEqual(exc.path, 'somepath') + + msg = r"ImportError\(\) got an unexpected keyword argument 'invalid'" + with self.assertRaisesRegex(TypeError, msg): + ImportError('test', invalid='keyword') + + with self.assertRaisesRegex(TypeError, msg): + ImportError('test', name='name', invalid='keyword') + + with self.assertRaisesRegex(TypeError, msg): + ImportError('test', path='path', invalid='keyword') + + with self.assertRaisesRegex(TypeError, msg): + ImportError(invalid='keyword') + + with self.assertRaisesRegex(TypeError, msg): + ImportError('test', invalid='keyword', another=True) + + def test_reset_attributes(self): + exc = ImportError('test', name='name', path='path') + self.assertEqual(exc.args, ('test',)) + self.assertEqual(exc.msg, 'test') + self.assertEqual(exc.name, 'name') + self.assertEqual(exc.path, 'path') + + # Reset not specified attributes + exc.__init__() + self.assertEqual(exc.args, ()) + self.assertEqual(exc.msg, None) + self.assertEqual(exc.name, None) + self.assertEqual(exc.path, None) + + def test_non_str_argument(self): + # Issue #15778 + with check_warnings(('', BytesWarning), quiet=True): + arg = b'abc' + exc = ImportError(arg) + self.assertEqual(str(arg), str(exc)) + + def test_copy_pickle(self): + for kwargs in (dict(), + dict(name='somename'), + dict(path='somepath'), + dict(name='somename', path='somepath')): + orig = ImportError('test', **kwargs) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + exc = pickle.loads(pickle.dumps(orig, proto)) + self.assertEqual(exc.args, ('test',)) + self.assertEqual(exc.msg, 'test') + self.assertEqual(exc.name, orig.name) + self.assertEqual(exc.path, orig.path) + for c in copy.copy, copy.deepcopy: + exc = c(orig) + self.assertEqual(exc.args, ('test',)) + self.assertEqual(exc.msg, 'test') + self.assertEqual(exc.name, orig.name) + self.assertEqual(exc.path, orig.path) + + +def run_script(source): + if isinstance(source, str): + with open(TESTFN, 'w', encoding='utf-8') as testfile: + testfile.write(dedent(source)) + else: + with open(TESTFN, 'wb') as testfile: + testfile.write(source) + _rc, _out, err = script_helper.assert_python_failure('-Wd', '-X', 'utf8', TESTFN) + return err.decode('utf-8').splitlines() + +class AssertionErrorTests(__TestCase): + def tearDown(self): + unlink(TESTFN) + + @force_not_colorized + def test_assertion_error_location(self): + cases = [ + ('assert None', + [ + ' assert None', + ' ^^^^', + 'AssertionError', + ], + ), + ('assert 0', + [ + ' assert 0', + ' ^', + 'AssertionError', + ], + ), + ('assert 1 > 2', + [ + ' assert 1 > 2', + ' ^^^^^', + 'AssertionError', + ], + ), + ('assert 1 > 2 and 3 > 2', + [ + ' assert 1 > 2 and 3 > 2', + ' ^^^^^^^^^^^^^^^', + 'AssertionError', + ], + ), + ('assert 1 > 2, "messäge"', + [ + ' assert 1 > 2, "messäge"', + ' ^^^^^', + 'AssertionError: messäge', + ], + ), + ('assert 1 > 2, "messäge"'.encode(), + [ + ' assert 1 > 2, "messäge"', + ' ^^^^^', + 'AssertionError: messäge', + ], + ), + ('# coding: latin1\nassert 1 > 2, "messäge"'.encode('latin1'), + [ + ' assert 1 > 2, "messäge"', + ' ^^^^^', + 'AssertionError: messäge', + ], + ), + (BOM_UTF8 + 'assert 1 > 2, "messäge"'.encode(), + [ + ' assert 1 > 2, "messäge"', + ' ^^^^^', + 'AssertionError: messäge', + ], + ), + + # Multiline: + (""" + assert ( + 1 > 2) + """, + [ + ' 1 > 2)', + ' ^^^^^', + 'AssertionError', + ], + ), + (""" + assert ( + 1 > 2), "Message" + """, + [ + ' 1 > 2), "Message"', + ' ^^^^^', + 'AssertionError: Message', + ], + ), + (""" + assert ( + 1 > 2), \\ + "Message" + """, + [ + ' 1 > 2), \\', + ' ^^^^^', + 'AssertionError: Message', + ], + ), + ] + for source, expected in cases: + with self.subTest(source=source): + result = run_script(source) + self.assertEqual(result[-3:], expected) + + @force_not_colorized + def test_multiline_not_highlighted(self): + cases = [ + (""" + assert ( + 1 > 2 + ) + """, + [ + ' 1 > 2', + 'AssertionError', + ], + ), + (""" + assert ( + 1 < 2 and + 3 > 4 + ) + """, + [ + ' 1 < 2 and', + ' 3 > 4', + 'AssertionError', + ], + ), + ] + for source, expected in cases: + with self.subTest(source=source): + result = run_script(source) + self.assertEqual(result[-len(expected):], expected) + + +@support.force_not_colorized_test_class +class SyntaxErrorTests(__TestCase): + maxDiff = None + + @force_not_colorized + def test_range_of_offsets(self): + cases = [ + # Basic range from 2->7 + (("bad.py", 1, 2, "abcdefg", 1, 7), + dedent( + """ + File "bad.py", line 1 + abcdefg + ^^^^^ + SyntaxError: bad bad + """)), + # end_offset = start_offset + 1 + (("bad.py", 1, 2, "abcdefg", 1, 3), + dedent( + """ + File "bad.py", line 1 + abcdefg + ^ + SyntaxError: bad bad + """)), + # Negative end offset + (("bad.py", 1, 2, "abcdefg", 1, -2), + dedent( + """ + File "bad.py", line 1 + abcdefg + ^ + SyntaxError: bad bad + """)), + # end offset before starting offset + (("bad.py", 1, 4, "abcdefg", 1, 2), + dedent( + """ + File "bad.py", line 1 + abcdefg + ^ + SyntaxError: bad bad + """)), + # Both offsets negative + (("bad.py", 1, -4, "abcdefg", 1, -2), + dedent( + """ + File "bad.py", line 1 + abcdefg + SyntaxError: bad bad + """)), + # Both offsets negative and the end more negative + (("bad.py", 1, -4, "abcdefg", 1, -5), + dedent( + """ + File "bad.py", line 1 + abcdefg + SyntaxError: bad bad + """)), + # Both offsets 0 + (("bad.py", 1, 0, "abcdefg", 1, 0), + dedent( + """ + File "bad.py", line 1 + abcdefg + SyntaxError: bad bad + """)), + # Start offset 0 and end offset not 0 + (("bad.py", 1, 0, "abcdefg", 1, 5), + dedent( + """ + File "bad.py", line 1 + abcdefg + SyntaxError: bad bad + """)), + # End offset pass the source length + (("bad.py", 1, 2, "abcdefg", 1, 100), + dedent( + """ + File "bad.py", line 1 + abcdefg + ^^^^^^ + SyntaxError: bad bad + """)), + ] + for args, expected in cases: + with self.subTest(args=args): + try: + raise SyntaxError("bad bad", args) + except SyntaxError as exc: + with support.captured_stderr() as err: + sys.__excepthook__(*sys.exc_info()) + self.assertIn(expected, err.getvalue()) + the_exception = exc + + def test_subclass(self): + class MySyntaxError(SyntaxError): + pass + + try: + raise MySyntaxError("bad bad", ("bad.py", 1, 2, "abcdefg", 1, 7)) + except SyntaxError as exc: + with support.captured_stderr() as err: + sys.__excepthook__(*sys.exc_info()) + self.assertIn(""" + File "bad.py", line 1 + abcdefg + ^^^^^ +""", err.getvalue()) + + def test_encodings(self): + self.addCleanup(unlink, TESTFN) + source = ( + '# -*- coding: cp437 -*-\n' + '"┬ó┬ó┬ó┬ó┬ó┬ó" + f(4, x for x in range(1))\n' + ) + err = run_script(source.encode('cp437')) + self.assertEqual(err[-3], ' "┬ó┬ó┬ó┬ó┬ó┬ó" + f(4, x for x in range(1))') + self.assertEqual(err[-2], ' ^^^^^^^^^^^^^^^^^^^') + + # Check backwards tokenizer errors + source = '# -*- coding: ascii -*-\n\n(\n' + err = run_script(source) + self.assertEqual(err[-3], ' (') + self.assertEqual(err[-2], ' ^') + + def test_non_utf8(self): + # Check non utf-8 characters + self.addCleanup(unlink, TESTFN) + err = run_script(b"\x89") + self.assertIn("SyntaxError: Non-UTF-8 code starting with '\\x89' in file", err[-1]) + + + def test_string_source(self): + def try_compile(source): + with self.assertRaises(SyntaxError) as cm: + compile(source, '', 'exec') + return cm.exception + + exc = try_compile('return "ä"') + self.assertEqual(str(exc), "'return' outside function (, line 1)") + self.assertIsNone(exc.text) + self.assertEqual(exc.offset, 1) + self.assertEqual(exc.end_offset, 12) + + exc = try_compile('return "ä"'.encode()) + self.assertEqual(str(exc), "'return' outside function (, line 1)") + self.assertIsNone(exc.text) + self.assertEqual(exc.offset, 1) + self.assertEqual(exc.end_offset, 12) + + exc = try_compile(BOM_UTF8 + 'return "ä"'.encode()) + self.assertEqual(str(exc), "'return' outside function (, line 1)") + self.assertIsNone(exc.text) + self.assertEqual(exc.offset, 1) + self.assertEqual(exc.end_offset, 12) + + exc = try_compile('# coding: latin1\nreturn "ä"'.encode('latin1')) + self.assertEqual(str(exc), "'return' outside function (, line 2)") + self.assertIsNone(exc.text) + self.assertEqual(exc.offset, 1) + self.assertEqual(exc.end_offset, 12) + + exc = try_compile('return "ä" #' + 'ä'*1000) + self.assertEqual(str(exc), "'return' outside function (, line 1)") + self.assertIsNone(exc.text) + self.assertEqual(exc.offset, 1) + self.assertEqual(exc.end_offset, 12) + + exc = try_compile('return "ä" # ' + 'ä'*1000) + self.assertEqual(str(exc), "'return' outside function (, line 1)") + self.assertIsNone(exc.text) + self.assertEqual(exc.offset, 1) + self.assertEqual(exc.end_offset, 12) + + def test_file_source(self): + self.addCleanup(unlink, TESTFN) + err = run_script('return "ä"') + self.assertEqual(err[-3:], [ + ' return "ä"', + ' ^^^^^^^^^^', + "SyntaxError: 'return' outside function"]) + + err = run_script('return "ä"'.encode()) + self.assertEqual(err[-3:], [ + ' return "ä"', + ' ^^^^^^^^^^', + "SyntaxError: 'return' outside function"]) + + err = run_script(BOM_UTF8 + 'return "ä"'.encode()) + self.assertEqual(err[-3:], [ + ' return "ä"', + ' ^^^^^^^^^^', + "SyntaxError: 'return' outside function"]) + + err = run_script('# coding: latin1\nreturn "ä"'.encode('latin1')) + self.assertEqual(err[-3:], [ + ' return "ä"', + ' ^^^^^^^^^^', + "SyntaxError: 'return' outside function"]) + + err = run_script('return "ä" #' + 'ä'*1000) + self.assertEqual(err[-2:], [ + ' ^^^^^^^^^^^', + "SyntaxError: 'return' outside function"]) + self.assertEqual(err[-3][:100], ' return "ä" #' + 'ä'*84) + + err = run_script('return "ä" # ' + 'ä'*1000) + self.assertEqual(err[-2:], [ + ' ^^^^^^^^^^^', + "SyntaxError: 'return' outside function"]) + self.assertEqual(err[-3][:100], ' return "ä" # ' + 'ä'*83) + + def test_attributes_new_constructor(self): + args = ("bad.py", 1, 2, "abcdefg", 1, 100) + the_exception = SyntaxError("bad bad", args) + filename, lineno, offset, error, end_lineno, end_offset = args + self.assertEqual(filename, the_exception.filename) + self.assertEqual(lineno, the_exception.lineno) + self.assertEqual(end_lineno, the_exception.end_lineno) + self.assertEqual(offset, the_exception.offset) + self.assertEqual(end_offset, the_exception.end_offset) + self.assertEqual(error, the_exception.text) + self.assertEqual("bad bad", the_exception.msg) + + def test_attributes_old_constructor(self): + args = ("bad.py", 1, 2, "abcdefg") + the_exception = SyntaxError("bad bad", args) + filename, lineno, offset, error = args + self.assertEqual(filename, the_exception.filename) + self.assertEqual(lineno, the_exception.lineno) + self.assertEqual(None, the_exception.end_lineno) + self.assertEqual(offset, the_exception.offset) + self.assertEqual(None, the_exception.end_offset) + self.assertEqual(error, the_exception.text) + self.assertEqual("bad bad", the_exception.msg) + + def test_incorrect_constructor(self): + args = ("bad.py", 1, 2) + self.assertRaises(TypeError, SyntaxError, "bad bad", args) + + args = ("bad.py", 1, 2, 4, 5, 6, 7) + self.assertRaises(TypeError, SyntaxError, "bad bad", args) + + args = ("bad.py", 1, 2, "abcdefg", 1) + self.assertRaises(TypeError, SyntaxError, "bad bad", args) + + +class TestInvalidExceptionMatcher(__TestCase): + def test_except_star_invalid_exception_type(self): + with self.assertRaises(TypeError): + try: + raise ValueError + except 42: + pass + + with self.assertRaises(TypeError): + try: + raise ValueError + except (ValueError, 42): + pass + + +class PEP626Tests(__TestCase): + + def lineno_after_raise(self, f, *expected): + try: + f() + except Exception as ex: + t = ex.__traceback__ + else: + self.fail("No exception raised") + lines = [] + t = t.tb_next # Skip this function + while t: + frame = t.tb_frame + lines.append( + None if frame.f_lineno is None else + frame.f_lineno-frame.f_code.co_firstlineno + ) + t = t.tb_next + self.assertEqual(tuple(lines), expected) + + def test_lineno_after_raise_simple(self): + def simple(): + 1/0 + pass + self.lineno_after_raise(simple, 1) + + def test_lineno_after_raise_in_except(self): + def in_except(): + try: + 1/0 + except: + 1/0 + pass + self.lineno_after_raise(in_except, 4) + + def test_lineno_after_other_except(self): + def other_except(): + try: + 1/0 + except TypeError as ex: + pass + self.lineno_after_raise(other_except, 3) + + def test_lineno_in_named_except(self): + def in_named_except(): + try: + 1/0 + except Exception as ex: + 1/0 + pass + self.lineno_after_raise(in_named_except, 4) + + def test_lineno_in_try(self): + def in_try(): + try: + 1/0 + finally: + pass + self.lineno_after_raise(in_try, 4) + + def test_lineno_in_finally_normal(self): + def in_finally_normal(): + try: + pass + finally: + 1/0 + pass + self.lineno_after_raise(in_finally_normal, 4) + + def test_lineno_in_finally_except(self): + def in_finally_except(): + try: + 1/0 + finally: + 1/0 + pass + self.lineno_after_raise(in_finally_except, 4) + + def test_lineno_after_with(self): + class Noop: + def __enter__(self): + return self + def __exit__(self, *args): + pass + def after_with(): + with Noop(): + 1/0 + pass + self.lineno_after_raise(after_with, 2) + + def test_missing_lineno_shows_as_none(self): + def f(): + 1/0 + self.lineno_after_raise(f, 1) + f.__code__ = f.__code__.replace(co_linetable=b'\xf8\xf8\xf8\xf9\xf8\xf8\xf8') + self.lineno_after_raise(f, None) + + def test_lineno_after_raise_in_with_exit(self): + class ExitFails: + def __enter__(self): + return self + def __exit__(self, *args): + raise ValueError + + def after_with(): + with ExitFails(): + 1/0 + self.lineno_after_raise(after_with, 1, 1) + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_float.diff b/test/dynamo/cpython/3_13/test_float.diff new file mode 100644 index 00000000000000..6b8586b1c6639d --- /dev/null +++ b/test/dynamo/cpython/3_13/test_float.diff @@ -0,0 +1,279 @@ +diff --git a/test/dynamo/cpython/3_13/test_float.py b/test/dynamo/cpython/3_13/test_float.py +index 97f951f1299..ce2c46777e0 100644 +--- a/test/dynamo/cpython/3_13/test_float.py ++++ b/test/dynamo/cpython/3_13/test_float.py +@@ -1,3 +1,54 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import fractions + import operator + import os +@@ -8,11 +59,84 @@ import time + import unittest + + from test import support +-from test.support.testcase import FloatsAreIdenticalMixin +-from test.support.numbers import ( +- VALID_UNDERSCORE_LITERALS, +- INVALID_UNDERSCORE_LITERALS, +-) ++ ++VALID_UNDERSCORE_LITERALS = [ ++ '0_0_0', ++ '4_2', ++ '1_0000_0000', ++ '0b1001_0100', ++ '0xffff_ffff', ++ '0o5_7_7', ++ '1_00_00.5', ++ '1_00_00.5e5', ++ '1_00_00e5_1', ++ '1e1_0', ++ '.1_4', ++ '.1_4e1', ++ '0b_0', ++ '0x_f', ++ '0o_5', ++ '1_00_00j', ++ '1_00_00.5j', ++ '1_00_00e5_1j', ++ '.1_4j', ++ '(1_2.5+3_3j)', ++ '(.5_6j)', ++] ++INVALID_UNDERSCORE_LITERALS = [ ++ # Trailing underscores: ++ '0_', ++ '42_', ++ '1.4j_', ++ '0x_', ++ '0b1_', ++ '0xf_', ++ '0o5_', ++ '0 if 1_Else 1', ++ # Underscores in the base selector: ++ '0_b0', ++ '0_xf', ++ '0_o5', ++ # Old-style octal, still disallowed: ++ '0_7', ++ '09_99', ++ # Multiple consecutive underscores: ++ '4_______2', ++ '0.1__4', ++ '0.1__4j', ++ '0b1001__0100', ++ '0xffff__ffff', ++ '0x___', ++ '0o5__77', ++ '1e1__0', ++ '1e1__0j', ++ # Underscore right before a dot: ++ '1_.4', ++ '1_.4j', ++ # Underscore right after a dot: ++ '1._4', ++ '1._4j', ++ '._5', ++ '._5j', ++ # Underscore right after a sign: ++ '1.0e+_1', ++ '1.0e+_1j', ++ # Underscore right before j: ++ '1.4_j', ++ '1.4e5_j', ++ # Underscore right before e: ++ '1_e1', ++ '1.4_e1', ++ '1.4_e1j', ++ # Underscore right after e: ++ '1e_1', ++ '1.4e_1', ++ '1.4e_1j', ++ # Complex cases with parens: ++ '(1+1.5_j_)', ++ '(1+1.5_j)', ++] ++ + from math import isinf, isnan, copysign, ldexp + import math + +@@ -35,7 +159,7 @@ class FloatSubclass(float): + class OtherFloatSubclass(float): + pass + +-class GeneralFloatCases(unittest.TestCase): ++class GeneralFloatCases(__TestCase): + + def test_float(self): + self.assertEqual(float(3.14), 3.14) +@@ -620,7 +744,7 @@ class GeneralFloatCases(unittest.TestCase): + + + @unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__") +-class FormatFunctionsTestCase(unittest.TestCase): ++class FormatFunctionsTestCase(__TestCase): + def test_getformat(self): + self.assertIn(float.__getformat__('double'), + ['unknown', 'IEEE, big-endian', 'IEEE, little-endian']) +@@ -645,7 +769,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN)) + # is accident (today). + # let's also try to guarantee that -0.0 and 0.0 don't get confused. + +-class IEEEFormatTestCase(unittest.TestCase): ++class IEEEFormatTestCase(__TestCase): + + @support.requires_IEEE_754 + def test_double_specials_do_unpack(self): +@@ -670,7 +794,7 @@ class IEEEFormatTestCase(unittest.TestCase): + self.assertEqual(struct.pack(" 1 + self.assertEqualAndEqualSign(pow_op(-INF, -INF), 0.0) + self.assertEqualAndEqualSign(pow_op(-2.0, -INF), 0.0) + self.assertEqualAndEqualSign(pow_op(2.0, -INF), 0.0) + self.assertEqualAndEqualSign(pow_op(INF, -INF), 0.0) + + # x**INF is 0 for abs(x) < 1 + self.assertEqualAndEqualSign(pow_op(-0.5, INF), 0.0) + self.assertEqualAndEqualSign(pow_op(-0.0, INF), 0.0) + self.assertEqualAndEqualSign(pow_op(0.0, INF), 0.0) + self.assertEqualAndEqualSign(pow_op(0.5, INF), 0.0) + + # x**INF is INF for abs(x) > 1 + self.assertEqualAndEqualSign(pow_op(-INF, INF), INF) + self.assertEqualAndEqualSign(pow_op(-2.0, INF), INF) + self.assertEqualAndEqualSign(pow_op(2.0, INF), INF) + self.assertEqualAndEqualSign(pow_op(INF, INF), INF) + + # (-INF)**y is -0.0 for y a negative odd integer + self.assertEqualAndEqualSign(pow_op(-INF, -1.0), -0.0) + + # (-INF)**y is 0.0 for y negative but not an odd integer + self.assertEqualAndEqualSign(pow_op(-INF, -0.5), 0.0) + self.assertEqualAndEqualSign(pow_op(-INF, -2.0), 0.0) + + # (-INF)**y is -INF for y a positive odd integer + self.assertEqualAndEqualSign(pow_op(-INF, 1.0), -INF) + + # (-INF)**y is INF for y positive but not an odd integer + self.assertEqualAndEqualSign(pow_op(-INF, 0.5), INF) + self.assertEqualAndEqualSign(pow_op(-INF, 2.0), INF) + + # INF**y is INF for y positive + self.assertEqualAndEqualSign(pow_op(INF, 0.5), INF) + self.assertEqualAndEqualSign(pow_op(INF, 1.0), INF) + self.assertEqualAndEqualSign(pow_op(INF, 2.0), INF) + + # INF**y is 0.0 for y negative + self.assertEqualAndEqualSign(pow_op(INF, -2.0), 0.0) + self.assertEqualAndEqualSign(pow_op(INF, -1.0), 0.0) + self.assertEqualAndEqualSign(pow_op(INF, -0.5), 0.0) + + # basic checks not covered by the special cases above + self.assertEqualAndEqualSign(pow_op(-2.0, -2.0), 0.25) + self.assertEqualAndEqualSign(pow_op(-2.0, -1.0), -0.5) + self.assertEqualAndEqualSign(pow_op(-2.0, -0.0), 1.0) + self.assertEqualAndEqualSign(pow_op(-2.0, 0.0), 1.0) + self.assertEqualAndEqualSign(pow_op(-2.0, 1.0), -2.0) + self.assertEqualAndEqualSign(pow_op(-2.0, 2.0), 4.0) + self.assertEqualAndEqualSign(pow_op(-1.0, -2.0), 1.0) + self.assertEqualAndEqualSign(pow_op(-1.0, -1.0), -1.0) + self.assertEqualAndEqualSign(pow_op(-1.0, -0.0), 1.0) + self.assertEqualAndEqualSign(pow_op(-1.0, 0.0), 1.0) + self.assertEqualAndEqualSign(pow_op(-1.0, 1.0), -1.0) + self.assertEqualAndEqualSign(pow_op(-1.0, 2.0), 1.0) + self.assertEqualAndEqualSign(pow_op(2.0, -2.0), 0.25) + self.assertEqualAndEqualSign(pow_op(2.0, -1.0), 0.5) + self.assertEqualAndEqualSign(pow_op(2.0, -0.0), 1.0) + self.assertEqualAndEqualSign(pow_op(2.0, 0.0), 1.0) + self.assertEqualAndEqualSign(pow_op(2.0, 1.0), 2.0) + self.assertEqualAndEqualSign(pow_op(2.0, 2.0), 4.0) + + # 1 ** large and -1 ** large; some libms apparently + # have problems with these + self.assertEqualAndEqualSign(pow_op(1.0, -1e100), 1.0) + self.assertEqualAndEqualSign(pow_op(1.0, 1e100), 1.0) + self.assertEqualAndEqualSign(pow_op(-1.0, -1e100), 1.0) + self.assertEqualAndEqualSign(pow_op(-1.0, 1e100), 1.0) + + # check sign for results that underflow to 0 + self.assertEqualAndEqualSign(pow_op(-2.0, -2000.0), 0.0) + self.assertEqual(type(pow_op(-2.0, -2000.5)), complex) + self.assertEqualAndEqualSign(pow_op(-2.0, -2001.0), -0.0) + self.assertEqualAndEqualSign(pow_op(2.0, -2000.0), 0.0) + self.assertEqualAndEqualSign(pow_op(2.0, -2000.5), 0.0) + self.assertEqualAndEqualSign(pow_op(2.0, -2001.0), 0.0) + self.assertEqualAndEqualSign(pow_op(-0.5, 2000.0), 0.0) + self.assertEqual(type(pow_op(-0.5, 2000.5)), complex) + self.assertEqualAndEqualSign(pow_op(-0.5, 2001.0), -0.0) + self.assertEqualAndEqualSign(pow_op(0.5, 2000.0), 0.0) + self.assertEqualAndEqualSign(pow_op(0.5, 2000.5), 0.0) + self.assertEqualAndEqualSign(pow_op(0.5, 2001.0), 0.0) + + # check we don't raise an exception for subnormal results, + # and validate signs. Tests currently disabled, since + # they fail on systems where a subnormal result from pow + # is flushed to zero (e.g. Debian/ia64.) + #self.assertTrue(0.0 < pow_op(0.5, 1048) < 1e-315) + #self.assertTrue(0.0 < pow_op(-0.5, 1048) < 1e-315) + #self.assertTrue(0.0 < pow_op(0.5, 1047) < 1e-315) + #self.assertTrue(0.0 > pow_op(-0.5, 1047) > -1e-315) + #self.assertTrue(0.0 < pow_op(2.0, -1048) < 1e-315) + #self.assertTrue(0.0 < pow_op(-2.0, -1048) < 1e-315) + #self.assertTrue(0.0 < pow_op(2.0, -1047) < 1e-315) + #self.assertTrue(0.0 > pow_op(-2.0, -1047) > -1e-315) + + def test_hash(self): + for x in range(-30, 30): + self.assertEqual(hash(float(x)), hash(x)) + self.assertEqual(hash(float(sys.float_info.max)), + hash(int(sys.float_info.max))) + self.assertEqual(hash(float('inf')), sys.hash_info.inf) + self.assertEqual(hash(float('-inf')), -sys.hash_info.inf) + + def test_hash_nan(self): + value = float('nan') + self.assertEqual(hash(value), object.__hash__(value)) + class H: + def __hash__(self): + return 42 + class F(float, H): + pass + value = F('nan') + self.assertEqual(hash(value), object.__hash__(value)) + + +@unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__") +class FormatFunctionsTestCase(__TestCase): + def test_getformat(self): + self.assertIn(float.__getformat__('double'), + ['unknown', 'IEEE, big-endian', 'IEEE, little-endian']) + self.assertIn(float.__getformat__('float'), + ['unknown', 'IEEE, big-endian', 'IEEE, little-endian']) + self.assertRaises(ValueError, float.__getformat__, 'chicken') + self.assertRaises(TypeError, float.__getformat__, 1) + + +BE_DOUBLE_INF = b'\x7f\xf0\x00\x00\x00\x00\x00\x00' +LE_DOUBLE_INF = bytes(reversed(BE_DOUBLE_INF)) +BE_DOUBLE_NAN = b'\x7f\xf8\x00\x00\x00\x00\x00\x00' +LE_DOUBLE_NAN = bytes(reversed(BE_DOUBLE_NAN)) + +BE_FLOAT_INF = b'\x7f\x80\x00\x00' +LE_FLOAT_INF = bytes(reversed(BE_FLOAT_INF)) +BE_FLOAT_NAN = b'\x7f\xc0\x00\x00' +LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN)) + +# on an IEEE platform, all we guarantee is that bit patterns +# representing infinities or NaNs do not raise an exception; all else +# is accident (today). +# let's also try to guarantee that -0.0 and 0.0 don't get confused. + +class IEEEFormatTestCase(__TestCase): + + @support.requires_IEEE_754 + def test_double_specials_do_unpack(self): + for fmt, data in [('>d', BE_DOUBLE_INF), + ('>d', BE_DOUBLE_NAN), + ('f', BE_FLOAT_INF), + ('>f', BE_FLOAT_NAN), + (''), str(x)) + self.assertEqual(format(x, '2'), str(x)) + + self.assertEqual(format(1.0, 'f'), '1.000000') + + self.assertEqual(format(-1.0, 'f'), '-1.000000') + + self.assertEqual(format( 1.0, ' f'), ' 1.000000') + self.assertEqual(format(-1.0, ' f'), '-1.000000') + self.assertEqual(format( 1.0, '+f'), '+1.000000') + self.assertEqual(format(-1.0, '+f'), '-1.000000') + + # % formatting + self.assertEqual(format(-1.0, '%'), '-100.000000%') + + # conversion to string should fail + self.assertRaises(ValueError, format, 3.0, "s") + + # confirm format options expected to fail on floats, such as integer + # presentation types + for format_spec in 'sbcdoxX': + self.assertRaises(ValueError, format, 0.0, format_spec) + self.assertRaises(ValueError, format, 1.0, format_spec) + self.assertRaises(ValueError, format, -1.0, format_spec) + self.assertRaises(ValueError, format, 1e100, format_spec) + self.assertRaises(ValueError, format, -1e100, format_spec) + self.assertRaises(ValueError, format, 1e-100, format_spec) + self.assertRaises(ValueError, format, -1e-100, format_spec) + + # issue 3382 + self.assertEqual(format(NAN, 'f'), 'nan') + self.assertEqual(format(NAN, 'F'), 'NAN') + self.assertEqual(format(INF, 'f'), 'inf') + self.assertEqual(format(INF, 'F'), 'INF') + + @support.requires_IEEE_754 + def test_format_testfile(self): + with open(format_testfile, encoding="utf-8") as testfile: + for line in testfile: + if line.startswith('--'): + continue + line = line.strip() + if not line: + continue + + lhs, rhs = map(str.strip, line.split('->')) + fmt, arg = lhs.split() + f = float(arg) + self.assertEqual(fmt % f, rhs) + self.assertEqual(fmt % -f, '-' + rhs) + if fmt != '%r': + fmt2 = fmt[1:] + self.assertEqual(format(f, fmt2), rhs) + self.assertEqual(format(-f, fmt2), '-' + rhs) + + def test_issue5864(self): + self.assertEqual(format(123.456, '.4'), '123.5') + self.assertEqual(format(1234.56, '.4'), '1.235e+03') + self.assertEqual(format(12345.6, '.4'), '1.235e+04') + + def test_issue35560(self): + self.assertEqual(format(123.0, '00'), '123.0') + self.assertEqual(format(123.34, '00f'), '123.340000') + self.assertEqual(format(123.34, '00e'), '1.233400e+02') + self.assertEqual(format(123.34, '00g'), '123.34') + self.assertEqual(format(123.34, '00.10f'), '123.3400000000') + self.assertEqual(format(123.34, '00.10e'), '1.2334000000e+02') + self.assertEqual(format(123.34, '00.10g'), '123.34') + self.assertEqual(format(123.34, '01f'), '123.340000') + + self.assertEqual(format(-123.0, '00'), '-123.0') + self.assertEqual(format(-123.34, '00f'), '-123.340000') + self.assertEqual(format(-123.34, '00e'), '-1.233400e+02') + self.assertEqual(format(-123.34, '00g'), '-123.34') + self.assertEqual(format(-123.34, '00.10f'), '-123.3400000000') + self.assertEqual(format(-123.34, '00.10f'), '-123.3400000000') + self.assertEqual(format(-123.34, '00.10e'), '-1.2334000000e+02') + self.assertEqual(format(-123.34, '00.10g'), '-123.34') + +class ReprTestCase(__TestCase): + def test_repr(self): + with open(os.path.join(os.path.split(__file__)[0], + 'mathdata', + 'floating_points.txt'), encoding="utf-8") as floats_file: + for line in floats_file: + line = line.strip() + if not line or line.startswith('#'): + continue + v = eval(line) + self.assertEqual(v, eval(repr(v))) + + @unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short', + "applies only when using short float repr style") + def test_short_repr(self): + # test short float repr introduced in Python 3.1. One aspect + # of this repr is that we get some degree of str -> float -> + # str roundtripping. In particular, for any numeric string + # containing 15 or fewer significant digits, those exact same + # digits (modulo trailing zeros) should appear in the output. + # No more repr(0.03) -> "0.029999999999999999"! + + test_strings = [ + # output always includes *either* a decimal point and at + # least one digit after that point, or an exponent. + '0.0', + '1.0', + '0.01', + '0.02', + '0.03', + '0.04', + '0.05', + '1.23456789', + '10.0', + '100.0', + # values >= 1e16 get an exponent... + '1000000000000000.0', + '9999999999999990.0', + '1e+16', + '1e+17', + # ... and so do values < 1e-4 + '0.001', + '0.001001', + '0.00010000000000001', + '0.0001', + '9.999999999999e-05', + '1e-05', + # values designed to provoke failure if the FPU rounding + # precision isn't set correctly + '8.72293771110361e+25', + '7.47005307342313e+26', + '2.86438000439698e+28', + '8.89142905246179e+28', + '3.08578087079232e+35', + ] + + for s in test_strings: + negs = '-'+s + self.assertEqual(s, repr(float(s))) + self.assertEqual(negs, repr(float(negs))) + # Since Python 3.2, repr and str are identical + self.assertEqual(repr(float(s)), str(float(s))) + self.assertEqual(repr(float(negs)), str(float(negs))) + +@support.requires_IEEE_754 +class RoundTestCase(__TestCase): + def assertFloatsAreIdentical(self, x, y): + """assert that floats x and y are identical, in the sense that: + (1) both x and y are nans, or + (2) both x and y are infinities, with the same sign, or + (3) both x and y are zeros, with the same sign, or + (4) x and y are both finite and nonzero, and x == y + + """ + msg = 'floats {!r} and {!r} are not identical' + + if isnan(x) or isnan(y): + if isnan(x) and isnan(y): + return + elif x == y: + if x != 0.0: + return + # both zero; check that signs match + elif copysign(1.0, x) == copysign(1.0, y): + return + else: + msg += ': zeros have different signs' + self.fail(msg.format(x, y)) + + def test_inf_nan(self): + self.assertRaises(OverflowError, round, INF) + self.assertRaises(OverflowError, round, -INF) + self.assertRaises(ValueError, round, NAN) + self.assertRaises(TypeError, round, INF, 0.0) + self.assertRaises(TypeError, round, -INF, 1.0) + self.assertRaises(TypeError, round, NAN, "ceci n'est pas un integer") + self.assertRaises(TypeError, round, -0.0, 1j) + + def test_inf_nan_ndigits(self): + self.assertEqual(round(INF, 0), INF) + self.assertEqual(round(-INF, 0), -INF) + self.assertTrue(math.isnan(round(NAN, 0))) + + def test_large_n(self): + for n in [324, 325, 400, 2**31-1, 2**31, 2**32, 2**100]: + self.assertEqual(round(123.456, n), 123.456) + self.assertEqual(round(-123.456, n), -123.456) + self.assertEqual(round(1e300, n), 1e300) + self.assertEqual(round(1e-320, n), 1e-320) + self.assertEqual(round(1e150, 300), 1e150) + self.assertEqual(round(1e300, 307), 1e300) + self.assertEqual(round(-3.1415, 308), -3.1415) + self.assertEqual(round(1e150, 309), 1e150) + self.assertEqual(round(1.4e-315, 315), 1e-315) + + def test_small_n(self): + for n in [-308, -309, -400, 1-2**31, -2**31, -2**31-1, -2**100]: + self.assertFloatsAreIdentical(round(123.456, n), 0.0) + self.assertFloatsAreIdentical(round(-123.456, n), -0.0) + self.assertFloatsAreIdentical(round(1e300, n), 0.0) + self.assertFloatsAreIdentical(round(1e-320, n), 0.0) + + def test_overflow(self): + self.assertRaises(OverflowError, round, 1.6e308, -308) + self.assertRaises(OverflowError, round, -1.7e308, -308) + + @unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short', + "applies only when using short float repr style") + def test_previous_round_bugs(self): + # particular cases that have occurred in bug reports + self.assertEqual(round(562949953421312.5, 1), + 562949953421312.5) + self.assertEqual(round(56294995342131.5, 3), + 56294995342131.5) + # round-half-even + self.assertEqual(round(25.0, -1), 20.0) + self.assertEqual(round(35.0, -1), 40.0) + self.assertEqual(round(45.0, -1), 40.0) + self.assertEqual(round(55.0, -1), 60.0) + self.assertEqual(round(65.0, -1), 60.0) + self.assertEqual(round(75.0, -1), 80.0) + self.assertEqual(round(85.0, -1), 80.0) + self.assertEqual(round(95.0, -1), 100.0) + + @unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short', + "applies only when using short float repr style") + def test_matches_float_format(self): + # round should give the same results as float formatting + for i in range(500): + x = i/1000. + self.assertEqual(float(format(x, '.0f')), round(x, 0)) + self.assertEqual(float(format(x, '.1f')), round(x, 1)) + self.assertEqual(float(format(x, '.2f')), round(x, 2)) + self.assertEqual(float(format(x, '.3f')), round(x, 3)) + + for i in range(5, 5000, 10): + x = i/1000. + self.assertEqual(float(format(x, '.0f')), round(x, 0)) + self.assertEqual(float(format(x, '.1f')), round(x, 1)) + self.assertEqual(float(format(x, '.2f')), round(x, 2)) + self.assertEqual(float(format(x, '.3f')), round(x, 3)) + + for i in range(500): + x = random.random() + self.assertEqual(float(format(x, '.0f')), round(x, 0)) + self.assertEqual(float(format(x, '.1f')), round(x, 1)) + self.assertEqual(float(format(x, '.2f')), round(x, 2)) + self.assertEqual(float(format(x, '.3f')), round(x, 3)) + + def test_format_specials(self): + # Test formatting of nans and infs. + + def test(fmt, value, expected): + # Test with both % and format(). + self.assertEqual(fmt % value, expected, fmt) + fmt = fmt[1:] # strip off the % + self.assertEqual(format(value, fmt), expected, fmt) + + for fmt in ['%e', '%f', '%g', '%.0e', '%.6f', '%.20g', + '%#e', '%#f', '%#g', '%#.20e', '%#.15f', '%#.3g']: + pfmt = '%+' + fmt[1:] + sfmt = '% ' + fmt[1:] + test(fmt, INF, 'inf') + test(fmt, -INF, '-inf') + test(fmt, NAN, 'nan') + test(fmt, -NAN, 'nan') + # When asking for a sign, it's always provided. nans are + # always positive. + test(pfmt, INF, '+inf') + test(pfmt, -INF, '-inf') + test(pfmt, NAN, '+nan') + test(pfmt, -NAN, '+nan') + # When using ' ' for a sign code, only infs can be negative. + # Others have a space. + test(sfmt, INF, ' inf') + test(sfmt, -INF, '-inf') + test(sfmt, NAN, ' nan') + test(sfmt, -NAN, ' nan') + + def test_None_ndigits(self): + for x in round(1.23), round(1.23, None), round(1.23, ndigits=None): + self.assertEqual(x, 1) + self.assertIsInstance(x, int) + for x in round(1.78), round(1.78, None), round(1.78, ndigits=None): + self.assertEqual(x, 2) + self.assertIsInstance(x, int) + + +# Beginning with Python 2.6 float has cross platform compatible +# ways to create and represent inf and nan +class InfNanTest(__TestCase): + def test_inf_from_str(self): + self.assertTrue(isinf(float("inf"))) + self.assertTrue(isinf(float("+inf"))) + self.assertTrue(isinf(float("-inf"))) + self.assertTrue(isinf(float("infinity"))) + self.assertTrue(isinf(float("+infinity"))) + self.assertTrue(isinf(float("-infinity"))) + + self.assertEqual(repr(float("inf")), "inf") + self.assertEqual(repr(float("+inf")), "inf") + self.assertEqual(repr(float("-inf")), "-inf") + self.assertEqual(repr(float("infinity")), "inf") + self.assertEqual(repr(float("+infinity")), "inf") + self.assertEqual(repr(float("-infinity")), "-inf") + + self.assertEqual(repr(float("INF")), "inf") + self.assertEqual(repr(float("+Inf")), "inf") + self.assertEqual(repr(float("-iNF")), "-inf") + self.assertEqual(repr(float("Infinity")), "inf") + self.assertEqual(repr(float("+iNfInItY")), "inf") + self.assertEqual(repr(float("-INFINITY")), "-inf") + + self.assertEqual(str(float("inf")), "inf") + self.assertEqual(str(float("+inf")), "inf") + self.assertEqual(str(float("-inf")), "-inf") + self.assertEqual(str(float("infinity")), "inf") + self.assertEqual(str(float("+infinity")), "inf") + self.assertEqual(str(float("-infinity")), "-inf") + + self.assertRaises(ValueError, float, "info") + self.assertRaises(ValueError, float, "+info") + self.assertRaises(ValueError, float, "-info") + self.assertRaises(ValueError, float, "in") + self.assertRaises(ValueError, float, "+in") + self.assertRaises(ValueError, float, "-in") + self.assertRaises(ValueError, float, "infinit") + self.assertRaises(ValueError, float, "+Infin") + self.assertRaises(ValueError, float, "-INFI") + self.assertRaises(ValueError, float, "infinitys") + + self.assertRaises(ValueError, float, "++Inf") + self.assertRaises(ValueError, float, "-+inf") + self.assertRaises(ValueError, float, "+-infinity") + self.assertRaises(ValueError, float, "--Infinity") + + def test_inf_as_str(self): + self.assertEqual(repr(1e300 * 1e300), "inf") + self.assertEqual(repr(-1e300 * 1e300), "-inf") + + self.assertEqual(str(1e300 * 1e300), "inf") + self.assertEqual(str(-1e300 * 1e300), "-inf") + + def test_nan_from_str(self): + self.assertTrue(isnan(float("nan"))) + self.assertTrue(isnan(float("+nan"))) + self.assertTrue(isnan(float("-nan"))) + + self.assertEqual(repr(float("nan")), "nan") + self.assertEqual(repr(float("+nan")), "nan") + self.assertEqual(repr(float("-nan")), "nan") + + self.assertEqual(repr(float("NAN")), "nan") + self.assertEqual(repr(float("+NAn")), "nan") + self.assertEqual(repr(float("-NaN")), "nan") + + self.assertEqual(str(float("nan")), "nan") + self.assertEqual(str(float("+nan")), "nan") + self.assertEqual(str(float("-nan")), "nan") + + self.assertRaises(ValueError, float, "nana") + self.assertRaises(ValueError, float, "+nana") + self.assertRaises(ValueError, float, "-nana") + self.assertRaises(ValueError, float, "na") + self.assertRaises(ValueError, float, "+na") + self.assertRaises(ValueError, float, "-na") + + self.assertRaises(ValueError, float, "++nan") + self.assertRaises(ValueError, float, "-+NAN") + self.assertRaises(ValueError, float, "+-NaN") + self.assertRaises(ValueError, float, "--nAn") + + def test_nan_as_str(self): + self.assertEqual(repr(1e300 * 1e300 * 0), "nan") + self.assertEqual(repr(-1e300 * 1e300 * 0), "nan") + + self.assertEqual(str(1e300 * 1e300 * 0), "nan") + self.assertEqual(str(-1e300 * 1e300 * 0), "nan") + + def test_inf_signs(self): + self.assertEqual(copysign(1.0, float('inf')), 1.0) + self.assertEqual(copysign(1.0, float('-inf')), -1.0) + + def test_nan_signs(self): + # The sign of float('nan') should be predictable. + self.assertEqual(copysign(1.0, float('nan')), 1.0) + self.assertEqual(copysign(1.0, float('-nan')), -1.0) + + +fromHex = float.fromhex +toHex = float.hex +class HexFloatTestCase(__TestCase): + MAX = fromHex('0x.fffffffffffff8p+1024') # max normal + MIN = fromHex('0x1p-1022') # min normal + TINY = fromHex('0x0.0000000000001p-1022') # min subnormal + EPS = fromHex('0x0.0000000000001p0') # diff between 1.0 and next float up + + def assertFloatsAreIdentical(self, x, y): + """assert that floats x and y are identical, in the sense that: + (1) both x and y are nans, or + (2) both x and y are infinities, with the same sign, or + (3) both x and y are zeros, with the same sign, or + (4) x and y are both finite and nonzero, and x == y + + """ + msg = 'floats {!r} and {!r} are not identical' + + if isnan(x) or isnan(y): + if isnan(x) and isnan(y): + return + elif x == y: + if x != 0.0: + return + # both zero; check that signs match + elif copysign(1.0, x) == copysign(1.0, y): + return + else: + msg += ': zeros have different signs' + self.fail(msg.format(x, y)) + + def identical(self, x, y): + self.assertFloatsAreIdentical(x, y) + + def test_ends(self): + self.identical(self.MIN, ldexp(1.0, -1022)) + self.identical(self.TINY, ldexp(1.0, -1074)) + self.identical(self.EPS, ldexp(1.0, -52)) + self.identical(self.MAX, 2.*(ldexp(1.0, 1023) - ldexp(1.0, 970))) + + def test_invalid_inputs(self): + invalid_inputs = [ + 'infi', # misspelt infinities and nans + '-Infinit', + '++inf', + '-+Inf', + '--nan', + '+-NaN', + 'snan', + 'NaNs', + 'nna', + 'an', + 'nf', + 'nfinity', + 'inity', + 'iinity', + '0xnan', + '', + ' ', + 'x1.0p0', + '0xX1.0p0', + '+ 0x1.0p0', # internal whitespace + '- 0x1.0p0', + '0 x1.0p0', + '0x 1.0p0', + '0x1 2.0p0', + '+0x1 .0p0', + '0x1. 0p0', + '-0x1.0 1p0', + '-0x1.0 p0', + '+0x1.0p +0', + '0x1.0p -0', + '0x1.0p 0', + '+0x1.0p+ 0', + '-0x1.0p- 0', + '++0x1.0p-0', # double signs + '--0x1.0p0', + '+-0x1.0p+0', + '-+0x1.0p0', + '0x1.0p++0', + '+0x1.0p+-0', + '-0x1.0p-+0', + '0x1.0p--0', + '0x1.0.p0', + '0x.p0', # no hex digits before or after point + '0x1,p0', # wrong decimal point character + '0x1pa', + '0x1p\uff10', # fullwidth Unicode digits + '\uff10x1p0', + '0x\uff11p0', + '0x1.\uff10p0', + '0x1p0 \n 0x2p0', + '0x1p0\0 0x1p0', # embedded null byte is not end of string + ] + for x in invalid_inputs: + try: + result = fromHex(x) + except ValueError: + pass + else: + self.fail('Expected float.fromhex(%r) to raise ValueError; ' + 'got %r instead' % (x, result)) + + + def test_whitespace(self): + value_pairs = [ + ('inf', INF), + ('-Infinity', -INF), + ('nan', NAN), + ('1.0', 1.0), + ('-0x.2', -0.125), + ('-0.0', -0.0) + ] + whitespace = [ + '', + ' ', + '\t', + '\n', + '\n \t', + '\f', + '\v', + '\r' + ] + for inp, expected in value_pairs: + for lead in whitespace: + for trail in whitespace: + got = fromHex(lead + inp + trail) + self.identical(got, expected) + + + def test_from_hex(self): + MIN = self.MIN + MAX = self.MAX + TINY = self.TINY + EPS = self.EPS + + # two spellings of infinity, with optional signs; case-insensitive + self.identical(fromHex('inf'), INF) + self.identical(fromHex('+Inf'), INF) + self.identical(fromHex('-INF'), -INF) + self.identical(fromHex('iNf'), INF) + self.identical(fromHex('Infinity'), INF) + self.identical(fromHex('+INFINITY'), INF) + self.identical(fromHex('-infinity'), -INF) + self.identical(fromHex('-iNFiNitY'), -INF) + + # nans with optional sign; case insensitive + self.identical(fromHex('nan'), NAN) + self.identical(fromHex('+NaN'), NAN) + self.identical(fromHex('-NaN'), NAN) + self.identical(fromHex('-nAN'), NAN) + + # variations in input format + self.identical(fromHex('1'), 1.0) + self.identical(fromHex('+1'), 1.0) + self.identical(fromHex('1.'), 1.0) + self.identical(fromHex('1.0'), 1.0) + self.identical(fromHex('1.0p0'), 1.0) + self.identical(fromHex('01'), 1.0) + self.identical(fromHex('01.'), 1.0) + self.identical(fromHex('0x1'), 1.0) + self.identical(fromHex('0x1.'), 1.0) + self.identical(fromHex('0x1.0'), 1.0) + self.identical(fromHex('+0x1.0'), 1.0) + self.identical(fromHex('0x1p0'), 1.0) + self.identical(fromHex('0X1p0'), 1.0) + self.identical(fromHex('0X1P0'), 1.0) + self.identical(fromHex('0x1P0'), 1.0) + self.identical(fromHex('0x1.p0'), 1.0) + self.identical(fromHex('0x1.0p0'), 1.0) + self.identical(fromHex('0x.1p4'), 1.0) + self.identical(fromHex('0x.1p04'), 1.0) + self.identical(fromHex('0x.1p004'), 1.0) + self.identical(fromHex('0x1p+0'), 1.0) + self.identical(fromHex('0x1P-0'), 1.0) + self.identical(fromHex('+0x1p0'), 1.0) + self.identical(fromHex('0x01p0'), 1.0) + self.identical(fromHex('0x1p00'), 1.0) + self.identical(fromHex(' 0x1p0 '), 1.0) + self.identical(fromHex('\n 0x1p0'), 1.0) + self.identical(fromHex('0x1p0 \t'), 1.0) + self.identical(fromHex('0xap0'), 10.0) + self.identical(fromHex('0xAp0'), 10.0) + self.identical(fromHex('0xaP0'), 10.0) + self.identical(fromHex('0xAP0'), 10.0) + self.identical(fromHex('0xbep0'), 190.0) + self.identical(fromHex('0xBep0'), 190.0) + self.identical(fromHex('0xbEp0'), 190.0) + self.identical(fromHex('0XBE0P-4'), 190.0) + self.identical(fromHex('0xBEp0'), 190.0) + self.identical(fromHex('0xB.Ep4'), 190.0) + self.identical(fromHex('0x.BEp8'), 190.0) + self.identical(fromHex('0x.0BEp12'), 190.0) + + # moving the point around + pi = fromHex('0x1.921fb54442d18p1') + self.identical(fromHex('0x.006487ed5110b46p11'), pi) + self.identical(fromHex('0x.00c90fdaa22168cp10'), pi) + self.identical(fromHex('0x.01921fb54442d18p9'), pi) + self.identical(fromHex('0x.03243f6a8885a3p8'), pi) + self.identical(fromHex('0x.06487ed5110b46p7'), pi) + self.identical(fromHex('0x.0c90fdaa22168cp6'), pi) + self.identical(fromHex('0x.1921fb54442d18p5'), pi) + self.identical(fromHex('0x.3243f6a8885a3p4'), pi) + self.identical(fromHex('0x.6487ed5110b46p3'), pi) + self.identical(fromHex('0x.c90fdaa22168cp2'), pi) + self.identical(fromHex('0x1.921fb54442d18p1'), pi) + self.identical(fromHex('0x3.243f6a8885a3p0'), pi) + self.identical(fromHex('0x6.487ed5110b46p-1'), pi) + self.identical(fromHex('0xc.90fdaa22168cp-2'), pi) + self.identical(fromHex('0x19.21fb54442d18p-3'), pi) + self.identical(fromHex('0x32.43f6a8885a3p-4'), pi) + self.identical(fromHex('0x64.87ed5110b46p-5'), pi) + self.identical(fromHex('0xc9.0fdaa22168cp-6'), pi) + self.identical(fromHex('0x192.1fb54442d18p-7'), pi) + self.identical(fromHex('0x324.3f6a8885a3p-8'), pi) + self.identical(fromHex('0x648.7ed5110b46p-9'), pi) + self.identical(fromHex('0xc90.fdaa22168cp-10'), pi) + self.identical(fromHex('0x1921.fb54442d18p-11'), pi) + # ... + self.identical(fromHex('0x1921fb54442d1.8p-47'), pi) + self.identical(fromHex('0x3243f6a8885a3p-48'), pi) + self.identical(fromHex('0x6487ed5110b46p-49'), pi) + self.identical(fromHex('0xc90fdaa22168cp-50'), pi) + self.identical(fromHex('0x1921fb54442d18p-51'), pi) + self.identical(fromHex('0x3243f6a8885a30p-52'), pi) + self.identical(fromHex('0x6487ed5110b460p-53'), pi) + self.identical(fromHex('0xc90fdaa22168c0p-54'), pi) + self.identical(fromHex('0x1921fb54442d180p-55'), pi) + + + # results that should overflow... + self.assertRaises(OverflowError, fromHex, '-0x1p1024') + self.assertRaises(OverflowError, fromHex, '0x1p+1025') + self.assertRaises(OverflowError, fromHex, '+0X1p1030') + self.assertRaises(OverflowError, fromHex, '-0x1p+1100') + self.assertRaises(OverflowError, fromHex, '0X1p123456789123456789') + self.assertRaises(OverflowError, fromHex, '+0X.8p+1025') + self.assertRaises(OverflowError, fromHex, '+0x0.8p1025') + self.assertRaises(OverflowError, fromHex, '-0x0.4p1026') + self.assertRaises(OverflowError, fromHex, '0X2p+1023') + self.assertRaises(OverflowError, fromHex, '0x2.p1023') + self.assertRaises(OverflowError, fromHex, '-0x2.0p+1023') + self.assertRaises(OverflowError, fromHex, '+0X4p+1022') + self.assertRaises(OverflowError, fromHex, '0x1.ffffffffffffffp+1023') + self.assertRaises(OverflowError, fromHex, '-0X1.fffffffffffff9p1023') + self.assertRaises(OverflowError, fromHex, '0X1.fffffffffffff8p1023') + self.assertRaises(OverflowError, fromHex, '+0x3.fffffffffffffp1022') + self.assertRaises(OverflowError, fromHex, '0x3fffffffffffffp+970') + self.assertRaises(OverflowError, fromHex, '0x10000000000000000p960') + self.assertRaises(OverflowError, fromHex, '-0Xffffffffffffffffp960') + + # ...and those that round to +-max float + self.identical(fromHex('+0x1.fffffffffffffp+1023'), MAX) + self.identical(fromHex('-0X1.fffffffffffff7p1023'), -MAX) + self.identical(fromHex('0X1.fffffffffffff7fffffffffffffp1023'), MAX) + + # zeros + self.identical(fromHex('0x0p0'), 0.0) + self.identical(fromHex('0x0p1000'), 0.0) + self.identical(fromHex('-0x0p1023'), -0.0) + self.identical(fromHex('0X0p1024'), 0.0) + self.identical(fromHex('-0x0p1025'), -0.0) + self.identical(fromHex('0X0p2000'), 0.0) + self.identical(fromHex('0x0p123456789123456789'), 0.0) + self.identical(fromHex('-0X0p-0'), -0.0) + self.identical(fromHex('-0X0p-1000'), -0.0) + self.identical(fromHex('0x0p-1023'), 0.0) + self.identical(fromHex('-0X0p-1024'), -0.0) + self.identical(fromHex('-0x0p-1025'), -0.0) + self.identical(fromHex('-0x0p-1072'), -0.0) + self.identical(fromHex('0X0p-1073'), 0.0) + self.identical(fromHex('-0x0p-1074'), -0.0) + self.identical(fromHex('0x0p-1075'), 0.0) + self.identical(fromHex('0X0p-1076'), 0.0) + self.identical(fromHex('-0X0p-2000'), -0.0) + self.identical(fromHex('-0x0p-123456789123456789'), -0.0) + + # values that should underflow to 0 + self.identical(fromHex('0X1p-1075'), 0.0) + self.identical(fromHex('-0X1p-1075'), -0.0) + self.identical(fromHex('-0x1p-123456789123456789'), -0.0) + self.identical(fromHex('0x1.00000000000000001p-1075'), TINY) + self.identical(fromHex('-0x1.1p-1075'), -TINY) + self.identical(fromHex('0x1.fffffffffffffffffp-1075'), TINY) + + # check round-half-even is working correctly near 0 ... + self.identical(fromHex('0x1p-1076'), 0.0) + self.identical(fromHex('0X2p-1076'), 0.0) + self.identical(fromHex('0X3p-1076'), TINY) + self.identical(fromHex('0x4p-1076'), TINY) + self.identical(fromHex('0X5p-1076'), TINY) + self.identical(fromHex('0X6p-1076'), 2*TINY) + self.identical(fromHex('0x7p-1076'), 2*TINY) + self.identical(fromHex('0X8p-1076'), 2*TINY) + self.identical(fromHex('0X9p-1076'), 2*TINY) + self.identical(fromHex('0xap-1076'), 2*TINY) + self.identical(fromHex('0Xbp-1076'), 3*TINY) + self.identical(fromHex('0xcp-1076'), 3*TINY) + self.identical(fromHex('0Xdp-1076'), 3*TINY) + self.identical(fromHex('0Xep-1076'), 4*TINY) + self.identical(fromHex('0xfp-1076'), 4*TINY) + self.identical(fromHex('0x10p-1076'), 4*TINY) + self.identical(fromHex('-0x1p-1076'), -0.0) + self.identical(fromHex('-0X2p-1076'), -0.0) + self.identical(fromHex('-0x3p-1076'), -TINY) + self.identical(fromHex('-0X4p-1076'), -TINY) + self.identical(fromHex('-0x5p-1076'), -TINY) + self.identical(fromHex('-0x6p-1076'), -2*TINY) + self.identical(fromHex('-0X7p-1076'), -2*TINY) + self.identical(fromHex('-0X8p-1076'), -2*TINY) + self.identical(fromHex('-0X9p-1076'), -2*TINY) + self.identical(fromHex('-0Xap-1076'), -2*TINY) + self.identical(fromHex('-0xbp-1076'), -3*TINY) + self.identical(fromHex('-0xcp-1076'), -3*TINY) + self.identical(fromHex('-0Xdp-1076'), -3*TINY) + self.identical(fromHex('-0xep-1076'), -4*TINY) + self.identical(fromHex('-0Xfp-1076'), -4*TINY) + self.identical(fromHex('-0X10p-1076'), -4*TINY) + + # ... and near MIN ... + self.identical(fromHex('0x0.ffffffffffffd6p-1022'), MIN-3*TINY) + self.identical(fromHex('0x0.ffffffffffffd8p-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffdap-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffdcp-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffdep-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffe0p-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffe2p-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffe4p-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffe6p-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffe8p-1022'), MIN-2*TINY) + self.identical(fromHex('0x0.ffffffffffffeap-1022'), MIN-TINY) + self.identical(fromHex('0x0.ffffffffffffecp-1022'), MIN-TINY) + self.identical(fromHex('0x0.ffffffffffffeep-1022'), MIN-TINY) + self.identical(fromHex('0x0.fffffffffffff0p-1022'), MIN-TINY) + self.identical(fromHex('0x0.fffffffffffff2p-1022'), MIN-TINY) + self.identical(fromHex('0x0.fffffffffffff4p-1022'), MIN-TINY) + self.identical(fromHex('0x0.fffffffffffff6p-1022'), MIN-TINY) + self.identical(fromHex('0x0.fffffffffffff8p-1022'), MIN) + self.identical(fromHex('0x0.fffffffffffffap-1022'), MIN) + self.identical(fromHex('0x0.fffffffffffffcp-1022'), MIN) + self.identical(fromHex('0x0.fffffffffffffep-1022'), MIN) + self.identical(fromHex('0x1.00000000000000p-1022'), MIN) + self.identical(fromHex('0x1.00000000000002p-1022'), MIN) + self.identical(fromHex('0x1.00000000000004p-1022'), MIN) + self.identical(fromHex('0x1.00000000000006p-1022'), MIN) + self.identical(fromHex('0x1.00000000000008p-1022'), MIN) + self.identical(fromHex('0x1.0000000000000ap-1022'), MIN+TINY) + self.identical(fromHex('0x1.0000000000000cp-1022'), MIN+TINY) + self.identical(fromHex('0x1.0000000000000ep-1022'), MIN+TINY) + self.identical(fromHex('0x1.00000000000010p-1022'), MIN+TINY) + self.identical(fromHex('0x1.00000000000012p-1022'), MIN+TINY) + self.identical(fromHex('0x1.00000000000014p-1022'), MIN+TINY) + self.identical(fromHex('0x1.00000000000016p-1022'), MIN+TINY) + self.identical(fromHex('0x1.00000000000018p-1022'), MIN+2*TINY) + + # ... and near 1.0. + self.identical(fromHex('0x0.fffffffffffff0p0'), 1.0-EPS) + self.identical(fromHex('0x0.fffffffffffff1p0'), 1.0-EPS) + self.identical(fromHex('0X0.fffffffffffff2p0'), 1.0-EPS) + self.identical(fromHex('0x0.fffffffffffff3p0'), 1.0-EPS) + self.identical(fromHex('0X0.fffffffffffff4p0'), 1.0-EPS) + self.identical(fromHex('0X0.fffffffffffff5p0'), 1.0-EPS/2) + self.identical(fromHex('0X0.fffffffffffff6p0'), 1.0-EPS/2) + self.identical(fromHex('0x0.fffffffffffff7p0'), 1.0-EPS/2) + self.identical(fromHex('0x0.fffffffffffff8p0'), 1.0-EPS/2) + self.identical(fromHex('0X0.fffffffffffff9p0'), 1.0-EPS/2) + self.identical(fromHex('0X0.fffffffffffffap0'), 1.0-EPS/2) + self.identical(fromHex('0x0.fffffffffffffbp0'), 1.0-EPS/2) + self.identical(fromHex('0X0.fffffffffffffcp0'), 1.0) + self.identical(fromHex('0x0.fffffffffffffdp0'), 1.0) + self.identical(fromHex('0X0.fffffffffffffep0'), 1.0) + self.identical(fromHex('0x0.ffffffffffffffp0'), 1.0) + self.identical(fromHex('0X1.00000000000000p0'), 1.0) + self.identical(fromHex('0X1.00000000000001p0'), 1.0) + self.identical(fromHex('0x1.00000000000002p0'), 1.0) + self.identical(fromHex('0X1.00000000000003p0'), 1.0) + self.identical(fromHex('0x1.00000000000004p0'), 1.0) + self.identical(fromHex('0X1.00000000000005p0'), 1.0) + self.identical(fromHex('0X1.00000000000006p0'), 1.0) + self.identical(fromHex('0X1.00000000000007p0'), 1.0) + self.identical(fromHex('0x1.00000000000007ffffffffffffffffffffp0'), + 1.0) + self.identical(fromHex('0x1.00000000000008p0'), 1.0) + self.identical(fromHex('0x1.00000000000008000000000000000001p0'), + 1+EPS) + self.identical(fromHex('0X1.00000000000009p0'), 1.0+EPS) + self.identical(fromHex('0x1.0000000000000ap0'), 1.0+EPS) + self.identical(fromHex('0x1.0000000000000bp0'), 1.0+EPS) + self.identical(fromHex('0X1.0000000000000cp0'), 1.0+EPS) + self.identical(fromHex('0x1.0000000000000dp0'), 1.0+EPS) + self.identical(fromHex('0x1.0000000000000ep0'), 1.0+EPS) + self.identical(fromHex('0X1.0000000000000fp0'), 1.0+EPS) + self.identical(fromHex('0x1.00000000000010p0'), 1.0+EPS) + self.identical(fromHex('0X1.00000000000011p0'), 1.0+EPS) + self.identical(fromHex('0x1.00000000000012p0'), 1.0+EPS) + self.identical(fromHex('0X1.00000000000013p0'), 1.0+EPS) + self.identical(fromHex('0X1.00000000000014p0'), 1.0+EPS) + self.identical(fromHex('0x1.00000000000015p0'), 1.0+EPS) + self.identical(fromHex('0x1.00000000000016p0'), 1.0+EPS) + self.identical(fromHex('0X1.00000000000017p0'), 1.0+EPS) + self.identical(fromHex('0x1.00000000000017ffffffffffffffffffffp0'), + 1.0+EPS) + self.identical(fromHex('0x1.00000000000018p0'), 1.0+2*EPS) + self.identical(fromHex('0X1.00000000000018000000000000000001p0'), + 1.0+2*EPS) + self.identical(fromHex('0x1.00000000000019p0'), 1.0+2*EPS) + self.identical(fromHex('0X1.0000000000001ap0'), 1.0+2*EPS) + self.identical(fromHex('0X1.0000000000001bp0'), 1.0+2*EPS) + self.identical(fromHex('0x1.0000000000001cp0'), 1.0+2*EPS) + self.identical(fromHex('0x1.0000000000001dp0'), 1.0+2*EPS) + self.identical(fromHex('0x1.0000000000001ep0'), 1.0+2*EPS) + self.identical(fromHex('0X1.0000000000001fp0'), 1.0+2*EPS) + self.identical(fromHex('0x1.00000000000020p0'), 1.0+2*EPS) + + # Regression test for a corner-case bug reported in b.p.o. 44954 + self.identical(fromHex('0x.8p-1074'), 0.0) + self.identical(fromHex('0x.80p-1074'), 0.0) + self.identical(fromHex('0x.81p-1074'), TINY) + self.identical(fromHex('0x8p-1078'), 0.0) + self.identical(fromHex('0x8.0p-1078'), 0.0) + self.identical(fromHex('0x8.1p-1078'), TINY) + self.identical(fromHex('0x80p-1082'), 0.0) + self.identical(fromHex('0x81p-1082'), TINY) + self.identical(fromHex('.8p-1074'), 0.0) + self.identical(fromHex('8p-1078'), 0.0) + self.identical(fromHex('-.8p-1074'), -0.0) + self.identical(fromHex('+8p-1078'), 0.0) + + def test_roundtrip(self): + def roundtrip(x): + return fromHex(toHex(x)) + + for x in [NAN, INF, self.MAX, self.MIN, self.MIN-self.TINY, self.TINY, 0.0]: + self.identical(x, roundtrip(x)) + self.identical(-x, roundtrip(-x)) + + # fromHex(toHex(x)) should exactly recover x, for any non-NaN float x. + import random + for i in range(10000): + e = random.randrange(-1200, 1200) + m = random.random() + s = random.choice([1.0, -1.0]) + try: + x = s*ldexp(m, e) + except OverflowError: + pass + else: + self.identical(x, fromHex(toHex(x))) + + def test_subclass(self): + class F(float): + def __new__(cls, value): + return float.__new__(cls, value + 1) + + f = F.fromhex((1.5).hex()) + self.assertIs(type(f), F) + self.assertEqual(f, 2.5) + + class F2(float): + def __init__(self, value): + self.foo = 'bar' + + f = F2.fromhex((1.5).hex()) + self.assertIs(type(f), F2) + self.assertEqual(f, 1.5) + self.assertEqual(getattr(f, 'foo', 'none'), 'bar') + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_generator_stop.diff b/test/dynamo/cpython/3_13/test_generator_stop.diff new file mode 100644 index 00000000000000..4f6450a86e5636 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_generator_stop.diff @@ -0,0 +1,74 @@ +diff --git a/test/dynamo/cpython/3_13/test_generator_stop.py b/test/dynamo/cpython/3_13/test_generator_stop.py +index bc235ceb00e..cb2a85255cb 100644 +--- a/test/dynamo/cpython/3_13/test_generator_stop.py ++++ b/test/dynamo/cpython/3_13/test_generator_stop.py +@@ -1,9 +1,60 @@ + from __future__ import generator_stop + ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import unittest + + +-class TestPEP479(unittest.TestCase): ++class TestPEP479(__TestCase): + def test_stopiteration_wrapping(self): + def f(): + raise StopIteration +@@ -30,5 +81,5 @@ class TestPEP479(unittest.TestCase): + 'were not properly set') + + +-if __name__ == '__main__': +- unittest.main() ++if __name__ == "__main__": ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_generator_stop.py b/test/dynamo/cpython/3_13/test_generator_stop.py new file mode 100644 index 00000000000000..cb2a85255cb42f --- /dev/null +++ b/test/dynamo/cpython/3_13/test_generator_stop.py @@ -0,0 +1,85 @@ +from __future__ import generator_stop + +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import unittest + + +class TestPEP479(__TestCase): + def test_stopiteration_wrapping(self): + def f(): + raise StopIteration + def g(): + yield f() + with self.assertRaisesRegex(RuntimeError, + "generator raised StopIteration"): + next(g()) + + def test_stopiteration_wrapping_context(self): + def f(): + raise StopIteration + def g(): + yield f() + + try: + next(g()) + except RuntimeError as exc: + self.assertIs(type(exc.__cause__), StopIteration) + self.assertIs(type(exc.__context__), StopIteration) + self.assertTrue(exc.__suppress_context__) + else: + self.fail('__cause__, __context__, or __suppress_context__ ' + 'were not properly set') + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_generators.diff b/test/dynamo/cpython/3_13/test_generators.diff new file mode 100644 index 00000000000000..49a2d664cf17d3 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_generators.diff @@ -0,0 +1,289 @@ +diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py +index e48d79d34f4..40a02d644a9 100644 +--- a/test/dynamo/cpython/3_13/test_generators.py ++++ b/test/dynamo/cpython/3_13/test_generators.py +@@ -1,3 +1,53 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import copy + import gc + import pickle +@@ -22,7 +72,7 @@ except ImportError: + @unittest.skipUnless(_testcapi is not None and + hasattr(_testcapi, "raise_SIGINT_then_send_None"), + "needs _testcapi.raise_SIGINT_then_send_None") +-class SignalAndYieldFromTest(unittest.TestCase): ++class SignalAndYieldFromTest(__TestCase): + + def generator1(self): + return (yield from self.generator2()) +@@ -46,7 +96,7 @@ class SignalAndYieldFromTest(unittest.TestCase): + self.assertEqual(exc.value, "PASSED") + + +-class FinalizationTest(unittest.TestCase): ++class FinalizationTest(__TestCase): + + def test_frame_resurrect(self): + # A generator frame can be resurrected by a generator's finalization. +@@ -113,7 +163,7 @@ class FinalizationTest(unittest.TestCase): + self.assertEqual(cm.exception.value, 2) + + +-class GeneratorTest(unittest.TestCase): ++class GeneratorTest(__TestCase): + + def test_name(self): + def func(): +@@ -246,8 +296,31 @@ class GeneratorTest(unittest.TestCase): + #This should not raise + loop() + ++ @unittest.expectedFailure ++ def test_genexpr_only_calls_dunder_iter_once(self): ++ ++ class Iterator: ++ ++ def __init__(self): ++ self.val = 0 ++ ++ def __next__(self): ++ if self.val == 2: ++ raise StopIteration ++ self.val += 1 ++ return self.val ++ ++ # No __iter__ method ++ ++ class C: ++ ++ def __iter__(self): ++ return Iterator() ++ ++ self.assertEqual([1,2], list(i for i in C())) ++ + +-class ModifyUnderlyingIterableTest(unittest.TestCase): ++class ModifyUnderlyingIterableTest(__TestCase): + iterables = [ + range(0), + range(20), +@@ -319,7 +392,7 @@ class ModifyUnderlyingIterableTest(unittest.TestCase): + self.process_tests(get_generator_genfunc) + + +-class ExceptionTest(unittest.TestCase): ++class ExceptionTest(__TestCase): + # Tests for the issue #23353: check that the currently handled exception + # is correctly saved/restored in PyEval_EvalFrameEx(). + +@@ -528,7 +601,7 @@ class ExceptionTest(unittest.TestCase): + self.assertEqual(cm.exception.value.value, 2) + + +-class GeneratorCloseTest(unittest.TestCase): ++class GeneratorCloseTest(__TestCase): + + def test_close_no_return_value(self): + def f(): +@@ -630,90 +703,7 @@ class GeneratorCloseTest(unittest.TestCase): + self.assertIsNone(f_wr()) + + +-# See https://github.com/python/cpython/issues/125723 +-class GeneratorDeallocTest(unittest.TestCase): +- def test_frame_outlives_generator(self): +- def g1(): +- a = 42 +- yield sys._getframe() +- +- def g2(): +- a = 42 +- yield +- +- def g3(obj): +- a = 42 +- obj.frame = sys._getframe() +- yield +- +- class ObjectWithFrame(): +- def __init__(self): +- self.frame = None +- +- def get_frame(index): +- if index == 1: +- return next(g1()) +- elif index == 2: +- gen = g2() +- next(gen) +- return gen.gi_frame +- elif index == 3: +- obj = ObjectWithFrame() +- next(g3(obj)) +- return obj.frame +- else: +- return None +- +- for index in (1, 2, 3): +- with self.subTest(index=index): +- frame = get_frame(index) +- frame_locals = frame.f_locals +- self.assertIn('a', frame_locals) +- self.assertEqual(frame_locals['a'], 42) +- +- def test_frame_locals_outlive_generator(self): +- frame_locals1 = None +- +- def g1(): +- nonlocal frame_locals1 +- frame_locals1 = sys._getframe().f_locals +- a = 42 +- yield +- +- def g2(): +- a = 42 +- yield sys._getframe().f_locals +- +- def get_frame_locals(index): +- if index == 1: +- nonlocal frame_locals1 +- next(g1()) +- return frame_locals1 +- if index == 2: +- return next(g2()) +- else: +- return None +- +- for index in (1, 2): +- with self.subTest(index=index): +- frame_locals = get_frame_locals(index) +- self.assertIn('a', frame_locals) +- self.assertEqual(frame_locals['a'], 42) +- +- def test_frame_locals_outlive_generator_with_exec(self): +- def g(): +- a = 42 +- yield locals(), sys._getframe().f_locals +- +- locals_ = {'g': g} +- for i in range(10): +- exec("snapshot, live_locals = next(g())", locals=locals_) +- for l in (locals_['snapshot'], locals_['live_locals']): +- self.assertIn('a', l) +- self.assertEqual(l['a'], 42) +- +- +-class GeneratorThrowTest(unittest.TestCase): ++class GeneratorThrowTest(__TestCase): + + def test_exception_context_with_yield(self): + def f(): +@@ -812,7 +802,7 @@ class GeneratorThrowTest(unittest.TestCase): + gen.throw(ValueError) + + +-class GeneratorStackTraceTest(unittest.TestCase): ++class GeneratorStackTraceTest(__TestCase): + + def check_stack_names(self, frame, expected): + names = [] +@@ -861,7 +851,7 @@ class GeneratorStackTraceTest(unittest.TestCase): + self.check_yield_from_example(call_throw) + + +-class YieldFromTests(unittest.TestCase): ++class YieldFromTests(__TestCase): + def test_generator_gi_yieldfrom(self): + def a(): + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING) +@@ -2752,21 +2742,27 @@ test_generators just happened to be the test that drew these out. + + """ + +-__test__ = {"tut": tutorial_tests, +- "pep": pep_tests, +- "email": email_tests, +- "fun": fun_tests, +- "syntax": syntax_tests, +- "conjoin": conjoin_tests, +- "weakref": weakref_tests, +- "coroutine": coroutine_tests, +- "refleaks": refleaks_tests, +- } +- +-def load_tests(loader, tests, pattern): +- tests.addTest(doctest.DocTestSuite()) +- return tests ++# __test__ = {"tut": tutorial_tests, ++# "pep": pep_tests, ++# "email": email_tests, ++# "fun": fun_tests, ++# "syntax": syntax_tests, ++# "conjoin": conjoin_tests, ++# "weakref": weakref_tests, ++# "coroutine": coroutine_tests, ++# "refleaks": refleaks_tests, ++# } ++ ++# def load_tests(loader, tests, pattern): ++# # ======= BEGIN Dynamo patch ======= ++# suite = doctest.DocTestSuite() ++# for test in suite: ++# # Dynamically change base class ++# test.__class__ = type(test.__class__.__name__, (__TestCase, test.__class__), {}) ++# tests.addTests(suite) ++# # ======= END DYNAMO PATCH ======= ++# return tests + + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_generators.py b/test/dynamo/cpython/3_13/test_generators.py new file mode 100644 index 00000000000000..40a02d644a99f2 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_generators.py @@ -0,0 +1,2768 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import copy +import gc +import pickle +import sys +import doctest +import unittest +import weakref +import inspect +import types + +from test import support + +try: + import _testcapi +except ImportError: + _testcapi = None + + +# This tests to make sure that if a SIGINT arrives just before we send into a +# yield from chain, the KeyboardInterrupt is raised in the innermost +# generator (see bpo-30039). +@unittest.skipUnless(_testcapi is not None and + hasattr(_testcapi, "raise_SIGINT_then_send_None"), + "needs _testcapi.raise_SIGINT_then_send_None") +class SignalAndYieldFromTest(__TestCase): + + def generator1(self): + return (yield from self.generator2()) + + def generator2(self): + try: + yield + except KeyboardInterrupt: + return "PASSED" + else: + return "FAILED" + + def test_raise_and_yield_from(self): + gen = self.generator1() + gen.send(None) + try: + _testcapi.raise_SIGINT_then_send_None(gen) + except BaseException as _exc: + exc = _exc + self.assertIs(type(exc), StopIteration) + self.assertEqual(exc.value, "PASSED") + + +class FinalizationTest(__TestCase): + + def test_frame_resurrect(self): + # A generator frame can be resurrected by a generator's finalization. + def gen(): + nonlocal frame + try: + yield + finally: + frame = sys._getframe() + + g = gen() + wr = weakref.ref(g) + next(g) + del g + support.gc_collect() + self.assertIs(wr(), None) + self.assertTrue(frame) + del frame + support.gc_collect() + + def test_refcycle(self): + # A generator caught in a refcycle gets finalized anyway. + old_garbage = gc.garbage[:] + finalized = False + def gen(): + nonlocal finalized + try: + g = yield + yield 1 + finally: + finalized = True + + g = gen() + next(g) + g.send(g) + self.assertGreater(sys.getrefcount(g), 2) + self.assertFalse(finalized) + del g + support.gc_collect() + self.assertTrue(finalized) + self.assertEqual(gc.garbage, old_garbage) + + def test_lambda_generator(self): + # bpo-23192, gh-119897: Test that a lambda returning a generator behaves + # like the equivalent function + f = lambda: (yield 1) + self.assertIsInstance(f(), types.GeneratorType) + self.assertEqual(next(f()), 1) + + def g(): return (yield 1) + + # test 'yield from' + f2 = lambda: (yield from g()) + def g2(): return (yield from g()) + + f3 = lambda: (yield from f()) + def g3(): return (yield from f()) + + for gen_fun in (f, g, f2, g2, f3, g3): + gen = gen_fun() + self.assertEqual(next(gen), 1) + with self.assertRaises(StopIteration) as cm: + gen.send(2) + self.assertEqual(cm.exception.value, 2) + + +class GeneratorTest(__TestCase): + + def test_name(self): + def func(): + yield 1 + + # check generator names + gen = func() + self.assertEqual(gen.__name__, "func") + self.assertEqual(gen.__qualname__, + "GeneratorTest.test_name..func") + + # modify generator names + gen.__name__ = "name" + gen.__qualname__ = "qualname" + self.assertEqual(gen.__name__, "name") + self.assertEqual(gen.__qualname__, "qualname") + + # generator names must be a string and cannot be deleted + self.assertRaises(TypeError, setattr, gen, '__name__', 123) + self.assertRaises(TypeError, setattr, gen, '__qualname__', 123) + self.assertRaises(TypeError, delattr, gen, '__name__') + self.assertRaises(TypeError, delattr, gen, '__qualname__') + + # modify names of the function creating the generator + func.__qualname__ = "func_qualname" + func.__name__ = "func_name" + gen = func() + self.assertEqual(gen.__name__, "func_name") + self.assertEqual(gen.__qualname__, "func_qualname") + + # unnamed generator + gen = (x for x in range(10)) + self.assertEqual(gen.__name__, + "") + self.assertEqual(gen.__qualname__, + "GeneratorTest.test_name..") + + def test_copy(self): + def f(): + yield 1 + g = f() + with self.assertRaises(TypeError): + copy.copy(g) + + def test_pickle(self): + def f(): + yield 1 + g = f() + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((TypeError, pickle.PicklingError)): + pickle.dumps(g, proto) + + def test_send_non_none_to_new_gen(self): + def f(): + yield 1 + g = f() + with self.assertRaises(TypeError): + g.send(0) + self.assertEqual(next(g), 1) + + def test_handle_frame_object_in_creation(self): + + #Attempt to expose partially constructed frames + #See https://github.com/python/cpython/issues/94262 + + def cb(*args): + inspect.stack() + + def gen(): + yield 1 + + thresholds = gc.get_threshold() + + gc.callbacks.append(cb) + gc.set_threshold(1, 0, 0) + try: + gen() + finally: + gc.set_threshold(*thresholds) + gc.callbacks.pop() + + class Sneaky: + def __del__(self): + inspect.stack() + + sneaky = Sneaky() + sneaky._s = Sneaky() + sneaky._s._s = sneaky + + gc.set_threshold(1, 0, 0) + try: + del sneaky + gen() + finally: + gc.set_threshold(*thresholds) + + def test_ag_frame_f_back(self): + async def f(): + yield + ag = f() + self.assertIsNone(ag.ag_frame.f_back) + + def test_cr_frame_f_back(self): + async def f(): + pass + cr = f() + self.assertIsNone(cr.cr_frame.f_back) + cr.close() # Suppress RuntimeWarning. + + def test_gi_frame_f_back(self): + def f(): + yield + gi = f() + self.assertIsNone(gi.gi_frame.f_back) + + def test_issue103488(self): + + def gen_raises(): + yield + raise ValueError() + + def loop(): + try: + for _ in gen_raises(): + if True is False: + return + except ValueError: + pass + + #This should not raise + loop() + + @unittest.expectedFailure + def test_genexpr_only_calls_dunder_iter_once(self): + + class Iterator: + + def __init__(self): + self.val = 0 + + def __next__(self): + if self.val == 2: + raise StopIteration + self.val += 1 + return self.val + + # No __iter__ method + + class C: + + def __iter__(self): + return Iterator() + + self.assertEqual([1,2], list(i for i in C())) + + +class ModifyUnderlyingIterableTest(__TestCase): + iterables = [ + range(0), + range(20), + [1, 2, 3], + (2,), + {13, 48, 211}, + frozenset((15, 8, 6)), + {1: 2, 3: 4}, + ] + + non_iterables = [ + None, + 42, + 3.0, + 2j, + ] + + def genexpr(self): + return (x for x in range(10)) + + def genfunc(self): + def gen(it): + for x in it: + yield x + return gen(range(10)) + + def process_tests(self, get_generator): + for obj in self.iterables: + g_obj = get_generator(obj) + with self.subTest(g_obj=g_obj, obj=obj): + self.assertListEqual(list(g_obj), list(obj)) + + g_iter = get_generator(iter(obj)) + with self.subTest(g_iter=g_iter, obj=obj): + self.assertListEqual(list(g_iter), list(obj)) + + err_regex = "'.*' object is not iterable" + for obj in self.non_iterables: + g_obj = get_generator(obj) + with self.subTest(g_obj=g_obj): + self.assertRaisesRegex(TypeError, err_regex, list, g_obj) + + def test_modify_f_locals(self): + def modify_f_locals(g, local, obj): + g.gi_frame.f_locals[local] = obj + return g + + def get_generator_genexpr(obj): + return modify_f_locals(self.genexpr(), '.0', obj) + + def get_generator_genfunc(obj): + return modify_f_locals(self.genfunc(), 'it', obj) + + self.process_tests(get_generator_genexpr) + self.process_tests(get_generator_genfunc) + + def test_new_gen_from_gi_code(self): + def new_gen_from_gi_code(g, obj): + generator_func = types.FunctionType(g.gi_code, {}) + return generator_func(obj) + + def get_generator_genexpr(obj): + return new_gen_from_gi_code(self.genexpr(), obj) + + def get_generator_genfunc(obj): + return new_gen_from_gi_code(self.genfunc(), obj) + + self.process_tests(get_generator_genexpr) + self.process_tests(get_generator_genfunc) + + +class ExceptionTest(__TestCase): + # Tests for the issue #23353: check that the currently handled exception + # is correctly saved/restored in PyEval_EvalFrameEx(). + + def test_except_throw(self): + def store_raise_exc_generator(): + try: + self.assertIsNone(sys.exception()) + yield + except Exception as exc: + # exception raised by gen.throw(exc) + self.assertIsInstance(sys.exception(), ValueError) + self.assertIsNone(exc.__context__) + yield + + # ensure that the exception is not lost + self.assertIsInstance(sys.exception(), ValueError) + yield + + # we should be able to raise back the ValueError + raise + + make = store_raise_exc_generator() + next(make) + + try: + raise ValueError() + except Exception as exc: + try: + make.throw(exc) + except Exception: + pass + + next(make) + with self.assertRaises(ValueError) as cm: + next(make) + self.assertIsNone(cm.exception.__context__) + + self.assertIsNone(sys.exception()) + + def test_except_next(self): + def gen(): + self.assertIsInstance(sys.exception(), ValueError) + yield "done" + + g = gen() + try: + raise ValueError + except Exception: + self.assertEqual(next(g), "done") + self.assertIsNone(sys.exception()) + + def test_except_gen_except(self): + def gen(): + try: + self.assertIsNone(sys.exception()) + yield + # we are called from "except ValueError:", TypeError must + # inherit ValueError in its context + raise TypeError() + except TypeError as exc: + self.assertIsInstance(sys.exception(), TypeError) + self.assertEqual(type(exc.__context__), ValueError) + # here we are still called from the "except ValueError:" + self.assertIsInstance(sys.exception(), ValueError) + yield + self.assertIsNone(sys.exception()) + yield "done" + + g = gen() + next(g) + try: + raise ValueError + except Exception: + next(g) + + self.assertEqual(next(g), "done") + self.assertIsNone(sys.exception()) + + def test_nested_gen_except_loop(self): + def gen(): + for i in range(100): + self.assertIsInstance(sys.exception(), TypeError) + yield "doing" + + def outer(): + try: + raise TypeError + except: + for x in gen(): + yield x + + try: + raise ValueError + except Exception: + for x in outer(): + self.assertEqual(x, "doing") + self.assertEqual(sys.exception(), None) + + def test_except_throw_exception_context(self): + def gen(): + try: + try: + self.assertIsNone(sys.exception()) + yield + except ValueError: + # we are called from "except ValueError:" + self.assertIsInstance(sys.exception(), ValueError) + raise TypeError() + except Exception as exc: + self.assertIsInstance(sys.exception(), TypeError) + self.assertEqual(type(exc.__context__), ValueError) + # we are still called from "except ValueError:" + self.assertIsInstance(sys.exception(), ValueError) + yield + self.assertIsNone(sys.exception()) + yield "done" + + g = gen() + next(g) + try: + raise ValueError + except Exception as exc: + g.throw(exc) + + self.assertEqual(next(g), "done") + self.assertIsNone(sys.exception()) + + def test_except_throw_bad_exception(self): + class E(Exception): + def __new__(cls, *args, **kwargs): + return cls + + def boring_generator(): + yield + + gen = boring_generator() + + err_msg = 'should have returned an instance of BaseException' + + with self.assertRaisesRegex(TypeError, err_msg): + gen.throw(E) + + self.assertRaises(StopIteration, next, gen) + + def generator(): + with self.assertRaisesRegex(TypeError, err_msg): + yield + + gen = generator() + next(gen) + with self.assertRaises(StopIteration): + gen.throw(E) + + def test_gen_3_arg_deprecation_warning(self): + def g(): + yield 42 + + gen = g() + with self.assertWarns(DeprecationWarning): + with self.assertRaises(TypeError): + gen.throw(TypeError, TypeError(24), None) + + def test_stopiteration_error(self): + # See also PEP 479. + + def gen(): + raise StopIteration + yield + + with self.assertRaisesRegex(RuntimeError, 'raised StopIteration'): + next(gen()) + + def test_tutorial_stopiteration(self): + # Raise StopIteration" stops the generator too: + + def f(): + yield 1 + raise StopIteration + yield 2 # never reached + + g = f() + self.assertEqual(next(g), 1) + + with self.assertRaisesRegex(RuntimeError, 'raised StopIteration'): + next(g) + + def test_return_tuple(self): + def g(): + return (yield 1) + + gen = g() + self.assertEqual(next(gen), 1) + with self.assertRaises(StopIteration) as cm: + gen.send((2,)) + self.assertEqual(cm.exception.value, (2,)) + + def test_return_stopiteration(self): + def g(): + return (yield 1) + + gen = g() + self.assertEqual(next(gen), 1) + with self.assertRaises(StopIteration) as cm: + gen.send(StopIteration(2)) + self.assertIsInstance(cm.exception.value, StopIteration) + self.assertEqual(cm.exception.value.value, 2) + + +class GeneratorCloseTest(__TestCase): + + def test_close_no_return_value(self): + def f(): + yield + + gen = f() + gen.send(None) + self.assertIsNone(gen.close()) + + def test_close_return_value(self): + def f(): + try: + yield + # close() raises GeneratorExit here, which is caught + except GeneratorExit: + return 0 + + gen = f() + gen.send(None) + self.assertEqual(gen.close(), 0) + + def test_close_not_catching_exit(self): + def f(): + yield + # close() raises GeneratorExit here, which isn't caught and + # therefore propagates -- no return value + return 0 + + gen = f() + gen.send(None) + self.assertIsNone(gen.close()) + + def test_close_not_started(self): + def f(): + try: + yield + except GeneratorExit: + return 0 + + gen = f() + self.assertIsNone(gen.close()) + + def test_close_exhausted(self): + def f(): + try: + yield + except GeneratorExit: + return 0 + + gen = f() + next(gen) + with self.assertRaises(StopIteration): + next(gen) + self.assertIsNone(gen.close()) + + def test_close_closed(self): + def f(): + try: + yield + except GeneratorExit: + return 0 + + gen = f() + gen.send(None) + self.assertEqual(gen.close(), 0) + self.assertIsNone(gen.close()) + + def test_close_raises(self): + def f(): + try: + yield + except GeneratorExit: + pass + raise RuntimeError + + gen = f() + gen.send(None) + with self.assertRaises(RuntimeError): + gen.close() + + def test_close_releases_frame_locals(self): + # See gh-118272 + + class Foo: + pass + + f = Foo() + f_wr = weakref.ref(f) + + def genfn(): + a = f + yield + + g = genfn() + next(g) + del f + g.close() + support.gc_collect() + self.assertIsNone(f_wr()) + + +class GeneratorThrowTest(__TestCase): + + def test_exception_context_with_yield(self): + def f(): + try: + raise KeyError('a') + except Exception: + yield + + gen = f() + gen.send(None) + with self.assertRaises(ValueError) as cm: + gen.throw(ValueError) + context = cm.exception.__context__ + self.assertEqual((type(context), context.args), (KeyError, ('a',))) + + def test_exception_context_with_yield_inside_generator(self): + # Check that the context is also available from inside the generator + # with yield, as opposed to outside. + def f(): + try: + raise KeyError('a') + except Exception: + try: + yield + except Exception as exc: + self.assertEqual(type(exc), ValueError) + context = exc.__context__ + self.assertEqual((type(context), context.args), + (KeyError, ('a',))) + yield 'b' + + gen = f() + gen.send(None) + actual = gen.throw(ValueError) + # This ensures that the assertions inside were executed. + self.assertEqual(actual, 'b') + + def test_exception_context_with_yield_from(self): + def f(): + yield + + def g(): + try: + raise KeyError('a') + except Exception: + yield from f() + + gen = g() + gen.send(None) + with self.assertRaises(ValueError) as cm: + gen.throw(ValueError) + context = cm.exception.__context__ + self.assertEqual((type(context), context.args), (KeyError, ('a',))) + + def test_exception_context_with_yield_from_with_context_cycle(self): + # Check trying to create an exception context cycle: + # https://bugs.python.org/issue40696 + has_cycle = None + + def f(): + yield + + def g(exc): + nonlocal has_cycle + try: + raise exc + except Exception: + try: + yield from f() + except Exception as exc: + has_cycle = (exc is exc.__context__) + yield + + exc = KeyError('a') + gen = g(exc) + gen.send(None) + gen.throw(exc) + # This also distinguishes from the initial has_cycle=None. + self.assertEqual(has_cycle, False) + + def test_throw_after_none_exc_type(self): + def g(): + try: + raise KeyError + except KeyError: + pass + + try: + yield + except Exception: + raise RuntimeError + + gen = g() + gen.send(None) + with self.assertRaises(RuntimeError) as cm: + gen.throw(ValueError) + + +class GeneratorStackTraceTest(__TestCase): + + def check_stack_names(self, frame, expected): + names = [] + while frame: + name = frame.f_code.co_name + # Stop checking frames when we get to our test helper. + if name.startswith('check_') or name.startswith('call_'): + break + + names.append(name) + frame = frame.f_back + + self.assertEqual(names, expected) + + def check_yield_from_example(self, call_method): + def f(): + self.check_stack_names(sys._getframe(), ['f', 'g']) + try: + yield + except Exception: + pass + self.check_stack_names(sys._getframe(), ['f', 'g']) + + def g(): + self.check_stack_names(sys._getframe(), ['g']) + yield from f() + self.check_stack_names(sys._getframe(), ['g']) + + gen = g() + gen.send(None) + try: + call_method(gen) + except StopIteration: + pass + + def test_send_with_yield_from(self): + def call_send(gen): + gen.send(None) + + self.check_yield_from_example(call_send) + + def test_throw_with_yield_from(self): + def call_throw(gen): + gen.throw(RuntimeError) + + self.check_yield_from_example(call_throw) + + +class YieldFromTests(__TestCase): + def test_generator_gi_yieldfrom(self): + def a(): + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING) + self.assertIsNone(gen_b.gi_yieldfrom) + yield + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING) + self.assertIsNone(gen_b.gi_yieldfrom) + + def b(): + self.assertIsNone(gen_b.gi_yieldfrom) + yield from a() + self.assertIsNone(gen_b.gi_yieldfrom) + yield + self.assertIsNone(gen_b.gi_yieldfrom) + + gen_b = b() + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_CREATED) + self.assertIsNone(gen_b.gi_yieldfrom) + + gen_b.send(None) + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_SUSPENDED) + self.assertEqual(gen_b.gi_yieldfrom.gi_code.co_name, 'a') + + gen_b.send(None) + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_SUSPENDED) + self.assertIsNone(gen_b.gi_yieldfrom) + + [] = gen_b # Exhaust generator + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_CLOSED) + self.assertIsNone(gen_b.gi_yieldfrom) + + +tutorial_tests = """ +Let's try a simple generator: + + >>> def f(): + ... yield 1 + ... yield 2 + + >>> for i in f(): + ... print(i) + 1 + 2 + >>> g = f() + >>> next(g) + 1 + >>> next(g) + 2 + +"Falling off the end" stops the generator: + + >>> next(g) + Traceback (most recent call last): + File "", line 1, in ? + File "", line 2, in g + StopIteration + +"return" also stops the generator: + + >>> def f(): + ... yield 1 + ... return + ... yield 2 # never reached + ... + >>> g = f() + >>> next(g) + 1 + >>> next(g) + Traceback (most recent call last): + File "", line 1, in ? + File "", line 3, in f + StopIteration + >>> next(g) # once stopped, can't be resumed + Traceback (most recent call last): + File "", line 1, in ? + StopIteration + +However, "return" and StopIteration are not exactly equivalent: + + >>> def g1(): + ... try: + ... return + ... except: + ... yield 1 + ... + >>> list(g1()) + [] + + >>> def g2(): + ... try: + ... raise StopIteration + ... except: + ... yield 42 + >>> print(list(g2())) + [42] + +This may be surprising at first: + + >>> def g3(): + ... try: + ... return + ... finally: + ... yield 1 + ... + >>> list(g3()) + [1] + +Let's create an alternate range() function implemented as a generator: + + >>> def yrange(n): + ... for i in range(n): + ... yield i + ... + >>> list(yrange(5)) + [0, 1, 2, 3, 4] + +Generators always return to the most recent caller: + + >>> def creator(): + ... r = yrange(5) + ... print("creator", next(r)) + ... return r + ... + >>> def caller(): + ... r = creator() + ... for i in r: + ... print("caller", i) + ... + >>> caller() + creator 0 + caller 1 + caller 2 + caller 3 + caller 4 + +Generators can call other generators: + + >>> def zrange(n): + ... for i in yrange(n): + ... yield i + ... + >>> list(zrange(5)) + [0, 1, 2, 3, 4] + +""" + +# The examples from PEP 255. + +pep_tests = """ + +Specification: Yield + + Restriction: A generator cannot be resumed while it is actively + running: + + >>> def g(): + ... i = next(me) + ... yield i + >>> me = g() + >>> next(me) + Traceback (most recent call last): + ... + File "", line 2, in g + ValueError: generator already executing + +Specification: Return + + Note that return isn't always equivalent to raising StopIteration: the + difference lies in how enclosing try/except constructs are treated. + For example, + + >>> def f1(): + ... try: + ... return + ... except: + ... yield 1 + >>> print(list(f1())) + [] + + because, as in any function, return simply exits, but + + >>> def f2(): + ... try: + ... raise StopIteration + ... except: + ... yield 42 + >>> print(list(f2())) + [42] + + because StopIteration is captured by a bare "except", as is any + exception. + +Specification: Generators and Exception Propagation + + >>> def f(): + ... return 1//0 + >>> def g(): + ... yield f() # the zero division exception propagates + ... yield 42 # and we'll never get here + >>> k = g() + >>> next(k) + Traceback (most recent call last): + File "", line 1, in ? + File "", line 2, in g + File "", line 2, in f + ZeroDivisionError: integer division or modulo by zero + >>> next(k) # and the generator cannot be resumed + Traceback (most recent call last): + File "", line 1, in ? + StopIteration + >>> + +Specification: Try/Except/Finally + + >>> def f(): + ... try: + ... yield 1 + ... try: + ... yield 2 + ... 1//0 + ... yield 3 # never get here + ... except ZeroDivisionError: + ... yield 4 + ... yield 5 + ... raise + ... except: + ... yield 6 + ... yield 7 # the "raise" above stops this + ... except: + ... yield 8 + ... yield 9 + ... try: + ... x = 12 + ... finally: + ... yield 10 + ... yield 11 + >>> print(list(f())) + [1, 2, 4, 5, 8, 9, 10, 11] + >>> + +Guido's binary tree example. + + >>> # A binary tree class. + >>> class Tree: + ... + ... def __init__(self, label, left=None, right=None): + ... self.label = label + ... self.left = left + ... self.right = right + ... + ... def __repr__(self, level=0, indent=" "): + ... s = level*indent + repr(self.label) + ... if self.left: + ... s = s + "\\n" + self.left.__repr__(level+1, indent) + ... if self.right: + ... s = s + "\\n" + self.right.__repr__(level+1, indent) + ... return s + ... + ... def __iter__(self): + ... return inorder(self) + + >>> # Create a Tree from a list. + >>> def tree(list): + ... n = len(list) + ... if n == 0: + ... return [] + ... i = n // 2 + ... return Tree(list[i], tree(list[:i]), tree(list[i+1:])) + + >>> # Show it off: create a tree. + >>> t = tree("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + + >>> # A recursive generator that generates Tree labels in in-order. + >>> def inorder(t): + ... if t: + ... for x in inorder(t.left): + ... yield x + ... yield t.label + ... for x in inorder(t.right): + ... yield x + + >>> # Show it off: create a tree. + >>> t = tree("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + >>> # Print the nodes of the tree in in-order. + >>> for x in t: + ... print(' '+x, end='') + A B C D E F G H I J K L M N O P Q R S T U V W X Y Z + + >>> # A non-recursive generator. + >>> def inorder(node): + ... stack = [] + ... while node: + ... while node.left: + ... stack.append(node) + ... node = node.left + ... yield node.label + ... while not node.right: + ... try: + ... node = stack.pop() + ... except IndexError: + ... return + ... yield node.label + ... node = node.right + + >>> # Exercise the non-recursive generator. + >>> for x in t: + ... print(' '+x, end='') + A B C D E F G H I J K L M N O P Q R S T U V W X Y Z + +""" + +# Examples from Iterator-List and Python-Dev and c.l.py. + +email_tests = """ + +The difference between yielding None and returning it. + +>>> def g(): +... for i in range(3): +... yield None +... yield None +... return +>>> list(g()) +[None, None, None, None] + +Ensure that explicitly raising StopIteration acts like any other exception +in try/except, not like a return. + +>>> def g(): +... yield 1 +... try: +... raise StopIteration +... except: +... yield 2 +... yield 3 +>>> list(g()) +[1, 2, 3] + +Next one was posted to c.l.py. + +>>> def gcomb(x, k): +... "Generate all combinations of k elements from list x." +... +... if k > len(x): +... return +... if k == 0: +... yield [] +... else: +... first, rest = x[0], x[1:] +... # A combination does or doesn't contain first. +... # If it does, the remainder is a k-1 comb of rest. +... for c in gcomb(rest, k-1): +... c.insert(0, first) +... yield c +... # If it doesn't contain first, it's a k comb of rest. +... for c in gcomb(rest, k): +... yield c + +>>> seq = list(range(1, 5)) +>>> for k in range(len(seq) + 2): +... print("%d-combs of %s:" % (k, seq)) +... for c in gcomb(seq, k): +... print(" ", c) +0-combs of [1, 2, 3, 4]: + [] +1-combs of [1, 2, 3, 4]: + [1] + [2] + [3] + [4] +2-combs of [1, 2, 3, 4]: + [1, 2] + [1, 3] + [1, 4] + [2, 3] + [2, 4] + [3, 4] +3-combs of [1, 2, 3, 4]: + [1, 2, 3] + [1, 2, 4] + [1, 3, 4] + [2, 3, 4] +4-combs of [1, 2, 3, 4]: + [1, 2, 3, 4] +5-combs of [1, 2, 3, 4]: + +From the Iterators list, about the types of these things. + +>>> def g(): +... yield 1 +... +>>> type(g) + +>>> i = g() +>>> type(i) + +>>> [s for s in dir(i) if not s.startswith('_')] +['close', 'gi_code', 'gi_frame', 'gi_running', 'gi_suspended', 'gi_yieldfrom', 'send', 'throw'] +>>> from test.support import HAVE_DOCSTRINGS +>>> print(i.__next__.__doc__ if HAVE_DOCSTRINGS else 'Implement next(self).') +Implement next(self). +>>> iter(i) is i +True +>>> import types +>>> isinstance(i, types.GeneratorType) +True + +And more, added later. + +>>> i.gi_running +0 +>>> type(i.gi_frame) + +>>> i.gi_running = 42 +Traceback (most recent call last): + ... +AttributeError: attribute 'gi_running' of 'generator' objects is not writable +>>> def g(): +... yield me.gi_running +>>> me = g() +>>> me.gi_running +0 +>>> next(me) +1 +>>> me.gi_running +0 + +A clever union-find implementation from c.l.py, due to David Eppstein. +Sent: Friday, June 29, 2001 12:16 PM +To: python-list@python.org +Subject: Re: PEP 255: Simple Generators + +>>> class disjointSet: +... def __init__(self, name): +... self.name = name +... self.parent = None +... self.generator = self.generate() +... +... def generate(self): +... while not self.parent: +... yield self +... for x in self.parent.generator: +... yield x +... +... def find(self): +... return next(self.generator) +... +... def union(self, parent): +... if self.parent: +... raise ValueError("Sorry, I'm not a root!") +... self.parent = parent +... +... def __str__(self): +... return self.name + +>>> names = "ABCDEFGHIJKLM" +>>> sets = [disjointSet(name) for name in names] +>>> roots = sets[:] + +>>> import random +>>> gen = random.Random(42) +>>> while 1: +... for s in sets: +... print(" %s->%s" % (s, s.find()), end='') +... print() +... if len(roots) > 1: +... s1 = gen.choice(roots) +... roots.remove(s1) +... s2 = gen.choice(roots) +... s1.union(s2) +... print("merged", s1, "into", s2) +... else: +... break + A->A B->B C->C D->D E->E F->F G->G H->H I->I J->J K->K L->L M->M +merged K into B + A->A B->B C->C D->D E->E F->F G->G H->H I->I J->J K->B L->L M->M +merged A into F + A->F B->B C->C D->D E->E F->F G->G H->H I->I J->J K->B L->L M->M +merged E into F + A->F B->B C->C D->D E->F F->F G->G H->H I->I J->J K->B L->L M->M +merged D into C + A->F B->B C->C D->C E->F F->F G->G H->H I->I J->J K->B L->L M->M +merged M into C + A->F B->B C->C D->C E->F F->F G->G H->H I->I J->J K->B L->L M->C +merged J into B + A->F B->B C->C D->C E->F F->F G->G H->H I->I J->B K->B L->L M->C +merged B into C + A->F B->C C->C D->C E->F F->F G->G H->H I->I J->C K->C L->L M->C +merged F into G + A->G B->C C->C D->C E->G F->G G->G H->H I->I J->C K->C L->L M->C +merged L into C + A->G B->C C->C D->C E->G F->G G->G H->H I->I J->C K->C L->C M->C +merged G into I + A->I B->C C->C D->C E->I F->I G->I H->H I->I J->C K->C L->C M->C +merged I into H + A->H B->C C->C D->C E->H F->H G->H H->H I->H J->C K->C L->C M->C +merged C into H + A->H B->H C->H D->H E->H F->H G->H H->H I->H J->H K->H L->H M->H + +""" +# Emacs turd ' + +# Fun tests (for sufficiently warped notions of "fun"). + +fun_tests = """ + +Build up to a recursive Sieve of Eratosthenes generator. + +>>> def firstn(g, n): +... return [next(g) for i in range(n)] + +>>> def intsfrom(i): +... while 1: +... yield i +... i += 1 + +>>> firstn(intsfrom(5), 7) +[5, 6, 7, 8, 9, 10, 11] + +>>> def exclude_multiples(n, ints): +... for i in ints: +... if i % n: +... yield i + +>>> firstn(exclude_multiples(3, intsfrom(1)), 6) +[1, 2, 4, 5, 7, 8] + +>>> def sieve(ints): +... prime = next(ints) +... yield prime +... not_divisible_by_prime = exclude_multiples(prime, ints) +... for p in sieve(not_divisible_by_prime): +... yield p + +>>> primes = sieve(intsfrom(2)) +>>> firstn(primes, 20) +[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71] + + +Another famous problem: generate all integers of the form + 2**i * 3**j * 5**k +in increasing order, where i,j,k >= 0. Trickier than it may look at first! +Try writing it without generators, and correctly, and without generating +3 internal results for each result output. + +>>> def times(n, g): +... for i in g: +... yield n * i +>>> firstn(times(10, intsfrom(1)), 10) +[10, 20, 30, 40, 50, 60, 70, 80, 90, 100] + +>>> def merge(g, h): +... ng = next(g) +... nh = next(h) +... while 1: +... if ng < nh: +... yield ng +... ng = next(g) +... elif ng > nh: +... yield nh +... nh = next(h) +... else: +... yield ng +... ng = next(g) +... nh = next(h) + +The following works, but is doing a whale of a lot of redundant work -- +it's not clear how to get the internal uses of m235 to share a single +generator. Note that me_times2 (etc) each need to see every element in the +result sequence. So this is an example where lazy lists are more natural +(you can look at the head of a lazy list any number of times). + +>>> def m235(): +... yield 1 +... me_times2 = times(2, m235()) +... me_times3 = times(3, m235()) +... me_times5 = times(5, m235()) +... for i in merge(merge(me_times2, +... me_times3), +... me_times5): +... yield i + +Don't print "too many" of these -- the implementation above is extremely +inefficient: each call of m235() leads to 3 recursive calls, and in +turn each of those 3 more, and so on, and so on, until we've descended +enough levels to satisfy the print stmts. Very odd: when I printed 5 +lines of results below, this managed to screw up Win98's malloc in "the +usual" way, i.e. the heap grew over 4Mb so Win98 started fragmenting +address space, and it *looked* like a very slow leak. + +>>> result = m235() +>>> for i in range(3): +... print(firstn(result, 15)) +[1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 16, 18, 20, 24] +[25, 27, 30, 32, 36, 40, 45, 48, 50, 54, 60, 64, 72, 75, 80] +[81, 90, 96, 100, 108, 120, 125, 128, 135, 144, 150, 160, 162, 180, 192] + +Heh. Here's one way to get a shared list, complete with an excruciating +namespace renaming trick. The *pretty* part is that the times() and merge() +functions can be reused as-is, because they only assume their stream +arguments are iterable -- a LazyList is the same as a generator to times(). + +>>> class LazyList: +... def __init__(self, g): +... self.sofar = [] +... self.fetch = g.__next__ +... +... def __getitem__(self, i): +... sofar, fetch = self.sofar, self.fetch +... while i >= len(sofar): +... sofar.append(fetch()) +... return sofar[i] + +>>> def m235(): +... yield 1 +... # Gack: m235 below actually refers to a LazyList. +... me_times2 = times(2, m235) +... me_times3 = times(3, m235) +... me_times5 = times(5, m235) +... for i in merge(merge(me_times2, +... me_times3), +... me_times5): +... yield i + +Print as many of these as you like -- *this* implementation is memory- +efficient. + +>>> m235 = LazyList(m235()) +>>> for i in range(5): +... print([m235[j] for j in range(15*i, 15*(i+1))]) +[1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 16, 18, 20, 24] +[25, 27, 30, 32, 36, 40, 45, 48, 50, 54, 60, 64, 72, 75, 80] +[81, 90, 96, 100, 108, 120, 125, 128, 135, 144, 150, 160, 162, 180, 192] +[200, 216, 225, 240, 243, 250, 256, 270, 288, 300, 320, 324, 360, 375, 384] +[400, 405, 432, 450, 480, 486, 500, 512, 540, 576, 600, 625, 640, 648, 675] + +Ye olde Fibonacci generator, LazyList style. + +>>> def fibgen(a, b): +... +... def sum(g, h): +... while 1: +... yield next(g) + next(h) +... +... def tail(g): +... next(g) # throw first away +... for x in g: +... yield x +... +... yield a +... yield b +... for s in sum(iter(fib), +... tail(iter(fib))): +... yield s + +>>> fib = LazyList(fibgen(1, 2)) +>>> firstn(iter(fib), 17) +[1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584] + + +Running after your tail with itertools.tee (new in version 2.4) + +The algorithms "m235" (Hamming) and Fibonacci presented above are both +examples of a whole family of FP (functional programming) algorithms +where a function produces and returns a list while the production algorithm +suppose the list as already produced by recursively calling itself. +For these algorithms to work, they must: + +- produce at least a first element without presupposing the existence of + the rest of the list +- produce their elements in a lazy manner + +To work efficiently, the beginning of the list must not be recomputed over +and over again. This is ensured in most FP languages as a built-in feature. +In python, we have to explicitly maintain a list of already computed results +and abandon genuine recursivity. + +This is what had been attempted above with the LazyList class. One problem +with that class is that it keeps a list of all of the generated results and +therefore continually grows. This partially defeats the goal of the generator +concept, viz. produce the results only as needed instead of producing them +all and thereby wasting memory. + +Thanks to itertools.tee, it is now clear "how to get the internal uses of +m235 to share a single generator". + +>>> from itertools import tee +>>> def m235(): +... def _m235(): +... yield 1 +... for n in merge(times(2, m2), +... merge(times(3, m3), +... times(5, m5))): +... yield n +... m1 = _m235() +... m2, m3, m5, mRes = tee(m1, 4) +... return mRes + +>>> it = m235() +>>> for i in range(5): +... print(firstn(it, 15)) +[1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 16, 18, 20, 24] +[25, 27, 30, 32, 36, 40, 45, 48, 50, 54, 60, 64, 72, 75, 80] +[81, 90, 96, 100, 108, 120, 125, 128, 135, 144, 150, 160, 162, 180, 192] +[200, 216, 225, 240, 243, 250, 256, 270, 288, 300, 320, 324, 360, 375, 384] +[400, 405, 432, 450, 480, 486, 500, 512, 540, 576, 600, 625, 640, 648, 675] + +The "tee" function does just what we want. It internally keeps a generated +result for as long as it has not been "consumed" from all of the duplicated +iterators, whereupon it is deleted. You can therefore print the hamming +sequence during hours without increasing memory usage, or very little. + +The beauty of it is that recursive running-after-their-tail FP algorithms +are quite straightforwardly expressed with this Python idiom. + +Ye olde Fibonacci generator, tee style. + +>>> def fib(): +... +... def _isum(g, h): +... while 1: +... yield next(g) + next(h) +... +... def _fib(): +... yield 1 +... yield 2 +... next(fibTail) # throw first away +... for res in _isum(fibHead, fibTail): +... yield res +... +... realfib = _fib() +... fibHead, fibTail, fibRes = tee(realfib, 3) +... return fibRes + +>>> firstn(fib(), 17) +[1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584] + +""" + +# syntax_tests mostly provokes SyntaxErrors. Also fiddling with #if 0 +# hackery. + +syntax_tests = """ + +These are fine: + +>>> def f(): +... yield 1 +... return + +>>> def f(): +... try: +... yield 1 +... finally: +... pass + +>>> def f(): +... try: +... try: +... 1//0 +... except ZeroDivisionError: +... yield 666 +... except: +... pass +... finally: +... pass + +>>> def f(): +... try: +... try: +... yield 12 +... 1//0 +... except ZeroDivisionError: +... yield 666 +... except: +... try: +... x = 12 +... finally: +... yield 12 +... except: +... return +>>> list(f()) +[12, 666] + +>>> def f(): +... yield +>>> type(f()) + + + +>>> def f(): +... if 0: +... yield +>>> type(f()) + + + +>>> def f(): +... if 0: +... yield 1 +>>> type(f()) + + +>>> def f(): +... if "": +... yield None +>>> type(f()) + + +>>> def f(): +... return +... try: +... if x==4: +... pass +... elif 0: +... try: +... 1//0 +... except SyntaxError: +... pass +... else: +... if 0: +... while 12: +... x += 1 +... yield 2 # don't blink +... f(a, b, c, d, e) +... else: +... pass +... except: +... x = 1 +... return +>>> type(f()) + + +>>> def f(): +... if 0: +... def g(): +... yield 1 +... +>>> type(f()) + + +>>> def f(): +... if 0: +... class C: +... def __init__(self): +... yield 1 +... def f(self): +... yield 2 +>>> type(f()) + + +>>> def f(): +... if 0: +... return +... if 0: +... yield 2 +>>> type(f()) + + +This one caused a crash (see SF bug 567538): + +>>> def f(): +... for i in range(3): +... try: +... continue +... finally: +... yield i +... +>>> g = f() +>>> print(next(g)) +0 +>>> print(next(g)) +1 +>>> print(next(g)) +2 +>>> print(next(g)) +Traceback (most recent call last): +StopIteration + + +Test the gi_code attribute + +>>> def f(): +... yield 5 +... +>>> g = f() +>>> g.gi_code is f.__code__ +True +>>> next(g) +5 +>>> next(g) +Traceback (most recent call last): +StopIteration +>>> g.gi_code is f.__code__ +True + + +Test the __name__ attribute and the repr() + +>>> def f(): +... yield 5 +... +>>> g = f() +>>> g.__name__ +'f' +>>> repr(g) # doctest: +ELLIPSIS +'' + +Lambdas shouldn't have their usual return behavior. + +>>> x = lambda: (yield 1) +>>> list(x()) +[1] + +>>> x = lambda: ((yield 1), (yield 2)) +>>> list(x()) +[1, 2] +""" + +# conjoin is a simple backtracking generator, named in honor of Icon's +# "conjunction" control structure. Pass a list of no-argument functions +# that return iterable objects. Easiest to explain by example: assume the +# function list [x, y, z] is passed. Then conjoin acts like: +# +# def g(): +# values = [None] * 3 +# for values[0] in x(): +# for values[1] in y(): +# for values[2] in z(): +# yield values +# +# So some 3-lists of values *may* be generated, each time we successfully +# get into the innermost loop. If an iterator fails (is exhausted) before +# then, it "backtracks" to get the next value from the nearest enclosing +# iterator (the one "to the left"), and starts all over again at the next +# slot (pumps a fresh iterator). Of course this is most useful when the +# iterators have side-effects, so that which values *can* be generated at +# each slot depend on the values iterated at previous slots. + +def simple_conjoin(gs): + + values = [None] * len(gs) + + def gen(i): + if i >= len(gs): + yield values + else: + for values[i] in gs[i](): + for x in gen(i+1): + yield x + + for x in gen(0): + yield x + +# That works fine, but recursing a level and checking i against len(gs) for +# each item produced is inefficient. By doing manual loop unrolling across +# generator boundaries, it's possible to eliminate most of that overhead. +# This isn't worth the bother *in general* for generators, but conjoin() is +# a core building block for some CPU-intensive generator applications. + +def conjoin(gs): + + n = len(gs) + values = [None] * n + + # Do one loop nest at time recursively, until the # of loop nests + # remaining is divisible by 3. + + def gen(i): + if i >= n: + yield values + + elif (n-i) % 3: + ip1 = i+1 + for values[i] in gs[i](): + for x in gen(ip1): + yield x + + else: + for x in _gen3(i): + yield x + + # Do three loop nests at a time, recursing only if at least three more + # remain. Don't call directly: this is an internal optimization for + # gen's use. + + def _gen3(i): + assert i < n and (n-i) % 3 == 0 + ip1, ip2, ip3 = i+1, i+2, i+3 + g, g1, g2 = gs[i : ip3] + + if ip3 >= n: + # These are the last three, so we can yield values directly. + for values[i] in g(): + for values[ip1] in g1(): + for values[ip2] in g2(): + yield values + + else: + # At least 6 loop nests remain; peel off 3 and recurse for the + # rest. + for values[i] in g(): + for values[ip1] in g1(): + for values[ip2] in g2(): + for x in _gen3(ip3): + yield x + + for x in gen(0): + yield x + +# And one more approach: For backtracking apps like the Knight's Tour +# solver below, the number of backtracking levels can be enormous (one +# level per square, for the Knight's Tour, so that e.g. a 100x100 board +# needs 10,000 levels). In such cases Python is likely to run out of +# stack space due to recursion. So here's a recursion-free version of +# conjoin too. +# NOTE WELL: This allows large problems to be solved with only trivial +# demands on stack space. Without explicitly resumable generators, this is +# much harder to achieve. OTOH, this is much slower (up to a factor of 2) +# than the fancy unrolled recursive conjoin. + +def flat_conjoin(gs): # rename to conjoin to run tests with this instead + n = len(gs) + values = [None] * n + iters = [None] * n + _StopIteration = StopIteration # make local because caught a *lot* + i = 0 + while 1: + # Descend. + try: + while i < n: + it = iters[i] = gs[i]().__next__ + values[i] = it() + i += 1 + except _StopIteration: + pass + else: + assert i == n + yield values + + # Backtrack until an older iterator can be resumed. + i -= 1 + while i >= 0: + try: + values[i] = iters[i]() + # Success! Start fresh at next level. + i += 1 + break + except _StopIteration: + # Continue backtracking. + i -= 1 + else: + assert i < 0 + break + +# A conjoin-based N-Queens solver. + +class Queens: + def __init__(self, n): + self.n = n + rangen = range(n) + + # Assign a unique int to each column and diagonal. + # columns: n of those, range(n). + # NW-SE diagonals: 2n-1 of these, i-j unique and invariant along + # each, smallest i-j is 0-(n-1) = 1-n, so add n-1 to shift to 0- + # based. + # NE-SW diagonals: 2n-1 of these, i+j unique and invariant along + # each, smallest i+j is 0, largest is 2n-2. + + # For each square, compute a bit vector of the columns and + # diagonals it covers, and for each row compute a function that + # generates the possibilities for the columns in that row. + self.rowgenerators = [] + for i in rangen: + rowuses = [(1 << j) | # column ordinal + (1 << (n + i-j + n-1)) | # NW-SE ordinal + (1 << (n + 2*n-1 + i+j)) # NE-SW ordinal + for j in rangen] + + def rowgen(rowuses=rowuses): + for j in rangen: + uses = rowuses[j] + if uses & self.used == 0: + self.used |= uses + yield j + self.used &= ~uses + + self.rowgenerators.append(rowgen) + + # Generate solutions. + def solve(self): + self.used = 0 + for row2col in conjoin(self.rowgenerators): + yield row2col + + def printsolution(self, row2col): + n = self.n + assert n == len(row2col) + sep = "+" + "-+" * n + print(sep) + for i in range(n): + squares = [" " for j in range(n)] + squares[row2col[i]] = "Q" + print("|" + "|".join(squares) + "|") + print(sep) + +# A conjoin-based Knight's Tour solver. This is pretty sophisticated +# (e.g., when used with flat_conjoin above, and passing hard=1 to the +# constructor, a 200x200 Knight's Tour was found quickly -- note that we're +# creating 10s of thousands of generators then!), and is lengthy. + +class Knights: + def __init__(self, m, n, hard=0): + self.m, self.n = m, n + + # solve() will set up succs[i] to be a list of square #i's + # successors. + succs = self.succs = [] + + # Remove i0 from each of its successor's successor lists, i.e. + # successors can't go back to i0 again. Return 0 if we can + # detect this makes a solution impossible, else return 1. + + def remove_from_successors(i0, len=len): + # If we remove all exits from a free square, we're dead: + # even if we move to it next, we can't leave it again. + # If we create a square with one exit, we must visit it next; + # else somebody else will have to visit it, and since there's + # only one adjacent, there won't be a way to leave it again. + # Finally, if we create more than one free square with a + # single exit, we can only move to one of them next, leaving + # the other one a dead end. + ne0 = ne1 = 0 + for i in succs[i0]: + s = succs[i] + s.remove(i0) + e = len(s) + if e == 0: + ne0 += 1 + elif e == 1: + ne1 += 1 + return ne0 == 0 and ne1 < 2 + + # Put i0 back in each of its successor's successor lists. + + def add_to_successors(i0): + for i in succs[i0]: + succs[i].append(i0) + + # Generate the first move. + def first(): + if m < 1 or n < 1: + return + + # Since we're looking for a cycle, it doesn't matter where we + # start. Starting in a corner makes the 2nd move easy. + corner = self.coords2index(0, 0) + remove_from_successors(corner) + self.lastij = corner + yield corner + add_to_successors(corner) + + # Generate the second moves. + def second(): + corner = self.coords2index(0, 0) + assert self.lastij == corner # i.e., we started in the corner + if m < 3 or n < 3: + return + assert len(succs[corner]) == 2 + assert self.coords2index(1, 2) in succs[corner] + assert self.coords2index(2, 1) in succs[corner] + # Only two choices. Whichever we pick, the other must be the + # square picked on move m*n, as it's the only way to get back + # to (0, 0). Save its index in self.final so that moves before + # the last know it must be kept free. + for i, j in (1, 2), (2, 1): + this = self.coords2index(i, j) + final = self.coords2index(3-i, 3-j) + self.final = final + + remove_from_successors(this) + succs[final].append(corner) + self.lastij = this + yield this + succs[final].remove(corner) + add_to_successors(this) + + # Generate moves 3 through m*n-1. + def advance(len=len): + # If some successor has only one exit, must take it. + # Else favor successors with fewer exits. + candidates = [] + for i in succs[self.lastij]: + e = len(succs[i]) + assert e > 0, "else remove_from_successors() pruning flawed" + if e == 1: + candidates = [(e, i)] + break + candidates.append((e, i)) + else: + candidates.sort() + + for e, i in candidates: + if i != self.final: + if remove_from_successors(i): + self.lastij = i + yield i + add_to_successors(i) + + # Generate moves 3 through m*n-1. Alternative version using a + # stronger (but more expensive) heuristic to order successors. + # Since the # of backtracking levels is m*n, a poor move early on + # can take eons to undo. Smallest square board for which this + # matters a lot is 52x52. + def advance_hard(vmid=(m-1)/2.0, hmid=(n-1)/2.0, len=len): + # If some successor has only one exit, must take it. + # Else favor successors with fewer exits. + # Break ties via max distance from board centerpoint (favor + # corners and edges whenever possible). + candidates = [] + for i in succs[self.lastij]: + e = len(succs[i]) + assert e > 0, "else remove_from_successors() pruning flawed" + if e == 1: + candidates = [(e, 0, i)] + break + i1, j1 = self.index2coords(i) + d = (i1 - vmid)**2 + (j1 - hmid)**2 + candidates.append((e, -d, i)) + else: + candidates.sort() + + for e, d, i in candidates: + if i != self.final: + if remove_from_successors(i): + self.lastij = i + yield i + add_to_successors(i) + + # Generate the last move. + def last(): + assert self.final in succs[self.lastij] + yield self.final + + if m*n < 4: + self.squaregenerators = [first] + else: + self.squaregenerators = [first, second] + \ + [hard and advance_hard or advance] * (m*n - 3) + \ + [last] + + def coords2index(self, i, j): + assert 0 <= i < self.m + assert 0 <= j < self.n + return i * self.n + j + + def index2coords(self, index): + assert 0 <= index < self.m * self.n + return divmod(index, self.n) + + def _init_board(self): + succs = self.succs + del succs[:] + m, n = self.m, self.n + c2i = self.coords2index + + offsets = [( 1, 2), ( 2, 1), ( 2, -1), ( 1, -2), + (-1, -2), (-2, -1), (-2, 1), (-1, 2)] + rangen = range(n) + for i in range(m): + for j in rangen: + s = [c2i(i+io, j+jo) for io, jo in offsets + if 0 <= i+io < m and + 0 <= j+jo < n] + succs.append(s) + + # Generate solutions. + def solve(self): + self._init_board() + for x in conjoin(self.squaregenerators): + yield x + + def printsolution(self, x): + m, n = self.m, self.n + assert len(x) == m*n + w = len(str(m*n)) + format = "%" + str(w) + "d" + + squares = [[None] * n for i in range(m)] + k = 1 + for i in x: + i1, j1 = self.index2coords(i) + squares[i1][j1] = format % k + k += 1 + + sep = "+" + ("-" * w + "+") * n + print(sep) + for i in range(m): + row = squares[i] + print("|" + "|".join(row) + "|") + print(sep) + +conjoin_tests = """ + +Generate the 3-bit binary numbers in order. This illustrates dumbest- +possible use of conjoin, just to generate the full cross-product. + +>>> for c in conjoin([lambda: iter((0, 1))] * 3): +... print(c) +[0, 0, 0] +[0, 0, 1] +[0, 1, 0] +[0, 1, 1] +[1, 0, 0] +[1, 0, 1] +[1, 1, 0] +[1, 1, 1] + +For efficiency in typical backtracking apps, conjoin() yields the same list +object each time. So if you want to save away a full account of its +generated sequence, you need to copy its results. + +>>> def gencopy(iterator): +... for x in iterator: +... yield x[:] + +>>> for n in range(10): +... all = list(gencopy(conjoin([lambda: iter((0, 1))] * n))) +... print(n, len(all), all[0] == [0] * n, all[-1] == [1] * n) +0 1 True True +1 2 True True +2 4 True True +3 8 True True +4 16 True True +5 32 True True +6 64 True True +7 128 True True +8 256 True True +9 512 True True + +And run an 8-queens solver. + +>>> q = Queens(8) +>>> LIMIT = 2 +>>> count = 0 +>>> for row2col in q.solve(): +... count += 1 +... if count <= LIMIT: +... print("Solution", count) +... q.printsolution(row2col) +Solution 1 ++-+-+-+-+-+-+-+-+ +|Q| | | | | | | | ++-+-+-+-+-+-+-+-+ +| | | | |Q| | | | ++-+-+-+-+-+-+-+-+ +| | | | | | | |Q| ++-+-+-+-+-+-+-+-+ +| | | | | |Q| | | ++-+-+-+-+-+-+-+-+ +| | |Q| | | | | | ++-+-+-+-+-+-+-+-+ +| | | | | | |Q| | ++-+-+-+-+-+-+-+-+ +| |Q| | | | | | | ++-+-+-+-+-+-+-+-+ +| | | |Q| | | | | ++-+-+-+-+-+-+-+-+ +Solution 2 ++-+-+-+-+-+-+-+-+ +|Q| | | | | | | | ++-+-+-+-+-+-+-+-+ +| | | | | |Q| | | ++-+-+-+-+-+-+-+-+ +| | | | | | | |Q| ++-+-+-+-+-+-+-+-+ +| | |Q| | | | | | ++-+-+-+-+-+-+-+-+ +| | | | | | |Q| | ++-+-+-+-+-+-+-+-+ +| | | |Q| | | | | ++-+-+-+-+-+-+-+-+ +| |Q| | | | | | | ++-+-+-+-+-+-+-+-+ +| | | | |Q| | | | ++-+-+-+-+-+-+-+-+ + +>>> print(count, "solutions in all.") +92 solutions in all. + +And run a Knight's Tour on a 10x10 board. Note that there are about +20,000 solutions even on a 6x6 board, so don't dare run this to exhaustion. + +>>> k = Knights(10, 10) +>>> LIMIT = 2 +>>> count = 0 +>>> for x in k.solve(): +... count += 1 +... if count <= LIMIT: +... print("Solution", count) +... k.printsolution(x) +... else: +... break +Solution 1 ++---+---+---+---+---+---+---+---+---+---+ +| 1| 58| 27| 34| 3| 40| 29| 10| 5| 8| ++---+---+---+---+---+---+---+---+---+---+ +| 26| 35| 2| 57| 28| 33| 4| 7| 30| 11| ++---+---+---+---+---+---+---+---+---+---+ +| 59|100| 73| 36| 41| 56| 39| 32| 9| 6| ++---+---+---+---+---+---+---+---+---+---+ +| 74| 25| 60| 55| 72| 37| 42| 49| 12| 31| ++---+---+---+---+---+---+---+---+---+---+ +| 61| 86| 99| 76| 63| 52| 47| 38| 43| 50| ++---+---+---+---+---+---+---+---+---+---+ +| 24| 75| 62| 85| 54| 71| 64| 51| 48| 13| ++---+---+---+---+---+---+---+---+---+---+ +| 87| 98| 91| 80| 77| 84| 53| 46| 65| 44| ++---+---+---+---+---+---+---+---+---+---+ +| 90| 23| 88| 95| 70| 79| 68| 83| 14| 17| ++---+---+---+---+---+---+---+---+---+---+ +| 97| 92| 21| 78| 81| 94| 19| 16| 45| 66| ++---+---+---+---+---+---+---+---+---+---+ +| 22| 89| 96| 93| 20| 69| 82| 67| 18| 15| ++---+---+---+---+---+---+---+---+---+---+ +Solution 2 ++---+---+---+---+---+---+---+---+---+---+ +| 1| 58| 27| 34| 3| 40| 29| 10| 5| 8| ++---+---+---+---+---+---+---+---+---+---+ +| 26| 35| 2| 57| 28| 33| 4| 7| 30| 11| ++---+---+---+---+---+---+---+---+---+---+ +| 59|100| 73| 36| 41| 56| 39| 32| 9| 6| ++---+---+---+---+---+---+---+---+---+---+ +| 74| 25| 60| 55| 72| 37| 42| 49| 12| 31| ++---+---+---+---+---+---+---+---+---+---+ +| 61| 86| 99| 76| 63| 52| 47| 38| 43| 50| ++---+---+---+---+---+---+---+---+---+---+ +| 24| 75| 62| 85| 54| 71| 64| 51| 48| 13| ++---+---+---+---+---+---+---+---+---+---+ +| 87| 98| 89| 80| 77| 84| 53| 46| 65| 44| ++---+---+---+---+---+---+---+---+---+---+ +| 90| 23| 92| 95| 70| 79| 68| 83| 14| 17| ++---+---+---+---+---+---+---+---+---+---+ +| 97| 88| 21| 78| 81| 94| 19| 16| 45| 66| ++---+---+---+---+---+---+---+---+---+---+ +| 22| 91| 96| 93| 20| 69| 82| 67| 18| 15| ++---+---+---+---+---+---+---+---+---+---+ +""" + +weakref_tests = """\ +Generators are weakly referencable: + +>>> import weakref +>>> def gen(): +... yield 'foo!' +... +>>> wr = weakref.ref(gen) +>>> wr() is gen +True +>>> p = weakref.proxy(gen) + +Generator-iterators are weakly referencable as well: + +>>> gi = gen() +>>> wr = weakref.ref(gi) +>>> wr() is gi +True +>>> p = weakref.proxy(gi) +>>> list(p) +['foo!'] + +""" + +coroutine_tests = """\ +>>> from test.support import gc_collect + +Sending a value into a started generator: + +>>> def f(): +... print((yield 1)) +... yield 2 +>>> g = f() +>>> next(g) +1 +>>> g.send(42) +42 +2 + +Sending a value into a new generator produces a TypeError: + +>>> f().send("foo") +Traceback (most recent call last): +... +TypeError: can't send non-None value to a just-started generator + + +Yield by itself yields None: + +>>> def f(): yield +>>> list(f()) +[None] + + +Yield is allowed only in the outermost iterable in generator expression: + +>>> def f(): list(i for i in [(yield 26)]) +>>> type(f()) + + + +A yield expression with augmented assignment. + +>>> def coroutine(seq): +... count = 0 +... while count < 200: +... count += yield +... seq.append(count) +>>> seq = [] +>>> c = coroutine(seq) +>>> next(c) +>>> print(seq) +[] +>>> c.send(10) +>>> print(seq) +[10] +>>> c.send(10) +>>> print(seq) +[10, 20] +>>> c.send(10) +>>> print(seq) +[10, 20, 30] + + +Check some syntax errors for yield expressions: + +>>> f=lambda: (yield 1),(yield 2) +Traceback (most recent call last): + ... +SyntaxError: 'yield' outside function + +>>> f=lambda: (yield from (1,2)), (yield from (3,4)) +Traceback (most recent call last): + ... +SyntaxError: 'yield from' outside function + +>>> yield from [1,2] +Traceback (most recent call last): + ... +SyntaxError: 'yield from' outside function + +>>> def f(): x = yield = y +Traceback (most recent call last): + ... +SyntaxError: assignment to yield expression not possible + +>>> def f(): (yield bar) = y +Traceback (most recent call last): + ... +SyntaxError: cannot assign to yield expression here. Maybe you meant '==' instead of '='? + +>>> def f(): (yield bar) += y +Traceback (most recent call last): + ... +SyntaxError: 'yield expression' is an illegal expression for augmented assignment + + +Now check some throw() conditions: + +>>> def f(): +... while True: +... try: +... print((yield)) +... except ValueError as v: +... print("caught ValueError (%s)" % (v)) +>>> import sys +>>> g = f() +>>> next(g) + +>>> g.throw(ValueError) # type only +caught ValueError () + +>>> g.throw(ValueError("xyz")) # value only +caught ValueError (xyz) + +>>> import warnings +>>> old_filters = warnings.filters.copy() +>>> warnings.filterwarnings("ignore", category=DeprecationWarning) + +# Filter DeprecationWarning: regarding the (type, val, tb) signature of throw(). +# Deprecation warnings are re-enabled below. + +>>> g.throw(ValueError, ValueError(1)) # value+matching type +caught ValueError (1) + +>>> g.throw(ValueError, TypeError(1)) # mismatched type, rewrapped +caught ValueError (1) + +>>> g.throw(ValueError, ValueError(1), None) # explicit None traceback +caught ValueError (1) + +>>> g.throw(ValueError(1), "foo") # bad args +Traceback (most recent call last): + ... +TypeError: instance exception may not have a separate value + +>>> g.throw(ValueError, "foo", 23) # bad args +Traceback (most recent call last): + ... +TypeError: throw() third argument must be a traceback object + +>>> g.throw("abc") +Traceback (most recent call last): + ... +TypeError: exceptions must be classes or instances deriving from BaseException, not str + +>>> g.throw(0) +Traceback (most recent call last): + ... +TypeError: exceptions must be classes or instances deriving from BaseException, not int + +>>> g.throw(list) +Traceback (most recent call last): + ... +TypeError: exceptions must be classes or instances deriving from BaseException, not type + +>>> def throw(g,exc): +... try: +... raise exc +... except: +... g.throw(*sys.exc_info()) +>>> throw(g,ValueError) # do it with traceback included +caught ValueError () + +>>> g.send(1) +1 + +>>> throw(g,TypeError) # terminate the generator +Traceback (most recent call last): + ... +TypeError + +>>> print(g.gi_frame) +None + +>>> g.send(2) +Traceback (most recent call last): + ... +StopIteration + +>>> g.throw(ValueError,6) # throw on closed generator +Traceback (most recent call last): + ... +ValueError: 6 + +>>> f().throw(ValueError,7) # throw on just-opened generator +Traceback (most recent call last): + ... +ValueError: 7 + +>>> warnings.filters[:] = old_filters + +# Re-enable DeprecationWarning: the (type, val, tb) exception representation is deprecated, +# and may be removed in a future version of Python. + +Plain "raise" inside a generator should preserve the traceback (#13188). +The traceback should have 3 levels: +- g.throw() +- f() +- 1/0 + +>>> def f(): +... try: +... yield +... except: +... raise +>>> g = f() +>>> try: +... 1/0 +... except ZeroDivisionError as v: +... try: +... g.throw(v) +... except Exception as w: +... tb = w.__traceback__ +>>> levels = 0 +>>> while tb: +... levels += 1 +... tb = tb.tb_next +>>> levels +3 + +Now let's try closing a generator: + +>>> def f(): +... try: yield +... except GeneratorExit: +... print("exiting") + +>>> g = f() +>>> next(g) +>>> g.close() +exiting +>>> g.close() # should be no-op now + +>>> f().close() # close on just-opened generator should be fine + +>>> def f(): yield # an even simpler generator +>>> f().close() # close before opening +>>> g = f() +>>> next(g) +>>> g.close() # close normally + +And finalization: + +>>> def f(): +... try: yield +... finally: +... print("exiting") + +>>> g = f() +>>> next(g) +>>> del g; gc_collect() # For PyPy or other GCs. +exiting + + +GeneratorExit is not caught by except Exception: + +>>> def f(): +... try: yield +... except Exception: +... print('except') +... finally: +... print('finally') + +>>> g = f() +>>> next(g) +>>> del g; gc_collect() # For PyPy or other GCs. +finally + + +Now let's try some ill-behaved generators: + +>>> def f(): +... try: yield +... except GeneratorExit: +... yield "foo!" +>>> g = f() +>>> next(g) +>>> g.close() +Traceback (most recent call last): + ... +RuntimeError: generator ignored GeneratorExit +>>> g.close() + + +Our ill-behaved code should be invoked during GC: + +>>> with support.catch_unraisable_exception() as cm: +... g = f() +... next(g) +... del g +... +... cm.unraisable.exc_type == RuntimeError +... "generator ignored GeneratorExit" in str(cm.unraisable.exc_value) +... cm.unraisable.exc_traceback is not None +True +True +True + +And errors thrown during closing should propagate: + +>>> def f(): +... try: yield +... except GeneratorExit: +... raise TypeError("fie!") +>>> g = f() +>>> next(g) +>>> g.close() +Traceback (most recent call last): + ... +TypeError: fie! + + +Ensure that various yield expression constructs make their +enclosing function a generator: + +>>> def f(): x += yield +>>> type(f()) + + +>>> def f(): x = yield +>>> type(f()) + + +>>> def f(): lambda x=(yield): 1 +>>> type(f()) + + +>>> def f(d): d[(yield "a")] = d[(yield "b")] = 27 +>>> data = [1,2] +>>> g = f(data) +>>> type(g) + +>>> g.send(None) +'a' +>>> data +[1, 2] +>>> g.send(0) +'b' +>>> data +[27, 2] +>>> try: g.send(1) +... except StopIteration: pass +>>> data +[27, 27] + +""" + +refleaks_tests = """ +Prior to adding cycle-GC support to itertools.tee, this code would leak +references. We add it to the standard suite so the routine refleak-tests +would trigger if it starts being uncleanable again. + +>>> import itertools +>>> def leak(): +... class gen: +... def __iter__(self): +... return self +... def __next__(self): +... return self.item +... g = gen() +... head, tail = itertools.tee(g) +... g.item = head +... return head +>>> it = leak() + +Make sure to also test the involvement of the tee-internal teedataobject, +which stores returned items. + +>>> item = next(it) + + + +This test leaked at one point due to generator finalization/destruction. +It was copied from Lib/test/leakers/test_generator_cycle.py before the file +was removed. + +>>> def leak(): +... def gen(): +... while True: +... yield g +... g = gen() + +>>> leak() + + + +This test isn't really generator related, but rather exception-in-cleanup +related. The coroutine tests (above) just happen to cause an exception in +the generator's __del__ (tp_del) method. We can also test for this +explicitly, without generators. We do have to redirect stderr to avoid +printing warnings and to doublecheck that we actually tested what we wanted +to test. + +>>> from test import support +>>> class Leaker: +... def __del__(self): +... def invoke(message): +... raise RuntimeError(message) +... invoke("del failed") +... +>>> with support.catch_unraisable_exception() as cm: +... l = Leaker() +... del l +... +... cm.unraisable.object == Leaker.__del__ +... cm.unraisable.exc_type == RuntimeError +... str(cm.unraisable.exc_value) == "del failed" +... cm.unraisable.exc_traceback is not None +True +True +True +True + + +These refleak tests should perhaps be in a testfile of their own, +test_generators just happened to be the test that drew these out. + +""" + +# __test__ = {"tut": tutorial_tests, +# "pep": pep_tests, +# "email": email_tests, +# "fun": fun_tests, +# "syntax": syntax_tests, +# "conjoin": conjoin_tests, +# "weakref": weakref_tests, +# "coroutine": coroutine_tests, +# "refleaks": refleaks_tests, +# } + +# def load_tests(loader, tests, pattern): +# # ======= BEGIN Dynamo patch ======= +# suite = doctest.DocTestSuite() +# for test in suite: +# # Dynamically change base class +# test.__class__ = type(test.__class__.__name__, (__TestCase, test.__class__), {}) +# tests.addTests(suite) +# # ======= END DYNAMO PATCH ======= +# return tests + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_int.diff b/test/dynamo/cpython/3_13/test_int.diff new file mode 100644 index 00000000000000..257cbbd3768665 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_int.diff @@ -0,0 +1,187 @@ +diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py +index 48825f46911..4ab200372ea 100644 +--- a/test/dynamo/cpython/3_13/test_int.py ++++ b/test/dynamo/cpython/3_13/test_int.py +@@ -1,13 +1,137 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import sys + import time + + import unittest + from unittest import mock + from test import support +-from test.support.numbers import ( +- VALID_UNDERSCORE_LITERALS, +- INVALID_UNDERSCORE_LITERALS, +-) ++ ++VALID_UNDERSCORE_LITERALS = [ ++ '0_0_0', ++ '4_2', ++ '1_0000_0000', ++ '0b1001_0100', ++ '0xffff_ffff', ++ '0o5_7_7', ++ '1_00_00.5', ++ '1_00_00.5e5', ++ '1_00_00e5_1', ++ '1e1_0', ++ '.1_4', ++ '.1_4e1', ++ '0b_0', ++ '0x_f', ++ '0o_5', ++ '1_00_00j', ++ '1_00_00.5j', ++ '1_00_00e5_1j', ++ '.1_4j', ++ '(1_2.5+3_3j)', ++ '(.5_6j)', ++] ++INVALID_UNDERSCORE_LITERALS = [ ++ # Trailing underscores: ++ '0_', ++ '42_', ++ '1.4j_', ++ '0x_', ++ '0b1_', ++ '0xf_', ++ '0o5_', ++ '0 if 1_Else 1', ++ # Underscores in the base selector: ++ '0_b0', ++ '0_xf', ++ '0_o5', ++ # Old-style octal, still disallowed: ++ '0_7', ++ '09_99', ++ # Multiple consecutive underscores: ++ '4_______2', ++ '0.1__4', ++ '0.1__4j', ++ '0b1001__0100', ++ '0xffff__ffff', ++ '0x___', ++ '0o5__77', ++ '1e1__0', ++ '1e1__0j', ++ # Underscore right before a dot: ++ '1_.4', ++ '1_.4j', ++ # Underscore right after a dot: ++ '1._4', ++ '1._4j', ++ '._5', ++ '._5j', ++ # Underscore right after a sign: ++ '1.0e+_1', ++ '1.0e+_1j', ++ # Underscore right before j: ++ '1.4_j', ++ '1.4e5_j', ++ # Underscore right before e: ++ '1_e1', ++ '1.4_e1', ++ '1.4_e1j', ++ # Underscore right after e: ++ '1e_1', ++ '1.4e_1', ++ '1.4e_1j', ++ # Complex cases with parens: ++ '(1+1.5_j_)', ++ '(1+1.5_j)', ++] + + try: + import _pylong +@@ -38,7 +162,7 @@ L = [ + class IntSubclass(int): + pass + +-class IntTestCases(unittest.TestCase): ++class IntTestCases(__TestCase): + + def test_basic(self): + self.assertEqual(int(314), 314) +@@ -566,6 +690,7 @@ class IntTestCases(unittest.TestCase): + self.assertEqual(n, 1) + self.assertIs(type(n), IntSubclass) + ++ @skipIfTorchDynamo("flaky under dynamo") + def test_error_message(self): + def check(s, base=None): + with self.assertRaises(ValueError, +@@ -607,7 +732,7 @@ class IntTestCases(unittest.TestCase): + self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807) + + +-class IntStrDigitLimitsTests(unittest.TestCase): ++class IntStrDigitLimitsTests(__TestCase): + + int_class = int # Override this in subclasses to reuse the suite. + +@@ -818,7 +943,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests): + int_class = IntSubclass + + +-class PyLongModuleTests(unittest.TestCase): ++class PyLongModuleTests(__TestCase): + # Tests of the functions in _pylong.py. Those get used when the + # number of digits in the input values are large enough. + +@@ -922,4 +1047,4 @@ class PyLongModuleTests(unittest.TestCase): + bits <<= 1 + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py new file mode 100644 index 00000000000000..4ab200372ea222 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_int.py @@ -0,0 +1,1050 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import sys +import time + +import unittest +from unittest import mock +from test import support + +VALID_UNDERSCORE_LITERALS = [ + '0_0_0', + '4_2', + '1_0000_0000', + '0b1001_0100', + '0xffff_ffff', + '0o5_7_7', + '1_00_00.5', + '1_00_00.5e5', + '1_00_00e5_1', + '1e1_0', + '.1_4', + '.1_4e1', + '0b_0', + '0x_f', + '0o_5', + '1_00_00j', + '1_00_00.5j', + '1_00_00e5_1j', + '.1_4j', + '(1_2.5+3_3j)', + '(.5_6j)', +] +INVALID_UNDERSCORE_LITERALS = [ + # Trailing underscores: + '0_', + '42_', + '1.4j_', + '0x_', + '0b1_', + '0xf_', + '0o5_', + '0 if 1_Else 1', + # Underscores in the base selector: + '0_b0', + '0_xf', + '0_o5', + # Old-style octal, still disallowed: + '0_7', + '09_99', + # Multiple consecutive underscores: + '4_______2', + '0.1__4', + '0.1__4j', + '0b1001__0100', + '0xffff__ffff', + '0x___', + '0o5__77', + '1e1__0', + '1e1__0j', + # Underscore right before a dot: + '1_.4', + '1_.4j', + # Underscore right after a dot: + '1._4', + '1._4j', + '._5', + '._5j', + # Underscore right after a sign: + '1.0e+_1', + '1.0e+_1j', + # Underscore right before j: + '1.4_j', + '1.4e5_j', + # Underscore right before e: + '1_e1', + '1.4_e1', + '1.4_e1j', + # Underscore right after e: + '1e_1', + '1.4e_1', + '1.4e_1j', + # Complex cases with parens: + '(1+1.5_j_)', + '(1+1.5_j)', +] + +try: + import _pylong +except ImportError: + _pylong = None + +L = [ + ('0', 0), + ('1', 1), + ('9', 9), + ('10', 10), + ('99', 99), + ('100', 100), + ('314', 314), + (' 314', 314), + ('314 ', 314), + (' \t\t 314 \t\t ', 314), + (repr(sys.maxsize), sys.maxsize), + (' 1x', ValueError), + (' 1 ', 1), + (' 1\02 ', ValueError), + ('', ValueError), + (' ', ValueError), + (' \t\t ', ValueError), + ("\u0200", ValueError) +] + +class IntSubclass(int): + pass + +class IntTestCases(__TestCase): + + def test_basic(self): + self.assertEqual(int(314), 314) + self.assertEqual(int(3.14), 3) + # Check that conversion from float truncates towards zero + self.assertEqual(int(-3.14), -3) + self.assertEqual(int(3.9), 3) + self.assertEqual(int(-3.9), -3) + self.assertEqual(int(3.5), 3) + self.assertEqual(int(-3.5), -3) + self.assertEqual(int("-3"), -3) + self.assertEqual(int(" -3 "), -3) + self.assertEqual(int("\N{EM SPACE}-3\N{EN SPACE}"), -3) + # Different base: + self.assertEqual(int("10",16), 16) + # Test conversion from strings and various anomalies + for s, v in L: + for sign in "", "+", "-": + for prefix in "", " ", "\t", " \t\t ": + ss = prefix + sign + s + vv = v + if sign == "-" and v is not ValueError: + vv = -v + try: + self.assertEqual(int(ss), vv) + except ValueError: + pass + + s = repr(-1-sys.maxsize) + x = int(s) + self.assertEqual(x+1, -sys.maxsize) + self.assertIsInstance(x, int) + # should return int + self.assertEqual(int(s[1:]), sys.maxsize+1) + + # should return int + x = int(1e100) + self.assertIsInstance(x, int) + x = int(-1e100) + self.assertIsInstance(x, int) + + + # SF bug 434186: 0x80000000/2 != 0x80000000>>1. + # Worked by accident in Windows release build, but failed in debug build. + # Failed in all Linux builds. + x = -1-sys.maxsize + self.assertEqual(x >> 1, x//2) + + x = int('1' * 600) + self.assertIsInstance(x, int) + + + self.assertRaises(TypeError, int, 1, 12) + self.assertRaises(TypeError, int, "10", 2, 1) + + self.assertEqual(int('0o123', 0), 83) + self.assertEqual(int('0x123', 16), 291) + + # Bug 1679: "0x" is not a valid hex literal + self.assertRaises(ValueError, int, "0x", 16) + self.assertRaises(ValueError, int, "0x", 0) + + self.assertRaises(ValueError, int, "0o", 8) + self.assertRaises(ValueError, int, "0o", 0) + + self.assertRaises(ValueError, int, "0b", 2) + self.assertRaises(ValueError, int, "0b", 0) + + # SF bug 1334662: int(string, base) wrong answers + # Various representations of 2**32 evaluated to 0 + # rather than 2**32 in previous versions + + self.assertEqual(int('100000000000000000000000000000000', 2), 4294967296) + self.assertEqual(int('102002022201221111211', 3), 4294967296) + self.assertEqual(int('10000000000000000', 4), 4294967296) + self.assertEqual(int('32244002423141', 5), 4294967296) + self.assertEqual(int('1550104015504', 6), 4294967296) + self.assertEqual(int('211301422354', 7), 4294967296) + self.assertEqual(int('40000000000', 8), 4294967296) + self.assertEqual(int('12068657454', 9), 4294967296) + self.assertEqual(int('4294967296', 10), 4294967296) + self.assertEqual(int('1904440554', 11), 4294967296) + self.assertEqual(int('9ba461594', 12), 4294967296) + self.assertEqual(int('535a79889', 13), 4294967296) + self.assertEqual(int('2ca5b7464', 14), 4294967296) + self.assertEqual(int('1a20dcd81', 15), 4294967296) + self.assertEqual(int('100000000', 16), 4294967296) + self.assertEqual(int('a7ffda91', 17), 4294967296) + self.assertEqual(int('704he7g4', 18), 4294967296) + self.assertEqual(int('4f5aff66', 19), 4294967296) + self.assertEqual(int('3723ai4g', 20), 4294967296) + self.assertEqual(int('281d55i4', 21), 4294967296) + self.assertEqual(int('1fj8b184', 22), 4294967296) + self.assertEqual(int('1606k7ic', 23), 4294967296) + self.assertEqual(int('mb994ag', 24), 4294967296) + self.assertEqual(int('hek2mgl', 25), 4294967296) + self.assertEqual(int('dnchbnm', 26), 4294967296) + self.assertEqual(int('b28jpdm', 27), 4294967296) + self.assertEqual(int('8pfgih4', 28), 4294967296) + self.assertEqual(int('76beigg', 29), 4294967296) + self.assertEqual(int('5qmcpqg', 30), 4294967296) + self.assertEqual(int('4q0jto4', 31), 4294967296) + self.assertEqual(int('4000000', 32), 4294967296) + self.assertEqual(int('3aokq94', 33), 4294967296) + self.assertEqual(int('2qhxjli', 34), 4294967296) + self.assertEqual(int('2br45qb', 35), 4294967296) + self.assertEqual(int('1z141z4', 36), 4294967296) + + # tests with base 0 + # this fails on 3.0, but in 2.x the old octal syntax is allowed + self.assertEqual(int(' 0o123 ', 0), 83) + self.assertEqual(int(' 0o123 ', 0), 83) + self.assertEqual(int('000', 0), 0) + self.assertEqual(int('0o123', 0), 83) + self.assertEqual(int('0x123', 0), 291) + self.assertEqual(int('0b100', 0), 4) + self.assertEqual(int(' 0O123 ', 0), 83) + self.assertEqual(int(' 0X123 ', 0), 291) + self.assertEqual(int(' 0B100 ', 0), 4) + with self.assertRaises(ValueError): + int('010', 0) + + # without base still base 10 + self.assertEqual(int('0123'), 123) + self.assertEqual(int('0123', 10), 123) + + # tests with prefix and base != 0 + self.assertEqual(int('0x123', 16), 291) + self.assertEqual(int('0o123', 8), 83) + self.assertEqual(int('0b100', 2), 4) + self.assertEqual(int('0X123', 16), 291) + self.assertEqual(int('0O123', 8), 83) + self.assertEqual(int('0B100', 2), 4) + + # the code has special checks for the first character after the + # type prefix + self.assertRaises(ValueError, int, '0b2', 2) + self.assertRaises(ValueError, int, '0b02', 2) + self.assertRaises(ValueError, int, '0B2', 2) + self.assertRaises(ValueError, int, '0B02', 2) + self.assertRaises(ValueError, int, '0o8', 8) + self.assertRaises(ValueError, int, '0o08', 8) + self.assertRaises(ValueError, int, '0O8', 8) + self.assertRaises(ValueError, int, '0O08', 8) + self.assertRaises(ValueError, int, '0xg', 16) + self.assertRaises(ValueError, int, '0x0g', 16) + self.assertRaises(ValueError, int, '0Xg', 16) + self.assertRaises(ValueError, int, '0X0g', 16) + + # SF bug 1334662: int(string, base) wrong answers + # Checks for proper evaluation of 2**32 + 1 + self.assertEqual(int('100000000000000000000000000000001', 2), 4294967297) + self.assertEqual(int('102002022201221111212', 3), 4294967297) + self.assertEqual(int('10000000000000001', 4), 4294967297) + self.assertEqual(int('32244002423142', 5), 4294967297) + self.assertEqual(int('1550104015505', 6), 4294967297) + self.assertEqual(int('211301422355', 7), 4294967297) + self.assertEqual(int('40000000001', 8), 4294967297) + self.assertEqual(int('12068657455', 9), 4294967297) + self.assertEqual(int('4294967297', 10), 4294967297) + self.assertEqual(int('1904440555', 11), 4294967297) + self.assertEqual(int('9ba461595', 12), 4294967297) + self.assertEqual(int('535a7988a', 13), 4294967297) + self.assertEqual(int('2ca5b7465', 14), 4294967297) + self.assertEqual(int('1a20dcd82', 15), 4294967297) + self.assertEqual(int('100000001', 16), 4294967297) + self.assertEqual(int('a7ffda92', 17), 4294967297) + self.assertEqual(int('704he7g5', 18), 4294967297) + self.assertEqual(int('4f5aff67', 19), 4294967297) + self.assertEqual(int('3723ai4h', 20), 4294967297) + self.assertEqual(int('281d55i5', 21), 4294967297) + self.assertEqual(int('1fj8b185', 22), 4294967297) + self.assertEqual(int('1606k7id', 23), 4294967297) + self.assertEqual(int('mb994ah', 24), 4294967297) + self.assertEqual(int('hek2mgm', 25), 4294967297) + self.assertEqual(int('dnchbnn', 26), 4294967297) + self.assertEqual(int('b28jpdn', 27), 4294967297) + self.assertEqual(int('8pfgih5', 28), 4294967297) + self.assertEqual(int('76beigh', 29), 4294967297) + self.assertEqual(int('5qmcpqh', 30), 4294967297) + self.assertEqual(int('4q0jto5', 31), 4294967297) + self.assertEqual(int('4000001', 32), 4294967297) + self.assertEqual(int('3aokq95', 33), 4294967297) + self.assertEqual(int('2qhxjlj', 34), 4294967297) + self.assertEqual(int('2br45qc', 35), 4294967297) + self.assertEqual(int('1z141z5', 36), 4294967297) + + def test_invalid_signs(self): + with self.assertRaises(ValueError): + int('+') + with self.assertRaises(ValueError): + int('-') + with self.assertRaises(ValueError): + int('- 1') + with self.assertRaises(ValueError): + int('+ 1') + with self.assertRaises(ValueError): + int(' + 1 ') + + def test_unicode(self): + self.assertEqual(int("१२३४५६७८९०1234567890"), 12345678901234567890) + self.assertEqual(int('١٢٣٤٥٦٧٨٩٠'), 1234567890) + self.assertEqual(int("१२३४५६७८९०1234567890", 0), 12345678901234567890) + self.assertEqual(int('١٢٣٤٥٦٧٨٩٠', 0), 1234567890) + + def test_underscores(self): + for lit in VALID_UNDERSCORE_LITERALS: + if any(ch in lit for ch in '.eEjJ'): + continue + self.assertEqual(int(lit, 0), eval(lit)) + self.assertEqual(int(lit, 0), int(lit.replace('_', ''), 0)) + for lit in INVALID_UNDERSCORE_LITERALS: + if any(ch in lit for ch in '.eEjJ'): + continue + self.assertRaises(ValueError, int, lit, 0) + # Additional test cases with bases != 0, only for the constructor: + self.assertEqual(int("1_00", 3), 9) + self.assertEqual(int("0_100"), 100) # not valid as a literal! + self.assertEqual(int(b"1_00"), 100) # byte underscore + self.assertRaises(ValueError, int, "_100") + self.assertRaises(ValueError, int, "+_100") + self.assertRaises(ValueError, int, "1__00") + self.assertRaises(ValueError, int, "100_") + + @support.cpython_only + def test_small_ints(self): + # Bug #3236: Return small longs from PyLong_FromString + self.assertIs(int('10'), 10) + self.assertIs(int('-1'), -1) + self.assertIs(int(b'10'), 10) + self.assertIs(int(b'-1'), -1) + + def test_no_args(self): + self.assertEqual(int(), 0) + + def test_keyword_args(self): + # Test invoking int() using keyword arguments. + self.assertEqual(int('100', base=2), 4) + with self.assertRaisesRegex(TypeError, 'keyword argument'): + int(x=1.2) + with self.assertRaisesRegex(TypeError, 'keyword argument'): + int(x='100', base=2) + self.assertRaises(TypeError, int, base=10) + self.assertRaises(TypeError, int, base=0) + + def test_int_base_limits(self): + """Testing the supported limits of the int() base parameter.""" + self.assertEqual(int('0', 5), 0) + with self.assertRaises(ValueError): + int('0', 1) + with self.assertRaises(ValueError): + int('0', 37) + with self.assertRaises(ValueError): + int('0', -909) # An old magic value base from Python 2. + with self.assertRaises(ValueError): + int('0', base=0-(2**234)) + with self.assertRaises(ValueError): + int('0', base=2**234) + # Bases 2 through 36 are supported. + for base in range(2,37): + self.assertEqual(int('0', base=base), 0) + + def test_int_base_bad_types(self): + """Not integer types are not valid bases; issue16772.""" + with self.assertRaises(TypeError): + int('0', 5.5) + with self.assertRaises(TypeError): + int('0', 5.0) + + def test_int_base_indexable(self): + class MyIndexable(object): + def __init__(self, value): + self.value = value + def __index__(self): + return self.value + + # Check out of range bases. + for base in 2**100, -2**100, 1, 37: + with self.assertRaises(ValueError): + int('43', base) + + # Check in-range bases. + self.assertEqual(int('101', base=MyIndexable(2)), 5) + self.assertEqual(int('101', base=MyIndexable(10)), 101) + self.assertEqual(int('101', base=MyIndexable(36)), 1 + 36**2) + + def test_non_numeric_input_types(self): + # Test possible non-numeric types for the argument x, including + # subclasses of the explicitly documented accepted types. + class CustomStr(str): pass + class CustomBytes(bytes): pass + class CustomByteArray(bytearray): pass + + factories = [ + bytes, + bytearray, + lambda b: CustomStr(b.decode()), + CustomBytes, + CustomByteArray, + memoryview, + ] + try: + from array import array + except ImportError: + pass + else: + factories.append(lambda b: array('B', b)) + + for f in factories: + x = f(b'100') + with self.subTest(type(x)): + self.assertEqual(int(x), 100) + if isinstance(x, (str, bytes, bytearray)): + self.assertEqual(int(x, 2), 4) + else: + msg = "can't convert non-string" + with self.assertRaisesRegex(TypeError, msg): + int(x, 2) + with self.assertRaisesRegex(ValueError, 'invalid literal'): + int(f(b'A' * 0x10)) + + def test_int_memoryview(self): + self.assertEqual(int(memoryview(b'123')[1:3]), 23) + self.assertEqual(int(memoryview(b'123\x00')[1:3]), 23) + self.assertEqual(int(memoryview(b'123 ')[1:3]), 23) + self.assertEqual(int(memoryview(b'123A')[1:3]), 23) + self.assertEqual(int(memoryview(b'1234')[1:3]), 23) + + def test_string_float(self): + self.assertRaises(ValueError, int, '1.2') + + def test_intconversion(self): + # Test __int__() + class ClassicMissingMethods: + pass + self.assertRaises(TypeError, int, ClassicMissingMethods()) + + class MissingMethods(object): + pass + self.assertRaises(TypeError, int, MissingMethods()) + + class Foo0: + def __int__(self): + return 42 + + self.assertEqual(int(Foo0()), 42) + + class Classic: + pass + for base in (object, Classic): + class IntOverridesTrunc(base): + def __int__(self): + return 42 + def __trunc__(self): + return -12 + self.assertEqual(int(IntOverridesTrunc()), 42) + + class JustTrunc(base): + def __trunc__(self): + return 42 + with self.assertWarns(DeprecationWarning): + self.assertEqual(int(JustTrunc()), 42) + + class ExceptionalTrunc(base): + def __trunc__(self): + 1 / 0 + with self.assertRaises(ZeroDivisionError), \ + self.assertWarns(DeprecationWarning): + int(ExceptionalTrunc()) + + for trunc_result_base in (object, Classic): + class Index(trunc_result_base): + def __index__(self): + return 42 + + class TruncReturnsNonInt(base): + def __trunc__(self): + return Index() + with self.assertWarns(DeprecationWarning): + self.assertEqual(int(TruncReturnsNonInt()), 42) + + class Intable(trunc_result_base): + def __int__(self): + return 42 + + class TruncReturnsNonIndex(base): + def __trunc__(self): + return Intable() + with self.assertWarns(DeprecationWarning): + self.assertEqual(int(TruncReturnsNonInt()), 42) + + class NonIntegral(trunc_result_base): + def __trunc__(self): + # Check that we avoid infinite recursion. + return NonIntegral() + + class TruncReturnsNonIntegral(base): + def __trunc__(self): + return NonIntegral() + try: + with self.assertWarns(DeprecationWarning): + int(TruncReturnsNonIntegral()) + except TypeError as e: + self.assertEqual(str(e), + "__trunc__ returned non-Integral" + " (type NonIntegral)") + else: + self.fail("Failed to raise TypeError with %s" % + ((base, trunc_result_base),)) + + # Regression test for bugs.python.org/issue16060. + class BadInt(trunc_result_base): + def __int__(self): + return 42.0 + + class TruncReturnsBadInt(base): + def __trunc__(self): + return BadInt() + + with self.assertRaises(TypeError), \ + self.assertWarns(DeprecationWarning): + int(TruncReturnsBadInt()) + + def test_int_subclass_with_index(self): + class MyIndex(int): + def __index__(self): + return 42 + + class BadIndex(int): + def __index__(self): + return 42.0 + + my_int = MyIndex(7) + self.assertEqual(my_int, 7) + self.assertEqual(int(my_int), 7) + + self.assertEqual(int(BadIndex()), 0) + + def test_int_subclass_with_int(self): + class MyInt(int): + def __int__(self): + return 42 + + class BadInt(int): + def __int__(self): + return 42.0 + + my_int = MyInt(7) + self.assertEqual(my_int, 7) + self.assertEqual(int(my_int), 42) + + my_int = BadInt(7) + self.assertEqual(my_int, 7) + self.assertRaises(TypeError, int, my_int) + + def test_int_returns_int_subclass(self): + class BadIndex: + def __index__(self): + return True + + class BadIndex2(int): + def __index__(self): + return True + + class BadInt: + def __int__(self): + return True + + class BadInt2(int): + def __int__(self): + return True + + class TruncReturnsBadIndex: + def __trunc__(self): + return BadIndex() + + class TruncReturnsBadInt: + def __trunc__(self): + return BadInt() + + class TruncReturnsIntSubclass: + def __trunc__(self): + return True + + bad_int = BadIndex() + with self.assertWarns(DeprecationWarning): + n = int(bad_int) + self.assertEqual(n, 1) + self.assertIs(type(n), int) + + bad_int = BadIndex2() + n = int(bad_int) + self.assertEqual(n, 0) + self.assertIs(type(n), int) + + bad_int = BadInt() + with self.assertWarns(DeprecationWarning): + n = int(bad_int) + self.assertEqual(n, 1) + self.assertIs(type(n), int) + + bad_int = BadInt2() + with self.assertWarns(DeprecationWarning): + n = int(bad_int) + self.assertEqual(n, 1) + self.assertIs(type(n), int) + + bad_int = TruncReturnsBadIndex() + with self.assertWarns(DeprecationWarning): + n = int(bad_int) + self.assertEqual(n, 1) + self.assertIs(type(n), int) + + bad_int = TruncReturnsBadInt() + with self.assertWarns(DeprecationWarning): + self.assertRaises(TypeError, int, bad_int) + + good_int = TruncReturnsIntSubclass() + with self.assertWarns(DeprecationWarning): + n = int(good_int) + self.assertEqual(n, 1) + self.assertIs(type(n), int) + with self.assertWarns(DeprecationWarning): + n = IntSubclass(good_int) + self.assertEqual(n, 1) + self.assertIs(type(n), IntSubclass) + + @skipIfTorchDynamo("flaky under dynamo") + def test_error_message(self): + def check(s, base=None): + with self.assertRaises(ValueError, + msg="int(%r, %r)" % (s, base)) as cm: + if base is None: + int(s) + else: + int(s, base) + self.assertEqual(cm.exception.args[0], + "invalid literal for int() with base %d: %r" % + (10 if base is None else base, s)) + + check('\xbd') + check('123\xbd') + check(' 123 456 ') + + check('123\x00') + # SF bug 1545497: embedded NULs were not detected with explicit base + check('123\x00', 10) + check('123\x00 245', 20) + check('123\x00 245', 16) + check('123\x00245', 20) + check('123\x00245', 16) + # byte string with embedded NUL + check(b'123\x00') + check(b'123\x00', 10) + # non-UTF-8 byte string + check(b'123\xbd') + check(b'123\xbd', 10) + # lone surrogate in Unicode string + check('123\ud800') + check('123\ud800', 10) + + def test_issue31619(self): + self.assertEqual(int('1_0_1_0_1_0_1_0_1_0_1_0_1_0_1_0_1_0_1_0_1_0_1_0_1_0_1_0_1_0_1', 2), + 0b1010101010101010101010101010101) + self.assertEqual(int('1_2_3_4_5_6_7_0_1_2_3', 8), 0o12345670123) + self.assertEqual(int('1_2_3_4_5_6_7_8_9', 16), 0x123456789) + self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807) + + +class IntStrDigitLimitsTests(__TestCase): + + int_class = int # Override this in subclasses to reuse the suite. + + def setUp(self): + super().setUp() + self._previous_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(2048) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_limit) + super().tearDown() + + def test_disabled_limit(self): + self.assertGreater(sys.get_int_max_str_digits(), 0) + self.assertLess(sys.get_int_max_str_digits(), 20_000) + with support.adjust_int_max_str_digits(0): + self.assertEqual(sys.get_int_max_str_digits(), 0) + i = self.int_class('1' * 20_000) + str(i) + self.assertGreater(sys.get_int_max_str_digits(), 0) + + def test_max_str_digits_edge_cases(self): + """Ignore the +/- sign and space padding.""" + int_class = self.int_class + maxdigits = sys.get_int_max_str_digits() + + int_class('1' * maxdigits) + int_class(' ' + '1' * maxdigits) + int_class('1' * maxdigits + ' ') + int_class('+' + '1' * maxdigits) + int_class('-' + '1' * maxdigits) + self.assertEqual(len(str(10 ** (maxdigits - 1))), maxdigits) + + def check(self, i, base=None): + with self.assertRaises(ValueError): + if base is None: + self.int_class(i) + else: + self.int_class(i, base) + + def test_max_str_digits(self): + maxdigits = sys.get_int_max_str_digits() + + self.check('1' * (maxdigits + 1)) + self.check(' ' + '1' * (maxdigits + 1)) + self.check('1' * (maxdigits + 1) + ' ') + self.check('+' + '1' * (maxdigits + 1)) + self.check('-' + '1' * (maxdigits + 1)) + self.check('1' * (maxdigits + 1)) + + i = 10 ** maxdigits + with self.assertRaises(ValueError): + str(i) + + def test_denial_of_service_prevented_int_to_str(self): + """Regression test: ensure we fail before performing O(N**2) work.""" + maxdigits = sys.get_int_max_str_digits() + assert maxdigits < 50_000, maxdigits # A test prerequisite. + + huge_int = int(f'0x{"c"*65_000}', base=16) # 78268 decimal digits. + digits = 78_268 + with ( + support.adjust_int_max_str_digits(digits), + support.CPUStopwatch() as sw_convert): + huge_decimal = str(huge_int) + self.assertEqual(len(huge_decimal), digits) + # Ensuring that we chose a slow enough conversion to measure. + # It takes 0.1 seconds on a Zen based cloud VM in an opt build. + # Some OSes have a low res 1/64s timer, skip if hard to measure. + if sw_convert.seconds < sw_convert.clock_info.resolution * 2: + raise unittest.SkipTest('"slow" conversion took only ' + f'{sw_convert.seconds} seconds.') + + # We test with the limit almost at the size needed to check performance. + # The performant limit check is slightly fuzzy, give it a some room. + with support.adjust_int_max_str_digits(int(.995 * digits)): + with ( + self.assertRaises(ValueError) as err, + support.CPUStopwatch() as sw_fail_huge): + str(huge_int) + self.assertIn('conversion', str(err.exception)) + self.assertLessEqual(sw_fail_huge.seconds, sw_convert.seconds/2) + + # Now we test that a conversion that would take 30x as long also fails + # in a similarly fast fashion. + extra_huge_int = int(f'0x{"c"*500_000}', base=16) # 602060 digits. + with ( + self.assertRaises(ValueError) as err, + support.CPUStopwatch() as sw_fail_extra_huge): + # If not limited, 8 seconds said Zen based cloud VM. + str(extra_huge_int) + self.assertIn('conversion', str(err.exception)) + self.assertLess(sw_fail_extra_huge.seconds, sw_convert.seconds/2) + + def test_denial_of_service_prevented_str_to_int(self): + """Regression test: ensure we fail before performing O(N**2) work.""" + maxdigits = sys.get_int_max_str_digits() + assert maxdigits < 100_000, maxdigits # A test prerequisite. + + digits = 133700 + huge = '8'*digits + with ( + support.adjust_int_max_str_digits(digits), + support.CPUStopwatch() as sw_convert): + int(huge) + # Ensuring that we chose a slow enough conversion to measure. + # It takes 0.1 seconds on a Zen based cloud VM in an opt build. + # Some OSes have a low res 1/64s timer, skip if hard to measure. + if sw_convert.seconds < sw_convert.clock_info.resolution * 2: + raise unittest.SkipTest('"slow" conversion took only ' + f'{sw_convert.seconds} seconds.') + + with support.adjust_int_max_str_digits(digits - 1): + with ( + self.assertRaises(ValueError) as err, + support.CPUStopwatch() as sw_fail_huge): + int(huge) + self.assertIn('conversion', str(err.exception)) + self.assertLessEqual(sw_fail_huge.seconds, sw_convert.seconds/2) + + # Now we test that a conversion that would take 30x as long also fails + # in a similarly fast fashion. + extra_huge = '7'*1_200_000 + with ( + self.assertRaises(ValueError) as err, + support.CPUStopwatch() as sw_fail_extra_huge): + # If not limited, 8 seconds in the Zen based cloud VM. + int(extra_huge) + self.assertIn('conversion', str(err.exception)) + self.assertLessEqual(sw_fail_extra_huge.seconds, sw_convert.seconds/2) + + def test_power_of_two_bases_unlimited(self): + """The limit does not apply to power of 2 bases.""" + maxdigits = sys.get_int_max_str_digits() + + for base in (2, 4, 8, 16, 32): + with self.subTest(base=base): + self.int_class('1' * (maxdigits + 1), base) + assert maxdigits < 100_000 + self.int_class('1' * 100_000, base) + + def test_underscores_ignored(self): + maxdigits = sys.get_int_max_str_digits() + + triples = maxdigits // 3 + s = '111' * triples + s_ = '1_11' * triples + self.int_class(s) # succeeds + self.int_class(s_) # succeeds + self.check(f'{s}111') + self.check(f'{s_}_111') + + def test_sign_not_counted(self): + int_class = self.int_class + max_digits = sys.get_int_max_str_digits() + s = '5' * max_digits + i = int_class(s) + pos_i = int_class(f'+{s}') + assert i == pos_i + neg_i = int_class(f'-{s}') + assert -pos_i == neg_i + str(pos_i) + str(neg_i) + + def _other_base_helper(self, base): + int_class = self.int_class + max_digits = sys.get_int_max_str_digits() + s = '2' * max_digits + i = int_class(s, base) + if base > 10: + with self.assertRaises(ValueError): + str(i) + elif base < 10: + str(i) + with self.assertRaises(ValueError) as err: + int_class(f'{s}1', base) + + def test_int_from_other_bases(self): + base = 3 + with self.subTest(base=base): + self._other_base_helper(base) + base = 36 + with self.subTest(base=base): + self._other_base_helper(base) + + def test_int_max_str_digits_is_per_interpreter(self): + # Changing the limit in one interpreter does not change others. + code = """if 1: + # Subinterpreters maintain and enforce their own limit + import sys + sys.set_int_max_str_digits(2323) + try: + int('3'*3333) + except ValueError: + pass + else: + raise AssertionError('Expected a int max str digits ValueError.') + """ + with support.adjust_int_max_str_digits(4000): + before_value = sys.get_int_max_str_digits() + self.assertEqual(support.run_in_subinterp(code), 0, + 'subinterp code failure, check stderr.') + after_value = sys.get_int_max_str_digits() + self.assertEqual(before_value, after_value) + + +class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests): + int_class = IntSubclass + + +class PyLongModuleTests(__TestCase): + # Tests of the functions in _pylong.py. Those get used when the + # number of digits in the input values are large enough. + + def setUp(self): + super().setUp() + self._previous_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(0) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_limit) + super().tearDown() + + def _test_pylong_int_to_decimal(self, n, suffix): + s = str(n) + self.assertEqual(s[-10:], suffix) + s2 = str(-n) + self.assertEqual(s2, '-' + s) + s3 = '%d' % n + self.assertEqual(s3, s) + s4 = b'%d' % n + self.assertEqual(s4, s.encode('ascii')) + + def test_pylong_int_to_decimal(self): + self._test_pylong_int_to_decimal((1 << 100_000), '9883109376') + self._test_pylong_int_to_decimal((1 << 100_000) - 1, '9883109375') + self._test_pylong_int_to_decimal(10**30_000, '0000000000') + self._test_pylong_int_to_decimal(10**30_000 - 1, '9999999999') + self._test_pylong_int_to_decimal(3**60_000, '9313200001') + + @support.requires_resource('cpu') + def test_pylong_int_to_decimal_2(self): + self._test_pylong_int_to_decimal(2**1_000_000, '2747109376') + self._test_pylong_int_to_decimal(10**300_000, '0000000000') + self._test_pylong_int_to_decimal(3**600_000, '3132000001') + + def test_pylong_int_divmod(self): + n = (1 << 100_000) + a, b = divmod(n*3 + 1, n) + assert a == 3 and b == 1 + + def test_pylong_str_to_int(self): + v1 = 1 << 100_000 + s = str(v1) + v2 = int(s) + assert v1 == v2 + v3 = int(' -' + s) + assert -v1 == v3 + v4 = int(' +' + s + ' ') + assert v1 == v4 + with self.assertRaises(ValueError) as err: + int(s + 'z') + with self.assertRaises(ValueError) as err: + int(s + '_') + with self.assertRaises(ValueError) as err: + int('_' + s) + + @support.cpython_only # tests implementation details of CPython. + @unittest.skipUnless(_pylong, "_pylong module required") + @mock.patch.object(_pylong, "int_to_decimal_string") + def test_pylong_misbehavior_error_path_to_str( + self, mock_int_to_str): + with support.adjust_int_max_str_digits(20_000): + big_value = int('7'*19_999) + mock_int_to_str.return_value = None # not a str + with self.assertRaises(TypeError) as ctx: + str(big_value) + self.assertIn('_pylong.int_to_decimal_string did not', + str(ctx.exception)) + mock_int_to_str.side_effect = RuntimeError("testABC") + with self.assertRaises(RuntimeError): + str(big_value) + + @support.cpython_only # tests implementation details of CPython. + @unittest.skipUnless(_pylong, "_pylong module required") + @mock.patch.object(_pylong, "int_from_string") + def test_pylong_misbehavior_error_path_from_str( + self, mock_int_from_str): + big_value = '7'*19_999 + with support.adjust_int_max_str_digits(20_000): + mock_int_from_str.return_value = b'not an int' + with self.assertRaises(TypeError) as ctx: + int(big_value) + self.assertIn('_pylong.int_from_string did not', + str(ctx.exception)) + + mock_int_from_str.side_effect = RuntimeError("test123") + with self.assertRaises(RuntimeError): + int(big_value) + + def test_pylong_roundtrip(self): + from random import randrange, getrandbits + bits = 5000 + while bits <= 1_000_000: + bits += randrange(-100, 101) # break bitlength patterns + hibit = 1 << (bits - 1) + n = hibit | getrandbits(bits - 1) + assert n.bit_length() == bits + sn = str(n) + self.assertFalse(sn.startswith('0')) + self.assertEqual(n, int(sn)) + bits <<= 1 + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_int_literal.diff b/test/dynamo/cpython/3_13/test_int_literal.diff new file mode 100644 index 00000000000000..2f25367ff9fae2 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_int_literal.diff @@ -0,0 +1,74 @@ +diff --git a/test/dynamo/cpython/3_13/test_int_literal.py b/test/dynamo/cpython/3_13/test_int_literal.py +index bf725710d55..831d03666fb 100644 +--- a/test/dynamo/cpython/3_13/test_int_literal.py ++++ b/test/dynamo/cpython/3_13/test_int_literal.py +@@ -1,3 +1,54 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + """Test correct treatment of hex/oct constants. + + This is complex because of changes due to PEP 237. +@@ -5,7 +56,7 @@ This is complex because of changes due to PEP 237. + + import unittest + +-class TestHexOctBin(unittest.TestCase): ++class TestHexOctBin(__TestCase): + + def test_hex_baseline(self): + # A few upper/lowercase tests +@@ -140,4 +191,4 @@ class TestHexOctBin(unittest.TestCase): + self.assertEqual(-0b1111111111111111111111111111111111111111111111111111111111111111, -18446744073709551615) + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_int_literal.py b/test/dynamo/cpython/3_13/test_int_literal.py new file mode 100644 index 00000000000000..831d03666fb911 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_int_literal.py @@ -0,0 +1,194 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +"""Test correct treatment of hex/oct constants. + +This is complex because of changes due to PEP 237. +""" + +import unittest + +class TestHexOctBin(__TestCase): + + def test_hex_baseline(self): + # A few upper/lowercase tests + self.assertEqual(0x0, 0X0) + self.assertEqual(0x1, 0X1) + self.assertEqual(0x123456789abcdef, 0X123456789abcdef) + # Baseline tests + self.assertEqual(0x0, 0) + self.assertEqual(0x10, 16) + self.assertEqual(0x7fffffff, 2147483647) + self.assertEqual(0x7fffffffffffffff, 9223372036854775807) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0x0), 0) + self.assertEqual(-(0x10), -16) + self.assertEqual(-(0x7fffffff), -2147483647) + self.assertEqual(-(0x7fffffffffffffff), -9223372036854775807) + # Ditto with a minus sign and NO parentheses + self.assertEqual(-0x0, 0) + self.assertEqual(-0x10, -16) + self.assertEqual(-0x7fffffff, -2147483647) + self.assertEqual(-0x7fffffffffffffff, -9223372036854775807) + + def test_hex_unsigned(self): + # Positive constants + self.assertEqual(0x80000000, 2147483648) + self.assertEqual(0xffffffff, 4294967295) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0x80000000), -2147483648) + self.assertEqual(-(0xffffffff), -4294967295) + # Ditto with a minus sign and NO parentheses + # This failed in Python 2.2 through 2.2.2 and in 2.3a1 + self.assertEqual(-0x80000000, -2147483648) + self.assertEqual(-0xffffffff, -4294967295) + + # Positive constants + self.assertEqual(0x8000000000000000, 9223372036854775808) + self.assertEqual(0xffffffffffffffff, 18446744073709551615) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0x8000000000000000), -9223372036854775808) + self.assertEqual(-(0xffffffffffffffff), -18446744073709551615) + # Ditto with a minus sign and NO parentheses + # This failed in Python 2.2 through 2.2.2 and in 2.3a1 + self.assertEqual(-0x8000000000000000, -9223372036854775808) + self.assertEqual(-0xffffffffffffffff, -18446744073709551615) + + def test_oct_baseline(self): + # A few upper/lowercase tests + self.assertEqual(0o0, 0O0) + self.assertEqual(0o1, 0O1) + self.assertEqual(0o1234567, 0O1234567) + # Baseline tests + self.assertEqual(0o0, 0) + self.assertEqual(0o20, 16) + self.assertEqual(0o17777777777, 2147483647) + self.assertEqual(0o777777777777777777777, 9223372036854775807) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0o0), 0) + self.assertEqual(-(0o20), -16) + self.assertEqual(-(0o17777777777), -2147483647) + self.assertEqual(-(0o777777777777777777777), -9223372036854775807) + # Ditto with a minus sign and NO parentheses + self.assertEqual(-0o0, 0) + self.assertEqual(-0o20, -16) + self.assertEqual(-0o17777777777, -2147483647) + self.assertEqual(-0o777777777777777777777, -9223372036854775807) + + def test_oct_unsigned(self): + # Positive constants + self.assertEqual(0o20000000000, 2147483648) + self.assertEqual(0o37777777777, 4294967295) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0o20000000000), -2147483648) + self.assertEqual(-(0o37777777777), -4294967295) + # Ditto with a minus sign and NO parentheses + # This failed in Python 2.2 through 2.2.2 and in 2.3a1 + self.assertEqual(-0o20000000000, -2147483648) + self.assertEqual(-0o37777777777, -4294967295) + + # Positive constants + self.assertEqual(0o1000000000000000000000, 9223372036854775808) + self.assertEqual(0o1777777777777777777777, 18446744073709551615) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0o1000000000000000000000), -9223372036854775808) + self.assertEqual(-(0o1777777777777777777777), -18446744073709551615) + # Ditto with a minus sign and NO parentheses + # This failed in Python 2.2 through 2.2.2 and in 2.3a1 + self.assertEqual(-0o1000000000000000000000, -9223372036854775808) + self.assertEqual(-0o1777777777777777777777, -18446744073709551615) + + def test_bin_baseline(self): + # A few upper/lowercase tests + self.assertEqual(0b0, 0B0) + self.assertEqual(0b1, 0B1) + self.assertEqual(0b10101010101, 0B10101010101) + # Baseline tests + self.assertEqual(0b0, 0) + self.assertEqual(0b10000, 16) + self.assertEqual(0b1111111111111111111111111111111, 2147483647) + self.assertEqual(0b111111111111111111111111111111111111111111111111111111111111111, 9223372036854775807) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0b0), 0) + self.assertEqual(-(0b10000), -16) + self.assertEqual(-(0b1111111111111111111111111111111), -2147483647) + self.assertEqual(-(0b111111111111111111111111111111111111111111111111111111111111111), -9223372036854775807) + # Ditto with a minus sign and NO parentheses + self.assertEqual(-0b0, 0) + self.assertEqual(-0b10000, -16) + self.assertEqual(-0b1111111111111111111111111111111, -2147483647) + self.assertEqual(-0b111111111111111111111111111111111111111111111111111111111111111, -9223372036854775807) + + def test_bin_unsigned(self): + # Positive constants + self.assertEqual(0b10000000000000000000000000000000, 2147483648) + self.assertEqual(0b11111111111111111111111111111111, 4294967295) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0b10000000000000000000000000000000), -2147483648) + self.assertEqual(-(0b11111111111111111111111111111111), -4294967295) + # Ditto with a minus sign and NO parentheses + # This failed in Python 2.2 through 2.2.2 and in 2.3a1 + self.assertEqual(-0b10000000000000000000000000000000, -2147483648) + self.assertEqual(-0b11111111111111111111111111111111, -4294967295) + + # Positive constants + self.assertEqual(0b1000000000000000000000000000000000000000000000000000000000000000, 9223372036854775808) + self.assertEqual(0b1111111111111111111111111111111111111111111111111111111111111111, 18446744073709551615) + # Ditto with a minus sign and parentheses + self.assertEqual(-(0b1000000000000000000000000000000000000000000000000000000000000000), -9223372036854775808) + self.assertEqual(-(0b1111111111111111111111111111111111111111111111111111111111111111), -18446744073709551615) + # Ditto with a minus sign and NO parentheses + # This failed in Python 2.2 through 2.2.2 and in 2.3a1 + self.assertEqual(-0b1000000000000000000000000000000000000000000000000000000000000000, -9223372036854775808) + self.assertEqual(-0b1111111111111111111111111111111111111111111111111111111111111111, -18446744073709551615) + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_list.diff b/test/dynamo/cpython/3_13/test_list.diff new file mode 100644 index 00000000000000..943f67dd4a00a7 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_list.diff @@ -0,0 +1,77 @@ +diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py +index 23ef902aa0b..30e69ff75bd 100644 +--- a/test/dynamo/cpython/3_13/test_list.py ++++ b/test/dynamo/cpython/3_13/test_list.py +@@ -1,6 +1,57 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import sys + import textwrap +-from test import list_tests ++import list_tests + from test.support import cpython_only + from test.support.script_helper import assert_python_ok + import pickle +@@ -324,6 +375,7 @@ class ListTest(list_tests.CommonTest): + a.append(4) + self.assertEqual(list(it), []) + ++ @unittest.skip("Fails on python <=3.13.2 and passes on >=3.13.3") + def test_deopt_from_append_list(self): + # gh-132011: it used to crash, because + # of `CALL_LIST_APPEND` specialization failure. +@@ -345,4 +397,4 @@ class ListTest(list_tests.CommonTest): + self.assertEqual(rc, 0) + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py new file mode 100644 index 00000000000000..6e4c6d99d169ab --- /dev/null +++ b/test/dynamo/cpython/3_13/test_list.py @@ -0,0 +1,398 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import sys +import textwrap +import list_tests +from test.support import cpython_only +from test.support.script_helper import assert_python_ok +import pickle +import unittest + +class ListTest(list_tests.CommonTest): + type2test = list + + def test_basic(self): + self.assertEqual(list([]), []) + l0_3 = [0, 1, 2, 3] + l0_3_bis = list(l0_3) + self.assertEqual(l0_3, l0_3_bis) + self.assertTrue(l0_3 is not l0_3_bis) + self.assertEqual(list(()), []) + self.assertEqual(list((0, 1, 2, 3)), [0, 1, 2, 3]) + self.assertEqual(list(''), []) + self.assertEqual(list('spam'), ['s', 'p', 'a', 'm']) + self.assertEqual(list(x for x in range(10) if x % 2), + [1, 3, 5, 7, 9]) + + if sys.maxsize == 0x7fffffff: + # This test can currently only work on 32-bit machines. + # XXX If/when PySequence_Length() returns a ssize_t, it should be + # XXX re-enabled. + # Verify clearing of bug #556025. + # This assumes that the max data size (sys.maxint) == max + # address size this also assumes that the address size is at + # least 4 bytes with 8 byte addresses, the bug is not well + # tested + # + # Note: This test is expected to SEGV under Cygwin 1.3.12 or + # earlier due to a newlib bug. See the following mailing list + # thread for the details: + self.assertRaises(MemoryError, list, range(sys.maxsize // 2)) + + # This code used to segfault in Py2.4a3 + x = [] + x.extend(-y for y in x) + self.assertEqual(x, []) + + def test_keyword_args(self): + with self.assertRaisesRegex(TypeError, 'keyword argument'): + list(sequence=[]) + + def test_keywords_in_subclass(self): + class subclass(list): + pass + u = subclass([1, 2]) + self.assertIs(type(u), subclass) + self.assertEqual(list(u), [1, 2]) + with self.assertRaises(TypeError): + subclass(sequence=()) + + class subclass_with_init(list): + def __init__(self, seq, newarg=None): + super().__init__(seq) + self.newarg = newarg + u = subclass_with_init([1, 2], newarg=3) + self.assertIs(type(u), subclass_with_init) + self.assertEqual(list(u), [1, 2]) + self.assertEqual(u.newarg, 3) + + class subclass_with_new(list): + def __new__(cls, seq, newarg=None): + self = super().__new__(cls, seq) + self.newarg = newarg + return self + u = subclass_with_new([1, 2], newarg=3) + self.assertIs(type(u), subclass_with_new) + self.assertEqual(list(u), [1, 2]) + self.assertEqual(u.newarg, 3) + + def test_truth(self): + super().test_truth() + self.assertTrue(not []) + self.assertTrue([42]) + + def test_identity(self): + self.assertTrue([] is not []) + + def test_len(self): + super().test_len() + self.assertEqual(len([]), 0) + self.assertEqual(len([0]), 1) + self.assertEqual(len([0, 1, 2]), 3) + + def test_overflow(self): + lst = [4, 5, 6, 7] + n = int((sys.maxsize*2+2) // len(lst)) + def mul(a, b): return a * b + def imul(a, b): a *= b + self.assertRaises((MemoryError, OverflowError), mul, lst, n) + self.assertRaises((MemoryError, OverflowError), imul, lst, n) + + def test_empty_slice(self): + x = [] + x[:] = x + self.assertEqual(x, []) + + def test_list_resize_overflow(self): + # gh-97616: test new_allocated * sizeof(PyObject*) overflow + # check in list_resize() + lst = [0] * 65 + del lst[1:] + self.assertEqual(len(lst), 1) + + size = sys.maxsize + with self.assertRaises((MemoryError, OverflowError)): + lst * size + with self.assertRaises((MemoryError, OverflowError)): + lst *= size + + def test_repr_mutate(self): + class Obj: + @staticmethod + def __repr__(): + try: + mylist.pop() + except IndexError: + pass + return 'obj' + + mylist = [Obj() for _ in range(5)] + self.assertEqual(repr(mylist), '[obj, obj, obj]') + + def test_repr_large(self): + # Check the repr of large list objects + def check(n): + l = [0] * n + s = repr(l) + self.assertEqual(s, + '[' + ', '.join(['0'] * n) + ']') + check(10) # check our checking code + check(1000000) + + def test_iterator_pickle(self): + orig = self.type2test([4, 5, 6, 7]) + data = [10, 11, 12, 13, 14, 15] + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + # initial iterator + itorig = iter(orig) + d = pickle.dumps((itorig, orig), proto) + it, a = pickle.loads(d) + a[:] = data + self.assertEqual(type(it), type(itorig)) + self.assertEqual(list(it), data) + + # running iterator + next(itorig) + d = pickle.dumps((itorig, orig), proto) + it, a = pickle.loads(d) + a[:] = data + self.assertEqual(type(it), type(itorig)) + self.assertEqual(list(it), data[1:]) + + # empty iterator + for i in range(1, len(orig)): + next(itorig) + d = pickle.dumps((itorig, orig), proto) + it, a = pickle.loads(d) + a[:] = data + self.assertEqual(type(it), type(itorig)) + self.assertEqual(list(it), data[len(orig):]) + + # exhausted iterator + self.assertRaises(StopIteration, next, itorig) + d = pickle.dumps((itorig, orig), proto) + it, a = pickle.loads(d) + a[:] = data + self.assertEqual(list(it), []) + + def test_reversed_pickle(self): + orig = self.type2test([4, 5, 6, 7]) + data = [10, 11, 12, 13, 14, 15] + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + # initial iterator + itorig = reversed(orig) + d = pickle.dumps((itorig, orig), proto) + it, a = pickle.loads(d) + a[:] = data + self.assertEqual(type(it), type(itorig)) + self.assertEqual(list(it), data[len(orig)-1::-1]) + + # running iterator + next(itorig) + d = pickle.dumps((itorig, orig), proto) + it, a = pickle.loads(d) + a[:] = data + self.assertEqual(type(it), type(itorig)) + self.assertEqual(list(it), data[len(orig)-2::-1]) + + # empty iterator + for i in range(1, len(orig)): + next(itorig) + d = pickle.dumps((itorig, orig), proto) + it, a = pickle.loads(d) + a[:] = data + self.assertEqual(type(it), type(itorig)) + self.assertEqual(list(it), []) + + # exhausted iterator + self.assertRaises(StopIteration, next, itorig) + d = pickle.dumps((itorig, orig), proto) + it, a = pickle.loads(d) + a[:] = data + self.assertEqual(list(it), []) + + def test_step_overflow(self): + a = [0, 1, 2, 3, 4] + a[1::sys.maxsize] = [0] + self.assertEqual(a[3::sys.maxsize], [3]) + + def test_no_comdat_folding(self): + # Issue 8847: In the PGO build, the MSVC linker's COMDAT folding + # optimization causes failures in code that relies on distinct + # function addresses. + class L(list): pass + with self.assertRaises(TypeError): + (3,) + L([1,2]) + + def test_equal_operator_modifying_operand(self): + # test fix for seg fault reported in bpo-38588 part 2. + class X: + def __eq__(self,other) : + list2.clear() + return NotImplemented + + class Y: + def __eq__(self, other): + list1.clear() + return NotImplemented + + class Z: + def __eq__(self, other): + list3.clear() + return NotImplemented + + list1 = [X()] + list2 = [Y()] + self.assertTrue(list1 == list2) + + list3 = [Z()] + list4 = [1] + self.assertFalse(list3 == list4) + + def test_lt_operator_modifying_operand(self): + # See gh-120298 + class evil: + def __lt__(self, other): + other.clear() + return NotImplemented + + a = [[evil()]] + with self.assertRaises(TypeError): + a[0] < a + + def test_list_index_modifing_operand(self): + # See gh-120384 + class evil: + def __init__(self, lst): + self.lst = lst + def __iter__(self): + yield from self.lst + self.lst.clear() + + lst = list(range(5)) + operand = evil(lst) + with self.assertRaises(ValueError): + lst[::-1] = operand + + @cpython_only + def test_preallocation(self): + iterable = [0] * 10 + iter_size = sys.getsizeof(iterable) + + self.assertEqual(iter_size, sys.getsizeof(list([0] * 10))) + self.assertEqual(iter_size, sys.getsizeof(list(range(10)))) + + def test_count_index_remove_crashes(self): + # bpo-38610: The count(), index(), and remove() methods were not + # holding strong references to list elements while calling + # PyObject_RichCompareBool(). + class X: + def __eq__(self, other): + lst.clear() + return NotImplemented + + lst = [X()] + with self.assertRaises(ValueError): + lst.index(lst) + + class L(list): + def __eq__(self, other): + str(other) + return NotImplemented + + lst = L([X()]) + lst.count(lst) + + lst = L([X()]) + with self.assertRaises(ValueError): + lst.remove(lst) + + # bpo-39453: list.__contains__ was not holding strong references + # to list elements while calling PyObject_RichCompareBool(). + lst = [X(), X()] + 3 in lst + lst = [X(), X()] + X() in lst + + def test_tier2_invalidates_iterator(self): + # GH-121012 + for _ in range(100): + a = [1, 2, 3] + it = iter(a) + for _ in it: + pass + a.append(4) + self.assertEqual(list(it), []) + + @unittest.skip("Fails on python <=3.13.2 and passes on >=3.13.3") + def test_deopt_from_append_list(self): + # gh-132011: it used to crash, because + # of `CALL_LIST_APPEND` specialization failure. + code = textwrap.dedent(""" + l = [] + def lappend(l, x, y): + l.append((x, y)) + for x in range(3): + lappend(l, None, None) + try: + lappend(list, None, None) + except TypeError: + pass + else: + raise AssertionError + """) + + rc, _, _ = assert_python_ok("-c", code) + self.assertEqual(rc, 0) + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_math.diff b/test/dynamo/cpython/3_13/test_math.diff new file mode 100644 index 00000000000000..4192addeca5c99 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_math.diff @@ -0,0 +1,191 @@ +diff --git a/test/dynamo/cpython/3_13/test_math.py b/test/dynamo/cpython/3_13/test_math.py +index 5ee3055c871..51773d5f478 100644 +--- a/test/dynamo/cpython/3_13/test_math.py ++++ b/test/dynamo/cpython/3_13/test_math.py +@@ -1,3 +1,58 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import ( ++ slowTest, ++ run_tests, ++ skipIfTorchDynamo, ++) ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + # Python test set -- math module + # XXXX Should not do tests around zero only + +@@ -242,7 +297,7 @@ class BadDescr: + def __get__(self, obj, objtype=None): + raise ValueError + +-class MathTests(unittest.TestCase): ++class MathTests(__TestCase): + + def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0): + """Compare arguments expected and got, as floats, if either +@@ -533,6 +588,7 @@ class MathTests(unittest.TestCase): + self.ftest('fabs(0)', math.fabs(0), 0) + self.ftest('fabs(1)', math.fabs(1), 1) + ++ @skipIfTorchDynamo("infinite loop") + def testFactorial(self): + self.assertEqual(math.factorial(0), 1) + total = 1 +@@ -1072,6 +1128,7 @@ class MathTests(unittest.TestCase): + with self.assertRaises(ValueError): + math.dist([1, 2], [3, 4, 5]) + ++ @slowTest + def testIsqrt(self): + # Test a variety of inputs, large and small. + test_values = ( +@@ -1202,12 +1259,6 @@ class MathTests(unittest.TestCase): + self.assertEqual(math.ldexp(NINF, n), NINF) + self.assertTrue(math.isnan(math.ldexp(NAN, n))) + +- @requires_IEEE_754 +- def testLdexp_denormal(self): +- # Denormal output incorrectly rounded (truncated) +- # on some Windows. +- self.assertEqual(math.ldexp(6993274598585239, -1126), 1e-323) +- + def testLog(self): + self.assertRaises(TypeError, math.log) + self.assertRaises(TypeError, math.log, 1, 2, 3) +@@ -1233,6 +1284,7 @@ class MathTests(unittest.TestCase): + self.assertRaises(ValueError, math.log1p, -1) + self.assertEqual(math.log1p(INF), INF) + ++ @skipIfTorchDynamo("Infinite loop") + @requires_IEEE_754 + def testLog2(self): + self.assertRaises(TypeError, math.log2) +@@ -1251,6 +1303,7 @@ class MathTests(unittest.TestCase): + self.assertRaises(ValueError, math.log2, NINF) + self.assertTrue(math.isnan(math.log2(NAN))) + ++ @skipIfTorchDynamo("Infinite loop") + @requires_IEEE_754 + # log2() is not accurate enough on Mac OS X Tiger (10.4) + @support.requires_mac_ver(10, 5) +@@ -1332,7 +1385,7 @@ class MathTests(unittest.TestCase): + with self.assertRaises(RuntimeError): + sumprod(raise_after(5), range(10)) + +- from test.test_iter import BasicIterClass ++ from test_iter import BasicIterClass + + self.assertEqual(sumprod(BasicIterClass(1), [1]), 0) + self.assertEqual(sumprod([1], BasicIterClass(1)), 0) +@@ -2252,6 +2305,7 @@ class MathTests(unittest.TestCase): + self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])), + decimal.Decimal) + ++ @skipIfTorchDynamo("Infinite loop") + def testPerm(self): + perm = math.perm + factorial = math.factorial +@@ -2316,6 +2370,7 @@ class MathTests(unittest.TestCase): + self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int) + self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int) + ++ @skipIfTorchDynamo("infinite loop") + def testComb(self): + comb = math.comb + factorial = math.factorial +@@ -2446,6 +2501,7 @@ class MathTests(unittest.TestCase): + math.nextafter(1.0, INF, steps=-1) + + ++ @unittest.skip("flaky test under torch dynamo") # works on pytest and crashes on unittest + @requires_IEEE_754 + def test_ulp(self): + self.assertEqual(math.ulp(1.0), sys.float_info.epsilon) +@@ -2508,7 +2564,7 @@ class MathTests(unittest.TestCase): + self.assertEqual(math.copysign(1.0, x), math.copysign(1.0, y)) + + +-class IsCloseTests(unittest.TestCase): ++class IsCloseTests(__TestCase): + isclose = math.isclose # subclasses should override this + + def assertIsClose(self, a, b, *args, **kwargs): +@@ -2631,7 +2687,7 @@ class IsCloseTests(unittest.TestCase): + self.assertAllNotClose(fraction_examples, rel_tol=1e-9) + + +-class FMATests(unittest.TestCase): ++class FMATests(__TestCase): + """ Tests for math.fma. """ + + def test_fma_nan_results(self): +@@ -2719,8 +2775,7 @@ class FMATests(unittest.TestCase): + # properly: it doesn't use the right sign when the result is zero. + @unittest.skipIf( + sys.platform.startswith(("freebsd", "wasi", "netbsd", "emscripten")) +- or (sys.platform == "android" and platform.machine() == "x86_64") +- or support.linked_to_musl(), # gh-131032 ++ or (sys.platform == "android" and platform.machine() == "x86_64"), + f"this platform doesn't implement IEE 754-2008 properly") + def test_fma_zero_result(self): + nonnegative_finites = [0.0, 1e-300, 2.3, 1e300] +@@ -2879,10 +2934,5 @@ class FMATests(unittest.TestCase): + ) + + +-def load_tests(loader, tests, pattern): +- from doctest import DocFileSuite +- tests.addTest(DocFileSuite(os.path.join("mathdata", "ieee754.txt"))) +- return tests +- +-if __name__ == '__main__': +- unittest.main() ++if __name__ == "__main__": ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_math.py b/test/dynamo/cpython/3_13/test_math.py new file mode 100644 index 00000000000000..51773d5f4783d0 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_math.py @@ -0,0 +1,2938 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + slowTest, + run_tests, + skipIfTorchDynamo, +) + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +# Python test set -- math module +# XXXX Should not do tests around zero only + +from test.support import verbose, requires_IEEE_754 +from test import support +import unittest +import fractions +import itertools +import decimal +import math +import os +import platform +import random +import struct +import sys + + +eps = 1E-05 +NAN = float('nan') +INF = float('inf') +NINF = float('-inf') +FLOAT_MAX = sys.float_info.max +FLOAT_MIN = sys.float_info.min + +# detect evidence of double-rounding: fsum is not always correctly +# rounded on machines that suffer from double rounding. +x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer +HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4) + +# locate file with test values +if __name__ == '__main__': + file = sys.argv[0] +else: + file = __file__ +test_dir = os.path.dirname(file) or os.curdir +math_testcases = os.path.join(test_dir, 'mathdata', 'math_testcases.txt') +test_file = os.path.join(test_dir, 'mathdata', 'cmath_testcases.txt') + + +def to_ulps(x): + """Convert a non-NaN float x to an integer, in such a way that + adjacent floats are converted to adjacent integers. Then + abs(ulps(x) - ulps(y)) gives the difference in ulps between two + floats. + + The results from this function will only make sense on platforms + where native doubles are represented in IEEE 754 binary64 format. + + Note: 0.0 and -0.0 are converted to 0 and -1, respectively. + """ + n = struct.unpack('= 0} product_{0 < j <= n >> i; j odd} j +# +# The outer product above is an infinite product, but once i >= n.bit_length, +# (n >> i) < 1 and the corresponding term of the product is empty. So only the +# finitely many terms for 0 <= i < n.bit_length() contribute anything. +# +# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner +# product in the formula above starts at 1 for i == n.bit_length(); for each i +# < n.bit_length() we get the inner product for i from that for i + 1 by +# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms, +# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2). + +def count_set_bits(n): + """Number of '1' bits in binary expansion of a nonnnegative integer.""" + return 1 + count_set_bits(n & n - 1) if n else 0 + +def partial_product(start, stop): + """Product of integers in range(start, stop, 2), computed recursively. + start and stop should both be odd, with start <= stop. + + """ + numfactors = (stop - start) >> 1 + if not numfactors: + return 1 + elif numfactors == 1: + return start + else: + mid = (start + numfactors) | 1 + return partial_product(start, mid) * partial_product(mid, stop) + +def py_factorial(n): + """Factorial of nonnegative integer n, via "Binary Split Factorial Formula" + described at http://www.luschny.de/math/factorial/binarysplitfact.html + + """ + inner = outer = 1 + for i in reversed(range(n.bit_length())): + inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1) + outer *= inner + return outer << (n - count_set_bits(n)) + +def ulp_abs_check(expected, got, ulp_tol, abs_tol): + """Given finite floats `expected` and `got`, check that they're + approximately equal to within the given number of ulps or the + given absolute tolerance, whichever is bigger. + + Returns None on success and an error message on failure. + """ + ulp_error = abs(to_ulps(expected) - to_ulps(got)) + abs_error = abs(expected - got) + + # Succeed if either abs_error <= abs_tol or ulp_error <= ulp_tol. + if abs_error <= abs_tol or ulp_error <= ulp_tol: + return None + else: + fmt = ("error = {:.3g} ({:d} ulps); " + "permitted error = {:.3g} or {:d} ulps") + return fmt.format(abs_error, ulp_error, abs_tol, ulp_tol) + +def parse_mtestfile(fname): + """Parse a file with test values + + -- starts a comment + blank lines, or lines containing only a comment, are ignored + other lines are expected to have the form + id fn arg -> expected [flag]* + + """ + with open(fname, encoding="utf-8") as fp: + for line in fp: + # strip comments, and skip blank lines + if '--' in line: + line = line[:line.index('--')] + if not line.strip(): + continue + + lhs, rhs = line.split('->') + id, fn, arg = lhs.split() + rhs_pieces = rhs.split() + exp = rhs_pieces[0] + flags = rhs_pieces[1:] + + yield (id, fn, float(arg), float(exp), flags) + + +def parse_testfile(fname): + """Parse a file with test values + + Empty lines or lines starting with -- are ignored + yields id, fn, arg_real, arg_imag, exp_real, exp_imag + """ + with open(fname, encoding="utf-8") as fp: + for line in fp: + # skip comment lines and blank lines + if line.startswith('--') or not line.strip(): + continue + + lhs, rhs = line.split('->') + id, fn, arg_real, arg_imag = lhs.split() + rhs_pieces = rhs.split() + exp_real, exp_imag = rhs_pieces[0], rhs_pieces[1] + flags = rhs_pieces[2:] + + yield (id, fn, + float(arg_real), float(arg_imag), + float(exp_real), float(exp_imag), + flags) + + +def result_check(expected, got, ulp_tol=5, abs_tol=0.0): + # Common logic of MathTests.(ftest, test_testcases, test_mtestcases) + """Compare arguments expected and got, as floats, if either + is a float, using a tolerance expressed in multiples of + ulp(expected) or absolutely (if given and greater). + + As a convenience, when neither argument is a float, and for + non-finite floats, exact equality is demanded. Also, nan==nan + as far as this function is concerned. + + Returns None on success and an error message on failure. + """ + + # Check exactly equal (applies also to strings representing exceptions) + if got == expected: + if not got and not expected: + if math.copysign(1, got) != math.copysign(1, expected): + return f"expected {expected}, got {got} (zero has wrong sign)" + return None + + failure = "not equal" + + # Turn mixed float and int comparison (e.g. floor()) to all-float + if isinstance(expected, float) and isinstance(got, int): + got = float(got) + elif isinstance(got, float) and isinstance(expected, int): + expected = float(expected) + + if isinstance(expected, float) and isinstance(got, float): + if math.isnan(expected) and math.isnan(got): + # Pass, since both nan + failure = None + elif math.isinf(expected) or math.isinf(got): + # We already know they're not equal, drop through to failure + pass + else: + # Both are finite floats (now). Are they close enough? + failure = ulp_abs_check(expected, got, ulp_tol, abs_tol) + + # arguments are not equal, and if numeric, are too far apart + if failure is not None: + fail_fmt = "expected {!r}, got {!r}" + fail_msg = fail_fmt.format(expected, got) + fail_msg += ' ({})'.format(failure) + return fail_msg + else: + return None + +class FloatLike: + def __init__(self, value): + self.value = value + + def __float__(self): + return self.value + +class IntSubclass(int): + pass + +# Class providing an __index__ method. +class MyIndexable(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value + +class BadDescr: + def __get__(self, obj, objtype=None): + raise ValueError + +class MathTests(__TestCase): + + def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0): + """Compare arguments expected and got, as floats, if either + is a float, using a tolerance expressed in multiples of + ulp(expected) or absolutely, whichever is greater. + + As a convenience, when neither argument is a float, and for + non-finite floats, exact equality is demanded. Also, nan==nan + in this function. + """ + failure = result_check(expected, got, ulp_tol, abs_tol) + if failure is not None: + self.fail("{}: {}".format(name, failure)) + + def testConstants(self): + # Ref: Abramowitz & Stegun (Dover, 1965) + self.ftest('pi', math.pi, 3.141592653589793238462643) + self.ftest('e', math.e, 2.718281828459045235360287) + self.assertEqual(math.tau, 2*math.pi) + + def testAcos(self): + self.assertRaises(TypeError, math.acos) + self.ftest('acos(-1)', math.acos(-1), math.pi) + self.ftest('acos(0)', math.acos(0), math.pi/2) + self.ftest('acos(1)', math.acos(1), 0) + self.assertRaises(ValueError, math.acos, INF) + self.assertRaises(ValueError, math.acos, NINF) + self.assertRaises(ValueError, math.acos, 1 + eps) + self.assertRaises(ValueError, math.acos, -1 - eps) + self.assertTrue(math.isnan(math.acos(NAN))) + + def testAcosh(self): + self.assertRaises(TypeError, math.acosh) + self.ftest('acosh(1)', math.acosh(1), 0) + self.ftest('acosh(2)', math.acosh(2), 1.3169578969248168) + self.assertRaises(ValueError, math.acosh, 0) + self.assertRaises(ValueError, math.acosh, -1) + self.assertEqual(math.acosh(INF), INF) + self.assertRaises(ValueError, math.acosh, NINF) + self.assertTrue(math.isnan(math.acosh(NAN))) + + def testAsin(self): + self.assertRaises(TypeError, math.asin) + self.ftest('asin(-1)', math.asin(-1), -math.pi/2) + self.ftest('asin(0)', math.asin(0), 0) + self.ftest('asin(1)', math.asin(1), math.pi/2) + self.assertRaises(ValueError, math.asin, INF) + self.assertRaises(ValueError, math.asin, NINF) + self.assertRaises(ValueError, math.asin, 1 + eps) + self.assertRaises(ValueError, math.asin, -1 - eps) + self.assertTrue(math.isnan(math.asin(NAN))) + + def testAsinh(self): + self.assertRaises(TypeError, math.asinh) + self.ftest('asinh(0)', math.asinh(0), 0) + self.ftest('asinh(1)', math.asinh(1), 0.88137358701954305) + self.ftest('asinh(-1)', math.asinh(-1), -0.88137358701954305) + self.assertEqual(math.asinh(INF), INF) + self.assertEqual(math.asinh(NINF), NINF) + self.assertTrue(math.isnan(math.asinh(NAN))) + + def testAtan(self): + self.assertRaises(TypeError, math.atan) + self.ftest('atan(-1)', math.atan(-1), -math.pi/4) + self.ftest('atan(0)', math.atan(0), 0) + self.ftest('atan(1)', math.atan(1), math.pi/4) + self.ftest('atan(inf)', math.atan(INF), math.pi/2) + self.ftest('atan(-inf)', math.atan(NINF), -math.pi/2) + self.assertTrue(math.isnan(math.atan(NAN))) + + def testAtanh(self): + self.assertRaises(TypeError, math.atan) + self.ftest('atanh(0)', math.atanh(0), 0) + self.ftest('atanh(0.5)', math.atanh(0.5), 0.54930614433405489) + self.ftest('atanh(-0.5)', math.atanh(-0.5), -0.54930614433405489) + self.assertRaises(ValueError, math.atanh, 1) + self.assertRaises(ValueError, math.atanh, -1) + self.assertRaises(ValueError, math.atanh, INF) + self.assertRaises(ValueError, math.atanh, NINF) + self.assertTrue(math.isnan(math.atanh(NAN))) + + def testAtan2(self): + self.assertRaises(TypeError, math.atan2) + self.ftest('atan2(-1, 0)', math.atan2(-1, 0), -math.pi/2) + self.ftest('atan2(-1, 1)', math.atan2(-1, 1), -math.pi/4) + self.ftest('atan2(0, 1)', math.atan2(0, 1), 0) + self.ftest('atan2(1, 1)', math.atan2(1, 1), math.pi/4) + self.ftest('atan2(1, 0)', math.atan2(1, 0), math.pi/2) + self.ftest('atan2(1, -1)', math.atan2(1, -1), 3*math.pi/4) + + # math.atan2(0, x) + self.ftest('atan2(0., -inf)', math.atan2(0., NINF), math.pi) + self.ftest('atan2(0., -2.3)', math.atan2(0., -2.3), math.pi) + self.ftest('atan2(0., -0.)', math.atan2(0., -0.), math.pi) + self.assertEqual(math.atan2(0., 0.), 0.) + self.assertEqual(math.atan2(0., 2.3), 0.) + self.assertEqual(math.atan2(0., INF), 0.) + self.assertTrue(math.isnan(math.atan2(0., NAN))) + # math.atan2(-0, x) + self.ftest('atan2(-0., -inf)', math.atan2(-0., NINF), -math.pi) + self.ftest('atan2(-0., -2.3)', math.atan2(-0., -2.3), -math.pi) + self.ftest('atan2(-0., -0.)', math.atan2(-0., -0.), -math.pi) + self.assertEqual(math.atan2(-0., 0.), -0.) + self.assertEqual(math.atan2(-0., 2.3), -0.) + self.assertEqual(math.atan2(-0., INF), -0.) + self.assertTrue(math.isnan(math.atan2(-0., NAN))) + # math.atan2(INF, x) + self.ftest('atan2(inf, -inf)', math.atan2(INF, NINF), math.pi*3/4) + self.ftest('atan2(inf, -2.3)', math.atan2(INF, -2.3), math.pi/2) + self.ftest('atan2(inf, -0.)', math.atan2(INF, -0.0), math.pi/2) + self.ftest('atan2(inf, 0.)', math.atan2(INF, 0.0), math.pi/2) + self.ftest('atan2(inf, 2.3)', math.atan2(INF, 2.3), math.pi/2) + self.ftest('atan2(inf, inf)', math.atan2(INF, INF), math.pi/4) + self.assertTrue(math.isnan(math.atan2(INF, NAN))) + # math.atan2(NINF, x) + self.ftest('atan2(-inf, -inf)', math.atan2(NINF, NINF), -math.pi*3/4) + self.ftest('atan2(-inf, -2.3)', math.atan2(NINF, -2.3), -math.pi/2) + self.ftest('atan2(-inf, -0.)', math.atan2(NINF, -0.0), -math.pi/2) + self.ftest('atan2(-inf, 0.)', math.atan2(NINF, 0.0), -math.pi/2) + self.ftest('atan2(-inf, 2.3)', math.atan2(NINF, 2.3), -math.pi/2) + self.ftest('atan2(-inf, inf)', math.atan2(NINF, INF), -math.pi/4) + self.assertTrue(math.isnan(math.atan2(NINF, NAN))) + # math.atan2(+finite, x) + self.ftest('atan2(2.3, -inf)', math.atan2(2.3, NINF), math.pi) + self.ftest('atan2(2.3, -0.)', math.atan2(2.3, -0.), math.pi/2) + self.ftest('atan2(2.3, 0.)', math.atan2(2.3, 0.), math.pi/2) + self.assertEqual(math.atan2(2.3, INF), 0.) + self.assertTrue(math.isnan(math.atan2(2.3, NAN))) + # math.atan2(-finite, x) + self.ftest('atan2(-2.3, -inf)', math.atan2(-2.3, NINF), -math.pi) + self.ftest('atan2(-2.3, -0.)', math.atan2(-2.3, -0.), -math.pi/2) + self.ftest('atan2(-2.3, 0.)', math.atan2(-2.3, 0.), -math.pi/2) + self.assertEqual(math.atan2(-2.3, INF), -0.) + self.assertTrue(math.isnan(math.atan2(-2.3, NAN))) + # math.atan2(NAN, x) + self.assertTrue(math.isnan(math.atan2(NAN, NINF))) + self.assertTrue(math.isnan(math.atan2(NAN, -2.3))) + self.assertTrue(math.isnan(math.atan2(NAN, -0.))) + self.assertTrue(math.isnan(math.atan2(NAN, 0.))) + self.assertTrue(math.isnan(math.atan2(NAN, 2.3))) + self.assertTrue(math.isnan(math.atan2(NAN, INF))) + self.assertTrue(math.isnan(math.atan2(NAN, NAN))) + + def testCbrt(self): + self.assertRaises(TypeError, math.cbrt) + self.ftest('cbrt(0)', math.cbrt(0), 0) + self.ftest('cbrt(1)', math.cbrt(1), 1) + self.ftest('cbrt(8)', math.cbrt(8), 2) + self.ftest('cbrt(0.0)', math.cbrt(0.0), 0.0) + self.ftest('cbrt(-0.0)', math.cbrt(-0.0), -0.0) + self.ftest('cbrt(1.2)', math.cbrt(1.2), 1.062658569182611) + self.ftest('cbrt(-2.6)', math.cbrt(-2.6), -1.375068867074141) + self.ftest('cbrt(27)', math.cbrt(27), 3) + self.ftest('cbrt(-1)', math.cbrt(-1), -1) + self.ftest('cbrt(-27)', math.cbrt(-27), -3) + self.assertEqual(math.cbrt(INF), INF) + self.assertEqual(math.cbrt(NINF), NINF) + self.assertTrue(math.isnan(math.cbrt(NAN))) + + def testCeil(self): + self.assertRaises(TypeError, math.ceil) + self.assertEqual(int, type(math.ceil(0.5))) + self.assertEqual(math.ceil(0.5), 1) + self.assertEqual(math.ceil(1.0), 1) + self.assertEqual(math.ceil(1.5), 2) + self.assertEqual(math.ceil(-0.5), 0) + self.assertEqual(math.ceil(-1.0), -1) + self.assertEqual(math.ceil(-1.5), -1) + self.assertEqual(math.ceil(0.0), 0) + self.assertEqual(math.ceil(-0.0), 0) + #self.assertEqual(math.ceil(INF), INF) + #self.assertEqual(math.ceil(NINF), NINF) + #self.assertTrue(math.isnan(math.ceil(NAN))) + + class TestCeil: + def __ceil__(self): + return 42 + class FloatCeil(float): + def __ceil__(self): + return 42 + class TestNoCeil: + pass + class TestBadCeil: + __ceil__ = BadDescr() + self.assertEqual(math.ceil(TestCeil()), 42) + self.assertEqual(math.ceil(FloatCeil()), 42) + self.assertEqual(math.ceil(FloatLike(42.5)), 43) + self.assertRaises(TypeError, math.ceil, TestNoCeil()) + self.assertRaises(ValueError, math.ceil, TestBadCeil()) + + t = TestNoCeil() + t.__ceil__ = lambda *args: args + self.assertRaises(TypeError, math.ceil, t) + self.assertRaises(TypeError, math.ceil, t, 0) + + self.assertEqual(math.ceil(FloatLike(+1.0)), +1.0) + self.assertEqual(math.ceil(FloatLike(-1.0)), -1.0) + + @requires_IEEE_754 + def testCopysign(self): + self.assertEqual(math.copysign(1, 42), 1.0) + self.assertEqual(math.copysign(0., 42), 0.0) + self.assertEqual(math.copysign(1., -42), -1.0) + self.assertEqual(math.copysign(3, 0.), 3.0) + self.assertEqual(math.copysign(4., -0.), -4.0) + + self.assertRaises(TypeError, math.copysign) + # copysign should let us distinguish signs of zeros + self.assertEqual(math.copysign(1., 0.), 1.) + self.assertEqual(math.copysign(1., -0.), -1.) + self.assertEqual(math.copysign(INF, 0.), INF) + self.assertEqual(math.copysign(INF, -0.), NINF) + self.assertEqual(math.copysign(NINF, 0.), INF) + self.assertEqual(math.copysign(NINF, -0.), NINF) + # and of infinities + self.assertEqual(math.copysign(1., INF), 1.) + self.assertEqual(math.copysign(1., NINF), -1.) + self.assertEqual(math.copysign(INF, INF), INF) + self.assertEqual(math.copysign(INF, NINF), NINF) + self.assertEqual(math.copysign(NINF, INF), INF) + self.assertEqual(math.copysign(NINF, NINF), NINF) + self.assertTrue(math.isnan(math.copysign(NAN, 1.))) + self.assertTrue(math.isnan(math.copysign(NAN, INF))) + self.assertTrue(math.isnan(math.copysign(NAN, NINF))) + self.assertTrue(math.isnan(math.copysign(NAN, NAN))) + # copysign(INF, NAN) may be INF or it may be NINF, since + # we don't know whether the sign bit of NAN is set on any + # given platform. + self.assertTrue(math.isinf(math.copysign(INF, NAN))) + # similarly, copysign(2., NAN) could be 2. or -2. + self.assertEqual(abs(math.copysign(2., NAN)), 2.) + + def testCos(self): + self.assertRaises(TypeError, math.cos) + self.ftest('cos(-pi/2)', math.cos(-math.pi/2), 0, abs_tol=math.ulp(1)) + self.ftest('cos(0)', math.cos(0), 1) + self.ftest('cos(pi/2)', math.cos(math.pi/2), 0, abs_tol=math.ulp(1)) + self.ftest('cos(pi)', math.cos(math.pi), -1) + try: + self.assertTrue(math.isnan(math.cos(INF))) + self.assertTrue(math.isnan(math.cos(NINF))) + except ValueError: + self.assertRaises(ValueError, math.cos, INF) + self.assertRaises(ValueError, math.cos, NINF) + self.assertTrue(math.isnan(math.cos(NAN))) + + @unittest.skipIf(sys.platform == 'win32' and platform.machine() in ('ARM', 'ARM64'), + "Windows UCRT is off by 2 ULP this test requires accuracy within 1 ULP") + def testCosh(self): + self.assertRaises(TypeError, math.cosh) + self.ftest('cosh(0)', math.cosh(0), 1) + self.ftest('cosh(2)-2*cosh(1)**2', math.cosh(2)-2*math.cosh(1)**2, -1) # Thanks to Lambert + self.assertEqual(math.cosh(INF), INF) + self.assertEqual(math.cosh(NINF), INF) + self.assertTrue(math.isnan(math.cosh(NAN))) + + def testDegrees(self): + self.assertRaises(TypeError, math.degrees) + self.ftest('degrees(pi)', math.degrees(math.pi), 180.0) + self.ftest('degrees(pi/2)', math.degrees(math.pi/2), 90.0) + self.ftest('degrees(-pi/4)', math.degrees(-math.pi/4), -45.0) + self.ftest('degrees(0)', math.degrees(0), 0) + + def testExp(self): + self.assertRaises(TypeError, math.exp) + self.ftest('exp(-1)', math.exp(-1), 1/math.e) + self.ftest('exp(0)', math.exp(0), 1) + self.ftest('exp(1)', math.exp(1), math.e) + self.assertEqual(math.exp(INF), INF) + self.assertEqual(math.exp(NINF), 0.) + self.assertTrue(math.isnan(math.exp(NAN))) + self.assertRaises(OverflowError, math.exp, 1000000) + + def testExp2(self): + self.assertRaises(TypeError, math.exp2) + self.ftest('exp2(-1)', math.exp2(-1), 0.5) + self.ftest('exp2(0)', math.exp2(0), 1) + self.ftest('exp2(1)', math.exp2(1), 2) + self.ftest('exp2(2.3)', math.exp2(2.3), 4.924577653379665) + self.assertEqual(math.exp2(INF), INF) + self.assertEqual(math.exp2(NINF), 0.) + self.assertTrue(math.isnan(math.exp2(NAN))) + self.assertRaises(OverflowError, math.exp2, 1000000) + + def testFabs(self): + self.assertRaises(TypeError, math.fabs) + self.ftest('fabs(-1)', math.fabs(-1), 1) + self.ftest('fabs(0)', math.fabs(0), 0) + self.ftest('fabs(1)', math.fabs(1), 1) + + @skipIfTorchDynamo("infinite loop") + def testFactorial(self): + self.assertEqual(math.factorial(0), 1) + total = 1 + for i in range(1, 1000): + total *= i + self.assertEqual(math.factorial(i), total) + self.assertEqual(math.factorial(i), py_factorial(i)) + self.assertRaises(ValueError, math.factorial, -1) + self.assertRaises(ValueError, math.factorial, -10**100) + + def testFactorialNonIntegers(self): + self.assertRaises(TypeError, math.factorial, 5.0) + self.assertRaises(TypeError, math.factorial, 5.2) + self.assertRaises(TypeError, math.factorial, -1.0) + self.assertRaises(TypeError, math.factorial, -1e100) + self.assertRaises(TypeError, math.factorial, decimal.Decimal('5')) + self.assertRaises(TypeError, math.factorial, decimal.Decimal('5.2')) + self.assertRaises(TypeError, math.factorial, "5") + + # Other implementations may place different upper bounds. + @support.cpython_only + def testFactorialHugeInputs(self): + # Currently raises OverflowError for inputs that are too large + # to fit into a C long. + self.assertRaises(OverflowError, math.factorial, 10**100) + self.assertRaises(TypeError, math.factorial, 1e100) + + def testFloor(self): + self.assertRaises(TypeError, math.floor) + self.assertEqual(int, type(math.floor(0.5))) + self.assertEqual(math.floor(0.5), 0) + self.assertEqual(math.floor(1.0), 1) + self.assertEqual(math.floor(1.5), 1) + self.assertEqual(math.floor(-0.5), -1) + self.assertEqual(math.floor(-1.0), -1) + self.assertEqual(math.floor(-1.5), -2) + #self.assertEqual(math.ceil(INF), INF) + #self.assertEqual(math.ceil(NINF), NINF) + #self.assertTrue(math.isnan(math.floor(NAN))) + + class TestFloor: + def __floor__(self): + return 42 + class FloatFloor(float): + def __floor__(self): + return 42 + class TestNoFloor: + pass + class TestBadFloor: + __floor__ = BadDescr() + self.assertEqual(math.floor(TestFloor()), 42) + self.assertEqual(math.floor(FloatFloor()), 42) + self.assertEqual(math.floor(FloatLike(41.9)), 41) + self.assertRaises(TypeError, math.floor, TestNoFloor()) + self.assertRaises(ValueError, math.floor, TestBadFloor()) + + t = TestNoFloor() + t.__floor__ = lambda *args: args + self.assertRaises(TypeError, math.floor, t) + self.assertRaises(TypeError, math.floor, t, 0) + + self.assertEqual(math.floor(FloatLike(+1.0)), +1.0) + self.assertEqual(math.floor(FloatLike(-1.0)), -1.0) + + def testFmod(self): + self.assertRaises(TypeError, math.fmod) + self.ftest('fmod(10, 1)', math.fmod(10, 1), 0.0) + self.ftest('fmod(10, 0.5)', math.fmod(10, 0.5), 0.0) + self.ftest('fmod(10, 1.5)', math.fmod(10, 1.5), 1.0) + self.ftest('fmod(-10, 1)', math.fmod(-10, 1), -0.0) + self.ftest('fmod(-10, 0.5)', math.fmod(-10, 0.5), -0.0) + self.ftest('fmod(-10, 1.5)', math.fmod(-10, 1.5), -1.0) + self.assertTrue(math.isnan(math.fmod(NAN, 1.))) + self.assertTrue(math.isnan(math.fmod(1., NAN))) + self.assertTrue(math.isnan(math.fmod(NAN, NAN))) + self.assertRaises(ValueError, math.fmod, 1., 0.) + self.assertRaises(ValueError, math.fmod, INF, 1.) + self.assertRaises(ValueError, math.fmod, NINF, 1.) + self.assertRaises(ValueError, math.fmod, INF, 0.) + self.assertEqual(math.fmod(3.0, INF), 3.0) + self.assertEqual(math.fmod(-3.0, INF), -3.0) + self.assertEqual(math.fmod(3.0, NINF), 3.0) + self.assertEqual(math.fmod(-3.0, NINF), -3.0) + self.assertEqual(math.fmod(0.0, 3.0), 0.0) + self.assertEqual(math.fmod(0.0, NINF), 0.0) + self.assertRaises(ValueError, math.fmod, INF, INF) + + def testFrexp(self): + self.assertRaises(TypeError, math.frexp) + + def testfrexp(name, result, expected): + (mant, exp), (emant, eexp) = result, expected + if abs(mant-emant) > eps or exp != eexp: + self.fail('%s returned %r, expected %r'%\ + (name, result, expected)) + + testfrexp('frexp(-1)', math.frexp(-1), (-0.5, 1)) + testfrexp('frexp(0)', math.frexp(0), (0, 0)) + testfrexp('frexp(1)', math.frexp(1), (0.5, 1)) + testfrexp('frexp(2)', math.frexp(2), (0.5, 2)) + + self.assertEqual(math.frexp(INF)[0], INF) + self.assertEqual(math.frexp(NINF)[0], NINF) + self.assertTrue(math.isnan(math.frexp(NAN)[0])) + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "fsum is not exact on machines with double rounding") + def testFsum(self): + # math.fsum relies on exact rounding for correct operation. + # There's a known problem with IA32 floating-point that causes + # inexact rounding in some situations, and will cause the + # math.fsum tests below to fail; see issue #2937. On non IEEE + # 754 platforms, and on IEEE 754 platforms that exhibit the + # problem described in issue #2937, we simply skip the whole + # test. + + # Python version of math.fsum, for comparison. Uses a + # different algorithm based on frexp, ldexp and integer + # arithmetic. + from sys import float_info + mant_dig = float_info.mant_dig + etiny = float_info.min_exp - mant_dig + + def msum(iterable): + """Full precision summation. Compute sum(iterable) without any + intermediate accumulation of error. Based on the 'lsum' function + at https://code.activestate.com/recipes/393090-binary-floating-point-summation-accurate-to-full-p/ + + """ + tmant, texp = 0, 0 + for x in iterable: + mant, exp = math.frexp(x) + mant, exp = int(math.ldexp(mant, mant_dig)), exp - mant_dig + if texp > exp: + tmant <<= texp-exp + texp = exp + else: + mant <<= exp-texp + tmant += mant + # Round tmant * 2**texp to a float. The original recipe + # used float(str(tmant)) * 2.0**texp for this, but that's + # a little unsafe because str -> float conversion can't be + # relied upon to do correct rounding on all platforms. + tail = max(len(bin(abs(tmant)))-2 - mant_dig, etiny - texp) + if tail > 0: + h = 1 << (tail-1) + tmant = tmant // (2*h) + bool(tmant & h and tmant & 3*h-1) + texp += tail + return math.ldexp(tmant, texp) + + test_values = [ + ([], 0.0), + ([0.0], 0.0), + ([1e100, 1.0, -1e100, 1e-100, 1e50, -1.0, -1e50], 1e-100), + ([1e100, 1.0, -1e100, 1e-100, 1e50, -1, -1e50], 1e-100), + ([2.0**53, -0.5, -2.0**-54], 2.0**53-1.0), + ([2.0**53, 1.0, 2.0**-100], 2.0**53+2.0), + ([2.0**53+10.0, 1.0, 2.0**-100], 2.0**53+12.0), + ([2.0**53-4.0, 0.5, 2.0**-54], 2.0**53-3.0), + ([1./n for n in range(1, 1001)], + float.fromhex('0x1.df11f45f4e61ap+2')), + ([(-1.)**n/n for n in range(1, 1001)], + float.fromhex('-0x1.62a2af1bd3624p-1')), + ([1e16, 1., 1e-16], 10000000000000002.0), + ([1e16-2., 1.-2.**-53, -(1e16-2.), -(1.-2.**-53)], 0.0), + # exercise code for resizing partials array + ([2.**n - 2.**(n+50) + 2.**(n+52) for n in range(-1074, 972, 2)] + + [-2.**1022], + float.fromhex('0x1.5555555555555p+970')), + ] + + # Telescoping sum, with exact differences (due to Sterbenz) + terms = [1.7**i for i in range(1001)] + test_values.append(( + [terms[i+1] - terms[i] for i in range(1000)] + [-terms[1000]], + -terms[0] + )) + + for i, (vals, expected) in enumerate(test_values): + try: + actual = math.fsum(vals) + except OverflowError: + self.fail("test %d failed: got OverflowError, expected %r " + "for math.fsum(%.100r)" % (i, expected, vals)) + except ValueError: + self.fail("test %d failed: got ValueError, expected %r " + "for math.fsum(%.100r)" % (i, expected, vals)) + self.assertEqual(actual, expected) + + from random import random, gauss, shuffle + for j in range(1000): + vals = [7, 1e100, -7, -1e100, -9e-20, 8e-20] * 10 + s = 0 + for i in range(200): + v = gauss(0, random()) ** 7 - s + s += v + vals.append(v) + shuffle(vals) + + s = msum(vals) + self.assertEqual(msum(vals), math.fsum(vals)) + + self.assertEqual(math.fsum([1.0, math.inf]), math.inf) + self.assertTrue(math.isnan(math.fsum([math.nan, 1.0]))) + self.assertEqual(math.fsum([1e100, FloatLike(1.0), -1e100, 1e-100, + 1e50, FloatLike(-1.0), -1e50]), 1e-100) + self.assertRaises(OverflowError, math.fsum, [1e+308, 1e+308]) + self.assertRaises(ValueError, math.fsum, [math.inf, -math.inf]) + self.assertRaises(TypeError, math.fsum, ['spam']) + self.assertRaises(TypeError, math.fsum, 1) + self.assertRaises(OverflowError, math.fsum, [10**1000]) + + def bad_iter(): + yield 1.0 + raise ZeroDivisionError + + self.assertRaises(ZeroDivisionError, math.fsum, bad_iter()) + + def testGcd(self): + gcd = math.gcd + self.assertEqual(gcd(0, 0), 0) + self.assertEqual(gcd(1, 0), 1) + self.assertEqual(gcd(-1, 0), 1) + self.assertEqual(gcd(0, 1), 1) + self.assertEqual(gcd(0, -1), 1) + self.assertEqual(gcd(7, 1), 1) + self.assertEqual(gcd(7, -1), 1) + self.assertEqual(gcd(-23, 15), 1) + self.assertEqual(gcd(120, 84), 12) + self.assertEqual(gcd(84, -120), 12) + self.assertEqual(gcd(1216342683557601535506311712, + 436522681849110124616458784), 32) + + x = 434610456570399902378880679233098819019853229470286994367836600566 + y = 1064502245825115327754847244914921553977 + for c in (652560, + 576559230871654959816130551884856912003141446781646602790216406874): + a = x * c + b = y * c + self.assertEqual(gcd(a, b), c) + self.assertEqual(gcd(b, a), c) + self.assertEqual(gcd(-a, b), c) + self.assertEqual(gcd(b, -a), c) + self.assertEqual(gcd(a, -b), c) + self.assertEqual(gcd(-b, a), c) + self.assertEqual(gcd(-a, -b), c) + self.assertEqual(gcd(-b, -a), c) + + self.assertEqual(gcd(), 0) + self.assertEqual(gcd(120), 120) + self.assertEqual(gcd(-120), 120) + self.assertEqual(gcd(120, 84, 102), 6) + self.assertEqual(gcd(120, 1, 84), 1) + + self.assertRaises(TypeError, gcd, 120.0) + self.assertRaises(TypeError, gcd, 120.0, 84) + self.assertRaises(TypeError, gcd, 120, 84.0) + self.assertRaises(TypeError, gcd, 120, 1, 84.0) + self.assertEqual(gcd(MyIndexable(120), MyIndexable(84)), 12) + + def testHypot(self): + from decimal import Decimal + from fractions import Fraction + + hypot = math.hypot + + # Test different numbers of arguments (from zero to five) + # against a straightforward pure python implementation + args = math.e, math.pi, math.sqrt(2.0), math.gamma(3.5), math.sin(2.1) + for i in range(len(args)+1): + self.assertAlmostEqual( + hypot(*args[:i]), + math.sqrt(sum(s**2 for s in args[:i])) + ) + + # Test allowable types (those with __float__) + self.assertEqual(hypot(12.0, 5.0), 13.0) + self.assertEqual(hypot(12, 5), 13) + self.assertEqual(hypot(0.75, -1), 1.25) + self.assertEqual(hypot(-1, 0.75), 1.25) + self.assertEqual(hypot(0.75, FloatLike(-1.)), 1.25) + self.assertEqual(hypot(FloatLike(-1.), 0.75), 1.25) + self.assertEqual(hypot(Decimal(12), Decimal(5)), 13) + self.assertEqual(hypot(Fraction(12, 32), Fraction(5, 32)), Fraction(13, 32)) + self.assertEqual(hypot(True, False, True, True, True), 2.0) + + # Test corner cases + self.assertEqual(hypot(0.0, 0.0), 0.0) # Max input is zero + self.assertEqual(hypot(-10.5), 10.5) # Negative input + self.assertEqual(hypot(), 0.0) # Negative input + self.assertEqual(1.0, + math.copysign(1.0, hypot(-0.0)) # Convert negative zero to positive zero + ) + self.assertEqual( # Handling of moving max to the end + hypot(1.5, 1.5, 0.5), + hypot(1.5, 0.5, 1.5), + ) + + # Test handling of bad arguments + with self.assertRaises(TypeError): # Reject keyword args + hypot(x=1) + with self.assertRaises(TypeError): # Reject values without __float__ + hypot(1.1, 'string', 2.2) + int_too_big_for_float = 10 ** (sys.float_info.max_10_exp + 5) + with self.assertRaises((ValueError, OverflowError)): + hypot(1, int_too_big_for_float) + + # Any infinity gives positive infinity. + self.assertEqual(hypot(INF), INF) + self.assertEqual(hypot(0, INF), INF) + self.assertEqual(hypot(10, INF), INF) + self.assertEqual(hypot(-10, INF), INF) + self.assertEqual(hypot(NAN, INF), INF) + self.assertEqual(hypot(INF, NAN), INF) + self.assertEqual(hypot(NINF, NAN), INF) + self.assertEqual(hypot(NAN, NINF), INF) + self.assertEqual(hypot(-INF, INF), INF) + self.assertEqual(hypot(-INF, -INF), INF) + self.assertEqual(hypot(10, -INF), INF) + + # If no infinity, any NaN gives a NaN. + self.assertTrue(math.isnan(hypot(NAN))) + self.assertTrue(math.isnan(hypot(0, NAN))) + self.assertTrue(math.isnan(hypot(NAN, 10))) + self.assertTrue(math.isnan(hypot(10, NAN))) + self.assertTrue(math.isnan(hypot(NAN, NAN))) + self.assertTrue(math.isnan(hypot(NAN))) + + # Verify scaling for extremely large values + fourthmax = FLOAT_MAX / 4.0 + for n in range(32): + self.assertTrue(math.isclose(hypot(*([fourthmax]*n)), + fourthmax * math.sqrt(n))) + + # Verify scaling for extremely small values + for exp in range(32): + scale = FLOAT_MIN / 2.0 ** exp + self.assertEqual(math.hypot(4*scale, 3*scale), 5*scale) + + self.assertRaises(TypeError, math.hypot, *([1.0]*18), 'spam') + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "hypot() loses accuracy on machines with double rounding") + def testHypotAccuracy(self): + # Verify improved accuracy in cases that were known to be inaccurate. + # + # The new algorithm's accuracy depends on IEEE 754 arithmetic + # guarantees, on having the usual ROUND HALF EVEN rounding mode, on + # the system not having double rounding due to extended precision, + # and on the compiler maintaining the specified order of operations. + # + # This test is known to succeed on most of our builds. If it fails + # some build, we either need to add another skipIf if the cause is + # identifiable; otherwise, we can remove this test entirely. + + hypot = math.hypot + Decimal = decimal.Decimal + high_precision = decimal.Context(prec=500) + + for hx, hy in [ + # Cases with a 1 ulp error in Python 3.7 compiled with Clang + ('0x1.10e89518dca48p+29', '0x1.1970f7565b7efp+30'), + ('0x1.10106eb4b44a2p+29', '0x1.ef0596cdc97f8p+29'), + ('0x1.459c058e20bb7p+30', '0x1.993ca009b9178p+29'), + ('0x1.378371ae67c0cp+30', '0x1.fbe6619854b4cp+29'), + ('0x1.f4cd0574fb97ap+29', '0x1.50fe31669340ep+30'), + ('0x1.494b2cdd3d446p+29', '0x1.212a5367b4c7cp+29'), + ('0x1.f84e649f1e46dp+29', '0x1.1fa56bef8eec4p+30'), + ('0x1.2e817edd3d6fap+30', '0x1.eb0814f1e9602p+29'), + ('0x1.0d3a6e3d04245p+29', '0x1.32a62fea52352p+30'), + ('0x1.888e19611bfc5p+29', '0x1.52b8e70b24353p+29'), + + # Cases with 2 ulp error in Python 3.8 + ('0x1.538816d48a13fp+29', '0x1.7967c5ca43e16p+29'), + ('0x1.57b47b7234530p+29', '0x1.74e2c7040e772p+29'), + ('0x1.821b685e9b168p+30', '0x1.677dc1c1e3dc6p+29'), + ('0x1.9e8247f67097bp+29', '0x1.24bd2dc4f4baep+29'), + ('0x1.b73b59e0cb5f9p+29', '0x1.da899ab784a97p+28'), + ('0x1.94a8d2842a7cfp+30', '0x1.326a51d4d8d8ap+30'), + ('0x1.e930b9cd99035p+29', '0x1.5a1030e18dff9p+30'), + ('0x1.1592bbb0e4690p+29', '0x1.a9c337b33fb9ap+29'), + ('0x1.1243a50751fd4p+29', '0x1.a5a10175622d9p+29'), + ('0x1.57a8596e74722p+30', '0x1.42d1af9d04da9p+30'), + + # Cases with 1 ulp error in version fff3c28052e6b0 + ('0x1.ee7dbd9565899p+29', '0x1.7ab4d6fc6e4b4p+29'), + ('0x1.5c6bfbec5c4dcp+30', '0x1.02511184b4970p+30'), + ('0x1.59dcebba995cap+30', '0x1.50ca7e7c38854p+29'), + ('0x1.768cdd94cf5aap+29', '0x1.9cfdc5571d38ep+29'), + ('0x1.dcf137d60262ep+29', '0x1.1101621990b3ep+30'), + ('0x1.3a2d006e288b0p+30', '0x1.e9a240914326cp+29'), + ('0x1.62a32f7f53c61p+29', '0x1.47eb6cd72684fp+29'), + ('0x1.d3bcb60748ef2p+29', '0x1.3f13c4056312cp+30'), + ('0x1.282bdb82f17f3p+30', '0x1.640ba4c4eed3ap+30'), + ('0x1.89d8c423ea0c6p+29', '0x1.d35dcfe902bc3p+29'), + ]: + x = float.fromhex(hx) + y = float.fromhex(hy) + with self.subTest(hx=hx, hy=hy, x=x, y=y): + with decimal.localcontext(high_precision): + z = float((Decimal(x)**2 + Decimal(y)**2).sqrt()) + self.assertEqual(hypot(x, y), z) + + def testDist(self): + from decimal import Decimal as D + from fractions import Fraction as F + + dist = math.dist + sqrt = math.sqrt + + # Simple exact cases + self.assertEqual(dist((1.0, 2.0, 3.0), (4.0, 2.0, -1.0)), 5.0) + self.assertEqual(dist((1, 2, 3), (4, 2, -1)), 5.0) + + # Test different numbers of arguments (from zero to nine) + # against a straightforward pure python implementation + for i in range(9): + for j in range(5): + p = tuple(random.uniform(-5, 5) for k in range(i)) + q = tuple(random.uniform(-5, 5) for k in range(i)) + self.assertAlmostEqual( + dist(p, q), + sqrt(sum((px - qx) ** 2.0 for px, qx in zip(p, q))) + ) + + # Test non-tuple inputs + self.assertEqual(dist([1.0, 2.0, 3.0], [4.0, 2.0, -1.0]), 5.0) + self.assertEqual(dist(iter([1.0, 2.0, 3.0]), iter([4.0, 2.0, -1.0])), 5.0) + + # Test allowable types (those with __float__) + self.assertEqual(dist((14.0, 1.0), (2.0, -4.0)), 13.0) + self.assertEqual(dist((14, 1), (2, -4)), 13) + self.assertEqual(dist((FloatLike(14.), 1), (2, -4)), 13) + self.assertEqual(dist((11, 1), (FloatLike(-1.), -4)), 13) + self.assertEqual(dist((14, FloatLike(-1.)), (2, -6)), 13) + self.assertEqual(dist((14, -1), (2, -6)), 13) + self.assertEqual(dist((D(14), D(1)), (D(2), D(-4))), D(13)) + self.assertEqual(dist((F(14, 32), F(1, 32)), (F(2, 32), F(-4, 32))), + F(13, 32)) + self.assertEqual(dist((True, True, False, False, True, True), + (True, False, True, False, False, False)), + 2.0) + + # Test corner cases + self.assertEqual(dist((13.25, 12.5, -3.25), + (13.25, 12.5, -3.25)), + 0.0) # Distance with self is zero + self.assertEqual(dist((), ()), 0.0) # Zero-dimensional case + self.assertEqual(1.0, # Convert negative zero to positive zero + math.copysign(1.0, dist((-0.0,), (0.0,))) + ) + self.assertEqual(1.0, # Convert negative zero to positive zero + math.copysign(1.0, dist((0.0,), (-0.0,))) + ) + self.assertEqual( # Handling of moving max to the end + dist((1.5, 1.5, 0.5), (0, 0, 0)), + dist((1.5, 0.5, 1.5), (0, 0, 0)) + ) + + # Verify tuple subclasses are allowed + class T(tuple): + pass + self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0) + + # Test handling of bad arguments + with self.assertRaises(TypeError): # Reject keyword args + dist(p=(1, 2, 3), q=(4, 5, 6)) + with self.assertRaises(TypeError): # Too few args + dist((1, 2, 3)) + with self.assertRaises(TypeError): # Too many args + dist((1, 2, 3), (4, 5, 6), (7, 8, 9)) + with self.assertRaises(TypeError): # Scalars not allowed + dist(1, 2) + with self.assertRaises(TypeError): # Reject values without __float__ + dist((1.1, 'string', 2.2), (1, 2, 3)) + with self.assertRaises(ValueError): # Check dimension agree + dist((1, 2, 3, 4), (5, 6, 7)) + with self.assertRaises(ValueError): # Check dimension agree + dist((1, 2, 3), (4, 5, 6, 7)) + with self.assertRaises(TypeError): + dist((1,)*17 + ("spam",), (1,)*18) + with self.assertRaises(TypeError): # Rejects invalid types + dist("abc", "xyz") + int_too_big_for_float = 10 ** (sys.float_info.max_10_exp + 5) + with self.assertRaises((ValueError, OverflowError)): + dist((1, int_too_big_for_float), (2, 3)) + with self.assertRaises((ValueError, OverflowError)): + dist((2, 3), (1, int_too_big_for_float)) + with self.assertRaises(TypeError): + dist((1,), 2) + with self.assertRaises(TypeError): + dist([1], 2) + + class BadFloat: + __float__ = BadDescr() + + with self.assertRaises(ValueError): + dist([1], [BadFloat()]) + + # Verify that the one dimensional case is equivalent to abs() + for i in range(20): + p, q = random.random(), random.random() + self.assertEqual(dist((p,), (q,)), abs(p - q)) + + # Test special values + values = [NINF, -10.5, -0.0, 0.0, 10.5, INF, NAN] + for p in itertools.product(values, repeat=3): + for q in itertools.product(values, repeat=3): + diffs = [px - qx for px, qx in zip(p, q)] + if any(map(math.isinf, diffs)): + # Any infinite difference gives positive infinity. + self.assertEqual(dist(p, q), INF) + elif any(map(math.isnan, diffs)): + # If no infinity, any NaN gives a NaN. + self.assertTrue(math.isnan(dist(p, q))) + + # Verify scaling for extremely large values + fourthmax = FLOAT_MAX / 4.0 + for n in range(32): + p = (fourthmax,) * n + q = (0.0,) * n + self.assertTrue(math.isclose(dist(p, q), fourthmax * math.sqrt(n))) + self.assertTrue(math.isclose(dist(q, p), fourthmax * math.sqrt(n))) + + # Verify scaling for extremely small values + for exp in range(32): + scale = FLOAT_MIN / 2.0 ** exp + p = (4*scale, 3*scale) + q = (0.0, 0.0) + self.assertEqual(math.dist(p, q), 5*scale) + self.assertEqual(math.dist(q, p), 5*scale) + + def test_math_dist_leak(self): + # gh-98897: Check for error handling does not leak memory + with self.assertRaises(ValueError): + math.dist([1, 2], [3, 4, 5]) + + @slowTest + def testIsqrt(self): + # Test a variety of inputs, large and small. + test_values = ( + list(range(1000)) + + list(range(10**6 - 1000, 10**6 + 1000)) + + [2**e + i for e in range(60, 200) for i in range(-40, 40)] + + [3**9999, 10**5001] + ) + + for value in test_values: + with self.subTest(value=value): + s = math.isqrt(value) + self.assertIs(type(s), int) + self.assertLessEqual(s*s, value) + self.assertLess(value, (s+1)*(s+1)) + + # Negative values + with self.assertRaises(ValueError): + math.isqrt(-1) + + # Integer-like things + s = math.isqrt(True) + self.assertIs(type(s), int) + self.assertEqual(s, 1) + + s = math.isqrt(False) + self.assertIs(type(s), int) + self.assertEqual(s, 0) + + class IntegerLike(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value + + s = math.isqrt(IntegerLike(1729)) + self.assertIs(type(s), int) + self.assertEqual(s, 41) + + with self.assertRaises(ValueError): + math.isqrt(IntegerLike(-3)) + + # Non-integer-like things + bad_values = [ + 3.5, "a string", decimal.Decimal("3.5"), 3.5j, + 100.0, -4.0, + ] + for value in bad_values: + with self.subTest(value=value): + with self.assertRaises(TypeError): + math.isqrt(value) + + def test_lcm(self): + lcm = math.lcm + self.assertEqual(lcm(0, 0), 0) + self.assertEqual(lcm(1, 0), 0) + self.assertEqual(lcm(-1, 0), 0) + self.assertEqual(lcm(0, 1), 0) + self.assertEqual(lcm(0, -1), 0) + self.assertEqual(lcm(7, 1), 7) + self.assertEqual(lcm(7, -1), 7) + self.assertEqual(lcm(-23, 15), 345) + self.assertEqual(lcm(120, 84), 840) + self.assertEqual(lcm(84, -120), 840) + self.assertEqual(lcm(1216342683557601535506311712, + 436522681849110124616458784), + 16592536571065866494401400422922201534178938447014944) + + x = 43461045657039990237 + y = 10645022458251153277 + for c in (652560, + 57655923087165495981): + a = x * c + b = y * c + d = x * y * c + self.assertEqual(lcm(a, b), d) + self.assertEqual(lcm(b, a), d) + self.assertEqual(lcm(-a, b), d) + self.assertEqual(lcm(b, -a), d) + self.assertEqual(lcm(a, -b), d) + self.assertEqual(lcm(-b, a), d) + self.assertEqual(lcm(-a, -b), d) + self.assertEqual(lcm(-b, -a), d) + + self.assertEqual(lcm(), 1) + self.assertEqual(lcm(120), 120) + self.assertEqual(lcm(-120), 120) + self.assertEqual(lcm(120, 84, 102), 14280) + self.assertEqual(lcm(120, 0, 84), 0) + + self.assertRaises(TypeError, lcm, 120.0) + self.assertRaises(TypeError, lcm, 120.0, 84) + self.assertRaises(TypeError, lcm, 120, 84.0) + self.assertRaises(TypeError, lcm, 120, 0, 84.0) + self.assertEqual(lcm(MyIndexable(120), MyIndexable(84)), 840) + + def testLdexp(self): + self.assertRaises(TypeError, math.ldexp) + self.assertRaises(TypeError, math.ldexp, 2.0, 1.1) + self.ftest('ldexp(0,1)', math.ldexp(0,1), 0) + self.ftest('ldexp(1,1)', math.ldexp(1,1), 2) + self.ftest('ldexp(1,-1)', math.ldexp(1,-1), 0.5) + self.ftest('ldexp(-1,1)', math.ldexp(-1,1), -2) + self.assertRaises(OverflowError, math.ldexp, 1., 1000000) + self.assertRaises(OverflowError, math.ldexp, -1., 1000000) + self.assertEqual(math.ldexp(1., -1000000), 0.) + self.assertEqual(math.ldexp(-1., -1000000), -0.) + self.assertEqual(math.ldexp(INF, 30), INF) + self.assertEqual(math.ldexp(NINF, -213), NINF) + self.assertTrue(math.isnan(math.ldexp(NAN, 0))) + + # large second argument + for n in [10**5, 10**10, 10**20, 10**40]: + self.assertEqual(math.ldexp(INF, -n), INF) + self.assertEqual(math.ldexp(NINF, -n), NINF) + self.assertEqual(math.ldexp(1., -n), 0.) + self.assertEqual(math.ldexp(-1., -n), -0.) + self.assertEqual(math.ldexp(0., -n), 0.) + self.assertEqual(math.ldexp(-0., -n), -0.) + self.assertTrue(math.isnan(math.ldexp(NAN, -n))) + + self.assertRaises(OverflowError, math.ldexp, 1., n) + self.assertRaises(OverflowError, math.ldexp, -1., n) + self.assertEqual(math.ldexp(0., n), 0.) + self.assertEqual(math.ldexp(-0., n), -0.) + self.assertEqual(math.ldexp(INF, n), INF) + self.assertEqual(math.ldexp(NINF, n), NINF) + self.assertTrue(math.isnan(math.ldexp(NAN, n))) + + def testLog(self): + self.assertRaises(TypeError, math.log) + self.assertRaises(TypeError, math.log, 1, 2, 3) + self.ftest('log(1/e)', math.log(1/math.e), -1) + self.ftest('log(1)', math.log(1), 0) + self.ftest('log(e)', math.log(math.e), 1) + self.ftest('log(32,2)', math.log(32,2), 5) + self.ftest('log(10**40, 10)', math.log(10**40, 10), 40) + self.ftest('log(10**40, 10**20)', math.log(10**40, 10**20), 2) + self.ftest('log(10**1000)', math.log(10**1000), + 2302.5850929940457) + self.assertRaises(ValueError, math.log, -1.5) + self.assertRaises(ValueError, math.log, -10**1000) + self.assertRaises(ValueError, math.log, 10, -10) + self.assertRaises(ValueError, math.log, NINF) + self.assertEqual(math.log(INF), INF) + self.assertTrue(math.isnan(math.log(NAN))) + + def testLog1p(self): + self.assertRaises(TypeError, math.log1p) + for n in [2, 2**90, 2**300]: + self.assertAlmostEqual(math.log1p(n), math.log1p(float(n))) + self.assertRaises(ValueError, math.log1p, -1) + self.assertEqual(math.log1p(INF), INF) + + @skipIfTorchDynamo("Infinite loop") + @requires_IEEE_754 + def testLog2(self): + self.assertRaises(TypeError, math.log2) + + # Check some integer values + self.assertEqual(math.log2(1), 0.0) + self.assertEqual(math.log2(2), 1.0) + self.assertEqual(math.log2(4), 2.0) + + # Large integer values + self.assertEqual(math.log2(2**1023), 1023.0) + self.assertEqual(math.log2(2**1024), 1024.0) + self.assertEqual(math.log2(2**2000), 2000.0) + + self.assertRaises(ValueError, math.log2, -1.5) + self.assertRaises(ValueError, math.log2, NINF) + self.assertTrue(math.isnan(math.log2(NAN))) + + @skipIfTorchDynamo("Infinite loop") + @requires_IEEE_754 + # log2() is not accurate enough on Mac OS X Tiger (10.4) + @support.requires_mac_ver(10, 5) + def testLog2Exact(self): + # Check that we get exact equality for log2 of powers of 2. + actual = [math.log2(math.ldexp(1.0, n)) for n in range(-1074, 1024)] + expected = [float(n) for n in range(-1074, 1024)] + self.assertEqual(actual, expected) + + def testLog10(self): + self.assertRaises(TypeError, math.log10) + self.ftest('log10(0.1)', math.log10(0.1), -1) + self.ftest('log10(1)', math.log10(1), 0) + self.ftest('log10(10)', math.log10(10), 1) + self.ftest('log10(10**1000)', math.log10(10**1000), 1000.0) + self.assertRaises(ValueError, math.log10, -1.5) + self.assertRaises(ValueError, math.log10, -10**1000) + self.assertRaises(ValueError, math.log10, NINF) + self.assertEqual(math.log(INF), INF) + self.assertTrue(math.isnan(math.log10(NAN))) + + def testSumProd(self): + sumprod = math.sumprod + Decimal = decimal.Decimal + Fraction = fractions.Fraction + + # Core functionality + self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140) + self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5) + self.assertEqual(sumprod([], []), 0) + self.assertEqual(sumprod([-1], [1.]), -1) + self.assertEqual(sumprod([1.], [-1]), -1) + + # Type preservation and coercion + for v in [ + (10, 20, 30), + (1.5, -2.5), + (Fraction(3, 5), Fraction(4, 5)), + (Decimal(3.5), Decimal(4.5)), + (2.5, 10), # float/int + (2.5, Fraction(3, 5)), # float/fraction + (25, Fraction(3, 5)), # int/fraction + (25, Decimal(4.5)), # int/decimal + ]: + for p, q in [(v, v), (v, v[::-1])]: + with self.subTest(p=p, q=q): + expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True)) + actual = sumprod(p, q) + self.assertEqual(expected, actual) + self.assertEqual(type(expected), type(actual)) + + # Bad arguments + self.assertRaises(TypeError, sumprod) # No args + self.assertRaises(TypeError, sumprod, []) # One arg + self.assertRaises(TypeError, sumprod, [], [], []) # Three args + self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable + self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable + self.assertRaises(TypeError, sumprod, ['x'], [1.0]) + + # Uneven lengths + self.assertRaises(ValueError, sumprod, [10, 20], [30]) + self.assertRaises(ValueError, sumprod, [10], [20, 30]) + + # Overflows + self.assertEqual(sumprod([10**20], [1]), 10**20) + self.assertEqual(sumprod([1], [10**20]), 10**20) + self.assertEqual(sumprod([10**10], [10**10]), 10**20) + self.assertEqual(sumprod([10**7]*10**5, [10**7]*10**5), 10**19) + self.assertRaises(OverflowError, sumprod, [10**1000], [1.0]) + self.assertRaises(OverflowError, sumprod, [1.0], [10**1000]) + + # Error in iterator + def raise_after(n): + for i in range(n): + yield i + raise RuntimeError + with self.assertRaises(RuntimeError): + sumprod(range(10), raise_after(5)) + with self.assertRaises(RuntimeError): + sumprod(raise_after(5), range(10)) + + from test_iter import BasicIterClass + + self.assertEqual(sumprod(BasicIterClass(1), [1]), 0) + self.assertEqual(sumprod([1], BasicIterClass(1)), 0) + + # Error in multiplication + class BadMultiply: + def __mul__(self, other): + raise RuntimeError + def __rmul__(self, other): + raise RuntimeError + with self.assertRaises(RuntimeError): + sumprod([10, BadMultiply(), 30], [1, 2, 3]) + with self.assertRaises(RuntimeError): + sumprod([1, 2, 3], [10, BadMultiply(), 30]) + + # Error in addition + with self.assertRaises(TypeError): + sumprod(['abc', 3], [5, 10]) + with self.assertRaises(TypeError): + sumprod([5, 10], ['abc', 3]) + + # Special values should give the same as the pure python recipe + self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf) + self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf) + self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf) + self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf]))) + self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3]))) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3]))) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan]))) + + # Error cases that arose during development + args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952)) + self.assertEqual(sumprod(*args), 0.0) + + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "sumprod() accuracy not guaranteed on machines with double rounding") + @support.cpython_only # Other implementations may choose a different algorithm + def test_sumprod_accuracy(self): + sumprod = math.sumprod + self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0) + self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0) + self.assertEqual(sumprod([True, False] * 10, [0.1] * 20), 1.0) + self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0) + + @support.requires_resource('cpu') + def test_sumprod_stress(self): + sumprod = math.sumprod + product = itertools.product + Decimal = decimal.Decimal + Fraction = fractions.Fraction + + class Int(int): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Int({int(self)})' + + class Flt(float): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Flt({int(self)})' + + def baseline_sumprod(p, q): + """This defines the target behavior including exceptions and special values. + However, it is subject to rounding errors, so float inputs should be exactly + representable with only a few bits. + """ + total = 0 + for p_i, q_i in zip(p, q, strict=True): + total += p_i * q_i + return total + + def run(func, *args): + "Make comparing functions easier. Returns error status, type, and result." + try: + result = func(*args) + except (AssertionError, NameError): + raise + except Exception as e: + return type(e), None, 'None' + return None, type(result), repr(result) + + pools = [ + (-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)), + (5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125), + (-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333, + 5.25, -3.25, -3.0*2**(-333), 3, 2**513), + (3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14, + 9, 3+4j, Flt(13), 0.0), + (13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8), + Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)), + (Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0), + Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5), + (-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538, + 2*2**-513), + (-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25), + (11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)), + ] + + for pool in pools: + for size in range(4): + for args1 in product(pool, repeat=size): + for args2 in product(pool, repeat=size): + args = (args1, args2) + self.assertEqual( + run(baseline_sumprod, *args), + run(sumprod, *args), + args, + ) + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "sumprod() accuracy not guaranteed on machines with double rounding") + @support.cpython_only # Other implementations may choose a different algorithm + @support.requires_resource('cpu') + def test_sumprod_extended_precision_accuracy(self): + import operator + from fractions import Fraction + from itertools import starmap + from collections import namedtuple + from math import log2, exp2, fabs + from random import choices, uniform, shuffle + from statistics import median + + DotExample = namedtuple('DotExample', ('x', 'y', 'target_sumprod', 'condition')) + + def DotExact(x, y): + vec1 = map(Fraction, x) + vec2 = map(Fraction, y) + return sum(starmap(operator.mul, zip(vec1, vec2, strict=True))) + + def Condition(x, y): + return 2.0 * DotExact(map(abs, x), map(abs, y)) / abs(DotExact(x, y)) + + def linspace(lo, hi, n): + width = (hi - lo) / (n - 1) + return [lo + width * i for i in range(n)] + + def GenDot(n, c): + """ Algorithm 6.1 (GenDot) works as follows. The condition number (5.7) of + the dot product xT y is proportional to the degree of cancellation. In + order to achieve a prescribed cancellation, we generate the first half of + the vectors x and y randomly within a large exponent range. This range is + chosen according to the anticipated condition number. The second half of x + and y is then constructed choosing xi randomly with decreasing exponent, + and calculating yi such that some cancellation occurs. Finally, we permute + the vectors x, y randomly and calculate the achieved condition number. + """ + + assert n >= 6 + n2 = n // 2 + x = [0.0] * n + y = [0.0] * n + b = log2(c) + + # First half with exponents from 0 to |_b/2_| and random ints in between + e = choices(range(int(b/2)), k=n2) + e[0] = int(b / 2) + 1 + e[-1] = 0.0 + + x[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e] + y[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e] + + # Second half + e = list(map(round, linspace(b/2, 0.0 , n-n2))) + for i in range(n2, n): + x[i] = uniform(-1.0, 1.0) * exp2(e[i - n2]) + y[i] = (uniform(-1.0, 1.0) * exp2(e[i - n2]) - DotExact(x, y)) / x[i] + + # Shuffle + pairs = list(zip(x, y)) + shuffle(pairs) + x, y = zip(*pairs) + + return DotExample(x, y, DotExact(x, y), Condition(x, y)) + + def RelativeError(res, ex): + x, y, target_sumprod, condition = ex + n = DotExact(list(x) + [-res], list(y) + [1]) + return fabs(n / target_sumprod) + + def Trial(dotfunc, c, n): + ex = GenDot(10, c) + res = dotfunc(ex.x, ex.y) + return RelativeError(res, ex) + + times = 1000 # Number of trials + n = 20 # Length of vectors + c = 1e30 # Target condition number + + # If the following test fails, it means that the C math library + # implementation of fma() is not compliant with the C99 standard + # and is inaccurate. To solve this problem, make a new build + # with the symbol UNRELIABLE_FMA defined. That will enable a + # slower but accurate code path that avoids the fma() call. + relative_err = median(Trial(math.sumprod, c, n) for i in range(times)) + self.assertLess(relative_err, 1e-16) + + def testModf(self): + self.assertRaises(TypeError, math.modf) + + def testmodf(name, result, expected): + (v1, v2), (e1, e2) = result, expected + if abs(v1-e1) > eps or abs(v2-e2): + self.fail('%s returned %r, expected %r'%\ + (name, result, expected)) + + testmodf('modf(1.5)', math.modf(1.5), (0.5, 1.0)) + testmodf('modf(-1.5)', math.modf(-1.5), (-0.5, -1.0)) + + self.assertEqual(math.modf(INF), (0.0, INF)) + self.assertEqual(math.modf(NINF), (-0.0, NINF)) + + modf_nan = math.modf(NAN) + self.assertTrue(math.isnan(modf_nan[0])) + self.assertTrue(math.isnan(modf_nan[1])) + + def testPow(self): + self.assertRaises(TypeError, math.pow) + self.ftest('pow(0,1)', math.pow(0,1), 0) + self.ftest('pow(1,0)', math.pow(1,0), 1) + self.ftest('pow(2,1)', math.pow(2,1), 2) + self.ftest('pow(2,-1)', math.pow(2,-1), 0.5) + self.assertEqual(math.pow(INF, 1), INF) + self.assertEqual(math.pow(NINF, 1), NINF) + self.assertEqual((math.pow(1, INF)), 1.) + self.assertEqual((math.pow(1, NINF)), 1.) + self.assertTrue(math.isnan(math.pow(NAN, 1))) + self.assertTrue(math.isnan(math.pow(2, NAN))) + self.assertTrue(math.isnan(math.pow(0, NAN))) + self.assertEqual(math.pow(1, NAN), 1) + self.assertRaises(OverflowError, math.pow, 1e+100, 1e+100) + + # pow(0., x) + self.assertEqual(math.pow(0., INF), 0.) + self.assertEqual(math.pow(0., 3.), 0.) + self.assertEqual(math.pow(0., 2.3), 0.) + self.assertEqual(math.pow(0., 2.), 0.) + self.assertEqual(math.pow(0., 0.), 1.) + self.assertEqual(math.pow(0., -0.), 1.) + self.assertRaises(ValueError, math.pow, 0., -2.) + self.assertRaises(ValueError, math.pow, 0., -2.3) + self.assertRaises(ValueError, math.pow, 0., -3.) + self.assertEqual(math.pow(0., NINF), INF) + self.assertTrue(math.isnan(math.pow(0., NAN))) + + # pow(INF, x) + self.assertEqual(math.pow(INF, INF), INF) + self.assertEqual(math.pow(INF, 3.), INF) + self.assertEqual(math.pow(INF, 2.3), INF) + self.assertEqual(math.pow(INF, 2.), INF) + self.assertEqual(math.pow(INF, 0.), 1.) + self.assertEqual(math.pow(INF, -0.), 1.) + self.assertEqual(math.pow(INF, -2.), 0.) + self.assertEqual(math.pow(INF, -2.3), 0.) + self.assertEqual(math.pow(INF, -3.), 0.) + self.assertEqual(math.pow(INF, NINF), 0.) + self.assertTrue(math.isnan(math.pow(INF, NAN))) + + # pow(-0., x) + self.assertEqual(math.pow(-0., INF), 0.) + self.assertEqual(math.pow(-0., 3.), -0.) + self.assertEqual(math.pow(-0., 2.3), 0.) + self.assertEqual(math.pow(-0., 2.), 0.) + self.assertEqual(math.pow(-0., 0.), 1.) + self.assertEqual(math.pow(-0., -0.), 1.) + self.assertRaises(ValueError, math.pow, -0., -2.) + self.assertRaises(ValueError, math.pow, -0., -2.3) + self.assertRaises(ValueError, math.pow, -0., -3.) + self.assertEqual(math.pow(-0., NINF), INF) + self.assertTrue(math.isnan(math.pow(-0., NAN))) + + # pow(NINF, x) + self.assertEqual(math.pow(NINF, INF), INF) + self.assertEqual(math.pow(NINF, 3.), NINF) + self.assertEqual(math.pow(NINF, 2.3), INF) + self.assertEqual(math.pow(NINF, 2.), INF) + self.assertEqual(math.pow(NINF, 0.), 1.) + self.assertEqual(math.pow(NINF, -0.), 1.) + self.assertEqual(math.pow(NINF, -2.), 0.) + self.assertEqual(math.pow(NINF, -2.3), 0.) + self.assertEqual(math.pow(NINF, -3.), -0.) + self.assertEqual(math.pow(NINF, NINF), 0.) + self.assertTrue(math.isnan(math.pow(NINF, NAN))) + + # pow(-1, x) + self.assertEqual(math.pow(-1., INF), 1.) + self.assertEqual(math.pow(-1., 3.), -1.) + self.assertRaises(ValueError, math.pow, -1., 2.3) + self.assertEqual(math.pow(-1., 2.), 1.) + self.assertEqual(math.pow(-1., 0.), 1.) + self.assertEqual(math.pow(-1., -0.), 1.) + self.assertEqual(math.pow(-1., -2.), 1.) + self.assertRaises(ValueError, math.pow, -1., -2.3) + self.assertEqual(math.pow(-1., -3.), -1.) + self.assertEqual(math.pow(-1., NINF), 1.) + self.assertTrue(math.isnan(math.pow(-1., NAN))) + + # pow(1, x) + self.assertEqual(math.pow(1., INF), 1.) + self.assertEqual(math.pow(1., 3.), 1.) + self.assertEqual(math.pow(1., 2.3), 1.) + self.assertEqual(math.pow(1., 2.), 1.) + self.assertEqual(math.pow(1., 0.), 1.) + self.assertEqual(math.pow(1., -0.), 1.) + self.assertEqual(math.pow(1., -2.), 1.) + self.assertEqual(math.pow(1., -2.3), 1.) + self.assertEqual(math.pow(1., -3.), 1.) + self.assertEqual(math.pow(1., NINF), 1.) + self.assertEqual(math.pow(1., NAN), 1.) + + # pow(x, 0) should be 1 for any x + self.assertEqual(math.pow(2.3, 0.), 1.) + self.assertEqual(math.pow(-2.3, 0.), 1.) + self.assertEqual(math.pow(NAN, 0.), 1.) + self.assertEqual(math.pow(2.3, -0.), 1.) + self.assertEqual(math.pow(-2.3, -0.), 1.) + self.assertEqual(math.pow(NAN, -0.), 1.) + + # pow(x, y) is invalid if x is negative and y is not integral + self.assertRaises(ValueError, math.pow, -1., 2.3) + self.assertRaises(ValueError, math.pow, -15., -3.1) + + # pow(x, NINF) + self.assertEqual(math.pow(1.9, NINF), 0.) + self.assertEqual(math.pow(1.1, NINF), 0.) + self.assertEqual(math.pow(0.9, NINF), INF) + self.assertEqual(math.pow(0.1, NINF), INF) + self.assertEqual(math.pow(-0.1, NINF), INF) + self.assertEqual(math.pow(-0.9, NINF), INF) + self.assertEqual(math.pow(-1.1, NINF), 0.) + self.assertEqual(math.pow(-1.9, NINF), 0.) + + # pow(x, INF) + self.assertEqual(math.pow(1.9, INF), INF) + self.assertEqual(math.pow(1.1, INF), INF) + self.assertEqual(math.pow(0.9, INF), 0.) + self.assertEqual(math.pow(0.1, INF), 0.) + self.assertEqual(math.pow(-0.1, INF), 0.) + self.assertEqual(math.pow(-0.9, INF), 0.) + self.assertEqual(math.pow(-1.1, INF), INF) + self.assertEqual(math.pow(-1.9, INF), INF) + + # pow(x, y) should work for x negative, y an integer + self.ftest('(-2.)**3.', math.pow(-2.0, 3.0), -8.0) + self.ftest('(-2.)**2.', math.pow(-2.0, 2.0), 4.0) + self.ftest('(-2.)**1.', math.pow(-2.0, 1.0), -2.0) + self.ftest('(-2.)**0.', math.pow(-2.0, 0.0), 1.0) + self.ftest('(-2.)**-0.', math.pow(-2.0, -0.0), 1.0) + self.ftest('(-2.)**-1.', math.pow(-2.0, -1.0), -0.5) + self.ftest('(-2.)**-2.', math.pow(-2.0, -2.0), 0.25) + self.ftest('(-2.)**-3.', math.pow(-2.0, -3.0), -0.125) + self.assertRaises(ValueError, math.pow, -2.0, -0.5) + self.assertRaises(ValueError, math.pow, -2.0, 0.5) + + # the following tests have been commented out since they don't + # really belong here: the implementation of ** for floats is + # independent of the implementation of math.pow + #self.assertEqual(1**NAN, 1) + #self.assertEqual(1**INF, 1) + #self.assertEqual(1**NINF, 1) + #self.assertEqual(1**0, 1) + #self.assertEqual(1.**NAN, 1) + #self.assertEqual(1.**INF, 1) + #self.assertEqual(1.**NINF, 1) + #self.assertEqual(1.**0, 1) + + def testRadians(self): + self.assertRaises(TypeError, math.radians) + self.ftest('radians(180)', math.radians(180), math.pi) + self.ftest('radians(90)', math.radians(90), math.pi/2) + self.ftest('radians(-45)', math.radians(-45), -math.pi/4) + self.ftest('radians(0)', math.radians(0), 0) + + @requires_IEEE_754 + def testRemainder(self): + from fractions import Fraction + + def validate_spec(x, y, r): + """ + Check that r matches remainder(x, y) according to the IEEE 754 + specification. Assumes that x, y and r are finite and y is nonzero. + """ + fx, fy, fr = Fraction(x), Fraction(y), Fraction(r) + # r should not exceed y/2 in absolute value + self.assertLessEqual(abs(fr), abs(fy/2)) + # x - r should be an exact integer multiple of y + n = (fx - fr) / fy + self.assertEqual(n, int(n)) + if abs(fr) == abs(fy/2): + # If |r| == |y/2|, n should be even. + self.assertEqual(n/2, int(n/2)) + + # triples (x, y, remainder(x, y)) in hexadecimal form. + testcases = [ + # Remainders modulo 1, showing the ties-to-even behaviour. + '-4.0 1 -0.0', + '-3.8 1 0.8', + '-3.0 1 -0.0', + '-2.8 1 -0.8', + '-2.0 1 -0.0', + '-1.8 1 0.8', + '-1.0 1 -0.0', + '-0.8 1 -0.8', + '-0.0 1 -0.0', + ' 0.0 1 0.0', + ' 0.8 1 0.8', + ' 1.0 1 0.0', + ' 1.8 1 -0.8', + ' 2.0 1 0.0', + ' 2.8 1 0.8', + ' 3.0 1 0.0', + ' 3.8 1 -0.8', + ' 4.0 1 0.0', + + # Reductions modulo 2*pi + '0x0.0p+0 0x1.921fb54442d18p+2 0x0.0p+0', + '0x1.921fb54442d18p+0 0x1.921fb54442d18p+2 0x1.921fb54442d18p+0', + '0x1.921fb54442d17p+1 0x1.921fb54442d18p+2 0x1.921fb54442d17p+1', + '0x1.921fb54442d18p+1 0x1.921fb54442d18p+2 0x1.921fb54442d18p+1', + '0x1.921fb54442d19p+1 0x1.921fb54442d18p+2 -0x1.921fb54442d17p+1', + '0x1.921fb54442d17p+2 0x1.921fb54442d18p+2 -0x0.0000000000001p+2', + '0x1.921fb54442d18p+2 0x1.921fb54442d18p+2 0x0p0', + '0x1.921fb54442d19p+2 0x1.921fb54442d18p+2 0x0.0000000000001p+2', + '0x1.2d97c7f3321d1p+3 0x1.921fb54442d18p+2 0x1.921fb54442d14p+1', + '0x1.2d97c7f3321d2p+3 0x1.921fb54442d18p+2 -0x1.921fb54442d18p+1', + '0x1.2d97c7f3321d3p+3 0x1.921fb54442d18p+2 -0x1.921fb54442d14p+1', + '0x1.921fb54442d17p+3 0x1.921fb54442d18p+2 -0x0.0000000000001p+3', + '0x1.921fb54442d18p+3 0x1.921fb54442d18p+2 0x0p0', + '0x1.921fb54442d19p+3 0x1.921fb54442d18p+2 0x0.0000000000001p+3', + '0x1.f6a7a2955385dp+3 0x1.921fb54442d18p+2 0x1.921fb54442d14p+1', + '0x1.f6a7a2955385ep+3 0x1.921fb54442d18p+2 0x1.921fb54442d18p+1', + '0x1.f6a7a2955385fp+3 0x1.921fb54442d18p+2 -0x1.921fb54442d14p+1', + '0x1.1475cc9eedf00p+5 0x1.921fb54442d18p+2 0x1.921fb54442d10p+1', + '0x1.1475cc9eedf01p+5 0x1.921fb54442d18p+2 -0x1.921fb54442d10p+1', + + # Symmetry with respect to signs. + ' 1 0.c 0.4', + '-1 0.c -0.4', + ' 1 -0.c 0.4', + '-1 -0.c -0.4', + ' 1.4 0.c -0.4', + '-1.4 0.c 0.4', + ' 1.4 -0.c -0.4', + '-1.4 -0.c 0.4', + + # Huge modulus, to check that the underlying algorithm doesn't + # rely on 2.0 * modulus being representable. + '0x1.dp+1023 0x1.4p+1023 0x0.9p+1023', + '0x1.ep+1023 0x1.4p+1023 -0x0.ap+1023', + '0x1.fp+1023 0x1.4p+1023 -0x0.9p+1023', + ] + + for case in testcases: + with self.subTest(case=case): + x_hex, y_hex, expected_hex = case.split() + x = float.fromhex(x_hex) + y = float.fromhex(y_hex) + expected = float.fromhex(expected_hex) + validate_spec(x, y, expected) + actual = math.remainder(x, y) + # Cheap way of checking that the floats are + # as identical as we need them to be. + self.assertEqual(actual.hex(), expected.hex()) + + # Test tiny subnormal modulus: there's potential for + # getting the implementation wrong here (for example, + # by assuming that modulus/2 is exactly representable). + tiny = float.fromhex('1p-1074') # min +ve subnormal + for n in range(-25, 25): + if n == 0: + continue + y = n * tiny + for m in range(100): + x = m * tiny + actual = math.remainder(x, y) + validate_spec(x, y, actual) + actual = math.remainder(-x, y) + validate_spec(-x, y, actual) + + # Special values. + # NaNs should propagate as usual. + for value in [NAN, 0.0, -0.0, 2.0, -2.3, NINF, INF]: + self.assertIsNaN(math.remainder(NAN, value)) + self.assertIsNaN(math.remainder(value, NAN)) + + # remainder(x, inf) is x, for non-nan non-infinite x. + for value in [-2.3, -0.0, 0.0, 2.3]: + self.assertEqual(math.remainder(value, INF), value) + self.assertEqual(math.remainder(value, NINF), value) + + # remainder(x, 0) and remainder(infinity, x) for non-NaN x are invalid + # operations according to IEEE 754-2008 7.2(f), and should raise. + for value in [NINF, -2.3, -0.0, 0.0, 2.3, INF]: + with self.assertRaises(ValueError): + math.remainder(INF, value) + with self.assertRaises(ValueError): + math.remainder(NINF, value) + with self.assertRaises(ValueError): + math.remainder(value, 0.0) + with self.assertRaises(ValueError): + math.remainder(value, -0.0) + + def testSin(self): + self.assertRaises(TypeError, math.sin) + self.ftest('sin(0)', math.sin(0), 0) + self.ftest('sin(pi/2)', math.sin(math.pi/2), 1) + self.ftest('sin(-pi/2)', math.sin(-math.pi/2), -1) + try: + self.assertTrue(math.isnan(math.sin(INF))) + self.assertTrue(math.isnan(math.sin(NINF))) + except ValueError: + self.assertRaises(ValueError, math.sin, INF) + self.assertRaises(ValueError, math.sin, NINF) + self.assertTrue(math.isnan(math.sin(NAN))) + + def testSinh(self): + self.assertRaises(TypeError, math.sinh) + self.ftest('sinh(0)', math.sinh(0), 0) + self.ftest('sinh(1)**2-cosh(1)**2', math.sinh(1)**2-math.cosh(1)**2, -1) + self.ftest('sinh(1)+sinh(-1)', math.sinh(1)+math.sinh(-1), 0) + self.assertEqual(math.sinh(INF), INF) + self.assertEqual(math.sinh(NINF), NINF) + self.assertTrue(math.isnan(math.sinh(NAN))) + + def testSqrt(self): + self.assertRaises(TypeError, math.sqrt) + self.ftest('sqrt(0)', math.sqrt(0), 0) + self.ftest('sqrt(0)', math.sqrt(0.0), 0.0) + self.ftest('sqrt(2.5)', math.sqrt(2.5), 1.5811388300841898) + self.ftest('sqrt(0.25)', math.sqrt(0.25), 0.5) + self.ftest('sqrt(25.25)', math.sqrt(25.25), 5.024937810560445) + self.ftest('sqrt(1)', math.sqrt(1), 1) + self.ftest('sqrt(4)', math.sqrt(4), 2) + self.assertEqual(math.sqrt(INF), INF) + self.assertRaises(ValueError, math.sqrt, -1) + self.assertRaises(ValueError, math.sqrt, NINF) + self.assertTrue(math.isnan(math.sqrt(NAN))) + + def testTan(self): + self.assertRaises(TypeError, math.tan) + self.ftest('tan(0)', math.tan(0), 0) + self.ftest('tan(pi/4)', math.tan(math.pi/4), 1) + self.ftest('tan(-pi/4)', math.tan(-math.pi/4), -1) + try: + self.assertTrue(math.isnan(math.tan(INF))) + self.assertTrue(math.isnan(math.tan(NINF))) + except ValueError: + self.assertRaises(ValueError, math.tan, INF) + self.assertRaises(ValueError, math.tan, NINF) + self.assertTrue(math.isnan(math.tan(NAN))) + + def testTanh(self): + self.assertRaises(TypeError, math.tanh) + self.ftest('tanh(0)', math.tanh(0), 0) + self.ftest('tanh(1)+tanh(-1)', math.tanh(1)+math.tanh(-1), 0, + abs_tol=math.ulp(1)) + self.ftest('tanh(inf)', math.tanh(INF), 1) + self.ftest('tanh(-inf)', math.tanh(NINF), -1) + self.assertTrue(math.isnan(math.tanh(NAN))) + + @requires_IEEE_754 + def testTanhSign(self): + # check that tanh(-0.) == -0. on IEEE 754 systems + self.assertEqual(math.tanh(-0.), -0.) + self.assertEqual(math.copysign(1., math.tanh(-0.)), + math.copysign(1., -0.)) + + def test_trunc(self): + self.assertEqual(math.trunc(1), 1) + self.assertEqual(math.trunc(-1), -1) + self.assertEqual(type(math.trunc(1)), int) + self.assertEqual(type(math.trunc(1.5)), int) + self.assertEqual(math.trunc(1.5), 1) + self.assertEqual(math.trunc(-1.5), -1) + self.assertEqual(math.trunc(1.999999), 1) + self.assertEqual(math.trunc(-1.999999), -1) + self.assertEqual(math.trunc(-0.999999), -0) + self.assertEqual(math.trunc(-100.999), -100) + + class TestTrunc: + def __trunc__(self): + return 23 + class FloatTrunc(float): + def __trunc__(self): + return 23 + class TestNoTrunc: + pass + class TestBadTrunc: + __trunc__ = BadDescr() + + self.assertEqual(math.trunc(TestTrunc()), 23) + self.assertEqual(math.trunc(FloatTrunc()), 23) + + self.assertRaises(TypeError, math.trunc) + self.assertRaises(TypeError, math.trunc, 1, 2) + self.assertRaises(TypeError, math.trunc, FloatLike(23.5)) + self.assertRaises(TypeError, math.trunc, TestNoTrunc()) + self.assertRaises(ValueError, math.trunc, TestBadTrunc()) + + def testIsfinite(self): + self.assertTrue(math.isfinite(0.0)) + self.assertTrue(math.isfinite(-0.0)) + self.assertTrue(math.isfinite(1.0)) + self.assertTrue(math.isfinite(-1.0)) + self.assertFalse(math.isfinite(float("nan"))) + self.assertFalse(math.isfinite(float("inf"))) + self.assertFalse(math.isfinite(float("-inf"))) + + def testIsnan(self): + self.assertTrue(math.isnan(float("nan"))) + self.assertTrue(math.isnan(float("-nan"))) + self.assertTrue(math.isnan(float("inf") * 0.)) + self.assertFalse(math.isnan(float("inf"))) + self.assertFalse(math.isnan(0.)) + self.assertFalse(math.isnan(1.)) + + def testIsinf(self): + self.assertTrue(math.isinf(float("inf"))) + self.assertTrue(math.isinf(float("-inf"))) + self.assertTrue(math.isinf(1E400)) + self.assertTrue(math.isinf(-1E400)) + self.assertFalse(math.isinf(float("nan"))) + self.assertFalse(math.isinf(0.)) + self.assertFalse(math.isinf(1.)) + + def test_nan_constant(self): + # `math.nan` must be a quiet NaN with positive sign bit + self.assertTrue(math.isnan(math.nan)) + self.assertEqual(math.copysign(1., math.nan), 1.) + + def test_inf_constant(self): + self.assertTrue(math.isinf(math.inf)) + self.assertGreater(math.inf, 0.0) + self.assertEqual(math.inf, float("inf")) + self.assertEqual(-math.inf, float("-inf")) + + # RED_FLAG 16-Oct-2000 Tim + # While 2.0 is more consistent about exceptions than previous releases, it + # still fails this part of the test on some platforms. For now, we only + # *run* test_exceptions() in verbose mode, so that this isn't normally + # tested. + @unittest.skipUnless(verbose, 'requires verbose mode') + def test_exceptions(self): + try: + x = math.exp(-1000000000) + except: + # mathmodule.c is failing to weed out underflows from libm, or + # we've got an fp format with huge dynamic range + self.fail("underflowing exp() should not have raised " + "an exception") + if x != 0: + self.fail("underflowing exp() should have returned 0") + + # If this fails, probably using a strict IEEE-754 conforming libm, and x + # is +Inf afterwards. But Python wants overflows detected by default. + try: + x = math.exp(1000000000) + except OverflowError: + pass + else: + self.fail("overflowing exp() didn't trigger OverflowError") + + # If this fails, it could be a puzzle. One odd possibility is that + # mathmodule.c's macros are getting confused while comparing + # Inf (HUGE_VAL) to a NaN, and artificially setting errno to ERANGE + # as a result (and so raising OverflowError instead). + try: + x = math.sqrt(-1.0) + except ValueError: + pass + else: + self.fail("sqrt(-1) didn't raise ValueError") + + @requires_IEEE_754 + def test_testfile(self): + # Some tests need to be skipped on ancient OS X versions. + # See issue #27953. + SKIP_ON_TIGER = {'tan0064'} + + osx_version = None + if sys.platform == 'darwin': + version_txt = platform.mac_ver()[0] + try: + osx_version = tuple(map(int, version_txt.split('.'))) + except ValueError: + pass + + fail_fmt = "{}: {}({!r}): {}" + + failures = [] + for id, fn, ar, ai, er, ei, flags in parse_testfile(test_file): + # Skip if either the input or result is complex + if ai != 0.0 or ei != 0.0: + continue + if fn in ['rect', 'polar']: + # no real versions of rect, polar + continue + # Skip certain tests on OS X 10.4. + if osx_version is not None and osx_version < (10, 5): + if id in SKIP_ON_TIGER: + continue + + func = getattr(math, fn) + + if 'invalid' in flags or 'divide-by-zero' in flags: + er = 'ValueError' + elif 'overflow' in flags: + er = 'OverflowError' + + try: + result = func(ar) + except ValueError: + result = 'ValueError' + except OverflowError: + result = 'OverflowError' + + # C99+ says for math.h's sqrt: If the argument is +∞ or ±0, it is + # returned, unmodified. On another hand, for csqrt: If z is ±0+0i, + # the result is +0+0i. Lets correct zero sign of er to follow + # first convention. + if id in ['sqrt0002', 'sqrt0003', 'sqrt1001', 'sqrt1023']: + er = math.copysign(er, ar) + + # Default tolerances + ulp_tol, abs_tol = 5, 0.0 + + failure = result_check(er, result, ulp_tol, abs_tol) + if failure is None: + continue + + msg = fail_fmt.format(id, fn, ar, failure) + failures.append(msg) + + if failures: + self.fail('Failures in test_testfile:\n ' + + '\n '.join(failures)) + + @requires_IEEE_754 + def test_mtestfile(self): + fail_fmt = "{}: {}({!r}): {}" + + failures = [] + for id, fn, arg, expected, flags in parse_mtestfile(math_testcases): + func = getattr(math, fn) + + if 'invalid' in flags or 'divide-by-zero' in flags: + expected = 'ValueError' + elif 'overflow' in flags: + expected = 'OverflowError' + + try: + got = func(arg) + except ValueError: + got = 'ValueError' + except OverflowError: + got = 'OverflowError' + + # Default tolerances + ulp_tol, abs_tol = 5, 0.0 + + # Exceptions to the defaults + if fn == 'gamma': + # Experimental results on one platform gave + # an accuracy of <= 10 ulps across the entire float + # domain. We weaken that to require 20 ulp accuracy. + ulp_tol = 20 + + elif fn == 'lgamma': + # we use a weaker accuracy test for lgamma; + # lgamma only achieves an absolute error of + # a few multiples of the machine accuracy, in + # general. + abs_tol = 1e-15 + + elif fn == 'erfc' and arg >= 0.0: + # erfc has less-than-ideal accuracy for large + # arguments (x ~ 25 or so), mainly due to the + # error involved in computing exp(-x*x). + # + # Observed between CPython and mpmath at 25 dp: + # x < 0 : err <= 2 ulp + # 0 <= x < 1 : err <= 10 ulp + # 1 <= x < 10 : err <= 100 ulp + # 10 <= x < 20 : err <= 300 ulp + # 20 <= x : < 600 ulp + # + if arg < 1.0: + ulp_tol = 10 + elif arg < 10.0: + ulp_tol = 100 + else: + ulp_tol = 1000 + + failure = result_check(expected, got, ulp_tol, abs_tol) + if failure is None: + continue + + msg = fail_fmt.format(id, fn, arg, failure) + failures.append(msg) + + if failures: + self.fail('Failures in test_mtestfile:\n ' + + '\n '.join(failures)) + + def test_prod(self): + from fractions import Fraction as F + + prod = math.prod + self.assertEqual(prod([]), 1) + self.assertEqual(prod([], start=5), 5) + self.assertEqual(prod(list(range(2,8))), 5040) + self.assertEqual(prod(iter(list(range(2,8)))), 5040) + self.assertEqual(prod(range(1, 10), start=10), 3628800) + + self.assertEqual(prod([1, 2, 3, 4, 5]), 120) + self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0) + self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0) + self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0) + self.assertEqual(prod([1., F(3, 2)]), 1.5) + + # Error in multiplication + class BadMultiply: + def __rmul__(self, other): + raise RuntimeError + with self.assertRaises(RuntimeError): + prod([10., BadMultiply()]) + + # Test overflow in fast-path for integers + self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32) + # Test overflow in fast-path for floats + self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32)) + + self.assertRaises(TypeError, prod) + self.assertRaises(TypeError, prod, 42) + self.assertRaises(TypeError, prod, ['a', 'b', 'c']) + self.assertRaises(TypeError, prod, ['a', 'b', 'c'], start='') + self.assertRaises(TypeError, prod, [b'a', b'c'], start=b'') + values = [bytearray(b'a'), bytearray(b'b')] + self.assertRaises(TypeError, prod, values, start=bytearray(b'')) + self.assertRaises(TypeError, prod, [[1], [2], [3]]) + self.assertRaises(TypeError, prod, [{2:3}]) + self.assertRaises(TypeError, prod, [{2:3}]*2, start={2:3}) + self.assertRaises(TypeError, prod, [[1], [2], [3]], start=[]) + + # Some odd cases + self.assertEqual(prod([2, 3], start='ab'), 'abababababab') + self.assertEqual(prod([2, 3], start=[1, 2]), [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) + self.assertEqual(prod([], start={2: 3}), {2:3}) + + with self.assertRaises(TypeError): + prod([10, 20], 1) # start is a keyword-only argument + + self.assertEqual(prod([0, 1, 2, 3]), 0) + self.assertEqual(prod([1, 0, 2, 3]), 0) + self.assertEqual(prod([1, 2, 3, 0]), 0) + + def _naive_prod(iterable, start=1): + for elem in iterable: + start *= elem + return start + + # Big integers + + iterable = range(1, 10000) + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = range(-10000, -1) + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = range(-1000, 1000) + self.assertEqual(prod(iterable), 0) + + # Big floats + + iterable = [float(x) for x in range(1, 1000)] + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = [float(x) for x in range(-1000, -1)] + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = [float(x) for x in range(-1000, 1000)] + self.assertIsNaN(prod(iterable)) + + # Float tests + + self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3])) + self.assertIsNaN(prod([1, 0, float("nan"), 2, 3])) + self.assertIsNaN(prod([1, float("nan"), 0, 3])) + self.assertIsNaN(prod([1, float("inf"), float("nan"),3])) + self.assertIsNaN(prod([1, float("-inf"), float("nan"),3])) + self.assertIsNaN(prod([1, float("nan"), float("inf"),3])) + self.assertIsNaN(prod([1, float("nan"), float("-inf"),3])) + + self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf')) + self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf')) + + self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4])) + self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4])) + self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3])) + self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2])) + + # Type preservation + + self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int) + self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float) + self.assertEqual(type(prod(range(1, 10000))), int) + self.assertEqual(type(prod(range(1, 10000), start=1.0)), float) + self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])), + decimal.Decimal) + + @skipIfTorchDynamo("Infinite loop") + def testPerm(self): + perm = math.perm + factorial = math.factorial + # Test if factorial definition is satisfied + for n in range(500): + for k in (range(n + 1) if n < 100 else range(30) if n < 200 else range(10)): + self.assertEqual(perm(n, k), + factorial(n) // factorial(n - k)) + + # Test for Pascal's identity + for n in range(1, 100): + for k in range(1, n): + self.assertEqual(perm(n, k), perm(n - 1, k - 1) * k + perm(n - 1, k)) + + # Test corner cases + for n in range(1, 100): + self.assertEqual(perm(n, 0), 1) + self.assertEqual(perm(n, 1), n) + self.assertEqual(perm(n, n), factorial(n)) + + # Test one argument form + for n in range(20): + self.assertEqual(perm(n), factorial(n)) + self.assertEqual(perm(n, None), factorial(n)) + + # Raises TypeError if any argument is non-integer or argument count is + # not 1 or 2 + self.assertRaises(TypeError, perm, 10, 1.0) + self.assertRaises(TypeError, perm, 10, decimal.Decimal(1.0)) + self.assertRaises(TypeError, perm, 10, "1") + self.assertRaises(TypeError, perm, 10.0, 1) + self.assertRaises(TypeError, perm, decimal.Decimal(10.0), 1) + self.assertRaises(TypeError, perm, "10", 1) + + self.assertRaises(TypeError, perm) + self.assertRaises(TypeError, perm, 10, 1, 3) + self.assertRaises(TypeError, perm) + + # Raises Value error if not k or n are negative numbers + self.assertRaises(ValueError, perm, -1, 1) + self.assertRaises(ValueError, perm, -2**1000, 1) + self.assertRaises(ValueError, perm, 1, -1) + self.assertRaises(ValueError, perm, 1, -2**1000) + + # Returns zero if k is greater than n + self.assertEqual(perm(1, 2), 0) + self.assertEqual(perm(1, 2**1000), 0) + + n = 2**1000 + self.assertEqual(perm(n, 0), 1) + self.assertEqual(perm(n, 1), n) + self.assertEqual(perm(n, 2), n * (n-1)) + if support.check_impl_detail(cpython=True): + self.assertRaises(OverflowError, perm, n, n) + + for n, k in (True, True), (True, False), (False, False): + self.assertEqual(perm(n, k), 1) + self.assertIs(type(perm(n, k)), int) + self.assertEqual(perm(IntSubclass(5), IntSubclass(2)), 20) + self.assertEqual(perm(MyIndexable(5), MyIndexable(2)), 20) + for k in range(3): + self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int) + self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int) + + @skipIfTorchDynamo("infinite loop") + def testComb(self): + comb = math.comb + factorial = math.factorial + # Test if factorial definition is satisfied + for n in range(500): + for k in (range(n + 1) if n < 100 else range(30) if n < 200 else range(10)): + self.assertEqual(comb(n, k), factorial(n) + // (factorial(k) * factorial(n - k))) + + # Test for Pascal's identity + for n in range(1, 100): + for k in range(1, n): + self.assertEqual(comb(n, k), comb(n - 1, k - 1) + comb(n - 1, k)) + + # Test corner cases + for n in range(100): + self.assertEqual(comb(n, 0), 1) + self.assertEqual(comb(n, n), 1) + + for n in range(1, 100): + self.assertEqual(comb(n, 1), n) + self.assertEqual(comb(n, n - 1), n) + + # Test Symmetry + for n in range(100): + for k in range(n // 2): + self.assertEqual(comb(n, k), comb(n, n - k)) + + # Raises TypeError if any argument is non-integer or argument count is + # not 2 + self.assertRaises(TypeError, comb, 10, 1.0) + self.assertRaises(TypeError, comb, 10, decimal.Decimal(1.0)) + self.assertRaises(TypeError, comb, 10, "1") + self.assertRaises(TypeError, comb, 10.0, 1) + self.assertRaises(TypeError, comb, decimal.Decimal(10.0), 1) + self.assertRaises(TypeError, comb, "10", 1) + + self.assertRaises(TypeError, comb, 10) + self.assertRaises(TypeError, comb, 10, 1, 3) + self.assertRaises(TypeError, comb) + + # Raises Value error if not k or n are negative numbers + self.assertRaises(ValueError, comb, -1, 1) + self.assertRaises(ValueError, comb, -2**1000, 1) + self.assertRaises(ValueError, comb, 1, -1) + self.assertRaises(ValueError, comb, 1, -2**1000) + + # Returns zero if k is greater than n + self.assertEqual(comb(1, 2), 0) + self.assertEqual(comb(1, 2**1000), 0) + + n = 2**1000 + self.assertEqual(comb(n, 0), 1) + self.assertEqual(comb(n, 1), n) + self.assertEqual(comb(n, 2), n * (n-1) // 2) + self.assertEqual(comb(n, n), 1) + self.assertEqual(comb(n, n-1), n) + self.assertEqual(comb(n, n-2), n * (n-1) // 2) + if support.check_impl_detail(cpython=True): + self.assertRaises(OverflowError, comb, n, n//2) + + for n, k in (True, True), (True, False), (False, False): + self.assertEqual(comb(n, k), 1) + self.assertIs(type(comb(n, k)), int) + self.assertEqual(comb(IntSubclass(5), IntSubclass(2)), 10) + self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10) + for k in range(3): + self.assertIs(type(comb(IntSubclass(5), IntSubclass(k))), int) + self.assertIs(type(comb(MyIndexable(5), MyIndexable(k))), int) + + @requires_IEEE_754 + def test_nextafter(self): + # around 2^52 and 2^63 + self.assertEqual(math.nextafter(4503599627370496.0, -INF), + 4503599627370495.5) + self.assertEqual(math.nextafter(4503599627370496.0, INF), + 4503599627370497.0) + self.assertEqual(math.nextafter(9223372036854775808.0, 0.0), + 9223372036854774784.0) + self.assertEqual(math.nextafter(-9223372036854775808.0, 0.0), + -9223372036854774784.0) + + # around 1.0 + self.assertEqual(math.nextafter(1.0, -INF), + float.fromhex('0x1.fffffffffffffp-1')) + self.assertEqual(math.nextafter(1.0, INF), + float.fromhex('0x1.0000000000001p+0')) + self.assertEqual(math.nextafter(1.0, -INF, steps=1), + float.fromhex('0x1.fffffffffffffp-1')) + self.assertEqual(math.nextafter(1.0, INF, steps=1), + float.fromhex('0x1.0000000000001p+0')) + self.assertEqual(math.nextafter(1.0, -INF, steps=3), + float.fromhex('0x1.ffffffffffffdp-1')) + self.assertEqual(math.nextafter(1.0, INF, steps=3), + float.fromhex('0x1.0000000000003p+0')) + + # x == y: y is returned + for steps in range(1, 5): + self.assertEqual(math.nextafter(2.0, 2.0, steps=steps), 2.0) + self.assertEqualSign(math.nextafter(-0.0, +0.0, steps=steps), +0.0) + self.assertEqualSign(math.nextafter(+0.0, -0.0, steps=steps), -0.0) + + # around 0.0 + smallest_subnormal = sys.float_info.min * sys.float_info.epsilon + self.assertEqual(math.nextafter(+0.0, INF), smallest_subnormal) + self.assertEqual(math.nextafter(-0.0, INF), smallest_subnormal) + self.assertEqual(math.nextafter(+0.0, -INF), -smallest_subnormal) + self.assertEqual(math.nextafter(-0.0, -INF), -smallest_subnormal) + self.assertEqualSign(math.nextafter(smallest_subnormal, +0.0), +0.0) + self.assertEqualSign(math.nextafter(-smallest_subnormal, +0.0), -0.0) + self.assertEqualSign(math.nextafter(smallest_subnormal, -0.0), +0.0) + self.assertEqualSign(math.nextafter(-smallest_subnormal, -0.0), -0.0) + + # around infinity + largest_normal = sys.float_info.max + self.assertEqual(math.nextafter(INF, 0.0), largest_normal) + self.assertEqual(math.nextafter(-INF, 0.0), -largest_normal) + self.assertEqual(math.nextafter(largest_normal, INF), INF) + self.assertEqual(math.nextafter(-largest_normal, -INF), -INF) + + # NaN + self.assertIsNaN(math.nextafter(NAN, 1.0)) + self.assertIsNaN(math.nextafter(1.0, NAN)) + self.assertIsNaN(math.nextafter(NAN, NAN)) + + self.assertEqual(1.0, math.nextafter(1.0, INF, steps=0)) + with self.assertRaises(ValueError): + math.nextafter(1.0, INF, steps=-1) + + + @unittest.skip("flaky test under torch dynamo") # works on pytest and crashes on unittest + @requires_IEEE_754 + def test_ulp(self): + self.assertEqual(math.ulp(1.0), sys.float_info.epsilon) + # use int ** int rather than float ** int to not rely on pow() accuracy + self.assertEqual(math.ulp(2 ** 52), 1.0) + self.assertEqual(math.ulp(2 ** 53), 2.0) + self.assertEqual(math.ulp(2 ** 64), 4096.0) + + # min and max + self.assertEqual(math.ulp(0.0), + sys.float_info.min * sys.float_info.epsilon) + self.assertEqual(math.ulp(FLOAT_MAX), + FLOAT_MAX - math.nextafter(FLOAT_MAX, -INF)) + + # special cases + self.assertEqual(math.ulp(INF), INF) + self.assertIsNaN(math.ulp(math.nan)) + + # negative number: ulp(-x) == ulp(x) + for x in (0.0, 1.0, 2 ** 52, 2 ** 64, INF): + with self.subTest(x=x): + self.assertEqual(math.ulp(-x), math.ulp(x)) + + def test_issue39871(self): + # A SystemError should not be raised if the first arg to atan2(), + # copysign(), or remainder() cannot be converted to a float. + class F: + def __float__(self): + self.converted = True + 1/0 + for func in math.atan2, math.copysign, math.remainder: + y = F() + with self.assertRaises(TypeError): + func("not a number", y) + + # There should not have been any attempt to convert the second + # argument to a float. + self.assertFalse(getattr(y, "converted", False)) + + def test_input_exceptions(self): + self.assertRaises(TypeError, math.exp, "spam") + self.assertRaises(TypeError, math.erf, "spam") + self.assertRaises(TypeError, math.atan2, "spam", 1.0) + self.assertRaises(TypeError, math.atan2, 1.0, "spam") + self.assertRaises(TypeError, math.atan2, 1.0) + self.assertRaises(TypeError, math.atan2, 1.0, 2.0, 3.0) + + # Custom assertions. + + def assertIsNaN(self, value): + if not math.isnan(value): + self.fail("Expected a NaN, got {!r}.".format(value)) + + def assertEqualSign(self, x, y): + """Similar to assertEqual(), but compare also the sign with copysign(). + + Function useful to compare signed zeros. + """ + self.assertEqual(x, y) + self.assertEqual(math.copysign(1.0, x), math.copysign(1.0, y)) + + +class IsCloseTests(__TestCase): + isclose = math.isclose # subclasses should override this + + def assertIsClose(self, a, b, *args, **kwargs): + self.assertTrue(self.isclose(a, b, *args, **kwargs), + msg="%s and %s should be close!" % (a, b)) + + def assertIsNotClose(self, a, b, *args, **kwargs): + self.assertFalse(self.isclose(a, b, *args, **kwargs), + msg="%s and %s should not be close!" % (a, b)) + + def assertAllClose(self, examples, *args, **kwargs): + for a, b in examples: + self.assertIsClose(a, b, *args, **kwargs) + + def assertAllNotClose(self, examples, *args, **kwargs): + for a, b in examples: + self.assertIsNotClose(a, b, *args, **kwargs) + + def test_negative_tolerances(self): + # ValueError should be raised if either tolerance is less than zero + with self.assertRaises(ValueError): + self.assertIsClose(1, 1, rel_tol=-1e-100) + with self.assertRaises(ValueError): + self.assertIsClose(1, 1, rel_tol=1e-100, abs_tol=-1e10) + + def test_identical(self): + # identical values must test as close + identical_examples = [(2.0, 2.0), + (0.1e200, 0.1e200), + (1.123e-300, 1.123e-300), + (12345, 12345.0), + (0.0, -0.0), + (345678, 345678)] + self.assertAllClose(identical_examples, rel_tol=0.0, abs_tol=0.0) + + def test_eight_decimal_places(self): + # examples that are close to 1e-8, but not 1e-9 + eight_decimal_places_examples = [(1e8, 1e8 + 1), + (-1e-8, -1.000000009e-8), + (1.12345678, 1.12345679)] + self.assertAllClose(eight_decimal_places_examples, rel_tol=1e-8) + self.assertAllNotClose(eight_decimal_places_examples, rel_tol=1e-9) + + def test_near_zero(self): + # values close to zero + near_zero_examples = [(1e-9, 0.0), + (-1e-9, 0.0), + (-1e-150, 0.0)] + # these should not be close to any rel_tol + self.assertAllNotClose(near_zero_examples, rel_tol=0.9) + # these should be close to abs_tol=1e-8 + self.assertAllClose(near_zero_examples, abs_tol=1e-8) + + def test_identical_infinite(self): + # these are close regardless of tolerance -- i.e. they are equal + self.assertIsClose(INF, INF) + self.assertIsClose(INF, INF, abs_tol=0.0) + self.assertIsClose(NINF, NINF) + self.assertIsClose(NINF, NINF, abs_tol=0.0) + + def test_inf_ninf_nan(self): + # these should never be close (following IEEE 754 rules for equality) + not_close_examples = [(NAN, NAN), + (NAN, 1e-100), + (1e-100, NAN), + (INF, NAN), + (NAN, INF), + (INF, NINF), + (INF, 1.0), + (1.0, INF), + (INF, 1e308), + (1e308, INF)] + # use largest reasonable tolerance + self.assertAllNotClose(not_close_examples, abs_tol=0.999999999999999) + + def test_zero_tolerance(self): + # test with zero tolerance + zero_tolerance_close_examples = [(1.0, 1.0), + (-3.4, -3.4), + (-1e-300, -1e-300)] + self.assertAllClose(zero_tolerance_close_examples, rel_tol=0.0) + + zero_tolerance_not_close_examples = [(1.0, 1.000000000000001), + (0.99999999999999, 1.0), + (1.0e200, .999999999999999e200)] + self.assertAllNotClose(zero_tolerance_not_close_examples, rel_tol=0.0) + + def test_asymmetry(self): + # test the asymmetry example from PEP 485 + self.assertAllClose([(9, 10), (10, 9)], rel_tol=0.1) + + def test_integers(self): + # test with integer values + integer_examples = [(100000001, 100000000), + (123456789, 123456788)] + + self.assertAllClose(integer_examples, rel_tol=1e-8) + self.assertAllNotClose(integer_examples, rel_tol=1e-9) + + def test_decimals(self): + # test with Decimal values + from decimal import Decimal + + decimal_examples = [(Decimal('1.00000001'), Decimal('1.0')), + (Decimal('1.00000001e-20'), Decimal('1.0e-20')), + (Decimal('1.00000001e-100'), Decimal('1.0e-100')), + (Decimal('1.00000001e20'), Decimal('1.0e20'))] + self.assertAllClose(decimal_examples, rel_tol=1e-8) + self.assertAllNotClose(decimal_examples, rel_tol=1e-9) + + def test_fractions(self): + # test with Fraction values + from fractions import Fraction + + fraction_examples = [ + (Fraction(1, 100000000) + 1, Fraction(1)), + (Fraction(100000001), Fraction(100000000)), + (Fraction(10**8 + 1, 10**28), Fraction(1, 10**20))] + self.assertAllClose(fraction_examples, rel_tol=1e-8) + self.assertAllNotClose(fraction_examples, rel_tol=1e-9) + + +class FMATests(__TestCase): + """ Tests for math.fma. """ + + def test_fma_nan_results(self): + # Selected representative values. + values = [ + -math.inf, -1e300, -2.3, -1e-300, -0.0, + 0.0, 1e-300, 2.3, 1e300, math.inf, math.nan + ] + + # If any input is a NaN, the result should be a NaN, too. + for a, b in itertools.product(values, repeat=2): + with self.subTest(a=a, b=b): + self.assertIsNaN(math.fma(math.nan, a, b)) + self.assertIsNaN(math.fma(a, math.nan, b)) + self.assertIsNaN(math.fma(a, b, math.nan)) + + def test_fma_infinities(self): + # Cases involving infinite inputs or results. + positives = [1e-300, 2.3, 1e300, math.inf] + finites = [-1e300, -2.3, -1e-300, -0.0, 0.0, 1e-300, 2.3, 1e300] + non_nans = [-math.inf, -2.3, -0.0, 0.0, 2.3, math.inf] + + # ValueError due to inf * 0 computation. + for c in non_nans: + for infinity in [math.inf, -math.inf]: + for zero in [0.0, -0.0]: + with self.subTest(c=c, infinity=infinity, zero=zero): + with self.assertRaises(ValueError): + math.fma(infinity, zero, c) + with self.assertRaises(ValueError): + math.fma(zero, infinity, c) + + # ValueError when a*b and c both infinite of opposite signs. + for b in positives: + with self.subTest(b=b): + with self.assertRaises(ValueError): + math.fma(math.inf, b, -math.inf) + with self.assertRaises(ValueError): + math.fma(math.inf, -b, math.inf) + with self.assertRaises(ValueError): + math.fma(-math.inf, -b, -math.inf) + with self.assertRaises(ValueError): + math.fma(-math.inf, b, math.inf) + with self.assertRaises(ValueError): + math.fma(b, math.inf, -math.inf) + with self.assertRaises(ValueError): + math.fma(-b, math.inf, math.inf) + with self.assertRaises(ValueError): + math.fma(-b, -math.inf, -math.inf) + with self.assertRaises(ValueError): + math.fma(b, -math.inf, math.inf) + + # Infinite result when a*b and c both infinite of the same sign. + for b in positives: + with self.subTest(b=b): + self.assertEqual(math.fma(math.inf, b, math.inf), math.inf) + self.assertEqual(math.fma(math.inf, -b, -math.inf), -math.inf) + self.assertEqual(math.fma(-math.inf, -b, math.inf), math.inf) + self.assertEqual(math.fma(-math.inf, b, -math.inf), -math.inf) + self.assertEqual(math.fma(b, math.inf, math.inf), math.inf) + self.assertEqual(math.fma(-b, math.inf, -math.inf), -math.inf) + self.assertEqual(math.fma(-b, -math.inf, math.inf), math.inf) + self.assertEqual(math.fma(b, -math.inf, -math.inf), -math.inf) + + # Infinite result when a*b finite, c infinite. + for a, b in itertools.product(finites, finites): + with self.subTest(b=b): + self.assertEqual(math.fma(a, b, math.inf), math.inf) + self.assertEqual(math.fma(a, b, -math.inf), -math.inf) + + # Infinite result when a*b infinite, c finite. + for b, c in itertools.product(positives, finites): + with self.subTest(b=b, c=c): + self.assertEqual(math.fma(math.inf, b, c), math.inf) + self.assertEqual(math.fma(-math.inf, b, c), -math.inf) + self.assertEqual(math.fma(-math.inf, -b, c), math.inf) + self.assertEqual(math.fma(math.inf, -b, c), -math.inf) + + self.assertEqual(math.fma(b, math.inf, c), math.inf) + self.assertEqual(math.fma(b, -math.inf, c), -math.inf) + self.assertEqual(math.fma(-b, -math.inf, c), math.inf) + self.assertEqual(math.fma(-b, math.inf, c), -math.inf) + + # gh-73468: On some platforms, libc fma() doesn't implement IEE 754-2008 + # properly: it doesn't use the right sign when the result is zero. + @unittest.skipIf( + sys.platform.startswith(("freebsd", "wasi", "netbsd", "emscripten")) + or (sys.platform == "android" and platform.machine() == "x86_64"), + f"this platform doesn't implement IEE 754-2008 properly") + def test_fma_zero_result(self): + nonnegative_finites = [0.0, 1e-300, 2.3, 1e300] + + # Zero results from exact zero inputs. + for b in nonnegative_finites: + with self.subTest(b=b): + self.assertIsPositiveZero(math.fma(0.0, b, 0.0)) + self.assertIsPositiveZero(math.fma(0.0, b, -0.0)) + self.assertIsNegativeZero(math.fma(0.0, -b, -0.0)) + self.assertIsPositiveZero(math.fma(0.0, -b, 0.0)) + self.assertIsPositiveZero(math.fma(-0.0, -b, 0.0)) + self.assertIsPositiveZero(math.fma(-0.0, -b, -0.0)) + self.assertIsNegativeZero(math.fma(-0.0, b, -0.0)) + self.assertIsPositiveZero(math.fma(-0.0, b, 0.0)) + + self.assertIsPositiveZero(math.fma(b, 0.0, 0.0)) + self.assertIsPositiveZero(math.fma(b, 0.0, -0.0)) + self.assertIsNegativeZero(math.fma(-b, 0.0, -0.0)) + self.assertIsPositiveZero(math.fma(-b, 0.0, 0.0)) + self.assertIsPositiveZero(math.fma(-b, -0.0, 0.0)) + self.assertIsPositiveZero(math.fma(-b, -0.0, -0.0)) + self.assertIsNegativeZero(math.fma(b, -0.0, -0.0)) + self.assertIsPositiveZero(math.fma(b, -0.0, 0.0)) + + # Exact zero result from nonzero inputs. + self.assertIsPositiveZero(math.fma(2.0, 2.0, -4.0)) + self.assertIsPositiveZero(math.fma(2.0, -2.0, 4.0)) + self.assertIsPositiveZero(math.fma(-2.0, -2.0, -4.0)) + self.assertIsPositiveZero(math.fma(-2.0, 2.0, 4.0)) + + # Underflow to zero. + tiny = 1e-300 + self.assertIsPositiveZero(math.fma(tiny, tiny, 0.0)) + self.assertIsNegativeZero(math.fma(tiny, -tiny, 0.0)) + self.assertIsPositiveZero(math.fma(-tiny, -tiny, 0.0)) + self.assertIsNegativeZero(math.fma(-tiny, tiny, 0.0)) + self.assertIsPositiveZero(math.fma(tiny, tiny, -0.0)) + self.assertIsNegativeZero(math.fma(tiny, -tiny, -0.0)) + self.assertIsPositiveZero(math.fma(-tiny, -tiny, -0.0)) + self.assertIsNegativeZero(math.fma(-tiny, tiny, -0.0)) + + # Corner case where rounding the multiplication would + # give the wrong result. + x = float.fromhex('0x1p-500') + y = float.fromhex('0x1p-550') + z = float.fromhex('0x1p-1000') + self.assertIsNegativeZero(math.fma(x-y, x+y, -z)) + self.assertIsPositiveZero(math.fma(y-x, x+y, z)) + self.assertIsNegativeZero(math.fma(y-x, -(x+y), -z)) + self.assertIsPositiveZero(math.fma(x-y, -(x+y), z)) + + def test_fma_overflow(self): + a = b = float.fromhex('0x1p512') + c = float.fromhex('0x1p1023') + # Overflow from multiplication. + with self.assertRaises(OverflowError): + math.fma(a, b, 0.0) + self.assertEqual(math.fma(a, b/2.0, 0.0), c) + # Overflow from the addition. + with self.assertRaises(OverflowError): + math.fma(a, b/2.0, c) + # No overflow, even though a*b overflows a float. + self.assertEqual(math.fma(a, b, -c), c) + + # Extreme case: a * b is exactly at the overflow boundary, so the + # tiniest offset makes a difference between overflow and a finite + # result. + a = float.fromhex('0x1.ffffffc000000p+511') + b = float.fromhex('0x1.0000002000000p+512') + c = float.fromhex('0x0.0000000000001p-1022') + with self.assertRaises(OverflowError): + math.fma(a, b, 0.0) + with self.assertRaises(OverflowError): + math.fma(a, b, c) + self.assertEqual(math.fma(a, b, -c), + float.fromhex('0x1.fffffffffffffp+1023')) + + # Another extreme case: here a*b is about as large as possible subject + # to math.fma(a, b, c) being finite. + a = float.fromhex('0x1.ae565943785f9p+512') + b = float.fromhex('0x1.3094665de9db8p+512') + c = float.fromhex('0x1.fffffffffffffp+1023') + self.assertEqual(math.fma(a, b, -c), c) + + def test_fma_single_round(self): + a = float.fromhex('0x1p-50') + self.assertEqual(math.fma(a - 1.0, a + 1.0, 1.0), a*a) + + def test_random(self): + # A collection of randomly generated inputs for which the naive FMA + # (with two rounds) gives a different result from a singly-rounded FMA. + + # tuples (a, b, c, expected) + test_values = [ + ('0x1.694adde428b44p-1', '0x1.371b0d64caed7p-1', + '0x1.f347e7b8deab8p-4', '0x1.19f10da56c8adp-1'), + ('0x1.605401ccc6ad6p-2', '0x1.ce3a40bf56640p-2', + '0x1.96e3bf7bf2e20p-2', '0x1.1af6d8aa83101p-1'), + ('0x1.e5abd653a67d4p-2', '0x1.a2e400209b3e6p-1', + '0x1.a90051422ce13p-1', '0x1.37d68cc8c0fbbp+0'), + ('0x1.f94e8efd54700p-2', '0x1.123065c812cebp-1', + '0x1.458f86fb6ccd0p-1', '0x1.ccdcee26a3ff3p-1'), + ('0x1.bd926f1eedc96p-1', '0x1.eee9ca68c5740p-1', + '0x1.960c703eb3298p-2', '0x1.3cdcfb4fdb007p+0'), + ('0x1.27348350fbccdp-1', '0x1.3b073914a53f1p-1', + '0x1.e300da5c2b4cbp-1', '0x1.4c51e9a3c4e29p+0'), + ('0x1.2774f00b3497bp-1', '0x1.7038ec336bff0p-2', + '0x1.2f6f2ccc3576bp-1', '0x1.99ad9f9c2688bp-1'), + ('0x1.51d5a99300e5cp-1', '0x1.5cd74abd445a1p-1', + '0x1.8880ab0bbe530p-1', '0x1.3756f96b91129p+0'), + ('0x1.73cb965b821b8p-2', '0x1.218fd3d8d5371p-1', + '0x1.d1ea966a1f758p-2', '0x1.5217b8fd90119p-1'), + ('0x1.4aa98e890b046p-1', '0x1.954d85dff1041p-1', + '0x1.122b59317ebdfp-1', '0x1.0bf644b340cc5p+0'), + ('0x1.e28f29e44750fp-1', '0x1.4bcc4fdcd18fep-1', + '0x1.fd47f81298259p-1', '0x1.9b000afbc9995p+0'), + ('0x1.d2e850717fe78p-3', '0x1.1dd7531c303afp-1', + '0x1.e0869746a2fc2p-2', '0x1.316df6eb26439p-1'), + ('0x1.cf89c75ee6fbap-2', '0x1.b23decdc66825p-1', + '0x1.3d1fe76ac6168p-1', '0x1.00d8ea4c12abbp+0'), + ('0x1.3265ae6f05572p-2', '0x1.16d7ec285f7a2p-1', + '0x1.0b8405b3827fbp-1', '0x1.5ef33c118a001p-1'), + ('0x1.c4d1bf55ec1a5p-1', '0x1.bc59618459e12p-2', + '0x1.ce5b73dc1773dp-1', '0x1.496cf6164f99bp+0'), + ('0x1.d350026ac3946p-1', '0x1.9a234e149a68cp-2', + '0x1.f5467b1911fd6p-2', '0x1.b5cee3225caa5p-1'), + ] + for a_hex, b_hex, c_hex, expected_hex in test_values: + with self.subTest(a_hex=a_hex, b_hex=b_hex, c_hex=c_hex, + expected_hex=expected_hex): + a = float.fromhex(a_hex) + b = float.fromhex(b_hex) + c = float.fromhex(c_hex) + expected = float.fromhex(expected_hex) + self.assertEqual(math.fma(a, b, c), expected) + self.assertEqual(math.fma(b, a, c), expected) + + # Custom assertions. + def assertIsNaN(self, value): + self.assertTrue( + math.isnan(value), + msg="Expected a NaN, got {!r}".format(value) + ) + + def assertIsPositiveZero(self, value): + self.assertTrue( + value == 0 and math.copysign(1, value) > 0, + msg="Expected a positive zero, got {!r}".format(value) + ) + + def assertIsNegativeZero(self, value): + self.assertTrue( + value == 0 and math.copysign(1, value) < 0, + msg="Expected a negative zero, got {!r}".format(value) + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.diff b/test/dynamo/cpython/3_13/test_ordered_dict.diff new file mode 100644 index 00000000000000..c55fee2f7daf0d --- /dev/null +++ b/test/dynamo/cpython/3_13/test_ordered_dict.diff @@ -0,0 +1,173 @@ +diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.py b/test/dynamo/cpython/3_13/test_ordered_dict.py +index a9b6a84996e..b77eff70414 100644 +--- a/test/dynamo/cpython/3_13/test_ordered_dict.py ++++ b/test/dynamo/cpython/3_13/test_ordered_dict.py +@@ -1,3 +1,57 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import ( ++ run_tests, ++ xfailIfTorchDynamo, ++) ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import builtins + import contextlib + import copy +@@ -760,7 +814,7 @@ class _TriggerSideEffectOnEqual: + def side_effect(self): + raise NotImplementedError + +-class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): ++class PurePythonOrderedDictTests(OrderedDictTests, __TestCase): + + module = py_coll + OrderedDict = py_coll.OrderedDict +@@ -781,7 +835,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): + self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) + + +-class CPythonBuiltinDictTests(unittest.TestCase): ++class CPythonBuiltinDictTests(__TestCase): + """Builtin dict preserves insertion order. + + Reuse some of tests in OrderedDict selectively. +@@ -800,6 +854,7 @@ for method in ( + del method + + ++ + class CPythonOrderedDictSideEffects: + + def check_runtime_error_issue119004(self, dict1, dict2): +@@ -878,7 +933,7 @@ class CPythonOrderedDictSideEffects: + @unittest.skipUnless(c_coll, 'requires the C version of the collections module') + class CPythonOrderedDictTests(OrderedDictTests, + CPythonOrderedDictSideEffects, +- unittest.TestCase): ++ __TestCase): + + module = c_coll + OrderedDict = c_coll.OrderedDict +@@ -986,7 +1041,7 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests): + pass + + +-class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): ++class PurePythonOrderedDictWithSlotsCopyingTests(__TestCase): + + module = py_coll + class OrderedDict(py_coll.OrderedDict): +@@ -995,7 +1050,7 @@ class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): + + + @unittest.skipUnless(c_coll, 'requires the C version of the collections module') +-class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): ++class CPythonOrderedDictWithSlotsCopyingTests(__TestCase): + + module = c_coll + class OrderedDict(c_coll.OrderedDict): +@@ -1008,6 +1063,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): + @classmethod + def setUpClass(cls): + cls.type2test = py_coll.OrderedDict ++ super().setUpClass() + + def test_popitem(self): + d = self._empty_mapping() +@@ -1020,6 +1076,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): + @classmethod + def setUpClass(cls): + cls.type2test = c_coll.OrderedDict ++ super().setUpClass() + + def test_popitem(self): + d = self._empty_mapping() +@@ -1033,6 +1090,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + class MyOrderedDict(py_coll.OrderedDict): + pass + cls.type2test = MyOrderedDict ++ super().setUpClass() + + def test_popitem(self): + d = self._empty_mapping() +@@ -1047,6 +1105,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + class MyOrderedDict(c_coll.OrderedDict): + pass + cls.type2test = MyOrderedDict ++ super().setUpClass() + + def test_popitem(self): + d = self._empty_mapping() +@@ -1120,21 +1179,22 @@ class SimpleLRUCacheTests: + self.assertEqual(list(c), [1, 3, 2]) + + +-class PySimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase): ++class PySimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): + + class type2test(SimpleLRUCache, py_coll.OrderedDict): + pass + + + @unittest.skipUnless(c_coll, 'requires the C version of the collections module') +-class CSimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase): ++class CSimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): + + @classmethod + def setUpClass(cls): + class type2test(SimpleLRUCache, c_coll.OrderedDict): + pass + cls.type2test = type2test ++ super().setUpClass() + + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.py b/test/dynamo/cpython/3_13/test_ordered_dict.py new file mode 100644 index 00000000000000..b77eff704149f4 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_ordered_dict.py @@ -0,0 +1,1200 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, + xfailIfTorchDynamo, +) + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import builtins +import contextlib +import copy +import gc +import operator +import pickle +import re +from random import randrange, shuffle +import struct +import sys +import unittest +import weakref +from collections.abc import MutableMapping +from test import mapping_tests, support +from test.support import import_helper, suppress_immortalization + + +py_coll = import_helper.import_fresh_module('collections', + blocked=['_collections']) +c_coll = import_helper.import_fresh_module('collections', + fresh=['_collections']) + + +@contextlib.contextmanager +def replaced_module(name, replacement): + original_module = sys.modules[name] + sys.modules[name] = replacement + try: + yield + finally: + sys.modules[name] = original_module + + +class OrderedDictTests: + + def test_init(self): + OrderedDict = self.OrderedDict + with self.assertRaises(TypeError): + OrderedDict([('a', 1), ('b', 2)], None) # too many args + pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] + self.assertEqual(sorted(OrderedDict(dict(pairs)).items()), pairs) # dict input + self.assertEqual(sorted(OrderedDict(**dict(pairs)).items()), pairs) # kwds input + self.assertEqual(list(OrderedDict(pairs).items()), pairs) # pairs input + self.assertEqual(list(OrderedDict([('a', 1), ('b', 2), ('c', 9), ('d', 4)], + c=3, e=5).items()), pairs) # mixed input + + # make sure no positional args conflict with possible kwdargs + self.assertEqual(list(OrderedDict(self=42).items()), [('self', 42)]) + self.assertEqual(list(OrderedDict(other=42).items()), [('other', 42)]) + self.assertRaises(TypeError, OrderedDict, 42) + self.assertRaises(TypeError, OrderedDict, (), ()) + self.assertRaises(TypeError, OrderedDict.__init__) + + # Make sure that direct calls to __init__ do not clear previous contents + d = OrderedDict([('a', 1), ('b', 2), ('c', 3), ('d', 44), ('e', 55)]) + d.__init__([('e', 5), ('f', 6)], g=7, d=4) + self.assertEqual(list(d.items()), + [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)]) + + def test_468(self): + OrderedDict = self.OrderedDict + items = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)] + shuffle(items) + argdict = OrderedDict(items) + d = OrderedDict(**argdict) + self.assertEqual(list(d.items()), items) + + def test_update(self): + OrderedDict = self.OrderedDict + with self.assertRaises(TypeError): + OrderedDict().update([('a', 1), ('b', 2)], None) # too many args + pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] + od = OrderedDict() + od.update(dict(pairs)) + self.assertEqual(sorted(od.items()), pairs) # dict input + od = OrderedDict() + od.update(**dict(pairs)) + self.assertEqual(sorted(od.items()), pairs) # kwds input + od = OrderedDict() + od.update(pairs) + self.assertEqual(list(od.items()), pairs) # pairs input + od = OrderedDict() + od.update([('a', 1), ('b', 2), ('c', 9), ('d', 4)], c=3, e=5) + self.assertEqual(list(od.items()), pairs) # mixed input + + # Issue 9137: Named argument called 'other' or 'self' + # shouldn't be treated specially. + od = OrderedDict() + od.update(self=23) + self.assertEqual(list(od.items()), [('self', 23)]) + od = OrderedDict() + od.update(other={}) + self.assertEqual(list(od.items()), [('other', {})]) + od = OrderedDict() + od.update(red=5, blue=6, other=7, self=8) + self.assertEqual(sorted(list(od.items())), + [('blue', 6), ('other', 7), ('red', 5), ('self', 8)]) + + # Make sure that direct calls to update do not clear previous contents + # add that updates items are not moved to the end + d = OrderedDict([('a', 1), ('b', 2), ('c', 3), ('d', 44), ('e', 55)]) + d.update([('e', 5), ('f', 6)], g=7, d=4) + self.assertEqual(list(d.items()), + [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)]) + + self.assertRaises(TypeError, OrderedDict().update, 42) + self.assertRaises(TypeError, OrderedDict().update, (), ()) + self.assertRaises(TypeError, OrderedDict.update) + + self.assertRaises(TypeError, OrderedDict().update, 42) + self.assertRaises(TypeError, OrderedDict().update, (), ()) + self.assertRaises(TypeError, OrderedDict.update) + + def test_init_calls(self): + calls = [] + class Spam: + def keys(self): + calls.append('keys') + return () + def items(self): + calls.append('items') + return () + + self.OrderedDict(Spam()) + self.assertEqual(calls, ['keys']) + + def test_overridden_init(self): + # Sync-up pure Python OD class with C class where + # a consistent internal state is created in __new__ + # rather than __init__. + OrderedDict = self.OrderedDict + class ODNI(OrderedDict): + def __init__(*args, **kwargs): + pass + od = ODNI() + od['a'] = 1 # This used to fail because __init__ was bypassed + + def test_fromkeys(self): + OrderedDict = self.OrderedDict + od = OrderedDict.fromkeys('abc') + self.assertEqual(list(od.items()), [(c, None) for c in 'abc']) + od = OrderedDict.fromkeys('abc', value=None) + self.assertEqual(list(od.items()), [(c, None) for c in 'abc']) + od = OrderedDict.fromkeys('abc', value=0) + self.assertEqual(list(od.items()), [(c, 0) for c in 'abc']) + + def test_abc(self): + OrderedDict = self.OrderedDict + self.assertIsInstance(OrderedDict(), MutableMapping) + self.assertTrue(issubclass(OrderedDict, MutableMapping)) + + def test_clear(self): + OrderedDict = self.OrderedDict + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + self.assertEqual(len(od), len(pairs)) + od.clear() + self.assertEqual(len(od), 0) + + def test_delitem(self): + OrderedDict = self.OrderedDict + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + od = OrderedDict(pairs) + del od['a'] + self.assertNotIn('a', od) + with self.assertRaises(KeyError): + del od['a'] + self.assertEqual(list(od.items()), pairs[:2] + pairs[3:]) + + def test_setitem(self): + OrderedDict = self.OrderedDict + od = OrderedDict([('d', 1), ('b', 2), ('c', 3), ('a', 4), ('e', 5)]) + od['c'] = 10 # existing element + od['f'] = 20 # new element + self.assertEqual(list(od.items()), + [('d', 1), ('b', 2), ('c', 10), ('a', 4), ('e', 5), ('f', 20)]) + + def test_iterators(self): + OrderedDict = self.OrderedDict + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + self.assertEqual(list(od), [t[0] for t in pairs]) + self.assertEqual(list(od.keys()), [t[0] for t in pairs]) + self.assertEqual(list(od.values()), [t[1] for t in pairs]) + self.assertEqual(list(od.items()), pairs) + self.assertEqual(list(reversed(od)), + [t[0] for t in reversed(pairs)]) + self.assertEqual(list(reversed(od.keys())), + [t[0] for t in reversed(pairs)]) + self.assertEqual(list(reversed(od.values())), + [t[1] for t in reversed(pairs)]) + self.assertEqual(list(reversed(od.items())), list(reversed(pairs))) + + def test_detect_deletion_during_iteration(self): + OrderedDict = self.OrderedDict + od = OrderedDict.fromkeys('abc') + it = iter(od) + key = next(it) + del od[key] + with self.assertRaises(Exception): + # Note, the exact exception raised is not guaranteed + # The only guarantee that the next() will not succeed + next(it) + + def test_sorted_iterators(self): + OrderedDict = self.OrderedDict + with self.assertRaises(TypeError): + OrderedDict([('a', 1), ('b', 2)], None) + pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] + od = OrderedDict(pairs) + self.assertEqual(sorted(od), [t[0] for t in pairs]) + self.assertEqual(sorted(od.keys()), [t[0] for t in pairs]) + self.assertEqual(sorted(od.values()), [t[1] for t in pairs]) + self.assertEqual(sorted(od.items()), pairs) + self.assertEqual(sorted(reversed(od)), + sorted([t[0] for t in reversed(pairs)])) + + def test_iterators_empty(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + empty = [] + self.assertEqual(list(od), empty) + self.assertEqual(list(od.keys()), empty) + self.assertEqual(list(od.values()), empty) + self.assertEqual(list(od.items()), empty) + self.assertEqual(list(reversed(od)), empty) + self.assertEqual(list(reversed(od.keys())), empty) + self.assertEqual(list(reversed(od.values())), empty) + self.assertEqual(list(reversed(od.items())), empty) + + def test_popitem(self): + OrderedDict = self.OrderedDict + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + while pairs: + self.assertEqual(od.popitem(), pairs.pop()) + with self.assertRaises(KeyError): + od.popitem() + self.assertEqual(len(od), 0) + + def test_popitem_last(self): + OrderedDict = self.OrderedDict + pairs = [(i, i) for i in range(30)] + + obj = OrderedDict(pairs) + for i in range(8): + obj.popitem(True) + obj.popitem(True) + obj.popitem(last=True) + self.assertEqual(len(obj), 20) + + def test_pop(self): + OrderedDict = self.OrderedDict + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + shuffle(pairs) + while pairs: + k, v = pairs.pop() + self.assertEqual(od.pop(k), v) + with self.assertRaises(KeyError): + od.pop('xyz') + self.assertEqual(len(od), 0) + self.assertEqual(od.pop(k, 12345), 12345) + + # make sure pop still works when __missing__ is defined + class Missing(OrderedDict): + def __missing__(self, key): + return 0 + m = Missing(a=1) + self.assertEqual(m.pop('b', 5), 5) + self.assertEqual(m.pop('a', 6), 1) + self.assertEqual(m.pop('a', 6), 6) + self.assertEqual(m.pop('a', default=6), 6) + with self.assertRaises(KeyError): + m.pop('a') + + def test_equality(self): + OrderedDict = self.OrderedDict + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od1 = OrderedDict(pairs) + od2 = OrderedDict(pairs) + self.assertEqual(od1, od2) # same order implies equality + pairs = pairs[2:] + pairs[:2] + od2 = OrderedDict(pairs) + self.assertNotEqual(od1, od2) # different order implies inequality + # comparison to regular dict is not order sensitive + self.assertEqual(od1, dict(od2)) + self.assertEqual(dict(od2), od1) + # different length implied inequality + self.assertNotEqual(od1, OrderedDict(pairs[:-1])) + + def test_copying(self): + OrderedDict = self.OrderedDict + # Check that ordered dicts are copyable, deepcopyable, picklable, + # and have a repr/eval round-trip + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + od = OrderedDict(pairs) + od.x = ['x'] + od.z = ['z'] + def check(dup): + msg = "\ncopy: %s\nod: %s" % (dup, od) + self.assertIsNot(dup, od, msg) + self.assertEqual(dup, od) + self.assertEqual(list(dup.items()), list(od.items())) + self.assertEqual(len(dup), len(od)) + self.assertEqual(type(dup), type(od)) + check(od.copy()) + dup = copy.copy(od) + check(dup) + self.assertIs(dup.x, od.x) + self.assertIs(dup.z, od.z) + self.assertFalse(hasattr(dup, 'y')) + dup = copy.deepcopy(od) + check(dup) + self.assertEqual(dup.x, od.x) + self.assertIsNot(dup.x, od.x) + self.assertEqual(dup.z, od.z) + self.assertIsNot(dup.z, od.z) + self.assertFalse(hasattr(dup, 'y')) + # pickle directly pulls the module, so we have to fake it + with replaced_module('collections', self.module): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + dup = pickle.loads(pickle.dumps(od, proto)) + check(dup) + self.assertEqual(dup.x, od.x) + self.assertEqual(dup.z, od.z) + self.assertFalse(hasattr(dup, 'y')) + check(eval(repr(od))) + update_test = OrderedDict() + update_test.update(od) + check(update_test) + check(OrderedDict(od)) + + def test_yaml_linkage(self): + OrderedDict = self.OrderedDict + # Verify that __reduce__ is setup in a way that supports PyYAML's dump() feature. + # In yaml, lists are native but tuples are not. + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + od = OrderedDict(pairs) + # yaml.dump(od) --> + # '!!python/object/apply:__main__.OrderedDict\n- - [a, 1]\n - [b, 2]\n' + self.assertTrue(all(type(pair)==list for pair in od.__reduce__()[1])) + + def test_reduce_not_too_fat(self): + OrderedDict = self.OrderedDict + # do not save instance dictionary if not needed + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + od = OrderedDict(pairs) + self.assertIsInstance(od.__dict__, dict) + self.assertIsNone(od.__reduce__()[2]) + od.x = 10 + self.assertEqual(od.__dict__['x'], 10) + self.assertEqual(od.__reduce__()[2], {'x': 10}) + + def test_pickle_recursive(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + od[1] = od + + # pickle directly pulls the module, so we have to fake it + with replaced_module('collections', self.module): + for proto in range(-1, pickle.HIGHEST_PROTOCOL + 1): + dup = pickle.loads(pickle.dumps(od, proto)) + self.assertIsNot(dup, od) + self.assertEqual(list(dup.keys()), [1]) + self.assertIs(dup[1], dup) + + def test_repr(self): + OrderedDict = self.OrderedDict + od = OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]) + self.assertEqual(repr(od), + "OrderedDict({'c': 1, 'b': 2, 'a': 3, 'd': 4, 'e': 5, 'f': 6})") + self.assertEqual(eval(repr(od)), od) + self.assertEqual(repr(OrderedDict()), "OrderedDict()") + + def test_repr_recursive(self): + OrderedDict = self.OrderedDict + # See issue #9826 + od = OrderedDict.fromkeys('abc') + od['x'] = od + self.assertEqual(repr(od), + "OrderedDict({'a': None, 'b': None, 'c': None, 'x': ...})") + + def test_repr_recursive_values(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + od[42] = od.values() + r = repr(od) + # Cannot perform a stronger test, as the contents of the repr + # are implementation-dependent. All we can say is that we + # want a str result, not an exception of any sort. + self.assertIsInstance(r, str) + od[42] = od.items() + r = repr(od) + # Again. + self.assertIsInstance(r, str) + + def test_setdefault(self): + OrderedDict = self.OrderedDict + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + pair_order = list(od.items()) + self.assertEqual(od.setdefault('a', 10), 3) + # make sure order didn't change + self.assertEqual(list(od.items()), pair_order) + self.assertEqual(od.setdefault('x', 10), 10) + # make sure 'x' is added to the end + self.assertEqual(list(od.items())[-1], ('x', 10)) + self.assertEqual(od.setdefault('g', default=9), 9) + + # make sure setdefault still works when __missing__ is defined + class Missing(OrderedDict): + def __missing__(self, key): + return 0 + self.assertEqual(Missing().setdefault(5, 9), 9) + + def test_reinsert(self): + OrderedDict = self.OrderedDict + # Given insert a, insert b, delete a, re-insert a, + # verify that a is now later than b. + od = OrderedDict() + od['a'] = 1 + od['b'] = 2 + del od['a'] + self.assertEqual(list(od.items()), [('b', 2)]) + od['a'] = 1 + self.assertEqual(list(od.items()), [('b', 2), ('a', 1)]) + + def test_move_to_end(self): + OrderedDict = self.OrderedDict + od = OrderedDict.fromkeys('abcde') + self.assertEqual(list(od), list('abcde')) + od.move_to_end('c') + self.assertEqual(list(od), list('abdec')) + od.move_to_end('c', False) + self.assertEqual(list(od), list('cabde')) + od.move_to_end('c', False) + self.assertEqual(list(od), list('cabde')) + od.move_to_end('e') + self.assertEqual(list(od), list('cabde')) + od.move_to_end('b', last=False) + self.assertEqual(list(od), list('bcade')) + with self.assertRaises(KeyError): + od.move_to_end('x') + with self.assertRaises(KeyError): + od.move_to_end('x', False) + + def test_move_to_end_issue25406(self): + OrderedDict = self.OrderedDict + od = OrderedDict.fromkeys('abc') + od.move_to_end('c', last=False) + self.assertEqual(list(od), list('cab')) + od.move_to_end('a', last=False) + self.assertEqual(list(od), list('acb')) + + od = OrderedDict.fromkeys('abc') + od.move_to_end('a') + self.assertEqual(list(od), list('bca')) + od.move_to_end('c') + self.assertEqual(list(od), list('bac')) + + def test_sizeof(self): + OrderedDict = self.OrderedDict + # Wimpy test: Just verify the reported size is larger than a regular dict + d = dict(a=1) + od = OrderedDict(**d) + self.assertGreater(sys.getsizeof(od), sys.getsizeof(d)) + + def test_views(self): + OrderedDict = self.OrderedDict + # See http://bugs.python.org/issue24286 + s = 'the quick brown fox jumped over a lazy dog yesterday before dawn'.split() + od = OrderedDict.fromkeys(s) + self.assertEqual(od.keys(), dict(od).keys()) + self.assertEqual(od.items(), dict(od).items()) + + def test_override_update(self): + OrderedDict = self.OrderedDict + # Verify that subclasses can override update() without breaking __init__() + class MyOD(OrderedDict): + def update(self, *args, **kwds): + raise Exception() + items = [('a', 1), ('c', 3), ('b', 2)] + self.assertEqual(list(MyOD(items).items()), items) + + def test_highly_nested(self): + # Issues 25395 and 35983: test that the trashcan mechanism works + # correctly for OrderedDict: deleting a highly nested OrderDict + # should not crash Python. + OrderedDict = self.OrderedDict + obj = None + for _ in range(1000): + obj = OrderedDict([(None, obj)]) + del obj + support.gc_collect() + + def test_highly_nested_subclass(self): + # Issues 25395 and 35983: test that the trashcan mechanism works + # correctly for OrderedDict: deleting a highly nested OrderDict + # should not crash Python. + OrderedDict = self.OrderedDict + deleted = [] + class MyOD(OrderedDict): + def __del__(self): + deleted.append(self.i) + obj = None + for i in range(100): + obj = MyOD([(None, obj)]) + obj.i = i + del obj + support.gc_collect() + self.assertEqual(deleted, list(reversed(range(100)))) + + def test_delitem_hash_collision(self): + OrderedDict = self.OrderedDict + + class Key: + def __init__(self, hash): + self._hash = hash + self.value = str(id(self)) + def __hash__(self): + return self._hash + def __eq__(self, other): + try: + return self.value == other.value + except AttributeError: + return False + def __repr__(self): + return self.value + + def blocking_hash(hash): + # See the collision-handling in lookdict (in Objects/dictobject.c). + MINSIZE = 8 + i = (hash & MINSIZE-1) + return (i << 2) + i + hash + 1 + + COLLIDING = 1 + + key = Key(COLLIDING) + colliding = Key(COLLIDING) + blocking = Key(blocking_hash(COLLIDING)) + + od = OrderedDict() + od[key] = ... + od[blocking] = ... + od[colliding] = ... + od['after'] = ... + + del od[blocking] + del od[colliding] + self.assertEqual(list(od.items()), [(key, ...), ('after', ...)]) + + def test_issue24347(self): + OrderedDict = self.OrderedDict + + class Key: + def __hash__(self): + return randrange(100000) + + od = OrderedDict() + for i in range(100): + key = Key() + od[key] = i + + # These should not crash. + with self.assertRaises(KeyError): + list(od.values()) + with self.assertRaises(KeyError): + list(od.items()) + with self.assertRaises(KeyError): + repr(od) + with self.assertRaises(KeyError): + od.copy() + + def test_issue24348(self): + OrderedDict = self.OrderedDict + + class Key: + def __hash__(self): + return 1 + + od = OrderedDict() + od[Key()] = 0 + # This should not crash. + od.popitem() + + def test_issue24667(self): + """ + dict resizes after a certain number of insertion operations, + whether or not there were deletions that freed up slots in the + hash table. During fast node lookup, OrderedDict must correctly + respond to all resizes, even if the current "size" is the same + as the old one. We verify that here by forcing a dict resize + on a sparse odict and then perform an operation that should + trigger an odict resize (e.g. popitem). One key aspect here is + that we will keep the size of the odict the same at each popitem + call. This verifies that we handled the dict resize properly. + """ + OrderedDict = self.OrderedDict + + od = OrderedDict() + for c0 in '0123456789ABCDEF': + for c1 in '0123456789ABCDEF': + if len(od) == 4: + # This should not raise a KeyError. + od.popitem(last=False) + key = c0 + c1 + od[key] = key + + # Direct use of dict methods + + def test_dict_setitem(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + dict.__setitem__(od, 'spam', 1) + self.assertNotIn('NULL', repr(od)) + + def test_dict_delitem(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + od['spam'] = 1 + od['ham'] = 2 + dict.__delitem__(od, 'spam') + with self.assertRaises(KeyError): + repr(od) + + def test_dict_clear(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + od['spam'] = 1 + od['ham'] = 2 + dict.clear(od) + self.assertNotIn('NULL', repr(od)) + + def test_dict_pop(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + od['spam'] = 1 + od['ham'] = 2 + dict.pop(od, 'spam') + with self.assertRaises(KeyError): + repr(od) + + def test_dict_popitem(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + od['spam'] = 1 + od['ham'] = 2 + dict.popitem(od) + with self.assertRaises(KeyError): + repr(od) + + def test_dict_setdefault(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + dict.setdefault(od, 'spam', 1) + self.assertNotIn('NULL', repr(od)) + + def test_dict_update(self): + OrderedDict = self.OrderedDict + od = OrderedDict() + dict.update(od, [('spam', 1)]) + self.assertNotIn('NULL', repr(od)) + + @suppress_immortalization() + def test_reference_loop(self): + # Issue 25935 + OrderedDict = self.OrderedDict + class A: + od = OrderedDict() + A.od[A] = None + r = weakref.ref(A) + del A + gc.collect() + self.assertIsNone(r()) + + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, self.OrderedDict) + support.check_free_after_iterating(self, lambda d: iter(d.keys()), self.OrderedDict) + support.check_free_after_iterating(self, lambda d: iter(d.values()), self.OrderedDict) + support.check_free_after_iterating(self, lambda d: iter(d.items()), self.OrderedDict) + + def test_merge_operator(self): + OrderedDict = self.OrderedDict + + a = OrderedDict({0: 0, 1: 1, 2: 1}) + b = OrderedDict({1: 1, 2: 2, 3: 3}) + + c = a.copy() + d = a.copy() + c |= b + d |= list(b.items()) + expected = OrderedDict({0: 0, 1: 1, 2: 2, 3: 3}) + self.assertEqual(a | dict(b), expected) + self.assertEqual(a | b, expected) + self.assertEqual(c, expected) + self.assertEqual(d, expected) + + c = b.copy() + c |= a + expected = OrderedDict({1: 1, 2: 1, 3: 3, 0: 0}) + self.assertEqual(dict(b) | a, expected) + self.assertEqual(b | a, expected) + self.assertEqual(c, expected) + + self.assertIs(type(a | b), OrderedDict) + self.assertIs(type(dict(a) | b), OrderedDict) + self.assertIs(type(a | dict(b)), OrderedDict) + + expected = a.copy() + a |= () + a |= "" + self.assertEqual(a, expected) + + with self.assertRaises(TypeError): + a | None + with self.assertRaises(TypeError): + a | () + with self.assertRaises(TypeError): + a | "BAD" + with self.assertRaises(TypeError): + a | "" + with self.assertRaises(ValueError): + a |= "BAD" + + @support.cpython_only + def test_ordered_dict_items_result_gc(self): + # bpo-42536: OrderedDict.items's tuple-reuse speed trick breaks the GC's + # assumptions about what can be untracked. Make sure we re-track result + # tuples whenever we reuse them. + it = iter(self.OrderedDict({None: []}).items()) + gc.collect() + # That GC collection probably untracked the recycled internal result + # tuple, which is initialized to (None, None). Make sure it's re-tracked + # when it's mutated and returned from __next__: + self.assertTrue(gc.is_tracked(next(it))) + + +class _TriggerSideEffectOnEqual: + count = 0 # number of calls to __eq__ + trigger = 1 # count value when to trigger side effect + + def __eq__(self, other): + if self.__class__.count == self.__class__.trigger: + self.side_effect() + self.__class__.count += 1 + return True + + def __hash__(self): + # all instances represent the same key + return -1 + + def side_effect(self): + raise NotImplementedError + +class PurePythonOrderedDictTests(OrderedDictTests, __TestCase): + + module = py_coll + OrderedDict = py_coll.OrderedDict + + def test_issue119004_attribute_error(self): + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + del dict1[TODEL] + + TODEL = Key() + dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) + dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) + # This causes an AttributeError due to the linked list being changed + msg = re.escape("'NoneType' object has no attribute 'key'") + self.assertRaisesRegex(AttributeError, msg, operator.eq, dict1, dict2) + self.assertEqual(Key.count, 2) + self.assertDictEqual(dict1, dict.fromkeys((0, 4.2))) + self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) + + +class CPythonBuiltinDictTests(__TestCase): + """Builtin dict preserves insertion order. + + Reuse some of tests in OrderedDict selectively. + """ + + module = builtins + OrderedDict = dict + +for method in ( + "test_init test_update test_abc test_clear test_delitem " + + "test_setitem test_detect_deletion_during_iteration " + + "test_popitem test_reinsert test_override_update " + + "test_highly_nested test_highly_nested_subclass " + + "test_delitem_hash_collision ").split(): + setattr(CPythonBuiltinDictTests, method, getattr(OrderedDictTests, method)) +del method + + + +class CPythonOrderedDictSideEffects: + + def check_runtime_error_issue119004(self, dict1, dict2): + msg = re.escape("OrderedDict mutated during iteration") + self.assertRaisesRegex(RuntimeError, msg, operator.eq, dict1, dict2) + + def test_issue119004_change_size_by_clear(self): + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + dict1.clear() + + dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) + dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) + self.check_runtime_error_issue119004(dict1, dict2) + self.assertEqual(Key.count, 2) + self.assertDictEqual(dict1, {}) + self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) + + def test_issue119004_change_size_by_delete_key(self): + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + del dict1[TODEL] + + TODEL = Key() + dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) + dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) + self.check_runtime_error_issue119004(dict1, dict2) + self.assertEqual(Key.count, 2) + self.assertDictEqual(dict1, dict.fromkeys((0, 4.2))) + self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) + + def test_issue119004_change_linked_list_by_clear(self): + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + dict1.clear() + dict1['a'] = dict1['b'] = 'c' + + dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) + dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) + self.check_runtime_error_issue119004(dict1, dict2) + self.assertEqual(Key.count, 2) + self.assertDictEqual(dict1, dict.fromkeys(('a', 'b'), 'c')) + self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) + + def test_issue119004_change_linked_list_by_delete_key(self): + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + del dict1[TODEL] + dict1['a'] = 'c' + + TODEL = Key() + dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) + dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) + self.check_runtime_error_issue119004(dict1, dict2) + self.assertEqual(Key.count, 2) + self.assertDictEqual(dict1, {0: None, 'a': 'c', 4.2: None}) + self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) + + def test_issue119004_change_size_by_delete_key_in_dict_eq(self): + class Key(_TriggerSideEffectOnEqual): + trigger = 0 + def side_effect(self): + del dict1[TODEL] + + TODEL = Key() + dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) + dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) + self.assertEqual(Key.count, 0) + # the side effect is in dict.__eq__ and modifies the length + self.assertNotEqual(dict1, dict2) + self.assertEqual(Key.count, 2) + self.assertDictEqual(dict1, dict.fromkeys((0, 4.2))) + self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) + + +@unittest.skipUnless(c_coll, 'requires the C version of the collections module') +class CPythonOrderedDictTests(OrderedDictTests, + CPythonOrderedDictSideEffects, + __TestCase): + + module = c_coll + OrderedDict = c_coll.OrderedDict + check_sizeof = support.check_sizeof + + @support.cpython_only + def test_sizeof_exact(self): + OrderedDict = self.OrderedDict + calcsize = struct.calcsize + size = support.calcobjsize + check = self.check_sizeof + + basicsize = size('nQ2P' + '3PnPn2P') + keysize = calcsize('n2BI2n') + + entrysize = calcsize('n2P') + p = calcsize('P') + nodesize = calcsize('Pn2P') + + od = OrderedDict() + check(od, basicsize) # 8byte indices + 8*2//3 * entry table + od.x = 1 + check(od, basicsize) + od.update([(i, i) for i in range(3)]) + check(od, basicsize + keysize + 8*p + 8 + 5*entrysize + 3*nodesize) + od.update([(i, i) for i in range(3, 10)]) + check(od, basicsize + keysize + 16*p + 16 + 10*entrysize + 10*nodesize) + + check(od.keys(), size('P')) + check(od.items(), size('P')) + check(od.values(), size('P')) + + itersize = size('iP2n2P') + check(iter(od), itersize) + check(iter(od.keys()), itersize) + check(iter(od.items()), itersize) + check(iter(od.values()), itersize) + + def test_key_change_during_iteration(self): + OrderedDict = self.OrderedDict + + od = OrderedDict.fromkeys('abcde') + self.assertEqual(list(od), list('abcde')) + with self.assertRaises(RuntimeError): + for i, k in enumerate(od): + od.move_to_end(k) + self.assertLess(i, 5) + with self.assertRaises(RuntimeError): + for k in od: + od['f'] = None + with self.assertRaises(RuntimeError): + for k in od: + del od['c'] + self.assertEqual(list(od), list('bdeaf')) + + def test_iterators_pickling(self): + OrderedDict = self.OrderedDict + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + od = OrderedDict(pairs) + + for method_name in ('keys', 'values', 'items'): + meth = getattr(od, method_name) + expected = list(meth())[1:] + for i in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(method_name=method_name, protocol=i): + it = iter(meth()) + next(it) + p = pickle.dumps(it, i) + unpickled = pickle.loads(p) + self.assertEqual(list(unpickled), expected) + self.assertEqual(list(it), expected) + + @support.cpython_only + def test_weakref_list_is_not_traversed(self): + # Check that the weakref list is not traversed when collecting + # OrderedDict objects. See bpo-39778 for more information. + + gc.collect() + + x = self.OrderedDict() + x.cycle = x + + cycle = [] + cycle.append(cycle) + + x_ref = weakref.ref(x) + cycle.append(x_ref) + + del x, cycle, x_ref + + gc.collect() + + +class PurePythonOrderedDictSubclassTests(PurePythonOrderedDictTests): + + module = py_coll + class OrderedDict(py_coll.OrderedDict): + pass + + +class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests): + + module = c_coll + class OrderedDict(c_coll.OrderedDict): + pass + + +class PurePythonOrderedDictWithSlotsCopyingTests(__TestCase): + + module = py_coll + class OrderedDict(py_coll.OrderedDict): + __slots__ = ('x', 'y') + test_copying = OrderedDictTests.test_copying + + +@unittest.skipUnless(c_coll, 'requires the C version of the collections module') +class CPythonOrderedDictWithSlotsCopyingTests(__TestCase): + + module = c_coll + class OrderedDict(c_coll.OrderedDict): + __slots__ = ('x', 'y') + test_copying = OrderedDictTests.test_copying + + +class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): + + @classmethod + def setUpClass(cls): + cls.type2test = py_coll.OrderedDict + super().setUpClass() + + def test_popitem(self): + d = self._empty_mapping() + self.assertRaises(KeyError, d.popitem) + + +@unittest.skipUnless(c_coll, 'requires the C version of the collections module') +class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): + + @classmethod + def setUpClass(cls): + cls.type2test = c_coll.OrderedDict + super().setUpClass() + + def test_popitem(self): + d = self._empty_mapping() + self.assertRaises(KeyError, d.popitem) + + +class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + + @classmethod + def setUpClass(cls): + class MyOrderedDict(py_coll.OrderedDict): + pass + cls.type2test = MyOrderedDict + super().setUpClass() + + def test_popitem(self): + d = self._empty_mapping() + self.assertRaises(KeyError, d.popitem) + + +@unittest.skipUnless(c_coll, 'requires the C version of the collections module') +class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + + @classmethod + def setUpClass(cls): + class MyOrderedDict(c_coll.OrderedDict): + pass + cls.type2test = MyOrderedDict + super().setUpClass() + + def test_popitem(self): + d = self._empty_mapping() + self.assertRaises(KeyError, d.popitem) + + +class SimpleLRUCache: + + def __init__(self, size): + super().__init__() + self.size = size + self.counts = dict.fromkeys(('get', 'set', 'del'), 0) + + def __getitem__(self, item): + self.counts['get'] += 1 + value = super().__getitem__(item) + self.move_to_end(item) + return value + + def __setitem__(self, key, value): + self.counts['set'] += 1 + while key not in self and len(self) >= self.size: + self.popitem(last=False) + super().__setitem__(key, value) + self.move_to_end(key) + + def __delitem__(self, key): + self.counts['del'] += 1 + super().__delitem__(key) + + +class SimpleLRUCacheTests: + + def test_add_after_full(self): + c = self.type2test(2) + c['t1'] = 1 + c['t2'] = 2 + c['t3'] = 3 + self.assertEqual(c.counts, {'get': 0, 'set': 3, 'del': 0}) + self.assertEqual(list(c), ['t2', 't3']) + self.assertEqual(c.counts, {'get': 0, 'set': 3, 'del': 0}) + + def test_popitem(self): + c = self.type2test(3) + for i in range(1, 4): + c[i] = i + self.assertEqual(c.popitem(last=False), (1, 1)) + self.assertEqual(c.popitem(last=True), (3, 3)) + self.assertEqual(c.counts, {'get': 0, 'set': 3, 'del': 0}) + + def test_pop(self): + c = self.type2test(3) + for i in range(1, 4): + c[i] = i + self.assertEqual(c.counts, {'get': 0, 'set': 3, 'del': 0}) + self.assertEqual(c.pop(2), 2) + self.assertEqual(c.counts, {'get': 0, 'set': 3, 'del': 0}) + self.assertEqual(c.pop(4, 0), 0) + self.assertEqual(c.counts, {'get': 0, 'set': 3, 'del': 0}) + self.assertRaises(KeyError, c.pop, 4) + self.assertEqual(c.counts, {'get': 0, 'set': 3, 'del': 0}) + + def test_change_order_on_get(self): + c = self.type2test(3) + for i in range(1, 4): + c[i] = i + self.assertEqual(list(c), list(range(1, 4))) + self.assertEqual(c.counts, {'get': 0, 'set': 3, 'del': 0}) + self.assertEqual(c[2], 2) + self.assertEqual(c.counts, {'get': 1, 'set': 3, 'del': 0}) + self.assertEqual(list(c), [1, 3, 2]) + + +class PySimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): + + class type2test(SimpleLRUCache, py_coll.OrderedDict): + pass + + +@unittest.skipUnless(c_coll, 'requires the C version of the collections module') +class CSimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): + + @classmethod + def setUpClass(cls): + class type2test(SimpleLRUCache, c_coll.OrderedDict): + pass + cls.type2test = type2test + super().setUpClass() + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_raise.diff b/test/dynamo/cpython/3_13/test_raise.diff new file mode 100644 index 00000000000000..4aaab8e8b66a21 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_raise.diff @@ -0,0 +1,120 @@ +diff --git a/test/dynamo/cpython/3_13/test_raise.py b/test/dynamo/cpython/3_13/test_raise.py +index 6d26a61bee4..8a52b9bfc82 100644 +--- a/test/dynamo/cpython/3_13/test_raise.py ++++ b/test/dynamo/cpython/3_13/test_raise.py +@@ -1,3 +1,55 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import ( ++ run_tests, ++) ++ ++__TestCase = CPythonTestCase ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + # Copyright 2007 Google, Inc. All Rights Reserved. + # Licensed to PSF under a Contributor Agreement. + +@@ -23,7 +75,7 @@ class Context: + return True + + +-class TestRaise(unittest.TestCase): ++class TestRaise(__TestCase): + def test_invalid_reraise(self): + try: + raise +@@ -148,7 +200,7 @@ class TestRaise(unittest.TestCase): + + + +-class TestCause(unittest.TestCase): ++class TestCause(__TestCase): + + def testCauseSyntax(self): + try: +@@ -221,7 +273,7 @@ class TestCause(unittest.TestCase): + self.fail("No exception raised") + + +-class TestTraceback(unittest.TestCase): ++class TestTraceback(__TestCase): + + def test_sets_traceback(self): + try: +@@ -242,7 +294,7 @@ class TestTraceback(unittest.TestCase): + self.fail("No exception raised") + + +-class TestTracebackType(unittest.TestCase): ++class TestTracebackType(__TestCase): + + def raiser(self): + raise ValueError +@@ -308,7 +360,7 @@ class TestTracebackType(unittest.TestCase): + types.TracebackType(other_tb, frame, 1, "nuh-uh") + + +-class TestContext(unittest.TestCase): ++class TestContext(__TestCase): + def test_instance_context_instance_raise(self): + context = IndexError() + try: +@@ -498,7 +550,7 @@ class TestContext(unittest.TestCase): + self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type) + + +-class TestRemovedFunctionality(unittest.TestCase): ++class TestRemovedFunctionality(__TestCase): + def test_tuples(self): + try: + raise (IndexError, KeyError) # This should be a tuple! +@@ -517,4 +569,4 @@ class TestRemovedFunctionality(unittest.TestCase): + + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_raise.py b/test/dynamo/cpython/3_13/test_raise.py new file mode 100644 index 00000000000000..8a52b9bfc82f2b --- /dev/null +++ b/test/dynamo/cpython/3_13/test_raise.py @@ -0,0 +1,572 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, +) + +__TestCase = CPythonTestCase + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +# Copyright 2007 Google, Inc. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement. + +"""Tests for the raise statement.""" + +from test import support +import sys +import types +import unittest + + +def get_tb(): + try: + raise OSError() + except OSError as e: + return e.__traceback__ + + +class Context: + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, exc_tb): + return True + + +class TestRaise(__TestCase): + def test_invalid_reraise(self): + try: + raise + except RuntimeError as e: + self.assertIn("No active exception", str(e)) + else: + self.fail("No exception raised") + + def test_reraise(self): + try: + try: + raise IndexError() + except IndexError as e: + exc1 = e + raise + except IndexError as exc2: + self.assertIs(exc1, exc2) + else: + self.fail("No exception raised") + + def test_except_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + try: + raise KeyError("caught") + except KeyError: + pass + raise + self.assertRaises(TypeError, reraise) + + def test_finally_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + try: + raise KeyError("caught") + finally: + raise + self.assertRaises(KeyError, reraise) + + def test_nested_reraise(self): + def nested_reraise(): + raise + def reraise(): + try: + raise TypeError("foo") + except: + nested_reraise() + self.assertRaises(TypeError, reraise) + + def test_raise_from_None(self): + try: + try: + raise TypeError("foo") + except: + raise ValueError() from None + except ValueError as e: + self.assertIsInstance(e.__context__, TypeError) + self.assertIsNone(e.__cause__) + + def test_with_reraise1(self): + def reraise(): + try: + raise TypeError("foo") + except: + with Context(): + pass + raise + self.assertRaises(TypeError, reraise) + + def test_with_reraise2(self): + def reraise(): + try: + raise TypeError("foo") + except: + with Context(): + raise KeyError("caught") + raise + self.assertRaises(TypeError, reraise) + + def test_yield_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + yield 1 + raise + g = reraise() + next(g) + self.assertRaises(TypeError, lambda: next(g)) + self.assertRaises(StopIteration, lambda: next(g)) + + def test_erroneous_exception(self): + class MyException(Exception): + def __init__(self): + raise RuntimeError() + + try: + raise MyException + except RuntimeError: + pass + else: + self.fail("No exception raised") + + def test_new_returns_invalid_instance(self): + # See issue #11627. + class MyException(Exception): + def __new__(cls, *args): + return object() + + with self.assertRaises(TypeError): + raise MyException + + def test_assert_with_tuple_arg(self): + try: + assert False, (3,) + except AssertionError as e: + self.assertEqual(str(e), "(3,)") + + + +class TestCause(__TestCase): + + def testCauseSyntax(self): + try: + try: + try: + raise TypeError + except Exception: + raise ValueError from None + except ValueError as exc: + self.assertIsNone(exc.__cause__) + self.assertTrue(exc.__suppress_context__) + exc.__suppress_context__ = False + raise exc + except ValueError as exc: + e = exc + + self.assertIsNone(e.__cause__) + self.assertFalse(e.__suppress_context__) + self.assertIsInstance(e.__context__, TypeError) + + def test_invalid_cause(self): + try: + raise IndexError from 5 + except TypeError as e: + self.assertIn("exception cause", str(e)) + else: + self.fail("No exception raised") + + def test_class_cause(self): + try: + raise IndexError from KeyError + except IndexError as e: + self.assertIsInstance(e.__cause__, KeyError) + else: + self.fail("No exception raised") + + def test_class_cause_nonexception_result(self): + class ConstructsNone(BaseException): + @classmethod + def __new__(*args, **kwargs): + return None + try: + raise IndexError from ConstructsNone + except TypeError as e: + self.assertIn("should have returned an instance of BaseException", str(e)) + except IndexError: + self.fail("Wrong kind of exception raised") + else: + self.fail("No exception raised") + + def test_instance_cause(self): + cause = KeyError() + try: + raise IndexError from cause + except IndexError as e: + self.assertIs(e.__cause__, cause) + else: + self.fail("No exception raised") + + def test_erroneous_cause(self): + class MyException(Exception): + def __init__(self): + raise RuntimeError() + + try: + raise IndexError from MyException + except RuntimeError: + pass + else: + self.fail("No exception raised") + + +class TestTraceback(__TestCase): + + def test_sets_traceback(self): + try: + raise IndexError() + except IndexError as e: + self.assertIsInstance(e.__traceback__, types.TracebackType) + else: + self.fail("No exception raised") + + def test_accepts_traceback(self): + tb = get_tb() + try: + raise IndexError().with_traceback(tb) + except IndexError as e: + self.assertNotEqual(e.__traceback__, tb) + self.assertEqual(e.__traceback__.tb_next, tb) + else: + self.fail("No exception raised") + + +class TestTracebackType(__TestCase): + + def raiser(self): + raise ValueError + + def test_attrs(self): + try: + self.raiser() + except Exception as exc: + tb = exc.__traceback__ + + self.assertIsInstance(tb.tb_next, types.TracebackType) + self.assertIs(tb.tb_frame, sys._getframe()) + self.assertIsInstance(tb.tb_lasti, int) + self.assertIsInstance(tb.tb_lineno, int) + + self.assertIs(tb.tb_next.tb_next, None) + + # Invalid assignments + with self.assertRaises(TypeError): + del tb.tb_next + + with self.assertRaises(TypeError): + tb.tb_next = "asdf" + + # Loops + with self.assertRaises(ValueError): + tb.tb_next = tb + + with self.assertRaises(ValueError): + tb.tb_next.tb_next = tb + + # Valid assignments + tb.tb_next = None + self.assertIs(tb.tb_next, None) + + new_tb = get_tb() + tb.tb_next = new_tb + self.assertIs(tb.tb_next, new_tb) + + def test_constructor(self): + other_tb = get_tb() + frame = sys._getframe() + + tb = types.TracebackType(other_tb, frame, 1, 2) + self.assertEqual(tb.tb_next, other_tb) + self.assertEqual(tb.tb_frame, frame) + self.assertEqual(tb.tb_lasti, 1) + self.assertEqual(tb.tb_lineno, 2) + + tb = types.TracebackType(None, frame, 1, 2) + self.assertEqual(tb.tb_next, None) + + with self.assertRaises(TypeError): + types.TracebackType("no", frame, 1, 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, "no", 1, 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, frame, "no", 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, frame, 1, "nuh-uh") + + +class TestContext(__TestCase): + def test_instance_context_instance_raise(self): + context = IndexError() + try: + try: + raise context + except: + raise OSError() + except OSError as e: + self.assertIs(e.__context__, context) + else: + self.fail("No exception raised") + + def test_class_context_instance_raise(self): + context = IndexError + try: + try: + raise context + except: + raise OSError() + except OSError as e: + self.assertIsNot(e.__context__, context) + self.assertIsInstance(e.__context__, context) + else: + self.fail("No exception raised") + + def test_class_context_class_raise(self): + context = IndexError + try: + try: + raise context + except: + raise OSError + except OSError as e: + self.assertIsNot(e.__context__, context) + self.assertIsInstance(e.__context__, context) + else: + self.fail("No exception raised") + + def test_c_exception_context(self): + try: + try: + 1/0 + except: + raise OSError + except OSError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + def test_c_exception_raise(self): + try: + try: + 1/0 + except: + xyzzy + except NameError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + def test_noraise_finally(self): + try: + try: + pass + finally: + raise OSError + except OSError as e: + self.assertIsNone(e.__context__) + else: + self.fail("No exception raised") + + def test_raise_finally(self): + try: + try: + 1/0 + finally: + raise OSError + except OSError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + def test_context_manager(self): + class ContextManager: + def __enter__(self): + pass + def __exit__(self, t, v, tb): + xyzzy + try: + with ContextManager(): + 1/0 + except NameError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + def test_cycle_broken(self): + # Self-cycles (when re-raising a caught exception) are broken + try: + try: + 1/0 + except ZeroDivisionError as e: + raise e + except ZeroDivisionError as e: + self.assertIsNone(e.__context__) + + def test_reraise_cycle_broken(self): + # Non-trivial context cycles (through re-raising a previous exception) + # are broken too. + try: + try: + xyzzy + except NameError as a: + try: + 1/0 + except ZeroDivisionError: + raise a + except NameError as e: + self.assertIsNone(e.__context__.__context__) + + def test_not_last(self): + # Context is not necessarily the last exception + context = Exception("context") + try: + raise context + except Exception: + try: + raise Exception("caught") + except Exception: + pass + try: + raise Exception("new") + except Exception as exc: + raised = exc + self.assertIs(raised.__context__, context) + + def test_3118(self): + # deleting the generator caused the __context__ to be cleared + def gen(): + try: + yield 1 + finally: + pass + + def f(): + g = gen() + next(g) + try: + try: + raise ValueError + except: + del g + raise KeyError + except Exception as e: + self.assertIsInstance(e.__context__, ValueError) + + f() + + def test_3611(self): + import gc + # A re-raised exception in a __del__ caused the __context__ + # to be cleared + class C: + def __del__(self): + try: + 1/0 + except: + raise + + def f(): + x = C() + try: + try: + f.x + except AttributeError: + # make x.__del__ trigger + del x + gc.collect() # For PyPy or other GCs. + raise TypeError + except Exception as e: + self.assertNotEqual(e.__context__, None) + self.assertIsInstance(e.__context__, AttributeError) + + with support.catch_unraisable_exception() as cm: + f() + + self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type) + + +class TestRemovedFunctionality(__TestCase): + def test_tuples(self): + try: + raise (IndexError, KeyError) # This should be a tuple! + except TypeError: + pass + else: + self.fail("No exception raised") + + def test_strings(self): + try: + raise "foo" + except TypeError: + pass + else: + self.fail("No exception raised") + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_sys.diff b/test/dynamo/cpython/3_13/test_sys.diff new file mode 100644 index 00000000000000..7fd0241560565a --- /dev/null +++ b/test/dynamo/cpython/3_13/test_sys.diff @@ -0,0 +1,157 @@ +diff --git a/test/dynamo/cpython/3_13/test_sys.py b/test/dynamo/cpython/3_13/test_sys.py +index 72d51361e0b..0b4c6882e62 100644 +--- a/test/dynamo/cpython/3_13/test_sys.py ++++ b/test/dynamo/cpython/3_13/test_sys.py +@@ -1,3 +1,55 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import ( ++ run_tests, ++) ++ ++__TestCase = CPythonTestCase ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + import builtins + import codecs + import _datetime +@@ -35,7 +87,7 @@ def requires_subinterpreters(meth): + + DICT_KEY_STRUCT_FORMAT = 'n2BI2n' + +-class DisplayHookTest(unittest.TestCase): ++class DisplayHookTest(__TestCase): + + def test_original_displayhook(self): + dh = sys.__displayhook__ +@@ -81,19 +133,8 @@ class DisplayHookTest(unittest.TestCase): + code = compile("42", "", "single") + self.assertRaises(ValueError, eval, code) + +- def test_gh130163(self): +- class X: +- def __repr__(self): +- sys.stdout = io.StringIO() +- support.gc_collect() +- return 'foo' +- +- with support.swap_attr(sys, 'stdout', None): +- sys.stdout = io.StringIO() # the only reference +- sys.displayhook(X()) # should not crash + +- +-class ActiveExceptionTests(unittest.TestCase): ++class ActiveExceptionTests(__TestCase): + def test_exc_info_no_exception(self): + self.assertEqual(sys.exc_info(), (None, None, None)) + +@@ -157,7 +198,7 @@ class ActiveExceptionTests(unittest.TestCase): + self.assertIs(exc, e) + + +-class ExceptHookTest(unittest.TestCase): ++class ExceptHookTest(__TestCase): + + @force_not_colorized + def test_original_excepthook(self): +@@ -200,7 +241,7 @@ class ExceptHookTest(unittest.TestCase): + # Python/pythonrun.c::PyErr_PrintEx() is tricky. + + +-class SysModuleTest(unittest.TestCase): ++class SysModuleTest(__TestCase): + + def tearDown(self): + test.support.reap_children() +@@ -500,6 +541,7 @@ class SysModuleTest(unittest.TestCase): + is sys._getframe().f_code + ) + ++ @unittest.expectedFailure + def test_getframemodulename(self): + # Default depth gets ourselves + self.assertEqual(__name__, sys._getframemodulename()) +@@ -808,7 +850,7 @@ class SysModuleTest(unittest.TestCase): + self.assertRaises(TypeError, sys.intern, S("abc")) + if has_is_interned: + self.assertIs(sys._is_interned(S("abc")), False) +- ++ + @support.cpython_only + @requires_subinterpreters + def test_subinterp_intern_dynamically_allocated(self): +@@ -1359,7 +1401,7 @@ class SysModuleTest(unittest.TestCase): + + + @test.support.cpython_only +-class UnraisableHookTest(unittest.TestCase): ++class UnraisableHookTest(__TestCase): + def test_original_unraisablehook(self): + _testcapi = import_helper.import_module('_testcapi') + from _testcapi import err_writeunraisable, err_formatunraisable +@@ -1516,7 +1558,7 @@ class UnraisableHookTest(unittest.TestCase): + + + @test.support.cpython_only +-class SizeofTest(unittest.TestCase): ++class SizeofTest(__TestCase): + + def setUp(self): + self.P = struct.calcsize('P') +@@ -1524,6 +1566,7 @@ class SizeofTest(unittest.TestCase): + _testinternalcapi = import_helper.import_module("_testinternalcapi") + self.gc_headsize = _testinternalcapi.SIZEOF_PYGC_HEAD + self.managed_pre_header_size = _testinternalcapi.SIZEOF_MANAGED_PRE_HEADER ++ super().setUp() + + check_sizeof = test.support.check_sizeof + +@@ -1960,4 +2003,4 @@ class SizeofTest(unittest.TestCase): + self.assertEqual(err, b"") + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_sys.py b/test/dynamo/cpython/3_13/test_sys.py new file mode 100644 index 00000000000000..f2d782127a4852 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_sys.py @@ -0,0 +1,2006 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, +) + +__TestCase = CPythonTestCase + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import builtins +import codecs +import _datetime +import gc +import io +import locale +import operator +import os +import random +import struct +import subprocess +import sys +import sysconfig +import test.support +from test import support +from test.support import os_helper +from test.support.script_helper import assert_python_ok, assert_python_failure +from test.support import threading_helper +from test.support import import_helper +from test.support import force_not_colorized +try: + from test.support import interpreters +except ImportError: + interpreters = None +import textwrap +import unittest +import warnings + + +def requires_subinterpreters(meth): + """Decorator to skip a test if subinterpreters are not supported.""" + return unittest.skipIf(interpreters is None, + 'subinterpreters required')(meth) + + +DICT_KEY_STRUCT_FORMAT = 'n2BI2n' + +class DisplayHookTest(__TestCase): + + def test_original_displayhook(self): + dh = sys.__displayhook__ + + with support.captured_stdout() as out: + dh(42) + + self.assertEqual(out.getvalue(), "42\n") + self.assertEqual(builtins._, 42) + + del builtins._ + + with support.captured_stdout() as out: + dh(None) + + self.assertEqual(out.getvalue(), "") + self.assertTrue(not hasattr(builtins, "_")) + + # sys.displayhook() requires arguments + self.assertRaises(TypeError, dh) + + stdout = sys.stdout + try: + del sys.stdout + self.assertRaises(RuntimeError, dh, 42) + finally: + sys.stdout = stdout + + def test_lost_displayhook(self): + displayhook = sys.displayhook + try: + del sys.displayhook + code = compile("42", "", "single") + self.assertRaises(RuntimeError, eval, code) + finally: + sys.displayhook = displayhook + + def test_custom_displayhook(self): + def baddisplayhook(obj): + raise ValueError + + with support.swap_attr(sys, 'displayhook', baddisplayhook): + code = compile("42", "", "single") + self.assertRaises(ValueError, eval, code) + + +class ActiveExceptionTests(__TestCase): + def test_exc_info_no_exception(self): + self.assertEqual(sys.exc_info(), (None, None, None)) + + def test_sys_exception_no_exception(self): + self.assertEqual(sys.exception(), None) + + def test_exc_info_with_exception_instance(self): + def f(): + raise ValueError(42) + + try: + f() + except Exception as e_: + e = e_ + exc_info = sys.exc_info() + + self.assertIsInstance(e, ValueError) + self.assertIs(exc_info[0], ValueError) + self.assertIs(exc_info[1], e) + self.assertIs(exc_info[2], e.__traceback__) + + def test_exc_info_with_exception_type(self): + def f(): + raise ValueError + + try: + f() + except Exception as e_: + e = e_ + exc_info = sys.exc_info() + + self.assertIsInstance(e, ValueError) + self.assertIs(exc_info[0], ValueError) + self.assertIs(exc_info[1], e) + self.assertIs(exc_info[2], e.__traceback__) + + def test_sys_exception_with_exception_instance(self): + def f(): + raise ValueError(42) + + try: + f() + except Exception as e_: + e = e_ + exc = sys.exception() + + self.assertIsInstance(e, ValueError) + self.assertIs(exc, e) + + def test_sys_exception_with_exception_type(self): + def f(): + raise ValueError + + try: + f() + except Exception as e_: + e = e_ + exc = sys.exception() + + self.assertIsInstance(e, ValueError) + self.assertIs(exc, e) + + +class ExceptHookTest(__TestCase): + + @force_not_colorized + def test_original_excepthook(self): + try: + raise ValueError(42) + except ValueError as exc: + with support.captured_stderr() as err: + sys.__excepthook__(*sys.exc_info()) + + self.assertTrue(err.getvalue().endswith("ValueError: 42\n")) + + self.assertRaises(TypeError, sys.__excepthook__) + + @force_not_colorized + def test_excepthook_bytes_filename(self): + # bpo-37467: sys.excepthook() must not crash if a filename + # is a bytes string + with warnings.catch_warnings(): + warnings.simplefilter('ignore', BytesWarning) + + try: + raise SyntaxError("msg", (b"bytes_filename", 123, 0, "text")) + except SyntaxError as exc: + with support.captured_stderr() as err: + sys.__excepthook__(*sys.exc_info()) + + err = err.getvalue() + self.assertIn(""" File "b'bytes_filename'", line 123\n""", err) + self.assertIn(""" text\n""", err) + self.assertTrue(err.endswith("SyntaxError: msg\n")) + + def test_excepthook(self): + with test.support.captured_output("stderr") as stderr: + with test.support.catch_unraisable_exception(): + sys.excepthook(1, '1', 1) + self.assertTrue("TypeError: print_exception(): Exception expected for " \ + "value, str found" in stderr.getvalue()) + + # FIXME: testing the code for a lost or replaced excepthook in + # Python/pythonrun.c::PyErr_PrintEx() is tricky. + + +class SysModuleTest(__TestCase): + + def tearDown(self): + test.support.reap_children() + + def test_exit(self): + # call with two arguments + self.assertRaises(TypeError, sys.exit, 42, 42) + + # call without argument + with self.assertRaises(SystemExit) as cm: + sys.exit() + self.assertIsNone(cm.exception.code) + + rc, out, err = assert_python_ok('-c', 'import sys; sys.exit()') + self.assertEqual(rc, 0) + self.assertEqual(out, b'') + self.assertEqual(err, b'') + + # gh-125842: Windows uses 32-bit unsigned integers for exit codes + # so a -1 exit code is sometimes interpreted as 0xffff_ffff. + rc, out, err = assert_python_failure('-c', 'import sys; sys.exit(0xffff_ffff)') + self.assertIn(rc, (-1, 0xff, 0xffff_ffff)) + self.assertEqual(out, b'') + self.assertEqual(err, b'') + + # Overflow results in a -1 exit code, which may be converted to 0xff + # or 0xffff_ffff. + rc, out, err = assert_python_failure('-c', 'import sys; sys.exit(2**128)') + self.assertIn(rc, (-1, 0xff, 0xffff_ffff)) + self.assertEqual(out, b'') + self.assertEqual(err, b'') + + # call with integer argument + with self.assertRaises(SystemExit) as cm: + sys.exit(42) + self.assertEqual(cm.exception.code, 42) + + # call with tuple argument with one entry + # entry will be unpacked + with self.assertRaises(SystemExit) as cm: + sys.exit((42,)) + self.assertEqual(cm.exception.code, 42) + + # call with string argument + with self.assertRaises(SystemExit) as cm: + sys.exit("exit") + self.assertEqual(cm.exception.code, "exit") + + # call with tuple argument with two entries + with self.assertRaises(SystemExit) as cm: + sys.exit((17, 23)) + self.assertEqual(cm.exception.code, (17, 23)) + + # test that the exit machinery handles SystemExits properly + rc, out, err = assert_python_failure('-c', 'raise SystemExit(47)') + self.assertEqual(rc, 47) + self.assertEqual(out, b'') + self.assertEqual(err, b'') + + def check_exit_message(code, expected, **env_vars): + rc, out, err = assert_python_failure('-c', code, **env_vars) + self.assertEqual(rc, 1) + self.assertEqual(out, b'') + self.assertTrue(err.startswith(expected), + "%s doesn't start with %s" % (ascii(err), ascii(expected))) + + # test that stderr buffer is flushed before the exit message is written + # into stderr + check_exit_message( + r'import sys; sys.stderr.write("unflushed,"); sys.exit("message")', + b"unflushed,message") + + # test that the exit message is written with backslashreplace error + # handler to stderr + check_exit_message( + r'import sys; sys.exit("surrogates:\uDCFF")', + b"surrogates:\\udcff") + + # test that the unicode message is encoded to the stderr encoding + # instead of the default encoding (utf8) + check_exit_message( + r'import sys; sys.exit("h\xe9")', + b"h\xe9", PYTHONIOENCODING='latin-1') + + @support.requires_subprocess() + def test_exit_codes_under_repl(self): + # GH-129900: SystemExit, or things that raised it, didn't + # get their return code propagated by the REPL + import tempfile + + exit_ways = [ + "exit", + "__import__('sys').exit", + "raise SystemExit" + ] + + for exitfunc in exit_ways: + for return_code in (0, 123): + with self.subTest(exitfunc=exitfunc, return_code=return_code): + with tempfile.TemporaryFile("w+") as stdin: + stdin.write(f"{exitfunc}({return_code})\n") + stdin.seek(0) + proc = subprocess.run([sys.executable], stdin=stdin) + self.assertEqual(proc.returncode, return_code) + + def test_getdefaultencoding(self): + self.assertRaises(TypeError, sys.getdefaultencoding, 42) + # can't check more than the type, as the user might have changed it + self.assertIsInstance(sys.getdefaultencoding(), str) + + # testing sys.settrace() is done in test_sys_settrace.py + # testing sys.setprofile() is done in test_sys_setprofile.py + + def test_switchinterval(self): + self.assertRaises(TypeError, sys.setswitchinterval) + self.assertRaises(TypeError, sys.setswitchinterval, "a") + self.assertRaises(ValueError, sys.setswitchinterval, -1.0) + self.assertRaises(ValueError, sys.setswitchinterval, 0.0) + orig = sys.getswitchinterval() + # sanity check + self.assertTrue(orig < 0.5, orig) + try: + for n in 0.00001, 0.05, 3.0, orig: + sys.setswitchinterval(n) + self.assertAlmostEqual(sys.getswitchinterval(), n) + finally: + sys.setswitchinterval(orig) + + def test_getrecursionlimit(self): + limit = sys.getrecursionlimit() + self.assertIsInstance(limit, int) + self.assertGreater(limit, 1) + + self.assertRaises(TypeError, sys.getrecursionlimit, 42) + + def test_setrecursionlimit(self): + old_limit = sys.getrecursionlimit() + try: + sys.setrecursionlimit(10_005) + self.assertEqual(sys.getrecursionlimit(), 10_005) + + self.assertRaises(TypeError, sys.setrecursionlimit) + self.assertRaises(ValueError, sys.setrecursionlimit, -42) + finally: + sys.setrecursionlimit(old_limit) + + def test_recursionlimit_recovery(self): + if hasattr(sys, 'gettrace') and sys.gettrace(): + self.skipTest('fatal error if run with a trace function') + + old_limit = sys.getrecursionlimit() + def f(): + f() + try: + for depth in (50, 75, 100, 250, 1000): + try: + sys.setrecursionlimit(depth) + except RecursionError: + # Issue #25274: The recursion limit is too low at the + # current recursion depth + continue + + # Issue #5392: test stack overflow after hitting recursion + # limit twice + with self.assertRaises(RecursionError): + f() + with self.assertRaises(RecursionError): + f() + finally: + sys.setrecursionlimit(old_limit) + + @test.support.cpython_only + def test_setrecursionlimit_to_depth(self): + # Issue #25274: Setting a low recursion limit must be blocked if the + # current recursion depth is already higher than limit. + + old_limit = sys.getrecursionlimit() + try: + depth = support.get_recursion_depth() + with self.subTest(limit=sys.getrecursionlimit(), depth=depth): + # depth + 1 is OK + sys.setrecursionlimit(depth + 1) + + # reset the limit to be able to call self.assertRaises() + # context manager + sys.setrecursionlimit(old_limit) + with self.assertRaises(RecursionError) as cm: + sys.setrecursionlimit(depth) + self.assertRegex(str(cm.exception), + "cannot set the recursion limit to [0-9]+ " + "at the recursion depth [0-9]+: " + "the limit is too low") + finally: + sys.setrecursionlimit(old_limit) + + @unittest.skipUnless(support.Py_GIL_DISABLED, "only meaningful if the GIL is disabled") + @threading_helper.requires_working_threading() + def test_racing_recursion_limit(self): + from threading import Thread + def something_recursive(): + def count(n): + if n > 0: + return count(n - 1) + 1 + return 0 + + count(50) + + def set_recursion_limit(): + for limit in range(100, 200): + sys.setrecursionlimit(limit) + + threads = [] + for _ in range(5): + threads.append(Thread(target=set_recursion_limit)) + + for _ in range(5): + threads.append(Thread(target=something_recursive)) + + with threading_helper.catch_threading_exception() as cm: + with threading_helper.start_threads(threads): + pass + + if cm.exc_value: + raise cm.exc_value + + def test_getwindowsversion(self): + # Raise SkipTest if sys doesn't have getwindowsversion attribute + test.support.get_attribute(sys, "getwindowsversion") + v = sys.getwindowsversion() + self.assertEqual(len(v), 5) + self.assertIsInstance(v[0], int) + self.assertIsInstance(v[1], int) + self.assertIsInstance(v[2], int) + self.assertIsInstance(v[3], int) + self.assertIsInstance(v[4], str) + self.assertRaises(IndexError, operator.getitem, v, 5) + self.assertIsInstance(v.major, int) + self.assertIsInstance(v.minor, int) + self.assertIsInstance(v.build, int) + self.assertIsInstance(v.platform, int) + self.assertIsInstance(v.service_pack, str) + self.assertIsInstance(v.service_pack_minor, int) + self.assertIsInstance(v.service_pack_major, int) + self.assertIsInstance(v.suite_mask, int) + self.assertIsInstance(v.product_type, int) + self.assertEqual(v[0], v.major) + self.assertEqual(v[1], v.minor) + self.assertEqual(v[2], v.build) + self.assertEqual(v[3], v.platform) + self.assertEqual(v[4], v.service_pack) + + # This is how platform.py calls it. Make sure tuple + # still has 5 elements + maj, min, buildno, plat, csd = sys.getwindowsversion() + + def test_call_tracing(self): + self.assertRaises(TypeError, sys.call_tracing, type, 2) + + @unittest.skipUnless(hasattr(sys, "setdlopenflags"), + 'test needs sys.setdlopenflags()') + def test_dlopenflags(self): + self.assertTrue(hasattr(sys, "getdlopenflags")) + self.assertRaises(TypeError, sys.getdlopenflags, 42) + oldflags = sys.getdlopenflags() + self.assertRaises(TypeError, sys.setdlopenflags) + sys.setdlopenflags(oldflags+1) + self.assertEqual(sys.getdlopenflags(), oldflags+1) + sys.setdlopenflags(oldflags) + + @test.support.refcount_test + def test_refcount(self): + # n here originally had to be a global in order for this test to pass + # while tracing with a python function. Tracing used to call + # PyFrame_FastToLocals, which would add a copy of any locals to the + # frame object, causing the ref count to increase by 2 instead of 1. + # While that no longer happens (due to PEP 667), this test case retains + # its original global-based implementation + # PEP 683's immortal objects also made this point moot, since the + # refcount for None doesn't change anyway. Maybe this test should be + # using a different constant value? (e.g. an integer) + global n + self.assertRaises(TypeError, sys.getrefcount) + c = sys.getrefcount(None) + n = None + # Singleton refcnts don't change + self.assertEqual(sys.getrefcount(None), c) + del n + self.assertEqual(sys.getrefcount(None), c) + if hasattr(sys, "gettotalrefcount"): + self.assertIsInstance(sys.gettotalrefcount(), int) + + def test_getframe(self): + self.assertRaises(TypeError, sys._getframe, 42, 42) + self.assertRaises(ValueError, sys._getframe, 2000000000) + self.assertTrue( + SysModuleTest.test_getframe.__code__ \ + is sys._getframe().f_code + ) + + @unittest.expectedFailure + def test_getframemodulename(self): + # Default depth gets ourselves + self.assertEqual(__name__, sys._getframemodulename()) + self.assertEqual("unittest.case", sys._getframemodulename(1)) + i = 0 + f = sys._getframe(i) + while f: + self.assertEqual( + f.f_globals['__name__'], + sys._getframemodulename(i) or '__main__' + ) + i += 1 + f2 = f.f_back + try: + f = sys._getframe(i) + except ValueError: + break + self.assertIs(f, f2) + self.assertIsNone(sys._getframemodulename(i)) + + # sys._current_frames() is a CPython-only gimmick. + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_current_frames(self): + import threading + import traceback + + # Spawn a thread that blocks at a known place. Then the main + # thread does sys._current_frames(), and verifies that the frames + # returned make sense. + entered_g = threading.Event() + leave_g = threading.Event() + thread_info = [] # the thread's id + + def f123(): + g456() + + def g456(): + thread_info.append(threading.get_ident()) + entered_g.set() + leave_g.wait() + + t = threading.Thread(target=f123) + t.start() + entered_g.wait() + + try: + # At this point, t has finished its entered_g.set(), although it's + # impossible to guess whether it's still on that line or has moved on + # to its leave_g.wait(). + self.assertEqual(len(thread_info), 1) + thread_id = thread_info[0] + + d = sys._current_frames() + for tid in d: + self.assertIsInstance(tid, int) + self.assertGreater(tid, 0) + + main_id = threading.get_ident() + self.assertIn(main_id, d) + self.assertIn(thread_id, d) + + # Verify that the captured main-thread frame is _this_ frame. + frame = d.pop(main_id) + self.assertTrue(frame is sys._getframe()) + + # Verify that the captured thread frame is blocked in g456, called + # from f123. This is a little tricky, since various bits of + # threading.py are also in the thread's call stack. + frame = d.pop(thread_id) + stack = traceback.extract_stack(frame) + for i, (filename, lineno, funcname, sourceline) in enumerate(stack): + if funcname == "f123": + break + else: + self.fail("didn't find f123() on thread's call stack") + + self.assertEqual(sourceline, "g456()") + + # And the next record must be for g456(). + filename, lineno, funcname, sourceline = stack[i+1] + self.assertEqual(funcname, "g456") + self.assertIn(sourceline, ["leave_g.wait()", "entered_g.set()"]) + finally: + # Reap the spawned thread. + leave_g.set() + t.join() + + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_current_exceptions(self): + import threading + import traceback + + # Spawn a thread that blocks at a known place. Then the main + # thread does sys._current_frames(), and verifies that the frames + # returned make sense. + g_raised = threading.Event() + leave_g = threading.Event() + thread_info = [] # the thread's id + + def f123(): + g456() + + def g456(): + thread_info.append(threading.get_ident()) + while True: + try: + raise ValueError("oops") + except ValueError: + g_raised.set() + if leave_g.wait(timeout=support.LONG_TIMEOUT): + break + + t = threading.Thread(target=f123) + t.start() + g_raised.wait(timeout=support.LONG_TIMEOUT) + + try: + self.assertEqual(len(thread_info), 1) + thread_id = thread_info[0] + + d = sys._current_exceptions() + for tid in d: + self.assertIsInstance(tid, int) + self.assertGreater(tid, 0) + + main_id = threading.get_ident() + self.assertIn(main_id, d) + self.assertIn(thread_id, d) + self.assertEqual(None, d.pop(main_id)) + + # Verify that the captured thread frame is blocked in g456, called + # from f123. This is a little tricky, since various bits of + # threading.py are also in the thread's call stack. + exc_value = d.pop(thread_id) + stack = traceback.extract_stack(exc_value.__traceback__.tb_frame) + for i, (filename, lineno, funcname, sourceline) in enumerate(stack): + if funcname == "f123": + break + else: + self.fail("didn't find f123() on thread's call stack") + + self.assertEqual(sourceline, "g456()") + + # And the next record must be for g456(). + filename, lineno, funcname, sourceline = stack[i+1] + self.assertEqual(funcname, "g456") + self.assertTrue((sourceline.startswith("if leave_g.wait(") or + sourceline.startswith("g_raised.set()"))) + finally: + # Reap the spawned thread. + leave_g.set() + t.join() + + def test_attributes(self): + self.assertIsInstance(sys.api_version, int) + self.assertIsInstance(sys.argv, list) + for arg in sys.argv: + self.assertIsInstance(arg, str) + self.assertIsInstance(sys.orig_argv, list) + for arg in sys.orig_argv: + self.assertIsInstance(arg, str) + self.assertIn(sys.byteorder, ("little", "big")) + self.assertIsInstance(sys.builtin_module_names, tuple) + self.assertIsInstance(sys.copyright, str) + self.assertIsInstance(sys.exec_prefix, str) + self.assertIsInstance(sys.base_exec_prefix, str) + self.assertIsInstance(sys.executable, str) + self.assertEqual(len(sys.float_info), 11) + self.assertEqual(sys.float_info.radix, 2) + self.assertEqual(len(sys.int_info), 4) + self.assertTrue(sys.int_info.bits_per_digit % 5 == 0) + self.assertTrue(sys.int_info.sizeof_digit >= 1) + self.assertGreaterEqual(sys.int_info.default_max_str_digits, 500) + self.assertGreaterEqual(sys.int_info.str_digits_check_threshold, 100) + self.assertGreater(sys.int_info.default_max_str_digits, + sys.int_info.str_digits_check_threshold) + self.assertEqual(type(sys.int_info.bits_per_digit), int) + self.assertEqual(type(sys.int_info.sizeof_digit), int) + self.assertIsInstance(sys.int_info.default_max_str_digits, int) + self.assertIsInstance(sys.int_info.str_digits_check_threshold, int) + self.assertIsInstance(sys.hexversion, int) + + self.assertEqual(len(sys.hash_info), 9) + self.assertLess(sys.hash_info.modulus, 2**sys.hash_info.width) + # sys.hash_info.modulus should be a prime; we do a quick + # probable primality test (doesn't exclude the possibility of + # a Carmichael number) + for x in range(1, 100): + self.assertEqual( + pow(x, sys.hash_info.modulus-1, sys.hash_info.modulus), + 1, + "sys.hash_info.modulus {} is a non-prime".format( + sys.hash_info.modulus) + ) + self.assertIsInstance(sys.hash_info.inf, int) + self.assertIsInstance(sys.hash_info.nan, int) + self.assertIsInstance(sys.hash_info.imag, int) + algo = sysconfig.get_config_var("Py_HASH_ALGORITHM") + if sys.hash_info.algorithm in {"fnv", "siphash13", "siphash24"}: + self.assertIn(sys.hash_info.hash_bits, {32, 64}) + self.assertIn(sys.hash_info.seed_bits, {32, 64, 128}) + + if algo == 1: + self.assertEqual(sys.hash_info.algorithm, "siphash24") + elif algo == 2: + self.assertEqual(sys.hash_info.algorithm, "fnv") + elif algo == 3: + self.assertEqual(sys.hash_info.algorithm, "siphash13") + else: + self.assertIn(sys.hash_info.algorithm, {"fnv", "siphash13", "siphash24"}) + else: + # PY_HASH_EXTERNAL + self.assertEqual(algo, 0) + self.assertGreaterEqual(sys.hash_info.cutoff, 0) + self.assertLess(sys.hash_info.cutoff, 8) + + self.assertIsInstance(sys.maxsize, int) + self.assertIsInstance(sys.maxunicode, int) + self.assertEqual(sys.maxunicode, 0x10FFFF) + self.assertIsInstance(sys.platform, str) + self.assertIsInstance(sys.prefix, str) + self.assertIsInstance(sys.base_prefix, str) + self.assertIsInstance(sys.platlibdir, str) + self.assertIsInstance(sys.version, str) + vi = sys.version_info + self.assertIsInstance(vi[:], tuple) + self.assertEqual(len(vi), 5) + self.assertIsInstance(vi[0], int) + self.assertIsInstance(vi[1], int) + self.assertIsInstance(vi[2], int) + self.assertIn(vi[3], ("alpha", "beta", "candidate", "final")) + self.assertIsInstance(vi[4], int) + self.assertIsInstance(vi.major, int) + self.assertIsInstance(vi.minor, int) + self.assertIsInstance(vi.micro, int) + self.assertIn(vi.releaselevel, ("alpha", "beta", "candidate", "final")) + self.assertIsInstance(vi.serial, int) + self.assertEqual(vi[0], vi.major) + self.assertEqual(vi[1], vi.minor) + self.assertEqual(vi[2], vi.micro) + self.assertEqual(vi[3], vi.releaselevel) + self.assertEqual(vi[4], vi.serial) + self.assertTrue(vi > (1,0,0)) + self.assertIsInstance(sys.float_repr_style, str) + self.assertIn(sys.float_repr_style, ('short', 'legacy')) + if not sys.platform.startswith('win'): + self.assertIsInstance(sys.abiflags, str) + + def test_thread_info(self): + info = sys.thread_info + self.assertEqual(len(info), 3) + self.assertIn(info.name, ('nt', 'pthread', 'pthread-stubs', 'solaris', None)) + self.assertIn(info.lock, ('semaphore', 'mutex+cond', None)) + if sys.platform.startswith(("linux", "android", "freebsd")): + self.assertEqual(info.name, "pthread") + elif sys.platform == "win32": + self.assertEqual(info.name, "nt") + elif sys.platform == "emscripten": + self.assertIn(info.name, {"pthread", "pthread-stubs"}) + elif sys.platform == "wasi": + self.assertEqual(info.name, "pthread-stubs") + + @unittest.skipUnless(support.is_emscripten, "only available on Emscripten") + def test_emscripten_info(self): + self.assertEqual(len(sys._emscripten_info), 4) + self.assertIsInstance(sys._emscripten_info.emscripten_version, tuple) + self.assertIsInstance(sys._emscripten_info.runtime, (str, type(None))) + self.assertIsInstance(sys._emscripten_info.pthreads, bool) + self.assertIsInstance(sys._emscripten_info.shared_memory, bool) + + def test_43581(self): + # Can't use sys.stdout, as this is a StringIO object when + # the test runs under regrtest. + self.assertEqual(sys.__stdout__.encoding, sys.__stderr__.encoding) + + def test_intern(self): + has_is_interned = (test.support.check_impl_detail(cpython=True) + or hasattr(sys, '_is_interned')) + self.assertRaises(TypeError, sys.intern) + self.assertRaises(TypeError, sys.intern, b'abc') + if has_is_interned: + self.assertRaises(TypeError, sys._is_interned) + self.assertRaises(TypeError, sys._is_interned, b'abc') + s = "never interned before" + str(random.randrange(0, 10**9)) + self.assertTrue(sys.intern(s) is s) + if has_is_interned: + self.assertIs(sys._is_interned(s), True) + s2 = s.swapcase().swapcase() + if has_is_interned: + self.assertIs(sys._is_interned(s2), False) + self.assertTrue(sys.intern(s2) is s) + if has_is_interned: + self.assertIs(sys._is_interned(s2), False) + + # Subclasses of string can't be interned, because they + # provide too much opportunity for insane things to happen. + # We don't want them in the interned dict and if they aren't + # actually interned, we don't want to create the appearance + # that they are by allowing intern() to succeed. + class S(str): + def __hash__(self): + return 123 + + self.assertRaises(TypeError, sys.intern, S("abc")) + if has_is_interned: + self.assertIs(sys._is_interned(S("abc")), False) + + @support.cpython_only + @requires_subinterpreters + def test_subinterp_intern_dynamically_allocated(self): + # Implementation detail: Dynamically allocated strings + # are distinct between interpreters + s = "never interned before" + str(random.randrange(0, 10**9)) + t = sys.intern(s) + self.assertIs(t, s) + + interp = interpreters.create() + interp.exec(textwrap.dedent(f''' + import sys + + # set `s`, avoid parser interning & constant folding + s = str({s.encode()!r}, 'utf-8') + + t = sys.intern(s) + + assert id(t) != {id(s)}, (id(t), {id(s)}) + assert id(t) != {id(t)}, (id(t), {id(t)}) + ''')) + + @support.cpython_only + @requires_subinterpreters + def test_subinterp_intern_statically_allocated(self): + # Implementation detail: Statically allocated strings are shared + # between interpreters. + # See Tools/build/generate_global_objects.py for the list + # of strings that are always statically allocated. + for s in ('__init__', 'CANCELLED', '', 'utf-8', + '{{', '', '\n', '_', 'x', '\0', '\N{CEDILLA}', '\xff', + ): + with self.subTest(s=s): + t = sys.intern(s) + + interp = interpreters.create() + interp.exec(textwrap.dedent(f''' + import sys + + # set `s`, avoid parser interning & constant folding + s = str({s.encode()!r}, 'utf-8') + + t = sys.intern(s) + assert id(t) == {id(t)}, (id(t), {id(t)}) + ''')) + + @support.cpython_only + @requires_subinterpreters + def test_subinterp_intern_singleton(self): + # Implementation detail: singletons are used for 0- and 1-character + # latin1 strings. + for s in '', '\n', '_', 'x', '\0', '\N{CEDILLA}', '\xff': + with self.subTest(s=s): + interp = interpreters.create() + interp.exec(textwrap.dedent(f''' + import sys + + # set `s`, avoid parser interning & constant folding + s = str({s.encode()!r}, 'utf-8') + + assert id(s) == {id(s)} + t = sys.intern(s) + ''')) + self.assertTrue(sys._is_interned(s)) + + def test_sys_flags(self): + self.assertTrue(sys.flags) + attrs = ("debug", + "inspect", "interactive", "optimize", + "dont_write_bytecode", "no_user_site", "no_site", + "ignore_environment", "verbose", "bytes_warning", "quiet", + "hash_randomization", "isolated", "dev_mode", "utf8_mode", + "warn_default_encoding", "safe_path", "int_max_str_digits") + for attr in attrs: + self.assertTrue(hasattr(sys.flags, attr), attr) + attr_type = bool if attr in ("dev_mode", "safe_path") else int + self.assertEqual(type(getattr(sys.flags, attr)), attr_type, attr) + self.assertTrue(repr(sys.flags)) + self.assertEqual(len(sys.flags), len(attrs)) + + self.assertIn(sys.flags.utf8_mode, {0, 1, 2}) + + def assert_raise_on_new_sys_type(self, sys_attr): + # Users are intentionally prevented from creating new instances of + # sys.flags, sys.version_info, and sys.getwindowsversion. + arg = sys_attr + attr_type = type(sys_attr) + with self.assertRaises(TypeError): + attr_type(arg) + with self.assertRaises(TypeError): + attr_type.__new__(attr_type, arg) + + def test_sys_flags_no_instantiation(self): + self.assert_raise_on_new_sys_type(sys.flags) + + def test_sys_version_info_no_instantiation(self): + self.assert_raise_on_new_sys_type(sys.version_info) + + def test_sys_getwindowsversion_no_instantiation(self): + # Skip if not being run on Windows. + test.support.get_attribute(sys, "getwindowsversion") + self.assert_raise_on_new_sys_type(sys.getwindowsversion()) + + @test.support.cpython_only + def test_clear_type_cache(self): + sys._clear_type_cache() + + @force_not_colorized + @support.requires_subprocess() + def test_ioencoding(self): + env = dict(os.environ) + + # Test character: cent sign, encoded as 0x4A (ASCII J) in CP424, + # not representable in ASCII. + + env["PYTHONIOENCODING"] = "cp424" + p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xa2))'], + stdout = subprocess.PIPE, env=env) + out = p.communicate()[0].strip() + expected = ("\xa2" + os.linesep).encode("cp424") + self.assertEqual(out, expected) + + env["PYTHONIOENCODING"] = "ascii:replace" + p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xa2))'], + stdout = subprocess.PIPE, env=env) + out = p.communicate()[0].strip() + self.assertEqual(out, b'?') + + env["PYTHONIOENCODING"] = "ascii" + p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xa2))'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + env=env) + out, err = p.communicate() + self.assertEqual(out, b'') + self.assertIn(b'UnicodeEncodeError:', err) + self.assertIn(rb"'\xa2'", err) + + env["PYTHONIOENCODING"] = "ascii:" + p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xa2))'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + env=env) + out, err = p.communicate() + self.assertEqual(out, b'') + self.assertIn(b'UnicodeEncodeError:', err) + self.assertIn(rb"'\xa2'", err) + + env["PYTHONIOENCODING"] = ":surrogateescape" + p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xdcbd))'], + stdout=subprocess.PIPE, env=env) + out = p.communicate()[0].strip() + self.assertEqual(out, b'\xbd') + + @unittest.skipUnless(os_helper.FS_NONASCII, + 'requires OS support of non-ASCII encodings') + @unittest.skipUnless(sys.getfilesystemencoding() == locale.getpreferredencoding(False), + 'requires FS encoding to match locale') + @support.requires_subprocess() + def test_ioencoding_nonascii(self): + env = dict(os.environ) + + env["PYTHONIOENCODING"] = "" + p = subprocess.Popen([sys.executable, "-c", + 'print(%a)' % os_helper.FS_NONASCII], + stdout=subprocess.PIPE, env=env) + out = p.communicate()[0].strip() + self.assertEqual(out, os.fsencode(os_helper.FS_NONASCII)) + + @unittest.skipIf(sys.base_prefix != sys.prefix, + 'Test is not venv-compatible') + @support.requires_subprocess() + def test_executable(self): + # sys.executable should be absolute + self.assertEqual(os.path.abspath(sys.executable), sys.executable) + + # Issue #7774: Ensure that sys.executable is an empty string if argv[0] + # has been set to a non existent program name and Python is unable to + # retrieve the real program name + + # For a normal installation, it should work without 'cwd' + # argument. For test runs in the build directory, see #7774. + python_dir = os.path.dirname(os.path.realpath(sys.executable)) + p = subprocess.Popen( + ["nonexistent", "-c", + 'import sys; print(sys.executable.encode("ascii", "backslashreplace"))'], + executable=sys.executable, stdout=subprocess.PIPE, cwd=python_dir) + stdout = p.communicate()[0] + executable = stdout.strip().decode("ASCII") + p.wait() + self.assertIn(executable, ["b''", repr(sys.executable.encode("ascii", "backslashreplace"))]) + + def check_fsencoding(self, fs_encoding, expected=None): + self.assertIsNotNone(fs_encoding) + codecs.lookup(fs_encoding) + if expected: + self.assertEqual(fs_encoding, expected) + + def test_getfilesystemencoding(self): + fs_encoding = sys.getfilesystemencoding() + if sys.platform == 'darwin': + expected = 'utf-8' + else: + expected = None + self.check_fsencoding(fs_encoding, expected) + + def c_locale_get_error_handler(self, locale, isolated=False, encoding=None): + # Force the POSIX locale + env = os.environ.copy() + env["LC_ALL"] = locale + env["PYTHONCOERCECLOCALE"] = "0" + code = '\n'.join(( + 'import sys', + 'def dump(name):', + ' std = getattr(sys, name)', + ' print("%s: %s" % (name, std.errors))', + 'dump("stdin")', + 'dump("stdout")', + 'dump("stderr")', + )) + args = [sys.executable, "-X", "utf8=0", "-c", code] + if isolated: + args.append("-I") + if encoding is not None: + env['PYTHONIOENCODING'] = encoding + else: + env.pop('PYTHONIOENCODING', None) + p = subprocess.Popen(args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + universal_newlines=True) + stdout, stderr = p.communicate() + return stdout + + def check_locale_surrogateescape(self, locale): + out = self.c_locale_get_error_handler(locale, isolated=True) + self.assertEqual(out, + 'stdin: surrogateescape\n' + 'stdout: surrogateescape\n' + 'stderr: backslashreplace\n') + + # replace the default error handler + out = self.c_locale_get_error_handler(locale, encoding=':ignore') + self.assertEqual(out, + 'stdin: ignore\n' + 'stdout: ignore\n' + 'stderr: backslashreplace\n') + + # force the encoding + out = self.c_locale_get_error_handler(locale, encoding='iso8859-1') + self.assertEqual(out, + 'stdin: strict\n' + 'stdout: strict\n' + 'stderr: backslashreplace\n') + out = self.c_locale_get_error_handler(locale, encoding='iso8859-1:') + self.assertEqual(out, + 'stdin: strict\n' + 'stdout: strict\n' + 'stderr: backslashreplace\n') + + # have no any effect + out = self.c_locale_get_error_handler(locale, encoding=':') + self.assertEqual(out, + 'stdin: surrogateescape\n' + 'stdout: surrogateescape\n' + 'stderr: backslashreplace\n') + out = self.c_locale_get_error_handler(locale, encoding='') + self.assertEqual(out, + 'stdin: surrogateescape\n' + 'stdout: surrogateescape\n' + 'stderr: backslashreplace\n') + + @support.requires_subprocess() + def test_c_locale_surrogateescape(self): + self.check_locale_surrogateescape('C') + + @support.requires_subprocess() + def test_posix_locale_surrogateescape(self): + self.check_locale_surrogateescape('POSIX') + + def test_implementation(self): + # This test applies to all implementations equally. + + levels = {'alpha': 0xA, 'beta': 0xB, 'candidate': 0xC, 'final': 0xF} + + self.assertTrue(hasattr(sys.implementation, 'name')) + self.assertTrue(hasattr(sys.implementation, 'version')) + self.assertTrue(hasattr(sys.implementation, 'hexversion')) + self.assertTrue(hasattr(sys.implementation, 'cache_tag')) + + version = sys.implementation.version + self.assertEqual(version[:2], (version.major, version.minor)) + + hexversion = (version.major << 24 | version.minor << 16 | + version.micro << 8 | levels[version.releaselevel] << 4 | + version.serial << 0) + self.assertEqual(sys.implementation.hexversion, hexversion) + + # PEP 421 requires that .name be lower case. + self.assertEqual(sys.implementation.name, + sys.implementation.name.lower()) + + @test.support.cpython_only + def test_debugmallocstats(self): + # Test sys._debugmallocstats() + from test.support.script_helper import assert_python_ok + args = ['-c', 'import sys; sys._debugmallocstats()'] + ret, out, err = assert_python_ok(*args) + + # Output of sys._debugmallocstats() depends on configure flags. + # The sysconfig vars are not available on Windows. + if sys.platform != "win32": + with_freelists = sysconfig.get_config_var("WITH_FREELISTS") + with_pymalloc = sysconfig.get_config_var("WITH_PYMALLOC") + if with_freelists: + self.assertIn(b"free PyDictObjects", err) + if with_pymalloc: + self.assertIn(b'Small block threshold', err) + if not with_freelists and not with_pymalloc: + self.assertFalse(err) + + # The function has no parameter + self.assertRaises(TypeError, sys._debugmallocstats, True) + + @unittest.skipUnless(hasattr(sys, "getallocatedblocks"), + "sys.getallocatedblocks unavailable on this build") + def test_getallocatedblocks(self): + try: + import _testinternalcapi + except ImportError: + with_pymalloc = support.with_pymalloc() + else: + try: + alloc_name = _testinternalcapi.pymem_getallocatorsname() + except RuntimeError as exc: + # "cannot get allocators name" (ex: tracemalloc is used) + with_pymalloc = True + else: + with_pymalloc = (alloc_name in ('pymalloc', 'pymalloc_debug')) + + # Some sanity checks + a = sys.getallocatedblocks() + self.assertIs(type(a), int) + if with_pymalloc: + self.assertGreater(a, 0) + else: + # When WITH_PYMALLOC isn't available, we don't know anything + # about the underlying implementation: the function might + # return 0 or something greater. + self.assertGreaterEqual(a, 0) + try: + # While we could imagine a Python session where the number of + # multiple buffer objects would exceed the sharing of references, + # it is unlikely to happen in a normal test run. + self.assertLess(a, sys.gettotalrefcount()) + except AttributeError: + # gettotalrefcount() not available + pass + gc.collect() + b = sys.getallocatedblocks() + self.assertLessEqual(b, a) + gc.collect() + c = sys.getallocatedblocks() + self.assertIn(c, range(b - 50, b + 50)) + + def test_is_gil_enabled(self): + if support.Py_GIL_DISABLED: + self.assertIs(type(sys._is_gil_enabled()), bool) + else: + self.assertTrue(sys._is_gil_enabled()) + + def test_is_finalizing(self): + self.assertIs(sys.is_finalizing(), False) + # Don't use the atexit module because _Py_Finalizing is only set + # after calling atexit callbacks + code = """if 1: + import sys + + class AtExit: + is_finalizing = sys.is_finalizing + print = print + + def __del__(self): + self.print(self.is_finalizing(), flush=True) + + # Keep a reference in the __main__ module namespace, so the + # AtExit destructor will be called at Python exit + ref = AtExit() + """ + rc, stdout, stderr = assert_python_ok('-c', code) + self.assertEqual(stdout.rstrip(), b'True') + + def test_issue20602(self): + # sys.flags and sys.float_info were wiped during shutdown. + code = """if 1: + import sys + class A: + def __del__(self, sys=sys): + print(sys.flags) + print(sys.float_info) + a = A() + """ + rc, out, err = assert_python_ok('-c', code) + out = out.splitlines() + self.assertIn(b'sys.flags', out[0]) + self.assertIn(b'sys.float_info', out[1]) + + def test_sys_ignores_cleaning_up_user_data(self): + code = """if 1: + import struct, sys + + class C: + def __init__(self): + self.pack = struct.pack + def __del__(self): + self.pack('I', -42) + + sys.x = C() + """ + rc, stdout, stderr = assert_python_ok('-c', code) + self.assertEqual(rc, 0) + self.assertEqual(stdout.rstrip(), b"") + self.assertEqual(stderr.rstrip(), b"") + + @unittest.skipUnless(sys.platform == "android", "Android only") + def test_getandroidapilevel(self): + level = sys.getandroidapilevel() + self.assertIsInstance(level, int) + self.assertGreater(level, 0) + + @force_not_colorized + @support.requires_subprocess() + def test_sys_tracebacklimit(self): + code = """if 1: + import sys + def f1(): + 1 / 0 + def f2(): + f1() + sys.tracebacklimit = %r + f2() + """ + def check(tracebacklimit, expected): + p = subprocess.Popen([sys.executable, '-c', code % tracebacklimit], + stderr=subprocess.PIPE) + out = p.communicate()[1] + self.assertEqual(out.splitlines(), expected) + + traceback = [ + b'Traceback (most recent call last):', + b' File "", line 8, in ', + b' f2()', + b' ~~^^', + b' File "", line 6, in f2', + b' f1()', + b' ~~^^', + b' File "", line 4, in f1', + b' 1 / 0', + b' ~~^~~', + b'ZeroDivisionError: division by zero' + ] + check(10, traceback) + check(3, traceback) + check(2, traceback[:1] + traceback[4:]) + check(1, traceback[:1] + traceback[7:]) + check(0, [traceback[-1]]) + check(-1, [traceback[-1]]) + check(1<<1000, traceback) + check(-1<<1000, [traceback[-1]]) + check(None, traceback) + + def test_no_duplicates_in_meta_path(self): + self.assertEqual(len(sys.meta_path), len(set(sys.meta_path))) + + @unittest.skipUnless(hasattr(sys, "_enablelegacywindowsfsencoding"), + 'needs sys._enablelegacywindowsfsencoding()') + def test__enablelegacywindowsfsencoding(self): + code = ('import sys', + 'sys._enablelegacywindowsfsencoding()', + 'print(sys.getfilesystemencoding(), sys.getfilesystemencodeerrors())') + rc, out, err = assert_python_ok('-c', '; '.join(code)) + out = out.decode('ascii', 'replace').rstrip() + self.assertEqual(out, 'mbcs replace') + + @support.requires_subprocess() + def test_orig_argv(self): + code = textwrap.dedent(''' + import sys + print(sys.argv) + print(sys.orig_argv) + ''') + args = [sys.executable, '-I', '-X', 'utf8', '-c', code, 'arg'] + proc = subprocess.run(args, check=True, capture_output=True, text=True) + expected = [ + repr(['-c', 'arg']), # sys.argv + repr(args), # sys.orig_argv + ] + self.assertEqual(proc.stdout.rstrip().splitlines(), expected, + proc) + + def test_module_names(self): + self.assertIsInstance(sys.stdlib_module_names, frozenset) + for name in sys.stdlib_module_names: + self.assertIsInstance(name, str) + + def test_stdlib_dir(self): + os = import_helper.import_fresh_module('os') + marker = getattr(os, '__file__', None) + if marker and not os.path.exists(marker): + marker = None + expected = os.path.dirname(marker) if marker else None + self.assertEqual(os.path.normpath(sys._stdlib_dir), + os.path.normpath(expected)) + + @unittest.skipUnless(hasattr(sys, 'getobjects'), 'need sys.getobjects()') + def test_getobjects(self): + # sys.getobjects(0) + all_objects = sys.getobjects(0) + self.assertIsInstance(all_objects, list) + self.assertGreater(len(all_objects), 0) + + # sys.getobjects(0, MyType) + class MyType: + pass + size = 100 + my_objects = [MyType() for _ in range(size)] + get_objects = sys.getobjects(0, MyType) + self.assertEqual(len(get_objects), size) + for obj in get_objects: + self.assertIsInstance(obj, MyType) + + # sys.getobjects(3, MyType) + get_objects = sys.getobjects(3, MyType) + self.assertEqual(len(get_objects), 3) + + @unittest.skipUnless(hasattr(sys, '_stats_on'), 'need Py_STATS build') + def test_pystats(self): + # Call the functions, just check that they don't crash + # Cannot save/restore state. + sys._stats_on() + sys._stats_off() + sys._stats_clear() + sys._stats_dump() + + @test.support.cpython_only + @unittest.skipUnless(hasattr(sys, 'abiflags'), 'need sys.abiflags') + def test_disable_gil_abi(self): + self.assertEqual('t' in sys.abiflags, support.Py_GIL_DISABLED) + + +@test.support.cpython_only +class UnraisableHookTest(__TestCase): + def test_original_unraisablehook(self): + _testcapi = import_helper.import_module('_testcapi') + from _testcapi import err_writeunraisable, err_formatunraisable + obj = hex + + with support.swap_attr(sys, 'unraisablehook', + sys.__unraisablehook__): + with support.captured_stderr() as stderr: + err_writeunraisable(ValueError(42), obj) + lines = stderr.getvalue().splitlines() + self.assertEqual(lines[0], f'Exception ignored in: {obj!r}') + self.assertEqual(lines[1], 'Traceback (most recent call last):') + self.assertEqual(lines[-1], 'ValueError: 42') + + with support.captured_stderr() as stderr: + err_writeunraisable(ValueError(42), None) + lines = stderr.getvalue().splitlines() + self.assertEqual(lines[0], 'Traceback (most recent call last):') + self.assertEqual(lines[-1], 'ValueError: 42') + + with support.captured_stderr() as stderr: + err_formatunraisable(ValueError(42), 'Error in %R', obj) + lines = stderr.getvalue().splitlines() + self.assertEqual(lines[0], f'Error in {obj!r}:') + self.assertEqual(lines[1], 'Traceback (most recent call last):') + self.assertEqual(lines[-1], 'ValueError: 42') + + with support.captured_stderr() as stderr: + err_formatunraisable(ValueError(42), None) + lines = stderr.getvalue().splitlines() + self.assertEqual(lines[0], 'Traceback (most recent call last):') + self.assertEqual(lines[-1], 'ValueError: 42') + + def test_original_unraisablehook_err(self): + # bpo-22836: PyErr_WriteUnraisable() should give sensible reports + class BrokenDel: + def __del__(self): + exc = ValueError("del is broken") + # The following line is included in the traceback report: + raise exc + + class BrokenStrException(Exception): + def __str__(self): + raise Exception("str() is broken") + + class BrokenExceptionDel: + def __del__(self): + exc = BrokenStrException() + # The following line is included in the traceback report: + raise exc + + for test_class in (BrokenDel, BrokenExceptionDel): + with self.subTest(test_class): + obj = test_class() + with test.support.captured_stderr() as stderr, \ + test.support.swap_attr(sys, 'unraisablehook', + sys.__unraisablehook__): + # Trigger obj.__del__() + del obj + + report = stderr.getvalue() + self.assertIn("Exception ignored", report) + self.assertIn(test_class.__del__.__qualname__, report) + self.assertIn("test_sys.py", report) + self.assertIn("raise exc", report) + if test_class is BrokenExceptionDel: + self.assertIn("BrokenStrException", report) + self.assertIn("", report) + else: + self.assertIn("ValueError", report) + self.assertIn("del is broken", report) + self.assertTrue(report.endswith("\n")) + + def test_original_unraisablehook_exception_qualname(self): + # See bpo-41031, bpo-45083. + # Check that the exception is printed with its qualified name + # rather than just classname, and the module names appears + # unless it is one of the hard-coded exclusions. + _testcapi = import_helper.import_module('_testcapi') + from _testcapi import err_writeunraisable + class A: + class B: + class X(Exception): + pass + + for moduleName in 'builtins', '__main__', 'some_module': + with self.subTest(moduleName=moduleName): + A.B.X.__module__ = moduleName + with test.support.captured_stderr() as stderr, test.support.swap_attr( + sys, 'unraisablehook', sys.__unraisablehook__ + ): + err_writeunraisable(A.B.X(), "obj") + report = stderr.getvalue() + self.assertIn(A.B.X.__qualname__, report) + if moduleName in ['builtins', '__main__']: + self.assertNotIn(moduleName + '.', report) + else: + self.assertIn(moduleName + '.', report) + + def test_original_unraisablehook_wrong_type(self): + exc = ValueError(42) + with test.support.swap_attr(sys, 'unraisablehook', + sys.__unraisablehook__): + with self.assertRaises(TypeError): + sys.unraisablehook(exc) + + def test_custom_unraisablehook(self): + _testcapi = import_helper.import_module('_testcapi') + from _testcapi import err_writeunraisable, err_formatunraisable + hook_args = None + + def hook_func(args): + nonlocal hook_args + hook_args = args + + obj = hex + try: + with test.support.swap_attr(sys, 'unraisablehook', hook_func): + exc = ValueError(42) + err_writeunraisable(exc, obj) + self.assertIs(hook_args.exc_type, type(exc)) + self.assertIs(hook_args.exc_value, exc) + self.assertIs(hook_args.exc_traceback, exc.__traceback__) + self.assertIsNone(hook_args.err_msg) + self.assertEqual(hook_args.object, obj) + + err_formatunraisable(exc, "custom hook %R", obj) + self.assertIs(hook_args.exc_type, type(exc)) + self.assertIs(hook_args.exc_value, exc) + self.assertIs(hook_args.exc_traceback, exc.__traceback__) + self.assertEqual(hook_args.err_msg, f'custom hook {obj!r}') + self.assertIsNone(hook_args.object) + finally: + # expected and hook_args contain an exception: break reference cycle + expected = None + hook_args = None + + def test_custom_unraisablehook_fail(self): + _testcapi = import_helper.import_module('_testcapi') + from _testcapi import err_writeunraisable + def hook_func(*args): + raise Exception("hook_func failed") + + with test.support.captured_output("stderr") as stderr: + with test.support.swap_attr(sys, 'unraisablehook', hook_func): + err_writeunraisable(ValueError(42), "custom hook fail") + + err = stderr.getvalue() + self.assertIn(f'Exception ignored in sys.unraisablehook: ' + f'{hook_func!r}\n', + err) + self.assertIn('Traceback (most recent call last):\n', err) + self.assertIn('Exception: hook_func failed\n', err) + + +@test.support.cpython_only +class SizeofTest(__TestCase): + + def setUp(self): + self.P = struct.calcsize('P') + self.longdigit = sys.int_info.sizeof_digit + _testinternalcapi = import_helper.import_module("_testinternalcapi") + self.gc_headsize = _testinternalcapi.SIZEOF_PYGC_HEAD + self.managed_pre_header_size = _testinternalcapi.SIZEOF_MANAGED_PRE_HEADER + super().setUp() + + check_sizeof = test.support.check_sizeof + + def test_gc_head_size(self): + # Check that the gc header size is added to objects tracked by the gc. + vsize = test.support.calcvobjsize + gc_header_size = self.gc_headsize + # bool objects are not gc tracked + self.assertEqual(sys.getsizeof(True), vsize('') + self.longdigit) + # but lists are + self.assertEqual(sys.getsizeof([]), vsize('Pn') + gc_header_size) + + def test_errors(self): + class BadSizeof: + def __sizeof__(self): + raise ValueError + self.assertRaises(ValueError, sys.getsizeof, BadSizeof()) + + class InvalidSizeof: + def __sizeof__(self): + return None + self.assertRaises(TypeError, sys.getsizeof, InvalidSizeof()) + sentinel = ["sentinel"] + self.assertIs(sys.getsizeof(InvalidSizeof(), sentinel), sentinel) + + class FloatSizeof: + def __sizeof__(self): + return 4.5 + self.assertRaises(TypeError, sys.getsizeof, FloatSizeof()) + self.assertIs(sys.getsizeof(FloatSizeof(), sentinel), sentinel) + + class OverflowSizeof(int): + def __sizeof__(self): + return int(self) + self.assertEqual(sys.getsizeof(OverflowSizeof(sys.maxsize)), + sys.maxsize + self.gc_headsize + self.managed_pre_header_size) + with self.assertRaises(OverflowError): + sys.getsizeof(OverflowSizeof(sys.maxsize + 1)) + with self.assertRaises(ValueError): + sys.getsizeof(OverflowSizeof(-1)) + with self.assertRaises((ValueError, OverflowError)): + sys.getsizeof(OverflowSizeof(-sys.maxsize - 1)) + + def test_default(self): + size = test.support.calcvobjsize + self.assertEqual(sys.getsizeof(True), size('') + self.longdigit) + self.assertEqual(sys.getsizeof(True, -1), size('') + self.longdigit) + + def test_objecttypes(self): + # check all types defined in Objects/ + calcsize = struct.calcsize + size = test.support.calcobjsize + vsize = test.support.calcvobjsize + check = self.check_sizeof + # bool + check(True, vsize('') + self.longdigit) + check(False, vsize('') + self.longdigit) + # buffer + # XXX + # builtin_function_or_method + check(len, size('5P')) + # bytearray + samples = [b'', b'u'*100000] + for sample in samples: + x = bytearray(sample) + check(x, vsize('n2Pi') + x.__alloc__()) + # bytearray_iterator + check(iter(bytearray()), size('nP')) + # bytes + check(b'', vsize('n') + 1) + check(b'x' * 10, vsize('n') + 11) + # cell + def get_cell(): + x = 42 + def inner(): + return x + return inner + check(get_cell().__closure__[0], size('P')) + # code + def check_code_size(a, expected_size): + self.assertGreaterEqual(sys.getsizeof(a), expected_size) + check_code_size(get_cell().__code__, size('6i13P')) + check_code_size(get_cell.__code__, size('6i13P')) + def get_cell2(x): + def inner(): + return x + return inner + check_code_size(get_cell2.__code__, size('6i13P') + calcsize('n')) + # complex + check(complex(0,1), size('2d')) + # method_descriptor (descriptor object) + check(str.lower, size('3PPP')) + # classmethod_descriptor (descriptor object) + # XXX + # member_descriptor (descriptor object) + import datetime + check(datetime.timedelta.days, size('3PP')) + # getset_descriptor (descriptor object) + import collections + check(collections.defaultdict.default_factory, size('3PP')) + # wrapper_descriptor (descriptor object) + check(int.__add__, size('3P2P')) + # method-wrapper (descriptor object) + check({}.__iter__, size('2P')) + # empty dict + check({}, size('nQ2P')) + # dict (string key) + check({"a": 1}, size('nQ2P') + calcsize(DICT_KEY_STRUCT_FORMAT) + 8 + (8*2//3)*calcsize('2P')) + longdict = {str(i): i for i in range(8)} + check(longdict, size('nQ2P') + calcsize(DICT_KEY_STRUCT_FORMAT) + 16 + (16*2//3)*calcsize('2P')) + # dict (non-string key) + check({1: 1}, size('nQ2P') + calcsize(DICT_KEY_STRUCT_FORMAT) + 8 + (8*2//3)*calcsize('n2P')) + longdict = {1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8} + check(longdict, size('nQ2P') + calcsize(DICT_KEY_STRUCT_FORMAT) + 16 + (16*2//3)*calcsize('n2P')) + # dictionary-keyview + check({}.keys(), size('P')) + # dictionary-valueview + check({}.values(), size('P')) + # dictionary-itemview + check({}.items(), size('P')) + # dictionary iterator + check(iter({}), size('P2nPn')) + # dictionary-keyiterator + check(iter({}.keys()), size('P2nPn')) + # dictionary-valueiterator + check(iter({}.values()), size('P2nPn')) + # dictionary-itemiterator + check(iter({}.items()), size('P2nPn')) + # dictproxy + class C(object): pass + check(C.__dict__, size('P')) + # BaseException + check(BaseException(), size('6Pb')) + # UnicodeEncodeError + check(UnicodeEncodeError("", "", 0, 0, ""), size('6Pb 2P2nP')) + # UnicodeDecodeError + check(UnicodeDecodeError("", b"", 0, 0, ""), size('6Pb 2P2nP')) + # UnicodeTranslateError + check(UnicodeTranslateError("", 0, 1, ""), size('6Pb 2P2nP')) + # ellipses + check(Ellipsis, size('')) + # EncodingMap + import codecs, encodings.iso8859_3 + x = codecs.charmap_build(encodings.iso8859_3.decoding_table) + check(x, size('32B2iB')) + # enumerate + check(enumerate([]), size('n4P')) + # reverse + check(reversed(''), size('nP')) + # float + check(float(0), size('d')) + # sys.floatinfo + check(sys.float_info, vsize('') + self.P * len(sys.float_info)) + # frame + def func(): + return sys._getframe() + x = func() + check(x, size('3Pi2c2P7P2ic??2P')) + # function + def func(): pass + check(func, size('15Pi')) + class c(): + @staticmethod + def foo(): + pass + @classmethod + def bar(cls): + pass + # staticmethod + check(foo, size('PP')) + # classmethod + check(bar, size('PP')) + # generator + def get_gen(): yield 1 + check(get_gen(), size('PP4P4c7P2ic??2P')) + # iterator + check(iter('abc'), size('lP')) + # callable-iterator + import re + check(re.finditer('',''), size('2P')) + # list + check(list([]), vsize('Pn')) + check(list([1]), vsize('Pn') + 2*self.P) + check(list([1, 2]), vsize('Pn') + 2*self.P) + check(list([1, 2, 3]), vsize('Pn') + 4*self.P) + # sortwrapper (list) + # XXX + # cmpwrapper (list) + # XXX + # listiterator (list) + check(iter([]), size('lP')) + # listreverseiterator (list) + check(reversed([]), size('nP')) + # int + check(0, vsize('') + self.longdigit) + check(1, vsize('') + self.longdigit) + check(-1, vsize('') + self.longdigit) + PyLong_BASE = 2**sys.int_info.bits_per_digit + check(int(PyLong_BASE), vsize('') + 2*self.longdigit) + check(int(PyLong_BASE**2-1), vsize('') + 2*self.longdigit) + check(int(PyLong_BASE**2), vsize('') + 3*self.longdigit) + # module + if support.Py_GIL_DISABLED: + check(unittest, size('PPPPPP')) + else: + check(unittest, size('PPPPP')) + # None + check(None, size('')) + # NotImplementedType + check(NotImplemented, size('')) + # object + check(object(), size('')) + # property (descriptor object) + class C(object): + def getx(self): return self.__x + def setx(self, value): self.__x = value + def delx(self): del self.__x + x = property(getx, setx, delx, "") + check(x, size('5Pi')) + # PyCapsule + check(_datetime.datetime_CAPI, size('6P')) + # rangeiterator + check(iter(range(1)), size('3l')) + check(iter(range(2**65)), size('3P')) + # reverse + check(reversed(''), size('nP')) + # range + check(range(1), size('4P')) + check(range(66000), size('4P')) + # set + # frozenset + PySet_MINSIZE = 8 + samples = [[], range(10), range(50)] + s = size('3nP' + PySet_MINSIZE*'nP' + '2nP') + for sample in samples: + minused = len(sample) + if minused == 0: tmp = 1 + # the computation of minused is actually a bit more complicated + # but this suffices for the sizeof test + minused = minused*2 + newsize = PySet_MINSIZE + while newsize <= minused: + newsize = newsize << 1 + if newsize <= 8: + check(set(sample), s) + check(frozenset(sample), s) + else: + check(set(sample), s + newsize*calcsize('nP')) + check(frozenset(sample), s + newsize*calcsize('nP')) + # setiterator + check(iter(set()), size('P3n')) + # slice + check(slice(0), size('3P')) + # super + check(super(int), size('3P')) + # tuple + check((), vsize('')) + check((1,2,3), vsize('') + 3*self.P) + # type + # static type: PyTypeObject + fmt = 'P2nPI13Pl4Pn9Pn12PIPc' + s = vsize(fmt) + check(int, s) + # class + s = vsize(fmt + # PyTypeObject + '4P' # PyAsyncMethods + '36P' # PyNumberMethods + '3P' # PyMappingMethods + '10P' # PySequenceMethods + '2P' # PyBufferProcs + '6P' + '1PIP' # Specializer cache + ) + class newstyleclass(object): pass + # Separate block for PyDictKeysObject with 8 keys and 5 entries + check(newstyleclass, s + calcsize(DICT_KEY_STRUCT_FORMAT) + 64 + 42*calcsize("2P")) + # dict with shared keys + [newstyleclass() for _ in range(100)] + check(newstyleclass().__dict__, size('nQ2P') + self.P) + o = newstyleclass() + o.a = o.b = o.c = o.d = o.e = o.f = o.g = o.h = 1 + # Separate block for PyDictKeysObject with 16 keys and 10 entries + check(newstyleclass, s + calcsize(DICT_KEY_STRUCT_FORMAT) + 64 + 42*calcsize("2P")) + # dict with shared keys + check(newstyleclass().__dict__, size('nQ2P') + self.P) + # unicode + # each tuple contains a string and its expected character size + # don't put any static strings here, as they may contain + # wchar_t or UTF-8 representations + samples = ['1'*100, '\xff'*50, + '\u0100'*40, '\uffff'*100, + '\U00010000'*30, '\U0010ffff'*100] + # also update field definitions in test_unicode.test_raiseMemError + asciifields = "nnb" + compactfields = asciifields + "nP" + unicodefields = compactfields + "P" + for s in samples: + maxchar = ord(max(s)) + if maxchar < 128: + L = size(asciifields) + len(s) + 1 + elif maxchar < 256: + L = size(compactfields) + len(s) + 1 + elif maxchar < 65536: + L = size(compactfields) + 2*(len(s) + 1) + else: + L = size(compactfields) + 4*(len(s) + 1) + check(s, L) + # verify that the UTF-8 size is accounted for + s = chr(0x4000) # 4 bytes canonical representation + check(s, size(compactfields) + 4) + # compile() will trigger the generation of the UTF-8 + # representation as a side effect + compile(s, "", "eval") + check(s, size(compactfields) + 4 + 4) + # TODO: add check that forces the presence of wchar_t representation + # TODO: add check that forces layout of unicodefields + # weakref + import weakref + if support.Py_GIL_DISABLED: + expected = size('2Pn4P') + else: + expected = size('2Pn3P') + check(weakref.ref(int), expected) + # weakproxy + # XXX + # weakcallableproxy + check(weakref.proxy(int), expected) + + def check_slots(self, obj, base, extra): + expected = sys.getsizeof(base) + struct.calcsize(extra) + if gc.is_tracked(obj) and not gc.is_tracked(base): + expected += self.gc_headsize + self.assertEqual(sys.getsizeof(obj), expected) + + def test_slots(self): + # check all subclassable types defined in Objects/ that allow + # non-empty __slots__ + check = self.check_slots + class BA(bytearray): + __slots__ = 'a', 'b', 'c' + check(BA(), bytearray(), '3P') + class D(dict): + __slots__ = 'a', 'b', 'c' + check(D(x=[]), {'x': []}, '3P') + class L(list): + __slots__ = 'a', 'b', 'c' + check(L(), [], '3P') + class S(set): + __slots__ = 'a', 'b', 'c' + check(S(), set(), '3P') + class FS(frozenset): + __slots__ = 'a', 'b', 'c' + check(FS(), frozenset(), '3P') + from collections import OrderedDict + class OD(OrderedDict): + __slots__ = 'a', 'b', 'c' + check(OD(x=[]), OrderedDict(x=[]), '3P') + + def test_pythontypes(self): + # check all types defined in Python/ + size = test.support.calcobjsize + vsize = test.support.calcvobjsize + check = self.check_sizeof + # _ast.AST + import _ast + check(_ast.AST(), size('P')) + try: + raise TypeError + except TypeError as e: + tb = e.__traceback__ + # traceback + if tb is not None: + check(tb, size('2P2i')) + # symtable entry + # XXX + # sys.flags + # FIXME: The +1 will not be necessary once gh-122575 is fixed + check(sys.flags, vsize('') + self.P * (1 + len(sys.flags))) + + def test_asyncgen_hooks(self): + old = sys.get_asyncgen_hooks() + self.assertIsNone(old.firstiter) + self.assertIsNone(old.finalizer) + + firstiter = lambda *a: None + finalizer = lambda *a: None + + with self.assertRaises(TypeError): + sys.set_asyncgen_hooks(firstiter=firstiter, finalizer="invalid") + cur = sys.get_asyncgen_hooks() + self.assertIsNone(cur.firstiter) + self.assertIsNone(cur.finalizer) + + # gh-118473 + with self.assertRaises(TypeError): + sys.set_asyncgen_hooks(firstiter="invalid", finalizer=finalizer) + cur = sys.get_asyncgen_hooks() + self.assertIsNone(cur.firstiter) + self.assertIsNone(cur.finalizer) + + sys.set_asyncgen_hooks(firstiter=firstiter) + hooks = sys.get_asyncgen_hooks() + self.assertIs(hooks.firstiter, firstiter) + self.assertIs(hooks[0], firstiter) + self.assertIs(hooks.finalizer, None) + self.assertIs(hooks[1], None) + + sys.set_asyncgen_hooks(finalizer=finalizer) + hooks = sys.get_asyncgen_hooks() + self.assertIs(hooks.firstiter, firstiter) + self.assertIs(hooks[0], firstiter) + self.assertIs(hooks.finalizer, finalizer) + self.assertIs(hooks[1], finalizer) + + sys.set_asyncgen_hooks(*old) + cur = sys.get_asyncgen_hooks() + self.assertIsNone(cur.firstiter) + self.assertIsNone(cur.finalizer) + + def test_changing_sys_stderr_and_removing_reference(self): + # If the default displayhook doesn't take a strong reference + # to sys.stderr the following code can crash. See bpo-43660 + # for more details. + code = textwrap.dedent(''' + import sys + class MyStderr: + def write(self, s): + sys.stderr = None + sys.stderr = MyStderr() + 1/0 + ''') + rc, out, err = assert_python_failure('-c', code) + self.assertEqual(out, b"") + self.assertEqual(err, b"") + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_tuple.diff b/test/dynamo/cpython/3_13/test_tuple.diff new file mode 100644 index 00000000000000..46d4bb32d9efd3 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_tuple.diff @@ -0,0 +1,67 @@ +diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py +index 9ce80c5e8ea..e52c0cbc140 100644 +--- a/test/dynamo/cpython/3_13/test_tuple.py ++++ b/test/dynamo/cpython/3_13/test_tuple.py +@@ -1,4 +1,55 @@ +-from test import support, seq_tests ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ ++from test import support ++import seq_tests + import unittest + + import gc +@@ -510,4 +561,4 @@ class TupleTest(seq_tests.CommonTest): + # pileup 262,143 mean 8.0 coll 262,143 z +92683.6 + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py new file mode 100644 index 00000000000000..e52c0cbc140307 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_tuple.py @@ -0,0 +1,564 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +from test import support +import seq_tests +import unittest + +import gc +import pickle + +# For tuple hashes, we normally only run a test to ensure that we get +# the same results across platforms in a handful of cases. If that's +# so, there's no real point to running more. Set RUN_ALL_HASH_TESTS to +# run more anyway. That's usually of real interest only when analyzing, +# or changing, the hash algorithm. In which case it's usually also +# most useful to set JUST_SHOW_HASH_RESULTS, to see all the results +# instead of wrestling with test "failures". See the bottom of the +# file for extensive notes on what we're testing here and why. +RUN_ALL_HASH_TESTS = False +JUST_SHOW_HASH_RESULTS = False # if RUN_ALL_HASH_TESTS, just display + +class TupleTest(seq_tests.CommonTest): + type2test = tuple + + def test_getitem_error(self): + t = () + msg = "tuple indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + t['a'] + + def test_constructors(self): + super().test_constructors() + # calling built-in types without argument must return empty + self.assertEqual(tuple(), ()) + t0_3 = (0, 1, 2, 3) + t0_3_bis = tuple(t0_3) + self.assertTrue(t0_3 is t0_3_bis) + self.assertEqual(tuple([]), ()) + self.assertEqual(tuple([0, 1, 2, 3]), (0, 1, 2, 3)) + self.assertEqual(tuple(''), ()) + self.assertEqual(tuple('spam'), ('s', 'p', 'a', 'm')) + self.assertEqual(tuple(x for x in range(10) if x % 2), + (1, 3, 5, 7, 9)) + + def test_keyword_args(self): + with self.assertRaisesRegex(TypeError, 'keyword argument'): + tuple(sequence=()) + + def test_keywords_in_subclass(self): + class subclass(tuple): + pass + u = subclass([1, 2]) + self.assertIs(type(u), subclass) + self.assertEqual(list(u), [1, 2]) + with self.assertRaises(TypeError): + subclass(sequence=()) + + class subclass_with_init(tuple): + def __init__(self, arg, newarg=None): + self.newarg = newarg + u = subclass_with_init([1, 2], newarg=3) + self.assertIs(type(u), subclass_with_init) + self.assertEqual(list(u), [1, 2]) + self.assertEqual(u.newarg, 3) + + class subclass_with_new(tuple): + def __new__(cls, arg, newarg=None): + self = super().__new__(cls, arg) + self.newarg = newarg + return self + u = subclass_with_new([1, 2], newarg=3) + self.assertIs(type(u), subclass_with_new) + self.assertEqual(list(u), [1, 2]) + self.assertEqual(u.newarg, 3) + + def test_truth(self): + super().test_truth() + self.assertTrue(not ()) + self.assertTrue((42, )) + + def test_len(self): + super().test_len() + self.assertEqual(len(()), 0) + self.assertEqual(len((0,)), 1) + self.assertEqual(len((0, 1, 2)), 3) + + def test_iadd(self): + super().test_iadd() + u = (0, 1) + u2 = u + u += (2, 3) + self.assertTrue(u is not u2) + + def test_imul(self): + super().test_imul() + u = (0, 1) + u2 = u + u *= 3 + self.assertTrue(u is not u2) + + def test_tupleresizebug(self): + # Check that a specific bug in _PyTuple_Resize() is squashed. + def f(): + for i in range(1000): + yield i + self.assertEqual(list(tuple(f())), list(range(1000))) + + # We expect tuples whose base components have deterministic hashes to + # have deterministic hashes too - and, indeed, the same hashes across + # platforms with hash codes of the same bit width. + def test_hash_exact(self): + def check_one_exact(t, e32, e64): + got = hash(t) + expected = e32 if support.NHASHBITS == 32 else e64 + if got != expected: + msg = f"FAIL hash({t!r}) == {got} != {expected}" + self.fail(msg) + + check_one_exact((), 750394483, 5740354900026072187) + check_one_exact((0,), 1214856301, -8753497827991233192) + check_one_exact((0, 0), -168982784, -8458139203682520985) + check_one_exact((0.5,), 2077348973, -408149959306781352) + check_one_exact((0.5, (), (-2, 3, (4, 6))), 714642271, + -1845940830829704396) + + # Various tests for hashing of tuples to check that we get few collisions. + # Does something only if RUN_ALL_HASH_TESTS is true. + # + # Earlier versions of the tuple hash algorithm had massive collisions + # reported at: + # - https://bugs.python.org/issue942952 + # - https://bugs.python.org/issue34751 + def test_hash_optional(self): + from itertools import product + + if not RUN_ALL_HASH_TESTS: + return + + # If specified, `expected` is a 2-tuple of expected + # (number_of_collisions, pileup) values, and the test fails if + # those aren't the values we get. Also if specified, the test + # fails if z > `zlimit`. + def tryone_inner(tag, nbins, hashes, expected=None, zlimit=None): + from collections import Counter + + nballs = len(hashes) + mean, sdev = support.collision_stats(nbins, nballs) + c = Counter(hashes) + collisions = nballs - len(c) + z = (collisions - mean) / sdev + pileup = max(c.values()) - 1 + del c + got = (collisions, pileup) + failed = False + prefix = "" + if zlimit is not None and z > zlimit: + failed = True + prefix = f"FAIL z > {zlimit}; " + if expected is not None and got != expected: + failed = True + prefix += f"FAIL {got} != {expected}; " + if failed or JUST_SHOW_HASH_RESULTS: + msg = f"{prefix}{tag}; pileup {pileup:,} mean {mean:.1f} " + msg += f"coll {collisions:,} z {z:+.1f}" + if JUST_SHOW_HASH_RESULTS: + import sys + print(msg, file=sys.__stdout__) + else: + self.fail(msg) + + def tryone(tag, xs, + native32=None, native64=None, hi32=None, lo32=None, + zlimit=None): + NHASHBITS = support.NHASHBITS + hashes = list(map(hash, xs)) + tryone_inner(tag + f"; {NHASHBITS}-bit hash codes", + 1 << NHASHBITS, + hashes, + native32 if NHASHBITS == 32 else native64, + zlimit) + + if NHASHBITS > 32: + shift = NHASHBITS - 32 + tryone_inner(tag + "; 32-bit upper hash codes", + 1 << 32, + [h >> shift for h in hashes], + hi32, + zlimit) + + mask = (1 << 32) - 1 + tryone_inner(tag + "; 32-bit lower hash codes", + 1 << 32, + [h & mask for h in hashes], + lo32, + zlimit) + + # Tuples of smallish positive integers are common - nice if we + # get "better than random" for these. + tryone("range(100) by 3", list(product(range(100), repeat=3)), + (0, 0), (0, 0), (4, 1), (0, 0)) + + # A previous hash had systematic problems when mixing integers of + # similar magnitude but opposite sign, obscurely related to that + # j ^ -2 == -j when j is odd. + cands = list(range(-10, -1)) + list(range(9)) + + # Note: -1 is omitted because hash(-1) == hash(-2) == -2, and + # there's nothing the tuple hash can do to avoid collisions + # inherited from collisions in the tuple components' hashes. + tryone("-10 .. 8 by 4", list(product(cands, repeat=4)), + (0, 0), (0, 0), (0, 0), (0, 0)) + del cands + + # The hashes here are a weird mix of values where all the + # variation is in the lowest bits and across a single high-order + # bit - the middle bits are all zeroes. A decent hash has to + # both propagate low bits to the left and high bits to the + # right. This is also complicated a bit in that there are + # collisions among the hashes of the integers in L alone. + L = [n << 60 for n in range(100)] + tryone("0..99 << 60 by 3", list(product(L, repeat=3)), + (0, 0), (0, 0), (0, 0), (324, 1)) + del L + + # Used to suffer a massive number of collisions. + tryone("[-3, 3] by 18", list(product([-3, 3], repeat=18)), + (7, 1), (0, 0), (7, 1), (6, 1)) + + # And even worse. hash(0.5) has only a single bit set, at the + # high end. A decent hash needs to propagate high bits right. + tryone("[0, 0.5] by 18", list(product([0, 0.5], repeat=18)), + (5, 1), (0, 0), (9, 1), (12, 1)) + + # Hashes of ints and floats are the same across platforms. + # String hashes vary even on a single platform across runs, due + # to hash randomization for strings. So we can't say exactly + # what this should do. Instead we insist that the # of + # collisions is no more than 4 sdevs above the theoretically + # random mean. Even if the tuple hash can't achieve that on its + # own, the string hash is trying to be decently pseudo-random + # (in all bit positions) on _its_ own. We can at least test + # that the tuple hash doesn't systematically ruin that. + tryone("4-char tuples", + list(product("abcdefghijklmnopqrstuvwxyz", repeat=4)), + zlimit=4.0) + + # The "old tuple test". See https://bugs.python.org/issue942952. + # Ensures, for example, that the hash: + # is non-commutative + # spreads closely spaced values + # doesn't exhibit cancellation in tuples like (x,(x,y)) + N = 50 + base = list(range(N)) + xp = list(product(base, repeat=2)) + inps = base + list(product(base, xp)) + \ + list(product(xp, base)) + xp + list(zip(base)) + tryone("old tuple test", inps, + (2, 1), (0, 0), (52, 49), (7, 1)) + del base, xp, inps + + # The "new tuple test". See https://bugs.python.org/issue34751. + # Even more tortured nesting, and a mix of signed ints of very + # small magnitude. + n = 5 + A = [x for x in range(-n, n+1) if x != -1] + B = A + [(a,) for a in A] + L2 = list(product(A, repeat=2)) + L3 = L2 + list(product(A, repeat=3)) + L4 = L3 + list(product(A, repeat=4)) + # T = list of testcases. These consist of all (possibly nested + # at most 2 levels deep) tuples containing at most 4 items from + # the set A. + T = A + T += [(a,) for a in B + L4] + T += product(L3, B) + T += product(L2, repeat=2) + T += product(B, L3) + T += product(B, B, L2) + T += product(B, L2, B) + T += product(L2, B, B) + T += product(B, repeat=4) + assert len(T) == 345130 + tryone("new tuple test", T, + (9, 1), (0, 0), (21, 5), (6, 1)) + + def test_repr(self): + l0 = tuple() + l2 = (0, 1, 2) + a0 = self.type2test(l0) + a2 = self.type2test(l2) + + self.assertEqual(str(a0), repr(l0)) + self.assertEqual(str(a2), repr(l2)) + self.assertEqual(repr(a0), "()") + self.assertEqual(repr(a2), "(0, 1, 2)") + + def _not_tracked(self, t): + # Nested tuples can take several collections to untrack + gc.collect() + gc.collect() + self.assertFalse(gc.is_tracked(t), t) + + def _tracked(self, t): + self.assertTrue(gc.is_tracked(t), t) + gc.collect() + gc.collect() + self.assertTrue(gc.is_tracked(t), t) + + @support.cpython_only + def test_track_literals(self): + # Test GC-optimization of tuple literals + x, y, z = 1.5, "a", [] + + self._not_tracked(()) + self._not_tracked((1,)) + self._not_tracked((1, 2)) + self._not_tracked((1, 2, "a")) + self._not_tracked((1, 2, (None, True, False, ()), int)) + self._not_tracked((object(),)) + self._not_tracked(((1, x), y, (2, 3))) + + # Tuples with mutable elements are always tracked, even if those + # elements are not tracked right now. + self._tracked(([],)) + self._tracked(([1],)) + self._tracked(({},)) + self._tracked((set(),)) + self._tracked((x, y, z)) + + def check_track_dynamic(self, tp, always_track): + x, y, z = 1.5, "a", [] + + check = self._tracked if always_track else self._not_tracked + check(tp()) + check(tp([])) + check(tp(set())) + check(tp([1, x, y])) + check(tp(obj for obj in [1, x, y])) + check(tp(set([1, x, y]))) + check(tp(tuple([obj]) for obj in [1, x, y])) + check(tuple(tp([obj]) for obj in [1, x, y])) + + self._tracked(tp([z])) + self._tracked(tp([[x, y]])) + self._tracked(tp([{x: y}])) + self._tracked(tp(obj for obj in [x, y, z])) + self._tracked(tp(tuple([obj]) for obj in [x, y, z])) + self._tracked(tuple(tp([obj]) for obj in [x, y, z])) + + @support.cpython_only + def test_track_dynamic(self): + # Test GC-optimization of dynamically constructed tuples. + self.check_track_dynamic(tuple, False) + + @support.cpython_only + def test_track_subtypes(self): + # Tuple subtypes must always be tracked + class MyTuple(tuple): + pass + self.check_track_dynamic(MyTuple, True) + + @support.cpython_only + def test_bug7466(self): + # Trying to untrack an unfinished tuple could crash Python + self._not_tracked(tuple(gc.collect() for i in range(101))) + + def test_repr_large(self): + # Check the repr of large list objects + def check(n): + l = (0,) * n + s = repr(l) + self.assertEqual(s, + '(' + ', '.join(['0'] * n) + ')') + check(10) # check our checking code + check(1000000) + + def test_iterator_pickle(self): + # Userlist iterators don't support pickling yet since + # they are based on generators. + data = self.type2test([4, 5, 6, 7]) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + itorg = iter(data) + d = pickle.dumps(itorg, proto) + it = pickle.loads(d) + self.assertEqual(type(itorg), type(it)) + self.assertEqual(self.type2test(it), self.type2test(data)) + + it = pickle.loads(d) + next(it) + d = pickle.dumps(it, proto) + self.assertEqual(self.type2test(it), self.type2test(data)[1:]) + + def test_reversed_pickle(self): + data = self.type2test([4, 5, 6, 7]) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + itorg = reversed(data) + d = pickle.dumps(itorg, proto) + it = pickle.loads(d) + self.assertEqual(type(itorg), type(it)) + self.assertEqual(self.type2test(it), self.type2test(reversed(data))) + + it = pickle.loads(d) + next(it) + d = pickle.dumps(it, proto) + self.assertEqual(self.type2test(it), self.type2test(reversed(data))[1:]) + + def test_no_comdat_folding(self): + # Issue 8847: In the PGO build, the MSVC linker's COMDAT folding + # optimization causes failures in code that relies on distinct + # function addresses. + class T(tuple): pass + with self.assertRaises(TypeError): + [3,] + T((1,2)) + + def test_lexicographic_ordering(self): + # Issue 21100 + a = self.type2test([1, 2]) + b = self.type2test([1, 2, 0]) + c = self.type2test([1, 3]) + self.assertLess(a, b) + self.assertLess(b, c) + +# Notes on testing hash codes. The primary thing is that Python doesn't +# care about "random" hash codes. To the contrary, we like them to be +# very regular when possible, so that the low-order bits are as evenly +# distributed as possible. For integers this is easy: hash(i) == i for +# all not-huge i except i==-1. +# +# For tuples of mixed type there's really no hope of that, so we want +# "randomish" here instead. But getting close to pseudo-random in all +# bit positions is more expensive than we've been willing to pay for. +# +# We can tolerate large deviations from random - what we don't want is +# catastrophic pileups on a relative handful of hash codes. The dict +# and set lookup routines remain effective provided that full-width hash +# codes for not-equal objects are distinct. +# +# So we compute various statistics here based on what a "truly random" +# hash would do, but don't automate "pass or fail" based on those +# results. Instead those are viewed as inputs to human judgment, and the +# automated tests merely ensure we get the _same_ results across +# platforms. In fact, we normally don't bother to run them at all - +# set RUN_ALL_HASH_TESTS to force it. +# +# When global JUST_SHOW_HASH_RESULTS is True, the tuple hash statistics +# are just displayed to stdout. A typical output line looks like: +# +# old tuple test; 32-bit upper hash codes; \ +# pileup 49 mean 7.4 coll 52 z +16.4 +# +# "old tuple test" is just a string name for the test being run. +# +# "32-bit upper hash codes" means this was run under a 64-bit build and +# we've shifted away the lower 32 bits of the hash codes. +# +# "pileup" is 0 if there were no collisions across those hash codes. +# It's 1 less than the maximum number of times any single hash code was +# seen. So in this case, there was (at least) one hash code that was +# seen 50 times: that hash code "piled up" 49 more times than ideal. +# +# "mean" is the number of collisions a perfectly random hash function +# would have yielded, on average. +# +# "coll" is the number of collisions actually seen. +# +# "z" is "coll - mean" divided by the standard deviation of the number +# of collisions a perfectly random hash function would suffer. A +# positive value is "worse than random", and negative value "better than +# random". Anything of magnitude greater than 3 would be highly suspect +# for a hash function that claimed to be random. It's essentially +# impossible that a truly random function would deliver a result 16.4 +# sdevs "worse than random". +# +# But we don't care here! That's why the test isn't coded to fail. +# Knowing something about how the high-order hash code bits behave +# provides insight, but is irrelevant to how the dict and set lookup +# code performs. The low-order bits are much more important to that, +# and on the same test those did "just like random": +# +# old tuple test; 32-bit lower hash codes; \ +# pileup 1 mean 7.4 coll 7 z -0.2 +# +# So there are always tradeoffs to consider. For another: +# +# 0..99 << 60 by 3; 32-bit hash codes; \ +# pileup 0 mean 116.4 coll 0 z -10.8 +# +# That was run under a 32-bit build, and is spectacularly "better than +# random". On a 64-bit build the wider hash codes are fine too: +# +# 0..99 << 60 by 3; 64-bit hash codes; \ +# pileup 0 mean 0.0 coll 0 z -0.0 +# +# but their lower 32 bits are poor: +# +# 0..99 << 60 by 3; 32-bit lower hash codes; \ +# pileup 1 mean 116.4 coll 324 z +19.2 +# +# In a statistical sense that's waaaaay too many collisions, but (a) 324 +# collisions out of a million hash codes isn't anywhere near being a +# real problem; and, (b) the worst pileup on a single hash code is a measly +# 1 extra. It's a relatively poor case for the tuple hash, but still +# fine for practical use. +# +# This isn't, which is what Python 3.7.1 produced for the hashes of +# itertools.product([0, 0.5], repeat=18). Even with a fat 64-bit +# hashcode, the highest pileup was over 16,000 - making a dict/set +# lookup on one of the colliding values thousands of times slower (on +# average) than we expect. +# +# [0, 0.5] by 18; 64-bit hash codes; \ +# pileup 16,383 mean 0.0 coll 262,128 z +6073641856.9 +# [0, 0.5] by 18; 32-bit lower hash codes; \ +# pileup 262,143 mean 8.0 coll 262,143 z +92683.6 + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_userdict.diff b/test/dynamo/cpython/3_13/test_userdict.diff new file mode 100644 index 00000000000000..1c01574892067f --- /dev/null +++ b/test/dynamo/cpython/3_13/test_userdict.diff @@ -0,0 +1,74 @@ +diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py +index 61e79f553e8..c953390355e 100644 +--- a/test/dynamo/cpython/3_13/test_userdict.py ++++ b/test/dynamo/cpython/3_13/test_userdict.py +@@ -1,3 +1,54 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + # Check every path through every method of UserDict + + from test import mapping_tests, support +@@ -215,10 +266,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): + + # Decorate existing test with recursion limit, because + # the test is for C structure, but `UserDict` is a Python structure. +- test_repr_deep = support.infinite_recursion(25)( +- mapping_tests.TestHashMappingProtocol.test_repr_deep, +- ) ++ # test_repr_deep = support.infinite_recursion(25)( ++ # mapping_tests.TestHashMappingProtocol.test_repr_deep, ++ # ) + + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py new file mode 100644 index 00000000000000..c953390355e678 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_userdict.py @@ -0,0 +1,275 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +# Check every path through every method of UserDict + +from test import mapping_tests, support +import unittest +import collections + +d0 = {} +d1 = {"one": 1} +d2 = {"one": 1, "two": 2} +d3 = {"one": 1, "two": 3, "three": 5} +d4 = {"one": None, "two": None} +d5 = {"one": 1, "two": 1} + +class UserDictTest(mapping_tests.TestHashMappingProtocol): + type2test = collections.UserDict + + def test_all(self): + # Test constructors + u = collections.UserDict() + u0 = collections.UserDict(d0) + u1 = collections.UserDict(d1) + u2 = collections.UserDict(d2) + + uu = collections.UserDict(u) + uu0 = collections.UserDict(u0) + uu1 = collections.UserDict(u1) + uu2 = collections.UserDict(u2) + + # keyword arg constructor + self.assertEqual(collections.UserDict(one=1, two=2), d2) + # item sequence constructor + self.assertEqual(collections.UserDict([('one',1), ('two',2)]), d2) + self.assertEqual(collections.UserDict(dict=[('one',1), ('two',2)]), + {'dict': [('one', 1), ('two', 2)]}) + # both together + self.assertEqual(collections.UserDict([('one',1), ('two',2)], two=3, three=5), d3) + + # alternate constructor + self.assertEqual(collections.UserDict.fromkeys('one two'.split()), d4) + self.assertEqual(collections.UserDict().fromkeys('one two'.split()), d4) + self.assertEqual(collections.UserDict.fromkeys('one two'.split(), 1), d5) + self.assertEqual(collections.UserDict().fromkeys('one two'.split(), 1), d5) + self.assertTrue(u1.fromkeys('one two'.split()) is not u1) + self.assertIsInstance(u1.fromkeys('one two'.split()), collections.UserDict) + self.assertIsInstance(u2.fromkeys('one two'.split()), collections.UserDict) + + # Test __repr__ + self.assertEqual(str(u0), str(d0)) + self.assertEqual(repr(u1), repr(d1)) + self.assertIn(repr(u2), ("{'one': 1, 'two': 2}", + "{'two': 2, 'one': 1}")) + + # Test rich comparison and __len__ + all = [d0, d1, d2, u, u0, u1, u2, uu, uu0, uu1, uu2] + for a in all: + for b in all: + self.assertEqual(a == b, len(a) == len(b)) + + # Test __getitem__ + self.assertEqual(u2["one"], 1) + self.assertRaises(KeyError, u1.__getitem__, "two") + + # Test __setitem__ + u3 = collections.UserDict(u2) + u3["two"] = 2 + u3["three"] = 3 + + # Test __delitem__ + del u3["three"] + self.assertRaises(KeyError, u3.__delitem__, "three") + + # Test clear + u3.clear() + self.assertEqual(u3, {}) + + # Test copy() + u2a = u2.copy() + self.assertEqual(u2a, u2) + u2b = collections.UserDict(x=42, y=23) + u2c = u2b.copy() # making a copy of a UserDict is special cased + self.assertEqual(u2b, u2c) + + class MyUserDict(collections.UserDict): + def display(self): print(self) + + m2 = MyUserDict(u2) + m2a = m2.copy() + self.assertEqual(m2a, m2) + + # SF bug #476616 -- copy() of UserDict subclass shared data + m2['foo'] = 'bar' + self.assertNotEqual(m2a, m2) + + # Test keys, items, values + self.assertEqual(sorted(u2.keys()), sorted(d2.keys())) + self.assertEqual(sorted(u2.items()), sorted(d2.items())) + self.assertEqual(sorted(u2.values()), sorted(d2.values())) + + # Test "in". + for i in u2.keys(): + self.assertIn(i, u2) + self.assertEqual(i in u1, i in d1) + self.assertEqual(i in u0, i in d0) + + # Test update + t = collections.UserDict() + t.update(u2) + self.assertEqual(t, u2) + + # Test get + for i in u2.keys(): + self.assertEqual(u2.get(i), u2[i]) + self.assertEqual(u1.get(i), d1.get(i)) + self.assertEqual(u0.get(i), d0.get(i)) + + # Test "in" iteration. + for i in range(20): + u2[i] = str(i) + ikeys = [] + for k in u2: + ikeys.append(k) + keys = u2.keys() + self.assertEqual(set(ikeys), set(keys)) + + # Test setdefault + t = collections.UserDict() + self.assertEqual(t.setdefault("x", 42), 42) + self.assertIn("x", t) + self.assertEqual(t.setdefault("x", 23), 42) + + # Test pop + t = collections.UserDict(x=42) + self.assertEqual(t.pop("x"), 42) + self.assertRaises(KeyError, t.pop, "x") + self.assertEqual(t.pop("x", 1), 1) + t["x"] = 42 + self.assertEqual(t.pop("x", 1), 42) + + # Test popitem + t = collections.UserDict(x=42) + self.assertEqual(t.popitem(), ("x", 42)) + self.assertRaises(KeyError, t.popitem) + + def test_init(self): + for kw in 'self', 'other', 'iterable': + self.assertEqual(list(collections.UserDict(**{kw: 42}).items()), + [(kw, 42)]) + self.assertEqual(list(collections.UserDict({}, dict=42).items()), + [('dict', 42)]) + self.assertEqual(list(collections.UserDict({}, dict=None).items()), + [('dict', None)]) + self.assertEqual(list(collections.UserDict(dict={'a': 42}).items()), + [('dict', {'a': 42})]) + self.assertRaises(TypeError, collections.UserDict, 42) + self.assertRaises(TypeError, collections.UserDict, (), ()) + self.assertRaises(TypeError, collections.UserDict.__init__) + + def test_update(self): + for kw in 'self', 'dict', 'other', 'iterable': + d = collections.UserDict() + d.update(**{kw: 42}) + self.assertEqual(list(d.items()), [(kw, 42)]) + self.assertRaises(TypeError, collections.UserDict().update, 42) + self.assertRaises(TypeError, collections.UserDict().update, {}, {}) + self.assertRaises(TypeError, collections.UserDict.update) + + def test_missing(self): + # Make sure UserDict doesn't have a __missing__ method + self.assertEqual(hasattr(collections.UserDict, "__missing__"), False) + # Test several cases: + # (D) subclass defines __missing__ method returning a value + # (E) subclass defines __missing__ method raising RuntimeError + # (F) subclass sets __missing__ instance variable (no effect) + # (G) subclass doesn't define __missing__ at all + class D(collections.UserDict): + def __missing__(self, key): + return 42 + d = D({1: 2, 3: 4}) + self.assertEqual(d[1], 2) + self.assertEqual(d[3], 4) + self.assertNotIn(2, d) + self.assertNotIn(2, d.keys()) + self.assertEqual(d[2], 42) + class E(collections.UserDict): + def __missing__(self, key): + raise RuntimeError(key) + e = E() + try: + e[42] + except RuntimeError as err: + self.assertEqual(err.args, (42,)) + else: + self.fail("e[42] didn't raise RuntimeError") + class F(collections.UserDict): + def __init__(self): + # An instance variable __missing__ should have no effect + self.__missing__ = lambda key: None + collections.UserDict.__init__(self) + f = F() + try: + f[42] + except KeyError as err: + self.assertEqual(err.args, (42,)) + else: + self.fail("f[42] didn't raise KeyError") + class G(collections.UserDict): + pass + g = G() + try: + g[42] + except KeyError as err: + self.assertEqual(err.args, (42,)) + else: + self.fail("g[42] didn't raise KeyError") + + # Decorate existing test with recursion limit, because + # the test is for C structure, but `UserDict` is a Python structure. + # test_repr_deep = support.infinite_recursion(25)( + # mapping_tests.TestHashMappingProtocol.test_repr_deep, + # ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/test_userlist.diff b/test/dynamo/cpython/3_13/test_userlist.diff new file mode 100644 index 00000000000000..299a8abeb99ac5 --- /dev/null +++ b/test/dynamo/cpython/3_13/test_userlist.diff @@ -0,0 +1,78 @@ +diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py +index 312702c8e39..a4532922f5d 100644 +--- a/test/dynamo/cpython/3_13/test_userlist.py ++++ b/test/dynamo/cpython/3_13/test_userlist.py +@@ -1,7 +1,58 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import run_tests ++ ++__TestCase = CPythonTestCase ++ ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + # Check every path through every method of UserList + + from collections import UserList +-from test import list_tests ++import list_tests + import unittest + from test import support + +@@ -69,9 +120,9 @@ class UserListTest(list_tests.CommonTest): + + # Decorate existing test with recursion limit, because + # the test is for C structure, but `UserList` is a Python structure. +- test_repr_deep = support.infinite_recursion(25)( +- list_tests.CommonTest.test_repr_deep, +- ) ++ # test_repr_deep = support.infinite_recursion(25)( ++ # list_tests.CommonTest.test_repr_deep, ++ # ) + + if __name__ == "__main__": +- unittest.main() ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py new file mode 100644 index 00000000000000..a4532922f5d42a --- /dev/null +++ b/test/dynamo/cpython/3_13/test_userlist.py @@ -0,0 +1,128 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +# Check every path through every method of UserList + +from collections import UserList +import list_tests +import unittest +from test import support + + +class UserListTest(list_tests.CommonTest): + type2test = UserList + + def test_getslice(self): + super().test_getslice() + l = [0, 1, 2, 3, 4] + u = self.type2test(l) + for i in range(-3, 6): + self.assertEqual(u[:i], l[:i]) + self.assertEqual(u[i:], l[i:]) + for j in range(-3, 6): + self.assertEqual(u[i:j], l[i:j]) + + def test_slice_type(self): + l = [0, 1, 2, 3, 4] + u = UserList(l) + self.assertIsInstance(u[:], u.__class__) + self.assertEqual(u[:],u) + + def test_add_specials(self): + u = UserList("spam") + u2 = u + "eggs" + self.assertEqual(u2, list("spameggs")) + + def test_radd_specials(self): + u = UserList("eggs") + u2 = "spam" + u + self.assertEqual(u2, list("spameggs")) + u2 = u.__radd__(UserList("spam")) + self.assertEqual(u2, list("spameggs")) + + def test_iadd(self): + super().test_iadd() + u = [0, 1] + u += UserList([0, 1]) + self.assertEqual(u, [0, 1, 0, 1]) + + def test_mixedcmp(self): + u = self.type2test([0, 1]) + self.assertEqual(u, [0, 1]) + self.assertNotEqual(u, [0]) + self.assertNotEqual(u, [0, 2]) + + def test_mixedadd(self): + u = self.type2test([0, 1]) + self.assertEqual(u + [], u) + self.assertEqual(u + [2], [0, 1, 2]) + + def test_getitemoverwriteiter(self): + # Verify that __getitem__ overrides *are* recognized by __iter__ + class T(self.type2test): + def __getitem__(self, key): + return str(key) + '!!!' + self.assertEqual(next(iter(T((1,2)))), "0!!!") + + def test_userlist_copy(self): + u = self.type2test([6, 8, 1, 9, 1]) + v = u.copy() + self.assertEqual(u, v) + self.assertEqual(type(u), type(v)) + + # Decorate existing test with recursion limit, because + # the test is for C structure, but `UserList` is a Python structure. + # test_repr_deep = support.infinite_recursion(25)( + # list_tests.CommonTest.test_repr_deep, + # ) + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/cpython/3_13/typinganndata/__init__.py b/test/dynamo/cpython/3_13/typinganndata/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo/cpython/3_13/typinganndata/_typed_dict_helper.py b/test/dynamo/cpython/3_13/typinganndata/_typed_dict_helper.py new file mode 100644 index 00000000000000..9df0ede7d40ee5 --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/_typed_dict_helper.py @@ -0,0 +1,30 @@ +"""Used to test `get_type_hints()` on a cross-module inherited `TypedDict` class + +This script uses future annotations to postpone a type that won't be available +on the module inheriting from to `Foo`. The subclass in the other module should +look something like this: + + class Bar(_typed_dict_helper.Foo, total=False): + b: int + +In addition, it uses multiple levels of Annotated to test the interaction +between the __future__ import, Annotated, and Required. +""" + +from __future__ import annotations + +from typing import Annotated, Generic, Optional, Required, TypedDict, TypeVar + + +OptionalIntType = Optional[int] + +class Foo(TypedDict): + a: OptionalIntType + +T = TypeVar("T") + +class FooGeneric(TypedDict, Generic[T]): + a: Optional[T] + +class VeryAnnotated(TypedDict, total=False): + a: Annotated[Annotated[Annotated[Required[int], "a"], "b"], "c"] diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module.py b/test/dynamo/cpython/3_13/typinganndata/ann_module.py new file mode 100644 index 00000000000000..5081e6b58345a9 --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module.py @@ -0,0 +1,62 @@ + + +""" +The module for testing variable annotations. +Empty lines above are for good reason (testing for correct line numbers) +""" + +from typing import Optional +from functools import wraps + +__annotations__[1] = 2 + +class C: + + x = 5; y: Optional['C'] = None + +from typing import Tuple +x: int = 5; y: str = x; f: Tuple[int, int] + +class M(type): + + __annotations__['123'] = 123 + o: type = object + +(pars): bool = True + +class D(C): + j: str = 'hi'; k: str= 'bye' + +from types import new_class +h_class = new_class('H', (C,)) +j_class = new_class('J') + +class F(): + z: int = 5 + def __init__(self, x): + pass + +class Y(F): + def __init__(self): + super(F, self).__init__(123) + +class Meta(type): + def __new__(meta, name, bases, namespace): + return super().__new__(meta, name, bases, namespace) + +class S(metaclass = Meta): + x: str = 'something' + y: str = 'something else' + +def foo(x: int = 10): + def bar(y: List[str]): + x: str = 'yes' + bar() + +def dec(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + +u: int | float diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module2.py b/test/dynamo/cpython/3_13/typinganndata/ann_module2.py new file mode 100644 index 00000000000000..76cf5b3ad97e62 --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module2.py @@ -0,0 +1,36 @@ +""" +Some correct syntax for variable annotation here. +More examples are in test_grammar and test_parser. +""" + +from typing import no_type_check, ClassVar + +i: int = 1 +j: int +x: float = i/10 + +def f(): + class C: ... + return C() + +f().new_attr: object = object() + +class C: + def __init__(self, x: int) -> None: + self.x = x + +c = C(5) +c.new_attr: int = 10 + +__annotations__ = {} + + +@no_type_check +class NTC: + def meth(self, param: complex) -> None: + ... + +class CV: + var: ClassVar['CV'] + +CV.var = CV() diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module3.py b/test/dynamo/cpython/3_13/typinganndata/ann_module3.py new file mode 100644 index 00000000000000..eccd7be22dd894 --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module3.py @@ -0,0 +1,18 @@ +""" +Correct syntax for variable annotation that should fail at runtime +in a certain manner. More examples are in test_grammar and test_parser. +""" + +def f_bad_ann(): + __annotations__[1] = 2 + +class C_OK: + def __init__(self, x: int) -> None: + self.x: no_such_name = x # This one is OK as proposed by Guido + +class D_bad_ann: + def __init__(self, x: int) -> None: + sfel.y: int = 0 + +def g_bad_ann(): + no_such_name.attr: int = 0 diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module4.py b/test/dynamo/cpython/3_13/typinganndata/ann_module4.py new file mode 100644 index 00000000000000..13e9aee54c98b6 --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module4.py @@ -0,0 +1,5 @@ +# This ann_module isn't for test_typing, +# it's for test_module + +a:int=3 +b:str=4 diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module5.py b/test/dynamo/cpython/3_13/typinganndata/ann_module5.py new file mode 100644 index 00000000000000..837041e121f652 --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module5.py @@ -0,0 +1,10 @@ +# Used by test_typing to verify that Final wrapped in ForwardRef works. + +from __future__ import annotations + +from typing import Final + +name: Final[str] = "final" + +class MyClass: + value: Final = 3000 diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module6.py b/test/dynamo/cpython/3_13/typinganndata/ann_module6.py new file mode 100644 index 00000000000000..679175669bc3ac --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module6.py @@ -0,0 +1,7 @@ +# Tests that top-level ClassVar is not allowed + +from __future__ import annotations + +from typing import ClassVar + +wrong: ClassVar[int] = 1 diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module695.py b/test/dynamo/cpython/3_13/typinganndata/ann_module695.py new file mode 100644 index 00000000000000..2ede9fe382564f --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module695.py @@ -0,0 +1,22 @@ +from __future__ import annotations +from typing import Callable + + +class A[T, *Ts, **P]: + x: T + y: tuple[*Ts] + z: Callable[P, str] + + +class B[T, *Ts, **P]: + T = int + Ts = str + P = bytes + x: T + y: Ts + z: P + + +def generic_function[T, *Ts, **P]( + x: T, *y: *Ts, z: P.args, zz: P.kwargs +) -> None: ... diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module7.py b/test/dynamo/cpython/3_13/typinganndata/ann_module7.py new file mode 100644 index 00000000000000..8f890cd28025be --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module7.py @@ -0,0 +1,11 @@ +# Tests class have ``__text_signature__`` + +from __future__ import annotations + +DEFAULT_BUFFER_SIZE = 8192 + +class BufferedReader(object): + """BufferedReader(raw, buffer_size=DEFAULT_BUFFER_SIZE)\n--\n\n + Create a new buffered reader using the given readable raw IO object. + """ + pass diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module8.py b/test/dynamo/cpython/3_13/typinganndata/ann_module8.py new file mode 100644 index 00000000000000..bd031481378415 --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module8.py @@ -0,0 +1,10 @@ +# Test `@no_type_check`, +# see https://bugs.python.org/issue46571 + +class NoTypeCheck_Outer: + class Inner: + x: int + + +def NoTypeCheck_function(arg: int) -> int: + ... diff --git a/test/dynamo/cpython/3_13/typinganndata/ann_module9.py b/test/dynamo/cpython/3_13/typinganndata/ann_module9.py new file mode 100644 index 00000000000000..952217393e1ff7 --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/ann_module9.py @@ -0,0 +1,14 @@ +# Test ``inspect.formatannotation`` +# https://github.com/python/cpython/issues/96073 + +from typing import Union, List + +ann = Union[List[str], int] + +# mock typing._type_repr behaviour +class A: ... + +A.__module__ = 'testModule.typing' +A.__qualname__ = 'A' + +ann1 = Union[List[A], int] diff --git a/test/dynamo/cpython/3_13/typinganndata/mod_generics_cache.py b/test/dynamo/cpython/3_13/typinganndata/mod_generics_cache.py new file mode 100644 index 00000000000000..6c1ee2fec8374d --- /dev/null +++ b/test/dynamo/cpython/3_13/typinganndata/mod_generics_cache.py @@ -0,0 +1,24 @@ +"""Module for testing the behavior of generics across different modules.""" + +from typing import TypeVar, Generic, Optional, TypeAliasType + +default_a: Optional['A'] = None +default_b: Optional['B'] = None + +T = TypeVar('T') + + +class A(Generic[T]): + some_b: 'B' + + +class B(Generic[T]): + class A(Generic[T]): + pass + + my_inner_a1: 'B.A' + my_inner_a2: A + my_outer_a: 'A' # unless somebody calls get_type_hints with localns=B.__dict__ + +type Alias = int +OldStyle = TypeAliasType("OldStyle", int) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index c1ab329a137def..6699c973052bbd 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -81,9 +81,9 @@ def match_rng_op(node, op): for node in gm.graph.nodes: if match_rng_op(node, op) or node.target == op: actual_count += 1 - assert ( - actual_count >= freq_ge - ), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." + assert actual_count >= freq_ge, ( + f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." + ) return gm diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index d400ac259cf7d8..af162b41ccd76c 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -1235,7 +1235,7 @@ def relu(x): def test_donated_buffer2(self): logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" - # we will re-use the graph for g across f1 and f2 + # we will reuse the graph for g across f1 and f2 @torch.compile() def g(activation, param2): return torch.matmul(activation, param2) @@ -1257,7 +1257,7 @@ def f(inp, param1, param2): def test_donated_buffer3(self): logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" - # we will re-use the graph for g across f1 and f2 + # we will reuse the graph for g across f1 and f2 @torch.compile() def g(activation, param2): return torch.matmul(activation, param2) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 60e8c348cbc915..0d4a1f01f9a303 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] +import copy import os import shutil import unittest @@ -24,7 +25,7 @@ from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.runtime.triton_compat import tl, triton from torch._inductor.test_case import TestCase as InductorTestCase -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache from torch._subclasses import FakeTensorMode from torch.compiler._cache import CacheArtifactManager from torch.fx.experimental.symbolic_shapes import ShapeEnv @@ -164,7 +165,7 @@ def fn(x, y): b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True) # Record artifacts - with fresh_inductor_cache(): + with fresh_cache(): compiled_fn = torch.compile(fn, dynamic=dynamic) # A first call should miss in the cache. @@ -174,9 +175,14 @@ def fn(x, y): if hasattr(a, "_dynamo_weak_dynamic_indices"): del a._dynamo_weak_dynamic_indices self.assertEqual(eager_result, compiled_result) - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) - self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) - self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) + if functorch_config.bundled_autograd_cache: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) + else: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) @@ -189,7 +195,10 @@ def fn(x, y): autotune_expect = 2 if device == GPU_TYPE else 0 - self.assertEqual(len(cache_info.inductor_artifacts), 2) + if functorch_config.bundled_autograd_cache: + self.assertEqual(len(cache_info.inductor_artifacts), 0) + else: + self.assertEqual(len(cache_info.inductor_artifacts), 2) self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) self.assertEqual(len(cache_info.pgo_artifacts), 0) @@ -200,16 +209,21 @@ def fn(x, y): shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) # We did not load anything so dont hit yet - with fresh_inductor_cache(): + with fresh_cache(): eager_result = fn(a, b) compiled_result = compiled_fn(a, b) self.assertEqual(eager_result, compiled_result) compiled_result.sum().backward() if hasattr(a, "_dynamo_weak_dynamic_indices"): del a._dynamo_weak_dynamic_indices - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 4) - self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) - self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) + if functorch_config.bundled_autograd_cache: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) + else: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 4) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) @@ -220,10 +234,12 @@ def fn(x, y): shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) # Hot load and hit - with fresh_inductor_cache(): + with fresh_cache(): cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) - - self.assertEqual(len(cache_info.inductor_artifacts), 2) + if functorch_config.bundled_autograd_cache: + self.assertEqual(len(cache_info.inductor_artifacts), 0) + else: + self.assertEqual(len(cache_info.inductor_artifacts), 2) self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) self.assertEqual(len(cache_info.pgo_artifacts), 0) @@ -234,8 +250,12 @@ def fn(x, y): if hasattr(a, "_dynamo_weak_dynamic_indices"): del a._dynamo_weak_dynamic_indices self.assertEqual(eager_result, compiled_result) - self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 4) - self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) + if functorch_config.bundled_autograd_cache: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + else: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 4) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 2) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) @@ -272,6 +292,49 @@ def fn(x, y): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_multi_graph_specialization(self): + """ + Verify multi graph specializations all cache hit + """ + + def fn(x): + return x * 5 + + a = torch.randn(5) + a8 = torch.randn(8) + a16 = torch.randn(16) + torch._dynamo.mark_dynamic( + a, + 0, + specialize_on=[ + lambda x: x == 8, + lambda x: x == 16, + ], + ) + + compiled_fn = torch.compile(fn, backend="inductor") + + # A first call should miss in the cache. + compiled_fn(a) + compiled_fn(a8) + compiled_fn(a16) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 3) + + self._clear_dynamo_and_codecache() + + # A second call should hit on all 3 graphs + compiled_fn(a) + compiled_fn(a8) + compiled_fn(a16) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 3) + @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -371,6 +434,36 @@ def fn(x, y): # We save again into the cache self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"strict_autograd_cache": True}) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") + @requires_triton() + def test_non_bundled_to_bundled_config_change(self): + if functorch_config.bundled_autograd_cache: + raise unittest.SkipTest("BundledAutogradCache is already enabled") + + def fn(x, y): + return (x * 2, y @ y) + + a = torch.rand(25, device="cuda") + b = torch.rand(5, 5, device="cuda") + + compiled_fn = torch.compile(fn, backend="inductor") + self.assertEqual(fn(a, b), compiled_fn(a, b)) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Now turn on bundled autograd cache, see that we successfully save again + with functorch_config.patch({"bundled_autograd_cache": True}): + torch._dynamo.reset() + self.assertEqual(fn(a, b), compiled_fn(a, b)) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch( @@ -378,7 +471,7 @@ def fn(x, y): ) def test_view_replay_bypass(self): """ - Shoud bypass when view replay is turned on + Should bypass when view replay is turned on """ def fn(a): @@ -393,6 +486,27 @@ def fn(a): self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch( + {"enable_autograd_cache": True, "strict_autograd_cache": True} + ) + def test_invoke_subgraph(self): + from torch._higher_order_ops.invoke_subgraph import mark_compile_region + + @mark_compile_region + def gn(x, y): + return x + y + + @torch.compile + def fn(x, y): + return gn(x, y) + gn(x, y) + + a = torch.randn(25) + b = torch.randn(25) + + fn(a, b) + @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch( @@ -779,6 +893,44 @@ def fn(a, b): self.assertEqual(a.grad, a2.grad) self.assertEqual(b.grad, b2.grad) + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch({"fx_graph_cache": True}) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"strict_autograd_cache": True}) + def test_autograd_no_dynamo_trace_backward(self): + """ + Test that dynamo does not trace into the backward compiled function, + even on cache hit. + """ + torch._dynamo.eval_frame.clear_dynamo_tls() + + @torch.compile + def fn(x): + # Calls x.sum().backward() during forward execution of fn + (x_grad,) = torch.autograd.grad(x.sum(), x) + return x_grad + + a = torch.randn(10, 10, requires_grad=True, device="cpu") + result = fn(a) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + # Backward of `sum` will run during execution of graph break + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + traced_frame_infos = copy.deepcopy( + torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos + ) + + torch._dynamo.reset() + torch._dynamo.eval_frame.clear_dynamo_tls() + result2 = fn(a) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + new_traced_frame_infos = torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos + self.assertEqual(result, result2) + # Dynamo should trace exactly the same frames on cache hit + self.assertEqual(traced_frame_infos, new_traced_frame_infos) + @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -1351,6 +1503,133 @@ def inp_fn(): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 3) + @functorch_config.patch({"enable_autograd_cache": True}) + @inductor_config.patch( + { + "fx_graph_cache": True, + "fx_graph_remote_cache": False, + "autotune_local_cache": True, + } + ) + def test_cache_lazy_backward_for_compiled_autograd(self): + device = "cpu" + dtype = torch.float32 + dynamic = True + """ + Verify that we can populate and hot load functions from the cache. + """ + if device == GPU_TYPE and not HAS_GPU: + raise unittest.SkipTest(f"requires {GPU_TYPE}") + if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + raise unittest.SkipTest("requires SM80 or later") + + def fn(x, y): + return x.sin() @ y + + a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True) + b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True) + + # Record artifacts + with fresh_cache(): + compiled_fn = torch.compile(fn, dynamic=dynamic) + + # A first call should miss in the cache. + eager_result = fn(a, b) + expected_grads = torch.autograd.grad(eager_result.sum(), inputs=(a, b)) + compiled_result = compiled_fn(a, b) + with torch._dynamo.compiled_autograd._enable( + torch.compile(dynamic=dynamic) + ): + actual_grads = torch.autograd.grad(compiled_result.sum(), inputs=(a, b)) + if hasattr(a, "_dynamo_weak_dynamic_indices"): + del a._dynamo_weak_dynamic_indices + self.assertEqual(eager_result, compiled_result) + self.assertEqual(expected_grads[0], actual_grads[0]) + self.assertEqual(expected_grads[1], actual_grads[1]) + if functorch_config.bundled_autograd_cache: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) + else: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 3) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + artifacts = torch.compiler.save_cache_artifacts() + + self.assertIsNotNone(artifacts) + + artifact_bytes, cache_info = artifacts + + autotune_expect = 2 if device == GPU_TYPE else 0 + + if functorch_config.bundled_autograd_cache: + self.assertEqual(len(cache_info.inductor_artifacts), 0) + else: + self.assertEqual(len(cache_info.inductor_artifacts), 3) + self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) + self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) + self.assertEqual(len(cache_info.pgo_artifacts), 0) + + self._clear_all_caches() + + # Clean triton kernels + shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) + + # Hot load and hit, should not recompile + with fresh_cache(): + cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) + + if functorch_config.bundled_autograd_cache: + self.assertEqual(len(cache_info.inductor_artifacts), 0) + else: + self.assertEqual(len(cache_info.inductor_artifacts), 3) + self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) + self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) + self.assertEqual(len(cache_info.pgo_artifacts), 0) + + for i in range(3): + counters.clear() + eager_result = fn(a, b) + expected_grads = torch.autograd.grad(eager_result.sum(), inputs=(a, b)) + compiled_result = compiled_fn(a, b) + with torch._dynamo.compiled_autograd._enable( + torch.compile(dynamic=dynamic) + ): + actual_grads = torch.autograd.grad( + compiled_result.sum(), inputs=(a, b) + ) + if hasattr(a, "_dynamo_weak_dynamic_indices"): + del a._dynamo_weak_dynamic_indices + self.assertEqual(eager_result, compiled_result) + self.assertEqual(expected_grads[0], actual_grads[0]) + self.assertEqual(expected_grads[1], actual_grads[1]) + + if i == 0: + # initial compile + if functorch_config.bundled_autograd_cache: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + else: + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 3) + self.assertEqual( + counters["inductor"]["fxgraph_lookup_write_file"], 3 + ) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual( + counters["aot_autograd"]["autograd_cache_saved"], 0 + ) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + else: + # no recompiles + self.assertFalse(counters) + @functorch_config.patch({"bundled_autograd_cache": True}) class AOTAutogradCacheBundledTests(AOTAutogradCacheTests): @@ -1378,6 +1657,7 @@ def default_config(self): is_export=False, no_tangents=False, enable_log=False, + precompile_backend_id=None, ) def _get_dynamo_output(self, fn, *args, **kwargs): @@ -1406,7 +1686,8 @@ def gen_cache_key(self, f, config, inputs=None): # Needs a shape env for FxGraphCache.check_can_cache to pass. # Not needed for actual key calculation. with torch._guards.tracing(ctx): - return autograd_cache_key(fx_g, example_inputs, config, {}) + with sanitize_gm_for_cache(fx_g): + return autograd_cache_key(fx_g, example_inputs, config, {}) def test_basic_hash_key(self): def fn(x): diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 80aa5c1025fe3c..6f460b402404fd 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -455,9 +455,9 @@ def backward(ctx, grad_output): # Modify gradient using .data (Dangerous: Breaks autograd tracking!) modified_grad = grad_output.clone() - modified_grad.data[ - input_tensor.data < 0 - ] = 0 # Zero-out gradients for negative inputs + modified_grad.data[input_tensor.data < 0] = ( + 0 # Zero-out gradients for negative inputs + ) return modified_grad * 3 @@ -1429,7 +1429,7 @@ def backward(ctx, grad_output, grad_dx): result = grad_output * dx + grad_dx * 6 * x # Intentionally return a wrong value to test if the backward is triggered twice. # Since if the first MyCube.apply returns values w/o requires_grad=True, - # this backward would be only triggered once (the first MyCube.appy call), + # this backward would be only triggered once (the first MyCube.apply call), # as the second MyCube.apply is inlined by Dynamo and the corresponding backward # would be generated by autograd engine. return result * 0.5 diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 4686fb9240c5b2..2c60d6ba4cf599 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -146,16 +146,14 @@ def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"): call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None aot1_tangents_1: "f32[s21]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - new_grad_strided: "f32[s21]" = torch.empty_like(getitem_1); getitem_1 = None - - copy_: "f32[s21]" = new_grad_strided.copy_(aot1_tangents_1); copy_ = None + accumulate_grad = torch__dynamo_compiled_autograd_ops_AccumulateGrad([aot1_tangents_1], getitem_1, None, False); getitem_1 = None + getitem_11: "f32[s21]" = accumulate_grad[0]; accumulate_grad = None result: "f32[s21]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None - new_grad_strided_1: "f32[s21]" = torch.empty_like(getitem_2); getitem_2 = None - - copy__1: "f32[s21]" = new_grad_strided_1.copy_(result); result = copy__1 = None - return (new_grad_strided, new_grad_strided_1) + accumulate_grad_1 = torch__dynamo_compiled_autograd_ops_AccumulateGrad([result], getitem_2, None, False); result = getitem_2 = None + getitem_12: "f32[s21]" = accumulate_grad_1[0]; accumulate_grad_1 = None + return (getitem_11, getitem_12) """, ) elif backend == "inductor": @@ -179,16 +177,14 @@ def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"): call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None aot3_tangents_1: "f32[s21]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - new_grad_strided: "f32[s21]" = torch.empty_like(getitem_1); getitem_1 = None - - copy_: "f32[s21]" = new_grad_strided.copy_(aot3_tangents_1); copy_ = None + accumulate_grad = torch__dynamo_compiled_autograd_ops_AccumulateGrad([aot3_tangents_1], getitem_1, None, False); getitem_1 = None + getitem_11: "f32[s21]" = accumulate_grad[0]; accumulate_grad = None result: "f32[s21]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None - new_grad_strided_1: "f32[s21]" = torch.empty_like(getitem_2); getitem_2 = None - - copy__1: "f32[s21]" = new_grad_strided_1.copy_(result); result = copy__1 = None - return (new_grad_strided, new_grad_strided_1) + accumulate_grad_1 = torch__dynamo_compiled_autograd_ops_AccumulateGrad([result], getitem_2, None, False); result = getitem_2 = None + getitem_12: "f32[s21]" = accumulate_grad_1[0]; accumulate_grad_1 = None + return (getitem_11, getitem_12) """, ) @@ -265,18 +261,16 @@ def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]", call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None aot0_tangents_1: "f32[s21]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None - new_grad_strided: "f32[s21]" = torch.empty_like(getitem_1); getitem_1 = None - - copy_: "f32[s21]" = new_grad_strided.copy_(aot0_tangents_1); copy_ = None + accumulate_grad = torch__dynamo_compiled_autograd_ops_AccumulateGrad([aot0_tangents_1], getitem_1, None, False); getitem_1 = None + getitem_11: "f32[s21]" = accumulate_grad[0]; accumulate_grad = None add: "Sym(s45 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None result: "f32[s21]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None - new_grad_strided_1: "f32[s21]" = torch.empty_like(getitem_2); getitem_2 = None - - copy__1: "f32[s21]" = new_grad_strided_1.copy_(result); result = copy__1 = None - return (new_grad_strided, new_grad_strided_1, add) + accumulate_grad_1 = torch__dynamo_compiled_autograd_ops_AccumulateGrad([result], getitem_2, None, False); result = getitem_2 = None + getitem_12: "f32[s21]" = accumulate_grad_1[0]; accumulate_grad_1 = None + return (getitem_11, getitem_12, add) """, ) diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index 5d14642bb52dd9..18cdf78c61f27c 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -12,10 +12,7 @@ normalize_gm, ) from torch._higher_order_ops.schema import find_hop_schema -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) +from torch.testing._internal.common_utils import instantiate_parametrized_tests from torch.testing._internal.inductor_utils import HAS_CUDA @@ -179,8 +176,7 @@ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"): """invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor(a2!) arg1, *, str scheme="nf4") -> ((Tensor))""", ) - @parametrize("backend", ["eager", "aot_eager"]) - def test_schema_gen_pytree_in_out_with_mutation(self, backend): + def test_schema_gen_pytree_in_out_with_mutation(self): def inner(x_y): x, y = x_y x.add_(1) @@ -194,26 +190,24 @@ def inner(x_y): x = torch.randn(3, 3, requires_grad=False) y = torch.randn(3, 3, requires_grad=True) - if backend == "eager": - bk = EagerAndRecordGraphs() - else: - assert backend == "aot_eager" - bk = AotEagerAndRecordGraphs() + bk = EagerAndRecordGraphs() def f(x, y): return invoke_quant_test(inner, [x, y], scheme="nf4") - with mock.patch( - "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", - True, + with ( + mock.patch( + "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", + True, + ), + torch.no_grad(), ): torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y) - if backend == "eager": - self.assertEqual(len(bk.graphs), 1) - self.assertExpectedInline( - normalize_graph(bk.graphs[0]), - """\ + self.assertEqual(len(bk.graphs), 1) + self.assertExpectedInline( + normalize_graph(bk.graphs[0]), + """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"): l_x_ = L_x_ @@ -241,41 +235,11 @@ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"): child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None return (child, child_1, child_2, child_3) """, # noqa: B950 - ) - self.assertExpectedInline( - str(find_hop_schema(bk.graphs[0], invoke_quant_test)[0]), - """invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor arg1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950 - ) - elif backend == "aot_eager": - self.assertEqual(len(bk.fw_graphs), 1) - self.assertExpectedInline( - normalize_graph(bk.fw_graphs[0]), - """\ -class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"): - auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0 - _tree_spec_constant0 = self._tree_spec_constant0 - auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = primals_2, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [primals_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = _tree_spec_constant0 = None - getitem: "f32[3, 3]" = auto_functionalized_v2[0] - getitem_1: "f32[3, 3]" = auto_functionalized_v2[1] - getitem_2: "f32[3, 3]" = auto_functionalized_v2[2] - getitem_3: "f32[3, 3]" = auto_functionalized_v2[3] - getitem_4: "f32[3, 3]" = auto_functionalized_v2[4]; auto_functionalized_v2 = None - return (getitem, getitem_1, getitem_2, getitem_3, primals_1, primals_2, getitem_4) - - class auto_functionalized_subgraph_0(torch.nn.Module): - def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"): - add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1) - mm: "f32[3, 3]" = torch.ops.aten.mm.default(add, arg1_1) - sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm); mm = None - cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin); sin = None - add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, arg1_1) - sub: "f32[3, 3]" = torch.ops.aten.sub.Tensor(add, arg1_1) - mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(add, arg1_1); arg1_1 = None - copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None - return (cos, add_1, sub, mm_1) -""", # noqa: B950 - ) + ) + self.assertExpectedInline( + str(find_hop_schema(bk.graphs[0], invoke_quant_test)[0]), + """invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor arg1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950 + ) def test_none_input(self): def inner(x, y): @@ -358,26 +322,29 @@ def f(x, y): x = torch.randn(3, 3, requires_grad=False) x_clone = x.clone() y = torch.randn(3, 3, requires_grad=True) - with mock.patch( - "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", - True, + with ( + mock.patch( + "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", + True, + ), + torch.no_grad(), ): compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y) - # assert x is not mutated - self.assertEqual(x, x_clone) - self.assertEqual(compiled_out, x + y + 1) + self.assertEqual(x, x_clone + 1) + self.assertEqual(compiled_out, x_clone + y + 1) self.assertEqual(len(backend.fw_graphs), 1) self.assertExpectedInline( normalize_graph(backend.fw_graphs[0]), """\ -class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"): +class (torch.nn.Module): + def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"): auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0 _tree_spec_constant0 = self._tree_spec_constant0 - auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = primals_2, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [primals_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = _tree_spec_constant0 = None + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = arg1_1, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [arg0_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = arg1_1 = _tree_spec_constant0 = None getitem: "f32[3, 3]" = auto_functionalized_v2[0] getitem_1: "f32[3, 3]" = auto_functionalized_v2[1]; auto_functionalized_v2 = None - return (getitem, primals_1, primals_2, getitem_1) + copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None + return (getitem,) class auto_functionalized_subgraph_0(torch.nn.Module): def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"): diff --git a/test/dynamo/test_buffers_override.py b/test/dynamo/test_buffers_override.py index 946283dc4f19df..3ceba631423d1d 100644 --- a/test/dynamo/test_buffers_override.py +++ b/test/dynamo/test_buffers_override.py @@ -30,7 +30,7 @@ def __init__(self): super().__init__() # Override buffers; should not cause breakage # but skip the marking static here since - # named_buffers is overriden + # named_buffers is overridden self.register_buffer("B", torch.ones(3, 3)) self.named_buffers = [] diff --git a/test/dynamo/test_bytecode_utils.py b/test/dynamo/test_bytecode_utils.py index fa906a2ac162e7..b91b8156ec1814 100644 --- a/test/dynamo/test_bytecode_utils.py +++ b/test/dynamo/test_bytecode_utils.py @@ -53,8 +53,8 @@ def fn(): fn_str = f"""\ def fn(): foo.bar(1, 2, 3) -{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))} - l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}] +{str(chr(10)).join(" " * 4 + "x" + str(i) + " = 1" for i in range(1 << 9))} + l = [{" ".join("x" + str(i) + "," for i in range(1 << 9))}] """ locals = {} exec(fn_str, {}, locals) diff --git a/test/dynamo/test_callback.py b/test/dynamo/test_callback.py index c2caa000e4c6ca..8112a2e89e9578 100644 --- a/test/dynamo/test_callback.py +++ b/test/dynamo/test_callback.py @@ -1,9 +1,14 @@ # Owner(s): ["module: dynamo"] +import unittest from unittest.mock import Mock -from torch._dynamo.callback import callback_handler +import torch +from torch._dynamo.callback import callback_handler, CallbackArgs, CallbackTrigger from torch._dynamo.test_case import run_tests, TestCase +from torch._guards import CompileId +from torch.testing._internal.common_utils import TEST_WITH_ROCM +from torch.testing._internal.inductor_utils import HAS_CUDA class CallbackTests(TestCase): @@ -15,16 +20,23 @@ def setUp(self) -> None: callback_handler.register_end_callback(self._on_compile_end) def tearDown(self) -> None: - return super().tearDown() callback_handler.clear() + return super().tearDown() def test_callbacks_with_duplicate_prevention(self) -> None: - with callback_handler.install_callbacks(), callback_handler.install_callbacks(): + trigger = CallbackTrigger.DYNAMO + compile_id = CompileId(0, 0) + with ( + callback_handler.install_callbacks(trigger, compile_id), + callback_handler.install_callbacks(trigger, compile_id), + ): self._on_compile_start.assert_called_once() self._on_compile_end.assert_called_once() def test_counter(self) -> None: - with callback_handler.install_callbacks(): + trigger = CallbackTrigger.DYNAMO + compile_id = CompileId(0, 0) + with callback_handler.install_callbacks(trigger, compile_id): self.assertEqual( callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 1, @@ -35,18 +47,91 @@ def test_counter(self) -> None: def test_counter_assertion(self) -> None: callback_handler._CompilationCallbackHandler__pending_callbacks_counter -= 1 + with self.assertRaisesRegex( + AssertionError, "Pending callbacks counter cannot become negative." + ): + trigger = CallbackTrigger.DYNAMO + compile_id = CompileId(0, 0) + with callback_handler.install_callbacks(trigger, str(compile_id)): + pass + self.assertEqual( + callback_handler._CompilationCallbackHandler__pending_callbacks_counter, 0 + ) + + @unittest.skipIf( + TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs" + ) + @unittest.skipIf(not HAS_CUDA, "requires triton") + @torch._inductor.config.patch(force_disable_caches=True) + def test_triggers(self) -> None: + torch._dynamo.reset() + order = [] + + def on_start(args: CallbackArgs): + nonlocal order + order.append(f"start={args}") - with self.assertRaises( - AssertionError - ) as e, callback_handler.install_callbacks(): - pass + def on_end(args: CallbackArgs): + nonlocal order + order.append(f"end={args}") + + torch._dynamo.callback.on_compile_start(on_start) + torch._dynamo.callback.on_compile_start(on_end) + + class TinyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(10, 10) + + def forward(self, x): + temp = self.fc1(x) + temp = self.relu(temp) + torch._dynamo.graph_break() + return self.fc2(temp) + + model = TinyModel().to("cuda") + compiled_model = torch.compile(model, mode="max-autotune") + x = torch.randn(10, 10, device="cuda") + + loss = compiled_model(x).sum() + loss.backward() + self.assertExpectedInline( + "\n".join(order), + """\ +start=CallbackArgs(callback_trigger=, compile_id='0/0') +end=CallbackArgs(callback_trigger=, compile_id='0/0') +start=CallbackArgs(callback_trigger=, compile_id='1/0') +end=CallbackArgs(callback_trigger=, compile_id='1/0') +start=CallbackArgs(callback_trigger=, compile_id='1/0') +end=CallbackArgs(callback_trigger=, compile_id='1/0') +start=CallbackArgs(callback_trigger=, compile_id='0/0') +end=CallbackArgs(callback_trigger=, compile_id='0/0')""", # noqa: B950 + ) + order.clear() - self.assertIn( - "Pending callbacks counter cannot become negative.", - str(e.exception), + compiled_model.zero_grad() + loss = compiled_model(x).sum() + loss.backward() + self.assertExpectedInline( + "\n".join(order), + """\ +start=CallbackArgs(callback_trigger=, compile_id='0/0') +end=CallbackArgs(callback_trigger=, compile_id='0/0') +start=CallbackArgs(callback_trigger=, compile_id='1/0') +end=CallbackArgs(callback_trigger=, compile_id='1/0') +start=CallbackArgs(callback_trigger=, compile_id='1/0') +end=CallbackArgs(callback_trigger=, compile_id='1/0') +start=CallbackArgs(callback_trigger=, compile_id='0/0') +end=CallbackArgs(callback_trigger=, compile_id='0/0')""", # noqa: B950 ) + order.clear() - callback_handler._CompilationCallbackHandler__pending_callbacks_counter += 1 + compiled_model.zero_grad() + loss = compiled_model(x).sum() + loss.backward() + self.assertEqual(len(order), 0) if __name__ == "__main__": diff --git a/test/dynamo/test_compile.py b/test/dynamo/test_compile.py index 7161ee7f751966..1f7290c51dd8d9 100644 --- a/test/dynamo/test_compile.py +++ b/test/dynamo/test_compile.py @@ -81,11 +81,11 @@ def test_compilation_callback(self): torch._dynamo.reset() @torch._dynamo.on_compile_start - def start_callback(): + def start_callback(_): print("Compilation started.") @torch._dynamo.on_compile_end - def end_callback(): + def end_callback(_): print("Compilation ended.") mod = ToyModel() @@ -116,13 +116,13 @@ def test_compilation_callback_with_graph_break(self): counter = 0 @torch._dynamo.on_compile_start - def start_callback(): + def start_callback(_): nonlocal counter counter += 1 print(f"Counter = {counter}") @torch._dynamo.on_compile_end - def end_callback(): + def end_callback(_): nonlocal counter counter += 1 print(f"Counter = {counter}") @@ -225,6 +225,15 @@ def fn(x): c_output = c_fn(x) self.assertEqual(output, c_output) + def test_list_bad_access(self): + @torch.compile(backend="eager") + def fn(x, y): + a = [x] + return a[y] + + with self.assertRaises(IndexError): + fn(torch.randn(10), 99) + # The private variants of the below functions are extensively tested # So as long as the signatures match we're good @@ -236,9 +245,17 @@ def check_signature(self, public_fn_name, private_fn_name, private_namespace): public_sig = inspect.signature(public_fn) private_sig = inspect.signature(private_fn) + matching = public_sig == private_sig + matching |= len(public_sig.parameters) < len(private_sig.parameters) and all( + public == private + for public, private in zip( + public_sig.parameters.items(), private_sig.parameters.items() + ) + ) + self.assertEqual( - public_sig, - private_sig, + matching, + True, f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}", ) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index d0216ed5903850..f2c781379ef5c1 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -388,7 +388,7 @@ def fn(x, s0, s1): ref1 = fn(x, s1, s1) res1 = opt_fn(x, s1, s1) - # We have a re-compilation because of chaning inputs + # We have a re-compilation because of changing inputs self.assertEqual(cnts.frame_count, 2) self.assertEqual(ref1, res1) @@ -403,7 +403,7 @@ def fn(x, s0, s1): ref0 = fn(x, s0, s1) res0 = opt_fn(x, s0, s1) - # We have a re-compilation because of chaning inputs + # We have a re-compilation because of changing inputs self.assertEqual(cnts.frame_count, 2) self.assertEqual(ref0, res0) @@ -1252,10 +1252,13 @@ def fn(z): def f(x, y): return x + y - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) @@ -1289,10 +1292,13 @@ def fn(z): def f(x, y): return x + y - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) @@ -1335,10 +1341,13 @@ def inner_fn(x, y): return inner_fn(x, y) + x - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 98a0267f4a6282..6cd8fafa2342d2 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -8,7 +8,7 @@ import torch import torch._dynamo.test_case import torch._dynamo.testing -from torch._dynamo.exc import IncorrectUsage +from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters @@ -1040,11 +1040,11 @@ def fn3(x): self.assertEqual(cnts.frame_count, 2) self.assertEqual(cnts.op_count, 4) - try: - fn3(torch.randn(4, 5)) - self.assertFalse(True) - except torch._dynamo.exc.Unsupported as e: - self.assertIn("Skip calling `torch.compiler.disable()`d function", str(e)) + cnts.clear() + torch._dynamo.reset() + fn3(torch.randn(4, 5)) + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.op_count, 4) def test_disable_optimize(self): cnt = torch._dynamo.testing.CompileCounter() @@ -1694,6 +1694,266 @@ def f4(x): ): f4(torch.randn(3)) + def test_set_fullgraph(self): + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def f1(x): + x = x + 1 + with torch._dynamo.set_fullgraph(False): + torch._dynamo.graph_break() + return x + 2 + + inp = torch.ones(3) + self.assertEqual(f1(inp), inp + 3) + self.assertEqual(cnts.frame_count, 2) + + @torch.compile(backend=cnts) + def f2(x): + x = x + 1 + with torch._dynamo.set_fullgraph(True): + torch._dynamo.graph_break() + return x + 2 + + with self.assertRaises(Unsupported): + f2(inp) + + @torch.compile(backend=cnts, fullgraph=True) + def f3(x): + x = x + 1 + with torch._dynamo.set_fullgraph(False): + torch._dynamo.graph_break() + x = x + 2 + torch._dynamo.graph_break() + return x + 4 + + cnts.clear() + self.assertEqual(f3(inp), inp + 7) + self.assertEqual(cnts.frame_count, 3) + + def inner_f4(x): + x = x + 2 + torch._dynamo.graph_break() + return x + 4 + + @torch.compile(backend=cnts, fullgraph=True) + def f4(x): + x = x + 1 + with torch._dynamo.set_fullgraph(False): + torch._dynamo.skip_frame() + return inner_f4(x) + + cnts.clear() + self.assertEqual(f4(inp), inp + 7) + self.assertEqual(cnts.frame_count, 2) + + def test_set_fullgraph_nested(self): + # set_fullgraph in a nested frame + cnts = torch._dynamo.testing.CompileCounter() + + @torch._dynamo.set_fullgraph(False) + def inner_f5(x): + x = x + 2 + torch._dynamo.graph_break() + return x + 4 + + @torch.compile(backend=cnts, fullgraph=True) + def f5(x): + x = x + 1 + return inner_f5(x) + + inp = torch.ones(3) + self.assertEqual(f5(inp), inp + 7) + self.assertEqual(cnts.frame_count, 4) + + def inner_f6(x): + x = x + 2 + with torch._dynamo.set_fullgraph(False): + torch._dynamo.graph_break() + return x + 4 + + @torch.compile(backend=cnts, fullgraph=True) + def f6(x): + x = x + 1 + return inner_f6(x) + + cnts.clear() + self.assertEqual(f6(inp), inp + 7) + self.assertEqual(cnts.frame_count, 3) + + def inner_f7(x): + x = x + 2 + with torch._dynamo.set_fullgraph(True): + torch._dynamo.graph_break() + return x + 4 + + @torch.compile(backend=cnts, fullgraph=False) + def f7(x): + x = x + 1 + return inner_f7(x) + + with self.assertRaises(Unsupported): + f7(inp) + + def test_set_fullgraph_nested_with_skip(self): + # set_fullgraph in a nested frame with a skipped frame in between + cnts = torch._dynamo.testing.CompileCounter() + + @torch._dynamo.set_fullgraph(False) + def inner2_f8(x): + x = x + 2 + torch._dynamo.graph_break() + return x + 4 + + def inner1_f8(x): + with torch._dynamo.set_fullgraph(False): + torch._dynamo.skip_frame() + return inner2_f8(x) + + @torch.compile(backend=cnts, fullgraph=True) + def f8(x): + x = x + 1 + return inner1_f8(x) + + inp = torch.ones(3) + self.assertEqual(f8(inp), inp + 7) + self.assertEqual(cnts.frame_count, 4) + + def inner2_f9(x): + x = x + 2 + with torch._dynamo.set_fullgraph(True): + torch._dynamo.graph_break() + return x + 4 + + @torch._dynamo.disable(recursive=False) + def inner1_f9(x): + return inner2_f9(x) + + @torch.compile(backend=cnts, fullgraph=False) + def f9(x): + x = x + 1 + return inner1_f9(x) + + with self.assertRaises(Unsupported): + f9(inp) + + # test export with set_fullgraph(False) still errors + + def test_set_fullgraph_export(self): + @torch._dynamo.set_fullgraph(False) + def inner(x): + x = x + 2 + torch._dynamo.graph_break() + return x + 4 + + def f(x): + x = x + 1 + return inner(x) + + with self.assertRaises(Unsupported): + torch._dynamo.export(f)(torch.ones(3)) + + def test_set_fullgraph_nested_deep(self): + cnts = torch._dynamo.testing.CompileCounter() + + def inner1_f1(x): + x = x + 1 + torch._dynamo.graph_break() + return x + 2 + + def inner2_f1(x): + return inner1_f1(x) + + def inner3_f1(x): + with torch._dynamo.set_fullgraph(False): + return inner2_f1(x) + + def inner4_f1(x): + return inner3_f1(x) + + @torch.compile(backend=cnts, fullgraph=True) + def f1(x): + x = x + 4 + return inner4_f1(x) + + inp = torch.ones(3) + self.assertEqual(f1(inp), inp + 7) + self.assertEqual(cnts.frame_count, 4) + + def inner1_f2(x): + x = x + 1 + torch._dynamo.graph_break() + return x + 2 + + def inner2_f2(x): + return inner1_f2(x) + + def inner3_f2(x): + with torch._dynamo.set_fullgraph(True): + return inner2_f2(x) + + def inner4_f2(x): + return inner3_f2(x) + + @torch.compile(backend=cnts, fullgraph=False) + def f2(x): + x = x + 4 + return inner4_f2(x) + + with self.assertRaises(Unsupported): + f2(inp) + + def test_set_fullgraph_error(self): + @torch.compile(backend="eager") + def f1(): + with torch._dynamo.set_fullgraph(foo="bar"): + pass + + @torch.compile(backend="eager") + def f2(): + with torch._dynamo.set_fullgraph(): + pass + + @torch.compile(backend="eager") + def f3(): + with torch._dynamo.set_fullgraph("foo"): + pass + + with self.assertRaises(Exception): + f1() + with self.assertRaises(Exception): + f2() + with self.assertRaises(Exception): + f3() + + def test_nested_compile_fullgraph(self): + inp = torch.ones(3) + + @torch.compile(backend="eager", fullgraph=True) + def inner_f1(x): + x = x + 1 + torch._dynamo.graph_break() + return x + 2 + + @torch.compile(backend="eager", fullgraph=False) + def f1(x): + return inner_f1(x) + + with self.assertRaises(Unsupported): + f1(inp) + + @torch.compile(backend="eager", fullgraph=False) + def inner_f2(x): + x = x + 1 + torch._dynamo.graph_break() + return x + 2 + + @torch.compile(backend="eager", fullgraph=True) + def f2(x): + return inner_f2(x) + + self.assertEqual(f2(inp), inp + 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index b4f9d658fc8f2c..1812971034f908 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -3,6 +3,7 @@ # ruff: noqa: TRY002 import itertools +import operator import types import unittest import weakref @@ -10,7 +11,6 @@ from typing import Any import torch -import torch._dynamo.config import torch._dynamo.test_case import torch._dynamo.testing import torch._functorch.config @@ -18,6 +18,10 @@ import torch.utils.checkpoint from torch._dynamo.testing import same from torch._dynamo.utils import dict_items +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) class SimpleDict(dict): @@ -840,6 +844,8 @@ def fn(x, mp): y = torch.sin(x * mp["a"]) for k, v in mp.items(): # noqa: PERF102 y += torch.cos(x * v) + if isinstance(mp, types.MappingProxyType): + y *= 2 return y opt_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -859,6 +865,21 @@ def fn(x, mp): res = opt_fn(x, mp) self.assertEqual(ref, res) + def test_dict_construction_from_mapping_proxy(self): + d = {"a": 2, "b": 3, "c": 5} + + def fn(x, mp): + d = dict(mp) + y = torch.sin(x * d["a"]) + return y + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + mp = types.MappingProxyType(d) + ref = fn(x, mp) + res = opt_fn(x, mp) + self.assertEqual(ref, res) + def test_mapping_proxy_existing_mutation(self): d = {"a": 2, "b": 3, "c": 5} @@ -883,7 +904,7 @@ def test_mapping_proxy_existing_local_mutation(self): def fn(x): # Dynamo should not cause a graph break here because it knows that - # the existing proxy cant point to this new dict + # the existing proxy can't point to this new dict other_dict = {} other_dict["d"] = 4 y = torch.sin(x * mp["c"]) @@ -956,12 +977,10 @@ def fn(b: Any): a = {"one": torch.ones(1)} return a | b - from torch._dynamo.exc import InternalTorchDynamoError + from torch._dynamo.exc import Unsupported for arg in args: - with self.assertRaisesRegex( - InternalTorchDynamoError, "unsupported operand type" - ): + with self.assertRaises(Unsupported): _ = fn(arg) def test_builtin_or_with_diff_keys(self): @@ -992,6 +1011,48 @@ def f(): opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) + def test_newly_constructed_default_dict(self): + def f(x): + d = defaultdict(list) + d[0] = 42 + return x + 1, d + + x = torch.ones(2) + ref = f(x) + res = torch.compile(f, backend="eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + + @parametrize("op", ["or_", "and_", "xor", "sub"]) + def test_dict_keys_binop(self, op): + op = getattr(operator, op) + + def f(): + a = {"one": torch.ones(1), "two": torch.ones(2)} + b = {"one": torch.ones(1), "three": torch.ones(3)} + return op(a.keys(), b.keys()), op(b.keys(), a.keys()) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["ior", "iand", "ixor", "isub"]) + def test_dict_keys_inplace_binop(self, op): + op = getattr(operator, op) + + def f(): + a = {"one": torch.ones(1), "two": torch.ones(2)}.keys() + b = {"one": torch.ones(1), "three": torch.ones(3)}.keys() + c = {"one": torch.ones(1), "two": torch.ones(2)}.keys() + a = op(a, b) + b = op(b, c) + return a, b + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + +instantiate_parametrized_tests(DictTests) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_einops.py b/test/dynamo/test_einops.py new file mode 100644 index 00000000000000..af15b91434c16b --- /dev/null +++ b/test/dynamo/test_einops.py @@ -0,0 +1,158 @@ +# Owner(s): ["module: dynamo"] +import importlib +import subprocess +import sys +import unittest + +import torch +import torch._dynamo.config +import torch._dynamo.test_case +from torch import nn +from torch._dynamo.test_case import TestCase +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +HAS_EINOPS = importlib.util.find_spec("einops") + +if HAS_EINOPS: + import einops + + einops_version = einops.__version__ +else: + einops_version = "none" +einops_version_sanitized = einops_version.replace(".", "_") + + +@unittest.skipIf(not HAS_EINOPS, "these tests require einops") +class TestEinops(TestCase): + """ + These tests adapted from similar tests in the einops repo. + https://github.com/arogozhnikov/einops/blob/main/einops/tests/test_other.py#L254 + + The goal of this test suite is to test torch.compile x einops for multiple + versions of einops. Our goal is to prevent regressions in einops from changes + in PyTorch. + """ + + @unittest.skipIf( + einops_version == "0.6.1", "https://github.com/pytorch/pytorch/issues/157417" + ) + @parametrize("version", [einops_version_sanitized]) + def test_functions(self, version): + from einops import einsum, pack, rearrange, reduce, repeat, unpack + + class TorchModuleWithOperations(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x_abc, suffix=""): + a, b, c = x_abc.shape + + def suf(pattern): + parts = pattern.split() + return " ".join( + [p if p[-1] not in "acd" else p + suffix for p in parts] + ) + + # patterns look a bit strange because names a, c, d will be modified on every run + # by suf function + x_abcd = repeat(x_abc, suf("a b c -> a b c 4")) + x_abc = reduce(x_abcd, suf("a b c d -> a b c"), "min") + x_abdc, ps = pack([x_abc] * (2 + len(suffix)), suf("a b * c")) + x_array = unpack( + rearrange(x_abdc, suf("a b d c -> (a b ) 1 c d")), ps, "ab one1 c *" + ) + x1 = x_array[0] + len(x_array) + x1 = rearrange(x1, suf("(a b ) 1 c -> a b c"), b=b) + addition = einsum(x_abc, x_abcd, suf("a b c , a b c d -> d"))[0] + return x1 + addition + + original = TorchModuleWithOperations() + # Einops only interacts with Dynamo but we test backend="inductor" just in case + compiled = torch.compile(original, backend="inductor", fullgraph=True) + for size in [10, 20, 40]: + x = torch.rand([size, size + 1, size + 2]) + for suffix in ["", "suf1", "other_suffix"]: + result1 = compiled(x, suffix) + result2 = original(x.double(), suffix).float() + self.assertEqual(result1, result2) + + @parametrize("version", [einops_version_sanitized]) + def test_layers(self, version): + from einops.layers.torch import EinMix, Rearrange, Reduce + + original = nn.Sequential( + Rearrange("b (t c) -> b t c", c=16), + EinMix( + "b t c -> qkv b t cout", + weight_shape="qkv c cout", + bias_shape="qkv cout", + qkv=3, + c=16, + cout=8, + ), + Reduce("qkv b t cout -> b t qkv", "min", cout=8), + ) + + # Einops only interacts with Dynamo but we test backend="inductor" just in case + compiled = torch.compile(original, backend="inductor", fullgraph=True) + + for size in [16, 32, 64]: + x = torch.rand([size, size]) + result1 = original(x) + result2 = compiled(x.double()).float() + self.assertEqual(result1, result2) + + @parametrize("version", [einops_version_sanitized]) + def test_no_recompile_on_lazy_state(self, version): + """einops has some lazy state that gets initialized the first time an API + is called. This should not trigger a recompile.""" + script = """\ +import torch +import torch.nn as nn +from einops import einsum, pack, reduce, repeat, unpack, rearrange + +class TorchModuleWithOperations(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x_abc, suffix=""): + a, b, c = x_abc.shape + + def suf(pattern): + parts = pattern.split() + return " ".join([p if p[-1] not in "acd" else p + suffix for p in parts]) + + # patterns look a bit strange because names a, c, d will be modified on every run + # by suf function + x_abcd = repeat(x_abc, suf("a b c -> a b c 4")) + x_abc = reduce(x_abcd, suf("a b c d -> a b c"), "min") + x_abdc, ps = pack([x_abc] * (2 + len(suffix)), suf("a b * c")) + x_array = unpack(rearrange(x_abdc, suf("a b d c -> (a b ) 1 c d")), ps, "ab one1 c *") + x1 = x_array[0] + len(x_array) + x1 = rearrange(x1, suf("(a b ) 1 c -> a b c"), b=b) + addition = einsum(x_abc, x_abcd, suf("a b c , a b c d -> d"))[0] + return x1 + addition + +compiled_fn = torch.compile(TorchModuleWithOperations(), fullgraph=True) +x = torch.arange(2 * 3 * 5).view(2, 3, 5) +y = compiled_fn(x) + +# Should not recompile! +with torch.compiler.set_stance("fail_on_recompile"): + z = compiled_fn(x) +""" + subprocess.check_output([sys.executable, "-c", script]) + + +instantiate_parametrized_tests( + TestEinops, +) + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 0a54ddad19ab00..69a822d7b1c091 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -1,16 +1,19 @@ # Owner(s): ["module: dynamo"] +import logging import re import traceback import unittest +import unittest.mock import warnings +from functools import lru_cache import torch import torch._dynamo import torch._dynamo.config import torch._dynamo.test_case import torch.utils._pytree as python_pytree -from torch._dynamo.exc import Unsupported +from torch._dynamo.exc import ResumePrologueTracingError, Unsupported from torch._dynamo.testing import skipIfNotPy312 from torch._dynamo.utils import counters from torch.testing._internal.common_utils import ( @@ -96,7 +99,12 @@ def fn(x): torch.Tensor([1]) ), """\ -Tensor.item +Unsupported Tensor.item() call with capture_scalar_outputs=False + Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False. + Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph. + + Developer debug context: call_method TensorVariable() item () {} + from user code: File "test_error_messages.py", line N, in fn @@ -116,7 +124,7 @@ def fn(x): """\ Data dependent operator Explanation: Operator `aten.equal.default` has a non-Tensor output whose value is dependent on the data of Tensor inputs. - Hint: Consider wrapping the operator into a PyTorch-understood custom operator (see https:/pytorch.org/tutorials/advanced/custom_ops_landing_page.html) + Hint: Consider wrapping the operator into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) Developer debug context: aten.equal.default @@ -168,6 +176,7 @@ def fn(it): Hint: Avoid calling `zip.__iter__` in your code. Hint: Please report an issue to PyTorch. Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope. + Hint: List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, (2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a function, or (4) use Python 3.12+. Developer debug context: call_method UserDefinedObjectVariable(zip) __iter__ () {} @@ -195,6 +204,7 @@ def fn(x, items): Hint: Please report an issue to PyTorch. Hint: Consider moving the creation of dict view object (e.g. `dict.keys()`, `dict.items()`,) to the compiled region, instead of passing it as an input to the compiled region. Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope. + Hint: List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, (2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a function, or (4) use Python 3.12+. Developer debug context: call_method UserDefinedObjectVariable(dict_items) __iter__ () {} @@ -717,7 +727,7 @@ def post_munge(s): """\ Missing bytecode handler Explanation: Dynamo does not know how to handle the bytecode instruction `GET_AITER`. - Hint: Do not trace code that produces the `GET_AITER` bytecode instruction (see https:/docs.python.org/3/library/dis.html for bytecode semantics). + Hint: Do not trace code that produces the `GET_AITER` bytecode instruction (see https://docs.python.org/3/library/dis.html for bytecode semantics). Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. Developer debug context: GET_AITER with args (, Instruction(GET_AITER) @@ -916,7 +926,7 @@ def fn(x): Data-dependent assertion failed (cannot compile partial graph) Explanation: Dynamo has determined when encountering a data-dependent assert failure that it should not compile the partial graph. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. - Hint: Use `torch._assert()` to raise a hard AssertionError when the check fails. This error will propagate back the user code that called the compiled function (i.e. Dynamo wil not trace any exception handling). + Hint: Use `torch._assert()` to raise a hard AssertionError when the check fails. This error will propagate back the user code that called the compiled function (i.e. Dynamo will not trace any exception handling). Hint: Remove the assert statement. Hint: Move the assert statement outside of any context managers in order to graph break with partial graph compilation (if fullgraph=False). @@ -1177,6 +1187,61 @@ def f3(x): """, ) + @make_logging_test(dynamo=logging.DEBUG) + def test_lru_cache_warning_logs_user_stack_trace(self, records): + @lru_cache + def foo(x): + return x + 1 + + torch.compile(foo, backend="eager")(torch.randn(4)) + + lru_cache_log = None + for record in records: + if "call to a lru_cache wrapped function at:" in record.getMessage(): + lru_cache_log = record.getMessage() + break + + self.assertIsNotNone(lru_cache_log, "No lru_cache warning was logged") + + self.assertExpectedInline( + munge_exc(lru_cache_log), + """\ +call to a lru_cache wrapped function at: _dynamo/external_utils.py:N + File "test_error_messages.py", line N, in test_lru_cache_warning_logs_user_stack_trace + torch.compile(foo, backend="eager")(torch.randn(4)) +""", + ) + + @make_logging_test(dynamo=logging.DEBUG) + def test_lru_cache_warning_logs_nested_call(self, records): + @lru_cache + def foo(x): + return x + 1 + + def nested(x): + return foo(x) + + torch.compile(nested, backend="eager")(torch.randn(4)) + + lru_cache_log = None + for record in records: + if "call to a lru_cache wrapped function at:" in record.getMessage(): + lru_cache_log = record.getMessage() + break + + self.assertIsNotNone(lru_cache_log, "No lru_cache warning was logged") + + self.assertExpectedInline( + munge_exc(lru_cache_log), + """\ +call to a lru_cache wrapped function at: test_error_messages.py:N + File "test_error_messages.py", line N, in test_lru_cache_warning_logs_nested_call + torch.compile(nested, backend="eager")(torch.randn(4)) + File "test_error_messages.py", line N, in nested + return foo(x) +""", + ) + def test_disable_message(self): @torch.compile(backend="eager", fullgraph=True) def outer(fn, x): @@ -1252,6 +1317,48 @@ def forward(self, x): post_munge=post_munge, ) + # Test that errors while tracing resume function prologues do not get suppressed + def test_graph_break_in_buggy_resume_prologue(self): + import torch._dynamo.bytecode_transformation as bt + import torch._dynamo.resume_execution as rex + + # NOTE: do not define non_global as a global in this file! + @torch.compile(backend="eager") + def fn(non_global): + non_global = non_global + 1 + torch._dynamo.graph_break() + return non_global + 1 + + orig_clean_and_assemble_instructions = bt.clean_and_assemble_instructions + + def bad_clean_and_assemble_instructions(instructions, *args): + # Inject an invalid LOAD_GLOBAL after the first STORE_FAST IS_TRACING_RESUME_PROLOGUE_VARNAME + for i, inst in enumerate(instructions): + if ( + inst.opname == "STORE_FAST" + and inst.argval == rex.IS_TRACING_RESUME_PROLOGUE_VARNAME + ): + instructions[:] = ( + instructions[: i + 1] + + [ + # this should cause a graph break + bt.create_instruction("LOAD_GLOBAL", argval="non_global"), + ] + + instructions[i + 1 :] + ) + break + return orig_clean_and_assemble_instructions(instructions, *args) + + with unittest.mock.patch( + "torch._dynamo.bytecode_transformation.clean_and_assemble_instructions", + bad_clean_and_assemble_instructions, + ): + with self.assertRaisesRegex( + ResumePrologueTracingError, + "Error while tracing through a Dynamo-generated resume function prologue.", + ): + fn(torch.randn(3)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 07bb5760326007..94ce690ed5b96f 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -19,7 +19,7 @@ class CustomException(Exception): - ... + pass class CustomExceptionMeta(type): @@ -28,7 +28,7 @@ def __instancecheck__(cls, instance): class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta): - ... + pass class CustomExceptionWithArgs(Exception): @@ -292,7 +292,7 @@ def fn(x): x = torch.randn(4) fn(x) - # Cant use fullgraph=True because RERAISE is not supported + # Can't use fullgraph=True because RERAISE is not supported opt_fn = torch.compile(fn, backend="eager") opt_fn(x) @@ -358,7 +358,7 @@ def fn(x): def test_raise_custom_exception(self): class Exc(Exception): - ... + pass @torch.compile(backend="eager", fullgraph=True) def fn(t): @@ -375,7 +375,7 @@ def fn(t): def test_raise_custom_exception_with_args(self): class Exc(Exception): - ... + pass @torch.compile(backend="eager", fullgraph=True) def fn(t): diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 94bb39a4fceb3c..53c9e2b79f3815 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -368,6 +368,25 @@ def func(x): self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + def test_immutable_list_dict(self): + class M(torch.nn.Module): + def forward(self, x1, x2): + return [x1 + x2], {"moo1": x1 * x1, "moo2": x2 * x2} + + x1 = torch.randn(2, 3) + x2 = torch.randn(2, 3) + model = M() + + fx_model = make_fx( + model, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + _error_on_data_dependent_ops=True, + )(*[x1, x2]) + ep = torch.export.export(fx_model, (x1, x2)) + res = torch.compile(ep.module(), dynamic=True, fullgraph=True)(x1, x2) + self.assertTrue(torch._dynamo.utils.same(res, M()(x1, x2))) + def test_dupes(self): inp = torch.tensor([0.1, 0.1]) @@ -2517,7 +2536,8 @@ def forward(self, x): dynamic_shapes = {"x": (dim0,)} with self.assertRaisesRegex( torch._dynamo.exc.UserError, - "You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.", + "You marked.*but your code specialized it to be a constant.*" + "If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO", ): torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes, strict=True) @@ -3516,7 +3536,7 @@ def forward(self, pred, x): [3, 3, 4, 5], [true_graph, true_graph, false_graph, false_graph], [true_guard_code, true_guard_code, false_guard_code, false_guard_code], - # Outter shape env should have no guards in it because we never specialize on the outter symbool. + # Outer shape env should have no guards in it because we never specialize on the outer symbool. [[], [], [], []], ) @@ -4577,6 +4597,20 @@ def forward(self, x): out = graph(x) self.assertEqual(ref_out, out) + def test_strict_fake_tensor_prop_real_tensors(self): + class Foo(torch.nn.Module): + def forward(self, x): + return bool(x.eq(0.1).any().item()) + + model = Foo() + inputs = (torch.randn(64),) + ref = model(*inputs) + with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True): + ep = torch.export.export(model, inputs, strict=True) + res = ep.module()(*inputs) + + self.assertEqual(ref, res) + class ExportTestsDevice(torch._dynamo.test_case.TestCase): def test_export_with_parameters(self, device): diff --git a/test/dynamo/test_flat_apply.py b/test/dynamo/test_flat_apply.py index 8e5d945299186f..aad5d6b2815688 100644 --- a/test/dynamo/test_flat_apply.py +++ b/test/dynamo/test_flat_apply.py @@ -175,7 +175,7 @@ def fn(x, y): class (torch.nn.Module): def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"): mul: "f32[10]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None - _tensor_constant0 = self._tensor_constant0 + _tensor_constant0: "f32[1]" = self._tensor_constant0 add: "f32[10]" = torch.ops.aten.add.Tensor(mul, _tensor_constant0); mul = _tensor_constant0 = None return (add,) """, # NOQA: B950 diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index d6deb3fc45c797..aa28e3da51812b 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -166,7 +166,7 @@ def foo(x): for warning in w: warning_message = str(warning.message) if ( - "Dynamo detected a call to a `functools.lru_cache` wrapped function" + "Dynamo detected a call to a `functools.lru_cache`-wrapped" in warning_message ): break @@ -519,6 +519,17 @@ def test_tuple2(a, b): args = [a, b] return sub(*args) + def test_size_tuple_add(self): + def fn(): + size = torch.Size([]) + assert isinstance(size + size, torch.Size) + assert isinstance(size + (), tuple) + assert isinstance(size + (), torch.Size) + + fn() + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + compiled_fn() + @make_test def test_is_in_onnx_export(x, y): if torch.onnx.is_in_onnx_export(): @@ -1229,7 +1240,7 @@ def test_module_constant(x, y): @make_test def test_inline_softmax(x, y): - # This is common in sme huggingface models + # This is common in some huggingface models return torch.nn.Softmax(dim=-1)(x + y * 2) @make_test @@ -1691,47 +1702,15 @@ def test_tuple_contains(a, b): return a + b return a - b + @unittest.expectedFailure @make_test - def test_set_invalid_ConstantVariable_op(a, b): - s = set({"banana", "apple", "orange"}) - try: - s - 1 - except TypeError: - return a + b - except Exception: - return a - b - else: - return a * b - - @make_test - def test_set_pop_raise_KeyError(a, b): - s = set() - try: - s.pop() - except KeyError: - return a + b - except Exception: - return a - b + def test_set_in_frozenset(x): + var = set("abc") + other = set([frozenset("abc")]) + if var in other: + return x + 1 else: - return a * b - - @make_test - def test_set_issubset(a, b): - vals1 = {"a", "b", "c"} - vals2 = {"b", "c"} - vals3 = {"b", "e", "f"} - if vals2.issubset(vals1) and not vals2.issubset(vals3): - return a + b - return a - b - - @make_test - def test_set_issuperset(a, b): - vals1 = {"a", "b", "c"} - vals2 = {"b", "c"} - vals3 = {"b", "e", "f"} - if vals1.issuperset(vals2) and not vals1.issuperset(vals3): - return a + b - return a - b + return x - 1 @make_test def test_set_update_bytecode(x): @@ -1751,181 +1730,6 @@ def test_set_update_list_with_duplicated_items(x): else: return x - 1 - @make_test - def test_set_contains(a, b): - vals = set(["a", "b", "c"]) - if "a" in vals: - x = a + b - else: - x = a - b - if "d" in vals: - y = a + b - else: - y = a - b - return x, y - - def test_set_isdisjoint(self): - x = {"apple", "banana", "cherry"} - y = {"google", "microsoft", "apple"} - - def fn(a): - if x.isdisjoint(y): - return a + 1 - else: - return a - 1 - - test = make_test(fn) - test(self) - - @make_test - def test_set_intersection(a, b): - set1 = {"apple", "banana", "cherry"} - set2 = {"google", "microsoft", "apple"} - set3 = {"shoes", "flipflops", "apple"} - intersection_set = set1.intersection(set2, set3) - if "apple" in intersection_set: - x = a + b - else: - x = a - b - if "banana" in intersection_set: - y = a + b - else: - y = a - b - if "shoes" in intersection_set: - z = a + b - else: - z = a - b - return x, y, z - - @make_test - def test_set_intersection_update(a, b): - set1 = {"apple", "banana", "cherry"} - set2 = {"google", "microsoft", "apple"} - set3 = {"shoes", "flipflops", "apple"} - set1.intersection_update(set2, set3) - if "apple" in set1: - x = a + b - else: - x = a - b - if "banana" in set1: - y = a + b - else: - y = a - b - if "shoes" in set1: - z = a + b - else: - z = a - b - return x, y, z - - @parametrize("_type", [set]) - def test_set_union(self, _type): - @make_test - def fn(a, b): - set1 = _type({"apple", "banana", "cherry"}) - set2 = _type({"google", "microsoft", "apple"}) - set3 = _type({"shoes", "flipflops", "sneakers"}) - union_set = set1.union(set2, set3) - if "apple" in union_set: - x = a + b - else: - x = a - b - if "banana" in union_set: - y = a + b - else: - y = a - b - if "shoes" in union_set: - z = a + b - else: - z = a - b - return x, y, z - - fn(self) - - @parametrize( - "fn_name", ["add", "symmetric_difference", "symmetric_difference_update"] - ) - def test_set_raise_TypeError(self, fn_name): - @make_test - def fn(a, b): - set1 = {"apple", "banana", "cherry"} - try: - getattr(set1, fn_name)() - except TypeError: - return a + b - return a - b - - fn(self) - - @make_test - def test_set_difference(a, b): - set1 = {"apple", "banana", "cherry"} - set2 = {"google", "microsoft", "apple"} - set3 = {"shoes", "flipflops", "sneakers"} - difference_set = set1.difference(set2, set3) - if "apple" in difference_set: - x = a + b - else: - x = a - b - if "banana" in difference_set: - y = a + b - else: - y = a - b - if "shoes" in difference_set: - z = a + b - else: - z = a - b - return x, y, z - - @make_test - def test_set_difference_update(a, b): - set1 = {"apple", "banana", "cherry"} - set2 = {"google", "microsoft", "apple"} - set3 = {"shoes", "flipflops", "sneakers"} - set1.difference_update(set2, set3) - if "apple" in set1: - x = a + b - else: - x = a - b - if "banana" in set1: - y = a + b - else: - y = a - b - if "shoes" in set1: - z = a + b - else: - z = a - b - return x, y, z - - @make_test - def test_set_symmetric_difference(a, b): - set1 = {"apple", "banana", "cherry"} - set2 = {"google", "microsoft", "apple"} - symmetric_diff_set = set1.difference(set2) - if "apple" in symmetric_diff_set: - x = a + b - else: - x = a - b - if "banana" in symmetric_diff_set: - y = a + b - else: - y = a - b - return x, y - - @make_test - def test_set_symmetric_difference_update(a, b): - set1 = {"apple", "banana", "cherry"} - set2 = {"google", "microsoft", "apple"} - set1.difference(set2) - if "apple" in set1: - x = a + b - else: - x = a - b - if "banana" in set1: - y = a + b - else: - y = a - b - return x, y - def test_set_keys_view(self): from collections.abc import KeysView @@ -1958,23 +1762,6 @@ def fn(x): x = torch.rand(4) self.assertEqual(fn(x), opt_fn(x)) - @parametrize("method", ["add", "__contains__"]) - def test_set_raise_TypeError_on_unshashable_obj(self, method): - @make_test - def fn(a, b): - s = set({1, 2, 3, 4}) - try: - m = getattr(s, method) - m([[]]) - except TypeError: - return a + b - except Exception: - return a - b - else: - return a * b - - fn(self) - def test_constant_set(self): s = set([1, 2]) @@ -2967,6 +2754,26 @@ def fn(x, a, b): opt_fn = torch.compile(fullgraph=True, backend="eager")(fn) self.assertEqual(opt_fn(x, a, b), fn(x, a, b)) + def test_list_setitem(self): + def fn(a: int): + some_array = [1, 2, 3] + some_array[a] = 5 + return torch.ones(some_array) + + opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn) + self.assertEqual(opt_fn(0), fn(0)) + self.assertEqual(opt_fn(1), fn(1)) + + def test_list_setitem_slice(self): + def fn(a: int): + some_array = [1, 2, 3] + some_array[a : a + 1] = [5] + return torch.ones(some_array) + + opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn) + self.assertEqual(opt_fn(0), fn(0)) + self.assertEqual(opt_fn(1), fn(1)) + def test_pow_int(self): def fn(a, b): return torch.pow(a, b) @@ -4083,6 +3890,39 @@ def f(): self.assertTrue(same(res, torch.ones(1))) self.assertTrue(f is f()) + def test_functools_partial_binding(self): + class Foo: + def __init__(self, x): + self.x = x + + @functools.lru_cache # noqa: B019 + def incr(self, val): + self.x += val + + def fn(x): + f = Foo(4) + f.incr(3) + return x + f.x + + x = torch.randn(2) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + + def test_functools_cache_guard(self): + class Foo: + @functools.lru_cache # noqa: B019 + def run(self, val, c=1.0): + return val * c * 2 + + f = Foo() + + def fn(x): + return f.run(x) + + x = torch.randn(2) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + def udf_mul(x, y): return x * y @@ -4376,6 +4216,33 @@ def fn(a, b): fn(self) + @parametrize( + "method_name", + [ + "copy", + "difference", + "intersection", + "symmetric_difference", + "union", + ], + ) + def test_frozenset_return_type(self, method_name): + @make_test + def fn(a, b): + set1 = frozenset({"apple", "banana", "cherry"}) + set2 = frozenset({"google", "microsoft", "apple"}) + if method_name == "copy": + result = set1.copy() + else: + result = getattr(set1, method_name)(set2) + if type(result) is frozenset: + x = a + b + else: + x = a - b + return x + + fn(self) + def test_frozenset_construction(self): def fn(x): s = frozenset({x}) @@ -4842,6 +4709,68 @@ def fn(x, tup): self.assertTrue(ref_tup.checked) self.assertTrue(res_tup.checked) + def test_udf_tuple_construction(self): + class MyTuple(tuple): # noqa: SLOT001 + pass + + def fn(x): + tup = MyTuple([1, 2, 3]) + if 3 in tup: + x = torch.cos(x) + else: + x = torch.sin(x) + return x, tup + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref_x, ref_tup = fn(x) + res_x, res_tup = opt_fn(x) + self.assertEqual(ref_x, res_x) + self.assertEqual(ref_tup, res_tup) + + def test_udf_tuple_construction_custom_new(self): + class MyTuple(tuple): # noqa: SLOT001 + def __new__(cls, *args, **kwargs): + return super().__new__(cls, [1, 2, 3]) + + def fn(x): + tup = MyTuple() + if 3 in tup: + x = torch.cos(x) + else: + x = torch.sin(x) + return x, tup + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref_x, ref_tup = fn(x) + res_x, res_tup = opt_fn(x) + self.assertEqual(ref_x, res_x) + self.assertEqual(ref_tup, res_tup) + + def test_udf_namedtuple(self): + class MyTuple(NamedTuple): + a: torch.Tensor + b: torch.Tensor + + class PairTensor(MyTuple): + def __new__(cls, a, b): + return super().__new__(cls, a, b) + + def __add__(self, other): + return PairTensor(self.a + other.a, self.b + other.b) + + def fn(pair1, pair2): + return pair1 + pair2 + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + pair1 = PairTensor(torch.randn(4), torch.randn(2, 8)) + pair2 = PairTensor(torch.randn(1), torch.randn(2, 1)) + ref = fn(pair1, pair2) + res = opt_fn(pair1, pair2) + self.assertEqual(ref.a, res.a) + self.assertEqual(ref.b, res.b) + def test_udf_tuple_reconstruction(self): class MyTuple(tuple): # noqa: SLOT001 pass diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py new file mode 100644 index 00000000000000..2954d2aa296932 --- /dev/null +++ b/test/dynamo/test_fx_graph_runnable.py @@ -0,0 +1,183 @@ +# Owner(s): ["module: dynamo"] +import io +import logging +import subprocess +import sys +import tempfile +import unittest + +import torch +import torch._logging.structured +from torch._inductor.test_case import TestCase +from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE + + +class FxGraphRunnableArtifactFilter(logging.Filter): + def filter(self, record): + return ( + "artifact" in record.metadata + and record.metadata["artifact"]["name"] == "fx_graph_runnable" + ) + + +class StructuredTracePayloadFormatter(logging.Formatter): + def format(self, record): + return record.payload.strip() + + +trace_log = logging.getLogger("torch.__trace") + + +class ToyModel(torch.nn.Module): + def __init__(self, input_size=10, hidden_size=20, output_size=5): + super().__init__() + self.linear1 = torch.nn.Linear(input_size, hidden_size) + self.linear2 = torch.nn.Linear(hidden_size, output_size) + self.relu = torch.nn.ReLU() + self.dropout = torch.nn.Dropout(0.1) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.dropout(x) + x = self.linear2(x) + return x + + +class FxGraphRunnableTest(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + torch._logging.structured.INTERN_TABLE.clear() + self.old_level = trace_log.level + trace_log.setLevel(logging.DEBUG) + + # Create a custom filter specifically for fx_graph_runnable entries + self.filter = FxGraphRunnableArtifactFilter() + + # Create a separate buffer and handler for capturing fx_graph_runnable entries + self.buffer = io.StringIO() + self.handler = logging.StreamHandler(self.buffer) + self.handler.setFormatter(StructuredTracePayloadFormatter()) + self.handler.addFilter(self.filter) + trace_log.addHandler(self.handler) + + def tearDown(self): + trace_log.removeHandler(self.handler) + trace_log.setLevel(self.old_level) + + def _exec_and_verify_payload(self): + # Write captured payload & run it in a fresh Python process + payload = self.buffer.getvalue().strip() + self.assertTrue(payload, "Expected fx_graph_runnable payload but got nothing") + self.assertIn("def forward", payload) # sanity-check for actual FX code + + with tempfile.NamedTemporaryFile("w", suffix=".py") as tmp: + tmp.write(payload) + tmp.flush() + res = subprocess.run( + [sys.executable, tmp.name], capture_output=True, text=True, timeout=30 + ) + self.assertEqual( + res.returncode, + 0, + f"Standalone fx_graph_runnable failed:\nSTDERR:\n{res.stderr}", + ) + + # basic tests + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") + def test_basic_tensor_add(self): + def f(x): + return x + 1 + + torch.compile(f)(torch.randn(4)) + self._exec_and_verify_payload() + + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") + def test_two_inputs_matmul(self): + def f(a, b): + return (a @ b).relu() + + a, b = torch.randn(2, 3), torch.randn(3, 4) + torch.compile(f)(a, b) + self._exec_and_verify_payload() + + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") + def test_scalar_multiply(self): + def f(x): + return x * 2 + + torch.compile(f)(torch.randn(5)) + self._exec_and_verify_payload() + + # testing dynamic shapes + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") + def test_dynamic_shapes_run(self): + torch._dynamo.reset() + torch._dynamo.config.dynamic_shapes = True + + def f(x): + return (x @ x.transpose(0, 1)).relu() + + a = torch.randn(10, 12) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(a, 1) + + torch.compile(f)(a) + self._exec_and_verify_payload() + + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") + def test_broadcast_add_dynamic(self): + torch._dynamo.reset() + torch._dynamo.config.dynamic_shapes = True + + def f(x, y): + return x + y * 2 + + x = torch.randn(5, 1) + y = torch.randn(1, 8) + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(y, 1) + + torch.compile(f)(x, y) + self._exec_and_verify_payload() + + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") + def test_toy_model_basic(self): + model = ToyModel(input_size=8, hidden_size=16, output_size=4) + model.eval() # Set to eval mode to avoid dropout randomness + + x = torch.randn(3, 8) + torch.compile(model)(x) + self._exec_and_verify_payload() + + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") + def test_toy_model_batch_processing(self): + model = ToyModel(input_size=12, hidden_size=24, output_size=6) + model.eval() + + x = torch.randn(16, 12) + torch.compile(model)(x) + self._exec_and_verify_payload() + + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") + def test_toy_model_dynamic_batch(self): + torch._dynamo.reset() + torch._dynamo.config.dynamic_shapes = True + + model = ToyModel(input_size=10, hidden_size=20, output_size=5) + model.eval() + + x = torch.randn(7, 10) + torch._dynamo.mark_dynamic(x, 0) + + torch.compile(model)(x) + self._exec_and_verify_payload() + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + if not (IS_FBCODE or IS_SANDCASTLE): + # fbcode complains about not being able to find torch in subprocess + run_tests() diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index adf1e5aff0d398..80ec15499147f5 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -22,10 +22,13 @@ def setUp(self): super().setUp() self._old = torch._dynamo.config.enable_faithful_generator_behavior torch._dynamo.config.enable_faithful_generator_behavior = True + self._unittest_old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True def tearDown(self): super().tearDown() torch._dynamo.config.enable_faithful_generator_behavior = self._old + torch._dynamo.config.enable_trace_unittest = self._unittest_old def _compile_check(self, fn, args=None, fullgraph=True): eager = EagerAndRecordGraphs() @@ -886,6 +889,37 @@ def f(x): torch.compile(f, backend="eager", fullgraph=True)(torch.ones(3)), ) + @make_dynamo_test + def test_generator___contains__(self): + def whoo(): + yield 1 + yield 2 + + g = whoo() + self.assertTrue(1 in g) + self.assertTrue(2 in g) + self.assertRaises(StopIteration, next, g) + self.assertFalse(3 in whoo()) + + @make_dynamo_test + def test_generator___contains___side_effects(self): + n = 0 + + def whoo(): + nonlocal n + n = 1 + yield 1 + n = 2 + yield 2 + + g = whoo() + self.assertTrue(1 in g) + self.assertEqual(n, 1) + self.assertTrue(2 in g) + self.assertEqual(n, 2) + self.assertRaises(StopIteration, next, g) + self.assertFalse(3 in whoo()) + class TestGeneratorSend(GeneratorTestsBase): def test_send(self): diff --git a/test/dynamo/test_graph_region_tracker.py b/test/dynamo/test_graph_region_tracker.py index a7bd75f48e689c..dfc452020957f8 100644 --- a/test/dynamo/test_graph_region_tracker.py +++ b/test/dynamo/test_graph_region_tracker.py @@ -317,6 +317,17 @@ def fn(x, y, z): """{sin_: OrderedSet([0]), add_: OrderedSet([0])}""", ) + def test_mutation_tracking_setitem(self): + def fn(x): + y = x + 1 + y[0] = 3 + return y + + self.assertExpectedInline( + self.get_mutation_tracking(fn, torch.rand(10, 10)), + """{setitem: OrderedSet([0])}""", + ) + def test_mutation_tracking_allow_in_graph(self): @torch._dynamo.allow_in_graph def fn_mut(x, y): @@ -338,6 +349,27 @@ def fn(x, y): """{o0: OrderedSet([0]), sin_: OrderedSet([0])}""", ) + def test_non_tensor_arg_hashing(self): + def inner(x, w, t): + y = x + x + return torch.conv2d(y, w, None, *t) + + def fn(x, y): + o1 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1)) + o2 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1)) + o3 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1)) + o4 = inner(x, y, ((2, 2), (0, 0), (1, 1), 1)) + return o1.sum() + o2.sum() + o3.sum() + o4.sum() + + self.assertExpectedInline( + self.get_result( + fn, + torch.rand(32, 256, 56, 56), + torch.nn.Parameter(torch.rand(512, 256, 1, 1)), + ), + """[[['y', 'o1'], ['y_1', 'o2'], ['y_2', 'o3']]]""", + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 1cfd6d77ece047..1aeafaf5dd33cd 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -68,7 +68,8 @@ def less_match_verbose_code_parts(expected): class GuardManagerTests(torch._dynamo.test_case.TestCase): def test_global_state_guard(self): - guard = guards.GLOBAL_STATE(["global_state_check"]) + root = RootGuardManager() + guard = guards.GLOBAL_STATE(root, ["global_state_check"]) self.assertTrue(guard(None)) with set_default_dtype(torch.double): self.assertFalse(guard(None)) @@ -109,7 +110,9 @@ def test_global_state_reason(self): self.assertEqual(guards.reason(), "grad_mode ") def test_python_lambda_leaf_guard(self): + root = RootGuardManager() const_guard = guards.LAMBDA_GUARD( + root, functools.partial(equals_match, expected=5), equals_match_verbose_code_parts(5), ) @@ -118,15 +121,16 @@ def test_python_lambda_leaf_guard(self): self.assertFalse(const_guard("foo")) def test_type_guard(self): + root = RootGuardManager() foo = 4 - guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"]) + guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == int"]) self.assertTrue(guard(5)) self.assertTrue(guard(4)) self.assertFalse(guard("foo")) foo = {"a": 1} - guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"]) + guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == dict"]) self.assertTrue(guard(foo)) self.assertTrue(guard({})) self.assertFalse(guard(5)) @@ -139,30 +143,32 @@ def __init__(self, x, y): foo = Foo(1, 2) - guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"]) + guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == Foo"]) self.assertTrue(guard(foo)) self.assertFalse(guard({})) self.assertFalse(guard(5)) self.assertFalse(guard("foo")) def test_id_guard(self): + root = RootGuardManager() foo = 4 - guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) + guard = guards.ID_MATCH(root, id(foo), ["id(x) == id(foo)"]) self.assertTrue(guard(foo)) self.assertFalse(guard(5)) self.assertFalse(guard("foo")) foo = {"a": 1} - guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) + guard = guards.ID_MATCH(root, id(foo), ["id(x) == id(foo)"]) self.assertTrue(guard(foo)) self.assertFalse(guard({"a": 1})) self.assertFalse(guard({})) self.assertFalse(guard(5)) def test_equals_guard(self): + root = RootGuardManager() foo = 4 - guard = guards.EQUALS_MATCH(foo, ["x == 4"]) + guard = guards.EQUALS_MATCH(root, foo, ["x == 4"]) self.assertTrue(guard(4)) self.assertFalse(guard(5)) @@ -170,7 +176,7 @@ def test_equals_guard(self): # tuple foo = (1, 2, 3) - guard = guards.EQUALS_MATCH(foo, ["x == foo"]) + guard = guards.EQUALS_MATCH(root, foo, ["x == foo"]) self.assertTrue(guard(foo)) self.assertTrue(guard((1, 2, 3))) self.assertFalse(guard((1, 2, 3, 4))) @@ -178,21 +184,22 @@ def test_equals_guard(self): # list foo = [1, 2, 3] - guard = guards.EQUALS_MATCH(foo, ["x == foo"]) + guard = guards.EQUALS_MATCH(root, foo, ["x == foo"]) self.assertTrue(guard(foo)) self.assertTrue(guard([1, 2, 3])) self.assertFalse(guard([1, 2, 3, 4])) # type foo = int - guard = guards.EQUALS_MATCH(foo, ["x == foo"]) + guard = guards.EQUALS_MATCH(root, foo, ["x == foo"]) self.assertTrue(guard(foo)) self.assertTrue(guard(int)) self.assertFalse(guard(float)) def test_default_device_guard(self): + root = RootGuardManager() foo = 1 - guard = guards.DEFAULT_DEVICE(["cpu device"]) + guard = guards.DEFAULT_DEVICE(root, ["cpu device"]) self.assertTrue(guard(foo)) try: @@ -202,12 +209,15 @@ def test_default_device_guard(self): torch.set_default_device(None) def test_length_check_guard(self): + root = RootGuardManager() foo = [1, 2, 3] - guard = guards.LENGTH_CHECK(len(foo), ["len(x) == len(foo)"]) + guard = guards.LENGTH_CHECK(root, len(foo), ["len(x) == len(foo)"]) self.assertTrue(guard(foo)) self.assertFalse(guard([])) def test_no_hasattr_guard(self): + root = RootGuardManager() + class Bar: def __init__(self) -> None: self.bar = 2 @@ -220,7 +230,7 @@ def __init__(self) -> None: foo = Foo() - guard = guards.NO_HASATTR("foo", ["hasattr(x, 'foo') == False"]) + guard = guards.NO_HASATTR(root, "foo", ["hasattr(x, 'foo') == False"]) self.assertTrue(guard(bar)) self.assertFalse(guard(foo)) @@ -258,8 +268,9 @@ def __init__(self, x, y): self.assertFalse(guard_manager.check(f_locals_unaliased)) def test_dict_version_guard(self): + root = RootGuardManager() foo = {"a": 1, "b": 2} - guard = guards.DICT_VERSION(foo, ["x.version == foo.version"]) + guard = guards.DICT_VERSION(root, foo, ["x.version == foo.version"]) self.assertTrue(guard(foo)) self.assertFalse(guard(dict(foo))) @@ -269,8 +280,9 @@ def test_dict_version_guard(self): self.assertFalse(guard({})) def test_dynamic_indices_guard(self): - guard1 = guards.DYNAMIC_INDICES(set(), ["x.size(0) == y.size(0)"]) - guard2 = guards.DYNAMIC_INDICES(set({0, 1}), ["x.size(0) == y.size(0)"]) + root = RootGuardManager() + guard1 = guards.DYNAMIC_INDICES(root, set(), ["x.size(0) == y.size(0)"]) + guard2 = guards.DYNAMIC_INDICES(root, set({0, 1}), ["x.size(0) == y.size(0)"]) x = torch.randn(4) self.assertTrue(guard1(x)) @@ -368,18 +380,20 @@ def __init__(self, x, y, z): self.assertFalse(guard_manager.check_verbose(f_locals_unaliased).result) def test_weakref_alive_guard(self): + root = RootGuardManager() x = torch.rand(3, 4) weakref_x = weakref.ref(x) - guard = guards.NOT_NONE(["weakref_x is not None"]) + guard = guards.NOT_NONE(root, ["weakref_x is not None"]) self.assertTrue(guard(weakref_x())) del x self.assertFalse(guard(weakref_x())) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_call_function_no_args_guard(self): + root = RootGuardManager() x = torch.cuda.current_device() - guard = guards.EQUALS_MATCH(x, [0]) + guard = guards.EQUALS_MATCH(root, x, [0]) self.assertTrue(guard(0)) self.assertFalse(guard(1)) self.assertFalse(guard(2)) @@ -697,15 +711,16 @@ def fn(x): self.assertTrue("Test" in debug_info.verbose_code_parts[0]) def test_dict_contains_guard(self): + root = RootGuardManager() foo = {"a": 1, "b": 2} - guard = guards.DICT_CONTAINS(True, "a", ["has a"]) + guard = guards.DICT_CONTAINS(root, True, "a", ["has a"]) self.assertTrue(guard(foo)) self.assertTrue(guard({"a": 1, "b": 2})) self.assertFalse(guard({"b": 2, "c": 3})) self.assertFalse(guard({})) - guard = guards.DICT_CONTAINS(False, "c", ["not has c"]) + guard = guards.DICT_CONTAINS(root, False, "c", ["not has c"]) self.assertTrue(guard(foo)) self.assertTrue(guard({"a": 1, "b": 2})) self.assertFalse(guard({"b": 2, "c": 3})) diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index f421597b5eaa13..8e5f12894711e8 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -254,7 +254,7 @@ def _tracefunc(self, frame, event, arg): self._frame_state = _FrameState( f_locals=dict(frame.f_locals), - f_globals=dict(frame.f_globals), + f_globals=frame.f_globals, f_code=frame.f_code, f_builtins=frame.f_builtins, ) @@ -305,6 +305,9 @@ def transform(instructions: list, code_options: dict[str, object]): nonlocal ref_gm nonlocal loaded_gm + torch._dynamo.convert_frame.initial_global_state = ( + torch._C._dynamo.guards.GlobalStateGuard() + ) tracer = InstructionTranslator( instructions, self._frame_state.f_code, @@ -322,20 +325,32 @@ def transform(instructions: list, code_options: dict[str, object]): speculation_log=SpeculationLog(), exn_vt_stack=ExceptionStack(), distributed_state=None, + package=None, ) - with compile_context(CompileContext(CompileId(0, 0))), tracing( - tracer.output.tracing_context - ), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""): + with ( + compile_context(CompileContext(CompileId(0, 0))), + tracing(tracer.output.tracing_context), + tracer.set_current_tx(), + get_metrics_context(), + dynamo_timed(""), + ): tracer.run() + ref_gm = CheckFunctionManager( + self._frame_state.f_code, + tracer.output, + guard_filter_fn=guard_filter_fn, + ).guard_manager + check_fn_manager = CheckFunctionManager( self._frame_state.f_code, tracer.output, guard_filter_fn=guard_filter_fn, guards_serialization_mode="save", ) - ref_gm = check_fn_manager.guard_manager guards_state = check_fn_manager.guards_state + self._cached_guards_state = guards_state + self._cached_f_code = self._frame_state.f_code self.assertIsNotNone(guards_state) guards_state = pickle.loads(guards_state) @@ -344,12 +359,14 @@ def transform(instructions: list, code_options: dict[str, object]): guards_state.output_graph, guards_serialization_mode="load", shape_code_parts=guards_state.shape_code_parts, + runtime_global_scope=self._frame_state.f_globals, ) loaded_gm = check_fn_manager.guard_manager try: transform_code_object(self._frame_state.f_code, transform) finally: + torch._dynamo.convert_frame.initial_global_state = None self._frame_state = None self.assertIsNotNone(ref_gm) @@ -1032,10 +1049,10 @@ def fn(x, x_): return x + x_ x = torch.randn(3, 2) - with self.assertRaisesRegex( - PackageError, "DUPLICATE_INPUT guard cannot be serialized" - ): - self._test_serialization("DUPLICATE_INPUT", fn, x, x) + ref, loaded = self._test_serialization("DUPLICATE_INPUT", fn, x, x) + + self._test_check_fn(ref, loaded, {"x": x, "x_": x}, True) + self._test_check_fn(ref, loaded, {"x": x, "x_": torch.randn(3, 2)}, False) def test_weakref_alive(self): mod = torch.nn.Linear(10, 10, bias=False) @@ -1133,6 +1150,25 @@ def fn(x): with torch.enable_grad(): self._test_check_fn(ref, loaded, {"x": x}, True) + def test_grad_mode_loading(self): + def fn(x): + return x + 1 + + x = torch.randn(3, 2) + with torch.enable_grad(): + ref, _ = self._test_serialization("GRAD_MODE", fn, x) + with torch.no_grad(): + # Ensure guards state loading is not affected by the current global grad mode. + guards_state = pickle.loads(self._cached_guards_state) + check_fn_manager = CheckFunctionManager( + self._cached_f_code, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + ) + loaded = check_fn_manager.guard_manager + self._test_check_fn(ref, loaded, {"x": x}, False) + def test_deterministic_algorithms(self): def fn(x): return x + 1 @@ -1248,6 +1284,30 @@ def fn(x): self._test_check_fn(ref, loaded, {"x": torch.randn(3, 11, 2)}, False) self._test_check_fn(ref, loaded, {"x": torch.randn(3, 2, 2)}, False) + def test_builtin_match(self): + def fn(x): + # usage of getattr() here installs a BUILTIN_MATCH guard + s = getattr(x, "shape") # noqa: B009 + return x + s[0] + + x = torch.randn(3) + + ref, loaded = self._test_serialization("BUILTIN_MATCH", fn, x) + self._test_check_fn(ref, loaded, {"x": x}, True) + getattr_original = getattr + + def getattr_new(*args, **kwargs): + return getattr_original(*args, **kwargs) + + builtins_dict = ( + __builtins__ if isinstance(__builtins__, dict) else __builtins__.__dict__ + ) + builtins_dict["getattr"] = getattr_new + try: + self._test_check_fn(ref, loaded, {"x": x}, False) + finally: + builtins_dict["getattr"] = getattr_original + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 6aaf59582de96e..42d84c8b79e319 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -296,6 +296,58 @@ def f(x): arg_count = ifdynstaticdefault(3, 4) self._test_wrap_simple(f, default_args_generator((x,)), arg_count) + def test_allow_python_side_effects_utility(self): + from torch._dynamo.utils import ( + _disable_side_effect_safety_checks_for_current_subtracer, + ) + from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper + + def wrapper(fn): + return fn + + count = 0 + + def does_side_effect(x): + nonlocal count + count += 1 + return x.sin() + + def does_side_effect_wrapped(*args, **kwargs): + return _disable_side_effect_safety_checks_for_current_subtracer( + does_side_effect, *args, **kwargs + ) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + return dynamo_bypassing_wrapper(wrapper, does_side_effect_wrapped, x) + + x = torch.tensor(1.0) + fn(x) + + def inner_does_side_effect(x): + nonlocal count + count += 1 + return x + + # Test that any nested HOPs are unaffected + def outer(x): + return dynamo_bypassing_wrapper(wrapper, inner_does_side_effect, x) + + def outer_wrapped(*args, **kwargs): + return _disable_side_effect_safety_checks_for_current_subtracer( + outer, *args, **kwargs + ) + + @torch.compile(backend="eager", fullgraph=True) + def fn_nested(x): + return dynamo_bypassing_wrapper(wrapper, outer_wrapped, x) + + x = torch.tensor(1.0) + with self.assertRaisesRegex( + RuntimeError, "Mutating a variable not in the current scope" + ): + fn_nested(x) + def test_symint_input(self): def f(x): i = x.size(0) @@ -2081,7 +2133,7 @@ def false_fn(x): and node.target == torch.ops.higher_order.cond ): _, _, _, operands = node.args - # Since we compile wit dynamic, each branch takes 4 inputs (buffer, x, z, s1) + # Since we compile with dynamic, each branch takes 4 inputs (buffer, x, z, s1) self.assertEqual(len(operands), 4) if node.op == "get_attr": if str(node.target) in ("cond_true_0, cond_false_0"): diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index 3908178f18db1a..3f3a3bd7f65378 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -746,7 +746,7 @@ def test_fn(fn): if cnts: self.assertEqual(cnts.frame_count, 1) # These same exact assertions run on both eager and compiled - # X goes to x*2 becaue of mul_ + # X goes to x*2 because of mul_ self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2) # This test proves grad aliasing works - self.assertEqual(x.grad, b * 5) diff --git a/test/dynamo/test_inline_and_install.py b/test/dynamo/test_inline_and_install.py index 6a94b9f9ea85e7..b38b96ccc3e9ea 100644 --- a/test/dynamo/test_inline_and_install.py +++ b/test/dynamo/test_inline_and_install.py @@ -57,7 +57,7 @@ def make_dynamic_cls(cls): ) -# These tests do string comparisson on the graphs, and since buffers are now inlined, they +# These tests do string comparison on the graphs, and since buffers are now inlined, they # are named different, resulting in failure unittest.expectedFailure( InlineAndInstallExportTests.test_param_buffer_safe_from_mutation_simple_inline_and_install # noqa: F821 diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 4e7d19ca259ae9..2b120349ea01aa 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -40,33 +40,18 @@ def munge_shape_guards(s: str) -> str: - SHAPE_GUARD = ( - "SYMBOLIC_SHAPE_GUARD" - if torch._dynamo.config.enable_cpp_symbolic_shape_guards - else "LAMBDA_GUARD" - ) SHAPE_GUARD_REGEX = ( - r"[| ]* \+- SYMBOLIC_SHAPE_GUARD" + r"[| ]* \+- SYMBOLIC_SHAPE_GUARD:" if torch._dynamo.config.enable_cpp_symbolic_shape_guards - else r"\+- LAMBDA_GUARD" + else r"^\+- LAMBDA_GUARD:" ) def munge(s): - return re.sub( - SHAPE_GUARD_REGEX, - "+- __SHAPE_GUARD__", - re.sub(r"[^ ]+:\d+ in [^ ]+", "#:# in #", s), - ) - - lines = [munge(l) for l in s.splitlines() if SHAPE_GUARD in l] + s = re.sub(r"[^ ]+:\d+ in [^ ]+", "#:# in #", s) + return re.subn(SHAPE_GUARD_REGEX, "+- __SHAPE_GUARD__:", s) - if torch._dynamo.config.enable_cpp_symbolic_shape_guards: - # Since we can have multiple guard accessors for one guard, the shape guard - # printing will have just SYMBOLIC_SHAPE_GUARD in one line for the second - # guard accessor and onwards. We remove those lines - lines = [line for line in lines if "__SHAPE_GUARD__:" in line] - - return "\n".join(lines) + lines = [munge(l) for l in s.splitlines()] + return "\n".join([line for line, nsubs in lines if nsubs > 0]) def example_fn(a): @@ -169,6 +154,22 @@ def test_dynamo_debug_default_off_artifacts(self, records): self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0) self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0) + @make_logging_test(hierarchical_compile=True) + def test_hierarchical_compile(self, records): + from torch._higher_order_ops.invoke_subgraph import mark_compile_region + + @mark_compile_region + def gn(x): + return x * 2 + + def fn(x): + return gn(x) + + fn_opt = torch.compile(fn, backend="inductor") + fn_opt(torch.ones(1000, 1000)) + fn_opt(torch.ones(1000, 1000)) + self.assertGreater(len(records), 0) + @make_logging_test() def test_dynamo_error(self, records): try: @@ -710,7 +711,6 @@ def f(x, y, z): self.assertExpectedInline( munge_shape_guards(record.getMessage()), """\ -| +- __SHAPE_GUARD__: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # #:# in # +- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # +- __SHAPE_GUARD__: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) +- __SHAPE_GUARD__: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in # @@ -729,7 +729,6 @@ def f(x, y): self.assertExpectedInline( munge_shape_guards(record.getMessage()), """\ -| +- __SHAPE_GUARD__: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # #:# in # +- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # return any([x.size(0) == y.size(0) * 2]) # #:# in # #:# in # +- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return any([x.size(0) == y.size(0) * 2]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 ) @@ -749,7 +748,6 @@ def f(x, y): self.assertExpectedInline( munge_shape_guards(record.getMessage()), """\ -| +- __SHAPE_GUARD__: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # #:# in # +- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # torch._check(x.size(0) == y.size(0) * 2) # #:# in # #:# in # +- __SHAPE_GUARD__: 3 <= L['y'].size()[0] <= 14 # torch._check(x.size(0) > 5) # #:# in # #:# in # and torch._check(x.size(0) < 30) # #:# in # #:# in #""", # noqa: B950 ) @@ -782,7 +780,7 @@ def test_optimizer_non_static_param(self, records): @requires_cuda @unittest.skipIf(not SM90OrLater, "requires H100+ GPU") def test_autotuning(self, records): - with torch._inductor.utils.fresh_inductor_cache(): + with torch._inductor.utils.fresh_cache(): def f(a, b): return torch.mm(a, b) @@ -957,9 +955,10 @@ def bar(): "cudagraph_static_inputs", "benchmarking", "loop_ordering", + "loop_tiling", "autotuning", "graph_region_expansion", - "codecache", + "hierarchical_compile", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/test/dynamo/test_metrics_context.py b/test/dynamo/test_metrics_context.py index 01016eea471548..3a8657003cd19b 100644 --- a/test/dynamo/test_metrics_context.py +++ b/test/dynamo/test_metrics_context.py @@ -64,7 +64,7 @@ def test_set_disallow_overwrite(self): def test_update_disallow_overwrite(self): """ - Validate update won't overwite. + Validate update won't overwrite. """ with MetricsContext(self._on_exit) as context: context.update({"m1": 1, "m2": 2}) @@ -73,7 +73,7 @@ def test_update_disallow_overwrite(self): def test_update_allow_overwrite(self): """ - Validate update will overwite when given param. + Validate update will overwrite when given param. """ with MetricsContext(self._on_exit) as context: context.update({"m1": 1, "m2": 2}) diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index ec3aa2c6067214..9ea0f287edc50d 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -3,14 +3,17 @@ import torch._dynamo from torch._dynamo.test_minifier_common import MinifierTestBase +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import skipIfNNModuleInlined -requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") +requires_gpu = unittest.skipUnless( + torch.cuda.is_available() or torch.xpu.is_available(), "requires cuda or xpu" +) class MinifierTests(MinifierTestBase): - # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA) + # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA/XPU) def _test_after_dynamo(self, device, backend, expected_error): run_code = f"""\ @torch.compile(backend={backend!r}) @@ -41,22 +44,22 @@ def test_after_dynamo_cpu_accuracy_error(self): "cpu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError" ) - @requires_cuda - def test_after_dynamo_cuda_compile_error(self): + @requires_gpu + def test_after_dynamo_cuda_compile_error(self, device): self._test_after_dynamo( - "cuda", "relu_compile_error_TESTING_ONLY", "ReluCompileError" + device, "relu_compile_error_TESTING_ONLY", "ReluCompileError" ) - @requires_cuda - def test_after_dynamo_cuda_runtime_error(self): + @requires_gpu + def test_after_dynamo_cuda_runtime_error(self, device): self._test_after_dynamo( - "cuda", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError" + device, "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError" ) - @requires_cuda - def test_after_dynamo_cuda_accuracy_error(self): + @requires_gpu + def test_after_dynamo_cuda_accuracy_error(self, device): self._test_after_dynamo( - "cuda", "relu_accuracy_error_TESTING_ONLY", "AccuracyError" + device, "relu_accuracy_error_TESTING_ONLY", "AccuracyError" ) def test_after_dynamo_non_leaf_compile_error(self): @@ -94,38 +97,38 @@ def test_after_dynamo_cpu_accuracy_backend_passes(self): "cpu", "relu_accuracy_error_TESTING_ONLY" ) - @requires_cuda - def test_after_dynamo_cuda_compile_backend_passes(self): + @requires_gpu + def test_after_dynamo_cuda_compile_backend_passes(self, device): self._test_after_dynamo_backend_passes( - "cuda", "relu_compile_error_TESTING_ONLY" + device, "relu_compile_error_TESTING_ONLY" ) - @requires_cuda - def test_after_dynamo_cuda_runtime_backend_passes(self): + @requires_gpu + def test_after_dynamo_cuda_runtime_backend_passes(self, device): self._test_after_dynamo_backend_passes( - "cuda", "relu_runtime_error_TESTING_ONLY" + device, "relu_runtime_error_TESTING_ONLY" ) - @requires_cuda - def test_after_dynamo_cuda_accuracy_backend_passes(self): + @requires_gpu + def test_after_dynamo_cuda_accuracy_backend_passes(self, device): self._test_after_dynamo_backend_passes( - "cuda", "relu_accuracy_error_TESTING_ONLY" + device, "relu_accuracy_error_TESTING_ONLY" ) - # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd + # Test that a module with mixed cpu/(cuda|xpu) parts with an error after dynamo can be repro'd @skipIfNNModuleInlined() - @requires_cuda - def test_cpu_cuda_module_after_dynamo(self): + @requires_gpu + def test_cpu_cuda_module_after_dynamo(self, device): backend_name = "relu_compile_error_TESTING_ONLY" run_code = f"""\ class CpuCudaModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.m_x = torch.nn.Linear(20, 20).cuda() + self.m_x = torch.nn.Linear(20, 20).to(device) self.m_y = torch.nn.Linear(20, 20) - self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) + self.p_x = torch.nn.Parameter(torch.randn(20, 20).to(device)) self.p_y = torch.nn.Parameter(torch.randn(20, 20)) - self.b_x = torch.nn.Buffer(torch.ones(20, 20).cuda()) + self.b_x = torch.nn.Buffer(torch.ones(20, 20).to(device)) self.b_y = torch.nn.Buffer(torch.ones(20, 20)) def forward(self, x, y): @@ -135,12 +138,12 @@ def forward(self, x, y): @torch.compile(backend={backend_name!r}) def inner(x1, y1): - x2 = torch.randn(20, 20).cuda() + x2 = torch.randn(20, 20).to(device) y2 = torch.randn(20, 20) x3, y3 = mod(x1 + x2, y1 + y2) return torch.relu(x3.cpu() + y3) -inner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) +inner(torch.randn(20, 20).to(device), torch.randn(20, 20)) """ res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False) @@ -151,18 +154,18 @@ def inner(x1, y1): class Repro(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda() + self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).to(device) self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True) - self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).cuda()) + self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).to(device)) self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32)) - self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device="cuda")) + self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device=device)) self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32)) def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor): l_x1_ = L_x1_ l_y1_ = L_y1_ randn = torch.randn(20, 20) - x2 = randn.cuda(); randn = None + x2 = randn.to(device); randn = None y2 = torch.randn(20, 20) add = l_x1_ + x2; l_x1_ = x2 = None add_1 = l_y1_ + y2; l_y1_ = y2 = None @@ -213,6 +216,11 @@ def forward(self, x_19): ) +devices = ["cuda", "xpu", "cpu"] +instantiate_device_type_tests( + MinifierTests, globals(), only_for=devices, allow_xpu=True +) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 2179abf2582eab..49896ae6ae2c56 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -52,9 +52,9 @@ skipIfNotPy311, unsupported, ) -from torch._dynamo.utils import counters, ifdynstaticdefault +from torch._dynamo.utils import call_size, counters, ifdynstaticdefault from torch._dynamo.variables import builder -from torch._inductor.utils import fresh_inductor_cache, run_and_get_code +from torch._inductor.utils import fresh_cache, run_and_get_code from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.qconfig import QConfig @@ -90,6 +90,7 @@ skipIfNNModuleInlined, skipIfWindows, TEST_HPU, + TEST_XPU, wrapDeterministicFlagAPITest, ) from torch.testing._internal.jit_utils import JitTestCase @@ -579,6 +580,16 @@ def fn(x): self.assertEqual(obj.y, x + 1) self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"}) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_unbacked_repeat_cat(self): + def f(x, n): + m = x.item() + x = torch.empty(x).repeat(n) # s0*u0 + return torch.cat([x, x], dim=0) + + fn = torch.compile(f, backend="eager", dynamic=True, fullgraph=True) + fn(torch.tensor([5]), 5) + def test_tensor_setattr_getset_descriptor(self): # Tensor attribute `real` has special getter/setter for complex dtype. def f(x): @@ -1082,9 +1093,7 @@ def fn(x, y): not ___dict_contains('cccccccc', G['sys'].modules) str(L['x'].device) == 'cpu' str(L['x'].dtype) == 'torch.float32' -utils_device.CURRENT_DEVICE == None""".split( - "\n" - ): +utils_device.CURRENT_DEVICE == None""".split("\n"): self.assertIn( line, guard_code_str, @@ -2003,6 +2012,30 @@ def fn(g, x): with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): res = opt_fn(g, torch.ones(2, 2)) + def test_set_descriptor(self): + class Field: + def __set__(self, obj, value): + obj.__dict__["field"] += value * 2 + + class Foo: + field = Field() + + def __init__(self): + self.__dict__["field"] = 0 + + def fn(x, foo): + foo.field = 10 + return x + foo.field + + opt_fn = torch.compile(fn, fullgraph=True, backend="eager") + x = torch.zeros(2) + foo1, foo2 = Foo(), Foo() + + ref = fn(x, foo1) + res = opt_fn(x, foo2) + self.assertEqual(ref, res) + self.assertEqual(foo1.field, foo2.field) + def test_get_attr_function(self): def fn(g, x): return g(x) @@ -2772,7 +2805,7 @@ def test_dtypes_no_graphbreaks(self): "int", np.intp, np.int32, - np.uint8 + np.uint8, # np.dtype('int') # XXX: as above ] @@ -3242,7 +3275,7 @@ def fn(m, x): def test_global_state_guard_serialization(self): GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard guards = GlobalStateGuard() - serialized_guards = guards.dump() + serialized_guards = guards.__getstate__() json_guards = json.loads(serialized_guards) samples = [] @@ -3264,17 +3297,17 @@ def test_global_state_guard_serialization(self): samples.append(new_dict) for sample in samples: - guards.load(json.dumps(sample)) + guards.__setstate__(json.dumps(sample)) self.assertFalse(guards.check()) - guards.load(json.dumps(json_guards)) + guards.__setstate__(json.dumps(json_guards)) self.assertTrue(guards.check()) # Test on autocast states. def _test_autocast(dtype): with torch.autocast("cpu", dtype): guards = GlobalStateGuard() - serialized_guards = guards.dump() + serialized_guards = guards.__getstate__() json_guards = json.loads(serialized_guards) for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]): @@ -3283,7 +3316,7 @@ def _test_autocast(dtype): type(json_guards["autocast_state"]["dtype"][i]), int ) json_guards["autocast_state"]["dtype"][i] += 1 - guards.load(json.dumps(json_guards)) + guards.__setstate__(json.dumps(json_guards)) self.assertFalse(guards.check()) _test_autocast(torch.float16) @@ -4019,7 +4052,7 @@ def test_write_to_cells_with_name_shadowing(self): y = x def make_x_get_set(): - # NOTE: this `x` is a different cell object than the outter `x`. + # NOTE: this `x` is a different cell object than the outer `x`. x = y def set_x(v): @@ -4811,7 +4844,7 @@ def fn(x, y): self.assertEqual(cnts.frame_count, 2) def test_id_guarded_object(self): - class UDO: + class UserDefinedObject: @torch.compile(backend="eager") def call(self, x, ref_id): self_id = id(self) @@ -4824,11 +4857,11 @@ def call(self, x, ref_id): # Make sure we do recompile when id(self) is executed on # different self objects. x = torch.ones(2) - obj1 = UDO() + obj1 = UserDefinedObject() obj1_id = id(obj1) self.assertEqual(obj1.call(x, obj1_id), torch.ones(2)) - obj2 = UDO() + obj2 = UserDefinedObject() # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2)) @@ -5493,9 +5526,9 @@ def __init__(self) -> None: def forward(self, idx, targets=None): b, t = idx.size() - assert ( - t <= self.block_size - ), "Cannot forward, model block size is exhausted." + assert t <= self.block_size, ( + "Cannot forward, model block size is exhausted." + ) # forward the GPT model token_embeddings = self.tok_emb( @@ -6041,15 +6074,17 @@ def g2(a, b): def count_graph_break_msgs(msgs): return sum("Graph break in user code" in msg for msg in msgs) - with self.assertLogs( - logger="torch._dynamo", level=logging.DEBUG - ) as log, torch._dynamo.config.patch(verbose=True): + with ( + self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log, + torch._dynamo.config.patch(verbose=True), + ): f1(torch.randn(10), torch.randn(10)) self.assertGreater(count_graph_break_msgs(log.output), 1) - with self.assertLogs( - logger="torch._dynamo", level=logging.DEBUG - ) as log, torch._dynamo.config.patch(verbose=False): + with ( + self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log, + torch._dynamo.config.patch(verbose=False), + ): g1(torch.randn(10), torch.randn(10)) self.assertEqual(count_graph_break_msgs(log.output), 1) @@ -6870,7 +6905,7 @@ def guard_failures(failure): self.assertTrue(guard_failure is not None) self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0]) - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test requires CUDA or XPU.") def test_symint_as_device_kwarg_non_strict_export(self): class Mod(torch.nn.Module): def forward(self, x): @@ -7031,6 +7066,184 @@ def test_torch_package_working_with_trace(self): optimized_loaded_model = torch.compile(loaded_model, backend="eager")(*inputs) + def test_precompile_entry_hit(self): + from torch._C._dynamo.eval_frame import ( + _load_precompile_entry, + _reset_precompile_entries, + ) + + def fn(x): + return x + 1 + + def injected(x): + return x + 42 + + args = (torch.randn(3, 2),) + + compiled_fn = torch.compile(fn) + _load_precompile_entry( + fn.__code__, + torch._dynamo.guards.GuardManagerWrapper(), + injected.__code__, + ) + self.assertEqual(compiled_fn(*args), injected(*args)) + _reset_precompile_entries(fn.__code__) + + self.assertEqual(compiled_fn(*args), fn(*args)) + + def test_precompile_entry_miss(self): + from torch._C._dynamo.eval_frame import _load_precompile_entry + + def fn(x): + return x + 1 + + guard_manager = torch._dynamo.guards.RootGuardManager() + guard_manager.add_lambda_guard(lambda L: isinstance(L["x"], int), []) + + def injected(x): + return x + 42 + + args = (torch.randn(3, 2),) + + compiled_fn = torch.compile(fn) + _load_precompile_entry( + fn.__code__, + torch._dynamo.guards.GuardManagerWrapper(guard_manager), + injected.__code__, + ) + self.assertEqual(compiled_fn(*args), fn(*args)) + + def test_precompile_entries(self): + from torch._C._dynamo.eval_frame import ( + _load_precompile_entry, + _reset_precompile_entries, + ) + + def fn(x): + return x + 1 + + guard_manager_bool = torch._dynamo.guards.RootGuardManager() + guard_manager_bool.add_lambda_guard(lambda L: isinstance(L["x"], bool), []) + + def injected_bool(x: bool): + return x + 102 + + guard_manager_int = torch._dynamo.guards.RootGuardManager() + guard_manager_int.add_lambda_guard(lambda L: isinstance(L["x"], int), []) + + def injected_int(x: int): + return x + 42 + + guard_manager_tensor = torch._dynamo.guards.RootGuardManager() + guard_manager_tensor.add_lambda_guard( + lambda L: isinstance(L["x"], torch.Tensor), [] + ) + + def injected_tensor(x: torch.Tensor): + return x + 100 + + guard_manager_str = torch._dynamo.guards.RootGuardManager() + guard_manager_str.add_lambda_guard(lambda L: isinstance(L["x"], str), []) + + def injected_str(x: str): + return x + "1" + + args = (torch.randn(3, 2),) + + compiled_fn = torch.compile(fn) + _load_precompile_entry( + fn.__code__, + torch._dynamo.guards.GuardManagerWrapper(guard_manager_bool), + injected_bool.__code__, + ) + + _load_precompile_entry( + fn.__code__, + torch._dynamo.guards.GuardManagerWrapper(guard_manager_int), + injected_int.__code__, + ) + + _load_precompile_entry( + fn.__code__, + torch._dynamo.guards.GuardManagerWrapper(guard_manager_tensor), + injected_tensor.__code__, + ) + + _load_precompile_entry( + fn.__code__, + torch._dynamo.guards.GuardManagerWrapper(guard_manager_str), + injected_str.__code__, + ) + + self.assertEqual(compiled_fn(*args), injected_tensor(*args)) + self.assertEqual(compiled_fn(True), injected_bool(True)) + self.assertEqual(compiled_fn(10), injected_int(10)) + self.assertEqual(compiled_fn("10"), injected_str("10")) + _reset_precompile_entries(fn.__code__) + + self.assertEqual(compiled_fn(*args), fn(*args)) + + def test_precompile_fail_on_recompile(self): + from torch._C._dynamo.eval_frame import _load_precompile_entry + + @torch.compiler.disable + def graph(x, s0): + return x + s0 + + def fn(x): + nonlocal graph # Forcing fn and injected to have the same closure. + return x - 1 + + def injected(x): + s0 = call_size(x, 0) + return graph(x, s0) + + args = (torch.randn(3, 2),) + + compiled_fn = torch.compile(fn) + _load_precompile_entry( + fn.__code__, + torch._dynamo.guards.GuardManagerWrapper(), + injected.__code__, + ) + with torch.compiler.set_stance("fail_on_recompile"): + self.assertEqual(compiled_fn(*args), injected(*args)) + + def test_fail_on_recompile_error_message(self): + from torch._C._dynamo.eval_frame import ( + _load_precompile_entry, + _reset_precompile_entries, + ) + + def fn(x): + return x + 1 + + guard_manager_bool = torch._dynamo.guards.RootGuardManager() + guard_manager_bool.add_lambda_guard( + lambda L: isinstance(L["x"], bool), ["isinstance(L['x'], bool)"] + ) + + def injected_bool(x: bool): + return x + 102 + + args = (torch.randn(3, 2),) + + compiled_fn = torch.compile(fn) + _load_precompile_entry( + fn.__code__, + torch._dynamo.guards.GuardManagerWrapper(guard_manager_bool), + injected_bool.__code__, + ) + + try: + with torch.compiler.set_stance("fail_on_recompile"): + with self.assertRaisesRegex( + RuntimeError, "Failed on the following precompiled guards:" + ): + compiled_fn(*args) + finally: + _reset_precompile_entries(fn.__code__) + def test_shape_and_tuple_equality(self): def fn(x, y, t): z = x * y @@ -7646,6 +7859,29 @@ def fn(x): self.assertEqual(fn(torch.tensor([4])).size(0), 1) self.assertEqual(fn(torch.tensor([1])).size(0), 0) + def test_sym_and_terms(self): + from torch.fx.experimental.symbolic_shapes import sym_and + + @torch.compile(fullgraph=True, dynamic=True, backend="eager") + def fn(xs): + u0, u1 = xs.tolist() + torch._check(sym_and(u0 >= 3, u0 <= 10, u1 >= 2)) + + # test individual checks + n = 0 + if u0 >= 3: + n += 1 + if u0 <= 11: + n += 1 + if u1 >= 1: + n += 1 + return u0 + u1 + n + + fn(torch.tensor([5, 6])) + fn(torch.tensor([8, 7])) + with self.assertRaises(RuntimeError): + fn(torch.tensor([9, 0])) + def test_unbacked_2d_expand(self): @torch.compile(fullgraph=True, dynamic=True, backend="inductor") def func(a, b): @@ -7887,7 +8123,7 @@ def forward(self, a): m1 = Model(50) m2 = Model(60) - with fresh_inductor_cache(): + with fresh_cache(): m1(torch.rand(1, 2, 3)) m2(torch.rand(1, 2, 3)) @@ -8035,8 +8271,9 @@ def h(a): def f(a): return h(a) - with warnings.catch_warnings(record=True) as w, self.assertRaises( - torch._dynamo.exc.BackendCompilerFailed + with ( + warnings.catch_warnings(record=True) as w, + self.assertRaises(torch._dynamo.exc.BackendCompilerFailed), ): f(torch.randn(2, 2, requires_grad=True)) @@ -8229,8 +8466,7 @@ def run_fn(): def test_torch_compile_ctx_on_forward_and_training_step(self): class MyModel(torch.nn.Module): - def forward(self): - ... + def forward(self): ... def training_step(self): self() @@ -8462,7 +8698,7 @@ def test_guards_cse_pass_single(self): ), testcase(expr="f(m.n[0], '1').x.y.z", expected="f(_var3, '1').x.y.z"), testcase(expr="f(m.n[0], '2').x.y.z", expected="f(_var3, '2').x.y.z"), - # The whole expressiong gets CSE-d, as well as all of its sub-expressions. + # The whole expression gets CSE-d, as well as all of its sub-expressions. testcase( expr="self.g(a, b).k", preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"], @@ -9085,31 +9321,6 @@ def foo(x, y): self.assertEqual(counter.frame_count, 1) self.assertEqual(result, eager_result) - def test_input_set_graph_break(self): - def foo(x): - return x.pop() * x.pop() - - x = torch.randn(10, 10) - y = torch.randn(10, 10) - - counter = CompileCounter() - - inp = {x, x, x, x, y, y} - foo = torch.compile(foo, backend=counter, fullgraph=True) - - # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part. - # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents) - # and so the guard story for the objects passed into input just isn't there atm. - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Unsupported method call", - ): - foo(inp) - - foo = torch.compile(foo, backend=counter, fullgraph=False) - foo(inp) - self.assertEqual(counter.frame_count, 1) - def test_reconstruct_set_across_graph_break(self): def foo(x, y): setty = set() @@ -10282,11 +10493,11 @@ def fn(x, y): self.assertEqual(actual, expected) def test_pytree_tree_leaves(self): - implemtations = [("python", python_pytree)] + implementations = [("python", python_pytree)] if cxx_pytree is not None: - implemtations.append(("cxx", cxx_pytree)) + implementations.append(("cxx", cxx_pytree)) - for name, module in implemtations: + for name, module in implementations: with self.subTest(f"pytree implement: {name}"): def fn(x): @@ -10316,11 +10527,11 @@ def fn(x): self.assertEqual(actual, expected) def test_pytree_tree_flatten_unflatten(self): - implemtations = [("python", python_pytree)] + implementations = [("python", python_pytree)] if cxx_pytree is not None: - implemtations.append(("cxx", cxx_pytree)) + implementations.append(("cxx", cxx_pytree)) - for name, module in implemtations: + for name, module in implementations: with self.subTest(f"pytree implement: {name}"): def fn(x, y): @@ -10367,11 +10578,11 @@ def fn(x, y): self.assertEqual(actual, expected) def test_pytree_tree_map(self): - implemtations = [("python", python_pytree)] + implementations = [("python", python_pytree)] if cxx_pytree is not None: - implemtations.append(("cxx", cxx_pytree)) + implementations.append(("cxx", cxx_pytree)) - for name, module in implemtations: + for name, module in implementations: with self.subTest(f"pytree implement: {name}"): def fn(x, y): @@ -11520,7 +11731,7 @@ def fn(x, y): # Ensure that the generated graph returns only one output. We want the # add_ on the grad to be part of the graph itself, so that inductor can - # theoretically move the add_ and resutling copy_ nodes at the right + # theoretically move the add_ and resulting copy_ nodes at the right # place to free memory. self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1) self.assertEqual(z, ref_y) @@ -11682,6 +11893,30 @@ def fn(x): res = fn(x) self.assertEqual(ref, res) + def test_descriptor_side_effect(self): + # This pattern (readonly descriptor but writable value in `__dict__`) is + # from scipy `_make_tuple_bunch`: + # https://github.com/scipy/scipy/blob/maintenance/1.9.x/scipy/_lib/_bunch.py#L32-L226 + def fget(obj): + return obj.__dict__["field"] + + class MyClass: + def __init__(self, n): + self.__dict__["field"] = n + + field = property(fget) + + def fn(x): + obj = MyClass(42) + return x + obj.field, obj + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref_t, ref_obj = fn(x) + res_t, res_obj = opt_fn(x) + self.assertEqual(ref_t, res_t) + self.assertEqual(ref_obj.field, res_obj.field) + def test_assert_size_stride(self): x = torch.randn(2, 3, 4) with self.assertRaisesRegex( @@ -12058,7 +12293,7 @@ def __init__(self, x): self.ne_called = False def __ne__(self, other): - # ne_called attr is later checked to ensure that overrideen + # ne_called attr is later checked to ensure that overridden # `__ne__` is traced self.ne_called = True return not self.__eq__(other) @@ -12238,6 +12473,147 @@ def fn(x, y): with torch.compiler.set_stance("fail_on_recompile"): self.assertEqual(fn(*inputs), inputs[0]) + def test_guard_filter_inbuilt_nn_modules(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.LayerNorm(8) + + def forward(self, x): + return self.norm(x) + + mod = Mod() + opt_mod = torch.compile( + mod, + options={ + "guard_filter_fn": torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe + }, + ) + + x = torch.rand(4, 8) + opt_mod(x) + + mod.norm.eps = 1e-02 + # Since the guards are skipped on inbuilt nn modules, we should not recompile + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + opt_mod(x) + + def test_guard_filter_nn_modules(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 2 + self.norm = torch.nn.LayerNorm(8) + + def forward(self, x): + return self.norm(x) + self.c + + mod = Mod() + opt_mod = torch.compile( + mod, + options={ + "guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe + }, + ) + + x = torch.rand(4, 8) + opt_mod(x) + + mod.c = 3 + # Since the guards are skipped on all nn modules, we should not recompile + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + opt_mod(x) + + def test_guard_filter_tensors(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 2.0 + self.norm = torch.nn.LayerNorm(8) + + def forward(self, x): + return self.norm(x) + self.c + + mod = Mod() + opt_mod = torch.compile( + mod, + options={ + "guard_filter_fn": torch.compiler.keep_tensor_guards_unsafe, + }, + ) + + x = torch.rand(4, 8) + opt_mod(x) + + mod.c = 3.0 + # Since the guards are skipped on all tensors + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + opt_mod(x) + + def test_guard_filter_globals(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 2 + self.norm = torch.nn.LayerNorm(8) + + def forward(self, x): + return self.norm(x) + self.c + GLOBAL_INT + + mod = Mod() + opt_mod = torch.compile( + mod, + options={ + "guard_filter_fn": torch.compiler.skip_guard_on_globals_unsafe, + }, + ) + + global GLOBAL_INT + GLOBAL_INT = 1 + x = torch.rand(4, 8) + opt_mod(x) + + GLOBAL_INT = 2 + # Since the guards are skipped on globals, we should not recompile + with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): + opt_mod(x) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_builtin_bool_on_symint(self): + def f(x): + return bool(x.item()) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.randint(10, (1,)) + + ref = f(x) + res = opt_f(x) + self.assertEqual(ref, res) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_builtin_bool_on_symfloat(self): + def f(x): + return bool(x.item()) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.randn(1) + + ref = f(x) + res = opt_f(x) + self.assertEqual(ref, res) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_builtin_bool_on_symbool(self): + def f(x): + return bool(x.item()) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.randn(1) == 1 + + ref = f(x) + res = opt_f(x) + self.assertEqual(ref, res) + class TestTracer(JitTestCase): def test_jit_save(self): @@ -12406,7 +12782,7 @@ def forward(self, query, key, value): def test_torch_device_is_available(self, device): def fn(x): - if TEST_HPU or TEST_CUDA: + if torch.accelerator.is_available(): return x + 1 else: return x - 1 @@ -12458,6 +12834,30 @@ def fn2(x): opt_fn2 = torch.compile(fn2, backend="eager", fullgraph=True) res = opt_fn2(x2) + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @torch._dynamo.config.patch(recompile_limit=999) + def test_legacy_cuda_tensor(self): + typs = [ + torch.cuda.FloatTensor, + torch.cuda.DoubleTensor, + torch.cuda.HalfTensor, + torch.cuda.BFloat16Tensor, + torch.cuda.ByteTensor, + torch.cuda.CharTensor, + torch.cuda.IntTensor, + torch.cuda.ShortTensor, + torch.cuda.LongTensor, + ] + + def f2(typ): + return typ([1, 2, 3]) + + compiled_f2 = torch.compile(f2, backend="eager", fullgraph=True) + for typ in typs: + output = compiled_f2(typ) + expected = f2(typ) + self.assertEqual(output, expected) + def test_get_device(self, device): def fn(x, y): x = x + 1 @@ -12485,27 +12885,23 @@ def f(rank): def test_cuda_set_device(self, device): def fn(): a = torch.ones(2, device=device) - torch.cuda.set_device(1) + torch.get_device_module(device).set_device(1) return a + 1 - with torch.cuda.device(0): + with torch.get_device_module(device).device(0): counter = CompileCounter() opt_fn = torch.compile(fn, backend=counter) res = opt_fn() - self.assertEqual(res.device.type, "cuda") + self.assertEqual(res.device.type, device) self.assertEqual(res.device.index, 0) self.assertEqual(counter.frame_count, 2) - def test_torch_device_python_type(self): + def test_torch_device_python_type(self, device): + device_type = torch.device(device).type for device, device_type, index in [ ("cpu", "cpu", None), - ("cuda:0", "cuda", 0), - ("hpu:0", "hpu", 0), + (device, device_type, 0), ]: - if (device == "cuda:0" and not TEST_CUDA) or ( - device == "hpu:0" and not TEST_HPU - ): - continue def fn(target): target_device = target.device @@ -12555,9 +12951,22 @@ def f(): res = opt_f() self.assertEqual(ref, res) + def test_randint_no_graphbreak(self): + @torch.compile(backend="aot_eager", fullgraph=True) + def f(actions, n_act, epsilon=0.1): + actions_random = torch.randint_like(actions, n_act) + + return actions_random -devices = ("cuda", "hpu") -instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices) + x = torch.ones([1], dtype=torch.int64) + y = torch.tensor(5) + f(x, y) + + +devices = ("cuda", "hpu", "xpu") +instantiate_device_type_tests( + MiscTestsDevice, globals(), only_for=devices, allow_xpu=True +) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index 8e91d1bb0644bb..d2833e1a7195a2 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -7,7 +7,7 @@ import torch._dynamo.testing from torch._dynamo.testing import same from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import TEST_HPU, TestCase +from torch.testing._internal.common_utils import TestCase try: @@ -359,11 +359,11 @@ def forward( ) -devices = ["cpu", "cuda"] -if TEST_HPU: - devices.append("hpu") +devices = ["cpu", "cuda", "xpu", "hpu"] -instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices) +instantiate_device_type_tests( + TestModelOutputBert, globals(), only_for=devices, allow_xpu=True +) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 57b75eca84002e..304893026453e2 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -12,11 +12,16 @@ _push_on_torch_function_stack, ) from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.triton_utils import requires_gpu from torch.utils._device import DeviceContext from torch.utils._python_dispatch import TorchDispatchMode +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) + + class TestMode(BaseTorchFunctionMode): def __torch_function__(self, func, types, args, kwargs=None): if not kwargs: @@ -613,12 +618,12 @@ def func(a): func(torch.randn(3)) - @requires_cuda + @requires_gpu def test_flex_attention(self): import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention - torch.set_default_device("cuda") + torch.set_default_device(device_type) flex_attention = torch.compile(flex_attention, dynamic=False) @@ -628,7 +633,9 @@ def prefix_lm(b, h, q, kv): return prefix_lengths[b] >= kv # This runs in fullgraph already - create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True) + create_block_mask( + prefix_lm, 8, None, 512, 512, _compile=True, device=device_type + ) def test_register_hook(self): import functools diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 9b3b2c40ffad82..39454876a15f29 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1299,6 +1299,7 @@ def test_unsupportedmodule(self): self.assertTrue(torch._dynamo.testing.same(r, m(i))) self.assertEqual(cnt.op_count, 6) + @patch.object(torch._dynamo.config, "allow_unspec_int_on_nn_module", True) def test_self_mutating1(self): m1 = torch.nn.Linear(10, 10) m2 = SelfMutatingModule(m1) @@ -1986,7 +1987,7 @@ def forward(self, x): # Check order of _modules def fn(x): for idx, p in enumerate(mod.modules()): - # Something silly to force depedency on the order + # Something silly to force dependency on the order x += coeffs_for_mod[p] * coeffs[idx] for idx, p in enumerate(mod.named_modules()): x += coeffs_for_mod[p[1]] * coeffs[idx] @@ -2093,11 +2094,12 @@ def forward(self, x): mod = MockModule() # Each submod is compiled separately and has a different nn module # guard. Ensure that recompilation logic is handle correctly. - with unittest.mock.patch( - "torch._dynamo.config.error_on_recompile", True - ), unittest.mock.patch( - "torch._dynamo.config.recompile_limit", - recompile_limit, + with ( + unittest.mock.patch("torch._dynamo.config.error_on_recompile", True), + unittest.mock.patch( + "torch._dynamo.config.recompile_limit", + recompile_limit, + ), ): x = torch.randn(*size, requires_grad=True) mod(x) @@ -2159,11 +2161,12 @@ def forward(self, x): mod = MockModule() # Each submod is compiled separately and has a different nn module # guard. Ensure that recompilation logic is handle correctly. - with unittest.mock.patch( - "torch._dynamo.config.error_on_recompile", True - ), unittest.mock.patch( - "torch._dynamo.config.recompile_limit", - recompile_limit, + with ( + unittest.mock.patch("torch._dynamo.config.error_on_recompile", True), + unittest.mock.patch( + "torch._dynamo.config.recompile_limit", + recompile_limit, + ), ): x = torch.randn(*size, requires_grad=True) mod(x) diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 614baec1e3dcee..e74ebc22587100 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -3,6 +3,7 @@ PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes with test_adam in OptimizerTests) """ + import functools import torch diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py new file mode 100644 index 00000000000000..64331adcac7a10 --- /dev/null +++ b/test/dynamo/test_package.py @@ -0,0 +1,312 @@ +# Owner(s): ["module: dynamo"] + +import importlib +import os +import sys +import tempfile +import unittest + +import torch +import torch._dynamo.testing +import torch._inductor.config +import torch._inductor.test_case +import torch.onnx.operators +import torch.utils.cpp_extension +from torch._dynamo.package import CompilePackage, DiskDynamoStore +from torch._functorch import config as functorch_config +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) +from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU + + +@functorch_config.patch("bundled_autograd_cache", True) +@instantiate_parametrized_tests +class TestPackage(torch._inductor.test_case.TestCase): + def path(self): + path = os.path.join(cache_dir(), f"package_{self.id()}") + os.makedirs(path, exist_ok=True) + return path + + @parametrize("backend", ("eager", "inductor")) + @parametrize("device", ("cpu", "cuda", "xpu")) + def test_basic_fn(self, backend, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + if device == "xpu" and not HAS_XPU: + raise unittest.SkipTest("Requires XPU/Triton") + + ctx = DiskDynamoStore() + + def fn(x): + return x + 1 + + args = ( + torch.randn( + 3, + 2, + device=device, + ), + ) + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize(backend, package=package)(fn) + expected = compiled_fn(*args) + if backend == "eager": + for backend_id, backend in package.cached_backends.items(): + ctx.record_eager_backend(backend_id, backend) + + ctx.save_package(package, self.path()) + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args) + + package, backends = ctx.load_package(fn, self.path()) + compiled_fn = torch._dynamo.optimize(package=package)(fn) + package.install(backends) + self.assertEqual(expected, compiled_fn(*args)) + + @parametrize("backend", ("eager", "inductor")) + @parametrize("device", ("cpu", "cuda", "xpu")) + def test_lazy_backward(self, backend, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + if device == "xpu" and not HAS_XPU: + raise unittest.SkipTest("Requires XPU/Triton") + + ctx = DiskDynamoStore() + + def fn(x): + return x.sin() + x.cos() + + args = ( + torch.zeros( + 3, + 2, + device=device, + requires_grad=True, + ), + ) + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize(backend, package=package)(fn) + expected = compiled_fn(*args) + expected.sum().backward() + + if backend == "eager": + for backend_id, backend in package.cached_backends.items(): + ctx.record_eager_backend(backend_id, backend) + + ctx.save_package(package, self.path()) + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args) + + package, backends = ctx.load_package(fn, self.path()) + compiled_fn = torch._dynamo.optimize(package=package)(fn) + package.install(backends) + self.assertEqual(expected, compiled_fn(*args)) + + @parametrize("backend", ("eager", "inductor")) + @parametrize("device", ("cpu", "cuda", "xpu")) + def test_graph_break_bomb(self, backend, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + if device == "xpu" and not HAS_XPU: + raise unittest.SkipTest("Requires XPU/Triton") + + ctx = DiskDynamoStore() + + def fn(x, l, r): + if l > r: + return x.sum() + mid = (l + r) // 2 + if x.sum() == mid: + return x.sum() + elif x.sum() < mid: + return fn(x, l, mid) + else: + return fn(x, mid + 1, r) + + def guard_filter_fn(guards): + return [ + guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH") + for guard in guards + ] + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize( + backend=backend, package=package, guard_filter_fn=guard_filter_fn + )(fn) + N = 10 + args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)] + for args in args_list: + compiled_fn(*args) + if backend == "eager": + for backend_id, backend in package.cached_backends.items(): + ctx.record_eager_backend(backend_id, backend) + ctx.save_package(package, self.path()) + + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + for args in args_list: + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args) + package, backends = ctx.load_package(fn, self.path()) + compiled_fn = torch._dynamo.optimize( + backend="eager", package=package, guard_filter_fn=guard_filter_fn + )(fn) + package.install(backends) + for args in args_list: + self.assertEqual(compiled_fn(*args), args[0].sum()) + + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(torch.tensor(N), 0, N - 1) + + @parametrize("backend", ("eager", "inductor")) + @parametrize("device", ("cpu", "cuda", "xpu")) + def test_dynamic_shape(self, backend, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + if device == "xpu" and not HAS_XPU: + raise unittest.SkipTest("Requires XPU/Triton") + + ctx = DiskDynamoStore() + + def fn(x): + return x + x.shape[0] + + args = (torch.randn(3, 2, device=device),) + args1 = (torch.randn(5, 2, device=device),) + args2 = (torch.randn(7, 2, device=device),) + expected1 = fn(*args1) + + torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5) + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn) + compiled_fn(*args) + if backend == "eager": + for backend_id, backend in package.cached_backends.items(): + ctx.record_eager_backend(backend_id, backend) + ctx.save_package(package, self.path()) + + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args1) + + package, backends = ctx.load_package(fn, self.path()) + compiled_fn = torch._dynamo.optimize(package=package)(fn) + package.install(backends) + + self.assertEqual(expected1, compiled_fn(*args1)) + + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args2) + + def test_file_change(self): + ctx = DiskDynamoStore() + + def import_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + mock_module_add_original = """ +def add(x, y): + return x + y +""" + + mock_module_add_modified = """ +def add(x, y): + return x - y +""" + with tempfile.TemporaryDirectory() as tmp_dir: + mock_module_add_original_path = os.path.join( + tmp_dir, "mock_module_add_original.py" + ) + mock_module_add_modified_path = os.path.join( + tmp_dir, "mock_module_add_modified.py" + ) + with open(mock_module_add_original_path, "w") as f: + f.write(mock_module_add_original) + with open(mock_module_add_modified_path, "w") as f: + f.write(mock_module_add_modified) + + module = import_from_path( + "torch.test_package_helper", + mock_module_add_original_path, + ) + + def fn(x): + return module.add(x, 1) + + args = (torch.randn(3, 2),) + + def guard_filter_fn(guards): + return [ + guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH") + for guard in guards + ] + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize( + backend="eager", package=package, guard_filter_fn=guard_filter_fn + )(fn) + compiled_fn(*args) + for backend_id, backend in package.cached_backends.items(): + ctx.record_eager_backend(backend_id, backend) + ctx.save_package(package, self.path()) + + module = import_from_path( + "torch.test_package_helper", + mock_module_add_modified_path, + ) + with self.assertRaisesRegex(RuntimeError, "Source code changes detected"): + ctx.load_package(fn, self.path()) + + module = import_from_path( + "torch.test_package_helper", + mock_module_add_original_path, + ) + ctx.load_package(fn, self.path()) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py index 25c630e5a59fb5..0e6d05964c85ec 100644 --- a/test/dynamo/test_pgo.py +++ b/test/dynamo/test_pgo.py @@ -12,7 +12,7 @@ import torch.compiler.config import torch.nested from torch._dynamo.testing import CompileCounter -from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache +from torch._inductor.utils import clear_caches, fresh_cache class PgoTest(torch._dynamo.test_case.TestCase): @@ -24,7 +24,7 @@ def setUp(self): torch._dynamo.config.patch(automatic_dynamic_local_pgo=True) ) if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1": - self._test_stack.enter_context(fresh_inductor_cache()) + self._test_stack.enter_context(fresh_cache()) mock_cache.PatchCaches.setUp() def tearDown(self): @@ -35,7 +35,7 @@ def tearDown(self): def reset(self): torch._dynamo.reset() - clear_inductor_caches() + clear_caches() def test_basic(self): cnts = CompileCounter() @@ -116,6 +116,51 @@ def check_whitelist(sources_): f(torch.randn(8, 8), torch.randn(8)) self.assertEqual(cnts.frame_count, 1) + def test_pgo_dynamic_false(self): + @torch.compile(backend="eager", dynamic=False) + class Foo(torch.nn.Module): + def forward(self, x, y): + x += 2 + y += 2 + torch._dynamo.graph_break() + x -= 2 + y *= 2 + return x, y + + self.reset() + f = Foo() + f(torch.randn(2, 4), torch.randn(2, 4)) + f(torch.randn(4, 4), torch.randn(6, 8)) + + # check PGO code state is overwritten with static value, both before/after graph break + for code_state in torch._dynamo.pgo.get_code_state().values(): + self.assertEqual(code_state.automatic_dynamic["L['x']"].size, (4, 4)) + self.assertEqual(code_state.automatic_dynamic["L['y']"].size, (6, 8)) + + def test_whitelist_ints_floats(self): + @torch.compile(backend="eager", fullgraph=True) + class Bar(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + def forward(self, x, y, z): + if self.c == 1.0: + return x + y + torch.tensor([z]) + + f = Bar(1.0) + f(2, 1.0, 2.0) + f(3, 1.2, 2.0) + state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state()) + whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1) + self.assertTrue("L['x']" in whitelist) + self.assertTrue("L['y']" in whitelist) + self.assertTrue( + "___as_tensor(L['y'])" not in whitelist + ) # ephemeral FloatTensor source + self.assertTrue("L['z']" not in whitelist) # static float + self.assertTrue("L['self'].c" not in whitelist) # static float property + def test_pgo_dynamic_params(self): cnts = CompileCounter() @@ -209,7 +254,7 @@ def f(x): self.assertEqual(cnts.frame_count, 2) torch._dynamo.reset() - clear_inductor_caches() + clear_caches() cnts.clear() with torch.compiler.config.patch(job_id="foo"): diff --git a/test/dynamo/test_precompile_context.py b/test/dynamo/test_precompile_context.py new file mode 100644 index 00000000000000..d3a5140cbe821f --- /dev/null +++ b/test/dynamo/test_precompile_context.py @@ -0,0 +1,99 @@ +# Owner(s): ["module: dynamo"] + +import torch +import torch._dynamo +import torch._dynamo.test_case +import torch._functorch +from torch._dynamo.precompile_context import PrecompileContext +from torch._functorch import config as functorch_config +from torch._functorch._aot_autograd.autograd_cache import ( + BundledAOTAutogradCacheArtifact, +) +from torch._inductor.test_case import TestCase as InductorTestCase +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton + + +@functorch_config.patch({"enable_autograd_cache": True}) +@functorch_config.patch( + {"bundled_autograd_cache": True} +) # Requires bundledaotautograd cache for now +class PrecompileContextTests(InductorTestCase): + def setUp(self): + """ + Reset all counters and caches before each unit test + """ + super().setUp() + # Clear PrecompileContext cache artifacts + PrecompileContext.clear() + + @requires_triton() + def test_basic(self): + """ + Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1 + """ + + def simple_function(x): + return x.sin() + x.cos() + + compiled_fn = torch.compile(simple_function) + + # Run the compiled function + x = torch.randn(10, device=GPU_TYPE, requires_grad=True) + result = compiled_fn(x) + result.sum().backward() + # Check that PrecompileContext._new_cache_artifacts_by_key has length 1 + self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1) + + self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0) + result = PrecompileContext.serialize() + assert result is not None + serialized, cache_info = result + self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1) + + artifacts = PrecompileContext.deserialize(serialized) + assert artifacts is not None + deserialized = artifacts["precompile_aot_autograd"] + assert len(deserialized) == 1 + entry = deserialized[0] + assert isinstance(entry, BundledAOTAutogradCacheArtifact) + entry = entry.after_deserialization() + # Now that we've serialized, there should be no new cache artifacts + self.assertEqual( + len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0 + ) + + @requires_triton() + def test_serialize_by_key(self): + """ + Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1 + """ + + def simple_function(x): + return x.sin() + x.cos() + + compiled_fn = torch.compile(simple_function) + + # Run the compiled function + x = torch.randn(10, device=GPU_TYPE, requires_grad=True) + result = compiled_fn(x) + result.sum().backward() + # Check that PrecompileContext._new_cache_artifacts_by_key has length 1 + # TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once + # we have torch._dynamo.package implemented + self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1) + key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys())) + result = PrecompileContext.serialize_artifact_by_key(key) + assert isinstance(result, BundledAOTAutogradCacheArtifact) + self.assertEqual(result.key, key) + + self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0) + result = PrecompileContext.serialize() + assert result is not None + _, cache_info = result + self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index 4507d33946205c..e69c23c952438f 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -238,9 +238,7 @@ def f(x): tensor 'x' size mismatch at index 0. expected 11, actual 12 tensor 'x' size mismatch at index 0. expected 10, actual 12 tensor 'x' size mismatch at index 0. expected 9, actual 12 -tensor 'x' size mismatch at index 0. expected 8, actual 12""".split( - "\n" - ): +tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"): self.assertIn( line, failure_str, @@ -276,9 +274,7 @@ def filter_reasons(): opt_f([7, 8]) for line in """\ -len(x) == 3""".split( - "\n" - ): +len(x) == 3""".split("\n"): self.assertIn(line, filter_reasons()) failure_reasons.clear() @@ -286,9 +282,7 @@ def filter_reasons(): for line in """\ len(x) == 2 -len(x) == 3""".split( - "\n" - ): +len(x) == 3""".split("\n"): self.assertIn(line, filter_reasons()) @torch._dynamo.config.patch(recompile_limit=1) diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 662f5420bfcb13..7185682df70ec7 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -7,6 +7,11 @@ import torch import torch._dynamo.test_case from torch.testing._internal.common_utils import IS_FBCODE +from torch.testing._internal.inductor_utils import requires_triton +from torch.utils._triton import ( + has_triton_experimental_host_tma, + has_triton_tensor_descriptor_host_tma, +) def _filter_instructions(instructions, opname): @@ -397,6 +402,52 @@ def gn(x): inp = torch.randn(3) self.assertEqual(gn(inp), inp + 3) + @requires_triton() + @unittest.skipIf( + not has_triton_experimental_host_tma(), + "Test requires triton.tools.experimental_descriptor API", + ) + def test_tma_experimental_reconstruct(self): + import triton + + def create_tma(tensor): + tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor( + tensor.data_ptr(), + tensor.size(0), + tensor.size(1), + 32, + 32, + tensor.element_size(), + ) + return tensor + 1, tma + + x = torch.randn(128, 128, device="cuda") + + ref = create_tma(x) + res = torch.compile(create_tma, backend="eager")(x) + self.assertEqual(ref[1].desc, res[1].desc) + + @requires_triton() + @unittest.skipIf( + not has_triton_tensor_descriptor_host_tma(), + "Test requires triton.tools.tensor_descriptor API", + ) + def test_tma_stable_reconstruct(self): + import triton + + def create_tma(tensor): + tma = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + tensor, + [32, 32], + ) + return tensor + 1, tma + + x = torch.randn(128, 128, device="cuda") + + ref = create_tma(x) + res = torch.compile(create_tma, backend="eager")(x) + self.assertEqual(ref, res) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_reorder_logs.py b/test/dynamo/test_reorder_logs.py index 0b22ca50c18ce1..84b4f00dc9d110 100644 --- a/test/dynamo/test_reorder_logs.py +++ b/test/dynamo/test_reorder_logs.py @@ -202,7 +202,16 @@ def f(x): graph_break_key = counters["graph_break"].keys() self.assertEqual(len(graph_break_key), 1) - self.assertEqual(next(iter(graph_break_key)), "Tensor.item") + self.assertExpectedInline( + next(iter(graph_break_key)), + """\ +Unsupported Tensor.item() call with capture_scalar_outputs=False + Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False. + Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph. + + Developer debug context: call_method TensorVariable() item () {} +""", # noqa: B950 + ) if __name__ == "__main__": diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 602a0a48a1b3da..d58638573978de 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -13,6 +13,7 @@ import importlib import inspect import itertools +import logging import os import random import sys @@ -43,8 +44,9 @@ from torch import nn from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312 -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache from torch.nn import functional as F +from torch.profiler import profile, ProfilerActivity from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_FP8, @@ -59,6 +61,7 @@ skipIfWindows, TEST_WITH_ROCM, ) +from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test from torch.testing._internal.two_tensor import TwoTensor from torch.utils._python_dispatch import TorchDispatchMode @@ -176,9 +179,9 @@ def shapes_to_tensor(x, device=None): if torch.jit.is_scripting(): return torch.as_tensor(x, device=device) if torch.jit.is_tracing(): - assert all( - isinstance(t, torch.Tensor) for t in x - ), "Shape should be tensor during tracing!" + assert all(isinstance(t, torch.Tensor) for t in x), ( + "Shape should be tensor during tracing!" + ) # as_tensor should not be used in tracing because it records a constant ret = torch.stack(x) if ret.device != device: # avoid recording a hard-coded device if not necessary @@ -477,9 +480,9 @@ def forward( real_seq_length = seq_length if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + assert len(past_key_value) == 2, ( + f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + ) real_seq_length += ( past_key_value[0].shape[2] if query_length is None else query_length ) @@ -945,6 +948,26 @@ def __init__(self, x): self.x = x + 2 +class LRUCacheWarningTests(LoggingTestCase): + @requires_cuda + @make_logging_test(dynamo=logging.DEBUG) + def test_lru_cache_warning_issued_during_tracing(self, records): + torch.set_default_device("cuda") + + @torch.compile(backend="eager") + def f(x): + torch.get_device_module() + x = x.cos().sin() + return x + + result = f(torch.randn(1024)) + self.assertIsInstance(result, torch.Tensor) + + for record in records: + if "call to a lru_cache wrapped function at:" in record.getMessage(): + self.fail("lru_cache warning was incorrectly logged") + + class ReproTests(torch._dynamo.test_case.TestCase): def setUp(self) -> None: try: @@ -3461,6 +3484,40 @@ def fn(x, obj): self.assertEqual(obj1.b.item(), 0) self.assertEqual(obj2.a.item(), 2) + def test_delattr_return(self): + class MyObject: + def __init__(self, val): + self.val = val + self.deletion_attempted = False + + def __delattr__(self, attr): + if attr == "val": + self.deletion_attempted = True + else: + super().__delattr__(attr) + + @torch.compile(fullgraph=True, backend="eager") + def test_delattr(input_tensor): + instance_a = MyObject(1) + instance_b = MyObject(2) + del instance_a.val + del instance_b.val + exists_a = hasattr(instance_a, "val") + exists_b = hasattr(instance_b, "val") + deletion_attempted_a = instance_a.deletion_attempted + deletion_attempted_b = instance_b.deletion_attempted + return ( + input_tensor + 1, + exists_a, + exists_b, + deletion_attempted_a, + deletion_attempted_b, + ) + + result = test_delattr(torch.ones(1)) + self.assertEqual(result[0], torch.tensor([2.0])) + self.assertEqual(result[1:], (True, True, True, True)) + def test_delattr_raises(self): class MyObj: def __init__(self, a, b): @@ -3883,7 +3940,7 @@ def randint_fn(high, size, out): opt_model(17, (12,), out2) @requires_cuda - @serialTest + @serialTest() def test_mem_leak_guards(self): def gn(x0, x): return x0 * x @@ -4404,6 +4461,20 @@ def func3(x, y): # frame_count should stay at 1. self.assertEqual(cnt.frame_count, 1) + def test_tensor_set_data_mismatched_dtype(self): + def func(x, y): + x.data = y.to(dtype=torch.bfloat16) + + x1 = torch.tensor([], dtype=torch.float32) + x2 = torch.tensor([], dtype=torch.float32) + y1 = torch.tensor([1, 2, 3], dtype=torch.float32) + y2 = torch.tensor([1, 2, 3], dtype=torch.float32) + func(x1, y1) + torch.compile(func, backend="eager")(x2, y2) + self.assertEqual(x1, x2) + self.assertEqual(x1.data, x2.data) + self.assertEqual(y1, y2) + def test_user_ctor_ctx_manager(self): class UserCtxManager: def __enter__(self): @@ -4821,9 +4892,9 @@ def set(self, name: str, value: Any) -> None: with warnings.catch_warnings(record=True): data_len = len(value) if len(self._fields): - assert ( - len(self) == data_len - ), f"Adding a field of length {data_len} to a Instances of length {len(self)}" + assert len(self) == data_len, ( + f"Adding a field of length {data_len} to a Instances of length {len(self)}" + ) self._fields[name] = value def get(self, name: str) -> Any: @@ -4960,6 +5031,66 @@ def fn(x_weak, y): res = opt_fn(x_weak, y) self.assertEqual(ref, res) + # The programming model around (weak)references is that we DO NOT guarantee + # any behavior that depends on deallocation order. We do guarantee "eventual consistency", + # that is, after the torch.compile'd function is finished running (including any graph breaks), + # refcount semantics will match eager's. + def test_weakref_callback(self): + called1 = False + + def callback1(ref): + nonlocal called1 + called1 = True + if not torch.compiler.is_compiling(): + raise RuntimeError("callback1 expected to be compiled") + + # weakref callbacks that should be called in the compiled region will be compiled. + # But the exact place in the compiled code that the callback is made is undefined. + @torch.compile(backend="eager") + def fn(x): + y = x + 1 + ref = weakref.ref(y, callback1) + torch._dynamo.graph_break() + return ref + + fn(torch.ones(3)) + self.assertTrue(called1) + + called2 = False + + def callback2(ref): + nonlocal called2 + called2 = True + if torch.compiler.is_compiling(): + raise RuntimeError("callback2 expected to not be compiled") + + # weakref callbacks that fire outside the compiled region work + @torch.compile(backend="eager") + def gn(x): + y = x + 1 + ref = weakref.ref(y, callback2) + torch._dynamo.graph_break() + return y, ref + + y, _ = gn(torch.ones(3)) + del y + self.assertTrue(called2) + + def callback3(ref): + raise RuntimeError("callback3 should not be called") + + # The callback will NOT be called if both the weakref and the referrent are + # deleted in the same compiled region (graph breaks act like a "memory sync" + # and thus make things tricky - the callback is actually expected to be called). + # This test does NOT mean that this behavior is part of the (weak)ref programming + # model, but rather reminds us that this is an intentionally allowed weakref-Dynamo behavior. + @torch.compile(backend="eager") + def hn(x): + y = x + 1 + _ = weakref.ref(y, callback3) + + hn(torch.ones(3)) + # @torch._functorch.config.patch( # recompute_views=True, # ) @@ -5433,7 +5564,7 @@ def fn2(): y = torch.randn(100, 10) return torch.mm(x, y).sum() - with fresh_inductor_cache(): + with fresh_cache(): torch.compile(fn)() torch.compile(fn2)() @@ -5589,25 +5720,27 @@ def forward(self): # https://github.com/pytorch/pytorch/issues/121621 def test_tensor_random(self): - def random_op(tensor, params): - res = tensor.random_(**params) + def random_op(tensor, args, kwargs): + res = tensor.random_(*args, **kwargs) return res random_op = torch.compile(random_op) - params = {"from": -10, "to": 10} tensor = torch.randn([2, 3]) - random_op(tensor, params) + random_op(tensor, [], {"from": -10, "to": 10}) + random_op(tensor, [-10], {"to": 10}) + random_op(tensor, [-10, 10], {}) # https://github.com/pytorch/pytorch/issues/131019 def test_tensor_uniform(self): - def uniform_op(tensor, params): - res = tensor.uniform_(**params) + def uniform_op(tensor, args, kwargs): + res = tensor.uniform_(*args, **kwargs) return res uniform_op = torch.compile(uniform_op) - params = {"from": -10, "to": 10} tensor = torch.randn([2, 3]) - uniform_op(tensor, params) + uniform_op(tensor, [], {"from": -10, "to": 10}) + uniform_op(tensor, [-10], {"to": 10}) + uniform_op(tensor, [-10, 10], {}) def test_data_attr_mutation_after_saved_for_bw(self): def f(x): @@ -5706,6 +5839,17 @@ def f(x): torch.view_as_real(out_test).sum().backward() self.assertEqual(x_ref.grad, x_test.grad) + def test_add_complex_conj(self): + def f(x): + return x + x.conj() + + x = torch.randn(4, dtype=torch.complex64, requires_grad=True) + out = torch.compile(f)(x) + expected_complex = (2 * x.real).to(dtype=out.dtype) + + self.assertTrue(out.dtype == torch.complex64) + self.assertEqual(out, expected_complex) + # https://github.com/pytorch/pytorch/issues/132200 def test_partitioner_cse_respects_mutation_boundaries(self): set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_") @@ -5765,6 +5909,43 @@ def f(x, l): self.assertEqual(result, result_test) self.assertEqual(x, x_test) + def test_aot_autograd_runtime_wrapper_prologue_profiled(self): + # Names for prologue profiling event + prologue_name = "AOTDispatcher Runtime Wrapper Prologue" + + # Simple linear op to compile + mod = torch.nn.Linear(4, 4) + opt_mod = torch.compile(mod) + x = torch.randn(4, 4) + + # Run this test with grad and no-grad to test both boolean cases trace_joint + for c in [contextlib.nullcontext, torch.no_grad]: + # Run compiled op with profiling + with c(): + # warmup before profiling + opt_mod(x) + with profile(activities=[ProfilerActivity.CPU]) as prof: + opt_mod(x) + + # Make sure events are populated then find prologue event and last start time + events = prof.events() + self.assertTrue(events is not None) + + prologue_event = None + last_start_time = 0 + for event in events: + if hasattr(event, "name") and prologue_name in event.name: + prologue_event = event + if event.time_range.start > last_start_time: + last_start_time = event.time_range.start + + # Make sure prologue event exist + self.assertTrue(prologue_event is not None) + + # Make sure there is at least one other event (compiled function) that starts + # after prologue starts + self.assertLess(prologue_event.time_range.end, last_start_time) + def test_changing_stride(self): cnt = torch._dynamo.testing.CompileCounter() @@ -5978,7 +6159,7 @@ def f(x, param): self.assertEqual(out_ref, out_test) @requires_cuda - # This test will fail as flip in combination with particular input lenghts + # This test will fail as flip in combination with particular input lengths # produces weird results. # This is under investigations in # https://github.com/pytorch/pytorch/issues/131805 @@ -6779,9 +6960,12 @@ def test_incompatible_configs(self): ): torch.compile(lambda: None) - with torch._dynamo.config.patch( - suppress_errors=True, fail_on_recompile_limit_hit=True - ), self.assertRaises(AssertionError): + with ( + torch._dynamo.config.patch( + suppress_errors=True, fail_on_recompile_limit_hit=True + ), + self.assertRaises(AssertionError), + ): torch.compile(lambda: None) def test_str_isalnum(self): @@ -6794,6 +6978,100 @@ def f(x, c): c = "foobar" self.assertEqual(f(x, c), opt_f(x, c)) + def test_nn_param_freevar_codegen(self): + class Model2(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3) + self.batchnorm = nn.BatchNorm2d(num_features=5) + self.conv_weight = torch.randn(5, 3, 3, 3) + self.conv_bias = torch.randn(5) + + def forward(self, x): + self.conv.weight = nn.Parameter(self.conv_weight) + self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False) + self.conv.eval() + x = self.conv(x) + x = self.batchnorm(x) + x = F.relu(x) + return x + + input_tensor = torch.randn(1, 3, 10, 10) + func = Model2().to("cpu") + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + with torch.no_grad(): + func.train(False) + v1 = func(input_tensor) + jit_func = torch.compile(wrapper, backend="eager", fullgraph=True) + v2 = jit_func(input_tensor) + self.assertEqual(v1, v2) + + def test_amp_foreach_fake_impl(self): + inv_scale = torch.full((1,), 0.25) + found_inf = torch.full((1,), 0.0) + grads = [torch.ones(10), torch.ones(10)] + + def f(): + res = torch._amp_foreach_non_finite_check_and_unscale_( + grads, found_inf, inv_scale + ) + return res + + ref = f() + res = torch.compile(f, backend="aot_eager")() + self.assertEqual(ref, res) + + def test_deleted_compile_wrapper_segfault(self): + def fn(x): + return x + 1 + + opt_fn = torch.compile(fn, backend="eager") + # This calls cached_backend.clear() which removes any strong references + # to the callback + torch._dynamo.reset() + opt_fn(torch.randn(3)) + opt_fn = torch.compile(fn, backend="eager") + opt_fn(torch.randn(3)) # possible segfault due to first opt_fn deletion + + def test_delete_local_error(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + y = x + 1 + del y + z = y + 1 # noqa: F821 + return z + + with self.assertRaises(torch._dynamo.exc.Unsupported): + fn(torch.ones(3)) + + def test_nanmean_out(self): + def f(x, out): + torch.nanmean(x, out=out) + + x = torch.randn(4) + out_ref = torch.tensor(0.0) + out_res = torch.tensor(0.0) + + f(x, out_ref) + torch.compile(f, backend="eager", fullgraph=True)(x, out_res) + self.assertEqual(out_ref, out_res) + + def test_unbind_copy_out(self): + def f(eye, out): + torch.unbind_copy(eye, out=out) + + eye = torch.eye(3) + out_ref = (torch.zeros(3), torch.zeros(3), torch.zeros(3)) + out_res = (torch.zeros(3), torch.zeros(3), torch.zeros(3)) + + f(eye, out_ref) + torch.compile(f, backend="eager", fullgraph=True)(eye, out_res) + self.assertEqual(out_ref, out_res) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): @@ -7171,7 +7449,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # *are* saved for backward, and become back inputs. # The easier-to-test thing I'm checking for here is that the recompute # on primals_2 happens in the backward. With the recompute, - # there are 5 _to_copy ops in the backwrad. Without it, there are 4 + # there are 5 _to_copy ops in the backward. Without it, there are 4 # (aka if you set torch._functorch.config.treat_parameters_as_free_to_save = False) self.assertEqual(mode.ops_counter[torch.ops.aten._to_copy.default], 5) @@ -7271,6 +7549,19 @@ def f2(x): self.assertEqual(f2(torch.ones(3)), torch.ones(3) + 1) + def test_torch_cuda_is_initialized(self): + @torch.compile(fullgraph=True, backend="eager") + def f(x): + if torch.cuda.is_initialized(): + return x + 1 + return x + 2 + + inp = torch.randn(3) + self.assertEqual(f(inp), inp + 1) + + with mock.patch("torch.cuda.is_initialized", lambda: False): + self.assertEqual(f(inp), inp + 2) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_sets.py b/test/dynamo/test_sets.py new file mode 100644 index 00000000000000..bcb59ca4cd54a2 --- /dev/null +++ b/test/dynamo/test_sets.py @@ -0,0 +1,693 @@ +# Owner(s): ["module: dynamo"] + +# TODO: move set tests from test_functions.py/test_misc.py to this file + +import math +import unittest +from collections.abc import Iterable + +import torch +import torch._dynamo.test_case +from torch._dynamo.exc import Unsupported +from torch._dynamo.testing import CompileCounter +from torch.testing._internal.common_utils import make_dynamo_test, munge_exc +from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test + + +class SetSubclass(set): + pass + + +class FrozenstSubclass(frozenset): + pass + + +class _BaseSetTests(torch._dynamo.test_case.TestCase): + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + def assertEqual(self, a, b): + return self.assertTrue(a == b, f"{a} != {b}") + + def assertNotEqual(self, a, b): + return self.assertTrue(a != b, f"{a} == {b}") + + +class CustomSetTests(_BaseSetTests): + class CustomSet(set): + def add(self, item): + return super().add(item + 1) + + def contains(self, item): + return True + + thetype = CustomSet + + @make_dynamo_test + def test_custom_add(self): + s = self.thetype([1, 2]) + s.add(3) + self.assertTrue(s == {1, 2, 4}) + + @make_dynamo_test + def test_custom_contains(self): + s = self.thetype([1, 2]) + self.assertTrue(s.contains(3)) + + +class MiscTests(torch._dynamo.test_case.TestCase): + def test_isdisjoint_with_generator(self): + n = 0 + + def gen(): + nonlocal n + n += 1 + yield 1 + n += 2 + yield 2 + n += 3 + yield 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + nonlocal n + s = {2, 4, 5} + s.isdisjoint(gen()) + if n == 3: + return x.sin() + return x.cos() + + x = torch.randn(1) + y = fn(x) + self.assertEqual(y, x.sin()) + + +class TestSetGuards(LoggingTestCase): + def test_set_with_function(self): + s = { + torch._C._set_grad_enabled, + "hello", + torch.amp._exit_autocast, + } + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def fn(x, s): + if torch.amp._exit_autocast in s: + return x.sin() + return x.cos() + + x = torch.randn(2) + y = fn(x, s) + self.assertEqual(y, x.sin()) + self.assertEqual(cnts.frame_count, 1) + + s.remove(torch.amp._exit_autocast) + s.add(torch._C._set_fwd_grad_enabled) + y = fn(x, s) + self.assertEqual(y, x.cos()) + self.assertEqual(cnts.frame_count, 2) + + @make_logging_test(recompiles=True) + def test_in_guard(self, records): + s = { + "Dynamo", + "Inductor", + "PyTorch", + torch.sin, + } + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def fn(x, s): + if "PyTorch" in s: + return x.sin() + return x.cos() + + x = torch.randn(2) + y = fn(x, s) + self.assertEqual(y, x.sin()) + self.assertEqual(cnts.frame_count, 1) + + s.remove("PyTorch") + s.add("Cuda") + y = fn(x, s) + self.assertEqual(y, x.cos()) + self.assertEqual(cnts.frame_count, 2) + self.assertGreater(len(records), 0) + record = self.getRecord(records, "set.__contains__") + self.assertIn( + """set.__contains__(s, 'PyTorch')""", + munge_exc(record.getMessage()), + ) + + def test_set_with_tensors(self): + s = { + torch.ones(1), + torch.tensor([1.0]), + torch.zeros(1), + } + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def fn(x, s): + z = torch.zeros(1) + for i in s: + z += i + return x + z + + x = torch.tensor([1.0]) + self.assertExpectedInlineMunged( + Unsupported, + lambda: fn(x, s), + """\ +Attempted to wrap a set with tensors + Explanation: Dynamo cannot trace sets of tensors. To get a stable ordering, Dynamo needs to convert the set into a list and the order might not be stable if the set contains tensors. + Hint: Use a dictionary where the keys are tensors. + Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. + + Developer debug context: Python set containing torch.Tensor elements + + +from user code: + File "test_sets.py", line N, in fn + for i in s:""", # noqa: B950 + ) + + def test_set_multiple_types(self): + s = { + "PyTorch", + 3.3, + 1j, + math.nan, + } + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def fn(x, s): + if "PyTorch" in s: + return x.sin() + return x.cos() + + x = torch.tensor(1.0) + y = fn(x, s) + self.assertEqual(y, x.sin()) + self.assertEqual(cnts.frame_count, 1) + + s.remove("PyTorch") + y = fn(x, s) + self.assertEqual(y, x.cos()) + self.assertEqual(cnts.frame_count, 2) + + def test_set_recompile_on_key_pop(self): + s = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, + } + + cnts = CompileCounter() + + def fn(x, s): + if torch.amp._exit_autocast in s: + return x.sin() + return x.cos() + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + res = opt_fn(x, s) + opt_fn(x, s) + self.assertEqual(res, fn(x, s)) + # No recompilation + self.assertEqual(cnts.frame_count, 1) + + # Pop a value + s.remove(torch.amp._exit_autocast) + + res = opt_fn(x, s) + # Check recompilation + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(res, fn(x, s)) + + def test_set_recompile_on_key_change(self): + s = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, + } + + cnts = CompileCounter() + + def fn(x, s): + if torch.amp._exit_autocast in s: + return x.sin() + return x.cos() + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + res = opt_fn(x, s) + opt_fn(x, s) + self.assertEqual(res, fn(x, s)) + # No recompilation + self.assertEqual(cnts.frame_count, 1) + + # Pop a value + s.remove(torch.amp._exit_autocast) + # Add a different value + s.add(torch._C._set_autograd_fallback_mode) + + res = opt_fn(x, s) + # Check recompilation + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(res, fn(x, s)) + + @unittest.skip("random failures on Python 3.9") + def test_set_guard_on_keys_change(self): + # This test guarantee that we're not triggering any of the dict guards + # on sets + s = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, + } + + cnts = CompileCounter() + + def fn(x, s): + for e in s: + x = x * len(str(e)) + return x + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + opt_fn(torch.randn(4), s) + opt_fn(torch.randn(4), s) + # No recompilation + self.assertEqual(cnts.frame_count, 1) + + # pop and add the same item + s.remove(torch.amp._exit_autocast) + # It is not guaranteed that _exit_autocast will be in a specific order + s.add(torch.amp._exit_autocast) + + x = torch.randn(4) + res = opt_fn(x, s) + # Check Dynamo don't recompile + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(res, fn(x, s)) + + +class _FrozensetBase: + # Frozenset methods + # + copy + # + difference + # + intersection + # + isdisjoint + # + issubset + # + issuperset + # + symmetric_difference + # + union + # BinOps: + # +, -, |, &, ^, <, >, <=, >=, ==, != + + @make_dynamo_test + def test_binop_sub(self): + p, q = map(self.thetype, ["abc", "bef"]) + self.assertEqual(p - p, self.thetype()) + self.assertEqual(p - q, self.thetype("ac")) + self.assertEqual(q - p, self.thetype("ef")) + self.assertRaises(TypeError, lambda: p - 1) + self.assertEqual(self.thetype.__sub__(p, q), set("ac")) + + @make_dynamo_test + def test_binop_or(self): + p, q = map(self.thetype, ["abc", "bef"]) + self.assertEqual(p | p, self.thetype("abc")) + self.assertEqual(p | q, self.thetype("abcef")) + self.assertEqual(self.thetype.__or__(p, q), set("abcef")) + + @make_dynamo_test + def test_binop_and(self): + p, q = map(self.thetype, ["abc", "bef"]) + self.assertEqual(p & p, self.thetype("abc")) + self.assertEqual(p & q, self.thetype("b")) + self.assertEqual(self.thetype.__and__(p, q), set("b")) + + @make_dynamo_test + def test_binop_xor(self): + p, q = map(self.thetype, ["abc", "bef"]) + self.assertEqual(p ^ p, self.thetype()) + self.assertEqual(p ^ q, self.thetype("acef")) + self.assertEqual(self.thetype.__xor__(p, q), set("acef")) + + @make_dynamo_test + def test_cmp_eq(self): + p = self.thetype("abc") + self.assertEqual(p, p) + for C in set, frozenset, SetSubclass: + self.assertEqual(p, C("abc")) + self.assertEqual(p, C(p)) + self.assertTrue(self.thetype.__eq__(p, p)) + + @make_dynamo_test + def test_cmp_ne(self): + p, q = map(self.thetype, ["abc", "bef"]) + self.assertNotEqual(p, q) + self.assertNotEqual(q, p) + for C in set, frozenset, SetSubclass, dict.fromkeys, str, list, tuple: + self.assertNotEqual(p, C("abe")) + self.assertNotEqual(p, 1) + self.assertTrue(self.thetype.__ne__(p, q)) + + @make_dynamo_test + def test_cmp_less_than(self): + p, q, r = map(self.thetype, ["abc", "bef", "ab"]) + self.assertFalse(p < p) + self.assertFalse(p < q) + self.assertTrue(r < p) + self.assertFalse(r < q) + self.assertFalse(self.thetype.__lt__(p, p)) + + @make_dynamo_test + def test_cmp_greater_than(self): + p, q, r = map(self.thetype, ["abc", "bef", "ab"]) + self.assertFalse(p > p) + self.assertFalse(p > q) + self.assertTrue(p > r) + self.assertFalse(q > r) + self.assertFalse(self.thetype.__gt__(p, p)) + + @make_dynamo_test + def test_cmp_less_than_or_equal(self): + p, q, r = map(self.thetype, ["abc", "bef", "ab"]) + self.assertTrue(p <= p) + self.assertFalse(p <= q) + self.assertTrue(r <= p) + self.assertFalse(r <= q) + self.assertTrue(self.thetype.__le__(p, p)) + + @make_dynamo_test + def test_cmp_greater_than_or_equal(self): + p, q, r = map(self.thetype, ["abc", "bef", "ab"]) + self.assertTrue(p >= p) + self.assertFalse(p >= q) + self.assertTrue(p >= r) + self.assertFalse(q >= r) + self.assertTrue(self.thetype.__ge__(p, p)) + + @make_dynamo_test + def test_copy(self): + p = self.thetype("abc") + q = p.copy() + self.assertEqual(p, q) + self.assertRaises(TypeError, p.copy, 1) + self.assertEqual(self.thetype.copy(p), p) + + @make_dynamo_test + def test_issubset(self): + p, q, r = map(self.thetype, ["abc", "bc", "bef"]) + self.assertTrue(q.issubset(p)) + self.assertFalse(r.issubset(p)) + self.assertRaises(TypeError, p.issubset) + self.assertRaises(TypeError, p.issubset, 1) + self.assertRaises(TypeError, p.issubset, [[]]) + self.assertTrue(self.thetype.issubset(q, p)) + + @make_dynamo_test + def test_issuperset(self): + p, q, r = map(self.thetype, ["abc", "bc", "bef"]) + self.assertTrue(p.issuperset(q)) + self.assertFalse(p.issuperset(r)) + self.assertRaises(TypeError, p.issuperset) + self.assertRaises(TypeError, p.issuperset, 1) + self.assertRaises(TypeError, p.issuperset, [[]]) + self.assertTrue(self.thetype.issuperset(p, q)) + + @make_dynamo_test + def test_constructor_iterable(self): + p = self.thetype("abc") + self.assertIsInstance(p, self.thetype) + self.assertIsInstance(p, Iterable) + + @make_dynamo_test + def test_equality(self): + a = self.thetype("abc") + for typ in (self.thetype, set, frozenset): + self.assertEqual(a, typ(a)) + self.assertTrue(a == typ(a)) + self.assertTrue(a.__eq__(typ(a))) + self.assertTrue(self.thetype.__eq__(a, typ(a))) + + @make_dynamo_test + def test_in_frozenset(self): + item = self.thetype("abc") + container = self.thetype([frozenset("abc")]) # noqa: C405 + self.assertIn(item, container) + + @make_dynamo_test + def test_contains(self): + s = self.thetype(["a", "b", "c"]) + self.assertIn("a", s) + self.assertNotIn("d", s) + self.assertTrue(s.__contains__("a")) + self.assertTrue(self.thetype.__contains__(s, "b")) + + @make_dynamo_test + def test_isdisjoint(self): + x = self.thetype({"apple", "banana", "cherry"}) + y = self.thetype({"google", "microsoft", "apple"}) + z = self.thetype({"shoes", "flipflops", "sneakers"}) + self.assertFalse(x.isdisjoint(y)) + self.assertTrue(x.isdisjoint(z)) + self.assertRaises(TypeError, x.isdisjoint) + self.assertRaises(TypeError, x.isdisjoint, 1) + self.assertRaises(TypeError, x.isdisjoint, [[]]) + p, q = map(self.thetype, ["abc", "bef"]) + self.assertFalse(self.thetype.isdisjoint(p, q)) + + @make_dynamo_test + def test_intersection(self): + set1 = self.thetype({"apple", "banana", "cherry"}) + set2 = self.thetype({"google", "microsoft", "apple"}) + set3 = self.thetype({"shoes", "flipflops", "apple"}) + intersection_set = set1.intersection(set2, set3) + self.assertEqual(intersection_set, {"apple"}) + self.assertRaises(TypeError, set1.intersection, 1) + self.assertRaises(TypeError, set1.intersection, [[]]) + p, q = map(self.thetype, ["abc", "bef"]) + self.assertEqual(self.thetype.intersection(p, q), {"b"}) + + @make_dynamo_test + def test_union(self): + p, q, r = map(self.thetype, ["abc", "bc", "bef"]) + union_set = p.union(q, r) + self.assertEqual(union_set, {"a", "b", "c", "e", "f"}) + self.assertRaises(TypeError, p.union, 1) + self.assertRaises(TypeError, p.union, [[]]) + s = self.thetype.union(q, r) + self.assertEqual(s, {"b", "c", "e", "f"}) + + @make_dynamo_test + def test_difference(self): + set1 = self.thetype({"apple", "banana", "cherry"}) + set2 = self.thetype({"google", "microsoft", "apple"}) + set3 = self.thetype({"shoes", "flipflops", "sneakers"}) + difference_set = set1.difference(set2, set3) + self.assertEqual(difference_set, {"banana", "cherry"}) + self.assertRaises(TypeError, set1.difference, 1) + self.assertRaises(TypeError, set1.difference, [[]]) + p, q = map(self.thetype, ["abc", "bef"]) + self.assertEqual(self.thetype.difference(p, q), {"a", "c"}) + + @make_dynamo_test + def test_symmetric_difference(self): + set1 = self.thetype({"apple", "banana", "cherry"}) + set2 = self.thetype({"google", "microsoft", "apple"}) + symmetric_diff_set = set1.difference(set2) + self.assertEqual(symmetric_diff_set, {"banana", "cherry"}) + self.assertRaises(TypeError, set1.symmetric_difference) + self.assertRaises(TypeError, set1.symmetric_difference, 1) + self.assertRaises(TypeError, set1.symmetric_difference, [[]]) + p, q = map(self.thetype, ["abc", "bef"]) + symmetric_diff_set = self.thetype.symmetric_difference(p, q) + self.assertEqual(symmetric_diff_set, {"a", "c", "e", "f"}) + + @make_dynamo_test + def test_to_frozenset(self): + set1 = frozenset(self.thetype({"apple", "banana", "cherry"})) + self.assertIsInstance(set1, frozenset) + self.assertEqual(len(set1), 3) + + @make_dynamo_test + def test_to_set(self): + set1 = frozenset(self.thetype({"apple", "banana", "cherry"})) + self.assertIsInstance(set1, frozenset) + self.assertEqual(len(set1), 3) + + +class _SetBase(_FrozensetBase): + # Set Methods + # + add + # + clear + # - copy (inherited from frozenset) + # - difference (inherited from frozenset) + # + difference_update + # + discard + # - intersection (inherited from frozenset) + # + intersection_update + # - isdisjoint (inherited from frozenset) + # - issubset (inherited from frozenset) + # - issuperset (inherited from frozenset) + # + pop + # + remove + # - symmetric_difference (inherited from frozenset) + # + symmetric_difference_update + # - union (inherited from frozenset) + # + update + + @make_dynamo_test + def test_add(self): + p = self.thetype("abc") + p.add("d") + self.assertEqual(p, {"a", "b", "c", "d"}) + p.add("a") + self.assertEqual(p, {"a", "b", "c", "d"}) + self.assertRaises(TypeError, p.add, ["ab"]) + self.assertRaises(TypeError, p.add) + set.add(p, "e") + self.assertEqual(p, {"a", "b", "c", "d", "e"}) + + @make_dynamo_test + def test_clear(self): + p = self.thetype("abc") + p.clear() + self.assertEqual(p, set()) + p = self.thetype("abc") + self.thetype.clear(p) + self.assertEqual(len(p), 0) + + @make_dynamo_test + def test_remove(self): + p = self.thetype("abc") + self.assertEqual(p.remove("a"), None) + self.assertEqual(p, {"b", "c"}) + self.assertRaises(KeyError, p.remove, "a") + p = self.thetype("abc") + self.thetype.remove(p, "b") + self.assertEqual(p, self.thetype({"a", "c"})) + + @make_dynamo_test + def test_intersection_update(self): + set1 = self.thetype({"apple", "banana", "cherry"}) + set2 = self.thetype({"google", "microsoft", "apple"}) + set3 = self.thetype({"shoes", "flipflops", "apple"}) + self.assertIsNone(set1.intersection_update(set2, set3)) + self.assertEqual(set1, {"apple"}) + self.assertRaises(TypeError, set1.intersection_update, [[]]) + p, q = map(self.thetype, ["abc", "bef"]) + self.thetype.intersection_update(p, q) + self.assertEqual(p, {"b"}) + + @make_dynamo_test + def test_difference_update(self): + set1 = self.thetype({"apple", "banana", "cherry"}) + set2 = self.thetype({"google", "microsoft", "apple"}) + set3 = self.thetype({"shoes", "flipflops", "sneakers"}) + self.assertIsNone(set1.difference_update(set2, set3)) + self.assertEqual(set1, {"banana", "cherry"}) + self.assertRaises(TypeError, set1.difference_update, [[]]) + p, q = map(self.thetype, ["abc", "bef"]) + self.thetype.difference_update(p, q) + self.assertEqual(p, {"a", "c"}) + + @make_dynamo_test + def test_symmetric_difference_update(self): + set1 = self.thetype({"apple", "banana", "cherry"}) + set2 = self.thetype({"google", "microsoft", "apple"}) + self.assertIsNone(set1.symmetric_difference_update(set2)) + self.assertEqual(set1, {"banana", "cherry", "google", "microsoft"}) + self.assertRaises(TypeError, set1.symmetric_difference_update) + self.assertRaises(TypeError, set1.symmetric_difference_update, [[]]) + p, q = map(self.thetype, ["abc", "bef"]) + self.thetype.symmetric_difference_update(p, q) + self.assertEqual(p, {"a", "c", "e", "f"}) + + @make_dynamo_test + def test_pop(self): + set1 = self.thetype({"apple", "banana", "cherry"}) + e = set1.pop() + self.assertNotIn(e, set1) + s = self.thetype() + self.assertRaises(KeyError, s.pop) + p = self.thetype("a") + self.assertEqual(self.thetype.pop(p), "a") + + @make_dynamo_test + def test_update(self): + p, q, r = map(self.thetype, ["abc", "bc", "bef"]) + p.update(q, r) + self.assertEqual(p, {"a", "b", "c", "e", "f"}) + self.assertRaises(TypeError, p.update, [[]]) + self.thetype.update(q, r) + self.assertEqual(q, {"b", "c", "e", "f"}) + + @make_dynamo_test + def test_discard(self): + set1 = self.thetype({"apple", "banana", "cherry"}) + set2 = self.thetype({"google", "microsoft", "apple"}) + set1.discard("banana") + set2.discard("cherry") + self.assertEqual(set1, {"apple", "cherry"}) + self.assertEqual(set2, {"google", "microsoft", "apple"}) + p = self.thetype("abc") + self.thetype.discard(p, "a") + self.assertEqual(p, {"b", "c"}) + + +class FrozensetTests(_FrozensetBase, _BaseSetTests): + thetype = frozenset + + +class SetTests(_SetBase, _BaseSetTests): + thetype = set + + @unittest.expectedFailure + def test_in_frozenset(self): + super().test_in_frozenset() + + +class UserDefinedSetTests(_SetBase, _BaseSetTests): + class CustomSet(set): + pass + + thetype = CustomSet + + @unittest.expectedFailure + def test_in_frozenset(self): + super().test_in_frozenset() + + @unittest.expectedFailure + def test_equality(self): + super().test_in_frozenset() + + +class UserDefinedFrozensetTests(_FrozensetBase, _BaseSetTests): + class CustomFrozenset(frozenset): + pass + + thetype = CustomFrozenset + + @unittest.expectedFailure + def test_in_frozenset(self): + super().test_in_frozenset() + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 975f6b6a6c57a2..5452258ed5674f 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -108,9 +108,10 @@ def mk_obscure(base_is_nt): for requires_grad_1, requires_grad_2 in itertools.product( [True, False], repeat=2 ): - yield partial( - mk_leaf, base_is_nt, requires_grad_1, requires_grad_2 - ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}" + yield ( + partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2), + f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}", + ) # (3) obscure case: # view is not a leaf (implies requires_grad True) @@ -118,9 +119,10 @@ def mk_obscure(base_is_nt): yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure" # Subclass -> Dense - yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[ - 0 - ].clone(), "subclass_dense" + yield ( + lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(), + "subclass_dense", + ) # Dense -> Subclass -> Dense -> Subclass def mk_dense_subclass_dense_subclass(): @@ -735,7 +737,7 @@ def fn(a, w): self.assertEqual(res_exp, res_act) - def test_user_overidden_method_unsupported(self): + def test_user_overridden_method_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -755,7 +757,7 @@ def fn(x): self.assertEqual(res_exp, res_act) - def test_user_overidden_attr_unsupported(self): + def test_user_overridden_attr_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -769,12 +771,12 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def fn(x): return x.ndim - msg = "`torch.compile` only support tracing certain types of overriden tensor subclass attributes" + msg = "`torch.compile` only support tracing certain types of overridden tensor subclass attributes" with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) - def test_user_overidden_property_unsupported(self): + def test_user_overridden_property_unsupported(self): class LocalSubclass(torch.Tensor): def __init__(self, *args, **kwargs) -> None: self._ndim = 10 @@ -988,8 +990,8 @@ def fn(x): self.assertEqual(x0, x1) self.assertEqual(x0.tensor_shape, x1.tensor_shape) - def test_subclass_dont_invoke_torch_function_on_overriden_method(self): - # We shouldn't fire `__torch_function__` for overriden tensor methods. + def test_subclass_dont_invoke_torch_function_on_overridden_method(self): + # We shouldn't fire `__torch_function__` for overridden tensor methods. class MySubclass(torch.Tensor): def to(self, device): return self * len(device) @@ -1011,10 +1013,10 @@ def fn(x): res_act = fn_opt(x) self.assertEqual(res_exp, res_act) - def test_subclass_dont_invoke_torch_function_on_overriden_attr(self): + def test_subclass_dont_invoke_torch_function_on_overridden_attr(self): from types import MethodWrapperType - # We shouldn't fire `__torch_function__` for overriden tensor attrs. + # We shouldn't fire `__torch_function__` for overridden tensor attrs. class MySubclass(torch.Tensor): def ndim(self): return 42 @@ -1113,9 +1115,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): elif isinstance(result, (tuple, list)): # Preserve the original type (tuple or list) wrapped = [ - cls(x, quant_type=quant_type) - if isinstance(x, torch.Tensor) - else x + ( + cls(x, quant_type=quant_type) + if isinstance(x, torch.Tensor) + else x + ) for x in result ] return type(result)(wrapped) @@ -1202,7 +1206,7 @@ def f(t): def test_nontraceable_tensor_subclass(self): # This will error if Dynamo tries to wrap it as a tensor variable, # because that involves calling certain methods to inspect the tensor - # property, which will blow up in the overriden `__torch_function__`. + # property, which will blow up in the overridden `__torch_function__`. class MySubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -1364,7 +1368,7 @@ def forward(self, L_x_: "f32[3, 4]"): ) self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) - # Cannot re-use the version from AOTAutograd, since that uses python functional tensors. + # Cannot reuse the version from AOTAutograd, since that uses python functional tensors. def to_fun(x): x_functional = torch._to_functional_tensor(x) torch._mirror_autograd_meta_to(x, x_functional) @@ -2013,7 +2017,7 @@ def forward(self): exp_frame_count=[1, 1, 2, 2], exp_shape_env_guards=[ [], - # s0 is specialized and guarded in outter shape_env when dynamo checks the guards + # s0 is specialized and guarded in outer shape_env when dynamo checks the guards ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"], [ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", @@ -2035,7 +2039,7 @@ def forward(self): exp_frame_count=[1, 1, 2, 2], exp_shape_env_guards=[ [], - # s0 is specialized and guarded in outter shape_env when dynamo checks the guards + # s0 is specialized and guarded in outer shape_env when dynamo checks the guards ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"], [ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", @@ -2535,9 +2539,9 @@ def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1]) - sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone) + sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0) view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1]) - return (clone, view, view_1, sym_numel_default, clone_1, primals_5) + return (clone, view, view_1, sym_size_int_2, clone_1, primals_5) """, # noqa: B950 ) @@ -2591,9 +2595,9 @@ def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1]) - sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone) + sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0) view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1]) - return (clone, view, view_1, sym_numel_default, clone_1, primals_5) + return (clone, view, view_1, sym_size_int_2, clone_1, primals_5) """, # noqa: B950 ) @@ -2856,7 +2860,7 @@ def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", ar cat: "f64[9, 5]" = torch.ops.aten.cat.default([randn, randn_1, randn_2]); randn = randn_1 = randn_2 = None zeros: "i64[1]" = torch.ops.aten.zeros.default([1], dtype = torch.int64, device = device(type='cpu'), pin_memory = False) - _tensor_constant0 = self._tensor_constant0 + _tensor_constant0: "i64[3]" = self._tensor_constant0 lift_fresh_copy: "i64[3]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None cumsum: "i64[3]" = torch.ops.aten.cumsum.default(lift_fresh_copy, 0); lift_fresh_copy = None cat_1: "i64[4]" = torch.ops.aten.cat.default([zeros, cumsum]); zeros = cumsum = None @@ -3081,7 +3085,7 @@ def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"): # triggers the eager logic to run, updating the counter and registry. # # Notably however, compile differs in two ways from eager: - # (1) The order in which the offsets are assigned ids is differnet + # (1) The order in which the offsets are assigned ids is different # the registry would be set in the order the offsets are returned # which is not necessarily the same order as they were constructed. # (2) If a NestedTensor is not returned, then the AOTAutograd wrapping diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index 0cac9499b9d065..35036fd1de3fa9 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -401,7 +401,7 @@ def fn(a, b): y = torch.randn(3) self.assertEqual(opt_fn(x, y), fn(x, y)) self.assertEqual(opt_fn(x, x), fn(x, x)) - # NB: This COULD validly be 2, but we don't test disjointness in the + # NB: This COULD validly be 2, but we don't test disjointedness in the # guards for when x and y didn't duck size together, so we end up # with a generic graph that also works when x and y happen to duck # size together. diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index 90aa18caee484b..9bfccd94b1f7e6 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -126,7 +126,7 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject torch_name_rule_map = {} # In some platforms, these functions were loaded as classes instead of functions. - # To mitigate these weired cases, we need this special check. + # To mitigate these weird cases, we need this special check. def is_special_functions(obj): return hashable(obj) and obj in { torch._C._cuda_isCurrentStreamCapturing, @@ -151,9 +151,9 @@ def heuristic_record_if_in_graph_function(obj, module, name): types.WrapperDescriptorType, ), ) or is_special_functions(obj): - torch_name_rule_map[ - f"{module.__name__}.{name}" - ] = TorchInGraphFunctionVariable + torch_name_rule_map[f"{module.__name__}.{name}"] = ( + TorchInGraphFunctionVariable + ) if c_binding_only: if not hasattr(obj, "__code__"): c_binding_in_graph_functions.add(obj) @@ -398,12 +398,15 @@ def fn(x): ) self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST) - with unittest.mock.patch( - "torch._dynamo.trace_rules.torch_name_rule_map", - _torch_name_rule_map, - ), unittest.mock.patch( - "torch._dynamo.trace_rules.get_torch_obj_rule_map", - torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache + with ( + unittest.mock.patch( + "torch._dynamo.trace_rules.torch_name_rule_map", + _torch_name_rule_map, + ), + unittest.mock.patch( + "torch._dynamo.trace_rules.get_torch_obj_rule_map", + torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache + ), ): x = torch.rand(3) opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) @@ -419,9 +422,9 @@ def fn(x): _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() # Force inline `mod.func` by setting trace rule. - _manual_torch_name_rule_map[ - f"{mod.__name__}.{func.__name__}" - ] = UserFunctionVariable + _manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = ( + UserFunctionVariable + ) _torch_name_rule_map = [ _manual_torch_name_rule_map, @@ -429,12 +432,15 @@ def fn(x): torch_non_c_binding_in_graph_functions, ] - with unittest.mock.patch( - "torch._dynamo.trace_rules.torch_name_rule_map", - _torch_name_rule_map, - ), unittest.mock.patch( - "torch._dynamo.trace_rules.get_torch_obj_rule_map", - torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, + with ( + unittest.mock.patch( + "torch._dynamo.trace_rules.torch_name_rule_map", + _torch_name_rule_map, + ), + unittest.mock.patch( + "torch._dynamo.trace_rules.get_torch_obj_rule_map", + torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, + ), ): # First adding the module to SKIP_DIRS so that it will be skipped by default. torch._dynamo.trace_rules.add(mod.__name__) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index ecba213ebddb68..8ae4c9e58343ce 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -266,7 +266,7 @@ def fn(x, rand1, rand2, rand3): self.assertEqual(rand2_1.getstate(), rand2_2.getstate()) self.assertEqual(rand3_1.getstate(), rand3_2.getstate()) - def test_random_object_overriden_methods(self): + def test_random_object_overridden_methods(self): # these will result in graph breaks, but we shouldn't crash def get_rng(): rand1 = random.Random(1) @@ -883,8 +883,10 @@ def fn(x, scaler): self.assertEqual(ref.device, res.device) -devices = ["cuda", "hpu"] -instantiate_device_type_tests(UnspecTestsDevice, globals(), only_for=devices) +devices = ["cuda", "hpu", "xpu"] +instantiate_device_type_tests( + UnspecTestsDevice, globals(), only_for=devices, allow_xpu=True +) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index ee74fba237f792..c9ab3b781887dd 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -142,7 +142,7 @@ def break_it2(x): compilation_events = [arg[0][0] for arg in log_event.call_args_list] self.assertEqual(compilation_events[-1].num_graph_breaks, 2) - def test_frame_traced_hook(self): + def test_traced_code_query(self): try: from .utils import add, break_it except ImportError: @@ -150,31 +150,33 @@ def test_frame_traced_hook(self): traced_code_lists = [] - def get_traced_code(s): - nonlocal traced_code_lists - traced_code_lists.append(s) - def get_filenames(traced_code_lists): return [ [code.co_filename for code in code_list] for code_list in traced_code_lists ] + def my_backend(gm, example_inputs): + from torch._dynamo.utils import get_traced_code + + nonlocal traced_code_lists + traced_code_lists.append(get_traced_code()) + return gm.forward + utils_path = os.path.join(os.path.dirname(__file__), "utils.py") # === no inlining === - @torch.compile(options={"frame_traced_fn": get_traced_code}) + @torch.compile(backend=my_backend) def fn(x): return x * 2 x = torch.randn(3) traced_code_lists = [] fn(x) - # expect hook to be called once with this file self.assertEqual(get_filenames(traced_code_lists), [[__file__]]) # === successful inlining === - @torch.compile(options={"frame_traced_fn": get_traced_code}) + @torch.compile(backend=my_backend) def fn(x): return add(x) * 2 @@ -182,30 +184,28 @@ def fn(x): traced_code_lists = [] fn(x) utils_path = os.path.join(os.path.dirname(__file__), "utils.py") - # expect hook to be called once with both this file and file of inlined func - self.assertEqual(get_filenames(traced_code_lists), [[utils_path, __file__]]) + self.assertEqual(get_filenames(traced_code_lists), [[__file__, utils_path]]) # === graph break occurs during inlining === - @torch.compile(options={"frame_traced_fn": get_traced_code}) + @torch.compile(backend=my_backend) def fn(x): - y = break_it(x) + z = x + 1 + y = break_it(z) return y * 2 x = torch.randn(3) traced_code_lists = [] fn(x) - # expect hook to be called twice; once for this file one for file of inlined func self.assertEqual(get_filenames(traced_code_lists), [[__file__], [utils_path]]) # === empty graph === - @torch.compile(options={"frame_traced_fn": get_traced_code}) + @torch.compile(backend=my_backend) def fn(x): return x x = torch.randn(3) traced_code_lists = [] fn(x) - # hook is not expected to be called at all for an empty graph self.assertEqual(traced_code_lists, []) @@ -593,9 +593,10 @@ def f(x): ) compilation_events = [] - with dynamo_config.patch({"automatic_dynamic_shapes": False}), mock.patch( - "torch._dynamo.utils.log_compilation_event" - ) as log_event: + with ( + dynamo_config.patch({"automatic_dynamic_shapes": False}), + mock.patch("torch._dynamo.utils.log_compilation_event") as log_event, + ): @torch.compile() def f(x): diff --git a/test/dynamo/test_view.py b/test/dynamo/test_view.py index 61b80f7bd8b09e..03b9ac5a9f81a4 100644 --- a/test/dynamo/test_view.py +++ b/test/dynamo/test_view.py @@ -33,6 +33,86 @@ def f(t, _n): t = torch.tensor([2, 4], dtype=torch.int32) f(t, 8) + def test_view_with_tensor_shape_params(self): + # Test for issue #156720: aten.view.default with tensor shape parameters + class TestModel(torch.nn.Module): + def forward(self, x, shape_params): + return torch.ops.aten.view.default(x, shape_params) + + x = torch.randn(24) + shape_params = [ + torch.tensor(2, dtype=torch.int32), + torch.tensor(3, dtype=torch.int32), + torch.tensor(4, dtype=torch.int32), + ] + + model = TestModel() + expected = model(x, shape_params) + + compiled_model = torch.compile(model, backend="eager") + result = compiled_model(x, shape_params) + + torch.testing.assert_close(result, expected) + + def test_tensor_view_with_tensor_shape_params(self): + # Test tensor.view() method with tensor shape parameters (list version) + class TestModel(torch.nn.Module): + def forward(self, x, shape_params): + return x.view(shape_params) + + x = torch.randn(24) + shape_params = ( + torch.tensor(2, dtype=torch.int32), + torch.tensor(3, dtype=torch.int32), + torch.tensor(4, dtype=torch.int32), + ) + + model = TestModel() + expected = model(x, shape_params) + + compiled_model = torch.compile(model, backend="eager") + result = compiled_model(x, shape_params) + + torch.testing.assert_close(result, expected) + + def test_tensor_view_with_tensor_args(self): + # Test tensor.view() method with individual tensor arguments + class TestModel(torch.nn.Module): + def forward(self, x, dim1, dim2, dim3): + return x.view(dim1, dim2, dim3) + + x = torch.randn(24) + dim1 = torch.tensor(2, dtype=torch.int32) + dim2 = torch.tensor(3, dtype=torch.int32) + dim3 = torch.tensor(4, dtype=torch.int32) + + model = TestModel() + expected = model(x, dim1, dim2, dim3) + + compiled_model = torch.compile(model, backend="eager") + result = compiled_model(x, dim1, dim2, dim3) + + torch.testing.assert_close(result, expected) + + def test_torch_reshape_with_tensor_shape_params(self): + # Test torch.reshape() function with tensor shape parameters + def test_fn(x, shape_params): + return torch.reshape(x, shape_params) + + x = torch.randn(24) + shape_params = [ + torch.tensor(2, dtype=torch.int32), + torch.tensor(3, dtype=torch.int32), + torch.tensor(4, dtype=torch.int32), + ] + + expected = test_fn(x, shape_params) + + compiled_fn = torch.compile(test_fn, backend="eager") + result = compiled_fn(x, shape_params) + + torch.testing.assert_close(result, expected) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_inheritance b/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_inheritance new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_multi_arg b/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_multi_arg new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_no_arg b/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_no_arg new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_single_arg b/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_single_arg new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_setstate_refcount_no_crash b/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_setstate_refcount_no_crash new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_catch_BaseException_instance b/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_catch_BaseException_instance new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_catch_non_BaseException b/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_catch_non_BaseException new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_catch_string b/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_catch_string new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_raise_new_style_non_exception b/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_raise_new_style_non_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_raise_string b/test/dynamo_expected_failures/CPython313-test_baseexception-UsageTests.test_raise_string new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.testAtanSign b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.testAtanSign new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.testAtanhSign b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.testAtanhSign new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.testTanhSign b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.testTanhSign new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_abs b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_abs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_abs_overflows b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_abs_overflows new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_cmath_matches_math b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_cmath_matches_math new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_input_type b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_input_type new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_isfinite b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_isfinite new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_isinf b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_isinf new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_isnan b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_isnan new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_phase b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_phase new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_polar b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_polar new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_polar_errno b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_polar_errno new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_rect b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_rect new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_specific_values b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_specific_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_user_object b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_user_object new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_asymmetry b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_asymmetry new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_complex_near_zero b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_complex_near_zero new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_complex_special b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_complex_special new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_complex_values b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_complex_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_decimals b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_decimals new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_eight_decimal_places b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_eight_decimal_places new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_fractions b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_fractions new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_identical b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_identical new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_identical_infinite b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_identical_infinite new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_inf_ninf_nan b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_inf_ninf_nan new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_integers b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_integers new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_near_zero b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_near_zero new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_negative_tolerances b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_negative_tolerances new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_reject_complex_tolerances b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_reject_complex_tolerances new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_zero_tolerance b/test/dynamo_expected_failures/CPython313-test_cmath-IsCloseTests.test_zero_tolerance new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ClosingTestCase.test_closing b/test/dynamo_expected_failures/CPython313-test_contextlib-ClosingTestCase.test_closing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ClosingTestCase.test_closing_error b/test/dynamo_expected_failures/CPython313-test_contextlib-ClosingTestCase.test_closing_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ClosingTestCase.test_instance_docs b/test/dynamo_expected_failures/CPython313-test_contextlib-ClosingTestCase.test_instance_docs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_attribs b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_attribs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_do_not_unchain_non_stopiteration_exceptions b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_do_not_unchain_non_stopiteration_exceptions new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_doc_attrib b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_doc_attrib new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_except_pep479 b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_except_pep479 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_except_stopiter b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_except_stopiter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_traceback b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_traceback new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_trap_second_yield b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_trap_second_yield new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_trap_yield_after_throw b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_trap_yield_after_throw new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_wrap_runtimeerror b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_contextmanager_wrap_runtimeerror new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_instance_docstring_given_cm_docstring b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_instance_docstring_given_cm_docstring new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_nokeepref b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_nokeepref new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_param_errors b/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_param_errors new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-FileContextTestCase.testWithOpen b/test/dynamo_expected_failures/CPython313-test_contextlib-FileContextTestCase.testWithOpen new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithBoundedSemaphore b/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithBoundedSemaphore new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithCondition b/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithCondition new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithLock b/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithLock new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithRLock b/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithRLock new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithSemaphore b/test/dynamo_expected_failures/CPython313-test_contextlib-LockContextTestCase.testWithSemaphore new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-NullcontextTestCase.test_nullcontext b/test/dynamo_expected_failures/CPython313-test_contextlib-NullcontextTestCase.test_nullcontext new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_enter b/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_enter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_exit_is_abstract b/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_exit_is_abstract new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_slots b/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_slots new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_structural_subclassing b/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_structural_subclassing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestChdir.test_exception b/test/dynamo_expected_failures/CPython313-test_contextlib-TestChdir.test_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestChdir.test_reentrant b/test/dynamo_expected_failures/CPython313-test_contextlib-TestChdir.test_reentrant new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestChdir.test_simple b/test/dynamo_expected_failures/CPython313-test_contextlib-TestChdir.test_simple new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_contextdecorator_as_mixin b/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_contextdecorator_as_mixin new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_contextdecorator_with_exception b/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_contextdecorator_with_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_decorating_method b/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_decorating_method new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_decorator_with_exception b/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_decorator_with_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_typo_enter b/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_typo_enter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_typo_exit b/test/dynamo_expected_failures/CPython313-test_contextlib-TestContextDecorator.test_typo_exit new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_body_exception_suppress b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_body_exception_suppress new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_callback b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_callback new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_close b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_close new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_dont_reraise_RuntimeError b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_dont_reraise_RuntimeError new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_enter_context b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_enter_context new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_enter_context_errors b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_enter_context_errors new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_excessive_nesting b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_excessive_nesting new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_chaining b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_chaining new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_chaining_reference b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_chaining_reference new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_chaining_suppress b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_chaining_suppress new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_explicit_none_context b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_explicit_none_context new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_non_suppressing b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_non_suppressing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_traceback b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_traceback new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_with_correct_context b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_with_correct_context new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_with_existing_context b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_exception_with_existing_context new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_raise b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_raise new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_suppress b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_exit_suppress new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_instance_bypass b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_instance_bypass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_instance_docs b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_instance_docs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_no_resources b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_no_resources new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_pop_all b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_pop_all new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_push b/test/dynamo_expected_failures/CPython313-test_contextlib-TestExitStack.test_push new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_cm_is_reentrant b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_cm_is_reentrant new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_cm_is_reusable b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_cm_is_reusable new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_enter_result_is_target b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_enter_result_is_target new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_instance_docs b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_instance_docs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_no_redirect_in_init b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_no_redirect_in_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_redirect_to_string_io b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStderr.test_redirect_to_string_io new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_cm_is_reentrant b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_cm_is_reentrant new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_cm_is_reusable b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_cm_is_reusable new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_enter_result_is_target b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_enter_result_is_target new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_instance_docs b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_instance_docs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_no_redirect_in_init b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_no_redirect_in_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_redirect_to_string_io b/test/dynamo_expected_failures/CPython313-test_contextlib-TestRedirectStdout.test_redirect_to_string_io new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_cm_is_reentrant b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_cm_is_reentrant new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_exact_exception b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_exact_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_exception_groups b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_exception_groups new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_exception_hierarchy b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_exception_hierarchy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_instance_docs b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_instance_docs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_multiple_exception_args b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_multiple_exception_args new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_no_args b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_no_args new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_no_exception b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_no_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_no_result_from_enter b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_no_result_from_enter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_other_exception b/test/dynamo_expected_failures/CPython313-test_contextlib-TestSuppress.test_other_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-CAPITest.test_getitem_knownhash b/test/dynamo_expected_failures/CPython313-test_dict-CAPITest.test_getitem_knownhash new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_bad_key b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_bad_key new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_clear b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_constructor b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_contains b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_contains new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_fuzz b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_fuzz new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_maintains_tracking b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_maintains_tracking new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_noncompact b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_noncompact new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_contain_use_after_free b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_contain_use_after_free new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictitems_contains_use_after_free b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictitems_contains_use_after_free new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_mixed_set_operations b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_mixed_set_operations new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_set_operations_on_items b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_set_operations_on_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_eq b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_eq new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_equal_operator_modifying_operand b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_equal_operator_modifying_operand new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_errors_in_view_containment_check b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_errors_in_view_containment_check new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys_operator_modifying_dict_operand b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys_operator_modifying_dict_operand new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys_operator_modifying_set_operand b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys_operator_modifying_set_operand new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_get b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_getitem b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_init_use_after_free b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_init_use_after_free new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_instance_dict_getattr_str_subclass b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_instance_dict_getattr_str_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_invalid_keyword_arguments b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_invalid_keyword_arguments new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_itemiterator_pickling b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_itemiterator_pickling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_items b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_items_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_items_symmetric_difference new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_iterator_pickling b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_iterator_pickling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_keys b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_keys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_keys_contained b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_keys_contained new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_literal_constructor b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_literal_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_merge_and_mutate b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_merge_and_mutate new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_merge_operator b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_merge_operator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_missing b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_missing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_iteration b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_iteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_iteration_delete b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_iteration_delete new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_iteration_delete_over_items b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_iteration_delete_over_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_iteration_delete_over_values b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_iteration_delete_over_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_lookup b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_lookup new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_object_set_item_single_instance_non_str_key b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_object_set_item_single_instance_non_str_key new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_oob_indexing_dictiter_iternextitem b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_oob_indexing_dictiter_iternextitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_pop b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_popitem b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reentrant_insertion b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reentrant_insertion new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_repr b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_repr_deep b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_repr_deep new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_resize2 b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_resize2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverse_iterator_for_shared_shared_dicts b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverse_iterator_for_shared_shared_dicts new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reversed b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reversed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverseitemiterator_pickling b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverseitemiterator_pickling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverseiterator_pickling b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverseiterator_pickling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reversevaluesiterator_pickling b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reversevaluesiterator_pickling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setdefault b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setdefault_atomic b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setdefault_atomic new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setitem_atomic_at_resize b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setitem_atomic_at_resize new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_del b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_del new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_pop b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_pop_pending b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_pop_pending new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_popitem b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_setdefault b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_to_generic_combinedtable b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_to_generic_combinedtable new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_update b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_store_evilattr b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_store_evilattr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_str_nonstr b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_str_nonstr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_tuple_keyerror b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_tuple_keyerror new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_update b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_values b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_valuesiterator_pickling b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_valuesiterator_pickling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_views_mapping b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_views_mapping new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_constructor b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_keys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_read new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_update b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_write new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_constructor b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_keys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_update b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_write new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested b/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_else b/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_else new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_else_mixed2 b/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_else_mixed2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_mixed1 b/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_mixed1 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except b/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_else b/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_else new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_else_finally b/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_else_finally new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_finally b/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_finally new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-AssertionErrorTests.test_assertion_error_location b/test/dynamo_expected_failures/CPython313-test_exceptions-AssertionErrorTests.test_assertion_error_location new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-AssertionErrorTests.test_multiline_not_highlighted b/test/dynamo_expected_failures/CPython313-test_exceptions-AssertionErrorTests.test_multiline_not_highlighted new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-AttributeErrorTests.test_attributes b/test/dynamo_expected_failures/CPython313-test_exceptions-AttributeErrorTests.test_attributes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-AttributeErrorTests.test_getattr_has_name_and_obj b/test/dynamo_expected_failures/CPython313-test_exceptions-AttributeErrorTests.test_getattr_has_name_and_obj new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-AttributeErrorTests.test_getattr_has_name_and_obj_for_method b/test/dynamo_expected_failures/CPython313-test_exceptions-AttributeErrorTests.test_getattr_has_name_and_obj_for_method new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testAttributes b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testAttributes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testChainingAttrs b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testChainingAttrs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testExceptionCleanupState b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testExceptionCleanupState new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testInfiniteRecursion b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testInfiniteRecursion new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testInvalidTraceback b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testInvalidTraceback new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testKeywordArgs b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testKeywordArgs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testMemoryErrorBigSource b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testMemoryErrorBigSource new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testRaising b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testRaising new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testSettingException b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testSettingException new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testSyntaxErrorMessage b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testSyntaxErrorMessage new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testSyntaxErrorMissingParens b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testSyntaxErrorMissingParens new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testSyntaxErrorOffset b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testSyntaxErrorOffset new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testWithTraceback b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.testWithTraceback new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_3114 b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_3114 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_MemoryError b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_MemoryError new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_WindowsError b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_WindowsError new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_assert_shadowing b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_assert_shadowing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_badisinstance b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_badisinstance new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_errno_ENOTDIR b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_errno_ENOTDIR new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_error_offset_continuation_characters b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_error_offset_continuation_characters new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_exception_cleanup_names b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_exception_cleanup_names new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_exception_cleanup_names2 b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_exception_cleanup_names2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_exception_target_in_nested_scope b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_exception_target_in_nested_scope new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_exception_with_doc b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_exception_with_doc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_close_cleanup_exc_state b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_close_cleanup_exc_state new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_del_cleanup_exc_state b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_del_cleanup_exc_state new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_finalizing_and_sys_exception b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_finalizing_and_sys_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking4 b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking4 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_next_cleanup_exc_state b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_next_cleanup_exc_state new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_send_cleanup_exc_state b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_send_cleanup_exc_state new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_throw_cleanup_exc_state b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_throw_cleanup_exc_state new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_invalid_delattr b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_invalid_delattr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_invalid_setattr b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_invalid_setattr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_invalid_setstate b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_invalid_setstate new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_memory_error_cleanup b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_memory_error_cleanup new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_memory_error_in_PyErr_PrintEx b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_memory_error_in_PyErr_PrintEx new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_memory_error_in_subinterp b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_memory_error_in_subinterp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_memory_error_subclasses b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_memory_error_subclasses new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_no_hang_on_context_chain_cycle2 b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_no_hang_on_context_chain_cycle2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_no_hang_on_context_chain_cycle3 b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_no_hang_on_context_chain_cycle3 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_notes b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_notes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_raise_does_not_create_context_chain_cycle b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_raise_does_not_create_context_chain_cycle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_error_cleanup b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_error_cleanup new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_in_except_handler b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_in_except_handler new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_normalizing_exception b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_normalizing_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_normalizing_infinite_exception b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_normalizing_infinite_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_normalizing_with_no_memory b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_recursion_normalizing_with_no_memory new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_setstate b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_setstate new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_trashcan_recursion b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_trashcan_recursion new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unhandled b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unhandled new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unicode_change_attributes b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unicode_change_attributes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unicode_error_str_does_not_crash b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unicode_error_str_does_not_crash new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unicode_errors_no_object b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unicode_errors_no_object new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unraisable b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unraisable new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_yield_in_nested_try_excepts b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_yield_in_nested_try_excepts new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ImportErrorTests.test_attributes b/test/dynamo_expected_failures/CPython313-test_exceptions-ImportErrorTests.test_attributes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ImportErrorTests.test_copy_pickle b/test/dynamo_expected_failures/CPython313-test_exceptions-ImportErrorTests.test_copy_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ImportErrorTests.test_non_str_argument b/test/dynamo_expected_failures/CPython313-test_exceptions-ImportErrorTests.test_non_str_argument new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ImportErrorTests.test_reset_attributes b/test/dynamo_expected_failures/CPython313-test_exceptions-ImportErrorTests.test_reset_attributes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-NameErrorTests.test_gh_111654 b/test/dynamo_expected_failures/CPython313-test_exceptions-NameErrorTests.test_gh_111654 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-NameErrorTests.test_issue45826 b/test/dynamo_expected_failures/CPython313-test_exceptions-NameErrorTests.test_issue45826 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-NameErrorTests.test_issue45826_focused b/test/dynamo_expected_failures/CPython313-test_exceptions-NameErrorTests.test_issue45826_focused new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-NameErrorTests.test_name_error_has_name b/test/dynamo_expected_failures/CPython313-test_exceptions-NameErrorTests.test_name_error_has_name new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_other_except b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_other_except new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_raise_in_except b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_raise_in_except new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_raise_in_with_exit b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_raise_in_with_exit new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_raise_simple b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_raise_simple new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_with b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_after_with new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_in_finally_except b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_in_finally_except new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_in_finally_normal b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_in_finally_normal new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_in_named_except b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_in_named_except new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_in_try b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_lineno_in_try new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_missing_lineno_shows_as_none b/test/dynamo_expected_failures/CPython313-test_exceptions-PEP626Tests.test_missing_lineno_shows_as_none new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_attributes_new_constructor b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_attributes_new_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_attributes_old_constructor b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_attributes_old_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_encodings b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_encodings new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_file_source b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_file_source new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_incorrect_constructor b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_incorrect_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_non_utf8 b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_non_utf8 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_range_of_offsets b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_range_of_offsets new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_string_source b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_string_source new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_subclass b/test/dynamo_expected_failures/CPython313-test_exceptions-SyntaxErrorTests.test_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-TestInvalidExceptionMatcher.test_except_star_invalid_exception_type b/test/dynamo_expected_failures/CPython313-test_exceptions-TestInvalidExceptionMatcher.test_except_star_invalid_exception_type new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-FormatFunctionsTestCase.test_getformat b/test/dynamo_expected_failures/CPython313-test_float-FormatFunctionsTestCase.test_getformat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-FormatTestCase.test_format b/test/dynamo_expected_failures/CPython313-test_float-FormatTestCase.test_format new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-FormatTestCase.test_format_testfile b/test/dynamo_expected_failures/CPython313-test_float-FormatTestCase.test_format_testfile new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-FormatTestCase.test_issue35560 b/test/dynamo_expected_failures/CPython313-test_float-FormatTestCase.test_issue35560 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-FormatTestCase.test_issue5864 b/test/dynamo_expected_failures/CPython313-test_float-FormatTestCase.test_issue5864 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_error_message b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_error_message new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_containment b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_containment new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_memoryview b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_memoryview new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_with_comma b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_with_comma new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_floatasratio b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_floatasratio new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_floatconversion b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_floatconversion new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_hash b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_hash new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_hash_nan b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_hash_nan new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_keyword_args b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_keyword_args new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_keywords_in_subclass b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_keywords_in_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_non_numeric_input_types b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_non_numeric_input_types new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_underscores b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_underscores new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_roundtrip b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_roundtrip new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_subclass b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-ReprTestCase.test_repr b/test/dynamo_expected_failures/CPython313-test_float-ReprTestCase.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-RoundTestCase.test_format_specials b/test/dynamo_expected_failures/CPython313-test_float-RoundTestCase.test_format_specials new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_float-RoundTestCase.test_matches_float_format b/test/dynamo_expected_failures/CPython313-test_float-RoundTestCase.test_matches_float_format new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generator_stop-TestPEP479.test_stopiteration_wrapping b/test/dynamo_expected_failures/CPython313-test_generator_stop-TestPEP479.test_stopiteration_wrapping new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_bad_exception b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_bad_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_gen_3_arg_deprecation_warning b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_gen_3_arg_deprecation_warning new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_return_stopiteration b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_return_stopiteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_return_tuple b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_return_tuple new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_stopiteration_error b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_stopiteration_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_tutorial_stopiteration b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_tutorial_stopiteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-FinalizationTest.test_frame_resurrect b/test/dynamo_expected_failures/CPython313-test_generators-FinalizationTest.test_frame_resurrect new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-FinalizationTest.test_lambda_generator b/test/dynamo_expected_failures/CPython313-test_generators-FinalizationTest.test_lambda_generator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-FinalizationTest.test_refcycle b/test/dynamo_expected_failures/CPython313-test_generators-FinalizationTest.test_refcycle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorCloseTest.test_close_releases_frame_locals b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorCloseTest.test_close_releases_frame_locals new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorStackTraceTest.test_send_with_yield_from b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorStackTraceTest.test_send_with_yield_from new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorStackTraceTest.test_throw_with_yield_from b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorStackTraceTest.test_throw_with_yield_from new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_ag_frame_f_back b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_ag_frame_f_back new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_copy b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_copy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_cr_frame_f_back b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_cr_frame_f_back new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_gi_frame_f_back b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_gi_frame_f_back new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_handle_frame_object_in_creation b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_handle_frame_object_in_creation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_name b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_name new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_pickle b/test/dynamo_expected_failures/CPython313-test_generators-GeneratorTest.test_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ModifyUnderlyingIterableTest.test_modify_f_locals b/test/dynamo_expected_failures/CPython313-test_generators-ModifyUnderlyingIterableTest.test_modify_f_locals new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ModifyUnderlyingIterableTest.test_new_gen_from_gi_code b/test/dynamo_expected_failures/CPython313-test_generators-ModifyUnderlyingIterableTest.test_new_gen_from_gi_code new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-SignalAndYieldFromTest.test_raise_and_yield_from b/test/dynamo_expected_failures/CPython313-test_generators-SignalAndYieldFromTest.test_raise_and_yield_from new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-YieldFromTests.test_generator_gi_yieldfrom b/test/dynamo_expected_failures/CPython313-test_generators-YieldFromTests.test_generator_gi_yieldfrom new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_denial_of_service_prevented_int_to_str b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_denial_of_service_prevented_int_to_str new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_denial_of_service_prevented_str_to_int b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_denial_of_service_prevented_str_to_int new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_disabled_limit b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_disabled_limit new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_int_from_other_bases b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_int_from_other_bases new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_int_max_str_digits_is_per_interpreter b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_int_max_str_digits_is_per_interpreter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_power_of_two_bases_unlimited b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_power_of_two_bases_unlimited new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_denial_of_service_prevented_int_to_str b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_denial_of_service_prevented_int_to_str new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_denial_of_service_prevented_str_to_int b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_denial_of_service_prevented_str_to_int new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_disabled_limit b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_disabled_limit new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_int_from_other_bases b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_int_from_other_bases new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_int_max_str_digits_is_per_interpreter b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_int_max_str_digits_is_per_interpreter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_max_str_digits b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_max_str_digits new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_max_str_digits_edge_cases b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_max_str_digits_edge_cases new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_power_of_two_bases_unlimited b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_power_of_two_bases_unlimited new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_sign_not_counted b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_sign_not_counted new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_underscores_ignored b/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_underscores_ignored new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_error_message b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_error_message new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_base_indexable b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_base_indexable new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_memoryview b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_memoryview new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_returns_int_subclass b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_returns_int_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_index b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_index new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_int b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_int new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_intconversion b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_intconversion new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_keyword_args b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_keyword_args new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_non_numeric_input_types b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_non_numeric_input_types new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_underscores b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_underscores new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_misbehavior_error_path_from_str b/test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_misbehavior_error_path_from_str new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_misbehavior_error_path_to_str b/test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_misbehavior_error_path_to_str new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_roundtrip b/test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_roundtrip new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_addmul b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_addmul new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_append b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_append new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_clear b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_constructor_exception_handling b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_constructor_exception_handling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_constructors b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_constructors new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains_fake b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains_fake new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains_order b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains_order new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_copy b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_copy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_count b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_count new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_count_index_remove_crashes b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_count_index_remove_crashes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delitem b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delslice b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delslice new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_deopt_from_append_list b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_deopt_from_append_list new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_equal_operator_modifying_operand b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_equal_operator_modifying_operand new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_exhausted_iterator b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_exhausted_iterator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extend b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extend new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extendedslicing b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extendedslicing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_free_after_iterating b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_free_after_iterating new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitem b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitem_error b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitem_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitemoverwriteiter b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitemoverwriteiter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_iadd b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_iadd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_imul b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_imul new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_index b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_index new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_init b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_insert b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_insert new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_iterator_pickle b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_iterator_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_keyword_args b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_keyword_args new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_keywords_in_subclass b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_keywords_in_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_index_modifing_operand b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_index_modifing_operand new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_resize_overflow b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_resize_overflow new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_lt_operator_modifying_operand b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_lt_operator_modifying_operand new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_no_comdat_folding b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_no_comdat_folding new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_overflow b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_overflow new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_pickle b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_pop b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_preallocation b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_preallocation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_remove b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_remove new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repeat b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repeat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repr b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repr_deep b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repr_deep new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repr_mutate b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repr_mutate new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_reverse b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_reverse new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_reversed b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_reversed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_reversed_pickle b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_reversed_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_set_subscript b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_set_subscript new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setitem b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setitem_error b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setitem_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setslice b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setslice new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_sort b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_sort new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_subscript b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_subscript new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_infinities b/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_infinities new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_nan_results b/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_nan_results new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_overflow b/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_overflow new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_single_round b/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_single_round new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_zero_result b/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_fma_zero_result new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_random b/test/dynamo_expected_failures/CPython313-test_math-FMATests.test_random new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_decimals b/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_decimals new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_fractions b/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_fractions new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances b/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCeil b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCeil new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testComb b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testComb new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDist b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDist new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorial b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorial new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialNonIntegers b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialNonIntegers new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFloor b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFloor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFsum b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFsum new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testGcd b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testGcd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testHypot b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testHypot new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testHypotAccuracy b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testHypotAccuracy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testIsqrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testIsqrt new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog2 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog2Exact b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog2Exact new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPerm b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPerm new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRemainder b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRemainder new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSumProd b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSumProd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_issue39871 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_issue39871 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_lcm b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_lcm new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_mtestfile b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_mtestfile new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_prod b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_prod new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_sumprod_accuracy b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_sumprod_accuracy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_sumprod_extended_precision_accuracy b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_sumprod_extended_precision_accuracy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_sumprod_stress b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_sumprod_stress new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_testfile b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_testfile new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_trunc b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_trunc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_delitem_hash_collision new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_detect_deletion_during_iteration b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_detect_deletion_during_iteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_highly_nested b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_highly_nested new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_highly_nested_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_override_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_keys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_read new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_write new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_468 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_468 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_copying b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_copying new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_delitem_hash_collision new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_detect_deletion_during_iteration b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_detect_deletion_during_iteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_setitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_setitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_equality b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_equality new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_free_after_iterating b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_free_after_iterating new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_fromkeys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_fromkeys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_highly_nested b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_highly_nested new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_highly_nested_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_init_calls b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_init_calls new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_linked_list_by_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_linked_list_by_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_linked_list_by_delete_key b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_linked_list_by_delete_key new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_delete_key b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_delete_key new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_delete_key_in_dict_eq b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_delete_key_in_dict_eq new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24347 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24347 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24348 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24348 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24667 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24667 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_iterators new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_iterators_pickling b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_iterators_pickling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_key_change_during_iteration b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_key_change_during_iteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_merge_operator b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_merge_operator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_move_to_end b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_move_to_end new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_move_to_end_issue25406 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_move_to_end_issue25406 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_ordered_dict_items_result_gc b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_ordered_dict_items_result_gc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_overridden_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_overridden_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_override_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_pickle_recursive b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_pickle_recursive new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_popitem_last b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_popitem_last new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_reduce_not_too_fat b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_reduce_not_too_fat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_reference_loop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_reference_loop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_repr b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_repr_recursive b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_repr_recursive new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_repr_recursive_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_repr_recursive_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sizeof b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sizeof new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sizeof_exact b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sizeof_exact new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_views b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_views new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_weakref_list_is_not_traversed b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_weakref_list_is_not_traversed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_yaml_linkage b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_yaml_linkage new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_468 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_468 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_copying b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_copying new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_delitem_hash_collision new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_detect_deletion_during_iteration b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_detect_deletion_during_iteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_dict_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_dict_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_dict_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_dict_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_dict_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_dict_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_equality b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_equality new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_free_after_iterating b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_free_after_iterating new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_highly_nested b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_highly_nested new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_highly_nested_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_init_calls b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_init_calls new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_linked_list_by_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_linked_list_by_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_linked_list_by_delete_key b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_linked_list_by_delete_key new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_delete_key b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_delete_key new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_delete_key_in_dict_eq b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_delete_key_in_dict_eq new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24347 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24347 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24348 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24348 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24667 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24667 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_iterators new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_iterators_pickling b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_iterators_pickling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_key_change_during_iteration b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_key_change_during_iteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_merge_operator b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_merge_operator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end_issue25406 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end_issue25406 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_ordered_dict_items_result_gc b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_ordered_dict_items_result_gc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_overridden_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_overridden_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_override_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_pickle_recursive b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_pickle_recursive new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_popitem_last b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_popitem_last new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_reduce_not_too_fat b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_reduce_not_too_fat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_reference_loop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_reference_loop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_repr b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_repr_recursive b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_repr_recursive new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_repr_recursive_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_repr_recursive_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sizeof b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sizeof new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sizeof_exact b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sizeof_exact new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_views b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_views new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_weakref_list_is_not_traversed b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_weakref_list_is_not_traversed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_yaml_linkage b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_yaml_linkage new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictWithSlotsCopyingTests.test_copying b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictWithSlotsCopyingTests.test_copying new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_constructor b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_keys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_write new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_add_after_full b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_add_after_full new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_change_order_on_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_change_order_on_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_bool new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_constructor b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_keys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_len b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_len new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_read new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_write new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_468 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_468 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_abc b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_abc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_copying b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_copying new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_delitem_hash_collision new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_detect_deletion_during_iteration b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_detect_deletion_during_iteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_setitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_setitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_dict_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_equality b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_equality new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_free_after_iterating b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_free_after_iterating new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_fromkeys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_fromkeys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_highly_nested b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_highly_nested new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_highly_nested_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_init_calls b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_init_calls new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue119004_attribute_error b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue119004_attribute_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24347 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24347 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24348 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24348 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24667 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24667 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_iterators new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_iterators_empty b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_iterators_empty new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_merge_operator b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_merge_operator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_move_to_end b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_move_to_end new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_move_to_end_issue25406 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_move_to_end_issue25406 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_ordered_dict_items_result_gc b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_ordered_dict_items_result_gc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_overridden_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_overridden_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_override_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_pickle_recursive b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_pickle_recursive new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_popitem_last b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_popitem_last new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_reduce_not_too_fat b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_reduce_not_too_fat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_reference_loop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_reference_loop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_reinsert b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_reinsert new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_repr b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_repr_recursive b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_repr_recursive new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_repr_recursive_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_repr_recursive_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_setitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_setitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_sizeof b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_sizeof new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_sorted_iterators new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_views b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_views new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_yaml_linkage b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_yaml_linkage new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_468 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_468 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_abc b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_abc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_copying b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_copying new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_delitem_hash_collision new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_detect_deletion_during_iteration b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_detect_deletion_during_iteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_delitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_setitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_setitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_dict_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_equality b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_equality new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_free_after_iterating b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_free_after_iterating new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_fromkeys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_fromkeys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_highly_nested b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_highly_nested new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_highly_nested_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_init_calls b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_init_calls new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue119004_attribute_error b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue119004_attribute_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24347 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24347 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24348 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24348 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24667 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24667 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_iterators new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_iterators_empty b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_iterators_empty new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_merge_operator b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_merge_operator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_move_to_end b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_move_to_end new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_move_to_end_issue25406 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_move_to_end_issue25406 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_ordered_dict_items_result_gc b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_ordered_dict_items_result_gc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_overridden_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_overridden_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_override_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_pickle_recursive b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_pickle_recursive new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_popitem_last b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_popitem_last new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_reduce_not_too_fat b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_reduce_not_too_fat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_reference_loop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_reference_loop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_reinsert b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_reinsert new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_repr b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_repr_recursive b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_repr_recursive new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_repr_recursive_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_repr_recursive_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_setitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_setitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_sizeof b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_sizeof new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_sorted_iterators new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_views b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_views new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_yaml_linkage b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_yaml_linkage new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictWithSlotsCopyingTests.test_copying b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictWithSlotsCopyingTests.test_copying new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_bool new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_constructor b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_keys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_len b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_len new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_read new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_write new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PySimpleLRUCacheTests.test_add_after_full b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PySimpleLRUCacheTests.test_add_after_full new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PySimpleLRUCacheTests.test_change_order_on_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PySimpleLRUCacheTests.test_change_order_on_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PySimpleLRUCacheTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PySimpleLRUCacheTests.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PySimpleLRUCacheTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PySimpleLRUCacheTests.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestCause.test_class_cause_nonexception_result b/test/dynamo_expected_failures/CPython313-test_raise-TestCause.test_class_cause_nonexception_result new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestCause.test_erroneous_cause b/test/dynamo_expected_failures/CPython313-test_raise-TestCause.test_erroneous_cause new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestContext.test_3611 b/test/dynamo_expected_failures/CPython313-test_raise-TestContext.test_3611 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestContext.test_c_exception_raise b/test/dynamo_expected_failures/CPython313-test_raise-TestContext.test_c_exception_raise new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestContext.test_context_manager b/test/dynamo_expected_failures/CPython313-test_raise-TestContext.test_context_manager new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestContext.test_reraise_cycle_broken b/test/dynamo_expected_failures/CPython313-test_raise-TestContext.test_reraise_cycle_broken new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestRaise.test_assert_with_tuple_arg b/test/dynamo_expected_failures/CPython313-test_raise-TestRaise.test_assert_with_tuple_arg new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestRaise.test_erroneous_exception b/test/dynamo_expected_failures/CPython313-test_raise-TestRaise.test_erroneous_exception new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestRaise.test_new_returns_invalid_instance b/test/dynamo_expected_failures/CPython313-test_raise-TestRaise.test_new_returns_invalid_instance new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestRemovedFunctionality.test_strings b/test/dynamo_expected_failures/CPython313-test_raise-TestRemovedFunctionality.test_strings new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestRemovedFunctionality.test_tuples b/test/dynamo_expected_failures/CPython313-test_raise-TestRemovedFunctionality.test_tuples new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestTraceback.test_accepts_traceback b/test/dynamo_expected_failures/CPython313-test_raise-TestTraceback.test_accepts_traceback new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestTraceback.test_sets_traceback b/test/dynamo_expected_failures/CPython313-test_raise-TestTraceback.test_sets_traceback new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestTracebackType.test_attrs b/test/dynamo_expected_failures/CPython313-test_raise-TestTracebackType.test_attrs new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_raise-TestTracebackType.test_constructor b/test/dynamo_expected_failures/CPython313-test_raise-TestTracebackType.test_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-DisplayHookTest.test_custom_displayhook b/test/dynamo_expected_failures/CPython313-test_sys-DisplayHookTest.test_custom_displayhook new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-DisplayHookTest.test_lost_displayhook b/test/dynamo_expected_failures/CPython313-test_sys-DisplayHookTest.test_lost_displayhook new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-DisplayHookTest.test_original_displayhook b/test/dynamo_expected_failures/CPython313-test_sys-DisplayHookTest.test_original_displayhook new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-ExceptHookTest.test_excepthook b/test/dynamo_expected_failures/CPython313-test_sys-ExceptHookTest.test_excepthook new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-ExceptHookTest.test_excepthook_bytes_filename b/test/dynamo_expected_failures/CPython313-test_sys-ExceptHookTest.test_excepthook_bytes_filename new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-ExceptHookTest.test_original_excepthook b/test/dynamo_expected_failures/CPython313-test_sys-ExceptHookTest.test_original_excepthook new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_asyncgen_hooks b/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_asyncgen_hooks new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_changing_sys_stderr_and_removing_reference b/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_changing_sys_stderr_and_removing_reference new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_default b/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_default new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_errors b/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_errors new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_gc_head_size b/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_gc_head_size new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_objecttypes b/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_objecttypes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_pythontypes b/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_pythontypes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_slots b/test/dynamo_expected_failures/CPython313-test_sys-SizeofTest.test_slots new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_attributes b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_attributes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_c_locale_surrogateescape b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_c_locale_surrogateescape new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_call_tracing b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_call_tracing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_clear_type_cache b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_clear_type_cache new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_current_exceptions b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_current_exceptions new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_current_frames b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_current_frames new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_debugmallocstats b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_debugmallocstats new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_dlopenflags b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_dlopenflags new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_executable b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_executable new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_exit b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_exit new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_exit_codes_under_repl b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_exit_codes_under_repl new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getallocatedblocks b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getallocatedblocks new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getdefaultencoding b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getdefaultencoding new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getfilesystemencoding b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getfilesystemencoding new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getframe b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getframe new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getframemodulename b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getframemodulename new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getrecursionlimit b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getrecursionlimit new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getwindowsversion b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_getwindowsversion new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_intern b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_intern new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_ioencoding b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_ioencoding new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_is_finalizing b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_is_finalizing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_is_gil_enabled b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_is_gil_enabled new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_issue20602 b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_issue20602 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_no_duplicates_in_meta_path b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_no_duplicates_in_meta_path new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_orig_argv b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_orig_argv new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_posix_locale_surrogateescape b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_posix_locale_surrogateescape new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_recursionlimit_recovery b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_recursionlimit_recovery new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_refcount b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_refcount new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_setrecursionlimit b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_setrecursionlimit new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_setrecursionlimit_to_depth b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_setrecursionlimit_to_depth new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_stdlib_dir b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_stdlib_dir new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_subinterp_intern_dynamically_allocated b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_subinterp_intern_dynamically_allocated new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_subinterp_intern_singleton b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_subinterp_intern_singleton new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_subinterp_intern_statically_allocated b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_subinterp_intern_statically_allocated new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_switchinterval b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_switchinterval new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_flags b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_flags new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_flags_no_instantiation b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_flags_no_instantiation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_getwindowsversion_no_instantiation b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_getwindowsversion_no_instantiation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_ignores_cleaning_up_user_data b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_ignores_cleaning_up_user_data new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_tracebacklimit b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_tracebacklimit new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_version_info_no_instantiation b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_sys_version_info_no_instantiation new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_custom_unraisablehook b/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_custom_unraisablehook new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_custom_unraisablehook_fail b/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_custom_unraisablehook_fail new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_original_unraisablehook b/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_original_unraisablehook new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_original_unraisablehook_err b/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_original_unraisablehook_err new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_original_unraisablehook_exception_qualname b/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_original_unraisablehook_exception_qualname new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_original_unraisablehook_wrong_type b/test/dynamo_expected_failures/CPython313-test_sys-UnraisableHookTest.test_original_unraisablehook_wrong_type new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_addmul b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_addmul new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_bug7466 b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_bug7466 new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_constructors b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_constructors new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains_fake b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains_fake new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains_order b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains_order new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_count b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_count new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_free_after_iterating b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_free_after_iterating new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitem b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitem_error b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitem_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitemoverwriteiter b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitemoverwriteiter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_hash_exact b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_hash_exact new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_index b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_index new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_iterator_pickle b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_iterator_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_keyword_args b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_keyword_args new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_keywords_in_subclass b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_keywords_in_subclass new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_no_comdat_folding b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_no_comdat_folding new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_pickle b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_repeat b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_repeat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_reversed_pickle b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_reversed_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_subscript b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_subscript new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_track_dynamic b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_track_dynamic new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_track_literals b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_track_literals new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_track_subtypes b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_track_subtypes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_all b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_all new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_bool b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_bool new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_clear b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_constructor b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_constructor new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_contains b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_contains new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_copy b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_copy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_eq b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_eq new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_fromkeys b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_fromkeys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_get b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_get new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_getitem b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_init b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_items b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_items new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_keys b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_keys new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_len b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_len new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_missing b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_missing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_mutatingiteration b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_mutatingiteration new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_pop b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_popitem b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_popitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_read b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_read new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_repr b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_repr_deep b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_repr_deep new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_setdefault b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_setdefault new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_update b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_update new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_values b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_values new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_write b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_write new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_add_specials b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_add_specials new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_addmul b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_addmul new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_append b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_append new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_clear b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_clear new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_constructor_exception_handling b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_constructor_exception_handling new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_constructors b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_constructors new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_contains b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_contains new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_contains_fake b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_contains_fake new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_contains_order b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_contains_order new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_copy b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_copy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_count b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_count new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_delitem b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_delitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_delslice b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_delslice new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_exhausted_iterator b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_exhausted_iterator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_extend b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_extend new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_extendedslicing b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_extendedslicing new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_free_after_iterating b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_free_after_iterating new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getitem b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getitem_error b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getitem_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getitemoverwriteiter b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getitemoverwriteiter new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getslice b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getslice new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_iadd b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_iadd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_imul b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_imul new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_index b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_index new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_init b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_init new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_insert b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_insert new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_len b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_len new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_minmax b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_minmax new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_mixedadd b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_mixedadd new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_mixedcmp b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_mixedcmp new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_pickle b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_pickle new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_pop b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_pop new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_radd_specials b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_radd_specials new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_remove b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_remove new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_repeat b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_repeat new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_repr b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_repr new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_repr_deep b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_repr_deep new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_reverse b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_reverse new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_reversed b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_reversed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_set_subscript b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_set_subscript new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_setitem b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_setitem new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_setitem_error b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_setitem_error new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_setslice b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_setslice new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice_assign_iterator b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice_assign_iterator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice_type b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice_type new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_sort b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_sort new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_subscript b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_subscript new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_truth b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_truth new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_userlist_copy b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_userlist_copy new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/TestSerialization.test_linear_relu_package_quantization_transforms b/test/dynamo_expected_failures/TestSerialization.test_linear_relu_package_quantization_transforms new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/TestTorch.test_print b/test/dynamo_expected_failures/TestTorch.test_print new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_skips/CPython313-test_dict-DictTest.test_container_iterator b/test/dynamo_skips/CPython313-test_dict-DictTest.test_container_iterator new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_skips/CPython313-test_dict-DictTest.test_dict_items_result_gc b/test/dynamo_skips/CPython313-test_dict-DictTest.test_dict_items_result_gc new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_skips/CPython313-test_dict-DictTest.test_dict_items_result_gc_reversed b/test/dynamo_skips/CPython313-test_dict-DictTest.test_dict_items_result_gc_reversed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_skips/CPython313-test_dict-DictTest.test_free_after_iterating b/test/dynamo_skips/CPython313-test_dict-DictTest.test_free_after_iterating new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_skips/CPython313-test_dict-DictTest.test_track_dynamic b/test/dynamo_skips/CPython313-test_dict-DictTest.test_track_dynamic new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_skips/CPython313-test_dict-DictTest.test_track_literals b/test/dynamo_skips/CPython313-test_dict-DictTest.test_track_literals new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_skips/CPython313-test_dict-DictTest.test_track_subtypes b/test/dynamo_skips/CPython313-test_dict-DictTest.test_track_subtypes new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/edge/CMakeLists.txt b/test/edge/CMakeLists.txt deleted file mode 100644 index 985b77202485d9..00000000000000 --- a/test/edge/CMakeLists.txt +++ /dev/null @@ -1,74 +0,0 @@ -cmake_minimum_required(VERSION 3.15) - -set(TORCH_ROOT ${CMAKE_CURRENT_LIST_DIR}/../..) -set(TEST_ROOT ${TORCH_ROOT}/test/edge) -set(OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/out) -file(GLOB_RECURSE all_python "${TORCH_ROOT}/torchgen/*.py") -include(${TORCH_ROOT}/cmake/public/utils.cmake) -append_cxx_flag_if_supported("-Wno-unused-private-field" CMAKE_CXX_FLAGS) - -# Generate unboxing kernels -set(GEN_COMMAND - Python::Interpreter -m torchgen.gen_executorch - --source-path=${TEST_ROOT} - --install-dir=${OUTPUT_DIRECTORY} - --tags-path=${TORCH_ROOT}/aten/src/ATen/native/tags.yaml - --aten-yaml-path=${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml - --use-aten-lib - --op-selection-yaml-path=${TEST_ROOT}/selected_operators.yaml - --custom-ops-yaml-path=${TEST_ROOT}/custom_ops.yaml - ) -set(GEN_COMMAND_sources - ${OUTPUT_DIRECTORY}/RegisterCodegenUnboxedKernelsEverything.cpp - ${OUTPUT_DIRECTORY}/RegisterCPUCustomOps.cpp - ${OUTPUT_DIRECTORY}/Functions.h - ${OUTPUT_DIRECTORY}/NativeFunctions.h - ${OUTPUT_DIRECTORY}/CustomOpsNativeFunctions.h - ) -message(STATUS "Generating sources for unboxing kernels ${GEN_COMMAND}") -add_custom_command( - COMMENT "Generating sources" - OUTPUT ${GEN_COMMAND_sources} - COMMAND ${GEN_COMMAND} - DEPENDS - ${all_python} - ${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml - ${TORCH_ROOT}/aten/src/ATen/native/tags.yaml - ${TEST_ROOT}/templates/Functions.h - ${TEST_ROOT}/templates/NativeFunctions.h - ${TEST_ROOT}/templates/RegisterCodegenUnboxedKernels.cpp - ${TEST_ROOT}/templates/RegisterDispatchKeyCustomOps.cpp - WORKING_DIRECTORY ${TORCH_ROOT} -) -add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources}) - -add_library(unbox_lib STATIC - ${GEN_COMMAND_sources} - ${TEST_ROOT}/operator_registry.cpp - ${TEST_ROOT}/custom_ops.cpp - ) -target_include_directories(unbox_lib PUBLIC ${TEST_ROOT} ${ATen_CPU_INCLUDE}) -target_link_libraries(unbox_lib PUBLIC torch_cpu) -target_compile_definitions(unbox_lib PUBLIC USE_ATEN_LIB) - -add_executable(test_edge_op_registration - ${TEST_ROOT}/test_operator_registration.cpp - ${TEST_ROOT}/test_main.cpp - ) - -target_compile_definitions(test_edge_op_registration PRIVATE USE_GTEST) - -target_link_libraries(test_edge_op_registration PRIVATE gtest_main unbox_lib) -if((CMAKE_CXX_COMPILER_ID MATCHES "AppleClang") OR (APPLE AND CMAKE_CXX_COMPILER_ID MATCHES "Clang")) - target_link_options(test_edge_op_registration PRIVATE - "-Wl,-force_load,$" - ) -elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU") - target_link_options(test_edge_op_registration PRIVATE - "-Wl,--whole-archive,$,--no-whole-archive" - ) -endif() -if(INSTALL_TEST) - set_target_properties(test_edge_op_registration PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") - install(TARGETS test_edge_op_registration DESTINATION bin) -endif() diff --git a/test/edge/Evalue.h b/test/edge/Evalue.h deleted file mode 100644 index 7038a7bdaa6701..00000000000000 --- a/test/edge/Evalue.h +++ /dev/null @@ -1,479 +0,0 @@ -#pragma once - -#include -/** - * WARNING: EValue is a class used by Executorch, for its boxed operators. It - * contains similar logic as `IValue` in PyTorch, by providing APIs to convert - * boxed values to unboxed values. - * - * It's mirroring a fbcode internal source file - * [`EValue.h`](https://www.internalfb.com/code/fbsource/xplat/executorch/core/values/Evalue.h). - * - * The reason why we are mirroring this class, is to make sure we have CI job - * coverage on torchgen logic, given that torchgen is used for both Executorch - * and PyTorch. - * - * If any of the logic here needs to be changed, please update fbcode version of - * `Evalue.h` as well. These two versions will be merged as soon as Executorch - * is in OSS (hopefully by Q2 2023). - */ -namespace torch { -namespace executor { - -#define ET_CHECK_MSG TORCH_CHECK_MSG -#define EXECUTORCH_FORALL_TAGS(_) \ - _(None) \ - _(Tensor) \ - _(String) \ - _(Double) \ - _(Int) \ - _(Bool) \ - _(ListBool) \ - _(ListDouble) \ - _(ListInt) \ - _(ListTensor) \ - _(ListScalar) \ - _(ListOptionalTensor) - -enum class Tag : uint32_t { -#define DEFINE_TAG(x) x, - EXECUTORCH_FORALL_TAGS(DEFINE_TAG) -#undef DEFINE_TAG -}; - -struct EValue; - -template -struct evalue_to_const_ref_overload_return { - using type = T; -}; - -template <> -struct evalue_to_const_ref_overload_return { - using type = const at::Tensor&; -}; - -template -struct evalue_to_ref_overload_return { - using type = T; -}; - -template <> -struct evalue_to_ref_overload_return { - using type = at::Tensor&; -}; - -/* - * Helper class used to correlate EValues in the executor table, with the - * unwrapped list of the proper type. Because values in the runtime's values - * table can change during execution, we cannot statically allocate list of - * objects at deserialization. Imagine the serialized list says index 0 in the - * value table is element 2 in the list, but during execution the value in - * element 2 changes (in the case of tensor this means the TensorImpl* stored in - * the tensor changes). To solve this instead they must be created dynamically - * whenever they are used. - */ -template -class EValObjectList { - public: - EValObjectList() = default; - /* - * Wrapped_vals is a list of pointers into the values table of the runtime - * whose destinations correlate with the elements of the list, unwrapped_vals - * is a container of the same size whose serves as memory to construct the - * unwrapped vals. - */ - EValObjectList(EValue** wrapped_vals, T* unwrapped_vals, int size) - : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {} - /* - * Constructs and returns the list of T specified by the EValue pointers - */ - at::ArrayRef get() const; - - private: - // Source of truth for the list - at::ArrayRef wrapped_vals_; - // Same size as wrapped_vals - mutable T* unwrapped_vals_; -}; - -// Aggregate typing system similar to IValue only slimmed down with less -// functionality, no dependencies on atomic, and fewer supported types to better -// suit embedded systems (ie no intrusive ptr) -struct EValue { - union Payload { - // When in ATen mode at::Tensor is not trivially copyable, this nested union - // lets us handle tensor as a special case while leaving the rest of the - // fields in a simple state instead of requiring a switch on tag everywhere. - union TriviallyCopyablePayload { - TriviallyCopyablePayload() : as_int(0) {} - // Scalar supported through these 3 types - int64_t as_int; - double as_double; - bool as_bool; - // TODO(jakeszwe): convert back to pointers to optimize size of this - // struct - at::ArrayRef as_string; - at::ArrayRef as_int_list; - at::ArrayRef as_double_list; - at::ArrayRef as_bool_list; - EValObjectList as_tensor_list; - EValObjectList> as_list_optional_tensor; - } copyable_union; - - // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor* - // here. - at::Tensor as_tensor; - - Payload() {} - ~Payload() {} - }; - - // Data storage and type tag - Payload payload; - Tag tag; - - // Basic ctors and assignments - EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {} - - EValue(EValue&& rhs) noexcept : tag(rhs.tag) { - moveFrom(std::move(rhs)); - } - - EValue& operator=(EValue&& rhs) & noexcept { - if (&rhs == this) { - return *this; - } - - destroy(); - moveFrom(std::move(rhs)); - return *this; - } - - EValue& operator=(EValue const& rhs) & { - // Define copy assignment through copy ctor and move assignment - *this = EValue(rhs); - return *this; - } - - ~EValue() { - destroy(); - } - - /****** None Type ******/ - EValue() : tag(Tag::None) { - payload.copyable_union.as_int = 0; - } - - bool isNone() const { - return tag == Tag::None; - } - - /****** Int Type ******/ - /*implicit*/ EValue(int64_t i) : tag(Tag::Int) { - payload.copyable_union.as_int = i; - } - - bool isInt() const { - return tag == Tag::Int; - } - - int64_t toInt() const { - ET_CHECK_MSG(isInt(), "EValue is not an int."); - return payload.copyable_union.as_int; - } - - /****** Double Type ******/ - /*implicit*/ EValue(double d) : tag(Tag::Double) { - payload.copyable_union.as_double = d; - } - - bool isDouble() const { - return tag == Tag::Double; - } - - double toDouble() const { - ET_CHECK_MSG(isDouble(), "EValue is not a Double."); - return payload.copyable_union.as_double; - } - - /****** Bool Type ******/ - /*implicit*/ EValue(bool b) : tag(Tag::Bool) { - payload.copyable_union.as_bool = b; - } - - bool isBool() const { - return tag == Tag::Bool; - } - - bool toBool() const { - ET_CHECK_MSG(isBool(), "EValue is not a Bool."); - return payload.copyable_union.as_bool; - } - - /****** Scalar Type ******/ - /// Construct an EValue using the implicit value of a Scalar. - /*implicit*/ EValue(at::Scalar s) { - if (s.isIntegral(false)) { - tag = Tag::Int; - payload.copyable_union.as_int = s.to(); - } else if (s.isFloatingPoint()) { - tag = Tag::Double; - payload.copyable_union.as_double = s.to(); - } else if (s.isBoolean()) { - tag = Tag::Bool; - payload.copyable_union.as_bool = s.to(); - } else { - ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized."); - } - } - - bool isScalar() const { - return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool; - } - - at::Scalar toScalar() const { - // Convert from implicit value to Scalar using implicit constructors. - - if (isDouble()) { - return toDouble(); - } else if (isInt()) { - return toInt(); - } else if (isBool()) { - return toBool(); - } else { - ET_CHECK_MSG(false, "EValue is not a Scalar."); - return c10::Scalar(); - } - } - - /****** Tensor Type ******/ - /*implicit*/ EValue(at::Tensor t) : tag(Tag::Tensor) { - // When built in aten mode, at::Tensor has a non trivial constructor - // destructor, so regular assignment to a union field is UB. Instead we must - // go through placement new (which causes a refcount bump). - new (&payload.as_tensor) at::Tensor(t); - } - - bool isTensor() const { - return tag == Tag::Tensor; - } - - at::Tensor toTensor() && { - ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); - return std::move(payload.as_tensor); - } - - at::Tensor& toTensor() & { - ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); - return payload.as_tensor; - } - - const at::Tensor& toTensor() const& { - ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); - return payload.as_tensor; - } - - /****** String Type ******/ - /*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) { - payload.copyable_union.as_string = at::ArrayRef(s, size); - } - - bool isString() const { - return tag == Tag::String; - } - - std::string_view toString() const { - ET_CHECK_MSG(isString(), "EValue is not a String."); - return std::string_view( - payload.copyable_union.as_string.data(), - payload.copyable_union.as_string.size()); - } - - /****** Int List Type ******/ - /*implicit*/ EValue(at::ArrayRef i) : tag(Tag::ListInt) { - payload.copyable_union.as_int_list = i; - } - - bool isIntList() const { - return tag == Tag::ListInt; - } - - at::ArrayRef toIntList() const { - ET_CHECK_MSG(isIntList(), "EValue is not an Int List."); - return payload.copyable_union.as_int_list; - } - - /****** Bool List Type ******/ - /*implicit*/ EValue(at::ArrayRef b) : tag(Tag::ListBool) { - payload.copyable_union.as_bool_list = b; - } - - bool isBoolList() const { - return tag == Tag::ListBool; - } - - at::ArrayRef toBoolList() const { - ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List."); - return payload.copyable_union.as_bool_list; - } - - /****** Double List Type ******/ - /*implicit*/ EValue(at::ArrayRef d) : tag(Tag::ListDouble) { - payload.copyable_union.as_double_list = d; - } - - bool isDoubleList() const { - return tag == Tag::ListDouble; - } - - at::ArrayRef toDoubleList() const { - ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List."); - return payload.copyable_union.as_double_list; - } - - /****** Tensor List Type ******/ - /*implicit*/ EValue(EValObjectList t) : tag(Tag::ListTensor) { - payload.copyable_union.as_tensor_list = t; - } - - bool isTensorList() const { - return tag == Tag::ListTensor; - } - - at::ArrayRef toTensorList() const { - ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List."); - return payload.copyable_union.as_tensor_list.get(); - } - - /****** List Optional Tensor Type ******/ - /*implicit*/ EValue(EValObjectList> t) - : tag(Tag::ListOptionalTensor) { - payload.copyable_union.as_list_optional_tensor = t; - } - - bool isListOptionalTensor() const { - return tag == Tag::ListOptionalTensor; - } - - at::ArrayRef> toListOptionalTensor() { - return payload.copyable_union.as_list_optional_tensor.get(); - } - - /****** ScalarType Type ******/ - at::ScalarType toScalarType() const { - ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); - return static_cast(payload.copyable_union.as_int); - } - - /****** MemoryFormat Type ******/ - at::MemoryFormat toMemoryFormat() const { - ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); - return static_cast(payload.copyable_union.as_int); - } - - template - T to() &&; - - template - typename evalue_to_ref_overload_return::type to() &; - - /** - * Converts the EValue to an optional object that can represent both T and - * an uninitialized state. - */ - template - inline std::optional toOptional() { - if (this->isNone()) { - return std::nullopt; - } - return this->to(); - } - - private: - // Pre cond: the payload value has had its destructor called - void clearToNone() noexcept { - payload.copyable_union.as_int = 0; - tag = Tag::None; - } - - // Shared move logic - void moveFrom(EValue&& rhs) noexcept { - if (rhs.isTensor()) { - new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); - rhs.payload.as_tensor.~Tensor(); - } else { - payload.copyable_union = rhs.payload.copyable_union; - } - tag = rhs.tag; - rhs.clearToNone(); - } - - // Destructs stored tensor if there is one - void destroy() { - // Necessary for ATen tensor to refcount decrement the intrusive_ptr to - // tensorimpl that got a refcount increment when we placed it in the evalue, - // no-op if executorch tensor #ifdef could have a - // minor performance bump for a code maintainability hit - if (isTensor()) { - payload.as_tensor.~Tensor(); - } else if (isTensorList()) { - for (auto& tensor : toTensorList()) { - tensor.~Tensor(); - } - } else if (isListOptionalTensor()) { - for (auto& optional_tensor : toListOptionalTensor()) { - optional_tensor.~optional(); - } - } - } - - EValue(const Payload& p, Tag t) : tag(t) { - if (isTensor()) { - new (&payload.as_tensor) at::Tensor(p.as_tensor); - } else { - payload.copyable_union = p.copyable_union; - } - } -}; - -#define EVALUE_DEFINE_TO(T, method_name) \ - template <> \ - inline evalue_to_ref_overload_return::type EValue::to()& { \ - return static_cast(this->method_name()); \ - } - -template <> -inline at::Tensor& EValue::to() & { - return this->toTensor(); -} - -EVALUE_DEFINE_TO(at::Scalar, toScalar) -EVALUE_DEFINE_TO(int64_t, toInt) -EVALUE_DEFINE_TO(bool, toBool) -EVALUE_DEFINE_TO(double, toDouble) -EVALUE_DEFINE_TO(std::string_view, toString) -EVALUE_DEFINE_TO(at::ScalarType, toScalarType) -EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat) -EVALUE_DEFINE_TO(std::optional, toOptional) -EVALUE_DEFINE_TO(at::ArrayRef, toIntList) -EVALUE_DEFINE_TO( - std::optional>, - toOptional>) -EVALUE_DEFINE_TO( - std::optional>, - toOptional>) -EVALUE_DEFINE_TO(at::ArrayRef>, toListOptionalTensor) -EVALUE_DEFINE_TO(at::ArrayRef, toDoubleList) -#undef EVALUE_DEFINE_TO - -template -at::ArrayRef EValObjectList::get() const { - for (size_t i = 0; i < wrapped_vals_.size(); i++) { - unwrapped_vals_[i] = wrapped_vals_[i]->template to(); - } - return at::ArrayRef{unwrapped_vals_, wrapped_vals_.size()}; -} - -} // namespace executor -} // namespace torch diff --git a/test/edge/custom_ops.cpp b/test/edge/custom_ops.cpp deleted file mode 100644 index cce09841127a55..00000000000000 --- a/test/edge/custom_ops.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include - -namespace custom { -namespace native { -at::Tensor& add_3_out(const at::Tensor& a, const at::Tensor& b, const at::Tensor& c, at::Tensor& out) { - out = a.add(b).add(c); - return out; -} -} -} diff --git a/test/edge/custom_ops.yaml b/test/edge/custom_ops.yaml deleted file mode 100644 index 2ff2db88f97372..00000000000000 --- a/test/edge/custom_ops.yaml +++ /dev/null @@ -1,4 +0,0 @@ -- func: custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!) - kernels: - - arg_meta: null - kernel_name: custom::add_3_out diff --git a/test/edge/event_tracer.h b/test/edge/event_tracer.h deleted file mode 100644 index 9a62df3f522c3f..00000000000000 --- a/test/edge/event_tracer.h +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -#pragma once - -namespace torch { -namespace executor { - -typedef uint32_t AllocatorID; -typedef int32_t ChainID; -typedef uint32_t DebugHandle; - -/** - * EventTracer is a class that users can inherit and implement to - * log/serialize/stream etc. the profiling and debugging events that are - * generated at runtime for a model. An example of this is the ETDump - * implementation in the SDK codebase that serializes these events to a - * flatbuffer. - */ -class EventTracer {}; - -struct EventTracerEntry {}; - -} // namespace executor -} // namespace torch diff --git a/test/edge/event_tracer_hooks.h b/test/edge/event_tracer_hooks.h deleted file mode 100644 index 086eae36ac88d2..00000000000000 --- a/test/edge/event_tracer_hooks.h +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include - -/** - * @file - * - * This file contains the hooks that are inserted across various parts of the - * core runtime code to call into the EventTracer class for logging of profiling - * and debugging events. Any calls made to the EventTracer from the runtime must - * be made via these hooks. - * Users shouldn't directly add these hooks in their code and it's meant only - * for usage in ExecuTorch internal code. - * - * The benefit of defining these hooks is that we can easily control whether or - * not we want to compile in the EventTracer code based on the status of the - * ET_EVENT_TRACER_ENABLED flag. - */ - -namespace torch { -namespace executor { -namespace internal { - -/** - * This class enables scope based profiling where needed using RAII. - * Profiling will be started when the object is created and will end - * when the object goes out of scope. - */ -class EventTracerProfileScope final { - public: - EventTracerProfileScope(EventTracer* event_tracer, const char* name) {}; - - ~EventTracerProfileScope() {}; - - private: - EventTracer* event_tracer_; - EventTracerEntry event_entry_; -}; - -/** - * This class enables scope based profiling where needed using RAII. - * Profiling will be started when the object is created and will end - * when the object goes out of scope. - */ -class EventTracerProfileOpScope final { - public: - EventTracerProfileOpScope(EventTracer* event_tracer, const char* name) {}; - - ~EventTracerProfileOpScope() {}; - - private: - EventTracer* event_tracer_; - EventTracerEntry event_entry_; -}; - -/** - * This class helps us set and then clear out the chain id and debug handle - * values stored in the event tracer class using RAII. This is typically called - * in the executor loop before entering the codegen layer to configure the chain - * id and debug handle of the current instruction being executed. - * After we return from the kernel execution we can then reset the chain id and - * debug handle to defaults when this object goes out of scope. - */ -class EventTracerProfileInstructionScope final { - public: - EventTracerProfileInstructionScope( - EventTracer* event_tracer, - ChainID chain_idx, - DebugHandle debug_handle) {}; - - ~EventTracerProfileInstructionScope() {}; - - private: - EventTracer* event_tracer_; -}; - -void event_tracer_log_evalue(EventTracer* event_tracer, EValue& evalue) { - (void)evalue; -} - -} // namespace internal -} // namespace executor -} // namespace torch diff --git a/test/edge/kernel_runtime_context.h b/test/edge/kernel_runtime_context.h deleted file mode 100644 index 74f6914a871840..00000000000000 --- a/test/edge/kernel_runtime_context.h +++ /dev/null @@ -1,44 +0,0 @@ -#pragma once - -#include "event_tracer.h" - -namespace torch { -namespace executor { - -/** - * Bucket type abstraction that contains many elements of runtime state that - * a kernel author may want available, but would otherwise be unable to access. - * - * Forwarded along to all operators when running in lean mode. NOTE: Will not be - * forwarded to operators if running in ATen mode as those operators do not - * expect to receive a KernelRuntimeContext and would not use it. - * - * This includes things like setting an error state, a scratch allocator for - * operators that need more then constant space, and a TensorResizer for dynamic - * shape tensors allowing programs to be more flexible with Tensor shape. - */ -class KernelRuntimeContext { - public: - /** - * Construct a new kernel runtime context along with an optional event tracer. - */ - KernelRuntimeContext(EventTracer* event_tracer = nullptr) - : event_tracer_(event_tracer) {} - - /** - * INTERNAL ONLY - * - * Returns a pointer to an instance of EventTracer to do profiling/debugging - * logging inside the codegen layer. This is only for internal usage inside - * the codegen layer and users should not be accessing this. - */ - EventTracer* internal_event_tracer() { - return event_tracer_; - } - - private: - EventTracer* event_tracer_; -}; - -} // namespace executor -} // namespace torch diff --git a/test/edge/operator_registry.cpp b/test/edge/operator_registry.cpp deleted file mode 100644 index 765afa66e7a19d..00000000000000 --- a/test/edge/operator_registry.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include -#include - -namespace torch { -namespace executor { - -KernelRegistry& getKernelRegistry() { - static KernelRegistry kernel_registry; - return kernel_registry; -} - -bool register_kernels(const ArrayRef& kernels) { - return getKernelRegistry().register_kernels(kernels); -} - -bool KernelRegistry::register_kernels( - const ArrayRef& kernels) { - for (const auto& kernel : kernels) { - this->kernels_map_[kernel.name_] = kernel.kernel_; - } - return true; -} - -bool hasKernelFn(const char* name) { - return getKernelRegistry().hasKernelFn(name); -} - -bool KernelRegistry::hasKernelFn(const char* name) { - auto kernel = this->kernels_map_.find(name); - return kernel != this->kernels_map_.end(); -} - -KernelFunction& getKernelFn(const char* name) { - return getKernelRegistry().getKernelFn(name); -} - -KernelFunction& KernelRegistry::getKernelFn(const char* name) { - auto kernel = this->kernels_map_.find(name); - TORCH_CHECK_MSG(kernel != this->kernels_map_.end(), "Kernel not found!"); - return kernel->second; -} - - -} // namespace executor -} // namespace torch diff --git a/test/edge/operator_registry.h b/test/edge/operator_registry.h deleted file mode 100644 index 3fd1708a8b715f..00000000000000 --- a/test/edge/operator_registry.h +++ /dev/null @@ -1,72 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "Evalue.h" -#include "kernel_runtime_context.h" - -#include - -namespace torch { -namespace executor { - -using KernelFunction = std::function; - -template -using ArrayRef = at::ArrayRef; - -#define EXECUTORCH_SCOPE_PROF(x) - -struct Kernel { - const char* name_; - KernelFunction kernel_; - - Kernel() = default; - - /** - * We are doing a copy of the string pointer instead of duplicating the string - * itself, we require the lifetime of the kernel name to be at least as long - * as the kernel registry. - */ - explicit Kernel(const char* name, KernelFunction func) - : name_(name), kernel_(func) {} -}; - -/** - * See KernelRegistry::hasKernelFn() - */ -bool hasKernelFn(const char* name); - -/** - * See KernelRegistry::getKernelFn() - */ -KernelFunction& getKernelFn(const char* name); - - -[[nodiscard]] bool register_kernels(const ArrayRef&); - -struct KernelRegistry { - public: - KernelRegistry() : kernelRegSize_(0) {} - - bool register_kernels(const ArrayRef&); - - /** - * Checks whether an kernel with a given name is registered - */ - bool hasKernelFn(const char* name); - - /** - * Checks whether an kernel with a given name is registered - */ - KernelFunction& getKernelFn(const char* name); - - private: - std::map kernels_map_; - uint32_t kernelRegSize_; -}; - -} // namespace executor -} // namespace torch diff --git a/test/edge/selected_operators.yaml b/test/edge/selected_operators.yaml deleted file mode 100644 index 70545ae216f65c..00000000000000 --- a/test/edge/selected_operators.yaml +++ /dev/null @@ -1,463 +0,0 @@ -build_features: [] -custom_classes: [] -include_all_non_op_selectives: false -include_all_operators: false -kernel_metadata: {} -et_kernel_metadata: - custom::add_3.out: - - v1/6;0,1,2,3|6;0,1,2,3|6;0,1,2,3 - - v1/3;0,1,2,3|3;0,1,2,3|3;0,1,2,3 - aten::add.out: - - v1/6;0,1,2,3|6;0,1,2,3|6;0,1,2,3 - - v1/3;0,1,2,3|3;0,1,2,3|3;0,1,2,3 -operators: - aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::_reshape_alias_copy.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::_softmax.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::_to_copy.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::_unique2.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::add.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::addmm.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::avg_pool2d.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::baddbmm.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::bitwise_and.Tensor_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::bmm.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::cat.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::clamp.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::clone.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::constant_pad_nd.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::conv1d.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::convolution.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::cumsum.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::detach_copy.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::div.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::embedding.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::eq.Scalar_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::eq.Tensor_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::exp.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::expand_copy.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::floor_divide.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::gelu.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::grid_sampler_2d.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::gt.Scalar_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::index.Tensor_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::index_put.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::index_select.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::leaky_relu.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::linalg_inv_ex.inverse: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::logit.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::masked_fill.Scalar_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::max.unary_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::max_pool2d_with_indices.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::mean.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::minimum.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::mm.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::mul.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::native_batch_norm.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::native_layer_norm.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::ne.Scalar_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::nonzero.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::permute_copy.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::pixel_shuffle.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::relu.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::remainder.Scalar_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::repeat.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::round.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::rsub.Scalar_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::select_copy.int_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::sigmoid.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::slice_copy.Tensor_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::softplus.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::sort.values: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::split_copy.Tensor_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::split_with_sizes_copy.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::stack.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::sub.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::sum.IntList_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::tanh.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::topk.values: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::transpose_copy.int_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::unbind_copy.int_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::unsafe_split.Tensor_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::unsqueeze_copy.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::upsample_bilinear2d.vec_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::upsample_nearest2d.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::upsample_nearest2d.vec_out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::view_copy.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - aten::zeros_like.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true - custom::add_3.out: - debug_info: - - functions.yaml - include_all_overloads: false - is_root_operator: true - is_used_for_training: true diff --git a/test/edge/templates/Functions.h b/test/edge/templates/Functions.h deleted file mode 100644 index 5355b3890f8cfa..00000000000000 --- a/test/edge/templates/Functions.h +++ /dev/null @@ -1,25 +0,0 @@ -// clang-format off -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -// ${generated_comment} - -${static_dispatch_extra_headers} - -namespace torch { -namespace executor { - -${Functions_declarations} - -} // namespace executor -} // namespace torch diff --git a/test/edge/templates/NativeFunctions.h b/test/edge/templates/NativeFunctions.h deleted file mode 100644 index c71a4ea2ec0c82..00000000000000 --- a/test/edge/templates/NativeFunctions.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -// ${generated_comment} - -#ifdef TORCH_ASSERT_NO_OPERATORS -#error This change adds a dependency on native_functions.yaml, \ - meaning the file will need to be re-compiled every time an operator \ - is changed or added. Consider if your change would be better placed in \ - another file, or if a more specific header might achieve the same goal. \ - See NOTE: [Tensor vs. TensorBase] -#endif - -#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) -#error This change adds a dependency on all pytorch operators, meaning the \ - file will need to be re-compiled every time an operator is changed or added. \ - Consider including a specific operator from \ - and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -${nativeFunctions_declarations} diff --git a/test/edge/templates/RegisterCodegenUnboxedKernels.cpp b/test/edge/templates/RegisterCodegenUnboxedKernels.cpp deleted file mode 100644 index 40c6779e93951d..00000000000000 --- a/test/edge/templates/RegisterCodegenUnboxedKernels.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include -#include -#include "${fn_header}" // Generated Function import headers - -namespace torch { -namespace executor { - -using namespace internal; - -namespace { -using KernelArrayRef = ::at::ArrayRef<::torch::executor::Kernel>; - -static Kernel kernels_to_register[] = { - ${unboxed_kernels} // Generated operators -}; - -// Explicitly convert to ArrayRef, so that the API can take an empty C array of -// Kernels. -static KernelArrayRef kernel_array_ref( - kernels_to_register, - kernels_to_register + sizeof(kernels_to_register) / sizeof(Kernel)); - -// Return value not used. Keep the static variable assignment to register -// operators in static initialization time. -static auto success_with_kernel_reg = register_kernels(kernel_array_ref); -} // namespace -} // namespace executor -} // namespace torch diff --git a/test/edge/templates/RegisterDispatchKeyCustomOps.cpp b/test/edge/templates/RegisterDispatchKeyCustomOps.cpp deleted file mode 100644 index 14c3d085f9323f..00000000000000 --- a/test/edge/templates/RegisterDispatchKeyCustomOps.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// clang-format off -// Generated code for registering custom operators into the dispatcher. - -#include -#include - -$ops_headers - -namespace torch { -namespace executor { -namespace function { - - -${dispatch_anonymous_definitions} - -// All out variants ops -${static_init_dispatch_registrations} - -namespace ${dispatch_namespace} -{ - ${dispatch_namespaced_definitions} - -} // namespace ${dispatch_namespace} - -} // namespace function -} // namespace executor -} // namespace torch diff --git a/test/edge/templates/RegisterKernels.h b/test/edge/templates/RegisterKernels.h deleted file mode 100644 index 3c7ecff50b517b..00000000000000 --- a/test/edge/templates/RegisterKernels.h +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// ${generated_comment} -// Exposing an API for registering all kernels at once. -#include -#include -#include -#include - -namespace torch { -namespace executor { - -Error register_all_kernels(); - -} // namespace executor -} // namespace torch diff --git a/test/edge/templates/RegisterSchema.cpp b/test/edge/templates/RegisterSchema.cpp deleted file mode 100644 index f2ba92a4305fc8..00000000000000 --- a/test/edge/templates/RegisterSchema.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// ${generated_comment} -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include - -namespace at { -TORCH_LIBRARY_FRAGMENT(aten, m) { - ${aten_schema_registrations}; -} -$schema_registrations -} // namespace at diff --git a/test/edge/test_main.cpp b/test/edge/test_main.cpp deleted file mode 100644 index 5c5cabccaaaa0c..00000000000000 --- a/test/edge/test_main.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include - -std::string add_negative_flag(const std::string& flag) { - std::string filter = ::testing::GTEST_FLAG(filter); - if (filter.find('-') == std::string::npos) { - filter.push_back('-'); - } else { - filter.push_back(':'); - } - filter += flag; - return filter; -} -int main(int argc, char* argv[]) { - ::testing::InitGoogleTest(&argc, argv); - ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA"); - - return RUN_ALL_TESTS(); -} diff --git a/test/edge/test_operator_registration.cpp b/test/edge/test_operator_registration.cpp deleted file mode 100644 index eed4c31c70c8a4..00000000000000 --- a/test/edge/test_operator_registration.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "kernel_runtime_context.h" -#include "operator_registry.h" - -#include - -namespace torch { -namespace executor { - -// add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) -TEST(OperatorRegistrationTest, Add) { - EValue values[4]; - values[0] = EValue(at::ones({2, 3})); - values[1] = EValue(at::ones({2, 3})); - values[2] = EValue(int64_t(1)); - values[3] = EValue(at::zeros({2, 3})); - ASSERT_TRUE(hasKernelFn("aten::add.out")); - auto op = getKernelFn("aten::add.out"); - - EValue* kernel_values[4]; - for (size_t i = 0; i < 4; i++) { - kernel_values[i] = &values[i]; - } - KernelRuntimeContext context{}; - op(context, kernel_values); - at::Tensor expected = at::ones({2, 3}); - expected = at::fill(expected, 2); - ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); - -} - -// custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!) -TEST(OperatorRegistrationTest, CustomAdd3) { - EValue values[4]; - values[0] = EValue(at::ones({2, 3})); - values[1] = EValue(at::ones({2, 3})); - values[2] = EValue(at::ones({2, 3})); - values[3] = EValue(at::zeros({2, 3})); - ASSERT_TRUE(hasKernelFn("custom::add_3.out")); - auto op = getKernelFn("custom::add_3.out"); - - EValue* kernel_values[4]; - for (size_t i = 0; i < 4; i++) { - kernel_values[i] = &values[i]; - } - KernelRuntimeContext context{}; - op(context, kernel_values); - at::Tensor expected = at::ones({2, 3}); - expected = at::fill(expected, 3); - ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); - -} -} // namespace executor -} // namespace torch diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index e98eb91de8b398..042959c22cd4a7 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1094,6 +1094,8 @@ aten::randint.low_generator_out aten::randint.low_out aten::randint.out aten::randint_like +aten::randint_like.Tensor +aten::randint_like.Tensor_out aten::randint_like.low_dtype aten::randint_like.low_dtype_out aten::randint_like.out diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 34e82d8e9e1076..6cf819958fccf5 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -181,9 +181,12 @@ def forward(self, a, b): self.assertEqual(len(report.op_profiles), 1) self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1) - with torch._library.fake_profile.unsafe_generate_fake_kernels( - report.op_profiles - ), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()): + with ( + torch._library.fake_profile.unsafe_generate_fake_kernels( + report.op_profiles + ), + FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()), + ): torch.ops.mylib.foo8(*new_inp) # Existing registration has been updated to match the new @@ -319,11 +322,7 @@ def forward(self, x): ep = draft_export(M(), (torch.tensor([938]),)) report = ep._report - self.assertEqual(len(report.failures), 1) - self.assertEqual( - report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR - ) - self.assertEqual(report.failures[0].data["expr"], "Eq(Mod(10, 2*u1), 0)") + self.assertEqual(len(report.failures), 0) def test_dedup_data_dependent_failure(self): class M(torch.nn.Module): @@ -668,6 +667,36 @@ def forward(self, x, y): package_path=f.name, ) + @unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(0).total_memory < 2**28, + "Requires 16 MB GPU memory to pass the test; setting it higher to catch violations", + ) + def test_cuda_memory_usage(self): + # This used to OOM + class Foo(torch.nn.Module): + def forward(self, x): + for _ in range(100): + x = x + 1e-3 + return x + + # measure base usage + device = torch.device("cuda:0") + torch.cuda.reset_peak_memory_stats() + base_usage = torch.cuda.memory_allocated(device) + + # usage with input tensor allocated + x = torch.randn(2**10, 2**10).to(device) + x_usage = torch.cuda.memory_allocated(device) + + # draft export peak memory usage + draft_export(Foo(), (x,), strict=False) + peak_mem_usage = torch.cuda.memory_stats(device)["allocated_bytes.all.peak"] + + # right now it's actually exactly 4x; + # I guess original tensor, 2 tensors per add op, 1 for clone stored in node.meta["val"] + self.assertTrue((peak_mem_usage - base_usage) <= (x_usage - base_usage) * 4.0) + if __name__ == "__main__": run_tests() diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index e843e049efe6a6..641dd586edb596 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -9,7 +9,6 @@ from torch._dynamo.test_case import run_tests, TestCase from torch._functorch.aot_autograd import aot_export_module from torch.export import export, export_for_training -from torch.export._trace import _convert_ts_to_export_experimental from torch.export.experimental import _export_forward_backward, _sticky_export from torch.export.graph_signature import OutputKind from torch.testing import FileCheck @@ -17,93 +16,6 @@ @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") class TestExperiment(TestCase): - def test_torchscript_module_export(self): - class M(torch.nn.Module): - def forward(self, x): - return x.cos() + x.sin() - - model_to_trace = M() - inps = (torch.randn(4, 4),) - traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) - - exported_module = _convert_ts_to_export_experimental( - traced_module_by_torchscript, inps - ) - - self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps))) - - def test_torchscript_module_export_single_input(self): - class M(torch.nn.Module): - def forward(self, x): - return x.cos() + x.sin() - - model_to_trace = M() - inps = torch.randn(4, 4) - traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) - - exported_module = _convert_ts_to_export_experimental( - traced_module_by_torchscript, inps - ) - - self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps))) - - def test_torchscript_module_export_various_inputs_with_annotated_input_names(self): - def _check_equality_and_annotations(m_func, inps): - # Original module. - model_to_trace = m_func() - - # ExportedProgram from TorchScript module. - traced_module_by_torchscript = torch.jit.trace( - m_func(), example_inputs=inps - ) - exported_module = _convert_ts_to_export_experimental( - traced_module_by_torchscript, inps - ) - - # ExportedProgram from original module. - original_exported_module = torch.export.export_for_training( - m_func(), inps, strict=True - ) - - # Check whether input annotations are the same as tracing the original module. - orig_ph_name_list = [ - n.name - for n in original_exported_module.graph.nodes - if n.op == "placeholder" - ] - ph_name_list = [ - n.name for n in exported_module.graph.nodes if n.op == "placeholder" - ] - self.assertEqual(orig_ph_name_list, ph_name_list) - - # Check results equality. - self.assertTrue( - torch.allclose(exported_module(*inps), model_to_trace(*inps)) - ) - - # Tuple - class MTuple(torch.nn.Module): - def forward(self, x: Tuple[torch.Tensor]): - return x[0] + x[1] - - _check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),)) - - # List - class MList(torch.nn.Module): - def forward(self, x: List[torch.Tensor]): - return x[0] + x[1] - - _check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],)) - - # Dict - class MDict(torch.nn.Module): - def forward(self, x: Dict[str, torch.Tensor]): - return x["0"] + x["1"] - - _check_equality_and_annotations( - MDict, ({"0": torch.randn(4), "1": torch.randn(4)},) - ) - def test_joint_basic(self) -> None: class Module(torch.nn.Module): def __init__(self) -> None: @@ -409,7 +321,7 @@ def forward(self, x): sym_size_int_2 = torch.ops.aten.sym_size.int(x, 1) linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None eq = sym_size_int_2 == 4; sym_size_int_2 = None - _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16, 4) on node 'eq'"); eq = _assert_scalar_default = None + _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s27, 4) on node 'eq'"); eq = _assert_scalar_default = None return pytree.tree_unflatten((linear,), self._out_spec)""", ) diff --git a/test/export/test_export.py b/test/export/test_export.py index 846ba858a35b66..33c432d310431c 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -22,7 +22,7 @@ import torch.utils._pytree as pytree from functorch.experimental.control_flow import cond, map from torch import Tensor -from torch._decomp import decomposition_table +from torch._decomp import decomposition_table, get_decompositions from torch._dynamo.test_case import TestCase from torch._dynamo.testing import normalize_gm from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse @@ -213,8 +213,6 @@ class Inp3: PREDISPATCH_SUFFIX = "_pre_dispatch" TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp_strict" TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict" -LEGACY_EXPORT_STRICT_SUFFIX = "_legacy_export_strict" -LEGACY_EXPORT_NONSTRICT_SUFFIX = "_legacy_export_nonstrict" CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict" CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict" @@ -225,16 +223,6 @@ def is_non_strict_test(test_name): return not test_name.endswith(STRICT_SUFFIX) -def is_non_strict_legacy_test(test_name): - return test_name.endswith(LEGACY_EXPORT_NONSTRICT_SUFFIX) - - -def is_legacy_test(test_name): - return test_name.endswith(LEGACY_EXPORT_NONSTRICT_SUFFIX) or test_name.endswith( - LEGACY_EXPORT_STRICT_SUFFIX - ) - - def is_inline_and_install_strict_test(test_name: str) -> bool: return test_name.endswith(INLINE_AND_INSTALL_STRICT_SUFFIX) @@ -417,11 +405,12 @@ def forward(self, x, p): inputs = (torch.arange(10), torch.tensor(2)) - # Without transforming the unbacked int expression, we can't export. - with self.assertRaisesRegex( - RuntimeError, escape("Could not guard on data-dependent expression") - ): - export(Module(identity), inputs, strict=True) + # See https://github.com/pytorch/pytorch/issues/154574 + # # Without transforming the unbacked int expression, we can't export. + # with self.assertRaisesRegex( + # RuntimeError, escape("Could not guard on data-dependent expression") + # ): + # export(Module(identity), inputs, strict=True) # It works if we transform the whole unbacked int expression into # an unbacked int. @@ -570,6 +559,70 @@ def forward(self, x): self.assertEqual(counter, 1) + def test_from_node_metadata_export(self): + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1d = torch.nn.Conv1d(3, 3, 3) + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv2d(x) + x = x.squeeze(0) + x = self.conv1d(x) + return x + + def example_inputs(self): + return + + f = Foo() + inputs = (torch.randn(1, 3, 5, 5),) + gm = export(f, inputs).module() + from torch.fx.traceback import NodeSourceAction + + for node in gm.graph.nodes: + if node.op in ("placeholder", "output"): + continue + if "weight" in node.name or "bias" in node.name: + self.assertTrue( + node.meta["from_node"][-1].pass_name + == "ExportedProgram.module().unlift()" + ) + self.assertTrue( + node.meta["from_node"][-1].action + == [NodeSourceAction.CREATE, NodeSourceAction.REPLACE] + ) + else: + self.assertTrue( + node.meta["from_node"][-1].pass_name == "ExportedProgram.module()" + ) + self.assertTrue( + node.meta["from_node"][-1].action == [NodeSourceAction.CREATE] + ) + + ## re-export + gm2 = export(gm, inputs).module() + + for node in gm2.graph.nodes: + if node.op in ("placeholder", "output"): + continue + if "weight" in node.name or "bias" in node.name: + self.assertTrue( + node.meta["from_node"][-1].pass_name + == "ExportedProgram.module().unlift()" + ) + self.assertTrue( + node.meta["from_node"][-1].action + == [NodeSourceAction.CREATE, NodeSourceAction.REPLACE] + ) + else: + self.assertTrue( + node.meta["from_node"][-1].pass_name == "ExportedProgram.module()" + ) + self.assertTrue( + node.meta["from_node"][-1].action == [NodeSourceAction.CREATE] + ) + def test_bincount(self): class M(torch.nn.Module): def __init__(self): @@ -615,6 +668,120 @@ def forward(self, x, y): return (1,)""", ) + def test_inline_script_function(self): + @torch.jit.script + def _forward(x: torch.Tensor): + if torch.jit.is_scripting(): + return x.cos() + return x.sin() + + class M(torch.nn.Module): + def forward(self, x: torch.Tensor): + return _forward(x) + + x = torch.randn(3, 4) + ep = torch.export.export(M(), (x,)) + FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run( + str(ep.graph) + ) + FileCheck().check_count("torch.ops.aten.cos", 0, exactly=True).run( + str(ep.graph) + ) + res = ep.module()(x) + # We're inlining the original _forward function + # instead of the scripted function, so we get x.sin() + self.assertEqual(res, x.sin()) + + def test_inline_script_class_method(self): + class M(torch.nn.Module): + @staticmethod + @torch.jit.script + def _forward(x: torch.Tensor): + if torch.jit.is_scripting(): + return x.cos() + return x.sin() + + def forward(self, x: torch.Tensor): + return M._forward(x) + + x = torch.randn(3, 4) + ep = torch.export.export(M(), (x,)) + FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run( + str(ep.graph) + ) + FileCheck().check_count("torch.ops.aten.cos", 0, exactly=True).run( + str(ep.graph) + ) + res = ep.module()(x) + # We're inlining the original _forward function + # instead of the scripted function, so we get x.sin() + self.assertEqual(res, x.sin()) + + def test_inline_script_class_method_recursive(self): + f = 0.4 + i = 2 + s = "foo" + + @torch.jit.script + def _inner(x: torch.Tensor, y: torch.Tensor, f: float, i: int, s_len: int): + return x * y * f * i * s_len + + class M(torch.nn.Module): + @staticmethod + @torch.jit.script + def _forward(x: torch.Tensor, y: torch.Tensor, f: float, i: int, s: str): + if torch.jit.is_scripting(): + return _inner(x.cos(), y.cos(), f, i, len(s)) + return _inner(x.sin(), y.sin(), f, i, len(s)) + + def forward(self, x: torch.Tensor): + return M._forward(x, y=x, f=f, i=i, s=s) + + x = torch.randn(3, 4) + ep = torch.export.export(M(), (x,)) + FileCheck().check_count("torch.ops.aten.sin", 2, exactly=True).run( + str(ep.graph) + ) + FileCheck().check_count("torch.ops.aten.cos", 0, exactly=True).run( + str(ep.graph) + ) + res = ep.module()(x) + # We're inlining the original _forward function + # instead of the scripted function, so we get x.sin() + self.assertEqual(res, _inner(x.sin(), x.sin(), f, i, len(s))) + + def test_inline_script_method(self): + class M(torch.jit.ScriptModule): + @torch.jit.script_method + def _forward(self, x: torch.Tensor): + if torch.jit.is_scripting(): + return x.cos() + return x.sin() + + def forward(self, x): + return self._forward(x) + + class Wrapped(torch.nn.Module): + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, x): + return self.mod(x) + + x = torch.randn(3, 4) + ep = torch.export.export(Wrapped(M()), (x,)) + FileCheck().check_count("torch.ops.aten.sin", 1, exactly=True).run( + str(ep.graph) + ) + FileCheck().check_count("torch.ops.aten.cos", 0, exactly=True).run( + str(ep.graph) + ) + res = ep.module()(x) + # We're inlining the original _forward function + # instead of the scripted function, so we get x.sin() + self.assertEqual(res, x.sin()) + def test_no_tensor_computation_2(self): class Module(torch.nn.Module): def forward(self, x, y): @@ -906,8 +1073,6 @@ def forward(self, x, y): self.assertEqual(exp_out, ep.module()(*args)) @requires_gpu - @testing.expectedFailureLegacyExportNonStrict # Old export graph contains auto_functionalize not Triton wrapper - @testing.expectedFailureLegacyExportStrict # Old export graph contains auto_functionalize not Triton wrapper def test_export_custom_triton_kernel_mutable(self): @triton.jit def add_kernel( @@ -2026,8 +2191,6 @@ def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True): ep = export(model, inputs) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses def test_subclasses_parameterization(self): class Foo(torch.nn.Module): def __init__(self): @@ -2081,8 +2244,6 @@ def forward(self, x): self.assertEqual(res, ref_out) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses def test_subclasses_parameterization_nested(self): class Foo(torch.nn.Module): def __init__(self): @@ -2155,8 +2316,6 @@ def forward(self, x): res = ep.module()(ref_x) self.assertEqual(res, ref_out) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses def test_subclass_nested_attr_access(self): class Foo(torch.nn.Module): def __init__(self): @@ -2204,8 +2363,6 @@ def forward(self, x): ep = export(m, (ref_x,)) self.assertTrue(torch.allclose(ep.module()(ref_x), ref_out)) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses def test_subclass_nested_attr_access_submodule(self): class Bar(torch.nn.Module): def __init__(self): @@ -2260,8 +2417,6 @@ def forward(self, x): ep = export(m, (ref_x,)) self.assertTrue(torch.allclose(ep.module()(ref_x), ref_out)) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses def test_subclass_nested_attr_access_const_metadata(self): class Foo(torch.nn.Module): def __init__(self): @@ -2300,8 +2455,6 @@ def forward(self, x): ep = export(m, (ref_x,)) self.assertTrue(torch.allclose(ep.module()(ref_x), ref_out)) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses def test_subclass_nested_attr_access_const_metadata_not_top_level(self): class Foo(torch.nn.Module): def __init__(self): @@ -2340,8 +2493,6 @@ def forward(self, x): ep = export(m, (ref_x,)) self.assertTrue(torch.allclose(ep.module()(ref_x), ref_out)) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses def test_subclass_nested_attr_access_const_metadata_not_top_level(self): class Foo(torch.nn.Module): def __init__(self): @@ -2384,8 +2535,6 @@ def forward(self, x): ep = export(m, (ref_x,)) self.assertTrue(torch.allclose(ep.module()(ref_x), ref_out)) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses def test_subclass_nested_attr_access_complicated_metadata(self): class Foo(torch.nn.Module): def __init__(self): @@ -2488,8 +2637,6 @@ def _bool_tensor(nz): self.assertEqual(ep.module()(sample_input), nz) print(ep) - @testing.expectedFailureLegacyExportNonStrict # Trivial error, just need to move the error check earlier, for real users it wont matter - @testing.expectedFailureLegacyExportStrict # Trivial error, just need to move the error check earlier, for real users it wont matter def test_export_script_module(self): class Foo(torch.nn.Module): def forward(self, rv: torch.Tensor, t: torch.Tensor): @@ -2556,7 +2703,8 @@ def forward(self, x, y, z): } with self.assertRaisesRegex( torch._dynamo.exc.UserError, - r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.", + r"You marked.*but your code specialized it to be a constant.*" + r"If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO", ): export(Foo(), inputs, dynamic_shapes=shapes) @@ -4047,6 +4195,34 @@ def forward(self, x, y): vr = next(iter(ep.range_constraints.values())) self.assertEqual(vr.lower, 3) + def test_unbacked_linear_layer_norm_input(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(387, 128, bias=True) + self.layer_norm = torch.nn.LayerNorm(387) + + def forward(self, x, mask): + masked_select = x.masked_select(mask) + view = masked_select.view(-1, 387) + + linear = self.linear(view) + layer_norm = self.layer_norm(view) + return linear, layer_norm + + inputs = ( + torch.randn((256, 387), dtype=torch.float), + torch.randint(low=0, high=1, size=(256, 1), dtype=torch.bool), + ) + + model = MyModel() + ep = export(model, inputs) + + ref = model(*inputs) + actual = ep.module()(*inputs) + self.assertTrue(torch.allclose(ref[0], actual[0])) + self.assertTrue(torch.allclose(ref[1], actual[1])) + def test_dynamic_shapes_builder_basic(self): class M(torch.nn.Module): def forward(self, x, y, z): @@ -4180,7 +4356,7 @@ def expect_error(bad_args, run_time_msg, compile_time_msg): # 4->5, 4->5, 3->4 bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}), run_time_msg="Expected input.*to be equal to 3, but got 4", - compile_time_msg=r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.", + compile_time_msg=r"You marked.*but your code specialized it to be a constant.*If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO", ) def test_additional_inputs_constants(self): @@ -4828,8 +5004,6 @@ def forward(self, x, y): y2 = torch.arange(9).reshape((3, 3)) self.assertTrue(torch.allclose(ep.module()(x2, y2), model(x2, y2))) - @testing.expectedFailureLegacyExportNonStrict # Some small change due to unbacked values getting regenerated - @testing.expectedFailureLegacyExportStrict # Some small change due to unbacked values getting regenerated def test_export_max_nonstrict(self): class FooMax(torch.nn.Module): def forward(self, x): @@ -5148,8 +5322,6 @@ def forward(self, arg1, arg2, *args): self._test_export_same_as_eager(kw_func, args) @testing.expectedFailureCppRuntime - @testing.expectedFailureLegacyExportNonStrict - @testing.expectedFailureLegacyExportStrict def test_export_module(self): class Foo(torch.nn.Module): def __init__(self): @@ -5681,7 +5853,8 @@ def forward(self, x, y): with self.assertRaisesRegex( torch._dynamo.exc.UserError, ( - "You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.(.*\n)*.*" + "You marked.*but your code specialized it to be a constant.*" + "If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO(.*\n)*.*" "Suggested fixes:(.*\n)*.*" "batch = 10" ), @@ -5847,7 +6020,8 @@ def forward(self, x, y): with self.assertRaisesRegex( torch._dynamo.exc.UserError, ( - "You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO(.*\n)*" + "You marked.*but your code specialized it to be a constant.*" + "If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO(.*\n)*" "Suggested fixes:(.*\n)*.*" "K1 = 3" ), @@ -6095,7 +6269,7 @@ def forward(self, inputs): if node.op == "placeholder" ] self.assertEqual(len(input_shapes), 9) - self.assertTrue(all(shape == "torch.Size([s3])" for shape in input_shapes)) + self.assertEqual(len(set(input_shapes)), 1) def test_error_does_not_reference_eager_fallback(self): class Module(torch.nn.Module): @@ -6298,9 +6472,7 @@ def forward(self, x): if node.op == "call_function": ops.append(node.target) - if is_legacy_test(self._testMethodName) or is_training_ir_test( - self._testMethodName - ): + if is_training_ir_test(self._testMethodName): # aten.to will just specialize by decomposing to a no-op self.assertEqual( ops, @@ -6339,9 +6511,7 @@ def forward(self, x): if node.op == "call_function": ops.append(node.target) - if is_legacy_test(self._testMethodName) or is_training_ir_test( - self._testMethodName - ): + if is_training_ir_test(self._testMethodName): # aten.to will just specialize by decomposing to a no-op self.assertEqual( ops, @@ -6377,9 +6547,7 @@ def forward(self, x): for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - if is_legacy_test(self._testMethodName) or is_training_ir_test( - self._testMethodName - ): + if is_training_ir_test(self._testMethodName): # aten.to decomposes to no-op, add_ decomposes to functional variant self.assertEqual( ops, @@ -6427,9 +6595,7 @@ def forward(self, x): for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - if is_legacy_test(self._testMethodName) or is_training_ir_test( - self._testMethodName - ): + if is_training_ir_test(self._testMethodName): # aten.to decomposes to _to_copy self.assertEqual( ops, @@ -6490,6 +6656,109 @@ def forward(self, x): self.assertEqual(ep.module()(*inputs), model(*inputs)) + def test_export_aten_to_unflatten(self): + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.sum() + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.bar = Bar() + + def forward(self, x): + to = x.to(torch.float) + return self.bar(to).sum() + + inp = torch.randn(4, 4) + + ep = export( + Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",) + ) + mod = ep.module() + self.assertTrue(torch.allclose(mod(inp), Foo()(inp))) + + @testing.expectedFailureLegacyExportNonStrict + @testing.expectedFailureLegacyExportStrict + @testing.expectedFailureRetraceabilityNonStrict # when we retrace, ep.module() is hierarchical + @testing.expectedFailureRetraceability # when we retrace, ep.module() is hierarchical + def test_export_aten_to_unflatten_subclass(self): + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.sum() + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.bar = Bar() + self.param = torch.nn.Parameter( + TwoTensor(torch.ones(4, 4), torch.ones(4, 4)) + ) + + def forward(self, x): + to = self.param.to(torch.float) + return (self.bar(to).sum() + x.sum()).get_elem_a() + + inp = torch.randn(4, 4) + + with self.assertRaisesRegex( + ValueError, "It looks like p_param is a tensor subclass." + ): + export( + Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",) + ).run_decompositions({}) + + def test_export_aten_to_unflatten_subclass_pre_dispatch(self): + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.sum() + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.bar = Bar() + self.param = torch.nn.Parameter( + TwoTensor(torch.ones(4, 4), torch.ones(4, 4)) + ) + + def forward(self, x): + to = self.param.to(torch.float) + return (self.bar(to).sum() + x.sum()).get_elem_a() + + inp = torch.randn(4, 4) + + ep = export_for_training( + Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",) + ) + unflat = unflatten(ep).bar + self.assertExpectedInline( + str(unflat.graph).strip(), + """\ +graph(): + %_positional_arg_0 : [num_users=1] = placeholder[target=_positional_arg_0] + %_spec_0 : [num_users=1] = get_attr[target=_spec_0] + %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (((%_positional_arg_0,), {}), %_spec_0), kwargs = {}) + %to : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) + %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%to,), kwargs = {}) + %_spec_1 : [num_users=1] = get_attr[target=_spec_1] + %tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {}) + return tree_unflatten""", + ) + + with self.assertRaisesRegex( + ValueError, "It looks like p_param is a tensor subclass." + ): + ep.run_decompositions() + def test_float_conversion(self): class Module(torch.nn.Module): def forward(self, x): @@ -6500,9 +6769,7 @@ def forward(self, x): for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - if is_legacy_test(self._testMethodName) or is_training_ir_test( - self._testMethodName - ): + if is_training_ir_test(self._testMethodName): # .float() decomposes to no-op self.assertEqual( ops, @@ -6541,9 +6808,7 @@ def forward(self, x): for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - if is_legacy_test(self._testMethodName) or is_training_ir_test( - self._testMethodName - ): + if is_training_ir_test(self._testMethodName): # .float() decomposes to _to_copy() self.assertEqual( ops, @@ -6597,9 +6862,7 @@ def forward(self, x): for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - if is_legacy_test(self._testMethodName) or is_training_ir_test( - self._testMethodName - ): + if is_training_ir_test(self._testMethodName): # aten.to decomposes to no-op, add_ decomposes to functional variant self.assertEqual( ops, @@ -7331,6 +7594,24 @@ def forward(self, x, m): ep = torch.export.export_for_training(f, (torch.randn(2, 2), mod), strict=False) self.assertEqual(ref_out, ep.module()(ref_x, mod)) + def test_unbacked_noncontig_lin(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(32, 64) + + def forward(self, x): + n = x.item() + y = torch.empty(x).view(1, -1, 32) + return self.lin(y) + + mod = Foo() + x = torch.tensor([128]) + ep = export(mod, (x,)) + self.assertEqual(mod(x).shape, ep.module()(x).shape) + x = torch.tensor([512]) + self.assertEqual(mod(x).shape, ep.module()(x).shape) + def test_runtime_assert_for_prim(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -7395,9 +7676,7 @@ def forward(self, xs): RuntimeError, r".* expression Eq\(u0, 2\) \| Eq\(u0, 4\) \| Eq\(u0, 6\) .*" ): ep.module()(torch.tensor([3, 6, 5])) - with self.assertRaisesRegex( - RuntimeError, r".* expression Eq\(u2, 5\) & \(4 <= u1\) & \(u1 <= 8\) .*" - ): + with self.assertRaisesRegex(RuntimeError, r".* expression u[\d]+ <= 5 .*"): ep.module()(torch.tensor([6, 6, 6])) def test_redundant_assert_max_upper_bound(self): @@ -10442,7 +10721,8 @@ def forward(self, x, y): inp = (3, torch.randn(4, 4)) with self.assertRaisesRegex( torch._dynamo.exc.UserError, - r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.", + r"You marked.*but your code specialized it to be a constant.*" + r"If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO", ): ep = export( M(), @@ -10564,6 +10844,36 @@ def forward(self, x, y): self.assertEqual(ep.module()(3, 5), 8) self.assertEqual(ep.module()(5, 4), 9) + def test_dynamic_shapes_bounds(self): + class M(torch.nn.Module): + """ + Example: bounds on dynamic shapes + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor, zs: list[torch.Tensor]): + return x[:3] + y @ torch.cat(zs) + + m = M() + x = torch.randn(7, 5) + y = torch.randn(3, 6) + zs = [torch.randn(2, 5), torch.randn(4, 5)] + + from torch.export import Dim, ShapesCollection + + dynamic_shapes = ShapesCollection() + dynamic_shapes[x] = (Dim.DYNAMIC, Dim.DYNAMIC) + dynamic_shapes[y] = (Dim.DYNAMIC, Dim.DYNAMIC) + for z in zs: + dynamic_shapes[z] = (Dim.DYNAMIC, Dim.DYNAMIC) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + r"Constraints violated.*\n.*" + r"You marked L\['y'\].size\(\)\[0\] as dynamic but your code specialized it to be a constant \(3\).*" + r"If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.", + ): + export(m, (x, y, zs), dynamic_shapes=dynamic_shapes) + def test_unflatten_random_dag_const_preserving_3_1(self): class N2(torch.nn.Module): def __init__(self): @@ -10960,8 +11270,6 @@ def forward(self, x): ep2_result = ep2.module()(inp) self.assertTrue(torch.allclose(ep2_result, orig_result)) - @testing.expectedFailureLegacyExportNonStrict - @testing.expectedFailureLegacyExportStrict def test_constant_tensor_with_non_functional(self): class TestModel(torch.nn.Module): def __init__(self): @@ -10998,8 +11306,6 @@ def forward(self, c_params, x): return (add_2,)""", ) - @testing.expectedFailureLegacyExportNonStrict - @testing.expectedFailureLegacyExportStrict def test_constant_tensor_with_non_functional_nested(self): class SubMod(torch.nn.Module): def __init__(self): @@ -11268,6 +11574,61 @@ def forward(self, x): ) ) + def test_stack_trace_make_fx(self): + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear(x) + x *= 2.0 + return x + + inp = torch.randn(4, 4) + gm = torch.fx.experimental.proxy_tensor.make_fx( + Foo(), record_stack_traces=True + )( + inp, + ) + + # check correct lines are in stack trace + trace_mul = [node for node in gm.graph.nodes if node.name == "mul_"][ + 0 + ].meta.get("stack_trace", "") + self.assertTrue( + re.search(r"test_export.py.*in forward\n.*x \*= 2.0", trace_mul) + ) + trace_addmm = [node for node in gm.graph.nodes if node.name in ["addmm", "t"]][ + 0 + ].meta.get("stack_trace", "") + self.assertTrue( + re.search( + r"test_export.py.*in forward\n.*x = self.linear\(x\)", trace_addmm + ) + ) + + # check correct lines are still in stack trace after export + ep = export( + gm, + (torch.randn(4, 4),), + ).run_decompositions({}) + # check correct lines are in stack trace + trace_mul = [node for node in ep.graph.nodes if node.name == "mul"][0].meta.get( + "stack_trace", "" + ) + self.assertTrue( + re.search(r"test_export.py.*in forward\n.*x \*= 2.0", trace_mul) + ) + trace_addmm = [ + node for node in ep.graph.nodes if node.name in ["addmm", "linear"] + ][0].meta.get("stack_trace", "") + self.assertTrue( + re.search( + r"test_export.py.*in forward\n.*x = self.linear\(x\)", trace_addmm + ) + ) + @testing.expectedFailureSerDerNonStrict # register_constant needs to handle serialization @testing.expectedFailureSerDer # register_constant needs to handle serialization def test_register_constant(self): @@ -11913,6 +12274,26 @@ def forward(self, x): self.assertEqual(mod.foo, ep.module().foo) self.assertEqual(mod(torch.ones(4, 4)), ep.module()(torch.ones(4, 4))) + def test_unbacked_scalar_constructor(self): + class Foo(torch.nn.Module): + def forward(self, u, zuf, b): + return ( + torch.tensor([u.item()]), + torch.tensor([zuf.item()]), + torch.tensor([b.item()]), + ) + + mod = Foo() + inps = (torch.tensor([3]), torch.tensor([3.14]), torch.tensor([True])) + ep = torch.export.export(mod, inps) + for eager_out, ep_out in zip(mod(*inps), ep.module()(*inps)): + self.assertTrue(torch.allclose(eager_out, ep_out)) + + # test with other inputs + inps = (torch.tensor([5]), torch.tensor([-1.2]), torch.tensor([False])) + for eager_out, ep_out in zip(mod(*inps), ep.module()(*inps)): + self.assertTrue(torch.allclose(eager_out, ep_out)) + def test_symint_tensor_return(self): class Module(torch.nn.Module): def forward(self, x): @@ -12455,12 +12836,14 @@ def forward(self, w, x, y, z): "y": [Dim("dy")], # y & z incorrect, export is supposed to fail. "z": [Dim("dz")], # suggested fix should be to match these up. } - with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize. - torch._dynamo.exc.UserError, - r".*Constraints violated(.*\n)*" - r"Suggested fixes:(.*\n)*" - r".*dz = dy(.*\n)*", - ) as msg: + with ( + self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize. + torch._dynamo.exc.UserError, + r".*Constraints violated(.*\n)*" + r"Suggested fixes:(.*\n)*" + r".*dz = dy(.*\n)*", + ) as msg + ): export( Foo(), inputs, @@ -12697,8 +13080,6 @@ def forward(self, x): self.assertTrue(torch.allclose(a, torch.ones(4, 4))) self.assertTrue(torch.allclose(b, torch.ones(4, 4))) - @testing.expectedFailureLegacyExportNonStrict - @testing.expectedFailureLegacyExportStrict def test_constant_tensor_mutation(self): class M(torch.nn.Module): def __init__(self): @@ -12892,8 +13273,33 @@ def forward(self, x, y): ): ep.module()(torch.randn(10), torch.tensor(2)) - @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses - @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses + @testing.expectedFailureCppSerDes # TODO: When we deserialize we somehow hardcode sympy.lower to 2 + @testing.expectedFailureSerDerNonStrict + @testing.expectedFailureSerDer + @torch.fx.experimental._config.patch(backed_size_oblivious=True) + def test_baddbmm(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn(64, 64, 192, dtype=torch.float16) + ) + self.bias = torch.nn.Parameter( + torch.randn(64, 1, 192, dtype=torch.float16) + ) + + def forward(self, x): + return torch.ops.aten.baddbmm.default(self.bias, x, self.weight) + + x1 = torch.randn(64, 2048, 64, dtype=torch.float16) + x2 = torch.randn(64, 1, 64, dtype=torch.float16) + m = M() + + ep = export(m, (x2,), dynamic_shapes=({1: Dim("batch")},)) + + self.assertTrue(torch.allclose(m(x2), ep.module()(x2))) + self.assertTrue(torch.allclose(m(x1), ep.module()(x1))) + @testing.expectedFailureSerDerNonStrict # construtor is not serialized today @testing.expectedFailureSerDer # constructor is not serialized today @testing.expectedFailureRetraceability # dynamo doesn't work with FlatApply op @@ -13273,8 +13679,7 @@ def test_run_decompositions_keep_metadata(self): """Make sure the metadata is kept after exported program run_decompositions.""" @torch.library.custom_op("mylib::add", mutates_args=()) - def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - ... + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... @torch.library.register_fake("mylib::add") def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -13352,6 +13757,25 @@ def forward(self, x): self.assertTrue(torch.allclose(comp_mod(inp1), mod(inp1))) self.assertTrue(torch.allclose(comp_mod(inp2), mod(inp2))) + @torch.fx.experimental._config.patch(backed_size_oblivious=True) + def test_repeat_interleave(self): + class M(torch.nn.Module): + def forward(self, values, batch_sizes): + return torch.repeat_interleave( + torch.arange( + values.shape[0], + ), + batch_sizes, + ) + + inp = (torch.randint(0, 10, (1, 3)), torch.randint(0, 10, (1,))) + ep = torch.export.export( + M(), inp, dynamic_shapes=({0: Dim("dim")}, {0: Dim("dim")}) + ) + self.assertTrue(torch.allclose(M()(*inp), ep.module()(*inp))) + inp = (torch.randint(0, 10, (2, 3)), torch.randint(0, 10, (2,))) + self.assertTrue(torch.allclose(M()(*inp), ep.module()(*inp))) + def test_automatic_dynamic_shapes_simple_equality(self): # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism # leads to replacement symbols being set for equalities, and inferred relationships being checked @@ -13819,7 +14243,8 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, - r"You marked.*but your code specialized it to be a constant.*less strict API such as maybe_mark_dynamic or Dim.AUTO.", + r"You marked.*but your code specialized it to be a constant.*" + r"If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO", ): ep = export( Specialize(), @@ -14861,6 +15286,33 @@ def forward(self, x): len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs) ) + def test_input_output_no_stacktrace(self): + class M(torch.nn.Module): + def forward(self, x): + return x + x + + pyt_model = M() + example_inputs = (torch.ones(3, 3),) + + class Wrapper: + def __init__(self, model, example_inputs): + self.model = model + self.example_inputs = example_inputs + + def compile(self): + self.exp_program = torch.export.export( + self.model, args=self.example_inputs + ) + self.exp_program = self.exp_program.run_decompositions( + get_decompositions([torch.ops.aten.new_full]) + ) + + def forward(self, *args, **kwargs): + self.compile() + + wrapper = Wrapper(pyt_model, example_inputs) + wrapper.forward() + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestExportCustomClass(TorchTestCase): @@ -15015,6 +15467,48 @@ def forward(self, x, ranks): MyModel(), inps, dynamic_shapes=spec, strict=True ).run_decompositions({}) + def test_unbacked_contiguous(self): + class MyModel(torch.nn.Module): + def forward(self, x, mask): + masked_select = x.masked_select(mask) + view = masked_select.view(-1, 1548) + contig = view.contiguous() + return contig + 1 + + example_inputs = ( + torch.randn((768, 1548), dtype=torch.bfloat16), + torch.randint(low=0, high=1, size=(768, 1), dtype=torch.bool), + ) + spec = { + "x": [Dim.STATIC, Dim.STATIC], + "mask": [Dim.STATIC, Dim.STATIC], + } + + traced = export(MyModel(), example_inputs, strict=True) + self.assertExpectedInline( + traced.graph_module.code, + """\ +def forward(self, x, mask): + masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None + sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0) + sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None + ge = sym_size_int_1 >= 0 + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None + le = sym_size_int_1 <= 1188864 + _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 1188864 on node 'le'"); le = _assert_scalar_default_1 = None + mod = sym_size_int_1 % 1548 + eq_2 = mod == 0; mod = None + _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(Mod(u0, 1548), 0) on node 'eq_2'"); eq_2 = _assert_scalar_default_2 = None + floordiv = sym_size_int_1 // 1548 + mul_2 = 1548 * floordiv; floordiv = None + eq_3 = sym_size_int_1 == mul_2; sym_size_int_1 = mul_2 = None + _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(u0, 1548*((u0//1548))) on node 'eq_3'"); eq_3 = _assert_scalar_default_3 = None + view = torch.ops.aten.view.default(masked_select, [-1, 1548]); masked_select = None + add = torch.ops.aten.add.Tensor(view, 1); view = None + return (add,)""", + ignore_empty_lines=True, + ) + if __name__ == "__main__": run_tests() diff --git a/test/export/test_export_legacy.py b/test/export/test_export_legacy.py deleted file mode 100644 index 01c98e85b6d744..00000000000000 --- a/test/export/test_export_legacy.py +++ /dev/null @@ -1,80 +0,0 @@ -# Owner(s): ["oncall: export"] -import torch - - -try: - from . import test_export, testing -except ImportError: - import test_export # @manual=fbcode//caffe2/test:test_export-library - - import testing # @manual=fbcode//caffe2/test:test_export-library - -from torch.testing._internal.common_utils import IS_FBCODE - - -if IS_FBCODE: - from pyjk import PyPatchJustKnobs - - -test_classes = {} - - -def mocked_legacy_export(*args, **kwargs): - with PyPatchJustKnobs().patch( - "pytorch/export:export_training_ir_rollout_check", False - ): - return torch.export._trace._export(*args, **kwargs, pre_dispatch=True) - - -def mocked_legacy_export_non_strict(*args, **kwargs): - with PyPatchJustKnobs().patch( - "pytorch/export:export_training_ir_rollout_check", False - ): - if "strict" in kwargs: - return torch.export._trace._export(*args, **kwargs, pre_dispatch=True) - return torch.export._trace._export( - *args, **kwargs, pre_dispatch=True, strict=False - ) - - -def make_dynamic_cls(cls, strict): - if strict: - test_class = testing.make_test_cls_with_mocked_export( - cls, - "LegacyExport", - test_export.LEGACY_EXPORT_STRICT_SUFFIX, - mocked_legacy_export, - xfail_prop="_expected_failure_legacy_export", - ) - else: - test_class = testing.make_test_cls_with_mocked_export( - cls, - "LegacyExportNonStrict", - test_export.LEGACY_EXPORT_NONSTRICT_SUFFIX, - mocked_legacy_export_non_strict, - xfail_prop="_expected_failure_legacy_export_non_strict", - ) - - test_classes[test_class.__name__] = test_class - # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING - globals()[test_class.__name__] = test_class - test_class.__module__ = __name__ - return test_class - - -tests = [ - test_export.TestDynamismExpression, - test_export.TestExport, -] - -if IS_FBCODE: - for test in tests: - make_dynamic_cls(test, True) - make_dynamic_cls(test, False) - del test - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - if IS_FBCODE: - run_tests() diff --git a/test/export/test_schema.py b/test/export/test_schema.py index 27e8cd59f2da61..f184fead8b4130 100644 --- a/test/export/test_schema.py +++ b/test/export/test_schema.py @@ -404,6 +404,62 @@ def test_schema_check(self): next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) + def test_schema_comparison(self): + import torch._export.serde.schema as schema + + sig = schema.ModuleCallSignature( + inputs=[ + schema.Argument.create(as_none=True), + schema.Argument.create( + as_sym_int=schema.SymIntArgument.create(as_name="s0") + ), + ], + outputs=[ + schema.Argument.create( + as_sym_int=schema.SymIntArgument.create(as_name="s1") + ) + ], + in_spec="foo", + out_spec="bar", + forward_arg_names=["None", "symint"], + ) + # same content as sig + sig_same = schema.ModuleCallSignature( + inputs=[ + schema.Argument.create(as_none=True), + schema.Argument.create( + as_sym_int=schema.SymIntArgument.create(as_name="s0") + ), + ], + outputs=[ + schema.Argument.create( + as_sym_int=schema.SymIntArgument.create(as_name="s1") + ) + ], + in_spec="foo", + out_spec="bar", + forward_arg_names=["None", "symint"], + ) + # as_name of symint is different + sig_diff = schema.ModuleCallSignature( + inputs=[ + schema.Argument.create(as_none=True), + schema.Argument.create( + as_sym_int=schema.SymIntArgument.create(as_name="s0") + ), + ], + outputs=[ + schema.Argument.create( + as_sym_int=schema.SymIntArgument.create(as_name="s2") + ) + ], + in_spec="foo", + out_spec="bar", + forward_arg_names=["None", "symint"], + ) + self.assertEqual(sig, sig_same) + self.assertNotEqual(sig, sig_diff) + if __name__ == "__main__": run_tests() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 9dd21674dbc172..75a30ccf3da964 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -34,9 +34,12 @@ from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.export import Dim, export_for_training, load, save, unflatten +from torch.export.pt2_archive.constants import ARCHIVE_VERSION_PATH from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + IS_FBCODE, + IS_MACOS, IS_WINDOWS, parametrize, run_tests, @@ -944,10 +947,12 @@ def forward(self, x, y, z): ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3))) serialized_program = ExportedProgramSerializer(None, 2).serialize(ep) - serialized_program.exported_program.graph_module.signature.input_specs[ - 1 - ] = schema.InputSpec.create( - user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True)) + serialized_program.exported_program.graph_module.signature.input_specs[1] = ( + schema.InputSpec.create( + user_input=schema.UserInputSpec( + arg=schema.Argument.create(as_none=True) + ) + ) ) ep = ExportedProgramDeserializer(None).deserialize( serialized_program.exported_program, {}, {}, {} @@ -1491,6 +1496,7 @@ def forward(self, x): self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) + @unittest.skipIf(IS_WINDOWS, "Cannot modify file in windows") def test_save_file(self): class Foo(torch.nn.Module): def forward(self, x): @@ -1501,10 +1507,10 @@ def forward(self, x): inp = (torch.randn(2, 2),) ep = export_for_training(f, inp, strict=True) - with tempfile.NamedTemporaryFile() as f: - save(ep, f) + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + save(ep, f.name) f.seek(0) - loaded_ep = load(f) + loaded_ep = load(f.name) self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) @@ -1518,7 +1524,7 @@ def forward(self, x, y): inp = (torch.tensor([6]), torch.tensor([7])) ep = export_for_training(f, inp, strict=True) - with TemporaryFileName() as fname: + with TemporaryFileName(suffix=".pt2") as fname: path = Path(fname) save(ep, path) loaded_ep = load(path) @@ -1545,6 +1551,9 @@ def forward(self, x): self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) self.assertEqual(extra_files["extra.txt"], "moo") + @unittest.skipIf( + IS_FBCODE or IS_MACOS or IS_WINDOWS, "The file path is different in fbcode CI" + ) def test_version_error(self): class Foo(torch.nn.Module): def forward(self, x): @@ -1555,18 +1564,19 @@ def forward(self, x): ep = export_for_training(f, (torch.randn(1, 3),), strict=True) with self.assertRaisesRegex( - RuntimeError, r"Serialized version .* does not match our current" + ValueError, r"Saved archive version -1 does not match our current" ): - with tempfile.NamedTemporaryFile() as f: - save(ep, f) + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + save(ep, f.name) f.seek(0) + file_prefix = f.name.split("/")[2].split(".")[0] # Modify the version with zipfile.ZipFile(f, "a") as zipf: - zipf.writestr("version", "-1.1") + zipf.writestr(f"{file_prefix}/{ARCHIVE_VERSION_PATH}", "-1") f.seek(0) - load(f) + load(f.name) def test_save_constants(self): class Foo(torch.nn.Module): diff --git a/test/export/test_swap.py b/test/export/test_swap.py index 8833c3c94ae7be..d9b2269dc3243c 100644 --- a/test/export/test_swap.py +++ b/test/export/test_swap.py @@ -22,7 +22,9 @@ {"strict": False}, {"strict": True}, ], - class_name_func=lambda cls, _, params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}", + class_name_func=lambda cls, + _, + params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}", ) class TestSwap(TestCase): def test_unflatten_preserve_signature(self): diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 2719ff661b9bcb..3f8f11aca0e524 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -1303,7 +1303,7 @@ def forward(self, tq, x): tq1 = _empty_tensor_queue() tq1.push(x) - with self.assertRaisesRegex(RuntimeError, "is alising"): + with self.assertRaisesRegex(RuntimeError, "is aliasing"): torch.compile(mod, backend=backend)(tq1, x) @parametrize("backend", ["eager", "aot_eager"]) diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index b6d19ada81380f..adf74dc62b7000 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -233,7 +233,7 @@ def forward(self, x, y): new_inps = *inps, torch.rand(2, 3) with self.assertRaisesRegex( TypeError, - "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?", + "There is no flat args adapter specified. Are you sure you are calling this with the right arguments?", ): unflattened(new_inps) diff --git a/test/export/test_upgrader.py b/test/export/test_upgrader.py new file mode 100644 index 00000000000000..0c36b28750f90b --- /dev/null +++ b/test/export/test_upgrader.py @@ -0,0 +1,284 @@ +# Owner(s): ["oncall: export"] + +import json + +import torch +from torch.testing._internal.common_utils import TestCase + + +class TestUpgrader(TestCase): + def setUp(self) -> None: + # Register example upgraders dynamically + torch._C._export.register_example_upgraders() + + def tearDown(self) -> None: + # Clean up registered upgraders + torch._C._export.deregister_example_upgraders() + + def test_nn_module_stack_transformation_from_v0(self): + """Test that nn_module_stack strings are prepended with 'test_upgrader_' when upgrading from version 0""" + + # Create a mock JSON object that simulates version 0 schema + # with nn_module_stack as a string that needs to be upgraded + mock_json = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "nodes": [ + { + "target": "aten.add.Tensor", + "inputs": [], + "outputs": [], + "metadata": { + "nn_module_stack": "original_stack_info", + "other_field": "some_value", + }, + }, + { + "target": "aten.mul.Tensor", + "inputs": [], + "outputs": [], + "metadata": { + "nn_module_stack": "another_stack", + "stack_trace": "some trace", + }, + }, + ] + } + }, + } + + # Test the upgrader using the Python binding + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify nn_module_stack was prepended with "test_upgrader_" + nodes = upgraded_json["graph_module"]["graph"]["nodes"] + + # Check first node + first_node_metadata = nodes[0]["metadata"] + nn_stack = first_node_metadata["nn_module_stack"] + self.assertIsInstance(nn_stack, str) + self.assertEqual(nn_stack, "test_upgrader_original_stack_info") + # Other metadata should be unchanged + self.assertEqual(first_node_metadata["other_field"], "some_value") + + # Check second node + second_node_metadata = nodes[1]["metadata"] + nn_stack2 = second_node_metadata["nn_module_stack"] + self.assertIsInstance(nn_stack2, str) + self.assertEqual(nn_stack2, "test_upgrader_another_stack") + # Other metadata should be unchanged + self.assertEqual(second_node_metadata["stack_trace"], "some trace") + + def test_nn_module_stack_error_handling_invalid_type(self): + """Test error handling when nn_module_stack is not a string""" + + # Test case: nn_module_stack is not a string + mock_json_invalid_type = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "nodes": [ + { + "target": "aten.add.Tensor", + "inputs": [], + "outputs": [], + "metadata": { + "nn_module_stack": 42 # Invalid: should be string + }, + } + ] + } + }, + } + + with self.assertRaisesRegex( + RuntimeError, + "Error in upgrader 'version_0_upgrader_registered'", + ): + serialized_json = json.dumps(mock_json_invalid_type) + torch._C._export.upgrade(serialized_json, 2) + + def test_nodes_without_metadata_handled_gracefully(self): + """Test that nodes without metadata or nn_module_stack are handled gracefully""" + + mock_json = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "nodes": [ + { + "target": "aten.add.Tensor", + "inputs": [], + "outputs": [], + # No metadata field + }, + { + "target": "aten.mul.Tensor", + "inputs": [], + "outputs": [], + "metadata": { + "stack_trace": "some trace" + # No nn_module_stack field + }, + }, + ] + } + }, + } + + # Should not raise an error + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify nodes are unchanged + nodes = upgraded_json["graph_module"]["graph"]["nodes"] + self.assertEqual(len(nodes), 2) + + # First node should have no metadata + self.assertNotIn("metadata", nodes[0]) + + # Second node should have unchanged metadata + self.assertEqual(nodes[1]["metadata"]["stack_trace"], "some trace") + self.assertNotIn("nn_module_stack", nodes[1]["metadata"]) + + def test_field_renaming_chain_from_v0_complete(self): + """Test complete field renaming chain from v0: old_test_field -> new_test_field -> new_test_field2""" + + mock_json = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "inputs": [], + "outputs": [], + "nodes": [ + { + "target": "aten.add.Tensor", + "inputs": [], + "outputs": [], + "metadata": {"nn_module_stack": "test_stack"}, + } + ], + "old_test_field": "original_value", + "existing_field": "existing_value", + } + }, + } + + # Test the upgrader using the Python binding + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify complete field transformation: old_test_field -> new_test_field -> new_test_field2 + graph = upgraded_json["graph_module"]["graph"] + self.assertIn("new_test_field2", graph) + self.assertEqual(graph["new_test_field2"], "original_value") + self.assertNotIn("old_test_field", graph) + self.assertNotIn("new_test_field", graph) + + # Verify existing fields are preserved + self.assertEqual(graph["existing_field"], "existing_value") + self.assertIn("inputs", graph) + self.assertIn("outputs", graph) + self.assertIn("nodes", graph) + + # Verify the nn_module_stack was also upgraded by the other upgrader + nodes = graph["nodes"] + self.assertEqual( + nodes[0]["metadata"]["nn_module_stack"], "test_upgrader_test_stack" + ) + + def test_field_renaming_chain_from_v0_missing_field(self): + """Test that upgraders work gracefully when old_test_field doesn't exist""" + + mock_json = { + "schema_version": {"major": 0, "minor": 0}, + "graph_module": { + "graph": { + "inputs": [], + "outputs": [], + "nodes": [], + "existing_field": "existing_value", + } + }, + } + + # Test the upgrader using the Python binding + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 0 -> version 2 due to both v0 and v1 upgraders) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify no field transformations occurred since old_test_field didn't exist + graph = upgraded_json["graph_module"]["graph"] + self.assertNotIn("new_test_field2", graph) + self.assertNotIn("new_test_field", graph) + self.assertNotIn("old_test_field", graph) + + # Verify existing fields are preserved + self.assertEqual(graph["existing_field"], "existing_value") + self.assertIn("inputs", graph) + self.assertIn("outputs", graph) + self.assertIn("nodes", graph) + + def test_field_renaming_from_v1_partial_chain(self): + """Test partial upgrade chain starting from v1: new_test_field -> new_test_field2""" + + mock_json = { + "schema_version": {"major": 1, "minor": 0}, + "graph_module": { + "graph": { + "inputs": [], + "outputs": [], + "nodes": [], + "new_test_field": "test_value", + "existing_field": "existing_value", + } + }, + } + + # Test the upgrader using the Python binding + serialized_json = json.dumps(mock_json) + upgraded_json_str = torch._C._export.upgrade(serialized_json, 2) + upgraded_json = json.loads(upgraded_json_str) + + # Verify the schema version was updated (version 1 -> version 2 due to v1 upgrader only) + self.assertEqual(upgraded_json["schema_version"]["major"], 2) + self.assertEqual(upgraded_json["schema_version"]["minor"], 0) + + # Verify new_test_field was renamed to new_test_field2 + graph = upgraded_json["graph_module"]["graph"] + self.assertIn("new_test_field2", graph) + self.assertEqual(graph["new_test_field2"], "test_value") + self.assertNotIn("new_test_field", graph) + + # Verify existing fields are preserved + self.assertEqual(graph["existing_field"], "existing_value") + self.assertIn("inputs", graph) + self.assertIn("outputs", graph) + self.assertIn("nodes", graph) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 3fe6b66039ca55..d6cf2df4343ffe 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -348,8 +348,7 @@ def check_fc(existing_schemas): "\n\t".join(str(s) for s in matching_new_schemas), ) log.warning( - "Refer to following reasons for failure " - "to find FC schema:\n[\n%s\n]", + "Refer to following reasons for failure to find FC schema:\n[\n%s\n]", "\n\t".join(str(r) for r in possible_failure_reasons), ) broken_ops.append(str(existing_schema)) diff --git a/test/functorch/common_utils.py b/test/functorch/common_utils.py index 72a41dad777f65..4fa17b89f19ee0 100644 --- a/test/functorch/common_utils.py +++ b/test/functorch/common_utils.py @@ -523,15 +523,15 @@ def wrapped( dtypes=dtypes, ): name_parts = fn.__qualname__.split(".") - assert ( - len(name_parts) == 2 - ), "Decorator only applies to a test function of a test class" + assert len(name_parts) == 2, ( + "Decorator only applies to a test function of a test class" + ) test_case_name, base_test_name = name_parts for module_cls in module_classes: matching_module_infos = [m for m in module_db if m.module_cls == module_cls] - assert ( - len(matching_module_infos) == 1 - ), f"Couldn't find single ModuleInfo for {module_cls}" + assert len(matching_module_infos) == 1, ( + f"Couldn't find single ModuleInfo for {module_cls}" + ) module_info = matching_module_infos[0] decorators = list(module_info.decorators) new_decorator = DecorateInfo( diff --git a/test/functorch/functorch_additional_op_db.py b/test/functorch/functorch_additional_op_db.py index 03e83615a189b3..01539bf2eefb0b 100644 --- a/test/functorch/functorch_additional_op_db.py +++ b/test/functorch/functorch_additional_op_db.py @@ -322,13 +322,13 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): test_args = [ (3, ([1, 2],)), (3, (slice(0, 3),)), - (3, ([slice(0, 3), 1],)), - (3, ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],)), - (3, ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],)), - (3, ([slice(None), slice(None), [0, 3]],)), - (3, ([slice(None), [0, 3], slice(None)],)), - (3, ([[0, 3], slice(None), slice(None)],)), - (3, ([[0, 3], [1, 2], slice(None)],)), + (3, ((slice(0, 3), 1),)), + (3, (([0, 2, 3], [1, 3, 3], [0, 0, 2]),)), + (3, (([0, 0, 3], [1, 1, 3], [0, 0, 2]),)), + (3, ((slice(None), slice(None), [0, 3]),)), + (3, ((slice(None), [0, 3], slice(None)),)), + (3, (([0, 3], slice(None), slice(None)),)), + (3, (([0, 3], [1, 2], slice(None)),)), ( 3, ( @@ -337,20 +337,20 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): ], ), ), - (3, ([[0, 3], slice(None)],)), - (3, ([[0, 3], Ellipsis],)), - (3, ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],)), - (4, ([slice(None), adv_idx, adv_idx, slice(None)],)), - (4, ([slice(None), adv_idx, slice(None), adv_idx],)), - (4, ([adv_idx, slice(None), slice(None), adv_idx],)), - (4, ([slice(None), slice(None), adv_idx, adv_idx],)), - (4, ([Ellipsis, adv_idx, adv_idx],)), - (5, ([slice(None), slice(None), adv_idx, slice(None), adv_idx],)), - (5, ([slice(None), slice(None), adv_idx, adv_idx, slice(None)],)), - (5, ([slice(None), slice(None), adv_idx, None, adv_idx, slice(None)],)), - (6, ([slice(None), slice(None), slice(None), adv_idx, adv_idx],)), - (6, ([slice(None), slice(None), adv_idx, adv_idx, adv_idx],)), - (6, ([slice(None), slice(None), None, adv_idx, adv_idx, adv_idx],)), + (3, (([0, 3], slice(None)),)), + (3, (([0, 3], Ellipsis),)), + (3, (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),)), + (4, ((slice(None), adv_idx, adv_idx, slice(None)),)), + (4, ((slice(None), adv_idx, slice(None), adv_idx),)), + (4, ((adv_idx, slice(None), slice(None), adv_idx),)), + (4, ((slice(None), slice(None), adv_idx, adv_idx),)), + (4, ((Ellipsis, adv_idx, adv_idx),)), + (5, ((slice(None), slice(None), adv_idx, slice(None), adv_idx),)), + (5, ((slice(None), slice(None), adv_idx, adv_idx, slice(None)),)), + (5, ((slice(None), slice(None), adv_idx, None, adv_idx, slice(None)),)), + (6, ((slice(None), slice(None), slice(None), adv_idx, adv_idx),)), + (6, ((slice(None), slice(None), adv_idx, adv_idx, adv_idx),)), + (6, ((slice(None), slice(None), None, adv_idx, adv_idx, adv_idx),)), ] def get_shape(dim): @@ -400,20 +400,22 @@ def sample_inputs_aten_index_put(op_info, device, dtype, requires_grad, **kwargs adv_idx = torch.LongTensor([[0, 1], [2, 3]]) # self_shape, indices additional = [ - ((5, 6, 7, 8), [None, adv_idx, adv_idx, None]), - ((5, 6, 7, 8), [None, adv_idx, None, adv_idx]), - ((5, 6, 7, 8), [adv_idx, None, None, adv_idx]), - ((5, 6, 7, 8), [None, None, adv_idx, adv_idx]), - ((5, 6, 7, 8, 9), [None, None, adv_idx, None, adv_idx]), - ((5, 6, 7, 8, 9), [None, None, adv_idx, adv_idx, None]), - ((5, 6, 7, 8, 9, 10), [None, None, None, adv_idx, adv_idx]), - ((5, 6, 7, 8, 9, 10), [None, None, adv_idx, adv_idx, adv_idx]), + ((5, 6, 7, 8), (None, adv_idx, adv_idx, None)), + ((5, 6, 7, 8), (None, adv_idx, None, adv_idx)), + ((5, 6, 7, 8), (adv_idx, None, None, adv_idx)), + ((5, 6, 7, 8), (None, None, adv_idx, adv_idx)), + ((5, 6, 7, 8, 9), (None, None, adv_idx, None, adv_idx)), + ((5, 6, 7, 8, 9), (None, None, adv_idx, adv_idx, None)), + ((5, 6, 7, 8, 9, 10), (None, None, None, adv_idx, adv_idx)), + ((5, 6, 7, 8, 9, 10), (None, None, adv_idx, adv_idx, adv_idx)), ] for self_shape, indices in additional: for broadcast_value in [False, True]: inp = make_arg(self_shape) - tmp_indices = [slice(None) if idx is None else idx for idx in indices] + tmp_indices = tuple( + [slice(None) if idx is None else idx for idx in indices] + ) values_shape = inp[tmp_indices].shape if broadcast_value: values_shape = values_shape[3:] diff --git a/test/functorch/test_ac_knapsack.py b/test/functorch/test_ac_knapsack.py index f0a3c3916e6b9b..751a4c4d21859c 100644 --- a/test/functorch/test_ac_knapsack.py +++ b/test/functorch/test_ac_knapsack.py @@ -124,9 +124,7 @@ def test_recomputable_node_only_graph(self): ) def test_recomputable_node_only_graph_with_larger_graph_context(self): - recomputable_node_only_graph_with_larger_graph_context = ( - self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context - ) + recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950 expected_nodes = self.all_recomputable_banned_nodes # node1 does not have an indirect path to node5 because of node2 # node2 has an indirect path to node5 diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 33c1da771fe5bb..cf6ee336b86b61 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -690,6 +690,80 @@ def f(a, b): ] self.verify_aot_autograd(f, inp, keep_inp_mutations=True) + def _compile_autocast(self, device, *, forward_autocast): + with torch.library._scoped_library("mylib", "FRAGMENT") as m: + m.define("foo(Tensor x) -> Tensor") + m.impl("foo", torch.clone, "CompositeExplicitAutograd") + + def autocast(x): + return x + 1 + + m.impl("foo", autocast, "AutocastCPU") + m.impl("foo", autocast, "AutocastCUDA") + + foo = torch.ops.mylib.foo.default + + class Foo(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return foo(x) + + @staticmethod + def backward(ctx, grad): + (x,) = ctx.saved_tensors + return grad * foo(x) + + def fn(x): + with torch.amp.autocast(device, enabled=False): + return Foo.apply(x) + + x = torch.tensor(0.0, device=device, requires_grad=True) + if forward_autocast: + with ( + torch.amp.autocast(device), + torch._dynamo.config.patch(recompile_limit=999), + ): + out = torch.compile(fn, fullgraph=True, backend="aot_eager")(x) + else: + with torch._dynamo.config.patch(recompile_limit=999): + out = torch.compile(fn, fullgraph=True, backend="aot_eager")(x) + (grad,) = torch.autograd.grad(out, x) + return out, grad + + @torch._functorch.config.patch(backward_pass_autocast="same_as_forward") + def test_backward_pass_autocast_on(self): + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + for device in devices: + out, grad = self._compile_autocast(device, forward_autocast=True) + self.assertEqual(out, torch.zeros_like(out)) + self.assertEqual(grad, torch.ones_like(grad)) + + @torch._functorch.config.patch(backward_pass_autocast="off") + def test_backward_pass_autocast_off(self): + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + for device in devices: + out, grad = self._compile_autocast(device, forward_autocast=True) + self.assertEqual(out, torch.zeros_like(out)) + self.assertEqual(grad, torch.zeros_like(grad)) + + @torch._functorch.config.patch(backward_pass_autocast="off") + def test_backward_pass_autocast_custom(self): + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + for device in devices: + with torch._functorch.config.patch( + backward_pass_autocast=[{"device_type": device}] + ): + out, grad = self._compile_autocast(device, forward_autocast=False) + self.assertEqual(out, torch.zeros_like(out)) + self.assertEqual(grad, torch.ones_like(grad)) + @skipIfDynamoInput( "Test doesn't make sense with dynamo, which changes order of mutations" ) @@ -2568,8 +2642,9 @@ def setup_context(ctx, inputs, output): def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: return torch.ops._test._clone_create_graph(x, x1) - inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn( - 3, requires_grad=True + inp_x, inp_x1 = ( + torch.randn(3, requires_grad=True), + torch.randn(3, requires_grad=True), ) ref_x, ref_x1 = inp_x.clone(), inp_x1.clone() @@ -5283,11 +5358,12 @@ def fn(p, x): mod = TestMod(fn) inp = torch.randn(2) - with patch( - "functorch.compile.config.functionalize_rng_ops", True - ), self.assertRaisesRegex( - RuntimeError, - "Functionalized RNG is not currently supported in the aot_export", + with ( + patch("functorch.compile.config.functionalize_rng_ops", True), + self.assertRaisesRegex( + RuntimeError, + "Functionalized RNG is not currently supported in the aot_export", + ), ): aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) @@ -5390,6 +5466,32 @@ def f(x, mod_weight, mod_bias): self.assertEqual(get_num_ins_outs(fw_graph), (3, 6)) self.assertEqual(get_num_ins_outs(bw_graph), (6, 3)) + @unittest.skipIf(not USE_NETWORKX, "networkx not available") + def test_min_cut_partitioner_raise_getitems(self): + def f(x): + y = torch.split(x, x.size(0) // 2, dim=0) + a = y[0].sin() + b = y[1].cos() + return a + b + + _, bw_graph = get_fw_bw_graph(f, [torch.randn(4, 4, requires_grad=True)]) + + self.assertExpectedInline( + bw_graph.code.strip(), + """\ +def forward(self, primals_1, tangents_1): + split = torch.ops.aten.split.Tensor(primals_1, 2); primals_1 = None + getitem_1 = split[1] + getitem = split[0]; split = None + sin_1 = torch.ops.aten.sin.default(getitem_1); getitem_1 = None + neg = torch.ops.aten.neg.default(sin_1); sin_1 = None + mul = torch.ops.aten.mul.Tensor(tangents_1, neg); neg = None + cos_1 = torch.ops.aten.cos.default(getitem); getitem = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos_1); tangents_1 = cos_1 = None + cat = torch.ops.aten.cat.default([mul_1, mul]); mul_1 = mul = None + return (cat,)""", + ) + @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_min_cut_partitioner_save_shape(self): def f(x): @@ -7779,6 +7881,90 @@ def run(f): self.assertEqual(out, optout) + def test_mutations_in_bw_detached_from_tangent(self): + class AF(torch.autograd.Function): + @staticmethod + def forward(ctx, dummy, inplace_tensor): + ctx.inplace_tensor = inplace_tensor + return dummy.clone() + + @staticmethod + def backward(ctx, grad_output): + inplace_tensor = ctx.inplace_tensor + gradient_attachment = grad_output * 0 + 1 + inplace_tensor.add_(1 * gradient_attachment) + return grad_output, None, None + + def fn(dummy, inplace_tensor): + return AF.apply(dummy, inplace_tensor) + + def _inps(): + dummy = torch.zeros((2,), requires_grad=True) + inplace_tensor = torch.zeros((2,), requires_grad=False) + return dummy, inplace_tensor + + inps = _inps() + out = fn(*inps) + ref_inps_after_fw = [x.clone().detach() for x in inps] + out.sum().backward() + ref_inps_after_bw = [x.clone().detach() for x in inps] + + inps = _inps() + out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inps) + inps_after_fw = [x.clone().detach() for x in inps] + out.sum().backward() + inps_after_bw = [x.clone().detach() for x in inps] + + self.assertEqual(ref_inps_after_fw, inps_after_fw) + self.assertEqual(ref_inps_after_bw, inps_after_bw) + + def test_mutation_of_input_in_fw_and_bw(self): + class AF(torch.autograd.Function): + @staticmethod + def forward(ctx, dummy, inplace_tensor): + inplace_tensor.add_(1) + + ctx.inplace_tensor = inplace_tensor + return dummy.clone() + + @staticmethod + def backward(ctx, grad_output): + inplace_tensor = ctx.inplace_tensor + inplace_tensor.add_(1) + return grad_output, None, None + + def fn(dummy, inplace_tensor): + return AF.apply(dummy, inplace_tensor) + + def inps(): + dummy = torch.randn((2,), requires_grad=True) + inplace_tensor = torch.zeros((2,), requires_grad=False) + return dummy, inplace_tensor + + def sc_inps(): + dummy = TwoTensor( + torch.randn((2,), requires_grad=True), + torch.randn((2,), requires_grad=True), + ) + inplace_tensor = TwoTensor( + torch.zeros((2,), requires_grad=False), + torch.zeros((2,), requires_grad=False), + ) + return dummy, inplace_tensor + + for _inps in [inps, sc_inps]: + dummy, inplace = _inps() + y = fn(dummy, inplace) + ref0 = inplace.clone().detach() + y.sum().backward() + ref = inplace.clone().detach() + + dummy, inplace = _inps() + y = torch.compile(fn, backend="aot_eager", fullgraph=True)(dummy, inplace) + self.assertEqual(ref0, inplace) + y.sum().backward() + self.assertEqual(ref, inplace) + class MockFXGraphCache: """ diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 0d7e810c1ef5a1..1508997384d2f4 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -7,13 +7,14 @@ import torch.utils._pytree as pytree from functorch.experimental import control_flow from functorch.experimental.control_flow import cond -from torch._dynamo.testing import normalize_gm +from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm from torch._higher_order_ops.associative_scan import ( _fake_associative_scan, associative_scan, ) from torch._higher_order_ops.map import _fake_map from torch._higher_order_ops.scan import _fake_scan, scan +from torch._higher_order_ops.schema import HopSchemaGenerator from torch._higher_order_ops.while_loop import while_loop from torch._subclasses.functional_tensor import ( CppFunctionalizeAPI, @@ -873,17 +874,15 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): """\ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1): add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = add = None + clone = torch.ops.aten.clone.default(arg6_1) + clone_1 = torch.ops.aten.clone.default(arg6_1); arg6_1 = None zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None - return [arg6_1, arg6_1, None, None, zeros_like, None]""", + return [clone, clone_1, None, None, zeros_like, None]""", ) def test_cond_autograd_pytree_input(self): - # TODO: This is an unexpected behavior for cond - # Without this additional multiplication, - # the output of the backward graph would alias the - # inputs, as the gradients are just 1s and thus get optimized def true_fn(x): - return (x["t"][0] * 2.0) + x["t"][1]["b"] * x["t"][2][0] + return x["t"][0] + x["t"][1]["b"] * x["t"][2][0] def false_fn(x): return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"]) @@ -1429,7 +1428,11 @@ def f2(x, y): control_flow.map(f, x, y) with self.assertRaisesRegex( - RuntimeError, "Expect outputs of map only contains tensors" + # Should be + # torch._dynamo.exc.UncapturedHigherOrderOpError, + # "Expected all leaves to be of torch.Tensor type.*", + torch._dynamo.exc.UncapturedHigherOrderOpError, + "map doesn't work unless it is captured completely with torch.compile.*", ): control_flow.map(f1, x, y) @@ -1533,6 +1536,40 @@ def fwbw(map_op, f, x, y): fake_outs = fwbw(_fake_map, f, x, y) self.assertEqual(true_outs, fake_outs) + def test_map_autograd_higher_order(self): + from torch.autograd.functional import hessian as hes, jacobian as jac + + def f(x, y): + return x.sin().cos() + y + + def wrapper_jac(x, y): + return control_flow.map(f, x, y) + + def wrapper_jac_fake(x, y): + return _fake_map(f, x, y) + + def wrapper_hes(x, y): + return control_flow.map(f, x, y).sum() + + def wrapper_hes_fake(x, y): + return _fake_map(f, x, y).sum() + + for g_fct, (wrap, wrap_fake) in [ + (jac, [wrapper_jac, wrapper_jac_fake]), + (hes, [wrapper_hes, wrapper_hes_fake]), + ]: + xs = torch.ones(3, 2, 2, requires_grad=True) + # Disable the gradient computation for y + y = torch.ones(2, requires_grad=False) + res = control_flow.map(f, xs, y) + expected_res = _fake_map(f, xs, y) + self.assertEqual(expected_res, res) + + expected_grads = g_fct(wrap_fake, (xs, y)) + grads = g_fct(wrap, (xs, y)) + self.assertEqual(expected_res, res) + self.assertEqual(expected_grads, grads) + def test_scan_y_less_ndim_then_dim(self): def combine_fn(carry, x): return carry @ x, (carry @ x).sum() @@ -2709,8 +2746,6 @@ def fct_pointwise_different_carry(x, y): @skipIfNoDynamoSupport @skipIfCrossRef # Arg order changes with crossref def test_scan_pytree_output(self): - from torch._dynamo.testing import EagerAndRecordGraphs - x = torch.randn(3, 10, 2, device=torch.device("cpu")) init = torch.randn(1, 10, 2, device=torch.device("cpu")) @@ -3258,8 +3293,6 @@ def f(fct, init, xs): @skipIfNoDynamoSupport @skipIfCrossRef # Arg order changes with crossref def test_scan_simple_graph(self): - from torch._dynamo.testing import EagerAndRecordGraphs - x = torch.randn(3, 10, 2, device=torch.device("cpu")) init = torch.randn(1, 10, 2, device=torch.device("cpu")) @@ -3827,8 +3860,6 @@ def test_associative_scan_complex_pytree( @skipIfNoDynamoSupport @skipIfCrossRef # Arg order changes with crossref def test_associative_scan_pytree_output(self): - from torch._dynamo.testing import EagerAndRecordGraphs - x = ( ( torch.randn(3, 10, 2, device=torch.device("cpu")), @@ -4987,8 +5018,6 @@ def f(x, y): @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): - from torch._dynamo.testing import EagerAndRecordGraphs - def true_fn(x): return x.sin() @@ -5036,7 +5065,7 @@ def forward(self, L_ctx_saved_tensors_0_: "f32[4]", L_ctx_pred: "b8[]", L_args_1 return (getitem,) class cond_true_0(torch.nn.Module): - def forward(self, l_args_1_, l_ctx_saved_tensors_0_): + def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"): l_args_1__1 = l_args_1_ l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_ @@ -5048,7 +5077,7 @@ def forward(self, l_args_1_, l_ctx_saved_tensors_0_): return (mul,) class cond_false_0(torch.nn.Module): - def forward(self, l_args_1_, l_ctx_saved_tensors_0_): + def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"): l_args_1__1 = l_args_1_ l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_ @@ -5133,8 +5162,6 @@ def forward(self, arg0_1, arg1_1, arg2_1): def test_while_loop_pytree_carry(self): fn, inp = WHILE_LOOP_TESTS["simple_with_pytree_carry"] - from torch._dynamo.testing import EagerAndRecordGraphs - backend = EagerAndRecordGraphs() expected_res = fn(*inp) compiled_res = torch.compile(fn, backend=backend)(*inp) @@ -5341,8 +5368,6 @@ def test_while_loop_compile(self, backend, while_loop_test): @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] - from torch._dynamo.testing import EagerAndRecordGraphs - backend = EagerAndRecordGraphs() torch.compile(fn, backend=backend)(*inp) self.assertEqual(len(backend.graphs), 1) @@ -7404,8 +7429,6 @@ def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor: ): out = torch.compile(Mod(), backend="inductor")(inp, tmp) - from torch._dynamo.testing import EagerAndRecordGraphs - backend = EagerAndRecordGraphs() out = torch.compile(Mod(), backend=backend)(inp, tmp) self.assertExpectedInline( @@ -7429,8 +7452,6 @@ def forward(self, l_inp_, l_tmp_): @parametrize("requires_grad", [True, False]) def test_cond_symint_operands(self, requires_grad): - from torch._dynamo.testing import EagerAndRecordGraphs - backend = EagerAndRecordGraphs() class Mod(torch.nn.Module): @@ -7454,14 +7475,13 @@ def forward(self, a, b): self.assertExpectedInline( backend.graphs[0].code.strip(), """\ -def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt): +def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor): l_a_ = L_a_ l_b_ = L_b_ - l_self_num = L_self_num tensor = torch.tensor([True]) cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 - cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None + cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None getitem = cond[0]; cond = None return (getitem,)""", # noqa: B950 ) @@ -7595,8 +7615,6 @@ def f(init, xs): @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") def test_scan_pytree_closure(self): - from torch._dynamo.testing import EagerAndRecordGraphs - param_buffer = ({"param": torch.randn(3, 3)}, (torch.randn(3),)) def add(carry, x): @@ -7668,10 +7686,10 @@ def test_while_loop_op_int_carry_export(self, strict, dynamic): """\ class GraphModule(torch.nn.Module): def forward(self, x): - x: "f32[s35, 3]"; + x: "f32[s77, 3]"; x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) - sym_size_int_1: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0) + sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0) while_loop_cond_graph_0 = self.while_loop_cond_graph_0 while_loop_body_graph_0 = self.while_loop_body_graph_0 @@ -7685,27 +7703,27 @@ def forward(self, x): gt_1: "Sym(u1 > 0)" = getitem_2 > 0 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None - getitem_1: "f32[s35, 3]" = while_loop[1]; while_loop = None + getitem_1: "f32[s77, 3]" = while_loop[1]; while_loop = None add: "Sym(u1 + 1)" = getitem_2 + 1 - add_1: "f32[s35, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None + add_1: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None - lt: "Sym(u1 < s35)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None + lt: "Sym(u1 < s77)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec) class while_loop_cond_graph_0(torch.nn.Module): - def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"): - sym_size_int: "Sym(s35)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None - lt: "Sym(u0 < s35)" = it_1 < sym_size_int; it_1 = sym_size_int = None + def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"): + sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None + lt: "Sym(u0 < s77)" = it_1 < sym_size_int; it_1 = sym_size_int = None return lt class while_loop_body_graph_0(torch.nn.Module): - def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"): - clone: "f32[s35, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None + def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"): + clone: "f32[s77, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1) select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1) add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None @@ -7719,8 +7737,6 @@ def forward(self, it_1: "Sym(u0)", x_1: "f32[s35, 3]"): @parametrize("dynamic", [True, False]) @parametrize("backend", ["eager", "aot_eager"]) def test_while_loop_op_int_carry_compile(self, dynamic, backend): - from torch._dynamo.testing import EagerAndRecordGraphs - m, args = WHILE_LOOP_TESTS["int_carry"] if backend == "eager": backend = EagerAndRecordGraphs() @@ -7766,7 +7782,7 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): return (add, add_1, lt, ones) class cond_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77): + def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 @@ -7777,7 +7793,7 @@ def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77): return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77): + def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 @@ -7880,8 +7896,6 @@ def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4 @parametrize("backend", ["eager", "aot_eager"]) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_while_loop_op_constant_and_symint_output_compile(self, dynamic, backend): - from torch._dynamo.testing import EagerAndRecordGraphs - m, args = WHILE_LOOP_TESTS["const_and_symint_output"] if backend == "eager": backend = EagerAndRecordGraphs() @@ -8022,8 +8036,6 @@ def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", ar @parametrize("backend", ["eager", "aot_eager"]) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_while_loop_op_pytree_int_carry_compile(self, dynamic, backend): - from torch._dynamo.testing import EagerAndRecordGraphs - m, args = WHILE_LOOP_TESTS["pytree_int_carry"] if backend == "eager": backend = EagerAndRecordGraphs() @@ -8065,7 +8077,7 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x) class cond_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77): + def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 @@ -8076,7 +8088,7 @@ def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unba return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77): + def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 @@ -8160,8 +8172,6 @@ def mutate_f(x): @skipIfTorchDynamo("Graph is not captured correctly when test with dynamo") def test_while_loop_unbacked_bindings(self): - from torch._dynamo.testing import EagerAndRecordGraphs - m, args = WHILE_LOOP_TESTS["pytree_int_carry"] backend = EagerAndRecordGraphs() self._check_compile(m, args, dynamic=True, backend=backend) @@ -8241,6 +8251,31 @@ def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"): """, # noqa: B950 ) + def test_cond_merge_graph_preserves_ph_meta(self): + class M(torch.nn.Module): + def forward(self, x, y, z): + a = y.shape[0] + b = z.shape[0] + + def true_fn(x): + return x + a + + def false_fn(x): + return x + b * z + + return torch.cond(x.sum() > 5, true_fn, false_fn, (x,)) + + backend = EagerAndRecordGraphs() + _ = torch.compile(M(), backend=backend)( + torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4) + ) + self.assertEqual(len(backend.graphs), 1) + gm = backend.graphs[0] + subgraph_attr = gm.graph.find_nodes(op="get_attr")[0] + subgm = getattr(gm, subgraph_attr.target) + for ph in subgm.graph.find_nodes(op="placeholder"): + self.assertTrue("example_value" in ph.meta) + @skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.") def test_cond_symint_closure(self): from torch.export import Dim @@ -8271,30 +8306,30 @@ def false_fn(x): """\ class GraphModule(torch.nn.Module): def forward(self, x, y, z): - x: "f32[s35, 3]"; y: "f32[s58]"; z: "f32[s35, 3]"; + x: "f32[s68, 3]"; y: "f32[s17]"; z: "f32[s68, 3]"; x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec) - sym_size_int_3: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0) - sym_size_int_4: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0); y = None + sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None + sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0) - gt: "Sym(s35 > 5)" = sym_size_int_3 > 5 + gt: "Sym(s68 > 5)" = sym_size_int_5 > 5 true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 - cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_3, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None - getitem: "f32[s35, 3]" = cond[0]; cond = None + cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_5, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_5 = z = None + getitem: "f32[s68, 3]" = cond[0]; cond = None return pytree.tree_unflatten((getitem,), self._out_spec) class true_graph_0(torch.nn.Module): - def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"): - add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None + def forward(self, x: "f32[s68, 3]", sym_size_int_4: "Sym(s17)", sym_size_int_5: "Sym(s68)", z: "f32[s68, 3]"): + add: "f32[s68, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None return (add,) class false_graph_0(torch.nn.Module): - def forward(self, x: "f32[s35, 3]", sym_size_int_4: "Sym(s58)", sym_size_int_3: "Sym(s35)", z: "f32[s35, 3]"): - mul: "f32[s35, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None + def forward(self, x: "f32[s68, 3]", sym_size_int_4: "Sym(s17)", sym_size_int_5: "Sym(s68)", z: "f32[s68, 3]"): + mul: "f32[s68, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_5); z = sym_size_int_5 = None - add: "f32[s35, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None + add: "f32[s68, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None return (add,) """, # noqa: B950 ) @@ -8421,8 +8456,6 @@ def _inner(case): @parametrize("dynamic", [True, False]) @parametrize("backend", ["eager", "aot_eager"]) def test_cond_mismatched_branch_output(self, dynamic, backend): - from torch._dynamo.testing import EagerAndRecordGraphs - class M(torch.nn.Module): def forward(self, x, y, z): a = y.shape[0] @@ -8482,7 +8515,7 @@ def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: return (sub,) class cond_true_0(torch.nn.Module): - def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch): + def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym(s17)", getitem_2_false_branch: "Sym(s17)", l_z__false_branch: "f32[s17, s94]"): l_x__1 = l_x_ s94_1 = s94 @@ -8492,7 +8525,7 @@ def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false return (clone,) class cond_false_0(torch.nn.Module): - def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch): + def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym(s17)", getitem_2_false_branch: "Sym(s17)", l_z__false_branch: "f32[s17, s94]"): l_x__1 = l_x_ s94_1 = s94 @@ -8671,6 +8704,22 @@ def test_while_loop_schema_gen(self): ) self.assertEqual(schema.parse(str(schema)), schema) + def test_schema_tree_spec(self): + schema_gen = HopSchemaGenerator(torch.ops.higher_order.cond) + args = (torch.randn(3, 4), torch.randn(2, 3)) + with self.assertRaisesRegex( + RuntimeError, "Please only add flattened inputs to the hop schema" + ): + schema_gen.add_arg("tuple_args", args) + + for i, arg in enumerate(args): + schema_gen.add_arg(f"tuple_args{i}", arg) + schema_gen.add_schema_tree_spec(pytree.tree_flatten(args)[1]) + flat_schema = schema_gen.gen_schema() + self.assertExpectedInline( + str(flat_schema), """cond(Tensor tuple_args0, Tensor tuple_args1) -> ()""" + ) + instantiate_parametrized_tests(TestHopSchema) instantiate_parametrized_tests(TestControlFlowTraced) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 1c53063b7a7cf8..bd8abbc3ea8563 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -4324,9 +4324,7 @@ def lennard_jones(r): def lennard_jones_force(r): """Get magnitude of LJ force""" - return -epsilon * ( - (-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7) - ) + return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)) r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device) drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device)) @@ -4495,8 +4493,9 @@ def test_find_learning_rate_ensembling(self, device, dropout_layer, mechanism): # This example mimics what a user might do when trying to find the optimal learning rate. They would # want to run a bunch of models with the same behavior (including the same dropout!) and have them # each run with different learning rates. Specifically, this is an example of using same randomness with vmap - points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint( - 0, 2, (100,), device=device + points, labels = ( + torch.randn(100, 2, 2, 2, 2, device=device), + torch.randint(0, 2, (100,), device=device), ) class MLPClassifier(nn.Module): diff --git a/test/functorch/test_memory_efficient_fusion.py b/test/functorch/test_memory_efficient_fusion.py index 7bf263431ad08a..4926781d7f6598 100644 --- a/test/functorch/test_memory_efficient_fusion.py +++ b/test/functorch/test_memory_efficient_fusion.py @@ -208,33 +208,33 @@ def check(f, t, delta, check_val=True, graph_input=False): old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) if delta == -1: - assert ( - old_num_nodes >= new_num_nodes - ), f"number of nodes increased {old_num_nodes}, {new_num_nodes}" + assert old_num_nodes >= new_num_nodes, ( + f"number of nodes increased {old_num_nodes}, {new_num_nodes}" + ) else: - assert ( - old_num_nodes == new_num_nodes + delta - ), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" + assert old_num_nodes == new_num_nodes + delta, ( + f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" + ) # a second pass should not reduce more nodes pass_2_graph = fx_graph_cse(new_graph) pass_2_num_nodes = len(pass_2_graph.nodes) - assert ( - pass_2_num_nodes == new_num_nodes - ), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" + assert pass_2_num_nodes == new_num_nodes, ( + f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" + ) # check correctness if check_val: true_result = fx_g(t) our_result = new_g(t) if true_result is None: # both return None - assert ( - our_result is None - ), f"true result is None, CSE result is {our_result}" + assert our_result is None, ( + f"true result is None, CSE result is {our_result}" + ) else: # results returned are the same - assert torch.all( - true_result == our_result - ), f"results are different {true_result}, {our_result}" # check results are the same + assert torch.all(true_result == our_result), ( + f"results are different {true_result}, {our_result}" + ) # check results are the same class NoChangeTestCase(TestCase): diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 8a0bf6ad40f50a..cef00f83eb72db 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -2154,9 +2154,9 @@ def test_extremal_numerics_nll_loss(self, device): else: weight = torch.randn(weight_shape, device=device) target = torch.randint(0, C, target_shape, device=device) - target[ - 0 - ] = 1 # since we're ignoring index 0, at least one element must be non-zero + target[0] = ( + 1 # since we're ignoring index 0, at least one element must be non-zero + ) fn = functools.partial( torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs diff --git a/test/functorch/test_parsing.py b/test/functorch/test_parsing.py index 46c9b340c59447..8183755ebd4dec 100644 --- a/test/functorch/test_parsing.py +++ b/test/functorch/test_parsing.py @@ -24,6 +24,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from typing import Any from unittest import mock @@ -107,7 +108,7 @@ def test_invalid_expressions(self) -> None: ParsedExpression("(a) ((b c) (d ...))") # invalid identifiers - ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...") + ParsedExpression("camelCase under_scored cApiTaLs \u00df ...") with self.assertRaises(ValueError): ParsedExpression("1a") with self.assertRaises(ValueError): diff --git a/test/functorch/test_rearrange.py b/test/functorch/test_rearrange.py index d5f55d7e7a3b26..b3c8f7753687fc 100644 --- a/test/functorch/test_rearrange.py +++ b/test/functorch/test_rearrange.py @@ -25,7 +25,6 @@ SOFTWARE. """ - import numpy as np import torch diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 2f1d1416b634ba..6ba61a6c1d0d3c 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -734,6 +734,7 @@ def test_fallback_does_not_warn_by_default(self): # warning, not a warning from the vmap fallback path. self.assertEqual(len(wa), 1) + @skipIfTorchDynamo("Flaky test") @unittest.expectedFailure def test_fallback_warns_when_warnings_are_enabled(self): # NB: One day we will implement a batching rule for torch.atan2. @@ -1532,7 +1533,7 @@ def test_unary_pointwise(self, case): self._test_unary(op, getter, "cpu") # test in-place - method = getattr(Tensor, f'{op.__name__ + "_"}') + method = getattr(Tensor, f"{op.__name__ + '_'}") self._test_unary(method, getter, "cpu", check_propagates_grad=False) def test_clone(self): @@ -1718,11 +1719,6 @@ def test_silu_backward(self): test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0)) test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None)) - @skipIf( - TEST_WITH_TORCHDYNAMO - and os.getenv("BUILD_ENVIRONMENT", "") == "linux-focal-py3.8-clang10", - "Segfaults with dynamo on focal, see https://github.com/pytorch/pytorch/issues/107173", - ) @parametrize( "case", [ @@ -4479,7 +4475,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints xfail("resize_"), xfail("view_as_complex"), - xfail("matrix_exp"), xfail("fft.ihfft2"), xfail("fft.ihfftn"), xfail("allclose"), @@ -4539,13 +4534,21 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("clamp_min", ""), xfail("sparse.sampled_addmm"), xfail("sparse.mm", "reduce"), + xfail("special.chebyshev_polynomial_t"), + xfail("special.chebyshev_polynomial_v"), xfail("special.chebyshev_polynomial_u"), + xfail("special.chebyshev_polynomial_w"), + xfail("special.shifted_chebyshev_polynomial_t"), + xfail("special.shifted_chebyshev_polynomial_v"), + xfail("special.shifted_chebyshev_polynomial_u"), + xfail("special.shifted_chebyshev_polynomial_w"), xfail("_segment_reduce", "offsets"), xfail("index_reduce", "prod"), xfail("index_reduce", "mean"), xfail("index_reduce", "amin"), xfail("index_reduce", "amax"), xfail("special.laguerre_polynomial_l"), + xfail("special.legendre_polynomial_p"), xfail("special.hermite_polynomial_h"), xfail("jiterator_binary", device_type="cuda"), xfail("jiterator_4inputs_with_extra_args", device_type="cuda"), @@ -4553,7 +4556,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("lu_solve", ""), xfail("special.hermite_polynomial_he"), xfail("nn.functional.dropout3d", ""), - xfail("special.chebyshev_polynomial_t"), xfail("as_strided_scatter", ""), xfail("equal", ""), xfail("linalg.lu", ""), diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 1bff959e3c4f82..bf738207a41b4b 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -120,7 +120,6 @@ "aten::lu_solve", "aten::margin_ranking_loss", "aten::masked_select_backward", - "aten::matrix_exp", "aten::matrix_exp_backward", "aten::max.names_dim", "aten::max.names_dim_max", diff --git a/test/fx/quantization.py b/test/fx/quantization.py index 3daa4da479ecdb..33550702ca6c70 100644 --- a/test/fx/quantization.py +++ b/test/fx/quantization.py @@ -2,6 +2,7 @@ **This file is EXPERIMENTAL and is mostly used for testing purposes! Do not rely on it for anything!** """ + import operator import sys diff --git a/test/fx/test_common_passes.py b/test/fx/test_common_passes.py index 642da7255b681b..58ecfb58faecec 100644 --- a/test/fx/test_common_passes.py +++ b/test/fx/test_common_passes.py @@ -9,7 +9,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - run_tests, + raise_on_run_directly, TestCase, ) @@ -128,4 +128,4 @@ def test_correctness_factory(self, common_pass, f, device): if __name__ == "__main__": - run_tests() + raise_on_run_directly("test/test_fx.py") diff --git a/test/fx/test_cse_pass.py b/test/fx/test_cse_pass.py index 16aa9e70a029de..74eb2ca3af4269 100644 --- a/test/fx/test_cse_pass.py +++ b/test/fx/test_cse_pass.py @@ -6,7 +6,7 @@ from torch.fx import symbolic_trace from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase banned_ops = get_CSE_banned_ops() @@ -46,9 +46,9 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None): old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) - assert ( - new_num_nodes < old_num_nodes - ) == modified, "modified should be True if the number of nodes decrease" + assert (new_num_nodes < old_num_nodes) == modified, ( + "modified should be True if the number of nodes decrease" + ) if delta == -1: self.assertTrue( @@ -259,4 +259,4 @@ def f(x): if __name__ == "__main__": - run_tests() + raise_on_run_directly("test/test_fx.py") diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py index 4e11ed562254b2..0a0852898a6b2d 100644 --- a/test/fx/test_dce_pass.py +++ b/test/fx/test_dce_pass.py @@ -5,7 +5,11 @@ import torch import torch.fx -from torch.testing._internal.common_utils import IS_MACOS, TestCase +from torch.testing._internal.common_utils import ( + IS_MACOS, + raise_on_run_directly, + TestCase, +) class TestDCE(TestCase): @@ -328,3 +332,7 @@ def forward( # collective nodes should not be removed because they have side effects. self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False) torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + raise_on_run_directly("test/test_fx.py") diff --git a/test/fx/test_dynamism.py b/test/fx/test_dynamism.py index 185a5956048b0b..37db8912b4571a 100644 --- a/test/fx/test_dynamism.py +++ b/test/fx/test_dynamism.py @@ -2,7 +2,7 @@ import torch from torch.fx.experimental._dynamism import track_dynamism_across_examples -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import TestCase class TestDynamism(TestCase): @@ -148,4 +148,7 @@ def not_implemented_property(self): if __name__ == "__main__": - run_tests() + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py index c1e5929ca3019e..8ff9fb438619ab 100644 --- a/test/fx/test_fx_const_fold.py +++ b/test/fx/test_fx_const_fold.py @@ -6,7 +6,7 @@ import torch.fx from torch.fx.experimental import const_fold from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class TestConstFold(TestCase): @@ -706,3 +706,7 @@ def forward(self, x, y): base_result = mod(in_x, in_y) fold_result = mod_folded(in_x, in_y) self.assertTrue(torch.equal(fold_result, base_result)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_fx.py") diff --git a/test/fx/test_fx_node_hook.py b/test/fx/test_fx_node_hook.py index 43cd0e0722e4a7..4d681ab35a3257 100644 --- a/test/fx/test_fx_node_hook.py +++ b/test/fx/test_fx_node_hook.py @@ -89,3 +89,10 @@ def replace_node_hook2(old, new, user): assert gm._create_node_hooks == [] assert gm._erase_node_hooks == [] assert gm._replace_hooks == [] + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_fx_param_shape_control_flow.py b/test/fx/test_fx_param_shape_control_flow.py index d943e1e0b368d4..8972540076f54d 100644 --- a/test/fx/test_fx_param_shape_control_flow.py +++ b/test/fx/test_fx_param_shape_control_flow.py @@ -1,10 +1,8 @@ # Owner(s): ["module: fx"] -import unittest - import torch import torch.fx -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class MyModuleBase(torch.nn.Module): @@ -158,4 +156,4 @@ def test_param_nelement_const(self): if __name__ == "__main__": - unittest.main() + raise_on_run_directly("test/test_fx.py") diff --git a/test/fx/test_fx_split.py b/test/fx/test_fx_split.py index 58f2fea08217b4..6d95592fd290e3 100644 --- a/test/fx/test_fx_split.py +++ b/test/fx/test_fx_split.py @@ -223,3 +223,10 @@ def test_split_by_tags(self) -> None: self.assertTrue(type(gm_output) == type(split_gm_output)) self.assertTrue(torch.equal(gm_output, split_gm_output)) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index 9fb98a525fe2e1..6306daa571bd00 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -102,7 +102,7 @@ def forward(self, x): ) # Check node "linear" is created from node "x" in PropagateUnbackedSymInts - key_provenance = provenance["linear"] + key_provenance = provenance["linear"][0]["from_node"] self.assertEqual(len(key_provenance), 1) key_provenance = key_provenance[0] check_node_source( @@ -151,7 +151,9 @@ def forward(self, x): ) # Check node "linear" is then created from node "x" in PropagateUnbackedSymInts - key_provenance = get_first_node_source_and_check(key_provenance) + key_provenance = get_first_node_source_and_check(key_provenance)[ + "from_node" + ][0] check_node_source( key_provenance, "x", @@ -167,3 +169,10 @@ def forward(self, x): "Interpreter_FlattenInputOutputSignature", CREATE_STR, ) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py index 7d4370b5dcf2c6..2517439d9fe36b 100644 --- a/test/fx/test_fx_xform_observer.py +++ b/test/fx/test_fx_xform_observer.py @@ -11,14 +11,6 @@ from torch.testing._internal.common_utils import TestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_fx.py TESTNAME\n\n" - "instead." - ) - - class TestGraphTransformObserver(TestCase): def test_graph_transform_observer(self): class M(torch.nn.Module): @@ -186,3 +178,10 @@ def forward(self, x): self.assertEqual(len(gm2._create_node_hooks), 0) self.assertEqual(len(gm2._erase_node_hooks), 0) self.assertEqual(len(gm2._deepcopy_hooks), 0) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_gradual_type.py b/test/fx/test_gradual_type.py index fcf50dad99ead1..9fa01fbf149c7e 100644 --- a/test/fx/test_gradual_type.py +++ b/test/fx/test_gradual_type.py @@ -17,7 +17,7 @@ from torch.fx.experimental.unify_refinements import infer_symbolic_types from torch.fx.passes.shape_prop import ShapeProp from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase try: @@ -1168,4 +1168,4 @@ def forward(self, x: TensorType((4, 3, Dyn, Dyn))): if __name__ == "__main__": - unittest.main() + raise_on_run_directly("test/test_fx.py") diff --git a/test/fx/test_graph_pickler.py b/test/fx/test_graph_pickler.py index b593e0adaf280f..ae299140d48a70 100644 --- a/test/fx/test_graph_pickler.py +++ b/test/fx/test_graph_pickler.py @@ -14,8 +14,7 @@ import torch.library from torch._dynamo.testing import make_test_cls_with_patches from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import TEST_WITH_ASAN -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU +from torch.testing._internal.inductor_utils import HAS_CPU # Make the helper files in test/ importable @@ -93,8 +92,7 @@ def fn(a, b): if __name__ == "__main__": - from torch._inductor.test_case import run_tests - - # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 - if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: - run_tests(needs="filelock") + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_lazy_graph_module.py b/test/fx/test_lazy_graph_module.py index ac9d404d67dc84..6404b587d8707e 100644 --- a/test/fx/test_lazy_graph_module.py +++ b/test/fx/test_lazy_graph_module.py @@ -15,7 +15,7 @@ ) from torch.fx.experimental.proxy_tensor import make_fx from torch.package import PackageExporter, PackageImporter -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import TestCase class TestLazyGraphModule(TestCase): @@ -276,4 +276,7 @@ def f(x): if __name__ == "__main__": - run_tests() + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_matcher_utils.py b/test/fx/test_matcher_utils.py index 578e0ab07a6a13..604de73fcd880a 100644 --- a/test/fx/test_matcher_utils.py +++ b/test/fx/test_matcher_utils.py @@ -19,7 +19,7 @@ from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( SubgraphMatcherWithNameNodeMap, ) -from torch.testing._internal.common_utils import IS_WINDOWS, run_tests +from torch.testing._internal.common_utils import IS_WINDOWS from torch.testing._internal.jit_utils import JitTestCase @@ -269,4 +269,7 @@ def forward(self, x): if __name__ == "__main__": - run_tests() + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_net_min_base.py b/test/fx/test_net_min_base.py index 79f25af3bfaeaa..75382304e1950c 100644 --- a/test/fx/test_net_min_base.py +++ b/test/fx/test_net_min_base.py @@ -100,3 +100,10 @@ def test_contiguous_partial_discrepancy_end(self) -> None: def test_continugous_partial_discrepancy_beginning(self) -> None: self.assert_problematic_nodes(["linear", "linear2"]) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index a646ec1bc776cd..ebe40f471e62bb 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -1,6 +1,5 @@ # Owner(s): ["module: fx"] -import unittest from collections.abc import Mapping import torch @@ -49,4 +48,7 @@ def test_partitioner_order(self): if __name__ == "__main__": - unittest.main() + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_pass_infra.py b/test/fx/test_pass_infra.py index d0449b4c313e92..195a4fad2ba33f 100644 --- a/test/fx/test_pass_infra.py +++ b/test/fx/test_pass_infra.py @@ -9,7 +9,7 @@ PassManager, this_before_that_pass_constraint, ) -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase # Pass that uses PassBase and returns a PassResult (best scenario) @@ -228,3 +228,7 @@ def pass_fail(graph_module): error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass" with self.assertRaisesRegex(Exception, error_msg): pm(traced_m) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_fx.py") diff --git a/test/fx/test_shape_inference.py b/test/fx/test_shape_inference.py index 1caa4847bc24c0..77c69d065dde44 100644 --- a/test/fx/test_shape_inference.py +++ b/test/fx/test_shape_inference.py @@ -108,3 +108,10 @@ def generate_graph_module(model): gm = generate_graph_module(m) input_tensors = [torch.randn(1, 1)] infer_shape(gm, input_tensors) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/fx/test_source_matcher_utils.py b/test/fx/test_source_matcher_utils.py index 544e676efcf2ee..b7453272a83e71 100644 --- a/test/fx/test_source_matcher_utils.py +++ b/test/fx/test_source_matcher_utils.py @@ -18,6 +18,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + raise_on_run_directly, skipIfTorchDynamo, ) from torch.testing._internal.jit_utils import JitTestCase @@ -481,3 +482,6 @@ def forward(self, x): instantiate_parametrized_tests(TestSourceMatcher) + +if __name__ == "__main__": + raise_on_run_directly("test/test_fx.py") diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py index 70430e03c3a5fe..9b1a3878ed6a2f 100644 --- a/test/fx/test_z3_gradual_types.py +++ b/test/fx/test_z3_gradual_types.py @@ -1783,8 +1783,9 @@ def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn])): self.assertEqual(s.check(), z3.sat) add_result = z3.Const(3, tensor_type) - broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const( - 5, tensor_type + broadcast_res1, broadcast_res2 = ( + z3.Const(4, tensor_type), + z3.Const(5, tensor_type), ) # print(s.model()) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 4ee5332bcd847d..052baebce337e7 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -12,6 +12,7 @@ import torch._functorch import torch._inductor import torch._inductor.decomposition +import torch.utils._pytree as pytree from functorch.compile import aot_function, nop from torch._dynamo.testing import ( AotEagerAndRecordGraphs, @@ -19,8 +20,12 @@ InductorAndRecordGraphs, normalize_gm, ) -from torch._higher_order_ops.invoke_subgraph import mark_compile_region from torch._higher_order_ops.schema import find_hop_schema +from torch._inductor.pattern_matcher import ( + CallFunctionVarArgs, + PatternMatcherPass, + register_graph_pattern, +) from torch.testing._internal.common_utils import ( run_tests, skipIfTorchDynamo, @@ -31,6 +36,8 @@ from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +nested_compile_region = torch.compiler.nested_compile_region + if HAS_GPU: import triton @@ -42,7 +49,7 @@ def gn(x, y): return torch.mul(x, y) def fn(x, y): - return mark_compile_region(gn)(x, y) + return nested_compile_region(gn)(x, y) x = torch.randn(8, requires_grad=True) y = torch.randn(8, requires_grad=True) @@ -65,7 +72,7 @@ def gn(x, y): return torch.mul(x, y) def fn(x, y): - return mark_compile_region(gn)(x, y) + return nested_compile_region(gn)(x, y) x = torch.randn(8, requires_grad=True) y = torch.randn(8, requires_grad=True) @@ -85,11 +92,11 @@ def fn(x, y): self.assertEqual(y.grad, y_clone.grad) def test_multiple(self): - @mark_compile_region + @nested_compile_region def cos(x): return torch.cos(x) - @mark_compile_region + @nested_compile_region def sin(x): return torch.sin(x) @@ -116,7 +123,7 @@ def count_unique_get_attr_nodes(self, gm, args, expected): self.assertEqual(len(subgraph_attr_names), expected) def test_simple(self): - @mark_compile_region + @nested_compile_region def gn(x, y): return torch.mul(x, y) @@ -145,7 +152,7 @@ def __init__(self): super().__init__() self.c = 5 - @mark_compile_region + @nested_compile_region def forward(self, x, y): return torch.mul(x, y).sin() + self.c @@ -176,7 +183,7 @@ def __init__(self): super().__init__() self.c = 5 - @mark_compile_region + @nested_compile_region def forward(self, x, y): return torch.mul(x, y).sin() + self.c @@ -226,7 +233,7 @@ def __init__(self): self.c = 5 self.register_buffer("buf", torch.ones(8, requires_grad=False)) - @mark_compile_region + @nested_compile_region def forward(self, x, y): self.buf.add_(1) return torch.mul(x, y).sin() + self.c + self.buf @@ -244,9 +251,8 @@ def fn(mod, x, y): x_clone = x.detach().clone().requires_grad_(True) y_clone = y.detach().clone().requires_grad_(True) backend = EagerAndRecordGraphs() - with mock.patch( - "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", - True, + with ( + torch.no_grad(), ): res = torch.compile(fn, backend=backend, fullgraph=True)( mod, x_clone, y_clone @@ -298,8 +304,171 @@ def forward(self, l_mod_buffers_buf_: "f32[8]", l_x_: "f32[8]", l_y_: "f32[8]"): self.assertEqual(res, ref) self.assertEqual(mod.buf, mod_ref.buf) + def test_auto_functionalize(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + self.register_buffer("buf", torch.ones(8, requires_grad=False)) + + @nested_compile_region + def forward(self, x, y): + return torch.mul(x, y).sin() * self.c * self.buf + + mod_ref = Mod() + mod = Mod() + + def fn(mod, x, y): + return mod(x, y) + mod(x, y) + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = fn(mod_ref, x, y) + + x_clone = x.detach().clone().requires_grad_(True) + y_clone = y.detach().clone().requires_grad_(True) + backend = AotEagerAndRecordGraphs() + res = torch.compile(fn, backend=backend, fullgraph=True)(mod, x_clone, y_clone) + res.sum().backward() + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + self.assertEqual(ref, res) + self.assertExpectedInline( + normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[8]", primals_2: "f32[8]", primals_3: "f32[8]"): + partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 + + invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, primals_2, primals_3); partitioned_fw_subgraph_0_0 = None + getitem_12: "f32[8]" = invoke_subgraph_4[3] + getitem_11: "f32[8]" = invoke_subgraph_4[2] + getitem_10: "f32[8]" = invoke_subgraph_4[1] + getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None + + partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0 + + invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', primals_1, primals_2, primals_3); partitioned_fw_subgraph_0_1 = primals_1 = primals_2 = primals_3 = None + getitem_15: "f32[8]" = invoke_subgraph_6[3] + getitem_14: "f32[8]" = invoke_subgraph_6[2] + getitem_13: "f32[8]" = invoke_subgraph_6[1] + getitem_1: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None + + add: "f32[8]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None + return (add, getitem_12, getitem_11, getitem_10, getitem_15, getitem_14, getitem_13) + + class partitioned_fw_subgraph_0_0(torch.nn.Module): + def forward(self, primals_0: "f32[8]", primals_1: "f32[8]", primals_2: "f32[8]"): + mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_0, primals_1) + sin: "f32[8]" = torch.ops.aten.sin.default(mul); mul = None + mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(sin, 5); sin = None + mul_2: "f32[8]" = torch.ops.aten.mul.Tensor(mul_1, primals_2); mul_1 = None + return (mul_2, primals_0, primals_1, primals_2) +""", + ) + self.assertExpectedInline( + normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, getitem_12: "f32[8]", getitem_11: "f32[8]", getitem_10: "f32[8]", getitem_15: "f32[8]", getitem_14: "f32[8]", getitem_13: "f32[8]", tangents_1: "f32[8]"): + partitioned_bw_subgraph_0_1 = self.partitioned_bw_subgraph_0_0 + + invoke_subgraph_7 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_1, 'partitioned_bw_subgraph_0_0', getitem_13, getitem_14, getitem_15, tangents_1); partitioned_bw_subgraph_0_1 = getitem_13 = getitem_14 = getitem_15 = None + getitem_2: "f32[8]" = invoke_subgraph_7[0] + getitem_3: "f32[8]" = invoke_subgraph_7[1]; invoke_subgraph_7 = None + + partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0 + + invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_10, getitem_11, getitem_12, tangents_1); partitioned_bw_subgraph_0_0 = getitem_10 = getitem_11 = getitem_12 = tangents_1 = None + getitem_6: "f32[8]" = invoke_subgraph_5[0] + getitem_7: "f32[8]" = invoke_subgraph_5[1]; invoke_subgraph_5 = None + + add_1: "f32[8]" = torch.ops.aten.add.Tensor(getitem_2, getitem_6); getitem_2 = getitem_6 = None + add_2: "f32[8]" = torch.ops.aten.add.Tensor(getitem_3, getitem_7); getitem_3 = getitem_7 = None + return (add_1, add_2, None) + + class partitioned_bw_subgraph_0_0(torch.nn.Module): + def forward(self, primals_0: "f32[8]", primals_1: "f32[8]", primals_2: "f32[8]", tangents_0: "f32[8]"): + mul_3: "f32[8]" = torch.ops.aten.mul.Tensor(tangents_0, primals_2); tangents_0 = primals_2 = None + mul_4: "f32[8]" = torch.ops.aten.mul.Tensor(mul_3, 5); mul_3 = None + mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_0, primals_1) + cos: "f32[8]" = torch.ops.aten.cos.default(mul); mul = None + mul_5: "f32[8]" = torch.ops.aten.mul.Tensor(mul_4, cos); mul_4 = cos = None + mul_6: "f32[8]" = torch.ops.aten.mul.Tensor(mul_5, primals_0); primals_0 = None + mul_7: "f32[8]" = torch.ops.aten.mul.Tensor(mul_5, primals_1); mul_5 = primals_1 = None + return (mul_7, mul_6, None) +""", + ) + + def test_buffer_mutation_works_under_no_grad(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.ones(8, requires_grad=False)) + + @nested_compile_region + def forward(self, x, y): + self.buf.add_(1) + return torch.mul(x, y).sin() * self.buf + + mod_ref = Mod() + mod = Mod() + + def fn(mod, x, y): + return mod(x, y) + mod(x, y) + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = fn(mod_ref, x, y) + + x_clone = x.detach().clone().requires_grad_(True) + y_clone = y.detach().clone().requires_grad_(True) + with torch.no_grad(): + res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) + self.assertEqual(ref, res) + self.assertEqual(mod_ref.buf, mod.buf) + + mod = Mod() + x_clone = x.detach().clone().requires_grad_(True) + y_clone = y.detach().clone().requires_grad_(True) + with torch.inference_mode(): + res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) + self.assertEqual(ref, res) + self.assertEqual(mod_ref.buf, mod.buf) + + mod = Mod() + x_clone = x.detach().clone().requires_grad_(False) + y_clone = y.detach().clone().requires_grad_(False) + res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) + self.assertEqual(ref, res) + self.assertEqual(mod_ref.buf, mod.buf) + + def test_buffer_mutation_errors_under_training(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.ones(8, requires_grad=False)) + + @nested_compile_region + def forward(self, x, y): + self.buf.add_(1) + return torch.mul(x, y).sin() * self.buf + + mod = Mod() + + def fn(mod, x, y): + return mod(x, y) + mod(x, y) + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + with self.assertRaisesRegex( + RuntimeError, + "does not currently support training with in-place input or buffer mutations", + ): + torch.compile(fn, backend="inductor", fullgraph=True)(mod, x, y) + def test_list(self): - @mark_compile_region + @nested_compile_region def gn(x, y): return [torch.mul(x, y), torch.add(x, y)] @@ -325,7 +494,7 @@ def fn(x, y): self.assertEqual(y.grad, y_clone.grad) def test_tuple_of_tuple(self): - @mark_compile_region + @nested_compile_region def gn(x, y): return ((torch.mul(x, y),), torch.add(x, y)) @@ -361,7 +530,7 @@ def backward(ctx, grad_out): a = grad_out.view(12, 5) return torch.cos(torch.reshape(a, (3, 4, 5))) - @mark_compile_region + @nested_compile_region def gn(x): return CustomOp.apply(x) @@ -388,7 +557,7 @@ def fn(x): @requires_cuda def test_sdpa(self): - @mark_compile_region + @nested_compile_region def gn(q, k, v): return torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True @@ -420,7 +589,7 @@ def fn(q, k, v): res.sum().backward() def test_symint_from_fwd_to_bwd(self): - @mark_compile_region + @nested_compile_region def gn(x, y): a = torch.sum(x, (1,), keepdim=True).view(y.shape[1], y.shape[0]) return torch.matmul(a, y) @@ -456,11 +625,11 @@ def test_dropout_checks_joint_graph(self): # graph passes. Without running joint graph passes, we would get an # error like AssertionError: should have been handled in # replace_random.py - @mark_compile_region + @nested_compile_region def gn(x): return torch.nn.functional.dropout(torch.sin(x), p=0.5) - @mark_compile_region + @nested_compile_region def hn(x): return torch.sin(x) @@ -523,7 +692,7 @@ def forward(self, primals_0: "f32[8]"): def test_dropout_checks_joint_graph_inference(self): # Checks that joint graph results in inductor seeds for just the inference graph - @mark_compile_region + @nested_compile_region def gn(x): return torch.nn.functional.dropout(torch.sin(x), p=0.5) @@ -562,7 +731,7 @@ def forward(self, arg0_1: "f32[8]"): ) def test_dedupe(self): - @mark_compile_region + @nested_compile_region def gn(x, y): return torch.mul(x, y) @@ -648,7 +817,7 @@ def forward(self, primals_0: "f32[8]", primals_1: "f32[8]"): ) def test_dce(self): - @mark_compile_region + @nested_compile_region def gn(x): x = torch.sin(x) # should be dce'd @@ -684,7 +853,7 @@ def forward(self, arg0_1: "f32[4]"): def test_nonlocal_update(self): counter = 2 - @mark_compile_region + @nested_compile_region def gn(x, y): nonlocal counter return (torch.mul(x, y) * counter,) @@ -749,7 +918,7 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"): ) def test_view_to_reshape(self): - @mark_compile_region + @nested_compile_region def gn(x): x = torch.sin(x) x = x.view(1, 8) @@ -789,7 +958,7 @@ def forward(self, arg0_1: "f32[8]"): ) def test_normalize_gm(self): - @mark_compile_region + @nested_compile_region def gn(x, y): # Different graph give different names to intermediate nodes for _ in range(5): @@ -847,7 +1016,7 @@ def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"): ) def test_input_mutation(self): - @mark_compile_region + @nested_compile_region def gn(x, y): x.add_(1) return torch.mul(x, y) @@ -859,14 +1028,92 @@ def fn(x, y): y = torch.randn(8, requires_grad=False) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Encountered input mutation during higher order op tracing", + + x_clone = x.clone() + self.assertEqual(opt_fn(x, y), fn(x_clone, y)) + + def test_input_mutation_mutiple_times(self): + @nested_compile_region + def gn(x, y): + x.add_(1) + return torch.mul(x, y) + + def fn(x, y): + z = gn(x, y) + for _ in range(16): + z += gn(x, y) + return z + + x = torch.randn(8, requires_grad=False) + x_clone = x.clone() + y = torch.randn(8, requires_grad=False) + + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + + with ( + torch.no_grad(), ): - opt_fn(x, y) + out = opt_fn(x, y) + exp_out = fn(x_clone, y) + self.assertEqual(exp_out, out) + self.assertEqual(x_clone, x) + + def test_input_mutation_mutiple_times_fake_tensor_cahche_hit(self): + @nested_compile_region + def gn(x, y): + x.add_(1) + return torch.mul(x, y) + + def fn(x, y): + z = gn(x, y) + for _ in range(16): + z += gn(x, y) + return z + + x = torch.randn(8, requires_grad=False) + x_clone = x.clone() + y = torch.randn(8, requires_grad=False) + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + + fake_prop_count = 0 + + def _mock_invoke_subgraph(mode, subgraph, identifer, *operands): + nonlocal fake_prop_count + fake_prop_count += 1 + return (operands[0].clone(),) + + with ( + mock.patch( + "torch._higher_order_ops.utils.registered_hop_fake_fns", + {torch.ops.higher_order.invoke_subgraph: _mock_invoke_subgraph}, + ), + torch.no_grad(), + ): + out = opt_fn(x, y) + + # Fake propagation occurs only twice, with subsequent calls using cached results. + # + # First fake propagation (in collect_metadata_analysis of AOT): + # - Uses the original Dynamo graph + # - Flow: functionalization -> fake tensor + # + # Second fake propagation (in _create_graph of AOT): + # - Uses a materialized graph that includes epilogue operations + # - Flow: functionalization -> proxy -> fake tensor + # + # The key difference: the second time we materialize the graph with epilogue + # operations included in the proxy key. Since the dynamo graph module is not + # in the functional + epilogue format, the cache key should be different, + # preventing cache reuse between these two phases. + self.assertEqual(fake_prop_count, 2) + exp_out = fn(x_clone, y) + self.assertEqual(exp_out, out) + self.assertEqual(x_clone, x) def test_input_mutation_inference_mode(self): - @mark_compile_region + @nested_compile_region def gn(x, y): x.add_(1) return torch.mul(x, y) @@ -881,15 +1128,15 @@ def fn(x, y): y = torch.randn(8, requires_grad=False) with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Encountered input mutation during higher order op tracing", + RuntimeError, + "Inplace update to inference tensor outside InferenceMode is not allowed", ): opt_fn(x, y) def test_simple_module(self): mod = torch.nn.Linear(8, 8) - @mark_compile_region + @nested_compile_region def gn(x): return torch.cos(x), mod(x) @@ -930,7 +1177,7 @@ def fn(x): opt_fn(x) def test_input_output_aliasing(self): - @mark_compile_region + @nested_compile_region def gn(x, y): return (x, torch.mul(x, y)) @@ -942,14 +1189,21 @@ def fn(x, y): y = torch.randn(8, requires_grad=False) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Encountered aliasing during higher order op tracing", - ): + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: opt_fn(x, y) + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered aliasing during higher order op tracing" in str(cause) + ) + def test_input_input_aliasing(self): - @mark_compile_region + @nested_compile_region def gn(x, y): return torch.mul(x, y) @@ -959,14 +1213,21 @@ def fn(x): x = torch.randn(8, requires_grad=False) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Encountered aliasing during higher order op tracing", - ): + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: opt_fn(x) + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered aliasing during higher order op tracing" in str(cause) + ) + def test_output_output_aliasing(self): - @mark_compile_region + @nested_compile_region def gn(x): z = torch.cos(x) return z, z.view(1, 8) @@ -977,12 +1238,19 @@ def fn(x): x = torch.randn(8, requires_grad=False) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Encountered aliasing during higher order op tracing", - ): + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: opt_fn(x) + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered aliasing during higher order op tracing" in str(cause) + ) + def test_mod_attr_aliasing(self): class MutateParam(torch.nn.Module): def __init__(self): @@ -993,7 +1261,7 @@ def forward(self, x): self.a.add_(1) return torch.mul(x, self.a) - @mark_compile_region + @nested_compile_region def gn(x): return mod(x) @@ -1004,17 +1272,58 @@ def fn(x, y): x = torch.randn(8, requires_grad=False) y = torch.randn(8, requires_grad=False) - fn(x, y) - opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Encountered input mutation during higher order op tracing", - ): - opt_fn(x, y) + + compiled_out = opt_fn(x, y) + # reset constant attr + mod.a = torch.ones(8) + self.assertEqual(compiled_out, fn(x, y)) + + def test_redundant_compile_region(self): + @nested_compile_region + @nested_compile_region + def gn(x): + return torch.sin(x) + + def fn(x): + return gn(x) + gn(x) + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + + x = torch.randn(8, 8, requires_grad=True) + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + if not TEST_WITH_CROSSREF: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[8, 8]"): + l_x_ = L_x_ + + subgraph_0 = self.subgraph_0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_); subgraph_0 = None + getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None + subgraph_1 = self.subgraph_0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', l_x_); subgraph_1 = l_x_ = None + getitem_1: "f32[8, 8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + + add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None + return (add,) + + class subgraph_0(torch.nn.Module): + def forward(self, l_x_: "f32[8, 8]"): + sin: "f32[8, 8]" = torch.sin(l_x_); l_x_ = None + return (sin,) +""", + ) def test_kwargs_only(self): - @mark_compile_region + @nested_compile_region def gn(x, *, y): return x * y @@ -1035,7 +1344,7 @@ def __init__(self): super().__init__() self.linear = torch.nn.Linear(8, 8) - @mark_compile_region + @nested_compile_region def helper(self, x): return self.linear(x) @@ -1092,7 +1401,7 @@ def forward(self, x): class Mod(torch.nn.Module): def __init__(self): super().__init__() - self.submod = mark_compile_region(SubMod()) + self.submod = nested_compile_region(SubMod()) def forward(self, x): return x + self.submod(x) * self.submod(x) + x @@ -1143,7 +1452,7 @@ def test_return_none(self): ) ones = torch.ones(1000, device="cuda:0", dtype=torch.float32) - @mark_compile_region + @nested_compile_region def fn(x, train): return F.dropout(x * weight, 0.33, train) @@ -1156,7 +1465,7 @@ def run(x, train=True): weight.grad.clone() def test_return_none_from_fwd(self): - @mark_compile_region + @nested_compile_region def gn(x): return x * 2, None, x * 3 @@ -1260,7 +1569,7 @@ def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"): ) def test_dynamic(self): - @mark_compile_region + @nested_compile_region def gn(x): return torch.sin(x) @@ -1274,9 +1583,24 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + def test_complex(self): + # Observed in Wan2.1 + @nested_compile_region + def gn(x): + return torch.sin(x) + + def fn(x): + return gn(x) + gn(x) + + x = torch.randn(2, 2, dtype=torch.complex64) + ref = fn(x) + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_pending_unbacked(self): - @mark_compile_region + @nested_compile_region def gn(x): u = x[0].item() return x * u @@ -1295,7 +1619,7 @@ def fn(x): @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_unbacked(self): - @mark_compile_region + @nested_compile_region def gn(x, y): b = x.item() torch._check_is_size(b) @@ -1315,7 +1639,7 @@ def fn(x, y): self.assertEqual(ref, res) def test_bwd_partitioning(self): - @mark_compile_region + @nested_compile_region def gn(x, y): z = torch.matmul(x, y) return torch.sin(z) @@ -1398,7 +1722,7 @@ def forward(self, mm: "f32[8, 8]", t: "f32[8, 8]", t_1: "f32[8, 8]", tangents_0: ) def test_const_tensor(self): - @mark_compile_region + @nested_compile_region def gn(x): return torch.tensor(64, dtype=torch.float32) * x @@ -1417,14 +1741,14 @@ def test_ac(self): def fn1(x): return torch.cos(x) - @mark_compile_region + @nested_compile_region def fn1_checkpoint(x): return torch.utils.checkpoint.checkpoint(fn1, x, use_reentrant=False) def fn2(x): return torch.sin(x) - @mark_compile_region + @nested_compile_region def fn2_checkpoint(x): return torch.utils.checkpoint.checkpoint(fn2, x, use_reentrant=False) @@ -1462,8 +1786,45 @@ def fn(x): res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone) self.assertEqual(ref, res) + @torch._inductor.config.patch(fallback_random=True) + def test_ac_rng(self): + def fn1(x): + return torch.cos(torch.nn.functional.dropout(x, p=0.5)) + + @nested_compile_region + def fn1_checkpoint(x): + return torch.utils.checkpoint.checkpoint(fn1, x, use_reentrant=False) + + def fn(x): + return fn1_checkpoint(x) + fn1_checkpoint(x) + + x = torch.randn(8, requires_grad=True) + torch.manual_seed(0) + ref = fn(x) + ref.sum().backward() + + x_clone = x.clone().detach().requires_grad_(True) + backend = AotEagerAndRecordGraphs() + + torch.manual_seed(0) + res = torch.compile(fn, backend=backend, fullgraph=True)(x_clone) + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + + # Check that the Dynamo and AOT graphs have just one subgraph module + self.assertEqual(len(backend.graphs), 1) + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + + torch.manual_seed(0) + res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone) + self.assertEqual(ref, res) + res.sum().backward() + def test_fake_tensor_checking(self): - @mark_compile_region + @nested_compile_region def gn(x): return torch.sin(x) @@ -1517,7 +1878,7 @@ def test_different_symint(self): Tests check that the same subgraph called with different symints use different graphs """ - @mark_compile_region + @nested_compile_region def gn(x): return torch.sin(x) @@ -1592,7 +1953,7 @@ def backward(ctx, grad_out): (x,) = ctx.saved_tensors return x * torch.cos(grad_out) - @mark_compile_region + @nested_compile_region def gn(x): return CustomOp.apply(x) @@ -1692,7 +2053,7 @@ def grid_fn(meta): return output - @mark_compile_region + @nested_compile_region def gn(x, y): o = torch.zeros_like(x) call_triton_add(x, y, o, 0) @@ -1762,7 +2123,7 @@ def forward(self, z: "f32[5]", y: "f32[5]"): @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) def test_unbacked_symbol(self): - @mark_compile_region + @nested_compile_region def gn(x): return torch.sin(torch.nonzero(x)) @@ -1779,7 +2140,7 @@ def fn(x): self.assertEqual(ref, res) def test_different_strides_in_backward(self): - @mark_compile_region + @nested_compile_region def gn(x): return torch.cos(x) @@ -1925,12 +2286,11 @@ def forward(self, primals_0: "Sym(s77)", primals_1: "f32[s77, 16]", tangents_0: """, ) - @unittest.skip("Repro for an issue which is not fixed yet") def test_div(self): - @mark_compile_region + @nested_compile_region def gn(x): div = torch.div(1024, 256, rounding_mode="trunc") - return div * torch.ones(64, div) + return div * torch.ones(64, div) * x def fn(x): return gn(x) @@ -1943,17 +2303,179 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + @requires_gpu + def test_preserves_strides(self): + class _CustomPass(PatternMatcherPass): + def __init__(self) -> None: + super().__init__() + + def __call__(self, g: torch.fx.Graph): + self.apply(g) + + g = _CustomPass() + called = False + + x = torch.randn(4, 4, 2, 2, device=GPU_TYPE) + other = torch.randn(4, 4, 2, 2, device=GPU_TYPE) + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.permute), + pass_dict=g, + ) + def _(match, *args, **kwargs): + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return torch.ops.mylib.force_channels_last( + torch.ops.aten.permute(*args, **kwargs) + ) + + nonlocal called + called = True + match.replace_by_example(decomp, flat_args) + + from torch._inductor import config + + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define( + "force_channels_last(Tensor x) -> Tensor", + tags=[torch._C.Tag.flexible_layout], + ) + + def impl2(x): + return x.clone(memory_format=torch.channels_last) + + lib.impl("force_channels_last", impl2, "CompositeExplicitAutograd") + + lib.define( + "add_op(Tensor x, Tensor y) -> Tensor", + ) + + def impl(x, y): + out = y.clone() # contiguous with strides (16, 4, 2, 1) + out.add_(x.transpose(-1, -2)) + return out + + def meta(x, y): + return torch.empty_like(y, memory_format=torch.contiguous_format) + + lib.impl("add_op", impl, "CompositeExplicitAutograd") + lib.impl("add_op", meta, "Meta") + @nested_compile_region + def gn(y, z): + return torch.ops.mylib.add_op.default(y, z) + + def f(x, other): + y = x.transpose(2, 3).contiguous().transpose(2, 3) + z = y.sin().transpose(2, 3) + return gn(y, z) + + with config.patch( + post_grad_custom_post_pass=g, + ): + f_compile = torch.compile(f, fullgraph=True) + self.assertEqual(f(x, other), f_compile(x, other)) + self.assertTrue(called) + + @requires_gpu + def test_preserves_output_strides(self): + # Have a graph pass that changes strides for the output op of the + # invoke_subgraph, and check if the output strides are preserved + x = torch.randn(4, 4, 2, 2, device=GPU_TYPE) + other = torch.randn(4, 4, 2, 2, device=GPU_TYPE) + + class _CustomPass(PatternMatcherPass): + def __init__(self) -> None: + super().__init__() + + def __call__(self, g: torch.fx.Graph): + self.apply(g) + + g = _CustomPass() + called = False + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.permute), + pass_dict=g, + ) + def _(match, *args, **kwargs): + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return torch.ops.mylib.force_channels_last( + torch.ops.aten.permute(*args, **kwargs) + ) + + nonlocal called + called = True + match.replace_by_example(decomp, flat_args) + + from torch._inductor import config + + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define( + "force_channels_last(Tensor x) -> Tensor", + tags=[torch._C.Tag.flexible_layout], + ) + + def impl2(x): + return x.clone(memory_format=torch.channels_last) + + lib.impl("force_channels_last", impl2, "CompositeExplicitAutograd") + + lib.define( + "add_op(Tensor x, Tensor y) -> Tensor", + ) + + def impl(x, y): + # Check that the input strides are preserved. This helps in + # testing that the HOP preserves the output strides. + assert x.stride() == (16, 4, 1, 2) + assert y.stride() == (16, 4, 2, 1) + out = y.clone() # contiguous with strides (16, 4, 2, 1) + out.add_(x.transpose(-1, -2)) + return out + + def meta(x, y): + return torch.empty_like(y, memory_format=torch.contiguous_format) + + lib.impl("add_op", impl, "CompositeExplicitAutograd") + lib.impl("add_op", meta, "Meta") + + @nested_compile_region + def gn(x, other): + y = x.transpose(2, 3).contiguous().transpose(2, 3) + z = y.sin().transpose(2, 3) + return y, z + + def f(x, other): + y, z = gn(x, other) + return torch.ops.mylib.add_op.default(y, z) + + with config.patch( + post_grad_custom_post_pass=g, + ): + f_compile = torch.compile(f, fullgraph=True) + self.assertEqual(f(x, other), f_compile(x, other)) + self.assertTrue(called) + + +@skipIfTorchDynamo("Not a torch._dynamo test") @parameterized_class( [ {"strict": False}, {"strict": True}, ], - class_name_func=lambda cls, _, params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}", + class_name_func=lambda cls, + _, + params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}", ) class TestInvokeSubgraphExport(TestCase): def test_simple_func(self): - @mark_compile_region + @nested_compile_region def gn(x, y): return torch.mul(x, y) @@ -1992,7 +2514,7 @@ def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"): ) def test_unbacked(self): - @mark_compile_region + @nested_compile_region def gn(x, y): b = x.item() torch._check_is_size(b) @@ -2017,7 +2539,7 @@ def forward(self, x, y): def test_pending_unbacked(self): class M(torch.nn.Module): - @mark_compile_region + @nested_compile_region def gn(self, x): u = x[0].item() return x * u @@ -2049,7 +2571,7 @@ def forward(self, x): def test_simple_method(self): class M(torch.nn.Module): - @mark_compile_region + @nested_compile_region def gn(self, x, y): return torch.mul(x, y) @@ -2073,7 +2595,7 @@ def __init__(self): super().__init__() self.register_buffer("buf", b) - @mark_compile_region + @nested_compile_region def forward(self, x, y): return x * y + self.buf @@ -2095,5 +2617,24 @@ def forward(self, x, y): self.assertEqual(len(list(ep.graph_module.named_modules())), 2) +class NegativeTesting(TestCase): + def test_graph_break(self): + @nested_compile_region + def gn(x): + torch._dynamo.graph_break() + return torch.cos(x) + + def fn(x): + return gn(x) + + x = torch.randn(8, 8, requires_grad=True) + + with self.assertRaisesRegex( + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ): + torch.compile(fn, backend="eager")(x) + + if __name__ == "__main__": run_tests() diff --git a/test/inductor/custom_ops.cpp b/test/inductor/custom_ops.cpp index 65e83dfd177163..ae1d00c5b6346c 100644 --- a/test/inductor/custom_ops.cpp +++ b/test/inductor/custom_ops.cpp @@ -27,6 +27,53 @@ std::tuple, std::optional> fn_with_optiona return {t3, t4, t5}; } +std::tuple, std::optional> fn_with_optional_tensor_output_2_impl(Tensor t1, Tensor t2) { + Tensor t3 = t1 + t2; + Tensor t4; + Tensor t5 = t1 - t2; + return {t3, t4, t5}; +} + +std::tuple, std::optional> fn_with_optional_tensor_output_2_meta(Tensor t1, Tensor t2) { + Tensor t3 = t1.clone(); + Tensor t4; + Tensor t5 = t1.clone(); + return {t3, t4, t5}; +} + +std::tuple, std::optional, std::optional> fn_with_optional_tensor_nullopt_output_impl(Tensor t1, Tensor t2) { + Tensor t3 = t1 + t2; + Tensor t4; + Tensor t5 = t1 - t2; + return {t3, t4, t5, std::nullopt}; +} + + +std::tuple, std::optional, std::optional> fn_with_optional_tensor_nullopt_output_meta(Tensor t1, Tensor t2) { + Tensor t3 = t1.clone(); + Tensor t4; + Tensor t5 = t1.clone(); + return {t3, t4, t5, std::nullopt}; +} + +std::tuple, std::optional, int64_t, int64_t> fn_with_int_output_impl(Tensor t1, Tensor t2, int64_t i1) { + Tensor t3 = t1 + t2; + Tensor t4 = t1 - t2; + Tensor t5; + int64_t i2 = 0; + int64_t i3 = 0; + return {t3, t4, t5, i2, i3}; +} + +std::tuple, std::optional, int64_t, int64_t> fn_with_int_output_meta(Tensor t1, Tensor t2, int64_t i1) { + Tensor t3 = t1.clone(); + Tensor t4 = t1.clone(); + Tensor t5; + int64_t i2 = 0; + int64_t i3 = 0; + return {t3, t4, t5, i2, i3}; +} + Tensor fn_with_all_inputs_impl( const Tensor& tensor, const c10::List& tensors, @@ -364,6 +411,9 @@ extern "C" { TORCH_LIBRARY(aoti_custom_ops, m) { m.def("custom_add(Tensor t1, Tensor t2) -> Tensor"); m.def("fn_with_optional_tensor_output(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)"); + m.def("fn_with_optional_tensor_output_2(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)"); + m.def("fn_with_optional_tensor_nullopt_output(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?, Tensor?)"); + m.def("fn_with_int_output(Tensor t1, Tensor t2, int i) -> (Tensor, Tensor?, Tensor?, int, int)"); m.def( "fn_with_all_inputs(Tensor tensor, " "Tensor[] tensors, " @@ -410,6 +460,9 @@ TORCH_LIBRARY(aoti_custom_ops, m) { TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) { m.impl("custom_add", at::custom_add_impl); m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_impl); + m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_impl); + m.impl("fn_with_optional_tensor_nullopt_output", at::fn_with_optional_tensor_nullopt_output_impl); + m.impl("fn_with_int_output", at::fn_with_int_output_impl); m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl); m.impl("fn_with_default_input", at::fn_with_default_input_impl); m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl); @@ -422,6 +475,9 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) { TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) { m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_meta); + m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_meta); + m.impl("fn_with_optional_tensor_nullopt_output", at::fn_with_optional_tensor_nullopt_output_meta); + m.impl("fn_with_int_output", at::fn_with_int_output_meta); m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta); m.impl("fn_with_default_input", at::fn_with_default_input_meta); m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta); diff --git a/test/inductor/extension_backends/triton/device_interface.py b/test/inductor/extension_backends/triton/device_interface.py index 7c26640e216186..857885a662d5a9 100644 --- a/test/inductor/extension_backends/triton/device_interface.py +++ b/test/inductor/extension_backends/triton/device_interface.py @@ -80,16 +80,16 @@ def device_count() -> int: @staticmethod def maybe_exchange_device(device: int) -> int: - assert ( - device == 0 - ), f"Only device index 0 is supported, tried to set index to {device}" + assert device == 0, ( + f"Only device index 0 is supported, tried to set index to {device}" + ) return 0 # previous device is always 0 @staticmethod def exchange_device(device: int) -> int: - assert ( - device == 0 - ), f"Only device index 0 is supported, tried to set index to {device}" + assert device == 0, ( + f"Only device index 0 is supported, tried to set index to {device}" + ) return 0 # previous device is always 0 @staticmethod diff --git a/test/inductor/indirect_assert_helper.py b/test/inductor/indirect_assert_helper.py index ca9819ac1ec91a..33f74f44e52b6e 100644 --- a/test/inductor/indirect_assert_helper.py +++ b/test/inductor/indirect_assert_helper.py @@ -58,9 +58,9 @@ def lower2(x): assert dims in ("2", "3") shape_x = [3, 2, 4] if dims == "3" else [3, 2] if one_size: - assert ( - fn_name == "first_arg" - ), "only first_arg can be tested for a special case of 1-size tensor" + assert fn_name == "first_arg", ( + "only first_arg can be tested for a special case of 1-size tensor" + ) shape_x[0] = 1 assert dyn_shape in ("True", "False") dynamic_shapes = dyn_shape == "True" diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index d24d3d26974e31..28bd4bf0f22a60 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -5,6 +5,7 @@ import sys import tempfile import unittest +import zipfile from unittest import skip from unittest.mock import patch @@ -19,6 +20,7 @@ from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters from torch._inductor import config +from torch._inductor.package import package_aoti from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.test_case import TestCase from torch._inductor.utils import is_big_gpu, run_and_get_cpp_code @@ -26,6 +28,7 @@ from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.export import Dim, export, export_for_training +from torch.export.pt2_archive._package import load_pt2 from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater @@ -50,11 +53,14 @@ TEST_WITH_ROCM, ) from torch.testing._internal.custom_tensor import CustomTensorPlainOut -from torch.testing._internal.inductor_utils import GPU_TYPE, IS_BIG_GPU +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test -from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu +from torch.testing._internal.triton_utils import requires_gpu from torch.utils import _pytree as pytree -from torch.utils._triton import has_triton_tma +from torch.utils._triton import ( + has_triton_experimental_host_tma, + has_triton_tensor_descriptor_host_tma, +) if HAS_GPU: @@ -66,11 +72,16 @@ add_kernel_2d_autotuned, add_kernel_autotuned, add_kernel_autotuned_weird_param_order, + add_kernel_on_device_tma_new_api, + add_kernel_on_device_tma_old_api, add_kernel_with_none_param_and_equal_to_1_arg, add_kernel_with_optional_param, add_kernel_with_scaling, - add_kernel_with_tma_1d, - add_kernel_with_tma_2d, + add_kernel_with_tma_1d_new_api, + add_kernel_with_tma_1d_old_api, + add_kernel_with_tma_2d_new_api, + add_kernel_with_tma_2d_old_api, + create_tensor_descriptor_shim, mul2_inplace_kernel, strange_config_matmul_kernel, sub_kernel_autotuned, @@ -125,7 +136,11 @@ class AOTInductorTestsTemplate: @common_utils.parametrize("embed_kernel_binary", [False, True]) - def test_simple(self, embed_kernel_binary): + @common_utils.parametrize("max_autotune", [False, True]) + def test_simple(self, embed_kernel_binary, max_autotune): + if self.device == "cpu" and IS_MACOS and max_autotune: + raise unittest.SkipTest("max_autotune not supported on macos") + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -139,7 +154,12 @@ def forward(self, x, y): torch.randn(10, 10, device=self.device), ) model = Model() - with config.patch({"aot_inductor.embed_kernel_binary": embed_kernel_binary}): + with config.patch( + { + "aot_inductor.embed_kernel_binary": embed_kernel_binary, + "max_autotune": max_autotune, + } + ): self.check_model(model, example_inputs) _, code = run_and_get_cpp_code( @@ -161,10 +181,8 @@ def forward(self, x, y): "toolchain doesn't support ptx to fatbin", ) @skipIfRocm - @skipIfXpu @common_utils.parametrize("embed_kernel_binary", [True, False]) - @common_utils.parametrize("emit_current_arch_binary", [True, False]) - def test_simple_multi_arch(self, embed_kernel_binary, emit_current_arch_binary): + def test_simple_multi_arch(self, embed_kernel_binary): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU_TYPE") @@ -185,7 +203,6 @@ def forward(self, x, y): { "aot_inductor.embed_kernel_binary": embed_kernel_binary, "aot_inductor.emit_multi_arch_kernel": True, - "aot_inductor.emit_current_arch_binary": emit_current_arch_binary, } ): self.check_model(model, example_inputs) @@ -193,7 +210,8 @@ def forward(self, x, y): _, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) - FileCheck().check(".fatbin").run(code) + file_extension = ".spv" if self.device == "xpu" else ".fatbin" + FileCheck().check(file_extension).run(code) def test_small_constant(self): class Model(torch.nn.Module): @@ -296,11 +314,14 @@ def forward(self, x): return torch.matmul(x, w) example_inputs = (torch.randn(4, 4, device=self.device),) - with torch.no_grad(), config.patch( - { - "always_keep_tensor_constants": True, - "aot_inductor.use_runtime_constant_folding": True, - } + with ( + torch.no_grad(), + config.patch( + { + "always_keep_tensor_constants": True, + "aot_inductor.use_runtime_constant_folding": True, + } + ), ): model = Model(self.device) so_path = AOTIRunnerUtil.legacy_compile( @@ -398,6 +419,77 @@ def forward(self, y): ep, inductor_configs={"aot_inductor.use_runtime_constant_folding": True} ) + @common_utils.parametrize("dynamic", [False, True]) + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_triton_kernel_on_device_tma(self, dynamic, tma_version): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + + kernel = ( + add_kernel_on_device_tma_new_api + if tma_version == "new" + else add_kernel_on_device_tma_old_api + ) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + BLOCK_SIZE = 32 + out = torch.zeros_like(a) + m, n = out.size() + + # Allocate workspace for on-device TMA descriptors + # Need 128 bytes per descriptor, 3 descriptors total + if tma_version == "old": + workspace = torch.zeros(3 * 128, dtype=torch.uint8, device=a.device) + else: + workspace = None + + grid = (triton.cdiv(m, BLOCK_SIZE), triton.cdiv(n, BLOCK_SIZE)) + + kernel[grid]( + a, + b, + out, + m, + n, + workspace, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out + + a = torch.randn((32 * 4, 32 * 8), device=self.device) + b = torch.randn((32 * 4, 32 * 8), device=self.device) + example_inputs = (a, b) + + triton.set_allocator( + lambda size, align, stream: torch.empty( + size, dtype=torch.int8, device="cuda" + ) + ) + + dynamic_shapes = None + if dynamic: + dim0 = Dim("s0", min=2, max=1024) + dim1 = Dim("s1", min=2, max=1024) + dynamic_shapes = { + "a": {0: dim0, 1: None}, + "b": {0: dim1, 1: None}, + } + + self.check_model( + Model(), + example_inputs=example_inputs, + dynamic_shapes=dynamic_shapes, + ) + @requires_gpu def test_multi_device(self): if self.device == "cpu" and GPU_TYPE == "xpu": @@ -976,6 +1068,23 @@ def forward(self, x, y): example_inputs = (x, y) self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + def test_large_dynamic_dim(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + add_0 = x + y + return torch.nn.functional.relu(input=add_0, inplace=False) + + x = torch.randn(128, 2048, device=self.device) + y = torch.randn(128, 2048, device=self.device) + # Use a dimension that exceeds the maximum value of a C long long (2^63 - 1) + dim0_x = Dim("dim0_x", min=1, max=1171368248680556527362) + dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}} + example_inputs = (x, y) + self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", @@ -1830,7 +1939,7 @@ def test_cond_mismatched_branch_output(self, dynamic): # Note the minimum has to be 4 because the model # is slicing over the first dim with [2:], if first # dim is 2 or 3, the slicing will be 0/1 specialized, - # causing a constraint violation eror. + # causing a constraint violation error. dim0_a = Dim("s0", min=4, max=1024) dim0_b = Dim("s1", min=4, max=1024) dynamic_shapes = { @@ -2265,9 +2374,10 @@ def forward(self, x): example_inputs = (torch.randn(32, 16),) model = Model().eval() - with config.patch( - {"freezing": True, "aot_inductor.force_mmap_weights": True} - ), torch.no_grad(): + with ( + config.patch({"freezing": True, "aot_inductor.force_mmap_weights": True}), + torch.no_grad(), + ): exported_model = export_for_training( model, example_inputs, strict=True ).module() @@ -2282,6 +2392,49 @@ def forward(self, x): self.check_model(converted_model, example_inputs) + def test_fallback_mem_leak_fix(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y, idx): + tmp = x + y + w = torch.ops.aten.as_strided(tmp, x.shape, x.stride()) + out = torch.ops.aten.index.Tensor(w, [idx]) + return w, out + + example_inputs = ( + torch.randn(4, 1, 4, device=GPU_TYPE), + torch.randn(4, 1, 4, device=GPU_TYPE), + torch.randn(4, device=GPU_TYPE) > 0, + ) + + dim0 = Dim("dim0", min=1, max=2048) + dynamic_shapes = { + "x": {0: dim0}, + "y": {0: dim0}, + "idx": {0: dim0}, + } + package_path: str = AOTIRunnerUtil.compile( + Model(), + example_inputs, + dynamic_shapes=dynamic_shapes, + ) + aot_inductor_module = torch._inductor.aoti_load_package(package_path) + device_interface = get_interface_for_device(GPU_TYPE) + device: int = device_interface.current_device() + mem_before = device_interface.memory_allocated(device) + aot_inductor_module(*example_inputs) + mem_after = device_interface.memory_allocated(device) + self.assertEqual(mem_before, mem_after) + + actual = aot_inductor_module(*example_inputs) + expected = Model()(*example_inputs) + torch.testing.assert_close(actual, expected) + @requires_multigpu() def test_replicate_on_devices(self): if self.device != GPU_TYPE: @@ -2415,6 +2568,41 @@ def forward(self, x, y): self.assertTrue(same(result_cpu, result_gpu_0.cpu())) self.assertTrue(same(result_cpu, result_gpu_1.cpu())) + @requires_multigpu() + def test_load_package_multiple_gpus(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Model(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x, y): + return x + torch.nn.functional.linear(y, self.weight) + + weight = torch.randn(10, 10, device=self.device) + inputs = ( + torch.randn(10, 10, device=self.device), + torch.randn(10, 10, device=self.device), + ) + model = Model(weight).to(device=self.device) + result_ref = model(*inputs) + + package_path = AOTIRunnerUtil.compile(model, inputs) + + # Load AOT package on gpu:N + device_interface = get_interface_for_device(GPU_TYPE) + for i in range(device_interface.device_count()): + device = torch.device(GPU_TYPE, i) + with device_interface.device(i), torch.no_grad(): + model_package = torch._inductor.aoti_load_package( + package_path, device_index=i + ) + inputs_on_device = [input.to(device=device) for input in inputs] + result_package = model_package(*inputs_on_device) + self.assertTrue(same(result_ref.cpu(), result_package.cpu())) + def test_reuse_kernel(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -2589,6 +2777,29 @@ def forward(self, x): example_inputs = (torch.randn(8, 4, 4, device=self.device),) self.check_model(Model(), example_inputs) + @patch("torch._dynamo.utils.CompileEventLogger.log_instant_event") + def test_backward_no_op_logging(self, mock_log_instant_event): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x + + model = Model() + dummy_input = torch.randn(1, 5) + + from torch._dynamo.utils import CompileEventLogLevel + from torch._inductor import compile_fx + + graph_module = torch.fx.symbolic_trace(model) + compile_fx._compile_fx_inner(graph_module, (dummy_input,)) + mock_log_instant_event.assert_called_once_with( + "backward no-op", + metadata={"compile_id": None}, + log_level=CompileEventLogLevel.PT2_COMPILE, + ) + @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") def test_dup_unbacked_sym_decl(self): class Model(torch.nn.Module): @@ -3003,11 +3214,20 @@ def forward(self, x): self.check_model(Model(), example_inputs) @common_utils.parametrize("dynamic", [False, True]) - def test_triton_kernel_tma_descriptor_1d(self, dynamic): + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_triton_kernel_tma_descriptor_1d(self, dynamic, tma_version): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") - if not has_triton_tma(): - raise unittest.SkipTest("requires Triton TMA") + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + + kernel = ( + add_kernel_with_tma_1d_new_api + if tma_version == "new" + else add_kernel_with_tma_1d_old_api + ) class Model(torch.nn.Module): def __init__(self) -> None: @@ -3019,11 +3239,8 @@ def forward(self, a, b): n_elements = out.numel() desc_a, desc_b, desc_out = ( - triton.tools.experimental_descriptor.create_1d_tma_descriptor( - t.data_ptr(), - n_elements, - BLOCK_SIZE, - t.element_size(), + create_tensor_descriptor_shim( + t, [BLOCK_SIZE], new_api=(tma_version == "new") ) for t in (a, b, out) ) @@ -3031,7 +3248,7 @@ def forward(self, a, b): grid = lambda meta: ( # noqa: E731 triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) - add_kernel_with_tma_1d[grid]( + kernel[grid]( desc_a, desc_b, desc_out, @@ -3059,11 +3276,20 @@ def forward(self, a, b): ) @common_utils.parametrize("dynamic", [False, True]) - def test_triton_kernel_tma_descriptor_2d(self, dynamic): + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_triton_kernel_tma_descriptor_2d(self, dynamic, tma_version): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") - if not has_triton_tma(): - raise unittest.SkipTest("requires Triton TMA") + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + + kernel = ( + add_kernel_with_tma_2d_new_api + if tma_version == "new" + else add_kernel_with_tma_2d_old_api + ) class Model(torch.nn.Module): def __init__(self) -> None: @@ -3076,13 +3302,10 @@ def forward(self, a, b): x_size, y_size = out.size() desc_a, desc_b, desc_out = ( - triton.tools.experimental_descriptor.create_2d_tma_descriptor( - t.data_ptr(), - x_size, - y_size, - BLOCK_SIZE_X, - BLOCK_SIZE_Y, - t.element_size(), + create_tensor_descriptor_shim( + t, + [BLOCK_SIZE_X, BLOCK_SIZE_Y], + new_api=(tma_version == "new"), ) for t in (a, b, out) ) @@ -3091,7 +3314,7 @@ def forward(self, a, b): triton.cdiv(x_size, meta["BLOCK_SIZE_X"]), triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]), ) - add_kernel_with_tma_2d[grid]( + kernel[grid]( desc_a, desc_b, desc_out, @@ -3288,8 +3511,8 @@ def forward(self, x, y): if dynamic: dim0_xy = Dim("s0", min=2, max=1024) dynamic_shapes = { - "x": {0: dim0_xy, 1: None}, - "y": {0: dim0_xy, 1: None}, + "x": {0: dim0_xy}, + "y": {0: dim0_xy}, } example_inputs = ( torch.randn(2, device=self.device), @@ -3322,6 +3545,42 @@ def forward(self, x): x = torch.randn(16, 16, device=self.device) self.check_model(Model(), (x,)) + def test_triton_kernel_dynamic_grid(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + import math + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y, n_elements_tensor): + output = torch.zeros_like(x) + n_elements_symint = n_elements_tensor.item() + n_elements = x.numel() + + def grid(meta): + n_elements_complicated = n_elements_symint // 1.0 + return (math.trunc(n_elements_complicated / meta["BLOCK_SIZE"]),) + + add_kernel_autotuned[grid]( + x, + y, + output, + n_elements, + ) + + return output + + x = torch.randn(128, device=self.device) + y = torch.randn(128, device=self.device) + n_elem = torch.tensor(128) + dim0_x = Dim("dim0_x", min=8, max=256) + dim0_y = Dim("dim0_y", min=8, max=256) + dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}, "n_elements_tensor": {}} + self.check_model(Model(), (x, y, n_elem), dynamic_shapes=dynamic_shapes) + def test_shifted_constraint_ranges(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -3414,6 +3673,46 @@ def forward( self.check_model(Model(), inputs) + def test_narrow_fallback(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, inp: torch.Tensor, dim: int, start: int, length: int): + return torch.ops.aten.narrow(inp, dim, start, length) + + inputs = (torch.rand((3, 4), device=self.device), 0, 0, 2) + + self.check_model(Model(), inputs) + + def test_pad_fallback(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, + inp: torch.Tensor, + pad: tuple[int, ...], + ): + return torch.ops.aten.pad(inp, pad) + + inputs = (torch.rand((3, 3, 4, 2), device=self.device), (0, 1, 2, 1, 3, 3)) + + self.check_model(Model(), inputs) + + def test_fill__fallback(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, inp: torch.Tensor, scalar: float): + torch.ops.aten.fill_(inp, scalar) + return inp + + inputs = (torch.rand((3, 3, 4, 2), device=self.device), 0.5) + self.check_model(Model(), inputs) + @common_utils.parametrize("embed_kernel_binary", [False, True]) def test_repeated_user_defined_triton_kernel(self, embed_kernel_binary): if self.device != GPU_TYPE: @@ -4247,7 +4546,7 @@ def forward(self, x): self.assertTrue(result[0].data_ptr() != result[1].data_ptr()) def test_multiple_output_alias(self): - # Test when mutliple outputs alias the same tensor + # Test when multiple outputs alias the same tensor class Model(torch.nn.Module): def forward(self, x): squared = x * x @@ -5077,22 +5376,28 @@ def forward(self, a): model = Model(N, K, self.device) a = torch.randn(M, K, device=self.device) example_inputs = (a,) - with torch.no_grad(), config.patch( - { - "always_keep_tensor_constants": True, - "aot_inductor.package_constants_in_so": True, - } + with ( + torch.no_grad(), + config.patch( + { + "always_keep_tensor_constants": True, + "aot_inductor.package_constants_in_so": True, + } + ), ): so_path = AOTIRunnerUtil.legacy_compile( model=model, example_inputs=example_inputs, ) - with torch.no_grad(), config.patch( - { - "always_keep_tensor_constants": True, - "aot_inductor.package_constants_in_so": False, - } + with ( + torch.no_grad(), + config.patch( + { + "always_keep_tensor_constants": True, + "aot_inductor.package_constants_in_so": False, + } + ), ): so_path_weightless = AOTIRunnerUtil.legacy_compile( model=model, @@ -5133,6 +5438,47 @@ def runner_call(*args, **kwargs): output = runner_call(test_inputs) self.assertEqual(expected, output) + def test_weight_on_disk_legacy(self): + class Model(torch.nn.Module): + def __init__(self, n, k, device): + super().__init__() + self.weight = torch.randn(n, k, device=device) + self.bias = torch.randn(n, device=device) + + def forward(self, a): + return torch.nn.functional.linear(a, self.weight, self.bias) + + M, N, K = 128, 2048, 4096 + model = Model(N, K, self.device) + a = torch.randn(M, K, device=self.device) + example_inputs = (a,) + + with ( + torch.no_grad(), + config.patch( + { + "always_keep_tensor_constants": True, + "aot_inductor.package_constants_in_so": False, + "aot_inductor.package_constants_on_disk": True, + "aot_inductor.package": True, + } + ), + ): + aoti_files = AOTIRunnerUtil.legacy_compile( + model=model, + example_inputs=example_inputs, + ) + + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + package_path = package_aoti( + f.name, + {"model": aoti_files}, + ) + pt2_contents = load_pt2(package_path, load_weights_from_disk=True) + loaded1 = pt2_contents.aoti_runners["model"] + + self.assertEqual(loaded1(a), model(a)) + def test_extract_constants_map(self): class Model(torch.nn.Module): def __init__(self, n, k, device): @@ -5727,6 +6073,82 @@ def convert_weight_to_int4pack(b): model = Model(b_int4pack, b_scales_and_zeros_f32) self.check_model(model, (a,)) + @parametrize("m", [32]) + @parametrize("n", [64]) + @parametrize("q_group", [32, 64]) + @parametrize("num_groups", [1, 2]) + def test__weight_int4pack_mm_with_scales_and_zeros(self, m, n, q_group, num_groups): + if "xpu" not in self.device: + raise unittest.SkipTest("requires Intel GPU") + + class Model(torch.nn.Module): + def __init__(self, weight, scale, zeros) -> None: + super().__init__() + self.weight = weight + self.scale = scale + self.zeros = zeros + + def forward(self, a): + return torch._weight_int4pack_mm_with_scales_and_zeros( + a, self.weight, q_group, self.scale, self.zeros + ) + + def _group_quantize_tensor_xpu(w, n_bit=4, q_group_size=16): + # w [k, n] = [32, 48] + assert w.dim() == 2 + # w [n, k] = [48, 32] + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + # to_quant: [n * k / group_size, group_size] + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + + zeros = min_int - min_val.div(scales).round() + zeros = torch.clamp(zeros, min_int, max_int) + zeros = zeros.to(torch.int8) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int) + assert torch.isnan(out).sum() == 0 + + # [n, k] + out = out.to(dtype=torch.int32).reshape(w.shape) + if out.device != torch.device("cpu"): + out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8) + + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous() + zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous() + + return out, scales, zeros + + def convert_weight_to_int4pack(b): + # b_uint8 [n, k //2] + b_uint8, scales, zeros = _group_quantize_tensor_xpu( + b, n_bit=4, q_group_size=q_group + ) + # b_int4pack [k//8, n] + b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=2) + + return b_int4pack, scales, zeros + + k = q_group * num_groups + a = torch.rand((m, k), device=self.device, dtype=torch.bfloat16) + b = torch.rand((k, n), device=self.device, dtype=torch.bfloat16) + b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b) + model = Model(b_int4pack, b_scales, zeros_int8) + self.check_model(model, (a,)) + def test_assert_tensor_meta(self): class Module(torch.nn.Module): def forward(self, x): @@ -5834,6 +6256,99 @@ def forward(self, x, y, m): Model(), (x, y, m), f"uint32_t grid_0 = {grid_0}L;", 1 ) + @skipIfRocm + @patch.dict(os.environ, {"TRITON_DEBUG": "1"}) + def test_triton_dynamic_launcher_grid(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=["numel"], + ) + @triton.jit + def add_one_kernel(X, Y, numel, BLOCK_SIZE: "tl.constexpr"): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + tl.device_assert(block_start < numel) + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + x = tl.load(X + offsets) + y = x + 1 + tl.store(Y + offsets, y) + + class Model(torch.nn.Module): + def forward(self, x, value): + numel = value.item() + out = torch.zeros_like(x, dtype=torch.float16) + + grid = lambda META: ( # noqa: E731 + triton.cdiv(numel, META["BLOCK_SIZE"]), + ) + add_one_kernel[grid](x, out, numel) + + return out + + example_inputs = ( + torch.randn(1024, device=self.device), + torch.tensor([1024], dtype=torch.int32, device=self.device), + ) + + with config.patch("triton.autotune_with_sample_inputs", True): + dim0_x = Dim("dim0_x", min=2, max=8192) + dynamic_shapes = {"x": {0: dim0_x}, "value": {0: Dim.AUTO}} + self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + + @skipIfRocm + @patch.dict(os.environ, {"TRITON_DEBUG": "1"}) + def test_triton_dynamic_launcher_grid_infer_from_tensor(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=["numel"], + ) + @triton.jit + def add_one_kernel(X, Y, numel, BLOCK_SIZE: "tl.constexpr"): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + tl.device_assert(block_start < numel) + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + offsets) + y = x + 1 + tl.store(Y + offsets, y) + + class Model(torch.nn.Module): + def forward(self, x, dim_D): + numel = x.shape[1] * dim_D.item() + x = x.repeat(dim_D, 1) + out = torch.zeros_like(x, dtype=torch.float16) + + grid = lambda META: ( # noqa: E731 + triton.cdiv(numel, META["BLOCK_SIZE"]), + ) + add_one_kernel[grid](x, out, numel) + + return out + + example_inputs = ( + torch.randn(1, 1024, device=self.device), + torch.tensor([2], dtype=torch.int32, device=self.device), + ) + + with config.patch("triton.autotune_with_sample_inputs", True): + dim1_x = Dim("dim1_x", min=2, max=8192) + dynamic_shapes = {"x": {0: Dim.AUTO, 1: dim1_x}, "dim_D": {0: Dim.AUTO}} + self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + def test_composed_dynamic_size(self): class Model(torch.nn.Module): def forward(self, x): @@ -5847,6 +6362,35 @@ def forward(self, x): } self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + def test_boolean_indexing(self): + class Model(torch.nn.Module): + def forward(self, x, y, z, x1, z1): + a = x[y] + a1 = x1[y] + b = torch.cat([a, z], dim=1) + b1 = torch.cat([a1, z1], dim=1) + return b, b1 + + x = torch.randn(3, 5, device=self.device) + y = torch.tensor([0, 1, 1], dtype=torch.bool, device=self.device) + z = torch.randn(2, 4, device=self.device) + x1 = torch.randn(3, 5, device=self.device) + z1 = torch.randn(2, 4, device=self.device) + + example_inputs = (x, y, z, x1, z1) + s0 = Dim("s0", min=0, max=10240) + s1 = Dim("s1", min=0, max=10240) + s2 = Dim("s2", min=0, max=10240) + s3 = Dim("s3", min=0, max=10240) + dynamic_shapes = { + "x": {0: s0, 1: s1}, + "y": {0: s0}, + "z": {0: s2, 1: s3}, + "x1": {0: s0, 1: s1}, + "z1": {0: s2, 1: s3}, + } + self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + def test_with_cudagraphs(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") @@ -5935,6 +6479,39 @@ def forward(self, x): # the output should have int type self.check_model(Model2(), (x,)) + def test_using_model_name_for_files(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x, y): + return x + self.linear(y) + + example_inputs = ( + torch.randn(10, 10, device=self.device), + torch.randn(10, 10, device=self.device), + ) + model = Model().to(self.device) + with torch.no_grad(): + package_path: str = AOTIRunnerUtil.compile( + model, + example_inputs, + inductor_configs={ + "aot_inductor.model_name_for_generated_files": "test_model" + }, + ) + + with zipfile.ZipFile(package_path, "r") as zip_ref: + all_files = zip_ref.namelist() + base_dir = "test_model.wrapper/data/aotinductor/model/test_model" + self.assertTrue(f"{base_dir}.wrapper.cpp" in all_files) + self.assertTrue(f"{base_dir}.kernel.cpp" in all_files) + self.assertTrue(f"{base_dir}.wrapper.so" in all_files) + + aot_inductor_module = torch._inductor.aoti_load_package(package_path) + self.assertEqual(aot_inductor_module(*example_inputs), model(*example_inputs)) + class AOTInductorLoggingTest(LoggingTestCase): @make_logging_test(dynamic=logging.DEBUG) @@ -5986,7 +6563,6 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): "test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)), # No fft implementation for XPU yet. "test_fft_c2c": fail_gpu(("xpu",), is_skip=True), - "test_stft": fail_gpu(("xpu",)), } diff --git a/test/inductor/test_aot_inductor_arrayref.py b/test/inductor/test_aot_inductor_arrayref.py index 2cdac977ba1fcc..16f4e4f20c8ddc 100644 --- a/test/inductor/test_aot_inductor_arrayref.py +++ b/test/inductor/test_aot_inductor_arrayref.py @@ -93,7 +93,7 @@ def fail_minimal_arrayref_interface(is_skip=False): ), # https://github.com/pytorch/pytorch/issues/129550 # https://github.com/pytorch/pytorch/issues/123691 - "test_dynamic_scalar": fail_minimal_arrayref_interface(is_skip=True), + "test_dynamic_scalar": fail_stack_allocation(is_skip=True), # https://github.com/pytorch/pytorch/issues/122980 "test_fft_c2c": fail_stack_allocation(is_skip=True), "test_freezing": fail_minimal_arrayref_interface(is_skip=True), @@ -169,6 +169,29 @@ def fail_minimal_arrayref_interface(is_skip=False): "test_symbool_item": fail_minimal_arrayref_interface(is_skip=True), # TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype' "test_symfloat_item": fail_minimal_arrayref_interface(is_skip=True), + # Causes a segfault when the process exits + "test_view_outputs": fail_stack_allocation(is_skip=True), + "test_pytree_inputs": fail_stack_allocation(is_skip=True), + "test_duplicated_params": fail_stack_allocation(is_skip=True), + "test_output_misaligned": fail_stack_allocation(is_skip=True), + "test_no_args": fail_stack_allocation(is_skip=True), + "test_fqn": fail_stack_allocation(is_skip=True), + "test_assert_tensor_meta": fail_stack_allocation(is_skip=True), + "test_clamp_decomposition": fail_stack_allocation(is_skip=True), + "test_aoti_constant_tensor_name_collision": fail_stack_allocation(is_skip=True), + "test_cond_unbacked_symint_closure_dynamic_False": fail_stack_allocation( + is_skip=True + ), + "test_empty_cat_dtype_promotion": fail_stack_allocation(is_skip=True), + "test_pad_fallback": fail_stack_allocation(is_skip=True), + "test_simple_embed_kernel_binary_False_max_autotune_True": fail_stack_allocation( + is_skip=True + ), + "test_simple_embed_kernel_binary_True_max_autotune_True": fail_stack_allocation( + is_skip=True + ), + # When running test_seq with test_issue_140766, the process segfaults + "test_seq": fail_stack_allocation(is_skip=True), } diff --git a/test/inductor/test_aot_inductor_custom_ops.py b/test/inductor/test_aot_inductor_custom_ops.py index 2c0c34b8f8b7d8..31de9ac4c71d0e 100644 --- a/test/inductor/test_aot_inductor_custom_ops.py +++ b/test/inductor/test_aot_inductor_custom_ops.py @@ -1,5 +1,5 @@ # Owner(s): ["module: inductor"] -# This test requires libaoti_custom_ops.so to be built, which happnes when BUILD_TEST = 1 +# This test requires libaoti_custom_ops.so to be built, which happens when BUILD_TEST = 1 import logging import os import sys @@ -113,14 +113,18 @@ def _(x): class AOTInductorTestsTemplate: def test_custom_op_add(self) -> None: class M(torch.nn.Module): - def forward(self, x, y): - return torch.ops.aoti_custom_ops.custom_add(x, y) + def __init__(self, device): + super().__init__() + self.device = device + self.w = torch.randn(3, 3, device=device) - m = M().to(device=self.device) - args = ( - torch.randn(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) + def forward(self, x): + const = torch.tensor([1], device=self.device) + x = torch.ops.aoti_custom_ops.custom_add(x, const) + return torch.ops.aoti_custom_ops.custom_add(x, self.w) + + m = M(self.device).to(device=self.device) + args = (torch.randn(3, 3, device=self.device),) self.check_model(m, args) def test_custom_op_add_output_path(self) -> None: @@ -149,6 +153,46 @@ def forward(self, x, y): ) self.check_model(m, args) + def test_fn_with_optional_tensor_output_2(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aoti_custom_ops.fn_with_optional_tensor_output_2(x, y) + + m = M().to(device=self.device) + args = ( + torch.randn(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + self.check_model(m, args) + + def test_fn_with_optional_tensor_nullopt_output(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aoti_custom_ops.fn_with_optional_tensor_nullopt_output( + x, y + ) + + m = M().to(device=self.device) + args = ( + torch.randn(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + self.check_model(m, args) + + def test_fn_with_int_output(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + i = x.shape[0] + z, _, _, i1, i2 = torch.ops.aoti_custom_ops.fn_with_int_output(x, y, i) + return z, z * (i1 + i2 + i) + + m = M().to(device=self.device) + args = ( + torch.randn(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + self.check_model(m, args) + def test_custom_op_all_inputs(self) -> None: class MyModel(torch.nn.Module): # pyre-fixme[3]: Return type must be annotated. @@ -379,25 +423,28 @@ def forward(self, x): m = Model().to(device=self.device) args = (torch.randn(2, 3, device=self.device),) - with config.patch( - "aot_inductor.custom_ops_to_c_shims", - { - torch.ops.aoti_custom_ops.fn_square.default: [ - """ + with ( + config.patch( + "aot_inductor.custom_ops_to_c_shims", + { + torch.ops.aoti_custom_ops.fn_square.default: [ + """ AOTITorchError aoti_torch_cpu_fn_square( AtenTensorHandle input, AtenTensorHandle* ret)""", - """ + """ AOTITorchError aoti_torch_cuda_fn_square( AtenTensorHandle input, AtenTensorHandle* ret)""", - ], - }, - ), config.patch( - "aot_inductor.custom_op_libs", - ["aoti_custom_ops"], + ], + }, + ), + config.patch( + "aot_inductor.custom_op_libs", + ["aoti_custom_ops"], + ), ): self.check_model(m, args) diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index a9f8a29daf8928..daf2e340efa2a1 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -18,8 +18,9 @@ from torch._inductor.codecache import get_kernel_bin_format from torch._inductor.package import AOTICompiledModel, load_package, package_aoti from torch._inductor.test_case import TestCase -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache from torch.export import Dim +from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents from torch.testing._internal.common_utils import ( IS_FBCODE, skipIfRocm, @@ -86,7 +87,9 @@ def compile( if sys.platform != "darwin" else [] ), - class_name_func=lambda cls, _, params: f"{cls.__name__}{'Cpp' if params['package_cpp_only'] else ''}_{params['device']}", + class_name_func=lambda cls, + _, + params: f"{cls.__name__}{'Cpp' if params['package_cpp_only'] else ''}_{params['device']}", ) class TestAOTInductorPackage(TestCase): def check_model( @@ -136,8 +139,8 @@ def forward(self, x, y): def test_remove_intermediate_files(self): # For CUDA, generated cpp files contain absolute path to the generated cubin files. - # With the package artifact, that cubin path should be overriden at the run time, - # so removing those intermeidate files in this test to verify that. + # With the package artifact, that cubin path should be overridden at the run time, + # so removing those intermediate files in this test to verify that. class Model(torch.nn.Module): def forward(self, x, y): return x + y @@ -157,7 +160,7 @@ def forward(self, x, y): torch.manual_seed(0) with tempfile.NamedTemporaryFile(suffix=".pt2") as f: ep = torch.export.export(model, example_inputs, strict=True) - with fresh_inductor_cache(): + with fresh_cache(): # cubin files are removed when exiting this context package_path = torch._inductor.aoti_compile_and_package( ep, @@ -218,9 +221,10 @@ def forward(self, x, y): package_path = torch._inductor.aoti_compile_and_package( ep, inductor_configs=options ) - with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile( - package_path, "r" - ) as zip_ref: + with ( + tempfile.TemporaryDirectory() as tmp_dir, + zipfile.ZipFile(package_path, "r") as zip_ref, + ): filenames = zip_ref.namelist() prefix = filenames[0].split("/")[0] zip_ref.extractall(tmp_dir) @@ -284,7 +288,7 @@ def forward(self, x, y): options = { "aot_inductor.package_cpp_only": self.package_cpp_only, - # Expect kernel to be embeded in the final binary. + # Expect kernel to be embedded in the final binary. # We will make it the default behavior for the standalone mode. "aot_inductor.emit_multi_arch_kernel": True, "aot_inductor.embed_kernel_binary": True, @@ -293,9 +297,10 @@ def forward(self, x, y): package_path = torch._inductor.aoti_compile_and_package( ep, inductor_configs=options ) - with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile( - package_path, "r" - ) as zip_ref: + with ( + tempfile.TemporaryDirectory() as tmp_dir, + zipfile.ZipFile(package_path, "r") as zip_ref, + ): filenames = zip_ref.namelist() prefix = filenames[0].split("/")[0] zip_ref.extractall(tmp_dir) @@ -517,8 +522,8 @@ def forward(self, a, b): ) @skipif( - lambda device, package_cpp_only: device == "cpu" or package_cpp_only, - "No support for cpp only and cpu", + lambda device, package_cpp_only: package_cpp_only, + "No support for cpp only", ) def test_package_without_weight(self): class Model(torch.nn.Module): @@ -551,8 +556,8 @@ def forward(self, a): self.assertEqual(expected, output) @skipif( - lambda device, package_cpp_only: device == "cpu" or package_cpp_only, - "No support for cpp only and cpu", + lambda device, package_cpp_only: package_cpp_only, + "No support for cpp only", ) def test_package_user_managed_weight(self): class Model(torch.nn.Module): @@ -630,8 +635,8 @@ def forward(self, x, y): self.assertEqual(expected, output_copy) @skipif( - lambda device, package_cpp_only: device == "cpu" or package_cpp_only, - "No support for cpp only and cpu", + lambda device, package_cpp_only: package_cpp_only, + "No support for cpp only", ) def test_update_weights(self): class Model(torch.nn.Module): @@ -661,6 +666,140 @@ def forward(self, a): output = compiled(test_inputs) self.assertEqual(expected, output) + @skipif( + lambda device, package_cpp_only: package_cpp_only, + "No support for cpp only", + ) + def test_package_shared_weights(self): + options = { + "aot_inductor.package": True, + "aot_inductor.package_cpp_only": self.package_cpp_only, + "always_keep_tensor_constants": True, + "aot_inductor.package_constants_in_so": False, + "aot_inductor.package_constants_on_disk": True, + } + + class Bar(torch.nn.Module): + def __init__(self, p1, p2): + super().__init__() + self.p1 = p1 + self.register_buffer("p2", p2) + + def forward(self): + self.p1 += 1 + self.p2 += 1 + return self.p1, self.p2 + + class Bar2(torch.nn.Module): + def __init__(self, p1, p2): + super().__init__() + self.p1 = p1 + self.register_buffer("p2", p2[2:3]) + + def forward(self): + self.p1 += 3 + self.p2 += 3 + return self.p1, self.p2 + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + buffer = torch.nn.Buffer(x.clone()) + buffer2 = torch.nn.Buffer(y.clone()) + bar1 = Bar(buffer, buffer2) + bar2 = Bar2(buffer, buffer2) + ep1 = torch.export.export(bar1, ()) + ep2 = torch.export.export(bar2, ()) + aoti_files1 = torch._inductor.aot_compile(ep1.module(), (), options=options) + aoti_files2 = torch._inductor.aot_compile(ep2.module(), (), options=options) + + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + package_path = package_aoti( + f.name, + {"model1": aoti_files1, "model2": aoti_files2}, + ) + pt2_contents = load_pt2(package_path, load_weights_from_disk=True) + loaded1 = pt2_contents.aoti_runners["model1"] + loaded2 = pt2_contents.aoti_runners["model2"] + + # note that loading like below doesn't work, because new weights will be loaded + # for each load_package call. + # loaded1 = load_package(package_path, "model1") + # loaded2 = load_package(package_path, "model2") + + result_1_p1, result_1_p2 = loaded1() + self.assertEqual(result_1_p1, x + 1) + self.assertEqual(result_1_p2, y + 1) + + result_2_p1, result_2_p2 = loaded2() + # the result already incremented by 1 from the run above + self.assertEqual(result_2_p1, x + 4) + self.assertEqual(result_2_p2, y[2:3] + 4) + + # note that the returned result will not change though p2 changed + self.assertEqual(result_1_p2, y + 1) + + # test shared weights but user managed + gm1 = ep1.module() + gm2 = ep2.module() + load_weights_to_pt2_contents( + pt2_contents, {"model1": gm1.state_dict(), "model2": gm2.state_dict()} + ) + result_1_p1, result_1_p2 = loaded1() + self.assertEqual(result_1_p1, x + 1) + self.assertEqual(result_1_p2, y + 1) + self.assertEqual(gm1.p1, x + 1) + self.assertEqual(gm1.p2, y + 1) + + @skipif( + lambda device, package_cpp_only: package_cpp_only, + "No support for cpp only", + ) + def test_package_weights_on_disk_nested_module(self): + options = { + "aot_inductor.package": True, + "aot_inductor.package_cpp_only": self.package_cpp_only, + "always_keep_tensor_constants": True, + "aot_inductor.package_constants_in_so": False, + "aot_inductor.package_constants_on_disk": True, + } + + # linear.weight's node name is linear_weight. + # This unit test tests that we package the right weight name + # `liear.weight`, but not `linear_weight` + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return self.linear(x) + + x = torch.randn(3, 3).to(self.device) + bar1 = Bar().to(self.device) + ep = torch.export.export(bar1, (x,)) + package_path = torch._inductor.aoti_compile_and_package( + ep, inductor_configs=options + ) + pt2_contents = load_pt2(package_path, load_weights_from_disk=True) + loaded1 = pt2_contents.aoti_runners["model"] + self.assertEqual(loaded1(x), bar1(x)) + + def test_loading_wrong_model(self): + class Model(torch.nn.Module): + def forward(self, x): + return x + 1 + + example_inputs = (torch.randn(10, 10, device=self.device),) + model = Model() + ep = torch.export.export(model, example_inputs) + package_path = torch._inductor.aoti_compile_and_package(ep) + + with self.assertRaisesRegex( + RuntimeError, + "Failed to find a generated cpp file or so file for model 'forward' in the zip archive.", + ): + load_package(package_path, model_name="forward") + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 6868928957a210..9d25aa47560183 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -5,7 +5,7 @@ import shutil import tempfile import types -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch._export @@ -16,11 +16,15 @@ from torch._inductor import config from torch._inductor.test_case import TestCase from torch.testing import FileCheck -from torch.testing._internal.common_utils import IS_FBCODE +from torch.testing._internal.common_utils import IS_FBCODE, run_tests from torch.testing._internal.inductor_utils import clone_preserve_strides_offset from torch.utils import _pytree as pytree +if TYPE_CHECKING: + from torch._C._aoti import AOTIModelContainerRunner + + class WrapperModule(torch.nn.Module): def __init__(self, model): super().__init__() @@ -73,7 +77,7 @@ def legacy_compile( return so_path @staticmethod - def legacy_load_runner(device, so_path): + def legacy_load_runner(device, so_path: str) -> "AOTIModelContainerRunner": if IS_FBCODE: from .fb import test_aot_inductor_model_runner_pybind # @manual @@ -205,17 +209,20 @@ def check_model( atol=None, rtol=None, ): - with torch.no_grad(), config.patch( - { - "aot_inductor.allow_stack_allocation": self.allow_stack_allocation, - "aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, - } + with ( + torch.no_grad(), + config.patch( + { + "aot_inductor.allow_stack_allocation": self.allow_stack_allocation, + "aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, + } + ), ): torch.manual_seed(0) if not isinstance(model, types.FunctionType): model = model.to(self.device) - # For non mixed device inputs with default "cpu",set the device manully. + # For non mixed device inputs with default "cpu",set the device manually. if all( t.device.type == "cpu" for t in example_inputs @@ -248,11 +255,14 @@ def check_model_with_multiple_inputs( options=None, dynamic_shapes=None, ): - with torch.no_grad(), config.patch( - { - "aot_inductor.allow_stack_allocation": self.allow_stack_allocation, - "aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, - } + with ( + torch.no_grad(), + config.patch( + { + "aot_inductor.allow_stack_allocation": self.allow_stack_allocation, + "aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, + } + ), ): torch.manual_seed(0) model = model.to(self.device) @@ -275,11 +285,14 @@ def code_check_count( target_str: str, target_count: int, ): - with torch.no_grad(), config.patch( - { - "aot_inductor.allow_stack_allocation": self.allow_stack_allocation, - "aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, - } + with ( + torch.no_grad(), + config.patch( + { + "aot_inductor.allow_stack_allocation": self.allow_stack_allocation, + "aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, + } + ), ): package_path = torch._export.aot_compile(model, example_inputs) @@ -290,3 +303,7 @@ def code_check_count( target_count, exactly=True, ).run(src_code) + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_async_compile.py b/test/inductor/test_async_compile.py index d05fa4748667ef..67d9c1493070f8 100644 --- a/test/inductor/test_async_compile.py +++ b/test/inductor/test_async_compile.py @@ -3,7 +3,7 @@ from torch._inductor import config from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -32,7 +32,7 @@ def fn(x, y): pool = AsyncCompile.process_pool() pool.ready_future.result(timeout=120) - with fresh_inductor_cache(): + with fresh_cache(): compiled_fn = torch.compile(fn) self.assertEqual(fn(x, y), compiled_fn(x, y)) diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index de0405bb20e1f5..a0453fab40e597 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -255,9 +255,10 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 def test_auto_functionalize_on_view(self): for value in [True, False]: - with torch.library._scoped_library( - "mylib", "FRAGMENT" - ) as lib, inductor_config.patch({"enable_auto_functionalized_v2": value}): + with ( + torch.library._scoped_library("mylib", "FRAGMENT") as lib, + inductor_config.patch({"enable_auto_functionalized_v2": value}), + ): torch.library.define( "mylib::foo", "(Tensor(a!) x) -> ()", @@ -1398,7 +1399,7 @@ def test_round_trip(base, tensor): test_round_trip(t, f[1]) test_round_trip(t, f[2]) - # example where slice wont work + # example where slice won't work # selection t = torch.ones(10) @@ -1560,7 +1561,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): def test_alias2_dynamic(self): self.test_alias2(_dynamic=True) - # Test that the view regenration optimizations do not result in recompilations. By comparing re-compilation in eager backend + # Test that the view regeneration optimizations do not result in recompilations. By comparing re-compilation in eager backend # with recompilation in inductor backend. @torch.fx.experimental._config.patch(use_duck_shape=False) def test_recompile(self): diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 86cf79e496caa3..b3afba7d6843fa 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -7,7 +7,7 @@ from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.test_operators import realize -from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code +from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_utils import slowTest from torch.testing._internal.inductor_utils import ( @@ -96,9 +96,12 @@ def new_benchmark_fn(scheduler, nodes): # Disable dynamic_scale_rblock to make it easier to trigger register # spilling. - with unittest.mock.patch.object( - Scheduler, "benchmark_fused_nodes", new_benchmark_fn - ), config.patch("dynamic_scale_rblock", False): + with ( + unittest.mock.patch.object( + Scheduler, "benchmark_fused_nodes", new_benchmark_fn + ), + config.patch("dynamic_scale_rblock", False), + ): S = 512 def f(*inputs): @@ -170,9 +173,10 @@ def foo(m, inp): ".run", 2, exactly=True ).run(out_code[0]) - with config.patch( - {"benchmark_fusion": False, "epilogue_fusion": False} - ), torch.no_grad(): + with ( + config.patch({"benchmark_fusion": False, "epilogue_fusion": False}), + torch.no_grad(), + ): torch._dynamo.reset() foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) @@ -283,7 +287,7 @@ def foo(m, inp): self.assertEqual(res, res2, atol=1e-4, rtol=1.1) return code, code2 - @fresh_inductor_cache() + @fresh_cache() @config.patch(max_autotune_gemm_backends="TRITON") def test_equivalent_template_code(self): code, code2 = self._equivalent_output_code_impl(256) @@ -292,13 +296,9 @@ def test_equivalent_template_code(self): "empty_strided", 1, exactly=True ).check("triton_tem_fused_addmm_relu_0").check_count( ".reset()" if config.cpp_wrapper else "del", 3, exactly=True - ).check( - "" if config.cpp_wrapper else "return" - ).run( - out_code[0] - ) + ).check("" if config.cpp_wrapper else "return").run(out_code[0]) - @fresh_inductor_cache() + @fresh_cache() @config.patch(max_autotune_gemm_backends="ATEN") def test_equivalent_extern_code(self): torch._dynamo.reset() @@ -310,11 +310,7 @@ def test_equivalent_extern_code(self): "empty_strided", 1, exactly=True ).check("" if config.cpp_wrapper else "extern_kernels.").check_count( ".reset()" if config.cpp_wrapper else "del", 3, exactly=True - ).check( - "" if config.cpp_wrapper else "return" - ).run( - out_code[0] - ) + ).check("" if config.cpp_wrapper else "return").run(out_code[0]) def test_changed_layout(self): # cat addmm planning will change layout - make sure propagated diff --git a/test/inductor/test_best_config.py b/test/inductor/test_best_config.py new file mode 100644 index 00000000000000..7a2ce535a406bc --- /dev/null +++ b/test/inductor/test_best_config.py @@ -0,0 +1,96 @@ +# Owner(s): ["module: inductor"] + +import glob +import json +import os +import sys +import tempfile +import unittest + +import torch +from torch._inductor import config +from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU + + +try: + import triton # noqa: F401 +except ImportError as e: + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires triton") from e + +from torch._inductor.test_case import run_tests, TestCase + + +def trivial_kernel(x): + return torch.sin(x) + torch.cos(x) + + +class TestKernelBestConfig(TestCase): + device_type = GPU_TYPE + + @classmethod + def setUpClass(cls): + # Save the original configuration and environment variables. + cls.original_compile_threads = config.compile_threads + cls.original_max_autotune = config.max_autotune + cls.original_inductor_env = os.environ.get("TORCHINDUCTOR_CACHE_DIR", "") + cls.original_triton_env = os.environ.get("TRITON_CACHE_DIR", "") + super().setUpClass() + + @classmethod + def tearDownClass(cls): + # Restore the original configuration and environment variables. + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cls.original_inductor_env + os.environ["TRITON_CACHE_DIR"] = cls.original_triton_env + config.compile_threads = cls.original_compile_threads + config.max_autotune = cls.original_max_autotune + super().tearDownClass() + + @skipIfXpu + def test_best_config_has_triton_cache_key(self): + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = tmpdir + triton_cache_dir = os.path.join(tmpdir, "triton_cache") + os.environ["TRITON_CACHE_DIR"] = triton_cache_dir + + config.compile_threads = 0 + config.max_autotune = True + + compiled_fn = torch.compile(trivial_kernel) + + x = torch.randn(32, 10, device=GPU_TYPE) + compiled_fn(x) + + # Search for .best_config files in the inductor cache directory. + best_config_files = glob.glob( + os.path.join(tmpdir, "**", "*.best_config"), recursive=True + ) + self.assertGreater( + len(best_config_files), + 0, + f"No best_config files found in {tmpdir}. Directory contents: {os.listdir(tmpdir)}", + ) + + # Validate that each best_config file contains a real triton_cache_hash, + # and that a corresponding Triton cache directory exists. + for file_path in best_config_files: + with open(file_path) as f: + data = json.load(f) + self.assertIn( + "triton_cache_hash", + data, + f"Missing triton_cache_hash in {os.path.basename(file_path)}", + ) + cache_hash = data["triton_cache_hash"] + expected_path = os.path.join(triton_cache_dir, cache_hash) + self.assertTrue( + os.path.exists(expected_path), + f"Triton cache directory missing: {expected_path}", + ) + + +if __name__ == "__main__": + if IS_LINUX and HAS_GPU: + run_tests() diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index 539c598b2235f4..2652c94264d6ac 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -66,13 +66,13 @@ def setUp(self): os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1" super().setUp() finally: - os.environ[ - "INDUCTOR_TEST_DISABLE_FRESH_CACHE" - ] = old_disable_fresh_cache_envvar + os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = ( + old_disable_fresh_cache_envvar + ) @unittest.skipIf(not torch.version.hip, "ROCM only") @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) + @parametrize("max_autotune_gemm_backends", ("CK", "CKTILE", "ATen,Triton,CK")) @parametrize("autotune_in_subproc", (True, False)) @parametrize("use_aoti", (True, False)) def test_max_autotune_precompile_matmul( diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 4a042b6b41ae2b..5466a26ffdb9a0 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import functools +import logging import os import pickle import shutil @@ -7,6 +8,7 @@ import sys import tempfile import unittest +from contextlib import contextmanager from typing import Optional, Union from typing_extensions import override from unittest import mock @@ -27,12 +29,16 @@ TensorMetadata, TensorMetadataAndValues, ) -from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files +from torch._inductor.custom_graph_pass import ( + CustomGraphModulePass, + CustomGraphPass, + get_hash_for_files, +) from torch._inductor.graph import GraphLowering from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache +from torch._inductor.utils import clear_caches, fresh_cache from torch._library import capture_triton from torch.compiler._cache import ( CacheArtifact, @@ -53,10 +59,10 @@ HAS_GPU, HAS_MULTIGPU, HAS_TRITON, + patch_inductor_backend, requires_gpu, requires_triton, ) -from torch.testing._internal.logging_utils import multiple_logs_to_string from torch.testing._internal.triton_utils import requires_cuda @@ -69,6 +75,36 @@ torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True +class LogCaptureHandler(logging.Handler): + def __init__(self, level): + super().__init__(level) + self.records = [] + + def emit(self, record): + self.records.append(record) + + +@contextmanager +def capture_logs(log_name, log_level): + try: + logger = logging.getLogger(log_name) + old_level = logger.level + handler = logging.Handler() + logger.setLevel(log_level) + log_records = [] + + def emit(record): + log_records.append(record) + + handler.emit = emit + logger.addHandler(handler) + + yield log_records + finally: + logger.removeHandler(handler) + logger.setLevel(old_level) + + class MyModelConv2d(torch.nn.Module): def __init__(self, dim=512): super().__init__() @@ -82,6 +118,16 @@ def forward(self, x): return x +class TestPyCodeCache(TestCase): + def test_linemaps_empty(self): + src = """import torch""" + (key, path) = PyCodeCache.write(src, "") + # Load with an empty linemap + PyCodeCache.load_by_key_path(key, path, linemap=[]) + stack_frames = PyCodeCache.stack_frames_for_code(path, 0) + self.assertEqual(stack_frames, None) + + @instantiate_parametrized_tests class TestFxGraphCache(TestCase): device_type = GPU_TYPE @@ -100,7 +146,7 @@ def reset(self): AOTAutogradCache.clear() PyCodeCache.cache_clear(purge=True) torch._dynamo.reset() - clear_inductor_caches() + clear_caches() @requires_triton() @config.patch({"fx_graph_cache": True}) @@ -163,13 +209,25 @@ def fn(x, y): ) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) - # "cuda" has .ptx and .cubin file, but xpu only has .spv file - save_kernel_count = 6 if device == "xpu" else 7 - read_and_emit_kernel_count = 6 if device == "xpu" else 7 + + # we expect: + # .ttir + # .ttgir + # .llir + # .ptx (cuda) or .spv (xpu) + # .json + # __grp__.*.json + # optionally, we can also get + # .cubin (CUDA only) + # .source (new versions of triton only, triton-lang/triton#6992) + + # to avoid depending on the device and triton version, just assert that + # we have at least 6 kernels. + save_and_read_min_artifact_count = 6 if bundle_triton and device != "cpu": - self.assertEqual( + self.assertGreaterEqual( counters["inductor"]["triton_bundler_save_kernel"], - grad_multiplier * save_kernel_count, + grad_multiplier * save_and_read_min_artifact_count, ) self.assertEqual( counters["inductor"]["triton_bundler_read_and_emit_kernel"], 0 @@ -214,13 +272,13 @@ def fn(x, y): ) if bundle_triton and device != "cpu": - self.assertEqual( + self.assertGreaterEqual( counters["inductor"]["triton_bundler_save_kernel"], - grad_multiplier * save_kernel_count, + grad_multiplier * save_and_read_min_artifact_count, ) - self.assertEqual( + self.assertGreaterEqual( counters["inductor"]["triton_bundler_read_and_emit_kernel"], - grad_multiplier * read_and_emit_kernel_count, + grad_multiplier * save_and_read_min_artifact_count, ) if use_static_cuda_launcher: self.assertEqual( @@ -262,13 +320,13 @@ def fn(x, y): ) if bundle_triton and device != "cpu": - self.assertEqual( + self.assertGreaterEqual( counters["inductor"]["triton_bundler_save_kernel"], - grad_multiplier * save_kernel_count * 2, + grad_multiplier * save_and_read_min_artifact_count * 2, ) - self.assertEqual( + self.assertGreaterEqual( counters["inductor"]["triton_bundler_read_and_emit_kernel"], - grad_multiplier * read_and_emit_kernel_count, + grad_multiplier * save_and_read_min_artifact_count, ) if use_static_cuda_launcher: self.assertEqual( @@ -312,25 +370,27 @@ def fn(x, y): a = torch.rand(25, dtype=dtype, device=device) b = torch.rand(5, 5, dtype=dtype, device=device) - with config.patch( - { - "fx_graph_remote_cache": True, - "bundle_triton_into_fx_graph_cache": bundle_triton, - "use_static_cuda_launcher": use_static_cuda_launcher, - } - ), patch.dict(os.environ), PatchCaches(): + with ( + config.patch( + { + "fx_graph_remote_cache": True, + "bundle_triton_into_fx_graph_cache": bundle_triton, + "use_static_cuda_launcher": use_static_cuda_launcher, + } + ), + patch.dict(os.environ), + PatchCaches(), + ): os.environ.pop("TRITON_CACHE_MANAGER", None) for _ in range(4): - with fresh_inductor_cache(): + with fresh_cache(): compiled_fn = torch.compile(fn, dynamic=dynamic) self.assertEqual(fn(a, b), compiled_fn(a, b)) reset() self.assertEqual(global_stats.fx_graph, Stats(1, 3, 1)) - with torch.compiler.config.patch( - {"cache_key_tag": "test"} - ), fresh_inductor_cache(): + with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache(): compiled_fn = torch.compile(fn, dynamic=dynamic) self.assertEqual(fn(a, b), compiled_fn(a, b)) @@ -368,7 +428,7 @@ def fn(x, y): b = torch.rand(100, 100, dtype=dtype, device=device) # Record artifacts - with fresh_inductor_cache(): + with fresh_cache(): compiled_fn = torch.compile(fn, dynamic=dynamic) # A first call should miss in the cache. @@ -398,7 +458,7 @@ def fn(x, y): shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) # We did not load anything so dont hit yet - with fresh_inductor_cache(): + with fresh_cache(): eager_result = fn(a, b) compiled_result = compiled_fn(a, b) self.assertEqual(eager_result, compiled_result) @@ -412,7 +472,7 @@ def fn(x, y): shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) # Hot load and hit - with fresh_inductor_cache(): + with fresh_cache(): cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) self.assertEqual(len(cache_info.inductor_artifacts), 1) @@ -445,7 +505,7 @@ def fn(x, y): a2 = torch.randn(4, 8) b2 = torch.randn(8, 4) - with fresh_inductor_cache(): + with fresh_cache(): eager_result = fn(a, b) compiled_result = compiled_fn(a, b) self.assertEqual(eager_result, compiled_result) @@ -461,7 +521,7 @@ def fn(x, y): self.reset() - with fresh_inductor_cache(): + with fresh_cache(): torch.compiler.load_cache_artifacts(artifact_bytes) eager_result = fn(a, b) compiled_result = compiled_fn(a, b) @@ -473,7 +533,7 @@ def fn(x, y): self.reset() - with fresh_inductor_cache(): + with fresh_cache(): eager_result = fn(a2, b2) compiled_result = compiled_fn(a2, b2) self.assertEqual(eager_result, compiled_result) @@ -497,7 +557,7 @@ def f(x): return x * 2 # Record artifacts - with torch.compiler.config.patch(job_id=self.id()), fresh_inductor_cache(): + with torch.compiler.config.patch(job_id=self.id()), fresh_cache(): f(torch.randn(2, 3)) f(torch.randn(2, 4)) self.assertEqual(backend.frame_count, 2) @@ -524,7 +584,7 @@ def f(x): shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) # Hot load and hit - with torch.compiler.config.patch({"job_id": self.id()}), fresh_inductor_cache(): + with torch.compiler.config.patch({"job_id": self.id()}), fresh_cache(): cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) self.assertEqual(len(cache_info.inductor_artifacts), 2) @@ -559,7 +619,7 @@ def f(x): with mock.patch( "torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5) ): - with fresh_inductor_cache(): + with fresh_cache(): f(torch.randn(2, 3)) f(torch.randn(2, 4)) self.assertEqual(backend.frame_count, 2) @@ -579,9 +639,13 @@ def f(x): shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) # Hot load and hit - with mock.patch( - "torch._utils_internal.get_mast_job_name_version", return_value=("bar", 10) - ), fresh_inductor_cache(): + with ( + mock.patch( + "torch._utils_internal.get_mast_job_name_version", + return_value=("bar", 10), + ), + fresh_cache(), + ): cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) self.assertEqual(len(cache_info.pgo_artifacts), 2) @@ -1524,7 +1588,7 @@ def reset(self): AOTAutogradCache.clear() PyCodeCache.cache_clear(purge=True) torch._dynamo.reset() - clear_inductor_caches() + clear_caches() def capture(self, fn, dynamic=None): def inner(*args): @@ -1554,7 +1618,10 @@ def backend(gm_, args_, **kwargs_): @parametrize("device", (GPU_TYPE, "cpu")) @parametrize("format", ("binary", "unpacked")) @parametrize("dynamic", (False, True)) - def test_basic(self, device: str, format: str, dynamic: bool) -> None: + @parametrize("graph_partition", (False, True)) + def test_basic( + self, device: str, format: str, dynamic: bool, graph_partition: bool + ) -> None: if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") @@ -1569,13 +1636,16 @@ def f(x): eager_out = f(x) - with tempfile.TemporaryDirectory() as temp_dir: + with ( + tempfile.TemporaryDirectory() as temp_dir, + config.patch(graph_partition=graph_partition), + ): path = ( temp_dir if format == "unpacked" else os.path.join(temp_dir, "compiled_artifact.bin") ) - with fresh_inductor_cache(): + with fresh_cache(): gm, args, kwargs = self.capture(f)(x) assert not kwargs @@ -1584,7 +1654,7 @@ def f(x): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) - with fresh_inductor_cache(): + with fresh_cache(): loaded = torch._inductor.CompiledArtifact.load(path=path, format=format) if dynamic: concrete_args = [ @@ -1616,7 +1686,7 @@ def f(x): def backend(gm, args, **kwargs): return torch._inductor.standalone_compile(gm, args) - with fresh_inductor_cache(): + with fresh_cache(): compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x) self.assertEqual(eager_out, compiled_out) @@ -1635,7 +1705,7 @@ def f(x): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, "new_dir") - with fresh_inductor_cache(): + with fresh_cache(): gm, args, kwargs = self.capture(f)(x) assert not kwargs @@ -1644,7 +1714,7 @@ def f(x): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) - with fresh_inductor_cache(): + with fresh_cache(): loaded = torch._inductor.CompiledArtifact.load( path=path, format="unpacked" ) @@ -1668,7 +1738,7 @@ def f(x): eager_out = f(x) with tempfile.TemporaryDirectory() as temp_dir: - with fresh_inductor_cache(): + with fresh_cache(): gm, args, kwargs = self.capture(f)(x) assert not kwargs @@ -1680,7 +1750,7 @@ def f(x): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) - with fresh_inductor_cache(): + with fresh_cache(): # Now modify the output file and expect to see the changes for subdir in os.listdir(temp_dir): if subdir in ["aotautograd", "fxgraph"]: @@ -1728,16 +1798,16 @@ def f(x): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, "compiled_artifact.bin") - with fresh_inductor_cache(): + with fresh_cache(): compiled_artifact = torch._inductor.standalone_compile(gm, args) compiled_artifact.save(path=path) script = f""" import torch -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache arg = torch.ones(4, 1) -with fresh_inductor_cache(): +with fresh_cache(): loaded = torch._inductor.CompiledArtifact.load(path="{path}") compiled_result = loaded(arg)[0] @@ -1769,7 +1839,7 @@ def f(x): x = torch.ones(3) torch._dynamo.mark_dynamic(x, 0) - with fresh_inductor_cache(): + with fresh_cache(): # captured graph is lambda s0, x: x * s0 gm, args, kwargs = self.capture(f)(x) assert not kwargs @@ -1791,7 +1861,7 @@ def f(x): x = torch.ones(3) torch._dynamo.mark_dynamic(x, 0) - with fresh_inductor_cache(): + with fresh_cache(): # captured graph is lambda s0, x: x * s0 gm, args, kwargs = self.capture(f)(x) assert not kwargs @@ -1827,7 +1897,7 @@ def f(x): return x.shape[0] * x static_x = torch.randn(3) - with fresh_inductor_cache(): + with fresh_cache(): # static_gm is lambda x: x * 3 static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x) assert not kwargs @@ -2186,11 +2256,9 @@ def forward(self, x): torch._dynamo.reset() counters.clear() - (codecache_stream,), ctx = multiple_logs_to_string( - "torch._inductor.codecache", "codecache" - ) - with ctx(), config.patch( - {"_fuse_ddp_communication_passes": [lambda *args: None]} + with ( + capture_logs("torch._inductor.codecache", logging.INFO) as logs, + config.patch({"_fuse_ddp_communication_passes": [lambda *args: None]}), ): # bypass (custom pass is not serializable) mod_compiled(x) @@ -2199,10 +2267,12 @@ def forward(self, x): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) counters.clear() # assert that our bypass is explicit - codecache_logs = codecache_stream.getvalue().strip() self.assertTrue( - "Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'" - in codecache_logs + any( + x.getMessage() + == "Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'" + for x in logs + ) ) def test_hash_custom_passes(self): @@ -2241,6 +2311,42 @@ def uuid(self) -> Optional[Union[bytes, str]]: pickler.dumps(details3), ) + def test_hash_custom_backend_pass(self): + """ + Test CustomGraphModulePass usage. + """ + + class TestCustomGraphModulePass(CustomGraphModulePass): + def __init__(self): + self._uuid = None + + def __call__(self, gm: torch.fx.GraphModule) -> None: + return None + + def uuid(self) -> Optional[Union[bytes, str]]: + return self._uuid + + custom_pass = TestCustomGraphModulePass() + with patch_inductor_backend("cpu", custom_pass=custom_pass): + custom_pass._uuid = "1" + details1 = FxGraphHashDetails(None, [], {}, []) + details2 = FxGraphHashDetails(None, [], {}, []) + + custom_pass._uuid = "2" + details3 = FxGraphHashDetails(None, [], {}, []) + + gm = torch.fx.GraphModule({}, torch.fx.Graph()) + pickler = FxGraphCachePickler(gm) + + self.assertEqual( + pickler.dumps(details1), + pickler.dumps(details2), + ) + self.assertNotEqual( + pickler.dumps(details1), + pickler.dumps(details3), + ) + def test_bypass_unsupported(self): """ Test _reduce_unsupported @@ -2257,7 +2363,7 @@ def test_stable_strings(self): even if they are not the same id. """ s1 = "string" - s2 = "strin" + s2 = "strin" # codespell:ignore s2 += "g" self.assertNotEqual(id(s1), id(s2)) @@ -2340,7 +2446,7 @@ def tearDown(self): def reset(self): PyCodeCache.cache_clear(purge=True) torch._dynamo.reset() - clear_inductor_caches() + clear_caches() @unittest.skipIf(not HAS_CUDA, "Requires CUDA") @unittest.skipIf(not SM80OrLater, "Requires SM80+") @@ -2650,20 +2756,20 @@ def fn(a, b): class TestUtils(TestCase): @config.patch({"fx_graph_remote_cache": False}) - def test_fresh_inductor_cache(self): + def test_fresh_cache(self): def fn(x, y): return x + y a = torch.rand(10) b = torch.rand(10) - with fresh_inductor_cache(): + with fresh_cache(): self.assertEqual(len(PyCodeCache.modules), 0) res1 = torch.compile(fn)(a, b) cache_dir1 = cache_dir() torch._dynamo.reset() - with fresh_inductor_cache(): + with fresh_cache(): self.assertEqual(len(PyCodeCache.modules), 0) res2 = torch.compile(fn)(a, b) cache_dir2 = cache_dir() diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index bccdacab2a679a..a054464bf6689f 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -475,6 +475,25 @@ def fn(x, y): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", True) + @torch._dynamo.config.patch("assume_static_by_default", True) + def test_dynamic_shapes_persistent_reduction_no_x_dim_2(self): + def fn(x, y): + return x.sum(2), y.sum(2) + + inps = ( + torch.rand(8, 16, 256, device="cuda"), + torch.rand(8, 32, 256, device="cuda"), + ) + torch._dynamo.mark_dynamic(inps[0], (0, 1), min=1, max=256) + torch._dynamo.mark_dynamic(inps[1], (0, 1), min=1, max=256) + out_eager = fn(*inps) + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) + @requires_cuda @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index 9ec1432c8369ed..6eba88ecae970d 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -9,14 +9,24 @@ import os import sys import time +import unittest +from unittest import mock from unittest.mock import patch import torch import torch.library from torch._inductor.compile_fx import _InProcessFxCompile, FxCompile, FxCompileMode +from torch._inductor.graph import GraphLowering from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import TEST_WITH_ASAN -from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + IS_BIG_GPU, + requires_gpu, + requires_triton, + RUN_CPU, + RUN_GPU, +) # Make the helper files in test/ importable @@ -42,6 +52,9 @@ "test_remove_noop_slice_scatter": TestFailure(("xpu"), is_skip=True), "test_remove_noop_view_default": TestFailure(("xpu"), is_skip=True), "test_remove_noop_view_dtype": TestFailure(("xpu"), is_skip=True), + # TODO:remove test_upsample_bicubic2d after the following issue resolved: + # https://github.com/intel/intel-xpu-backend-for-triton/issues/4184 + "test_upsample_bicubic2d": TestFailure(("xpu"), is_skip=False), } @@ -72,6 +85,88 @@ def tearDown(self): TestCase.tearDown(self) torch._dynamo.reset() + @requires_gpu() + @requires_triton() + @unittest.skipIf( + not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)" + ) + def test_progressive(self): + from triton.testing import do_bench + + from torch._inductor.compile_fx_async import _ProgressiveFxCompile + + torch._inductor.compile_fx.fx_compile_progressive = True + + x = torch.randn(1152, 1024, device=GPU_TYPE, dtype=torch.bfloat16) + y = torch.randn(1024, 1024, device=GPU_TYPE, dtype=torch.bfloat16) + + @torch.compile(fullgraph=True, backend="inductor") + def optimized(x, y): + return (x @ y).relu() + + _ProgressiveFxCompile._reset_stats() + + source_codes: list[str] = [] + + def save_output_code(code: str) -> None: + source_codes.append(code) + + with contextlib.ExitStack() as stack: + # When this bug is fixed, remove the cache disabling below + assert torch._inductor.compile_fx_async.BUG_CACHES_DONT_WORK_WITH_ASYNC + stack.enter_context( + torch._inductor.config.patch( + autotune_local_cache=False, fx_graph_cache=False + ) + ) + stack.enter_context( + mock.patch.object(GraphLowering, "save_output_code", save_output_code) + ) + stack.enter_context( + torch._functorch.config.patch(enable_autograd_cache=False) + ) + + # How long to wait (in seconds) before giving up. + TIMEOUT = 300 + # If non-None then how often (in seconds) to print a TICK message. + TICK_REPORT = None + + start = time.time() + last_report = start + while _ProgressiveFxCompile._stat_optimized_runs < 4: + time.sleep(0.25) + + optimized(x, y) + + now = time.time() + if TICK_REPORT is not None and (now - last_report > TICK_REPORT): + print(f"*** TICK {int(now - start)}") + last_report = now + + if now - start > TIMEOUT: + raise RuntimeError( + "Test timed out before producing a progressively optimized compiled artifact." + ) + + self.assertEqual(_ProgressiveFxCompile._stat_optimized_runs, 4) + self.assertGreater(_ProgressiveFxCompile._stat_fast_runs, 0) + self.assertGreaterEqual(_ProgressiveFxCompile._stat_bg_started, 1) + self.assertGreaterEqual(_ProgressiveFxCompile._stat_bg_finished, 1) + + torch._inductor.compile_fx.fx_compile_progressive = False + + @torch.compile(fullgraph=True, backend="inductor") + def baseline(x, y): + return (x @ y).relu() + + # Warmup + baseline(x, y) + + self.assertGreater( + do_bench(lambda: baseline(x, y)), do_bench(lambda: optimized(x, y)) + ) + self.assertTrue("'max_autotune': True" in source_codes[-1]) + @patch("torch._inductor.compile_fx.fx_compile_async", True) def test_async(self): # Test that async+subprocess works. @@ -87,7 +182,7 @@ def model_add(x, y): _AsyncFxCompile._reset_stats() with contextlib.ExitStack() as stack: - # TODO: Turn off local caches - they don't play nice w/ async currently. + assert torch._inductor.compile_fx_async.BUG_CACHES_DONT_WORK_WITH_ASYNC stack.enter_context( torch._inductor.config.patch( autotune_local_cache=False, fx_graph_cache=False @@ -108,9 +203,14 @@ def model_add(x, y): # Sleep a bit so we don't drive the CPU unnecessarily. time.sleep(0.25) - x = torch.randn(100, 100) - y = torch.randn(100, 100) - model_add(x, y) + x = torch.randn(100, 100, requires_grad=True) + y = torch.randn(100, 100, requires_grad=True) + + # Forward pass + output = model_add(x, y) + + # Backward pass + output.sum().backward() # DEBUGGING: Print a periodic message so we know we're still # running... diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 9d4c970323dab0..c0d304290b0569 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -32,6 +32,7 @@ from torch._inductor.test_case import run_tests, TestCase from torch.nn.attention.flex_attention import flex_attention from torch.nn.parallel import DistributedDataParallel as DDP +from torch.overrides import BaseTorchFunctionMode from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, @@ -46,6 +47,7 @@ from torch.testing._internal.hop_db import hop_db from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU from torch.testing._internal.logging_utils import logs_to_string +from torch.utils._python_dispatch import TorchDispatchMode # note: these tests are not run on windows due to inductor_utils.HAS_CPU @@ -99,14 +101,28 @@ def reset(): torch._logging.set_logs(compiled_autograd_verbose=False) config.compiled_autograd = False compiled_autograd.reset() + torch._dynamo.utils.counters.clear() + + +class BaseCustomOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x * 2 + + @staticmethod + def backward(ctx, grad_output): + raise NotImplementedError("must override") class TestCompiledAutograd(TestCase): def setUp(self) -> None: + self.exit_stack = contextlib.ExitStack() + self.exit_stack.enter_context(config.patch("record_runtime_overhead", False)) super().setUp() reset() def tearDown(self) -> None: + self.exit_stack.close() super().tearDown() reset() @@ -123,9 +139,12 @@ def check_output_and_recompiles( torch.manual_seed(123) expected = list(fn()) torch.manual_seed(123) - with compiled_autograd._enable(compiler_fn), mock.patch( - "torch._functorch.aot_autograd.AOT_COUNTER", - new_callable=itertools.count, + with ( + compiled_autograd._enable(compiler_fn), + mock.patch( + "torch._functorch.aot_autograd.AOT_COUNTER", + new_callable=itertools.count, + ), ): opt_fn = torch.compile(fn) if compile_fn else fn actual = list(opt_fn()) @@ -701,6 +720,44 @@ def fn(model, inputs): self.assertEqual(expected, actual) self.assertEqual(counters["compiled_autograd"]["captures"], 2) + @parametrize("api", ("compile", "optimize")) + @parametrize("backend", ("eager", "aot_eager", "inductor")) + def test_compile_api_disable(self, api, backend): + def wrap(fn, backend): + if api == "compile": + return torch.compile(fn, backend=backend) + elif api == "optimize": + return torch._dynamo.optimize(backend)(fn) + + def fn(model, inputs): + res = [] + for inp in inputs: + result = model(inp).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + inputs = [ + torch.randn([1, 4]), + torch.randn([2, 4]), + torch.randn([3, 4]), + ] + + expected = fn(model, inputs) + with config.patch(compiled_autograd=True): + compiled_fn = wrap(fn, backend) + with torch._dynamo.compiled_autograd._disable(): + actual = compiled_fn(model, inputs) + self.assertEqual(expected, actual) + self.assertTrue("compiled_autograd" not in counters) + @parametrize("backend", ("eager", "aot_eager", "inductor")) def test_optimize_assert(self, backend): # can be merged into the test above once we support @@ -731,6 +788,88 @@ def fn(model, inp): self.assertEqual(expected, actual) self.assertEqual(counters["compiled_autograd"]["captures"], 0) + @config.patch(compiled_autograd=True) + def test_nested_context_manager(self): + def ctx(): + return compiled_autograd._enable(torch.compile) + + # ok + outer = ctx() + inner = ctx() + outer.__enter__() + inner.__enter__() + inner.__exit__(None, None, None) + outer.__exit__(None, None, None) + + # not ok + outer = ctx() + inner = ctx() + outer.__enter__() + inner.__enter__() + with self.assertRaisesRegex( + AssertionError, + "Nested Compiled Autograd Contexts must return before their parent context", + ): + outer.__exit__(None, None, None) + + @config.patch(compiled_autograd=True) + def test_nested_compile(self): + with torch.library._scoped_library("testlib", "FRAGMENT") as lib: + lib.define("square(Tensor x) -> Tensor") + + @torch.library.impl("testlib::square", "CPU") + def square_impl(x: torch.Tensor) -> torch.Tensor: + # nested inference graph compile + @torch.compile(backend="eager") + def fn(x): + return x**2 + + return fn(x) + + class MyFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, x): + return torch.ops.testlib.square(x) + + x = torch.tensor([2.0, 3.0], requires_grad=True) + + @torch.compile + def fn(x): + return MyFn.apply(x) + + fn(x).sum().backward() + + @config.patch(compiled_autograd=True) + def test_no_nested_compiled_autograd(self): + # We disable CA before entering the CA graph + # So re-entrants should be running with the eager autograd engine + + def unrelated_autograd_call(): + x = torch.randn(20, 20, requires_grad=True) + y = torch.randn(20, 20, requires_grad=True) + loss = torch.matmul(x, y).sum() + loss.backward() + + class MyFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, gO): + unrelated_autograd_call() + return gO + + x = torch.randn(10, 10, requires_grad=True) + loss = MyFn.apply(x).sum() + + torch.compile(lambda: loss.backward(create_graph=True))() + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + def test_multiple_torch_compile(self): model = torch.nn.Sequential( torch.nn.Linear(4, 4), @@ -869,8 +1008,8 @@ def test_inputs_aliasing_bytecode_attr_mutations(self): # Freeze compiled autograd graph compiler = torch._dynamo.compiled_autograd.AutogradCompilerInstance(compiler_fn) param = torch.ones(100) - activ = torch.ones(100) * 2 - inputs = [param, activ] + active = torch.ones(100) * 2 + inputs = [param, active] _, proxies, _, _ = compiler.begin_capture( inputs=inputs, sizes=[], @@ -921,7 +1060,7 @@ def bytecode_hook(code, out_code): try: runtime_wrapper( compiled_fn=compiled_fn, - inputs=[param, activ], + inputs=[param, active], sizes=(), scalars=(), hooks=[], @@ -1642,9 +1781,9 @@ def test_custom_fn_output_metadata(self): def my_compiler_fn(gm): for node in gm.graph.nodes: if isinstance(node.target, torch._ops.OpOverload): - assert ( - node.target._name != "aten::_to_copy" - ), "there should be no implicit copies (e.g. dtype casting)" + assert node.target._name != "aten::_to_copy", ( + "there should be no implicit copies (e.g. dtype casting)" + ) def inner_compiler(gm_, example_inputs_): counters["compiled_autograd"]["compiles"] += 1 @@ -2846,8 +2985,9 @@ def test_cudagraphs_cpu_division(self): loss = reduce_to_scalar_loss(out) stderr_msgs = io.StringIO() - with mock.patch("sys.stderr", stderr_msgs), compiled_autograd._enable( - compiler_fn + with ( + mock.patch("sys.stderr", stderr_msgs), + compiled_autograd._enable(compiler_fn), ): torch._inductor.config.triton.cudagraphs = True loss.backward() @@ -2884,8 +3024,9 @@ def test_cudagraphs_sdpa(self): value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") out = torch.nn.functional.scaled_dot_product_attention(query, key, value) - with config.patch(compiled_autograd=True), inductor_config.patch( - "triton.cudagraphs", True + with ( + config.patch(compiled_autograd=True), + inductor_config.patch("triton.cudagraphs", True), ): opt_bwd = torch.compile(lambda: out.sum().backward()) opt_bwd() @@ -2914,8 +3055,9 @@ def backward(ctx, gO): x = torch.randn(10, requires_grad=True, device="cuda") out = MyFn.apply(x) - with config.patch(compiled_autograd=True), inductor_config.patch( - "triton.cudagraphs", True + with ( + config.patch(compiled_autograd=True), + inductor_config.patch("triton.cudagraphs", True), ): opt_bwd = torch.compile(lambda: out.backward()) opt_bwd() @@ -2973,8 +3115,9 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): ) x = torch.randn(2, 2, requires_grad=True, device="cuda") - with config.patch(compiled_autograd=True), inductor_config.patch( - "triton.cudagraphs", True + with ( + config.patch(compiled_autograd=True), + inductor_config.patch("triton.cudagraphs", True), ): out = torch.ops.test_cudagraphs_cpu_scalar_used_in_cpp_custom_op.custom_op_backed_by_autograd_fn( x @@ -3364,9 +3507,12 @@ def compiler_fn(gm): graphs.append(gm) return inner_compiler_fn(gm) - with compiled_autograd._enable(compiler_fn), mock.patch( - "torch._functorch.aot_autograd.AOT_COUNTER", - new_callable=itertools.count, + with ( + compiled_autograd._enable(compiler_fn), + mock.patch( + "torch._functorch.aot_autograd.AOT_COUNTER", + new_callable=itertools.count, + ), ): res = fn(x) res.sum().backward() @@ -3434,11 +3580,11 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_36, getitem_37], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_16, unwrap_maybe_dynamic_int_17], False), ((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_18, unwrap_maybe_dynamic_int_19], False)]); getitem_36 = getitem_37 = unwrap_maybe_dynamic_int_16 = unwrap_maybe_dynamic_int_17 = unwrap_maybe_dynamic_int_18 = unwrap_maybe_dynamic_int_19 = None getitem_39 = validate_outputs_2[0] - accumulate_grad__default_1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_39); getitem_4 = getitem_39 = accumulate_grad__default_1 = None + call_accumulate_grad_1 = torch__dynamo_external_utils_call_accumulate_grad(getitem_4, getitem_39, False); getitem_4 = getitem_39 = call_accumulate_grad_1 = None getitem_40 = validate_outputs_2[1]; validate_outputs_2 = None - accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_40); getitem_3 = getitem_40 = accumulate_grad__default = None + call_accumulate_grad = torch__dynamo_external_utils_call_accumulate_grad(getitem_3, getitem_40, False); getitem_3 = getitem_40 = call_accumulate_grad = None _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None return [] @@ -3446,6 +3592,10 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): ) # https://github.com/pytorch/pytorch/issues/138920 + # Inductor has a joint graph pattern to remove pointless view pairs. + # That will remove the no-op view pairs this test is checking. Disable + # pattern matcher for this test. + @inductor_config.patch(pattern_matcher=False) def test_compiled_autograd_does_not_specialize_on_bw_symints(self): class Mod(torch.nn.Module): def __init__(self, a, b, c): @@ -3594,9 +3744,10 @@ def tensor_hook(_): x = torch.ones(4, requires_grad=True) y = torch.ones(4, requires_grad=False) - with torch.autograd.graph.saved_tensors_hooks( - pack_hook, unpack_hook - ), compiled_autograd._enable(make_compiler_fn(fullgraph=False)): + with ( + torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook), + compiled_autograd._enable(make_compiler_fn(fullgraph=False)), + ): out_test = f(x, y) self.assertEqual(pack_count, 1) self.assertEqual(unpack_count, 0) @@ -3699,7 +3850,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_31], [((None, None, device(type='cpu'), 6, 0, None), [unwrap_maybe_dynamic_int_10, unwrap_maybe_dynamic_int_11], False)]); getitem_31 = unwrap_maybe_dynamic_int_10 = unwrap_maybe_dynamic_int_11 = None getitem_32 = validate_outputs_4[0]; validate_outputs_4 = None - accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_32); getitem_1 = getitem_32 = accumulate_grad__default = None + call_accumulate_grad = torch__dynamo_external_utils_call_accumulate_grad(getitem_1, getitem_32, False); getitem_1 = getitem_32 = call_accumulate_grad = None _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None return [] """, # noqa: B950 @@ -4177,11 +4328,14 @@ def test_ddp_cpp_reducer_error(self): model = DDP(model) inputs = torch.randn(10, 10) loss = model(inputs).sum() - with compiled_autograd._enable(compiler_fn), self.assertRaisesRegex( - RuntimeError, - ( - r"Compiled autograd is not compatible with C\+\+ DDP Reducer, " - r'please use torch._dynamo.config.optimize_ddp="python_reducer"' + with ( + compiled_autograd._enable(compiler_fn), + self.assertRaisesRegex( + RuntimeError, + ( + r"Compiled autograd is not compatible with C\+\+ DDP Reducer, " + r'please use torch._dynamo.config.optimize_ddp="python_reducer"' + ), ), ): loss.backward() @@ -4211,6 +4365,622 @@ def test_ddp_python_reducer(self): finally: dist.destroy_process_group() + # Case 1.1: Stealable dense new_grad + # if (!GradMode::is_enabled() && !new_grad.is_sparse() && + # !new_grad.is_sparse_csr() && + # !(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) && + # at::caching::adjusted_use_count(new_grad) <= num_expected_refs && + # (new_grad.is_mkldnn() || utils::obeys_layout_contract(new_grad, variable))) { + @unittest.expectedFailure + def test_accumulate_grad_polyfill_case_1_1(self): + def fn(): + class StealableDenseOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + return torch.ones_like(grad_output, requires_grad=False) * 5 + + pre_hook_storage_id = None + + def check(grad): + nonlocal pre_hook_storage_id + assert pre_hook_storage_id is None + pre_hook_storage_id = id(grad.untyped_storage()) + + var = torch.randn(2, 2, requires_grad=True) + var.register_hook(check) + output = StealableDenseOp.apply(var) + output.backward(torch.ones_like(output)) + + assert var.grad is not None, "Grad should be defined" + assert torch.equal(var.grad, torch.ones_like(var) * 5), ( + "Grad content should be as returned by backward" + ) + assert var.grad.requires_grad is False, ( + "Detached grad should not require grad" + ) + assert id(var.grad.untyped_storage()) == pre_hook_storage_id, ( + "Should be stolen" + ) + yield var.grad + + self.check_output_and_recompiles( + fn, + compiler_fn=make_compiler_fn(fullgraph=False), + count=[1, 2], + ) + + # Case 1.2: Stealable sparse new_grad + # } else if (!GradMode::is_enabled() && new_grad.is_sparse() && + # new_grad._indices().is_contiguous() && + # new_grad._values().is_contiguous() && + # new_grad._indices().use_count() <= 1 && + # new_grad._values().use_count() <= 1 && + # new_grad.use_count() <= num_expected_refs) { + @unittest.expectedFailure + def test_accumulate_grad_polyfill_case_1_2(self): + def fn(): + class StealableSparseOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + size = grad_output.size() + indices = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64) + values = torch.tensor([5.0, 5.0]) + return torch.sparse_coo_tensor( + indices, values, size, requires_grad=False + ) + + pre_hook_storages_id = None + + def check(grad): + nonlocal pre_hook_storages_id + assert pre_hook_storages_id is None + pre_hook_storages_id = [ + id(grad._indices().untyped_storage()), + id(grad._values().untyped_storage()), + ] + + var = torch.randn(2, 2, requires_grad=True) + var.register_hook(check) + output = StealableSparseOp.apply(var) + output.backward(torch.ones_like(output)) + + assert var.grad is not None, "Grad should be defined" + assert var.grad.is_sparse, "Grad should be sparse" + expected_dense_grad = torch.tensor([[5.0, 0.0], [0.0, 5.0]]) + assert torch.equal(var.grad.to_dense(), expected_dense_grad), ( + "Content should be equal after shallow copy" + ) + assert var.grad.requires_grad is False, ( + "Detached grad should not require grad" + ) + assert ( + id(var.grad._indices().untyped_storage()) == pre_hook_storages_id[0] + ), "Should be stolen" + assert ( + id(var.grad._values().untyped_storage()) == pre_hook_storages_id[1] + ), "Should be stolen" + yield var.grad + + self.check_output_and_recompiles( + fn, + compiler_fn=make_compiler_fn(fullgraph=False), + count=[1, 2], + ) + + # Case 1.3: Cloning sparse/nested new_grad + # else { + # if (new_grad.is_sparse() || new_grad.is_sparse_csr() || + # new_grad.is_nested()) { + def test_accumulate_grad_polyfill_case_1_3(self): + def fn(): + class CloneSparseGradOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + size = grad_output.size() + indices = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64) + values = torch.tensor( + [5.0, 5.0], requires_grad=True + ) # Requires grad + return torch.sparse_coo_tensor( + indices, values, size, requires_grad=True + ) + + pre_hook_storages_id = None + + def check(grad): + nonlocal pre_hook_storages_id + assert pre_hook_storages_id is None + pre_hook_storages_id = [ + id(grad._indices().untyped_storage()), + id(grad._values().untyped_storage()), + ] + + var = torch.randn(2, 2, requires_grad=True) + var.register_hook(check) + output = CloneSparseGradOp.apply(var) + output.backward( + torch.ones_like(output), create_graph=True + ) # grad mode == create_graph + + assert var.grad is not None, "Grad should be defined" + assert var.grad.is_sparse, "Grad should be sparse" + expected_dense_grad = torch.tensor([[5.0, 0.0], [0.0, 5.0]]) + assert torch.equal(var.grad.to_dense(), expected_dense_grad), ( + "Content should be equal after clone" + ) + assert var.grad.requires_grad, ( + "Grad should require grad for double backward" + ) + assert ( + id(var.grad._indices().untyped_storage()) != pre_hook_storages_id[0] + ), "Should be copied" + assert ( + id(var.grad._values().untyped_storage()) != pre_hook_storages_id[1] + ), "Should be copied" + yield var.grad + + self.check_output_and_recompiles( + fn, + compiler_fn=make_compiler_fn(fullgraph=False), + count=[1, 2], + ) + + # Case 1.5.1: Dense variable gradient layout contract + # else { // Covers various deep copy scenarios not covered by specific stealable paths + # ... + # if (new_grad.is_mkldnn()) { + # ... + # } else { + # // Deep copies new_grad according to the "Gradient Layout Contract." + # update_grad(utils::clone_obey_contract(new_grad, variable)); + # } + # } + def test_accumulate_grad_polyfill_case_1_5_1(self): + def fn(): + class NotStealableRefsOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + return torch.ones_like(grad_output, requires_grad=False) * 10.0 + + var = torch.randn(2, 2, requires_grad=True) + grad_ref_holder = [None] + + def check(grad): + # forces a clone due to refcount + grad_ref_holder[0] = grad + return grad + + var.register_hook(check) + output = NotStealableRefsOp.apply(var) + output.backward(torch.ones_like(output)) + + assert var.grad is not None, "Grad should be defined" + assert torch.equal(var.grad, torch.ones_like(var) * 10.0), ( + "Grad content should be as returned by backward" + ) + assert ( + grad_ref_holder[0].untyped_storage() is not var.grad.untyped_storage() + ), "Should be copied" + yield var.grad + + self.check_output_and_recompiles(fn) + + # Case 1.5.2: Non-dense variable gradient layout contract + # else { // Covers various deep copy scenarios not covered by specific stealable paths + # ... + # if (new_grad.is_mkldnn()) { + # ... + # } else { + # // Deep copies new_grad according to the "Gradient Layout Contract." + # update_grad(utils::clone_obey_contract(new_grad, variable)); + # } + # } + def test_accumulate_grad_polyfill_case_1_5_2(self): + def fn(): + class SimpleDenseGradOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + return torch.ones_like(grad_output, requires_grad=False) * 7.0 + + # Create a non-contiguous variable + base_tensor = torch.randn(4, 4) + var = base_tensor[::2, ::2] + assert not var.is_contiguous(), ( + "Variable should be non-contiguous for this test" + ) + var.requires_grad_(True) + + grad_ref_holder = [None] + + def check(grad): + # forces a clone due to refcount + grad_ref_holder[0] = grad + return grad + + var.register_hook(check) + output = SimpleDenseGradOp.apply(var) + output.backward(torch.ones_like(output)) + + assert var.grad is not None, "Grad should be defined" + # The `clone_obey_contract` branch 2 (`new_grad.clone(at::MemoryFormat::Contiguous)`) + # will make the resulting grad contiguous. + assert var.grad.is_contiguous(), ( + "Resulting grad should be contiguous due to branch 2 of clone_obey_contract" + ) + assert torch.equal(var.grad, torch.ones_like(var) * 7.0), ( + "Grad content should be as returned by backward" + ) + assert ( + grad_ref_holder[0].untyped_storage() is not var.grad.untyped_storage() + ), "Should be copied" + yield var.grad + + self.check_output_and_recompiles( + fn, + ) + + # Case 2.1: Sparse variable_grad + Dense new_grad + # } else if (!GradMode::is_enabled()) { + # if (variable_grad.is_sparse() && !new_grad.is_sparse()) { + # auto result = new_grad + variable_grad; + def test_accumulate_grad_polyfill_case_2_1(self): + def fn(): + class SparseVarGradDenseNewGradOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + return torch.ones_like(grad_output) * 3.0 + + var = torch.randn(2, 2, requires_grad=True) + indices = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64) + values = torch.tensor([1.0, 1.0]) + var.grad = torch.sparse_coo_tensor( + indices, values, var.size(), requires_grad=False + ) + initial_grad_ref = var.grad + output = SparseVarGradDenseNewGradOp.apply(var) + + expected_sum = (torch.ones_like(var) * 3.0) + initial_grad_ref.to_dense() + output.backward(torch.ones_like(output)) + + assert var.grad is not None, "Grad should be defined" + assert not var.grad.is_sparse, "Resulting grad should be dense" + assert torch.equal(var.grad, expected_sum), "Grad content should be the sum" + assert var.grad is not initial_grad_ref, ( + "Grad object should be replaced (out-of-place)" + ) + yield var.grad + + self.check_output_and_recompiles( + fn, + compiler_fn=lambda gm: gm, # https://github.com/pytorch/pytorch/issues/154161 + count=[1, 0], + ) + + # Case 2.3.1: Dense/Dense in-place addition + # } else if (!GradMode::is_enabled()) { + # ... + # } else { + # variable_grad += new_grad; + def test_accumulate_grad_polyfill_case_2_3_1(self): + def fn(): + class DenseVarGradDenseNewGradOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + return torch.ones_like(grad_output) * 3.0 + + var = torch.randn(2, 2, requires_grad=True) + var.grad = torch.ones_like(var) * 1.0 + initial_grad_ref = var.grad + output = DenseVarGradDenseNewGradOp.apply(var) + expected_sum = initial_grad_ref + (torch.ones_like(var) * 3.0) + output.backward(torch.ones_like(output)) + + assert var.grad is not None, "Grad should be defined" + assert not var.grad.is_sparse, "Resulting grad should be dense" + assert torch.equal(var.grad, expected_sum), "Grad content should be the sum" + assert var.grad is initial_grad_ref, ( + "Grad object should be modified in-place (same object)" + ) + yield var.grad + + self.check_output_and_recompiles(fn) + + # Case 2.3.2: Sparse/Sparse in-place addition + # } else if (!GradMode::is_enabled()) { + # ... + # } else { + # variable_grad += new_grad; + def test_accumulate_grad_polyfill_case_2_3_2(self): + def fn(): + class SparseVarGradSparseNewGradOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + size = grad_output.size() + indices = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64) + values = torch.tensor([3.0, 3.0]) + return torch.sparse_coo_tensor( + indices, values, size, requires_grad=False + ) + + var = torch.randn(2, 2, requires_grad=True) + indices_v = torch.tensor([[0, 0], [0, 1]], dtype=torch.int64) + values_v = torch.tensor([1.0, 2.0]) + var.grad = torch.sparse_coo_tensor( + indices_v, values_v, var.size(), requires_grad=False + ) + initial_grad_ref = var.grad + + output = SparseVarGradSparseNewGradOp.apply(var) + + new_grad_for_sum = torch.sparse_coo_tensor( + torch.tensor([[0, 1], [0, 1]], dtype=torch.int64), + torch.tensor([3.0, 3.0]), + var.size(), + ) + expected_sum_dense = ( + initial_grad_ref.to_dense() + new_grad_for_sum.to_dense() + ) + + output.backward(torch.ones_like(output)) + + assert var.grad is not None, "Grad should be defined" + assert var.grad.is_sparse, "Resulting grad should remain sparse" + assert torch.equal(var.grad.to_dense(), expected_sum_dense), ( + "Grad content should be the sum of sparse grads" + ) + assert var.grad is initial_grad_ref, ( + "Grad object should be modified in-place (same object)" + ) + yield var.grad + + self.check_output_and_recompiles( + fn, + compiler_fn=lambda gm: gm, # https://github.com/pytorch/pytorch/issues/154161 + count=[1, 0], + ) + + # Case 2.3.3: Dense/Sparse in-place addition + # } else if (!GradMode::is_enabled()) { + # ... + # } else { + # variable_grad += new_grad; + def test_accumulate_grad_polyfill_case_2_3_3(self): + def fn(): + class DenseVarGradSparseNewGradOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + size = grad_output.size() + indices = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64) + values = torch.tensor([3.0, 3.0]) # New sparse values + return torch.sparse_coo_tensor( + indices, values, size, requires_grad=False + ) + + var = torch.randn(2, 2, requires_grad=True) + var.grad = torch.ones_like(var) * 1.0 # Initial value + initial_grad_ref = var.grad + output = DenseVarGradSparseNewGradOp.apply(var) + + new_grad_for_sum = torch.sparse_coo_tensor( + torch.tensor([[0, 1], [0, 1]], dtype=torch.int64), + torch.tensor([3.0, 3.0]), + var.size(), + ).to_dense() + expected_sum = initial_grad_ref + new_grad_for_sum + + output.backward(torch.ones_like(output)) + + assert var.grad is not None, "Grad should be defined" + assert not var.grad.is_sparse, "Resulting grad should be dense" + assert torch.equal(var.grad, expected_sum), "Grad content should be the sum" + assert var.grad is initial_grad_ref, ( + "Grad object should be modified in-place (same object)" + ) + yield var.grad + + self.check_output_and_recompiles( + fn, + compiler_fn=make_compiler_fn(fullgraph=False), + count=[1, 2], + ) + + # Case 3.1: Sparse variable_grad + Dense new_grad (reorder into Dense + Sparse) + # } else { // if GradMode::is_enabled() + # at::Tensor result; + # if (variable_grad.is_sparse() && !new_grad.is_sparse()) { + # result = new_grad + variable_grad; + # } + # } + def test_accumulate_grad_polyfill_case_3_1(self): + def fn(): + class SparseVarGradDenseNewGradDoubleBackwardOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + return torch.ones_like(grad_output, requires_grad=True) * 3.0 + + var = torch.randn(2, 2, requires_grad=True) + indices = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64) + values = torch.tensor([1.0, 1.0], requires_grad=True) + var.grad = torch.sparse_coo_tensor( + indices, values, var.size(), requires_grad=True + ) + initial_grad_ref = var.grad + + output = SparseVarGradDenseNewGradDoubleBackwardOp.apply(var) + + expected_sum = ( + torch.ones_like(var, requires_grad=True) * 3.0 + ) + initial_grad_ref.to_dense() + + output.backward(torch.ones_like(output), create_graph=True) + + assert var.grad is not None, "Grad should be defined" + assert not var.grad.is_sparse, "Resulting grad should be dense" + assert torch.equal(var.grad, expected_sum), "Grad content should be the sum" + assert var.grad is not initial_grad_ref, ( + "Grad object should be replaced (out-of-place)" + ) + assert var.grad.requires_grad, ( + "Resulting grad should track history for double backward" + ) + yield var.grad + + self.check_output_and_recompiles( + fn, + compiler_fn=lambda gm: gm, # https://github.com/pytorch/pytorch/issues/154161 + count=[1, 0], + ) + + # Case 3.2: variable_grad.defined() & GradMode::is_enabled() - Double backward (dense variable_grad + dense new_grad) + # } else { // if GradMode::is_enabled() + # at::Tensor result; + # ... + # } else { + # result = variable_grad + new_grad; + # } + # } + def test_accumulate_grad_polyfill_case_3_2(self): + def fn(): + class DenseVarGradDenseNewGradDoubleBackwardOp(BaseCustomOp): + @staticmethod + def backward(ctx, grad_output): + return torch.ones_like(grad_output, requires_grad=True) * 3.0 + + var = torch.randn(2, 2, requires_grad=True) + var.grad = torch.ones_like(var) * 1.0 + initial_grad_ref = var.grad + + output = DenseVarGradDenseNewGradDoubleBackwardOp.apply(var) + + expected_sum = initial_grad_ref + ( + torch.ones_like(var, requires_grad=True) * 3.0 + ) + + output.backward(torch.ones_like(output), create_graph=True) + + assert var.grad is not None, "Grad should be defined" + assert not var.grad.is_sparse, "Resulting grad should be dense" + assert torch.equal(var.grad, expected_sum), "Grad content should be the sum" + assert var.grad is not initial_grad_ref, ( + "Grad object should be replaced (out-of-place)" + ) + assert var.grad.requires_grad, ( + "Resulting grad should track history for double backward" + ) + yield var.grad + + self.check_output_and_recompiles( + fn, + compiler_fn=make_compiler_fn(fullgraph=False), + count=[1, 3], + ) + + def test_torch_function_mode(self): + called_funcs = [] + + class LoggingTorchFunctionMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + called_funcs.append(str(func.__name__)) + return super().__torch_function__(func, types, args, kwargs) + + class MyLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, out): + ctx.save_for_backward(out) + return out.sum() + + @staticmethod + def backward(ctx, grad_output): + (saved,) = ctx.saved_tensors + return torch.ones_like(saved) * grad_output + + x = torch.randn(2, 2, requires_grad=True) + y = torch.randn(2, 2) + z = torch.randn(2, 2) + + def fwd(x, y, z): + out = x * y * z + loss = MyLoss.apply(out) + return loss + + with LoggingTorchFunctionMode(): + called_funcs.append("Forward") + loss = fwd(x, y, z) + called_funcs.append("Backward") + with torch._dynamo.compiled_autograd._enable(torch.compile): + loss.backward() + + self.assertExpectedInline( + "\n".join(called_funcs), + """\ +Forward +mul +mul +sum +Backward +_set_multithreading_enabled +backward +_set_multithreading_enabled""", + ) # noqa: B950 + + def test_torch_dispatch_mode(self): + called_funcs = [] + + class LoggingTorchDispatchMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + called_funcs.append(str(func.__name__)) + return func(*args, **kwargs) + + class MyLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, out): + ctx.save_for_backward(out) + return out.sum() + + @staticmethod + def backward(ctx, grad_output): + (saved,) = ctx.saved_tensors + return torch.ones_like(saved) * grad_output + + x = torch.randn(2, 2, requires_grad=True) + y = torch.randn(2, 2) + z = torch.randn(2, 2) + + def fwd(x, y, z): + out = x * y * z + loss = MyLoss.apply(out) + return loss + + with LoggingTorchDispatchMode(): + called_funcs.append("Forward") + loss = fwd(x, y, z) + called_funcs.append("Backward") + with torch._dynamo.compiled_autograd._enable(lambda gm: gm): + loss.backward() + + self.assertExpectedInline( + "\n".join(called_funcs), + """\ +Forward +mul.Tensor +mul.Tensor +sum.default +Backward +ones_like.default +empty.memory_format +empty.memory_format +empty.memory_format +empty.memory_format +empty.memory_format +empty.memory_format +ones_like.default +mul.Tensor +mul.Tensor +mul.Tensor +new_empty_strided.default +copy_.default""", + ) # noqa: B950 + def load_test_module(name): testdir = Path(__file__).absolute().parent.parent @@ -4400,8 +5170,6 @@ def wrap_test_class(orig_cls): "test_reentrant_with_callbacks_depth_1", # queue_callback "test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook "test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd - "test_current_node", # TorchDispatchMode not yet implemented for compiled autograd - "test_nested_checkpoint_set_early_stop_no_recompution_needed", # TorchDispatchMode not yet implemented "test_post_accumulate_grad_hook_ordering", # accuracy error "test_current_graph_task_id", # autograd state already cleared once dynamo is called "test_custom_function_forward_mode_forward_is_no_op", # forward AD @@ -4493,6 +5261,8 @@ def wrap_test_class(orig_cls): "test_grad_call_compiled_backward_fn", # different functorch error "test_vjp_call_compiled_backward_fn", # different functorch error "test_vmap_call_compiled_backward_fn", # different functorch error + "test_accumulate_grad", # always out of place add for compiled autograd + "test_current_node", # slightly different dispatched ops } skipped_tests = set() diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index db8c169cfe018c..9751b3ca8f554f 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -1,12 +1,15 @@ # Owner(s): ["module: inductor"] import sys +import types import unittest import weakref from contextlib import ExitStack from copy import deepcopy from typing import NamedTuple +from expecttest import assert_expected_inline + import torch import torch._inductor import torch._inductor.cudagraph_trees @@ -186,69 +189,73 @@ class KernelCounts(NamedTuple): # tests you can get different kernel counts # This maps the test name to the # expected kernel count + +# fmt: off +# expecttest got error after PYFMT add line break for the triple quotes KERNEL_COUNT_OVERRIDES = { - "test_rmsprop_foreach_weight_decay_cpu": 12, - "test_nadam_foreach_weight_decay_momentum_decay_cpu": 20, - "test_adamw_amsgrad_capturable_foreach_cuda": 3, - "test_adamw_amsgrad_capturable_foreach_xpu": 3, - "test_adamw_amsgrad_capturable_cuda": 6, - "test_adamw_amsgrad_capturable_xpu": 6, - "test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6, - "test_adamw_tensor_lr_tensor_betas_capturable_cuda": 6, - "test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_xpu": 6, - "test_adamw_tensor_lr_amsgrad_capturable_cuda": 6, - "test_adamw_tensor_lr_amsgrad_capturable_xpu": 6, - "test_adam_tensor_lr_amsgrad_capturable_cuda": 6, - "test_adam_tensor_lr_amsgrad_capturable_xpu": 6, - "test_adam_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6, - "test_adam_tensor_lr_tensor_betas_capturable_cuda": 6, - "test_adam_amsgrad_capturable_cuda": 6, - "test_adam_amsgrad_capturable_xpu": 6, - "test_adadelta_tensor_lr_capturable_cuda": 6, - "test_adadelta_tensor_lr_capturable_xpu": 6, - "test_rmsprop_tensor_lr_capturable_cuda": 6, - "test_rmsprop_tensor_lr_capturable_xpu": 6, - "test_adadelta_foreach_weight_decay_maximize_cpu": 12, - "test_adadelta_foreach_rho_weight_decay_cpu": 12, - "test_adadelta_foreach_weight_decay_cpu": 12, - "test_sgd_foreach_momentum_weight_decay_cpu": 16, - "test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16, - "test_sgd_momentum_dampening_foreach_cuda": 5, - "test_sgd_momentum_dampening_foreach_xpu": 5, - "test_sgd_momentum_foreach_cuda": 5, - "test_sgd_momentum_foreach_xpu": 5, - "test_sgd_weight_decay_maximize_cuda": 4, - "test_sgd_weight_decay_maximize_xpu": 4, - "test_sgd_weight_decay_maximize_cpu": 4, - "test_sgd_weight_decay_cpu": 4, - "test_sgd_weight_decay_cuda": 4, - "test_sgd_weight_decay_xpu": 4, - "test_sgd_momentum_weight_decay_foreach_cuda": 2, - "test_sgd_momentum_weight_decay_foreach_xpu": 2, - "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, - "test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2, - "test_sgd_cuda": 4, - "test_sgd_cpu": 4, - "test_sgd_xpu": 4, - "test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2, - "test_adagrad_lr_decay_weight_decay_foreach_xpu": 2, - "test_adagrad_weight_decay_foreach_xpu": 2, - "test_adagrad_weight_decay_maximize_foreach_xpu": 2, - "test_adagrad_tensor_lr_cpu": 6, - "test_adagrad_tensor_lr_cuda": 6, - "test_adagrad_tensor_lr_xpu": 6, - "test_adamax_tensor_lr_weight_decay_capturable_cuda": 6, - "test_adamax_tensor_lr_weight_decay_capturable_xpu": 6, - "test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5, - "test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8, - "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6, - "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9, - "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6, - "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6, - "test_sgd_tensor_lr_cpu": 2, - "test_sgd_tensor_lr_cuda": 2, - "test_sgd_tensor_lr_xpu": 2, + "test_rmsprop_foreach_weight_decay_cpu": lambda x: assert_expected_inline(x, """12""") , + "test_nadam_foreach_weight_decay_momentum_decay_cpu": lambda x: assert_expected_inline(x, """20"""), + "test_adamw_amsgrad_capturable_foreach_cuda": lambda x: assert_expected_inline(x, """3"""), + "test_adamw_amsgrad_capturable_foreach_xpu": lambda x: assert_expected_inline(x, """3"""), + "test_adamw_amsgrad_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adamw_amsgrad_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adamw_tensor_lr_tensor_betas_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_adamw_tensor_lr_amsgrad_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adamw_tensor_lr_amsgrad_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_adam_tensor_lr_amsgrad_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adam_tensor_lr_amsgrad_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_adam_tensor_lr_tensor_betas_amsgrad_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adam_tensor_lr_tensor_betas_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adam_amsgrad_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adam_amsgrad_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_adadelta_tensor_lr_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adadelta_tensor_lr_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_rmsprop_tensor_lr_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_rmsprop_tensor_lr_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_adadelta_foreach_weight_decay_maximize_cpu": lambda x: assert_expected_inline(x, """12"""), + "test_adadelta_foreach_rho_weight_decay_cpu": lambda x: assert_expected_inline(x, """12"""), + "test_adadelta_foreach_weight_decay_cpu": lambda x: assert_expected_inline(x, """12"""), + "test_sgd_foreach_momentum_weight_decay_cpu": lambda x: assert_expected_inline(x, """16"""), + "test_sgd_foreach_momentum_nesterov_weight_decay_cpu": lambda x: assert_expected_inline(x, """16"""), + "test_sgd_momentum_dampening_foreach_cuda": lambda x: assert_expected_inline(x, """5"""), + "test_sgd_momentum_dampening_foreach_xpu": lambda x: assert_expected_inline(x, """5"""), + "test_sgd_momentum_foreach_cuda": lambda x: assert_expected_inline(x, """5"""), + "test_sgd_momentum_foreach_xpu": lambda x: assert_expected_inline(x, """5"""), + "test_sgd_weight_decay_maximize_cuda": lambda x: assert_expected_inline(x, """4"""), + "test_sgd_weight_decay_maximize_xpu": lambda x: assert_expected_inline(x, """4"""), + "test_sgd_weight_decay_maximize_cpu": lambda x: assert_expected_inline(x, """4"""), + "test_sgd_weight_decay_cpu": lambda x: assert_expected_inline(x, """4"""), + "test_sgd_weight_decay_cuda": lambda x: assert_expected_inline(x, """4"""), + "test_sgd_weight_decay_xpu": lambda x: assert_expected_inline(x, """4"""), + "test_sgd_momentum_weight_decay_foreach_cuda": lambda x: assert_expected_inline(x, """2"""), + "test_sgd_momentum_weight_decay_foreach_xpu": lambda x: assert_expected_inline(x, """2"""), + "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": lambda x: assert_expected_inline(x, """2"""), + "test_sgd_momentum_nesterov_weight_decay_foreach_xpu": lambda x: assert_expected_inline(x, """2"""), + "test_sgd_cuda": lambda x: assert_expected_inline(x, """4"""), + "test_sgd_cpu": lambda x: assert_expected_inline(x, """4"""), + "test_sgd_xpu": lambda x: assert_expected_inline(x, """4"""), + "test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": lambda x: assert_expected_inline(x, """2"""), + "test_adagrad_lr_decay_weight_decay_foreach_xpu": lambda x: assert_expected_inline(x, """2"""), + "test_adagrad_weight_decay_foreach_xpu": lambda x: assert_expected_inline(x, """2"""), + "test_adagrad_weight_decay_maximize_foreach_xpu": lambda x: assert_expected_inline(x, """2"""), + "test_adagrad_tensor_lr_cpu": lambda x: assert_expected_inline(x, """6"""), + "test_adagrad_tensor_lr_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adagrad_tensor_lr_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_adamax_tensor_lr_weight_decay_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_adamax_tensor_lr_weight_decay_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": lambda x: assert_expected_inline(x, """5"""), + "test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": lambda x: assert_expected_inline(x, """8"""), + "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), # noqa: B950 + "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": lambda x: assert_expected_inline(x, """9"""), # noqa: B950 + "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": lambda x: assert_expected_inline(x, """6"""), + "test_sgd_tensor_lr_cpu": lambda x: assert_expected_inline(x, """2"""), + "test_sgd_tensor_lr_cuda": lambda x: assert_expected_inline(x, """2"""), + "test_sgd_tensor_lr_xpu": lambda x: assert_expected_inline(x, """2"""), } +# fmt: on # also tracks currently supported optimizers KERNEL_COUNTS = { @@ -503,9 +510,12 @@ def test_fn(self): # currently, we compile the step and the rest of the computation # separately because the step is a single element tensor # hence, the usual kernel count is 2 - self.assertEqual( - torch._inductor.metrics.generated_kernel_count, kernel_count - ) + if isinstance(kernel_count, types.LambdaType): + kernel_count(str(torch._inductor.metrics.generated_kernel_count)) + else: + self.assertEqual( + torch._inductor.metrics.generated_kernel_count, kernel_count + ) finally: stack.close() @@ -919,9 +929,9 @@ def test_S429861(self): import torch._dynamo import torch._inductor from torch._dynamo.debug_utils import aot_graph_input_parser - from torch._inductor.utils import fresh_inductor_cache + from torch._inductor.utils import fresh_cache - with fresh_inductor_cache(): + with fresh_cache(): kwargs = aot_graph_input_parser(forward) torch.compile(forward)(**kwargs) diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 8a675af8469cac..107a65d6fa1dfc 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -13,6 +13,7 @@ decorateIf, instantiate_parametrized_tests, parametrize, + skipIfXpu, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU from torch.testing._internal.triton_utils import requires_gpu @@ -212,6 +213,40 @@ def false_fn(x): return y.sum() - torch.cond(x.sum() > 0, true_fn, false_fn, (x,)) + class FunctionalCall(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, p, x): + true_new_weight = torch.ones(x.size(0), x.size(0), device=x.device) + false_new_weight = torch.zeros(x.size(0), x.size(0), device=x.device) + true_new_bias = torch.ones(x.size(0), device=x.device) + false_new_bias = torch.zeros(x.size(0), device=x.device) + x = x.reshape(-1, x.size(0)) + + def true_fn(x): + return torch.func.functional_call( + self.linear, + { + "weight": true_new_weight, + "bias": true_new_bias, + }, + x, + ) + + def false_fn(x): + return torch.func.functional_call( + self.linear, + { + "weight": false_new_weight, + "bias": false_new_bias, + }, + x, + ) + + return torch.cond(p, true_fn, false_fn, (x,)) + class CondTests(TestCase): def _run_test( @@ -293,6 +328,7 @@ def test_cond_unbacked_symint_closure(self, device, dynamic): dynamic=dynamic, ) + @skipIfXpu(msg="Remove this skip after issue #154949 resolved.") @requires_gpu def test_cond_control_flow_with_precomputed_size(self): class TestModel(torch.nn.Module): @@ -673,6 +709,17 @@ def test_cond_mismatched_branch_output_size(self, device, dynamic): dynamic=dynamic, ) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) + @parametrize("dynamic", [True, False]) + def test_cond_functional_call(self, device, dynamic): + self._run_test( + model=CondModels.FunctionalCall(), + inputs=(torch.randn(10, 20),), + device=device, + dynamic=dynamic, + ) + class WhileLoopModels: class Simple(torch.nn.Module): @@ -1302,7 +1349,7 @@ class AssociativeScanTests(TestCase): @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("backend", ["inductor"]) @parametrize("device", [torch.device("cpu"), GPU_TYPE]) - # This test will fail as flip in combination with particular input lenghts + # This test will fail as flip in combination with particular input lengths # produces weird results. # This is under investigations in # https://github.com/pytorch/pytorch/issues/131805 @@ -1326,7 +1373,7 @@ def fct(x: torch.Tensor, y: torch.Tensor): fct, x, 0, reverse=False, combine_mode=combine_mode ) - # Skipping test because combine_mode currently only suppors CUDA tensors + # Skipping test because combine_mode currently only supports CUDA tensors return result1 = associative_scan1( @@ -1563,12 +1610,13 @@ def combine_fn(carry, xs): grad_weight, grad_bias, loss_acc = carry input_chunk, target_chunk = xs ( - chunk_grad_input, - chunk_grad_weight, - chunk_grad_bias, - ), chunk_loss = torch.func.grad_and_value( - compute_loss, argnums=(0, 1, 2) - )( + ( + chunk_grad_input, + chunk_grad_weight, + chunk_grad_bias, + ), + chunk_loss, + ) = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))( input_chunk, weight, bias, target_chunk ) return ( @@ -1616,12 +1664,13 @@ def compute_loss(input_chunk, weight, bias, target): def accumulate_chunk(input_chunk, target_chunk): ( - chunk_grad_input, - chunk_grad_weight, - chunk_grad_bias, - ), chunk_loss = torch.func.grad_and_value( - compute_loss, argnums=(0, 1, 2) - )( + ( + chunk_grad_input, + chunk_grad_weight, + chunk_grad_bias, + ), + chunk_loss, + ) = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))( input_chunk, weight, bias, target_chunk ) grad_weight.add_(chunk_grad_weight) @@ -1642,6 +1691,18 @@ def accumulate_chunk(input_chunk, target_chunk): torch.cat(grad_inputs, dim=0) / chunks, ) + class ScanWithClamp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, scan_op, initial, xs): + def step(h_prev, x_t): + h_next = (h_prev + x_t).clamp(min=0.1) + return h_next, h_next.clone() + + final, ys = scan_op(step, initial, xs) + return final, ys + class ScanTests(TestCase): def _run_test( @@ -1824,6 +1885,24 @@ def test_scan_compare_chunked_ce_with_no_scan(self, device, dynamic): device=device, ) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) + @parametrize("dynamic", [True, False]) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_scan_with_clamp(self, device, dynamic): + B = 4 + T = 8 + H = 16 + self._run_test( + model=ScanModels.ScanWithClamp(), + inputs=( + torch.randn((B, H)), + torch.randn((T, B, H), requires_grad=True), + ), + device=device, + dynamic=dynamic, + ) + class MapModels: class Simple(torch.nn.Module): diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index 469ceec2e1b2b6..fc296b12a9d707 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -12,6 +12,7 @@ from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_code +from torch.testing import assert_close from torch.testing._internal.common_cuda import IS_SM89 from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -57,11 +58,90 @@ def setUp(self): torch._inductor.metrics.generated_kernel_count = 0 torch._dynamo.reset() - def run_and_check(self, fn, args, *, expect_kernel_count=1): - expected = fn(*args) - fn = torch.compile(fn, fullgraph=True) - result, (source_code,) = run_and_get_code(fn, *args) - self.assertEqual(result, expected) + def run_and_check(self, fn, args, dtype=None, *, expect_kernel_count=1): + # Define fixed tolerances + RTOL = 1e-5 + ATOL = 1e-6 + + # calculate reference value in higher precision when input dtype is float16 + ref_dtype = dtype + if dtype == torch.float16: + ref_dtype = torch.float64 + + # Cast to the determined reference dtype + args_ref = [tensor.to(ref_dtype) for tensor in args] + + # Calculate expected output + raw_expected = fn(*args_ref) + + if isinstance(raw_expected, (tuple, list)): + # If it's a tuple or list, apply .to(dtype) to each tensor within it + # Also, handle cases where dtype might not be provided (e.g., for bool reductions) + if dtype is not None: + expected = type(raw_expected)( + [ + t.to(dtype) if isinstance(t, torch.Tensor) else t + for t in raw_expected + ] + ) + else: + expected = type(raw_expected)( + [ + t.to(torch.float64) if isinstance(t, torch.Tensor) else t + for t in raw_expected + ] + ) + else: + # If it's a single tensor + if dtype is not None: + expected = raw_expected.to(dtype) + else: + expected = raw_expected.to(torch.float64) + + fn_compiled = torch.compile(fn, fullgraph=True) + result, (source_code,) = run_and_get_code(fn_compiled, *args) + + # For comparison, ensure result is also a tuple/list if expected is + if isinstance(expected, (tuple, list)): + if isinstance(result, torch.Tensor): + result = (result,) + elif not isinstance(result, type(expected)): + result = type(expected)(result) + + if dtype is not None: + result = type(result)( + [t.to(dtype) if isinstance(t, torch.Tensor) else t for t in result] + ) + else: + result = type(result)( + [ + t.to(torch.float64) if isinstance(t, torch.Tensor) else t + for t in result + ] + ) + else: + if dtype is not None and isinstance(result, torch.Tensor): + result = result.to(dtype) + elif isinstance(result, torch.Tensor): + result = result.to(torch.float64) + + # Apply assert_close with fixed tolerances for tensor comparisons + if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor): + assert_close(result, expected, rtol=RTOL, atol=ATOL) + elif isinstance(result, (tuple, list)) and isinstance(expected, (tuple, list)): + # Iterate through elements for comparison + for r_item, e_item in zip(result, expected): + if isinstance(r_item, torch.Tensor) and isinstance( + e_item, torch.Tensor + ): + assert_close(r_item, e_item, rtol=RTOL, atol=ATOL) + else: + # Fallback to assertEqual for non-tensor elements (e.g., bool, int) + self.assertEqual(r_item, e_item) + else: + # Fallback to assertEqual for other types not handled by assert_close + self.assertEqual(result, expected) + if "@triton_heuristics.fixed_config" in source_code: self.assertIn("cooperative_reduction_grid", source_code) else: @@ -97,7 +177,7 @@ def fn(x, y): reduction_fn = getattr(torch, name) args = [torch.randn(1, 1024**2, device="cuda", dtype=dtype) for _ in range(2)] - self.run_and_check(fn, args) + self.run_and_check(fn, args, dtype) def test_bool_reduction_fns(self): def fn(x, y): diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 7716898c542434..4b4daaef5c4385 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -97,6 +97,7 @@ def make_test_case( slow=False, func_inputs=None, code_string_count=None, + test_build_separate=False, ): test_name = f"{name}_{device}" if device else name if code_string_count is None: @@ -105,8 +106,12 @@ def make_test_case( func = getattr(tests, test_name) assert callable(func), "not a callable" func = slowTest(func) if slow else func + new_test_name = f"{test_name}_separate" if test_build_separate else test_name - @config.patch(cpp_wrapper=True, search_autotune_cache=False) + @config.patch( + cpp_wrapper=True, + cpp_wrapper_build_separate=test_build_separate, + ) def fn(self): tests.setUpClass() tests.setUp() @@ -123,6 +128,8 @@ def fn(self): # happen for tests validating build-dependent features (e.g. datatypes # that are available on some platforms and not others). if code: + if test_build_separate: + self.assertIn("kernel_src", code) self.assertIn("CppWrapperCodeCache", code) self.assertTrue( all( @@ -134,14 +141,14 @@ def fn(self): tests.tearDown() tests.tearDownClass() - fn.__name__ = test_name + fn.__name__ = new_test_name import copy fn.__dict__ = copy.deepcopy(func.__dict__) if condition: setattr( CppWrapperTemplate, - test_name, + new_test_name, fn, ) @@ -156,14 +163,18 @@ class BaseTest(NamedTuple): slow: bool = False func_inputs: list = None code_string_count: dict = {} + test_build_separate: bool = False for item in [ BaseTest("test_add_complex"), + BaseTest("test_add_complex", test_build_separate=True), BaseTest("test_add_complex4"), + BaseTest("test_add_complex4", test_build_separate=True), BaseTest("test_as_strided"), # buffer reuse BaseTest("test_bernoulli1"), BaseTest("test_bitwise"), # int32 BaseTest("test_bmm1"), + BaseTest("test_bmm1", test_build_separate=True), BaseTest("test_bmm2"), BaseTest("test_cat"), # alias BaseTest( @@ -220,9 +231,9 @@ class BaseTest(NamedTuple): ], BaseTest("test_polar"), BaseTest( - "test_linear_binary_cpu", + "test_linear_binary", "", - test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(), + test_mkldnn_pattern_matcher.TestPatternMatcher(), torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), @@ -372,6 +383,7 @@ class BaseTest(NamedTuple): item.slow, item.func_inputs, item.code_string_count, + item.test_build_separate, ) test_torchinductor.copy_tests( diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 11a6c8739bfc2f..b6a46176c27c64 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -152,7 +152,7 @@ class RecordFunctions(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} if func == torch.ops.aten.convolution.default: - # For CPU and mkldnn enable, we always using channles last + # For CPU and mkldnn enable, we always using channels last nonlocal fmt if ( torch.backends.mkldnn.enabled @@ -232,9 +232,11 @@ def forward(self, x): metrics.reset() v = torch.randn(*input_size) mod = Model(output_size, kernel_size, stride).eval() - with contextlib.nullcontext() if ( - num_threads != 1 - ) else set_num_threads(1): + with ( + contextlib.nullcontext() + if (num_threads != 1) + else set_num_threads(1) + ): with torch.no_grad(): self.common( mod, @@ -760,7 +762,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else "aten.set_.source_Tensor", code, ) - self.assertEqual(model(inp), result) + expected = model(inp) + self.assertEqual(expected, result) + + # test cpp_wrapper_build_separate + with config.patch(cpp_wrapper=True, cpp_wrapper_build_separate=True): + result, code = run_and_get_cpp_code(fn_opt, inp) + self.assertIn("kernel_src", code) + self.assertEqual(expected, result) + + with config.patch(cpp_wrapper=True, cpp_wrapper_build_separate=False): + result, code = run_and_get_cpp_code(fn_opt, inp) + self.assertNotIn("kernel_src", code) + self.assertEqual(expected, result) @torch._dynamo.config.patch(dynamic_shapes=True) @torch._dynamo.config.patch(assume_static_by_default=False) @@ -982,7 +996,7 @@ def fn(x): v = torch.randn(10) # TODO: OMP parallel reduction order is not deterministic. - # Hence, the accurarcy might vary up and down. For short term, + # Hence, the accuracy might vary up and down. For short term, # we increase the tolerance and will fix it later by using # aten parallel. self.common(fn, (v,), atol=5e-1, rtol=5e-1) @@ -990,7 +1004,7 @@ def fn(x): def test_parallel_reduction_vectorization(self): # Fix issue: https://github.com/pytorch/pytorch/issues/151523 class Model(torch.nn.Module): - def __init__(self): + def __init__(self, enable_masked_tail_vec): super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, @@ -999,20 +1013,23 @@ def __init__(self): stride=(2, 1), padding=0, ) + self.enable_masked_tail_vec = enable_masked_tail_vec def forward(self, x, weight): x = self.conv(x) - x = F.hardshrink(x, lambd=0) + if not self.enable_masked_tail_vec: + x = F.hardshrink(x, lambd=0) x = x.view(x.size(0), -1) x = torch.mv(weight, x[0]) return x - mod = Model().eval() - x = torch.randn(2, 3, 127, 255) - weight = torch.randn(10, 254976) - # Use same criterion as test_inplace_squeeze_needed - # for parallel reduction. - self.common(mod, (x, weight), atol=5e-1, rtol=5e-1) + for enable_masked_tail_vec in [True, False]: + mod = Model(enable_masked_tail_vec).eval() + x = torch.randn(2, 3, 127, 255) + weight = torch.randn(10, 254976) + # Use same criterion as test_inplace_squeeze_needed + # for parallel reduction. + self.common(mod, (x, weight), atol=5e-1, rtol=5e-1) def test_cat_mul(self): # https://github.com/pytorch/pytorch/issues/93365 @@ -1273,9 +1290,9 @@ def test_slice_scatter_default_end_value(self): # From HF AllenaiLongformerBase. def fn(query, key, window_overlap): batch_size, seq_len, num_heads, head_dim = query.size() - assert ( - seq_len % (window_overlap * 2) == 0 - ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" + assert seq_len % (window_overlap * 2) == 0, ( + f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" + ) chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 diagonal_chunked_attention_scores = key @@ -1287,11 +1304,11 @@ def fn(query, key, window_overlap): window_overlap * 2 + 1, ) ) - diagonal_attention_scores[ - :, :3, :, window_overlap: - ] = diagonal_chunked_attention_scores[ - :, :, :window_overlap, : window_overlap + 1 - ] + diagonal_attention_scores[:, :3, :, window_overlap:] = ( + diagonal_chunked_attention_scores[ + :, :, :window_overlap, : window_overlap + 1 + ] + ) return diagonal_attention_scores self.common( @@ -2618,8 +2635,9 @@ def fn(a, dim, index, b): self.common(fn, inps) assert metrics.generated_cpp_vec_kernel_count == 2 - with set_num_threads(1), config.patch( - {"fx_graph_cache": False, "fx_graph_remote_cache": False} + with ( + set_num_threads(1), + config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False}), ): torch._dynamo.reset() metrics.reset() @@ -3512,7 +3530,7 @@ def forward(self, x): metrics.reset() m = Model().eval() if eval_mode else Model() self.common(m, (x,)) - check_metrics_vec_kernel_count(8) + check_metrics_vec_kernel_count(6) @requires_vectorization @config.patch("cpp.enable_tiling_heuristics", False) @@ -4816,9 +4834,12 @@ def forward(self, q, k, v): from torch.nn.attention import sdpa_kernel, SDPBackend context = contextlib.nullcontext if not is_inference else torch.no_grad - with config.patch( - {"fallback_random": True} - ), torch.cpu.amp.autocast(), context(), sdpa_kernel(SDPBackend.MATH): + with ( + config.patch({"fallback_random": True}), + torch.cpu.amp.autocast(), + context(), + sdpa_kernel(SDPBackend.MATH), + ): torch.manual_seed(0) eager = mod(*inputs) torch.manual_seed(0) @@ -5346,6 +5367,44 @@ def test_vector_norm_compile(self): res = compiled_vector_norm(x, ord=2, dim=[], keepdim=False, dtype=None) self.assertEqual(ref, res) + def test_fractional_max_pool2d_3d_input(self): + """Test for https://github.com/pytorch/pytorch/issues/156682 - 3D input causing assertion error""" + + # Test various 3D input shapes to ensure the compilation crash is fixed + test_shapes = [ + (1, 8, 8), # Original failing case + (3, 16, 16), # Different channel count + (2, 12, 10), # Non-square input + (5, 20, 20), # Larger input + ] + + for shape in test_shapes: + with self.subTest(shape=shape): + torch.manual_seed(42) + x = torch.randn(shape) + + # Generate explicit samples to ensure deterministic, correct results + n_batch = 1 if x.dim() == 3 else x.size(0) + torch.manual_seed(42) + samples = torch.rand( + n_batch, x.size(-3), 2, dtype=x.dtype, device=x.device + ) + + def fn(x, samples): + return F.fractional_max_pool2d( + x, kernel_size=3, output_size=(4, 4), _random_samples=samples + ) + + # Test that eager mode works + expected = fn(x, samples) + + # Test that compiled mode works (was failing with AssertionError before fix) + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(x, samples) + + # Verify correctness with explicit samples (should match exactly) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index d168a161be9514..e23a8285088991 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -45,6 +45,7 @@ check_model = test_torchinductor.check_model set_num_threads = test_cpu_repro.set_num_threads +run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code aten = torch.ops.aten @@ -649,8 +650,9 @@ def forward(self, mul_239, view_425, add_184): view_425 = torch.randn(flatten_BS, in_features) add_184 = torch.randn(batch_size, img_size_0, img_size_1, in_features) mod = M(bias=bias).eval() - with verify(dtype) as (atol, rtol), torch.cpu.amp.autocast( - enabled=dtype == torch.bfloat16 + with ( + verify(dtype) as (atol, rtol), + torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), ): self.common( mod, @@ -1352,10 +1354,10 @@ def forward(self, x): if dtype == torch.bfloat16: atol, rtol = 5e-2, 5e-2 - with patch.object( - select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol) - ), torch.no_grad(), torch.autocast( - "cpu", enabled=(dtype == torch.bfloat16), dtype=dtype + with ( + patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)), + torch.no_grad(), + torch.autocast("cpu", enabled=(dtype == torch.bfloat16), dtype=dtype), ): ref_res = ref_quantized_mod(input) cfn = torch.compile(ref_quantized_mod) @@ -1434,6 +1436,90 @@ def forward(self, x, scale): vec_amx = VecAMX() self._check_amx_counter(vec_amx) + @inductor_config.patch({"freezing": True, "cpp.enable_concat_linear": True}) + @patches + @torch.no_grad + @dtypes(torch.bfloat16) + @parametrize( + "batch_size", + ( + 1, + 32, + ), + ) + @parametrize( + "mid_dim", + ( + 1, + 8, + ), + ) + @parametrize("in_features", (128,)) + @parametrize("out_features", (64,)) + def test_int8_woq_mm_concat( + self, dtype, batch_size, mid_dim, in_features, out_features + ): + def _convert_weight_to_int8pack(w): + scale, zp = _calculate_dynamic_per_channel_qparams( + w.to(torch.float), torch.int8 + ) + scale = torch.from_numpy(scale) + zp = torch.from_numpy(zp) + w_int8 = torch.ao.quantization.fx._decomposed.quantize_per_channel( + input=w, + scales=scale, + zero_points=zp, + axis=0, + quant_min=-128, + quant_max=127, + dtype=torch.int8, + ) + return w_int8, scale.to(torch.bfloat16) + + class M(torch.nn.Module): + def __init__(self, w1, w2, w3): + super().__init__() + self.w1 = torch.nn.Parameter(w1, requires_grad=False) + self.w2 = torch.nn.Parameter(w2, requires_grad=False) + self.w3 = torch.nn.Parameter(w3, requires_grad=False) + + def forward(self, x, scale1, scale2, scale3): + # Ref: _linear_fp_act_int8_weight_impl in torchao/dtypes/uintx/plain_layout.py + y1 = ( + torch.mm(x.reshape(-1, x.shape[-1]), self.w1.t().to(x.dtype)) + * scale1 + ) + y2 = ( + torch.mm(x.reshape(-1, x.shape[-1]), self.w2.t().to(x.dtype)) + * scale2 + ) + y3 = ( + torch.mm(x.reshape(-1, x.shape[-1]), self.w3.t().to(x.dtype)) + * scale3 + ) + return ( + y1.reshape(*x.shape[:-1], y1.shape[-1]), + y2.reshape(*x.shape[:-1], y2.shape[-1]), + y3.reshape(*x.shape[:-1], y3.shape[-1]), + ) + + counters.clear() + # Currently, the corresponding torch.fx pattern only supports 3D x + # Add 2D X case once the corresponding pattern-matcher pattern is added + x = torch.rand((batch_size, mid_dim, in_features), dtype=dtype) + w1 = torch.rand((out_features, in_features), dtype=dtype) + w2 = torch.rand((out_features, in_features), dtype=dtype) + w3 = torch.rand((out_features, in_features), dtype=dtype) + w1_int8pack, w1_scales = _convert_weight_to_int8pack(w1) + w2_int8pack, w2_scales = _convert_weight_to_int8pack(w2) + w3_int8pack, w3_scales = _convert_weight_to_int8pack(w3) + mod = M(w1_int8pack, w2_int8pack, w3_int8pack).eval() + self.common(mod, (x, w1_scales, w2_scales, w3_scales)) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + if batch_size * mid_dim >= 16: + vec_amx = VecAMX() + self._check_amx_counter(vec_amx) + @unittest.skipIf( not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required" ) @@ -1552,7 +1638,7 @@ def forward(self, x): return y.reshape(*x_shape[:-1], out_features) counters.clear() - seq_len = 8 + seq_len = 4 x = torch.rand((batch_size, seq_len, in_features), dtype=dtype) mod = M(in_features, out_features, group_size).eval() self.common(mod, (x,), reference_in_float=False) @@ -1570,7 +1656,122 @@ def forward(self, x): @patches @torch.no_grad @dtypes(torch.bfloat16) - @parametrize("batch_size", (4, 6)) + @parametrize("batch_size", (64,)) + @parametrize("in_features", (14336,)) + @parametrize("out_features", (96,)) + @parametrize("group_size", (128,)) + @set_num_threads(1) + def test_int4_woq_mm_amx_Nc_larger_than_one( + self, dtype, batch_size, in_features, out_features, group_size + ): + """ + Note: + `torch._weight_int4pack_mm_for_cpu` computes with float32, while the AMX-based GEMM + template computes with bfloat16. So, the difference of computation results may be big. + But we need `_weight_int4pack_mm_for_cpu` for its pattern. + Therefore, we define module M1 for its pattern and parameters and define module M2 for + the reference computation. M2's forward function gets the dequantized and unpacked weight + in bfloat16 then computes GEMM with bfloat16. + Besides, we need to skip the VERIFY patch and cannot use self.common for testing. + """ + + class M1(torch.nn.Module): + def __init__(self, K, N, group_size): + super().__init__() + self.linear_weight = torch.randint( + 0, 255, (N, K // 2), dtype=torch.uint8 + ) + self.qscale_and_zeros = torch.rand(K // group_size, N, 2, dtype=dtype) + self.group_size = group_size + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x_shape[-1]) + y = torch._weight_int4pack_mm_for_cpu( + x, self.linear_weight, self.group_size, self.qscale_and_zeros + ) + return y.reshape(*x_shape[:-1], out_features) + + class M2(torch.nn.Module): + def __init__(self, mod: M1): + super().__init__() + self.mod = mod + + def forward(self, x): + x_eye = torch.eye(x.shape[-1], device=x.device, dtype=x.dtype) + dq_w = self.mod(x_eye).T.contiguous() + return torch.nn.functional.linear(x, dq_w) + + counters.clear() + seq_len = 8 + x = torch.rand((batch_size, seq_len, in_features), dtype=dtype) + mod = M1(in_features, out_features, group_size).eval() + mod2 = M2(mod) + # Skip VERIFY during torch.compile and don't use self.common. See explanation above. + with patch.object(select_algorithm, "VERIFY", None): + m = torch.compile(mod) + y_ref = mod2(x) + y = m(x) + self.assertEqual( + y, + y_ref, + atol=1e-2, + rtol=1e-2, + ) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + + @unittest.skipIf( + not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required" + ) + @inductor_config.patch({"freezing": True}) + @inductor_config.patch({"cpp.use_small_dequant_buffer": True}) + @patches + @torch.no_grad + @dtypes(torch.bfloat16) + @parametrize("batch_size", (16,)) + @parametrize("in_features", (14336,)) + @parametrize("out_features", (96,)) + @parametrize("group_size", (128,)) + @set_num_threads(1) + def test_int4_woq_mm_with_small_buffer_config( + self, dtype, batch_size, in_features, out_features, group_size + ): + class M1(torch.nn.Module): + def __init__(self, K, N, group_size): + super().__init__() + self.linear_weight = torch.randint( + 0, 255, (N, K // 2), dtype=torch.uint8 + ) + self.qscale_and_zeros = torch.rand(K // group_size, N, 2, dtype=dtype) + self.group_size = group_size + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x_shape[-1]) + y = torch._weight_int4pack_mm_for_cpu( + x, self.linear_weight, self.group_size, self.qscale_and_zeros + ) + return y.reshape(*x_shape[:-1], out_features) + + counters.clear() + seq_len = 1 + x = torch.rand((batch_size, seq_len, in_features), dtype=dtype) + mod = M1(in_features, out_features, group_size).eval() + with patch.object(select_algorithm, "VERIFY", None): + m = torch.compile(mod) + _, code = run_and_get_cpp_code(m, x) + kr = 32 # only kr=32 supported in woq int4 amx kernel + _target_code_check = f"constexpr int64_t Kc_blocks = {group_size // kr};" + torch._C.FileCheck().check(_target_code_check).run(code) + + @unittest.skipIf( + not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required" + ) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @dtypes(torch.bfloat16) + @parametrize("batch_size", (1, 4, 6)) @parametrize("in_features", (128, 1024)) @parametrize("out_features", (128, 1024)) @parametrize("group_size", (32, 64, 128)) @@ -1633,6 +1834,94 @@ def forward(self, x): ) self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @unittest.skipIf( + not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required" + ) + @inductor_config.patch({"freezing": True}) + @inductor_config.patch({"cpp.enable_concat_linear": True}) + @patches + @torch.no_grad + @dtypes(torch.bfloat16) + @parametrize("batch_size", (4,)) + @parametrize("in_features", (256,)) + @parametrize("out_features", ((512, 256, 256), (512, 512))) + @parametrize("group_size", (32, 128)) + def test_int4_concat_woq_mm( + self, dtype, batch_size, in_features, out_features, group_size + ): + class M1(torch.nn.Module): + def __init__(self, K, out_features, group_size): + super().__init__() + self.linear_weight = [ + torch.randint(0, 255, (N, K // 2), dtype=torch.uint8) + for N in out_features + ] + self.qscale_and_zeros = [ + torch.rand(K // group_size, N, 2, dtype=dtype) for N in out_features + ] + self.group_size = group_size + self.out_features = out_features + + def forward(self, x): + x_shape = x.shape + x = x.reshape(-1, x_shape[-1]) + y = [ + torch._weight_int4pack_mm_for_cpu( + x, + self.linear_weight[idx], + self.group_size, + self.qscale_and_zeros[idx], + ) + for idx in range(len(self.out_features)) + ] + return [ + y[idx].reshape(*x_shape[:-1], self.out_features[idx]) + for idx in range(len(self.out_features)) + ] + + class M2(torch.nn.Module): + def __init__(self, mod: M1): + super().__init__() + self.mod = mod + + def forward(self, x): + x_eye = torch.eye(x.shape[-1], device=x.device, dtype=x.dtype) + dq_w_list = [] + for idx in range(len(self.mod.out_features)): + x_shape = x_eye.shape + dq_w = torch._weight_int4pack_mm_for_cpu( + x_eye, + self.mod.linear_weight[idx], + self.mod.group_size, + self.mod.qscale_and_zeros[idx], + ) + dq_w_list.append( + dq_w.reshape( + *x_shape[:-1], self.mod.out_features[idx] + ).T.contiguous() + ) + + return [torch.nn.functional.linear(x, dq_w) for dq_w in dq_w_list] + + counters.clear() + seq_len = 8 + x = torch.rand((batch_size, seq_len, in_features), dtype=dtype) + mod = M1(in_features, out_features, group_size).eval() + mod2 = M2(mod) + # Skip VERIFY during torch.compile and don't use self.common. See explanation above. + with patch.object(select_algorithm, "VERIFY", None): + y_ref = mod2(x) + m = torch.compile(mod) + y = m(x) + self.assertEqual( + y, + y_ref, + atol=1e-2, + rtol=1e-2, + ) + # Only do once tuning, since the wgt has been concat + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @inductor_config.patch({"freezing": True}) @patches @torch.no_grad @@ -1669,7 +1958,7 @@ def test_quantized_linear_with_pointwise_binary( input = torch.randn(*B, in_features).to(dtype=torch.float32) other = torch.randn(*B, out_features).to(dtype=dtype) - # Avoid hiting qlinear inplace sum fusion + # Avoid hitting qlinear inplace sum fusion if input_3d: other2 = torch.randn(B[0] * B[1], out_features).to(dtype=dtype) else: @@ -1686,7 +1975,7 @@ def __init__(self, bias, input_3d): def forward(self, x, other, other2): res = self.epilogue(self.linear(x) + other) - # Avoid hiting qlinear inplace sum fusion + # Avoid hitting qlinear inplace sum fusion if self.input_3d: other2 = other2.view(2, other2.size(0) // 2, other2.size(1)) else: @@ -1700,10 +1989,10 @@ def forward(self, x, other, other2): (input, other, other2), ) atol, rtol = 5e-2, 5e-2 - with patch.object( - select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol) - ), torch.no_grad(), torch.autocast( - "cpu", enabled=int8_mixed_bf16, dtype=torch.bfloat16 + with ( + patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)), + torch.no_grad(), + torch.autocast("cpu", enabled=int8_mixed_bf16, dtype=torch.bfloat16), ): ref_res = ref_quantized_mod(input, other, other2) cfn = torch.compile(ref_quantized_mod) @@ -1926,7 +2215,7 @@ def __init__(self, in_feature, out_feature, gemm_num): def forward(self, x): return [linear(x) for linear in self.linears] - # each linear has different num of out features, thus invaild grouped gemm + # each linear has different num of out features, thus invalid grouped gemm dtypes = [] if torch.ops.mkldnn._is_mkldnn_bf16_supported(): dtypes.append(torch.bfloat16) @@ -1938,9 +2227,11 @@ def forward(self, x): counters.clear() mod = M(in_features, out_features, gemm_num).eval() v = torch.randn(batch_size, in_features).to(dtype) - with verify(dtype) as (atol, rtol), torch.autocast( - device_type="cpu", dtype=dtype - ), torch.no_grad(): + with ( + verify(dtype) as (atol, rtol), + torch.autocast(device_type="cpu", dtype=dtype), + torch.no_grad(), + ): self.common(mod, (v,), atol=atol, rtol=rtol) # gemm_num independent template instead of grouped gemm template self.assertEqual( @@ -1992,9 +2283,11 @@ def forward(self, x): mod = M(in_features, out_features, gemm_num).eval() B = (2, batch_size) if input_3d else (batch_size,) v = torch.randn(*B, in_features).to(dtype) - with verify(dtype) as (atol, rtol), torch.autocast( - device_type="cpu", dtype=dtype - ), torch.no_grad(): + with ( + verify(dtype) as (atol, rtol), + torch.autocast(device_type="cpu", dtype=dtype), + torch.no_grad(), + ): self.common(mod, (v,), atol=atol, rtol=rtol) self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 1) @@ -2070,9 +2363,11 @@ def forward(self, x): mod = M(in_features, out_features, bias, epilogue).eval() B = (2, batch_size) if input_3d else (batch_size,) v = torch.randn(*B, in_features).to(dtype) - with verify(dtype) as (atol, rtol), torch.autocast( - device_type="cpu", dtype=dtype - ), torch.no_grad(): + with ( + verify(dtype) as (atol, rtol), + torch.autocast(device_type="cpu", dtype=dtype), + torch.no_grad(), + ): self.common(mod, (v,), atol=atol, rtol=rtol) self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 1) if any(e != "none" for e in epilogue): @@ -2684,6 +2979,7 @@ def forward(self, x, other, noise): return self.epilogue(result) + noise counters.clear() + u = torch.randn(bs, 8, Mdim, Kdim).to(dtype=dtype) v = torch.randn(bs, 8, Kdim, Ndim).to(dtype=dtype) noise = torch.randn(bs * 8, Mdim, Ndim).to(dtype=dtype) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 6c285213af1277..74c5f8468840e3 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -649,9 +649,9 @@ def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr): kernel.run(inout1, in0, xnumel, stream=stream0) kernel.run(inout2, in0, xnumel, stream=stream0) - assert same( - inout1, inout2, tol=0.001, equal_nan=True - ), "failed autotune with inplace kernel" + assert same(inout1, inout2, tol=0.001, equal_nan=True), ( + "failed autotune with inplace kernel" + ) def test_sort_stride_issue(self): # This minified testcase comes from detectron2_maskrcnn_r_50_fpn @@ -842,9 +842,9 @@ def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor): ] for dec_inp in dec_inputs: - assert same_two_models( - mod, opt_mod, [enc_out, dec_inp], only_fwd=True - ), "Inductor with dynamic shapes failed" + assert same_two_models(mod, opt_mod, [enc_out, dec_inp], only_fwd=True), ( + "Inductor with dynamic shapes failed" + ) def test_issue97695_1input(self): def fn(arg3_1, relu, permute_1): @@ -1569,6 +1569,19 @@ def test_multi_output_layout_fallback(self): self.assertEqual(o1, o2) + def test_sorted_masks(self): + @torch.compile() + def foo(x, y): + return (x + y).sum(dim=1) + + x = torch.rand([255, 255], device="cuda") + y = torch.rand([255, 255], device="cuda") + + _, code = run_and_get_code(foo, x, y) + FileCheck().check("tl.load").check_same("r0_mask").check_same("xmask").run( + code[0] + ) + def test_cat_int8_one_kernel(self): @torch.compile() def cat(inps): @@ -1801,9 +1814,9 @@ def forward(self, x): m = ToyModel().to(device="cuda:0") input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0") - from torch._inductor.utils import fresh_inductor_cache + from torch._inductor.utils import fresh_cache - with fresh_inductor_cache(): + with fresh_cache(): cm = torch.compile(m, mode="max-autotune") out = cm(input_tensor) out2 = m(input_tensor) @@ -1842,7 +1855,7 @@ def test_triton_interpret(self): def foo(x): return x + 1 -# somehow gives different results.. still, check that it doesnt error +# somehow gives different results.. still, check that it doesn't error foo(torch.rand([256], device="cuda")) """ subprocess.run([sys.executable, "-c", script], check=True) @@ -1910,11 +1923,9 @@ def foo(inp): getitem_24, ] ) - getitem_17 = ( - getitem_18 - ) = ( - getitem_19 - ) = getitem_20 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None + getitem_17 = getitem_18 = getitem_19 = getitem_20 = getitem_21 = ( + getitem_22 + ) = getitem_23 = getitem_24 = None return cat_1 for mark_dynamic in [False, True]: @@ -2111,6 +2122,95 @@ def interpolate_chunked(x): out_compiled = torch.compile(interpolate_chunked)(x) self.assertEqual(out_eager, out_compiled) + def test_max_autotune_nograd(self): + """ + https://github.com/pytorch/pytorch/issues/155688 + Smallest repro for max-autotune not working with no_grad + Before adding __int__ function to torch.utils._sympy.functions.Identity, + running the max_autotune mode would raise an error: + TypeError: Expected a number but got Identity + """ + + class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self.linear_layers = nn.ModuleList( + [ + nn.Linear(4, 1, bias=True), + nn.Linear(5, 1, bias=True), + nn.Linear(6, 1, bias=True), + nn.Linear(7, 1, bias=True), + nn.Linear(8, 1, bias=True), + ] + ) + + def forward(self, x): + for layer in self.linear_layers: + x2 = layer(x) + x2 = F.relu(x2) + x = torch.cat((x, x2), dim=1) + + return x + + model = ToyModel().to("cuda") + input_tensor = torch.randn((2, 4)).to("cuda") + + compile_default = torch.compile(model, mode="default") + compile_max_autotune = torch.compile(model, mode="max-autotune") + + with torch.no_grad(): + default_output = compile_default(input_tensor) + max_autotune_output = compile_max_autotune(input_tensor) + + self.assertEqual(default_output, max_autotune_output) + + def test_adaptive_avg_pool3d_issue_157248(self): + """Test for GitHub issue #157248: Conv2d-unsqueeze-AdaptiveAvgPool3d produces incorrect results""" + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) + self.adaptive_pool = torch.nn.AdaptiveAvgPool3d((4, 4, 4)) + + def forward(self, x): + x = self.conv(x) + # This specific unsqueeze position was problematic due to zero strides + x = x.unsqueeze(1) + x = self.adaptive_pool(x) + return x + + model = Model().cuda() + model.eval() + test_cases = [ + (1, 3, 8, 8), + (2, 3, 16, 16), + (1, 3, 32, 32), + (1, 3, 15, 15), + (2, 3, 13, 13), + ] + + for batch, channels, h, w in test_cases: + with self.subTest(input_shape=(batch, channels, h, w)): + input_tensor = torch.randn(batch, channels, h, w, device="cuda") + + # Test eager mode + with torch.no_grad(): + eager_output = model(input_tensor) + + # Test compiled mode with inductor + compiled_model = torch.compile(model, backend="inductor") + with torch.no_grad(): + compiled_output = compiled_model(input_tensor) + + # They should be identical (or very close) + self.assertTrue( + torch.allclose(eager_output, compiled_output, rtol=1e-5, atol=1e-5), + f"Results differ for input shape {(batch, channels, h, w)}. " + f"Max diff: {torch.max(torch.abs(eager_output - compiled_output)):.6f}", + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 2054c9abb50d4c..970fe64a758d81 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -1,16 +1,14 @@ # Owner(s): ["module: inductor"] import ctypes -import unittest import torch -from torch._inductor import config from torch._inductor.async_compile import AsyncCompile from torch._inductor.codecache import CUDACodeCache from torch._inductor.codegen.cuda.cuda_env import nvcc_exist from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache _SOURCE_CODE = r""" @@ -37,10 +35,9 @@ """ -@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup") class TestCUDACodeCache(InductorTestCase): def test_cuda_load(self): - with fresh_inductor_cache(): + with fresh_cache(): # Test both .o and .so compilation. ( object_file_path, @@ -50,8 +47,8 @@ def test_cuda_load(self): dll_wrapper, so_hash_key, source_code_path1 = CUDACodeCache.load( _SOURCE_CODE, "so" ) - self.assertNotEqual(source_code_path0, source_code_path1) - self.assertNotEqual(object_hash_key, so_hash_key) + self.assertEqual(source_code_path0, source_code_path1) + self.assertEqual(object_hash_key, so_hash_key) # Test load and call functions in .so. x = torch.rand(10).float().cuda() @@ -67,13 +64,13 @@ def test_cuda_load(self): torch.testing.assert_close(y, expected_y) def test_compilation_error(self): - with fresh_inductor_cache(): + with fresh_cache(): error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) with self.assertRaises(CUDACompileError): CUDACodeCache.compile(error_source_code, "o") def test_async_compile(self): - with fresh_inductor_cache(): + with fresh_cache(): async_compile = AsyncCompile() compiled_res = async_compile.cuda(_SOURCE_CODE, "so") async_compile.wait(globals()) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index d6e10fec83f843..7819cee39a7304 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -332,7 +332,7 @@ def inp(): ).check(".add_(2)").run(captured_output[0]) self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) - # mutation on inp doesnt hit cudagraphs + # mutation on inp doesn't hit cudagraphs self.assertEqual(len(self.get_manager().roots), 0) # mutation on parameters/buffers hits cudagraphs @@ -564,8 +564,8 @@ def foo2(x): del out # when I tried inducing separate recordings via graph break, - # the frame kept interferring by keeping outputs alive - # this isnt great by simulates the logic. + # the frame kept interfering by keeping outputs alive + # this isn't great by simulates the logic. from torch._dynamo.mutation_guard import GenerationTracker GenerationTracker.generation -= 1 @@ -575,7 +575,7 @@ def foo2(x): foo_opt(torch.ones([4, 4], device="cuda")) - # Two separate traces - one has a child, one doesnt + # Two separate traces - one has a child, one doesn't self.assertEqual(self.get_root_children(), [1, 0]) def test_execution_into_recording(self): @@ -1325,7 +1325,7 @@ def test_multiple_insert_removal_caching(self): torch._C._set_cached_tensors_enabled(False) def test_accumulate_grad(self): - # cudagraph trees shouldnt interfere with accumulation logic + # cudagraph trees shouldn't interfere with accumulation logic def compute_grad(grad_output, create_graph): x = torch.randn(5, 5, requires_grad=True, device="cuda") @@ -1366,7 +1366,7 @@ def foo(x): for _ in range(3): out = frozen(torch.rand([10, 10], device="cuda")) - # didnt do additional recordings + # didn't do additional recordings self.assertTrue(self.get_manager().new_graph_id().id == 2) def test_empty_cpu_tensor(self): @@ -2403,9 +2403,7 @@ def forward(self, x): "on cudagraph node None due to static input data pointer changed.", 1, exactly=True, - ).run( - captured_output[0] - ) + ).run(captured_output[0]) self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) @@ -2712,6 +2710,129 @@ def f(x, y): # 2 graph partitions lead to 2 cudagraph self.assertEqual(self.get_manager().new_graph_id().id, 2) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_cpu_scalar1(self): + def f(x, y): + return x + y + + compiled_f = torch.compile(f, mode="reduce-overhead") + + inputs = (torch.ones(2, 2, device="cuda"), torch.ones((), device="cpu")) + for i in range(3): + if i == 0: + _, code = run_and_get_code(compiled_f, *inputs) + FileCheck().check_count(".copy_", 1, exactly=True).run(code[0]) + else: + compiled_f(*inputs) + self.assertEqual(compiled_f(*inputs), f(*inputs)) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_cpu_scalar2(self): + def f(x, y, z): + return x + y, x + z + + compiled_f = torch.compile(f, mode="reduce-overhead") + + inputs = ( + torch.ones((), device="cpu"), + torch.ones(2, 2, device="cuda"), + torch.ones(2, 2, device="cuda"), + ) + for i in range(3): + if i == 0: + _, code = run_and_get_code(compiled_f, *inputs) + FileCheck().check_count(".copy_", 1, exactly=True).run(code[0]) + else: + compiled_f(*inputs) + self.assertEqual(compiled_f(*inputs), f(*inputs)) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_cpu_scalar3(self): + def f(x, y, cpu_scalar_tensor): + z = x + y + z = z + cpu_scalar_tensor + return z + + compiled_f = torch.compile(f, mode="reduce-overhead") + + inputs = ( + torch.randn(2, 2, device="cuda"), + torch.randn(2, 2, device="cuda"), + torch.tensor(1, device="cpu"), + ) + for i in range(3): + if i == 0: + _, code = run_and_get_code(compiled_f, *inputs) + FileCheck().check_count(".copy_", 1, exactly=True).run(code[0]) + else: + compiled_f(*inputs) + self.assertEqual(compiled_f(*inputs), f(*inputs)) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_cpu_scalar4(self): + # cpu_scalar_tensor is accessed by cpu_scalar2 which is + # added with a gpu tensor z. This test checks the cpu + # scalar tensors are still moved in this case. + def f(x, y, cpu_scalar_tensor): + cpu_scalar2 = cpu_scalar_tensor + 1 + z = x + y + z = z + cpu_scalar2 + return z + + compiled_f = torch.compile(f, mode="reduce-overhead") + + inputs = ( + torch.randn(2, 2, device="cuda"), + torch.randn(2, 2, device="cuda"), + torch.tensor(1, device="cpu"), + ) + for i in range(3): + if i == 0: + _, code = run_and_get_code(compiled_f, *inputs) + FileCheck().check_count(".copy_", 1, exactly=True).run(code[0]) + else: + compiled_f(*inputs) + self.assertEqual(compiled_f(*inputs), f(*inputs)) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + + @torch._inductor.config.patch("graph_partition", True) + # turn on input mutation support to avoid skipping cudagraph at dynamo level + @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) + def test_graph_partition_cpu_scalar_mutation(self): + # tests that input mutation on a cpu scalar tensor x is correctly + # handled when moving x to gpu at the beginning of the graph. + + @torch.compile(mode="reduce-overhead") + def foo(x, y): + return x.copy_(y) + + x = torch.tensor(1) + y = torch.tensor(2, device="cuda") + + for _ in range(3): + foo(x, y) + + self.assertEqual(x, torch.tensor(2, device="cpu")) + self.assertEqual(y, torch.tensor(2, device="cuda")) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_cpu_scalar_device_put(self): + @torch.compile(mode="reduce-overhead") + def foo(x): + y = x.to("cuda") + z = y.to("cpu") + return z + + x = torch.tensor(1) + for _ in range(3): + foo(x) + + self.assertEqual(x, torch.tensor(1, device="cpu")) + @torch._inductor.config.patch("graph_partition", True) @torch._inductor.config.patch("triton.cudagraphs", False) def test_graph_partition_reduce_overhead_mode_effectiveness(self): @@ -3411,6 +3532,79 @@ def foobar(x, y): self.assertEqual(eager_out, compiled_out) self.assertEqual(self.get_manager().new_graph_id().id, 1) + def test_cudagraph_capture_sizes(self): + torch._inductor.config.triton.cudagraph_capture_sizes = (2, 5, 7) + + def f(x): + return x + 1 + + f = torch.compile(f, mode="reduce-overhead") + + def run(shape): + x = torch.randn((shape, 5), device="cuda") + torch._dynamo.mark_dynamic(x, 0) + for _ in range(3): + f(x) + + for i in range(1, 10): + run(i) + + self.assertEqual(self.get_manager().new_graph_id().id, 3) + + def test_cudagraph_capture_sizes1(self): + torch._inductor.config.triton.cudagraph_capture_sizes = ( + (2, 3), + (4, 5), + (6, 2), + (7, 3), + ) + + def f(x): + return x + 1 + + f = torch.compile(f, mode="reduce-overhead") + + def run(batch_size, seq_len, d): + x = torch.randn((batch_size, seq_len, d), device="cuda") + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(x, 1) + for _ in range(3): + f(x) + + for i in range(2, 10): + for j in range(2, 10): + run(i, j, 8) + + self.assertEqual(self.get_manager().new_graph_id().id, 4) + + def test_cudagraph_capture_sizes2(self): + torch._inductor.config.triton.cudagraph_capture_sizes = ( + (2, 3, 4), + (4, 4, 3), + (3, 4, 4), + (4, 2, 3), + ) + + def f(x): + return x + 1 + + f = torch.compile(f, mode="reduce-overhead") + + def run(batch_size, seq_len, d): + x = torch.randn((batch_size, seq_len, d), device="cuda") + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(x, 1) + torch._dynamo.mark_dynamic(x, 2) + for _ in range(3): + f(x) + + for i in range(2, 5): + for j in range(2, 5): + for k in range(2, 5): + run(i, j, k) + + self.assertEqual(self.get_manager().new_graph_id().id, 4) + class TestSAC(TestCase): def _make_observer_mode(self): class ObserverMode(TorchDispatchMode): @@ -3760,7 +3954,7 @@ def multi_fn(x, y, a, b): a = torch.randn(4, 4, device="cuda:1", requires_grad=True) b = torch.randn(4, 4, device="cuda:1", requires_grad=True) - # No errors. TODO - get graphs from logging, couldnt figure out how + # No errors. TODO - get graphs from logging, couldn't figure out how multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition") out = multi_fn_c(x, y, a, b) diff --git a/test/inductor/test_custom_post_grad_passes.py b/test/inductor/test_custom_post_grad_passes.py index 2994b4109e6009..653b94d6ce0525 100644 --- a/test/inductor/test_custom_post_grad_passes.py +++ b/test/inductor/test_custom_post_grad_passes.py @@ -8,12 +8,17 @@ import torch.fx as fx from torch._dynamo.utils import counters from torch._inductor import config -from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files +from torch._inductor.codegen.common import get_custom_backend_pass_for_device +from torch._inductor.custom_graph_pass import ( + CustomGraphModulePass, + CustomGraphPass, + get_hash_for_files, +) from torch._inductor.lowering import lowerings as L from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CPU +from torch.testing._internal.inductor_utils import HAS_CPU, patch_inductor_backend @config.patch({"freezing": True}) @@ -264,6 +269,35 @@ def f(W, nested_seqs): inner_test() + def test_custom_backend_pass(self): + class CustomBackendPass(CustomGraphModulePass): + def __init__(self, existing_pass: CustomGraphModulePass = None): + super().__init__() + self.existing_pass = existing_pass + + def __call__(self, gm: fx.GraphModule) -> None: + if self.existing_pass: + self.existing_pass(gm) + + change_cos_pass(gm.graph) + + def uuid(self) -> bytes: + return get_hash_for_files((__file__,)) + + custom_backend_pass = CustomBackendPass( + get_custom_backend_pass_for_device("cpu") + ) + with patch_inductor_backend("cpu", custom_pass=custom_backend_pass): + + def g(x): + return x.sin().sin().sin() + + def f(x): + return x.cos().cos().cos() + + x = torch.randn(8, dtype=torch.float32) + torch.testing.assert_close(torch.compile(f)(x), g(x)) + if __name__ == "__main__": if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available(): diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index bafce4fe0ef520..5861fa2856b335 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -13,7 +13,7 @@ from typing import Callable, Optional from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer -from torch._inductor.utils import clear_inductor_caches +from torch._inductor.utils import clear_caches from torch.export import Dim from torch.testing._internal.logging_utils import log_settings @@ -38,7 +38,7 @@ from torch._inductor.ir import FixedLayout from torch._inductor.select_algorithm import NoValidChoicesError from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( @@ -166,14 +166,14 @@ def setUp(self): os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1" super().setUp() finally: - os.environ[ - "INDUCTOR_TEST_DISABLE_FRESH_CACHE" - ] = old_disable_fresh_cache_envvar + os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = ( + old_disable_fresh_cache_envvar + ) torch.random.manual_seed(1234) def tearDown(self): super().tearDown() - clear_inductor_caches() + clear_caches() def run_evt_test(self, model, op, shape, num_fusions=1): M, N = shape @@ -234,7 +234,10 @@ def test_import_cutlass(self): self.assertTrue(try_import_cutlass()) - import cutlass # noqa: F401 + if config.is_fbcode(): + import python_cutlass + else: + import cutlass as python_cutlass # noqa: F401 import cutlass_library # noqa: F401 def test_cutlass_key(self): @@ -446,21 +449,26 @@ def test_max_autotune_cutlass_backend_regular_mm( Main test for mm. """ - class MyModel(torch.nn.Module): - def forward(self, a, b): - return a @ b - - model = MyModel().cuda() # M, N, K shapes = [ (128, 128, 16), (1024, 1024, 256), ] - shapes = shapes[0:1] if not dynamic else shapes + + # M, N, K + shapes = shapes if dynamic else shapes[0:1] + + class MyModel(torch.nn.Module): + def forward(self, a, b): + return a @ b + + model = MyModel().cuda() + inputs = [ (torch.randn(M, K).cuda().to(dtype), torch.randn(K, N).cuda().to(dtype)) for (M, N, K) in shapes ] + dynamic_shapes = ( { "a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, @@ -470,23 +478,118 @@ def forward(self, a, b): else None ) - with config.patch( + with ( + config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_max_profiling_configs": 2, + } + ), + dynamo_config.patch({"error_on_recompile": dynamic}), + ): + expected = [model(*input) for input in inputs] + if use_aoti: + actual = AOTIRunnerUtil.run_multiple( + model, inputs, dynamic_shapes=dynamic_shapes + ) + else: + compiled_model = torch.compile(model, dynamic=True) + actual = [compiled_model(*input) for input in inputs] + + torch.testing.assert_close(actual, expected) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @parametrize("dynamic", (False, True)) + @parametrize("use_aoti", (False, True)) + @parametrize("dtype", (torch.float8_e4m3fn,)) + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_max_autotune_cutlass_backend_fp8_scaled_mm( + self, + dynamic: bool, + max_autotune_gemm_backends: str = "CUTLASS", + use_aoti: bool = False, + dtype: torch.dtype = torch.float16, + ): + """ + Main test for mm. + """ + + # M, N, K + shapes = [ + (128, 128, 16), + (1024, 1024, 256), + ] + + # M, N, K + shapes = shapes if dynamic else shapes[0:1] + + inputs = [] + for shape in shapes: + M, N, K = shape + output_dtype = torch.bfloat16 + device = "cuda" + + x = torch.randn(M, K, dtype=output_dtype, device=device) + w = torch.randn(N, K, dtype=output_dtype, device=device) + + # quantize weight (prior to inference) + w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype) + w_t_fp8 = w_fp8.t() + w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N) + + # quantize input x + x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype) + + inputs.append((x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale)) + + class MyModel(torch.nn.Module): + def forward(self, x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale): + y = torch._scaled_mm( + x_fp8, + w_t_fp8, + x_inverse_scale, + w_inverse_scale, + None, + out_dtype=torch.bfloat16, + use_fast_accum=False, + ) + return y + + dynamic_shapes = ( { - "max_autotune": True, - "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "x_fp8": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, + "x_inverse_scale": {0: Dim.DYNAMIC, 1: 1}, + "w_t_fp8": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, + "w_inverse_scale": {0: 1, 1: Dim.DYNAMIC}, } - ), dynamo_config.patch({"error_on_recompile": dynamic}): + if dynamic + else None + ) + model = MyModel().cuda() + + with ( + config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_max_profiling_configs": 2, + "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet + "cuda.cutlass_tma_only": True, + } + ), + dynamo_config.patch({"error_on_recompile": dynamic}), + ): expected = [model(*input) for input in inputs] if use_aoti: actual = AOTIRunnerUtil.run_multiple( model, inputs, dynamic_shapes=dynamic_shapes ) else: - compiled_model = torch.compile(model, dynamic=dynamic) + compiled_model = torch.compile(model, dynamic=True) actual = [compiled_model(*input) for input in inputs] - torch.testing.assert_close(actual, expected) + torch.testing.assert_close(actual, expected, rtol=1e-2, atol=0.05) @unittest.skipIf(not SM90OrLater, "need sm_90") @parametrize("dynamic", (False, True)) @@ -524,7 +627,7 @@ def forward(self, x, a, b): ] for x_shape in x_shapes: torch._dynamo.reset() - clear_inductor_caches() + clear_caches() inputs = [ ( @@ -547,13 +650,16 @@ def forward(self, x, a, b): if dynamic else None ) - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, - } - ), dynamo_config.patch({"error_on_recompile": dynamic}): + with ( + config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_max_profiling_configs": 2, + } + ), + dynamo_config.patch({"error_on_recompile": dynamic}), + ): expected = [model(*input) for input in inputs] if use_aoti: actual = AOTIRunnerUtil.run_multiple( @@ -709,9 +815,9 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion( Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) Y = mm(a, b) actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"] - assert ( - actual_count == expected_fuse_count - ), f"Expected fuse count of {expected_fuse_count} but got {actual_count}" + assert actual_count == expected_fuse_count, ( + f"Expected fuse count of {expected_fuse_count} but got {actual_count}" + ) torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) @unittest.skipIf(not SM90OrLater, "need sm_90") @@ -971,7 +1077,7 @@ def my_addmm(x, a, b, alpha, beta): def select_no_algorithm(*args, **kwargs): raise NoValidChoicesError - with fresh_inductor_cache(): + with fresh_cache(): with config.patch( { "max_autotune": True, @@ -998,9 +1104,9 @@ def select_no_algorithm(*args, **kwargs): choice_info = choice.info_dict() op_conf_name = choice_info.get("op_conf_name", "") assert isinstance(op_conf_name, str) - assert ( - "pingpong" not in op_conf_name - ), "All pingpong Kernels should have been filtered" + assert "pingpong" not in op_conf_name, ( + "All pingpong Kernels should have been filtered" + ) cuda_template_count += 1 assert cuda_template_count > 0, "No CUDATemplateCaller choices" @@ -1019,7 +1125,7 @@ def addmm(x, a, b, alpha, beta): def select_no_algorithm(*args, **kwargs): raise NoValidChoicesError - with fresh_inductor_cache(): + with fresh_cache(): with config.patch( { "max_autotune": True, @@ -1046,12 +1152,99 @@ def select_no_algorithm(*args, **kwargs): choice_info = choice.info_dict() op_conf_name = choice_info.get("op_conf_name", "") assert isinstance(op_conf_name, str) - assert ( - "pingpong" in op_conf_name - ), "Only pingpong Kernels should have been allowed" + assert "pingpong" in op_conf_name, ( + "Only pingpong Kernels should have been allowed" + ) cuda_template_count += 1 assert cuda_template_count > 0, "No CUDATemplateCaller choices" + @unittest.skipIf(not SM90OrLater, "need sm_90") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_cutlass_backend_fp8_scaled_mm_fast_accum_filtering( + self, + ): + float8_dtype = torch.float8_e4m3fn + # Only bf16 output type is supported for row-wise scaling, not fp32 + output_dtype: torch.dtype = torch.bfloat16 + device = "cuda" + M, K, N = 128, 128, 128 # Matmul Y = X [M, K] x W [N, K] + x = torch.randn(M, K, dtype=output_dtype, device=device) + w = torch.randn(N, K, dtype=output_dtype, device=device) + bias = None + # quantize weight (prior to inference) + w_fp8, w_inverse_scale = _quantize_rowwise(w, float8_dtype) + w_t_fp8 = w_fp8.t() + w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N) + + # quantize input x + x_fp8, x_inverse_scale = _quantize_rowwise(x, float8_dtype) + + def linear( + x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias, use_fast_accum + ): + y = torch._scaled_mm( + x_fp8, + w_t_fp8, + x_inverse_scale, + w_inverse_scale, + bias, + out_dtype=output_dtype, + use_fast_accum=use_fast_accum, + ) + return y + + linear_compiled = torch.compile(linear, backend="inductor") + + def select_no_algorithm(*args, **kwargs): + raise NoValidChoicesError + + def run_test(use_fast_accum): + with fresh_cache(): + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_max_profiling_configs": 2, + } + ): + with mock.patch( + "torch._inductor.kernel.mm.autotune_select_algorithm", + wraps=select_no_algorithm, + ) as sa: + with self.assertRaisesRegex( + InductorError, r".*NoValidChoicesError.*" + ): + linear_compiled( + x_fp8, + x_inverse_scale, + w_t_fp8, + w_inverse_scale, + bias, + use_fast_accum, + ) + + args, _ = sa.call_args + _, choices, _, _ = args + cuda_template_count = 0 + for choice in choices: + if isinstance(choice, CUDATemplateCaller): + choice_info = choice.info_dict() + op_conf_name = choice_info.get("op_conf_name", "") + assert isinstance(op_conf_name, str) + if use_fast_accum: + assert "fastaccum" in op_conf_name, ( + "Only fastaccum Kernels should have been allowed" + ) + else: + assert "fastaccum" not in op_conf_name, ( + "fastaccum Kernels should have been filtered" + ) + cuda_template_count += 1 + assert cuda_template_count > 0, "No CUDATemplateCaller choices" + + run_test(True) + run_test(False) + @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_shape_coverage_mm( @@ -1085,16 +1278,20 @@ def test_cutlass_backend_shape_coverage_mm( def select_no_algorithm(*args, **kwargs): raise NoValidChoicesError - with fresh_inductor_cache(), config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, - } - ), mock.patch( - "torch._inductor.kernel.mm.autotune_select_algorithm", - wraps=select_no_algorithm, - ) as sa: + with ( + fresh_cache(), + config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_max_profiling_configs": 2, + } + ), + mock.patch( + "torch._inductor.kernel.mm.autotune_select_algorithm", + wraps=select_no_algorithm, + ) as sa, + ): for input in inputs: A, B = input M, K = A.shape @@ -1143,17 +1340,21 @@ def test_cutlass_presets( def select_no_algorithm(*args, **kwargs): raise NoValidChoicesError - with fresh_inductor_cache(), config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, - "cuda.cutlass_presets": presets, - } - ), mock.patch( - "torch._inductor.kernel.mm.autotune_select_algorithm", - wraps=select_no_algorithm, - ) as sa: + with ( + fresh_cache(), + config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_max_profiling_configs": 2, + "cuda.cutlass_presets": presets, + } + ), + mock.patch( + "torch._inductor.kernel.mm.autotune_select_algorithm", + wraps=select_no_algorithm, + ) as sa, + ): with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"): torch.compile(torch.mm)(A, B) @@ -1340,9 +1541,12 @@ def mm(a, b): "force_disable_caches": True, } ): - with log_settings("+inductor"), self.assertLogs( - logger="torch._inductor.codegen.cuda", level=logging.DEBUG - ) as test_log: + with ( + log_settings("+inductor"), + self.assertLogs( + logger="torch._inductor.codegen.cuda", level=logging.DEBUG + ) as test_log, + ): Y_compiled = torch.compile(mm, dynamic=False)(a, b) Y = mm(a, b) torch.testing.assert_close(Y_compiled, Y) @@ -1449,10 +1653,18 @@ def forward(self, B): @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - def test_compilation_time(self): + @parametrize("use_aoti", (False, True)) + def test_compilation_time(self, use_aoti): M = 1024 A = torch.randn(M, M).cuda().half() - B = torch.randn(M, M).cuda().half() + B = torch.randn(M, M).cuda().half().t() + + class MyModel(torch.nn.Module): + def forward(self, a, b): + return a @ b + + model = MyModel().cuda() + expected = model(A, B) start_time = time.time() with config.patch( @@ -1462,7 +1674,15 @@ def test_compilation_time(self): "cuda.cutlass_max_profiling_configs": 1, } ): - _ = torch.compile(torch.mm)(A, B) + if use_aoti: + actual = AOTIRunnerUtil.run( + model, + (A, B), + ) + else: + actual = torch.compile(model, fullgraph=True)(A, B) + + torch.testing.assert_close(actual, expected) self.assertTrue(time.time() - start_time < 50) @unittest.skipIf(not SM90OrLater, "need sm_90") @@ -1566,7 +1786,10 @@ def forward(self, a, b, extra_args): @unittest.skipIf(not SM90OrLater, "need sm_90") @use_evt_config @evt_all_ops - def test_evt_multi_output(self, op): + @parametrize( + "dynamic", (False, True) + ) # To not drastically increase test time we only test dynamic on this test + def test_evt_multi_output(self, op, dynamic): class TestModel(torch.nn.Module): def forward(self, a, b, extra_args): acc = a @ b @@ -1577,18 +1800,24 @@ def forward(self, a, b, extra_args): M = 1024 N = 512 - a = torch.ones(M, N).cuda().half() - b = torch.ones(N, N).cuda().half() - extra_args = gen_args(op, (M, N)) - model = TestModel().cuda() + shapes = [(512, 512)] if not dynamic else [(1024, 64), (128, 256)] + for i, shape in enumerate(shapes): + M, N = shape + a = torch.ones(M, N).cuda().half() + b = torch.ones(N, N).cuda().half() + extra_args = gen_args(op, (M, N)) + model = TestModel().cuda() - result = torch.compile(model)(a, b, extra_args) - ref_result = model(a, b, extra_args) + result = torch.compile(model)(a, b, extra_args) + ref_result = model(a, b, extra_args) - self.assertEqual( - torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2 - ) - torch.testing.assert_close(result, ref_result) + self.assertEqual( + torch._dynamo.utils.counters["inductor"][ + "cuda_epilogue_fusion_counter" + ], + 2 * (i + 1), + ) + torch.testing.assert_close(result, ref_result) @unittest.skipIf(not SM90OrLater, "need sm_90") @use_evt_config @@ -1648,13 +1877,13 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str): "shape", ( ( - 16, - 16, - 32, + 512, + 128, + 64, ), ), ) - @parametrize("has_bias", (False,)) + @parametrize("has_bias", (False, True)) @parametrize("use_fast_accum", (False,)) def test_fp8_rowwise_scaling( self, @@ -1720,13 +1949,13 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): "shape", ( ( - 16, - 16, - 32, + 512, + 128, + 64, ), ), ) - @parametrize("has_bias", (False,)) + @parametrize("has_bias", (False, True)) @parametrize("use_fast_accum", (False,)) def test_fp8_tensorwise_scaling( self, diff --git a/test/inductor/test_cutlass_evt.py b/test/inductor/test_cutlass_evt.py index baefa62878fa7c..9a8db3d2169dcc 100644 --- a/test/inductor/test_cutlass_evt.py +++ b/test/inductor/test_cutlass_evt.py @@ -4,6 +4,7 @@ import sympy import torch +import torch._inductor.config as config from torch._dynamo.test_case import TestCase from torch._inductor.codegen.cuda.cutlass_utils import ( torch_dtype_to_cutlass_type, @@ -26,10 +27,15 @@ from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( _render_argument_type, _trace, - CutlassTensor, trace, ) + if config.is_fbcode(): + import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401 + else: + import cutlass as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401 + CutlassTensor = python_cutlass.backend.evt.ir.tensor.Tensor + BIAS_CODE = """def example_epilogue(accum, C, aux, bias): F = accum + C + aux E = relu(F) + bias @@ -107,6 +113,7 @@ def __init__(self, name_to_buffer): self.name_to_buffer = name_to_buffer self.graph_inputs = dict() self.mutated_buffers = OrderedSet() + self.constants = dict() class TestCutlassEVT(TestCase): @@ -347,7 +354,9 @@ def test_example_tensor_creation(self): ) buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"} name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1} - result = create_example_tensors(buffer_renames, name_to_buffer) + result = create_example_tensors( + buffer_renames, name_to_buffer, lambda x: int(x) + ) self.assertEqual(result["acc"].shape, (3, 4, 1)) self.assertEqual(result["acc"].stride, (4, 1, 0)) self.assertEqual( @@ -370,7 +379,9 @@ def test_evt_argument_codegen(self): self.assertExpectedInline( _render_argument_type( - epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS) + epilogue_functor, + _create_mock_buffer_name_map(EXAMPLE_TENSORS), + lambda x: int(x), ), """\ { /* thread */ @@ -425,7 +436,9 @@ def fn(accum, bias): self.assertExpectedInline( _render_argument_type( - epilogue_functor, _create_mock_buffer_name_map(example_tensors) + epilogue_functor, + _create_mock_buffer_name_map(example_tensors), + lambda x: int(x), ), """\ { /* thread */ @@ -450,6 +463,7 @@ def test_evt_codegen(self): MockTileDescription(), EpilogueScheduleType.ScheduleAuto, _create_mock_buffer_name_map(EXAMPLE_TENSORS), + lambda x: x, # static shapes ) self.assertExpectedInline( code, diff --git a/test/inductor/test_debug_trace.py b/test/inductor/test_debug_trace.py index 145932a72a1653..7b999a1e692d72 100644 --- a/test/inductor/test_debug_trace.py +++ b/test/inductor/test_debug_trace.py @@ -10,7 +10,7 @@ import torch from torch._inductor import config, test_operators -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.logging_utils import multiple_logs_to_string @@ -44,7 +44,7 @@ def fn(a, b): "torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion" ) - # TODO(aakhundov): make this work with fresh_inductor_cache + # TODO(aakhundov): make this work with fresh_cache # instead of force_disable_caches. currently, with the latter # enabled, we get `inductor [('fxgraph_cache_hit', 1)]` in # the counters: so the cache is actually hit and the test fails. @@ -54,9 +54,12 @@ def fn(a, b): "force_disable_caches": True, } ): - with self.assertLogs( - logging.getLogger("torch._inductor.debug"), level=logging.WARNING - ) as cm, ctx(): + with ( + self.assertLogs( + logging.getLogger("torch._inductor.debug"), level=logging.WARNING + ) as cm, + ctx(), + ): fn(torch.randn(16, 16), torch.randn(16, 16)) m = None @@ -261,9 +264,13 @@ def forward(self, x): return self.relu(self.l(x)) # no failure - with self.assertLogs( - logging.getLogger("torch._inductor.debug"), level=logging.WARNING - ), fresh_inductor_cache(): + with ( + self.assertLogs( + logging.getLogger("torch._inductor.debug"), + level=logging.WARNING, + ), + fresh_cache(), + ): m = ToyModel().to(device=GPU_TYPE) m = torch.compile(m, mode="max-autotune") input_tensor = torch.randn(100).to(device=GPU_TYPE) diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index d21de3178cf1e0..dd67a1f806c731 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -278,11 +278,9 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): ) counters.clear() - # (1, 64, 32, False) vesrion fails - @unittest.skip @parametrize( "m,k,n, should_decompose", - [(1, 64, 16, True), (2, 64, 16, False), (1, 64, 32, False)], + [(1, 64, 16, True), (2, 64, 16, False), (1, 64, 32, True)], ) def test_decompose_mm_cpu(self, m, n, k, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index 66bd548b7a987c..a3bbe177803082 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -29,8 +29,9 @@ def fw_pre_hook(mod, inp): mod.unsharded_weight.untyped_storage().resize_( mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() ) - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight + with ( + torch.no_grad(), + torch.autograd._unsafe_preserve_version_counter(mod.unsharded_weight), ): torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight)) mod._parameters["weight"] = mod.unsharded_weight @@ -52,8 +53,9 @@ def bw_pre_hook(mod, gO): mod.unsharded_weight.untyped_storage().resize_( mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() ) - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight + with ( + torch.no_grad(), + torch.autograd._unsafe_preserve_version_counter(mod.unsharded_weight), ): torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight)) mod._parameters["weight"] = mod.unsharded_weight @@ -338,7 +340,7 @@ def test_module_backward_hooks_eager(self): self.assertEqual(fw_cnt.op_count, 5) self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None self.assertEqual( - bw_cnt.op_count, 114 + bw_cnt.op_count, 111 ) # Number of ops in the Dynamo-produced graphs def test_module_backward_hooks_aot(self): diff --git a/test/inductor/test_external_callables.py b/test/inductor/test_external_callables.py index eadf00df50e03c..a8aab1c00d80b4 100644 --- a/test/inductor/test_external_callables.py +++ b/test/inductor/test_external_callables.py @@ -16,7 +16,7 @@ def forward(self, x): return torch.matmul(x, self.matrix) -# torch.add performs better than torch.mm and got choosed during tuning +# torch.add performs better than torch.mm and got chosen during tuning def matmul_cpu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: torch.add(a, b, out=out) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index bb04e970604898..4d14555800c8c4 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -14,6 +14,7 @@ from unittest.mock import patch import torch +import torch.nn as nn from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm from torch._inductor import metrics from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC @@ -30,6 +31,7 @@ BlockMask, create_block_mask, flex_attention, + flex_attention_hop, noop_mask, or_masks, ) @@ -2716,7 +2718,7 @@ def mask(b, h, q, kv): ) q, k, v = make_tensor2(), make_tensor2(), make_tensor2() - # Compile 2st version with q/k/v(seqlen=2048) and block_mask(seqlen=4096), + # Compile 2nd version with q/k/v(seqlen=2048) and block_mask(seqlen=4096), # The graph includes the BlockMask._adjust part. out = torch.compile(flex_attention, dynamic=True, fullgraph=True)( q, k, v, block_mask=block_mask @@ -3455,9 +3457,9 @@ def forward(self, x, input_lengths): l = torch.randint(0, T, (B,), device=device) model(x, l) - assert ( - counter.frame_count == 1 - ), f"Expected 1 graph, but got {counter.frame_count} graphs" + assert counter.frame_count == 1, ( + f"Expected 1 graph, but got {counter.frame_count} graphs" + ) @supported_platform @skip_on_cpu @@ -3803,6 +3805,145 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 expected_joint_graph, ) + @supported_platform + def test_tensor_subclass_dispatch_order(self, device): + """Test that tensor subclasses get proper dispatch priority over modes. + + This test verifies the fix that allows tensor subclasses' pyimpl to run before + FakeTensorMode/FunctionalTensorMode implementations, preventing issues + where subclasses that error on as_strided would fail in flex_attention. + """ + import torch.utils._pytree as pytree + from torch.utils._python_dispatch import return_and_correct_aliasing + + class AsStridedErrorTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem): + assert isinstance(elem, torch.Tensor) + return torch.Tensor._make_wrapper_subclass( + cls, + elem.shape, + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad, + ) + + def __init__(self, elem): + self.elem = elem + + def __repr__(self): + return f"AsStridedErrorTensor({self.elem})" + + def __tensor_flatten__(self): + return ["elem"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + elem = inner_tensors["elem"] + return AsStridedErrorTensor(elem) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # Error if as_strided is called + if func is torch.ops.aten.as_strided.default: + raise RuntimeError("as_strided was called on AsStridedErrorTensor!") + + if kwargs is None: + kwargs = {} + args_elem = pytree.tree_map_only( + AsStridedErrorTensor, lambda x: x.elem, args + ) + kwargs_elem = pytree.tree_map_only( + AsStridedErrorTensor, lambda x: x.elem, kwargs + ) + + out = func(*args_elem, **kwargs_elem) + + def wrap_output(x): + if isinstance(x, torch.Tensor): + return AsStridedErrorTensor(x) + return x + + out_wrapped = pytree.tree_map(wrap_output, out) + return return_and_correct_aliasing(func, args, kwargs, out_wrapped) + + from torch._higher_order_ops.flex_attention import ( + flex_attention as flex_attention_hop, + ) + + @flex_attention_hop.py_impl(AsStridedErrorTensor) + def flex_attention_as_strided_error_tensor( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers=(), + mask_mod_other_buffers=(), + ): + inner_q, inner_k, inner_v = query.elem, key.elem, value.elem + out, lse = flex_attention_hop( + inner_q, + inner_k, + inner_v, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + return AsStridedErrorTensor(out), AsStridedErrorTensor(lse) + + # Test setup + B, H, S, D = 2, 1, 128, 16 + dtype = torch.float32 + + # Create regular tensors + query_elem = torch.randn(B, H, S, D, device=device, dtype=dtype) + key_elem = torch.randn(B, H, S, D, device=device, dtype=dtype) + value_elem = torch.randn(B, H, S, D, device=device, dtype=dtype) + + # Test 1: Verify as_strided raises error when called directly on AsStridedErrorTensor + test_tensor = AsStridedErrorTensor(query_elem) + with self.assertRaisesRegex( + RuntimeError, "as_strided was called on AsStridedErrorTensor!" + ): + torch.as_strided( + test_tensor, size=(B, H, S, D), stride=test_tensor.stride() + ) + + # Test 2: Run flex_attention with normal tensors first + compiled_fn = torch.compile(flex_attention, backend="aot_eager", fullgraph=True) + normal_out, normal_lse = compiled_fn( + query_elem, key_elem, value_elem, return_lse=True + ) + + # Test 3: Wrap in our subclass + query = AsStridedErrorTensor(query_elem) + key = AsStridedErrorTensor(key_elem) + value = AsStridedErrorTensor(value_elem) + + # This should NOT error with as_strided after the fix + # Before the fix, it would error because FakeTensorMode would directly + # call flex_attention_fake_impl which uses as_strided + out, lse = compiled_fn(query, key, value, return_lse=True) + # Verify we got valid output + self.assertIsInstance(out, AsStridedErrorTensor) + self.assertIsInstance(lse, AsStridedErrorTensor) + self.assertEqual(out.shape, (B, H, S, D)) + self.assertEqual(lse.shape, (B, H, S)) + + # Test 4: Compare outputs between normal tensors and subclassed tensors + torch.testing.assert_close(out.elem, normal_out, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(lse.elem, normal_lse, rtol=1e-5, atol=1e-5) + @supported_platform @skip_on_cuda def test_cpu_error_message_return_lse(self, device): @@ -3842,6 +3983,132 @@ def forward(self, q, k, v, block_mask): attn_output = mod(q, k, v, mask) self.assertEqual(attn_output.device, torch.device("cuda:1")) + @supported_platform + @skip_on_cpu + @common_utils.parametrize( + "ops_to_save", + [ + [ + torch.ops.aten.mm.default, + ], + [ + flex_attention_hop, + ], + [torch.ops.aten.mm.default, flex_attention_hop], + ], + ) + def test_selective_ac(self, device, ops_to_save): + class FlexAttentionModule(nn.Module): + def __init__(self, hidden_size, num_heads): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + # In-projections (query, key, value) + self.q_proj = nn.Linear(hidden_size, hidden_size) + self.k_proj = nn.Linear(hidden_size, hidden_size) + self.v_proj = nn.Linear(hidden_size, hidden_size) + + # Out-projection + self.out_proj = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + batch_size, seq_len, _ = x.size() + + # Project queries, keys, and values + q = ( + self.q_proj(x) + .view(batch_size, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + k = ( + self.k_proj(x) + .view(batch_size, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + v = ( + self.v_proj(x) + .view(batch_size, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + # Apply flex attention + attn_output = flex_attention( + q, + k, + v, + ) + + # Reshape output + attn_output = ( + attn_output.transpose(1, 2) + .contiguous() + .view(batch_size, seq_len, self.hidden_size) + ) + + # Out projection + output = self.out_proj(attn_output) + + return output + + from torch.utils.checkpoint import ( + checkpoint, + create_selective_checkpoint_contexts, + ) + + context_fn = functools.partial( + create_selective_checkpoint_contexts, ops_to_save + ) + + # Define a model that uses FlexAttention with selective activation checkpointing + class SacModule(nn.Module): + def __init__(self, hidden_size, num_heads, context_fn): + super().__init__() + self.flex_attn = FlexAttentionModule(hidden_size, num_heads) + self.context_fn = context_fn + + def forward(self, x): + def flex_attn_fn(x): + return self.flex_attn(x) + + output = checkpoint( + flex_attn_fn, + x, + use_reentrant=False, + context_fn=self.context_fn, + ) + + return output + + flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to( + "cuda", dtype=torch.bfloat16 + ) + x = torch.ones(8, 1024, 512, device="cuda", dtype=torch.bfloat16) + + # Run without compilation + output_module = flex_module(x) + compiled_module = torch.compile(flex_module) + output_compiled = compiled_module(x) + + torch.testing.assert_close(output_module, output_compiled, rtol=1e-2, atol=1e-2) + + # Calculate gradients and compare them + x.requires_grad_(True) + output_module = flex_module(x) + output_compiled = compiled_module(x) + grad_output = torch.ones_like(output_module) + + grad_module = torch.autograd.grad( + outputs=output_module, inputs=x, grad_outputs=grad_output, retain_graph=True + )[0] + + grad_compiled = torch.autograd.grad( + outputs=output_compiled, inputs=x, grad_outputs=grad_output + )[0] + + torch.testing.assert_close(grad_module, grad_compiled, rtol=1e-2, atol=1e-2) + @supported_platform @skip_on_cpu def test_validate_small_embedding_size_error_message(self, device): @@ -3895,20 +4162,20 @@ def make_tensor(): **keyword_args, ) assert kernel_code is not None, "Failed to retrieve compiled kernel code" - assert ( - "num_consumer_groups" in kernel_code[0] - ), "num_consumer_groups missing in kernel definition" - assert ( - "num_buffers_warp_spec" in kernel_code[0] - ), "num_buffers_warp_spec missing in kernel definition" + assert "num_consumer_groups" in kernel_code[0], ( + "num_consumer_groups missing in kernel definition" + ) + assert "num_buffers_warp_spec" in kernel_code[0], ( + "num_buffers_warp_spec missing in kernel definition" + ) # Validate correctness C1 = flex_compiled(q, k, v) C2 = flex_attention(q, k, v) - assert torch.allclose( - C1, C2, atol=1e-2, rtol=1e-2 - ), "Warp specialized kernel result differs from reference" + assert torch.allclose(C1, C2, atol=1e-2, rtol=1e-2), ( + "Warp specialized kernel result differs from reference" + ) @supported_platform @skip_on_cpu diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index c82e75ac98a380..3b4905fc356168 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -30,7 +30,13 @@ Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) -torch.set_float32_matmul_precision("high") +# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul. +# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the +# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest. +if torch.version.hip: + torch.set_float32_matmul_precision("highest") +else: + torch.set_float32_matmul_precision("high") index = torch.ops.aten.index Tensor = torch.Tensor @@ -348,9 +354,9 @@ def run_test( block_mask: Optional[BlockMask] = None, device="cuda", ): - assert ( - score_mod is not None or block_mask is not None - ), "Must provide score_mod or block_mask" + assert score_mod is not None or block_mask is not None, ( + "Must provide score_mod or block_mask" + ) assert Q_H % KV_H == 0 if device == "cpu" and dtype is torch.float16: dtype = torch.float32 @@ -1814,9 +1820,9 @@ def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self, device): ) # Ensure no more re-compilation after the second automatic dynamic shape version. if i == 0: - self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) - else: self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) + else: + self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 4) @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @@ -1869,7 +1875,7 @@ def causal_mask(b, h, q, kv): # init 4 requests with different prefill length prefill_length = [5, 98, 47, 194] - querys, keys, values = [], [], [] + queries, keys, values = [], [], [] for seq_len in prefill_length: q = torch.randn( 1, @@ -1898,13 +1904,13 @@ def causal_mask(b, h, q, kv): dtype=dtype, requires_grad=False, ) - querys.append(q) + queries.append(q) keys.append(k) values.append(v) # get ground truth output ref_outs, golden_outs = [], [] - for q, k, v in zip(querys, keys, values): + for q, k, v in zip(queries, keys, values): q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) @@ -1972,7 +1978,7 @@ def causal_mask(b, h, q, kv): ) ) paged_out = compiled_sdpa( - torch.cat(querys, 0), k_cache, v_cache, block_mask=new_block_mask + torch.cat(queries, 0), k_cache, v_cache, block_mask=new_block_mask ) with torch.no_grad(): diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index d5557253af0486..50044b2c1943a4 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -40,7 +40,7 @@ def _fix_fp8_dtype_for_rocm( # with MI300 supported FP8 types if device is GPU: # e4m3fn -> e4m3fnuz # e5m2 -> e5m2fnuz - # Supports single, typle and list of dtypes + # Supports single, tuple and list of dtypes # Keeps the same test name for CUDA and ROCm # Also it allows to enable FP8 inductor tests for CPU if ( diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 23c71f7cbbabe5..4d52775ccbade6 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -977,6 +977,94 @@ def dot_prod_attention( self._check_common(dot_prod_attention, check_train=False, has_dropout=True) + def _test_sdpa_rewriter_21(self): + def dot_prod_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + scores = torch.matmul(query, key.permute(0, 1, 3, 2)) + scores += attn_mask + attn_weights = scores.float().softmax(dim=-1).type(value.dtype) + return attn_weights.matmul(value) + + tensor_shape = (4, 2, 16, 32) + attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device) + args = [ + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + attn_mask, + ] + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + ) + + def _test_sdpa_rewriter_22(self): + def dot_prod_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + scores = torch.matmul(query, key.permute(0, 1, 3, 2)) + scores += attn_mask + attn_weights = scores.float().softmax(dim=-1).type(value.dtype) + return attn_weights.matmul(value), key, value + + tensor_shape = (4, 2, 16, 32) + attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device) + args = [ + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + attn_mask, + ] + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + ) + + def _test_sdpa_rewriter_23(self): + def dot_prod_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + attn_mask = torch.full((1, 1, 1, 2), 0.0, device=query.device) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + scores = torch.matmul(query, key.permute(0, 1, 3, 2)) + scores += attn_mask + attn_weights = scores.float().softmax(dim=-1).type(value.dtype) + return attn_weights.matmul(value), key, value + + tensor_shape = (4, 2, 16, 32) + args = [ + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + torch.randn(tensor_shape, device=self.device), + ] + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + ) + if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION): @@ -1036,6 +1124,15 @@ class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_20_gpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_20 ) + test_sdpa_rewriter_21_gpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_21 + ) + test_sdpa_rewriter_22_gpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_22 + ) + test_sdpa_rewriter_23_gpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_23 + ) class SDPAPatternRewriterGpuDynamicTests(SDPAPatternRewriterGpuTests): use_static_shapes = False @@ -1093,6 +1190,15 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_20_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_20 ) + test_sdpa_rewriter_21_cpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_21 + ) + test_sdpa_rewriter_22_cpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_22 + ) + test_sdpa_rewriter_23_cpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_23 + ) class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests): use_static_shapes = False diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index cfd34fad50d7d5..02a1d59627c9f7 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -23,6 +23,10 @@ from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen from torch._inductor.select_algorithm import extern_kernels from torch._inductor.test_case import TestCase as InductorTestCase +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_GPU, @@ -39,6 +43,7 @@ scalar_asserts=False, nan_asserts=False, ) +@instantiate_parametrized_tests class FxirTestCase(InductorTestCase): device = GPU_TYPE @@ -182,6 +187,33 @@ def foo(x, y): args = [torch.randn(8, device=self.device) for _ in range(2)] self._compile_and_check(foo, args, expected_num_triton_kernels=1) + def test_cat_views(self): + """ + Test concatenation with multiple kernels writing to the same buffer. + """ + + def foo(x, y): + a = x - 2 + b = y.sum(0, keepdim=True) + c = torch.cat((a, b)).clone() + return a, b, c + + args = [torch.randn(8, device=self.device) for _ in range(2)] + (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2) + + def get_offset(node: torch.fx.Node) -> int: + (input_, shape, stride, offset) = node.args + assert isinstance(offset, int) + return offset + + # Check for 2 views, one of which is offset. + as_strided_nodes = list( + gm.graph.find_nodes(op="call_function", target=torch.as_strided) + ) + self.assertEqual(len(as_strided_nodes), 2) + num_offset_views = sum(get_offset(node) > 0 for node in as_strided_nodes) + self.assertEqual(num_offset_views, 1) + def test_cat_to_alloc(self): """ Test concatenation that's optimized out to an allocation. @@ -375,6 +407,32 @@ def test_debug(self, mock_output_code): output_code = f.read() self.assertIn("triton_kernel_wrapper_mutation", output_code) + @parametrize( + "const", + (1, 1.5), + ) + def test_export_const_placeholder(self, const): + """ + Test that we can compile a graph coming from torch.export with a constant input. + """ + + class TestModule(torch.nn.Module): + def forward(self, x, y): + return x - y + + args = (torch.randn(8, device=self.device), const) + mod = TestModule() + export_gm = torch.export.export(mod, args).module() + + def compile_module(*inps): + torch._inductor.compile(export_gm, inps) + + (inductor_gm,) = self._run_and_capture_graphs(compile_module, args) + result = inductor_gm(*args) + ref = mod(*args) + + self.assertTrue(same(ref, result)) + @torch._inductor.config.patch("graph_partition", True) def test_subgraph_raises(self): """ @@ -404,11 +462,48 @@ def foo(x, y): args = [torch.randn(5, device=device) for _ in range(2)] cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None) - with unittest.mock.patch.dict( - common.device_codegens, {device.type: cpp_backend} - ), self.assertRaisesRegex(BackendCompilerFailed, "Triton"): + with ( + unittest.mock.patch.dict( + common.device_codegens, {device.type: cpp_backend} + ), + self.assertRaisesRegex(BackendCompilerFailed, "Triton"), + ): self._compile_and_check(foo, args) + @parametrize("enable_tuning", (False, True)) + @parametrize("use_dynamic_shapes", (False, True)) + def test_autotune(self, use_dynamic_shapes: bool, enable_tuning: bool): + orig_run = torch._inductor.runtime.triton_heuristics.CachingAutotuner.run + called = False + + def run(*args, **kwargs): + nonlocal called + called = True + return orig_run(*args, **kwargs) + + args = [torch.randn(8, device=self.device) for _ in range(2)] + + with ( + config.patch("triton.autotune_at_compile_time", enable_tuning), + unittest.mock.patch.object( + torch._inductor.runtime.triton_heuristics.CachingAutotuner, "run", run + ), + ): + # Compile and check that the tuner was called. + self.assertFalse(called) + (gm,) = self._compile_and_check( + torch.mul, args, compile_kwargs={"dynamic": use_dynamic_shapes} + ) + self.assertEqual(called, enable_tuning) + + # Check for a symbolic output shape. + (empty_strided,) = gm.graph.find_nodes( + op="call_function", target=torch.empty_strided + ) + (shape, stride) = empty_strided.args + output_is_symbolic = any(isinstance(dim, torch.SymInt) for dim in shape) + self.assertEqual(output_is_symbolic, use_dynamic_shapes) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_gpu_cpp_wrapper.py b/test/inductor/test_gpu_cpp_wrapper.py index 14f6c52f1a613b..24163ece1f919e 100644 --- a/test/inductor/test_gpu_cpp_wrapper.py +++ b/test/inductor/test_gpu_cpp_wrapper.py @@ -125,7 +125,7 @@ def make_test_case( assert callable(func), "not a callable" func = slowTest(func) if slow else func - @config.patch(cpp_wrapper=True, search_autotune_cache=False) + @config.patch(cpp_wrapper=True) def fn(self): tests.setUpClass() tests.setUp() @@ -302,12 +302,12 @@ class BaseTest(NamedTuple): skip_list = ["test_addmm", "test_linear_relu"] # need to skip instead of omit, otherwise fbcode ci can be flaky for test_name in skip_list: - test_failures_gpu_wrapper[ - f"{test_name}_cuda" - ] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True) - test_failures_gpu_wrapper[ - f"{test_name}_gpu_dynamic_shapes" - ] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True) + test_failures_gpu_wrapper[f"{test_name}_cuda"] = ( + test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True) + ) + test_failures_gpu_wrapper[f"{test_name}_gpu_dynamic_shapes"] = ( + test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True) + ) test_torchinductor.copy_tests( GpuWrapperTemplate, TestGpuWrapper, "gpu_wrapper", test_failures_gpu_wrapper diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 58a356c63e81a4..516120da298655 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -484,7 +484,7 @@ def test_pointwise_op_fusion_post_grad(self): self.assertEqual(counters["inductor"]["batch_aten_tanh"], 1) self.assertEqual(counters["inductor"]["batch_aten_relu"], 1) self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1) - self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 1) + self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 2) ref.sum().backward() res.sum().backward() self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) @@ -1214,7 +1214,7 @@ def test_find_independent_subset_greedy_fuse(self): ) self.assertEqual(next(i), [lookup[n] for n in ["n2", "n3", "n5"]]) - # fuse n2 and n3 which makes n4 now dependant on n1. + # fuse n2 and n3 which makes n4 now dependent on n1. args = tuple(lookup[n] for n in ["n0", "n1"]) fused = g.create_node("placeholder", "target", name="n2+n3", args=args) lookup["n2"].replace_all_uses_with(fused) diff --git a/test/inductor/test_helion_kernels.py b/test/inductor/test_helion_kernels.py new file mode 100644 index 00000000000000..7690c13f63bdab --- /dev/null +++ b/test/inductor/test_helion_kernels.py @@ -0,0 +1,69 @@ +# Owner(s): ["module: inductor"] +import torch +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_utils import instantiate_parametrized_tests +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_HELION, requires_helion + + +if HAS_HELION: + import helion + import helion.language as hl + + +class HelionTests(TestCase): + @requires_helion() + def test_add_kernel(self): + @helion.kernel(config=helion.Config(block_sizes=[1, 2])) + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # match pytorch broadcasting rules + x, y = torch.broadcast_tensors(x, y) + out = torch.empty( + x.shape, + # match type promotion of torch.add + dtype=torch.promote_types(x.dtype, y.dtype), + device=x.device, + ) + # tile will be a tuple of blocks + for tile in hl.tile(out.size()): + out[tile] = x[tile] + y[tile] + return out + + def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return add(x, y) + + x = torch.randn(4, 8, device=GPU_TYPE, dtype=torch.float16) + y = torch.randn(4, 8, device=GPU_TYPE, dtype=torch.float16) + + out = add(x, y) + compiled_add = torch.compile(f, fullgraph=True, backend="inductor") + compiled_out = compiled_add(x, y) + + self.assertEqual(out, x + y) + self.assertEqual(compiled_out, x + y) + + @requires_helion() + def test_softmax_view_reshape(self): + @helion.kernel(config={"block_size": 1}) + def softmax(x: torch.Tensor) -> torch.Tensor: + n, _m = x.size() + out = torch.empty_like(x) + for tile_n in hl.tile(n): + values = x[tile_n, :] + amax = torch.amax(values, dim=1).view(tile_n, 1) + exp = torch.exp(values - amax) + sum_exp = torch.reshape(torch.sum(exp, dim=1), [tile_n, 1]) + out[tile_n, :] = exp / sum_exp + return out + + x = torch.randn([1024, 1024], device=GPU_TYPE, dtype=torch.float16) + result = softmax(x) + self.assertEqual( + result, torch.nn.functional.softmax(x, dim=1), rtol=1e-2, atol=1e-1 + ) + + +instantiate_parametrized_tests(HelionTests) + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 973335fdba27d2..3359b237904fe6 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -55,7 +55,7 @@ def test_indexing_simplification(self): sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2), ) - # all the modular indexing should be removed when the body cant be larger than the modulus + # all the modular indexing should be removed when the body can't be larger than the modulus var_ranges[r3] = 2 self.assertEqual( sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3 @@ -247,7 +247,7 @@ def f(x): x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device=GPU_TYPE) triton_code = run_and_get_triton_code(f, x) - # Make sure the 2 load uses simpified indexing rather than something like + # Make sure the 2 load uses simplified indexing rather than something like # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + (x2 // 2),")) if DO_PERF_TEST: diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 377a87a673a0e4..45045a3c41893c 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -17,7 +17,12 @@ from torch._inductor.utils import override_lowering, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_cuda import SM80OrLater, tf32_on_and_off -from torch.testing._internal.common_utils import IS_FBCODE, skipIfRocm, skipIfXpu +from torch.testing._internal.common_utils import ( + IS_FBCODE, + skipIfRocm, + skipIfXpu, + TEST_WITH_SLOW_GRADCHECK, +) # Make the helper files in test/ importable @@ -510,7 +515,7 @@ def foo(mod, x): out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) # we unfuse the conv bias, but it should only have one constant in the kernel - if self.device == GPU_TYPE: + if self.device == "cuda": FileCheck().check_not(".run(").check("conv").check(".run(").check_same( "frozen_param" ).check_not("frozen_param").check_next("return").run(code[0]) @@ -555,7 +560,7 @@ def foo(mod, x): out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) # we unfuse the conv bias, but it should only have one constant in the kernel - if self.device == GPU_TYPE: + if self.device == "cuda": FileCheck().check_not(".run(").check("conv").check(".run(").check_same( "frozen_param" ).check_not("frozen_param").check_next("return").run(code[0]) @@ -785,6 +790,10 @@ def foo(mod, inp): @skipIfXpu @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") + @unittest.skipIf( + TEST_WITH_SLOW_GRADCHECK, + "Failing in slow gradcheck on cuda12.8, see https://github.com/pytorch/pytorch/pull/156731 for example", + ) def test_cpp_wrapper(self): mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) @@ -879,7 +888,7 @@ def my_inner_compile(gm, example_inputs, *args, **kwargs): # in the joint graph rather than torch.ops.aten.convolution.default. # Currently we only handle aten.convolution.default in layout # optimization. That's why the count may be 0 here for CPU. - if self.device == GPU_TYPE: + if self.device == "cuda": self.assertTrue(nconv == 1) def test_unequal_bias_horizontal_addmm_fusion(self): diff --git a/test/inductor/test_inductor_scheduler.py b/test/inductor/test_inductor_scheduler.py index 1be3b1c5b42539..b981d2849de140 100644 --- a/test/inductor/test_inductor_scheduler.py +++ b/test/inductor/test_inductor_scheduler.py @@ -5,7 +5,7 @@ import torch.utils.flop_counter from torch._dynamo.utils import counters from torch._inductor.ir import FixedLayout -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache from torch.testing._internal.common_cuda import SM70OrLater from torch.testing._internal.common_device_type import ( dtypes, @@ -77,7 +77,7 @@ def test_disable_get_estimated_runtime_logging(self, device, dtype): for op, example_inputs, kwargs in tc: comp = torch.compile(op) torch._dynamo.reset() - with fresh_inductor_cache(): + with fresh_cache(): comp(*example_inputs, **kwargs) self.assertEqual(metrics.num_bytes_accessed, 0) self.assertEqual(any(m[1] for m in metrics.node_runtimes), False) @@ -108,7 +108,7 @@ def test_get_estimated_runtime_logging(self, device, dtype): comp = torch.compile(op) torch._dynamo.reset() - with fresh_inductor_cache(): + with fresh_cache(): comp(*example_inputs, **kwargs) self.assertEqual(enba, metrics.num_bytes_accessed) nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0) @@ -150,9 +150,9 @@ def test_flop_counter_op(self, device, dtype, options): torch._logging.set_logs(inductor_metrics=True) for op, example_inputs, kwargs in tc: comp = torch.compile(op, options=options) - # next two lines are required, otherwise the flops will be cached from pervious runs of this function. + # next two lines are required, otherwise the flops will be cached from previous runs of this function. torch._dynamo.reset() - with fresh_inductor_cache(): + with fresh_cache(): # actually run to set the counters comp(*example_inputs, **kwargs) with FlopCounterMode() as mode: diff --git a/test/inductor/test_inplacing_pass.py b/test/inductor/test_inplacing_pass.py index 0d4f72d4ffec6f..dd592f8c4e823e 100644 --- a/test/inductor/test_inplacing_pass.py +++ b/test/inductor/test_inplacing_pass.py @@ -413,6 +413,31 @@ def f(b): # Both list inputs failed to reinplace. So we should have emitted clones for them. self.assertEqual(post_grad_graphs.count("aten.clone"), 2) + def test_generalized_scatter(self): + # This is an integration test for the reinplacing pass. + def fn(x_1): + a = torch.ones([2, 3]) + c = torch.ones(2) + a[:, 0].copy_(c) + + d = a.clone() + e = torch.ops.aten.as_strided.default(d, [2], [3], 0) + f = e.clone() + + g = torch.zeros(2) + e.copy_(g) + + h = torch.zeros(2, 3) + h[:, 0].copy_(f) + + add_1 = d + h + return add_1 + + x = torch.randn(2, 3) + expected = fn(x) + result = torch.compile(fn, fullgraph=True, backend="inductor")(x) + self.assertEqual(result, expected) + @parametrize( "factory_op", [ diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 88d73e59b4c9a2..6e438cdeab9141 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -13,7 +13,7 @@ from torch._inductor import config from torch._inductor.codecache import PyCodeCache from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache from torch.testing import FileCheck from torch.testing._internal.common_cuda import xfailIfSM89 from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU @@ -152,7 +152,7 @@ def f(x): @unittest.skipIf( not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)" ) - @fresh_inductor_cache() + @fresh_cache() def test_matmul_triton_kernel_benchmark(self): M = 12544 N = 256 @@ -170,7 +170,7 @@ def f(a, b): @config.patch( max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False ) - @fresh_inductor_cache() + @fresh_cache() def test_mm_triton_kernel_benchmark(self): M = 2048 N = 2432 @@ -199,7 +199,7 @@ def test_matmul_bandwidth_computation(self): def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr): Note the in_out_ptr0 argument. It's for a 1000x1000 tensor, but it's - inplace udpated, so when computing the bandwidth, we should count + inplace updated, so when computing the bandwidth, we should count the total memory access as 2 * 1000 * 1000 * 4 = 8MB. This amount is what this test asserts. """ diff --git a/test/inductor/test_layout_optim.py b/test/inductor/test_layout_optim.py index 52203caddab6f6..8962e6bb18b5f1 100644 --- a/test/inductor/test_layout_optim.py +++ b/test/inductor/test_layout_optim.py @@ -300,7 +300,7 @@ def test_nll_loss_backward(self): The CUDA implementation of aten.nll_loss2d_backward.default requires the self tensor (whose layout will be used to create grad_input) to be contiguous. Layout optimization may change the self tensor's layout - and cause failure. We fix that by adding layout constaints to the + and cause failure. We fix that by adding layout constraints to the fallback of aten.nll_loss2d_backward.default . """ diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index d68e377926199e..e2317486366403 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -5,6 +5,7 @@ import unittest import numpy as np +import sympy import torch from torch import nn @@ -16,13 +17,19 @@ from torch._inductor.scheduler import SchedulerNode from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_operators import realize -from torch._inductor.utils import sympy_index_symbol +from torch._inductor.utils import run_and_get_code, sympy_index_symbol from torch._inductor.virtualized import ops, V +from torch.testing import FileCheck from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 -from torch.testing._internal.common_utils import skipIfRocm +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + skipIfRocm, +) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU +from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map -from torch.utils._sympy.functions import ModularIndexing +from torch.utils._sympy.functions import FloorDiv, ModularIndexing # set so that metrics appear @@ -41,9 +48,11 @@ class MockScheduler: def get_backend(cls, *args): return TritonScheduling(cls) + def can_buffer_be_removed_through_fusion(self, *args, **kwargs): + return False -@inductor_config.patch(loop_ordering_after_fusion=True) -class ImplDetailTest(TestCase): + +class MockSchedulerTest(TestCase): _exit_stack = None @classmethod @@ -61,6 +70,9 @@ def tearDownClass(cls): super().tearDownClass() cls._exit_stack.close() + +@inductor_config.patch(loop_ordering_after_fusion=True) +class ImplDetailTest(MockSchedulerTest): @staticmethod def _get_snode_body_sym_prefix(snode): body = snode._body @@ -435,9 +447,9 @@ def test_fp8_pattern_2(self): M, K = 4096, 4096 input_tensor = torch.randn( - M, K, device="cuda", dtype=ref_dtype, requires_grad=False + M, K, device=GPU_TYPE, dtype=ref_dtype, requires_grad=False ) - scale = torch.Tensor([10.0]).to("cuda") + scale = torch.Tensor([10.0]).to(GPU_TYPE) E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max @@ -509,6 +521,531 @@ def f(x): print(f"{ms=:.3f}") +@inductor_config.patch( + { + "triton.unique_kernel_names": True, + "loop_ordering_after_fusion": True, + "triton.max_tiles": 3, + "triton.coalesce_tiling_analysis": True, + } +) +@instantiate_parametrized_tests +class MemoryCoalescingTest(MockSchedulerTest): + """Tests for memory coalescing analysis with specific tensor sizes.""" + + device = GPU_TYPE + _exit_stack = None + + def setUp(self): + super().setUp() + metrics.reset() + + def _create_buffer(self, name, sizes): + """Create a buffer with specified sizes""" + + strides = ir.FlexibleLayout.contiguous_strides(sizes) + + box = ir.TensorBox.create( + ir.Buffer( + name=name, + layout=ir.FixedLayout( + torch.device(self.device), + dtype=torch.float32, + size=sizes, + stride=strides, + ), + ) + ) + box_loader = box.make_loader() + + def inner_fn(index): + return box_loader(index) * 2 + + buf = ir.Pointwise.create( + device=box.get_device(), + dtype=box.get_dtype(), + inner_fn=inner_fn, + ranges=box.get_size(), + ) + buf.realize() + computed_buf = buf.data.data + computed_buf.decide_layout() + + return computed_buf + + def _create_scheduler_node(self, buf): + s = SchedulerNode(V.graph.scheduler, buf) + s.min_order = 0 + s.max_order = 100 + return s + + @parametrize( + "inps", + ( + ((128, 384, 196), (768, 64, 196), (128, 6, 64, 196)), + ((64,), (16, 4), (16, 4)), + ((5, 6), (3, 10), (30,)), + ((5, 6, 20), (3, 10, 20), (30, 20)), + ), + ) + def test_inferred_splits(self, inps): + """ + Test memory coalescing analysis with the specified tensor sizes. + Using direct SchedulerNode creation with sizes (128, 384, 196) and (768, 64, 196). + """ + + s1, s2, expected_size = inps + + # Create buffers with the specified sizes + buf1 = self._create_buffer("buffer1", s1) + buf2 = self._create_buffer("buffer2", s2) + + # Create scheduler nodes + snode1 = self._create_scheduler_node(buf1) + snode2 = self._create_scheduler_node(buf2) + + # Create a fused node + fused_node = torch._inductor.scheduler.FusedSchedulerNode.fuse(snode1, snode2) + + from torch._inductor import tiling_utils + + fused_norm_read_writes = tiling_utils.extract_normalized_read_writes(fused_node) + + var_ranges = fused_norm_read_writes.var_ranges + self.assertEqual(list(var_ranges.values()), list(expected_size)) + + def test_remapped_reads(self): + from torch._inductor import tiling_utils + + def fn(nodes): + assert len(nodes) == 1 + fused_norm_read_writes = tiling_utils.extract_normalized_read_writes( + nodes[0] + ) + + self.assertTrue(len(fused_norm_read_writes.var_ranges) == 2) + + # both reads remapped correctly + FileCheck().check("4*n0 + n1").run( + repr(fused_norm_read_writes.reads.keys()) + ) + FileCheck().check("n0 + 4*n1").run( + repr(fused_norm_read_writes.reads.keys()) + ) + + return nodes + + with torch._inductor.config.patch(_post_fusion_custom_pass=fn): + + @torch.compile() + def foo(x, y): + return x + y + + foo( + torch.rand([4, 4], device=GPU_TYPE), + torch.rand([4, 4], device=GPU_TYPE).T, + ) + + def test_remapped_reads_split(self): + from torch._inductor import tiling_utils + + def fn(nodes): + self.assertTrue(len(nodes) == 1) + fused_norm_read_writes = tiling_utils.extract_normalized_read_writes( + nodes[0] + ) + + inp_node_reads = nodes[0].get_nodes()[1]._body.get_read_exprs() + node_ranges = nodes[0].get_nodes()[1]._body.var_ranges + self.assertTrue(len(node_ranges) == 1) + self.assertTrue(next(iter(node_ranges.values())) == 36) + var = next(iter(node_ranges.keys())) + + r = FloorDiv(var, 6) + 6 * ModularIndexing(var, 1, 6) + self.assertTrue(r in inp_node_reads) + + # mapped reads + self.assertTrue(list(fused_norm_read_writes.var_ranges.values()) == [6, 6]) + n0, n1 = list(fused_norm_read_writes.var_ranges.keys()) + + # translation of above is n0 + 6 * n1 + self.assertTrue((n0 + 6 * n1) in fused_norm_read_writes.reads.keys()) + + return nodes + + with torch._inductor.config.patch(_post_fusion_custom_pass=fn): + + @torch.compile() + def foo(x, y): + return ( + x + y + ).contiguous().flatten() + torch.ops._inductor_test.realize( + (y.T + 1).flatten() + ) + + foo( + torch.rand([6, 6], device=GPU_TYPE), + torch.rand([6, 6], device=GPU_TYPE).T, + ) + + def test_reduction_pointwise(self): + # test one pw var, one red var + from torch._inductor import tiling_utils + + def fn(nodes): + self.assertTrue(len(nodes) == 1) + fused_rw = tiling_utils.extract_normalized_read_writes(nodes[0]) + + i_vars, r_vars = fused_rw.index_vars, fused_rw.reduce_vars + self.assertTrue(len(i_vars) == 1) + self.assertTrue(len(r_vars) == 1) + + # single write to index var + self.assertTrue( + fused_rw.index_vars[0] == next(iter(fused_rw.writes.keys())) + ) + + # the write to the fused intermediary node should be removed + self.assertTrue(len(fused_rw.writes) == 1) + + # single read + self.assertTrue(len(fused_rw.reads) == 1) + # that is applied to two bufs + self.assertTrue(len(next(iter(fused_rw.reads.values()))) == 2) + + # and the read should be in terms of the index + reduce var, + # even though node is pointwise + self.assertTrue(256 * i_vars[0] + r_vars[0] in fused_rw.reads) + + return nodes + + with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): + + @torch.compile() + def foo(x, y): + out = torch.ops._inductor_test.realize(x + y) + return out.sum(dim=1) + + foo( + torch.rand(256, 256, device=GPU_TYPE), + torch.rand(256, 256, device=GPU_TYPE), + ) + + def test_reduction_no_pointwise(self): + # test one pw var, one red var + from torch._inductor import tiling_utils + + def fn(nodes): + self.assertTrue(len(nodes) == 1) + fused_rw = tiling_utils.extract_normalized_read_writes(nodes[0]) + + i_vars, r_vars = fused_rw.index_vars, fused_rw.reduce_vars + self.assertTrue(len(i_vars) == 0) + self.assertTrue(len(r_vars) == 1) + + return nodes + + with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): + + @torch.compile() + def foo(x): + return x.sum() + + foo(torch.rand(1024, device=GPU_TYPE)) + + def test_coalescing(self): + from torch._inductor import tiling_utils + + # Define symbolic variables + i, j, n, m = sympy.symbols("i j n m", integer=True) + + # Test cases: (expression, var_ranges, expected_result) + test_cases = [ + # Simple direct case + (i + j * 5, {i: 10, j: 8}, i), + # Floor division case + (i + FloorDiv(j, 2), {i: 4, j: 8}, i), + # Modular indexing + (i * 10 + ModularIndexing(j, 1, 3), {i: 5, j: 10}, j), + # Case with no coalescing variable + (i * 2 + j * 3, {i: 8, j: 5}, None), + # Division case + (i / 2, {i: 10}, None), + # More complex floor division + (j + FloorDiv(i, 3), {i: 6, j: 12}, j), + # Addition inside modular indexing + (ModularIndexing(i + 3, 1, 6), {i: 8, j: 12}, i), + ] + + for expr, var_ranges, expected in test_cases: + # Test the function + result = tiling_utils.find_coalesced_var(expr, var_ranges) + self.assertEqual(result, expected) + + @parametrize("downcast_transposed_v", (False, True)) + def test_tiled_coalesce_analysis(self, downcast_transposed_v): + # test one pw var, one red var + from torch._inductor import tiling_utils + + def fn(nodes): + self.assertTrue(len(nodes) == 1) + + coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0]) + + i_vars = coalesce_analysis.norm_read_writes.index_vars + + # because output is contiguous, second dimension should + # coalesce twice as many bytes as first dimension + # if not downcasted + # if downcasted, should be equal, bc larger dtype size + # we also weight writes x 2 + cont_reads = coalesce_analysis.coalesced_by_var[i_vars[1]] + t_reads = coalesce_analysis.coalesced_by_var[i_vars[0]] + + if not downcast_transposed_v: + self.assertEqual(cont_reads, t_reads * 3) + else: + self.assertEqual(cont_reads, t_reads * 1.5) + + return nodes + + with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): + + @torch.compile() + def foo(x, y): + return x + y.to(x.dtype) + + y_dtype = torch.float if not downcast_transposed_v else torch.float64 + foo( + torch.rand(256, 256, device=GPU_TYPE), + torch.rand(256, 256, device=GPU_TYPE, dtype=y_dtype).T, + ) + + def test_solve_for_zero(self): + from torch._inductor import tiling_utils + + x, y = sympy.symbols("x y", integer=True) + # Test cases: (expression, expected_result) + test_cases = [ + # Simple linear expressions + (x + 5, (-5)), + (2 * x - 10, (5)), + # Constant expressions (should return None) + (sympy.Integer(7), None), + (sympy.Integer(0), None), + # FloorDiv cases (should return None per function) + (FloorDiv(x, 2), None), + (FloorDiv(x, 2) + 5, None), + # ModularIndexing cases + (ModularIndexing(x, 1, 5), (5)), + (ModularIndexing(x, 1, 3), (3)), + # Expressions with no constant solution + (x**2 + 1, None), # No real solution + ] + for expr, expected in test_cases: + result = tiling_utils.solve_for_zero(expr) + self.assertEqual(result, expected) + + def test_solve_for_tiling(self): + from torch._inductor import tiling_utils + + x = sympy.Symbol("x", integer=True) + + test_cases = [ + # Simple linear cases that coalesce + (3 * x, None), + # # # # Expression with no free symbols + # (sympy.Integer(5), None), + (x / 3, 3), + (FloorDiv(x * 2, 6), 3), + # # ModularIndexing expressions + (ModularIndexing(FloorDiv(x, 4), 1, 64), 4), + (x + ModularIndexing(x, 1, 5), None), + (x**2, None), # Non-linear, diff is not constant + (4096 * (ModularIndexing(32 * x, 1, 2048)) + FloorDiv(x, 64), 64), + (4096 * (ModularIndexing(x, 1, 2048)) + FloorDiv(x, 2048), 2048), + ] + + for expr, expected in test_cases: + result = tiling_utils.solve_for_tiling(expr) + self.assertEqual(result, expected) + + def test_induced_fused_tiling(self): + from torch._inductor import tiling_utils + + def fn(nodes): + self.assertTrue(len(nodes) == 1) + + coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0]) + self.assertEqual(coalesce_analysis.suggested_split.tiling_factor, 64) + return nodes + + with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): + + def forward(permute): + clone = torch.ops.aten.clone.default( + permute, memory_format=torch.contiguous_format + ) + view_2 = torch.ops.aten.view.default(clone, [-1, 32]) + amax_1 = torch.ops.aten.amax.default(view_2, [1]) + return amax_1 + + XDIM = 2048 + YDIM = 4096 + + arg0_1 = torch.randn([XDIM, YDIM], device=GPU_TYPE, dtype=torch.bfloat16) + permute = torch.ops.aten.permute.default(arg0_1, [1, 0]) + + out, code = run_and_get_code(torch.compile(forward), (permute)) + + self.assertEqual(out, forward(permute)) + FileCheck().check("YBLOCK").check("XBLOCK").run(code[0]) + + +layouts = ("cont", "NHWC", "T") + + +@inductor_config.patch( + { + "triton.unique_kernel_names": True, + "loop_ordering_after_fusion": True, + "triton.coalesce_tiling_analysis": True, + } +) +@instantiate_parametrized_tests +class TestTiling(TestCase): + def T(self, layout: str): + SIZE_A = 128 + SIZE_B = 256 + SIZE_C = 512 + + if layout == "cont": + return torch.rand(SIZE_A, SIZE_B, SIZE_C, device=GPU_TYPE).unsqueeze(0) + elif layout == "T": + return ( + torch.rand(SIZE_A, SIZE_B, SIZE_C, device=GPU_TYPE) + .transpose(1, 2) + .contiguous() + .transpose(1, 2) + .unsqueeze(0) + ) + else: + assert layout == "NHWC" + return torch.rand([1, SIZE_A, SIZE_B, SIZE_C], device=GPU_TYPE).to( + memory_format=torch.channels_last + ) + + @parametrize("a", layouts) + @parametrize("b", layouts) + def test_pointwise(self, a, b): + def foo(x, y): + return x + y + + x, y = self.T(a), self.T(b) + res, code = run_and_get_code(torch.compile(foo), x, y) + + if a != b: + FileCheck().check("ynumel").run(code[0]) + else: + FileCheck().check_not("ynumel").run(code[0]) + + self.assertEqual(res, foo(x, y)) + + def test_tiled_reduction(self): + def f(a, b): + return (a * b).sum(dim=-1) + + N = 512 + inps = ( + torch.randn(N, N, N, device=GPU_TYPE).permute(2, 1, 0), + torch.randn(N, N, N, device=GPU_TYPE).permute(1, 2, 0), + ) + f_c = torch.compile(f) + out, code = run_and_get_code(f_c, *inps) + + FileCheck().check_dag("xnumel = 512").check_dag("ynumel = 512").check_dag( + "rnumel" + ).run(code[0]) + self.assertEqual(out, f(*inps), atol=0.001, rtol=0.04) + + def test_3d_pointwise(self): + inps = (self.T("cont"), self.T("T"), self.T("NHWC")) + + def f(x, y, z): + return x + y + z + + f_c = torch.compile(f) + out, code = run_and_get_code(f_c, *inps) + + FileCheck().check_dag("znumel").check_dag("ynumel").check_dag("xnumel").run( + code[0] + ) + self.assertEqual(out, f(*inps)) + + def test_cat(self): + # test unwrapping Identity + + def f(x, y): + return torch.cat((x, y)) + 1 + + x = self.T("cont") + y = self.T("T") + + inps = (x, y) + + f_c = torch.compile(f) + out, code = run_and_get_code(f_c, *inps) + FileCheck().check_dag("ynumel").check_dag("xnumel").run(code[0]) + self.assertEqual(out, f(*inps)) + + def test_penalized_small_dim(self): + x = torch.rand([2000, 1], device=GPU_TYPE) + y = torch.rand([4, 1], device=GPU_TYPE).T + + # don't tile when it doesn't affect total coalesced mem accesses much + def f(x, y): + return x + y + + inps = (x, y) + + f_c = torch.compile(f) + out, code = run_and_get_code(f_c, *inps) + FileCheck().check_not("ynumel").check_dag("xnumel").run(code[0]) + self.assertEqual(out, f(*inps)) + + def test_mutation_deps(self): + def f(x): + return x.add_(1) + + x = self.T("cont") + + from torch._inductor import tiling_utils + + def fn(nodes): + self.assertTrue(len(nodes) == 1) + + coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0]) + assert coalesce_analysis is not None + + reads = coalesce_analysis.norm_read_writes.reads + writes = coalesce_analysis.norm_read_writes.writes + + self.assertTrue(len(reads) == 1 and len(writes) == 1) + self.assertEqual( + list(coalesce_analysis.norm_read_writes.reads.values()), + [OrderedSet(("arg0_1",))], + ) + self.assertEqual( + list(coalesce_analysis.norm_read_writes.writes.values()), + [OrderedSet(("buf1",))], + ) + + return nodes + + with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad(): + torch.compile(f)(x) + + if __name__ == "__main__": if HAS_GPU: run_tests() diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 53b0479ef4ed18..6fc81566f071c9 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -10,6 +10,7 @@ import re import tempfile import unittest +from functools import partial from typing import Callable, Optional from unittest import mock from unittest.mock import MagicMock @@ -35,6 +36,11 @@ TritonTemplate, TritonTemplateCaller, ) +from torch._inductor.template_heuristics import ( + BaseConfigHeuristic, + CUDAConfigHeuristic, + GemmConfig, +) from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -49,7 +55,7 @@ aten = torch.ops.aten from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import fresh_inductor_cache, run_and_get_code +from torch._inductor.utils import fresh_cache, run_and_get_code from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck @@ -160,12 +166,15 @@ def mm(a, b): a = torch.randn(M, K).to(torch.float16).cuda() b = torch.randn(K, N).to(torch.float16).cuda() - with self.assertRaises(BackendCompilerFailed) as context, config.patch( - { - "max_autotune": True, - "triton.enable_persistent_tma_matmul": "1", - "test_configs.autotune_choice_name_regex": "mm_persistent_tma", - } + with ( + self.assertRaises(BackendCompilerFailed) as context, + config.patch( + { + "max_autotune": True, + "triton.enable_persistent_tma_matmul": "1", + "test_configs.autotune_choice_name_regex": "mm_persistent_tma", + } + ), ): torch.compile(mm, dynamic=dynamic)(a, b) @@ -278,12 +287,15 @@ def addmm(x, a, b): b = torch.randn(K, N).to(torch.float16).cuda() x = torch.randn(N).to(torch.float16).cuda() - with self.assertRaises(BackendCompilerFailed) as context, config.patch( - { - "max_autotune": True, - "triton.enable_persistent_tma_matmul": "1", - "test_configs.autotune_choice_name_regex": "mm_persistent_tma", - } + with ( + self.assertRaises(BackendCompilerFailed) as context, + config.patch( + { + "max_autotune": True, + "triton.enable_persistent_tma_matmul": "1", + "test_configs.autotune_choice_name_regex": "mm_persistent_tma", + } + ), ): torch.compile(addmm, dynamic=dynamic)(x, a, b) @@ -325,7 +337,7 @@ def addmm(x, a, b): torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) - @fresh_inductor_cache() + @fresh_cache() @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA") @unittest.skipIf( @@ -470,7 +482,7 @@ def foo(mod, x): FileCheck().check_not("extern_kernels.convolution").run(code[0]) self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0) - @fresh_inductor_cache() + @fresh_cache() @config.patch(max_autotune=True, max_fusion_size=2) def test_jit_fusion_matches_aot_fusion(self): # In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due @@ -563,7 +575,7 @@ def test_autotune_device_guard(self): def f(x, y): return x @ y - with fresh_inductor_cache(): + with fresh_cache(): act = torch.compile(f)(x, y) ref = f(x, y) self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3)) @@ -704,16 +716,22 @@ def f(x, y, z, other): inps = (t(3, 3), t(3, 3), t(3, 3), t(3)) fn = torch.compile(f, mode="max-autotune-no-cudagraphs") ( - pre_fusion_tream, - post_fusion_stream, - ), ctx = multiple_logs_to_string( + ( + pre_fusion_tream, + post_fusion_stream, + ), + ctx, + ) = multiple_logs_to_string( "torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion" ) with config.patch({"trace.debug_dir": tempfile.mkdtemp()}): - with self.assertLogs( - logging.getLogger("torch._inductor.debug"), level=logging.INFO - ) as cm, ctx(): + with ( + self.assertLogs( + logging.getLogger("torch._inductor.debug"), level=logging.INFO + ) as cm, + ctx(), + ): out = fn(*inps) self.assertEqual(f(*inps), out) @@ -737,7 +755,7 @@ def test_cat_max_autotune_extern(self): self._test_cat_max_autotune_impl(using_triton_mm=False) @skipIfXpu( - msg="The fusion not happend because it do not speedup on XPU, see issue #146568" + msg="The fusion not happened because it do not speedup on XPU, see issue #146568" ) @config.patch(max_autotune_gemm_backends="TRITON") def test_cat_max_autotune_triton(self): @@ -765,7 +783,7 @@ def forward(self, x): self.assertEqual(out, m(input_tensor)) if not TEST_WITH_ROCM: - FileCheck().check("triton_poi_fused_cat_2.run").run(code[0]) + FileCheck().check("def triton_poi_fused_cat_").run(code[0]) def test_conv3d(self): fn = torch.nn.functional.conv3d @@ -885,8 +903,9 @@ def mock_lookup(self, *args, **kwargs): a = torch.zeros([16, 16], device=GPU_TYPE) b = torch.zeros([16, 16], device=GPU_TYPE) - with patch.object(AlgorithmSelectorCache, "lookup", mock_lookup), config.patch( - benchmark_epilogue_fusion=multi_template + with ( + patch.object(AlgorithmSelectorCache, "lookup", mock_lookup), + config.patch(benchmark_epilogue_fusion=multi_template), ): with self.assertRaises(BackendCompilerFailed) as context: torch.compile(lambda a, b: a.matmul(b))(a, b) @@ -936,7 +955,6 @@ def f(x, y): assert same(expect, actual, tol=1e-2), f"ref:\n{expect}\nact:\n{actual}" @skipIfXpu - @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -984,7 +1002,7 @@ def check_divisors(code): # We assume with the large k dim relative to m, n, decompose_k will be most performant out, code = run_and_get_code(compiled_func, a, b) - if dynamic: + if dynamic or torch.version.hip: FileCheck().check_not("extern_kernels.bmm_dtype").check_not( "decompose_k" ).run(code[0]) @@ -998,7 +1016,7 @@ def check_divisors(code): # Test adding epilogue also equivalent to eager compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic) out, code = run_and_get_code(compiled_func, a, b) - if dynamic: + if dynamic or torch.version.hip: FileCheck().check_not("extern_kernels.bmm_dtype").check_not( "decompose_k" ).run(code[0]) @@ -1017,7 +1035,9 @@ def check_divisors(code): lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic ) out, code = run_and_get_code(compiled_func, a, b) - if dynamic: + + # DecomposeK is not enabled for AMD yet + if dynamic or torch.version.hip: FileCheck().check_not("extern_kernels.bmm_dtype").check_not( "decompose_k" ).run(code[0]) @@ -1073,13 +1093,9 @@ def f(a, b): out, code = run_and_get_code(compiled_func, a, b) FileCheck().check("extern_kernels.bmm_dtype").check_regex( "triton_.*_fused_0.run" - ).check("decompose_k").check_regex("s[0-9]+ = primals_1").check_regex( - "2*s[0-9]+" - ).check( - "primals_1 = 32" - ).run( - code[0] - ) + ).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex( + r"2\*s[0-9]+" + ).check_regex("s[0-9]+ = 32").run(code[0]) torch.testing.assert_close( out, f(a, b), @@ -1087,6 +1103,49 @@ def f(a, b): rtol=1e-2, ) + @skipIfXpu + @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") + @unittest.skipIf( + config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" + ) + @config.patch( + max_autotune=True, + max_autotune_gemm_backends="TRITON", + ) + def test_max_autotune_decompose_k_dynamic_input_bwd(self): + def f(a, b): + # 256 * s0 + a_in = torch.cat([a for _ in range(256)], dim=0) + return (a_in @ b).relu().sum() + + a = torch.randn(8, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + b = torch.randn( + 64, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + + torch._dynamo.reset() + torch._dynamo.maybe_mark_dynamic(a, 0) + compiled_func = torch.compile(f) + res = compiled_func(a, b) + res.backward() + + with mock.patch( + "torch._inductor.kernel.mm.use_decompose_k_choice" + ) as decomp_mock: + decomp_mock.return_value = True + + out, code = run_and_get_code(compiled_func, a, b) + out.backward() + + FileCheck().check("extern_kernels.bmm_dtype").check_regex( + "triton_.*_fused_0.run" + ).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex( + r"256\*s[0-9]+" + ).check_regex("s[0-9]+ = 8").run( + # code[1] in this case given backwards + code[1] + ) + @skipIfXpu @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( @@ -1107,11 +1166,14 @@ def f(a, b): b = b[:, :1096] # Force only decomposeK choice - with mock.patch( - "torch._inductor.kernel.mm.V.choices.get_base_mm_configs" - ) as base_mm_mock, mock.patch( - "torch._inductor.kernel.mm.use_decompose_k_choice" - ) as decompose_mock: + with ( + mock.patch( + "torch._inductor.kernel.mm.V.choices.get_base_mm_configs" + ) as base_mm_mock, + mock.patch( + "torch._inductor.kernel.mm.use_decompose_k_choice" + ) as decompose_mock, + ): mm_configs_mock = MagicMock() mm_configs_mock.return_value = [] base_mm_mock.return_value = mm_configs_mock @@ -1144,9 +1206,9 @@ def test_triton_template_generated_code_cache_key(self): # Make sure all args of generate_and_load_args are passed to make_key_args (Except generate_with_caching) # update this function each time new arg added to generate_and_load and make sure arg is added to make_key self.assertEqual(generate_and_load_args - 1, make_key_args) - self.assertEqual(generate_and_load_args, 15) + self.assertEqual(generate_and_load_args, 16) - @fresh_inductor_cache() + @fresh_cache() @config.patch( { "max_autotune": True, @@ -1213,7 +1275,7 @@ def func_test1(x, y, z, m): b = torch.rand(22, 30, device=GPU_TYPE) # Valid cache hit. - with fresh_inductor_cache(): + with fresh_cache(): reset_counters() compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b) eager_results = func_test1(a, b, a, b) @@ -1224,18 +1286,20 @@ def func_test1(x, y, z, m): cache_key, events = get_cache_key_and_events() if not TEST_WITH_ROCM: + expected = """{ + 'input_nodes':[ + "[[10,22],[22,1],torch.float32,device(type='cuda',index=0),0]", + "[[22,30],[30,1],torch.float32,device(type='cuda',index=0),0]"], + 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[10,30], + 'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]", + 'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity', + 'kwargs':{'EVEN_K':False,'ALLOW_TF32':True,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32', + 'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8}}""" + + expected = expected.replace("cuda", GPU_TYPE) self.assertExpectedInline( remove_white_space(cache_key), - remove_white_space( - f""" - {{'input_nodes': ["[[10, 22], [22, 1], torch.float32, device(type='{GPU_TYPE}', index=0), 0]", - "[[22, 30], [30, 1], torch.float32, device(type='{GPU_TYPE}', index=0), 0]"], - 'num_stages': 1, 'num_warps': 2, 'prefix_args': 0, 'suffix_args': 0, - 'call_sizes': [10, 30], 'layout': "[[10, 30], [30, 1], torch.float32, device(type='{GPU_TYPE}', index=0), 0]", - 'num_consumer_groups': 0, 'num_buffers_warp_spec': 0, - 'kwargs': {{'EVEN_K': False, 'ALLOW_TF32': True, 'USE_FAST_ACCUM': False, 'ACC_TYPE': 'tl.float32', - 'BLOCK_M': 16, 'BLOCK_N': 32, 'BLOCK_K': 16, 'GROUP_M': 8}}}}""" - ), + remove_white_space(expected), ) self.assertEqual( @@ -1244,7 +1308,7 @@ def func_test1(x, y, z, m): ) # Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs. - with fresh_inductor_cache(): + with fresh_cache(): a = torch.rand(10, 22, device=GPU_TYPE) b = torch.rand(22, 30, device=GPU_TYPE) @@ -1262,17 +1326,18 @@ def func_test1(x, y, z, m): cache_key, events = get_cache_key_and_events() if not TEST_WITH_ROCM: + expected = """{ + 'input_nodes':[ + "[[s77,s17],[s17,1],torch.float32,device(type='cuda',index=0),0]", + "[[s17,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"], + 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94], + 'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0, + 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True, + 'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8}}""" + expected = expected.replace("cuda", GPU_TYPE) self.assertExpectedInline( remove_white_space(cache_key), - remove_white_space( - f"""{{'input_nodes': ["[[s77, s17], [s17, 1], torch.float32, device(type='{GPU_TYPE}', index=0), 0]", - "[[s17, s94], [s94, 1], torch.float32, device(type='{GPU_TYPE}', index=0), 0]"], - 'num_stages': 1, 'num_warps': 2, 'prefix_args': 0, 'suffix_args': 0, 'call_sizes': [s77, s94], - 'layout': "[[s77, s94], [s94, 1], torch.float32, device(type='{GPU_TYPE}', index=0), 0]", - 'num_consumer_groups': 0, 'num_buffers_warp_spec': 0, 'kwargs': {{'EVEN_K': False, - 'ALLOW_TF32': True, 'USE_FAST_ACCUM': False, - 'ACC_TYPE': 'tl.float32', 'BLOCK_M': 16, 'BLOCK_N': 32, 'BLOCK_K': 16, 'GROUP_M': 8}}}}""" - ), + remove_white_space(expected), ) self.assertExpectedInline( @@ -1294,7 +1359,7 @@ def func_test1(x, y, z, m): ) # Test duck typing. - with fresh_inductor_cache(): + with fresh_cache(): reset_counters() compile_results = torch.compile(func_test1, dynamic=True)(a, b, a, b) @@ -1310,7 +1375,7 @@ def test_func2(x): x = torch.matmul(x, x) return x - with fresh_inductor_cache(): + with fresh_cache(): reset_counters() input = torch.rand(10, 10, device=GPU_TYPE) @@ -1321,7 +1386,7 @@ def test_func2(x): self.assertEqual(hits(), 36) self.assertEqual(misses(), 4) - with fresh_inductor_cache(): + with fresh_cache(): reset_counters() input = torch.rand(10, 10, device=GPU_TYPE) @@ -1340,7 +1405,7 @@ def test_func3(x, y, z, m, l): b = torch.matmul(torch.cat([x, z], 1), torch.cat([y, m, l], 0)) return a, b - with fresh_inductor_cache(): + with fresh_cache(): a = torch.rand(10, 22, device=GPU_TYPE) b = torch.rand(22, 30, device=GPU_TYPE) c = torch.rand(10, 11, device=GPU_TYPE) @@ -1354,8 +1419,82 @@ def test_func3(x, y, z, m, l): self.assertEqual(hits(), 0) self.assertEqual(misses(), 7) + @config.patch( + { + "max_autotune": True, + "test_configs.max_mm_configs": 4, + "max_autotune_gemm_backends": "TRITON", + } + ) + def test_triton_template_generated_code_caching_bmm(self): + def func_test1(x, y, z, m): + a = torch.bmm(x, y) + b = torch.bmm(z, m) + return a, b + + a = torch.rand(10, 10, 22, device=GPU_TYPE) + b = torch.rand(10, 22, 30, device=GPU_TYPE) + + def hits(): + return torch._dynamo.utils.counters["inductor"][ + "generated_module_cache_hit" + ] + + def misses(): + return torch._dynamo.utils.counters["inductor"][ + "generated_module_cache_miss" + ] + + # Valid cache hit. + with fresh_cache(): + torch._dynamo.utils.counters.clear() + compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b) + eager_results = func_test1(a, b, a, b) + self.assertEqual(compile_results, eager_results, atol=0.05, rtol=0.05) + self.assertEqual(hits(), 4) + self.assertEqual(misses(), 4) + + @config.patch( + { + "max_autotune": True, + "test_configs.max_mm_configs": 4, + "max_autotune_gemm_backends": "ATEN, TRITON", + } + ) + def test_triton_template_generated_code_caching_mm_plus_mm(self): + def func_test1(x, y, z, m): + a = torch.mm(x, y) + b = torch.mm(z, m) + sum1 = a + b + + c = torch.mm(x, y) + d = torch.mm(z, m) + sum2 = c + d + return sum1, sum2 + + a = torch.rand(10, 40, device=GPU_TYPE) + b = torch.rand(40, 30, device=GPU_TYPE) + + def hits(): + return torch._dynamo.utils.counters["inductor"][ + "generated_module_cache_hit" + ] + + def misses(): + return torch._dynamo.utils.counters["inductor"][ + "generated_module_cache_miss" + ] + + # Valid cache hit. + with fresh_cache(): + torch._dynamo.utils.counters.clear() + compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b) + eager_results = func_test1(a, b, a, b) + self.assertEqual(compile_results, eager_results, atol=0.05, rtol=0.05) + self.assertEqual(hits(), 4) + self.assertEqual(misses(), 4) + @skipIfXpu - @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -1377,6 +1516,96 @@ def test_max_autotune_disable_decompose_K(self): for codegen in code: FileCheck().check_not("decompose_k").run(codegen) + @skipIfXpu + @unittest.skipIf( + TEST_WITH_ROCM, "exhaustive currently only thoroughly tested on NVIDIA" + ) + @config.patch(max_autotune=True, max_autotune_gemm_search_space="EXHAUSTIVE") + def test_max_autotune_exhaustive(self): + def f(a, b): + return a @ b + + M, N, K = (1024, 1024, 1024) + + a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) + b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) + + with mock.patch( + "torch._inductor.kernel.mm.V.choices.get_config_heuristics" + ) as config_mock: + config_heuristics = CUDAConfigHeuristic() + + # Traditionally, this would be set of all possible configs + # We mock out the code path for the sake of the unit test + config_heuristics.exhaustive_configs = [GemmConfig(32, 32, 32, 1, 8, 8)] + config_mock.return_value = config_heuristics + + from torch._dynamo.utils import counters + + compiled_func = torch.compile(f) + compiled_func(a, b) + + # Only benchmarks 2 choices, aten and the exhaustive triton config + # Counter can be InductorBenchmarker or TritonBenchmarker + for counter in counters["inductor"]: + if "benchmark_gpu" in counter: + self.assertEqual(counters["inductor"][counter], 2) + + @unittest.skipIf( + not has_triton_tma_device(), "Need device-side TMA support in Triton" + ) + @config.patch( + max_autotune=True, + max_autotune_gemm_backends="TRITON", + autotune_fallback_to_aten=False, + ) + def test_one_triton_choice_epilogue_fusion(self): + """ + Here we test the fusion case with only 1 Triton choice for mm lowering. + The hardcoded config itself is valid, but when fused with the torch.float32 + case, the shared memory requirements is higher than the amount available on H100. + + This test checks that the fusion does not occur in this edge case. This is important + for future work on lookup table for autotuned gemm configs. + """ + + def f(a, b): + return (a @ b).to(torch.float32) + + a = torch.randn(512, 1152, device="cuda", dtype=torch.bfloat16) + b = torch.randn(1152, 7680, device="cuda", dtype=torch.bfloat16) + + config_heuristic = BaseConfigHeuristic() + with config.patch( + { + "triton.enable_persistent_tma_matmul": "1", + } + ): + with ( + mock.patch( + "torch._inductor.kernel.mm.V.choices.get_base_mm_configs" + ) as base_mm_mock, + mock.patch( + "torch._inductor.kernel.mm.V.choices.get_persistent_mm_configs" + ) as persistent_mm_mock, + ): + base_mm_mock.return_value = partial( + config_heuristic.preprocess_mm_configs, configs=[] + ) + persistent_mm_mock.return_value = partial( + config_heuristic.preprocess_mm_configs, + configs=[GemmConfig(256, 128, 64, 4, 8, 8)], + ) + + compiled_f = torch.compile(f) + out, code = run_and_get_code(compiled_f, a, b) + + FileCheck().check("triton_tem_fused_mm").check( + "triton_poi_fused__to_copy" + ).run(code[0]) + + torch.testing.assert_close(out, f(a, b), atol=1e-2, rtol=1e-2) + class TestMaxAutotunePrecompile(TestCase): def test_precompilation_threads(self): @@ -1438,12 +1667,12 @@ def fake_benchmark_fn(*args, **kwargs): ): asc("test_call", fake_choices, [], Mock()) for fake_choice in fake_choices: - assert ( - fake_choice.thread_id is not None - ), "Expected all ChoiceCaller's precompile method to have been called" - assert ( - fake_choice.thread_id != main_thread_id - ), "Expected all ChoiceCaller's precompile method to have been called on separate thread" + assert fake_choice.thread_id is not None, ( + "Expected all ChoiceCaller's precompile method to have been called" + ) + assert fake_choice.thread_id != main_thread_id, ( + "Expected all ChoiceCaller's precompile method to have been called on separate thread" + ) finally: V.set_debug_handler(old_debug_handler) @@ -1465,21 +1694,6 @@ def fn(a, b, c): fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) - @fresh_inductor_cache() - @config.patch(search_autotune_cache=True) - def test_search_autotune_cache(self): - def fn(a, b, c): - a = (a @ b) @ c - a, b, c = (t.to(torch.float16) for t in [a, b, c]) - return (a @ b) @ c - - fn_c = torch.compile()(fn) - inputs = [torch.rand([256, 256], device=GPU_TYPE) for _ in range(3)] - from torch._dynamo.utils import counters - - self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) - self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) - @config.patch(autotune_local_cache=False, autotune_remote_cache=False) @runOnRocmArch(MI300_ARCH) def test_precompilations(self): @@ -1516,7 +1730,7 @@ def test_benchmark_choice_in_subproc(self): )() # a dummy graph to construct the GraphLowering graph = GraphLowering(gm) - # the graph handler is neede to create benchmark example value below + # the graph handler is needed to create benchmark example value below with V.set_graph_handler(graph): buf1 = self._create_buffer("mat1", (2, 3)) buf2 = self._create_buffer("mat2", (3, 2)) @@ -1556,7 +1770,7 @@ def test_benchmark_choice_fail_in_subproc(self): )() # a dummy graph to construct the GraphLowering graph = GraphLowering(gm) - # the graph handler is neede to create benchmark example value below + # the graph handler is needed to create benchmark example value below with V.set_graph_handler(graph): buf1 = self._create_buffer("mat1", (2, 3)) buf2 = self._create_buffer("mat2", (3, 2)) @@ -1724,21 +1938,26 @@ def f(x, y): x = torch.randn(100, 100).to(GPU_TYPE) y = torch.randn(100, 100).to(GPU_TYPE) - with config.patch( - { - "autotune_local_cache": False, - "autotune_remote_cache": True, - } - ), patch.dict(os.environ), PatchCaches(): + with ( + config.patch( + { + "autotune_local_cache": False, + "autotune_remote_cache": True, + } + ), + patch.dict(os.environ), + PatchCaches(), + ): os.environ.pop("TRITON_CACHE_MANAGER", None) with config.patch({"max_autotune": True}): for _ in range(4): - with fresh_inductor_cache(): + with fresh_cache(): torch.compile(mm, dynamic=dynamic)(a, b) reset() - with torch.compiler.config.patch( - {"cache_key_tag": "test"} - ), fresh_inductor_cache(): + with ( + torch.compiler.config.patch({"cache_key_tag": "test"}), + fresh_cache(), + ): torch.compile(mm, dynamic=dynamic)(a, b) reset() @@ -1747,12 +1966,10 @@ def f(x, y): global_stats.reset() for _ in range(4): - with fresh_inductor_cache(): + with fresh_cache(): torch.compile(f, dynamic=dynamic)(x, y) reset() - with torch.compiler.config.patch( - {"cache_key_tag": "test"} - ), fresh_inductor_cache(): + with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache(): torch.compile(mm, dynamic=dynamic)(a, b) reset() global_stats.report() @@ -2056,7 +2273,7 @@ def foo(x, y): } ) @skipIfXpu( - msg="The fusion not happend because it do not speedup on XPU, see issue #146568" + msg="The fusion not happened because it do not speedup on XPU, see issue #146568" ) def test_pending_fusions_multiple(self): def multi_use(x, y): @@ -2090,7 +2307,7 @@ def resolve_pending(x): } ) @skipIfXpu( - msg="The fusion not happend because it do not speedup on XPU, see issue #146568" + msg="The fusion not happened because it do not speedup on XPU, see issue #146568" ) def test_pending_fusion_pro_and_epi(self): def test_multiple_fusions(x): @@ -2235,8 +2452,8 @@ def foo(x, y, z): out, code = run_and_get_code(torch.compile(foo), x, y, z) self.assertEqual(out, foo(x, y, z), atol=0.05, rtol=0.05) - # theres one more dealloc than there should be because of a buffer reuse. TODO: - # not sure why disabling buffer reuse doesnt stop + # there's one more dealloc than there should be because of a buffer reuse. TODO: + # not sure why disabling buffer reuse doesn't stop self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=4) # XPU have not enabled pad_mm in fx_passes, so there is always one kernel. diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index 3116ea1a8a5084..d5f90e662697dc 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -58,11 +58,7 @@ def test_python_wrapper(self): + "((4*s27*s77 + align(4*s77*s77), ), (1, )" ).check_next( "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s77, s77), (s77, 1))" - ).check( - "buf1 = alloc_from_pool(pool1, align(4*s77*s77)," - ).run( - code - ) + ).check("buf1 = alloc_from_pool(pool1, align(4*s77*s77),").run(code) self.assertTrue(same(f(*args), result)) def test_cpp_wrapper(self): @@ -75,9 +71,7 @@ def test_cpp_wrapper(self): "aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_2, int_array_3, &tmp_tensor_handle_0)" ).check_next("auto buf0 = RAIIAtenTensorHandle(tmp_tensor_handle_0);").check( "auto buf1 = RAIIAtenTensorHandle(tmp_tensor_handle_1);" - ).run( - code - ) + ).run(code) self.assertTrue(same(f(*args), result)) @skipIfXpu(msg="aoti doesn't work on XPU") @@ -102,19 +96,11 @@ def test_aoti(self): "AtenTensorHandle pool1_handle;" ).check_next( "aoti_torch_empty_strided(1, int_array_0, int_array_1," - ).check_next( - "RAIIAtenTensorHandle pool1(pool1_handle);" - ).check_next( + ).check_next("RAIIAtenTensorHandle pool1(pool1_handle);").check_next( "int64_t int_array_2[] = {s77, 3L};" - ).check_next( - "int64_t int_array_3[] = {3L, 1L};" - ).check_next( + ).check_next("int64_t int_array_3[] = {3L, 1L};").check_next( "AtenTensorHandle tmp_tensor_handle_0;" - ).check_next( - "aoti_torch__alloc_from_pool(pool1, 0" - ).run( - code - ) + ).check_next("aoti_torch__alloc_from_pool(pool1, 0").run(code) self.assertTrue(same(f(*args), result)) diff --git a/test/inductor/test_metrics.py b/test/inductor/test_metrics.py index cf8c9414394358..1517c945187d90 100644 --- a/test/inductor/test_metrics.py +++ b/test/inductor/test_metrics.py @@ -50,9 +50,7 @@ def triton_red_fused_add_sum_2(in_out_ptr0, in_ptr0, xnumel, rnumel, XBLOCK : tl tmp5 = tmp4 + tmp2 tl.debug_barrier() tl.store(in_out_ptr0 + (x0), tmp5, xmask) -""".replace( - "GPU_TYPE", GPU_TYPE -) +""".replace("GPU_TYPE", GPU_TYPE) class TestMetrics(TestCase): diff --git a/test/inductor/test_minifier.py b/test/inductor/test_minifier.py index 702ade28a61a92..0fe17a6e526d48 100644 --- a/test/inductor/test_minifier.py +++ b/test/inductor/test_minifier.py @@ -75,9 +75,7 @@ def inner(x): return x - torch.tensor(655, dtype=torch.half, device='GPU_TYPE') * 100 inner(torch.tensor(655 * 100, dtype=torch.half, device='GPU_TYPE')) -""".replace( - "GPU_TYPE", GPU_TYPE - ) +""".replace("GPU_TYPE", GPU_TYPE) # If we disable RMSE against fp64, this triggers accuracy error, # as the increased precision from torch.compile changes the result diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 66855f85868c73..7760bfd834efdc 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -10,10 +10,15 @@ from torch._dynamo.utils import counters from torch._inductor import config, metrics from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import run_and_get_code +from torch._inductor.utils import ( + is_mkldnn_bf16_supported, + is_mkldnn_fp16_supported, + run_and_get_code, +) from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.nn import functional as F from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_mkldnn import bf32_on_and_off from torch.testing._internal.common_quantization import ( _generate_qdq_quantized_model, skipIfNoDynamoSupport, @@ -30,6 +35,7 @@ skipIfNoXPU, skipIfRocm, skipIfRocmArch, + skipIfXpu, TEST_ACL, TEST_MKL, xfailIfACL, @@ -129,11 +135,20 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"): TEST_ACL and dtype == torch.bfloat16 ): output_kernel = 1 + return input_kernel + output_kernel -@config.patch({"freezing": True}) class TestPatternMatcherBase(TestCase): + def setUp(self): + TestCase.setUp(self) + self.ctx_stack = contextlib.ExitStack() + self.ctx_stack.enter_context(config.patch({"freezing": True})) + + def tearDown(self): + TestCase.tearDown(self) + self.ctx_stack.close() + def _check_unary_is_decomposed(self, unary_fn): return not any( isinstance(unary_fn, fn) @@ -179,16 +194,12 @@ def _test_common( ) counters.clear() torch._dynamo.reset() - if check_autocast == torch.bfloat16 and ( - torch.ops.mkldnn._is_mkldnn_bf16_supported() or device == "xpu" - ): + if check_autocast == torch.bfloat16 and is_mkldnn_bf16_supported(device): maybe_autocast = torch.amp.autocast( device_type=device, dtype=torch.bfloat16 ) atol, rtol = 1e-2, 1e-2 - elif check_autocast == torch.float16 and ( - torch.ops.mkldnn._is_mkldnn_fp16_supported() or device == "xpu" - ): + elif check_autocast == torch.float16 and (is_mkldnn_fp16_supported(device)): maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16) atol, rtol = 1e-2, 1e-2 else: @@ -206,7 +217,12 @@ def _test_common( clone_inputs = self._clone_inputs(inputs) expected = mod(*inputs) actual = torch.compile(mod, **compile_options)(*clone_inputs) - torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + if self.precision != 0: + torch.testing.assert_close( + actual, expected, atol=self.precision, rtol=self.precision + ) + else: + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) matcher_check_fn() def _test_code_common( @@ -231,6 +247,14 @@ def _test_code_common( torch.compile(mod, fullgraph=True, dynamic=check_dynamic), *clone_inputs, ) + assert_keywords = ["assert_size_stride", "assert_alignment"] + filtered_lines = [ + line + for line in source_code.splitlines() + if not any(assert_key in line for assert_key in assert_keywords) + ] + source_code = "\n".join(filtered_lines) + for op in include_ops: self.assertIn(op, source_code) if num_include_ops is not None: @@ -272,9 +296,9 @@ def forward(self, x): dtypes = [ torch.float, ] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + if is_mkldnn_bf16_supported(self.device): dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d options = itertools.product( @@ -326,6 +350,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm + @bf32_on_and_off() def test_conv2d_unary(self, device): self.device = device self._test_conv_unary_base(dim=4) @@ -333,204 +358,11 @@ def test_conv2d_unary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm + @bf32_on_and_off() def test_conv3d_unary(self, device): self.device = device self._test_conv_unary_base(dim=5) - def test_linear_unary(self, device): - self.device = device - - class M(torch.nn.Module): - def __init__( - self, - unary_fn, - in_features, - out_features, - bias, - **kwargs, - ): - super().__init__() - self.linear = torch.nn.Linear( - in_features, - out_features, - bias, - **kwargs, - ) - self.unary_fn = unary_fn - - def forward(self, x): - x = self.linear(x) - return self.unary_fn(x) - - dtypes = [] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): - dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): - dtypes.append(torch.float16) - options = itertools.product(unary_list, [True, False], dtypes) - for unary_fn, bias, dtype in options: - metrics.reset() - mod = M(unary_fn, 10, 30, bias=bias).eval() - # only fuse for linear when the dtype is bf16 - mod = mod - v = torch.randn(2, 10) - - def matcher_check_fn(): - match_nodes = unary_list[unary_fn] - if self._check_unary_is_decomposed(unary_fn): - # Has extra dtype conversion nodes for autocast. - match_nodes += 2 - self.assertEqual( - counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], - 0 if TEST_ACL else match_nodes, - ) - self.assertEqual( - counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1 - ) - - self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) - # only generated 1 kernel for "to" - self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1) - - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - def test_linear_fp32(self, device): - self.device = device - - class M(torch.nn.Module): - def __init__(self, bias): - super().__init__() - self.linear = torch.nn.Linear(10, 30, bias) - - def forward(self, x): - return self.linear(x) - - for bias in [True, False]: - mod = M(bias=bias).eval() - v = torch.randn(2, 10) - - # packing pass. - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1 - ) - - self._test_common(mod, (v,), matcher_check_fn) - - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - def test_linear_input_non_contiguous_3D_wo_bias(self, device): - self.device = device - - # Activation is 3D, non-contiguous and without Bias - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4096, 1024, bias=False) - - def forward(self, x): - x = torch.ops.aten.permute.default(x, [0, 2, 1, 3]) - x = torch.ops.aten.reshape.default(x, [4, 1, 4096]) - return self.linear(x) - - mod = M().eval() - v = torch.randn(4, 32, 1, 128) - - dtypes = [torch.float] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): - dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): - dtypes.append(torch.float16) - - for dtype in dtypes: - torch._dynamo.reset() - autocast_enabled = ( - True if dtype in [torch.bfloat16, torch.float16] else False - ) - with torch.no_grad(), torch.autocast( - device_type="cpu", enabled=autocast_enabled, dtype=dtype - ): - expected = mod(v) - actual, (source_code,) = run_and_get_code( - torch.compile(mod, fullgraph=True), - v, - ) - self.assertIn( - "torch.ops.mkldnn._linear_pointwise.default" - if autocast_enabled - else "torch.ops.mkl._mkl_linear.default", - source_code, - ) - torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) - - def test_linear_add_bias(self, device): - self.device = device - - class M(torch.nn.Module): - def __init__(self, device, dtype, unary_fn, cast_bias): - super().__init__() - self.linear1 = torch.nn.Linear(10, 64, bias=False) - self.bias1 = torch.randn(64, device=device) - self.linear2 = torch.nn.Linear(10, 64, bias=False) - self.bias2 = torch.randn(64, device=device) - if cast_bias: - self.bias1 = self.bias1.to(dtype=dtype, device=device) - self.bias2 = self.bias2.to(dtype=dtype, device=device) - self.unary_fn = unary_fn - - def forward(self, x): - a = self.linear1(x) + self.bias1 - b = self.linear2(x) + self.bias2 - return self.unary_fn(a), self.unary_fn(b) - - dtypes = [] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): - dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): - dtypes.append(torch.float16) - options = itertools.product(unary_list, dtypes) - for unary_fn, dtype in options: - metrics.reset() - fold_mod = M(self.device, dtype, unary_fn, cast_bias=True).eval() - v = torch.randn(2, 10) - - def folder_matcher_check_fn(): - match_nodes = unary_list[unary_fn] - if self._check_unary_is_decomposed(unary_fn): - # Has extra dtype conversion nodes for autocast. - match_nodes += 2 - # we have 2 linears, so we double the matcher_count/nodes - self.assertEqual( - counters["inductor"]["mkldnn_unary_fusion_matcher_count"], - 0 if TEST_ACL else 2, - ) - self.assertEqual( - counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], - 0 if TEST_ACL else match_nodes * 2, - ) - self.assertEqual( - counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2 - ) - - self._test_common( - fold_mod, - (v,), - folder_matcher_check_fn, - check_autocast=dtype, - ) - self.assertEqual(metrics.generated_kernel_count, 3 if TEST_ACL else 1) - # we won't fold the bias if bias is not same dtype with weight - # https://github.com/pytorch/pytorch/pull/129138 - metrics.reset() - mod = M(self.device, dtype, unary_fn, cast_bias=False).eval() - - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2 - ) - - self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) - # 1 kernel for "to_lowp", 2 kernels for unary ops - self.assertEqual(metrics.generated_kernel_count, 3) - def _test_conv_transpose_unary_base(self, dim=4): assert dim == 4 or dim == 5 @@ -558,9 +390,9 @@ def forward(self, x): dtypes = [ torch.float, ] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + if is_mkldnn_bf16_supported(self.device): dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d @@ -607,6 +439,10 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm + @skipIfXpu( + msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." + ) + @bf32_on_and_off() def test_conv_transpose2d_unary(self, device): self.device = device self._test_conv_transpose_unary_base(dim=4) @@ -614,6 +450,10 @@ def test_conv_transpose2d_unary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm + @skipIfXpu( + msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." + ) + @bf32_on_and_off() def test_conv_transpose3d_unary(self, device): self.device = device self._test_conv_transpose_unary_base(dim=5) @@ -649,9 +489,9 @@ def forward(self, x): dtypes = [ torch.float, ] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + if is_mkldnn_bf16_supported(self.device): dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d test_memory_format = [torch.contiguous_format, cl_format] @@ -703,6 +543,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm + @bf32_on_and_off(0.02) def test_conv2d_binary(self, device): self.device = device self._test_conv_binary_base(dim=4) @@ -710,6 +551,7 @@ def test_conv2d_binary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm + @bf32_on_and_off(0.02) def test_conv3d_binary(self, device): self.device = device self._test_conv_binary_base(dim=5) @@ -742,9 +584,9 @@ def forward(self, x, x2): dtypes = [ torch.float, ] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + if is_mkldnn_bf16_supported(self.device): dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d test_memory_format = [torch.contiguous_format, cl_format] @@ -808,16 +650,267 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv2d_binary_broadcast_shapes_cpu(self): + @bf32_on_and_off() + def test_conv2d_binary_broadcast_shapes(self, device): + self.device = device self._test_conv_binary_broadcast_shapes_base(dim=4) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv3d_binary_broadcast_shapes_cpu(self): + @bf32_on_and_off() + def test_conv3d_binary_broadcast_shapes(self, device): + self.device = device self._test_conv_binary_broadcast_shapes_base(dim=5) - def test_linear_binary(self, device): + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + @unittest.skipIf(IS_FBCODE, "Failing in fbcode") + @bf32_on_and_off() + def test_conv2d_linear_add_broadcast_shapes(self, device): + self.device = device + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1) + self.linear = torch.nn.Linear(3, 16) + + def forward(self, x1, x2): + return self.conv(x1) + self.linear(x2)[:, :, None, None] + + metrics.reset() + mod = M().eval() + x1 = torch.randn(2, 3, 56, 56) + x2 = torch.randn(2, 3) + + def matcher_check_fn(): + match_nodes = 0 if TEST_ACL else 2 + self.assertEqual( + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"], + match_nodes, + ) + self.assertEqual( + counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"], 1 + ) + + self._test_common(mod, (x1, x2), matcher_check_fn) + + +class TestPatternMatcher(TestPatternMatcherBase): + @bf32_on_and_off() + def test_linear_unary(self, device="cpu"): + self.device = device + + class M(torch.nn.Module): + def __init__( + self, + unary_fn, + in_features, + out_features, + bias, + **kwargs, + ): + super().__init__() + self.linear = torch.nn.Linear( + in_features, + out_features, + bias, + **kwargs, + ) + self.unary_fn = unary_fn + + def forward(self, x): + x = self.linear(x) + return self.unary_fn(x) + + dtypes = [] + if is_mkldnn_bf16_supported(self.device): + dtypes.append(torch.bfloat16) + if is_mkldnn_fp16_supported(self.device): + dtypes.append(torch.float16) + if torch.backends.mkldnn.matmul.fp32_precision == "bf16": + dtypes.append(torch.float32) + options = itertools.product(unary_list, [True, False], dtypes) + for unary_fn, bias, dtype in options: + metrics.reset() + mod = M(unary_fn, 10, 30, bias=bias).eval() + # only fuse for linear when the dtype is bf16 + mod = mod + v = torch.randn(2, 10) + + def matcher_check_fn(): + match_nodes = unary_list[unary_fn] + if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn): + # Has extra dtype conversion nodes for autocast. + match_nodes += 2 + self.assertEqual( + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], + 0 if TEST_ACL else match_nodes, + ) + self.assertEqual( + counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1 + ) + + self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) + # only generated 1 kernel for "to_dtype" + expected_kernel_count = 2 if TEST_ACL else 1 + if dtype == torch.float32: + # In BF32, input is float32, will not generate kernel for "to_dtype" + expected_kernel_count -= 1 + self.assertEqual(metrics.generated_kernel_count, expected_kernel_count) + + @bf32_on_and_off() + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + def test_linear_fp32(self, device="cpu"): + self.device = device + + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear = torch.nn.Linear(10, 30, bias) + + def forward(self, x): + return self.linear(x) + + for bias in [True, False]: + mod = M(bias=bias).eval() + v = torch.randn(2, 10) + + # packing pass. + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1 + ) + + self._test_common(mod, (v,), matcher_check_fn) + + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + def test_linear_input_non_contiguous_3D_wo_bias(self, device="cpu"): + self.device = device + + # Activation is 3D, non-contiguous and without Bias + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4096, 1024, bias=False) + + def forward(self, x): + x = torch.ops.aten.permute.default(x, [0, 2, 1, 3]) + x = torch.ops.aten.reshape.default(x, [4, 1, 4096]) + return self.linear(x) + + mod = M().eval() + v = torch.randn(4, 32, 1, 128) + + dtypes = [torch.float] + if is_mkldnn_bf16_supported(self.device): + dtypes.append(torch.bfloat16) + if is_mkldnn_fp16_supported(self.device): + dtypes.append(torch.float16) + + for dtype in dtypes: + torch._dynamo.reset() + autocast_enabled = ( + True if dtype in [torch.bfloat16, torch.float16] else False + ) + with ( + torch.no_grad(), + torch.autocast( + device_type="cpu", + enabled=autocast_enabled, + dtype=dtype, + ), + ): + expected = mod(v) + actual, (source_code,) = run_and_get_code( + torch.compile(mod, fullgraph=True), + v, + ) + self.assertIn( + "torch.ops.mkldnn._linear_pointwise.default" + if autocast_enabled + else "torch.ops.mkl._mkl_linear.default", + source_code, + ) + torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) + + @skipIfXpu( + msg="Different with CPU, two linears will be concat on XPU for better performance" + ) + def test_linear_add_bias(self, device="cpu"): + self.device = device + + class M(torch.nn.Module): + def __init__(self, device, dtype, unary_fn, cast_bias): + super().__init__() + self.linear1 = torch.nn.Linear(10, 64, bias=False) + self.bias1 = torch.randn(64, device=device) + self.linear2 = torch.nn.Linear(10, 64, bias=False) + self.bias2 = torch.randn(64, device=device) + if cast_bias: + self.bias1 = self.bias1.to(dtype=dtype, device=device) + self.bias2 = self.bias2.to(dtype=dtype, device=device) + self.unary_fn = unary_fn + + def forward(self, x): + a = self.linear1(x) + self.bias1 + b = self.linear2(x) + self.bias2 + return self.unary_fn(a), self.unary_fn(b) + + dtypes = [] + if is_mkldnn_bf16_supported(self.device): + dtypes.append(torch.bfloat16) + if is_mkldnn_fp16_supported(self.device): + dtypes.append(torch.float16) + options = itertools.product(unary_list, dtypes) + for unary_fn, dtype in options: + metrics.reset() + fold_mod = M(self.device, dtype, unary_fn, cast_bias=True).eval() + v = torch.randn(2, 10) + + def folder_matcher_check_fn(): + match_nodes = unary_list[unary_fn] + if self._check_unary_is_decomposed(unary_fn): + # Has extra dtype conversion nodes for autocast. + match_nodes += 2 + # we have 2 linears, so we double the matcher_count/nodes + self.assertEqual( + counters["inductor"]["mkldnn_unary_fusion_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"], + 0 if TEST_ACL else match_nodes * 2, + ) + self.assertEqual( + counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2 + ) + + self._test_common( + fold_mod, + (v,), + folder_matcher_check_fn, + check_autocast=dtype, + ) + self.assertEqual(metrics.generated_kernel_count, 3 if TEST_ACL else 1) + # we won't fold the bias if bias is not same dtype with weight + # https://github.com/pytorch/pytorch/pull/129138 + metrics.reset() + mod = M(self.device, dtype, unary_fn, cast_bias=False).eval() + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2 + ) + + self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) + # 1 kernel for "to_lowp", 2 kernels for unary ops + self.assertEqual(metrics.generated_kernel_count, 3) + + @bf32_on_and_off() + def test_linear_binary(self, device="cpu"): self.device = device class M(torch.nn.Module): @@ -834,10 +927,12 @@ def forward(self, x, y): return x dtypes = [] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + if is_mkldnn_bf16_supported(self.device): dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) + if torch.backends.mkldnn.matmul.fp32_precision == "bf16": + dtypes.append(torch.float32) options = itertools.product( binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes ) @@ -874,9 +969,16 @@ def matcher_check_fn(): matcher_check_fn, check_autocast=dtype, ) - self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1) + # only generated 1 kernel for "to_dtype" + expected_kernel_count = 2 if TEST_ACL else 1 + if dtype == torch.float32: + # In BF32, input is float32, will not generate kernel for "to_dtype" + expected_kernel_count -= 1 + self.assertEqual(metrics.generated_kernel_count, expected_kernel_count) + + def test_linear_binary_broadcast_shapes(self, device="cpu"): + self.device = device - def test_linear_binary_broadcast_shapes_cpu(self): class M(torch.nn.Module): def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): super().__init__() @@ -891,9 +993,9 @@ def forward(self, x, y): return x dtypes = [] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + if is_mkldnn_bf16_supported(self.device): dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) options = itertools.product( binary_list, @@ -939,38 +1041,10 @@ def matcher_check_fn(): ) self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfRocm - @unittest.skipIf(IS_FBCODE, "Failing in fbcode") - def test_conv2d_linear_add_broadcast_shapes_cpu(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1) - self.linear = torch.nn.Linear(3, 16) - - def forward(self, x1, x2): - return self.conv(x1) + self.linear(x2)[:, :, None, None] - - metrics.reset() - mod = M().eval() - x1 = torch.randn(2, 3, 56, 56) - x2 = torch.randn(2, 3) - - def matcher_check_fn(): - match_nodes = 0 if TEST_ACL else 2 - self.assertEqual( - counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"], - match_nodes, - ) - self.assertEqual( - counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"], 1 - ) - - self._test_common(mod, (x1, x2), matcher_check_fn) - - def test_multi_linear_share_same_input(self, device): + @skipIfXpu( + msg="Different with CPU, two linears will be concat on XPU for better performance" + ) + def test_multi_linear_share_same_input(self, device="cpu"): self.device = device # llama pattern. @@ -986,9 +1060,9 @@ def forward(self, x): return F.silu(self.w1(x)) * F.relu(self.w2(x)) dtypes = [] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + if is_mkldnn_bf16_supported(self.device): dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) def matcher_check_fn(): @@ -1012,8 +1086,6 @@ def matcher_check_fn(): v = torch.randn(2, 4, 16).to(dtype) self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2) - -class TestPatternMatcher(TestPatternMatcherBase): def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): class M(torch.nn.Module): def __init__( @@ -2282,7 +2354,7 @@ def _default_matcher_check_fn(): @skipIfNoONEDNN def test_qlinear_cpu(self): r""" - This testcase will quantize a single Linear Moduel. + This testcase will quantize a single Linear Module. """ for bias in [True, False]: self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias) @@ -2292,7 +2364,7 @@ def test_qlinear_cpu(self): @skipIfNoXPU def test_qlinear_xpu(self): r""" - This testcase will quantize a single Linear Moduel. + This testcase will quantize a single Linear Module. """ for bias in [True, False]: self._qlinear_test_helper( @@ -2303,7 +2375,7 @@ def test_qlinear_xpu(self): @skipIfNoONEDNN def test_dynamic_qlinear_cpu(self): r""" - This testcase will quantize a single Linear Moduel. + This testcase will quantize a single Linear Module. """ for bias in [True, False]: self._qlinear_test_helper( @@ -2314,7 +2386,7 @@ def test_dynamic_qlinear_cpu(self): @skipIfNoONEDNN def test_dynamic_qlinear_qat_cpu(self): r""" - This testcase will quantize a single Linear Moduel. + This testcase will quantize a single Linear Module. """ for bias in [True, False]: self._qlinear_test_helper( @@ -2325,7 +2397,7 @@ def test_dynamic_qlinear_qat_cpu(self): @skipIfNoONEDNN def test_dynamic_qlinear_input_dim_exceeds_2(self): r""" - This testcase will quantize a single Linear Moduel. + This testcase will quantize a single Linear Module. """ for bias in [True, False]: self._qlinear_test_helper( @@ -2337,7 +2409,7 @@ def test_dynamic_qlinear_input_dim_exceeds_2(self): @skipIfNoONEDNN def test_qlinear_int8_mixed_bf16(self): r""" - This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. """ for bias in [True, False]: self._qlinear_test_helper( @@ -2349,7 +2421,7 @@ def test_qlinear_int8_mixed_bf16(self): @skipIfNoXPU def test_qlinear_int8_mixed_bf16_xpu(self): r""" - This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. """ for bias in [True, False]: self._qlinear_test_helper( @@ -2363,7 +2435,7 @@ def test_qlinear_int8_mixed_bf16_xpu(self): @skipIfNoONEDNN def test_qlinear_input_dim_exceeds_2(self): r""" - This testcase will quantize a single Linear Moduel. + This testcase will quantize a single Linear Module. """ for bias in [True, False]: self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias) @@ -2373,7 +2445,7 @@ def test_qlinear_input_dim_exceeds_2(self): @skipIfNoXPU def test_qlinear_input_dim_exceeds_2_xpu(self): r""" - This testcase will quantize a single Linear Moduel. + This testcase will quantize a single Linear Module. """ for bias in [True, False]: self._qlinear_test_helper( @@ -2385,7 +2457,7 @@ def test_qlinear_input_dim_exceeds_2_xpu(self): @skipIfNoONEDNN def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self): r""" - This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. """ for bias in [True, False]: self._qlinear_test_helper( @@ -2398,7 +2470,7 @@ def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self): @skipIfNoXPU def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_xpu(self): r""" - This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. """ for bias in [True, False]: self._qlinear_test_helper( @@ -4091,10 +4163,13 @@ def matcher_check_fn(): nodes_count = 10 if has_bias else 7 else: nodes_count = 7 if has_bias else 6 - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], - nodes_count, - ) + if counters["inductor"]["removed_pointless_view_pair"] == 0: + # Removing pointless view pairs affect how the pattern + # for this test is matched. + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + nodes_count, + ) self._test_common( mod, @@ -4210,8 +4285,7 @@ def matcher_check_fn(): class TestDynamicPatternMatcherGeneric(TestPatternMatcherBase): def setUp(self): - TestCase.setUp(self) - self.ctx_stack = contextlib.ExitStack() + super().setUp() self.ctx_stack.enter_context( # When testing kernel counts, unspecializing float causes wobbling of our tests because # we end up reusing the same compiled region across tests. Thus we purposely specialize floats @@ -4226,20 +4300,12 @@ def setUp(self): ) ) - def tearDown(self): - TestCase.tearDown(self) - self.ctx_stack.close() - _test_conv_unary_base = TestPatternMatcherGeneric._test_conv_unary_base test_conv2d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_unary test_conv3d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_unary _test_conv_binary_base = TestPatternMatcherGeneric._test_conv_binary_base test_conv2d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_binary test_conv3d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_binary - test_linear_unary_dynamic_shapes = TestPatternMatcherGeneric.test_linear_unary - test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = ( - TestPatternMatcherGeneric.test_linear_input_non_contiguous_3D_wo_bias - ) def test_conv_transpose2d_dynamic_shapes(self, device): self.device = device @@ -4264,6 +4330,9 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn) + @skipIfXpu( + msg="Different with CPU, two linears will be concat on XPU for better performance" + ) def test_multi_linear_share_same_input_dynamic(self, device): self.device = device @@ -4280,9 +4349,9 @@ def forward(self, x): return F.silu(self.w1(x)) * F.relu(self.w2(x)) dtypes = [] - if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + if is_mkldnn_bf16_supported(self.device): dtypes.append(torch.bfloat16) - if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) def matcher_check_fn(): @@ -4310,14 +4379,28 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2) -@dynamo_config.patch( - { - "dynamic_shapes": True, - "assume_static_by_default": False, - "specialize_float": True, - } -) class TestDynamicPatternMatcher(TestPatternMatcherBase): + test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary + test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = ( + TestPatternMatcher.test_linear_input_non_contiguous_3D_wo_bias + ) + + def setUp(self): + super().setUp() + self.ctx_stack.enter_context( + # When testing kernel counts, unspecializing float causes wobbling of our tests because + # we end up reusing the same compiled region across tests. Thus we purposely specialize floats + # here since we primarily care about number of kernels generated in the absence of compile + # caching. + dynamo_config.patch( + { + "dynamic_shapes": True, + "assume_static_by_default": False, + "specialize_float": True, + } + ) + ) + @xfailIfACL def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): r""" @@ -4480,10 +4563,10 @@ def matcher_check_fn(): instantiate_device_type_tests( - TestPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu") + TestPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu", "xpu") ) instantiate_device_type_tests( - TestDynamicPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu") + TestDynamicPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu", "xpu") ) instantiate_parametrized_tests(TestPatternMatcher) if __name__ == "__main__": diff --git a/test/inductor/test_mmdecomp.py b/test/inductor/test_mmdecomp.py index 71c81e6083cc2c..05b7afe0d2ccc2 100644 --- a/test/inductor/test_mmdecomp.py +++ b/test/inductor/test_mmdecomp.py @@ -152,7 +152,7 @@ def test_bmm_batch2_last_dim_size_is_one(self, device): @parametrize("dtype", [torch.float, torch.bfloat16, torch.int]) def test_some(self, device, dtype): # this Pytorch data type is not fully supported on cuda today - # - unfortunately we can't skipIf because we don't see the actual parms in skipIf + # - unfortunately we can't skipIf because we don't see the actual params in skipIf if device.startswith(GPU_TYPE) and dtype == torch.int: return @@ -172,7 +172,7 @@ def test_some(self, device, dtype): @parametrize("bs", [1, 2, 4, 10]) def test_some_batched(self, device, dtype, bs): # this Pytorch data type is not fully supported on cuda today - # - unfortunately we can't skipIf because we don't see the actual parms in skipIf + # - unfortunately we can't skipIf because we don't see the actual params in skipIf if device.startswith(GPU_TYPE) and dtype == torch.int: return diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index c9b6ece0c221aa..219cea49e11d37 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -6,7 +6,7 @@ import numpy as np import torch -from torch.testing import make_tensor +from torch.testing import FileCheck, make_tensor from torch.testing._internal.common_dtype import get_all_dtypes from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -85,65 +85,6 @@ def foo(x): def test_cast(self, dtype): self.common(lambda a: a.to(dtype), (torch.rand(1024),)) - pointwise_unary_ops = [ - "i0", - "i0e", - "i1", - "i1e", - "erf", - "digamma", - "sinc", - "spherical_bessel_j0", - "bessel_j0", - "bessel_j1", - "bessel_y0", - "bessel_y1", - "modified_bessel_i0", - "modified_bessel_i1", - "modified_bessel_k0", - "modified_bessel_k1", - "scaled_modified_bessel_k0", - "scaled_modified_bessel_k1", - "entr", - ] - - @parametrize("op_name", pointwise_unary_ops) - def test_pointwise_unary_op(self, op_name): - self.common( - lambda x: getattr(torch.special, op_name)(x), - (torch.rand(128, 128),), - check_lowp=False, - ) - - def test_pointwise_polygamma(self): - self.common( - torch.special.polygamma, - ( - 1, - torch.rand(128, 128), - ), - check_lowp=False, - ) - - @parametrize( - "op_name", - [ - "zeta", - "xlog1py", - "chebyshev_polynomial_t", - "chebyshev_polynomial_u", - "chebyshev_polynomial_v", - "chebyshev_polynomial_w", - "hermite_polynomial_he", - ], - ) - def test_pointwise_binary_op(self, op_name): - self.common( - lambda x, y: getattr(torch.special, op_name)(x, y), - (torch.rand(128, 128), torch.rand(128, 128)), - check_lowp=False, - ) - def test_broadcast(self): self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024))) @@ -180,8 +121,25 @@ def fn(x, y): ), ) + def test_cholesky(self): + def fn(x): + return ( + torch.linalg.cholesky(x, upper=False), + torch.linalg.cholesky(x, upper=True), + ) + + self.common(fn, (torch.eye(64),), check_lowp=False) + class MPSBasicTestsAOTI(TestCase): + def check_model(self, m, inp, dynamic_shapes=None): + res2 = m(*inp) + ep = torch.export.export(m, inp, dynamic_shapes=dynamic_shapes) + path = torch._inductor.aoti_compile_and_package(ep) + m = torch._inductor.aoti_load_package(path) + res = m(*inp) + assert torch.allclose(res, res2) + def test_add_mps(self): class M(torch.nn.Module): def forward(self, x, y): @@ -189,12 +147,114 @@ def forward(self, x, y): inp = (torch.ones(3, 3, device="mps"), torch.ones(3, 3, device="mps")) m = M().to("mps") - res2 = m(*inp) - ep = torch.export.export(m, inp) - path = torch._inductor.aoti_compile_and_package(ep, "here.pt2") - m = torch._inductor.aoti_load_package(path) - res = m(*inp) - assert torch.allclose(res, res2) + self.check_model(m, inp) + + def test_fallback_mps(self): + class M(torch.nn.Module): + def forward(self, x, y): + return torch.nn.functional.linear(x, y) + + inp = ( + torch.randn(10, 10, device="mps"), + torch.randn(10, 10, device="mps"), + ) + m = M().to("mps") + self.check_model(m, inp) + + def test_c10(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2) + + inp = (torch.randn(2, 8, device="mps"),) + m = M().to("mps") + self.check_model(m, inp) + + def test_two_const(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.y = torch.ones(3, 3, device="mps") + self.z = torch.full((3, 3), 2, device="mps") + + def forward(self, x): + return x + self.y + self.z + + inp = (torch.ones(3, 3, device="mps"),) + m = Model().to(device="mps") + self.check_model(m, inp) + + def test_simple_dynamic(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + add_0 = x + y + return torch.nn.functional.relu(input=add_0, inplace=False) + + x = torch.randn(128, 2048, device="mps") + y = torch.randn(128, 2048, device="mps") + inp = (x, y) + + m = Model().to(device="mps") + dim0_x = torch.export.Dim("dim0_x", min=1, max=2048) + dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}} + + self.check_model(m, inp, dynamic_shapes) + + def test_dynamic_cat(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + return torch.cat([a, b], dim=0) + + a = torch.randn(2, 4, device="mps") + b = torch.randn(3, 4, device="mps") + inp = (a, b) + m = Model().to(device="mps") + + dim0_a = torch.export.Dim("dim0_a", min=1, max=10) + dim0_b = torch.export.Dim("dim0_b", min=1, max=20) + dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}} + self.check_model(m, inp, dynamic_shapes) + + def test_reuse_kernel(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + a = torch.sin(x) + b = torch.mm(a, y) + c = torch.sin(b) + d = torch.mm(b, c) + return d + + example_inputs = ( + torch.randn(87, 87, device="mps"), + torch.randn(87, 87, device="mps"), + ) + model = Model() + + ep = torch.export.export(model, example_inputs) + package_path = torch._export.aot_compile(ep.module(), example_inputs) + + target_str = 'mps_lib_0.getKernelFunction("generated_kernel")' + target_count = 1 + + with open(os.path.splitext(package_path)[0] + ".cpp") as cpp: + src_code = cpp.read() + FileCheck().check_count( + target_str, + target_count, + exactly=True, + ).run(src_code) if __name__ == "__main__": diff --git a/test/inductor/test_multi_kernel.py b/test/inductor/test_multi_kernel.py index 78c8f7b5ea0ae8..9adc68a2ac763b 100644 --- a/test/inductor/test_multi_kernel.py +++ b/test/inductor/test_multi_kernel.py @@ -115,12 +115,13 @@ def mock_run(self, *args, **kwargs): picked_kernel = self.picked_kernel return out - with unittest.mock.patch.object( - MultiKernelCall, "run", mock_run - ), unittest.mock.patch.object( - MultiKernelCall, - "benchmark_sub_kernels", - lambda *args, **kwargs: mock_latency, + with ( + unittest.mock.patch.object(MultiKernelCall, "run", mock_run), + unittest.mock.patch.object( + MultiKernelCall, + "benchmark_sub_kernels", + lambda *args, **kwargs: mock_latency, + ), ): torch.compile(f)(x) self.assertEqual(picked_kernel, force_kernel) @@ -190,8 +191,8 @@ def test_batchnorm_training(self): once for input and once for output. They are ruled out as in-out argument because they are considered as graph inputs. - Multi-kernel previously assumes that we never pass the same argument mutli times - for a kernel. No mater if we change inductor behavior to assure that, it's better + Multi-kernel previously assumes that we never pass the same argument multi times + for a kernel. No matter if we change inductor behavior to assure that, it's better to make multi-kernel being able to handle those cases. """ bn = nn.BatchNorm2d(3).to(GPU_TYPE) @@ -231,7 +232,7 @@ def f(x, y): def test_reduction_scratch_buffer(self, force_multi_kernel=1): """ - The explicited realized buffer in the test function will be passed in + The explicitly realized buffer in the test function will be passed in as a scratch buffer for the non-persistent reduction kernel but can be skipped for the persistent reduction kernel. diff --git a/test/inductor/test_op_dtype_prop.py b/test/inductor/test_op_dtype_prop.py index 8668e7ee1b7671..458d64aa41d5b9 100644 --- a/test/inductor/test_op_dtype_prop.py +++ b/test/inductor/test_op_dtype_prop.py @@ -78,7 +78,18 @@ def run(op, args, kwargs): args = (sample_input.input,) + sample_input.args kwargs = sample_input.kwargs out = run(op.get_op(), args, kwargs) - out_c = torch.compile(run)(op.get_op(), args, kwargs) + + # test_configs.runtime_triton_dtype_assert does not work well with dynamic shape so far. + # Consider the following cases for torch.add: + # both lhs/rhs are int32 tensor, there is also a integer alpha argument. + # In dynamic shape case, alpha is passed in as an ks0 argument. To be safe, + # we use tl.int64 for ks0's dtype. + # But the dtype for alpha is also decided as tl.int32 during lowering when + # we promote alpha to a ir.Constant. + # Ideally to resolve this problem, we should track assignment like + # alpha = ks0 + # so that we know alpha is actually tl.int64 rather than tl.int32. + out_c = torch.compile(run, dynamic=False)(op.get_op(), args, kwargs) self.assertEqual(out, out_c) @requires_gpu() diff --git a/test/inductor/test_ordered_set.py b/test/inductor/test_ordered_set.py index 305f6efbc9ab10..cbbb9bda56c731 100644 --- a/test/inductor/test_ordered_set.py +++ b/test/inductor/test_ordered_set.py @@ -156,8 +156,8 @@ def f(s1, s2): "Pure python equivalent of isdisjoint()" return not OrderedSet(s1).intersection(s2) - for larg in "", "a", "ab", "abc", "ababac", "cdc", "cc", "efgfe", "ccb", "ef": - s1 = self.thetype(larg) + for large in "", "a", "ab", "abc", "ababac", "cdc", "cc", "efgfe", "ccb", "ef": + s1 = self.thetype(large) for rarg in ( "", "a", @@ -235,7 +235,8 @@ def test_symmetric_difference(self): self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) for C in OrderedSet, frozenset, dict.fromkeys, str, list, tuple: self.assertEqual( - self.thetype("abcba").symmetric_difference(C("cdc")), OrderedSet("abd") + self.thetype("abcba").symmetric_difference(C("cdc")), + OrderedSet("abd"), # codespell:ignore ) self.assertEqual( self.thetype("abcba").symmetric_difference(C("efgfe")), @@ -651,7 +652,7 @@ def test_symmetric_difference_update(self): ) self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) for p, q in ( - ("cdc", "abd"), + ("cdc", "abd"), # codespell:ignore ("efgfe", "abcefg"), ("ccb", "a"), ("ef", "abcef"), diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index 8d71fd74d613c5..bcd1519c59350f 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -13,7 +13,7 @@ should_pad_mm_bf16, ) from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code +from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.inductor_utils import HAS_CUDA @@ -362,7 +362,7 @@ def foo(x, y): self.assertEqual(out, inps[0] @ inps[1]) @inductor_config.patch(force_shape_pad=True) - @fresh_inductor_cache() + @fresh_cache() def test_pad_addmm_2d_bias(self): @torch.compile() def foo(input, x, y): @@ -400,9 +400,9 @@ def test_pad_batch(self): expected_alignment = get_alignment_size(mat1) assert expected_alignment == 8, "Alignment for float16 should be 8" - assert should_pad_common( - mat1, mat2 - ), "This should pass the common padding criteria" + assert should_pad_common(mat1, mat2), ( + "This should pass the common padding criteria" + ) @torch.compile() def bmm(mat1, mat2): @@ -415,11 +415,11 @@ def bmm(mat1, mat2): ".run(", 2, exactly=True ).check("empty_strided_cuda((3, 8, 16)").run(code) - assert torch.allclose( - res2, bmm_expected_result - ), "BMM results are not identical" + assert torch.allclose(res2, bmm_expected_result), ( + "BMM results are not identical" + ) - @fresh_inductor_cache() + @fresh_cache() def test_exclude_padding(self): @torch.compile() def mm(a, b): @@ -448,7 +448,7 @@ def mm(a, b): repr(local_cache) ) - @fresh_inductor_cache() + @fresh_cache() @inductor_config.patch(max_pointwise_cat_inputs=2) def test_exclude_cat_padding(self): @torch.compile() @@ -475,7 +475,7 @@ def mm(inps, b): "No perf regression on H100+ with BF16", ) @skipIfRocm - @fresh_inductor_cache() + @fresh_cache() @inductor_config.patch( post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}} ) @@ -488,12 +488,12 @@ def test_pad_mm_bf16(self): expected_alignment = get_alignment_size(mat1) assert expected_alignment == 8, "Alignment for bfloat16 should be 8" - assert should_pad_common( - mat1, mat2 - ), "This should pass the common padding criteria" - assert should_pad_mm_bf16( - mat1.dtype, m, n, k - ), "This should pass the should_pad_mm_bf16 padding criteria" + assert should_pad_common(mat1, mat2), ( + "This should pass the common padding criteria" + ) + assert should_pad_mm_bf16(mat1.dtype, m, n, k), ( + "This should pass the should_pad_mm_bf16 padding criteria" + ) @torch.compile() def mm(mat1, mat2): @@ -508,7 +508,7 @@ def mm(mat1, mat2): assert torch.allclose(res2, mm_expected_result), "MM results are not identical" - @fresh_inductor_cache() + @fresh_cache() @inductor_config.patch( { "triton.unique_kernel_names": "original_aten", diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 74eb018ca806c3..15c1abdf32db22 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -8,11 +8,11 @@ from torch import nn, Tensor from torch._dynamo.convert_frame import maybe_cprofile from torch._dynamo.device_interface import get_interface_for_device -from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss from torch._inductor import config, ir, metrics from torch._inductor.fx_passes import pad_mm as pad_mm_pass from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import ceildiv, run_and_get_code from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -97,7 +97,13 @@ def setUpClass(cls): if HAS_GPU: cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision() cls.prior_default_device = torch.get_default_device() - torch.set_float32_matmul_precision("high") + # In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul. + # In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the + # logic of allowTF32CuBLAS(), set float32_matmul_precision to highest. + if torch.version.hip: + torch.set_float32_matmul_precision("highest") + else: + torch.set_float32_matmul_precision("high") torch.set_default_device(GPU_TYPE) @classmethod @@ -363,7 +369,7 @@ def test_longformer(self, bs=4): @unittest.skipIf(not DO_PERF_TEST or not HAS_TRANSFORMER, "Perf test not enabled") def test_longformer_small_bs(self): """ - The model exists in both HF and TB. In TB it uses a samller batch size. + The model exists in both HF and TB. In TB it uses a smaller batch size. """ self.test_longformer(bs=2) @@ -404,7 +410,7 @@ def pad_mm(a, b, align=16): @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") def test_padmm(self): """ - Latency between origional matmul and padded matmul: 2.717 v.s. 2.356 + Latency between original matmul and padded matmul: 2.717 v.s. 2.356 """ mat1_pad = torch.randn(8192, 30522, dtype=torch.float16) mat2_pad = torch.randn(30522, 768, dtype=torch.float16) @@ -428,7 +434,7 @@ def g(): pad_time = benchmarker.benchmark_gpu(g) print( - f"Latency between origional matmul and padded matmul: {ori_time:.3f} v.s. {pad_time:.3f}" + f"Latency between original matmul and padded matmul: {ori_time:.3f} v.s. {pad_time:.3f}" ) self.do_profiling(f, g, "No MM Padding", "With mm padding") @@ -481,7 +487,7 @@ def test_LinearAndSoftmax_codegen(self, bias=True): self.assertEqual( m_bad_shape.linear.weight.grad, m_bad_shape_opt.linear.weight.grad ) - self.assertTrue(len(wrapper_codes) == 2) # one for forward and oen for backward + self.assertTrue(len(wrapper_codes) == 2) # one for forward and one for backward forward_wrapper = wrapper_codes[0] # make sure the load for softmax is aligned diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 636794f47f11b0..ded14ce5ba870e 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -775,6 +775,32 @@ def f(x): joint_graph.joint_graph_passes(gm) self.assertEqual(count_calls(gm.graph), 2) + # handle negative 1 in size argument of view + def f(x): + x = aten.view.default(x, [3, 5, 7]) + x = aten.view.default(x, [-1, 7]) + return x + + gm = make_fx(f)(x) + self.assertEqual(count_calls(gm.graph), 2) + joint_graph.joint_graph_passes(gm) + self.assertEqual(count_calls(gm.graph), 0) + + def test_pointless_view_pair_dynamic_shapes(self): + def f(x): + s1, s2 = x.shape + x = aten.view.default(x, [-1]) + x = aten.view.default(x, [s1, s2]) + return x + + x = torch.randn(15, 7, device=GPU_TYPE) + torch._dynamo.decorators.mark_unbacked(x, 0) + + out = torch.compile(f, dynamic=True)(x) + self.assertTrue(torch.equal(x, out)) + + self.assertEqual(counters["inductor"]["removed_pointless_view_pair"], 1) + def test_pointless_permute_pair(self): def f(x): x = aten.permute.default(x, [1, 0]) @@ -978,7 +1004,7 @@ def fn(a, b, c): ] self.common(fn, args, 0, 0) - # cat and split lenghts are different + # cat and split lengths are different def fn(a, b, c): cat = torch.ops.aten.cat.default([a, b, c], 1) split_with_sizes = torch.ops.aten.split_with_sizes.default(cat, [5, 5], 1) @@ -1132,15 +1158,19 @@ def fn5(x, y): torch.randn(5, 5, device=GPU_TYPE), ] - with unittest.mock.patch( - "torch._inductor.fx_passes.pre_grad.config.pre_grad_fusion_options", - {"test": {}}, - ), unittest.mock.patch( - "torch._inductor.fx_passes.pre_grad.PRE_GRAD_FUSIONS", - [], - ), unittest.mock.patch( - "torch._inductor.fx_passes.pre_grad.PRE_GRAD_PATTERNS", - {"test": test_pass}, + with ( + unittest.mock.patch( + "torch._inductor.fx_passes.pre_grad.config.pre_grad_fusion_options", + {"test": {}}, + ), + unittest.mock.patch( + "torch._inductor.fx_passes.pre_grad.PRE_GRAD_FUSIONS", + [], + ), + unittest.mock.patch( + "torch._inductor.fx_passes.pre_grad.PRE_GRAD_PATTERNS", + {"test": test_pass}, + ), ): for fn in (fn0, fn1, fn2, fn3, fn4, fn5): counter = 0 diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index d83aae1e675709..0ca54257250f67 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -646,7 +646,7 @@ def f(a): @patch.object(config, "pattern_matcher", False) def test_fusion_choice4_cpu(self): - # Fuse nodes with same number of elements and compatible orginal var ranges + # Fuse nodes with same number of elements and compatible original var ranges # [buf0: {d0: 60, d1: 11}, buf1: {d0: 660}] -> buf0_buf1 def f(x, w): o1 = x * w diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index ff557874881e44..3d54c378de4a2f 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -185,6 +185,8 @@ def fn(x, y): def test_inductor_profiling_triton_hooks(self): from triton.compiler import CompiledKernel # @manual + from torch._inductor.runtime.triton_compat import knobs + hooks_called = {"enter": False, "exit": False} def launch_enter_hook(lazy_dict): @@ -193,8 +195,12 @@ def launch_enter_hook(lazy_dict): def launch_exit_hook(lazy_dict): hooks_called["exit"] = True - CompiledKernel.launch_enter_hook = launch_enter_hook - CompiledKernel.launch_exit_hook = launch_exit_hook + if knobs: + knobs.runtime.launch_enter_hook = launch_enter_hook + knobs.runtime.launch_exit_hook = launch_exit_hook + else: + CompiledKernel.launch_enter_hook = launch_enter_hook + CompiledKernel.launch_exit_hook = launch_exit_hook def fn(x, y): return torch._foreach_add(x, y) diff --git a/test/inductor/test_scatter_optimization.py b/test/inductor/test_scatter_optimization.py index 2fa1e9db14469b..3e8561020fe68f 100644 --- a/test/inductor/test_scatter_optimization.py +++ b/test/inductor/test_scatter_optimization.py @@ -180,9 +180,9 @@ def f(m, x, label): ref_grad = ref_model.weight.grad opt_f(opt_model, x, label) act_grad = opt_model.weight.grad - assert torch.allclose( - ref_grad, act_grad, atol=1e-3, rtol=1e-3 - ), f"{ref_grad=}\n{act_grad=}" + assert torch.allclose(ref_grad, act_grad, atol=1e-3, rtol=1e-3), ( + f"{ref_grad=}\n{act_grad=}" + ) self.check_metric() diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index 44c101864fcbda..fe897eaded3129 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -38,9 +38,9 @@ def skip_cache(self, choices, name, key, benchmark): def wrapped(*args, **kwargs): counters.clear() torch.manual_seed(12345) - assert ( - not torch.backends.cuda.matmul.allow_tf32 - ), "correctness testing is allergic to tf32" + assert not torch.backends.cuda.matmul.allow_tf32, ( + "correctness testing is allergic to tf32" + ) return fn(*args, **kwargs) return wrapped @@ -51,6 +51,8 @@ def setUp(self): super().setUp() if not is_big_gpu(): return self.skipTest("Need a big GPU to run max_autotune=True") + # Clear preprocessing functions to ensure clean state + select_algorithm.clear_preprocessing_fns() @patches def test_linear_relu(self): @@ -83,6 +85,37 @@ def foo(input, weight, bias): foo(*inps) self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @patches + def test_preprocessing_single_choice(self): + # pass a list to the preprocessing function to assert that it was + # actually called + func_called = [False] + + # Register a preprocessing function that returns only the first choice + # This in turn will lead to autotuning being skipped as it's a single + # choice, and the counter itself will not be bumped + def return_first_choice_only(choices): + func_called[0] = True + return choices[:1] if choices else [] + + select_algorithm.add_preprocessing_fn(return_first_choice_only) + + @torch.compile + def foo(input, weight, bias): + return torch.addmm(bias, input, weight) + + inps = ( + torch.randn(20, 33, device=GPU_TYPE), + torch.randn(33, 16, device=GPU_TYPE), + torch.randn(20, 16, device=GPU_TYPE), + ) + + foo(*inps) + # Since we only have one choice, autotuning should be skipped + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) + # The preprocessing function should have been called + self.assertTrue(func_called[0]) + @patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2)) @patches def test_addmm_fp16(self): diff --git a/test/inductor/test_snode_runtime.py b/test/inductor/test_snode_runtime.py index 571c0ec9f0a1cb..c57393d993eabf 100644 --- a/test/inductor/test_snode_runtime.py +++ b/test/inductor/test_snode_runtime.py @@ -56,7 +56,7 @@ class TestCase(InductorTestCase): """ Helper methods to compare runtime estimate against 0. Since this estimate is hardware dependent, - stronger comparisons may fail dependending on the host's specs. + stronger comparisons may fail depending on the host's specs. atol/rtol must be provided explicitly with each call, since precision/rel_tol overrides are not always utilized """ diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index 238e5ae9272523..4286bdfda7cd9c 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -115,6 +115,33 @@ def normalize_reshape_with_dynamic_shape(x): ) counters.clear() + @torch._inductor.config.patch( + pre_grad_fusion_options={ + "normalization_pass": {}, + }, + post_grad_fusion_options={}, + ) + def test_cat_normalization(self): + def caoncat_only(x): + return torch.concat(list(torch.split(x, 2, 1)), dim=1) + + args = [ + torch.randn(2, 32), + ] + for fn, dynamic, expected_cat_norm_count in [ + (caoncat_only, False, 2), + ]: + expected = fn(*args) + actual = torch.compile(fn, dynamic=dynamic)(*args) + + torch.testing.assert_close(actual, expected) + self.assertEqual( + counters["inductor"]["normalization_pass"], + expected_cat_norm_count, + msg=f"for {fn}", + ) + counters.clear() + @patch def test_consecutive_split_merge(self): def multi_split(x): diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index a2e59989b66b33..477d5ac2e6c207 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -14,6 +14,7 @@ from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.triton_utils import requires_cuda +from torch.torch_version import TorchVersion @requires_cuda @@ -156,10 +157,13 @@ def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64): compiled_kernel = floats[1,](*args) launcher = self._make_launcher(compiled_kernel) - # TODO: in Pytorch's pinned version of triton, arg3 is typed as regular float - # but in triton 3.3.0, this is fixed and it's 0ffd. We'll need to update later. - self.assertEqual(launcher.arg_tys, "Offf") - self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda")) + if TorchVersion(triton.__version__) >= TorchVersion("3.4.0"): + self.assertEqual(launcher.arg_tys, "Offd") + else: + self.assertEqual(launcher.arg_tys, "Offf") + # TODO this line fails on Triton 3.4.0 (https://github.com/triton-lang/triton/issues/6176) + # Add the check back when this is fixed in Triton + # self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda")) new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda") device_interface = get_interface_for_device("cuda") stream = device_interface.get_raw_stream(device_interface.current_device()) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 749df214a169f5..ea7fa654583478 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -30,6 +30,7 @@ import torch._dynamo.config as dynamo_config import torch._inductor.aoti_eager import torch.nn as nn +from torch._C._dynamo.guards import assert_alignment, assert_size_stride from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.debug_utils import aot_graph_input_parser from torch._dynamo.device_interface import get_interface_for_device @@ -71,6 +72,7 @@ PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM80OrLater, + SM90OrLater, TEST_CUDNN, tf32_on_and_off, with_tf32_off, @@ -189,7 +191,7 @@ def _large_cumprod_input(shape, dim, dtype, device): - # Construct a cumprod input which guaruntees not to overflow or underflow + # Construct a cumprod input which guarantees not to overflow or underflow if is_integer_dtype(dtype): # Large products don't fit in integers, the best we can do # is random +/-1 values to test the sign of the result @@ -526,9 +528,9 @@ def reference_to_expect(actual_flat, correct_flat): if reference_in_float and exact_dtype: for expect_dtype, actual_result in zip(expect_dtypes, actual_flat): if expect_dtype is not None: - assert ( - actual_result.dtype == expect_dtype - ), f"dtype mismatch, expected {expect_dtype} but got {actual_result.dtype}" + assert actual_result.dtype == expect_dtype, ( + f"dtype mismatch, expected {expect_dtype} but got {actual_result.dtype}" + ) if reference_in_float: correct_flat = reference_to_expect(actual_flat, correct_flat) @@ -1409,7 +1411,14 @@ def fn(a, b): ) _, code = run_and_get_code(fn, x, y) code = " ".join(code) - self.assertEqual( + assert_keywords = ["assert_size_stride", "assert_alignment"] + filtered_lines = [ + line + for line in code.splitlines() + if not any(assert_key in line for assert_key in assert_keywords) + ] + code = "\n".join(filtered_lines) + self.assertGreaterEqual( code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3 ) @@ -1724,6 +1733,12 @@ def fn(a): self.common(fn, (torch.randn(1024),)) + def test_index_remainder(self): + def fn(x, y): + return x[y % 12] + + self.common(fn, (torch.rand(1024), torch.randint(50, (50,)))) + @xfailIfS390X @config.patch(debug_index_asserts=False) @config.patch("cpp.enable_tiling_heuristics", False) @@ -2471,17 +2486,8 @@ def fn(x, y): z = x * y return z.sum((0, 1)) - atol = None - rtol = None - - # By default, inductor generate non-persistent reduction kernels in this - # case. But when multi-kernel is enabled, inductor will pick the faster - # of persistent reduction and non-persistent-reduction kernel. - # In this case, inductor picked the persistent-reduction kernel. - # The persistent reduction kernel happens to need looser tolerance. - if config.triton.multi_kernel: - atol = 1e-5 - rtol = 1e-5 + atol = 1e-3 + rtol = 1e-3 self.common( fn, (torch.randn(2, 197, 256), torch.randn(2, 1, 256)), atol=atol, rtol=rtol ) @@ -2558,7 +2564,6 @@ def fn(x): self.common(fn, (torch.ones(32, 32) * 70,)) @skip_if_halide - @xfail_if_mps_unimplemented # aten::_cummin_helper is not implemented for MPS def test_cummin(self): def fn(x): return x.cummin(0) @@ -2645,7 +2650,6 @@ def make_tensor(shape): inp = torch.full((2, n), float("inf"), device=self.device, dtype=_dtype) self.assertEqual(cfn(inp), fn(inp)) - @xfail_if_mps_unimplemented @xfail_if_triton_cpu def test_logcumsumexp(self): def fn(x): @@ -2672,7 +2676,6 @@ def fn(x): rtol=1e-5, ) - @xfail_if_mps_unimplemented def test_logcumsumexp_zero_dim(self): def fn(x): return x.logcumsumexp(0), x.logcumsumexp(-1) @@ -3864,9 +3867,7 @@ def forward(self, x): with self.assertRaisesRegex(RuntimeError, msg): with torch.no_grad(): torch.compile(fn)(t) - # TODO: Autograd internal assertion - msg = r".*isDifferentiableType\(variable.scalar_type\(\)\) INTERNAL ASSERT FAILED.*" - with self.assertRaisesRegex(RuntimeError, msg): + with self.assertRaisesRegex(RuntimeError, "Autograd not support dtype:.*"): torch.compile(fn)(t) @unittest.skipIf( @@ -4105,15 +4106,16 @@ def test_conv_inference_heuristics(self): def foo(m, inp): return m(inp) - with torch.no_grad(): - _, code = run_and_get_code(foo, grouped_conv, input_tensor) - # no to channels last permuting before kernel - if config.cpp_wrapper: - FileCheck().check_not(" call_triton").check("_convolution(").run( - code[0] - ) - else: - FileCheck().check_not(".run(").check(".convolution(").run(code[0]) + if self.device != "xpu": + with torch.no_grad(): + _, code = run_and_get_code(foo, grouped_conv, input_tensor) + # no to channels last permuting before kernel + if config.cpp_wrapper: + FileCheck().check_not(" call_triton").check("_convolution(").run( + code[0] + ) + else: + FileCheck().check_not(".run(").check(".convolution(").run(code[0]) # in out should do channels last in inference in_channels = 8 @@ -4283,7 +4285,7 @@ def fn(a): (torch.randn([2, 20, 2]),), ) - # It's a view so it doens't generate a kernel + # It's a view so it doesn't generate a kernel @expectedFailureCodegenDynamic def test_slice3(self): def fn(a, b): @@ -4376,9 +4378,7 @@ def fn2(a): ) @parametrize("dilation", (1, 2)) - @parametrize( - "dim", (subtest(2), subtest(3, decorators=[xfail_if_mps_unimplemented])) - ) + @parametrize("dim", (subtest(2), subtest(3))) def test_low_memory_max_pool(self, dilation: int, dim: int): prims = torch.ops.prims @@ -5526,6 +5526,13 @@ def fn(x): (torch.randn([1, 2, 4, 8]),), ) + def test_var_mean_div_by(self): + def fn(x): + var, mean = torch.var_mean(x, dim=2, keepdim=True) + return x / var, var, mean + + self.common(fn, (torch.rand([1, 17, 2048]),)) + def test_var_correction(self): def fn(x): dim = -1 @@ -6317,22 +6324,10 @@ def matmul_with_op(x, y, fn): # test no-op fns = ( - lambda x: x - + torch.zeros( - [256, 256], dtype=torch.float32, device=x.device - ), # noqa: E731 - lambda x: x - - torch.zeros( - [256, 256], dtype=torch.float32, device=x.device - ), # noqa: E731 - lambda x: x - * torch.ones( - [256, 256], dtype=torch.float32, device=x.device - ), # noqa: E731 - lambda x: x - / torch.ones( - [256, 256], dtype=torch.float32, device=x.device - ), # noqa: E731 + lambda x: x + torch.zeros([256, 256], dtype=torch.float32, device=x.device), # noqa: E731 + lambda x: x - torch.zeros([256, 256], dtype=torch.float32, device=x.device), # noqa: E731 + lambda x: x * torch.ones([256, 256], dtype=torch.float32, device=x.device), # noqa: E731 + lambda x: x / torch.ones([256, 256], dtype=torch.float32, device=x.device), # noqa: E731 ) inps = [torch.rand([256, 256], device=self.device) for _ in range(2)] @@ -7192,7 +7187,6 @@ def fn(a): self.common(fn, (torch.randn([2, 4, 37, 38]),)) - @xfail_if_mps_unimplemented def test_upsample_nearest3d(self): def fn(a): return ( @@ -8268,7 +8262,6 @@ def fn(x): actual_out = compiled_fn(view) self.assertEqual(reference_out.stride(), actual_out.stride()) - @xfail_if_triton_cpu def test_like_channels_last(self): def foo(): randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32) @@ -10236,19 +10229,19 @@ def test_zero_dim_reductions(self): for kd in [True, False]: inps0 = (torch.zeros(2, 0, device=self.device, dtype=torch.float16), 1, kd) failed_ops = [aten.argmin, aten.argmax, aten.max, aten.min] - for fo in failed_ops: + for op in failed_ops: with self.assertRaisesRegex( IndexError, "Expected reduction dim 1 to have non-zero size" ): - mod = make_fx(fo)(*inps0) + mod = make_fx(op)(*inps0) _ = compile_fx_inner(mod, inps0) pass_ops = [ lambda *x: fn(*x) for fn in [aten.sum, aten.prod, aten.any, aten.all] ] - for po in pass_ops: - compiled = torch.compile(po, backend="inductor") - expected = po(*inps0) + for op in pass_ops: + compiled = torch.compile(op, backend="inductor") + expected = op(*inps0) actual = compiled(*inps0) self.assertTrue(torch.allclose(actual, expected, atol=1e-3, rtol=1e-3)) @@ -10389,7 +10382,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with TestRefMode(): fn_compiled(inps) - # for some reason, TorchDispatch doesnt capture the + # for some reason, TorchDispatch doesn't capture the # cuda mm call (even without cudagraphs) if self.device == "cpu": self.assertTrue(matmul_seen) @@ -10505,6 +10498,50 @@ def f(x): self.assertEqual(out_ref.stride(), out_test.stride()) self.assertEqual(x_ref, x_test) + @requires_gpu() + @skip_if_not_triton + @unittest.skipIf( + not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)" + ) + def test_inductor_multiple_specializations(self): + from triton.testing import do_bench + + @torch.compile( + options={ + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + }, + dynamic=False, + ) + def inductor_matmul(a, b): + torch._check(a.shape[0] == b.shape[1]) + return (m, torch.mm(a, b)) + + m = 16 + k = 1280 + dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16) + dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16) + torch._dynamo.decorators.mark_dynamic( + dynamic_a, + 0, + ) + torch._dynamo.decorators.mark_dynamic( + dynamic_specialized_a, + 0, + specialize_on=[lambda x0: x0 == 16], + ) + torch._dynamo.decorators.mark_dynamic( + b, + 1, + ) + dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b)) + torch._dynamo.reset() + dynamic_specialized = do_bench( + lambda: inductor_matmul(dynamic_specialized_a, b) + ) + self.assertGreaterEqual(dynamic, dynamic_specialized) + @requires_gpu() def test_stride_preservation_with_stride_modifying_fx_pass(self): def f(x): @@ -10512,7 +10549,7 @@ def f(x): def custom_pass(g: torch.fx.Graph) -> None: """ - Applies `lamda x: x.t().contiguous().t()` to the output. + Applies `lambda x: x.t().contiguous().t()` to the output. """ output_node = g.find_nodes(op="output")[0] assert len(output_node.args) == 1 @@ -10667,7 +10704,6 @@ def fn(arg0_1): @skip_if_halide # log2 not yet implemented @skip_if_triton_cpu # log2 implemented only in Dec 2024 - @expectedFailureXPU # Remmove this after the known issue of Intel Triton #3871 resolved. def test_pow_by_natural_log2_dynamic_shapes(self): @torch.compile(dynamic=True) def fn(x): @@ -11401,7 +11437,7 @@ def fn(x, size, memory_format): @staticmethod def _cases_resize_as_common(): for x, y_size, memory_format in CommonTemplate._cases_resize_common(): - # each sizes /memory_format combintation tested in 2 ways: + # each sizes /memory_format combination tested in 2 ways: # 1. y is contiguous fn gets memory_format kwargs # 2. y has memory_format contiguity and fn gets preserve kwarg # 3. y has some other strides (not contiguous or channels last) and fn gets preserve @@ -11448,7 +11484,6 @@ def fn(x, y): opt_fn = torch.compile(fn, backend="inductor") same(fn(x, y), opt_fn(x_clone, y)) - @xfail_if_mps_unimplemented @xfail_if_triton_cpu def test_erfc(self): def fn(x): @@ -11879,6 +11914,98 @@ def fn(x): check_lowp=False, ) + @requires_gpu() + @skip_if_not_triton + @skip_if_cpp_wrapper("skip cpp_wrapper tests") + @config.patch(implicit_fallbacks=True) + def test_generated_code_has_size_stride_assert(self): + def foo(x): + return 3 * x + + def foo_meta(x): + return torch.empty_like(x) + + define_custom_op_for_test("foo", foo, foo_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.foo(a) + return b + + a = torch.randn((16, 32), device=self.device) + + _, code = run_and_get_code( + torch.compile(fn), + a, + ) + if not is_dynamic_shape_enabled(): + if code and len(code) > 0 and "assert_size_stride(" in code[0]: + try: + FileCheck().check_regex( + r"assert_size_stride\s*\(\s*[^,]+,\s*\([^\)]*\),\s*\([^\)]*\),\s*'[^']+'\s*\)" + ).run(code[0]) + except Exception as e: + print(f"Failed regex match for assert_size_stride: {e}") + print(code[0]) + raise e + else: + print("Skipping: No assert_size_stride found.") + + @requires_gpu() + @skip_if_not_triton + @skip_if_cpp_wrapper("skip cpp_wrapper tests") + @config.patch(implicit_fallbacks=True) + def test_generated_code_has_alignment_assert(self): + def foo(x): + return 3 * x + + def foo_meta(x): + return torch.empty_like(x) + + define_custom_op_for_test("foo", foo, foo_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.foo(a) + return b + + a = torch.randn((16, 32), device=self.device) + + _, code = run_and_get_code( + torch.compile(fn), + a, + ) + if not is_dynamic_shape_enabled(): + if code and len(code) > 0 and "assert_alignment(" in code[0]: + try: + FileCheck().check_regex( + r"assert_alignment\s*\(\s*[^,]+,\s*[^,]+,\s*'[^']+'\s*\)" + ).run(code[0]) + except Exception as e: + print(f"Failed regex match for assert_alignment: {e}") + print(code[0]) + raise e + else: + print("Skipping: No assert_alignment found.") + + def test_assert_size_stride_op_name_pass(self): + tensor = torch.empty((16, 32)) + assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name") + + def test_assert_size_stride_op_name_fail(self): + tensor = torch.empty((16, 32)) + with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"): + assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name") + + def test_assert_alignment_op_name_pass(self): + tensor = torch.empty((16, 32)) + assert_alignment(tensor, 16, "torch.ops.dummy.op_name") + + def test_assert_alignment_op_name_fail(self): + tensor = torch.empty((16, 32)) + with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"): + assert_alignment(tensor, 0, "torch.ops.dummy.op_name") + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) @torch._inductor.config.patch(implicit_fallbacks=True) def test_custom_op_unbacked_symints(self): @@ -12007,7 +12134,7 @@ def fn(x): # a new test case. self.assertEqual(len(bar_strides), 1) if self.device == "mps" and MACOS_VERSION < 15.0: - # Before MacOS15 contigous output were returned regardless of input + # Before MacOS15 contiguous output were returned regardless of input self.assertEqual(bar_strides[0], expected_stride) else: self.assertNotEqual(bar_strides[0], expected_stride) @@ -12639,7 +12766,6 @@ def fn(x): or name not in [ "airy_ai", - "erfc", "erfcx", "gammainc", "gammaincc", @@ -12647,10 +12773,6 @@ def fn(x): "legendre_polynomial_p", "log_ndtr", "ndtri", - "shifted_chebyshev_polynomial_t", - "shifted_chebyshev_polynomial_u", - "shifted_chebyshev_polynomial_v", - "shifted_chebyshev_polynomial_w", ] else self.assertRaises(NotImplementedError) ) @@ -12726,7 +12848,11 @@ def test_generate_rand_fp8(self): self.assertTrue(t.dtype is torch.float8_e4m3fn) @largeTensorTest("1GB", inductor=True) - def test_large_grid(self): + @parametrize( + "use_block_ptr", + [subtest(False), subtest(True, decorators=[skip_if_not_triton])], + ) + def test_large_grid(self, use_block_ptr): # https://github.com/pytorch/pytorch/issues/123210 def fn(primals_5): view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) @@ -12739,9 +12865,11 @@ def fn(primals_5): s0 = 16777472 s1 = 8 - compiled_fn = torch.compile(fn) - actual = compiled_fn(torch.ones(s0, s1, device=self.device)) - self.assertTrue((actual == 1).all()) + + with config.patch({"triton.use_block_ptr": use_block_ptr}): + compiled_fn = torch.compile(fn) + actual = compiled_fn(torch.ones(s0, s1, device=self.device)) + self.assertTrue((actual == 1).all()) @skip_if_gpu_halide def test_pattern_matcher_multi_user(self): @@ -12937,7 +13065,7 @@ def __init__(self, dim): def forward(self, x): x = self.conv_t(x) - x = torch.sigmoid(x) # tigger condition + x = torch.sigmoid(x) # trigger condition return x for dim in (1, 2, 3): @@ -13012,12 +13140,12 @@ def f(x): code = run_and_get_triton_code(f, x) if is_dynamic_shape_enabled(): - FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check( - "assert_size_stride(buf2, (s77, s27), (s27, 1))" + FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check( + "assert_size_stride(buf2, (s77, s27), (s27, 1)" ).run(code) else: - FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check( - "assert_size_stride(buf2, (16, 32), (32, 1))" + FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check( + "assert_size_stride(buf2, (16, 32), (32, 1)" ).run(code) @requires_cuda @@ -13318,8 +13446,9 @@ def get_same_padding(x: int, k: int, s: int, d: int): def pad_same(x, k, s, d=(1, 1), value=0): ih, iw = x.size()[-2:] - pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding( - iw, k[1], s[1], d[1] + pad_h, pad_w = ( + get_same_padding(ih, k[0], s[0], d[0]), + get_same_padding(iw, k[1], s[1], d[1]), ) if pad_h > 0 or pad_w > 0: x = torch.nn.functional.pad( @@ -13335,6 +13464,35 @@ def pad_same(x, k, s, d=(1, 1), value=0): ref = pad_same(x, (5, 5), (2, 2)) self.assertEqual(res, ref, atol=0, rtol=0) + @skip_if_halide # only 32-bit indexing + @largeTensorTest("16GB", inductor=True) + def test_split_reduction_with_int64_size(self): + if torch._inductor.config.cpu_backend == "triton": + raise unittest.SkipTest( + "Fail for triton cpu backend with error: https://gist.github.com/shunting314/a873fb32b6b7b5a437f44280ae86839f" + ) + + if self.device == "cpu": + raise unittest.SkipTest( + "The test fails some times on CI: " + "https://github.com/pytorch/pytorch/actions/runs/15333913377/job/43153170162. " + "Skip for now." + ) + + size = (30000, 100000) + + # rand rather than randn since the mean for the latter is close to 0 + # which happens to be close to the value generated by the bug. + t = torch.rand(size, dtype=torch.float, device=self.device) + op = torch.mean + expected = op(t) + actual = torch.compile(op)(t) + # self.common takes more GPU memory. Do the check directly + self.assertTrue( + torch.allclose(expected, actual, atol=1e-2, rtol=1e-2), + f"{expected=} {actual=}", + ) + def test_remove_noop_view_default(self): def f(x): batch_size = x.shape[0] @@ -13398,6 +13556,32 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar ignore_empty_lines=True, ) + @config.patch("min_num_split", 256) + @xfail_if_mps # TypeError: cannot determine truth value of Relational + def test_split_reduction_dynamic_shape(self): + from torch._dynamo.decorators import mark_dynamic + + def f(x): + # outer reduction + return x.sum(dim=0) + + N = 512 + x_small = torch.randn(4096, N, device=self.device) + + mark_dynamic(x_small, 0) + expect = f(x_small) + opt_f = torch.compile(f, dynamic=True) + actual = opt_f(x_small) + self.assertTrue(torch.allclose(expect, actual, atol=1e-3, rtol=1e-3)) + + if DO_PERF_TEST: + from triton.testing import do_bench + + # benchmark for a much larger input + x_large = torch.randn(4096 * 1000, N, device=self.device) + ms = do_bench(lambda: opt_f(x_large)) + print(f"{ms=:.3f}") + @expectedFailureCodegenDynamic def test_special_polygamma(self): fn = torch.special.polygamma @@ -13406,6 +13590,42 @@ def test_special_polygamma(self): self.common(fn, (1, x)) self.common(fn, (2, x)) + @skip_if_triton + @skip_if_halide + @config.patch({"freezing": True}) + def test_dont_constant_fold(self): + from torch._inductor.constant_folding import ( + add_dont_constant_fold, + clear_dont_constant_fold, + ) + + m = 5 + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.randn(m) + self.s = torch.randn(m) + + def forward(self, x): + return self.w * self.s + x + + x = torch.rand(m) + mod = M() + for dont_constant_fold in [True, False]: + clear_dont_constant_fold() + if dont_constant_fold: + add_dont_constant_fold(torch.ops.aten.mul.Tensor) + with torch.no_grad(): + refe_out = mod(x) + mod = torch.compile(mod) + test_out, (code,) = run_and_get_code(mod, x) + if dont_constant_fold: + FileCheck().check("cpp_fused_add_mul").run(code) + else: + FileCheck().check("cpp_fused_add_0").run(code) + self.assertEqual(refe_out, test_out) + @dataclasses.dataclass class TestFailure: @@ -13414,9 +13634,7 @@ class TestFailure: __test__: bool = False -def copy_tests( - my_cls, other_cls, suffix, test_failures=None, xfail_prop=None -): # noqa: B902 +def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902 for name, value in my_cls.__dict__.items(): if name.startswith("test_"): # You cannot copy functions in Python, so we use closures here to @@ -13762,6 +13980,51 @@ def forward( ) torch._inductor.aot_compile(traced, inputs) + @skipCUDAIf(not SM90OrLater, "Requires sm90") + @requires_cuda + @unittest.skipIf(TEST_WITH_ROCM, "no grouped_mm support") + @config.patch(implicit_fallbacks=True) + def test_grouped_mm(self): + @torch.compile(fullgraph=True) + def f(a, b, offs, out_dtype): + return torch._grouped_mm( + a, b.transpose(-2, -1), offs=offs, out_dtype=out_dtype + ) + + device = "cuda" + dtype = torch.bfloat16 + + m, n, k, n_groups = 16, 32, 16, 4 + a_ref = torch.randn(m * n_groups, k, device=device, dtype=dtype)[:, :k] + + b_ref = torch.randn( + n_groups, + n, + k, + device=device, + dtype=dtype, + )[::1, :, :k] + + offs = torch.arange( + m, n_groups * m + 1, m, device=device, dtype=torch.int32 + ) + + a_ref.requires_grad_(True) + b_ref.requires_grad_(True) + + a_test = a_ref.clone().detach().requires_grad_() + b_test = b_ref.clone().detach().requires_grad_() + + out_ref = f(a_ref, b_ref, offs, out_dtype=torch.bfloat16) + out_ref.sum().backward() + + out_test = f(a_test, b_test, offs=offs, out_dtype=torch.bfloat16) + out_test.sum().backward() + + self.assertEqual(out_ref, out_test) + self.assertEqual(a_ref.grad, a_test.grad) + self.assertEqual(b_ref.grad, b_test.grad) + def test_optimize_indexing_assert(self): def has_indirect(code, tl_fn: str): self.assertTrue( @@ -13779,7 +14042,7 @@ def has_indirect(code, tl_fn: str): def has_assert(code, lower: bool, upper: bool): self.assertIn( - "device_assert", code, msg=f"No device asert found:\n{code}" + "device_assert", code, msg=f"No device assert found:\n{code}" ) for line in code.split("\n"): if "device_assert" in line: @@ -13989,6 +14252,8 @@ def f(x, mask): # it does not move the tensor constructor to cuda and keeps it on CPU. self.assertFalse("empty_strided_cuda(()" in code) + # only uncoalesced without this :) + @config.patch("triton.coalesce_tiling_analysis", False) @config.patch("triton.use_block_ptr", False) def test_evict_last_non_coalesced_loads(self): @torch.compile @@ -14039,6 +14304,7 @@ def f(a, b): ) @config.patch("triton.use_block_ptr", True) + @config.patch("triton.coalesce_tiling_analysis", False) def test_evict_last_non_coalesced_loads_block_ptr(self): @torch.compile def f(a, b): @@ -14534,9 +14800,7 @@ def forward(self, x): B, T, C, - ) = ( - x.size() - ) # batch size, sequence length, embedding dimensionality (n_embd) + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) q, k, v = qkv.split(self.n_embd, dim=2) @@ -14619,9 +14883,9 @@ def __init__(self, config): def forward(self, idx, targets): device = idx.device b, t = idx.size() - assert ( - t <= self.config.block_size - ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + assert t <= self.config.block_size, ( + f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + ) pos = torch.arange( 0, t, dtype=torch.long, device=device ) # shape (t) @@ -14812,18 +15076,41 @@ def f(x, y): return x1 + y1 + z + y_cpu.to(GPU_TYPE) f_compiled = torch.compile(f) - x, y = torch.ones(3, 3, device=self.device), torch.randn( - 3, 3, device=self.device + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), ) compiled_out = f_compiled(x, y) self.assertEqual(compiled_out, f(x, y)) - x, y = torch.ones(4, 4, device=self.device), torch.randn( - 4, 4, device=self.device + x, y = ( + torch.ones(4, 4, device=self.device), + torch.randn(4, 4, device=self.device), ) compiled_out = f_compiled(x, y) self.assertEqual(compiled_out, f(x, y)) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_cat_backward(self): + def f(x, w): + y = torch.cat((x, x), dim=0) + z = y @ w + return z @ z.T + + compiled_f = torch.compile(f) + + for shape in (2, 3): + torch.manual_seed(42) + eager_x = torch.randn(shape, 2, device=self.device) + eager_w = torch.randn(2, 2, device=self.device, requires_grad=True) + torch.manual_seed(42) + compiled_x = torch.randn(shape, 2, device=self.device) + compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True) + + f(eager_x, eager_w).sum().backward() + compiled_f(compiled_x, compiled_w).sum().backward() + self.assertEqual(eager_w.grad, compiled_w.grad) + @dynamo_config.patch("capture_dynamic_output_shape_ops", True) @config.patch(implicit_fallbacks=True) @torch._inductor.config.patch("graph_partition", True) @@ -14874,8 +15161,9 @@ def f(x, y): return x1 + y1 + z + y_cpu.to(GPU_TYPE) f_compiled = torch.compile(f) - x, y = torch.ones(3, 3, device=self.device), torch.randn( - 3, 3, device=self.device + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), ) torch._dynamo.decorators.mark_unbacked(x, 0) @@ -14896,8 +15184,9 @@ def f(x, y, integer): return x1 + y1 + z + y_cpu.to(GPU_TYPE) f_compiled = torch.compile(f) - x, y = torch.ones(3, 3, device=self.device), torch.randn( - 3, 3, device=self.device + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), ) torch._dynamo.decorators.mark_unbacked(x, 0) @@ -14995,6 +15284,9 @@ def f(x): self.assertIn("aoti_torch_check_inf_and_nan", code) else: self.assertIn("# make sure graph inputs are not nan/inf", code) + self.assertRegex(code, r"return_vars = (.*)") + self.assertIn("for var in return_vars:", code) + self.assertIn("if isinstance(var, torch.Tensor):", code) self.assertRegex(code, r"assert not .*\.isnan\(\)\.any\(\).item\(\)") self.assertRegex(code, r"assert not .*\.isinf\(\)\.any\(\).item\(\)") diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 6e9dce76a2b8b7..2c66765d9aaf97 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -242,7 +242,6 @@ def run(*ex, **kwargs): "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), - "test_randint_distribution_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_randn_generator_dynamic_shapes": TestFailure(("cpu",)), "test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_single_elem_dynamic_shapes": TestFailure(("cpu",)), @@ -389,7 +388,7 @@ def run(*ex, **kwargs): test_failures.update( { "test_custom_op_fixed_layout_sequential_dynamic_shapes": TestFailure( - ("cuda", "xpu") if IS_LINUX else ("cpu", "cuda", "xpu") + ("cuda") if IS_LINUX else ("cpu", "cuda", "xpu") ), } ) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index a6851ea5f9927d..b75907894f63f7 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -12,8 +12,6 @@ import torch.library from torch._dynamo.testing import CompileCounterWithBackend, make_test_cls_with_patches from torch._inductor import metrics -from torch._inductor.codegen.common import device_codegens, register_backend_for_device -from torch._inductor.codegen.cpp import CppScheduling from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_code @@ -34,7 +32,12 @@ TEST_WITH_ASAN, TEST_WITH_ROCM, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_GPU, + patch_inductor_backend, +) # Make the helper files in test/ importable @@ -58,11 +61,10 @@ "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu") ), - "test_randint_distribution_dynamic_shapes": TestFailure(("cuda", "xpu")), } if not torch._inductor.config.cpp_wrapper: test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure( - ("cuda", "xpu") + ("cuda",) ) if TEST_WITH_ROCM: @@ -367,18 +369,20 @@ def f(x): @torch._dynamo.config.patch(capture_scalar_outputs=True) @torch._inductor.config.patch(implicit_fallbacks=True) def test_item_to_inputs_kernel_nobreak(self, device): - @torch.library.custom_op("test::foo", mutates_args=()) - def foo(x: torch.Tensor, y: int) -> torch.Tensor: + @torch.library.custom_op( + "test_inductor_dynamic_shapes::nobreak_test", mutates_args=() + ) + def nobreak_test(x: torch.Tensor, y: int) -> torch.Tensor: return x.clone() - @foo.register_fake + @nobreak_test.register_fake def _(x: torch.Tensor, y: int) -> torch.Tensor: return x.clone() @torch.compile(fullgraph=True) def f(x, r): y = x.item() - return torch.ops.test.foo(r, y) + return torch.ops.test_inductor_dynamic_shapes.nobreak_test(r, y) f(torch.tensor([3], device=device), torch.randn(10, device=device)) @@ -591,11 +595,13 @@ def f(x): ) @torch._inductor.config.patch(implicit_fallbacks=True) def test_multi_output_unbacked_custom_op(self, device): - @torch.library.custom_op("test::foo", mutates_args=()) - def foo(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + @torch.library.custom_op( + "test_inductor_dynamic_shapes::unbacked_test", mutates_args=() + ) + def unbacked_test(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty(2, device=x.device), torch.empty(3, device=x.device) - @foo.register_fake + @unbacked_test.register_fake def _(x: torch.Tensor) -> torch.Tensor: ctx = torch.library.get_ctx() u0 = ctx.new_dynamic_size() @@ -603,7 +609,7 @@ def _(x: torch.Tensor) -> torch.Tensor: @torch.compile(fullgraph=True) def f(x): - a, b = torch.ops.test.foo(x) + a, b = torch.ops.test_inductor_dynamic_shapes.unbacked_test(x) return a.sum() + b.sum() f(torch.tensor([3], device=device)) @@ -643,8 +649,9 @@ def get_same_padding(x: int, k: int, s: int, d: int): def pad_same(x, k, s, d=(1, 1), value=0): ih, iw = x.size()[-2:] - pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding( - iw, k[1], s[1], d[1] + pad_h, pad_w = ( + get_same_padding(ih, k[0], s[0], d[0]), + get_same_padding(iw, k[1], s[1], d[1]), ) if pad_h > 0 or pad_w > 0: x = torch.nn.functional.pad( @@ -913,7 +920,9 @@ def _test_wrapper_codegen_statically_known_int_or_none_in_context(): # testing fn_2 assert ( PythonWrapperCodegen.statically_known_int_or_none(batch_dim) == 5 - ), "Should be limited to exactly 5 on second call due to multiple constraints" + ), ( + "Should be limited to exactly 5 on second call due to multiple constraints" + ) elif call_count == 2: # testing fn_3 assert ( @@ -928,23 +937,13 @@ def generate(self, is_inference, *args, **kwargs): _test_wrapper_codegen_statically_known_int_or_none_in_context() return super().generate(is_inference, *args, **kwargs) - if "cpu" not in device_codegens: - register_backend_for_device("cpu", CppScheduling, PythonWrapperCodegen) - orig_cpu_codegens = device_codegens["cpu"] - try: - register_backend_for_device( - "cpu", orig_cpu_codegens.scheduling, TestWrapperCodegen - ) + with patch_inductor_backend("cpu", python_wrapper_codegen=TestWrapperCodegen): # Compile each of the functions above, with an example input # that has 5 in the first dimension, but is marked as dynamic torch.compile(backend="inductor", dynamic=None)(fn_1)(_x) torch.compile(backend="inductor", dynamic=None)(fn_2)(_x) torch.compile(backend="inductor", dynamic=None)(fn_3)(_x) - finally: - register_backend_for_device( - "cpu", orig_cpu_codegens.scheduling, orig_cpu_codegens.wrapper_codegen - ) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_item_unbacked_stride_nobreak(self, device): diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index f764bea717b368..2b541968455a65 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -136,7 +136,7 @@ def maybe_truncate(x, length=80): if idx >= 0: x = f"{x[:idx]}..." if len(x) > length: - return f"{x[:length - 3]}..." + return f"{x[: length - 3]}..." return x reasons = sorted(set(map(maybe_truncate, failed_reasons[key]))) @@ -391,7 +391,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "rtol": 1e-4, }, ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01}, - # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors + # Following tests are failing with strict comparison but atol=1 is acceptable due roundings errors ("nn.functional.interpolate.bilinear", u8): {"atol": 1, "rtol": 0}, ("nn.functional.upsample_bilinear", u8): {"atol": 1, "rtol": 0}, ("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0}, @@ -420,6 +420,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("cumsum", f16): {"reference_in_float": True}, "cumprod": {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002}, "logcumsumexp": {"grad_atol": 8e-4, "grad_rtol": 0.001}, + ("logcumsumexp", f16): {"grad_atol": 3e-3, "grad_rtol": 0.01}, "exponential": {"reference_in_float": True}, "geometric": {"reference_in_float": True}, ("kron", f16): {"reference_in_float": True}, @@ -429,6 +430,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("nn.functional.batch_norm.without_cudnn", f16): {"reference_in_float": True}, ("nn.functional.cosine_similarity", f16): {"reference_in_float": True}, ("nn.functional.instance_norm", f16): {"reference_in_float": True}, + ("nn.functional.linear", f16): {"atol": 3e-4, "rtol": 0.01}, ("nn.functional.local_response_norm", f16): {"reference_in_float": True}, ("nn.functional.normalize", f16): {"atol": 1e-3, "rtol": 0.05}, ("nn.functional.rms_norm", f16): {"reference_in_float": True}, @@ -521,6 +523,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("baddbmm", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy ("angle", f64): {"reference_in_float": True}, ("asin", f16): {"reference_in_float": True}, + ("asin", f32): {"reference_in_float": True, "atol": 1e-4, "rtol": 1e-4}, ("atanh", f16): {"reference_in_float": True}, "cauchy": {"reference_in_float": True}, ("cummax", f16): {"atol": 5e-4, "rtol": 0.002}, @@ -605,7 +608,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("var_mean", f16): {"atol": 1e-5, "rtol": 2e-3}, ("var_mean.unbiased", f16): {"atol": 1e-5, "rtol": 2e-3}, ("vdot", f16): {"atol": 1e-5, "rtol": 2e-3}, - # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors + # Following tests are failing with strict comparison but atol=1 is acceptable due roundings errors # High atol due to precision loss ("nn.functional.interpolate.bilinear", f64): {"atol": 5e-4, "rtol": 0}, ("nn.functional.upsample_bilinear", f64): {"atol": 5e-4, "rtol": 0}, @@ -850,6 +853,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "nn.functional.adaptive_avg_pool3d": {f16}, "nn.functional.adaptive_max_pool1d": {f16, f32}, "nn.functional.adaptive_max_pool2d": {f16, f32}, + "nn.functional.max_pool2d": {f16, f32, f64}, "nn.functional.bilinear": {f16}, "nn.functional.conv_transpose1d": {f16}, "nn.functional.conv_transpose2d": {f16}, @@ -1034,9 +1038,7 @@ def test_comprehensive(self, device, dtype, op): op_name, set() ) or dtype in inductor_gradient_expected_failures_single_sample[ device_type - ].get( - op_name, set() - ): + ].get(op_name, set()): test_expect = ExpectedTestResult.XFAILURE # noqa: F841 else: test_expect = ExpectedTestResult.SUCCESS # noqa: F841 diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index d029b9b6ee7bc0..940bc24dbd12e4 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -1,17 +1,22 @@ # Owner(s): ["module: inductor"] # ruff: noqa: F841 import contextlib +import dataclasses import importlib +import math import unittest from typing import Any, Callable, Optional, Union import torch import torch.utils._pytree as pytree from torch._inductor import config +from torch._inductor.choices import InductorChoices +from torch._inductor.codegen.triton import FixedTritonConfig from torch._inductor.runtime.hints import TRITON_MAX_BLOCK -from torch._inductor.runtime.runtime_utils import is_power_of_2 +from torch._inductor.runtime.runtime_utils import get_max_y_grid, is_power_of_2 from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import run_and_get_code +from torch._inductor.virtualized import V from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -1141,6 +1146,148 @@ def foo(x, y, z): # Singleton splits should be discarded. self._assert_pointwise_ndims(triton_code, 2) + # Integration test to ensure that matched dims & strides from match_mod_div_expr + # are unsigned and signed integers respectively. This test case has the following + # index:=(ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2)) + # and the match below is a candidate that is invalid: + # match={ + # dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16, + # dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0 + # } + # This is now fixed by ensuring that that wild symbols only match integers + def test_ensure_integral_dims_and_strides(self): + def model(data, *args): + return torch.nn.functional.unfold(data, *args) + + data = torch.zeros( + [2, 3, 5, 5], dtype=torch.float16, requires_grad=True, device=self.device + ) + args = [2, 1, 0, 1] + run_and_compare( + self, + model, + data, + *args, + expected_num_triton_kernels=2, + expected_num_block_pointers=4, + compile_kwargs={"fullgraph": True}, + ) + + # Integration test to test block analysis with index expressions using + # negative strides. + # This test case has the following index: + # index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) + # - 16*(ModularIndexing(xindex, 8, 8)) + 1911 + # subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8)) + # Block analysis should produce the following: + # BlockParameters( + # shape=[8, 8, 8], + # block_shape=[((XBLOCK + 63)//64), Min(8, ((XBLOCK + 7)//8)), Min(8, XBLOCK) ], + # strides=[-256, -16, -1], + # offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8), ModularIndexing(xoffset, 1, 8)] + # ) + # constant_offset = 1911 + def test_negative_strides(self): + def model(x, y): + # Slice in reverse order via a negative stride + return torch.flip(x, [0, 1, 2]) + y + + x, y = ( + self._discontiguous_tensor((8, 8, 8), device=self.device) for _ in range(2) + ) + run_and_compare( + self, + model, + x, + y, + expected_num_triton_kernels=1, + expected_num_block_pointers=3, + ) + + @config.patch("triton.prefer_nd_tiling", True) + @config.patch("triton.max_tiles", 3) + @parametrize( + "block_multiple, ynumel_exceed_ygrid_size, include_z", + [ + # No boundary check in all dimensions + [True, False, True], + # No xdim boundary check, ydim is checked since > max_ygrid + # z dim can be used since its not included + [True, True, False], + # Boundary check in all dimensions + # skip triton_cpu very slow test > 1000s + subtest( + [False, False, True], decorators=[test_torchinductor.skip_if_triton_cpu] + ), + ], + ) + def test_boundary_check(self, block_multiple, ynumel_exceed_ygrid_size, include_z): + @dataclasses.dataclass + class InputShape: + x: int + y: int + z: Optional[int] = None + + def to_list(self): + out = [self.y, self.x] + if self.z is not None: + out.insert(0, self.z) + return out + + BLOCK_SIZE = 8 + DIM_SIZE = BLOCK_SIZE if block_multiple else BLOCK_SIZE + 1 + shape = InputShape(DIM_SIZE, DIM_SIZE, DIM_SIZE if include_z else None) + if ynumel_exceed_ygrid_size: + shape.y = math.ceil(get_max_y_grid()) * shape.y + shape.y + + # Use fixed block sizes to avoid having to generate very large input tensors + class FixedBlockSizeChoices(InductorChoices): + def triton_kernel_kwargs(self, kernel_cls, features, groups, kernel_kwargs): + block_sizes = { + f"{prefix.upper()}BLOCK": BLOCK_SIZE + for prefix, size in dataclasses.asdict(shape).items() + if size is not None + } + kernel_kwargs["fixed_config"] = FixedTritonConfig(block_sizes) + return kernel_kwargs + + a = self._discontiguous_tensor(shape.to_list(), device=self.device) + b_shape = shape.to_list() + b_shape[-1] = 1 + b = self._discontiguous_tensor(b_shape, device=self.device) + + def func(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a + b + + with V.set_choices_handler(FixedBlockSizeChoices()): + result, code = run_and_compare( + self, + func, + a, + b, + expected_num_triton_kernels=1, + expected_num_block_pointers=3, + ) + + code = code[0] + if block_multiple: + if ynumel_exceed_ygrid_size: + self.assertIn( + "yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK", + code, + ) + # Only the y dimension should be boundary checked + # a, b, and output + self.assertEqual(code.count("boundary_check=[0]"), 3) + else: + # No boundary checking + self.assertNotIn("boundary_check", code) + else: + # Loading a + self.assertTrue("boundary_check=[0, 1, 2]" in code) + # Loading b + self.assertTrue("boundary_check=[0, 1]" in code) + @unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend") @config.patch(cpu_backend="triton") diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index a9b085f9995f46..a9f898a36af557 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] +import functools import sys import unittest from unittest.mock import MagicMock, patch @@ -8,7 +9,7 @@ from torch._dynamo.testing import rand_strided from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC from torch._inductor.utils import clone_preserve_strides -from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu +from torch.testing._internal.common_utils import IS_LINUX, runOnRocm, skipIfXpu from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_GPU, @@ -42,6 +43,30 @@ from torch._inductor.test_case import run_tests, TestCase +@triton.jit +def amd_sqr_kernel(in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + data = tl.load(in_ptr + offsets, mask=offsets < numel) + sqr = data * data + tl.store(out_ptr + offsets, sqr, mask=offsets < numel) + + +@functools.lru_cache +def get_autotuned_amd_sqr_kernel(): + return triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE": 64, + "waves_per_eu": 3, + } + ) + ], + key=[], + )(amd_sqr_kernel) + + class TestTritonHeuristics(TestCase): device_type = GPU_TYPE @@ -211,6 +236,32 @@ def test_template_function_ws(self): self.assertEqual(configs[0].num_consumer_groups, num_consumer_groups) self.assertEqual(configs[0].num_buffers_warp_spec, num_buffers_warp_spec) + @runOnRocm + def test_amd_special_config_args(self): + """ + waves_per_eu is an example of a special config arg on AMD; if it is explicitly specified + in a config, the kwarg will exist in the kwargs but not in the function signature. + """ + + @torch.library.triton_op("test_triton_heuristics::triton_sqr", mutates_args=()) + def triton_sqr(x: torch.Tensor) -> torch.Tensor: + y = torch.empty_like(x) + + def grid(meta): + return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) + + torch.library.wrap_triton(get_autotuned_amd_sqr_kernel())[grid]( + x, y, x.numel() + ) + + def fn(x): + return triton_sqr(x) + + x = torch.randn(32, device="cuda") + ref = fn(x) + res = torch.compile(fn)(x) + self.assertEqual(ref, res) + class TestArgumentCloneAndRestore(TestCase): # Our tensor is large enough. If a unexpected copy happens, the diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 68e5bcb3a61fa8..689cf218b2bcd0 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -25,13 +25,22 @@ from torch._library import capture_triton from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import parametrize, skipIfWindows, skipIfXpu +from torch.testing._internal.common_utils import ( + parametrize, + skipIfRocm, + skipIfWindows, + skipIfXpu, +) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU from torch.testing._internal.logging_utils import log_settings, logs_to_string # Defines all the kernels for tests from torch.testing._internal.triton_utils import * # noqa: F403 -from torch.utils._triton import has_triton_package, has_triton_tma +from torch.utils._triton import ( + has_triton_experimental_host_tma, + has_triton_package, + has_triton_tensor_descriptor_host_tma, +) if HAS_GPU: @@ -996,7 +1005,7 @@ def _mul2(x): def f(x): for _ in range(4): # The output of one kernel is the input to the next kernel, but - # at some point we should re-use buffers not allocate new ones. + # at some point we should reuse buffers not allocate new ones. x = _mul2(x) return x + 1 @@ -1014,7 +1023,7 @@ def f(x): num_bufs_allocated = code.count(code_string) self.assertEqual(num_bufs_allocated, 2) - # Check we're re-using buffers if not allocating. + # Check we're reusing buffers if not allocating. num_bufs_reused = code.count( "// reuse" if inductor_config.cpp_wrapper else "# reuse" ) @@ -1299,10 +1308,10 @@ def f(x, y): else: if dynamic: # when half_n_elements passed to the Triton kernel is - # dynamic, equal_to_1 specializaiton can't be enforced + # dynamic, equal_to_1 specialization can't be enforced # also, equal_to_1 specialization doesn't occur (or appear in the signature) - # for newer versions ofo triton (i.e. the ones where triton_version_uses_attrs_dict() == True) + # for newer versions of triton (i.e. the ones where triton_version_uses_attrs_dict() == True) self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0]) else: self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0]) @@ -1641,6 +1650,65 @@ def f(x): self.assertEqual(eager_out, expected_out) self.assertEqual(compiled_out, expected_out) + @requires_gpu + @common_utils.parametrize("dynamic", [False, True]) + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_on_device_tma(self, dynamic, tma_version): + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + + kernel = ( + add_kernel_on_device_tma_new_api + if tma_version == "new" + else add_kernel_on_device_tma_old_api + ) + + def f(a, b): + BLOCK_SIZE = 32 + out = torch.zeros_like(a) + m, n = out.size() + + # Allocate workspace for on-device TMA descriptors + # Need 128 bytes per descriptor, 3 descriptors total + if tma_version == "old": + workspace = torch.zeros(3 * 128, dtype=torch.uint8, device=a.device) + else: + workspace = None + + grid = lambda meta: ( + triton.cdiv(m, meta["BLOCK_SIZE"]), + triton.cdiv(n, meta["BLOCK_SIZE"]), + ) + + kernel[grid]( + a, + b, + out, + m, + n, + workspace, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out + + a = torch.randn((32, 32), device=GPU_TYPE) + b = torch.randn((32, 32), device=GPU_TYPE) + + expected_out = a + b + triton.set_allocator( + lambda size, align, stream: torch.empty( + size, dtype=torch.int8, device=GPU_TYPE + ) + ) + eager_out = f(a, b) + compiled_out = torch.compile(f, fullgraph=True, dynamic=dynamic)(a, b) + + self.assertEqual(eager_out, expected_out) + self.assertEqual(compiled_out, expected_out) + @requires_gpu @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) @@ -1683,12 +1751,22 @@ def f(x, y, z): self.assertEqual(out3, z**2) @requires_gpu - @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") @common_utils.parametrize("dynamic", [False, True]) - def test_tma_capture_and_functionalize(self, dynamic): + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_tma_capture_and_functionalize(self, dynamic, tma_version): + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table kernel_side_table.reset_table() + kernel = ( + add_kernel_with_tma_1d_new_api + if tma_version == "new" + else add_kernel_with_tma_1d_old_api + ) def f(a, b): BLOCK_SIZE = 256 @@ -1696,17 +1774,14 @@ def f(a, b): n_elements = out.numel() desc_a, desc_b, desc_out = ( - triton.tools.experimental_descriptor.create_1d_tma_descriptor( - t.data_ptr(), - n_elements, - BLOCK_SIZE, - t.element_size(), + create_tensor_descriptor_shim( + t, [BLOCK_SIZE], new_api=(tma_version == "new") ) for t in (a, b, out) ) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - add_kernel_with_tma_1d[grid]( + kernel[grid]( desc_a, desc_b, desc_out, @@ -1719,6 +1794,7 @@ def f(a, b): b = torch.randn(301, device=GPU_TYPE) backend = torch._dynamo.testing.AotEagerAndRecordGraphs() + _ = f(a, b) torch.compile( f, fullgraph=True, @@ -1727,34 +1803,70 @@ def f(a, b): )(a, b) if dynamic: - self.assertExpectedInline( - backend.fw_graphs[0].code.strip(), - """\ + if tma_version == "new": + self.assertExpectedInline( + backend.fw_graphs[0].code.strip(), + """\ +def forward(self, arg0_1, arg1_1, arg2_1): + zeros_like = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False) + add_2 = arg0_1 + 256; arg0_1 = None + sub_1 = add_2 - 1; add_2 = None + floordiv = sub_1 // 256; sub_1 = None + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(floordiv, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ('stable', ([256],)), 'in_desc_ptr1': ('stable', ([256],)), 'out_desc_ptr': ('stable', ([256],))}, kwargs = {'in_desc_ptr0': arg1_1, 'in_desc_ptr1': arg2_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); floordiv = arg1_1 = arg2_1 = zeros_like = None + getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None + return (getitem,)""", + ) + elif tma_version == "old": + self.assertExpectedInline( + backend.fw_graphs[0].code.strip(), + """\ def forward(self, arg0_1, arg1_1, arg2_1): zeros_like = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False) add_2 = arg0_1 + 256 sub_1 = add_2 - 1; add_2 = None floordiv = sub_1 // 256; sub_1 = None - triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(floordiv, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([arg0_1], [256], 4), 'in_desc_ptr1': ([arg0_1], [256], 4), 'out_desc_ptr': ([arg0_1], [256], 4)}, kwargs = {'in_desc_ptr0': arg1_1, 'in_desc_ptr1': arg2_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); floordiv = arg0_1 = arg1_1 = arg2_1 = zeros_like = None + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(floordiv, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ('experimental', ([arg0_1], [256], 4)), 'in_desc_ptr1': ('experimental', ([arg0_1], [256], 4)), 'out_desc_ptr': ('experimental', ([arg0_1], [256], 4))}, kwargs = {'in_desc_ptr0': arg1_1, 'in_desc_ptr1': arg2_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); floordiv = arg0_1 = arg1_1 = arg2_1 = zeros_like = None getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None return (getitem,)""", - ) + ) else: - self.assertExpectedInline( - backend.fw_graphs[0].code.strip(), - """\ + if tma_version == "new": + self.assertExpectedInline( + backend.fw_graphs[0].code.strip(), + """\ def forward(self, arg0_1, arg1_1): zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False) - triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(2, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([301], [256], 4), 'in_desc_ptr1': ([301], [256], 4), 'out_desc_ptr': ([301], [256], 4)}, kwargs = {'in_desc_ptr0': arg0_1, 'in_desc_ptr1': arg1_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); arg0_1 = arg1_1 = zeros_like = None + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(2, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ('stable', ([256],)), 'in_desc_ptr1': ('stable', ([256],)), 'out_desc_ptr': ('stable', ([256],))}, kwargs = {'in_desc_ptr0': arg0_1, 'in_desc_ptr1': arg1_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); arg0_1 = arg1_1 = zeros_like = None getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None return (getitem,)""", - ) + ) + elif tma_version == "old": + self.assertExpectedInline( + backend.fw_graphs[0].code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False) + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(2, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ('experimental', ([301], [256], 4)), 'in_desc_ptr1': ('experimental', ([301], [256], 4)), 'out_desc_ptr': ('experimental', ([301], [256], 4))}, kwargs = {'in_desc_ptr0': arg0_1, 'in_desc_ptr1': arg1_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); arg0_1 = arg1_1 = zeros_like = None + getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None + return (getitem,)""", + ) @requires_gpu - @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") @common_utils.parametrize("after_data_ptr", [False, True]) @common_utils.parametrize("after_create_desc", [False, True]) - def test_tma_graph_breaks(self, after_data_ptr, after_create_desc): + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_tma_graph_breaks(self, after_data_ptr, after_create_desc, tma_version): + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + + kernel = ( + add_kernel_with_tma_1d_new_api + if tma_version == "new" + else add_kernel_with_tma_1d_old_api + ) + def f(a, b): BLOCK_SIZE = 256 out = torch.zeros_like(a) @@ -1764,11 +1876,8 @@ def f(a, b): torch._dynamo.graph_break() descs = [ - triton.tools.experimental_descriptor.create_1d_tma_descriptor( - t.data_ptr(), - n_elements, - BLOCK_SIZE, - t.element_size(), + create_tensor_descriptor_shim( + t, [BLOCK_SIZE], new_api=(tma_version == "new") ) for t in (a, b, out) ] @@ -1777,7 +1886,7 @@ def f(a, b): torch._dynamo.graph_break() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - add_kernel_with_tma_1d[grid]( + kernel[grid]( *descs, BLOCK_SIZE=BLOCK_SIZE, ) @@ -1800,27 +1909,35 @@ def f(a, b): self.assertEqual(compiled_out, expected_out) @requires_gpu - @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) - def test_tma_descriptor_1d(self, dynamic, backend): + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_tma_descriptor_1d(self, dynamic, backend, tma_version): + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + + kernel = ( + add_kernel_with_tma_1d_new_api + if tma_version == "new" + else add_kernel_with_tma_1d_old_api + ) + def f(a, b): BLOCK_SIZE = 256 out = torch.zeros_like(a) n_elements = out.numel() desc_a, desc_b, desc_out = ( - triton.tools.experimental_descriptor.create_1d_tma_descriptor( - t.data_ptr(), - n_elements, - BLOCK_SIZE, - t.element_size(), + create_tensor_descriptor_shim( + t, [BLOCK_SIZE], new_api=(tma_version == "new") ) for t in (a, b, out) ) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - add_kernel_with_tma_1d[grid]( + kernel[grid]( desc_a, desc_b, desc_out, @@ -1845,25 +1962,33 @@ def f(a, b): self.assertEqual(compiled_out, expected_out) @requires_gpu - @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") - def test_tma_descriptor_dedup(self): + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_tma_descriptor_dedup(self, tma_version): + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + + kernel = ( + add_kernel_with_tma_1d_new_api + if tma_version == "new" + else add_kernel_with_tma_1d_old_api + ) + def f(a): BLOCK_SIZE = 256 out = torch.zeros_like(a) n_elements = out.numel() desc_a, desc_out = ( - triton.tools.experimental_descriptor.create_1d_tma_descriptor( - t.data_ptr(), - n_elements, - BLOCK_SIZE, - t.element_size(), + create_tensor_descriptor_shim( + t, [BLOCK_SIZE], new_api=(tma_version == "new") ) for t in (a, out) ) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - add_kernel_with_tma_1d[grid]( + kernel[grid]( desc_a, desc_a, desc_out, @@ -1890,13 +2015,27 @@ def f(a): self.assertEqual(compiled_out, expected_out) # 2 calls: one for two inputs (dedupped), one for the output - self.assertEqual(code.count("create_1d_tma_descriptor("), 2) + if tma_version == "new": + self.assertEqual(code.count("TensorDescriptor.from_tensor("), 2) + else: + self.assertEqual(code.count("create_1d_tma_descriptor("), 2) @requires_gpu - @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager"]) - def test_tma_descriptor_2d(self, dynamic, backend): + @common_utils.parametrize("tma_version", ["new", "old"]) + def test_tma_descriptor_2d(self, dynamic, backend, tma_version): + if tma_version == "new" and not has_triton_tensor_descriptor_host_tma(): + self.skipTest("requires triton.tools.tensor_descriptor TMA support") + if tma_version == "old" and not has_triton_experimental_host_tma(): + self.skipTest("requires triton.tools.experimental_descriptor TMA support") + + kernel = ( + add_kernel_with_tma_2d_new_api + if tma_version == "new" + else add_kernel_with_tma_2d_old_api + ) + def f(a, b): BLOCK_SIZE_X = 16 BLOCK_SIZE_Y = 32 @@ -1904,13 +2043,8 @@ def f(a, b): x_size, y_size = out.size() desc_a, desc_b, desc_out = ( - triton.tools.experimental_descriptor.create_2d_tma_descriptor( - t.data_ptr(), - x_size, - y_size, - BLOCK_SIZE_X, - BLOCK_SIZE_Y, - t.element_size(), + create_tensor_descriptor_shim( + t, [BLOCK_SIZE_X, BLOCK_SIZE_Y], new_api=(tma_version == "new") ) for t in (a, b, out) ) @@ -1919,7 +2053,7 @@ def f(a, b): triton.cdiv(x_size, meta["BLOCK_SIZE_X"]), triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]), ) - add_kernel_with_tma_2d[grid]( + kernel[grid]( desc_a, desc_b, desc_out, @@ -2337,15 +2471,63 @@ def fn(x): "'BLOCK_SIZE': 'constexpr'" ).run(code[0]) + @requires_gpu + @inductor_config.patch({"triton.autotune_at_compile_time": True}) + @parametrize("quotes", ["single", "double"]) + def test_kernel_with_docstring(self, quotes): + kernel = ( + kernel_with_docstring_single_quotes + if quotes == "single" + else kernel_with_docstring_double_quotes + ) + + # https://github.com/pytorch/pytorch/issues/155006 + def fn(sz): + x = torch.empty(sz, device=GPU_TYPE) + BLOCK_SIZE = 32 + grid = (triton.cdiv(sz, BLOCK_SIZE),) + kernel[grid](x, sz, BLOCK_SIZE) + return x + + actual = fn(345) + expected = torch.compile(fn, fullgraph=True)(345) + self.assertEqual(actual, expected) + + @requires_gpu + @skipIfRocm + @skipIfXpu + @inductor_config.patch({"triton.autotune_at_compile_time": True}) + @parametrize("quotes", ["single", "double"]) + def test_kernel_inline_asm(self, quotes): + kernel = ( + kernel_inline_asm_single_quotes + if quotes == "single" + else kernel_inline_asm_double_quotes + ) + + # https://github.com/pytorch/pytorch/issues/155006 + def fn(inp): + sz = inp.size(0) + x = torch.empty(sz, device=GPU_TYPE) + BLOCK_SIZE = 32 + grid = (triton.cdiv(sz, BLOCK_SIZE),) + kernel[grid](inp, x, sz, BLOCK_SIZE) + return x + + inp = torch.randn(345, device=GPU_TYPE) + actual = fn(inp) + expected = torch.compile(fn, fullgraph=True)(inp) + self.assertEqual(actual, expected) + def make_mutation_test(fn): @requires_gpu def test_fn(self): from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors - kernel, inputs, outputs = fn() + kernel, inputs, tma_descriptor_metadata, outputs = fn() self.assertListEqual( - identify_mutated_tensors(kernel, inputs), + identify_mutated_tensors(kernel, inputs, tma_descriptor_metadata), outputs, ) @@ -2397,6 +2579,7 @@ def add_kernel_out_of_order( "out_ptr": t, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2428,6 +2611,7 @@ def add_kernel_out_of_order_fn1( "out_ptr": t, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2458,7 +2642,7 @@ def reduce_sum_kernel(a_ptr, c_ptr, stride_am, stride_an): # old TTIR string parsing-based one). remove this gating # and use ["c_ptr"] as `expected` after the new Triton # pin lands both in OSS and internally. - ttir_module, _ = generate_ttir(kernel, kwargs) + ttir_module, _ = generate_ttir(kernel, kwargs, tma_descriptor_metadata={}) if hasattr(ttir_module, "walk"): # with MLIR-based Triton analysis pass expected = ["c_ptr"] @@ -2469,6 +2653,7 @@ def reduce_sum_kernel(a_ptr, c_ptr, stride_am, stride_an): return ( kernel, kwargs, + {}, expected, ) @@ -2499,7 +2684,7 @@ def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an): # old TTIR string parsing-based one). remove this gating # and use ["c_ptr"] as `expected` after the new Triton # pin lands both in OSS and internally. - ttir_module, _ = generate_ttir(kernel, kwargs) + ttir_module, _ = generate_ttir(kernel, kwargs, tma_descriptor_metadata={}) if hasattr(ttir_module, "walk"): # with MLIR-based Triton analysis pass expected = ["c_ptr"] @@ -2510,6 +2695,7 @@ def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an): return ( kernel, kwargs, + {}, expected, ) @@ -2555,7 +2741,7 @@ def cumsum_kernel(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): # old TTIR string parsing-based one). remove this gating # and use ["out_ptr"] as `expected` after the new Triton # pin lands both in OSS and internally. - ttir_module, _ = generate_ttir(kernel, kwargs) + ttir_module, _ = generate_ttir(kernel, kwargs, tma_descriptor_metadata={}) if hasattr(ttir_module, "walk"): # with MLIR-based Triton analysis pass expected = ["out_ptr"] @@ -2566,6 +2752,7 @@ def cumsum_kernel(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): return ( kernel, kwargs, + {}, expected, ) @@ -2599,6 +2786,7 @@ def add_kernel_with_fn_call( "out_ptr": t, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2631,6 +2819,7 @@ def add_kernel_with_fn_call( "out_ptr": t, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2667,6 +2856,7 @@ def nested_cond_op_kernel( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2701,6 +2891,7 @@ def add_4_times_kernel( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2735,6 +2926,7 @@ def add_1_time_kernel( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2770,6 +2962,7 @@ def add_4_times_kernel( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2808,6 +3001,7 @@ def add_4_times_kernel( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2842,6 +3036,7 @@ def kernel_with_label( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ) @@ -2901,6 +3096,7 @@ def fwd_kernel( "BLOCK_SIZE_M": 64, "BLOCK_SIZE_C2": 64, }, + {}, ["O_ptr"], ) @@ -2960,6 +3156,7 @@ def fwd_kernel( "BLOCK_M": M, "BLOCK_N": N, }, + {}, ["o_ptr"], ) @@ -3021,6 +3218,7 @@ def fwd_kernel( "BLOCK_M": M, "BLOCK_N": N, }, + {}, ["o_ptr"], ) @@ -3066,9 +3264,111 @@ def branch_with_multiple_yield_args( "n_elements": 14, "BLOCK_SIZE": 16, }, + {}, ["out_ptr"], ) + def test_get_tma_stores(self): + from torch._higher_order_ops.triton_kernel_wrap import ( + get_tma_stores, + Intermediate, + Op, + Param, + ) + + functions = { + "helper": { + Intermediate(idx=0): [ + Op( + "tt.reinterpret_tensor_descriptor", + None, + [Param(idx=0)], + Intermediate(idx=0), + ) + ], + }, + "main": { + Intermediate(idx=-1): [ + Op( + "tt.call", + "helper", + [Param(idx=0), Param(idx=1)], + Intermediate(idx=-1), + ) + ], + }, + } + + self.assertEqual(get_tma_stores(functions, "helper"), set()) + self.assertEqual(get_tma_stores(functions, "main"), set()) + + functions["helper"][Intermediate(idx=-1)] = [ + Op( + "tt.experimental_descriptor_store", + None, + [Intermediate(idx=0), Param(idx=1)], + Intermediate(idx=-1), + ) + ] + get_tma_stores.reset() + + self.assertEqual( + get_tma_stores(functions, "helper"), {Param(idx=0), Intermediate(idx=0)} + ) + self.assertEqual(get_tma_stores(functions, "main"), {Param(idx=0)}) + + @unittest.skipIf( + not has_triton_experimental_host_tma(), + "requires experimental TMA descriptor API", + ) + @make_mutation_test + def test_add_kernel_on_device_tma_old_api(): + a = torch.randn(1024, 1024) + b = torch.randn(1024, 1024) + c = torch.empty(1024, 1024) + workspace = torch.empty(128 * 3, dtype=torch.int8) + return ( + add_kernel_on_device_tma_old_api, + { + "a_ptr": a, + "b_ptr": b, + "c_ptr": c, + "m": 1024, + "n": 1024, + "workspace": workspace, + "BLOCK_SIZE": 32, + }, + {}, + ["c_ptr", "workspace"], + ) + + @unittest.skipIf( + not has_triton_tensor_descriptor_host_tma(), + "requires TensorDescriptor API in Triton", + ) + @make_mutation_test + def test_add_kernel_on_device_tma_new_api(): + a = torch.randn(1024, 1024) + b = torch.randn(1024, 1024) + c = torch.empty(1024, 1024) + workspace = torch.empty( + 128 * 3, dtype=torch.int8 + ) # Not used by the new API but kept for consistency + return ( + add_kernel_on_device_tma_new_api, + { + "a_ptr": a, + "b_ptr": b, + "c_ptr": c, + "m": 1024, + "n": 1024, + "workspace": workspace, + "BLOCK_SIZE": 32, + }, + {}, + ["c_ptr"], + ) + if HAS_GPU: t = torch.randn(4) @@ -3083,6 +3383,7 @@ def branch_with_multiple_yield_args( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ], [ @@ -3094,6 +3395,7 @@ def branch_with_multiple_yield_args( "x_elements": 4, "y_elements": 4, }, + {}, ["out_ptr"], ], [ @@ -3105,6 +3407,7 @@ def branch_with_multiple_yield_args( "BLOCK_SIZE": 4, "ACTIVATION": "mul2_inplace_kernel", }, + {}, ["in_ptr0", "out_ptr"], ], [ @@ -3116,21 +3419,25 @@ def branch_with_multiple_yield_args( "BLOCK_SIZE": 4, "ACTIVATION": "add_kernel", }, + {}, ["out_ptr"], ], [ mul2_inplace_kernel, {"ptr": t, "n_elements": 4, "BLOCK_SIZE": 4}, + {}, ["ptr"], ], [ inline_asm_kernel_is_pure_true, {"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4}, + {}, ["Z"], ], [ inline_asm_kernel_is_pure_false, {"X": t, "Y": t, "Z": t, "n": 4, "BLOCK": 4}, + {}, ["X", "Y", "Z"], ], [ @@ -3142,6 +3449,7 @@ def branch_with_multiple_yield_args( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["output_ptr"], ], [ @@ -3152,6 +3460,7 @@ def branch_with_multiple_yield_args( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["output_ptr"], ], [ @@ -3163,6 +3472,7 @@ def branch_with_multiple_yield_args( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ], [ @@ -3174,6 +3484,7 @@ def branch_with_multiple_yield_args( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ], [ @@ -3185,6 +3496,7 @@ def branch_with_multiple_yield_args( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ], [ @@ -3196,16 +3508,21 @@ def branch_with_multiple_yield_args( "n_elements": 4, "BLOCK_SIZE": 4, }, + {}, ["out_ptr"], ], ] - for kernel, inputs, outputs in tests: + for kernel, inputs, tma_descriptor_metadata, outputs in tests: fn = make_mutation_test( # Add default arguments to avoid Python lambda capture pitfall # This forces the capture at lambda creation - lambda kernel=kernel, inputs=inputs, outputs=outputs: ( + lambda kernel=kernel, + inputs=inputs, + tma_descriptor_metadata=tma_descriptor_metadata, + outputs=outputs: ( kernel, inputs, + tma_descriptor_metadata, outputs, ) ) @@ -3673,9 +3990,10 @@ def grid(META): torch._dynamo.decorators.mark_unbacked(x, 0) - with log_settings("+output_code"), self.assertLogs( - logger="torch._inductor", level=logging.DEBUG - ) as log: + with ( + log_settings("+output_code"), + self.assertLogs(logger="torch._inductor", level=logging.DEBUG) as log, + ): foo(x, w) output = "\n".join(record.getMessage() for record in log.records) diff --git a/test/inductor/test_utils.py b/test/inductor/test_utils.py index e06d85a6d1bb39..379f40e6e13de4 100644 --- a/test/inductor/test_utils.py +++ b/test/inductor/test_utils.py @@ -59,7 +59,7 @@ def testSympySubs(self): result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")}) self.assertEqual(result.name, "x") - # replaced cant be string + # replaced can't be string self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"}) # replaced can be an expression diff --git a/test/jit/test_alias_analysis.py b/test/jit/test_alias_analysis.py index 222140dc560209..8905872c5c3cc8 100644 --- a/test/jit/test_alias_analysis.py +++ b/test/jit/test_alias_analysis.py @@ -2,18 +2,13 @@ import torch from torch._C import parse_ir -from torch.testing._internal.common_utils import TemporaryFileName +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + TemporaryFileName, +) from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestAliasAnalysis(JitTestCase): def test_becomes_wildcard_annotations(self): graph_str = """ @@ -28,7 +23,7 @@ def test_becomes_wildcard_annotations(self): graph = parse_ir(graph_str) alias_db = graph.alias_db() split_node = graph.findNode("aten::split") - # split input enters wildcard set, list initalized as containing wildcard set + # split input enters wildcard set, list initialized as containing wildcard set self.assertTrue( alias_db.may_contain_alias(next(split_node.inputs()), split_node.output()) ) @@ -154,3 +149,7 @@ def forward(self, x): mod = ModuleWrapper(module_list) mod = torch.jit.script(mod) mod(torch.zeros((2, 2))) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_async.py b/test/jit/test_async.py index e5d5de52bc1318..b739963ad5ea1a 100644 --- a/test/jit/test_async.py +++ b/test/jit/test_async.py @@ -16,6 +16,7 @@ from torch import Tensor from torch.jit import Future +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import _inline_everything, JitTestCase @@ -547,8 +548,4 @@ def fn_float(x: int) -> Any: if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_aten_pow.py b/test/jit/test_aten_pow.py index d227f252504912..754970263c57ea 100644 --- a/test/jit/test_aten_pow.py +++ b/test/jit/test_aten_pow.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: jit"] import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase class TestAtenPow(TestCase): @@ -99,3 +99,7 @@ def fn_float_float(a: float, b: float): self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0)) # zero base and negative exponent case that should trigger RunTimeError self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_attr.py b/test/jit/test_attr.py index 2e641b91d66ec1..d9d5fab1615aee 100644 --- a/test/jit/test_attr.py +++ b/test/jit/test_attr.py @@ -4,17 +4,10 @@ import torch from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestGetDefaultAttr(JitTestCase): def test_getattr_with_default(self): class A(torch.nn.Module): @@ -66,3 +59,7 @@ def fn(x: Tuple[str, int]) -> int: with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"): torch.jit.script(fn) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_autodiff.py b/test/jit/test_autodiff.py index 0594efd6ea51ad..798f382968fe92 100644 --- a/test/jit/test_autodiff.py +++ b/test/jit/test_autodiff.py @@ -4,7 +4,10 @@ from typing import List import torch -from torch.testing._internal.common_utils import skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, +) from torch.testing._internal.jit_utils import JitTestCase @@ -145,3 +148,7 @@ def fn(a, b, c): self.assertEqual(x_s.requires_grad, x.requires_grad) self.assertEqual(y_s.requires_grad, y.requires_grad) self.assertEqual(z_s.requires_grad, z.requires_grad) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index ea367108788b41..f42aa7f8f43657 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -20,20 +20,13 @@ from typing import List, Optional, Tuple from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import ( disable_autodiff_subgraph_inlining, JitTestCase, ) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - @unittest.skipIf( GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" ) @@ -589,3 +582,7 @@ def test_has_profiled_info_aliasing_outputs(self): FileCheck().check("= prim::DifferentiableGraph").check( "with prim::DifferentiableGraph" ).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_await.py b/test/jit/test_await.py index 7a65beb9bdbde8..564b61d03c621d 100644 --- a/test/jit/test_await.py +++ b/test/jit/test_await.py @@ -6,6 +6,7 @@ import torch from torch import Tensor from torch._awaits import _Await as Await +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global @@ -390,3 +391,7 @@ def main(x: Tensor) -> Tensor: sm = torch.jit.load(iofile) script_out_load = sm(inp) self.assertTrue(torch.allclose(expected, script_out_load)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_backend_nnapi.py b/test/jit/test_backend_nnapi.py index 9f4771665020f6..042c82eca803ae 100644 --- a/test/jit/test_backend_nnapi.py +++ b/test/jit/test_backend_nnapi.py @@ -7,7 +7,11 @@ import torch import torch._C -from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + IS_FBCODE, + raise_on_run_directly, + skipIfTorchDynamo, +) # hacky way to skip these tests in fbcode: @@ -28,13 +32,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - """ Unit Tests for Nnapi backend with delegate Inherits most tests from TestNNAPI, which loads Android NNAPI models @@ -139,3 +136,7 @@ def test_compile_spec_santiy(self): def tearDown(self): # Change dtype back to default (Otherwise, other unit tests will complain) torch.set_default_dtype(self.default_dtype) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py index 8453f59cfdbe74..4c593b9c865d20 100644 --- a/test/jit/test_backends.py +++ b/test/jit/test_backends.py @@ -15,6 +15,7 @@ IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, + raise_on_run_directly, skipIfRocm, TEST_WITH_ROCM, ) @@ -25,13 +26,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - def to_test_backend(module, method_compile_spec): return torch._C._jit_to_backend( @@ -822,3 +816,7 @@ def test_attribute(self): ) self.assertEqual(pre_bundled, post_bundled) self.assertEqual(post_bundled, post_load) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_batch_mm.py b/test/jit/test_batch_mm.py index a0f0eb76bbba20..e0b2c640898fe2 100644 --- a/test/jit/test_batch_mm.py +++ b/test/jit/test_batch_mm.py @@ -2,24 +2,19 @@ import torch from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestBatchMM(JitTestCase): @staticmethod def _get_test_tensors(n: int): return [ - torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]]) - if x % 2 == 0 - else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]]) + ( + torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]]) + if x % 2 == 0 + else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]]) + ) for x in range(n) ] @@ -288,3 +283,7 @@ def test_batch_mm(n: int): FileCheck().check_count("aten::mm", 10, exactly=True).check_not( "prim::MMBatchSide" ).run(test_batch_mm.graph) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index 510b911e463397..b84bc96519cbc9 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -13,17 +13,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestBuiltins(JitTestCase): """ Tests for TorchScript support of Python builtin functions. @@ -299,3 +292,7 @@ def test_func(func, x, tensor): self.assertEqual( test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor) ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index 02182b3b2fbfc5..0ae1c3dcfd307a 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -18,18 +18,14 @@ from typing import Dict, Iterable, List, Optional, Tuple import torch.testing._internal.jit_utils -from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + IS_SANDCASTLE, + raise_on_run_directly, + skipIfTorchDynamo, +) from torch.testing._internal.jit_utils import JitTestCase, make_global -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestClassType(JitTestCase): def test_reference_semantics(self): """ @@ -1667,3 +1663,7 @@ def fn_e(): for fn in (fn_a, fn_b, fn_c, fn_d, fn_e): with self.assertRaisesRegex(RuntimeError, error_message_regex): torch.jit.script(fn) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_complex.py b/test/jit/test_complex.py index 0cbec5605c08e7..388a93c4a04ea0 100644 --- a/test/jit/test_complex.py +++ b/test/jit/test_complex.py @@ -8,7 +8,7 @@ from typing import Dict, List import torch -from torch.testing._internal.common_utils import IS_MACOS +from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly from torch.testing._internal.jit_utils import execWrapper, JitTestCase @@ -617,3 +617,7 @@ def div(x: complex, y: torch.Tensor): scripted = torch.jit.script(op) jit_result = scripted(x, y) self.assertEqual(eager_result, jit_result) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_complexity.py b/test/jit/test_complexity.py index cd022eb52244a5..2fa038d149688d 100644 --- a/test/jit/test_complexity.py +++ b/test/jit/test_complexity.py @@ -13,7 +13,6 @@ sys.path.append(pytorch_test_dir) from torch.testing._internal.common_utils import ( IS_FBCODE, - run_tests, set_default_dtype, suppress_warnings, ) @@ -105,4 +104,7 @@ def test_nn_module_tests(self): if __name__ == "__main__": - run_tests() + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/jit/test_convert_activation.py b/test/jit/test_convert_activation.py index 56826193ce7862..90cb26ce2633e0 100644 --- a/test/jit/test_convert_activation.py +++ b/test/jit/test_convert_activation.py @@ -22,16 +22,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - activations = [ F.celu, F.elu, @@ -204,3 +198,7 @@ def test_resnet18_correctness(self): inp = torch.randn(N, C, H, W) self.run_pass("inplace_to_functional_activation", frozen_model.graph) self.assertEqual(model(inp), frozen_model(inp)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_cuda.py b/test/jit/test_cuda.py index fb7e5cd325d4a5..f026fa0188b2a0 100644 --- a/test/jit/test_cuda.py +++ b/test/jit/test_cuda.py @@ -12,6 +12,7 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( NoTest, + raise_on_run_directly, skipCUDANonDefaultStreamIf, skipIfRocm, TEST_CUDA, @@ -36,13 +37,6 @@ torch.ones(1).cuda() # initialize cuda context TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9 -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestCUDA(JitTestCase): """ @@ -698,3 +692,7 @@ def fn(device: int, tensor): FileCheck().check("cuda::_maybe_exchange_device(").run(g) torch._C._jit_pass_inline(g) FileCheck().check("cuda::_maybe_exchange_device(").run(g) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_custom_operators.py b/test/jit/test_custom_operators.py index 498179f910646c..02fb5d28519ede 100644 --- a/test/jit/test_custom_operators.py +++ b/test/jit/test_custom_operators.py @@ -10,17 +10,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - def canonical(graph): return torch._C._jit_pass_canonicalize(graph).str(False) @@ -151,3 +144,7 @@ def test_generic_list(self): def test_where_no_scalar(self): x = torch.rand(1, 3, 224, 224) torch.ops.aten.where(x > 0.5, -1.5, 1.5) # does not raise + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_data_parallel.py b/test/jit/test_data_parallel.py index fc69e86a8cf4e3..6f9351a0766a0b 100644 --- a/test/jit/test_data_parallel.py +++ b/test/jit/test_data_parallel.py @@ -12,17 +12,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestDataParallel(JitTestCase): class Mpy(torch.nn.Module): def __init__(self) -> None: @@ -158,3 +151,7 @@ def test_tensor_sharing_with_forward(self): x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1) r1_forward = replica[1](x1) self.assertEqual(first_forward, r1_forward) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_dataclasses.py b/test/jit/test_dataclasses.py index e678937d153d84..6c04ecfae6dccc 100644 --- a/test/jit/test_dataclasses.py +++ b/test/jit/test_dataclasses.py @@ -7,6 +7,7 @@ from hypothesis import given, settings, strategies as st import torch +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase @@ -168,3 +169,7 @@ def f(a: MixupParams3): with self.assertRaises(OSError): torch.jit.script(f) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_dce.py b/test/jit/test_dce.py index dddfcb1c8b7de1..e89862b085aa36 100644 --- a/test/jit/test_dce.py +++ b/test/jit/test_dce.py @@ -2,6 +2,7 @@ import torch from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global @@ -73,3 +74,7 @@ def fn(x: torch.Tensor): torch._C._jit_pass_dce_graph(fn_s.graph) FileCheck().check("aten::add_").run(fn_s.graph) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_decorator.py b/test/jit/test_decorator.py index c0d95d9f039b21..793b406a2f6530 100644 --- a/test/jit/test_decorator.py +++ b/test/jit/test_decorator.py @@ -19,3 +19,10 @@ def test_decorator(self): fn = my_function_a fx = torch.jit.script(fn) self.assertEqual(fn(1.0), fx(1.0)) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/jit/test_device_analysis.py b/test/jit/test_device_analysis.py index e2cb461924d89a..67dce398b58f1a 100644 --- a/test/jit/test_device_analysis.py +++ b/test/jit/test_device_analysis.py @@ -5,7 +5,7 @@ import torch from torch.jit._passes._property_propagation import apply_input_props_using_example -from torch.testing._internal.common_utils import TEST_CUDA +from torch.testing._internal.common_utils import raise_on_run_directly, TEST_CUDA from torch.testing._internal.jit_utils import JitTestCase @@ -14,13 +14,6 @@ except ImportError: models = None -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestDeviceAnalysis(JitTestCase): @classmethod @@ -336,3 +329,7 @@ def test_fn(x, y, z: bool, a: bool): test_fn, [self.mkldnn, self.mkldnn, None, None], self.mkldnn ) self.assert_device_equal(test_fn, [self.cpu, self.cuda, None, None], None) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_dtype_analysis.py b/test/jit/test_dtype_analysis.py index 1a5fd2038bdb10..f2acd3fe5087f5 100644 --- a/test/jit/test_dtype_analysis.py +++ b/test/jit/test_dtype_analysis.py @@ -17,7 +17,11 @@ sample_inputs_conv2d, SampleInput, ) -from torch.testing._internal.common_utils import first_sample, set_default_dtype +from torch.testing._internal.common_utils import ( + first_sample, + raise_on_run_directly, + set_default_dtype, +) from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn from torch.testing._internal.jit_utils import JitTestCase @@ -27,14 +31,6 @@ """ -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - custom_rules_works_list = { "nn.functional.adaptive_avg_pool1d", "nn.functional.adaptive_avg_pool2d", @@ -386,3 +382,6 @@ def test_custom_rules_expected_failure(self, device, dtype, op): TestDtypeCustomRulesCPU = None # This creates TestDtypeCustomRulesCPU instantiate_device_type_tests(TestDtypeCustomRules, globals(), only_for=("cpu",)) + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_enum.py b/test/jit/test_enum.py index d7b703eb1b0323..2308ebb4f4ef10 100644 --- a/test/jit/test_enum.py +++ b/test/jit/test_enum.py @@ -12,17 +12,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestEnum(JitTestCase): def test_enum_value_types(self): class IntEnum(Enum): @@ -358,3 +351,7 @@ class Color(int, Enum): @torch.jit.script def is_red(x: Color) -> bool: return x == Color.RED + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_exception.py b/test/jit/test_exception.py index 38d9424d3b7418..f46b6bc1f3275e 100644 --- a/test/jit/test_exception.py +++ b/test/jit/test_exception.py @@ -197,3 +197,10 @@ def fn(): "jit.myexception.MyKeyError: This is a user defined key error", ): fn() + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 7da41f0cc71389..8258124680b470 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -15,6 +15,7 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_quantized import override_quantized_engine from torch.testing._internal.common_utils import ( + raise_on_run_directly, set_default_dtype, skipCUDAMemoryLeakCheckIf, skipIfTorchDynamo, @@ -32,13 +33,6 @@ HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None @@ -55,7 +49,7 @@ def __init__(self) -> None: self.a = 1 # folded self.b = 1.2 # folded self.c = "hello" # folded - self.c2 = "hi\xA1" # not folded + self.c2 = "hi\xa1" # not folded self.d = [1, 1] # folded self.e = [1.0, 1.1] # folded self.f = ["hello", "world"] # folded @@ -67,7 +61,7 @@ def __init__(self) -> None: torch.tensor([5.5], requires_grad=True), ) # folded self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]} - self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]} + self.h2 = {"layer\xb1": [torch.tensor([8.8], requires_grad=True)]} self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded self.ts = [ torch.tensor([1.0, 2.0], requires_grad=True), @@ -3461,3 +3455,7 @@ def forward(self, x): mod = self.freezeAndConvert(mod_eager) FileCheck().check("aten::add_").run(mod.graph) self.checkResults(mod_eager, mod) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_functional_blocks.py b/test/jit/test_functional_blocks.py index 29f180b66fff98..40dff3765fe05d 100644 --- a/test/jit/test_functional_blocks.py +++ b/test/jit/test_functional_blocks.py @@ -10,17 +10,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestFunctionalBlocks(JitTestCase): def test_subgraph_creation(self): def fn(x, y, z): @@ -54,3 +47,7 @@ def fn(x, y, z): FileCheck().check("add").check("add_").check_not("mul").check( "FunctionalGraph" ).run(graph) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_fuser_common.py b/test/jit/test_fuser_common.py index 9b0921d22b1f2f..81cf534b74eaf2 100644 --- a/test/jit/test_fuser_common.py +++ b/test/jit/test_fuser_common.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: jit"] import torch +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase @@ -19,3 +20,7 @@ def fn(x): # test fallback when optimization is not applicable y = fn(torch.randn(5, requires_grad=rq)) self.assertEqual(y.requires_grad, rq) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit_fuser_te.py") diff --git a/test/jit/test_generator.py b/test/jit/test_generator.py index 5f6e15cf8fe364..6fe35582063972 100644 --- a/test/jit/test_generator.py +++ b/test/jit/test_generator.py @@ -6,18 +6,13 @@ import torch from torch.nn import init -from torch.testing._internal.common_utils import skipIfLegacyJitExecutor +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfLegacyJitExecutor, +) from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestGenerator(JitTestCase): # torch.jit.trace does not properly capture the generator manual seed # and thus is non deterministic even if the generator is manually seeded @@ -193,3 +188,7 @@ def forward(self, x): except: # noqa: B001, E722 print(loaded_module.forward.code) raise + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_graph_rewrite_passes.py b/test/jit/test_graph_rewrite_passes.py index 061ef66aa1907f..f9b30704fd951c 100644 --- a/test/jit/test_graph_rewrite_passes.py +++ b/test/jit/test_graph_rewrite_passes.py @@ -3,6 +3,7 @@ import torch import torch._C from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase @@ -59,3 +60,7 @@ def forward(self, x): FileCheck().check_not("aten::linear").run(model.graph) # make sure it runs model(x) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_hash.py b/test/jit/test_hash.py index 439cd62a5bb1a0..21c99a8a426c88 100644 --- a/test/jit/test_hash.py +++ b/test/jit/test_hash.py @@ -10,17 +10,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestHash(JitTestCase): def test_hash_tuple(self): def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool: @@ -115,3 +108,7 @@ def fn(d1: torch.device, d2: torch.device): self.checkScript(fn, (gpu0, gpu1)) self.checkScript(fn, (gpu0, cpu)) self.checkScript(fn, (cpu, cpu)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_hooks.py b/test/jit/test_hooks.py index 33e84440bbe626..157dd588aa78c4 100644 --- a/test/jit/test_hooks.py +++ b/test/jit/test_hooks.py @@ -33,17 +33,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - # Tests for JIT forward hooks and pre-hooks class TestHooks(JitTestCase): def test_module_no_forward_input(self): @@ -393,3 +386,7 @@ def forward_hook_wrong_output_from_prev_hook( r"Received type: 'str'. Expected type: 'Tuple\[str\]'", ): torch.jit.script(m) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_hooks_modules.py b/test/jit/test_hooks_modules.py index ffcd6fea37fd79..a4062c0dd889ce 100644 --- a/test/jit/test_hooks_modules.py +++ b/test/jit/test_hooks_modules.py @@ -528,3 +528,9 @@ def forward_hook(self, input: Tuple[str], output: str): m.submodule.register_forward_hook(forward_hook) return m + + +if __name__ == "__main__": + raise RuntimeError( + "This file is a collection of utils, it should be imported not executed directly" + ) diff --git a/test/jit/test_ignorable_args.py b/test/jit/test_ignorable_args.py index 07968319caf331..9dea0e30a85b39 100644 --- a/test/jit/test_ignorable_args.py +++ b/test/jit/test_ignorable_args.py @@ -11,17 +11,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - # Tests that Python slice class is supported in TorchScript class TestIgnorableArgs(JitTestCase): def test_slice_ignorable_args_for_slice(self): @@ -61,3 +54,7 @@ def fn(x: torch.Tensor, y: torch.Tensor): torch.add(x, y, out=y) FileCheck().check("torch.add(x, y, out=y)").run(fn.code) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_ignore_context_manager.py b/test/jit/test_ignore_context_manager.py index b0d5bf457000d0..59b27cba52a783 100644 --- a/test/jit/test_ignore_context_manager.py +++ b/test/jit/test_ignore_context_manager.py @@ -11,17 +11,10 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestIgnoreContextManager(JitTestCase): @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required") def test_with_ignore_context_manager_with_inp_out(self): @@ -103,3 +96,7 @@ def forward(self): s = torch.jit.script(model) self.assertEqual(s(), 5) self.assertEqual(s(), model()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_isinstance.py b/test/jit/test_isinstance.py index 53b701590f78d6..0781ad9c747652 100644 --- a/test/jit/test_isinstance.py +++ b/test/jit/test_isinstance.py @@ -11,17 +11,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - # Tests for torch.jit.isinstance class TestIsinstance(JitTestCase): def test_int(self): @@ -354,3 +347,7 @@ def test_empty_container_special_cases(self): # Should not throw "Boolean value of Tensor with more than # one value is ambiguous" error torch._jit_internal.check_empty_containers(torch.rand(2, 3)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_jit_utils.py b/test/jit/test_jit_utils.py index 4e2e2898f09389..b6eb2e5901cc13 100644 --- a/test/jit/test_jit_utils.py +++ b/test/jit/test_jit_utils.py @@ -11,17 +11,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - # Tests various JIT-related utility functions. class TestJitUtils(JitTestCase): # Tests that POSITIONAL_OR_KEYWORD arguments are captured. @@ -116,3 +109,7 @@ def test_no_tracer_warn_context_manager(self): with jit_utils.NoTracerWarnContextManager(): self.assertEqual(False, torch._C._jit_get_tracer_state_warn()) self.assertEqual(True, torch._C._jit_get_tracer_state_warn()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 53245e811ec4cc..58bd66e7df165c 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -19,18 +19,14 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.common_utils import skipIfTorchDynamo, TEST_CUDA +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, + TEST_CUDA, +) from torch.testing._internal.jit_utils import JitTestCase, make_global -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestList(JitTestCase): def test_list_bool_conversion(self): def if_predicate(l: List[int]): @@ -1825,7 +1821,7 @@ def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]: def test_popitem(self): @torch.jit.script def popitem( - x: Dict[str, Tensor] + x: Dict[str, Tensor], ) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]: item = x.popitem() return item, x @@ -2996,3 +2992,7 @@ def forward(self): for i in range(300): test = Test() test_script = torch.jit.script(test) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_logging.py b/test/jit/test_logging.py index 366a6b93442c5f..e03ffa9e0a1376 100644 --- a/test/jit/test_logging.py +++ b/test/jit/test_logging.py @@ -10,17 +10,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestLogging(JitTestCase): def test_bump_numeric_counter(self): class ModuleThatLogs(torch.jit.ScriptModule): @@ -122,3 +115,7 @@ def foo(x): def test_logging_levels_set(self): torch._C._jit_set_logging_option("foo") self.assertEqual("foo", torch._C._jit_get_logging_option()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_misc.py b/test/jit/test_misc.py index 8c63e61a8daac8..8c584d1f3102fb 100644 --- a/test/jit/test_misc.py +++ b/test/jit/test_misc.py @@ -12,7 +12,7 @@ from jit.test_module_interface import TestModuleInterface # noqa: F401 from torch import jit from torch.testing import FileCheck -from torch.testing._internal.common_utils import freeze_rng_state +from torch.testing._internal.common_utils import freeze_rng_state, raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF @@ -20,13 +20,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestMisc(JitTestCase): def test_joined_str(self): @@ -129,7 +122,7 @@ def forward(x: Any) -> str: def test_subexpression_Tuple_int_int_Future(self): @torch.jit.script def fn( - x: Tuple[int, int, torch.jit.Future[int]] + x: Tuple[int, int, torch.jit.Future[int]], ) -> Tuple[int, torch.jit.Future[int]]: return x[0], x[2] @@ -147,7 +140,7 @@ def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]: def test_subexpression_Optional(self): @torch.jit.script def fn( - x: Optional[Dict[int, torch.jit.Future[int]]] + x: Optional[Dict[int, torch.jit.Future[int]]], ) -> Optional[torch.jit.Future[int]]: if x is not None: return x[0] @@ -504,3 +497,7 @@ def test_jit_get_operation_order(self): self.assertTrue(len(complex_indices) > 0) self.assertTrue(len(Scalar_indices) > 0) self.assertTrue(complex_indices[0] > Scalar_indices[0]) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_models.py b/test/jit/test_models.py index 7ee9ef365eb4e2..c6364f10197d1a 100644 --- a/test/jit/test_models.py +++ b/test/jit/test_models.py @@ -11,24 +11,19 @@ enable_profiling_mode_for_profiling_tests, GRAPH_EXECUTOR, ProfilingMode, + raise_on_run_directly, set_default_dtype, + slowTest, + suppress_warnings, ) # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.common_utils import slowTest, suppress_warnings from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - try: import torchvision @@ -84,7 +79,7 @@ def __init__(self, nz, ngf, nc): nn.ReLU(True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), - nn.Tanh() + nn.Tanh(), # state size. (nc) x 64 x 64 ) @@ -754,3 +749,7 @@ def test_alexnet(self): m = self.createFunctionFromGraph(g) with torch.random.fork_rng(devices=[]): self.assertEqual(outputs, m(*inputs)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_module_apis.py b/test/jit/test_module_apis.py index 24a50b0164276c..d7d0c022ccf900 100644 --- a/test/jit/test_module_apis.py +++ b/test/jit/test_module_apis.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List import torch +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase @@ -12,13 +13,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestModuleAPIs(JitTestCase): def test_default_state_dict_methods(self): @@ -141,3 +135,7 @@ def forward(self, x): self.assertFalse(m2.sub.customized_load_state_dict_called) m2.load_state_dict(state_dict) self.assertTrue(m2.sub.customized_load_state_dict_called) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py index e8200eb2c09b3a..a6564c4b4c4ded 100644 --- a/test/jit/test_module_containers.py +++ b/test/jit/test_module_containers.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase @@ -14,13 +15,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestModuleContainers(JitTestCase): def test_sequential_intermediary_types(self): @@ -756,3 +750,7 @@ def forward(self, x): ) self.checkModule(MyModule(), (torch.ones(1),)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_module_interface.py b/test/jit/test_module_interface.py index ad30ea3492d3eb..c9765b4e282ffe 100644 --- a/test/jit/test_module_interface.py +++ b/test/jit/test_module_interface.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn from torch import Tensor +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global @@ -15,13 +16,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class OrigModule(nn.Module): def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: @@ -701,3 +695,7 @@ def method(self, input): with self.assertRaisesRegex(Exception, "Could not compile"): scripted_mod = torch.jit.script(TestModule()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_modules.py b/test/jit/test_modules.py index 3602887133d957..ff4ca58e557e43 100644 --- a/test/jit/test_modules.py +++ b/test/jit/test_modules.py @@ -4,6 +4,7 @@ import sys import torch +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase @@ -11,13 +12,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestModules(JitTestCase): def test_script_module_with_constants_list(self): @@ -27,7 +21,7 @@ def test_script_module_with_constants_list(self): """ # torch.nn.Linear has a __constants__ attribute defined - # and intialized to a list. + # and initialized to a list. class Net(torch.nn.Linear): x: torch.jit.Final[int] @@ -36,3 +30,7 @@ def __init__(self) -> None: self.x = 0 self.checkModule(Net(), (torch.randn(5),)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_op_decompositions.py b/test/jit/test_op_decompositions.py index bd9ced8daa855c..dacd829e7939ae 100644 --- a/test/jit/test_op_decompositions.py +++ b/test/jit/test_op_decompositions.py @@ -2,17 +2,10 @@ import torch from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestOpDecompositions(JitTestCase): def test_op_decomposition(self): def foo(x): @@ -42,3 +35,7 @@ def square_decomp(x): FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph) x = torch.rand([4]) self.assertEqual(foo(x), torch.square(x)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_optimize_for_mobile_preserve_debug_info.py b/test/jit/test_optimize_for_mobile_preserve_debug_info.py index d405b2764e6fa2..d643a670be356d 100644 --- a/test/jit/test_optimize_for_mobile_preserve_debug_info.py +++ b/test/jit/test_optimize_for_mobile_preserve_debug_info.py @@ -3,7 +3,7 @@ import torch import torch._C import torch.nn.functional as F -from torch.testing._internal.common_utils import skipIfNoXNNPACK +from torch.testing._internal.common_utils import raise_on_run_directly, skipIfNoXNNPACK from torch.testing._internal.jit_utils import JitTestCase @@ -263,3 +263,7 @@ def test_fuse_activation_with_pack_ops_linear_conv2d_4(self): conv2d_activation=F.relu, conv2d_activation_kind="aten::relu", ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_parametrization.py b/test/jit/test_parametrization.py index 1372885e5db684..3be2fc526f545a 100644 --- a/test/jit/test_parametrization.py +++ b/test/jit/test_parametrization.py @@ -4,17 +4,10 @@ import torch import torch.nn.utils.parametrize as parametrize from torch import nn +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestParametrization(JitTestCase): # Define some parametrization class Symmetric(nn.Module): @@ -68,3 +61,7 @@ def test_scriptable(self): # Check the scripting process throws an error when caching with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"): scripted_model = torch.jit.trace_module(model) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_pdt.py b/test/jit/test_pdt.py index 7e77d93cdcd6af..0ac620b368b6e5 100644 --- a/test/jit/test_pdt.py +++ b/test/jit/test_pdt.py @@ -6,7 +6,7 @@ import torch from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED -from torch.testing._internal.common_utils import NoTest +from torch.testing._internal.common_utils import NoTest, raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global @@ -21,13 +21,6 @@ ) JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestPDT(JitTestCase): """ @@ -896,3 +889,7 @@ def test_none(a) -> Any: torch.ones(1), ), ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_peephole.py b/test/jit/test_peephole.py index ac2f54bfe260ba..914d423a5196da 100644 --- a/test/jit/test_peephole.py +++ b/test/jit/test_peephole.py @@ -6,17 +6,10 @@ import torch from torch import nn from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestPeephole(JitTestCase): def test_peephole_with_writes(self): def test_write(x): @@ -890,3 +883,7 @@ def foo(x: int, y: int): self.run_pass("peephole", foo.graph) FileCheck().check("aten::slice").run(foo.graph) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py index 29f3cc9be4cdac..fba01ee67082ba 100644 --- a/test/jit/test_profiler.py +++ b/test/jit/test_profiler.py @@ -4,7 +4,10 @@ import sys import torch -from torch.testing._internal.common_utils import skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, +) # Make the helper files in test/ importable @@ -13,14 +16,6 @@ from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - @skipIfTorchDynamo() class TestProfiler(JitTestCase): def setUp(self): @@ -284,3 +279,7 @@ def foo(a, b, c, d): g = torch.jit.last_executed_optimized_graph() self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_python_bindings.py b/test/jit/test_python_bindings.py index 9de3bf02b7ff8b..2af4552dcebaa5 100644 --- a/test/jit/test_python_bindings.py +++ b/test/jit/test_python_bindings.py @@ -2,17 +2,10 @@ import torch from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TestPythonBindings\n\n" - "instead." - ) - - class TestPythonBindings(JitTestCase): def test_cu_get_functions(self): @torch.jit.script @@ -114,3 +107,7 @@ def test_canonicalize(self): graph3 = torch._C.parse_ir(ir) graph3 = torch._C._jit_pass_canonicalize(graph3, False) FileCheck().check_not("%p207").run(graph3) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_python_builtins.py b/test/jit/test_python_builtins.py index c84e4edff23325..3dafc89ac7fda3 100644 --- a/test/jit/test_python_builtins.py +++ b/test/jit/test_python_builtins.py @@ -7,6 +7,7 @@ from textwrap import dedent import torch +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import execWrapper, JitTestCase @@ -14,13 +15,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - def get_fn(file_name, script_path): import importlib.util @@ -473,3 +467,7 @@ def foo(a): s = torch.rand(1) self.assertTrue(foo(s)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_python_ir.py b/test/jit/test_python_ir.py index e5e98ac9fb80f9..593d98193b1ce3 100644 --- a/test/jit/test_python_ir.py +++ b/test/jit/test_python_ir.py @@ -6,18 +6,10 @@ import torch from torch.testing import FileCheck -from torch.testing._internal.common_utils import IS_MACOS +from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestPythonIr(JitTestCase): def test_param_strides(self): def trace_me(arg): @@ -100,3 +92,7 @@ def foo(x): FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph) self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index 33fd38c2b9c7ee..d595c793e79b6c 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -20,20 +20,13 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import ( _tmp_donotuse_dont_inline_everything, JitTestCase, ) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestRecursiveScript(JitTestCase): def test_inferred_nonetype(self): class M(nn.Module): @@ -799,3 +792,7 @@ def i_am_ignored(self): # ScriptModule should correctly reflect the override. s = torch.jit.script(m) self.assertEqual(s.i_am_ignored(), "new") + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_remove_mutation.py b/test/jit/test_remove_mutation.py index 8048d406ab33f9..3250a86f804536 100644 --- a/test/jit/test_remove_mutation.py +++ b/test/jit/test_remove_mutation.py @@ -11,17 +11,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestRemoveMutation(JitTestCase): def test_aten_inplace(self): def test_not_new_alias(x): @@ -318,3 +311,7 @@ def test_multiple_uses(): self.run_pass("remove_mutation", mod_script.forward.graph) FileCheck().check("aten::add_").run(test_multiple_uses.graph) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 4c83c40e1aa9c4..f697e74ae9ac12 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -8,7 +8,11 @@ import torch from torch import Tensor -from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, + TemporaryFileName, +) # Make the helper files in test/ importable @@ -17,14 +21,6 @@ from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestSaveLoad(JitTestCase): def test_different_modules(self): """ @@ -1197,3 +1193,7 @@ def forward(self, x: Tensor): torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files) self.assertEqual(extra_files, re_extra_files) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_save_load_for_op_version.py b/test/jit/test_save_load_for_op_version.py index 1b62e4043eb89d..37df00ab8ef47c 100644 --- a/test/jit/test_save_load_for_op_version.py +++ b/test/jit/test_save_load_for_op_version.py @@ -17,17 +17,10 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.jit.mobile import _load_for_lite_interpreter +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestSaveLoadForOpVersion(JitTestCase): # Helper that returns the module after saving and loading def _save_load_module(self, m): @@ -617,3 +610,7 @@ def forward( self.assertTrue(output.size(dim=0) == 100) # "Upgraded" model should match the new version output self.assertEqual(output, output_current) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_script_profile.py b/test/jit/test_script_profile.py index c3977c2314de05..4bc8008d1aa2b6 100644 --- a/test/jit/test_script_profile.py +++ b/test/jit/test_script_profile.py @@ -10,17 +10,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class Sequence(nn.Module): def __init__(self) -> None: super().__init__() @@ -115,3 +108,7 @@ def test_empty(self): p.enable() p.disable() self.assertEqual(p.dump_string(), "") + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_scriptmod_ann.py b/test/jit/test_scriptmod_ann.py index 60d8d434b3aef7..f5cd655c2e2c78 100644 --- a/test/jit/test_scriptmod_ann.py +++ b/test/jit/test_scriptmod_ann.py @@ -11,17 +11,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): # NB: There are no tests for `Tuple` or `NamedTuple` here. In fact, # reassigning a non-empty Tuple to an attribute previously typed @@ -363,3 +356,7 @@ def forward(self, x: Optional[str]): "empty non-base types", ): torch.jit.script(M()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_slice.py b/test/jit/test_slice.py index f14dc68358c1ea..e1aca2839aba9f 100644 --- a/test/jit/test_slice.py +++ b/test/jit/test_slice.py @@ -10,17 +10,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - # Tests that Python slice class is supported in TorchScript class TestSlice(JitTestCase): def test_slice_kwarg(self): @@ -178,3 +171,7 @@ def forward(self): self.assertEqual(result2[0].identifier, "B") self.assertEqual(result2[1].identifier, "C") self.assertEqual(result2[2].identifier, "D") + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_sparse.py b/test/jit/test_sparse.py index 97ce0a32b6c43d..78e292b62d74b8 100644 --- a/test/jit/test_sparse.py +++ b/test/jit/test_sparse.py @@ -4,7 +4,11 @@ import unittest import torch -from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + raise_on_run_directly, + TEST_MKL, +) from torch.testing._internal.jit_utils import JitTestCase @@ -118,3 +122,7 @@ def forward(self, x): loaded_result = loaded_model.forward(x) self.assertEqual(expected_result.to_dense(), loaded_result.to_dense()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_string_formatting.py b/test/jit/test_string_formatting.py index e90c3cd9eebc2f..295ae85e3fb981 100644 --- a/test/jit/test_string_formatting.py +++ b/test/jit/test_string_formatting.py @@ -10,17 +10,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestStringFormatting(JitTestCase): def test_modulo_operator(self): def fn(dividend: int, divisor: int) -> int: @@ -199,3 +192,7 @@ def fn(arg1: str) -> str: '"%a in template" % arg1', ): fn("foo") + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index f43105093d7448..702fdd851954c8 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -9,18 +9,10 @@ from torch import nn, Tensor from torch.testing import FileCheck from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat -from torch.testing._internal.common_utils import make_tensor +from torch.testing._internal.common_utils import make_tensor, raise_on_run_directly from torch.testing._internal.jit_utils import execWrapper, JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - # XXX: still in prototype class TestSymbolicShapeAnalysis(JitTestCase): def setUp(self): @@ -819,3 +811,7 @@ def foo(x): input.setType(input.type().with_sizes([1, 5, 8])) torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) self.assertEqual(next(foo.graph.outputs()).type().symbolic_sizes(), [5, 8]) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_tensor_creation_ops.py b/test/jit/test_tensor_creation_ops.py index a51cd2bd3f388b..3933984179b5f5 100644 --- a/test/jit/test_tensor_creation_ops.py +++ b/test/jit/test_tensor_creation_ops.py @@ -9,17 +9,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestTensorCreationOps(JitTestCase): """ A suite of tests for ops that create tensors. @@ -78,3 +71,7 @@ def tril_indices(rows: int, cols: int): assert indices.dtype == torch.int32 self.checkScript(tril_indices, (3, 3)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_tensor_methods.py b/test/jit/test_tensor_methods.py index 8e75a96e260da7..05526341c9f5b2 100644 --- a/test/jit/test_tensor_methods.py +++ b/test/jit/test_tensor_methods.py @@ -10,17 +10,10 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestTensorMethods(JitTestCase): def test_getitem(self): def tensor_getitem(inp: torch.Tensor): @@ -41,3 +34,7 @@ def tensor_getitem_invalid(inp: torch.Tensor): RuntimeError, "expected exactly 1 argument", "inp.__getitem__" ): torch.jit.script(tensor_getitem_invalid) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index 813e5ba0f9e561..aab330972b01e2 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -8,7 +8,10 @@ from typing import Optional import torch -from torch.testing._internal.common_utils import skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, +) # Make the helper files in test/ importable @@ -19,14 +22,6 @@ from torch.testing._internal.torchbind_impls import load_torchbind_test_lib -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - @skipIfTorchDynamo("skipping as a precaution") class TestTorchbind(JitTestCase): def setUp(self): @@ -463,3 +458,7 @@ def gn() -> int: return obj.decrement() self.checkScript(gn, ()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 2aee35a60e9b00..b32bb4722aa287 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -29,6 +29,7 @@ from torch.testing._internal.common_utils import ( enable_profiling_mode_for_profiling_tests, IS_SANDCASTLE, + raise_on_run_directly, skipIfCompiledWithoutNumpy, skipIfCrossRef, skipIfTorchDynamo, @@ -46,14 +47,6 @@ ) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - @skipIfTorchDynamo("Not a suitable test for TorchDynamo") class TestTracer(JitTestCase): @unittest.skipIf(not RUN_CUDA, "requires CUDA") @@ -2826,3 +2819,7 @@ def outer_fn(x, y): for n in fn_t.graph.nodes(): if n.kind() == "prim::CallFunction": self.assertTrue(n.output().isCompleteTensor()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_type_sharing.py b/test/jit/test_type_sharing.py index 747222ad2649d4..a6313a94244b6e 100644 --- a/test/jit/test_type_sharing.py +++ b/test/jit/test_type_sharing.py @@ -10,18 +10,13 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.common_utils import suppress_warnings +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + suppress_warnings, +) from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestTypeSharing(JitTestCase): def assertSameType(self, m1, m2): if not isinstance(m1, torch.jit.ScriptModule): @@ -626,3 +621,7 @@ def forward(self, x): # of A, __jit_ignored_attributes__ was modified before scripting s2, # so the set of ignored attributes is different between s1 and s2. self.assertDifferentType(s1, s2) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_types.py b/test/jit/test_types.py index c0e56bb47c867c..a7b0752ab75004 100644 --- a/test/jit/test_types.py +++ b/test/jit/test_types.py @@ -12,6 +12,7 @@ import torch.testing._internal.jit_utils from jit.test_module_interface import TestModuleInterface # noqa: F401 from torch.testing import FileCheck +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase @@ -19,13 +20,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestTypesAndAnnotation(JitTestCase): def test_pep585_type(self): @@ -370,3 +364,7 @@ def test_inferred_type_error_message(self): with self.assertRaisesRegex(RuntimeError, "ErrorReason"): t = inferred_type.type() + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py index bf5e53b9e9f04f..8f34a1c75b6d7a 100644 --- a/test/jit/test_typing.py +++ b/test/jit/test_typing.py @@ -7,7 +7,7 @@ from typing import Dict, List, NamedTuple, Tuple import torch -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global @@ -15,13 +15,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestTyping(JitTestCase): def test_dict_in_not_in(self): @@ -140,7 +133,7 @@ def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]): # Check for invalid key and value type annotation def wrong_key_value_type( - dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule] + dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule], ): return @@ -688,3 +681,7 @@ def __init__(self) -> None: mod2 = LowestModule() mod_s = torch.jit.script(mod) mod2_s = torch.jit.script(mod2) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_union.py b/test/jit/test_union.py index c3810117dad39b..43bf5e67e6cb93 100644 --- a/test/jit/test_union.py +++ b/test/jit/test_union.py @@ -15,17 +15,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestUnion(JitTestCase): """ This class tests the functionality of `Union`. @@ -1066,3 +1059,7 @@ def fn(): # "Union[Dict[str, torch.Tensor], int]", # lhs["dict_comprehension_of_mixed"], # "foobar") + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_union_pep604.py b/test/jit/test_union_pep604.py index 871af5aa75a068..c3b1a5d8f24085 100644 --- a/test/jit/test_union_pep604.py +++ b/test/jit/test_union_pep604.py @@ -16,17 +16,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase, make_global -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - @unittest.skipIf(sys.version_info < (3, 10), "Requires Python 3.10") class TestUnion(JitTestCase): """ @@ -1064,3 +1057,7 @@ def fn(): # "Dict[str, torch.Tensor] | int", # lhs["dict_comprehension_of_mixed"], # "foobar") + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_unsupported_ops.py b/test/jit/test_unsupported_ops.py index cf07b9485ab220..47d57bd7461ef2 100644 --- a/test/jit/test_unsupported_ops.py +++ b/test/jit/test_unsupported_ops.py @@ -10,16 +10,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - # NOTE: FIXING FAILING TESTS # If you are seeing a test failure from this file, congrats, you improved # parity between JIT and Python API. Before you fix the test, you must also update @@ -90,3 +84,7 @@ def sparse(): func() with self.assertRaisesRegex(Exception, ""): torch.jit.script(func) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_upgraders.py b/test/jit/test_upgraders.py index 6a7b294164bca1..22d05052b4f0b3 100644 --- a/test/jit/test_upgraders.py +++ b/test/jit/test_upgraders.py @@ -13,17 +13,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestUpgraders(JitTestCase): def _load_model_version(self, loaded_model): buffer = io.BytesIO() @@ -346,3 +339,7 @@ def test_aten_full_out_at_4(self): FileCheck().check_count("aten::full", 5).run(loaded_model.graph) version = self._load_model_version(loaded_model) self.assertTrue(version == 5) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_warn.py b/test/jit/test_warn.py index e72ab71b30a018..b8e85607c579fe 100644 --- a/test/jit/test_warn.py +++ b/test/jit/test_warn.py @@ -13,17 +13,10 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - - class TestWarn(JitTestCase): def test_warn(self): @torch.jit.script @@ -148,3 +141,7 @@ def bar(): ).run( f.getvalue() ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/test_with.py b/test/jit/test_with.py index c03085efd326af..5afb9459c2df22 100644 --- a/test/jit/test_with.py +++ b/test/jit/test_with.py @@ -6,7 +6,10 @@ from typing import Any, List import torch -from torch.testing._internal.common_utils import skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, +) from torch.testing._internal.jit_utils import JitTestCase, make_global @@ -14,13 +17,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -if __name__ == "__main__": - raise RuntimeError( - "This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead." - ) - class TestWith(JitTestCase): """ @@ -647,3 +643,7 @@ def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Nested record function should have child "aten::add" nested_child_events = nested_function_event.cpu_children self.assertTrue("aten::add" in (child.name for child in nested_child_events)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/test/jit/xnnpack/test_xnnpack_delegate.py b/test/jit/xnnpack/test_xnnpack_delegate.py index 6996ee7e4d4685..b97765ed5bb0ba 100644 --- a/test/jit/xnnpack/test_xnnpack_delegate.py +++ b/test/jit/xnnpack/test_xnnpack_delegate.py @@ -184,3 +184,10 @@ def forward(self, x, y): } }, ) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/lazy/test_bindings.py b/test/lazy/test_bindings.py index 39466b33a168dc..f84763f695941a 100644 --- a/test/lazy/test_bindings.py +++ b/test/lazy/test_bindings.py @@ -1,8 +1,13 @@ # Owner(s): ["oncall: jit"] import torch._lazy.metrics +from torch.testing._internal.common_utils import run_tests def test_metrics(): names = torch._lazy.metrics.counter_names() assert len(names) == 0, f"Expected no counter names, but got {names}" + + +if __name__ == "__main__": + run_tests() diff --git a/test/lazy/test_extract_compiled_graph.py b/test/lazy/test_extract_compiled_graph.py index 79359ddb769ab4..1ea0219066d4a7 100644 --- a/test/lazy/test_extract_compiled_graph.py +++ b/test/lazy/test_extract_compiled_graph.py @@ -206,3 +206,10 @@ class OptimizeTest(unittest.TestCase): test_return_multi = maketest(ModuleReturnMulti) test_return_dup_tensor = maketest(ModuleReturnDupTensor) test_inplace_update = maketest(ModuleInplaceUpdate) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/lazy/test_meta_kernel.py b/test/lazy/test_meta_kernel.py index e212fca89ba414..e0922b88fc28a3 100644 --- a/test/lazy/test_meta_kernel.py +++ b/test/lazy/test_meta_kernel.py @@ -37,3 +37,10 @@ def test_addmm(self): def test_add_invalid_device(self): with self.assertRaisesRegex(RuntimeError, ".*not a lazy tensor.*"): _ = torch.tensor([1], device="cpu") + torch.tensor([1], device="lazy") + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/mobile/model_test/gen_test_model.py b/test/mobile/model_test/gen_test_model.py index f234210d28905c..5e760a739cec7f 100644 --- a/test/mobile/model_test/gen_test_model.py +++ b/test/mobile/model_test/gen_test_model.py @@ -118,16 +118,16 @@ def calcOpsCoverage(ops): uncovered_ops = production_ops - covered_ops coverage = round(100 * len(covered_ops) / len(production_ops), 2) - # weighted coverage (take op occurances into account) - total_occurances = sum(production_ops_dict["root_operators"].values()) + # weighted coverage (take op occurrences into account) + total_occurrences = sum(production_ops_dict["root_operators"].values()) covered_ops_dict = { op: production_ops_dict["root_operators"][op] for op in covered_ops } uncovered_ops_dict = { op: production_ops_dict["root_operators"][op] for op in uncovered_ops } - covered_occurances = sum(covered_ops_dict.values()) - occurances_coverage = round(100 * covered_occurances / total_occurances, 2) + covered_occurrences = sum(covered_ops_dict.values()) + occurrences_coverage = round(100 * covered_occurrences / total_occurrences, 2) print(f"\n{len(uncovered_ops)} uncovered ops: {uncovered_ops}\n") print(f"Generated {len(all_generated_ops)} ops") @@ -135,7 +135,7 @@ def calcOpsCoverage(ops): f"Covered {len(covered_ops)}/{len(production_ops)} ({coverage}%) production ops" ) print( - f"Covered {covered_occurances}/{total_occurances} ({occurances_coverage}%) occurances" + f"Covered {covered_occurrences}/{total_occurrences} ({occurrences_coverage}%) occurrences" ) print(f"pytorch ver {torch.__version__}\n") diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 38238405d7bd83..858e72416c890c 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -1293,11 +1293,12 @@ def func(*inputs): torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [] ) ) - def test_Conv2d_deterministic_cudnn(self, device, dtype): - inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True) + @parametrize_test("dilation", [1, 2, 3]) + def test_Conv2d_deterministic_cudnn(self, device, dtype, dilation): + inputs = torch.randn(2, 3, 7, 7, device=device, dtype=dtype, requires_grad=True) with cudnn.flags(enabled=True, benchmark=True, deterministic=True): - conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) - conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) + conv1 = torch.nn.Conv2d(3, 3, 3, dilation=dilation).to(device, dtype) + conv2 = torch.nn.Conv2d(3, 3, 3, dilation=dilation).to(device, dtype) conv2.bias.data.copy_(conv1.bias.data) conv2.weight.data.copy_(conv1.weight.data) out1 = conv1(inputs) @@ -1797,7 +1798,9 @@ def test_conv3d_valid_padding(self, device, dtype): self.assertEqual(expect, actual) @dtypes(torch.float, torch.cfloat) - @dtypesIfMPS(torch.float) + @dtypesIfMPS( + *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat]) + ) # Complex not supported on MacOS13 def test_conv1d_same_padding_backward(self, device, dtype): # Test F.conv1d gradients work with padding='same' x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True) @@ -2132,9 +2135,7 @@ def test_conv3d_valid_padding_backward(self, device, dtype): arg_str="N", arg_values=[ subtest(arg_values=(2), name="ConvTranspose2d"), - subtest( - arg_values=(3), name="ConvTranspose3d", decorators=[expectedFailureMPS] - ), + subtest(arg_values=(3), name="ConvTranspose3d"), ], ) def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): @@ -3091,7 +3092,6 @@ def test_conv_large_nosplit(self, device): input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device) conv2(input_large) - @expectedFailureMPS # ConvTranspose 3D is not supported on MPS def test_conv_noncontig_weights(self, device): for dim in (1, 2, 3): for grouped in (False, True): diff --git a/test/nn/test_lazy_modules.py b/test/nn/test_lazy_modules.py index d64020c2dcc637..6cc78cbfc51a1d 100644 --- a/test/nn/test_lazy_modules.py +++ b/test/nn/test_lazy_modules.py @@ -33,7 +33,7 @@ def test_lazy_module_parameter(self): new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5))) with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): new_module.load_state_dict(state_dict) - # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one + # Uninitialized parameters are overridden when the state dict to be loaded contains a valid one new_module = LazyModule() new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5))) module.load_state_dict(new_module.state_dict()) @@ -62,7 +62,7 @@ def test_lazy_module_buffer(self): new_module.test_buffer = Buffer(torch.ones(5, 5)) with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): new_module.load_state_dict(state_dict) - # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one + # Uninitialized parameters are overridden when the state dict to be loaded contains a valid one new_module = LazyModule() new_module.test_buffer = Buffer(torch.ones(5, 5)) module.load_state_dict(new_module.state_dict()) diff --git a/test/nn/test_load_state_dict.py b/test/nn/test_load_state_dict.py index 641017284c63e1..8ce1f03c0a841f 100644 --- a/test/nn/test_load_state_dict.py +++ b/test/nn/test_load_state_dict.py @@ -470,9 +470,9 @@ def module_load(dest, src, assign=False): return cls(src._data) return cls(src) else: - assert isinstance( - src, cls - ), f"Expected isinstance(src, {cls}) but got {type(src)}" + assert isinstance(src, cls), ( + f"Expected isinstance(src, {cls}) but got {type(src)}" + ) assert ( type(dest) == torch.Tensor or type(dest) == torch.nn.Parameter diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index c9c29f0ba4a3d0..72e3665cfdd5d8 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -1445,7 +1445,14 @@ def hook(mod, grad_input, grad_output): mod.register_full_backward_hook(hook) # This should run and trigger the hook properly - mod(inp).sum().backward() + with self.assertWarnsRegex( + UserWarning, + ( + "Full backward hook is firing when gradients are computed with " + "respect to module outputs since no inputs require gradients" + ), + ): + mod(inp).sum().backward() self.assertEqual(hook_called[0], 1) return_val = "grad_input" diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py index cbc2a143ec4065..eb1f7c982b7cad 100644 --- a/test/nn/test_parametrization.py +++ b/test/nn/test_parametrization.py @@ -1475,9 +1475,9 @@ def test_new_spectral_norm_load_state_dict(self): snm.load_state_dict(non_strict_state_dict, strict=False) del non_strict_state_dict["parametrizations.weight.0._v"] snm.load_state_dict(non_strict_state_dict, strict=False) - non_strict_state_dict[ - "weight" - ] = snm.weight.detach().clone() # set W as a buffer + non_strict_state_dict["weight"] = ( + snm.weight.detach().clone() + ) # set W as a buffer snm.load_state_dict(non_strict_state_dict, strict=False) del non_strict_state_dict._metadata[ "parametrizations.weight.0" @@ -1652,7 +1652,7 @@ def assert_weight_allclose_Q(weight, W): if can_initialize: assert_weight_allclose_Q(m.weight, w_init) - # Intializing with a given orthogonal matrix works + # Initializing with a given orthogonal matrix works X = torch.randn_like(m.weight) if wide_matrix: X = X.mT @@ -1669,7 +1669,7 @@ def assert_weight_allclose_Q(weight, W): with self.assertRaisesRegex(NotImplementedError, msg): m.weight = w_new - # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix + # Initializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix w_new = torch.randn_like(m.weight) if can_initialize: m.weight = w_new diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 537d2a6f8ec313..c4722741a46624 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -50,6 +50,11 @@ def forward( return x - w, x - y, c +class SampleModelForDimOne(torch.nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y), axis=1) + z + + class TestExportAPIDynamo(common_utils.TestCase): """Tests for the ONNX exporter API when dynamo=True.""" @@ -288,6 +293,17 @@ def forward(self, x): input = torch.randn(2) self.assert_export(Model(), (input)) + def test_export_successful_when_dynamic_dimension_is_one(self): + self.assert_export( + SampleModelForDimOne(), + (torch.randn(1, 3), torch.randn(1, 5), torch.randn(1, 8)), + dynamic_shapes=( + {0: "batch", 1: "sequence"}, + {0: "batch", 1: "sequence"}, + {0: "batch", 1: "sequence"}, + ), + ) + class TestCustomTranslationTable(common_utils.TestCase): def test_custom_translation_table_overrides_ops(self): diff --git a/test/onnx/exporter/test_small_models_e2e.py b/test/onnx/exporter/test_small_models_e2e.py index 9b90e2f878459f..8bdca45920f032 100644 --- a/test/onnx/exporter/test_small_models_e2e.py +++ b/test/onnx/exporter/test_small_models_e2e.py @@ -5,8 +5,13 @@ import logging +import onnx.reference as onnx_ref + +import onnxruntime +import pytest import transformers from onnxscript import ir +from packaging import version import torch from torch.onnx._internal.exporter import _testing as onnx_testing @@ -14,8 +19,11 @@ from torch.utils import _pytree as torch_pytree -@common_utils.instantiate_parametrized_tests -class DynamoExporterTest(common_utils.TestCase): +def has_onnxruntime_opset_23() -> bool: + return version.parse(onnxruntime.__version__) >= version.parse("1.22") + + +class _WithExport: def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram: onnx_program = torch.onnx.export( model, @@ -29,6 +37,9 @@ def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgr assert onnx_program is not None return onnx_program + +@common_utils.instantiate_parametrized_tests +class DynamoExporterTest(common_utils.TestCase, _WithExport): def test_insert_contiguous_between_transpose_and_view(self): class Model(torch.nn.Module): def forward(self, query, key, value): @@ -227,8 +238,8 @@ def forward(self): onnx_program = self.export(Float4Module(), optimize=False) output = onnx_program.model.graph.outputs[0] self.assertEqual(output.dtype, ir.DataType.FLOAT4E2M1) - # The shape is [*shape, 2] because ONNX stores the shape of the unpacked tensor - self.assertEqual(output.shape.dims, [1, 2]) + # The shape is [*shape[:-1], shape[-1]*2] because ONNX stores the shape of the unpacked tensor + self.assertEqual(output.shape.numpy(), [2]) def test_bfloat16_support(self): class BfloatModel(torch.nn.Module): @@ -306,7 +317,7 @@ def forward(self, x, y): return x + y dim0_x = torch.export.Dim("dim0_x", min=6) - dynamic_shapes = {"x": {0: dim0_x}, "y": None} + dynamic_shapes = {"x": {0: dim0_x}, "y": torch.export.Dim.STATIC} # specialized input y to 5 during tracing onnx_program = self.export( Model(), @@ -547,11 +558,11 @@ def forward(self, x, y, z): # all of these should be fine dynamic_shapes = ( {0: dx, 1: torch.export.Dim.AUTO}, - {0: dy, 1: None}, + {0: dy, 1: torch.export.Dim.STATIC}, {0: dz, 1: 3}, ) onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes) - onnx_testing.assert_onnx_program(onnx_program, args=inputs) + onnx_testing.assert_onnx_program(onnx_program) # make sre the naming is working self.assertEqual(onnx_program.model.graph.inputs[0].shape[0], "dx") @@ -563,7 +574,7 @@ def forward(self, x): inputs = (torch.zeros((2, 3)),) dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},) onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes) - onnx_testing.assert_onnx_program(onnx_program, args=inputs) + onnx_testing.assert_onnx_program(onnx_program) self.assertIn( "Max", [node.op_type for node in onnx_program.model.graph], @@ -577,7 +588,7 @@ def forward(self, x): inputs = (torch.zeros((2, 3)),) dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},) onnx_program = self.export(Model(), inputs, dynamic_shapes=dynamic_shapes) - onnx_testing.assert_onnx_program(onnx_program, args=inputs) + onnx_testing.assert_onnx_program(onnx_program) self.assertIn( "Min", [node.op_type for node in onnx_program.model.graph], @@ -592,7 +603,7 @@ def forward(self, x): inputs = (torch.zeros((2, 2)),) dynamic_shapes = ({0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},) onnx_program = self.export(SymNotModel(), inputs, dynamic_shapes=dynamic_shapes) - onnx_testing.assert_onnx_program(onnx_program, args=inputs) + onnx_testing.assert_onnx_program(onnx_program) self.assertIn( "Not", [node.op_type for node in onnx_program.model.graph], @@ -609,12 +620,106 @@ def forward(self, x): onnx_program = self.export( SymFloatModel(), inputs, dynamic_shapes=dynamic_shapes ) - onnx_testing.assert_onnx_program(onnx_program, args=inputs) + onnx_testing.assert_onnx_program(onnx_program) self.assertIn( "Cast", [node.op_type for node in onnx_program.model.graph], ) + def test_scan_cdist_add(self): + def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor): + sub = samex - x.reshape((1, -1)) + sq = sub * sub + rd = torch.sqrt(sq.sum(axis=1)) + return [unused.clone(), rd] + + class ScanModel(torch.nn.Module): + def forward(self, x): + z = torch.tensor([0], dtype=torch.float32) + y = x.clone() + out = torch.ops.higher_order.scan(dist, [z], [x], additional_inputs=[y]) + return out[1] + + inputs = ( + torch.tensor( + [[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32 + ), + ) + onnx_program = self.export(ScanModel(), inputs) + onnx_testing.assert_onnx_program(onnx_program) + + def test_scan_cdist_dynamic_shapes(self): + def dist(y: torch.Tensor, scanned_x: torch.Tensor): + sub = y - scanned_x.reshape((1, -1)) + sq = sub * sub + rd = torch.sqrt(sq.sum(axis=1)) + return [y.clone(), rd] + + class ScanModel(torch.nn.Module): + def forward(self, x, y): + carry, out = torch.ops.higher_order.scan( + dist, [y], [x], additional_inputs=[] + ) + return out + + x_rows = torch.export.Dim("x_rows") + y_rows = torch.export.Dim("y_rows") + dim = torch.export.Dim("dim") + inputs = (torch.randn(3, 4), torch.randn(5, 4)) + onnx_program = self.export( + ScanModel(), + inputs, + dynamic_shapes=({0: x_rows, 1: dim}, {0: y_rows, 1: dim}), + ) + onnx_testing.assert_onnx_program(onnx_program) + + @pytest.mark.xfail(reason="Data dependent error.") + def test_scan_loop_inplace(self): + def dummy_loop(padded: torch.Tensor, pos: torch.Tensor): + copy = torch.zeros(padded.shape) + for i in range(pos.shape[0]): + p = pos[i] + copy[i, :p] = padded[i, :p] + return copy + + def dummy_loop_with_scan(padded: torch.Tensor, pos: torch.Tensor): + def pad_row(padded, p): + row = torch.zeros((padded.shape[0],)) + torch._check(p.item() > 0) + torch._check(p.item() < padded.shape[0]) + # this check is not always true, we add it anyway to make this dimension >= 2 + # and avoid raising an exception about dynamic dimension in {0, 1} + if torch.compiler.is_exporting(): + torch._check(p.item() > 1) + row[: p.item()] = padded[: p.item()] + return (row,) + + return torch.ops.higher_order.scan(pad_row, [], [padded, pos], []) + + def select_when_exporting(f, f_scan): + return f_scan if torch.compiler.is_exporting() else f + + class ScanModel(torch.nn.Module): + def forward(self, images, position): + return select_when_exporting(dummy_loop, dummy_loop_with_scan)( + images, position + ) + + DYN = torch.export.Dim.DYNAMIC + x = torch.randn((5, 6)) + y = torch.arange(5, dtype=torch.int64) + 1 + ep = torch.export.export( + ScanModel(), + (x, y), + dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}}, + strict=False, + ) + onnx_program = self.export(ep) + onnx_testing.assert_onnx_program(onnx_program) + + +@common_utils.instantiate_parametrized_tests +class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport): def test_group_norm_opset_21(self): class Model(torch.nn.Module): def forward(self, x): @@ -629,6 +734,46 @@ def forward(self, x): [node.op_type for node in onnx_program.model.graph], ) + def test_graph_attention_opset_23(self): + class Model(torch.nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value + ) + + query = torch.rand(32, 8, 128, 64, dtype=torch.float16) + key = torch.rand(32, 8, 128, 64, dtype=torch.float16) + value = torch.rand(32, 8, 128, 64, dtype=torch.float16) + expected = Model()(query, key, value) + + onnx_program = self.export(Model(), (query, key, value), opset_version=23) + self.assertIn("Attention", [node.op_type for node in onnx_program.model.graph]) + + ref = onnx_ref.ReferenceEvaluator(onnx_program.model_proto) + got = ref.run( + None, dict(query=query.numpy(), key=key.numpy(), value=value.numpy()) + )[0] + torch.testing.assert_close(torch.from_numpy(got), expected, atol=1e-2, rtol=1) + + def test_graph_accuracy_attention_opset_23(self): + class Model(torch.nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value + ) + + query = torch.rand(32, 8, 128, 64, dtype=torch.float16) + key = torch.rand(32, 8, 128, 64, dtype=torch.float16) + value = torch.rand(32, 8, 128, 64, dtype=torch.float16) + + onnx_program = self.export( + Model(), (query, key, value), opset_version=23, optimize=True + ) + self.assertEqual(["Attention"], [n.op_type for n in onnx_program.model.graph]) + # onnxruntime inlines any op defined as a function and without any implemented kernel + if has_onnxruntime_opset_23(): + onnx_testing.assert_onnx_program(onnx_program, atol=1e-2, rtol=1) + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/internal/test_registraion.py b/test/onnx/internal/test_registraion.py index 39afcc24ee65e0..e357dbff713a8e 100644 --- a/test/onnx/internal/test_registraion.py +++ b/test/onnx/internal/test_registraion.py @@ -144,7 +144,7 @@ def test_remove_override_removes_overridden_key(self): self.assertEqual(len(self.override_dict), 0) self.assertNotIn("a", self.override_dict) - def test_overriden_key_precededs_base_key_regardless_of_insert_order(self): + def test_overridden_key_precedes_base_key_regardless_of_insert_order(self): self.override_dict.set_base("a", 42) self.override_dict.override("a", 100) self.override_dict.set_base("a", 0) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index a138d023c5f928..ef0c3ba3ebcaf1 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -236,12 +236,6 @@ def run_ort( MAX_ONNX_OPSET_VERSION = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1) -# The min onnx opset version to test for -FX_MIN_ONNX_OPSET_VERSION = 18 -# The max onnx opset version to test for -FX_MAX_ONNX_OPSET_VERSION = 18 -FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1) - BOOL_TYPES = (torch.bool,) INT_TYPES = ( diff --git a/test/onnx/ops/test_ops.py b/test/onnx/ops/test_ops.py index 0125d5b08dd127..437c74e9bfbfd9 100644 --- a/test/onnx/ops/test_ops.py +++ b/test/onnx/ops/test_ops.py @@ -3,10 +3,11 @@ from __future__ import annotations +import onnx_ir.passes.common as common_passes from onnxscript import ir import torch -from torch.onnx.ops import _symbolic_impl +from torch.onnx.ops import _impl, _symbolic_impl from torch.testing._internal import common_utils @@ -414,5 +415,1039 @@ def test_symbolic_multi_out_raises_when_dtypes_and_shapes_differ(self): ) +class NativeOnnxOpsTest(common_utils.TestCase): + def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram: + onnx_program = torch.onnx.export( + model, + args, + kwargs=kwargs, + dynamo=True, + fallback=False, + verbose=False, + **options, + ) + assert onnx_program is not None + common_passes.CheckerPass()(onnx_program.model) + return onnx_program + + def test_onnx_ops_can_be_decomposed_to_aten(self): + input_data = torch.rand(2, 3, 4, 8) + position_ids_data = torch.randint(0, 50, (2, 3)).long() + sin_cache_data = torch.rand(50, 4) + cos_cache_data = torch.rand(50, 4) + + class Model(torch.nn.Module): + def forward( + self, input_data, cos_cache_data, sin_cache_data, position_ids_data + ): + return torch.onnx.ops.rotary_embedding( + input_data, + cos_cache_data, + sin_cache_data, + position_ids_data, + interleaved=True, + ) + + model = Model() + + ep = torch.export.export( + model, + (input_data, cos_cache_data, sin_cache_data, position_ids_data), + ) + self.assertIn( + "onnx.RotaryEmbedding.opset23", + [str(node.target) for node in ep.graph.nodes], + ) + # The program can be decomposed into aten ops so it is fully compatible with the PyTorch ecosystem + aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions()) + self.assertNotIn( + "onnx.RotaryEmbedding.opset23", + [str(node.target) for node in aten_decomped.graph.nodes], + ) + torch.testing.assert_close( + aten_decomped.module()( + input_data, cos_cache_data, sin_cache_data, position_ids_data + ), + model(input_data, cos_cache_data, sin_cache_data, position_ids_data), + ) + + def test_rotary_embedding_opcheck(self): + input_data = torch.rand(2, 3, 4, 8) + position_ids_data = torch.randint(0, 50, (2, 3)).long() + sin_cache_data = torch.rand(50, 4) + cos_cache_data = torch.rand(50, 4) + + torch.library.opcheck( + _impl.rotary_embedding_23, + (input_data, cos_cache_data, sin_cache_data, position_ids_data), + ) + + def test_rotary_embedding(self): + input_data = torch.rand(2, 3, 4, 8) + position_ids_data = torch.randint(0, 50, (2, 3)).long() + sin_cache_data = torch.rand(50, 4) + cos_cache_data = torch.rand(50, 4) + + # Eager mode is supported. Autograd is also supported so users can choose to use the op + # in development and production + result = torch.onnx.ops.rotary_embedding( + input_data, cos_cache_data, sin_cache_data, position_ids_data + ) + self.assertEqual(result.shape, input_data.shape) + + class Model(torch.nn.Module): + def forward( + self, input_data, cos_cache_data, sin_cache_data, position_ids_data + ): + return torch.onnx.ops.rotary_embedding( + input_data, + cos_cache_data, + sin_cache_data, + position_ids_data, + interleaved=True, + ) + + model = Model() + + # Dynamic shapes are supported + dynamic_shapes = { + "input_data": {0: torch.export.Dim.DYNAMIC}, + "cos_cache_data": None, + "sin_cache_data": None, + "position_ids_data": {0: torch.export.Dim.DYNAMIC}, + } + + onnx_program = self.export( + model, + (input_data, cos_cache_data, sin_cache_data, position_ids_data), + dynamic_shapes=dynamic_shapes, + opset_version=23, + ) + self.assertEqual(onnx_program.model.opset_imports[""], 23) + self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type) + + def test_attention_basic(self): + """Test basic attention functionality.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + # Test eager mode + torch.library.opcheck(_impl.attention_23, (Q, K, V)) + output, present_key, present_value, qk_output = torch.onnx.ops.attention( + Q, K, V + ) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + self.assertEqual(present_key.shape, K.shape) + self.assertEqual(present_value.shape, V.shape) + self.assertEqual( + qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len) + ) + + def test_attention_3d_inputs(self): + """Test attention with 3D inputs (requires num_heads parameters).""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size) + K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size) + V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size) + + torch.library.opcheck( + _impl.attention_23, + (Q, K, V), + dict(q_num_heads=q_num_heads, kv_num_heads=kv_num_heads), + ) + output, present_key, present_value, qk_output = torch.onnx.ops.attention( + Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads + ) + + # Output should be reshaped back to 3D + self.assertEqual(output.shape, (batch_size, q_seq_len, q_num_heads * head_size)) + self.assertEqual( + present_key.shape, (batch_size, kv_num_heads, kv_seq_len, head_size) + ) + self.assertEqual( + present_value.shape, (batch_size, kv_num_heads, kv_seq_len, head_size) + ) + + def test_attention_gqa(self): + """Test Group Query Attention (GQA).""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 4 # GQA: q_num_heads % kv_num_heads = 0 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + torch.library.opcheck(_impl.attention_23, (Q, K, V)) + output, present_key, present_value, qk_output = torch.onnx.ops.attention( + Q, K, V + ) + expected = torch.nn.functional.scaled_dot_product_attention( + Q, K, V, None, enable_gqa=True + ) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + self.assertEqual(present_key.shape, K.shape) + self.assertEqual(present_value.shape, V.shape) + torch.testing.assert_close(output, expected) + + def test_attention_mqa(self): + """Test Multi-Query Attention (MQA).""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 1 # MQA: kv_num_heads = 1 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + torch.library.opcheck(_impl.attention_23, (Q, K, V)) + output, present_key, present_value, qk_output = torch.onnx.ops.attention( + Q, K, V + ) + expected = torch.nn.functional.scaled_dot_product_attention( + Q, K, V, None, enable_gqa=True + ) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + torch.testing.assert_close(output, expected) + + def test_attention_with_2d_mask(self): + """Test attention with 2D attention mask (q_seq_len, kv_seq_len).""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + # Test with boolean mask + bool_mask = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool) + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask)) + output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask) + + # Test with float mask + float_mask = torch.randn(q_seq_len, kv_seq_len) + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask)) + output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask) + + self.assertEqual( + output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size) + ) + self.assertEqual( + output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size) + ) + + def test_attention_with_4d_mask(self): + """Test attention with 4D attention mask (batch_size, num_heads, q_seq_len, kv_seq_len).""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + # Test with boolean mask + bool_mask = torch.randint( + 0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool + ) + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask)) + output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask) + + # Test with float mask + float_mask = torch.randn(batch_size, q_num_heads, q_seq_len, kv_seq_len) + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask)) + output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask) + + self.assertEqual( + output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size) + ) + self.assertEqual( + output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size) + ) + + def test_attention_with_zero_float_mask(self): + """Test attention with zero float mask.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + zero_mask = torch.zeros(q_seq_len, kv_seq_len) + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=zero_mask)) + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=zero_mask) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + + def test_attention_with_causal_mask_pattern(self): + """Test attention with lower triangular causal mask pattern.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + # Create a lower triangular causal mask + causal_mask = torch.tril(torch.ones(q_seq_len, kv_seq_len, dtype=torch.bool)) + torch.library.opcheck( + _impl.attention_23, (Q, K, V), dict(attn_mask=causal_mask) + ) + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=causal_mask) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + + def test_attention_with_gqa_and_mask(self): + """Test attention with GQA and different mask shapes.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 4 # GQA + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + # Test 2D mask with GQA + mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool) + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_2d)) + output_2d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_2d) + + # Test 4D mask with GQA (note: using q_num_heads for mask heads) + mask_4d = torch.randint( + 0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool + ) + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_4d)) + output_4d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_4d) + + self.assertEqual( + output_2d.shape, (batch_size, q_num_heads, q_seq_len, head_size) + ) + self.assertEqual( + output_4d.shape, (batch_size, q_num_heads, q_seq_len, head_size) + ) + + def test_attention_with_large_negative_float_mask(self): + """Test attention with large negative values in float mask.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + # Create mask with large negative values (similar to -inf masking) + float_mask = torch.full((q_seq_len, kv_seq_len), -1e9) + # Allow some positions + float_mask[:, :3] = 0.0 + + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask)) + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + + def test_attention_causal(self): + """Test causal attention.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(is_causal=True)) + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, is_causal=True) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + + def test_attention_with_past_kv(self): + """Test attention with past key/value caches.""" + batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size) + past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size) + + torch.library.opcheck( + _impl.attention_23, + (Q, K, V), + dict(past_key=past_key, past_value=past_value), + ) + output, present_key, present_value, _ = torch.onnx.ops.attention( + Q, K, V, past_key=past_key, past_value=past_value + ) + + # Present key/value should include past + current + expected_total_seq_len = past_seq_len + kv_seq_len + self.assertEqual( + present_key.shape, + (batch_size, kv_num_heads, expected_total_seq_len, head_size), + ) + self.assertEqual( + present_value.shape, + (batch_size, kv_num_heads, expected_total_seq_len, head_size), + ) + + def test_attention_with_softcap(self): + """Test attention with softcap.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(softcap=30.0)) + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, softcap=30.0) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + + def test_attention_qk_output_modes(self): + """Test different QK matmul output modes.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + for mode in [0, 1, 2, 3]: + torch.library.opcheck( + _impl.attention_23, + (Q, K, V), + dict(qk_matmul_output_mode=mode), + ) + output, _, _, qk_output = torch.onnx.ops.attention( + Q, K, V, qk_matmul_output_mode=mode + ) + + self.assertEqual( + output.shape, (batch_size, q_num_heads, q_seq_len, head_size) + ) + self.assertEqual( + qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len) + ) + + def test_attention_custom_scale(self): + """Test attention with custom scale factor.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + custom_scale = 0.25 + torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(scale=custom_scale)) + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, scale=custom_scale) + + self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size)) + + def test_attention_export(self): + """Test that attention can be exported to ONNX.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + class AttentionModel(torch.nn.Module): + def forward(self, Q, K, V): + output, present_key, present_value, qk_output = ( + torch.onnx.ops.attention(Q, K, V) + ) + return output + + model = AttentionModel() + + onnx_program = self.export( + model, + (Q, K, V), + opset_version=23, + ) + + self.assertEqual(onnx_program.model.opset_imports[""], 23) + self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type) + + def test_attention_export_with_dynamic_shapes(self): + """Test attention export with dynamic shapes.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + class AttentionModel(torch.nn.Module): + def forward(self, Q, K, V): + output, present_key, present_value, qk_output = ( + torch.onnx.ops.attention(Q, K, V) + ) + return output + + model = AttentionModel() + + dynamic_shapes = { + "Q": {0: "batch", 2: "q_seq_len"}, + "K": {0: "batch", 2: "kv_seq_len"}, + "V": {0: "batch", 2: "kv_seq_len"}, + } + + onnx_program = self.export( + model, + (Q, K, V), + dynamic_shapes=dynamic_shapes, + opset_version=23, + ) + + self.assertEqual(onnx_program.model.opset_imports[""], 23) + self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type) + node = onnx_program.model.graph.node(0) + # Verify inputs + self.assertEqual(len(node.inputs), 3) # Q, K, V (no optional inputs) + self.assertEqual( + node.inputs[0].shape, ["batch", q_num_heads, "q_seq_len", head_size] + ) + self.assertEqual( + node.inputs[1].shape, ["batch", kv_num_heads, "kv_seq_len", head_size] + ) + self.assertEqual( + node.inputs[2].shape, ["batch", kv_num_heads, "kv_seq_len", head_size] + ) + + # Verify default attributes (should be minimal) + self.assertEqual(len(node.attributes), 0) + + def test_attention_3d_export(self): + """Test attention export with 3D inputs.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size) + K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size) + V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size) + + class AttentionModel(torch.nn.Module): + def forward(self, Q, K, V): + output, _, _, _ = torch.onnx.ops.attention( + Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads + ) + return output + + model = AttentionModel() + + onnx_program = self.export( + model, + (Q, K, V), + opset_version=23, + ) + + self.assertEqual(onnx_program.model.opset_imports[""], 23) + self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type) + + def test_attention_decomposition(self): + """Test that attention can be decomposed to aten ops.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + class AttentionModel(torch.nn.Module): + def forward(self, Q, K, V): + output, present_key, present_value, qk_output = ( + torch.onnx.ops.attention(Q, K, V) + ) + return output + + model = AttentionModel() + + ep = torch.export.export(model, (Q, K, V)) + self.assertIn( + "onnx.Attention.opset23", + [str(node.target) for node in ep.graph.nodes], + ) + + # The program can be decomposed into aten ops + aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions()) + self.assertNotIn( + "onnx.Attention.opset23", + [str(node.target) for node in aten_decomped.graph.nodes], + ) + + # Results should match + torch.testing.assert_close( + aten_decomped.module()(Q, K, V), + model(Q, K, V), + ) + + def test_attention_export_with_past_key_value(self): + """Test export with past_key, past_value to ensure the optional input order is correct.""" + batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size) + past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size) + + class Model(torch.nn.Module): + def forward(self, Q, K, V, past_key, past_value): + output, _, _, _ = torch.onnx.ops.attention( + Q, + K, + V, + past_key=past_key, + attn_mask=None, + # Switched argument order + past_value=past_value, + ) + return output + + model = Model() + onnx_program = self.export( + model, (Q, K, V, past_key, past_value), opset_version=23 + ) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + + # Verify all 6 inputs are present + self.assertEqual( + len(node.inputs), 6 + ) # Q, K, V, attn_mask, past_key, past_value + self.assertEqual( + node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size] + ) + self.assertEqual( + node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size] + ) + self.assertEqual( + node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size] + ) + self.assertIsNone(node.inputs[3]) + self.assertEqual( + node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size] + ) + self.assertEqual( + node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size] + ) + + def test_attention_export_with_all_optional_inputs(self): + """Test export with all optional inputs: mask, past_key, past_value.""" + batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + attn_mask = torch.randint( + 0, 2, (1, 1, q_seq_len, kv_seq_len + past_seq_len), dtype=torch.bool + ) + past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size) + past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size) + + class FullAttentionModel(torch.nn.Module): + def forward(self, Q, K, V, attn_mask, past_key, past_value): + output, _, _, _ = torch.onnx.ops.attention( + Q, + K, + V, + attn_mask=attn_mask, + past_key=past_key, + past_value=past_value, + ) + return output + + model = FullAttentionModel() + onnx_program = self.export( + model, (Q, K, V, attn_mask, past_key, past_value), opset_version=23 + ) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + + # Verify all 6 inputs are present + self.assertEqual( + len(node.inputs), 6 + ) # Q, K, V, attn_mask, past_key, past_value + self.assertEqual( + node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size] + ) + self.assertEqual( + node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size] + ) + self.assertEqual( + node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size] + ) + self.assertEqual( + node.inputs[3].shape, [1, 1, q_seq_len, kv_seq_len + past_seq_len] + ) + self.assertEqual( + node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size] + ) + self.assertEqual( + node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size] + ) + + def test_attention_export_3d_with_num_heads_attributes(self): + """Test export with 3D inputs and explicit num_heads attributes.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 4 # GQA + head_size = 64 + + Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size) + K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size) + V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size) + + class Attention3DModel(torch.nn.Module): + def forward(self, Q, K, V): + output, _, _, _ = torch.onnx.ops.attention( + Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads + ) + return output + + model = Attention3DModel() + onnx_program = self.export(model, (Q, K, V), opset_version=23) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + + # Verify 3D input shapes + self.assertEqual( + node.inputs[0].shape, [batch_size, q_seq_len, q_num_heads * head_size] + ) + self.assertEqual( + node.inputs[1].shape, [batch_size, kv_seq_len, kv_num_heads * head_size] + ) + self.assertEqual( + node.inputs[2].shape, [batch_size, kv_seq_len, kv_num_heads * head_size] + ) + + # Verify num_heads attributes are set + attrs = node.attributes + self.assertIn("q_num_heads", attrs) + self.assertIn("kv_num_heads", attrs) + self.assertEqual(attrs["q_num_heads"].value, q_num_heads) + self.assertEqual(attrs["kv_num_heads"].value, kv_num_heads) + + def test_attention_export_with_all_attributes(self): + """Test export with all possible attributes set.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + class FullAttributesModel(torch.nn.Module): + def forward(self, Q, K, V): + output, _, _, _ = torch.onnx.ops.attention( + Q, + K, + V, + is_causal=True, + qk_matmul_output_mode=2, + scale=0.25, + softcap=30.0, + softmax_precision=1, # FLOAT + ) + return output + + model = FullAttributesModel() + onnx_program = self.export(model, (Q, K, V), opset_version=23) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + + # Verify all attributes are set correctly + attrs = node.attributes + self.assertIn("is_causal", attrs) + self.assertIn("qk_matmul_output_mode", attrs) + self.assertIn("scale", attrs) + self.assertIn("softcap", attrs) + self.assertIn("softmax_precision", attrs) + + self.assertEqual(attrs["is_causal"].value, 1) # True as int + self.assertEqual(attrs["qk_matmul_output_mode"].value, 2) + self.assertAlmostEqual(attrs["scale"].value, 0.25, places=6) + self.assertAlmostEqual(attrs["softcap"].value, 30.0, places=6) + self.assertEqual(attrs["softmax_precision"].value, 1) + + def test_attention_export_with_different_mask_shapes(self): + """Test export with different attention mask shapes.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + # Test 2D mask + mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool) + + class Mask2DModel(torch.nn.Module): + def forward(self, Q, K, V, mask): + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask) + return output + + model_2d = Mask2DModel() + onnx_program_2d = self.export(model_2d, (Q, K, V, mask_2d), opset_version=23) + + node_2d = onnx_program_2d.model.graph.node(0) + self.assertEqual(node_2d.inputs[3].shape, [q_seq_len, kv_seq_len]) + + # Test 3D mask + mask_3d = torch.randint( + 0, 2, (batch_size, 1, q_seq_len, kv_seq_len), dtype=torch.bool + ) + + class Mask3DModel(torch.nn.Module): + def forward(self, Q, K, V, mask): + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask) + return output + + model_3d = Mask3DModel() + onnx_program_3d = self.export(model_3d, (Q, K, V, mask_3d), opset_version=23) + + node_3d = onnx_program_3d.model.graph.node(0) + self.assertEqual( + node_3d.inputs[3].shape, [batch_size, 1, q_seq_len, kv_seq_len] + ) + + # Test 4D mask + mask_4d = torch.randint( + 0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool + ) + + class Mask4DModel(torch.nn.Module): + def forward(self, Q, K, V, mask): + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask) + return output + + model_4d = Mask4DModel() + onnx_program_4d = self.export(model_4d, (Q, K, V, mask_4d), opset_version=23) + + node_4d = onnx_program_4d.model.graph.node(0) + self.assertEqual( + node_4d.inputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len] + ) + + def test_attention_export_with_float_mask(self): + """Test export with float attention mask.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + float_mask = torch.randn(q_seq_len, kv_seq_len) + + class FloatMaskModel(torch.nn.Module): + def forward(self, Q, K, V, mask): + output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask) + return output + + model = FloatMaskModel() + onnx_program = self.export(model, (Q, K, V, float_mask), opset_version=23) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + self.assertEqual(node.inputs[3].shape, [q_seq_len, kv_seq_len]) + # Verify the mask input has float dtype in the ONNX model + self.assertEqual(node.inputs[3].dtype, ir.DataType.FLOAT) + + def test_attention_export_qk_output_modes(self): + """Test export with different QK output modes.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + for mode in [0, 1, 2, 3]: + + class QKOutputModel(torch.nn.Module): + def __init__(self, qk_mode): + super().__init__() + self.qk_mode = qk_mode + + def forward(self, Q, K, V): + output, _, _, qk_output = torch.onnx.ops.attention( + Q, K, V, qk_matmul_output_mode=self.qk_mode + ) + return output, qk_output + + model = QKOutputModel(mode) + onnx_program = self.export(model, (Q, K, V), opset_version=23) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + + # Verify qk_matmul_output_mode attribute + attrs = node.attributes + if mode != 0: + self.assertIn("qk_matmul_output_mode", attrs) + self.assertEqual(attrs["qk_matmul_output_mode"].value, mode) + + # Verify 4 outputs (output, present_key, present_value, qk_output) + self.assertEqual(len(node.outputs), 4) + + def test_attention_export_mqa(self): + """Test export with Multi-Query Attention (MQA).""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 1 # MQA + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + class MQAModel(torch.nn.Module): + def forward(self, Q, K, V): + output, _, _, _ = torch.onnx.ops.attention(Q, K, V) + return output + + model = MQAModel() + onnx_program = self.export(model, (Q, K, V), opset_version=23) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + + # Verify MQA tensor shapes + self.assertEqual( + node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size] + ) + self.assertEqual( + node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size] + ) # kv_num_heads = 1 + self.assertEqual( + node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size] + ) + + def test_attention_export_with_softmax_precision(self): + """Test export with different softmax precision values.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 8 + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + # Test different ONNX precision types + precision_types = [ + (1, "FLOAT"), + (10, "FLOAT16"), + (11, "DOUBLE"), + (16, "BFLOAT16"), + ] + + for precision_val, precision_name in precision_types: + + class SoftmaxPrecisionModel(torch.nn.Module): + def __init__(self, precision): + super().__init__() + self.precision = precision + + def forward(self, Q, K, V): + output, _, _, _ = torch.onnx.ops.attention( + Q, K, V, softmax_precision=self.precision + ) + return output + + model = SoftmaxPrecisionModel(precision_val) + onnx_program = self.export(model, (Q, K, V), opset_version=23) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + + # Verify softmax_precision attribute + attrs = node.attributes + self.assertIn("softmax_precision", attrs) + self.assertEqual(attrs["softmax_precision"].value, precision_val) + + def test_attention_export_gqa(self): + """Test export and verify output tensor shapes.""" + batch_size, q_seq_len, kv_seq_len = 2, 4, 6 + q_num_heads, kv_num_heads = 8, 4 # GQA + head_size = 64 + + Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size) + K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size) + + class AttentionOutputsModel(torch.nn.Module): + def forward(self, Q, K, V): + return torch.onnx.ops.attention(Q, K, V) + + model = AttentionOutputsModel() + onnx_program = self.export(model, (Q, K, V), opset_version=23) + + node = onnx_program.model.graph.node(0) + self.assertEqual(node.op_type, "Attention") + + # Verify all 4 outputs have correct shapes + outputs = node.outputs + self.assertEqual(len(outputs), 4) + + # output: (batch_size, q_num_heads, q_seq_len, head_size) + self.assertEqual( + outputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size] + ) + + # present_key: (batch_size, kv_num_heads, kv_seq_len, head_size) + self.assertEqual( + outputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size] + ) + + # present_value: (batch_size, kv_num_heads, kv_seq_len, head_size) + self.assertEqual( + outputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size] + ) + + # qk_output: (batch_size, q_num_heads, q_seq_len, kv_seq_len) + self.assertEqual( + outputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len] + ) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_models_quantized_onnxruntime.py b/test/onnx/test_models_quantized_onnxruntime.py index 81a180ea01fd8f..991bb878df22a4 100644 --- a/test/onnx/test_models_quantized_onnxruntime.py +++ b/test/onnx/test_models_quantized_onnxruntime.py @@ -10,6 +10,7 @@ import torch from torch import nn +from torch.testing._internal import common_utils def _get_test_image_tensor(): @@ -95,3 +96,7 @@ def test_resnext101_32x8d(self): pretrained=True, quantize=True ) self.run_test(model, _get_test_image_tensor()) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/test_onnxscript_no_runtime.py b/test/onnx/test_onnxscript_no_runtime.py index fcac54d948d8cd..17e92f0e0117e0 100644 --- a/test/onnx/test_onnxscript_no_runtime.py +++ b/test/onnx/test_onnxscript_no_runtime.py @@ -160,3 +160,10 @@ def custom_selu(g, X): ) loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue())) self.assertEqual(len(loop_selu_proto.functions), 1) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index fbfbb1e85d3741..a6e448173f9e5d 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -107,9 +107,7 @@ def get_lr(self, step): [0] + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m] )[-1] - return [ - init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr - ] + return [init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr] optimizer = SGD([torch.rand(1)], lr=1) diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index 00f5db1478c904..8e060907fe5aa9 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -62,9 +62,9 @@ def _multistep_backprop_diff_hyperparams_fn( kwargs: dict[str, Any], *ignored: Any, ) -> tuple[Tensor, ...]: - assert ( - kwargs["differentiable"] is True - ), "Only call this test function when differentiable=True" + assert kwargs["differentiable"] is True, ( + "Only call this test function when differentiable=True" + ) params = params.clone() params.grad = grad @@ -81,9 +81,9 @@ def _multistep_backprop_diff_hyperparams_fn( # so they're passed in as Tensors (not a tuple) and recognized by gradcheck if "beta1" in kwargs or "beta2" in kwargs: # Prevent just one beta kwarg from being passed in - assert ( - "beta1" in kwargs and "beta2" in kwargs - ), "Both betas should be defined in kwargs" + assert "beta1" in kwargs and "beta2" in kwargs, ( + "Both betas should be defined in kwargs" + ) kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))}) kwargs.update( diff --git a/test/package/test_package_fx.py b/test/package/test_package_fx.py index 9976766f47f3df..ffbcb7a511ccd3 100644 --- a/test/package/test_package_fx.py +++ b/test/package/test_package_fx.py @@ -187,6 +187,27 @@ def forward(self, a): input = torch.rand(2, 3) self.assertEqual(loaded_traced(input), traced(input)) + def test_package_gm_preserve_stack_trace(self): + class SimpleTest(torch.nn.Module): + def forward(self, x): + return torch.relu(x + 3.0) + + st = SimpleTest() + traced = symbolic_trace(st) + + for node in traced.graph.nodes: + node.meta["stack_trace"] = f"test_{node.name}" + + f = BytesIO() + with PackageExporter(f) as pe: + pe.save_pickle("model", "model.pkl", traced) + + f.seek(0) + pi = PackageImporter(f) + loaded_traced = pi.load_pickle("model", "model.pkl") + for node in loaded_traced.graph.nodes: + self.assertEqual(f"test_{node.name}", node.meta["stack_trace"]) + if __name__ == "__main__": run_tests() diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index 96e0466a039dae..267e36ef8f2c22 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -15,7 +15,6 @@ import json import os -import sys import tempfile import unittest from typing import Any @@ -365,9 +364,6 @@ def test_execution_trace_env_disabled(self, device): self.assertTrue(p.execution_trace_observer is None) @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", @@ -419,9 +415,6 @@ def fn(a, b, c): assert found_captured_triton_kernel_node @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 03e922f84b2063..433c9596d05c9c 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -53,7 +53,6 @@ SynchronizedDataLoaderPattern, ) from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_device_type import skipCUDAVersionIn from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_ARM64, @@ -66,8 +65,10 @@ skipIfTorchDynamo, TemporaryDirectoryName, TemporaryFileName, + TEST_CUDA, TEST_WITH_CROSSREF, TEST_WITH_ROCM, + TEST_XPU, TestCase, ) @@ -102,7 +103,6 @@ @unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") class TestProfilerCUDA(TestCase): - @skipCUDAVersionIn([(11, 5)]) # https://github.com/pytorch/pytorch/issues/69023 def test_mem_leak(self): """Checks that there's no memory leak when using profiler with CUDA""" t = torch.rand(1, 1).cuda() @@ -2033,38 +2033,55 @@ def test_user_annotation(self): else: self.assertFalse(evt.is_user_annotation) + @unittest.skipUnless(TEST_CUDA or TEST_XPU, "requires gpu") + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") + def test_basic_profile(self): + # test a really basic profile to make sure no erroneous aten ops are run + x = torch.randn(4, device="cuda") + with torch.profiler.profile(with_stack=True) as p: + x *= 2 + names = [e.name for e in p.events()] + for name in names: + if name.startswith("aten") and name != "aten::mul_": + self.assertTrue(False, "Found unexpected event: " + name) + self.assertTrue("aten::mul_" in names) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_dynamic_toggle(self): - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p: + acc = torch.accelerator.current_accelerator() + self.assertIsNotNone(acc) + device = acc.type + gpu_activity = getattr(ProfilerActivity, device.upper(), None) + self.assertIsNotNone(gpu_activity) + activities = [ProfilerActivity.CPU, gpu_activity] + with profile(activities=activities) as p: with torch.profiler.record_function("test_user_annotation"): - x, y = (torch.rand(4, 4).to("cuda") for _ in range(2)) + x, y = (torch.rand(4, 4).to(device) for _ in range(2)) torch.add(x, y) self.assertTrue(any("aten" in e.name for e in p.events())) - self.assertTrue(any("cuda" in e.name for e in p.events())) + self.assertTrue(any(device in e.name for e in p.events())) - self.assertTrue(any("kernel" in e.name for e in p.events())) + self.assertTrue(any("kernel" in e.name.lower() for e in p.events())) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p1: - p1.toggle_collection_dynamic(False, [ProfilerActivity.CUDA]) + with profile(activities=activities) as p1: + p1.toggle_collection_dynamic(False, [gpu_activity]) with torch.profiler.record_function("test_user_annotation"): - x, y = (torch.rand(4, 4).to("cuda") for _ in range(2)) + x, y = (torch.rand(4, 4).to(device) for _ in range(2)) torch.add(x, y) self.assertTrue(any("aten" in e.name for e in p1.events())) - self.assertTrue(all("cuda" not in e.name for e in p1.events())) + self.assertTrue(all(device not in e.name for e in p1.events())) - self.assertTrue(all("kernel" not in e.name for e in p1.events())) + self.assertTrue(all("kernel" not in e.name.lower() for e in p1.events())) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p2: - p2.toggle_collection_dynamic( - False, [ProfilerActivity.CUDA, ProfilerActivity.CPU] - ) + with profile(activities=activities) as p2: + p2.toggle_collection_dynamic(False, activities) with torch.profiler.record_function("test_user_annotation"): - x, y = (torch.rand(4, 4).to("cuda") for _ in range(2)) + x, y = (torch.rand(4, 4).to(device) for _ in range(2)) torch.add(x, y) self.assertTrue(len(p2.events()) == 0) diff --git a/test/quantization/ao_migration/test_ao_migration.py b/test/quantization/ao_migration/test_ao_migration.py index 020dc6d56d8d4e..84fb6b569b78b9 100644 --- a/test/quantization/ao_migration/test_ao_migration.py +++ b/test/quantization/ao_migration/test_ao_migration.py @@ -1,5 +1,7 @@ # Owner(s): ["oncall: quantization"] +from torch.testing._internal.common_utils import raise_on_run_directly + from .common import AOMigrationTestCase @@ -359,3 +361,7 @@ def test_modules_no_import_nn_intrinsic_quantized_dynamic(self): _ = torch.ao.nn.intrinsic.quantized.dynamic _ = torch.nn.intrinsic.quantized.dynamic + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py index 3d416f3b67a284..c5f186f837c25f 100644 --- a/test/quantization/ao_migration/test_quantization.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -1,5 +1,7 @@ # Owner(s): ["oncall: quantization"] +from torch.testing._internal.common_utils import raise_on_run_directly + from .common import AOMigrationTestCase @@ -219,3 +221,7 @@ def test_function_import_utils(self): "weight_is_statically_quantized", ] self._test_function_import("utils", function_list) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/ao_migration/test_quantization_fx.py b/test/quantization/ao_migration/test_quantization_fx.py index 25b3328c8f44d6..78a45e162ee2ee 100644 --- a/test/quantization/ao_migration/test_quantization_fx.py +++ b/test/quantization/ao_migration/test_quantization_fx.py @@ -1,5 +1,7 @@ # Owner(s): ["oncall: quantization"] +from torch.testing._internal.common_utils import raise_on_run_directly + from .common import AOMigrationTestCase @@ -150,3 +152,7 @@ def test_function_import_fx_utils(self): "maybe_get_next_module", ] self._test_function_import("fx.utils", function_list) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/bc/test_backward_compatibility.py b/test/quantization/bc/test_backward_compatibility.py index c7eabd629f4892..911c26defe2820 100644 --- a/test/quantization/bc/test_backward_compatibility.py +++ b/test/quantization/bc/test_backward_compatibility.py @@ -20,7 +20,11 @@ ) # Testing utils -from torch.testing._internal.common_utils import IS_AVX512_VNNI_SUPPORTED, TestCase +from torch.testing._internal.common_utils import ( + IS_AVX512_VNNI_SUPPORTED, + raise_on_run_directly, + TestCase, +) from torch.testing._internal.quantization_torch_package_models import ( LinearReluFunctional, ) @@ -565,3 +569,7 @@ def forward(self, x): def test_linear_relu_package_quantization_transforms(self): m = LinearReluFunctional(4).eval() self._test_package(m, input_size=(1, 1, 4, 4), generate=False) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/core/experimental/test_adaround_eager.py b/test/quantization/core/experimental/test_adaround_eager.py index 53f943398c4e01..0ef8523836b5d1 100644 --- a/test/quantization/core/experimental/test_adaround_eager.py +++ b/test/quantization/core/experimental/test_adaround_eager.py @@ -134,3 +134,10 @@ def forward(self, x): ada_loss = F.mse_loss(ada_out, float_out) fq_loss = F.mse_loss(fq_out, float_out) self.assertTrue(ada_loss.item() < fq_loss.item()) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/quantization/core/test_docs.py b/test/quantization/core/test_docs.py index 2222ef64b62e58..ab1689cccab2d1 100644 --- a/test/quantization/core/test_docs.py +++ b/test/quantization/core/test_docs.py @@ -11,7 +11,7 @@ SingleLayerLinearModel, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import IS_ARM64, IS_FBCODE +from torch.testing._internal.common_utils import raise_on_run_directly, IS_ARM64, IS_FBCODE import unittest @@ -141,3 +141,6 @@ def test_quantization_doc_custom(self): code = self._get_code(path_from_pytorch, unique_identifier) self._test_code(code, global_inputs) + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/core/test_quantized_functional.py b/test/quantization/core/test_quantized_functional.py index e593b113b27b23..a890c6358e087d 100644 --- a/test/quantization/core/test_quantized_functional.py +++ b/test/quantization/core/test_quantized_functional.py @@ -16,7 +16,7 @@ _make_conv_test_input, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import IS_PPC +from torch.testing._internal.common_utils import raise_on_run_directly, IS_PPC class TestQuantizedFunctionalOps(QuantizationTestCase): def test_relu_api(self): @@ -235,3 +235,6 @@ def test_grid_sample(self, N, C, H, H_out, W, W_out, scale, zero_point): out_exp = torch.quantize_per_tensor(F.grid_sample(X, grid), scale=scale, zero_point=zero_point, dtype=torch.quint8) np.testing.assert_array_almost_equal( out.int_repr().numpy(), out_exp.int_repr().numpy(), decimal=0) + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index 8918696078abd4..b2b2b402327ade 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -31,6 +31,7 @@ qengine_is_qnnpack, qengine_is_onednn, ) +from torch.testing._internal.common_utils import raise_on_run_directly import torch.fx from hypothesis import assume, given from hypothesis import strategies as st @@ -2095,3 +2096,6 @@ def test_linear_decomposed_weight_custom_qmin_qmax(self): self.assertTrue(qmax == 127) found += 1 self.assertTrue(found == 2) + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index d200609c8f8d54..c01e3c318335b3 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -24,8 +24,15 @@ hu.assert_deadline_disabled() from torch.testing._internal.common_cuda import SM80OrLater -from torch.testing._internal.common_utils import TestCase -from torch.testing._internal.common_utils import IS_PPC, IS_MACOS, IS_SANDCASTLE, IS_FBCODE, IS_ARM64 +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + TestCase, + IS_PPC, + IS_MACOS, + IS_SANDCASTLE, + IS_FBCODE, + IS_ARM64 +) from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ override_quantized_engine, supported_qengines, override_qengines, _snr @@ -4504,7 +4511,7 @@ def _test_qlinear_pt2e_helper( qlinear_op, post_op="none", unary_post_op_args=(), - post_op_algorithms=("none"), + post_op_algorithms=("none",), ): qlinear_prepack = torch.ops.onednn.qlinear_prepack linear_op = F.linear @@ -4671,6 +4678,184 @@ def test_qlinear_add_relu_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_pt2e_helper(qlinear, "add_relu") + def _quantize_fp8e4m3(self, t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None): + quant_max = torch.finfo(torch.float8_e4m3fn).max + eps = torch.Tensor([torch.finfo(torch.float32).eps]) + if channelwise: + scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max + scale = torch.max(scale, eps) + scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1)) + qt = t / scale_reshape + else: + scale = scale or t.abs().max().reshape([1]) / quant_max + scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item()) + qt = t / scale + qt = qt.to(torch.float8_e4m3fn) + return qt, scale + + def _dequantize_fp8e4m3(self, qt: torch.Tensor, scale: torch.Tensor): + dqt = qt.float() + if scale.numel() == 1: + # per tensor + dqt = dqt * scale + else: + # per channel + scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1)) + dqt = dqt * scale_reshape + return dqt + + def _test_qlinear_fp8_helper( + self, + qlinear_op, + post_op="none", + unary_post_op_args=(), + post_op_algorithms=("none",), + ): + qlinear_prepack = torch.ops.onednn.qlinear_prepack + linear_op = F.linear + in_channels_list = [4, 8] + out_channels_list = [16, 32] + batch_size = 1 + use_bias_list = [True, False] + weight_quant_per_channel_list = [True, False] + output_dtype_list = [None, torch.float32, torch.bfloat16] + y_scale, y_zp = 0.07, 0 + input_dim_list = [2, 3] + cases = itertools.product( + in_channels_list, out_channels_list, use_bias_list, + weight_quant_per_channel_list, output_dtype_list, post_op_algorithms, input_dim_list) + with override_quantized_engine('onednn'): + for ic, oc, use_bias, weight_quant_per_channel, output_dtype, post_op_algo, input_dim in cases: + used_y_scale = y_scale + used_y_zp = y_zp + fp32_out = output_dtype == torch.float32 + bfloat16_out = output_dtype == torch.bfloat16 + if fp32_out or bfloat16_out: + used_y_scale = 1.0 + x2_scale, x2_zp = 1.0, 0 + else: + x2_scale, x2_zp = 0.3, 0 + x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10 + w = torch.rand(oc, ic) * 10 + qx, x_scale = self._quantize_fp8e4m3(x, channelwise=False) + qw, w_scales = self._quantize_fp8e4m3(w, channelwise=weight_quant_per_channel) + if use_bias: + b = torch.rand(oc) * 10 + else: + b = None + + # compute reference result + x_ref = self._dequantize_fp8e4m3(qx, x_scale) + w_ref = self._dequantize_fp8e4m3(qw, w_scales) + y_ref = linear_op(x_ref, w_ref, b) + + # compute fp8 linear + qw_packed = qlinear_prepack(qw, x.shape) + x_zp = 0 + w_zps = torch.zeros_like(w_scales, dtype=torch.int) + + if post_op in ("none", "relu", "gelu"): + qy = qlinear_op( + qx, x_scale, x_zp, qw_packed, w_scales, w_zps, + b, used_y_scale, used_y_zp, output_dtype, + post_op, unary_post_op_args, post_op_algo + ) + if post_op == "relu": + y_ref = F.relu(y_ref) + elif post_op == "gelu": + y_ref = F.gelu(y_ref, approximate=post_op_algo) + elif post_op in ("sum", "sum_relu"): + x2 = torch.rand_like(y_ref) + x2_q, x2_scale = self._quantize_fp8e4m3(x2, channelwise=False) + x2_dq = self._dequantize_fp8e4m3(x2_q, x2_scale) + unary_post_op = "relu" if post_op == "sum_relu" else "none" + binary_alpha = 1.0 # we only support alpha=1.0 now + # if output_dtype is fp32 or bf16, accumulate on x2 + # if output_dtype is None (fp8), accumulate on x2_dq + accum = x2_q if output_dtype is None else x2 + accum_ref = x2_dq if output_dtype is None else x2.clone() + x2_scale = x2_scale if output_dtype is None else 1.0 + if bfloat16_out: + accum = accum.bfloat16() + accum_ref = accum_ref.bfloat16() + qy = qlinear_op( + qx, x_scale, x_zp, qw_packed, w_scales, w_zps, + accum, b, used_y_scale, used_y_zp, output_dtype, + x2_scale, x2_zp, "sum", binary_alpha, + unary_post_op, unary_post_op_args, post_op_algo + ) + y_ref = y_ref + accum_ref * binary_alpha + if unary_post_op == "relu": + y_ref = F.relu(y_ref) + elif post_op in ("add", "add_relu"): + if output_dtype is not None: + # Only support fp8 output + continue + x2 = torch.rand_like(y_ref) + unary_post_op = "relu" if post_op == "add_relu" else "none" + binary_alpha = 1.0 # we only support alpha=1.0 now + qy = qlinear_op( + qx, x_scale, x_zp, qw_packed, w_scales, w_zps, + x2, b, used_y_scale, used_y_zp, output_dtype, + 1.0, 0, "add", binary_alpha, + unary_post_op, unary_post_op_args, post_op_algo + ) + y_ref = y_ref + x2 * binary_alpha + if unary_post_op == "relu": + y_ref = F.relu(y_ref) + + # Compare results + if output_dtype is None: + y_ref = self._quantize_fp8e4m3(y_ref, False, used_y_scale)[0] + else: + y_ref = y_ref.to(output_dtype) + + self.assertEqual(x.dim(), qy.dim()) + self.assertEqual(y_ref.float(), qy.float()) + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise + self._test_qlinear_fp8_helper(qlinear, "none") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_relu_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise + self._test_qlinear_fp8_helper(qlinear, "relu") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_gelu_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise + post_op_algorithms = ['none', 'tanh'] + self._test_qlinear_fp8_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms) + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_sum_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_helper(qlinear, "sum") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_sum_relu_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_helper(qlinear, "sum_relu") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_add_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_helper(qlinear, "add") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_add_relu_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_helper(qlinear, "add_relu") + @unittest.skipIf(IS_MACOS, "Known test failure on Mac.") class TestQuantizedEmbeddingOps(TestCase): @@ -8265,3 +8450,6 @@ def test_compare_tensor_scalar(self, A, b): note(f"result 3: {result}") self.assertEqual(result_ref, result, msg=f"'tensor.{op}(scalar)'' failed") + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/core/test_top_level_apis.py b/test/quantization/core/test_top_level_apis.py index f76db1cd4139b7..86a4a30af7baa6 100644 --- a/test/quantization/core/test_top_level_apis.py +++ b/test/quantization/core/test_top_level_apis.py @@ -91,3 +91,9 @@ def test_reduce_range(self) -> None: fake_quantize_weight = qconfig.weight() self.assertEqual(fake_quantize_weight.reduce_range, reduce_ranges[1]) + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/quantization/core/test_utils.py b/test/quantization/core/test_utils.py index e4a3d3079c4ecd..7b4d415303225e 100644 --- a/test/quantization/core/test_utils.py +++ b/test/quantization/core/test_utils.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: quantization"] import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import raise_on_run_directly, TestCase from torch.ao.quantization.utils import get_fqn_to_example_inputs from torch.ao.nn.quantized.modules.utils import _quantize_weight from torch.ao.quantization import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver @@ -220,3 +220,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], ], dtype=torch.uint8)) assert x.dtype == dtype + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index 4b7a6587d86cc5..4cf34ac8c6c843 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -29,7 +29,7 @@ from hypothesis import strategies as st import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() -from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_cuda import TEST_CUDA, TEST_WITH_ROCM from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo # Reference method for fake quantize @@ -818,6 +818,9 @@ def test_learnable_forward_per_channel_cpu(self, X): @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,), qparams=hu.qparams(dtypes=torch.quint8))) @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") + @unittest.skip( + "this is broken without changes to any relevant code, " + "we need to remove hypothesis testing in CI") def test_learnable_forward_per_channel_cuda(self, X): torch.random.manual_seed(NP_RANDOM_SEED) X, (_, _, axis, _) = X @@ -954,6 +957,9 @@ def test_learnable_backward_per_channel_cpu(self, X): @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,), qparams=hu.qparams(dtypes=torch.quint8))) @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") + @unittest.skip( + "this is broken without changes to any relevant code, " + "we need to remove hypothesis testing in CI") def test_learnable_backward_per_channel_cuda(self, X): torch.random.manual_seed(NP_RANDOM_SEED) X, (scale, zero_point, axis, torch_type) = X @@ -1032,12 +1038,28 @@ def test_fake_quantize_per_channel_affine_scale_dtypes(self): input, scale, zero_point, axis, quant_min, quant_max ) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") + @unittest.skipIf(TEST_WITH_ROCM, "Not a suitable test for ROCM") + @given(dtype=st.sampled_from([torch.float, torch.float64, torch.half, torch.bfloat16]), + device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu'])) + def test_fake_quantize_per_tensor_affine_inf(self, dtype, device) -> None: + # https://github.com/pytorch/pytorch/issues/154328 + input_tensor = torch.tensor([torch.inf], dtype=dtype).to(device) + scale = 0.01 + zero_point = 0 + quant_min = 0 + quant_max = 255 + result = torch.fake_quantize_per_tensor_affine(input_tensor, scale, zero_point, quant_min, quant_max) + ref_result = (min(quant_max, max(quant_min, torch.round(input_tensor / scale) + zero_point)) - zero_point) * scale + ref_result = torch.Tensor([ref_result]).to(dtype).to(device) + self.assertEqual(result, ref_result) + class TestFusedObsFakeQuant(TestCase): @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), - symmetric_quant=st.booleans()) + symmetric_quant=st.booleans(), use_bool=st.booleans()) @settings(deadline=None) - def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None: + def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant, use_bool) -> None: """ Tests the case where we call the fused_obs_fake_quant op multiple times and update the running_min and max of the activation tensors. @@ -1049,15 +1071,15 @@ def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None: avg_const = 0.01 scale = torch.tensor([1.0], device=device) zero_point = torch.tensor([0], dtype=torch.int, device=device) - observer_on = fake_quant_on = 0 + observer_on = fake_quant_on = False if use_bool else 0 pt_op = torch.fused_moving_avg_obs_fake_quant # enable observer after 2 iterations and fake_quant after 4 iterations for i in range(10): if i > 2: - observer_on = 1 + observer_on = True if use_bool else 1 if i > 4: - fake_quant_on = 1 + fake_quant_on = True if use_bool else 1 x = torch.randn(5, 5, device=device) out = pt_op( @@ -1126,9 +1148,9 @@ def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None: self.assertEqual(out.shape, output_shape) @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), - symmetric_quant=st.booleans()) + symmetric_quant=st.booleans(), use_bool=st.booleans()) @settings(deadline=None) - def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant) -> None: + def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant, use_bool) -> None: """ Tests the case where we call the fused_obs_fake_quant op multiple times and update the running_min and max of the activation tensors. @@ -1145,15 +1167,15 @@ def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_qua scale = torch.empty(m, device=device).fill_(0.1) zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0) - observer_on = fake_quant_on = 0 + observer_on = fake_quant_on = False if use_bool else 0 pt_op = torch.fused_moving_avg_obs_fake_quant # enable observer after 2 iterations and fake_quant after 4 iterations for i in range(10): if i > 2: - observer_on = 1 + observer_on = True if use_bool else 1 if i > 4: - fake_quant_on = 1 + fake_quant_on = True if use_bool else 1 x = torch.randn(size, device=device) out = pt_op( diff --git a/test/quantization/eager/test_bias_correction_eager.py b/test/quantization/eager/test_bias_correction_eager.py index a8e80ad763ae7c..5f0c475f934dd7 100644 --- a/test/quantization/eager/test_bias_correction_eager.py +++ b/test/quantization/eager/test_bias_correction_eager.py @@ -18,6 +18,7 @@ QuantizationTestCase, skipIfNoFBGEMM, ) +from torch.testing._internal.common_utils import raise_on_run_directly class TestBiasCorrectionEager(QuantizationTestCase): @@ -119,3 +120,7 @@ def forward(self, x): for _ in range(50) ] self.correct_artificial_bias_quantize(float_model, img_data) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/eager/test_equalize_eager.py b/test/quantization/eager/test_equalize_eager.py index 985a5d67f3df98..d2ea10f334c541 100644 --- a/test/quantization/eager/test_equalize_eager.py +++ b/test/quantization/eager/test_equalize_eager.py @@ -7,6 +7,7 @@ import torch.nn as nn from torch.ao.quantization.fuse_modules import fuse_modules from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.testing._internal.common_utils import raise_on_run_directly class TestEqualizeEager(QuantizationTestCase): @@ -203,3 +204,7 @@ def forward(self, x): input = torch.randn(20, 3) self.assertEqual(fused_model1(input), fused_model2(input)) self.assertEqual(fused_model1(input), model(input)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/eager/test_numeric_suite_eager.py b/test/quantization/eager/test_numeric_suite_eager.py index a933c7226b73ca..cd11e968599375 100644 --- a/test/quantization/eager/test_numeric_suite_eager.py +++ b/test/quantization/eager/test_numeric_suite_eager.py @@ -38,7 +38,7 @@ test_only_eval_fn, ) from torch.testing._internal.common_quantized import override_qengines -from torch.testing._internal.common_utils import IS_ARM64 +from torch.testing._internal.common_utils import IS_ARM64, raise_on_run_directly class SubModule(torch.nn.Module): @@ -612,3 +612,7 @@ def test_mobilenet_v3(self): from torchvision.models.quantization import mobilenet_v3_large self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False)) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py index 648afa81b5ae97..4a1e6dbdf5c32e 100644 --- a/test/quantization/fx/test_equalize_fx.py +++ b/test/quantization/fx/test_equalize_fx.py @@ -43,6 +43,7 @@ FunctionalConvReluModel, FunctionalConvReluConvModel, ) +from torch.testing._internal.common_utils import raise_on_run_directly # Standard Libraries import copy @@ -894,3 +895,6 @@ def forward(self, x): # Check the order of nodes in the graph self.checkGraphModuleNodes(equalized_model, expected_node_list=node_list) + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index 3b0ff7a5ececae..80ab0f1e8618e5 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -30,6 +30,7 @@ skipIfNoQNNPACK, override_quantized_engine, ) +from torch.testing._internal.common_utils import raise_on_run_directly """ @@ -1944,7 +1945,7 @@ def _get_prepped_for_calibration_model_helper(model, detector_set, example_input example_input = example_input.to(torch.float) q_config_mapping = torch.ao.quantization.get_default_qconfig_mapping() - # if they passed in fusion paramter, make sure to test that + # if they passed in fusion parameter, make sure to test that if fused: model = torch.ao.quantization.fuse_modules(model, model.get_fusion_modules()) @@ -1956,3 +1957,6 @@ def _get_prepped_for_calibration_model_helper(model, detector_set, example_input prepared_for_callibrate_model = model_report.prepare_detailed_calibration() return (prepared_for_callibrate_model, model_report) + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index 84c4f84fa355a7..b53b9b0193e077 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -37,7 +37,7 @@ skip_if_no_torchvision, TwoLayerLinearModel ) -from torch.testing._internal.common_utils import skipIfTorchDynamo +from torch.testing._internal.common_utils import raise_on_run_directly, skipIfTorchDynamo from torch.ao.quantization.quantization_mappings import ( get_default_static_quant_module_mappings, get_default_dynamic_quant_module_mappings, @@ -2915,3 +2915,6 @@ def test_mobilenet_v2(self): m, (torch.randn(1, 3, 224, 224),), qconfig_dict=qconfig_dict, should_log_inputs=False) + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/jit/test_fusion_passes.py b/test/quantization/jit/test_fusion_passes.py index a1a9eceadb53d9..f4580c891e8fb5 100644 --- a/test/quantization/jit/test_fusion_passes.py +++ b/test/quantization/jit/test_fusion_passes.py @@ -4,6 +4,7 @@ import torch from torch.testing import FileCheck from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.testing._internal.common_utils import raise_on_run_directly class TestFusionPasses(QuantizationTestCase): @@ -104,3 +105,7 @@ def forward(self, x, y: float, z): ).check("quantized::add_scalar_relu_out").run(scripted_m.graph) output = scripted_m(qA, 3.0, qC) self.assertEqual(ref_output, output) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/jit/test_ondevice_quantization.py b/test/quantization/jit/test_ondevice_quantization.py index f9d43f183a4a1e..ed6434f4943202 100644 --- a/test/quantization/jit/test_ondevice_quantization.py +++ b/test/quantization/jit/test_ondevice_quantization.py @@ -528,3 +528,10 @@ def test_serialization_deserialization(self): def test_device_side_api(self): model = MyConvLinearModule() self._check_device_side_api(model) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index 264a11cb863136..c634f8ad39704b 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -71,7 +71,10 @@ qengine_is_fbgemm, qengine_is_qnnpack, ) -from torch.testing._internal.common_utils import set_default_dtype +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + set_default_dtype, +) from torch.testing._internal.jit_utils import ( attrs_with_prefix, get_forward, @@ -3880,3 +3883,7 @@ def test_linear_dynamic_fp16(self): ) # compare result with eager mode self.assertEqual(quantized_model(self.calib_data[0][0]), result_eager) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index 4a5cb6edaeb626..b830763d4ba86f 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -26,7 +26,7 @@ ) from torch.export import export_for_training from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly class TestHelperModules: @@ -307,3 +307,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: example_inputs, BackendAQuantizer(), ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/pt2e/test_graph_utils.py b/test/quantization/pt2e/test_graph_utils.py index 43d1d55140fb9f..2a26ff682b93fe 100644 --- a/test/quantization/pt2e/test_graph_utils.py +++ b/test/quantization/pt2e/test_graph_utils.py @@ -9,7 +9,11 @@ get_equivalent_types, update_equivalent_types_dict, ) -from torch.testing._internal.common_utils import IS_WINDOWS, TestCase +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + raise_on_run_directly, + TestCase, +) class TestGraphUtils(TestCase): @@ -121,3 +125,7 @@ def forward(self, x): [torch.nn.Conv2d, torch.nn.ReLU6], ) self.assertEqual(len(fused_partitions), 1) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 96eff3a789f282..fe9e1b295561b1 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -12,7 +12,11 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + raise_on_run_directly, + skipIfCrossRef, +) class TestHelperModules: @@ -513,3 +517,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: BackendAQuantizer(), node_tags, ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index deff8e4987e50d..53c7939411631e 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -21,7 +21,12 @@ ) from torch.export import export_for_training from torch.testing._internal.common_quantization import TestHelperModules -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + raise_on_run_directly, + skipIfCrossRef, + TestCase, +) @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") @@ -346,3 +351,7 @@ def test_added_node_gets_unique_id(self) -> None: # may change with future node ordering changes. self.assertNotEqual(handles_after_modification["relu_default"], 0) self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 3f9fde0444a0d7..9ffb63028ff257 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2552,7 +2552,7 @@ def prepare_obs_or_fq_callback( torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ): - # Entire graph share the same qspec which was overriden by FixedQParamsObserver + # Entire graph share the same qspec which was overridden by FixedQParamsObserver self.assertEqual(n.args[1], 0.125) self.assertEqual(n.args[2], 42) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index b8b58ea006c3f0..dee1a840a09d42 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -43,6 +43,7 @@ skipIfNoQNNPACK, ) from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import raise_on_run_directly class PT2EQATTestCase(QuantizationTestCase): @@ -1177,3 +1178,7 @@ def test_mixing_qat_ptq(self): self.checkGraphModuleNodes( exported_model.graph_module, expected_node_occurrence=node_occurrence ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index 3648ac352dc4ab..5c5a7cce505b6d 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -17,6 +17,7 @@ skipIfNoQNNPACK, TestHelperModules, ) +from torch.testing._internal.common_utils import raise_on_run_directly @skipIfNoQNNPACK @@ -306,3 +307,7 @@ def forward(self, x, y): ref_node_occurrence, non_ref_node_occurrence, ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index e0fcbbc9b515d2..1f1020b9bd41c0 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -25,7 +25,10 @@ skipIfNoX86, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + raise_on_run_directly, + skipIfTorchDynamo, +) class NodePosType(Enum): @@ -2858,3 +2861,7 @@ def test_lowering_to_x86(self): node_list, lower=True, ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index 4e14dfd27ae292..37bac5c8f51f97 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -38,6 +38,7 @@ TestHelperModules, ) from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import raise_on_run_directly @skipIfNoQNNPACK @@ -1080,3 +1081,7 @@ def test_resnet18(self): self.assertTrue( compute_sqnr(after_quant_result, after_quant_result_fx) > 35 ) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/run_test.py b/test/run_test.py index 23b5cad7e7544a..04e05b7b71d3c7 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -213,6 +213,7 @@ def __contains__(self, item): "test_unary_ufuncs", # these tests fail when cuda is not available "inductor/test_aot_inductor", + "inductor/test_best_config", "inductor/test_cudacodecache", "inductor/test_inductor_utils", "inductor/test_inplacing_pass", @@ -240,6 +241,16 @@ def __contains__(self, item): "test_fx", # some false errors "doctests", + # new failures to investigate and fix + "cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic", + "test_tensorboard", + # onnx + protobuf failure, see + # https://github.com/protocolbuffers/protobuf/issues/22104 + "dynamo/test_backends", + "dynamo/test_modules", + "inductor/test_config", + "test_public_bindings", + "test_testing", ] XPU_BLOCKLIST = [ @@ -422,6 +433,7 @@ def _is_cpp_test(test): "test_decomp", "test_cpp_extensions_jit", "test_jit", + "test_matmul_cuda", "test_ops", "test_ops_jit", "dynamo/test_recompile_ux", @@ -619,6 +631,7 @@ def run_test( stepcurrent_key, output, options.continue_through_error, + test_file, ) else: command.extend([f"--sc={stepcurrent_key}", "--print-items"]) @@ -697,6 +710,7 @@ def run_test_retries( stepcurrent_key, output, continue_through_error, + test_file, ): # Run the test with -x to stop at first failure. Rerun the test by itself. # If it succeeds, move on to the rest of the tests in a new process. If it @@ -772,6 +786,8 @@ def print_to_file(s): print_to_file("Retrying single test...") print_items = [] # do not continue printing them, massive waste of space + if "null" in num_failures: + num_failures[f"'{test_file}'"] = num_failures.pop("null") consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3] flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3] if len(flaky_failures) > 0: @@ -1279,6 +1295,16 @@ def parse_args(): "(including functorch tests)." ), ) + parser.add_argument( + "--einops", + "--einops", + action="store_true", + help=( + "If this flag is present, we will only run einops tests. " + "If this flag is not present, we will run all tests " + "(including einops tests)." + ), + ) parser.add_argument( "--mps", "--mps", @@ -1530,6 +1556,15 @@ def get_selected_tests(options) -> list[str]: filter(lambda test_name: test_name in FUNCTORCH_TESTS, selected_tests) ) + # Filter to only run einops tests when --einops option is specified + if options.einops: + selected_tests = list( + filter( + lambda test_name: test_name.startswith("test/dynamo/test_einops"), + selected_tests, + ) + ) + if options.cpp: selected_tests = list( filter(lambda test_name: test_name in CPP_TESTS, selected_tests) @@ -1792,11 +1827,6 @@ def run_tests( x for x in selected_tests if x not in selected_tests_parallel ] - # See Note [ROCm parallel CI testing] - pool = get_context("spawn").Pool( - NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1 - ) - # NB: This is a hack to make conftest.py and files it depends on available # on CPP_TESTS_DIR. We should see if the file could be turned into a # full-fledge ptest plugin instead @@ -1814,24 +1844,16 @@ def run_tests( ): shutil.copy(os.path.join(test_directory, conftest_file), cpp_file) - def handle_error_messages(failure: Optional[TestFailure]): - if failure is None: + def handle_complete(failure: Optional[TestFailure]): + failed = failure is not None + if IS_CI and options.upload_artifacts_while_running: + zip_and_upload_artifacts(failed) + if not failed: return False failures.append(failure) print_to_stderr(failure.message) return True - def parallel_test_completion_callback(failure): - test_failed = handle_error_messages(failure) - if IS_CI and options.upload_artifacts_while_running: - zip_and_upload_artifacts(test_failed) - if ( - test_failed - and not options.continue_through_error - and not RERUN_DISABLED_TESTS - ): - pool.terminate() - keep_going_message = ( "\n\nTip: You can keep running tests even on failure by passing --keep-going to run_test.py.\n" "If running on CI, add the 'keep-going' label to your PR and rerun your jobs." @@ -1843,7 +1865,7 @@ def parallel_test_completion_callback(failure): if can_run_in_pytest(test): options_clone.pytest = True failure = run_test_module(test, test_directory, options_clone) - test_failed = handle_error_messages(failure) + test_failed = handle_complete(failure) if ( test_failed and not options.continue_through_error @@ -1858,7 +1880,7 @@ def parallel_test_completion_callback(failure): options_clone.pytest = True options_clone.additional_args.extend(["-m", "serial"]) failure = run_test_module(test, test_directory, options_clone) - test_failed = handle_error_messages(failure) + test_failed = handle_complete(failure) if ( test_failed and not options.continue_through_error @@ -1866,7 +1888,24 @@ def parallel_test_completion_callback(failure): ): raise RuntimeError(failure.message + keep_going_message) - os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS) + # This is used later to constrain memory per proc on the GPU. On ROCm + # the number of procs is the number of GPUs, so we don't need to do this + os.environ["NUM_PARALLEL_PROCS"] = str(1 if torch.version.hip else NUM_PROCS) + + # See Note [ROCm parallel CI testing] + pool = get_context("spawn").Pool( + NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1 + ) + + def parallel_test_completion_callback(failure): + test_failed = handle_complete(failure) + if ( + test_failed + and not options.continue_through_error + and not RERUN_DISABLED_TESTS + ): + pool.terminate() + for test in selected_tests_parallel: options_clone = copy.deepcopy(options) if can_run_in_pytest(test): diff --git a/test/slow_tests.json b/test/slow_tests.json index 80788d45b10832..65d9f1c5b7fc7a 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,241 +1,259 @@ { - "EndToEndLSTM (__main__.RNNTest)": 192.89666239420572, - "MultiheadAttention (__main__.ModulesTest)": 136.05533345540366, - "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 86.2237777709961, - "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 63.802555084228516, - "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.73824987411498, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 90.2943344116211, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 189.56100463867188, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 104.09633127848308, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 97.2173360188802, - "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 774.8972473144531, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 130.95370025634764, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 497.9846666124132, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 499.8869934082031, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 127.44683583577473, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 78.60400136311848, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 87.86199951171875, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 273.40610758463544, - "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 151.22699991861978, - "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 338.79077487521704, - "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 394.75244479709204, - "test_collect_callgrind (__main__.TestBenchmarkUtils)": 310.34222242567273, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 88.13049952189128, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 87.3980000813802, - "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 61.15833346048991, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 60.6113338470459, - "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 87.60466766357422, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 82.47533416748047, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 62.53499857584635, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 381.16466267903644, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 60.08166631062826, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 382.3089904785156, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 61.6903330485026, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 250.1481679280599, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 290.079340616862, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1157.0956624348958, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.68783378601074, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 984.5683288574219, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 69.99383290608723, - "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 65.60000101725261, - "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 62.93949953715006, - "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 65.45450019836426, - "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 62.865333557128906, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 61.61699888441298, - "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 78.07533264160156, - "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 80.1913350423177, - "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 83.79266866048177, - "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 170.5373331705729, - "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 164.96800231933594, - "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 151.26199340820312, - "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 159.55667114257812, - "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 190.70066833496094, - "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 184.95733133951822, - "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 190.6016642252604, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 107.89249928792317, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 112.23483403523763, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 115.43733469645183, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 96.89950052897136, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 89.77849960327148, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 97.53133392333984, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 98.43199920654297, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 214.80700174967447, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 274.826665242513, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 75.38883463541667, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 78.37283325195312, - "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 142.29466756184897, - "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 134.29466756184897, - "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 137.10233052571616, - "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 1140.7083333333333, - "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 1023.3370157877604, - "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 1029.234354654948, - "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 925.2996622721354, - "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 869.8800048828125, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1064.2745056152344, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1091.0223286946614, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1137.6966959635417, - "test_comprehensive_nn_functional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 843.7843424479166, - "test_comprehensive_nn_functional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 678.9483439127604, - "test_comprehensive_nn_functional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 682.2549845377604, - "test_comprehensive_nn_functional_max_pool3d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 655.1159871419271, - "test_comprehensive_nn_functional_max_pool3d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 688.5863444010416, - "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 452.19276064918154, - "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 494.6151631673177, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 106.57633209228516, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 102.02733612060547, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 104.28433481852214, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 60.384334564208984, - "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 63.31666819254557, - "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 106.76466623942058, - "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 100.48400115966797, - "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 113.65899912516277, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 122.09866587320964, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 112.16566721598308, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 65.88233311971028, - "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 62.29449907938639, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 66.21600023905437, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 70.06133270263672, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 425.8736686706543, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 122.4875005086263, - "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 92.10933494567871, - "test_constructor_autograd_SparseCSR_cuda (__main__.TestSparseAnyCUDA)": 68.79866600036621, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 335.8283437093099, - "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 499.17123074001734, - "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 70.24400075276692, - "test_conv3d_binary_broadcast_shapes_cpu_cpu (__main__.TestPatternMatcherGenericCPU)": 63.91080042521159, - "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.03383127848308, - "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 75.30299886067708, - "test_count_nonzero_all (__main__.TestBool)": 613.8112182617188, - "test_create_rand_mask_from_inputs_dynamic_shapes (__main__.DynamicShapesReproTests)": 108.1828633221713, - "test_custom_module_lstm (__main__.TestQuantizedOps)": 683.9032355414497, - "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 377.1489969889323, - "test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 60.342000325520836, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.97183227539062, - "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 313.06516313552856, - "test_fail_arithmetic_ops.py (__main__.TestTyping)": 67.29399998982747, - "test_fail_random.py (__main__.TestTyping)": 77.66903102397919, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 70.51049995422363, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 87.53116607666016, - "test_fn_gradgrad_map_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 79.7923355102539, - "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 502.4463297526042, - "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 354.87266540527344, - "test_fuse_large_params_cpu (__main__.CpuTests)": 77.44819946289063, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 173.53422376844617, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 172.43211195203992, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 91.95266723632812, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 87.97433344523112, - "test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 66.18542540073395, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 88.75333404541016, - "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 129.9683380126953, - "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 115.12666829427083, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 225.0261662801107, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 134.22716649373373, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 153.7135009765625, - "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 589.8596700032552, - "test_group_norm (__main__.TestQuantizedOps)": 269.3797738817003, - "test_indirect_device_assert (__main__.TritonCodeGenTests)": 252.2490030924479, - "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 68.43844350179036, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 130.6469980875651, - "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 127.96811082628038, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 117.09833272298177, - "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 612.490000406901, - "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 70.4946657816569, - "test_linear (__main__.TestStaticQuantizedModule)": 236.15133497450086, - "test_linear_relu (__main__.TestStaticQuantizedModule)": 83.33777703179254, - "test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 652.8833414713541, - "test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 543.8513387044271, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 66.57833099365234, - "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 110.17455546061198, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 61.217555575900604, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.5684445699056, - "test_out_variant_custom_op_dynamic_shapes (__main__.DynamicShapesMiscTests)": 81.44568417289041, - "test_proper_exit (__main__.TestDataLoader)": 232.6415023803711, - "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 217.0334955851237, - "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 61.368499755859375, - "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 151.07800123426648, - "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 121.81777699788411, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 70.08533477783203, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 83.84000142415364, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 78.2229995727539, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 68.46966552734375, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 84.02066802978516, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 81.13233184814453, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 64.79833348592122, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 79.65733591715495, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 79.89266713460286, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 67.25400034586589, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 88.87366739908855, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.77400207519531, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 67.29300181070964, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 81.40399932861328, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 81.21766662597656, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 69.02466583251953, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 86.9990005493164, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 85.08866882324219, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 66.32600021362305, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 83.8606669108073, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 79.64900207519531, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 72.0116678873698, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 88.13433329264323, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 83.44666544596355, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 437.65899658203125, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 776.3168334960938, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 651.9120076497396, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1365.6099853515625, - "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 77.25249989827473, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 67.37566757202148, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 251.60799916585287, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 90.72666676839192, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 169.8308308919271, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 67.62333424886067, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 116.40383275349934, - "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 64.57591004805131, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 76.36966705322266, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 142.90583165486655, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 88.87016677856445, - "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 111.07800165812175, - "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 67.357666015625, - "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 60.83366584777832, - "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 72.8857773674859, - "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 152.97100067138672, - "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 115.08066813151042, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 129.90166558159723, - "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 114.59166463216145, - "test_sum_all_cpu_float64 (__main__.TestReductionsCPU)": 227.3716271975136, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 158.47800040245056, - "test_terminate_handler_on_crash (__main__.TestTorch)": 100.17410944567786, - "test_terminate_signal (__main__.ForkTest)": 132.80577541804976, - "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 132.27144405080213, - "test_terminate_signal (__main__.SpawnTest)": 135.38566891352335, - "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 83.95989015367296, - "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 118.00528522602482, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 86.77699788411458, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 90.67566808064778, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 90.42966715494792, - "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 123.54350090026855, - "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 120.91033109029134, - "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 105.34733390808105, - "test_unary_ops (__main__.TestTEFuserDynamic)": 179.04810841878256, - "test_unary_ops (__main__.TestTEFuserStatic)": 166.6141096750895, - "test_unwaited (__main__.CommTest)": 60.188666025797524, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 79.1989974975586, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 76.45166651407878, - "test_vmapjvpvjp_diff_cuda_float32 (__main__.TestOperatorsCUDA)": 60.167444441053604, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 88.72733561197917, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 347.67066701253253, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 68.70466550191243, - "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 73.99283345540364, - "test_vmapjvpvjp_linalg_pinv_singular_cuda_float32 (__main__.TestOperatorsCUDA)": 66.16983350118001, - "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 72.80399894714355, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 83.73800150553386, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 69.54799906412761, - "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 65.50566673278809, - "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 91.0239995320638, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 93.5385004679362, - "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.33033307393392, - "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 61.884665171305336, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.04450098673503, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 73.37666575113933, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 75.89116541544597, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 137.67250188191733 + "EndToEndLSTM (__main__.RNNTest)": 204.19766235351562, + "MultiheadAttention (__main__.ModulesTest)": 144.12199910481772, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 89.76433372497559, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 60.78266716003418, + "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 71.0364990234375, + "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.0049991607666, + "test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 93.56200218200684, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 122.12249755859375, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 202.69849395751953, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 178.81350326538086, + "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.96700286865234, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 153.12700271606445, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1025.2469787597656, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 103.65142822265625, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 489.04433186848956, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 487.5743357340495, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 123.70524787902832, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 73.78200149536133, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.94600009918213, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 184.28466669718424, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 129.23699951171875, + "test_cat_2k_args (__main__.TestTEFuserDynamic)": 105.84207906299515, + "test_cat_2k_args (__main__.TestTEFuserStatic)": 118.78279071262008, + "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 385.1773325602214, + "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 415.5623270670573, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 302.11150614420575, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 89.50349998474121, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 88.17774963378906, + "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 63.955399703979495, + "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 62.5629997253418, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 86.9015007019043, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 91.01150131225586, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 65.91899871826172, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 404.79449462890625, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 64.88150024414062, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 390.1374969482422, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 65.5984992980957, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 242.27249908447266, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 288.8112487792969, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1021.2769927978516, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 65.61349868774414, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1078.281997680664, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 69.2790002822876, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.57924842834473, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 69.62350082397461, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 70.42674827575684, + "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 69.98200035095215, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 68.36449909210205, + "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 61.57929992675781, + "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 83.75249862670898, + "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 82.4640007019043, + "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 97.88249969482422, + "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 181.32449340820312, + "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 171.81600189208984, + "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 169.01850128173828, + "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 162.26849365234375, + "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 213.85850524902344, + "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 206.47949981689453, + "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 202.62000274658203, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 116.22050094604492, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 116.08074951171875, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 117.2509994506836, + "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 128.12999725341797, + "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 126.67150115966797, + "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 118.62150192260742, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 91.70499992370605, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 96.74850082397461, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 99.50400161743164, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 117.61600112915039, + "test_comprehensive_nn_functional_grid_sample_cuda_bfloat16 (__main__.TestDecompCUDA)": 61.68622292412652, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 285.966251373291, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 286.9002494812012, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 75.1487263766202, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 73.88652204430622, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 72.14090042114258, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 75.99790482293992, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 90.85624885559082, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 94.66775131225586, + "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 169.1240005493164, + "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 162.36299896240234, + "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 169.9939956665039, + "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 1291.2069702148438, + "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 1166.8740234375, + "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 1116.0714721679688, + "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 994.1804809570312, + "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 982.9049987792969, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1249.2317504882812, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1255.1132507324219, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.1112365722656, + "test_comprehensive_nn_functional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 929.2744750976562, + "test_comprehensive_nn_functional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 740.7665100097656, + "test_comprehensive_nn_functional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 753.1840209960938, + "test_comprehensive_nn_functional_max_pool3d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 709.1789855957031, + "test_comprehensive_nn_functional_max_pool3d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 723.7825012207031, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 505.14124298095703, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 504.5137481689453, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 124.45050048828125, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 129.04349899291992, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 124.9415054321289, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 61.10431827198375, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 60.84139135609502, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.411044245180875, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 74.14450073242188, + "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 70.26900100708008, + "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 113.28900146484375, + "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 113.8120002746582, + "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 118.50249862670898, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 110.55724716186523, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 109.60800170898438, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 62.25924873352051, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.93900108337402, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 64.4350004196167, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 67.63899898529053, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 64.41949939727783, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 88.83725166320801, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 94.19975090026855, + "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 70.39225006103516, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 110.43149948120117, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 267.1796696980794, + "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 83.31399917602539, + "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 73.52913335164388, + "test_conv_bn_fuse_cpu (__main__.CpuTests)": 87.52300071716309, + "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 70.49650128682454, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.59375190734863, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.47550010681152, + "test_count_nonzero_all (__main__.TestBool)": 671.6698404947916, + "test_create_rand_mask_from_inputs_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.40256757321565, + "test_cusparse_multiple_threads_same_device (__main__.TestCuda)": 61.286570753370015, + "test_custom_module_lstm (__main__.TestQuantizedOps)": 322.2589975992839, + "test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 63.88200124104818, + "test_diff_hyperparams_sharding_strategy_str_no_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 62.30999883015951, + "test_diff_hyperparams_sharding_strategy_str_shard_grad_op (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 60.74166615804037, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 94.44025039672852, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 62.17200152079264, + "test_fail_random.py (__main__.TestTyping)": 69.17674193843719, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 89.89649963378906, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 89.71349906921387, + "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 211.4114990234375, + "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 132.2224998474121, + "test_forward_ad_svd_lowrank_cpu_float32 (__main__.TestCompositeComplianceCPU)": 97.98550033569336, + "test_fuse_large_params_cpu (__main__.CpuTests)": 85.25499979654948, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 153.6696662902832, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 152.69783401489258, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 125.90250015258789, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 92.34375190734863, + "test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 64.21676149822417, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 123.23225212097168, + "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 128.2329978942871, + "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 125.50249862670898, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 269.924503326416, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 125.14425086975098, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 176.18125534057617, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 571.6790008544922, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 271.3869934082031, + "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 70.60233497619629, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 94.59475135803223, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 156.87916564941406, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 116.82549667358398, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 122.43350219726562, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 97.80849838256836, + "test_linear (__main__.TestStaticQuantizedModule)": 91.20350011189778, + "test_lobpcg_basic_cuda_float64 (__main__.TestLinalgCUDA)": 80.49750232696533, + "test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.5243326822916, + "test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 516.3270060221354, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 62.10499954223633, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 132.2573331197103, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 60.5918337504069, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.55583381652832, + "test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 76.3048454110439, + "test_out_variant_custom_op_dynamic_shapes (__main__.DynamicShapesMiscTests)": 79.89417275138523, + "test_proper_exit (__main__.TestDataLoader)": 246.03874969482422, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 238.50450134277344, + "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 60.665499210357666, + "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 144.47583134969076, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn1d)": 64.56902594315379, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 62.75823718623111, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 78.22150039672852, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 91.45999908447266, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.36700057983398, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 75.94200134277344, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 90.85449981689453, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.81800079345703, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.60699844360352, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 86.91350173950195, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 86.44449996948242, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 75.63949966430664, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.4010009765625, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 91.33300018310547, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.61149978637695, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 87.62849807739258, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 92.85200119018555, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 80.0374984741211, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 91.739501953125, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.72999954223633, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.18899917602539, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 86.14900207519531, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 86.40299987792969, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 91.67950057983398, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.89799880981445, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 426.1199951171875, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 788.5797576904297, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 645.885986328125, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1150.8487548828125, + "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 63.335999488830566, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 83.89899826049805, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 219.58025360107422, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 94.60549926757812, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 175.31949615478516, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 80.4379997253418, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 117.42150115966797, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 85.20699691772461, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 140.9692497253418, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 96.70874977111816, + "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 98.57174968719482, + "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 67.52949905395508, + "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 67.56499767303467, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 74.22599983215332, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 155.21199798583984, + "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 107.1515007019043, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 147.1798324584961, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 113.45633443196614, + "test_sum_all_cpu_float64 (__main__.TestReductionsCPU)": 247.22645892538694, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 157.11124801635742, + "test_terminate_handler_on_crash (__main__.TestTorch)": 112.97849977016449, + "test_terminate_signal (__main__.ForkTest)": 138.1518301591277, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 138.05183368052045, + "test_terminate_signal (__main__.SpawnTest)": 141.89416662851968, + "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 174.1542510986328, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 90.13175010681152, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 76.49149990081787, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 94.42874908447266, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 121.6265001296997, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 125.12349891662598, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 107.74774932861328, + "test_unary_ops (__main__.TestTEFuserDynamic)": 165.78299776713052, + "test_unary_ops (__main__.TestTEFuserStatic)": 147.64583269755045, + "test_unwaited (__main__.CommTest)": 60.16999944051107, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 83.02149963378906, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 78.58975028991699, + "test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.85185841151646, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 90.26150131225586, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 147.89199574788412, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 61.03449821472168, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 70.4332504272461, + "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 75.01950168609619, + "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 75.30324840545654, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 71.50250053405762, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 75.63800144195557, + "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 62.03466642470587, + "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 168.41749572753906, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 68.871750831604, + "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 77.00125026702881, + "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 94.14950180053711, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 90.05274963378906, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 110.32675170898438, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 92.06575202941895, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 158.6577491760254 } \ No newline at end of file diff --git a/test/test_ao_sparsity.py b/test/test_ao_sparsity.py index db6b2222a3bee7..5ae5a0874318e5 100644 --- a/test/test_ao_sparsity.py +++ b/test/test_ao_sparsity.py @@ -1,4 +1,5 @@ # Owner(s): ["module: unknown"] +import logging # Kernels from ao.sparsity.test_kernels import ( # noqa: F401 @@ -56,4 +57,9 @@ if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.INFO, + ) + run_tests() diff --git a/test/test_autograd.py b/test/test_autograd.py index e7dbc99ce25d4e..6c9241192fc0b6 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -3028,8 +3028,8 @@ def check_index(x, y, idx): check_index(x, y, ([1, 2, 3], [0])) check_index(x, y, ([1, 2], [2, 1])) check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]])) - check_index(x, y, ([slice(None), [2, 3]])) - check_index(x, y, ([[2, 3], slice(None)])) + check_index(x, y, ((slice(None), [2, 3]))) + check_index(x, y, (([2, 3], slice(None)))) # advanced indexing, with less dim, or ellipsis check_index(x, y, ([0])) @@ -3061,8 +3061,8 @@ def check_index(x, y, idx): # advanced indexing, with a tensor wrapped in a variable z = torch.LongTensor([0, 1]) zv = Variable(z, requires_grad=False) - seq = [z, Ellipsis] - seqv = [zv, Ellipsis] + seq = (z, Ellipsis) + seqv = (zv, Ellipsis) if y.grad is not None: with torch.no_grad(): @@ -3086,7 +3086,7 @@ def test_indexing_duplicates(self): x = torch.arange(1.0, 17).view(4, 4) y = Variable(x, requires_grad=True) - idx = [[1, 1, 3, 2, 1, 2], [0]] + idx = ([1, 1, 3, 2, 1, 2], [0]) y[idx].sum().backward() expected_grad = torch.zeros(4, 4) for i in idx[0]: @@ -3097,7 +3097,7 @@ def test_indexing_duplicates(self): x = torch.arange(1.0, 17).view(4, 4) y = Variable(x, requires_grad=True) - idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]] + idx = ([[1, 2], [0, 0]], [[0, 1], [1, 1]]) y[idx].sum().backward() expected_grad = torch.tensor( [ @@ -3112,7 +3112,7 @@ def test_indexing_duplicates(self): x = torch.arange(1.0, 65).view(4, 4, 4) y = Variable(x, requires_grad=True) - idx = [[1, 1, 1], slice(None), slice(None)] + idx = ([1, 1, 1], slice(None), slice(None)) y[idx].sum().backward() expected_grad = torch.empty(4, 4, 4).zero_() expected_grad[1].fill_(3) @@ -3541,32 +3541,32 @@ def test_setitem(self): self._test_setitem((5, 5), 1) self._test_setitem((5,), 1) self._test_setitem((1,), 0) - self._test_setitem((10,), [[0, 4, 2]]) - self._test_setitem((5, 5), [[0, 4], [2, 2]]) - self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]]) - self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)]) - self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)]) - self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]]) - self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)]) + self._test_setitem((10,), ([0, 4, 2])) + self._test_setitem((5, 5), ([0, 4], [2, 2])) + self._test_setitem((5, 5, 5), (slice(None), slice(None), [1, 3])) + self._test_setitem((5, 5, 5), (slice(None), [1, 3], slice(None))) + self._test_setitem((5, 5, 5), ([1, 3], slice(None), slice(None))) + self._test_setitem((5, 5, 5), (slice(None), [2, 4], [1, 3])) + self._test_setitem((5, 5, 5), ([1, 3], [2, 4], slice(None))) self._test_setitem_tensor((5, 5), 3) - self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]]) + self._test_setitem_tensor((5, 5), ([0, 1], [1, 0])) self._test_setitem_tensor((5,), 3) self._test_setitem_tensor( (5,), Variable(torch.LongTensor([3]), requires_grad=False).sum() ) self._test_setitem_tensor((5,), [[0, 1, 2, 3]]) - self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]]) - self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)]) - self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)]) - self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]]) - self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)]) + self._test_setitem_tensor((5, 5, 5), (slice(None), slice(None), [1, 3])) + self._test_setitem_tensor((5, 5, 5), (slice(None), [1, 3], slice(None))) + self._test_setitem_tensor((5, 5, 5), ([1, 3], slice(None), slice(None))) + self._test_setitem_tensor((5, 5, 5), (slice(None), [2, 4], [1, 3])) + self._test_setitem_tensor((5, 5, 5), ([1, 3], [2, 4], slice(None))) self._test_setitem_tensor( (5, 5, 5), - [ + ( Variable(torch.LongTensor([1, 3]), requires_grad=False), [2, 4], slice(None), - ], + ), ) def test_setitem_mask(self): @@ -3725,6 +3725,18 @@ def backward(ctx, grad_x): f.next_functions with self.assertRaisesRegex(RuntimeError, "Attribute 'name' is invalid"): f.name() + with self.assertRaisesRegex( + RuntimeError, "Attribute '_sequence_nr' is invalid" + ): + f._sequence_nr() + with self.assertRaisesRegex( + RuntimeError, "Attribute '_set_sequence_nr' is invalid" + ): + f._set_sequence_nr(2) + with self.assertRaisesRegex( + RuntimeError, "Attribute '_input_metadata' is invalid" + ): + f._input_metadata with self.assertRaisesRegex( RuntimeError, "underlying PyNode has already been deallocated" ): diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index bdc0d7329df594..05226def3b4310 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1688,12 +1688,11 @@ def test_cpu_tensor_pow_cuda_scalar_tensor(self, device): @onlyCUDA @dtypes(torch.complex64, torch.complex128) - def test_pow_cuda_complex_extremal_failing(self, device, dtype): + def test_pow_cuda_complex_extremal_passing(self, device, dtype): t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device) - with self.assertRaises(AssertionError): - cuda_out = t.pow(2) - cpu_out = t.cpu().pow(2) - self.assertEqual(cpu_out, cuda_out) + cuda_out = t.pow(2) + cpu_out = t.cpu().pow(2) + self.assertEqual(cpu_out, cuda_out) @skipIfTorchDynamo() @onlyNativeDeviceTypes diff --git a/test/test_bundled_images.py b/test/test_bundled_images.py index 1919e1cd4fe34a..c8b6a6140025e8 100644 --- a/test/test_bundled_images.py +++ b/test/test_bundled_images.py @@ -92,3 +92,10 @@ def forward(self, arg): im2_tensor = torch.ops.fb.image_decode_to_NCHW(byte_tensor, weight, bias) self.assertEqual(raw_data.shape, im2_tensor.shape) self.assertEqual(raw_data, im2_tensor, atol=0.1, rtol=1e-01) + + +if __name__ == "__main__": + raise RuntimeError( + "This test is not currently used and should be " + "enabled in discover_tests.py if required." + ) diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index a8a93f14996b5b..8bca2264d00277 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -317,13 +317,13 @@ def test_conv_backend_override(self): weight = torch.empty(6, 4, 2, 2, device="maia", requires_grad=True) bias = torch.empty(6, device="maia") - # Make sure forward is overriden + # Make sure forward is overridden out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1) self.assertEqual(maia_extension.get_test_int(), 2) self.assertEqual(out.shape[0], input.shape[0]) self.assertEqual(out.shape[1], weight.shape[0]) - # Make sure backward is overriden + # Make sure backward is overridden # Double backward is dispatched to _convolution_double_backward. # It is not tested here as it involves more computation/overrides. grad = torch.autograd.grad(out, input, out, create_graph=True) diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index dc7269c865f9cd..c7e104963fa663 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -21,6 +21,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN from torch.testing._internal.common_utils import gradcheck, TEST_XPU from torch.utils.cpp_extension import ( + _get_cuda_arch_flags, _TORCH_PATH, check_compiler_is_gcc, CUDA_HOME, @@ -118,47 +119,81 @@ def test_jit_cuda_extension(self): # 2 * sigmoid(0) = 2 * 0.5 = 1 self.assertEqual(z, torch.ones_like(z)) - def _test_jit_xpu_extension(self): - name = "torch_test_xpu_extension_" - # randomizing name for the case when we test building few extensions - # in a row using this function - name += "".join(random.sample(string.ascii_letters, 5)) - module = torch.utils.cpp_extension.load( - name=name, - sources=[ - "cpp_extensions/xpu_extension.sycl", - ], - verbose=True, - keep_intermediates=False, - ) + def _test_jit_xpu_extension(self, extra_sycl_cflags): + # randomizing extension name and names of extension methods + # for the case when we test building few extensions in a row + # using this function + rand = "".join(random.sample(string.ascii_letters, 5)) + name = f"torch_test_xpu_extension_{rand}" + temp_dir = tempfile.mkdtemp() + try: + with open("cpp_extensions/xpu_extension.sycl") as f: + text = f.read() + for fn in ["sigmoid_add", "SigmoidAddKernel"]: + text = text.replace(fn, f"{fn}_{rand}") + + sycl_file = f"{temp_dir}/xpu_extension.sycl" + with open(sycl_file, "w") as f: + f.write(text) + + module = torch.utils.cpp_extension.load( + name=name, + sources=[sycl_file], + extra_sycl_cflags=extra_sycl_cflags, + verbose=True, + keep_intermediates=True, + build_directory=temp_dir, + ) - x = torch.zeros(100, device="xpu", dtype=torch.float32) - y = torch.zeros(100, device="xpu", dtype=torch.float32) + x = torch.zeros(100, device="xpu", dtype=torch.float32) + y = torch.zeros(100, device="xpu", dtype=torch.float32) - z = module.sigmoid_add(x, y).cpu() + method = f"sigmoid_add_{rand}" + self.assertTrue(hasattr(module, method)) + z = getattr(module, method)(x, y).cpu() - # 2 * sigmoid(0) = 2 * 0.5 = 1 - self.assertEqual(z, torch.ones_like(z)) + # 2 * sigmoid(0) = 2 * 0.5 = 1 + self.assertEqual(z, torch.ones_like(z)) + finally: + shutil.rmtree(temp_dir) @unittest.skipIf(not (TEST_XPU), "XPU not found") def test_jit_xpu_extension(self): # NOTE: this test can be affected by setting TORCH_XPU_ARCH_LIST - self._test_jit_xpu_extension() + self._test_jit_xpu_extension(extra_sycl_cflags=[]) @unittest.skipIf(not (TEST_XPU), "XPU not found") def test_jit_xpu_archlists(self): # NOTE: in this test we explicitly test few different options # for TORCH_XPU_ARCH_LIST. Setting TORCH_XPU_ARCH_LIST in the # environment before the test won't affect it. - archlists = [ - "", # expecting JIT compilation - ",".join(torch.xpu.get_arch_list()), + cases = [ + { + # Testing JIT compilation + "archlist": "", + "extra_sycl_cflags": [], + }, + { + # Testing JIT + AOT (full torch AOT arch list) + # NOTE: default cpp extension AOT arch list might be reduced + # from the full list + "archlist": ",".join(torch.xpu.get_arch_list()), + "extra_sycl_cflags": [], + }, + { + # Testing AOT (full torch AOT arch list) + # NOTE: default cpp extension AOT arch list might be reduced + # from the full list + "archlist": ",".join(torch.xpu.get_arch_list()), + # below excludes spir64 target responsible for JIT + "extra_sycl_cflags": ["-fsycl-targets=spir64_gen"], + }, ] old_envvar = os.environ.get("TORCH_XPU_ARCH_LIST", None) try: - for al in archlists: - os.environ["TORCH_XPU_ARCH_LIST"] = al - self._test_jit_xpu_extension() + for c in cases: + os.environ["TORCH_XPU_ARCH_LIST"] = c["archlist"] + self._test_jit_xpu_extension(extra_sycl_cflags=c["extra_sycl_cflags"]) finally: if old_envvar is None: os.environ.pop("TORCH_XPU_ARCH_LIST") @@ -313,6 +348,35 @@ def test_jit_cuda_archflags(self): # to avoid errors from here leaking into other tests pass + @unittest.skipIf(not TEST_CUDA, "CUDA not found") + def test_cuda_arch_flags_non_default_gencode(self): + user_arch_flags = ["-gencode=arch=compute_86,code=sm_86"] + result = _get_cuda_arch_flags(user_arch_flags) + + self.assertEqual( + len(result), + 0, + f"User arch flags should prevent default generation. " + f"Expected: [], Got: {result}", + ) + + @unittest.skipIf(not TEST_CUDA, "CUDA not found") + def test_cuda_arch_flags_default_gencode(self): + default_flags = _get_cuda_arch_flags() + self.assertGreater( + len(default_flags), 0, "No args should generate default flags" + ) + + non_arch_flags = _get_cuda_arch_flags(["-O2", "--use-fast-math"]) + self.assertGreater( + len(non_arch_flags), 0, "Non-arch flags should still generate defaults" + ) + + empty_flags = _get_cuda_arch_flags([]) + self.assertGreater( + len(empty_flags), 0, "Empty list should generate default flags" + ) + @unittest.skipIf(not TEST_CUDNN, "CuDNN not found") @unittest.skipIf(TEST_ROCM, "Not supported on ROCm") def test_jit_cudnn_extension(self): diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 7779a006bda88a..dc44f66bcebc9e 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -1,45 +1,18 @@ # Owner(s): ["module: cpp-extensions"] -import _codecs -import io import os -import sys -import tempfile import unittest -from unittest.mock import patch -import numpy as np import pytorch_openreg # noqa: F401 import torch import torch.testing._internal.common_utils as common import torch.utils.cpp_extension -from torch.serialization import safe_globals -from torch.testing._internal.common_utils import ( - IS_ARM64, - skipIfTorchDynamo, - TemporaryFileName, - TEST_CUDA, - TEST_XPU, -) -from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME -TEST_CUDA = TEST_CUDA and CUDA_HOME is not None -TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None - - -def generate_faked_module(): - class _OpenRegMod: - pass - - return _OpenRegMod() - - -@unittest.skipIf(IS_ARM64, "Does not work on arm") -@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently") -@torch.testing._internal.common_utils.markDynamoStrictTest -class TestCppExtensionOpenRgistration(common.TestCase): +@unittest.skipIf(common.TEST_XPU, "XPU does not support cppextension currently") +@common.markDynamoStrictTest +class TestCppExtensionOpenRegistration(common.TestCase): """Tests Open Device Registration with C++ extensions.""" module = None @@ -62,7 +35,7 @@ def tearDown(self): @classmethod def setUpClass(cls): - torch.testing._internal.common_utils.remove_cpp_extensions_build_root() + common.remove_cpp_extensions_build_root() cls.module = torch.utils.cpp_extension.load( name="custom_device_extension", @@ -74,373 +47,6 @@ def setUpClass(cls): verbose=True, ) - torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) - - def test_base_device_registration(self): - self.assertFalse(self.module.custom_add_called()) - # create a tensor using our custom device object - device = self.module.custom_device() - x = torch.empty(4, 4, device=device) - y = torch.empty(4, 4, device=device) - # Check that our device is correct. - self.assertTrue(x.device == device) - self.assertFalse(x.is_cpu) - self.assertFalse(self.module.custom_add_called()) - # calls out custom add kernel, registered to the dispatcher - z = x + y - # check that it was called - self.assertTrue(self.module.custom_add_called()) - z_cpu = z.to(device="cpu") - # Check that our cross-device copy correctly copied the data to cpu - self.assertTrue(z_cpu.is_cpu) - self.assertFalse(z.is_cpu) - self.assertTrue(z.device == device) - self.assertEqual(z, z_cpu) - - def test_common_registration(self): - # check unsupported device and duplicated registration - with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): - torch._register_device_module("dev", generate_faked_module()) - with self.assertRaisesRegex(RuntimeError, "The runtime module of"): - torch._register_device_module("openreg", generate_faked_module()) - - # backend name can be renamed to the same name multiple times - torch.utils.rename_privateuse1_backend("openreg") - - # backend name can't be renamed multiple times to different names. - with self.assertRaisesRegex( - RuntimeError, "torch.register_privateuse1_backend()" - ): - torch.utils.rename_privateuse1_backend("dev") - - # generator tensor and module can be registered only once - with self.assertRaisesRegex(RuntimeError, "The custom device module of"): - torch.utils.generate_methods_for_privateuse1_backend() - - # check whether torch.openreg have been registered correctly - self.assertTrue( - torch.utils.backend_registration._get_custom_mod_func("device_count")() == 2 - ) - with self.assertRaisesRegex(RuntimeError, "Try to call torch.openreg"): - torch.utils.backend_registration._get_custom_mod_func("func_name_") - - # check attributes after registered - self.assertTrue(hasattr(torch.Tensor, "is_openreg")) - self.assertTrue(hasattr(torch.Tensor, "openreg")) - self.assertTrue(hasattr(torch.TypedStorage, "is_openreg")) - self.assertTrue(hasattr(torch.TypedStorage, "openreg")) - self.assertTrue(hasattr(torch.UntypedStorage, "is_openreg")) - self.assertTrue(hasattr(torch.UntypedStorage, "openreg")) - self.assertTrue(hasattr(torch.nn.Module, "openreg")) - self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_openreg")) - self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "openreg")) - - def test_open_device_generator_registration_and_hooks(self): - device = self.module.custom_device() - # None of our CPU operations should call the custom add function. - self.assertFalse(self.module.custom_add_called()) - - gen = torch.Generator(device=device) - self.assertTrue(gen.device == device) - - default_gen = self.module.default_generator(0) - self.assertTrue( - default_gen.device.type == torch._C._get_privateuse1_backend_name() - ) - - def test_open_device_dispatchstub(self): - # test kernels could be reused by privateuse1 backend through dispatchstub - input_data = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu") - openreg_input_data = input_data.to("openreg") - output_data = torch.abs(input_data) - openreg_output_data = torch.abs(openreg_input_data) - self.assertEqual(output_data, openreg_output_data.cpu()) - - output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu") - # output operand will resize flag is True in TensorIterator. - openreg_input_data = input_data.to("openreg") - openreg_output_data = output_data.to("openreg") - # output operand will resize flag is False in TensorIterator. - torch.abs(input_data, out=output_data[:, :, 0:6:2]) - torch.abs(openreg_input_data, out=openreg_output_data[:, :, 0:6:2]) - self.assertEqual(output_data, openreg_output_data.cpu()) - - # output operand will resize flag is True in TensorIterator. - # and convert output to contiguous tensor in TensorIterator. - output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu") - openreg_input_data = input_data.to("openreg") - openreg_output_data = output_data.to("openreg") - torch.abs(input_data, out=output_data[:, :, 0:6:3]) - torch.abs(openreg_input_data, out=openreg_output_data[:, :, 0:6:3]) - self.assertEqual(output_data, openreg_output_data.cpu()) - - def test_open_device_quantized(self): - input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to( - "openreg" - ) - quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8) - self.assertEqual(quantized_tensor.device, torch.device("openreg:0")) - self.assertEqual(quantized_tensor.dtype, torch.qint8) - - def test_open_device_random(self): - # check if torch.openreg have implemented get_rng_state - with torch.random.fork_rng(device_type="openreg"): - pass - - def test_open_device_tensor(self): - device = self.module.custom_device() - - # check whether print tensor.type() meets the expectation - dtypes = { - torch.bool: "torch.openreg.BoolTensor", - torch.double: "torch.openreg.DoubleTensor", - torch.float32: "torch.openreg.FloatTensor", - torch.half: "torch.openreg.HalfTensor", - torch.int32: "torch.openreg.IntTensor", - torch.int64: "torch.openreg.LongTensor", - torch.int8: "torch.openreg.CharTensor", - torch.short: "torch.openreg.ShortTensor", - torch.uint8: "torch.openreg.ByteTensor", - } - for tt, dt in dtypes.items(): - test_tensor = torch.empty(4, 4, dtype=tt, device=device) - self.assertTrue(test_tensor.type() == dt) - - # check whether the attributes and methods of the corresponding custom backend are generated correctly - x = torch.empty(4, 4) - self.assertFalse(x.is_openreg) - - x = x.openreg(torch.device("openreg")) - self.assertFalse(self.module.custom_add_called()) - self.assertTrue(x.is_openreg) - - # test different device type input - y = torch.empty(4, 4) - self.assertFalse(y.is_openreg) - - y = y.openreg(torch.device("openreg:0")) - self.assertFalse(self.module.custom_add_called()) - self.assertTrue(y.is_openreg) - - # test different device type input - z = torch.empty(4, 4) - self.assertFalse(z.is_openreg) - - z = z.openreg(0) - self.assertFalse(self.module.custom_add_called()) - self.assertTrue(z.is_openreg) - - def test_open_device_packed_sequence(self): - device = self.module.custom_device() # noqa: F841 - a = torch.rand(5, 3) - b = torch.tensor([1, 1, 1, 1, 1]) - input = torch.nn.utils.rnn.PackedSequence(a, b) - self.assertFalse(input.is_openreg) - input_openreg = input.openreg() - self.assertTrue(input_openreg.is_openreg) - - def test_open_device_storage(self): - # check whether the attributes and methods for storage of the corresponding custom backend are generated correctly - x = torch.empty(4, 4) - z1 = x.storage() - self.assertFalse(z1.is_openreg) - - z1 = z1.openreg() - self.assertFalse(self.module.custom_add_called()) - self.assertTrue(z1.is_openreg) - - with self.assertRaisesRegex(RuntimeError, "Invalid device"): - z1.openreg(torch.device("cpu")) - - z1 = z1.cpu() - self.assertFalse(self.module.custom_add_called()) - self.assertFalse(z1.is_openreg) - - z1 = z1.openreg(device="openreg:0", non_blocking=False) - self.assertFalse(self.module.custom_add_called()) - self.assertTrue(z1.is_openreg) - - with self.assertRaisesRegex(RuntimeError, "Invalid device"): - z1.openreg(device="cuda:0", non_blocking=False) - - # check UntypedStorage - y = torch.empty(4, 4) - z2 = y.untyped_storage() - self.assertFalse(z2.is_openreg) - - z2 = z2.openreg() - self.assertFalse(self.module.custom_add_called()) - self.assertTrue(z2.is_openreg) - - # check custom StorageImpl create - self.module.custom_storage_registry() - - z3 = y.untyped_storage() - self.assertFalse(self.module.custom_storageImpl_called()) - - z3 = z3.openreg() - self.assertTrue(self.module.custom_storageImpl_called()) - self.assertFalse(self.module.custom_storageImpl_called()) - - z3 = z3[0:3] - self.assertTrue(self.module.custom_storageImpl_called()) - - @unittest.skipIf( - sys.version_info >= (3, 13), - "Error: Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.", - ) - @skipIfTorchDynamo("unsupported aten.is_pinned.default") - def test_open_device_storage_pin_memory(self): - # Check if the pin_memory is functioning properly on custom device - cpu_tensor = torch.empty(3) - self.assertFalse(cpu_tensor.is_openreg) - self.assertFalse(cpu_tensor.is_pinned()) - - cpu_tensor_pin = cpu_tensor.pin_memory() - self.assertTrue(cpu_tensor_pin.is_pinned()) - - # Test storage pin_memory and is_pin - cpu_storage = cpu_tensor.storage() - self.assertFalse(cpu_storage.is_pinned("openreg")) - - cpu_storage_pinned = cpu_storage.pin_memory("openreg") - self.assertTrue(cpu_storage_pinned.is_pinned("openreg")) - - # Test untyped storage pin_memory and is_pin - cpu_tensor = torch.randn([3, 2, 1, 4]) - cpu_untyped_storage = cpu_tensor.untyped_storage() - self.assertFalse(cpu_untyped_storage.is_pinned("openreg")) - - cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg") - self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg")) - - @unittest.skip( - "Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function" - ) - def test_open_device_serialization(self): - self.module.set_custom_device_index(-1) - storage = torch.UntypedStorage(4, device=torch.device("openreg")) - self.assertEqual(torch.serialization.location_tag(storage), "openreg") - - self.module.set_custom_device_index(0) - storage = torch.UntypedStorage(4, device=torch.device("openreg")) - self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") - - cpu_storage = torch.empty(4, 4).storage() - openreg_storage = torch.serialization.default_restore_location( - cpu_storage, "openreg:0" - ) - self.assertTrue(openreg_storage.is_openreg) - - # test tensor MetaData serialization - x = torch.empty(4, 4).long() - y = x.openreg() - self.assertFalse(self.module.check_backend_meta(y)) - self.module.custom_set_backend_meta(y) - self.assertTrue(self.module.check_backend_meta(y)) - - self.module.custom_serialization_registry() - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "data.pt") - torch.save(y, path) - z1 = torch.load(path) - # loads correctly onto the openreg backend device - self.assertTrue(z1.is_openreg) - # loads BackendMeta data correctly - self.assertTrue(self.module.check_backend_meta(z1)) - - # cross-backend - z2 = torch.load(path, map_location="cpu") - # loads correctly onto the cpu backend device - self.assertFalse(z2.is_openreg) - # loads BackendMeta data correctly - self.assertFalse(self.module.check_backend_meta(z2)) - - def test_open_device_storage_resize(self): - cpu_tensor = torch.randn([8]) - openreg_tensor = cpu_tensor.openreg() - openreg_storage = openreg_tensor.storage() - self.assertTrue(openreg_storage.size() == 8) - - # Only register tensor resize_ function. - openreg_tensor.resize_(8) - self.assertTrue(openreg_storage.size() == 8) - - with self.assertRaisesRegex(TypeError, "Overflow"): - openreg_tensor.resize_(8**29) - - def test_open_device_storage_type(self): - # test cpu float storage - cpu_tensor = torch.randn([8]).float() - cpu_storage = cpu_tensor.storage() - self.assertEqual(cpu_storage.type(), "torch.FloatStorage") - - # test custom float storage before defining FloatStorage - openreg_tensor = cpu_tensor.openreg() - openreg_storage = openreg_tensor.storage() - self.assertEqual(openreg_storage.type(), "torch.storage.TypedStorage") - - class CustomFloatStorage: - @property - def __module__(self): - return "torch." + torch._C._get_privateuse1_backend_name() - - @property - def __name__(self): - return "FloatStorage" - - # test custom float storage after defining FloatStorage - try: - torch.openreg.FloatStorage = CustomFloatStorage() - self.assertEqual(openreg_storage.type(), "torch.openreg.FloatStorage") - - # test custom int storage after defining FloatStorage - openreg_tensor2 = torch.randn([8]).int().openreg() - openreg_storage2 = openreg_tensor2.storage() - self.assertEqual(openreg_storage2.type(), "torch.storage.TypedStorage") - finally: - torch.openreg.FloatStorage = None - - def test_open_device_faketensor(self): - with torch._subclasses.fake_tensor.FakeTensorMode.push(): - a = torch.empty(1, device="openreg") - b = torch.empty(1, device="openreg:0") - result = a + b # noqa: F841 - - def test_open_device_named_tensor(self): - torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"]) - - # Not an open registration test - this file is just very convenient - # for testing torch.compile on custom C++ operators - def test_compile_autograd_function_returns_self(self): - x_ref = torch.randn(4, requires_grad=True) - out_ref = self.module.custom_autograd_fn_returns_self(x_ref) - out_ref.sum().backward() - - x_test = x_ref.detach().clone().requires_grad_(True) - f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self) - out_test = f_compiled(x_test) - out_test.sum().backward() - - self.assertEqual(out_ref, out_test) - self.assertEqual(x_ref.grad, x_test.grad) - - # Not an open registration test - this file is just very convenient - # for testing torch.compile on custom C++ operators - @skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket") - def test_compile_autograd_function_aliasing(self): - x_ref = torch.randn(4, requires_grad=True) - out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref) - out_ref.sum().backward() - - x_test = x_ref.detach().clone().requires_grad_(True) - f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing) - out_test = f_compiled(x_test) - out_test.sum().backward() - - self.assertEqual(out_ref, out_test) - self.assertEqual(x_ref.grad, x_test.grad) - def test_open_device_scalar_type_fallback(self): z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64) z = torch.triu_indices(3, 3, device="openreg") @@ -488,130 +94,6 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - @skipIfTorchDynamo() - @unittest.skipIf( - np.__version__ < "1.25", - "versions < 1.25 serialize dtypes differently from how it's serialized in data_legacy_numpy", - ) - def test_open_device_numpy_serialization(self): - """ - This tests the legacy _rebuild_device_tensor_from_numpy serialization path - """ - device = self.module.custom_device() - - # Legacy data saved with _rebuild_device_tensor_from_numpy on f80ed0b8 via - - # with patch.object(torch._C, "_has_storage", return_value=False): - # x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=device) - # x_foo = x.to(device) - # sd = {"x": x_foo} - # rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0] - # self.assertTrue( - # rebuild_func is torch._utils._rebuild_device_tensor_from_numpy - # ) - # with open("foo.pt", "wb") as f: - # torch.save(sd, f) - - data_legacy_numpy = ( - b"PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - b"\x00\x00\x00\x10\x00\x12\x00archive/data.pklFB\x0e\x00ZZZZZZZZZZZZZZ\x80\x02}q\x00X\x01" - b"\x00\x00\x00xq\x01ctorch._utils\n_rebuild_device_tensor_from_numpy\nq\x02(cnumpy.core.m" - b"ultiarray\n_reconstruct\nq\x03cnumpy\nndarray\nq\x04K\x00\x85q\x05c_codecs\nencode\nq\x06" - b"X\x01\x00\x00\x00bq\x07X\x06\x00\x00\x00latin1q\x08\x86q\tRq\n\x87q\x0bRq\x0c(K\x01K\x02K" - b"\x03\x86q\rcnumpy\ndtype\nq\x0eX\x02\x00\x00\x00f4q\x0f\x89\x88\x87q\x10Rq\x11(K\x03X\x01" - b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00" - b"PK\x05\x06\x00\x00\x00\x00\x04\x00\x04\x00\x06\x01\x00\x008\x03\x00\x00\x00\x00" - ) - buf_data_legacy_numpy = io.BytesIO(data_legacy_numpy) - - with safe_globals( - [ - (np.core.multiarray._reconstruct, "numpy.core.multiarray._reconstruct") - if np.__version__ >= "2.1" - else np.core.multiarray._reconstruct, - np.ndarray, - np.dtype, - _codecs.encode, - np.dtypes.Float32DType, - ] - ): - sd_loaded = torch.load(buf_data_legacy_numpy, weights_only=True) - buf_data_legacy_numpy.seek(0) - # Test map_location - sd_loaded_cpu = torch.load( - buf_data_legacy_numpy, weights_only=True, map_location="cpu" - ) - expected = torch.tensor( - [[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=device - ) - self.assertEqual(sd_loaded["x"].cpu(), expected.cpu()) - self.assertFalse(sd_loaded["x"].is_cpu) - self.assertTrue(sd_loaded_cpu["x"].is_cpu) - - def test_open_device_cpu_serialization(self): - torch.utils.rename_privateuse1_backend("openreg") - device = self.module.custom_device() - default_protocol = torch.serialization.DEFAULT_PROTOCOL - - with patch.object(torch._C, "_has_storage", return_value=False): - x = torch.randn(2, 3) - x_openreg = x.to(device) - sd = {"x": x_openreg} - rebuild_func = x_openreg._reduce_ex_internal(default_protocol)[0] - self.assertTrue( - rebuild_func is torch._utils._rebuild_device_tensor_from_cpu_tensor - ) - # Test map_location - with TemporaryFileName() as f: - torch.save(sd, f) - sd_loaded = torch.load(f, weights_only=True) - # Test map_location - sd_loaded_cpu = torch.load(f, weights_only=True, map_location="cpu") - self.assertFalse(sd_loaded["x"].is_cpu) - self.assertEqual(sd_loaded["x"].cpu(), x) - self.assertTrue(sd_loaded_cpu["x"].is_cpu) - - # Test metadata_only - with TemporaryFileName() as f: - with self.assertRaisesRegex( - RuntimeError, - "Cannot serialize tensors on backends with no storage under skip_data context manager", - ): - with torch.serialization.skip_data(): - torch.save(sd, f) - - def test_open_device_dlpack(self): - t = torch.randn(2, 3).to("openreg") - capsule = torch.utils.dlpack.to_dlpack(t) - t1 = torch.from_dlpack(capsule) - self.assertTrue(t1.device == t.device) - t = t.to("cpu") - t1 = t1.to("cpu") - self.assertEqual(t, t1) - if __name__ == "__main__": common.run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index c77c724d3c9043..d8375fe9429bc4 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -12,6 +12,7 @@ import sys import tempfile import threading +import time import unittest import warnings from collections import defaultdict @@ -34,6 +35,7 @@ from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast from torch.testing._internal.common_cuda import ( _create_scaling_case, + SM70OrLater, TEST_CUDNN, TEST_MULTIGPU, tf32_on_and_off, @@ -43,6 +45,7 @@ largeTensorTest, onlyCUDA, onlyNativeDeviceTypes, + skipCUDAIf, ) from torch.testing._internal.common_optimizers import ( _get_optim_inputs_including_global_cliquey_kwargs, @@ -51,6 +54,7 @@ TensorTracker, ) from torch.testing._internal.common_utils import ( + cuda_python_error_check, EXPANDABLE_SEGMENTS, freeze_rng_state, gcIfJetson, @@ -65,6 +69,7 @@ load_tests, MI300_ARCH, parametrize, + recover_orig_fp32_precision, run_tests, serialTest, setBlasBackendsToDefaultFinally, @@ -77,6 +82,7 @@ TemporaryFileName, TEST_CUDA, TEST_CUDA_GRAPH, + TEST_CUDA_PYTHON_BINDINGS, TEST_NUMPY, TEST_WITH_ROCM, TestCase, @@ -168,7 +174,7 @@ def test_pinned_memory_with_cudaregister_multithread(self): for thread in threads: thread.join() - @serialTest + @serialTest() def test_host_memory_stats(self): # Helper functions def empty_stats(): @@ -844,6 +850,55 @@ def test_cudnn_allow_tf32_get_set(self): ): self.assertTrue(torch.backends.cudnn.allow_tf32) + @recover_orig_fp32_precision + def test_fp32_precision_with_tf32(self): + with torch.backends.cudnn.flags( + enabled=None, + benchmark=None, + benchmark_limit=None, + deterministic=None, + allow_tf32=True, + fp32_precision="none", + ): + self.assertEqual(torch.backends.cudnn.conv.fp32_precision, "tf32") + self.assertEqual(torch.backends.cudnn.rnn.fp32_precision, "tf32") + + with torch.backends.cudnn.flags( + enabled=None, + benchmark=None, + benchmark_limit=None, + deterministic=None, + allow_tf32=False, + fp32_precision="none", + ): + self.assertEqual(torch.backends.cudnn.conv.fp32_precision, "none") + self.assertEqual(torch.backends.cudnn.rnn.fp32_precision, "none") + + @recover_orig_fp32_precision + def test_fp32_precision_with_float32_matmul_precision(self): + torch.set_float32_matmul_precision("highest") + self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "ieee") + torch.set_float32_matmul_precision("high") + self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "tf32") + torch.set_float32_matmul_precision("medium") + self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "tf32") + + @recover_orig_fp32_precision + def test_invalid_status_for_legacy_api(self): + torch.backends.cudnn.conv.fp32_precision = "none" + torch.backends.cudnn.rnn.fp32_precision = "tf32" + with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"): + print(torch.backends.cudnn.allow_tf32) + + torch.set_float32_matmul_precision("highest") + torch.backends.cuda.matmul.fp32_precision = "tf32" + with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"): + print(torch.get_float32_matmul_precision()) + + if not TEST_WITH_ROCM: + with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"): + print(torch.backends.cuda.matmul.allow_tf32) + def test_type_conversions(self): x = torch.randn(5, 5) self.assertIsInstance(x.float(), torch.FloatTensor) @@ -1061,7 +1116,7 @@ def test_stream_compatibility(self): torch.accelerator.set_stream(s2) self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id) with self.assertRaisesRegex( - RuntimeError, "device_index >= 0 && device_index < num_gpus" + RuntimeError, "Device index value .* is out of index range" ): torch.accelerator.current_stream(torch.accelerator.device_count()) @@ -1387,6 +1442,8 @@ def _spawn_method(self, method, arg): for e in errors: if "device-side assert triggered" not in str(e): self.fail(e) + if e.error_code != 710: # cudaErrorAssert == 710 + self.fail(e) @staticmethod def _test_index_bounds_cuda(idx): @@ -2251,6 +2308,26 @@ def test_graph_debugdump(self): with tempfile.TemporaryDirectory() as tempdir: g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot")) + @unittest.skipIf( + not TEST_CUDA_GRAPH or TEST_WITH_ROCM, + "CUDA >= 11.0 required for external events in cuda graphs. rocm does not support external events", + ) + def test_graph_timing(self): + torch.cuda.empty_cache() + x = torch.randn(10240000, device="cuda") + y = torch.rand_like(x) + g = torch.cuda.CUDAGraph() + start_event = torch.cuda.Event(enable_timing=True, external=True) + end_event = torch.cuda.Event(enable_timing=True, external=True) + with torch.cuda.graph(g): + start_event.record() + z = x + y + end_event.record() + torch.cuda.synchronize() + g.replay() + torch.cuda.synchronize() + self.assertTrue(start_event.elapsed_time(end_event) > 0) + @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -3502,6 +3579,95 @@ def throws_on_cuda_event(capture_error_mode): # Exception would Corrupt Process and make other tests fail # self.assertTrue(throws_on_cuda_event("global")) + @unittest.skipIf( + not TEST_CUDA_GRAPH, + "CUDA >= 11.0 or ROCM >= 5.3 required for graphs, cuda-python must be installed", + ) + def test_cuda_graph_raw_graph_keep_graph_false(self): + graph = torch.cuda.CUDAGraph(keep_graph=False) + x = torch.zeros([2000], device="cuda") + y = torch.ones([2000], device="cuda") + with torch.cuda.graph(graph, capture_error_mode="relaxed"): + z = x + y + + with self.assertRaisesRegex( + RuntimeError, + r"instantiate\(\) is intended to be called by the user only when keep_graph=true", + ): + raw_pointer = graph.instantiate() + + with self.assertRaisesRegex( + RuntimeError, + r"You cannot access the raw (cuda|hip)Graph_t instance unless CUDAGraph was initialized with keep_graph=true", + ): + raw_pointer = graph.raw_cuda_graph() + + @unittest.skipIf( + not TEST_CUDA_GRAPH or not TEST_CUDA_PYTHON_BINDINGS, + "CUDA >= 11.0 or ROCM >= 5.3 required for graphs, cuda-bindings must be installed", + ) + def test_cuda_graph_raw_graph(self): + import cuda.bindings.runtime as cudart + + graph = torch.cuda.CUDAGraph(keep_graph=True) + x = torch.zeros([2000], device="cuda") + y = torch.ones([2000], device="cuda") + with torch.cuda.graph(graph, capture_error_mode="relaxed"): + z = x + y + + raw_pointer = graph.raw_cuda_graph() + + cudart_cuda_graph = cudart.cudaGraph_t(init_value=raw_pointer) + _, num_nodes = cuda_python_error_check( + cudart.cudaGraphGetNodes(cudart_cuda_graph) + ) + nodes, _ = cuda_python_error_check( + cudart.cudaGraphGetNodes(cudart_cuda_graph, num_nodes) + ) + for node in nodes: + cuda_python_error_check(cudart.cudaGraphNodeGetType(node)) + + graph.replay() + + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_cuda_graph_raw_graph_reset_and_recapture(self): + graph = torch.cuda.CUDAGraph(keep_graph=True) + x = torch.zeros([2000], device="cuda") + with torch.cuda.graph(graph, capture_error_mode="relaxed"): + x += 1.0 + + graph.instantiate() + graph.replay() + self.assertTrue(torch.all(x == 1.0)) + # Exercise the code path where you reinstantiate the cuda graph twice. + graph.instantiate() + graph.replay() + self.assertTrue(torch.all(x == 2.0)) + graph.replay() + self.assertTrue(torch.all(x == 3.0)) + + # Check that graph capture can succeed after reseting. + graph.reset() + + # Don't do x[:] = 0.0 because we want to capture a new address + # in the next cuda graph, to make sure we are running a new + # cuda graph. + x = torch.zeros([2000], device="cuda") + with torch.cuda.graph(graph, capture_error_mode="relaxed"): + x += 2.0 + + graph.instantiate() + graph.replay() + self.assertTrue(torch.all(x == 2.0)) + # Exercise the code path where you reinstantiate the cuda graph twice. + graph.instantiate() + graph.replay() + self.assertTrue(torch.all(x == 4.0)) + graph.replay() + self.assertTrue(torch.all(x == 6.0)) + @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -3673,10 +3839,16 @@ def test_hip_device_count(self): {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"}, {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"}, - {"ROCR_VISIBLE_DEVICES": "1,2,3", "HIP_VISIBLE_DEVICES": "0"}, {"ROCR_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, ] + if torch.cuda.device_count() >= 2: + custom_envs.extend( + [ + {"ROCR_VISIBLE_DEVICES": "1,2,3", "HIP_VISIBLE_DEVICES": "0"}, + ] + ) + for env_config in custom_envs: env = os.environ.copy() for key, value in env_config.items(): @@ -4120,7 +4292,7 @@ def foo(): finally: torch.cuda.memory._record_memory_history(None) - @serialTest + @serialTest() def test_max_split_expandable(self): try: torch.cuda.memory.empty_cache() @@ -4156,7 +4328,7 @@ def alloc(n): finally: torch.cuda.memory.set_per_process_memory_fraction(orig) - @serialTest + @serialTest() def test_garbage_collect_expandable(self): try: torch.cuda.memory.empty_cache() @@ -5112,6 +5284,47 @@ def get_dummy_allocator(self, check_vars): ) return allocator, dummy_allocator + def test_mempool_empty_cache(self): + torch.cuda.empty_cache() + pool = torch.cuda.MemPool() + x = torch.empty(1024, 1024, device="cuda") + + with torch.cuda.use_mem_pool(pool): + y = torch.empty(1024, 1024, device="cuda") + + del y + del x + del pool + segments = torch.cuda.memory._snapshot()["segments"] + self.assertTrue(len(segments) > 0, "expected more than one segment") + + def test_mempool_empty_cache_inactive(self): + torch.cuda.empty_cache() + allocator, dummy_allocator = self.get_dummy_allocator(check_vars=True) + alloc_lib = ctypes.CDLL(dummy_allocator) + called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc") + called_dummy_free = ctypes.c_int.in_dll(alloc_lib, "called_dummy_free") + self.assertEqual(called_dummy_alloc.value, 0) + self.assertEqual(called_dummy_free.value, 0) + + def f(): + pool = torch.cuda.MemPool(allocator.allocator()) + + # allocate memory with ncclMemAlloc + with torch.cuda.use_mem_pool(pool): + x = torch.arange(1024 * 1024 * 2, device="cuda") + # Note: pool will be destroyed upon function return, but x, which + # was allocated via the pool is still alive. + return x + + x = f() + self.assertEqual(called_dummy_alloc.value, 123) + self.assertEqual(called_dummy_free.value, 0) + + del x + torch.cuda.empty_cache() + self.assertEqual(called_dummy_free.value, 321) + def test_mempool_with_allocator(self): pool = torch.cuda.MemPool() @@ -5172,8 +5385,7 @@ def test_mempool_with_allocator(self): # to make a new 2 MB buffer to accomodate out_2 self.assertEqual(len(pool.snapshot()), 2) - all_segments = torch.cuda.memory._snapshot()["segments"] - self.assertEqual(len(all_segments), 3) + self.assertEqual(len(pool.snapshot()), 2) del out_0, out_1, out_2 @@ -6435,11 +6647,222 @@ def test_compile_kernel_advanced(self): # Verify results self.assertEqual(C_explicit, expected) + @unittest.skipIf(TEST_WITH_ROCM, "ROCM does not support nvrtc") + @unittest.skipIf(not TEST_CUDA, "No CUDA") + def test_compile_kernel_as_custom_op(self): + # Define a simple vector addition kernel + kernel_source = """ + __global__ void vector_add(const float* a, const float* b, float* c, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } + } + """ + + @torch.library.custom_op("test_compile_kernel::vector_add", mutates_args=()) + def vector_add_op(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + from torch.cuda import _compile_kernel + + # Validate that tensors are 1-dimensional and have the same size + torch._check( + a.dim() == 1, + lambda: f"Expected tensor 'a' to be 1-dimensional, but got {a.dim()} dimensions", + ) + torch._check( + b.dim() == 1, + lambda: f"Expected tensor 'b' to be 1-dimensional, but got {b.dim()} dimensions", + ) + torch._check( + a.size() == b.size(), + lambda: f"Expected tensors to have the same size, but got a.size()={a.size()} and b.size()={b.size()}", + ) + compiled_kernel = _compile_kernel(kernel_source, "vector_add") + + c = torch.empty_like(a) + n = a.numel() + + threads_per_block = 256 + blocks_per_grid = (n + threads_per_block - 1) // threads_per_block + compiled_kernel( + grid=(blocks_per_grid, 1, 1), + block=(threads_per_block, 1, 1), + args=[a, b, c, n], + ) + + return c + + @vector_add_op.register_fake + def _(a, b): + return torch.empty_like(a) + + device = torch.device("cuda:0") + size = (1024,) + + a = torch.randn(size, device=device, dtype=torch.float32) + b = torch.randn(size, device=device, dtype=torch.float32) + + result = vector_add_op(a, b) + + expected = a + b + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCM does not support nvrtc") + @unittest.skipIf(not TEST_CUDA, "No CUDA") + def test_compile_kernel_custom_op_validation(self): + kernel_source = """ + __global__ void add_scalar(const float* input, float* output, float scalar, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = input[idx] + scalar; + } + } + """ + + @torch.library.custom_op("test_compile_kernel::add_scalar", mutates_args=()) + def add_scalar_op(input_tensor: torch.Tensor, scalar: float) -> torch.Tensor: + from torch.cuda import _compile_kernel + + compiled_kernel = _compile_kernel(kernel_source, "add_scalar") + + output = torch.empty_like(input_tensor) + n = input_tensor.numel() + + threads_per_block = 256 + blocks_per_grid = (n + threads_per_block - 1) // threads_per_block + compiled_kernel( + grid=(blocks_per_grid, 1, 1), + block=(threads_per_block, 1, 1), + args=[input_tensor, output, scalar, n], + ) + + return output + + @add_scalar_op.register_fake + def _(input_tensor, scalar): + return torch.empty_like(input_tensor) + + # Test with opcheck + device = torch.device("cuda:0") + input_data = torch.randn((64,), device=device, dtype=torch.float32) + scalar_val = 3.14 + + # Run opcheck validation + torch.library.opcheck(add_scalar_op, (input_data, scalar_val), {}) + + # Also test the actual functionality + result = add_scalar_op(input_data, scalar_val) + expected = input_data + scalar_val + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") +class TestCudaDeviceParametrized(TestCase): + @unittest.skipIf( + TEST_WITH_ROCM, "ROCM does not support nvrtc or external cuda graph events" + ) + @skipCUDAIf( + not SM70OrLater, "Compute capability >= SM70 required for relaxed ptx flag" + ) + def test_graph_external_wait_and_record(self): + torch.cuda.empty_cache() + + kernel_source = r""" + __global__ void wait_for_cpu(int *pinned_cpu_flag) { + int flag = 0; + do { + asm volatile("ld.relaxed.sys.global.s32 %0, [%1];" : "=r"(flag) : "l"(pinned_cpu_flag) : "memory"); + } while (flag == 0); + } + """ + from torch.cuda import _compile_kernel + + spin_wait_kernel = _compile_kernel( + kernel_source, "wait_for_cpu", compute_capability="70" + ) + + x = torch.ones(4, device="cuda") + x_cpu = torch.zeros(x.shape, device="cpu").pin_memory() + flag_cpu = torch.zeros(1, dtype=torch.int32, device="cpu").pin_memory() + start_event = torch.cuda.Event(external=True) + end_event = torch.cuda.Event(external=True) + + signalling_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(signalling_graph, capture_error_mode="relaxed"): + spin_wait_kernel(grid=(1, 1, 1), block=(1, 1, 1), args=[flag_cpu]) + start_event.record() + + # This is counter-intuitive, but a cudaEventRecord() during + # stream capture does not count as the first call to + # cudaEventRecord(). Rather, cudaGraphLaunch() counts as the + # first call to cudaEventRecord(). Therefore, all calls to + # cudaEventQuery() will succeed before that happens. + + # See: + # "Before the first call to cudaEventRecord(), an event represents an empty set of work, so for example cudaEventQuery() would return cudaSuccess." # noqa: B950 + # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html + self.assertTrue(start_event.query(), "Start event's work should be empty") + + work_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(work_graph, capture_error_mode="relaxed"): + start_event.wait() + x_cpu.copy_(x, non_blocking=True) + end_event.record() + + self.assertTrue( + torch.all(x_cpu == 0.0), "Copy cannot occur until start_event is recorded" + ) + + try: + signalling_stream = torch.cuda.Stream() + with torch.cuda.stream(signalling_stream): + signalling_graph.replay() + + work_stream = torch.cuda.Stream() + with torch.cuda.stream(work_stream): + work_graph.replay() + + self.assertFalse( + end_event.query(), "Event record node cannot run until flag_cpu[0]=1" + ) + + # Sleep for a little to make sure that work doesn't proceed until we set flag_cpu[0]=1 + time.sleep(1) + self.assertTrue( + torch.all(x_cpu == 0.0), + "Copy cannot occur until start_event is recorded", + ) + finally: + # In case an assertion fails, we still need to empty out + # the GPU queue of work. Therefore, we do this write + # unconditionally, even if an exception is thrown. + + # This writes allows wait_for_cpu to proceed + # This is an atomic store at system scope according to this rule: + # "the scope is thread_scope_system and and it is a load or store that affects a naturally-aligned object of sizes 1, 2, 4, 8, or 16 bytes on mapped memory" # noqa: B950 + # https://nvidia.github.io/cccl/libcudacxx/extended_api/memory_model.html#atomicity + + # Note that every CPU store is implicitly system scope, + # even if we don't use C++ atomics like this: + # std::atomic_ref::store(1); + flag_cpu[0] = 1 + + end_event.synchronize() + self.assertTrue( + torch.all(x_cpu == 1.0), + "Copy should be done once end_event is synchronized", + ) + self.assertTrue( + work_stream.query(), + "end_event.synchronize() completing should imply that work_stream is done", + ) + instantiate_parametrized_tests(TestCuda) instantiate_parametrized_tests(TestCudaMallocAsync) instantiate_parametrized_tests(TestCompileKernel) instantiate_device_type_tests(TestCudaOptims, globals()) +instantiate_device_type_tests(TestCudaDeviceParametrized, globals()) if __name__ == "__main__": run_tests() diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index 37a9c0aedbdd95..9b044458b3e6fa 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -4,7 +4,11 @@ import unittest import torch -from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU +from torch.testing._internal.common_cuda import ( + _get_torch_cuda_version, + TEST_CUDA, + TEST_MULTIGPU, +) from torch.testing._internal.common_utils import NoTest, run_tests, TestCase @@ -31,6 +35,19 @@ def setUp(self): TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG, ) + def test_set_device_0(self): + # In CUDA 12 the behavior of cudaSetDevice has changed. It eagerly creates context on target. + # The behavior of `torch.cuda.set_device(0)` should also create context on the device 0. + # Initially, we should not have any context on device 0. + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + torch.cuda.set_device(0) + if _get_torch_cuda_version() >= (12, 0): + # Now after the device was set, the contex should present in CUDA 12. + self.assertTrue(torch._C._cuda_hasPrimaryContext(0)) + else: + # In CUDA 11 the context should not be created. + self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") def test_str_repr(self): x = torch.randn(1, device="cuda:1") diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 4aa4d8ccde4cc4..f9d231a7df8513 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -225,6 +225,31 @@ def _(x: torch.Tensor) -> torch.Tensor: example = torch.zeros([10, 20], device=device) torch.library.opcheck(f, args=[example]) + # https://github.com/pytorch/pytorch/issues/150472 + def test_single_element_tuple_output(self, device): + # Helper function to register id_tuple custom and the fake tensor implementation + # so that Dynamo has the fake tensor implementation + def get_id_tuple(): + @torch.library.custom_op("test::id_tuple", mutates_args=[]) + def id_tuple(x: torch.Tensor) -> Tuple[torch.Tensor]: + return (x.clone(),) + + @id_tuple.register_fake + def _( + x: torch.Tensor, + ) -> Tuple[torch.Tensor]: + return (x.clone(),) + + return id_tuple + + id_tuple = get_id_tuple() + x = torch.randn(3, device=device) + ret = id_tuple(x) + # Check if ret is a tuple and has exactly one and the same element + self.assertIsInstance(ret, tuple) + self.assertEqual(len(ret), 1) + self.assertEqual(x, ret[0]) + def test_missing_abstract_impl(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") @@ -4473,9 +4498,9 @@ def test_mixed_types(self): class TestOpProfiles(TestCase): - def get_sample_op_profile(self) -> dict[str, set[OpProfile]]: + def get_sample_op_profile(self, opname) -> dict[str, set[OpProfile]]: return { - "mylib.foo.default": { + opname: { OpProfile( args_profile=( TensorMetadata( @@ -4508,46 +4533,46 @@ def test_fake_registration(self): t1 = fm.from_tensor(torch.ones(3, 3)) t2 = fm.from_tensor(torch.ones(3, 3)) - op_profiles = self.get_sample_op_profile() + op_profiles = self.get_sample_op_profile("mylib.foo2.default") with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( - "mylib::foo", + "mylib::foo2", "(Tensor a, Tensor b) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) - @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch.library.impl("mylib::foo2", "cpu", lib=lib) def foo_impl(a, b): return a + b with ( self.assertRaisesRegex( torch._subclasses.fake_tensor.UnsupportedOperatorException, - "mylib.foo.default", + "mylib.foo2.default", ), fm, ): - torch.ops.mylib.foo(t1, t2) + torch.ops.mylib.foo2(t1, t2) with ( torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles), fm, ): - torch.ops.mylib.foo(t1, t2) + torch.ops.mylib.foo2(t1, t2) - with self.assertRaisesRegex(MissingOpProfile, "mylib::foo"): - torch.ops.mylib.foo(torch.ones(3, 3, 3), torch.ones(3, 3, 3)) + with self.assertRaisesRegex(MissingOpProfile, "mylib::foo2"): + torch.ops.mylib.foo2(torch.ones(3, 3, 3), torch.ones(3, 3, 3)) with ( self.assertRaisesRegex( torch._subclasses.fake_tensor.UnsupportedOperatorException, - "mylib.foo.default", + "mylib.foo2.default", ), fm, ): - torch.ops.mylib.foo(t1, t2) + torch.ops.mylib.foo2(t1, t2) def test_duplicate_registration_impl(self): fm = torch._subclasses.FakeTensorMode( @@ -4556,33 +4581,33 @@ def test_duplicate_registration_impl(self): t1 = fm.from_tensor(torch.ones(3, 3)) t2 = fm.from_tensor(torch.ones(3, 3)) - op_profiles = self.get_sample_op_profile() + op_profiles = self.get_sample_op_profile("mylib.foo3.default") with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( - "mylib::foo", + "mylib::foo3", "(Tensor a, Tensor b) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) - @torch.library.impl("mylib::foo", "cpu", lib=lib) - def foo_impl(a, b): + @torch.library.impl("mylib::foo3", "cpu", lib=lib) + def foo3_impl(a, b): return a + b - @torch.library.register_fake("mylib::foo", lib=lib) - def foo_impl_fake(a, b): + @torch.library.register_fake("mylib::foo3", lib=lib) + def foo3_impl_fake(a, b): return (a + b).to(dtype=torch.bfloat16) with fm: - self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16) + self.assertEqual(torch.ops.mylib.foo3(t1, t2).dtype, torch.bfloat16) with torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles): with fm: - self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.float32) + self.assertEqual(torch.ops.mylib.foo3(t1, t2).dtype, torch.float32) with fm: - self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16) + self.assertEqual(torch.ops.mylib.foo3(t1, t2).dtype, torch.bfloat16) def test_duplicate_registration_custom_op(self): fm = torch._subclasses.FakeTensorMode( @@ -4591,7 +4616,7 @@ def test_duplicate_registration_custom_op(self): t1 = fm.from_tensor(torch.ones(3, 3)) t2 = fm.from_tensor(torch.ones(3, 3)) - op_profiles = self.get_sample_op_profile() + op_profiles = self.get_sample_op_profile("mylib.foo1.default") @torch.library.custom_op("mylib::foo1", mutates_args=()) def foo_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: @@ -4604,10 +4629,6 @@ def foo_impl_fake(a, b): with fm: self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16) - op_profiles = { - "mylib.foo1.default": self.get_sample_op_profile()["mylib.foo.default"] - } - with torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles): with fm: self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.float32) @@ -4616,14 +4637,14 @@ def foo_impl_fake(a, b): self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16) def test_yaml(self): - op_profiles = self.get_sample_op_profile() + op_profiles = self.get_sample_op_profile("mylib.foo.default") yaml_str = generate_yaml_from_profiles(op_profiles) loaded = read_profiles_from_yaml(yaml_str) self.assertEqual(op_profiles, loaded) @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") def test_save_to_file(self): - op_profile = self.get_sample_op_profile() + op_profile = self.get_sample_op_profile("mylib.foo.default") # Saving with buffer buffer = io.BytesIO() @@ -4647,7 +4668,7 @@ def test_save_to_file(self): self.assertEqual(op_profile, loaded) def test_version(self): - op_profiles = self.get_sample_op_profile() + op_profiles = self.get_sample_op_profile("mylib.foo.default") yaml_str = generate_yaml_from_profiles(op_profiles) loaded = yaml.safe_load(yaml_str) loaded["torch_version"] = "2.7" diff --git a/test/test_dataloader.py b/test/test_dataloader.py index ef92b4f1b82d1a..a0745deae9875e 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -133,11 +133,28 @@ ) -# collate_fn that returns the batch cloned; defined globally here for pickle purposes. +# The following collate functions are defined globally here for pickle purposes. + + +# collate_fn that returns the batch cloned def _clone_collate(b): return [x.clone() for x in b] +# collate_fn that returns the batch of sparse coo tensors cloned +def _sparse_coo_collate(b): + lst = [] + for x in b: + t = x.clone() + lst.append(t) + # Force sparse tensor invariants checks. check_pinning=True + # reproduces gh-153143. + torch._validate_sparse_coo_tensor_args( + t._indices(), t._values(), t.size(), t.is_coalesced(), check_pinning=False + ) + return lst + + @unittest.skipIf( TEST_WITH_TSAN, "Fails with TSAN with the following error: starting new threads after multi-threaded " @@ -2893,8 +2910,9 @@ class TestDataLoaderDeviceType(TestCase): def test_nested_tensor_multiprocessing(self, device, context): # The 'fork' multiprocessing context doesn't work for CUDA so skip it if "cuda" in device and context == "fork": - # TODO: Skip this better in a better way when the test framework allows - return + self.skipTest( + f"{context} multiprocessing context not supported for {device}" + ) dataset = [ torch.nested.nested_tensor([torch.randn(5)], device=device) @@ -2932,6 +2950,37 @@ def test_nested_tensor_multiprocessing(self, device, context): next(iter(loader)) + @parametrize( + "context", + [ctx for ctx in supported_multiprocessing_contexts if ctx is not None], + ) + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") + def test_sparse_tensor_multiprocessing(self, device, context): + # The 'fork' multiprocessing context doesn't work for CUDA so skip it + if "cuda" in device and context == "fork": + self.skipTest( + f"{context} multiprocessing context not supported for {device}" + ) + + dataset = [torch.randn(5, 5).to_sparse().to(device) for _ in range(10)] + + pin_memory_settings = [False] + if device == "cpu" and torch.cuda.is_available(): + pin_memory_settings.append(True) + + for pin_memory in pin_memory_settings: + loader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + num_workers=4, + collate_fn=_sparse_coo_collate, + pin_memory=pin_memory, + multiprocessing_context=context, + ) + + for i, batch in enumerate(loader): + self.assertEqual(batch[0], dataset[i]) + class IntegrationTestDataLoaderDataPipe(TestCase): r""" diff --git a/test/test_dispatch.py b/test/test_dispatch.py index 0e77c31915e533..046faea9c4843f 100644 --- a/test/test_dispatch.py +++ b/test/test_dispatch.py @@ -1118,7 +1118,7 @@ def test_autogradother(self): def test_duplicate_registrations(self): dispatcher = PythonDispatcher() - with self.assertRaisesRegex(RuntimeError, r"Overriden is not allowed"): + with self.assertRaisesRegex(RuntimeError, r"Overridden is not allowed"): dispatcher.register(["CPU", "CPU"]) def test_defaultbackend_math(self): diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 2ee4e64b9f3219..389f63efa687f3 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -11,7 +11,12 @@ skipMeta, ) from torch.testing._internal.common_dtype import all_types_and_complex_and -from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase +from torch.testing._internal.common_utils import ( + IS_JETSON, + run_tests, + skipIfTorchDynamo, + TestCase, +) from torch.utils.dlpack import from_dlpack, to_dlpack @@ -164,7 +169,7 @@ def test_dlpack_conversion_with_diff_streams(self, device, dtype): # in the current stream to make sure that it was correctly populated. with torch.cuda.stream(stream_a): x = make_tensor((5,), dtype=dtype, device=device) + 1 - z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream)) + z = torch.from_dlpack(x.__dlpack__(stream=stream_b.cuda_stream)) stream_a.synchronize() stream_b.synchronize() self.assertEqual(z, x) @@ -201,7 +206,7 @@ def __dlpack__(self, stream=None): assert stream == 1 else: assert stream == 0 - capsule = self.tensor.__dlpack__(stream) + capsule = self.tensor.__dlpack__(stream=stream) return capsule # CUDA-based tests runs on non-default streams @@ -224,7 +229,7 @@ def test_dlpack_convert_default_stream(self, device): x = torch.zeros(1, device=device) torch.cuda._sleep(2**20) self.assertTrue(torch.cuda.default_stream().query()) - x.__dlpack__(1) + x.__dlpack__(stream=1) # check that the default stream has work (a pending cudaStreamWaitEvent) self.assertFalse(torch.cuda.default_stream().query()) @@ -281,6 +286,37 @@ def test_automatically_select_in_creation(self, device): new_tensor = torch.tensor(wrap) self.assertEqual(tensor, new_tensor) + @skipMeta + @skipIfTorchDynamo("__dlpack__ doesn't work with dynamo") + @onlyNativeDeviceTypes + def test_max_version(self, device): + def capsule_name(kwargs): + is_versioned = "max_version" in kwargs and kwargs["max_version"][0] >= 1 + return "dltensor_versioned" if is_versioned else "dltensor" + + def test(device, **kwargs): + inp = make_tensor((5,), dtype=torch.float32, device=device) + + # Make sure we are actually using the (un)versioned DLPack tensor, based on the + # informed keyword arguments. + capsule = inp.__dlpack__(**kwargs) + self.assertRegex( + str(capsule), f"""capsule object "{capsule_name(kwargs)}" at""" + ) + + out = torch.from_dlpack(capsule) + self.assertEqual(inp, out) + + # Use the DLPack 0.X version implementation, since max_version=None. + test(device) + # Use the DLPack 0.X version implementation. + test(device, max_version=(0, 8)) + # Current highest DLPack version implemented. + test(device, max_version=(1, 0)) + # Newer DLPack version. + # Consumer should still be able to process a smaller version capsule. + test(device, max_version=(2, 0)) + instantiate_device_type_tests(TestTorchDlPack, globals()) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d13a61146aa2a0..f9fc61af81d402 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -15,8 +15,8 @@ import torch.nn.functional as F from torch import sym_int, SymBool, SymFloat, SymInt from torch._C import _disabled_torch_function_impl -from torch._dynamo.testing import CompileCounterWithBackend -from torch._inductor.utils import fresh_inductor_cache +from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend +from torch._inductor.utils import fresh_cache from torch.fx.experimental import sym_node from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node @@ -949,7 +949,7 @@ def test_floor_clean_div_axioms(self): shape_env = ShapeEnv() a = shape_env.create_unbacked_symint() - shape_env.defer_runtime_assert((a // 3 == 1).node.expr, " test") + shape_env.guard_or_defer_runtime_assert((a // 3 == 1).node.expr, " test") from sympy import Eq @@ -960,7 +960,7 @@ def test_floor_clean_div_axioms(self): self.assertEqual(shape_env._maybe_evaluate_static(test2), None) # After this FloorDiv(a, 3) is simplified to CleanDiv(a, 3) - shape_env.defer_runtime_assert(Eq(Mod(a, 3), 0), " test") + shape_env.guard_or_defer_runtime_assert(Eq(Mod(a, 3), 0), " test") self.assertEqual(test2, shape_env.simplify(test1)) self.assertTrue(shape_env.evaluate_expr(test1)) @@ -1857,6 +1857,28 @@ def is_complex(x): class TestDimConstraints(TestCase): + @skipIfTorchDynamo("mark_dynamic not supported") + def test_simplify_max_1_0(self): + x = torch.rand(10) + torch._dynamo.mark_dynamic(x, 0, max=20, min=5) + + @torch.compile(fullgraph=True) + def func(x, v): + # test that statically_known_true + if (v == 0 or v == 1) and not statically_known_true( + max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2) + ): + raise AssertionError("error") + + if max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2): + return x * 400 + else: + return (x * 10) * 100 + + # testing that this does not throw constraint violation error. + self.assertEqual(func(x, 1), x * 400) + self.assertEqual(func(x, 0), x * 400) + def test_dim_constraints_reduce_congruences_simple(self): from sympy import Symbol @@ -3050,6 +3072,7 @@ def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph: class TestUnbacked(TestCase): + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_neq_assert(self, backend): @@ -3097,6 +3120,7 @@ def func(x, y): with self.assertRaises(RuntimeError): func(torch.rand(2, 50), torch.tensor([51])) + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_sym_or_assert(self, backend): @@ -3118,6 +3142,7 @@ def test_has_free_symbols(self): self.assertTrue(has_free_symbols(sympy.sympify("a*2"))) self.assertTrue(has_free_symbols(sympy.sympify("a+b"))) + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_sym_eq_assert(self, backend): @@ -3150,7 +3175,7 @@ def func(a, b): class TestUbackedOps(TestCase): - @fresh_inductor_cache() + @fresh_cache() @skipIfTorchDynamo("not allowed to trace mark_unbacked") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_reshape1(self): @@ -3242,8 +3267,8 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1] pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None - view: "i64[u0, u0][u0, 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) - view_1: "i64[u0, u0][u0, 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None + view: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) + view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None return (mul_4, mul_7)""", # noqa: B950 @@ -3311,8 +3336,8 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", _assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None - mul_18: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None - return (mul_18,)""", # noqa: B950 + mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None + return (mul_21,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) @@ -3334,7 +3359,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", ) with ctx(): # This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3]. - # but not anymore since we use definitely_contiguous . + # but not anymore since we use contiguous_or_false . # We need a way to mark strides unbacked to avoid the recompilation here. x = torch.randn(10, 10) torch._dynamo.decorators.mark_unbacked(x, 0) @@ -3417,6 +3442,93 @@ def func(x, y): # throws a data dependent error. compiled_func(x, torch.tensor([5, 20])) + @skipIfTorchDynamo() + def test_unbind_not_dynamic(self): + cnt = CompileCounter() + + @torch.compile(fullgraph=True, dynamic=True, backend=cnt) + def func(y): + return y.unbind(dim=2), y * 10 + + func(torch.ones(5, 6, 7, 8)) + self.assertEqual(cnt.frame_count, 1) + # it can be dynamic in all dimentions except dim=2 + func(torch.ones(4, 9, 7, 10)) + self.assertEqual(cnt.frame_count, 1) + + func(torch.ones(5, 6, 8, 8)) + func(torch.ones(5, 6, 9, 8)) + self.assertEqual(cnt.frame_count, 3) + + @skipIfTorchDynamo("not allowed to trace mark_unbacked") + @fresh_cache() + def test_unbacked_contiguous(self): + cnt = CompileCounterWithBackend("inductor") + + def func(x): + contig = x.contiguous() + return (contig + 1) * 100 + + compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) + + x = torch.randn(10, 10) + # make x not contiguous. + x = x.t_() + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(x, 1) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + compiled_func(x) + self.assertEqual(compiled_func(x), func(x)) + y = torch.rand(20, 20).t() + self.assertEqual(compiled_func(y), func(y)) + self.assertEqual(cnt.frame_count, 1) + output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + self.assertExpectedInline( + output, + """\ + ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None + clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None + add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None + mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None + return (mul_6,)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + # recompilation will happen due to stride specialization. + y = torch.rand(20, 20) + torch._dynamo.decorators.mark_unbacked(y, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + self.assertEqual(compiled_func(y), func(y)) + self.assertEqual(cnt.frame_count, 2) + + output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + + # No clone this time since input is contiguous. + self.assertExpectedInline( + output, + """\ + ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None + add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None + mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None + return (mul_5,)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_expanded_weights.py b/test/test_expanded_weights.py index 7f210bf79a29f8..02bf6d776568c3 100644 --- a/test/test_expanded_weights.py +++ b/test/test_expanded_weights.py @@ -679,7 +679,7 @@ def _do_test( expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] if not batch_first: expected_grads[-1] = expected_grads[-1].transpose(0, 1) - self.assertEqual(actual_res, expected_res) + self.assertEqual(actual_res, expected_res, atol=atol, rtol=rtol) [ self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads) @@ -776,7 +776,7 @@ def _do_test_rnn_packed_sequence( expected_grads.append(out_grads) expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] - self.assertEqual(actual_res, expected_res) + self.assertEqual(actual_res, expected_res, atol=atol, rtol=rtol) [ self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads) @@ -807,11 +807,7 @@ def batch_hidden(h): return h.unsqueeze(1).repeat(new_h_shape) module_cls = module_info.module_cls - atol, rtol = ( - (1e-4, 1e-5) - if module_cls == torch.nn.GRU and dtype == torch.float32 - else (None, None) - ) + atol, rtol = (1e-3, 1e-4) if dtype == torch.float32 else (None, None) module_inputs = module_info.module_inputs_func( module_info, device=device, diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index b4e74ec15864c1..017e41c114ee62 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -64,6 +64,7 @@ skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, + skipIfWindows, TemporaryFileName, TEST_WITH_TORCHDYNAMO, TestCase, @@ -345,6 +346,42 @@ def test_fake_mode_error(self): with FakeTensorMode(): y = x[0] + def test_no_tag_func(self): + import functools + + from torch.nn.attention.flex_attention import _identity, flex_attention + + def create_attention(score_mod, block_mask, enable_gqa=False): + return functools.partial( + flex_attention, + score_mod=score_mod, + block_mask=block_mask, + enable_gqa=enable_gqa, + ) + + input_shape = (4, 16, 128, 64) + q = torch.randn( + input_shape, + dtype=torch.bfloat16, + device="cpu", + requires_grad=False, + ) + k = torch.randn( + input_shape, + dtype=torch.bfloat16, + device="cpu", + requires_grad=False, + ) + v = torch.randn( + input_shape, + dtype=torch.bfloat16, + device="cpu", + requires_grad=False, + ) + sdpa_partial = create_attention(_identity, None) + with FakeTensorMode(allow_non_fake_inputs=True): + sdpa_partial(q, k, v, return_lse=False) + @unittest.skipIf( TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" ) @@ -982,6 +1019,26 @@ def test_fast_div(self): y = fast_div(mode, x, 2) self.assertEqual(y.dtype, torch.float32) + def test_nanmean_out(self): + # Regression test to ensure we don't error out. + with torch._subclasses.fake_tensor.FakeTensorMode() as mode: + x = torch.randn(10) + out = torch.empty(()) + torch.nanmean(x, out=out) + + self.assertEqual(out.dtype, x.dtype) + + def test_unbind_copy_out(self): + # Regression test to ensure we don't error out. + with torch._subclasses.fake_tensor.FakeTensorMode() as mode: + eye = torch.eye(3) + out = (torch.zeros(3), torch.zeros(3), torch.zeros(3)) + torch.unbind_copy(eye, out=out) + + self.assertEqual(out[0].dtype, eye.dtype) + self.assertEqual(out[1].dtype, eye.dtype) + self.assertEqual(out[2].dtype, eye.dtype) + instantiate_parametrized_tests(FakeTensorTest) @@ -1489,6 +1546,20 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): self.assertEqual(mode.count, 0) + # PropagateRealTensors installs weakrefs + @expectedFailurePropagateRealTensors + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_module_to(self): + def _check_device(sd, device_type): + for v in sd.values(): + self.assertEqual(v.device.type, device_type) + + with FakeTensorMode(): + m = torch.nn.Linear(2, 2) + _check_device(m.state_dict(), "cpu") + m.to("cuda") + _check_device(m.state_dict(), "cuda") + make_propagate_real_tensors_cls(FakeTensorOperatorInvariants) @@ -2212,6 +2283,9 @@ def test_cache_aten_index(self): lambda: torch.ops.aten.index(x, [None, idx_tensor1]), ) + @skipIfWindows( + msg="weird bug - cache may not be cleared after https://github.com/pytorch/pytorch/pull/154283" + ) @skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching") def test_invoke_subgraph(self): """ diff --git a/test/test_foreach.py b/test/test_foreach.py index 760eeb9c86e549..a5ca220dcb5253 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -16,6 +16,7 @@ from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, + largeTensorTest, onlyCUDA, OpDTypes, ops, @@ -78,8 +79,13 @@ def __init__(self, func): def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): actual = None zero_size = kwargs.pop("zero_size", False) + + # Skip profiler check for CUDA 12.6, 12.8 as the upgrade makes profiler results flaky + # https://github.com/pytorch/pytorch/issues/148681. TODO: ADD IT BACK!!! + skip_profiler_check = _get_torch_cuda_version() in [(12, 6), (12, 8)] if ( is_cuda + and not skip_profiler_check and torch.autograd.kineto_available() and torch.profiler.ProfilerActivity.CUDA in torch.profiler.supported_activities() @@ -90,6 +96,7 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): torch.cuda.synchronize() keys = tuple([e.key for e in p.key_averages()]) mta_called = any("multi_tensor_apply_kernel" in k for k in keys) + assert mta_called == (expect_fastpath and (not zero_size)), ( f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}" ) @@ -358,8 +365,6 @@ def clone(arg): @ops(foreach_pointwise_op_db) @parametrize("is_fastpath", (True, False)) - # TODO: Remove skip CUDA 12.6 once resolved: https://github.com/pytorch/pytorch/issues/148681 - @unittest.skipIf(_get_torch_cuda_version() >= (12, 6), "Failure on CUDA 12.6") def test_pointwise_op_with_tensor_of_scalarlist_overload( self, device, dtype, op, is_fastpath ): @@ -1340,6 +1345,7 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_ foreach_copy_ = ForeachFuncWrapper(op.inplace_variant) + for sample in op.sample_inputs( device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True ): @@ -1358,6 +1364,17 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): for t, ref_t in zip(out, ref_out): self.assertTrue(torch.equal(t, ref_t)) + @onlyCUDA + @largeTensorTest("40GB", device="cuda") + def test_foreach_copy_with_multi_dtypes_large_input(self): + # see https://github.com/pytorch/pytorch/issues/156261 + self_tensor = torch.empty(2**31 + 1, device="cuda", dtype=torch.float32) + src_tensor = torch.ones(2**31 + 1, device="cuda", dtype=torch.bfloat16) + + torch._foreach_copy_([self_tensor], [src_tensor]) + ref_out = torch.empty_like(self_tensor).copy_(src_tensor) + self.assertEqual(self_tensor, ref_out) + @requires_cuda @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) def test_foreach_copy_with_different_device_inputs(self, device, dtype, op): diff --git a/test/test_hub.py b/test/test_hub.py index 1447b3dc4a761f..2add5926d2c409 100644 --- a/test/test_hub.py +++ b/test/test_hub.py @@ -8,7 +8,12 @@ import torch import torch.hub as hub -from torch.testing._internal.common_utils import IS_SANDCASTLE, retry, TestCase +from torch.testing._internal.common_utils import ( + IS_SANDCASTLE, + retry, + run_tests, + TestCase, +) def sum_of_state_dict(state_dict): @@ -307,3 +312,7 @@ def test_trust_repo_legacy(self): torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") self._assert_trusted_list_is_empty() + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_indexing.py b/test/test_indexing.py index 1a07e64717b728..987b3caa810875 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -250,7 +250,10 @@ def validate_setting(x): reference = consec((10,)) strided = torch.tensor((), dtype=dtype, device=device) strided.set_( - reference.storage(), storage_offset=0, size=torch.Size([4]), stride=[2] + reference.untyped_storage(), + storage_offset=0, + size=torch.Size([4]), + stride=[2], ) self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device)) @@ -274,7 +277,10 @@ def validate_setting(x): # stride is [4, 8] strided = torch.tensor((), dtype=dtype, device=device) strided.set_( - reference.storage(), storage_offset=4, size=torch.Size([2]), stride=[4] + reference.untyped_storage(), + storage_offset=4, + size=torch.Size([2]), + stride=[4], ) self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device)) self.assertEqual( @@ -309,15 +315,15 @@ def validate_setting(x): self.assertEqual(reference[ri([0]), ri([0])], consec((1,))) self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6)) self.assertEqual( - reference[[ri([0, 0]), ri([0, 1])]], + reference[(ri([0, 0]), ri([0, 1]))], torch.tensor([1, 2], dtype=dtype, device=device), ) self.assertEqual( - reference[[ri([0, 1, 1, 0, 2]), ri([1])]], + reference[(ri([0, 1, 1, 0, 2]), ri([1]))], torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device), ) self.assertEqual( - reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + reference[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))], torch.tensor([1, 2, 3, 3], dtype=dtype, device=device), ) @@ -387,15 +393,15 @@ def validate_setting(x): reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device) ) self.assertEqual( - reference[[ri([0, 0]), ri([0, 1])]], + reference[(ri([0, 0]), ri([0, 1]))], torch.tensor([0, 4], dtype=dtype, device=device), ) self.assertEqual( - reference[[ri([0, 1, 1, 0, 3]), ri([1])]], + reference[(ri([0, 1, 1, 0, 3]), ri([1]))], torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device), ) self.assertEqual( - reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + reference[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))], torch.tensor([0, 4, 1, 1], dtype=dtype, device=device), ) @@ -446,7 +452,9 @@ def validate_setting(x): reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), stride=[8, 2]) + strided.set_( + reference.untyped_storage(), 1, size=torch.Size([2, 4]), stride=[8, 2] + ) self.assertEqual( strided[ri([0, 1]), ri([0])], @@ -463,15 +471,15 @@ def validate_setting(x): strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device) ) self.assertEqual( - strided[[ri([0, 0]), ri([0, 3])]], + strided[(ri([0, 0]), ri([0, 3]))], torch.tensor([1, 7], dtype=dtype, device=device), ) self.assertEqual( - strided[[ri([1]), ri([0, 1, 1, 0, 3])]], + strided[(ri([1]), ri([0, 1, 1, 0, 3]))], torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device), ) self.assertEqual( - strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + strided[(ri([0, 0, 1, 1]), ri([0, 1, 0, 0]))], torch.tensor([1, 3, 9, 9], dtype=dtype, device=device), ) @@ -502,7 +510,9 @@ def validate_setting(x): reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) + strided.set_( + reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1] + ) self.assertEqual( strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device) ) @@ -513,7 +523,9 @@ def validate_setting(x): reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) + strided.set_( + reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1] + ) self.assertEqual( strided[ri([0, 1]), ri([1, 0])], torch.tensor([11, 17], dtype=dtype, device=device), @@ -528,7 +540,9 @@ def validate_setting(x): reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8) strided = torch.tensor((), dtype=dtype, device=device) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) + strided.set_( + reference.untyped_storage(), 10, size=torch.Size([2, 2]), stride=[7, 1] + ) rows = ri([[0], [1]]) columns = ri([[0, 1], [0, 1]]) @@ -642,19 +656,19 @@ def get_set_tensor(indexed, indexer): indices_to_test = [ # grab the second, fourth columns - [slice(None), [1, 3]], + (slice(None), [1, 3]), # first, third rows, - [[0, 2], slice(None)], + ([0, 2], slice(None)), # weird shape - [slice(None), [[0, 1], [2, 3]]], + (slice(None), [[0, 1], [2, 3]]), # negatives - [[-1], [0]], - [[0, 2], [-1]], - [slice(None), [-1]], + ([-1], [0]), + ([0, 2], [-1]), + (slice(None), [-1]), ] # only test dupes on gets - get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] + get_indices_to_test = indices_to_test + [(slice(None), [0, 1, 1, 2, 2])] for indexer in get_indices_to_test: assert_get_eq(reference, indexer) @@ -668,46 +682,46 @@ def get_set_tensor(indexed, indexer): reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5) indices_to_test = [ - [slice(None), slice(None), [0, 3, 4]], - [slice(None), [2, 4, 5, 7], slice(None)], - [[2, 3], slice(None), slice(None)], - [slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), [0], [1, 2, 4]], - [slice(None), [0, 1, 3], [4]], - [slice(None), [[0, 1], [1, 0]], [[2, 3]]], - [slice(None), [[0, 1], [2, 3]], [[0]]], - [slice(None), [[5, 6]], [[0, 3], [4, 4]]], - [[0, 2, 3], [1, 3, 4], slice(None)], - [[0], [1, 2, 4], slice(None)], - [[0, 1, 3], [4], slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], - [[[0, 1], [1, 0]], [[2, 3]], slice(None)], - [[[0, 1], [2, 3]], [[0]], slice(None)], - [[[2, 1]], [[0, 3], [4, 4]], slice(None)], - [[[2]], [[0, 3], [4, 1]], slice(None)], + (slice(None), slice(None), (0, 3, 4)), + (slice(None), (2, 4, 5, 7), slice(None)), + ((2, 3), slice(None), slice(None)), + (slice(None), (0, 2, 3), (1, 3, 4)), + (slice(None), (0,), (1, 2, 4)), + (slice(None), (0, 1, 3), (4,)), + (slice(None), ((0, 1), (1, 0)), ((2, 3),)), + (slice(None), ((0, 1), (2, 3)), ((0,),)), + (slice(None), ((5, 6),), ((0, 3), (4, 4))), + ((0, 2, 3), (1, 3, 4), slice(None)), + ((0,), (1, 2, 4), slice(None)), + ((0, 1, 3), (4,), slice(None)), + (((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None)), + (((0, 1), (1, 0)), ((2, 3),), slice(None)), + (((0, 1), (2, 3)), ((0,),), slice(None)), + (((2, 1),), ((0, 3), (4, 4)), slice(None)), + (((2,),), ((0, 3), (4, 1)), slice(None)), # non-contiguous indexing subspace - [[0, 2, 3], slice(None), [1, 3, 4]], + ((0, 2, 3), slice(None), (1, 3, 4)), # [...] # less dim, ellipsis - [[0, 2]], - [[0, 2], slice(None)], - [[0, 2], Ellipsis], - [[0, 2], slice(None), Ellipsis], - [[0, 2], Ellipsis, slice(None)], - [[0, 2], [1, 3]], - [[0, 2], [1, 3], Ellipsis], - [Ellipsis, [1, 3], [2, 3]], - [Ellipsis, [2, 3, 4]], - [Ellipsis, slice(None), [2, 3, 4]], - [slice(None), Ellipsis, [2, 3, 4]], + ((0, 2),), + ((0, 2), slice(None)), + ((0, 2), Ellipsis), + ((0, 2), slice(None), Ellipsis), + ((0, 2), Ellipsis, slice(None)), + ((0, 2), (1, 3)), + ((0, 2), (1, 3), Ellipsis), + (Ellipsis, (1, 3), (2, 3)), + (Ellipsis, (2, 3, 4)), + (Ellipsis, slice(None), (2, 3, 4)), + (slice(None), Ellipsis, (2, 3, 4)), # ellipsis counts for nothing - [Ellipsis, slice(None), slice(None), [0, 3, 4]], - [slice(None), Ellipsis, slice(None), [0, 3, 4]], - [slice(None), slice(None), Ellipsis, [0, 3, 4]], - [slice(None), slice(None), [0, 3, 4], Ellipsis], - [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], + (Ellipsis, slice(None), slice(None), (0, 3, 4)), + (slice(None), Ellipsis, slice(None), (0, 3, 4)), + (slice(None), slice(None), Ellipsis, (0, 3, 4)), + (slice(None), slice(None), (0, 3, 4), Ellipsis), + (Ellipsis, ((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None)), + (((0, 1), (1, 0)), ((2, 1), (3, 5)), Ellipsis, slice(None)), + (((0, 1), (1, 0)), ((2, 1), (3, 5)), slice(None), Ellipsis), ] for indexer in indices_to_test: @@ -720,65 +734,65 @@ def get_set_tensor(indexed, indexer): reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6) indices_to_test = [ - [slice(None), slice(None), slice(None), [0, 3, 4]], - [slice(None), slice(None), [2, 4, 5, 7], slice(None)], - [slice(None), [2, 3], slice(None), slice(None)], - [[1, 2], slice(None), slice(None), slice(None)], - [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), slice(None), [0], [1, 2, 4]], - [slice(None), slice(None), [0, 1, 3], [4]], - [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], - [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], - [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], - [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], - [slice(None), [0], [1, 2, 4], slice(None)], - [slice(None), [0, 1, 3], [4], slice(None)], - [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], - [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], - [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], - [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], - [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], - [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], - [[0], [1, 2, 4], slice(None), slice(None)], - [[0, 1, 2], [4], slice(None), slice(None)], - [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], - [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], - [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], - [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], - [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], - [slice(None), [2, 3, 4], [1, 3, 4], [4]], - [slice(None), [0, 1, 3], [4], [1, 3, 4]], - [slice(None), [6], [0, 2, 3], [1, 3, 4]], - [slice(None), [2, 3, 5], [3], [4]], - [slice(None), [0], [4], [1, 3, 4]], - [slice(None), [6], [0, 2, 3], [1]], - [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], - [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], - [[2, 0, 1], [1, 2, 3], [4], slice(None)], - [[0, 1, 2], [4], [1, 3, 4], slice(None)], - [[0], [0, 2, 3], [1, 3, 4], slice(None)], - [[0, 2, 1], [3], [4], slice(None)], - [[0], [4], [1, 3, 4], slice(None)], - [[1], [0, 2, 3], [1], slice(None)], - [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], + (slice(None), slice(None), slice(None), (0, 3, 4)), + (slice(None), slice(None), (2, 4, 5, 7), slice(None)), + (slice(None), (2, 3), slice(None), slice(None)), + ((1, 2), slice(None), slice(None), slice(None)), + (slice(None), slice(None), (0, 2, 3), (1, 3, 4)), + (slice(None), slice(None), (0,), (1, 2, 4)), + (slice(None), slice(None), (0, 1, 3), (4,)), + (slice(None), slice(None), ((0, 1), (1, 0)), ((2, 3),)), + (slice(None), slice(None), ((0, 1), (2, 3)), ((0,),)), + (slice(None), slice(None), ((5, 6),), ((0, 3), (4, 4))), + (slice(None), (0, 2, 3), (1, 3, 4), slice(None)), + (slice(None), (0,), (1, 2, 4), slice(None)), + (slice(None), (0, 1, 3), (4,), slice(None)), + (slice(None), ((0, 1), (3, 4)), ((2, 3), (0, 1)), slice(None)), + (slice(None), ((0, 1), (3, 4)), ((2, 3),), slice(None)), + (slice(None), ((0, 1), (3, 2)), ((0,),), slice(None)), + (slice(None), ((2, 1),), ((0, 3), (6, 4)), slice(None)), + (slice(None), ((2,),), ((0, 3), (4, 2)), slice(None)), + ((0, 1, 2), (1, 3, 4), slice(None), slice(None)), + ((0,), (1, 2, 4), slice(None), slice(None)), + ((0, 1, 2), (4,), slice(None), slice(None)), + (((0, 1), (0, 2)), ((2, 4), (1, 5)), slice(None), slice(None)), + (((0, 1), (1, 2)), ((2, 0),), slice(None), slice(None)), + (((2, 2),), ((0, 3), (4, 5)), slice(None), slice(None)), + (((2,),), ((0, 3), (4, 5)), slice(None), slice(None)), + (slice(None), (3, 4, 6), (0, 2, 3), (1, 3, 4)), + (slice(None), (2, 3, 4), (1, 3, 4), (4,)), + (slice(None), (0, 1, 3), (4,), (1, 3, 4)), + (slice(None), (6,), (0, 2, 3), (1, 3, 4)), + (slice(None), (2, 3, 5), (3,), (4,)), + (slice(None), (0,), (4,), (1, 3, 4)), + (slice(None), (6,), (0, 2, 3), (1,)), + (slice(None), ((0, 3), (3, 6)), ((0, 1), (1, 3)), ((5, 3), (1, 2))), + ((2, 2, 1), (0, 2, 3), (1, 3, 4), slice(None)), + ((2, 0, 1), (1, 2, 3), (4,), slice(None)), + ((0, 1, 2), (4,), (1, 3, 4), slice(None)), + ((0,), (0, 2, 3), (1, 3, 4), slice(None)), + ((0, 2, 1), (3,), (4,), slice(None)), + ((0,), (4,), (1, 3, 4), slice(None)), + ((1,), (0, 2, 3), (1,), slice(None)), + (((1, 2), (1, 2)), ((0, 1), (2, 3)), ((2, 3), (3, 5)), slice(None)), # less dim, ellipsis - [Ellipsis, [0, 3, 4]], - [Ellipsis, slice(None), [0, 3, 4]], - [Ellipsis, slice(None), slice(None), [0, 3, 4]], - [slice(None), Ellipsis, [0, 3, 4]], - [slice(None), slice(None), Ellipsis, [0, 3, 4]], - [slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], - [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], - [[0], [1, 2, 4]], - [[0], [1, 2, 4], slice(None)], - [[0], [1, 2, 4], Ellipsis], - [[0], [1, 2, 4], Ellipsis, slice(None)], - [[1]], - [[0, 2, 1], [3], [4]], - [[0, 2, 1], [3], [4], slice(None)], - [[0, 2, 1], [3], [4], Ellipsis], - [Ellipsis, [0, 2, 1], [3], [4]], + (Ellipsis, (0, 3, 4)), + (Ellipsis, slice(None), (0, 3, 4)), + (Ellipsis, slice(None), slice(None), (0, 3, 4)), + (slice(None), Ellipsis, (0, 3, 4)), + (slice(None), slice(None), Ellipsis, (0, 3, 4)), + (slice(None), (0, 2, 3), (1, 3, 4)), + (slice(None), (0, 2, 3), (1, 3, 4), Ellipsis), + (Ellipsis, (0, 2, 3), (1, 3, 4), slice(None)), + ((0,), (1, 2, 4)), + ((0,), (1, 2, 4), slice(None)), + ((0,), (1, 2, 4), Ellipsis), + ((0,), (1, 2, 4), Ellipsis, slice(None)), + ((1,),), + ((0, 2, 1), (3,), (4,)), + ((0, 2, 1), (3,), (4,), slice(None)), + ((0, 2, 1), (3,), (4,), Ellipsis), + (Ellipsis, (0, 2, 1), (3,), (4,)), ] for indexer in indices_to_test: @@ -786,8 +800,8 @@ def get_set_tensor(indexed, indexer): assert_set_eq(reference, indexer, 1333) assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) indices_to_test += [ - [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], - [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], + (slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]), + (slice(None), slice(None), [[2]], [[0, 3], [4, 4]]), ] for indexer in indices_to_test: assert_get_eq(reference, indexer) @@ -866,6 +880,21 @@ def test_bool_indices(self, device): ) self.assertEqual(len(w), 1) + def test_list_indices(self, device): + N = 1000 + t = torch.randn(N, device=device) + # Set window size + W = 10 + # Generate a list of lists, containing overlapping window indices + indices = [range(i, i + W) for i in range(0, N - W)] + + for i in [len(indices), 100, 32]: + windowed_data = t[indices[:i]] + self.assertEqual(windowed_data.shape, (i, W)) + + with self.assertRaisesRegex(IndexError, "too many indices"): + windowed_data = t[indices[:31]] + def test_bool_indices_accumulate(self, device): mask = torch.zeros(size=(10,), dtype=torch.bool, device=device) y = torch.ones(size=(10, 10), device=device) diff --git a/test/test_linalg.py b/test/test_linalg.py index 5ffc1a11a8a092..108a5f590079c2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -24,12 +24,12 @@ make_fullrank_matrices_with_distinct_singular_values, freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo, setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest, - runOnRocmArch, MI300_ARCH) + runOnRocmArch, MI300_ARCH, TEST_CUDA) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA, - onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm, skipCUDAIfRocmVersionLessThan, + onlyCUDA, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm, skipCUDAIfRocmVersionLessThan, dtypesIfMPS, largeTensorTest) from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( @@ -2818,6 +2818,9 @@ def test_invariance_error_spectral_decompositions(self, device, dtype): Q = torch.linalg.eigh(A).eigenvectors Q.sum().abs().backward() + # I don't know how much memory this test uses but on complex64 it needs at least 4GB + @largeTensorTest("4GB", device="cuda") + @serialTest(TEST_CUDA) @skipCUDAIfNoCusolver # MAGMA backend doesn't work in this case @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) @skipCPUIfNoLapack @@ -3713,7 +3716,6 @@ def test_matrix_rank_atol_rtol(self, device, dtype): @skipCUDAIfNoMagma @skipCPUIfNoLapack - @skipCUDAVersionIn([(11, 6), (11, 7)]) # https://github.com/pytorch/pytorch/issues/75391 @dtypes(*floating_and_complex_types()) def test_matrix_rank_empty(self, device, dtype): matrix_rank = torch.linalg.matrix_rank @@ -5915,6 +5917,56 @@ def test_rowwise_scaled_gemm_numerics_tunableop(self, device, dtype): delta = tuned_default_scaled_mm - ref_scaled_mm self.assertTrue(torch.all(delta == 0)) + @onlyCUDA + @skipCUDAIfNotRocm + @dtypes(torch.float) + def test_call_count_tunableop(self, device, dtype): + # Test that after tuning a GEMM in TunableOp, we only call the GEMM kernel once + # per PyTorch API invocation. + # We use the torch profiler to get the call counts on the kernels + + # Supported only for: MM, batch MM, and GEMM with bias (linear) + from torch.profiler import profile, ProfilerActivity + + with self._tunableop_ctx(): + # set these to single iterations to keep it short but still exercise the code + torch.cuda.tunable.set_max_tuning_iterations(1) + + b = 2 + M = 10 + + # MM + A = torch.rand(M, M, device=device) + C = torch.mm(A, A) + + # Linear - GEMM BIAS + X = torch.rand(M, M, device='cuda') + bias = torch.rand(M, device='cuda') + Y = torch.nn.functional.linear(X, A, bias) + + # BMM + batch_A = torch.rand((b, M, M), device='cuda') + batch_C = torch.bmm(batch_A, batch_A) + + kernel_count = 0 + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + C = torch.mm(A, A) + Y = torch.nn.functional.linear(X, A, bias) + batch_C = torch.bmm(batch_A, batch_A) + + # Check that after tuning, there was only one kernel + # launched per PyTorch API. The kernels have string + # that always starts with `Cijk*` + mm_key = 'Cijk' + events = prof.key_averages() + for evt in events: + if mm_key in evt.key: + self.assertEqual(evt.count, 1) + kernel_count = kernel_count + 1 + + # There must be exactly three kernels only + self.assertEqual(kernel_count, 3) + @dtypes(torch.float, torch.complex64) def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0) @@ -6022,6 +6074,18 @@ def test_linalg_cross_with_and_without_dim(self, device, dtype): self.assertEqual(res1, res2) self.assertEqual(res1, res3) + def test_cross_error(self, device): + x = torch.randn(4, 3, device=device) + y = torch.randn(4, 3, device=device) + with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"): + torch.cross(x, y, out=x) + with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"): + torch.cross(y, x, out=x) + with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"): + torch.linalg.cross(x, y, out=x) + with self.assertRaisesRegex(RuntimeError, "input tensor and the written-to tensor refer to a single memory location"): + torch.linalg.cross(y, x, out=x) + def test_renorm(self, device): m1 = torch.randn(20, 20, device=device) # big enough to exercise vectorized path res1 = torch.tensor((), device=device) @@ -9101,7 +9165,8 @@ def dims_full_for_fn(): r1 = fntorch(t0_full, t1, t2) self.assertEqual(r0, r1) - @tf32_on_and_off(0.001) + # ROCm 6.4 passes with tf32=on, but 6.4.1 needed tolerance reduced slightly + @tf32_on_and_off(0.002 if torch.version.hip else 0.001) @bf32_on_and_off(0.001) def test_broadcast_batched_matmul(self, device): n_dim = random.randint(1, 8) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 593c78f74d41cc..c7c75cdb7927bb 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -307,15 +307,15 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist) self.assertEqual(bgrad, b.grad) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - @xfailIfSM100OrLater - @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @xfailIfSM120OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major): device = "cuda" dtype = torch.bfloat16 - m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16 + m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] else: @@ -329,8 +329,9 @@ def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major): a.requires_grad_(True) b.requires_grad_(True) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) - out = torch._grouped_mm(a, b.t(), offs=offs, - out_dtype=torch.bfloat16) + + f = torch._grouped_mm + out = f(a, b.t(), offs=offs, out_dtype=torch.bfloat16) gO = torch.rand_like(out) out.backward(gO) offs_cpu = offs.cpu() @@ -345,8 +346,8 @@ def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major): self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - @xfailIfSM100OrLater - @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @xfailIfSM120OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) @@ -354,7 +355,7 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major): device = "cuda" dtype = torch.bfloat16 s_int = int(strided) - m, n, k, n_groups = 16, 32, 16, 4 + m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(m * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] else: @@ -374,13 +375,17 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major): b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) for check_zero_size in (False, True): + if check_zero_size and n_groups <= 1: + continue + a.grad = None b.grad = None offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) if check_zero_size: offs[0] = offs[1] - out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs, - out_dtype=torch.bfloat16) + + f = torch._grouped_mm + out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=torch.bfloat16) gO = torch.rand_like(out) if not check_zero_size: out.backward(gO) @@ -398,8 +403,8 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major): @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - @xfailIfSM100OrLater - @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @xfailIfSM120OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) @@ -407,7 +412,7 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major): device = "cuda" dtype = torch.bfloat16 s_int = int(strided) - m, n, k, n_groups = 16, 32, 16, 4 + m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] else: @@ -426,14 +431,15 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major): b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) - out = torch._grouped_mm(a, b.transpose(-2, -1), out_dtype=torch.bfloat16) + f = torch._grouped_mm + out = f(a, b.transpose(-2, -1), out_dtype=torch.bfloat16) gO = torch.rand_like(out) out.backward(gO) self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - @xfailIfSM100OrLater - @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @xfailIfSM120OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) @@ -441,7 +447,7 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major): device = "cuda" dtype = torch.bfloat16 s_int = int(strided) - m, n, k, n_groups = 16, 32, 16, 4 + m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] else: @@ -460,11 +466,15 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major): b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) for check_zero_size in (False, True): + if check_zero_size and n_groups <= 1: + continue + offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) if check_zero_size: offs[0] = offs[1] - out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs, - out_dtype=torch.bfloat16) + + f = torch._grouped_mm + out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=torch.bfloat16) gO = torch.rand_like(out) if not check_zero_size: out.backward(gO) @@ -480,6 +490,117 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major): start = offs_cpu[i] self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist) + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major): + torch._dynamo.reset() + + device = "cuda" + dtype_AB = torch.bfloat16 + dtype_offset = torch.int32 + + align = 16 // dtype_AB.itemsize + + f_ref = torch._grouped_mm + f = torch.compile( + f_ref, + options={ + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + }, + ) + + if op == "2d/2d": + m, n = 3, 7 + m_align = (m + align - 1) // align * align + n_align = (n + align - 1) // align * align + if not a_row_major and not b_row_major: + offs = torch.tensor([1, 3, 4, 6, 7], device=device, dtype=dtype_offset) + else: + offs = torch.tensor([8, 16, 32, 37], device=device, dtype=dtype_offset) + ngroups = offs.shape[0] + k = offs[-1] + k_align = (k + align - 1) // align * align + + if a_row_major: + A = torch.randn(m, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + A = torch.randn(k, m_align, device=device, dtype=dtype_AB).t()[:m, :] + if b_row_major: + B = torch.randn(n, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] + elif op == "2d/3d": + n, k = 7, 13 + n_align = (n + align - 1) // align * align + k_align = (k + align - 1) // align * align + if a_row_major: + offs = torch.tensor([0, 1, 3, 3, 5], device=device, dtype=dtype_offset) + else: + offs = torch.tensor([0, 8, 16, 16, 19], device=device, dtype=dtype_offset) + ngroups = offs.shape[0] + m = offs[-1] + m_align = (m + align - 1) // align * align + + if a_row_major: + A = torch.randn(m, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + A = torch.randn(k, m_align, device=device, dtype=dtype_AB).t()[:m, :] + if b_row_major: + B = torch.randn(ngroups, n, k_align, device=device, dtype=dtype_AB)[:, :, :k] + else: + B = torch.randn(ngroups, k, n_align, device=device, dtype=dtype_AB).transpose( + -2, -1 + )[:, :n, :] + elif op == "3d/2d": + m, k = 3, 13 + m_align = (m + align - 1) // align * align + k_align = (k + align - 1) // align * align + offs = torch.tensor([0, 8, 16, 16, 19], device=device, dtype=dtype_offset) + ngroups = offs.shape[0] + n = offs[-1] + n_align = (n + align - 1) // align * align + + if a_row_major: + A = torch.randn(ngroups, m, k_align, device=device, dtype=dtype_AB)[:, :, :k] + else: + A = torch.randn(ngroups, k, m_align, device=device, dtype=dtype_AB).transpose( + -2, -1 + )[:, :m, :] + if b_row_major: + B = torch.randn(n, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] + elif op == "3d/3d": + offs = None + ngroups = 5 + m, n, k = 3, 7, 13 + m_align = (m + align - 1) // align * align + n_align = (n + align - 1) // align * align + k_align = (k + align - 1) // align * align + if a_row_major: + A = torch.randn(ngroups, m, k_align, device=device, dtype=dtype_AB)[:, :, :k] + else: + A = torch.randn(ngroups, k, m_align, device=device, dtype=dtype_AB).transpose( + -2, -1 + )[:, :m, :] + if b_row_major: + B = torch.randn(ngroups, n, k_align, device=device, dtype=dtype_AB)[:, :, :k] + else: + B = torch.randn(ngroups, k, n_align, device=device, dtype=dtype_AB).transpose( + -2, -1 + )[:, :n, :] + else: + raise AssertionError(f"Invaild op: {op}") + + C_ref = f_ref(A, B.transpose(-2, -1), offs=offs) + C = f(A, B.transpose(-2, -1), offs=offs) + torch.testing.assert_close(C, C_ref) + @onlyCUDA @skipIfRocm @@ -1070,7 +1191,6 @@ def test_float8_scale_fast_accum(self, device) -> None: self.assertEqual(out_fp8, out_fp8_s) @onlyCUDA - @xfailIfSM120OrLater @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") @parametrize("use_fast_accum", [True, False]) @@ -1177,7 +1297,6 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) - @xfailIfSM120OrLater @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") @parametrize("base_dtype", [torch.bfloat16]) @@ -1239,7 +1358,6 @@ def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_fp32, out_fp8.to(torch.float)) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet") @@ -1268,15 +1386,31 @@ def test_honor_sm_carveout(self) -> None: torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16) prof.export_chrome_trace(f.name) - no_carveout, carveout_0, carveout_66, no_carveout_again = [ - math.prod(evt.get("args", {}).get("grid", [])) - for evt in json.load(open(f.name))["traceEvents"] - if evt.get("cat", "") == "kernel" - ] - - self.assertEqual(no_carveout, no_carveout_again) - self.assertNotEqual(no_carveout, carveout_66) - self.assertNotEqual(carveout_66, carveout_0) + if torch.version.hip: + events = [evt for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel"] + # events were returned out of order; need to be sorted on "ts" timestamp + events = sorted(events, key=lambda x: x['ts']) + # ROCm carveout is invisible except for kernels running slower on fewer CUs + no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events] + self.assertTrue(no_carveout < carveout_66) + self.assertTrue(carveout_0 < carveout_66) + self.assertTrue(no_carveout_again < carveout_66) + # ROCm carveout will create new streams when enabled, and go back to the original stream when disabled + no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events] + self.assertTrue(no_carveout == no_carveout_again) + self.assertTrue(no_carveout != carveout_0) + self.assertTrue(no_carveout != carveout_66) + self.assertTrue(carveout_0 != carveout_66) + else: + no_carveout, carveout_0, carveout_66, no_carveout_again = [ + math.prod(evt.get("args", {}).get("grid", [])) + for evt in json.load(open(f.name))["traceEvents"] + if evt.get("cat", "") == "kernel" + ] + + self.assertEqual(no_carveout, no_carveout_again) + self.assertNotEqual(no_carveout, carveout_66) + self.assertNotEqual(carveout_66, carveout_0) def test_pack_uint4(self): """ @@ -1616,24 +1750,27 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist): out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1), out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum) - self.assertEqual(out, out_ref) + self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4) + + # Testing only _scaled_grouped_mm() with multiple shapes, as + # _scaled_mm() already has more combinations of parameters than + # _scaled_grouped_mm(), for supporing more than one inputs layout + # combinations. @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM100OrLater @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - @parametrize("use_torch_compile", [False, True]) - def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile): + def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): device = "cuda" - m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16 + m, n, k, n_groups = 16, 32, 64, 4 a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] - scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4 - scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4 + scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32) + scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) f = torch._scaled_grouped_mm - f = torch.compile(f) if use_torch_compile else f out = f(a, b.t(), scale_a, scale_b, offs=offs, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) offs_cpu = offs.cpu() @@ -1653,24 +1790,25 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile) @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - @parametrize("use_torch_compile", [False, True]) - def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile): + def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 s_int = int(strided) - m, n, k, n_groups = 16, 32, 16, 4 a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) for check_zero_size in (True, False): + if check_zero_size and n_groups <= 1: + continue + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) if check_zero_size: offs[0] = offs[1] - scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32) - scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n) + scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32) + scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n) f = torch._scaled_grouped_mm - f = torch.compile(f, dynamic=False) if use_torch_compile else f out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) @@ -1682,7 +1820,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile) ascalelist.append(scale_a[start:offs_cpu[i]]) outlist.append(out[start:offs_cpu[i]]) start = offs_cpu[i] - self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum) + self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @@ -1690,20 +1828,18 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile) @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - @parametrize("use_torch_compile", [False, True]) - def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile): + def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided): device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 s_int = int(strided) - m, n, k, n_groups = 16, 32, 16, 4 a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) - scale_a = torch.ones(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) - scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n) + scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) + scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n) f = torch._scaled_grouped_mm - f = torch.compile(f) if use_torch_compile else f out = f(a, b.transpose(-2, -1), scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) @@ -1715,24 +1851,25 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile) @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - @parametrize("use_torch_compile", [False, True]) - def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile): + def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided): device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 s_int = int(strided) - m, n, k, n_groups = 16, 32, 16, 4 a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) - scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) - scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32) + scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) + scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32) for check_zero_size in (True, False): + if check_zero_size and n_groups <= 1: + continue + offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) if check_zero_size: offs[0] = offs[1] f = torch._scaled_grouped_mm - f = torch.compile(f) if use_torch_compile else f out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) offs_cpu = offs.cpu() @@ -1743,7 +1880,7 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile) bscalelist.append(scale_b[start:offs_cpu[i]]) outlist.append(out[:, start:offs_cpu[i]]) start = offs_cpu[i] - self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum) + self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum) @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index ba592874b17075..0f73a71c182a49 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -22,11 +22,12 @@ from torch.utils import mkldnn as mkldnn_utils from torch.testing._internal.common_utils import TestCase, \ run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \ - skipIfTorchDynamo, xfailIfTorchDynamo + skipIfTorchDynamo, xfailIfTorchDynamo, recover_orig_fp32_precision from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, dtypes, ) +from torch.testing._internal.common_mkldnn import bf32_on_and_off # batched grad doesn't support mkldnn gradcheck = functools.partial(gradcheck, check_batched_grad=False) @@ -264,7 +265,10 @@ def _test_conv_base(self, dim): loss1.backward() if not train or (train and dim != 1): y_mkldnn = mkldnn_conv(x2).to_dense() - self.assertEqual(y_aten, y_mkldnn) + if self.precision != 0: + self.assertEqual(y_aten, y_mkldnn, atol=self.precision, rtol=self.precision) + else: + self.assertEqual(y_aten, y_mkldnn) if not train: self._test_serialization(mkldnn_conv, (x.to_mkldnn(),)) self._test_tracing(mkldnn_conv, (x.to_mkldnn(),)) @@ -280,12 +284,15 @@ def _test_conv_base(self, dim): if bias: self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad) + @bf32_on_and_off() def test_conv1d(self): self._test_conv_base(dim=1) + @bf32_on_and_off() def test_conv2d(self): self._test_conv_base(dim=2) + @bf32_on_and_off() def test_conv3d(self): self._test_conv_base(dim=3) @@ -400,6 +407,7 @@ def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec) self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec) + @bf32_on_and_off() def test_conv_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32) @@ -435,6 +443,7 @@ def test_conv_nhwc_lower_precision(self, dtype): self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec) + @bf32_on_and_off() def test_conv_transpose_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32) @@ -509,7 +518,11 @@ def _test_conv_transpose_base(self, dim): if train: y.sum().backward() - self.assertEqual(y, y_ref) + if self.precision != 0: + self.assertEqual(y, y_ref, atol=self.precision, rtol=self.precision) + else: + self.assertEqual(y, y_ref) + if train: self.assertEqual(x.grad, x_ref.grad) self.assertEqual(conv.weight.grad, @@ -519,12 +532,15 @@ def _test_conv_transpose_base(self, dim): if bias: self.assertEqual(conv.bias.grad, conv_ref.bias.grad) + @bf32_on_and_off() def test_conv_transpose1d(self): self._test_conv_transpose_base(dim=1) + @bf32_on_and_off() def test_conv_transpose2d(self): self._test_conv_transpose_base(dim=2) + @bf32_on_and_off() def test_conv_transpose3d(self): self._test_conv_transpose_base(dim=3) @@ -1659,6 +1675,53 @@ def test_mkldnn_scaled_mm(self, device) -> None: self.assertEqual(out_emulated.float(), out.float(), atol=5e-2, rtol=5e-2) + @recover_orig_fp32_precision + def test_mlkdnn_get_set(self): + # get/set mkldnn ops + with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"): + self.assertEqual(torch.backends.mkldnn.fp32_precision, "bf16") + with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"): + self.assertEqual(torch.backends.mkldnn.fp32_precision, "none") + # get/set matmul + torch.backends.mkldnn.matmul.fp32_precision = "bf16" + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") + torch.backends.mkldnn.matmul.fp32_precision = "none" + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") + # get/set conv + torch.backends.mkldnn.conv.fp32_precision = "bf16" + self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "bf16") + torch.backends.mkldnn.conv.fp32_precision = "none" + self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "none") + # get/set rnn + torch.backends.mkldnn.rnn.fp32_precision = "bf16" + self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "bf16") + torch.backends.mkldnn.rnn.fp32_precision = "none" + self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "none") + + @recover_orig_fp32_precision + def test_generic_precision(self): + with torch.backends.flags(fp32_precision="none"): + self.assertEqual(torch.backends.fp32_precision, "none") + with torch.backends.flags(fp32_precision="tf32"): + self.assertEqual(torch.backends.fp32_precision, "tf32") + + @recover_orig_fp32_precision + def test_default_use_parent(self): + torch.backends.mkldnn.matmul.fp32_precision = "none" + with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"): + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") + with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"): + with torch.backends.flags(fp32_precision="bf16"): + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") + with torch.backends.flags(fp32_precision="tf32"): + # when parent is a not supported precision, use default + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") + + @recover_orig_fp32_precision + def test_invalid(self): + # use default if user set a not supported precision + torch.backends.mkldnn.matmul.fp32_precision = "tf32" + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',)) diff --git a/test/test_mps.py b/test/test_mps.py index 0e8e83677244ff..f26250eaefd312 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -192,21 +192,21 @@ def test_matmul_autocast(self): f"Autocast & non-autocast tensors did not match, \ got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}") - # Regression test for https://github.com/pytorch/pytorch/issues/141774 - def test_scaled_dot_product_attention_autocast(self): - # TODO(hvaara): Parameterize the dtypes for cleaner code and better failure debugability - dtypes = [torch.float16] if MACOS_VERSION < 14.0 else [torch.bfloat16, torch.float16] + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_scaled_dot_product_attention_autocast(self, dtype): + # Regression test for https://github.com/pytorch/pytorch/issues/141774 + if dtype == torch.bfloat16 and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat16 needs MacOS14+") - for dtype in dtypes: - query = torch.rand(4, 1, 16, 8, dtype=torch.float32, device="mps") - key = torch.rand(4, 1, 16, 8, dtype=torch.float32, device="mps") - value = torch.rand(4, 1, 16, 8, dtype=dtype, device="mps") + query = torch.rand(4, 1, 16, 8, dtype=torch.float32, device="mps") + key = torch.rand(4, 1, 16, 8, dtype=torch.float32, device="mps") + value = torch.rand(4, 1, 16, 8, dtype=dtype, device="mps") - with torch.amp.autocast(device_type="mps"): - y_autocast = F.scaled_dot_product_attention(query, key, value) + with torch.amp.autocast(device_type="mps"): + y_autocast = F.scaled_dot_product_attention(query, key, value) - y = F.scaled_dot_product_attention(query, key, value.to(torch.float32)) - self.assertEqual(y.to(y_autocast.dtype), y_autocast) + y = F.scaled_dot_product_attention(query, key, value.to(torch.float32)) + self.assertEqual(y.to(y_autocast.dtype), y_autocast) def test_gradscaler_mps(self): # big model to force chunking/depth in the gradscaler dispatch @@ -1735,8 +1735,8 @@ def test_batch_norm_backward(self): # This used to crash, see https://github.com/pytorch/pytorch/issues/98602 outputs.sum().backward() - # Regression test for https://github.com/pytorch/pytorch/issues/133520 def test_batch_norm_slices(self): + # Regression test for https://github.com/pytorch/pytorch/issues/133520 bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu') bn_mps = nn.BatchNorm2d(100, affine=False, device='mps') @@ -1748,6 +1748,26 @@ def test_batch_norm_slices(self): self.assertEqual(res_cpu, res_mps) + def test_batch_norm_backward_weight_bias_gradients(self): + # See issue: https://github.com/pytorch/pytorch/issues/156555 + N, C, L = 4, 3, 5 + x = torch.randn(N, C, L) + y = torch.randn(N, C, L) + bn_cpu = nn.BatchNorm1d(C, affine=True).cpu().train() + bn_mps = nn.BatchNorm1d(C, affine=True).to('mps').train() + bn_mps.load_state_dict(bn_cpu.state_dict()) + + out_cpu = bn_cpu(x) + out_mps = bn_mps(x.to('mps')) + + loss_cpu = ((out_cpu - y) ** 2).mean() + loss_mps = ((out_mps - y.to('mps')) ** 2).mean() + loss_cpu.backward() + loss_mps.backward() + + self.assertEqual(bn_cpu.weight.grad, bn_mps.weight.grad, atol=1e-5, rtol=1e-5) + self.assertEqual(bn_cpu.bias.grad, bn_mps.bias.grad, atol=1e-5, rtol=1e-5) + def test_layer_norm_backward(self): inputs = torch.rand(4, 4, device="mps", requires_grad=True) x = torch.nn.LayerNorm(4).to("mps") @@ -2033,8 +2053,8 @@ def test_ifft(self): # Expecting the inverted to yield the original signal self.assertEqual(ifft_result, signal) - # Regression test for https://github.com/pytorch/pytorch/issues/135223 def test_fftfreq(self): + # Regression test for https://github.com/pytorch/pytorch/issues/135223 freq_cpu = torch.fft.fftfreq(10**4, device='cpu') freq_mps = torch.fft.fftfreq(10**4, device='mps') self.assertEqual(freq_cpu, freq_mps) @@ -5994,17 +6014,25 @@ def helper(shape): helper((2, 8, 4, 5)) - def test_log1p(self): - def helper(shape): - cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - x = cpu_x.detach().clone().to('mps') + @parametrize("dtype", {torch.float, torch.half} if MACOS_VERSION < 14 else {torch.float, torch.half, torch.bfloat16}) + def test_log1p(self, dtype): + eps = torch.finfo(dtype).eps + # Small values + cpu_x = torch.arange(-10.0 * eps, 10.0 * eps, 1e-2 * eps, dtype=dtype, requires_grad=False) + x = cpu_x.detach().clone().to('mps') - log_result = torch.log1p(x) - log_result_cpu = torch.log1p(cpu_x) + log_result = torch.log1p(x) + log_result_cpu = torch.log1p(cpu_x) + self.assertEqual(log_result, log_result_cpu, atol=0, rtol=2e-7) - self.assertEqual(log_result, log_result_cpu) + # Fallback to log + cpu_x = torch.arange(-1.0, 2.0, 1e-4, dtype=dtype, requires_grad=False) + x = cpu_x.detach().clone().to('mps') - helper((2, 8, 4, 5)) + log_result = torch.log1p(x) + log_result_cpu = torch.log1p(cpu_x) + + self.assertEqual(log_result, log_result_cpu, atol=0, rtol=2e-7) def test_logaddexp(self): def helper(shape): @@ -7138,6 +7166,11 @@ def helper(shape, diag=0): helper((2, 8, 4, 5), diag=-1) helper((2, 8, 4, 5), diag=-2) helper((2, 8, 4, 5), diag=-3) + # Test inplace + x_mps = torch.arange(9.0, device='mps').reshape(3, 3).t().triu() + x_cpu = torch.arange(9.0, device='cpu').reshape(3, 3).t().triu() + self.assertEqual(x_cpu, x_mps) + self.assertEqual(x_cpu.stride(), x_mps.stride()) # Test inverse def test_inverse(self): @@ -7947,6 +7980,21 @@ def test_inplace_bitwise_not(self, dtype): x[::2].bitwise_not_() self.assertEqual(x_mps.cpu(), x_cpu) + +class TestLargeTensors(TestCaseMPS): + @serialTest() + def test_64bit_binops(self): + if torch.mps.recommended_max_memory() < 16_000_000_000: + raise unittest.SkipTest("Needs at least 16Gb of RAM") + a = torch.rand(1, 1024, 1024, dtype=torch.float16, device='mps') + b = torch.rand(5000, 1, 1, dtype=torch.float16, device='mps') + rc = (a + b).sin() + slice_idx = -2 + rc_slice = rc[slice_idx:] + rc_slice_cpu = (a.cpu() + b.cpu()[slice_idx:]).sin() + self.assertEqual(rc_slice, rc_slice_cpu) + + class TestLogical(TestCaseMPS): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad) @@ -8393,6 +8441,14 @@ def test_topk(self): with self.subTest(shape=shape, largest_val=largest_val): self._test_topk(shape, largest_val) + def test_topk_gt_4d(self): + a = torch.ones(5, 4, 3, 2, 1, dtype=torch.float).to('mps') + try: + t_mps = torch.ops.aten.topk(a, k=5, dim=0) + except Exception as e: + e_string = str(e) + self.assertEqual(e_string, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890") + class TestNNMPS(NNTestCase): def _create_basic_net(self): @@ -9280,7 +9336,7 @@ def test_sdpa_enable_gqa(self, dtype, is_causal): ) self._compare_tensors(y.cpu(), y_ref) - @serialTest + @serialTest() def test_sdpa_fp32_no_memory_leak(self): def get_mps_memory_usage(): return (torch.mps.current_allocated_memory() / (1024 * 1024), @@ -11991,10 +12047,11 @@ class TestConsistency(TestCaseMPS): } FP32_LOW_PRECISION_LIST = { - # conv2d and conv_transpose2d results have a very small + # conv2d, conv_transpose2d and conv_transpose3d results have a very small # difference compared to CPU/CUDA, so we use lower precision on FP32 'nn.functional.conv2d', 'nn.functional.conv_transpose2d', + 'nn.functional.conv_transpose3d', 'matmul', '__rmatmul__', 'linalg.multi_dot', 'addbmm', @@ -12185,6 +12242,8 @@ def req_grad(t): # which leads to larger errors if op.name == "_unsafe_masked_index" and dtype == torch.float16: atol, rtol = 3e-3, 3e-3 + if op.name == "logcumsumexp": + atol, rtol = 4e-3, 1e-3 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) def test_fmax_mixed_dtypes(self, device): @@ -12491,6 +12550,58 @@ def test_metal_capture(self): f"Capture file {capture_dirname} contains only metadata, i.e. {capture_listdir}") + +class TestSparseMPS(TestCaseMPS): + def _get_basic_sparse_coo(self, device="mps"): + indices = torch.tensor([[0, 1], [2, 0]], dtype=torch.int64, device=device) + values = torch.tensor([1, 2], dtype=torch.float32, device=device) + size = (2, 3) + return torch.sparse_coo_tensor(indices, values, size, device=device) + + def test_sparse_coo_tensor_with_dims(self): + indices = torch.zeros((2, 0), dtype=torch.int64, device="mps") + values = torch.tensor([], dtype=torch.float32, device="mps") + size = (2, 3) + t = torch.sparse_coo_tensor(indices, values, size, device="mps") + self.assertEqual(t.device.type, "mps") + self.assertEqual(t.layout, torch.sparse_coo) + + def test_sparse_coo_tensor_with_dims_and_tensors(self): + indices = torch.tensor([[0, 1], [2, 0]], device="mps") + values = torch.tensor([1., 2.], device="mps") + size = (2, 3) + t = torch.sparse_coo_tensor(indices, values, size, device="mps") + self.assertEqual(t.device.type, "mps") + self.assertEqual(t.layout, torch.sparse_coo) + self.assertEqual(t._indices().cpu(), indices.cpu()) + self.assertEqual(t._values().cpu(), values.cpu()) + + def test_nnz(self): + t = self._get_basic_sparse_coo() + self.assertEqual(t._nnz(), 2) + + def test_sparse_dim(self): + t = self._get_basic_sparse_coo() + self.assertEqual(t.sparse_dim(), 2) + + def test_to_sparse(self): + t = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]], device="mps") + x = t.to_sparse() + t_cpu = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]], device="mps") + x_cpu = t.to_sparse() + self.assertEqual(x.cpu(), x_cpu) + + def test_resize(self): + indices = torch.tensor([[0, 1], [2, 0]]) + values = torch.tensor([3.0, 4.0]) + size = torch.Size([2, 3]) + sparse = torch.sparse_coo_tensor(indices, values, size, device="mps") + sparse_cpu = torch.sparse_coo_tensor(indices, values, size, device="cpu") + sparse = sparse.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0) + sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0) + self.assertEqual(sparse, sparse_cpu) + + # TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. # This requires mps to be properly registered in the device generic test framework which is not the # case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342 @@ -12499,6 +12610,7 @@ def test_metal_capture(self): instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps") instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps") instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps") +instantiate_parametrized_tests(TestAutocastMPS) instantiate_parametrized_tests(TestLogical) instantiate_parametrized_tests(TestMPS) instantiate_parametrized_tests(TestSDPA) diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 9f58f14143a570..85c3b4d2cb3cc5 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -98,6 +98,44 @@ def send_and_delete_tensors(queue, event, device, dtype, count, size=5): event.wait() +def send_tensor_with_untyped_storage(queue, event): + tensors = torch.ones(2, device="cuda").chunk(2, dim=0) + specs = [] + for tensor in tensors: + storage = tensor.untyped_storage() + ( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() + specs.append( + { + "tensor_cls": type(tensor), + "tensor_size": tensor.shape, + "tensor_stride": tensor.stride(), + "tensor_offset": tensor.storage_offset(), + "dtype": tensor.dtype, + "requires_grad": tensor.requires_grad, + "storage_cls": type(storage), + "storage_device": storage_device, + "storage_handle": storage_handle, + "storage_size_bytes": storage_size_bytes, + "storage_offset_bytes": storage_offset_bytes, + "ref_counter_handle": ref_counter_handle, + "ref_counter_offset": ref_counter_offset, + "event_handle": event_handle, + "event_sync_required": event_sync_required, + } + ) + queue.put(specs) + event.wait() + + def receive_and_send_sum(queue, out_queue, event, device, dtype, count, size=5): s = torch.full([size], 0, device=device, dtype=dtype) for i in range(count): @@ -630,6 +668,27 @@ def run(rank): ) self.assertRegex(stderr, "Cannot re-initialize CUDA in forked subprocess.") + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") + def test_rebuild_cuda_tensor(self): + ctx = mp.get_context("spawn") + queue = ctx.Queue() + event = ctx.Event() + + proc = ctx.Process( + target=send_tensor_with_untyped_storage, + args=(queue, event), + ) + proc.start() + + specs = queue.get() + tensors = [] + for spec in specs: + tensors.append(mp.reductions.rebuild_cuda_tensor(**spec)) + self.assertEqual(tensors, [1, 1]) + + del tensors, spec + event.set() + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") def test_event(self): ctx = mp.get_context("spawn") diff --git a/test/test_nn.py b/test/test_nn.py index 353e0d6abd8040..b9a56698edc6b5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -30,13 +30,15 @@ from torch.nn import Buffer, Parameter from torch.nn.parallel._functions import Broadcast from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types -from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ +from torch.testing._internal.common_utils import dtype_name, freeze_rng_state, run_tests, TestCase, \ + skipIfNoLapack, skipIfRocm, \ TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ download_file, get_function_arglist, load_tests, skipIfMPS, \ IS_PPC, \ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ skipIfTorchDynamo, gcIfJetson, set_default_dtype -from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ + PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input @@ -5126,6 +5128,175 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): self.assertTrue(torch.equal(running_mean, bn.running_mean)) self.assertTrue(torch.equal(running_var, bn.running_var)) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @parametrize_test("dims", [2, 3], name_fn=lambda x: f"{x}D") + @parametrize_test("mode", ["train", "inference"], name_fn=lambda x: x) + @parametrize_test( + # test verifies cudnn/miopen batchnorm with the reference backend or memory format + # memory_format - one of ("NCHW", NHWC") + # ref_backend - one of ("cpu", "native", "NCHW", "NHWC") + # "cpu" - cpu backend with the same memory_format will be used as reference + # "native" - native backend (`with torch.backends.cudnn.flags(enabled=False)`) + # with the same memory_format will be used + # "NCHW" or "NHWC" - the same backend will be used but another memory format + # mixed - True or False. Mixed batchnorm mode where inputs are 16-bit and batchnorm is fp32 + # + "memory_format,ref_backend,mixed,dtype", + [ + ("NCHW", "cpu", False, torch.float), + ("NCHW", "cpu", True, torch.half), + ("NCHW", "cpu", True, torch.bfloat16), + + ("NCHW", "native", False, torch.float), + ("NCHW", "native", True, torch.half), + ("NCHW", "native", True, torch.bfloat16), + ], + name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}" + ) + def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype): + if torch.version.cuda: + if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16", + "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"): + self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue " + "https://github.com/pytorch/pytorch/issues/156513") + if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16": + self.skipTest("Batchnorm 3D NHWC train failed on CUDA") + + if torch.version.hip: + if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16", + "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \ + and _get_torch_rocm_version() < (6, 4): + # NCHW bfloat16 path uses native kernels for rocm<=6.3 + # train failed on rocm<=6.3 due to native tolerance issue + # https://github.com/pytorch/pytorch/issues/156513 + self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3") + + if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16", + "test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16") \ + and _get_torch_rocm_version() >= (6, 4): + # https://github.com/pytorch/pytorch/issues/156513 + self.skipTest("bfloat16 NCHW train failed due to native tolerance issue") + + if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \ + and _get_torch_rocm_version() < (7, 0): + self.skipTest("3D float16 NCHW train failed on ROCm<7.0") + + if dims == 3 and memory_format in ("NHWC", "NCHW"): + memory_format = memory_format + "3D" + + def _create_tensor(size, memory_format, dtype, device): + t = torch.empty(size=size, memory_format=memory_format, dtype=dtype, device=device) + t = t.random_(1, 10) + return t + + def _get_ref_device(backend: str , device: str): + # If 'backend' specifies the memory format, return 'device' arg, otherwise return a device matches the backend + if backend in ("NHWC", "NHWC3D", "NCHW", "NCHW3D"): + return device + if backend == "native": + return "cuda" + if backend == "cpu": + return "cpu" + else: + raise ValueError("Unknown backend") + + def _get_backend_memory_format(backend: str, memory_format: torch.memory_format) -> torch.memory_format: + # If 'backend' specifies the memory format, return it, otherwise look at 'memory_format' arg + if backend == "NHWC": + return torch.channels_last + if backend == "NHWC3D": + return torch.channels_last_3d + if backend in ("NCHW", "NCHW3D"): + return torch.contiguous_format + if memory_format in (torch.contiguous_format, torch.channels_last, torch.channels_last_3d): + return memory_format + raise ValueError("Unable to detect memory format for backend={backend} and memory_format={memory_format}") + + def _get_memory_format(t: torch.Tensor) -> torch.memory_format: + if t.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + if t.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + if t.is_contiguous(memory_format=torch.channels_last_3d): + return torch.channels_last_3d + return ValueError("Unsupported memory_format") + + def _get_memory_format_from_name(memory_format_name: str) -> torch.memory_format: + if memory_format_name == "NHWC": + return torch.channels_last + elif memory_format_name == "NHWC3D": + return torch.channels_last_3d + elif memory_format_name in ("NCHW", "NCHW3D"): + return torch.contiguous_format + return ValueError("Unsupported memory_format") + + def _create_backend(inp: torch.Tensor, mixed: bool = False): + if inp.dim() == 4: + return nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) + else: + return nn.BatchNorm3d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) + + def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend): + mod = _create_backend(inp, mixed).train() + mod.weight.data.uniform_() + mod.bias.data.uniform_() + + ref_mod = _create_backend(ref_inp, mixed).train() + ref_mod.load_state_dict(mod.state_dict()) + + out = mod(inp) + out.backward(grad) + + with torch.backends.cudnn.flags(enabled=False) if ref_backend == "native" else contextlib.nullcontext(): + ref_out = ref_mod(ref_inp) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=_get_memory_format(inp))) + self.assertTrue(ref_out.is_contiguous(memory_format=_get_memory_format(ref_inp))) + self.assertEqual(out, ref_out) + self.assertEqual(mod.weight.grad, ref_mod.weight.grad) + self.assertEqual(mod.bias.grad, ref_mod.bias.grad) + self.assertEqual(mod.running_mean, ref_mod.running_mean) + self.assertEqual(mod.running_var, ref_mod.running_var) + self.assertEqual(inp.grad, ref_inp.grad) + + def _train(memory_format_name, ref_backend, mixed, dtype): + memory_format = _get_memory_format_from_name(memory_format_name) + + ref_memory_format = _get_backend_memory_format(ref_backend, memory_format) + ref_device = _get_ref_device(ref_backend, device="cuda") + + size = (4, 8, 2, 2, 2) if memory_format_name in ("NCHW3D", "NHWC3D") else (4, 8, 2, 2) + inp = _create_tensor(size, memory_format, dtype, device="cuda").detach().requires_grad_() + grad = _create_tensor(size, memory_format, dtype, device="cuda") + ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device).requires_grad_() + ref_grad = grad.detach().clone(memory_format=ref_memory_format).to(device=ref_device) + + _test_batchnorm_train(inp=inp, grad=grad, mixed=mixed, + ref_inp=ref_inp, ref_grad=ref_grad, ref_backend=ref_backend) + + def _inference(memory_format_name, ref_backend, mixed, dtype): + memory_format = _get_memory_format_from_name(memory_format_name) + ref_memory_format = _get_backend_memory_format(ref_backend, memory_format) + ref_device = _get_ref_device(ref_backend, device="cuda") + + size = (2, 64, 50, 50, 50) if memory_format_name in ("NCHW3D", "NHWC3D") else (2, 64, 50, 50) + inp = _create_tensor(size, memory_format, dtype, device="cuda") + ref_inp = inp.detach().clone(memory_format=ref_memory_format).to(device=ref_device) + mod = _create_backend(inp, mixed).eval() + ref_mod = _create_backend(ref_inp, mixed).eval() + + out = mod(inp) + with torch.backends.cudnn.flags(enabled=False) if ref_backend == "native" else contextlib.nullcontext(): + ref_out = ref_mod(ref_inp) + self.assertEqual(out, ref_out) + + if mode == "train": + _train(memory_format, ref_backend, mixed, dtype) + else: + _inference(memory_format, ref_backend, mixed, dtype) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_batchnorm_nhwc_cuda(self): for dtype in (torch.half, torch.float): @@ -7212,25 +7383,37 @@ def test_layer_norm_eps(self): ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) self.assertEqual(ln.forward(x), torch.zeros_like(x)) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") def test_layer_norm_backwards_eps(self): dtype = torch.float m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55), (32, 32), (1024, 32), (1024, 1024), - (33, 33), (1025, 33), (1025, 1025)] - for m, n in m_x_n_list: - x = torch.randn((m, n), dtype=dtype, requires_grad=True) - grad_output = torch.rand_like(x) - x_cuda = x.clone().detach().to("cuda").requires_grad_() - grad_output_cuda = grad_output.clone().detach().to("cuda") - ln = nn.LayerNorm(n, dtype=dtype) - ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype) - ln_out = ln(x) - ln_out_cuda = ln_cuda(x_cuda) - ln_out.backward(grad_output) - ln_out_cuda.backward(grad_output_cuda) - self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) - self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) + (33, 33), (1025, 33), (1025, 1025), + (128 * 1024, 32), (32, 128 * 1024)] + boolean = [True, False] + combinations = itertools.product(boolean, repeat=2) + for elementwise_affine, bias in combinations: + for m, n in m_x_n_list: + x = torch.randn((m, n), dtype=dtype, requires_grad=True) + grad_output = torch.rand_like(x) + x_cuda = x.clone().detach().to("cuda").requires_grad_() + grad_output_cuda = grad_output.clone().detach().to("cuda") + ln = nn.LayerNorm(n, dtype=dtype, elementwise_affine=elementwise_affine, bias=bias) + ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype, elementwise_affine=elementwise_affine, bias=bias) + ln_out = ln(x) + ln_out_cuda = ln_cuda(x_cuda) + ln_out.backward(grad_output) + ln_out_cuda.backward(grad_output_cuda) + atol = 1e-4 + rtol = 1e-5 + if m > 64 * 1024: + atol = 1e-3 + rtol = 1e-3 + if elementwise_affine: + self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=rtol, atol=atol) + if bias and elementwise_affine: + self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=rtol, atol=atol) @largeTensorTest("40GB", device="cuda") def test_layer_norm_large_tensor(self): @@ -9121,6 +9304,13 @@ def test_MarginLoss_warnings(self, device): l.backward() self.assertTrue(len(f.getvalue()) == 0) + @onlyCUDA + def test_mse_loss_error(self, device): + i = torch.randn((10, 1), device=device) + t = torch.randn((10,)) + with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): + F.mse_loss(i, t) + @onlyNativeDeviceTypes def test_Unfold_empty(self, device): inp = torch.randn(0, 3, 3, 4, device=device) @@ -9799,7 +9989,6 @@ def test_upsamplingNearest3d_correctness(self, device, memory_format, isize, osi expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) - @expectedFailureMPS # NotImplementedError: aten::_upsample_nearest_exact3d.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) @parametrize_test("isize, osize", [(20, 11), (10, 15)]) def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize, osize): @@ -10498,6 +10687,17 @@ def test_softmax_double(self, device, dtype): expected_ones = F.log_softmax(logits, dim=1).exp().sum(dim=1) self.assertEqual(expected_ones, torch.ones_like(expected_ones)) + # backward + logits = torch.randn(5, 513, dtype=dtype, device=device, requires_grad=True) + out = F.log_softmax(logits, dim=1) + grad = torch.randn_like(out) + out.backward(grad) + logits_cpu = logits.detach().cpu() + logits_cpu.requires_grad = True + out_cpu = F.log_softmax(logits_cpu, dim=1) + out_cpu.backward(grad.detach().cpu()) + self.assertEqual(logits.grad, logits_cpu.grad) + @onlyCUDA @dtypes(torch.half) @largeTensorTest("20GB") @@ -11662,6 +11862,8 @@ def check_lengths(lengths, enforce_sorted, use_default_hiddens, proj_size): prec = dtype2prec_DONTUSE[dtype] if dtype == torch.float16: prec = 4e-2 + elif dtype == torch.float32: + prec = 2e-4 self.assertEqual(p1.grad, p2.grad, atol=prec, rtol=0) tests = [ diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index 15864a056041bd..927ec303bbea2e 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -286,6 +286,14 @@ def test_from_numpy_no_leak_on_invalid_dtype(self): pass self.assertTrue(sys.getrefcount(x) == 2) + @skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.") + @onlyCPU + def test_from_numpy_zero_element_type(self): + # This tests that dtype check happens before strides check + # which results in div-by-zero on-x86 + x = np.ndarray((3, 3), dtype=str) + self.assertRaises(TypeError, lambda: torch.from_numpy(x)) + @skipMeta def test_from_list_of_ndarray_warning(self, device): warning_msg = ( diff --git a/test/test_openreg.py b/test/test_openreg.py index 4635e36f1b1498..59e1c4bfac4194 100644 --- a/test/test_openreg.py +++ b/test/test_openreg.py @@ -1,56 +1,262 @@ -# Owner(s): ["module: cpp"] +# Owner(s): ["module: PrivateUse1"] +import _codecs +import io import os +import tempfile +import types import unittest +from unittest.mock import patch +import numpy as np import psutil import pytorch_openreg # noqa: F401 import torch +from torch.serialization import safe_globals from torch.testing._internal.common_utils import ( IS_LINUX, run_tests, skipIfTorchDynamo, + skipIfXpu, + TemporaryFileName, TestCase, ) -class TestOpenReg(TestCase): - def test_initializes(self): +class TestPrivateUse1(TestCase): + """Tests of third-parth device integration mechinasm based PrivateUse1""" + + def test_backend_name(self): self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg") + # backend can be renamed to the same name multiple times + torch.utils.rename_privateuse1_backend("openreg") + with self.assertRaisesRegex(RuntimeError, "has already been set"): # type: ignore[misc] + torch.utils.rename_privateuse1_backend("dev") + + def test_backend_module_registration(self): + def generate_faked_module(): + return types.ModuleType("fake_module") + + with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): # type: ignore[misc] + torch._register_device_module("dev", generate_faked_module()) + with self.assertRaisesRegex(RuntimeError, "The runtime module of"): # type: ignore[misc] + torch._register_device_module("openreg", generate_faked_module()) + + def test_backend_generate_methods(self): + with self.assertRaisesRegex(RuntimeError, "The custom device module of"): # type: ignore[misc] + torch.utils.generate_methods_for_privateuse1_backend() # type: ignore[misc] + + self.assertTrue(hasattr(torch.Tensor, "is_openreg")) + self.assertTrue(hasattr(torch.Tensor, "openreg")) + self.assertTrue(hasattr(torch.TypedStorage, "is_openreg")) + self.assertTrue(hasattr(torch.TypedStorage, "openreg")) + self.assertTrue(hasattr(torch.UntypedStorage, "is_openreg")) + self.assertTrue(hasattr(torch.UntypedStorage, "openreg")) + self.assertTrue(hasattr(torch.nn.Module, "openreg")) + self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_openreg")) + self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "openreg")) + + def test_backend_module_function(self): + with self.assertRaisesRegex(RuntimeError, "Try to call torch.openreg"): # type: ignore[misc] + torch.utils.backend_registration._get_custom_mod_func("func_name_") # type: ignore[misc] + self.assertTrue( + torch.utils.backend_registration._get_custom_mod_func("device_count")() == 2 # type: ignore[misc] + ) - @unittest.skipIf(not IS_LINUX, "Only works on linux") - def test_autograd_init(self): - # Make sure autograd is initialized - torch.ones(2, requires_grad=True, device="openreg").sum().backward() + @skipIfTorchDynamo() + def test_backend_operator_registration(self): + self.assertTrue( + torch._C._dispatch_has_kernel_for_dispatch_key( + "aten::empty.memory_format", torch.DispatchKey.PrivateUse1 + ) + ) + x = torch.empty(3, 3, device="openreg") + self.assertTrue(x.device.type, "openreg") + self.assertTrue(x.shape, torch.Size([3, 3])) + + def test_backend_dispatchstub(self): + x_cpu = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu") + x_openreg = x_cpu.to("openreg") + + y_cpu = torch.abs(x_cpu) + y_openreg = torch.abs(x_openreg) + self.assertEqual(y_cpu, y_openreg.cpu()) + + o_cpu = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu") + o_openreg = o_cpu.to("openreg") + # output operand with resize flag is False in TensorIterator. + torch.abs(x_cpu, out=o_cpu[:, :, 0:6:2]) + torch.abs(x_openreg, out=o_openreg[:, :, 0:6:2]) + self.assertEqual(o_cpu, o_openreg.cpu()) + + # output operand with resize flag is True in TensorIterator and + # convert output to contiguous tensor in TensorIterator. + torch.abs(x_cpu, out=o_cpu[:, :, 0:6:3]) + torch.abs(x_openreg, out=o_openreg[:, :, 0:6:3]) + self.assertEqual(o_cpu, o_openreg.cpu()) + + def test_backend_tensor_type(self): + dtypes_map = { + torch.bool: "torch.openreg.BoolTensor", + torch.double: "torch.openreg.DoubleTensor", + torch.float32: "torch.openreg.FloatTensor", + torch.half: "torch.openreg.HalfTensor", + torch.int32: "torch.openreg.IntTensor", + torch.int64: "torch.openreg.LongTensor", + torch.int8: "torch.openreg.CharTensor", + torch.short: "torch.openreg.ShortTensor", + torch.uint8: "torch.openreg.ByteTensor", + } + + for dtype, str in dtypes_map.items(): + x = torch.empty(4, 4, dtype=dtype, device="openreg") + self.assertTrue(x.type() == str) + + # Note that all dtype-d Tensor objects here are only for legacy reasons + # and should NOT be used. + def test_backend_type_methods(self): + # Tensor + tensor_cpu = torch.randn([8]).float() + self.assertEqual(tensor_cpu.type(), "torch.FloatTensor") + + tensor_openreg = tensor_cpu.openreg() + self.assertEqual(tensor_openreg.type(), "torch.openreg.FloatTensor") + + # Storage + storage_cpu = tensor_cpu.storage() + self.assertEqual(storage_cpu.type(), "torch.FloatStorage") + + tensor_openreg = tensor_cpu.openreg() + storage_openreg = tensor_openreg.storage() + self.assertEqual(storage_openreg.type(), "torch.storage.TypedStorage") + + class CustomFloatStorage: + @property + def __module__(self): + return "torch." + torch._C._get_privateuse1_backend_name() + + @property + def __name__(self): + return "FloatStorage" + + try: + torch.openreg.FloatStorage = CustomFloatStorage() + self.assertEqual(storage_openreg.type(), "torch.openreg.FloatStorage") + + # test custom int storage after defining FloatStorage + tensor_openreg = tensor_cpu.int().openreg() + storage_openreg = tensor_openreg.storage() + self.assertEqual(storage_openreg.type(), "torch.storage.TypedStorage") + finally: + torch.openreg.FloatStorage = None + + def test_backend_tensor_methods(self): + x = torch.empty(4, 4) + self.assertFalse(x.is_openreg) # type: ignore[misc] + + y = x.openreg(torch.device("openreg")) # type: ignore[misc] + self.assertTrue(y.is_openreg) # type: ignore[misc] + z = x.openreg(torch.device("openreg:0")) # type: ignore[misc] + self.assertTrue(z.is_openreg) # type: ignore[misc] + n = x.openreg(0) # type: ignore[misc] + self.assertTrue(n.is_openreg) # type: ignore[misc] + + @unittest.skip("Need to support Parameter in openreg") + def test_backend_module_methods(self): + class FakeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(3, 3)) + + def forward(self): + pass + + module = FakeModule() + self.assertEqual(module.x.device.type, "cpu") + module.openreg() # type: ignore[misc] + self.assertEqual(module.x.device.type, "openreg") + + @unittest.skip("Need to support untyped_storage in openreg") + def test_backend_storage_methods(self): + x = torch.empty(4, 4) + + x_cpu = x.storage() + self.assertFalse(x_cpu.is_openreg) # type: ignore[misc] + x_openreg = x_cpu.openreg() # type: ignore[misc] + self.assertTrue(x_openreg.is_openreg) # type: ignore[misc] + + y = torch.empty(4, 4) + + y_cpu = y.untyped_storage() + self.assertFalse(y_cpu.is_openreg) # type: ignore[misc] + y_openreg = y_cpu.openreg() # type: ignore[misc] + self.assertTrue(y_openreg.is_openreg) # type: ignore[misc] + + def test_backend_packed_sequence_methods(self): + x = torch.rand(5, 3) + y = torch.tensor([1, 1, 1, 1, 1]) + + z_cpu = torch.nn.utils.rnn.PackedSequence(x, y) + self.assertFalse(z_cpu.is_openreg) # type: ignore[misc] + + z_openreg = z_cpu.openreg() # type: ignore[misc] + self.assertTrue(z_openreg.is_openreg) # type: ignore[misc] - pid = os.getpid() - task_path = f"/proc/{pid}/task" - all_threads = psutil.Process(pid).threads() - all_thread_names = set() +class TestOpenReg(TestCase): + """Tests of mimick accelerator named OpenReg based on PrivateUse1""" - for t in all_threads: - with open(f"{task_path}/{t.id}/comm") as file: - thread_name = file.read().strip() - all_thread_names.add(thread_name) + # Stream & Event + def test_stream_synchronize(self): + stream = torch.Stream(device="openreg:1") + stream.synchronize() + self.assertEqual(True, stream.query()) - for i in range(torch.accelerator.device_count()): - self.assertIn(f"pt_autograd_{i}", all_thread_names) + def test_stream_wait_stream(self): + stream_1 = torch.Stream(device="openreg:0") + stream_2 = torch.Stream(device="openreg:1") + # Does not crash! + stream_2.wait_stream(stream_1) - def test_factory(self): - a = torch.empty(50, device="openreg") - self.assertEqual(a.device.type, "openreg") + @skipIfTorchDynamo() + def test_record_event(self): + stream = torch.Stream(device="openreg:1") + event1 = stream.record_event() + self.assertNotEqual(0, event1.event_id) + event2 = stream.record_event() + self.assertNotEqual(0, event2.event_id) + self.assertNotEqual(event1.event_id, event2.event_id) + + @skipIfTorchDynamo() + def test_event_elapsed_time(self): + stream = torch.Stream(device="openreg:1") + e1 = torch.Event(device="openreg:1", enable_timing=True) + e1.record(stream) + e2 = torch.Event(device="openreg:1", enable_timing=True) + e2.record(stream) - a.fill_(3.5) + e2.synchronize() + self.assertTrue(e2.query()) - self.assertTrue(a.eq(3.5).all()) + ms = e1.elapsed_time(e2) + self.assertTrue(ms > 0) - def test_printing(self): - a = torch.ones(20, device="openreg") - # Does not crash! - str(a) + @skipIfTorchDynamo() + def test_stream_wait_event(self): + s1 = torch.Stream(device="openreg") + s2 = torch.Stream(device="openreg") + e = s1.record_event() + s2.wait_event(e) + + @skipIfTorchDynamo() + def test_event_wait_stream(self): + s1 = torch.Stream(device="openreg") + s2 = torch.Stream(device="openreg") + e1 = s1.record_event() + e1.wait(s2) + # Copy def test_cross_device_copy(self): a = torch.rand(10) b = a.to(device="openreg").add(2).to(device="cpu") @@ -64,32 +270,61 @@ def test_cross_diff_devices_copy(self): a = torch.ones(10, device="openreg:0").to(device="openreg:1").to(device="cpu") self.assertEqual(a, torch.ones(10)) - def test_data_dependent_output(self): - cpu_a = torch.randn(10) - a = cpu_a.to(device="openreg") - mask = a.gt(0) - out = torch.masked_select(a, mask) - - self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0))) - + # RNG def test_generator(self): generator = torch.Generator(device="openreg:1") self.assertEqual(generator.device.type, "openreg") self.assertEqual(generator.device.index, 1) - # TODO(FFFrog): Add more check for rng_state def test_rng_state(self): - state = torch.openreg.get_rng_state(0) - torch.openreg.set_rng_state(state, 0) + state = torch.openreg.get_rng_state(0) # type: ignore[misc] + torch.openreg.set_rng_state(state, 0) # type: ignore[misc] + + def test_manual_seed(self): + torch.openreg.manual_seed_all(2024) # type: ignore[misc] + self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc] + + # Autograd + @unittest.skipIf(not IS_LINUX, "Only works on linux") + def test_autograd_init(self): + # Make sure autograd is initialized + torch.ones(2, requires_grad=True, device="openreg").sum().backward() + + pid = os.getpid() + task_path = f"/proc/{pid}/task" + all_threads = psutil.Process(pid).threads() + + all_thread_names = set() + + for t in all_threads: + with open(f"{task_path}/{t.id}/comm") as file: + thread_name = file.read().strip() + all_thread_names.add(thread_name) + + for i in range(torch.accelerator.device_count()): + self.assertIn(f"pt_autograd_{i}", all_thread_names) + # Storage & Pin Memory @skipIfTorchDynamo("unsupported aten.is_pinned.default") def test_pin_memory(self): - cpu_a = torch.randn(10) - self.assertFalse(cpu_a.is_pinned()) - pinned_a = cpu_a.pin_memory() - self.assertTrue(pinned_a.is_pinned()) - slice_a = pinned_a[2:5] - self.assertTrue(slice_a.is_pinned()) + tensor = torch.randn(10) + self.assertFalse(tensor.is_pinned()) + pinned_tensor = tensor.pin_memory() + self.assertTrue(pinned_tensor.is_pinned()) + slice_tensor = pinned_tensor[2:5] + self.assertTrue(slice_tensor.is_pinned()) + + tensor = torch.randn(10) + storage = tensor.storage() + self.assertFalse(storage.is_pinned("openreg")) + pinned_storage = storage.pin_memory("openreg") + self.assertTrue(pinned_storage.is_pinned("openreg")) + + tensor = torch.randn(10) + untyped_storage = tensor.untyped_storage() + self.assertFalse(untyped_storage.is_pinned("openreg")) + pinned_untyped_storage = untyped_storage.pin_memory("openreg") + self.assertTrue(pinned_untyped_storage.is_pinned("openreg")) @skipIfTorchDynamo("unsupported aten.is_pinned.default") def test_rewrapped_storage(self): @@ -103,53 +338,178 @@ def test_rewrapped_storage(self): self.assertTrue(rewrapped_a.is_pinned()) self.assertNotEqual(pinned_a.data_ptr(), rewrapped_a.data_ptr()) - def test_stream_synchronize(self): - stream = torch.Stream(device="openreg:1") - stream.synchronize() - self.assertEqual(True, stream.query()) + # Serialization + def test_serialization(self): + storage = torch.UntypedStorage(4, device=torch.device("openreg")) + self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") - def test_stream_wait_stream(self): - stream_1 = torch.Stream(device="openreg:0") - stream_2 = torch.Stream(device="openreg:1") - # Does not crash! - stream_2.wait_stream(stream_1) + storage = torch.UntypedStorage(4, device=torch.device("openreg:0")) + self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") - @skipIfTorchDynamo() - def test_record_event(self): - stream = torch.Stream(device="openreg:1") - event1 = stream.record_event() - self.assertNotEqual(0, event1.event_id) - event2 = stream.record_event() - self.assertNotEqual(0, event2.event_id) - self.assertNotEqual(event1.event_id, event2.event_id) + storage_cpu = torch.empty(4, 4).storage() + storage_openreg = torch.serialization.default_restore_location( + storage_cpu, "openreg:0" + ) + self.assertTrue(storage_openreg.is_openreg) # type: ignore[misc] - @skipIfTorchDynamo() - def test_event_elapsed_time(self): - stream = torch.Stream(device="openreg:1") - e1 = torch.Event(device="openreg:1", enable_timing=True) - e1.record(stream) - e2 = torch.Event(device="openreg:1", enable_timing=True) - e2.record(stream) + tensor = torch.empty(3, 3, device="openreg") + self.assertEqual(torch._utils.get_tensor_metadata(tensor), {}) # type: ignore[misc] + metadata = {"version_number": True, "format_number": True} + torch._utils.set_tensor_metadata(tensor, metadata) # type: ignore[misc] + self.assertEqual(torch._utils.get_tensor_metadata(tensor), metadata) # type: ignore[misc] - e2.synchronize() - self.assertTrue(e2.query()) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.pt") + torch.save(tensor, path) - ms = e1.elapsed_time(e2) - self.assertTrue(ms > 0) + tensor_openreg = torch.load(path) + self.assertTrue(tensor_openreg.is_openreg) + self.assertEqual(torch._utils.get_tensor_metadata(tensor_openreg), metadata) # type: ignore[misc] - @skipIfTorchDynamo() - def test_stream_wait_event(self): - s1 = torch.Stream(device="openreg") - s2 = torch.Stream(device="openreg") - e = s1.record_event() - s2.wait_event(e) + tensor_cpu = torch.load(path, map_location="cpu") + self.assertFalse(tensor_cpu.is_openreg) + self.assertEqual(torch._utils.get_tensor_metadata(tensor_cpu), {}) # type: ignore[misc] @skipIfTorchDynamo() - def test_event_wait_stream(self): - s1 = torch.Stream(device="openreg") - s2 = torch.Stream(device="openreg") - e1 = s1.record_event() - e1.wait(s2) + @unittest.skipIf( + np.__version__ < "1.25", + "versions < 1.25 serialize dtypes differently from how it's serialized in data_legacy_numpy", + ) + def test_open_device_numpy_serialization(self): + """ + This tests the legacy _rebuild_device_tensor_from_numpy serialization path + """ + data_legacy_numpy = ( + b"PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x10\x00\x12\x00archive/data.pklFB\x0e\x00ZZZZZZZZZZZZZZ\x80\x02}q\x00X\x01" + b"\x00\x00\x00xq\x01ctorch._utils\n_rebuild_device_tensor_from_numpy\nq\x02(cnumpy.core.m" + b"ultiarray\n_reconstruct\nq\x03cnumpy\nndarray\nq\x04K\x00\x85q\x05c_codecs\nencode\nq\x06" + b"X\x01\x00\x00\x00bq\x07X\x06\x00\x00\x00latin1q\x08\x86q\tRq\n\x87q\x0bRq\x0c(K\x01K\x02K" + b"\x03\x86q\rcnumpy\ndtype\nq\x0eX\x02\x00\x00\x00f4q\x0f\x89\x88\x87q\x10Rq\x11(K\x03X\x01" + b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00" + b"PK\x05\x06\x00\x00\x00\x00\x04\x00\x04\x00\x06\x01\x00\x008\x03\x00\x00\x00\x00" + ) + buf_data_legacy_numpy = io.BytesIO(data_legacy_numpy) + + with safe_globals( + [ + ( + ( + np.core.multiarray._reconstruct, + "numpy.core.multiarray._reconstruct", + ) + if np.__version__ >= "2.1" + else np.core.multiarray._reconstruct + ), + np.ndarray, + np.dtype, + _codecs.encode, + np.dtypes.Float32DType, + ] + ): + sd_loaded = torch.load(buf_data_legacy_numpy, weights_only=True) + buf_data_legacy_numpy.seek(0) + # Test map_location + sd_loaded_cpu = torch.load( + buf_data_legacy_numpy, weights_only=True, map_location="cpu" + ) + + expected = torch.tensor( + [[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device="openreg" + ) + self.assertEqual(sd_loaded["x"].cpu(), expected.cpu()) + self.assertFalse(sd_loaded["x"].is_cpu) + self.assertTrue(sd_loaded_cpu["x"].is_cpu) + + def test_open_device_cpu_serialization(self): + default_protocol = torch.serialization.DEFAULT_PROTOCOL + + with patch.object(torch._C, "_has_storage", return_value=False): + x = torch.randn(2, 3) + x_openreg = x.to("openreg") + sd = {"x": x_openreg} + rebuild_func = x_openreg._reduce_ex_internal(default_protocol)[0] + self.assertTrue( + rebuild_func is torch._utils._rebuild_device_tensor_from_cpu_tensor + ) + + # Test map_location + with TemporaryFileName() as f: + torch.save(sd, f) + sd_loaded = torch.load(f, weights_only=True) + # Test map_location + sd_loaded_cpu = torch.load(f, weights_only=True, map_location="cpu") + self.assertFalse(sd_loaded["x"].is_cpu) + self.assertEqual(sd_loaded["x"].cpu(), x) + self.assertTrue(sd_loaded_cpu["x"].is_cpu) + + # Test metadata_only + with TemporaryFileName() as f: + with self.assertRaisesRegex( + RuntimeError, + "Cannot serialize tensors on backends with no storage under skip_data context manager", + ): + with torch.serialization.skip_data(): + torch.save(sd, f) + + # Opeartors + def test_factory(self): + x = torch.empty(3, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.shape, torch.Size([3])) + + y = torch.zeros(3, device="openreg") + self.assertEqual(y.device.type, "openreg") + self.assertEqual(y.shape, torch.Size([3])) + + z = torch.tensor((), device="openreg") + self.assertEqual(z.device.type, "openreg") + self.assertEqual(z.shape, torch.Size([0])) + + def test_fake_tensor(self): + with torch._subclasses.fake_tensor.FakeTensorMode(): + a = torch.empty(1, device="openreg") + b = torch.empty(1, device="openreg:0") + result = a + b # noqa: F841 + + def test_named_tensor(self): + return torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"]) + + def test_printing(self): + a = torch.ones(20, device="openreg") + # Does not crash! + str(a) + + def test_data_dependent_output(self): + cpu_a = torch.randn(10) + a = cpu_a.to(device="openreg") + mask = a.gt(0) + out = torch.masked_select(a, mask) + + self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0))) def test_expand(self): x = torch.tensor([[1], [2], [3]], device="openreg") @@ -157,9 +517,70 @@ def test_expand(self): self.assertEqual(y.to(device="cpu"), torch.tensor([[1, 1], [2, 2], [3, 3]])) self.assertEqual(x.data_ptr(), y.data_ptr()) - def test_empty_tensor(self): - empty_tensor = torch.tensor((), device="openreg") - self.assertEqual(empty_tensor.to(device="cpu"), torch.tensor(())) + def test_resize(self): + tensor_cpu = torch.randn([4, 4]) + + tensor_openreg = tensor_cpu.openreg() + self.assertTrue(tensor_openreg.size() == torch.Size([4, 4])) + + storage_openreg = tensor_openreg.storage() + self.assertTrue(storage_openreg.size() == 16) + + tensor_openreg.resize_(2, 2, 2, 2) + self.assertTrue(tensor_openreg.size() == torch.Size([2, 2, 2, 2])) + + storage_openreg = tensor_openreg.storage() + self.assertTrue(storage_openreg.size() == 16) + + # Quantize + @skipIfXpu(msg="missing kernel for openreg") + def test_quantize(self): + x = torch.randn(3, 4, 5, dtype=torch.float32, device="openreg") + quantized_tensor = torch.quantize_per_tensor(x, 0.1, 10, torch.qint8) + self.assertEqual(quantized_tensor.device, torch.device("openreg:0")) + self.assertEqual(quantized_tensor.dtype, torch.qint8) + + # custom autograd + def test_compile_autograd_function_returns_self(self): + in_ref = torch.randn(4, device="openreg", requires_grad=True) + out_ref = torch.ops.openreg.custom_autograd_fn_returns_self(in_ref) + out_ref.sum().backward() + + in_test = in_ref.detach().clone().requires_grad_(True) + # TODO(FFFrog): Need to support inductor for OpenReg first. + out_test = torch.compile(backend="aot_eager")( + torch.ops.openreg.custom_autograd_fn_returns_self + )(in_test) + out_test.sum().backward() + + self.assertEqual(out_ref, out_test) + self.assertEqual(in_ref.grad, in_test.grad) + + @skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket") + def test_compile_autograd_function_aliasing(self): + in_ref = torch.randn(4, device="openreg", requires_grad=True) + out_ref = torch.ops.openreg.custom_autograd_fn_aliasing(in_ref) + out_ref.sum().backward() + + in_test = in_ref.detach().clone().requires_grad_(True) + # TODO(FFFrog): Need to support inductor for OpenReg first. + out_test = torch.compile(backend="aot_eager")( + torch.ops.openreg.custom_autograd_fn_aliasing + )(in_test) + out_test.sum().backward() + + self.assertEqual(out_ref, out_test) + self.assertEqual(in_ref.grad, in_test.grad) + + def test_open_device_dlpack(self): + x_in = torch.randn(2, 3).to("openreg") + capsule = torch.utils.dlpack.to_dlpack(x_in) + x_out = torch.from_dlpack(capsule) + self.assertTrue(x_out.device == x_in.device) + + x_in = x_in.to("cpu") + x_out = x_out.to("cpu") + self.assertEqual(x_in, x_out) if __name__ == "__main__": diff --git a/test/test_ops.py b/test/test_ops.py index c8079ea71255d7..f5d848532a13a8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -118,21 +118,17 @@ def reduction_dtype_filter(op): aten = torch.ops.aten meta_consistency_out_dtype_mismatch_xfails = { - xfail("alias_copy"), xfail("all"), xfail("amax"), xfail("amin"), xfail("aminmax"), xfail("any"), - xfail("as_strided_copy"), xfail("bucketize"), xfail("conj_physical"), xfail("cross"), xfail("cummax"), xfail("cummin"), xfail("diag"), - xfail("diagonal_copy"), - xfail("expand_copy"), xfail("fft.ihfft2"), xfail("fft.ihfftn"), xfail("frexp"), @@ -167,8 +163,6 @@ def reduction_dtype_filter(op): xfail("msort"), xfail("multinomial"), xfail("nan_to_num"), - xfail("nanmean"), - xfail("narrow_copy"), xfail("native_batch_norm"), xfail("neg"), xfail("nn.functional.avg_pool3d"), @@ -178,7 +172,6 @@ def reduction_dtype_filter(op): xfail("nn.functional.softplus"), xfail("nn.functional.softshrink"), xfail("ormqr"), - xfail("permute_copy"), xfail("qr"), xfail("renorm"), xfail("round"), @@ -193,15 +186,10 @@ def reduction_dtype_filter(op): xfail("softmax"), xfail("sort"), xfail("sparse.sampled_addmm"), - xfail("squeeze_copy"), - xfail("t_copy"), xfail("take"), - xfail("transpose_copy"), xfail("tril"), xfail("triu"), xfail("unfold_copy"), - xfail("unsqueeze_copy"), - xfail("view_copy"), xfail("where"), # Output has dynamic shape. # Does not have a meta kernel implementation. @@ -2498,7 +2486,6 @@ def test_refs_are_in_decomp_table(self, op): "mvlgamma.mvlgamma_p_1", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend "mvlgamma.mvlgamma_p_3", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend "mvlgamma.mvlgamma_p_5", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend - "nanmean", # logical_not() got an unexpected keyword argument 'out' "quantile", # quantile() q values must be in the range [0, 1] "nanquantile", # quantile() q values must be in the range [0, 1] "nn.functional.ctc_loss", # The tensor has a non-zero number of elements, but its data is not allocated yet @@ -2583,6 +2570,7 @@ def test_refs_are_in_decomp_table(self, op): @unMarkDynamoStrictTest class TestFakeTensor(TestCase): def setUp(self): + super().setUp() # Turn on FakeTensor caching and cross-checking for these tests: cache_enabled = unittest.mock.patch( "torch._dynamo.config.fake_tensor_cache_enabled", True diff --git a/test/test_overrides.py b/test/test_overrides.py index fc47b72bfbce9b..8575bb90271ccb 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -368,7 +368,7 @@ class TensorLike: """A class that overrides the full torch API This class is used to explicitly test that the full torch.tensor API - can be overriden with a class that defines __torch_function__. + can be overridden with a class that defines __torch_function__. """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 4704c9992d5ea4..6d36b36996c4b8 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1): view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None - mul_6 = sym_size_int * 3 - view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None + mul_9 = sym_size_int * 3 + view_3 = torch.ops.aten.view.default(view_2, [mul_9, 3]); view_2 = mul_9 = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 74282260553b8f..039898cc16007c 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -59,6 +59,7 @@ def test_no_new_bindings(self): # # {elem for elem in dir(torch._C) if not elem.startswith("_")} torch_C_allowlist_superset = { + "AcceleratorError", "AggregationType", "AliasDb", "AnyType", diff --git a/test/test_pytree.py b/test/test_pytree.py index 82665854c2b135..228dec85bff69f 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -859,6 +859,21 @@ class DirectNamedTuple2(NamedTuple): self.assertFalse(pytree.is_namedtuple(cls)) self.assertFalse(pytree.is_namedtuple_class(cls)) + @parametrize( + "pytree", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_enum_treespec_roundtrip(self, pytree): + data = {TestEnum.A: 5} + spec = pytree.tree_structure(data) + + serialized = pytree.treespec_dumps(spec) + deserialized_spec = pytree.treespec_loads(serialized) + self.assertEqual(spec, deserialized_spec) + class TestPythonPytree(TestCase): def test_deprecated_register_pytree_node(self): diff --git a/test/test_scatter_gather_ops.py b/test/test_scatter_gather_ops.py index 744b65d44b5696..4acff8fab3bdfa 100644 --- a/test/test_scatter_gather_ops.py +++ b/test/test_scatter_gather_ops.py @@ -380,6 +380,22 @@ def helper(input_size, idx_size): helper([50, 8, 7], 100) helper([50, 3, 4, 5], 100) + @dtypes(torch.float32) + def test_scatter_add_broadcasted_index_deterministic(self, device, dtype): + for d in (0, 1): + inp = torch.randn(3, 4, device=device, dtype=dtype) + idx_1d = torch.randint(3, (10,), device=device) + src_shape = list(inp.shape) + src_shape[d] = 10 + src = torch.randn(src_shape, device=device, dtype=dtype) + idx = idx_1d.unsqueeze(1 - d).expand(src_shape) + print(idx.stride()) + ref = inp.clone().scatter_add_(d, idx, src) + with DeterministicGuard(True): + res = inp.clone().scatter_add_(d, idx, src) + self.assertEqual(res, ref) + + @onlyCPU @dtypes(torch.float32, torch.float64, torch.bfloat16) def test_gather_expanded_index(self, device, dtype): diff --git a/test/test_schema_check.py b/test/test_schema_check.py index 9e1d6a6f1250be..29ea36fd8a5f5c 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -232,7 +232,7 @@ def test_schema_check_mode_functionality(self): actual = x.relu().sin() self.assertEqual(expected, actual) - # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overriden + # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overridden def test_schema_check_mode_functionality_default_replaced(self): x = torch.rand((3, 3), requires_grad=True) expected = x.add(x, alpha=2) diff --git a/test/test_serialization.py b/test/test_serialization.py index 94632b1a0ffc09..e92fc4018b0ce4 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -413,6 +413,7 @@ def test_serialization_sparse(self): def test_serialization_sparse_safe(self): self._test_serialization(True) + @unittest.skipIf(True, "Temporary skip due to gh-153143") def test_serialization_sparse_invalid(self): x = torch.zeros(3, 3) x[1][1] = 1 @@ -438,11 +439,12 @@ def __reduce_ex__(self, proto): torch.save({"spoofed": TensorSerializationSpoofer(x)}, f) for weights_only in (False, True): f.seek(0) - with self.assertRaisesRegex( + with torch.sparse.check_sparse_tensor_invariants(), self.assertRaisesRegex( RuntimeError, "size is inconsistent with indices"): y = torch.load(f, weights_only=weights_only) + @unittest.skipIf(True, "Temporary skip due to gh-153143") def test_serialization_sparse_invalid_legacy_ctor(self): # This is set in test class setup but would not be check when running user code prev_invariant_check_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled() @@ -469,14 +471,15 @@ def __reduce_ex__(self, proto): torch.save(sd, f) for weights_only in (True,): f.seek(0) - with self.assertRaisesRegex( + with torch.sparse.check_sparse_tensor_invariants(), self.assertRaisesRegex( RuntimeError, - "size is inconsistent with indices"): + "size is inconsistent with indices|found negative index"): y = torch.load(f, weights_only=weights_only) finally: if prev_invariant_check_enabled: torch.sparse.check_sparse_tensor_invariants.enable() + @torch.sparse.check_sparse_tensor_invariants(enable=True) def _test_serialization_sparse_compressed_invalid(self, conversion, get_compressed_indices, @@ -515,18 +518,22 @@ def __reduce_ex__(self, proto): f"`{compressed_indices_name}[[]..., 0[]] == 0` is not satisfied."): y = torch.load(f) + @unittest.skipIf(True, "Temporary skip due to gh-153143") def test_serialization_sparse_csr_invalid(self): self._test_serialization_sparse_compressed_invalid( torch.Tensor.to_sparse_csr, torch.Tensor.crow_indices, torch.Tensor.col_indices) + @unittest.skipIf(True, "Temporary skip due to gh-153143") def test_serialization_sparse_csc_invalid(self): self._test_serialization_sparse_compressed_invalid( torch.Tensor.to_sparse_csc, torch.Tensor.ccol_indices, torch.Tensor.row_indices) + @unittest.skipIf(True, "Temporary skip due to gh-153143") def test_serialization_sparse_bsr_invalid(self): self._test_serialization_sparse_compressed_invalid( lambda x: x.to_sparse_bsr((1, 1)), torch.Tensor.crow_indices, torch.Tensor.col_indices) + @unittest.skipIf(True, "Temporary skip due to gh-153143") def test_serialization_sparse_bsc_invalid(self): self._test_serialization_sparse_compressed_invalid( lambda x: x.to_sparse_bsc((1, 1)), torch.Tensor.ccol_indices, torch.Tensor.row_indices) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index daa3996437498f..360dc058212a02 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -222,14 +222,14 @@ def test_sort_large(self, device, dtype): t = t0.view(1, 8192).expand(2**18 + 1, -1).contiguous() v, i = t.sort() del t - iv, im = i.var_mean(dim=0) + iv, im = torch.var_mean(i.to(dtype), dim=0) del i - vv, vm = v.var_mean(dim=0) + vv, vm = torch.var_mean(v.to(dtype), dim=0) del v self.assertEqual(vv, torch.zeros_like(vv)) self.assertEqual(iv, torch.zeros_like(iv)) - self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device)) - self.assertEqual(im, t0.sort().indices) + self.assertEqual(vm, torch.arange(8192, dtype=dtype, device=device)) + self.assertEqual(im, t0.sort().indices, exact_dtype=False) @dtypes(torch.float32) def test_sort_restride(self, device, dtype): diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 142eff2b3ae480..5078649bb0065a 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -47,17 +47,18 @@ _IS_SM8X = False _IS_SM9X = False +_IS_HIPSPARSELT_AVAILABLE = False if torch.cuda.is_available(): _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 _IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9 - + _IS_HIPSPARSELT_AVAILABLE = torch.version.hip is not None and tuple(int(v) for v in torch.version.hip.split('.')[:2]) > (6, 4) # CUTLASS kernels only work for Ampere if _IS_SM8X: SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS # add cuSPASRELt tests if available - if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X): + if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE): SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8) @@ -223,6 +224,7 @@ def forward(self, x): @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_mlp_contiguous_relu_compile_cusparselt(self): """ test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile @@ -233,6 +235,7 @@ def test_mlp_contiguous_relu_compile_cusparselt(self): @unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine") @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_mlp_contiguous_relu_compile_cutlass(self): """ test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile @@ -243,6 +246,7 @@ def test_mlp_contiguous_relu_compile_cutlass(self): @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_compile(self) -> None: x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) @@ -571,6 +575,7 @@ def setUp(self): @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_prune_dense_static_sort(self, dtype) -> None: # Ideally we would like to clone and compare, but that won't work because the sorting order will be different # instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern. @@ -615,6 +620,7 @@ def test_prune_dense_static_sort(self, dtype) -> None: @training_dtypes @parametrize_backends + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None: inp = torch.tensor( [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], @@ -651,6 +657,7 @@ def test_gemm(self, dtype) -> None: @training_dtypes @parametrize_backends + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None: M, N = 128, 256 # Construct x to make sure we always have exactly 8 elements per 4x4 tile @@ -684,6 +691,7 @@ def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None: torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype]) @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_pack_both_ways_id(self, dtype) -> None: N = 512 torch.manual_seed(0) @@ -718,6 +726,7 @@ def test_pack_both_ways_id(self, dtype) -> None: ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_pack_both_ways_edge_case1(self, dtype) -> None: # In this case, the heuristic will keep 7 values out of 16 # instead of 8. let's see how the kernel handles this @@ -742,6 +751,7 @@ def test_pack_both_ways_edge_case1(self, dtype) -> None: assert packed_t[0, 1].item() == 0 @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_apply(self, dtype) -> None: M, N = 256, 1024 x = torch.randn([M, N], dtype=dtype, device="cuda") @@ -757,6 +767,7 @@ def test_sp24_apply(self, dtype) -> None: torch.testing.assert_close(packed_t, packed_t2) @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_apply_dense(self, dtype) -> None: M, N = 256, 1024 x = torch.randn([M, N], dtype=dtype, device="cuda") @@ -794,6 +805,7 @@ def test_sp24_apply_dense(self, dtype) -> None: @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_matmuls(self, dtype) -> None: M, N, K = 64, 256, 1024 a = torch.randn([M, K], device="cuda", dtype=dtype) @@ -828,6 +840,7 @@ def test_sp24_matmuls(self, dtype) -> None: a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1 ) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_matmuls_mat_vec(self) -> None: a = torch.randn([64, 128], device="cuda", dtype=torch.float16) b = torch.randn([128], device="cuda", dtype=torch.float16) @@ -837,7 +850,7 @@ def test_sp24_matmuls_mat_vec(self) -> None: with pytest.raises(NotImplementedError): torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) - + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_matmuls_bmm(self) -> None: a = torch.randn([64, 128], device="cuda", dtype=torch.float16) b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16) @@ -988,6 +1001,7 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @inference_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_conversions(self, device, dtype): def run_test(r, c, device, dtype): @@ -1016,6 +1030,7 @@ def run_test(r, c, device, dtype): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @inference_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_conversions_all_patterns(self, device, dtype): r, c = 32, 128 @@ -1135,6 +1150,7 @@ def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device): @unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling") @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_cslt_sparse_mm_alpha(self, dtype, device): A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda() B = torch.ones((256, 128), device=device).to(dtype) @@ -1151,6 +1167,7 @@ def test_cslt_sparse_mm_alpha(self, dtype, device): torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_cslt_sparse_mm_alpha_compile_autotune(self, device, out_dtype): A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).to(device) B = torch.ones((128, 256), device=device, dtype=torch.int8).t() @@ -1172,6 +1189,7 @@ def get_dense_result(): torch.testing.assert_close(sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() B = torch.ones((128, 256), device=device).to(torch.int8).t() diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 154c36832f72ad..03c62a272286d8 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -60,20 +60,22 @@ def _hermitian_conj(x, dim): """ out = torch.empty_like(x) mid = (x.size(dim) - 1) // 2 - idx = [slice(None)] * out.dim() - idx_center = list(idx) - idx_center[dim] = 0 + idx = tuple([slice(None)] * out.dim()) out[idx] = x[idx] idx_neg = list(idx) idx_neg[dim] = slice(-mid, None) - idx_pos = idx + idx_neg = tuple(idx_neg) + idx_pos = list(idx) idx_pos[dim] = slice(1, mid + 1) + idx_pos = tuple(idx_pos) out[idx_pos] = x[idx_neg].flip(dim) out[idx_neg] = x[idx_pos].flip(dim) if (2 * mid + 1 < x.size(dim)): + idx = list(idx) idx[dim] = mid + 1 + idx = tuple(idx) out[idx] = x[idx] return out.conj() @@ -518,6 +520,7 @@ def test_hfftn(self, device, dtype): lastdim_size = input.size(lastdim) // 2 + 1 idx = [slice(None)] * input_ndim idx[lastdim] = slice(0, lastdim_size) + idx = tuple(idx) input = input[idx] s = [shape[dim] for dim in actual_dims] @@ -558,6 +561,7 @@ def test_ihfftn(self, device, dtype): lastdim_size = expect.size(lastdim) // 2 + 1 idx = [slice(None)] * input_ndim idx[lastdim] = slice(0, lastdim_size) + idx = tuple(idx) expect = expect[idx] actual = torch.fft.ihfftn(input, dim=dim, norm="ortho") diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index e804e289c1c20c..220ad2c1c2f3b7 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -37,6 +37,7 @@ from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve from torch.utils._sympy.value_ranges import ValueRanges from torch._inductor.bounds import ValueRangeAnalysis +from torch._inductor.index_propagation import TypedExpr UNARY_OPS = [ @@ -968,6 +969,33 @@ def test_expand_identity(self): self.assertEqual(expanded.count(Identity), 0) self.assertEqual(expanded, arg) + def test_cast_identity_int(self): + num = 1 + expr = Identity(num) + self.assertEqual(num, int(expr)) + + def test_cast_identity_float(self): + num = 1.1 + expr = Identity(num) + self.assertEqual(num, float(expr)) + + def test_cast_identity_illegal(self): + sym = Identity(sympy.Symbol("x")) + self.assertRaises(TypeError, int, sym) + self.assertRaises(TypeError, float, sym) + + tup = (0, 1, 2) + tup_I = Identity(tup) + self.assertRaises(TypeError, int, tup_I) + self.assertRaises(TypeError, float, tup_I) + +class TestTypedExpr(TestCase): + def test_typed_expr(self): + I = Identity(1) + typed_I = TypedExpr(I, torch.int32) + self.assertEqual(typed_I.expr, 1) + + instantiate_parametrized_tests(TestValueRanges) instantiate_parametrized_tests(TestSympyInterp) instantiate_parametrized_tests(TestSympySolve) diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 35db309a0bdf2b..2108b13c0be302 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -33,7 +33,6 @@ IS_S390X, IS_ARM64, parametrize, - skipIfTorchDynamo, xfailIfTorchDynamo, ) from torch.testing._internal.common_device_type import ( @@ -196,7 +195,6 @@ def test_fill_all_dtypes_and_devices(self, device): self.assertEqual(x, torch.tensor([n] * numel, dtype=dt, device=device)) self.assertEqual(dt, x.dtype) - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_roll(self, device): numbers = torch.arange(1, 9, device=device) @@ -535,6 +533,14 @@ def test_cat_empty(self, device): res1 = torch.cat([empty, empty], dim=1) self.assertEqual(res1, empty) + def test_concat_empty_list_error(self, device): + # Regression test for https://github.com/pytorch/pytorch/issues/155306 + msg = "expected a non-empty list of Tensors" + with self.assertRaisesRegex(ValueError, msg): + torch.concat([], dim='N') + with self.assertRaisesRegex(ValueError, msg): + torch.concatenate([], dim='N') + def test_cat_out(self, device): x = torch.zeros((0), device=device) y = torch.randn((4, 6), device=device) @@ -775,7 +781,6 @@ def test_device_rounding(self, device, dtype): # Note: This test failed on XLA since its test cases are created by empty_strided which # doesn't support overlapping sizes/strides in XLA impl - @skipIfTorchDynamo("TorchDynamo fails on this test for unknown reasons") @onlyNativeDeviceTypes def test_like_fn_stride_proparation_vs_tensoriterator_unary_op(self, device): # Test like functions against tensoriterator based unary operator (exp) to @@ -1015,7 +1020,6 @@ def test_dstack(self, device, dtype): expected = np.dstack(np_input) self.assertEqual(actual, expected) - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @dtypes(torch.int32, torch.int64) def test_large_linspace(self, device, dtype): start = torch.iinfo(dtype).min @@ -1284,7 +1288,6 @@ def __getitem__(self, item): torch.tensor(bad_mock_seq, device=device) self.assertEqual(torch.tensor([1.0, 2.0, 3.0], device=device), torch.tensor(good_mock_seq, device=device)) - @skipIfTorchDynamo("Not a TorchDynamo suitable test") def test_simple_scalar_cast(self, device): ok = [torch.tensor([1.5], device=device), torch.zeros(1, 1, 1, 1, device=device)] ok_values = [1.5, 0] @@ -1536,7 +1539,6 @@ def test_combinations(self, device): self.assertEqual(c1, expected) self.assertEqual(c2, expected) - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @skipMeta def test_linlogspace_mem_overlap(self, device): x = torch.rand(1, device=device).expand(10) @@ -1612,8 +1614,6 @@ def test_random_bool(self, device): self.assertEqual(t.max(), True) self.assertTrue(0.4 < (t.eq(True)).to(torch.int).sum().item() / size < 0.6) - # https://github.com/pytorch/pytorch/issues/126834 - @xfailIfTorchDynamo def test_random_from_to_bool(self, device): size = 2000 @@ -1693,8 +1693,6 @@ def test_random_full_range(self, device, dtype): # NB: uint64 is broken because its max value is not representable in # int64_t, but this is what random expects - # https://github.com/pytorch/pytorch/issues/126834 - @xfailIfTorchDynamo @dtypes(*all_types_and(torch.bfloat16, torch.half, torch .uint16, torch.uint32)) def test_random_from_to(self, device, dtype): size = 2000 @@ -1784,8 +1782,6 @@ def test_random_from_to(self, device, dtype): lambda: t.random_(from_, to_) ) - # https://github.com/pytorch/pytorch/issues/126834 - @xfailIfTorchDynamo @dtypes(*all_types_and(torch.bfloat16, torch.half, torch.uint16, torch.uint32)) def test_random_to(self, device, dtype): size = 2000 @@ -2333,7 +2329,6 @@ def test_as_tensor(self, device): self.assertNotEqual(torch.tensor(n, device='cuda'), n_astensor) # TODO: this test should be updated - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @suppress_warnings @dtypesIfCPU(torch.float, torch.bfloat16, torch.float16) @dtypes(torch.float) @@ -2364,7 +2359,6 @@ def test_range(self, device, dtype): self.assertEqual(res1, res2, atol=0, rtol=0) # TODO: this test should be updated - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_range_warning(self, device): with warnings.catch_warnings(record=True) as w: torch.range(0, 10, device=device) @@ -2684,13 +2678,11 @@ def test_fn(torch_fn, numpy_fn, steps): for steps in [1, 2, 3, 5, 11, 256, 257, 2**22]: test_fn(torch.linspace, np.linspace, steps) - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @dtypes(torch.complex64) def test_linspace_vs_numpy_complex(self, device, dtype): self._test_linspace_logspace_complex_helper(torch.linspace, np.linspace, device, dtype) - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @dtypes(torch.complex64) def test_logspace_vs_numpy_complex(self, device, dtype): self._test_linspace_logspace_complex_helper(torch.logspace, np.logspace, @@ -2811,7 +2803,6 @@ def _test_signal_window_functions(self, name, dtype, device, **kwargs): @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) @unittest.skipIf(not TEST_SCIPY, "Scipy not found") @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) - @skipIfTorchDynamo("Not a TorchDynamo suitable test") @dtypes(torch.float, torch.double, torch.long) @parametrize("window", ['hann', 'hamming', 'bartlett', 'blackman']) def test_signal_window_functions(self, device, dtype, window): @@ -2820,7 +2811,6 @@ def test_signal_window_functions(self, device, dtype, window): @onlyNativeDeviceTypes @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) @dtypes(torch.float, torch.double, torch.long, torch.bfloat16, torch.float16) def test_kaiser_window(self, device, dtype): @@ -2847,7 +2837,6 @@ def _test_signal_windows_functions(self, name, dtype, device, **kwargs): # torch.signal.windows functions (except any with extra parameters) @onlyNativeDeviceTypes @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - @skipIfTorchDynamo("Not a TorchDynamo suitable test") @dtypes(torch.float, torch.double) @parametrize("window", ['bartlett', 'blackman', 'cosine', 'hamming', 'hann', 'nuttall']) def test_signal_windows_functions(self, device, dtype, window): @@ -2856,7 +2845,6 @@ def test_signal_windows_functions(self, device, dtype, window): # torch.signal.windows.kaiser @onlyNativeDeviceTypes @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @dtypes(torch.float, torch.double) def test_kaiser(self, device, dtype): for num_test in range(50): @@ -3040,12 +3028,10 @@ def _test_linspace_logspace_deduction_helper(self, fn, device): self.assertEqual(fn(start, end, steps=100, device=device).dtype, dtype) - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_linspace_deduction(self, device): # Test deduction from input parameters. self._test_linspace_logspace_deduction_helper(torch.linspace, device) - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") def test_logspace_deduction(self, device): # Test deduction from input parameters. self._test_linspace_logspace_deduction_helper(torch.logspace, device) @@ -3227,7 +3213,7 @@ def test_from_file(self, device, shared): self.assertTrue(t_mapped.untyped_storage().filename == expected_filename) self.assertEqual(torch.flatten(t), t_mapped) - s = torch.UntypedStorage.from_file(f.name, shared, t.numel() * dtype.itemsize) + s = torch.UntypedStorage.from_file(f.name, shared, nbytes=t.numel() * dtype.itemsize) self.assertTrue(s.filename == expected_filename) @onlyCPU @@ -4185,7 +4171,8 @@ def test_astensor_consistency(self, device): t = torch.asarray(e) self.assertEqual(t, original) - @skipIfTorchDynamo() + # Dynamo changes numpy scalar to array, thus skips the asserted error. + @xfailIfTorchDynamo @onlyCPU def test_numpy_scalars(self, device): scalar = np.float64(0.5) diff --git a/test/test_tensorboard.py b/test/test_tensorboard.py index a7f99f5c459472..c6982d319d8100 100644 --- a/test/test_tensorboard.py +++ b/test/test_tensorboard.py @@ -122,6 +122,10 @@ def assertImageProto(self, actual_proto): from torch.utils.tensorboard._pytorch_graph import graph from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC from torch.utils.tensorboard.summary import int_to_half, tensor_proto +else: + # Dummy for parametrization + class DataType: + DT_FLOAT, DT_HALF, DT_BFLOAT16, DT_INT32 = [None] * 4 class TestTensorBoardPyTorchNumpy(BaseTestCase): diff --git a/test/test_torch.py b/test/test_torch.py index c73d34688bdab4..cb118d2b10b63f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -248,6 +248,19 @@ def test_storage_setitem(self, device, dtype): s[2:7] = 1 self.assertEqual(s, storage_type(l)) + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") + @onlyNativeDeviceTypes + @unittest.skipIf( + "RelWithAssert" in torch.__config__.show(), + "failing in debug build, see https://github.com/pytorch/pytorch/pull/156731 for example", + ) + def test_storage_use_count(self, device): + a = torch.randn(10, device=device) + prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata) + self.assertEqual(prev_cf, 1) + b = a.view(2, 5) + self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1) + @xfailIfTorchDynamo @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @@ -1536,6 +1549,40 @@ def test_nondeterministic_alert_interpolate_bilinear(self, device): 'upsample_bilinear2d_backward_out_cuda', torch.device(device).type == 'cuda') + def test_no_nondeterministic_alert_interpolate_bilinear(self, device): + input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True) + + def fn(): + res = torch.nn.functional.interpolate( + input, + size=12, + mode='bilinear', + align_corners=False) + grad = torch.ones_like(res) + return res.backward(grad) + + self.check_nondeterministic_alert( + fn, + 'upsample_bilinear2d_backward_out_cuda', + False) + + def test_no_nondeterministic_alert_interpolate_trilinear(self, device): + input = torch.randn(1, 2, 4, 4, 4, device=device, requires_grad=True) + + def fn(): + res = torch.nn.functional.interpolate( + input, + size=12, + mode='trilinear', + align_corners=False) + grad = torch.ones_like(res) + return res.backward(grad) + + self.check_nondeterministic_alert( + fn, + 'upsample_trilinear3d_backward_out_cuda', + False) + @skipIfTorchInductor("aot-autograd issue") def test_deterministic_replication_pad2d(self, device): test_cases = [ @@ -2103,8 +2150,11 @@ def _cond_fn(x): ind_cpu = ind.cpu() repeats = torch.full((1,), 2, device=device) mask = torch.randint(2, (size,), device=device, dtype=bool) + mask_cpu = mask.cpu() expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.), + lambda: _ind_put_fn(x, mask_cpu, y), lambda: _ind_put_fn(x, ind, y), + lambda: _ind_get_fn(x, mask_cpu), lambda: _ind_get_fn(x, ind), lambda: torch.nn.functional.one_hot(ind, num_classes=size), lambda: torch.randperm(20000, device=device), @@ -2274,7 +2324,7 @@ def test_corrcoef(self, device, dtype): for x in self._generate_correlation_tensors(device, dtype): res = torch.corrcoef(x) ref = np.corrcoef(x.cpu().numpy()) - self.assertEqual(res, ref, exact_dtype=False) + self.assertEqual(res, ref, atol=1e-04, rtol=1e-03, exact_dtype=False) @skipRocmIfTorchInductor @dtypes(torch.int, torch.float, torch.cfloat) @@ -8589,17 +8639,96 @@ def test_map2(self): lambda: res.map2_(y, z, lambda a, b, c: a + b * c)) def test_Size(self): - x = torch.Size([1, 2, 3]) - self.assertIsInstance(x, tuple) - self.assertEqual(x[0], 1) - self.assertEqual(x[1], 2) - self.assertEqual(x[2], 3) - self.assertEqual(len(x), 3) + # expects iterable of int, not Tensor self.assertRaises(TypeError, lambda: torch.Size(torch.ones(3))) - - self.assertIsInstance(x * 2, torch.Size) - self.assertIsInstance(x[:-1], torch.Size) - self.assertIsInstance(x + x, torch.Size) + # initialization + empty_size = torch.Size([]) + size = torch.Size([1, 2, 3]) + self.assertIsInstance(empty_size, tuple) + self.assertIsInstance(size, tuple) + # value check __len__ + self.assertEqual(len(empty_size), 0) + self.assertEqual(len(size), 3) + # type check __getitem__[int] + self.assertIsInstance(size[0], int) + self.assertIsInstance(size[1], int) + self.assertIsInstance(size[2], int) + # value check __getitem__[int] + self.assertEqual(size[0], 1) + self.assertEqual(size[1], 2) + self.assertEqual(size[2], 3) + # type check __getitem__[slice] + self.assertIsInstance(size[:], torch.Size) + self.assertIsInstance(size[:-1], torch.Size) + self.assertIsInstance(size[0:0], torch.Size) + # value check __getitem__[slice] + self.assertEqual(size[:], (1, 2, 3)) + self.assertEqual(size[:-1], (1, 2)) + self.assertEqual(size[0:0], ()) + # type check __add__ + self.assertIsInstance(empty_size + (), torch.Size) + self.assertIsInstance(size + (), torch.Size) + self.assertIsInstance(size + (4, 5), torch.Size) + self.assertIsInstance(size + size, torch.Size) + # value check __add__ + self.assertEqual(empty_size + (), ()) + self.assertEqual(size + (), (1, 2, 3)) + self.assertEqual(size + (4, 5), (1, 2, 3, 4, 5)) + self.assertEqual(size + size, (1, 2, 3, 1, 2, 3)) + # type check __radd__ + self.assertIsInstance(() + empty_size, torch.Size) + self.assertIsInstance((4, 5) + size, torch.Size) + # value check __radd__ + self.assertEqual(() + size, (1, 2, 3)) + self.assertEqual((4, 5) + size, (4, 5, 1, 2, 3)) + # type check __mul__ + self.assertIsInstance(empty_size * 0, torch.Size) + self.assertIsInstance(size * 0, torch.Size) + self.assertIsInstance(size * 1, torch.Size) + self.assertIsInstance(size * 2, torch.Size) + # value check __mul__ + self.assertEqual(empty_size * 0, ()) + self.assertEqual(size * 0, ()) + self.assertEqual(size * 1, (1, 2, 3)) + self.assertEqual(size * 2, (1, 2, 3, 1, 2, 3)) + # type check __rmul__ + self.assertIsInstance(0 * empty_size, torch.Size) + self.assertIsInstance(0 * size, torch.Size) + self.assertIsInstance(1 * size, torch.Size) + self.assertIsInstance(2 * size, torch.Size) + # value check __rmul__ + self.assertEqual(0 * empty_size, ()) + self.assertEqual(0 * size, ()) + self.assertEqual(1 * size, (1, 2, 3)) + self.assertEqual(2 * size, (1, 2, 3, 1, 2, 3)) + + def test_Size_concat_non_tuple_sequence(self): + # check that TypeError get's raised on adding non-tuple sequences. + from collections.abc import Sequence + + class DummySequence(Sequence): + vals = list(range(5)) + def __len__(self): return len(self.vals) + def __getitem__(self, i): return self.vals[i] + def __iter__(self): return iter(self.vals) + + size = torch.Size([1, 2, 3]) + seq = DummySequence() + msg = r"can only concatenate tuple \(not \w+\) to torch.Size" + self.assertRaisesRegex(TypeError, msg, lambda: size + seq) + msg = r"unsupported operand type" + self.assertRaisesRegex(TypeError, msg, lambda: seq + size) + + def test_Size_concat_wildcard(self): + # check that 3rd party classes can support addition with torch.Size + class Wildcard: + def __add__(self, other): return 42 + def __radd__(self, other): return 42 + + size = torch.Size([1, 2, 3]) + wildcard = Wildcard() + self.assertEqual(wildcard + size, 42) + self.assertEqual(size + wildcard, 42) def test_Size_scalar(self): three = torch.tensor(3) diff --git a/test/test_transformers.py b/test/test_transformers.py index 34eaca2390d480..0a273c87be847e 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -17,7 +17,7 @@ import math import itertools import torch.optim as optim -from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU +from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU, largeTensorTest from typing import Optional import torch.utils.cpp_extension from torch.testing._internal.common_nn import NNTestCase @@ -1900,23 +1900,71 @@ def test_flash_attention_fail_with_non_square_causal_attention(self, device): @onlyCUDA def test_mem_eff_attention_fail_with_batch_size_geq_65536(self): - query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) - key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) - value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) + batch_size = 2**16 + query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True) + key = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True) + value = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True) + q_cpu, k_cpu, v_cpu = (query.detach().cpu().requires_grad_(True), + key.detach().cpu().requires_grad_(True), + value.detach().cpu().requires_grad_(True)) with sdpa_kernel(backends=SDPBackend.EFFICIENT_ATTENTION): out = F.scaled_dot_product_attention(query, key, value) - out_cpu = F.scaled_dot_product_attention(query.cpu(), key.cpu(), value.cpu()) - self.assertEqual(out, out_cpu, atol=1e-3, rtol=1e-4) + out_cpu = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu) + grad_out = torch.rand_like(out) + out.backward(grad_out) + out_cpu.backward(grad_out.cpu()) + + self.assertEqual(out, out_cpu, atol=2e-3, rtol=1e-4) + self.assertEqual(query.grad, q_cpu.grad, atol=2e-3, rtol=1e-4) + self.assertEqual(key.grad, k_cpu.grad, atol=2e-3, rtol=1e-4) + self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4) @onlyCUDA def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self): query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) - error_str = (r"Efficient attention cannot produce valid seed, " - r"logsumexp and offset outputs when the batch size exceeds \(65535\)\.") + error_str = (r"Efficient attention cannot produce valid seed and offset outputs when " + r"the batch size exceeds \(65535\)\.") with self.assertRaisesRegex(RuntimeError, error_str): - torch._scaled_dot_product_efficient_attention(query, key, value, attn_bias=None, compute_log_sumexp=True) + torch._scaled_dot_product_efficient_attention(query, key, value, + attn_bias=None, compute_log_sumexp=True, + dropout_p=0.01) + + @largeTensorTest("15GB", "cuda") + @onlyCUDA + def test_mem_eff_attention_large_seq_len_uniform_attention(self): + device = torch.device("cuda") + dtype = torch.bfloat16 + + num_queries = 49999 + num_heads = 2 + feature_dim = 16 + + # Q and K are all zeros -> uniform attention + query = torch.zeros(1, num_heads, num_queries, feature_dim, device=device, dtype=dtype, requires_grad=True) + key = torch.zeros(1, num_heads, num_queries, feature_dim, device=device, dtype=dtype, requires_grad=True) + value = torch.ones(1, num_heads, num_queries, feature_dim, device=device, dtype=dtype, requires_grad=True) + mask = torch.ones((num_queries, num_queries), dtype=torch.bool, device=device) + + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): + output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=mask, + ) + expected = torch.ones_like(output) + grad_output = torch.ones_like(output) + output.backward(grad_output) + + self.assertTrue(torch.allclose(output, expected)) + self.assertTrue(torch.allclose(query.grad, torch.zeros_like(query))) + self.assertTrue(torch.allclose(key.grad, torch.zeros_like(key))) + # For value, since each input position contributed 1/num_queries to each output, the grad should sum accordingly + # for all ones grad_output, each value position receives grad of 1 (because sum of all softmax weights per row is 1) + self.assertTrue(torch.allclose(value.grad, torch.ones_like(value))) + def _get_block_size_n(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel @@ -3145,7 +3193,7 @@ def test_mem_eff_backwards_determinism(self, device): out = F.scaled_dot_product_attention(query, key, value) upward_grad = torch.rand_like(out) out.backward(upward_grad) - intial_query_grad = query.grad + initial_query_grad = query.grad # Re-run the op with the same upward grad and check that the backward is # not deterministic @@ -3154,7 +3202,7 @@ def test_mem_eff_backwards_determinism(self, device): query.grad = None out = F.scaled_dot_product_attention(query, key, value) out.backward(upward_grad) - if not torch.equal(intial_query_grad, query.grad): + if not torch.equal(initial_query_grad, query.grad): diff_anwser_once = True break self.assertTrue(diff_anwser_once) @@ -3164,7 +3212,7 @@ def test_mem_eff_backwards_determinism(self, device): out = F.scaled_dot_product_attention(query, key, value) upward_grad = torch.rand_like(out) out.backward(upward_grad) - intial_query_grad = query.grad + initial_query_grad = query.grad # Re-run the op with the same upward grad and check that the backward is # deterministic now that we have enforced it @@ -3173,7 +3221,7 @@ def test_mem_eff_backwards_determinism(self, device): query.grad = None out = F.scaled_dot_product_attention(query, key, value) out.backward(upward_grad) - if not torch.equal(intial_query_grad, query.grad): + if not torch.equal(initial_query_grad, query.grad): diff_anwser_once = True break self.assertFalse(diff_anwser_once) @@ -3281,7 +3329,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors['grad_key'] = 70.0 if seq_len_k >= 2048: fudge_factors['grad_key'] = 160.0 - fudge_factors['grad_query'] = 650.0 + fudge_factors['grad_query'] = 670.0 if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 @@ -3402,7 +3450,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors['grad_key'] = 70.0 if seq_len_k >= 2048: fudge_factors['grad_key'] = 160.0 - fudge_factors['grad_query'] = 650.0 + fudge_factors['grad_query'] = 670.0 # gfx90a if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 @@ -3994,7 +4042,7 @@ def test_fused_sdp_choice_xpu(self, device, type: str, dropout: float, dtype: to def test_fused_attention_different_dk_dv(self, device): dtype = torch.bfloat16 - make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64 q_shape = SdpaShape(batch, num_heads, 1, head_dim_k) k_shape = SdpaShape(batch, num_heads, 2, head_dim_k) @@ -4010,11 +4058,49 @@ def test_fused_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - def test_onednn_attention_fail_d256(self, device): - # Test that onednn graph attention dispatching correctly bails out on d > 256 + @parametrize("dtype", [torch.half, torch.bfloat16]) + @parametrize("batch_size,n_head,n_head_kv,q_size,kv_size,head_dim", [ + (2, 64, 16, 9216, 77, 64), + (2, 32, 4, 2304, 2304, 64), + (2, 32, 2, 2304, 77, 64), + (2, 20, 2, 576, 576, 64), + (2, 20, 2, 576, 77, 64), + (2, 20, 2, 144, 144, 64), + (2, 20, 2, 144, 77, 64), + (1, 32, 2, 1, 32, 128), + (4, 32, 4, 1, 32, 128), + (1, 32, 2, 32, 32, 128), + (4, 32, 4, 32, 32, 128), + (1, 32, 2, 2016, 2016, 128), + (4, 32, 4, 2016, 2016, 128), + ]) + @parametrize("is_causal", [True, False]) + def test_fused_attention_gqa(self, device, dtype, batch_size, n_head, n_head_kv, q_size, kv_size, head_dim, is_causal): + tol = Tolerances(1e-5, 5e-6) + if dtype is torch.bfloat16: + tol = Tolerances(5e-2, 5e-2) + if dtype is torch.float16: + tol = Tolerances(1e-2, 1e-2) + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) + q_shape = SdpaShape(batch_size, n_head, q_size, head_dim) + k_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim) + v_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim) + query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + # test that we do not dispatch to onednn for an unsupported case + actual = F.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True) + + math_ref = torch.ops.aten._scaled_dot_product_attention_math( + query.float(), key.float(), value.float(), attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True)[0] + + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=tol.atol, rtol=tol.rtol) + + def test_onednn_attention_fail_d576(self, device): + # Test that onednn graph attention dispatching correctly bails out on d > 576 b, h = 1, 2 s_q, s_kv = 128, 128 - d_qk, d_v = 512, 512 + d_qk, d_v = 1024, 1024 q = torch.randn(b, h, s_q, d_qk, device=device, dtype=torch.bfloat16) k = torch.randn(b, h, s_kv, d_qk, device=device, dtype=torch.bfloat16) @@ -4046,6 +4132,50 @@ def test_fused_attention_broadcasted_input(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + def test_attention_preserves_query_layout(self, device): + + def test_attention(permute_order: list[list[int]]): + BHSqD = [4, 16, 256, 64] + BHSkvD = [4, 16, 512, 64] + + shape_q = [BHSqD[idx] for idx in permute_order] + shape_kv = [BHSkvD[idx] for idx in permute_order] + reverse = [permute_order.index(idx) for idx in range(4)] + q = torch.randn(*shape_q, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse) + k = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse) + v = torch.randn(*shape_kv, dtype=torch.bfloat16, device=device, requires_grad=False).permute(reverse) + self.assertEqual(q.shape, BHSqD) + self.assertEqual(k.shape, BHSkvD) + self.assertEqual(v.shape, BHSkvD) + + out = F.scaled_dot_product_attention(q, k, v) + self.assertTrue(out.permute(permute_order).is_contiguous()) + + permutable = [0, 1, 2] + permute_orders = itertools.permutations(permutable) + + for permute_order in permute_orders: + test_attention(list(permute_order) + [3]) + + def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device): + dtype = torch.bfloat16 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) + batch, num_heads, seqlen, head_dim = 32, 16, 32, 64 + q_shape = SdpaShape(batch, num_heads, seqlen, head_dim) + k_shape = SdpaShape(batch, num_heads, seqlen, head_dim) + v_shape = SdpaShape(batch, num_heads, seqlen, head_dim) + query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + attn_mask = torch.full((seqlen, seqlen), float('-inf'), device=device, dtype=torch.bfloat16) + + actual = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + + math_ref = torch.ops.aten._scaled_dot_product_attention_math( + query.float(), key.float(), value.float(), attn_mask=attn_mask, dropout_p=0.0, is_causal=False)[0] + + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + @parametrize("type", ["dense"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): diff --git a/test/test_utils.py b/test/test_utils.py index 5f69ecdfe35a57..080afe76159133 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,6 +20,7 @@ import torch.utils.cpp_extension import torch.utils.data from torch._utils import try_import +from torch._utils_internal import deprecated from torch.autograd._functions.utils import check_onnx_broadcast from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings from torch.testing._internal.common_cuda import TEST_MULTIGPU @@ -61,6 +62,9 @@ from torch.testing._internal.common_utils import run_tests, TestCase +# mypy: disable-error-code="name-defined" + + class RandomDatasetMock(torch.utils.data.Dataset): def __getitem__(self, index): return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) @@ -1060,18 +1064,33 @@ def test_get_default_device(self): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_get_default_device_more(self): - torch.set_default_device("cuda") - self.assertEqual(torch.get_default_device(), torch.tensor([]).device) - torch.set_default_device(None) + try: + torch.set_default_device("cuda") + self.assertEqual(torch.get_default_device(), torch.tensor([]).device) + torch.set_default_device(None) - torch.set_default_device("cuda") - torch.cuda.set_device("cuda:1") - self.assertEqual(torch.get_default_device(), torch.tensor([]).device) - torch.set_default_device(None) + torch.set_default_device("cuda") + torch.cuda.set_device("cuda:1") + self.assertEqual(torch.get_default_device(), torch.tensor([]).device) + torch.set_default_device(None) - torch.set_default_device("cuda:1") - self.assertEqual(torch.get_default_device(), torch.tensor([]).device) - torch.set_default_device(None) + torch.set_default_device("cuda:1") + self.assertEqual(torch.get_default_device(), torch.tensor([]).device) + torch.set_default_device(None) + + torch.set_default_device("cuda:1") + with torch.device("cuda:0"): + self.assertEqual(torch.get_default_device(), torch.device("cuda", 0)) + + torch.set_default_device("cpu") + self.assertEqual(torch.get_default_device(), torch.device("cpu")) + with torch.device("cuda:0"): + self.assertEqual(torch.get_default_device(), torch.device("cuda", 0)) + + self.assertEqual(torch.get_default_device(), torch.device("cpu")) + finally: + # Reset the device at the end. + torch.set_default_device(None) @onlyCPU @ops(op_db) @@ -1182,5 +1201,20 @@ def test_import_missing(self): self.assertIsNone(missing_module) +@deprecated() +def _deprecated_api(x, y=15): + return x + y + + +class TestDeprecate(TestCase): + def test_deprecated(self): + with self.assertWarnsRegex(Warning, "is DEPRECATED"): + deprecated_api(1, 2) # noqa: F821 + with self.assertWarnsRegex(Warning, "is DEPRECATED"): + deprecated_api(1, y=2) # noqa: F821 + _deprecated_api(1, 2) + _deprecated_api(1, y=2) + + if __name__ == "__main__": run_tests() diff --git a/test/test_xpu.py b/test/test_xpu.py index cf7415ead2a89c..cd5275418c4402 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -101,6 +101,7 @@ def test_get_device_properties(self): self.assertEqual(device_name, torch.xpu.get_device_name()) device_capability = torch.xpu.get_device_capability(current_device) + self.assertTrue(device_capability["device_id"] > 0) self.assertTrue(device_capability["max_work_group_size"] > 0) self.assertTrue(device_capability["max_num_sub_groups"] > 0) self.assertEqual( @@ -551,15 +552,10 @@ def test_torch_version_xpu(self): library = find_library_location("libtorch_xpu.so") cmd = f"ldd {library} | grep libsycl" results = subprocess.check_output(cmd, shell=True).strip().split(b"\n") - # There should be only one libsycl.so or libsycl-preview.so + # There should be only one libsycl.so self.assertEqual(len(results), 1) for result in results: - if b"libsycl.so" in result: - self.assertGreaterEqual(compiler_version, 20250000) - elif b"libsycl-preview.so" in result: - self.assertLess(compiler_version, 20250000) - else: - self.fail("Unexpected libsycl library") + self.assertTrue(b"libsycl.so" in result) def test_dlpack_conversion(self): x = make_tensor((5,), dtype=torch.float32, device="xpu") diff --git a/test/torch_np/numpy_tests/core/test_getlimits.py b/test/torch_np/numpy_tests/core/test_getlimits.py index 738b272d40a3ba..04ca9e207da3a7 100644 --- a/test/torch_np/numpy_tests/core/test_getlimits.py +++ b/test/torch_np/numpy_tests/core/test_getlimits.py @@ -135,9 +135,9 @@ def test_basic(self): [ np.uint8, # xfail: unsupported add (uint[16,32,64]) - subtest(np.uint16, decorators=[xfail]), - subtest(np.uint32, decorators=[xfail]), - subtest(np.uint64, decorators=[xfail]), + subtest(np.uint16, decorators=[] if TEST_WITH_TORCHDYNAMO else [xfail]), + subtest(np.uint32, decorators=[] if TEST_WITH_TORCHDYNAMO else [xfail]), + subtest(np.uint64, decorators=[] if TEST_WITH_TORCHDYNAMO else [xfail]), ], ) def test_unsigned_max(self, T): diff --git a/test/typing/fail/arithmetic_ops.py b/test/typing/fail/arithmetic_ops.py index 3108d4b1379e6c..b3f816329445aa 100644 --- a/test/typing/fail/arithmetic_ops.py +++ b/test/typing/fail/arithmetic_ops.py @@ -7,33 +7,18 @@ # See ../pass/arithmetic_ops.py for more information -TENSOR, INT, FLOAT = randn(3), 2, 1.5 +TENSOR, FLOAT = randn(3), 1.5 -assert_type( - INT & TENSOR, # E: Unsupported operand types for & ("int" and "Tensor") [operator] - Any, -) -assert_type( - INT | TENSOR, # E: Unsupported operand types for | ("int" and "Tensor") [operator] - Any, -) -assert_type( - INT ^ TENSOR, # E: Unsupported operand types for ^ ("int" and "Tensor") [operator] - Any, -) - -assert_type( - FLOAT # E: Unsupported operand types for & ("float" and "Tensor") [operator] - & TENSOR, - Tensor, -) -assert_type( - FLOAT # E: Unsupported operand types for | ("float" and "Tensor") [operator] - | TENSOR, - Tensor, -) -assert_type( - FLOAT # E: Unsupported operand types for ^ ("float" and "Tensor") [operator] - ^ TENSOR, - Tensor, -) +FLOAT & TENSOR # E: Unsupported operand types for & ("float" and "Tensor") +FLOAT | TENSOR # E: Unsupported operand types for | ("float" and "Tensor") +FLOAT ^ TENSOR # E: Unsupported operand types for ^ ("float" and "Tensor") +# FIXME: false negatives (https://github.com/pytorch/pytorch/issues/155701) +# +# FLOAT << TENSOR # E: Unsupported operand types for & ("float" and "Tensor") +# FLOAT >> TENSOR # E: Unsupported operand types for & ("float" and "Tensor") +# +# TENSOR & FLOAT # E: Unsupported operand types for & ("Tensor" and "float" ) +# TENSOR | FLOAT # E: Unsupported operand types for | ("Tensor" and "float" ) +# TENSOR ^ FLOAT # E: Unsupported operand types for ^ ("Tensor" and "float" ) +# TENSOR << FLOAT # E: Unsupported operand types for & ("Tensor" and "float") +# TENSOR >> FLOAT # E: Unsupported operand types for & ("Tensor" and "float") diff --git a/test/typing/pass/arithmetic_ops.py b/test/typing/pass/arithmetic_ops.py index 4edfb73e73594f..556ef90523e947 100644 --- a/test/typing/pass/arithmetic_ops.py +++ b/test/typing/pass/arithmetic_ops.py @@ -1,154 +1,208 @@ -from typing import Any, Union +from typing import Union from typing_extensions import assert_type, TypeAlias from torch import randn, Tensor -TENSOR, INT, FLOAT, BOOL = randn(3), 2, 1.5, True - # Test deduced types of arithmetic operations between tensors, ints, floats and bools -# The expected type should always be `Tensor`: `Any` and `bool` below are wrong. +# The expected type should always be `Tensor`, but isn't. # See https://github.com/pytorch/pytorch/issues/145838 +TENSOR, INT, FLOAT, BOOL = randn(3), 2, 1.5, True + +# # Unary ops +# assert_type(+TENSOR, Tensor) assert_type(-TENSOR, Tensor) assert_type(~TENSOR, Tensor) -# Binary ops +# +# Binary ops that return a bolean +# +# Operator == assert_type(TENSOR == TENSOR, Tensor) -assert_type(TENSOR != TENSOR, Tensor) -assert_type(TENSOR < TENSOR, Tensor) -assert_type(TENSOR > TENSOR, Tensor) -assert_type(TENSOR <= TENSOR, Tensor) -assert_type(TENSOR >= TENSOR, Tensor) -assert_type(TENSOR + TENSOR, Tensor) -assert_type(TENSOR - TENSOR, Tensor) -assert_type(TENSOR * TENSOR, Tensor) -assert_type(TENSOR // TENSOR, Any) -assert_type(TENSOR / TENSOR, Tensor) -assert_type(TENSOR % TENSOR, Tensor) -assert_type(TENSOR**TENSOR, Tensor) -assert_type(TENSOR << TENSOR, Tensor) -assert_type(TENSOR >> TENSOR, Tensor) -assert_type(TENSOR & TENSOR, Tensor) -assert_type(TENSOR | TENSOR, Tensor) -assert_type(TENSOR ^ TENSOR, Tensor) - assert_type(TENSOR == BOOL, Tensor) +assert_type(BOOL == TENSOR, bool) # Should be Tensor +assert_type(TENSOR == INT, Tensor) +assert_type(INT == TENSOR, bool) # Should be Tensor +assert_type(TENSOR == FLOAT, Tensor) +assert_type(FLOAT == TENSOR, bool) # Should be Tensor + +# Operator != +assert_type(TENSOR != TENSOR, Tensor) assert_type(TENSOR != BOOL, Tensor) -assert_type(TENSOR < BOOL, Tensor) -assert_type(TENSOR > BOOL, Tensor) -assert_type(TENSOR <= BOOL, Tensor) -assert_type(TENSOR >= BOOL, Tensor) -assert_type(TENSOR + BOOL, Tensor) -assert_type(TENSOR - BOOL, Tensor) -assert_type(TENSOR * BOOL, Tensor) -assert_type(TENSOR // BOOL, Any) -assert_type(TENSOR / BOOL, Tensor) -assert_type(TENSOR % BOOL, Tensor) -assert_type(TENSOR**BOOL, Tensor) -assert_type(TENSOR << BOOL, Tensor) -assert_type(TENSOR >> BOOL, Tensor) -assert_type(TENSOR & BOOL, Tensor) -assert_type(TENSOR | BOOL, Tensor) -assert_type(TENSOR ^ BOOL, Tensor) +assert_type(BOOL != TENSOR, bool) # Should be Tensor +assert_type(TENSOR != INT, Tensor) +assert_type(INT != TENSOR, bool) # Should be Tensor +assert_type(TENSOR != FLOAT, Tensor) +assert_type(FLOAT != TENSOR, bool) # Should be Tensor -assert_type(BOOL == TENSOR, bool) -assert_type(BOOL != TENSOR, bool) +# Operator < +assert_type(TENSOR < TENSOR, Tensor) +assert_type(TENSOR < BOOL, Tensor) assert_type(BOOL < TENSOR, Tensor) -assert_type(BOOL > TENSOR, Tensor) -assert_type(BOOL <= TENSOR, Tensor) -assert_type(BOOL >= TENSOR, Tensor) -assert_type(BOOL + TENSOR, Tensor) -assert_type(BOOL - TENSOR, Any) -assert_type(BOOL * TENSOR, Tensor) -assert_type(BOOL // TENSOR, Any) -assert_type(BOOL / TENSOR, Any) -assert_type(BOOL % TENSOR, Any) -assert_type(BOOL**TENSOR, Any) -assert_type(BOOL << TENSOR, Any) -assert_type(BOOL >> TENSOR, Any) -assert_type(BOOL & TENSOR, Tensor) -assert_type(BOOL | TENSOR, Tensor) -assert_type(BOOL ^ TENSOR, Tensor) - -assert_type(TENSOR == INT, Tensor) -assert_type(TENSOR != INT, Tensor) assert_type(TENSOR < INT, Tensor) +assert_type(INT < TENSOR, Tensor) +assert_type(TENSOR < FLOAT, Tensor) +assert_type(FLOAT < TENSOR, Tensor) + +# Operator > +assert_type(TENSOR > TENSOR, Tensor) +assert_type(TENSOR > BOOL, Tensor) +assert_type(BOOL > TENSOR, Tensor) assert_type(TENSOR > INT, Tensor) +assert_type(INT > TENSOR, Tensor) +assert_type(TENSOR > FLOAT, Tensor) +assert_type(FLOAT > TENSOR, Tensor) + +# Operator <= +assert_type(TENSOR <= TENSOR, Tensor) +assert_type(TENSOR <= BOOL, Tensor) +assert_type(BOOL <= TENSOR, Tensor) assert_type(TENSOR <= INT, Tensor) +assert_type(INT <= TENSOR, Tensor) +assert_type(TENSOR <= FLOAT, Tensor) +assert_type(FLOAT <= TENSOR, Tensor) + +# Operator >= +assert_type(TENSOR >= TENSOR, Tensor) +assert_type(TENSOR >= BOOL, Tensor) +assert_type(BOOL >= TENSOR, Tensor) assert_type(TENSOR >= INT, Tensor) +assert_type(INT >= TENSOR, Tensor) +assert_type(TENSOR >= FLOAT, Tensor) +assert_type(FLOAT >= TENSOR, Tensor) + +# +# Binary ops that take and return ints or floats +# + +# Operator + +assert_type(TENSOR + TENSOR, Tensor) +assert_type(TENSOR + BOOL, Tensor) +assert_type(BOOL + TENSOR, Tensor) assert_type(TENSOR + INT, Tensor) +assert_type(INT + TENSOR, Tensor) +assert_type(TENSOR + FLOAT, Tensor) +assert_type(FLOAT + TENSOR, Tensor) + +# Operator - +assert_type(TENSOR - TENSOR, Tensor) +assert_type(TENSOR - BOOL, Tensor) +assert_type(BOOL - TENSOR, Tensor) assert_type(TENSOR - INT, Tensor) +assert_type(INT - TENSOR, Tensor) +assert_type(TENSOR - FLOAT, Tensor) +assert_type(FLOAT - TENSOR, Tensor) + +# Operator * +assert_type(TENSOR * TENSOR, Tensor) +assert_type(TENSOR * BOOL, Tensor) +assert_type(BOOL * TENSOR, Tensor) assert_type(TENSOR * INT, Tensor) -assert_type(TENSOR // INT, Any) +assert_type(INT * TENSOR, Tensor) +assert_type(TENSOR * FLOAT, Tensor) +assert_type(FLOAT * TENSOR, Tensor) + +# Operator // +assert_type(TENSOR // TENSOR, Tensor) +assert_type(TENSOR // BOOL, Tensor) +assert_type(BOOL // TENSOR, Tensor) +assert_type(TENSOR // INT, Tensor) +assert_type(INT // TENSOR, Tensor) +assert_type(TENSOR // FLOAT, Tensor) +assert_type(FLOAT // TENSOR, Tensor) + +# Operator / +assert_type(TENSOR / TENSOR, Tensor) +assert_type(TENSOR / BOOL, Tensor) +assert_type(BOOL / TENSOR, Tensor) assert_type(TENSOR / INT, Tensor) +assert_type(INT / TENSOR, Tensor) +assert_type(TENSOR / FLOAT, Tensor) +assert_type(FLOAT / TENSOR, Tensor) + +# Operator % +assert_type(TENSOR % TENSOR, Tensor) +assert_type(TENSOR % BOOL, Tensor) +assert_type(BOOL % TENSOR, Tensor) assert_type(TENSOR % INT, Tensor) +assert_type(INT % TENSOR, Tensor) +assert_type(TENSOR % FLOAT, Tensor) +assert_type(FLOAT % TENSOR, Tensor) + +# Operator ** +assert_type(TENSOR**TENSOR, Tensor) +assert_type(TENSOR**BOOL, Tensor) +assert_type(BOOL**TENSOR, Tensor) assert_type(TENSOR**INT, Tensor) +assert_type(INT**TENSOR, Tensor) +assert_type(TENSOR**FLOAT, Tensor) +assert_type(FLOAT**TENSOR, Tensor) + +# +# Matrix multiplication +# + +# Operator @ +assert_type(TENSOR @ TENSOR, Tensor) +assert_type(TENSOR @ BOOL, Tensor) # Should fail type checking +assert_type(BOOL @ TENSOR, Tensor) # type: ignore[operator] +assert_type(TENSOR @ INT, Tensor) # Should fail type checking +assert_type(INT @ TENSOR, Tensor) # type: ignore[operator] +assert_type(TENSOR @ FLOAT, Tensor) # Should fail type checking +assert_type(FLOAT @ TENSOR, Tensor) # type: ignore[operator] + +# +# Binary ops that take and return ints only +# + +# Operator << +assert_type(TENSOR << TENSOR, Tensor) +assert_type(TENSOR << BOOL, Tensor) +assert_type(BOOL << TENSOR, Tensor) assert_type(TENSOR << INT, Tensor) -assert_type(TENSOR >> INT, Tensor) -assert_type(TENSOR & INT, Tensor) -assert_type(TENSOR | INT, Tensor) -assert_type(TENSOR ^ INT, Tensor) +assert_type(INT << TENSOR, Tensor) +assert_type(TENSOR << FLOAT, Tensor) # Should fail type checking +assert_type(FLOAT << TENSOR, Tensor) # Should fail type checking -assert_type(INT == TENSOR, bool) -assert_type(INT != TENSOR, bool) -assert_type(INT < TENSOR, Tensor) -assert_type(INT > TENSOR, Tensor) -assert_type(INT <= TENSOR, Tensor) -assert_type(INT >= TENSOR, Tensor) -assert_type(INT + TENSOR, Tensor) -assert_type(INT - TENSOR, Any) -assert_type(INT * TENSOR, Tensor) -assert_type(INT // TENSOR, Any) -assert_type(INT / TENSOR, Any) -assert_type(INT % TENSOR, Any) -assert_type(INT**TENSOR, Any) -assert_type(INT << TENSOR, Any) -assert_type(INT >> TENSOR, Any) -assert_type(INT & TENSOR, Any) # type: ignore[operator] -assert_type(INT | TENSOR, Any) # type: ignore[operator] -assert_type(INT ^ TENSOR, Any) # type: ignore[operator] +# Operator >> +assert_type(TENSOR >> TENSOR, Tensor) +assert_type(TENSOR >> BOOL, Tensor) +assert_type(BOOL >> TENSOR, Tensor) +assert_type(TENSOR >> INT, Tensor) +assert_type(INT >> TENSOR, Tensor) +assert_type(TENSOR >> FLOAT, Tensor) # Should fail type checking +assert_type(FLOAT >> TENSOR, Tensor) # Should fail type checking -assert_type(TENSOR == FLOAT, Tensor) -assert_type(TENSOR != FLOAT, Tensor) -assert_type(TENSOR < FLOAT, Tensor) -assert_type(TENSOR > FLOAT, Tensor) -assert_type(TENSOR <= FLOAT, Tensor) -assert_type(TENSOR >= FLOAT, Tensor) -assert_type(TENSOR + FLOAT, Tensor) -assert_type(TENSOR - FLOAT, Tensor) -assert_type(TENSOR * FLOAT, Tensor) -assert_type(TENSOR // FLOAT, Any) -assert_type(TENSOR / FLOAT, Tensor) -assert_type(TENSOR % FLOAT, Tensor) -assert_type(TENSOR**FLOAT, Tensor) -assert_type(TENSOR << FLOAT, Tensor) -assert_type(TENSOR >> FLOAT, Tensor) -assert_type(TENSOR & FLOAT, Tensor) -assert_type(TENSOR | FLOAT, Tensor) -assert_type(TENSOR ^ FLOAT, Tensor) - -assert_type(FLOAT == TENSOR, bool) -assert_type(FLOAT != TENSOR, bool) -assert_type(FLOAT < TENSOR, Tensor) -assert_type(FLOAT > TENSOR, Tensor) -assert_type(FLOAT <= TENSOR, Tensor) -assert_type(FLOAT >= TENSOR, Tensor) -assert_type(FLOAT + TENSOR, Tensor) -assert_type(FLOAT - TENSOR, Any) -assert_type(FLOAT * TENSOR, Tensor) -assert_type(FLOAT // TENSOR, Any) -assert_type(FLOAT / TENSOR, Any) -assert_type(FLOAT % TENSOR, Any) -assert_type(FLOAT**TENSOR, Any) -assert_type(FLOAT << TENSOR, Any) -assert_type(FLOAT >> TENSOR, Any) +# Operator & +assert_type(TENSOR & TENSOR, Tensor) +assert_type(TENSOR & BOOL, Tensor) +assert_type(BOOL & TENSOR, Tensor) +assert_type(TENSOR & INT, Tensor) +assert_type(INT & TENSOR, Tensor) +assert_type(TENSOR & FLOAT, Tensor) # Should fail type checking assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator] + +# Operator | +assert_type(TENSOR | TENSOR, Tensor) +assert_type(TENSOR | BOOL, Tensor) +assert_type(BOOL | TENSOR, Tensor) +assert_type(TENSOR | INT, Tensor) +assert_type(INT | TENSOR, Tensor) +assert_type(TENSOR | FLOAT, Tensor) # Should fail type checking assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator] + +# Operator ^ +assert_type(TENSOR ^ TENSOR, Tensor) +assert_type(TENSOR ^ BOOL, Tensor) +assert_type(BOOL ^ TENSOR, Tensor) +assert_type(TENSOR ^ INT, Tensor) +assert_type(INT ^ TENSOR, Tensor) +assert_type(TENSOR ^ FLOAT, Tensor) # Should fail type checking assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator] @@ -373,47 +427,3 @@ def __xor__(self, other: NUMBER) -> "Binary": # type: ignore[override] assert_type(BOOL >> BINARY, Binary) assert_type(BOOL - BINARY, Binary) assert_type(BOOL ^ BINARY, Binary) - -# Tensor operators whose types could be improved -# This is the "diff" of the first and second sections. - -assert_type(BOOL // TENSOR, Any) -assert_type(FLOAT // TENSOR, Any) -assert_type(INT // TENSOR, Any) -assert_type(TENSOR // BOOL, Any) -assert_type(TENSOR // FLOAT, Any) -assert_type(TENSOR // INT, Any) -assert_type(TENSOR // TENSOR, Any) - -assert_type(BOOL**TENSOR, Any) -assert_type(FLOAT**TENSOR, Any) -assert_type(INT**TENSOR, Any) - -assert_type(BOOL - TENSOR, Any) -assert_type(FLOAT - TENSOR, Any) -assert_type(INT - TENSOR, Any) - -assert_type(BOOL / TENSOR, Any) -assert_type(FLOAT / TENSOR, Any) -assert_type(INT / TENSOR, Any) - -assert_type(BOOL % TENSOR, Any) -assert_type(FLOAT % TENSOR, Any) -assert_type(INT % TENSOR, Any) - -assert_type(BOOL << TENSOR, Any) -assert_type(FLOAT << TENSOR, Any) -assert_type(INT << TENSOR, Any) - -assert_type(BOOL >> TENSOR, Any) -assert_type(FLOAT >> TENSOR, Any) -assert_type(INT >> TENSOR, Any) - -assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator] -assert_type(INT & TENSOR, Any) # type: ignore[operator] - -assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator] -assert_type(INT | TENSOR, Any) # type: ignore[operator] - -assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator] -assert_type(INT ^ TENSOR, Any) # type: ignore[operator] diff --git a/test/typing/pass/torch_size.py b/test/typing/pass/torch_size.py index 2ea2088a52cb36..7368f9d9eed1d7 100644 --- a/test/typing/pass/torch_size.py +++ b/test/typing/pass/torch_size.py @@ -1,10 +1,8 @@ -from typing_extensions import assert_type - -from torch import Size +# mypy: enable-error-code=unused-ignore +from typing_extensions import assert_type, Never -s1 = Size([1, 2, 3]) -s2 = Size([1, 2, 3]) +from torch import Size class ZeroIndex: @@ -12,17 +10,36 @@ def __index__(self) -> int: return 0 +tup0: tuple[()] = () +tup1: tuple[int] = (1,) +tup2: tuple[int, int] = (1, 2) +tupN: tuple[int, int, int] = (1, 2, 3) +tupX: tuple[Never, ...] = tuple() +s = Size([1, 2, 3]) + +# assignability to tuple +t: tuple[int, ...] = s + # __getitem__ -assert_type(s1[0], int) -assert_type(s1[ZeroIndex()], int) -assert_type(s1[:2], Size) +assert_type(s[0], int) +assert_type(s[ZeroIndex()], int) +assert_type(s[:2], Size) # __add__ -assert_type(s1 + s2, Size) -assert_type(s1 + (1, 2), Size) -# Size has no __radd__, so tuple.__add__(right, left) is called -assert_type((1, 2) + s1, tuple[int, ...]) +assert_type(s + s, Size) +assert_type(s + tup0, Size) +assert_type(s + tup1, Size) +assert_type(s + tup2, Size) +assert_type(s + tupN, Size) +assert_type(s + tupX, Size) +# __radd__ +# NOTE: currently incorrect inference, see: https://github.com/python/mypy/issues/19006 +assert_type(tup0 + s, Size) # type: ignore[assert-type] +assert_type(tup1 + s, Size) # type: ignore[assert-type] +assert_type(tup2 + s, Size) # type: ignore[assert-type] +assert_type(tupN + s, Size) # type: ignore[assert-type] +assert_type(tupX + s, Size) # type: ignore[assert-type] # __mul__ -assert_type(s1 * 3, Size) -assert_type(s1 * ZeroIndex(), Size) -assert_type(3 * s1, Size) -assert_type(ZeroIndex() * s1, Size) +assert_type(s * 3, Size) +assert_type(s * ZeroIndex(), Size) +assert_type(3 * s, Size) +assert_type(ZeroIndex() * s, Size) diff --git a/test/typing/test_python_operators.py b/test/typing/test_python_operators.py new file mode 100644 index 00000000000000..d7146b7e580f6b --- /dev/null +++ b/test/typing/test_python_operators.py @@ -0,0 +1,172 @@ +# mypy: ignore-errors +# Owner(s): ["module: unknown"] +import token +from itertools import product +from pathlib import Path + +import torch +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) + + +MM = "@" + +BINARY_RETURNS_BOOL = "!=", "<", "<=", "==", ">", ">=" +BINARY_ACCEPTS_FLOAT_OR_INT = "%", "*", "**", "+", "-", "/", "//" +BINARY_ACCEPTS_INT_ONLY = "&", "<<", ">>", "^", "|" +BINARY_OPS = ( + *BINARY_RETURNS_BOOL, + *BINARY_ACCEPTS_FLOAT_OR_INT, + *BINARY_ACCEPTS_INT_ONLY, + MM, +) + +BINARY_RETURNS_FLOAT = ("/",) + +UNARY_ACCEPTS_FLOAT_OR_INT = "+", "-" +UNARY_ACCEPTS_INT_ONLY = ("~",) +UNARY_OPS = *UNARY_ACCEPTS_FLOAT_OR_INT, *UNARY_ACCEPTS_INT_ONLY + +PUNCTUATION = ",", ";" + +OPERATORS = *UNARY_OPS, *BINARY_OPS, *PUNCTUATION + +FLOATS = 1.5, torch.tensor((2.5, 3.5)) +INTS = 3, torch.tensor((1, 2)) +ALL = *FLOATS, *INTS + +TYPE_TEST_FILE = Path(__file__).parent / "pass/arithmetic_ops.py" + + +class TestPythonOperators(TestCase): + # Prove that UNARY_OPS, BINARY_OPS, and OPERATORS are correct and complete + def test_operators_are_correct_and_complete(self): + self.assertFalse(set(OPERATORS).difference(token.EXACT_TOKEN_TYPES)) + + unary, binary, punctuation = {}, {}, {} + + for op in token.EXACT_TOKEN_TYPES: + if op in PUNCTUATION: + punctuation[op] = True + else: + try: + unary[op] = compile(f"{op}1 ; {op}a", op, "single") + except SyntaxError: + pass + try: + binary[op] = compile(f"2 {op} 3 ; a {op} b", op, "single") + except SyntaxError: + pass + + self.assertEqual(sorted(unary), sorted(UNARY_OPS)) + self.assertEqual(sorted(binary), sorted(BINARY_OPS)) + self.assertEqual(sorted(punctuation), sorted(PUNCTUATION)) + + def test_type_tests_are_complete(self): + binary, unary = {}, [] + + with TYPE_TEST_FILE.open() as fp: + # Looking for lines like: assert_type(TENSOR ^ BOOL, Tensor) + # But not: assert_type(BOOL ^ BINARY, Binary) + lines = (i for i in fp if "TENSOR" in i) + for line in lines: + if expr := line.partition("assert_type(")[2].partition(",")[0]: + if expr[0].isalpha(): + # ** formats differently from all other operators + a, op, b = expr.replace("**", " ** ").split() + binary.setdefault(op, []).append((a, b)) + else: + unary.append(expr[0]) + + self.assertEqual(sorted(unary), sorted(UNARY_OPS)) + self.assertEqual(sorted(binary), sorted(BINARY_OPS)) + value, *values = binary.values() + self.assertEqual(values, [value] * len(values)) + + @parametrize("a, op, b", product(ALL, BINARY_OPS, ALL)) + def test_binary(self, a, op, b): + try: + r = eval(f"a {op} b") + except Exception as e: + r = e + + any_tensor = isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor) + any_float = _any_float(a, b) + returns_float = any_float or op in BINARY_RETURNS_FLOAT + + if op == MM: + if not (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)): + self.assertIsInstance(r, TypeError) + elif a is b: + self.assertIsInstance(r, torch.Tensor) + else: + self.assertIsInstance(r, RuntimeError) + + elif any_tensor: + if op in BINARY_ACCEPTS_INT_ONLY and any_float: + # See https://github.com/pytorch/pytorch/issues/15754 + self.assertIsInstance(r, NotImplementedError) + else: + self.assertIsInstance(r, torch.Tensor) + + if op in BINARY_RETURNS_BOOL: + self.assertEqual(r.dtype, torch.bool) + elif op in BINARY_ACCEPTS_INT_ONLY: + self.assertFalse(r.dtype.is_floating_point) + elif op in BINARY_ACCEPTS_FLOAT_OR_INT: + self.assertEqual(r.dtype.is_floating_point, returns_float) + else: + self.assertFalse("Logic error") + + elif op in BINARY_RETURNS_BOOL: + self.assertIsInstance(r, bool) + + elif op in BINARY_ACCEPTS_INT_ONLY: + if any_float: + self.assertIsInstance(r, TypeError) + else: + self.assertIsInstance(r, int) + + elif returns_float: + self.assertIsInstance(r, float) + + else: + self.assertIsInstance(r, int) + + @parametrize("op, a", product(UNARY_OPS, ALL)) + def test_unary(self, op, a): + try: + r = eval(f"{op} a") + except Exception as e: + r = e + + if op in UNARY_ACCEPTS_INT_ONLY and _any_float(a): + self.assertIsInstance(r, TypeError) + elif isinstance(a, torch.Tensor): + self.assertIsInstance(r, torch.Tensor) + elif op in UNARY_ACCEPTS_INT_ONLY: + self.assertIsInstance(r, int) + elif isinstance(a, float): + self.assertIsInstance(r, float) + else: + self.assertIsInstance(r, int) + + +def _any_float(*x): + for i in x: + if isinstance(i, float) or ( + isinstance(i, torch.Tensor) and i.dtype.is_floating_point + ): + return True + return False + + +instantiate_parametrized_tests(TestPythonOperators) + + +if __name__ == "__main__": + run_tests() diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index 7207e1aed24ece..103a71d5debb6c 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -1344,6 +1344,26 @@ def weight_int4pack_mm(a, b_int4pack, qscale, qzeros): mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) + def test_mm_with_offset(self, device): + from torch._dynamo.testing import rand_strided + + offset = 997 + a = rand_strided( + (2, 4, 128, 64), + (65536, 16384, 64, 1), + dtype=torch.float16, + device=device, + extra_size=offset, + ) + a = a.as_strided((2, 4, 128, 64), (65536, 16384, 64, 1), storage_offset=offset) + b = rand_strided( + (2, 4, 64, 256), (65536, 16384, 1, 64), dtype=torch.float16, device=device + ) + + gpu_out = torch.matmul(a, b) + cpu_out = torch.matmul(a.cpu(), b.cpu()) + self.assertEqual(gpu_out.cpu(), cpu_out) + instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) diff --git a/third_party/NVTX b/third_party/NVTX index e170594ac7cf1d..2942f167cc30c5 160000 --- a/third_party/NVTX +++ b/third_party/NVTX @@ -1 +1 @@ -Subproject commit e170594ac7cf1dac584da473d4ca9301087090c1 +Subproject commit 2942f167cc30c5e3a44a2aecd5b0d9c07ff61a07 diff --git a/third_party/VulkanMemoryAllocator b/third_party/VulkanMemoryAllocator index a6bfc237255a6b..1d8f600fd42427 160000 --- a/third_party/VulkanMemoryAllocator +++ b/third_party/VulkanMemoryAllocator @@ -1 +1 @@ -Subproject commit a6bfc237255a6bac1513f7c1ebde6d8aed6b5191 +Subproject commit 1d8f600fd424278486eade7ed3e877c99f0846b1 diff --git a/third_party/aiter b/third_party/aiter new file mode 160000 index 00000000000000..01aae101b9e5e9 --- /dev/null +++ b/third_party/aiter @@ -0,0 +1 @@ +Subproject commit 01aae101b9e5e94d6c16a9514c9fb8df99c93150 diff --git a/third_party/composable_kernel b/third_party/composable_kernel index 8086bbe3a78d93..434d19f696da62 160000 --- a/third_party/composable_kernel +++ b/third_party/composable_kernel @@ -1 +1 @@ -Subproject commit 8086bbe3a78d931eb96fe12fdc014082e18d18d3 +Subproject commit 434d19f696da62c12b5372b32cbc9ba968588d7e diff --git a/third_party/cpp-httplib b/third_party/cpp-httplib index 3b6597bba913d5..3af7f2c16147f3 160000 --- a/third_party/cpp-httplib +++ b/third_party/cpp-httplib @@ -1 +1 @@ -Subproject commit 3b6597bba913d51161383657829b7e644e59c006 +Subproject commit 3af7f2c16147f3fbc6e4d717032daf505dc1652c diff --git a/third_party/eigen b/third_party/eigen deleted file mode 160000 index 3147391d946bb4..00000000000000 --- a/third_party/eigen +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3147391d946bb4b6c68edd901f2add6ac1f31f8c diff --git a/third_party/eigen_pin.txt b/third_party/eigen_pin.txt new file mode 100644 index 00000000000000..18091983f59ddd --- /dev/null +++ b/third_party/eigen_pin.txt @@ -0,0 +1 @@ +3.4.0 diff --git a/third_party/flatbuffers b/third_party/flatbuffers index 01834de25e4bf3..a2cd1ea3b6d3fe 160000 --- a/third_party/flatbuffers +++ b/third_party/flatbuffers @@ -1 +1 @@ -Subproject commit 01834de25e4bf3975a9a00e816292b1ad0fe184b +Subproject commit a2cd1ea3b6d3fee220106b5fed3f7ce8da9eb757 diff --git a/third_party/gloo b/third_party/gloo index fe67c4bea940a1..c7b7b022c124d9 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit fe67c4bea940a117ff539d23f4110efc19404edb +Subproject commit c7b7b022c124d9643957d9bd55f57ac59fce8fa2 diff --git a/third_party/kineto b/third_party/kineto index 20f652846f651f..5e7501833f1021 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 20f652846f651fcae287f667d34bcf164c99f383 +Subproject commit 5e7501833f1021ce6f618572d3baf657b6319658 diff --git a/third_party/mimalloc b/third_party/mimalloc index b66e3214d8a104..94036de6fe20bf 160000 --- a/third_party/mimalloc +++ b/third_party/mimalloc @@ -1 +1 @@ -Subproject commit b66e3214d8a104669c2ec05ae91ebc26a8f5ab78 +Subproject commit 94036de6fe20bfd8a73d4a6d142fcf532ea604d9 diff --git a/third_party/nlohmann b/third_party/nlohmann index 87cda1d6646592..55f93686c01528 160000 --- a/third_party/nlohmann +++ b/third_party/nlohmann @@ -1 +1 @@ -Subproject commit 87cda1d6646592ac5866dc703c8e1839046a6806 +Subproject commit 55f93686c01528224f448c19128836e7df245f72 diff --git a/third_party/pocketfft b/third_party/pocketfft index 9d3ab05a7fffbc..0fa0ef591e38c2 160000 --- a/third_party/pocketfft +++ b/third_party/pocketfft @@ -1 +1 @@ -Subproject commit 9d3ab05a7fffbc71a492bc6a17be034e83e8f0fe +Subproject commit 0fa0ef591e38c2758e3184c6c23e497b9f732ffa diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 16ca2b7d6cfc4c..f3cfe7166aa779 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -4e027f1e1c560d7dc7db7eb41e48bdee5fc00707 +3a9419c8bb6a98dd3e3cd473c36691fb4abeae40 diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl index 86a48bbc246288..5b25c5caeb1062 100644 --- a/tools/BUCK.bzl +++ b/tools/BUCK.bzl @@ -287,18 +287,3 @@ def define_tools_targets( ":autograd", ], ) - - python_test( - name = "test_torchgen_executorch", - srcs = [ - "test/test_executorch_gen.py", - "test/test_executorch_signatures.py", - "test/test_executorch_types.py", - "test/test_executorch_unboxing.py", - ], - contacts = contacts, - visibility = ["PUBLIC"], - deps = [ - torchgen_deps, - ], - ) diff --git a/tools/alerts/create_alerts.py b/tools/alerts/create_alerts.py index 97607e07fa0a6f..6b679a0306824b 100644 --- a/tools/alerts/create_alerts.py +++ b/tools/alerts/create_alerts.py @@ -11,7 +11,7 @@ from typing import Any import requests -from setuptools import distutils # type: ignore[import] +from setuptools import distutils # type: ignore[import,attr-defined] ALL_SKIPPED_THRESHOLD = 100 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fe4dd72b247d35..e2419aab268b16 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -113,7 +113,7 @@ # - `wrap_opt_if`, is a 2-argument function that accepts a tensor # variable and a boolean condition that dictates whether to save that # variable in a graph. The result of this function is `std::optional`, -# and it is `::std::nullopt` when the condition evalutes to `false`, +# and it is `::std::nullopt` when the condition evaluates to `false`, # otherwise it is the variable wrapped in `std::optional`. # For example, wrap_opt_if(var_0, grad_input_mask[1] || grad_input_mask[2]) # would mean that `var_0` is saved as long as the second (grad_input_mask[1]) @@ -200,7 +200,7 @@ # Undefined Tensors are created with the default constructor `at::Tensor()`. # It is an efficient way to represent a Tensor filled with zeros because # the Tensor holds no sizing information and no Storage data is allocated. -# But consequentially, Tensor operations cannot be performed on them. +# But consequently, Tensor operations cannot be performed on them. # Therefore, your backward function should treat an undefined output grad as # a zero, and it needs to be a special case. # diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 0fd882d00cf1ad..684290da0a7260 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -2,7 +2,7 @@ # # NOTE: If any changes are being made to the ADInplaceOrView codegen please also check # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp -# The fallback is expected to mimick this codegen, so we should keep the two in sync. +# The fallback is expected to mimic this codegen, so we should keep the two in sync. from __future__ import annotations diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index f1e0140a415546..995243a9e6b4fc 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -148,7 +148,7 @@ "mH", # these need to be an attributes in Python, not functions "nonzero(_(out|numpy))?", "set_data", - ".*_overrideable", # overrideable functions for backend extension + ".*_overrideable", # overridable functions for backend extension "data", "is_leaf", "output_nr", @@ -617,7 +617,7 @@ def load_deprecated_signatures( schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} for name in call_args: assert name in schema_args_by_name or name in known_constants, ( - f"deprecation definiton: Unrecognized value {name}" + f"deprecation definition: Unrecognized value {name}" ) # Map deprecated signature arguments to their aten signature and test diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index 67f71d2df5034c..21069b4671e24e 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -17,7 +17,7 @@ # Note [Manual Backend kernels] # For these ops, we want to manually register to dispatch key Backend and -# skip codegen-ed registeration to all keys before Backend. +# skip codegen-ed registration to all keys before Backend. # For codegen this means: # - op set below must match ops with manual_kernel_registration=True in native_functions.yaml # where we skip codegen backend kernels diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 6df4d389fa55d8..f61226f25fb90f 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -336,7 +336,7 @@ def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]: # This transformation is based on the observation that for element-wise functions, the Jacobian # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions) # For the complex case, we use hermitian transpose and get (v.conj() J).conj() - # So here we are going to re-use the backward formula and replace two things: + # So here we are going to reuse the backward formula and replace two things: # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input. # 2) all usage of an original input "foo" with its primal value "foo_p". # 3) conjugate the final result diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index cd7bc0281981f3..bfc5b80835c4b2 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -264,7 +264,7 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec auto& self_ = THPVariable_Unpack(self); auto memory_format = r.memoryformat(0); // avoids touching the GIL or current device if self is already contiguous - if (self_.is_contiguous(memory_format)) { + if (self_.is_contiguous_or_false(memory_format)) { // NOTE: this logic is duplicated from VariableType.cpp. Since we need to // record this call to contiguous() in the trace regardless of whether // we actually call contiguous here, we need to record this information diff --git a/tools/bazel_tools/shellwrap.sh b/tools/bazel_tools/shellwrap.sh index 1ebab29a6a73c8..712788ae09e06d 100755 --- a/tools/bazel_tools/shellwrap.sh +++ b/tools/bazel_tools/shellwrap.sh @@ -54,5 +54,5 @@ echo "Entering interactive shell at the execution root:" # quote escape all the arguments to use as a single input string cmd="'$shell' --noprofile --rcfile '$rcfile'" -# run the command in a script psuedo terminal and dump to null +# run the command in a script pseudo terminal and dump to null /usr/bin/script -c "$cmd" -q /dev/null diff --git a/tools/build_defs/buck_helpers.bzl b/tools/build_defs/buck_helpers.bzl index 2353fae91101da..aced2308ba24c7 100644 --- a/tools/build_defs/buck_helpers.bzl +++ b/tools/build_defs/buck_helpers.bzl @@ -24,7 +24,7 @@ ONLY_AVAILABLE_IN_BUCK2 = [ def filter_attributes(kwgs): keys = list(kwgs.keys()) - # drop unncessary attributes + # drop unnecessary attributes for key in keys: if key in IGNORED_ATTRIBUTES or key in ONLY_AVAILABLE_IN_BUCK2: kwgs.pop(key) diff --git a/tools/build_defs/type_defs.bzl b/tools/build_defs/type_defs.bzl index 7a905e7d6cc072..6dc5ffe42d1756 100644 --- a/tools/build_defs/type_defs.bzl +++ b/tools/build_defs/type_defs.bzl @@ -83,7 +83,7 @@ def is_bool(arg): """Checks if provided instance is a boolean value. Args: - arg: An instance ot check. type: Any + arg: An instance to check. type: Any Returns: True for boolean values, False otherwise. rtype: bool @@ -96,7 +96,7 @@ def is_number(arg): """Checks if provided instance is a number value. Args: - arg: An instance ot check. type: Any + arg: An instance to check. type: Any Returns: True for number values, False otherwise. rtype: bool @@ -109,7 +109,7 @@ def is_struct(arg): """Checks if provided instance is a struct value. Args: - arg: An instance ot check. type: Any + arg: An instance to check. type: Any Returns: True for struct values, False otherwise. rtype: bool diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 9b078696cc2b82..457b224354fb2d 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -3,8 +3,8 @@ import os import platform import subprocess -from pathlib import Path +from .optional_submodules import checkout_nccl from .setup_helpers.cmake import CMake, USE_NINJA from .setup_helpers.env import ( check_env_flag, @@ -14,19 +14,17 @@ ) -repo_root = Path(__file__).absolute().parent.parent -third_party_path = os.path.join(repo_root, "third_party") - - def _get_vc_env(vc_arch: str) -> dict[str, str]: try: - from setuptools import distutils # type: ignore[import] + from setuptools import distutils # type: ignore[import,attr-defined] return distutils._msvccompiler._get_vc_env(vc_arch) # type: ignore[no-any-return] except AttributeError: - from setuptools._distutils import _msvccompiler # type: ignore[import] + from setuptools._distutils import ( + _msvccompiler, # type: ignore[import,attr-defined] + ) - return _msvccompiler._get_vc_env(vc_arch) # type: ignore[no-any-return] + return _msvccompiler._get_vc_env(vc_arch) # type: ignore[no-any-return,attr-defined] def _overlay_windows_vcvars(env: dict[str, str]) -> dict[str, str]: @@ -80,39 +78,6 @@ def _create_build_env() -> dict[str, str]: return my_env -def read_nccl_pin() -> str: - nccl_file = "nccl-cu12.txt" - if os.getenv("DESIRED_CUDA", "").startswith("11") or os.getenv( - "CUDA_VERSION", "" - ).startswith("11"): - nccl_file = "nccl-cu11.txt" - nccl_pin_path = os.path.join( - repo_root, ".ci", "docker", "ci_commit_pins", nccl_file - ) - with open(nccl_pin_path) as f: - return f.read().strip() - - -def checkout_nccl() -> None: - release_tag = read_nccl_pin() - print(f"-- Checkout nccl release tag: {release_tag}") - nccl_basedir = os.path.join(third_party_path, "nccl") - if not os.path.exists(nccl_basedir): - subprocess.check_call( - [ - "git", - "clone", - "--depth", - "1", - "--branch", - release_tag, - "https://github.com/NVIDIA/nccl.git", - "nccl", - ], - cwd=third_party_path, - ) - - def build_pytorch( version: str | None, cmake_python_library: str | None, @@ -134,4 +99,20 @@ def build_pytorch( ) if cmake_only: return + build_custom_step = os.getenv("BUILD_CUSTOM_STEP") + if build_custom_step: + try: + output = subprocess.check_output( + build_custom_step, + shell=True, + stderr=subprocess.STDOUT, + text=True, + ) + print("Command output:") + print(output) + except subprocess.CalledProcessError as e: + print("Command failed with return code:", e.returncode) + print("Output (stdout and stderr):") + print(e.output) + raise cmake.build(my_env) diff --git a/tools/build_with_debinfo.py b/tools/build_with_debinfo.py index 73c9dba0090bd7..8742966aabe854 100755 --- a/tools/build_with_debinfo.py +++ b/tools/build_with_debinfo.py @@ -38,7 +38,7 @@ def get_lib_extension() -> str: return "so" if sys.platform == "darwin": return "dylib" - raise RuntimeError(f"Usupported platform {sys.platform}") + raise RuntimeError(f"Unsupported platform {sys.platform}") def create_symlinks() -> None: @@ -78,7 +78,7 @@ def create_build_plan() -> list[tuple[str, str]]: if line.startswith(": &&") and line.endswith("&& :"): line = line[4:-4] line = line.replace("-O2", "-g").replace("-O3", "-g") - # Build Metal shaders with debug infomation + # Build Metal shaders with debug information if "xcrun metal " in line and "-frecord-sources" not in line: line += " -frecord-sources -gline-tables-only" try: diff --git a/tools/code_analyzer/gen_operators_yaml.py b/tools/code_analyzer/gen_operators_yaml.py index ede6516798479e..a5a80c1c66f59b 100644 --- a/tools/code_analyzer/gen_operators_yaml.py +++ b/tools/code_analyzer/gen_operators_yaml.py @@ -68,13 +68,13 @@ # used by training, and not just the root operators. All Training ops are # also considered for inference, so these are merged into inference ops. # -# 3. Operator Depencency Graph (--dep-graph-yaml-path): A path to the +# 3. Operator Dependency Graph (--dep-graph-yaml-path): A path to the # operator dependency graph used to determine which operators depend on # which other operators for correct functioning. This is used for # generating the transitive closure of all the operators used by the # model based on the root operators when static selective build is used. # For tracing based selective build, we don't need to perform this -# transitive cloure. +# transitive closure. # # 4. Model Metadata (--model-name, --model-versions, --model-assets, # --model-backends): Self-descriptive. These are used to tell this diff --git a/tools/code_coverage/package/tool/print_report.py b/tools/code_coverage/package/tool/print_report.py index 26c20aca231aae..ea099751d74044 100644 --- a/tools/code_coverage/package/tool/print_report.py +++ b/tools/code_coverage/package/tool/print_report.py @@ -133,7 +133,7 @@ def print_file_oriented_report( coverage_percentage = print_file_summary( covered_summary, total_summary, summary_file ) - # print test condition (interested folder / tests that are successsful or failed) + # print test condition (interested folder / tests that are successful or failed) print_test_condition( tests, tests_type, @@ -204,7 +204,7 @@ def html_oriented_report() -> None: # use lcov to generate the coverage report build_folder = os.path.join(get_pytorch_folder(), "build") coverage_info_file = os.path.join(SUMMARY_FOLDER_DIR, "coverage.info") - # generage coverage report -- coverage.info in build folder + # generate coverage report -- coverage.info in build folder subprocess.check_call( [ "lcov", diff --git a/tools/config/defs.bzl b/tools/config/defs.bzl index 6ddd0e991561aa..f8a1e9dc16f265 100644 --- a/tools/config/defs.bzl +++ b/tools/config/defs.bzl @@ -27,7 +27,7 @@ def if_rocm(if_true, if_false = []): def if_sycl(if_true, if_false = []): """Helper for selecting based on the whether SYCL/ComputeCPP is configured.""" - # NOTE: Tensorflow expects some stange behavior (see their if_sycl) if we + # NOTE: Tensorflow expects some strange behavior (see their if_sycl) if we # actually plan on supporting this at some point. return select({ "//conditions:default": if_false, diff --git a/tools/coverage_plugins_package/pyproject.toml b/tools/coverage_plugins_package/pyproject.toml index 374b58cbf4636f..5f271d56ffaf24 100644 --- a/tools/coverage_plugins_package/pyproject.toml +++ b/tools/coverage_plugins_package/pyproject.toml @@ -1,6 +1,22 @@ [build-system] -requires = [ - "setuptools>=42", - "wheel" -] +requires = ["setuptools>=77"] build-backend = "setuptools.build_meta" + +[project] +name = "coverage-plugins" +version = "0.0.1" +description = "Plug-in to coverage for PyTorch JIT" +readme = "README.md" +license = "MIT" +authors = [{ name = "PyTorch Team", email = "packages@pytorch.org" }] +classifiers = [ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", +] + +[project.urls] +Repository = "https://github.com/pytorch/pytorch" +"Issue Tracker" = "https://github.com/pytorch/pytorch/issues" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/tools/coverage_plugins_package/setup.py b/tools/coverage_plugins_package/setup.py deleted file mode 100644 index 10214ec72f9f9b..00000000000000 --- a/tools/coverage_plugins_package/setup.py +++ /dev/null @@ -1,27 +0,0 @@ -import setuptools # type: ignore[import] - - -with open("README.md", encoding="utf-8") as fh: - long_description = fh.read() - -setuptools.setup( - name="coverage-plugins", - version="0.0.1", - author="PyTorch Team", - author_email="packages@pytorch.org", - description="plug-in to coverage for PyTorch JIT", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/pytorch/pytorch", - project_urls={ - "Bug Tracker": "https://github.com/pytorch/pytorch/issues", - }, - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - package_dir={"": "src"}, - packages=setuptools.find_packages(where="src"), - python_requires=">=3.6", -) diff --git a/tools/dynamo/gb_id_mapping.py b/tools/dynamo/gb_id_mapping.py new file mode 100644 index 00000000000000..80e72842b807f0 --- /dev/null +++ b/tools/dynamo/gb_id_mapping.py @@ -0,0 +1,468 @@ +# mypy: ignore-errors + +import argparse +import ast +import json +import re +import sys +from pathlib import Path + + +def get_source_segment(source, node): + return ast.get_source_segment(source, node) + + +def load_registry(path): + if path.exists(): + with path.open() as f: + return json.load(f) + return {} + + +def save_registry(reg, path): + with path.open("w") as f: + json.dump(reg, f, indent=2) + + +def next_gb_id(reg): + ids = [int(x[2:]) for x in reg if x.startswith("GB") and x[2:].isdigit()] + return f"GB{(max(ids, default=0) + 1):04d}" + + +def clean_string(s): + """ + Normalizes string literals by removing formatting artifacts and escape sequences. + Handles f-strings, quotes, newlines, and other syntax elements for cleaner output. + """ + if isinstance(s, str): + # Convert f-string prefix to regular string prefix (e.g., f"hello" -> "hello") + s = re.sub(r'^f["\']', r'"', s) + # Replace quoted strings with f-prefix in the middle with a space (e.g., " f"" -> " ") + s = re.sub(r'["\'] f["\']', " ", s) + # Remove surrounding quotes, keeping only the content (e.g., "hello" -> hello) + s = re.sub(r'^["\'](.*)["\']$', r"\1", s) + # Replace any whitespace + s = " ".join(s.splitlines()) + # Replace escaped quotes with their unescaped versions + s = s.encode().decode("unicode_escape") + # Replace adjacent quoted strings with a space (e.g., " "" -> " ") + s = re.sub(r'" "', " ", s) + return s + + +def expand_hints(hints): + # Expands hint references to their actual values from graph_break_hints. + from torch._dynamo import graph_break_hints + + hint_constants = { + name: value + for name, value in graph_break_hints.__dict__.items() + if isinstance(value, list) and name.isupper() + } + + expanded_hints = [] + for hint in hints: + for name, value in hint_constants.items(): + if f"*graph_break_hints.{name}" in hint: + expanded_hints.extend(value) + break + return expanded_hints + + +def extract_info_from_keyword(source, kw): + """ + Extracts and returns the value of a keyword argument from an AST node. + + This function handles different types of AST nodes: + - If the node is a constant, it returns the constant value. + - If the node is an f-string, it reconstructs the string by + evaluating formatted values and concatenating them with string literals. + - For other types, it cleans the source segment to remove formatting artifacts. + + """ + param_source = get_source_segment(source, kw.value) + if isinstance(kw.value, ast.Constant): + return kw.value.value + elif isinstance(kw.value, ast.JoinedStr): + evaluated_context = [] + for value in kw.value.values: + if isinstance(value, ast.FormattedValue): + evaluated_context.append(f"{{{ast.unparse(value.value)}}}") + elif isinstance(value, ast.Constant): + evaluated_context.append(value.value) + return "".join(evaluated_context) + else: + return clean_string(param_source) + + +def find_unimplemented_v2_calls(path): + results = [] + path = Path(path) + + if path.is_dir(): + file_paths = path.glob("**/*.py") + else: + file_paths = [path] + + for file_path in file_paths: + with open(file_path) as f: + source = f.read() + try: + tree = ast.parse(source) + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + if node.name in ( + "unimplemented_v2", + "unimplemented_v2_with_warning", + ): + continue + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id + in ("unimplemented_v2", "unimplemented_v2_with_warning") + ): + info = { + "gb_type": None, + "context": None, + "explanation": None, + "hints": [], + } + + for kw in node.keywords: + if kw.arg in info: + info[kw.arg] = extract_info_from_keyword(source, kw) + + if info["gb_type"] is None: + continue + + if info["hints"]: + hints = info["hints"] + expanded_hints = [] + items = re.findall(r'"([^"]*)"', hints) + if items: + expanded_hints.extend(items) + + if "*graph_break_hints." in hints: + expanded_hints.extend(expand_hints([hints])) + + info["hints"] = expanded_hints + + results.append(info) + except SyntaxError: + print(f"Syntax error in {file_path}") + + return results + + +def cmd_add_new_gb_type(gb_type, file_path, registry_path, additional_info=None): + """ + Add a new graph break type to the registry. + + Args: + gb_type: The graph break type to add + file_path: Path to the file containing the unimplemented_v2 call + registry_path: Path to the registry JSON file + """ + registry_path = Path(registry_path) + reg = load_registry(registry_path) + + existing_gb_types = {entry[0]["Gb_type"] for entry in reg.values()} + if gb_type in existing_gb_types: + print( + f"Error: gb_type '{gb_type}' already exists in registry. Please rename the gb_type so it can be unique." + ) + return False + + calls = find_unimplemented_v2_calls(Path(file_path)) + matching_call = next((call for call in calls if call["gb_type"] == gb_type), None) + + if not matching_call: + print( + f"Error: Could not find unimplemented_v2 call with gb_type '{gb_type}' in {file_path}" + ) + return False + + gb_id = next_gb_id(reg) + reg[gb_id] = [ + { + "Gb_type": gb_type, + "Context": matching_call["context"], + "Explanation": matching_call["explanation"], + "Hints": matching_call["hints"] or [], + **({"Additional_Info": [additional_info]} if additional_info else {}), + } + ] + + save_registry(reg, registry_path) + print(f"Added {gb_type} to registry with ID {gb_id}") + return True + + +def cmd_update_gb_type( + old_gb_type, file_path, registry_path, new_gb_type=None, additional_info=None +): + """ + Update an existing graph break type in the registry by adding a new version + to the version history list. + + Args: + old_gb_type: The current graph break type to update + file_path: Path to the file containing the updated unimplemented_v2 call + registry_path: Path to the registry JSON file + new_gb_type: Optional new gb_type name to replace the old one + """ + registry_path = Path(registry_path) + reg = load_registry(registry_path) + + gb_id_map = {entry[0]["Gb_type"]: id for id, entry in reg.items()} + gb_id = gb_id_map.get(old_gb_type) + + if gb_id is None: + print(f"Error: gb_type '{old_gb_type}' not found in registry.") + return False + + search_gb_type = new_gb_type if new_gb_type else old_gb_type + calls = find_unimplemented_v2_calls(Path(file_path)) + matching_call = next( + (call for call in calls if call["gb_type"] == search_gb_type), None + ) + + if not matching_call: + print( + f"Error: Could not find unimplemented_v2 call with gb_type '{search_gb_type}' in {file_path}" + ) + return False + + if ( + matching_call["gb_type"] != old_gb_type + and matching_call["gb_type"] in gb_id_map + ): + print( + f"Error: New gb_type '{matching_call['gb_type']}' already exists in registry. Please use a unique gb_type." + ) + return False + + new_entry = { + "Gb_type": matching_call["gb_type"], + "Context": matching_call["context"], + "Explanation": matching_call["explanation"], + "Hints": matching_call["hints"] or [], + } + + if additional_info: + additional_info_list = reg[gb_id][0].get("Additional_Info", []) + new_entry["Additional_Info"] = ( + additional_info_list + [additional_info] + if additional_info_list + else [additional_info] + ) + elif "Additional_Info" in reg[gb_id][0]: + new_entry["Additional_Info"] = reg[gb_id][0]["Additional_Info"] + + reg[gb_id].insert(0, new_entry) + + save_registry(reg, registry_path) + print( + f"Updated {old_gb_type} to {matching_call['gb_type']} in registry with ID {gb_id}" + ) + return True + + +def test_verify_gb_id_mapping(dynamo_dir, registry_path): + """ + Verifies that all unimplemented_v2 calls in torch/_dynamo match entries in the registry. + """ + script_dir = Path(__file__).resolve().parent + dynamo_dir = script_dir.parent.parent / "torch" / "_dynamo" + registry_path = ( + script_dir.parent.parent / "torch" / "_dynamo" / "graph_break_registry.json" + ) + + python_files = list(dynamo_dir.glob("**/*.py")) + + reg = load_registry(registry_path) + gb_type_to_entry = {entries[0]["Gb_type"]: entries[0] for _, entries in reg.items()} + + mismatches = [] + for file_path in python_files: + calls = find_unimplemented_v2_calls(file_path) + for call in calls: + gb_type = call["gb_type"] + if gb_type not in gb_type_to_entry: + mismatches.append((gb_type, file_path, "Not found in registry")) + continue + + entry = gb_type_to_entry[gb_type] + if call["context"] != entry["Context"]: + mismatches.append((gb_type, file_path, "Context mismatch")) + elif call["explanation"] != entry["Explanation"]: + mismatches.append((gb_type, file_path, "Explanation mismatch")) + elif sorted(call["hints"]) != sorted(entry["Hints"]): + mismatches.append((gb_type, file_path, "Hints mismatch")) + + if mismatches: + print( + "Found the unimplemented_v2 or unimplemented_v2_with_warning calls below that " + "don't match the registry in graph_break_registry.json." + ) + for gb_type, file_path, reason in mismatches: + print(f" - {gb_type} in {file_path}: {reason}") + + print("Please update the registry using one of these commands:") + + print( + "- If you added a new callsite: python tools/dynamo/gb_id_mapping.py add " + '"GB_TYPE" PATH_TO_FILE --additional-info "INFO"' + ) + + print( + " • GB_TYPE: The graph break type string used in your unimplemented_v2 call" + " • PATH_TO_FILE: Path to the file containing your new unimplemented_v2 call" + " • --additional-info: Optional extra information to include in the registry entry" + ) + + print( + '- If you updated an existing callsite: python tools/dynamo/gb_id_mapping.py update "GB_TYPE" PATH_TO_FILE ' + '--new_gb_type "NEW_NAME" --additional-info "INFO"' + ) + print(" • GB_TYPE: The original graph break type to update") + print(" • PATH_TO_FILE: Path to the file containing the updated call") + print(" • --new_gb_type: New name if you changed the graph break type") + print(" • --additional-info: Optional extra information to add") + print( + "- Recreate registry (Only do this if a complete reset is needed): python tools/dynamo/gb_id_mapping.py create" + ) + print( + "If you have also wrote a test for the new graph break, please update the test as well " + "using EXPECTTEST_ACCEPT=1 so the message includes the respective webpage " + ) + print( + "Note: If you've reset the entire registry file, you can force push to bypass this check." + ) + return False + + print("All unimplemented_v2 calls match the registry.") + return True + + +def create_registry(dynamo_dir, registry_path): + calls = find_unimplemented_v2_calls(dynamo_dir) + registry = {} + + gb_types = {} + for info in calls: + gb_types[info["gb_type"]] = info + + GB_ID_INDEX = 0000 + for i, (gb_type, info) in enumerate(sorted(gb_types.items()), GB_ID_INDEX): + gb_id = f"GB{i:04d}" + hints = info["hints"] + + registry[gb_id] = [ + { + "Gb_type": gb_type, + "Context": info["context"], + "Explanation": info["explanation"], + "Hints": hints if hints else [], + } + ] + + with open(registry_path, "w") as f: + json.dump(registry, f, indent=2) + + +def main(): + repo_root = Path(__file__).resolve().parent.parent.parent + registry_path = repo_root / "torch" / "_dynamo" / "graph_break_registry.json" + + try: + import torch._dynamo + + default_dynamo_dir = str(Path(torch._dynamo.__file__).parent) + except ImportError: + default_dynamo_dir = str(repo_root / "torch" / "_dynamo") + + parser = argparse.ArgumentParser(description="Manage graph break registry.") + subparsers = parser.add_subparsers(dest="command", help="Command to execute") + + create_parser = subparsers.add_parser("create", help="Create registry from scratch") + create_parser.add_argument( + "--dynamo_dir", + type=str, + default=default_dynamo_dir, + help="Directory to search for unimplemented_v2 calls.", + ) + + add_parser = subparsers.add_parser("add", help="Add a gb_type to registry") + add_parser.add_argument("gb_type", help="The gb_type to add") + add_parser.add_argument( + "file_path", help="Path to the file containing the unimplemented_v2 call" + ) + add_parser.add_argument( + "--additional-info", help="Optional additional information to include" + ) + + update_parser = subparsers.add_parser( + "update", help="Update an existing gb_type in registry" + ) + update_parser.add_argument("gb_type", help="The gb_type to update") + update_parser.add_argument( + "file_path", + help="Path to the file containing the updated unimplemented_v2 call", + ) + update_parser.add_argument( + "--new_gb_type", help="New gb_type name if it has changed", default=None + ) + update_parser.add_argument( + "--additional-info", help="Optional additional information to include" + ) + + verify_parser = subparsers.add_parser( + "verify", help="Verify all unimplemented_v2 calls match registry entries" + ) + verify_parser.add_argument( + "--dynamo_dir", + type=str, + default=default_dynamo_dir, + help="Directory to search for unimplemented_v2 calls.", + ) + + parser.add_argument( + "--registry-path", + type=str, + default=str(registry_path), + help="Path to save the registry JSON file", + ) + + args = parser.parse_args() + + if args.command == "create": + create_registry(args.dynamo_dir, args.registry_path) + elif args.command == "add": + success = cmd_add_new_gb_type( + args.gb_type, args.file_path, args.registry_path, args.additional_info + ) + if not success: + sys.exit(1) + elif args.command == "update": + success = cmd_update_gb_type( + args.gb_type, + args.file_path, + args.registry_path, + args.new_gb_type, + args.additional_info, + ) + if not success: + sys.exit(1) + elif args.command == "verify": + success = test_verify_gb_id_mapping(args.dynamo_dir, args.registry_path) + if not success: + sys.exit(1) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/tools/extract_scripts.py b/tools/extract_scripts.py index ab64424348f02d..06bc5f744c14da 100755 --- a/tools/extract_scripts.py +++ b/tools/extract_scripts.py @@ -42,7 +42,7 @@ def extract(step: Step) -> Script | None: "bash": f"#!/usr/bin/env bash\nset -eo pipefail\n{run}", "sh": f"#!/usr/bin/env sh\nset -e\n{run}", }.get(shell, run) - return {"extension": extension, "script": script} + return {"extension": extension, "script": script} # type: ignore[typeddict-item] elif is_gh_script and gh_script is not None: return {"extension": ".js", "script": gh_script} else: diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index 9ff3f3c68d45c4..2a9cee36f7bc8c 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -300,7 +300,7 @@ def build_collectives( for _ in range(1, num_coalesced_entries): all_entries[i].pop(k) else: - # Iterate through all the ranks and check if there is a mis-match for the current entry. + # Iterate through all the ranks and check if there is a mismatch for the current entry. check_current_entry_match( all_entries, _pg_guids, diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index d836779b585f5c..dd2eb109aa563f 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -46,7 +46,7 @@ def read_dump(prefix: str, filename: str) -> dict[str, Union[str, int, list[Any] def _determine_prefix(files: list[str]) -> str: """If the user doesn't specify a prefix, but does pass a dir full of similarly-prefixed files, we should be able to - infer the common prefix most of the time. But if we can't confidently infer, just fall back to requring the user + infer the common prefix most of the time. But if we can't confidently infer, just fall back to requiring the user to specify it """ possible_prefixes: defaultdict[str, set[int]] = defaultdict(set) diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 5a5063a15978ed..73ec2a13d3be0b 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -312,7 +312,7 @@ def visualize_ops( {first_rank}, ) - # Iterate through all the ranks and check if there is a mis-match for the current entry. + # Iterate through all the ranks and check if there is a mismatch for the current entry. check_current_entry_match( all_coalesced_entries, _pg_guids, @@ -463,10 +463,10 @@ def error_analysis( match_record.candidate_idx.update(match_record.found_idx) match_record.found_idx.clear() match_record.found_ranks.clear() - elif ( - len(match_record.candidate_ranks) == 1 - and dumps_ranks == match_record.expected_ranks - ): + # We didn't see any mismatch and all expected ranks are in the dump. + elif len( + match_record.candidate_ranks + ) == 1 and match_record.expected_ranks.issubset(dumps_ranks): # case two: alltoall or alltoall_base case. if match_record.has_undecided_case: alltoall_cases = [current_entry] + [ @@ -527,6 +527,7 @@ def error_analysis( match_record.candidate_idx.update(match_record.found_idx) match_record.found_idx.clear() match_record.found_ranks.clear() + # if any element in expected_ranks not in dumps_ranks. if match_record.expected_ranks - dumps_ranks: mismatch[pg_name] += 1 logger.info( diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py index aebd914eb46760..1d8abcefabfac5 100644 --- a/tools/flight_recorder/fr_trace.py +++ b/tools/flight_recorder/fr_trace.py @@ -14,7 +14,7 @@ - TODO- tracebacks aren't implemented Known Issues -- Flight Recorder buffer sequence_id information is not sufficient to match collectives and coalseced collectives +- Flight Recorder buffer sequence_id information is not sufficient to match collectives and coalesced collectives unless we have the trace data from the beginning of the program. To enable confident analysis of trace buffers that do not start from zero (and to simplify the script's matching logic) we need to add more information to the recorder. - Currently, the script omits checking the 'status' of collectives. We can look for the first 'non completed' diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index a10d87faf938f4..ce92638c859e5f 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -6,7 +6,7 @@ import subprocess from pathlib import Path -from setuptools import distutils # type: ignore[import] +from setuptools import distutils # type: ignore[import,attr-defined] UNKNOWN = "Unknown" diff --git a/tools/linter/adapters/_linter.py b/tools/linter/adapters/_linter.py deleted file mode 100644 index 0b767d5fb817a0..00000000000000 --- a/tools/linter/adapters/_linter.py +++ /dev/null @@ -1,512 +0,0 @@ -from __future__ import annotations - -import argparse -import dataclasses as dc -import json -import logging -import sys -import token -from abc import ABC, abstractmethod -from argparse import Namespace -from enum import Enum -from functools import cached_property -from pathlib import Path -from tokenize import generate_tokens, TokenInfo -from typing import Any, Generic, get_args, TYPE_CHECKING -from typing_extensions import Never, Self, TypeVar - - -if TYPE_CHECKING: - from collections.abc import Iterator, Sequence - - -# Python 3.12 and up have two new token types, FSTRING_START and FSTRING_END -NO_TOKEN = -1 -FSTRING_START: int = getattr(token, "FSTRING_START", NO_TOKEN) -FSTRING_END: int = getattr(token, "FSTRING_END", NO_TOKEN) - -START_OF_LINE_TOKENS = dict.fromkeys((token.DEDENT, token.INDENT, token.NEWLINE)) -IGNORED_TOKENS = dict.fromkeys( - (token.COMMENT, token.ENDMARKER, token.ENCODING, token.NL) -) -EMPTY_TOKENS = START_OF_LINE_TOKENS | IGNORED_TOKENS - -BRACKETS = {"{": "}", "(": ")", "[": "]"} -BRACKETS_INV = {j: i for i, j in BRACKETS.items()} - -ROOT = Path(__file__).absolute().parents[3] - - -def is_name(t: TokenInfo, *names: str) -> bool: - return t.type == token.NAME and not names or t.string in names - - -def is_op(t: TokenInfo, *names: str) -> bool: - return t.type == token.OP and not names or t.string in names - - -class LintSeverity(str, Enum): - ERROR = "error" - WARNING = "warning" - ADVICE = "advice" - DISABLED = "disabled" - - -@dc.dataclass -class LintMessage: - """This is a datatype representation of the JSON that gets sent to lintrunner - as described here: - https://docs.rs/lintrunner/latest/lintrunner/lint_message/struct.LintMessage.html - """ - - code: str - name: str - severity: LintSeverity - - char: int | None = None - description: str | None = None - line: int | None = None - original: str | None = None - path: str | None = None - replacement: str | None = None - - asdict = dc.asdict - - -@dc.dataclass -class LintResult: - """LintResult is a single result from a linter. - - Like LintMessage but the .length member allows you to make specific edits to - one location within a file, not just replace the whole file. - - Linters can generate recursive results - results that contain other results. - - For example, the annotation linter would find two results in this code sample: - - index = Union[Optional[str], int] - - And the first result, `Union[Optional[str], int]`, contains the second one, - `Optional[str]`, so the first result is recursive but the second is not. - - If --fix is selected, the linter does a cycle of tokenizing and fixing all - the non-recursive edits until no edits remain. - """ - - name: str - - line: int | None = None - char: int | None = None - replacement: str | None = None - length: int | None = None # Not in LintMessage - description: str | None = None - original: str | None = None - - is_recursive: bool = False # Not in LintMessage - - @property - def is_edit(self) -> bool: - return None not in (self.char, self.length, self.line, self.replacement) - - def apply(self, lines: list[str]) -> None: - if not ( - self.char is None - or self.length is None - or self.line is None - or self.replacement is None - ): - line = lines[self.line - 1] - before = line[: self.char] - after = line[self.char + self.length :] - lines[self.line - 1] = f"{before}{self.replacement}{after}" - - def contains(self, r: LintResult) -> bool: - assert self.char is not None and self.line is not None - assert r.char is not None and r.line is not None - return self.line == r.line and self.char <= r.char and self.end >= r.end - - @property - def end(self) -> int: - assert self.char is not None and self.length is not None - return self.char + self.length - - def as_message(self, code: str, path: str) -> LintMessage: - d = dc.asdict(self) - d.pop("is_recursive") - d.pop("length") - if self.is_edit: - # This is one of our , which we don't want to - # send to lintrunner as a replacement - d["replacement"] = None - - return LintMessage(code=code, path=path, severity=LintSeverity.ERROR, **d) - - def sort_key(self) -> tuple[int, int, str]: - line = -1 if self.line is None else self.line - char = -1 if self.char is None else self.char - return line, char, self.name - - -class ParseError(ValueError): - def __init__(self, token: TokenInfo, *args: str) -> None: - super().__init__(*args) - self.token = token - - -class ArgumentParser(argparse.ArgumentParser): - """ - Adds better help formatting and default arguments to argparse.ArgumentParser - """ - - def __init__( - self, - prog: str | None = None, - usage: str | None = None, - description: str | None = None, - epilog: str | None = None, - is_fixer: bool = False, - **kwargs: Any, - ) -> None: - super().__init__(prog, usage, description, None, **kwargs) - self._epilog = epilog - - help = "A list of files or directories to lint" - self.add_argument("files", nargs="*", help=help) - # TODO(rec): get fromfile_prefix_chars="@", type=argparse.FileType to work - - help = "Fix lint errors if possible" if is_fixer else argparse.SUPPRESS - self.add_argument("-f", "--fix", action="store_true", help=help) - - help = "Run for lintrunner and print LintMessages which aren't edits" - self.add_argument("-l", "--lintrunner", action="store_true", help=help) - - help = "Print more debug info" - self.add_argument("-v", "--verbose", action="store_true", help=help) - - def exit(self, status: int = 0, message: str | None = None) -> Never: - """ - Overriding this method is a workaround for argparse throwing away all - line breaks when printing the `epilog` section of the help message. - """ - argv = sys.argv[1:] - if self._epilog and not status and "-h" in argv or "--help" in argv: - print(self._epilog) - super().exit(status, message) - - -class OmittedLines: - """Read lines textually and find comment lines that end in 'noqa {linter_name}'""" - - omitted: set[int] - - def __init__(self, lines: Sequence[str], linter_name: str) -> None: - self.lines = lines - suffix = f"# noqa: {linter_name}" - omitted = ((i, s.rstrip()) for i, s in enumerate(lines)) - self.omitted = {i + 1 for i, s in omitted if s.endswith(suffix)} - - def __call__( - self, tokens: Sequence[TokenInfo], begin: int = 0, end: int = NO_TOKEN - ) -> bool: - if end == NO_TOKEN: - end = len(tokens) - # A token_line might span multiple physical lines - start = min((tokens[i].start[0] for i in range(begin, end)), default=0) - end = max((tokens[i].end[0] for i in range(begin, end)), default=-1) - return self.contains_lines(start, end) - - def contains_lines(self, begin: int, end: int) -> bool: - return bool(self.omitted.intersection(range(begin, end + 1))) - - -class PythonFile: - contents: str - lines: list[str] - path: Path | None - linter_name: str - - def __init__( - self, - linter_name: str, - path: Path | None = None, - contents: str | None = None, - ) -> None: - self.linter_name = linter_name - self.path = path and (path.relative_to(ROOT) if path.is_absolute() else path) - if contents is None and path is not None: - contents = path.read_text() - - self.contents = contents or "" - self.lines = self.contents.splitlines(keepends=True) - - @classmethod - def make(cls, linter_name: str, pc: Path | str | None = None) -> Self: - if isinstance(pc, Path): - return cls(linter_name, path=pc) - return cls(linter_name, contents=pc) - - def with_contents(self, contents: str) -> Self: - return self.__class__(self.linter_name, self.path, contents) - - @cached_property - def omitted(self) -> OmittedLines: - assert self.linter_name is not None - return OmittedLines(self.lines, self.linter_name) - - @cached_property - def tokens(self) -> list[TokenInfo]: - # Might raise IndentationError if the code is mal-indented - return list(generate_tokens(iter(self.lines).__next__)) - - @cached_property - def token_lines(self) -> list[list[TokenInfo]]: - """Returns lists of TokenInfo segmented by token.NEWLINE""" - token_lines: list[list[TokenInfo]] = [[]] - - for t in self.tokens: - if t.type not in (token.COMMENT, token.ENDMARKER, token.NL): - token_lines[-1].append(t) - if t.type == token.NEWLINE: - token_lines.append([]) - if token_lines and not token_lines[-1]: - token_lines.pop() - return token_lines - - @cached_property - def import_lines(self) -> list[list[int]]: - froms, imports = [], [] - for i, (t, *_) in enumerate(self.token_lines): - if t.type == token.INDENT: - break - if t.type == token.NAME: - if t.string == "from": - froms.append(i) - elif t.string == "import": - imports.append(i) - - return [froms, imports] - - @cached_property - def opening_comment_lines(self) -> int: - """The number of comments at the very top of the file.""" - it = (i for i, s in enumerate(self.lines) if not s.startswith("#")) - return next(it, 0) - - -def bracket_pairs(tokens: Sequence[TokenInfo]) -> dict[int, int]: - """Returns a dictionary mapping opening to closing brackets""" - braces: dict[int, int] = {} - stack: list[int] = [] - - for i, t in enumerate(tokens): - if t.type == token.OP: - if t.string in BRACKETS: - stack.append(i) - elif inv := BRACKETS_INV.get(t.string): - if not stack: - raise ParseError(t, "Never opened") - begin = stack.pop() - - if not (stack and stack[-1] == FSTRING_START): - braces[begin] = i - - b = tokens[begin].string - if b != inv: - raise ParseError(t, f"Mismatched braces '{b}' at {begin}") - elif t.type == FSTRING_START: - stack.append(FSTRING_START) - elif t.type == FSTRING_END: - if stack.pop() != FSTRING_START: - raise ParseError(t, "Mismatched FSTRING_START/FSTRING_END") - if stack: - raise ParseError(t, "Left open") - return braces - - -class ErrorLines: - """How many lines to display before and after an error""" - - WINDOW = 5 - BEFORE = 2 - AFTER = WINDOW - BEFORE - 1 - - -PythonFileT = TypeVar("PythonFileT", bound=PythonFile) - - -class FileLinter(Generic[PythonFileT], ABC): - """The base class that all token-based linters inherit from""" - - description: str - linter_name: str - - epilog: str | None = None - is_fixer: bool = True - report_column_numbers: bool = False - - @abstractmethod - def _lint(self, python_file: PythonFileT) -> Iterator[LintResult]: - raise NotImplementedError - - def __init__(self, argv: Sequence[str] | None = None) -> None: - self.argv = argv - self.parser = ArgumentParser( - is_fixer=self.is_fixer, - description=self.description, - epilog=self.epilog, - ) - self.result_shown = False - - @classmethod - def run(cls) -> Never: - sys.exit(not cls().lint_all()) - - def lint_all(self) -> bool: - if self.args.fix and self.args.lintrunner: - raise ValueError("--fix and --lintrunner are incompatible") - - success = True - for p in self.paths: - success = self._lint_file(p) and success - return self.args.lintrunner or success - - @classmethod - def make_file(cls, pc: Path | str | None = None) -> PythonFileT: - c = cls.__orig_bases__[0] # type: ignore[attr-defined] - # See https://github.com/microsoft/pyright/issues/3442 - actual_python_file_type: PythonFileT = get_args(c)[0] - return actual_python_file_type.make(cls.linter_name, pc) - - @cached_property - def args(self) -> Namespace: - args = self.parser.parse_args(self.argv) - - return args - - @cached_property - def code(self) -> str: - return self.linter_name.upper() - - @cached_property - def paths(self) -> list[Path]: - files = [] - file_parts = (f for fp in self.args.files for f in fp.split(":")) - for f in file_parts: - if f.startswith("@"): - files.extend(Path(f[1:]).read_text().splitlines()) - elif f != "--": - files.append(f) - return sorted(Path(f) for f in files) - - def _lint_file(self, p: Path) -> bool: - if self.args.verbose: - print(p, "Reading", file=sys.stderr) - - pf = self.make_file(p) - replacement, results = self._replace(pf) - - if display := list(self._display(pf, results)): - print(*display, sep="\n") - if results and self.args.fix and pf.path and pf.contents != replacement: - pf.path.write_text(replacement) - - return not results or self.args.fix and all(r.is_edit for r in results) - - def _error(self, pf: PythonFileT, result: LintResult) -> None: - """Called on files that are unparseable""" - - def _replace(self, pf: PythonFileT) -> tuple[str, list[LintResult]]: - # Because of recursive replacements, we need to repeat replacing and reparsing - # from the inside out until all possible replacements are complete - previous_result_count = float("inf") - first_results = None - original = replacement = pf.contents - - while True: - try: - results = sorted(self._lint(pf), key=LintResult.sort_key) - except IndentationError as e: - error, (_name, lineno, column, _line) = e.args - - results = [LintResult(error, lineno, column)] - self._error(pf, *results) - - except ParseError as e: - results = [LintResult(str(e), *e.token.start)] - self._error(pf, *results) - - for i, ri in enumerate(results): - if not ri.is_recursive: - for rj in results[i + 1 :]: - if ri.contains(rj): - rj.is_recursive = True - else: - break - - first_results = first_results or results - if not results or len(results) >= previous_result_count: - break - previous_result_count = len(results) - - lines = pf.lines[:] - for r in reversed(results): - if r.is_edit and not r.is_recursive: - r.apply(lines) - replacement = "".join(lines) - - if not any(r.is_recursive for r in results): - break - pf = pf.with_contents(replacement) - - if first_results and self.args.lintrunner: - name = f"Suggested fixes for {self.linter_name}" - msg = LintResult(name=name, original=original, replacement=replacement) - first_results.append(msg) - - return replacement, first_results - - def _display(self, pf: PythonFileT, results: list[LintResult]) -> Iterator[str]: - """Emit a series of human-readable strings representing the results""" - for r in results: - if self.args.lintrunner: - msg = r.as_message(code=self.code, path=str(pf.path)) - yield json.dumps(msg.asdict(), sort_keys=True) - else: - if self.result_shown: - yield "" - else: - self.result_shown = True - if r.line is None: - yield f"{pf.path}: {r.name}" - else: - yield from (i.rstrip() for i in self._display_window(pf, r)) - - def _display_window(self, pf: PythonFileT, r: LintResult) -> Iterator[str]: - """Display a window onto the code with an error""" - if r.char is None or not self.report_column_numbers: - yield f"{pf.path}:{r.line}: {r.name}" - else: - yield f"{pf.path}:{r.line}:{r.char + 1}: {r.name}" - - begin = max((r.line or 0) - ErrorLines.BEFORE, 1) - end = min(begin + ErrorLines.WINDOW, 1 + len(pf.lines)) - - for lineno in range(begin, end): - source_line = pf.lines[lineno - 1].rstrip() - yield f"{lineno:5} | {source_line}" - if lineno == r.line: - spaces = 8 + (r.char or 0) - carets = len(source_line) if r.char is None else (r.length or 1) - yield spaces * " " + carets * "^" - - -def set_logging_level(args: argparse.Namespace, paths: Sequence[Path | str]) -> None: - if args.verbose: - level = logging.NOTSET - elif len(paths) < 1000: - level = logging.DEBUG - else: - level = logging.INFO - - fmt = "<%(threadName)s:%(levelname)s> %(message)s" - logging.basicConfig(format=fmt, level=level, stream=sys.stderr) diff --git a/tools/linter/adapters/_linter/__init__.py b/tools/linter/adapters/_linter/__init__.py new file mode 100644 index 00000000000000..136b168de3f3fc --- /dev/null +++ b/tools/linter/adapters/_linter/__init__.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import token +from pathlib import Path +from typing import Any, TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Sequence + from tokenize import TokenInfo + + +__all__ = ( + "Block", + "EMPTY_TOKENS", + "FileLinter", + "LineWithSets", + "LintResult", + "ParseError", + "PythonFile", + "ROOT", +) + +NO_TOKEN = -1 + +# Python 3.12 and up have two new token types, FSTRING_START and FSTRING_END +_START_OF_LINE_TOKENS = token.DEDENT, token.INDENT, token.NEWLINE +_IGNORED_TOKENS = token.COMMENT, token.ENDMARKER, token.ENCODING, token.NL +EMPTY_TOKENS = dict.fromkeys(_START_OF_LINE_TOKENS + _IGNORED_TOKENS) + +_LINTER = Path(__file__).absolute().parents[0] +ROOT = _LINTER.parents[3] + + +class ParseError(ValueError): + def __init__(self, token: TokenInfo, *args: str) -> None: + super().__init__(*args) + self.token = token + + +from .block import Block +from .file_linter import FileLinter +from .messages import LintResult +from .python_file import PythonFile diff --git a/tools/linter/adapters/_linter/argument_parser.py b/tools/linter/adapters/_linter/argument_parser.py new file mode 100644 index 00000000000000..29a1c18dcd4f82 --- /dev/null +++ b/tools/linter/adapters/_linter/argument_parser.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import argparse +import sys +from typing import Any +from typing_extensions import Never + + +class ArgumentParser(argparse.ArgumentParser): + """ + Adds better help formatting and default arguments to argparse.ArgumentParser + """ + + def __init__( + self, + prog: str | None = None, + usage: str | None = None, + description: str | None = None, + epilog: str | None = None, + is_fixer: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(prog, usage, description, None, **kwargs) + self._epilog = epilog + + help = "A list of files or directories to lint" + self.add_argument("files", nargs="*", help=help) + # TODO(rec): get fromfile_prefix_chars="@", type=argparse.FileType to work + + help = "Fix lint errors if possible" if is_fixer else argparse.SUPPRESS + self.add_argument("-f", "--fix", action="store_true", help=help) + + help = "Run for lintrunner and print LintMessages which aren't edits" + self.add_argument("-l", "--lintrunner", action="store_true", help=help) + + help = "Print more debug info" + self.add_argument("-v", "--verbose", action="store_true", help=help) + + def exit(self, status: int = 0, message: str | None = None) -> Never: + """ + Overriding this method is a workaround for argparse throwing away all + line breaks when printing the `epilog` section of the help message. + """ + argv = sys.argv[1:] + if self._epilog and not status and "-h" in argv or "--help" in argv: + print(self._epilog) + super().exit(status, message) diff --git a/tools/linter/adapters/_linter/block.py b/tools/linter/adapters/_linter/block.py new file mode 100644 index 00000000000000..f0417a5ff47daa --- /dev/null +++ b/tools/linter/adapters/_linter/block.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import dataclasses as dc +import itertools +import token +from enum import Enum +from functools import cached_property, total_ordering +from typing import Any, Optional, TYPE_CHECKING +from typing_extensions import Self + + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + from tokenize import TokenInfo + + +@total_ordering +@dc.dataclass +class Block: + """A block of Python code starting with either `def` or `class`""" + + class Category(str, Enum): + CLASS = "class" + DEF = "def" + + category: Category + + # The sequence of tokens that contains this Block. + # Tokens are represented in `Block` as indexes into `self.tokens` + tokens: Sequence[TokenInfo] = dc.field(repr=False) + + # The name of the function or class being defined + name: str + + # The index of the very first token in the block (the "class" or "def" keyword) + begin: int + + # The index of the first INDENT token for this block + indent: int + + # The index of the DEDENT token for this end of this block + dedent: int + + # The docstring for the block + docstring: str + + # These next members only get filled in after all blocks have been constructed + # and figure out family ties + + # The full qualified name of the block within the file. + # This is the name of this block and all its parents, joined with `.`. + full_name: str = "" + + # The index of this block within the full list of blocks in the file + index: int = 0 + + # Is this block contained within a function definition? + is_local: bool = dc.field(default=False, repr=False) + + # Is this block a function definition in a class definition? + is_method: bool = dc.field(default=False, repr=False) + + # A block index to the parent of this block, or None for a top-level block. + parent: Optional[int] = None + + # A list of block indexes for the children + children: list[int] = dc.field(default_factory=list) + + @property + def start_line(self) -> int: + return self.tokens[max(self.indent, self.index)].start[0] + + @property + def end_line(self) -> int: + return self.tokens[max(self.dedent, self.index)].start[0] + + @property + def line_count(self) -> int: + return self.end_line - self.start_line + + @property + def is_class(self) -> bool: + return self.category == Block.Category.CLASS + + @property + def display_name(self) -> str: + """A user-friendly name like 'class One' or 'def One.method()'""" + ending = "" if self.is_class else "()" + return f"{self.category.value} {self.full_name}{ending}" + + @cached_property + def decorators(self) -> list[str]: + """A list of decorators for this function or method. + + Each decorator both the @ symbol and any arguments to the decorator + but no extra whitespace. + """ + return _get_decorators(self.tokens, self.begin) + + @cached_property + def is_override(self) -> bool: + return not self.is_class and any( + d.rpartition(".")[2] == "override" for d in self.decorators + ) + + DATA_FIELDS = ( + "category", + "children", + "decorators", + "display_name", + "docstring", + "full_name", + "index", + "is_local", + "is_method", + "line_count", + "parent", + "start_line", + ) + + def as_data(self) -> dict[str, Any]: + d = {i: getattr(self, i) for i in self.DATA_FIELDS} + d["category"] = d["category"].value + return d + + @property + def is_init(self) -> bool: + return not self.is_class and self.name == "__init__" + + def contains(self, b: Block) -> bool: + return self.start_line < b.start_line and self.end_line >= b.end_line + + def __eq__(self, o: object) -> bool: + assert isinstance(o, Block) + return o.tokens is self.tokens and o.index == self.index + + def __hash__(self) -> int: + return super().__hash__() + + def __lt__(self, o: Self) -> bool: + assert isinstance(o, Block) and o.tokens is self.tokens + return o.index < self.index + + +_IGNORE = {token.COMMENT, token.DEDENT, token.INDENT, token.NL} + + +def _get_decorators(tokens: Sequence[TokenInfo], block_start: int) -> list[str]: + def decorators() -> Iterator[str]: + rev = reversed(range(block_start)) + newlines = (i for i in rev if tokens[i].type == token.NEWLINE) + newlines = itertools.chain(newlines, [-1]) # To account for the first line + + it = iter(newlines) + end = next(it, -1) # Like itertools.pairwise in Python 3.10 + for begin in it: + for i in range(begin + 1, end): + t = tokens[i] + if t.type == token.OP and t.string == "@": + useful = (t for t in tokens[i:end] if t.type not in _IGNORE) + yield "".join(s.string.strip("\n") for s in useful) + break + elif t.type not in _IGNORE: + return # A statement means no more decorators + end = begin + + out = list(decorators()) + out.reverse() + return out diff --git a/tools/linter/adapters/_linter/blocks.py b/tools/linter/adapters/_linter/blocks.py new file mode 100644 index 00000000000000..7f511d06cc9877 --- /dev/null +++ b/tools/linter/adapters/_linter/blocks.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import token +from typing import NamedTuple, TYPE_CHECKING + +from . import EMPTY_TOKENS, ParseError +from .block import Block + + +if TYPE_CHECKING: + from collections.abc import Sequence + from tokenize import TokenInfo + + +class BlocksResult(NamedTuple): + blocks: list[Block] + errors: dict[str, str] + + +def blocks(tokens: Sequence[TokenInfo]) -> BlocksResult: + blocks: list[Block] = [] + indent_to_dedent = _make_indent_dict(tokens) + errors: dict[str, str] = {} + + def starts_block(t: TokenInfo) -> bool: + return t.type == token.NAME and t.string in ("class", "def") + + it = (i for i, t in enumerate(tokens) if starts_block(t)) + blocks = [_make_block(tokens, i, indent_to_dedent, errors) for i in it] + + for i, parent in enumerate(blocks): + for j in range(i + 1, len(blocks)): + if parent.contains(child := blocks[j]): + child.parent = i + parent.children.append(j) + else: + break + + for i, b in enumerate(blocks): + b.index = i + parents = [b] + while (p := parents[-1].parent) is not None: + parents.append(blocks[p]) + parents = parents[1:] + + b.is_local = not all(p.is_class for p in parents) + b.is_method = not b.is_class and bool(parents) and parents[0].is_class + + _add_full_names(blocks, [b for b in blocks if b.parent is None]) + return BlocksResult(blocks, errors) + + +def _make_indent_dict(tokens: Sequence[TokenInfo]) -> dict[int, int]: + dedents = dict[int, int]() + stack = list[int]() + + for i, t in enumerate(tokens): + if t.type == token.INDENT: + stack.append(i) + elif t.type == token.DEDENT: + dedents[stack.pop()] = i + + return dedents + + +def _docstring(tokens: Sequence[TokenInfo], start: int) -> str: + for i in range(start + 1, len(tokens)): + tk = tokens[i] + if tk.type == token.STRING: + return tk.string + if tk.type not in EMPTY_TOKENS: + return "" + return "" + + +def _add_full_names( + blocks: Sequence[Block], children: Sequence[Block], prefix: str = "" +) -> None: + # Would be trivial except that there can be duplicate names at any level + dupes: dict[str, list[Block]] = {} + for b in children: + dupes.setdefault(b.name, []).append(b) + + for dl in dupes.values(): + for i, b in enumerate(dl): + suffix = f"[{i + 1}]" if len(dl) > 1 else "" + b.full_name = prefix + b.name + suffix + + for b in children: + if kids := [blocks[i] for i in b.children]: + _add_full_names(blocks, kids, b.full_name + ".") + + +def _make_block( + tokens: Sequence[TokenInfo], + begin: int, + indent_to_dedent: dict[int, int], + errors: dict[str, str], +) -> Block: + def next_token(start: int, token_type: int, error: str) -> int: + for i in range(start, len(tokens)): + if tokens[i].type == token_type: + return i + raise ParseError(tokens[-1], error) + + t = tokens[begin] + category = Block.Category[t.string.upper()] + indent = -1 + dedent = -1 + docstring = "" + name = "(not found)" + try: + ni = next_token(begin + 1, token.NAME, "Definition but no name") + name = tokens[ni].string + indent = next_token(ni + 1, token.INDENT, "Definition but no indent") + dedent = indent_to_dedent[indent] + docstring = _docstring(tokens, indent) + except ParseError as e: + errors[t.line] = " ".join(e.args) + + return Block( + begin=begin, + category=category, + dedent=dedent, + docstring=docstring, + indent=indent, + name=name, + tokens=tokens, + ) diff --git a/tools/linter/adapters/_linter/bracket_pairs.py b/tools/linter/adapters/_linter/bracket_pairs.py new file mode 100644 index 00000000000000..23f08c9ff7391f --- /dev/null +++ b/tools/linter/adapters/_linter/bracket_pairs.py @@ -0,0 +1,42 @@ +import token +from collections.abc import Sequence +from tokenize import TokenInfo + +from . import NO_TOKEN, ParseError + + +FSTRING_START: int = getattr(token, "FSTRING_START", NO_TOKEN) +FSTRING_END: int = getattr(token, "FSTRING_END", NO_TOKEN) + +BRACKETS = {"{": "}", "(": ")", "[": "]"} +BRACKETS_INV = {j: i for i, j in BRACKETS.items()} + + +def bracket_pairs(tokens: Sequence[TokenInfo]) -> dict[int, int]: + """Returns a dictionary mapping opening to closing brackets""" + braces: dict[int, int] = {} + stack: list[int] = [] + + for i, t in enumerate(tokens): + if t.type == token.OP: + if t.string in BRACKETS: + stack.append(i) + elif inv := BRACKETS_INV.get(t.string): + if not stack: + raise ParseError(t, "Never opened") + begin = stack.pop() + + if not (stack and stack[-1] == FSTRING_START): + braces[begin] = i + + b = tokens[begin].string + if b != inv: + raise ParseError(t, f"Mismatched braces '{b}' at {begin}") + elif t.type == FSTRING_START: + stack.append(FSTRING_START) + elif t.type == FSTRING_END: + if stack.pop() != FSTRING_START: + raise ParseError(t, "Mismatched FSTRING_START/FSTRING_END") + if stack: + raise ParseError(t, "Left open") + return braces diff --git a/tools/linter/adapters/_linter/file_linter.py b/tools/linter/adapters/_linter/file_linter.py new file mode 100644 index 00000000000000..7f9c0890fbf640 --- /dev/null +++ b/tools/linter/adapters/_linter/file_linter.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import json +import sys +from abc import abstractmethod +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING +from typing_extensions import Never + +from . import ParseError +from .argument_parser import ArgumentParser +from .messages import LintResult +from .python_file import PythonFile + + +if TYPE_CHECKING: + from argparse import Namespace + from collections.abc import Iterator, Sequence + + +class ErrorLines: + """How many lines to display before and after an error""" + + WINDOW = 5 + BEFORE = 2 + AFTER = WINDOW - BEFORE - 1 + + +class FileLinter: + """The base class that all token-based linters inherit from""" + + description: str + linter_name: str + + epilog: str | None = None + is_fixer: bool = True + report_column_numbers: bool = False + + @abstractmethod + def _lint(self, python_file: PythonFile) -> Iterator[LintResult]: + raise NotImplementedError + + def __init__(self, argv: Sequence[str] | None = None) -> None: + self.argv = argv + self.parser = ArgumentParser( + is_fixer=self.is_fixer, + description=self.description, + epilog=self.epilog, + ) + self.result_shown = False + + @classmethod + def run(cls) -> Never: + sys.exit(not cls().lint_all()) + + def lint_all(self) -> bool: + if self.args.fix and self.args.lintrunner: + raise ValueError("--fix and --lintrunner are incompatible") + + success = True + for p in self.paths: + success = self._lint_file(p) and success + return self.args.lintrunner or success + + @classmethod + def make_file(cls, pc: Path | str | None = None) -> PythonFile: + return PythonFile.make(cls.linter_name, pc) + + @cached_property + def args(self) -> Namespace: + args = self.parser.parse_args(self.argv) + + return args + + @cached_property + def code(self) -> str: + return self.linter_name.upper() + + @cached_property + def paths(self) -> list[Path]: + files = [] + file_parts = (f for fp in self.args.files for f in fp.split(":")) + for f in file_parts: + if f.startswith("@"): + files.extend(Path(f[1:]).read_text().splitlines()) + elif f != "--": + files.append(f) + return sorted(Path(f) for f in files) + + def _lint_file(self, p: Path) -> bool: + if self.args.verbose: + print(p, "Reading", file=sys.stderr) + + pf = self.make_file(p) + replacement, results = self._replace(pf) + + if display := list(self._display(pf, results)): + print(*display, sep="\n") + if results and self.args.fix and pf.path and pf.contents != replacement: + pf.path.write_text(replacement) + + return not results or self.args.fix and all(r.is_edit for r in results) + + def _error(self, pf: PythonFile, result: LintResult) -> None: + """Called on files that are unparsable""" + + def _replace(self, pf: PythonFile) -> tuple[str, list[LintResult]]: + # Because of recursive replacements, we need to repeat replacing and reparsing + # from the inside out until all possible replacements are complete + previous_result_count = float("inf") + first_results = None + original = replacement = pf.contents + + while True: + try: + results = sorted(self._lint(pf), key=LintResult.sort_key) + except IndentationError as e: + error, (_name, lineno, column, _line) = e.args + + results = [LintResult(error, lineno, column)] + self._error(pf, *results) + + except ParseError as e: + results = [LintResult(str(e), *e.token.start)] + self._error(pf, *results) + + for i, ri in enumerate(results): + if not ri.is_recursive: + for rj in results[i + 1 :]: + if ri.contains(rj): + rj.is_recursive = True + else: + break + + first_results = first_results or results + if not results or len(results) >= previous_result_count: + break + previous_result_count = len(results) + + lines = pf.lines[:] + for r in reversed(results): + if r.is_edit and not r.is_recursive: + r.apply(lines) + replacement = "".join(lines) + + if not any(r.is_recursive for r in results): + break + pf = pf.with_contents(replacement) + + if first_results and self.args.lintrunner: + name = f"Suggested fixes for {self.linter_name}" + msg = LintResult(name=name, original=original, replacement=replacement) + first_results.append(msg) + + return replacement, first_results + + def _display(self, pf: PythonFile, results: list[LintResult]) -> Iterator[str]: + """Emit a series of human-readable strings representing the results""" + for r in results: + if self.args.lintrunner: + msg = r.as_message(code=self.code, path=str(pf.path)) + yield json.dumps(msg.asdict(), sort_keys=True) + else: + if self.result_shown: + yield "" + else: + self.result_shown = True + if r.line is None: + yield f"{pf.path}: {r.name}" + else: + yield from (i.rstrip() for i in self._display_window(pf, r)) + + def _display_window(self, pf: PythonFile, r: LintResult) -> Iterator[str]: + """Display a window onto the code with an error""" + if r.char is None or not self.report_column_numbers: + yield f"{pf.path}:{r.line}: {r.name}" + else: + yield f"{pf.path}:{r.line}:{r.char + 1}: {r.name}" + + begin = max((r.line or 0) - ErrorLines.BEFORE, 1) + end = min(begin + ErrorLines.WINDOW, 1 + len(pf.lines)) + + for lineno in range(begin, end): + source_line = pf.lines[lineno - 1].rstrip() + yield f"{lineno:5} | {source_line}" + if lineno == r.line: + spaces = 8 + (r.char or 0) + carets = len(source_line) if r.char is None else (r.length or 1) + yield spaces * " " + carets * "^" diff --git a/tools/linter/adapters/_linter/messages.py b/tools/linter/adapters/_linter/messages.py new file mode 100644 index 00000000000000..5408d3600185a5 --- /dev/null +++ b/tools/linter/adapters/_linter/messages.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import dataclasses as dc +from enum import Enum + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +@dc.dataclass +class LintMessage: + """This is a datatype representation of the JSON that gets sent to lintrunner + as described here: + https://docs.rs/lintrunner/latest/lintrunner/lint_message/struct.LintMessage.html + """ + + code: str + name: str + severity: LintSeverity + + char: int | None = None + description: str | None = None + line: int | None = None + original: str | None = None + path: str | None = None + replacement: str | None = None + + asdict = dc.asdict + + +@dc.dataclass +class LintResult: + """LintResult is a single result from a linter. + + Like LintMessage but the .length member allows you to make specific edits to + one location within a file, not just replace the whole file. + + Linters can generate recursive results - results that contain other results. + + For example, the annotation linter would find two results in this code sample: + + index = Union[Optional[str], int] + + And the first result, `Union[Optional[str], int]`, contains the second one, + `Optional[str]`, so the first result is recursive but the second is not. + + If --fix is selected, the linter does a cycle of tokenizing and fixing all + the non-recursive edits until no edits remain. + """ + + name: str + + line: int | None = None + char: int | None = None + replacement: str | None = None + length: int | None = None # Not in LintMessage + description: str | None = None + original: str | None = None + + is_recursive: bool = False # Not in LintMessage + + @property + def is_edit(self) -> bool: + return None not in (self.char, self.length, self.line, self.replacement) + + def apply(self, lines: list[str]) -> None: + if not ( + self.char is None + or self.length is None + or self.line is None + or self.replacement is None + ): + line = lines[self.line - 1] + before = line[: self.char] + after = line[self.char + self.length :] + lines[self.line - 1] = f"{before}{self.replacement}{after}" + + def contains(self, r: LintResult) -> bool: + assert self.char is not None and self.line is not None + assert r.char is not None and r.line is not None + return self.line == r.line and self.char <= r.char and self.end >= r.end + + @property + def end(self) -> int: + assert self.char is not None and self.length is not None + return self.char + self.length + + def as_message(self, code: str, path: str) -> LintMessage: + d = dc.asdict(self) + d.pop("is_recursive") + d.pop("length") + if self.is_edit: + # This is one of our , which we don't want to + # send to lintrunner as a replacement + d["replacement"] = None + + return LintMessage(code=code, path=path, severity=LintSeverity.ERROR, **d) + + def sort_key(self) -> tuple[int, int, str]: + line = -1 if self.line is None else self.line + char = -1 if self.char is None else self.char + return line, char, self.name diff --git a/tools/linter/adapters/_linter/python_file.py b/tools/linter/adapters/_linter/python_file.py new file mode 100644 index 00000000000000..41ebfba6ea4722 --- /dev/null +++ b/tools/linter/adapters/_linter/python_file.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import token +from functools import cached_property +from pathlib import Path +from tokenize import generate_tokens, TokenInfo +from typing import TYPE_CHECKING +from typing_extensions import Self + +from . import EMPTY_TOKENS, NO_TOKEN, ParseError, ROOT +from .blocks import blocks +from .sets import LineWithSets + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from .block import Block + + +class PythonFile: + contents: str + lines: list[str] + path: Path | None + linter_name: str + + def __init__( + self, + linter_name: str, + path: Path | None = None, + contents: str | None = None, + ) -> None: + self.linter_name = linter_name + self.path = path and (path.relative_to(ROOT) if path.is_absolute() else path) + if contents is None and path is not None: + contents = path.read_text() + + self.contents = contents or "" + self.lines = self.contents.splitlines(keepends=True) + + @classmethod + def make(cls, linter_name: str, pc: Path | str | None = None) -> Self: + if isinstance(pc, Path): + return cls(linter_name, path=pc) + return cls(linter_name, contents=pc) + + def with_contents(self, contents: str) -> Self: + return self.__class__(self.linter_name, self.path, contents) + + @cached_property + def omitted(self) -> OmittedLines: + assert self.linter_name is not None + return OmittedLines(self.lines, self.linter_name) + + @cached_property + def tokens(self) -> list[TokenInfo]: + # Might raise IndentationError if the code is mal-indented + return list(generate_tokens(iter(self.lines).__next__)) + + @cached_property + def token_lines(self) -> list[list[TokenInfo]]: + """Returns lists of TokenInfo segmented by token.NEWLINE""" + token_lines: list[list[TokenInfo]] = [[]] + + for t in self.tokens: + if t.type not in (token.COMMENT, token.ENDMARKER, token.NL): + token_lines[-1].append(t) + if t.type == token.NEWLINE: + token_lines.append([]) + if token_lines and not token_lines[-1]: + token_lines.pop() + return token_lines + + @cached_property + def import_lines(self) -> list[list[int]]: + froms, imports = [], [] + for i, (t, *_) in enumerate(self.token_lines): + if t.type == token.INDENT: + break + if t.type == token.NAME: + if t.string == "from": + froms.append(i) + elif t.string == "import": + imports.append(i) + + return [froms, imports] + + @cached_property + def opening_comment_lines(self) -> int: + """The number of comments at the very top of the file.""" + it = (i for i, s in enumerate(self.lines) if not s.startswith("#")) + return next(it, 0) + + def __getitem__(self, i: int | slice) -> TokenInfo | Sequence[TokenInfo]: + return self.tokens[i] + + def next_token(self, start: int, token_type: int, error: str) -> int: + for i in range(start, len(self.tokens)): + if self.tokens[i].type == token_type: + return i + raise ParseError(self.tokens[-1], error) + + def docstring(self, start: int) -> str: + for i in range(start + 1, len(self.tokens)): + tk = self.tokens[i] + if tk.type == token.STRING: + return tk.string + if tk.type not in EMPTY_TOKENS: + return "" + return "" + + @cached_property + def indent_to_dedent(self) -> dict[int, int]: + dedents = dict[int, int]() + stack = list[int]() + + for i, t in enumerate(self.tokens): + if t.type == token.INDENT: + stack.append(i) + elif t.type == token.DEDENT: + dedents[stack.pop()] = i + + return dedents + + @cached_property + def errors(self) -> dict[str, str]: + return {} + + @cached_property + def braced_sets(self) -> list[Sequence[TokenInfo]]: + lines = [t for tl in self._lines_with_sets for t in tl.braced_sets] + return [s for s in lines if not self.omitted(s)] + + @cached_property + def sets(self) -> list[TokenInfo]: + tokens = [t for tl in self._lines_with_sets for t in tl.sets] + return [t for t in tokens if not self.omitted([t])] + + @cached_property + def insert_import_line(self) -> int | None: + froms, imports = self.import_lines + for i in froms + imports: + tl = self.token_lines[i] + if any(i.type == token.NAME and i.string == "OrderedSet" for i in tl): + return None + if section := froms or imports: + return self._lines_with_sets[section[-1]].tokens[-1].start[0] + 1 + return self.opening_comment_lines + 1 + + @cached_property + def _lines_with_sets(self) -> list[LineWithSets]: + return [LineWithSets(tl) for tl in self.token_lines] + + @cached_property + def blocks(self) -> list[Block]: + res = blocks(self.tokens) + self.errors.update(res.errors) + return res.blocks + + +class OmittedLines: + """Read lines textually and find comment lines that end in 'noqa {linter_name}'""" + + omitted: set[int] + + def __init__(self, lines: Sequence[str], linter_name: str) -> None: + self.lines = lines + suffix = f"# noqa: {linter_name}" + omitted = ((i, s.rstrip()) for i, s in enumerate(lines)) + self.omitted = {i + 1 for i, s in omitted if s.endswith(suffix)} + + def __call__( + self, tokens: Sequence[TokenInfo], begin: int = 0, end: int = NO_TOKEN + ) -> bool: + if end == NO_TOKEN: + end = len(tokens) + # A token_line might span multiple physical lines + start = min((tokens[i].start[0] for i in range(begin, end)), default=0) + end = max((tokens[i].end[0] for i in range(begin, end)), default=-1) + return self.contains_lines(start, end) + + def contains_lines(self, begin: int, end: int) -> bool: + return bool(self.omitted.intersection(range(begin, end + 1))) diff --git a/tools/linter/adapters/_linter/sets.py b/tools/linter/adapters/_linter/sets.py new file mode 100644 index 00000000000000..0aab76876acff8 --- /dev/null +++ b/tools/linter/adapters/_linter/sets.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import dataclasses as dc +import token +from functools import cached_property +from typing import TYPE_CHECKING + +from . import EMPTY_TOKENS +from .bracket_pairs import bracket_pairs + + +if TYPE_CHECKING: + from tokenize import TokenInfo + + +@dc.dataclass +class LineWithSets: + """A logical line of Python tokens, terminated by a NEWLINE or the end of file""" + + tokens: list[TokenInfo] + + @cached_property + def sets(self) -> list[TokenInfo]: + """A list of tokens which use the built-in set symbol""" + return [t for i, t in enumerate(self.tokens) if self.is_set(i)] + + @cached_property + def braced_sets(self) -> list[list[TokenInfo]]: + """A list of lists of tokens, each representing a braced set, like {1}""" + return [ + self.tokens[b : e + 1] + for b, e in self.bracket_pairs.items() + if self.is_braced_set(b, e) + ] + + @cached_property + def bracket_pairs(self) -> dict[int, int]: + return bracket_pairs(self.tokens) + + def is_set(self, i: int) -> bool: + t = self.tokens[i] + after = i < len(self.tokens) - 1 and self.tokens[i + 1] + if t.string == "Set" and t.type == token.NAME: + return after and after.string == "[" and after.type == token.OP + return ( + (t.string == "set" and t.type == token.NAME) + and not (i and self.tokens[i - 1].string in ("def", ".")) + and not (after and after.string == "=" and after.type == token.OP) + ) + + def is_braced_set(self, begin: int, end: int) -> bool: + if ( + begin + 1 == end + or self.tokens[begin].string != "{" + or begin + and self.tokens[begin - 1].string == "in" # skip `x in {1, 2, 3}` + ): + return False + + i = begin + 1 + empty = True + while i < end: + t = self.tokens[i] + if t.type == token.OP and t.string in (":", "**"): + return False + if brace_end := self.bracket_pairs.get(i): + # Skip to the end of a subexpression + i = brace_end + elif t.type not in EMPTY_TOKENS: + empty = False + i += 1 + return not empty diff --git a/tools/linter/adapters/codespell_linter.py b/tools/linter/adapters/codespell_linter.py new file mode 100644 index 00000000000000..eb7c55081e5319 --- /dev/null +++ b/tools/linter/adapters/codespell_linter.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import argparse +import concurrent.futures +import json +import logging +import os +import subprocess +import sys +from enum import Enum +from pathlib import Path +from typing import NamedTuple + + +REPO_ROOT = Path(__file__).absolute().parents[3] +PYPROJECT = REPO_ROOT / "pyproject.toml" +DICTIONARY = REPO_ROOT / "tools" / "linter" / "dictionary.txt" + +FORBIDDEN_WORDS = { + "multipy", # project pytorch/multipy is dead # codespell:ignore multipy +} + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: str | None + line: int | None + char: int | None + code: str + severity: LintSeverity + name: str + original: str | None + replacement: str | None + description: str | None + + +def format_error_message( + filename: str, + error: Exception | None = None, + *, + message: str | None = None, +) -> LintMessage: + if message is None and error is not None: + message = ( + f"Failed due to {error.__class__.__name__}:\n{error}\n" + "Please either fix the error or " + "add the word(s) to the dictionary file (lowercase is preferred)." + ) + return LintMessage( + path=filename, + line=None, + char=None, + code="CODESPELL", + severity=LintSeverity.ERROR, + name="spelling error", + original=None, + replacement=None, + description=message, + ) + + +def run_codespell(path: Path) -> str: + try: + return subprocess.check_output( + [ + sys.executable, + "-m", + "codespell_lib", + "--toml", + str(PYPROJECT), + str(path), + ], + stderr=subprocess.STDOUT, + text=True, + encoding="utf-8", + ) + except subprocess.CalledProcessError as exc: + raise ValueError(exc.output) from exc + + +def check_file(filename: str) -> list[LintMessage]: + path = Path(filename).absolute() + try: + run_codespell(path) + except Exception as err: + return [format_error_message(filename, err)] + return [] + + +def check_dictionary(filename: str) -> list[LintMessage]: + """Check the dictionary file for duplicates.""" + path = Path(filename).absolute() + try: + words = path.read_text(encoding="utf-8").splitlines() + words_set = set(words) + if len(words) != len(words_set): + raise ValueError("The dictionary file contains duplicate entries.") + uncased_words = list(map(str.lower, words)) + if uncased_words != sorted(uncased_words): + raise ValueError( + "The dictionary file is not sorted alphabetically (case-insensitive)." + ) + for forbidden_word in sorted( + FORBIDDEN_WORDS & (words_set | set(uncased_words)) + ): + raise ValueError( + f"The dictionary file contains a forbidden word: {forbidden_word!r}. " + "Please remove it from the dictionary file and use 'codespell:ignore' " + "inline comment instead." + ) + except Exception as err: + return [format_error_message(str(filename), err)] + return [] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Check files for spelling mistakes using codespell.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(processName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + with concurrent.futures.ProcessPoolExecutor( + max_workers=os.cpu_count(), + ) as executor: + futures = {executor.submit(check_file, x): x for x in args.filenames} + futures[executor.submit(check_dictionary, str(DICTIONARY))] = str(DICTIONARY) + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + print(json.dumps(lint_message._asdict()), flush=True) + except Exception: + logging.critical('Failed at "%s".', futures[future]) + raise + + +if __name__ == "__main__": + main() diff --git a/tools/linter/adapters/docstring_linter.py b/tools/linter/adapters/docstring_linter.py index f24119f06bfe2b..cc27e6be72d950 100644 --- a/tools/linter/adapters/docstring_linter.py +++ b/tools/linter/adapters/docstring_linter.py @@ -1,15 +1,11 @@ from __future__ import annotations -import dataclasses as dc import itertools import json import sys -import token -from enum import Enum -from functools import cached_property, total_ordering +from functools import cached_property from pathlib import Path -from typing import Any, Callable, Optional, TYPE_CHECKING -from typing_extensions import Self +from typing import Any, Callable, TYPE_CHECKING _FILE = Path(__file__).absolute() @@ -22,10 +18,9 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence - from tokenize import TokenInfo -GRANDFATHER_LIST = Path(str(_FILE).replace(".py", "-grandfather.json")) +GRANDFATHER_LIST = _FILE.parent / "docstring_linter-grandfather.json" # We tolerate a 10% increase in block size before demanding a docstring TOLERANCE_PERCENT = 10 @@ -45,274 +40,7 @@ ) -@total_ordering -@dc.dataclass -class Block: - """A block of Python code starting with either `def` or `class`""" - - class Category(str, Enum): - CLASS = "class" - DEF = "def" - - category: Category - - # The sequence of tokens that contains this Block. - # Tokens are represented in `Block` as indexes into `self.tokens` - tokens: Sequence[TokenInfo] = dc.field(repr=False) - - # The name of the function or class being defined - name: str - - # The index of the very first token in the block (the "class" or "def" keyword) - begin: int - - # The index of the first INDENT token for this block - indent: int - - # The index of the DEDENT token for this end of this block - dedent: int - - # The docstring for the block - docstring: str - - # These next members only get filled in after all blocks have been constructed - # and figure out family ties - - # The full qualified name of the block within the file. - # This is the name of this block and all its parents, joined with `.`. - full_name: str = "" - - # The index of this block within the full list of blocks in the file - index: int = 0 - - # Is this block contained within a function definition? - is_local: bool = dc.field(default=False, repr=False) - - # Is this block a function definition in a class definition? - is_method: bool = dc.field(default=False, repr=False) - - # A block index to the parent of this block, or None for a top-level block. - parent: Optional[int] = None - - # A list of block indexes for the children - children: list[int] = dc.field(default_factory=list) - - @property - def start_line(self) -> int: - return self.tokens[max(self.indent, self.index)].start[0] - - @property - def end_line(self) -> int: - return self.tokens[max(self.dedent, self.index)].start[0] - - @property - def line_count(self) -> int: - return self.end_line - self.start_line - - @property - def is_class(self) -> bool: - return self.category == Block.Category.CLASS - - @property - def display_name(self) -> str: - """A user-friendly name like 'class One' or 'def One.method()'""" - ending = "" if self.is_class else "()" - return f"{self.category.value} {self.full_name}{ending}" - - @cached_property - def decorators(self) -> list[str]: - """A list of decorators for this function or method. - - Each decorator both the @ symbol and any arguments to the decorator - but no extra whitespace. - """ - return _get_decorators(self.tokens, self.begin) - - @cached_property - def is_override(self) -> bool: - return not self.is_class and any( - d.rpartition(".")[2] == "override" for d in self.decorators - ) - - DATA_FIELDS = ( - "category", - "children", - "decorators", - "display_name", - "docstring", - "full_name", - "index", - "is_local", - "is_method", - "line_count", - "parent", - "start_line", - ) - - def as_data(self) -> dict[str, Any]: - d = {i: getattr(self, i) for i in self.DATA_FIELDS} - d["category"] = d["category"].value - return d - - @property - def is_init(self) -> bool: - return not self.is_class and self.name == "__init__" - - def contains(self, b: Block) -> bool: - return self.start_line < b.start_line and self.end_line >= b.end_line - - def __eq__(self, o: object) -> bool: - assert isinstance(o, Block) - return o.tokens is self.tokens and o.index == self.index - - def __hash__(self) -> int: - return super().__hash__() - - def __lt__(self, o: Self) -> bool: - assert isinstance(o, Block) and o.tokens is self.tokens - return o.index < self.index - - -_IGNORE = {token.COMMENT, token.DEDENT, token.INDENT, token.NL} - - -def _get_decorators(tokens: Sequence[TokenInfo], block_start: int) -> list[str]: - def decorators() -> Iterator[str]: - rev = reversed(range(block_start)) - newlines = (i for i in rev if tokens[i].type == token.NEWLINE) - newlines = itertools.chain(newlines, [-1]) # To account for the first line - - it = iter(newlines) - end = next(it, -1) # Like itertools.pairwise in Python 3.10 - for begin in it: - for i in range(begin + 1, end): - t = tokens[i] - if t.type == token.OP and t.string == "@": - useful = (t for t in tokens[i:end] if t.type not in _IGNORE) - yield "".join(s.string.strip("\n") for s in useful) - break - elif t.type not in _IGNORE: - return # A statement means no more decorators - end = begin - - out = list(decorators()) - out.reverse() - return out - - -class DocstringFile(_linter.PythonFile): - def __getitem__(self, i: int | slice) -> TokenInfo | Sequence[TokenInfo]: - return self.tokens[i] - - def next_token(self, start: int, token_type: int, error: str) -> int: - for i in range(start, len(self.tokens)): - if self.tokens[i].type == token_type: - return i - raise _linter.ParseError(self.tokens[-1], error) - - def docstring(self, start: int) -> str: - for i in range(start + 1, len(self.tokens)): - tk = self.tokens[i] - if tk.type == token.STRING: - return tk.string - if tk.type not in _linter.EMPTY_TOKENS: - return "" - return "" - - @cached_property - def indent_to_dedent(self) -> dict[int, int]: - dedents = dict[int, int]() - stack = list[int]() - - for i, t in enumerate(self.tokens): - if t.type == token.INDENT: - stack.append(i) - elif t.type == token.DEDENT: - dedents[stack.pop()] = i - - return dedents - - @cached_property - def errors(self) -> dict[str, str]: - return {} - - @cached_property - def blocks(self) -> list[Block]: - blocks: list[Block] = [] - - for i in range(len(self.tokens)): - try: - if (b := self.block(i)) is not None: - blocks.append(b) - except _linter.ParseError as e: - self.errors[e.token.line] = " ".join(e.args) - - for i, parent in enumerate(blocks): - for j in range(i + 1, len(blocks)): - if parent.contains(child := blocks[j]): - child.parent = i - parent.children.append(j) - else: - break - - for i, b in enumerate(blocks): - b.index = i - - parents = [b] - while (p := parents[-1].parent) is not None: - parents.append(blocks[p]) - parents = parents[1:] - - b.is_local = not all(p.is_class for p in parents) - b.is_method = not b.is_class and bool(parents) and parents[0].is_class - - def add_full_names(children: Sequence[Block], prefix: str = "") -> None: - dupes: dict[str, list[Block]] = {} - for b in children: - dupes.setdefault(b.name, []).append(b) - - for dl in dupes.values(): - for i, b in enumerate(dl): - suffix = f"[{i + 1}]" if len(dl) > 1 else "" - b.full_name = prefix + b.name + suffix - - for b in children: - if kids := [blocks[i] for i in b.children]: - add_full_names(kids, b.full_name + ".") - - add_full_names([b for b in blocks if b.parent is None]) - return blocks - - def block(self, begin: int) -> Block | None: - t = self.tokens[begin] - if not (t.type == token.NAME and t.string in ("class", "def")): - return None - - category = Block.Category[t.string.upper()] - try: - ni = self.next_token(begin + 1, token.NAME, "Definition but no name") - name = self.tokens[ni].string - indent = self.next_token(ni + 1, token.INDENT, "Definition but no indent") - dedent = self.indent_to_dedent[indent] - docstring = self.docstring(indent) - except _linter.ParseError: - name = "(ParseError)" - indent = -1 - dedent = -1 - docstring = "" - - return Block( - begin=begin, - category=category, - dedent=dedent, - docstring=docstring, - indent=indent, - name=name, - tokens=self.tokens, - ) - - -class DocstringLinter(_linter.FileLinter[DocstringFile]): +class DocstringLinter(_linter.FileLinter): linter_name = "docstring_linter" description = DESCRIPTION is_fixer = False @@ -332,26 +60,26 @@ def lint_all(self) -> bool: self._write_grandfather() return success - def _lint(self, df: DocstringFile) -> Iterator[_linter.LintResult]: - if (p := str(df.path)) in self.path_to_blocks: + def _lint(self, pf: _linter.PythonFile) -> Iterator[_linter.LintResult]: + if (p := str(pf.path)) in self.path_to_blocks: print("Repeated file", p, file=sys.stderr) return - blocks = df.blocks - bad = {b for b in blocks if self._is_bad_block(b, df)} + blocks = pf.blocks + bad = {b for b in blocks if self._is_bad_block(b, pf)} bad = self._dont_require_constructor_and_class_docs(blocks, bad) - gf = self._grandfathered(df.path, bad) + gf = self._grandfathered(pf.path, bad) - yield from (self._block_result(b, df) for b in sorted(bad - gf)) + yield from (self._block_result(b, pf) for b in sorted(bad - gf)) - def as_data(b: Block) -> dict[str, Any]: + def as_data(b: _linter.Block) -> dict[str, Any]: status = "grandfather" if b in gf else "bad" if b in bad else "good" return {"status": status, **b.as_data()} self.path_to_blocks[p] = [as_data(b) for b in blocks] - def _error(self, df: DocstringFile, result: _linter.LintResult) -> None: - self.path_to_errors[str(df.path)] = [{str(result.line): result.name}] + def _error(self, pf: _linter.PythonFile, result: _linter.LintResult) -> None: + self.path_to_errors[str(pf.path)] = [{str(result.line): result.name}] @cached_property def _grandfather(self) -> dict[str, dict[str, Any]]: @@ -368,20 +96,24 @@ def _grandfather(self) -> dict[str, dict[str, Any]]: def _max_lines(self) -> dict[str, int]: return {"class": self.args.max_class, "def": self.args.max_def} - def _grandfathered(self, path: Path | None, bad: set[Block]) -> set[Block]: + def _grandfathered( + self, path: Path | None, bad: set[_linter.Block] + ) -> set[_linter.Block]: if path is None or self.args.no_grandfather or self.args.write_grandfather: return set() grand: dict[str, int] = self._grandfather.get(str(path), {}) tolerance_ratio = 1 + self.args.grandfather_tolerance / 100.0 - def grandfathered(b: Block) -> bool: + def grandfathered(b: _linter.Block) -> bool: lines = int(grand.get(b.display_name, 0) * tolerance_ratio) return b.line_count <= lines return {b for b in bad if grandfathered(b)} - def _block_result(self, b: Block, df: DocstringFile) -> _linter.LintResult: + def _block_result( + self, b: _linter.Block, pf: _linter.PythonFile + ) -> _linter.LintResult: def_name = "function" if b.category == "def" else "class" msg = f"docstring found for {def_name} '{b.name}' ({b.line_count} lines)" if len(b.docstring): @@ -392,23 +124,23 @@ def _block_result(self, b: Block, df: DocstringFile) -> _linter.LintResult: msg = f"No {msg}" if b.is_method: msg = f"{msg}. {METHOD_OVERRIDE_HINT}" - return _linter.LintResult(msg, *df.tokens[b.begin].start) + return _linter.LintResult(msg, *pf.tokens[b.begin].start) def _display( - self, df: DocstringFile, results: list[_linter.LintResult] + self, pf: _linter.PythonFile, results: list[_linter.LintResult] ) -> Iterator[str]: if not self.args.report: - yield from super()._display(df, results) + yield from super()._display(pf, results) def _dont_require_constructor_and_class_docs( - self, blocks: Sequence[Block], bad: set[Block] - ) -> set[Block]: + self, blocks: Sequence[_linter.Block], bad: set[_linter.Block] + ) -> set[_linter.Block]: if self.args.lint_init: return bad good = {b for b in blocks if len(b.docstring) >= self.args.min_docstring} - def has_class_init_doc(b: Block) -> bool: + def has_class_init_doc(b: _linter.Block) -> bool: if b.is_class: # Is it a class whose constructor is documented? children = (blocks[i] for i in b.children) @@ -419,10 +151,10 @@ def has_class_init_doc(b: Block) -> bool: return {b for b in bad if not has_class_init_doc(b)} - def _is_bad_block(self, b: Block, df: DocstringFile) -> bool: + def _is_bad_block(self, b: _linter.Block, pf: _linter.PythonFile) -> bool: max_lines = self._max_lines[b.category] return ( - not df.omitted(df.tokens, b.begin, b.dedent) + not pf.omitted(pf.tokens, b.begin, b.dedent) and b.line_count > max_lines and len(b.docstring) < self.args.min_docstring and (self.args.lint_local or not b.is_local) diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index 764d4613a8f41d..fa2c43d6285c48 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -13,17 +13,20 @@ import time -def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: +def run_command( + args: list[str], + env: dict[str, str] | None = None, +) -> subprocess.CompletedProcess[str]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: - return subprocess.run(args, check=True) + return subprocess.run(args, env=env, text=True, encoding="utf-8", check=True) finally: end_time = time.monotonic() logging.debug("took %dms", (end_time - start_time) * 1000) -if __name__ == "__main__": +def main() -> None: parser = argparse.ArgumentParser(description="pip initializer") parser.add_argument( "packages", @@ -52,10 +55,20 @@ def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: stream=sys.stderr, ) - uv_available = shutil.which("uv") is not None + env: dict[str, str] = { + **os.environ, + "UV_PYTHON": sys.executable, + "UV_PYTHON_DOWNLOADS": "never", + "FORCE_COLOR": "1", + "CLICOLOR_FORCE": "1", + } + uv_index_url = env.get("UV_INDEX_URL", env.get("PIP_EXTRA_INDEX_URL")) + if uv_index_url: + env["UV_INDEX_URL"] = uv_index_url - if uv_available: - pip_args = ["uv", "pip", "install"] + uv: str | None = shutil.which("uv") + if uv: + pip_args = [uv, "pip", "install"] elif sys.executable: pip_args = [sys.executable, "-mpip", "install"] else: @@ -89,4 +102,8 @@ def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: print(f"Would have run: {pip_args}") sys.exit(0) - run_command(pip_args) + run_command(pip_args, env=env) + + +if __name__ == "__main__": + main() diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index e714d70a31d5df..3b0f5d60a1e5f9 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -38,30 +38,22 @@ # torchgen/** # test/** # test/[a-h]*/** - "test/[a-h]*/**", # test/[i-j]*/** - "test/[i-j]*/**", - # test/[k-n]*/** - "test/[k-n]*/**", + "test/j*/**", + # test/[k-m]*/** + "test/[k-m]*/**", # test/optim/** - "test/optim/**", # "test/[p-z]*/**", "test/[p-z]*/**", # torch/** # torch/_[a-c]*/** - "torch/_[a-c]*/**", # torch/_[e-h]*/** - "torch/_[e-h]*/**", # torch/_i*/** # torch/_[j-z]*/** - "torch/_[j-z]*/**", # torch/[a-c]*/** - "torch/[a-c]*/**", # torch/d*/** - # torch/[e-n]*/** - "torch/[e-n]*/**", + # torch/[e-m]*/** # torch/optim/** - "torch/optim/**", # torch/[p-z]*/** "torch/[p-z]*/**", ], diff --git a/tools/linter/adapters/set_linter.py b/tools/linter/adapters/set_linter.py index 3497684410c6ce..2a9331f4c90872 100644 --- a/tools/linter/adapters/set_linter.py +++ b/tools/linter/adapters/set_linter.py @@ -1,9 +1,6 @@ from __future__ import annotations -import dataclasses as dc import sys -import token -from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING @@ -17,8 +14,7 @@ import _linter if TYPE_CHECKING: - from collections.abc import Iterator, Sequence - from tokenize import TokenInfo + from collections.abc import Iterator ERROR = "Builtin `set` is deprecated" @@ -73,104 +69,22 @@ """ -@dc.dataclass -class LineWithSets: - """A logical line of Python tokens, terminated by a NEWLINE or the end of file""" - - tokens: list[TokenInfo] - - @cached_property - def sets(self) -> list[TokenInfo]: - """A list of tokens which use the built-in set symbol""" - return [t for i, t in enumerate(self.tokens) if self.is_set(i)] - - @cached_property - def braced_sets(self) -> list[list[TokenInfo]]: - """A list of lists of tokens, each representing a braced set, like {1}""" - return [ - self.tokens[b : e + 1] - for b, e in self.bracket_pairs.items() - if self.is_braced_set(b, e) - ] - - @cached_property - def bracket_pairs(self) -> dict[int, int]: - return _linter.bracket_pairs(self.tokens) - - def is_set(self, i: int) -> bool: - t = self.tokens[i] - after = i < len(self.tokens) - 1 and self.tokens[i + 1] - if t.string == "Set" and t.type == token.NAME: - return after and after.string == "[" and after.type == token.OP - return ( - (t.string == "set" and t.type == token.NAME) - and not (i and self.tokens[i - 1].string in ("def", ".")) - and not (after and after.string == "=" and after.type == token.OP) - ) - - def is_braced_set(self, begin: int, end: int) -> bool: - if begin + 1 == end or self.tokens[begin].string != "{": - return False - if begin and self.tokens[begin - 1].string == "in": - return False # skip `x in {1, 2, 3}` - i = begin + 1 - empty = True - while i < end: - t = self.tokens[i] - if t.type == token.OP and t.string in (":", "**"): - return False - if brace_end := self.bracket_pairs.get(i): - # Skip to the end of a subexpression - i = brace_end - elif t.type not in _linter.EMPTY_TOKENS: - empty = False - i += 1 - return not empty - - -class SetFile(_linter.PythonFile): - @cached_property - def braced_sets(self) -> list[Sequence[TokenInfo]]: - lines = [t for tl in self._lines_with_sets for t in tl.braced_sets] - return [s for s in lines if not self.omitted(s)] - - @cached_property - def sets(self) -> list[TokenInfo]: - tokens = [t for tl in self._lines_with_sets for t in tl.sets] - return [t for t in tokens if not self.omitted([t])] - - @cached_property - def insert_import_line(self) -> int | None: - froms, imports = self.import_lines - for i in froms + imports: - tl = self.token_lines[i] - if any(i.type == token.NAME and i.string == "OrderedSet" for i in tl): - return None - if section := froms or imports: - return self._lines_with_sets[section[-1]].tokens[-1].start[0] + 1 - return self.opening_comment_lines + 1 - - @cached_property - def _lines_with_sets(self) -> list[LineWithSets]: - return [LineWithSets(tl) for tl in self.token_lines] - - -class SetLinter(_linter.FileLinter[SetFile]): +class SetLinter(_linter.FileLinter): linter_name = "set_linter" description = DESCRIPTION epilog = EPILOG report_column_numbers = True - def _lint(self, sf: SetFile) -> Iterator[_linter.LintResult]: - if (sf.sets or sf.braced_sets) and (ins := sf.insert_import_line) is not None: + def _lint(self, pf: _linter.PythonFile) -> Iterator[_linter.LintResult]: + if (pf.sets or pf.braced_sets) and (ins := pf.insert_import_line) is not None: yield _linter.LintResult( "Add import for OrderedSet", ins, 0, IMPORT_LINE, 0 ) - for b in sf.braced_sets: + for b in pf.braced_sets: yield _linter.LintResult(ERROR, *b[0].start, "OrderedSet([", 1) yield _linter.LintResult(ERROR, *b[-1].start, "])", 1) - for s in sf.sets: + for s in pf.sets: yield _linter.LintResult(ERROR, *s.start, "OrderedSet", 3) diff --git a/tools/linter/clang_tidy/generate_build_files.py b/tools/linter/clang_tidy/generate_build_files.py index af322e754b877e..157b36d2257aea 100644 --- a/tools/linter/clang_tidy/generate_build_files.py +++ b/tools/linter/clang_tidy/generate_build_files.py @@ -31,7 +31,8 @@ def gen_compile_commands() -> None: os.environ["USE_PRECOMPILED_HEADERS"] = "1" os.environ["CC"] = "clang" os.environ["CXX"] = "clang++" - run_cmd([sys.executable, "setup.py", "--cmake-only", "build"]) + os.environ["CMAKE_ONLY"] = "1" + run_cmd([sys.executable, "setup.py", "build"]) def run_autogen() -> None: diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt new file mode 100644 index 00000000000000..89e6c8645a1bf5 --- /dev/null +++ b/tools/linter/dictionary.txt @@ -0,0 +1,42 @@ +ans +belows +BU +contiguities +contiguity +coo +Din +Dout +dOut +ElementE +followings +fro +froms +Halfs +hsa +nd +nin +NotIn +nout +NowNs +numer +oH +optins +ot +overrideable +oW +padD +ptd +rebuild +rebuilt +reenable +reenabled +requestor +serde +serder +serdes +statics +strat +supercede +supercedes +te +WONT diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py index d62d622326a4cd..f90d33c5ba4529 100644 --- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py +++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py @@ -25,8 +25,8 @@ selected_kernel_dtypes_h_template_str = """ #include -#include #include +#include namespace at { inline constexpr bool should_include_kernel_dtype( diff --git a/tools/lldb/deploy_debugger.py b/tools/lldb/deploy_debugger.py index 135a6167e3a4d5..7a28c72a6caf21 100644 --- a/tools/lldb/deploy_debugger.py +++ b/tools/lldb/deploy_debugger.py @@ -25,7 +25,7 @@ stem = Path(name).stem with NamedTemporaryFile(prefix=stem, suffix='.so', delete=False) as tf: tf.write(r) - print("torch_deploy registering debug inforation for ", tf.name) + print("torch_deploy registering debug information for ", tf.name) cmd1 = f"target modules add {tf.name}" # print(cmd1) lldb.debugger.HandleCommand(cmd1) diff --git a/tools/nightly.py b/tools/nightly.py index 7de5d8c4c58f1b..0ed8cfe165aa95 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -9,7 +9,12 @@ $ ./tools/nightly.py checkout -b my-nightly-branch $ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows -Or if you would like to re-use an existing virtual environment, you can pass in +Or if you would like to check out the nightly commit in detached HEAD mode:: + + $ ./tools/nightly.py checkout + $ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows + +Or if you would like to reuse an existing virtual environment, you can pass in the prefix argument (--prefix):: $ ./tools/nightly.py checkout -b my-nightly-branch -p my-env @@ -92,6 +97,7 @@ LOGGER: logging.Logger | None = None +VERBOSE: bool = False DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss" SHA1_RE = re.compile(r"(?P[0-9a-fA-F]{40})") USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@") @@ -123,23 +129,25 @@ class PipSource(NamedTuple): supported_platforms={"Linux", "macOS", "Windows"}, accelerator="cpu", ), - "cuda-11.8": PipSource( - name="cuda-11.8", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu118", + # NOTE: Sync with CUDA_ARCHES in .github/scripts/generate_binary_build_matrix.py + "cuda-12.6": PipSource( + name="cuda-12.6", + index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu126", supported_platforms={"Linux", "Windows"}, accelerator="cuda", ), - "cuda-12.4": PipSource( - name="cuda-12.4", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu124", + "cuda-12.8": PipSource( + name="cuda-12.8", + index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu128", supported_platforms={"Linux", "Windows"}, accelerator="cuda", ), - "cuda-12.6": PipSource( - name="cuda-12.6", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu126", - supported_platforms={"Linux", "Windows"}, - accelerator="cuda", + # NOTE: Sync with ROCM_ARCHES in .github/scripts/generate_binary_build_matrix.py + "rocm-6.3": PipSource( + name="rocm-6.3", + index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm6.3", + supported_platforms={"Linux"}, + accelerator="rocm", ), "rocm-6.4": PipSource( name="rocm-6.4", @@ -218,7 +226,14 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: class Venv: """Virtual environment manager""" - AGGRESSIVE_UPDATE_PACKAGES = ("pip", "setuptools", "packaging", "wheel") + AGGRESSIVE_UPDATE_PACKAGES = ( + "uv", + "pip", + "setuptools", + "packaging", + "wheel", + "build[uv]", + ) def __init__( self, @@ -231,21 +246,36 @@ def __init__( self.pip_source = pip_source self.base_executable = Path(base_executable or sys.executable).absolute() self._executable: Path | None = None - self._env = {"PIP_EXTRA_INDEX_URL": self.pip_source.index_url} + self._bindir: Path | None = None + self._env = { + "PIP_EXTRA_INDEX_URL": self.pip_source.index_url, + "UV_INDEX": self.pip_source.index_url, + "UV_PYTHON_DOWNLOADS": "never", + "FORCE_COLOR": "1", + "CLICOLOR_FORCE": "1", + } def is_venv(self) -> bool: """Check if the prefix is a virtual environment.""" return self.prefix.is_dir() and (self.prefix / "pyvenv.cfg").is_file() + @property + def bindir(self) -> Path: + """Get the bin directory for the virtual environment.""" + assert self.is_venv() + if self._bindir is None: + if WINDOWS: + self._bindir = self.prefix / "Scripts" + else: + self._bindir = self.prefix / "bin" + return self._bindir + @property def executable(self) -> Path: """Get the Python executable for the virtual environment.""" assert self.is_venv() if self._executable is None: - if WINDOWS: - executable = self.prefix / "Scripts" / "python.exe" - else: - executable = self.prefix / "bin" / "python" + executable = self.bindir / ("python.exe" if WINDOWS else "python") assert executable.is_file() or executable.is_symlink() assert os.access(executable, os.X_OK), f"{executable} is not executable" self._executable = executable @@ -400,9 +430,10 @@ def python( python = self.executable cmd = [str(python), *args] env = popen_kwargs.pop("env", None) or {} + check = popen_kwargs.pop("check", True) return subprocess.run( cmd, - check=True, + check=check, text=True, encoding="utf-8", env={**self._env, **env}, @@ -433,6 +464,54 @@ def base_python_version(self) -> str: """Get the Python version for the base environment.""" return self.python_version(python=self.base_executable) + def uv( + self, + *args: str, + python: Path | str | None = None, + **popen_kwargs: Any, + ) -> subprocess.CompletedProcess[str]: + """Run a uv command in the virtual environment.""" + if python is None: + python = self.executable + cmd = [str(self.bindir / "uv"), *args] + env = popen_kwargs.pop("env", None) or {} + check = popen_kwargs.pop("check", True) + return subprocess.run( + cmd, + check=check, + text=True, + encoding="utf-8", + env={**self._env, **env, "UV_PYTHON": str(python)}, + **popen_kwargs, + ) + + @timed("Installing packages") + def uv_pip_install( + self, + *packages: str, + prerelease: bool = False, + upgrade: bool = False, + no_deps: bool = False, + **popen_kwargs: Any, + ) -> subprocess.CompletedProcess[str]: + """Run a pip install command in the virtual environment.""" + uv_pip_args = [] + if VERBOSE: + uv_pip_args.append("-v") + if prerelease: + uv_pip_args.append("--prerelease") + if upgrade: + uv_pip_args.append("--upgrade") + verb = "Upgrading" + else: + verb = "Installing" + if no_deps: + uv_pip_args.append("--no-deps") + print(f"{verb} package(s) ({self.pip_source.index_url}):") + for package in packages: + print(f" - {os.path.basename(package)}") + return self.uv("pip", "install", *uv_pip_args, *packages, **popen_kwargs) + def pip(self, *args: str, **popen_kwargs: Any) -> subprocess.CompletedProcess[str]: """Run a pip command in the virtual environment.""" return self.python("-m", "pip", *args, **popen_kwargs) @@ -443,28 +522,33 @@ def pip_install( *packages: str, prerelease: bool = False, upgrade: bool = False, + no_deps: bool = False, **popen_kwargs: Any, ) -> subprocess.CompletedProcess[str]: """Run a pip install command in the virtual environment.""" + pip_args = [] + if VERBOSE: + pip_args.append("-v") + if prerelease: + pip_args.append("--pre") if upgrade: - args = ["--upgrade", *packages] + pip_args.append("--upgrade") verb = "Upgrading" else: - args = list(packages) verb = "Installing" - if prerelease: - args = ["--pre", *args] - print( - f"{verb} package(s) ({self.pip_source.index_url}): " - f"{', '.join(map(os.path.basename, packages))}" - ) - return self.pip("install", *args, **popen_kwargs) + if no_deps: + pip_args.append("--no-deps") + print(f"{verb} package(s) ({self.pip_source.index_url}):") + for package in packages: + print(f" - {os.path.basename(package)}") + return self.pip("install", *pip_args, *packages, **popen_kwargs) @timed("Downloading packages") def pip_download( self, *packages: str, prerelease: bool = False, + no_deps: bool = False, **popen_kwargs: Any, ) -> list[Path]: """Download a package in the virtual environment.""" @@ -475,11 +559,14 @@ def pip_download( f"Downloading package(s) ({self.pip_source.index_url}): " f"{', '.join(packages)}" ) + pip_args = [] + if VERBOSE: + pip_args.append("-v") if prerelease: - args = ["--pre", *packages] - else: - args = list(packages) - self.pip("download", "--dest", str(tempdir), *args, **popen_kwargs) + pip_args.append("--pre") + if no_deps: + pip_args.append("--no-deps") + self.pip("download", f"--dest={tempdir}", *pip_args, *packages, **popen_kwargs) files = list(tempdir.iterdir()) print(f"Downloaded {len(files)} file(s) to {tempdir}:") for file in files: @@ -505,7 +592,7 @@ def wheel_unpack( wheel = Path(wheel).absolute() dest = Path(dest).absolute() assert wheel.is_file() and wheel.suffix.lower() == ".whl" - return self.wheel("unpack", "--dest", str(dest), str(wheel), **popen_kwargs) + return self.wheel("unpack", f"--dest={dest}", str(wheel), **popen_kwargs) @contextlib.contextmanager def extracted_wheel(self, wheel: Path | str) -> Generator[Path]: @@ -613,19 +700,17 @@ def check_branch(subcommand: str, branch: str | None) -> str | None: """Checks that the branch name can be checked out.""" if subcommand != "checkout": return None - # first make sure actual branch name was given - if branch is None: - return "Branch name to checkout must be supplied with '-b' option" # next check that the local repo is clean cmd = git("status", "--untracked-files=no", "--porcelain") stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") if stdout.strip(): return "Need to have clean working tree to checkout!\n\n" + stdout - # next check that the branch name doesn't already exist - cmd = git("show-ref", "--verify", "--quiet", f"refs/heads/{branch}") - p = subprocess.run(cmd, capture_output=True, check=False) # type: ignore[assignment] - if not p.returncode: - return f"Branch {branch!r} already exists" + # next check that the branch name doesn't already exist (if a branch name is provided) + if branch is not None: + cmd = git("show-ref", "--verify", "--quiet", f"refs/heads/{branch}") + p = subprocess.run(cmd, capture_output=True, check=False) # type: ignore[assignment] + if not p.returncode: + return f"Branch {branch!r} already exists" return None @@ -635,7 +720,7 @@ def install_packages(venv: Venv, packages: Iterable[str]) -> None: # install packages packages = list(dict.fromkeys(packages)) if packages: - venv.pip_install(*packages) + venv.uv_pip_install(*packages) def _ensure_commit(git_sha1: str) -> None: @@ -680,10 +765,15 @@ def _nightly_version(site_dir: Path) -> str: @timed("Checking out nightly PyTorch") -def checkout_nightly_version(branch: str, site_dir: Path) -> None: - """Get's the nightly version and then checks it out.""" +def checkout_nightly_version(branch: str | None, site_dir: Path) -> None: + """Gets the nightly version and then checks it out.""" nightly_version = _nightly_version(site_dir) - cmd = git("checkout", "-b", branch, nightly_version) + if branch is None: + # Detached mode - explicitly use --detach flag + cmd = git("checkout", "--detach", nightly_version) + else: + # Branch mode + cmd = git("checkout", "-b", branch, nightly_version) subprocess.check_call(cmd) @@ -780,12 +870,14 @@ def _move_single( relname = relroot / name s = src / relname t = trg / relname - print(f"{verb} {s} -> {t}") + if VERBOSE: + print(f"{verb} {s} -> {t}") mover(s, t) for name in dirs: (trg / relroot / name).mkdir(parents=True, exist_ok=True) else: - print(f"{verb} {src} -> {trg}") + if VERBOSE: + print(f"{verb} {src} -> {trg}") mover(src, trg) @@ -828,13 +920,54 @@ def write_pth(venv: Venv) -> None: ) +def parse_dependencies( + venv: Venv, + wheel_site_dir: Path, +) -> list[str]: + """Parse dependencies from the torch wheel's metadata.""" + dist_info_dirs = list(wheel_site_dir.glob("*.dist-info")) + if len(dist_info_dirs) != 1: + raise RuntimeError( + f"Expected exactly one .dist-info directory in {wheel_site_dir}, " + f"got {dist_info_dirs}" + ) + dist_info_dir = dist_info_dirs[0] + if not (dist_info_dir / "METADATA").is_file(): + raise RuntimeError( + f"Expected METADATA file in {dist_info_dir}, but it does not exist." + ) + + # Use the Python interpreter in the virtual environment instead of the interpreter + # running this script, so that we can evaluate markers correctly. + dependencies = ( + venv.python( + "-c", + textwrap.dedent( + """ + from packaging.metadata import Metadata + + with open("METADATA", encoding="utf-8") as f: + metadata = Metadata.from_email(f.read()) + for req in metadata.requires_dist: + if req.marker is None or req.marker.evaluate(): + print(req) + """ + ).strip(), + cwd=dist_info_dir, + capture_output=True, + ) + .stdout.strip() + .splitlines() + ) + return [dep.strip() for dep in dependencies] + + def install( *, venv: Venv, packages: Iterable[str], subcommand: str = "checkout", branch: str | None = None, - logger: logging.Logger, ) -> None: """Development install of PyTorch""" use_existing = subcommand == "checkout" @@ -845,22 +978,21 @@ def install( packages = [p for p in packages if p != "torch"] - dependencies = venv.pip_download("torch", prerelease=True) - torch_wheel = [ - dep - for dep in dependencies - if dep.name.startswith("torch-") and dep.name.endswith(".whl") - ] - if len(torch_wheel) != 1: + downloaded_files = venv.pip_download("torch", prerelease=True, no_deps=True) + if len(downloaded_files) != 1: + raise RuntimeError(f"Expected exactly one torch wheel, got {downloaded_files}") + torch_wheel = downloaded_files[0] + if not ( + torch_wheel.name.startswith("torch-") and torch_wheel.name.endswith(".whl") + ): raise RuntimeError(f"Expected exactly one torch wheel, got {torch_wheel}") - torch_wheel = torch_wheel[0] - dependencies = [deps for deps in dependencies if deps != torch_wheel] - - install_packages(venv, [*packages, *map(str, dependencies)]) with venv.extracted_wheel(torch_wheel) as wheel_site_dir: + dependencies = parse_dependencies(venv, wheel_site_dir) + install_packages(venv, [*dependencies, *packages]) + if subcommand == "checkout": - checkout_nightly_version(cast(str, branch), wheel_site_dir) + checkout_nightly_version(branch, wheel_site_dir) elif subcommand == "pull": pull_nightly_version(wheel_site_dir) else: @@ -868,11 +1000,11 @@ def install( move_nightly_files(wheel_site_dir) write_pth(venv) - logger.info( + cast(logging.Logger, LOGGER).info( "-------\n" "PyTorch Development Environment set up!\n" "Please activate to enable this environment:\n\n" - " $ %s", + " $ %s\n", venv.activate_command, ) @@ -893,7 +1025,7 @@ def find_executable(name: str) -> Path: checkout.add_argument( "-b", "--branch", - help="Branch name to checkout", + help="Branch name to checkout (if omitted, checks out in detached HEAD mode)", dest="branch", default=None, metavar="NAME", @@ -969,14 +1101,15 @@ def parse_arguments() -> argparse.Namespace: def main() -> None: """Main entry point""" - global LOGGER + global LOGGER, VERBOSE args = parse_arguments() + VERBOSE = args.verbose + status = check_branch(args.subcmd, args.branch) if status: sys.exit(status) pip_source = None - for toolkit in ("CUDA", "ROCm"): accel = toolkit.lower() if hasattr(args, accel): @@ -1016,7 +1149,6 @@ def main() -> None: packages=PACKAGES_TO_INSTALL, subcommand=args.subcmd, branch=args.branch, - logger=logger, ) diff --git a/tools/nvcc_fix_deps.py b/tools/nvcc_fix_deps.py index 0c0c9db66693a3..a4a3b536eeae87 100644 --- a/tools/nvcc_fix_deps.py +++ b/tools/nvcc_fix_deps.py @@ -1,4 +1,4 @@ -"""Tool to fix the nvcc's dependecy file output +"""Tool to fix the nvcc's dependency file output Usage: python nvcc_fix_deps.py nvcc [nvcc args]... diff --git a/tools/onnx/update_default_opset_version.py b/tools/onnx/update_default_opset_version.py deleted file mode 100755 index 88a98e5b27c0b9..00000000000000 --- a/tools/onnx/update_default_opset_version.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 - -"""Updates the default value of opset_version. - -The current policy is that the default should be set to the -latest released version as of 18 months ago. - -Usage: -Run with no arguments. -""" - -import argparse -import datetime -import os -import re -import subprocess -import sys -from pathlib import Path -from subprocess import DEVNULL -from typing import Any - - -def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None: - with open(path, encoding="utf-8") as f: - content_str = f.read() - content_str = re.sub(prefix_pat, rf"\g<1>{new_default}", content_str) - with open(path, "w", encoding="utf-8") as f: - f.write(content_str) - print("modified", path) - - -def main(args: Any) -> None: - pytorch_dir = Path(__file__).parents[2].resolve() - onnx_dir = pytorch_dir / "third_party" / "onnx" - os.chdir(onnx_dir) - - date = datetime.datetime.now() - datetime.timedelta(days=18 * 30) - onnx_commit = subprocess.check_output( - ("git", "log", f"--until={date}", "--max-count=1", "--format=%H"), - encoding="utf-8", - ).strip() - onnx_tags = subprocess.check_output( - ("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8" - ) - tag_tups = [] - semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)") - for tag in onnx_tags.splitlines(): - match = semver_pat.match(tag) - if match: - tag_tups.append(tuple(int(x) for x in match.groups())) - - # Take the release 18 months ago - version_str = "{}.{}.{}".format(*min(tag_tups)) - - print("Using ONNX release", version_str) - - head_commit = subprocess.check_output( - ("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8" - ).strip() - - new_default = None - - subprocess.check_call( - ("git", "checkout", f"v{version_str}"), stdout=DEVNULL, stderr=DEVNULL - ) - try: - from onnx import helper # type: ignore[import] - - for version in helper.VERSION_TABLE: - if version[0] == version_str: - new_default = version[2] - print("found new default opset_version", new_default) - break - if not new_default: - sys.exit( - f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}" - ) - finally: - subprocess.check_call( - ("git", "checkout", head_commit), stdout=DEVNULL, stderr=DEVNULL - ) - - os.chdir(pytorch_dir) - - read_sub_write( - os.path.join("torch", "onnx", "_constants.py"), - r"(ONNX_DEFAULT_OPSET = )\d+", - new_default, - ) - read_sub_write( - os.path.join("torch", "onnx", "utils.py"), - r"(opset_version \(int, default )\d+", - new_default, - ) - - if not args.skip_build: - print("Building PyTorch...") - subprocess.check_call( - ("python", "setup.py", "develop"), - ) - print("Updating operator .expect files") - subprocess.check_call( - ("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"), - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--skip-build", - "--skip_build", - action="store_true", - help="Skip building pytorch", - ) - main(parser.parse_args()) diff --git a/tools/optional_submodules.py b/tools/optional_submodules.py new file mode 100644 index 00000000000000..1e7589edf2fb99 --- /dev/null +++ b/tools/optional_submodules.py @@ -0,0 +1,63 @@ +import os +from pathlib import Path +from subprocess import check_call + + +repo_root = Path(__file__).absolute().parent.parent +third_party_path = repo_root / "third_party" + + +def _read_file(path: Path) -> str: + with path.open(encoding="utf-8") as f: + return f.read().strip() + + +def _checkout_by_tag(repo: str, tag: str) -> None: + check_call( + [ + "git", + "clone", + "--depth", + "1", + "--branch", + tag, + repo, + ], + cwd=third_party_path, + ) + + +def read_nccl_pin() -> str: + nccl_file = "nccl-cu12.txt" + if os.getenv("DESIRED_CUDA", os.getenv("CUDA_VERSION", "")).startswith("11"): + nccl_file = "nccl-cu11.txt" + nccl_pin_path = repo_root / ".ci" / "docker" / "ci_commit_pins" / nccl_file + return _read_file(nccl_pin_path) + + +def checkout_nccl() -> None: + release_tag = read_nccl_pin() + print(f"-- Checkout nccl release tag: {release_tag}") + nccl_basedir = third_party_path / "nccl" + if not nccl_basedir.exists(): + _checkout_by_tag("https://github.com/NVIDIA/nccl", release_tag) + + +def checkout_eigen() -> None: + eigen_tag = _read_file(third_party_path / "eigen_pin.txt") + print(f"-- Checkout Eigen release tag: {eigen_tag}") + eigen_basedir = third_party_path / "eigen" + if not eigen_basedir.exists(): + _checkout_by_tag("https://gitlab.com/libeigen/eigen", eigen_tag) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) == 1: + # If no arguments are given checkout all optional dependency + checkout_nccl() + checkout_eigen() + else: + # Otherwise just call top-level function of choice + globals()[sys.argv[1]]() diff --git a/tools/packaging/build_wheel.py b/tools/packaging/build_wheel.py index 96e4978c7fcdbd..16e9a87bd9638f 100644 --- a/tools/packaging/build_wheel.py +++ b/tools/packaging/build_wheel.py @@ -62,7 +62,7 @@ def venv(interpreter: str) -> Iterator[str]: class Builder: - # The python interpeter that we should be using + # The python interpreter that we should be using interpreter: str def __init__(self, interpreter: str) -> None: @@ -124,7 +124,7 @@ def main() -> None: with venv(interpreter) as venv_interpreter: builder = Builder(venv_interpreter) # clean actually requires setuptools so we need to ensure we - # install requriements before + # install requirements before builder.install_requirements() builder.clean() diff --git a/tools/packaging/split_wheel.py b/tools/packaging/split_wheel.py index 1aa77aa5c69491..fd52c39a22b029 100644 --- a/tools/packaging/split_wheel.py +++ b/tools/packaging/split_wheel.py @@ -76,11 +76,15 @@ def split_build(cmd: str) -> None: extra_env={"BUILD_LIBTORCH_WHL": "1", "BUILD_PYTHON_ONLY": "0"}, ) logger.info("Running %s for torch wheel", cmd) - # NOTE: Passing --cmake is necessary here since the torch frontend has it's + # NOTE: Passing CMAKE_FRESH=1 is necessary here since the torch frontend has it's # own cmake files that it needs to generate setup_py( - [cmd, "--cmake"], - extra_env={"BUILD_LIBTORCH_WHL": "0", "BUILD_PYTHON_ONLY": "1"}, + [cmd], + extra_env={ + "BUILD_LIBTORCH_WHL": "0", + "BUILD_PYTHON_ONLY": "1", + "CMAKE_FRESH": "1", + }, ) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 629dadc00c2865..81fadb855b004e 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -261,7 +261,7 @@ def sig_for_ops(opname: str) -> list[str]: ] return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."] elif name in logic_ops: - return [f"def {opname}(self, other: Tensor | _bool) -> Tensor: ..."] + return [f"def {opname}(self, other: Tensor | _int) -> Tensor: ..."] elif name in shift_ops: return [f"def {opname}(self, other: Tensor | _int) -> Tensor: ..."] elif name in symmetric_comparison_ops: @@ -412,6 +412,16 @@ def gen_nn_functional(fm: FileManager) -> None: "tuple[Tensor, Tensor]", ) ], + f"adaptive_avg_pool{d}d": [ + defs( + f"adaptive_avg_pool{d}d", + [ + INPUT, + "output_size: _int | _size", + ], + "Tensor", + ) + ], } ) @@ -516,6 +526,31 @@ def gen_nn_functional(fm: FileManager) -> None: "Tensor", ) ], + "binary_cross_entropy": [ + defs( + "binary_cross_entropy", + [ + INPUT, + "target: Tensor", + "weight: Tensor | None = None", + "reduction: str = ...", + ], + "Tensor", + ) + ], + "col2im": [ + defs( + "col2im", + [ + INPUT, + "output_size: _int | _size", + KERNEL_SIZE, + "dilation: _int | _size", + *STRIDE_PADDING, + ], + "Tensor", + ) + ], } ) @@ -877,6 +912,27 @@ def gen_pyi( "None", ) ], + "_functionalize_mutation_counter": [ + defs( + "_functionalize_mutation_counter", + ["t: Tensor"], + "_int", + ) + ], + "_functionalize_storage_changed_counter": [ + defs( + "_functionalize_storage_changed_counter", + ["t: Tensor"], + "_int", + ) + ], + "_functionalize_inductor_storage_resized_counter": [ + defs( + "_functionalize_inductor_storage_resized_counter", + ["t: Tensor"], + "_int", + ) + ], "_functionalize_are_all_mutations_hidden_from_autograd": [ defs( "_functionalize_are_all_mutations_hidden_from_autograd", @@ -902,8 +958,8 @@ def gen_pyi( "_functionalize_was_storage_changed": [ defs("_functionalize_was_storage_changed", ["tensor: Tensor"], "_bool") ], - "_functionalize_set_storage_changed": [ - "def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..." + "_functionalize_mark_storage_changed": [ + "def _functionalize_mark_storage_changed(tensor: Tensor) -> _bool: ..." ], "_functionalize_has_metadata_mutation": [ defs( @@ -1681,10 +1737,10 @@ def replace_special_case(hint: str) -> str: # Include only the functions that contain hints, to prevent undefined # symbols to be included in the `__all__` directive. - hinted_function_names = [ + hinted_function_names = { name for name, hint in unsorted_function_hints.items() if hint - ] - all_symbols = sorted(list(structseqs) + hinted_function_names) + } + all_symbols = sorted(hinted_function_names.union(structseqs)) all_directive = [ "__all__ = [", *(f' "{name}",' for name in all_symbols), diff --git a/tools/setup_helpers/__init__.py b/tools/setup_helpers/__init__.py index 32731175f18037..e227fd2ac0d954 100644 --- a/tools/setup_helpers/__init__.py +++ b/tools/setup_helpers/__init__.py @@ -2,9 +2,17 @@ import os import sys +import warnings def which(thefile: str) -> str | None: + warnings.warn( + "tools.setup_helpers.which is deprecated and will be removed in a future version. " + "Use shutil.which instead.", + FutureWarning, + stacklevel=2, + ) + path = os.environ.get("PATH", os.defpath).split(os.pathsep) for d in path: fname = os.path.join(d, thefile) diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 1fd7377cb07d74..4d131846ea1feb 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -1,22 +1,34 @@ -"Manages CMake." +"""Manages CMake.""" from __future__ import annotations +import functools +import json import multiprocessing import os import platform +import shutil import sys import sysconfig -from distutils.version import LooseVersion from pathlib import Path -from subprocess import CalledProcessError, check_call, check_output -from typing import Any, cast +from subprocess import CalledProcessError, check_call, check_output, DEVNULL +from typing import cast -from . import which from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file from .env import BUILD_DIR, check_negative_env_flag, IS_64BIT, IS_DARWIN, IS_WINDOWS +try: + from packaging.version import Version +except ImportError: + try: + from setuptools.dist import Version # type: ignore[attr-defined,no-redef] + except ImportError: + from distutils.version import ( # type: ignore[assignment,no-redef] + LooseVersion as Version, + ) + + def _mkdir_p(d: str) -> None: try: os.makedirs(d, exist_ok=True) @@ -26,10 +38,14 @@ def _mkdir_p(d: str) -> None: ) from e +# Print to stderr +eprint = functools.partial(print, file=sys.stderr, flush=True) + + # Ninja # Use ninja if it is on the PATH. Previous version of PyTorch required the # ninja python package, but we no longer use it, so we do not have to import it -USE_NINJA = not check_negative_env_flag("USE_NINJA") and which("ninja") is not None +USE_NINJA = bool(not check_negative_env_flag("USE_NINJA") and shutil.which("ninja")) if "CMAKE_GENERATOR" in os.environ: USE_NINJA = os.environ["CMAKE_GENERATOR"].lower() == "ninja" @@ -50,22 +66,34 @@ def _cmake_cache_file(self) -> str: """ return os.path.join(self.build_dir, "CMakeCache.txt") + @property + def _ninja_build_file(self) -> str: + r"""Returns the path to build.ninja. + + Returns: + string: The path to build.ninja. + """ + return os.path.join(self.build_dir, "build.ninja") + @staticmethod def _get_cmake_command() -> str: - "Returns cmake command." + """Returns cmake command.""" cmake_command = "cmake" if IS_WINDOWS: return cmake_command - cmake3_version = CMake._get_version(which("cmake3")) - cmake_version = CMake._get_version(which("cmake")) + cmake3_version = CMake._get_version(shutil.which("cmake3")) + cmake_version = CMake._get_version(shutil.which("cmake")) - _cmake_min_version = LooseVersion("3.18.0") + _cmake_min_version = Version("3.27.0") if all( ver is None or ver < _cmake_min_version for ver in [cmake_version, cmake3_version] ): - raise RuntimeError("no cmake or cmake3 with version >= 3.18.0 found") + raise RuntimeError( + "no cmake or cmake3 with version >= 3.27.0 found:" + + str([cmake_version, cmake3_version]) + ) if cmake3_version is None: cmake_command = "cmake" @@ -79,21 +107,32 @@ def _get_cmake_command() -> str: return cmake_command @staticmethod - def _get_version(cmd: str | None) -> Any: - "Returns cmake version." + def _get_version(cmd: str | None) -> Version | None: + """Returns cmake version.""" if cmd is None: return None - for line in check_output([cmd, "--version"]).decode("utf-8").split("\n"): - if "version" in line: - return LooseVersion(line.strip().split(" ")[2]) - raise RuntimeError("no version found") + + try: + cmake_capabilities = json.loads( + check_output( + [cmd, "-E", "capabilities"], + stderr=DEVNULL, + text=True, + ), + ) + except (OSError, CalledProcessError, json.JSONDecodeError): + cmake_capabilities = {} + cmake_version = cmake_capabilities.get("version", {}).get("string") + if cmake_version is not None: + return Version(cmake_version) + raise RuntimeError(f"Failed to get CMake version from command: {cmd}") def run(self, args: list[str], env: dict[str, str]) -> None: - "Executes cmake with arguments and an environment." + """Executes cmake with arguments and an environment.""" command = [self._cmake_command] + args - print(" ".join(command)) + eprint(" ".join(command)) try: check_call(command, cwd=self.build_dir, env=env) except (CalledProcessError, KeyboardInterrupt): @@ -104,7 +143,7 @@ def run(self, args: list[str], env: dict[str, str]) -> None: @staticmethod def defines(args: list[str], **kwargs: CMakeValue) -> None: - "Adds definitions to a cmake argument list." + """Adds definitions to a cmake argument list.""" for key, value in sorted(kwargs.items()): if value is not None: args.append(f"-D{key}={value}") @@ -126,14 +165,31 @@ def generate( my_env: dict[str, str], rerun: bool, ) -> None: - "Runs cmake to generate native build files." + """Runs cmake to generate native build files.""" if rerun and os.path.isfile(self._cmake_cache_file): os.remove(self._cmake_cache_file) - ninja_build_file = os.path.join(self.build_dir, "build.ninja") - if os.path.exists(self._cmake_cache_file) and not ( - USE_NINJA and not os.path.exists(ninja_build_file) + cmake_cache_file_available = os.path.exists(self._cmake_cache_file) + if cmake_cache_file_available: + cmake_cache_variables = self.get_cmake_cache_variables() + make_program: str | None = cmake_cache_variables.get("CMAKE_MAKE_PROGRAM") # type: ignore[assignment] + if make_program and not shutil.which(make_program): + # CMakeCache.txt exists, but the make program (e.g., ninja) does not. + # See also: https://github.com/astral-sh/uv/issues/14269 + # This can happen if building with PEP-517 build isolation, where `ninja` was + # installed in the isolated environment of the previous build run, but it has been + # removed. The `ninja` executable with an old absolute path not available anymore. + eprint( + "!!!WARNING!!!: CMakeCache.txt exists, " + f"but CMAKE_MAKE_PROGRAM ({make_program!r}) does not exist. " + "Clearing CMake cache." + ) + self.clear_cache() + cmake_cache_file_available = False + + if cmake_cache_file_available and ( + not USE_NINJA or os.path.exists(self._ninja_build_file) ): # Everything's in place. Do not rerun. return @@ -147,9 +203,9 @@ def generate( generator = os.getenv("CMAKE_GENERATOR", "Visual Studio 16 2019") supported = ["Visual Studio 16 2019", "Visual Studio 17 2022"] if generator not in supported: - print("Unsupported `CMAKE_GENERATOR`: " + generator) - print("Please set it to one of the following values: ") - print("\n".join(supported)) + eprint("Unsupported `CMAKE_GENERATOR`: " + generator) + eprint("Please set it to one of the following values: ") + eprint("\n".join(supported)) sys.exit(1) args.append("-G" + generator) toolset_dict = {} @@ -158,7 +214,7 @@ def generate( toolset_dict["version"] = toolset_version curr_toolset = os.getenv("VCToolsVersion") if curr_toolset is None: - print( + eprint( "When you specify `CMAKE_GENERATOR_TOOLSET_VERSION`, you must also " "activate the vs environment of this version. Please read the notes " "in the build steps carefully." @@ -290,9 +346,9 @@ def generate( # Detect build dependencies from python lib path (in order to set *_HOME variables) # NVSHMEM - nvshmem_home = py_lib_path + "/nvidia/nvshmem" - if os.path.exists(nvshmem_home): - build_options["NVSHMEM_HOME"] = nvshmem_home + nvshmem_py_dir = py_lib_path + "/nvidia/nvshmem" + if os.path.exists(nvshmem_py_dir): + build_options["NVSHMEM_PY_DIR"] = nvshmem_py_dir # Options starting with CMAKE_ cmake__options = { @@ -303,7 +359,7 @@ def generate( # error if the user also attempts to set these CMAKE options directly. specified_cmake__options = set(build_options).intersection(cmake__options) if len(specified_cmake__options) > 0: - print( + eprint( ", ".join(specified_cmake__options) + " should not be specified in the environment variable. They are directly set by PyTorch build script." ) @@ -332,11 +388,8 @@ def generate( my_env[env_var_name] = str(my_env[env_var_name].encode("utf-8")) except UnicodeDecodeError as e: shex = ":".join(f"{ord(c):02x}" for c in my_env[env_var_name]) - print( - f"Invalid ENV[{env_var_name}] = {shex}", - file=sys.stderr, - ) - print(e, file=sys.stderr) + eprint(f"Invalid ENV[{env_var_name}] = {shex}") + eprint(e) # According to the CMake manual, we should pass the arguments first, # and put the directory as the last element. Otherwise, these flags # may not be passed correctly. @@ -347,7 +400,7 @@ def generate( self.run(args, env=my_env) def build(self, my_env: dict[str, str]) -> None: - "Runs cmake to build binaries." + """Runs cmake to build binaries.""" from .env import build_type @@ -385,3 +438,10 @@ def build(self, my_env: dict[str, str]) -> None: # CMake 3.12 provides a '-j' option. build_args += ["-j", max_jobs] self.run(build_args, my_env) + + def clear_cache(self) -> None: + """Clears the CMake cache.""" + if os.path.isfile(self._cmake_cache_file): + os.remove(self._cmake_cache_file) + if os.path.isfile(self._ninja_build_file): + os.remove(self._ninja_build_file) diff --git a/tools/stats/check_disabled_tests.py b/tools/stats/check_disabled_tests.py index 5505dc26592901..f1f8e2f99ee809 100644 --- a/tools/stats/check_disabled_tests.py +++ b/tools/stats/check_disabled_tests.py @@ -173,7 +173,7 @@ def save_results( all_tests: dict[str, dict[str, int]], ) -> None: """ - Save the result to S3, which then gets put into the HUD backened database + Save the result to S3, which then gets put into the HUD backend database """ should_be_enabled_tests = { name: stats diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index a79f50bc141ec7..a5affc2510b776 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """ A Python script that logging the system-level utilization usage in json format. -Data collected: CPU, memory, GPU memeory utilzation, and GPU utilization if available. +Data collected: CPU, memory, GPU memory utilization, and GPU utilization if available. Usage: - To run the script with default data collect time setting, use the following command: diff --git a/tools/test/test_docstring_linter.py b/tools/test/test_docstring_linter.py index 2573058c36f0e9..e16e086cf606fd 100644 --- a/tools/test/test_docstring_linter.py +++ b/tools/test/test_docstring_linter.py @@ -8,8 +8,8 @@ from pathlib import Path from unittest import mock +from tools.linter.adapters._linter.block import _get_decorators from tools.linter.adapters.docstring_linter import ( - _get_decorators, DocstringLinter, file_summary, make_recursive, @@ -54,7 +54,7 @@ def run(name, *argv): grandfather_file = f"{td}/grandfather.json" grandfather = f"--grandfather={grandfather_file}" - # Find some faiures + # Find some failures run("before.txt", grandfather) # Rewrite grandfather file diff --git a/tools/test/test_executorch_custom_ops.py b/tools/test/test_executorch_custom_ops.py deleted file mode 100644 index 767fe0580b17c8..00000000000000 --- a/tools/test/test_executorch_custom_ops.py +++ /dev/null @@ -1,147 +0,0 @@ -from __future__ import annotations - -import tempfile -import unittest -from typing import Any -from unittest.mock import ANY, Mock, patch - -import expecttest - -import torchgen -from torchgen.executorch.api.custom_ops import ComputeNativeFunctionStub -from torchgen.executorch.model import ETKernelIndex -from torchgen.gen_executorch import gen_headers -from torchgen.model import Location, NativeFunction -from torchgen.selective_build.selector import SelectiveBuilder -from torchgen.utils import FileManager - - -SPACES = " " - - -def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction: - native_function, _ = NativeFunction.from_yaml( - yaml_obj, - loc=Location(__file__, 1), - valid_tags=set(), - ) - return native_function - - -class TestComputeNativeFunctionStub(expecttest.TestCase): - """ - Could use torch.testing._internal.common_utils to reduce boilerplate. - GH CI job doesn't build torch before running tools unit tests, hence - manually adding these parametrized tests. - """ - - def _test_function_schema_generates_correct_kernel( - self, obj: dict[str, Any], expected: str - ) -> None: - func = _get_native_function_from_yaml(obj) - - gen = ComputeNativeFunctionStub() - res = gen(func) - self.assertIsNotNone(res) - self.assertExpectedInline( - str(res), - expected, - ) - - def test_function_schema_generates_correct_kernel_tensor_out(self) -> None: - obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"} - expected = """ -at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) { - return out; -} - """ - self._test_function_schema_generates_correct_kernel(obj, expected) - - def test_function_schema_generates_correct_kernel_no_out(self) -> None: - obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"} - expected = """ -at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) { - return self; -} - """ - self._test_function_schema_generates_correct_kernel(obj, expected) - - def test_function_schema_generates_correct_kernel_no_return(self) -> None: - obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!)[] out) -> ()"} - expected = f""" -void wrapper_CPU_out_foo_out(const at::Tensor & self, at::TensorList out) {{ -{SPACES} -}} - """ - self._test_function_schema_generates_correct_kernel(obj, expected) - - def test_function_schema_generates_correct_kernel_3_returns(self) -> None: - obj = { - "func": "custom::foo(Tensor self, Tensor[] other) -> (Tensor, Tensor, Tensor)" - } - expected = """ -::std::tuple wrapper_CPU__foo(const at::Tensor & self, at::TensorList other) { - return ::std::tuple( - at::Tensor(), at::Tensor(), at::Tensor() - ); -} - """ - self._test_function_schema_generates_correct_kernel(obj, expected) - - def test_function_schema_generates_correct_kernel_1_return_no_out(self) -> None: - obj = {"func": "custom::foo(Tensor[] a) -> Tensor"} - expected = """ -at::Tensor wrapper_CPU__foo(at::TensorList a) { - return at::Tensor(); -} - """ - self._test_function_schema_generates_correct_kernel(obj, expected) - - def test_schema_has_no_return_type_argument_throws(self) -> None: - func = _get_native_function_from_yaml( - {"func": "custom::foo.bool(Tensor self) -> bool"} - ) - - gen = ComputeNativeFunctionStub() - with self.assertRaisesRegex(Exception, "Can't handle this return type"): - gen(func) - - -class TestGenCustomOpsHeader(unittest.TestCase): - @patch.object(torchgen.utils.FileManager, "write_with_template") - @patch.object(torchgen.utils.FileManager, "write") - def test_fm_writes_custom_ops_header_when_boolean_is_true( - self, unused: Mock, mock_method: Mock - ) -> None: - with tempfile.TemporaryDirectory() as tempdir: - fm = FileManager(tempdir, tempdir, False) - gen_headers( - native_functions=[], - gen_custom_ops_header=True, - custom_ops_native_functions=[], - selector=SelectiveBuilder.get_nop_selector(), - kernel_index=ETKernelIndex(index={}), - cpu_fm=fm, - use_aten_lib=False, - ) - mock_method.assert_called_once_with( - "CustomOpsNativeFunctions.h", "NativeFunctions.h", ANY - ) - - @patch.object(torchgen.utils.FileManager, "write_with_template") - @patch.object(torchgen.utils.FileManager, "write") - def test_fm_doesnot_writes_custom_ops_header_when_boolean_is_false( - self, unused: Mock, mock_method: Mock - ) -> None: - with tempfile.TemporaryDirectory() as tempdir: - fm = FileManager(tempdir, tempdir, False) - gen_headers( - native_functions=[], - gen_custom_ops_header=False, - custom_ops_native_functions=[], - selector=SelectiveBuilder.get_nop_selector(), - kernel_index=ETKernelIndex(index={}), - cpu_fm=fm, - use_aten_lib=False, - ) - mock_method.assert_not_called() diff --git a/tools/test/test_executorch_gen.py b/tools/test/test_executorch_gen.py deleted file mode 100644 index 9448265aa7177b..00000000000000 --- a/tools/test/test_executorch_gen.py +++ /dev/null @@ -1,689 +0,0 @@ -from __future__ import annotations - -import os -import tempfile -import unittest - -import yaml - -from torchgen.executorch.model import ETKernelIndex, ETKernelKey -from torchgen.gen import LineLoader -from torchgen.gen_executorch import ( - ComputeCodegenUnboxedKernels, - gen_functions_declarations, - parse_yaml_files, - translate_native_yaml, -) -from torchgen.model import ( - BackendIndex, - BackendMetadata, - DispatchKey, - Location, - NativeFunction, - OperatorName, -) -from torchgen.selective_build.selector import SelectiveBuilder - - -TEST_YAML = """ -- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - device_check: NoCheck # TensorIterator - structured: True - structured_inherits: TensorIteratorBase - ufunc_inner_loop: - Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf) - ScalarOnly: add (Bool) - dispatch: - SparseCPU: add_out_sparse_cpu - SparseCUDA: add_out_sparse_cuda - SparseCsrCPU: add_out_sparse_csr_cpu - SparseCsrCUDA: add_out_sparse_csr_cuda - MkldnnCPU: mkldnn_add_out - MPS: add_out_mps - -- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - device_check: NoCheck # TensorIterator - structured_delegate: add.out - variants: function, method - dispatch: - SparseCPU, SparseCUDA: add_sparse - SparseCsrCPU, SparseCsrCUDA: add_sparse_csr - MkldnnCPU: mkldnn_add - ZeroTensor: add_zerotensor - NestedTensorCPU, NestedTensorCUDA: NestedTensor_add_Tensor - tags: core - -- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - device_check: NoCheck # TensorIterator - structured: True - structured_inherits: TensorIteratorBase - dispatch: - CPU, CUDA: mul_out - MPS: mul_out_mps - SparseCPU: mul_out_sparse_cpu - SparseCUDA: mul_out_sparse_cuda - SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr - MkldnnCPU: mkldnn_mul_out - -- func: mul.Tensor(Tensor self, Tensor other) -> Tensor - device_check: NoCheck # TensorIterator - structured_delegate: mul.out - variants: function, method - dispatch: - SparseCPU, SparseCUDA: mul_sparse - SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr - MkldnnCPU: mkldnn_mul - ZeroTensor: mul_zerotensor - NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Tensor - tags: core - -""" - - -TEST_KERNEL_YAML = """ -- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - device_check: NoCheck # TensorIterator - structured: True - structured_inherits: TensorIteratorBase - ufunc_inner_loop: - Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf) - ScalarOnly: add (Bool) - type_alias: - T0: [Float, Double] - T1: [Double, Int] - dim_order_alias: - D0: [0, 1, 2, 3] - D1: [0, 3, 2, 1] - kernels: - - arg_meta: null - kernel_name: default_impl - - arg_meta: - self: [T0, D0] - other: [T1, D0] - out: [T0, D0] - kernel_name: test_impl - - arg_meta: - self: [T1, D0] - other: [T1, D1] - out: [T0, D1] - kernel_name: test_impl_2 - -- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - device_check: NoCheck # TensorIterator - structured_delegate: add.out - variants: function, method - tags: core - -- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - device_check: NoCheck # TensorIterator - structured: True - structured_inherits: TensorIteratorBase - type_alias: - T0: [Float] - T1: [Double] - dim_order_alias: - D0: [0, 1, 2, 3] - kernels: - - arg_meta: null - kernel_name: default_impl - - arg_meta: - self: [T0, D0] - other: [T1, D0] - out: [T0, D0] - kernel_name: test_impl - -- func: mul.Tensor(Tensor self, Tensor other) -> Tensor - device_check: NoCheck # TensorIterator - structured_delegate: mul.out - variants: function, method - tags: core - -""" - - -class TestParseNativeYaml(unittest.TestCase): - def setUp(self) -> None: - self.temp_dir = tempfile.mkdtemp() - - self.aten_yaml_path = os.path.join(self.temp_dir, "test_native_functions.yaml") - with open(self.aten_yaml_path, "w") as f: - f.write(TEST_YAML) - self.ops_yaml_path = os.path.join(self.temp_dir, "test.yaml") - self.tags_yaml_path = os.path.join(self.temp_dir, "tags.yaml") - with open(self.tags_yaml_path, "w") as f: - f.write( - """ -- tag: core - desc: test - """ - ) - with open(self.ops_yaml_path, "w") as f: - f.write( - """ -- op: add.out - device_check: NoCheck # TensorIterator - dispatch: - CPU: torch::executor::add_out_kernel - -- op: mul.out - device_check: NoCheck # TensorIterator - dispatch: - CPU: torch::executor::mul_out_kernel - """ - ) - - def test_translate_native_yaml_writes_correct_data(self) -> None: - out_yaml_path = os.path.join(self.temp_dir, "out.yaml") - with open(out_yaml_path, "w") as out_file: - translate_native_yaml( - tags_yaml_path=self.tags_yaml_path, - aten_yaml_path=self.aten_yaml_path, - native_yaml_path=self.ops_yaml_path, - use_aten_lib=False, - out_file=out_file, - ) - with open(out_yaml_path) as out_file: - es = yaml.load(out_file, Loader=LineLoader) - self.assertTrue(all("func" in e for e in es)) - self.assertTrue(all(e.get("variants") == "function" for e in es)) - - # Check that kernel fields aren't introduced in yaml - for e in es: - self.assertFalse({"kernels", "type_alias", "dim_order_alias"} < e.keys()) - - def test_parse_yaml_files(self) -> None: - custom_ops_yaml_path = None - selector = SelectiveBuilder.get_nop_selector() - use_aten_lib = False - - parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( - aten_yaml_path=self.aten_yaml_path, - tags_yaml_path=self.tags_yaml_path, - native_yaml_path=self.ops_yaml_path, - custom_ops_yaml_path=custom_ops_yaml_path, - selector=selector, - use_aten_lib=use_aten_lib, - ) - - # Just the default kernel entry - expected_kernel_entry = {"add.out": 1, "mul.out": 1} - self.assertTrue(len(parsed_yaml.native_functions) == len(expected_kernel_entry)) - - op_entries = parsed_yaml.kernel_index.index - for op_name, kernel_mapping in op_entries.items(): - self.assertTrue( - len(kernel_mapping) == expected_kernel_entry.pop(str(op_name)) - ) - - self.assertTrue(len(expected_kernel_entry) == 0) - - def tearDown(self) -> None: - import shutil - - try: - shutil.rmtree(self.temp_dir) - except OSError: - pass - - -class TestParseKernelYamlFiles(unittest.TestCase): - def setUp(self) -> None: - self.temp_dir = tempfile.mkdtemp() - - self.aten_kernel_yaml_path = os.path.join( - self.temp_dir, "test_kernel_native_functions.yaml" - ) - with open(self.aten_kernel_yaml_path, "w") as f: - f.write(TEST_KERNEL_YAML) - self.ops_yaml_path = os.path.join(self.temp_dir, "test.yaml") - self.tags_yaml_path = os.path.join(self.temp_dir, "tags.yaml") - with open(self.tags_yaml_path, "w") as f: - f.write( - """ -- tag: core - desc: test - """ - ) - with open(self.ops_yaml_path, "w") as f: - f.write( - """ -- op: add.out - device_check: NoCheck # TensorIterator - dispatch: - CPU: torch::executor::add_out_kernel - -- op: mul.out - device_check: NoCheck # TensorIterator - dispatch: - CPU: torch::executor::mul_out_kernel - """ - ) - - def test_translate_kernel_native_yaml_writes_correct_data(self) -> None: - out_yaml_path = os.path.join(self.temp_dir, "out2.yaml") - with open(out_yaml_path, "w") as out_file: - translate_native_yaml( - tags_yaml_path=self.tags_yaml_path, - aten_yaml_path=self.aten_kernel_yaml_path, - native_yaml_path=self.ops_yaml_path, - use_aten_lib=False, - out_file=out_file, - ) - with open(out_yaml_path) as out_file: - es = yaml.load(out_file, Loader=LineLoader) - self.assertTrue(all("func" in e for e in es)) - self.assertTrue(all(e.get("variants") == "function" for e in es)) - - # Check persistence of kernel fields in yaml - for e in es: - self.assertTrue({"kernels", "type_alias", "dim_order_alias"} < e.keys()) - - def test_parse_yaml_files(self) -> None: - custom_ops_yaml_path = None - selector = SelectiveBuilder.get_nop_selector() - use_aten_lib = False - - parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( - aten_yaml_path=self.aten_kernel_yaml_path, - tags_yaml_path=self.tags_yaml_path, - native_yaml_path=self.ops_yaml_path, - custom_ops_yaml_path=custom_ops_yaml_path, - selector=selector, - use_aten_lib=use_aten_lib, - ) - - expected_kernel_entry = {"add.out": 9, "mul.out": 2} - self.assertTrue(len(parsed_yaml.native_functions) == len(expected_kernel_entry)) - - op_entries = parsed_yaml.kernel_index.index - for op_name, kernel_mapping in op_entries.items(): - self.assertTrue( - len(kernel_mapping) == expected_kernel_entry.pop(str(op_name)) - ) - - self.assertTrue(len(expected_kernel_entry) == 0) - - def tearDown(self) -> None: - import shutil - - try: - shutil.rmtree(self.temp_dir) - except OSError: - pass - - -class TestGenFunctionsDeclarations(unittest.TestCase): - def setUp(self) -> None: - ( - self.custom_1_native_function, - custom_1_backend_index, - ) = NativeFunction.from_yaml( - {"func": "custom_1::op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, - loc=Location(__file__, 1), - valid_tags=set(), - ) - ( - self.custom_2_native_function, - custom_2_backend_index, - ) = NativeFunction.from_yaml( - { - "func": "custom_2::op_2() -> bool", - "dispatch": {"CPU": "kernel_2"}, - }, - loc=Location(__file__, 1), - valid_tags=set(), - ) - ( - self.custom_3_native_function, - custom_3_backend_index, - ) = NativeFunction.from_yaml( - { - "func": "custom_3::op_3(Tensor(a!) self, Tensor x) -> Tensor(a!)", - "dispatch": {"CPU": "kernel_3"}, - "variants": "method", - }, - loc=Location(__file__, 1), - valid_tags=set(), - ) - - backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = { - DispatchKey.CPU: {}, - DispatchKey.QuantizedCPU: {}, - } - BackendIndex.grow_index(backend_indices, custom_1_backend_index) - BackendIndex.grow_index(backend_indices, custom_2_backend_index) - self.static_dispatch_idx = [ - BackendIndex( - dispatch_key=k, - use_out_as_primary=True, - external=False, - device_guard=False, - index=backend_indices[k], - ) - for k in backend_indices - ] - self.kernel_index = ETKernelIndex.from_backend_indices(backend_indices) - - def test_operators_with_different_namespaces_are_grouped_correctly(self) -> None: - declarations = gen_functions_declarations( - native_functions=[ - self.custom_1_native_function, - self.custom_2_native_function, - ], - kernel_index=self.kernel_index, - selector=SelectiveBuilder.get_nop_selector(), - use_aten_lib=False, - ) - self.assertTrue( - """ -namespace custom_1 { - -// custom_1::op_1() -> bool -TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) { - return ::at::native::kernel_1(context); -} - -} // namespace custom_1 -""" - in declarations - ) - - self.assertTrue( - """ -namespace custom_2 { - -// custom_2::op_2() -> bool -TORCH_API inline bool op_2(torch::executor::KernelRuntimeContext & context) { - return ::at::native::kernel_2(context); -} - -} // namespace custom_2 - """ - in declarations - ) - - def test_aten_lib_has_context_arg(self) -> None: - declarations = gen_functions_declarations( - native_functions=[ - self.custom_1_native_function, - ], - kernel_index=self.kernel_index, - selector=SelectiveBuilder.get_nop_selector(), - use_aten_lib=True, - ) - self.assertTrue( - """ -namespace custom_1 { - -// custom_1::op_1() -> bool -TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) { - return at::op_1(); -} - -} // namespace custom_1 - """ - in declarations - ) - - def test_aten_lib_method_variant(self) -> None: - declarations = gen_functions_declarations( - native_functions=[ - self.custom_3_native_function, - ], - kernel_index=self.kernel_index, - selector=SelectiveBuilder.get_nop_selector(), - use_aten_lib=True, - ) - self.assertTrue( - """ -namespace custom_3 { - -// custom_3::op_3(Tensor(a!) self, Tensor x) -> Tensor(a!) -TORCH_API inline at::Tensor & op_3(torch::executor::KernelRuntimeContext & context, at::Tensor & self, const at::Tensor & x) { - return self.op_3(x); -} - -} // namespace custom_3 - """ - in declarations - ) - - -class TestComputeCodegenUnboxedKernels(unittest.TestCase): - def setUp(self) -> None: - ( - self.native_function_no_kern, - _, - ) = NativeFunction.from_yaml( - { - "func": "custom_1::op_1() -> bool", - "dispatch": {"CPU": "unused_kernel_1"}, - }, - loc=Location(__file__, 1), - valid_tags=set(), - ) - - self.default_kernel_key = ETKernelKey(default=True) - self.default_backend_metadata = BackendMetadata( - "default_kernel", False, "at::native" - ) - self.default_kernel_entry = ( - [self.default_kernel_key], - self.default_backend_metadata, - ) - - def test_codegen_unboxed_specialized(self) -> None: - specialized_kernel_key = ETKernelKey.gen_from_yaml( - {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, - {"T0": ["Double"]}, - {"D0": [0, 1, 2, 3]}, - ) - selector = SelectiveBuilder.from_yaml_dict( - { - "include_all_operators": True, - "et_kernel_metadata": { - "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] - }, - } - ) - use_aten_lib = False - entry = ( - self.native_function_no_kern, - (specialized_kernel_key, self.default_backend_metadata), - ) - - result = ComputeCodegenUnboxedKernels( - selector, use_aten_lib, add_exception_boundary=False - )(entry) - # Concat used to prevent whitespace stripping - expected_str = ( - """ -Kernel( - "custom_1::op_1", - "v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3", - [](torch::executor::KernelRuntimeContext & context, EValue** stack) { - """ - + """ - - - internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); - EXECUTORCH_SCOPE_PROF("native_call_op_1"); - bool result_ = at::native::default_kernel(context, ); - internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); - - *stack[0] = EValue(result_); - - } -), -""" - ) - - self.assertEqual(expected_str, result) - - def test_codegen_unboxed_specialized_not_matching(self) -> None: - specialized_kernel_key = ETKernelKey.gen_from_yaml( - {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, - {"T0": ["Double"]}, - {"D0": [0, 1, 2, 3]}, - ) - selector = SelectiveBuilder.from_yaml_dict( - { - "include_all_operators": True, - "et_kernel_metadata": { - "custom_1::op_1": ["v1/8;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] - }, - } - ) - use_aten_lib = False - entry = ( - self.native_function_no_kern, - (specialized_kernel_key, self.default_backend_metadata), - ) - - self.assertRaises( - Exception, - ComputeCodegenUnboxedKernels( - selector, use_aten_lib, add_exception_boundary=False - ), - entry, - ) - - def test_codegen_unboxed_specialized_missing_root_op(self) -> None: - specialized_kernel_key = ETKernelKey.gen_from_yaml( - {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, - {"T0": ["Double"]}, - {"D0": [0, 1, 2, 3]}, - ) - selector = SelectiveBuilder.from_yaml_dict( - { - "et_kernel_metadata": { - "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] - } - } - ) - use_aten_lib = False - entry = ( - self.native_function_no_kern, - (specialized_kernel_key, self.default_backend_metadata), - ) - - for add_exception_boundary in (True, False): - result = ComputeCodegenUnboxedKernels( - selector, use_aten_lib, add_exception_boundary - )(entry) - # Concat used to prevent whitespace stripping - expected_str = """""" - - self.assertEqual(expected_str, result) - - def test_codegen_unboxed_default(self) -> None: - """ - This test checks that if there is no specialized kernel, the default kernel is used. - """ - selector = SelectiveBuilder.from_yaml_dict( - { - "include_all_operators": True, - "et_kernel_metadata": { - "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] - }, - } - ) - use_aten_lib = False - entry = (self.native_function_no_kern, self.default_kernel_entry) - - result = ComputeCodegenUnboxedKernels( - selector, use_aten_lib, add_exception_boundary=False - )(entry) - # Concat used to prevent whitespace stripping - expected_str = ( - """ -Kernel( - "custom_1::op_1", - [](torch::executor::KernelRuntimeContext & context, EValue** stack) { - """ - + """ - - - internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); - EXECUTORCH_SCOPE_PROF("native_call_op_1"); - bool result_ = at::native::default_kernel(context, ); - internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); - - *stack[0] = EValue(result_); - - } -), -""" - ) - - self.assertEqual(expected_str, result) - - result = ComputeCodegenUnboxedKernels( - selector, use_aten_lib, add_exception_boundary=True - )(entry) - # Concat used to prevent whitespace stripping - expected_str = ( - """ -Kernel( - "custom_1::op_1", - [](torch::executor::KernelRuntimeContext & context, EValue** stack) { - """ - + """ - - try { - internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); - EXECUTORCH_SCOPE_PROF("native_call_op_1"); - bool result_ = at::native::default_kernel(context, ); - internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); - - *stack[0] = EValue(result_); - } catch (const std::exception& ex) { - ET_LOG(Error, "Kernel threw an exception: %s", ex.what()); - context.fail(torch::executor::Error::Internal); - } - } -), -""" - ) - self.maxDiff = None - self.assertEqual(expected_str, result) - - def test_codegen_unboxed_default_kernel_key_selected(self) -> None: - """ - This test checks that if there is no specialized kernel, the default kernel is used, when the selector only has default key. - """ - selector = SelectiveBuilder.from_yaml_dict( - { - "include_all_operators": True, - "et_kernel_metadata": {"custom_1::op_1": ["default"]}, - } - ) - use_aten_lib = False - entry = (self.native_function_no_kern, self.default_kernel_entry) - - result = ComputeCodegenUnboxedKernels( - selector, use_aten_lib, add_exception_boundary=False - )(entry) - # Concat used to prevent whitespace stripping - expected_str = ( - """ -Kernel( - "custom_1::op_1", - [](torch::executor::KernelRuntimeContext & context, EValue** stack) { - """ - + """ - - - internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); - EXECUTORCH_SCOPE_PROF("native_call_op_1"); - bool result_ = at::native::default_kernel(context, ); - internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); - - *stack[0] = EValue(result_); - - } -), -""" - ) - - self.assertEqual(expected_str, result) diff --git a/tools/test/test_executorch_signatures.py b/tools/test/test_executorch_signatures.py deleted file mode 100644 index 79f291aba3d27b..00000000000000 --- a/tools/test/test_executorch_signatures.py +++ /dev/null @@ -1,59 +0,0 @@ -import unittest - -from torchgen.executorch.api.types import ExecutorchCppSignature -from torchgen.local import parametrize -from torchgen.model import Location, NativeFunction - - -DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( - {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"}, - loc=Location(__file__, 1), - valid_tags=set(), -) - - -class ExecutorchCppSignatureTest(unittest.TestCase): - def setUp(self) -> None: - self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION) - - def test_runtime_signature_contains_runtime_context(self) -> None: - # test if `KernelRuntimeContext` argument exists in `RuntimeSignature` - with parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ): - args = self.sig.arguments(include_context=True) - self.assertEqual(len(args), 3) - self.assertTrue(any(a.name == "context" for a in args)) - - def test_runtime_signature_does_not_contain_runtime_context(self) -> None: - # test if `KernelRuntimeContext` argument is missing in `RuntimeSignature` - with parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ): - args = self.sig.arguments(include_context=False) - self.assertEqual(len(args), 2) - self.assertFalse(any(a.name == "context" for a in args)) - - def test_runtime_signature_declaration_correct(self) -> None: - with parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ): - decl = self.sig.decl(include_context=True) - self.assertEqual( - decl, - ( - "torch::executor::Tensor & foo_outf(" - "torch::executor::KernelRuntimeContext & context, " - "const torch::executor::Tensor & input, " - "torch::executor::Tensor & out)" - ), - ) - no_context_decl = self.sig.decl(include_context=False) - self.assertEqual( - no_context_decl, - ( - "torch::executor::Tensor & foo_outf(" - "const torch::executor::Tensor & input, " - "torch::executor::Tensor & out)" - ), - ) diff --git a/tools/test/test_executorch_types.py b/tools/test/test_executorch_types.py deleted file mode 100644 index dedb19e21f3e6d..00000000000000 --- a/tools/test/test_executorch_types.py +++ /dev/null @@ -1,114 +0,0 @@ -import unittest - -from torchgen import local -from torchgen.api.types import ( - BaseCType, - boolT, - ConstRefCType, - CType, - longT, - MutRefCType, - NamedCType, - OptionalCType, - TupleCType, - VectorCType, - voidT, -) -from torchgen.executorch.api.et_cpp import argument_type, return_type, returns_type -from torchgen.executorch.api.types import ArrayRefCType, scalarT, tensorListT, tensorT -from torchgen.model import Argument, FunctionSchema, Return - - -class ExecutorchCppTest(unittest.TestCase): - """ - Test torchgen.executorch.api.cpp - """ - - def _test_argumenttype_type(self, arg_str: str, expected: NamedCType) -> None: - arg = Argument.parse(arg_str) - self.assertEqual(str(argument_type(arg, binds=arg.name)), str(expected)) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def test_argumenttype_type(self) -> None: - data = [ - ("Tensor self", NamedCType("self", ConstRefCType(BaseCType(tensorT)))), - ("Tensor(a!) out", NamedCType("out", MutRefCType(BaseCType(tensorT)))), - ( - "Tensor? opt", - NamedCType("opt", ConstRefCType(OptionalCType(BaseCType(tensorT)))), - ), - ("Scalar scalar", NamedCType("scalar", ConstRefCType(BaseCType(scalarT)))), - ( - "Scalar? scalar", - NamedCType("scalar", ConstRefCType(OptionalCType(BaseCType(scalarT)))), - ), - ("int[] size", NamedCType("size", ArrayRefCType(BaseCType(longT)))), - ("int? dim", NamedCType("dim", OptionalCType(BaseCType(longT)))), - ("Tensor[] weight", NamedCType("weight", BaseCType(tensorListT))), - ( - "Scalar[] spacing", - NamedCType("spacing", ArrayRefCType(ConstRefCType(BaseCType(scalarT)))), - ), - ( - "Tensor?[] weight", - NamedCType("weight", ArrayRefCType(OptionalCType(BaseCType(tensorT)))), - ), - ( - "SymInt[]? output_size", - NamedCType( - "output_size", OptionalCType(ArrayRefCType(BaseCType(longT))) - ), - ), - ( - "int[]? dims", - NamedCType("dims", OptionalCType(ArrayRefCType(BaseCType(longT)))), - ), - ( - "bool[3] output_mask", - NamedCType("output_mask", ArrayRefCType(BaseCType(boolT))), - ), - ] - for d in data: - self._test_argumenttype_type(*d) - - def _test_returntype_type(self, ret_str: str, expected: CType) -> None: - ret = Return.parse(ret_str) - self.assertEqual(str(return_type(ret)), str(expected)) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def test_returntype_type(self) -> None: - data = [ - ("Tensor", BaseCType(tensorT)), - ("Tensor(a!)", MutRefCType(BaseCType(tensorT))), - ("Tensor[]", VectorCType(BaseCType(tensorT))), - ] - for d in data: - self._test_returntype_type(*d) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def test_returns_type(self) -> None: - func = FunctionSchema.parse( - "min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)" - ) - expected = TupleCType([BaseCType(tensorT), BaseCType(tensorT)]) - self.assertEqual(str(returns_type(func.returns)), str(expected)) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def test_void_return_type(self) -> None: - func = FunctionSchema.parse( - "_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()" - ) - expected = BaseCType(voidT) - self.assertEqual(str(returns_type(func.returns)), str(expected)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tools/test/test_executorch_unboxing.py b/tools/test/test_executorch_unboxing.py deleted file mode 100644 index eff0145856dd7b..00000000000000 --- a/tools/test/test_executorch_unboxing.py +++ /dev/null @@ -1,176 +0,0 @@ -import unittest -from types import ModuleType - -from torchgen import local -from torchgen.api import cpp as aten_cpp, types as aten_types -from torchgen.api.types import ( - ArgName, - BaseCType, - ConstRefCType, - MutRefCType, - NamedCType, -) -from torchgen.executorch.api import et_cpp as et_cpp, types as et_types -from torchgen.executorch.api.unboxing import Unboxing -from torchgen.model import BaseTy, BaseType, ListType, OptionalType, Type - - -def aten_argumenttype_type_wrapper( - t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False -) -> NamedCType: - return aten_cpp.argumenttype_type( - t, - mutable=mutable, - binds=binds, - remove_non_owning_ref_types=remove_non_owning_ref_types, - ) - - -ATEN_UNBOXING = Unboxing(argument_type_gen=aten_argumenttype_type_wrapper) -ET_UNBOXING = Unboxing(argument_type_gen=et_cpp.argumenttype_type) - - -class TestUnboxing(unittest.TestCase): - """ - Could use torch.testing._internal.common_utils to reduce boilerplate. - GH CI job doesn't build torch before running tools unit tests, hence - manually adding these parametrized tests. - """ - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def test_symint_argument_translate_ctype_aten(self) -> None: - # test if `SymInt[]` JIT argument can be translated into C++ argument correctly. - # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig. - - # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt` - # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. - symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None) - - out_name, ctype, _, _ = ATEN_UNBOXING.argumenttype_evalue_convert( - t=symint_list_type, arg_name="size", mutable=False - ) - - self.assertEqual(out_name, "size_list_out") - self.assertIsInstance(ctype, BaseCType) - # pyre-fixme[16]: - self.assertEqual(ctype, aten_types.BaseCType(aten_types.intArrayRefT)) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def test_symint_argument_translate_ctype_executorch(self) -> None: - # test if `SymInt[]` JIT argument can be translated into C++ argument correctly. - # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig. - - # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt` - # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. - symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None) - - out_name, ctype, _, _ = ET_UNBOXING.argumenttype_evalue_convert( - t=symint_list_type, arg_name="size", mutable=False - ) - - self.assertEqual(out_name, "size_list_out") - self.assertIsInstance(ctype, et_types.ArrayRefCType) - # pyre-fixme[16]: - self.assertEqual( - ctype, et_types.ArrayRefCType(elem=BaseCType(aten_types.longT)) - ) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def _test_const_tensor_argument_translate_ctype( - self, unboxing: Unboxing, types: ModuleType - ) -> None: - # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` - # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. - tensor_type = BaseType(BaseTy.Tensor) - - out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( - t=tensor_type, arg_name="self", mutable=False - ) - - self.assertEqual(out_name, "self_base") - # pyre-fixme[16]: - self.assertEqual(ctype, ConstRefCType(BaseCType(types.tensorT))) - - def test_const_tensor_argument_translate_ctype_aten(self) -> None: - self._test_const_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types) - - def test_const_tensor_argument_translate_ctype_executorch(self) -> None: - self._test_const_tensor_argument_translate_ctype(ET_UNBOXING, et_types) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def _test_mutable_tensor_argument_translate_ctype( - self, unboxing: Unboxing, types: ModuleType - ) -> None: - # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` - # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. - tensor_type = BaseType(BaseTy.Tensor) - - out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( - t=tensor_type, arg_name="out", mutable=True - ) - - self.assertEqual(out_name, "out_base") - # pyre-fixme[16]: - self.assertEqual(ctype, MutRefCType(BaseCType(types.tensorT))) - - def test_mutable_tensor_argument_translate_ctype_aten(self) -> None: - self._test_mutable_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types) - - def test_mutable_tensor_argument_translate_ctype_executorch(self) -> None: - self._test_mutable_tensor_argument_translate_ctype(ET_UNBOXING, et_types) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def _test_tensor_list_argument_translate_ctype( - self, unboxing: Unboxing, types: ModuleType - ) -> None: - # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` - # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. - tensor_list_type = ListType(elem=BaseType(BaseTy.Tensor), size=None) - - out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( - t=tensor_list_type, arg_name="out", mutable=True - ) - - self.assertEqual(out_name, "out_list_out") - # pyre-fixme[16]: - self.assertEqual(ctype, BaseCType(types.tensorListT)) - - def test_tensor_list_argument_translate_ctype_aten(self) -> None: - self._test_tensor_list_argument_translate_ctype(ATEN_UNBOXING, aten_types) - - def test_tensor_list_argument_translate_ctype_executorch(self) -> None: - self._test_tensor_list_argument_translate_ctype(ET_UNBOXING, et_types) - - @local.parametrize( - use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False - ) - def _test_optional_int_argument_translate_ctype( - self, unboxing: Unboxing, types: ModuleType - ) -> None: - # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` - # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. - optional_int_type = OptionalType(elem=BaseType(BaseTy.int)) - - out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( - t=optional_int_type, arg_name="something", mutable=True - ) - - self.assertEqual(out_name, "something_opt_out") - # pyre-fixme[16]: - self.assertEqual(ctype, types.OptionalCType(BaseCType(types.longT))) - - def test_optional_int_argument_translate_ctype_aten(self) -> None: - self._test_optional_int_argument_translate_ctype(ATEN_UNBOXING, aten_types) - - def test_optional_int_argument_translate_ctype_executorch(self) -> None: - self._test_optional_int_argument_translate_ctype(ET_UNBOXING, et_types) diff --git a/tools/test/test_selective_build.py b/tools/test/test_selective_build.py index 59e6e617072e1d..fac6ca6c8b50ba 100644 --- a/tools/test/test_selective_build.py +++ b/tools/test/test_selective_build.py @@ -298,45 +298,3 @@ def test_custom_namespace_selected_correctly(self) -> None: valid_tags=set(), ) self.assertTrue(selector.is_native_function_selected(native_function)) - - -class TestExecuTorchSelectiveBuild(unittest.TestCase): - def test_et_kernel_selected(self) -> None: - yaml_config = """ -et_kernel_metadata: - aten::add.out: - - "v1/6;0,1|6;0,1|6;0,1|6;0,1" - aten::sub.out: - - "v1/6;0,1|6;0,1|6;0,1|6;0,1" -""" - selector = SelectiveBuilder.from_yaml_str(yaml_config) - self.assertListEqual( - ["v1/6;0,1|6;0,1|6;0,1|6;0,1"], - selector.et_get_selected_kernels( - "aten::add.out", - [ - "v1/6;0,1|6;0,1|6;0,1|6;0,1", - "v1/3;0,1|3;0,1|3;0,1|3;0,1", - "v1/6;1,0|6;0,1|6;0,1|6;0,1", - ], - ), - ) - self.assertListEqual( - ["v1/6;0,1|6;0,1|6;0,1|6;0,1"], - selector.et_get_selected_kernels( - "aten::sub.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"] - ), - ) - self.assertListEqual( - [], - selector.et_get_selected_kernels( - "aten::mul.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"] - ), - ) - # We don't use version for now. - self.assertListEqual( - ["v2/6;0,1|6;0,1|6;0,1|6;0,1"], - selector.et_get_selected_kernels( - "aten::add.out", ["v2/6;0,1|6;0,1|6;0,1|6;0,1"] - ), - ) diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py index 614d036b45a9a8..28ff5bc3ff2929 100644 --- a/tools/testing/discover_tests.py +++ b/tools/testing/discover_tests.py @@ -104,7 +104,10 @@ def skip_test_p(name: str) -> bool: "distributed/test_c10d_spawn", "distributions/test_transforms", "distributions/test_utils", + "lazy/test_meta_kernel", + "lazy/test_extract_compiled_graph", "test/inductor/test_aot_inductor_utils", + "onnx/test_onnxscript_no_runtime", "onnx/test_pytorch_onnx_onnxruntime_cuda", "onnx/test_models", # These are not C++ tests diff --git a/tools/testing/target_determination/heuristics/__init__.py b/tools/testing/target_determination/heuristics/__init__.py index 1bd5940abbb1db..388b72425457d7 100644 --- a/tools/testing/target_determination/heuristics/__init__.py +++ b/tools/testing/target_determination/heuristics/__init__.py @@ -33,7 +33,7 @@ # All currently running heuristics. -# To add a heurstic in trial mode, specify the keywork argument `trial_mode=True`. +# To add a heurstic in trial mode, specify the keyword argument `trial_mode=True`. HEURISTICS: list[HeuristicInterface] = [ PreviouslyFailedInPR(), EditedByPR(), diff --git a/tools/testing/test_run.py b/tools/testing/test_run.py index 81bdfc4d7088e6..aa4efa6d890c67 100644 --- a/tools/testing/test_run.py +++ b/tools/testing/test_run.py @@ -285,7 +285,7 @@ def __lt__(self, other: object) -> bool: if not isinstance(other, ShardedTest): raise NotImplementedError - # This is how the list was implicity sorted when it was a NamedTuple + # This is how the list was implicitly sorted when it was a NamedTuple if self.name != other.name: return self.name < other.name if self.shard != other.shard: diff --git a/tools/testing/upload_artifacts.py b/tools/testing/upload_artifacts.py index a8b6d15fb39a4a..bcc5b221f30aa6 100644 --- a/tools/testing/upload_artifacts.py +++ b/tools/testing/upload_artifacts.py @@ -1,4 +1,5 @@ import glob +import gzip import os import time import zipfile @@ -9,6 +10,7 @@ REPO_ROOT = Path(__file__).resolve().parent.parent.parent LAST_UPDATED = 0.0 +LOG_BUCKET_PREFIX = "temp_logs" @lru_cache(maxsize=1) @@ -28,11 +30,27 @@ def zip_artifact(file_name: str, paths: list[str]) -> None: f.write(file, os.path.relpath(file, REPO_ROOT)) -def upload_to_s3_artifacts() -> None: +def concated_logs() -> str: + """Concatenate all the logs in the test-reports directory into a single string.""" + logs = [] + for log_file in glob.glob( + f"{REPO_ROOT}/test/test-reports/**/*.log", recursive=True + ): + logs.append(f"=== {log_file} ===") + with open(log_file) as f: + # For every line, prefix with fake timestamp for log classifier + for line in f: + line = line.rstrip("\n") # Remove any trailing newline + logs.append(f"2020-01-01T00:00:00.0000000Z {line}") + return "\n".join(logs) + + +def upload_to_s3_artifacts(failed: bool) -> None: """Upload the file to S3.""" workflow_id = os.environ.get("GITHUB_RUN_ID") workflow_run_attempt = os.environ.get("GITHUB_RUN_ATTEMPT") file_suffix = os.environ.get("ARTIFACTS_FILE_SUFFIX") + job_id = os.environ.get("JOB_ID") if not workflow_id or not workflow_run_attempt or not file_suffix: print( "GITHUB_RUN_ID, GITHUB_RUN_ATTEMPT, or ARTIFACTS_FILE_SUFFIX not set, not uploading" @@ -70,6 +88,18 @@ def upload_to_s3_artifacts() -> None: Bucket="gha-artifacts", Key=f"workflows_failing_pending_upload/{workflow_id}.txt", ) + if job_id and failed: + logs = concated_logs() + # Put logs into bucket so log classifier can access them. We cannot get + # the actual GH logs so this will have to be a proxy. + print(f"Uploading logs for {job_id} to S3") + get_s3_resource().put_object( + Body=gzip.compress(logs.encode("utf-8")), + Bucket="gha-artifacts", + Key=f"{LOG_BUCKET_PREFIX}/{job_id}", + ContentType="text/plain", + ContentEncoding="gzip", + ) def zip_and_upload_artifacts(failed: bool) -> None: @@ -81,7 +111,7 @@ def zip_and_upload_artifacts(failed: bool) -> None: if failed or time.time() - LAST_UPDATED > 20 * 60: start = time.time() try: - upload_to_s3_artifacts() + upload_to_s3_artifacts(failed=failed) LAST_UPDATED = time.time() except Exception as e: print(f"Failed to upload artifacts: {e}") diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 0da483c09743f9..55bd03122eee42 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -2,7 +2,7 @@ # Now it only builds the Torch python bindings. if(NOT CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) - cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + cmake_minimum_required(VERSION 3.27 FATAL_ERROR) project(torch CXX C) find_package(torch REQUIRED) option(USE_CUDA "Use CUDA" ON) @@ -74,6 +74,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_SRC_DIR}/csrc ${TORCH_SRC_DIR}/csrc/api/include ${TORCH_SRC_DIR}/lib + ${TORCH_SRC_DIR}/standalone ) list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${LIBSHM_SRCDIR}) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 35cbbeda39245c..a9810251d1f4e1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -16,12 +16,10 @@ from typing import ( Literal, NamedTuple, overload, - Protocol, - runtime_checkable, SupportsIndex, TypeVar, ) -from typing_extensions import ParamSpec, Self, TypeAlias +from typing_extensions import ParamSpec, Protocol, runtime_checkable, Self, TypeAlias import numpy @@ -178,7 +176,7 @@ class Size(tuple[_int, ...]): def __getitem__(self: Size, key: slice, /) -> Size: ... # Note: torch.Size does not support adding non-integer tuples. def __add__(self, other: tuple[_int, ...], /) -> Size: ... # type: ignore[override] - # Note: tuple[int, ...] + Size results in tuple[int, ...], not Size! + def __radd__(self: Size, other: tuple[_int, ...], /) -> Size: ... def __mul__(self, other: SupportsIndex, /) -> Size: ... def __rmul__(self, other: SupportsIndex, /) -> Size: ... def numel(self: Size, /) -> _int: ... @@ -452,13 +450,13 @@ ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]] # and torch/csrc/jit/python/init.cpp def _maybe_call_torch_function_for_op_packet( op_overload_packet: Any, - args: Any, - kwargs: Any, + *args: Any, + **kwargs: Any, ) -> Any: ... def _check_schema_allow_fake_script_object( schema: FunctionSchema, - args: Any, - kwargs: Any, + *args: Any, + **kwargs: Any, ) -> _bool: ... def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ... def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... @@ -1272,6 +1270,7 @@ def _set_sm_carveout_experimental(arg: _int | None) -> None: ... def _set_conj(x: Tensor, conj: _bool) -> None: ... def _set_neg(x: Tensor, neg: _bool) -> None: ... def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ... +def _autocast_supported_devices() -> list[str]: ... def _meta_in_tls_dispatch_include() -> _bool: ... def _stash_obj_in_tls(key: str, arg: Any) -> None: ... def _get_obj_in_tls(key: str) -> Any: ... @@ -1297,10 +1296,12 @@ def _group_tensors_by_device_and_dtype( tuple[torch.device, torch.dtype], tuple[list[list[Tensor | None]], list[_int]], ]: ... +def _initCrashHandler() -> None: ... # NB: There is no Capsule type in typing, see # https://github.com/python/cpython/issues/109562 def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack +def _to_dlpack_versioned(data: Tensor) -> Any: ... # THPModule_toDLPackVersioned def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack def _get_cpp_backtrace( frames_to_skip: _int, @@ -1361,6 +1362,8 @@ def _disabled_torch_dispatch_impl( ) -> Any: ... # THPModule_disable_dispatch_function def _get_linalg_preferred_backend() -> _LinalgBackend: ... def _set_linalg_preferred_backend(arg: _LinalgBackend): ... +def _get_fp32_precision_getter(backend: str, op: str) -> str: ... +def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ... class _LinalgBackend: Default: _LinalgBackend @@ -1561,6 +1564,7 @@ class PyTorchFileReader: @overload def __init__(self, buffer: IO[bytes]) -> None: ... def get_record(self, name: str) -> bytes: ... + def get_all_records(self) -> list[str]: ... def serialization_id(self) -> str: ... class PyTorchFileWriter: @@ -1631,6 +1635,7 @@ class Generator: class _DispatchOperatorHandle: def schema(self) -> FunctionSchema: ... def debug(self) -> str: ... + def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ... class _DispatchModule: def reset(self) -> None: ... @@ -1827,7 +1832,7 @@ class _SetExcludeDispatchKeyGuard: # Defined in torch/csrc/utils/schema_info.h class _SchemaInfo: - def __init__(self, schema: _int) -> None: ... + def __init__(self, schema: FunctionSchema) -> None: ... @overload def is_mutable(self) -> _bool: ... @overload @@ -1921,6 +1926,7 @@ def _accelerator_hooks_get_current_device() -> _int: ... def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ... def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ... def _get_accelerator(check: _bool = False) -> _device: ... +def _storage_Use_Count(storage_ptr: _int) -> _int: ... # Defined in torch/csrc/mtia/Module.cpp def _mtia_init() -> None: ... @@ -2066,7 +2072,6 @@ def _construct_CUDA_Tensor_From_Storage_And_Metadata( metadata: dict, storage: Storage, ) -> Tensor: ... -def _storage_Use_Count(storage_ptr: _int) -> _int: ... def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ... def _set_storage_data_ptr_access_error_msg(storage_ptr: _int, s: str) -> None: ... def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ... @@ -2265,6 +2270,7 @@ class _CudaEventBase: enable_timing: _bool = False, blocking: _bool = False, interprocess: _bool = False, + external: _bool = False, ) -> Self: ... @classmethod def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ... @@ -2277,18 +2283,21 @@ class _CudaEventBase: # Defined in torch/csrc/cuda/Graph.cpp class _CUDAGraph: + def __new__(cls, keep_graph: _bool = ...) -> Self: ... def capture_begin( self, pool: tuple[_int, _int] | None = ..., capture_error_mode: str = "global", ) -> None: ... def capture_end(self) -> None: ... + def instantiate(self) -> None: ... def register_generator_state(self, Generator) -> None: ... def replay(self) -> None: ... def reset(self) -> None: ... def pool(self) -> tuple[_int, _int]: ... def enable_debug_mode(self) -> None: ... def debug_dump(self, debug_path: str) -> None: ... + def raw_cuda_graph(self) -> _int: ... # Defined in torch/csrc/cuda/MemPool.cpp class _MemPool: @@ -2297,10 +2306,13 @@ class _MemPool: allocator: _cuda_CUDAAllocator | None = None, is_user_created: _bool = True, use_on_oom: _bool = False, + symmetric: _bool = False, ) -> None: ... @property def id(self) -> tuple[_int, _int]: ... @property + def is_symmetric(self) -> _bool: ... + @property def allocator(self) -> _cuda_CUDAAllocator | None: ... def use_count(self) -> _int: ... @@ -2330,6 +2342,7 @@ class _XpuDeviceProperties: name: str platform_name: str vendor: str + device_id: _int driver_version: str version: str max_compute_units: _int @@ -2628,6 +2641,7 @@ def _will_engine_execute_node(node: _Node) -> _bool: ... def _dispatch_key_set(tensor) -> str: ... # Defined in torch/csrc/Exceptions.cpp +class AcceleratorError(RuntimeError): ... class OutOfMemoryError(RuntimeError): ... class _DistError(RuntimeError): ... class _DistBackendError(RuntimeError): ... diff --git a/torch/_C/_aoti.pyi b/torch/_C/_aoti.pyi index 2b60fd237b703d..2f57b5e5e72b1e 100644 --- a/torch/_C/_aoti.pyi +++ b/torch/_C/_aoti.pyi @@ -1,4 +1,5 @@ from ctypes import c_void_p +from typing import overload, Protocol from torch import Tensor @@ -16,10 +17,148 @@ def alloc_tensor_by_stealing_from_void_ptr( handle: c_void_p, ) -> Tensor: ... -class AOTIModelContainerRunnerCpu: ... -class AOTIModelContainerRunnerCuda: ... -class AOTIModelContainerRunnerXpu: ... -class AOTIModelContainerRunnerMps: ... +class AOTIModelContainerRunner(Protocol): + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +class AOTIModelContainerRunnerCpu: + def __init__(self, model_so_path: str, num_models: int) -> None: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +class AOTIModelContainerRunnerCuda: + @overload + def __init__(self, model_so_path: str, num_models: int) -> None: ... + @overload + def __init__( + self, model_so_path: str, num_models: int, device_str: str + ) -> None: ... + @overload + def __init__( + self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str + ) -> None: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +class AOTIModelContainerRunnerXpu: + @overload + def __init__(self, model_so_path: str, num_models: int) -> None: ... + @overload + def __init__( + self, model_so_path: str, num_models: int, device_str: str + ) -> None: ... + @overload + def __init__( + self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str + ) -> None: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +class AOTIModelContainerRunnerMps: + def __init__(self, model_so_path: str, num_models: int) -> None: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... # Defined in torch/csrc/inductor/aoti_package/pybind.cpp -class AOTIModelPackageLoader: ... +class AOTIModelPackageLoader: + def __init__( + self, + model_package_path: str, + model_name: str, + run_single_threaded: bool, + num_runners: int, + device_index: int, + ) -> None: ... + def get_metadata(self) -> dict[str, str]: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def boxed_run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_fqns(self) -> list[str]: ... + def load_constants( + self, + constants_map: dict[str, Tensor], + use_inactive: bool, + check_full_update: bool, + user_managed: bool = ..., + ) -> None: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 8c07061a8f89cc..b166b280df9da3 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -77,6 +77,7 @@ class _KinetoEvent: def cuda_elapsed_us(self) -> int: ... def privateuse1_elapsed_us(self) -> int: ... def is_user_annotation(self) -> bool: ... + def is_hidden_event(self) -> bool: ... class _ProfilerResult: def events(self) -> list[_KinetoEvent]: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 018229401f3fac..d145ed7ce653a7 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -186,6 +186,7 @@ class Store: def set(self, key: str, value: str): ... def get(self, key: str) -> bytes: ... def add(self, key: str, value: int) -> int: ... + def check(self, keys: list[str]) -> bool: ... def compare_set( self, key: str, @@ -318,6 +319,14 @@ class Backend: def _set_sequence_number_for_group(self) -> None: ... def _set_default_timeout(self, timeout: timedelta) -> None: ... def get_error(self) -> ErrorType: ... + def supports_tensor_alloc(self, device: torch.device) -> bool: ... + def allocate_tensor( + self, + size: int, + *, + dtype: torch.dtype, + device: torch.device, + ) -> Tensor: ... @property def mem_allocator(self) -> Any: ... @@ -564,6 +573,8 @@ class ProcessGroupGloo(Backend): class Options(Backend.Options): devices: list[ProcessGroupGloo.Device] threads: int + global_ranks_in_group: list[int] + group_name: str def __init__(self): ... @@ -579,6 +590,8 @@ class ProcessGroupGloo(Backend): @staticmethod def create_default_device(lazy_init=None) -> Device: ... def _set_default_timeout(self, timeout) -> None: ... + @property + def options(self) -> Options: ... # type: ignore[override] class _ProcessGroupWrapper(Backend): def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo) -> None: ... @@ -630,6 +643,10 @@ class ProcessGroupNCCL(Backend): def uid(self) -> int: ... @property def options(self) -> Options: ... # type: ignore[override] + @staticmethod + def get_build_nccl_version(self) -> tuple[int, int, int]: ... + @staticmethod + def get_runtime_nccl_version(self) -> tuple[int, int, int]: ... class ProcessGroupUCC(Backend): def __init__( @@ -683,6 +700,14 @@ def _allow_inflight_collective_as_graph_input() -> bool: ... def _unregister_all_process_groups() -> None: ... def _unregister_process_group(group_name: str) -> None: ... +# Initializes the device state in CUmodule so that it’s able to perform NVSHMEM +# operations. CUmodule is a pointer to a CUDA module, carried by a int64 in +# Python. At C++ interface, it is converted to a uintptr_t. +def _nvshmemx_cumodule_init(module: int) -> None: ... + +# Check if NVSHMEM is available on current system. +def _is_nvshmem_available() -> bool: ... + class _SymmetricMemory: @staticmethod def set_group_info( @@ -705,6 +730,11 @@ class _SymmetricMemory: device_type: DeviceType, device_idx: int, ) -> bool: ... + # Set Symmetric Memory allocation backend. + @staticmethod + def set_backend(name: str) -> None: ... + @staticmethod + def get_backend(device: torch.device) -> Optional[str]: ... @property def rank(self) -> int: ... @property diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index da0b3263775983..129984e6c10d3a 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,8 +1,13 @@ import enum import types -from typing import overload +from typing import Optional, overload -from torch._dynamo.types import DynamoCallback, DynamoGuardHook +from torch._dynamo.types import ( + DynamoCallback, + DynamoGuardCompleteHook, + DynamoGuardHook, + GuardFn, +) def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def set_skip_guard_eval_unsafe(value: bool) -> bool: ... @@ -13,6 +18,9 @@ def set_code_exec_strategy( code: types.CodeType, strategy: _FrameExecStrategy ) -> None: ... def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... +def set_guard_complete_hook( + hook: Optional[DynamoGuardCompleteHook], +) -> Optional[DynamoGuardCompleteHook]: ... def raise_sigtrap() -> None: ... class _CacheEntry: @@ -20,6 +28,9 @@ class _CacheEntry: code: types.CodeType next: _CacheEntry | None +class _PrecompileEntry: + guard_manager: GuardFn + class _ExtraState: def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ... @@ -57,3 +68,8 @@ def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ... py_opcode_caches: list[int] def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... +def _load_precompile_entry( + code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType +) -> None: ... +def _reset_precompile_entries(code: types.CodeType) -> None: ... +def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ... diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 5059968df49d79..c05345497e173d 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -100,7 +100,9 @@ class GuardManager: equals_val, verbose_code_parts: list[str], ) -> None: ... - def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... + def add_global_state_guard( + self, initial_state, verbose_code_parts: list[str] + ) -> None: ... def add_torch_function_mode_stack_guard( self, initial_stack, verbose_code_parts: list[str] ) -> None: ... @@ -176,6 +178,12 @@ def assert_size_stride( item: torch.Tensor, size: torch.types._size, stride: torch.types._size, + op_name: str | None = None, +): ... +def assert_alignment( + item: torch.Tensor, + alignment: int, + op_name: str | None = None, ): ... def check_obj_id(obj: object, expected: int) -> bool: ... def check_type_id(obj: object, expected: int) -> bool: ... diff --git a/torch/_C/_jit_tree_views.pyi b/torch/_C/_jit_tree_views.pyi new file mode 100644 index 00000000000000..cf4cffc05a9c34 --- /dev/null +++ b/torch/_C/_jit_tree_views.pyi @@ -0,0 +1,202 @@ +from typing import Any, Optional + +# Defined in torch/csrc/jit/python/python_tree_views.cpp + +class SourceRange: + def highlight(self) -> str: ... + @property + def start(self) -> int: ... + @property + def end(self) -> int: ... + +class SourceRangeFactory: + def __init__( + self, + text: str, + filename: Any, + file_lineno: int, + leading_whitespace_chars: int, + ) -> None: ... + def make_range(self, line: int, start_col: int, end_col: int) -> SourceRange: ... + def make_raw_range(self, start: int, end: int) -> SourceRange: ... + @property + def source(self) -> str: ... + +class TreeView: + def range(self) -> SourceRange: ... + def dump(self) -> None: ... + +class Ident(TreeView): + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + @property + def name(self) -> str: ... + +class Param(TreeView): + def __init__(self, type: Optional[Any], name: Ident, kwarg_only: bool) -> None: ... + +class Attribute(TreeView): + def __init__(self, name: Ident, value: Any) -> None: ... + +# Literals +def TrueLiteral(range: SourceRange) -> Any: ... +def FalseLiteral(range: SourceRange) -> Any: ... +def NoneLiteral(range: SourceRange) -> Any: ... + +# Tree nodes +class Stmt(TreeView): + def __init__(self, thing: TreeView) -> None: ... + +class Expr(TreeView): ... + +class Def(TreeView): + def __init__(self, name: Ident, decl: Any, body: list[Stmt]) -> None: ... + def decl(self) -> Any: ... + def name(self) -> Ident: ... + +class Property(TreeView): + def __init__( + self, r: SourceRange, name: Ident, getter: Def, setter: Optional[Def] + ) -> None: ... + def name(self) -> Ident: ... + def getter_name(self) -> str: ... + def setter_name(self) -> Optional[Ident]: ... + +class ClassDef(TreeView): + def __init__( + self, name: Ident, body: list[Stmt], props: list[Property], assigns: list[Any] + ) -> None: ... + +class Decl(TreeView): + def __init__( + self, r: SourceRange, params: list[Param], return_type: Optional[Expr] + ) -> None: ... + +class Delete(Stmt): + def __init__(self, range: SourceRange, targets: list[Expr]) -> None: ... + +class WithItem(Expr): + def __init__( + self, range: SourceRange, target: Expr, var: Optional[Any] + ) -> None: ... + +class Assign(Stmt): + def __init__( + self, lhs: list[Expr], rhs: Expr, type: Optional[Expr] = None + ) -> None: ... + +class AugAssign(Stmt): + def __init__(self, lhs: Expr, kind_str: str, rhs: Expr) -> None: ... + +class Return(Stmt): + def __init__(self, range: SourceRange, value: Optional[Expr]) -> None: ... + +class Raise(Stmt): + def __init__(self, range: SourceRange, expr: Expr) -> None: ... + +class Assert(Stmt): + def __init__(self, range: SourceRange, test: Expr, msg: Optional[Expr]) -> None: ... + +class Pass(Stmt): + def __init__(self, range: SourceRange) -> None: ... + +class Break(Stmt): ... +class Continue(Stmt): ... + +class Dots(Expr, TreeView): + def __init__(self, range: SourceRange) -> None: ... + +class If(Stmt): + def __init__( + self, + range: SourceRange, + cond: Expr, + true_branch: list[Stmt], + false_branch: list[Stmt], + ) -> None: ... + +class While(Stmt): + def __init__(self, range: SourceRange, cond: Expr, body: list[Stmt]) -> None: ... + +class With(Stmt): + def __init__( + self, range: SourceRange, targets: list[WithItem], body: list[Stmt] + ) -> None: ... + +class For(Stmt): + def __init__( + self, + range: SourceRange, + targets: list[Expr], + itrs: list[Expr], + body: list[Stmt], + ) -> None: ... + +class ExprStmt(Stmt): + def __init__(self, expr: Expr) -> None: ... + +class Var(Expr): + def __init__(self, name: Ident) -> None: ... + @property + def name(self) -> str: ... + +class BinOp(Expr): + def __init__(self, kind: str, lhs: Expr, rhs: Expr) -> None: ... + +class UnaryOp(Expr): + def __init__(self, range: SourceRange, kind: str, expr: Expr) -> None: ... + +class Const(Expr): + def __init__(self, range: SourceRange, value: str) -> None: ... + +class StringLiteral(Expr): + def __init__(self, range: SourceRange, value: str) -> None: ... + +class Apply(Expr): + def __init__( + self, expr: Expr, args: list[Expr], kwargs: list[Attribute] + ) -> None: ... + +class Select(Expr): + def __init__(self, expr: Expr, field: Ident) -> None: ... + +class TernaryIf(Expr): + def __init__(self, cond: Expr, true_expr: Expr, false_expr: Expr) -> None: ... + +class ListComp(Expr): + def __init__( + self, range: SourceRange, elt: Expr, target: Expr, iter: Expr + ) -> None: ... + +class DictComp(Expr): + def __init__( + self, range: SourceRange, key: Expr, value: Expr, target: Expr, iter: Expr + ) -> None: ... + +class ListLiteral(Expr): + def __init__(self, range: SourceRange, args: list[Expr]) -> None: ... + +class TupleLiteral(Expr): + def __init__(self, range: SourceRange, args: list[Expr]) -> None: ... + +class DictLiteral(Expr): + def __init__( + self, range: SourceRange, keys: list[Expr], values: list[Expr] + ) -> None: ... + +class Subscript(Expr): + def __init__(self, base: Expr, subscript_exprs: list[Expr]) -> None: ... + +class SliceExpr(Expr): + def __init__( + self, + range: SourceRange, + lower: Optional[Expr], + upper: Optional[Expr], + step: Optional[Expr], + ) -> None: ... + +class Starred(Expr): + def __init__(self, range: SourceRange, expr: Expr) -> None: ... + +class EmptyTypeAnnotation(TreeView): + def __init__(self, range: SourceRange) -> None: ... diff --git a/torch/__init__.py b/torch/__init__.py index c66c1f139278a9..3f2a2c0d963125 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -32,7 +32,7 @@ TypeVar as _TypeVar, Union as _Union, ) -from typing_extensions import ParamSpec as _ParamSpec +from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs if TYPE_CHECKING: @@ -63,15 +63,7 @@ def _running_with_deploy() -> builtins.bool: # TODO(torch_deploy) figure out how to freeze version.py in fbcode build if _running_with_deploy(): __version__ = "torch-deploy-1.8" - # TODO: Remove this ugly hack when deploy typing extensions are updated to 4.10+ - if not TYPE_CHECKING: - import typing_extensions - - _TypeIs = typing_extensions.TypeGuard - typing_extensions.TypeIs = _TypeIs else: - from typing_extensions import TypeIs as _TypeIs - from torch.torch_version import __version__ as __version__ __all__ = [ @@ -155,6 +147,21 @@ def _running_with_deploy() -> builtins.bool: # Load the extension module ################################################################################ +# If PyTorch was built against the ROCm runtime wheels, then there will be +# a _rocm_init module and it will define an initialize() function which can +# prepare ROCm for use. See general documentation on ROCm runtime wheels: +# https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md +# Since this module is only ever added to the wheel if built for such a +# deployment, it is always safe to attempt. +try: + from . import _rocm_init # type: ignore[attr-defined] +except ImportError: + pass +else: + _rocm_init.initialize() + del _rocm_init + + if sys.platform == "win32": def _load_dll_libraries() -> None: @@ -1144,14 +1151,32 @@ def get_default_device() -> "torch.device": r"""Gets the default ``torch.Tensor`` to be allocated on ``device``""" global _GLOBAL_DEVICE_CONTEXT - if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): - device = _GLOBAL_DEVICE_CONTEXT.device_context.device + from torch.overrides import _get_current_function_mode_stack + from torch.utils._device import DeviceContext + + def _get_device_with_index(device): if device.index is not None: return device else: # TODO: Call like get_device_index() method corresponding to # each device type return torch.tensor([]).device + + # Get device from any active DeviceContext. + device_mode = next( + filter( + lambda mode: isinstance(mode, DeviceContext), + reversed(_get_current_function_mode_stack()), + ), + None, + ) + if device_mode: + device = device_mode.device + return _get_device_with_index(device) + + if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): + device = _GLOBAL_DEVICE_CONTEXT.device_context.device + return _get_device_with_index(device) else: return torch.device("cpu") @@ -2538,6 +2563,14 @@ def compile( - `trace.graph_diagram` which will show you a picture of your graph after fusion + - `guard_filter_fn` that controls which dynamo guards are saved with compilations. + This is an unsafe feature and there is no backward compatibility guarantee provided + for dynamo guards as data types. + For stable helper functions to use, see the documentations in `torch.compiler`, for example: + - `torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe` + - `torch.compiler.skip_guard_on_all_nn_modules_unsafe` + - `torch.compiler.keep_tensor_guards_unsafe` + - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()` disable (bool): Turn torch.compile() into a no-op for testing @@ -2597,10 +2630,6 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: if options and isinstance(options, dict): guard_filter_fn = options.pop("guard_filter_fn", None) - frame_traced_fn = None - if options and isinstance(options, dict): - frame_traced_fn = options.pop("frame_traced_fn", None) - if backend == "inductor": backend = _TorchCompileInductorWrapper(mode, options, dynamic) else: @@ -2612,7 +2641,6 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: dynamic=dynamic, disable=disable, guard_filter_fn=guard_filter_fn, - frame_traced_fn=frame_traced_fn, )(model) # type: ignore[return-value] diff --git a/torch/_compile.py b/torch/_compile.py index 05e63baa7c9628..33855b44b70572 100644 --- a/torch/_compile.py +++ b/torch/_compile.py @@ -45,7 +45,9 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T: if disable_fn is None: import torch._dynamo - disable_fn = torch._dynamo.disable(fn, recursive) + # We can safely turn off functools.wraps here because the inner + # already wraps fn in the outer scope. + disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False) fn.__dynamo_disable = disable_fn # type: ignore[attr-defined] return disable_fn(*args, **kwargs) diff --git a/torch/_custom_op/autograd.py b/torch/_custom_op/autograd.py index 35727197d03c1c..4f688164a001dc 100644 --- a/torch/_custom_op/autograd.py +++ b/torch/_custom_op/autograd.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs +import functools +from collections import namedtuple + import torch import torch.utils._pytree as pytree -from collections import namedtuple -import functools # NOTE [CustomOp autograd kernel indirection] @@ -19,19 +20,18 @@ def autograd_kernel_indirection(custom_op): autograd_fallback = autograd_not_implemented(custom_op) def inner(*args, **kwargs): - if custom_op._has_impl('autograd'): - kernel = custom_op._get_impl('autograd').func + if custom_op._has_impl("autograd"): + kernel = custom_op._get_impl("autograd").func return kernel(*args, **kwargs) # As explained in NOTE ["backward", "save_for_backward", and "autograd"], # after the user gives us "backward" and "save_for_backward", we generate # the "autograd" impl. If the user only provided one, then we tell # the user they've done something wrong. - if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'): + if custom_op._has_impl("save_for_backward") or custom_op._has_impl("backward"): missing = ( - 'save_for_backward' if custom_op._has_impl('backward') - else 'backward' + "save_for_backward" if custom_op._has_impl("backward") else "backward" ) - found = 'save_for_backward' if missing == 'backward' else 'backward' + found = "save_for_backward" if missing == "backward" else "backward" loc = custom_op._get_impl(found).location raise RuntimeError( f"We found a '{found}' registration for {custom_op} at " @@ -39,8 +39,10 @@ def inner(*args, **kwargs): f"To use the CustomOp API to register a backward formula, " f"please provide us both a backward function and a " f"'save for backward' function via `impl_backward` and " - f"`impl_save_for_backward` respectively.") + f"`impl_save_for_backward` respectively." + ) return autograd_fallback(*args, **kwargs) + return inner @@ -54,6 +56,7 @@ def kernel(*args, **kwargs): raise RuntimeError("Autograd has not been implemented for operator") with torch._C._AutoDispatchBelowAutograd(): return custom_op(*args, **kwargs) + return kernel @@ -70,7 +73,9 @@ def mark_non_differentiable(ctx, output, output_differentiability): tuple_output = output # type: ignore[assignment] assert len(output_differentiability) == len(tuple_output) non_differentiable_tensors = [] - for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)): + for idx, (differentiable, out) in enumerate( + zip(output_differentiability, tuple_output) + ): if isinstance(out, torch.Tensor): if not differentiable: non_differentiable_tensors.append(out) @@ -84,19 +89,20 @@ def mark_non_differentiable(ctx, output, output_differentiability): f"With output_differentiability={output_differentiability}. " f"At idx {idx}, we received an object of type {type(out)} that " f"is not a Tensor, so it cannot have be marked as differentiable in " - f"output_differentiability.") + f"output_differentiability." + ) if non_differentiable_tensors: ctx.mark_non_differentiable(*non_differentiable_tensors) def construct_autograd_kernel( - schema, - output_differentiability, - custom_op, - op_overload, - save_for_backward_fn, - backward_fn): - + schema, + output_differentiability, + custom_op, + op_overload, + save_for_backward_fn, + backward_fn, +): def apply(*args): flat_args, spec = pytree.tree_flatten(args) out_spec = None @@ -108,8 +114,7 @@ def forward(ctx, *flat_args): output = op_overload(*args) # We use the info about args to give better error messages in backward - args_info = namedtuple_args( - schema, pytree.tree_map(type, args)) + args_info = namedtuple_args(schema, pytree.tree_map(type, args)) save_for_backward_fn_inputs = namedtuple_args(schema, args) to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) @@ -138,11 +143,13 @@ def backward(ctx, *flat_grad_output): return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) generated_cls = gen_autograd_function( - custom_op._opname + '_customop', forward, backward) + custom_op._opname + "_customop", forward, backward + ) flat_output = generated_cls.apply(*flat_args) assert out_spec is not None return pytree.tree_unflatten(list(flat_output), out_spec) + return apply @@ -151,9 +158,9 @@ def gen_autograd_function(name, forward, backward): name, (torch.autograd.Function,), { - 'forward': staticmethod(forward), - 'backward': staticmethod(backward), - } + "forward": staticmethod(forward), + "backward": staticmethod(backward), + }, ) return generated_cls @@ -175,62 +182,82 @@ def namedtuple_args(schema, args): def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): def error(what): - backward = forward_op._get_impl('backward') + backward = forward_op._get_impl("backward") raise RuntimeError( f"In the backward function defined for {forward_op} at " - f"{backward.location} using the CustomOp API, {what}") + f"{backward.location} using the CustomOp API, {what}" + ) if not isinstance(grad_inputs_dict, dict): - error(f"expected the output of the backward function to be a dict but " - f"got {type(grad_inputs_dict)}") - - expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all - if arg.type.is_tensor_like()} + error( + f"expected the output of the backward function to be a dict but " + f"got {type(grad_inputs_dict)}" + ) + + expected_keys = { + arg.name + for arg in forward_op._schema.arguments.flat_all + if arg.type.is_tensor_like() + } actual_keys = grad_inputs_dict.keys() if expected_keys != actual_keys: - error(f"expected the returned grad_input dict to have keys " - f"{expected_keys} but got {actual_keys}. The backward " - f"function must return a gradient (can be None) for each arg " - f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " - f"Args declared to be non-Tensor-like types should not appear " - f"in the grad_input dict") + error( + f"expected the returned grad_input dict to have keys " + f"{expected_keys} but got {actual_keys}. The backward " + f"function must return a gradient (can be None) for each arg " + f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " + f"Args declared to be non-Tensor-like types should not appear " + f"in the grad_input dict" + ) for name, grad in grad_inputs_dict.items(): arg_info = getattr(args_info, name) if isinstance(arg_info, list): if not isinstance(grad, (tuple, list)): - error(f"for input '{name}' expected the grad_input dict to " - f"hold a list of gradients but got object of type " - f"{type(grad)}.") + error( + f"for input '{name}' expected the grad_input dict to " + f"hold a list of gradients but got object of type " + f"{type(grad)}." + ) if not len(grad) == len(arg_info): - error(f"for input '{name}' expected the grad_input dict to " - f"hold a list of {len(arg_info)} gradients but got " - f"{len(grad)}") + error( + f"for input '{name}' expected the grad_input dict to " + f"hold a list of {len(arg_info)} gradients but got " + f"{len(grad)}" + ) for idx, (g, info) in enumerate(zip(grad, arg_info)): if g is None: continue if not isinstance(g, torch.Tensor): - error(f"for input '{name}' expected the grad_input dict to " - f"hold a list of None or Tensor gradients but got " - f"object of {type(g)} at index {idx}") + error( + f"for input '{name}' expected the grad_input dict to " + f"hold a list of None or Tensor gradients but got " + f"object of {type(g)} at index {idx}" + ) if not issubclass(info, torch.Tensor): - error(f"for input '{name}', got a Tensor as the gradient " - f"for the {idx}-th value but expected None because " - f"the {idx}-th value was not a Tensor (it was " - f"type {arg_info}") + error( + f"for input '{name}', got a Tensor as the gradient " + f"for the {idx}-th value but expected None because " + f"the {idx}-th value was not a Tensor (it was " + f"type {arg_info}" + ) continue if grad is None: continue if not isinstance(grad, torch.Tensor): - error(f"got object of type {type(grad)} as the gradient for input " - f"'{name}', " - f"but expected the gradient to be either None or a Tensor") + error( + f"got object of type {type(grad)} as the gradient for input " + f"'{name}', " + f"but expected the gradient to be either None or a Tensor" + ) if not issubclass(arg_info, torch.Tensor): - error(f"got a Tensor as the gradient for input '{name}' but " - f"expected None as the gradient because input '{name}' " - f"was not a Tensor (it was type {arg_info}).") + error( + f"got a Tensor as the gradient for input '{name}' but " + f"expected None as the gradient because input '{name}' " + f"was not a Tensor (it was type {arg_info})." + ) def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): @@ -242,6 +269,7 @@ def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): result.append(grad_inputs_dict[name]) return tuple(pytree.tree_leaves(result)) + # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. # autograd.Function prefers that users use ctx.save_for_backward to # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the @@ -249,10 +277,14 @@ def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): def save_pytree_for_backward(ctx, stuff): flat_stuff, spec = pytree.tree_flatten(stuff) num_elts = len(flat_stuff) - tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) - if isinstance(thing, torch.Tensor)] - non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) - if not isinstance(thing, torch.Tensor)] + tensor_idxs = [ + idx for idx, thing in enumerate(flat_stuff) if isinstance(thing, torch.Tensor) + ] + non_tensor_idxs = [ + idx + for idx, thing in enumerate(flat_stuff) + if not isinstance(thing, torch.Tensor) + ] tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index ffa7ded27dbc8e..dd3e9e8fa2dd1b 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -4,19 +4,26 @@ import inspect import sys import typing -import weakref import warnings - -from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy +import weakref import torch import torch._C as _C +import torch._library.infer_schema import torch.library as library +from torch._library.infer_schema import infer_schema from torch.library import get_ctx +from torchgen.model import ( + BaseTy, + BaseType, + FunctionSchema, + ListType, + OperatorName, + SchemaKind, +) from .autograd import autograd_kernel_indirection, construct_autograd_kernel -import torch._library.infer_schema -from torch._library.infer_schema import infer_schema + """ torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library. @@ -42,10 +49,13 @@ "pytorch", } + def warn_deprecated(): warnings.warn( "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please " - "use the equivalent torch.library API instead.", DeprecationWarning) + "use the equivalent torch.library API instead.", + DeprecationWarning, + ) def custom_op( @@ -73,7 +83,11 @@ def inner(func): f"is passed to `custom_op`" ) - schema = infer_schema(func, mutates_args=()) if manual_schema is None else manual_schema + schema = ( + infer_schema(func, mutates_args=()) + if manual_schema is None + else manual_schema + ) schema_str = f"{name}{schema}" function_schema = FunctionSchema.parse(schema_str) validate_schema(function_schema) @@ -83,7 +97,9 @@ def inner(func): lib = library.Library(ns, "FRAGMENT") lib.define(schema_str) ophandle = find_ophandle_or_throw(ns, function_schema.name) - result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) + result = CustomOp( + lib, ns, function_schema, name, ophandle, _private_access=True + ) result.__name__ = func.__name__ result.__module__ = func.__module__ @@ -116,7 +132,9 @@ class CustomOp: This API is deprecated, please use torch.library.custom_op instead """ - def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False): + def __init__( + self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False + ): super().__init__() warn_deprecated() if not _private_access: @@ -144,7 +162,9 @@ def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_acc def _register_autograd_kernel_indirection(self): assert not self._registered_autograd_kernel_indirection - self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd") + self._lib.impl( + self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd" + ) self._registered_autograd_kernel_indirection = True # Records the impl and the source location in self._impls @@ -196,7 +216,9 @@ def __call__(self, *args, **kwargs): return result def impl( - self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2, + self, + device_types: typing.Union[str, typing.Iterable[str]], + _stacklevel=2, ) -> typing.Callable: r""" This API is deprecated, please use torch.library.custom_op instead @@ -224,7 +246,8 @@ def _check_doesnt_have_library_impl(self, device_type): raise RuntimeError( f"impl(..., device_types={device_type}): the operator {self._qualname} " f"already has an implementation for this device type via a " - f"pre-existing torch.library or TORCH_LIBRARY registration.") + f"pre-existing torch.library or TORCH_LIBRARY registration." + ) def impl_factory(self) -> typing.Callable: r"""Register an implementation for a factory function.""" @@ -306,20 +329,25 @@ def error(detail): for ret in schema.returns: if ret.type in allowed_return_types: continue - error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})") + error( + f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})" + ) def _check_doesnt_have_library_autograd_impl(self): if self._registered_autograd_kernel_indirection: return - if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): + if _C._dispatch_has_kernel_for_dispatch_key( + self._qualname, "CompositeImplicitAutograd" + ): raise RuntimeError( f"impl_backward/impl_save_for_backward: the operator {self._qualname} " f"already has an implementation for this device type via a " f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." f"CompositeImplicitAutograd operators do not need an autograd formula; " f"instead, the operator will decompose into its constituents and those " - f"can have autograd formulas defined on them.") + f"can have autograd formulas defined on them." + ) # We can improve this by adding "all Autograd keys", but # realistically people will just be using this API for CPU/CUDA for now. @@ -330,7 +358,8 @@ def _check_doesnt_have_library_autograd_impl(self): f"the operator {self._qualname} already has an Autograd kernel " f"registered to DispatchKey::{key} vi a pre-existing " f"torch.library or TORCH_LIBRARY registration. Please either " - f"remove those registrations or don't use the torch._custom_ops APIs") + f"remove those registrations or don't use the torch._custom_ops APIs" + ) def _check_doesnt_have_library_meta_impl(self): if self._has_impl("abstract"): @@ -341,10 +370,9 @@ def _check_doesnt_have_library_meta_impl(self): # (existing custom ops may have CompositeExplicitAutograd # registration that don't work with Meta kernels, so this # gives them an escape hatch). - if ( - _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd") - and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta") - ): + if _C._dispatch_has_kernel_for_dispatch_key( + self._qualname, "CompositeExplicitAutograd" + ) and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): return # Otherwise, if the user's already has a Meta kernel or their @@ -352,21 +380,25 @@ def _check_doesnt_have_library_meta_impl(self): # raise. # Special case for CompositeImplicitAutograd - if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): + if _C._dispatch_has_kernel_for_dispatch_key( + self._qualname, "CompositeImplicitAutograd" + ): raise RuntimeError( f"impl_abstract(...): the operator {self._qualname} " f"already has an implementation for this device type via a " f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." f"CompositeImplicitAutograd operators do not need an abstract impl; " f"instead, the operator will decompose into its constituents and those " - f"can have abstract impls defined on them.") + f"can have abstract impls defined on them." + ) if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): raise RuntimeError( f"impl_abstract(...): the operator {self._qualname} " f"already has an DispatchKey::Meta implementation via a " f"pre-existing torch.library or TORCH_LIBRARY registration. " - f"Please either remove that registration or don't call impl_abstract.") + f"Please either remove that registration or don't call impl_abstract." + ) # NOTE ["backward", "save_for_backward", and "autograd"] # As a part of the explicit autograd API, a user must provide us @@ -382,7 +414,8 @@ def _register_autograd_kernel(self): self, get_op(self._qualname), self._get_impl("save_for_backward").func, - self._get_impl("backward").func) + self._get_impl("backward").func, + ) self._register_impl("autograd", kernel) def impl_save_for_backward(self, _stacklevel=2): @@ -390,6 +423,7 @@ def impl_save_for_backward(self, _stacklevel=2): Please see impl_backward for more details. """ + def inner(f): self._check_can_register_backward() self._check_doesnt_have_library_autograd_impl() @@ -398,6 +432,7 @@ def inner(f): self._register_impl("save_for_backward", f, stacklevel=_stacklevel) if self._has_impl("backward"): self._register_autograd_kernel() + return inner def impl_backward(self, output_differentiability=None, _stacklevel=2): @@ -405,12 +440,14 @@ def impl_backward(self, output_differentiability=None, _stacklevel=2): This API is deprecated, please use torch.library.custom_op instead """ if output_differentiability is not None: + def yell(): raise RuntimeError( f"impl_backward(output_differentiability): expected " f"output_differentiability to be a list of bools with " f"length equal to the number of outputs of this CustomOp " - f"got: {output_differentiability}") + f"got: {output_differentiability}" + ) if not isinstance(output_differentiability, list): yell() @@ -429,6 +466,7 @@ def inner(f): self._output_differentiability = output_differentiability if self._has_impl("save_for_backward"): self._register_autograd_kernel() + return inner @@ -459,6 +497,7 @@ def validate_namespace(ns: str) -> None: f"please choose something else. " ) + def validate_schema(schema: FunctionSchema) -> None: if not torch._library.utils.is_functional_schema(schema): raise ValueError( @@ -479,13 +518,17 @@ def validate_schema(schema: FunctionSchema) -> None: def parse_qualname(qualname: str) -> tuple[str, str]: names = qualname.split("::", 1) if len(names) != 2: - raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The " - f"operator name should look something like ns::foo") - if '.' in names[1]: - raise ValueError(f"The torch.custom_ops APIs do not handle overloads, " - f"i.e. operator names with '.' in them. " - f"Please name your operator something like ns::foo. " - f"Got: {qualname}") + raise ValueError( + f"Expected there to be a namespace in {qualname}, i.e. The " + f"operator name should look something like ns::foo" + ) + if "." in names[1]: + raise ValueError( + f"The torch.custom_ops APIs do not handle overloads, " + f"i.e. operator names with '.' in them. " + f"Please name your operator something like ns::foo. " + f"Got: {qualname}" + ) return names[0], names[1] @@ -615,7 +658,8 @@ def error_not_found(): raise ValueError( f"Could not find the operator {qualname}. Please make sure you have " f"already registered the operator and (if registered from C++) " - f"loaded it via torch.ops.load_library.") + f"loaded it via torch.ops.load_library." + ) ns, name = parse_qualname(qualname) if not hasattr(torch.ops, ns): @@ -624,7 +668,7 @@ def error_not_found(): if not hasattr(opnamespace, name): error_not_found() packet = getattr(opnamespace, name) - if not hasattr(packet, 'default'): + if not hasattr(packet, "default"): error_not_found() return packet.default @@ -635,7 +679,8 @@ def _find_custom_op(qualname, also_check_torch_library=False): if not also_check_torch_library: raise RuntimeError( f'Could not find custom op "{qualname}". Did you register it via ' - f"the torch._custom_ops API?") + f"the torch._custom_ops API?" + ) overload = get_op(qualname) result = custom_op_from_existing(overload) return result diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 04abcc6c2d4cef..0ff7e46f839ba0 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -750,11 +750,11 @@ def slice_forward( elif guard_size_oblivious(start_val > sizes[dim]): start_val = sizes[dim] - if guard_size_oblivious(end_val < start_val): + if statically_known_true(end_val == sys.maxsize): + end_val = sizes[dim] + elif guard_size_oblivious(end_val < start_val): end_val = start_val - elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious( - end_val > sizes[dim] - ): + elif guard_size_oblivious(end_val > sizes[dim]): end_val = sizes[dim] storage_offset = self.storage_offset() + start_val * strides[dim] @@ -814,7 +814,7 @@ def slice_scatter( if start == 0 and end == dim_size and step == 1: return src.clone() - indices = [None] * input.dim() + indices: list[Optional[Tensor]] = [None] * input.dim() idx = torch.arange(dim_size, device=input.device) indices[dim] = (idx - start) // step @@ -1677,6 +1677,7 @@ def native_layer_norm_backward( ) mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr] rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + assert input_cast is not None x_hat = (input_cast - mean) * rstd if weight_cast is not None: grad_x_hat = grad_out_cast * weight_cast diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index 1e03146bbcc362..a4103eb8387dcf 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -3,12 +3,15 @@ import unittest.mock from collections.abc import Iterator from contextlib import contextmanager +from typing import Callable, TypeVar, Union +from typing_extensions import ParamSpec import torch import torch._C import torch._ops import torch.utils._python_dispatch import torch.utils._pytree as pytree +from torch._C import DispatchKey __all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] @@ -19,6 +22,9 @@ CROSSREF_FUNCTIONALIZE = False +_P = ParamSpec("_P") +_T = TypeVar("_T") + def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: """ @@ -103,14 +109,16 @@ def _fmt(a: object) -> object: return a -def make_crossref_functionalize(op, final_key): +def make_crossref_functionalize( + op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey +) -> Union[Callable[_P, _T], DispatchKey]: from torch._subclasses.fake_tensor import FakeTensorMode # This case is pretty weird, suppress it for now if op == torch.ops.aten.lift_fresh.default: return final_key - def handler(*args, **kwargs): + def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T: fake_mode = FakeTensorMode() def fakeify_defun(t): diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 578a36a0b0160d..ec6d522ae1d435 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -30,6 +30,7 @@ nonstrict_trace, patch_dynamo_config, run, + set_fullgraph, set_stance, skip_frame, substitute_in_graph, @@ -59,11 +60,17 @@ __all__ = [ "allow_in_graph", "assume_constant_result", + "config", + "disable", "disallow_in_graph", "dont_skip_tracing", + "export", + "explain", "forbid_in_graph", - "substitute_in_graph", "graph_break", + "is_compiling", + "list_backends", + "lookup_backend", "mark_dynamic", "maybe_mark_dynamic", "mark_static", @@ -71,21 +78,16 @@ "nonstrict_trace", "optimize", "optimize_assert", + "OptimizedModule", "patch_dynamo_config", - "skip_frame", - "export", - "explain", - "run", + "register_backend", "replay", - "disable", - "set_stance", "reset", - "OptimizedModule", - "is_compiling", - "register_backend", - "list_backends", - "lookup_backend", - "config", + "run", + "set_fullgraph", + "set_stance", + "skip_frame", + "substitute_in_graph", ] # allowlist this for weights_only load of NJTs diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 278c89c83a03c9..8fab0b20054911 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -50,7 +50,7 @@ if not torch._running_with_deploy(): - # torch.library.custom_op does not work with torch.deploy/multipy + # torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore @torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] def zeros_and_scatter( diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index 246596bcbcabed..167f678b6a208a 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -68,7 +68,10 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs): def wrap_bw_compiler(bw_compiler_fn): def _wrapped_bw_compiler(*args, **kwargs): - # stop TorchDynamo from trying to compile our generated backwards pass + # Note [Wrapping bw_compiler in disable] + # The two disables here: + # - stop TorchDynamo from trying to compile the bw_compiler function itself + # - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces return disable( disable( bw_compiler_fn, reason="do not trace backward compiler function" diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 490185b5d42639..cded5b005ee3c4 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -56,7 +56,7 @@ def make_eager_backend_with_torch_function_mode(mode): def make_eager_backend_with_torch_function_modes(modes): - """Used to trace HOPs (cond and while) for eager exectution, the metadata + """Used to trace HOPs (cond and while) for eager execution, the metadata TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks in the HOP, so we need to externally run this mode and not trace it.""" from contextlib import ExitStack diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py index df36dd7d0efe55..6e54fae7e089e6 100644 --- a/torch/_dynamo/backends/distributed.py +++ b/torch/_dynamo/backends/distributed.py @@ -146,6 +146,26 @@ def has_higher_order_op(gm): return False +def propagate_metadata(orig_gm, split_gm) -> None: + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384 + module.meta = orig_gm.meta + module._param_name_to_source = orig_gm._param_name_to_source + + +def propagate_dynamo_source(orig_gm, split_gm) -> None: + name_to_dynamo_source = {} + for node in orig_gm.graph.find_nodes(op="placeholder"): + name_to_dynamo_source[node.name] = node._dynamo_source + + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + for node in module.graph.find_nodes(op="placeholder"): + # non-placeholder in original_gm may become placeholder in submodules + node._dynamo_source = name_to_dynamo_source.get(node.name, None) + + # compile each of the partitioned submodules using the user-provided compiler class SubmodCompiler(torch.fx.interpreter.Interpreter): def __init__(self, module, compiler, fake_mode) -> None: @@ -516,6 +536,10 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]): gm, None, lambda node: partition_map[node] ) + # See note [Assumption on Dynamo Metadata] + propagate_dynamo_source(gm, split_gm) + propagate_metadata(gm, split_gm) + debug_str = ( f"\n---orig graph---\n{gm.graph}\n" + f"\n---split graph---\n{split_gm.graph}\n" diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index 01381aa66b80e8..79376b0e460bf8 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -154,7 +154,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: return sorted(backends) -@functools.lru_cache(None) +@functools.cache def _lazy_import(): from .. import backends from ..utils import import_submodule @@ -168,7 +168,7 @@ def _lazy_import(): _discover_entrypoint_backends() -@functools.lru_cache(None) +@functools.cache def _discover_entrypoint_backends(): # importing here so it will pick up the mocked version in test_backends.py from importlib.metadata import entry_points diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index 3a5b239183f3da..ab0097e314ca95 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -201,7 +201,7 @@ def has_tvm(): return False -@functools.lru_cache(None) +@functools.cache def llvm_target(): if sys.platform == "linux": cpuinfo = open("/proc/cpuinfo").read() diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 0164287b5734c0..9226a61577d87f 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1649,7 +1649,7 @@ def template(): # replace returns with jumps for inst in returns: # don't replace inst with new instruction - # due to targetting/exn table/etc. + # due to targeting/exn table/etc. jump_inst = create_jump_absolute(insts[-1]) inst.opname = jump_inst.opname inst.opcode = jump_inst.opcode diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index f88f8dab7d62e4..cff7ea3fef334b 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -86,7 +86,7 @@ class CacheSizeRelevantForFrame: num_cache_entries_with_same_id_matched_objs: int = 0 def will_compilation_exceed(self, limit: int) -> bool: - # Checks if a compilation will exceed the given limit (thats why >=). + # Checks if a compilation will exceed the given limit (that's why >=). return ( self.will_compilation_exceed_accumulated_limit() or self.will_compilation_exceed_specific_limit(limit) diff --git a/torch/_dynamo/callback.py b/torch/_dynamo/callback.py index fa936d4012718a..58cfe66baee7ab 100644 --- a/torch/_dynamo/callback.py +++ b/torch/_dynamo/callback.py @@ -25,6 +25,7 @@ def my_end_callback(): print("Compilation complete") """ +import enum import threading from collections.abc import Generator from contextlib import contextmanager @@ -32,10 +33,27 @@ def my_end_callback(): from typing import Any, Callable +class CallbackTrigger(enum.Enum): + # most common case, dynamo attempts to trace a new frame + DYNAMO = 1 + # backward compilation can be deferred to runtime + LAZY_BACKWARD = 2 + # some backends autotune at runtime + TRITON_AUTOTUNING = 3 # Temporarily disabled due to spam + # cudagraphs record at runtime + CUDAGRAPH_RECORDING = 4 + + +@dataclass +class CallbackArgs: + callback_trigger: CallbackTrigger + compile_id: str + + @dataclass class CompilationCallbackHandler: - start_callbacks: list[Callable[[], None]] = field(default_factory=list) - end_callbacks: list[Callable[[], None]] = field(default_factory=list) + start_callbacks: list[Callable[[CallbackArgs], None]] = field(default_factory=list) + end_callbacks: list[Callable[[CallbackArgs], None]] = field(default_factory=list) __pending_callbacks_counter: int = field(default=0, init=False, repr=False) __pending_callbacks_counter_lock: threading.Lock = field( @@ -43,8 +61,8 @@ class CompilationCallbackHandler: ) def register_start_callback( - self, callback: Callable[[], None] - ) -> Callable[[], None]: + self, callback: Callable[[CallbackArgs], None] + ) -> Callable[[CallbackArgs], None]: """ Register a callback function to be called when the compilation starts. @@ -54,7 +72,9 @@ def register_start_callback( self.start_callbacks.append(callback) return callback - def register_end_callback(self, callback: Callable[[], None]) -> Callable[[], None]: + def register_end_callback( + self, callback: Callable[[CallbackArgs], None] + ) -> Callable[[CallbackArgs], None]: """ Register a callback function to be called when the compilation ends. @@ -64,7 +84,7 @@ def register_end_callback(self, callback: Callable[[], None]) -> Callable[[], No self.end_callbacks.append(callback) return callback - def remove_start_callback(self, callback: Callable[[], None]) -> None: + def remove_start_callback(self, callback: Callable[[CallbackArgs], None]) -> None: """ Remove a registered start callback function. @@ -73,7 +93,7 @@ def remove_start_callback(self, callback: Callable[[], None]) -> None: """ self.start_callbacks.remove(callback) - def remove_end_callback(self, callback: Callable[[], None]) -> None: + def remove_end_callback(self, callback: Callable[[CallbackArgs], None]) -> None: """ Remove a registered end callback function. @@ -82,30 +102,33 @@ def remove_end_callback(self, callback: Callable[[], None]) -> None: """ self.end_callbacks.remove(callback) - def run_start_callbacks(self) -> None: + def run_start_callbacks(self, args: CallbackArgs) -> None: """ Execute all registered start callbacks. """ for callback in self.start_callbacks: - callback() + callback(args) - def run_end_callbacks(self) -> None: + def run_end_callbacks(self, args: CallbackArgs) -> None: """ Execute all registered end callbacks. """ for callback in self.end_callbacks: - callback() + callback(args) @contextmanager - def install_callbacks(self) -> Generator[None, Any, Any]: + def install_callbacks( + self, trigger: CallbackTrigger, compile_id: str + ) -> Generator[None, Any, Any]: """ Context manager to install the callbacks and run them when the context is exited. """ + args = CallbackArgs(trigger, compile_id) try: with self.__pending_callbacks_counter_lock: - if self.__pending_callbacks_counter == 0: - self.run_start_callbacks() self.__pending_callbacks_counter += 1 + if self.__pending_callbacks_counter == 1: + self.run_start_callbacks(args) yield finally: with self.__pending_callbacks_counter_lock: @@ -113,7 +136,7 @@ def install_callbacks(self) -> Generator[None, Any, Any]: "Pending callbacks counter cannot become negative." ) if self.__pending_callbacks_counter == 1: - self.run_end_callbacks() + self.run_end_callbacks(args) self.__pending_callbacks_counter -= 1 def clear(self) -> None: @@ -122,12 +145,15 @@ def clear(self) -> None: """ self.start_callbacks.clear() self.end_callbacks.clear() + assert self.__pending_callbacks_counter == 0 callback_handler = CompilationCallbackHandler() -def on_compile_start(callback: Callable[[], None]) -> Callable[[], None]: +def on_compile_start( + callback: Callable[[CallbackArgs], None], +) -> Callable[[CallbackArgs], None]: """ Decorator to register a callback function for the start of the compilation. """ @@ -135,7 +161,9 @@ def on_compile_start(callback: Callable[[], None]) -> Callable[[], None]: return callback -def on_compile_end(callback: Callable[[], None]) -> Callable[[], None]: +def on_compile_end( + callback: Callable[[CallbackArgs], None], +) -> Callable[[CallbackArgs], None]: """ Decorator to register a callback function for the end of the compilation. """ diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index ec29cba0dfc400..a0bebb053d6e03 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -23,7 +23,7 @@ import torch.nn from torch.utils._ordered_set import OrderedSet -from . import graph_break_hints, utils +from . import config, graph_break_hints, utils from .bytecode_transformation import ( add_push_null, add_push_null_call_function_ex, @@ -79,7 +79,7 @@ def __init__( ) -> None: self.root = root self.top_of_stack: Optional[Union[VariableTracker, Source]] = None - self.uses: Counter[VariableTracker] = collections.Counter() + self.uses: Counter[Union[VariableTracker, Source]] = collections.Counter() self.graph_outputs: dict[int, GraphOutputEntry] = {} self._output: list[Instruction] = [] # This determines which VariableTracker/Source should be stored as @@ -181,9 +181,9 @@ def __call__(self, value, allow_cache=True): Notable effects: 1. `self.top_of_stack` will be set to `value`, if we don't codegen `value` based on source. - 2. `self.uses[value]` will increment, if we don't codegen `value` based - on source or cache/top-of-stack reuse; in other words, if we codegen - as if `value` is modelling some brand new python value. + 2. `self.uses[value]` will increment, unless (a). we codegen via + `top_of_stack` or cached `tempvars`, or (b). `value` has special VT + types like `NNModuleVariable`, etc. """ if isinstance(value, Source): # If the source needs to be overridden, use the new one. @@ -198,6 +198,7 @@ def __call__(self, value, allow_cache=True): self.top_of_stack = source return + self.uses[source] += 1 try: self.call_reconstruct(source) except NotImplementedError: @@ -207,9 +208,9 @@ def __call__(self, value, allow_cache=True): explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.", hints=[*graph_break_hints.DYNAMO_BUG], ) - - self._output.append(create_dup_top()) - self.add_cache(source) + if source in self.tempvars: + self._output.append(create_dup_top()) + self.add_cache(source) self.top_of_stack = source return @@ -252,7 +253,7 @@ def __call__(self, value, allow_cache=True): # above, export _wants to_ obtain an identity FX graph (despite it # appears unnecessarily expensive for `torch.compile`), so we have # the following option to override Dynamo's preference for codegen - # from source. Morever, this option applies recursively, for cases + # from source. Moreover, this option applies recursively, for cases # like input tensor being returned in a new dictionary. # # And why the `ValueMutationExisting` check? Not sure, so leaving it @@ -590,7 +591,7 @@ def make_call_generated_code(self, fn_name: str) -> None: def collect_temp_source(source): if source in seen_sources: - # This source is used atleast twice, so it can be reused + # This source is used at least twice, so it can be reused self.mark_source_temp(source) # Dont trace source further. This prevents us from marking too # many nodes as temp sources. @@ -613,6 +614,18 @@ def collect_temp_source(source): if arg.source is not None: collect_temp_source(arg.source) + cm_var = None + if config.record_runtime_overhead: + # Record the pregraph bytecode start + self.add_push_null( + lambda: self.load_import_from( + utils.__name__, "record_pregraph_bytecode_enter" + ) + ) + self.extend_output(create_call_function(0, False)) + cm_var = self.new_var() + self.store(cm_var) + for arg in graphargs: if arg.pass_arg_as_tensor: self.add_push_null( @@ -628,6 +641,18 @@ def collect_temp_source(source): else: self.call_reconstruct(arg) + if config.record_runtime_overhead: + # Record the pregraph bytecode end + self.add_push_null( + lambda: self.load_import_from( + utils.__name__, "record_pregraph_bytecode_exit" + ) + ) + assert cm_var is not None + self.extend_output([self.create_load(cm_var)]) + self.extend_output(create_call_function(1, False)) + self.pop_top() + self.extend_output(create_call_function(len(graphargs), False)) def load_import_from(self, module_name, object_name) -> None: diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index ccad684bbd0374..e52fb5026cb985 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -27,6 +27,7 @@ import torch import torch.utils._pytree as pytree from torch._dynamo.external_utils import ( + call_accumulate_grad, call_backward, call_hook, FakeCompiledAutogradEngine, @@ -39,6 +40,10 @@ lazy_format_graph_code, set_locals_to_steal, ) +from torch._functorch._aot_autograd.runtime_wrappers import ( + AutogradLazyBackwardCompileInfo, + CachedAutogradLazyBackwardCompileInfo, +) from torch._guards import compile_context, CompileContext, CompileId from torch._logging import getArtifactLogger, trace_structured from torch._prims_common import clone_preserve_strides @@ -64,6 +69,12 @@ from torch.fx.proxy import Proxy +TURN_OFF_MSG = """You can turn off compiled autograd by either: +1. Moving the unsupported autograd call outside of the torch.compile'd region. +2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager. +3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call. +4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program.""" + compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") @@ -84,6 +95,22 @@ def maybe_clone(x): return x +def extract_bw_module(CompiledFunction): + if isinstance( + CompiledFunction._lazy_backward_info, AutogradLazyBackwardCompileInfo + ): + return CompiledFunction._lazy_backward_info.bw_module + elif isinstance( + CompiledFunction._lazy_backward_info, CachedAutogradLazyBackwardCompileInfo + ): + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + return CompiledFunction._lazy_backward_info.bw_module_fn() + else: + raise AssertionError( + "Unexpected Lazy Backward Compilation Info Type. Please file an issue." + ) + + # Note: [Anomaly Mode Semantics in Compiled Autograd] # In the eager autograd engine, anomaly mode is able to detect NaNs # after each node. This is useful, because the executed code with @@ -105,7 +132,7 @@ def __init__(self, accumulate_grad: bool): def prep_with_graph(self, graph: torch.fx.Graph): inputs_node = next(iter(graph.nodes)) acc_grad_nodes = graph.find_nodes( - op="call_function", target=torch.ops.inductor.accumulate_grad_.default + op="call_function", target=call_accumulate_grad ) output_nodes = graph.find_nodes(op="output")[0].args[0] assert self.accumulate_grad == bool( @@ -137,7 +164,7 @@ def prep_with_inputs(self, inputs: tuple[torch.Tensor]): if grad is not None: assert not torch.isnan(grad).any(), ( f"Compiled autograd running under anomaly mode with inputs[{idx}] already " - "having NaN gradient. This is not supported." + "having NaN gradient. This is not supported. {TURN_OFF_MSG}" ) self.params_to_check[f"inputs[{idx}]"] = inputs[idx] @@ -226,7 +253,7 @@ def __repr__(self): call_hook, call_backward, FakeCompiledAutogradEngine._exec_final_callbacks_stub, - torch.ops.inductor.accumulate_grad_.default, + call_accumulate_grad, ] ) @@ -307,11 +334,16 @@ def begin_capture( self.stack.enter_context(preserve_node_meta()) inputs_origins, sizes_origins, scalars_origins = origins + # tensor inputs to fake tensors - inputs = [ - self.wrap_fake(x, self.source("inputs", idx)) - for idx, x in enumerate(inputs) - ] + x = inputs[0] # mypy will complain about unbound x + try: + for idx, x in enumerate(inputs): + inputs[idx] = self.wrap_fake(x, self.source("inputs", idx)) + except Exception as e: + raise NotImplementedError( + f"Found tensor of type {type(x)}, which is not supported by FakeTensorMode. {TURN_OFF_MSG}" + ) from e self.bind_objects_to_proxies(inputs, args_proxy, inputs_origins) # size inputs to symints @@ -418,6 +450,7 @@ def proxy_call_aot_backward( # NOTE: we should only close over constants CompiledFunction = ctx._forward_cls + bw_module = extract_bw_module(CompiledFunction) metadata = CompiledFunction.metadata maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata aot_id = CompiledFunction._aot_id @@ -468,9 +501,9 @@ def num_inputs(graph): break return num_args - # set up the proxy inputs to ctx._bw_module + # set up the proxy inputs to bw_module # the calling convention is: [*symints, *args (primals and tangents), backward_state] - num_args = num_inputs(ctx._bw_module.graph) + num_args = num_inputs(bw_module.graph) pall_args = [ pgrads[i] for i in range(num_args - int(pbackward_state is not None)) ] @@ -500,7 +533,7 @@ def make_unique(node_name): # make it both informative and unique return f"aot{deduped_aot_id}_{node_name}" - for node in ctx._bw_module.graph.nodes: + for node in bw_module.graph.nodes: if node.op == "placeholder": ph = pall_args[args_idx].node ph.name = make_unique(node.name) @@ -517,9 +550,7 @@ def make_unique(node_name): elif node.op == "get_attr": name = node.target qualname = self.fx_tracer.get_fresh_qualname(name) - setattr( - self.fx_tracer.root, qualname, getattr(ctx._bw_module, name) - ) + setattr(self.fx_tracer.root, qualname, getattr(bw_module, name)) result = self.fx_tracer.create_node("get_attr", qualname, (), {}) result.name = make_unique(node.name) value_remap[node] = result @@ -537,9 +568,7 @@ def make_unique(node_name): elif node.op == "call_module": name = node.target qualname = self.fx_tracer.get_fresh_qualname(name) - setattr( - self.fx_tracer.root, qualname, getattr(ctx._bw_module, name) - ) + setattr(self.fx_tracer.root, qualname, getattr(bw_module, name)) result = self.fx_tracer.graph.node_copy( node, lambda n: value_remap[n] ) @@ -716,13 +745,14 @@ def accumulate(self, old_var, new_var): self.bind_objects_to_proxies([result], [proxy_out]) return result - def accumulate_grad(self, variable, grad): + def accumulate_grad(self, variable, grad, has_post_hooks): self.fx_tracer.create_proxy( "call_function", - torch.ops.inductor.accumulate_grad_.default, + call_accumulate_grad, args=( self.to_proxy(variable), self.to_proxy(grad), + has_post_hooks, ), kwargs={}, ) @@ -1080,7 +1110,7 @@ def reorder_accumulate_grad_nodes(self): pass attempts to reorder the graph to mimic eager behavior. """ for node in self.fx_tracer.graph.find_nodes( - op="call_function", target=torch.ops.inductor.accumulate_grad_.default + op="call_function", target=call_accumulate_grad ): param_node, grad_node = node.args[0], node.args[1] getitem_node = None @@ -1222,10 +1252,7 @@ def reorder_post_acc_grad_hook_nodes(self): # find the corresponding acc_grad node acc_grad_node = None for n in list(param_node.users.keys()): - if ( - n.op == "call_function" - and n.target == torch.ops.inductor.accumulate_grad_.default - ): + if n.op == "call_function" and n.target == call_accumulate_grad: acc_grad_node = n break @@ -1274,10 +1301,7 @@ def reorder_post_hook_nodes(self): ) arg = max(input_nodes_and_users) # last input users - if ( - arg.op == "call_function" - and arg.target == torch.ops.inductor.accumulate_grad_.default - ): + if arg.op == "call_function" and arg.target == call_accumulate_grad: param_node = arg.args[0] post_acc_grad_hook_node = None for n in list(param_node.users.keys()): @@ -1373,9 +1397,13 @@ def set_node_origin( # global flag to check if we are processing graphs produced from a compiled autograd graph in_compiled_autograd_region = False +active_disable_ctx = False + +depth = 0 + @contextlib.contextmanager -def _enable(compiler_fn, dynamic: bool = True): +def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): # The entrypoint to enable CA. # It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather # than using this context manager directly. If you are torch.compiling the corresponding @@ -1396,44 +1424,54 @@ def _enable(compiler_fn, dynamic: bool = True): # - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic. # This doesn't affect the dynamic configuration of the compilation wrapper. - if dynamic: - assert type(dynamic) is bool - - from torch._dynamo import eval_frame - - if eval_frame._stance.stance == "force_eager": - # If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd - # to fall back to eager as well. - global compiled_autograd_enabled_force_eager - compiled_autograd_enabled_force_eager = True - try: - yield - finally: - compiled_autograd_enabled_force_eager = False + if not ignore_active_disable_ctx and active_disable_ctx: + yield else: - # we need to import this, because user might not have imported it if they directly use this context manager - # we need to lazily import it, because of circular dependencies - import torch._inductor.cudagraph_trees + if dynamic: + assert type(dynamic) is bool - ( - prior_compiler, - prior_dynamic, - ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler( - functools.partial(AutogradCompilerInstance, compiler_fn), dynamic - ) - if snapshot_verbose_logging_enabled(): - torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) - global compiled_autograd_enabled - compiled_autograd_enabled = True - try: - with torch.autograd.set_multithreading_enabled(False): + from torch._dynamo import eval_frame + + if eval_frame._stance.stance == "force_eager": + # If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd + # to fall back to eager as well. + global compiled_autograd_enabled_force_eager + compiled_autograd_enabled_force_eager = True + try: yield - finally: - if not prior_compiler: - compiled_autograd_enabled = False - torch._C._dynamo.compiled_autograd.set_autograd_compiler( - prior_compiler, prior_dynamic + finally: + compiled_autograd_enabled_force_eager = False + else: + # we need to import this, because user might not have imported it if they directly use this context manager + # we need to lazily import it, because of circular dependencies + import torch._inductor.cudagraph_trees + + ( + prior_compiler, + prior_dynamic, + ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler( + functools.partial(AutogradCompilerInstance, compiler_fn), dynamic ) + if snapshot_verbose_logging_enabled(): + torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type] + global compiled_autograd_enabled + compiled_autograd_enabled = True + global depth + prior_depth = depth + depth += 1 + try: + with torch.autograd.set_multithreading_enabled(False): + yield + finally: + if not prior_compiler: + compiled_autograd_enabled = False + torch._C._dynamo.compiled_autograd.set_autograd_compiler( + prior_compiler, prior_dynamic + ) + depth -= 1 + assert depth == prior_depth, ( + "Nested Compiled Autograd Contexts must return before their parent context" + ) @contextlib.contextmanager @@ -1444,11 +1482,15 @@ def _disable(): ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False) global compiled_autograd_enabled compiled_autograd_enabled = False + global active_disable_ctx + if not active_disable_ctx: + active_disable_ctx = True try: yield finally: if prior_compiler: compiled_autograd_enabled = True + active_disable_ctx = False torch._C._dynamo.compiled_autograd.set_autograd_compiler( prior_compiler, prior_dynamic ) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index a6818b4e9a91b9..b3708ff8493f8f 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -143,7 +143,7 @@ # guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required # but kept around for debugging and discussing unspecializing nn module # variables. -# TODO(janimesh, voz): Remove both of these flags (or atleast guard_nn_modules) +# TODO(janimesh, voz): Remove both of these flags (or at least guard_nn_modules) # once we have reached stability for the guard_nn_modules_using_dict_tags. guard_nn_modules_using_dict_tags = True @@ -397,7 +397,7 @@ # Use C++ guard manager (deprecated: always true) enable_cpp_guard_manager = True -# Use C++ guard manger for symbolic shapes +# Use C++ guard manager for symbolic shapes enable_cpp_symbolic_shape_guards = not is_fbcode() # Enable tracing through contextlib.contextmanager @@ -418,7 +418,7 @@ # Install "free" tensor variables (globals, non-locals, nn module attributes) # as graph attributes. This is useful for export, as it -# produces a consitent number of inputs to the graph. +# produces a consistent number of inputs to the graph. install_free_tensors = False # Use C++ FrameLocalsMapping (raw array view of Python frame fastlocals) (deprecated: always True) @@ -493,14 +493,14 @@ def default_debug_dir_root(): # This flag is ignored and maintained for backwards compatibility. capture_autograd_function = True -# This flag is ignored and maintained for backwards compatbility. +# This flag is ignored and maintained for backwards compatibility. capture_func_transforms = True # If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode). log_compilation_metrics = True # A set of logging functions which will be reordered to the end of graph breaks, -# allowing dynamo to construct larget graph. Note that there are some +# allowing dynamo to construct large graph. Note that there are some # limitations to this, such as how it does not correctly print objects that were # mutated after the print statement. reorderable_logging_functions: set[Callable[[Any], None]] = set() @@ -615,6 +615,10 @@ def default_debug_dir_root(): # wrapper. This ensures that nn.module hooks are also compiled in the same frame. wrap_top_frame = False +# Flag to record runtime overhead in profile traces. Used for pre-graph bytecode +# and AOTAutograd runtime wrapper. +record_runtime_overhead = True + # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 0f8a212f127617..b5a33a0740b1ba 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -17,6 +17,10 @@ The conversion process preserves program semantics while enabling optimizations through torch.compile() and related systems. + +NOTE: _torchdynamo_orig_backend is used for convert frame wrappers to identify the inner wrapped function. +By going down the _torchdynamo_orig_backend chain, one can recover the original unwrapped backend, +which is checked for during the Dynamo cache lookup. """ from __future__ import annotations @@ -39,6 +43,7 @@ import traceback import typing import weakref +from dataclasses import dataclass from pathlib import Path from types import CellType, CodeType, FunctionType, ModuleType from typing import Any, Callable, Optional, TypeVar, Union @@ -48,6 +53,7 @@ import torch import torch._logging from torch._C._dynamo.guards import GlobalStateGuard +from torch._dynamo.callback import CallbackTrigger from torch._dynamo.distributed import get_compile_pg from torch._dynamo.symbolic_convert import TensorifyState from torch._guards import compile_context, CompileContext, CompileId, tracing @@ -101,6 +107,7 @@ InternalTorchDynamoError, PackageError, RecompileLimitExceeded, + ResumePrologueTracingError, ShortenTraceback, SkipCodeRecursiveException, TorchRuntimeError, @@ -114,7 +121,7 @@ GuardedCode, ) from .hooks import Hooks -from .pgo import put_code_state +from .pgo import log_frame_dynamic_whitelist, put_code_state from .replay_record import ExecutionRecord from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX from .symbolic_convert import ( @@ -127,6 +134,7 @@ from .trace_rules import is_numpy from .types import ConvertFrameReturn, FrameAction, FrameExecStrategy, wrap_guarded_code from .utils import ( + _get_error_on_graph_break, chromium_event_timed, CleanupManager, CompileTimeInstructionCounter, @@ -160,6 +168,7 @@ if typing.TYPE_CHECKING: from .backends.registry import CompilerFn + from .package import CompilePackage from .repro.after_dynamo import WrapBackendDebug from .types import BytecodeHook, CacheEntry, DynamoFrameType from .variables.builder import FrameStateSizeEntry @@ -252,7 +261,9 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: cuda_rng_state = None if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() - allow_tf32 = torch._C._get_cublas_allow_tf32() + cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter( + "cuda", "matmul" + ) prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() @@ -284,13 +295,15 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: torch._C._unset_default_mobile_cpu_allocator() if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - torch._C._set_cublas_allow_tf32(allow_tf32) + torch._C._set_fp32_precision_setter( + "cuda", "matmul", cuda_matmul_fp32_prec + ) torch.fx.graph_module._forward_from_src = prior_fwd_from_src assert guards.check(), ( f"Global {guards.reason()}state changed while dynamo tracing, please report a bug" ) - _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + _fn._torchdynamo_orig_backend = fn # type: ignore[attr-defined] return _fn @@ -470,6 +483,17 @@ def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: return profile_wrapper +@dataclass +class ConvertFrameBox: + error_on_graph_break: Optional[bool] = None + + +def _is_error_on_graph_break(tx: Optional[InstructionTranslator]) -> bool: + if tx is None: + return _get_error_on_graph_break() + return tx.error_on_graph_break + + class ConvertFrameAssert: def __init__( self, @@ -477,13 +501,16 @@ def __init__( one_graph: bool = True, export: bool = False, export_constraints: Optional[typing.Never] = None, + package: Optional[CompilePackage] = None, ) -> None: # assert export_constraints is None reset_graph_break_dup_checker() - self._torchdynamo_orig_callable = compiler_fn + self._torchdynamo_orig_backend = compiler_fn self._one_graph = one_graph self._export = export self._export_constraints = export_constraints + self._package = package + self._box = ConvertFrameBox() @property def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]: @@ -628,7 +655,7 @@ def __call__( frame.f_locals, frame.f_builtins, frame.closure, - self._torchdynamo_orig_callable, + self._torchdynamo_orig_backend, self._one_graph, self._export, self._export_constraints, @@ -639,6 +666,8 @@ def __call__( frame_state=frame_state, compile_id=compile_id, skip=skip + 1, + package=self._package, + convert_frame_box=self._box, ) @@ -647,9 +676,12 @@ def convert_frame_assert( one_graph: bool = True, export: bool = False, export_constraints: Optional[typing.Never] = None, + package: Optional[CompilePackage] = None, ) -> ConvertFrameAssert: - """Fully convert a frame into an FX graph""" - return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints) + """Fully convert a frame into an FX graph, raising an exception if we fail.""" + return ConvertFrameAssert( + compiler_fn, one_graph, export, export_constraints, package + ) from collections import OrderedDict @@ -692,6 +724,10 @@ def _compile( *, compile_id: CompileId, skip: int = 0, + package: Optional[CompilePackage] = None, + # Can be used to record things for the caller, both + # in the case of normal and exception code paths + convert_frame_box: Optional[ConvertFrameBox] = None, ) -> ConvertFrameReturn: from torch.fx.experimental.validator import ( bisect, @@ -716,7 +752,7 @@ def transform( ) -> None: nonlocal output nonlocal tracer - speculation_log.restart() + speculation_log.restart() # type: ignore[has-type] exn_vt_stack = ExceptionStack() tracer = InstructionTranslator( instructions, @@ -732,9 +768,10 @@ def transform( export, export_constraints, frame_state=frame_state, - speculation_log=speculation_log, + speculation_log=speculation_log, # type: ignore[has-type] exn_vt_stack=exn_vt_stack, - distributed_state=distributed_state, + distributed_state=distributed_state, # type: ignore[has-type] + package=package, ) try: @@ -742,7 +779,7 @@ def transform( with tracing(tracer.output.tracing_context), tracer.set_current_tx(): tracer.run() except exc.UnspecializeRestartAnalysis: - speculation_log.clear() + speculation_log.clear() # type: ignore[has-type] raise except ( exc.SpeculationRestartAnalysis, @@ -774,7 +811,11 @@ def compile_inner( transform: Callable[[list[Instruction], dict[str, Any]], Any], ) -> ConvertFrameReturn: with contextlib.ExitStack() as stack: - stack.enter_context(torch._dynamo.callback_handler.install_callbacks()) + stack.enter_context( + torch._dynamo.callback_handler.install_callbacks( + CallbackTrigger.DYNAMO, str(CompileContext.current_compile_id()) + ) + ) stack.enter_context(CompileTimeInstructionCounter.record()) return _compile_inner(code, one_graph, hooks, transform) @@ -848,11 +889,13 @@ def log_bytecode( code.co_filename, code.co_firstlineno, ) - if one_graph: - log.debug("No graph captured with one_graph=True") + if one_graph or _is_error_on_graph_break(tracer): + log.debug( + "No graph captured with one_graph=True or error_on_graph_break=True" + ) return ConvertFrameReturn() - assert distributed_state is None or distributed_state.all_states is not None, ( + assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type] "compiler collective wasn't run before compilation completed" ) @@ -865,10 +908,11 @@ def log_bytecode( out_code, ) - for hook in _bytecode_hooks.values(): - hook_output = hook(code, out_code) - if hook_output is not None: - out_code = hook_output + for idx, hook in enumerate(_bytecode_hooks.values()): + with dynamo_timed(f"bytecode_hooks_{idx}", log_pt2_compile_event=True): + hook_output = hook(code, out_code) + if hook_output is not None: + out_code = hook_output orig_code_map[out_code] = code output_codes.add(out_code) @@ -931,8 +975,14 @@ def count_args(code: CodeType) -> int: cache_entry, hooks.guard_fail_fn if hooks else None, hooks.guard_filter_fn if hooks else None, + guards_serialization_mode="save" if package else None, ) + if package is not None: + assert check_fn.guards_state is not None + package.add_guarded_code(check_fn.guards_state, out_code) + package.add_inlined_source(output.tracing_context.traced_code) + compile_id_str = str(compile_id) if compile_id is not None else "Unknown" annotation_str = "Torch-Compiled Region: " + compile_id_str guarded_code = GuardedCode( @@ -942,21 +992,20 @@ def count_args(code: CodeType) -> int: annotation_str, ) - if not output.is_empty_graph(): - if hooks.guard_export_fn is not None: - # We should not run the guard_export_fn when Dynamo does not - # generate any graph. This can happen in export when TorchDynamo - # generated bytecode has some reconstruction logic for mutated - # variables which can trigger TorchDynamo on the children frames but - # they are benign and do not generate any new graphs. - hooks.guard_export_fn(output.guards) - if hooks.frame_traced_fn is not None: - output.tracing_context.traced_code.append(output.f_code) - hooks.frame_traced_fn(output.tracing_context.traced_code) + if not output.is_empty_graph() and hooks.guard_export_fn is not None: + # We should not run the guard_export_fn when Dynamo does not + # generate any graph. This can happen in export when TorchDynamo + # generated bytecode has some reconstruction logic for mutated + # variables which can trigger TorchDynamo on the children frames but + # they are benign and do not generate any new graphs. + hooks.guard_export_fn(output.guards) return wrap_guarded_code(guarded_code) metrics_context = get_metrics_context() + code_context = ( + package.code_context(code) if package is not None else contextlib.nullcontext() + ) with ( _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(CompileContext(compile_id)), @@ -970,6 +1019,7 @@ def count_args(code: CodeType) -> int: phase_name="entire_frame_compile", dynamo_compile_column_us="dynamo_cumulative_compile_time_us", ), + code_context, ): restart_reasons: set[str] = set() # This is shared across restarts @@ -1010,9 +1060,10 @@ def format_func_info(code: CodeType) -> str: raise FailOnRecompileLimitHit( f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure" ) - elif one_graph: + elif one_graph or _is_error_on_graph_break(tracer): raise FailOnRecompileLimitHit( - f"{limit_type} reached with one_graph=True. Excessive recompilations can degrade " + f"{limit_type} reached with one_graph=True or error_on_graph_break=True. " + "Excessive recompilations can degrade " "performance due to the compilation overhead of each recompilation. To monitor " "recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider " "increasing torch._dynamo.config.cache_size_limit to an appropriate value." @@ -1101,6 +1152,7 @@ def format_func_info(code: CodeType) -> str: # to upload for graph break though, because this can prevent # extra graph break compilations.) put_code_state() + log_frame_dynamic_whitelist(code) return guarded_code except Exception as e: @@ -1122,7 +1174,15 @@ def format_func_info(code: CodeType) -> str: fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( e, compile_id ) - if isinstance( + if tracer and tracer.is_tracing_resume_prologue: + # Do not allow any errors to be suppressed if tracer is currently tracing + # through resume function. + raise ResumePrologueTracingError( + "Error while tracing through a Dynamo-generated resume function prologue. " + "Errors are not allowed when tracing resume function prologues.\n" + f"{type(e).__qualname__}: {str(e)}" + ).with_traceback(e.__traceback__) from None + elif isinstance( e, ( Unsupported, @@ -1157,6 +1217,7 @@ def format_func_info(code: CodeType) -> str: if tracer: tracer.output.local_scope = {} + tracer.f_locals = {} from .utils import curr_frame @@ -1219,20 +1280,37 @@ def format_func_info(code: CodeType) -> str: metrics_context.update_outer(metrics) # === END WARNING WARNING WARNING === + # If tracer is available, then tracer.error_on_graph_break reflects value of + # global symbolic_convert.error_on_graph_break at the time of the graph break - + # symbolic_convert.error_on_graph_break may have been (correctly) changed during cleanup. + # If tracer is unavailable, then fallback to symbolic_convert.error_on_graph_break. + if convert_frame_box: + convert_frame_box.error_on_graph_break = ( + tracer.error_on_graph_break + if tracer + else _get_error_on_graph_break() + ) + class ConvertFrame: def __init__( self, compiler_fn: CompilerFn, hooks: Hooks, + package: Optional[CompilePackage] = None, ) -> None: - self._torchdynamo_orig_callable = compiler_fn - self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False) + self._torchdynamo_orig_backend = compiler_fn + self._inner_convert = convert_frame_assert( + compiler_fn, one_graph=False, package=package + ) self._hooks = hooks @property def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]: - return lambda backend: convert_frame(backend, self._hooks) + return lambda backend: convert_frame( + backend, + self._hooks, + ) def __call__( self, @@ -1251,6 +1329,22 @@ def __call__( counters["frames"]["ok"] += 1 return result except Exception as e: + # Do not allow errors to be suppressed if we're tracing a resume function prologue + if isinstance(e, ResumePrologueTracingError): + raise + + error_on_graph_break = ( + self._inner_convert._box.error_on_graph_break is not None + ) + assert error_on_graph_break is not None + if self._inner_convert._box.error_on_graph_break: + # NOTE we _might_ have to wrap the current in a custom exception + # in order to correctly bubble up to the top-level compile wrapper in + # eval_frame.py. But re-raising seems to work for now because exceptions from tracing + # a nested call that results in a top-level frame compile will be handled by the caller + # as an observed exception - we don't expect that exception to be suppressed. + raise + # These two exception types are "soft" failure, in the sense that # we know this is due to something we didn't implement all the # way, scare the user less about it. That being said, if you @@ -1330,9 +1424,13 @@ def __call__( return ConvertFrameReturn() -def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame: +def convert_frame( + compiler_fn: CompilerFn, + hooks: Hooks, + package: Optional[CompilePackage] = None, +) -> ConvertFrame: """Try to convert a frame into an FX graph, if error leave frame unmodified""" - return ConvertFrame(compiler_fn, hooks) + return ConvertFrame(compiler_fn, hooks, package=package) # TODO mlazos: add support for same args, or record them @@ -1345,26 +1443,27 @@ def replay(filename: str) -> None: record = ExecutionRecord.load(in_file) record.globals = dict(itertools.chain(record.globals.items(), globals().items())) - try: - _compile( - record.code, - record.globals, - record.locals, - record.builtins, - record.closure, - compiler_fn=eager, - one_graph=False, - export=False, - export_constraints=None, - hooks=Hooks(), - cache_size=CacheSizeRelevantForFrame(0, 0), - cache_entry=None, - frame=None, - frame_state={}, - compile_id=CompileId(frame_id=42, frame_compile_id=999), - ) - finally: - config.replay_record_enabled = original_replay_val + with decorators.set_fullgraph(fullgraph=False): + try: + _compile( + record.code, + record.globals, + record.locals, + record.builtins, + record.closure, + compiler_fn=eager, + one_graph=False, + export=False, + export_constraints=None, + hooks=Hooks(), + cache_size=CacheSizeRelevantForFrame(0, 0), + cache_entry=None, + frame=None, + frame_state={}, + compile_id=CompileId(frame_id=42, frame_compile_id=999), + ) + finally: + config.replay_record_enabled = original_replay_val def first_real_inst_idx(code: CodeType) -> int: @@ -1391,7 +1490,7 @@ def __call__( class CatchErrorsWrapper: def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None: functools.wraps(callback)(self) - self._torchdynamo_orig_callable = callback + self._torchdynamo_orig_backend = callback self.hooks = hooks def __call__( @@ -1416,7 +1515,7 @@ def __call__( or config.disable or ( is_in_torch_dispatch_mode(include_infra_modes=False) - and not getattr(self._torchdynamo_orig_callable, "_export", False) + and not getattr(self._torchdynamo_orig_backend, "_export", False) ) ): if log.isEnabledFor(logging.DEBUG): @@ -1448,15 +1547,15 @@ def __call__( ddp_optimizer = DDPOptimizer( bucket_bytes_cap=ddp_module.bucket_bytes_cap, - backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, # type: ignore[attr-defined] + backend_compile_fn=self._torchdynamo_orig_backend._torchdynamo_orig_backend, # type: ignore[attr-defined] ) assert hasattr( - self._torchdynamo_orig_callable, "_clone_with_backend" + self._torchdynamo_orig_backend, "_clone_with_backend" ), ( "DDPOptimizer only supports callback fns that know how to clone themselves." ) hijacked_callback = ( - self._torchdynamo_orig_callable._clone_with_backend( + self._torchdynamo_orig_backend._clone_with_backend( ddp_optimizer.compile_fn, ) ) @@ -1466,7 +1565,7 @@ def __call__( with compile_lock, _disable_current_modes(): # skip=1: skip this frame - return self._torchdynamo_orig_callable( + return self._torchdynamo_orig_backend( frame, cache_entry, self.hooks, frame_state, skip=1 ) diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 2e793bb4c7dace..a23b58cedf2267 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -247,7 +247,7 @@ def __init__(self) -> None: return model_str -@functools.lru_cache(None) # subprocess is expensive +@functools.cache # subprocess is expensive def _cuda_system_info_comment(): if not torch.cuda.is_available(): return "# torch.cuda.is_available()==False, no GPU info collected\n" diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index ff1156b04b616f..ab66edd6b909a9 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -10,7 +10,7 @@ import sys import weakref from dataclasses import dataclass -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, Optional, overload, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -31,11 +31,11 @@ ) from .exc import IncorrectUsage from .external_utils import ( - _dynamo_config_patch_proxy_dunder_call, get_nonrecursive_disable_wrapper, is_compiling, + wrap_dunder_call_ctx_manager, ) -from .utils import is_function +from .utils import _get_error_on_graph_break, _set_error_on_graph_break, is_function if TYPE_CHECKING: @@ -44,6 +44,7 @@ from torch._C._dynamo.eval_frame import ( # noqa: F401 reset_code, set_eval_frame, + set_guard_complete_hook, set_guard_error_hook, unsupported, ) @@ -69,7 +70,7 @@ def run(fn=None): return RunOnlyContext() -def disable(fn=None, recursive=True, *, reason=None): +def disable(fn=None, recursive=True, *, reason=None, wrapping=True): """ Decorator to disable TorchDynamo @@ -85,8 +86,8 @@ def disable(fn=None, recursive=True, *, reason=None): if fn is not None: fn = innermost_fn(fn) assert callable(fn) - return DisableContext(msg=reason)(fn) - return DisableContext(msg=reason) + return DisableContext(msg=reason, wrapping=wrapping)(fn) + return DisableContext(msg=reason, wrapping=wrapping) else: def wrap(fn): @@ -553,7 +554,11 @@ def mark_unbacked(t, index, strict=False, specialize_on=None): if not hasattr(t, "_dynamo_unbacked_indices"): t._dynamo_unbacked_indices = set() - t._specialize_on[index] = specialize_on if specialize_on is not None else [] + # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: + # TypeError: 'Attribute' object does not support item assignment + if isinstance(t._specialize_on, dict): + t._specialize_on[index] = specialize_on if specialize_on is not None else [] + t._dynamo_unbacked_indices.add(index) return @@ -619,7 +624,12 @@ def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): # TODO(voz): Should we bounds check? t._dynamo_dynamic_indices.add(index) t._dynamo_dynamic_range.add(_DimRange(index, min, max)) - t._specialize_on[index] = specialize_on if specialize_on is not None else [] + + # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: + # TypeError: 'Attribute' object does not support item assignment + if isinstance(t._specialize_on, dict): + t._specialize_on[index] = specialize_on if specialize_on is not None else [] + return assert isinstance(index, (list, tuple)) @@ -698,7 +708,7 @@ def mark_static(t, index=None): if not isinstance(t, torch.Tensor): raise TypeError( - f"mark_static expects a tensor/nn.Module class but recieved {type(t)}" + f"mark_static expects a tensor/nn.Module class but received {type(t)}" ) if isinstance(index, int): @@ -724,7 +734,7 @@ def mark_static_address(t, guard=True): Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called. """ if not isinstance(t, torch.Tensor): - raise TypeError(f"mark_static_address expects a tensor but recieved {type(t)}") + raise TypeError(f"mark_static_address expects a tensor but received {type(t)}") if guard: t._dynamo_static_input_type = "guarded" # type: ignore[attr-defined] @@ -781,7 +791,7 @@ def changes(self): # Decorator implementation that simply sets up `self` as a context manager. # Placed in external_utils so that we can trace through it. - __call__ = _dynamo_config_patch_proxy_dunder_call + __call__ = wrap_dunder_call_ctx_manager def __enter__(self): return self.config_patch.__enter__() @@ -845,7 +855,7 @@ def patch_dynamo_config( See _allowed_config_patches for the list of allowed config patches. - Arguments are the same as with torch._dynamo.confing.patch. + Arguments are the same as with torch._dynamo.config.patch. Can be used as a decorator or a context manager. @@ -862,6 +872,14 @@ def patch_dynamo_config( return DynamoConfigPatchProxy(config_patch) +@overload +def dont_skip_tracing(fn: None = None) -> DynamoConfigPatchProxy: ... + + +@overload +def dont_skip_tracing(fn: Callable[_P, _R]) -> Callable[_P, _R]: ... + + def dont_skip_tracing(fn=None): """ Context manager/decorator to trace into functions intentionally marked by developers to be skipped @@ -873,3 +891,27 @@ def dont_skip_tracing(fn=None): if fn: return ctx(fn) return ctx + + +class SetFullgraphDecoratorContextManager: + def __init__(self, fullgraph): + self.fullgraph = fullgraph + + __call__ = wrap_dunder_call_ctx_manager + + def __enter__(self): + self.prev_fullgraph = _get_error_on_graph_break() + _set_error_on_graph_break(self.fullgraph) + + def __exit__(self, exc_type, exc_val, exc_tb): + _set_error_on_graph_break(self.prev_fullgraph) + + +def set_fullgraph(fullgraph: bool) -> SetFullgraphDecoratorContextManager: + """ + Context manager/decorator to toggle fullgraph setting. + + More precisely, when encountering a graph break, we will decide to resume (fullgraph=False) + or error out (fullgraph=True) based on the fullgraph setting at the location of the graph break. + """ + return SetFullgraphDecoratorContextManager(fullgraph) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 263da3417c4318..2ec7c5f7259f1e 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -199,12 +199,12 @@ def __exit__(self, type: Any, value: Any, traceback: Any): class CudaInterface(DeviceInterface): - device = torch.cuda.device + device = torch.cuda.device # type: ignore[assignment] # register Event and Stream class into the backend interface # make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream - Event = torch.cuda.Event - Stream = torch.cuda.Stream + Event = torch.cuda.Event # type: ignore[assignment] + Stream = torch.cuda.Stream # type: ignore[assignment] class Worker: @staticmethod @@ -297,9 +297,9 @@ def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: class XpuInterface(DeviceInterface): - device = torch.xpu.device - Event = torch.xpu.Event - Stream = torch.xpu.Stream + device = torch.xpu.device # type: ignore[assignment] + Event = torch.xpu.Event # type: ignore[assignment] + Stream = torch.xpu.Stream # type: ignore[assignment] class Worker: @staticmethod diff --git a/torch/_dynamo/distributed.py b/torch/_dynamo/distributed.py index aa60b325844b69..490b6330fafa45 100644 --- a/torch/_dynamo/distributed.py +++ b/torch/_dynamo/distributed.py @@ -22,6 +22,7 @@ _COMPILE_PG: Optional[dist.ProcessGroup] = None +_GUARD_PG: Optional[dist.ProcessGroup] = None def get_compile_pg() -> Optional[dist.ProcessGroup]: @@ -39,3 +40,15 @@ def get_compile_pg() -> Optional[dist.ProcessGroup]: return _COMPILE_PG return None + + +# NB: Unlike get_compile_pg, this is only called when guard collectives were +# explicitly requested +def get_guard_pg() -> Optional[dist.ProcessGroup]: + if dist.is_available() and dist.is_initialized(): + global _GUARD_PG + if _GUARD_PG is None: + _GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg") + return _GUARD_PG + + return None diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 73f36ba571991e..de2b44953817f1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -58,6 +58,7 @@ reset_code, set_code_exec_strategy, set_eval_frame, + set_guard_complete_hook, set_guard_error_hook, set_skip_guard_eval_unsafe, unsupported, @@ -90,7 +91,7 @@ ) from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo -from . import config, convert_frame, external_utils, trace_rules, utils +from . import config, convert_frame, distributed, external_utils, trace_rules, utils from .backends.registry import CompilerFn, lookup_backend from .code_context import code_context from .exc import ( @@ -102,7 +103,12 @@ ) from .hooks import Hooks from .mutation_guard import install_generation_tagging_init -from .utils import common_constant_types, compile_times +from .utils import ( + _get_error_on_graph_break, + _set_error_on_graph_break, + common_constant_types, + compile_times, +) if TYPE_CHECKING: @@ -206,13 +212,27 @@ def _callback_from_stance(callback): if callback in (False, None): return callback - def fail_callback(*args, **kwargs): - raise RuntimeError( - "Detected recompile when torch.compile stance is 'fail_on_recompile'" + def fail_callback(frame, *args, **kwargs): + if trace_rules.check(frame.f_code): + return ConvertFrameReturn() + + from torch._C._dynamo.eval_frame import _debug_get_precompile_entries + + message = ( + "Detected recompile when torch.compile stance is 'fail_on_recompile'. " + + f"filename: '{frame.f_code.co_filename}', " + + f"function name: '{frame.f_code.co_name}', " + + f"line number: {frame.f_lineno}" ) + precompile_entries = _debug_get_precompile_entries(frame.f_code) + if len(precompile_entries) > 0: + message += "\nFailed on the following precompiled guards: " + for entry in precompile_entries: + message += f"\n{entry.guard_manager}{entry.guard_manager.check_verbose(frame.f_locals)}" # type: ignore[attr-defined] + raise RuntimeError(message) - # to prevent cache miss due to different callback - fail_callback._torchdynamo_orig_callable = callback # type: ignore[attr-defined] + # to prevent cache miss due to different backend + fail_callback._torchdynamo_orig_backend = callback # type: ignore[attr-defined] return fail_callback else: @@ -258,9 +278,12 @@ def callback_fn(*args, **kwargs): dynamism = track_dynamism_across_examples(example_inputs) code_context.get_context(frame.f_code)["dynamism"] = dynamism - compiler_fn = callback._torchdynamo_orig_callable._torchdynamo_orig_callable + compiler_fn = callback._torchdynamo_orig_backend._torchdynamo_orig_backend return _create_wrapped_callback(compiler_fn)(*args, **kwargs) + # to prevent cache miss due to different backend + callback_fn._torchdynamo_orig_backend = callback # type: ignore[attr-defined] + return callback_fn @@ -464,15 +487,15 @@ def always_false(): return False -def innermost_fn(fn): +def innermost_fn(fn, unaltered_fn_attr="_torchdynamo_orig_callable"): """ In case of nesting of _TorchDynamoContext calls, find the innermost function. TorchDynamo caches on fn.__code__ object, so its necessary to find the innermost function to pass on the optimize, run, disable etc. """ unaltered_fn = fn - while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): - unaltered_fn = unaltered_fn._torchdynamo_orig_callable + while hasattr(unaltered_fn, unaltered_fn_attr): + unaltered_fn = getattr(unaltered_fn, unaltered_fn_attr) assert callable(unaltered_fn), ( f"A callable function is expected, but {type(unaltered_fn)} is provided." ) @@ -517,6 +540,38 @@ def _log_traced_frames(): log.info(msg) +def guard_collectives_hook(guard_eval_result): + import torch.distributed as dist + from torch._dynamo.utils import dynamo_timed + + # guard_eval_result == True ==> cache hit + if pg := distributed.get_guard_pg(): + with dynamo_timed( + "guard_collective", log_pt2_compile_event=True, log_waitcounter=True + ): + log.info("guard_collective %s", guard_eval_result) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "guard_collective", + "encoding": "string", + }, + payload_fn=lambda: str(guard_eval_result), + ) + # TODO: a bit awkward to time, this isn't inside of the dynamo compile region + all_results = [None] * pg.size() + dist.all_gather_object(all_results, guard_eval_result, group=pg) + # True = everyone hit, OK to run + # False = someone missed, force recompile everywhere + res = all(all_results) + log.info("guard_collective %s -> %s", guard_eval_result, res) + return res + return guard_eval_result + + +_not_set = object() + + class _TorchDynamoContext: def __init__( self, @@ -526,9 +581,11 @@ def __init__( patch_fn=nothing, first_ctx=False, *, + error_on_graph_break=False, export=False, dynamic=None, compiler_config=None, + package=None, ) -> None: super().__init__() assert callable(callback) or callback is False or callback is None @@ -536,15 +593,17 @@ def __init__( self._backend_ctx_ctor = backend_ctx_ctor self.prior: Union[Unset, DynamoCallback] = unset self.first_ctx = first_ctx + self.error_on_graph_break = error_on_graph_break self.export = export self._dynamic = dynamic self.compiler_config = compiler_config self.cleanup_fns: list[Callable[[], Any]] = [] self.enter_exit_hooks = [] + self._package = package patch_fn() # Save the backends so that we can reset them during torch._dynamo.reset - backend = innermost_fn(callback) + backend = innermost_fn(callback, unaltered_fn_attr="_torchdynamo_orig_backend") cached_backends.setdefault(id(backend), backend) if dynamic is not None: @@ -683,16 +742,21 @@ def compile_wrapper(*args, **kwargs): prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe( _is_skip_guard_eval_unsafe_stance() ) + prior_error_on_graph_break = None + if self.error_on_graph_break is not None: + prior_error_on_graph_break = _get_error_on_graph_break() + _set_error_on_graph_break(self.error_on_graph_break) # Ensure that if an assertion occurs after graph pushes # something onto the DynamicLayerStack then we pop it off (the # constructed graph code isn't guarded with try/finally). # - # This used to be a context but putting a `with` here is a noticible + # This used to be a context but putting a `with` here is a noticeable # perf regression (#126293) saved_dynamic_layer_stack_depth = ( torch._C._functorch.get_dynamic_layer_stack_depth() ) + _maybe_set_eval_frame(_callback_from_stance(callback)) try: @@ -713,6 +777,8 @@ def compile_wrapper(*args, **kwargs): finally: # Restore the dynamic layer stack depth if necessary. set_eval_frame(None) + if prior_error_on_graph_break is not None: + _set_error_on_graph_break(prior_error_on_graph_break) torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( saved_dynamic_layer_stack_depth ) @@ -724,7 +790,9 @@ def compile_wrapper(*args, **kwargs): _maybe_set_eval_frame(prior) # hooks to properly handle inlining - compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined] + compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined] + external_utils.wrap_inline_with_set_fullgraph(fn, self.error_on_graph_break) + ) # Save the function pointer to find the original callable while nesting # of decorators. @@ -784,12 +852,14 @@ def __init__( backend_ctx_ctor, first_ctx=False, *, + error_on_graph_break=False, export=False, dynamic=None, compiler_config=None, rebuild_ctx: Optional[ Callable[[], Union[OptimizeContext, _NullDecorator]] ] = None, + package=None, ) -> None: def on_enter(): install_generation_tagging_init() @@ -800,9 +870,11 @@ def on_enter(): backend_ctx_ctor=backend_ctx_ctor, patch_fn=TorchPatcher.patch, first_ctx=first_ctx, + error_on_graph_break=error_on_graph_break, export=export, dynamic=dynamic, compiler_config=compiler_config, + package=package, ) if config.compiled_autograd: @@ -814,7 +886,7 @@ def call_compiled_autograd(): assert rebuild_ctx is not None compiler_fn = rebuild_ctx() ctx = torch._dynamo.compiled_autograd._enable( - compiler_fn, dynamic=_dynamic + compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False ) ctx.__enter__() return functools.partial(ctx.__exit__, None, None, None) @@ -846,9 +918,10 @@ def __reduce__(self): class DisableContext(_TorchDynamoContext): - def __init__(self, msg: Optional[str] = None) -> None: + def __init__(self, msg: Optional[str] = None, wrapping: bool = True) -> None: super().__init__(callback=None) self.msg = msg + self.wrapping = wrapping def __call__(self, fn): # Earlier this code was in the base class _TorchDynamoContext. But we @@ -863,14 +936,14 @@ def __call__(self, fn): new_mod._torchdynamo_orig_callable = mod.forward return new_mod - if inspect.isclass(fn): + if isinstance(fn, type): # User has wrapped the class with compile/disable decorator. Apply # disable to init/call method. cls_obj = fn # Disable on init is useful for reconstruction of bytecodes where we # want to prevent Dynamo from tracing into the init function. Check # test_reconstruction in test_model_output.py. - cls_obj.__init__ = self(cls_obj.__init__) + cls_obj.__init__ = self(cls_obj.__init__) # type: ignore[misc] cls_obj.__call__ = self(cls_obj.__call__) if issubclass(cls_obj, torch.nn.Module): # NN module variable tracker directly inlines the _call_impl. Disable it. @@ -881,22 +954,25 @@ def __call__(self, fn): f"A callable function is expected, but {type(fn)} is provided." ) - @functools.wraps(fn) def _fn(*args, **kwargs): prior = set_eval_frame(None) try: - prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe( - _is_skip_guard_eval_unsafe_stance() - ) _maybe_set_eval_frame(_callback_from_stance(self.callback)) try: return fn(*args, **kwargs) finally: set_eval_frame(None) - set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe) finally: _maybe_set_eval_frame(prior) + # Under some circumstances (e.g. precompile) we can end up calling @disable + # decorator in generated bytecode and trigger recompile. This is due to the + # fact that the old callback from torch.compile() is still active and under + # this circumstance we will trigger a failure with set_stance("fail_on_recompile"). + # Therefore we want to skip calling into any frame in this case. + if self.wrapping: + _fn = functools.wraps(fn)(_fn) + _fn._torchdynamo_disable = True # type: ignore[attr-defined] _fn._torchdynamo_disable_msg = self.msg # type: ignore[attr-defined] @@ -914,19 +990,23 @@ def _optimize_catch_errors( compile_fn, hooks: Hooks, backend_ctx_ctor=null_context, + error_on_graph_break=False, export=False, dynamic=None, compiler_config=None, rebuild_ctx=None, + package=None, ): return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), backend_ctx_ctor=backend_ctx_ctor, first_ctx=True, + error_on_graph_break=error_on_graph_break, export=export, dynamic=dynamic, compiler_config=compiler_config, rebuild_ctx=rebuild_ctx, + package=package, ) @@ -1015,9 +1095,9 @@ def _optimize( guard_export_fn=None, guard_fail_fn=None, guard_filter_fn=None, - frame_traced_fn=None, disable=False, dynamic=None, + package=None, ) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -1055,7 +1135,6 @@ def toy_example(a, b): ... guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn, guard_filter_fn=guard_filter_fn, - frame_traced_fn=frame_traced_fn, ) torch._C._log_api_usage_once("torch._dynamo.optimize") if ( @@ -1065,26 +1144,19 @@ def toy_example(a, b): ... ): return _NullDecorator() - if nopython: - return optimize_assert( - backend, - dynamic=dynamic, - hooks=hooks, - rebuild_ctx=rebuild_ctx, - ) - backend = get_compiler_fn(backend) # Find if backend has any extra context manager backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) # The backend function is stashed in the callable returned by - # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can + # _optimize_catch_errors in the field _torchdynamo_orig_backend. This can # be used by eval_frame.c to insert a guard on the backend. return _optimize_catch_errors( - convert_frame.convert_frame(backend, hooks=hooks), + convert_frame.convert_frame(backend, hooks, package=package), hooks, backend_ctx_ctor, + error_on_graph_break=nopython, dynamic=dynamic, compiler_config=( backend.get_compiler_config() @@ -1092,6 +1164,7 @@ def toy_example(a, b): ... else None ), rebuild_ctx=rebuild_ctx, + package=package, ) @@ -1196,7 +1269,7 @@ def __init__( if i in matched_input_elements_to_fake: arg.node.meta["val"] = matched_input_elements_to_fake[i] else: - # Fill node.mata["val"] with faketensor from the input, + # Fill node.meta["val"] with faketensor from the input, # if it's not found in matched_input_elements_positions if fake_mode is not None and isinstance(flat_args[i], torch.Tensor): # TODO(zhxchen17) Also preserve all the user constraints here. @@ -1839,7 +1912,7 @@ def fakify_with_ambient(path, t): "Failed to produce a graph during tracing as no tensor operations were found and same_signature is False." ) # If the module does not contain any tensor computation, we would create a graph with inputs and outputs. - # To be consitant with the graph traced by dynano, `graph` will have only tensor inputs as placeholders + # To be consistent with the graph traced by dynano, `graph` will have only tensor inputs as placeholders # and tensor outputs as output nodes. non-tensor inputs and outputs will be added when rewriting signature. # We will also construct the `example_inputs`, `graph_captured_input`, and `graph_captured_result` corresponding # to `graph`. @@ -1982,9 +2055,14 @@ def _optimize_assert( export=False, export_constraints=None, dynamic=None, + package=None, ): """ - The same as `torch._dynamo.optimize(backend, nopython=True)` + The same as `torch._dynamo.optimize(backend, nopython=True)`, + but ignores symbolic_convert.error_on_graph_break setting. + + Used for export, since we must always error on graph breaks and ignore + symbolic_convert.error_on_graph_break. Can also be used for testing. """ backend = get_compiler_fn(backend) @@ -1993,19 +2071,23 @@ def _optimize_assert( return _optimize_catch_errors( convert_frame.convert_frame_assert( - backend, export=export, export_constraints=export_constraints + backend, + export=export, + export_constraints=export_constraints, + package=package, ), hooks, backend_ctx_ctor, export=export, dynamic=dynamic, rebuild_ctx=rebuild_ctx, + package=package, ) class TorchPatcher: @staticmethod - @functools.lru_cache(None) + @functools.cache def patch(): # A better way to disable the following would be decorate the source # functions with @torch._disable_dynamo. However, this causes issues diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index ae9c6eed879296..8352cae57efed2 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -70,6 +70,10 @@ class InternalTorchDynamoError(TorchDynamoException): pass +class ResumePrologueTracingError(TorchDynamoException): + pass + + class RestartAnalysis(TorchDynamoException): restart_reason: Optional[str] @@ -367,7 +371,7 @@ def raise_observed_exception( # stack and raise the exception. exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type] tx.exn_vt_stack.set_current_exception(exception_vt) - raise observed_exception_map[exc_type] + raise get_dynamo_observed_exception(exc_type) def handle_observed_exception(tx: Any) -> None: @@ -404,6 +408,7 @@ def handle_observed_exception(tx: Any) -> None: torch._subclasses.fake_tensor.DynamicOutputShapeException, torch._subclasses.fake_tensor.UnsupportedOperatorException, torch._subclasses.fake_tensor.UnsupportedFakeTensorException, + torch._subclasses.fake_tensor.UnsupportedMutationAliasingException, ) @@ -493,6 +498,46 @@ def format_graph_break_message( return msg +''' +@lru_cache(maxsize=1) +def _load_graph_break_registry() -> dict[str, Any]: + """ + Loads the graph break registry from JSON file with caching. + """ + try: + script_dir = Path(__file__).resolve().parent + registry_path = script_dir / "graph_break_registry.json" + with registry_path.open() as f: + return json.load(f) + except (FileNotFoundError, json.JSONDecodeError) as e: + log.error("Error accessing the registry file: %s", e) + return {} + +''' + +''' +def get_gbid_documentation_link(gb_type: str) -> Optional[str]: + """ + Retrieves the GBID documentation link for a given graph break type. + + Args: + gb_type: The graph break type to look up. + + Returns: + A string containing the documentation URL if found, otherwise None. + """ + GRAPH_BREAK_SITE_URL = "https://compile-graph-break-site.vercel.app/gb/" + + registry = _load_graph_break_registry() + + for k, v in registry.items(): + if v and v[0].get("Gb_type") == gb_type: + return f"{GRAPH_BREAK_SITE_URL}{k}" + + return "None" +''' + + # TODO replace old unimplemented later def unimplemented_v2( gb_type: str, @@ -514,6 +559,12 @@ def unimplemented_v2( """ msg = format_graph_break_message(gb_type, context, explanation, hints) + + # Temporarily disabling the generation of the weblinks in error message + + # documentation_link = get_gbid_documentation_link(gb_type) + # msg += f"\n For more details about this graph break, please visit: {documentation_link}" + if log_warning: log.warning(msg) if from_exc is not _NOTHING: diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 752680b6fd8bfb..c4fbc62ea5db2a 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -200,9 +200,11 @@ def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: return nonrecursive_disable_wrapper -def _dynamo_config_patch_proxy_dunder_call( - self: Any, func: Callable[_P, _R] -) -> Callable[_P, _R]: +def wrap_dunder_call_ctx_manager(self: Any, func: Callable[_P, _R]) -> Callable[_P, _R]: + """ + Apply self as a ctx manager around a call to func + """ + @functools.wraps(func) def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: with self: @@ -218,3 +220,34 @@ def unwrap_maybe_dynamic_int(x: Union[torch.Tensor, int]) -> int: # x.size() is expected to be [0, dynamic_int] return x.size(1) return x + + +def call_accumulate_grad( + variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool +) -> None: + updated_grad = torch._dynamo.compiled_autograd.ops.AccumulateGrad( # type: ignore[attr-defined] + [grad], variable, variable.grad, has_post_hooks + ) + variable.grad = updated_grad[0] + + +def wrap_inline_with_set_fullgraph( + fn: Callable[_P, _R], fullgraph: bool +) -> Callable[_P, _R]: + # NB: need multiple definitions in order to prevent `fullgraph` from + # being a freevar of wrapper + if fullgraph: + + @functools.wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + with torch._dynamo.set_fullgraph(True): + return fn(*args, **kwargs) + + else: + + @functools.wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + with torch._dynamo.set_fullgraph(False): + return fn(*args, **kwargs) + + return wrapper diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json new file mode 100644 index 00000000000000..153b7132bf19e8 --- /dev/null +++ b/torch/_dynamo/graph_break_registry.json @@ -0,0 +1,2203 @@ +{ + "GB0000": [ + { + "Gb_type": "All __torch_function__ overrides returned NotImplemented due to TypeError from user code", + "Context": "fn={fn}, args={args}, kwargs={kwargs}", + "Explanation": "All __torch_function__ overrides for for function {fn} returned NotImplemented", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0001": [ + { + "Gb_type": "Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", + "Context": "{self}.as_subclass({cls})", + "Explanation": "Currently not supported", + "Hints": [ + "Avoid this call or move it outside `torch.compile` regione", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0002": [ + { + "Gb_type": "Assertion failed on symbolic shapes", + "Context": "str(sym_expr)", + "Explanation": "", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0003": [ + { + "Gb_type": "Attempt to trace generator", + "Context": "", + "Explanation": "Generators cannot be compiled directly with `torch.compile`.", + "Hints": [ + "Call a generator from inside of a non-generator Python function and ", + "compile that function instead.", + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0004": [ + { + "Gb_type": "Attempted super().__delattr__() on an object without mutation tracking", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo needs to track mutations on an object before `super().__delattr__` can be used on it. But the object ({self.objvar}) doesn't have attribute mutation tracking enabled.", + "Hints": [ + "Ensure the object is tracked by Dynamo's side effect system.", + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0005": [ + { + "Gb_type": "Attempted to a str() method implemented in C/C++", + "Context": "", + "Explanation": "{type(arg.value)} has a C/C++ based str method. This is not supported.", + "Hints": [ + "Write the str method in Python" + ] + } + ], + "GB0006": [ + { + "Gb_type": "Attempted to call a super() attribute that is not a function or method", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo does not know how to trace the call `super().{name}()` because `super().{name}` is not a function or method attribute.", + "Hints": [ + "Ensure the attribute accessed via `super()` is a standard method or function." + ] + } + ], + "GB0007": [ + { + "Gb_type": "Attempted to call function marked as skipped", + "Context": "module: {module_name}, qualname: {qualname}, skip reason: {reason}", + "Explanation": "explanation", + "Hints": [] + } + ], + "GB0008": [ + { + "Gb_type": "Attempted to inline function marked as skipped", + "Context": "qualname: {fn_qualname}, name: {func.get_name()}, filename: `{func.get_filename()}`, skip reason: {result.reason}", + "Explanation": "Dynamo developers have intentionally marked that the function `{fn_qualname}` should not be traced.", + "Hints": [] + } + ], + "GB0009": [ + { + "Gb_type": "Attempted to inline function marked as skipped (SkipFunctionVariable)", + "Context": "Attempted to inline a SkipFunctionVariable {func}", + "Explanation": "Attempted to inline a function that was previously determined to be marked as intentionally skipped.", + "Hints": [] + } + ], + "GB0010": [ + { + "Gb_type": "Attempted to read a deleted variable", + "Context": "item: {item}, name: {name}", + "Explanation": "", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0011": [ + { + "Gb_type": "Attempted to read undefined local variable", + "Context": "LOAD_FAST {name}", + "Explanation": "Could not find a local variable with name `{name}`", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0012": [ + { + "Gb_type": "Attempted to read undefined local variable (implicit)", + "Context": "LOAD_FAST {name}", + "Explanation": "Could not find an implicit local variable with name `{name}`", + "Hints": [ + "This happens in dict/list comprehensions", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0013": [ + { + "Gb_type": "Attempted to represent unregistered RemovableHandle", + "Context": "", + "Explanation": "Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, which is not supported. This happens because the RemovableHandle was created in another frame.", + "Hints": [] + } + ], + "GB0014": [ + { + "Gb_type": "Attempted to wrap RNN, GRU, or LSTM", + "Context": "str(value)", + "Explanation": "Dynamo does not support RNN, GRU, or LSTM.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0015": [ + { + "Gb_type": "Attempted to wrap sparse Tensor", + "Context": "", + "Explanation": "torch.compile does not support sparse Tensors", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0016": [ + { + "Gb_type": "Attempted to wrap strided NestedTensor", + "Context": "", + "Explanation": "torch.compile does not support strided NestedTensor", + "Hints": [] + } + ], + "GB0017": [ + { + "Gb_type": "Attempted to wrap torch._higher_order_ops.invoke_subgraph", + "Context": "", + "Explanation": "Directly using invoke_subgraph is not supported. Use nested_compile_region", + "Hints": [] + } + ], + "GB0018": [ + { + "Gb_type": "Attempted to wrap unbacked SymInt", + "Context": "", + "Explanation": "Unbacked SymInt input is not supported yet.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0019": [ + { + "Gb_type": "AutogradFunctionContextVariable escaped Dynamo-traced region", + "Context": "", + "Explanation": "We cannot reconstruct a torch.autograd.Function's context object.", + "Hints": [] + } + ], + "GB0020": [ + { + "Gb_type": "BUILD_STRING key conflict", + "Context": "format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}", + "Explanation": "Failed to build format string due to key conflict", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0021": [ + { + "Gb_type": "BUILD_STRING type error", + "Context": "str(part)", + "Explanation": "Format string part type is not correct - expected constant or format string.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0022": [ + { + "Gb_type": "Bad import result", + "Context": "typestr(value)", + "Explanation": "Import result is not a Python module.", + "Hints": [] + } + ], + "GB0023": [ + { + "Gb_type": "Builtin `operator.*` comparison with constant `self` failed", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "\"Failed to compare {self} with {other}, \" + f\"because {other} is not a Python constant or its mutation check fails.\"", + "Hints": [] + } + ], + "GB0024": [ + { + "Gb_type": "CLEANUP_THROW with StopIteration", + "Context": "", + "Explanation": "Received StopIteration when handling generator.throw/close. This is not supported.", + "Hints": [] + } + ], + "GB0025": [ + { + "Gb_type": "Call to `torch._dynamo.graph_break()`", + "Context": "Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`", + "Explanation": "User-inserted graph break. Message: {graph_break_msg}", + "Hints": [ + "Remove the `torch._dynamo.graph_break()` call." + ] + } + ], + "GB0026": [ + { + "Gb_type": "Calling subclass default constructor with more than tensor argument", + "Context": "{self.value}(args={args}, kwargs={kwargs})", + "Explanation": "Currently not supported", + "Hints": [ + "Avoid this constructor call or move it outside ", + "`torch.compile` regione", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0027": [ + { + "Gb_type": "Cannot check Tensor object identity without its fake value", + "Context": "str(fake_tensor)", + "Explanation": "TensorVariable is missing a fake example_value.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0028": [ + { + "Gb_type": "Caught non-Exception value", + "Context": "str(exc_instance)", + "Explanation": "Except expects to receive an object of Exception type but received {exc_instance}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0029": [ + { + "Gb_type": "Compilation of intermediate hooks requires compiled autograd", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo must be in compiled_autograd to register hooks.", + "Hints": [] + } + ], + "GB0030": [ + { + "Gb_type": "ComptimeContext graph break", + "Context": "msg", + "Explanation": "Manually triggered ComptimeContext graph break with message {msg}.", + "Hints": [] + } + ], + "GB0031": [ + { + "Gb_type": "Custom __getattribute__ in nn.Module attribute access", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo does not support checking key existence on `nn.Module` instances that have a custom `__getattribute__` method defined.", + "Hints": [ + "Avoid defining `__getattribute__` in your module.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0032": [ + { + "Gb_type": "Custom __getattribute__ in nn.Module dict key check", + "Context": "has_key_in_generic_dict {self} {key}", + "Explanation": "Dynamo does not support checking key existence on `nn.Module` instances that have a custom `__getattribute__` method defined.", + "Hints": [ + "Avoid defining `__getattribute__` in your module.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0033": [ + { + "Gb_type": "Data dependent operator", + "Context": "str(cause.func)", + "Explanation": "Operator `{cause.func}` has a non-Tensor output whose value is dependent on the data of Tensor inputs.", + "Hints": [] + } + ], + "GB0034": [ + { + "Gb_type": "Data-dependent assertion failed (cannot compile partial graph)", + "Context": "value: {value}", + "Explanation": "Dynamo has determined when encountering a data-dependent assert failure that it should not compile the partial graph.", + "Hints": [ + "Use `torch._assert()` to raise a hard AssertionError when the check fails. ", + "This error will propagate back the user code ", + "that called the compiled function (i.e. Dynamo will not trace any exception handling).", + "Remove the assert statement.", + "Move the assert statement outside of any context managers in order to graph break with ", + "partial graph compilation (if fullgraph=False).", + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0035": [ + { + "Gb_type": "Data-dependent branching with non-constant __bool__", + "Context": "method: {x}, result: {result}", + "Explanation": "Attempted to perform data-dependent branching on a user-defined object with a __bool__ method that did not return a constant.", + "Hints": [] + } + ], + "GB0036": [ + { + "Gb_type": "Dynamic shape operator", + "Context": "str(cause.func)", + "Explanation": "Operator `{cause.func}`'s output shape depends on input Tensor data.", + "Hints": [ + "Enable tracing of dynamic shape operators with ", + "`torch._dynamo.config.capture_dynamic_output_shape_ops = True`" + ] + } + ], + "GB0037": [ + { + "Gb_type": "Dynamic shape operator (no meta kernel)", + "Context": "str(cause.func)", + "Explanation": "Operator `{cause.func}` does not have a meta kernel that supports dynamic output shapes", + "Hints": [ + "Please report an issue to PyTorch" + ] + } + ], + "GB0038": [ + { + "Gb_type": "Dynamic slicing with Tensor arguments", + "Context": "SliceVariable start: {start}, stop: {stop}, step: {step}", + "Explanation": "Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0039": [ + { + "Gb_type": "Dynamo cache limit exceeded", + "Context": "Limit type: {limit_type}", + "Explanation": "Dynamo attempted to recompile the code object too many times, exceeding the {limit_type} cache size limit.Giving up on compiling as the compile time tradeoff is likely not worth the performance gain.", + "Hints": [] + } + ], + "GB0040": [ + { + "Gb_type": "Encountered aliasing during higher order op tracing", + "Context": "context", + "Explanation": "Higher order ops do not support aliasing. Found in {source_target.name()}", + "Hints": [ + "Replace `return input` with `return input.clone()` to avoid aliasing.", + "Consider using the debug context to change user code to avoid aliasing.", + "Please open an issue." + ] + } + ], + "GB0041": [ + { + "Gb_type": "Encountered input mutation during higher order op tracing", + "Context": "context", + "Explanation": "Higher order ops do not support input mutation. Found in {source_target.name()}", + "Hints": [ + "Consider using the debug context to change user code to avoid mutation.", + "Please open an issue." + ] + } + ], + "GB0042": [ + { + "Gb_type": "Encountered non user function variable during invoke_subgraph HOP tracing", + "Context": "str(fn_vt)", + "Explanation": "invoke_subgraph does not support non user function variable", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0043": [ + { + "Gb_type": "Encountered non-PT2-compliant op", + "Context": "", + "Explanation": "msg + + err_epilogue", + "Hints": [] + } + ], + "GB0044": [ + { + "Gb_type": "Encountered strided NestedTensor in automatic dynamic dim determination", + "Context": "", + "Explanation": "torch.compile does not support strided NestedTensor", + "Hints": [] + } + ], + "GB0045": [ + { + "Gb_type": "Encountered tensor.is_inference() during tracing", + "Context": "", + "Explanation": "tensor.is_inference() is not supported", + "Hints": [ + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0046": [ + { + "Gb_type": "Encountered torch.is_inference_mode_enabled during tracing", + "Context": "", + "Explanation": "torch.is_inference_mode_enabled() is not supported", + "Hints": [ + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0047": [ + { + "Gb_type": "Encountered unconverted argument when attempting to inline", + "Context": "func: {func}, arg: {v}", + "Explanation": "An argument to an inlined function was not successfully converted to a VariableTracker.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0048": [ + { + "Gb_type": "Error getting associated real value", + "Context": "call_id {self}", + "Explanation": "Dynamo encountered an error while trying to get the associated real value.", + "Hints": [] + } + ], + "GB0049": [ + { + "Gb_type": "Error when attempting to resolve op packet", + "Context": "", + "Explanation": "str(e)", + "Hints": [] + } + ], + "GB0050": [ + { + "Gb_type": "Exception with bad expected type", + "Context": "str(expected_exc_types)", + "Explanation": "`except ...` has unsupported type {expected_exc_types}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0051": [ + { + "Gb_type": "Exception with non-type expectation", + "Context": "str(expected_type)", + "Explanation": "`except ...` expects a non-type: {expected_type}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0052": [ + { + "Gb_type": "Excessive RestartAnalysis() calls", + "Context": "", + "Explanation": "Dynamo attempted to trace the same frame 100+ times. Giving up on compiling as the compile time tradeoff is likely not worth the performance gain.", + "Hints": [] + } + ], + "GB0053": [ + { + "Gb_type": "FSDP with use_orig_params=False", + "Context": "", + "Explanation": "Dynamo only supports FSDP with use_orig_params=True", + "Hints": [] + } + ], + "GB0054": [ + { + "Gb_type": "Failed to construct Enum variable", + "Context": "value: {value_vt}, allowed enum values: {list(cls_type)}", + "Explanation": "Attempted to construct an Enum value that is non-constant (e.g. int, string) or is not an acceptable value for the Enum. Acceptable values for Enum `{cls_type}`: {list(cls_type)}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0055": [ + { + "Gb_type": "Failed to convert args/kwargs to proxy", + "Context": "call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}", + "Explanation": "Missing `as_proxy()` implementation for some arg/kwarg.", + "Hints": [] + } + ], + "GB0056": [ + { + "Gb_type": "Failed to mutate tensor data attribute", + "Context": "setattr({obj}, {name}, {val})", + "Explanation": "Dyanmo only supports mutating `.data` of tensor created outside `torch.compile` region", + "Hints": [ + "Don't mutate `.data` on this tensor, or move ", + "the mutation out of `torch.compile` region" + ] + } + ], + "GB0057": [ + { + "Gb_type": "Failed to raise exception", + "Context": "str(exc)", + "Explanation": "Attempted to raise a non-Exception type/value.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0058": [ + { + "Gb_type": "Failed to set tensor attribute", + "Context": "setattr({obj}, {name}, {val})", + "Explanation": "Dyanmo doesn't support setting these tensor attributes", + "Hints": [ + "Don't mutate attribute '{name}' on tensors, or ", + "move the mutation out of `torch.compile` region" + ] + } + ], + "GB0059": [ + { + "Gb_type": "Failed to trace builtin operator", + "Context": "builtin {fn.__name__} {arg_types} {has_kwargs}", + "Explanation": "Dynamo does not know how to trace builtin operator `{fn.__name__}` with argument types {real_arg_types} (has_kwargs {has_kwargs})", + "Hints": [ + "Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. ", + "Consider using an equivalent alternative function/method to `{fn.__name__}`.", + "If you are attempting to call a logging function (e.g. `print`), ", + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please report an issue to PyTorch." + ] + } + ], + "GB0060": [ + { + "Gb_type": "Failed to trace unittest method", + "Context": "function: unittest.TestCase.{name}", + "Explanation": "Dynamo does not know how to trace unittest method `{name}` ", + "Hints": [ + "Avoid calling `TestCase.{name}`. ", + "Please report an issue to PyTorch." + ] + } + ], + "GB0061": [ + { + "Gb_type": "Failed to unpack object for BUILD_LIST_UNPACK", + "Context": "str(seq)", + "Explanation": "{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK bytecode (`[*x, *y, ...]`).", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0062": [ + { + "Gb_type": "Failed to unpack object for UNPACK_EX", + "Context": "str(seq)", + "Explanation": "{seq} cannot be unpacked into a list for the UNPACK_EX bytecode.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0063": [ + { + "Gb_type": "Failed to unpack object for UNPACK_SEQUENCE", + "Context": "str(seq)", + "Explanation": "{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode (i.e. `a, b, c = d`).", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0064": [ + { + "Gb_type": "Fake tensor propagation exception", + "Context": "str(e.reason)", + "Explanation": "msg", + "Hints": [] + } + ], + "GB0065": [ + { + "Gb_type": "Graph break in inlined function", + "Context": "", + "Explanation": "Graph breaks in an inlined call are not supported.", + "Hints": [] + } + ], + "GB0066": [ + { + "Gb_type": "Graph break under GenericContextWrappingVariable", + "Context": "Active generic context managers: {self.active_generic_context_managers}", + "Explanation": "Attempted to graph break in an active context manager(s) that doesn't support graph breaking.", + "Hints": [ + "Move the offending context manager(s) to outside the compiled region.", + "This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one." + ] + } + ], + "GB0067": [ + { + "Gb_type": "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)", + "Context": "", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0068": [ + { + "Gb_type": "Illegal method invocation in strict mode", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Dynamo currently does not support this method ({name}) invocation in strict mode.", + "Hints": [] + } + ], + "GB0069": [ + { + "Gb_type": "Import failure", + "Context": "module_name: {module_name}, fromlist: {fromlist}, level={level}", + "Explanation": "Failure when attempting to import.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0070": [ + { + "Gb_type": "Indexing list with non-scalar tensor", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Attempted to index list-like object with tensor with > 1 element.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0071": [ + { + "Gb_type": "Inline attempt with __self__", + "Context": "str(func)", + "Explanation": "Attempted to inline a function with the `__self__` attribute. Dynamo is expected to decompose method calls into function calls with a `self` argument.", + "Hints": [] + } + ], + "GB0072": [ + { + "Gb_type": "Inplace op on input tensor", + "Context": "", + "Explanation": "Attempted to trace an inplace view op on input tensor {typestr(self.value)}.", + "Hints": [ + "Ensure you do not modify input tensor in place.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0073": [ + { + "Gb_type": "Invoking an nn.Module inside a HigherOrderOperator", + "Context": "", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0074": [ + { + "Gb_type": "Invoking an nn.Module inside a higher order operator", + "Context": "Higher order op name: {self.source_target}", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0075": [ + { + "Gb_type": "LOAD_BUILD_CLASS bytecode not supported", + "Context": "", + "Explanation": "Dynamo does not support tracing classes that are defined in the compiled region.", + "Hints": [ + "Move the class definition out of the compiled region.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0076": [ + { + "Gb_type": "LOAD_FAST_CHECK on uninitialized variable", + "Context": "inst.argval", + "Explanation": "Attempted to load uninitialized local variable {inst.argval}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0077": [ + { + "Gb_type": "Length mismatch when unpacking object for UNPACK_SEQUENCE", + "Context": "expected length: {inst.argval}, actual: {len(val)}", + "Explanation": "{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode (i.e. `a, b, c = d`) with unexpected length.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0078": [ + { + "Gb_type": "Limitation of `nonstrict_trace", + "Context": "{self}", + "Explanation": "msg", + "Hints": [ + "make sure definition of {fn_name} is outside ", + "`torch.compile` region" + ] + } + ], + "GB0079": [ + { + "Gb_type": "Missing CALL_INTRINSIC_1 handler", + "Context": "CALL_INTRINSIC_1 operand: {inst.argval}", + "Explanation": "No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0080": [ + { + "Gb_type": "Missing FakeTensor example value", + "Context": "str(node)", + "Explanation": "`FakeTensor` example value was required for {node} but not available.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0081": [ + { + "Gb_type": "Missing attribute when running call_method node", + "Context": "", + "Explanation": "make_error_message(\"attribute not defined\")", + "Hints": [] + } + ], + "GB0082": [ + { + "Gb_type": "Missing bytecode handler", + "Context": "{opname} with args {args}", + "Explanation": "Dynamo does not know how to handle the bytecode instruction `{opname}`.", + "Hints": [ + "Do not trace code that produces the `{opname}` bytecode instruction ", + "(see https://docs.python.org/3/library/dis.html for bytecode semantics).", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0083": [ + { + "Gb_type": "Module-level backwards hooks require compiled autograd.", + "Context": "", + "Explanation": "", + "Hints": [ + "Enable compiled autograd by setting torch._dynamo.config.compiled_autograd = True." + ] + } + ], + "GB0084": [ + { + "Gb_type": "Non-constant attribute given to `super().__delattr__()`", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo requires the attribute name passed to `super().__delattr__(...)` to be a constant (string).", + "Hints": [ + "Ensure the attribute name is a string literal or a constant variable." + ] + } + ], + "GB0085": [ + { + "Gb_type": "Non-function or method in subclass of torch.autograd.Function", + "Context": "call_apply {self} {args} {kwargs}", + "Explanation": "Dynamo requires the `forward` attribute of a `torch.autograd.Function` subclass to be a standard Python function or method. Found type `{type(fn).__name__}` instead.", + "Hints": [ + "Ensure the `forward` method is defined as a regular ", + "function or instance method." + ] + } + ], + "GB0086": [ + { + "Gb_type": "Not a Python constant", + "Context": "guard_as_python_constant {self}", + "Explanation": "Failed to convert {self} into a Python constant.", + "Hints": [] + } + ], + "GB0087": [ + { + "Gb_type": "NotImplementedError/UnsupportedFakeTensorException when running FX node", + "Context": "", + "Explanation": "make_error_message(e)", + "Hints": [] + } + ], + "GB0088": [ + { + "Gb_type": "Observed exception", + "Context": "str(raised_exception)", + "Explanation": "observed_exn_gb_explanation", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0089": [ + { + "Gb_type": "Observed exception (EXCEPT_HANDLER)", + "Context": "str(raised_exception)", + "Explanation": "observed_exn_gb_explanation + \" This graph break is unexpected.\"", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0090": [ + { + "Gb_type": "Operator does not support running with fake tensors", + "Context": "unsupported operator: {cause.func}", + "Explanation": "", + "Hints": [ + "{import_suggestion}see ", + "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0", + " for how to fix" + ] + } + ], + "GB0091": [ + { + "Gb_type": "Read uninitialized cell", + "Context": "str(cellvar)", + "Explanation": "Attempted to read a cell variable that has not been populated yet.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0092": [ + { + "Gb_type": "Reconstruction failure", + "Context": "str(value)", + "Explanation": "Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.", + "Hints": [ + "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable ", + "that Dynamo cannot reconstruct, then remove it from the return statement.", + "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have ", + "reconstruction rules may be fundamentally unreconstructable.", + "This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one." + ] + } + ], + "GB0093": [ + { + "Gb_type": "Reconstruction failure: source.reconstruct not implemented", + "Context": "str(source)", + "Explanation": "Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0094": [ + { + "Gb_type": "SEND with bad type", + "Context": "TOS type: {typestr(tos)}", + "Explanation": "Attempted to SEND with unsupported type {typestr(tos)}.", + "Hints": [] + } + ], + "GB0095": [ + { + "Gb_type": "Set Exception object `__traceback__` attribute to not-`None`", + "Context": "call_setattr {self} {name}", + "Explanation": "Dynamo does not support setting the attribute '__traceback__' on tracked exception objects to anything other than None.", + "Hints": [ + "Avoid setting '__traceback__' on exception objects ", + "within traced code, or set it to None." + ] + } + ], + "GB0096": [ + { + "Gb_type": "Should not compile partial graph (STORE_ATTR)", + "Context": "", + "Explanation": "Dynamo has determined when encountering an unsupported STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.", + "Hints": [] + } + ], + "GB0097": [ + { + "Gb_type": "Side effect on existing deque with limited maxlen", + "Context": "", + "Explanation": "This is not supported.", + "Hints": [ + "Don't use a deque with `maxlen` specified." + ] + } + ], + "GB0098": [ + { + "Gb_type": "Skip calling `torch.compiler.disable()`d function", + "Context": "str(self.value)", + "Explanation": "Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable` (reason: {msg})", + "Hints": [ + "Remove the `torch.compiler.disable` call" + ] + } + ], + "GB0099": [ + { + "Gb_type": "Skip inlining `torch.compiler.disable()`d function", + "Context": "str(func.get_function())", + "Explanation": "Skip inlining function {func.get_function()} since it was wrapped with `torch.compiler.disable` (reason: {msg})", + "Hints": [ + "Remove the `torch.compiler.disable` call" + ] + } + ], + "GB0100": [ + { + "Gb_type": "Storing Tensor hook handle in globals", + "Context": "name", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0101": [ + { + "Gb_type": "Storing Tensor hook handle in globals (inline call)", + "Context": "inst.argval", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0102": [ + { + "Gb_type": "Strict mode banned op", + "Context": "var_getattr {self} {name}", + "Explanation": "Getattr invocation '{name}' in strict mode is not supported.", + "Hints": [ + "Remove `{name}` from the list of banned ops by ", + "setting `torch._dynamo.config._autograd_backward_strict_mode_banned_ops`." + ] + } + ], + "GB0103": [ + { + "Gb_type": "Tensor subclass overridden method call", + "Context": "{name}", + "Explanation": "`torch.compile` currently can't trace this", + "Hints": [ + "Avoid calling {name} of tensor subclass in torch.compile region", + "Renaming method `{name}` of type {self.class_type}", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0104": [ + { + "Gb_type": "Tensor with grad_fn()", + "Context": "var_getattr {self} grad_fn", + "Explanation": "Dynamo does not support tracing tensors with a grad_fn directly.", + "Hints": [] + } + ], + "GB0105": [ + { + "Gb_type": "Tensor.numpy() with trace_numpy=False", + "Context": "call_method {self} numpy", + "Explanation": "`Tensor.numpy()` was called, but the `trace_numpy` configuration was manually disabled.", + "Hints": [ + "Set `torch._dynamo.config.trace_numpy = True` to allow ", + "Dynamo to trace through NumPy." + ] + } + ], + "GB0106": [ + { + "Gb_type": "Tensor.numpy() without NumPy installed", + "Context": "call_method {self} numpy", + "Explanation": "`Tensor.numpy()` was called, but the NumPy library is not available in the current environment.", + "Hints": [ + "Ensure NumPy is installed in your Python environment." + ] + } + ], + "GB0107": [ + { + "Gb_type": "Tensor.random_ op", + "Context": "Tensor.{name}(args={args}, kwargs={kwargs})", + "Explanation": "This is currently not supported.", + "Hints": [ + "Use the out-of-place version of this op", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0108": [ + { + "Gb_type": "Tensor.retain_grad() with AOTDispatcher", + "Context": "var_getattr {self} retain_grad", + "Explanation": "`Tensor.retain_grad()` does not work with AOTDispatcher.", + "Hints": [] + } + ], + "GB0109": [ + { + "Gb_type": "Tensor.tolist() with non-integer tensor", + "Context": "call_method {self} to_list", + "Explanation": "Dynamo currently does not support tracing `tolist()` on non-integer tensors.", + "Hints": [ + "Ensure the input tensor to `tolist()` is an integer ", + "type (e.g., int8, int16, int32, int64)." + ] + } + ], + "GB0110": [ + { + "Gb_type": "Tensor.uniform_ op called with `from` keyword", + "Context": "Tensor.{name}(args={args}, kwargs={kwargs})", + "Explanation": "This is currently not supported.", + "Hints": [ + "Avoid using the `from` keyword.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0111": [ + { + "Gb_type": "TypeError from user code", + "Context": "call_function({self.value}, {args}, {kwargs})", + "Explanation": "msg", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0112": [ + { + "Gb_type": "TypeError when making fake tensor call", + "Context": "TypeError {node.target}: {cause}", + "Explanation": "", + "Hints": [] + } + ], + "GB0113": [ + { + "Gb_type": "Unable to resolve super getattr", + "Context": "", + "Explanation": "Dynamo failed to trace attribute `{name}` accessed via `super()` (for type `{self.typevar}` and object `{self.objvar}`) because the resolved attribute type is not supported.", + "Hints": [ + "Ensure the attribute exists in the parent class.", + "Check the arguments passed to `super()`." + ] + } + ], + "GB0114": [ + { + "Gb_type": "Unexpected failure during itertools.accumulate() iteration", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0115": [ + { + "Gb_type": "Unexpected failure during itertools.groupby() iteration", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Unexpected failure in invoking function during groupby", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0116": [ + { + "Gb_type": "Unexpected type in sourceless builder", + "Context": "{value_type.__module__}.{value_type.__qualname__}", + "Explanation": "SourcelessBuilder.create does not know how to wrap {value_type}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0117": [ + { + "Gb_type": "Unhandled args for method", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Dynamo encountered an error while calling the method `{name}`.", + "Hints": [] + } + ], + "GB0118": [ + { + "Gb_type": "Unimplemented next() call", + "Context": "next({self})", + "Explanation": "This abstract method must be implemented", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0119": [ + { + "Gb_type": "Uninitialized nn.Module", + "Context": "typestr(value)", + "Explanation": "Attempted to trace an uninitialized nn.Module of type {typestr(value)}.", + "Hints": [ + "Ensure your nn.Module instance has called `super().__init__()`.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0120": [ + { + "Gb_type": "Unreachable sub-generator code", + "Context": "", + "Explanation": "Should only be encountered while implementing generator support.", + "Hints": [] + } + ], + "GB0121": [ + { + "Gb_type": "UnspecializedNNModuleVariable missing method", + "Context": "call_method: {self} {name} {args} {kwargs}", + "Explanation": "Dynamo does not support tracing method {name} of nn.Module {self.value}", + "Hints": [ + "Dynamo does not really define unspecialized nn.Module very well.", + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0122": [ + { + "Gb_type": "Unsupported SourceType", + "Context": "MutationType.__init__ {self} {typ}", + "Explanation": "Dynamo does not support the type `{typ}`", + "Hints": [ + "This branch is not supposed to be reachable.", + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0123": [ + { + "Gb_type": "Unsupported Tensor.backward() call", + "Context": "call_method {self} backward {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.backward()`.", + "Hints": [ + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0124": [ + { + "Gb_type": "Unsupported Tensor.item() call with capture_scalar_outputs=False", + "Context": "call_method {self} item {args} {kwargs}", + "Explanation": "Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.", + "Hints": [ + "Set `torch._dynamo.config.capture_scalar_outputs = True` ", + "or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` ", + "to include these operations in the captured graph." + ] + } + ], + "GB0125": [ + { + "Gb_type": "Unsupported Tensor.requires_grad_() call", + "Context": "call_method {self} requires_grad_", + "Explanation": "Dynamo does not support changes to a Tensor's `requires_grad` through calling `requires_grad_()`.", + "Hints": [] + } + ], + "GB0126": [ + { + "Gb_type": "Unsupported Tensor.resize_() call", + "Context": "call_method {self} resize_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.resize_()`.", + "Hints": [] + } + ], + "GB0127": [ + { + "Gb_type": "Unsupported Tensor.resize_as_() call", + "Context": "call_method {self} resize_as_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.resize_as_()`.", + "Hints": [] + } + ], + "GB0128": [ + { + "Gb_type": "Unsupported Tensor.set_() call", + "Context": "call_method {self} set_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.set_()` overloads that include more than one argument.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0129": [ + { + "Gb_type": "Unsupported Tensor.sparse_resize_() call", + "Context": "call_method {self} sparse_resize_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.sparse_resize_()`.", + "Hints": [] + } + ], + "GB0130": [ + { + "Gb_type": "Unsupported Tensor.sparse_resize_and_clear_() call", + "Context": "call_method {self} sparse_resize_and_clear_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.sparse_resize_and_clear_()`.", + "Hints": [] + } + ], + "GB0131": [ + { + "Gb_type": "Unsupported __setitem__/__setattr__ inline attempt", + "Context": "code name: {code.co_name}, args: {args}", + "Explanation": "Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.", + "Hints": [] + } + ], + "GB0132": [ + { + "Gb_type": "Unsupported `func` in itertools.accumulate", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to get the function to use for itertools.accumulate. itertools.accumulate expects the `func` as the second argument or as a keyword argument.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0133": [ + { + "Gb_type": "Unsupported arguments for itertools.accumulate", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace itertools.accumulate with args: {args} and kwargs: {kwargs}. itertools.accumulate expects an iterable, an optional binary function for accumulation, and an optional initial value to set the starting state.", + "Hints": [ + "Make sure the arguments to itertools.accumulate are correct.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0134": [ + { + "Gb_type": "Unsupported arguments for itertools.groupby", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace itertools.groupby with args: {args} and kwargs: {kwargs}. itertools.groupby expects an iterable to group and an optional key function to determine groupings.", + "Hints": [ + "Make sure the arguments to itertools.groupby are correct.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0135": [ + { + "Gb_type": "Unsupported attribute assignment on Exception object", + "Context": "call_setattr {self} {name}", + "Explanation": "Dynamo does not support setting the attribute '{name}' on tracked exception objects. Only `__context__`, `__cause__`, `__suppress_context__`, and `__traceback__` are supported.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0136": [ + { + "Gb_type": "Unsupported attribute for range() object", + "Context": "var_getattr {self} {name}", + "Explanation": "Expected attribute to be one of {','.join(fields)} but got {name}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0137": [ + { + "Gb_type": "Unsupported attribute for slice() object", + "Context": "var_getattr {self} {name}", + "Explanation": "Expected attribute to be one of {','.join(fields)} but got {name}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0138": [ + { + "Gb_type": "Unsupported autograd.Function context `save_for_backward`", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo requires the `saved_tensors` attribute to be initialized on the `autograd.Function` context object.", + "Hints": [ + "Ensure that the `saved_tensors` attribute is properly ", + "initialized before calling `save_for_backward`. ", + "`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`." + ] + } + ], + "GB0139": [ + { + "Gb_type": "Unsupported autograd.Function context method", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo does not support calling the method `{name}` on `autograd.Function` context objects. Supported methods are `__setattr__`, `save_for_backward` and `mark_non_differentiable`.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0140": [ + { + "Gb_type": "Unsupported autograd.Function method", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo does not support calling the method `{name}` directly on the `torch.autograd.Function` instance. Supported methods include `apply`, `backward`, static methods, and class methods.", + "Hints": [ + "Ensure the method is decorated with `@staticmethod` ", + "or `@classmethod` if it's meant to be called on the class." + ] + } + ], + "GB0141": [ + { + "Gb_type": "Unsupported call_id() without source", + "Context": "call_id {self}", + "Explanation": "call_id() not supported for sourceless TensorVariable.", + "Hints": [] + } + ], + "GB0142": [ + { + "Gb_type": "Unsupported context manager", + "Context": "Attempted SETUP_WITH/BEFORE_WITH on {ctx}", + "Explanation": "Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.", + "Hints": [ + "Avoid using the unsupported context manager.", + "If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then ", + "it may be the case that it was created outside the compiled region, which Dynamo does not support. ", + "Supported context managers can cross graph break boundaries only if they are local non-closure ", + "variables, or are intermediate values.", + "File an issue to PyTorch. Simple context managers can potentially be supported, ", + "but note that context managers can't be supported in general" + ] + } + ], + "GB0143": [ + { + "Gb_type": "Unsupported conversion for slice assignment", + "Context": "call_method {self} {name} {args}", + "Explanation": "Missing dynamo support for converting {value} into a list for slice assignment.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0144": [ + { + "Gb_type": "Unsupported custom jvp", + "Context": "call_apply {self} {args} {kwargs}", + "Explanation": "Dynamo does not support tracing `torch.autograd.Function` subclasses that define a custom `jvp` method.", + "Hints": [ + "Remove the custom `jvp` method if possible.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0145": [ + { + "Gb_type": "Unsupported custom vjp", + "Context": "call_apply {self} {args} {kwargs}", + "Explanation": "Dynamo does not support tracing `torch.autograd.Function` subclasses that define a custom `vjp` method.", + "Hints": [ + "Remove the custom `vjp` method if possible.", + "Use standard `backward` instead if applicable.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0146": [ + { + "Gb_type": "Unsupported event method", + "Context": "str(name)", + "Explanation": "Dynamo doesn't support tracing the {method_name} method. We currently support wait, record, synchronize, and query.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0147": [ + { + "Gb_type": "Unsupported function call", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace the function `{self.debug_repr()}`", + "Hints": [ + "Avoid calling `{self.debug_repr()}` in your code.", + "Please report an issue to PyTorch." + ] + } + ], + "GB0148": [ + { + "Gb_type": "Unsupported function call (delayed)", + "Context": "source: {self.source}", + "Explanation": "Dynamo determined that a graph break should occur when calling `{self.source.name()}`. Reason: {self.msg}", + "Hints": [] + } + ], + "GB0149": [ + { + "Gb_type": "Unsupported functorch tracing attempt", + "Context": "", + "Explanation": "msg", + "Hints": [] + } + ], + "GB0150": [ + { + "Gb_type": "Unsupported hasattr call", + "Context": "call_obj_hasattr {self} {name}", + "Explanation": "Dynamo does not know how to trace the function `{self.debug_repr()}`", + "Hints": [ + "Avoid calling `hasattr({self.__class__.__name__}, {name})` in your code.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0151": [ + { + "Gb_type": "Unsupported inspect call", + "Context": "inspect_parameter_names {self}", + "Explanation": "Dynamo does not know how to trace the function `{self.debug_repr()}`", + "Hints": [] + } + ], + "GB0152": [ + { + "Gb_type": "Unsupported key type for itertools.groupby", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace itertools.groupby with key type: {str(type(key))}. We only support grouping keys that are constants (int, float, str, etc.)", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0153": [ + { + "Gb_type": "Unsupported key type for nn.Module.__getitem__", + "Context": "call_method: {self} {name} {args} {kwargs}", + "Explanation": "Dynamo does not support getitem on `nn.Module` with non-constant key.", + "Hints": [] + } + ], + "GB0154": [ + { + "Gb_type": "Unsupported kwargs for itertools.accumulate", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Expected kwargs: 'initial', 'func', but got {','.join(set(kwargs.keys()) - {'initial', 'func'})}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0155": [ + { + "Gb_type": "Unsupported kwargs for itertools.groupby", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Expected kwargs: 'key', but got {','.join(set(kwargs.keys()) - {'key'})}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0156": [ + { + "Gb_type": "Unsupported method call", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`", + "Hints": [] + } + ], + "GB0157": [ + { + "Gb_type": "Unsupported ndarray attribute access", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo currently does not support tracing `ndarray.{name}`.", + "Hints": [] + } + ], + "GB0158": [ + { + "Gb_type": "Unsupported ndarray method call", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "`ndarray.{name}()` is not modelled in `torch._numpy`.", + "Hints": [] + } + ], + "GB0159": [ + { + "Gb_type": "Unsupported ndarray.__version__ access", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo currently does not support tracing `ndarray.{name}`.", + "Hints": [] + } + ], + "GB0160": [ + { + "Gb_type": "Unsupported next() call", + "Context": "next({self})", + "Explanation": "Dynamo does not know how to trace calling `next()` on variable `{self}`.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0161": [ + { + "Gb_type": "Unsupported nn.Module attribute type", + "Context": "nn.Module subclass: {typestr(base)}, name: {name}, attribute type: {typestr(subobj)}", + "Explanation": "Dynamo does not support tracing nn.Module attributes of type `{typestr(subobj)}`", + "Hints": [ + "Refactor your code so that `{name}` (type `{typestr(subobj)}`) is not an attribute of `{typestr(base)}`", + "Currently supported attribute types are methods, classmethods, staticmethods, ", + "properties, constants, and tensors.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0162": [ + { + "Gb_type": "Unsupported super().__init__() call", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Dynamo encountered a super().__init__() call on {objvar} that resolved to a `torch.nn.Module.__init__()` call that we cannot trace.", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0163": [ + { + "Gb_type": "Unsupported tensor subclass attribute access", + "Context": "{name}", + "Explanation": "`torch.compile` currently can't trace this", + "Hints": [ + "Avoid accessing {name} of tensor subclass in torch.compile region", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0164": [ + { + "Gb_type": "Unsupported tensor subclass overridden attribute access", + "Context": "{name}", + "Explanation": "`torch.compile` only support tracing certain types of overridden tensor subclass attributes", + "Hints": [ + "Avoid accessing {name} of tensor subclass in torch.compile region", + "Renaming attribute `{name}` of type {self.class_type}", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0165": [ + { + "Gb_type": "Unsupported torch._C._ImperativeEngine method", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo only supports the `queue_callback` method on a torch._C._ImperativeEngine instance, but found: `{name}`.", + "Hints": [] + } + ], + "GB0166": [ + { + "Gb_type": "Unsupported torch._C._ImperativeEngine.queue_callback()", + "Context": "call_method {self} {name}", + "Explanation": "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True.", + "Hints": [] + } + ], + "GB0167": [ + { + "Gb_type": "Variadic function call with bad args/kwargs type", + "Context": "args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}", + "Explanation": "Expected args to be a list and kwargs to be a dict", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0168": [ + { + "Gb_type": "Variadic function call with bad flags", + "Context": "flags: {inst.argval}", + "Explanation": "Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0169": [ + { + "Gb_type": "Write to immutable cell", + "Context": "cellvar: {cellvar}, value: {value}", + "Explanation": "Dynamo doesn't support writing to immutable/sourceless cell variables.", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0170": [ + { + "Gb_type": "Data-dependent branching", + "Context": "attempted to jump with {value}", + "Explanation": "_explanation", + "Hints": [ + "Use `torch.cond` to express dynamic control flow.", + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + }, + { + "Gb_type": "Data-dependent branching", + "Context": "attempted to jump with {value}", + "Explanation": "_explanation", + "Hints": [] + }, + { + "Gb_type": "_gb_type", + "Context": "attempted to jump with {value}", + "Explanation": "_explanation", + "Hints": [] + } + ], + "GB0171": [ + { + "Gb_type": "assert with non-string message", + "Context": "str(args)", + "Explanation": "Dynamo only supports asserts with string messages", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0172": [ + { + "Gb_type": "async_op=True for distributed collectives", + "Context": "{self.fn}, args={args}, kwargs={kwargs}", + "Explanation": "`torch.compile` doesn't support `async_op=True for {self.fn}", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0173": [ + { + "Gb_type": "backward_state does not support export", + "Context": "", + "Explanation": "Compiled autograd doesn't work with `torch.export`.", + "Hints": [] + } + ], + "GB0174": [ + { + "Gb_type": "bad args to builtin cast()", + "Context": "got args {args} {kwargs}", + "Explanation": "Dynamo expects exactly 2 args to builtin cast().", + "Hints": [ + "Ensure your call to cast() has exactly 2 arguments." + ] + } + ], + "GB0175": [ + { + "Gb_type": "builtin isinstance() cannot determine type of argument", + "Context": "isinstance({arg}, {isinstance_type})", + "Explanation": "Dynamo doesn't have a rule to determine the type of argument {arg}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0176": [ + { + "Gb_type": "call_id() without associated real value", + "Context": "call_id {self}", + "Explanation": "Dynamo could not find an associated real value for the tensor.", + "Hints": [] + } + ], + "GB0177": [ + { + "Gb_type": "can't handle functions not implemented in python ", + "Context": "{fn}", + "Explanation": "Dynamo can only handle functions defined in python", + "Hints": [ + "Move usage of this function out of `torch.compile` region", + "Avoid using `tensor.is_inference()` and `torch.is_inference_mode_enabled()` in your compile code. This is primarily used in conjunction with `torch.inference_mode`. Consider using `torch.no_grad` instead because `torch.no_grad` leads to same improvements as `inference_mode` when `torch.compile` is used." + ] + } + ], + "GB0178": [ + { + "Gb_type": "constant fold exception", + "Context": "attempted to run function {fn} with arguments {args}", + "Explanation": "Encountered exception when attempting to constant fold.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0179": [ + { + "Gb_type": "copy.deepcopy()", + "Context": "copy.deepcopy({x})", + "Explanation": "Dynamo does not support copy.deepcopy()", + "Hints": [ + "Avoid calling copy.deepcopy()", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0180": [ + { + "Gb_type": "dataclass fields failure", + "Context": "obj: {obj}; variable type: {type(obj)}", + "Explanation": "Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.", + "Hints": [] + } + ], + "GB0181": [ + { + "Gb_type": "dtype mismatch between tensor and its gradient", + "Context": "tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}", + "Explanation": "Inconsistent dtype between tensor and its gradient. This can happen in FSDP and crashes meta tensor creation.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0182": [ + { + "Gb_type": "failed to broadcast when attempting Tensor comparison op", + "Context": "{op.__name__}({left}, {right})", + "Explanation": "Dynamo was unable to broad cast the arguments {left}, {right} when attempting to trace the comparison op {op.__name__}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0183": [ + { + "Gb_type": "failed to call dict.fromkeys()", + "Context": "{user_cls.__name__}.fromkeys(): {args} {kwargs}", + "Explanation": "Failed to call {user_cls.__name__}.fromkeys() because arguments could not be automatically converted to a list, or some dict key is not hashable.", + "Hints": [ + "Manually convert the argument to a list.", + "Ensure all keys are hashable." + ] + } + ], + "GB0184": [ + { + "Gb_type": "failed to call str() on user defined object", + "Context": "str(arg)", + "Explanation": "User defined object has no __str__ or __repr__ method", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0185": [ + { + "Gb_type": "failed to convert numpy.ndarray to Tensor", + "Context": "str(value)", + "Explanation": "Exception encountered when attempting to convert numpy.ndarray to Tensor", + "Hints": [] + } + ], + "GB0186": [ + { + "Gb_type": "functools.partial() with non-literal keyword", + "Context": "non-literal keyword: {k}", + "Explanation": "functools.partial() expects literal/string keywords", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0187": [ + { + "Gb_type": "functools.wraps", + "Context": "{fn}", + "Explanation": "`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0188": [ + { + "Gb_type": "getattr with no source", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo does not know how to access an attribute on an `nn.Module` instance that lacks a source. This is usually an internal error in Dynamo.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0189": [ + { + "Gb_type": "getattr() on nn.Module with pending mutation", + "Context": "getattr({obj}, {name}, {default})", + "Explanation": "Intentionally graph breaking on getattr() on a nn.Module with a pending mutation", + "Hints": [] + } + ], + "GB0190": [ + { + "Gb_type": "getattr() with non-constant name argument", + "Context": "getattr({obj}, {name_var}, {default})", + "Explanation": "getattr() with non-constant name argument is not supported", + "Hints": [ + "Ensure the name argument of getattr() is a string" + ] + } + ], + "GB0191": [ + { + "Gb_type": "id() with unsupported args", + "Context": "str(args)", + "Explanation": "Dynamo doesn't know how to trace id() call with args {args}", + "Hints": [ + "Supported args are Tensors, and functions/nn.Modules/user-defined objects ", + "from outside the compiled region.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0192": [ + { + "Gb_type": "input iterator to itertools.cycle has too many items", + "Context": "next({self})", + "Explanation": "Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}", + "Hints": [] + } + ], + "GB0193": [ + { + "Gb_type": "invalid call to builtin op handler", + "Context": "invalid args to {self_handler}: {args} {kwargs}", + "Explanation": "Encountered TypeError when trying to handle op {fn.__name__}", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0194": [ + { + "Gb_type": "isinstance() called on user defined object with C extensions", + "Context": "isinstance({arg}, {isinstance_type})", + "Explanation": "User-defined object with C extensions can have torch.Tensor attributes; intentionally graph breaking.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0195": [ + { + "Gb_type": "issubclass() with non-constant arguments", + "Context": "issubclass({left_ty}, {right_ty})", + "Explanation": "issubclass() with non-constant arguments not supported.", + "Hints": [ + "Make sure your arguments are types.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0196": [ + { + "Gb_type": "key not found in dict", + "Context": "Key {arg.value}", + "Explanation": "msg", + "Hints": [ + "Check if the key exists in the dictionary before accessing it.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0197": [ + { + "Gb_type": "list elements are pointing to the list itself", + "Context": "", + "Explanation": "Dynamo does not support lists whose items reference to itself", + "Hints": [ + "Avoid using self referential list" + ] + } + ], + "GB0198": [ + { + "Gb_type": "mapping proxy affected by dictionary mutation", + "Context": "Source: {self.source}, Dict mutation detected", + "Explanation": "msg", + "Hints": [ + "Avoid modifying dictionaries that might be referenced by mapping proxy objects", + "Or avoid using the mapping proxy objects after modifying its underlying dictionary" + ] + } + ], + "GB0199": [ + { + "Gb_type": "mapping proxy cannot be reconstructed", + "Context": "Source: {self.source}", + "Explanation": "msg", + "Hints": [ + "Use a mapping proxy constructed in the same `torch.compile` region.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0200": [ + { + "Gb_type": "missing BUILD_SET handler", + "Context": "", + "Explanation": "Missing BUILD_SET bytecode handler (for testing purposes).", + "Hints": [] + } + ], + "GB0201": [ + { + "Gb_type": "namedtuple construction", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "`torch.compile` only support certain input types for namedtuple", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0202": [ + { + "Gb_type": "non-const argument in nn.Module method", + "Context": "call_method: {self} {name} {args} {kwargs}", + "Explanation": "Dynamo does not support calling method `{name}` of ``nn.Module`` {module} with non-constant arguments.", + "Hints": [] + } + ], + "GB0203": [ + { + "Gb_type": "non-const keys in dict_keys", + "Context": "non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}", + "Explanation": "Dynamo expects dict_keys keys to be constants.", + "Hints": [ + "Ensure your dict_keys keys are constants (e.g. int, float, strings)" + ] + } + ], + "GB0204": [ + { + "Gb_type": "non-const keys in mappingproxy", + "Context": "non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", + "Explanation": "Dynamo expects mappingproxy keys to be constants.", + "Hints": [ + "Ensure your mappingproxy keys are constants (e.g. int, float, strings)" + ] + } + ], + "GB0205": [ + { + "Gb_type": "proxy not set", + "Context": "as_proxy {self}", + "Explanation": "Dynamo requires the autograd.Function context to be initialized with a proxy.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0206": [ + { + "Gb_type": "setattr() on Tensor.requires_grad", + "Context": "setattr({obj}, {name}, {val})", + "Explanation": "setattr() on Tensor.requires_grad not supported. Mutating requires_grad can introduce a new leaf from non-leaf or vice versa in the middle of the graph, which AOTAutograd does not currently know how to handle.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0207": [ + { + "Gb_type": "sort with non-constant keys", + "Context": "str(first_non_constant_key)", + "Explanation": "Cannot perform sort with non-constant key. First non-constant key type: {python_type}. Most notably, we cannot sort with Tensor or SymInt keys, but we can sort ints.", + "Hints": [ + "Use something else as the key." + ] + } + ], + "GB0208": [ + { + "Gb_type": "torch.* op returned non-Tensor", + "Context": "example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}", + "Explanation": "torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output", + "Hints": [] + } + ], + "GB0209": [ + { + "Gb_type": "torch.autograd._unsafe_preserve_version_counter escaped from compiled region", + "Context": "str(self)", + "Explanation": "Dynamo doesn't support compiling a region that returns a torch.autograd._unsafe_preserve_version_counter context manager.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0210": [ + { + "Gb_type": "torch.distributed package is not available!", + "Context": "", + "Explanation": "The PyTorch package doesn't include torch.distributed when building from source.", + "Hints": [ + "Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source." + ] + } + ], + "GB0211": [ + { + "Gb_type": "torch.nn.Module with a non-function custom __getattr__", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo detected a nn.Module object with a custom `__getattr__` method, but this method is not a standard Python function (e.g., it might be implemented in C/C++). Dynamo cannot currently trace into such non-standard `__getattr__` methods.", + "Hints": [ + "Avoid using objects with non-standard __getattr__ methods ", + "within the compiled region. If possible, implement ", + "__getattr__ as a standard Python function.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0212": [ + { + "Gb_type": "torch.profiler object escaped from compiled region", + "Context": "str(self)", + "Explanation": "Dynamo doesn't support compiling a region that returns a torch.profiler context manager.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0213": [ + { + "Gb_type": "unimplemented builtin op on tensor arguments", + "Context": "partial tensor op: {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace builtin operator {self.fn} with tensor arguments", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0214": [ + { + "Gb_type": "unsupported SymNode comparison op", + "Context": "{op.__name__}({left}, {right})", + "Explanation": "Dynamo does not support the comparison op {op.__name__} with SymNode arguments {left}, {right}", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0215": [ + { + "Gb_type": "unsupported Tensor comparison op", + "Context": "{op.__name__}({left}, {right})", + "Explanation": "Dynamo does not support the comparison op {op.__name__} with Tensor arguments {left}, {right}", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0216": [ + { + "Gb_type": "unsupported grid type for triton hop check_grid", + "Context": "grid type = {type(grid)}", + "Explanation": "`torch.compile` only supports list-like grid for check_grid", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0217": [ + { + "Gb_type": "unsupported hasattr operation", + "Context": "Class {self.user_cls}", + "Explanation": "msg", + "Hints": [ + "Consider using a regular dictionary instead", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0218": [ + { + "Gb_type": "unsupported index(Tensor)", + "Context": "", + "Explanation": "Dynamo does not support tracing builtin index() on a Tensor", + "Hints": [] + } + ], + "GB0219": [ + { + "Gb_type": "Backend compiler exception", + "Context": "Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}", + "Explanation": "Backend compiler `{name}` failed with {str(e)}. Adding a graph break.", + "Hints": [ + "Report an issue to the backend compiler repo." + ] + } + ], + "GB0220": [ + { + "Gb_type": "Failed to mutate tensor data attribute to different dtype", + "Context": "setattr({obj}, {name}, {val})", + "Explanation": "Dyanmo only supports mutating `.data` of tensor to a new one with the same dtype", + "Hints": [ + "Don't mutate `.data` on this tensor, or move ", + "the mutation out of `torch.compile` region" + ], + "Additional_Info": [ + "INFO" + ] + } + ], + "GB0221": [ + { + "Gb_type": "non-generator contextlib.contextmanager", + "Context": "str(self.vt.get_code())", + "Explanation": "Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator, i.e. does not use `yield`", + "Hints": [ + "Use `yield` in the function body instead of `return`.", + "Remove the `@contextlib.contextmanager` decorator." + ] + } + ], + "GB0222": [ + { + "Gb_type": "Attempted to wrap a set with tensors", + "Context": "Python set containing torch.Tensor elements", + "Explanation": "Dynamo cannot trace sets of tensors. To get a stable ordering, Dynamo needs to convert the set into a list and the order might not be stable if the set contains tensors.", + "Hints": [ + "Use a dictionary where the keys are tensors.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ] +} \ No newline at end of file diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index 52496638dd144d..a16a9f45a9b549 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -17,6 +17,7 @@ import io import logging import math +import operator import pickle from collections import defaultdict, deque from dataclasses import fields @@ -102,9 +103,11 @@ def dumps(self, obj: Any) -> bytes: self._stream.truncate(0) -def _extract_tensor_arg(arg: Any) -> Any: +def _extract_args(arg: Any) -> Any: if isinstance(arg, Node): return arg.meta.get("example_value") + elif isinstance(arg, (torch.Tensor, int)): + return arg else: return None @@ -113,11 +116,11 @@ def _normalize_args( node: Node, ) -> tuple[tuple[str, ...], tuple[Optional[Any], ...]]: flat_args, _ = tree_flatten(node.args) - sorted_kwargs = sorted(node.kwargs.items(), key=lambda x: x[0]) + sorted_kwargs = sorted(node.kwargs.items(), key=operator.itemgetter(0)) sorted_keys = tuple(sorted(node.kwargs.keys())) flat_kwargs, _ = tree_flatten(sorted_kwargs) all_args = flat_args + flat_kwargs - return (sorted_keys, tuple(_extract_tensor_arg(arg) for arg in all_args)) + return (sorted_keys, tuple(_extract_args(arg) for arg in all_args)) def get_global_state_key() -> GlobalStateKey: @@ -262,6 +265,16 @@ def track_node_mutations( if mutated_arg_positions: self.node_to_mutated_arg_positions[node] = mutated_arg_positions + def add_node_mutation( + self, + node: Node, + arg_pos: int, + ) -> None: + if node in self.node_to_mutated_arg_positions: + self.node_to_mutated_arg_positions[node].add(arg_pos) + else: + self.node_to_mutated_arg_positions[node] = OrderedSet([arg_pos]) + def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]: """ This function is responsible for extracting the largest regions of identical nodes from the given graph. @@ -419,7 +432,7 @@ def fully_expand_region_group( if add_to_all_regions: assert len(region_wrappers) == len(nodes_to_add), ( - "Numer of nodes to add must equal the number of regions" + "Number of nodes to add must equal the number of regions" ) for region_wrapper, node in zip(region_wrappers, nodes_to_add): region_wrapper.add(node) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 575f61c87f80be..bf790ba126adf2 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -106,6 +106,7 @@ ConstDictKeySource, DefaultsSource, DictGetItemSource, + DictSubclassGetItemSource, FlattenScriptObjectSource, FloatTensorSource, FSDPNNModuleSource, @@ -118,6 +119,7 @@ ListGetItemSource, LocalSource, NNModuleSource, + NonSerializableSetGetItemSource, NumpyTensorSource, OptimizerSource, ScriptObjectQualifiedNameSource, @@ -384,7 +386,7 @@ def check_verbose(self, x): def populate_code_parts_for_debugging(self): # This should be called when the guard manager is fully populated - tensor_aliasing_guard_seen = False + relational_guards_seen = set() def get_code_parts(leaf_guard): code_parts = [] @@ -394,12 +396,12 @@ def get_code_parts(leaf_guard): return code_parts def visit(mgr): - nonlocal tensor_aliasing_guard_seen + nonlocal relational_guards_seen for guard in mgr.get_leaf_guards(): - if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): # type: ignore[attr-defined] - if not tensor_aliasing_guard_seen: + if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] + if guard not in relational_guards_seen: self.code_parts.extend(get_code_parts(guard)) - tensor_aliasing_guard_seen = True + relational_guards_seen.add(guard) else: self.code_parts.extend(get_code_parts(guard)) @@ -418,7 +420,7 @@ def from_numpy(a): # For user stack printing -@functools.lru_cache(None) +@functools.cache def uninteresting_files(): import torch._dynamo.external_utils import torch._dynamo.polyfills @@ -496,6 +498,8 @@ def get_verbose_code_parts( def convert_int_to_concrete_values(dim) -> Optional[int]: + if dim is None: + return None if not is_symbolic(dim): return dim else: @@ -552,7 +556,7 @@ class NNModuleAttrAccessorInfo: # Either the actual name or _parameters/_buffers/_modules l1_key: Optional[str] = None - # Actual paramter/buffer/submodule name + # Actual parameter/buffer/submodule name l2_key: Optional[str] = None @@ -600,7 +604,7 @@ def getitem_on_dict_manager( def match_on_id_for_tensor(guard): source = guard.originating_source # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads - # to a new tensor everytime and therefore id differs. + # to a new tensor every time and therefore id differs. if isinstance(source, NumpyTensorSource): return False @@ -623,7 +627,7 @@ class GuardManagerType(enum.Enum): DICT_GUARD_MANAGER = 2 -@functools.lru_cache(None) +@functools.cache def code_framelocals_names_reversed_cached(code: types.CodeType): return list(reversed(code_framelocals_names(code))) @@ -640,12 +644,14 @@ def __init__( guard_manager: GuardManagerWrapper, check_fn_manager: CheckFunctionManager, serialization_mode: Optional[str] = None, + runtime_global_scope: Optional[dict[str, Any]] = None, ): self.f_code = f_code self.id_ref = id_ref self.source_ref = source_ref self.lookup_weakrefs = lookup_weakrefs self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope} + self.runtime_global_scope = runtime_global_scope or global_scope self.scope["__builtins__"] = builtins.__dict__.copy() for ( name, @@ -940,6 +946,11 @@ def get_guard_manager_type(self, source, example_value): # Fix this if condition if isinstance(example_value, dict_keys): guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER + elif isinstance(example_value, (set, frozenset)): + # we don't need to guard on key order for set/frozenset + # but the if above will be true for these types as set is + # implemented using a dict in Dynamo + guard_manager_enum = GuardManagerType.GUARD_MANAGER else: assert isinstance(example_value, dict) guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER @@ -950,7 +961,7 @@ def manager_guards_on_keys(self, mgr_enum): def get_global_guard_manager(self): return self.guard_manager.root.globals_dict_manager( - f_globals=self.scope["G"], + f_globals=self.runtime_global_scope, source="G", example_value=self.scope["G"], guard_manager_enum=GuardManagerType.GUARD_MANAGER, @@ -1095,7 +1106,7 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) - elif istype(source, DictGetItemSource): + elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)): assert base_guard_manager # to make mypy happy assert isinstance(base_example_value, (dict, collections.OrderedDict)) if isinstance(base_guard_manager, DictGuardManager): @@ -1284,6 +1295,14 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, NonSerializableSetGetItemSource): + assert base_guard_manager + out = base_guard_manager.set_getitem_manager( + index=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, WeakRefCallSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.weakref_call_manager( @@ -1456,7 +1475,11 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: def TYPE_MATCH(self, guard: Guard) -> None: # ___check_type_id is same as `id(type(x)) == y` value = self.get(guard.name) - t = type(value) + if isinstance(value, torch._subclasses.FakeTensor) and value.pytype: + t = value.pytype + else: + t = type(value) + if self.serialization_mode == "save": if t.__qualname__ != t.__name__: raise_local_type_error(value) @@ -1498,6 +1521,19 @@ def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): not invert, key, get_verbose_code_parts(code, guard) ) + def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool): + set_ref = self.arg_ref(guard) + item = key + contains = not invert # install_dict_contains_guard inverts "contains" + + code = f"set.__contains__({set_ref}, {item!r})" + + self._set_guard_export_info(guard, [code]) + + self.get_guard_manager(guard).add_set_contains_guard( + contains, item, get_verbose_code_parts(code, guard) + ) + def BOOL_MATCH(self, guard: Guard): # checks val == True or val == False ref = self.arg_ref(guard) @@ -1530,6 +1566,9 @@ def NONE_MATCH(self, guard: Guard): def ID_MATCH(self, guard: Guard): if self.serialization_mode == "save": raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.") + return self.id_match_unchecked(guard) + + def id_match_unchecked(self, guard: Guard): # ___check_obj_id is same as `id(x) == y` if isinstance(guard.originating_source, TypeSource): # optional optimization to produce cleaner/faster guard code @@ -1665,7 +1704,7 @@ def metadata_checker(x): metadata_checker, get_verbose_code_parts(global_name, guard) ) - def EQUALS_MATCH(self, guard: Guard): + def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None): ref = self.arg_ref(guard) val = self.get(guard.name) if np: @@ -1762,9 +1801,14 @@ def EQUALS_MATCH(self, guard: Guard): # is immutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the # pointer equality check. val = deepcopy(val) - self.get_guard_manager(guard).add_equals_match_guard( - val, get_verbose_code_parts(code, guard) - ) + + verbose_code_parts = get_verbose_code_parts(code, guard) + if recompile_hint: + verbose_code_parts = [ + f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts + ] + + self.get_guard_manager(guard).add_equals_match_guard(val, verbose_code_parts) self._set_guard_export_info(guard, code) return @@ -1826,7 +1870,15 @@ def CLOSURE_MATCH(self, guard: Guard): self.FUNCTION_MATCH(guard) def BUILTIN_MATCH(self, guard: Guard): - return self.FUNCTION_MATCH(guard) + if self.serialization_mode == "save": + # Record which builtin variables are used for pruning later. + if isinstance(guard.originating_source, DictGetItemSource): + self.check_fn_manager.used_builtin_vars.add( + guard.originating_source.index + ) + return self.id_match_unchecked(guard) + + return self.ID_MATCH(guard) def SEQUENCE_LENGTH(self, guard): # This guard is used to check length of PySequence objects like list, @@ -1891,9 +1943,9 @@ def RANGE_ITERATOR_MATCH(self, guard): # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards def DUPLICATE_INPUT(self, guard, source_b): if self.serialization_mode == "save": - raise torch._dynamo.exc.PackageError( - "DUPLICATE_INPUT guard cannot be serialized yet." - ) + if name := get_local_source_name(source_b): + self.check_fn_manager.additional_used_local_vars.add(name) + ref_a = self.arg_ref(guard) ref_b = self.arg_ref(source_b.name()) @@ -2100,11 +2152,17 @@ def _get_code_parts(langs): assert maybe_cpp_code_parts is None or isinstance( maybe_cpp_code_parts, _CppShapeGuardsHelper ) + maybe_shape_env_sources = ( + [] + if maybe_cpp_code_parts is None + else list(maybe_cpp_code_parts.source_to_symbol.keys()) + ) self.check_fn_manager.shape_code_parts = ShapeCodeParts( python_code_parts=python_code_parts, verbose_code_parts=verbose_code_parts, cpp_code_parts=maybe_cpp_code_parts, python_fallback=python_fallback, + shape_env_sources=maybe_shape_env_sources, ) for code in python_code_parts.exprs: @@ -2178,6 +2236,7 @@ def _get_code_parts(langs): func_str = textwrap.dedent( f""" + #include #include #include #include @@ -2576,6 +2635,7 @@ class ShapeCodeParts: verbose_code_parts: _ShapeGuardsHelper cpp_code_parts: Optional[_CppShapeGuardsHelper] python_fallback: bool + shape_env_sources: list[Source] @dataclasses.dataclass @@ -2669,7 +2729,7 @@ def reducer_override(self, obj): ) return type(self)._unpickle_tensor, ( - torch.empty_like(obj, device="meta"), + torch.empty_like(obj, device="meta", requires_grad=obj.requires_grad), obj.device, type(obj), torch._C._dispatch_keys(obj).raw_repr(), @@ -2741,6 +2801,7 @@ def __init__( ] = None, guards_serialization_mode: Optional[str] = None, shape_code_parts: Optional[ShapeCodeParts] = None, + runtime_global_scope: Optional[dict[str, Any]] = None, ): guards = output_graph.guards if output_graph else None self._weakrefs: dict[int, ReferenceType[object]] = {} @@ -2759,6 +2820,11 @@ def __init__( output_graph.torch_function_mode_stack if output_graph else None ) self.guards_serialization_mode = guards_serialization_mode + self.used_builtin_vars: OrderedSet[str] = OrderedSet() + self.additional_used_local_vars: OrderedSet[str] = OrderedSet() + if runtime_global_scope: + assert self.guards_serialization_mode == "load" + self.runtime_global_scope = runtime_global_scope if not justknobs_check("pytorch/compiler:guard_nn_modules"): log.warning("guard_nn_modules is turned off using justknobs killswitch") @@ -2851,12 +2917,14 @@ def make_guard_filter_entry(guard): self.guard_manager, output_graph.local_scope ) - # NB for developers: n_iters is chosen to be 50 to achieve - # statistical significance. If you are working on a guard - # optimization, it might be a good idea to increase this number for - # more stabiilty during development. + # NB for developers: n_iters is chosen to be 1 to prevent excessive + # increase in compile time. We first do a cache flush to measure the + # guard latency more accurately. This cache flush is expensive. + # Note - If you are working on a guard optimization, it might be a + # good idea to increase this number for more stabiilty during + # development. latency = profile_guard_manager( - self.guard_manager.root, output_graph.local_scope, 50 + self.guard_manager.root, output_graph.local_scope, 1 ) guards_log.debug("Guard eval latency = %s us", f"{latency:.2f}") # Note: We use `increment_toplevel` instead of `compilation_metric` @@ -2871,6 +2939,7 @@ def make_guard_filter_entry(guard): CompileEventLogger.increment_toplevel("guard_latency_us", int(latency)) self.guards_state: Optional[bytes] = None + builtins_dict_name = self.output_graph.name_of_builtins_dict_key_in_fglobals if self.guards_serialization_mode == "save": used_global_vars = set() used_local_vars = set() @@ -2878,7 +2947,11 @@ def make_guard_filter_entry(guard): def prune_variable(source): if name := get_global_source_name(source): assert isinstance(name, str) - used_global_vars.add(name) + # Leave out the builtins dict key, as we will special handle + # it later because the guarded code rarely use the entire + # builtin dict in the common case. + if name not in (builtins_dict_name,): + used_global_vars.add(name) elif name := get_local_source_name(source): assert isinstance(name, str) used_local_vars.add(name) @@ -2886,7 +2959,13 @@ def prune_variable(source): output_graph_guards_state = self.output_graph.dump_guards_state() # Only serialize the global variables that are actually used in guards. for guard in sorted_guards: - prune_variable(guard.originating_source) + if isinstance(guard.originating_source, ShapeEnvSource): + assert self.shape_code_parts + for source in self.shape_code_parts.shape_env_sources: + prune_variable(source) + else: + prune_variable(guard.originating_source) + for source in self.output_graph.guard_on_key_order: prune_variable(source) @@ -2904,18 +2983,26 @@ def _ref(x): return x + global_scope_state = { + k: v + for k, v in output_graph_guards_state.global_scope.items() + if k in used_global_vars + } + global_scope_state[builtins_dict_name] = { + k: v + for k, v in output_graph_guards_state.global_scope[ + builtins_dict_name + ].items() + if k in self.used_builtin_vars + } output_graph_guards_state = dataclasses.replace( output_graph_guards_state, local_scope={ k: v for k, v in output_graph_guards_state.local_scope.items() - if k in used_local_vars - }, - global_scope={ - k: v - for k, v in output_graph_guards_state.global_scope.items() - if k in used_global_vars + if k in used_local_vars or k in self.additional_used_local_vars }, + global_scope=global_scope_state, _guards=torch._guards.GuardsSet( { dataclasses.replace( @@ -2987,6 +3074,7 @@ def source_ref(source): guard_manager, self, serialization_mode, + runtime_global_scope=self.runtime_global_scope, ) # Break retain cycle. See test_release_scope_memory @@ -3033,7 +3121,11 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): ) # Insert the global_state guard - self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) + assert self.output_graph is not None + global_state = self.output_graph.global_state_guard + self.guard_manager.root.add_global_state_guard( + global_state, ["___check_global_state()"] + ) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, @@ -3160,8 +3252,7 @@ def add_code_part(code_part, guard, log_only=False): "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] ) - global_state = convert_frame.initial_global_state - if global_state is None: + if convert_frame.initial_global_state is None: # we should only hit this case in NopTests() global_state = convert_frame.GlobalStateGuard() closure_vars = { @@ -3560,7 +3651,7 @@ def make_dupe_guard(obj_source, dupe_source): dupe_source ) or is_from_flatten_script_object_source(obj_source): raise exc.UnsafeScriptObjectError( - f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported." + f"{obj_source.name()} is aliasing {dupe_source.name()}. This is not supported." f" Please do a clone for corresponding input." ) diff --git a/torch/_dynamo/hooks.py b/torch/_dynamo/hooks.py index c362f04ebe7060..e180ad6dedf041 100644 --- a/torch/_dynamo/hooks.py +++ b/torch/_dynamo/hooks.py @@ -6,13 +6,10 @@ The Hooks class manages two types of hook functions: - guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input - guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input -- frame_traced_fn: Called when a frame has finished tracing, resulting in a non-empty graph. - This hook will be passed the set of filenames containing traced code These hooks enable customization of guard export and failure handling behaviors. """ import dataclasses -from types import CodeType from typing import Callable, Optional from torch._guards import GuardsSet @@ -25,4 +22,3 @@ class Hooks: guard_export_fn: Optional[Callable[[GuardsSet], None]] = None guard_fail_fn: Optional[Callable[[GuardFail], None]] = None guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None - frame_traced_fn: Optional[Callable[[list[CodeType]], None]] = None diff --git a/torch/_dynamo/logging.py b/torch/_dynamo/logging.py index 2d67665f5e9d7a..18febf1377cc34 100644 --- a/torch/_dynamo/logging.py +++ b/torch/_dynamo/logging.py @@ -33,7 +33,7 @@ def get_loggers() -> list[logging.Logger]: # get_step_logger should be lazily called (i.e. at runtime, not at module-load time) # so that step numbers are initialized properly. e.g.: -# @functools.lru_cache(None) +# @functools.cache # def _step_logger(): # return get_step_logger(logging.getLogger(...)) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 374f3f9cdbe5df..61a1447fdd8da4 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -43,13 +43,15 @@ import torch.distributed as dist import torch.nn import torch.utils._pytree as pytree -from torch import fx +from torch import fx, Tensor +from torch._C._dynamo import guards from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis from torch._guards import ( CompileContext, CompileId, GlobalContextCheckpointState, Source, + tracing, TracingContext, ) from torch._subclasses.fake_tensor import FakeTensor @@ -61,6 +63,7 @@ guard_scalar, is_symbolic, ShapeEnv, + Specialization, ) from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from torch.multiprocessing.reductions import StorageWeakRef @@ -122,6 +125,7 @@ get_unique_name_wrt, graph_break_reasons, increment_op_count, + istype, lazy_format_graph_code, LazyString, nn_module_proxy, @@ -158,6 +162,8 @@ graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") +RootGuardManager = guards.RootGuardManager + @dataclass(frozen=True) class VariableTrackerCacheKey: @@ -204,7 +210,7 @@ def clear(self): self.cache.clear() -@functools.lru_cache(None) +@functools.cache def _step_logger(): return torchdynamo_logging.get_step_logger(log) @@ -304,6 +310,8 @@ class OutputGraphGuardsState: dual_level: int functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter] current_device: Optional[torch.device] + global_state_guard: torch._C._dynamo.guards.GlobalStateGuard + name_of_builtins_dict_key_in_fglobals: Optional[str] = None export: bool = False export_constraints: bool = False @@ -337,6 +345,26 @@ class StackLocalsMetadata: locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list) +def get_builtins_dict(global_scope): + # f_globals["__builtins__"] can be a dict or a module. This is an + # implemenation detail - + # https://docs.python.org/3/library/builtins.html. + + # This makes guarding on any builtin messy because the guard check_fn + # has to check if the __builtins__ is a module or dict, and then access + # by either using getattr or getitem respectively. + + # To solve this problem, we insert a new entry in f_globals which points + # to the builtins __dict__ and then we guard any builtin on this dict. + # To avoid any collision with the pre-existing keys, we use the + # install_global to give us a unique dict key. + + f_builtins = global_scope["__builtins__"] + if not isinstance(f_builtins, dict): + f_builtins = f_builtins.__dict__ + return f_builtins + + class OutputGraph(OutputGraphGuardsState): """ Wrapper class to hold outputs of InstructionTranslator. Mainly the @@ -362,6 +390,7 @@ def __init__( global_scope: Scope, f_code, torch_function_mode_stack, + package, ): super().__init__( local_scope, @@ -372,6 +401,9 @@ def __init__( dual_level=torch.autograd.forward_ad._current_level, functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(), current_device=torch.utils._device.CURRENT_DEVICE, + # initial_global_state is only None during NopTest. + global_state_guard=torch._dynamo.convert_frame.initial_global_state + or torch._C._dynamo.guards.GlobalStateGuard(), ) self.tracers = [SubgraphTracer(self, is_export=export)] # Map from graph input's `Source` to its `VariableTracker` to @@ -386,8 +418,8 @@ def __init__( # Set of globals installed via install_global* APIs self.installed_globals: set[str] = set() - self.f_code = f_code - # TODO: maybe should only store the entire f_code + # TODO: maybe should just pass the entire f_code in here? Not + # sure... self.co_fields = { "co_name": f_code.co_name, "co_filename": f_code.co_filename, @@ -431,6 +463,7 @@ def __init__( export=self.export, ) self.tracing_context: TracingContext = TracingContext(fake_mode) + self.tracing_context.traced_code.append(f_code) self.dynamo_compile_id: Optional[CompileId] = ( CompileContext.current_compile_id() ) @@ -464,6 +497,7 @@ def __init__( self.compiler_fn: Optional[CompilerFn] = compiler_fn self.root_tx = root_tx + self.package = package # Given a source, what are the user stacks of all locations that # accessed it? # @@ -553,22 +587,7 @@ def mark_bytecode_tracing_stop(self): self.compiler_trace_stack.close() def install_builtins_dict_in_fglobals(self): - # f_globals["__builtins__"] can be a dict or a module. This is an - # implemenation detail - - # https://docs.python.org/3/library/builtins.html. - - # This makes guarding on any builtin messy because the guard check_fn - # has to check if the __builtins__ is a module or dict, and then access - # by either using getattr or getitem respectively. - - # To solve this problem, we insert a new entry in f_globals which points - # to the builtins __dict__ and then we guard any builtin on this dict. - # To avoid any collision with the pre-existing keys, we use the - # install_global to give us a unique dict key. - - f_builtins = self.global_scope["__builtins__"] - if not isinstance(f_builtins, dict): - f_builtins = f_builtins.__dict__ + f_builtins = get_builtins_dict(self.global_scope) return self.install_global("__builtins_dict__", f_builtins) def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"): @@ -666,6 +685,8 @@ def dump_guards_state(self): dual_level=self.dual_level, functorch_layers=self.functorch_layers, current_device=self.current_device, + global_state_guard=self.global_state_guard, + name_of_builtins_dict_key_in_fglobals=self.name_of_builtins_dict_key_in_fglobals, export=self.export, export_constraints=self.export_constraints, _guards=self.guards, @@ -1104,7 +1125,7 @@ def handle_aliases_for_stolen_lists(self, tx): # A small codegen optimization because we might have different # VariableTrackers that share the same source. - list_idx = x.source.index + list_idx = x.source.index # type: ignore[attr-defined] if list_idx not in visited: alias_name = self.new_var( f"{list_name}_ref" @@ -1389,12 +1410,20 @@ def compile_subgraph( ) self.codegen_suffix(tx, stack_values_flat, pass1) - # one more time now that we have established tempvars + # Use `pass1.uses` to selectively cache multi-user variables into a + # temporary local source. This (a). speeds up loading VTs with long + # chained source, and (b). avoids redundantly saving single-user VT + # into a temporary local. + tempvars = {} # type: ignore[var-annotated] + for val, count in pass1.uses.items(): + # If it's already a local source, no need to cache it + if count > 1 and not istype(val, (SyntheticLocalSource, LocalSource)): + tempvars[val] = None pass2 = PyCodegen( self.root_tx, root, graph_output_var, - tempvars={val: None for val, count in pass1.uses.items() if count > 1}, + tempvars=tempvars, overridden_sources=overridden_sources, ) self.codegen_suffix(tx, stack_values_flat, pass2) @@ -1640,6 +1669,7 @@ def compile_and_call_fx_graph(self, tx, rv, root): for register_finalizer in self.register_finalizer_fns: register_finalizer(gm) + gm._backend_id = name gm.compile_subgraph_reason = self.compile_subgraph_reason gm.meta["dynamo_flat_name_to_original_fqn"] = ( self.dynamo_flat_name_to_original_fqn.copy() @@ -1675,7 +1705,7 @@ def compile_and_call_fx_graph(self, tx, rv, root): self.tracing_context.fake_mode = backend_fake_mode with self.restore_global_state(): - compiled_fn = self.call_user_compiler(gm) + compiled_fn = self.call_user_compiler(gm, self.example_inputs()) from torch.fx._lazy_graph_module import _LazyGraphModule @@ -1700,13 +1730,74 @@ def compile_and_call_fx_graph(self, tx, rv, root): # replace compiled_fn with the real forward method compiled_fn = lazy_gm.forward + if self.package is not None: + self.package.add_backend_id(name, compiled_fn) + compiled_fn = disable( compiled_fn, reason="do not trace Dynamo-compiled graph" ) counters["stats"]["unique_graphs"] += 1 - # This is safe because we pre-process name to be unique - self.install_global_unsafe(name, compiled_fn) + if specializations := old_fake_mode.shape_env.specializations: + specialization_guards = [] + specialization_cache: dict[Specialization, Callable[[Any], Any]] = {} + sources = [a.source for a in self.graphargs] + for specialization in specializations: + source_index = sources.index(specialization.source) + check_fn_source = inspect.getsource(specialization.check_fn).strip() + # Required because the LABDA_GUARD API requires a root guard manager + unused_root_guard_manager = RootGuardManager() + check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] + unused_root_guard_manager, + specialization.check_fn, + [check_fn_source], + ) + + log.debug( + "Compiling backend specialized graph with specialization=%s", + check_fn_source, + ) + + specialization_guards.append( + ( + functools.partial( + lambda idx, args, check_fn=check_fn: check_fn( + args[idx] + ), + source_index, + ), + specialization, + ) + ) + + @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") + def specialized_dispatch(*args, **kwargs): + for check_fn, specialization in specialization_guards: + if check_fn(args): + if specialization in specialization_cache: + return specialization_cache[specialization]( + *args, **kwargs + ) + + with self.shape_env.patch_source_specialization( + specialization.source, specialization.check_fn + ): + # Modify gm so AOTAutogradCache key changes per specialization + gm.meta["specialization"] = specialization + example_inputs: list[Tensor] = list(args) + with tracing(self.tracing_context): + specialization_cache[specialization] = ( + self.call_user_compiler(gm, example_inputs) + ) + + return specialization_cache[specialization](*args, **kwargs) + return compiled_fn(*args, **kwargs) + + # This is safe because we pre-process name to be unique + self.install_global_unsafe(name, specialized_dispatch) + else: + # This is safe because we pre-process name to be unique + self.install_global_unsafe(name, compiled_fn) assert self.root_tx is not None cg = PyCodegen(self.root_tx) @@ -1721,7 +1812,9 @@ def placeholders(self) -> list[fx.Node]: def graphargs(self) -> list[GraphArg]: return [node.meta["grapharg"] for node in self.placeholders] - def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: + def call_user_compiler( + self, gm: fx.GraphModule, example_inputs: list[Tensor] + ) -> CompiledFn: with dynamo_timed( "OutputGraph.call_user_compiler", phase_name="backend_compile", @@ -1730,9 +1823,11 @@ def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: waitcounter_name_override="compile_aot_autograd", dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", ): - return self._call_user_compiler(gm) + return self._call_user_compiler(gm, example_inputs) - def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: + def _call_user_compiler( + self, gm: fx.GraphModule, example_inputs: list[Tensor] + ) -> CompiledFn: assert self.compiler_fn is not None tot = 0 placeholders = [] @@ -1743,10 +1838,11 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: placeholders.append(node) increment_op_count(tot) for pl in placeholders: - arg = pl.meta["grapharg"] - # TODO: Why isn't this stored in meta :think: - # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 - pl._dynamo_source = arg.source + if not hasattr(pl, "_dynamo_source"): + arg = pl.meta["grapharg"] + # TODO: Why isn't this stored in meta :think: + # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 + pl._dynamo_source = arg.source # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] @@ -1762,7 +1858,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: compiler_fn = self.compiler_fn if config.verify_correctness: compiler_fn = WrapperBackend(compiler_fn) - compiled_fn = compiler_fn(gm, self.example_inputs()) + compiled_fn = compiler_fn(gm, example_inputs) _step_logger()(logging.INFO, f"done compiler function {name}") assert callable(compiled_fn), "compiler_fn did not return callable" except (TensorifyScalarRestartAnalysis, ShortenTraceback): @@ -2238,11 +2334,16 @@ def __init__(self, output_graph, parent=None, is_export=False, source_target=Non # True if this tracer is currently tracing into torch.utils.checkpoint # as part of speculate_subgraph. self.under_activation_checkpoint = False - # True if we want to allow side-effects (doesn't throw error on their existence) + # True if we want to allow externally visible side-effects (doesn't throw error on their existence) # during this tracer's tracing of torch.utils.checkpoint (via speculate_subgraph). # Only safe if we know for sure that *NOT* replaying these side-effects during # backward recomputation of the checkpoint region doesn't affect its correctness. self.allow_side_effects_under_checkpoint = False + # True if we want to allow externally visible side-effects (doesn't throw error on their existence) + # during this tracer's tracing. This is currently only used by experimental AC out-of-tree + # via torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer. + # Note: Externally visible side-effects are allowed if this flag OR the above flag is True. + self.unsafe_allow_externally_visible_side_effects = False # True if this tracer is currently tracing (reconstructing) into a Python generator self.is_reconstructing_generator = False @@ -2651,7 +2752,7 @@ def lift_tracked_freevar_to_input(self, proxy): ): return self.bound_symbols[example_value.node.expr] - # Proxys are associated with VariableTracker. + # Proxies are associated with VariableTracker. # It is possible that we've already lifted the Proxy to be an input. # If that is the case, just return the already lifted Proxy. if proxy in self.lifted_freevars: @@ -2705,7 +2806,7 @@ def track_unbacked_symbols( self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy] ): # When binding the symbols in an exmaple_value, we bind the symbols - # to the proxy's associatied Tracer instead of current tracer. + # to the proxy's associated Tracer instead of current tracer. # This is because: # 1. We may be calling wrap_tensors during speculate_subgraph because # the variables are lazily realized. The proxy are top-level phs but diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py new file mode 100644 index 00000000000000..61d4b2821c36d9 --- /dev/null +++ b/torch/_dynamo/package.py @@ -0,0 +1,627 @@ +""" +This module provides the infrastructure for creating and managing compile package +for torch.compile. We mainly have two abstractions here: + - CompilePackage: Overarching data structure for store and lookup a list of compiled codes. + - CodeCacheEntry: Data structure for a single code being compiled by torch.compile. +The caching behavior is always under user control explicitly so that a stronger guarantee can +be provided about cache hit for a specific compiled model. Users can load the compile package +from a different process or host. +""" + +import abc +import contextlib +import dataclasses +import functools +import hashlib +import importlib +import inspect +import logging +import os +import pickle +import platform +import sys +import types +from collections.abc import Generator +from typing import Any, NewType, Optional + +import torch +import torch._inductor.package +from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext +from torch.compiler._cache import CacheArtifactFactory + +from .bytecode_transformation import get_code_keys + + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class SerializedCode: + co_argcount: int + co_posonlyargcount: int + co_kwonlyargcount: int + co_nlocals: int + co_stacksize: int + co_flags: int + co_code: bytes + co_consts: tuple[Any, ...] + co_names: tuple[str, ...] + co_varnames: tuple[str, ...] + co_filename: str + co_name: str + co_firstlineno: int + co_cellvars: tuple[str, ...] + co_freevars: tuple[str, ...] + co_linetable: Optional[bytes] = None + co_qualname: Optional[str] = None + co_exceptiontable: Optional[bytes] = None + co_lnotab: Optional[str] = None + + @classmethod + @functools.cache + def from_code_object(cls, code: types.CodeType) -> "SerializedCode": + kwargs = {key: getattr(code, key) for key in get_code_keys()} + kwargs["co_consts"] = tuple( + cls.from_code_object(c) if isinstance(c, types.CodeType) else c + for c in kwargs["co_consts"] + ) + return cls(**kwargs) + + @classmethod + @functools.cache + def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType: + kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()} + kwargs["co_consts"] = tuple( + cls.to_code_object(c) if isinstance(c, SerializedCode) else c + for c in kwargs["co_consts"] + ) + return types.CodeType( + *kwargs.values(), + ) + + +@dataclasses.dataclass +class _GuardedCodeCacheEntry: + """ + Contains the serializable information associated with a single compilation in dynamo. + To restore an execution of compiled code, we will need to serialize the following data: + - Dynamo bytecode for mapping Python inputs/outputs. + - Dynamo guards. + """ + + guards_state: bytes + dynamo_code: SerializedCode + + +_BackendId = NewType("_BackendId", str) # __compiled_fn +_FunctionId = NewType("_FunctionId", str) # __resume_at + + +@dataclasses.dataclass(frozen=True) +class InlinedSource: + module: str + firstlineno: int + lastlineno: int + checksum: str + + +@dataclasses.dataclass +class _DynamoCodeCacheEntry: + """ + Contains the serializable information associated with a single code object + in dynamo. To restore an execution of compiled code, we will need the following + ingredients: + 1. The "original" code object, which serves as the entry point for eager + execution, i.e. the code only executed when there's no cache entry hit. + 2. The python module name this code object belongs to, for identifying the + enclosing global scope to inject compiled and resume functions. + 3. A list of function names that pointing to this code object. There could be + multiple function objects pointing to the same code such as recursive functions. + 4. A list of guarded code that eval frame dispatches to. + 5. A list of imported module objects unioned from all compiled branches. + 6. A list of "backends" (compiled fx graph) unioned from all compield branches. + """ + + python_code: SerializedCode + python_module: str + function_names: list[_FunctionId] + guarded_codes: list[_GuardedCodeCacheEntry] + import_sources: dict[str, str] + backend_ids: list[_BackendId] + + +@dataclasses.dataclass +class _DynamoCacheEntry: + codes: list[_DynamoCodeCacheEntry] + inlined_sources: set[InlinedSource] + python_version: str = platform.python_version() + torch_version: str = torch.__version__ + + @property + def backend_ids(self) -> set[_BackendId]: + return {backend_id for code in self.codes for backend_id in code.backend_ids} + + +@CacheArtifactFactory.register +class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]): + @staticmethod + def type() -> str: + return "precompile_dynamo" + + def after_deserialization(self) -> _DynamoCacheEntry: + return pickle.loads(self.content) + + +def _hash_source(source: str) -> str: + sha256_hash = hashlib.sha256() + sha256_hash.update(source.encode()) + return sha256_hash.hexdigest() + + +def _get_sourcelines( + m: types.ModuleType, firstlineno: int, lastlineno: int +) -> list[str]: + return inspect.getsourcelines(m)[0][firstlineno - 1 : lastlineno - 1] + + +def _hash_sourcelines(m: types.ModuleType, firstlineno: int, lastlineno: int) -> str: + return _hash_source("".join(_get_sourcelines(m, firstlineno, lastlineno))) + + +class CompilePackage: + """ + CompilePackage is considered a low level component and should not be directly exposed to + end users. It has the following interface: + + 1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states. + a. when `dynamo` argument is None, it will construct a brand new CompilePackage object. + b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state. + 2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object. + 3. `package.install(backends) which will handle all the side-effectful global scope + updates with compiled functions and resume functions. + """ + + def __init__( + self, + fn: Any, + dynamo: Optional[_DynamoCacheEntry] = None, + ignore_inlined_sources: bool = False, + ) -> None: + self._innermost_fn = None + self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {} + + self._current_entry: Optional[_DynamoCodeCacheEntry] = None + self._installed_globals: dict[types.ModuleType, list[str]] = {} + + # For debugging/testing purpose only. + self._cached_backends: dict[_BackendId, Any] = {} + self._inlined_sources: set[InlinedSource] = set() + self._resume_codes: set[types.CodeType] = set() + + self._initialize(fn, dynamo, ignore_inlined_sources) + self.uninstall() + self.validate() + + def _initialize( + self, + fn: Any, + dynamo: Optional[_DynamoCacheEntry] = None, + ignore_inlined_sources: bool = False, + ) -> None: + from .eval_frame import innermost_fn + + self._inlined_sources = set() + self._innermost_fn = innermost_fn(fn) + assert self._innermost_fn is not None + if dynamo is not None: + assert isinstance(dynamo, _DynamoCacheEntry) + if dynamo.python_version != platform.python_version(): + raise RuntimeError( + f"Compile package was created with a different Python version: {dynamo.python_version}" + ) + if dynamo.torch_version != torch.__version__: + raise RuntimeError( + f"Compile package was created with a different PyTorch version: {dynamo.torch_version}" + ) + if not ignore_inlined_sources: + for code in dynamo.inlined_sources: + m = importlib.import_module(code.module) + checksum = _hash_sourcelines(m, code.firstlineno, code.lastlineno) + if checksum != code.checksum: + raise RuntimeError( + f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})" + ) + + self._inlined_sources = dynamo.inlined_sources + + main, *codes = dynamo.codes + self._codes = {self._innermost_fn.__code__: main} + for code in codes: + self._codes[SerializedCode.to_code_object(code.python_code)] = code + else: + self._add_function( + self._innermost_fn.__code__, self._innermost_fn.__module__ + ) + + def _add_function( + self, + python_code: types.CodeType, + python_module: str, + name: Optional[_FunctionId] = None, + ) -> None: + if python_code not in self._codes: + code = _DynamoCodeCacheEntry( + python_code=SerializedCode.from_code_object(python_code), + python_module=python_module, + function_names=[], + guarded_codes=[], + import_sources={}, + backend_ids=[], + ) + self._codes[python_code] = code + else: + code = self._codes[python_code] + assert code.python_module == python_module + + if name is not None: + code.function_names.append(name) + + @property + def cached_backends(self) -> dict[_BackendId, Any]: + return self._cached_backends + + @functools.cached_property + def source_id(self) -> str: + assert self._innermost_fn is not None + sha256_hash = hashlib.sha256() + sha256_hash.update(self._innermost_fn.__qualname__.encode()) + sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode()) + return sha256_hash.hexdigest() + + @contextlib.contextmanager + def code_context(self, code: types.CodeType) -> Generator[None, None, None]: + assert self._current_entry is None + + entry = self._codes[code] + self._current_entry = entry + try: + yield + finally: + self._current_entry = None + + def add_guarded_code( + self, + guards_state: bytes, + dynamo_code: types.CodeType, + ) -> None: + assert self._current_entry is not None + guarded_code_entry = _GuardedCodeCacheEntry( + guards_state=guards_state, + dynamo_code=SerializedCode.from_code_object(dynamo_code), + ) + self._current_entry.guarded_codes.append(guarded_code_entry) + + def add_inlined_source(self, sources: list[types.CodeType]) -> None: + for code in sources: + if code in self._resume_codes: + continue + module = inspect.getmodule(code) + if module is None: + continue + source = inspect.getsource(code) + lastlineno = code.co_firstlineno + len(inspect.getsourcelines(code)[0]) + assert source == "".join( + _get_sourcelines(module, code.co_firstlineno, lastlineno) + ) + self._inlined_sources.add( + InlinedSource( + module=module.__name__, + firstlineno=code.co_firstlineno, + lastlineno=lastlineno, + checksum=_hash_source(source), + ) + ) + + def add_resume_function( + self, + python_code: types.CodeType, + python_module: str, + name: Optional[str], + ) -> None: + self._add_function( + python_code, python_module, _FunctionId(name) if name else None + ) + self._resume_codes.add(python_code) + + def add_import_source(self, alias: str, module_name: str) -> None: + assert self._current_entry is not None + self._current_entry.import_sources[alias] = module_name + + def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None: + assert self._current_entry is not None + assert backend_id.startswith("__compiled_fn_") # sanity check + backend_id = _BackendId(backend_id) + self._current_entry.backend_ids.append(backend_id) + if backend is not None: + self._cached_backends[backend_id] = backend + + def validate(self) -> None: + assert self._current_entry is None + assert self._innermost_fn is not None + assert next(iter(self._codes)) is self._innermost_fn.__code__ + + def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None: + module.__dict__[name] = value + self._installed_globals.setdefault(module, []).append(name) + + def uninstall(self) -> None: + from torch._C._dynamo.eval_frame import _reset_precompile_entries + + assert self._innermost_fn is not None + for module, names in self._installed_globals.items(): + for name in names: + module.__dict__.pop(name) + + self._installed_globals = {} + + _reset_precompile_entries(self._innermost_fn.__code__) + + def install(self, backends: dict[_BackendId, Any]) -> None: + """ + Sync the package states to the compiled function. This includes the following actions: + 1. Clean up the previously installed states. + 2. Install the compiled functions to global scopes. + 3. Install the precompiled cache entries to ExtraStates on the code object. + """ + from torch._C._dynamo.eval_frame import _load_precompile_entry + + from .output_graph import get_builtins_dict + + self.uninstall() + + for code, entry in self._codes.items(): + module = sys.modules[entry.python_module] + for alias, module_name in entry.import_sources.items(): + self._install_global( + module, alias, importlib.import_module(module_name) + ) + for function_name in entry.function_names: + fn = types.FunctionType(code, module.__dict__, function_name) + self._install_global(module, function_name, fn) + for backend_id in entry.backend_ids: + if backend_id not in backends: + raise RuntimeError( + f"Backend {backend_id} is not found in the given backends" + ) + backend = backends[backend_id] + self._install_global( + module, + backend_id, + torch._dynamo.disable(backend), + ) + + for code, entry in self._codes.items(): + for guarded_code in entry.guarded_codes: + guards_state = pickle.loads(guarded_code.guards_state) + runtime_global_scope = sys.modules[entry.python_module].__dict__ + # The installed builtins dict might be absent from the runtime + # while loading guards. Populate it if it's missing. + if ( + builtin_dict_name + := guards_state.output_graph.name_of_builtins_dict_key_in_fglobals + ): + builtins_dict = get_builtins_dict(runtime_global_scope) + if builtin_dict_name in runtime_global_scope: + assert runtime_global_scope[builtin_dict_name] is builtins_dict + else: + runtime_global_scope[builtin_dict_name] = builtins_dict + assert isinstance(guards_state, torch._dynamo.guards.GuardsState) + check_fn_manager = torch._dynamo.guards.CheckFunctionManager( + code, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + runtime_global_scope=runtime_global_scope, + ) + _load_precompile_entry( + code, + check_fn_manager.guard_manager, + SerializedCode.to_code_object(guarded_code.dynamo_code), + ) + + def cache_entry(self) -> _DynamoCacheEntry: + self.validate() + return _DynamoCacheEntry( + codes=list(self._codes.values()), inlined_sources=self._inlined_sources + ) + + +@CacheArtifactFactory.register +class EagerCacheArtifact(PrecompileCacheArtifact[Any]): + @staticmethod + def type() -> str: + return "precompile_eager" + + def after_deserialization(self) -> Any: + return pickle.loads(self.content) + + +_Backends = dict[_BackendId, PrecompileCacheArtifact[Any]] + + +class DynamoStore(abc.ABC): + """ + A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them. + + This is an abstract base class for different storage implementations. + """ + + def record_package(self, package: CompilePackage) -> None: + """ + Records a package to PrecompileContext, so that it can be serialized later. + """ + cache_entry = package.cache_entry() + pickled_result = pickle.dumps(cache_entry) + PrecompileContext.record_artifact( + _DynamoCacheArtifact.type(), key=package.source_id, content=pickled_result + ) + + def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None: + """ + Records eager fx graphs to PrecompileContext for testing purposes. + """ + pickled_result = pickle.dumps(backend) + PrecompileContext.record_artifact( + EagerCacheArtifact.type(), key=backend_id, content=pickled_result + ) + + @abc.abstractmethod + def write( + self, + dynamo: _DynamoCacheEntry, + backends: _Backends, + path: str, + ) -> None: + """ + Abstract method to write dynamo cache entry and backends to storage. + + Args: + dynamo: The dynamo cache entry to write + backends: Dictionary of backend content to write + path: Path or key to identify where to write the data + """ + ... + + def save_package(self, package: CompilePackage, key: str) -> None: + """ + Saves a package to a given path. Grabs backends from PrecompileContext. + """ + backend_content: _Backends = {} + cache_entry = package.cache_entry() + for backend_id in cache_entry.backend_ids: + serialized_backend = PrecompileContext.serialize_artifact_by_key(backend_id) + if serialized_backend is None: + raise RuntimeError( + f"Backend {backend_id} is not found in the given backends" + ) + assert isinstance(serialized_backend, PrecompileCacheArtifact) + backend_content[backend_id] = serialized_backend + + self.write(cache_entry, backend_content, key) + + @abc.abstractmethod + def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]: + """ + Abstract method to read dynamo cache entry and backends from storage. + + Args: + path: Path or key to identify where to read the data from + + Returns: + A tuple containing (dynamo_cache_entry, backend_content) + """ + ... + + def load_package( + self, fn: Any, key: str + ) -> tuple[CompilePackage, dict[_BackendId, Any]]: + """ + Loads a package from a given path and returns it plus a list of deserialized backends + """ + cache_entry, backend_content = self.read(key) + + for backend_id, backend in backend_content.items(): + backend_content[backend_id] = backend.after_deserialization() + + package = CompilePackage(fn, cache_entry) + return package, backend_content + + +class InMemoryDynamoStore(DynamoStore): + """ + A DynamoStore implementation that keeps state about CompilePackages in memory. + """ + + def __init__(self) -> None: + self.packages: dict[str, tuple[_DynamoCacheEntry, _Backends]] = {} + + def write( + self, + dynamo: _DynamoCacheEntry, + backends: _Backends, + path: str, + ) -> None: + """ + Store the dynamo cache entry and backends in memory instead of writing to disk. + """ + self.packages[path] = (dynamo, backends) + + def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]: + """ + Read dynamo cache entry and backends from memory. + """ + if path not in self.packages: + raise RuntimeError(f"No package found with key {path}") + + return self.packages[path] + + +class DiskDynamoStore(DynamoStore): + """ + A DynamoStore implementation that keeps state about CompilePackages on disk. + """ + + def __init__(self, path_prefix: str = ""): + """ + Initialize a DiskDynamoStore with a path prefix. + + Args: + path_prefix: Prefix directory for where to put CompilePackages on disk + """ + self.path_prefix = path_prefix + + def write( + self, + dynamo: _DynamoCacheEntry, + backends: _Backends, + path: str, + ) -> None: + """ + Write dynamo cache entry and backends to disk. + """ + try: + with open(os.path.join(path, "dynamo"), "wb") as dynamo_path: + pickle.dump(dynamo, dynamo_path) + with open(os.path.join(path, "backends"), "wb") as backend_path: + pickle.dump(backends, backend_path) + except Exception as e: + raise RuntimeError(f"Failed to save package to {path}: {e}") from e + + def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]: + """ + Read dynamo cache entry and backends from disk. + """ + try: + with open(os.path.join(path, "dynamo"), "rb") as dynamo_path: + cache_entry = pickle.load(dynamo_path) + with open(os.path.join(path, "backends"), "rb") as backend_path: + backend_content = pickle.load(backend_path) + return cache_entry, backend_content + except Exception as e: + raise RuntimeError(f"Failed to load package from path {path}: {e}") from e + + def save_package(self, package: CompilePackage, key: str) -> None: + """ + Save a package to disk using the path_prefix + key as the file path. + """ + full_path = os.path.join(self.path_prefix, key) if self.path_prefix else key + super().save_package(package, full_path) + + def load_package( + self, fn: Any, key: str + ) -> tuple[CompilePackage, dict[_BackendId, Any]]: + """ + Load a package from disk using the path_prefix + key as the file path. + """ + full_path = os.path.join(self.path_prefix, key) if self.path_prefix else key + return super().load_package(fn, full_path) diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 999f0e3c1951ec..9bdec2df05c26f 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -374,90 +374,101 @@ def update_automatic_dynamic( ) -> FrameStateSizeEntry: code_id = CodeId.make(tx.f_code) frame_state = get_code_state()[code_id] - is_update = name in frame_state.automatic_dynamic - mut_entry = frame_state.automatic_dynamic[name] - old_entry = copy.copy(mut_entry) - mut_entry |= entry - - # Do some logs (damn, I spend more code logging than I do actually doing - # the updates lol) - if is_update and old_entry.scalar != mut_entry.scalar: - log.debug( - "automatic dynamic int %s val %s != %s", - name, - entry.scalar, - old_entry.scalar, - ) - CompileEventLogger.instant( - "automatic_dynamic", - { - "name": name, - "dim_changed": "scalar", - "reason": "scalar change", - "cached": str(old_entry.scalar), - "new": str(entry.scalar), - }, - ) - if is_unspecialized_nn_module: - log.info( - "%s is converted to a symbolic integer. It is an attribute of a " - "user defined nn module class. If you wish to keep it static, you can " - "mark the nn module class as `torch._dynamo.mark_static`.", + if torch._dynamo.config.automatic_dynamic_shapes: + is_update = name in frame_state.automatic_dynamic + mut_entry = frame_state.automatic_dynamic[name] + old_entry = copy.copy(mut_entry) + mut_entry |= entry + + # Do some logs (damn, I spend more code logging than I do actually doing + # the updates lol) + if is_update and old_entry.scalar != mut_entry.scalar: + log.debug( + "automatic dynamic int %s val %s != %s", name, + entry.scalar, + old_entry.scalar, + ) + CompileEventLogger.instant( + "automatic_dynamic", + { + "name": name, + "dim_changed": "scalar", + "reason": "scalar change", + "cached": str(old_entry.scalar), + "new": str(entry.scalar), + }, ) + if is_unspecialized_nn_module: + log.info( + "%s is converted to a symbolic integer. It is an attribute of a " + "user defined nn module class. If you wish to keep it static, you can " + "mark the nn module class as `torch._dynamo.mark_static`.", + name, + ) - def log_tup( - tup_name: str, short_reason: str, long_reason: str, i: Optional[int] = None - ) -> None: - entry_tup = ( - getattr(entry, tup_name) if i is None else getattr(entry, tup_name)[i] - ) - old_entry_tup = ( - getattr(old_entry, tup_name) - if i is None - else getattr(old_entry, tup_name)[i] - ) - log.debug( - "automatic dynamic %s %s %s %s != %s", - tup_name, - name, - short_reason, - # NB: We used to only report len(...) here for dim mismatch - entry_tup, - old_entry_tup, - ) - CompileEventLogger.instant( - "automatic_dynamic", - { - "name": name, - "dim_changed": "all" if i is None else i, - "reason": long_reason, - "cached": str(old_entry_tup), - "new": str(entry_tup), - }, - ) + def log_tup( + tup_name: str, short_reason: str, long_reason: str, i: Optional[int] = None + ) -> None: + entry_tup = ( + getattr(entry, tup_name) if i is None else getattr(entry, tup_name)[i] + ) + old_entry_tup = ( + getattr(old_entry, tup_name) + if i is None + else getattr(old_entry, tup_name)[i] + ) + log.debug( + "automatic dynamic %s %s %s %s != %s", + tup_name, + name, + short_reason, + # NB: We used to only report len(...) here for dim mismatch + entry_tup, + old_entry_tup, + ) + CompileEventLogger.instant( + "automatic_dynamic", + { + "name": name, + "dim_changed": "all" if i is None else i, + "reason": long_reason, + "cached": str(old_entry_tup), + "new": str(entry_tup), + }, + ) - if is_update and old_entry.size != mut_entry.size: - if isinstance(old_entry.size, tuple) and isinstance(entry.size, tuple): - if len(old_entry.size) != len(entry.size): - log_tup("size", "dim", "dimensionality change") + if is_update and old_entry.size != mut_entry.size: + if isinstance(old_entry.size, tuple) and isinstance(entry.size, tuple): + if len(old_entry.size) != len(entry.size): + log_tup("size", "dim", "dimensionality change") + else: + for i in range(len(entry.size)): + if old_entry.size[i] != entry.size[i]: + log_tup("size", f"size({i})", "size change", i) else: - for i in range(len(entry.size)): - if old_entry.size[i] != entry.size[i]: - log_tup("size", f"size({i})", "size change", i) - else: - log_tup("size", "other", "other") - - if is_update and old_entry.stride != mut_entry.stride: - if isinstance(old_entry.stride, tuple) and isinstance(entry.stride, tuple): - if len(old_entry.stride) != len(entry.stride): - log_tup("stride", "dim", "dimensionality change") + log_tup("size", "other", "other") + + if is_update and old_entry.stride != mut_entry.stride: + if isinstance(old_entry.stride, tuple) and isinstance(entry.stride, tuple): + if len(old_entry.stride) != len(entry.stride): + log_tup("stride", "dim", "dimensionality change") + else: + for i in range(len(entry.stride)): + if old_entry.stride[i] != entry.stride[i]: + log_tup("stride", f"stride({i})", "stride change", i) else: - for i in range(len(entry.stride)): - if old_entry.stride[i] != entry.stride[i]: - log_tup("stride", f"stride({i})", "stride change", i) - else: - log_tup("stride", "other", "other") + log_tup("stride", "other", "other") + else: + old_entry = frame_state.automatic_dynamic[name] + log.debug( + "automatic dynamic is off, overwriting int %s val %s -> %s", + name, + old_entry.scalar, + entry.scalar, + ) + frame_state.automatic_dynamic[name] = entry + mut_entry = entry return mut_entry @@ -535,9 +546,6 @@ def get_cache_key() -> Optional[str]: ) return f"{r}:{rank}:{tag}" - if r := torch.compiler.config.sticky_pgo_key: - return f"sticky:{r}:{rank}:{tag}" - if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None: mast_job_name, mast_job_version = name_version return f"mast:{mast_job_name}:{mast_job_version}:{rank}:{tag}" @@ -594,21 +602,45 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: ) +def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]: + dynamic_sources: OrderedSet[str] = OrderedSet() + for src, fs in code_state.automatic_dynamic.items(): + dynamic = False + if isinstance(fs.size, tuple): + dynamic = auto_dynamic in fs.size # type: ignore[operator] + elif fs.scalar == auto_dynamic: + dynamic = True + if dynamic: + dynamic_sources.add(src) + return dynamic_sources + + +def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: + code_id = CodeId.make(f_code) + frame_state = get_code_state()[code_id] + frame_whitelist = ",".join(_collect_dynamic_sources(frame_state)) + if frame_whitelist: + with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True): + CompileEventLogger.pt2_compile( + name, recompile_dynamic_whitelist=frame_whitelist + ) + + def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: - terms: list[str] = [] + code_state_str = "\n".join( + f"{k}:\n" + + "\n".join( + f" {src}: {fs.render()}" for src, fs in v.automatic_dynamic.items() + ) + for k, v in cs.items() + ) dynamic_sources: OrderedSet[str] = OrderedSet() - for k, v in cs.items(): - cs_terms: list[str] = [] - for src, fs in v.automatic_dynamic.items(): - cs_terms.append(f" {src}: {fs.render()}") - if isinstance(fs.size, tuple) and auto_dynamic in fs.size: # type: ignore[operator] - dynamic_sources.add(src) - terms.append(f"{k}:\n" + "\n".join(cs_terms)) - code_state_str = "\n".join(terms) + for state in cs.values(): + dynamic_sources.update(_collect_dynamic_sources(state)) if dynamic_sources: code_state_str += ( - "\n\nPGO detected changes a recompilation due to tensor sizes. " - "To potentially avoid thisTo reduce shape recompilations by compiling dynamically to start, " + "\n\nPGO detected a recompilation due to dynamic shapes. " + "To reduce shape recompilations by compiling dynamically to start, " f'set environment variable TORCH_COMPILE_DYNAMIC_SOURCES="{",".join(dynamic_sources)}"' ) return code_state_str @@ -636,7 +668,7 @@ def _rewrite_cache_key_for_mega_cache(original_key: str) -> str: update the key to use the new MAST job's name and version. """ if not original_key.startswith("mast:"): - # if original_key is overriden, then dont change it + # if original_key is overridden, then dont change it return original_key if (new_key := get_cache_key()) is not None: return new_key @@ -662,7 +694,7 @@ def hit(ty: str) -> defaultdict[CodeId, CodeState]: trace_structured_artifact( f"get_{ty}_code_state", "string", - lambda: render_code_state(_CODE_STATE), + lambda: render_code_state(_CODE_STATE), # type: ignore[arg-type] ) set_feature_use("pgo", True) _INIT_CODE_STATE = copy.deepcopy(_CODE_STATE) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index c0722999fbec42..93a8c27b80cac6 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -9,7 +9,7 @@ # mypy: allow-untyped-defs import types -from collections.abc import Iterable, MutableMapping, Sequence +from collections.abc import Hashable, Iterable, MutableMapping, Sequence from itertools import repeat as _repeat from typing import Any, Callable, TYPE_CHECKING @@ -119,9 +119,15 @@ def set_symmetric_difference_update(set1, set2): def set_isdisjoint(set1, set2): + if not isinstance(set2, Iterable): + raise TypeError(f"'{type(set2)}' object is not iterable") + for x in set1: - if x in set2: - return False + for y in set2: + if not isinstance(y, Hashable): + raise TypeError(f"unhashable type: '{type(y)}'") + if x == y: + return False return True @@ -129,10 +135,18 @@ def set_intersection(set1, *others): if len(others) == 0: return set1.copy() + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.difference expected an iterable, got {type(others)}") + + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + + # return a new set with elements common in all sets intersection_set = set() for x in set1: for set2 in others: - if x not in set2: + if not any(x == y for y in set2): break else: intersection_set.add(x) @@ -147,9 +161,21 @@ def set_intersection_update(set1, *others): def set_union(set1, *others): # frozenset also uses this function + if len(others) == 0: + return set1.copy() + + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.union expected an iterable, got {type(others)}") + + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + union_set = set(set1.copy()) for set2 in others: set_update(union_set, set2) + + # frozenset also uses this function return type(set1)(union_set) @@ -170,6 +196,10 @@ def set_difference(set1, *others): if not all(isinstance(s, Iterable) for s in others): raise TypeError(f"set.difference expected an iterable, got {type(others)}") + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + difference_set = set() for x in set1: for set2 in others: @@ -186,6 +216,21 @@ def set_difference_update(set1, *others): set1.update(result) +def assert_multi_line_equal(self_, first, second, msg=None): + return self_.assertTrue(first == second, msg) + + +# The original impl. uses difflib +def assert_sequence_equal(self_, seq1, seq2, msg=None, seq_type=None): + return self_.assertTrue(seq1 == seq2, msg) + + +def generator___contains__(gen, item): + # "any" lazily consumes the generator, which is important to prevent + # unintended side effects. + return any(e == item for e in gen) + + def getattr_and_trace(*args, **kwargs): wrapper_obj = args[0] attr_name = args[1] @@ -218,7 +263,7 @@ def construct_dict(cls, /, *args, **kwargs): src = args[0] # Ensure that the overridden __iter__ method is invoked - if isinstance(src, (dict, MutableMapping)): + if isinstance(src, (dict, MutableMapping, types.MappingProxyType)): for key in src: # This will inline the __getitem__ of the src object dst[key] = src[key] diff --git a/torch/_dynamo/polyfills/sys.py b/torch/_dynamo/polyfills/sys.py index 2504d2b6fcab85..ab666c385806f9 100644 --- a/torch/_dynamo/polyfills/sys.py +++ b/torch/_dynamo/polyfills/sys.py @@ -23,3 +23,12 @@ def intern(string: str, /) -> str: @substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True) def getrecursionlimit() -> int: return sys.getrecursionlimit() + + +if hasattr(sys, "get_int_max_str_digits"): + + @substitute_in_graph(sys.get_int_max_str_digits, can_constant_fold_through=True) + def get_int_max_str_digits() -> int: + return sys.get_int_max_str_digits() + + __all__ += ["get_int_max_str_digits"] diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py new file mode 100644 index 00000000000000..1ed28347066766 --- /dev/null +++ b/torch/_dynamo/precompile_context.py @@ -0,0 +1,152 @@ +from abc import abstractmethod +from collections import defaultdict +from typing import Any, Generic, Optional, TypeVar +from typing_extensions import override + +from torch.compiler._cache import ( + _serialize_single_cache, + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, + CacheArtifactsResult, + CacheInfo, +) +from torch.utils._appending_byte_serializer import AppendingByteSerializer +from torch.utils._ordered_set import OrderedSet + + +""" +Classes and implementations related to precompile +""" + +T = TypeVar("T") + + +class PrecompileCacheArtifact(CacheArtifact, Generic[T]): + """ + Data for each cache artifact that will be serialized and deserialized by + PrecompileContext, rather than CacheArtifactManager. + T represents the deserialized type of the artifact, i.e. the return type of after_deserialization + + PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts + as needed, and use them in after_deserialization. + + Example implementation: + + class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]): + my_field: int + + def after_deserialization(self) -> MySerializableType: + result = pickle.loads(self.content) + # Do some extra work post deserialization + result.my_post_deserialization_function(self.my_field) + return result + """ + + @override + def populate_cache(self) -> None: + raise RuntimeError("Precompile cache artifacts do not populate caches") + + @override + def precompile_compatible(self) -> bool: + return True + + @abstractmethod + def after_deserialization(self) -> T: + """ + Code to be run after reading raw byte contents from disk. + Generally converts self.content from raw bytes back into its original form. + """ + ... + + +class PrecompileContext(CacheArtifactManager): + """ + PrecompileContext is a special CacheArtifactManager for handling precompilation + It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead + of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key + together and place it into a global Precompile Cache. + + The following artifact types are supported by PrecompileContext: + - BundledAOTAutogradCacheArtifact + - CodeStateArtifact (from torch._dynamo.package once available) + """ + + # Protected by the compile_lock + # _new_cache_artifacts_by_key organizes results by the key of each artifact. + # This allows us to implement serialize_by_key easily. + # On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key + # are transferred to _new_cache_artifacts before serialization. + _new_cache_artifacts_by_key: dict[str, CacheArtifact] = {} + _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) + # Keep a separate seen artifacts list to make avoid unnecessary duplicates + # This list will not be cleared between serialize() calls + _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet() + # When serialize() is called, artifacts are transferred from _cache_artifacts to + # internal data structure of the _serializer + # This allows us to only pay the cost of serialization if serialize() is called + _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = ( + AppendingByteSerializer(serialize_fn=_serialize_single_cache) + ) + _cache_info: CacheInfo = CacheInfo() + + @classmethod + def clear(cls) -> None: + cls._new_cache_artifacts_by_key.clear() + super().clear() + + @override + @classmethod + def record_artifact( + cls, + artifact_type: str, + key: str, + content: Any, + ) -> None: + """ + Called from each caching operation to record the artifact in this + "mega" list + """ + artifact = CacheArtifactFactory.encode_create(artifact_type, key, content) + # TODO: although this covers completely same artifacts, it's possible + # with AOTAutogradCacheEntries to have multiple artifacts whose keys + # (i.e. backend_ids) are different, but whose contents are equal. + # In those cases, it would be much better if we only serialize once instead + # of N times. + if artifact in cls._seen_artifacts: + return + + cls._new_cache_artifacts_by_key[key] = artifact + cls._seen_artifacts.add(artifact) + + @classmethod + def _save_artifacts_by_type(cls) -> None: + """ + We normally record artifacts by key, but serialization expects them to be organized + by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts + """ + for artifact in cls._new_cache_artifacts_by_key.values(): + cls._new_cache_artifacts[artifact.__class__.type()].append(artifact) + cls._new_cache_artifacts_by_key.clear() + + @classmethod + def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]: + """ + Serialize all artifacts with the given key returned in a list. + """ + return cls._new_cache_artifacts_by_key.get(key, None) + + @classmethod + def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: + cls._save_artifacts_by_type() + return super().serialize() + + @staticmethod + def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: + raise NotImplementedError("TODO") + + @classmethod + def _ensure_cache_artifacts_registered(cls) -> None: + from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401 + BundledAOTAutogradCacheArtifact, + ) diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 00e412c7cc5835..f064dcde5a6bc3 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -301,7 +301,7 @@ def generate_compiler_repro_string( {extra_imports} {maybe_fbcode_instructions()} - """ + """ ) if not stable_output: model_str += f"# torch version: {torch.version.__version__}\n" @@ -313,12 +313,12 @@ def generate_compiler_repro_string( model_str += NNModuleToString.convert(gm) - # get hint shape/stride when dynamic shape enabled - def hint_if_symint(x): - return tuple(i.node.hint if isinstance(i, torch.SymInt) else i for i in x) - writer = InputWriter(save_dir, stable_hash=stable_hash) - for placeholder, arg in zip(fx_placeholder_targets(gm), args): + used_syms = {} + + # Extract from graph placeholders and their corresponding arguments + placeholder_targets = fx_placeholder_targets(gm) + for placeholder, arg in zip(placeholder_targets, args): if isinstance(arg, (int, torch.SymInt)): writer.symint(placeholder, arg) elif isinstance(arg, torch.Tensor): @@ -327,11 +327,32 @@ def hint_if_symint(x): elif arg is None: writer.const(placeholder) else: - # It's better to produce a slightly wrong repro string than none - # at all writer.unsupported(placeholder, arg) - model_str += "\n".join(writer.lines()) + "\n" + # Extract symbolic variables from the same arguments + if isinstance(arg, torch.SymInt): + sym_name = str(arg.node) + if arg.node.hint is not None: + used_syms[sym_name] = arg.node.hint + elif isinstance(arg, torch.Tensor): + # Extract symbolic variables from tensor shapes and strides + for dim in arg.shape: + if isinstance(dim, torch.SymInt) and dim.node.hint is not None: + used_syms[str(dim.node)] = dim.node.hint + for stride in arg.stride(): + if isinstance(stride, torch.SymInt) and stride.node.hint is not None: + used_syms[str(stride.node)] = stride.node.hint + + # Add symbolic variable definitions to the top of the generated code + if used_syms: + hint_lines = "\n".join( + f"{name} = {hint}" for name, hint in sorted(used_syms.items()) + ) + model_str = f"{hint_lines}\n\n{model_str}" + + load_args_lines = writer.lines() + load_args_code = "\n".join(load_args_lines) + model_str += load_args_code + "\n" model_str += "mod = Repro()\n" return model_str @@ -465,7 +486,7 @@ def isolate_fails( if use_buck: cmd = BuckTargetWriter(file_name).write(print_msg=False) else: - cmd = ["python", file_name] + cmd = [sys.executable, file_name] p = subprocess.Popen( cmd, diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 2738ae92391557..80191f2d6cefc4 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -81,7 +81,7 @@ def _accuracy_fails(gm, example_inputs, compiler_fn): class WrapBackendDebug: def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None: functools.wraps(unconfigured_compiler_fn)(self) - self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] + self._torchdynamo_orig_backend = unconfigured_compiler_fn # type: ignore[attr-defined] self._compiler_name = compiler_name if hasattr(unconfigured_compiler_fn, "__name__"): self.__name__ = unconfigured_compiler_fn.__name__ @@ -91,7 +91,7 @@ def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None: self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] def __call__(self, gm, example_inputs, **kwargs): - compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs) + compiler_fn = functools.partial(self._torchdynamo_orig_backend, **kwargs) assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index cd050b8da7f766..beaaa77671e1c1 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -49,6 +49,7 @@ # trace_rules.py import this constant for consistency TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in" +IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue" def _initial_push_null(insts): @@ -356,6 +357,7 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]): for v in code_options["co_varnames"] if v not in args and v not in freevars ] + + [IS_TRACING_RESUME_PROLOGUE_VARNAME] ) code_options["co_flags"] = code_options["co_flags"] & ~( CO_VARARGS | CO_VARKEYWORDS @@ -370,6 +372,18 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]): ) prefix.append(create_instruction("RESUME", arg=0)) + # Set is_tracing_resume_prologue to prevent graph breaks. + # This doesn't really do anything at runtime, but dynamo will trace this + # and will know that we're in a resume function prologue. + prefix.extend( + [ + create_instruction("LOAD_CONST", argval=True), + create_instruction( + "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME + ), + ] + ) + cleanup: list[Instruction] = [] hooks = {fn.stack_index: fn for fn in setup_fns} hook_target_offsets = { @@ -431,6 +445,16 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]): ] ) + # Set is_tracing_resume_prologue back to allow graph breaks. + prefix.extend( + [ + create_instruction("LOAD_CONST", argval=False), + create_instruction( + "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME + ), + ] + ) + prefix.append(create_jump_absolute(target)) # because the line number table monotonically increases from co_firstlineno diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 1deb09e2cc1ec1..a109d11e473de7 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,5 +1,28 @@ # mypy: allow-untyped-defs +""" +Side effect tracking and management for TorchDynamo's compilation system. + +This module provides infrastructure for tracking and managing side effects that occur +during symbolic execution, including: + +- Tracking mutations to objects, attributes, and variables +- Managing context changes (cell variables, global namespace modifications) +- Handling aliasing and object identity preservation +- Managing stack frame state and local variable changes +- Tracking function calls with side effects + +Key classes: +- SideEffects: Main container for tracking all side effects during execution +- MutableSideEffects: Specialization for mutable object tracking +- AttributeMutation/ValueMutation: Track specific types of mutations +- Various specialized side effect classes for different scenarios + +The side effect system ensures that mutations performed during symbolic execution +are properly replayed during runtime, maintaining the correctness of compiled code +while enabling optimizations where safe. +""" + import collections import contextlib import inspect @@ -101,6 +124,25 @@ def __init__( # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. self.ca_final_callbacks_var = None + # Tracks VariableTracker objects whose mutations can be skipped. + # For normal mutated variables, Dynamo generates code to replay/reconstruct + # the mutations after graph execution. However, variables in this set have + # their mutations ignored - the mutations happen during + # execution but don't need to be replayed in the generated code. + # Used for temporary mutations in contexts like torch.func.functional_call, + # where module parameters/buffers are modified but later restored. + self.ignore_mutation_on_these_variables = set() + + def ignore_mutations_on(self, var): + """Mutations to this variable will be executed but not not tracked, + typically used for temporary mutations that are later restored.""" + self.ignore_mutation_on_these_variables.add(var) + + def stop_ignoring_mutations_on(self, var): + """Remove a variable from the skip mutation set, restoring normal mutation tracking.""" + if var in self.ignore_mutation_on_these_variables: + self.ignore_mutation_on_these_variables.remove(var) + def __eq__(self, other: object) -> bool: assert isinstance(other, SideEffects) # NB: do NOT test keepalive @@ -160,6 +202,13 @@ def should_allow_side_effects_under_checkpoint(self): and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint ) + def should_allow_externally_visible_side_effects_in_subtracer(self): + output_graph = self.output_graph_weakref() + return ( + output_graph + and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + ) + def is_reconstructing_generator(self): output_graph = self.output_graph_weakref() @@ -168,13 +217,15 @@ def is_reconstructing_generator(self): and output_graph.current_tx.output.current_tracer.is_reconstructing_generator ) - def check_allowed_side_effect(self, item): + def check_allowed_side_effect(self, item: VariableTracker): from torch._dynamo.variables.misc import AutogradFunctionContextVariable # People do things like self.dim = dim inside autograd.Function. # These are benign. if isinstance(item, AutogradFunctionContextVariable): return True + if self.should_allow_externally_visible_side_effects_in_subtracer(): + return True if self.should_allow_side_effects_under_checkpoint(): return True if self.is_reconstructing_generator(): @@ -253,9 +304,12 @@ def cls_supports_mutation_side_effects(cls): return inspect.getattr_static(cls, "__getattribute__", None) in ( object.__getattribute__, dict.__getattribute__, + set.__getattribute__, + frozenset.__getattribute__, int.__getattribute__, str.__getattribute__, list.__getattribute__, + tuple.__getattribute__, BaseException.__getattribute__, ) @@ -368,6 +422,8 @@ def get_variable_cls(self, user_cls): variable_cls = variables.UnspecializedNNModuleVariable elif issubclass(user_cls, (dict, collections.OrderedDict)): variable_cls = variables.UserDefinedDictVariable + elif issubclass(user_cls, (set, frozenset)): + variable_cls = variables.UserDefinedSetVariable elif issubclass(user_cls, tuple): variable_cls = variables.UserDefinedTupleVariable elif issubclass(user_cls, list): @@ -557,6 +613,9 @@ def is_live(var: VariableTracker): } def mutation(self, var): + if var in self.ignore_mutation_on_these_variables: + return + self.check_allowed_side_effect(var) if isinstance(var.mutation_type, ValueMutationExisting): var.mutation_type.is_modified = True @@ -997,6 +1056,23 @@ def codegen_update_mutated(self, cg: PyCodegen): suffixes.append( [create_instruction("DELETE_ATTR", argval=name)] ) + elif isinstance( + var, variables.UserDefinedObjectVariable + ) and var.should_skip_descriptor_setter(name): + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "object_setattr_ignore_descriptor" + ) + ) + cg(var.source) # type: ignore[attr-defined] + cg(variables.ConstantVariable(name)) + cg(value) + suffixes.append( + [ + *create_call_function(3, False), + create_instruction("POP_TOP"), + ] + ) elif ( isinstance(var, variables.UserDefinedObjectVariable) and var.needs_slow_setattr() @@ -1069,6 +1145,16 @@ def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val +@contextlib.contextmanager +def allow_externally_visible_side_effects_in_subtracer(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + try: + tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True + yield + finally: + tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = orig_val + + @contextlib.contextmanager def disallow_side_effects_in_generator(tx: "InstructionTranslator"): orig_val = tx.output.current_tracer.is_reconstructing_generator diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 106131859d8a25..4c18c1f47ce525 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -21,6 +21,7 @@ import dataclasses import enum +import functools from typing import Any, Optional, TYPE_CHECKING, Union from torch._guards import ChainedSource, GuardSource, Source @@ -586,6 +587,34 @@ def is_dict_key(self): return True +@dataclasses.dataclass(frozen=True) +class NonSerializableSetGetItemSource(ChainedSource): + index: int + + def __post_init__(self): + from .variables import ConstantVariable + + assert ConstantVariable.is_literal(self.index) + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "set_getitem") + ) + codegen(self.base) + codegen.append_output(codegen.create_load_const(self.index)) + codegen.extend_output(create_call_function(2, False)) + + def name(self): + # set ordering might not be stable + return f"list({self.base.name()})[{self.index!r}]" + + def is_dict_key(self): + return False + + # Used to access an item from the dictionary @dataclasses.dataclass(frozen=True) class DictGetItemSource(ChainedSource): @@ -594,6 +623,43 @@ class DictGetItemSource(ChainedSource): # 2) constant - like string, integer index: Any + def __post_init__(self): + from .variables import ConstantVariable + + assert isinstance( + self.index, ConstDictKeySource + ) or ConstantVariable.is_literal(self.index) + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): + # Load dict + codegen(self.base) + + # Load key + if isinstance(self.index, Source): + codegen(self.index) + else: + codegen.append_output(codegen.create_load_const(self.index)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def name(self): + if isinstance(self.index, ConstDictKeySource): + return f"{self.base.name()}[{self.index.name()}]" + else: + return f"{self.base.name()}[{self.index!r}]" + + +# Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that +# torch.compile does not run the overridden __getitem__ method +@dataclasses.dataclass(frozen=True) +class DictSubclassGetItemSource(ChainedSource): + # Key to access in the dictionary. It can be one of the the following types + # 1) ConstDictKeySource + # 2) constant - like string, integer + index: Any + def __post_init__(self): from .variables import ConstantVariable @@ -901,6 +967,7 @@ def is_from_source(source: Source, target: Source): return source == target +@functools.lru_cache def is_from_unspecialized_nn_module_source(source: Source): if isinstance(source, UnspecializedNNModuleSource): return True @@ -909,6 +976,16 @@ def is_from_unspecialized_nn_module_source(source: Source): return False +@functools.lru_cache +def is_from_unspecialized_builtin_nn_module_source(source: Source): + if isinstance(source, UnspecializedBuiltinNNModuleSource): + return True + if isinstance(source, ChainedSource): + return is_from_unspecialized_builtin_nn_module_source(source.base) + return False + + +@functools.lru_cache def is_from_unspecialized_param_buffer_source(source: Source): if isinstance(source, UnspecializedParamBufferSource): return True @@ -917,6 +994,7 @@ def is_from_unspecialized_param_buffer_source(source: Source): return False +@functools.lru_cache def is_from_flatten_script_object_source(source: Source): if isinstance(source, FlattenScriptObjectSource): return True @@ -925,6 +1003,7 @@ def is_from_flatten_script_object_source(source: Source): return False +@functools.lru_cache def is_from_optimizer_source(source: Source): if isinstance(source, OptimizerSource): return True @@ -935,6 +1014,7 @@ def is_from_optimizer_source(source: Source): # TODO: can probably write a generic "test this on everything in the chain" # helper +@functools.lru_cache def is_from_defaults(source: Source): if isinstance(source, DefaultsSource): return True diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 5b977bb61b18a9..be0d0ebbcebc7f 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -44,7 +44,7 @@ import types import typing import weakref -from typing import Any, Callable, cast, NoReturn, Optional, Union +from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union from unittest.mock import patch import torch @@ -95,7 +95,11 @@ from .guards import GuardBuilder, install_guard from .output_graph import GraphCompileReason, OutputGraph from .replay_record import DummyModule, ExecutionRecorder -from .resume_execution import ContinueExecutionCache, ReenterWith +from .resume_execution import ( + ContinueExecutionCache, + IS_TRACING_RESUME_PROLOGUE_VARNAME, + ReenterWith, +) from .source import ( AttrSource, DictGetItemSource, @@ -107,6 +111,7 @@ ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( + _get_error_on_graph_break, counters, get_fake_value, get_instruction_source_311, @@ -167,6 +172,9 @@ ) +if TYPE_CHECKING: + from .package import CompilePackage + log = logging.getLogger(__name__) graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") @@ -205,20 +213,29 @@ class SpeculationEntry: lineno: int instruction_pointer: int inst: Instruction # for debugging only - failed: bool = False + _failed: bool = False + error_on_graph_break: Optional[bool] = None reason: Optional[GraphCompileReason] = None - def fail_and_restart_analysis(self): + def fail_and_restart_analysis(self, error_on_graph_break: bool): """ Start tracing of the current frame over again, and don't take this branch. """ - self.failed = True + self._failed = True + self.error_on_graph_break = error_on_graph_break if self.reason is not None: restart_reason = self.reason.reason else: restart_reason = "Unknown fail_and_restart_analysis" raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason) + def failed(self, tx): + if self._failed: + assert self.error_on_graph_break is not None + tx.error_on_graph_break = self.error_on_graph_break + return True + return False + @dataclasses.dataclass class SpeculationLog: @@ -270,7 +287,7 @@ def next( - Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer}) - Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer}) {prev_entry_msg} -There are two usual reasons why this may have occured: +There are two usual reasons why this may have occurred: - When Dynamo analysis restarted, the second run took a different path than the first. If this occurred, the previous instruction is the critical instruction that behaved differently. @@ -331,7 +348,7 @@ def empty(cls) -> bool: return len(cls.force_specializations) == 0 -@functools.lru_cache(None) +@functools.cache def _step_logger(): return torchdynamo_logging.get_step_logger(log) @@ -650,7 +667,7 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): *graph_break_hints.FUNDAMENTAL, "Use `torch._assert()` to raise a hard AssertionError when the check fails. " "This error will propagate back the user code " - "that called the compiled function (i.e. Dynamo wil not trace any exception handling).", + "that called the compiled function (i.e. Dynamo will not trace any exception handling).", "Remove the assert statement.", "Move the assert statement outside of any context managers in order to graph break with " "partial graph compilation (if fullgraph=False).", @@ -810,10 +827,13 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): self.jump(inst) else: unimplemented_v2( - gb_type=_gb_type, + gb_type="Data-dependent branching", context=f"attempted to jump with {value}", explanation=_explanation, - hints=_hints, + hints=[ + *graph_break_hints.FUNDAMENTAL, + "Use `torch.cond` to express dynamic control flow.", + ], ) return inner @@ -824,7 +844,7 @@ def decorator(inner_fn): @functools.wraps(inner_fn) def wrapper(self: "InstructionTranslatorBase", inst: Instruction): speculation = self.speculate() - if speculation.failed: + if speculation.failed(self): assert speculation.reason is not None return handle_graph_break(self, inst, speculation.reason) try: @@ -869,7 +889,7 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): excp.remove_from_stats() excp.add_to_stats("graph_break") speculation.reason = GraphCompileReason(excp.msg, excp.real_stack) - speculation.fail_and_restart_analysis() + speculation.fail_and_restart_analysis(self.error_on_graph_break) def handle_graph_break( self: "InstructionTranslatorBase", @@ -1094,6 +1114,7 @@ class InstructionTranslatorBase( is_leaf_tracer: bool parent: Optional["InstructionTranslatorBase"] debug_locals: list[tuple[VariableTracker, list[VariableTracker]]] + package: Optional["CompilePackage"] def mark_inconsistent_side_effects(self): """ @@ -1135,22 +1156,10 @@ def maybe_has_backedge(self): return False def cellvars(self): - if not hasattr(self, "_cellvars"): - self._cellvars = tuple(self.code_options["co_cellvars"] or []) - # An inlined function might depend on the cellvar of the parent - # function. So, recursively obtain parent cellvars. - if isinstance(self, InliningInstructionTranslator): - self._cellvars += self.parent.cellvars() - return self._cellvars + return self.code_options["co_cellvars"] def freevars(self): - if not hasattr(self, "_freevars"): - self._freevars = tuple(self.code_options["co_freevars"] or []) - # An inlined function might depend on the freevar of the parent - # function. So, recursively obtain parent freevars. - if isinstance(self, InliningInstructionTranslator): - self._freevars += self.parent.freevars() - return self._freevars + return self.code_options["co_freevars"] def cell_and_freevars(self): if not hasattr(self, "_cell_and_freevars"): @@ -1234,6 +1243,8 @@ def starts_line(self, lineno): def step(self): """Process exactly one instruction, return False we should exit""" + self.error_on_graph_break = _get_error_on_graph_break() + ip = self.instruction_pointer if ip is None: return False @@ -1249,7 +1260,7 @@ def step(self): and self.is_non_empty_graph() ): self.current_speculation = self.speculate() - if self.current_speculation.failed: + if self.current_speculation.failed(self): return self.step_graph_break(inst) if self.is_trace_bytecode_log_enabled: @@ -1275,7 +1286,7 @@ def step(self): raise log.debug("step triggered compile", exc_info=True) - self.current_speculation.fail_and_restart_analysis() + self.current_speculation.fail_and_restart_analysis(self.error_on_graph_break) if sys.version_info >= (3, 11): @@ -1458,6 +1469,10 @@ def STORE_FAST(self, inst): loaded_vt = self.pop() loaded_vt.set_name_hint(name) self.symbolic_locals[name] = loaded_vt + if name == IS_TRACING_RESUME_PROLOGUE_VARNAME: + val = loaded_vt.as_python_constant() + assert type(val) is bool + self.is_tracing_resume_prologue = val def DELETE_FAST(self, inst): del self.symbolic_locals[inst.argval] @@ -1554,6 +1569,9 @@ def import_source(self, module_name): else: value = _import_module(module_name) alias = f"__import_{module_name.replace('.', '_dot_')}" + + if self.package is not None: + self.package.add_import_source(alias, module_name) f_globals = self.output.global_scope assert alias not in f_globals or f_globals[alias] is value f_globals[alias] = value @@ -1765,7 +1783,7 @@ def _create_exception_type(self, val): def _raise_exception_variable(self, val) -> NoReturn: # User can raise exception in 2 ways # 1) raise exception type - raise NotImplementedError - # 2) raise execption instance - raise NotImplemetedError("foo") + # 2) raise exception instance - raise NotImplemetedError("foo") # 1) when user raises exception type val = self._create_exception_type(val) @@ -1922,7 +1940,7 @@ def exception_handler(self, raised_exception): self.jump(exn_tab_entry) else: # No handler found. Bubble the exception to the parent - # instruction translater. We use special exception for this. + # instruction translator. We use special exception for this. self.stack.clear() if type(self) is InstructionTranslator: unimplemented_v2( @@ -1948,7 +1966,7 @@ def exception_handler(self, raised_exception): self.exn_vt_stack.pop() if len(self.block_stack) == 0: # No handler found in this frame. Bubble the exception to the parent - # instruction translater. + # instruction translator. self.stack.clear() if type(self) is InstructionTranslator: unimplemented_v2( @@ -2002,7 +2020,7 @@ def exception_handler(self, raised_exception): self.jump(block_stack_entry) else: # No handler found. Bubble the exception to the parent - # instruction translater. We use special exception for this. + # instruction translator. We use special exception for this. self.stack.clear() if type(self) is InstructionTranslator: unimplemented_v2( @@ -2109,7 +2127,7 @@ def check_if_exc_matches(self): unimplemented_v2( gb_type="Caught non-Exception value", context=str(exc_instance), - explanation=f"Except expects to recieve an object of Exception type but received {exc_instance}.", + explanation=f"Except expects to receive an object of Exception type but received {exc_instance}.", hints=[*graph_break_hints.USER_ERROR], ) @@ -2197,53 +2215,6 @@ def CALL_FUNCTION_EX(self, inst): null = self.pop() assert isinstance(null, NullVariable) - if isinstance(fn, GetAttrVariable) and isinstance(fn.obj, TensorVariable): - # realize is requires for Python 3.8 - kwargsvars = kwargsvars.realize() - if fn.name == "view" and isinstance( - argsvars, (ConstantVariable, TensorVariable) - ): - # Hack to handle special case in some bert models. Converts - # x.view(*shape) into x.view(shape), which is correct for view() - # but not generally. See test_transpose_for_scores(). - argsvars = TupleVariable([argsvars]) - elif ( - fn.name == "random_" - and isinstance(argsvars, TupleVariable) - and len(argsvars.items) == 0 - and isinstance(kwargsvars, ConstDictVariable) - and ConstantVariable.create("from") in kwargsvars - ): - # `from`` is python keyword. Adding random_ with `from` in the - # Fx graph causes syntax error. Even if we convert the kwargs to - # args, aot_autograd/inductor while lowering generates - # aten.random.from, again causing syntax errors. Since this - # usecase is uncommon, graph break. - unimplemented_v2( - gb_type="Tensor.random_ op called with `from` keyword", - context="", - explanation="This is not supported.", - hints=[], - ) - elif ( - fn.name == "uniform_" - and isinstance(argsvars, TupleVariable) - and len(argsvars.items) == 0 - and isinstance(kwargsvars, ConstDictVariable) - and ConstantVariable.create("from") in kwargsvars - ): - # `from`` is python keyword. Adding uniform_ with `from` in the - # Fx graph causes syntax error. Even if we convert the kwargs to - # args, aot_autograd/inductor while lowering generates - # aten.uniform.from, again causing syntax errors. Since this - # usecase is uncommon, graph break. - unimplemented_v2( - gb_type="Tensor.uniform_ op called with `from` keyword", - context="", - explanation="This is not supported.", - hints=[], - ) - if not isinstance( argsvars, BaseListVariable ) and argsvars.has_force_unpack_var_sequence(self): @@ -2335,7 +2306,7 @@ def LOAD_ATTR(self, inst): def STORE_ATTR(self, inst): speculation = self.speculate() - if speculation.failed: + if speculation.failed(self): return self.store_attr_graph_break(inst) val, obj = self.popn(2) @@ -2359,7 +2330,7 @@ def STORE_ATTR(self, inst): log.debug("STORE_ATTR triggered compile", exc_info=True) e.remove_from_stats() e.add_to_stats("graph_break") - speculation.fail_and_restart_analysis() + speculation.fail_and_restart_analysis(self.error_on_graph_break) def store_attr_graph_break(self, inst): log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break") @@ -3064,7 +3035,7 @@ def END_FOR(self, inst): self.popn(2) def LOAD_FAST_CHECK(self, inst): - if isinstance(self.symbolic_locals[inst.argval], NullVariable): + if isinstance(self.symbolic_locals.get(inst.argval, None), NullVariable): unimplemented_v2( gb_type="LOAD_FAST_CHECK on uninitialized variable", context=inst.argval, @@ -3234,6 +3205,7 @@ def __init__( distributed_state: Optional[DistributedState], # This determines whether to use the execution recorder. closure: Optional[tuple[types.CellType]] = None, + package: Optional["CompilePackage"] = None, ) -> None: super().__init__() self.speculation_log = speculation_log @@ -3282,7 +3254,16 @@ def __init__( self.num_calls: dict[str, int] = {} # Flag to indicate whether tracing is used for export. self.export = export + # NOTE: one_graph is used for export/debugging to always force errors on graph breaks. + # To toggle fullgraph during normal compile, self.error_on_graph_break + # is used instead. Every step(), its value is updated to the global tls.error_on_graph_break. + # We mirror this value since cleanup may (correctly) inadvertently change tls.error_on_graph_break. + # This assumes that we cannot both trace a change to tls.error_on_graph_break and graph break on + # the same instruction. self.one_graph = False + self.error_on_graph_break = False + # Also do not graph break when tracing resume function prologues + self.is_tracing_resume_prologue = False self.current_speculation = None @@ -3292,6 +3273,8 @@ def __init__( self.parent = None self.debug_locals = [] + self.package = package + if sys.version_info >= (3, 10): from .resume_execution import ( CO_ASYNC_GENERATOR, @@ -3352,6 +3335,7 @@ def __init__( speculation_log: SpeculationLog, exn_vt_stack: ExceptionStack, distributed_state: Optional[DistributedState], + package: Optional["CompilePackage"], ) -> None: _step_logger()( logging.INFO, @@ -3369,6 +3353,7 @@ def __init__( global_scope=f_globals, f_code=f_code, torch_function_mode_stack=torch_function_mode_stack, + package=package, ), instructions=instructions, f_locals=f_locals, @@ -3386,6 +3371,7 @@ def __init__( speculation_log=speculation_log, exn_vt_stack=exn_vt_stack, distributed_state=distributed_state, + package=package, ) self._throw_if_in_functorch() @@ -3541,6 +3527,8 @@ def should_compile_partial_graph(self): return ( all(b.can_restore() for b in self.block_stack) and not self.one_graph + and not self.error_on_graph_break + and not self.is_tracing_resume_prologue and not self.active_generic_context_managers ) @@ -3622,12 +3610,19 @@ def create_call_resume_at(self, inst, all_stack_locals_metadata): # expose code object for debugging purposes self.output.install_global_unsafe(name, new_code) cg.make_function_with_closure(name, new_code, True, stack_len) + package_name = None else: # This is safe: we pre-generate a unique name self.output.install_global_unsafe( name, types.FunctionType(new_code, self.f_globals, name) ) cg.extend_output(cg.load_function_name(name, True, stack_len)) + package_name = name + + if self.package is not None: + self.package.add_resume_function( + new_code, self.f_globals["__name__"], package_name + ) cg.extend_output([cg.create_load(k) for k in argnames]) cg.extend_output(create_call_function(nargs, False)) @@ -3668,6 +3663,8 @@ def _return(self, inst): and not self.symbolic_locals_contain_module_class() and not self.export and not self.one_graph + and not self.error_on_graph_break + and not self.is_tracing_resume_prologue ): raise exc.SkipFrame("because no content in function call") @@ -3938,6 +3935,9 @@ def inline_call_(self): except Exception: log.debug("FAILED INLINING %s", code) raise + finally: + parent.error_on_graph_break = self.error_on_graph_break + assert self.symbolic_result is not None if self.f_globals is parent.f_globals: @@ -4022,6 +4022,7 @@ def __init__( speculation_log=parent.speculation_log, exn_vt_stack=parent.exn_vt_stack, distributed_state=parent.distributed_state, + package=parent.package, ) self.funcvar = funcvar self.parent = parent @@ -4059,7 +4060,12 @@ def RETURN_CONST(self, inst): raise ReturnValueOp def get_globals_source_and_value(self, name): - if "__name__" in self.f_globals: + # NamedTuple's `__new__` has a fake global scope that's not an actual + # module. TODO generalize the check for other non-importable cases. + # https://github.com/python/cpython/blob/8421b03b16a4852a527256cb7cdce2ab2d318548/Lib/collections/__init__.py#L441-L447 + if "__name__" in self.f_globals and not self.f_globals["__name__"].startswith( + "namedtuple_" + ): module_name = self.f_globals["__name__"] module_source = self.import_source(module_name) if "torch_package" in module_name: @@ -4122,7 +4128,7 @@ def STORE_GLOBAL(self, inst): class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): generated_items: list[VariableTracker] - # Flag wether or not the InlineGenerator should consume the entire iterator + # Flag whether or not the InlineGenerator should consume the entire iterator def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 7cd5dbc90d6cbc..dc7a4468405197 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -22,6 +22,7 @@ import torch import torch.testing +from torch._dynamo import polyfills from torch._logging._internal import trace_log from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] IS_WINDOWS, @@ -136,8 +137,8 @@ class CPythonTestCase(TestCase): assertRegex = unittest.TestCase.assertRegex assertNotRegex = unittest.TestCase.assertNotRegex assertCountEqual = unittest.TestCase.assertCountEqual - assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual - assertSequenceEqual = unittest.TestCase.assertSequenceEqual + assertMultiLineEqual = polyfills.assert_multi_line_equal + assertSequenceEqual = polyfills.assert_sequence_equal assertListEqual = unittest.TestCase.assertListEqual assertTupleEqual = unittest.TestCase.assertTupleEqual assertSetEqual = unittest.TestCase.assertSetEqual diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index 674157699884a0..32d10b53da99d1 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -73,6 +73,8 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() + if not os.path.exists(cls.DEBUG_DIR): + cls.DEBUG_DIR = tempfile.mkdtemp() cls._exit_stack.enter_context( # type: ignore[attr-defined] torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR) ) diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index e1b32b289abc84..85e44b7c7e4894 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -213,6 +213,7 @@ def insert_nops(instructions: list[Any], code_options: Any) -> None: global_scope=globals(), f_code=frame.f_code, torch_function_mode_stack=[], + package=None, ) return wrap_guarded_code( diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 9361e5c31efa68..de854807ba3b60 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1,4 +1,26 @@ # mypy: allow-untyped-defs + +""" +Tracing rules and policies for TorchDynamo compilation decisions. + +This module defines the rules that govern what code TorchDynamo should trace and compile +versus what should be executed eagerly. It contains functions and classes that determine: + +- Which modules, functions, and objects should be skipped during tracing +- Which parts of the code should cause graph breaks +- How to handle different Python libraries and third-party packages +- Rules for determining when to inline functions vs calling them eagerly + +Key components: +- Skip rules: Functions that return True if an object should be skipped during tracing +- Inlining rules: Policies for when to inline function calls during compilation +- Library-specific handling: Special cases for popular Python packages +- Performance heuristics: Rules that balance compilation overhead vs runtime benefits + +These rules are critical for TorchDynamo's ability to automatically determine +compilation boundaries and optimize PyTorch programs effectively. +""" + import abc import builtins import collections @@ -39,6 +61,7 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, + ReparametrizeModuleCallVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, UserFunctionVariable, @@ -143,6 +166,7 @@ "torch._utils.is_compiling": TorchInGraphFunctionVariable, "torch.fx._symbolic_trace.is_fx_tracing": TorchInGraphFunctionVariable, "torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable, + "torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer": UserFunctionVariable, "torch.compiler.is_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_exporting": TorchInGraphFunctionVariable, @@ -177,6 +201,10 @@ "torch.fx.node.map_aggregate": UserFunctionVariable, "torch.fx.node.map_arg": UserFunctionVariable, "torch.fx.immutable_collections._no_mutation": UserFunctionVariable, + "torch.fx.immutable_collections._immutable_list_flatten": UserFunctionVariable, + "torch.fx.immutable_collections._immutable_list_unflatten": UserFunctionVariable, + "torch.fx.immutable_collections._immutable_dict_flatten": UserFunctionVariable, + "torch.fx.immutable_collections._immutable_dict_unflatten": UserFunctionVariable, # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable, @@ -280,6 +308,7 @@ # functional_call "torch._functorch.functional_call.functional_call": FunctionalCallVariable, "torch.nn.utils.stateless._groupby_tensor": TorchInGraphFunctionVariable, + "torch.nn.utils.stateless._reparametrize_module": ReparametrizeModuleCallVariable, # functorch/deprecated "torch._functorch.deprecated.jvp": UserFunctionVariable, "torch._functorch.deprecated.hessian": UserFunctionVariable, @@ -311,11 +340,15 @@ "torch._dynamo.mark_static": UserFunctionVariable, "torch._dynamo.nonstrict_trace": UserFunctionVariable, "torch._dynamo.patch_dynamo_config": UserFunctionVariable, + "torch._dynamo.set_fullgraph": UserFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.statically_known_true": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.statically_known_false": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.sym_and": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.sym_or": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_scalar": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.has_static_value": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, @@ -326,6 +359,8 @@ "torch.sparse_csr_tensor": SkipFunctionVariable, "torch.sparse_compressed_tensor": SkipFunctionVariable, "torch._C._autograd._unsafe_set_version_counter": TorchInGraphFunctionVariable, + "torch.xpu.get_rng_state": SkipFunctionVariable, + "torch.xpu.set_rng_state": SkipFunctionVariable, # avoid skipping user defined modules in distributed unit tests "torch/testing/_internal/common_fsdp.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_fsdp.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, @@ -403,6 +438,11 @@ "torch._assert_async", "torch._assert_tensor_metadata", "torch._batch_norm_impl_index", + "torch._C._accelerator_getAccelerator", + "torch._C._accelerator_getDeviceIndex", + "torch._C._accelerator_getStream", + "torch._C._accelerator_setStream", + "torch._C._accelerator_synchronizeDevice", "torch._C._activate_gpu_trace", "torch._C._add_cached_tensor", "torch._C._add_docstr", @@ -1319,6 +1359,21 @@ "torch._C._warn", "torch._C._will_engine_execute_node", "torch._C._wrap_tensor_impl", + "torch._C._xpu_emptyCache", + "torch._C._xpu_getArchFlags", + "torch._C._xpu_getCurrentStream", + "torch._C._xpu_getCurrentRawStream", + "torch._C._xpu_getDeviceCount", + "torch._C._xpu_getDevice", + "torch._C._xpu_getMemoryInfo", + "torch._C._xpu_getStreamFromExternal", + "torch._C._xpu_isInBadFork", + "torch._C._xpu_init", + "torch._C._xpu_memoryStats", + "torch._C._xpu_resetAccumulatedMemoryStats", + "torch._C._xpu_resetPeakMemoryStats", + "torch._C._xpu_setStream", + "torch._C._xpu_synchronize", "torch._C.fork", "torch._C.get_autocast_cpu_dtype", "torch._C.get_autocast_dtype", @@ -1515,6 +1570,7 @@ "torch._fused_sdp_choice", "torch._fw_primal_copy", "torch._grid_sampler_2d_cpu_fallback", + "torch._grouped_mm", "torch._has_compatible_shallow_copy_type", "torch._histogramdd_bin_edges", "torch._histogramdd_from_bin_cts", @@ -2240,6 +2296,7 @@ "torch.slice_inverse", "torch._assert_scalar", "torch._functional_assert_scalar", + "torch.xpu._get_device_properties", ], TorchInGraphFunctionVariable, ) @@ -2351,6 +2408,13 @@ "torch._utils._unflatten_dense_tensors", "torch._weights_only_unpickler._get_allowed_globals", "torch._weights_only_unpickler.load", + "torch.accelerator.current_accelerator", + "torch.accelerator.current_device_index", + "torch.accelerator.current_stream", + "torch.accelerator.device_count", + "torch.accelerator.is_available", + "torch.accelerator.set_stream", + "torch.accelerator.synchronize", "torch.align_tensors", "torch.amp.autocast_mode._enter_autocast", "torch.amp.autocast_mode._exit_autocast", @@ -2855,6 +2919,43 @@ "torch.tensordot", "torch.unique_consecutive", "torch.use_deterministic_algorithms", + "torch.xpu._get_device", + "torch.xpu._get_generator", + "torch.xpu._get_rng_state_offset", + "torch.xpu._is_compiled", + "torch.xpu._lazy_call", + "torch.xpu._lazy_init", + "torch.xpu._set_rng_state_offset", + "torch.xpu._set_stream_by_id", + "torch.xpu._utils._get_device_index", + "torch.xpu.current_device", + "torch.xpu.current_stream", + "torch.xpu.device_count", + "torch.xpu.get_arch_list", + "torch.xpu.get_device_capability", + "torch.xpu.get_device_name", + "torch.xpu.get_device_properties", + "torch.xpu.get_gencode_flags", + "torch.xpu.get_stream_from_external", + "torch.xpu.init", + "torch.xpu.is_available", + "torch.xpu.is_bf16_supported", + "torch.xpu.is_initialized", + "torch.xpu.memory.empty_cache", + "torch.xpu.memory.max_memory_allocated", + "torch.xpu.memory.max_memory_reserved", + "torch.xpu.memory.mem_get_info", + "torch.xpu.memory.memory_allocated", + "torch.xpu.memory.memory_reserved", + "torch.xpu.memory.memory_stats_as_nested_dict", + "torch.xpu.memory.memory_stats", + "torch.xpu.memory.reset_accumulated_memory_stats", + "torch.xpu.memory.reset_peak_memory_stats", + "torch.xpu.random.initial_seed", + "torch.xpu.random.seed_all", + "torch.xpu.random.seed", + "torch.xpu.set_stream", + "torch.xpu.synchronize", ], TorchInGraphFunctionVariable, ) @@ -2872,7 +2973,7 @@ """ -@functools.lru_cache(None) +@functools.cache def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: d: dict[Any, type[VariableTracker]] = {} for m in torch_name_rule_map: @@ -2921,7 +3022,7 @@ def load_object(name): """ -@functools.lru_cache(None) +@functools.cache def get_tensor_method(): disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"} s = set() @@ -3443,7 +3544,7 @@ def _module_dir(m: types.ModuleType): MOD_SKIPLIST = set(MOD_SKIPLIST) -@functools.lru_cache(None) +@functools.cache def get_legacy_mod_inlinelist(): inlinelist = { _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) @@ -3452,7 +3553,7 @@ def get_legacy_mod_inlinelist(): return inlinelist -@functools.lru_cache(None) +@functools.cache def get_mod_inlinelist(): inlinelist = { _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) @@ -3461,7 +3562,7 @@ def get_mod_inlinelist(): return inlinelist -@functools.lru_cache(None) +@functools.cache def get_mod_skiplist(): skiplist = { _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) @@ -3628,9 +3729,9 @@ def f3(x, y): ...... There are mainly three call sites of check/check_verbose: -* The compile region entrance (like function f1), the correspoinding code is located at eval_frame.py. +* The compile region entrance (like function f1), the corresponding code is located at eval_frame.py. * When tracing the recursively called functions (like function f2 and f3). - * Dynamo decides inline/skip everytime it encounters a new recursively function call, and the call site + * Dynamo decides inline/skip every time it encounters a new recursively function call, and the call site is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py. * If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again and the call site is in catch_errors_wrapper.catch_errors of convert_frame.py. @@ -3715,7 +3816,7 @@ def is_torch_inline_allowed(filename): return any(filename.startswith(d) for d in get_mod_inlinelist()) -@functools.lru_cache(None) +@functools.cache def dynamo_dir(): import torch._dynamo diff --git a/torch/_dynamo/types.py b/torch/_dynamo/types.py index 2fd9c580181606..fc9bc601fd6358 100644 --- a/torch/_dynamo/types.py +++ b/torch/_dynamo/types.py @@ -75,7 +75,7 @@ class ConvertFrameReturn: # default return is no compiled code (i.e. `return None`): # strategy is to skip non-recursively, for all future intercepted frames too - # eval fram execution strategy for this frame + # eval frame execution strategy for this frame frame_exec_strategy: FrameExecStrategy = dataclasses.field( default_factory=lambda: FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) ) @@ -114,6 +114,13 @@ def __call__( ) -> None: ... +class DynamoGuardCompleteHook(Protocol): + def __call__( + self, + cache_hit: bool, + ) -> bool: ... + + class ProfilerStartHook(Protocol): def __call__( self, diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 0df0d51d0cf665..3c169333302ab5 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -47,10 +47,10 @@ import warnings import weakref from collections import Counter, OrderedDict -from contextlib import contextmanager +from contextlib import AbstractContextManager, contextmanager from dataclasses import is_dataclass from functools import lru_cache -from types import MethodWrapperType +from types import CodeType, MethodWrapperType from typing import ( Any, Callable, @@ -62,7 +62,7 @@ TypeVar, Union, ) -from typing_extensions import Literal, TypeIs +from typing_extensions import Literal, TypeAlias, TypeGuard, TypeIs import torch import torch._functorch.config @@ -1046,19 +1046,45 @@ def is_numpy_float_type(value): ) -def is_lru_cache_wrapped_function(value): +@overload +def is_lru_cache_wrapped_function( + value: Callable[..., T], +) -> TypeGuard[functools._lru_cache_wrapper[T]]: ... + + +@overload +def is_lru_cache_wrapped_function( + value: Any, +) -> TypeGuard[functools._lru_cache_wrapper[Any]]: ... + + +def is_lru_cache_wrapped_function( + value: Any, +) -> bool: return isinstance(value, functools._lru_cache_wrapper) and is_function( inspect.getattr_static(value, "__wrapped__") ) -def is_function_or_wrapper(value): +_FuncTypes: TypeAlias = Union[ + types.FunctionType, + types.BuiltinFunctionType, + types.MethodDescriptorType, + types.WrapperDescriptorType, +] + + +def is_function_or_wrapper( + value: Any, +) -> TypeIs[Union[_FuncTypes, torch._ops.OpOverloadPacket, torch._ops.OpOverload]]: return is_function(value) or isinstance( value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) ) -def is_function(value): +def is_function( + value: Any, +) -> TypeIs[_FuncTypes]: return isinstance( value, ( @@ -1090,7 +1116,17 @@ def is_function(value): } -def is_wrapper_or_member_descriptor(value): +def is_wrapper_or_member_descriptor( + value: Any, +) -> TypeIs[ + Union[ + types.GetSetDescriptorType, + types.MethodDescriptorType, + types.WrapperDescriptorType, + types.MemberDescriptorType, + types.MethodWrapperType, + ] +]: return isinstance( value, ( @@ -2127,7 +2163,9 @@ def preserve_rng_state(): torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] -def is_jit_model(model0): +def is_jit_model( + model0, +): return isinstance( model0, ( @@ -2331,7 +2369,7 @@ def is_safe_constant(v): ) -@functools.lru_cache(None) +@functools.cache def common_constants(): return { # We zero-one specialize shapes, so specialize these constants @@ -2341,7 +2379,7 @@ def common_constants(): } -def is_torch_sym(value): +def is_torch_sym(value: Any) -> TypeGuard[Union[torch.SymBool, torch.SymInt]]: return isinstance(value, (torch.SymBool, torch.SymInt)) and not isinstance( value.node, torch.nested._internal.nested_int.NestedIntNode ) @@ -2364,6 +2402,10 @@ def is_int_specialization_case(value, source): source.guard_source().is_unspecialized_builtin_nn_module() and not config.allow_unspec_int_on_nn_module ) + or ( + source.guard_source().is_unspecialized_nn_module() + and not config.allow_unspec_int_on_nn_module + ) or is_from_defaults(source) # TODO: Delete this condition when rollout is done. NB: this # condition never evaluates True in open source @@ -2461,6 +2503,10 @@ def check_numpy_ndarray_args(args, kwargs): for method in itertools.chain(dict.__dict__.values(), OrderedDict.__dict__.values()) if callable(method) } +set_methods = {method for method in set.__dict__.values() if callable(method)} +frozenset_methods = { + method for method in frozenset.__dict__.values() if callable(method) +} tuple_new = tuple.__new__ tuple_methods = {method for method in tuple.__dict__.values() if callable(method)} @@ -2530,6 +2576,11 @@ def dict_keys_getitem(d, n): return next(itertools.islice(dict_class.keys(d), n, n + 1)) +def set_getitem(s, n): + # Set ordering might not be stable + return list(s)[n] + + def enum_repr(value, local): # enum class can override __str__ method. Use __class__ and name attribute # to extract the class name and key name. @@ -2614,7 +2665,9 @@ def iter_contains(items, search, tx, check_tensor_identity=False): return found -def key_is_id(k): +def key_is_id( + k: Any, +) -> TypeIs[Union[torch.Tensor, torch.nn.Module, MethodWrapperType]]: """Returns whether it indexes dictionaries using its id""" return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType)) @@ -2670,7 +2723,7 @@ def dict_keys_repr(const_keys, *, local) -> str: def get_safe_global_name(tx, root, obj): # The global_mangled_class_name should be different for different # invocations of torch.compile. Otherwise, we can run into a situation - # where multiple torch.compile invocations re-use the same global name, + # where multiple torch.compile invocations reuse the same global name, # but the global's lifetime is tied to the first invocation (and # may be deleted when the first torch.compile invocation is deleted) # We mangle it based off of the output_graph's id. @@ -2933,7 +2986,7 @@ def get_multiplier(): ): # In the presence of noise, noise might dominate our error # metric for smaller tensors. - # Similary, for 1x1 kernels, there seems to be high noise with amp. + # Similarly, for 1x1 kernels, there seems to be high noise with amp. multiplier = 3.0 return multiplier @@ -3067,7 +3120,7 @@ def disable_cache_limit(): # return same dir unless user changes config between calls -@functools.lru_cache(None) +@functools.cache def _get_debug_dir(root_dir): dir_name = ( "run_" @@ -3505,6 +3558,12 @@ def object_has_getattribute(value: Any): return class_has_getattribute(type(value)) +def object_setattr_ignore_descriptor(obj, name, value): + # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1286-L1335 + d = object.__getattribute__(obj, "__dict__") + d[name] = value + + def class_has_getattribute(cls: type): try: if isinstance( @@ -3793,6 +3852,10 @@ def defake(x): return y +def _disable_side_effect_safety_checks_for_current_subtracer(fn, *args, **kwargs): + return fn(*args, **kwargs) + + def is_utils_checkpoint(obj): # Lazy import to avoid circular dependencies import torch.utils.checkpoint @@ -4445,38 +4508,22 @@ def does_not_override_dict_iter_methods(user_cls): ) -# Helper functions below are to prevent __torch_function__ -# calls from happening in the middle of __torch_function__ -# compiled bytecode -# They will be skipped which is the desired result +# Helper functions below are to prevent TorchDynamo to prevent tracing of +# __torch_function__ calls triggered on tensor properties in the pre graph +# bytecode. +@torch._disable_dynamo def call_size(x, i): - @torch._dynamo.disable( - recursive=True, reason="__torch_function__ tracing helper function" - ) - def fn(x, i): - return x.size(i) - - return fn(x, i) + return x.size(i) +@torch._disable_dynamo def call_stride(x, i): - @torch._dynamo.disable( - recursive=True, reason="__torch_function__ tracing helper function" - ) - def fn(x, i): - return x.stride(i) - - return fn(x, i) + return x.stride(i) +@torch._disable_dynamo def call_storage_offset(x): - @torch._dynamo.disable( - recursive=True, reason="__torch_function__ tracing helper function" - ) - def fn(x): - return x.storage_offset() - - return fn(x) + return x.storage_offset() # Helper function to extract relevant parts of a tensor's __dict__ to store in node meta. @@ -4624,3 +4671,40 @@ def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]: def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool: return node is None or "example_value" in node.meta or "val" in node.meta + + +# If True, enforce fullgraph=True - raise errors on graph break +_error_on_graph_break = False + + +def _get_error_on_graph_break() -> bool: + return _error_on_graph_break + + +def _set_error_on_graph_break(value: bool) -> None: + global _error_on_graph_break + _error_on_graph_break = value + + +@torch._disable_dynamo +def record_pregraph_bytecode_enter() -> AbstractContextManager[None]: + cm: AbstractContextManager[None] = ( + torch._C._profiler._RecordFunctionFast("Pregraph bytecode") + if torch.autograd.profiler._is_profiler_enabled + else contextlib.nullcontext() + ) + cm.__enter__() + return cm + + +@torch._disable_dynamo +def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None: + cm.__exit__(None, None, None) + + +# Returns a set of code objects present traced in the current TracingContext, or None +# if there is no current TracingContext. +def get_traced_code() -> list[CodeType]: + from torch._guards import TracingContext + + return TracingContext.get_traced_code() diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 45ca28fd645ee1..73209fdfff629f 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -34,6 +34,7 @@ InferenceModeVariable, JvpIncrementNestingCtxManagerVariable, SDPAKernelVariable, + SetFullgraphVariable, SetFwdGradEnabledContextManager, StreamContextVariable, StreamVariable, @@ -54,7 +55,8 @@ from .functions import ( BuiltinMethodVariable, CollectionsNamedTupleFunction, - CreateTMADescriptorVariable, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, FunctionDecoratedByContextlibContextManagerVariable, FunctoolsPartialVariable, FunctoolsWrapsVariable, @@ -63,13 +65,17 @@ NestedUserFunctionVariable, PolyfilledFunctionVariable, SkipFunctionVariable, - TMADescriptorVariable, + TMADescriptorExperimentalVariable, + TMADescriptorStableVariable, UserFunctionVariable, UserMethodVariable, + WrapperUserFunctionVariable, + WrapperUserMethodVariable, ) from .higher_order_ops import ( FunctionalCallVariable, FunctorchHigherOrderVariable, + ReparametrizeModuleCallVariable, TorchHigherOrderOperatorVariable, ) from .iter import ( @@ -142,6 +148,7 @@ UserDefinedExceptionObjectVariable, UserDefinedListVariable, UserDefinedObjectVariable, + UserDefinedSetVariable, UserDefinedTupleVariable, ) @@ -157,7 +164,8 @@ "ConstDictVariable", "ContextWrappingVariable", "CountIteratorVariable", - "CreateTMADescriptorVariable", + "CreateTMADescriptorExperimentalVariable", + "CreateTMADescriptorStableVariable", "CUDADeviceVariable", "CycleIteratorVariable", "DataPtrVariable", @@ -192,13 +200,15 @@ "RemovableHandleVariable", "RepeatIteratorVariable", "SDPAParamsVariable", + "SetFullgraphVariable", "SkipFunctionVariable", "SliceVariable", "StringFormatVariable", "SuperVariable", "TemporarilyPopInterpreterStackCtxManagerVariable", "TensorVariable", - "TMADescriptorVariable", + "TMADescriptorExperimentalVariable", + "TMADescriptorStableVariable", "TorchCtxManagerClassVariable", "TorchInGraphFunctionVariable", "TorchVersionVariable", diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 90f8a54dd7ef48..e786a0281087ab 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -210,7 +210,7 @@ class AsPythonConstantNotImplementedError(NotImplementedError): vt: "VariableTracker" def __init__(self, vt: "VariableTracker"): - super().__init__(self, f"{vt} is not a constant") + super().__init__(f"{vt} is not a constant") self.vt = vt @@ -547,6 +547,12 @@ def call_method( "This can happen unintentionally if a previous graph break happens with a builtin iterator " "in the local scope." ) + hints.append( + "List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo " + "cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, " + "(2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a " + "function, or (4) use Python 3.12+." + ) unimplemented_v2( gb_type="Unsupported method call", context=f"call_method {self} {name} {args} {kwargs}", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index eba4ca66f28bf0..b4ef5331ae6d8c 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -99,6 +99,7 @@ ConstDictKeySource, ConvertIntSource, DictGetItemSource, + DictSubclassGetItemSource, FloatTensorSource, GetItemSource, GradSource, @@ -109,12 +110,15 @@ is_from_unspecialized_nn_module_source, ListGetItemSource, LocalSource, + NonSerializableSetGetItemSource, NumpyTensorSource, OptimizerSource, RandomValueSource, Source, SubclassAttrListSource, TupleIteratorGetItemSource, + UnspecializedBuiltinNNModuleSource, + UnspecializedNNModuleSource, ) from ..utils import ( _extract_tensor_dict, @@ -164,6 +168,7 @@ EventVariable, NullContextVariable, PreserveVersionContextVariable, + SetFullgraphVariable, StreamContextVariable, StreamVariable, ) @@ -186,7 +191,8 @@ BuiltinMethodVariable, CollectionsNamedTupleFunction, CollectiveFunctionRewriteVariable, - CreateTMADescriptorVariable, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, FunctoolsPartialVariable, FunctoolsWrapsVariable, SysFunctionVariable, @@ -272,6 +278,7 @@ UserDefinedExceptionClassVariable, UserDefinedListVariable, UserDefinedObjectVariable, + UserDefinedSetVariable, UserDefinedTupleVariable, ) @@ -434,7 +441,10 @@ def __call__(self, value): return cached_vt vt = self._wrap(value) - vt.source = self.source + + if vt.source is None: + vt.source = self.source + if ( self._can_lift_attrs_to_inputs(vt) and value not in self.tx.output.side_effects @@ -470,7 +480,7 @@ def _type_dispatch(cls): return cls._type_dispatch_impl(config.trace_numpy) @classmethod - @functools.lru_cache(None) + @functools.cache def _type_dispatch_impl(cls, trace_numpy): # NB: Careful not to close over self to avoid ref cycle from lru_cache entries = [ @@ -571,7 +581,7 @@ def build_key_value(k, v): return self.tx.output.side_effects.track_mutable(value, result) @classmethod - @functools.lru_cache(None) + @functools.cache def _id_dispatch( cls, ) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: @@ -600,9 +610,16 @@ def _id_dispatch( def _wrap(self, value): # import here to avoid circular dependencies - from torch.utils._triton import has_triton, has_triton_tma + from torch.utils._triton import ( + has_triton, + has_triton_experimental_host_tma, + has_triton_tensor_descriptor_host_tma, + ) - from ..decorators import DynamoConfigPatchProxy + from ..decorators import ( + DynamoConfigPatchProxy, + SetFullgraphDecoratorContextManager, + ) if has_triton(): from triton.runtime.autotuner import Autotuner @@ -615,19 +632,26 @@ class JITFunction: class Autotuner: pass - if has_triton_tma(): - from triton.tools.experimental_descriptor import ( - create_1d_tma_descriptor, - create_2d_tma_descriptor, - ) - else: + # default implementations, in case we don't have triton (or the wrong triton version) + def create_1d_tma_descriptor(): + pass - def create_1d_tma_descriptor(): - pass + def create_2d_tma_descriptor(): + pass - def create_2d_tma_descriptor(): + class TensorDescriptor: + @staticmethod + def from_tensor(): pass + if has_triton_experimental_host_tma(): + from triton.tools.experimental_descriptor import ( # noqa: F811 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + if has_triton_tensor_descriptor_host_tma(): + from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811 + # Handle exact type() match type_dispatch = self._type_dispatch().get(type(value)) if type_dispatch is not None: @@ -695,7 +719,7 @@ def create_2d_tma_descriptor(): # 2) For non-constant objects, we also have to guard on the keys # (like TENSOR_MATCH on tensor). We might also have guards on # the attributes of the keys (like tensor.grad). To make this - # work in tree strucutre is complicated. + # work in tree structure is complicated. # # So, instead we guard on the key order. While guarding on key # order, we just save the indices and use it to access keys and @@ -750,6 +774,38 @@ def build_key_value(i, k, v): var = TorchFunctionModeVariable(value, source=self.source) self.tx.output.side_effects.track_object_existing(value, var) return var + elif istype(value, set): + if any(isinstance(x, torch.Tensor) for x in value): + unimplemented_v2( + gb_type="Attempted to wrap a set with tensors", + context="Python set containing torch.Tensor elements", + explanation=( + "Dynamo cannot trace sets of tensors. To get a stable ordering, " + "Dynamo needs to convert the set into a list and the order might not be " + "stable if the set contains tensors." + ), + hints=[ + "Use a dictionary where the keys are tensors.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # The list gives a ordering for the set items. The ordering is based + # on the Python hash and it is not related to object ordering inside + # the set object. The order being incorrect at runtime will lead to + # a recompilation. + L = list(value) + items = [ + LazyVariableTracker.create( + v, source=NonSerializableSetGetItemSource(self.source, i) + ) + for i, v in enumerate(L) + ] + result = SetVariable(items, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) elif istype(value, frozenset) and all( ( # For DBR quantization, we could get a frozenset of torch funcs. @@ -923,6 +979,8 @@ def build_key_value(i, k, v): ) elif isinstance(value, DynamoConfigPatchProxy): return DynamoConfigPatchVariable(value.changes) + elif isinstance(value, SetFullgraphDecoratorContextManager): + return SetFullgraphVariable(value.fullgraph) elif callable(value) and trace_rules.lookup_callable(value) is not None: if trace_rules.is_callable_allowed(value): self.tx.output.has_user_defined_allowed_in_graph = True @@ -936,7 +994,7 @@ def build_key_value(i, k, v): unimplemented_v2( gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph", context="", - explanation="Directly using invoke_subgraph is not supported. Use mark_compile_region", + explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region", hints=[], ) self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH) @@ -1032,7 +1090,7 @@ def build_key_value(i, k, v): return ItertoolsVariable(value, source=self.source) elif is_torch_sym(value): # Note: this doesn't handle nested symints. - # For SymBool input, we re-use the infra for SymInt by simulating SymBool with a SymInt in dynamo. + # For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo. # Concretely, # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source). @@ -1105,9 +1163,11 @@ def build_key_value(i, k, v): source=self.source, ) elif value is create_1d_tma_descriptor: - return CreateTMADescriptorVariable(rank=1) + return CreateTMADescriptorExperimentalVariable(rank=1) elif value is create_2d_tma_descriptor: - return CreateTMADescriptorVariable(rank=2) + return CreateTMADescriptorExperimentalVariable(rank=2) + elif value is TensorDescriptor.from_tensor: + return CreateTMADescriptorStableVariable() elif isinstance(value, torch.amp.autocast_mode.autocast): self.install_guards(GuardBuilder.ID_MATCH) return AutocastModeVariable( @@ -1277,7 +1337,7 @@ def build_key_value(i, k, v): ) # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default - # seting example to be real value because these example values will be used + # setting example to be real value because these example values will be used # as example_inputs for user compiler. proxy.node.meta["grapharg"] = GraphArg( self.source, value, False, None, False, value @@ -1322,7 +1382,7 @@ def build_key_value(i, k, v): ) # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default - # seting example to be real value because these example values will be used + # setting example to be real value because these example values will be used # as example_inputs for user compiler. proxy.node.meta["grapharg"] = GraphArg( self.source, value, False, None, False, fake_script_obj @@ -1350,7 +1410,7 @@ def build_key_value(i, k, v): source_key = ConstDictKeySource(base, i) key = LazyVariableTracker.create(k, source_key) - source_value = DictGetItemSource(base, source_key) + source_value = DictSubclassGetItemSource(base, source_key) res_value = LazyVariableTracker.create(v, source_value) return key, res_value @@ -1380,7 +1440,7 @@ def build_key_value(i, k, v): result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source) return self.tx.output.side_effects.track_object_existing(value, result) - elif isinstance(value, tuple) and type(value).__new__ is tuple.__new__: + elif isinstance(value, tuple): self.install_guards(GuardBuilder.TYPE_MATCH) self.install_guards(GuardBuilder.SEQUENCE_LENGTH) @@ -1397,7 +1457,7 @@ def build_key_value(i, k, v): tuple_vt = TupleVariable( output, source=self.source, mutation_type=ValueMutationExisting() ) - result = UserDefinedTupleVariable.create( + result = UserDefinedTupleVariable( value, tuple_vt=tuple_vt, source=self.source ) return self.tx.output.side_effects.track_object_existing(value, result) @@ -1419,6 +1479,24 @@ def build_key_value(i, k, v): ) result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source) return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, (set, frozenset)): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + L = list(dict.fromkeys(value)) + output = [ + LazyVariableTracker.create( + list.__getitem__(L, i), + source=NonSerializableSetGetItemSource(self.get_source(), i), + ) + for i in range(list.__len__(L)) + ] + set_vt_cls = SetVariable if isinstance(value, set) else FrozensetVariable + set_vt = set_vt_cls( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedSetVariable(value, set_vt=set_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) elif issubclass(type(value), MutableMapping): self.install_guards(GuardBuilder.TYPE_MATCH) return MutableMappingVariable(value, source=self.source) @@ -1714,7 +1792,6 @@ def wrap_module(self, value: torch.nn.Module): value = value.get_base() self.source = AttrProxySource(self.source) - self.install_guards(GuardBuilder.TYPE_MATCH) if torch._dynamo.config.inline_inbuilt_nn_modules: freezing = is_parameter_freezing() @@ -1749,12 +1826,27 @@ def wrap_module(self, value: torch.nn.Module): # this will get cleaned up once compile ends self.tx.output.nn_modules[self.name] = value - if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr( - value.__class__, "_dynamo_marked_static", False - ): - result = UnspecializedBuiltinNNModuleVariable(value, source=self.source) + if ( + value.__module__.startswith(("torch.nn.modules", "torch.ao.")) + and not value.__module__.startswith("torch.nn.modules.container") + ) or getattr(value.__class__, "_dynamo_marked_static", False): + new_source = self.source + if config.inline_inbuilt_nn_modules and ( + not self.tx.output.export or config.install_free_tensors + ): + # Export corner case - look at test_repros.py test_inlining_cornercase + new_source = UnspecializedBuiltinNNModuleSource(self.source) + result = UnspecializedBuiltinNNModuleVariable(value, source=new_source) + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) else: - result = UnspecializedNNModuleVariable(value, source=self.source) + new_source = self.source + if config.inline_inbuilt_nn_modules and ( + not self.tx.output.export or config.install_free_tensors + ): + # Export corner case - look at test_repros.py test_inlining_cornercase + new_source = UnspecializedNNModuleSource(self.source) + result = UnspecializedNNModuleVariable(value, source=new_source) + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) if not SideEffects.cls_supports_mutation_side_effects(type(value)): # don't allow STORE_ATTR mutation with custom __setattr__ @@ -1788,7 +1880,28 @@ def wrap_literal(self, value): # unspecializing int by default, but still # specialize for the following conditions if is_int_specialization_case(value, self.source): - self.install_guards(GuardBuilder.CONSTANT_MATCH) + recompile_hint = None + if ( + self.source.guard_source().is_unspecialized_builtin_nn_module() + or self.source.guard_source().is_unspecialized_nn_module() + ): + # This means that it is an integer from a NN module. + # Dynamo considers nn module int attributes to be static + # (a good heuristic). But a user might want to mark the + # int attribute to be a symint, so track this integer + # for recompilation later. + recompile_hint = ( + "torch.compile considers integer attributes of the nn.Module to be static. " + "If you are observing recompilation, you might want to make this integer dynamic " + "using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this " + "integer into a tensor." + ) + + self.install_guards( + functools.partial( + GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint + ) + ) return ConstantVariable.create(value=value, source=self.source) return self.wrap_symint(value) @@ -1943,7 +2056,7 @@ def wrap_tensor(self, value: torch.Tensor): ): # A hot fix for sparse tensors + torch.compile. Support for # export + sparsity is being added but we need to create - # SPARSE_TENSOR_GUARDS for guards to work propertly. + # SPARSE_TENSOR_GUARDS for guards to work properly. unimplemented_v2( gb_type="Attempted to wrap sparse Tensor", context="", @@ -2127,6 +2240,10 @@ def wrap_numpy_ndarray(self, value): ) proxy.node.meta["grapharg"] = grapharg + # TODO - Why do we need to set the source of the np ndarray vt back to + # original source. Many tests fails. + numpy_ndarray_variable.source = self.source + return numpy_ndarray_variable def wrap_symint( @@ -3134,6 +3251,7 @@ def update_dim2constraint(dim, constraint_range, name): dynamic_strides = [] constraint_sizes = [] constraint_strides = [] + specialize_on = [] for i in range(e.dim()): # NB: mark dynamic has precedence over static marked_strict_unbacked = i in getattr( @@ -3144,6 +3262,8 @@ def update_dim2constraint(dim, constraint_range, name): marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) marked_static = i in getattr(e, "_dynamo_static_indices", set()) + specialize_on.append(getattr(e, "_specialize_on", {}).get(i, [])) + # Reflect the user directive in the frame_state # For dynamic, apply None always @@ -3271,6 +3391,7 @@ def update_dim2constraint(dim, constraint_range, name): dynamic_strides=dynamic_strides, constraint_sizes=constraint_sizes, constraint_strides=constraint_strides, + specialize_on=specialize_on, view_base_context=view_base_context, tensor_source=source, shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 1356682ec37f8b..ec2603288b5070 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1,5 +1,26 @@ # mypy: allow-untyped-defs +""" +Built-in function and type variable tracking for TorchDynamo's symbolic execution. + +This module contains variable tracker classes for Python built-in functions, types, +and operations during graph compilation. It handles symbolic execution of: + +- Built-in functions (len, getattr, isinstance, etc.) +- Type constructors (int, float, str, list, dict, etc.) +- Built-in operators and methods +- Special Python constructs (super, hasattr, etc.) + +Key classes: +- BuiltinVariable: Tracks built-in functions and handles their execution +- TypeVariable: Manages type constructor calls and type checking +- SuperVariable: Handles super() calls in class hierarchies + +These variable trackers ensure that built-in Python operations are correctly +handled during symbolic execution, either by executing them directly when safe +or by creating appropriate graph nodes when needed. +""" + import contextlib import functools import inspect @@ -47,6 +68,7 @@ cmp_name_to_op_mapping, dict_methods, extract_fake_example_value, + frozenset_methods, get_fake_value, guard_if_dyn, is_tensor_getset_descriptor, @@ -54,6 +76,7 @@ istype, numpy_operator_wrapper, proxy_args_kwargs, + set_methods, str_methods, tensortype_to_dtype, ) @@ -63,6 +86,7 @@ from .dicts import ( ConstDictVariable, DefaultDictVariable, + DictKeysVariable, DictViewVariable, FrozensetVariable, is_hashable, @@ -83,7 +107,11 @@ TensorVariable, UnspecializedPythonVariable, ) -from .user_defined import UserDefinedObjectVariable, UserDefinedVariable +from .user_defined import ( + UserDefinedObjectVariable, + UserDefinedSetVariable, + UserDefinedVariable, +) if TYPE_CHECKING: @@ -148,7 +176,7 @@ def create_with_source(cls, value, source): return cls(value, source=source) @staticmethod - @functools.lru_cache(None) + @functools.cache def _constant_fold_functions(): fns = { abs, @@ -218,7 +246,7 @@ def can_constant_fold_through(self): return self.fn in self._constant_fold_functions() @staticmethod - @functools.lru_cache(None) + @functools.cache def _fx_graph_functions(): fns = { operator.abs, @@ -264,7 +292,7 @@ def _fx_graph_functions(): return fns @staticmethod - @functools.lru_cache(None) + @functools.cache def _binops() -> dict[ Callable[..., object], tuple[list[str], Callable[..., object]] ]: @@ -303,7 +331,7 @@ def _binops() -> dict[ return fns @staticmethod - @functools.lru_cache(None) + @functools.cache def _binop_handlers(): # Multiple dispatch mechanism defining custom binop behavior for certain type # combinations. Handlers are attempted in order, and will be used if the type checks @@ -439,6 +467,14 @@ def size_add_handler(tx: "InstructionTranslator", a, b): (SizeVariable, SizeVariable), size_add_handler, ), + ( + (SizeVariable, TupleVariable), + size_add_handler, + ), + ( + (TupleVariable, SizeVariable), + size_add_handler, + ), ( (TupleVariable, TupleVariable), tuple_add_handler, @@ -795,7 +831,7 @@ def _make_handler(fn, arg_types: list[type], has_kwargs: bool): if inspect.isclass(fn) and ( issubclass(fn, Exception) - # GeneratorExit doens't inherit from Exception + # GeneratorExit doesn't inherit from Exception # >>> issubclass(GeneratorExit, Exception) # False or fn is GeneratorExit @@ -1205,20 +1241,17 @@ def call_method( and args[1].has_unpack_var_sequence(tx) and not kwargs ): - init_args = args[1].unpack_var_sequence(tx) - tuple_vt = variables.TupleVariable( - init_args, mutation_type=ValueMutationNew() - ) if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple: - return tuple_vt + init_args = args[1].unpack_var_sequence(tx) + return variables.TupleVariable( + init_args, mutation_type=ValueMutationNew() + ) - result = tx.output.side_effects.track_new_user_defined_object( + return tx.output.side_effects.track_new_user_defined_object( self, args[0], args[1:], ) - result.set_underlying_tuple_vt(tuple_vt) - return result if self.fn is list: list_vt = ListVariable([], mutation_type=ValueMutationNew()) @@ -1245,12 +1278,32 @@ def call_method( elif isinstance(args[0], variables.ConstDictVariable): return args[0].call_method(tx, name, args[1:], kwargs) + if self.fn is set: + resolved_fn = getattr(self.fn, name) + if resolved_fn in set_methods: + if isinstance(args[0], variables.UserDefinedSetVariable): + return args[0]._set_vt.call_method(tx, name, args[1:], kwargs) + elif isinstance(args[0], variables.SetVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is frozenset: + resolved_fn = getattr(self.fn, name) + if resolved_fn in frozenset_methods: + if isinstance(args[0], variables.FrozensetVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + if self.fn is str and len(args) >= 1: resolved_fn = getattr(self.fn, name) if resolved_fn in str_methods: if isinstance(args[0], ConstantVariable): return args[0].call_method(tx, name, args[1:], kwargs) + if self.fn is float and len(args) >= 1: + if isinstance(args[0], ConstantVariable): + return ConstantVariable.create( + getattr(float, name)(args[0].as_python_constant()) + ) + return super().call_method(tx, name, args, kwargs) def _call_int_float(self, tx: "InstructionTranslator", arg): @@ -1277,6 +1330,24 @@ def _call_int_float(self, tx: "InstructionTranslator", arg): call_int = _call_int_float call_float = _call_int_float + def call_bool(self, tx: "InstructionTranslator", arg): + # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. + # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 + if isinstance(arg, SymNodeVariable): + # Note that we delay specializing on symbolic values to avoid + # unnecessary guards. Specialization will happen later if, e.g., the + # resulting boolean is used for branching. + if isinstance(arg.sym_num, torch.SymBool): + return arg + + # Emulate `nb_bool` of int/float objects + # - https://github.com/python/cpython/blob/3.12/Objects/longobject.c#L4940-L4944 + # - https://github.com/python/cpython/blob/3.12/Objects/floatobject.c#L878-L882 + assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat)) + return SymNodeVariable.create(tx, arg.as_proxy() != 0) + + # TODO handle more cases and merge this with this with `generic_jump`. + def call_str(self, tx: "InstructionTranslator", arg): # Handle `str` on a user defined function or object if isinstance(arg, (variables.UserFunctionVariable)): @@ -1538,11 +1609,22 @@ def _call_iter_tuple_list( if ( getattr(obj, "source", False) and isinstance(obj, ConstDictVariable) - and not istype(obj, SetVariable) + and not istype(obj, (SetVariable, FrozensetVariable)) ): tx.output.guard_on_key_order.add(obj.source) - install_guard(obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + if isinstance(obj, variables.MappingProxyVariable): + # This could be an overguarding, but its rare to iterate + # through a mapping proxy and not use the keys. + install_guard( + obj.source.make_guard(GuardBuilder.MAPPING_KEYS_CHECK) + ) + elif not isinstance(obj, variables.UnspecializedNNModuleVariable): + # Prevent calling __len__ method for guards, the tracing + # of __iter__ will insert the right guards later. + install_guard( + obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH) + ) return cls( list(obj.unpack_var_sequence(tx)), @@ -1698,7 +1780,7 @@ def call_set(self, tx: "InstructionTranslator", *args, **kwargs): ], ) arg = args[0] - if isinstance(arg, variables.SetVariable): + if istype(arg, variables.SetVariable): return arg.clone(mutation_type=ValueMutationNew()) elif arg.has_force_unpack_var_sequence(tx): items = arg.force_unpack_var_sequence(tx) @@ -1733,10 +1815,10 @@ def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs): ], ) arg = args[0] - if isinstance(arg, variables.FrozensetVariable): + if istype(arg, variables.FrozensetVariable): return FrozensetVariable([x.vt for x in arg.set_items]) - elif arg.has_unpack_var_sequence(tx): - items = arg.unpack_var_sequence(tx) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) return FrozensetVariable(items) raise_observed_exception( TypeError, @@ -1934,6 +2016,10 @@ def call_getattr( name = name_var.as_python_constant() + # See NOTE [Tensor "grad" and "_grad" attr] + if isinstance(obj, TensorVariable) and name == "_grad": + name = "grad" + if tx.output.side_effects.is_attribute_mutation(obj): if isinstance(obj, variables.UnspecializedNNModuleVariable): if ( @@ -2009,7 +2095,6 @@ def call_getattr( "assertNotWarns", "assertWarnsRegex", "assertDictEqual", - "assertSequenceEqual", "assertWarns", ) ): @@ -2122,6 +2207,17 @@ def call_setattr( "the mutation out of `torch.compile` region", ], ) + elif obj.dtype != val.dtype: # type: ignore[attr-defined] + unimplemented_v2( + gb_type="Failed to mutate tensor data attribute to different dtype", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor to a new one with the same dtype", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) # Remove the old reference in tracked fakes - if we don't do this # new .data value size and shape differences will cause @@ -2168,11 +2264,12 @@ def _lower_version_count_by_1(x): # Step 4 - replace all reference to the current object with the new one return out elif name in ("_grad", "grad"): + # NOTE: [Tensor "grad" and "_grad" attr] # _grad and grad share the same setter/getter, see # THPVariable_properties, and here we make sure setting one - # enables reading `val` from the other. - tx.output.side_effects.store_attr(obj, "grad", val) - tx.output.side_effects.store_attr(obj, "_grad", val) + # enables reading `val` from the other, by routing all + # read/write to `grad`. + name = "grad" elif is_tensor_getset_descriptor(name): # Attribute like `torch.Tensor.real` has special setters we # don't yet support; it's not as simple adding an entry to @@ -2208,7 +2305,7 @@ def _lower_version_count_by_1(x): # get_fake_val will get the same fake tensor existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx) - # same tensor identiy, setattr is a no-op + # same tensor identity, setattr is a no-op mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__") if ( existing_fake_attr is assigning_fake_val @@ -2224,7 +2321,7 @@ def call_delattr( obj: VariableTracker, name_var: VariableTracker, ): - return self.call_setattr(tx, obj, name_var, variables.DeletedVariable()) + return obj.call_method(tx, "__delattr__", [name_var], {}) def call_type(self, tx: "InstructionTranslator", obj: VariableTracker): try: @@ -2412,6 +2509,22 @@ def _comparison_with_symnode(self, tx: "InstructionTranslator", left, right): sym_num=None, ) + def call_xor(self, tx: "InstructionTranslator", a, b): + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedSetVariable)): + return a.call_method(tx, "__xor__", [b], {}) + + def call_ixor(self, tx: "InstructionTranslator", a, b): + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedSetVariable)): + return a.call_method(tx, "__ixor__", [b], {}) + + def call_sub(self, tx: "InstructionTranslator", a, b): + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedSetVariable)): + return a.call_method(tx, "__sub__", [b], {}) + + def call_isub(self, tx: "InstructionTranslator", a, b): + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedSetVariable)): + return a.call_method(tx, "__isub__", [b], {}) + def call_and_(self, tx: "InstructionTranslator", a, b): # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): @@ -2426,11 +2539,26 @@ def call_and_(self, tx: "InstructionTranslator", a, b): ), sym_num=None, ) - if hasattr(a, "set_items") and hasattr(b, "set_items"): - return SetVariable(list(a.set_items & b.set_items)) + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedSetVariable)): + return a.call_method(tx, "__and__", [b], {}) # None no-ops this handler and lets the driving function proceed - call_iand = call_and_ + def call_iand(self, tx: "InstructionTranslator", a, b): + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.iand, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedSetVariable)): + return a.call_method(tx, "__iand__", [b], {}) def call_or_(self, tx: "InstructionTranslator", a, b): # Rely on constant_handler @@ -2446,15 +2574,41 @@ def call_or_(self, tx: "InstructionTranslator", a, b): ), sym_num=None, ) - if hasattr(a, "set_items") and hasattr(b, "set_items"): - return SetVariable(list(a.set_items | b.set_items)) + # This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`. - if isinstance(a, ConstDictVariable): - return a.call_method(tx, "__or__", args=[b], kwargs={}) + if isinstance( + a, + (ConstDictVariable, DictKeysVariable, SetVariable, UserDefinedSetVariable), + ): + return a.call_method(tx, "__or__", [b], {}) + # None no-ops this handler and lets the driving function proceed return None - call_ior = call_or_ + def call_ior(self, tx: "InstructionTranslator", a, b): + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.ior, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + + # This call looks like `{"one": torch.ones(1)} |= {"two": torch.ones(2)}`. + if isinstance( + a, + (ConstDictVariable, DictKeysVariable, SetVariable, UserDefinedSetVariable), + ): + return a.call_method(tx, "__ior__", [b], {}) + + # None no-ops this handler and lets the driving function proceed + return None def call_not_(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 8f3bd9d3965e51..ce375975bed445 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -173,7 +173,14 @@ def call_method( raise_observed_exception(type(e), tx) elif isinstance(self.value, (float, int)): if not (args or kwargs): - return ConstantVariable.create(getattr(self.value, name)()) + try: + return ConstantVariable.create(getattr(self.value, name)()) + except (OverflowError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) if ( hasattr(operator, name) and len(args) == 1 @@ -203,9 +210,14 @@ def call_method( if name == "__len__" and not (args or kwargs): return ConstantVariable.create(len(self.value)) elif name == "__round__" and len(args) == 1 and args[0].is_python_constant(): - return ConstantVariable.create( - round(self.value, args[0].as_python_constant()) - ) + try: + return ConstantVariable.create( + round(self.value, args[0].as_python_constant()) + ) + except Exception as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): assert not kwargs search = args[0].as_python_constant() diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 3ed8b1d1a9b1aa..ec5c881c4155ad 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -38,6 +38,7 @@ from ..exc import unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GlobalStateSource +from ..utils import _get_error_on_graph_break, _set_error_on_graph_break from .base import VariableTracker from .functions import ( NestedUserFunctionVariable, @@ -155,7 +156,7 @@ def cleanup_assert(self): class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are - # python contants. Which might not always be the case here. + # python constants. Which might not always be the case here. def __init__(self, cm_obj, **kwargs) -> None: assert cm_obj is not None super().__init__( @@ -196,8 +197,36 @@ def exit_on_graph_break(self): return True +class RepararametrizeModuleContextVariable(GenericContextWrappingVariable): + def __init__(self, ctx_manager_vt, mod): + self.cm_vt = ctx_manager_vt + self.mod = mod + # We don't call super().__init__() because we're delegating most methods to cm_vt + + def enter(self, tx: "InstructionTranslator"): + # Custom enter implementation with side effects + + self.old_parameters_var = self.mod.var_getattr(tx, "_parameters").realize() + self.old_buffer_var = self.mod.var_getattr(tx, "_buffers").realize() + tx.output.side_effects.ignore_mutations_on(self.old_parameters_var) + tx.output.side_effects.ignore_mutations_on(self.old_buffer_var) + return self.cm_vt.enter(tx) + + def exit(self, tx: "InstructionTranslator", *args): + # Custom exit implementation with side effects + x = self.cm_vt.exit(tx, *args) + tx.output.side_effects.stop_ignoring_mutations_on(self.old_buffer_var) + tx.output.side_effects.stop_ignoring_mutations_on(self.old_parameters_var) + return x + + # Forward all other method calls to self.cm_vt + def __getattr__(self, name): + # This will be called for any attribute not explicitly defined in this class + return getattr(self.cm_vt, name) + + class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): - """represents torch grad requries grad""" + """represents torch grad requires grad""" @staticmethod def create(tx: "InstructionTranslator", target_values, **kwargs): @@ -1380,16 +1409,6 @@ def __init__(self, target_values, **kwargs) -> None: self.initial_values[key] = torch._dynamo.config.__getattr__(key) self.initial_values = (tuple(self.initial_values.items()),) - def enter(self, tx): - # resets all config patches at the end of tracing - self.set_cleanup_hook(tx) - self._call_func(tx, self.target_values) - return variables.ConstantVariable.create(None) - - def exit(self, tx: "InstructionTranslator", *args): - self._call_func(tx, self.initial_values) - return variables.ConstantVariable.create(None) - def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 value = values[0] @@ -1408,6 +1427,27 @@ def fn_name(self): return "patch_dynamo_config" +class SetFullgraphVariable(ContextWrappingVariable): + """represents torch._dynamo.set_fullgraph""" + + def __init__(self, fullgraph, **kwargs) -> None: + super().__init__( + target_values=(fullgraph,), + initial_values=(_get_error_on_graph_break(),), + **kwargs, + ) + + def _call_func(self, tx: "InstructionTranslator", values): + assert len(values) == 1 + _set_error_on_graph_break(values[0]) + + def module_name(self): + return "torch._dynamo" + + def fn_name(self): + return "set_fullgraph" + + class WithExitFunctionVariable(VariableTracker): _nonvar_fields = { "target", diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 04bd651d886dfc..1f22ea9e2ea484 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -113,6 +113,7 @@ def is_hashable(x): variables.SymNodeVariable, variables.ConstantVariable, variables.EnumVariable, + variables.FrozensetVariable, variables.UserDefinedClassVariable, variables.UserFunctionVariable, variables.SkipFunctionVariable, @@ -129,6 +130,8 @@ def is_hashable(x): class ConstDictVariable(VariableTracker): + CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS + _nonvar_fields = { "user_cls", *VariableTracker._nonvar_fields, @@ -144,7 +147,7 @@ class _HashableTracker: def __init__(self, vt) -> None: # We specialize SymNodes vt = specialize_symnode(vt) - # TODO Temorarily remove to figure out what keys are we breaking on + # TODO Temporarily remove to figure out what keys are we breaking on # and add proper support for them if not is_hashable(vt): raise_unhashable(vt) @@ -302,18 +305,8 @@ def is_new_item(self, value, other): return id(value.realize()) != id(other.realize()) return id(value) != id(other) - def reconstruct(self, codegen: "PyCodegen"): - # instructions to load collections.OrderedDict if necessary - if self.user_cls is collections.OrderedDict: - codegen.add_push_null( - lambda: codegen.extend_output( - [ - codegen.create_load_python_module(collections), - codegen.create_load_attr("OrderedDict"), - ] - ) - ) - # instructions to build the dict keys and values + def reconstruct_kvs_into_new_dict(self, codegen): + # Build a dictionary that contains the keys and values. num_args = 0 for key, value in self.items.items(): # We can safely call realize() here as it won't introduce any new guards @@ -322,18 +315,23 @@ def reconstruct(self, codegen: "PyCodegen"): codegen(key.vt) codegen(value) num_args += 1 + codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) - # BUILD_MAP and calling collections.OrderedDict if necessary + def reconstruct(self, codegen: "PyCodegen"): if self.user_cls is collections.OrderedDict: - codegen.extend_output( - [ - create_instruction("BUILD_MAP", arg=num_args), - *create_call_function(1, False), - ] + # emit `OrderedDict(constructed_dict)` + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(collections), + codegen.create_load_attr("OrderedDict"), + ] + ) ) - # BUILD_MAP only if user_cls is dict + self.reconstruct_kvs_into_new_dict(codegen) + codegen.extend_output(create_call_function(1, False)) else: - codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) + self.reconstruct_kvs_into_new_dict(codegen) def getitem_const_raise_exception_if_absent( self, tx: "InstructionTranslator", arg: VariableTracker @@ -404,7 +402,7 @@ def install_dict_contains_guard(self, tx, args): install_guard( self.make_guard( functools.partial( - GuardBuilder.DICT_CONTAINS, + type(self).CONTAINS_GUARD, key=args[0].value, invert=not contains, ) @@ -467,7 +465,8 @@ def call_method( return DictValuesVariable(self) elif name == "copy": self.install_dict_keys_match_guard() - assert not (args or kwargs) + if args or kwargs: + raise_args_mismatch(tx, name) return self.clone( items=self.items.copy(), mutation_type=ValueMutationNew(), source=None ) @@ -571,10 +570,24 @@ def call_method( return ConstantVariable.create(None) elif name == "__or__": assert len(args) == 1 - if not isinstance(args[0], ConstDictVariable): - raise TypeError( - f"unsupported operand type(s) for |: 'dict' and '{args[0].python_type().__name__}'" + # Dicts can only be unioned with other dicts or subclasses of dicts. + # Sets can be unioned with other sets, frozensets or subclasses of sets. + _raise = not ( + (istype(self, ConstDictVariable) and istype(args[0], ConstDictVariable)) + or ( + isinstance(self, SetVariable) + and isinstance( + args[0], (SetVariable, variables.UserDefinedSetVariable) + ) + ) + ) + + if _raise: + msg = ( + f"unsupported operand type(s) for |: '{self.python_type().__name__}'" + f"and '{args[0].python_type().__name__}'" ) + raise_observed_exception(TypeError, tx, args=[msg]) self.install_dict_keys_match_guard() new_dict_vt = self.clone( @@ -586,6 +599,9 @@ def call_method( args[0].install_dict_keys_match_guard() new_dict_vt.items.update(args[0].items) return new_dict_vt + elif name == "__ior__": + self.call_method(tx, "update", args, kwargs) + return self else: return super().call_method(tx, name, args, kwargs) @@ -624,6 +640,9 @@ def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict + def python_type(self): + return types.MappingProxyType + def unpack_var_sequence(self, tx): return self.dv_dict.unpack_var_sequence(tx) @@ -742,12 +761,28 @@ def call_method( else: return super().call_method(tx, name, args, kwargs) + def reconstruct(self, codegen): + # emit `defaultdict(default_factory, new_dict)` + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(collections), + codegen.create_load_attr("defaultdict"), + ] + ) + ) + codegen(self.default_factory) + self.reconstruct_kvs_into_new_dict(codegen) + codegen.extend_output(create_call_function(2, False)) + # TODO: Implementing this via inheritance rather than composition is a # footgun, because self method calls in dict will route back to the set # implementation, which is almost assuredly wrong class SetVariable(ConstDictVariable): - """We model a sets as dictonary with None values""" + """We model a sets as dictionary with None values""" + + CONTAINS_GUARD = GuardBuilder.SET_CONTAINS def __init__( self, @@ -769,7 +804,7 @@ def set_items(self): @staticmethod def _default_value(): - # Variable to fill in he keys of the dictinary + # Variable to fill in he keys of the dictionary return ConstantVariable.create(None) def as_proxy(self): @@ -818,8 +853,9 @@ def call_method( super().call_method(tx, name, (result,), kwargs) return result elif name == "isdisjoint": + if len(args) != 1: + raise_args_mismatch(tx, name) assert not kwargs - assert len(args) == 1 return variables.UserFunctionVariable( polyfills.set_isdisjoint ).call_function(tx, [self, args[0]], {}) @@ -881,6 +917,9 @@ def call_method( else: return ConstantVariable.create(value=None) elif name in ("issubset", "issuperset"): + if len(args) != 1: + raise_args_mismatch(tx, name) + op = { "issubset": operator.le, "issuperset": operator.ge, @@ -891,6 +930,44 @@ def call_method( return variables.BuiltinVariable(op.get(name)).call_function( tx, [self, other], {} ) + elif name in ("__and__", "__or__", "__xor__", "__sub__"): + m = { + "__and__": "intersection", + "__or__": "union", + "__xor__": "symmetric_difference", + "__sub__": "difference", + }.get(name) + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + msg = ConstantVariable.create( + f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + return self.call_method(tx, m, args, kwargs) + elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + msg = ConstantVariable.create( + f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + m = { + "__iand__": "intersection_update", + "__ior__": "update", + "__ixor__": "symmetric_difference_update", + "__isub__": "difference_update", + }.get(name) + self.call_method(tx, m, args, kwargs) + return self + elif name == "__eq__": + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + return ConstantVariable.create(False) + r = self.call_method(tx, "symmetric_difference", args, kwargs) + return ConstantVariable.create(len(r.set_items) == 0) + elif name in cmp_name_to_op_mapping: + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + return ConstantVariable.create(NotImplemented) + return ConstantVariable.create( + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) + ) return super().call_method(tx, name, args, kwargs) def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): @@ -901,8 +978,7 @@ def install_dict_keys_match_guard(self): pass def install_dict_contains_guard(self, tx, args): - # Already EQUALS_MATCH guarded - pass + super().install_dict_contains_guard(tx, args) class FrozensetVariable(SetVariable): @@ -927,7 +1003,7 @@ def python_type(self): return frozenset def as_python_constant(self): - return {k.vt.as_python_constant() for k in self.set_items} + return frozenset({k.vt.as_python_constant() for k in self.set_items}) def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) @@ -958,6 +1034,14 @@ def call_method( # In[3]: s # frozenset({1, 2}) return ConstantVariable.create(None) + elif name in ( + "copy", + "difference", + "intersection", + "symmetric_difference", + ): + r = super().call_method(tx, name, args, kwargs) + return FrozensetVariable(r.items) return super().call_method(tx, name, args, kwargs) @@ -979,6 +1063,14 @@ def debug_repr(self): + "])" ) + def install_dict_keys_match_guard(self): + # Already EQUALS_MATCH guarded + pass + + def install_dict_contains_guard(self, tx, args): + # Already EQUALS_MATCH guarded + pass + @property def set_items(self): return self.items @@ -1072,6 +1164,20 @@ def call_method( ) -> "VariableTracker": if name == "__contains__": return self.dv_dict.call_method(tx, name, args, kwargs) + elif name in ( + "__and__", + "__iand__", + "__or__", + "__ior__", + "__sub__", + "__isub__", + "__xor__", + "__ixor__", + ): + # These methods always returns a set + m = getattr(self.set_items, name) + r = m(args[0].set_items) + return SetVariable(r) if name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, DictKeysVariable)): return ConstantVariable.create(NotImplemented) diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 6742254f266f09..39320c423e4e38 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -49,7 +49,7 @@ class DistributedVariable(VariableTracker): Concrete distributed objects could inherit this class and add object specific logic. - i.e. It provides the check on the distributed package existance + i.e. It provides the check on the distributed package existence and hold the tracking value for the corresponding distributed object. """ @@ -59,7 +59,7 @@ def __init__(self, value, **kwargs) -> None: unimplemented_v2( gb_type="torch.distributed package is not available!", context="", - explanation="The PyTorch package doesn't include torch.distributed when builing from source.", + explanation="The PyTorch package doesn't include torch.distributed when building from source.", hints=[ "Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source." ], diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 6ac803997cac8d..c4996e75182101 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -27,9 +27,10 @@ import functools import inspect import itertools +import logging import sys +import traceback import types -import warnings from collections.abc import Sequence from types import FunctionType from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar @@ -38,6 +39,7 @@ from weakref import WeakKeyDictionary import torch +from torch._dynamo.exc import get_stack_above_dynamo from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator @@ -446,7 +448,6 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # Handle patch_dynamo_config call - if self.fn is torch._dynamo.patch_dynamo_config: try: args_const = [arg.as_python_constant() for arg in args] @@ -463,8 +464,19 @@ def call_function( "Please fix your call to patch_dynamo_config by using simpler inputs. " f"args: {args}, kwargs: {kwargs}" ) from e + elif self.fn is torch._dynamo.set_fullgraph: + try: + bound = inspect.signature(self.fn).bind(*args, **kwargs) + fullgraph = bound.arguments["fullgraph"].as_python_constant() + assert isinstance(fullgraph, bool) + return variables.SetFullgraphVariable(fullgraph) + except Exception as e: + raise RuntimeError( + "Improper set_fullgraph() call. Please fix your call to set_fullgraph(). " + f"args: {args}, kwargs: {kwargs}" + ) from e # Handle a `nonstrict_trace(fn)` call - if self.fn is torch._dynamo.nonstrict_trace: + elif self.fn is torch._dynamo.nonstrict_trace: bound = inspect.signature(self.fn).bind(*args, **kwargs) fn_var = bound.args[0] if not isinstance(fn_var, BaseUserFunctionVariable): @@ -499,6 +511,17 @@ def call_function( return invoke_and_store_as_constant( tx, self.fn, self.get_name(), args, kwargs ) + + if ( + not tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + and self.fn + is torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer + ): + with torch._dynamo.side_effects.allow_externally_visible_side_effects_in_subtracer( + tx + ): + return super().call_function(tx, args, kwargs) + if ( tx.output.current_tracer.under_activation_checkpoint and not tx.output.current_tracer.allow_side_effects_under_checkpoint @@ -862,6 +885,12 @@ def call_method( else: raise_observed_exception(RuntimeError, tracer) return retval + elif name == "__contains__": + # The generator needs to be lazily consumed here to avoid unintended + # side effects + return variables.UserFunctionVariable( + polyfills.generator___contains__ + ).call_function(tx, [self, *args], {}) super().call_method(tx, name, args, kwargs) @@ -919,7 +948,17 @@ def call_function( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - assert is_generator(self.vt.get_code()) + if not is_generator(self.vt.get_code()): + unimplemented_v2( + gb_type="non-generator contextlib.contextmanager", + context=str(self.vt.get_code()), + explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator" + ", i.e. does not use `yield`", + hints=[ + "Use `yield` in the function body instead of `return`.", + "Remove the `@contextlib.contextmanager` decorator.", + ], + ) inline_tracer = self._build_inline_tracer(tx, args, kwargs) code = self.vt.get_code() @@ -1537,6 +1576,9 @@ def var_getattr(self, tx: "InstructionTranslator", name): return super().var_getattr(tx, name) + def self_args(self): + return [] + def call_function( self, tx: "InstructionTranslator", @@ -1544,18 +1586,55 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if hasattr(self.wrapper_obj, "cache_info"): - warnings.warn( - "Dynamo detected a call to a `functools.lru_cache` wrapped function." - "Dynamo currently ignores `functools.lru_cache` and directly traces the wrapped function." - "`functools.lru_cache` wrapped functions that read outside state may not be traced soundly." - ) + target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None) + module_name = getattr(target_fn, "__module__", "") or "" + + if module_name.split(".", maxsplit=1)[0] != "torch": + msg = ( + "Dynamo detected a call to a `functools.lru_cache`-wrapped " + "function. Dynamo ignores the cache wrapper and directly " + "traces the wrapped function. Silent incorrectness is only " + "a *potential* risk, not something we have observed. " + 'Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.' + ) + + torch._dynamo.utils.warn_once(msg) + + dynamo_logger = torch._dynamo.utils.logging.getLogger("torch._dynamo") + if dynamo_logger.isEnabledFor(logging.DEBUG): + user_stack = torch._guards.TracingContext.extract_stack() + user_stack = get_stack_above_dynamo() + user_stack + frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + user_stack_formatted = "".join(traceback.format_list(user_stack)) + user_stack_trace = f"call to a lru_cache wrapped function at: {frame_loc[0]}:{frame_loc[1]}\n" + user_stack_trace += str(user_stack_formatted) + dynamo_logger.debug(user_stack_trace) + + all_args = self.self_args() + args return variables.UserFunctionVariable( polyfills.getattr_and_trace ).call_function( - tx, [self, variables.ConstantVariable(self.attr_to_trace), *args], kwargs + tx, + [self, variables.ConstantVariable(self.attr_to_trace), *all_args], + kwargs, ) +class WrapperUserMethodVariable(WrapperUserFunctionVariable): + """ + Similar to WrapperUserFunctionVariable, but for methods. The only delta is + saving the vt for `self` object of the method which is then used by + WrapperUserFunctionVariable in `call_function` method. + """ + + def __init__(self, wrapper_obj, attr_to_trace, self_obj, **kwargs) -> None: + super().__init__(wrapper_obj, attr_to_trace, **kwargs) + self.obj = self_obj + + def self_args(self): + return [self.obj] + + def _traceable_collective_remaps(): # We can't rely on importing from distributed, since it's not always built if torch.distributed.is_available(): @@ -1805,7 +1884,7 @@ class PolyfilledFunctionVariable(VariableTracker): } @classmethod - @functools.lru_cache(None) + @functools.cache def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: return {} @@ -1964,6 +2043,8 @@ def call_function(self, tx, args, kwargs): from torch._higher_order_ops.triton_kernel_wrap import ( + create_tma_experimental_metadata, + create_tma_stable_metadata, TMADescriptorMetadata, TritonHOPifier, ) @@ -2059,16 +2140,19 @@ def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable: from .dicts import ConstDictVariable # as we can only pass tensors as non-const args in fx graph, - # here we replace TMA descriptors (TMADescriptorVariable + # here we replace TMA descriptors + # (TMADescriptorExperimentalVariable and TMADescriptorStableVariable # instances) with the underlying tensors, while moving the # TMA descriptor-related metadata to a separate argument, # so that we can reconstruct the TMA descriptors downstream tma_descriptor_metadata: TMADescriptorMetadata = {} for k in list(combined_args_raw.keys()): v = combined_args_raw[k] - if isinstance(v, TMADescriptorVariable): + if isinstance( + v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable) + ): tma_descriptor_metadata[k] = v.to_metadata() - combined_args_raw[k] = v.data_ptr.from_tensor + combined_args_raw[k] = v.get_tensor() combined_args = { variables.ConstantVariable.create(k): v @@ -2170,7 +2254,7 @@ def specialize_symbolic(self, arg: Any) -> Any: return arg -class TMADescriptorVariable(VariableTracker): +class TMADescriptorExperimentalVariable(VariableTracker): def __init__( self, data_ptr: "variables.DataPtrVariable", @@ -2187,7 +2271,7 @@ def __init__( self.element_size = element_size def to_metadata(self): - return ( + return create_tma_experimental_metadata( [dim.as_proxy() for dim in self.dims], [dim.as_proxy() for dim in self.block_dims], self.element_size.as_proxy(), @@ -2205,8 +2289,44 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.foreach(args) codegen.call_function(len(args) + 1, False) + def get_tensor(self): + return self.data_ptr.from_tensor + -class CreateTMADescriptorVariable(VariableTracker): +class TMADescriptorStableVariable(VariableTracker): + def __init__( + self, + tensor: "variables.TensorVariable", + block_shape: "variables.ListVariable", + **kwargs, + ): + assert isinstance(tensor, variables.TensorVariable) + super().__init__(**kwargs) + self.tensor = tensor + self.block_shape = block_shape + + def to_metadata(self): + return create_tma_stable_metadata( + self.block_shape.as_proxy(), + ) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.tensor_descriptor", + "TensorDescriptor", + ) + ) + codegen.load_method("from_tensor") + self.tensor.reconstruct(codegen) + codegen(self.block_shape) + codegen.call_method(2) + + def get_tensor(self) -> "variables.TensorVariable": + return self.tensor + + +class CreateTMADescriptorExperimentalVariable(VariableTracker): def __init__( self, rank: int, @@ -2251,9 +2371,25 @@ def call_function( ] element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1] - return TMADescriptorVariable( + return TMADescriptorExperimentalVariable( data_ptr=ptr, dims=dims, block_dims=block_dims, element_size=element_size, ) + + +class CreateTMADescriptorStableVariable(VariableTracker): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + tensor = kwargs["tensor"] if "tensor" in kwargs else args[0] + block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1] + + return TMADescriptorStableVariable( + tensor=tensor, + block_shape=block_shape, + ) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index a397f6084c2871..82dd2eb4caea70 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -35,6 +35,7 @@ from torch._dynamo.utils import get_fake_value from torch._dynamo.variables.builtin import BuiltinVariable from torch._dynamo.variables.constant import ConstantVariable +from torch._dynamo.variables.ctx_manager import RepararametrizeModuleContextVariable from torch._dynamo.variables.functions import UserFunctionVariable from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable from torch._dynamo.variables.tensor import SymNodeVariable @@ -65,6 +66,7 @@ log = logging.getLogger(__name__) +hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") def raise_hard_error_if_graph_break(reason): @@ -261,7 +263,7 @@ def _check_supported_callable_arg( ) -def are_same_graph_modules(a_mod, b_mod, fake_mode): +def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode): from torch._subclasses._fake_tensor_utils import _CacheKeyState from torch._subclasses.fake_tensor import extract_tensor_metadata @@ -322,7 +324,11 @@ def check_all_args(a_nodes, b_nodes): a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) if not check_all_args(a_flat, b_flat): - # print("call_function args failed") + hc_log.debug( + "%s: Graph comparison failed at node (call_function): %s", + fn_name, + a_node, + ) return False elif a_node.op == "call_method": if a_node.target != b_node.target: @@ -330,13 +336,17 @@ def check_all_args(a_nodes, b_nodes): a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) if not check_all_args(a_flat, b_flat): - # print("call_method args failed") + hc_log.debug( + "%s: Graph comparison failed at node (call_method) : %s", + fn_name, + a_node, + ) return False elif a_node.op == "output": a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) if not check_all_args(a_flat, b_flat): - # print("output args failed") + hc_log.debug("%s: Graph comparison failed at the output node", fn_name) return False elif a_node.op == "get_attr": a_attr = getattr(a_mod, a_node.target) @@ -345,7 +355,7 @@ def check_all_args(a_nodes, b_nodes): if not isinstance(b_attr, torch.fx.GraphModule): return False # This is an example of a HOP inside a HOP - if not are_same_graph_modules(a_attr, b_attr, fake_mode): + if not are_same_graph_modules(fn_name, a_attr, b_attr, fake_mode): return False else: # TODO - write an example with tensor as a graph attribute in @@ -500,7 +510,7 @@ def dedup_and_sort_lifted_freevars(l_lifted_freevars, r_lifted_freevars): # # Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But # true_branch and false_branch belong to two separate tracing contexts, they may register the same - # attribute to top level seperately. This creates two get_attr proxies for the same attribute + # attribute to top level separately. This creates two get_attr proxies for the same attribute # that have different meta data such as stack_trace (one stack trace for the true_branch, # and the other for false_branch). It seems better to discard the proxy explicitly in cond # than make dynamo create a single proxy for the same get_attr target. @@ -567,11 +577,12 @@ def fixup_branch_inps(graph, lifted_freevars, shared, unique_l, unique_r): def _insert_or_replace_phs(new_args, name_suffix): for arg in new_args: new_ph = graph.placeholder(arg.node.name + name_suffix) + new_ph.meta = arg.node.meta # Override with new_ph if there exists a old placeholder. if arg in lifted_freevars: old_ph = lifted_freevars[arg].node old_ph.replace_all_uses_with(new_ph) - # replace_all_uses_with doesn't clean users. Clean it mannually so that we could erase it. + # replace_all_uses_with doesn't clean users. Clean it manually so that we could erase it. old_ph.users = {} graph.erase_node(old_ph) @@ -743,8 +754,8 @@ def speculate_subgraph( # NOTE: [HigherOrderOperator subgraph input ordering] # The input ordering of the higher order ops is determined by the order of - # the creatation of the placehoder. - # Mannually created inputs are created in validate_args_and_maybe_create_graph_inputs before + # the creation of the placeholder. + # Manually created inputs are created in validate_args_and_maybe_create_graph_inputs before # speculating subgraph. # During subgraph speculation, we may lift closured tensors and free symbols as inputs, # their ordering is determined by the time they are lifted: earlier lifted ones precede later @@ -828,6 +839,7 @@ def move_lifted_freevars_phs_to_end( context=context, explanation=f"Higher order ops do not support aliasing. Found in {source_target.name()}", hints=[ + "Replace `return input` with `return input.clone()` to avoid aliasing.", "Consider using the debug context to change user code to avoid aliasing.", "Please open an issue.", ], @@ -874,64 +886,15 @@ def __init__( @staticmethod def make(value, source=None, **kwargs): + variable_class = _hop_name_to_variable_class.get(value.__name__) + if variable_class is not None: + return variable_class(value, source, **kwargs) + from torch._higher_order_ops import BaseHOP - if value.__name__ == "cond": - return CondHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "while_loop": - return WhileLoopHigherOrderVariable(value, source, **kwargs) - elif value.__name__ in ("map", "map_impl"): - return MapHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "executorch_call_delegate": - return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "out_dtype": - return OutDtypeHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "wrap": - return WrapHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "hints_wrapper": - return HintsWrapperHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "flex_attention": - return FlexAttentionHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "flex_attention_backward": - return FlexAttentionBackwardHighOrderVariable(value, source, **kwargs) - elif value.__name__ in ( - "wrap_activation_checkpoint", - "tag_activation_checkpoint", - ): - return CheckpointHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "_export_tracepoint": - return ExportTracepointHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "trace_wrapped": - return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs) - elif value.__name__ == "strict_mode": - return StrictModeHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "run_with_rng_state": - return RunWithRNGStateHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "associative_scan": - return AssociativeScanHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "scan": - return ScanHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "call_torchbind": - return CallTorchbindHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "wrap_with_set_grad_enabled": - return WrapWithSetGradEnabledHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "wrap_with_autocast": - return WrapWithAutocastHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "dynamo_bypassing_wrapper": - return DynamoBypassingWrapperHigherOrderVariable(value, source, **kwargs) - elif ( - value.__name__ == "auto_functionalized" - or value.__name__ == "auto_functionalized_v2" - ): - return AutoFunctionalizeHigherOrderVariable(value, source, **kwargs) - elif value.__name__ == "invoke_subgraph": - return InvokeSubgraphHigherOrderVariable(value, source, **kwargs) - elif isinstance(value, BaseHOP): + if isinstance(value, BaseHOP): return BaseHOPVariable(value, source, **kwargs) - elif value.__name__ == "custom_function_call": - return CustomFunctionHigherOrderOperatorVariable(value, source, **kwargs) - else: - unimplemented(f"HigherOrderOperator {value.__name__}") + unimplemented(f"HigherOrderOperator {value.__name__}") def call_function( self, @@ -941,6 +904,9 @@ def call_function( ) -> VariableTracker: unimplemented(f"HigherOrderOperator {self.value.__name__}") + def as_python_constant(self): + return self.value + class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ @@ -1391,7 +1357,7 @@ def create_unbacked_sym_node_var(tx) -> SymNodeVariable: ) # Note: cond_shared and body_shared refer to the same proxy in parent graph - # so using either of them is OK. Use cond_shared as it doesnt matter. + # so using either of them is OK. Use cond_shared as it doesn't matter. additional_lifted_inputs = cond_shared + cond_unique + body_unique body_nn_modules = dict(tx.output.nn_modules) @@ -1933,6 +1899,17 @@ def call_function( supports_aliasing=self.supports_aliasing, ) + # Check all outputs of map are tensors. + # For map, outputting None is OK, thus ignore None values in the check + body_r_vars = body_r.unpack_var_sequence(tx) + none_mask = [ + type(x.realize()) is ConstantVariable and x.as_python_constant() is None + for x in body_r_vars + ] + _check_all_tensorvariable( + [br for bm, br in zip(none_mask, body_r_vars) if not bm] + ) + body_nn_modules = dict(tx.output.nn_modules) body_name = tx.output.install_subgraph( @@ -2028,6 +2005,17 @@ def call_function( return super().call_function(tx, args, kwargs) +class ReparametrizeModuleCallVariable(FunctorchHigherOrderVariable): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def call_function( + self, tx, args: list[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> VariableTracker: + ctx_manager_vt = super().call_function(tx, args, kwargs) + return RepararametrizeModuleContextVariable(ctx_manager_vt, args[0]) + + class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): supports_input_mutation = True supports_aliasing = True @@ -2658,8 +2646,8 @@ def call_function( class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable): def proxy_submod(self, tx, arg): - assert isinstance(arg.source, DictGetItemSource) - submod_name = tx.output.install_subgraph(arg.source.index, arg.value) + assert isinstance(arg.source.base, DictGetItemSource) + submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value) p_submod = make_attr(tx, submod_name) set_example_value(p_submod.node, arg.value) return p_submod @@ -2803,8 +2791,6 @@ def call_function( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - from torch._higher_order_ops.flex_attention import flex_attention_fake_impl - from .builder import wrap_fx_proxy ( @@ -2841,12 +2827,6 @@ def call_function( # Proxying user defined functions is not supported. inp_args, _ = proxy_args_kwargs(proxied_args, {}) - query_meta = query.as_proxy().node.meta["example_value"] - value_meta = value.as_proxy().node.meta["example_value"] - with torch._guards.TracingContext.try_get().fake_mode: - out_meta, lse_meta = flex_attention_fake_impl(query_meta, value_meta) - example_value = (out_meta, lse_meta) - # Compose the ordered HOO args: # - inp_args: [query, key, value, block_mask, scale, kernel_options] # - subgraph node: [score_mod, mask_fn_node] @@ -2869,7 +2849,7 @@ def call_function( ), kwargs={}, ), - example_value=example_value, + example_value=None, ) @@ -3214,7 +3194,7 @@ def unwrap_proxy(x): # Store the invocation as a call from torch._functorch.autograd_function import autograd_function_apply - # We use speculate_subgraph to get the fwd graph, but it's alway under no grad mode like what eager mode does. + # We use speculate_subgraph to get the fwd graph, but it's always under no grad mode like what eager mode does. # The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes # (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing. # Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it. @@ -3318,7 +3298,7 @@ def call_function( class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): - supports_input_mutation = False + supports_input_mutation = True supports_aliasing = False def install_subgraph_in_output_graph( @@ -3334,7 +3314,7 @@ def install_subgraph_in_output_graph( gb_type="Encountered non user function variable during invoke_subgraph HOP tracing", context=str(fn_vt), explanation="invoke_subgraph does not support non user function variable", - hints=graph_break_hints.SUPPORTABLE, + hints=[*graph_break_hints.SUPPORTABLE], ) invoke_subgraph_cache = ( @@ -3345,9 +3325,11 @@ def install_subgraph_in_output_graph( if isinstance(fn_vt, UserFunctionVariable): fn_id = id(fn_vt.get_function()) + fn_name = fn_vt.get_function().__name__ else: assert isinstance(fn_vt, UnspecializedNNModuleVariable) fn_id = id(fn_vt.value.forward.__func__) + fn_name = fn_vt.value.forward.__name__ previously_installed_submodules = [] if invoke_subgraph_cache: previously_installed_submodules = ( @@ -3359,17 +3341,29 @@ def install_subgraph_in_output_graph( for submodule_name in reversed(previously_installed_submodules): assert submodule_name in tx.output.nn_modules previous_mod = tx.output.nn_modules[submodule_name] - if are_same_graph_modules(previous_mod, current_mod, tx.fake_mode): + if are_same_graph_modules( + fn_name, previous_mod, current_mod, tx.fake_mode + ): return submodule_name body_name = super().install_subgraph_in_output_graph( tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph" ) + hc_log.debug( + "%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s", + fn_name, + body_name, + fn_name, + len(previously_installed_submodules) + 1, + ) if invoke_subgraph_cache: invoke_subgraph_cache.add_dynamo_installed_submodule(fn_id, body_name) return body_name + @raise_hard_error_if_graph_break( + reason="torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) def call_function( self, tx: "InstructionTranslator", @@ -3409,3 +3403,34 @@ def call_function( flat_example_value, treespec, ) + + +# Map operator names to their corresponding variable for fast TorchHigherOrderOperatorVariable.make() +_hop_name_to_variable_class = { + "cond": CondHigherOrderVariable, + "while_loop": WhileLoopHigherOrderVariable, + "map": MapHigherOrderVariable, + "map_impl": MapHigherOrderVariable, + "executorch_call_delegate": ExecutorchCallDelegateHigherOrderVariable, + "out_dtype": OutDtypeHigherOrderVariable, + "wrap": WrapHigherOrderVariable, + "hints_wrapper": HintsWrapperHigherOrderVariable, + "flex_attention": FlexAttentionHigherOrderVariable, + "flex_attention_backward": FlexAttentionBackwardHighOrderVariable, + "wrap_activation_checkpoint": CheckpointHigherOrderVariable, + "tag_activation_checkpoint": CheckpointHigherOrderVariable, + "_export_tracepoint": ExportTracepointHigherOrderVariable, + "trace_wrapped": TraceWrappedHigherOrderOperatorVariable, + "strict_mode": StrictModeHigherOrderVariable, + "run_with_rng_state": RunWithRNGStateHigherOrderVariable, + "associative_scan": AssociativeScanHigherOrderVariable, + "scan": ScanHigherOrderVariable, + "call_torchbind": CallTorchbindHigherOrderVariable, + "wrap_with_set_grad_enabled": WrapWithSetGradEnabledHigherOrderVariable, + "wrap_with_autocast": WrapWithAutocastHigherOrderVariable, + "dynamo_bypassing_wrapper": DynamoBypassingWrapperHigherOrderVariable, + "auto_functionalized": AutoFunctionalizeHigherOrderVariable, + "auto_functionalized_v2": AutoFunctionalizeHigherOrderVariable, + "invoke_subgraph": InvokeSubgraphHigherOrderVariable, + "custom_function_call": CustomFunctionHigherOrderOperatorVariable, +} diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 45f830beda27ad..aa5cb54da192c6 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -116,7 +116,12 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): ) else: assert isinstance(index, (int, torch.SymInt)) - return self.items[index] + try: + return self.items[index] + except IndexError: + raise_observed_exception( + IndexError, tx, args=["list index out of range"] + ) def unpack_var_sequence(self, tx): return list(self.items) @@ -413,13 +418,36 @@ def call_method( name == "__setitem__" and self.is_mutable() and args - and args[0].is_python_constant() + and ( + args[0].is_python_constant() + or isinstance(args[0], SymNodeVariable) + or ( + isinstance(args[0], SliceVariable) + and all( + s.is_python_constant() or isinstance(s, SymNodeVariable) + for s in args[0].items + ) + ) + ) ): assert not kwargs key, value = args tx.output.side_effects.mutation(self) - if isinstance(key, SliceVariable): - self.items[key.as_python_constant()] = list(value.items) + if isinstance(key, SymNodeVariable): + self.items[key.evaluate_expr()] = value + elif isinstance(key, SliceVariable): + if key.is_python_constant(): + self.items[key.as_python_constant()] = list(value.items) + else: + items = slice( + *[ + s.evaluate_expr() + if isinstance(s, SymNodeVariable) + else s.as_python_constant() + for s in key.items + ] + ) + self.items[items] = list(value.items) else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2b9591eff2fa06..923021e63294c2 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -305,6 +305,11 @@ def call_method( and inner_fn in self.objvar._dict_methods ): return self.objvar._dict_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedSetVariable) + and inner_fn in self.objvar._set_methods + ): + return self.objvar._set_vt.call_method(tx, name, args, kwargs) elif ( isinstance(self.objvar, variables.UserDefinedTupleVariable) and inner_fn in tuple_methods @@ -1006,7 +1011,7 @@ def call_method( ) -> "VariableTracker": if name == "queue_callback": if torch._dynamo.compiled_autograd.in_compiled_autograd_region: - assert tx.one_graph, ( + assert tx.one_graph or tx.error_on_graph_break, ( "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" ) return variables.UserFunctionVariable( @@ -1169,7 +1174,9 @@ def call_method( elif name == "__setitem__" and self.name == "__dict__" and not kwargs: if isinstance(self.obj, variables.UserDefinedObjectVariable): # Bypass any custom setattr as we are updating the `__dict__` itself - return self.obj.method_setattr_standard(tx, args[0], args[1]) + return self.obj.method_setattr_standard( + tx, args[0], args[1], directly_update_dict=True + ) if isinstance(self.obj, variables.NNModuleVariable): # This matches how `setattr` is handled for NNModuleVariable self.obj.convert_to_unspecialized(tx) @@ -1439,7 +1446,9 @@ def call_function( and config.use_numpy_random_stream ): msg = f"delegate '{func.__qualname__}' to NumPy itself via " - msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}" + msg += ( + f"config.use_numpy_random_stream={config.use_numpy_random_stream}" + ) unimplemented(msg) args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) @@ -1756,7 +1765,7 @@ class RandomVariable(VariableTracker): """random.Random() Implemented by wrapping a VariableTracker around a random.Random object. - The supported methods for the random.Random object cannot be overriden. + The supported methods for the random.Random object cannot be overridden. Assumes that random objects behave the same given a set seed or state. """ @@ -1901,15 +1910,19 @@ class WeakRefVariable(VariableTracker): @staticmethod def build(tx, weakref_value, **options): source = options.get("source", None) + callback = weakref_value.__callback__ + callback_source = source and AttrSource(source, "__callback__") + callback_vt = VariableTracker.build(tx, callback, callback_source) referent = weakref_value() source = source and WeakRefCallSource(source) referent_vt = VariableTracker.build(tx, referent, source) options["source"] = source - return WeakRefVariable(referent_vt, **options) + return WeakRefVariable(referent_vt, callback_vt, **options) - def __init__(self, referent_vt, **options): + def __init__(self, referent_vt, callback_vt, **options): super().__init__(**options) self.referent_vt = referent_vt + self.callback_vt = callback_vt def call_function( self, @@ -1922,4 +1935,5 @@ def call_function( def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) codegen(self.referent_vt) - codegen.extend_output(create_call_function(1, False)) + codegen(self.callback_vt) + codegen.extend_output(create_call_function(2, False)) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index d8ffd268f7eefd..8ea6f0c701f37d 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -48,7 +48,6 @@ FSDPNNModuleSource, GetItemSource, NNModuleSource, - UnspecializedBuiltinNNModuleSource, UnspecializedNNModuleSource, ) from ..utils import ( @@ -869,7 +868,7 @@ def __init__(self, value, **kwargs) -> None: if type(value) is torch.jit._script.RecursiveScriptModule: raise Unsupported( "ScriptModules aren't supported in UnspecializedNNModuleVariable" - " becuase their .forward function isn't a static member of their type" + " because their .forward function isn't a static member of their type" ) if "value_type" in kwargs: lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None) @@ -891,8 +890,7 @@ def __init__(self, value, **kwargs) -> None: self.nn_module_stack_source = self.source def _wrap_source(self, attr_source): - if not isinstance(attr_source, UnspecializedNNModuleSource): - return UnspecializedNNModuleSource(attr_source) + # the vt is already wrapped with UnspecializedNNModuleSource return attr_source def get_nn_module_stack_source(self): @@ -902,7 +900,7 @@ def set_nn_module_stack_source(self, source): self.nn_module_stack_source = source @staticmethod - @functools.lru_cache(None) + @functools.cache def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} @@ -1056,7 +1054,7 @@ def call_method( # Record if mutations happens on parameters/buffers/modules. The # mutations on these are not tracked by base class # UserDefinedObject vt. This will be used later to graph break - # on seeing a paramters() and family calls. + # on seeing a parameters() and family calls. # TODO(anijain2305) - This might not be needed if we let Dynamo # inline both getattr and setattr. In that case, it should see # the lowest level dicts - _parameters and family and @@ -1132,7 +1130,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): # For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable. # However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for - # differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why + # different nn module instances, because the key keeps changing (look more into RemovableHandle to understand why # key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a # NNModuleHooksDictVariable (a subclass of ConstDictVariable) to avoid any guard on the keys. if ( @@ -1193,8 +1191,7 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable): """ def _wrap_source(self, attr_source): - if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource): - return UnspecializedBuiltinNNModuleSource(attr_source) + # vt is already wrapped with the UnspecializedBuiltinNNModuleSource return attr_source diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 893e514bc2cc90..4e04ca4ef7eaa8 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -45,7 +45,6 @@ from .. import config, graph_break_hints, variables from .._trace_wrapped_higher_order_op import trace_wrapped from ..exc import ( - unimplemented, unimplemented_v2, UnknownPropertiesDuringBackwardTrace, UserError, @@ -376,7 +375,12 @@ def method_attr_is_nested(self, tx): return ConstantVariable.create(self.is_nested) def method_attr_retain_grad(self, tx): - unimplemented("retain_grad does not work with AOTDispatcher") + unimplemented_v2( + gb_type="Tensor.retain_grad() with AOTDispatcher", + context=f"var_getattr {self} retain_grad", + explanation="`Tensor.retain_grad()` does not work with AOTDispatcher.", + hints=[], + ) def method_attr_data(self, tx): return variables.TorchInGraphFunctionVariable( @@ -385,7 +389,12 @@ def method_attr_data(self, tx): def method_attr_grad_fn(self, tx): if self.has_grad_fn: - unimplemented("TensorVariable has a grad_fn") + unimplemented_v2( + gb_type="Tensor with grad_fn()", + context=f"var_getattr {self} grad_fn", + explanation="Dynamo does not support tracing tensors with a grad_fn directly.", + hints=[], + ) else: return variables.ConstantVariable(None) @@ -427,8 +436,14 @@ def call_obj_hasattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name): if self.is_strict_mode(tx): if name in self._strict_mode_banned_ops(): - unimplemented( - f"Getattr invocation {name} in strict mode is not supported" + unimplemented_v2( + gb_type="Strict mode banned op", + context=f"var_getattr {self} {name}", + explanation=f"Getattr invocation '{name}' in strict mode is not supported.", + hints=[ + f"Remove `{name}` from the list of banned ops by " + "setting `torch._dynamo.config._autograd_backward_strict_mode_banned_ops`.", + ], ) elif name in self._strict_mode_conditional_banned_ops(): raise UnknownPropertiesDuringBackwardTrace( @@ -511,17 +526,34 @@ def try_generic_attr_handling(): def call_id(self, tx): if not self.source: - unimplemented("call_id not supported for sourceless TensorVariable") + unimplemented_v2( + gb_type="Unsupported call_id() without source", + context=f"call_id {self}", + explanation="call_id() not supported for sourceless TensorVariable.", + hints=[], + ) # For local source, we associate the real value. We use this real value scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} try: _input_associated_real_value = eval(self.source.name(), scope) except Exception as exc: - unimplemented(f"error getting associated real value: {exc}") + unimplemented_v2( + gb_type="Error getting associated real value", + context=f"call_id {self}", + explanation="Dynamo encountered an error while trying to " + "get the associated real value.", + hints=[], + from_exc=exc, + ) if _input_associated_real_value is None: - unimplemented("call_id without associated real value") + unimplemented_v2( + gb_type="call_id() without associated real value", + context=f"call_id {self}", + explanation="Dynamo could not find an associated real value for the tensor.", + hints=[], + ) install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) id_value = id(_input_associated_real_value) @@ -592,7 +624,13 @@ def call_method( from .torch_function import can_dispatch_torch_function, dispatch_torch_function if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): - unimplemented(f"Illegal method invocation {name} in strict mode") + unimplemented_v2( + gb_type="Illegal method invocation in strict mode", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo currently does not support this method " + f"({name}) invocation in strict mode.", + hints=[], + ) # Only override builtin tensor methods # The user can manually add override handling @@ -625,6 +663,31 @@ def call_method( if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): return variables.ConstantVariable(False) + # For historical reasons, these ops decompose down to syntactically + # invalid aten ops because they contain the python keyword `from`, see + # discussions in #151432 for more details. + # We graph break for now since this use case is uncommon. + if name == "random_": + unimplemented_v2( + gb_type="Tensor.random_ op", + context=f"Tensor.{name}({args=}, {kwargs=})", + explanation="This is currently not supported.", + hints=[ + "Use the out-of-place version of this op", + *graph_break_hints.SUPPORTABLE, + ], + ) + elif name == "uniform_" and "from" in kwargs: + unimplemented_v2( + gb_type="Tensor.uniform_ op called with `from` keyword", + context=f"Tensor.{name}({args=}, {kwargs=})", + explanation="This is currently not supported.", + hints=[ + "Avoid using the `from` keyword.", + *graph_break_hints.SUPPORTABLE, + ], + ) + try: handler_method = getattr(self, f"method_{name}") except AttributeError: @@ -635,7 +698,14 @@ def call_method( if result: return result except TypeError as e: - unimplemented(f"unhandled args for {name}: {e}") + unimplemented_v2( + gb_type="Unhandled args for method", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo encountered an error while calling " + f"the method `{name}`.", + hints=[], + from_exc=e, + ) from .builder import wrap_fx_proxy @@ -825,9 +895,26 @@ def method_element_size(self): def method_numpy(self, *, force=False): if not config.trace_numpy: - unimplemented("Tensor.numpy(). config.trace_numpy is False") + unimplemented_v2( + gb_type="Tensor.numpy() with trace_numpy=False", + context=f"call_method {self} numpy", + explanation="`Tensor.numpy()` was called, but the `trace_numpy` " + "configuration was manually disabled.", + hints=[ + "Set `torch._dynamo.config.trace_numpy = True` to allow " + "Dynamo to trace through NumPy.", + ], + ) if not np: - unimplemented("Tensor.numpy(). NumPy is not available") + unimplemented_v2( + gb_type="Tensor.numpy() without NumPy installed", + context=f"call_method {self} numpy", + explanation="`Tensor.numpy()` was called, but the NumPy library " + "is not available in the current environment.", + hints=[ + "Ensure NumPy is installed in your Python environment.", + ], + ) if self.layout != torch.strided: raise TypeError( f"can't convert {self.layout} layout tensor to numpy. Use Tensor.to_dense() first" @@ -873,7 +960,16 @@ def wrap(i, sub_proxy): torch.int32, torch.int64, ]: - unimplemented("Input tensor for tolist must be an integer tensor") + unimplemented_v2( + gb_type="Tensor.tolist() with non-integer tensor", + context=f"call_method {self} to_list", + explanation="Dynamo currently does not support tracing " + "`tolist()` on non-integer tensors.", + hints=[ + "Ensure the input tensor to `tolist()` is an integer " + "type (e.g., int8, int16, int32, int64)." + ], + ) if tensor.dim() == 0: return wrap(tensor, sub_proxy) @@ -891,7 +987,12 @@ def wrap(i, sub_proxy): return VariableTracker.build(tx, out) def method_backward(self, *args, **kwargs): - unimplemented("Tensor.backward") + unimplemented_v2( + gb_type="Unsupported Tensor.backward() call", + context=f"call_method {self} backward {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.backward()`.", + hints=[*graph_break_hints.FUNDAMENTAL], + ) def method_data_ptr(self, *args, **kwargs): return DataPtrVariable(self) @@ -899,7 +1000,17 @@ def method_data_ptr(self, *args, **kwargs): def method_item(self, *args, **kwargs): if not config.capture_scalar_outputs: self._warn_capture_scalar_outputs() - unimplemented("Tensor.item") + unimplemented_v2( + gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False", + context=f"call_method {self} item {args} {kwargs}", + explanation="Dynamo does not support tracing `Tensor.item()` " + "with config.capture_scalar_outputs=False.", + hints=[ + "Set `torch._dynamo.config.capture_scalar_outputs = True` " + "or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` " + "to include these operations in the captured graph.", + ], + ) def method___getitem__(self, *args, **kwargs): from ..symbolic_convert import InstructionTranslator @@ -929,7 +1040,7 @@ def method___getitem__(self, *args, **kwargs): return wrap_fx_proxy(tx, proxy) @staticmethod - @functools.lru_cache(None) + @functools.cache def _warn_capture_scalar_outputs(): user_stack = torch._guards.TracingContext.extract_stack() user_stack_formatted = "".join(traceback.format_list(user_stack)) @@ -969,35 +1080,51 @@ def method_addcmul_(self, tensor1, tensor2, *, value=None): ) def method___setitem__(self, key, value): - def has_bool_key(v): - if isinstance(v, TensorVariable): - return v.dtype in (torch.bool, torch.int8) - elif isinstance(v, variables.TupleVariable): - return any(has_bool_key(item) for item in v.items) - else: - return False - from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() - tx.output.create_proxy( + proxy = tx.output.create_proxy( "call_function", operator.setitem, *proxy_args_kwargs([self, key, value], {}), ) + + if config.use_graph_deduplication or config.track_nodes_for_deduplication: + tx.output.region_tracker.add_node_mutation(proxy.node, 0) + return ConstantVariable.create(None) def method_resize_(self, *args, **kwargs): - unimplemented("Tensor.resize_") + unimplemented_v2( + gb_type="Unsupported Tensor.resize_() call", + context=f"call_method {self} resize_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.resize_()`.", + hints=[], + ) def method_resize_as_(self, *args, **kwargs): - unimplemented("Tensor.resize_as_") + unimplemented_v2( + gb_type="Unsupported Tensor.resize_as_() call", + context=f"call_method {self} resize_as_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.resize_as_()`.", + hints=[], + ) def method_sparse_resize_(self, *args, **kwargs): - unimplemented("Tensor.sparse_resize_") + unimplemented_v2( + gb_type="Unsupported Tensor.sparse_resize_() call", + context=f"call_method {self} sparse_resize_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_()`.", + hints=[], + ) def method_sparse_resize_and_clear_(self, *args, **kwargs): - unimplemented("Tensor.sparse_resize_and_clear_") + unimplemented_v2( + gb_type="Unsupported Tensor.sparse_resize_and_clear_() call", + context=f"call_method {self} sparse_resize_and_clear_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_and_clear_()`.", + hints=[], + ) def method_set_(self, *args, **kwargs): if len(args) > 1: @@ -1007,7 +1134,13 @@ def method_set_(self, *args, **kwargs): # overload and is used by FSDP. # graph-breaking on aten::set_source_Tensor_storage_offset for now, # unless we find that we need to make it work. - unimplemented("Tensor.set_.source_Tensor_storage_offset") + unimplemented_v2( + gb_type="Unsupported Tensor.set_() call", + context=f"call_method {self} set_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.set_()` " + "overloads that include more than one argument.", + hints=[*graph_break_hints.SUPPORTABLE], + ) def method_add_(self, other, *, alpha=None): if alpha is not None: @@ -1133,8 +1266,11 @@ def _method_register_hook(self, name: str, hook: VariableTracker): # would have no recourse - their forward traces just fine, but will fail at backwards unless # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) # then they have nothing they can do except disable compile. - unimplemented( - "Compilation of intermediate hooks requires compiled autograd" + unimplemented_v2( + gb_type="Compilation of intermediate hooks requires compiled autograd", + context=f"var_getattr {self} {name}", + explanation="Dynamo must be in compiled_autograd to register hooks.", + hints=[], ) hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook) @@ -1180,7 +1316,13 @@ def method_requires_grad_(self, requires_grad=True): requires_grad = requires_grad.as_python_constant() if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: - unimplemented("Tensor.requires_grad_") + unimplemented_v2( + gb_type="Unsupported Tensor.requires_grad_() call", + context=f"call_method {self} requires_grad_", + explanation="Dynamo does not support changes to a Tensor's " + "`requires_grad` through calling `requires_grad_()`.", + hints=[], + ) else: return self @@ -1366,9 +1508,19 @@ def insert_into_graph(): return ConstantVariable.create(int(r)) return insert_into_graph() elif name in ["base", "flags", "dtype"]: - unimplemented(f"TODO: add support for ndarray.{name}") + unimplemented_v2( + gb_type="Unsupported ndarray attribute access", + context=f"var_getattr {self} {name}", + explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", + hints=[], + ) elif name in ["__version__"]: - unimplemented("delegate np.__version__ to NumPy") + unimplemented_v2( + gb_type="Unsupported ndarray.__version__ access", + context=f"var_getattr {self} {name}", + explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", + hints=[], + ) if result is None: raise NotImplementedError return result @@ -1394,8 +1546,13 @@ def call_method( if name in ["__len__", "size", "tolist"]: # delegate back to TensorVariable return super().call_method(tx, name, args, kwargs) - if name in ("tostring", "tobytes"): - unimplemented(f"{name} is not modelled in torch._numpy") + if name in ("tostring", "tobytes", "__delattr__"): + unimplemented_v2( + gb_type="Unsupported ndarray method call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"`ndarray.{name}()` is not modelled in `torch._numpy`.", + hints=[], + ) proxy = tx.output.create_proxy( "call_function", numpy_method_wrapper(name), diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c1aeb512aef5d9..d97d167a6ecfaa 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -132,7 +132,11 @@ ) constant_fold_functions_need_guards = [ + torch.accelerator.current_device_index, torch.cuda.current_device, + torch.cuda.is_initialized, + torch.xpu.current_device, + torch.xpu.is_initialized, ] constant_fold_functions = [ @@ -140,6 +144,7 @@ torch._utils._get_device_index, torch._C._get_cublas_allow_tf32, torch._C._is_any_autocast_enabled, + torch.accelerator.is_available, torch.cuda.get_device_properties, torch.cuda.is_available, torch.distributed.is_available, @@ -155,6 +160,8 @@ torch.promote_types, torch._C._get_privateuse1_backend_name, torch.autograd._is_checkpoint_valid, + torch.xpu.get_device_properties, + torch.xpu.is_available, ] + constant_fold_functions_need_guards if torch.distributed.is_available(): constant_fold_functions.extend( @@ -169,7 +176,7 @@ constant_fold_functions = dict.fromkeys(constant_fold_functions) -@functools.lru_cache(None) +@functools.cache def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: # Defined as a function to avoid circular import like torch.onnx return { @@ -196,7 +203,7 @@ def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: } -@functools.lru_cache(None) +@functools.cache def get_overridable_functions(): from itertools import chain @@ -431,7 +438,7 @@ def get_function(self): return self.value @staticmethod - @functools.lru_cache(None) + @functools.cache def _get_handlers(): """Build a dict from function -> method to handle it so that we are O(1) in terms of the number of function with special handling.""" @@ -934,6 +941,18 @@ def handle_statically_known_false(self, tx: "InstructionTranslator", expr): elif isinstance(expr, ConstantVariable): return expr + @register(torch.fx.experimental.symbolic_shapes.guard_scalar) + def guard_scalar(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + val = expr.sym_num + elif isinstance(expr, ConstantVariable): + val = expr.value + else: + raise torch._dynamo.exc.Unsupported("branch not supported") + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_scalar(val) + ) + @register(torch.fx.experimental.symbolic_shapes.statically_known_true) def handle_statically_known_true(self, tx: "InstructionTranslator", expr): if isinstance(expr, SymNodeVariable): @@ -945,6 +964,28 @@ def handle_statically_known_true(self, tx: "InstructionTranslator", expr): elif isinstance(expr, ConstantVariable): return expr + @register(torch.fx.experimental.symbolic_shapes.sym_and) + def handle_sym_and(self, tx: "InstructionTranslator", *terms): + if all(isinstance(x, SymNodeVariable) for x in terms): + return SymNodeVariable.create( + tx, + torch.fx.experimental.symbolic_shapes.sym_and( + *(x.as_proxy() for x in terms) + ), + sym_num=None, + ) + + @register(torch.fx.experimental.symbolic_shapes.sym_or) + def handle_sym_or(self, tx: "InstructionTranslator", *terms): + if all(isinstance(x, SymNodeVariable) for x in terms): + return SymNodeVariable.create( + tx, + torch.fx.experimental.symbolic_shapes.sym_or( + *(x.as_proxy() for x in terms) + ), + sym_num=None, + ) + @register(torch.fx.experimental.symbolic_shapes.has_static_value) def handle_has_static_value(self, tx: "InstructionTranslator", expr): if isinstance(expr, SymNodeVariable): @@ -1223,7 +1264,7 @@ def patched_fn(*args, **kwargs): # Guard against inplace view op on input tensor (not supported) if args and isinstance(args[0], variables.TensorVariable): tensor_var = args[0] - # Check if input tensor and inplace_view op specifcally + # Check if input tensor and inplace_view op specifically if tensor_var.source is not None and hasattr(torch.ops.aten, name): fn = getattr(torch.ops.aten, name) if ( @@ -1284,15 +1325,31 @@ def patched_fn(*args, **kwargs): # variant torch ops, the original function could come from a user # defined `@allow_in_graph` function as well, which doesn't have the # same semantics as the torch ops. - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + saved_out_shapes = None + out_kwarg_vt = None + if "out" in kwargs: + out_kwarg_vt = kwargs["out"] + + # e.g., out=(t1, t2, ...) + if isinstance(out_kwarg_vt, (TupleVariable, ListVariable)): + saved_out_shapes = [] + for vt in out_kwarg_vt.items: + if isinstance(vt, variables.TensorVariable): + shape = vt.proxy.node.meta["example_value"].shape + else: + shape = None + saved_out_shapes.append(shape) + + # e.g., out=output_tensor + if isinstance(out_kwarg_vt, variables.TensorVariable): + saved_out_shapes = out_kwarg_vt.proxy.node.meta["example_value"].shape tensor_variable = wrap_fx_proxy( tx=tx, @@ -1315,10 +1372,7 @@ def patched_fn(*args, **kwargs): ) # Handle e.g., `torch.add(a, b, out=result)` - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): + if saved_out_shapes is not None: # out variants of torch operators like torch.sort and torch.sigmoid # mutate the tensors in the out field. # @@ -1330,26 +1384,34 @@ def patched_fn(*args, **kwargs): # Note that although these tensor variablels would hold different # proxies, the in-place mutation semantics is preserved in the FX # graph, so we won't have correctness issues. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items + if isinstance(saved_out_shapes, list): + for out_tensor_vt, saved_out_shape in zip( + out_kwarg_vt.items, # type: ignore[union-attr] + saved_out_shapes, ): - if ( - isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor._size - != result_tensor._size # we actually want to compare None values here - ): + if saved_out_shape is None: + # This should be extremely rare, but it's kept for now + # until we invest in enforcing the `out=` kwarg for only + # torch methods. + continue + + assert isinstance(out_tensor_vt, TensorVariable) + fake_out = out_tensor_vt.proxy.node.meta["example_value"] + if saved_out_shape != fake_out.shape: # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] - if fake_out_shape != fake_tensor.shape: + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + else: + assert isinstance(out_kwarg_vt, TensorVariable) + assert "example_value" in out_kwarg_vt.proxy.node.meta + fake_out = out_kwarg_vt.proxy.node.meta["example_value"] + if saved_out_shapes != fake_out.shape: # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") @@ -1359,32 +1421,6 @@ def patched_fn(*args, **kwargs): unimplemented( "out= op was called where output tensor was non-contiguous" ) - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where output tensor was non-contiguous" - ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where some of the output tensors were non-contiguous" - ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") return tensor_variable @@ -1487,8 +1523,9 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad # Alternate version if we have a .source varname = tx.output.new_var() - # construct the nn.Parmeter before the graph save it to varname - cg = PyCodegen(tx) + # construct the nn.Parameter before the graph save it to varname + assert tx.output.root_tx is not None + cg = PyCodegen(tx.output.root_tx) cg.add_push_null(lambda: cg.load_import_from("torch.nn", "Parameter")) cg(data.source) cg(variables.ConstantVariable(requires_grad)) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index ddacee127f114b..bb3ac286773b61 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -199,7 +199,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): ] -@functools.lru_cache(None) +@functools.cache def get_prev_stack_var_name(): from ..bytecode_transformation import unique_id @@ -368,7 +368,7 @@ def is_supported_torch_function_mode(ty): # We are able to trace custom modes but if there are graph breaks under them # and they have a custom __enter__/__exit__ we don't handle this for the # same reason we don't handle generic context managers: there may be side effects - # that are now affected by executing the funtion across two frames instead of one + # that are now affected by executing the function across two frames instead of one # Today we support the enter/exit of the default TorchFunctionMode as well as # DeviceContext (which is used for set_default_device) return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( @@ -487,7 +487,7 @@ def _get_subclass_type_var(tx: "InstructionTranslator", var): return VariableTracker.build(tx, var.python_type(), source) -def _is_attr_overidden(tx: "InstructionTranslator", var, name): +def _is_attr_overridden(tx: "InstructionTranslator", var, name): import torch overridden = False @@ -640,11 +640,11 @@ def var_getattr(self, tx: "InstructionTranslator", name): ], ) - # Handle non-overriden attributes inherited from `torch.Tensor`. - attr_is_overriden = _is_attr_overidden(tx, self, name) + # Handle non-overridden attributes inherited from `torch.Tensor`. + attr_is_overridden = _is_attr_overridden(tx, self, name) if ( hasattr(torch.Tensor, name) - and not attr_is_overriden + and not attr_is_overridden and not inspect.ismethoddescriptor(getattr(torch.Tensor, name)) ): args, kwargs = [self], {} @@ -694,11 +694,11 @@ def var_getattr(self, tx: "InstructionTranslator", name): attr.__func__, self.class_type_var(tx), source=attr_source ) - elif attr_is_overriden: + elif attr_is_overridden: unimplemented_v2( - gb_type="Unsupported tensor subclass overriden attribute access", + gb_type="Unsupported tensor subclass overridden attribute access", context=f"{name}", - explanation="`torch.compile` only support tracing certain types of overriden tensor subclass attributes", + explanation="`torch.compile` only support tracing certain types of overridden tensor subclass attributes", hints=[ f"Avoid accessing {name} of tensor subclass in torch.compile region", f"Renaming attribute `{name}` of type {self.class_type}", @@ -735,9 +735,9 @@ def call_method( if can_dispatch_torch_function(tx, tf_args, kwargs): import torch - if _is_attr_overidden(tx, self, name): + if _is_attr_overridden(tx, self, name): unimplemented_v2( - gb_type="Tensor subclass overriden method call", + gb_type="Tensor subclass overridden method call", context=f"{name}", explanation="`torch.compile` currently can't trace this", hints=[ diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index b18a269e0b9faa..6cc56364bcd35d 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -10,7 +10,9 @@ attribute access, and other Python object behaviors. - Specialized subclasses for common patterns: - UserDefinedDictVariable: For dict subclasses + - UserDefinedSetVariable: For set subclasses - UserDefinedTupleVariable: For tuple subclasses + - UserDefinedExceptionObjectVariable: For exception subclasses - FrozenDataClassVariable: Special handling of frozen dataclasses - MutableMappingVariable: For collections.abc.MutableMapping subclasses @@ -21,6 +23,7 @@ maintaining proper semantics while enabling optimizations where possible. """ +import _collections import builtins import collections import contextlib @@ -65,9 +68,11 @@ check_constant_args, cmp_name_to_op_mapping, dict_methods, + frozenset_methods, get_custom_getattr, has_torch_function, is_frozen_dataclass, + is_lru_cache_wrapped_function, is_namedtuple_cls, is_utils_checkpoint, is_wrapper_or_member_descriptor, @@ -76,6 +81,7 @@ namedtuple_fields, object_has_getattribute, proxy_args_kwargs, + set_methods, tensortype_to_dtype, tuple_methods, unpatched_nn_module_getattr, @@ -105,6 +111,10 @@ def is_standard_setattr(val): return val in (object.__setattr__, BaseException.__setattr__) +def is_standard_delattr(val): + return val in (object.__delattr__, BaseException.__delattr__) + + def is_forbidden_context_manager(ctx): f_ctxs = [] @@ -144,7 +154,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value})" @staticmethod - @functools.lru_cache(None) + @functools.cache def _constant_fold_classes(): return { torch.device, @@ -154,12 +164,25 @@ def _constant_fold_classes(): } @staticmethod - @functools.lru_cache(None) + @functools.cache def _in_graph_classes(): _in_graph_class_list = { torch.Tensor, + torch.cuda.FloatTensor, + torch.cuda.DoubleTensor, + torch.cuda.HalfTensor, + torch.cuda.BFloat16Tensor, + torch.cuda.ByteTensor, + torch.cuda.CharTensor, + torch.cuda.IntTensor, + torch.cuda.ShortTensor, + torch.cuda.LongTensor, + torch.Stream, + torch.Event, torch.cuda.Stream, torch.cuda.Event, + torch.xpu.Stream, + torch.xpu.Event, } if hasattr(torch, "hpu"): _in_graph_class_list.update( @@ -172,7 +195,7 @@ def _in_graph_classes(): return set(tensortype_to_dtype.keys()) | _in_graph_class_list @staticmethod - @functools.lru_cache(None) + @functools.cache def supported_c_new_functions(): exceptions = [ getattr(builtins, name).__new__ @@ -183,6 +206,8 @@ def supported_c_new_functions(): return { object.__new__, dict.__new__, + set.__new__, + frozenset.__new__, tuple.__new__, list.__new__, }.union(exceptions) @@ -378,6 +403,9 @@ def call_method( return variables.ConstantVariable(self.value == args[0].value) elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"): return variables.ConstantVariable(self.value != args[0].value) + elif issubclass(self.value, (set, frozenset)) and name != "__new__": + # __new__ is handled below + return variables.BuiltinVariable(set).call_method(tx, name, args, kwargs) elif ( name == "__new__" and self.value is collections.OrderedDict @@ -472,7 +500,11 @@ def call_function( items, maxlen=maxlen, mutation_type=ValueMutationNew() ) elif self.value is weakref.ref: - return variables.WeakRefVariable(args[0]) + if len(args) > 1: + callback = args[1] + else: + callback = variables.ConstantVariable.create(None) + return variables.WeakRefVariable(args[0], callback) elif self.value is functools.partial: if not args: unimplemented("functools.partial malformed") @@ -735,7 +767,12 @@ class UserDefinedObjectVariable(UserDefinedVariable): Mostly objects of defined type. Catch-all for something where we only know the type. """ - _nonvar_fields = {"value", "value_type", *UserDefinedVariable._nonvar_fields} + _nonvar_fields = { + "value", + "value_type", + "attrs_directly_modifed_on_dict", + *UserDefinedVariable._nonvar_fields, + } def __init__( self, @@ -763,6 +800,13 @@ def __init__( self.base_cls_vt = base_cls_vt self.init_args = init_args + # This records names of the attributes that were modified via instance + # `__dict__` directly, rather than the normal setattr path. + # + # TODO consider emulating `obj.__dict__` as a `ConstDictVariable` to get + # rid of these workarounds here and in `GetAttrVariable`. + self.attrs_directly_modifed_on_dict = set() + def __str__(self) -> str: inner = self.value_type.__name__ if inner in [ @@ -826,7 +870,7 @@ def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwar ) @staticmethod - @functools.lru_cache(None) + @functools.cache def _supported_random_functions(): fns = { random.random, @@ -861,6 +905,11 @@ def call_method( if is_standard_setattr(method) or isinstance(self.value, threading.local): return self.method_setattr_standard(tx, *args, **kwargs) + if is_standard_delattr(method): + return self.method_setattr_standard( + tx, args[0], variables.DeletedVariable() + ) + if method is object.__eq__ and len(args) == 1 and not kwargs: other = args[0] if not isinstance(other, UserDefinedObjectVariable): @@ -897,7 +946,9 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def method_setattr_standard(self, tx: "InstructionTranslator", name, value): + def method_setattr_standard( + self, tx: "InstructionTranslator", name, value, directly_update_dict=False + ): try: name = name.as_python_constant() except NotImplementedError: @@ -905,6 +956,28 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): if not tx.output.side_effects.is_attribute_mutation(self): unimplemented(f"setattr({self}, {name}, ...)") + if directly_update_dict: + self.attrs_directly_modifed_on_dict.add(name) + else: + tmp = self.try_get_descritor_and_setter_py_func(name) + if tmp: + descriptor, setter = tmp + # Emulate + # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1371-L1452 + desc_source = None + func_source = None + if self.cls_source: + desc_source = self.get_source_by_walking_mro(name) + # use `type(...)` to ignore instance attrs. + func_source = AttrSource(TypeSource(desc_source), "__set__") + desc_var = VariableTracker.build(tx, descriptor, desc_source) + func_var = VariableTracker.build(tx, setter, func_source) + args = [desc_var, self, value] + return func_var.call_function(tx, args, {}) + # NOTE: else we assume the descriptor (if any) has a + # side-effect-free `__set__` as far as Dynamo tracing is concerned. + + # Emulate the standard setattr on instance dict. tx.output.side_effects.store_attr(self, name, value) return variables.ConstantVariable(None) @@ -1010,26 +1083,20 @@ def _is_c_defined_property(self, subobj): def _getattr_static(self, name): subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ) - import _collections # In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local # has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup. + # NOTE we assume the following descriptors are side-effect-free as far + # as Dynamo tracing is concerned. if not object_has_getattribute(self.value) and ( subobj is NO_SUCH_SUBOBJ # e.g., threading.local - or isinstance( - subobj, _collections._tuplegetter - ) # namedtuple fields are represented by _tuplegetter - or ( - inspect.ismemberdescriptor(subobj) and name in self.value.__slots__ - ) # handle memberdecriptor and slots + or inspect.ismemberdescriptor(subobj) # e.g., __slots__ + or inspect.isgetsetdescriptor(subobj) # e.g., __dict__ or self._is_c_defined_property(subobj) - or inspect.isgetsetdescriptor( - subobj - ) # handle getsetdescriptor like __dict__ ): # Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't # want to call getattr because it can be user-overridden. - subobj = self.value.__getattribute__(name) + subobj = type(self.value).__getattribute__(self.value, name) elif object_has_getattribute(self.value) and subobj is NO_SUCH_SUBOBJ: # If the object has an overridden getattribute method, Dynamo has # already tried tracing it, and encountered an AttributeError. We @@ -1040,6 +1107,27 @@ def _getattr_static(self, name): return subobj + def should_skip_descriptor_setter(self, attr_name): + # Check if `attr_name` corresponds to a descriptor. + descriptor = inspect.getattr_static(type(self.value), attr_name, None) + setter = inspect.getattr_static(type(descriptor), "__set__", None) + if setter: + # Skip if `__set__` was traceable (no need to redo the side effect). + if inspect.isfunction(setter): + return True + # For untraceable `__set__` we should still skip if the attribute + # was mutated via instance `__dict__`. + elif attr_name in self.attrs_directly_modifed_on_dict: + return True + return False + + def try_get_descritor_and_setter_py_func(self, attr_name): + descriptor = inspect.getattr_static(type(self.value), attr_name, None) + setter = inspect.getattr_static(type(descriptor), "__set__", None) + if inspect.isfunction(setter): + return (descriptor, setter) + return None + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): if tx.output.side_effects.has_pending_mutation_of_attr(self, key): mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) @@ -1160,6 +1248,14 @@ def var_getattr(self, tx: "InstructionTranslator", name): return variables.UserMethodVariable( subobj.fget, self, source=source ).call_function(tx, [], {}) + elif isinstance(subobj, _collections._tuplegetter): + # namedtuple fields are represented by _tuplegetter, and here we + # emulate its `__get__`, which is implemented in C. + _, (idx, _) = subobj.__reduce__() + # Don't go through the `__getitem__` method anymore, see + # https://github.com/python/cpython/blob/470941782f74288823b445120f6383914b659f23/Modules/_collectionsmodule.c#L2690 + assert isinstance(self, UserDefinedTupleVariable) + return self._tuple_vt.items[idx] elif isinstance(subobj, staticmethod): # Safe because `staticmethod.__get__` basically won't trigger user # code and just returns the underlying `__func__`: @@ -1174,9 +1270,20 @@ def var_getattr(self, tx: "InstructionTranslator", name): # e.g.: inspect.getattr_static({}, "fromkeys") func = subobj.__get__(self.value, None) return VariableTracker.build(tx, func, source) - elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor( - subobj.__get__ + elif is_lru_cache_wrapped_function(subobj): + # getattr_static returns the lru_wrapped function, and we cannot + # extract the underlying method from the wrapped function. To handle + # it, manually create a wrapped user method vt. + return variables.WrapperUserMethodVariable( + subobj, "__wrapped__", self, source=source + ) + elif inspect.getattr_static( + type(subobj), "__get__", NO_SUCH_SUBOBJ + ) is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor( + type(subobj).__get__ ): + # Emulate https://github.com/python/cpython/blob/3.11/Objects/object.c#L1271-L1285 + # # Attribute has a __get__ method. Create a user defined object vt # for the subobj, and then trace the __get__ method. descriptor_source = None @@ -1185,7 +1292,9 @@ def var_getattr(self, tx: "InstructionTranslator", name): # To access the method descriptor from the udf object w/o using # inspect.getattr_static, we can look into the class mro descriptor_source = self.get_source_by_walking_mro(name) - descriptor_get_source = AttrSource(descriptor_source, "__get__") + descriptor_get_source = AttrSource( + TypeSource(descriptor_source), "__get__" + ) descriptor_var = VariableTracker.build(tx, subobj, descriptor_source) else: # Sourceless Builder does not support user defined objects @@ -1437,7 +1546,7 @@ def call_method(self, tx, name, args, kwargs): self.exc_vt.args = args self.value.args = args return variables.ConstantVariable(None) - if ( + elif ( name == "__setattr__" and len(args) == 2 and isinstance(args[0], variables.ConstantVariable) @@ -1445,6 +1554,8 @@ def call_method(self, tx, name, args, kwargs): in ("__cause__", "__context__", "__suppress_context__", "__traceback__") ): self.exc_vt.call_setattr(tx, args[0], args[1]) + elif name == "with_traceback": + return self.exc_vt.call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs) @property @@ -1583,6 +1694,81 @@ def is_underlying_vt_modified(self, side_effects): return side_effects.is_modified(self._dict_vt) +class UserDefinedSetVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of set. + + Internally, it uses a SetVariable to represent the set part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + + def __init__(self, value, set_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._set_vt = set_vt + + python_type = set if isinstance(value, set) else frozenset + self._set_methods = set_methods if python_type is set else frozenset_methods + + if self._set_vt is None: + assert self.source is None, ( + "set_vt must be constructed by builder.py when source is present" + ) + if python_type is set: + # set is initialized later + self._set_vt = variables.SetVariable( + {}, mutation_type=ValueMutationNew() + ) + else: + init_args = kwargs.get("init_args", {}) + tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx() + self._set_vt = variables.BuiltinVariable(python_type).call_function( + tx, init_args, {} + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + method = self._maybe_get_baseclass_method(name) + if method in self._set_methods: + return self._set_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def as_python_constant(self): + return self._set_vt.as_python_constant() + + def unpack_var_sequence(self, tx): + if inspect.getattr_static(self.value, "__iter__") in ( + set.__iter__, + frozenset.__iter__, + ): + return self._set_vt.unpack_var_sequence(tx) + raise NotImplementedError + + @property + def set_items(self): + return self._set_vt.set_items + + @property + def items(self): + return self._set_vt.items + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._set_vt) + + def install_dict_keys_match_guard(self): + return self._set_vt.install_dict_keys_match_guard() + + def install_dict_contains_guard(self): + return self._set_vt.install_dict_contains_guard() + + class UserDefinedListVariable(UserDefinedObjectVariable): """ Represents user defined objects that are subclasses of lists. @@ -1637,18 +1823,24 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable): _nonvar_fields = UserDefinedObjectVariable._nonvar_fields - def __init__(self, value, **kwargs): - super().__init__(value, **kwargs) - self._tuple_vt = None - - def set_underlying_tuple_vt(self, tuple_vt): + def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): + super().__init__(value, init_args=init_args, **kwargs) self._tuple_vt = tuple_vt - - @staticmethod - def create(value, tuple_vt, **kwargs): - result = UserDefinedTupleVariable(value, **kwargs) - result.set_underlying_tuple_vt(tuple_vt) - return result + if self._tuple_vt is None: + assert self.source is None, ( + "tuple_vt must be constructed by builder.py when source is present" + ) + # Emulate `tuple.__new__` + # https://github.com/python/cpython/blob/3.11/Objects/tupleobject.c#L697-L710 + # + # TODO this duplicates the logic in `BuiltinVariable(tuple)` + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + elems = init_args[0].unpack_var_sequence(tx) + self._tuple_vt = variables.TupleVariable( + elems, mutation_type=ValueMutationNew() + ) def call_method( self, diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 13ac881df54ff0..df72642d0cc5af 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import lru_cache -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union from unittest.mock import patch import torch @@ -48,6 +48,9 @@ from .wrappers import _wrap_submodules from .utils import _materialize_cpp_cia_ops +if TYPE_CHECKING: + from torch._C._aoti import AOTIModelContainerRunner + log = logging.getLogger(__name__) @dataclasses.dataclass @@ -83,7 +86,7 @@ def aot_compile( remove_runtime_assertions: bool = False, disable_constraint_solver: bool = False, same_signature: bool = True, -) -> Union[list[str], str]: +) -> Union[list[Any], str]: """ Note: this function is not stable yet @@ -160,23 +163,23 @@ def aot_load(so_path: str, device: str) -> Callable: aot_compile_warning() if device == "cpu": - runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] + runner: AOTIModelContainerRunner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) elif device == "cuda" or device.startswith("cuda:"): - runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) elif device == "xpu" or device.startswith("xpu:"): - runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) # type: ignore[assignment, call-arg] + runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) elif device == "mps" or device.startswith("mps:"): - runner = torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) # type: ignore[assignment, call-arg] + runner = torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) else: raise RuntimeError("Unsupported device " + device) def optimized(*args, **kwargs): - call_spec = runner.get_call_spec() # type: ignore[attr-defined] + call_spec = runner.get_call_spec() in_spec = pytree.treespec_loads(call_spec[0]) out_spec = pytree.treespec_loads(call_spec[1]) flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] - flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + flat_outputs = runner.run(flat_inputs) return pytree.tree_unflatten(flat_outputs, out_spec) return optimized diff --git a/torch/_export/converter.py b/torch/_export/converter.py index c8b8a6271edd10..bf0cad5a310aa1 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Sequence from contextlib import contextmanager -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.export._trace @@ -229,7 +229,7 @@ def get_dtype_as_int(tensor): # Those operators will be automatically populated to a instance method # of TS2FXGraphConverter with name convert__(). # Please check __init__ for method population implementations. -kind_to_standard_operators = { +kind_to_standard_operators: dict[str, Callable[..., Any]] = { "prim::max": builtins.max, "prim::min": builtins.min, "prim::TupleIndex": operator.getitem, @@ -624,9 +624,9 @@ def convert_graph_inputs(self): self.fx_graph, name, self.is_top_level_graph() ) elif name in self.name_to_constant: - assert isinstance( - self.name_to_constant[name], torch.ScriptObject - ), "Input conversion only handles ScriptObject" + assert isinstance(self.name_to_constant[name], torch.ScriptObject), ( + "Input conversion only handles ScriptObject" + ) normalized_name = normalize_name(name) self.input_specs.append( InputSpec( @@ -661,9 +661,7 @@ def convert_aten_Float(self, node: torch._C.Node): def to_float_tensor(t): return t.to(dtype=torch.float).item() - inp_list = [ - self.get_fx_value_by_ir_value(inp) for inp in node.inputs() - ] # noqa: C416 + inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416 fx_node = self.fx_graph.call_function( to_float_tensor, tuple(inp_list), @@ -749,9 +747,7 @@ def convert_prim_Constant(self, node: torch._C.Node): self.name_to_constant[name] = value def convert_prim_CallMethod(self, node: torch._C.Node): - inp_list = [ - self.get_fx_value_by_ir_value(inp) for inp in node.inputs() - ] # noqa: C416 + inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416 fx_node = self.fx_graph.call_method( node.s("name"), tuple(inp_list), @@ -783,9 +779,9 @@ def convert_prim_GetAttr(self, node: torch._C.Node): self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn) else: if attr_fqn not in self.name_to_non_tensor_attribute_node: - self.name_to_non_tensor_attribute_node[ - attr_fqn - ] = self.name_to_non_tensor_attribute[attr_fqn] + self.name_to_non_tensor_attribute_node[attr_fqn] = ( + self.name_to_non_tensor_attribute[attr_fqn] + ) self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[ attr_fqn ] @@ -850,15 +846,15 @@ def convert_prim_DictConstruct(self, node: torch._C.Node): k = self.get_fx_value_by_ir_value(inp) else: v = self.get_fx_value_by_ir_value(inp) - assert ( - k is not None and v is not None - ), "DictConstruct has an empty key value pair." + assert k is not None and v is not None, ( + "DictConstruct has an empty key value pair." + ) output_dict[k] = v k, v = None, None - assert ( - k is None and v is None - ), "DictConstruct has an odd number of elements (violating our assumption)." + assert k is None and v is None, ( + "DictConstruct has an odd number of elements (violating our assumption)." + ) output_name = node.output().debugName() self.name_to_node[output_name] = output_dict @@ -1124,9 +1120,9 @@ def convert_prim_Loop(self, node: torch._C.Node): ), # + 1 because the 0th element is the condition. ) global_argument_index = global_arguments.index(name) - fx_block_args[ - i + node.outputsSize() + global_argument_index - ] = self.name_to_node[name] + fx_block_args[i + node.outputsSize() + global_argument_index] = ( + self.name_to_node[name] + ) def _check_set_attr_in_if_block(self, if_node: torch._C.Node): for block in if_node.blocks(): @@ -1545,9 +1541,9 @@ def retrace_as_exported_program( for spec in ep.graph_signature.input_specs: # Mark as constant tensors for erroneously traced buffers. if spec.kind == InputKind.BUFFER and spec.target in name_to_constant: - assert isinstance( - name_to_constant[spec.target], torch.Tensor - ), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer" + assert isinstance(name_to_constant[spec.target], torch.Tensor), ( + f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer" + ) spec.kind = InputKind.CONSTANT_TENSOR spec.persistent = None ep.verifier().check(ep) diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 275dde63173a9e..f1d5e79b536a16 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -68,11 +68,77 @@ log = logging.getLogger(__name__) -def key_path_to_source(kp: KeyPath) -> Source: +class _KeyPath: + """ + Wraps `KeyPath` to aid `isinstance` checks. + """ + + def __init__(self, kp: KeyPath): + self.kp = kp + + +class _KeyPathTrie: + """ + Builds a trie of `KeyPath` prefixes mapping to `Source` leaves. + """ + + def __init__(self): + self.root = {} + + def add(self, kp: KeyPath, src: Source): + assert len(kp) > 0 + *path, leaf = kp + node = self.root + for k in path: + if k not in node: + node[k] = {} + node = node[k] + node[leaf] = src + + def get(self, kp: KeyPath) -> tuple[Source, KeyPath]: + node = self.root + while not isinstance(node, Source): + assert len(kp) > 0 + k, *kp = kp # type: ignore[assignment] + node = node[k] + return node, kp + + +def make_sourced_prefixes(nn_module, args, kwargs) -> _KeyPathTrie: + kp_args, kp_kwargs = tree_map_with_path( + lambda kp, _: _KeyPath(kp), + (tuple(None for _ in args), {k: None for k in kwargs}), # noqa: C420 + ) + kp_combined_args = _combine_args(nn_module, kp_args, kp_kwargs) + + sourced_prefixes = _KeyPathTrie() + for name, struct in kp_combined_args.items(): + src = LocalSource(name) + + if isinstance(struct, _KeyPath): + sourced_prefixes.add(struct.kp, src) + elif isinstance(struct, tuple): + for i, prefix in enumerate(struct): + assert isinstance(prefix, _KeyPath) + sourced_prefixes.add(prefix.kp, GetItemSource(src, i)) + elif isinstance(struct, dict): + for k, prefix in struct.items(): + assert isinstance(prefix, _KeyPath) + sourced_prefixes.add(prefix.kp, GetItemSource(src, k)) + + return sourced_prefixes + + +def key_path_to_source( + kp: KeyPath, sourced_prefixes: Optional[_KeyPathTrie] = None +) -> Source: """ Given a key path, return the source for the key path. """ - source: Source = LocalSource("args") + if sourced_prefixes is None: + source: Source = LocalSource("args") + else: + source, kp = sourced_prefixes.get(kp) for k in kp: if isinstance(k, SequenceKey): source = GetItemSource(source, k.idx) @@ -96,13 +162,17 @@ def fakify( t: Any, t_constraints: dict[int, dict[int, Constraint]], sources: dict[tuple[int, int], list[Source]], + sourced_prefixes: Optional[_KeyPathTrie] = None, ): - source = key_path_to_source(kp) + source = key_path_to_source(kp, sourced_prefixes=sourced_prefixes) if _is_constant_argument(t) or isinstance(t, (torch.ScriptObject, torch.nn.Module)): return t if isinstance(t, _IntWrapper): - if t.dynamism is not None and t.dynamism.type in (_DimHintType.DYNAMIC, _DimHintType.AUTO): # type: ignore[union-attr] + if t.dynamism is not None and t.dynamism.type in ( # type: ignore[union-attr] + _DimHintType.DYNAMIC, + _DimHintType.AUTO, + ): symint = mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr] t.val, source, DimDynamic.DYNAMIC ) @@ -143,9 +213,11 @@ def fakify( constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload] else: dynamic_sizes.append(DimDynamic.STATIC) - symbolic_context = StatelessSymbolicContext( - dynamic_sizes=dynamic_sizes, - constraint_sizes=constraint_sizes, # type: ignore[arg-type] + symbolic_context: StatelessSymbolicContext = ( # make mypy happy + StatelessSymbolicContext( + dynamic_sizes=dynamic_sizes, + constraint_sizes=constraint_sizes, # type: ignore[arg-type] + ) ) t_id = id(t) assert mode.shape_env is not None @@ -258,7 +330,6 @@ def make_fake_inputs( args, kwargs, dynamic_shapes, - _is_torch_jit_trace=False, allow_complex_guards_as_runtime_asserts=False, ): """ @@ -294,7 +365,7 @@ def make_fake_inputs( # a toplevel TracingContext with a fake mode, so we do not want to # create another fake mode. fake_mode = context.fake_mode - elif not _is_torch_jit_trace: + else: if isinstance(nn_module.forward, functools.partial): # functools handles nesting by itself, no need to recurse code = nn_module.forward.func.__code__ @@ -317,17 +388,6 @@ def make_fake_inputs( allow_non_fake_inputs=True, export=True, ) - else: - with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): - fake_mode = FakeTensorMode( - shape_env=ShapeEnv( - tracked_fakes=[], - prefer_deferred_runtime_asserts_over_guards=True, - allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, - trace_asserts=True, - ), - allow_non_fake_inputs=True, - ) if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: raise ValueError( "Detected fake_mode does not have a shape_env with tracked fakes. " @@ -336,14 +396,18 @@ def make_fake_inputs( ) with fake_mode: - # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock - if not _is_torch_jit_trace: - original_signature = inspect.signature(nn_module.forward) - else: - original_signature = None + original_signature = inspect.signature(nn_module.forward) sources: dict[tuple[int, int], list[Source]] = defaultdict(list) + sourced_prefixes = make_sourced_prefixes(nn_module, args, kwargs) fake_args, fake_kwargs = tree_map_with_path( - lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), + lambda kp, val: fakify( + fake_mode, + kp, + val, + t_constraints, + sources, + sourced_prefixes=sourced_prefixes, + ), (args, kwargs), ) @@ -413,7 +477,6 @@ def produce_guards_and_solve_constraints( dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], equalities_inputs: EqualityConstraint, original_signature: inspect.Signature, - _is_torch_jit_trace=False, ): """ Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, @@ -454,16 +517,14 @@ def produce_guards_and_solve_constraints( raise constraint_violation_error dim_constraints.solve() forced_specializations = dim_constraints.forced_specializations() - if not _is_torch_jit_trace: - msg = dim_constraints.prettify_results( - original_signature, - dynamic_shapes, # type: ignore[arg-type] - constraint_violation_error, - forced_specializations, # type: ignore[arg-type] - ) - else: - # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod - msg = "dummy constraint violation message" + + msg = dim_constraints.prettify_results( + original_signature, + dynamic_shapes, # type: ignore[arg-type] + constraint_violation_error, + forced_specializations, # type: ignore[arg-type] + ) + if constraint_violation_error: constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) elif forced_specializations: @@ -914,7 +975,8 @@ def _override(self, func, args, kwargs): # because it has some known incompletenesses, e.g., it doesn't support # empty data. See https://github.com/pytorch/pytorch/issues/143216 if any( - isinstance(a, torch.SymInt) for a in pytree.tree_flatten(args[0])[0] + isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)) + for a in pytree.tree_flatten(args[0])[0] ): return torch._refs.tensor, args, kwargs if func.__name__ == "__getitem__" and isinstance(args[0], torch.Tensor): diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 5fbd97d00f90f9..952e904ca26e0a 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -6,20 +6,23 @@ from typing import Any, Callable, Optional, Union import torch -from torch._higher_order_ops.map import _unstack_pytree from torch import fx from torch._dispatch.python import enable_python_dispatcher from torch._export.pass_infra.node_metadata import NodeMetadata from torch._export.pass_infra.proxy_value import ProxyValue +from torch._higher_order_ops.map import _unstack_pytree from torch._subclasses import FakeTensor, UnsupportedFakeTensorException from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx import traceback as fx_traceback from torch.fx.experimental.proxy_tensor import PythonKeyTracer +from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + PropagateUnbackedSymInts, +) from torch.fx.graph import CodeGen from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.utils import _pytree as pytree -from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings __all__ = ["_ExportPassBaseDeprecatedDoNotUse"] @@ -56,9 +59,10 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): def _create_dummy_node_metadata(): return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) - class ExportTracer(PythonKeyTracer): - def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None: + def __init__( + self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen + ) -> None: super().__init__() self.callback = callback self.root = torch.nn.Module() @@ -92,12 +96,24 @@ def create_arg(self, a: Argument) -> torch.fx.Node: return node def set_metadata( - self, node: torch.fx.Node, value: Argument, + self, + node: torch.fx.Node, + value: Argument, ) -> None: # propagate the fake tensor or sym nodes def make_val( x: Argument, - ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]: + ) -> Union[ + FakeTensor, + torch.SymInt, + torch.SymFloat, + torch.SymBool, + int, + float, + bool, + str, + None, + ]: if isinstance(x, FakeTensor): return x elif isinstance(x, torch.Tensor): @@ -124,7 +140,18 @@ def make_val( ) fake_tensor = None return fake_tensor - elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)): + elif isinstance( + x, + ( + torch.SymInt, + torch.SymFloat, + torch.SymBool, + int, + float, + bool, + str, + ), + ): return x else: return None @@ -153,7 +180,9 @@ def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) class ExportInterpreter(fx.Interpreter): - def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None: + def __init__( + self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule + ) -> None: super().__init__(gm) self.callback = callback self.node: torch.fx.Node = next(iter(gm.graph.nodes)) @@ -186,13 +215,19 @@ def call_function( if target == operator.getitem: value, key = args return self.callback.call_getitem(value, key, meta) - elif getattr(target, "__module__", None) in {"_operator", "builtins", "math"}: + elif getattr(target, "__module__", None) in { + "_operator", + "builtins", + "math", + }: assert callable(target) return self.callback.call_sym(target, args, meta) elif target in _TORCH_SYM_OPS: assert callable(target) return self.callback.call_sym(target, args, meta) - elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + elif isinstance( + target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket) + ): return self.callback.call_operator( target, args, @@ -217,8 +252,11 @@ def call_function( else: raise ExportPassBaseError(f"Unsupported target type: {target}") - def get_attr( - self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override] + def get_attr( # type: ignore[override] + self, + target: str, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], ) -> Argument: return super().get_attr(target, args, kwargs) @@ -230,8 +268,11 @@ def call_module( ) -> None: raise ExportPassBaseError("call_module is not supported.") - def call_method( - self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override] + def call_method( # type: ignore[override] + self, + target: str, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], ) -> None: raise ExportPassBaseError("call_method is not supported.") @@ -269,7 +310,9 @@ def _fx( if isinstance(target, torch._ops.OpOverload): name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) - res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name) + res_proxy = self.tracer.create_proxy( + kind, target, args_proxy, kwargs_proxy, name=name + ) res_proxy.node.meta.update(meta.data) if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env): if symbol_to_path := compute_unbacked_bindings(shape_env, res_data): @@ -389,13 +432,17 @@ def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue: def call_submodule( self, graph_module: fx.GraphModule, inputs: tuple[Argument, ...] ) -> PassResult: - prev_tracer, self.tracer = self.tracer, self.ExportTracer( - self, graph_module.graph._codegen + prev_tracer, self.tracer = ( + self.tracer, + self.ExportTracer(self, graph_module.graph._codegen), ) self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode interpreter = self.ExportInterpreter(self, graph_module) - prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment] - torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + prev_interpreter, self.interpreter = ( + self.interpreter, + torch.fx.Interpreter( # type: ignore[assignment] + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ), ) inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) with fx_traceback.preserve_node_meta(): @@ -421,9 +468,9 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: fake_tensor_mode = None for i in inputs: if isinstance(i, FakeTensor): - assert ( - fake_tensor_mode is None or fake_tensor_mode is i.fake_mode - ), "Multiple fake tensor mode detected." + assert fake_tensor_mode is None or fake_tensor_mode is i.fake_mode, ( + "Multiple fake tensor mode detected." + ) fake_tensor_mode = i.fake_mode if fake_tensor_mode is None: self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) diff --git a/torch/_export/pass_infra/proxy_value.py b/torch/_export/pass_infra/proxy_value.py index df62c9d0ffe563..40613c1283228b 100644 --- a/torch/_export/pass_infra/proxy_value.py +++ b/torch/_export/pass_infra/proxy_value.py @@ -1,12 +1,13 @@ # pyre-strict -from typing import Union, Generic -from collections.abc import Iterator, Iterable +from collections.abc import Iterable, Iterator +from typing import Generic, TypeVar, Union + import torch -from typing import TypeVar _T = TypeVar("_T") + class ProxyValue(Generic[_T]): # pyre-ignore def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]): diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index 99df6c7fb635f8..bd81f0a92676fa 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -9,10 +9,11 @@ import torch import torch.fx -from torch.utils._sympy.value_ranges import ValueRanges -from torch.utils._sympy.numbers import int_oo from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + __all__ = ["InputDim"] @@ -30,9 +31,7 @@ def _convert_to_int(val): return -math.inf if isinstance(val, sympy.Integer): return int(val) - raise RuntimeError( - "Export constraints cannot be non-integer expressions" - ) + raise RuntimeError("Export constraints cannot be non-integer expressions") def _convert_range_to_int(range: ValueRanges): @@ -55,10 +54,14 @@ def __init__( def _assert_range_constraint(self, node, lower, upper, assert_msg): last_node = node if lower > -math.inf: - last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg) + last_node = self._insert_assert_async( + last_node, operator.ge, node, lower, assert_msg + ) if upper < math.inf: - last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg) + last_node = self._insert_assert_async( + last_node, operator.le, node, upper, assert_msg + ) def _insert_assert_async(self, last_node, op, lower, upper, assert_msg): """ @@ -70,7 +73,9 @@ def _insert_assert_async(self, last_node, op, lower, upper, assert_msg): with graph.inserting_after(last_node): cmp = graph.call_function(op, (lower, upper), {}) with graph.inserting_after(cmp): - cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {}) + cmp_tensor = graph.call_function( + torch.ops.aten.scalar_tensor.default, (cmp,), {} + ) with graph.inserting_after(cmp_tensor): assert_async = graph.call_function( torch.ops.aten._assert_async.msg, @@ -111,7 +116,9 @@ def add_assertions(val): symbol = val.node.expr if symbol in self.existing_inline_assertions: return call_backs, messages - if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol): + if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols( + symbol + ): if symbol in self._asserts_generated_unbacked_symbols: return call_backs, messages # We only care about unbacked symints for these inline @@ -120,7 +127,11 @@ def add_assertions(val): min_val, max_val = _convert_range_to_int(constraint) assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." call_backs.append( - partial(self._assert_range_constraint, lower=min_val, upper=max_val) + partial( + self._assert_range_constraint, + lower=min_val, + upper=max_val, + ) ) messages.append(assert_msg) self._asserts_generated_unbacked_symbols.add(symbol) @@ -129,6 +140,7 @@ def add_assertions(val): for i, sym in enumerate(val.shape): cbs, msgs = add_assertions(sym) for cb, msg in zip(cbs, msgs): + def sym_size_cb(node, assert_msg, dim): with node.graph.inserting_after(node): dim_node = module.graph.call_function( @@ -137,6 +149,7 @@ def sym_size_cb(node, assert_msg, dim): {}, ) cb(node=dim_node, assert_msg=assert_msg) + call_backs.append(partial(sym_size_cb, dim=i)) messages.append(f".shape[{i}]" + msg) return call_backs, messages @@ -149,12 +162,18 @@ def sym_size_cb(node, assert_msg, dim): # Sometimes this pass would return a wrong graph where we have mismatched # node names in signature. Before we fix it, let's just skip it. - if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass: + if ( + self.counter == 0 + and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass + ): return PassResult(graph_module, False) # Populate the stack trace with dummy vals to respect IR for node in graph_module.graph.nodes: - if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]: + if not node.meta.get("stack_trace", None) and node.op not in [ + "placeholder", + "output", + ]: node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) return PassResult(graph_module, True) @@ -179,10 +198,10 @@ def _get_existing_inline_assertions( compare_arg = node.args[0] if not ( - isinstance(compare_arg, torch.fx.Node) and - compare_arg.op == "call_function" and - compare_arg.target in (operator.le, operator.ge) and - len(compare_arg.args) == 2 + isinstance(compare_arg, torch.fx.Node) + and compare_arg.op == "call_function" + and compare_arg.target in (operator.le, operator.ge) + and len(compare_arg.args) == 2 ): continue @@ -191,9 +210,9 @@ def _get_existing_inline_assertions( def maybe_get_symint(x): if ( - isinstance(x, torch.fx.Node) and - "val" in x.meta and - isinstance(x.meta["val"], torch.SymInt) + isinstance(x, torch.fx.Node) + and "val" in x.meta + and isinstance(x.meta["val"], torch.SymInt) ): return x.meta["val"].node.expr return x @@ -214,9 +233,13 @@ def maybe_get_symint(x): continue if symint not in range_constraints: - raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}") + raise RuntimeError( + f"Unable to find symint {symint} in {range_constraints}" + ) - previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf)) + previous_range = existing_inline_assertions.get( + symint, ValueRanges(-math.inf, math.inf) + ) if symint is lhs: bounds = ValueRanges(-math.inf, scalar) diff --git a/torch/_export/passes/functionalize_side_effectful_ops_pass.py b/torch/_export/passes/functionalize_side_effectful_ops_pass.py index c14e859e4ef34a..45dd734c72959c 100644 --- a/torch/_export/passes/functionalize_side_effectful_ops_pass.py +++ b/torch/_export/passes/functionalize_side_effectful_ops_pass.py @@ -2,15 +2,20 @@ from typing import Optional import torch -from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument +from torch._export.pass_base import ( + _ExportPassBaseDeprecatedDoNotUse, + Argument, + PassResult, +) from torch._export.pass_infra.node_metadata import NodeMetadata from torch._export.pass_infra.proxy_value import ProxyValue from torch._ops import OpOverload + aten = torch.ops.aten _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = { - aten.sym_constrain_range.default: aten._functional_sym_constrain_range, + aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default, aten._assert_async.msg: aten._functional_assert_async.msg, } diff --git a/torch/_export/passes/insert_custom_op_guards.py b/torch/_export/passes/insert_custom_op_guards.py index 76ab68dfc08979..4deecdf8182290 100644 --- a/torch/_export/passes/insert_custom_op_guards.py +++ b/torch/_export/passes/insert_custom_op_guards.py @@ -16,12 +16,15 @@ def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> """ for node in gm.graph.nodes: if node.op == "call_function" and str(node.target) in ops_to_guard: - with _set_node_metadata_hook( - gm, - functools.partial( - _node_metadata_hook, stack_trace=node.meta.get("stack_trace") + with ( + _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, stack_trace=node.meta.get("stack_trace") + ), ), - ), gm.graph.inserting_before(node): + gm.graph.inserting_before(node), + ): for arg in (*node.args, *node.kwargs.values()): if isinstance(arg, torch.fx.Node) and isinstance( arg.meta.get("val"), torch.Tensor diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index a9a711799f8a54..19d0462cc09dc3 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -373,7 +373,7 @@ def lift_constants_pass( def rewrite_script_object_meta( gm: torch.fx.GraphModule, -) -> dict[str, _ConstantAttributeType,]: +) -> dict[str, _ConstantAttributeType]: """When tracing, we produce a graph with FakeScriptObject in the meta["val"]. diff --git a/torch/_export/serde/aoti_schema.py b/torch/_export/serde/aoti_schema.py deleted file mode 100644 index d19add43705c95..00000000000000 --- a/torch/_export/serde/aoti_schema.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -from torch._export.serde.schema import Node - - -@dataclass -class ExternKernelNode: - name: str - node: Node - - -@dataclass -class ExternKernelNodes: - nodes: list[ExternKernelNode] diff --git a/torch/_export/serde/dynamic_shapes.py b/torch/_export/serde/dynamic_shapes.py index 10061ecb901e8d..e6a0295163dcb3 100644 --- a/torch/_export/serde/dynamic_shapes.py +++ b/torch/_export/serde/dynamic_shapes.py @@ -107,20 +107,20 @@ def _dump_dynamic_shapes( would generate the following output: ``` { - 'dynamic_shapes': ( + "dynamic_shapes": ( [ - ['dx', 4], - ['dx + 1', 4], + ["dx", 4], + ["dx + 1", 4], ], - ['_DimHint.STATIC'], - ['_DimHint.STATIC', '_DimHint.STATIC'], + ["_DimHint.STATIC"], + ["_DimHint.STATIC", "_DimHint.STATIC"], None, ), - 'dims': { - 'dx': { - 'min': 4, - 'max': 16, - 'derived': ['dx + 1'], + "dims": { + "dx": { + "min": 4, + "max": 16, + "derived": ["dx + 1"], }, }, } @@ -149,7 +149,7 @@ def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] return out def _track_dim_from_dims( - val: Union[None, int, _DimHint, Dim] + val: Union[None, int, _DimHint, Dim], ) -> Union[None, int, str]: """ Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. @@ -295,7 +295,7 @@ def _load_dynamic_shapes( dim_cache[_expr] = ddim # cache derived dims def deserialize_shape( - val: Union[None, int, str] + val: Union[None, int, str], ) -> Union[None, int, Dim, _DimHint]: if val is None or isinstance(val, int): return val diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 949e18e56f24a5..ef1ec3ee49819d 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -5,7 +5,7 @@ from enum import IntEnum from typing import Annotated, Optional -from torch._export.serde.union import _Union +from torch._export.serde.union import _Union, _union_dataclass # NOTE: Please update this value if any modifications are made to the schema @@ -60,7 +60,7 @@ class Device: index: Annotated[Optional[int], 20] = None -@dataclass(repr=False) +@_union_dataclass class SymExprHint(_Union): as_int: Annotated[int, 10] as_bool: Annotated[bool, 20] @@ -77,19 +77,19 @@ class SymExpr: hint: Annotated[Optional[SymExprHint], 20] = None -@dataclass(repr=False) +@_union_dataclass class SymInt(_Union): as_expr: Annotated[SymExpr, 10] as_int: Annotated[int, 20] -@dataclass(repr=False) +@_union_dataclass class SymFloat(_Union): as_expr: Annotated[SymExpr, 10] as_float: Annotated[float, 20] -@dataclass(repr=False) +@_union_dataclass class SymBool(_Union): as_expr: Annotated[SymExpr, 10] as_bool: Annotated[bool, 20] @@ -112,7 +112,7 @@ class TensorMeta: # of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to # be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints # to the "as_int" field. -@dataclass(repr=False) +@_union_dataclass class SymIntArgument(_Union): as_name: Annotated[str, 10] as_int: Annotated[int, 20] @@ -124,7 +124,7 @@ class SymIntArgument(_Union): # of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to # be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints # to the "as_float" field. -@dataclass(repr=False) +@_union_dataclass class SymFloatArgument(_Union): as_name: Annotated[str, 10] as_float: Annotated[float, 20] @@ -136,7 +136,7 @@ class SymFloatArgument(_Union): # of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to # be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools # to the "as_bool" field. -@dataclass(repr=False) +@_union_dataclass class SymBoolArgument(_Union): as_name: Annotated[str, 10] as_bool: Annotated[bool, 20] @@ -156,7 +156,7 @@ class TokenArgument: # (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the # type List[OptionalTensorArgument], with tensor values seiralized to the # "as_tensor" field, and None values serialized to the "as_none" field. -@dataclass(repr=False) +@_union_dataclass class OptionalTensorArgument(_Union): as_tensor: Annotated[TensorArgument, 20] as_none: Annotated[bool, 10] @@ -175,7 +175,7 @@ class CustomObjArgument: # This is actually a union type -@dataclass(repr=False) +@_union_dataclass class Argument(_Union): as_none: Annotated[bool, 10] as_tensor: Annotated[TensorArgument, 20] @@ -253,7 +253,7 @@ class UserInputSpec: arg: Annotated[Argument, 10] -@dataclass(repr=False) +@_union_dataclass class ConstantValue(_Union): as_none: Annotated[bool, 10] as_int: Annotated[int, 20] @@ -298,7 +298,7 @@ class InputTokenSpec: arg: Annotated[TokenArgument, 10] -@dataclass(repr=False) +@_union_dataclass class InputSpec(_Union): user_input: Annotated[UserInputSpec, 10] parameter: Annotated[InputToParameterSpec, 20] @@ -348,7 +348,7 @@ class OutputTokenSpec: arg: Annotated[TokenArgument, 10] -@dataclass(repr=False) +@_union_dataclass class OutputSpec(_Union): user_output: Annotated[UserOutputSpec, 10] loss_output: Annotated[LossOutputSpec, 20] diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index c976b9f13b85e3..0c6c57c648bdc1 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -129,13 +129,13 @@ def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: t, cpp_type, thrift_type = dump_type(f.type, 0) ret = {"type": t} cpp_default: Optional[str] = None - assert ( - typing.get_origin(f.type) == Annotated - ), f"Field {f.name} must be annotated with an integer id." + assert typing.get_origin(f.type) == Annotated, ( + f"Field {f.name} must be annotated with an integer id." + ) thrift_id = f.type.__metadata__[0] - assert ( - type(thrift_id) is int - ), f"Field {f.name} must be annotated with an integer id." + assert type(thrift_id) is int, ( + f"Field {f.name} must be annotated with an integer id." + ) value = dataclasses.MISSING if f.default is not dataclasses.MISSING: @@ -173,9 +173,7 @@ def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: def _handle_int_enum(name, ty): yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} - cpp_enum_defs[ - name - ] = f""" + cpp_enum_defs[name] = f""" enum class {name} {{ {chr(10).join([f" {x.name} = {x.value}," for x in ty])} }}; @@ -240,14 +238,17 @@ def accessor(name, ty): from_json_def = f"""{{ {name} nlohmann_json_default_obj; -{chr(10).join( - [f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});' - for name, f in cpp_fields.items()])} +{ + chr(10).join( + [ + f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});' + for name, f in cpp_fields.items() + ] + ) + } }} """ - cpp_class_defs[ - name - ] = f""" + cpp_class_defs[name] = f""" class {name} {{ private: {field_decls} @@ -262,9 +263,7 @@ class {name} {{ cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}") cpp_type_decls.append(f"class {name};") - thrift_type_defs[ - name - ] = f""" + thrift_type_defs[name] = f""" struct {name} {{ {chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} }}""" @@ -307,9 +306,7 @@ def accessor(name, ty, idx): ] ) - cpp_class_defs[ - name - ] = f""" + cpp_class_defs[name] = f""" class {name} {{ struct Void {{}}; @@ -352,9 +349,7 @@ class {name} {{ """ cpp_type_decls.append(f"class {name};") - thrift_type_defs[ - name - ] = f""" + thrift_type_defs[name] = f""" union {name} {{ {chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} }}""" diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index c9e2acca7d701e..9db34168512829 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -322,9 +322,9 @@ def _reconstruct_fake_tensor( json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) # Find the current fake mode - assert ( - _CURRENT_DESERIALIZER is not None - ), "Need access to current deserializer state" + assert _CURRENT_DESERIALIZER is not None, ( + "Need access to current deserializer state" + ) fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) if is_parameter: fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] @@ -337,9 +337,9 @@ def serialize_torch_artifact( if artifact is None: return b"" - assert ( - FakeTensor not in copyreg.dispatch_table - ), "Refusing to stomp on existing FakeTensor reducer" + assert FakeTensor not in copyreg.dispatch_table, ( + "Refusing to stomp on existing FakeTensor reducer" + ) try: copyreg.pickle(FakeTensor, _reduce_fake_tensor) buffer = io.BytesIO() @@ -356,7 +356,7 @@ def serialize_torch_artifact( def deserialize_torch_artifact( - serialized: Union[dict[str, Any], tuple[Any, ...], bytes] + serialized: Union[dict[str, Any], tuple[Any, ...], bytes], ): if isinstance(serialized, (dict, tuple)): return serialized @@ -415,7 +415,7 @@ def _symbol_index(sym: sympy.Symbol, sym_type: SymT): def serialize_range_constraints( - range_constraints: dict[sympy.Symbol, ValueRanges] + range_constraints: dict[sympy.Symbol, ValueRanges], ) -> dict[str, RangeConstraint]: return { str(k): RangeConstraint( @@ -499,9 +499,9 @@ def handle_placeholder(self, node: torch.fx.Node): graph_input = Argument.create( as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn) ) - self.graph_state.custom_obj_values[ - node.name - ] = self.serialize_script_obj_meta(val) + self.graph_state.custom_obj_values[node.name] = ( + self.serialize_script_obj_meta(val) + ) else: raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") self.graph_state.inputs.append(graph_input) @@ -627,9 +627,9 @@ def serialize_tensor_list_output(node): ) elif type(node.target) in _serialization_registry: # Sanity check for unhandled serialization. - assert ( - type(node.target) in _serialization_registry - ), f"{type(node.target)} is not supported in export serialization." + assert type(node.target) in _serialization_registry, ( + f"{type(node.target)} is not supported in export serialization." + ) handler = _serialization_registry[type(node.target)] namespace = handler.namespace() @@ -1295,9 +1295,9 @@ def store_namedtuple_fields(ts): f"but somehow previously was found to have field names {field_names}." ) else: - self.treespec_namedtuple_fields[ - serialized_type_name - ] = NamedTupleDef(field_names=ts.context._fields) + self.treespec_namedtuple_fields[serialized_type_name] = ( + NamedTupleDef(field_names=ts.context._fields) + ) for child in ts.children_specs: store_namedtuple_fields(child) @@ -1516,9 +1516,9 @@ def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]: idx_to_name = {} for user in node.users: - assert ( - user.target is operator.getitem - ), f"User node {user} of {node} is incorrect" + assert user.target is operator.getitem, ( + f"User node {user} of {node} is incorrect" + ) idx_to_name[user.args[1]] = user.name for idx, _ in enumerate(meta_val): @@ -1713,6 +1713,9 @@ def deserialize_operator(self, serialized_target: str): elif serialized_target.startswith("torch"): module = torch # type: ignore[misc] serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("math"): + module = math # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] elif serialized_target.startswith("#"): return self.deserialize_extension_operator(serialized_target) else: # TODO(zhxchen17) Don't catch all here. @@ -2495,6 +2498,7 @@ def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): len(serialized_node.outputs) == 1 and "torch.ops.higher_order" in serialized_node.target and not getattr(serialized_node, "is_hop_single_tensor_return", True) + and serialized_node.outputs[0].type != "as_none" ): def _deserialize_hop_with_single_return(serialized_node, fx_node): @@ -3525,9 +3529,9 @@ def register_extension( extension_handler: type[ExtensionHandler], ): """Register custom de/serialization method for a node with non-standard type.""" - assert issubclass( - extension_handler, ExtensionHandler - ), f"Expected ExtensionHandler, got {extension_handler}." + assert issubclass(extension_handler, ExtensionHandler), ( + f"Expected ExtensionHandler, got {extension_handler}." + ) assert op_type not in _serialization_registry, f"{op_type} is already registered." assert isinstance(op_type, type) # Maybe a good idea to enforce this first. assert not ( diff --git a/torch/_export/serde/union.py b/torch/_export/serde/union.py index ca8a87951ea965..e0ca90dbad1ad7 100644 --- a/torch/_export/serde/union.py +++ b/torch/_export/serde/union.py @@ -1,7 +1,12 @@ # mypy: allow-untyped-defs import functools from collections.abc import Hashable -from dataclasses import fields +from dataclasses import dataclass, fields +from typing import TypeVar +from typing_extensions import dataclass_transform + + +T = TypeVar("T", bound="_Union") class _UnionTag(str): @@ -18,9 +23,9 @@ def create(t, cls): def __eq__(self, cmp) -> bool: assert isinstance(cmp, str) other = str(cmp) - assert other in _get_field_names( - self._cls - ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" + assert other in _get_field_names(self._cls), ( + f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" + ) return str(self) == other def __hash__(self): @@ -32,6 +37,18 @@ def _get_field_names(cls) -> set[str]: return {f.name for f in fields(cls)} +# If you turn a schema class that inherits from union into a dataclass, please use +# this decorator to configure it. It's safe, faster and allows code sharing. +# +# For example, _union_dataclass customizes the __eq__ method to only check the type +# and value property instead of default implmentation of dataclass which goes +# through every field in the dataclass. +@dataclass_transform(eq_default=False) +def _union_dataclass(cls: type[T]) -> type[T]: + assert issubclass(cls, _Union), f"{cls} must inheirt from {_Union}." + return dataclass(repr=False, eq=False)(cls) + + class _Union: _type: _UnionTag @@ -43,7 +60,10 @@ def create(cls, **kwargs): return obj def __post_init__(self): - assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc] + assert not any( + f.name in ("type", "_type", "create", "value") + for f in fields(self) # type: ignore[arg-type, misc] + ) @property def type(self) -> str: @@ -64,6 +84,11 @@ def __getattribute__(self, name): raise AttributeError(f"Field {name} is not set.") return attr + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Union): + return False + return self.type == other.type and self.value == other.value + def __str__(self): return self.__repr__() diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 159ff2dabadf8d..3117d7322340f5 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -448,9 +448,9 @@ def register_dataclass_as_pytree_node( from_dumpable_context: Optional[FromDumpableContextFn] = None, return_none_fields: bool = False, ) -> None: - assert dataclasses.is_dataclass( - cls - ), f"Only dataclasses can be registered with this function: {cls}" + assert dataclasses.is_dataclass(cls), ( + f"Only dataclasses can be registered with this function: {cls}" + ) def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]: flattened = [] @@ -644,11 +644,14 @@ def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: continue if (tensor_val := node.args[0].meta.get("val")) is not None: - with gm.graph.inserting_before(node), _set_node_metadata_hook( - gm, - functools.partial( - _node_metadata_hook, - stack_trace=node.meta.get("stack_trace"), + with ( + gm.graph.inserting_before(node), + _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + ), ), ): gm.graph.call_function( @@ -1342,6 +1345,7 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None: import torch + class Module(torch.nn.Module): def __init__(self): super().__init__() @@ -1350,12 +1354,15 @@ def __init__(self): def forward(self, x): return self.linear(x) + torch._export.utils.register_module_as_pytree_node(InputDataClass) + class Mod(torch.nn.Module): def forward(self, x, m): return m(x) + x + ep = torch.export.export(Mod(), (torch.randn(3), Module())) print(ep) diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index 82e3a6a2545840..47f736de303d2e 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -195,9 +195,9 @@ def wrapper(*args, **kwargs): for mode in torch_function_mode_stack if isinstance(mode, PreDispatchTorchFunctionMode) ] - assert ( - len(pre_dispatch_tf_modes) <= 1 - ), f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}" + assert len(pre_dispatch_tf_modes) <= 1, ( + f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}" + ) if len(pre_dispatch_tf_modes) == 0: return diff --git a/torch/_functorch/_activation_checkpointing/graph_info_provider.py b/torch/_functorch/_activation_checkpointing/graph_info_provider.py index d92b3728f543bb..2a5da58fdd6330 100644 --- a/torch/_functorch/_activation_checkpointing/graph_info_provider.py +++ b/torch/_functorch/_activation_checkpointing/graph_info_provider.py @@ -96,9 +96,9 @@ def inialize_from_graph( @property def recomputable_node_only_graph(self) -> nx.DiGraph: if self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] is None: - self._lazily_initialized_graphs[ - self.__RECOMPUTABLE_NODE_ONLY_GRAPH - ] = self._create_recomputable_node_only_graph() + self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] = ( + self._create_recomputable_node_only_graph() + ) return self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] @property @@ -119,17 +119,17 @@ def recomputable_node_only_graph_with_larger_graph_context(self) -> nx.DiGraph: @property def full_joint_nx_graph(self) -> nx.DiGraph: if self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] is None: - self._lazily_initialized_graphs[ - self.__FULL_NX_JOINT_GRAPH - ] = self._create_full_joint_graph() + self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] = ( + self._create_full_joint_graph() + ) return self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] @property def simplified_fx_joint_graph(self) -> Graph: if self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] is None: - self._lazily_initialized_graphs[ - self.__SIMPLIFIED_FX_JOINT_GRAPH - ] = self._recreate_psuedo_joint_graph() + self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] = ( + self._recreate_psuedo_joint_graph() + ) return self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] def get_non_ac_peak_memory(self) -> float: @@ -285,9 +285,7 @@ def _visualize_recomputable_candidate_graph_with_larger_context( float( self.recomputable_node_only_graph_with_larger_graph_context.nodes[ node - ][ - "memory" - ] + ]["memory"] ) ) ) diff --git a/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py b/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py index 666a0d1e1f8fc0..7cc60f6ed54bec 100644 --- a/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py +++ b/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py @@ -1,3 +1,4 @@ +import operator from collections import deque from typing import Callable @@ -165,7 +166,7 @@ def evaluate_knapsack_output( for i in saved_nodes_idxs ), ) - peak_memory = max(memory_list, key=lambda x: x[0])[0] + peak_memory = max(memory_list, key=operator.itemgetter(0))[0] else: peak_memory = sum( self._graph_info_provider.all_node_memories[ diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 68759cbaaab4e6..8a7dd901724c3a 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -2,6 +2,7 @@ """ Utils for caching the outputs of AOTAutograd """ + from __future__ import annotations import base64 @@ -21,6 +22,7 @@ from typing_extensions import override import torch +from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions from torch._dynamo.utils import ( chromium_event_log_active, @@ -63,6 +65,7 @@ from .runtime_wrappers import ( AOTDispatchAutograd, AOTDispatchSubclassWrapper, + CachedAutogradLazyBackwardCompileInfo, CompilerWrapper, FunctionalizedRngRuntimeWrapper, post_compile, @@ -423,16 +426,13 @@ class InductorOutput(Generic[TOut], ABC): """ @abstractmethod - def pre_save(self) -> None: - ... + def pre_save(self) -> None: ... @abstractmethod - def load(self, example_inputs) -> TOut: - ... + def load(self, example_inputs) -> TOut: ... @abstractmethod - def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: - ... + def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ... @dataclass @@ -471,7 +471,6 @@ def post_compile( }, payload_fn=lambda: json.dumps(cache_info), ) - counters["inductor"]["fxgraph_cache_hit"] += 1 # Run normal post compile graph.post_compile(self.example_inputs, constants, fx_config) return graph @@ -589,6 +588,17 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa def _is_backward(self) -> bool: return True + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + # Forward types don't have any extra parameters, so this is just a TypeAlias, in essence class BundledCompiledForward(CompiledFxGraphLoadable): @@ -599,7 +609,38 @@ class BundledCompiledForward(CompiledFxGraphLoadable): class BundledCompiledBackward( GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable ): - pass + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + + +@dataclass +class SerializedGraphModule: + fn: Callable[[dict[Any, Any], str], torch.nn.Module] + args: tuple[Any, ...] + + def __init__(self, gm: torch.fx.GraphModule): + self.fn, self.args = gm.__reduce__() + + def deserialize(self) -> torch.fx.GraphModule: + gm = self.fn(*self.args) + assert isinstance(gm, torch.fx.GraphModule) + return gm + + +def serialize_graph_module(gm: torch.fx.GraphModule) -> SerializedGraphModule: + # NOTE: mutates the graph module + gm.meta = {} + for node in gm.graph.nodes: + node.meta = {} + return SerializedGraphModule(gm) TForward = TypeVar("TForward", bound=InductorOutput) @@ -655,6 +696,9 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]): guards_expr: Optional[str] + # Used by Compiled Autograd + serialized_bw_module: Optional[SerializedGraphModule] + def pre_save(self): """ Perform any preparations to make the cache entry ready for serialization. @@ -796,6 +840,12 @@ def wrap_post_compile( if needs_autograd: assert self.compiled_bw is not None + + cached_lazy_backward = None + if self.serialized_bw_module is not None: + cached_lazy_backward = CachedAutogradLazyBackwardCompileInfo( + self.serialized_bw_module.deserialize + ) # This function is run on both cache miss and cache hit, either here # or in aot_dispatch_autograd. On a cache hit, # 1. the bw is already compiled @@ -809,7 +859,7 @@ def wrap_post_compile( self.compiled_bw.backward_state_indices, disable_amp, self.indices_of_inps_to_detach, - None, # lazy_backward_info + cached_lazy_backward, aot_config, fw_metadata=self.runtime_metadata, try_save_cache_entry=None, @@ -874,6 +924,7 @@ def sanitize_gm_for_cache(gm: torch.fx.GraphModule): "meta", # metadata used by export "compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior "_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source + "_backend_id", ) saved_fields = {} for field in IGNORED_FIELDS: @@ -900,6 +951,35 @@ def type(): return "aot_autograd" +@CacheArtifactFactory.register +class BundledAOTAutogradCacheArtifact(PrecompileCacheArtifact[Callable]): + @override + @staticmethod + def type(): + return "precompile_aot_autograd" + + @override + def after_deserialization(self) -> Callable: + entry = pickle.loads(self.content) + # In the precompile use case, guards are already serialized + # by dynamo, so we don't need to add them to the environment + entry.guards_expr = None + # TODO: this isn't exactly right, because cudagraphs needs to be a shared config + # which is set by compile_fx. But in precompile, we never actually call compile_fx + # so we don't have a place to track cudagraphs here. + cudagraphs = torch._inductor.config.triton.cudagraphs + compiled_fn = entry.wrap_post_compile( + [], entry.sanitized_aot_config, {"cudagraphs": cudagraphs} + ) + + # TODO: this ignores flat_params, which can exist + # if inline_builtin_nn_modules=False + def forward(*runtime_args: tuple[Any]): + return compiled_fn(list(runtime_args)) + + return forward + + class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]): """ Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas @@ -974,9 +1054,11 @@ def load( cache_key, debug_lines = autograd_cache_key( gm, args, aot_config, fx_config ) - entry: Optional[ - GenericAOTAutogradCacheEntry - ] = AOTAutogradCache._lookup(cache_key, local, remote, args, cache_info) + entry: Optional[GenericAOTAutogradCacheEntry] = ( + AOTAutogradCache._lookup( + cache_key, local, remote, args, cache_info, aot_config + ) + ) if entry is not None: compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config) log.info("AOTAutograd cache hit for key %s", cache_key) @@ -1000,9 +1082,8 @@ def load( # FXGraphCache and AOTAutogradCache? # get_metrics_context().increment(...) if ( - ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( - time_saved_ns - ) + ephemeral_increase + := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) ) != 0: cache_info["ephemeral_timeout_increase"] = ephemeral_increase @@ -1052,7 +1133,9 @@ def load( symints = AOTAutogradCache._filter_backed_symints(args) if cache_key is not None: aot_config.cache_info = AOTAutogradCacheInfo( - cache_key, time.time_ns(), forward_symints=symints + cache_key, + time.time_ns(), + forward_symints=symints, ) compiled_fn = dispatch_and_compile() @@ -1125,7 +1208,12 @@ def evaluate_guards(guard_expr: str, hints: Union[list[int], list[torch.SymInt]] @staticmethod def _lookup( - key: str, local: bool, remote: bool, args: list[Any], cache_info: dict[str, Any] + key: str, + local: bool, + remote: bool, + args: list[Any], + cache_info: dict[str, Any], + aot_config: Optional[AOTConfig], ) -> Optional[GenericAOTAutogradCacheEntry]: """Given a key generated by AOTAutogradCachePickler, look up its location in the cache.""" remote_cache: Optional[RemoteCache[JsonDataTy]] = None @@ -1151,6 +1239,19 @@ def _lookup( CacheArtifactManager.record_artifact( AOTAutogradCacheArtifact.type(), key, pickled_content ) + if ( + config.bundled_autograd_cache + and aot_config is not None + and aot_config.precompile_backend_id is not None + ): + # NB: We don't want to use the cached aot_config.precompile_backend_id + # 1. because we set it to None on save 2. even if we didn't, this new run + # that cache hit has a *new* backend id associated with it. + PrecompileContext.record_artifact( + BundledAOTAutogradCacheArtifact.type(), + aot_config.precompile_backend_id, + pickled_content, + ) except Exception as e: log.info("AOTAutograd cache unable to load compiled graph: %s", e) if config.strict_autograd_cache: @@ -1180,6 +1281,17 @@ def save(key: str, entry: GenericAOTAutogradCacheEntry, remote: bool): CacheArtifactManager.record_artifact( AOTAutogradCacheArtifact.type(), key, content ) + if ( + config.bundled_autograd_cache + and entry.sanitized_aot_config.precompile_backend_id is not None + ): + precompile_key = entry.sanitized_aot_config.precompile_backend_id + # Now that we're saving it, the precompile_backend_id field is no longer + # useful, remove it from the entry. + entry.sanitized_aot_config.precompile_backend_id = None + PrecompileContext.record_artifact( + BundledAOTAutogradCacheArtifact.type(), precompile_key, content + ) AOTAutogradCache._write_to_local_cache(key, content) counters["aot_autograd"]["autograd_cache_saved"] += 1 except BypassAOTAutogradCache as e: @@ -1199,9 +1311,9 @@ def save(key: str, entry: GenericAOTAutogradCacheEntry, remote: bool): return None if remote: - remote_cache: Optional[ - RemoteCache[JsonDataTy] - ] = AOTAutogradCache.get_remote_cache() + remote_cache: Optional[RemoteCache[JsonDataTy]] = ( + AOTAutogradCache.get_remote_cache() + ) if remote_cache is not None: time_taken_ms = int( (entry.forward_time_taken_ns + entry.backward_time_taken_ns) // 1e6 @@ -1213,7 +1325,7 @@ def save(key: str, entry: GenericAOTAutogradCacheEntry, remote: bool): remote_cache.put(key, cache_data) @staticmethod - @functools.lru_cache(None) + @functools.cache def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: """ Attempts to load the remote cache, returns None on error. @@ -1244,6 +1356,7 @@ def make_entry( guards_expr: Optional[str], backward_state_indices: Optional[list[int]], num_symints_saved_for_bw: Optional[int], + serialized_bw_module: Optional[SerializedGraphModule], ) -> GenericAOTAutogradCacheEntry: if config.bundled_autograd_cache: # Helper function to unwrap all the wrappers we added during aotdispatch @@ -1280,6 +1393,7 @@ def unwrap_compiled_fx_graph(obj): backward_time_taken_ns=backward_time_taken_ns, sanitized_aot_config=sanitized_aot_config, guards_expr=guards_expr, + serialized_bw_module=serialized_bw_module, ) else: @@ -1324,4 +1438,5 @@ def unwrap_compiled_fx_graph(obj): backward_time_taken_ns=backward_time_taken_ns, sanitized_aot_config=sanitized_aot_config, guards_expr=guards_expr, + serialized_bw_module=serialized_bw_module, ) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index e128901d39adde..db5075c144b61f 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -571,9 +571,9 @@ def inner(*flat_args): output_type = ( OutputType.alias_of_intermediate_save_as_output ) - intermediate_base_tensor_id_to_output_idx[ - id(o._base) - ] = new_out_idx + intermediate_base_tensor_id_to_output_idx[id(o._base)] = ( + new_out_idx + ) intermediate_bases.append(o._base) elif ( # See https://github.com/pytorch/pytorch/issues/100348 for this case. diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index 3382cb102dcadb..be3226ca01f577 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -46,11 +46,14 @@ def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule: # FunctionalTensorMode must be enabled here. # See Note [Accessing .grad_fn on FunctionalTensor] - with enable_python_dispatcher(), FunctionalTensorMode( - pre_dispatch=aot_config.pre_dispatch, - export=aot_config.is_export, - # Allow token discovery for joint fn tracing as tokens can be used in backward. - _allow_token_discovery=True, + with ( + enable_python_dispatcher(), + FunctionalTensorMode( + pre_dispatch=aot_config.pre_dispatch, + export=aot_config.is_export, + # Allow token discovery for joint fn tracing as tokens can be used in backward. + _allow_token_discovery=True, + ), ): fx_g = make_fx( f, @@ -238,9 +241,9 @@ def aot_dispatch_base_graph( # TODO: should factor this into a separate function for export that always only returns just the graph. if aot_config.is_export: - assert ( - maybe_subclass_meta is None - ), "aot_export_module does not support tensor subclass inputs for now." + assert maybe_subclass_meta is None, ( + "aot_export_module does not support tensor subclass inputs for now." + ) return fw_module, saved_updated_flat_args_subclasses_desugared, maybe_subclass_meta @@ -265,6 +268,7 @@ def aot_dispatch_autograd_graph( fw_metadata, ) joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config) + joint_fn_handle = joint_fn_to_trace.handle joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn( joint_fn_to_trace, @@ -272,6 +276,7 @@ def aot_dispatch_autograd_graph( meta=fw_metadata, aot_config=aot_config, trace_joint=True, + joint_fn_handle=joint_fn_handle, ) # TODO: replace with AOTDispatchSubclassWrapper once we refactor @@ -332,7 +337,7 @@ def aot_dispatch_autograd_graph( # when we need to manually detach() some inputs in the forward. # Higher order ops might eventually need to do the same. if aot_config.is_export: - assert ( - maybe_subclass_meta is None - ), "aot_export_module does not support tensor subclass inputs for now." + assert maybe_subclass_meta is None, ( + "aot_export_module does not support tensor subclass inputs for now." + ) return fx_g, saved_updated_joint_inputs, maybe_subclass_meta diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index b07477a424a107..e208fa4f6a4438 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -6,6 +6,7 @@ 3. regenerating/replaying views from their base 4. checking if a graph is functional i.e. whether it contains any mutation ops """ + from __future__ import annotations from dataclasses import dataclass @@ -452,14 +453,14 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: # this is mostly a hack to avoid failing XLA tests. # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 if "set_buffer_donor_" not in str(n.args[0]): - assert ( - n.args[0] in placeholders - ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + assert n.args[0] in placeholders, ( + f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + ) mutation_count += 1 else: - assert ( - not n.target._schema.is_mutable - ), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + assert not n.target._schema.is_mutable, ( + f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + ) return mutation_count @@ -472,9 +473,9 @@ def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None: if n.target is torch.ops.aten.copy_.default: # Can only copy_ into an input, and can only do so once if "set_buffer_donor_" not in str(n.args[0]): - assert ( - n.args[0] in placeholders - ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + assert n.args[0] in placeholders, ( + f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + ) placeholders.remove(n.args[0]) copy_from_node = n.args[1] # Pre-condition: every node has a "stack_trace" field in its meta, diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 62470e3b683a9f..3078c25331026f 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -300,6 +300,21 @@ def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices): ] ) + if torch._inductor.config.is_fbcode(): + if symbolic and num_aliases > 400: + from torch._subclasses.fake_tensor import ( + UnsupportedMutationAliasingException, + ) + from torch._utils_internal import justknobs_check + + msg = f"Encountered {num_aliases} dynamic, aliased/mutated inputs, consider setting dynamic=False" + + if justknobs_check( + "pytorch/compiler:aliased_inputs_with_mutation_and_dyn_shapes_killswitch", + False, + ): + raise UnsupportedMutationAliasingException(msg) + with maybe_suppress_guards(): aliased_fwd_inputs = [fwd_inputs[i] for i in aliased_input_indices] actual_aliased_indices = { diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 266581e966e7b6..1f9f9631874ec4 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -40,7 +40,11 @@ from torchgen.utils import dataclass_repr from .. import config -from .autograd_cache import AOTAutogradCache, should_use_remote_autograd_cache +from .autograd_cache import ( + AOTAutogradCache, + serialize_graph_module, + should_use_remote_autograd_cache, +) from .dispatch_and_compile_graph import ( aot_dispatch_autograd_graph, aot_dispatch_base_graph, @@ -157,6 +161,7 @@ def sanitize_aot_config(input: AOTConfig) -> AOTConfig: static_input_indices=input.static_input_indices, pre_dispatch=input.pre_dispatch, cache_info=None, + precompile_backend_id=input.precompile_backend_id, ) @@ -255,8 +260,15 @@ def aot_dispatch_base( compiled_fw, aot_config, runtime_metadata=fw_metadata ) cache_info = aot_config.cache_info + + def should_save_cache(): + if torch._functorch.config.bundled_autograd_cache: + return True + else: + return hasattr(compiled_fw, "_fx_graph_cache_key") + if cache_info is not None: - if hasattr(compiled_fw, "_fx_graph_cache_key"): + if should_save_cache(): time_taken_ns = time.time_ns() - cache_info.start_time_ns guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) entry = AOTAutogradCache.make_entry( @@ -276,6 +288,7 @@ def aot_dispatch_base( guards_expr=guards_expr, backward_state_indices=None, num_symints_saved_for_bw=None, + serialized_bw_module=None, ) AOTAutogradCache.save( cache_info.cache_key, entry, remote=should_use_remote_autograd_cache() @@ -759,6 +772,11 @@ def propagate_meta_info(new_hop_gm, new_call_function_node, old_call_function_no ), ) propagate_meta_info(new_bw_hop_gm, new_bw_node, bw_node) + # Since the partitioner is run after the graph passes, we have lost + # the eager information and cannot faithfully extract the eager + # inputs for the new partitioned backward graph. For the forward + # graph, it was fine because the input signature remains same. + new_bw_node.meta.pop("eager_input_vals", None) bw_node.replace_all_uses_with(new_bw_node) joint_gm.graph.erase_node(bw_node) @@ -812,9 +830,9 @@ def create_wrap_fn(fn, args): from .functional_utils import from_fun, has_data_mutation, to_fun def assert_no_mutation(t): - assert not has_data_mutation( - t - ), "Saved tensors hooks with inputs mutations are not allowed" + assert not has_data_mutation(t), ( + "Saved tensors hooks with inputs mutations are not allowed" + ) @wraps(fn) def _wrapper(*args): @@ -1099,9 +1117,11 @@ def find_saved_in_bw_inputs(bw_inputs): # Inserting packed sym scalars before first saved tensor input. # Inserting packed tensors before last saved tensor input. # Saved tensor inputs between them will be removed. - with bw_g.inserting_before( - bw_g_inputs[0] - ) if is_sym else bw_g.inserting_before(bw_g_input): + with ( + bw_g.inserting_before(bw_g_inputs[0]) + if is_sym + else bw_g.inserting_before(bw_g_input) + ): new_n = bw_g.placeholder(new_node_name) assert new_n.name == new_node_name new_n.meta = copy.copy(out_n.meta) @@ -1753,12 +1773,22 @@ def aot_dispatch_autograd( # close over aot_config.cache_info, since aot_config never changes. # But closing over random variables is confusing IMO, so I'm leaving it. def try_save_cache_entry( # noqa: F811 - compiled_bw_func, _fw_metadata, aot_config + compiled_bw_func: Callable, + bw_module: torch.fx.GraphModule, + _fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, ): - fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) - bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) cache_info = aot_config.cache_info - if cache_info is not None and fw_key and bw_key: + + def should_save_cache(): + if torch._functorch.config.bundled_autograd_cache: + return True + else: + return hasattr(compiled_fw_func, "_fx_graph_cache_key") and hasattr( + compiled_bw_func, "_fx_graph_cache_key" + ) + + if cache_info is not None and should_save_cache(): assert forward_time_taken_ns is not None # TODO: technically, AOTAutograd does a *little* bit of post processing work # in the backward that isn't measured here. But it's small enough that it's not worth @@ -1775,7 +1805,7 @@ def try_save_cache_entry( # noqa: F811 entry = AOTAutogradCache.make_entry( compiled_fw_func, # type: ignore[arg-type] - compiled_bw_func, + compiled_bw_func, # type: ignore[arg-type] aot_joint_graph_str, aot_forward_graph_str, aot_backward_graph_str, @@ -1790,13 +1820,14 @@ def try_save_cache_entry( # noqa: F811 guards_expr=guards_expr, backward_state_indices=backward_state_indices, num_symints_saved_for_bw=num_symints_saved_for_bw, + serialized_bw_module=serialize_graph_module(bw_module), ) remote = should_use_remote_autograd_cache() AOTAutogradCache.save(cache_info.cache_key, entry, remote) if compiled_bw_func is not None: - # If we already compiled it we can just run it right now without waiting - try_save_cache_entry(compiled_bw_func, fw_metadata, aot_config) + # If we already compiled the backward, we save its cache entry now + try_save_cache_entry(compiled_bw_func, bw_module, fw_metadata, aot_config) try_save_cache_entry = None compiled_fn = AOTDispatchAutograd.post_compile( diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 8fcbb71ba77d6c..4f3eeb20a5b120 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -6,13 +6,14 @@ 3. handle functionalized randomness 4. deduplicate inputs and consolidate views into their bases (see input_output_analysis) """ + import builtins import collections import contextlib import copy import itertools import pprint -from contextlib import nullcontext +from contextlib import AbstractContextManager, nullcontext from dataclasses import dataclass, field from functools import wraps from typing import Any, Callable, Optional, TYPE_CHECKING, Union @@ -20,6 +21,8 @@ import torch import torch.utils.dlpack from torch import Tensor +from torch._dynamo import config as dynamo_config +from torch._dynamo.callback import callback_handler, CallbackTrigger from torch._dynamo.utils import CompileEventLogger, dynamo_timed, get_metrics_context from torch._guards import ( compile_context, @@ -316,7 +319,30 @@ def _create_runtime_wrapper( for info in runtime_metadata.output_info ) + def record_runtime_wrapper_prologue_enter() -> Optional[ + AbstractContextManager[None] + ]: + if ( + torch.autograd.profiler._is_profiler_enabled + and dynamo_config.record_runtime_overhead + ): + cm = torch._C._profiler._RecordFunctionFast( + "AOTDispatcher Runtime Wrapper Prologue" + ) + cm.__enter__() + return cm + return None + + def record_runtime_wrapper_prologue_exit( + cm: Optional[AbstractContextManager[None]], + ) -> None: + if cm is not None: + cm.__exit__(None, None, None) + def runtime_wrapper(args: list[Any]): + # Create context manager for profiler + cm = record_runtime_wrapper_prologue_enter() + # stash a ref to each input tensor we plan to use after the compiled function orig_inputs = {i: args[i] for i in epilogue_args_idx} @@ -337,9 +363,11 @@ def runtime_wrapper(args: list[Any]): # It's possible to have trace_joint inside user specified with no_grad() region, # if there is a nested with enable_grad(), that forces some outputs to require gradients. # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. - with torch.autograd._force_original_view_tracking( - True - ), torch.enable_grad(): + with ( + torch.autograd._force_original_view_tracking(True), + torch.enable_grad(), + ): + record_runtime_wrapper_prologue_exit(cm) all_outs = call_func_at_runtime_with_args( compiled_fn, args_, disable_amp=disable_amp, steal_args=True ) @@ -353,6 +381,7 @@ def runtime_wrapper(args: list[Any]): try: if grad_enabled: torch._C._set_grad_enabled(False) + record_runtime_wrapper_prologue_exit(cm) all_outs = call_func_at_runtime_with_args( compiled_fn, args, disable_amp=disable_amp, steal_args=True ) @@ -922,9 +951,9 @@ def pre_compile( keep_arg_mask.append(True) add_dupe_map.append(j) j += 1 - assert ( - len(add_dupe_map) == duped_arg_len - ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" + assert len(add_dupe_map) == duped_arg_len, ( + f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" + ) self.keep_arg_mask = keep_arg_mask self.add_dupe_map = add_dupe_map @@ -968,9 +997,9 @@ def wrapped_flat_fn(*args): keep_input_mutations=fw_metadata.keep_input_mutations, is_train=fw_metadata.is_train, )(*deduped_flat_args) - assert ( - ref_fw_metadata == updated_fw_metadata - ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" + assert ref_fw_metadata == updated_fw_metadata, ( + f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" + ) return wrapped_flat_fn, deduped_flat_args, updated_fw_metadata @@ -1369,14 +1398,14 @@ def _same_dtype_views(view1, view2): # The "inputs that are aliased but have different differentiable bases" case # is more complicated and hopefully pretty rare. Not currently handled. if not is_inference: - assert _are_differentiable_views( - view1, view2 - ), "aot_autograd() does not yet handle non-differentiable view input mutations." + assert _are_differentiable_views(view1, view2), ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) # Regenerating views when reinterpreting complex / real tensors seems non-trivial, # not handling for now - assert _same_dtype_views( - view1, view2 - ), "aot_autograd() does not yet handle input mutations on views with different dtypes." + assert _same_dtype_views(view1, view2), ( + "aot_autograd() does not yet handle input mutations on views with different dtypes." + ) non_none_bases = [ fwd_inputs[i]._base for i in aliased_input_indices @@ -1423,13 +1452,13 @@ def _same_dtype_views(view1, view2): # Case where all of the aliases require gradients, and have the same _base. synthetic_base = non_none_bases[0] for other_base in non_none_bases[1:]: - assert ( - other_base is synthetic_base - ), "aot_autograd() does not yet handle non-differentiable view input mutations." + assert other_base is synthetic_base, ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) for alias in aliases_with_none_bases: - assert ( - alias is synthetic_base - ), "aot_autograd() does not yet handle non-differentiable view input mutations." + assert alias is synthetic_base, ( + "aot_autograd() does not yet handle non-differentiable view input mutations." + ) base_args.append(synthetic_base) for curr_view_idx in aliased_input_indices: curr_view = fwd_inputs[curr_view_idx] @@ -1500,6 +1529,14 @@ class AutogradLazyBackwardCompileInfo: saved_compile_context: Optional[CompileContext] +# On an AOT Autograd cache hit, we already have a lowered backward, so there is usually +# no need to keep information around for a new lazy compilation. Except for compiled autograd, +# which wants to retrace this backward into a larger graph, and it needs the graph module to do so. +@dataclass +class CachedAutogradLazyBackwardCompileInfo: + bw_module_fn: Callable + + def _raise_if_functorch_active(): # not ideal but prevent the user from seeing a nasty traceback - See #138422 stack = torch._C._functorch.peek_interpreter_stack() @@ -1948,7 +1985,12 @@ def post_compile( backward_state_indices: list[int], disable_amp: bool, indices_of_inps_to_detach: list[int], - lazy_backward_info: Optional[AutogradLazyBackwardCompileInfo], + lazy_backward_info: Optional[ + Union[ + AutogradLazyBackwardCompileInfo, + CachedAutogradLazyBackwardCompileInfo, + ] + ], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta, # runtime metadata @@ -2245,9 +2287,9 @@ def backward(double_ctx, *args): @staticmethod def _backward_impl(ctx, all_args): # compiled autograd reimplements this function at proxy_call_aot_backward - assert ( - not backward_state_indices - ), "BackwardState requires CompiledAutograd" + assert not backward_state_indices, ( + "BackwardState requires CompiledAutograd" + ) ctx.maybe_clear_saved_tensors() saved_tensors_use_once = ( @@ -2256,6 +2298,9 @@ def _backward_impl(ctx, all_args): if CompiledFunction.compiled_bw is None: assert lazy_backward_info is not None + assert isinstance( + lazy_backward_info, AutogradLazyBackwardCompileInfo + ) if not saved_tensors_use_once: fw_metadata.bw_donated_idxs = [] @@ -2279,17 +2324,24 @@ def _backward_impl(ctx, all_args): context = torch._C._DisableAutocast if disable_amp else nullcontext metrics_context = get_metrics_context() - with tracing(saved_context), compile_context( - saved_compile_context - ), context(), track_graph_compiling( - aot_config, "backward" - ), metrics_context, dynamo_timed( - "backward._backward_impl", - phase_name="entire_backward_compile", - log_pt2_compile_event=True, - dynamo_compile_column_us="backward_cumulative_compile_time_us", - log_waitcounter=True, - waitcounter_name_override="entire_backward_compile", + with ( + tracing(saved_context), + compile_context(saved_compile_context), + context(), + track_graph_compiling(aot_config, "backward"), + metrics_context, + dynamo_timed( + "backward._backward_impl", + phase_name="entire_backward_compile", + log_pt2_compile_event=True, + dynamo_compile_column_us="backward_cumulative_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="entire_backward_compile", + ), + callback_handler.install_callbacks( + CallbackTrigger.LAZY_BACKWARD, + str(CompileContext.current_compile_id()), + ), ): CompileEventLogger.compilation_metric(is_forward=False) # See Note: [Backward graph lazy lowering] @@ -2300,6 +2352,7 @@ def _backward_impl(ctx, all_args): if try_save_cache_entry is not None: try_save_cache_entry( CompiledFunction.compiled_bw, + bw_module, fw_metadata, aot_config, ) diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 090d4f0e606cfa..cfcbaa8cc097f7 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -689,7 +689,7 @@ def __eq__(self, other): and len(self.traced_tangents) == len(other.traced_tangents) and all( x.shape == y.shape and x.dtype == y.dtype - for x, y, in zip(self.traced_tangents, other.traced_tangents) + for x, y in zip(self.traced_tangents, other.traced_tangents) ) and self.num_backward_tokens == other.num_backward_tokens ) @@ -726,9 +726,9 @@ class SubclassMeta: # in case we made incorrect assumptions about the subclass-ness of our grad_outputs # # Optional field because we don't compute for inference graphs - grad_input_metas: Optional[ - list[Union[PlainTensorMeta, SubclassCreationMeta]] - ] = None + grad_input_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = ( + None + ) def __init__(self) -> None: # The fields in this class get set after its construction. @@ -955,6 +955,7 @@ class AOTConfig: # specializing on example_inputs. # Used only by standalone_compile. ignore_shape_env: bool = False + precompile_backend_id: Optional[str] = None def __post_init__(self): if self.pre_dispatch: diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index b54075dbc0a149..986e569dfc3d7a 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -383,9 +383,9 @@ def wrap_tensor_subclasses( return wrapped_args + activations return tuple(list(wrapped_args) + list(activations)) else: - assert ( - len(unwrapped_args) == num_args_tallied - ), f"Expected {len(unwrapped_args)} == {num_args_tallied}" + assert len(unwrapped_args) == num_args_tallied, ( + f"Expected {len(unwrapped_args)} == {num_args_tallied}" + ) return tuple(wrapped_args) diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py index 8b8b5d11884abd..2d0d75e0d9e80e 100644 --- a/torch/_functorch/_aot_autograd/traced_function_transforms.py +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -12,9 +12,10 @@ """ import warnings -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager, ExitStack, nullcontext +from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, TypeVar, Union from unittest.mock import patch import torch @@ -25,6 +26,7 @@ from torch._guards import detect_fake_mode from torch._prims_common import CUDARngStateHelper from torch.fx.experimental.proxy_tensor import ( + _proxy_tensor_disable_update_tensor_tracker, maybe_disable_thunkify, maybe_enable_thunkify, ) @@ -34,6 +36,7 @@ sym_eq, ) from torch.nn.utils import stateless +from torch.utils._python_dispatch import is_traceable_wrapper_subclass from .. import config from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata @@ -91,6 +94,16 @@ def inner_fn(*args): return inner_fn +@contextmanager +def disable_autocast(): + with ExitStack() as stack: + autocast_enabled_devices = torch._C._autocast_supported_devices() + for device_type in autocast_enabled_devices: + if hasattr(torch, device_type): + stack.enter_context(torch.amp.autocast(device_type, enabled=False)) + yield + + # This function takes in a fn with external aliasing and mutation, # and returns a new fn with no external aliasing and mutation, # as needed for autograd. @@ -182,6 +195,11 @@ def inner_fn(*args): return inner_fn +@dataclass +class JointFnHandle: + post_forward: Optional[Callable] = None + + # Given a fn, computes the joint. # NOTE: fn is expects the following behavior: # (1) fn() needs to return a tuple of (outs, mask), @@ -193,9 +211,15 @@ def inner_fn(*args): # otherwise, when we compute autograd.grad(), we will not take those input mutations into account # (the way this is handled is that we ensure any inputs that normally get mutated are cloned first) def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any: + joint_fn_handle = JointFnHandle() + + # post_forward def inner_fn(primals: list[Any], tangents: list[Any]): outs, tangent_mask = fn(*primals) + if joint_fn_handle and joint_fn_handle.post_forward: + joint_fn_handle.post_forward(primals) + assert len(tangent_mask) == len(outs) outs_to_grad = [ o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent @@ -258,7 +282,25 @@ def inner_fn(primals: list[Any], tangents: list[Any]): ) functional_tensor_mode._tokens = {} - with set_partitioner_tag_is_backward(), fx_traceback.preserve_node_meta(): + with ( + set_partitioner_tag_is_backward(), + fx_traceback.preserve_node_meta(), + ExitStack() as stack, + ): + backward_pass_autocast = torch._functorch.config.backward_pass_autocast + if backward_pass_autocast == "same_as_forward": + # Use the ambient autocast mode(s) + pass + elif backward_pass_autocast == "off": + stack.enter_context(disable_autocast()) + else: + # Disable autocast, then enable anything in `backward_pass_autocast`. + stack.enter_context(disable_autocast()) + assert isinstance(backward_pass_autocast, list) + for kwargs in backward_pass_autocast: + assert isinstance(kwargs, dict) + stack.enter_context(torch.amp.autocast(**kwargs)) + # for full graph export, we always export a joint graph where we assume no tangents are needed. if aot_config.no_tangents: assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1 @@ -285,6 +327,8 @@ def inner_fn_with_anomaly(*args): with torch.autograd.detect_anomaly(check_nan=False): return inner_fn(*args) + inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined] + return inner_fn_with_anomaly @@ -320,15 +364,17 @@ def append_rng_offsets(args): def traced_joint( primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset ): - with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( - "torch.cuda.set_rng_state", override_set_rng_state + with ( + patch("torch.cuda.get_rng_state", override_get_rng_state), + patch("torch.cuda.set_rng_state", override_set_rng_state), ): return append_rng_offsets(func(primals, tangents)) def traced_forward(*primals_fwd_seed_fwd_base_offset): # The signature is (*primals, seed, offset) - with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( - "torch.cuda.set_rng_state", override_set_rng_state + with ( + patch("torch.cuda.get_rng_state", override_get_rng_state), + patch("torch.cuda.set_rng_state", override_set_rng_state), ): return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2])) @@ -379,6 +425,181 @@ def set_partitioner_tag_must_be_in_backward(): return set_partitioner_tag("must_be_in_backward") +def set_partitioner_tag_must_be_in_forward(): + return set_partitioner_tag("must_be_in_forward") + + +@dataclass +class MutationCounters: + mc_data: int + mc_storage: int + mc_inductor_storage_resized: int + + +T = TypeVar("T") + + +def sc_visit( + t, fn: Callable[[Tensor], T], reduce_fn: Callable[[T, T], T], accum_init: T +) -> T: + if not is_traceable_wrapper_subclass(t): + return fn(t) + + accum = accum_init + + def visit(e): + if not is_traceable_wrapper_subclass(e): + nonlocal accum + accum = reduce_fn(accum, fn(e)) + return + + for a in e.__tensor_flatten__()[0]: + visit(getattr(e, a)) + + visit(t) + return accum + + +def _get_mutation_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_mutation_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_storage_changed_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_storage_changed_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_inductor_storage_resized_counter(t) -> int: + return sc_visit( + t, + lambda t: torch._functionalize_inductor_storage_resized_counter(t.elem), # type: ignore[attr-defined] + lambda l, r: max(l, r), + -1, + ) + + +def _get_mutation_counters(t) -> MutationCounters: + return MutationCounters( + _get_mutation_counter(t), + _get_storage_changed_counter(t), + _get_inductor_storage_resized_counter(t), + ) + + +def apply_in_graph_mutations( + input_info, + inpt_old, + inpt_new, + f_inpt, + input_idx, + mcs: Optional[MutationCounters] = None, + applied_mcs: Optional[MutationCounters] = None, +): + assert input_info.mutation_type == MutationType.MUTATED_IN_GRAPH + # See Note [set_() Input Mutations in AOTAutograd] + # all mutations on the input must be under no_grad, so it is safe to put in the graph + # Here, we're saying that if an input experienced a set call, inp.set_(other), + # then we can effectively not have to worry about whether its data was mutated. + # There are 3 cases: + # (1) We mutate inp *after* the set_() call. other is a graph intermediate. + # In this case, we're not really mutating the input storage of "inp"; + # we're mutating the storage of an intermdiate value (other), + # and slamming that storage into the input tensor. So no data mutation is necessary. + # (2) We mutate inp *after* the set_() call. other is a graph *input*. + # In this case, the data mutation will be properly handled in the runtime + # epilogue during the processing of "other" + # (3) We mutate inp *before* the set_() call. + # This case is *not* currently handled. + if input_info.mutates_storage_metadata: + if mcs is None or mcs.mc_storage > applied_mcs.mc_storage: # type: ignore[union-attr] + with torch.no_grad(): + inpt_old.set_(inpt_new) + + # Note [Ordering of resize_() and set_()] + # Importantly: the common usage in FSDP is that we have a dummy parameter + # that sees a set_() and **Then** a resize_(). + # We must put those mutations into the graph in the same order, + # Since running them in the opposite order will have different behavior. + # We fully ban resize_() followed by set_() for now, although in principal + # we could support this + if input_info.mutation_inductor_storage_resize: + if ( + mcs is None + or mcs.mc_inductor_storage_resized > applied_mcs.mc_inductor_storage_resized # type: ignore[union-attr] + ): + # resizing is not supported on subclasses (we error earlier if this happens) + from torch._subclasses.functional_tensor import FunctionalTensor + + assert isinstance(f_inpt, FunctionalTensor) + old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + f_inpt.elem, before=True + ) + new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + f_inpt.elem, before=False + ) + if old_storage_size != new_storage_size: + assert old_storage_size == 0 or new_storage_size == 0, f"""\ + Encosize during tracing on input {input_idx}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} + We oresizing on graph inputs as long as the input either starts or ends with a storage size of 0 + (thee for FSDP)""" + torch.ops.inductor.resize_storage_bytes_(inpt_old, new_storage_size) + if new_storage_size == 0: + # Even if we marked the input as having a data mutation (thus needing a copy_()), + # We should **ignore** it if our input has no storage + # (this can happen if, e.g. we temporarily resize our input, copy data into it, + # and resize it back down to zero) + return + + # Optimization: if the copy_() is a no-op then don't include it in the graph. + # In theory inductor could optimize this away, however in fsdp, we end up with + # param.copy_(param), where param is a zero-storage-size tensor, + # and running this op in eager mode (using the aot_eager backend) will result in a segfault. + # So we may as well optimize it away here. + if inpt_old is inpt_new: + # (This check needs to be done after putting resize_() in the graph, + # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) + return + # We found an input that had a (data-only) mutation. + # Since keep_input_mutations is set, we need to faithfully apply a copy_() + # so the compiler will see the input mutation in the graph. + + if not input_info.mutates_data: + return + + if mcs is not None and mcs.mc_data <= applied_mcs.mc_data: # type: ignore[union-attr] + return + + if input_info.mutations_hidden_from_autograd: + # Hidden from autograd = run under no_grad, **and** don't bump VC + # (although if the tensor was created in inference mode, it has no VC) + if inpt_old.is_inference(): + maybe_preserve_vc = nullcontext() + else: + maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( + inpt_old # type: ignore[assignment] + ) + with torch.no_grad(), maybe_preserve_vc: + inpt_old.copy_(inpt_new) + elif input_info.mutations_under_no_grad_or_inference_mode: + # Under no_grad = run under no_grad (we still bump the VC though) + # (inference_mode will also bump the VC, as long as the tensor in question + # was created outside of inference_mode) + + with torch.no_grad(): + inpt_old.copy_(inpt_new) + else: + inpt_old.copy_(inpt_new) + + # This creates the final function that we want to trace using make_fx(), # in both aot_dispatch_autograd and aot_dispatch_base. # Preconditions: @@ -398,7 +619,16 @@ def create_functionalized_fn( meta: ViewAndMutationMeta, aot_config: AOTConfig, trace_joint: bool, + joint_fn_handle: Optional[JointFnHandle] = None, ) -> Any: + primals_after_forward = None + f_args_after_forward = None + f_args_mutation_counters_after_forward: Optional[list[MutationCounters]] = None + inputs_mutated_in_graph = [ + info.mutation_type == MutationType.MUTATED_IN_GRAPH for info in meta.input_info + ] + has_input_mutated_in_graph = any(inputs_mutated_in_graph) + @wraps(fn) def _functionalized_f_helper(*args): with maybe_enable_thunkify(): @@ -415,6 +645,23 @@ def _functionalized_f_helper(*args): # Wrap inputs into functional wrappers f_args = pytree.tree_map(to_fun, args) + if trace_joint and has_input_mutated_in_graph and joint_fn_handle: + # TODO(ivankobzarev): Support fw and bw mutations for subclasses + def _post_forward(primals): + nonlocal primals_after_forward + primals_after_forward = pytree.tree_map(from_fun, primals) + nonlocal f_args_after_forward + f_args_after_forward = f_args[0] + nonlocal f_args_mutation_counters_after_forward + f_args_mutation_counters_after_forward = [ + MutationCounters(-1, -1, -1) + if not inputs_mutated_in_graph[i] + else _get_mutation_counters(f_arg) + for i, f_arg in enumerate(f_args_after_forward) + ] + + joint_fn_handle.post_forward = _post_forward + # Run the joint f_outs = fn(*f_args) @@ -450,15 +697,15 @@ def _functionalized_f_helper(*args): # Ban metadata mutations on fw inputs during the bw if not inpt_info.mutates_metadata: - assert ( - not joint_mutates_metadata - ), "Found a graph input that had its metadata mutated in the backward. This is not supported" + assert not joint_mutates_metadata, ( + "Found a graph input that had its metadata mutated in the backward. This is not supported" + ) # Ban storage resizing on fw inputs during the bw if not inpt_info.mutation_inductor_storage_resize: - assert not was_inductor_storage_resized( - f_inpt - ), "Found a graph input that had storage resizing in the backward. This is not supported" + assert not was_inductor_storage_resized(f_inpt), ( + "Found a graph input that had storage resizing in the backward. This is not supported" + ) # Allow data mutations on fw inputs during the bw, but only if they do not require grad # So we can guarantee that we can keep the mutations in the graph @@ -470,7 +717,10 @@ def _functionalized_f_helper(*args): # Not banning here mutations on inpt_info.requires_grad - # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph) # Add node meta for copy_ for partitioner that this node should be in backward graph. - with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward(): + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_backward(), + ): before.copy_(after) meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append( idx @@ -485,7 +735,9 @@ def _functionalized_f_helper(*args): ): assert not has_metadata_mutation( f_inpt, before, check_only_storage_mutation=False - ), "Found an input to the backward that had metadata mutated during the backward pass. This is not supported" + ), ( + "Found an input to the backward that had metadata mutated during the backward pass. This is not supported" + ) if has_data_mutation(f_inpt): can_be_in_graph = _check_if_mutation_can_be_in_graph( keep_input_mutations=True, @@ -503,9 +755,9 @@ def _functionalized_f_helper(*args): ), requires_grad=f_inpt.requires_grad, ) - assert ( - can_be_in_graph - ), "a backward input that had data mutated in an autograd-aware way. This is not supported" + assert can_be_in_graph, ( + "a backward input that had data mutated in an autograd-aware way. This is not supported" + ) # Perform the input mutation with torch.fx.traceback.preserve_node_meta(): before.copy_(after) @@ -535,110 +787,87 @@ def _functionalized_f_helper(*args): # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry # about synthetic bases. - for i, (inpt_old, inpt_f) in enumerate( + + # Apply in graph forward mutations only in joint case. + # Note: Mutations of primals in forward AND backward. + # If we have mutations of the same input in forward and in backward, + # we can not fuse them into one copy_ node. As in this case partitioner will put it + # either in forward or in backward. This will lead to incorrect state + # after forward and before backward. + # We have to emit two copy_ nodes, marking with additional meta each node, + # if it must be in forward or backward. + # We memorize mutation counter of the inputs after forward. + # Based on this after joint graph we check if backward also mutated input or not. + # We emit copy_ only in the end of joint tracing, to provide invariant for joint + # graph passes, that our graph is functional, except only some number of copy_ nodes + # in the end. + mcs_applied: list[MutationCounters] = [MutationCounters(0, 0, 0)] * len( + meta.input_info + ) + if f_args_mutation_counters_after_forward is not None: + primals_before = args[0] + for idx, (f_inpt, before, after, inpt_info) in enumerate( + zip( + f_args_after_forward, # type: ignore[arg-type] + primals_before, # type: ignore[arg-type] + primals_after_forward, # type: ignore[arg-type] + meta.input_info, + ) + ): + if inpt_info.mutation_type != MutationType.MUTATED_IN_GRAPH: + continue + + mcs_after_forward = f_args_mutation_counters_after_forward[idx] + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_forward(), + _proxy_tensor_disable_update_tensor_tracker(), + ): + apply_in_graph_mutations( + inpt_info, + before, + after, + f_inpt, + idx, + mcs_after_forward, + mcs_applied[idx], + ) + mcs_applied[idx] = mcs_after_forward + + for idx, (inpt_old, f_inpt) in enumerate( zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) ): - if not isinstance(inpt_f, torch.Tensor): + if not isinstance(f_inpt, torch.Tensor): continue - assert is_fun(inpt_f) - inpt_new = from_fun(inpt_f) + assert is_fun(f_inpt) + inpt_new = from_fun(f_inpt) if ( - meta.input_info[i].mutation_type - == MutationType.MUTATED_IN_GRAPH + meta.input_info[idx].mutation_type + != MutationType.MUTATED_IN_GRAPH ): - # See Note [set_() Input Mutations in AOTAutograd] - # all mutations on the input must be under no_grad, so it is safe to put in the graph - # Here, we're saying that if an input experienced a set call, inp.set_(other), - # then we can effectively not have to worry about whether its data was mutated. - # There are 3 cases: - # (1) We mutate inp *after* the set_() call. other is a graph intermediate. - # In this case, we're not really mutating the input storage of "inp"; - # we're mutating the storage of an intermdiate value (other), - # and slamming that storage into the input tensor. So no data mutation is necessary. - # (2) We mutate inp *after* the set_() call. other is a graph *input*. - # In this case, the data mutation will be properly handled in the runtime - # epilogue during the processing of "other" - # (3) We mutate inp *before* the set_() call. - # This case is *not* currently handled. - if meta.input_info[i].mutates_storage_metadata: - with torch.no_grad(): - inpt_old.set_(inpt_new) - - # Note [Ordering of resize_() and set_()] - # Importantly: the common usage in FSDP is that we have a dummy parameter - # that sees a set_() and **Then** a resize_(). - # We must put those mutations into the graph in the same order, - # Since running them in the opposite order will have different behavior. - # We fully ban resize_() followed by set_() for now, although in principal - # we could support this - if meta.input_info[i].mutation_inductor_storage_resize: - # resizing is not supported on subclasses (we error earlier if this happens) - from torch._subclasses.functional_tensor import ( - FunctionalTensor, - ) - - assert isinstance(inpt_f, FunctionalTensor) - old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] - inpt_f.elem, before=True - ) - new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] - inpt_f.elem, before=False - ) - if old_storage_size != new_storage_size: - assert ( - old_storage_size == 0 or new_storage_size == 0 - ), f"""\ - Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} - We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0 - (the case for FSDP)""" - torch.ops.inductor.resize_storage_bytes_( - inpt_old, new_storage_size - ) - if new_storage_size == 0: - # Even if we marked the input as having a data mutation (thus needing a copy_()), - # We should **ignore** it if our input has no storage - # (this can happen if, e.g. we temporarily resize our input, copy data into it, - # and resize it back down to zero) - continue - # Optimization: if the copy_() is a no-op then don't include it in the graph. - # In theory inductor could optimize this away, however in fsdp, we end up with - # param.copy_(param), where param is a zero-storage-size tensor, - # and running this op in eager mode (using the aot_eager backend) will result in a segfault. - # So we may as well optimize it away here. - if inpt_old is inpt_new: - # (This check needs to be done after putting resize_() in the graph, - # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) + continue + mcs: Optional[MutationCounters] = None + if f_args_mutation_counters_after_forward is not None: + # This could happen for subclasses tracing + # Subclasses support for mutations in fw and bw is TBD. + mcs = _get_mutation_counters(f_inpt) + if mcs == mcs_applied[idx]: + # No mutation in backward; mutation was already applied. continue - # We found an input that had a (data-only) mutation. - # Since keep_input_mutations is set, we need to faithfully apply a copy_() - # so the compiler will see the input mutation in the graph. - if ( - meta.input_info[i].mutates_data - and meta.input_info[i].mutations_hidden_from_autograd - ): - # Hidden from autograd = run under no_grad, **and** don't bump VC - # (although if the tensor was created in inference mode, it has no VC) - if inpt_old.is_inference(): - maybe_preserve_vc = nullcontext() - else: - maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( - inpt_old # type: ignore[assignment] - ) - with torch.no_grad(), maybe_preserve_vc: - inpt_old.copy_(inpt_new) - elif ( - meta.input_info[i].mutates_data - and meta.input_info[ - i - ].mutations_under_no_grad_or_inference_mode - ): - # Under no_grad = run under no_grad (we still bump the VC though) - # (inference_mode will also bump the VC, as long as the tensor in question - # was created outside of inference_mode) - with torch.no_grad(): - inpt_old.copy_(inpt_new) - elif meta.input_info[i].mutates_data: - inpt_old.copy_(inpt_new) + + with ( + torch.fx.traceback.preserve_node_meta(), + set_partitioner_tag_must_be_in_backward(), + ): + apply_in_graph_mutations( + meta.input_info[idx], + inpt_old, + inpt_new, + f_inpt, + idx, + mcs, + mcs_applied[idx], + ) # When an output tensor is a functionalized mutated input, and we # were able to move the mutation in to the graph then we can return @@ -889,9 +1118,12 @@ def create_functional_call(mod, params_spec, params_len, store_orig_mod=False): # https://github.com/pytorch/pytorch/issues/103569 def functional_call(*args, **kwargs): - with stateless._reparametrize_module( - mod, pytree.tree_unflatten(args[:params_len], params_spec) - ), maybe_disable_thunkify(): + with ( + stateless._reparametrize_module( + mod, pytree.tree_unflatten(args[:params_len], params_spec) + ), + maybe_disable_thunkify(), + ): if isinstance(mod, torch.fx.GraphModule): with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): warnings.filterwarnings( diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 52d5722cf483a1..e161b8e7b595b5 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -140,9 +140,9 @@ def call_func_at_runtime_with_args( class PytreeThunk: spec: Optional[pytree.TreeSpec] = None # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. - is_simple: Optional[ - bool - ] = None # if the output spec is a tuple/list, we won't bother unflattening it. + is_simple: Optional[bool] = ( + None # if the output spec is a tuple/list, we won't bother unflattening it. + ) is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec def set(self, spec: pytree.TreeSpec) -> None: @@ -335,12 +335,12 @@ def do(module, subgraph, expected_num_erased): num_erased_inputs = len(input_token_nodes) - assert ( - num_erased_inputs == expected_num_erased - ), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}" - assert ( - num_erased_outs == expected_num_erased - ), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}" + assert num_erased_inputs == expected_num_erased, ( + f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}" + ) + assert num_erased_outs == expected_num_erased, ( + f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}" + ) module.recompile() diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 8537353ac5e77c..96ef59bfebcb04 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -454,8 +454,7 @@ def __call__( self, gm: torch.fx.GraphModule, example_inputs: Sequence[InputType], - ) -> Any: - ... + ) -> Any: ... # TODO: bikeshed on this name @@ -637,13 +636,14 @@ def _create_aot_dispatcher_function( # If any saved tensor hooks are active, we **don't** want to trace them. # Instead, we'll let them run at runtime, around the custom autograd.Function # that we generate in torch.compile. - with torch.autograd.set_multithreading_enabled( - False - ), preserve_rng_state(), ( - fake_mode - ), ( - python_dispatcher_mode - ), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + with ( + torch.autograd.set_multithreading_enabled(False), + preserve_rng_state(), + fake_mode, + python_dispatcher_mode, + PhiloxStateTracker(), + torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(), + ): from torch._library.fake_class_registry import ( FakeScriptObject, maybe_to_fake_obj, @@ -756,7 +756,7 @@ def _dup_fake_script_obj(fake_flat_args): if fw_metadata.num_intermediate_bases > 0: assert not req_subclass_dispatch, f"""\ torch.compile is currently being used with tensor subclass inputs: -{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs +{",".join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs that alias one another, which is currently unsupported in the subclass use case. If you run into this, please file a github issue""" @@ -899,7 +899,7 @@ def aot_function( A simple example usage of :func:`aot_function` is as follows. This example will print the forward and backward graphs of the function ``fn`` - >>> fn = lambda x : x.sin().cos() + >>> fn = lambda x: x.sin().cos() >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module @@ -1022,6 +1022,12 @@ def _try_get_metadata_from_dynamo( aot_autograd_arg_pos_to_source: used to dedup params and their guards static_input_indices: used to identify static inputs for cudagraphs """ + # Note [Assumption on Dynamo Metadata] + # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, + # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. + # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to + # be propagated in order to be recognized as a dynamo graph + if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): # graph was not captured by dynamo return None, [] @@ -1055,7 +1061,10 @@ def _try_get_metadata_from_dynamo( for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): assert hasattr(node, "_dynamo_source") source = node._dynamo_source - assert source not in seen_sources, source + # `source`` specifies the source from user code. ddp optimizer may have + # intermediate values becoming submodule placeholders which does not + # have a source + assert source is None or source not in seen_sources, source seen_sources.add(source) aot_autograd_arg_pos_to_source.append(source) source_name = source.name() if source else str(source) @@ -1162,6 +1171,7 @@ def aot_module_simplified( no_tangents=False, cache_info=None, ignore_shape_env=ignore_shape_env, + precompile_backend_id=getattr(mod, "_backend_id", None), ) fake_mode, shape_env = construct_fake_mode(full_args, aot_config) fake_flat_args = process_inputs( @@ -1415,9 +1425,7 @@ def flattened_joint(*args): output_gradients = [] for a, grad in zip(args, gradients): if isinstance(a, torch.Tensor) and a.requires_grad: - assert ( - grad is not None - ), """\ + assert grad is not None, """\ Found a parameter that did not receive a gradient. "This is most likely a bug, but if this needs to be supported please comment on this Github issue: https://github.com/pytorch/pytorch/issues/101192 @@ -1530,7 +1538,9 @@ def aot_export_joint_simple( if config.debug_assert: # Smoke test that after partitioning, we can run the forward without any calling convention changes. fw_module, _bw_module = aot_config.default_partition( # noqa: F821 - fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821 + fx_g, + args, + num_fwd_outputs=len(fw_metadata.output_infos), # noqa: F821 ) # Attempt to run the fw_module with the original user inputs fake_mode = detect_fake_mode(args) diff --git a/torch/_functorch/apis.py b/torch/_functorch/apis.py index 7d7f3e08a5406e..1faa767d4d05c5 100644 --- a/torch/_functorch/apis.py +++ b/torch/_functorch/apis.py @@ -92,7 +92,7 @@ def vmap( doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully rummaging through docs, use :func:`vmap` to construct a new function. - >>> torch.dot # [D], [D] -> [] + >>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y) @@ -104,7 +104,7 @@ def vmap( >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): - >>> # Very simple linear model with activation + >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) @@ -120,7 +120,7 @@ def vmap( >>> # Setup >>> N = 5 - >>> f = lambda x: x ** 2 + >>> f = lambda x: x**2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) @@ -137,43 +137,49 @@ def vmap( :func:`vmap` can also be nested, producing an output with multiple batched dimensions - >>> torch.dot # [D], [D] -> [] - >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap( + ... torch.vmap(torch.dot) + ... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) - >>> batched_dot(x, y) # tensor of size [2, 3] + >>> batched_dot(x, y) # tensor of size [2, 3] If the inputs are not batched along the first dimension, ``in_dims`` specifies the dimension that each inputs are batched along as - >>> torch.dot # [N], [N] -> [] + >>> torch.dot # [N], [N] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) - >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension + >>> batched_dot( + ... x, y + ... ) # output is [5] instead of [2] if batched along the 0th dimension If there are multiple inputs each of which is batched along different dimensions, ``in_dims`` must be a tuple with the batch dimension for each input as - >>> torch.dot # [D], [D] -> [] + >>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) - >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None + >>> batched_dot( + ... x, y + ... ) # second arg doesn't have a batch dim because in_dim[1] was None If the input is a Python struct, ``in_dims`` must be a tuple containing a struct matching the shape of the input: - >>> f = lambda dict: torch.dot(dict['x'], dict['y']) + >>> f = lambda dict: torch.dot(dict["x"], dict["y"]) >>> x, y = torch.randn(2, 5), torch.randn(5) - >>> input = {'x': x, 'y': y} - >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) + >>> input = {"x": x, "y": y} + >>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},)) >>> batched_dot(input) By default, the output is batched along the first dimension. However, it can be batched along any dimension by using ``out_dims`` - >>> f = lambda x: x ** 2 + >>> f = lambda x: x**2 >>> x = torch.randn(2, 5) >>> batched_pow = torch.vmap(f, out_dims=1) - >>> batched_pow(x) # [5, 2] + >>> batched_pow(x) # [5, 2] For any function that uses kwargs, the returned function will not batch the kwargs but will accept kwargs @@ -184,13 +190,13 @@ def vmap( >>> >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) - >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] + >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] .. note:: vmap does not provide general autobatching or handle variable-length sequences out of the box. """ - from torch._dynamo import is_compiling + from torch.compiler import is_compiling _check_randomness_arg(randomness) if not (chunk_size is None or chunk_size > 0): @@ -337,7 +343,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla >>> batch_size, feature_size = 3, 5 >>> >>> def model(weights, feature_vec): - >>> # Very simple linear model with activation + >>> # Very simple linear model with activation >>> assert feature_vec.dim() == 1 >>> return feature_vec.dot(weights).relu() >>> @@ -349,7 +355,9 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla >>> examples = torch.randn(batch_size, feature_size) >>> targets = torch.randn(batch_size) >>> inputs = (weights, examples, targets) - >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) + >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))( + ... *inputs + ... ) Example of using ``grad`` with ``has_aux`` and ``argnums``: @@ -392,7 +400,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla """ # To avoid cyclical dependency. import torch._functorch.eager_transforms as eager_transforms - from torch._dynamo import is_compiling + from torch.compiler import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) @@ -434,8 +442,8 @@ def grad_and_value( See :func:`grad` for examples """ - from torch._dynamo import is_compiling from torch._functorch import eager_transforms + from torch.compiler import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_and_value_impl( diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index bc715c44ed85aa..c29f52fe6ba9b3 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -258,7 +258,7 @@ class VmapInfo(NamedTuple): randomness: str -def has_overriden_vmap_rule(autograd_function): +def has_overridden_vmap_rule(autograd_function): return autograd_function.vmap is not torch.autograd.Function.vmap @@ -286,14 +286,14 @@ def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwarg ) if autograd_function.generate_vmap_rule: - if has_overriden_vmap_rule(autograd_function): + if has_overridden_vmap_rule(autograd_function): # TODO: Update link to stable once that's out # https://github.com/pytorch/pytorch/issues/92029 raise RuntimeError( f"You tried to vmap over {autograd_function.__name__}, but " - f"it has both generate_vmap_rule=True and an overriden vmap " + f"it has both generate_vmap_rule=True and an overridden vmap " f"staticmethod. Please set generate_vmap_rule=False or delete " - f"the overriden vmap staticmethod to avoid ambiguity. " + f"the overridden vmap staticmethod to avoid ambiguity. " f"For more details, please see " f"https://pytorch.org/docs/main/notes/extending.func.html" ) @@ -301,7 +301,7 @@ def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwarg interpreter, autograd_function, *operands ) - if not has_overriden_vmap_rule(autograd_function): + if not has_overridden_vmap_rule(autograd_function): # TODO: Update link to stable once that's out # https://github.com/pytorch/pytorch/issues/92029 raise RuntimeError( diff --git a/torch/_functorch/benchmark_utils.py b/torch/_functorch/benchmark_utils.py index ac69e8bd4744c6..ba0b31c018bd13 100644 --- a/torch/_functorch/benchmark_utils.py +++ b/torch/_functorch/benchmark_utils.py @@ -185,8 +185,12 @@ def benchmark_utilization( ``` def f(a): return a.sum() + + a = torch.rand(2**20, device="cuda") - utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace") + utilization, mm_conv_utilization = benchmark_utilization( + f, a, "tmp", trace_file_name="tmp_chrome_trace" + ) ``` Args: diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index 974659ed0adecf..39eadaae7ef681 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -1,6 +1,7 @@ # mypy: ignore-errors +import operator from typing import Callable import sympy @@ -170,6 +171,24 @@ def substitute(arg_list): return new_graph +def raise_getitems(gm: fx.GraphModule) -> fx.GraphModule: + # Pre-create a list of nodes to iterate over, as modifying the node order + # during the loop can lead to infinite loops if not handled properly. + getitem_nodes = list( + gm.graph.find_nodes(op="call_function", target=operator.getitem) + ) + + # loop through getitem nodes in the graph and raise them to the parent node + # in reverse order to perserve their original relative order + for node in reversed(getitem_nodes): + assert len(node.all_input_nodes) == 1 + parent = node.all_input_nodes[0] + parent.append(node) + + gm.recompile() + return gm + + def strip_overloads(gm): """ Modifies the target of graph nodes in :attr:`gm` to strip overloads. diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index edb17bfedf5587..65cb80211213a3 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -150,13 +150,13 @@ def check_significant_strides(a, b): def check(nv, rv, desc): assert callable(desc) assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}" - assert ( - subst_symint_tuple(nv.size()) == rv.size() - ), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" + assert subst_symint_tuple(nv.size()) == rv.size(), ( + f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" + ) same_strides = check_significant_strides(nv, rv) - assert ( - same_strides - ), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" + assert same_strides, ( + f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" + ) r = super().run_node(n) if "val" in n.meta: diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 454892b623eac6..66f1fe88c612b5 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -7,6 +7,7 @@ """ Global flags for aot autograd """ + import os import sys from typing import Literal, Optional, TYPE_CHECKING @@ -229,6 +230,40 @@ def remote_autograd_cache_default() -> Optional[bool]: # of tensors in question. fake_tensor_propagate_real_tensors = False +# AOTDispatcher traces out a backward graph at the time of the forward pass. +# This flags controls whether or not that backward graph gets autocast behavior +# applied to it. +# +# The options are either: +# - "same_as_forward". We assume that the backward of the torch.compile'ed region +# will be run under the same autocast context manager that the region was run +# under. This is equivalent to running the following code in eager: +# +# with torch.amp.autocast(...): +# y = region(x) +# ... +# z.backward() +# +# - "off". We assume that the backward of the torch.compile'd region will +# not be run under any autocast context managers. +# This is equivalent to running the following code in eager: +# +# with torch.amp.autocast(...): +# y = region(x) +# ... +# z.backward() +# +# - or a list of kwargs dicts that represent an autocast context manager to turn +# on during the backward pass. +# +# e.g. [{"device_type": "cuda"}] is equivalent to running the following code in eager: +# +# y = region(x) +# ... +# with torch.amp.autocast(device="cuda"): +# z.backward() +backward_pass_autocast = "same_as_forward" + # This controls whether we collect donated buffer. This flag must be set # False if a user wants to retain_graph=True for backward. donated_buffer = False if is_fbcode() else True @@ -288,7 +323,7 @@ def remote_autograd_cache_default() -> Optional[bool]: # This is a temporary config to ensure all ranks take the same decision in the partitioner # it will untimately be removed once we share size_hints across ranks through compiler collectives -_broadcast_rank0_decision = False +_sync_decision_cross_ranks = False # By default apply inlined saved_tensors_hooks only for "donated" buffers. # "donated" buffers are invisible to the user, they are intermediates of the forward graph. diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index bad04f2803f1ef..d99995b86f2bac 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -233,7 +233,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False): >>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) - >>> grad = vjpfunc(torch.tensor(1.))[0] + >>> grad = vjpfunc(torch.tensor(1.0))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x)) However, :func:`vjp` can support functions with multiple outputs by @@ -248,9 +248,9 @@ def vjp(func: Callable, *primals, has_aux: bool = False): :func:`vjp` can even support outputs being Python structs >>> x = torch.randn([5]) - >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} + >>> f = lambda x: {"first": x.sin(), "second": x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) - >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} + >>> cotangents = {"first": torch.ones([5]), "second": torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) @@ -274,7 +274,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False): >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) - >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.)) + >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.0)) .. note:: Using PyTorch ``torch.no_grad`` together with ``vjp``. @@ -930,8 +930,7 @@ def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None: return if not isinstance(output, tuple): raise RuntimeError( - f"{api}: Expected output of f to be a Tensor or Tensors, got " - f"{type(output)}" + f"{api}: Expected output of f to be a Tensor or Tensors, got {type(output)}" ) if len(output) == 0: raise RuntimeError( @@ -1023,10 +1022,10 @@ def jvp( >>> from torch.func import jvp >>> x = torch.randn([]) - >>> f = lambda x: x * torch.tensor([1., 2., 3]) - >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) + >>> f = lambda x: x * torch.tensor([1.0, 2.0, 3]) + >>> value, grad = jvp(f, (x,), (torch.tensor(1.0),)) >>> assert torch.allclose(value, f(x)) - >>> assert torch.allclose(grad, torch.tensor([1., 2, 3])) + >>> assert torch.allclose(grad, torch.tensor([1.0, 2, 3])) :func:`jvp` can support functions with multiple inputs by passing in the tangents for each of the inputs diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index 62ca24ab3fdb96..8d019871ffee3e 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -60,7 +60,10 @@ def functional_call( .. code-block:: python - a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries + a = ( + {"weight": torch.ones(1, 1)}, + {"buffer": torch.zeros(1)}, + ) # two separate dictionaries mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer print(mod.weight) # tensor(...) print(mod.buffer) # tensor(...) @@ -83,10 +86,12 @@ def functional_call( t = torch.randn(4, 3) model = nn.Linear(3, 3) + def compute_loss(params, x, t): y = functional_call(model, params, x) return nn.functional.mse_loss(y, t) + grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t) .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the @@ -179,9 +184,11 @@ def stack_module_state( models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) + def wrapper(params, buffers, data): return torch.func.functional_call(models[0], (params, buffers), data) + params, buffers = stack_module_state(models) output = vmap(wrapper, (0, 0, None))(params, buffers, data) @@ -192,6 +199,8 @@ def wrapper(params, buffers, data): .. code-block:: python import torch.nn as nn + + class Foo(nn.Module): def __init__(self, in_features, out_features): super().__init__() @@ -202,6 +211,7 @@ def __init__(self, in_features, out_features): def forward(self, x): return self.l2(self.l1(x)) + num_models = 5 in_features, out_features = 3, 3 models = [Foo(in_features, out_features) for i in range(num_models)] diff --git a/torch/_functorch/make_functional.py b/torch/_functorch/make_functional.py index 576100e2739d78..16988a022a9775 100644 --- a/torch/_functorch/make_functional.py +++ b/torch/_functorch/make_functional.py @@ -374,10 +374,12 @@ def make_functional( model = nn.Linear(3, 3) func, params = make_functional(model) + def compute_loss(params, x, t): y = func(params, x) return nn.functional.mse_loss(y, t) + grad_weights = grad(compute_loss)(params, x, t) If the model has any buffers, please use :func:`make_functional_with_buffers` instead. @@ -443,10 +445,12 @@ def make_functional_with_buffers( model = nn.Linear(3, 3) func, params, buffers = make_functional_with_buffers(model) + def compute_loss(params, buffers, x, t): y = func(params, buffers, x) return nn.functional.mse_loss(y, t) + grad_weights = grad(compute_loss)(params, buffers, x, t) Args: @@ -469,7 +473,7 @@ def compute_loss(params, buffers, x, t): def transpose_stack( - tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...] + tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...], ) -> tuple[Tensor, ...]: tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors)) results = tuple( diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 448870ee73d3b3..21218e6068538c 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import copy import functools +import hashlib import heapq import itertools import logging @@ -10,7 +11,7 @@ import os.path from collections import defaultdict from dataclasses import dataclass, replace -from typing import Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch import torch._inductor.inductor_prims @@ -49,7 +50,7 @@ from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects -from .compile_utils import fx_graph_cse, get_aten_target +from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems if TYPE_CHECKING: @@ -201,6 +202,10 @@ def _extract_graph_with_inputs_outputs( env[node] = InvalidNode # type: ignore[assignment] continue + if _must_be_in_forward(node) and subgraph != "forward": + env[node] = InvalidNode # type: ignore[assignment] + continue + if node in env: # Node must be one of our inputs. (Any member of env which wasn't an # input to start must have been created by this loop and won't be in @@ -228,9 +233,9 @@ def _extract_graph_with_inputs_outputs( if isinstance(x, fx.Node): if x not in env: raise RuntimeError(f"Node {x} couldn't be found in env") - assert not isinstance( - env[x], InvalidNodeBase - ), f"Node {x} was invalid, but is output" + assert not isinstance(env[x], InvalidNodeBase), ( + f"Node {x} was invalid, but is output" + ) output_values.append(env[x]) else: output_values.append(x) @@ -274,10 +279,18 @@ def _has_tag_is_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "is_backward" +def _has_tag_must_be_in_forward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "must_be_in_forward" + + def _has_tag_must_be_in_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "must_be_in_backward" +def _must_be_in_forward(node: fx.Node) -> bool: + return _has_tag_must_be_in_forward(node) + + def _must_be_in_backward(node: fx.Node) -> bool: return _has_tag_must_be_in_backward(node) or ( _has_tag_is_backward(node) and is_with_effects(node) @@ -448,10 +461,10 @@ def perform_quantization( args=(clamp_max_scaled_node, quant_type), name="fp8_quant_" + str(node.name), ) - quant_activation_node.meta[ - "val" - ] = torch.ops.prims.convert_element_type.default( - clamp_max_scaled_node.meta["val"], quant_type + quant_activation_node.meta["val"] = ( + torch.ops.prims.convert_element_type.default( + clamp_max_scaled_node.meta["val"], quant_type + ) ) quant_activation_node.meta["tensor_meta"] = extract_tensor_metadata( quant_activation_node.meta["val"] @@ -566,10 +579,10 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None: args=(node, quant_type), name="fp8_quant_" + str(node.name), ) - quant_node.meta[ - "val" - ] = torch.ops.prims.convert_element_type.default( - node.meta["val"], quant_type + quant_node.meta["val"] = ( + torch.ops.prims.convert_element_type.default( + node.meta["val"], quant_type + ) ) quant_node.meta["tensor_meta"] = extract_tensor_metadata( quant_node.meta["val"] @@ -577,7 +590,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None: node_to_quant[node] = quant_node # only update the return node args, and remain all other users unchanged output_updated_args = [ - node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs # type: ignore[union-attr] + node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs ] # add the scale nodes to the ouput find the first sym_node in the output idx = find_first_sym_node(output_updated_args) @@ -616,10 +629,10 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: torch.ops.prims.convert_element_type.default, args=(node, dequant_type), ) - activation_node.meta[ - "val" - ] = torch.ops.prims.convert_element_type.default( - node.meta["val"], dequant_type + activation_node.meta["val"] = ( + torch.ops.prims.convert_element_type.default( + node.meta["val"], dequant_type + ) ) activation_node.meta["tensor_meta"] = extract_tensor_metadata( activation_node.meta["val"] @@ -632,18 +645,18 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: divided_target_node_32.meta["val"] = torch.ops.aten.div.Tensor( activation_node.meta["val"], scale_node.meta["val"] ) - divided_target_node_32.meta[ - "tensor_meta" - ] = extract_tensor_metadata(divided_target_node_32.meta["val"]) + divided_target_node_32.meta["tensor_meta"] = ( + extract_tensor_metadata(divided_target_node_32.meta["val"]) + ) with graph.inserting_after(divided_target_node_32): dequant_node = graph.call_function( torch.ops.prims.convert_element_type.default, args=(divided_target_node_32, dequant_type), ) - dequant_node.meta[ - "val" - ] = torch.ops.prims.convert_element_type.default( - divided_target_node_32.meta["val"], dequant_type + dequant_node.meta["val"] = ( + torch.ops.prims.convert_element_type.default( + divided_target_node_32.meta["val"], dequant_type + ) ) dequant_node.meta["tensor_meta"] = extract_tensor_metadata( dequant_node.meta["val"] @@ -655,10 +668,10 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: args=(node, dequant_type), name="dequant_" + str(node.name), ) - dequant_node.meta[ - "val" - ] = torch.ops.prims.convert_element_type.default( - node.meta["val"], dequant_type + dequant_node.meta["val"] = ( + torch.ops.prims.convert_element_type.default( + node.meta["val"], dequant_type + ) ) dequant_node.meta["tensor_meta"] = extract_tensor_metadata( dequant_node.meta["val"] @@ -1049,10 +1062,10 @@ def _count_ops(graph: fx.Graph): for node in graph.nodes: if node.op == "call_function": cnt[node.target.__name__] += 1 - log.info("%s", sorted(cnt.items(), key=lambda x: x[1], reverse=True)) + log.info("%s", sorted(cnt.items(), key=operator.itemgetter(1), reverse=True)) -@functools.lru_cache(None) +@functools.cache def pointwise_ops(): ops = [] for attr_name in dir(torch.ops.aten): @@ -1074,7 +1087,7 @@ def sort_depths(args, depth_map: dict[fx.Node, int]) -> list[tuple[fx.Node, int] arg_depths = { arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) } - return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) + return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True) def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: @@ -1141,6 +1154,11 @@ def insert_node_in_graph(node): return gm # Build the graph op-by-op by starting from the node all the way to the end + # copy_ can be not using tangents at all, we must copy it. + for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]: + if node.op == "call_function" and node.target == torch.ops.aten.copy_.default: + insert_node_in_graph(node) + for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: insert_node_in_graph(node) @@ -1299,9 +1317,14 @@ def get_device(node) -> Optional[torch.device]: return torch.device("cpu") def get_sample_rng_state(device: Optional[torch.device]): - if device is not None and device.type == "cuda": - return torch.cuda.get_rng_state() - return torch.get_rng_state() + from torch._guards import detect_fake_mode # noqa: F401 + + fake_mode = detect_fake_mode() + assert fake_mode is not None + with fake_mode: + if device is not None and device.type == "cuda": + return fake_mode.from_tensor(torch.cuda.get_rng_state()) + return fake_mode.from_tensor(torch.get_rng_state()) # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. joint_graph_rng_ops = get_rng_ops(joint_module) @@ -1396,6 +1419,8 @@ def get_sample_rng_state(device: Optional[torch.device]): args=(functional_fw_node, 0), kwargs={}, ) + state.meta["val"] = get_sample_rng_state(device) + rng_output = fw_graph.create_node( "call_function", operator.getitem, @@ -1405,6 +1430,9 @@ def get_sample_rng_state(device: Optional[torch.device]): ), kwargs={}, ) + # Copy the meta data from the original node + rng_output.meta = copy.copy(fw_node.meta) + fw_node.replace_all_uses_with(rng_output) fw_graph.erase_node(fw_node) fw_rng_state_outputs.append(state) @@ -1460,6 +1488,29 @@ def force_save_collectives(joint_module: fx.GraphModule) -> None: node.meta["recompute"] = CheckpointPolicy.MUST_SAVE +def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: + # If we have mutations of the same primal in forward and backward, + # We must not recompute the source of mutation to not apply twice. + has_mutation_in_bw: OrderedSet[torch.fx.Node] = OrderedSet() + for node in reversed(joint_module.graph.nodes): + if node.op == "output": + continue + + is_copy_ = node.target == torch.ops.aten.copy_.default + if is_copy_: + if _has_tag_must_be_in_backward(node): + has_mutation_in_bw.add(node.args[0]) + + if _has_tag_must_be_in_forward(node) and node.args[0] in has_mutation_in_bw: + node.args[1].meta["recompute"] = CheckpointPolicy.MUST_SAVE + else: + # We use invariant of aotdispatch joint graph, + # That we emit copy_ only in the end of it. + # We do not want to iterate through all the joint graph, + # so break at the first non-output, non-copy_ node. + break + + def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in @@ -2009,6 +2060,7 @@ def get_default_op_list() -> OpTypes: aten.as_strided, aten.permute, aten.select, + aten.split, ] view_ops = recomputable_view_ops default_recomputable_ops += [ @@ -2061,7 +2113,9 @@ def get_default_op_list() -> OpTypes: default_recomputable_ops += [method_to_operator(m) for m in magic_methods] recomputable_ops = OrderedSet(default_recomputable_ops) - random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like]) + random_ops = OrderedSet[Callable[..., Any]]( + [aten.native_dropout, aten.rand_like, aten.randn_like] + ) compute_intensive_ops = [ aten.mm, aten.convolution, @@ -2435,7 +2489,7 @@ def estimate_for_budget(b): )[0] -def _broadcast_rank0_decision( +def _sync_decision_cross_ranks( joint_graph: torch.fx.Graph, saved_values: list[torch.fx.Node] ): # use the same policy across different GPUs @@ -2454,7 +2508,8 @@ def has_same_nodes(joint_graph): # We only consider the name and order of nodes. A more robust way # would be to check the hash of the whole graph (disregarding input shapes), # this is is a reasonable first-order approximation. - inputs = hash(tuple(x.name for x in joint_graph.nodes)) + node_str = "/".join(x.name for x in joint_graph.nodes) + inputs = hashlib.sha256(node_str.encode("utf-8")).hexdigest() all_inputs = [None for _ in range(torch.distributed.get_world_size())] with no_dispatch(), unset_fake_temporarily(): # TODO: maybe use a different process group? @@ -2470,11 +2525,48 @@ def has_same_nodes(joint_graph): ): with no_dispatch(), unset_fake_temporarily(): objects = [[x.name for x in saved_values]] - # TODO: maybe use a different process group for this - torch.distributed.broadcast_object_list(objects, src=0) - saved_values_names = objects[0] + saved_ops_names_all_ranks: list[list[str]] = [ + [] for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather_object(saved_ops_names_all_ranks, objects[0]) name_to_node = get_name_to_node(joint_graph) - saved_values = [name_to_node[n] for n in saved_values_names] + saved_sizes: list[int] = [] + saved_ops_with_sizes: dict[str, int] = {} + + for idx, saved_ops_names in enumerate(saved_ops_names_all_ranks): + saved_nodes = [name_to_node[op_name] for op_name in saved_ops_names] + saved_size = 0 + for node in saved_nodes: + size_of_node = _size_of(node) + saved_size += size_of_node + if idx == torch.distributed.get_rank(): + saved_ops_with_sizes[node.name] = size_of_node + saved_ops_with_sizes["total size"] = saved_size + saved_sizes.append(saved_size) + + saved_sizes_tensor = torch.tensor( + saved_sizes, + device=torch.distributed.distributed_c10d._get_object_coll_device(), + ) + torch.distributed.all_reduce( + saved_sizes_tensor, op=torch.distributed.distributed_c10d.ReduceOp.MAX + ) + + picked_rank_idx = int(torch.argmin(saved_sizes_tensor).item()) + sync_decision_cross_ranks_str = f"picked_rank_idx={picked_rank_idx}, saved_nodes of current rank={saved_ops_with_sizes}" + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_joint_graph_sync_decision_cross_ranks", + "encoding": "string", + }, + payload_fn=lambda: sync_decision_cross_ranks_str, + ) + + saved_values = [ + name_to_node[n] for n in saved_ops_names_all_ranks[picked_rank_idx] + ] + return saved_values @@ -2529,6 +2621,7 @@ def min_cut_rematerialization_partition( joint_module = cleanup_recompute_tags(joint_module) if not config.unsafe_allow_optimization_of_collectives: force_save_collectives(joint_module) + force_save_bw_mutation_src(joint_module) def classify_nodes(joint_module, static_lifetime_input_indices): name_to_node = get_name_to_node(joint_module.graph) @@ -2620,8 +2713,8 @@ def classify_nodes(joint_module, static_lifetime_input_indices): node_info, memory_budget=memory_budget, ) - if config._broadcast_rank0_decision: - saved_values = _broadcast_rank0_decision(joint_graph, saved_values) + if config._sync_decision_cross_ranks: + saved_values = _sync_decision_cross_ranks(joint_graph, saved_values) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) @@ -2641,6 +2734,11 @@ def classify_nodes(joint_module, static_lifetime_input_indices): ) bw_module = reordering_to_mimic_autograd_engine(bw_module) + # raise all getitem ops to as early as possible + # this is helpful for memory, especially in the case of aot_eager backend + fw_module = raise_getitems(fw_module) + bw_module = raise_getitems(bw_module) + if AOT_PARTITIONER_DEBUG: # Calculate sorted sizes of saved values sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values]) @@ -2669,7 +2767,9 @@ def classify_nodes(joint_module, static_lifetime_input_indices): len(fw_module_nodes), len(bw_module_nodes), ) - rematerialized_ops = sorted(counts.items(), key=lambda x: x[1], reverse=True) + rematerialized_ops = sorted( + counts.items(), key=operator.itemgetter(1), reverse=True + ) log.info("Count of Ops Rematerialized: %s", rematerialized_ops) return fw_module, bw_module diff --git a/torch/_functorch/top_operators_github_usage.py b/torch/_functorch/top_operators_github_usage.py index 6290a155500d00..171c6fc6c1e018 100644 --- a/torch/_functorch/top_operators_github_usage.py +++ b/torch/_functorch/top_operators_github_usage.py @@ -4,6 +4,7 @@ From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0 Try to keep this list in sync with that. """ + import operator diff --git a/torch/_guards.py b/torch/_guards.py index 818696c1f3e794..28becfac586590 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -986,6 +986,13 @@ def set_current_loc(filename, lineno, frame_name): # framesummary. TracingContext.get().loc_in_frame = (filename, lineno, frame_name) + @staticmethod + def get_traced_code(): + tc = TracingContext.try_get() + if tc is None: + return None + return tc.traced_code + @contextmanager def compile_context(context: Optional[CompileContext]): diff --git a/torch/_higher_order_ops/aoti_call_delegate.py b/torch/_higher_order_ops/aoti_call_delegate.py index d90586f8950db5..bb2c62de7617aa 100644 --- a/torch/_higher_order_ops/aoti_call_delegate.py +++ b/torch/_higher_order_ops/aoti_call_delegate.py @@ -156,6 +156,9 @@ def call_delegate_functionalize( ) with ctx.redispatch_to_next(): res = aoti_call_delegate( - lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type] + lowered_module, + original_gm, + unwrapped_weight_args, # type: ignore[arg-type] + unwrapped_input_args, # type: ignore[arg-type] ) return ctx.wrap_tensors(res) diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index 66e6c48e73bbf3..c9f5dda563369b 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -31,9 +31,9 @@ def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves): - assert ( - len(args) == 2 * num_leaves - ), f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}" + assert len(args) == 2 * num_leaves, ( + f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}" + ) lhs = pytree.tree_unflatten(args[:num_leaves], spec) rhs = pytree.tree_unflatten(args[num_leaves:], spec) return combine_fn(lhs, rhs) @@ -79,9 +79,9 @@ def __call__(self, combine_fn, xs, additional_inputs): # the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785 # Once this issue is resolved, the assertion should only allow tuples # and the tuple cast should be removed - assert isinstance( - additional_inputs, (tuple, list) - ), "additional_inputs must be a tuple." + assert isinstance(additional_inputs, (tuple, list)), ( + "additional_inputs must be a tuple." + ) additional_inputs = ( tuple(additional_inputs) if isinstance(additional_inputs, list) @@ -134,6 +134,7 @@ def associative_scan( def add(x: torch.Tensor, y: torch.Tensor): return x + y + cumsum = associative_scan(add, x, dim) """ @@ -377,9 +378,9 @@ def trace_associative_scan( assert outputs is not None outputs = pytree.tree_leaves(outputs) - assert len(outputs) == len( - xs - ), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}" + assert len(outputs) == len(xs), ( + f"expected combine_fn to return {len(xs)} results but got {len(outputs)}" + ) xs_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ first_slice_copy(x) for x in xs diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 309b2e5904d46c..ef8cddbae7c114 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, get_args, Optional, Union +from typing import Any, Callable, get_args, Optional, Union import torch import torch._library.utils as library_utils @@ -11,8 +11,10 @@ from torch import Tensor from torch._C import DispatchKey from torch._higher_order_ops.utils import ( + _has_gen_schema, call_op, HopInstance, + HopSchema, materialize_callable_in_args, unique_graph_id, ) @@ -237,7 +239,7 @@ def use_alias(): write_single_view( f"_{arg_name}", kwargs[arg_name], - arg_to_base_index.get(arg_name, None), + arg_to_base_index.get(arg_name, None), # type: ignore[arg-type] ) else: raise RuntimeError(f"Unsupported type {arg_type}") @@ -388,7 +390,7 @@ def __call__( if isinstance(_mutable_op, HigherOrderOperator): _op_to_check = HopInstance( _mutable_op, - SchemaHolder.from_tree_spec(kwargs.get("_op_schema", None)).schema, + SchemaHolder.from_tree_spec(kwargs.get("_op_schema", None)).schema, # type: ignore[arg-type] ) else: _op_to_check = _mutable_op @@ -411,12 +413,6 @@ def can_auto_functionalize( ) -> bool: if isinstance(op, HopInstance): # HOPs that implement gen_schema and schema is not functional are auto_functionalizable. - def _has_gen_schema(op: HigherOrderOperator): - method = "gen_schema" - return hasattr(type(op), method) and getattr( - type(op), method - ) is not getattr(HigherOrderOperator, method) - if not _has_gen_schema(op._op): return False @@ -526,7 +522,8 @@ def do_auto_functionalize( ) with ctx.redispatch_to_next(): unwrapped_outs = auto_functionalized( - op, **unwrapped_kwargs # type: ignore[arg-type] + op, + **unwrapped_kwargs, # type: ignore[arg-type] ) # List of the name of args that get mutated (according to the schema) @@ -575,9 +572,31 @@ def sync_update(o, orig_arg): return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] +# Wrapper for GraphModule that applies functionalization during execution to enable +# epilogue graph inlining and better fusion opportunities in subgraphs +# When tracing this wrapper, we'll get a graph module with epilogue. +# +# We want to hash it according to the original graph module, so that when we go +# from Functional mode -> fake mode for multiple invoke_subgraph calls that share, +# the same inner graph module, we can hit the cache. +class FunctionalCallableWithEpilogue: + def __init__(self, orig_callable: Callable): + self.orig_callable = orig_callable + + def __call__(self, *args, **kwargs): + # We call torch.func.functionalize. This allows us to inline the epilogue graph. + # Inlining has the benefit of allowing easiser fusion inside subgraph. + # Though the epilogue graph contains copy_, it is OK becuase inductor can handle it + # and this is also how we have been supporting top-level graph input mutation. + return tuple(torch.func.functionalize(self.orig_callable)(*args, **kwargs)) + + def __hash__(self): + return id(self.orig_callable) + + def do_auto_functionalize_v2( mode: "torch._subclasses.functional_tensor.FunctionalTensorMode", - op: _MutableOpType, + op: Union[OpOverload, HopInstance], args: tuple[Any, ...], kwargs: dict[str, Any], ) -> Any: @@ -589,28 +608,13 @@ def do_auto_functionalize_v2( # args come from the schema. This makes it easier for us to work with them. normalized_kwargs = {} + schema = op._schema + op = op._op if isinstance(op, HopInstance) else op assert isinstance(op, get_args(_MutableOpType)) - schema = ( - op.gen_schema(*args, **kwargs) - if isinstance(op, HigherOrderOperator) - else op._schema - ) def _functionalize_callable(arg: Any): if callable(arg): - - def functional_fn(*args, **kwargs): - # We call torch.func.functionalize. This allows us to inline the epilogue graph. - # Inlining has the benefit of allowing easiser fusion inside subgraph. - # Though the epilogue graph contains copy_, it is OK becuase inductor can handle it - # and this is also how we have been supporting top-level graph input mutation. - return tuple( - pytree.tree_leaves(torch.func.functionalize(arg)(*args, **kwargs)) - ) - - return torch._higher_order_ops.base_hop.FunctionWithNoFreeVars( - functional_fn - ) + return FunctionalCallableWithEpilogue(arg) return arg args, kwargs = pytree.tree_map(_functionalize_callable, (args, kwargs)) @@ -701,7 +705,8 @@ def set_result(base_index): with ctx.redispatch_to_next(): unwrapped_outs = auto_functionalized_v2( - op, **auto_func_kwargs # type: ignore[arg-type] + op, + **auto_func_kwargs, # type: ignore[arg-type] ) unwrapped_actual_out: Union[Any, tuple[Any]] = ( @@ -713,9 +718,9 @@ def set_result(base_index): ) if isinstance(op, HigherOrderOperator): - assert ( - len(schema.returns) > 0 - ), f"hop is expected to return at least one output {schema}." + assert len(schema.returns) > 0, ( + f"hop is expected to return at least one output {schema}." + ) assert len(unwrapped_actual_out) == len(schema.returns) else: if len(schema.returns) == 0: @@ -843,15 +848,15 @@ def auto_functionalized_v2_dense( _only_clone_these_bases = tuple(range(len(_all_bases))) if isinstance(_mutable_op, OpOverload): - schema = _mutable_op._schema + schema: torch._C.FunctionSchema = _mutable_op._schema else: schema = pytree.tree_unflatten([], kwargs.pop("_op_schema")).schema - _mutable_op = ( - _mutable_op - if isinstance(_mutable_op, OpOverload) - else HopInstance(_mutable_op, schema) - ) + if isinstance(_mutable_op, OpOverload): + _callable_op: Union[HopInstance, OpOverload] = _mutable_op + else: + assert isinstance(schema, HopSchema) + _callable_op = HopInstance(_mutable_op, schema) op_kwargs_new, all_bases_new = _generate_new_op_kwargs_from_bases( schema, @@ -861,7 +866,7 @@ def auto_functionalized_v2_dense( ) out = call_op( - _mutable_op, + _callable_op, tuple(), op_kwargs_new, ) @@ -956,7 +961,7 @@ def auto_functionalized_v2_proxy( if _only_clone_these_bases is None: _only_clone_these_bases = tuple(range(len(all_bases))) - schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema + schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema # type: ignore[arg-type] new_kwargs, _ = _generate_new_op_kwargs_from_bases( schema, {k: v for k, v in kwargs.items() if k not in ("_all_bases", "_op_schema")}, diff --git a/torch/_higher_order_ops/base_hop.py b/torch/_higher_order_ops/base_hop.py index ff18f86885d54f..8898c56ab2275f 100644 --- a/torch/_higher_order_ops/base_hop.py +++ b/torch/_higher_order_ops/base_hop.py @@ -6,9 +6,11 @@ import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._dispatch.python import suspend_functionalization -from torch._higher_order_ops.cond import materialize_as_graph +from torch._higher_order_ops.auto_functionalize import FunctionalCallableWithEpilogue from torch._higher_order_ops.utils import ( - check_input_alias_and_mutation_return_ouputs, + check_input_alias_and_mutation_return_outputs, + HopInstance, + materialize_as_graph, reenter_make_fx, ) from torch._ops import HigherOrderOperator @@ -38,11 +40,14 @@ class InvokeQuant(BaseHOP): def __init__(self): return super().__init__("invoke_quant") + invoke_quant = InvokeQuant() + def g(x): return x.sin().cos() + @torch.compile(backend="aot_eager") def f(x): return invoke_quant(g, x, scheme="nf4") @@ -66,7 +71,14 @@ def __init__(self, hop_name) -> None: ) def __call__(self, subgraph, *operands, **kwargs): - if not isinstance(subgraph, (torch.fx.GraphModule, FunctionWithNoFreeVars)): + if not isinstance( + subgraph, + ( + torch.fx.GraphModule, + FunctionWithNoFreeVars, + FunctionalCallableWithEpilogue, + ), + ): raise RuntimeError( f"{self._name}: when calling this API without torch.compile, " f"we require that the subgraph be a torch.fx.GraphModule (or " @@ -104,7 +116,10 @@ def _call_ProxyTorchDispatchMode(self, proxy_mode, subgraph, *operands, **kwargs out = self(subgraph, *operands, **kwargs) return track_tensor_tree( - out, out_proxy, constant=None, tracer=proxy_mode.tracer # type: ignore[arg-type] + out, + out_proxy, + constant=None, + tracer=proxy_mode.tracer, # type: ignore[arg-type] ) def _call_FakeTensorMode(self, mode, subgraph, *operands, **kwargs): @@ -131,15 +146,18 @@ def _call_FakeTensorMode(self, mode, subgraph, *operands, **kwargs): # copies the mutated inputs to the hop if necessary and call the hop. # After these steps, the rest of the inductor stack knows how to fuse the copy_ in subgraph with other ops. def _call_Functionalize(self, ctx, subgraph, *operands, **kwargs): - from torch._higher_order_ops.auto_functionalize import do_auto_functionalize_v2 + from torch._higher_order_ops.auto_functionalize import ( + can_auto_functionalize, + do_auto_functionalize_v2, + ) # invoke_quant has non-proxable argument of type InvokeQuant that # we cannot generate schema for. if self is not torch.ops.higher_order.invoke_quant_packed: - hop_schema = self.gen_schema(subgraph, *operands, **kwargs) - if hop_schema.is_mutable: + hop_instance = HopInstance.create(self, subgraph, *operands, **kwargs) + if can_auto_functionalize(hop_instance): return do_auto_functionalize_v2( - ctx.mode, self, (subgraph, *operands), kwargs + ctx.mode, hop_instance, (subgraph, *operands), kwargs ) unwrapped_operands = ctx.unwrap_tensors(operands) @@ -168,7 +186,7 @@ def gen_schema(self, subgraph, *operands, **kwargs): out_out_alias, mutated_inp_idx, output, - ) = check_input_alias_and_mutation_return_ouputs(subgraph, fake_args) + ) = check_input_alias_and_mutation_return_outputs(subgraph, fake_args) if not ( len(inp_inp_alias) == 0 @@ -218,7 +236,11 @@ def backward(ctx, *grad_outputs): kwargs = ctx.kwargs # TODO: Something special needs to happen with min cut partitioner - with suspend_functionalization(), disable_functional_mode(), torch.enable_grad(): + with ( + suspend_functionalization(), + disable_functional_mode(), + torch.enable_grad(), + ): with disable_proxy_modes_tracing(): from .invoke_subgraph import create_fw_bw_graph from .utils import _from_fun diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 4ab1709aaec6d5..518c0624cbabe5 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -15,12 +15,11 @@ is_batchedtensor, maybe_get_bdim, ) -from torch._dispatch.python import suspend_functionalization from torch._functorch.utils import exposed_in from torch._higher_order_ops.utils import ( - _maybe_reenter_make_fx, _maybe_run_with_interpreter, _set_compilation_env, + materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, saved_tensors_and_symints, @@ -29,17 +28,15 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, - disable_proxy_modes_tracing, ProxyTorchDispatchMode, track_tensor_tree, ) from torch.utils._python_dispatch import _get_current_dispatch_mode -from .utils import _from_fun +from .utils import clone_outputs_aliasing_inputs log = logging.getLogger(__name__) @@ -110,8 +107,12 @@ def cond(pred, true_branch, false_branch, operands): def true_fn(x: torch.Tensor): return x.cos() + + def false_fn(x: torch.Tensor): return x.sin() + + return cond(x.shape[0] > 4, true_fn, false_fn, (x,)) Restrictions: @@ -185,7 +186,11 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with ( + _set_compilation_env(), + torch._dynamo.utils.disable_cache_limit(), + _temp_remove_pre_dispatch_torch_function_mode(), + ): with _temp_remove_metadata_torch_function_mode() as metadata_mode: if metadata_mode: backend = make_eager_backend_with_torch_function_mode(metadata_mode) @@ -196,36 +201,6 @@ def _cond_op_wrapper(*args, **kwargs): ) -def materialize_as_graph( - fn: Callable, - args: tuple[Any], - include_key_set: Optional[torch._C.DispatchKeySet] = None, - exclude_key_set: Optional[torch._C.DispatchKeySet] = None, - force_enable_grad=False, -) -> torch.fx.GraphModule: - if include_key_set is None: - include_key_set = torch._C._dispatch_tls_local_include_set() - if exclude_key_set is None: - exclude_key_set = torch._C._dispatch_tls_local_exclude_set() - - @torch._dynamo.disable(recursive=True, reason=None) - def _materialize_as_graph_inner(): - with suspend_functionalization(), disable_functional_mode(): - with disable_proxy_modes_tracing(): - unfunc_t = [_from_fun(arg) for arg in args] - with contextlib.ExitStack() as stack: - stack.enter_context( - torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set), - ) - if force_enable_grad: - stack.enter_context(torch.enable_grad()) - return _maybe_reenter_make_fx(fn)(*unfunc_t) - - gm = _materialize_as_graph_inner() - assert gm is not None - return gm - - def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable: """ For a fn that accepts flat inputs and returns flat outputs: @@ -262,11 +237,17 @@ def flat_fn(*args_and_grad_outs): tangents = args_and_grad_outs[n_primals:] grad_args = bw_fn(primals, tangents)[1] assert len(args) == len(grad_args) + # In order to keep HOPs functional where the backward graph, + # would have outputs that are aliasing inputs. + # For example in cases where the backward of the function is simply + # passing the upstream gradients through. + maybe_clone = clone_outputs_aliasing_inputs(args_and_grad_outs) + return [ ( torch.zeros_like(arg) if isinstance(arg, torch.Tensor) and grad is None - else grad + else maybe_clone(grad) ) for grad, arg in zip(grad_args, primals) ] @@ -275,9 +256,9 @@ def flat_fn(*args_and_grad_outs): def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): - assert isinstance( - operands, (list, tuple) - ), f"Cond operands must be a list or tuple of tensors and SymInts {operands}" + assert isinstance(operands, (list, tuple)), ( + f"Cond operands must be a list or tuple of tensors and SymInts {operands}" + ) true_graph = reenter_make_fx(true_fn)(*operands) false_graph = reenter_make_fx(false_fn)(*operands) @@ -324,9 +305,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): @cond_op.py_impl(DispatchKey.CompositeExplicitAutograd) def cond_op_dense(pred, true_fn, false_fn, operands): - assert all( - isinstance(o, (torch.Tensor, int)) for o in operands - ), f"Dense implementation operands must be a list of tensors and ints {operands}" + assert all(isinstance(o, (torch.Tensor, int)) for o in operands), ( + f"Dense implementation operands must be a list of tensors and ints {operands}" + ) mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" if pred: @@ -550,11 +531,11 @@ def _has_unbacked_symbols(s: Union[int, torch.SymInt]) -> bool: """ This follows the logic in symbolic_shapes._compute_symbolic_stride - Step 2: Since tensor stride is an accumulative muliplication of the sizes, which is a permutated - (due to view ops) non-decending sequence. + Step 2: Since tensor stride is an accumulative multiplication of the sizes, which is a permutated + (due to view ops) non-descending sequence. Case 1: No size is 1. In this case, strides have unique values. - For example, suppose we have a tenosr with: + For example, suppose we have a tensor with: size [3, 4, 3, 5, 4, 5], stride (1200, 300, 1, 12, 3, 60), merged_size [u0, u1, u2, u3, u4, u5]. @@ -654,9 +635,9 @@ def _maybe_expr(s: Union[int, torch.SymInt]): if _maybe_expr(a_val) in a_stride_expr: a_expr = a_stride_expr[_maybe_expr(a_val)] - assert ( - b_stride_expr[_maybe_expr(b_val)] == a_expr - ), f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}" + assert b_stride_expr[_maybe_expr(b_val)] == a_expr, ( + f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}" + ) merged_strides[i] = a_expr else: if a_val == 1: @@ -713,12 +694,12 @@ def cond_func(ctx, pred, true_fn, false_fn, inputs): @cond_op.py_impl(torch._C._functorch.TransformType.Vmap) def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs): - assert isinstance( - inputs, (list, tuple) - ), "Cond inputs must be a list or tuple of tensors" - assert all( - isinstance(i, torch.Tensor) for i in inputs - ), "Cond inputs must be a list of tensors" + assert isinstance(inputs, (list, tuple)), ( + "Cond inputs must be a list or tuple of tensors" + ) + assert all(isinstance(i, torch.Tensor) for i in inputs), ( + "Cond inputs must be a list of tensors" + ) pred_is_batched = isinstance(pred, torch.Tensor) and is_batchedtensor(pred) pred_ = get_unwrapped(pred) if pred_is_batched else pred diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index 68092865f70268..23f7a5e474bdf0 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -24,11 +24,11 @@ class _EffectType(Enum): OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload] -SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary( - { - torch.ops.aten._print.default: _EffectType.ORDERED, - call_torchbind: _EffectType.ORDERED, - } +SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType]( + [ + (torch.ops.aten._print.default, _EffectType.ORDERED), + (call_torchbind, _EffectType.ORDERED), + ] ) @@ -240,9 +240,9 @@ def handle_effects( key = get_effect_key(op, args, kwargs) assert key is not None if key not in tokens: - assert ( - allow_token_discovery - ), f"Could not find a token for effect {key} which came from the function {op}" + assert allow_token_discovery, ( + f"Could not find a token for effect {key} which came from the function {op}" + ) proxy_tensor_mode = torch._C._get_dispatch_mode( torch._C._TorchDispatchModeKey.PROXY ) diff --git a/torch/_higher_order_ops/executorch_call_delegate.py b/torch/_higher_order_ops/executorch_call_delegate.py index 4221a2b888d3cf..3274502b943cd6 100644 --- a/torch/_higher_order_ops/executorch_call_delegate.py +++ b/torch/_higher_order_ops/executorch_call_delegate.py @@ -49,7 +49,10 @@ def _unwrap_proxy(e): if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): return e return get_proxy_slot( - cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy # type: ignore[attr-defined] + cast(torch.Tensor, e), + proxy_mode.tracer, + e, + lambda e: e.proxy, # type: ignore[attr-defined] ) if not is_lowered_module(lowered_module): diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 5e690f0bb8a9d5..3d0ee36d59b1aa 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -10,20 +10,25 @@ _has_potential_branch_input_mutation, _maybe_reenter_make_fx, autograd_not_implemented, + has_user_subclass, + redirect_to_mode, reenter_make_fx, + register_fake, save_tensors_and_symints_for_backward, saved_tensors_and_symints, UnsupportedAliasMutationException, validate_subgraph_args_types, ) from torch._ops import HigherOrderOperator -from torch._subclasses import FakeTensorMode +from torch._subclasses import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor from torch.fx.experimental.proxy_tensor import ( make_fx, ProxyTorchDispatchMode, track_tensor_tree, ) from torch.fx.graph_module import GraphModule +from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode # Duplicate of _inductor/kernel/flex_attention.py to avoid circular import @@ -33,9 +38,9 @@ def _construct_strides( ) -> Sequence[int]: """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" # Initialize strides - assert len(sizes) == len( - fill_order - ), "Length of sizes must match the length of the fill order" + assert len(sizes) == len(fill_order), ( + "Length of sizes must match the length of the fill order" + ) strides = [0] * len(sizes) # Start with stride 1 for the innermost dimension @@ -396,6 +401,22 @@ def flex_attention_functionalize( """ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + if has_user_subclass( + ( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ), + allowed_subclasses=(FakeTensor, FunctionalTensor), + ): + return NotImplemented + query_unwrapped = ctx.unwrap_tensors(query) key_unwrapped = ctx.unwrap_tensors(key) value_unwrapped = ctx.unwrap_tensors(value) @@ -445,9 +466,34 @@ def flex_attention_functionalize( return ctx.wrap_tensors(out) # type: ignore[return-value, arg-type] +@register_fake(flex_attention) def flex_attention_fake_impl( - query: torch.Tensor, value: torch.Tensor + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), ) -> tuple[torch.Tensor, torch.Tensor]: + if has_user_subclass( + ( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ), + allowed_subclasses=(FakeTensor,), + ): + return NotImplemented + # TODO: Figure out a better way to handle this for NJT than using sum() if query.is_nested: out = torch.empty_like(query, memory_format=torch.contiguous_format) @@ -463,22 +509,9 @@ def flex_attention_fake_impl( return out, logsumexp -@flex_attention.py_impl(FakeTensorMode) -def flex_attention_fake_tensor_mode( - mode: FakeTensorMode, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - score_mod: Callable, - block_mask: tuple, - scale: float, - kernel_options: dict[str, Any], - score_mod_other_buffers: tuple = (), - mask_mod_other_buffers: tuple = (), -) -> tuple[torch.Tensor, torch.Tensor]: - with mode: - out, logsumexp = flex_attention_fake_impl(query, value) - return out, logsumexp +# Registers dispatches for SAC +redirect_to_mode(flex_attention, _CachingTorchDispatchMode) +redirect_to_mode(flex_attention, _CachedTorchDispatchMode) # ---------------------------- Autograd Implementation ---------------------------- @@ -510,7 +543,7 @@ def create_fw_bw_graph( with disable_proxy_modes_tracing(): def _from_fun( - t: Union[Tensor, torch.SymInt, int] + t: Union[Tensor, torch.SymInt, int], ) -> Union[Tensor, torch.SymInt, int]: if isinstance(t, torch.Tensor): return torch.empty_strided( @@ -561,7 +594,7 @@ def joint_f( *other_buffers: tuple[Tensor, ...], ) -> tuple[Tensor, ...]: def fw_with_masks( - *args: tuple[Tensor, ...] + *args: tuple[Tensor, ...], ) -> tuple[tuple[Tensor], tuple[bool]]: fw_out = score_mod(*args) out_requires_grad = fw_out.requires_grad @@ -600,9 +633,9 @@ def forward( for buffer in mask_mod_other_buffers if isinstance(buffer, torch.Tensor) ) - assert ( - not any_buffer_requires_grad - ), "Captured buffers from mask mod that require grad are not supported." + assert not any_buffer_requires_grad, ( + "Captured buffers from mask mod that require grad are not supported." + ) ctx._fw_graph = fw_graph ctx._joint_graph = joint_graph ctx._mask_graph = block_mask[-1] @@ -638,7 +671,11 @@ def forward( return out, logsumexp @staticmethod - def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override] + def backward( # type: ignore[override] + ctx: Any, + grad_out: Tensor, + grad_logsumexp: Tensor, + ) -> tuple[Optional[Tensor], ...]: fw_args = saved_tensors_and_symints(ctx) ( query, @@ -798,7 +835,7 @@ def sdpa_dense_backward( actual_grad_value = _permute_strides(actual_grad_value, value.stride()) def _maybe_new_buffer( - buffer: Union[torch.Tensor, torch.SymInt, int] + buffer: Union[torch.Tensor, torch.SymInt, int], ) -> Optional[Union[torch.Tensor, torch.SymInt, int]]: if isinstance(buffer, torch.Tensor): return ( @@ -906,9 +943,9 @@ def _maybe_new_buffer( actual_grad_value.copy_(grad_value) if Bq != Bkv: - assert ( - Bq > 1 and Bkv == 1 - ), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" + assert Bq > 1 and Bkv == 1, ( + f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" + ) actual_grad_key = torch.sum(actual_grad_key, 0, keepdim=True) actual_grad_value = torch.sum(actual_grad_value, 0, keepdim=True) @@ -1090,6 +1127,25 @@ def flex_attention_backward_functionalize( since we know that the forward score mod function is assured to be free of mutations to the other_buffers, we skip that mutate check and go straight to redispatching. """ + + if has_user_subclass( + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ), + allowed_subclasses=(FakeTensor, FunctionalTensor), + ): + return NotImplemented query_unwrapped = ctx.unwrap_tensors(query) key_unwrapped = ctx.unwrap_tensors(key) value_unwrapped = ctx.unwrap_tensors(value) @@ -1142,9 +1198,8 @@ def flex_attention_backward_functionalize( return ctx.wrap_tensors((grad_query, grad_key, grad_value, grad_score_mod_captured)) # type: ignore[return-value,arg-type] -@flex_attention_backward.py_impl(FakeTensorMode) +@register_fake(flex_attention_backward) def flex_attention_backward_fake_tensor_mode( - mode: FakeTensorMode, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -1162,37 +1217,54 @@ def flex_attention_backward_fake_tensor_mode( ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] ]: - with mode: - Bq, _, _, qk_head_dim = query.shape - Bkv, Hkv, seq_len_kv, v_head_dim = value.shape - - grad_query = torch.empty_like(query) - # zeros_and_scatter creates a contiguous zeros tensor -> contiguous_format - grad_score_mod_captured = tuple( - [ + if has_user_subclass( + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ), + allowed_subclasses=(FakeTensor,), + ): + return NotImplemented + Bq, _, _, qk_head_dim = query.shape + Bkv, Hkv, seq_len_kv, v_head_dim = value.shape + + grad_query = torch.empty_like(query) + # zeros_and_scatter creates a contiguous zeros tensor -> contiguous_format + grad_score_mod_captured = tuple( + [ + ( torch.empty_like(buffer, memory_format=torch.contiguous_format) if isinstance(buffer, torch.Tensor) and buffer.requires_grad else None - for buffer in score_mod_other_buffers - ] - ) + ) + for buffer in score_mod_other_buffers + ] + ) - broadcasted_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim)) - broadcasted_grad_key = _permute_strides(broadcasted_grad_key, key.stride()) + broadcasted_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim)) + broadcasted_grad_key = _permute_strides(broadcasted_grad_key, key.stride()) - broadcasted_grad_value = value.new_empty((Bq, Hkv, seq_len_kv, v_head_dim)) - broadcasted_grad_value = _permute_strides( - broadcasted_grad_value, value.stride() - ) + broadcasted_grad_value = value.new_empty((Bq, Hkv, seq_len_kv, v_head_dim)) + broadcasted_grad_value = _permute_strides(broadcasted_grad_value, value.stride()) - if Bq > 1 and Bkv == 1: - grad_key = torch.sum(broadcasted_grad_key, dim=0, keepdim=True) - grad_value = torch.sum(broadcasted_grad_value, dim=0, keepdim=True) - else: - grad_key = broadcasted_grad_key - grad_value = broadcasted_grad_value + if Bq > 1 and Bkv == 1: + grad_key = torch.sum(broadcasted_grad_key, dim=0, keepdim=True) + grad_value = torch.sum(broadcasted_grad_value, dim=0, keepdim=True) + else: + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value - return grad_query, grad_key, grad_value, grad_score_mod_captured + return grad_query, grad_key, grad_value, grad_score_mod_captured flex_attention_backward.py_autograd_impl( diff --git a/torch/_higher_order_ops/hints_wrap.py b/torch/_higher_order_ops/hints_wrap.py index 7cebc9a4fe92d6..3f21c518cbd741 100644 --- a/torch/_higher_order_ops/hints_wrap.py +++ b/torch/_higher_order_ops/hints_wrap.py @@ -38,8 +38,7 @@ def __call__(self, body_fn, args, kwargs, hints): if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args): raise RuntimeError( - "args must be a tuple of tensors, ints, floats, or bools, got " - f"{args}" + f"args must be a tuple of tensors, ints, floats, or bools, got {args}" ) if not isinstance(kwargs, dict): diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index b6920b95db7b1e..0b21b61531003e 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs - import contextlib from contextlib import nullcontext from dataclasses import dataclass, field @@ -17,6 +16,7 @@ clone_outputs_aliasing_inputs, FunctionalizeCtxWrapper, get_dummy_aot_autograd_config, + HopInstance, prepare_fw_with_masks, reenter_make_fx, register_fake, @@ -51,7 +51,9 @@ class OutputMetadata: class InvokeSubgraphHOP(HigherOrderOperator): def __init__(self) -> None: - super().__init__("invoke_subgraph") + # Invoke subgraph does not have any state, it is just a wrapper over a + # subgraph, so we can safely cache the HOP. + super().__init__("invoke_subgraph", cacheable=True) # This is used by the fake tensor cache key validator to extract the # subgraph and iterate over the nodes to find if all nodes are fake # tensor cacheable. @@ -67,9 +69,9 @@ def __call__( identifier: Optional[str], *operands, ): - assert identifier is None or isinstance( - identifier, str - ), "identifier must be a None or a string" + assert identifier is None or isinstance(identifier, str), ( + "identifier must be a None or a string" + ) assert all( isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands @@ -80,19 +82,26 @@ def __call__( def gen_schema(self, subgraph, identifier, *operands): from torch._higher_order_ops.schema import HopSchemaGenerator from torch._higher_order_ops.utils import ( - check_input_alias_and_mutation_return_ouputs, + check_input_alias_and_mutation_return_outputs, + materialize_as_graph, + ) + + gm: torch.fx.GraphModule = ( + subgraph + if isinstance(subgraph, torch.fx.GraphModule) + else materialize_as_graph(subgraph, operands) ) schema_gen = HopSchemaGenerator(self) - schema_gen.add_arg("subgraph", subgraph) + schema_gen.add_arg("subgraph", gm) schema_gen.add_arg("identifier", identifier) - example_inputs = [ - n.meta["val"] if "val" in n.meta else n.meta["example_value"] - for n in subgraph.graph.find_nodes(op="placeholder") - ] - _, _, _, mutated_inputs, outputs = check_input_alias_and_mutation_return_ouputs( - subgraph, example_inputs - ) + ( + _, + _, + _, + mutated_inputs, + outputs, + ) = check_input_alias_and_mutation_return_outputs(gm, operands) for idx, arg in enumerate(operands): schema_gen.add_arg(f"arg{idx}", arg, is_mutated=idx in mutated_inputs) for out in outputs: @@ -118,7 +127,11 @@ def invoke_subgraph_placeholder(func, *args, **kwargs): def _invoke_subgraph_placeholder_wrapper(func, args): return invoke_subgraph_placeholder(func, *args) - with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with ( + _set_compilation_env(), + torch._dynamo.utils.disable_cache_limit(), + _temp_remove_pre_dispatch_torch_function_mode(), + ): with _temp_remove_metadata_torch_function_mode() as metadata_mode: if metadata_mode: backend = make_eager_backend_with_torch_function_mode(metadata_mode) @@ -146,7 +159,13 @@ def mark_compile_region(fn=None): def wrap(func): def inner(*args, **kwargs): - return invoke_subgraph_placeholder(func, *args, **kwargs) + # Get the innermost function to avoid nested compile regions + inner_func = func + while hasattr(inner_func, "__marked_compile_region_fn__"): + inner_func = inner_func.__marked_compile_region_fn__ + return invoke_subgraph_placeholder(inner_func, *args, **kwargs) + + inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined] return inner @@ -529,7 +548,30 @@ def _(subgraph, identifier, *operands): @invoke_subgraph.py_functionalize_impl def _(ctx, subgraph, identifier, *operands): + from torch._higher_order_ops.auto_functionalize import ( + can_auto_functionalize, + do_auto_functionalize_v2, + ) + unwrapped_operands = ctx.unwrap_tensors(operands) + hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands) + if can_auto_functionalize(hop_instance): + # NOTE: [auto_functionalize x invoke_subgraph caching] + # We call auto_functionalized_v2 to support input mutation of invoke_subgraph. + # See NOTE [Support input mutation of hops] for the overall design. + # + # invoke_subgraph is special because of its identifier based caching machanism. + # In invoke_subgraph's functionalization key implementation, we create a new + # identifer because the subgraph is replaced by FunctionWithNoFreeVars in a + # functional + epilogue form. + assert isinstance(identifier, str), identifier + return do_auto_functionalize_v2( + ctx.mode, + hop_instance, + (subgraph, "auto_functionalized_" + identifier, *operands), + {}, + ) + with ctx.redispatch_to_next(): # NB: There is an assumption that subgraph does not mutate inputs and # there is no aliasing. Its Dynamo responsibility to prevent formation @@ -574,13 +616,41 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands): graph.recompile() assert isinstance(proxy_mode.tracer, torch.fx.Tracer) - qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph") - proxy_mode.tracer.root.register_module(qualname, graph) if invoke_subgraph_cache: invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph) node_args = (graph, identifier, *operands) - proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) # type: ignore[union-attr] + + def _unwrap_proxy(arg): + if isinstance(arg, torch.fx.GraphModule): + # NOTE: [invoke_subgraph proxy_mode x auto_functionalize] + # Previously, we assumed that `invoke_subgraph` would always be traced with the same tracer. + # This allowed us to cache modules by their identifiers, assuming they were already registered. + # + # However, this assumption no longer holds when we auto-functionalize `invoke_subgraph`. + # auto_functionalize functionalizes the subgraph and wrap it with `FunctionWithNoFreeVars`. + # In the proxy mode implementation of `auto_functionalized_v2`, we need to materialize `FunctionWithNoFreeVars` + # input as a graph module. To do this, we re-trace the `invoke_subgraph` hop, which starts a new sub-tracer + # (see NOTE [materialize callable inputs as graph]). # When the new sub-tracer traces the `invoke_subgraph` + # with a previously cached identifier, the corresponding graph module might not + # exist as a submodule in the new tracer's root. Therefore, we register it as a submodule below. + # + # The alternative is to give a new identifer when we re-trace the invoke_subgraph but this will increase + # the compilatoin time, which defeats the purpose of caching. + registered_before = False + for ( + _, + submod, + ) in proxy_mode.tracer.root.named_modules(): # type: ignore[union-attr] + if arg is submod: + registered_before = True + + if not registered_before: + qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph") # type: ignore[union-attr] + proxy_mode.tracer.root.register_module(qualname, arg) # type: ignore[union-attr] + return proxy_mode.tracer.unwrap_proxy(arg) # type: ignore[union-attr] + + proxy_args = pytree.tree_map(_unwrap_proxy, node_args) # type: ignore[union-attr] out_proxy = proxy_mode.tracer.create_proxy( "call_function", invoke_subgraph, proxy_args, {} ) diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 191a42afb414f0..ff26c25222db98 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -246,8 +246,7 @@ def map_dense(f, xs, pos_args): return _stack_pytree(pytrees) -# TODO: Rework DispatchKey.Autograd to py_autograd_impl -@map_impl.py_impl(DispatchKey.Autograd) +@map_impl.py_autograd_impl def map_autograd(f, xs, pos_args): num_mapped_args = len(xs) fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index 4e68aa7b081e61..fb94bda71d2d80 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -8,11 +8,12 @@ import torch._prims_common as utils import torch.utils._pytree as pytree from torch._C import DispatchKey -from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph +from torch._higher_order_ops.cond import create_bw_fn from torch._higher_order_ops.utils import ( _maybe_compile_and_run_fn, check_meta_consistency, first_slice_copy, + materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, saved_tensors_and_symints, @@ -35,9 +36,9 @@ def wrap_combine_fn_flat( *args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves ): - assert len(args) == ( - num_init_leaves + num_inp_leaves - ), f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}" + assert len(args) == (num_init_leaves + num_inp_leaves), ( + f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}" + ) carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init) xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs) return combine_fn(carry, xs) @@ -72,13 +73,13 @@ def mask_list( # If other is None, then the elements of the `inp` list where the mask is False are removed # If other is not None, then the elements of the `inp` list where the mask is False are # replaced with the elements of the `other` list - assert len(mask) == len( - inp - ), "The length of the mask needs to be identical to the length of the input" + assert len(mask) == len(inp), ( + "The length of the mask needs to be identical to the length of the input" + ) if other is not None: - assert len(inp) == len( - other - ), "If an input and an other list is provided, they need to have the same length" + assert len(inp) == len(other), ( + "If an input and an other list is provided, they need to have the same length" + ) return [i if m else o for m, i, o in zip(mask, inp, other)] else: return [i for m, i in zip(mask, inp) if m] @@ -96,9 +97,9 @@ def first_slice_copy_with_grad(li: list[Any]) -> list[Any]: def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: it = iter(iterable) - assert sum(chunk_sizes) == len( - iterable - ), "the sum of all chunks needs to match the length of the iterable." + assert sum(chunk_sizes) == len(iterable), ( + "the sum of all chunks needs to match the length of the iterable." + ) return [list(itertools.islice(it, size)) for size in chunk_sizes] @@ -149,11 +150,22 @@ def scan( out (torch.Tensor or pytree with tensor leaves), each tensor leaf is a stacked output along first dim, where each slice is the output of a scan iteration. + Restrictions: + - The combine_fn shouldn't have any aliasing between input-input, input-output, and output-output. E.g. return a view + or the same tensor as input is not supported. As a workaround, can clone the output to avoid aliasing. + + - The combine_fn shoudn't mutate any inputs. We'll remove the mutation restriction for inference soon. Please file an issue + if you input mutation support for training is needed. + + - The combine_fn's init carry should match the next_carry in pytree structure and in tensor metadata. + Example:: def add(x: torch.Tensor, y: torch.Tensor): next_carry = y = x + y - return next_carry, y + # clone the output to avoid output-output aliasing + return next_carry, y.clone() + i0 = torch.zeros(1) xs = torch.arange(5) @@ -251,9 +263,9 @@ def __call__(self, combine_fn, init, xs, additional_inputs): # the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785 # Once this issue is resolved, the assertion should only allow tuples # and the tuple cast should be removed - assert isinstance( - additional_inputs, (tuple, list) - ), "additional_inputs must be a tuple." + assert isinstance(additional_inputs, (tuple, list)), ( + "additional_inputs must be a tuple." + ) additional_inputs = ( tuple(additional_inputs) if isinstance(additional_inputs, list) @@ -805,19 +817,6 @@ def construct_args_single_step_bw(): @scan_op.py_autograd_impl def scan_autograd(combine_fn, init, xs, additional_inputs): - if not any( - el.requires_grad - for el in (tuple(init) + tuple(xs) + additional_inputs) - if isinstance(el, torch.Tensor) - ): - with torch._C._AutoDispatchBelowAutograd(): - return scan_op( - combine_fn, - init, - xs, - additional_inputs, - ) - num_leaves_init = len(init) num_leaves_xs = len(xs) num_additional_inputs = len(additional_inputs) diff --git a/torch/_higher_order_ops/schema.py b/torch/_higher_order_ops/schema.py index 70f8bfebb53bcd..c7378147a205e1 100644 --- a/torch/_higher_order_ops/schema.py +++ b/torch/_higher_order_ops/schema.py @@ -1,3 +1,4 @@ +import copy from dataclasses import dataclass from typing import Any, Optional @@ -34,9 +35,9 @@ def from_example( kw_only: bool = False, ) -> HopArgumentInfo: if default_value is not None: - assert type(example_value) == type( - default_value - ), f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}" + assert type(example_value) == type(default_value), ( + f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}" + ) return HopArgumentInfo( name=name, @@ -92,6 +93,7 @@ class HopSchemaGenerator: def __init__(self, hop: torch._ops.HigherOrderOperator): self.arg_infos: list[HopArgumentInfo] = [] self.example_outputs: list[Any] = [] + self.schema_tree_spec: Optional[pytree.TreeSpec] = None self.hop = hop def add_arg( @@ -109,6 +111,16 @@ def add_arg( "Expect callable to be a GraphModule or an. Please call materialize_as_graph first " f"to turn callable arguments {example_value} into a GraphModule." ) + _, flat_spec = pytree.tree_flatten(example_value) + if not flat_spec.is_leaf(): + raise RuntimeError( + f"example_value {example_value} is not a leaf node. " + "Please only add flattened inputs to the hop schema. " + "If you need some structure in the arguments, please" + "add_arg for flattened args one by one then " + "call add_schema_tree_spec to register the original pytree " + " spec of the args." + ) arg_info = HopArgumentInfoGen.from_example( example_value=example_value, @@ -122,11 +134,28 @@ def add_arg( def add_output(self, output: Any) -> None: self.example_outputs.append(output) + def add_schema_tree_spec(self, *args: Any, **kwargs: Any) -> None: + """schema tree spec is the tree spec from flattening all inputs to the hop with pytree.tree_flatten + Since torch.FunctionSchema only have proper mutation/alias support for flattened inputs, we need + to store the tree spec in order to reconstruct the inputs to the hop. + """ + self.schema_tree_spec = pytree.tree_flatten((args, kwargs))[1] + def gen_schema(self) -> torch._C.FunctionSchema: + for i, arg_info in enumerate(self.arg_infos): + arg_spec = pytree.tree_flatten(arg_info.example_value)[1] + if not arg_spec.is_leaf() and self.schema_tree_spec is None: + raise RuntimeError( + f"example_value of arg_infos[{i}] is {arg_info.example_value}, which is not a leaf node. " + "Please call add_schema_tree_spec to add a schema tree spec first. " + "Or consider changing the hop's signature to only take flattened arguments." + ) + return CFunctionSchemaGen.from_hop_argument_info( str(self.hop), self.arg_infos, HopArgumentInfoGen.from_example(tuple(self.example_outputs), name="out"), + self.schema_tree_spec, ) @@ -171,18 +200,19 @@ def from_hop_argument_info( op_name: str, inp_argument_info: list[HopArgumentInfo], out_argument_info: HopArgumentInfo, + schema_tree_spec: Optional[pytree.TreeSpec], ) -> Any: args = [] for i, arg_info in enumerate(inp_argument_info): args.append(CArgumentGen.from_hop_argument_info(i, arg_info)) # NOTE: we want the output to always be a single argument with torch._C.TupleType. - assert isinstance( - out_argument_info.example_value, tuple - ), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}" - assert ( - not out_argument_info.is_mutated - ), "out_argument_info.is_mutated should always be set to False." + assert isinstance(out_argument_info.example_value, tuple), ( + f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}" + ) + assert not out_argument_info.is_mutated, ( + "out_argument_info.is_mutated should always be set to False." + ) rets = None if len(out_argument_info.example_value) == 1: rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)] @@ -201,13 +231,51 @@ def from_hop_argument_info( for i, val in enumerate(out_argument_info.example_value) ] - return torch._C.FunctionSchema( + return HopSchema( op_name, "", args, rets, False, False, + schema_tree_spec, + ) + + +class HopSchema(torch._C.FunctionSchema): + def __init__( + self, + name: str, + overload_name: str, + arguments: list[torch._C.Argument], + returns: list[torch._C.Argument], + is_vararg: bool, + is_varret: bool, + schema_tree_spec: Optional[pytree.TreeSpec], + ): + self.tree_spec = schema_tree_spec + self.is_vararg = is_vararg + self.is_varret = is_varret + super().__init__( + name, + overload_name, + arguments, + returns, + self.is_vararg, + self.is_varret, + ) + + def __deepcopy__(self, memo: Any) -> "HopSchema": + # Need to additionally copy the tree_spec since + # it's not a member of torch._C.FunctionSchema + return HopSchema( + self.name, + self.overload_name, + self.arguments, + self.returns, + self.is_vararg, + self.is_varret, + copy.deepcopy(self.tree_spec), ) diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index 8b5d83485e0a2d..5496276f1ddada 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -81,9 +81,9 @@ def enable_torchbind_tracing(): torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign] yield finally: - assert ( - KNOWN_TYPES.pop() is torch.ScriptObject - ), "Someone else messed with KNOWN_TYPES during tracing, exploding." + assert KNOWN_TYPES.pop() is torch.ScriptObject, ( + "Someone else messed with KNOWN_TYPES during tracing, exploding." + ) torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign] @@ -127,9 +127,9 @@ def inner(mode, *args, **kwargs): ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) if "val" not in out_proxy.node.meta: - assert out is None or isinstance( - out, (int, float, bool) - ), "Currently, only these constant dtypes are supported to be returned from torchbind methods." + assert out is None or isinstance(out, (int, float, bool)), ( + "Currently, only these constant dtypes are supported to be returned from torchbind methods." + ) out_proxy.node.meta["val"] = out return ret diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 4703b79af93c16..71fb2044103776 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -3,7 +3,9 @@ import dataclasses import functools import inspect +import itertools import logging +import operator import threading from collections import defaultdict from collections.abc import Sequence @@ -63,14 +65,10 @@ class JITFunction: # type: ignore[no-redef] log = logging.getLogger("torch._dynamo") -# TMADescriptorMetadata maps kernel parameter names to the metadata that allows -# reconstructing TMA descriptors from the underlying tensors (passed as kernel -# arguments in the fx graph, instead of the TMA descriptors). Namely: a tuple -# conisting of list of dims, list of block dims, and element size. E.g., for this -# call in host-side Triton TMA API ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``, -# the metadata will look like ``([50, 60], [32, 15], 4)``. All ints can be SymInts. -TMADescriptorMetadata = dict[ - str, # kernel parameter name +# e.g. for a host-side Triton TMA API call ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``, +# the metadata will look like ``("experimental", ([50, 60], [32, 15], 4))`` +TMAExperimentalMetadata = tuple[ + str, # type of TMA (should be "experimental") tuple[ list[IntLikeType], # dims list[IntLikeType], # block_dims @@ -78,6 +76,63 @@ class JITFunction: # type: ignore[no-redef] ], ] +# e.g. for host-side Triton TMA API call ``TensorDescriptor.from_tensor(ptr, [32, 64])`` +# the metadata will look like ``("stable", ([32, 64],))`` +TMAStableMetadata = tuple[ + str, # type of TMA ("experimental" or "stable") + tuple[list[IntLikeType],], # block_shape +] + + +def create_tma_experimental_metadata( + dims: list[IntLikeType], + block_dims: list[IntLikeType], + element_size: IntLikeType, +) -> TMAExperimentalMetadata: + return ("experimental", (dims, block_dims, element_size)) + + +def maybe_unpack_tma_experimental_metadata( + tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata], +) -> Optional[tuple[list[IntLikeType], list[IntLikeType], IntLikeType]]: + if not tma_meta or len(tma_meta) != 2: + return None + if tma_meta[0] == "experimental": + return tma_meta[1] # type: ignore[return-value] + return None + + +def create_tma_stable_metadata( + block_shape: list[IntLikeType], +) -> TMAStableMetadata: + return ("stable", (block_shape,)) + + +def maybe_unpack_tma_stable_metadata( + tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata], +) -> Optional[tuple[list[IntLikeType]]]: + if not tma_meta or len(tma_meta) != 2: + return None + if tma_meta[0] == "stable": + return tma_meta[1] # type: ignore[return-value] + return None + + +# TMADescriptorMetadata maps kernel parameter names to the metadata that allows +# reconstructing TMA descriptors from the underlying tensors (passed as kernel +# arguments in the fx graph, instead of the TMA descriptors). +# +# Since there are two TMA APIs (the old "experimental" API and the new "stable" API), +# each entry in the dict is a tuple that starts with a string, either "experimental" +# or "stable". The second entry in the tuple is another tuple, with data that depends +# on the API type (see TMAExperimentalMetadata and TMAStableMetadata above). +# +# These are stored as raw tuples (instead of classes) for ease of serialization. +TMADescriptorMetadata = dict[ + str, # kernel parameter name + Union[TMAExperimentalMetadata, TMAStableMetadata], +] + ############################################################################### # Kernel Side Table @@ -172,7 +227,9 @@ def __post_init__(self) -> None: def generate_ttir( - kernel: "TritonKernelType", kwargs: dict[str, Any] + kernel: "TritonKernelType", + kwargs: dict[str, Any], + tma_descriptor_metadata: TMADescriptorMetadata, ) -> tuple["TritonIRModule", list[str]]: """ Uses Triton's internal code generation to create TTIR @@ -189,6 +246,7 @@ def generate_ttir( triton_version_uses_attrs_dict, TritonAttrsDescriptorVersion, ) + from torch.utils._triton import has_triton_tensor_descriptor_host_tma triton_version = get_triton_attrs_descriptor_version() @@ -230,15 +288,69 @@ def generate_ttir( a = kwargs[name] if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)): ordered_args[name] = 2 + elif ( + stable_meta := maybe_unpack_tma_stable_metadata( + tma_descriptor_metadata.get(name, None) + ) + ) is not None: + from triton.tools.tensor_descriptor import TensorDescriptor + + block_shape = stable_meta[0] + with torch._C._DisableTorchDispatch(): + # need 16-byte aligned strides + elements_per_dim = max(1, 16 // a.dtype.itemsize) + base_tensor = torch.empty( + [elements_per_dim] * len(block_shape), dtype=a.dtype + ) + ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape) elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)): with torch._C._DisableTorchDispatch(): ordered_args[name] = torch.empty(2, dtype=a.dtype) else: ordered_args[name] = a - ordered_tensor_names = [ - name for name, arg in ordered_args.items() if isinstance(arg, Tensor) - ] + def is_stable_tensor_descriptor_arg(arg: Any) -> bool: + if has_triton_tensor_descriptor_host_tma(): + from triton.tools.tensor_descriptor import TensorDescriptor + + if isinstance(arg, TensorDescriptor): + return True + return False + + def is_tensor_like_arg(arg: Any) -> bool: + if isinstance(arg, Tensor) or is_stable_tensor_descriptor_arg(arg): + return True + return False + + # Note: one would expect that each input to the triton kernel maps to + # one input parameter in the TTIR. This is _not_ true for TMA descriptors: + # one TMA descriptor gets converted into: + # * one TMA descriptor input + # * N strides, for a rank-N tensor + # * N sizes, for a rank-N tensor + # To account for this, we inject some fake arg names as placeholders for + # the stride and size parameters. + def get_tensor_names(name: str, arg: Any) -> list[str]: + if isinstance(arg, Tensor): + return [name] + if is_stable_tensor_descriptor_arg(arg): + stable_meta = maybe_unpack_tma_stable_metadata( + tma_descriptor_metadata[name] + ) + assert stable_meta is not None + block_shape = stable_meta[0] + tensor_rank = len(block_shape) + names = [name] + names.extend(name + f" STRIDE PLACEHOLDER {i}" for i in range(tensor_rank)) + names.extend(name + f" SIZE PLACEHOLDER {i}" for i in range(tensor_rank)) + return names + return [] + + ordered_tensor_names = list( + itertools.chain.from_iterable( + get_tensor_names(name, arg) for name, arg in ordered_args.items() + ) + ) def _get_specialization(args): # type: ignore[no-untyped-def] # Support multiple triton versions. @@ -304,7 +416,7 @@ def _get_specialization(args): # type: ignore[no-untyped-def] specialization = _get_specialization(ordered_args.values()) constants = { - name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor) + name: arg for name, arg in ordered_args.items() if not is_tensor_like_arg(arg) } if (mangle_type := getattr(triton.runtime.jit, "mangle_type", None)) is not None: @@ -620,7 +732,7 @@ def mlir_to_functions(op: "TritonIROperation") -> None: class MemoizeWithCycleCheck: fn: Callable[..., Any] - cache: dict[tuple[str, int], Any] + cache: dict[tuple[Any], Any] def __init__(self, fn: Callable[..., Any]) -> None: self.fn = fn @@ -630,12 +742,12 @@ def __call__( self, functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str, - num_args: int, + *args: Any, ) -> list[bool]: - key = (fn_name, num_args) + key: tuple[Any, ...] = (fn_name, *args) if key not in self.cache: self.cache[key] = None - self.cache[key] = self.fn(functions, fn_name, num_args) + self.cache[key] = self.fn(functions, fn_name, *args) if self.cache[key] is None: raise RuntimeError("Recursion is not supported") return self.cache[key] @@ -644,6 +756,64 @@ def reset(self) -> None: self.cache = {} +@MemoizeWithCycleCheck +def get_tma_stores( + functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str +) -> set[Union[Intermediate, Param]]: + """ + Identifies all intermediates and parameters that are written to by a + `tt.experimental_descriptor_store`. It tracks only the specific values + written to via experimental_descriptor_store and the input values to + `tt.reinterpret_tensor_descriptor` used to construct the direct inputs + to tt.experimental_descriptor_store - not any recursive values + used to construct those values. + + For example: for + tt.reinterpret_tensor_descriptor(Intermediate(idx=0), ...) + Intermediate(idx=1) = tt.experimental_descriptor_store(Intermediate(idx=0), ...) + this function will return [Intermediate(idx=0), Intermediate(idx=1)], + + However + Intermediate(idx=4) = arith.addptr(Intermediate(idx=2), Intermediate(idx=3)) + Intermediate(idx=5) = tt.experimental_descriptor_store(Intermediate(idx=4), ...) + tt.experimental_descriptor_store(Intermediate(idx=5), ...) + this function will mark only idx=4 and idx=5 (but not idx=2 or idx=3) + + If an intermediate/parameter is passed into a function and is written to + via experimental_descriptor_store within that function, the argument to the + function will also be marked. + """ + + result: set[Union[Intermediate, Param]] = set() + + ops = functions[fn_name] + for op_list in ops.values(): + for op in op_list: + if op.name == "tt.call": + assert op.fn_call_name in functions + tma_stores = get_tma_stores(functions, op.fn_call_name) + for i, inp in enumerate(op.args): + if Param(idx=i) in tma_stores: + result.add(inp) + elif op.name == "tt.experimental_descriptor_store": + assert len(op.args) >= 1 + result.add(op.args[0]) + elif op.name == "tt.descriptor_store": + assert len(op.args) >= 1 + result.add(op.args[0]) + + for val in list(result): + if val in ops: + if not isinstance(val, Intermediate): + continue + for op in ops[val]: + if op.name == "tt.reinterpret_tensor_descriptor": + assert len(op.args) >= 1 + result.add(op.args[0]) + + return result + + @MemoizeWithCycleCheck def analyze_kernel_mutations( functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str, num_args: int @@ -663,6 +833,8 @@ def analyze_kernel_mutations( "tt.atomic_cas": [0], "tt.atomic_rmw": [0], "tt.experimental_descriptor_store": [0], + "tt.experimental_tensormap_create": [0], + "tt.descriptor_store": [0], } # Ops that we want to bail out on UNKNOWN_OPS = {"tt.elementwise_inline_asm"} @@ -670,20 +842,32 @@ def analyze_kernel_mutations( stack: list[Union[Param, Intermediate]] = [] visited = set() ops = functions[fn_name] + tma_stores = get_tma_stores(functions, fn_name) + for op_list in ops.values(): for op in op_list: # If we encounter an operation with effects that cannot be reliably analyzed # (e.g. `tt.elementwise_inline_asm`), we assume it does not mutate any input parameters. if op.name in UNKNOWN_OPS: if op.name == "tt.elementwise_inline_asm" and op.is_pure: - log.warning( - "TTIR mutation analysis: Skipping pure tt.elementwise_inline_asm op (is_pure=True)" - ) continue raise RuntimeError( f"ttir analysis hit an op we do not know how to analyze: {op.name}" ) + if op.name == "tt.experimental_tensormap_create": + # Note: this is how we implement experimental_descriptor_store mutation analysis. + # for on-device TMA. + # experimental_tensormap_store(a, b, ...) stores b to the location specified + # by descriptor in the memory of a. + # To track this, we first find all the intermediates/params to which we store via + # experimental_tensormap_store (get_tma_stores, called above). Then, during this + # analysis we wait to find the corresponding experimental_tensormap_create (if it + # exists), at which point we will mark the global_ptr as mutated (as done below). + assert len(op.args) >= 2 + if op.args[0] in tma_stores: + stack.append(op.args[1]) + if op.name == "tt.call": assert op.fn_call_name in functions mutations = analyze_kernel_mutations( @@ -716,7 +900,9 @@ def analyze_kernel_mutations( def identify_mutated_tensors( - kernel: "TritonKernelType", kwargs: dict[str, Any] + kernel: "TritonKernelType", + kwargs: dict[str, Any], + tma_descriptor_metadata: TMADescriptorMetadata, ) -> list[str]: """ Given a triton kernel and the arguments for this kernel, this function @@ -728,7 +914,9 @@ def identify_mutated_tensors( ttir_module = None functions = None try: - ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs) + ttir_module, ordered_tensor_names = generate_ttir( + kernel, kwargs, tma_descriptor_metadata + ) # extract functions from TTIR using MLIR bindings exposed by Triton code functions = ttir_to_functions(ttir_module) @@ -741,6 +929,7 @@ def identify_mutated_tensors( # The cache for analyze kernel mutations is mainly used for cycle # detection, so each top level invocation needs a clean cache analyze_kernel_mutations.reset() + get_tma_stores.reset() mutations = analyze_kernel_mutations( functions, kernel_name, len(ordered_tensor_names) ) @@ -845,11 +1034,6 @@ def triton_kernel_wrapper_mutation_dense( grid_fn = namespace[fn_name] if tma_descriptor_metadata: - from triton.tools.experimental_descriptor import ( # noqa: F401 - create_1d_tma_descriptor, - create_2d_tma_descriptor, - ) - # as we need to launch the kernel here, we "unwrap" the # tma_descriptor_metadata, create the TMA descriptors # from it, and replace the tensors in the kwargs by the @@ -857,16 +1041,32 @@ def triton_kernel_wrapper_mutation_dense( kwargs = kwargs.copy() for k, v in tma_descriptor_metadata.items(): tensor = kwargs[k] - dims, block_dims, element_size = v - create_tma_descriptor = ( - create_1d_tma_descriptor if len(dims) == 1 else create_2d_tma_descriptor - ) - kwargs[k] = create_tma_descriptor( - tensor.data_ptr(), - *dims, - *block_dims, - element_size, - ) + if (exp_meta := maybe_unpack_tma_experimental_metadata(v)) is not None: + from triton.tools.experimental_descriptor import ( # noqa: F401 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + + dims, block_dims, element_size = exp_meta + create_tma_descriptor = ( + create_1d_tma_descriptor + if len(dims) == 1 + else create_2d_tma_descriptor + ) + kwargs[k] = create_tma_descriptor( + tensor.data_ptr(), + *dims, + *block_dims, + element_size, + ) + else: + stable_meta = maybe_unpack_tma_stable_metadata(v) + assert stable_meta is not None + from triton.tools.tensor_descriptor import TensorDescriptor + + block_shape = stable_meta[0] + kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape) + # move as many positional arguments from dicts to args as we # can to circumvent the bug with the kwargs and pre_/post_hook: # https://github.com/triton-lang/triton/issues/5082 @@ -922,7 +1122,8 @@ def trace_triton_kernel_wrapper( out = func_overload(**node_args) proxy_args = pytree.tree_map( - proxy_mode.tracer.unwrap_proxy, node_args # type: ignore[union-attr] + proxy_mode.tracer.unwrap_proxy, # type: ignore[union-attr] + node_args, ) out_proxy = proxy_mode.tracer.create_proxy( "call_function", @@ -962,11 +1163,16 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( def get_mutated_tensors( - kernel_idx: int, constant_args_idx: int, kwargs: dict[str, Any] + kernel_idx: int, + constant_args_idx: int, + kwargs: dict[str, Any], + tma_descriptor_metadata: TMADescriptorMetadata, ) -> list[str]: kernel = kernel_side_table.get_kernel(kernel_idx) constant_args = kernel_side_table.get_constant_args(constant_args_idx) - return identify_mutated_tensors(kernel, {**kwargs, **constant_args}) + return identify_mutated_tensors( + kernel, {**kwargs, **constant_args}, tma_descriptor_metadata + ) @triton_kernel_wrapper_mutation.py_functionalize_impl @@ -984,7 +1190,7 @@ def triton_kernel_wrapper_mutation_functionalize( # they are no longer equal. Fix this by graph breaking on this condition # earlier in dynamo. tensors_to_clone = get_mutated_tensors( - kernel_idx, constant_args_idx, unwrapped_kwargs + kernel_idx, constant_args_idx, unwrapped_kwargs, tma_descriptor_metadata ) with ctx.redispatch_to_next(): unwrapped_outputs = triton_kernel_wrapper_functional( @@ -1154,7 +1360,7 @@ class TritonHOPifier: to the HOP (which can then be traced). Because Dynamo has its own calling conventions for e.g. invoking a user-defined function - TritonHOPifier is an abstract class that can be overriden by its subclasses. + TritonHOPifier is an abstract class that can be overridden by its subclasses. """ def raise_unsupported(self, msg: str) -> Never: @@ -1246,7 +1452,7 @@ def do_prune_configs( # type: ignore[no-untyped-def] ] configs = [ config[0] - for config in sorted(est_timing, key=lambda x: x[1])[:top_k] + for config in sorted(est_timing, key=operator.itemgetter(1))[:top_k] ] return configs @@ -1455,9 +1661,9 @@ def call_triton_kernel( # Update the kwargs in each config # maybe_unpack_heuristic_result raises unsupported if the value is non-constant - new_configs[config_idx].__dict__["kwargs"][ - kwarg_key - ] = self.maybe_unpack_heuristic_result(heuristic_result) + new_configs[config_idx].__dict__["kwargs"][kwarg_key] = ( + self.maybe_unpack_heuristic_result(heuristic_result) + ) iter_kernel = iter_kernel.fn assert isinstance(iter_kernel, JITFunction) @@ -1537,9 +1743,9 @@ def call_triton_kernel( for config in new_configs: for name in special_param_names: if name not in config.__dict__["kwargs"]: - assert ( - name in config.__dict__ - ), f"{name} must be in autotuning configs to be used as a kernel parameter" + assert name in config.__dict__, ( + f"{name} must be in autotuning configs to be used as a kernel parameter" + ) config.__dict__["kwargs"][name] = config.__dict__[name] updated = True diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 2525ef6af7340a..9b5293fbe64bd1 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1,15 +1,22 @@ # mypy: allow-untyped-defs +import contextlib import functools from contextlib import contextmanager, ExitStack, nullcontext from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, overload, TypeVar, Union import torch import torch.fx.traceback as fx_traceback import torch.utils._pytree as pytree +from torch._dispatch.python import suspend_functionalization from torch._guards import detect_fake_mode +from torch._higher_order_ops.schema import HopSchema from torch._ops import HigherOrderOperator, OperatorBase, OpOverload from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import ( + disable_functional_mode, + FunctionalTensor, +) from torch.fx.experimental.proxy_tensor import ( _temp_remove_metadata_torch_function_mode, disable_proxy_modes_tracing, @@ -108,9 +115,9 @@ def reenter_make_fx(fn): @functools.wraps(fn) def wrapped(*args): - assert ( - _CURRENT_MAKE_FX_TRACER is not None - ), "Cannot reenter make_fx when we're not under a make_fx tracing session" + assert _CURRENT_MAKE_FX_TRACER is not None, ( + "Cannot reenter make_fx when we're not under a make_fx tracing session" + ) return _CURRENT_MAKE_FX_TRACER.trace_subgraph( _maybe_run_with_interpreter(fn), *args ) @@ -316,20 +323,22 @@ def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations): def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False): ( - _, - _, - _, - ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) + (_, _, _), + inp_mutation, + ) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) return len(inp_mutation) > 0 def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False): ( - inp_inp_alias_map, - inp_out_alias_map, - out_out_alias_map, - ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) + ( + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + ), + inp_mutation, + ) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) return ( any( ( @@ -385,9 +394,7 @@ def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch): graph_module, inputs_fake, pre_dispatch=pre_dispatch ) if aliases: - raise RuntimeError( - f"{name} might be aliasing the input or the output!" - ) # noqa: F541 + raise RuntimeError(f"{name} might be aliasing the input or the output!") # noqa: F541 if inp_mutation: raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541 @@ -416,7 +423,6 @@ def unique_graph_name_with_root( def _from_fun(t): from torch._functorch.aot_autograd import from_fun - from torch._subclasses.functional_tensor import FunctionalTensor if isinstance(t, torch.Tensor): if t.dtype != torch.bool: @@ -498,9 +504,9 @@ def fw_with_masks(*args): # replaced with an all-zero tensor for better optimization def unmask_none_gradients(grads, operands): allowed_types = (torch.Tensor, int, torch.SymInt) - assert all( - isinstance(o, allowed_types) for o in operands - ), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}" + assert all(isinstance(o, allowed_types) for o in operands), ( + f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}" + ) unmasked_grads = [] for g, o in zip(grads, operands): @@ -528,6 +534,24 @@ def _maybe_fake_prop_ignore_unbacked(fn, args): return fn(*args) +def redirect_to_mode(hop: OperatorBase, mode): + """Utility for redispatching HOP to underlying mode + + Args: + hop: The HOP to redispatch + mode: The mode to redispatch to + + Returns: + A decorated function that implements the HOP for the given mode + """ + + @hop.py_impl(mode) + def impl(mode, *args, **kwargs): + return mode.__torch_dispatch__(hop, [], args, kwargs) + + return impl + + # TODO: The parameter use_output_and_grad_bw is required because some operations # that utilize this function, such as the while_loop, may require (grad, fwd_outputs) def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs): @@ -738,7 +762,9 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]) allowed_types = (torch.Tensor, int, torch.SymInt) assert all( isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args - ), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}" + ), ( + f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}" + ) # TODO: Return a more detailed information as to which node @@ -752,11 +778,11 @@ def check_input_alias_and_mutation( inp_out_alias_map, out_out_alias_map, mutated_inputs, - ) = check_input_alias_and_mutation_return_ouputs(gm, fake_args)[:-1] + ) = check_input_alias_and_mutation_return_outputs(gm, fake_args)[:-1] return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs -def check_input_alias_and_mutation_return_ouputs( +def check_input_alias_and_mutation_return_outputs( gm: torch.fx.GraphModule, fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]], ) -> tuple[ @@ -766,6 +792,29 @@ def check_input_alias_and_mutation_return_ouputs( list[int], Union[tuple[Any, ...], list[Any]], ]: + # This function can be called under autograd, functional, proxy and fake tensor mode. + # We need to return either a fake tensor or a real tensor depending on the mode. + # to detect the input mutation/aliasing. + with ( + disable_proxy_modes_tracing(), + disable_functional_mode(), + suspend_functionalization(), + ): + + def _from_functional_tensor(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, FunctionalTensor) or torch._is_functional_tensor(t): + return torch.empty_strided( + t.size(), + t.stride(), + dtype=t.dtype, + requires_grad=t.requires_grad, + device=t.device, + ) + return t + + fake_args = pytree.tree_map_only( + torch.Tensor, _from_functional_tensor, fake_args + ) # We want to disable active functional, proxy and fake modes if any. # to create a encapsulated environment for fake tensor prop with torch.utils._python_dispatch._disable_current_modes(): @@ -777,7 +826,8 @@ def check_input_alias_and_mutation_return_ouputs( def _tensor_version(t) -> Optional[int]: if isinstance(t, torch.Tensor): - assert isinstance(t, FakeTensor), "Only fake tensor is allowed" + if not isinstance(t, FakeTensor): + raise RuntimeError("Only fake tensor is allowed") return t._version return None @@ -795,13 +845,10 @@ def _get_shape_env( if len(fake_args) == 0: return torch.fx.experimental.symbolic_shapes.ShapeEnv() - prev_fake_mode = None for arg in fake_args: - if isinstance(arg, torch.Tensor): - assert isinstance(arg, FakeTensor) - prev_fake_mode = arg.fake_mode - assert prev_fake_mode is not None - return prev_fake_mode.shape_env + if isinstance(arg, FakeTensor): + return arg.fake_mode.shape_env + return None # Clone the fake args to avoid mutating the original fake args with ExitStack() as ctx_stack: @@ -883,6 +930,17 @@ def _get_shape_env( registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {} +F = TypeVar("F", bound=Callable) + + +@overload +def register_fake(hop, fn: None = None) -> Callable[[F], F]: ... + + +@overload +def register_fake(hop, fn: F) -> F: ... + + def register_fake(hop, fn=None): """ Register a fake function for a HOP. This is conceptually equivalent of the @@ -891,13 +949,10 @@ def register_fake(hop, fn=None): """ assert hop not in registered_hop_fake_fns - def register(func): + def register(func: F) -> F: from torch._subclasses.fake_tensor import FakeTensorMode - # Redirect the hop to the fake tensor mode implementation. - @hop.py_impl(FakeTensorMode) - def _(mode, *args, **kwargs): - return mode.__torch_dispatch__(hop, [], args, kwargs) + redirect_to_mode(hop, FakeTensorMode) registered_hop_fake_fns[hop] = func return func @@ -944,7 +999,7 @@ def __call__(self, *args, **kwargs): # A wrapper over HigherOrderOperator that also carries its schema class HopInstance: - def __init__(self, op: HigherOrderOperator, schema: torch.FunctionSchema): + def __init__(self, op: HigherOrderOperator, schema: HopSchema): assert isinstance(op, HigherOrderOperator), op self._op = op # Using "_" to be consistent with how we access _schema of OpOverload @@ -953,7 +1008,14 @@ def __init__(self, op: HigherOrderOperator, schema: torch.FunctionSchema): def __call__(self, *args, **kwargs): return self._op(*args, **kwargs) + @staticmethod + def create(hop: HigherOrderOperator, *args, **kwargs): + return HopInstance(hop, hop.gen_schema(*args, **kwargs)) + +# This call_op can be used to call a HopInstance with +# flat args and kwargs. We need to make use of the hop's schema's tree_spec +# to unflatten the args and kwargs before calling the hop. def call_op(op: Union[OpOverload, HopInstance], args, kwargs): if isinstance(op, OpOverload): return op(*args, **kwargs) @@ -969,7 +1031,44 @@ def call_op(op: Union[OpOverload, HopInstance], args, kwargs): bound_args.append(val) else: bound_kwargs[arg.name] = val - return op(*bound_args, **bound_kwargs) + + if schema.tree_spec is not None: + assert len(bound_args) == len(schema.arguments) and len(bound_kwargs) == 0 + args, kwargs = pytree.tree_unflatten(bound_args, schema.tree_spec) + return op(*args, **kwargs) + else: + assert len(bound_args) + len(bound_kwargs) == len(schema.arguments) + return op(*bound_args, **bound_kwargs) + + +def materialize_as_graph( + fn: Callable, + args: tuple[Any], + include_key_set: Optional[torch._C.DispatchKeySet] = None, + exclude_key_set: Optional[torch._C.DispatchKeySet] = None, + force_enable_grad=False, +) -> torch.fx.GraphModule: + if include_key_set is None: + include_key_set = torch._C._dispatch_tls_local_include_set() + if exclude_key_set is None: + exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + + @torch._dynamo.disable(recursive=True, reason=None) + def _materialize_as_graph_inner(): + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + unfunc_t = [_from_fun(arg) for arg in args] + with contextlib.ExitStack() as stack: + stack.enter_context( + torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set), + ) + if force_enable_grad: + stack.enter_context(torch.enable_grad()) + return _maybe_reenter_make_fx(fn)(*unfunc_t) + + gm = _materialize_as_graph_inner() + assert gm is not None + return gm def materialize_callable_in_args(op: HopInstance, args, kwargs): @@ -1003,3 +1102,37 @@ def wrapped_fn(*flat_args): materialized_args.append(flat_args[i]) return pytree.tree_unflatten(materialized_args, flat_spec) + + +def has_user_subclass(args, allowed_subclasses): + """Check if any tensor arguments are user subclasses. + + This is used to determine if tensor subclasses should get a chance to run + their own implementation first before falling back to the default implementation. + + Args: + args: Arguments to check (will be flattened with pytree) + allowed_subclasses: Tuple of allowed subclass types + + Returns: + True if user tensor subclasses are found, False otherwise + """ + flat_args, _ = pytree.tree_flatten(args) + + val = any( + isinstance(a, torch.Tensor) + and type(a) is not torch.Tensor + and not isinstance(a, allowed_subclasses) + for a in flat_args + ) + return val + + +def _has_gen_schema(op: HigherOrderOperator): + # There is an InvokeQuant argument we cannot gen_schema. + if op is torch.ops.higher_order.invoke_quant_packed: + return False + method = "gen_schema" + return hasattr(type(op), method) and getattr(type(op), method) is not getattr( + HigherOrderOperator, method + ) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 440450d5b603bf..e0e57dfad3f3a4 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -202,12 +202,12 @@ def _validate_cond_output(pred): while pred := cond_fn(*carried_vals, *additional_inputs): _validate_cond_output(pred) out = body_fn(*carried_vals, *additional_inputs) - assert isinstance( - out, tuple - ), f"body_fn should return a tuple but got {type(out)}" - assert len(out) == len( - carried_inputs - ), "body_fn should return the same number of elements as carried_inputs" + assert isinstance(out, tuple), ( + f"body_fn should return a tuple but got {type(out)}" + ) + assert len(out) == len(carried_inputs), ( + "body_fn should return the same number of elements as carried_inputs" + ) carried_vals = out return carried_vals @@ -230,9 +230,9 @@ def _find_or_create_fake_mode() -> FakeTensorMode: def _create_unbacked_symint( fake_mode: FakeTensorMode, ignore_fresh_unbacked_symbols: bool ) -> torch.SymInt: - assert ( - fake_mode is not None and fake_mode.shape_env is not None - ), "Must provide a fake_mode with shape_env." + assert fake_mode is not None and fake_mode.shape_env is not None, ( + "Must provide a fake_mode with shape_env." + ) ctx = ( contextlib.nullcontext() if not ignore_fresh_unbacked_symbols diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 649ae1168ce5f4..94762a68b3435c 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from torch._inductor.utils import InputType from torch.export import ExportedProgram + from torch.export.pt2_archive._package_weights import Weights from torch.types import FileLike __all__ = [ @@ -197,13 +198,13 @@ def _aoti_compile_and_package_inner( path = [ os.path.splitext(file)[0] for file in aoti_files - if os.path.splitext(file)[1] == ".so" + if isinstance(file, str) and os.path.splitext(file)[1] == ".so" ] if len(path) == 0: path = [ os.path.splitext(file)[0] for file in aoti_files - if os.path.splitext(file)[1] == ".cpp" + if isinstance(file, str) and os.path.splitext(file)[1] == ".cpp" ] package_path = path[0] + ".pt2" @@ -274,7 +275,7 @@ def aot_compile( kwargs: Optional[dict[str, Any]] = None, *, options: Optional[dict[str, Any]] = None, -) -> Union[str, list[str]]: +) -> Union[str, list[Union[str, Weights]]]: """ Ahead-of-time compile a given FX graph with TorchInductor into a shared library. diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 5319481adcde90..0be33474bcf1a4 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -8,7 +8,7 @@ import multiprocessing import os import sys -from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool from functools import partial from time import time, time_ns @@ -37,12 +37,15 @@ torch_key, ) from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool +from torch._inductor.compile_worker.tracked_process_pool import ( + TrackedProcessPoolExecutor, +) from torch._inductor.compile_worker.utils import _async_compile_initializer from torch._inductor.runtime.compile_tasks import ( _set_triton_ptxas_path, _worker_compile_triton, ) -from torch._inductor.utils import clear_on_fresh_inductor_cache +from torch._inductor.utils import clear_on_fresh_cache from torch._inductor.virtualized import V from torch.hub import _Faketqdm, tqdm from torch.utils._ordered_set import OrderedSet @@ -74,13 +77,10 @@ def pre_fork_setup(): # Computing the triton key can be slow. If we call it before fork, # it will be cached for the forked subprocesses. - try: - from triton.compiler.compiler import triton_key + from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key + if HAS_TRITON: triton_key() - except ImportError: - # Triton might not be installed or might be an old version. - pass def caching_device_properties(): @@ -162,7 +162,7 @@ def get_compile_threads() -> int: return config.compile_threads -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class CompiledTritonKernels: """ In memory cache for storing compiled triton kernels. @@ -251,7 +251,7 @@ def process_pool() -> AnyPool: os.environ["TORCH_WARM_POOL"] = "0" pre_fork_setup() ctx = multiprocessing.get_context(config.worker_start_method) - pool = ProcessPoolExecutor( + pool = TrackedProcessPoolExecutor( get_compile_threads(), mp_context=ctx, initializer=partial(_async_compile_initializer, os.getpid()), @@ -346,6 +346,7 @@ def reload_kernel_in_parent(): else: return future.result() + # Cache miss if is_parallel: # We want to support changing these env vars after (and while) the # process pool is running, so pass them to the subprocess to reset. @@ -441,7 +442,8 @@ def task(): if aot_compile: # We rely on JITInductor to compile the CUDA code, # so that we can load it into AOTInductor. - CUDACodeCache.compile(source_code, "o") + output_path, *_ = CUDACodeCache.compile(source_code, "o") + CUDACodeCache.aot_kernels_o.append(output_path) return CUDACodeCache.load(source_code, dst_file_ext)[0] return self.submit(task) @@ -456,7 +458,8 @@ def rocm( def task(): if aot_compile: - _ = ROCmCodeCache.compile(source_code, dst_file_ext="o") + output_path, *_ = ROCmCodeCache.compile(source_code, dst_file_ext="o") + ROCmCodeCache.aot_kernels_o.append(output_path) if config.rocm.generate_test_runner: _ = ROCmCodeCache.compile(source_code, dst_file_ext="exe") return ROCmCodeCache.load(source_code, dst_file_ext)[0] diff --git a/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/torch/_inductor/autoheuristic/learnedheuristic_interface.py index cb1519c8dd89de..cb2568d8a68019 100644 --- a/torch/_inductor/autoheuristic/learnedheuristic_interface.py +++ b/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -1,3 +1,4 @@ +import operator from typing import Optional from torch._inductor.autoheuristic.autoheuristic_utils import ( @@ -51,7 +52,9 @@ def get_decision( for choice in choices: predicted_feedback = self.get_feedback(context, choice) choice2feedback[choice] = predicted_feedback - sorted_choices_feedback = sorted(choice2feedback.items(), key=lambda t: t[1]) + sorted_choices_feedback = sorted( + choice2feedback.items(), key=operator.itemgetter(1) + ) highest_feedback = sorted_choices_feedback[-1][1] second_highest_feedback = sorted_choices_feedback[-2][1] if highest_feedback / second_highest_feedback > self.get_confidence_threshold(): diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index d605252def3e35..c936fbe92c6713 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -585,7 +585,7 @@ def __init__( num_buffers_warp_spec: int = 0, matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit - kpack: int = 0, # ROCm specific gemm paramete + kpack: int = 0, # ROCm specific gemm parameter ) -> None: super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) self.module_path = module_path @@ -874,7 +874,7 @@ def __str__(self) -> str: return f"{self.kernel_name=}" -@functools.lru_cache(None) +@functools.cache def get_tuning_process_pool() -> TuningProcessPool: pool = TuningProcessPool() atexit.register(pool.shutdown) diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 90693353c87aa5..b7bab02da5e4b4 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -127,6 +127,25 @@ def get_conv_configs( conv_heuristics = self.get_config_heuristics(device_type) return conv_heuristics.get_conv_configs() + # Flex attention configs + def get_flex_attention_fwd_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_attn_fwd_configs(head_dim, dtype) + + def get_flex_attention_bwd_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_attn_bwd_configs(head_dim, dtype) + + def get_flex_decode_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_decode_configs(head_dim, dtype) + def triton_kernel_kwargs( self, kernel_cls: type[TritonKernel], @@ -176,7 +195,9 @@ def should_use_persistent_reduction( if cooperative_reduction: # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements try: - threshold *= 32 // min(V.graph.sizevars.size_hint(features.numel), 32) + threshold *= 32 // min( + V.graph.sizevars.size_hint_or_throw(features.numel), 32 + ) except ValueError: pass # unbacked symint diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 130ecaed878bd9..f47e51b4673777 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -10,6 +10,7 @@ import io import itertools import json +import logging import os import pickle import pkgutil @@ -27,7 +28,7 @@ from copy import copy from ctypes import c_void_p, CDLL, cdll from datetime import timedelta -from functools import partial +from functools import lru_cache, partial from pathlib import Path from time import time, time_ns from types import ModuleType @@ -50,6 +51,10 @@ from torch._dynamo.exc import SkipFrame from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed from torch._inductor import config, exc, metrics +from torch._inductor.codegen.common import ( + custom_backend_passes, + init_backend_registration, +) from torch._inductor.codegen.cuda import cuda_env from torch._inductor.codegen.rocm.compile_command import ( rocm_compile_command, @@ -71,13 +76,17 @@ normalize_path_separator, ) from torch._inductor.cpu_vec_isa import pick_vec_isa -from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType +from torch._inductor.custom_graph_pass import ( + CustomGraphModulePass, + CustomGraphPass, + CustomGraphPassType, +) from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param from torch._inductor.runtime.compile_tasks import _reload_python_module from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir from torch._inductor.utils import ( ALIGN_BYTES, - clear_on_fresh_inductor_cache, + clear_on_fresh_cache, is_linux, is_windows, ) @@ -94,6 +103,7 @@ CacheArtifactFactory, CacheArtifactManager, ) +from torch.export.pt2_archive._package_weights import TensorProperties, Weights from torch.export.pt2_archive.constants import CUSTOM_OBJ_FILENAME_PREFIX from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv from torch.utils._ordered_set import OrderedSet @@ -109,26 +119,6 @@ if config.is_fbcode(): from triton.fb.build import build_paths - from torch._inductor.fb.utils import ( - log_global_cache_errors, - log_global_cache_stats, - log_global_cache_vals, - use_global_cache, - ) -else: - - def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: - pass - - def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: - pass - - def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: - pass - - def use_global_cache() -> bool: - return False - T = TypeVar("T") @@ -151,11 +141,14 @@ def use_global_cache() -> bool: LOCK_TIMEOUT = 600 output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") -log = torch._logging.getArtifactLogger(__name__, "codecache") +log = logging.getLogger(__name__) def use_re_build() -> bool: - if config.is_fbcode(): + """ + Use for CUTLASS compilation only right now. + """ + if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()): from triton.fb.re_build_helper import should_build_locally return not should_build_locally() @@ -175,26 +168,17 @@ def get_kernel_bin_format(device: str) -> str: return "" -@functools.lru_cache(None) -def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]: - return ( - Path(os.path.join(global_cache_dir, CacheBase.get_system()["hash"])) - if global_cache_dir is not None - else None - ) - - class CacheBase: @staticmethod - @functools.lru_cache(None) + @functools.cache def get_system() -> dict[str, Any]: - try: - from triton.compiler.compiler import triton_key + from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key + if HAS_TRITON: # Use triton_key instead of triton.__version__ as the version # is not updated with each code change triton_version = triton_key() - except ModuleNotFoundError: + else: triton_version = None try: @@ -224,15 +208,11 @@ def get_system() -> dict[str, Any]: return system @staticmethod - @clear_on_fresh_inductor_cache - @functools.lru_cache(None) + @clear_on_fresh_cache + @functools.cache def get_local_cache_path() -> Path: return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) - @staticmethod - def get_global_cache_path() -> Optional[Path]: - return get_global_cache_path_impl(config.global_cache_dir) - def __init__(self) -> None: self.system = CacheBase.get_system() @@ -279,15 +259,6 @@ def set_value(self, *keys: str, value: Any) -> None: class PersistentCache(CacheBase): - @functools.lru_cache(None) # noqa: B019 - def get_global_cache(self) -> dict[str, Any]: - global_cache_path = self.get_global_cache_path() - if global_cache_path is None or not global_cache_path.is_file(): - return {} - with open(global_cache_path) as global_cache_fp: - global_cache = json.load(global_cache_fp) - return global_cache["cache"] - def lookup( self, choices: list[ChoiceCaller], @@ -299,23 +270,17 @@ def lookup( Check to see if we have benchmarked the given choice callers. For each choice caller: - 1. Check global_cache[op][inputs][choice][precision], return benchmark if cached. - 2. Check local_cache[op][inputs][choice][precision], return benchmark if cached. - 3. If benchmark is not None: + 1. Check local_cache[op][inputs][choice][precision], return benchmark if cached. + 2. If benchmark is not None: a. `max_autotune_gemm=True`: benchmark the choice, update local_cache[op][inputs][choice], and return the benchmark. b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. """ precision = torch.get_float32_matmul_precision() - log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision) - log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision) - log_errors = partial( - log_global_cache_errors, self.system, op, inputs, precision - ) timings = {} - def check_cache(cache: dict[str, Any], callback: Any = None) -> bool: + def check_cache(cache: dict[str, Any]) -> bool: """Check if `cache` contains data for all the choices""" hit = True for choice in choices: @@ -327,44 +292,19 @@ def check_cache(cache: dict[str, Any], callback: Any = None) -> bool: # cache miss hit = False break - if callback: - callback(cached=hit) return hit - if config.max_autotune or config.max_autotune_gemm: - local_cache = self.get_local_cache() if config.autotune_local_cache else {} - # check local cache first since it is data specific to the current machine - if ( - not check_cache(local_cache) - and not ( - use_global_cache() - and check_cache(self.get_global_cache(), callback=log_stats) - ) - and benchmark is not None - ): - try: - # re-benchmark everything to try to get consistent numbers from the same machine - timings = benchmark(choices) - assert all(choice in timings for choice in choices) - local_cache.setdefault(op, {}) - local_cache[op].setdefault(inputs, {}).setdefault(precision, {}) - for choice, timing in timings.items(): - local_cache[op][inputs][precision][choice.hash_key()] = timing - except RuntimeError as e: - # catch and log autotuning failures - log_errors(e) - raise e - - self.update_local_cache(local_cache) + local_cache = self.get_local_cache() if config.autotune_local_cache else {} + if (not check_cache(local_cache)) and (benchmark is not None): + # re-benchmark everything to try to get consistent numbers from the same machine + timings = benchmark(choices) + assert all(choice in timings for choice in choices) + local_cache.setdefault(op, {}) + local_cache[op].setdefault(inputs, {}).setdefault(precision, {}) + for choice, timing in timings.items(): + local_cache[op][inputs][precision][choice.hash_key()] = timing - timings_to_log = { - choice.hash_key(): timings[choice] for choice in choices - } - log_vals(timings_to_log) - elif use_global_cache(): - # only check global cache, not local one - check_cache(self.get_global_cache(), callback=log_stats) - # may have a partial cache hit, where not everything is benchmarked + self.update_local_cache(local_cache) return timings @@ -406,7 +346,7 @@ def get_path( def get_hash( content: Union[str, bytes], extra: str = "", hash_type: str = "code" ) -> str: - if hash_type in {"amdgcn", "code", "ptx"}: + if hash_type in {"amdgcn", "code", "ptx", "spv"}: return code_hash(content, extra) if hash_type in {"cubin", "hsaco", "spv"}: return code_hash(repr(content)) @@ -608,7 +548,7 @@ def _reduce_graph_module( defined triton kernels Essentially what we are doing here is a huge hack where user defined triton kernel contain a dynamo time side table and the arguments to the - call_function are indicies into this side table. These arguments are not + call_function are indices into this side table. These arguments are not for hashing purposes since we included the source code into the cache key and the numbers are prone to give false negatives due to ordering. """ @@ -886,6 +826,8 @@ def __init__( self.post_grad_custom_pre_pass = self._get_custom_pass_detail( config.post_grad_custom_pre_pass ) + # TODO: change to more holistic config rather than bundled_autograd_cache + self.precompile_enabled = torch._functorch.config.bundled_autograd_cache self.post_grad_custom_post_pass = self._get_custom_pass_detail( config.post_grad_custom_post_pass ) @@ -896,6 +838,12 @@ def __init__( config._fuse_ddp_communication_passes ) + # Register indcutor backends and custom passes and get their UUIDs. + init_backend_registration() + self.custom_backend_passes = tuple( + map(self._get_custom_pass_detail, custom_backend_passes.values()) + ) + # This is mainly added to handle these two inductor configs, which are (unfortunately) # sometimes cache safe: # - _pre_fusion_custom_pass @@ -921,11 +869,11 @@ def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]: raise AssertionError(f"unknown config type: {str(type(custom_pass))}") def _get_custom_pass_detail( - self, custom_pass: CustomGraphPassType + self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass] ) -> Optional[Any]: if not custom_pass: return None - assert isinstance(custom_pass, CustomGraphPass) + assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass)) return custom_pass.uuid() @@ -1135,7 +1083,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]): current context to validate that a cached entry can be served. - A given graph could have multiple compiled versions, corresponding to different sets of guards. Therefore, we store cache entries in the form: - // + // - On lookup, we compute the key from the graph details, iterate over all leaf files in the corresponding subdirectory, deserialize the entry, and evaluate its guards expression. If the evaluation succeeds, we have a @@ -1566,7 +1514,7 @@ def clear() -> None: pass -@functools.lru_cache(None) +@functools.cache def split_aot_inductor_output_path(path: str) -> tuple[str, str]: """Returns the path where the AOT Inductor compiled kernels are stored.""" if path.endswith(".so"): @@ -1577,7 +1525,7 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]: return path, "" -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class CudaKernelParamCache: cache: dict[str, dict[str, Any]] = {} cache_clear = staticmethod(cache.clear) @@ -1613,9 +1561,12 @@ def set( basename, _ = get_name_and_dir_from_output_file_path(bin_path) if config.aot_inductor.emit_multi_arch_kernel: - assert bin_type == "cubin", "emit_multi_arch_kernel only supported in CUDA" + bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv"} + assert bin_type in bin_type_to_ext.keys(), ( + "multi_arch_kernel_binary only supported in CUDA/XPU" + ) base_path, _ = os.path.splitext(bin_path) - bin_path = base_path + ".fatbin" + bin_path = base_path + bin_type_to_ext[bin_type] asm_path: str = "" if ( @@ -1649,6 +1600,10 @@ def get_keys(cls) -> KeysView[str]: class AotCodeCompiler: + """ + Compile AOT Inductor generated code. + """ + @classmethod def compile( cls, @@ -1659,12 +1614,12 @@ def compile( *, device_type: str, additional_files: list[str], - ) -> Union[list[str], str]: + ) -> Union[list[Union[str, Weights]], str]: """ Returns the .so path, or returns a list of files that were generated if config.aot_inductor.package=True. """ - generated_files = additional_files + generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment] if sys.platform == "win32": raise RuntimeError("AotCodeCompiler not yet supported for inductor") @@ -1710,6 +1665,7 @@ def compile( "wrapper.cpp", extra=cpp_command, specified_dir=specified_output_path, + key=config.aot_inductor.model_name_for_generated_files, ) kernel_code = ( f"// Triton kernels are embedded as comments in {wrapper_path}\n" @@ -1720,6 +1676,7 @@ def compile( "kernel.cpp", extra=cpp_command, specified_dir=specified_output_path, + key=config.aot_inductor.model_name_for_generated_files, ) # Log the AOTInductor wrapper and kernel code, if needed. @@ -1808,8 +1765,8 @@ def _compile_consts(consts: bytes, platform: str) -> str: ) consts_s = Path(consts_s) object_build_options = CppTorchDeviceOptions( - # Intel compiler failed to compile this manully constructed assembly file. - # it is ok to use gcc to compile the .S to a .o and linked with Intel comiler . + # Intel compiler failed to compile this manually constructed assembly file. + # it is ok to use gcc to compile the .S to a .o and linked with Intel compiler . device_type=device_type if device_type != "xpu" else "cpu", aot_mode=graph.aot_mode, compile_only=True, @@ -1938,6 +1895,20 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: else: serialized_weights = b"" + if config.aot_inductor.package_constants_on_disk: + # We need to return a storage key here because the original value tensor might be a clone + weights_dict = Weights( + { + graph.allocated_constant_name[name]: ( + graph.get_original_value_of_constant(name), + TensorProperties(graph.constants[name]), + ) + for name in graph.constants.keys() + if name not in graph.folded_constants + } + ) + generated_files.append(weights_dict) + consts_size = len(serialized_weights) # TODO: Fix mmap weights with cuda @@ -2073,11 +2044,10 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = ( ROCmCodeCache() if torch.version.hip else CUDACodeCache() ) - gpu_kernels_o = [ - entry.output_path - for entry in gpu_codecache.cache.values() - if entry.output_path.endswith(".o") - ] + gpu_kernels_o = gpu_codecache.aot_kernels_o.copy() + # clear the list of aot kernels after each linking + gpu_codecache.aot_kernels_o.clear() + if gpu_kernels_o: assert not config.aot_inductor.emit_multi_arch_kernel, ( "TODO: add emit_multi_arch_kernel support for cutlass kernels" @@ -2091,21 +2061,28 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: asm_files.append(asm_file) cubin_file = value[get_cpp_wrapper_cubin_path_name()] - if config.aot_inductor.emit_multi_arch_kernel: + if config.aot_inductor.emit_multi_arch_kernel and device_type == "cuda": current_arch = _nvcc_arch_as_compile_option() cmd = ( f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " - # Include PTX with the minimum arch as SM80 - "-gencode arch=compute_80,code=compute_80 " - ) - if config.aot_inductor.emit_current_arch_binary: - # Include SASS for the current specific arch, to avoid - # CUDA JIT compilation overhead. In theory, we could do - # this for all archs that are newer than the current arch. - cmd += f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " - subprocess.run( - cmd.split(), capture_output=True, text=True, check=True + # Triton only allows generating PTX version as same as the current arch + f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " + # Include SASS for the current specific arch + f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " ) + try: + subprocess.run( + cmd.split(), + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + print( + f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", + file=sys.stderr, + ) + raise if config.aot_inductor.embed_kernel_binary: # Embed cubin files into model.so using objcopy @@ -2168,7 +2145,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: generated_files.append(weight_file) else: - # TODO: unify to alway use mmap_weights + # TODO: unify to always use mmap_weights generated_files.append(consts_o) so_builder.save_src_to_cmake(cmake_path, consts_o) @@ -2268,12 +2245,12 @@ def convert_arg(arg: Any) -> Any: # Precompiled headers are persistent past program runtime, but associated with one # specific compiler version and set of flags. We explicitly use default_cache_dir here -# because these headers need to be global, rather than ignored by fresh_inductor_cache. +# because these headers need to be global, rather than ignored by fresh_cache. _HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers") _HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks") -@functools.lru_cache(None) +@functools.cache def _precompile_header( header: str, hashable_cmd_line: str, @@ -2355,7 +2332,7 @@ def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str: ) -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class CppCodeCache: """Compiles and caches C++ libraries. Users of this class supply the source code to be compiled, while compilation flags are set by CppBuilder.""" @@ -2564,7 +2541,7 @@ def _worker_compile_cpp( # Customized Python binding for cpp kernels -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class CppPythonBindingsCodeCache(CppCodeCache): cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache_clear = staticmethod(cache.clear) @@ -2745,7 +2722,7 @@ def load_pybinding(cls, *args: Any, **kwargs: Any) -> Any: return cls.load_pybinding_async(*args, **kwargs)() -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class CppWrapperCodeCache(CppPythonBindingsCodeCache): cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache_clear = staticmethod(cache.clear) @@ -2814,7 +2791,7 @@ def _get_uncompiled_header(cls, device: str) -> str | None: return _get_cpp_wrapper_header(device) -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class HalideCodeCache(CppPythonBindingsCodeCache): cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} cache_clear = staticmethod(cache.clear) @@ -2915,7 +2892,9 @@ def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[st return [ f"halide_buffer_t {name};", - f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};", + f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};" + if len(dims) > 0 + else f"halide_dimension_t * {name}_dims = nullptr;", f"{name}.device = {device};", f"{name}.device_interface = {device_interface};", f"{name}.host = {host};", @@ -2959,7 +2938,7 @@ def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str: return glue_code @classmethod - @functools.lru_cache(None) + @functools.cache def config_hash(cls) -> str: command_gen = CppBuilder( name="O", @@ -3003,7 +2982,7 @@ def _search_for_file(suffix: str, errmsg: str) -> str: raise RuntimeError(errmsg) @staticmethod - @functools.lru_cache(None) + @functools.cache def find_libautoschedule(name: str) -> str: sofile = f"libautoschedule_{name.lower()}.so" if "HALIDE_LIB" in os.environ: @@ -3016,7 +2995,7 @@ def find_libautoschedule(name: str) -> str: return HalideCodeCache._search_for_file(sofile, errmsg) @staticmethod - @functools.lru_cache(None) + @functools.cache def find_header(name: str) -> str: if "HALIDE_INCLUDE" in os.environ: path = os.path.join(os.environ["HALIDE_INCLUDE"], name) @@ -3117,40 +3096,40 @@ def build_standalone_runtime(cls) -> str: target = "host-cuda" if device_type == "cuda" else "host" if cls._standalone_runtime_path: assert not os.path.exists(cls._standalone_runtime_path) - # We hit this case in unittests when we run with fresh_inductor_cache() + # We hit this case in unittests when we run with fresh_cache() # Generating a fresh runtime over and over causes errors because we initialize # cuda hundreds of times in the same process and run out of file descriptors. - # Workaround by jail breaking the current fresh_inductor_cache(). + # Workaround by jail breaking the current fresh_cache(). base = default_cache_dir() else: base = cache_dir() dirpath = Path(base) / f"halide-runtime-{target}-{cls.config_hash()}" os.makedirs(dirpath, exist_ok=True) - donefile = str(dirpath / "done") - lockfile = str(dirpath / "lock") - hookfile = str(dirpath / "hooks.cpp") - afile = str(dirpath / "standalone_halide_runtime.a") - sofile = str(dirpath / libname) - if not os.path.exists(donefile): + done_file = str(dirpath / "done") + lock_file = str(dirpath / "lock") + hook_file = str(dirpath / "hooks.cpp") + a_file = str(dirpath / "standalone_halide_runtime.a") + so_file = str(dirpath / libname) + if not os.path.exists(done_file): import halide as hl # type: ignore[import-untyped,import-not-found] from torch.utils._filelock import FileLock - with FileLock(lockfile, LOCK_TIMEOUT): - if not os.path.exists(donefile): - with open(hookfile, "w") as f: + with FileLock(lock_file, LOCK_TIMEOUT): + if not os.path.exists(done_file): + with open(hook_file, "w") as f: if device_type == "cuda": f.write( cls.standalone_runtime_cuda_init.format( cls.find_header("HalideRuntimeCuda.h") ) ) - hl.compile_standalone_runtime(afile, hl.Target(target)) + hl.compile_standalone_runtime(a_file, hl.Target(target)) - name, output_dir = get_name_and_dir_from_output_file_path(sofile) + name, output_dir = get_name_and_dir_from_output_file_path(so_file) halide_cmd_gen = CppBuilder( name=name, - sources=[hookfile, afile], + sources=[hook_file, a_file], output_dir=output_dir, BuildOption=CppTorchDeviceOptions( device_type=device_type, @@ -3160,10 +3139,10 @@ def build_standalone_runtime(cls) -> str: subprocess.check_call( shlex.split(halide_cmd_gen.get_command_line()) ) - touch(donefile) - assert os.path.exists(sofile) - cls._standalone_runtime_path = sofile - return sofile + touch(done_file) + assert os.path.exists(so_file) + cls._standalone_runtime_path = so_file + return so_file @classmethod def _get_uncompiled_header(cls, device: str) -> str | None: @@ -3216,7 +3195,7 @@ def touch(filename: str) -> None: open(filename, "a").close() -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class PyCodeCache: # Track the loaded modules so we can remove the on-disk artifacts when # clearing the cache. Note also that we may load the same path more @@ -3290,12 +3269,14 @@ def cache_clear(cls, purge: bool = False) -> None: cls.modules_no_attr.clear() @classmethod - @functools.lru_cache(None) + @functools.cache def stack_frames_for_code( cls, path: str, lineno: int ) -> Optional[list[dict[str, Any]]]: if path not in cls.linemaps: return None + if len(cls.linemaps[path]) == 0: + return None # [(starting_line, ), ...] lines, nodes = cls.linemaps[path] p = bisect_right(lines, lineno) @@ -3381,7 +3362,9 @@ def cutlass_key() -> bytes: Note: OSS and fbcode will have different keys. """ if config.is_fbcode(): - with importlib.resources.path("cutlass", "src_hash.txt") as resource_path: + with importlib.resources.path( + "cutlass_library", "src_hash.txt" + ) as resource_path: with open(resource_path) as resource_file: return resource_file.read().encode() @@ -3600,7 +3583,15 @@ def __del__(self) -> None: self.close() -@clear_on_fresh_inductor_cache +@lru_cache +def binary_error_path(output_path: str) -> str: + """ + standard format for the error path + """ + return output_path + ".error" + + +@clear_on_fresh_cache class CUDACodeCache: """ A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS. @@ -3616,9 +3607,47 @@ class CacheEntry: error_json: Optional[str] = None cache: dict[str, CacheEntry] = {} - cache_clear = staticmethod(cache.clear) + aot_kernels_o: list[str] = [] _SOURCE_CODE_SUFFIX = "cu" + @staticmethod + def cache_clear() -> None: + CUDACodeCache.cache.clear() + CUDACodeCache.aot_kernels_o.clear() + + @staticmethod + @lru_cache(maxsize=4) + def get_kernel_binary_remote_cache( + caching_enabled: bool, caching_available: bool + ) -> Optional[Any]: + """ + Get or create the class instance of the CUTLASSKernelBinaryRemoteCache. + + Args: + caching_enabled: Whether binary remote caching is enabled + caching_available: Whether we're in fbcode environment + + Returns: + CUTLASSKernelBinaryRemoteCache: The class instance of the kernel binary remote cache + """ + if not caching_enabled: + log.debug("CUTLASSKernelBinaryRemoteCache not requested, skipping") + return None + if not caching_available: + return None + + try: + from torch._inductor.fb.kernel_binary_remote_cache import ( + CUTLASSKernelBinaryRemoteCache, + ) + + return CUTLASSKernelBinaryRemoteCache() + except ImportError: + log.debug( + "CUTLASSKernelBinaryRemoteCache not available, remote caching disabled" + ) + return None + @classmethod def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: """ @@ -3644,9 +3673,6 @@ def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: cutlass_key(), # hack to deal with AOTI .o compilation ] - + [dst_file_ext] - if dst_file_ext == "o" - else [] ) key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=extra) return key, input_path @@ -3657,33 +3683,65 @@ def compile( ) -> tuple[str, str, str]: """ Compiles CUDA source_code into a file with dst_file_ext extension. + If dst_file_ext is "so", first compiles to ".o" and then links to ".so". Returns a tuple of dst_file_path, hash_key, source_code_path """ - key, input_path = cls.write(source_code, dst_file_ext) - if key not in cls.cache: + if dst_file_ext == "so": + # Two-step compilation: first compile to .o, then link to .so + obj_path, _, _ = cls.compile(source_code, "o", extra_args) + key, input_path = cls.write(source_code, dst_file_ext) + src_files, operation_name = [obj_path], "Linking" + else: + # Regular compilation for non-.so files + key, input_path = cls.write(source_code, dst_file_ext) + src_files, operation_name = [input_path], "Compilation" + + key_with_ext = key + dst_file_ext + if key_with_ext not in cls.cache: from torch.utils._filelock import FileLock lock_dir = get_lock_dir() lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext - if os.path.exists(output_path + ".error"): - with open(output_path + ".error", encoding="utf-8") as fh: + error_path = binary_error_path(output_path) + binary_remote_cache = cls.get_kernel_binary_remote_cache( + caching_enabled=config.cuda.use_binary_remote_cache + and not config.force_disable_caches, + caching_available=config.is_fbcode(), + ) + if binary_remote_cache is not None: + # The remote cache implementation will only download if the file does + # not already exist locally + binary_remote_cache.get(output_path, error_path) + + if os.path.exists(error_path): + with open(error_path, encoding="utf-8") as fh: error_json = fh.read() cmd_parts, error_output = json.loads(error_json) - cls.cache[key] = CUDACodeCache.CacheEntry( + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + # This ensures that a local error is uploaded to the remote cache, + # as we make no assumptions about the remote cache having the same + # information as the local cache + binary_remote_cache.put( + error_path, config.cuda.binary_remote_cache_force_write + ) + cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, error_json ) raise exc.CUDACompileError(cmd_parts, error_output) if not os.path.exists(output_path): cmd = cuda_compile_command( - [input_path], output_path, dst_file_ext, extra_args + src_files, output_path, dst_file_ext, extra_args ) with open(input_path, "a") as f: f.write("\n") - f.write(f"// CUDA Compile cmd\n// {cmd}\n") + f.write(f"// CUDA {operation_name} cmd\n// {cmd}\n") start_time = time() - log.debug("CUDA Compilation: %s", cmd) + log.debug("CUDA %s: %s", operation_name, cmd) cmd_parts = cmd.split(" ") try: if use_re_build(): @@ -3701,34 +3759,54 @@ def compile( except subprocess.CalledProcessError as error: cls._record_cuda_compile_error( error.output.decode("utf-8"), - key, + key_with_ext, cmd_parts, input_path, output_path, + binary_remote_cache, ) raise exc.CUDACompileError(cmd_parts, error.output) from error except Exception as error: if "COMPILE FAILED WITH" in str(error): cls._record_cuda_compile_error( - str(error), key, cmd_parts, input_path, output_path + str(error), + key_with_ext, + cmd_parts, + input_path, + output_path, + binary_remote_cache, ) raise exc.CUDACompileError(cmd_parts, str(error)) from error raise error end_time = time() - log_duration_msg = f"CUDA Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log_duration_msg = f"CUDA {operation_name} took {end_time - start_time} seconds. Command: {cmd}" log.info(log_duration_msg) + else: log.debug( - "CUDA Compilation skipped: %s since output already exists", - input_path, + "CUDA %s skipped: %s since output already exists", + operation_name, + output_path, + ) + # Upload to remote cache if enabled + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + # will log on errors, but not fail out + binary_remote_cache.put( + output_path, config.cuda.binary_remote_cache_force_write ) - cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path, None) - cache_entry: CUDACodeCache.CacheEntry = cls.cache[key] + cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( + input_path, output_path, None + ) + + cache_entry: CUDACodeCache.CacheEntry = cls.cache[key_with_ext] if cache_entry.error_json is not None: # Restore cached Exception and raise it as if we had compiled cmd_parts, error_output = json.loads(cache_entry.error_json) raise exc.CUDACompileError(cmd_parts, error_output.encode("utf-8")) - return (cls.cache[key].output_path, key, input_path) + return (cls.cache[key_with_ext].output_path, key, input_path) @classmethod def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: @@ -3751,18 +3829,33 @@ def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str def _record_cuda_compile_error( cls, error_str: str, - key: str, + key_with_ext: str, cmd_parts: list[str], input_path: str, output_path: str, + # Any here, as the import and type will only work in fbcode + # TODO: Make the typing hint strong here + binary_remote_cache: Any = None, ) -> None: error_json = json.dumps([cmd_parts, error_str]) - cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path, error_json) - with open(output_path + ".error", "w", encoding="utf-8") as fh: + cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( + input_path, output_path, error_json + ) + error_path = binary_error_path(output_path) + with open(error_path, "w", encoding="utf-8") as fh: fh.write(error_json) + # Upload to remote cache directly from memory if enabled + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + binary_remote_cache.put( + error_path, config.cuda.binary_remote_cache_force_write + ) + -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class ROCmCodeCache: @dataclasses.dataclass class CacheEntry: @@ -3770,10 +3863,15 @@ class CacheEntry: output_path: str cache: dict[str, CacheEntry] = {} - cache_clear = staticmethod(cache.clear) + aot_kernels_o: list[str] = [] _SOURCE_CODE_SUFFIX = "cpp" _logged_compiler_version = False + @staticmethod + def cache_clear() -> None: + ROCmCodeCache.cache.clear() + ROCmCodeCache.aot_kernels_o.clear() + @classmethod def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: """ diff --git a/torch/_inductor/codegen/aoti_hipify_utils.py b/torch/_inductor/codegen/aoti_hipify_utils.py index b6ccaab56f8284..eb71d4ee7f3921 100644 --- a/torch/_inductor/codegen/aoti_hipify_utils.py +++ b/torch/_inductor/codegen/aoti_hipify_utils.py @@ -8,7 +8,7 @@ # "... # from ..codecache import CudaKernelParamCache # ..." -# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache +# In such cases, we do not need to hipify_torch the original class/file name in codegen/codecache def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str: diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index b99f7f786cff29..b47c8325e21545 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -17,6 +17,13 @@ class BlockPatternMatcher: Matches block indexing expressions. """ + _indexing_wild_signed_int = functools.partial( + sympy.Wild, properties=[lambda x: x.is_integer] + ) + _indexing_wild_unsigned_int = functools.partial( + sympy.Wild, properties=[lambda x: x.is_integer and x.is_nonnegative] + ) + @classmethod def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: """ @@ -63,9 +70,18 @@ def match_mod_div_block_expr( index = cls._preprocess(index) # Pattern match to find the strides and offset. - wild = functools.partial(sympy.Wild, exclude=[index_var]) - dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)] - strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)] + wild_unsigned_int = functools.partial( + cls._indexing_wild_unsigned_int, exclude=[index_var] + ) + wild_signed_int = functools.partial( + cls._indexing_wild_signed_int, exclude=[index_var] + ) + dims: list[Expr] = [ + wild_unsigned_int(f"dim_mod{idx}") for idx in range(num_dims) + ] + strides: list[Expr] = [ + wild_signed_int(f"stride_mod{idx}") for idx in range(num_dims) + ] # The first dimension's index is computed by division. # The remaining are computed by modulo. @@ -83,7 +99,8 @@ def match_mod_div_block_expr( # for more details. In short, here we check that each subexpression in sympy.Add contains # only FloorDiv or ModularIndexing expressions. if num_dims >= 5: - stride, denom, other = sympy.symbols("stride denominator other", cls=wild) + stride = sympy.symbols("stride", cls=wild_signed_int) + denom, other = sympy.symbols("denominator other", cls=wild_unsigned_int) mod_div_pattern = stride * ModularIndexing(index_var, denom, other) floor_div_pattern = stride * FloorDiv(index_var, denom) first_dim_floor_div_matched = False @@ -167,7 +184,7 @@ def match_affine_block_expr( stride. """ index = cls._preprocess(index) - stride = sympy.Wild("stride", exclude=[index_var]) + stride = cls._indexing_wild_signed_int(name="stride", exclude=[index_var]) m = index.match(index_var * stride) if m is None: return None diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 11aacf0d4379ac..828050d6da1402 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -65,6 +65,7 @@ from torch.fx import GraphModule + from ..custom_graph_pass import CustomGraphModulePass from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode from ..loop_body import LoopBody from ..scheduler import BaseScheduling, Scheduler, SchedulerNode @@ -255,6 +256,15 @@ def get_inputs_that_alias_output(self) -> list[str]: return [] +class TritonScratchWorkspace: + def __init__(self, size: int, generate_dtype_str: Callable[..., str]): + self.size = size + self._generate_dtype_str = generate_dtype_str + + def generate_dtype_str(self) -> str: + return self._generate_dtype_str() + + @dataclasses.dataclass class TensorArg: name: str @@ -282,6 +292,9 @@ class ConstexprArg: @dataclasses.dataclass class TMADescriptorArg: name: str + api_type: str # "experimental" or "stable" + block_shape: Optional[list[sympy.Expr]] # only needed for "stable" + dtype: Optional[torch.dtype] # only needed for "stable" @dataclasses.dataclass @@ -345,12 +358,15 @@ def cpp_device_ptr(self) -> str: def tma_descriptor_helpers(self) -> str: raise NotImplementedError - def cpp_global_scratch(self, idx: int) -> Optional[tuple[str, str]]: + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: # optionally return (scratch definition, arg name) raise NotImplementedError device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} +custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} # The code generated by Inductor consists of two main parts: kernel code and wrapper code. @@ -379,10 +395,12 @@ def register_backend_for_device( device_scheduling: SchedulingConstructor, device_wrapper_codegen: WrapperConstructor, device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, + device_custom_pass: Optional[CustomGraphModulePass] = None, ) -> None: device_codegens[device] = DeviceCodegen( device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen ) + custom_backend_passes[device] = device_custom_pass class BackendFeature(Enum): @@ -441,7 +459,11 @@ def get_wrapper_codegen_for_device( return None -@functools.lru_cache(None) +def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]: + return custom_backend_passes[device] if device in custom_backend_passes else None + + +@functools.cache def init_backend_registration() -> None: from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu @@ -1540,7 +1562,7 @@ def seed_offset(self, name: str, value: int) -> str: def size(self, name: sympy.Symbol) -> str: assert isinstance(name, sympy.Symbol), (type(name), name) if name.name == "seed": - self.sizevars[name] = "seed" # dont' mange the name of seeds + self.sizevars[name] = "seed" # don't manage the name of seeds return "seed" return self._lookup("ks", self.sizevars, name) @@ -1688,7 +1710,7 @@ def is_removed(self, name: str) -> bool: # after you do a call into this kernel, which buffers actually contain # updated data? Modeled off of python_argdefs. def live_output_buffers(self) -> OrderedSet[str]: - live_outs = OrderedSet[str]() + live_outs: OrderedSet[str] = OrderedSet() for inplaced in unique(self.inplace_buffers.values()): if isinstance(inplaced, RemovedArg): continue @@ -1873,7 +1895,7 @@ def generate( line = f"{expr}{self.suffix}" buffer.writeline(line) - # cpp backend cannot determin is_vec at this point + # cpp backend cannot determine is_vec at this point if ( assignment and ( @@ -1948,16 +1970,16 @@ def __init__( self.num_reduction = 0 self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix) - self.must_keep_buffers = OrderedSet[str]() - self.store_buffer_names = OrderedSet[str]() + self.must_keep_buffers: OrderedSet[str] = OrderedSet() + self.store_buffer_names: OrderedSet[str] = OrderedSet() self._load_mask: Optional[str] = None self._load_other: Union[None, int, float] = None # OrderedSet in set_current_node self.current_node: Optional[SchedulerNode] = None self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None - self.removed_buffers = OrderedSet[str]() - self.inplaced_to_remove = OrderedSet[str]() + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() # key: the buffer to write # value: the buffer to read and whose memory can be reused for @@ -2091,7 +2113,7 @@ def indirect_assert( assert upper is None or isinstance(upper, str) if lower and upper: # The conditions need to be in parens because of Python's operator precedence. - # It'd be less error-prone to use and/or/not, which is suported by triton + # It'd be less error-prone to use and/or/not, which is supported by triton cond = f"({lower} <= {var}) & ({var} < {upper})" cond_print = f"{lower} <= {var} < {upper}" elif lower: @@ -2144,7 +2166,7 @@ def remove_kernel_local_buffers(self) -> None: for buf in self.store_buffer_names if buf in scheduler.name_to_buf ) - names_to_remove = OrderedSet[str]() + names_to_remove: OrderedSet[str] = OrderedSet() for name in self.store_buffer_names: if ( name not in self.must_keep_buffers @@ -2223,7 +2245,7 @@ class OptimizationContext: ops_name: str = "" -@functools.lru_cache(None) +@functools.cache def jinja2_env() -> Any: try: import jinja2 @@ -2376,7 +2398,7 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> output_dtype = V.interpreter.current_node.meta.get( OptimizationContext.key, None ).dtype - elif backend in ("triton", "cpp"): + elif backend in ("triton", "cpp", "mps"): dtype_op = getattr(dtype_handler, name) output_dtype = dtype_op(*args, **kwargs) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ee6a728aabe63d..06467f06fc0289 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -86,7 +86,7 @@ _IS_WINDOWS = sys.platform == "win32" -@functools.lru_cache(None) +@functools.cache def get_export_declaration(): return "__declspec(dllexport)" if _IS_WINDOWS else "" @@ -327,7 +327,7 @@ def reduction_prefix_array( Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler MSVC is the only one compiler without VLA. support. Since MSVC can't get good performance here. We just use unique_ptr make it works on MSVC. - For other compilers, we continue to use VLA to get best performence. + For other compilers, we continue to use VLA to get best performance. """ code_buffer = IndentedBuffer() acc_decl = ( @@ -2844,7 +2844,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): # use welford_helper for vec kernel assert self.reduction_depth is not None reduction_size = functools.reduce( - lambda x, y: x * y, self.ranges[self.reduction_depth :] + operator.mul, self.ranges[self.reduction_depth :] ) welford_helper_val = self.welford_helper_cse.generate( self.compute, f"reduction {reduction_key}", write=False @@ -3735,7 +3735,7 @@ def _is_valid_indices( call_ranges[tiling_indice], fallback=0 ) if call_range < factor_lowp: - V.graph.sizevars.guard_lt(call_range, factor_lowp) # type: ignore[arg-type] + V.graph.sizevars.check_lt(call_range, factor_lowp) # type: ignore[arg-type] tiling_factor = factor_lowp // 2 break elif call_ranges[tiling_indice] < factor_lowp: @@ -3798,6 +3798,16 @@ def _select_tiling_indices( class CppKernelProxy(CppKernel): + # Subclass CppKernel, CppVecKernel, etc., to customize code generation. + # Override CppOverrides or CppVecOverrides to emit custom ops. + # Earlier, this meant copying codegen_functions() to use your subclasses. + # Now, use kernel_cls and vec_kernel_cls class attributes instead. + # This lets CppKernelProxy subclasses inject custom behavior cleanly. + # No need to duplicate codegen_functions() just to swap kernel classes. + kernel_cls: type[CppKernel] = CppKernel + vec_kernel_cls: type[CppVecKernel] = CppVecKernel + tile2d_kernel_cls: type[CppTile2DKernel] = CppTile2DKernel + def __init__(self, kernel_group): super().__init__(kernel_group.args, kernel_group.ws.num_threads) self.kernel_group = kernel_group @@ -4113,7 +4123,7 @@ def run(kernel): with kernel.write_to_suffix(): fn(vars, ()) - scalar_kernel = codegen_kernel(CppKernel) + scalar_kernel = codegen_kernel(self.kernel_cls) V.graph.removed_buffers |= scalar_kernel.removed_buffers V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove self.loop_nest = LoopNest.build(scalar_kernel) @@ -4162,13 +4172,13 @@ def run(kernel): metrics.generated_cpp_vec_kernel_count += 1 loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0]) vec_kernel = codegen_kernel( - CppVecKernel, tiling_factors[0], tiling_indices[0] + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0] ) tail_size = loop.size - loop.tiled_size vec_kernel.active_ranges = {loop.var: (0, loop.tiled_size)} if config.cpp.enable_loop_tail_vec and could_masked_vec: tail_kernel = codegen_kernel( - CppVecKernel, + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0], tail_size, @@ -4203,7 +4213,7 @@ def run(kernel): } inner_tail_size = inner_loop.size - inner_loop.tiled_size tile2d_kernel = codegen_kernel( - CppTile2DKernel, + self.tile2d_kernel_cls, tiling_factors[0], tiling_indices, ) @@ -4225,7 +4235,7 @@ def run(kernel): outer_tail_size if outer_r == "tail" else None ) kernel = codegen_kernel( - CppTile2DKernel, + self.tile2d_kernel_cls, tiling_factors[0], tiling_indices, _inner_tail_size, @@ -4238,7 +4248,7 @@ def run(kernel): tail_kernel.append(kernel) else: vec_kernel = codegen_kernel( - CppVecKernel, tiling_factors[0], tiling_indices[0] + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0] ) vec_kernel.active_ranges = { outer_loop.var: outer_ranges["main"], @@ -4327,10 +4337,10 @@ def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): assert len(self.kernels) >= 2 main_loop_kernel = self.kernels[0] tail_loop_kernel = self.kernels[-1] - assert isinstance(main_loop_kernel, CppVecKernel) + assert isinstance(main_loop_kernel, self.vec_kernel_cls) # Prefix - if type(tail_loop_kernel) == CppKernel: + if type(tail_loop_kernel) == self.kernel_cls: # if tail loop kernel is a scalar kernel, we need to extend tmp_acc -> tmp_acc_arr[] to # hold the temporary inner loop acc result for outer tail loop tail_loop_kernel.finalize_reduction_prefix( @@ -4358,7 +4368,7 @@ def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): suffix_buf, "C10_UNLIKELY", outer_loop.var ): stack.enter_context(suffix_buf.indent()) - if type(tail_loop_kernel) == CppKernel: + if type(tail_loop_kernel) == self.kernel_cls: reduction_vars = tail_loop_kernel.reduction_var_names for name in reduction_vars: new_name = f"{name}_arr[{outer_loop.var}_tail - {cexpr_index(outer_loop.tiled_size)}]" @@ -4441,6 +4451,10 @@ class ReasonFusedNodes(Enum): class CppScheduling(BaseScheduling): + # Subclass CppKernelProxy to customize codegen without copying codegen_node(). + # Use kernel_proxy_cls to inject custom proxies in CppScheduling subclasses. + # Avoid duplicating codegen_node() just to swap in a custom kernel proxy class. + kernel_proxy_cls: type[CppKernelProxy] = CppKernelProxy # ctypes limits the number of args to 1024, refer to: # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 # We set a conservative threshold here. @@ -4650,7 +4664,7 @@ def can_fuse_multi_outputs_template( isinstance(template_buf.layout, ir.MultiOutputLayout) and isinstance(node2.node, ir.MultiOutput) and len(node2.node.inputs) == 1 - and node2.node.inputs[0].get_name() == template_buf.name + and node2.node.inputs[0].get_name() == template_buf.name # type: ignore[union-attr] ) return False @@ -4865,7 +4879,7 @@ def codegen_outer_loop_node( """ kernel_group = self.kernel_group generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count - cpp_kernel_proxy_list: list[CppKernelProxy] = [] + cpp_kernel_proxy_list: list[self.kernel_proxy_cls] = [] # type: ignore[name-defined] nodes_list: list[list[SchedulerNode]] = [] assert isinstance(node, OuterLoopFusedSchedulerNode) @@ -4897,7 +4911,7 @@ def get_call_ranges(node: BaseSchedulerNode): # https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950 # where the buffer is with size of last dim and contiguous. # Only support this typical case at first. - visited_scheduler_nodes = OrderedSet[str]() + visited_scheduler_nodes: OrderedSet[str] = OrderedSet() for scheduler_node in node.get_nodes(): # all users inside same OuterLoopFusedSchedulerNode assert isinstance(scheduler_node, SchedulerNode) @@ -4986,7 +5000,7 @@ def try_share_local_buffer(local_buffer_layout, local_buffers): layout=local_buffer_layout, ) local_buffers.append(local_buffer_used) - local_to_global_buffers[local_buffer_used.name] = [] + local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index] local_to_global_buffers[local_buffer_used.name].append( global_buffer, ) @@ -5000,7 +5014,7 @@ def try_share_local_buffer(local_buffer_layout, local_buffers): ) for _node in node.get_outer_nodes(): assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) - cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type] cpp_kernel_proxy_list.append(cpp_kernel_proxy) nodes_list.append(_node.get_nodes()) # type: ignore[arg-type] @@ -5041,7 +5055,7 @@ def try_share_local_buffer(local_buffer_layout, local_buffers): for _node in node.get_outer_nodes(): assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) _nodes: list[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] - cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) cpp_kernel_proxy.codegen_nodes(_nodes) kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) @@ -5059,7 +5073,7 @@ def codegen_node( else: nodes: list[SchedulerNode] = node.get_nodes() # type: ignore[assignment] nodes = self.try_loop_split(nodes) - cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) cpp_kernel_proxy.codegen_nodes(nodes) kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) @@ -5124,7 +5138,7 @@ def template_buffer_has_other_users( flag_template_buffer_has_other_users = template_buffer_has_other_users( ctb, template_node.outputs_by_name, epilogue_ir_nodes ) - kernel, render = ctb.make_kernel_render( + kernel, render = ctb.make_kernel_render( # type: ignore[misc] ctb, flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, epilogue_nodes=epilogue_ir_nodes, @@ -5489,6 +5503,13 @@ def get_simd_vec_depth(loops): simd_vec_depth = get_simd_vec_depth(self.loops) + def has_scalar_kernel(loop_nest: LoopNest): + assert isinstance(loop_nest.kernel, CppKernelProxy) + return any( + not isinstance(kernel, CppVecKernel) + for kernel in loop_nest.kernel.kernels + ) + # When the number of steps of the first inner loop is much larger than the number of steps of # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`. if ( @@ -5502,6 +5523,7 @@ def get_simd_vec_depth(loops): simd_vec_depth is not None and max_depth > simd_vec_depth and self.loops[max_depth].is_reduction + and has_scalar_kernel(self) ) ): start_depth = max_depth diff --git a/torch/_inductor/codegen/cpp_flex_attention_template.py b/torch/_inductor/codegen/cpp_flex_attention_template.py index 2542acc6108b83..64e11b00fcc04f 100644 --- a/torch/_inductor/codegen/cpp_flex_attention_template.py +++ b/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -311,7 +311,7 @@ } if (need_pack) { // When the number of gemm is greater than the number of pack, - // the pack overhead can be overlaped. + // the pack overhead can be overlapped. int64_t thresh_size = 64; need_pack = kvSize >= thresh_size && qSize >= thresh_size; if (need_pack) { @@ -977,7 +977,8 @@ def render( # type: ignore[override,return] self.input_dtype = query.layout.dtype num_threads = parallel_num_threads() - buf_out = TensorBox.create(self.output_node) + assert isinstance(self.output_node, ir.IRNode) + buf_out: ir.IRNode = TensorBox.create(self.output_node) if template_buffer_node is not None: buf_out = template_buffer_node options = dict( diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 27231ab72688d1..8f04ac9236136f 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -234,14 +234,14 @@ {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} for (int64_t nci = nc; nci < nc_block_end; nci++) { {%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} -{%- if template.should_block_weights %} +{%- if template.should_block_weights and not is_woq_int4 %} {%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} {%- else %} {%- if is_woq_int4 %} - {%- set tile_W = kernel.slice_nd(W, [("n_start", "n_start + n_size"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %} + {%- set tile_W = kernel.slice_nd(W, [("nci * Nr", "(nci + 1) * Nr"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %} {%- set tile_qparam = kernel.slice_nd( - qscale_and_zeros, [("k_start // group_size", "k_end // group_size"), ("n_start", "n_start + n_size"), ()]) %} + qscale_and_zeros, [("k_start // group_size", "k_end // group_size"), ("nci * Nr", "(nci + 1) * Nr"), ()]) %} {%- else %} {%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %} {%- set tile_qparam = None %} @@ -578,6 +578,10 @@ def get_reindexer(epilogue_node, default_reindexer=None): class CppGemmTemplate(CppTemplate): + """ + GEMM Template for Inductor CPP Backend. + """ + def __init__( self, input_nodes, @@ -802,6 +806,17 @@ def get_num_byte(dtype): if size_cache_B > L1: Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B)) + if ( + config.cpp.use_small_dequant_buffer + and dtype_A is torch.bfloat16 + and dtype_B is torch.uint8 + and Mt_blocks == 1 + ): + # Make a small dequant_B buffer for woq int4 [q_group_size, Nr] + # Since when Mt_blocks == 1, L1-reside B block can't be reused by A. + if Kc_blocks * Kr >= self.q_group_size(): + Kc_blocks = self.q_group_size() // Kr + # Step 2: Decide Mc assuming A block is L2-reside. min_Mc_ratio = 2 # TODO(jgong5): something to tune? min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr) @@ -1092,7 +1107,7 @@ def prep_weight( """ NOTE Weight prep consists of 2 separate steps: 1. Blocking the weight tensor into a 3D shape: [n//block_n, k, block_n] - This is always done if the weight tensor is contant, i.e. for all GEMM and some BMM. + This is always done if the weight tensor is constant, i.e. for all GEMM and some BMM. For BMM, we also block non-contiguous weight tensors, since they would be reshaped anyway. This assumes that blocked, contiguous weights will be more efficient for the GEMM kernel, and is worth the overhead of reshape and blocking. @@ -1125,9 +1140,12 @@ def prep_weight( new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight) padding = padded_n - n - if should_block_weight: + if should_block_weight and not cls.is_woq_int4(): blocked_w = cls.block_weight(W, new_size, padding) new_inputs[1] = cls.pack_vnni_weight(blocked_w, micro_gemm, new_size) + elif should_block_weight: + assert cls.is_woq_int4() + new_inputs[1] = cls.block_weight(W, new_size, padding) elif isinstance(W, ir.IRNode): # Require W layout to be fixed & contiguous, happens inplace. ir.ExternKernel.require_contiguous(W) @@ -1209,7 +1227,7 @@ def block_weight(cls, W, new_size, padding): permute_size[-2], permute_size[-3] = permute_size[-3], permute_size[-2] blocked_w = L.constant_pad_nd(W, (0, padding)) blocked_w = L.permute( - L.view(blocked_w, permute_size), + L.view(blocked_w, permute_size), # type: ignore[arg-type] permute_dims, ) else: @@ -1341,7 +1359,7 @@ def get_options( reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = [] epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = [] fake_buffers: list[ir.Buffer] = [] - Y_aliases = OrderedSet[str]() + Y_aliases: OrderedSet[str] = OrderedSet() use_local_acc = ( self.layout.dtype != torch.float @@ -1689,7 +1707,68 @@ def q_group_size(cls): @staticmethod def check_if_block_weight(W, micro_gemm): # For WOQ INT4, weight is already packed - return False + # However, for AMX microkernel, we want to change the blocking of weight + from .cpp_micro_gemm import CppMicroGemmWoQInt4Amx + + return isinstance(micro_gemm, CppMicroGemmWoQInt4Amx) + + @classmethod + def block_weight(cls, W, new_size, padding): + # This method is called only if AMX microkernels are used. + # In this case, we unpack and repack weight so that block_n=32 + # the format of packed weight is described here: + # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 + if isinstance(W, ir.IRNode): + # in this case, we do nothing + ir.ExternKernel.require_contiguous(W) + blocked_w = W + else: + # in this case, we unpack and repack weight + assert isinstance(W, torch.Tensor) + assert W.dim() == 2 + N = W.size(0) + K = W.size(-1) * 2 + G = cls.q_group_size() + # x and qscales_and_zeros are in bfloat16 instead of float to use the optimized kernel + # so that the unpacking process is faster + x = torch.eye(K).bfloat16() + # Here we use scale=1 and qzero=8 because we want to unpack weight + # without dequantizing it. The qzero here is 8 instead of 0 because + # int4 values are converted to [-7, 8] in the _weight_int4pack_mm_for_cpu kernel: + # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L95 + qscales_and_zeros = ( + torch.tensor([1.0, 8.0]) + .bfloat16() + .expand(K // G, N, 2) + .contiguous() + ) + # shape: [K, N] + unpacked_w = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + W, + G, + qscales_and_zeros, + ).to(torch.uint8) + block_n = 32 + # shape: [N // block_n, K, block_n] + w_blocked = ( + unpacked_w.view(K, N // block_n, block_n) + .permute(1, 0, 2) + .contiguous() + ) + # pack 2 int4 -> 1 int8 + # block_n: [a0, a1, ..., a15, b0, b1, ..., b15] + # -> [(a0 & 0xf) | (b0 << 4), (a1 & 0xf) | (b1 << 4), ...] + # shape: [N // block_n, K, 2, block_n // 2] + w_blocked = w_blocked.view(N // block_n, K, 2, block_n // 2) + # shape: [N // block_n, K, block_n // 2] + w_blocked_packed = (w_blocked[:, :, 0, :] & 0xF) | ( + w_blocked[:, :, 1, :] << 4 + ) + # shape: [N, K // 2] + blocked_w = w_blocked_packed.view(N, K // 2) + + return blocked_w return CppWoqInt4GemmTemplateInstance diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 8cd595299b0454..c9c54553756fdf 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import dataclasses +import operator import sys from enum import Enum from typing import Callable, Optional @@ -683,7 +684,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm): // Use 2 implementations for the transposed B: // First implementation: // Transpose first and then perform outer product calculation in sub-blocks, - // which introduces an additional tranpose overhead of [K, N] compared to the non-tranpose version. + // which introduces an additional transpose overhead of [K, N] compared to the non-transpose version. // Second implementation: // Directly perform inner product calculation in sub-blocks, // which introduces an additional vector reduction of [M, N] compared to the non-tranpose version. @@ -1000,7 +1001,7 @@ def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): ) class CppMicroGemmAMX(CppMicroGemm): """ - This class generates the code for micro gemm using Advanced Matrix eXtention (AMX) + This class generates the code for micro gemm using Advanced Matrix extension (AMX) instructions available in 4th generation Intel Xeon for compute. It supports input types of torch.bfloat16 with fp32 output. """ @@ -1230,10 +1231,6 @@ def codegen_define(self, kernel: CppTemplateKernel) -> str: else: assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16" num_columns = block_n // 16 - if self.is_woq_int4(): - # block_n for woq int4 is 64, which is too large for micro kernel - # so we split it into 2x32. Here num_columns = 2. - num_columns //= 2 options = { "declare_kernel": self.get_kernel_declaration(), "use_cached_dequantized_B": ( @@ -1632,8 +1629,8 @@ def is_woq_int4(self): *generate_gemm_config( VecAMX, [ # (block_m, block_n, block_k) - (16, 64, 32), - (32, 64, 32), + (16, 32, 32), + (32, 32, 32), ], input_dtype=torch.bfloat16, input2_dtype=torch.uint8, @@ -1645,8 +1642,8 @@ def is_woq_int4(self): class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): """ This class generates the code for WoQ int4 micro gemm using AMX intrinsics, - which are available on 4th and 5th generation Intel Xeon. - Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2] + which are available on 4th and newer generations of Intel Xeon. + Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2] Shape of packed ScalesAndZeros = [K // group_size, N, 2] Reuse TEMPLATE_KERNEL of CppMicroGemmAMX. """ @@ -1659,7 +1656,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): {{declare_kernel}} { {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); - {{kernel.assert_function}}({{block_n}} == 64, "block_n must be 64 for WOQ int4"); + {{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4"); // Create a stack-allocated buffer for tiles of B. // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements. @@ -1673,6 +1670,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K; const int KB = K / BLOCK_K; + __m512i b32[COLS * 2]; __m512 vb[COLS * 2]; __m512 scale[COLS]; __m512 zero[COLS]; @@ -1758,7 +1756,7 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): // Dequantize a B block of 2 * block_n into bf16 // So, it handles k and k+1 at the same time auto dequantize_B = [&](int n) { - constexpr int64_t ldb_int4 = BLOCK_N / 2; // 32 + constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16 for (int k = 0, kb = 0; k < K; k += 2) { // Since block_k must be 32 for AMX microkernels, k_start may not be // a multiple of q_group_size. In that case, we need to load scales @@ -1768,35 +1766,25 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): } // load 256 bits = 64 elements in int4 - __m256i b4 = _mm256_loadu_si256((__m256i*)(B + n * K + k * ldb_int4)); if (k + PREFETCH_SIZE_K < K) { _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0); } - __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4)); - vb[0] = _mm512_permutexvar_ps(b32, lut); + __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4)); + b32[0] = _mm512_cvtepu8_epi32(b4); + b32[1] = _mm512_srli_epi32(b32[0], 4); + vb[0] = _mm512_permutexvar_ps(b32[0] , lut); vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]); - vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); - vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]); - - b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1)); - vb[1] = _mm512_permutexvar_ps(b32, lut); + vb[1] = _mm512_permutexvar_ps(b32[1], lut); vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]); - vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); - vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]); - b4 = _mm256_loadu_si256((__m256i*)(B + n * K + (k + 1) * ldb_int4)); - b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4)); - vb[0 + COLS] = _mm512_permutexvar_ps(b32, lut); + b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4)); + b32[0 + COLS] = _mm512_cvtepu8_epi32(b4); + b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4); + vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut); vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]); - vb[2 + COLS] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); - vb[2 + COLS] = _mm512_fmadd_ps(vb[2 + COLS], scale[2], zero[2]); - - b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1)); - vb[1 + COLS] = _mm512_permutexvar_ps(b32, lut); + vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut); vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]); - vb[3 + COLS] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); - vb[3 + COLS] = _mm512_fmadd_ps(vb[3 + COLS], scale[3], zero[3]); for (int i = 0; i < COLS; i++) { // convert to VNNI @@ -1810,57 +1798,52 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): auto v = _mm512_castsi256_si512(v0_bf16); v = _mm512_inserti64x4(v, v1_bf16, 1); // store the VNNI format bfloat16 values - // split block_n into 2x32 - {{input_t}}* addr = dequantized_B_buf + K * 32 * (i / 2) + k * 32 + (i % 2) * 32; + {{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32; _mm512_storeu_si512(addr, v); } } }; - const int64_t updated_ldb = {{block_n}} / 2; for (int64_t n = 0; n < N; n += {{block_n}}) { // Dequantize K * block_n int8 B elements into BF16 dequantize_B(n); - // for woq int4, block_n is 64, which is too large for micro kernel - for (int64_t ni = 0; ni < {{block_n}}; ni += 32) { - for (int64_t m = 0; m < M; m += {{block_m}}) { - int64_t block_m = std::min(M - m, {{block_m}}); - int64_t m_tail = m; - {%- for num_rows in range(block_m, 0, -16) %} - {%- if num_rows != block_m %} - else - {%- endif %} - if (block_m >= {{num_rows}}) { - {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( - amx_state, - A + m * lda, - dequantized_B_buf + ni * K, - C + m * ldc + n + ni, - K, - lda, - updated_ldb, - ldc, - 16 - ); - block_m -= {{num_rows}}; - m_tail += {{num_rows}}; - } - {%- endfor %} - if (block_m > 0) { - {{kernel_name}}_amx_kernel_16_{{num_columns}}( - amx_state, - A + m_tail * lda, - dequantized_B_buf + ni * K, - C + m_tail * ldc + n + ni, - K, - lda, - updated_ldb, - ldc, - block_m - ); - } - } // for m - } // for ni + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; + {%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, + dequantized_B_buf + n * K, + C + m * ldc + n, + K, + lda, + {{block_n}}, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } + {%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, + dequantized_B_buf + n * K, + C + m_tail * ldc + n, + K, + lda, + {{block_n}}, + ldc, + block_m + ); + } + } // for m } // for n } """ @@ -1921,6 +1904,22 @@ def create_from_config(cls, config: CppMicroGemmConfig): alpha, ) + def skip_amx_kernel_for_woq(config, dynamic_M, micro_gemm_cls): + # For WoQ GEMM, AMX micro-kernel may not perform well if m is small. + # Exception: for dynamic shapes, we consider using the AMX micro-kernel. + if ( + dynamic_M + or input_dtype != torch.bfloat16 + or input2_dtype not in [torch.int8, torch.uint8] + ): + return False + # For WOQ INT8, use AMX for m >= block_m + # For WOQ INT4, use AMX for m >= 5 + block_m, *_ = config.register_blocking + is_woq_int4 = micro_gemm_cls == CppMicroGemmWoQInt4Amx + m_threshold = 5 if is_woq_int4 else block_m + return m < m_threshold + assert isinstance(n, int) or n.is_number, n assert isinstance(k, int) or k.is_number, k from ..utils import has_free_symbols @@ -1963,15 +1962,9 @@ def create_from_config(cls, config: CppMicroGemmConfig): ): continue block_m, block_n, block_k = config.register_blocking - if ( - config.vec_isa_cls == VecAMX - and m < block_m - and not dynamic_M - and input_dtype == torch.bfloat16 - and input2_dtype in [torch.int8, torch.uint8] + if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq( + config, dynamic_M, cls ): - # For WoQ GEMM, AMX micro-kernel may not perform well if m < block_m. - # Exception: for dynamic shapes, we consider using the AMX micro-kernel. continue # Criteria on the ranking of configurations # 1. ISA: AMX > VEC @@ -2015,4 +2008,4 @@ def create_from_config(cls, config: CppMicroGemmConfig): else: return None # TODO(jgong5): allow autotuning on choices of configs - return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:]) + return create_from_config(*max(matched_configs, key=operator.itemgetter(0))[1:]) diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 4694a4d0d097c4..b7a830a5010514 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -33,7 +33,7 @@ def parse_expr_with_index_symbols(expr): return expr.subs(int_symbols) -def wrap_with_tensorbox(node) -> ir.TensorBox: +def wrap_with_tensorbox(node) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: return ( ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) ) @@ -161,6 +161,7 @@ def slice_nd(self, node, ranges: list[tuple[Any, Any]]) -> ir.ReinterpretView: assert len(_range) == 2 start, end = parse_expr_with_index_symbols(_range) sliced = L.slice_(sliced, dim, start, end, clamp=False) + assert isinstance(sliced, ir.TensorBox) assert isinstance(sliced.data, ir.ReinterpretView), sliced.data return sliced.data @@ -173,10 +174,10 @@ def select(self, node, dim: int, idx: int) -> ir.ReinterpretView: assert isinstance(sliced.data, ir.ReinterpretView), sliced.data return sliced.data - def view(self, node, sizes: list[Any]) -> ir.View: + def view(self, node, sizes: list[Any]) -> ir.IRNode: node = wrap_with_tensorbox(node) sizes = parse_expr_with_index_symbols(sizes) - return L.view(node, sizes).data + return L.view(node, sizes).data # type: ignore[arg-type] def permute(self, node, dims): node = wrap_with_tensorbox(node) @@ -585,7 +586,7 @@ def info_dict( ) -> dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]: return {"backend": "CPP", "op_type": "unknown"} - def output_node(self) -> ir.TensorBox: + def output_node(self) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: return ir.TensorBox.create( ir.CppTemplateBuffer( layout=self.layout, diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 9fc76fcdbac220..647d3ed104dad0 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations +import ctypes import functools import math import os @@ -56,7 +57,7 @@ def __init__(self): if not hasattr(self, "device"): self.device = "cpu" # must be initialized prior to calling super().__init__() - self.included_devices = OrderedSet[str]() + self.included_devices: OrderedSet[str] = OrderedSet() super().__init__() self.declare = "auto " self.declare_maybe_reference = "decltype(auto) " @@ -66,14 +67,14 @@ def __init__(self): self.supports_intermediate_hooks = False self.kernel_callsite_id = count() self.int_array_id = count() # for int array local variable declarations - self.declared_int_array_vars = OrderedSet[str]() + self.declared_int_array_vars: OrderedSet[str] = OrderedSet() self.tmp_tensor_id = count() # for tmp tensor local variable declarations self.arg_var_id = count() - self.used_cached_devices = OrderedSet[str]() - self.used_cached_dtypes = OrderedSet[str]() - self.used_cached_layouts = OrderedSet[str]() - self.used_cached_memory_formats = OrderedSet[str]() - self.used_cond_predicate = OrderedSet[str]() + self.used_cached_devices: OrderedSet[str] = OrderedSet() + self.used_cached_dtypes: OrderedSet[str] = OrderedSet() + self.used_cached_layouts: OrderedSet[str] = OrderedSet() + self.used_cached_memory_formats: OrderedSet[str] = OrderedSet() + self.used_cond_predicate: OrderedSet[str] = OrderedSet() self.cached_output_id = count() self.scalar_to_tensor_id = count() self.custom_op_wrapper_loaded = False @@ -278,12 +279,12 @@ def codegen_input_symbol_assignment( ): code = self.prefix - @functools.lru_cache(None) + @functools.cache def sizeof(name): self.codegen_input_size_var_decl(code, name) return f"{name}_size" - @functools.lru_cache(None) + @functools.cache def strideof(name): self.codegen_input_stride_var_decl(code, name) return f"{name}_stride" @@ -405,12 +406,15 @@ def gen_check(handle_kind, idx, name, tensor): """ ) if not math.isinf(sym_range.upper): + # Limit upper bound to max C long long value (2^63 - 1) + max_long_long = ctypes.c_longlong(2**63 - 1).value + upper_bound = min(sym_range.upper, max_long_long) self.prefix.splice( f""" - if ({name}_size[{dim_idx}] > {sym_range.upper}) {{ + if ({name}_size[{dim_idx}] > {upper_bound}) {{ std::stringstream ss; ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, " - << "expected to be <= {sym_range.upper}, " << "but got: " + << "expected to be <= {upper_bound}, " << "but got: " << {name}_size[{dim_idx}] << "\\n"; throw std::runtime_error(ss.str()); }} @@ -1041,6 +1045,7 @@ def generate_return(self, output_refs: list[str]): output_buffer = V.graph.graph_outputs[idx] if isinstance(output_buffer, ir.BaseView): output_storage = output_buffer.unwrap_view() + assert isinstance(output_storage, (ir.BaseView, ir.MutableBox)) if isinstance(output_storage.data, ir.ConstantBuffer): is_constant_buffer = True @@ -1087,20 +1092,28 @@ def generate_end(self, result): result.writeline("} // namespace torch::aot_inductor\n\n\n") return - # Close the wrapper code block, then write any kernel definitions. - result.splice("'''\n)") - if self.kernel_declarations: - result.splice("\nkernel_src = (\nr'''") - result.splice(self.kernel_declarations.getvalue()) + if config.cpp_wrapper_build_separate: + # Close the wrapper code block, then write any kernel definitions. result.splice("'''\n)") + if self.kernel_declarations: + result.splice("\nkernel_src = (\nr'''") + result.splice(self.kernel_declarations.getvalue()) + result.splice("'''\n)") + else: + result.splice( + """ + kernel_src = '' + """ + ) else: - result.splice( - """ - kernel_src = '' - """ - ) + # Merge main code and kernel code + result.splice(self.kernel_declarations.getvalue()) + self.kernel_declarations.clear() + # Close the wrapper code block + result.splice("'''\n)") - # cpp entry function for JIT with cpp wrapper + kernel_code = "kernel_src" if config.cpp_wrapper_build_separate else "None" + # Cpp entry function for JIT with cpp wrapper result.splice( f""" inductor_entry = CppWrapperCodeCache.load_pybinding( @@ -1108,7 +1121,7 @@ def generate_end(self, result): main_code=cpp_wrapper_src, device_type="{self.device}", num_outputs={len(V.graph.graph_outputs)}, - kernel_code=kernel_src, + kernel_code={kernel_code}, ) """ ) @@ -1343,7 +1356,7 @@ def generate_scatter_fallback( def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version # See the comment in codegen_reinterpret_view about why having something like - # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding # tensor prematurely deallocated, thus the temporary array trick here. indices_str = self._generate_temporary_array_pointer( "AtenTensorHandle", indices @@ -1473,7 +1486,7 @@ def codegen_memory_format(self, memory_format): self.used_cached_memory_formats.add(memory_format_str) return f"cached_torch_memory_format_{memory_format_str}" - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def codegen_int_array_var( self, int_array: str, @@ -1491,12 +1504,19 @@ def codegen_int_array_var( # This is why writeline needs to explicitly passed in as a parameter. var = f"int_array_{next(self.int_array_id)}" ctype = "int64_t" - if var not in self.declared_int_array_vars: - self.declared_int_array_vars.add(var) + if int_array == "{}": + # An array of unknown bound cannot be initialized with {}. if known_statically: - writeline(f"static constexpr {ctype} {var}[] = {int_array};") + writeline(f"static constexpr {ctype} *{var}=nullptr;") else: - writeline(f"const {ctype} {var}[] = {int_array};") + writeline(f"const {ctype} *{var}=nullptr;") + else: + if var not in self.declared_int_array_vars: + self.declared_int_array_vars.add(var) + if known_statically: + writeline(f"static constexpr {ctype} {var}[] = {int_array};") + else: + writeline(f"const {ctype} {var}[] = {int_array};") return var def make_buffer_allocation(self, buffer): @@ -1717,7 +1737,7 @@ def create_new_tensor_handle() -> tuple[str, list[str]]: # ``` return final_tensor_str - def codegen_device_copy(self, src, dst, non_blocking: bool): + def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): """This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to handle cases where dst is not an AtenTensorHandle.""" self.writeline( @@ -1780,7 +1800,7 @@ def codegen_conditional(self, conditional): if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): # in ABI-compatible mode, we need to use the ABI shim function - # to extract a C++ bool from the unrelying scalar bool Tensor + # to extract a C++ bool from the underlying scalar bool Tensor predicate = f"{conditional.predicate.get_name()}_scalar" if predicate not in self.used_cond_predicate: self.codegen_tensor_item( @@ -1844,7 +1864,7 @@ def codegen_while_loop(self, while_loop): # in ABI-compatible mode, the carried inputs are codegened # as buffers outside the while loop and set to the initial # values. at the end of each while_loop iteration, they - # will be assined the carried values. + # will be assigned the carried values. out_name = out.get_name() self.writeline(f"AtenTensorHandle {out_name}_handle;") self.writeline( @@ -1853,7 +1873,7 @@ def codegen_while_loop(self, while_loop): self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") cond_outer_inputs.append(out_name) - # additional inputs will be assinged within the while_loop + # additional inputs will be assigned within the while_loop # iteration directly from the corresponding outer graph buffers cond_outer_inputs.extend(outer_additional_inputs) @@ -1889,6 +1909,11 @@ def generate_extern_kernel_args_decl_if_needed( output_args: _OUTPUT_ARGS_TYPE, raw_outputs: Sequence[ir.Buffer], ): + """ + Generates declarations for external kernel arguments if needed, based on the provided + operator and its arguments. It processes both input and output arguments, categorizing + them into tensor and integer arguments for further code generation. + """ schema = None if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind): obj = raw_args[0] @@ -2013,7 +2038,9 @@ def fill_output_arg( # TODO: Only support None and tensor(s) returns for now, SymInt is not implemented yet for return_type in return_types: - if isinstance(return_type, (torch.TensorType, torch.NoneType)): + if isinstance( + return_type, (torch.TensorType, torch.NoneType, torch.IntType) + ): pass elif isinstance(return_type, torch.OptionalType): assert isinstance(return_type.getElementType(), torch.TensorType) @@ -2028,6 +2055,8 @@ def fill_output_arg( # None output is supported, but Optional return types are not yet supported if output_arg is None: continue + elif isinstance(raw_output_arg, int): + new_int_args.append(str(raw_output_arg)) elif isinstance(output_arg, list): for out in output_arg: assert out is not None, out @@ -2045,11 +2074,38 @@ def fill_output_arg( return new_tensor_args, new_int_args + @staticmethod + def _compatible_with_stableivalue(op: torch._ops.OpOverload) -> bool: + """Returns true if op_overload._schema only utilizes types supported by the AOT + C-shim *internal* function to_ivalue. to_ivalue is an implementation detail, so + these types are not guaranteed to be supported long-term. When generating code + for cpp_wrapper mode, we don't have to be forward-compatible, so changing this + function's implementation in future is fine.""" + supported_types = ( + torch.BoolType, + torch.DeviceObjType, + torch.FloatType, + # ScalarTypeType, LayoutType, and MemoryFormatType are seen as IntType + # when queried via torch.JitType.type. + torch.IntType, + torch.TensorType, + ) + + def type_supported(t: torch.JitType) -> bool: + if isinstance(t, torch.OptionalType): + return type_supported(t.getElementType()) + return isinstance(t, supported_types) + + return all( + type_supported(a.type) + for a in chain(op._schema.arguments, op._schema.returns) + ) + def generate_fallback_kernel_with_runtime_lookup( self, buf_name: str, python_kernel_name: str, - codegen_args: Sequence[str], + get_args: Callable[[], Sequence[str]], op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], raw_args: Sequence[Any], outputs: Sequence[ir.Buffer], @@ -2072,6 +2128,8 @@ def extract_output_name( return mutated_buf_names[0] if isinstance(out, (list, tuple)): return [extract_output_name(o) for o in out] # type: ignore[misc] + if isinstance(out, int): + return str(out) raise AssertionError(f"Unexpected output: {type(out)}") if isinstance(op_overload, torch._ops.HigherOrderOperator): @@ -2108,7 +2166,28 @@ def extract_output_name( return assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload) - self.generate_fallback_kernel_with_runtime_lookup_jit( + for output in output_args: + assert output is None or isinstance(output, str), ( + "fallback kernels with runtime lookup currently only support tensor " + "returns, not more complicated types (such as list-of-list-of-tensor)" + ) + + # In non-AOT mode, we use aoti_torch_call_dispatcher if all the inputs and + # outputs of the op can be represented with StableIValue. This avoids the + # overhead of calling back into Python, and covers most remaining fallback ops. + if self._compatible_with_stableivalue(op_overload): + self.generate_fallback_kernel_with_runtime_lookup_nopython( + get_args, + op_overload, + output_args, # type: ignore[arg-type] + outputs, + ) + return + + # Otherwise, we call back into Python, which has some extra runtime overhead, + # but handles situations like list[Tensor] (currently unrepresentable via + # StableIValue). + self.generate_fallback_kernel_with_runtime_lookup_python( buf_name, python_kernel_name, op_overload, @@ -2260,7 +2339,123 @@ def handle_scalar(scalar): ) return "".join(lines) - def generate_fallback_kernel_with_runtime_lookup_jit( + def generate_fallback_kernel_with_runtime_lookup_nopython( + self, + get_args: Callable[[], Sequence[str]], + op_overload: torch._ops.OpOverload, + output_args: Sequence[Optional[str]], + raw_outputs: Sequence[ir.Buffer], + ) -> None: + """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can + only be called in cpp_wrapper mode, and assumes that the input is a non-None + OpOverload. + + In the future, we may switch over to directly calling c10::Dispatcher if we need + to support more datatypes.""" + if raw_outputs: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] + if output_arg is not None + and not isinstance(raw_output_arg, ir.MutationOutput) + ] + else: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args # type: ignore[arg-type] + if output_arg is not None + ] + + dispatch_lines = IndentedBuffer() + dispatch_lines.writelines(declarations_before_scope) + dispatch_lines.writeline("{") + + with dispatch_lines.indent(): + tmp_var_number = count() + + def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: + # Strip off any temporary references; we're in an indented context, so + # any saved-off variables will be auto-destroyed. + new_codegen_arg = codegen_arg.removeprefix("&temporary_reference(") + if new_codegen_arg != codegen_arg: + # If we removed temporary_reference, there's a good chance the + # variable ends with get() (which would retrieve an ATenTensorHandle + # from a temporary RAII handle). Strip that off too, since we're + # going to save this in a temporary RAII handle. + if codegen_arg.endswith(".get())"): + codegen_arg = new_codegen_arg.removesuffix(".get())") + else: + codegen_arg = new_codegen_arg.removesuffix(")") + + if isinstance(arg_type, torch.OptionalType): + # If we have a pointer to a variable, strip it off and let + # from handle any internal pointers. + codegen_arg = codegen_arg.removeprefix("&") + + if codegen_arg == "nullptr": + return "from(std::nullopt)" + + var_name = f"tmp_var_{next(tmp_var_number)}" + dispatch_lines.writeline( + f"std::optional {var_name}{{{parse_arg(arg_type.getElementType(), codegen_arg)}}};" + ) + return f"from({var_name})" + + raii_var = self.create_tmp_raii_handle_var_if_needed( + codegen_arg, dispatch_lines + ) + temp_handle = raii_var != codegen_arg + + if isinstance(arg_type, torch.TensorType): + if not temp_handle: + # If the RAII tensor being referenced _isn't_ a temporary, + # scoped to this fallback call, then create a new handle + # referencing it which from can steal. + var_name = f"tmp_var_{next(tmp_var_number)}" + dispatch_lines.writeline(f"AtenTensorHandle {var_name};") + dispatch_lines.writeline( + f"aoti_torch_new_tensor_handle({raii_var}, &{var_name});" + ) + return f"from({var_name})" + # If the RAII tensor _is_ a temporary scoped to this fallback call, + # simply release and steal the handle. + return f"from({raii_var}.release())" + return f"from({codegen_arg})" + + codegen_args = get_args() + ivalue_args = ( + parse_arg(a.type, c) + for a, c in zip(op_overload._schema.arguments, codegen_args) + ) + array_len = max(len(codegen_args), len(output_args)) + dispatch_lines.writeline( + f"std::array dispatch_vars{{{', '.join(ivalue_args)}}};" + ) + dispatch_lines.writeline("AOTI_TORCH_ERROR_CODE_CHECK(") + with dispatch_lines.indent(): + dispatch_lines.writeline( + f'aoti_torch_call_dispatcher("{op_overload._schema.name}", "{op_overload._schema.overload_name}", dispatch_vars.data())' # noqa: B950 + ) + dispatch_lines.writeline(");") + + if len(output_args) == 1 and (output := output_args[0]) is not None: + # result is a single tensor + dispatch_lines.writeline( + f"{output} = to(dispatch_vars[0]);" + ) + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue + dispatch_lines.writeline( + f"{output_arg} = to(dispatch_vars[{idx}]);" + ) + + dispatch_lines.writeline("}") + self.writelines(dispatch_lines.getvalue().splitlines()) + + def generate_fallback_kernel_with_runtime_lookup_python( self, buf_name: str, python_kernel_name: str, @@ -2388,7 +2583,7 @@ def c_type_for_prim_type(self, val, type_) -> str: return "int64_t" elif isinstance( type_, (torch.BoolType, torch.SymBoolType, torch.EnumType) - ) or repr(type_) in ("ScalarType", "Layout"): + ) or repr(type_) in ("Layout", "MemoryFormat", "ScalarType"): return "int32_t" elif isinstance(type_, torch.FloatType): return "double" diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index b4ea5c41b489b7..eb3390cbc39cfb 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -98,10 +98,10 @@ def generate_extern_kernel_out(self, *args, **kwargs): self.allow_stack_allocation = False super().generate_extern_kernel_out(*args, **kwargs) - def generate_fallback_kernel(self, *args, **kwargs): + def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: # Disable stack allocation for extern kernels. self.allow_stack_allocation = False - super().generate_fallback_kernel(*args, **kwargs) + super().generate_fallback_kernel(node) def _generate_kernel_call_helper( self, @@ -409,6 +409,7 @@ def use_thread_local_cached_output_tensor(idx, output): output_buffer = V.graph.graph_outputs[idx] if isinstance(output_buffer, ir.BaseView): output_storage = output_buffer.unwrap_view() + assert isinstance(output_storage, (ir.BaseView, ir.MutableBox)) if isinstance(output_storage.data, ir.ConstantBuffer): is_constant_buffer = True @@ -728,7 +729,7 @@ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version # See the comment in codegen_reinterpret_view about why having something like - # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding # tensor prematurely deallocated, thus the temporary array trick here. indices_str = self._generate_temporary_array_pointer( "AtenTensorHandle", @@ -750,7 +751,7 @@ def generate_fallback_kernel_with_runtime_lookup( self, buf_name: str, python_kernel_name: str, - codegen_args: Sequence[str], + get_args: Callable[[], Sequence[str]], op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], raw_args: Sequence[Any], outputs: Sequence[ir.Buffer], @@ -758,10 +759,10 @@ def generate_fallback_kernel_with_runtime_lookup( # No stack allocation when there is a fallback op self.allow_stack_allocation = False super().generate_fallback_kernel_with_runtime_lookup( - buf_name, python_kernel_name, codegen_args, op_overload, raw_args, outputs + buf_name, python_kernel_name, get_args, op_overload, raw_args, outputs ) - def codegen_device_copy(self, src, dst, non_blocking: bool): + def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, # while stack-allocation results in ArrayRefTensor # so disable stack allocation here diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 852e56ee9f3a89..430511ce4ebf08 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -16,11 +16,16 @@ from .. import config from ..codecache import CudaKernelParamCache -from ..ir import GraphPartitionSignature, TensorBox +from ..ir import ( + GraphPartitionSignature, + TensorBox, + TMADescriptorExperimental, + TMADescriptorStable, +) from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper -from .common import get_device_op_overrides +from .common import get_device_op_overrides, TritonScratchWorkspace from .cpp_utils import cexpr from .cpp_wrapper_cpu import CppWrapperCpu from .multi_kernel import MultiKernelCall @@ -115,6 +120,7 @@ def generate(self, wrapper: CppWrapperGpu): prefix.writeline(f"bool {name},") else: raise ValueError(f"Unexpected arg type {arg_type}") + prefix.writeline("int32_t device_idx_,") prefix.writeline( maybe_hipify_code_wrapper( f"{wrapper.device_codegen.cpp_stream_type()} stream_," @@ -206,7 +212,11 @@ def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): arg_types = [arg_type_loookup[name] for name in call_args] arg_signatures = [triton_meta["signature"][name] for name in call_args] call_args_str = wrapper.generate_args_decl( - prefix, call_args, arg_types, arg_signatures + prefix, + call_args, + arg_types, + arg_signatures, + workspace_size=params.get("global_scratch") or 0, ) prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") launch_kernel_args = [ @@ -342,17 +352,34 @@ def generate(self, is_inference): def finalize_prefix(self): """Define the triton kernels now that autotuning is finished""" old_prefix = self.prefix # new content should go at start of prefix + + # Generating triton kernel callers can modify the prefix (cached dtypes), + # so do this before running finalize_prefix(), but put the generated code + # after the finalize_prefix() code. self.prefix = IndentedBuffer() - super().finalize_prefix() for kernel in self._triton_call_wrappers.values(): self.prefix.writeline("\n") kernel.generate(self) + triton_prefix = self.prefix + + self.prefix = IndentedBuffer() + super().finalize_prefix() + + self.prefix.splice(triton_prefix) + self.prefix.writeline("\n") self.prefix.splice(old_prefix) def generate_tma_descriptor(self, desc): self.write_tma_descriptor_helpers_once() + if isinstance(desc, TMADescriptorExperimental): + self._generate_experimental_tma_descriptor(desc) + else: + assert isinstance(desc, TMADescriptorStable) + self._generate_stable_tma_descriptor(desc) + + def _generate_experimental_tma_descriptor(self, desc): # generate data pointer for the source tensor source = self.generate_args_decl( code=self, @@ -368,7 +395,7 @@ def generate_tma_descriptor(self, desc): # `source` is in the form of `&var_x`, where `var_x` is the data pointer # (CUdeviceptr); we dereference `source` and cast to `void*` to pass to - # the data pointer of the source tensor ot the helper function + # the data pointer of the source tensor to the helper function # `init{1,2}DTMADescriptor` ptr = f"reinterpret_cast(*({source}))" dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.dims) @@ -378,6 +405,48 @@ def generate_tma_descriptor(self, desc): args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}" self.writeline(f"{fn}({args});") + def _generate_stable_tma_descriptor(self, desc): + source = self.generate_args_decl( + code=self, + call_args=[self.val_to_arg_str(desc.tensor)], + arg_types=[desc.tensor.get_dtype()], + arg_signatures=[None], + # these args are passed to initNDTMADescriptor, which is NOT a triton kernel + is_triton_kernel=False, + ) + + desc_name = desc.name + # Pack the relevant information into a StableTMADescriptor struct. + # See [Note: AOTI TMA Stable handling] for more details. + self.writeline(f"alignas(64) StableTMADescriptor {desc_name};") + + def fill_array(name, values): + for i, val in enumerate(values): + self.writeline(f"{name}[{i}] = {val};") + + ptr = f"reinterpret_cast(*({source}))" + rank = len(desc.tensor.get_size()) + + fill_array(f"{desc_name}.block_shape", desc.block_shape) + fill_array(f"{desc_name}.global_shape", desc.tensor.get_size()) + fill_array(f"{desc_name}.strides", desc.tensor.get_stride()) + + element_size = self.val_to_arg_str(desc.tensor.get_dtype().itemsize) + fn = "initTMADescriptor" + args = ", ".join( + str(x) + for x in [ + f"&{desc_name}.m", + ptr, + element_size, + rank, + f"{desc_name}.block_shape", + f"{desc_name}.global_shape", + f"{desc_name}.strides", + ] + ) + self.writeline(f"{fn}({args});") + def generate_args_decl( self, code: Union[IndentedBuffer, Self], @@ -385,6 +454,7 @@ def generate_args_decl( arg_types, arg_signatures, is_triton_kernel=True, + workspace_size=0, ): """ Generates any declarations of args to pass into a kernel call, and then returns the arg names. @@ -409,28 +479,74 @@ def generate_args_decl( "fp32": "float", } + def signature_is_tma_desc(sig): + if not sig: + return False + if sig == "nvTmaDesc": + return True + if sig.startswith("tensordesc<"): + return True + return False + + def process_tma_stable_arg(arg, arg_type, arg_signature, var_name): + # [Note: AOTI TMA Stable handling] + # For most args, a single arg passed to the python triton interface + # maps to a single arg in the cubin interface. However, for host-side + # TMA descriptors, a single python arg turns into 1 + 2 * N args in the + # cubin interface (where N is the rank). + # + # To do this: at TMA codegen time (for aoti), we generate a struct + # (StableTMADescriptor) containing the necessary information; and then + # when we call the function (i.e. here), we unpack the struct members. + code.writeline(f"auto {var_name} = {cexpr(arg)};") + + result = [] + result.append(f"&{var_name}.m") + + # from https://github.com/triton-lang/triton/blob/16961b79bdac1b774b42d44e52fd55a266ec2866/third_party/nvidia/backend/driver.py#L111 # noqa: B950 + match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", arg_signature) + assert match is not None + shape = match.group(2) + ndim = shape.count(",") + 1 + + for i in range(ndim): + result.append(f"&{var_name}.block_shape[{i}]") + + for i in range(ndim): + result.append(f"&{var_name}.strides[{i}]") + + return result + def process_args(arg, arg_type, arg_signature=None): var_name = f"var_{next(self.arg_var_id)}" - # ignore nvTmaDesc, as host-side TMA descriptors need + # ignore tma descriptors, as host-side TMA descriptors need # to be passed to the compiled Triton kernel by value - if isinstance(arg_type, UnwrapUnspecArg) and arg_signature != "nvTmaDesc": + if isinstance(arg_type, UnwrapUnspecArg) and not signature_is_tma_desc( + arg_signature + ): self.codegen_tensor_item( arg_type.dtype, arg, var_name, indented_buffer=code, ) - elif isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc": + new_args.append(f"&{var_name}") + elif isinstance(arg_type, torch_dtype) and not signature_is_tma_desc( + arg_signature + ): device_ptr_type = self.device_codegen.cpp_device_ptr() code.writeline( maybe_hipify_code_wrapper( f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" ) ) + new_args.append(f"&{var_name}") elif arg_type in (sympy.Integer, int): code.writeline(f"int {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") elif arg_type in (sympy.Float, float): code.writeline(f"float {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") # For symbolic call arguments, examine the arg signatures from triton meta # to explicitly cast to the right type # Reason: `auto` can infer unexpected type against kernel input signature. @@ -442,9 +558,14 @@ def process_args(arg, arg_type, arg_signature=None): code.writeline( f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" ) + new_args.append(f"&{var_name}") + elif arg_signature and arg_signature.startswith("tensordesc<"): + new_args.extend( + process_tma_stable_arg(arg, arg_type, arg_signature, var_name) + ) else: code.writeline(f"auto {var_name} = {cexpr(arg)};") - new_args.append(f"&{var_name}") + new_args.append(f"&{var_name}") for arg, arg_type, arg_signature in zip_longest( call_args, arg_types, arg_signatures @@ -455,13 +576,17 @@ def process_args(arg, arg_type, arg_signature=None): is_triton_kernel and ( global_scratch := self.device_codegen.cpp_global_scratch( - next(self.arg_var_id) + next(self.arg_var_id), + workspace=TritonScratchWorkspace( + size=workspace_size, + generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)), + ), ) ) is not None ): global_scratch_def, global_scratch_var = global_scratch - code.writeline(maybe_hipify_code_wrapper(global_scratch_def)) + code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def]) new_args.append(f"&{global_scratch_var}") return ", ".join(new_args) @@ -537,6 +662,8 @@ def _generate_kernel_call_helper( self._kernel_name_to_body, arg_types, ) + device_idx = "this->device_idx_" if V.graph.aot_mode else str(device.index) + call_args.append(device_idx) call_args.append(stream) if V.graph.aot_mode: call_args.append("kernels") diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index dcece1ebc53069..0b87a0f0379537 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -1,5 +1,10 @@ from typing import Any, Optional +import sympy + +import torch +from torch.utils._ordered_set import OrderedSet + from ..ir import GraphPartitionSignature from ..virtualized import V from .cpp_wrapper_gpu import CppWrapperGpu @@ -7,6 +12,10 @@ class CppWrapperMps(CppWrapperGpu): + def __init__(self) -> None: + super().__init__() + self._used_kernel_names: OrderedSet[str] = OrderedSet() + @staticmethod def create( is_subgraph: bool, @@ -20,6 +29,7 @@ def _generate_kernel_call_helper( self, kernel_name: str, call_args: list[str], + arg_types: Optional[list[type]] = None, **kwargs: dict[str, Any], ) -> None: """ @@ -36,11 +46,22 @@ def _generate_kernel_call_helper( }); ``` """ + assert arg_types is not None + new_args = [] - for idx, arg in enumerate(call_args[:-2]): - new_args.append( - f"aoti_torch_mps_set_arg({kernel_name}_handle, {idx}, {arg});\n" - ) + for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): + if isinstance(arg_type, torch.dtype): + new_args.append( + f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n" + ) + elif arg_type in (int, sympy.core.symbol.Symbol): + new_args.append( + f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n" + ) + else: + raise NotImplementedError( + f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}" + ) threads, group_size = call_args[-2], call_args[-1] if threads is None: @@ -65,14 +86,23 @@ def _generate_kernel_call_helper( def wrap_kernel_call(self, name: str, call_args: list[str]) -> str: lib_name = name[: -len("_func")] calling_args = " ".join(call_args) - return f""" + + kernel_call_str = "" + + # Only add handle definition if the kernel is not already used + if name not in self._used_kernel_names: + self._used_kernel_names.add(name) + kernel_call_str += f""" auto {name} = {lib_name}.getKernelFunction("generated_kernel"); auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get()); + """ + kernel_call_str += f""" {name}->runCommandBlock([&] {{ {name}->startEncoding(); {calling_args} }}); """ + return kernel_call_str @staticmethod def get_device_include_path(device: str) -> str: diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 889961f2523b2b..67828622fde598 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -144,8 +144,9 @@ def codegen_template( assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), ( "Epilogue nodes must all be instances of ir.ComputedBuffer" ) - kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes) - + kernel, render = ctb.make_kernel_render( # type: ignore[misc] + ctb, epilogue_nodes=epilogue_nodes + ) with kernel: for node in [template_node, *epilogue_nodes]: node.mark_run() diff --git a/torch/_inductor/codegen/cuda/cuda_env.py b/torch/_inductor/codegen/cuda/cuda_env.py index 95be434e03b71e..a11462fc8a0b8c 100644 --- a/torch/_inductor/codegen/cuda/cuda_env.py +++ b/torch/_inductor/codegen/cuda/cuda_env.py @@ -4,7 +4,7 @@ from typing import Optional import torch -from torch._inductor.utils import clear_on_fresh_inductor_cache +from torch._inductor.utils import clear_on_fresh_cache from ... import config @@ -12,7 +12,7 @@ log = logging.getLogger(__name__) -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache @functools.lru_cache(1) def get_cuda_arch() -> Optional[str]: try: @@ -27,7 +27,7 @@ def get_cuda_arch() -> Optional[str]: return None -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache @functools.lru_cache(1) def get_cuda_version() -> Optional[str]: try: @@ -40,6 +40,6 @@ def get_cuda_version() -> Optional[str]: return None -@functools.lru_cache(None) +@functools.cache def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool: return nvcc_path is not None and shutil.which(nvcc_path) is not None diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 4f75588ee99d42..f419ada67e1a38 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -12,7 +12,7 @@ from torch import dtype as torch_dtype from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.scheduler import BaseSchedulerNode -from torch._inductor.utils import do_bench_using_profiling, Placeholder +from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder from torch.utils._sympy.value_ranges import ValueRanges from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE @@ -29,6 +29,7 @@ IRNode, Layout, PrimitiveInfoType, + ShapeAsConstantBuffer, TensorBox, ) from ...utils import sympy_product @@ -81,6 +82,7 @@ class CUDAKernel(Kernel): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list) + self.size_args: list[Union[Expr, int]] = [] # Mapping from arg name to IRNode. self.named_nodes: dict[str, IRNode] = {} @@ -172,6 +174,9 @@ def get_ld(node) -> Union[Expr, int]: LDD = get_ld(Y) return (M, N, K, B, LDA, LDB, LDC, LDD) + def get_dynamic_shape_args(self) -> list[Union[Expr, int]]: + return [*self.get_layout_args(), *self.size_args] + @staticmethod def find_ld_idx(node: IRNode) -> int: strides = node.get_stride() @@ -242,7 +247,6 @@ def def_kernel( self, inputs: list[IRNode], outputs: list[IRNode], - epilogue_inputs: list[IRNode], names_str: str = "", input_reorder: Optional[list[int]] = None, ) -> str: @@ -258,9 +262,10 @@ def def_kernel( e.g. The template might have input argument defined as [X, W, Bias], and the actual input passed into this template could be [Bias, X, W]. In this case, the `input_reorder` would be [2, 0, 1]. + additional_size_args: Additional size arguments for epilogue inputs """ names = [x.strip() for x in names_str.strip().split(",")] - if len(inputs) + len(epilogue_inputs) + len(outputs) != len(names): + if len(inputs) + len(outputs) != len(names): raise RuntimeError( f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" ) @@ -277,24 +282,30 @@ def def_kernel( self.named_nodes[name] = node self.args.input_buffers[node.get_name()] = name - for epilogue_input in epilogue_inputs: - if epilogue_input is not None: - self.named_nodes[epilogue_input.get_name()] = epilogue_input - self.args.input_buffers[epilogue_input.get_name()] = ( - epilogue_input.get_name() - ) - + free_symbols: OrderedSet[Expr] = OrderedSet() for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): if node is not None: self.named_nodes[name] = node self.args.output_buffers[node.get_name()] = name + if name not in ( + "X", + "W", + "Bias", + "Y", + ): # we handle these symbolic shapes explicitly + for expr in itertools.chain(node.get_size(), node.get_stride()): + if isinstance(expr, Expr): + for s in expr.free_symbols: + free_symbols.add(s) # type: ignore[arg-type] + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) self.init_layout_args() - size_args = [ - f"const int {s}" for s in ("M", "N", "K", "B", "lda", "ldb", "ldc", "ldd") - ] + size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] + size_vars.extend(str(s) for s in free_symbols) + self.size_args.extend(free_symbols) + size_args = [f"const int {s}" for s in size_vars] runtime_arg_decls = ",".join( [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] @@ -334,11 +345,11 @@ def call_kernel( else: _, call_args, _, arg_types = self.args.python_argdefs() - layout_args = self.get_layout_args() - call_args.extend(layout_args) # type: ignore[arg-type] + dynamic_shape_args = self.get_dynamic_shape_args() + call_args.extend(dynamic_shape_args) # type: ignore[arg-type] for arg in self.runtime_arg_values: call_args.append(arg) - arg_types.extend("int" for a in layout_args) + arg_types.extend("int" for _ in dynamic_shape_args) for arg in self.runtime_arg_info: arg_types.append(arg.ty) # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar @@ -605,11 +616,26 @@ def __str__(self) -> str: def call_name(self) -> str: return f"cuda_template_kernels.{self.name}" + def kernel_hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + def hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ return "-".join( [ self.category, self.bmreq.hash_key, + str(self.info_dict().get("swizzle")), ] ) @@ -635,7 +661,7 @@ def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType else: return {"backend": "CUDA", "op_type": "unknown"} - def output_node(self) -> TensorBox: + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: self.bmreq.update_workspace_size() return TensorBox.create( CUDATemplateBuffer( diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 1b41e8c105ac5f..7ed67b0daa49f6 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -116,7 +116,7 @@ def generate( # type: ignore[override] expected_args, ) V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) - size_args = V.graph.sizevars.size_hints(kernel.get_layout_args()) + size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] diff --git a/torch/_inductor/codegen/cuda/cutlass_cache.py b/torch/_inductor/codegen/cuda/cutlass_cache.py new file mode 100644 index 00000000000000..7afdd654ea74ad --- /dev/null +++ b/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -0,0 +1,105 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import json +import logging +import os +import time +from typing import Any, Optional + +import torch._inductor.config as config +from torch._inductor.codecache import cutlass_key +from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version +from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer +from torch._inductor.runtime.cache_dir_utils import cache_dir +from torch._inductor.utils import clear_on_fresh_cache + + +log = logging.getLogger(__name__) + + +CONFIG_PREFIX: str = "configs" + + +def get_config_request_key( + arch: str, + cuda_version: str, + instantiation_level: str, +) -> str: + """ + Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level. + """ + hash_target = "-".join( + [ + cutlass_key().hex(), + arch, + cuda_version, + instantiation_level, + ] + ) + return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8] + + +def _generate_config_filename(request_key: str) -> str: + """ + Generate a filename for the full ops. + """ + return f"{CONFIG_PREFIX}_{request_key}.json" + + +@clear_on_fresh_cache +@functools.cache +def maybe_fetch_ops() -> Optional[list[Any]]: + """ + Fetch ops from databases. + """ + if config.force_disable_caches: + return None + + # setup + arch: str = get_cuda_arch() + # get_cuda_version might return "12.4.0" or "12.4" + # but we want to use "12.4" + version: str = ".".join(get_cuda_version().split(".")[:2]) + instantiation_level: str = config.cuda.cutlass_instantiation_level + + # filename and filepath + request_key: str = get_config_request_key(arch, version, instantiation_level) + filename: str = _generate_config_filename(request_key) + filepath: str = os.path.join(cache_dir(), filename) + + # try fetch + serialized_ops: Optional[list[str]] = None + start_time = time.time() + if os.path.isfile(filepath): + # locally + try: + with open(filepath) as f: + serialized_ops = json.load(f) + + assert isinstance(serialized_ops, list), ( + f"Expected serialized ops is a list, got {type(serialized_ops)}" + ) + except Exception as e: + log.warning( + "Failed to load CUTLASS config %s from local cache: %s", + filename, + e, + ) + serialized_ops = None + elif config.is_fbcode(): + from torch._inductor.fb.cutlass_remote_cache import ( + maybe_fetch_cutlass_configs_from_remote, + ) + + # from remote + serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath) + + if serialized_ops is None: + return None + + # deserialize + serializer = get_cutlass_operation_serializer() + full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr] + log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time) + return full_ops diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py index 7198358f882e3d..d25bec5981a239 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -1,5 +1,8 @@ -from typing import Any, Union +from typing import Any, Callable, Union +from sympy import Expr + +import torch._inductor.config as config from torch._inductor.ir import ( ComputedBuffer, InputBuffer, @@ -25,27 +28,6 @@ import textwrap from typing import Union - from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found] - EmptyByte, - ) - from cutlass.backend.epilogue import ( # type: ignore[import-untyped, import-not-found] - dtype2ctype, - ) - from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found] - EpilogueFunctorVisitor, - ) - from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found] - FusionCallbacks, - ) - from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found] - CollectiveEpilogue, - ) - from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found] - PythonASTFrontend, - ) - from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found] - Tensor as CutlassTensor, - ) from cutlass_library import ( DataType, EpilogueScheduleType, @@ -53,26 +35,28 @@ TileDescription, ) + if config.is_fbcode(): + import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401 + else: + import cutlass as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401 + from torch._inductor.codegen.cuda import cuda_env from torch._inductor.utils import IndentedBuffer - _CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated] + _CUTLASS_C_DTYPES = OrderedSet(python_cutlass.backend.epilogue.dtype2ctype.values()) # type: ignore[var-annotated] def create_example_tensors( var_name_to_buffer_name: dict[str, str], name_to_buffer: dict[str, Buffer], - ) -> dict[str, CutlassTensor]: - def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor: + size_hint_fn: Callable[[Union[Expr, int]], int], + ) -> dict[str, python_cutlass.backend.evt.ir.tensor.Tensor]: + def cutlass_tensor_from_buffer( + buffer: Buffer, + ) -> python_cutlass.backend.evt.ir.tensor.Tensor: shape = buffer.get_layout().size stride = buffer.get_layout().stride - assert all(isinstance(x, int) or x.is_integer for x in shape), ( - f"{buffer.get_name()}'s shape {shape} contains symints which aren't supported for cutlass EVT" - ) - assert all(isinstance(x, int) or x.is_integer for x in stride), ( - f"{buffer.get_name()}'s stride {stride} contains symints which aren't supported for cutlass EVT" - ) - shape = tuple(int(x) for x in shape) - stride = tuple(int(x) for x in stride) + shape = tuple(size_hint_fn(x) for x in shape) + stride = tuple(size_hint_fn(x) for x in stride) is_row_major = is_contiguous_strides_for_shape(stride, shape) is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1]) @@ -80,14 +64,14 @@ def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor: if not is_row_major and not is_column_major: raise RuntimeError( f"Cannot create example tensor for {buffer.get_name()} with \ -non-contiguous layout, recieved stride: {stride} and shape: {shape}" +non-contiguous layout, received stride: {stride} and shape: {shape}" ) - return CutlassTensor( + return python_cutlass.backend.evt.ir.tensor.Tensor( shape=shape, - layout_tag=LayoutType.RowMajor - if is_row_major - else LayoutType.ColumnMajor, + layout_tag=( + LayoutType.RowMajor if is_row_major else LayoutType.ColumnMajor + ), element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype), ) @@ -98,28 +82,37 @@ def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor: def trace( fn_src: str, - example_tensors: dict[str, CutlassTensor], + example_tensors: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor], accum_type: DataType, output_type: DataType, tile_description: TileDescription, epilogue_schedule: EpilogueScheduleType, name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], **kwargs: dict[str, Any], ) -> tuple[str, str, str]: cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type] assert cuda_arch >= 90, "Only SM90+ is supported for EVT" epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs) - visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor) - fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False) - collective_epilogue = CollectiveEpilogue( - tile_description, - epilogue_schedule, - accum_type, - output_type, - fusion_callbacks, + visitor = python_cutlass.backend.evt.EpilogueFunctorVisitor( + cuda_arch, epilogue_functor + ) + fusion_callbacks = ( + python_cutlass.backend.evt.backend.emitter_base.FusionCallbacks( + visitor.graph, cuda_arch, emit_CD=False + ) + ) + collective_epilogue = ( + python_cutlass.backend.evt.backend.sm90_emitter.CollectiveEpilogue( + tile_description, + epilogue_schedule, + accum_type, + output_type, + fusion_callbacks, + ) ) evt_name, evt_code = collective_epilogue.emit() - evt_args = _render_argument_type(epilogue_functor, name_to_buffer) + evt_args = _render_argument_type(epilogue_functor, name_to_buffer, size_hint_fn) return evt_name, evt_args, evt_code # Based off of @@ -127,14 +120,20 @@ def trace( # This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function # The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval def _trace( - fn_src: str, example_tensors: dict[str, CutlassTensor], cc: int, **kwargs: Any + fn_src: str, + example_tensors: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor], + cc: int, + **kwargs: Any, ) -> EpilogueFunctor: - class EpilogueFunctor(PythonASTFrontend): + class EpilogueFunctor(python_cutlass.backend.evt.frontend.PythonASTFrontend): def __init__(self, cc: int, **kwargs: Any): self.source = textwrap.dedent(fn_src) super().__init__(cc, **kwargs) - def parse(self, example_inputs: dict[str, CutlassTensor]) -> None: + def parse( + self, + example_inputs: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor], + ) -> None: self.example_inputs = example_inputs self.ast = ast.parse(self.source) self.visit(self.ast) @@ -147,15 +146,16 @@ def parse(self, example_inputs: dict[str, CutlassTensor]) -> None: def _render_argument_type( epilogue_functor: EpilogueFunctor, name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], ) -> str: epilogue_thread_type = epilogue_functor.epilogue_thread_type # Fragile, but this is the only way to guarantee t is expected type because t is a local class def is_nested_visitor_type(t: type) -> bool: - return ( - ".".join([t.__module__, t.__qualname__]) - == "cutlass.backend.c_types.visitor_factory..VisitorType" - ) + return ".".join([t.__module__, t.__qualname__]) in { + "python_cutlass.backend.c_types.visitor_factory..VisitorType", + "cutlass.backend.c_types.visitor_factory..VisitorType", + } buffer = IndentedBuffer() with buffer.set_tabwidth(2): @@ -165,7 +165,10 @@ def render_argument_type(name: str, t: CutlassArgType) -> None: buffer.writeline(f"{{}}, /* {name} */") else: fields = [ - (fname, _get_arg_from_node(ty, name_to_buffer[name])) + ( + fname, + _get_arg_from_node(ty, name_to_buffer[name], size_hint_fn), + ) for fname, ty in t._fields_ ] field_strs = [ @@ -197,19 +200,21 @@ def render_thread_type(name: str, t: CutlassArgType) -> None: return buffer.getvalue() - def _get_arg_from_node(arg_ty: type, node: Buffer) -> str: + def _get_arg_from_node( + arg_ty: type, node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int] + ) -> str: from ..cuda_template import CUTLASSTemplate # Today, arguments are either a pointer to the # node's memory, a stride tuple, the datatype # Once again, need to check for local class type for stride tuple - if ( - str(arg_ty) - == ".TupleType'>" - ): + if str(arg_ty) in { + ".TupleType'>", + ".TupleType'>", + }: DEFAULT_STRIDE_LEN = 3 assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN - stride = [int(x) for x in node.get_layout().stride] + stride = [size_hint_fn(x) for x in node.get_layout().stride] for _ in range(DEFAULT_STRIDE_LEN - len(stride)): stride.append(0) @@ -230,7 +235,7 @@ def render_stride(x: int) -> str: arg_ty in _CUTLASS_C_DTYPES ): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)" - elif issubclass(arg_ty, EmptyByte): + elif issubclass(arg_ty, python_cutlass.backend.c_types.EmptyByte): return "{}" raise NotImplementedError(f"Unsupported arg type: {arg_ty}") diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py index a2ab413ccf66e8..95af1a968a97ce 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -396,7 +396,7 @@ def emit(self, operation): "align_a": str(operation.A.alignment), "align_b": str(operation.B.alignment), "align_c": str(operation.C.alignment), - "align_d": str(operation.C.alignment), + "align_d": str(operation.D.alignment), "transform_a": ComplexTransformTag[operation.A.complex_transform], "transform_b": ComplexTransformTag[operation.B.complex_transform], "math_operation": MathOperationTag[ diff --git a/torch/_inductor/codegen/cuda/cutlass_presets.py b/torch/_inductor/codegen/cuda/cutlass_presets.py index bc97c22e247c6b..346be534e82e6d 100644 --- a/torch/_inductor/codegen/cuda/cutlass_presets.py +++ b/torch/_inductor/codegen/cuda/cutlass_presets.py @@ -5,7 +5,7 @@ from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch -@functools.lru_cache(None) +@functools.cache def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]: """ Generate cutlass presets for the given CUDA arch. @@ -20,220 +20,68 @@ def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]: if arch == "90": preset = presets[0] preset["0"] = [ - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_64x256x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x256x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_256x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", ] - preset["1111"] = [ - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - ] - preset["2222"] = [ - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - ] - preset["3333"] = [ - r"cutlass3x_sm90_tensorop_s64x48x16gemm_.*_64x48x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_4x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_4x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - ] - preset["4444"] = [ - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x8x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_64x192x64_4x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - ] - preset["5555"] = [ - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x32x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_128x32x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x256_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x8x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_128x192x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", - r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x8x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", - r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x256_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", - r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + preset["3332"] = [ + r"cutlass3x_sm90_tensorop_.*_64x48x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_4x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x128x64_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x16x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_64x16x64_4x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_256x128x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_256x192x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x16x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_64x32x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", ] return presets diff --git a/torch/_inductor/codegen/cuda/cutlass_python_evt.py b/torch/_inductor/codegen/cuda/cutlass_python_evt.py index 43c1387f2fb28a..ca5e6031b19cd5 100644 --- a/torch/_inductor/codegen/cuda/cutlass_python_evt.py +++ b/torch/_inductor/codegen/cuda/cutlass_python_evt.py @@ -21,14 +21,20 @@ def scaled_mm_evt( - scale_A_name: str, scale_B_name: str, output_name: str + scale_A_name: str, scale_B_name: str, bias_name: Optional[str], output_name: str ) -> tuple[list[str], dict[str, Any], str]: evt_read_names = [scale_A_name, scale_B_name] var_name_to_buffer_name = {n: n for n in [scale_A_name, scale_B_name]} var_name_to_buffer_name["D"] = output_name var_name_to_buffer_name[_ACCUMULATOR_ARG_NAME] = output_name - evt_py_code = f"def fn(accum, {scale_A_name}, {scale_B_name}):{linesep}\ - D = accum * {scale_A_name} * {scale_B_name}{linesep}\ + expr = f"accum * {scale_A_name} * {scale_B_name}{linesep}" + if bias_name: + expr = f"({expr}) + {bias_name}" + evt_read_names.append(bias_name) + var_name_to_buffer_name[bias_name] = bias_name + + evt_py_code = f"def fn(accum, {','.join(evt_read_names)}):{linesep}\ + D = {expr}{linesep}\ return D{linesep}" return evt_read_names, var_name_to_buffer_name, evt_py_code @@ -158,6 +164,10 @@ def __init__(self, accumulator_node_name: str, removed_buffers: OrderedSet[str]) self.removed_buffers: OrderedSet[str] = removed_buffers self.cur_node: Optional[ComputedBuffer] = None self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants.keys(): + self.name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) self.is_D_assigned = False self.D_var_name = None @@ -226,7 +236,7 @@ def get_renames(self) -> dict[str, str]: return dict(self.var_name_to_buffer_name) def get_reads(self) -> list[str]: - return list(self.reads) + return list(self.reads.difference(self.store_name_to_value.keys())) def get_writes(self) -> list[str]: return list(self.store_name_to_value.keys()) diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 21aaf26729f536..ae1131c76c0b66 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -13,7 +13,8 @@ import sympy import torch -from torch._inductor.utils import clear_on_fresh_inductor_cache +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch._inductor.utils import clear_on_fresh_cache from ... import config from ...ir import Layout @@ -31,17 +32,20 @@ @atexit.register def move_cutlass_compiled_cache() -> None: """Move CUTLASS compiled cache file to the cache directory if it exists.""" - if "cutlass" not in sys.modules: + if not try_import_cutlass.cache_info().currsize > 0: return - import cutlass # type: ignore[import-not-found] + if config.is_fbcode(): + import python_cutlass # type: ignore[import-not-found] + else: + import cutlass as python_cutlass # type: ignore[import-not-found] # noqa: F401 - if not os.path.exists(cutlass.CACHE_FILE): + if not os.path.exists(python_cutlass.CACHE_FILE): return try: - filename = os.path.basename(cutlass.CACHE_FILE) - shutil.move(cutlass.CACHE_FILE, os.path.join(cache_dir(), filename)) + filename = os.path.basename(python_cutlass.CACHE_FILE) + shutil.move(python_cutlass.CACHE_FILE, os.path.join(cache_dir(), filename)) log.debug("Moved CUTLASS compiled cache file to %s", cache_dir()) except OSError as e: log.warning("Failed to move CUTLASS compiled cache file: %s", str(e)) @@ -56,20 +60,18 @@ def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str: return content -@functools.lru_cache(None) +@functools.cache def try_import_cutlass() -> bool: """ We want to support three ways of passing in CUTLASS: 1. fbcode, handled by the internal build system. - 2. pip install nvidia-cutlass, which provides the cutlass_library package - and the header files in the cutlass_library/source directory. - 3. User specifies cutlass_dir. The default is ../third_party/cutlass/, + 2. User specifies cutlass_dir. The default is ../third_party/cutlass/, which is the directory when developers build from source. """ if config.is_fbcode(): try: - import cutlass # type: ignore[import-not-found] import cutlass_library # type: ignore[import-not-found] + import python_cutlass # type: ignore[import-not-found] # noqa: F401 except ImportError as e: log.warning( "Failed to import CUTLASS packages in fbcode: %s, ignoring the CUTLASS backend.", @@ -79,34 +81,6 @@ def try_import_cutlass() -> bool: return True - try: - import cutlass # type: ignore[import-not-found] # noqa: F811 - import cutlass_library # type: ignore[import-not-found] # noqa: F811 - - cutlass_minor_vesion = int(cutlass.__version__.split(".")[1]) - if cutlass_minor_vesion < 7: - log.warning("CUTLASS version < 3.7 is not recommended.") - - log.debug( - "Found cutlass_library in python search path, overriding config.cuda.cutlass_dir" - ) - cutlass_library_dir = os.path.dirname(cutlass_library.__file__) - assert os.path.isdir(cutlass_library_dir), ( - f"{cutlass_library_dir} is not a directory" - ) - config.cuda.cutlass_dir = os.path.abspath( - os.path.join( - cutlass_library_dir, - "source", - ) - ) - - return True - except ModuleNotFoundError: - log.debug( - "cutlass_library not found in sys.path, trying to import from config.cuda.cutlass_dir" - ) - # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. # This is a temporary hack to avoid CUTLASS module naming conflicts. # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. @@ -174,7 +148,7 @@ def link_and_append(dst_link, src_path, parent_dir): ) try: - import cutlass # noqa: F401 + import cutlass # noqa: F401, F811 import cutlass_library.generator # noqa: F401 import cutlass_library.library # noqa: F401 import cutlass_library.manifest # noqa: F401 @@ -250,8 +224,8 @@ def __post_init__(self): self.architectures = _normalize_cuda_arch(self.architectures) -@clear_on_fresh_inductor_cache -@functools.lru_cache(None) +@clear_on_fresh_cache +@functools.cache def _gen_ops_cached(arch, version) -> dict[Any, Any]: # Note: Cache needs to be specific for cuda architecture and version @@ -305,16 +279,17 @@ def gen_ops() -> dict[Any, Any]: """ Generates all supported CUTLASS operations. """ - arch = get_cuda_arch() - version = get_cuda_version() - return _gen_ops_cached(arch, version) + with dynamo_timed("cutlass_utils.gen_ops"): + arch = get_cuda_arch() + version = get_cuda_version() + return _gen_ops_cached(arch, version) DTYPE_TO_CUTLASS_TYPE = { **DTYPE_TO_CPP, torch.float16: "__half", torch.bfloat16: "__nv_bfloat16", - torch.float8_e4m3fn: "cutlass::float_e4m3_t", + torch.float8_e4m3fn: "__nv_fp8_e4m3", } @@ -471,7 +446,9 @@ def __enter__(self, *args, **kwargs): _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile - def my_compile(source_code, dst_file_ext): + def my_compile( + source_code, dst_file_ext, extra_args: Optional[list[str]] = None + ): self.sources.append(source_code) return _compile_method_orig(source_code, dst_file_ext) diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 7a99a2f7379f2b..0ba0677422944a 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -5,10 +5,18 @@ import torch from ...utils import triton_version_uses_attrs_dict -from ..common import DeviceOpOverrides, register_device_op_overrides +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) class CUDADeviceOpOverrides(DeviceOpOverrides): + """ + CUDA-specific codegen functions, see DeviceOpOverrides for details + """ + def import_get_raw_stream_as(self, name: str) -> str: return f"from torch._C import _cuda_getCurrentRawStream as {name}" @@ -126,11 +134,17 @@ def kernel_driver(self) -> str: return source_codes def tma_descriptor_helpers(self) -> str: + """ + CUDA helper functions for initializing TMA Descriptors on host side + """ if torch.version.hip is not None: raise RuntimeError("Host-side TMA descriptors not supported on HIP.") # helper functions for initializing 1D and 2D TMA descriptors in C++. borrowed from the Triton code here: + # Old APIs (fill(1|2)DTMADescriptor): # https://github.com/triton-lang/triton/blob/6af4f88591c85de079d8a36a4d7dba67918e2b39/third_party/nvidia/backend/driver.c#L283 + # New APIs (fillTMADescriptor): + # https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/driver.c#L283 return """ #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 [[maybe_unused]] static void init1DTMADescriptor( @@ -225,6 +239,85 @@ def tma_descriptor_helpers(self) -> str: swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); } + + [[maybe_unused]] static void initTMADescriptor( + CUtensorMap* m, + void* globalAddress, + int elemSize, + int rank, + uint32_t* blockSize, + uint64_t* shape, + uint64_t* stride + ) { + uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; + uint32_t blockSizeInt[5]; + uint64_t shapeInt[5]; + uint64_t stridesLL[5]; + + // Reorder blockSize (reverse the order) + for (int i = 0; i < rank; ++i) { + blockSizeInt[rank - i - 1] = blockSize[i]; + } + + // Reorder shape (reverse the order) + for (int i = 0; i < rank; ++i) { + shapeInt[rank - i - 1] = shape[i]; + } + + // Reorder and calculate strides + for (int i = 0; i + 1 < rank; ++i) { + stridesLL[rank - i - 2] = elemSize * stride[i]; + } + stridesLL[rank - 1] = + shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]); + + CUtensorMapDataType type; + // In Triton this is computed ahead of time; but for simplicity + // in the PyTorch version we copied this code from the old + // TMA API handling (i.e. init2DTMADescriptor) + switch (elemSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elemSize must be 1, 2, or 4"); + } + + // Calculate the size of the most contiguous dimension in bytes + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elemSize * blockSizeInt[0]; + if (rank == 1) { + // rank 1 should not be swizzled + swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + } else if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, + shapeInt, stridesLL, blockSizeInt, elementStrides, + CU_TENSOR_MAP_INTERLEAVE_NONE, (CUtensorMapSwizzle)swizzle, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + struct StableTMADescriptor { + CUtensorMap m; + uint32_t block_shape[5]; + uint64_t global_shape[5]; + uint64_t strides[5]; + }; #endif """ @@ -240,9 +333,33 @@ def cpp_kernel_type(self) -> str: def cpp_device_ptr(self) -> str: return "CUdeviceptr" - def cpp_global_scratch(self, idx: int) -> Optional[tuple[str, str]]: + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: if triton_version_uses_attrs_dict(): - return f"CUdeviceptr global_scratch_{idx} = 0;", f"global_scratch_{idx}" + var_name = f"global_scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" + + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name return None diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 32eaaeedf2d9ef..ce47fdc810731c 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -10,9 +10,11 @@ import torch import torch.utils._pytree as pytree +from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops +from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.select_algorithm import create_inputs_key -from torch._inductor.utils import clear_on_fresh_inductor_cache +from torch._inductor.utils import clear_on_fresh_cache from ... import ir from ...config import cuda as inductor_cuda_config @@ -25,7 +27,7 @@ Layout, ReinterpretView, ) -from ...utils import is_dynamic, OrderedSet, Placeholder +from ...utils import is_dynamic, Placeholder from ...virtualized import V from ..common import IndentedBuffer from . import cutlass_utils @@ -292,7 +294,7 @@ }; """ -# Additional includes which are neccessary if the standalone test / debug runner is generated as wel +# Additional includes which are necessary if the standalone test / debug runner is generated as well GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES = r""" #ifdef GENERATE_STANDALONE_RUNNER #include "cutlass/util/distribution.h" @@ -374,7 +376,7 @@ std::cout << "Calling once to get workspace size" << std::endl; {{test_call_statement}}; - // Allocate workspace if neccessary + // Allocate workspace if necessary if (workspace_size > 0) { workspace_data.reset(workspace_size); std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; @@ -404,7 +406,7 @@ """ # noqa: B950 -@clear_on_fresh_inductor_cache +@clear_on_fresh_cache class CUTLASSGemmTemplate(CUTLASSTemplate, ABC): """ CUTLASS GEMM Template, which is used to generate CUTLASS GEMM kernels @@ -421,6 +423,7 @@ def __init__( alpha: float, beta: float, input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, ) -> None: """ Args: @@ -436,7 +439,8 @@ def __init__( ) self.alpha = alpha self.beta = beta - assert len(input_nodes) == 2 or len(input_nodes) == 3 or len(input_nodes) == 4 + self.use_fast_accum = use_fast_accum + assert 2 <= len(input_nodes) <= 5 assert self._are_inputs_layout_compatible( [node.get_layout() for node in input_nodes] ) @@ -452,6 +456,7 @@ def add_cutlass_gemm_choices( alpha: Union[float, int] = 1, beta: Union[float, int] = 0, input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, **extra_kwargs, ) -> None: raise NotImplementedError @@ -552,12 +557,16 @@ def _add_cutlass_gemm_choices( """ ops = self.gen_ops() - for name, op in ops: - for swizzle in inductor_cuda_config.cutlass_max_profiling_swizzle_options: - description = f"{name} swizzle={swizzle}" - self.maybe_append_choice( - choices, description=description, op=op, swizzle=swizzle - ) + with dynamo_timed("CUTLASSGemmTemplate.maybe_append_choice"): + for name, op in ops: + for ( + swizzle + ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options: + description = f"{name} swizzle={swizzle}" + self.maybe_append_choice( + choices, description=description, op=op, swizzle=swizzle + ) + if len(ops) == 0: input_layouts = [node.get_layout() for node in input_nodes] input_strides = [node.get_stride() for node in input_nodes] @@ -679,13 +688,13 @@ def should_swap_XW( ) -> bool: """ Helper method to determine whether we should do an explicit transpose by switching the order of the - matmul operands. This might be neccessary when we can't otherwise arrive at the right memory + matmul operands. This might be necessary when we can't otherwise arrive at the right memory layout for the given Bias operand. Note: This method is a workaround for CUDA Errors that seemingly non-deterministically occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term. it might make sense to check on newer Cutlass releases whether it makes sense to keep - returning True in certain cases or whether it becomes unneccessary. + returning True in certain cases or whether it becomes unnecessary. """ # If bias is row major, swap all M and N dimensions if ( @@ -872,6 +881,11 @@ def filter_op( # TODO: update epilogue functor according to epilogues. op.element_epilogue = op.accumulator_type() + if self.use_fast_accum is not None: + is_op_fast_accum = "fastaccum" in op.configuration_name() + if self.use_fast_accum ^ is_op_fast_accum: + return None + # Set bias layout and alignment. status = self._set_bias_layout_and_alignment(op) if not status: @@ -930,8 +944,15 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: log.debug("Using cached ops for %s", self.cache_key) return self.filtered_ops_cache[self.cache_key] - full_ops = cutlass_utils.gen_ops() - ops = pytree.tree_flatten(full_ops)[0] + with dynamo_timed("CUTLASSGemmTemplate.maybe_fetch_ops"): + maybe_ops = maybe_fetch_ops() + if maybe_ops is None: + log.debug("Cannot fetch ops from cache, generating ops from scratch") + full_ops = cutlass_utils.gen_ops() + ops = pytree.tree_flatten(full_ops)[0] + else: + log.debug("Using cached ops from cache") + ops = maybe_ops res: dict[str, cutlass_gemm_op.GemmOperation] = {} start_time = time.time() @@ -1042,8 +1063,9 @@ def render( # type: ignore[override] # to make op mutable without affecting others op = copy.deepcopy(op) - if Bias is not None: - assert Bias.get_layout().dtype == X.get_layout().dtype + is_scaled_mm = len(self.input_nodes) in (4, 5) + if Bias is not None and not is_scaled_mm: + assert Bias.get_dtype() == X.get_dtype() # This might have been set to void during filtering, when the assumption was still that there's no C # operand op.C.element = op.A.element @@ -1064,53 +1086,57 @@ def render( # type: ignore[override] op = self.swap_XW(op) should_swap_xw = True - is_scaled_mm = len(self.input_nodes) == 4 if epilogue_nodes or is_scaled_mm: if epilogue_nodes: ( - evt_read_names, - evt_write_names, + input_names, + output_names, var_name_to_buffer_name, evt_py_code, ) = CutlassEVTCodegen.ir_to_evt_python_code( Y.get_name(), epilogue_nodes, V.kernel.removed_buffers ) + D_output_name = var_name_to_buffer_name["D"] name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants.keys(): + name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) D_output_buffer = name_to_buffer[D_output_name] - D_dtype = D_output_buffer.get_dtype() Y = D_output_buffer # type: ignore[assignment] # Interestingly, I don't think the rest of the layout matters here since we # use the properties of the Y buffer to fill in D's properties in the epilogue # args. This is needed though because it defines types expected in the epilogue args. op.D.element = cutlass_utils.torch_dtype_to_cutlass_type( - D_output_buffer.get_layout().dtype + D_output_buffer.get_dtype() ) - read_names = OrderedSet(evt_read_names) - OrderedSet(evt_write_names) - write_names = OrderedSet(evt_write_names) - assert write_names, "There should be at least one write" + assert output_names, "There should be at least one write" - input_names = list(read_names) - output_names = list(write_names) epilogue_inputs = [name_to_buffer[name] for name in input_names] - epilogue_outputs = [name_to_buffer[name] for name in output_names] - else: # Scaled MM, we read the two scale matrices and write a single output + outputs = [name_to_buffer[name] for name in output_names] + else: # Scaled MM, we read the two scale matrices (and optional bias) and write a single output + bias = None if len(self.input_nodes) < 5 else self.input_nodes[4] + bias_name = bias.get_name() if bias else None + ( evt_read_names, var_name_to_buffer_name, evt_py_code, ) = scaled_mm_evt( - self.input_nodes[2].get_name(), - self.input_nodes[3].get_name(), + self.input_nodes[2].get_name(), # scale_A + self.input_nodes[3].get_name(), # scale_B + bias_name, Y.get_name(), ) input_names = list(evt_read_names) output_names = [] # We only need Y - D_dtype = Y.get_layout().dtype epilogue_inputs = [self.input_nodes[2], self.input_nodes[3]] - epilogue_outputs = [] + if bias: + epilogue_inputs.append(bias) + outputs = [] acc_dtype = cutlass_utils.get_accumulator_dtype( [X.get_dtype(), W.get_dtype()] @@ -1121,7 +1147,7 @@ def render( # type: ignore[override] op, evt_py_code, var_name_to_buffer_name, - D_dtype, + Y.get_dtype(), acc_dtype, ) @@ -1138,15 +1164,13 @@ def render( # type: ignore[override] ) else: evt_name = None - epilogue_inputs = [] - epilogue_outputs = [Y] + outputs = [Y] evt_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" evt_code = "" kernel_call_signature = kernel.def_kernel( inputs=inputs, # type: ignore[arg-type] - outputs=epilogue_outputs, # type: ignore[arg-type] - epilogue_inputs=[], + outputs=outputs, # type: ignore[arg-type] names_str=names_str, input_reorder=input_reorder, ) @@ -1237,8 +1261,11 @@ def __init__( alpha: float, beta: float, input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, ): - super().__init__(input_nodes, layout, alpha, beta, input_reorder) + super().__init__( + input_nodes, layout, alpha, beta, input_reorder, use_fast_accum + ) @staticmethod def add_cutlass_gemm_choices( @@ -1248,10 +1275,16 @@ def add_cutlass_gemm_choices( alpha: Union[float, int] = 1, beta: Union[float, int] = 0, input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, **extra_kwargs, ) -> None: template = CUTLASS3xGemmTemplate( - input_nodes, layout, alpha, beta, input_reorder + input_nodes, + layout, + alpha, + beta, + input_reorder, + use_fast_accum, ) template._add_cutlass_gemm_choices( choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs @@ -1308,7 +1341,7 @@ def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: Returns: bool: True if layouts are GEMM compatible, otherwise False. """ - assert len(layouts) == 2 or len(layouts) == 3 or len(layouts) == 4 + assert 2 <= len(layouts) <= 5 # Check if A and B are compatible A_layout, B_layout = layouts[:2] if len(A_layout.size) < 1: @@ -1376,6 +1409,12 @@ def _render_evt( from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + + for name in V.graph.constants.keys(): + name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + # handle the fake output buffer during lowering name_to_buffer[self.output_node.get_name()] = self.output_node # type: ignore[assignment] @@ -1384,6 +1423,7 @@ def _render_evt( examples = create_example_tensors( var_name_to_buffer_name, name_to_buffer, # type: ignore[arg-type] + V.graph.sizevars.size_hint, ) evt_name, evt_args, evt_code = trace( evt_py_code, @@ -1393,6 +1433,7 @@ def _render_evt( op.tile_description, # type: ignore[attr-defined] op.epilogue_schedule, # type: ignore[attr-defined] {k: name_to_buffer[v] for k, v in var_name_to_buffer_name.items()}, # type: ignore[arg-type,misc] + V.graph.sizevars.size_hint, ) return ( @@ -1488,7 +1529,7 @@ def _get_extra_inputs_and_names( self, op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: - Bias = None if len(self.input_nodes) in (2, 4) else self.input_nodes[2] + Bias = self.input_nodes[2] if len(self.input_nodes) == 3 else None inputs: list[Optional[Buffer]] = [] names: list[str] = [] return (Bias, inputs, names) @@ -1618,6 +1659,7 @@ def add_cutlass_gemm_choices( alpha: Union[float, int] = 1, beta: Union[float, int] = 0, input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = False, **extra_kwargs, ) -> None: template = CUTLASS2xGemmTemplate( diff --git a/torch/_inductor/codegen/cuda/serialization.py b/torch/_inductor/codegen/cuda/serialization.py index 2d1c639fdff720..82fe188c09e843 100644 --- a/torch/_inductor/codegen/cuda/serialization.py +++ b/torch/_inductor/codegen/cuda/serialization.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import enum +import functools import json from enum import Enum from typing import Optional @@ -458,6 +459,7 @@ def _json_to_enum(cls, json_dict, enum_class): return enum_class[json_dict["name"]] +@functools.lru_cache(1) def get_cutlass_operation_serializer() -> Optional[CUTLASSOperationSerializer]: if not try_import_cutlass(): return None diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index f51ee70b73bc76..0d979eeed83faf 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -643,12 +643,12 @@ def eq(left, right): if V.graph.sizevars.statically_known_equals(left, right): return True try: - a = V.graph.sizevars.size_hint(left) - b = V.graph.sizevars.size_hint(right) + a = V.graph.sizevars.size_hint_or_throw(left) + b = V.graph.sizevars.size_hint_or_throw(right) except TypeError: # unbacked symints return False if a == b: - V.graph.sizevars.guard_equals(left, right) + V.graph.sizevars.check_equals(left, right) return a == b @@ -656,15 +656,15 @@ def lt(left, right): if V.graph.sizevars.statically_known_lt(left, right): return True try: - a = V.graph.sizevars.size_hint(left) - b = V.graph.sizevars.size_hint(right) + a = V.graph.sizevars.size_hint_or_throw(left) + b = V.graph.sizevars.size_hint_or_throw(right) except TypeError: # unbacked symints gcd = sympy.gcd(left, right) if gcd == left: return left != right return False if a < b: - V.graph.sizevars.guard_lt(left, right) + V.graph.sizevars.check_lt(left, right) return a < b @@ -1447,7 +1447,7 @@ def halide_kernel_meta(self) -> HalideMeta: current_device = V.graph.get_current_device_or_throw() if current_device.type == "cpu": target = [config.halide.cpu_target] - schduler = config.halide.scheduler_cpu + scheduler = config.halide.scheduler_cpu scheduler_flags = { "parallelism": parallel_num_threads(), } @@ -1456,7 +1456,7 @@ def halide_kernel_meta(self) -> HalideMeta: assert current_device.type == "cuda", "only cpu/cuda supported" assert current_device.index <= 0, "only default device supported" target = [config.halide.gpu_target] - schduler = config.halide.scheduler_cuda + scheduler = config.halide.scheduler_cuda capability = torch.cuda.get_device_properties(current_device) if "cuda_capability" not in target[0]: for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]: @@ -1490,7 +1490,7 @@ def halide_kernel_meta(self) -> HalideMeta: return HalideMeta( argtypes, target="-".join(target), - scheduler=schduler, + scheduler=scheduler, scheduler_flags=scheduler_flags, # type: ignore[arg-type] cuda_device=cuda_device, ) diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 3a3fb24f6cbd32..f8176c191fd485 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -6,14 +6,16 @@ import itertools import logging import math +from pathlib import Path from typing import Any, Optional, TYPE_CHECKING import sympy from sympy.printing.precedence import PRECEDENCE import torch +from torch.utils._cpp_embed_headers import _embed_headers from torch.utils._ordered_set import OrderedSet -from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_ +from torch.utils._sympy.printers import CppPrinter, ExprPrinter as ExprPrinter_ from torch.utils._sympy.value_ranges import ValueRanges from ..utils import ceildiv, get_bounds_index_expr, get_kernel_metadata @@ -158,7 +160,7 @@ def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: class MetalOverrides(OpOverrides): - """Implements Metal-specific overrids for ops. Base class emits Python-friendly overrides""" + """Implements Metal-specific overrides for ops. Base class emits Python-friendly overrides.""" @staticmethod def to_dtype( @@ -209,24 +211,7 @@ def where(a: OpVarT, b: OpVarT, c: OpVarT) -> str: @staticmethod def remainder(a: OpVarT, b: OpVarT) -> str: - if ( - isinstance(b, CSEVariable) - and b.dtype is not None - and not b.dtype.is_floating_point - ): - return f"{a} % {b}" - # Upcast to float otherwise results of remainder op are wrong for half - float_a = ( - f"static_cast({a})" - if isinstance(a, CSEVariable) and a.dtype != torch.float - else a - ) - float_b = ( - f"static_cast({b})" - if isinstance(b, CSEVariable) and b.dtype != torch.float - else b - ) - return f"{float_a} - {float_b} * metal::floor({float_a} / {float_b})" + return f"c10::metal::remainder({a}, {b})" @staticmethod def maximum(a: CSEVariable, b: CSEVariable) -> str: @@ -448,6 +433,10 @@ def _initialize_special_ops(cls) -> None: "chebyshev_polynomial_w", "hermite_polynomial_h", "hermite_polynomial_he", + "shifted_chebyshev_polynomial_t", + "shifted_chebyshev_polynomial_u", + "shifted_chebyshev_polynomial_v", + "shifted_chebyshev_polynomial_w", ]: setattr( cls, @@ -469,6 +458,7 @@ class MetalKernel(SIMDKernel): max_threadgroup_size = 1024 simd_group_size = 32 pexpr = PythonPrinter().doprint + cexpr = CppPrinter().doprint sexpr = MetalExprPrinter().doprint kexpr = sexpr headers: OrderedSet[str] = OrderedSet(["utils"]) @@ -492,9 +482,9 @@ def load(self, name: str, index: sympy.Expr) -> CSEVariable: dtype = V.graph.get_dtype(name) line = f"{var}[{self.index_to_str(index)}]" if dtype in [torch.float16, torch.bfloat16]: - # TODO(NS): Figure out the right balance betwene optype casts + # TODO(NS): Figure out the right balance between optype casts # op_math_t for half-precision floats should be float32 - # Otherwise it can lead to a corretness issues with eager + # Otherwise it can lead to a correctness issues with eager line = f"static_cast({line})" dtype = torch.float32 return self.cse.generate(self.loads, line, dtype=dtype) @@ -580,6 +570,12 @@ def _reduction_nocache( assert self.inside_reduction assert not self._load_mask + def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: + # Uwraps vec3 dtype into individual components + return OpsWrapper._unwrap( + [CSEVariable(f"{res3}.{t}", res3.bounds, res3.dtype) for t in "xyz"] + ) + # Establish reduction buffer size and index expression reduction_idx = "" acc_buf_size = 1 @@ -683,8 +679,9 @@ def _reduction_nocache( wf_res = self.cse.generate( self.compute, f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=torch.float32, ) - return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z")) + return _unwrap_helper(wf_res) acc_buf = self._new_idxvar("float3", acc_buf_size) acc_thread_var = f"{acc_buf}[{reduction_idx}]" self.indexing_code.splice(f"{acc_thread_var} = 0.0;") @@ -694,8 +691,9 @@ def _reduction_nocache( wf_res = self.cse.generate( self.stores, f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})", + dtype=torch.float32, ) - return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z")) + return _unwrap_helper(wf_res) if reduction_type == "welford_combine": assert isinstance(value, tuple), "Input to welford combine must be tuple" acc_buf = self._new_idxvar("float3", acc_buf_size) @@ -712,21 +710,22 @@ def _reduction_nocache( wf_res = self.cse.generate( self.stores if self.multistage_reduction_entry else self.compute, f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=torch.float32, ) - return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z")) + return _unwrap_helper(wf_res) raise NotImplementedError(reduction_type) def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: index_expr = self.rename_indexing(entry.expr) index_str = self.sexpr(index_expr) # type: ignore[misc] - if not entry.is_reduction or entry.root.numel < self.max_threadgroup_size: + if not entry.is_reduction or entry.root.numel <= self.max_threadgroup_size: self.indexing_code.writeline( f"{self.index_dtype} {entry.name} = {index_str};" ) return self.multistage_reduction_entry.append(entry) - # When reducing the thensor whose size exceeds max threadgroup size + # When reducing the tensor whose size exceeds max threadgroup size # loop over extra indices per reduction thread and perform part of the operation # using values in the shared memory loop_size = ( @@ -759,7 +758,16 @@ def codegen_body(self) -> None: self.body.splice(self.compute) self.body.writeline("}" * len(self.multistage_reduction_entry)) # Invalidate variables instantiated inside loop - self.cse.invalidate(OrderedSet(self.cse.reduction_cache.values())) + # But results of reduction alive. Reduction cache values can be + # either CSEVariable or tuple of CSEVariables, in which case all + # variables in the tuple must be preserved + self.cse.invalidate( + OrderedSet( + v + for item in self.cse.reduction_cache.values() + for v in (item if isinstance(item, tuple) else (item,)) + ) + ) # And loop codegen while self.multistage_reduction_entry: self.multistage_reduction_entry.pop().cache_clear() @@ -786,6 +794,17 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: if not V.graph.cpp_wrapper: for header in self.headers: code.writeline(f"#include ") + else: + headers = [ + f"#include " for header in self.headers + ] + header_contents = _embed_headers( + headers, + [Path(__file__).parent.parent.parent / "include"], + OrderedSet(), # type: ignore[arg-type] + ) + code.writeline(header_contents) + if self.inside_reduction: total_reduction_size = math.prod( t.numel for t in self.range_trees if t.is_reduction @@ -850,18 +869,34 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: def call_kernel(self, name: str, node: Any = None) -> None: """Codegen a call to this kernel""" wrapper = V.graph.wrapper_code - # Make sure sizevarss has been computed + # Make sure sizevars has been computed for v in self.args.sizevars.keys(): wrapper.ensure_size_computed(v) + _, call_args, _, arg_types = self.args.python_argdefs() + arg_name_to_type = { + str(call_arg): arg_type for call_arg, arg_type in zip(call_args, arg_types) + } + args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()] args = [arg for arg in args if arg not in self.removed_buffers] args += [str(v) for v in self.args.sizevars.keys()] - # For reduction kernels, limit the maximum size over reduction dimentions to + + arg_types = [arg_name_to_type[arg] for arg in args] + expr_printer = self.cexpr if V.graph.cpp_wrapper else self.pexpr + + def format_threads(threads: list[str], kwarg: str) -> str: + if V.graph.cpp_wrapper: + threads = [f"static_cast({t})" for t in threads] + return f"{{{', '.join(threads)}}}" + else: + return f"{kwarg}=[{', '.join(threads)}]" + + # For reduction kernels, limit the maximum size over reduction dimensions to # a maximum threadgroup size if len(self.active_range_trees()) > 0: threads = [ - self.pexpr( + expr_printer( sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc] if v.is_reduction else v.numel @@ -869,36 +904,34 @@ def call_kernel(self, name: str, node: Any = None) -> None: for v in self.active_range_trees() ] - if V.graph.cpp_wrapper: - args += [f"{', '.join(threads)}"] - else: - args += [f"threads=[{', '.join(threads)}]"] + args.append(format_threads(threads, "threads")) + arg_types.append(list) else: if V.graph.cpp_wrapper: raise RuntimeError("We should always have threads?") if self.inside_reduction: threads = [ - self.pexpr(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc] + expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc] if v.is_reduction else "1" for v in self.active_range_trees() ] - if V.graph.cpp_wrapper: - args += [f"{{{', '.join(threads)}}}"] - else: - args += [f"group_size=[{', '.join(threads)}]"] + args.append(format_threads(threads, "group_size")) + arg_types.append(list) else: if V.graph.cpp_wrapper: # Add a None so that we always have a group_size in the # arguments. We won't use it if the value is None. args += [None] # type: ignore[list-item] + arg_types.append(None) wrapper.generate_kernel_call( name, args, device=torch.device("cpu"), # TODO: Fix me, MPS does not expose streams now triton=False, + arg_types=arg_types, ) def check_bounds( @@ -943,15 +976,18 @@ def define_kernel( mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" if V.graph.cpp_wrapper: - src_code = ( - f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}" - + src_code - ) kernel_name = f"{mps_lib_name}_func" else: kernel_name = f"{mps_lib_name}.generated_kernel" wrapper.src_to_kernel[src_code] = kernel_name + + if V.graph.cpp_wrapper: + src_code = ( + f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}" + + src_code + ) + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) metadata_comment = f"{origins}\n{detailed_origins}" wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False) diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 27d8c9b2afc18b..7178cf7cc195d2 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -18,65 +18,6 @@ log = logging.getLogger(__name__) -def get_kernel_argdefs(kernel): - arg_defs, _, _, _ = kernel.args.python_argdefs() - return [x.name for x in arg_defs] - - -def _get_all_args(args_list, arg_types_list=None): - all_args = max(args_list, key=len)[:] - arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None - for args in args_list: - assert OrderedSet(args).issubset(OrderedSet(all_args)), ( - f"{args} v.s. {all_args}" - ) - - return all_args, arg_types - - -def get_all_kernel_argdefs(kernels): - """ - The logic here must match with `get_all_call_args`, except no need to get arg_types here - """ - argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels] - - return _get_all_args(argdefs_list)[0] - - -def get_all_call_args(call_args_list, arg_types_list): - """ - Passed in the call_args for each subkernel and return the call_args for the - combined multi-kernel. - - Note an algorithm as follows does not always work: - ``` - all_call_args: Dict[ - Any, None - ] = {} # use a dict rather than set to maintain insertion order - for call_args in call_args_list: - all_call_args.update({arg: None for arg in call_args}) - - all_call_args = list(all_call_args.keys()) - ``` - It will fail if any kernel has the same argument passed in multiple times. - Check test_pass_same_arg_multi_times in test_multi_kernel.py - - Instead, we pick the longest call args and assert that other call args are - a subset of it. - """ - return _get_all_args(call_args_list, arg_types_list) - - -def get_numel_argdefs(kernel): - numel_argdefs = [ - f"{tree.prefix}numel" - for tree in kernel.range_trees - if not tree.is_reduction or kernel.inside_reduction - ] - - return numel_argdefs - - class MultiKernelState: """ Maintain state of multi-kernel compilation so we don't define duplicated @@ -147,7 +88,7 @@ class MultiKernel: ) ``` - Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 + Here is a concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 """ def __init__(self, kernels): @@ -224,7 +165,7 @@ def call_kernel(self, kernel_name): def codegen_nan_check(self): wrapper = V.graph.wrapper_code - seen = OrderedSet[str]() + seen: OrderedSet[str] = OrderedSet() for k in self.kernels: _, call_args, precompile_args, _ = k.args.python_argdefs() for arg, precompile_arg in zip(call_args, precompile_args): diff --git a/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py index deb43c98be8adb..5862534ce6cc19 100644 --- a/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py @@ -3,12 +3,15 @@ import logging import random from dataclasses import asdict, dataclass +from typing import Any import torch from torch._inductor import config from torch._inductor.codegen.rocm.ck_tile_template import CKTileTemplate from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.codegen.rocm.rocm_template import ArgInfo from torch._inductor.ir import Buffer, Layout +from torch.utils._ordered_set import OrderedSet from ...utils import IndentedBuffer @@ -92,7 +95,7 @@ def dict_items(self): return asdict(self).items() -@functools.lru_cache(None) +@functools.cache def ops(): """ Generate the supported instance dataclasses @@ -216,7 +219,9 @@ def ops(): for epilogue in ["Default", "CShuffle"] ] - return itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances) + return list( + itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances) + ) class CKTileGemmTemplate(CKTileTemplate): @@ -231,8 +236,6 @@ class CKTileGemmTemplate(CKTileTemplate): extern "C" { PT_EXPORT {{kernel_definition}} { - constexpr int32_t kBatch = {{k_batch}}; - using {{instance_namespace}}::BaseGemmPipeline; using {{instance_namespace}}::TilePartitioner; @@ -297,7 +300,6 @@ def __init__( input_nodes=input_nodes, layout=layout, ) - self.k_batch = 1 def header(self) -> IndentedBuffer: res = super().header() @@ -437,8 +439,8 @@ def check(dim_size, tile_size, is_padded): return True if op.layout_a == "Row": - if not check(K, op.tile_k * self.k_batch, op.k_is_padded): - return False + # handle in kBatch check + return True elif op.layout_a == "Col": if not check(M, op.tile_m, op.m_is_padded): return False @@ -449,8 +451,8 @@ def check(dim_size, tile_size, is_padded): if not check(N, op.tile_n, op.n_is_padded): return False elif op.layout_b == "Col": - if not check(K, op.tile_k * self.k_batch, op.k_is_padded): - return False + # handle in kBatch check + return True else: raise AssertionError(f"Invalid {op.layout_b=}") @@ -847,7 +849,6 @@ def render_dispatch(pipeline_type, op_name): instance_namespace=op.name(), version_comment=version_comment, rendered_dispatch=render_dispatch(op.pipeline, op.name()), - k_batch=self.k_batch, ) def gen_ops(self): @@ -858,7 +859,13 @@ def gen_ops(self): An instance may invalidate the GEMM configuration at runtime. Such instances will be assigned +inf runtime by the autotune process. """ - filtered_instances = list(filter(self.filter_op, ops())) + instances = ops() + if not instances: + raise AssertionError( + "No Composable Kernel Universal GEMM instances found. " + "Please check if the library is installed." + ) + filtered_instances = list(filter(self.filter_op, instances)) # NB: when using a fixed list order, most likely we will pick the subset of instances # which are very similar to each other. Randomizing the choice seems to solve this. random.seed(-11) @@ -871,7 +878,7 @@ def gen_ops(self): else filtered_instances ) log.debug( - "generated %d ck instances after filter: %s", + "generated %d ck instances after sample: %s", len(chosen_instances), chosen_instances, ) @@ -892,10 +899,43 @@ def add_choices( ) ops = template.gen_ops() for op in ops: - template.maybe_append_choice( - choices, - op=op, + for k_batch in template.k_batch_choices(op): + template.maybe_append_choice( + choices, + op=op, + kBatch=k_batch, + ) + + def k_batch_choices(self, op: "CKTileGemmOperation") -> tuple[int, ...]: + """ + Returns a list of k_batch choices for the template. + """ + default_choices = (1, 2, 4, 8, 16, 32) + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + _, _, K, _, _, _ = self.size_args() + if op.layout_a == "Row" or op.layout_b == "Col": + choices = tuple( + filter( + lambda k_batch: check(K, op.tile_k * k_batch, op.k_is_padded), + default_choices, + ) ) + else: + choices = default_choices + + if op.epilogue == "Default": + choices = (1,) + + return choices def size_args(self): """ @@ -913,3 +953,15 @@ def size_args(self): LDC = Y.get_stride()[0 if Y.get_stride()[1] == 1 else 1] return M, N, K, LDA, LDB, LDC + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + # maybe_append_choice kwarg for k_batch must match the name of the argument + arg_names = OrderedSet([arg.name for arg in self.get_runtime_arg_info()]) + if not arg_names.issubset(kwargs): + raise ValueError( + "Missing runtime arguments: " + ", ".join(arg_names - kwargs.keys()) + ) + return [kwargs[k] for k in arg_names] diff --git a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py index 720509f282660b..9288f73954ff3b 100644 --- a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py +++ b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -86,7 +86,7 @@ def codegen_template( _, (_numel, rnumel) = template_node.group assert rnumel == 1 ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) - kernel, render = ctb.make_kernel_render(ctb) + kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc] with kernel: template_node.mark_run() src_code = render() diff --git a/torch/_inductor/codegen/rocm/rocm_kernel.py b/torch/_inductor/codegen/rocm/rocm_kernel.py index 2e51a96d327921..5b90823b7f41c9 100644 --- a/torch/_inductor/codegen/rocm/rocm_kernel.py +++ b/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -7,7 +7,15 @@ from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.utils import do_bench_using_profiling -from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, TensorBox +from ...ir import ( + Buffer, + ChoiceCaller, + IRNode, + Layout, + PrimitiveInfoType, + ShapeAsConstantBuffer, + TensorBox, +) from ...virtualized import V from ..common import Kernel, OpOverrides, WorkspaceArg, WorkspaceZeroMode from ..cpp_utils import CppPrinter @@ -276,7 +284,7 @@ def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType **dict(self.info_kwargs["op"].dict_items()), # type: ignore[union-attr, index] } - def output_node(self) -> TensorBox: + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: self.bmreq.update_workspace_size() return TensorBox.create( ROCmTemplateBuffer( diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 8f4dbda0fdaa03..bb8f0e37604d45 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -18,6 +18,7 @@ import torch import torch._logging +from torch._inductor.tiling_utils import analyze_memory_coalescing from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.immutable_collections import immutable_dict from torch.utils._ordered_set import OrderedSet @@ -62,6 +63,7 @@ from .simd_kernel_features import ( DisableReduction, EnableReduction, + NodeScheduleEntry, NodeScheduleMarker, SIMDKernelFeatures, ) @@ -70,6 +72,8 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence + from torch._inductor.tiling_utils import CoalesceVarAnalysis + log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") @@ -82,6 +86,11 @@ all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"]) +def get_max_tiles(default: int = 2) -> int: + max_tiles = torch._inductor.config.triton.max_tiles + return max_tiles if max_tiles is not None else default + + @dataclasses.dataclass class IterationRanges: """ @@ -142,7 +151,7 @@ def symt(self) -> SymT: class IterationRangesRoot(IterationRanges): """ Root of a iteration range tree that represents a single - tiled dimension in the output kernel. It contains muliple + tiled dimension in the output kernel. It contains multiple sets of iteration represented with IterationRangesEntry. """ @@ -376,6 +385,7 @@ def __init__( pid_cache: Optional[dict[str, str]] = None, override_persistent_reduction: Optional[bool] = None, override_cooperative_reduction: Optional[bool] = None, + tiling_scores: Optional[dict[str, sympy.Expr]] = None, ) -> None: if pid_cache is None: pid_cache = {} @@ -396,6 +406,7 @@ def __init__( if override_cooperative_reduction is not None else self.should_use_cooperative_reduction() ) + self.tiling_scores: Optional[dict[str, sympy.Expr]] = tiling_scores self.persistent_reduction: bool = ( override_persistent_reduction if override_persistent_reduction is not None @@ -405,7 +416,7 @@ def __init__( self.code_hash: Optional[str] = None # define this in a closure to make cache local to object - @functools.lru_cache(None) + @functools.cache def simplify_indexing(index: sympy.Expr): index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) for tree in self.range_trees: @@ -683,6 +694,7 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: size, remaining[current_group] ): raise CantSplit + size1 = remaining[current_group] size2 = FloorDiv(size, remaining[current_group]) return_getters.append( @@ -706,21 +718,33 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: return new_ranges, return_getters_groups @classmethod - def is_compatible( + def prepare_split_iteration_lengths( cls, groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]], reduction_numel: sympy.Expr = sympy.S.One, - ) -> bool: - # Fill in the reduction numel, in case the node is missing it. + ) -> Sequence[Sequence[sympy.Expr]]: + "Fill in the reduction numel of lengths if missing" sizevars = V.graph.sizevars if len(lengths[1]) == 0 and ( - sizevars.statically_known_equals( + not sizevars.statically_known_equals(reduction_numel, sympy.S.One) + and sizevars.statically_known_equals( sympy_product(groups), sympy_product(lengths[0]) * reduction_numel, ) ): - lengths = (lengths[0], [reduction_numel]) + return (lengths[0], [reduction_numel]) + + return lengths + + @classmethod + def is_compatible( + cls, + groups: Iterable[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + reduction_numel: sympy.Expr = sympy.S.One, + ) -> bool: + lengths = cls.prepare_split_iteration_lengths(groups, lengths, reduction_numel) try: cls._split_iteration_ranges(groups, lengths) @@ -731,13 +755,35 @@ def is_compatible( def split_and_set_ranges( self, lengths: Sequence[Sequence[sympy.Expr]] ) -> list[list[sympy.Expr]]: + """ + Split and set iteration ranges for the kernel based on the provided lengths. + + This method maps the kernel's tiling structure to the node's iteration space, + handling both pointwise and reduction dimensions appropriately. + + Args: + lengths: A sequence of sequences of symbolic expressions representing + the sizes of different dimensions for each node. + + Returns: + A list of lists of symbolic expressions representing the mapped + iteration variables for each dimension. + """ + # Create a dictionary mapping each range tree prefix to its total number of elements tiling = {rt.prefix: rt.numel for rt in self.range_trees} + + # If we're not inside a reduction loop, set all reduction dimensions to 1 + # This effectively disables reduction dimensions when not needed if not self.inside_reduction: for prefix in tiling: if prefix_is_reduction(prefix): tiling[prefix] = sympy.S.One + # Extract the values from the tiling dictionary to create groups groups = [*tiling.values()] + + # Map the kernel's group structure to the node's sizes and set the ranges + # using the set_ranges method, returning the resulting iteration variables return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) @classmethod @@ -1088,6 +1134,11 @@ def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): class SIMDScheduling(BaseScheduling): + """ + Single Instruction Multiple Data parent class used for fusion across + multiple different backends. + """ + kernel_type: type[Any] = SIMDKernel # override in subclass def group_fn(self, sizes): @@ -1230,8 +1281,8 @@ def generate_node_schedule(self, nodes, numel, rnumel): done = OrderedSet[scheduler.BaseSchedulerNode]() # Writes with a reduced shape, meaning they are only present once the # reduction loop has ended - not_ready_yet_nodes = OrderedSet[str]() - current_loop_buffer_usage = OrderedSet[str]() + not_ready_yet_nodes: OrderedSet[str] = OrderedSet() + current_loop_buffer_usage: OrderedSet[str] = OrderedSet() maybe_split_index: Optional[int] = None def fits_in_main_body(n): @@ -1330,13 +1381,17 @@ def codegen_node( nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + if torch._inductor.config.triton.coalesce_tiling_analysis: + coalesce_analysis = analyze_memory_coalescing(node) + else: + coalesce_analysis = None _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group node_schedule = self.generate_node_schedule(nodes, numel, rnumel) schedule_log.debug("Schedule:\n %s", node_schedule) return self.codegen_node_schedule( - SIMDKernelFeatures(node_schedule, numel, rnumel) + SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis) ) @staticmethod @@ -1364,18 +1419,24 @@ def can_use_32bit_indexing( # Only install guards for 32-bit indexing as there is no correctness # issue with using 64-bit for everything - V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] + V.graph.sizevars.check_leq(numel, int_max) # type: ignore[arg-type] for size in buf_sizes: - V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] + V.graph.sizevars.check_leq(size, int_max) # type: ignore[arg-type] return True def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): node_schedule = kernel_features.node_schedule - tiling = self.select_tiling( - node_schedule, kernel_features.numel, kernel_features.reduction_numel + + tiling, tiling_score = self.get_tiling_and_scores( + node_schedule, + kernel_features.numel, + kernel_features.reduction_numel, + kernel_features.coalesce_analysis, ) kernels = self.create_kernel_choices( - kernel_features, [tiling], {"features": kernel_features} + kernel_features, + [tiling], + {"features": kernel_features, "tiling_scores": tiling_score}, ) for kernel in kernels: self.codegen_node_schedule_with_kernel(node_schedule, kernel) @@ -1416,7 +1477,7 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove if ( - V.graph.wrapper_code.supports_intermediate_hooks + V.graph.wrapper_code.supports_intermediate_hooks # type: ignore[has-type] and config.generate_intermediate_hooks ): # Not every node in the schedule will actually be live on output; @@ -1531,7 +1592,7 @@ def codegen_template( p_n.can_codegen_without_upcasts() for p_n in prologue_group ) - # TODO - this doesnt work with libdevice calls, potentially other bugs + # TODO - this doesn't work with libdevice calls, potentially other bugs # upcasting to fp32 and downcasting gives large slowdown with config.patch( "triton.codegen_upcast_to_fp32", not can_codegen_without_upcast @@ -1869,7 +1930,7 @@ def get_nd_tilings( reduction_numel, ) -> list[dict[str, tuple[sympy.Expr]]]: """ - Creates N-dimensional tiling candidiates, attempting to simplify loads/stores + Creates N-dimensional tiling candidates, attempting to simplify loads/stores by tiling the kernel into higher dimensions. Returns a list of tilings ranked by dimensionality. @@ -1959,7 +2020,7 @@ def get_nd_tilings( # Flatten leading dimensions, assigning labels to each dim. for node_tiling in node_tilings: - num_leading_dims = max(0, len(node_tiling) - config.triton.max_tiles) + num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2)) first_trailing_dim = num_leading_dims + 1 collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim]) collapsed_splits = (collapsed_leading_dim,) + tuple( @@ -1983,10 +2044,265 @@ def get_nd_tilings( return ranked_tilings + @classmethod + def compute_tiling_strategy( + cls, + node_schedule: list[NodeScheduleEntry], + pointwise_numel: sympy.Expr, + reduction_numel: sympy.Expr, + coalesce_analysis: CoalesceVarAnalysis, + ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]: + """ + Generates a tiling, and a score of each tile according to each tile's coalesced memory accesses. + """ + tiling_var: Optional[sympy.Expr] = ( + None + if not coalesce_analysis.suggested_split + else coalesce_analysis.suggested_split.var + ) + + all_iter_vars = coalesce_analysis.norm_read_writes.index_vars + all_red_vars = coalesce_analysis.norm_read_writes.reduce_vars + ranges = coalesce_analysis.norm_read_writes.var_ranges + + pw_ranges = [ranges[v] for v in all_iter_vars] + red_ranges = [ranges[v] for v in all_red_vars] + + torch._check( + sympy_product(pw_ranges) == pointwise_numel, + lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}", + ) + torch._check( + sympy_product(red_ranges) == reduction_numel, + lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}", + ) + + # score of a pointwise or reduction split + scored_sub_split: dict[Any, tuple[list[int], list[int]]] = {} + + score_split: list[ + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]] + ] = [] + + def process_node_vars( + vars_to_use: tuple[sympy.Expr, ...] = (), + use_split_var: bool = False, + is_pointwise: bool = False, + ) -> tuple[list[int], list[int]]: + """ + Generate a tiling, and a tiling score, given vars to use as splits. + """ + + ranges = pw_ranges if is_pointwise else red_ranges + target_numel = pointwise_numel if is_pointwise else reduction_numel + # Some kernels have no reduction ranges, and a reduction numel of 1 + if not ranges: + if target_numel: + return ([target_numel], []) + else: + return ([], []) + + key = (repr(vars_to_use), use_split_var, is_pointwise) + if out := scored_sub_split.get(key, None): + return out + + splitting_vars = all_iter_vars if is_pointwise else all_red_vars + + splits = [] + split_scores = [] + prod = 1 + prev_var_coalesced_score = 0 + + # iterate from non-dense to dense + for v, v_range in zip(splitting_vars, ranges): + if v not in vars_to_use: + prod *= v_range + prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get( + v, 0 + ) + continue + + if use_split_var and v == tiling_var: + var_tiling = coalesce_analysis.suggested_split + assert var_tiling is not None + + tile = var_tiling.tiling_factor + remainder = FloorDiv(v_range, var_tiling.tiling_factor) + + splits.append(prod * remainder) + split_scores.append(var_tiling.score) + + splits.append(tile) + split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0)) + + prod = 1 + prev_var_coalesced_score = 0 + + continue + + prod *= v_range + splits.append(prod) + split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0)) + prod = 1 + + if prod != 1 or (is_pointwise and len(splits) == 0): + splits.append(prod) + split_scores.append(prev_var_coalesced_score) + + # penalize splits that leave small blocks + # where we can't fully utilize full memory transaction + # TODO: incorporate exact bitwidth, and read/write + # coalesced write is 2x more important + for i in range(len(splits)): + s = V.graph.sizevars.size_hint(splits[i], fallback=32) + s = min(s, 8) + split_scores[i] = int(split_scores[i] * s / 8) + + scored_sub_split[key] = (splits, split_scores) + return (splits, split_scores) + + # add the default tiling + score_split.append( + ( + process_node_vars(is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + if tiling_var: + score_split.append( + ( + process_node_vars( + (tiling_var,), use_split_var=True, is_pointwise=True + ), + process_node_vars(is_pointwise=False), + ) + ) + + # TODO, add tests, reduction splits if config.triton.tile_reductions + # TODO: we should ignore tiny increases in score for extra splits + overlapping_iter_vars = ( + all_iter_vars & coalesce_analysis.coalesced_by_var.keys() + ) + for v in overlapping_iter_vars: + score_split.append( + ( + process_node_vars((v,), is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + if get_max_tiles(default=3) == 3 and reduction_numel == 1: + for vars_to_use in itertools.combinations(overlapping_iter_vars, 2): + score_split.append( + ( + process_node_vars(vars_to_use, is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + tilings: list[tuple[CandidateTiling, dict[str, sympy.Expr]]] = [] + for (pw_split, pw_score), (red_split, red_score) in score_split: + candidate = CandidateTiling( + cls.create_tiling(pw_split, red_split), + score=sum(pw_score) + sum(red_score), + ) + tiling_score = cls.create_tiling(pw_score, red_score) + tilings.append((candidate, tiling_score)) + + default_tiling = cls.create_tiling([pointwise_numel], [reduction_numel]) + + # add a slight penalty for longer tilings that dont increase score much, + # and are poor sizes + bad_size_additional_tiling_penalty = 1.025 + good_size_tiling_penalty = 1.005 + + def score_mod(t): + score_factor = 1.0 + for tile_size in t[0].tiling.values(): + if not CandidateTiling.is_good_size(tile_size): + score_factor = score_factor / bad_size_additional_tiling_penalty + else: + score_factor = score_factor / good_size_tiling_penalty + + return -t[0].score * score_factor + + # apply penalty for longer tilings that dont increase score much + for cand, tiling_score in sorted(tilings, key=score_mod): + if cls.tiling_is_compatible( + node_schedule, pointwise_numel, reduction_numel, cand.tiling + ): + # we always include default reduction numel == 1, dont include + tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0) + if tiling_len > get_max_tiles(default=3): + perf_hint_log.info( + "Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles " + "set to %s. Consider increasing", + tiling_len, + torch._inductor.config.triton.max_tiles, + ) + continue + + return cand.tiling, tiling_score + + # surprisingly, the default tiling is not always read as compatible by `tiling_is_compatible` + # TODO - look into, occurs with dynamic shapes often + if cand.tiling == default_tiling: + return cand.tiling, tiling_score + + return default_tiling, None + + @classmethod + def tiling_is_compatible( + cls, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + tiling: dict[str, sympy.Expr], + ): + assert isinstance(tiling, dict) + return all( + SIMDKernel.is_compatible( + tiling.values(), node.get_ranges(), reduction_numel=reduction_numel + ) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode) + ) + + @classmethod + def get_first_compatible_tiling( + cls, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + ranked_tilings: list[dict[str, sympy.Expr]], + ): + for tiling in ranked_tilings: + if cls.tiling_is_compatible(node_schedule, numel, reduction_numel, tiling): + return tiling + + return None + @classmethod def select_tiling( - cls, node_schedule, numel, reduction_numel=sympy.S.One + cls, + node_schedule, + numel, + reduction_numel=sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, ) -> dict[str, sympy.Expr]: + return cls.get_tiling_and_scores( + node_schedule, numel, reduction_numel, coalesce_analysis + )[0] + + @classmethod + def get_tiling_and_scores( + cls, + node_schedule, + numel, + reduction_numel=sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]: """ Heuristics to decide how to tile kernels. Currently, we tile based on stride-1 dimensions. @@ -2000,9 +2316,20 @@ def select_tiling( # Tiled reductions are gated by a config flag. default_tiling = cls.create_tiling([numel], [reduction_numel]) + + # # TODO: enable by default if ( - not is_pointwise and not config.triton.tile_reductions - ) or config.triton.max_tiles <= 1: + torch._inductor.config.triton.coalesce_tiling_analysis + and coalesce_analysis + and not config.triton.prefer_nd_tiling + ): + return cls.compute_tiling_strategy( + node_schedule, numel, reduction_numel, coalesce_analysis + ) + + if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles( + default=2 + ) <= 1: # Emit a perf hint in case we miss an opportunity to tile a reduction. if perf_hint_log.level <= logging.WARNING: for node in EnableReduction.filter(node_schedule): @@ -2019,9 +2346,10 @@ def select_tiling( ) ) break - return default_tiling - seen_names = OrderedSet[str]() + return default_tiling, None + + seen_names: OrderedSet[str] = OrderedSet() candidate_tiles: Counter[CandidateTiling] = collections.Counter() for node in EnableReduction.filter(node_schedule): for candidate_tiling in cls.candidate_tilings(node, numel, reduction_numel): @@ -2036,7 +2364,7 @@ def select_tiling( for candidate_tiling, score in candidate_tiles.most_common() ] - if config.triton.max_tiles >= 3 and is_pointwise: + if get_max_tiles(default=2) >= 3 and is_pointwise: # Consider adding a third dimension of tiling, but only # when a1 is a multiple of b1; otherwise, you have a lot # of stragglers which is annoying to generate code for. @@ -2089,18 +2417,12 @@ def convert_tiling_to_3d( + ranked_tilings ) - for tiling in ranked_tilings: - assert isinstance(tiling, dict) - if all( - SIMDKernel.is_compatible( - tiling.values(), node.get_ranges(), reduction_numel=reduction_numel - ) - for node in node_schedule - if isinstance(node, scheduler.SchedulerNode) - ): - return tiling + if tiling := cls.get_first_compatible_tiling( + node_schedule, numel, reduction_numel, ranked_tilings + ): + return tiling, None - return default_tiling + return default_tiling, None def flush(self): pass diff --git a/torch/_inductor/codegen/simd_kernel_features.py b/torch/_inductor/codegen/simd_kernel_features.py index 54dcbfa275f297..77e9dba34eddad 100644 --- a/torch/_inductor/codegen/simd_kernel_features.py +++ b/torch/_inductor/codegen/simd_kernel_features.py @@ -24,6 +24,8 @@ if typing.TYPE_CHECKING: from collections.abc import Iterable, Sequence + from torch._inductor.tiling_utils import CoalesceVarAnalysis + class NodeScheduleMarker: @staticmethod @@ -80,12 +82,14 @@ def __init__( node_schedule: list[NodeScheduleEntry], numel: sympy.Expr, reduction_numel: sympy.Expr = sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, ): self.node_schedule = node_schedule # numel excludes reduction_numel self.numel: sympy.Expr = V.graph.sizevars.simplify(numel) self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel) self._stats_cache: dict[tuple[sympy.Expr, ...], MemoryStats] = {} + self.coalesce_analysis = coalesce_analysis @cache_on_self def is_reduction(self) -> bool: @@ -119,7 +123,7 @@ def contains_op(self, op_name: str) -> bool: return bool(self.op_counts().get(op_name)) def get_mutations(self) -> OrderedSet[str]: - mutations = OrderedSet[str]() + mutations: OrderedSet[str] = OrderedSet() for node in self.scheduler_nodes(): for buf in node.get_outputs(): mutations.update(buf.get_mutations()) @@ -128,7 +132,7 @@ def get_mutations(self) -> OrderedSet[str]: @cache_on_self def select_index_dtype(self) -> torch.dtype: # Gather all used buffer names - buffer_names = OrderedSet[str]() + buffer_names: OrderedSet[str] = OrderedSet() for node in self.scheduler_nodes(): buffer_names.update(node.get_buffer_names()) buffer_names.update(node.used_buffer_names()) diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 8c6b71f5dc327b..8e34c43cebad5c 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -1,15 +1,15 @@ import itertools import logging -from typing import Any, Callable +from typing import Any, Callable, Union import torch import torch._inductor.config as config from torch._inductor import ir from torch._inductor.codegen.common import KernelTemplate from torch._inductor.ir import ( - add_symbolic_shapes_for_inputs_to_subgraph, Buffer, get_free_symbols, + get_symbolic_inputs, gm_original_output_strides, ir_node_to_tensor, Layout, @@ -49,6 +49,9 @@ def __init__( self.example_inputs.append(ir_node_to_tensor(inp)) self.gm = make_fx_graph(*self.example_inputs) + gm_original_output_strides(self.gm) + + self.sym_inputs = get_symbolic_inputs(self.input_nodes) def __str__(self) -> str: return f"SubgraphCaller({self.name})" @@ -60,7 +63,6 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: import torch._inductor.config as inductor_config from torch._inductor.graph import GraphLowering - gm_original_output_strides(self.gm) bm_graph_lowering = GraphLowering( gm=self.gm, example_inputs=self.example_inputs, @@ -73,12 +75,13 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: name=f"benchmark_{self.name}", ) - sym_inputs = add_symbolic_shapes_for_inputs_to_subgraph( - self.input_nodes, bm_graph_lowering - ) + for sym_inp in self.sym_inputs: + bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp + bm_graph_lowering.graph_input_names.append(sym_inp.name) sym_inputs = [ - int(V.graph.sizevars.shape_env.size_hint(sym_var)) for sym_var in sym_inputs + int(V.graph.sizevars.shape_env.size_hint(sym_var)) + for sym_var in self.sym_inputs ] if len(sym_inputs) == 0: @@ -129,7 +132,7 @@ def hash_key(self) -> str: ] ) - def output_node(self) -> ir.TensorBox: + def output_node(self) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: return ir.TensorBox.create( ir.SubgraphBuffer( layout=self.layout, diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 8f3ddb77129091..47d222e1b2f237 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -8,6 +8,7 @@ import itertools import logging import math +import operator import os import textwrap from collections.abc import Iterable, Sequence @@ -226,7 +227,12 @@ def has_rmask(self) -> bool: @property def mask_str(self) -> str: - return " & ".join(map(str, self.mask_vars)) if self.mask_vars else "None" + # The sorted call is added to make sure the order is still + # deterministic if self.mask_vars contains mix of string + # and TritonCSEVariable + return ( + " & ".join(sorted(map(str, self.mask_vars))) if self.mask_vars else "None" + ) @dataclasses.dataclass @@ -383,7 +389,7 @@ def remove_dims(it): broadcast_shape=broadcast_shape, broadcasting_dims=broadcasting_dims, ) - result.compute_boundary_check(get_max_block) + result.compute_boundary_check(get_max_block, range_trees) return result def replace_offset( @@ -430,27 +436,50 @@ def remove_roffsets(expr: sympy.Expr) -> sympy.Expr: ] return f"tl.make_block_ptr({', '.join(args)})" - def compute_boundary_check(self, get_max_block: Callable[[str], int]) -> None: + def compute_boundary_check( + self, + get_max_block: Callable[[str], int], + range_trees: list[IterationRangesRoot], + ) -> None: """List of indices to pass to tl.load(boundary_check=...)""" sizevars = V.graph.sizevars # Substitute maximum block sizes in shape expressions. # This works in multiple_of checks because block sizes are powers of 2. block_to_max: dict[sympy.Expr, Any] = { - block_size: get_max_block(prefix_str[symt]) - for symt, block_size in TritonSymbols.block_sizes.items() + TritonSymbols.block_sizes[t.symt]: get_max_block(prefix_str[t.symt]) + for t in range_trees } + # Also see Note: Constant mask optimisation + # if ynumel / YBLOCK > max_ygrid, then the z dimension is used to handle + # the remaining programs that cannot fit into the y dimension. This means + # it's possible that more than the required number of programs are launched, + # possibly leading to out-of-bounds accesses. So even if ynumel divides YBLOCK, + # boundary checking is required in the dimensions that are based on YBLOCK + # e.g. for [YBLOCK // 16, YBLOCK, XBLOCK] dimensions 0 and 1 need boundary + # checks when max_ygrid is exceeded. + needs_overflow_grid = any(map(V.kernel.needs_yz_grid_overflow, range_trees)) self._boundary_check = [ idx for idx in range(len(self.shape)) if ( not sizevars.statically_known_equals(self.strides[idx], sympy.S.Zero) - and not sizevars.statically_known_multiple_of( - self.shape[idx], self.block_shape[idx] - ) - and not sizevars.statically_known_multiple_of( - self.shape[idx], sympy_subs(self.block_shape[idx], block_to_max) + and ( + ( + needs_overflow_grid + and TritonSymbols.block_sizes[SymT.YBLOCK] + in self.block_shape[idx].free_symbols + ) + or ( + not sizevars.statically_known_multiple_of( + self.shape[idx], self.block_shape[idx] + ) + and not sizevars.statically_known_multiple_of( + self.shape[idx], + sympy_subs(self.block_shape[idx], block_to_max), + ) + ) ) and not ( V.kernel.no_x_dim @@ -618,7 +647,7 @@ def _print_Where(self, expr: sympy.Expr) -> str: def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str: """ - Helper for max/min code genereration. + Helper for max/min code generation. cmp: > or < """ if len(expr.args) == 1: @@ -748,7 +777,7 @@ class TritonCSEVariable(CSEVariable): def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None: super().__init__(name, bounds, dtype) # We'll use this to track which masks the variable needs when used for indirect indexing - self.mask_vars = OrderedSet[str]() + self.mask_vars: OrderedSet[str] = OrderedSet() assert dtype is not None, "TritonCSEVariable must have dtype" def update_on_args(self, name, args, kwargs): @@ -910,7 +939,7 @@ def _shaped_constant(value, dtype, shape): return triton_val # NOTE: We use a tensor here in order to get the expected type. - # Otherwise, e.g. float64 constants would be trunctated to float32. + # Otherwise, e.g. float64 constants would be truncated to float32. if value < 0 and not dtype.is_signed: triton_signed_type = f"tl.{triton_type[4:]}" return f"tl.full({shape}, {triton_val}, {triton_signed_type}).to({triton_type})" @@ -1305,7 +1334,7 @@ def __init__(self, *args, **kwargs): self._setup_libdevice_routing() @classmethod - @functools.lru_cache(None) + @functools.cache def _setup_libdevice_routing(cls): """Set up routing to libdevice implementations for fp64 inputs.""" @@ -1768,8 +1797,8 @@ def indexing( index_vars = index.free_symbols has_rindex = False - mask_vars = OrderedSet[str]() - for var in index_vars: + mask_vars: OrderedSet[str] = OrderedSet() + for var in sorted(index_vars, key=operator.attrgetter("name")): assert isinstance(var, sympy.Symbol) has_rindex = has_rindex or symbol_is_type( var, TritonSymbols.reduction_types @@ -1810,7 +1839,7 @@ def indexing( have_dense = True have_loop_vars = False - dense_mask_vars = OrderedSet[str]() + dense_mask_vars: OrderedSet[str] = OrderedSet() for tree in self.active_range_trees(): if index_vars.intersection(tree.var_list): @@ -1927,7 +1956,7 @@ def match_mod_div_block( # Compute the ND block shape from the linear block size. # Use CielDiv to round leading dimensions up to 1. # Non-leading dimensions are clamped to the size of the iteration range, - # while the leading dimension can exceed this to accomodate a larger + # while the leading dimension can exceed this to accommodate a larger # block size. linear_block_size = TritonSymbols.get_block_size(range_tree) block_shape: list[sympy.Expr] = [ @@ -2736,13 +2765,19 @@ def _mask_value(value, default) -> CSEVariable: assert reduction_type == "welford_reduce" result_mean, result_m2, result_weight = result_var peer_mean = self.codegen_cooperative_reduction_peer_combine( - result_mean, upcast_acc_dtype(src_dtype), default[0] + result_mean, + upcast_acc_dtype(src_dtype), + default[0], # type: ignore[index] ) peer_m2 = self.codegen_cooperative_reduction_peer_combine( - result_m2, upcast_acc_dtype(src_dtype), default[1] + result_m2, + upcast_acc_dtype(src_dtype), + default[1], # type: ignore[index] ) peer_weight = self.codegen_cooperative_reduction_peer_combine( - result_weight, upcast_acc_dtype(src_dtype), default[2] + result_weight, + upcast_acc_dtype(src_dtype), + default[2], # type: ignore[index] ) self.welford_reduce_final_reduction( self.post_loop_store, @@ -2757,6 +2792,7 @@ def _mask_value(value, default) -> CSEVariable: ) elif reduction_type == "online_softmax_reduce": result_max, result_sum = result_var + assert isinstance(default, Sequence) peer_max = self.codegen_cooperative_reduction_peer_combine( result_max, upcast_acc_dtype(src_dtype), default[0] ) @@ -3549,7 +3585,7 @@ def codegen_kernel(self, name=None): arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] ) - mutated_args = OrderedSet[str]() + mutated_args: OrderedSet[str] = OrderedSet() for mutation in self.mutations: if mutation in self.args.input_buffers: mutated_args.add(self.args.input_buffers[mutation]) @@ -3641,6 +3677,9 @@ def add_constexpr_arg(arg_name): "num_reduction": self.num_reduction, **self.inductor_meta_common(), } + if self.tiling_scores: + inductor_meta["tiling_scores"] = self.tiling_scores + if self.cooperative_reduction: inductor_meta["persistent_reduction"] = self.persistent_reduction @@ -3928,6 +3967,7 @@ def _has_constant_mask(self, tree: IterationRangesRoot) -> bool: if tree.is_reduction and self.cooperative_reduction: max_block = max_block * self.max_rsplit() + # [Note: Constant mask optimisation] # Optional optimization: if block divides numel exactly, we will # never need to do a masked load to handle stragglers at the end. # If this tree is for the y dimension, we should only use a constant diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 96562779b0872f..dc2392119cc511 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -51,7 +51,7 @@ def _default_custom_combo_kernel_horizontal_partition( node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], ) -> list[list[BaseSchedulerNode]]: """Horizontally partition the given list of nodes into a list of list of nodes where each sublist - represents a partion. Nodes in different partitions are implemented in different combo kernels. + represents a partition. Nodes in different partitions are implemented in different combo kernels. Nodes in the same partition are likely to be implemented in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args. @@ -536,7 +536,7 @@ def select_combo_heuristics( return heuristics_list[0], size_hints_list[0], self.sub_kernels[0] def get_mutated_args_sub_kernels(self) -> list[str]: - mutated_args = OrderedSet[str]() + mutated_args: OrderedSet[str] = OrderedSet() for sub_kernel in self.sub_kernels: for mutation in sub_kernel.mutations: if mutation in sub_kernel.args.input_buffers: diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 7f4d72ee71b0b2..cef1c3ad36bb54 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -36,24 +36,24 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. # Related PR: https://github.com/triton-lang/triton/pull/2279/ if arg.dtype == torch.float8_e4m3fn: - tye = "*fp8e4nv" + typ = "*fp8e4nv" elif arg.dtype == torch.float8_e5m2: - tye = "*fp8e5" + typ = "*fp8e5" elif arg.dtype == torch.float8_e4m3fnuz: - tye = "*fp8e4b8" + typ = "*fp8e4b8" elif arg.dtype == torch.float8_e5m2fnuz: - tye = "*fp8e5b16" + typ = "*fp8e5b16" else: - tye = _type_of(arg.dtype) + typ = _type_of(arg.dtype) if should_unwrap_unspec_arg(arg.buffer): # had unwrapped 0d tensor as scalar - new_tye = tye.lstrip("*") - if new_tye in ["fp16", "bf16"]: + new_typ = typ.lstrip("*") + if new_typ in ["fp16", "bf16"]: return "fp32" else: - return new_tye + return new_typ else: - return tye + return typ if isinstance(arg, SizeArg): if arg.expr is None: if triton_version_uses_attrs_dict(): @@ -81,7 +81,7 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: # no hint: we'll see if we know that this is a 32-bit int, and guard if possible. int_max = torch.iinfo(torch.int32).max if expr_fits_within_32bit(arg.expr): - V.graph.sizevars.guard_leq(arg.expr, int_max) + V.graph.sizevars.check_leq(arg.expr, int_max) return "i32" else: return "i64" @@ -90,7 +90,15 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: if isinstance(arg, WorkspaceArg): return _type_of(arg.dtype) if isinstance(arg, TMADescriptorArg): - return "nvTmaDesc" + if arg.api_type == "experimental": + return "nvTmaDesc" + else: + # https://github.com/triton-lang/triton/blob/9695baed9b46cf957e08b157bb4133f4a4b331c5/python/triton/runtime/jit.py#L360-L363 + assert arg.api_type == "stable" + assert arg.block_shape is not None + assert arg.dtype is not None + inner = _type_of(arg.dtype)[1:] # strip the `*`: *fp32 -> fp32 + return f"tensordesc<{inner}{list(arg.block_shape)}>" if isinstance(arg, ConstexprArg): return "constexpr" raise NotImplementedError(f"unhandled {type(arg)}: {arg}") @@ -111,11 +119,34 @@ def signature_to_meta( size_dtype: Optional[str], argdefs: list[ArgName], indices: Optional[list[int]] = None, + is_template: bool = False, ) -> dict[str, str]: if indices is None: indices = list(range(len(signature))) + + def _decide_tl_dtype(arg): + # Even if the ks0 symbol itself is within tl.int32 range, it's + # risky to use tl.int32 dtype since we may have ks0*ks1 later + # for kernels like torch.mean when dynamic shape is enabled. + # + # Check config.triton.use_block_ptr, since Triton block pointer + # does not support 64bit indexing: + # https://gist.github.com/shunting314/6a41c776171720ce4561f202dcde0ad6 + # + # If the triton metadata is for a template, don't use tl.int64 index. + # Templates like flex attention/decoding uses block pointers which + # does not support 64 bit indexing. + if ( + not config.triton.use_block_ptr + and not is_template + and isinstance(arg, SizeArg) + and arg.name.startswith("ks") + ): + return "tl.int64" + return size_dtype + return { - argdefs[i].name: signature_of(arg, size_dtype=size_dtype) + argdefs[i].name: signature_of(arg, size_dtype=_decide_tl_dtype(arg)) for i, arg in zip(indices, signature) } diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 2bf7539a395198..a1747190a6c8e4 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -48,6 +48,7 @@ DelayReplaceLine, get_benchmark_name, IndentedBuffer, + is_codegen_graph_partition_subgraph, LineContext, set_kernel_post_grad_provenance_tracing, sympy_product, @@ -214,7 +215,7 @@ def writeline(line: str, example_grid: Optional[str] = None): else: assert len(grids) > 1 assert len(grids) == len(configs) - seen = OrderedSet[str]() + seen: OrderedSet[str] = OrderedSet() # sort the configs from the largest # of kwargs to the smallest to # emit the grids in the order of (approximately) decreasing specificity # TODO(aakhundov): the sorting below is generally not sufficient, so @@ -856,7 +857,7 @@ def __init__(self): self.kernel_autotune_defs = IndentedBuffer() self.kernel_autotune_calls = IndentedBuffer() self.subgraph_definitions = IndentedBuffer() - self.kernel_autotune_names = OrderedSet[str]() + self.kernel_autotune_names: OrderedSet[str] = OrderedSet() # Map key is the kernel argument name; value is a tuple of the resulting example # tensor name with the kernel where that tensor was most recently used. self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {} @@ -876,7 +877,9 @@ def __init__(self): self.last_seen_device_guard_index: Optional[int] = None self.supports_intermediate_hooks = True self.user_defined_kernel_cache: dict[tuple[Any, ...], tuple[str, Any]] = {} - self.unbacked_symbol_decls = OrderedSet[str]() # str of sympy.Symbol + self.unbacked_symbol_decls: OrderedSet[str] = ( + OrderedSet() + ) # str of sympy.Symbol self.computed_sizes: OrderedSet[sympy.Symbol] = OrderedSet() self.launcher_fn_name = None # This function can be overridden to change the launcher name @@ -891,10 +894,7 @@ def __init__(self): self.write_header() - if not ( - isinstance(self, SubgraphPythonWrapperCodegen) - and self.partition_signatures is not None - ): + if not is_codegen_graph_partition_subgraph(self): # See [Note: Removed Graph Partition Arguments] self.write_prefix() @@ -915,7 +915,7 @@ def __init__(self): self.write_get_raw_stream ) - @functools.lru_cache(None) + @functools.cache def add_import_once(line: str) -> None: self.imports.writeline(line) if config.triton.autotune_at_compile_time: @@ -923,9 +923,9 @@ def add_import_once(line: str) -> None: self.add_import_once = add_import_once self._metas: dict[str, str] = {} - self._meta_vars = OrderedSet[str]() + self._meta_vars: OrderedSet[str] = OrderedSet() self.multi_kernel_state = MultiKernelState() - self.already_codegened_subgraphs = OrderedSet[str]() + self.already_codegened_subgraphs: OrderedSet[str] = OrderedSet() self.allocated_workspaces: dict[str, Any] = {} # intermediate tensor value printing utility @@ -983,7 +983,6 @@ def write_header(self) -> None: from torch import device, empty_strided from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels - from torch._inductor.codegen.multi_kernel import MultiKernelCall {aot_inductor_debug_utils} """, strip=True, @@ -1057,8 +1056,7 @@ def write_triton_header_once(self) -> None: V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") ) - @cache_on_self - def write_get_raw_stream_header_once(self) -> None: + def write_get_raw_stream_header(self) -> None: if config.triton.autotune_at_compile_time: self.kernel_autotune_calls.writeline( V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") @@ -1068,6 +1066,10 @@ def write_get_raw_stream_header_once(self) -> None: V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") ) + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + self.write_get_raw_stream_header() + def add_meta_once(self, meta: TritonMetaParams) -> str: meta = repr(meta) if meta not in self._metas: @@ -1248,6 +1250,9 @@ def codegen_device_guard_enter(self, device_idx: int) -> None: self.kernel_autotune_calls.writeline( V.graph.device_ops.set_device(device_idx) ) + if is_codegen_graph_partition_subgraph(self): + # Need get_raw_stream for subgraph + self.write_get_raw_stream_header() self.kernel_autotune_calls.writeline( f"stream{device_idx} = get_raw_stream({device_idx})" ) @@ -1260,6 +1265,18 @@ def codegen_device_guard_exit(self) -> None: def generate_return(self, output_refs: list[str]) -> None: if output_refs: + if config.nan_asserts: + self.wrapper_call.writeline( + "return_vars = (" + ", ".join(output_refs) + ", )" + ) + self.wrapper_call.writeline("for var in return_vars:") + self.wrapper_call.do_indent() + self.wrapper_call.writeline("if isinstance(var, torch.Tensor):") + self.wrapper_call.do_indent() + self.wrapper_call.writeline("assert not var.isnan().any().item()") + self.wrapper_call.writeline("assert not var.isinf().any().item()") + self.wrapper_call.do_unindent(2) + self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") else: self.wrapper_call.writeline("return ()") @@ -1284,7 +1301,7 @@ def generate_after_suffix(self, result: IndentedBuffer) -> None: def generate_end(self, result: IndentedBuffer) -> None: return - def generate_fallback_kernel(self, node: ir.FallbackKernel): + def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: self.writeline(ExternKernelAllocLine(self, node)) def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc): @@ -1344,7 +1361,7 @@ def _generate_extern_kernel_out_helper( with debug_printer_manager: self.writeline(f"{kernel}({', '.join(args)})") - def _generate_tma_descriptor_call(self, desc, apply_size_hints=False): + def _generate_tma_descriptor_call_experimental(self, desc, apply_size_hints=False): dims = desc.dims block_dims = desc.block_dims if apply_size_hints: @@ -1366,6 +1383,28 @@ def _generate_tma_descriptor_call(self, desc, apply_size_hints=False): call = f"{fn}({args})" return call + def _generate_tma_descriptor_call_stable(self, desc, apply_size_hints=False): + block_shape = desc.block_shape + if apply_size_hints: + block_shape = tuple( + V.graph.sizevars.atomically_apply_size_hint(d) for d in block_shape + ) + + prefix = "triton.tools.tensor_descriptor.TensorDescriptor" + fn = f"{prefix}.from_tensor" + args = f"{desc.tensor.codegen_reference()}, {block_shape}" + call = f"{fn}({args})" + return call + + def _generate_tma_descriptor_call(self, desc, apply_size_hints=False): + if isinstance(desc, ir.TMADescriptorExperimental): + return self._generate_tma_descriptor_call_experimental( + desc, apply_size_hints + ) + else: + assert isinstance(desc, ir.TMADescriptorStable) + return self._generate_tma_descriptor_call_stable(desc, apply_size_hints) + def generate_tma_descriptor(self, desc): call = self._generate_tma_descriptor_call(desc) line = f"{desc.name} = {call}{self.ending}" @@ -1399,12 +1438,12 @@ def generate_fallback_kernel_with_runtime_lookup( self, buf_name: str, python_kernel_name: str, - codegen_args: Sequence[str], + get_args: Callable[[], Sequence[str]], op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], raw_args: Sequence[Any], outputs: Sequence[ir.Buffer], ) -> None: - self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})") + self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})") def generate(self, is_inference): with dynamo_timed("PythonWrapperCodegen.generate"): @@ -1606,12 +1645,12 @@ def codegen_input_symbol_assignment( ): code = self.prefix - @functools.lru_cache(None) + @functools.cache def sizeof(name): code.writeline(f"{name}_size = {name}.size()") return f"{name}_size" - @functools.lru_cache(None) + @functools.cache def strideof(name): code.writeline(f"{name}_stride = {name}.stride()") return f"{name}_stride" @@ -1757,12 +1796,12 @@ def codegen_reinterpret_view( f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" ) - def codegen_device_copy(self, src, dst, non_blocking: bool): + def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): self.writeline(f"{dst}.copy_({src}, {non_blocking})") def codegen_multi_output(self, node: ir.MultiOutput): result_name = node.get_name() - arg_name = node.inputs[0].get_name() + arg_name = node.input_name(0) self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) def codegen_dynamic_scalar(self, node): @@ -2033,10 +2072,18 @@ def add_arg(idx, arg, is_constexpr=False, equals_1=False, equals_none=False): add_arg(idx, ConstexprArg(name=key), equals_none=True) else: if isinstance(arg, ir.TMADescriptor): + api_type, block_shape, dtype = ( + ("stable", arg.block_shape, arg.tensor.get_dtype()) + if isinstance(arg, ir.TMADescriptorStable) + else ("experimental", None, None) + ) add_arg( idx, TMADescriptorArg( name=key, + api_type=api_type, + block_shape=block_shape, + dtype=dtype, ), ) elif isinstance(arg, ir.Buffer): @@ -2192,6 +2239,7 @@ def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr: if config.triton.unique_user_kernel_names: # We replace the original_name with the unique name. kernel_src = kernel_src.replace(f"def {original_name}(", f"def {name}(") + kernel_src = kernel_src.replace("'''", "\\'\\'\\'") compile_wrapper.splice(kernel_src) current_device = V.graph.get_current_device_or_throw() @@ -2343,7 +2391,7 @@ def generate_example_arg_value(self, arg, arg_type, raw_arg=None): if isinstance(arg_type, torch_dtype): if isinstance(raw_arg, ir.TMADescriptor): # first we generate the underlying buffer - buf_name = raw_arg.tensor.get_name() + buf_name = raw_arg.get_tensor().get_name() buf = self.args_to_buffers[arg] elif self.args_to_buffers.get(arg): buf_name = arg @@ -2521,6 +2569,12 @@ def _generate_kernel_call_helper( "call_args and arg_types do not match" ) + autotune_args = None + if original_fxnode_name and V.graph.autotuning_mapping: + autotune_args = V.graph.autotuning_mapping.get( + original_fxnode_name, None + ) + def get_autotune_deletion_call() -> str: """After all the autotune kernel calls have been written (i.e. self.kernel_autotune_example_args is complete), returns a deletion call @@ -2535,6 +2589,39 @@ def get_autotune_deletion_call() -> str: return f"del {', '.join(tensors_to_delete)}\n" return "" + def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args): + """We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args. + This is particularly useful for jagged cases, where the dimension is often + being passed in as an input.""" + + target_arg = raw_args[idx] + if target_arg in reused_args: + return True + + for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)): + if i == idx or not isinstance(raw_arg, IRNode): + continue + + triton_input = "" + if autotune_args and raw_key in autotune_args: + triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined] + autotune_args[raw_key] + ) + if triton_input == "": + continue + + try: + layout = raw_arg.get_layout() + for dim, s in enumerate(layout.size): + if s == target_arg: + reused_args[target_arg] = f"{triton_input}.shape[{dim}]" + return True + except NotImplementedError: + # If layout for this IRNode is not implemented, we could just skip. + # Only raise for other Error cases. + continue + return False + all_args = [] if raw_args is None: # create a dummy raw_args for uniform behavior in the following loop @@ -2546,11 +2633,7 @@ def get_autotune_deletion_call() -> str: "call_args and raw_args do not match" ) - autotune_args = None - if original_fxnode_name and V.graph.autotuning_mapping: - autotune_args = V.graph.autotuning_mapping.get( - original_fxnode_name, None - ) + reused_args = {} for i, (arg, arg_type, raw_key, raw_arg) in enumerate( zip(call_args, arg_types, raw_keys, raw_args) ): @@ -2567,6 +2650,17 @@ def get_autotune_deletion_call() -> str: if triton_input: arg_str = triton_input + if not isinstance(arg_type, torch_dtype) and ( + issubclass(arg_type, sympy.Basic) + or isinstance(arg, SymbolicCallArg) + ): + reused_args[raw_arg] = arg_str + elif raw_key == "" and infer_arg_by_inputs( + raw_keys, raw_args, i, reused_args + ): + # Empty raw_key means this is a arg that's not native to the triton kernel, + # and is being added by inductor. + arg_str = reused_args[raw_arg] elif isinstance(arg_type, torch_dtype): # workspace allocation is already generated by `generate_workspace_allocation()` # in `TritonKernel.call_kernel()`. @@ -3266,7 +3360,9 @@ def get_graph_inputs( self, ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]: if signature := self.partition_signatures: - inputs = signature.input_nodes + inputs = signature.input_nodes | { + str(s): s for s in signature.symbol_inputs + } else: inputs = V.graph.graph_inputs return inputs diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py index 9d2df32c0f62aa..a2b8954366690f 100644 --- a/torch/_inductor/codegen/wrapper_fxir.py +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -1,4 +1,6 @@ import dataclasses +import functools +import logging import operator import textwrap from collections import Counter @@ -19,9 +21,10 @@ from torch._inductor.virtualized import V from torch._library.triton import wrap_triton from torch.fx import GraphModule +from torch.utils import _pytree as pytree from torch.utils._sympy.functions import FloorDiv -from .. import ir +from .. import config, ir from ..utils import convert_shape_to_symint, convert_to_symint, LineContext from .common import ( CodegenSymbol, @@ -58,6 +61,7 @@ aten = torch.ops.aten +log = logging.getLogger(__name__) @dataclasses.dataclass @@ -194,6 +198,23 @@ def _create_meta_from_buffer( node.name = name node.meta["val"] = buffer.get_example() + def _create_as_strided( + self, + input_node: torch.fx.Node, + size: tuple[Any, ...], + stride: tuple[Any, ...], + offset: Union[int, sympy.Expr], + ) -> torch.fx.Node: + return self.gm.graph.call_function( + torch.as_strided, + args=( + input_node, + convert_shape_to_symint(size), + convert_shape_to_symint(stride), + convert_to_symint(offset), + ), + ) + def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None: """ Updates the symbol table to record that an Inductor buffer maps to the result of @@ -239,8 +260,13 @@ def _generate_graph_inputs(self) -> None: """ Converts graph inputs to FX placeholders. """ - for ir_node in V.graph.graph_inputs.values(): - buffer = self._get_buffer(ir_node) + for name, ir_node in V.graph.graph_inputs.items(): + # Introduce a new symbol for constant inputs. + buffer = ( + SymbolBuffer(sympy.Symbol(name, is_integer=True)) + if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float)) + else self._get_buffer(ir_node) + ) node = self.gm.graph.placeholder(buffer.get_name()) self._create_meta_from_buffer(node, buffer) self._record_allocation(buffer, node) @@ -406,12 +432,17 @@ def _generate_reinterpret_helper( assert name size = tuple(layout.size) stride = tuple(layout.stride) + if isinstance(layout, ir.NonOwningLayout): + # Look up the view's layout. + view = layout.view + assert isinstance(view, ir.ReinterpretView), ( + f"unexpected type: {type(view)}" + ) + layout = view.layout offset = input_buffer.get_offset() + layout.offset # Map ReinterpretView to as_strided. - result_node = self.gm.graph.call_function( - torch.as_strided, args=(input_node, size, stride, offset) - ) + result_node = self._create_as_strided(input_node, size, stride, offset) result_node.name = name result_node.meta["val"] = layout.get_example() self._record_allocation(result_buffer, result_node) @@ -427,17 +458,15 @@ def _generate_reuse(self, line: WrapperLine) -> None: result_node = old_node # Change shape and stride. - size = new.get_size() - stride = new.get_stride() + size = tuple(new.get_size()) + stride = tuple(new.get_stride()) offset = new.get_offset() if ( - old.get_size() != size - or old.get_stride() != stride + tuple(old.get_size()) != size + or tuple(old.get_stride()) != stride or old.get_offset() != offset ): - result_node = self.gm.graph.call_function( - torch.as_strided, args=(old_node, size, stride, offset) - ) + result_node = self._create_as_strided(old_node, size, stride, offset) self._create_meta_from_buffer(result_node, new) self._record_allocation(new, result_node) @@ -483,10 +512,47 @@ def _generate_triton_call(self, line: WrapperLine) -> None: call_args = self._lookup_args(line.call_args) kernel = self.kernels[line.kernel_name] tuner = kernel.tuner - config = tuner.compile_results[0].config - call_args, grid = tuner._interpret_args_grid(call_args, config) + + # Optionally autotune the kernels. + # The FX backend currently only supports compile-time tuning. + kernel_name = tuner.fn.__name__ + if config.triton.autotune_at_compile_time: + from triton.runtime import driver + + log.info("Autotuning Triton kernel %s at compile time.", kernel_name) + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + + def node_to_tuning_arg(arg: Any) -> Any: + """ + Create real tensors for autotuning arguments, substituting size hints + for dynamic shapes. + """ + to_size_hint = functools.partial( + pytree.tree_map, V.graph.sizevars.size_hint + ) + if not isinstance(arg, torch.fx.Node): + return to_size_hint(arg) + + fake = arg.meta["val"] + return torch.empty_strided( + to_size_hint(fake.shape), + to_size_hint(fake.stride()), + device=device, + ).zero_() + + arg_values = [node_to_tuning_arg(arg) for arg in call_args] + tuner.run(*arg_values, stream=stream) + else: + log.info( + "Skipping autotuning for kernel %s. Set config.triton.autotune_at_compile_time = True to enable.", + kernel_name, + ) + + kernel_config = tuner.compile_results[0].config + call_args, grid = tuner._interpret_args_grid(call_args, kernel_config) call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args)) - call_kwargs.update(config.kwargs) + call_kwargs.update(kernel_config.kwargs) def replace_floor_div(expr: sympy.Expr) -> sympy.Expr: """ @@ -561,6 +627,7 @@ def _generate_extern_kernel_common( """ # Get FX nodes corresponding to the call args. + assert ir.is_node_sequence(kernel.inputs) tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs) args = tensor_nodes + tuple(kernel.constant_args) diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 8678e30d26b085..632cfd29f174fa 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -2,7 +2,11 @@ from typing import Optional -from ..common import DeviceOpOverrides, register_device_op_overrides +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) class XPUDeviceOpOverrides(DeviceOpOverrides): @@ -54,7 +58,9 @@ def cpp_kernel_type(self) -> str: def cpp_device_ptr(self) -> str: return "void *" - def cpp_global_scratch(self, idx: int) -> Optional[tuple[str, str]]: + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: return None diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index f8a233a3b9e21e..2a69a053134798 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -67,7 +67,7 @@ def get_collective_input_size_bytes(node: ir.IRNode) -> int: def get_collective_group_size(node: ir.IRNode) -> int: - if type(node) == ir._CollectiveKernel: + if isinstance(node, ir._CollectiveKernel) and not isinstance(node, ir._WaitKernel): from torch.distributed.distributed_c10d import _get_group_size_by_name return _get_group_size_by_name(node.constant_args[-1]) diff --git a/torch/_inductor/comm_lowering.py b/torch/_inductor/comm_lowering.py index 408c211b8af600..b748f61f067b9d 100644 --- a/torch/_inductor/comm_lowering.py +++ b/torch/_inductor/comm_lowering.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import logging -from typing import cast import torch import torch.utils._pytree as pytree @@ -114,10 +113,12 @@ def realize_as_comm_buffer( def _get_data(x: ir.TensorBox) -> ir.IRNode: if isinstance(x.data, ir.BaseView): # TensorBox -> *View -> StorageBox -> IRNode - return x.data.unwrap_view().data + node = x.data.unwrap_view() + assert isinstance(node, (ir.BaseView, ir.MutableBox)) + return node.data elif isinstance(x.data, ir.StorageBox): # TensorBox -> StorageBox -> IRNode - return cast(ir.Buffer, x.data.data) + return x.data.data else: raise AssertionError( "Expect the data attr of a `TensorBox` to be either " @@ -208,10 +209,13 @@ def _all_reduce(inp: ir.TensorBox, reduce_op: str, group_name: str) -> ir.Tensor inp.realize() V.graph.no_fuse_buffer_names.add(inp.get_name()) inp = ir.ExternKernel.require_contiguous(inp) - ir._CollectiveKernel.create_inplace( - c10d.all_reduce_.default, inp, reduce_op, group_name + ir._AllReduceKernel.create_inplace( + c10d.all_reduce_.default, + inp, # type: ignore[arg-type] + reduce_op, + group_name, # type: ignore[arg-type] ) - return inp + return inp # type: ignore[return-value] @register_comm_lowering(c10d.all_reduce_) # type: ignore[misc] def _all_reduce_( @@ -227,10 +231,13 @@ def _all_reduce_( # Lower as c10d.all_reduce_ inp = ir.ExternKernel.require_contiguous(inp) - ir._CollectiveKernel.create_inplace( - c10d.all_reduce_.default, inp, reduce_op, group_name + ir._AllReduce_Kernel.create_inplace( + c10d.all_reduce_.default, + inp, # type: ignore[arg-type] + reduce_op, + group_name, # type: ignore[arg-type] ) - return inp + return inp # type: ignore[return-value] @register_comm_lowering(c10d.all_reduce_coalesced) def _all_reduce_coalesced(inputs, reduce_op, group_name): @@ -253,15 +260,18 @@ def _all_reduce_coalesced_(inputs, reduce_op, group_name): ) return inputs + def _create_out_of_place(kernel, inputs, *args) -> ir.IRNode: + node = ir._CollectiveKernel.create_out_of_place(kernel, inputs, *args) + assert isinstance(node, ir.IRNode) + return ir.TensorBox.create(node) + @register_comm_lowering(c10d.all_gather_into_tensor) def _all_gather_into_tensor(inp, group_size, group_name): - return ir.TensorBox.create( - ir._CollectiveKernel.create_out_of_place( - c10d.all_gather_into_tensor.default, - inp, - group_size, - group_name, - ) + return _create_out_of_place( + c10d.all_gather_into_tensor.default, + inp, + group_size, + group_name, ) @register_comm_lowering(c10d.all_gather_into_tensor_coalesced) @@ -289,14 +299,12 @@ def _all_gather_into_tensor_out(inp, group_size, group_name, *, out): @register_comm_lowering(c10d.reduce_scatter_tensor) def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): - return ir.TensorBox.create( - ir._CollectiveKernel.create_out_of_place( - c10d.reduce_scatter_tensor.default, - inp, - reduce_op, - group_size, - group_name, - ) + return _create_out_of_place( + c10d.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, ) @register_comm_lowering(c10d.reduce_scatter_tensor_coalesced) @@ -314,14 +322,12 @@ def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): @register_comm_lowering(c10d.all_to_all_single) def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): - return ir.TensorBox.create( - ir._CollectiveKernel.create_out_of_place( - c10d.all_to_all_single.default, - inp, - output_split_sizes, - input_split_sizes, - group_name, - ) + return _create_out_of_place( + c10d.all_to_all_single.default, + inp, + output_split_sizes, + input_split_sizes, + group_name, ) @register_comm_lowering(c10d.broadcast) @@ -341,14 +347,12 @@ def _broadcast_(inp, src, group_name): @register_comm_lowering(torch.ops._dtensor.shard_dim_alltoall) def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): - return ir.TensorBox.create( - ir._CollectiveKernel.create_out_of_place( - torch.ops._dtensor.shard_dim_alltoall.default, - inp, - gather_dim, - shard_dim, - group_name, - ) + return _create_out_of_place( + torch.ops._dtensor.shard_dim_alltoall.default, + inp, + gather_dim, + shard_dim, + group_name, ) @register_comm_lowering(c10d.wait_tensor) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index de1fc5c42593da..41746976ecaf7d 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -13,6 +13,7 @@ from typing import Any, TYPE_CHECKING import torch +from torch._logging import trace_structured from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._ordered_set import OrderedSet @@ -109,45 +110,6 @@ def reorder_communication_preserving_peak_memory( reordered_snodes, node_stats = ( _reorder_communication_preserving_peak_memory_internal(snodes) ) - improvement = {snode: node_stats[snode].improvement for snode in node_stats} - total_improvement = sum([improvement[snode] for snode in improvement]) - total_moves = sum([node_stats[snode].moves for snode in node_stats]) - overlap_log.info( - "reorder_communication_preserving_peak_memory improved overlap by %f ns after %d reorders", - total_improvement, - total_moves, - ) - - if importlib.util.find_spec("tabulate"): - from tabulate import tabulate - - overlap_log.info( - tabulate( - [ - [ - node_summary(snode), - node_reorder_info.initial_exposed, - node_reorder_info.final_exposed, - node_reorder_info.improvement, - node_reorder_info.limiting_factor, - node_reorder_info.moves, - ] - for snode, node_reorder_info in node_stats.items() - ], - headers=[ - "Collective node", - "initial exposed", - "final exposed", - "improvement", - "limiting factor", - "moves", - ], - ) - ) - else: - overlap_log.info( - "Please `pip install tabulate` to nicely render overlap stats." - ) return reordered_snodes @@ -230,9 +192,7 @@ def exposed_communication_time(collective_snode, remaining_snodes): reorder_info.limiting_factor = "collective ordering" break dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) - if any( - o.get_name() in dep_names for o in prev_snode.get_outputs() - ) and not contains_wait(prev_snode): + if any(o.get_name() in dep_names for o in prev_snode.get_outputs()): reorder_info.limiting_factor = "data dependency" break if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: @@ -254,6 +214,57 @@ def exposed_communication_time(collective_snode, remaining_snodes): snode, snodes[j + 1 :] ) + node_stats = stats + improvement = {snode: node_stats[snode].improvement for snode in node_stats} + total_improvement = sum([improvement[snode] for snode in improvement]) + total_moves = sum([node_stats[snode].moves for snode in node_stats]) + + reorder_log_str = ( + f"reorder_communication_preserving_peak_memory improved overlap by {total_improvement} ns" + f" after {total_moves} reorders.\n" + ) + headers = [ + "Collective node", + "initial exposed", + "final exposed", + "improvement", + "limiting factor", + "moves", + ] + rows = [ + [ + node_summary(snode), + node_reorder_info.initial_exposed, + node_reorder_info.final_exposed, + node_reorder_info.improvement, + node_reorder_info.limiting_factor, + node_reorder_info.moves, + ] + for snode, node_reorder_info in node_stats.items() + ] + if importlib.util.find_spec("tabulate"): + from tabulate import tabulate + + reorder_log_str += tabulate( + rows, + headers=headers, + ) + else: + reorder_log_str += ( + "Please `pip install tabulate` to nicely render overlap stats.\n" + ) + reorder_log_str += str(headers) + "\n" + reorder_log_str += "\n".join(map(str, rows)) + overlap_log.info(reorder_log_str) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "reorder_communication_preserving_peak_memory", + "encoding": "string", + }, + payload_fn=lambda: reorder_log_str, + ) + return snodes, stats @@ -315,8 +326,8 @@ def _schedule_for_comm( for snode in snodes: if raise_comms and contains_collective(snode): scores_0[snode.get_name()] = comm_idx - for anc in snode.ancestors: - anc_fused_name = name_to_fused_node[anc].get_name() + for ancestor in snode.ancestors: + anc_fused_name = name_to_fused_node[ancestor].get_name() scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx) comm_idx += 1 elif sink_waits and contains_wait(snode): @@ -475,7 +486,7 @@ def node_summary(snode): def visualize_overlap(order): - # TODO - this function probably doesn't do a very good job estimating the runtime becuase it doesn't carefully model + # TODO - this function probably doesn't do a very good job estimating the runtime because it doesn't carefully model # streams and overlap. For now its mostly useful as a debug visualization. total_est_runtime: float = 0.0 @@ -531,7 +542,10 @@ def reorder_compute_and_comm_for_overlap( overlap_log.debug( f"==== Visualize overlap before reordering pass {p}, {peak_memory=} ====" # noqa: G004 ) - visualize_overlap(order) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug("", exc_info=e) t0 = time.time() order = p(order) # type: ignore[operator] t = time.time() - t0 @@ -542,7 +556,7 @@ def reorder_compute_and_comm_for_overlap( try: visualize_overlap(order) except Exception as e: - overlap_log.debug(str(e)) + overlap_log.debug("", exc_info=e) peak_memory, _ = estimate_peak_memory( snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs ) @@ -804,12 +818,10 @@ def remove_unused_getitem(g): CallFunction( torch.ops.fsdp.all_gather_copy_in.default, KeywordArg("all_gather_inputs"), + KeywordArg("all_gather_output"), KeywordArg("inp_split_sizes"), KeywordArg("all_gather_input_numel"), - KeywordArg("world_size"), KeywordArg("rank"), - KeywordArg("dtype"), - KeywordArg("device"), ), KeywordArg("item_idx"), ), @@ -842,12 +854,10 @@ def repl( repl, [ kwargs["all_gather_inputs"], + kwargs["all_gather_output"], kwargs["inp_split_sizes"], kwargs["all_gather_input_numel"], - kwargs["world_size"], kwargs["rank"], - kwargs["dtype"], - kwargs["device"], kwargs["group_size"], kwargs["group_name"], ], diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 87cf80cfe0f702..4c8711db43ef87 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -14,6 +14,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import AbstractContextManager +from dataclasses import dataclass from inspect import currentframe from itertools import count from operator import attrgetter @@ -75,7 +76,8 @@ from torch._inductor.utils import ( BoxedBool, count_tangents, - fresh_inductor_cache, + fresh_cache, + get_all_devices, InputType, is_gpu, should_assume_input_aligned, @@ -127,6 +129,7 @@ from torch._inductor.output_code import _StrideExprStr from torch._ops import OpOverload + from torch.export.pt2_archive._package_weights import Weights from .ir import ExternKernelNode @@ -162,21 +165,32 @@ class FxCompileMode(enum.Enum): SUBPROCESS = 2 -# Return compile mode and use_async flag -def _fx_compile_mode_default() -> tuple[FxCompileMode, bool]: +@dataclass +class FxCompileConfig: + mode: FxCompileMode + use_async: bool + use_progressive: bool + + +def _fx_compile_mode_default() -> FxCompileConfig: name = "TORCHINDUCTOR_FX_COMPILE_MODE" value = os.environ.get(name) if value is None: - return FxCompileMode.NORMAL, False + return FxCompileConfig(FxCompileMode.NORMAL, False, False) use_async = False + use_progressive = False + + if value.lower().startswith("progressive+"): + use_progressive = True + value = value[12:] if value.lower().startswith("async+"): use_async = True value = value[6:] try: value = value.upper() - return FxCompileMode[value], use_async + return FxCompileConfig(FxCompileMode[value], use_async, use_progressive) except KeyError: import logging @@ -189,10 +203,20 @@ def _fx_compile_mode_default() -> tuple[FxCompileMode, bool]: ) # Remove from the environment so subprocesses don't ALSO complain. os.environ.pop(name) - return FxCompileMode.NORMAL, False + return FxCompileConfig(FxCompileMode.NORMAL, False, False) + + +def _get_progression_configs() -> list[dict[str, Any]]: + # TODO make this configurable + return [ + {"max_autotune": True}, + ] -fx_compile_mode, fx_compile_async = _fx_compile_mode_default() +_fx_compile_config = _fx_compile_mode_default() +fx_compile_mode = _fx_compile_config.mode +fx_compile_async = _fx_compile_config.use_async +fx_compile_progressive = _fx_compile_config.use_progressive log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") @@ -238,12 +262,39 @@ def record_original_output_strides(gm: GraphModule) -> None: output_node.meta["original_output_strides"] = output_strides +def _recursive_record_original_output_strides(gm: GraphModule) -> None: + # invoke_subgraph HOP requires output strides to be respected + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_subgraph + ): + subgraph = getattr(gm, node.args[0].target) + _recursive_record_original_output_strides(subgraph) + + record_original_output_strides(gm) + + +def _recursive_record_user_visible_output_idxs(gm: GraphModule) -> None: + # invoke_subgraph HOP requires output strides to be respected + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_subgraph + ): + subgraph = getattr(gm, node.args[0].target) + + for node in subgraph.graph.find_nodes(op="output"): + node.meta["user_visible_output_idxs"] = [ + idx + for idx in range(len(node.args[0])) + if isinstance(node.args[0][idx], torch.fx.Node) + ] + _recursive_record_user_visible_output_idxs(subgraph) + + @functools.lru_cache(None) def _step_logger() -> Callable[..., None]: return dynamo_logging.get_step_logger(log) -@functools.lru_cache(None) +@functools.cache def _warn_tf32_disabled() -> None: if ( torch.cuda.is_available() @@ -647,7 +698,7 @@ def with_fresh_cache_if_config() -> Generator[None, None, None]: # Don't delete the cache dir because it has to survive beyond the # compile_fx call. Let's put the temp dirs under the default cache # dir so they're easier to locate. - with fresh_inductor_cache(dir=cache_dir(), delete=False): + with fresh_cache(dir=cache_dir(), delete=False): yield else: yield @@ -707,7 +758,6 @@ def compile_fx_inner( dynamo_compile_column_us="inductor_cumulative_compile_time_us", ) ) - stack.enter_context(torch._dynamo.callback_handler.install_callbacks()) stack.enter_context(with_fresh_cache_if_config()) stack.enter_context(DebugContext()) CompileEventLogger.pt2_compile( @@ -742,9 +792,17 @@ def _compile_fx_inner( if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: # trigger the real recompilation for _LazyGraphModule before returning # the forward method. + from torch._dynamo.utils import CompileEventLogLevel from torch.fx._lazy_graph_module import _LazyGraphModule _LazyGraphModule.force_recompile(gm) + compile_id = torch._guards.CompileContext.current_compile_id() + CompileEventLogger.log_instant_event( + "backward no-op", + metadata={"compile_id": compile_id}, + log_level=CompileEventLogLevel.PT2_COMPILE, + ) + return make_boxed_func(gm.forward) static_input_idxs: Sequence[int] = graph_kwargs.setdefault("static_input_idxs", ()) @@ -787,6 +845,7 @@ def _compile_fx_inner( and (config.fx_graph_cache or fx_graph_remote_cache) and not aot_mode and backends_support_caching + and not torch._functorch.config.bundled_autograd_cache ) local = config.fx_graph_cache remote = fx_graph_remote_cache @@ -992,19 +1051,37 @@ def _compile_fx_inner( if log.isEnabledFor(logging.INFO): mm_table_data = [] for key, value in counters["aten_mm_info"].items(): - m, n, k = key.split("_")[-3:] - name = "_".join(key.split("_")[:-3]) - mm_table_data.append([name, m, n, k, value]) + parts = key.split("_") + if len(parts) < 3: + # Unexpected format, show as-is + mm_table_data.append([key, "-", "?", "?", "?", value]) + continue + + # Determine if this is a batched operation by checking the operation name + name = "_".join(parts[:-4]) if len(parts) >= 4 else "_".join(parts[:-3]) + is_batched = name.endswith(("bmm", "baddbmm")) + + if is_batched and len(parts) >= 4: + # Batched operation: last 4 parts are batch, m, n, k + batch, m, n, k = parts[-4:] + name = "_".join(parts[:-4]) + mm_table_data.append([name, batch, m, n, k, value]) + else: + # Non-batched operation: last 3 parts are m, n, k + m, n, k = parts[-3:] + name = "_".join(parts[:-3]) + mm_table_data.append([name, "-", m, n, k, value]) + log.info("Overview info of inductor aten mms: ") log.info( - "{:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001 - "Name", "M", "N", "K", "Count" + "{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001 + "Name", "B", "M", "N", "K", "Count" ) ) - log.info("-" * 100) + log.info("-" * 130) for row in mm_table_data: - log.info("{:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001 - log.info("-" * 100) + log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001 + log.info("-" * 130) # Not strictly necessary, but good to clean up straggling futures # that are unused to reclaim memory. @@ -1160,7 +1237,7 @@ def codegen_and_compile( with torch.no_grad(): fake_mode = fake_tensor_prop(gm, example_inputs) - record_original_output_strides(gm) + _recursive_record_original_output_strides(gm) # pattern matcher passes might not preserve striding information # on node.meta["val"]. if in the future we rely on these being @@ -1521,6 +1598,21 @@ def fx_codegen_and_compile( ) scheme = _AsyncFxCompile(scheme) + if fx_compile_progressive: + from .compile_fx_async import _ProgressiveFxCompile + from .compile_fx_ext import _OutOfProcessFxCompile + + assert isinstance(scheme, _OutOfProcessFxCompile), ( + "progressive is only valid with an out-of-process compile mode" + ) + + progression_configs = _get_progression_configs() + + # Use in-process compile for the fast version + fast_scheme = _InProcessFxCompile() + + scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs) + return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) @@ -1597,8 +1689,8 @@ def run(new_inputs: Sequence[InputType]) -> Any: nonlocal compiled_fn if compiled_fn is None: with dynamo_utils.preserve_rng_state(): - compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) - return compiled_fn(new_inputs) + compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) # type: ignore[arg-type] + return compiled_fn(new_inputs) # type: ignore[arg-type] return run @@ -1721,7 +1813,7 @@ def compile_fx_aot( example_inputs_: list[InputType], inner_compile: _CompileFxCallable = compile_fx_inner, config_patches: Optional[dict[str, str]] = None, -) -> Union[list[str], str]: +) -> Union[list[Union[str, Weights]], str]: assert isinstance(model_, GraphModule), model_ # [See NOTE] Unwrapping subclasses AOT @@ -1902,22 +1994,6 @@ def get_cpp_wrapper_config() -> dict[str, object]: } -def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]: - placeholder_nodes = gm.graph.find_nodes(op="placeholder") - input_devices: OrderedSet[torch.device] = OrderedSet( - node.meta["val"].device - for node in placeholder_nodes - if isinstance(node.meta.get("val"), torch.Tensor) - ) - - out_devices: OrderedSet[torch.device] = OrderedSet( - arg.meta["val"].device - for arg in output_node(gm).args[0] # type: ignore[union-attr] - if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor) - ) - return input_devices | out_devices - - def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]: """ Returns a cuda device context manager if there is a single device in the graph @@ -1943,7 +2019,7 @@ def compile_fx( config_patches: Optional[dict[str, Any]] = None, decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None, ignore_shape_env: bool = False, -) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]: +) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str], Weights]: """ Main entry point for compiling given FX graph. Despite the fact that this lives in :mod:`torch._inductor`, this function is responsible for calling @@ -2204,6 +2280,11 @@ def fw_compiler_base( else: model_outputs_node.meta["user_visible_output_idxs"] = [] + # We also mark the invoke_subgraph outputs as user_visible to + # force the outputs of invoke_subgraph subgraph to follow the + # original strides + _recursive_record_user_visible_output_idxs(gm) + return inner_compile( gm, example_inputs, diff --git a/torch/_inductor/compile_fx_async.py b/torch/_inductor/compile_fx_async.py index 4d412edbb72cbd..05c896ae864484 100644 --- a/torch/_inductor/compile_fx_async.py +++ b/torch/_inductor/compile_fx_async.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import deque from dataclasses import dataclass from typing import Any, Callable, Optional, TYPE_CHECKING from typing_extensions import final, override @@ -11,6 +12,10 @@ from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401 +# When async compile works with cache, remove the disabling below +BUG_CACHES_DONT_WORK_WITH_ASYNC = True + + if TYPE_CHECKING: from collections.abc import Sequence from concurrent.futures import Future @@ -28,11 +33,53 @@ class _PostCompileData: graph_kwargs: _CompileFxKwargs +@dataclass +class ProgressiveCompilationState: + progression_futures: deque[Future[_WireProtocolPickledOutput]] + callback: Callable[[_WireProtocolPickledOutput], OutputCode] + post_compile_data: Optional[_PostCompileData] + + def check_and_get_ready_stage(self) -> int: + """Check if any progression stage is ready and return its index, or -1 if none are ready.""" + if not self.progression_futures: + return -1 + + stage_index = -1 + if self.post_compile_data: + for i, future in enumerate(self.progression_futures): + if future.done(): + stage_index = i + + return stage_index + + def switch_to_progression_stage(self, stage_index: int) -> tuple[OutputCode, bool]: + """ + Switch to the specified progression stage and return the optimized output code. + Returns a tuple of (optimized_output_code, should_clear_compilation_state). + """ + future = self.progression_futures[stage_index] + assert future is not None + optimized_output_code = self.callback(future.result()) + + if pcd := self.post_compile_data: + optimized_output_code.post_compile( + pcd.example_inputs, pcd.constants, pcd.graph_kwargs + ) + + # Clear earlier progression futures to free memory + for _ in range(stage_index + 1): + self.progression_futures.popleft() + + # Return whether all compilation state should be cleared + should_clear_state = not self.progression_futures + return optimized_output_code, should_clear_state + + # _AsyncOutputCode handles the actual management of waiting for an # out-of-process compile to finish and then switching over to it. @final class _AsyncOutputCode(OutputCode): - _eager_forward: Optional[Callable[..., Any]] + _eager_fn: Optional[Callable[..., Any]] _output_code: Optional[OutputCode] _future: Optional[Future[_WireProtocolPickledOutput]] _callback: Callable[[_WireProtocolPickledOutput], OutputCode] @@ -41,16 +88,16 @@ class _AsyncOutputCode(OutputCode): def __init__( self, - # eager_forward is run until the future is finished. - eager_forward: Callable[..., Any], + # eager_fn is run until the future is finished. + eager_fn: Callable[..., Any], # this responds with the result of the out-of-process compile when it's # ready. future: Future[_WireProtocolPickledOutput], # this callback gets called to turn the _WireProtocolPickledOutput into an OutputCode callback: Callable[[_WireProtocolPickledOutput], OutputCode], ) -> None: - self._eager_forward = eager_forward - self._boxed_call = getattr(eager_forward, "_boxed_call", False) + self._eager_fn = eager_fn + self._boxed_call = getattr(eager_fn, "_boxed_call", False) self._output_code = None self._future = future @@ -59,11 +106,11 @@ def __init__( @override def __call__(self, *args: Any) -> Any: if self._future is not None and self._future.done(): - args = self._switch_to_compiled_forward(args) + args = self._switch_to_compiled_fn(args) - if eager_forward := self._eager_forward: + if eager_fn := self._eager_fn: _AsyncFxCompile._stat_eager_runs += 1 - return eager_forward(*args) + return eager_fn(*args) else: _AsyncFxCompile._stat_compiled_runs += 1 @@ -71,7 +118,7 @@ def __call__(self, *args: Any) -> Any: return self._output_code.__call__(*args) # Takes and returns the args (converted to the "right" boxed mode) - def _switch_to_compiled_forward(self, args: tuple[Any, ...]) -> tuple[Any, ...]: + def _switch_to_compiled_fn(self, args: tuple[Any, ...]) -> tuple[Any, ...]: assert self._future is not None # TODO: If the future ended in an exception do we want to continue @@ -87,7 +134,7 @@ def _switch_to_compiled_forward(self, args: tuple[Any, ...]) -> tuple[Any, ...]: ) self._output_code = output_code - self._eager_forward = None + self._eager_fn = None boxed_call = getattr(output_code, "_boxed_call", False) if self._boxed_call != boxed_call: @@ -108,7 +155,7 @@ def post_compile( constants: CompiledFxGraphConstants, graph_kwargs: _CompileFxKwargs, ) -> None: - if self._eager_forward is not None: + if self._eager_fn is not None: self._post_compile_data = _PostCompileData( example_inputs, constants, graph_kwargs ) @@ -171,7 +218,7 @@ def codegen_and_compile( _AsyncFxCompile._stat_bg_started += 1 f = self._compile._send_to_child_async(inputs) - # This is called by _switch_to_compiled_forward() when f has a result... + # This is called by _switch_to_compiled_fn() when f has a result... def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: _AsyncFxCompile._stat_bg_finished += 1 output = pickled_output.deserialize(constants) @@ -179,3 +226,173 @@ def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: return output.graph return _AsyncOutputCode(eager_output_code, f, callback) + + +# _ProgressiveOutputCode handles running a fast compile first, then hot-swapping +# to a more optimized version when the expensive compile finishes. +@final +class _ProgressiveOutputCode(OutputCode): + _fast_output_code: Optional[OutputCode] + _optimized_output_code: Optional[OutputCode] + _compilation_state: Optional[ProgressiveCompilationState] + # _boxed_call state is effectively cached (we sometimes wrap unboxed w/ + # lambdas to box them) so we can't change it mid-way. Since _boxed_call=True + # is more common let's default to that and we'll convert if necessary. + _boxed_call: bool = True + + def __init__( + self, + # Fast compile that runs faster than the progressive compiles + fast_output_code: OutputCode, + # Futures for the progressive optimized compiles + progression_futures: Sequence[Future[_WireProtocolPickledOutput]], + # Callback to convert the optimized result to OutputCode + callback: Callable[[_WireProtocolPickledOutput], OutputCode], + ) -> None: + self._fast_output_code = fast_output_code + self._optimized_output_code = None + self._compilation_state = ProgressiveCompilationState( + progression_futures=deque(progression_futures), + callback=callback, + post_compile_data=None, + ) + + @override + def __call__(self, args: Sequence[Any]) -> Any: + # Check if any newer progression stage is ready and switch to it + self._check_and_switch_progression() + + if self._optimized_output_code is not None: + _ProgressiveFxCompile._stat_optimized_runs += 1 + output_code = self._optimized_output_code + else: + _ProgressiveFxCompile._stat_fast_runs += 1 + assert self._fast_output_code is not None + output_code = self._fast_output_code + + boxed_call = getattr(output_code, "_boxed_call", False) + if boxed_call: + res = output_code.__call__(args) + else: + res = output_code.__call__(*args) + return res + + def _check_and_switch_progression(self) -> None: + if not self._compilation_state: + return + + stage_index = self._compilation_state.check_and_get_ready_stage() + if stage_index == -1: + # no futures are ready + return + + self._switch_to_progression_stage(stage_index) + + def _switch_to_progression_stage(self, stage_index: int) -> None: + assert self._compilation_state is not None + optimized_output_code, should_clear_state = ( + self._compilation_state.switch_to_progression_stage(stage_index) + ) + + self._optimized_output_code = optimized_output_code + self._fast_output_code = None + + # Clear all compilation state if no more progression futures are left + if should_clear_state: + self._compilation_state = None + + @override + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + assert self._fast_output_code is not None + self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs) + + assert self._compilation_state is not None + # Store for later when optimized version is ready + self._compilation_state.post_compile_data = _PostCompileData( + example_inputs, constants, graph_kwargs + ) + + +# _ProgressiveFxCompile runs a fast compile immediately, then kicks off +# progressive compiles in the background and hot-swaps when they're ready. +@final +class _ProgressiveFxCompile(FxCompile): + _fast_compile: FxCompile + _optimized_compile: _OutOfProcessFxCompile + _progression_configs: list[dict[str, Any]] + + # Debugging stats + _stat_bg_started: int = 0 + _stat_bg_finished: int = 0 + _stat_fast_runs: int = 0 + _stat_optimized_runs: int = 0 + + def __init__( + self, + fast_compile: FxCompile, + optimized_compile: _OutOfProcessFxCompile, + progression_configs: list[dict[str, Any]], + ) -> None: + self._fast_compile = fast_compile + self._optimized_compile = optimized_compile + self._progression_configs = progression_configs + + @classmethod + def _reset_stats(cls) -> None: + cls._stat_bg_started = 0 + cls._stat_bg_finished = 0 + cls._stat_fast_runs = 0 + cls._stat_optimized_runs = 0 + + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + import torch._inductor.config as inductor_config + + progression_futures: list[Future[_WireProtocolPickledOutput]] = [] + + for config in self._progression_configs: + with inductor_config.patch(config): + _ProgressiveFxCompile._stat_bg_started += 1 + + # Start the progressive compiles in the background + serialized = self._optimized_compile.serialize_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + if not serialized: + continue + + inputs, constants = serialized + future = self._optimized_compile._send_to_child_async(inputs) + progression_futures.append(future) + + fast_output_code = self._fast_compile.codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + if not progression_futures: + # All async compile attempts failed - just return the fast version + return fast_output_code + + # Callback to handle the optimized result. + # This callback may be called multiple times, once for each progressive level completed, + # but may be skipped if a level either never completes or if a more optimal level + # completes before a less optimal one is switched to. + def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: + _ProgressiveFxCompile._stat_bg_finished += 1 + output = pickled_output.deserialize(constants) + self._optimized_compile._postprocess(output) + return output.graph + + return _ProgressiveOutputCode(fast_output_code, progression_futures, callback) diff --git a/torch/_inductor/compile_fx_ext.py b/torch/_inductor/compile_fx_ext.py index 5b343f17fdccf4..7fd976a05ed9bf 100644 --- a/torch/_inductor/compile_fx_ext.py +++ b/torch/_inductor/compile_fx_ext.py @@ -167,6 +167,7 @@ class _FakeTensorModeSerializer: def __init__(self, fake_mode: FakeTensorMode) -> None: self.allow_non_fake_inputs = fake_mode.allow_non_fake_inputs + self.shape_env = fake_mode.shape_env @contextlib.contextmanager def patch(self, fake_mode: FakeTensorMode) -> Generator[None, None, None]: @@ -247,6 +248,7 @@ class _WireProtocolOutput: metrics: CachedMetricsDeltas logs: list[logging.LogRecord] warning_replay: Optional[list[warnings.WarningMessage]] + shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv] def serialize(self) -> _WireProtocolPickledOutput: """ @@ -546,7 +548,11 @@ def _run_in_child( logs = captured_logs.finish() return _WireProtocolOutput( - output_graph, metrics.get_deltas(), logs, warning_replay + output_graph, + metrics.get_deltas(), + logs, + warning_replay, + fake_mode.shape_env, ).serialize() @@ -608,7 +614,7 @@ def _postprocess(self, output: _WireProtocolOutput) -> None: # And forward our collected logs. The cache is cleared when the outer # function exits. - @functools.lru_cache(None) + @functools.cache def getLogger(name: str) -> logging.Logger: return logging.getLogger(name) diff --git a/torch/_inductor/compile_fx_subproc.py b/torch/_inductor/compile_fx_subproc.py index ca5365c23ad8f7..3a1535ec1e2fd8 100644 --- a/torch/_inductor/compile_fx_subproc.py +++ b/torch/_inductor/compile_fx_subproc.py @@ -13,7 +13,7 @@ SubprocKind, SubprocPool, ) -from torch._inductor.utils import clear_inductor_caches +from torch._inductor.utils import clear_caches from .compile_fx_ext import ( _OutOfProcessFxCompile, @@ -77,14 +77,14 @@ def _run_in_child_subprocess( # tmpdir still exists and fails to compile. # # TODO: We probably should be using a separate tmpdir in the worker - # anyway... but we should probably still respect clear_inductor_caches() + # anyway... but we should probably still respect clear_caches() # in the parent... maybe? # # TODO: We could be less aggressive by keeping a clock which gets # incremented when we clear the cache, send the clock to the worker and # only clear caches if the clock changed since last time. # - clear_inductor_caches() + clear_caches() torch._inductor.metrics.reset() # TODO: turn off config.fx_graph_async_compile diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index c761f849bbd1fc..bf213b37e8540f 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -23,6 +23,9 @@ import torch._thread_safe_fork # noqa: F401 from torch._inductor import config from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.tracked_process_pool import ( + TrackedProcessPoolExecutor, +) from torch._inductor.compile_worker.utils import _async_compile_initializer from torch._inductor.utils import get_ld_library_path @@ -269,7 +272,7 @@ def __init__( self.running = True def _new_pool(self, nprocs: int, warm: bool) -> ProcessPoolExecutor: - pool = ProcessPoolExecutor( + pool = TrackedProcessPoolExecutor( nprocs, mp_context=multiprocessing.get_context(self.kind.value), initializer=functools.partial(_async_compile_initializer, os.getpid()), diff --git a/torch/_inductor/compile_worker/tracked_process_pool.py b/torch/_inductor/compile_worker/tracked_process_pool.py new file mode 100644 index 00000000000000..36df56b963d69f --- /dev/null +++ b/torch/_inductor/compile_worker/tracked_process_pool.py @@ -0,0 +1,111 @@ +import atexit +import concurrent +import dataclasses +import logging +import threading +from concurrent.futures import Future, ProcessPoolExecutor +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from time import time +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +log = logging.getLogger(__name__) + + +@dataclass +class _QueueStats: + # Mapping from id(future) -> start time + pending: dict[int, float] = dataclasses.field(default_factory=dict) + timing: list[float] = dataclasses.field(default_factory=list) + enqueue_count: int = 0 + dequeue_count: int = 0 + max_queue_depth: int = 0 + pool_count: int = 0 + + +# The queue statistics tracked by TrackedProcessPoolExecutor. Always grab +# _queue_stats_lock before touching. +_queue_stats = _QueueStats() +_queue_stats_lock = threading.Lock() + + +class TrackedProcessPoolExecutor(ProcessPoolExecutor): + def __init__( + self, + max_workers: Optional[int] = None, + mp_context: Optional[BaseContext] = None, + initializer: Optional[Callable[[], object]] = None, + ) -> None: + with _queue_stats_lock: + _queue_stats.pool_count += 1 + super().__init__(max_workers, mp_context, initializer) + + def _record_dequeue(self, f: Future[Any]) -> None: + now = time() + with _queue_stats_lock: + stats = _queue_stats + if (start_time := stats.pending.pop(id(f), None)) is None: + return + stats.dequeue_count += 1 + duration = now - start_time + stats.timing.append(duration) + + def _record_enqueue(self, f: Future[Any]) -> None: + # Monkeypatch the set_running_or_notify_cancel so we can track when the Future moves out of PENDING. + saved_running_or_notify_cancel = f.set_running_or_notify_cancel + + def set_running_or_notify_cancel() -> Any: + self._record_dequeue(f) + return saved_running_or_notify_cancel() + + now = time() + with _queue_stats_lock: + stats = _queue_stats + stats.pending[id(f)] = now + stats.enqueue_count += 1 + stats.max_queue_depth = max(stats.max_queue_depth, len(stats.pending)) + f.set_running_or_notify_cancel = set_running_or_notify_cancel # type: ignore[method-assign] + + if f._state != concurrent.futures._base.PENDING: + self._record_dequeue(f) + + def submit( + self, fn: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_R]: + f = super().submit(fn, *args, **kwargs) + self._record_enqueue(f) + return f + + +@atexit.register +def _queue_stats_report() -> None: + stats = _queue_stats + if stats.pool_count == 0: + return + + timing = stats.timing + timing.sort() + + log.info("AsyncCompile Metrics:") + log.info(" Pools %s", stats.pool_count) + log.info( + " Items %d enqueued / %d dequeued", stats.enqueue_count, stats.dequeue_count + ) + log.info(" Max Queue Depth: %d", stats.max_queue_depth) + n = len(timing) + if n > 0: + log.info(" Longest queue time: %0.2fs", timing[-1]) + log.info(" P50: %0.2fs", timing[n // 2]) + if n >= 20: + log.info(" P95: %0.2fs", timing[n * 95 // 100]) diff --git a/torch/_inductor/compile_worker/utils.py b/torch/_inductor/compile_worker/utils.py index 864dcf9c9682d9..a54fa308d3fd30 100644 --- a/torch/_inductor/compile_worker/utils.py +++ b/torch/_inductor/compile_worker/utils.py @@ -23,6 +23,8 @@ def in_toplevel_process() -> bool: # This function cannot be an inner function since otherwise mp_context="spawn" would # not work for ProcessPoolExecutor since inner functions cannot be pickled. def _async_compile_initializer(orig_ppid: int) -> None: + import torch._C + def run() -> None: while True: sleep(1) @@ -36,6 +38,9 @@ def run() -> None: # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam. signal.signal(signal.SIGINT, signal.SIG_IGN) + # Install a crash handler to print out the stacktrace for SEGV + torch._C._initCrashHandler() + # Set a bit to distinguish async_compile subprocesses from the toplevel process. global _IN_TOPLEVEL_PROCESS _IN_TOPLEVEL_PROCESS = False diff --git a/torch/_inductor/compiler_bisector.py b/torch/_inductor/compiler_bisector.py index eebff4b566ce15..5cec2020c9fb07 100644 --- a/torch/_inductor/compiler_bisector.py +++ b/torch/_inductor/compiler_bisector.py @@ -79,7 +79,7 @@ def reset_counters() -> None: call_counter_debug_info.clear() -@functools.lru_cache(None) +@functools.cache def get_env_val(env_str: str) -> Optional[str]: return os.environ.get(env_str, None) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 794d3562828a83..ae88ad8c3106fa 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -40,7 +40,7 @@ def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]: def static_cuda_launcher_default() -> bool: - STATIC_CUDA_LAUNCHER_VERSION = 0 + STATIC_CUDA_LAUNCHER_VERSION = 2 if "TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER" in os.environ: return os.environ.get("TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER") == "1" @@ -103,7 +103,7 @@ def prologue_fusion_enabled() -> bool: ) non_blocking_remote_cache_write: bool = Config( - justknob="pytorch/remote_cache:enable_non_blocking_remote_cache_write", + justknob="pytorch/remote_cache:enable_non_blocking_remote_cache_write_v2", env_name_force="TORCHINDUCTOR_NON_BLOCKING_REMOTE_CACHE_WRITE", default=True, ) @@ -175,6 +175,15 @@ def prologue_fusion_enabled() -> bool: # incompatible with disable_cpp_codegen cpp_wrapper: bool = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" +# controls whether to compile entry and kernel separately for cpp_wrapper mode. +# turn on this option to compile entry and kernel separately and minimize compile time of the entry part. +# see https://github.com/pytorch/pytorch/pull/148773 +# Note: compiling entry and kernel separately may have a non-negligible impact on the performance. +# see https://github.com/pytorch/pytorch/issues/156037 +cpp_wrapper_build_separate: bool = ( + os.environ.get("TORCHINDUCTOR_CPP_WRAPPER_BUILD_SEPARATE", "0") == "1" +) + # Controls automatic precompiling of common include files for codecache.CppCodeCache # (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is # controlled by a separate flag. @@ -271,6 +280,16 @@ def prologue_fusion_enabled() -> bool: ] ] = None +# Registers a custom pass to be run right after fusion in Inductor scheduler. +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +_post_fusion_custom_pass: Optional[ + Callable[ + [list["torch._inductor.scheduler.BaseSchedulerNode"]], + list["torch._inductor.scheduler.BaseSchedulerNode"], + ] +] = None + # Deprecated split_cat_fx_passes = True @@ -365,6 +384,10 @@ def prologue_fusion_enabled() -> bool: # enable operator reordering for peak memory optimization reorder_for_peak_memory = True +bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none" +# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used +bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None + # runtime estimation function for ops # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle estimate_op_runtime = "default" @@ -409,8 +432,11 @@ def prologue_fusion_enabled() -> bool: # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations # for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure # that triton does not use TF32 wherever cublas would not use TF32 -force_same_precision = ( - True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" +# DEPRECATED. cuBLAS no longer has the above alignment requirements. will remove in the future. +force_same_precision: bool = Config( + justknob="pytorch/compiler:force_same_precision", + env_name_force="TORCHINDUCTOR_FORCE_SAME_PRECISION", + default=False, ) # Specify candidate backends for gemm autotune. @@ -424,6 +450,7 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" ).upper() + # As above, specify candidate backends for conv autotune. # NB: in some cases for 1x1 convs we emit as matmul, # which will use the backends of `max_autotune_gemm_backends` @@ -439,6 +466,13 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] +# Specify the size of the search space for flex attention autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" +).upper() # type: ignore[assignment] + # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False @@ -446,8 +480,8 @@ def prologue_fusion_enabled() -> bool: # that can appear in the input shapes (e.g., in autotuning) unbacked_symint_fallback = 8192 -# enable searching global and local cache regardless of `max_autotune` -search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1" +# DEPRECATED. This setting is ignored. +search_autotune_cache = False save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" @@ -477,7 +511,7 @@ def prologue_fusion_enabled() -> bool: ) # AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and -# generate the learned heursitic to code which is shipped with the compiler +# generate the learned heuristic to code which is shipped with the compiler # Specify a list of comma separated optimizations to collect data for autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "") # Specify a list of comma separated optimizations to use learned heuristics for @@ -572,7 +606,10 @@ def use_autoheuristic(name: str) -> bool: # how many nodes to allow into a single fusion max_fusion_size = 64 -# max number of inputs to generate cat as a pointwise op with masked laods +# how many nodes to attempt pairwise fusion with in a buffer group +max_fusion_buffer_group_pairwise_attempts = 64 + +# max number of inputs to generate cat as a pointwise op with masked loads max_pointwise_cat_inputs = 8 # force concat to be generated as a pointwise op with masked loads @@ -595,6 +632,12 @@ def use_autoheuristic(name: str) -> bool: # enabling both of these will implicitly disable split_reductions split_reductions = True +# When we do split reduction, this number control the minimum value for +# num_split. Too small num_split make the split reduction less efficient. +# It's a much bigger problem when we compile a dynamic shape kernel with +# non-representative inputs. +min_num_split = int(os.environ.get("TORCHINDUCTOR_MIN_NUM_SPLIT", 0)) + benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1" # Enable constant and index_expr folding @@ -684,7 +727,7 @@ def decide_worker_start_method() -> str: default=True, ) -# Flags to turn on all_reduce fusion. These 2 flags should be automaticaly turned +# Flags to turn on all_reduce fusion. These 2 flags should be automatically turned # on by DDP and should not be set by the users. _fuse_ddp_communication = False _fuse_ddp_bucket_size = 25 @@ -829,7 +872,7 @@ def decide_compile_threads() -> int: # Pad too small stride may also cause perf loss. We may result in many tiny data blocks # with gaps in between. That causes less coalesced GPU memory access! # -# Initially we pick 320 as the threshold since for alignement=16, +# Initially we pick 320 as the threshold since for alignment=16, # that results in at most 5% memory cost. # # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. @@ -929,7 +972,7 @@ def decide_compile_threads() -> int: ) -# Adds NVTX annotations aroung training phases +# Adds NVTX annotations around training phases annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1" # Enable caching codegen of triton templates. @@ -1050,6 +1093,9 @@ class cpp: os.environ.get("TORCHINDUCTOR_CPP_USE_DECOMPOSE_TANH", "0") == "1" ) + # Use a small dequant buffer for wgt of woq int4 size as: [q_group_size, Nr] + use_small_dequant_buffer = False + # config specific to codegen/triton.py class triton: @@ -1063,6 +1109,10 @@ class triton: # If False, we will re-record a graph for each unique set of shape inputs cudagraph_skip_dynamic_graphs = False + # Specify dynamic shapes to capture cudagraphs and skip cudagraph for other shapes. + # Default to None, which means we capture cudagraphs for all shapes. + cudagraph_capture_sizes: Optional[tuple[Union[int, tuple[int, ...]]]] = None + # assertions not on the fast path, steady state slow_path_cudagraph_asserts = True @@ -1105,12 +1155,23 @@ class triton: # Always load full blocks (rather than broadcasting inside the block) dense_indexing = False + # TODO - enable by default + coalesce_tiling_analysis: bool = ( + os.environ.get( + "TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0" + ) + == "1" + ) + # limit tiling dimensions # - max_tiles=1 disables tiling - # - max_tiles=2 is the default + # - max_tiles=2 # - max_tiles=3 is experimental and may have bugs # higher values are unsupported - max_tiles = 2 + + # We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise. + # Note - coalesce_tiling_analysis does not yet apply to dynamic shapes. + max_tiles: Optional[int] = None # Prefer higher dimensional tilings. This simplifies indexing expressions, making # it easier to identify block pointers. @@ -1221,7 +1282,7 @@ class triton: codegen_upcast_to_fp32 = True # Whether persistent matmul kernels should be enabled this flag only has effect when on h100 - # with a verison of triton new enough to support TMA + # with a version of triton new enough to support TMA enable_persistent_tma_matmul = ( os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1" ) @@ -1236,6 +1297,10 @@ class triton: class aot_inductor: + """ + Settings for Ahead-Of-Time Inductor Compilation + """ + # AOTInductor output path # If an absolute path is specified, the generated lib files will be stored under the directory; # If a relative path is specified, it will be used as a subdirectory under the default caching path; @@ -1277,7 +1342,7 @@ class aot_inductor: # flag to decide whether to create a submodule for constant graph. use_runtime_constant_folding: bool = False - # flag to force weight to be appened to the shared library and mmaped by the runtime + # flag to force weight to be appended to the shared library and mapped by the runtime # rather than embedded into the data section. Needed to support 1B+ parameter models force_mmap_weights: bool = False @@ -1324,6 +1389,9 @@ class aot_inductor: # Experimental. Flag to control whether to include weight in .so package_constants_in_so: bool = True + # Experimental. Flag to control whether to package weight separately on disk + package_constants_on_disk: bool = False + # Experimental. Controls automatic precompiling of common AOTI include files. precompile_headers: bool = not is_fbcode() @@ -1331,10 +1399,13 @@ class aot_inductor: embed_kernel_binary: bool = False # Generate kernel files that support multiple archs - # Default it will emit multi arch kernels as asm files, e.g. PTX for CUDA. + # For CUDA, this means generating fatbin files for kernels, and the fatbin files + # contains PTX and SASS for the current architecture. emit_multi_arch_kernel: bool = False - # In addition to emit asm files, also emit binary files for current arch - emit_current_arch_binary: bool = False + + # If not None, the generated files with use this name in file stem. + # If None, we will use a hash to name files. + model_name_for_generated_files: Optional[str] = None # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {} @@ -1456,6 +1527,23 @@ class cuda: os.environ.get("TORCHINDUCTOR_CUTLASS_PRESCREENING", "1") == "1" ) + # Specify which operations should use CUTLASS backend + # Comma-separated list like "mm,addmm,bmm", "all" for all operations, and "" for none. + # Acceptable operations: mm, int_mm, addmm, sparse_semi_structured_mm, bmm, scaled_mm + cutlass_enabled_ops: str = os.environ.get( + "TORCHINDUCTOR_CUTLASS_ENABLED_OPS", "all" + ) + + # Whether to consult the binary remote cache + use_binary_remote_cache: bool = True + + # Whether to upload compiled kernels to remote cache + upload_to_binary_remote_cache: bool = False + + # Whether to force upload if the key already exists + # Use this to overwrite and handle cache pollution + binary_remote_cache_force_write: bool = False + class rocm: # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. @@ -1602,7 +1690,7 @@ class trace: # replace records with HTML-like labels" # and thus fail to generate a graph. So, let's give the user an option # to specify the shape attribute for the dot graph. For example, passing - # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables + # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like labels # to workaround the above failure. dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) @@ -1616,10 +1704,10 @@ class trace: compile_profile = False # Upload the .tar.gz file - # Needs to be overriden based on specific environment needs + # Needs to be overridden based on specific environment needs upload_tar: Optional[Callable[[str], None]] = None - log_autotuning_results: bool = False + log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" # Save mapping info from inductor generated triton kernel to post_grad fx nodes log_inductor_triton_kernel_to_post_grad_node_info: bool = True @@ -1654,6 +1742,11 @@ class trace: "_pre_fusion_custom_pass", # tests assume that changes here don't invalidate cache "always_complex_memory_overlap_TESTING_ONLY", + # cache related options are not relevant to cache results + "fx_graph_cache", + "fx_graph_remote_cache", + "autotune_local_cache", + "autotune_remote_cache", ] # External callable for matmul tuning candidates diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 3b5b2104b402c4..869f2658219a4a 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -16,6 +16,18 @@ MODULE_TAG = "_MAIN_MODULE" CONST_MODULE_TAG = "_CONST_MODULE" +_dont_constant_fold: list[torch.fx.node.Target] = [] + + +def add_dont_constant_fold(op: torch.fx.node.Target) -> None: + global _dont_constant_fold + _dont_constant_fold.append(op) + + +def clear_dont_constant_fold() -> None: + global _dont_constant_fold + _dont_constant_fold.clear() + def replace_node_with_constant( gm: torch.fx.GraphModule, @@ -146,6 +158,9 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: # We only folding fp32_weight -> q # int8_weight and leave dq in graph to be fused return True + + if node.target in _dont_constant_fold: + return True return False def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]: diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 20a7c77c9e9b5f..a734292ce3e69b 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -127,7 +127,7 @@ def install_gcc_via_conda() -> str: return cxx_path -@functools.lru_cache(None) +@functools.cache def check_compiler_exist_windows(compiler: str) -> None: """ Check if compiler is ready, in case end user not activate MSVC environment. @@ -183,7 +183,6 @@ def convert_cubin_to_obj( # Convert .cubin to .o cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}" subprocess.run(cmd.split(), capture_output=True, text=True, check=True) - os.remove(cubin_file) # Rename .data to .rodata cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}" subprocess.run(cmd.split(), capture_output=True, text=True, check=True) @@ -201,13 +200,13 @@ def convert_cubin_to_obj( return obj_file -@functools.lru_cache(None) +@functools.cache def _is_apple_clang(cpp_compiler: str) -> bool: version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") return "Apple" in version_string.splitlines()[0] -@functools.lru_cache(None) +@functools.cache def _is_clang(cpp_compiler: str) -> bool: # Mac OS apple clang maybe named as gcc, need check compiler info. if sys.platform == "darwin": @@ -222,7 +221,7 @@ def _is_clang(cpp_compiler: str) -> bool: return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) -@functools.lru_cache(None) +@functools.cache def _is_gcc(cpp_compiler: str) -> bool: # Since "clang++" ends with "g++", the regex match below would validate on it. if _is_clang(cpp_compiler): @@ -230,7 +229,7 @@ def _is_gcc(cpp_compiler: str) -> bool: return bool(re.search(r"(gcc|g\+\+|gnu-c\+\+)", cpp_compiler)) -@functools.lru_cache(None) +@functools.cache def _is_msvc_cl(cpp_compiler: str) -> bool: if not _IS_WINDOWS: return False @@ -248,7 +247,7 @@ def _is_msvc_cl(cpp_compiler: str) -> bool: return False -@functools.lru_cache(None) +@functools.cache def _is_intel_compiler(cpp_compiler: str) -> bool: def _check_minimal_version(compiler_version: TorchVersion) -> None: """ @@ -292,32 +291,32 @@ def _check_minimal_version(compiler_version: TorchVersion) -> None: return False -@functools.lru_cache(None) +@functools.cache def is_gcc() -> bool: return _is_gcc(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def is_clang() -> bool: return _is_clang(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def is_intel_compiler() -> bool: return _is_intel_compiler(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def is_apple_clang() -> bool: return _is_apple_clang(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def is_msvc_cl() -> bool: return _is_msvc_cl(get_cpp_compiler()) -@functools.lru_cache(None) +@functools.cache def get_compiler_version_info(compiler: str) -> str: env = os.environ.copy() env["LC_ALL"] = "C" # Don't localize output @@ -332,7 +331,7 @@ def get_compiler_version_info(compiler: str) -> str: ).decode(*SUBPROCESS_DECODE_ARGS) except Exception: return "" - # Mutiple lines to one line string. + # Multiple lines to one line string. version_string = version_string.replace("\r", "_") version_string = version_string.replace("\n", "_") return version_string @@ -411,7 +410,7 @@ def normalize_path_separator(orig_path: str) -> str: class BuildOptionsBase: """ This is the Base class for store cxx build options, as a template. - Acturally, to build a cxx shared library. We just need to select a compiler + Actually, to build a cxx shared library. We just need to select a compiler and maintains the suitable args. """ @@ -573,6 +572,10 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]: else "Wno-ignored-optimization-argument" ) cflags.append(ignored_optimization_argument) + if _is_gcc(cpp_compiler): + # Issue all the warnings demanded by strict ISO C and ISO C++. + # Ref: https://github.com/pytorch/pytorch/issues/153180#issuecomment-2986676878 + cflags.append("pedantic") return cflags @@ -737,13 +740,6 @@ def __init__( self._finalize_options() -def _get_glibcxx_abi_build_flags() -> list[str]: - if not _IS_WINDOWS: - return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] - else: - return [] - - def _get_torch_cpp_wrapper_definition() -> list[str]: return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"] @@ -886,7 +882,7 @@ def _get_python_related_args() -> tuple[list[str], list[str]]: return python_include_dirs, python_lib_path -@functools.lru_cache(None) +@functools.cache def is_conda_llvm_openmp_installed() -> bool: try: command = "conda list llvm-openmp --json" @@ -896,7 +892,7 @@ def is_conda_llvm_openmp_installed() -> bool: return False -@functools.lru_cache(None) +@functools.cache def homebrew_libomp() -> tuple[bool, str]: try: # check if `brew` is installed @@ -917,7 +913,7 @@ def homebrew_libomp() -> tuple[bool, str]: return False, "" -@functools.lru_cache(None) +@functools.cache def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: try: output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( @@ -931,7 +927,7 @@ def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: pass -@functools.lru_cache(None) +@functools.cache def perload_icx_libomp_win(cpp_compiler: str) -> None: def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: try: @@ -949,7 +945,7 @@ def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: return False """ - Intel Compiler implenmented more math libraries than clang, for performance proposal. + Intel Compiler implemented more math libraries than clang, for performance proposal. We need preload them like openmp library. """ preload_list = [ @@ -1120,7 +1116,6 @@ def get_cpp_torch_options( omp_passthrough_args, ) = _get_openmp_args(cpp_compiler) - cxx_abi_passthrough_args = _get_glibcxx_abi_build_flags() fb_macro_passthrough_args = _use_fb_internal_macros() mmap_self_macros = get_mmap_self_macro(use_mmap_weights) @@ -1143,10 +1138,7 @@ def get_cpp_torch_options( libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths libraries = torch_libraries + omp_lib passthrough_args = ( - sys_libs_passthrough_args - + isa_ps_args_build_flags - + cxx_abi_passthrough_args - + omp_passthrough_args + sys_libs_passthrough_args + isa_ps_args_build_flags + omp_passthrough_args ) return ( @@ -1314,6 +1306,9 @@ def get_cpp_torch_device_options( "in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support." ) + if device_type == "mps": + definitions.append(" USE_MPS") + if config.is_fbcode(): include_dirs.append(build_paths.sdk_include) @@ -1425,7 +1420,7 @@ def get_name_and_dir_from_output_file_path( dir = /tmp/tmpof1n5g7t/5c/ put 'name' and 'dir' to CppBuilder's 'name' and 'output_dir'. - CppBuilder --> get_target_file_path will format output path accoding OS: + CppBuilder --> get_target_file_path will format output path according OS: Linux: /tmp/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.so Windows: [Windows temp path]/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.dll """ @@ -1442,13 +1437,13 @@ class CppBuilder: Args: name: 1. Build target name, the final target file will append extension type automatically. - 2. Due to the CppBuilder is supports mutliple OS, it will maintains ext for OS difference. + 2. Due to the CppBuilder is supports multiple OS, it will maintains ext for OS difference. sources: Source code file list to be built. BuildOption: Build options to the builder. output_dir: - 1. The output_dir the taget file will output to. + 1. The output_dir the target file will output to. 2. The default value is empty string, and then the use current dir as output dir. 3. Final target file: output_dir/name.ext """ @@ -1462,7 +1457,7 @@ def __get_python_module_flags() -> tuple[str, str]: @staticmethod def __get_object_flags() -> tuple[str, str]: extension = ".obj" if _IS_WINDOWS else ".o" - output_flags = "/c /Fo" if _IS_WINDOWS else "-c -o" + output_flags = "/c /Fo" if _IS_WINDOWS else "-c -o" # codespell:ignore return extension, output_flags @staticmethod @@ -1503,7 +1498,7 @@ def __init__( self._name = name - # Code start here, initial self internal veriables firstly. + # Code start here, initial self internal variables firstly. self._build_option = BuildOption self._compiler = BuildOption.get_compiler() self._use_relative_path = BuildOption.get_use_relative_path() @@ -1700,8 +1695,8 @@ def build_fbcode_re( def build(self) -> None: """ - It is must need a temperary directory to store object files in Windows. - After build completed, delete the temperary directory to save disk space. + It is must need a temporary directory to store object files in Windows. + After build completed, delete the temporary directory to save disk space. """ if self._use_relative_path: # remote build uses relative path @@ -1729,7 +1724,7 @@ def save_compile_cmd_to_cmake( definitions = " ".join(self._build_option.get_definitions()) contents = textwrap.dedent( f""" - cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + cmake_minimum_required(VERSION 3.27 FATAL_ERROR) project(aoti_model LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) @@ -1755,7 +1750,8 @@ def save_compile_cmd_to_cmake( current_arch = _nvcc_arch_as_compile_option() contents += textwrap.dedent( f""" - find_package(CUDA REQUIRED) + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) find_program(OBJCOPY_EXECUTABLE objcopy) if(NOT OBJCOPY_EXECUTABLE) @@ -1783,7 +1779,7 @@ def save_compile_cmd_to_cmake( # --- PTX to FATBIN Command & Target --- add_custom_command( OUTPUT ${{FATBIN_FILE}} - COMMAND ${{CUDA_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} + COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} -gencode arch=compute_80,code=compute_80 -gencode arch=compute_{current_arch},code=sm_{current_arch} DEPENDS ${{PTX_FILE}} diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 6ef286797767a8..b077c4da9c28d1 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -146,7 +146,7 @@ def check_build(self, code: str) -> bool: def __bool__(self) -> bool: return self.__bool__impl(config.cpp.vec_isa_ok) - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def __bool__impl(self, vec_isa_ok) -> bool: if vec_isa_ok is not None: return vec_isa_ok @@ -169,7 +169,7 @@ def __str__(self) -> str: return "neon" return "asimd" # detects the presence of advanced SIMD on armv8-a kernels - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] @dataclasses.dataclass @@ -191,7 +191,7 @@ def __str__(self) -> str: return "neon" return "asimd" - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] @dataclasses.dataclass @@ -208,7 +208,7 @@ class VecAVX512(VecISA): def __str__(self) -> str: return "avx512" - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] @dataclasses.dataclass @@ -241,7 +241,7 @@ def __str__(self) -> str: } """ - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def __bool__(self) -> bool: if super().__bool__(): if config.is_fbcode(): @@ -263,7 +263,7 @@ class VecAVX2(VecISA): def __str__(self) -> str: return "avx2" - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] @dataclasses.dataclass @@ -280,7 +280,7 @@ class VecZVECTOR(VecISA): def __str__(self) -> str: return "zvector" - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] @dataclasses.dataclass @@ -293,7 +293,7 @@ class VecVSX(VecISA): def __str__(self) -> str: return "vsx" - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] class InvalidVecISA(VecISA): @@ -308,7 +308,7 @@ def __str__(self) -> str: def __bool__(self) -> bool: # type: ignore[override] return False - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] def x86_isa_checker() -> list[str]: @@ -380,7 +380,7 @@ def get_isa_from_cpu_capability( # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # might have too much redundant content that is useless for ISA check. Hence, # we only cache some key isa information. -@functools.lru_cache(None) +@functools.cache def valid_vec_isa_list() -> list[VecISA]: isa_list: list[VecISA] = [] if sys.platform == "darwin" and platform.processor() == "arm": diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index cd0fa0a73d96a2..bdc201803fb605 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -54,6 +54,7 @@ import torch.fx from torch import Tensor +from torch._dynamo.callback import CallbackTrigger from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state from torch._inductor.compile_fx import ( @@ -227,7 +228,7 @@ def _finalize_tensor(self) -> None: self.graph = None # manager was used again after existing cleanup, - # we shouldnt set it to None + # we shouldn't set it to None if self.live_cudagraphify_fns == 0: self.tree_manager = None @@ -345,6 +346,17 @@ def get_manager( return get_container(device_index).tree_manager +def is_cudagraph_capture_sizes(int_key: Union[int, tuple[int, ...]]) -> bool: + """ + Returns true if all dynamic shapes should be captured or the dynamic shape + int_key should be captured. + """ + return ( + config.triton.cudagraph_capture_sizes is None + or int_key in config.triton.cudagraph_capture_sizes + ) + + def cudagraphify_impl( model: ModelType, inputs: list[InputType], @@ -366,6 +378,10 @@ def deferred_cudagraphify(inputs: list[InputType]) -> OutputType: nonlocal has_warn int_key = get_ints(inputs) + + if not is_cudagraph_capture_sizes(int_key): + return model(inputs) + fn = fn_cache.get(int_key) if fn is not None: return fn(inputs) @@ -1230,7 +1246,7 @@ def static_input_iter() -> Generator[torch.Tensor, None, None]: } if config.triton.slow_path_cudagraph_asserts: - # need to use parent live weakrefs because live_indices isnt set yet + # need to use parent live weakrefs because live_indices isn't set yet memory = ( [] if self.parent is None else list(self.parent.path_live_weakrefs()) ) @@ -1606,7 +1622,7 @@ def remove_path_cached_tensors(self) -> None: def clear_path_state(self) -> None: "Clear the path state in this current executing node" - # this doesnt actually do anything right now, leaving it as placeholder + # this doesn't actually do anything right now, leaving it as placeholder @staticmethod def _tensor_metadata( @@ -2190,34 +2206,37 @@ def record_function( self, new_inputs: list[InputType], function_id: FunctionID ) -> OutputType: assert not isinstance(self.current_node, CUDAWarmupNode) - graph_id = self.new_graph_id() - log.debug( - "Recording function %d of graph recording id %d", - function_id.id, - graph_id.id, - ) - torch.cuda.synchronize() - node = CUDAGraphNode( - self.ids_to_funcs[function_id], - graph_id, - self.current_node, - new_inputs, - self.cuda_graphs_thread_pool, - self.device_index, - self.ids_to_stack_traces[function_id], - self.stream, - self.mode, - self.compile_id, - ) - if self.current_node is None: - self.roots[function_id].append(node) - else: - self.current_node.add_child(function_id, node) - self.current_node = node - self.path_state = ExecutionState.RECORDING - self.update_generation() - torch.cuda.synchronize() - return node.run_first_inputs(new_inputs) + with torch._dynamo.callback_handler.install_callbacks( + CallbackTrigger.CUDAGRAPH_RECORDING, str(self.compile_id) + ): + graph_id = self.new_graph_id() + log.debug( + "Recording function %d of graph recording id %d", + function_id.id, + graph_id.id, + ) + torch.cuda.synchronize() + node = CUDAGraphNode( + self.ids_to_funcs[function_id], + graph_id, + self.current_node, + new_inputs, + self.cuda_graphs_thread_pool, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + self.mode, + self.compile_id, + ) + if self.current_node is None: + self.roots[function_id].append(node) + else: + self.current_node.add_child(function_id, node) + self.current_node = node + self.path_state = ExecutionState.RECORDING + self.update_generation() + torch.cuda.synchronize() + return node.run_first_inputs(new_inputs) def execute_node( self, node: CUDAGraphNode, new_inputs: list[InputType] diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index f6ce7e43ad95f3..2686d1d2ddde23 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -131,7 +131,7 @@ def check_for_mutation( inputs: list[InputType], is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool], ) -> Optional[str]: - # doesnt work for non-trees because the warmup run would apply mutation twice + # doesn't work for non-trees because the warmup run would apply mutation twice if torch._inductor.config.triton.cudagraph_trees: # checking if mutation is only on parameters/static inputs mutation_indices: Sequence[int] = [ @@ -222,7 +222,7 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor( ) -> Optional[str]: default_msg = format_default_skip_message("mutated inputs") - # doesnt work for non-trees because the warmup run would apply mutation twice + # doesn't work for non-trees because the warmup run would apply mutation twice if torch._inductor.config.triton.cudagraph_trees: unique_idxs = OrderedSet(static_input_idxs) # checking if mutation is only on parameters/static inputs diff --git a/torch/_inductor/custom_graph_pass.py b/torch/_inductor/custom_graph_pass.py index 9a22f17896a514..c9a8e33a1145a4 100644 --- a/torch/_inductor/custom_graph_pass.py +++ b/torch/_inductor/custom_graph_pass.py @@ -18,7 +18,7 @@ class CustomGraphPass(ABC): identifies your implementation (and can be pickled). The caching logic includes this identifier in its key calculation, i.e., any new value will effectively invalidate existing entries. We expect custom passes would typically depend purely on the - textual reprensentation of the implementation. In that case, we recommend using the + textual representation of the implementation. In that case, we recommend using the 'get_hash_for_files' helper below to compute a unique hash from the contents of a static list of source files, i.e., the source(s) containing the custom pass implementation. That approach ensures that any change to the implementation will @@ -53,6 +53,38 @@ def uuid(self) -> Optional[Any]: """ +class CustomGraphModulePass(ABC): + """ + Implement this interface for custom Graph passes: + + 1) The __call__() method contains the implementation of the custom pass. + + 2) The uuid() method enables inductor to cache compiled graphs when your custom + passes are applied. This method can return any identifier as long as it uniquely + identifies your implementation (and can be pickled). The caching logic includes this + identifier in its key calculation, i.e., any new value will effectively invalidate + existing entries. We expect custom passes would typically depend purely on the + textual representation of the implementation. In that case, we recommend using the + 'get_hash_for_files' helper below to compute a unique hash from the contents of a + static list of source files, i.e., the source(s) containing the custom pass + implementation. That approach ensures that any change to the implementation will + mean a new uuid. + """ + + @abstractmethod + def __call__(self, gm: torch.fx.GraphModule) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. Return None + to skip inductor code caching entirely. + """ + + CustomGraphPassType: TypeAlias = Optional[ Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]] ] diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index ee328e8b560475..d3bc89a3d41253 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -50,7 +50,7 @@ GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] -@functools.lru_cache(None) +@functools.cache def has_dot() -> bool: return shutil.which("dot") is not None @@ -583,6 +583,7 @@ def log_autotuning_results( timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821 elapse: float, precompile_elapse: float, + prescreening_elapse: Optional[float], ) -> None: from .ir import FixedLayout @@ -653,6 +654,7 @@ def build_node_info(node: ir.IRNode) -> dict[str, str]: "input_nodes": [build_node_info(node) for node in input_nodes], "autotuning_time": elapse, "precompile_time": precompile_elapse, + "prescreening_time": prescreening_elapse, } with self.fopen_context( "autotuning_result_json_list.txt", "at", encoding="utf-8" diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index bc0ee2979aabb9..08c3abc9f23f98 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -6,7 +6,7 @@ import sys import typing from typing import Any, Callable, Optional, TypeVar, Union -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, TypeAlias import torch import torch._decomp as decomp @@ -51,6 +51,10 @@ _T = TypeVar("_T") _P = ParamSpec("_P") +_GenericOperator: TypeAlias = Union[ + torch._ops.OperatorBase, torch._ops.OpOverloadPacket +] + log = logging.getLogger(__name__) aten = torch.ops.aten prims = torch.ops.prims @@ -108,7 +112,7 @@ # Remove unwanted decompositions included via the core ATen decompositions from # the Inductor decomp table. -decomps_to_exclude = [ +decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [ aten._unsafe_index, aten._unsafe_masked_index, aten._unsafe_masked_index_put_accumulate, @@ -132,9 +136,9 @@ def register_decomposition( - ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]], + ops: Union[_GenericOperator, list[_GenericOperator]], ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined] + for op in ops if isinstance(ops, list) else [ops]: if op in decompositions: log.warning("duplicate decomp: %s", ops) return decomp.register_decomposition(ops, decompositions) @@ -496,6 +500,10 @@ def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor: reshaped_tensor = tensor.view(new_shape) return reshaped_tensor + # Manually resolve complex tensors, as .is_conj() is unreliable after cloning during compilation. + x = x + 0 + z = z + 0 + x_reshaped = reshape_tensor_complex(x.view(x.real.dtype)) z_reshaped = reshape_tensor_complex(z.view(y.real.dtype)) result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type) @@ -504,7 +512,8 @@ def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor: @register_decomposition([aten.conj_physical]) def conj_physical(self: torch.Tensor) -> torch.Tensor: - assert not self.is_complex(), "TODO: implement this" + if self.is_complex(): + return NotImplemented return self @@ -856,7 +865,7 @@ def miopen_batch_norm( ) -@functools.lru_cache(None) +@functools.cache def fast_random_decomps() -> dict[Any, Callable[..., Any]]: return {**decompositions, **extra_random_decomps} diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index b1f75372ee4c79..9de52061c6489b 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -125,7 +125,7 @@ def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[list[int]]: ) return None - # May hanppen if self and other are as follows + # May happen if self and other are as follows # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None) # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None) if OrderedSet(self_strides) != OrderedSet(other_strides): @@ -581,12 +581,12 @@ def index_vars_no_squeeze( def index_vars_squeeze( *argsizes: Sequence[sympy.Expr], prefix: str = "d" -) -> tuple[list[list[sympy.Expr]], VarRanges]: +) -> tuple[list[Sequence[sympy.Expr]], VarRanges]: from .ir import SqueezeView var_ranges, add_var = var_builder(prefix) - args: list[list[sympy.Expr]] = [] - new_sizes: list[list[sympy.Expr]] = [] + args: list[Sequence[sympy.Expr]] = [] + new_sizes: list[Sequence[sympy.Expr]] = [] for size in argsizes: new_size, reindex = SqueezeView.squeezer(size) new_sizes.append(new_size) @@ -607,7 +607,10 @@ def extract_read_writes( if isinstance(fn, LoopBody): inner = extract_loop_body_with_args( - fn, [*args, *hidden_args], var_ranges, normalize + fn, + [*args, *hidden_args], # type: ignore[list-item] + var_ranges, + normalize, ) else: # Slow path tracing the function @@ -708,7 +711,7 @@ def extract_input_node_reduction_ranges( # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes? # The current method still uses reduction ranges from the dependent realized node, which is not ideal. - # Is there a way to check whether there are permutations inbetween? + # Is there a way to check whether there are permutations in between? reads = input_node.get_reads() reduction_size: Optional[list[sympy.Expr]] = None size: Optional[list[sympy.Expr]] = None diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index 811ae9982d2298..5f99d83e07e792 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -29,7 +29,7 @@ def dtype(self) -> torch.dtype: ... # So first decompose CSEVars -> tuple before calling this -@functools.lru_cache(None) +@functools.cache def get_promoted_dtype( *args: Sequence[tuple[torch.dtype, bool]], type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None, diff --git a/torch/_inductor/extern_node_serializer.py b/torch/_inductor/extern_node_serializer.py index ffd390152034b9..0e5f42e7309e85 100644 --- a/torch/_inductor/extern_node_serializer.py +++ b/torch/_inductor/extern_node_serializer.py @@ -1,6 +1,6 @@ import json -from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node +from torch._export.serde.schema import ExternKernelNode, ExternKernelNodes, Node from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index 7fe28a9f4a2f96..05222168095f48 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -150,7 +150,7 @@ def __init__(self, elem, name: Optional[str], mod) -> None: self.owning_mod_ref = weakref.ref(mod) @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] erased_tensors = [ e for e in pytree.arg_tree_leaves(*args, **kwargs) diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index 483099b6aca4c2..ff434ccba09521 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import functools from collections import deque +from typing import Union import torch from torch.utils._ordered_set import OrderedSet @@ -12,6 +13,7 @@ FixedLayout, FlexibleLayout, InputBuffer, + ShapeAsConstantBuffer, StorageBox, Subgraph, TensorBox, @@ -493,10 +495,12 @@ def convert_output_node_to_buffer(output): "The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ", type(output_buffer), ) + device = output_buffer.data.get_device() + assert device is not None subgraph_buffer = ComputedBuffer( name=None, layout=FlexibleLayout( - device=output_buffer.data.get_device(), + device=device, dtype=output_buffer.data.get_dtype(), size=output_buffer.data.get_size(), ), @@ -512,7 +516,7 @@ def convert_output_node_to_buffer(output): def create_placeholder( name: str, dtype: torch.dtype, device: torch.device -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: """ Creates a placeholder input buffers for producing subgraph_output """ @@ -538,8 +542,11 @@ def tuned_b2b_gemm( A.get_dtype(), [A.shape[0], C.shape[1]], # type: ignore[index] ) + placeholders = [ + create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error()) + ] subgraph_buffer = build_subgraph_buffer( - [create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())], + placeholders, # type: ignore[arg-type, list-item] subgraph, ) choices: list[TritonTemplateCaller] = [] diff --git a/torch/_inductor/fx_passes/binary_folding.py b/torch/_inductor/fx_passes/binary_folding.py index c64f1309319d32..d2ad3e1c8f9190 100644 --- a/torch/_inductor/fx_passes/binary_folding.py +++ b/torch/_inductor/fx_passes/binary_folding.py @@ -83,7 +83,7 @@ def recover_original_precision_folded_computation_ops(gm): _binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] -@functools.lru_cache(None) +@functools.cache def binary_folding_init(): _conv_args = [Arg() for _ in range(9)] _addmm_args = [Arg() for _ in range(3)] diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py new file mode 100644 index 00000000000000..7ed815d048cb9b --- /dev/null +++ b/torch/_inductor/fx_passes/bucketing.py @@ -0,0 +1,432 @@ +import logging +import operator +from typing import Any, Callable, Optional, Union + +import torch +from torch._dispatch.python import enable_python_dispatcher +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet + + +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def bucket_size_determinator(bucket_id: int) -> float: + """ + Determine the size of a bucket based on its ID. + + Args: + bucket_id (int): The ID of the bucket. + + Returns: + float: The size of the bucket. + """ + return 2000.0 + + +def bucket_all_gather( + gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float] +) -> None: + ag_buckets = bucket_all_gather_by_mb(gm, all_gather_bucket_cap_mb_callback) + if len(ag_buckets) == 0: + return + merge_all_gather(gm, ag_buckets) + + +def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type] + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + ) + + +def is_wait_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.wait_tensor.default + ) + + +def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool: + return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type] + + +def bucket_all_gather_by_mb( + gm: torch.fx.GraphModule, + all_gather_bucket_cap_mb_callback: Callable[[int], float], + filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> list[list[torch.fx.Node]]: + """ + Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`. + + + Returns a list of buckets, where each bucket is a list of all_gather nodes. + """ + + node_list = gm.graph.nodes + + # Prerequisite: Check if there is any all_gather node + found_all_gather = False + for node in node_list: + if is_all_gather_into_tensor(node): + found_all_gather = True + break + if not found_all_gather: + return [] + + ag_nodes: list[torch.fx.Node] = [] + + # Step 1: Find all all_gather nodes + for node in node_list: + if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): + if (filter_wait_node is None) or filter_wait_node(node): + ag_node = node.args[0] + ag_nodes.append(ag_node) + + # Step 2: Put all_gather nodes into buckets + ag_buckets: list[list[torch.fx.Node]] = [] + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_size_bytes: int = 0 + cur_bucket_id: int = 0 + # Convert MiB to bytes + all_gather_bucket_size_bytes = int( + all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 + ) + for ag_node in ag_nodes: + assert is_all_gather_into_tensor(ag_node) + assert "val" in ag_node.meta + ag_output_size_bytes = ( + ag_node.meta["val"].numel() + * torch.finfo(ag_node.meta["val"].dtype).bits + // 8 + ) + if ( + cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes + and cur_bucket + ): + # Current bucket is full, create new bucket + ag_buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + cur_bucket_size_bytes += ag_output_size_bytes + cur_bucket.append(ag_node) + if cur_bucket: + # add remaining nodes in the last bucket + ag_buckets.append(cur_bucket) + + return ag_buckets + + +def node_copy( # type: ignore[no-untyped-def] + env, + new_graph, + node: torch.fx.Node, + arg_transform: Callable[[torch.fx.Node], torch.fx.node.Argument], +) -> torch.fx.Node: + if node not in env: + new_node = new_graph.node_copy(node, arg_transform=arg_transform) + env[node] = new_node + else: + new_node = env[node] + return new_node + + +def new_graph_call_function( # type: ignore[no-untyped-def] + new_graph, + target: Callable[..., Any], + args: Optional[tuple[torch.fx.node.Argument, ...]] = None, + kwargs: Optional[dict[str, torch.fx.node.Argument]] = None, + type_expr: Optional[Any] = None, +) -> torch.fx.Node: + from torch.utils._pytree import tree_map_only + + new_node = new_graph.call_function(target, args, kwargs) + args_val = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], args) + kwargs_val = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], kwargs) + with V.fake_mode, enable_python_dispatcher(): + new_fake_tensor = target(*args_val, **kwargs_val) + new_node.meta["val"] = new_fake_tensor + return new_node + + +def env_lookup( # type: ignore[no-untyped-def] + env, x: torch.fx.Node, node_user: Union[torch.fx.Node, str] +) -> torch.fx.Node: + assert x in env, ( + f"Dependent node {x} not in env when creating downstream node {node_user}" + ) + return env[x] + + +def merge_all_gather( + gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]] +) -> None: + """ + Transforms the graph to use bucketed all_gather operations based on `ag_buckets`. + """ + assert len(ag_buckets) > 0 + + ag_nodes: list[torch.fx.Node] = [] + cast_nodes: list[torch.fx.Node] = [] + ag_node_to_wait_node: dict[torch.fx.Node, torch.fx.Node] = {} + ag_node_to_bucket_id = {} + cast_node_to_bucket_id = {} + + # Map nodes to buckets and identify wait nodes + for bucket_id, bucket in enumerate(ag_buckets): + for ag_node in bucket: + assert len(ag_node.users) == 1, ( + f"Expect only one user for {ag_node}, but got {ag_node.users}" + ) + wait_node = next(iter(ag_node.users)) + ag_node_to_wait_node[ag_node] = wait_node + ag_nodes.append(ag_node) + ag_node_to_bucket_id[ag_node] = bucket_id + if ( + ag_node.args[0].op == "call_function" # type: ignore[union-attr] + and ag_node.args[0].target # type: ignore[union-attr] + == torch.ops.prims.convert_element_type.default + ): + cast_nodes.append(ag_node.args[0]) # type: ignore[arg-type] + cast_node_to_bucket_id[ag_node.args[0]] = bucket_id # type: ignore[arg-type] + + # Step 3: Create new (bucketed) all_gather nodes + bucket_id_to_bucketed_op_info = {} + bucket_id_is_scheduled = {} + cast_bucket_id_is_scheduled = {} + _, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args + for bucket_id, ag_bucket in enumerate(ag_buckets): + ag_input_nodes = [] + wait_nodes = [] + for ag_node in ag_bucket: + assert ( + ag_node in ag_node_to_wait_node + and ag_node.args[1] == group_size + and ag_node.args[2] == group_name + ) + ag_input_nodes.append(ag_node.args[0]) + wait_nodes.append(ag_node_to_wait_node[ag_node]) + bucket_id_to_bucketed_op_info[bucket_id] = ( + ag_input_nodes, + group_size, + group_name, + wait_nodes, + ) + + ag_wait_nodes = list(ag_node_to_wait_node.values()) + ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes) + cast_nodes = OrderedSet(cast_nodes) + new_graph: torch.fx.Graph = torch.fx.Graph() + env: dict[torch.fx.Node, torch.fx.Node] = {} + + node_list = gm.graph.nodes + for node in node_list: + if node not in ag_and_wait_nodes and node not in cast_nodes: + # not cast-before-all_gather, all_gather or its wait_tensor - schedule it normally + node_copy(env, new_graph, node, lambda x: env_lookup(env, x, node)) + elif node in cast_nodes: + # batch cast nodes together into one foreach_copy node + assert node in cast_node_to_bucket_id + bucket_id = cast_node_to_bucket_id[node] + if bucket_id not in cast_bucket_id_is_scheduled: + ag_input_nodes, group_size, group_name, orig_wait_nodes = ( + bucket_id_to_bucketed_op_info[bucket_id] + ) + # device = ag_input_nodes[0].meta["val"].device + # rank = device.index + # dtype = ag_input_nodes[0].meta["val"].dtype + if all( + n.op == "call_function" # type: ignore[union-attr] + and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr] + for n in ag_input_nodes + ): + param_all_gather_inputs = [ + new_graph_call_function( + new_graph, + torch.ops.aten.empty.memory_format, + (n.meta["val"].shape,), # type: ignore[union-attr] + { + "dtype": n.args[1], # type: ignore[union-attr] + "device": n.meta["val"].device, # type: ignore[union-attr] + "pin_memory": False, + }, + ) + for n in ag_input_nodes + ] + for pp, n in zip(param_all_gather_inputs, ag_input_nodes): + pp.meta = n.meta.copy() # type: ignore[union-attr] + + cast_input_nodes = [env[n.args[0]] for n in ag_input_nodes] # type: ignore[union-attr, index] + foreach_copy = new_graph_call_function( + new_graph, + torch.ops.aten._foreach_copy.default, + (param_all_gather_inputs, cast_input_nodes), + {}, + ) + foreach_copy.meta["val"] = [n.meta["val"] for n in ag_input_nodes] # type: ignore[union-attr] + getitems = [ + new_graph_call_function( + new_graph, + operator.getitem, + (foreach_copy, i), + {}, + ) + for i in range(len(ag_input_nodes)) + ] + + for new_n, old_n in zip(getitems, ag_input_nodes): + env[old_n] = new_n # type: ignore[index] # noqa: PERF403 + else: + param_all_gather_inputs_orig = [ + node_copy( + env, + new_graph, + ag_input_node, # type: ignore[arg-type] + lambda x: env_lookup(env, x, ag_input_node), # type: ignore[arg-type] + ) + for ag_input_node in ag_input_nodes + ] + cast_bucket_id_is_scheduled[bucket_id] = True + else: + continue + elif node in ag_node_to_wait_node: + assert node in ag_node_to_bucket_id + bucket_id = ag_node_to_bucket_id[node] + if bucket_id not in bucket_id_is_scheduled: + ag_input_nodes, group_size, group_name, orig_wait_nodes = ( + bucket_id_to_bucketed_op_info[bucket_id] + ) + device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr] + rank = device.index + dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr] + # TODO: if we want to support mixed dtype in the same bucket, + # we need to first view all all_gather inputs as uint8 (common denominator), + # then do the all_gather, then view the output back to the original dtype. + # Look at FSDP2 to see how to do this. + assert all(n.meta["val"].dtype == dtype for n in ag_input_nodes), ( # type: ignore[union-attr] + "All all_gather inputs in the same bucket must have the same dtype" + ) + # must schedule all the all_gather input nodes first, before the bucketed all_gather node + param_all_gather_inputs_orig = [ + node_copy( + env, + new_graph, + ag_input_node, # type: ignore[arg-type] + lambda x: env_lookup(env, x, ag_input_node), # type: ignore[arg-type] + ) + for ag_input_node in ag_input_nodes + ] + # schedule the bucketed all_gather node + param_all_gather_inputs_flattened = [ + new_graph_call_function( + new_graph, torch.ops.aten.reshape.default, (n, [-1]), {} + ) + for n in param_all_gather_inputs_orig + ] + inp_split_sizes = [ + n.meta["val"].numel() for n in param_all_gather_inputs_orig + ] + param_all_gather_outputs = [ + new_graph_call_function( + new_graph, + torch.ops.aten.empty.memory_format, + ([n.meta["val"].numel() * group_size],), + { + "dtype": n.meta["val"].dtype, + "device": n.meta["val"].device, + "pin_memory": False, + }, + ) + for n in param_all_gather_inputs_orig + ] + # TODO: This assumes dim-0 sharding. + # If we need to support sharding on another dim, we should look at how FSDP2 does it + # (e.g. search for `shard_dim` in FSDP2 codebase) + param_all_gather_outputs_shape_orig = [ + (n.meta["val"].shape[0] * group_size,) + n.meta["val"].shape[1:] + for n in param_all_gather_inputs_orig + ] + all_gather_input_numel = sum(inp_split_sizes) + + all_gather_output = new_graph_call_function( + new_graph, + torch.ops.aten.empty.memory_format, + ([all_gather_input_numel * group_size],), + { + "dtype": dtype, + "device": device, + "pin_memory": False, + }, + ) + all_gather_copy_in = new_graph_call_function( + new_graph, + torch.ops.fsdp.all_gather_copy_in.default, + ( + param_all_gather_inputs_flattened, + all_gather_output, + inp_split_sizes, + all_gather_input_numel, + rank, + ), + {}, + ) + all_gather_input = new_graph_call_function( + new_graph, + operator.getitem, + (all_gather_copy_in, 0), + {}, + ) + all_gather_into_tensor_out = new_graph_call_function( + new_graph, + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + (all_gather_input, group_size, group_name), + {"out": all_gather_output}, + ) + wait_tensor = new_graph_call_function( + new_graph, + torch.ops._c10d_functional.wait_tensor.default, + (all_gather_into_tensor_out,), + {}, + ) + all_gather_output_reshaped = new_graph_call_function( + new_graph, + torch.ops.aten.reshape.default, + (wait_tensor, [group_size, -1]), + {}, + ) + outs_flattened = [ + new_graph_call_function( + new_graph, + torch.ops.aten.reshape.default, + (n, [group_size, -1]), + {}, + ) + for n in param_all_gather_outputs + ] + split_with_sizes_copy = new_graph_call_function( # noqa: F841 + new_graph, + torch.ops.fsdp.split_with_sizes_copy.default, + (all_gather_output_reshaped, inp_split_sizes), + {"dim": 1, "out": outs_flattened}, + ) + outs = [ + new_graph_call_function( + new_graph, + torch.ops.aten.reshape.default, + (n, orig_shape), + {}, + ) + for n, orig_shape in zip( + outs_flattened, param_all_gather_outputs_shape_orig + ) + ] + assert len(orig_wait_nodes) == len(outs) + assert len(orig_wait_nodes) > 0 + for out, orig_wait_node in zip(outs, orig_wait_nodes): + env[orig_wait_node] = out # noqa: PERF403 + bucket_id_is_scheduled[bucket_id] = True + else: + continue + gm.graph = new_graph diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 2d9409523c159f..ccea7d7e70af50 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -73,7 +73,7 @@ class CommBlock: def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]: """ Given a collective node (e.g., allreduce), find out all the nodes belong to - this communcation. + this communication. Args: comm_node(fx.Node): The target communication/collective node. @@ -304,7 +304,7 @@ def _scatter_fused_allreduce_waits( """ # Before we mass up the order, we need to get the index of the last wait node - # in orig_comm_blocks. This index will be later used to determinee what users + # in orig_comm_blocks. This index will be later used to determine what users # nodes need to be move to maintain a correct topological sort order. last_wait_node_idx = 0 for node in graph.nodes: diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index 8b6437fc2582ab..f05048a85e0e72 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -119,12 +119,55 @@ def register_binary_folding_pattern(pattern, extra_check=_return_true): ) -@functools.lru_cache(None) +@functools.cache def addmm_patterns_init(): + """ + addmm related patterns. + To avoid duplication, also includes int8 WoQ GEMM pattern without bias. + """ device = next( (gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu" ) val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False) + scale = functools.partial(torch.empty, (10,), device=device, requires_grad=False) + + def check_int8_woq_concat_linear_weights(match): + is_cpu = match.kwargs["inp"].meta["val"].is_cpu + if not is_cpu or not config.cpp.enable_concat_linear: + # Currently, this pattern is only supported on CPU + return False + + weight_inputs = ["w1", "w2"] + if "w3" in match.kwargs: + weight_inputs.append("w3") + + if not all( + match.kwargs[wgt].target == torch.ops.prims.convert_element_type.default + for wgt in weight_inputs + ): + return False + + if not all( + next(iter(match.kwargs[wgt]._input_nodes.keys())).meta["val"].dtype + is torch.int8 + for wgt in weight_inputs + ): + return False + + if not all( + match.kwargs[wgt].meta["val"].dtype is torch.bfloat16 + for wgt in weight_inputs + ): + return False + + equal_shape_inputs = [weight_inputs] + for equal_shape_group in equal_shape_inputs: + inps = [match.kwargs[name] for name in equal_shape_group] + if not all( + inp.meta["val"].shape == inps[0].meta["val"].shape for inp in inps + ): + return False + return True def check_concat_weights(match): is_cpu = match.kwargs["inp"].meta["val"].is_cpu @@ -153,9 +196,27 @@ def check_concat_weights(match): for inp in inps ): return False - return True + def int8_woq_fusion_pattern(inp, w1, w2, w3, s1, s2, s3): + return ((inp @ w1) * s1, (inp @ w2) * s2, (inp @ w3) * s3) + + def int8_woq_fusion_replacement(inp, w1, w2, w3, s1, s2, s3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_s = torch.cat((s1, s2, s3), dim=0) + mm = (inp @ cat_w).mul(cat_s) + return mm.chunk(3, dim=1) + + register_replacement( + int8_woq_fusion_pattern, + int8_woq_fusion_replacement, + [val(), val(), val(), val(), scale(), scale(), scale()], + fwd_only, + pass_patterns[0], + extra_check=check_int8_woq_concat_linear_weights, + exclusive_arg_names=("w1", "w2", "w3", "s1", "s2", "s3"), + ) + def matmul_fuse_pattern(inp, w1, w2, w3): return (inp @ w1, inp @ w2, inp @ w3) diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py new file mode 100644 index 00000000000000..162f62a2ea9bc8 --- /dev/null +++ b/torch/_inductor/fx_passes/fsdp.py @@ -0,0 +1,68 @@ +import logging +from typing import Callable + +import torch +from torch._inductor.fx_passes.bucketing import ( + bucket_all_gather_by_mb, + merge_all_gather, +) + + +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def is_graph_input(node: torch.fx.Node) -> bool: + return node.op == "placeholder" + + +def is_fsdp_all_gather_wait(wait: torch.fx.Node) -> bool: + # Assume all_gather_into_tensor input is either graph input + # or dtype conversion of graph input + ag_node = wait.args[0] # type: ignore[arg-type, union-attr] + return ( + is_graph_input(ag_node.args[0]) # type: ignore[arg-type, union-attr] + or ( # type: ignore[arg-type, union-attr] + ag_node.args[0].op == "call_function" # type: ignore[arg-type, union-attr] + and ag_node.args[0].target # type: ignore[arg-type, union-attr] + == torch.ops.prims.convert_element_type.default # type: ignore[arg-type, union-attr] + and is_graph_input(ag_node.args[0].args[0]) # type: ignore[arg-type, union-attr] + ) + ) + + +def bucket_fsdp_all_gather( + gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float] +) -> None: + """ + Bucketing pass for SimpleFSDP all_gather ops. + + Attributes: + gm (torch.fx.GraphModule): Graph module of the graph. + all_gather_bucket_cap_mb_callback (Callable[[int], float]): callback function that + takes in bucket id and returns size of a bucket in megabytes. + + Usage: + ``` + from torch._inductor.fx_passes.bucketing import ( + bucket_all_gather, + bucket_size_determinator, + ) + + + def _bucket_all_gather(graph): + return bucket_all_gather(graph.owning_module, bucket_size_determinator) + + + torch._inductor.config.post_grad_custom_post_pass = _bucket_all_gather + ``` + """ + + ag_buckets = bucket_all_gather_by_mb( + gm, + all_gather_bucket_cap_mb_callback, + filter_wait_node=is_fsdp_all_gather_wait, + ) + if len(ag_buckets) == 0: + return + merge_all_gather(gm, ag_buckets) diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index e807e00c9bb518..4ed950afe9a186 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -582,6 +582,112 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): ) +def _sfdp_pattern_21(query, key, value, attn_mask): + # for T5 with inplace add + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + masked_score = score + attn_mask + score = masked_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value) + + +def _sfdp_replacement_21(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + scale=1.0, + ) + + +def _sfdp_pattern_22(query, key, value, attn_mask): + # for T5 with inplace add and return key and value + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + masked_score = score + attn_mask + score = masked_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value + + +def _sfdp_replacement_22(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return ( + _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + scale=1.0, + ), + key, + value, + ) + + +def _sfdp_pattern_23(query, key, value): + # for T5 with inplace add and + # return key and value and + # attn_mask is generated by atem.full(..., 0) + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + fp32_score = score.float() + score = fp32_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value + + +def _sfdp_replacement_23(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return ( + _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + is_causal=False, + scale=1.0, + ), + key, + value, + ) + + def _sfdp_params_check(match): assert all(k in match.kwargs for k in ("query", "key", "value")) query = match.kwargs["query"].meta["val"] @@ -848,13 +954,6 @@ def _get_sfdp_patterns(): d, _sfdp_extra_check(aten.div.Tensor), ), - ( - _sfdp_pattern_20, - _sfdp_replacement_20, - [g(), g(), g(), m()], - d, - _sfdp_extra_check(aten.div.Tensor), - ), ( _sfdp_pattern_18, _sfdp_replacement_18, @@ -876,6 +975,34 @@ def _get_sfdp_patterns(): d, _sfdp_params_check, ), + ( + _sfdp_pattern_20, + _sfdp_replacement_20, + [g(), g(), g(), m_2d()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_21, + _sfdp_replacement_21, + [g(), g(), g(), m_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_22, + _sfdp_replacement_22, + [g(), g(), g(), m_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_23, + _sfdp_replacement_23, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), ] mask_fp32_patterns = ["pattern_16"] if dtype == torch.half: @@ -956,7 +1083,7 @@ def _get_sfdp_patterns(): ) -@functools.lru_cache(None) +@functools.cache def _sfdp_init(): for key, register_replacement_kwargs in _get_sfdp_patterns(): gen_register_replacement(key, **register_replacement_kwargs) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 0d6e7481785445..357a9d66cdad74 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -1052,7 +1052,7 @@ def __init__(self, op, **kwargs): def match(self, node: torch.fx.Node): input = get_arg_value(node, 0, "input") if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): - # check the input has the same shape and its uers have the same target + # check the input has the same shape and its users have the same target # check all clamp operators have the same min and max values, and # nan_to_num operators use the same default value. child = next(iter(node.users.keys())) diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 7e53d7bb8fcc09..c9d7187de0d9bc 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -11,10 +11,15 @@ import torch import torch._guards import torch.utils._pytree as pytree +from torch._dynamo.utils import counters from torch._inductor.constant_folding import ConstantFolder from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict from torch._inductor.utils import get_gpu_type -from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true +from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + statically_known_true, +) from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._ordered_set import OrderedSet @@ -201,7 +206,7 @@ def remove_redundant_views(gm: torch.fx.GraphModule): class UniformValueConstantFolder(ConstantFolder): """ - Runs constant folding and replaces tensors that have a unifrom value + Runs constant folding and replaces tensors that have a uniform value with a tensor constructor call: aten.full([shape], value, ...) """ @@ -694,18 +699,47 @@ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtyp def definitely_equal( - lhs_sizes: Sequence[Union[torch.SymInt, bool]], - rhs_sizes: Sequence[Union[torch.SymInt, bool]], + old_sizes: Sequence[Union[torch.SymInt, int]], + new_sizes: Sequence[Union[torch.SymInt, torch.fx.Node, int]], ) -> bool: """ - Leverage guard_or_false to compare if two lists of int/symint are equal. + Leverage guard_or_true/false to compare if two lists of int/symint are equal. Useful to compare sizes, strides etc. + + Can handle -1 in new_sizes which happens in the size arguments of a + view op. old_sizes is supposed to be the tensor shape and should not + contain -1. + + new_sizes can contains fx.Node when dynamic shape is enabled. In that + case new_sizes[i].meta['val'] contains the real torch.SymInt. """ - return len(lhs_sizes) == len(rhs_sizes) and all( - guard_or_false(lhs_item == rhs_item) - for lhs_item, rhs_item in zip(lhs_sizes, rhs_sizes) - ) + num_neg1 = 0 + + if len(old_sizes) != len(new_sizes): + return False + + for lhs_item, rhs_item in zip(old_sizes, new_sizes): + if isinstance(rhs_item, torch.fx.Node): + rhs_item = rhs_item.meta["val"] + + assert isinstance(lhs_item, (int, torch.SymInt)), type(lhs_item) + assert isinstance(rhs_item, (int, torch.SymInt)), type(rhs_item) + + # It still makes sense to call guard_or_true/false since lhs_item + # rhs_item are torch.SymInt rather than sympy expressions when + # dynamic shape is enabled. + if guard_or_false(lhs_item == rhs_item): + continue + + if guard_or_true(rhs_item != -1): + return False + + num_neg1 += 1 + + if num_neg1 > 1: + return False + return True @register_graph_pattern( @@ -716,7 +750,7 @@ def pointless_view(match: Match, arg, size): """Remove no-op view""" node = match.output_node() arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] - if definitely_equal(size, arg_size): + if definitely_equal(arg_size, size): node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] match.erase_nodes() @@ -738,6 +772,7 @@ def pointless_view_pair(match: Match, arg, size1, size2): if definitely_equal(arg_size, size2): node.replace_all_uses_with(arg) match.erase_nodes() + counters["inductor"]["removed_pointless_view_pair"] += 1 @register_graph_pattern( diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 5eb2dce80dfef9..af40d987f7d18f 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -440,7 +440,7 @@ def from_match(cls, match: list[torch.fx.Node]) -> "_Matmul": A_node=cast("torch.fx.Node", match[0].args[0]), B_node=cast("torch.fx.Node", mm_node.args[1]), # _Matmul handles reshapes via custom graph manipulation logic, see `replace_with()` method. - # TOOO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes. + # TODO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes. pre_mm_reshape=None, post_mm_reshape=None, ) @@ -906,7 +906,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: # 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims. # 2. The scatter dim after the reshape, to use when we are doing the 2D (a*b,c) @ (c,d) = (a,b,d) scaled mm op. # 3. Store expected potentially 3D+ mm output shape, so we can reshape the 2D mm output to the intended - # 3D+ shape before applying reduce-scatter, and to prevent shape erros with subsequent ops. + # 3D+ shape before applying reduce-scatter, and to prevent shape errors with subsequent ops. # If 'A' was reshaped from 3D+ -> 2D for the mm, we need to determine the new scattter dim after the reshape # for the fused matmul reduce scatter implementation to use. diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py index b4e0f1f35023df..d2c8068f130c8b 100644 --- a/torch/_inductor/fx_passes/misc_patterns.py +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -12,7 +12,7 @@ aten = torch.ops.aten -@functools.lru_cache(None) +@functools.cache def _misc_patterns_init(): from .joint_graph import patterns as joint_graph_patterns from .post_grad import pass_patterns as post_grad_patterns_all diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 9e69f96d27f06f..a269b17e3a2a98 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -2,7 +2,7 @@ import functools import operator from functools import reduce -from typing import Any +from typing import Any, Callable import torch from torch._dynamo.utils import counters @@ -19,10 +19,16 @@ KeywordArg, MULTIPLE, ) +from ..utils import ( + is_mkldnn_bf16_supported, + is_mkldnn_fp16_supported, + SUPPORTED_MKLDNN_DEVICES, +) from ..virtualized import ops, V from .freezing_patterns import register_freezing_graph_pattern from .post_grad import register_lowering_pattern from .quantization import ( + _register_int8_woq_concat_linear_pattern, _register_quantization_lowerings, _register_quantization_weight_pack_pass, _register_woq_lowerings, @@ -38,6 +44,126 @@ _linear_args = [Arg() for _ in range(6)] _conv_transpose_args = [Arg() for _ in range(11)] + class MkldnnDeviceOpBase: + def get_linear_transpose_weight(self, weight_node): + raise NotImplementedError + + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + raise NotImplementedError + + def pack_linear_weight( + self, graph, is_lp_weight, transpose_weight_node, batch_size + ): + raise NotImplementedError + + def pack_linear( + self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ): + raise NotImplementedError + + class CpuMkldnnDeviceOp(MkldnnDeviceOpBase): + def get_linear_transpose_weight(self, weight_node): + packed_weight_node = weight_node + assert packed_weight_node.target == mkldnn._reorder_linear_weight + transpose_weight_node = packed_weight_node.args[0] + assert transpose_weight_node.target == aten.permute.default + return transpose_weight_node + + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + packed_weight_op = mkldnn._reorder_convolution_weight + if is_transposed: + packed_weight_op = mkldnn._reorder_convolution_transpose_weight + + # mkldnn_reorder_conv_weight(self, padding, stride, dilation, groups, input_size) + packed_weight_inputs = (weight,) + tuple(constant_args) + (input_size,) + return graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + def pack_linear_weight( + self, graph, is_lp_weight, transpose_weight_node, batch_size + ): + # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. + packed_weight_inputs = ( + transpose_weight_node, + batch_size.node.shape_env.size_hint(batch_size.node.expr) + if has_free_symbols(batch_size) + else batch_size, + ) + + # MKL packed matrix can't be copied to a different address because the internal implementation + # depends on the alignment of internally-stored metadata. + # In aot mode, we need to firstly save the packed weight, when loading it, + # it will be in a different address which doesn't work. + # Disable MKL prepack linear in AOT mode + packed_weight_op = ( + mkldnn._reorder_linear_weight + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation + ) + else torch.ops.mkl._mkl_reorder_linear_weight + ) + return graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + def pack_linear( + self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ): + packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node) + transpose_weight_node = packed_weight_node.args[0] + if is_lp_weight or mkldnn._is_mkldnn_acl_supported() or V.aot_compilation: + packed_linear_inputs += (bias, "none", [], "") + packed_linear_op: Callable[..., Any] = mkldnn._linear_pointwise.default + else: + packed_linear_inputs += (transpose_weight_node, bias, batch_size) + packed_linear_op = torch.ops.mkl._mkl_linear + + return graph.create_node( + "call_function", packed_linear_op, packed_linear_inputs + ) + + class XpuMkldnnDeviceOp(MkldnnDeviceOpBase): + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + assert not is_transposed, ( + "'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." + ) + return weight + + def _get_mkldnn_device_op(device_type: str) -> MkldnnDeviceOpBase: + """ + Returns the MKLDNN device operation class based on the current device type. + """ + if device_type == "cpu": + return CpuMkldnnDeviceOp() + elif device_type == "xpu": + return XpuMkldnnDeviceOp() + else: + raise RuntimeError(f"MKLDNN is not supported on {device_type} device.") + def _is_valid_grouped_gemm_fusion(computation_nodes): """ Here we check: @@ -61,7 +187,7 @@ def _is_valid_grouped_gemm_fusion(computation_nodes): def grouped_gemm_pass(graph: torch.fx.Graph): """ - Group GEMM has multi output nodes which is compilicated to define a Pattern. + Group GEMM has multi output nodes which is complicated to define a Pattern. Use below way to connect the pattern to the lowering. TODO: Use MultiOutputPattern, current limitation is the pattern requires fixed number of output nodes. Extend to support Group GEMM for pattern matcher. @@ -923,10 +1049,11 @@ def get_val(val): def is_linear_add_bias(match): add_node = match.output_node() linear_node = add_node.args[0] - packed_weight_node = linear_node.args[1] - assert packed_weight_node.target == mkldnn._reorder_linear_weight - transpose_weight_node = packed_weight_node.args[0] - assert transpose_weight_node.target == aten.permute.default + device_type = add_node.meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + transpose_weight_node = mkldnn_device_op.get_linear_transpose_weight( + linear_node.args[1] + ) weight_meta = transpose_weight_node.args[0].meta.get("val") bias_node = add_node.args[1] if isinstance(bias_node, int): @@ -935,10 +1062,7 @@ def is_linear_add_bias(match): bias_meta = add_node.args[1].meta.get("val") if weight_meta is None or bias_meta is None: return False - assert weight_meta.dtype in ( - torch.bfloat16, - torch.float16, - ) + if bias_meta.dtype != weight_meta.dtype: return False return ( @@ -997,13 +1121,13 @@ def _is_packable_mkldnn_rnn_layer(match): # Check dtype if any( lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16 - and not mkldnn._is_mkldnn_bf16_supported() + and not is_mkldnn_bf16_supported("cpu") for POS_ARG in POS_ARGS ): return False if any( lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16 - and not mkldnn._is_mkldnn_fp16_supported() + and not is_mkldnn_fp16_supported("cpu") for POS_ARG in POS_ARGS ): return False @@ -1015,6 +1139,11 @@ def _is_packable_convolution(match): Check if the node is supported for MKLDNN convolution. """ conv_node = match.output_node() + device_type = conv_node.meta.get("val").device.type + # The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device. + if match.kwargs["is_transposed"] and device_type == "xpu": + return False + input_meta_value = conv_node.args[0].meta.get("val") weight_meta_value = conv_node.args[1].meta.get("val") if input_meta_value is None or weight_meta_value is None: @@ -1025,21 +1154,22 @@ def _is_packable_convolution(match): for meta_value in [input_meta_value, weight_meta_value]: if ( meta_value is None - or meta_value.device.type != "cpu" + or meta_value.device.type not in SUPPORTED_MKLDNN_DEVICES or (meta_value.dim() != 4 and meta_value.dim() != 5) ): return False + if ( input_meta_value.dtype == torch.bfloat16 or weight_meta_value.dtype == torch.bfloat16 ): - if not mkldnn._is_mkldnn_bf16_supported(): + if not is_mkldnn_bf16_supported(device_type): return False if ( input_meta_value.dtype == torch.float16 or weight_meta_value.dtype == torch.float16 ): - if not mkldnn._is_mkldnn_fp16_supported(): + if not is_mkldnn_fp16_supported(device_type): return False is_transposed = conv_node.args[-3] if is_transposed: @@ -1098,10 +1228,15 @@ def is_const_or_cat_by_const(weight): torch.bfloat16, torch.float16, ) + bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined] + use_bf16_for_fp32_weight = ( + bf32_matmul_enabled and weight_meta_value.dtype == torch.float32 + ) + compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. # on aarch64, use mkldnn op for fp32 as well if acl is enabled if ( - not is_lp_weight + not compute_with_lp and not mkldnn._is_mkldnn_acl_supported() and ((not torch._C.has_mkl) or has_free_symbols(batch_size)) ): @@ -1123,17 +1258,18 @@ def is_const_or_cat_by_const(weight): ): return False + device_type = input_meta_value.device.type if ( input_meta_value.dtype == torch.bfloat16 or weight_meta_value.dtype == torch.bfloat16 ): - if not mkldnn._is_mkldnn_bf16_supported(): + if not is_mkldnn_bf16_supported(device_type): return False if ( input_meta_value.dtype == torch.float16 or weight_meta_value.dtype == torch.float16 ): - if not mkldnn._is_mkldnn_fp16_supported(): + if not is_mkldnn_fp16_supported(device_type): return False return True @@ -1178,26 +1314,29 @@ def convolution(match, *args, **kwargs): assert isinstance(is_transposed, bool) graph = match.graph conv_node = match.output_node() + device_type = conv_node.args[0].meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) input_size = conv_node.args[0].meta.get("val").shape with graph.inserting_before(conv_node): constant_args = [args[4], args[3], args[5], args[-1]] - packed_weight_op = mkldnn._reorder_convolution_weight packed_conv_op = mkldnn._convolution_pointwise.default if is_transposed: constant_args.insert(1, args[-2]) # output_padding - packed_weight_op = mkldnn._reorder_convolution_transpose_weight packed_conv_op = mkldnn._convolution_transpose_pointwise.default + if not has_free_symbols(input_size): - packed_weight_inputs = ( - (args[1],) + tuple(constant_args) + (input_size,) - ) - packed_weight_node = graph.create_node( - "call_function", packed_weight_op, args=packed_weight_inputs + packed_weight_node = mkldnn_device_op.pack_conv_weight( + graph, + is_transposed, + args[1], + constant_args, + input_size, ) else: assert not is_transposed # For dynamic shape case, we need to pack weight in runtime. packed_weight_node = args[1] + packed_conv_inputs = ( (args[0], packed_weight_node, args[2]) + tuple(constant_args) @@ -1299,6 +1438,8 @@ def linear(match, *args, **kwargs): else args[0] ) weight = args[1] if linear_node.target == aten.mm.default else args[2] + device_type = input.meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) with graph.inserting_before(linear_node): transpose_weight_node = graph.create_node( "call_function", aten.permute.default, (weight, (1, 0)) @@ -1308,50 +1449,25 @@ def linear(match, *args, **kwargs): torch.bfloat16, torch.float16, ) + bf32_matmul_enabled = ( + torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined] + ) + use_bf16_for_fp32_weight = ( + bf32_matmul_enabled and weight_dtype == torch.float32 + ) + compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight batch_size = input.meta.get("val").shape[0] if has_free_symbols(batch_size): - assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), ( + assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), ( f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" ) - # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. - packed_weight_inputs = ( - transpose_weight_node, - batch_size.node.shape_env.size_hint(batch_size.node.expr) - if has_free_symbols(batch_size) - else batch_size, - ) - # MKL packed matrix can't be copied to a different address because the internal implementation - # depends on the alignment of internally-stored metadata. - # In aot mode, we need to firstly save the packed weight, when loading it, - # it will be in a different address which doesn't work. - # Disable MKL prepack linear in AOT mode - packed_weight_op = ( - mkldnn._reorder_linear_weight - if ( - is_lp_weight - or mkldnn._is_mkldnn_acl_supported() - or V.aot_compilation - ) - else torch.ops.mkl._mkl_reorder_linear_weight + packed_weight_node = mkldnn_device_op.pack_linear_weight( + graph, compute_with_lp, transpose_weight_node, batch_size ) - packed_weight_node = graph.create_node( - "call_function", packed_weight_op, args=packed_weight_inputs + packed_linear_node = mkldnn_device_op.pack_linear( + graph, compute_with_lp, batch_size, input, packed_weight_node, bias ) - packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node) - if ( - is_lp_weight - or mkldnn._is_mkldnn_acl_supported() - or V.aot_compilation - ): - packed_linear_inputs += (bias, "none", [], "") - packed_linear_op = mkldnn._linear_pointwise.default - else: - packed_linear_inputs += (transpose_weight_node, bias, batch_size) - packed_linear_op = torch.ops.mkl._mkl_linear - packed_linear_node = graph.create_node( - "call_function", packed_linear_op, packed_linear_inputs - ) linear_node.replace_all_uses_with(packed_linear_node) packed_linear_node.meta.update(linear_node.meta) graph.erase_node(linear_node) @@ -1398,7 +1514,7 @@ def forward(self, x): user_node.replace_all_uses_with(node) gm.graph.erase_node(user_node) - @functools.lru_cache(None) + @functools.cache def _mkldnn_fusion_init(): # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. # Otherwise even the matmul or innerproduct can not be accelerated with acl @@ -1414,9 +1530,10 @@ def _mkldnn_fusion_init(): _register_quantization_lowerings() _register_woq_lowerings() - @functools.lru_cache(None) + @functools.cache def _mkldnn_weight_pack_init(): if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): _register_weight_pack_pass() _recover_linear() _register_quantization_weight_pack_pass() + _register_int8_woq_concat_linear_pattern() diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 655a0e44d24da6..d2dfc3d9e4d0d9 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -102,7 +102,7 @@ def valid_shape_and_stride(t: Optional[Tensor]) -> bool: symbolic_cnt += 1 else: return False - # filter out cases where all dimentions are symbolic + # filter out cases where all dimensions are symbolic if symbolic_cnt == len(t.size()): return False return all( @@ -226,7 +226,7 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: and K > M and K > N and torch.cuda.get_device_capability() < (9, 0) - ): # doesnt repro on h100s: + ): # doesn't repro on h100s: return True # Fails with AMD @@ -239,7 +239,7 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: # dram_gbps might be underestimating bandwidth because of cache. # if we estimate machine balance too low we might miss some speedups, - # if we extimate too high there will be unnecessary compilation time increase. + # if we estimate too high there will be unnecessary compilation time increase. # TODO - finetune coefficient here. As a reference point, Triton mm model assumes # 80% of reads are in cache and cache is 4x faster than dram_gbps machine_balance = machine_balance * 0.5 @@ -247,7 +247,7 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: return arithmetic_intensity > machine_balance -@functools.lru_cache(None) +@functools.cache def get_pad_cache() -> torch._inductor.codecache.LocalCache: return torch._inductor.codecache.LocalCache() @@ -382,7 +382,7 @@ def should_pad_mm_bf16(dtype: torch.dtype, M: int, N: int, K: int) -> bool: and N % 2 == 1 and K >= large_k_threshold_to_pad and torch.cuda.get_device_capability() < (9, 0) - ): # doesnt repro on h100s: + ): # doesn't repro on h100s: return True return False @@ -711,7 +711,7 @@ def fallback() -> str: ah_ori_time = autoheuristic.get_collected_feedback(orig_choice) ah_pad_time = autoheuristic.get_collected_feedback(pad_choice) - # if precondition is not satisifed, autoheuristic does not collect data + # if precondition is not satisfied, autoheuristic does not collect data if ah_ori_time is not None and ah_pad_time is not None: if ori_time is None: set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time) @@ -851,7 +851,7 @@ def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: ) -@functools.lru_cache(None) +@functools.cache def _pad_mm_init() -> None: from .joint_graph import patterns diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index ac8e5377338446..0bc2ce0b1ab49f 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -22,6 +22,7 @@ from torch.utils._ordered_set import OrderedSet from .. import config, ir, pattern_matcher +from ..codegen.common import custom_backend_passes from ..comms import remove_fsdp2_unsharded_param_graph_input_usage from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage from ..lowering import lowerings as L @@ -48,6 +49,7 @@ ) from ..utils import ( decode_device, + get_all_devices, get_gpu_type, is_gpu, is_pointwise_use, @@ -109,15 +111,21 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): post_grad_custom_pre_pass ) - if ( - config.cpp.enable_grouped_gemm_template - and config.max_autotune - and "CPP" in config.max_autotune_gemm_backends - and torch._C._has_mkldnn - ): - from .mkldnn_fusion import grouped_gemm_pass + if torch._C._has_mkldnn: + if ( + config.cpp.enable_grouped_gemm_template + and config.max_autotune + and "CPP" in config.max_autotune_gemm_backends + ): + from .mkldnn_fusion import grouped_gemm_pass + + grouped_gemm_pass(gm.graph) - grouped_gemm_pass(gm.graph) + if config.cpp.enable_concat_linear: + from .quantization import concat_linear_woq_int4 + + # Concat linear optimization for WOQ int4 + concat_linear_woq_int4(gm) if config.pattern_matcher: lazy_init() @@ -182,10 +190,17 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): fake_tensor_updater.incremental_update() + for device, custom_backend_pass in custom_backend_passes.items(): + if custom_backend_pass is not None: + gm_devices = [d.type for d in get_all_devices(gm)] + if device in gm_devices: + pass_name = "custom_backend_passes_" + device + GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass) + # Keep these last, since they introduces mutation. Look at # ./fx_passes/README.md for a discussion of mutation invariants. GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( - reinplace_inplaceable_ops + functools.partial(reinplace_inplaceable_ops, fake_tensor_updater), ) GraphTransformObserver( gm, "decompose_triton_kernel_wrapper_functional" @@ -204,6 +219,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): decompose_map_to_while_loop ) + # Fx all_gather bucketing introduces mutation op + # Keeping it in the end to keep invariant of functional graph for previous passes. + if config.bucket_all_gathers_fx != "none": + from torch._inductor.fx_passes.bucketing import ( + bucket_all_gather, + bucket_size_determinator, + ) + from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather + + p = ( + bucket_fsdp_all_gather + if config.bucket_all_gathers_fx == "fsdp" + else bucket_all_gather + ) + d = ( + config.bucket_all_gathers_fx_bucket_size_determinator + or bucket_size_determinator + ) + GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass( + lambda graph: p(graph.owning_module, d) + ) + gm.recompile() gm.graph.lint() @@ -608,7 +645,7 @@ def visit(other_node): # only reorder nodes before the first copy_ in the graph. # copy_ will appear at the end of functionalized graphs when there is mutation on inputs, - # and this reordering doesnt work well with mutation + # and this reordering doesn't work well with mutation first_copy = next( iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)), None, @@ -644,7 +681,7 @@ def register_lowering_pattern( def is_valid_mm_plus_mm(match: Match): - if not torch._inductor.utils.use_max_autotune(): + if not (config.max_autotune or config.max_autotune_gemm): return False *_b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape @@ -1216,19 +1253,45 @@ def decomp(*flat_args): graph_pass.apply(graph) - # We need to remove the get_attr registered for _constant_schema and the - # auto_functioanlized's graph module (it's replaced with original ) when auto_functionalize a hop. - _to_remove = [] + # Remove unused get_attr nodes and their corresponding attributes from the graph module. + # When auto_functionalizing a hop, we need to clean up get_attr nodes for _constant_schema + # and the auto_functionalized graph module that are no longer referenced. + unused_get_attr_nodes = [] + removable_attrs: OrderedSet[torch.fx.node.Target] = OrderedSet() + protected_attrs: OrderedSet[torch.fx.node.Target] = OrderedSet() + + # First pass: identify unused get_attr nodes and track attribute usage for node in graph.nodes: - if node.op == "get_attr" and len(node.users) == 0: - _to_remove.append(node) - if hasattr(graph.owning_module, node.target) and isinstance( - getattr(graph.owning_module, node.target), torch.fx.GraphModule + if node.op != "get_attr": + continue + + if len(node.users) == 0: + # Node is unused, mark for removal + unused_get_attr_nodes.append(node) + + # Check if the attribute can be removed from the module + if ( + hasattr(graph.owning_module, node.target) + and isinstance( + getattr(graph.owning_module, node.target), torch.fx.GraphModule + ) + and node.target not in protected_attrs ): - delattr(graph.owning_module, node.target) - for node in _to_remove: + removable_attrs.add(node.target) + else: + # Node is used, protect its attribute from removal + if node.target in removable_attrs: + removable_attrs.remove(node.target) + protected_attrs.add(node.target) + + # Second pass: clean up unused nodes and attributes + for node in unused_get_attr_nodes: graph.erase_node(node) + for attr_name in removable_attrs: + assert isinstance(attr_name, str) + delattr(graph.owning_module, attr_name) + graph.lint() for _ in graph.find_nodes( @@ -1427,7 +1490,7 @@ def register_partial_reduction_pattern(): def reuse_partial(match, input, reduced_dims, keepdim): partial_red, full_red = match.output_nodes() - # if theyre small, reuse not worth it + # if they're small, reuse not worth it if not statically_known_true(input.meta["val"].numel() >= 4096): return True @@ -1481,7 +1544,9 @@ def is_index_put_and_requires_h2d_sync_for_gpu_value(node): class ConstructorMoverPass: - def __init__(self, target: str, allow_outputs: bool = False) -> None: + def __init__( + self, target: str, allow_outputs: bool = False, allow_inputs: bool = False + ) -> None: """ Move constructors from cpu to the target_device. @@ -1494,9 +1559,11 @@ def __init__(self, target: str, allow_outputs: bool = False) -> None: - target: target device type - allow_outputs: allow outputs to be moved + - allow_inputs: allow inputs to be moved """ self.target = target + self.allow_inputs = allow_inputs self.allow_outputs = allow_outputs assert isinstance(target, str), ( @@ -1518,6 +1585,38 @@ def allow_cpu_device(self, node: fx.Node) -> bool: torch.ops.aten.slice_scatter.default, ) + def is_on_target_device(self, node: fx.Node) -> bool: + """ + Returns whether a node is on the target device. + """ + node_device = self.get_node_device(node) + return node_device is not None and node_device.type == self.target + + def is_cpu_scalar_tensor(self, node: fx.Node) -> bool: + """ + Returns whether a node is a cpu scalar tensor. + """ + device = self.get_node_device(node) + is_cpu = device is not None and device.type == "cpu" + ten = node.meta.get("val") + is_scalar = isinstance(ten, torch.Tensor) and len(ten.size()) == 0 + return is_cpu and is_scalar + + def all_inputs_are_cpu_scalar_or_on_target_device(self, node: fx.Node) -> bool: + """ + Returns whether a node's inputs are either cpu scalar tensors or + on the target device. + """ + inputs = ( + inp + for inp in itertools.chain(node.args, node.kwargs.values()) + if isinstance(inp, fx.Node) + ) + return all( + self.is_cpu_scalar_tensor(inp) or self.is_on_target_device(inp) + for inp in inputs + ) + def cannot_be_moved(self, node: fx.Node) -> bool: """ Returns whether a node can be moved to the target device. @@ -1533,6 +1632,7 @@ def cannot_be_moved(self, node: fx.Node) -> bool: and node.target.namespace in ("prims", "aten") ): return True + if is_index_put_and_requires_h2d_sync_for_gpu_value(node): return True @@ -1569,12 +1669,22 @@ def add_cpu_inp(node): def __call__(self, graph: fx.Graph) -> None: target_devices = OrderedSet[torch.device]() constructors = [] + cpu_placeholders: OrderedSet[fx.Node] = OrderedSet() for node in graph.nodes: device = self.get_node_device(node) if device and device.type == self.target: target_devices.add(device) + if ( + self.allow_inputs + and node.op == "placeholder" + and self.is_cpu_scalar_tensor(node) + ): + cpu_placeholders.add(node) + constructors.append(node) + continue + if not ( isinstance(node.target, torch._ops.OpOverload) and node.target.namespace in ("prims", "aten") @@ -1595,10 +1705,35 @@ def __call__(self, graph: fx.Graph) -> None: movable_constructors = self.find_movable_constructors(graph, constructors) + target_device = next(iter(target_devices)) for node in movable_constructors: - kwargs = node.kwargs.copy() - kwargs["device"] = next(iter(target_devices)) - node.kwargs = kwargs + if node in cpu_placeholders: + with graph.inserting_after(node): + gpu_node = graph.call_function( + torch.ops.prims.device_put.default, (node, target_device) + ) + node.replace_all_uses_with( + gpu_node, + lambda x: x != gpu_node + and x.target != torch.ops.aten.copy_.default, + ) + + # noop elimination if there are other device_put for gpu_node to + # target device. Alternatively, we could just move the other device_put + # earlier in the graph, but that is not supported in fx graph yet. + noop_device_puts = [ + user + for user in gpu_node.users + if user.target == torch.ops.prims.device_put.default + and user.args[1] == target_device + ] + for noop in noop_device_puts: + noop.replace_all_uses_with(gpu_node) + graph.erase_node(noop) + else: + kwargs = node.kwargs.copy() + kwargs["device"] = target_device + node.kwargs = kwargs def find_movable_constructors( self, graph: fx.Graph, constructors: list[fx.Node] @@ -1649,12 +1784,15 @@ def make_dependencies_equivalent( # this node was used on a op which takes in multiple devices and output a gpu # tensor. we can convert its cpu input to gpu without making further changes - node_device = self.get_node_device(user) - if ( - self.allow_cpu_device(user) - and node_device - and node_device.type == self.target + if self.allow_cpu_device(user) and self.is_on_target_device(user): + del cpu_indeg[user] + elif ( + self.allow_inputs + and self.all_inputs_are_cpu_scalar_or_on_target_device(user) ): + # this node takes only cpu scalar tensors or gpu tensors as inputs + # and outputs a gpu tensor. we can convert its cpu scalar inputs to gpu + # without making further changes del cpu_indeg[user] else: # otherwise, we should continue look at its downstream uses @@ -1683,4 +1821,21 @@ def move_constructors_to_gpu(graph: fx.Graph) -> None: """ Moves intermediary tensors which are constructed on the cpu to gpu when safe """ - ConstructorMoverPass(get_gpu_type())(graph) + + # cudagraph does not support cpu tensors. In this pass, we update the graph + # by explicitly moving cpu scalar tensors to gpu when profitable, relying on + # graph partition to split off this data copy, and cudagraphifying + # the remaining gpu ops. + allow_inputs_outputs = ( + True + if ( + torch._inductor.config.triton.cudagraphs + and torch._inductor.config.graph_partition + ) + else False + ) + ConstructorMoverPass( + get_gpu_type(), + allow_inputs=allow_inputs_outputs, + allow_outputs=allow_inputs_outputs, + )(graph) diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index b51d7bc21a1e88..2d1709962e64bb 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -394,7 +394,7 @@ def fetch_attr(target: str, mod): for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( - f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 88c5f8497ac172..862df99a41e50d 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -12,8 +12,17 @@ from torch.fx.experimental.symbolic_shapes import has_free_symbols from torch.fx.node import map_arg +from .. import config from ..lowering import lowerings as L, require_channels_last -from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + KeywordArg, + ListOf, + Match, + stable_topological_sort, +) from ..utils import pad_listlike from .freezing_patterns import register_freezing_graph_pattern from .post_grad import register_lowering_pattern @@ -1068,6 +1077,53 @@ def _register_quantization_reshape(): ) +def _is_valid_concat_linear_int8_woq_optimization_pattern(): + def fn(match): + if not config.cpp.enable_concat_linear: + return False + assert all(k in match.kwargs for k in ("x", "w1", "w2", "w3", "scales")) + if not all( + hasattr(match.kwargs[key], "meta") + for key in ["x", "w1", "w2", "w3", "scales"] + ): + return False + x = match.kwargs["x"].meta["val"] + w1 = match.kwargs["w1"].meta["val"] + w2 = match.kwargs["w2"].meta["val"] + w3 = match.kwargs["w3"].meta["val"] + scales = match.kwargs["scales"].meta["val"] + if len(match.kwargs["scales"].meta["val"].size()) > 1: + return False + num_scales = match.kwargs["scales"].meta["val"].numel() + w1_cols = match.kwargs["w1"].meta["val"].size()[0] + w2_cols = match.kwargs["w2"].meta["val"].size()[0] + w3_cols = match.kwargs["w3"].meta["val"].size()[0] + # Technically, the shapes of the three weights need not be equal. + # But currently, we only enable replacement in this case. + if w1_cols != w2_cols or w2_cols != w3_cols: + return False + if 3 * w1_cols != num_scales: + return False + return ( + # For now, we only support woq mm kernels + # with x.type=bfloat16 and w.type=int8 + x.dtype == torch.bfloat16 + and w1.dtype == torch.int8 + and w2.dtype == torch.int8 + and w3.dtype == torch.int8 + and scales.dtype == torch.bfloat16 + # _weight_int8pack_mm kernel only supports cpu now + # TODO: add cuda kernel support instead of calling mul+sum + and x.device.type == "cpu" + and x.device == w1.device + and w1.device == w2.device + and w2.device == w3.device + and x.device == scales.device + ) + + return fn + + def _is_valid_woq_optimization_pattern(): def fn(match): assert all(k in match.kwargs for k in ("x", "weight", "scales")) @@ -1094,6 +1150,73 @@ def fn(match): return fn +def _register_concat_linear_int8_woq_lowering( + pattern, computation_woq, computation_reshape +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(), + pass_number=4, + ) + def woq(match: Match, *args, **kwargs): + x = kwargs["x"] + w1 = kwargs["w1"] + w2 = kwargs["w2"] + w3 = kwargs["w3"] + scales = kwargs["scales"] + counters["inductor"]["woq_matcher_count"] += 1 + counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) + out_features = ( + w1.meta["val"].size()[0] + + w2.meta["val"].size()[0] + + w3.meta["val"].size()[0] + ) + origin_x_size = tuple(x.meta["val"].size()) + x_shape = [-1, origin_x_size[-1]] + out_shape = list(origin_x_size[:-1] + (out_features,)) + mm_node_of_x = None + for candidate in iter(x.users.keys()): + if ( + candidate.target == aten.mm.default + and list(candidate._input_nodes)[1].target == aten.cat.default + ): + mm_node_of_x = candidate + break + assert mm_node_of_x is not None, "unable to find mm node" + _, cat_wgt_node = mm_node_of_x._input_nodes + scaling_node = next(iter(mm_node_of_x.users.keys())) + user_of_scaling_node = next(iter(scaling_node.users.keys())) + # Some other pass is making some changes that entails + # adding a node before it's used, but it can only be found when + # lint is run. stable_topological_sort() is being run before lint, + # so that error was not being being discovered. + # We call stable_topological_sort here as a workaround. + stable_topological_sort(match.graph) + with match.graph.inserting_before(user_of_scaling_node): + new_cat_node = match.graph.call_function( + aten.cat.default, + args=([w1, w2, w3], 0), + ) + x_reshape_node = match.graph.call_function( + computation_reshape, args=(x, x_shape) + ) + new_woq_node = match.graph.call_function( + computation_woq, + args=(x_reshape_node, new_cat_node, scales), + ) + new_woq_node.meta = copy.copy(x.meta) + output_reshape_node = match.graph.call_function( + computation_reshape, args=(new_woq_node, out_shape) + ) + scaling_node.replace_all_uses_with(output_reshape_node) + match.graph.erase_node(scaling_node) + match.graph.erase_node(mm_node_of_x) + match.graph.erase_node(cat_wgt_node) + match.graph.lint() + + return woq + + def _register_woq_lowering(pattern, computation_woq, computation_reshape): @register_lowering_pattern( pattern, @@ -1214,6 +1337,32 @@ def _register_woq_mm_int8_pattern4(): _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) +def _register_int8_woq_concat_linear_pattern(): + def _create_wgt_node(wgt_node_name: str): + return CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg(wgt_node_name), + Arg(), + ), + Arg(), + ) + + cat_wgt = CallFunction( + aten.cat.default, [_create_wgt_node(wgt) for wgt in ["w1", "w2", "w3"]], 1 + ) + + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction(aten.mm.default, KeywordArg("x"), cat_wgt), + KeywordArg("scales"), + ) + _register_concat_linear_int8_woq_lowering( + _woq_pattern, aten._weight_int8pack_mm.default, aten.reshape + ) + + def _register_quantization_lowerings(): _register_quantization_unary_lowering() _register_quantization_binary_lowering() @@ -3478,7 +3627,7 @@ def _register_qlinear_binary_fusion(): ) -@functools.lru_cache(None) +@functools.cache def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 _register_dequant_promotion() @@ -3502,6 +3651,155 @@ def _register_quantization_weight_pack_pass(): _register_qlinear_binary_fusion() +def _is_valid_concat_linear_woq_int4_fusion(computation_nodes): + computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default + act = computation_nodes[0].args[0] + wgt = computation_nodes[0].args[1] + in_feature_size = wgt.meta.get("val").size(1) # type: ignore[union-attr] + group_size = computation_nodes[0].args[2] + return len(computation_nodes) >= 2 and all( + ( + node.target == computation_op + and node.args[0] == act # share same activation + and ( + node.args[1].meta.get("val").size(1) == in_feature_size + ) # same in feature size + and (node.args[1] != wgt or gemm_idx == 0) + and node.args[1].op == "get_attr" # wgt are all constants + and node.args[2] == group_size # same group size + ) + for gemm_idx, node in enumerate(computation_nodes) + ) + + +def concat_linear_woq_int4(gm: torch.fx.GraphModule): + """ + Concat Linear optimization pass for WOQ int4 + This pass fuses the original pattern: + def ... + return (woq_int4(x, w1, group_size, scale_zp1), woq_int4(x, w2, group_size, scale_zp1) ...) + into a single operation: + def ... + concat_res = woq_int4(x, concat_w, group_size, concat_scale_zp) + return split(concat_res, split_size_list) + """ + + def concat_wgt(packed_wgts, scale_zps, group_size, act_dtype): + # Concat the wgts and scale_zps, and repack the wgt + unpacked_wgts = [] + for packed_wgt in packed_wgts: + # Get the unpacked weight list + # Same as https://github.com/pytorch/pytorch/pull/156174 + K = packed_wgt.size(1) * 2 + N = packed_wgt.size(0) + x = torch.eye(K).to(dtype=act_dtype) + qscales_and_zeros = ( + torch.tensor([1.0, 8.0]) + .to(dtype=act_dtype) + .expand(K // group_size, N, 2) + .contiguous() + ) + unpacked_wgts.append( + torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + packed_wgt, + group_size, + qscales_and_zeros, + ) + .t() + .contiguous() + .to(torch.int32) # N, K + ) + concat_unpacked_wgt = torch.cat(unpacked_wgts, dim=0) + repack_w = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + concat_unpacked_wgt, 1 + ) + concat_scale_zp = torch.cat(scale_zps, dim=1).contiguous() + return repack_w, concat_scale_zp + + graph = gm.graph + computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default + for node in graph.find_nodes(op="call_function", target=computation_op): + if ( + not node._erased + and isinstance(node.meta.get("val"), torch.Tensor) + and node.meta["val"].device.type == "cpu" + ): + act = node.args[0] + users = list(act.users) + if _is_valid_concat_linear_woq_int4_fusion(users): + with graph.inserting_before(node): + assert all(user.args[1].op == "get_attr" for user in users) + computation_node_0 = users[0] + packed_wgts = [getattr(gm, user.args[1].target) for user in users] + group_size = computation_node_0.args[2] + scale_zps = [getattr(gm, user.args[3].target) for user in users] + out_feature_size_list = [ + packed_wgt.size(0) for packed_wgt in packed_wgts + ] + repack_w, concat_scale_zp = concat_wgt( + packed_wgts, scale_zps, group_size, act.meta.get("val").dtype + ) + repack_w_node_name = computation_node_0.args[1].target + "_concat" + concat_scale_zp_node_name = ( + computation_node_0.args[3].target + "_concat" + ) + gm.register_buffer(repack_w_node_name, repack_w) + setattr(gm, repack_w_node_name, repack_w) + gm.register_buffer(concat_scale_zp_node_name, concat_scale_zp) + setattr(gm, concat_scale_zp_node_name, concat_scale_zp) + + repack_w_node = graph.create_node( + "get_attr", repack_w_node_name, (), {} + ) + with graph.inserting_after(repack_w_node): + concat_scale_zp_node = graph.create_node( + "get_attr", concat_scale_zp_node_name, (), {} + ) + + with graph.inserting_after(concat_scale_zp_node): + concat_int4_gemm_node = graph.create_node( + "call_function", + computation_op, + ( + act, + repack_w_node, + group_size, + concat_scale_zp_node, + ), + ) + with graph.inserting_after(concat_int4_gemm_node): + split_node = graph.create_node( + "call_function", + torch.ops.aten.split_with_sizes.default, + ( + concat_int4_gemm_node, + out_feature_size_list, + 1, # split dim + ), + ) + with graph.inserting_after(split_node): + for gemm_idx, user in enumerate(users): + assert user.target == computation_op + get_item = graph.create_node( + "call_function", + operator.getitem, + ( + split_node, + gemm_idx, + ), + ) + with graph.inserting_after(get_item): + clone_node = graph.create_node( + "call_function", + torch.ops.aten.clone.default, + (get_item,), + {"memory_format": torch.contiguous_format}, + ) + user.replace_all_uses_with(clone_node) + graph.erase_node(user) + + def quant_lift_up(graph_module: torch.fx.GraphModule): """ Lift up the quant node before view like nodes. It can benefit performance diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index ee258dfd41589c..b67c0dbb729ade 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -3,10 +3,12 @@ import logging import operator from collections import defaultdict +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Callable, Union +from typing import Any, Callable, cast, Union import torch +import torch.fx.node from torch._C._dynamo.guards import compute_overlapping_tensors from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger @@ -176,7 +178,12 @@ def _decompose_scatter_mutating( def scatter_always_uses_mutation(node: torch.fx.Node) -> bool: _, _, view_ops = node.args - return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr] + view_ops = cast(Sequence[torch.fx.node.Argument], view_ops) + return any( + target in _ALWAYS_MUTATING_SCATTER_OPS + for view in view_ops + if isinstance(target := getattr(view, "target", None), torch._ops.OpOverload) + ) def should_reinplace_scatter(node: torch.fx.Node) -> bool: @@ -253,7 +260,7 @@ def scatter(inp, src, views): def handle_views(node: torch.fx.Node): inp = node.args[0] - node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type] + node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type, assignment] node_to_view_op[node] = [ *node_to_view_op[inp], # type: ignore[index] ViewOp( @@ -267,6 +274,7 @@ def handle_view_scatter(node: torch.fx.Node): assert len(node.args) >= 2 inp, src = node.args[:2] + assert isinstance(node.target, torch._ops.OpOverload) scatter_view_op = ViewOp( _SCATTER_OP_TO_VIEW[node.target], args=node.args[2:], @@ -331,7 +339,7 @@ def can_fuse(): handle_view_scatter(node) -inplaceable_ops = { +inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = { aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0), _generalized_scatter: InplaceableOp( @@ -343,7 +351,7 @@ def can_fuse(): try: c10d_functional = torch.ops._c10d_functional - inplaceable_collective_ops = { + inplaceable_collective_ops: dict[Callable[..., Any], InplaceableOp] = { c10d_functional.all_reduce.default: InplaceableOp( c10d_functional.all_reduce_.default, 0 ), @@ -751,8 +759,15 @@ def tensor_with_same_storage_already_reinplaced(arg): graph.erase_node(node) -def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None: +def reinplace_inplaceable_ops( + fake_tensor_updater: torch._inductor.fx_utils.FakeTensorUpdater, + graph: torch.fx.Graph, +) -> None: with enable_python_dispatcher(): canonicalize_view_scatter_ops(graph) + # canonicalize_view_scatter_ops adds new operations to the graph. + # We run fake_tensor_updater to update the alias information. + # Correct alias information is required for `reinplace_inplaceable_ops_core`. + fake_tensor_updater.incremental_update() reinplace_inplaceable_ops_core(graph) decompose_generalized_scatter(graph) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py index 06f5efb35d4e72..9185aa3b1e3305 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py @@ -34,21 +34,22 @@ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) -expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) -view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) -view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) -bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) -view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_2, _users=2) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) @@ -57,42 +58,42 @@ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) -view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) -bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) -view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) neg_default = CallFunction(aten.neg.default, div_Tensor_1) -view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) -permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) -bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) -view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) -mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) -view_default_8 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) -permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) -bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) -view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) -permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) -bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) -view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) -permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) -permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) -bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) -view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) -permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) -_sfdp_pattern_20_training = MultiOutputPattern([view_default_5, +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_training = MultiOutputPattern([view_default_6, permute_default_6, permute_default_9, permute_default_11, @@ -136,21 +137,22 @@ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) -expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) -view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) -view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) -bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) -view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_2) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) @@ -161,45 +163,45 @@ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) -view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) -bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) -view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) -view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) -permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) -bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) -view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) -mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) -view_default_8 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) -permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) -bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) -view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) -permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) -bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) -view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) -permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) -permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) -bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) -view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) -permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) -_sfdp_pattern_20_half_training = MultiOutputPattern([view_default_5, +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_half_training = MultiOutputPattern([view_default_6, permute_default_6, permute_default_9, permute_default_11, diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py new file mode 100644 index 00000000000000..ad27e6eb6bb8eb --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py @@ -0,0 +1,217 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_half_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py new file mode 100644 index 00000000000000..41a433e405433a --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py @@ -0,0 +1,229 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py new file mode 100644 index 00000000000000..dc6f27cd284924 --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py @@ -0,0 +1,225 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 8f41e788538576..7636d7cde3647e 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -247,7 +247,7 @@ def remove_split_with_size_one(match: Match, *args, **kwargs): return # remove the dummy split whose split sections size is one # theoretically nodes with no users should be removed, but we have seen the corner case - # thus we add its uers check to walk around the StopIteration error. + # thus we add its users check to walk around the StopIteration error. if len(split_sections) == 1 and len(split_node.users.keys()) > 0: # find the grand children of the split_node next_users = find_next_users(split_node) @@ -302,7 +302,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs): @register_graph_pattern( - CallFunctionVarArgs(torch.cat, users=MULTIPLE), + CallFunctionVarArgs([torch.cat, torch.concat], users=MULTIPLE), pass_dict=construct_pattern_matcher_pass("normalization_pass"), ) def normalize_cat_default(match: Match, *args, **kwargs): @@ -347,6 +347,7 @@ def is_empty_tensor(x): cat_node.args == new_args and cat_node.kwargs == new_kwargs and cat_node.op == "call_function" + and cat_node.target == torch.cat ): return @@ -1525,7 +1526,7 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): # find the index of getitems to be cated/stacked # type: ignore[union-attr] indices = [arg.args[1] for arg in cat_user.args[0]] # type: ignore[union-attr] - # the gettitems to be merged must be consecutive, otherwise + # the getitems to be merged must be consecutive, otherwise # returned sliced tensor could be wrong if not is_sorted_and_consecutive(indices): # type: ignore[arg-type] continue @@ -1627,7 +1628,7 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int): for getitem in cat_user.args[0]: # type: ignore[union-attr] indices.append(getitem.args[1]) # type: ignore[union-attr] idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] - # the gettitems to be merged must be consecutive, otherwise + # the getitems to be merged must be consecutive, otherwise # returned sliced tensor could be wrong if not is_sorted_and_consecutive(indices): # type: ignore[arg-type] continue @@ -2069,7 +2070,7 @@ def update_args_from_split_getitem( threshold_to_cat: int = 2, ): split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1]) - # case 1: the number of getitems is the same as the split size, elimiate the split + # case 1: the number of getitems is the same as the split size, eliminate the split if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive( getitem_indices ): @@ -2164,7 +2165,7 @@ def update_args_from_unbind_getitem( unbind_input = get_arg_value(parents_seen[-1], 0, "input") # split or unbind input unbind_dim = get_arg_value(parents_seen[-1], 1, "dim") # split or unbind dim cat_dim = get_arg_value(node, 1, "dim") # cat or stack dim - # case 1: the number of getitems is the same as the split size, elimiate the split + # case 1: the number of getitems is the same as the split size, eliminate the split size = list(unbind_input.meta["example_value"].shape)[unbind_dim] if size == len(getitem_indices): cat_shape = torch.cat( diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index 750fd6a0b3bc89..efad708764fe48 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import contextlib import operator from collections import defaultdict from typing import Any, Callable, Optional @@ -88,6 +89,7 @@ def hash_node(self, node: torch.fx.Node): return (node, node.target, id(node.args), id(node.kwargs)) def incremental_update(self): + """Update FakeTensors on self.graph. We will try to do the minimum amount of work.""" existing_storages: defaultdict[Optional[int], int] = defaultdict(int) for node in self.graph.nodes: existing_storages[get_node_storage(node)] += 1 @@ -95,14 +97,15 @@ def incremental_update(self): def is_intlist_same(new, old): return statically_known_true(sym_eq(new, old)) - def is_fake_tensor_same(new, old): + def is_fake_tensor_same(new, old, *, node): if type(new) != type(old): return False if isinstance(new, (list, tuple)): if len(new) != len(old): return False return all( - is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old) + is_fake_tensor_same(new_i, old_i, node=node) + for new_i, old_i in zip(new, old) ) if new is None: return old is None @@ -132,12 +135,61 @@ def is_fake_tensor_same(new, old): if get_storage(new) == get_storage(old): return True + def any_user_may_alias(node): + if not isinstance(node.meta["val"], torch.Tensor): + # analysis too complicated on lists, can support in the future + return True + for user in node.users: + if not ( + isinstance( + user.target, + (torch._ops.OpOverload, torch._ops.HigherOrderOperator), + ) + or user.target + == torch._inductor.fx_passes.reinplace._generalized_scatter + ): + return True + if isinstance(user.target, torch._ops.HigherOrderOperator): + # HOPs that survive until inductor are all non-aliasing HOPs. + # We will likely never support HOPs that are aliasing. + continue + # Strategy: do a FakeTensor prop, see if the storage aliases. + # If Inductor ever gets tighter invariants on OpOverloads + # (that is, we ban things like torch.ops.aten.reshape calls in the graph), + # Then this could just be a fast schema lookup. + is_valid, args, kwargs = get_fake_args_kwargs(user) + if not is_valid: + return True + with ( + V.fake_mode, + enable_python_dispatcher(), + contextlib.ExitStack() as stack, + ): + # Ignore unbacked symbols (if they exist): we're making + # this FakeTensor and then throwing it away. + shape_env = V.fake_mode.shape_env + if shape_env is not None: + stack.enter_context( + shape_env.ignore_fresh_unbacked_symbols() + ) + new_fake_tensor = user.target(*args, **kwargs) + if not isinstance(new_fake_tensor, torch.Tensor): + # analysis too complicated on lists, can support in the future + return True + if get_storage(new_fake_tensor) == get_storage(node.meta["val"]): + return True + return False + # This is the case where it returns a completely fresh storage that's used nowhere else. + # If the FakeTensor's storage is fresh and none of the node's users can alias it, then + # we don't need to update this node. if ( existing_storages[get_storage(old)] == 1 and get_storage(new) not in existing_storages + and not any_user_may_alias(node) ): return True + return False def should_process_node(node): @@ -149,10 +201,16 @@ def should_process_node(node): return node.op == "call_function" and ( isinstance(node.target, torch._ops.OpOverload) or node.target == operator.getitem + or node.target + == torch._inductor.fx_passes.reinplace._generalized_scatter ) to_process = OrderedSet[int]() for node in self.graph.nodes: + # NB: Be very careful about skipping nodes (via continues) here + # and ask for a careful review when changing this code. The + # consequence for incorrect FakeTensor metadata is difficult-to-debug + # silent incorrectness. if ( self.hash_node(node) in self.processed_hashes and id(node) not in to_process @@ -167,8 +225,9 @@ def should_process_node(node): continue with V.fake_mode, enable_python_dispatcher(): new_fake_tensor = node.target(*args, **kwargs) + if "val" in node.meta and is_fake_tensor_same( - new_fake_tensor, node.meta["val"] + new_fake_tensor, node.meta["val"], node=node ): continue diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index ee3fc86afc909b..fc321b31502709 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -74,6 +74,7 @@ InputBuffer, Pointwise, Reduction, + ShapeAsConstantBuffer, StorageBox, TensorBox, TorchBindObject, @@ -106,6 +107,7 @@ maybe_get_suppress_shape_guards_ctx, normalize_name, should_assume_input_aligned, + SUPPORTED_MKLDNN_DEVICES, ValueWithLineMap, ) from .virtualized import NullHandler, V @@ -217,7 +219,6 @@ def mark_nodes_dislike_padding( aten.convolution, aten.convolution_backward, aten._scaled_mm, - aten._scaled_grouped_mm, ] ) # what's a better way to collect the reduction ops? @@ -343,7 +344,7 @@ def __init__( self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {} self.graph_inputs_original: dict[str, InputBuffer] = {} self.partition_maps: Optional[list[GraphPartitionMap]] = None - self.zero_dim_cpu_tensor_list = OrderedSet[str]() + self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet() self.device_types: OrderedSet[str] = ( const_module.device_types if const_module else OrderedSet() ) @@ -380,12 +381,12 @@ def __init__( ] = {} self.seen_subgraphs: dict[str, ir.Subgraph] = {} self.constant_reprs: dict[str, str] = {} - self.removed_operations = OrderedSet[str]() - self.removed_buffers = OrderedSet[str]() - self.removed_inplace_buffers = OrderedSet[str]() - self.mutated_buffers = OrderedSet[str]() - self.never_reuse_buffers = OrderedSet[str]() - self.inplaced_to_remove = OrderedSet[str]() + self.removed_operations: OrderedSet[str] = OrderedSet() + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.removed_inplace_buffers: OrderedSet[str] = OrderedSet() + self.mutated_buffers: OrderedSet[str] = OrderedSet() + self.never_reuse_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment] # See `ProxyExecutor Design Note` in ir.py for more details @@ -401,7 +402,7 @@ def __init__( self.current_node: torch.fx.Node = None # type: ignore[assignment] self.lists: dict[str, list[str]] = {} - self.mutated_inputs = OrderedSet[str]() + self.mutated_inputs: OrderedSet[str] = OrderedSet() self.mutated_input_idxs: list[int] = [] self.name_to_buffer: dict[str, ir.Buffer] = {} self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list) @@ -466,14 +467,14 @@ def __init__( # This can either be a graph input or the output of fallback # kernels. self.unaligned_buffers: OrderedSet[str] = OrderedSet() - self.no_fuse_buffer_names = OrderedSet[str]() + self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet() self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet() # more aggressive prologue fusion self.invoke_quant_ops: OrderedSet[str] = OrderedSet() # Below field is related to printing debug intermediate tensor values info for debugging - self.all_codegen_kernel_names = OrderedSet[str]() + self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet() # state used by for Kernel.workspace self.workspace_id = itertools.count() @@ -615,7 +616,7 @@ def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool: torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available() and all( - n.args[idx].meta["val"].device == torch.device("cpu") + n.args[idx].meta["val"].device.type in SUPPORTED_MKLDNN_DEVICES for n in conv_nodes for idx in [0, 1] ) @@ -1030,7 +1031,7 @@ def allocate_non_dup_const_name( def add_tensor_constant( self, data: Tensor, name: Optional[str] = None - ) -> TensorBox: + ) -> Union[TensorBox, ir.ShapeAsConstantBuffer]: new_name = self.allocate_non_dup_const_name(name, data) return TensorBox.create( ir.ConstantBuffer( @@ -1067,7 +1068,12 @@ def placeholder( example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] target = self.qualify_name(target) if isinstance(example, SymTypes): - expr = _get_placeholder_expr(example.node) + # TODO fix partitioning issue and re-enable for backward + # https://github.com/pytorch/pytorch/issues/155468. + if not V.graph.is_backward: + expr = _get_placeholder_expr(example.node) + else: + expr = example.node.expr self.graph_inputs[target] = expr self.graph_input_names.append(target) return expr @@ -1134,7 +1140,7 @@ def placeholder( self.graph_inputs[target] = tensor self.graph_input_names.append(target) - self.graph_inputs_original[target] = tensor.data.data + self.graph_inputs_original[target] = tensor.data.data # type: ignore[union-attr] if self.current_node.users: # cudagraphs should work with an unused CPU input self.add_device_info(example.device) @@ -1276,7 +1282,9 @@ def get_attr( target: str, # type: ignore[override] args: tuple[()], # type: ignore[override] kwargs: dict[str, object], - ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]: + ) -> Union[ + Constant, TensorBox, ShapeAsConstantBuffer, ir.Subgraph, TorchBindObject + ]: # this is a constant value = getattr_recursive(self.module, target) # type: ignore[arg-type] @@ -1469,6 +1477,7 @@ def propagate_mutation( k: v.meta["val"] if isinstance(v, torch.fx.Node) else v for k, v in kwargs.items() }, + old_kwargs["tma_descriptor_metadata"], ) for name in mutated: old_arg = old_kwargs["kwargs"][name] @@ -1516,7 +1525,7 @@ def maybe_propagate( def run_node(self, n: torch.fx.Node) -> object: def debug(msg: str) -> None: - log.debug("lowering %s %s", LazyString(n.format_node), msg) + log.debug("lowering %s %s", LazyString(n.format_node), msg) # type: ignore[arg-type] from torch._inductor.compiler_bisector import CompilerBisector @@ -1536,7 +1545,8 @@ def debug(msg: str) -> None: ): if ( n.op == "call_function" - and n.target is not operator.getitem + and n.target + not in (operator.getitem, torch._higher_order_ops.invoke_subgraph) and ( fallback_node_due_to_unsupported_type(n) or CompilerBisector.disable_subsystem( @@ -1878,7 +1888,7 @@ def create_deferred_runtime_asserts( # [NOTE] Codegen runtime asserts in Inductor # # We need to generate runtime asserts directly in Inductor instead - # of just re-using the asserts from input graphs becase we reuse the + # of just reusing the asserts from input graphs because we reuse the # same ShapeEnv as before. In particular, on subsequent graph passes, # we would immediately turn all of these assertions into noops, # because when we evaluated their expressions, we would see that @@ -1894,8 +1904,8 @@ def create_deferred_runtime_asserts( # equals = torch.add(ones, c) # return equals # torch._dynamo.mark_dynamic(c, 0) - # When we re-use the ShapeEnv in Inductor lowering, the check that checks - # a and nonzero have the same shape would be evaluted to True after we resolve + # When we reuse the ShapeEnv in Inductor lowering, the check that checks + # a and nonzero have the same shape would be evaluated to True after we resolve # unbacked bindings using the ShapeEnv. # See test_unbacked_equals_input_size_runtime_assertion in test_aot_inductor. # @@ -2043,6 +2053,7 @@ def extract_autotune_inputs( k: v.meta["val"] if isinstance(v, torch.fx.Node) else v for k, v in kwargs.items() }, + node.kwargs["tma_descriptor_metadata"], ) new_kwargs: dict[str, int] = {} @@ -2245,7 +2256,7 @@ def codegen_subgraph(self, parent_graph: GraphLowering) -> None: graph. The parent graph is passed as an argument: the intention is to inline codegening of the subgraph in the parent graph's wrapper code (including the generated - kerenls). The wrapper code is not finalized (via `.generate()` + kernels). The wrapper code is not finalized (via `.generate()` call), as this will be done in the parent graph's `codegen()`. """ with dynamo_timed("GraphLowering.codegen_subgraph", log_pt2_compile_event=True): @@ -2320,10 +2331,14 @@ def _compile_to_module_lines( from .codecache import PyCodeCache if config.triton.autotune_at_compile_time: + # sanitize docstrings in kernel defs (#155006) + kernel_autotune_defs = self.wrapper_code.kernel_autotune_defs.getvalue() + kernel_autotune_defs = kernel_autotune_defs.replace('"""', '\\"\\"\\"') + tuning_code = ( '"""\n' + "Compile-time auto-tuning block: \n" - + self.wrapper_code.kernel_autotune_defs.getvalue() + + kernel_autotune_defs + self.wrapper_code.kernel_autotune_calls.getvalue() + '"""\n' ) diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 16430ced7e6c32..a43925b8d744e1 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -35,7 +35,7 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from .ops_handler import DefaultHandler -from .sizevars import evaluate_expr +from .sizevars import statically_known_true from .utils import generate_assert from .virtualized import V @@ -311,7 +311,7 @@ def statically_true(self, e): If this is an issue, just use guards in `self.axioms`. The proper way of handling this would be to have a global shape_env that adds - runtime_asserts as they happen in the code. Then, it shuld be used in SimplifyIndexing + runtime_asserts as they happen in the code. Then, it should be used in SimplifyIndexing to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also for indirect_indexing """ @@ -322,7 +322,7 @@ def statically_true(self, e): for k, v in self.indirect_var_ranges.items() ), ) - return evaluate_expr(self.shape_env, e, self.axioms, var_to_range) + return statically_known_true(self.shape_env, e, self.axioms, var_to_range) def indirect_indexing( self, diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index c33a02afff4b8a..850974902a4136 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5,25 +5,36 @@ import functools import itertools import logging +import operator import textwrap import traceback -import typing -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Container, Generator, Iterable, Iterator, Sequence from contextlib import AbstractContextManager, nullcontext from enum import Enum from functools import partial from typing import ( Any, Callable, + cast, ClassVar, Literal, Optional, overload, + SupportsFloat, + SupportsInt, TYPE_CHECKING, TypeVar, Union, ) -from typing_extensions import assert_never, Never, TypeAlias +from typing_extensions import ( + assert_never, + Never, + override, + ParamSpec, + Self, + TypeAlias, + TypeIs, +) from unittest.mock import patch import sympy @@ -58,6 +69,7 @@ statically_known_true, SymTypes, ) +from torch.fx.node import Node from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import SymT @@ -68,6 +80,7 @@ CodegenSymbol, get_scheduling_for_device, index_prevent_reordering, + Kernel, ) from .dependencies import ( Dep, @@ -89,6 +102,7 @@ convert_shape_to_symint, developer_warning, do_bench_using_profiling, + dtype_from_size, get_dtype_size, get_kernel_metadata, GPU_ALIGN_BYTES, @@ -107,9 +121,11 @@ if TYPE_CHECKING: from torch._library.fake_class_registry import FakeScriptObject - from torch.fx.node import Node + from torch.fx.experimental.symbolic_shapes import SympyBoolean + from torch.fx.node import Argument from .codegen.cuda.cuda_template import CUDATemplate + from .codegen.wrapper import PythonWrapperCodegen from .graph import GraphLowering from .utils import IndentedBuffer @@ -127,6 +143,7 @@ has_triton = False +_P = ParamSpec("_P") _T = TypeVar("_T") _U = TypeVar("_U") _V = TypeVar("_V") @@ -134,6 +151,8 @@ _IntLike: TypeAlias = Union[int, Expr] _NumLike: TypeAlias = Union[int, float, Expr] +_OpOverloads: TypeAlias = Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] + log = logging.getLogger(__name__) indent = functools.partial(textwrap.indent, prefix=" ") aten = torch.ops.aten @@ -192,9 +211,13 @@ ] +def _is_static(x: object) -> bool: + return isinstance(x, (int, Integer)) + + @dataclasses.dataclass(frozen=True) class GraphPartitionSignature: - # symbol inputs that are neccessary for codegen + # symbol inputs that are necessary for codegen symbol_inputs: OrderedSet[sympy.Symbol] # mapping from partition input name to IRNode or Expr. Need the name str since @@ -246,7 +269,7 @@ def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None: def ops_wrapper(name: str) -> Callable[..., OpsValue]: - assert isinstance(name, str) + assert isinstance(name, str), type(name) def fn(*args: object, **kwargs: object) -> OpsValue: return getattr(ops, name)(*args, **kwargs) @@ -406,7 +429,7 @@ def is_triton(x: Union[IRNode, torch.device, None, str]) -> bool: return False from .codegen.triton import TritonScheduling - assert isinstance(device_scheduling, type) + assert isinstance(device_scheduling, type), type(device_scheduling) return issubclass(device_scheduling, TritonScheduling) @@ -419,13 +442,13 @@ def is_aligned_realized_tensor(x: Union[Buffer, TensorBox], alignment: int) -> b return False aligned_strides = all( - (V.graph.sizevars.size_hint(x.get_stride()[i]) % alignment) == 0 + (V.graph.sizevars.size_hint_or_throw(x.get_stride()[i]) % alignment) == 0 for i in range(len(x.get_stride()) - 1) ) - # if the last dim size is <= 1, stride doesnt matter + # if the last dim size is <= 1, stride doesn't matter aligned_last_dim = ( - V.graph.sizevars.size_hint(x.get_stride()[-1]) == 1 - or V.graph.sizevars.size_hint(x.get_size()[-1]) <= 1 + V.graph.sizevars.size_hint_or_throw(x.get_stride()[-1]) == 1 + or V.graph.sizevars.size_hint_or_throw(x.get_size()[-1]) <= 1 ) return aligned_last_dim and aligned_strides @@ -440,7 +463,7 @@ def significant_strides_equal( """ assert len(shape) == len(strides1) and len(strides1) == len(strides2) for dim, s1, s2 in zip(shape, strides1, strides2): - if V.graph.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type] + if V.graph.sizevars.statically_known_leq(dim, 1): continue if not V.graph.sizevars.statically_known_equals( @@ -454,9 +477,9 @@ def significant_strides_equal( def try_match_insignificant_strides( - tensor: Union[TensorBox, BaseView], + tensor: IRNode, strides: Sequence[Union[int, torch.SymInt]], -) -> Union[TensorBox, BaseView]: +) -> IRNode: """ Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant dimensions - size 0 or 1 - will be updated. @@ -470,7 +493,7 @@ def try_match_insignificant_strides( V.graph.sizevars.statically_known_equals(s1, s2) for s1, s2 in zip(strides, tensor.get_stride()) ): - return tensor # type: ignore[arg-type] + return tensor if not significant_strides_equal(strides, tensor.get_stride(), tensor.get_size()): return tensor @@ -478,7 +501,7 @@ def try_match_insignificant_strides( storage, old_layout = as_storage_and_layout(tensor) new_stride = [*old_layout.stride] for i, s in enumerate(tensor.get_size()): - if V.graph.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type] + if V.graph.sizevars.statically_known_leq(s, 1): new_stride[i] = strides[i] new_layout = FixedLayout( @@ -501,25 +524,13 @@ def gm_original_output_strides(gm: torch.fx.GraphModule) -> None: record_original_output_strides(gm) -def add_symbolic_shapes_for_inputs_to_subgraph( - inputs: list[Buffer], subgraph: GraphLowering -) -> list[Expr]: +def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]: sym_vars: OrderedSet[Expr] = OrderedSet() for inp in inputs: sym_vars |= get_free_symbols(inp.get_size(), unbacked_only=False) sym_vars |= get_free_symbols(inp.get_stride(), unbacked_only=False) - sym_inputs = [] - for sym_var in sym_vars: - assert sym_var in V.graph.graph_inputs.values() - - for graph_inp in V.graph.graph_inputs: - if V.graph.graph_inputs[graph_inp] == sym_var: - subgraph.graph_inputs[graph_inp] = sym_var - subgraph.graph_input_names.append(graph_inp) - sym_inputs.append(sym_var) - - return sym_inputs + return list(sym_vars) class IRNode: @@ -631,7 +642,7 @@ def get_numel(self) -> Expr: return sympy_product(self.get_size()) def is_zero_elements(self) -> bool: - return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) + return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0)) def realize(self) -> Optional[str]: """ @@ -711,18 +722,18 @@ def freeze_layout(self) -> None: raise NotImplementedError(type(self).__name__) def freeze_layout_with_stride_order( - self, order: list[int], allow_padding: bool = False + self, order: Sequence[int], allow_padding: bool = False ) -> None: raise NotImplementedError(type(self).__name__) - def freeze_layout_with_fill_order(self, order: list[int]) -> None: + def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None: raise NotImplementedError(type(self).__name__) - def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None: + def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None: raise NotImplementedError(type(self).__name__) def freeze_layout_with_exact_strides( - self, exact_strides: list[_IntLike], allow_padding: bool = False + self, exact_strides: Sequence[_IntLike], allow_padding: bool = False ) -> None: raise NotImplementedError(type(self).__name__) @@ -746,7 +757,7 @@ def get_free_symbol_uses( def get_reduction_type(self) -> Optional[str]: raise NotImplementedError(type(self).__name__) - def get_reduction_size(self) -> Sequence[sympy.Expr]: + def get_reduction_size(self) -> Sequence[Expr]: raise NotImplementedError(type(self).__name__) def is_extern(self) -> bool: @@ -896,7 +907,9 @@ def get_pointwise_size(self) -> Sequence[Expr]: return self.ranges @classmethod - def create(cls, *args: Any, **kwargs: Any) -> TensorBox: + def create( + cls, *args: Any, **kwargs: Any + ) -> Union[TensorBox, ShapeAsConstantBuffer]: origin_node = kwargs.pop("origin_node", None) tb = kwargs.pop("traceback", None) r = cls(*args, **kwargs) @@ -963,7 +976,7 @@ def get_read_names(self) -> OrderedSet[str]: def num_reads(self) -> int: return len(self.inner_fn_opcount().read_buffers) - def get_reduction_size(self) -> Sequence[sympy.Expr]: + def get_reduction_size(self) -> Sequence[Expr]: raise NotImplementedError( f"get_reduction_size() is not implemented by {type(self)}!" ) @@ -1042,7 +1055,7 @@ def store_output( output_name: Optional[str], indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], - ) -> None: + ) -> Any: loader = self.make_loader() if output_name is None: output_name = "unnamed" @@ -1145,7 +1158,7 @@ def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges) ) - def get_reduction_size(self) -> Sequence[sympy.Expr]: + def get_reduction_size(self) -> Sequence[Expr]: return self.reduction_ranges def get_reduction_type(self) -> Optional[str]: @@ -1164,7 +1177,7 @@ def store_reduction( self.reduction_type, self.inner_fn(vars, reduction_vars), ) - return ops.store_reduction(output_name or "unnamed", indexer(vars), value) + ops.store_reduction(output_name or "unnamed", indexer(vars), value) def index_length(self) -> int: return len(self.ranges) + len(self.reduction_ranges) @@ -1201,16 +1214,13 @@ def num_splits( device: torch.device, dst_dtype: torch.dtype, src_dtype: torch.dtype, - inner_fn: Callable[..., OpsValue], + inner_fn: Callable[_P, OpsValue], ranges: Sequence[_IntLike], reduction_ranges: Sequence[_IntLike], reduction_type: Union[ReductionType, Literal["scan"]], reduction_numel: Expr, input_node: Optional[IRNode] = None, ) -> tuple[ReductionHint, _IntLike]: - def _is_static(x: object) -> bool: - return isinstance(x, (int, Integer)) - reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) @@ -1296,10 +1306,12 @@ def inner_reduction_splits( ) def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]: + device = r.get_device() + assert device is not None cb = ComputedBuffer( name=None, layout=FlexibleLayout( - device=r.get_device(), + device=device, dtype=r.get_dtype(), size=r.get_size(), ), @@ -1368,9 +1380,7 @@ def _unroll_reduction_fn( src_dtype: torch.dtype, ) -> Callable[[Sequence[_IntLike]], OpsValue]: """Convert inner_fn from a reduction to an pointwise""" - reduction_ranges = [ - V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges - ] + reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges) combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) @@ -1387,12 +1397,10 @@ def fn(index: Sequence[_IntLike]) -> Any: value_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Any] if reduction_type in ("argmin", "argmax"): - flatten_index = FixedLayout( - None, # type: ignore[arg-type] - None, # type: ignore[arg-type] + flatten_index = _fixed_indexer( reduction_ranges, FlexibleLayout.contiguous_strides(reduction_ranges), - ).make_indexer() + ) def value_fn( index: Sequence[_IntLike], rindex: Sequence[_IntLike] @@ -1420,7 +1428,7 @@ def create( reduction_type: ReductionType, reduction_hint: ReductionHint = ReductionHint.DEFAULT, input_node: Optional[IRNode] = None, - ) -> TensorBox: + ) -> Union[TensorBox, ShapeAsConstantBuffer]: reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) if reduction_numel == 0: @@ -1431,10 +1439,10 @@ def py_cnst(val: object) -> Union[bool, float, int]: if dst_dtype == torch.bool: return bool(val) elif dst_dtype.is_floating_point: - assert isinstance(val, typing.SupportsFloat) + assert isinstance(val, SupportsFloat), type(val) return float(val) else: - assert isinstance(val, typing.SupportsInt) + assert isinstance(val, SupportsInt), type(val) return int(val) rtypes_to_inits = { @@ -1478,7 +1486,7 @@ def fn(index: int) -> OpsValue: if ( isinstance(reduction_numel, Integer) - and V.graph.sizevars.size_hint(reduction_numel) + and V.graph.sizevars.size_hint_or_throw(reduction_numel) < config.unroll_reductions_threshold and (sympy_product(ranges) != 1 or is_gpu(device.type)) ): @@ -1505,6 +1513,18 @@ def fn(index: int) -> OpsValue: reduction_numel, input_node, ) + + def _maybe_increase_split(split: int) -> int: + # don't apply min_num_split constraint for static shape case. + if _is_static(reduction_numel): + return split + if split > 1: + return max(split, config.min_num_split) + else: + return split + + split = _maybe_increase_split(split) + # intermediate reduction in split can contain complex indexing, # and num_splits will fail to correctly set the hint # reuse the passed hint if available @@ -1662,7 +1682,7 @@ def _multilayer_wrap_loader( reindex = View.dynamic_reshape_indexer( reduction_ranges, [reduction_numel], dense_index ) - need_mask = not V.graph.sizevars.is_expr_static_and_true( + need_mask = not V.graph.sizevars.statically_known_true( sympy.Eq(reduction_numel % split, 0) ) @@ -1677,9 +1697,10 @@ def body() -> OpsValue: return loader(new_index, reindex([indices])) if need_mask: + index_dtype = dtype_from_size(reduction_numel) mask = ops.lt( - ops.index_expr(indices, torch.int32), - ops.index_expr(reduction_numel, torch.int32), + ops.index_expr(indices, index_dtype), + ops.index_expr(reduction_numel, index_dtype), ) return ops.masked(mask, body, default) else: @@ -1690,7 +1711,7 @@ def body() -> OpsValue: @classmethod def _multilayer_wrap_loader_existing_ranges( cls, - loader: Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue], + loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue], original_ranges: Sequence[Expr], original_reduction_ranges: Sequence[Expr], new_ranges: Sequence[Integer], @@ -1704,8 +1725,8 @@ def _multilayer_wrap_loader_existing_ranges( ) def wrapper_fn( - merged_index: Sequence[sympy.Expr], - new_reduction_index: Sequence[sympy.Expr], + merged_index: Sequence[Expr], + new_reduction_index: Sequence[Expr], ) -> OpsValue: original_idx = merged_index[: len(original_ranges)] new_index = merged_index[len(original_ranges) :] @@ -1730,7 +1751,7 @@ def create_multilayer_helper( reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, - ) -> TensorBox: + ) -> Union[TensorBox, ShapeAsConstantBuffer]: """ Break a large reduction up into multiple smaller reductions recursively @@ -1793,7 +1814,7 @@ def create_multilayer( split: _IntLike, reduction_hint: ReductionHint, input_node: Optional[IRNode] = None, - ) -> TensorBox: + ) -> Union[TensorBox, ShapeAsConstantBuffer]: """ Break a large reduction up into multiple smaller reductions recursively @@ -1839,7 +1860,7 @@ def create_multilayer_existing_ranges( new_reduction_ranges: list[Integer], reduction_type: ReductionType, reduction_hint: ReductionHint, - ) -> TensorBox: + ) -> Union[TensorBox, ShapeAsConstantBuffer]: """ Break a large reduction up into multiple smaller reductions recursively @@ -1866,7 +1887,26 @@ def create_multilayer_existing_ranges( ) -INNER_FN_TY = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue] +def _fixed_indexer( + size: Sequence[int], + stride: Optional[Sequence[int]] = None, + offset: Expr = Integer(0), +) -> Callable[[Sequence[Expr]], Expr]: + """A closure containing math to read a given element""" + + def indexer(index: Sequence[int]) -> int: + assert stride is not None and len(index) == len(stride) + assert len(index) == len(size) + result = offset + for idx, st, sz in zip(index, stride, size): + if sz != 1: + result = result + idx * st + return result + + return indexer + + +INNER_FN_TY: TypeAlias = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue] class MultiOutputReduction(Reduction): @@ -1915,14 +1955,14 @@ def store_reduction( indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], reduction_vars: Sequence[Symbol], - ) -> None: + ) -> Any: values = ops.reduction( self.dtype, self.src_dtype, self.reduction_type, self.inner_fn(vars, reduction_vars), ) - assert isinstance(values, (tuple, list)), f"{type(values)}" + assert isinstance(values, (tuple, list)), type(values) value = values[self.output_index] return ops.store_reduction(output_name or "unnamed", indexer(vars), value) @@ -1940,7 +1980,7 @@ def create( # type: ignore[override] num_output: int, reduction_hint: ReductionHint = ReductionHint.DEFAULT, input_node: Optional[IRNode] = None, - ) -> Sequence[TensorBox]: + ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: """ Create the reduction disregarding splitting. """ @@ -1952,7 +1992,7 @@ def create( # type: ignore[override] inner_fn, ranges, reduction_ranges, - "online_softmax_reduce", # type: ignore[arg-type] + "online_softmax_reduce", src_dtype, reduction_hint, output_idx, @@ -1976,12 +2016,12 @@ def create( # type: ignore[override] reduction_ranges: list[Integer], reduction_type: ReductionType, reduction_hint: ReductionHint = ReductionHint.DEFAULT, - ) -> Sequence[TensorBox]: + ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: assert reduction_type in ("welford_reduce", "welford_combine") reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) - def const(val: int) -> TensorBox: + def const(val: int) -> Union[TensorBox, ShapeAsConstantBuffer]: def inner_fn(idx: Sequence[Expr]) -> OpsValue: return ops.constant( val, @@ -2005,7 +2045,7 @@ def inner_fn(idx: Sequence[Expr]) -> OpsValue: def copy( loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue], - ) -> TensorBox: + ) -> Union[TensorBox, ShapeAsConstantBuffer]: def inner_fn(idx: Sequence[Expr]) -> OpsValue: reduction_index = [sympy.S.Zero for _ in reduction_ranges] return loader(idx, reduction_index) @@ -2104,13 +2144,13 @@ def create_multilayer( # type: ignore[override] reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, - ) -> Sequence[TensorBox]: + ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: """ Break a large reduction up into multiple smaller reductions recursively """ reduction_numel = sympy_product(reduction_ranges) - need_mask = not V.graph.sizevars.is_expr_static_and_true( + need_mask = not V.graph.sizevars.statically_known_true( sympy.Eq(reduction_numel % split, 0) ) @@ -2199,7 +2239,7 @@ class Scan(Loops): dtypes: tuple[torch.dtype, ...] inner_fns: tuple[Callable[..., Any], ...] - # HACK we mimick reduction + # HACK we mimic reduction def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we @@ -2225,7 +2265,7 @@ def store_reduction( indexer: Callable[[Sequence[_IntLike]], Never], vars: Sequence[Expr], scan_vars: Sequence[Symbol], - ) -> None: + ) -> Any: idx = self.reindex(vars, scan_vars) values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) result = ops.scan(self.dtypes, self.combine_fn, values) @@ -2237,7 +2277,7 @@ def get_reduction_type(self) -> Optional[str]: # return self.scan_op return "custom" - def get_reduction_size(self) -> Sequence[sympy.Expr]: + def get_reduction_size(self) -> Sequence[Expr]: return self.scan_ranges def get_size(self) -> Sequence[Expr]: @@ -2275,7 +2315,7 @@ def create( # type: ignore[override] # Whether we have the option to fallback to aten can_fallback_to_aten: bool = True, **kwargs: Any, - ) -> Sequence[Optional[TensorBox]]: + ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]: pointwise_ranges = [*size[:axis], *size[axis + 1 :]] scan_ranges = [size[axis]] @@ -2293,7 +2333,7 @@ def create( # type: ignore[override] assert len(dtypes) == len(inner_fns) # Scan with a single element is just a copy - if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): + if sizevars.statically_known_true(sympy.Le(scan_numel, 1)): return [ Pointwise.create( device=device, @@ -2408,7 +2448,7 @@ class Sort(Loops): stable: bool descending: bool - # HACK we mimick reduction + # HACK we mimic reduction def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: return ( @@ -2431,7 +2471,7 @@ def store_reduction( indexer: Callable[[Sequence[Expr]], Expr], vars: Sequence[Expr], reduction_vars: Sequence[Expr], - ) -> None: + ) -> Any: idx = self.reindex(vars, reduction_vars) values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) result = ops.sort(self.dtypes, values, self.stable, self.descending) @@ -2478,7 +2518,7 @@ def create( # type: ignore[override] descending: bool, reduction_hint: ReductionHint = ReductionHint.DEFAULT, **kwargs: Any, - ) -> Sequence[Optional[TensorBox]]: + ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]: pointwise_ranges = [*size[:axis], *size[axis + 1 :]] sort_ranges = [size[axis]] @@ -2493,7 +2533,7 @@ def create( # type: ignore[override] max_rblock = 512 is_persistent_kernel = ( config.triton.persistent_reductions - and sizevars.is_expr_static_and_true(sympy.Le(sort_numel, max_rblock)) + and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock)) ) if not is_persistent_kernel: # We only support persistent triton kernels @@ -2502,7 +2542,7 @@ def create( # type: ignore[override] assert len(dtypes) == len(inner_fns) # Sort with a single element is just a copy - if sizevars.is_expr_static_and_true(sympy.Le(sort_numel, 1)): + if sizevars.statically_known_true(sympy.Le(sort_numel, 1)): return [ Pointwise.create( device=device, @@ -2710,17 +2750,18 @@ def has_exceeded_max_reads(self) -> bool: def realize(self) -> Optional[str]: return self.data.realize() - def realize_hint(self): # type: ignore[no-untyped-def] - return self.data.realize_hint() + def realize_hint(self) -> None: + self.data.realize_hint() - def get_storage_numel(self): # type: ignore[no-untyped-def] + def get_storage_numel(self) -> _IntLike: return self.data.get_storage_numel() def is_extern(self) -> bool: - return self.data.is_extern() # type: ignore[attr-defined] + return self.data.is_extern() def is_module_buffer(self) -> bool: - return self.data.is_module_buffer() # type: ignore[attr-defined] + assert isinstance(self.data, BaseView), type(self.data) + return self.data.is_module_buffer() def get_read_names(self) -> OrderedSet[str]: return self.data.get_read_names() @@ -2729,10 +2770,10 @@ def get_reads(self) -> OrderedSet[Dep]: with patch.object(FlexibleLayout, "allow_indexing", True): return extract_read_writes( self.make_loader(), - self.get_size(), # type: ignore[arg-type] + self.get_size(), ).reads - def unwrap_view(self): # type: ignore[no-untyped-def] + def unwrap_view(self) -> IRNode: x: IRNode = self while isinstance(x, BaseView): x = x.data @@ -2752,13 +2793,13 @@ def constant_to_device(self, device: torch.device) -> IRNode: @ir_dataclass class ExpandView(BaseView): - size: list[Expr] + size: Sequence[Expr] @staticmethod - def _normalize_size(x, new_size): # type: ignore[no-untyped-def] + def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLike]: """Replace `-1` with correct sizes""" sizevars = V.graph.sizevars - new_size = list(map(sympy.expand, new_size)) + new_size = [sympy.expand(s) for s in new_size] old_size = x.get_size() old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) assert len(new_size) == len(old_size) @@ -2782,7 +2823,7 @@ def _normalize_size(x, new_size): # type: ignore[no-untyped-def] return new_size @classmethod - def create(cls, x, new_size): # type: ignore[no-untyped-def] + def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: new_size = cls._normalize_size(x, new_size) if is_storage_and_layout(x): @@ -2812,12 +2853,16 @@ def create(cls, x, new_size): # type: ignore[no-untyped-def] def get_size(self) -> Sequence[Expr]: return self.size - def make_reindexer(self): # type: ignore[no-untyped-def] + def make_reindexer( + self, + ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: target = self.get_size() actual = self.data.get_size() skip = len(target) - len(actual) - def reindex(index): # type: ignore[no-untyped-def] + def reindex( + index: Sequence[Expr], + ) -> Sequence[Expr]: index = list(index[skip:]) assert len(index) == len(actual) for i in range(len(actual)): @@ -2834,7 +2879,7 @@ class PermuteView(BaseView): dims: list[Expr] @classmethod - def create(cls, x, dims): # type: ignore[no-untyped-def] + def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView: dims = cls._map_neg_dims(dims) assert OrderedSet(dims) == OrderedSet(range(len(dims))) @@ -2852,7 +2897,7 @@ def create(cls, x, dims): # type: ignore[no-untyped-def] return PermuteView(data=x, dims=dims) @classmethod - def _map_neg_dims(cls, dims): # type: ignore[no-untyped-def] + def _map_neg_dims(cls, dims: Sequence[int]) -> list[int]: return [dim if dim >= 0 else len(dims) + dim for dim in dims] def get_size(self) -> Sequence[Expr]: @@ -2862,12 +2907,16 @@ def get_size(self) -> Sequence[Expr]: size = self.data.get_size() return [size[i] for i in self.dims] - def make_reindexer(self): # type: ignore[no-untyped-def] + def make_reindexer( + self, + ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: inv = {j: i for i, j in enumerate(self.dims)} inv = [inv[i] for i in range(len(self.dims))] assert OrderedSet(inv) == OrderedSet(range(len(self.dims))) - def reindex(index): # type: ignore[no-untyped-def] + def reindex( + index: Sequence[Expr], + ) -> Sequence[Expr]: return [index[i] for i in inv] return reindex @@ -2876,13 +2925,13 @@ def reindex(index): # type: ignore[no-untyped-def] @ir_dataclass class SqueezeView(BaseView): @classmethod - def create(cls, x, *, dim=None): # type: ignore[no-untyped-def] + def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode: if is_storage_and_layout(x): storage, old_layout = as_storage_and_layout(x) new_size = [] new_stride = [] if dim is not None: - assert isinstance(dim, int), "expected integer dim argument" + assert isinstance(dim, int), type(dim) assert 0 <= dim and dim < len(old_layout.size) for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): @@ -2914,12 +2963,14 @@ def create(cls, x, *, dim=None): # type: ignore[no-untyped-def] return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) @staticmethod - def squeezer(size: Sequence[sympy.Expr]): # type: ignore[no-untyped-def] + def squeezer( + size: Sequence[Expr], + ) -> tuple[list[int], Callable[[Sequence[Expr]], tuple[Expr]]]: new_size = [s for s in size if s != 1] not_one = [i for i, s in enumerate(size) if s != 1] length = len(size) - def reindex(index: list[sympy.Expr]) -> tuple[sympy.Expr, ...]: + def reindex(index: Sequence[Expr]) -> tuple[Expr]: assert len(index) == len(not_one), f"{index} {not_one}" new_index = [sympy.S.Zero] * length for idx, s in zip(not_one, index): @@ -2928,16 +2979,18 @@ def reindex(index: list[sympy.Expr]) -> tuple[sympy.Expr, ...]: return new_size, reindex - def __init__(self, data) -> None: # type: ignore[no-untyped-def] + def __init__(self, data: Any) -> None: raise AssertionError("use SqueezeView.create()") @ir_dataclass class GenericView(BaseView): - size: list[Expr] - reindex: Callable[..., Any] + size: Sequence[Expr] + reindex: Callable[[Sequence[Expr]], Sequence[Expr]] - def make_reindexer(self): # type: ignore[no-untyped-def] + def make_reindexer( + self, + ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: return self.reindex def reindex_str(self) -> str: @@ -2955,7 +3008,12 @@ def __str__(self) -> str: __repr__ = __str__ @classmethod - def create(cls, x, new_size, reindex): # type: ignore[no-untyped-def] + def create( + cls, + x: IRNode, + new_size: Sequence[Expr], + reindex: Callable[[Sequence[Expr]], Sequence[Expr]], + ) -> BaseView: return cls(data=x, size=list(new_size), reindex=reindex) def get_size(self) -> Sequence[Expr]: @@ -2965,7 +3023,7 @@ def get_size(self) -> Sequence[Expr]: @ir_dataclass class View(GenericView): @staticmethod - def handle_negative_index(idx, size): # type: ignore[no-untyped-def] + def handle_negative_index(idx: Expr, size: Expr) -> Expr: idx = sympy.expand(idx) size = sympy.expand(size) evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr @@ -2974,8 +3032,8 @@ def handle_negative_index(idx, size): # type: ignore[no-untyped-def] return idx @classmethod - def create(cls, x, new_size): # type: ignore[no-untyped-def] - assert isinstance(new_size, (tuple, list)) + def create(cls, x: IRNode, new_size: Sequence[Expr]) -> IRNode: # type: ignore[override] + assert isinstance(new_size, Sequence), type(new_size) old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) # Skip pointless views @@ -2991,7 +3049,7 @@ def create(cls, x, new_size): # type: ignore[no-untyped-def] if 0 in new_size: - def fake_reindex(index): # type: ignore[no-untyped-def] + def fake_reindex(index: Any) -> tuple[int, ...]: return tuple([0] * len(old_size)) return cls(data=x, size=list(new_size), reindex=fake_reindex) @@ -3003,9 +3061,7 @@ def fake_reindex(index): # type: ignore[no-untyped-def] # TODO: unbacked should not diverge from backed in determining striding # Need to require contiguous here instead of realize, see: # https://github.com/pytorch/pytorch/issues/145561 - x = ExternKernel.require_exact_strides( - x, FlexibleLayout.contiguous_strides(x.get_size()) - ) + x = ExternKernel.require_contiguous(x) storage, old_layout = as_storage_and_layout(x, want_contiguous=True) new_layout = FixedLayout( @@ -3021,7 +3077,9 @@ def fake_reindex(index): # type: ignore[no-untyped-def] return cls(data=x, size=list(new_size), reindex=reindex) @staticmethod - def resolve_negative_size(old_size, new_size): # type: ignore[no-untyped-def] + def resolve_negative_size( + old_size: Sequence[Expr], new_size: Sequence[Expr] + ) -> tuple[list[Expr], list[Expr]]: new_size = [V.graph.sizevars.simplify(x) for x in new_size] old_size = [V.graph.sizevars.simplify(x) for x in old_size] @@ -3032,7 +3090,7 @@ def resolve_negative_size(old_size, new_size): # type: ignore[no-untyped-def] new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) break - V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) + V.graph.sizevars.check_equals(sympy_product(old_size), sympy_product(new_size)) return old_size, new_size @classmethod @@ -3040,7 +3098,7 @@ def dynamic_reshape_indexer( cls, old_size: Sequence[_IntLike], new_size: Sequence[_IntLike], - dense_dim: Optional[int] = None, # type: ignore[no-untyped-def] + dense_dim: Optional[int] = None, ) -> Callable[[Sequence[_T]], Sequence[_V]]: try: reindex = cls._dynamic_reshape_indexer(old_size, new_size, dense_dim) @@ -3053,7 +3111,11 @@ def dynamic_reshape_indexer( return reindex @staticmethod - def _dynamic_reshape_indexer(old_size, new_size, dense_dim: Optional[int] = None): # type: ignore[no-untyped-def] + def _dynamic_reshape_indexer( + old_size: Sequence[Expr], + new_size: Sequence[Expr], + dense_dim: Optional[int] = None, + ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: """ Perform a reshape entirely by modifying indexing math """ @@ -3089,14 +3151,14 @@ def _dynamic_reshape_indexer(old_size, new_size, dense_dim: Optional[int] = None stack_old.append(size_old) # re-add elif size_hint(size_new) == size_hint(size_old): view_expr.append(var) - V.graph.sizevars.guard_equals(size_new, size_old) + V.graph.sizevars.check_equals(size_new, size_old) elif size_hint(size_new) < size_hint(size_old): while size_hint(size_new) < size_hint(size_old): var2, size_new2 = stack_new.pop() var = var2 * size_new + var size_new = size_new * size_new2 view_expr.append(var) - V.graph.sizevars.guard_equals(size_new, size_old) + V.graph.sizevars.check_equals(size_new, size_old) elif size_hint(size_new) > size_hint(size_old): divisor = sympy.S.One modulus = size_old @@ -3107,18 +3169,18 @@ def _dynamic_reshape_indexer(old_size, new_size, dense_dim: Optional[int] = None view_expr.append(ModularIndexing(var, divisor, modulus)) divisor = divisor * modulus size_old = size_old * modulus - V.graph.sizevars.guard_equals(size_new, size_old) + V.graph.sizevars.check_equals(size_new, size_old) else: raise AssertionError while stack_old: size_old = stack_old.pop() - V.graph.sizevars.guard_equals(size_old, 1) + V.graph.sizevars.check_equals(size_old, 1) view_expr.append(sympy.S.Zero) while stack_new: var, size_new = stack_new.pop() - V.graph.sizevars.guard_equals(size_new, 1) + V.graph.sizevars.check_equals(size_new, 1) if dense_dim is not None and len(new_size) == 1: view_expr.reverse() @@ -3130,7 +3192,9 @@ def _dynamic_reshape_indexer(old_size, new_size, dense_dim: Optional[int] = None assert len(view_expr) == len(old_size) - def reindex(index): # type: ignore[no-untyped-def] + def reindex( + index: Sequence[Expr], + ) -> Sequence[Expr]: assert len(index) == len(vars), (len(index), len(vars)) replacements = dict(zip(vars, index)) return tuple(sympy_subs(x, replacements) for x in view_expr) @@ -3169,13 +3233,13 @@ def get_origin_node(self) -> Optional[torch.fx.Node]: return None @property - def dtype(self): # type: ignore[no-untyped-def] + def dtype(self) -> torch.dtype: return self.layout.dtype def get_size(self) -> Sequence[Expr]: return list(self.layout.size) - def get_stride(self): # type: ignore[no-untyped-def] + def get_stride(self) -> Sequence[Expr]: return list(self.layout.stride) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: @@ -3195,7 +3259,7 @@ def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: def get_layout(self) -> Layout: return self.layout - def freeze_layout(self): # type: ignore[no-untyped-def] + def freeze_layout(self) -> None: pass def get_free_symbol_uses( @@ -3231,7 +3295,7 @@ class DtypeView(BaseView): target_dtype: torch.dtype @classmethod - def create(cls, x, new_dtype): # type: ignore[no-untyped-def] + def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView: if is_storage_and_layout(x): storage, old_layout = as_storage_and_layout(x) new_layout = FixedLayout( @@ -3250,7 +3314,7 @@ def __str__(self) -> str: __repr__ = __str__ @property - def dtype(self): # type: ignore[no-untyped-def] + def dtype(self) -> torch.dtype: return self.target_dtype def get_size(self) -> Sequence[Expr]: @@ -3259,7 +3323,7 @@ def get_size(self) -> Sequence[Expr]: def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: inner = self.data.make_loader() - def loader(idx): # type: ignore[no-untyped-def] + def loader(idx: Sequence[Expr]) -> OpsValue: return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype) return loader @@ -3267,7 +3331,9 @@ def loader(idx): # type: ignore[no-untyped-def] class SliceView(View): @classmethod - def normalize_start_end(cls, x, dim, start, end): # type: ignore[no-untyped-def] + def normalize_start_end( + cls, x: IRNode, dim: int, start: int, end: int + ) -> tuple[int, int]: """ Normalize start and end such that both are in the range [0, x.get_size()[dim]] and start <= end. @@ -3282,7 +3348,7 @@ def normalize_start_end(cls, x, dim, start, end): # type: ignore[no-untyped-def min_func = sizevars.evaluate_min max_func = sizevars.evaluate_max - def clamp(x, lower, upper): # type: ignore[no-untyped-def] + def clamp(x: Expr, lower: int, upper: int) -> Expr: clamped_lower = ( x if sizevars.statically_known_geq(x, lower) else max_func(x, lower) ) @@ -3293,8 +3359,11 @@ def clamp(x, lower, upper): # type: ignore[no-untyped-def] ) return clamped_full - def clamp_wrap(val, lower, upper, default): # type: ignore[no-untyped-def] + def clamp_wrap( + val: Union[int, None], lower: int, upper: int, default: Union[Expr, int] + ) -> Union[Expr, int]: if val is None: + # TODO(rec): can this really happen? return default val = cls.handle_negative_index(val, dim_size) return clamp(val, lower, upper) @@ -3304,9 +3373,17 @@ def clamp_wrap(val, lower, upper, default): # type: ignore[no-untyped-def] return start, end @classmethod - def create(cls, x, dim, start, end, step=1, clamp=True): # type: ignore[no-untyped-def] + def create( # type: ignore[override] + cls, + x: IRNode, + dim: int, + start: int, + end: int, + step: int = 1, + clamp: bool = True, + ) -> IRNode: step = sympy.expand(step) - assert isinstance(step, sympy.Expr) or step > 0 + assert isinstance(step, Expr) or step > 0, step try: if start == 0 and end >= 2**63 - 1 and step == 1: return x @@ -3337,7 +3414,9 @@ def create(cls, x, dim, start, end, step=1, clamp=True): # type: ignore[no-unty ) return ReinterpretView(data=storage, layout=new_layout) - def reindex(index): # type: ignore[no-untyped-def] + def reindex( + index: Sequence[Expr], + ) -> Sequence[Expr]: assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" index = list(index) index[dim] = index[dim] * step + start @@ -3432,8 +3511,8 @@ def __init__( self, device: torch.device, dtype: torch.dtype, - size: list[Expr], - stride: Optional[list[Expr]] = None, + size: Sequence[Expr], + stride: Optional[Sequence[Expr]] = None, offset: Expr = Integer(0), ) -> None: if stride is None: @@ -3442,9 +3521,9 @@ def __init__( self.dtype = dtype assert len(size) == len(stride), f"size={size}, stride={stride}" assert all(isinstance(s, (Expr, int)) for s in size) - self.size: list[Expr] = size - self.stride: list[Expr] = stride - self.offset: Expr = offset + self.size = size + self.stride = stride + self.offset = offset def __str__(self) -> str: offset = "" @@ -3498,7 +3577,7 @@ def is_transposed(self) -> bool: return False return True - def is_stride_ordered(self, order) -> bool: # type: ignore[no-untyped-def] + def is_stride_ordered(self, order: Sequence[int]) -> bool: assert len(self.stride) == len(order) # ignore dimensions of size 1, they dont affect layout @@ -3509,9 +3588,9 @@ def is_stride_ordered(self, order) -> bool: # type: ignore[no-untyped-def] ] stride = [self.stride[i] for i in non_1_indices] - order = [order[i] for i in non_1_indices] + order: Sequence[int] = [order[i] for i in non_1_indices] - def sorted_indices(arr): # type: ignore[no-untyped-def] + def sorted_indices(arr: Sequence[int]) -> Sequence[int]: sorted_arr = sorted(arr) return [sorted_arr.index(element) for element in arr] @@ -3533,14 +3612,16 @@ def sorted_indices(arr): # type: ignore[no-untyped-def] return False return True - def is_channels_last_stride_ordered(self): # type: ignore[no-untyped-def] + def is_channels_last_stride_ordered(self) -> bool: # create channels_last order(NCHW, NCDHW, the C is the first order). order = [0] + list(reversed(range(1, len(self.stride) - 1))) order = [len(order)] + order return self.is_stride_ordered(order) @staticmethod - def _pad_strides(in_strides, size, dtype): # type: ignore[no-untyped-def] + def _pad_strides( + in_strides: Sequence[int], size: Sequence[Expr], dtype: torch.dtype + ) -> Sequence[int]: """ The padding does not change stride order but makes sure all strides larger than the threshold are multiple of align. @@ -3598,15 +3679,15 @@ def _pad_strides(in_strides, size, dtype): # type: ignore[no-untyped-def] metrics.num_comprehensive_padding += 1 return new_strides - def pad_strides(self): # type: ignore[no-untyped-def] - assert isinstance(self, FlexibleLayout) + def pad_strides(self) -> None: + assert isinstance(self, FlexibleLayout), type(self) assert self.stride is not None self.stride = self._pad_strides(self.stride, self.size, self.dtype) - def should_pad_strides(self): # type: ignore[no-untyped-def] + def should_pad_strides(self) -> bool: return config.comprehensive_padding and isinstance(self, FlexibleLayout) - def as_fixed(self): # type: ignore[no-untyped-def] + def as_fixed(self) -> FixedLayout: if isinstance(self, FixedLayout): return self @@ -3626,17 +3707,18 @@ def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: ) return self.as_fixed().make_indexer() - def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + def __eq__(self, other: object) -> bool: return ( - self.device == other.device + isinstance(other, Layout) + and self.device == other.device and self.dtype == other.dtype and self.size == other.size and self.stride == other.stride and self.offset == other.offset ) - def storage_size(self) -> sympy.Expr: - return compute_required_storage_length(self.size, self.stride, self.offset) + def storage_size(self) -> Expr: + return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type] class FixedLayout(Layout): @@ -3644,27 +3726,17 @@ class FixedLayout(Layout): def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: """A closure containing math to read a given element""" - - def indexer(index): # type: ignore[no-untyped-def] - assert len(index) == len(self.stride) - assert len(index) == len(self.size) - result = self.offset - for idx, stride, sz in zip(index, self.stride, self.size): - if sz != 1: - result = result + idx * stride - return result - - return indexer + return _fixed_indexer(self.size, self.stride, self.offset) class FlexibleLayout(Layout): - """A Tensor layout we are allowed to change""" + """A Tensor layout that we are allowed to change""" allow_indexing = False # WARNING! This doesn't handle zero size tensors correctly @staticmethod - def contiguous_strides(sizes): # type: ignore[no-untyped-def] + def contiguous_strides(sizes: Sequence[int]) -> list[Expr]: if len(sizes) == 0: return [] reversed_strides = [sympy.S.One] @@ -3673,7 +3745,7 @@ def contiguous_strides(sizes): # type: ignore[no-untyped-def] return list(reversed(reversed_strides)) @staticmethod - def fill_ordered(sizes, order): # type: ignore[no-untyped-def] + def fill_ordered(sizes: Sequence[int], order: Sequence[int]) -> list[Expr]: """ Create a stride based on the order the dimensions should be filled in. @@ -3690,7 +3762,7 @@ def fill_ordered(sizes, order): # type: ignore[no-untyped-def] return strides @staticmethod - def stride_ordered(sizes, order): # type: ignore[no-untyped-def] + def stride_ordered(sizes: Sequence[int], order: Sequence[int]) -> Sequence[Expr]: """ Create a stride based on the sorted order of a permuted range. @@ -3702,7 +3774,9 @@ def stride_ordered(sizes, order): # type: ignore[no-untyped-def] return FlexibleLayout.fill_ordered(sizes, fill_order) @staticmethod - def stride_ordered_for_memory_format(sizes, memory_format): # type: ignore[no-untyped-def] + def stride_ordered_for_memory_format( + sizes: Sequence[int], memory_format: torch.memory_format + ) -> Sequence[Expr]: """ Create a stride based on a memory format. @@ -3727,7 +3801,9 @@ def stride_ordered_for_memory_format(sizes, memory_format): # type: ignore[no-u raise NotImplementedError @staticmethod - def same_ordered(sizes, stride): # type: ignore[no-untyped-def] + def same_ordered( + sizes: Sequence[int], stride: Sequence[_IntLike] + ) -> Sequence[Expr]: """ Create a stride that has the same stride order as given stride @@ -3735,11 +3811,13 @@ def same_ordered(sizes, stride): # type: ignore[no-untyped-def] the fill order should be [1, 3, 2, 0] """ assert len(sizes) == len(stride) - stride = [V.graph.sizevars.size_hint(x) for x in stride] + stride = [V.graph.sizevars.size_hint_or_throw(x) for x in stride] fill_order = sorted(range(len(stride)), key=stride.__getitem__) return FlexibleLayout.fill_ordered(sizes, fill_order) - def as_stride_order(self, order, allow_padding=False): # type: ignore[no-untyped-def] + def as_stride_order( + self, order: Sequence[int], allow_padding: bool = False + ) -> FixedLayout: new_stride = self.stride_ordered(self.size, order) if self.should_pad_strides() and allow_padding: new_stride = self._pad_strides(new_stride, self.size, self.dtype) @@ -3752,7 +3830,9 @@ def as_stride_order(self, order, allow_padding=False): # type: ignore[no-untype self.offset, ) - def as_exact_strides(self, exact_strides, allow_padding=False): # type: ignore[no-untyped-def] + def as_exact_strides( + self, exact_strides: Sequence[_IntLike], allow_padding: bool = False + ) -> FixedLayout: new_stride = exact_strides if self.should_pad_strides() and allow_padding: new_stride = self._pad_strides(new_stride, self.size, self.dtype) @@ -3765,8 +3845,8 @@ def as_exact_strides(self, exact_strides, allow_padding=False): # type: ignore[ self.offset, ) - def as_fill_order(self, order): # type: ignore[no-untyped-def] - new_stride = self.fill_ordered(self.size, order) + def as_fill_order(self, order: Sequence[int]) -> FixedLayout: + new_stride: Sequence[int] = self.fill_ordered(self.size, order) if self.should_pad_strides(): new_stride = self._pad_strides(new_stride, self.size, self.dtype) return FixedLayout( @@ -3777,7 +3857,7 @@ def as_fill_order(self, order): # type: ignore[no-untyped-def] self.offset, ) - def as_same_order(self, stride): # type: ignore[no-untyped-def] + def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: new_stride = self.same_ordered(self.size, stride) if self.should_pad_strides(): new_stride = self._pad_strides(new_stride, self.size, self.dtype) @@ -3789,7 +3869,13 @@ def as_same_order(self, stride): # type: ignore[no-untyped-def] self.offset, ) - def __init__(self, device, dtype, size, stride_order=None) -> None: # type: ignore[no-untyped-def] + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: Sequence[Expr], + stride_order: Optional[Sequence[Union[int, Integer]]] = None, + ) -> None: if stride_order: strides = FlexibleLayout.fill_ordered(size, stride_order) else: @@ -3813,7 +3899,7 @@ def __init__(self, view: Union[BaseView, TensorBox]) -> None: def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: return self.as_fixed().make_indexer() - def maybe_guard_aligned(self): # type: ignore[no-untyped-def] + def maybe_guard_aligned(self) -> bool: offset = self.view.get_layout().offset if offset == 0: return True @@ -3882,7 +3968,7 @@ class NoneLayout(OutputSpec): def storage_size(self) -> int: return 0 - def as_fixed(self): # type: ignore[no-untyped-def] + def as_fixed(self) -> OutputSpec: return self def get_device(self) -> Optional[torch.device]: @@ -3894,7 +3980,7 @@ def __init__(self, target: IRNode) -> None: super().__init__( target.get_device_or_error(), target.get_dtype(), - target.get_size(), # type: ignore[arg-type] + target.get_size(), None, ) self.target = target @@ -3902,18 +3988,18 @@ def __init__(self, target: IRNode) -> None: V.graph.mark_buffer_mutated(name) @property - def stride(self) -> list[Expr]: + def stride(self) -> Sequence[Expr]: # type: ignore[override] return self.real_layout().stride - @stride.setter + @stride.setter # type: ignore[override] def stride(self, value: Never) -> None: pass # ignore setting of stride - def storage_size(self) -> sympy.Expr: + def storage_size(self) -> Expr: return self.real_layout().storage_size() def get_buffer(self) -> Buffer: - def unwrap_views(target): # type: ignore[no-untyped-def] + def unwrap_views(target: Any) -> Any: if isinstance(target, MutationLayoutSHOULDREMOVE): return unwrap_views(target.target) if isinstance(target, BaseView): @@ -3923,16 +4009,18 @@ def unwrap_views(target): # type: ignore[no-untyped-def] return target result = unwrap_views(self.target) - assert isinstance(result, Buffer), ( - "MutationLayoutSHOULDREMOVE must refer to a buffer" - ) + assert isinstance(result, Buffer), type(result) return result - def real_layout(self): # type: ignore[no-untyped-def] - return self.get_buffer().layout + def real_layout(self) -> Layout: + layout = self.get_buffer().layout + assert isinstance(layout, Layout) + return layout @classmethod - def realize_into(cls, src, dst, unsafe_alias=False): # type: ignore[no-untyped-def] + def realize_into( + cls, src: IRNode, dst: IRNode, unsafe_alias: bool = False + ) -> IRNode: dst.realize() # NOTE: We must realize users of `dst` before we realize `src`, since # realization order determines scheduling order. Otherwise, src's @@ -3951,22 +4039,25 @@ def realize_into(cls, src, dst, unsafe_alias=False): # type: ignore[no-untyped- src.realize_hint() if not unsafe_alias: - src = Pointwise.create( + node = Pointwise.create( device=src.get_device(), dtype=src.get_dtype(), inner_fn=src.make_loader(), ranges=[ - V.graph.sizevars.guard_equals(a, b) + V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(src.get_size(), dst.get_size()) ], - ).data + ) + assert isinstance(node, (BaseView, MutableBox)) + src = node.data src.realize() - assert isinstance(src.data.layout, FlexibleLayout) + assert hasattr(src, "data"), src + assert isinstance(src.data.layout, FlexibleLayout), type(src.data.layout) src.data.layout = MutationLayoutSHOULDREMOVE(dst) return src.data - def as_fixed(self): # type: ignore[no-untyped-def] + def as_fixed(self) -> Self: # type: ignore[override] return self def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: @@ -4026,44 +4117,46 @@ def get_layout(self) -> Layout: def get_output_spec(self) -> OutputSpec: return self.layout - def get_storage_numel(self): # type: ignore[no-untyped-def] + def get_storage_numel(self) -> int: return self.get_numel() - def freeze_layout(self): # type: ignore[no-untyped-def] + def freeze_layout(self) -> None: if isinstance(self.layout, Layout) and not isinstance( self.layout, NonOwningLayout ): self.layout = self.layout.as_fixed() - def freeze_layout_with_stride_order(self, order, allow_padding=False) -> None: # type: ignore[no-untyped-def] - assert isinstance(self.layout, FlexibleLayout) + def freeze_layout_with_stride_order( + self, order: Sequence[int], allow_padding: bool = False + ) -> None: + assert isinstance(self.layout, FlexibleLayout), type(self.layout) self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding) - def freeze_layout_with_fill_order(self, order) -> None: # type: ignore[no-untyped-def] - assert isinstance(self.layout, FlexibleLayout) + def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None: + assert isinstance(self.layout, FlexibleLayout), type(self.layout) self.layout = self.layout.as_fill_order(order) - def freeze_layout_with_same_order(self, stride) -> None: # type: ignore[no-untyped-def] - assert isinstance(self.layout, FlexibleLayout) + def freeze_layout_with_same_order(self, stride: Sequence[int]) -> None: + assert isinstance(self.layout, FlexibleLayout), type(self.layout) self.layout = self.layout.as_same_order(stride) - def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def] - self, exact_strides, allow_padding=False + def freeze_layout_with_exact_strides( + self, exact_strides: Sequence[int], allow_padding: bool = False ) -> None: - assert isinstance(self.layout, FlexibleLayout) + assert isinstance(self.layout, FlexibleLayout), type(self.layout) self.layout = self.layout.as_exact_strides( exact_strides, allow_padding=allow_padding ) - def is_zero_elements(self): # type: ignore[no-untyped-def] - return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) + def is_zero_elements(self) -> bool: + return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0)) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: # Loading from a zero-element buffer is a no-op if self.is_zero_elements(): return partial(nop_loader_fn, dtype=self.get_dtype()) - def loader(index): # type: ignore[no-untyped-def] + def loader(index: Sequence[Expr]) -> OpsValue: indexer = self.make_indexer() return ops.load(self.name or "unnamed", indexer(index)) @@ -4072,7 +4165,7 @@ def loader(index): # type: ignore[no-untyped-def] def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: return self.get_name() - def decide_layout(self): # type: ignore[no-untyped-def] + def decide_layout(self) -> None: pass def get_inputs_that_alias_output(self) -> Sequence[str]: @@ -4220,13 +4313,13 @@ def get_read_writes(self) -> dependencies.ReadWrites: if self.data.get_reduction_type(): return extract_read_writes( self.get_store_function(), - self.data.get_pointwise_size(), # type: ignore[arg-type] - self.data.get_reduction_size(), # type: ignore[arg-type] + self.data.get_pointwise_size(), + self.data.get_reduction_size(), ) else: return extract_read_writes( self.get_store_function(), - self.data.get_size(), # type: ignore[arg-type] + self.data.get_size(), ) def get_free_symbol_uses( @@ -4271,7 +4364,7 @@ def get_store_function(self) -> Callable[..., None]: if isinstance(self.data, (Reduction, Scan, Sort)): return partial(self.data.store_reduction, self.name, indexer) else: - assert isinstance(self.data, Pointwise) + assert isinstance(self.data, Pointwise), type(self.data) return partial(self.data.store_output, self.name, indexer) def get_fill_order(self) -> Optional[list[int]]: @@ -4325,9 +4418,9 @@ def decide_layout(self) -> None: def get_default_sizes_body( self, ) -> tuple[ - tuple[list[sympy.Expr], list[sympy.Expr]], + tuple[list[Expr], list[Expr]], LoopBody, - tuple[list[sympy.Expr], list[sympy.Expr]], + tuple[list[Expr], list[Expr]], ]: args, var_ranges = dependencies.index_vars_squeeze( self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" @@ -4358,7 +4451,7 @@ def simplify_and_reorder( self, extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, recompute_sizes_body_func: Optional[Callable[..., Any]] = None, - ) -> tuple[tuple[list[sympy.Expr], list[sympy.Expr]], LoopBody]: + ) -> tuple[tuple[list[Expr], list[Expr]], Optional[LoopBody]]: """ This is a main place where we do loop transformations in a backend-agnostic way. @@ -4398,8 +4491,8 @@ def simplify_and_reorder( and len(extra_indexing_constraints) == 2 ) extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints - assert isinstance(extra_indexing_ranges, dict) - assert isinstance(extra_indexing_expr, list) + assert isinstance(extra_indexing_ranges, dict), type(extra_indexing_ranges) + assert isinstance(extra_indexing_expr, list), type(extra_indexing_expr) assert all(isinstance(f, Expr) for f in extra_indexing_expr) expected_var_ranges = body.var_ranges @@ -4417,7 +4510,16 @@ def simplify_and_reorder( if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): memory_addrs.extend(body.get_read_exprs()) - def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): # type: ignore[no-untyped-def] + def simplify_and_reorder( + x_vars: Sequence[sympy.Symbol], + support_vars: Sequence[sympy.Symbol], + sizes: Sequence[int], + simplify_loops: bool, + ) -> tuple[ + list[int], + Callable[[Sequence[int]], Sequence[int]], + Callable[[Sequence[int]], Sequence[int]], + ]: sizes, reindex0, reindex1 = self._apply_loop_reordering( x_vars, support_vars, sizes, memory_addrs ) @@ -4470,13 +4572,17 @@ def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): # type: return (iter_ranges, reduce_ranges), body @staticmethod - def _apply_loop_reordering( # type: ignore[no-untyped-def] - index_vars, - support_vars, - sizes, - memory_addrs, - priority_idx=None, - ): + def _apply_loop_reordering( + index_vars: Sequence[sympy.Symbol], + support_vars: Sequence[sympy.Symbol], + sizes: Sequence[int], + memory_addrs: list[sympy.Expr], + priority_idx: Optional[list[int]] = None, + ) -> tuple[ + list[int], + Callable[[Sequence[int]], Sequence[int]], + Callable[[Sequence[int]], Sequence[int]], + ]: """ Shuffle the order of loops around to hopefully improve performance. """ @@ -4505,7 +4611,7 @@ def _apply_loop_reordering( # type: ignore[no-untyped-def] sizes = [sizes[i] for i in order] return sizes, same_reorder(order), inverse_reorder(order) - def get_reduction_size(self) -> Sequence[sympy.Expr]: + def get_reduction_size(self) -> Sequence[Expr]: return self.data.get_reduction_size() def get_reduction_type(self) -> Optional[str]: @@ -4530,9 +4636,9 @@ class TemplateBuffer(OperationBuffer): def __init__( self, - layout: Layout, + layout: OutputSpec, inputs: Sequence[IRNode], - make_kernel_render: Callable[..., Any], + make_kernel_render: Optional[Callable[..., Any]], ) -> None: super().__init__(name=None, layout=layout) self.inputs = InputsKernel.unwrap_storage(inputs) @@ -4543,11 +4649,11 @@ def __init__( def get_read_writes(self) -> dependencies.ReadWrites: return self.extract_read_writes(normalize=True) - def extract_read_writes(self, normalize): # type: ignore[no-untyped-def] + def extract_read_writes(self, normalize: bool = False) -> dependencies.ReadWrites: name = self.get_name() indexer = self.get_layout().make_indexer() - def dummy(index, rindex): # type: ignore[no-untyped-def] + def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: assert len(rindex) == 0 return ops.store(name, indexer(index), "fake") @@ -4556,11 +4662,14 @@ def dummy(index, rindex): # type: ignore[no-untyped-def] ) for inp in self.inputs: + assert isinstance(inp, (ReinterpretView, Buffer)), type(inp) + assert isinstance(inp.layout, Layout), type(inp.layout) + indexer = inp.layout.make_indexer() - def dummy(index, rindex): # type: ignore[no-untyped-def] + def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: assert len(rindex) == 0 - ops.load(inp.get_name(), indexer(index)) + return ops.load(inp.get_name(), indexer(index)) deps.reads |= dependencies.extract_read_writes( dummy, inp.get_size(), (), normalize=True @@ -4568,7 +4677,7 @@ def dummy(index, rindex): # type: ignore[no-untyped-def] return deps - def get_reduction_size(self) -> Sequence[sympy.Expr]: + def get_reduction_size(self) -> Sequence[Expr]: return sympy.S.One def get_reduction_type(self) -> Optional[str]: @@ -4577,26 +4686,26 @@ def get_reduction_type(self) -> Optional[str]: def should_allocate(self) -> bool: return True - def simplify_and_reorder( # type: ignore[no-untyped-def] + def simplify_and_reorder( self, extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, recompute_sizes_body_func: Optional[Callable[..., Any]] = None, - ): + ) -> tuple[tuple[Sequence[Expr], list[Expr]], Optional[LoopBody]]: return ( ( self.get_size(), - (), + [], ), None, ) class TritonTemplateBuffer(TemplateBuffer): - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - layout, - inputs, - make_kernel_render, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Optional[Callable[_P, _T]], mutated_inputs: Optional[Iterable[IRNode]] = None, allowed_prologue_inps: Optional[OrderedSet[str]] = None, ) -> None: @@ -4604,7 +4713,7 @@ def __init__( # type: ignore[no-untyped-def] NOTE:[TritonTemplates with multiple outputs] We want the ability for TritonTemplates to output multiple tensors. Triton kernels have no notion of outputs and this is done by creating tensors that - are then mutated by the kernel. Currenlty our STORE_OUTPUT codegen doesn't + are then mutated by the kernel. Currently our STORE_OUTPUT codegen doesn't support creating multinode outputs for triton templates. We work around this by creating an extra input buffer during the lowering and we mark them as mutated inputs. @@ -4622,6 +4731,7 @@ def __init__( # type: ignore[no-untyped-def] assert current_node in allowed_set, ( f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" ) + assert isinstance(self.inputs[0], IRNode), type(self.inputs[0]) device = self.inputs[0].get_device() self.outputs += [ MutationOutput(NoneLayout(device=device), buf, self) @@ -4696,7 +4806,7 @@ def __init__( # knowing what autotuning is choosing) self.description = description - def benchmark(self, *args, out) -> float: # type: ignore[no-untyped-def] + def benchmark(self, *args: Any, out: torch.Tensor) -> float: algo = self.to_callable() if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: algo(*args)) @@ -4705,13 +4815,20 @@ def benchmark(self, *args, out) -> float: # type: ignore[no-untyped-def] def call_name(self) -> str: raise NotImplementedError - def to_callable(self): # type: ignore[no-untyped-def] + def to_callable(self) -> Callable[..., Any]: raise NotImplementedError + def kernel_hash_key(self) -> str: + """ + Hash key for the underlying kernel. By default, we assume there are no + runtime params, so kernel hash key defaults to choice caller's hash key. + """ + return self.hash_key() + def hash_key(self) -> str: raise NotImplementedError - def output_node(self) -> TensorBox: + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: raise NotImplementedError def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: @@ -4739,8 +4856,8 @@ class MultiTemplateBuffer(TritonTemplateBuffer): def __init__( self, layout: Layout, - inputs: list[IRNode], - choice_timings: Callable[[], dict[ChoiceCaller, float]], + inputs: Sequence[IRNode], + choice_timings_fn: Callable[[], dict[ChoiceCaller, float]], unfiltered_choices: list[ChoiceCaller], allowed_prologue_inps: OrderedSet[str], ) -> None: @@ -4750,7 +4867,7 @@ def __init__( make_kernel_render=None, allowed_prologue_inps=allowed_prologue_inps, ) - self._choice_timings_fn = choice_timings + self._choice_timings_fn = choice_timings_fn self._choice_timings: Optional[dict[ChoiceCaller, float]] = None self.original_inputs = inputs self._output_plannable = all( @@ -4776,8 +4893,10 @@ def choice_timings(self) -> dict[ChoiceCaller, float]: return self._choice_timings @contextlib.contextmanager - def swap_as_triton_caller(self, caller: TritonTemplateCallerBase): # type: ignore[no-untyped-def] - assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) + def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[None]: + assert isinstance( + caller, torch._inductor.select_algorithm.TritonTemplateCaller + ), type(caller) assert self.layout == caller.layout render = self.make_kernel_render @@ -4788,22 +4907,23 @@ def swap_as_triton_caller(self, caller: TritonTemplateCallerBase): # type: igno self.make_kernel_render = render def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase) -> None: - assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) + assert isinstance( + caller, torch._inductor.select_algorithm.TritonTemplateCaller + ), type(caller) assert self.get_size() == caller.layout.size assert self.get_stride() == caller.layout.stride self.make_kernel_render = caller.get_make_kernel_render() def get_min_choice(self) -> tuple[ChoiceCaller, float]: - min_choice = min(self.choice_timings, key=self.choice_timings.get) # type: ignore[arg-type] - return (min_choice, self.choice_timings[min_choice]) + return min(self.choice_timings.items(), key=lambda x: x[1]) class CUDATemplateBuffer(TemplateBuffer): - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - layout, - inputs, - make_kernel_render, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Callable[_P, _T], workspace_size: int, template: CUDATemplate, supports_epilogue_fusion: bool, @@ -4814,7 +4934,7 @@ def __init__( # type: ignore[no-untyped-def] self.template = template self.supports_epilogue_fusion = supports_epilogue_fusion - def get_workspace_size(self): # type: ignore[no-untyped-def] + def get_workspace_size(self) -> int: return self.workspace_size if self.workspace_size is not None else 0 def emulate_store_fn(self) -> None: @@ -4823,7 +4943,14 @@ def emulate_store_fn(self) -> None: class CppTemplateBuffer(TemplateBuffer): - def __init__(self, layout, inputs, make_kernel_render, template, choice) -> None: # type: ignore[no-untyped-def] + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Callable[_P, _T], + template: CUDATemplate, + choice: Any, + ) -> None: super().__init__(layout, inputs, make_kernel_render) self.template = template self.choice = choice @@ -4831,28 +4958,39 @@ def __init__(self, layout, inputs, make_kernel_render, template, choice) -> None def get_layout(self) -> Layout: if isinstance(self.layout, MultiOutputLayout): - assert isinstance(self.outputs, Iterable) + assert isinstance(self.outputs, Iterable), type(self.outputs) first_output = self.outputs[0] - assert isinstance(first_output, Buffer) + assert isinstance(first_output, Buffer), type(first_output) layout = first_output.layout - assert isinstance(layout, Layout) + assert isinstance(layout, Layout), type(layout) return layout else: return super().get_layout() +def is_node_sequence( + nodes: Sequence[Union[IRNode, Sequence[IRNode]]], +) -> TypeIs[Sequence[IRNode]]: + return all(isinstance(n, IRNode) for n in nodes) + + @ir_dataclass(frozen=False) class InputsKernel(OperationBuffer): - inputs: list[Buffer] + inputs: Sequence[Union[IRNode, Sequence[IRNode]]] + + def input_name(self, i: int) -> str: + input = self.inputs[i] + assert isinstance(input, IRNode) + return input.get_name() def get_read_writes(self) -> dependencies.ReadWrites: reads = OrderedSet[dependencies.Dep]() StarDep = dependencies.StarDep for input in self.inputs: - if isinstance(input, list): + if isinstance(input, Sequence): reads.update(StarDep(x.get_name()) for x in input) elif isinstance(input, ShapeAsConstantBuffer): - # Skip creating dependncy for symbolics as they're visible globally + # Skip creating dependency for symbolics as they're visible globally continue else: reads.add(StarDep(input.get_name())) @@ -4886,14 +5024,16 @@ def unwrap_storage_for_input(cls, x: IRNode) -> IRNode: return cls.unwrap_storage_for_input(x) if isinstance(x, TorchBindObject): return x - assert isinstance(x, (Buffer, ReinterpretView)), x + assert isinstance(x, (Buffer, ReinterpretView)), type(x) return x @staticmethod - def unwrap_storage(inputs): # type: ignore[no-untyped-def] - inputs_new = [] + def unwrap_storage( + inputs: Sequence[Union[IRNode, Sequence[IRNode]]], + ) -> list[Union[IRNode, Sequence[IRNode]]]: + inputs_new: list[Union[IRNode, Sequence[IRNode]]] = [] for x in inputs: - if isinstance(x, list): + if isinstance(x, Sequence): x = [InputsKernel.unwrap_storage_for_input(i) for i in x] else: x = InputsKernel.unwrap_storage_for_input(x) @@ -4922,7 +5062,7 @@ class ConcatKernel(NopKernel): """ @classmethod - def create(cls, inputs, dim): # type: ignore[no-untyped-def] + def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: device = inputs[0].get_device() dtype = inputs[0].get_dtype() new_size = list(inputs[0].get_size()) @@ -4939,12 +5079,12 @@ def create(cls, inputs, dim): # type: ignore[no-untyped-def] if j == dim: new_size[j] = new_size[j] + input_size[j] else: - new_size[j] = V.graph.sizevars.guard_equals( + new_size[j] = V.graph.sizevars.check_equals_and_simplify( new_size[j], input_size[j] ) offsets_end.append(new_size[dim]) - output_stride = FlexibleLayout.contiguous_strides(new_size) + output_stride: Sequence[int] = FlexibleLayout.contiguous_strides(new_size) if config.comprehensive_padding: # Ensure the output stride matches the alignment requirements output_stride = Layout._pad_strides( @@ -4964,7 +5104,7 @@ def create(cls, inputs, dim): # type: ignore[no-untyped-def] break any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs) fx_node_args = V.graph.current_node.args[0] - assert isinstance(fx_node_args, list) + assert isinstance(fx_node_args, list), type(fx_node_args) # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output if any_input_is_storage_and_layout is False and any( "val" in arg.meta @@ -4976,6 +5116,7 @@ def create(cls, inputs, dim): # type: ignore[no-untyped-def] ): output_stride = make_channels_last_strides_for(new_size) + assert device is not None concat_kernel = ConcatKernel( name=None, layout=FixedLayout( @@ -4988,23 +5129,28 @@ def create(cls, inputs, dim): # type: ignore[no-untyped-def] ) kernel = StorageBox(concat_kernel) op_names = [] - for i in range(len(inputs)): + for i, inp in enumerate(inputs): + assert isinstance(inp, (BaseView, MutableBox)), type(inp) input_buffer = cls.realize_into( - inputs[i], + inp, SliceView.create( kernel, dim, offsets_start[i], offsets_end[i], clamp=False ), ) + assert isinstance(input_buffer, Buffer), type(input_buffer) + assert isinstance(concat_kernel.inputs, list), type(concat_kernel.inputs) concat_kernel.inputs.append(input_buffer) - if isinstance(inputs[i].data, BaseView): - input_unwrapped = inputs[i].data.unwrap_view() + if isinstance(inp.data, BaseView): + input_unwrapped = inp.data.unwrap_view() else: - input_unwrapped = inputs[i].data + input_unwrapped = inp.data if ( - input_unwrapped.is_input_buffer() - and is_gpu(inputs[i].get_device().type) + isinstance(input_unwrapped, StorageBox) + and input_unwrapped.is_input_buffer() + and (dev := inp.get_device()) is not None + and is_gpu(dev.type) and not is_dynamic(input_buffer) ): op_names.append(input_buffer.get_operation_name()) @@ -5019,11 +5165,14 @@ def create(cls, inputs, dim): # type: ignore[no-untyped-def] return kernel @classmethod - def can_realize_into_without_copy(cls, src, dst=None): # type: ignore[no-untyped-def] + def can_realize_into_without_copy( + cls, src: IRNode, dst: Optional[IRNode] = None + ) -> bool: if isinstance(src, TensorBox): # unwrap a TensorBox return cls.can_realize_into_without_copy(src.data, dst) + assert isinstance(src, (BaseView, StorageBox)), type(src) if isinstance(src.data, MultiTemplateBuffer): if ( not isinstance(src.data.layout, FixedLayout) @@ -5045,12 +5194,14 @@ def can_realize_into_without_copy(cls, src, dst=None): # type: ignore[no-untype for s1, s2 in zip(src.get_stride(), dst.get_stride()) ) - return isinstance(src.data.layout, FlexibleLayout) and not isinstance( - src.data, ExternKernelAlloc + return ( + hasattr(src.data, "layout") + and isinstance(src.data.layout, FlexibleLayout) + and not isinstance(src.data, ExternKernelAlloc) ) @classmethod - def realize_into(cls, src, dst): # type: ignore[no-untyped-def] + def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode: # Attempt to turn this into a ReinterpretView rather than assert. # This has concessions around layout, as as_storage_and_layout # can cause us to go from flexible to fixed layout. @@ -5058,7 +5209,7 @@ def realize_into(cls, src, dst): # type: ignore[no-untyped-def] if is_storage_and_layout(dst): storage, layout = as_storage_and_layout(dst) dst = ReinterpretView(data=storage, layout=layout) - assert isinstance(dst, ReinterpretView), dst + assert isinstance(dst, ReinterpretView), type(dst) if isinstance(src, TensorBox): # unwrap a TensorBox return cls.realize_into(src.data, dst) @@ -5076,7 +5227,7 @@ def realize_into(cls, src, dst): # type: ignore[no-untyped-def] dtype=src.get_dtype(), inner_fn=src.make_loader(), ranges=[ - V.graph.sizevars.guard_equals(a, b) + V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(src.get_size(), dst.get_size()) ], ) @@ -5088,7 +5239,7 @@ def should_allocate(self) -> bool: @ir_dataclass(frozen=False) class ExternKernel(InputsKernel): - constant_args: tuple[Any, ...] = () + constant_args: Sequence[Any] = () kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) output_view: Optional[ReinterpretView] = None python_kernel_name: Optional[str] = None @@ -5098,28 +5249,29 @@ class ExternKernel(InputsKernel): ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( default_factory=list ) - op_overload: Optional[ - Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] - ] = None + op_overload: Optional[_OpOverloads] = None arg_properties: Optional[list[dict[str, Any]]] = None + allarg_properties: dict[str, dict[str, Any]] = dataclasses.field( + default_factory=dict + ) kwarg_properties: Optional[dict[str, dict[str, Any]]] = None unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field( default_factory=dict ) mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list) - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - name, - layout, - inputs, - constant_args=(), - kwargs=None, - output_view=None, - python_kernel_name=None, - cpp_kernel_name=None, - ordered_kwargs_for_cpp_kernel=(), - op_overload=None, + name: Optional[str], + layout: OutputSpec, + inputs: Sequence[Union[IRNode, Sequence[IRNode]]], + constant_args: Sequence[Any] = (), + kwargs: Optional[dict[str, Any]] = None, + output_view: Optional[ReinterpretView] = None, + python_kernel_name: Optional[str] = None, + cpp_kernel_name: Optional[str] = None, + ordered_kwargs_for_cpp_kernel: Iterable[str] = (), + op_overload: Optional[_OpOverloads] = None, ) -> None: super().__init__( name=name, @@ -5144,7 +5296,7 @@ def get_outputs(self) -> list[Buffer]: def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() - def collect_arg_kwarg_properties(self): # type: ignore[no-untyped-def] + def collect_arg_kwarg_properties(self) -> None: # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen self.arg_properties = ( @@ -5169,7 +5321,7 @@ def collect_arg_kwarg_properties(self): # type: ignore[no-untyped-def] else {} ) # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes - # ordered_kwargs_for_cpp_kernel is explicilty passed in. + # ordered_kwargs_for_cpp_kernel is explicitly passed in. if isinstance(self.op_overload, torch._ops.OpOverload): if not self.ordered_kwargs_for_cpp_kernel: self.ordered_kwargs_for_cpp_kernel = [ @@ -5181,17 +5333,17 @@ def collect_arg_kwarg_properties(self): # type: ignore[no-untyped-def] else: self.schema_kwargs = [] - def decide_layout(self): # type: ignore[no-untyped-def] + def decide_layout(self) -> None: if isinstance(self.layout, FlexibleLayout): self.apply_constraint() self.freeze_layout() - def codegen_comment(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen_comment(self, wrapper: PythonWrapperCodegen) -> None: origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper) if origin_str: wrapper.make_comment(origin_str) - def codegen(self, wrapper): # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: raise NotImplementedError def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: @@ -5233,16 +5385,24 @@ def set_python_kernel_name(self, python_kernel_name: Optional[str]) -> None: f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" ) - def get_kernel_name(self): # type: ignore[no-untyped-def] + def get_kernel_name(self) -> str: + from .codegen.cpp_wrapper_cpu import CppWrapperCpu + device = d.type if (d := self.get_device()) else V.graph.device_type - return ( - V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name, device) # type: ignore[attr-defined] - if V.graph.cpp_wrapper - else self.python_kernel_name - ) + if V.graph.cpp_wrapper: + assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type( + V.graph.wrapper_code + ) + assert self.cpp_kernel_name is not None + return V.graph.wrapper_code.get_c_shim_func_name( + self.cpp_kernel_name, device + ) + else: + assert self.python_kernel_name is not None + return self.python_kernel_name @staticmethod - def copy_input(x): # type: ignore[no-untyped-def] + def copy_input(x: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: pw = Pointwise.create( device=x.get_device(), dtype=x.get_dtype(), @@ -5255,8 +5415,8 @@ def copy_input(x): # type: ignore[no-untyped-def] return pw @classmethod - def process_kernel( # type: ignore[no-untyped-def] - cls, kernel, *args, **kwargs + def process_kernel( + cls, kernel: _OpOverloads, *args: Any, **kwargs: Any ) -> tuple[ Any, list[Any], @@ -5279,11 +5439,13 @@ def process_kernel( # type: ignore[no-untyped-def] if is_arg_tensor[-1]: tensor_args.append(arg) else: - if isinstance(arg, sympy.Expr): + if isinstance(arg, Expr): arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) non_tensor_args.append(arg) - def unflatten_args(new_tensor_args, new_non_tensor_args): # type: ignore[no-untyped-def] + def unflatten_args( + new_tensor_args: Sequence[_T], new_non_tensor_args: Sequence[_T] + ) -> tuple[list[_T], dict[str, _T]]: result = [] it_tensors = iter(new_tensor_args) it_non_tensors = iter(new_non_tensor_args) @@ -5342,11 +5504,11 @@ def unflatten_args(new_tensor_args, new_non_tensor_args): # type: ignore[no-unt unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None if shape_env := V.fake_mode.shape_env: node_meta_val = V.current_node.meta.get("val") - ctx = nullcontext() + ctx: AbstractContextManager[None] = nullcontext() if V.current_node.target == torch._higher_order_ops.effects.with_effects: # remove the first effect token in meta["val"] and meta["unbacked_bindings"] node_meta_val = node_meta_val[1] - ctx = _remove_effect_token_unbacked_bindings(V.current_node) # type: ignore[assignment] + ctx = _remove_effect_token_unbacked_bindings(V.current_node) with ctx: rebind_unbacked(shape_env, V.current_node, example_output) @@ -5375,13 +5537,13 @@ def unflatten_args(new_tensor_args, new_non_tensor_args): # type: ignore[no-unt ) @classmethod - def convert_to_reinterpret_view(cls, x): # type: ignore[no-untyped-def] + def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView: """ In order to pass this to an extern kernel we need a ReinterpretView not a View. This allows us to avoid some unneeded copies. """ - assert isinstance(x, BaseView) + assert isinstance(x, BaseView), type(x) if isinstance(x, ReinterpretView): return x @@ -5395,6 +5557,8 @@ def convert_to_reinterpret_view(cls, x): # type: ignore[no-untyped-def] if ( x_unwrap_view_fx_node is not None and "val" in x_unwrap_view_fx_node.meta + and isinstance(x_unwrap_view, (ReinterpretView, Buffer)) + # and hasattr(x_unwrap_view, "layout") and isinstance(x_unwrap_view.layout, FlexibleLayout) and ( x_unwrap_view_fx_node.meta["val"].is_contiguous( @@ -5412,8 +5576,7 @@ def convert_to_reinterpret_view(cls, x): # type: ignore[no-untyped-def] x_unwrap_view.freeze_layout() index_args, var_ranges = dependencies.index_vars_squeeze( - x.get_size(), - prefix="r", # type: ignore[arg-type] + x.get_size(), prefix="r" ) range_vars = index_args[0] index = x.make_indexer()(range_vars) @@ -5437,17 +5600,17 @@ def convert_to_reinterpret_view(cls, x): # type: ignore[no-untyped-def] layout=FixedLayout( device=x.get_device_or_error(), dtype=x.get_dtype(), - size=x.get_size(), # type: ignore[arg-type] + size=x.get_size(), stride=strides, offset=offset, ), ) @classmethod - def realize_input(cls, x): # type: ignore[no-untyped-def] + def realize_input(cls, x: IRNode) -> IRNode: if x is None: return NoneAsConstantBuffer() - if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): + if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)): return ShapeAsConstantBuffer(expr=x) if isinstance(x, Constant): return V.graph.add_tensor_constant( @@ -5477,7 +5640,7 @@ def realize_input(cls, x): # type: ignore[no-untyped-def] return cls.copy_input(x) @classmethod - def require_stride1(cls, x): # type: ignore[no-untyped-def] + def require_stride1(cls, x: IRNode) -> IRNode: if is_storage_and_layout(x): if len(x.get_stride()) == 0: return x @@ -5487,13 +5650,13 @@ def require_stride1(cls, x): # type: ignore[no-untyped-def] return cls.copy_input(x) @classmethod - def require_strides( # type: ignore[no-untyped-def] + def require_strides( cls, - x, + x: IRNode, order: Optional[Sequence[int]] = None, exact_strides: Optional[Sequence[_IntLike]] = None, - allow_padding=False, - ): + allow_padding: bool = False, + ) -> IRNode: assert order is not None or exact_strides is not None # Layout generally doesn't matter, but some consuming external ops might have requirements if x.get_numel() in (0, 1) and not exact_strides: @@ -5552,19 +5715,21 @@ def require_strides( # type: ignore[no-untyped-def] if exact_strides is not None else x ) - elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE): - if isinstance(x.get_layout().real_layout(), FlexibleLayout): + elif isinstance( + (mutation_layout := x.get_layout()), MutationLayoutSHOULDREMOVE + ): + if isinstance( + (real_layout := mutation_layout.real_layout()), FlexibleLayout + ): raise AssertionError( "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" ) - elif isinstance(x.get_layout().real_layout(), FixedLayout) and ( - (order and x.get_layout().real_layout().is_stride_ordered(order)) + elif isinstance(real_layout, FixedLayout) and ( + (order and real_layout.is_stride_ordered(order)) or ( exact_strides and significant_strides_equal( - exact_strides, - x.get_layout().real_layout().stride, - x.get_size(), + exact_strides, real_layout.stride, x.get_size() ) ) ): @@ -5585,8 +5750,9 @@ def require_strides( # type: ignore[no-untyped-def] isinstance(x, TensorBox) and isinstance(x.data, BaseView) and not isinstance(x.data, ReinterpretView) - and is_storage_and_layout(x.unwrap_view()) - and not isinstance(x.unwrap_view().data, ExternKernelAlloc) # type: ignore[attr-defined] + and is_storage_and_layout(unwrap_view := x.unwrap_view()) + and hasattr(unwrap_view, "data") + and not isinstance(unwrap_view.data, ExternKernelAlloc) ): try: x.data = cls.convert_to_reinterpret_view(x.data) @@ -5644,29 +5810,47 @@ def require_strides( # type: ignore[no-untyped-def] return x @classmethod - def require_exact_strides(cls, x, exact_strides, allow_padding=False): # type: ignore[no-untyped-def] + def require_exact_strides( + cls, x: IRNode, exact_strides: Sequence[_IntLike], allow_padding: bool = False + ) -> IRNode: return cls.require_strides( x, exact_strides=exact_strides, allow_padding=allow_padding ) @classmethod - def require_stride_order(cls, x, order, allow_padding=False): # type: ignore[no-untyped-def] + def require_stride_order( + cls, x: IRNode, order: Sequence[int], allow_padding: bool = False + ) -> IRNode: return cls.require_strides(x, order=order, allow_padding=allow_padding) @classmethod - def require_channels_last(cls, x): # type: ignore[no-untyped-def] + def require_channels_last(cls, x: IRNode) -> IRNode: return cls.require_stride_order(x, NHWC_STRIDE_ORDER) @classmethod - def require_channels_last_3d(cls, x): # type: ignore[no-untyped-def] + def require_channels_last_3d(cls, x: IRNode) -> IRNode: return cls.require_stride_order(x, NHWDC_STRIDE_ORDER) @classmethod - def require_contiguous(cls, x): # type: ignore[no-untyped-def] - return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) + def require_contiguous(cls, x: IRNode) -> IRNode: + def is_mkldnn_tensor(x: IRNode) -> bool: + try: + name = x.get_name() + except (AttributeError, NotImplementedError): + return False + + return name in V.graph.constants and V.graph.constants[name].is_mkldnn + + # TODO move this to the more proper places + if is_mkldnn_tensor(x): + return x + else: + return cls.require_exact_strides( + x, FlexibleLayout.contiguous_strides(x.get_size()) + ) @classmethod - def require_contiguous_strides(cls, x): # type: ignore[no-untyped-def] + def require_contiguous_strides(cls, x: IRNode) -> IRNode: # TODO: combine this with require_contiguous after # https://github.com/pytorch/pytorch/pull/148235 lands. return cls.require_exact_strides( @@ -5676,7 +5860,9 @@ def require_contiguous_strides(cls, x): # type: ignore[no-untyped-def] def apply_constraint(self) -> None: pass - def fill_non_provided_args(self, args, kwargs): # type: ignore[no-untyped-def] + def fill_non_provided_args( + self, args: Sequence[Any], kwargs: dict[str, Any] + ) -> Sequence[Any]: # Previously, we want to maintain forward-compatibility by skipping # default args in the serialized artifacts in fbcode. However, # some of our shim interfaces require default values being OrderedSet. @@ -5685,8 +5871,8 @@ def fill_non_provided_args(self, args, kwargs): # type: ignore[no-untyped-def] # part if we see real FC requirement. More details related to FC # can be found at: # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing - assert isinstance(args, (list, tuple)) - if isinstance(args, tuple): + assert isinstance(args, Sequence), type(args) + if not isinstance(args, list): args = list(args) assert self.arg_properties, "ExternKernel.arg_properties should not be empty" @@ -5710,7 +5896,7 @@ def fill_non_provided_args(self, args, kwargs): # type: ignore[no-untyped-def] ) return args - def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore[no-untyped-def] + def codegen_const_args(self, names: Optional[list[str]] = None) -> list[str]: if V.graph.cpp_wrapper: result = [] # Aten ops follow the convention that tensor args are before non-tensor args, @@ -5728,7 +5914,8 @@ def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore for i, x in enumerate(self.constant_args): if name_to_arg_properties is not None: - prop = name_to_arg_properties.get(names[i]) # type: ignore[index] + assert names is not None + prop = name_to_arg_properties.get(names[i]) type_ = prop.get("type") if prop else None else: idx = len(self.inputs) + i @@ -5740,9 +5927,9 @@ def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore result.append(V.graph.wrapper_code.val_to_arg_str(x, type_)) return result else: - return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + return [V.graph.wrapper_code.val_to_arg_str(a) for a in self.constant_args] - def codegen_args(self): # type: ignore[no-untyped-def] + def codegen_args(self) -> list[str]: if V.graph.cpp_wrapper and self.op_overload is not None: # cpp wrapper needs special logic to fill in missing args with default values inputs = self.fill_non_provided_args( @@ -5768,7 +5955,7 @@ def codegen_args(self): # type: ignore[no-untyped-def] args.extend(self.codegen_const_args()) return args - def get_kwargs_value(self, arg_name, **kwargs): # type: ignore[no-untyped-def] + def get_kwargs_value(self, arg_name: str, **kwargs: Any) -> Any: """Given an argument name, queries for values in (in order): 1. any provided kwargs for this function. 2. the class self.kwargs member. @@ -5777,11 +5964,11 @@ def get_kwargs_value(self, arg_name, **kwargs): # type: ignore[no-untyped-def] return kwargs.get(arg_name) if arg_name in self.kwargs: return self.kwargs.get(arg_name) - if self.allarg_properties and arg_name in self.allarg_properties: - return self.allarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + if (arg := self.allarg_properties.get(arg_name)) is not None: + return arg.get("default_value") raise AssertionError(f"{arg_name} not in self.allarg_properties") - def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def] + def codegen_kwargs(self, skip_out: bool = False) -> list[str]: if V.graph.cpp_wrapper: if self.op_overload is not None and len(self.schema_kwargs) == 0: # All the args should have been generated by fill_non_provided_args in codegen_args @@ -5794,14 +5981,11 @@ def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def] continue v = self.get_kwargs_value(arg_name) - if isinstance(v, sympy.Expr): + if isinstance(v, Expr): kwargs.append(v) else: - type_ = ( - self.allarg_properties.get(arg_name).get("type") # type: ignore[union-attr] - if self.allarg_properties and arg_name in self.allarg_properties - else None - ) + assert self.allarg_properties is not None + type_ = self.allarg_properties.get(arg_name, {}).get("type") kwargs.append(V.graph.wrapper_code.val_to_arg_str(v, type_)) else: kwargs = [ @@ -5810,28 +5994,44 @@ def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def] ] return kwargs - def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def] + def get_op_name(self) -> str: + if self.fx_node is not None: + target = self.fx_node.target + op_namespace = getattr(target, "__module__", "unknown_namespace") + op_namespace = op_namespace.replace("._ops.", ".ops.") + op_namespace = op_namespace.rsplit(".", 1)[0] + op_name = f"{op_namespace}.{target}" + else: + op_name = "unknown_op" + return op_name + + def codegen_size_asserts(self, wrapper: PythonWrapperCodegen) -> None: if config.size_asserts and not V.graph.cpp_wrapper: # comparing strides for 0 size tensor is tricky. Ignore them for now. if sympy_product(self.get_size()) == 0: return size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) - + op_name = self.get_op_name() wrapper.writeline( - f"assert_size_stride({self.get_name()}, {size}, {stride})" + f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})" ) def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def] if config.alignment_asserts and not V.graph.cpp_wrapper: name = self.get_name() aligned = name not in V.graph.unaligned_buffers + op_name = self.get_op_name() if aligned: - wrapper.writeline(f"assert_alignment({name}, {GPU_ALIGN_BYTES})") + wrapper.writeline( + f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})" + ) else: - wrapper.writeline(f"# buffer {name} is assumed to be not aligned") + wrapper.writeline( + f"# buffer {name} (op: {op_name}) is assumed to be not aligned" + ) - def get_group_stride(self): # type: ignore[no-untyped-def] + def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]: """ get output sizes and strides, for template_codegen """ @@ -5840,7 +6040,7 @@ def get_group_stride(self): # type: ignore[no-untyped-def] # iter_ranges = _size of output tensor, reduce_range = [] because no reduction return [_size, []], _stride - def canonicalize(self): # type: ignore[no-untyped-def] + def canonicalize(self) -> tuple[Expr, Sequence[Expr]]: """ Manually get canonicalization of the output index """ @@ -5903,25 +6103,27 @@ def __str__(self) -> str: @ir_dataclass(frozen=False) class ExternKernelOut(ExternKernel): - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.generate_extern_kernel_out(self) - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - layout, - inputs, - constant_args=(), - kwargs=None, - output_view=None, - python_kernel_name=None, - cpp_kernel_name=None, - ordered_kwargs_for_cpp_kernel=(), - op_overload=None, + layout: Layout, + inputs: Sequence[IRNode], + constant_args: Sequence[Any] = (), + kwargs: Optional[dict[str, Any]] = None, + output_view: Optional[ReinterpretView] = None, + python_kernel_name: Optional[str] = None, + cpp_kernel_name: Optional[str] = None, + ordered_kwargs_for_cpp_kernel: Sequence[Any] = (), + op_overload: Optional[_OpOverloads] = None, ) -> None: + unwrapped_inputs = self.unwrap_storage(inputs) + assert isinstance(unwrapped_inputs, Sequence), type(unwrapped_inputs) super().__init__( None, layout, - self.unwrap_storage(inputs), + unwrapped_inputs, constant_args, kwargs or {}, None, @@ -5958,24 +6160,26 @@ def __init__(self, count: int, device: torch.device) -> None: class ExternKernelAlloc(ExternKernel): - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.generate_extern_kernel_alloc(self) - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - layout, - inputs, - constant_args=(), - kwargs=None, - python_kernel_name=None, - cpp_kernel_name=None, - ordered_kwargs_for_cpp_kernel=(), - op_overload=None, + layout: OutputSpec, + inputs: Sequence[IRNode], + constant_args: Sequence[Any] = (), + kwargs: Optional[dict[str, Any]] = None, + python_kernel_name: Optional[str] = None, + cpp_kernel_name: Optional[str] = None, + ordered_kwargs_for_cpp_kernel: Sequence[Any] = (), + op_overload: Optional[_OpOverloads] = None, ) -> None: + unwrapped_inputs = self.unwrap_storage(inputs) + assert all(isinstance(i, IRNode) for i in unwrapped_inputs) super().__init__( None, layout, - self.unwrap_storage(inputs), + cast(Sequence[IRNode], unwrapped_inputs), constant_args, kwargs or {}, None, @@ -5994,7 +6198,7 @@ def __init__( # type: ignore[no-untyped-def] def should_allocate(self) -> bool: return False - def apply_constraint(self): # type: ignore[no-untyped-def] + def apply_constraint(self) -> None: raise NotImplementedError @@ -6003,7 +6207,9 @@ class MutationOutput(Buffer): An output buffer that represents the mutation of a pre-existing buffer """ - def __init__(self, layout, mutated_node, mutating_node: Operation) -> None: # type: ignore[no-untyped-def] + def __init__( + self, layout: OutputSpec, mutated_node: IRNode, mutating_node: Operation + ) -> None: super().__init__(name=None, layout=layout) mutated_node_name = mutated_node.get_name() V.graph.mark_buffer_mutated(mutated_node_name) @@ -6023,10 +6229,12 @@ def should_allocate(self) -> bool: class TMADescriptor(ExternKernel): """ - An IR node representing a host-side TMA descriptor in the Triton API - (the ones obtained via create_{1d,2d}_tma_descriptor calls). Mostly - useful for user-defined Triton kernels relying on host-side TMA; but - can, in principle, be used for Inductor's Triton templates, too. + An IR node representing a generic host-side TMA descriptor in the Triton API + Mostly useful for user-defined Triton kernels relying on host-side TMA; + but can, in principle, be used for Inductor's Triton templates, too. + + See TMADescriptorExperimental and TMADescriptorStable for the two implementations + (the old API and the new API) """ # as TMA descriptors are immutable, @@ -6034,18 +6242,61 @@ class TMADescriptor(ExternKernel): _CACHE: dict[Any, TMADescriptor] = {} @classmethod - def create( # type: ignore[no-untyped-def] - cls, - tensor: IRNode, - dims: list[Union[int, torch.SymInt]], - block_dims: list[Union[int, torch.SymInt]], - element_size: Optional[int] = None, - ): - key = (id(tensor), dims, block_dims, element_size) + def _create_impl( + cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]] + ) -> TMADescriptor: + assert len(tma_meta) == 2 + if tma_meta[0] == "experimental": + return TMADescriptorExperimental(tensor, *tma_meta[1]) + else: + assert tma_meta[0] == "stable" + return TMADescriptorStable(tensor, *tma_meta[1]) + + @classmethod + def create( + cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]] + ) -> TMADescriptor: + key = (id(tensor), tma_meta) if key not in cls._CACHE: - cls._CACHE[key] = TMADescriptor(tensor, dims, block_dims, element_size) + cls._CACHE[key] = cls._create_impl(tensor, tma_meta) return cls._CACHE[key] + def __init__(self, tensor: IRNode, inputs, constant_args): # type: ignore[no-untyped-def] + super().__init__( + None, + # link back to the underlying tensor in terms of ownership + # to avoid getting the underlying tensor deleted *before* + # the TMADescriptor node can be deleted. + NonOwningLayout( + ReinterpretView( + data=tensor, + layout=tensor.get_layout(), + ) + ), + cast(Sequence[Buffer], inputs), + tuple(constant_args), + None, + ) + + self.tensor = tensor + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.generate_tma_descriptor(self) + + def get_tensor(self) -> IRNode: + return self.tensor + + +class TMADescriptorExperimental(TMADescriptor): + """ + the new host-side TMA Descriptor API: + (the ones obtained via create_{1d,2d}_tma_descriptor calls). + + See also TMADescriptorStable for the new API. + """ + def __init__( self, tensor: IRNode, @@ -6059,7 +6310,6 @@ def __init__( if element_size is None: element_size = tensor.get_dtype().itemsize - self.tensor = tensor self.dims = dims self.block_dims = block_dims self.element_size = element_size @@ -6073,26 +6323,28 @@ def __init__( ] super().__init__( - None, - # link back to the underlying tensor in terms of ownership - # to avoid getting the underlying tensor deleted *before* - # the TMADescriptor node can be deleted. - NonOwningLayout( - ReinterpretView( - data=tensor, - layout=tensor.get_layout(), - ) - ), - inputs, - tuple(constant_args), - None, + tensor=tensor, + inputs=inputs, + constant_args=constant_args, ) - self.name = V.graph.register_buffer(self) - V.graph.register_operation(self) - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] - wrapper.generate_tma_descriptor(self) +class TMADescriptorStable(TMADescriptor): + """ + the new host-side TMA descriptor API + (the ones obtained via TensorDescriptor.from_tensor). + + See also TMADescriptorExperimental for the old API. + """ + + def __init__(self, tensor: IRNode, block_shape: list[Union[int, torch.SymInt]]): + self.block_shape = block_shape + + super().__init__( + tensor=tensor, + inputs=[tensor], + constant_args=block_shape, + ) class SubgraphBuffer(ExternKernel): @@ -6110,12 +6362,15 @@ def __init__( self.name = V.graph.register_buffer(self) V.graph.register_operation(self) - gm_original_output_strides(self.gm) self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name) - sym_inputs = add_symbolic_shapes_for_inputs_to_subgraph( - self.inputs, self.subgraph - ) + assert is_node_sequence(self.inputs) + sym_inputs = get_symbolic_inputs(self.inputs) + + for sym_inp in sym_inputs: + self.subgraph.graph_inputs[sym_inp.name] = sym_inp + self.subgraph.graph_input_names.append(sym_inp.name) + self.sym_inputs = [sym_var.name for sym_var in sym_inputs] import torch._inductor.config as inductor_config @@ -6135,6 +6390,7 @@ def __init__(self, graph: GraphLowering): self.graph = graph self.name = graph.name + assert is_node_sequence(self.inputs) outer_inputs = [t.codegen_reference() for t in self.inputs] wrapper.codegen_subgraph_with_flattened_outputs( CodegenGraph(self.subgraph), @@ -6144,7 +6400,7 @@ def __init__(self, graph: GraphLowering): class UserDefinedTritonKernel(ExternKernel): - def get_kernel_and_metadata(self): # type: ignore[no-untyped-def] + def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]: from triton.runtime.autotuner import Autotuner from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table @@ -6175,7 +6431,11 @@ def get_kernel_and_metadata(self): # type: ignore[no-untyped-def] kernel = kernel.fn return kernel, configs, restore_value_args, reset_to_zero_args - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + @override + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + """Overrides the parent member. + See https://github.com/pytorch/pytorch/issues/151692""" + from torch._inductor.utils import triton_version_uses_attrs_dict ( @@ -6201,7 +6461,10 @@ def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] named_args = { k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel } - constexpr_names = OrderedSet([kernel.arg_names[i] for i in kernel.constexprs]) + assert hasattr(kernel, "arg_names") and hasattr(kernel, "constexprs"), type( + kernel + ) + constexpr_names = OrderedSet(kernel.arg_names[i] for i in kernel.constexprs) args: list[Any] = [] arg_types: list[Any] = [] @@ -6267,17 +6530,23 @@ def get_free_symbol_uses( def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() - def __init__( # type: ignore[no-untyped-def] - self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args + def __init__( + self, + *, + kernel_idx: int, + grid: Any, + tma_descriptor_metadata: dict[str, Any], + kernel_args: dict[str, Any], ) -> None: - inputs = [] - kwargs = {} - constant_args = [] + inputs: list[IRNode] = [] + kwargs: dict[str, IRNode] = {} + constant_args: list[IRNode] = [] + for k, v in kernel_args.items(): if isinstance(v, TensorBox): t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) if k in tma_descriptor_metadata: - t = TMADescriptor.create(t, *tma_descriptor_metadata[k]) + t = TMADescriptor.create(t, tma_descriptor_metadata[k]) inputs.append(t) kwargs[k] = t else: @@ -6287,6 +6556,7 @@ def __init__( # type: ignore[no-untyped-def] assert len(inputs) != 0 self.device = inputs[0].get_device() + assert isinstance(inputs, Sequence), type(inputs) super().__init__( None, NoneLayout(device=self.device), @@ -6300,6 +6570,7 @@ def __init__( # type: ignore[no-untyped-def] kernel, configs, _, _ = self.get_kernel_and_metadata() # If we are autotuning, not all arguments will be passed + assert hasattr(kernel, "arg_names") self.ordered_kwargs_for_cpp_kernel = [ arg for arg in kernel.arg_names if arg in kernel_args ] @@ -6310,7 +6581,7 @@ def __init__( # type: ignore[no-untyped-def] self.mutable_args = [ kernel_args[key] for key in identify_mutated_tensors( - kernel, {**kernel_args, **autotuned_kwargs} + kernel, {**kernel_args, **autotuned_kwargs}, tma_descriptor_metadata ) ] @@ -6332,8 +6603,9 @@ class InplaceBernoulliFallback(ExternKernel): This needs to be a custom class to handle mutation properly """ - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] - (x,) = (t.codegen_reference() for t in self.inputs) + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + assert all(isinstance(t, IRNode) for t in self.inputs) + (x,) = (cast(IRNode, t).codegen_reference() for t in self.inputs) if V.graph.cpp_wrapper: # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, @@ -6350,12 +6622,14 @@ def should_allocate(self) -> bool: return False def get_mutation_names(self) -> Sequence[str]: - return [self.inputs[0].get_name()] + return [self.input_name(0)] def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() - def __init__(self, op_overload, x, *constant_args) -> None: # type: ignore[no-untyped-def] + def __init__( + self, op_overload: _OpOverloads, x: IRNode, *constant_args: Any + ) -> None: super().__init__( None, NoneLayout(device=x.get_device()), @@ -6374,7 +6648,7 @@ class InplaceCopyFallback(ExternKernel): This needs to be a custom class to handle mutation properly """ - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: (dst, src, non_blocking) = self.codegen_args() wrapper.codegen_device_copy(src, dst, non_blocking) @@ -6382,16 +6656,16 @@ def should_allocate(self) -> bool: return False def get_mutation_names(self) -> Sequence[str]: - return [self.inputs[0].get_name()] + return [self.input_name(0)] def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - layout, - inputs, - constant_args, + layout: OutputSpec, + inputs: Sequence[IRNode], + constant_args: Sequence[Any], ) -> None: super().__init__( None, @@ -6406,7 +6680,9 @@ def __init__( # type: ignore[no-untyped-def] V.graph.register_operation(self) @classmethod - def create(cls, dst, src, non_blocking: bool = False): # type: ignore[no-untyped-def] + def create( + cls, dst: IRNode, src: IRNode, non_blocking: bool = False + ) -> InplaceCopyFallback: inputs = [cls.realize_input(t) for t in [dst, src]] constant_args = (non_blocking,) result = InplaceCopyFallback( @@ -6422,7 +6698,8 @@ class MutatingFirstArgExternKernel(ExternKernel): This needs to be a custom class to handle mutation properly """ - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + assert is_node_sequence(self.inputs) argrefs = [ *(t.codegen_reference() for t in self.inputs), *map(repr, self.constant_args), @@ -6435,7 +6712,7 @@ def should_allocate(self) -> bool: return False def get_mutation_names(self) -> Sequence[str]: - return [self.inputs[0].get_name()] + return [self.input_name(0)] def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() @@ -6445,7 +6722,7 @@ def has_side_effects(self) -> bool: class ResizeStorageBytes(MutatingFirstArgExternKernel): - def __init__(self, variable, new_size) -> None: # type: ignore[no-untyped-def] + def __init__(self, variable: IRNode, new_size: int) -> None: assert isinstance(new_size, int), "TODO: dynamic shapes" super().__init__( None, @@ -6458,11 +6735,12 @@ def __init__(self, variable, new_size) -> None: # type: ignore[no-untyped-def] V.graph.register_operation(self) self.python_kernel_name = "inductor_ops.resize_storage_bytes_" self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" + assert isinstance(variable, (BaseView, StorageBox, TensorBox)), type(variable) V.graph.never_reuse_buffers.add(variable.data.get_name()) class SetSourceTensorKernel(ExternKernelAlloc): - def __init__(self, self_tensor, storage_tensor) -> None: # type: ignore[no-untyped-def] + def __init__(self, self_tensor: IRNode, storage_tensor: IRNode) -> None: storage_tensor.freeze_layout() super().__init__( storage_tensor.get_layout(), @@ -6470,6 +6748,9 @@ def __init__(self, self_tensor, storage_tensor) -> None: # type: ignore[no-unty python_kernel_name="torch.ops.aten.set_.source_Tensor", op_overload=torch.ops.aten.set_.source_Tensor, ) + assert isinstance(self_tensor, (BaseView, StorageBox, TensorBox)), type( + self_tensor + ) V.graph.never_reuse_buffers.add(self_tensor.data.get_name()) V.graph.never_reuse_buffers.add(storage_tensor.get_name()) V.graph.never_reuse_buffers.add(self.get_name()) @@ -6480,7 +6761,7 @@ def __init__(self, self_tensor, storage_tensor) -> None: # type: ignore[no-unty ] def get_inputs_that_alias_output(self) -> Sequence[str]: - return [self.inputs[0].get_name(), self.inputs[1].get_name()] + return [self.input_name(0), self.input_name(1)] class ScatterFallback(ExternKernel): @@ -6490,7 +6771,7 @@ class ScatterFallback(ExternKernel): It also handle the case `src` being a scalar properly. """ - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: reduce = self.kwargs["reduce"] if V.graph.cpp_wrapper: # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum @@ -6498,6 +6779,7 @@ def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] if reduce in get_operator_enum: reduce = get_operator_enum[reduce] + assert is_node_sequence(self.inputs) if self.src_is_tensor: (x, index, src) = (t.codegen_reference() for t in self.inputs) else: @@ -6516,19 +6798,21 @@ def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] def should_allocate(self) -> bool: return False - def get_mutation_names(self) -> Sequence[str]: - return [self.inputs[0].get_name()] + def get_mutation_names(self) -> list[str]: + inp = self.inputs[0] + assert isinstance(inp, IRNode) + return [inp.get_name()] def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - op_overload, - x, + op_overload: _OpOverloads, + x: IRNode, dim: int, - index, - src, + index: IRNode, + src: IRNode, *, reduce: Optional[str] = None, include_self: bool = True, @@ -6563,7 +6847,8 @@ class IndexPutFallback(ExternKernel): This needs to be a custom class to handle mutation and indices properly """ - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + assert is_node_sequence(self.inputs) (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) indices = [] iter_valid_indices = iter(valid_indices) @@ -6581,12 +6866,19 @@ def should_allocate(self) -> bool: return False def get_mutation_names(self) -> Sequence[str]: - return [self.inputs[0].get_name()] + return [self.input_name(0)] def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() - def __init__(self, op_overload, x, indices, values, accumulate) -> None: # type: ignore[no-untyped-def] + def __init__( + self, + op_overload: torch._ops.OpOverload, + x: IRNode, + indices: list[Any], + values: Sequence[Any], + accumulate: Any, + ) -> None: self.indices = indices valid_indices = [i for i in indices if i is not None] tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] @@ -6600,14 +6892,14 @@ def __init__(self, op_overload, x, indices, values, accumulate) -> None: # type cpp_kernel_name=cpp_kernel_name, op_overload=op_overload, ) - V.graph.mark_buffer_mutated(self.inputs[0].get_name()) + V.graph.mark_buffer_mutated(self.input_name(0)) self.name = V.graph.register_buffer(self) V.graph.register_operation(self) class DeviceCopy(ExternKernelOut): @classmethod - def create(cls, x, device, non_blocking): # type: ignore[no-untyped-def] + def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: if ( not x.is_extern() and all(r in V.graph.constants for r in x.get_read_names()) @@ -6616,7 +6908,9 @@ def create(cls, x, device, non_blocking): # type: ignore[no-untyped-def] return x.constant_to_device(device) V.graph.add_device_info(device) - V.graph.add_device_info(x.get_device()) + x_device = x.get_device() + assert x_device is not None + V.graph.add_device_info(x_device) developer_warning("DeviceCopy in input program") constant_args = (non_blocking,) @@ -6630,7 +6924,7 @@ def create(cls, x, device, non_blocking): # type: ignore[no-untyped-def] constant_args, ) - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: args = self.codegen_args() assert len(args) == 2 if self.output_view: @@ -6652,7 +6946,9 @@ def get_reads(self) -> OrderedSet[Dep]: def should_allocate(self) -> bool: return False - def __init__(self, sym, keypath, data) -> None: # type: ignore[no-untyped-def] + def __init__( + self, sym: sympy.Symbol, keypath: pytree.KeyPath, data: IRNode + ) -> None: data.realize() super().__init__( None, NoneLayout(device=torch.device("cpu")), self.unwrap_storage([data]) @@ -6663,7 +6959,7 @@ def __init__(self, sym, keypath, data) -> None: # type: ignore[no-untyped-def] def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet([self.sym]) - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_dynamic_scalar(self) @@ -6678,7 +6974,7 @@ def get_reads(self) -> OrderedSet[Dep]: def should_allocate(self) -> bool: return False - def __init__(self, scalar, msg) -> None: # type: ignore[no-untyped-def] + def __init__(self, scalar: SympyBoolean, msg: str) -> None: super().__init__( # Buffer(name, layotu) None, @@ -6692,10 +6988,12 @@ def __init__(self, scalar, msg) -> None: # type: ignore[no-untyped-def] def has_side_effects(self) -> bool: return True - def get_free_symbol_uses(self, unbacked_only: bool = False): # type: ignore[no-untyped-def] + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: return get_free_symbols(self.scalar, unbacked_only) - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: if not config.scalar_asserts: return # NB: It is EXTREMELY important not to simplify the scalar under assertion here, @@ -6731,16 +7029,22 @@ class ExternKernelNode: class FallbackKernel(ExternKernelAlloc): + """ + A class that represents a fallback kernel for handling operators that are not + directly support by inductor. It currently supports functional ops, view ops, + inplace aten ops, and mutating ops that are auto-functionalizable. + """ + def __init__( # type: ignore[no-untyped-def] self, - layout, - kernel, - tensor_args, - nontensor_args, - unflatten_args, - kwargs=None, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + kwargs: Optional[dict[str, Any]] = None, *, - unbacked_bindings=None, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, ) -> None: super().__init__( layout, @@ -6750,19 +7054,16 @@ def __init__( # type: ignore[no-untyped-def] ) self.use_runtime_dispatch = False - self.unbacked_bindings = unbacked_bindings + self.unbacked_bindings = unbacked_bindings or {} assert isinstance( - kernel, - ( - torch._ops.OpOverload, - torch._ops.HigherOrderOperator, - ), + kernel, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" self.op_overload = kernel self.unflatten_args = unflatten_args self.kwargs = {} if kwargs is None else kwargs - V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type] + assert self.python_kernel_name is not None + V.graph.warn_fallback(self.python_kernel_name) # args that are aliased self.alias_names: list[str] = [] @@ -6807,10 +7108,10 @@ def __init__( # type: ignore[no-untyped-def] args, kwargs = self.unflatten_args(self.inputs, self.constant_args) - def handle_aliasing_and_mutation(info, arg) -> None: # type: ignore[no-untyped-def] + def handle_aliasing_and_mutation(info: torch._C.Argument, arg: Any) -> None: # Assertions to make sure we didn't mismatch args if isinstance(info.type, torch.ListType): - assert isinstance(arg, (list, tuple)) + assert isinstance(arg, (list, tuple)), type(arg) if library_utils.is_tensor_like_type(info.type): # PyTorch also accepts None and scalar types for args marked as "Tensor". # We're not going to check all of them here. @@ -6821,8 +7122,9 @@ def handle_aliasing_and_mutation(info, arg) -> None: # type: ignore[no-untyped- if info.alias_info is None: return - def add_alias(t) -> None: # type: ignore[no-untyped-def] + def add_alias(t: IRNode) -> None: self.alias_names.append(t.get_name()) + assert info.alias_info is not None if info.alias_info.is_write: self.mutation_outputs.append( MutationOutput(NoneLayout(device=t.get_device()), t, self) @@ -6851,22 +7153,22 @@ def get_read_writes(self) -> dependencies.ReadWrites: return read_writes - def codegen_unbacked_symbol_defs(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen_unbacked_symbol_defs(self, wrapper: PythonWrapperCodegen) -> None: return wrapper.codegen_unbacked_symbol_defs_for_outputs( self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None) ) - def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + def get_unbacked_symbol_defs(self) -> Container[sympy.Symbol]: # type: ignore[override] if unbacked_bindings := getattr(self, "unbacked_bindings", None): resolved = resolve_unbacked_bindings( V.graph.sizevars.shape_env, unbacked_bindings ) assert resolved is not None - return resolved.keys() # type: ignore[return-value] + return resolved.keys() else: return OrderedSet() - def codegen_args(self): # type: ignore[no-untyped-def] + def codegen_args(self) -> list[str]: @dataclasses.dataclass class Shim: ref: Any @@ -6874,6 +7176,7 @@ class Shim: def __repr__(self) -> str: return self.ref + assert is_node_sequence(self.inputs) tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] args, kwargs = self.unflatten_args(tensor_args, self.constant_args) if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): @@ -6890,13 +7193,16 @@ def __repr__(self) -> str: return args @staticmethod - def find_device(tensor_args, example_output): # type: ignore[no-untyped-def] + def find_device( + tensor_args: Optional[Sequence[torch.Tensor]], example_output: Sequence[Any] + ) -> Any: non_torch_bind_tensor_args = ( [t for t in tensor_args if not isinstance(t, TorchBindObject)] if tensor_args else None ) if non_torch_bind_tensor_args: + assert tensor_args devices = [arg.get_device() for arg in tensor_args if arg.get_device()] return devices[0] if isinstance(example_output, torch.Tensor): @@ -6910,17 +7216,18 @@ def find_device(tensor_args, example_output): # type: ignore[no-untyped-def] if len(devices) == 1: return devices[0] for device in devices: + assert isinstance(device, torch.device) if is_gpu(device.type): return device return devices[0] return None - def has_side_effects(self): # type: ignore[no-untyped-def] + def has_side_effects(self) -> bool: if isinstance(self.op_overload, torch._ops.HigherOrderOperator): return False return get_schema_info(self.op_overload).is_mutable() - def get_inputs_that_alias_output(self): # type: ignore[no-untyped-def] + def get_inputs_that_alias_output(self) -> Sequence[str]: return self.alias_names def get_mutation_names(self) -> Sequence[str]: @@ -6942,7 +7249,7 @@ def export_extern_kernel_node(self): # type: ignore[no-untyped-def] self.op_overload, ) - assert isinstance(self, FallbackKernel) + assert isinstance(self, FallbackKernel), type(self) args, kwargs = self.unflatten_args(self.inputs, self.constant_args) args = self.fill_non_provided_args(args, kwargs) ordered_kwargs = [ @@ -6955,11 +7262,14 @@ def export_extern_kernel_node(self): # type: ignore[no-untyped-def] # No need to serialize in the cpp wrapper JIT mode return [*args, *ordered_kwargs] - serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] + serializer = GraphModuleSerializer(None, []) # type: ignore[arg-type] named_arguments = serializer.serialize_inputs(target, args, kwargs) # serialize_outputs - def handle_single_output(return_type, output): # type: ignore[no-untyped-def] + def handle_single_output( + return_type: Union[torch.TensorType, torch.ListType, torch.JitType], + output: Union[IRNode, Sequence[IRNode]], + ) -> export_schema.Argument: if isinstance(return_type, (torch.TensorType, torch.NoneType)): # For single Tensor or None out = output @@ -6967,6 +7277,7 @@ def handle_single_output(return_type, output): # type: ignore[no-untyped-def] assert len(output) == 1 out = output[0] if isinstance(return_type, torch.TensorType): + assert isinstance(out, IRNode) return export_schema.Argument.create( as_tensor=export_schema.TensorArgument(name=out.get_name()) ) @@ -6976,6 +7287,7 @@ def handle_single_output(return_type, output): # type: ignore[no-untyped-def] elif isinstance(return_type, torch.ListType) and isinstance( return_type.getElementType(), torch.TensorType ): + assert isinstance(output, Sequence), type(output) # For single TensorList return export_schema.Argument.create( as_tensors=[ @@ -6994,6 +7306,7 @@ def handle_single_output(return_type, output): # type: ignore[no-untyped-def] ) ) else: + assert isinstance(output, IRNode) return export_schema.Argument.create( as_optional_tensor=export_schema.OptionalTensorArgument.create( as_tensor=export_schema.TensorArgument( @@ -7001,11 +7314,13 @@ def handle_single_output(return_type, output): # type: ignore[no-untyped-def] ) ) ) + elif isinstance(return_type, torch.IntType): + return export_schema.Argument.create(as_int=output) else: raise RuntimeError(f"Unsupported return type {type(return_type)}") if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind): - returns = target.schema(args[0], args[1]).returns # type: ignore[union-attr] + returns = target.schema(args[0], args[1]).returns else: returns = target._schema.returns # type: ignore[union-attr] if len(returns) == 1: @@ -7018,14 +7333,18 @@ def handle_single_output(return_type, output): # type: ignore[no-untyped-def] # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" # Not generating output args for self.mutation_outputs output_arguments = [ - handle_single_output(return_schema.real_type, output) + handle_single_output( + return_schema.real_type, # type: ignore[attr-defined] + output, + ) for return_schema, output in zip(returns, self.outputs) ] + assert self.op_overload is not None node = ExternKernelNode( name=self.get_name(), node=export_schema.Node( - target=self.op_overload.name(), # type: ignore[union-attr] + target=self.op_overload.name(), inputs=named_arguments, outputs=output_arguments, metadata={}, @@ -7036,11 +7355,15 @@ def handle_single_output(return_type, output): # type: ignore[no-untyped-def] return [*args, *ordered_kwargs] - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + @override + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + """Overrides the parent member. + See https://github.com/pytorch/pytorch/issues/151692""" kernel = self.op_overload - if kernel.namespace == "aten": # type: ignore[union-attr] + assert kernel is not None + if kernel.namespace == "aten": # Aten Fallback Ops - assert isinstance(kernel, torch._ops.OpOverload) + assert isinstance(kernel, torch._ops.OpOverload), type(kernel) if V.graph.cpp_wrapper: from torchgen.aoti.fallback_ops import inductor_fallback_ops @@ -7052,9 +7375,9 @@ def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] kernel, ) self.use_runtime_dispatch = True - elif kernel.namespace == "_quantized": # type: ignore[union-attr] + elif kernel.namespace == "_quantized": # Internal Quantized Fallback Ops - assert isinstance(kernel, torch._ops.OpOverload) + assert isinstance(kernel, torch._ops.OpOverload), type(kernel) elif V.graph.cpp_wrapper: # For non-aten OpOverload, i.e. custom ops # If the op is in custom_ops_to_c_shims, generate direct function call @@ -7062,31 +7385,48 @@ def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] kernel not in config.aot_inductor.custom_ops_to_c_shims ) - def is_number(t: torch.JitType) -> bool: - if isinstance(t, torch.OptionalType): - return is_number(t.getElementType()) - return isinstance(t, torch.NumberType) - - self.codegen_comment(wrapper) - args = [*self.codegen_args(), *self.codegen_kwargs()] - if self.use_runtime_dispatch or ( - # Handle the special case where a complex number is input to a - # cpp_wrapper C-shim kernel. If the corresponding argument is a number, - # the torchgen-created shim API will use type "double", which cannot be - # converted to from a c10::complex. In these cases, fallback to runtime - # dispatch. + # Handle the special case where a complex number is input to a C-shim kernel for + # a scalar input. The torchgen'ed shim API will use type "double", which is + # incompatible with complex numbers, forcing a fallback to runtime dispatch. + if ( V.graph.cpp_wrapper and isinstance(kernel, torch._ops.OpOverload) - and any( - "c10::complex" in arg_str and is_number(op_arg.real_type) - for arg_str, op_arg in zip(args, kernel._schema.arguments) - ) + and not self.use_runtime_dispatch ): + + def is_number(t: torch.JitType) -> bool: + if isinstance(t, torch.OptionalType): + return is_number(t.getElementType()) + return isinstance(t, torch.NumberType) + + # Using unflatten_args is a bit of a hack, but all the complex arguments we + # care about are in self.constant_args, and calling unflatten_args puts them + # in the correct order without triggering codegen. + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + # Append kwarg values to args. ordered_kwargs_for_cpp_kernel is guaranteed + # to be set, since this is an OpOverload kernel. + args_iter = itertools.chain( + args, + ( + self.get_kwargs_value(k, **kwargs) + for k in self.ordered_kwargs_for_cpp_kernel + ), + ) + self.use_runtime_dispatch = any( + isinstance(v, complex) and is_number(a.real_type) + for v, a in zip(args_iter, kernel._schema.arguments) + ) + + self.codegen_comment(wrapper) + if self.use_runtime_dispatch: exported_args = self.export_extern_kernel_node() + assert self.python_kernel_name is not None + assert self.op_overload is not None + wrapper.generate_fallback_kernel_with_runtime_lookup( self.get_name(), self.python_kernel_name, - args, + lambda: [*self.codegen_args(), *self.codegen_kwargs()], self.op_overload, exported_args, # NOTE: [special handling of all_reduce_coalesced_'s return value] @@ -7101,7 +7441,7 @@ def is_number(t: torch.JitType) -> bool: self.codegen_unbacked_symbol_defs(wrapper) @staticmethod - def tensor_to_layout(output: torch.Tensor): # type: ignore[no-untyped-def] + def tensor_to_layout(output: torch.Tensor) -> FixedLayout: return FixedLayout( output.device, output.dtype, @@ -7110,11 +7450,14 @@ def tensor_to_layout(output: torch.Tensor): # type: ignore[no-untyped-def] ) @classmethod - def create(cls, kernel, *args, **kwargs): # type: ignore[no-untyped-def] + def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKernel: + """Create an instance of FallbackKernel from an _OpOverloads""" fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) - context: AbstractContextManager[None] = ( - V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment] - ) + if kernel not in fake_incorrect_kernels: + context = cast(AbstractContextManager[None], V.graph.fake_mode) + else: + context = nullcontext() + with context: ( example_output, @@ -7157,7 +7500,7 @@ def create(cls, kernel, *args, **kwargs): # type: ignore[no-untyped-def] unbacked_bindings=unbacked_bindings, ) - def generate_output(output, indices): # type: ignore[no-untyped-def] + def generate_output(output: Any, indices: list[tuple[Any, int]]) -> Any: if isinstance(output, (list, tuple)): return type(output)( generate_output(output[i], indices + [(type(output), i)]) @@ -7192,13 +7535,15 @@ def generate_output(output, indices): # type: ignore[no-untyped-def] return None outputs = generate_output(example_output, []) - if isinstance(outputs, (list, tuple, dict)): - packed.outputs = outputs # type: ignore[assignment] + if isinstance(outputs, (list, tuple)): + packed.outputs = outputs + elif isinstance(outputs, dict): + packed.outputs = tuple(outputs) else: packed.outputs = [outputs] return outputs - def apply_constraint(self): # type: ignore[no-untyped-def] + def apply_constraint(self) -> None: return super().apply_constraint() @@ -7211,17 +7556,17 @@ def should_allocate(self) -> bool: def get_inputs_that_alias_output(self) -> Sequence[str]: # Signal to codegen that our output buffer isn't safe to reuse - return [self.inputs[0].get_name()] + return [self.input_name(0)] - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - layout, - kernel, - tensor_args, - nontensor_args, - unflatten_args, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], *, - unbacked_bindings=None, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, ) -> None: super().__init__( layout, @@ -7248,12 +7593,12 @@ def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] self.codegen_size_asserts(wrapper) self.codegen_alignment_asserts(wrapper) - def __init__( # type: ignore[no-untyped-def] + def __init__( self, layout: OutputSpec, - input, + input: IRNode, indices: list[tuple[Any, ...]], - skip_size_stride_alignment_checks=False, + skip_size_stride_alignment_checks: bool = False, ) -> None: super().__init__(None, layout, [input], ()) self.name = V.graph.register_buffer(self) @@ -7264,14 +7609,14 @@ def __init__( # type: ignore[no-untyped-def] def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: - return self.inputs[0].get_free_symbol_uses(unbacked_only) + input_node = self.inputs[0] + assert isinstance(input_node, IRNode), input_node + return input_node.get_free_symbol_uses(unbacked_only) def should_allocate(self) -> bool: - if len(self.inputs) == 1 and ( + return len(self.inputs) == 1 and ( isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM - ): - return True - return False + ) def get_inputs_that_alias_output(self) -> Sequence[str]: return [ @@ -7329,18 +7674,18 @@ def freeze_layout(self) -> None: return self.data.freeze_layout() def freeze_layout_with_stride_order( - self, order: list[int], allow_padding: bool = False + self, order: Sequence[int], allow_padding: bool = False ) -> None: return self.data.freeze_layout_with_stride_order(order, allow_padding) - def freeze_layout_with_fill_order(self, order: list[int]) -> None: + def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None: return self.data.freeze_layout_with_fill_order(order) - def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None: + def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None: return self.data.freeze_layout_with_same_order(stride) def freeze_layout_with_exact_strides( - self, exact_strides: list[_IntLike], allow_padding: bool = False + self, exact_strides: Sequence[_IntLike], allow_padding: bool = False ) -> None: return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding) @@ -7359,7 +7704,7 @@ def get_storage_numel(self) -> _IntLike: def get_reduction_type(self) -> Optional[str]: return self.data.get_reduction_type() - def get_reduction_size(self) -> Sequence[sympy.Expr]: + def get_reduction_size(self) -> Sequence[Expr]: return self.data.get_reduction_size() def is_extern(self) -> bool: @@ -7412,7 +7757,7 @@ def get_size(self) -> Sequence[Expr]: return self.data.get_size() @property - def dtype(self): # type: ignore[no-untyped-def] + def dtype(self) -> torch.dtype: return self.data.dtype def __str__(self) -> str: @@ -7437,7 +7782,7 @@ def __str__(self) -> str: class TensorBox(MutableBox): @staticmethod - def create(data): # type: ignore[no-untyped-def] + def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: if isinstance(data, ShapeAsConstantBuffer): return data return TensorBox(StorageBox(data)) @@ -7449,7 +7794,7 @@ def is_input_buffer(self) -> bool: return self.data.get_name() in V.graph.graph_inputs return False - def is_module_buffer(self): # type: ignore[no-untyped-def] + def is_module_buffer(self) -> bool: return ( isinstance(self.data, (ConstantBuffer)) and self.data.get_name() in V.graph.constants @@ -7472,10 +7817,13 @@ def realize(self) -> Optional[str]: ) origin_node = self.data.get_origin_node() traceback = self.data.get_traceback() + device = self.data.get_device() + assert device is not None + self.data = ComputedBuffer( name=None, layout=FlexibleLayout( - device=self.data.get_device(), + device=device, dtype=self.data.get_dtype(), size=self.data.get_size(), ), @@ -7504,7 +7852,7 @@ def has_exceeded_max_reads(self) -> bool: or self.has_large_inner_fn() ) - def should_realize_on_reuse(self, users): # type: ignore[no-untyped-def] + def should_realize_on_reuse(self, users: int) -> bool: """ A heuristic to decide if we should realize a tensor that is used multiple times. @@ -7526,7 +7874,7 @@ def mark_reuse(self, users: int) -> None: if self.should_realize_on_reuse(users): self.realize() - def num_reads(self): # type: ignore[no-untyped-def] + def num_reads(self) -> int: return self.data.num_reads() @@ -7548,12 +7896,16 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: @ir_dataclass(frozen=False) class InvokeSubgraph(ExternKernel): + """ + Ir node for the invoke_subgraph HOP. + """ + subgraph: Optional[Subgraph] = None - operands: Optional[list[TensorBox]] = None - outputs: Optional[list[MultiOutput]] = None + operands: Optional[Sequence[IRNode]] = None + outputs: Optional[Sequence[IRNode]] = None def __init__( - self, subgraph: Subgraph, operands: list[TensorBox], layout: MultiOutputLayout + self, subgraph: Subgraph, operands: Sequence[IRNode], layout: MultiOutputLayout ) -> None: super().__init__( name=None, @@ -7565,26 +7917,39 @@ def __init__( V.graph.register_operation(self) @classmethod - def create(cls, subgraph: Subgraph, *operands): # type: ignore[no-untyped-def] + def create( + cls, subgraph: Subgraph, *operands: IRNode + ) -> list[Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]]: + """For each operand, get a realized input, force it to have the same + strides as the subgraph inputs, then use an InvokeSubgraph""" + from .lowering import constrain_to_fake_tensor + # TODO(anijain2305) - Support sym expr as operands in future. - fx_operands = V.graph.current_node.args[2:] - fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + current_node = V.graph.current_node + + fake_operands = None + if eager_input_vals := current_node.meta.get("eager_input_vals"): + # eager_input_vals is (args_values, kwargs_values). We need args for invoke_subgraph + fake_operands = eager_input_vals[0][2:] + else: + # For the partitioned backward graph, we do not have + # eager_input_vals. Here, we rely on the recorded example values. + fx_operands = current_node.args[2:] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] # Realize the inputs. Also intermediates can have different strides than # the inputs of the subgraph. So, force the intermediates to have same # strides as that of subgraph inputs. - operands = [cls.realize_input(x) for x in operands] + operands: list[IRNode] = [cls.realize_input(x) for x in operands] + new_operands: list[IRNode] = [] - def handle_sym_expr(stride): # type: ignore[no-untyped-def] - return [s.node.expr if isinstance(s, torch.SymInt) else s for s in stride] - - new_operands = [] for idx, operand in enumerate(operands): if isinstance(operand, ShapeAsConstantBuffer): new_operands.append(operand) else: - example_stride = handle_sym_expr(fake_operands[idx].stride()) - new_operands.append(cls.require_exact_strides(operand, example_stride)) + new_operands.append( + constrain_to_fake_tensor(operand, fake_operands[idx]) + ) operands = new_operands @@ -7608,50 +7973,54 @@ def handle_sym_expr(stride): # type: ignore[no-untyped-def] device = operand.get_device() break assert device is not None - invoke_subgraph = InvokeSubgraph( subgraph=subgraph, operands=operands, layout=MultiOutputLayout(device=device), ) - def create_output(output: IRNode, ind: int): + def create_output( + output: IRNode, ind: int + ) -> Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]: if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)): return output else: + device = output.get_device() + assert device is not None + return MultiOutput( FixedLayout( - device=output.get_device(), + device=device, dtype=output.get_dtype(), - size=output.get_size(), # type: ignore[arg-type] + size=output.get_size(), stride=output.get_stride(), offset=output.get_layout().offset, ), - invoke_subgraph, + invoke_subgraph, # type: ignore[has-type] [(list, ind)], skip_size_stride_alignment_checks=True, ) - outputs = [create_output(output, i) for i, output in enumerate(outputs)] - invoke_subgraph.outputs = outputs - return outputs + outs = [create_output(output, i) for i, output in enumerate(outputs)] + invoke_subgraph.outputs = outs # type: ignore[assignment] + return outs - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_invoke_subgraph(self) @ir_dataclass(frozen=False) class Conditional(ExternKernel): predicate: Optional[IRNode] = None - operands: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + operands: Optional[Sequence[IRNode]] = None true_subgraph: Optional[Subgraph] = None false_subgraph: Optional[Subgraph] = None - outputs: Optional[list[MultiOutput]] = None + outputs: Optional[Sequence[MultiOutput]] = None def __init__( self, predicate: IRNode, - operands: list[Union[TensorBox, ShapeAsConstantBuffer]], + operands: Sequence[IRNode], true_subgraph: Subgraph, false_subgraph: Subgraph, layout: MultiOutputLayout, @@ -7662,7 +8031,7 @@ def __init__( self.true_subgraph = true_subgraph self.false_subgraph = false_subgraph - sym_args, tensor_args = _split_by_sym_type([predicate] + operands) + sym_args, tensor_args = _split_by_sym_type([predicate, *operands]) super().__init__( name=None, @@ -7677,17 +8046,21 @@ def __init__( V.graph.register_operation(self) @classmethod - def create( # type: ignore[no-untyped-def] + def create( cls, predicate: TensorBox, true_fn: Subgraph, false_fn: Subgraph, operands: list[Union[TensorBox, ShapeAsConstantBuffer]], - ): + ) -> Sequence[IRNode]: + """Create a Sequence of IRNodes from a conditional statement (see .lowering.cond)""" predicate = cls.realize_input(predicate) operands = [cls.realize_input(x) for x in operands] - fx_operands = V.graph.current_node.args[-1] - fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + fx_operands: Argument = V.graph.current_node.args[-1] + + assert isinstance(fx_operands, Sequence), type(fx_operands) + assert all(isinstance(n, Node) for n in fx_operands) + fake_operands = [cast(Node, x).meta["val"] for x in fx_operands] for subgraph in (true_fn, false_fn): if subgraph.graph is None: @@ -7700,8 +8073,10 @@ def create( # type: ignore[no-untyped-def] with V.set_graph_handler(subgraph.graph): subgraph.graph.run(*fake_operands) - true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] - false_outputs = false_fn.graph.graph_outputs # type: ignore[union-attr] + assert true_fn.graph is not None + assert false_fn.graph is not None + true_outputs = true_fn.graph.graph_outputs + false_outputs = false_fn.graph.graph_outputs for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): if _has_aliased_buffers(true_outputs): @@ -7712,10 +8087,10 @@ def create( # type: ignore[no-untyped-def] # make sure true and false outputs are structurally equivalent assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) - for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)): - assert to.get_device() == fo.get_device(), (i, to, fo) - assert to.get_dtype() == fo.get_dtype(), (i, to, fo) - assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) + for i, (t_o, f_o) in enumerate(zip(true_outputs, false_outputs)): + assert t_o.get_device() == f_o.get_device(), (i, t_o, f_o) + assert t_o.get_dtype() == f_o.get_dtype(), (i, t_o, f_o) + assert t_o.get_layout().offset == f_o.get_layout().offset, (i, t_o, f_o) device = next( o.get_device() @@ -7736,7 +8111,7 @@ def create( # type: ignore[no-untyped-def] unbacked_bindings=unbacked_bindings, ) - def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.expr]: + def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]: if isinstance(s, int): return s return s.node.expr @@ -7744,7 +8119,7 @@ def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.expr]: outputs = [ MultiOutput( FixedLayout( - device=output.get_device(), + device=device, dtype=output.get_dtype(), size=[_maybe_expr(sz) for sz in merged_output.size()], stride=[_maybe_expr(sz) for sz in merged_output.stride()], @@ -7763,7 +8138,7 @@ def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.expr]: conditional.outputs = outputs # type: ignore[assignment] return outputs - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_conditional(self) wrapper.codegen_unbacked_symbol_defs_for_outputs( self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {}) @@ -7775,7 +8150,7 @@ def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: V.graph.sizevars.shape_env, unbacked_bindings ) assert resolved is not None - return resolved.keys() # type: ignore[return-value] + return OrderedSet(resolved.keys()) else: return OrderedSet() @@ -7796,16 +8171,16 @@ def _split_by_sym_type( @ir_dataclass(frozen=False) class WhileLoop(ExternKernel): - carried_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None - additional_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + carried_inputs: Optional[Sequence[IRNode]] = None + additional_inputs: Optional[Sequence[IRNode]] = None cond_subgraph: Optional[Subgraph] = None body_subgraph: Optional[Subgraph] = None - outputs: Optional[list[MultiOutput]] = None + outputs: Optional[Sequence[MultiOutput]] = None def __init__( self, - carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], - additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + carried_inputs: Sequence[IRNode], + additional_inputs: Sequence[IRNode], cond_subgraph: Subgraph, body_subgraph: Subgraph, layout: MultiOutputLayout, @@ -7815,7 +8190,9 @@ def __init__( self.cond_subgraph = cond_subgraph self.body_subgraph = body_subgraph - sym_args, tensor_args = _split_by_sym_type(carried_inputs + additional_inputs) + sym_args, tensor_args = _split_by_sym_type( + [*carried_inputs, *additional_inputs] + ) super().__init__( name=None, layout=layout, @@ -7827,19 +8204,19 @@ def __init__( V.graph.register_operation(self) @classmethod - def create( # type: ignore[no-untyped-def] + def create( cls, cond_fn: Subgraph, body_fn: Subgraph, - carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], - additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], - ): + carried_inputs: Sequence[IRNode], + additional_inputs: Sequence[IRNode], + ) -> Union[IRNode, Sequence[IRNode]]: from torch._higher_order_ops.utils import check_input_alias_and_mutation def _require_exact_strides( - tensor_boxes: list[TensorBox | ShapeAsConstantBuffer], + tensor_boxes: Sequence[IRNode], fake_tensors: list[Union[int, torch.SymInt, torch.Tensor]], - ) -> list[TensorBox | ShapeAsConstantBuffer]: + ) -> list[IRNode]: assert len(tensor_boxes) == len(fake_tensors) ret = [] for tb, fk in zip(tensor_boxes, fake_tensors): @@ -7860,17 +8237,18 @@ def _require_exact_strides( fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr] fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr] - carried_inputs = [cls.realize_input(x) for x in carried_inputs] - carried_inputs = _require_exact_strides(carried_inputs, fake_carried_inputs) - additional_inputs = [cls.realize_input(x) for x in additional_inputs] - additional_inputs = _require_exact_strides( - additional_inputs, fake_additional_inputs + carried_inputs_ = [cls.realize_input(x) for x in carried_inputs] + carried_inputs_ = _require_exact_strides(carried_inputs_, fake_carried_inputs) + additional_inputs_ = [cls.realize_input(x) for x in additional_inputs] + additional_inputs_ = _require_exact_strides( + additional_inputs_, fake_additional_inputs ) - all_inputs = carried_inputs + additional_inputs + all_inputs = carried_inputs_ + additional_inputs_ for subgraph in (cond_fn, body_fn): if subgraph.graph is None: # create and lower subgraphs + assert isinstance(fx_all_inputs, Sequence), type(fx_all_inputs) subgraph.graph = V.graph.make_subgraph( gm=subgraph.graph_module, example_inputs=fx_all_inputs, # type: ignore[arg-type] @@ -7889,12 +8267,13 @@ def _require_exact_strides( fake_carried_inputs ) subgraph.graph.graph_outputs = _require_exact_strides( # type: ignore[assignment] - subgraph.graph.graph_outputs, # type: ignore[arg-type] + subgraph.graph.graph_outputs, fake_carried_inputs, ) - cond_outputs = cond_fn.graph.graph_outputs # type: ignore[union-attr] - body_outputs = body_fn.graph.graph_outputs # type: ignore[union-attr] + assert cond_fn.graph and body_fn.graph + cond_outputs = cond_fn.graph.graph_outputs + body_outputs = body_fn.graph.graph_outputs if _has_aliased_buffers(body_outputs): raise AssertionError( @@ -7916,28 +8295,33 @@ def _require_exact_strides( device = all_inputs[0].get_device() assert device is not None # to make linter happy - # make sure carried_inputs and body outputs are structurally equivalent - assert len(carried_inputs) == len(body_outputs), (carried_inputs, body_outputs) - for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)): + # make sure carried_inputs_ and body outputs are structurally equivalent + assert len(carried_inputs_) == len(body_outputs), ( + carried_inputs_, + body_outputs, + ) + for i, (op, bo) in enumerate(zip(carried_inputs_, body_outputs)): def _guard_list_equals( - lhs_exprs: Sequence[Union[int, Any]], - rhs_exprs: Sequence[Union[int, Any]], + lhs_exprs: Sequence[Union[int, sympy.Expr]], + rhs_exprs: Sequence[Union[int, sympy.Expr]], ) -> None: + assert len(lhs_exprs) == len(rhs_exprs) for lhs, rhs in zip(lhs_exprs, rhs_exprs): - V.graph.sizevars.guard_equals(lhs, rhs) + V.graph.sizevars.check_equals(lhs, rhs) _guard_list_equals(op.get_size(), bo.get_size()) _guard_list_equals(op.get_stride(), bo.get_stride()) - # assume all carried_inputs and outputs are on the same device + # assume all carried_inputs_ and outputs are on the same device # as the MultiOutputLayout below requires single device assert op.get_device() == bo.get_device(), (i, op, bo, device) assert op.get_dtype() == bo.get_dtype(), (i, op, bo) assert op.get_layout().offset == bo.get_layout().offset, (i, op, bo) + assert device is not None while_loop = WhileLoop( - carried_inputs=carried_inputs, - additional_inputs=additional_inputs, + carried_inputs=carried_inputs_, + additional_inputs=additional_inputs_, cond_subgraph=cond_fn, body_subgraph=body_fn, # asserted above that there is at least one operand @@ -7962,7 +8346,7 @@ def _guard_list_equals( real_outputs = [ MultiOutput( FixedLayout( - device=output.get_device(), + device=output.get_device(), # type: ignore[arg-type] dtype=output.get_dtype(), size=output.get_size(), stride=output.get_stride(), @@ -7975,7 +8359,7 @@ def _guard_list_equals( ] while_loop.outputs = real_outputs while_loop.mutation_outputs = [ - MutationOutput(inp.layout, inp, while_loop) # type: ignore[union-attr] + MutationOutput(inp.layout, inp, while_loop) # type: ignore[attr-defined, union-attr] for inp in mutated_inputs ] @@ -7995,21 +8379,21 @@ def _guard_list_equals( V.graph.never_reuse_buffers.add(out.get_name()) return all_outputs - def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_while_loop(self) class EffectfulKernel(FallbackKernel): - def __init__( # type: ignore[no-untyped-def] + def __init__( self, - layout, - kernel, - tensor_args, - nontensor_args, - unflatten_args, - kwargs=None, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + kwargs: Optional[dict[str, Any]] = None, *, - unbacked_bindings=None, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, ) -> None: super().__init__( layout, @@ -8073,14 +8457,15 @@ def get_real_obj(self) -> torch.ScriptObject: def get_buf_bytes(self) -> int: # Returns the sum of all tensors in the flattened object real_script_obj = self.get_real_obj() - flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] + assert hasattr(real_script_obj, "__obj_flatten__") + flat_dict = dict(real_script_obj.__obj_flatten__()) flat_elems = pytree.tree_flatten(flat_dict)[0] flat_sizes = [ x.element_size() * x.numel() for x in flat_elems if isinstance(x, torch.Tensor) ] - return functools.reduce(lambda x, y: x + y, flat_sizes, 0) + return functools.reduce(operator.add, flat_sizes, 0) @ir_dataclass @@ -8109,7 +8494,10 @@ def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: "Setting cpp kernel needs a valid op_overload" ) kernel = self.op_overload - self.cpp_kernel_name = kernel._schema.name + if cpp_kernel_name is not None: + self.cpp_kernel_name = cpp_kernel_name + else: + self.cpp_kernel_name = kernel._schema.name self.ordered_kwargs_for_cpp_kernel = [ x.name for x in kernel._schema.arguments if x.kwarg_only @@ -8122,8 +8510,12 @@ def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: # the constraints, we model collective -> wait_tensor as as two-step # mutation of the input buffers. @classmethod - def create_inplace( # type: ignore[no-untyped-def] - cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs + def create_inplace( + cls, + kernel: _OpOverloads, + inputs: Union[IRNode, list[IRNode]], + *args: Any, + **kwargs: Any, ) -> None: with V.graph.fake_mode: ( @@ -8183,9 +8575,13 @@ def create_inplace( # type: ignore[no-untyped-def] # TODO(yifu): add a pre-grad pass to validate the correctness of collective # usage in the user program. @classmethod - def create_out_of_place( # type: ignore[no-untyped-def] - cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs - ): + def create_out_of_place( + cls, + kernel: _OpOverloads, + inputs: Union[TensorBox, list[TensorBox]], + *args: Any, + **kwargs: Any, + ) -> Union[list[MultiOutput], _CollectiveKernel]: with V.graph.fake_mode: ( example_output, @@ -8200,6 +8596,7 @@ def create_out_of_place( # type: ignore[no-untyped-def] if isinstance(example_output, list): device = cls.find_device(tensor_args, example_output) + assert device is not None packed = cls( MultiOutputLayout(device=device), kernel, @@ -8237,12 +8634,106 @@ def create_out_of_place( # type: ignore[no-untyped-def] return packed +class _AllReduce_Kernel(_CollectiveKernel): + def __init__( + self, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + kwargs: Optional[dict[str, Any]] = None, + *, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce_") + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.generate_extern_kernel_alloc(self) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + +class _AllReduceKernel(_CollectiveKernel): + def __init__( # type: ignore[no-untyped-def] + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce") + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.generate_extern_kernel_alloc(self) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + class _WaitKernel(_CollectiveKernel): - def get_volatile_reads(self): # type: ignore[no-untyped-def] + def __init__( # type: ignore[no-untyped-def] + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_wait_tensor") + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.generate_extern_kernel_alloc(self) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_volatile_reads(self) -> Sequence[IRNode]: inp = self.inputs[0] + assert isinstance(inp, IRNode) if isinstance(inp, _CollectiveKernel): # Out-of-place single-output - return [inp.inputs[0]] + i = inp.inputs[0] + assert isinstance(i, IRNode), type(i) + return [i] elif isinstance(inp, MultiOutput): # This can be two things: # 1. Out-of-place multi-output coll @@ -8260,7 +8751,7 @@ def get_volatile_reads(self): # type: ignore[no-untyped-def] return [] @classmethod - def create_wait(cls, kernel, inp: TensorBox) -> None: # type: ignore[no-untyped-def] + def create_wait(cls, kernel: _OpOverloads, inp: TensorBox) -> None: with V.graph.fake_mode: ( _example_output, diff --git a/torch/_inductor/jagged_lowerings.py b/torch/_inductor/jagged_lowerings.py index 9b393b36b42eae..5d4e17ed538a1b 100644 --- a/torch/_inductor/jagged_lowerings.py +++ b/torch/_inductor/jagged_lowerings.py @@ -5,7 +5,7 @@ import torch -from .ir import Pointwise, TensorBox +from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox from .lowering import fallback_handler, is_integer_type, register_lowering from .virtualized import ops @@ -27,7 +27,7 @@ def get_inverse_offsets( offsets: TensorBox, jagged_len: Union[int, sympy.Expr], realize: bool = True, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: """ Returns "inverse_offsets" - the inverse of the offsets array. offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor). @@ -116,7 +116,7 @@ def _jagged_to_padded_dense_forward( jagged_offsets: list[TensorBox], max_lengths: list[int], # list of ints/SymInts padding_value: float = 0.0, - ) -> TensorBox: + ) -> Union[TensorBox, ShapeAsConstantBuffer]: device = jagged_values.get_device_or_error() dtype = jagged_values.get_dtype() @@ -186,7 +186,7 @@ def _dense_to_jagged_forward_impl( dense: TensorBox, jagged_offsets: list[TensorBox], jagged_len: Optional[int] = None, - ) -> TensorBox: + ) -> Union[TensorBox, ShapeAsConstantBuffer]: device = dense.get_device_or_error() dtype = dense.get_dtype() @@ -259,7 +259,7 @@ def _dense_to_jagged_forward( dense: TensorBox, jagged_offsets: list[TensorBox], jagged_len: Optional[int] = None, - ) -> TensorBox: + ) -> Union[TensorBox, ShapeAsConstantBuffer]: return _dense_to_jagged_forward_impl( fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default, dense=dense, diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index 14fd1412412ff8..be782bac3a828a 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -13,6 +13,7 @@ TritonTemplate, ) from ..utils import ( + _use_cutlass_for_op, use_aten_gemm_kernels, use_ck_gemm_template, use_cpp_bmm_template, @@ -117,6 +118,7 @@ def _is_large_block_for_cpu(m, n, k): # inductor generates a suffix {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}} """, + cache_codegen_enabled_for_template=True, ) aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out") @@ -178,9 +180,11 @@ def may_require_contiguous(t, meta_t): ) # below is for getting an overview logging info of inductor mms - counters["aten_mm_info"][f"aten.bmm_{m}_{n}_{k}"] += 1 + batch_size = mat1.get_size()[0] # Extract batch dimension + counters["aten_mm_info"][f"aten.bmm_{batch_size}_{m}_{n}_{k}"] += 1 log.info( - "Tuned aten.bmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + "Tuned aten.bmm: batch=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + batch_size, m, n, k, @@ -201,11 +205,15 @@ def may_require_contiguous(t, meta_t): device_type = ir.get_device_type(mat1) bmm_configs = V.choices.get_base_mm_configs(device_type) + dtype = mat1.get_dtype() if use_triton_template(layout): # TODO: add out_dtype support for Triton Template assert out_dtype is None, "out_dtype is not supported for Triton" for config in bmm_configs( - m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), ): bmm_template.maybe_append_choice( choices, @@ -215,7 +223,12 @@ def may_require_contiguous(t, meta_t): ) _, is_nonzero = _is_static_problem(layout) batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout) - if batch_stride_largest and is_nonzero and use_cutlass_template(layout, m, n, k): + if ( + batch_stride_largest + and is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("bmm") + ): from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type] @@ -240,9 +253,11 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) # below is for getting an overview logging info of inductor mms - counters["aten_mm_info"][f"aten.baddbmm_{m}_{n}_{k}"] += 1 + batch_size = mat1.get_size()[0] + counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1 log.info( - "Tuned aten.baddbmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s", + "Tuned aten.baddbmm: batch_size=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s", + batch_size, m, n, k, @@ -273,6 +288,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): **mm_options(config, m, n, k, layout), prefix_args=1, epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), ) return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout) diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 4b14989c372d11..ba1dc4aa2c2452 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -438,17 +438,17 @@ def convolution( dilation = tuple(dilation) output_padding = tuple(output_padding) if not isinstance(groups, int): - groups = V.graph.sizevars.evaluate_static_shape(groups) + groups = V.graph.sizevars.guard_int(groups) assert isinstance(groups, int) # Need use hint for triton template since the template does not # work with a dynamic shape. # - # No need to evaluate_static_shape for dilation and output_padding + # No need to guard_int for dilation and output_padding # since the template is only used when dilation is 1 and output_padding # is 0. - stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride)) - padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding)) + stride = tuple(V.graph.sizevars.guard_int_seq(stride)) + padding = tuple(V.graph.sizevars.guard_int_seq(padding)) kwargs: ConvLayoutParams = { "stride": stride, @@ -468,9 +468,7 @@ def convolution( dim=0, ) - out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes( - weight.get_size() - ) + out_chan, in_chan, *kernel_shape = V.graph.sizevars.guard_int_seq(weight.get_size()) # Always convert conv1D to 2D for Intel GPU. # Only conv2D can be converted to channel last layout, @@ -539,18 +537,18 @@ def channels_last_conv(): # apply channels last. if V.graph.layout_opt and ndim == 2: V.graph.num_channels_last_conv += 1 - x = ir.ExternKernel.require_channels_last(x) + x = ir.ExternKernel.require_channels_last(x) # type: ignore[assignment] # TODO maybe we can convert weights to channels last just once before # running the model. - weight = ir.ExternKernel.require_channels_last(weight) + weight = ir.ExternKernel.require_channels_last(weight) # type: ignore[assignment] layout = conv_layout(x, weight, None, **kwargs) else: layout = conv_layout(x, weight, None, **kwargs) req_stride_order = ir.get_stride_order( V.graph.sizevars.size_hints(layout.stride) ) - x = ir.ExternKernel.require_stride_order(x, req_stride_order) - weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) + x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment] + weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment] ordered_kwargs_for_cpp_kernel = [ "stride", @@ -568,7 +566,7 @@ def channels_last_conv(): args = [x, weight, bias] bias.realize() bias.freeze_layout() - V.graph.sizevars.evaluate_static_shapes(bias.get_size()) + V.graph.sizevars.guard_int_seq(bias.get_size()) choices = [] if torch._inductor.utils._use_conv_autotune_backend("ATEN"): diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index e79683b9e8bb95..99e869dc8fdb71 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -18,7 +18,6 @@ from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.value_ranges import ValueRanges -from .. import config from ..ir import ( Buffer, ComputedBuffer, @@ -30,6 +29,7 @@ IRNode, MutationLayoutSHOULDREMOVE, Scatter, + ShapeAsConstantBuffer, StorageBox, Subgraph, TensorBox, @@ -51,7 +51,6 @@ SymbolicGridFn, TritonTemplate, ) -from ..utils import get_tma_workspace_arg log = logging.getLogger(__name__) @@ -111,7 +110,7 @@ def create_placeholder( dtype: torch.dtype, device: torch.device, size: Optional[list[int]] = None, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: """Creates a placeholder input buffers for producing subgraph_output.""" input_buffer = InputBuffer( name=name, @@ -196,7 +195,8 @@ def zeros_and_scatter_lowering(shape: list[int], indices, values): def build_subgraph_module_buffer( - args: list[TensorBox], graph_module: torch.fx.GraphModule + args: list[Union[TensorBox, ShapeAsConstantBuffer]], + graph_module: torch.fx.GraphModule, ) -> SubgraphResults: """This function's goal is to take in the required args and produce the subgraph buffer The subgraph buffer is a ComputedBuffer that will be inlined into the triton template @@ -238,10 +238,12 @@ def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: "The output node for the flex attention subgraph must be a StorageBox, but got: ", type(output_buffer), ) + device = output_buffer.data.get_device() + assert device is not None subgraph_buffer = ComputedBuffer( name=None, layout=FlexibleLayout( - device=output_buffer.data.get_device(), + device=device, dtype=output_buffer.data.get_dtype(), size=output_buffer.data.get_size(), ), @@ -252,7 +254,9 @@ def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) -def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults: +def build_subgraph_buffer( + args: list[Union[TensorBox, ShapeAsConstantBuffer]], subgraph: Subgraph +) -> SubgraphResults: return build_subgraph_module_buffer(args, subgraph.graph_module) @@ -394,41 +398,26 @@ def load_checked_2d( desc_q = None desc_k = None desc_v = None - if USE_TMA: - TMA_SIZE = 128 - workspace_base = ws_ptr + TMA_SIZE * 3 * ( - tl.program_id(1) + tl.program_id(0) * tl.num_programs(1) - ) - desc_q = workspace_base - desc_v = workspace_base + TMA_SIZE - desc_k = workspace_base + 2 * TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=desc_q, - global_address=Q, - load_size=[BLOCK_M, QK_HEAD_DIM_ROUNDED], - global_size=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], - element_ty=Q.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=desc_v, - global_address=V, - load_size=[BLOCK_N, V_HEAD_DIM_ROUNDED], - global_size=[KV_LEN*ZKV*HQ, V_HEAD_DIM], - element_ty=K.dtype.element_ty, - ) - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=desc_k, - global_address=K, - load_size=[BLOCK_N, QK_HEAD_DIM_ROUNDED], - global_size=[KV_LEN*ZKV*HQ, QK_HEAD_DIM], - element_ty=K.dtype.element_ty, - ) - - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_q) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_k) + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], + strides=[QK_HEAD_DIM, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + desc_k = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. @@ -483,15 +472,14 @@ def load_checked_2d( order=(1, 0) ) - if USE_TMA: - q = tl._experimental_descriptor_load( # load in row major - desc_q, - [(q_start * BLOCK_M).to(tl.int32), 0], - [BLOCK_M, QK_HEAD_DIM_ROUNDED], - Q.dtype.element_ty, - ) - else: + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We don't know anything "special" about these blocks, so we need to apply @@ -709,15 +697,14 @@ def forward_block_mn( # -- load k -- # NB reversed order to since K is transposed - if USE_TMA: - k = tl._experimental_descriptor_load( # load in row major - desc_k, - [start_n.to(tl.int32) , kv_start], - [BLOCK_N, QK_HEAD_DIM_ROUNDED], - MATMUL_PRECISION, - ) - else: - k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( # load in row major + desc_k, + [start_n.to(tl.int32) , kv_start], + ) + {%- else %} + k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- endif %} if USE_TMA: k = tl.trans(k) @@ -784,15 +771,14 @@ def forward_block_mn( l_i = l_i * alpha + tl.sum(p, 1) # # -- scale and update acc -- acc = acc * alpha[:, None] - if USE_TMA: - v = tl._experimental_descriptor_load( # load in row major - desc_v, - [kv_start.to(tl.int32) + start_n.to(tl.int32),0], - [BLOCK_N, V_HEAD_DIM_ROUNDED], - MATMUL_PRECISION, - ) - else: - v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_start.to(tl.int32) + start_n.to(tl.int32),0], + ) + {%- else %} + v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) # -- update m_i @@ -854,144 +840,11 @@ def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa): ) -_h100_default_config = { - (torch.float32, 64): (128, 32, 4, 3), - (torch.float32, 128): (32, 64, 4, 3), - (torch.float32, 256): (32, 32, 4, 3), - (torch.bfloat16, 64): (128, 128, 4, 3), - (torch.bfloat16, 128): (128, 64, 8, 3), - (torch.bfloat16, 256): (64, 32, 4, 3), - (torch.float16, 64): (128, 128, 4, 3), - (torch.float16, 128): (128, 128, 8, 3), - (torch.float16, 256): (64, 32, 4, 3), -} - -_a100_default_config = { - (torch.float32, 64): (128, 32, 4, 3), - (torch.float32, 128): (128, 32, 4, 3), - (torch.float32, 256): (64, 16, 4, 3), - (torch.bfloat16, 64): (128, 64, 4, 3), - (torch.bfloat16, 128): (128, 64, 8, 3), - (torch.bfloat16, 256): (32, 64, 4, 3), - (torch.float16, 64): (128, 64, 4, 3), - (torch.float16, 128): (128, 64, 8, 3), - (torch.float16, 256): (32, 64, 4, 3), -} - -_rocm_default_config = { - (torch.float32, 64): (128, 32, 4, 1), - (torch.float32, 128): (128, 32, 4, 1), - (torch.float32, 256): (64, 16, 4, 1), - (torch.bfloat16, 64): (128, 64, 8, 1), - (torch.bfloat16, 128): (128, 64, 8, 1), - (torch.bfloat16, 256): (32, 64, 8, 1), - (torch.float16, 64): (128, 64, 8, 1), - (torch.float16, 128): (128, 64, 8, 1), - (torch.float16, 256): (32, 64, 4, 1), -} - - class Mode(Enum): fwd = auto() bwd = auto() -def _get_rocm_config(query, mode: Mode) -> tuple[int, int, int, int]: - dtype = query.get_dtype() - head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) - fwd_config = None - - if mode == Mode.fwd: - if head_dim <= 256: - if dtype == torch.float32: - fwd_config = (64, 64, 4, 1) - else: - fwd_config = (128, 64, 8, 1) - fwd_config = _rocm_default_config.get((dtype, head_dim), fwd_config) - else: # modest hardware or extremely large head_dim - if dtype == torch.float32: - fwd_config = (32, 16, 4, 1) - else: - fwd_config = (64, 32, 4, 1) - return fwd_config - else: # bwd - assert mode == Mode.bwd - if dtype == torch.float32: - return (16, 16, 4, 1) - elif head_dim <= 256: - if head_dim == 64: - return (64, 64, 4, 1) - elif head_dim == 128: - return (64, 128, 8, 1) - else: - return (64, 64, 4, 1) - else: # modest hardware or extremely large head_dim - return (16, 16, 4, 1) - - -def _get_nv_config(query, mode: Mode) -> tuple[int, int, int, int]: - dtype = query.get_dtype() - head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) - fwd_config = None - bwd_config = None - capability = torch.cuda.get_device_capability() - - if mode == Mode.fwd: - if head_dim <= 256: - if dtype == torch.float32: - fwd_config = (64, 64, 4, 3) - else: - fwd_config = (128, 64, 4, 3) - if capability >= (9, 0): - fwd_config = _h100_default_config.get((dtype, head_dim), fwd_config) - elif capability >= (8, 0): - fwd_config = _a100_default_config.get((dtype, head_dim), fwd_config) - else: # modest hardware or extremely large head_dim - if dtype == torch.float32: - fwd_config = (32, 16, 4, 3) - else: - fwd_config = (64, 32, 4, 3) - return fwd_config - - else: # bwd - assert mode == Mode.bwd - if dtype == torch.float32: - bwd_config = (16, 16, 4, 1) - elif head_dim <= 256 and capability >= (9, 0): # H100 - if head_dim == 64: - bwd_config = (64, 64, 4, 3) - elif head_dim == 128: - bwd_config = (64, 128, 8, 3) - else: - bwd_config = (64, 64, 4, 2) - elif capability >= (8, 0): - if head_dim >= 64: - bwd_config = (32, 128, 4, 3) - elif head_dim == 128: - # SM86/89 have smaller shared memory sizes - num_stages = 3 if capability[-1] == 0 else 2 - bwd_config = (64, 64, 4, num_stages) - else: - bwd_config = (64, 64, 4, 2) - else: # modest hardware or extremely large head_dim - bwd_config = (16, 16, 4, 1) - return bwd_config - - -def _get_default_config_fwd(query) -> tuple[int, int, int, int]: - if torch.version.hip is None: - return _get_nv_config(query, mode=Mode.fwd) - else: - return _get_rocm_config(query, mode=Mode.fwd) - - -def _get_default_config_bwd(query) -> tuple[int, int, int, int]: - if torch.version.hip is None: - return _get_nv_config(query, mode=Mode.bwd) - else: - return _get_rocm_config(query, mode=Mode.bwd) - - def create_num_blocks_fake_generator(sparse_indices): # The idea here is that we need to create a real tensor with real data # that's representative for benchmarking. @@ -1044,7 +897,7 @@ def check_cpu_supported(): def contiguous_last_dim(x): - """Ensure that realized IR node has a contigous stride in the last dimension.""" + """Ensure that realized IR node has a contiguous stride in the last dimension.""" strides = x.maybe_get_stride() if strides and strides[-1] != 1: contiguous_stride_order = list(reversed(range(len(x.get_size())))) @@ -1099,7 +952,7 @@ def lower_cpu( cur_kvSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr shape_env = V.graph.sizevars.shape_env - # We don't know the concret value of cur_qSplitSize and cur_kvSplitSize during the compilation. + # We don't know the concrete value of cur_qSplitSize and cur_kvSplitSize during the compilation. # Mark symbols > 1 to ensure broadcasting is always applied. # This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`. shape_env.var_to_range[cur_qSplitSize] = ValueRanges(2, int_oo) @@ -1298,8 +1151,8 @@ def convert_mask_graph_module(mask_graph): skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False) # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) - SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) assert V.graph.sizevars.evaluate_expr( sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) ), ( @@ -1370,18 +1223,18 @@ def set_head_dim_values( kernel_options: Dictionary to populate with options qk_head_dim: Query/Key head dimension v_head_dim: Value head dimension - graph_sizevars: Graph size variables object with evaluate_static_shape method + graph_sizevars: Graph size variables object with guard_int method """ # QK dimensions - qk_head_dim_static = graph_sizevars.evaluate_static_shape(qk_head_dim) + qk_head_dim_static = graph_sizevars.guard_int(qk_head_dim) kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static) kernel_options.setdefault( "QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static) ) # V dimensions - v_head_dim_static = graph_sizevars.evaluate_static_shape(v_head_dim) + v_head_dim_static = graph_sizevars.guard_int(v_head_dim) kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static) kernel_options.setdefault( "V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static) @@ -1475,9 +1328,7 @@ def flex_attention( kernel_options = dict(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { - k: V.graph.sizevars.evaluate_static_shape(v) - if isinstance(v, sympy.Symbol) - else v + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v for k, v in kernel_options.items() } kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) @@ -1587,28 +1438,14 @@ def flex_attention( set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) choices: list[Any] = [] - configs: list[tuple[int, int, int, int]] = [] - configs.append(_get_default_config_fwd(query)) - if config.max_autotune: - configs += [ - (128, 64, 4, 3), - (128, 128, 4, 3), - (128, 128, 8, 2), - (64, 128, 4, 3), - (64, 64, 4, 3), - ] - # On ROCm convert num_stages to 1 to avoid shmem issues - if torch.version.hip: - configs = [(c[0], c[1], c[2], 1) for c in configs] + dtype = query.get_dtype() + head_dim = V.graph.sizevars.guard_int(query.get_size()[-1]) + configs = V.choices.get_flex_attention_fwd_configs(head_dim, dtype) # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) - SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) - - # ROCm specific considerations - if torch.version.hip: - kernel_options["kpack"] = 2 + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function @@ -1617,8 +1454,11 @@ def flex_attention( # Default config for warp specialization num_consumer_groups, num_buffers_warp_spec = 0, 0 - for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: - if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0: + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + ): if len(configs) == 1: raise ValueError( f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We " @@ -1636,8 +1476,8 @@ def flex_attention( cur_kernel_options[k[4:]] = v if k.startswith("bwd_"): cur_kernel_options.pop(k) - cur_kernel_options.setdefault("num_stages", num_stages) - cur_kernel_options.setdefault("num_warps", num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) if cur_kernel_options.get("num_consumer_groups", False): cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) cur_kernel_options.setdefault( @@ -1647,26 +1487,17 @@ def flex_attention( # Disabling TMA by default, only explicit kernel_options supported for now cur_kernel_options.setdefault("USE_TMA", False) - cur_kernel_options.setdefault("BLOCK_M", BLOCK_M) - cur_kernel_options.setdefault("BLOCK_N", BLOCK_N) + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) # Blocksparse options cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - workspace_arg = None - if cur_kernel_options.get("USE_TMA", False): - seq_len_q = V.graph.sizevars.evaluate_static_shape(seq_len_q) - - grid = flex_attention_grid( - Bq, Hq, seq_len_q, qk_head_dim, cur_kernel_options - ) + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) - num_programs = grid[0] * grid[1] * grid[2] - workspace_arg = get_tma_workspace_arg( - num_tma_descriptors=3, - device=query.get_device(), - num_programs=num_programs, - ) error = flex_attention_template.maybe_append_choice( choices=choices, input_nodes=[ @@ -1687,7 +1518,6 @@ def flex_attention( mutated_inputs=[ logsumexp, ], - workspace_arg=workspace_arg, call_sizes=query.get_size(), **cur_kernel_options, ) @@ -1860,7 +1690,7 @@ def flex_attention_backward_grid( sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 - # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) @@ -1968,7 +1798,7 @@ def flex_attention_backward_grid( for off_g in range(0, GQA_SHARED_HEADS): off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g - # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) @@ -2605,9 +2435,7 @@ def flex_attention_backward(*args, **kwargs): kernel_options = dict(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { - k: V.graph.sizevars.evaluate_static_shape(v) - if isinstance(v, sympy.Symbol) - else v + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v for k, v in kernel_options.items() } kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) @@ -2722,39 +2550,25 @@ def flex_attention_backward(*args, **kwargs): set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) - SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) choices: list[Any] = [] - configs: list[tuple[int, int, int, int]] = [] - configs.append(_get_default_config_bwd(query)) + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.guard_int(query.get_size()[-1]) + configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype) + # Default config for warp specialization num_consumer_groups, num_buffers_warp_spec = 0, 0 - if config.max_autotune: - num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1] - - configs.extend( - [ - (BLOCK1, BLOCK2, w, s) - for BLOCK1 in [32, 64] - for BLOCK2 in [32, 64, 128] - for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) - for s in num_stages_list - if BLOCK2 % BLOCK1 == 0 - ] - ) + original_kernel_options = kernel_options.copy() - for ( - BLOCK1, - BLOCK2, - num_warps, - num_stages, - ) in configs: + for conf in configs: if ( - SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0 - or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0 - or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0 - or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0 + SPARSE_KV_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_n != 0 ): continue @@ -2768,8 +2582,8 @@ def flex_attention_backward(*args, **kwargs): cur_kernel_options[k[4:]] = v if k.startswith("fwd_"): cur_kernel_options.pop(k) - cur_kernel_options.setdefault("num_warps", num_warps) - cur_kernel_options.setdefault("num_stages", num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) if cur_kernel_options.get("num_consumer_groups", False): cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) @@ -2777,14 +2591,20 @@ def flex_attention_backward(*args, **kwargs): "num_buffers_warp_spec", num_buffers_warp_spec ) - cur_kernel_options.setdefault("BLOCK_M1", BLOCK1) - cur_kernel_options.setdefault("BLOCK_N1", BLOCK2) - cur_kernel_options.setdefault("BLOCK_M2", BLOCK2) - cur_kernel_options.setdefault("BLOCK_N2", BLOCK1) + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N1", conf.block_n) + cur_kernel_options.setdefault("BLOCK_M2", conf.block_n) + cur_kernel_options.setdefault("BLOCK_N2", conf.block_m) + # Blocksparse options cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + flex_attention_backward_template.maybe_append_choice( choices=choices, input_nodes=[ diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index ec8fbc0808546c..7e0aef98185603 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -8,7 +8,7 @@ import torch from torch._inductor.virtualized import V -from .. import config, ir +from .. import ir from ..ir import FixedLayout, FlexibleLayout from ..lowering import empty, empty_strided, lowerings from ..runtime.runtime_utils import is_power_of_2, next_power_of_2 @@ -321,21 +321,6 @@ def get_split_k(B: int, H: int, Mk: int) -> int: return split_k -def _get_decoding_default_config(key) -> tuple[int, int, int]: - dtype = key.get_dtype() - head_dim = key.get_size()[-1] - sm_version = torch.cuda.get_device_capability() - default_config = (64, 2, 1) - if sm_version >= (9, 0): - if head_dim > 128 and dtype == torch.float32: - return default_config - if torch.version.hip is None: - return (64, 2, 3) - else: - return (64, 2, 1) - return default_config - - def create_flex_decoding_kernel(*args, **kwargs): from .flex_attention import set_head_dim_values @@ -378,9 +363,7 @@ def create_flex_decoding_kernel(*args, **kwargs): kernel_options = dict(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { - k: V.graph.sizevars.evaluate_static_shape(v) - if isinstance(v, sympy.Symbol) - else v + k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v for k, v in kernel_options.items() } @@ -430,19 +413,9 @@ def create_flex_decoding_kernel(*args, **kwargs): mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) choices: list[Any] = [] - configs: list[tuple[int, int, int]] = [] - configs.append(_get_decoding_default_config(key)) - # Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops. - if config.max_autotune: - configs += [ - (64, 2, 2), - (32, 2, 3), - (128, 2, 3), - ] - - # Use num_stages=1 on ROCm to avoid shmem limitation - if torch.version.hip: - configs = [(c[0], c[1], 1) for c in configs] + dtype = key.get_dtype() + head_dim = V.graph.sizevars.guard_int(key.get_size()[-1]) + configs = V.choices.get_flex_decode_configs(head_dim, dtype) # TODO: fix autotuning. @@ -508,7 +481,7 @@ def create_flex_decoding_kernel(*args, **kwargs): ) query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride) - V.graph.sizevars.guard_leq( + V.graph.sizevars.check_leq( seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"]) ) @@ -519,7 +492,7 @@ def create_flex_decoding_kernel(*args, **kwargs): # TODO: This feels sketchy kernel_options.setdefault("SAFE_N_BOUNDARY", True) # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE) original_kernel_options = kernel_options.copy() # Note, we don't need to pass in the captured buffers explicitly @@ -529,8 +502,8 @@ def create_flex_decoding_kernel(*args, **kwargs): # Default config for warp specialization num_consumer_groups, num_buffers_warp_spec = 0, 0 - for BLOCK_N, num_warps, num_stages in configs: - if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0: + for conf in configs: + if SPARSE_KV_BLOCK_SIZE % conf.block_n != 0: continue cur_kernel_options = original_kernel_options.copy() @@ -542,10 +515,10 @@ def create_flex_decoding_kernel(*args, **kwargs): if k.startswith("bwd_"): cur_kernel_options.pop(k) # Performance tuning - cur_kernel_options.setdefault("BLOCK_N", BLOCK_N) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - cur_kernel_options.setdefault("num_warps", num_warps) - cur_kernel_options.setdefault("num_stages", num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) if cur_kernel_options.get("num_consumer_groups", False): cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) @@ -556,6 +529,11 @@ def create_flex_decoding_kernel(*args, **kwargs): # Set default to False cur_kernel_options.setdefault("USE_TMA", False) + # Add ROCm-specific parameters if they exist in the config + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + flex_decoding_template.maybe_append_choice( choices=choices, input_nodes=[ diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 20cfa766dc31ea..f1c77afd52fd2a 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -38,14 +38,15 @@ TritonTemplate, ) from ..utils import ( + _use_cutlass_for_op, get_k_splits, get_tma_workspace_arg, use_aten_gemm_kernels, use_ck_gemm_template, + use_ck_tile_gemm_template, use_cpp_gemm_template, use_cutlass_template, use_decompose_k_choice, - use_max_autotune, use_triton_template, use_triton_tma_template, ) @@ -264,6 +265,7 @@ a_desc_ptr = workspace_base b_desc_ptr = workspace_base + TMA_SIZE + {%- if TMA_EXPERIMENTAL_API %} triton.language.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=a_desc_ptr, global_address=A, @@ -282,6 +284,23 @@ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + a_desc = a_desc_ptr + b_desc = b_desc_ptr + {%- else %} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[K, 1] if A_ROW_MAJOR else [M, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[N, 1] if B_ROW_MAJOR else [K, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + {%- endif %} + pid_m = 0 pid_n = 0 rm = 0 @@ -302,18 +321,29 @@ rk = ki * BLOCK_K + {%- if TMA_EXPERIMENTAL_API %} a = tl._experimental_descriptor_load( - a_desc_ptr, + a_desc, [rm, rk] if A_ROW_MAJOR else [rk, rm], [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], A.dtype.element_ty, ) b = tl._experimental_descriptor_load( - b_desc_ptr, + b_desc, [rk, rn] if B_ROW_MAJOR else [rn, rk], [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], B.dtype.element_ty, ) + {%- else %} + a = tl.load_tensor_descriptor( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + ) + b = tl.load_tensor_descriptor( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + ) + {%- endif %} acc += tl.dot( a if A_ROW_MAJOR else a.T, b if B_ROW_MAJOR else b.T, @@ -415,6 +445,7 @@ def apply_scaling( a_desc_ptr = workspace_base b_desc_ptr = workspace_base + TMA_SIZE + {%- if TMA_EXPERIMENTAL_API %} triton.language.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=a_desc_ptr, global_address=A, @@ -433,6 +464,23 @@ def apply_scaling( tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + a_desc = a_desc_ptr + b_desc = a_desc_ptr + {%- else %} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + {%- endif %} + tiles_per_SM = num_tiles // NUM_SMS if start_pid < num_tiles % NUM_SMS: tiles_per_SM += 1 @@ -464,12 +512,17 @@ def apply_scaling( offs_k = ki * BLOCK_K + {%- if TMA_EXPERIMENTAL_API %} a = tl._experimental_descriptor_load( a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty ) b = tl._experimental_descriptor_load( b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty ) + {%- else %} + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + {%- endif %} if USE_FAST_ACCUM: accumulator = tl.dot(a, b.T, accumulator) else: @@ -510,7 +563,7 @@ def apply_scaling( # prevent duplication registration of extern functions -@functools.lru_cache(None) +@functools.cache def lazy_register_extern_choice(fn): return ExternKernelChoice(fn) @@ -629,7 +682,7 @@ def tuned_mm(mat1, mat2, *, layout=None): ) aten_layout = layout - if not use_max_autotune(): + if not (inductor_config.max_autotune or inductor_config.max_autotune_gemm): aten_layout = FlexibleLayout( device=layout.device, dtype=layout.dtype, size=layout.size ) @@ -644,12 +697,13 @@ def tuned_mm(mat1, mat2, *, layout=None): persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) extra_mm_configs = V.choices.get_extra_mm_configs(device_type) + dtype = mat1.get_dtype() if is_nonzero and use_triton_template(layout): for config in mm_configs( m, n, k, - **mm_config_kwargs(device_type, _is_large_block_for_cpu), + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), ): mm_template.maybe_append_choice( choices, @@ -663,7 +717,9 @@ def tuned_mm(mat1, mat2, *, layout=None): m, n, k, - **mm_config_kwargs(device_type, _is_large_block_for_cpu), + **mm_config_kwargs( + device_type, _is_large_block_for_cpu, dtype.itemsize + ), ): persistent_tma_mm_template.maybe_append_choice( choices, @@ -697,7 +753,7 @@ def tuned_mm(mat1, mat2, *, layout=None): k_splits = get_k_splits(m, n, k) for k_split in k_splits: - if not V.graph.sizevars.is_expr_static_and_true( + if not V.graph.sizevars.statically_known_true( sympy.Eq(sympy.Mod(k, k_split), 0) ): continue @@ -719,11 +775,16 @@ def tuned_mm(mat1, mat2, *, layout=None): layout=layout, ) - if is_nonzero and use_cutlass_template(layout, m, n, k): + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("mm") + ): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k): CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2]) if use_cpp_gemm_template(layout, mat1, mat2): @@ -813,7 +874,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None): [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] ) - if use_cutlass: + if use_cutlass and _use_cutlass_for_op("int_mm"): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True ) @@ -852,7 +913,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): layout, ) - if (not is_nonzero) or (not use_max_autotune()): + if (not is_nonzero) or ( + not (inductor_config.max_autotune or inductor_config.max_autotune_gemm) + ): # Use a FlexibleLayout if we are not autotuning. # This allows padding strides for the output. from torch._inductor.ir import FixedLayout, FlexibleLayout @@ -905,9 +968,13 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): mm_configs = V.choices.get_base_mm_configs(device_type) persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + dtype = mat1.get_dtype() if is_nonzero and use_triton_template(layout): for config in mm_configs( - m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), ): mm_template.maybe_append_choice( choices, @@ -916,11 +983,17 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): **mm_options(config, m, n, k, layout), prefix_args=1, epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), ) if use_triton_tma_template(mat1, mat2): for config in persistent_mm_configs( - m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + m, + n, + k, + **mm_config_kwargs( + device_type, _is_large_block_for_cpu, dtype.itemsize + ), ): persistent_tma_mm_template.maybe_append_choice( choices, @@ -936,7 +1009,11 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), ) - if is_nonzero and use_cutlass_template(layout, m, n, k): + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("addmm") + ): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, @@ -981,8 +1058,8 @@ def tuned_sparse_semi_structured_mm( m1, k1 = mat1.get_size() m2, _ = mat1_meta.get_size() k2, n = mat2.get_size() - m = V.graph.sizevars.guard_equals(m1, m2) - k = V.graph.sizevars.guard_equals(2 * k1, k2) + m = V.graph.sizevars.check_equals_and_simplify(m1, m2) + k = V.graph.sizevars.check_equals_and_simplify(2 * k1, k2) if layout is None: from torch._inductor.ir import FixedLayout @@ -1006,7 +1083,11 @@ def tuned_sparse_semi_structured_mm( else [] ) - if m * n != 0 and use_cutlass_template(layout, m, n, k): + if ( + m * n != 0 + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("sparse_semi_structured_mm") + ): CUTLASS2xGemmTemplate.add_cutlass_gemm_choices( choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True ) @@ -1031,6 +1112,21 @@ def tuned_scaled_mm( use_fast_accum=False, layout=None, ): + """ + Performs an optimized matrix multiplication where scaling factors are applied + to the inputs and/or output. + + Args: + mat1 (Tensor): First input matrix + mat2 (Tensor): Second input matrix + scale1 (Tensor): Scale factor applied to mat1 (supports broadcasting) + scale2 (Tensor): Scale factor applied to mat2 (supports broadcasting) + bias (Tensor, optional): Optional bias tensor to add to the result + layout: Layout hint for optimization + + Returns: + Tensor: The result of the scaled matrix multiplication + """ m, n, k, layout, mat_a, mat_b = mm_args( mat_a, mat_b, layout=layout, out_dtype=out_dtype ) @@ -1135,12 +1231,13 @@ def tuned_scaled_mm( ) for config in scaled_mm_configs(m, n, k): - if k <= 16: - continue # Triton crashes in this case + if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)): + # Triton crashes however uncommon for real workloads + continue # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape - if using_b200() and k < 32: + if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)): continue kwargs = scaled_mm_options( @@ -1154,15 +1251,20 @@ def tuned_scaled_mm( **kwargs, suffix_args=suffix_args, epilogue_fn=scale_mm_epilogue(), + epilogue_fn_hash="scale_mm_epilogue", ) - if is_nonzero and use_cutlass_template(layout, m, n, k): - if use_fast_accum: - log.warning( - "use_fast_accum=True is not supported by cutlass template, skipping cutlass choices" - ) - else: - CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, input_nodes) # type: ignore[arg-type] + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("scaled_mm") + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + input_nodes, # type: ignore[arg-type] + use_fast_accum=use_fast_accum, # type: ignore[arg-type] + ) if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) @@ -1170,7 +1272,7 @@ def tuned_scaled_mm( return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) -@functools.lru_cache(None) +@functools.cache def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: props = torch.cuda.get_device_properties(index or 0) return props.major <= 7 diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 17224d17d29fde..9caeffe5013c23 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -38,7 +38,8 @@ def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min): @SymbolicGridFn -def persistent_grouped_mm_grid(m, n, meta): +def persistent_grouped_mm_grid(*args): + meta = args[-1] return (meta["NUM_SMS"], 1, 1) @@ -78,13 +79,21 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): return options_dict +def tma_options() -> dict[str, Any]: + from torch.utils._triton import has_triton_stable_tma_api + + return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()} + + def persistent_mm_options(mat1, mat2): - return dict( + res = dict( A_ROW_MAJOR=not mat1.layout.is_transposed(), B_ROW_MAJOR=not mat2.layout.is_transposed(), NUM_SMS=get_num_sms(), TMA_SIZE=TMA_DESCRIPTOR_SIZE, ) + res.update(tma_options()) + return res def scaled_mm_options( # type: ignore[no-untyped-def] @@ -99,7 +108,7 @@ def scaled_mm_options( # type: ignore[no-untyped-def] device_tma: bool = False, ) -> dict[str, Any]: def are_compatible_scales(size_a, size_b) -> bool: - # Same sized scales are compatable + # Same sized scales are compatible if len(size_a) == len(size_b): return True @@ -125,6 +134,8 @@ def are_compatible_scales(size_a, size_b) -> bool: mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE mm_template_options["NUM_SMS"] = get_num_sms() + mm_template_options.update(tma_options()) + return mm_template_options @@ -146,10 +157,10 @@ def mm_args( *b2, n, k2 = mat2.get_size() else: *b2, k2, n = mat2.get_size() - b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)] + b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)] if use_4x2_dim: k2 = k2 * 2 - k = V.graph.sizevars.guard_equals(k1, k2) + k = V.graph.sizevars.check_equals_and_simplify(k1, k2) if layout is None: from torch._inductor.ir import FixedLayout @@ -170,12 +181,17 @@ def mm_args( return [m, n, k, layout, mat1, mat2, *others] -def mm_config_kwargs(device, exclude_condition): +def mm_config_kwargs(device, exclude_condition, dtype_size=None): if device == "cpu": return { "scale": 0.5, "exclude": exclude_condition, } + + if dtype_size and inductor_config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return { + "dtype_size": dtype_size, + } return {} diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py new file mode 100644 index 00000000000000..d311b62950bd2a --- /dev/null +++ b/torch/_inductor/kernel/mm_grouped.py @@ -0,0 +1,741 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from torch._dynamo.utils import counters +from torch._inductor.runtime.triton_compat import tl +from torch._inductor.virtualized import V +from torch.utils._triton import has_triton + +from ..ir import ChoiceCaller, Layout, TensorBox +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + get_gpu_shared_memory, + get_num_sms, + has_free_symbols, + use_aten_gemm_kernels, +) +from .mm_common import ( + _is_static_problem, + check_supported_striding, + persistent_grouped_mm_grid, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass +class Config: + kwargs: dict[str, int] + num_stages: int + num_warps: int + + +_NV_CONFIGS = [ + Config( + { + "BLOCK_M": block_size_m, + "BLOCK_N": block_size_n, + "BLOCK_K": block_size_k, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m in [16, 32, 64, 128] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] +] + + +def grouped_mm_configs(): + return _NV_CONFIGS + + +def early_config_prune(g, m, configs, named_args): + dtsize = 1 + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, num_consumer_groups = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + config.num_stages, + config.num_warps, + getattr(config, "num_consumer_groups", 0), + ) + + # 1. Prune NV configs depending on g and m. + if not has_free_symbols((g, m)): + a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"] + m_avg = m // g if a_is_2d and not b_is_2d else m + if m_avg <= 16: + if BLOCK_M > 32: + continue + elif m_avg <= 32: + if BLOCK_M > 64: + continue + elif m_avg <= 64: + if BLOCK_M <= 16: + continue + else: + if BLOCK_M <= 32: + continue + + # 2. make sure we have enough smem + max_shared_memory = get_gpu_shared_memory() + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + use_warp_specialization = num_consumer_groups >= 1 + + # 3. make sure we can partition for ws + if use_warp_specialization: + if num_warps != 4: + continue + + # "tritongpu-warp-spec-data-partition" + m_slice = BLOCK_M // num_consumer_groups + n_slice = BLOCK_N // num_consumer_groups + if m_slice < 64 and n_slice < 256: + continue + + pruned_configs.append(config) + + return pruned_configs + + +triton_grouped_mm_source = r""" +{%- if SCALED %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}} +{%- endif %} +{%- else %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr")}} +{%- endif %} +{%- endif %} + tidx = tl.program_id(0) + +{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %} +{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %} +{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %} + +{%- if A_IS_2D %} +{%- if B_IS_2D %} + G = {{size("offsets_ptr", 0)}} +{%- else %} + G = {{size("b_ptr", 0)}} +{%- endif %} +{%- else %} +{%- if B_IS_2D %} + G = {{size("a_ptr", 0)}} +{%- else %} + G = {{size("a_ptr", 0)}} +{%- endif %} +{%- endif %} + + # the b_ptr tensor is given with its last two dims transposed, revert here + + M = {{size("a_ptr", -2)}} + N = {{size("b_ptr", -1)}} + K = {{size("a_ptr", -1)}} + + A_STRIDE_M = {{stride("a_ptr", -2)}} + A_STRIDE_K = {{stride("a_ptr", -1)}} +{%- if not A_IS_2D %} + A_STRIDE_G = {{stride("a_ptr", 0)}} +{%- if SCALED %} + SCALE_A_STRIDE_G = {{stride("scale_a_ptr", 0)}} +{%- endif %} +{%- endif %} + B_STRIDE_N = {{stride("b_ptr", -1)}} + B_STRIDE_K = {{stride("b_ptr", -2)}} +{%- if not B_IS_2D %} + B_STRIDE_G = {{stride("b_ptr", 0)}} +{%- if SCALED %} + SCALE_B_STRIDE_G = {{stride("scale_b_ptr", 0)}} +{%- endif %} +{%- endif %} + +{%- if USE_TMA_LOAD %} +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + a_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + a_desc = tl.make_tensor_descriptor( +{%- endif %} + a_ptr, +{%- if A_IS_2D %} + shape=[M, K], + # fixme: strides=[A_STRIDE_M, A_STRIDE_K], + strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}], + block_shape=[BLOCK_M, BLOCK_K], +{%- else %} + shape=[G, M, K], + # fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K], + strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}], + block_shape=[1, BLOCK_M, BLOCK_K], +{%- endif %} + ) + +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + b_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + b_desc = tl.make_tensor_descriptor( +{%- endif %} + b_ptr, +{%- if B_IS_2D %} + shape=[N, K], + # fixme: strides=[B_STRIDE_N, B_STRIDE_K], + strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}], + block_shape=[BLOCK_N, BLOCK_K], +{%- else %} + shape=[G, N, K], + # fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K], + strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}], + block_shape=[1, BLOCK_N, BLOCK_K], +{%- endif %} + ) +{%- endif %} + +{%- if M_IS_VARYING %} + m_end_offset = 0 +{%- endif %} +{%- if N_IS_VARYING %} + n_end_offset = 0 +{%- endif %} +{%- if K_IS_VARYING %} + k_end_offset = 0 +{%- endif %} + iterated_tiles = 0 + for g in tl.range(G): +{%- if M_IS_VARYING %} + # Move across groups + m_start_offset = m_end_offset + m_end_offset = tl.load(offsets_ptr + g) + m_size = m_end_offset - m_start_offset +{%- if SCALED %} + m_scale_start_offset = m_start_offset +{%- endif %} +{%- else %} + m_start_offset = 0 + m_size = M +{%- if SCALED %} + m_scale_start_offset = g * M +{%- endif %} +{%- endif %} + +{%- if N_IS_VARYING %} + # Move across groups + n_start_offset = n_end_offset + n_end_offset = tl.load(offsets_ptr + g) + n_size = n_end_offset - n_start_offset +{%- if SCALED %} + n_scale_start_offset = n_start_offset +{%- endif %} +{%- else %} + n_start_offset = 0 + n_size = N +{%- if SCALED %} + n_scale_start_offset = g * N +{%- endif %} +{%- endif %} + + if m_size > 0 and n_size > 0: +{%- if K_IS_VARYING %} + # Move across groups + k_start_offset = k_end_offset + k_end_offset = tl.load(offsets_ptr + g) + k_size = k_end_offset - k_start_offset +{%- else %} + k_start_offset = 0 + k_size = K +{%- endif %} + + num_m_tiles = tl.cdiv(m_size, BLOCK_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_N) + num_tiles = num_m_tiles * num_n_tiles + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{%- if USE_TMA_LOAD %} + m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32) + n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32) + + for k_offset in range(0, k_size, BLOCK_K): +{%- if A_IS_2D %} + a = a_desc.load([m_offset, k_start_offset + k_offset]) +{%- else %} + a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K) +{%- endif %} +{%- if B_IS_2D %} + b = b_desc.load([n_offset, k_start_offset + k_offset]) +{%- else %} + b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K) +{%- endif %} + +{%- if K_IS_VARYING %} + if k_offset + BLOCK_K > k_size: + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + a = tl.where(group_offs_k < k_size, a, 0) + b = tl.where(group_offs_k < k_size, b, 0) +{%- endif %} + +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} +{%- else %} + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = k_start_offset + tl.arange(0, BLOCK_K) + a_ptrs = ( + a_ptr +{%- if not A_IS_2D %} + + g * A_STRIDE_G +{%- endif %} + + (m_start_offset + offs_am[:, None]) * A_STRIDE_M + + offs_k[None, :] * A_STRIDE_K + ) + b_ptrs = ( + b_ptr +{%- if not B_IS_2D %} + + g * B_STRIDE_G +{%- endif %} + + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N + + offs_k[None, :] * B_STRIDE_K + ) + for k_offset in range(0, k_size, BLOCK_K): + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + if k_offset + BLOCK_K > k_size: + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + a = tl.where(group_offs_k < k_size, a, 0) + b = tl.where(group_offs_k < k_size, b, 0) +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} + a_ptrs += BLOCK_K + b_ptrs += BLOCK_K +{%- endif %} + + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) +{%- if SCALED %} + scale_a = tl.load( + scale_a_ptr +{%- if A_IS_2D %} + + m_scale_start_offset +{%- else %} + + g * SCALE_A_STRIDE_G +{%- endif %} + + offs_am[:, None], + mask=offs_am[:, None] < m_size, + ) + scale_b = tl.load( + scale_b_ptr +{%- if B_IS_2D %} + + n_scale_start_offset +{%- else %} + + g * SCALE_B_STRIDE_G +{%- endif %} + + offs_bn[None, :], + mask=offs_bn[None, :] < n_size, + ) + c = accumulator.to(tl.float32) * scale_a * scale_b +{%- else %} + c = accumulator.to(tl.float32) +{%- endif %} + +{%- if M_IS_VARYING %} + idx_m = (m_start_offset + offs_am[:, None]) +{%- else %} + idx_m = offs_am[:, None] +{%- endif %} +{%- if N_IS_VARYING %} + idx_n = (n_start_offset + offs_bn[None, :]) +{%- else %} + idx_n = offs_bn[None, :] +{%- endif %} + mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size +{%- if M_IS_VARYING or N_IS_VARYING %} + {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16)}} +{%- else %} + {{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16)}} +{%- endif %} + tidx += NUM_SMS + + iterated_tiles += num_tiles +""" + + +triton_grouped_mm_template = TritonTemplate( + name="grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + +triton_scaled_grouped_mm_template = TritonTemplate( + name="scaled_grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + + +def grouped_mm_args( + mat1: TensorBox, + mat2: TensorBox, + offs: Optional[TensorBox], + layout=None, + out_dtype=None, +): + mat1, mat2 = realize_inputs(mat1, mat2) + if offs is not None: + realize_inputs(offs) + mat1_size = mat1.get_size() + mat2_size = mat2.get_size() + + m1dim, m2dim = len(mat1_size), len(mat2_size) + + assert m1dim == 2 or m1dim == 3 + assert m2dim == 2 or m2dim == 3 + + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + + dims = [] + if m1dim == 2: + if m2dim == 2: + assert offs is not None + dims = [offs.get_size()[0], mat1_size[0], mat2_size[1]] + else: + dims = [mat1_size[0], mat2_size[-1]] + else: + if m2dim == 2: + dims = [mat1_size[1], mat2_size[1]] + else: + dims = [mat1_size[0], mat1_size[1], mat2_size[-1]] + layout = FixedLayout( + mat1.get_device(), + out_dtype, + dims, + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + return (mat1_size, mat2_size, layout, mat1, mat2, offs) + + +aten__grouped_mm = ExternKernelChoice( + torch._grouped_mm, + "at::_grouped_mm", + op_overload=aten._grouped_mm, + has_out_variant=False, +) + + +aten__scaled_grouped_mm = ExternKernelChoice( + torch._scaled_grouped_mm, + "at::_scaled_grouped_mm", + op_overload=aten._scaled_grouped_mm, + has_out_variant=False, +) + + +def can_use_triton_kernel( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox], + bias: Optional[TensorBox], + scale_result: Optional[TensorBox], +) -> bool: + if not ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + return False + if not has_triton(): + return False + + # The _grouped_mm()/_scaled_grouped_mm() operator do not support + # bias nor scale_result yet. + if bias is not None: + return False + if scale_result is not None: + return False + + if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2: + return offs is not None + else: + return offs is None + + +def create_offsets(x, m1_size, m2_size, offs_size): + m1_is_2d = len(m1_size) == 2 + m2_is_2d = len(m2_size) == 2 + if m1_is_2d: + if m2_is_2d: + k = V.graph.sizevars.size_hint(m1_size[1]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = k / noffs + return torch.linspace( + step, k, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + + else: + m = V.graph.sizevars.size_hint(m1_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = m / noffs + return torch.linspace( + step, m, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + if m2_is_2d: + n = V.graph.sizevars.size_hint(m2_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = n / noffs + return torch.linspace( + step, n, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + return None + + +def _tuned_grouped_mm_common( + operator_name: str, + algorithm_name: str, + extern_kernel_choice: ExternKernelChoice, + kernel_template: TritonTemplate, + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: Optional[TensorBox] = None, + scale_b: Optional[TensorBox] = None, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: Optional[bool] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + assert (scale_a is None) == (scale_b is None) + assert scale_result is None or scale_a is not None + + m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args( + mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype + ) + counters["aten_mm_info"][operator_name] += 1 + log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s" + log.info( + log_message, + m1_size, + m2_size, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + if scale_a is not None and scale_b is not None: + check_supported_striding(mat_a, mat_b) + + # workaround for Inductor not supporting optional tensor input arguments + input_nodes: list[Any] = [mat_a, mat_b] + if scale_a is not None: + input_nodes.append(realize_inputs(scale_a)) + if scale_b is not None: + input_nodes.append(realize_inputs(scale_b)) + if offs is not None: + input_nodes.append(realize_inputs(offs)) + + if use_fast_accum is None: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + ) + else: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + if use_fast_accum is None: + use_fast_accum = False + + choices: list[ChoiceCaller] = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + # Checking only for the equality of corresponding dims of + # multiplicands here, relying on meta function checks for + # everything else. + if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result): + scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False + + triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor") + triton_has_experimental_make_tensor_descriptor = hasattr( + tl, "_experimental_make_tensor_descriptor" + ) + use_tma_load = ( + triton_has_make_tensor_descriptor + or triton_has_experimental_make_tensor_descriptor + ) + # The make_tensor_descriptor imposes this additional limitation. + use_tma_load = use_tma_load and ( + mat_a.get_stride()[-1] == 1 and mat_b.get_stride()[-2] == 1 + ) + + kwargs = { + "SCALED": scaled, + "A_IS_2D": a_is_2d, + "B_IS_2D": b_is_2d, + "USE_FAST_ACCUM": use_fast_accum, + "NUM_SMS": get_num_sms(), + "USE_TMA_LOAD": use_tma_load, + "USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR": triton_has_experimental_make_tensor_descriptor, + } + + for config in early_config_prune(g, m, grouped_mm_configs(), kwargs): + kernel_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + num_stages=config.num_stages, + num_warps=config.num_warps, + **kwargs, + **config.kwargs, + ) + + input_gen_fns = { + 4: lambda x: create_offsets( + x, m1_size, m2_size, offs.get_size() if offs is not None else None + ), + } + return autotune_select_algorithm( + algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns + ) + + +@register_lowering(aten._grouped_mm.default, type_promotion_kind=None) +def tuned_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._grouped_mm.default", + "grouped_mm", + aten__grouped_mm, + triton_grouped_mm_template, + mat_a, + mat_b, + None, + None, + offs, + bias, + None, + out_dtype, + None, + layout, + ) + + +@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None) +def tuned_scaled_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: TensorBox, + scale_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _scaled_grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._scaled_grouped_mm.default", + "scaled_grouped_mm", + aten__scaled_grouped_mm, + triton_scaled_grouped_mm_template, + mat_a, + mat_b, + scale_a, + scale_b, + offs, + bias, + scale_result, + out_dtype, + use_fast_accum, + layout, + ) diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index 5447c27f4f07a0..64249e6fb57a9e 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -111,6 +111,7 @@ # inductor generates a suffix {{store_output(("idx_m", "idx_n"), "acc", "mask")}} """, + cache_codegen_enabled_for_template=True, ) @@ -148,7 +149,6 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): ) mm_configs = V.choices.get_mm_plus_mm_configs(device_type) - if use_triton_template(layout1): for config in mm_configs(): # see https://github.com/triton-lang/triton/issues/1298 diff --git a/torch/_inductor/kernel/mm_scaled_grouped.py b/torch/_inductor/kernel/mm_scaled_grouped.py deleted file mode 100644 index 7e86516f865da6..00000000000000 --- a/torch/_inductor/kernel/mm_scaled_grouped.py +++ /dev/null @@ -1,479 +0,0 @@ -# mypy: allow-untyped-defs -import logging -from dataclasses import dataclass -from typing import Any, Optional - -import torch -from torch._dynamo.utils import counters -from torch._inductor.virtualized import V -from torch.utils._triton import has_triton_tma_device - -from ..ir import ChoiceCaller, Layout, TensorBox -from ..lowering import register_lowering -from ..runtime.runtime_utils import next_power_of_2 -from ..select_algorithm import ( - autotune_select_algorithm, - ExternKernelChoice, - realize_inputs, - TritonTemplate, -) -from ..utils import ( - get_gpu_shared_memory, - get_num_sms, - get_tma_workspace_arg, - use_aten_gemm_kernels, -) -from .mm_common import ( - _is_static_problem, - check_supported_striding, - persistent_grouped_mm_grid, -) - - -log = logging.getLogger(__name__) -aten = torch.ops.aten - - -@dataclass -class Config: - kwargs: dict[str, int] - num_stages: int - num_warps: int - - -_NV_CONFIGS = [ - Config( - { - "BLOCK_M": block_size_m, - "BLOCK_N": block_size_n, - "BLOCK_K": block_size_k, - "NUM_CONSUMER_GROUPS": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - for block_size_m in [64, 128] - for block_size_n in [64, 128, 256] - for block_size_k in [64, 128, 256] - for num_stages in [3, 4] - for num_warps in [4, 8] -] - -_AMD_CONFIGS = [ - Config( - { - "BLOCK_M": block_size_m, - "BLOCK_N": block_size_n, - "BLOCK_K": block_size_k, - "waves_per_eu": waves_per_cu, - "matrix_instr_nonkdim": matrix_instr_nonkdim, - "NUM_CONSUMER_GROUPS": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - for block_size_m in [32, 64, 128] - for block_size_n in [32, 64, 128, 256] - for block_size_k in [128, 256] - for num_stages in [1, 2] - for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)] - for matrix_instr_nonkdim in [16] -] - - -def scaled_grouped_mm_configs(): - return _AMD_CONFIGS if torch.version.hip else _NV_CONFIGS - - -def early_config_prune(configs, named_args): - dtsize = 1 - pruned_configs = [] - for config in configs: - kw = config.kwargs - BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, num_consumer_groups = ( - kw["BLOCK_M"], - kw["BLOCK_N"], - kw["BLOCK_K"], - config.num_stages, - config.num_warps, - getattr(config, "num_consumer_groups", 0), - ) - G, M, N, K = ( - named_args["G"], - named_args["M_BUCKET"], - named_args["N"], - named_args["K"], - ) - - # 1. make sure we have enough smem - max_shared_memory = get_gpu_shared_memory() - - if torch.version.hip: - required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize - else: - required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize - if required_shared_memory > max_shared_memory: - continue - - use_warp_specialization = num_consumer_groups >= 1 - - M_PER_GROUP = M // G - MIN_M_TILES = 32 if torch.version.hip else 64 - # 2. make sure we don't load M tiles that are too big - if ( - not use_warp_specialization - and BLOCK_M > MIN_M_TILES - and BLOCK_M > (M_PER_GROUP * 2) - ): - continue - # 3. make sure we don't load N tiles that are too small - if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2): - continue - - num_sm = get_num_sms() - - N_TILES = N // BLOCK_N - MIN_N_TILES = 32 if torch.version.hip else 64 - # 4. make sure we don't load N tiles that are too big - if ( - not use_warp_specialization - and BLOCK_N > MIN_N_TILES - and M * N_TILES < num_sm - ): - continue - # 5. make sure we don't load N tiles that are too small - if BLOCK_N < 128 and M * N_TILES > 2 * num_sm: - continue - - # 6. make sure K can be evenly divided - if K % BLOCK_K != 0: - continue - - # 7. make sure we can partition for ws - if use_warp_specialization: - if num_warps != 4: - continue - - # "tritongpu-warp-spec-data-partition" - m_slice = BLOCK_M // num_consumer_groups - n_slice = BLOCK_N // num_consumer_groups - if m_slice < 64 and n_slice < 256: - continue - - pruned_configs.append(config) - - return pruned_configs - - -# Copied from fbgemm grouped_gemm.py -triton_scaled_grouped_mm_source = r""" -{{def_kernel("a_ptr", "b_ptr", "a_scale_ptr", "b_scale_ptr", "m_sizes")}} - tidx = tl.program_id(0) - - dtype = tl.float8e4nv - TMA_SIZE: tl.constexpr = tl.constexpr(128) - - workspace_base = ws_ptr + tidx * 2 * TMA_SIZE - c_desc_ptr = None - - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=a_ptr, - load_size=[BLOCK_M, BLOCK_K], - global_size=[M, K], - element_ty=a_ptr.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=b_ptr, - load_size=[BLOCK_N, BLOCK_K], - global_size=[N * G, K], - element_ty=b_ptr.dtype.element_ty, - ) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - M_end_offset = 0 - iterated_tiles = 0 - for g in tl.range(G): - # Move across groups - M_start_offset = M_end_offset - M_end_offset = tl.load(m_sizes + g) - m_size = M_end_offset - M_start_offset - - if m_size > 0: - N_start_offset = g.to(tl.int64) * N - n_size = N - num_m_tiles = tl.cdiv(m_size, BLOCK_M) - num_n_tiles = tl.cdiv(n_size, BLOCK_N) - num_tiles = num_m_tiles * num_n_tiles - - # Move across tiles - while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: - gidx = tidx - iterated_tiles - # Split M first and N second. - tile_m_idx = gidx % num_m_tiles - tile_n_idx = gidx // num_m_tiles - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - tl.static_assert(K % BLOCK_K == 0) - if USE_TMA_LOAD: - m_offset = (M_start_offset + tile_m_idx * BLOCK_M).to(tl.int32) - n_offset = (N_start_offset + tile_n_idx * BLOCK_N).to(tl.int32) - for k_offset in range(0, K, BLOCK_K): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_M, BLOCK_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_N, BLOCK_K], - dtype, - ) - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - else: - offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) - offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_K) - a_ptrs = ( - a_desc_ptr - + (M_start_offset + offs_am[:, None]) * K - + offs_k[None, :] - ) - b_ptrs = ( - b_desc_ptr - + (N_start_offset + offs_bn[:, None]) * K - + offs_k[None, :] - ) - for k_offset in range(0, K, BLOCK_K): - a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) - b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) - accumulator += tl.dot(a, b.T) - a_ptrs += BLOCK_K - b_ptrs += BLOCK_K - - offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) - offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) - a_scale = tl.load( - a_scale_ptr + M_start_offset + offs_am[:, None], - mask=offs_am[:, None] < m_size, - ) - b_scale = tl.load( - b_scale_ptr + N_start_offset + offs_bn[None, :], - mask=offs_bn[None, :] < n_size, - ) - c = accumulator.to(tl.float32) * a_scale * b_scale - - idx_m = (M_start_offset + offs_am[:, None]) - idx_n = offs_bn[None, :] - mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size - {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16)}} - tidx += NUM_SMS - - iterated_tiles += num_tiles -""" - - -triton_scaled_grouped_mm_template = TritonTemplate( - name="scaled_grouped_mm", - grid=persistent_grouped_mm_grid, - source=triton_scaled_grouped_mm_source, -) - - -def grouped_mm_args( - mat1: TensorBox, - mat2: TensorBox, - offs: Optional[TensorBox], - layout=None, - out_dtype=None, -): - mat1, mat2, offs = realize_inputs(mat1, mat2, offs) - mat1_size = mat1.get_size() - mat2_size = mat2.get_size() - - m1dim, m2dim = len(mat1_size), len(mat2_size) - - assert m1dim == 2 or m1dim == 3 - assert m2dim == 2 or m2dim == 3 - - if layout is None: - from torch._inductor.ir import FixedLayout - - if out_dtype is None: - out_dtype = mat1.get_dtype() - - dims = [] - if m1dim == 2: - if m2dim == 2: - assert offs is not None - dims = [offs.get_size()[0], mat1_size[0], mat2_size[1]] - else: - dims = [mat1_size[0], mat2_size[-1]] - else: - if m2dim == 2: - dims = [mat1_size[1], mat2_size[1]] - else: - dims = [mat1_size[0], mat1_size[1], mat2_size[-1]] - layout = FixedLayout( - mat1.get_device(), - out_dtype, - dims, - ) - else: - assert out_dtype is None, "out_dtype is ignored if layout is specified." - - return (mat1_size, mat2_size, layout, mat1, mat2, offs) - - -aten__scaled_grouped_mm = ExternKernelChoice( - torch._scaled_grouped_mm, - "at::_scaled_grouped_mm", - op_overload=aten._scaled_grouped_mm, - has_out_variant=False, -) - - -def can_use_triton_kernel( - mat_a: TensorBox, - mat_b: TensorBox, - offs: Optional[TensorBox], - bias: Optional[TensorBox], -) -> bool: - a_shape = mat_a.get_size() - b_shape = mat_b.get_size() - a_stride = mat_a.get_stride() - b_stride = mat_b.get_stride() - - # A must be contiguous 2d - a_layout_ok = ( - len(a_shape) == 2 - and a_stride[1] == 1 - and a_stride[0] == a_shape[1] - and a_shape[1] >= 32 - ) - - # B must be contiguous 3d with transposed last dimension - b_layout_ok = ( - len(b_shape) == 3 - and b_stride[2] == b_shape[1] - and b_stride[1] == 1 - and b_stride[0] == (b_shape[1] * b_shape[2]) - and b_shape[1] >= 32 - ) - - return ( - offs is not None - and bias is None - and has_triton_tma_device() - and a_layout_ok - and b_layout_ok - ) - - -def create_offsets(x, m1_size, m2_size, offs_size): - assert len(m1_size) == 2 and len(m2_size) == 3, ( - "Autotuning _scaled_grouped_mm is only implemented for 2d-3d tensors" - ) - m = V.graph.sizevars.size_hint(m1_size[0]) - noffs = V.graph.sizevars.size_hint(offs_size[0]) - step = m / noffs - return torch.linspace(step, m, noffs, dtype=x.get_dtype(), device=x.get_device()) - - -@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None) -def tuned_scaled_grouped_mm( - mat_a: TensorBox, - mat_b: TensorBox, - scale_a: TensorBox, - scale_b: TensorBox, - offs: Optional[TensorBox] = None, - bias: Optional[TensorBox] = None, - scale_result: Optional[TensorBox] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - layout: Optional[Layout] = None, -) -> TensorBox: - m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args( - mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype - ) - - counters["aten_mm_info"]["aten._scaled_grouped_mm.default"] += 1 - log.info( - "Tuned aten._scaled_grouped_mm.default: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", - m1_size, - m2_size, - mat_a.get_dtype(), - mat_b.get_dtype(), - layout, - ) - - check_supported_striding(mat_a, mat_b) - - scale_a, scale_b = realize_inputs(scale_a, scale_b) - - # workaround for Inductor not supporting optional tensor input arguments - input_nodes: list[Any] = [mat_a, mat_b, scale_a, scale_b] - if offs is not None: - input_nodes.append(realize_inputs(offs)) - if bias is not None: - input_nodes.append(realize_inputs(bias)) - - aten_choice = aten__scaled_grouped_mm.bind( - input_nodes, - layout, - out_dtype=out_dtype, - use_fast_accum=use_fast_accum, - ) - - choices: list[ChoiceCaller] = [] - if use_aten_gemm_kernels(): - choices.append(aten_choice) - - _, is_nonzero = _is_static_problem(layout) - - if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias): - m, k1 = m1_size - g, k2, n = m2_size - k = V.graph.sizevars.guard_equals(k1, k2) - kwargs = { - "G": g, - "M": m, - "M_BUCKET": next_power_of_2(m), - "N": n, - "K": k, - "NUM_SMS": get_num_sms(), - "USE_TMA_LOAD": True, - "USE_TMA_STORE": False, - "USE_FAST_ACCUM": use_fast_accum, - } - for config in early_config_prune(scaled_grouped_mm_configs(), kwargs): - triton_scaled_grouped_mm_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - workspace_arg=get_tma_workspace_arg( - num_tma_descriptors=2, - device=mat_a.get_device(), - ), - num_stages=config.num_stages, - num_warps=config.num_warps, - **kwargs, - **config.kwargs, - ) - - input_gen_fns = { - 4: lambda x: create_offsets(x, m1_size, m2_size, offs.get_size()), - } - return autotune_select_algorithm( - "scaled_grouped_mm", choices, input_nodes, layout, input_gen_fns=input_gen_fns - ) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index fa37748ce81708..ffcf431c0cb307 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -36,7 +36,7 @@ class InterpreterShim(torch.fx.Interpreter): @staticmethod - @functools.lru_cache(None) + @functools.cache def _dummy_gm(): return torch.fx.symbolic_trace(identity) @@ -312,6 +312,14 @@ def get_read_exprs(self): for entry in self.memory_usage[MemoryUsageType.LOAD] ] + def get_all_read_expr(self, buffer_name): + # reversed to match old behavior + out = [] + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + out.append(self.indexing_exprs[entry.index_name]) + return out + def get_write_exprs(self): return [ self.indexing_exprs[entry.index_name] @@ -321,6 +329,16 @@ def get_write_exprs(self): ) ] + def get_all_write_expr(self, buffer_name): + out = [] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + out.append(self.indexing_exprs[entry.index_name]) + return out + def debug_str(self): lines = [f"var_ranges = {dict(self.var_ranges)}"] lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index c0fd365ebdde4c..2d3fec708cba97 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -48,16 +48,19 @@ from . import config, inductor_prims, ir, test_operators # NOQA: F401 from .decomposition import decompositions, get_decompositions from .ir import ( + BaseView, DtypeView, ExpandView, IndexingConstant, IRNode, is_triton, + MutableBox, OnlineSoftmaxReduction, ops_wrapper, PermuteView, Pointwise, Reduction, + ShapeAsConstantBuffer, SqueezeView, TensorBox, validate_ir, @@ -164,7 +167,7 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A def tag_to_layout_constraint(tag): if tag == torch._C.Tag.needs_exact_strides: return constrain_to_fake_tensors - if tag == torch._C.Tag.needs_contiguous_strides: + if tag == torch._C.Tag.needs_contiguous_strides: # type: ignore[attr-defined] return require_contiguous_strides if tag == torch._C.Tag.needs_fixed_stride_order: return constrain_to_fx_strides @@ -494,7 +497,7 @@ def broadcast_symbolic_shapes(a, b): ): output.append(y) else: - V.graph.sizevars.guard_equals(x, y) + V.graph.sizevars.check_equals(x, y) if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): output.append(y) # prefer shorter formula else: @@ -697,7 +700,9 @@ def inner(*inputs: list[list[TensorBox]], alpha=1): return inner -def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): +def to_dtype( + x: Union[TensorBox, ShapeAsConstantBuffer], dtype: torch.dtype, copy: bool = False +): src_dtype = x.get_dtype() if src_dtype == dtype: return clone(x) if copy else x @@ -976,9 +981,9 @@ def squeeze(x, dim=None): return TensorBox(SqueezeView.create(x.data)) dim = ( - V.graph.sizevars.evaluate_static_shape(dim) + V.graph.sizevars.guard_int(dim) if isinstance(dim, (int, sympy.Expr)) - else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) + else tuple(V.graph.sizevars.guard_int(d) for d in dim) ) dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim) @@ -1069,7 +1074,9 @@ def expand(x, sizes): return x if not free_unbacked_symbols(x.get_size()): - x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) + x_size_product = V.graph.sizevars.size_hint_or_throw( + sympy_product(x.get_size()) + ) # TODO: It would be better to realize the input if any of its sizes # are unbacked, because typically the size will be non-zero. However, # this cannot be done directly as below as we'll choke on the size_hint @@ -1077,7 +1084,8 @@ def expand(x, sizes): if x_size_product > 0 and not free_unbacked_symbols(sizes): # maybe realize input before broadcasting it x.mark_reuse( - V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product + V.graph.sizevars.size_hint_or_throw(sympy_product(sizes)) + // x_size_product ) return TensorBox(ExpandView.create(x.data, tuple(sizes))) @@ -1136,12 +1144,13 @@ def inner_fn(index): return x_loader(index) if not free_unbacked_symbols(old_size) and not free_unbacked_symbols(new_size): - old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) + old_size_product = V.graph.sizevars.size_hint_or_throw(sympy_product(old_size)) if old_size_product > 0: # maybe realize the input but skip for unbacked symints since it'll # choke on the size hint. x.mark_reuse( - V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product + V.graph.sizevars.size_hint_or_throw(sympy_product(new_size)) + // old_size_product ) x_loader = x.make_loader() @@ -1156,9 +1165,7 @@ def inner_fn(index): @register_lowering(aten._unsafe_view, type_promotion_kind=None) @register_lowering(aten.view, type_promotion_kind=None) @register_lowering(aten.reshape, type_promotion_kind=None) -def view(x, sizes): - assert isinstance(x, TensorBox) - assert isinstance(sizes, (list, tuple)) +def view(x: TensorBox, sizes: Sequence[sympy.Expr]) -> TensorBox: return TensorBox(View.create(x.data, sizes)) @@ -1286,7 +1293,7 @@ def quantized_decomposed_quantize_per_channel( quant_min: int, quant_max: int, dtype: torch.dtype, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: assert len(scales.get_size()) == 1, "expect scales 1 dim" assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" @@ -1341,7 +1348,7 @@ def quantized_decomposed_dequantize_per_channel( dtype: torch.dtype, *, out_dtype: Optional[torch.dtype] = None, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: assert len(scales.get_size()) == 1, "expect scales 1 dim" assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" assert input.get_dtype() == dtype, ( @@ -1391,7 +1398,7 @@ def quantized_decomposed_quantize_per_tensor_default( quant_min: int, quant_max: int, dtype: torch.dtype, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: if input.get_dtype() == torch.bfloat16: input = to_dtype(input, torch.float32) assert input.get_dtype() == torch.float32, ( @@ -1432,7 +1439,7 @@ def quantized_decomposed_dequantize_per_tensor_default( dtype: torch.dtype, *, out_dtype: Optional[torch.dtype] = None, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: assert input.get_dtype() == dtype, ( f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" ) @@ -1469,7 +1476,7 @@ def quantized_decomposed_quantize_per_tensor_tensor( quant_min: int, quant_max: int, dtype: torch.dtype, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: if input.get_dtype() == torch.bfloat16: input = to_dtype(input, torch.float32) assert input.get_dtype() == torch.float32, ( @@ -1519,7 +1526,7 @@ def quantized_decomposed_dequantize_per_tensor_tensor( dtype: torch.dtype, *, out_dtype: Optional[torch.dtype] = None, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: assert len(scale.get_size()) == 0 or ( len(scale.get_size()) == 1 and scale.get_size()[0] == 1 ), "expect scale as scalar tensor" @@ -1765,9 +1772,7 @@ def split(x, sizes, dim=0): # by computing what the actual size of each chunk should be. if not isinstance(sizes, (list, tuple)): x_size = x.get_size()[dim] - chunks = V.graph.sizevars.evaluate_static_shape( - FloorDiv(x_size + sizes - 1, sizes) - ) + chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes)) sizes_ = [sizes] * chunks # The last chunk might have a smaller size than the rest. sizes_[-1] = x_size - (chunks - 1) * sizes @@ -1793,7 +1798,7 @@ def split_with_sizes(x, sizes, dim=0): @register_lowering(aten.unbind, type_promotion_kind=None) def unbind(x, dim=0): dim = _validate_dim(x, dim, 0) - x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + x_size = V.graph.sizevars.guard_int(x.get_size()[dim]) result = [select(x, dim, i) for i in range(x_size)] return result @@ -1809,12 +1814,14 @@ def unfold(x, dimension, size, step): dim_size = sizes[dim] sizevars = V.graph.sizevars - sizevars.guard_leq(size, dim_size) - sizevars.guard_lt(0, step) # type: ignore[arg-type] + sizevars.check_leq(size, dim_size) + sizevars.check_lt(0, step) # type: ignore[arg-type] new_dim_size = FloorDiv(dim_size - size, step) + 1 - if sizevars.size_hint(dim_size) > 0: - x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, dim_size))) + if sizevars.size_hint_or_throw(dim_size) > 0: + x.mark_reuse( + sizevars.size_hint_or_throw(CeilDiv(new_dim_size * size, dim_size)) + ) out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] @@ -1855,7 +1862,7 @@ def _validate_dim(x, dim, offset=0): def glu(x, dim=-1): dim = _validate_dim(x, dim, 0) # TODO: don't guard on static shape here - new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 + new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2 a = slice_(x, dim, 0, new_len) b = slice_(x, dim, new_len, new_len * 2) return mul(a, sigmoid(b)) @@ -1879,7 +1886,7 @@ def wrap_tensors(x): return handler -@functools.lru_cache(None) +@functools.cache def _warn_complex_not_supported(): warnings.warn( "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." @@ -2259,7 +2266,7 @@ def searchsorted( right: bool = False, side: Optional[str] = None, sorter: Optional[TensorBox] = None, -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731 tb, BackendFeature.BUCKETIZE ) @@ -2439,28 +2446,29 @@ def require_channels_last(_, *args, **kwargs): return args, kwargs -def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs): - def apply_constraint(arg, fake_arg): - if isinstance(arg, ir.IRNode): - meta_stride_expr = [ - s.node.expr if isinstance(s, torch.SymInt) else s - for s in fake_arg.stride() - ] - return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr) - if isinstance(arg, dict): - return { - key: apply_constraint(arg[key], fake_arg[key]) for key in arg.keys() - } - elif isinstance(arg, (tuple, list)): - return type(arg)( - apply_constraint(a, f_a) for (a, f_a) in zip(arg, fake_arg) - ) - return arg +def constrain_to_fake_tensor(arg, fake_arg): + if isinstance(arg, ir.IRNode): + meta_stride_expr = [ + s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride() + ] + return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr) + if isinstance(arg, dict): + return { + key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg.keys() + } + elif isinstance(arg, (tuple, list)): + return type(arg)( + constrain_to_fake_tensor(a, f_a) for (a, f_a) in zip(arg, fake_arg) + ) + return arg + +def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs): args = tuple( - apply_constraint(arg, fake_arg) for arg, fake_arg in zip(args, fake_args) + constrain_to_fake_tensor(arg, fake_arg) + for arg, fake_arg in zip(args, fake_args) ) - kwargs = {k: apply_constraint(v, fake_kwargs[k]) for k, v in kwargs.items()} + kwargs = {k: constrain_to_fake_tensor(v, fake_kwargs[k]) for k, v in kwargs.items()} return args, kwargs @@ -2659,6 +2667,8 @@ def is_aligned(x): make_fallback(aten.addbmm) make_fallback(aten._addmm_activation, warn=False) +make_fallback(aten._grouped_mm, require_dense) + # Need templated kernel. Probably impossible to write efficiently make_fallback(aten.convolution_backward, constrain_to_fx_strides) make_fallback(aten._cudnn_rnn, require_dense) @@ -2744,7 +2754,7 @@ def is_aligned(x): make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state) -# Implmented / Half implemented +# Implemented / Half implemented # Scans. Implemented for CUDA, missing CPU make_fallback(aten.masked_scatter) make_fallback(aten.masked_scatter_backward) @@ -2903,8 +2913,8 @@ def select_scatter(x, src, dim: int, index: int): dim = _validate_dim(x, dim, 0) if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): index = index + x.get_size()[dim] - V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] - V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type] src = expand(unsqueeze(src, dim), x.get_size()) src_loader = src.make_loader() @@ -3334,7 +3344,7 @@ def new_empty_strided( @register_lowering(prims.copy_strided.default) def copy_strided(x, stride): - stride = [V.graph.sizevars.size_hint(s) for s in stride] + stride = [V.graph.sizevars.size_hint_or_throw(s) for s in stride] stride_order = sorted(range(len(stride)), key=stride.__getitem__) return ir.ExternKernel.require_stride_order(x, stride_order) @@ -3611,6 +3621,7 @@ def index_put_as_masked_fill(self, indices, value, accumulate): def index_put_fallback(self, indices, values, accumulate): + assert isinstance(V.graph.current_node.target, torch._ops.OpOverload) ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) return self @@ -3734,8 +3745,10 @@ def indice_slice_from_randperm(indice): values = expand(values, expected_vals_size) # all guards are set above during broadcast_tensors and expand + device = self.get_device() + assert device is not None scatter = ir.Scatter( - device=self.get_device(), + device=device, dtype=self.get_dtype(), inner_fn=values.make_loader(), ranges=expected_vals_size, # iter_ranges, @@ -3955,10 +3968,13 @@ def backend_reduce_str(reduce): assert reduce is None return None + device = self.get_device() + assert device is not None + if not include_self: # zero out the corresponding elements first zero_out = ir.Scatter( - device=self.get_device(), + device=device, dtype=self.get_dtype(), inner_fn=lambda index: ops.constant(0, self.get_dtype()), ranges=index.get_size(), @@ -3977,7 +3993,7 @@ def backend_reduce_str(reduce): # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 scatter = ir.Scatter( - device=self.get_device(), + device=device, dtype=self.get_dtype(), inner_fn=fn, ranges=index.get_size(), @@ -4008,7 +4024,7 @@ def upsample_nearestnd( x_loader = x.make_loader() i_sizes = x.get_size()[-n:] batch = x.get_size()[:-n] - i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] + i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes] assert len(scales_x) == n o_sizes = output_size @@ -4340,10 +4356,10 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: # Sliding windows must start within the input or left padding x_alt -= 1 # type: ignore[assignment] - V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] if V.graph.sizevars.size_hint(x_out - x_alt) == 0: # ceil mode is actually a no-op, lets guard on that - V.graph.sizevars.guard_equals(x_out, x_alt) + V.graph.sizevars.check_equals(x_out, x_alt) ceil_mode = False else: x_out = x_alt @@ -4450,10 +4466,10 @@ def fn_inner(idx, reduction_idx): ranges=new_size, reduction_ranges=kernel_size, ) - if isinstance(result.data.data, Reduction): # type: ignore[attr-defined] + if isinstance(result.data.data, Reduction): # type: ignore[attr-defined, union-attr] # Only realize if reduction isn't unrolled result.realize() - if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined] + if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined, union-attr] # Only realize if reduction isn't unrolled offsets.realize() @@ -4503,7 +4519,7 @@ def _pool_offsets_to_indices( [Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]], torch._inductor.virtualized.OpsValue, ], -) -> TensorBox: +) -> Union[TensorBox, ShapeAsConstantBuffer]: n_dim = len(kernel_size) offsets_loader = offsets.make_loader() window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size)) @@ -4636,10 +4652,12 @@ def max_pool2d_with_indices_backward( x_stride: Optional[Sequence[Any]] if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] data = x.data.data # type: ignore[attr-defined] + device = data.get_device() + assert device is not None x_buffer = ir.ComputedBuffer( name=None, layout=ir.FlexibleLayout( - device=data.get_device(), + device=device, dtype=data.get_dtype(), size=data.get_size(), ), @@ -4904,8 +4922,8 @@ def _adaptive_avg_pool2d(x, output_size): *batch, h_in, w_in = x.get_size() - h_in = V.graph.sizevars.evaluate_static_shape(h_in) - w_in = V.graph.sizevars.evaluate_static_shape(w_in) + h_in = V.graph.sizevars.guard_int(h_in) + w_in = V.graph.sizevars.guard_int(w_in) h_out, w_out = output_size @@ -4979,8 +4997,8 @@ def adaptive_max_pool2d(x, output_size): *batch, h_in, w_in = x.get_size() - h_in = V.graph.sizevars.evaluate_static_shape(h_in) - w_in = V.graph.sizevars.evaluate_static_shape(w_in) + h_in = V.graph.sizevars.guard_int(h_in) + w_in = V.graph.sizevars.guard_int(w_in) h_out, w_out = output_size @@ -5057,7 +5075,28 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims): samples_loader = samples.make_loader() def load(prefix, i): - sample = samples_loader([*prefix, ndims - 1 - dim]) + # Handle indexing for samples tensor correctly for different input dimensions + # samples tensor always has shape (N, C, 2) for fractional_max_pool2d where: + # - N=1 for 3D inputs (C,H,W), N=batch_size for 4D inputs (N,C,H,W) + # - C=num_channels + # - 2 for the two spatial dimensions (height, width) + samples_shape = samples.get_size() + + if len(samples_shape) == 3: # Expected: (N, C, 2) + if len(prefix) == 1: + # 3D input case: prefix=(channel,), samples=(1, C, 2) + # Access: samples[0, channel, dim] + sample = samples_loader([0, prefix[0], ndims - 1 - dim]) + elif len(prefix) >= 2: + # 4D+ input case: prefix=(batch, channel, ...), samples=(batch, C, 2) + # Access: samples[batch, channel, dim] + sample = samples_loader([prefix[0], prefix[1], ndims - 1 - dim]) + else: + # Edge case - shouldn't happen for valid fractional pooling + sample = samples_loader([0, 0, ndims - 1 - dim]) + else: + # Fallback for unexpected tensor shapes + sample = samples_loader([*prefix, ndims - 1 - dim]) i_expr = ops.index_expr(i, samples.get_dtype()) diff = ops.index_expr(in_sz - kernel_sz, torch.int64) out_sz_expr = ops.index_expr(out_sz - 1, torch.int64) @@ -5136,9 +5175,11 @@ def increments_to_index(idx, reduction_idx): ranges=new_size, reduction_ranges=kernel_size, ) + assert isinstance(result, TensorBox), result if isinstance(result.data.data, Reduction): # type: ignore[attr-defined] # Only realize if reduction isn't unrolled result.realize() + assert isinstance(offsets, TensorBox), offsets if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined] # Only realize if reduction isn't unrolled offsets.realize() @@ -5156,8 +5197,8 @@ def upsample_nearest2d_backward( x.realize_hint() *_batch, inp_h, inp_w = x.get_size() - inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) - inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) + inp_h = V.graph.sizevars.guard_int(inp_h) + inp_w = V.graph.sizevars.guard_int(inp_w) *_batch, out_h, out_w = input_size @@ -5823,7 +5864,7 @@ def inner(x, axis=None, keepdims=False, *, dtype=None): ) result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) if isinstance( - result.data.data, # type: ignore[attr-defined] + result.data.data, # type: ignore[attr-defined, attr-type, union-attr] Reduction, ): # Only realize if reduction isn't unrolled result.realize() @@ -6069,12 +6110,14 @@ def mutate_to(changed, val, unsafe_alias=False): if not isinstance(val, ir.StorageBox): # introduce a copy to handle views - val = Pointwise.create( + node = Pointwise.create( device=changed.get_device(), dtype=changed.get_dtype(), inner_fn=val.make_loader(), ranges=changed.get_size(), - ).data + ) + assert isinstance(node, (BaseView, MutableBox)) + val = node.data assert isinstance(val, ir.StorageBox) if isinstance(changed_data, ir.StorageBox) and not ( @@ -6938,13 +6981,14 @@ def _map_output(out: Any): raise RuntimeError(f"NYI unsupported output type: {type(out)}") result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs) + assert isinstance(result, Sequence) return list(map(_map_output, result)) @register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None) def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): result = ir.InvokeSubgraph.create(subgraph_fn, *operands) - return list(map(TensorBox.create, result)) + return list(map(TensorBox.create, result)) # type: ignore[call-overload] @register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None) @@ -7074,7 +7118,7 @@ def prepare_softmax_online(x, dim): # Note: [Split online_softmax_reduce] # We don't split reduction for online_softmax_reduce for now. # On one hand, supporting split reduction makes things complex since - # the splitted out reuctions requires 2 inputs rather than one. + # the split out reuctions requires 2 inputs rather than one. # On the other hand, during training the online_softmax_reduce should # usually don't requires a split due to large batch size # (more specifically batch size times sequence length). diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 4f179691645965..b086234769f8eb 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -10,7 +10,7 @@ from torch.utils._ordered_set import OrderedSet from .ir import MultiOutputLayout, NoneLayout -from .utils import get_dtype_size +from .utils import get_dtype_size, is_wait from .virtualized import V @@ -147,8 +147,23 @@ def _compute_and_update_buf_size( sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False ) -> int: if isinstance(sched_buf.node.layout, NoneLayout): - sched_buf_to_size[sched_buf.get_name()] = (0, 0) - return 0 + _size = 0 + # for a wait tensor op, its schedulerBuffer NoneLayout layout. However, + # the schedulerBuffer is treated as a mutation of the collective output + # so it needs to inherit the size of the collectives + if ( + sched_buf.defining_op + and is_wait(sched_buf.defining_op.node) + and sched_buf.get_mutations() + ): + mutated_buf_name = sched_buf.get_mutations()[0] + _size = ( + sched_buf_to_size[mutated_buf_name][1] + if mutated_buf_name in sched_buf_to_size + else 0 + ) + sched_buf_to_size[sched_buf.get_name()] = (_size, _size) + return _size elif isinstance(sched_buf.node.layout, MultiOutputLayout): size_alloc = 0 for user in sched_buf.users: diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 4892a2b5e3697a..116550be70e795 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -433,7 +433,7 @@ def enabled_metric_tables() -> OrderedSet[str]: @lru_cache def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]: - enabled = OrderedSet[str]() + enabled: OrderedSet[str] = OrderedSet() for name in config_str.split(","): name = name.strip() if not name: diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 1df926bb903f26..db63d880d971e8 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -1,11 +1,11 @@ # mypy: allow-untyped-defs from collections.abc import Sequence -from typing import Any, Optional +from typing import Any, Optional, Union import sympy import torch -from torch._prims_common import make_channels_last_strides_for +from torch._prims_common import make_channels_last_strides_for, StrideType from torch.utils._ordered_set import OrderedSet from .ir import ( @@ -14,6 +14,7 @@ FlexibleLayout, get_device_type, ir_node_to_tensor, + IRNode, is_contiguous_storage_and_layout, Layout, may_convert_to_optional, @@ -21,9 +22,10 @@ MultiOutputLayout, MutationOutput, NoneLayout, + ShapeAsConstantBuffer, TensorBox, ) -from .utils import convert_shape_to_inductor, pad_listlike +from .utils import convert_shape_to_inductor, pad_listlike, SUPPORTED_MKLDNN_DEVICES from .virtualized import V @@ -164,7 +166,7 @@ def _original_deconv_weight_size( x = cls.require_stride_order(x, req_stride_order) - # We won't do weight prepack for Conv if dynamic_shapes. + # We won't do weight prepack for Conv if dynamic_shapes or if is xpu. # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order @@ -172,13 +174,23 @@ def _original_deconv_weight_size( # this tensor is considered as channels first and the output will be in contiguous format. # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) - if dynamic_shapes and is_contiguous_storage_and_layout(x): + if ( + dynamic_shapes or get_device_type(x) == "xpu" + ) and is_contiguous_storage_and_layout(x): + output_stride: StrideType = FlexibleLayout.contiguous_strides(output_size) + # Currently we don't support channel last for the situation that stride of input's batch dim is 0, + # eg. input_size = (1, 1280, 64, 64), but input_stride=(0, 1, 81920, 1280). + # So we use NCHW hear instead. + # Different with cpu, cpu conv always use channels_last for convolution when weight is prepacked, + # but xpu does not do the prepack, so the problem exposed here is only for xpu. + # TODO support channels_last for such zero stride input. + elif get_device_type(x) == "xpu" and x.get_stride()[0] == 0: output_stride = FlexibleLayout.contiguous_strides(output_size) else: output_stride = make_channels_last_strides_for(output_size) assert get_device_type(x) == get_device_type(weight) - assert get_device_type(x) in ["cpu", "xpu"] + assert get_device_type(x) in SUPPORTED_MKLDNN_DEVICES inputs = [x] if quantize_args is not None: @@ -242,7 +254,7 @@ def _prepare_linear_fusion_create( x = cls.require_stride_order(x, req_stride_order) assert get_device_type(x) == get_device_type(weight) - assert get_device_type(x) in ["cpu", "xpu"] + assert get_device_type(x) in SUPPORTED_MKLDNN_DEVICES inputs = [x] if quantize_args is not None: @@ -294,17 +306,20 @@ def __init__( inputs, constant_args=(), ) -> None: + self.device_type = get_device_type(inputs[0]) super().__init__( layout, inputs, constant_args, None, op_overload=torch.ops.mkldnn._convolution_pointwise.default, - cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise", + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise", ) def codegen(self, wrapper): - wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) super().codegen(wrapper) @classmethod @@ -351,18 +366,21 @@ def __init__( constant_args=(), cpp_constant_args=(), ) -> None: + self.device_type = get_device_type(inputs[0]) super().__init__( layout, inputs, constant_args, None, op_overload=torch.ops.mkldnn._convolution_pointwise.binary, - cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise_binary", + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise_binary", ) self.cpp_constant_args = cpp_constant_args def codegen(self, wrapper): - wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) super().codegen(wrapper) @classmethod @@ -416,6 +434,7 @@ def __init__( constant_args=(), ) -> None: # Due to constrain of op.call, other (Tensor&) should be at input[0] + self.device_type = get_device_type(inputs[0]) reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] super().__init__( @@ -424,7 +443,7 @@ def __init__( constant_args, None, op_overload=torch.ops.mkldnn._convolution_pointwise_.binary, - cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise_binary_", + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise_binary_", ) self.mutation_outputs = [ @@ -433,7 +452,9 @@ def __init__( ] def codegen(self, wrapper): - wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) super().codegen(wrapper) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: @@ -589,8 +610,8 @@ def codegen(self, wrapper): def create( cls, qx: "TensorBox", - x_scale: "TensorBox", - x_zero_point: "TensorBox", + x_scale: Union["ShapeAsConstantBuffer", "TensorBox"], + x_zero_point: Union["ShapeAsConstantBuffer", "TensorBox"], qw: "TensorBox", # qw w_scale: "TensorBox", w_zero_point: "TensorBox", @@ -625,7 +646,7 @@ def create( groups, transposed, output_padding, - [x_scale, x_zero_point, w_scale, w_zero_point], + [x_scale, x_zero_point, w_scale, w_zero_point], # type: ignore[list-item] ) # swap padding and stride to align with functional conv arg order if bias is None: @@ -667,11 +688,11 @@ def __init__( if bias is not None - inputs = [x, x_scale, x_zp, w, w_scale, w_zp, accum, b] - const_args = [stride, padding, dilation, groups, o_scale, o_zp, - output_dtype, accum_scale, accum_zp, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] else - inputs = [x, x_scale, x_zp, w, w_scale, w_zp, accum] - const_args [b, stride, padding, dilation, groups, o_scale, o_zp, - output_dtype, accum_scale, accum_zp, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] """ self.has_bias = len(inputs) == 8 self.idx_for_inplace_sum = 6 @@ -691,7 +712,7 @@ def codegen(self, wrapper): self.codegen_size_asserts(wrapper) def get_mutation_names(self) -> Sequence[str]: - return [self.inputs[self.idx_for_inplace_sum].get_name()] + return [self.input_name(self.idx_for_inplace_sum)] def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() @@ -813,10 +834,10 @@ def create(cls, x, packed_w, orig_w, B, batch_size): else: constant_args.insert(0, None) + device = x.get_device() + assert device is not None return MKLPackedLinear( - layout=FixedLayout( - x.get_device(), x.get_dtype(), output_size, output_stride - ), + layout=FixedLayout(device, x.get_dtype(), output_size, output_stride), inputs=inputs, constant_args=constant_args, ) @@ -858,9 +879,12 @@ def create(cls, x, w, B, attr, scalars, algorithm): else: constant_args.insert(0, None) + device = x.get_device() + assert device is not None + packed = LinearUnary( layout=FixedLayout( - device=x.get_device(), + device=device, dtype=x.get_dtype(), size=output_size, ), @@ -912,9 +936,11 @@ def create(cls, x, y, w, B, attr): else: constant_args.insert(0, B) + device = x.get_device() + assert device is not None packed = LinearBinary( layout=FixedLayout( - device=x.get_device(), + device=device, dtype=x.get_dtype(), size=output_size, ), @@ -1022,11 +1048,11 @@ def __init__( if bias is not None - inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2, bias] - const_args is: [o_scale, o_zp, - fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] else - inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2] - const_args is: [bias, o_scale, o_zp, - fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] """ self.has_bias = has_bias self.idx_for_inplace_sum = 6 @@ -1048,7 +1074,9 @@ def codegen(self, wrapper): def get_mutation_names(self) -> Sequence[str]: binary_post_op = self.constant_args[-5] if binary_post_op == "sum": - return [self.inputs[self.idx_for_inplace_sum].get_name()] + input = self.inputs[self.idx_for_inplace_sum] + assert isinstance(input, IRNode) + return [input.get_name()] else: return [] @@ -1199,8 +1227,10 @@ def create( train, ] + device = x.get_device() + assert device is not None packed = MkldnnRnnLayer( - MultiOutputLayout(device=x.get_device()), + MultiOutputLayout(device=device), inputs=inputs, constant_args=constant_args, ) @@ -1221,7 +1251,7 @@ def get_strides_of_lstm_output(output_shape, batch_first): output_ir = [ MultiOutput( FixedLayout( - x.get_device(), + x.get_device(), # type: ignore[arg-type] x.get_dtype(), output_size, output_stride, diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 23e5e2410a1728..3d750e7731aa61 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,12 +1,12 @@ # mypy: allow-untyped-defs import functools -from typing import Optional +from typing import Optional, Union import torch import torch.utils._pytree as pytree from torch._inductor.kernel.mm_common import mm_args -from . import ir +from . import config, ir from .codegen.cpp_gemm_template import CppGemmTemplate from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate from .codegen.cpp_utils import create_epilogue_with_attr @@ -25,7 +25,7 @@ ChoiceCaller, ExternKernelChoice, ) -from .utils import use_aten_gemm_kernels, use_cpp_gemm_template, use_max_autotune +from .utils import use_aten_gemm_kernels, use_cpp_gemm_template from .virtualized import ops, OpsValue, V @@ -35,18 +35,20 @@ def create_int8_compensation( x_scale: ir.TensorBox, x_zp: ir.TensorBox, w_scale: ir.TensorBox, -) -> tuple[bool, ir.TensorBox, Optional[ir.TensorBox]]: - use_int8_fast_compensation_path = False - weight_compens = None - x_w_scale = None - if all( +) -> tuple[ + bool, + Union[ir.TensorBox, ir.ShapeAsConstantBuffer], + Optional[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]], +]: + x_w_scale: Optional[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]] = None + use_int8_fast_compensation_path = all( isinstance(item, ir.TensorBox) and item.get_name() in V.graph.constants and hasattr(item.data, "data") and isinstance(item.data.data, ir.ConstantBuffer) for item in [x_scale, x_zp, w_scale] - ): - use_int8_fast_compensation_path = True + ) + if use_int8_fast_compensation_path: x_w_scale_tensor = ( V.graph.constants[x_scale.get_name()] * V.graph.constants[w_scale.get_name()] @@ -68,7 +70,7 @@ def create_int8_compensation( weight_compens_tensor, name=packed_weight.get_name() + "_BMatrixCompens", ) - return ( + return ( # type: ignore[return-type] use_int8_fast_compensation_path, weight_compens, x_w_scale, @@ -139,7 +141,7 @@ def grouped_gemm_lowering( x = view(x, [-1, x_size[-1]]) num_gemm = len(w) - assert use_max_autotune() + assert config.max_autotune or config.max_autotune_gemm b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b] choices: list[ChoiceCaller] = [] @@ -182,7 +184,7 @@ def grouped_gemm_lowering( if len(x_size) > 2: for gemm_idx in range(num_gemm): return_tensors[gemm_idx] = view( - return_tensors[gemm_idx], + return_tensors[gemm_idx], # type: ignore[arg-type] (*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]), ) return return_tensors @@ -339,9 +341,9 @@ def linear_unary( # GEMM template needs 2D input, normalize input shape here x = view(x, [-1, x_size[-1]]) if b is not None: - b = ir.ExternKernel.realize_input(b) + b = ir.ExternKernel.realize_input(b) # type: ignore[assignment] choices: list[ChoiceCaller] = [] - if use_max_autotune(): + if config.max_autotune or config.max_autotune_gemm: transposed_w = permute(w, [1, 0]) *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) if use_cpp_gemm_template(layout, x, transposed_w): @@ -402,9 +404,9 @@ def linear_binary( if len(y_size) > 2: y = view(y, [-1, y_size[-1]]) if b is not None: - b = ir.ExternKernel.realize_input(b) + b = ir.ExternKernel.realize_input(b) # type: ignore[assignment] choices: list[ChoiceCaller] = [] - if use_max_autotune(): + if config.max_autotune or config.max_autotune_gemm: transposed_w = permute(w, [1, 0]) *_, layout, x, transposed_w, y = mm_args( x, transposed_w, y, layout=layout @@ -624,13 +626,13 @@ def qconvolution_binary( # For int8-mixed-bf16 quantization and inplace add, # there is case when accum dtype is float32 but output dtype is bfloat16. # Since the accum will be inplaced changed with post op sum, - # we will do accum dtype convertion here. + # we will do accum dtype conversion here. accum = to_dtype(accum, output_dtype) return TensorBox.create( mkldnn_ir.QConvPointWiseBinaryPT2E.create( x, - x_scale, - x_zp, + x_scale, # type: ignore[arg-type] + x_zp, # type: ignore[arg-type] packed_weight, w_scale, w_zp, @@ -727,14 +729,14 @@ def qlinear_unary( ): # W_zp might be a ConstantBuffer with int64, convert it to int32 w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) - w_zp = V.graph.add_tensor_constant( + w_zp = V.graph.add_tensor_constant( # type: ignore[assignment] torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() ) bias_dtype = None if bias is None else bias.get_dtype() choices: list[ChoiceCaller] = [] - if use_max_autotune(): + if config.max_autotune or config.max_autotune_gemm: *_, layout, x, packed_weight = mm_args( x, packed_weight, layout=layout, out_dtype=output_dtype ) @@ -1030,7 +1032,7 @@ def qlinear_binary( ir.ConstantBuffer, ): w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) - w_zp = V.graph.add_tensor_constant( + w_zp = V.graph.add_tensor_constant( # type: ignore[assignment] torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() ) if binary_attr == "sum": @@ -1042,7 +1044,7 @@ def qlinear_binary( # For int8-mixed-bf16 quantization and inplace add, # there is case when accum dtype is float32 but output dtype is bfloat16. # Since the accum will be inplaced changed with post op sum, - # we will do accum dtype convertion here. + # we will do accum dtype conversion here. x2 = to_dtype(x2, output_dtype) else: assert x2.get_dtype() == output_dtype, ( @@ -1052,8 +1054,8 @@ def qlinear_binary( bias_dtype = bias.get_dtype() if bias is not None else None choices: list[ChoiceCaller] = [] if ( - use_max_autotune() and binary_attr == "add" - ): # Support inplace sum fusion + config.max_autotune or config.max_autotune_gemm + ) and binary_attr == "add": # Support inplace sum fusion *_, layout, x, packed_weight, x2 = mm_args( x, packed_weight, x2, layout=layout, out_dtype=output_dtype ) @@ -1302,7 +1304,7 @@ def mkl_packed_linear( layout=None, ): choices: list[ChoiceCaller] = [] - if use_max_autotune(): + if config.max_autotune or config.max_autotune_gemm: transposed_w = permute(orig_w, [1, 0]) *_, layout, x, transposed_w = mm_args( x, transposed_w, layout=layout diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 2f7ad5251fe814..35b5f464dd775a 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -31,6 +31,7 @@ "prod", "sum", "xor_sum", + "online_softmax_reduce", ] diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 2bfc1aefcc8ee3..460fae79fbf58c 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -61,6 +61,7 @@ from torch._inductor import metrics from torch._inductor.graph import GraphLowering from torch._library.fake_class_registry import FakeScriptObject + from torch.export.pt2_archive._package_weights import Weights from .compile_fx import _CompileFxKwargs from .triton_bundler import TritonBundle @@ -718,7 +719,7 @@ class CompiledAOTI(OutputCode): Class holding an AOTInductor compiled so. """ - filename: Union[str, list[str]] + filename: Union[str, list[Union[str, Weights]]] def __call__(self, inputs: Sequence[Any]) -> Any: raise NotImplementedError("NYI") diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index b1d47150b427b6..726b41d9724037 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -3,19 +3,16 @@ import logging import os import tempfile -from typing import Any, IO, Optional, Union +from typing import IO import torch -import torch._inductor -import torch.utils._pytree as pytree from torch._inductor import config from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder -from torch.export._tree_utils import reorder_kwargs -from torch.export.pt2_archive._package import PT2ArchiveWriter -from torch.export.pt2_archive.constants import ( - AOTINDUCTOR_DIR, - CONSTANTS_DIR, - CUSTOM_OBJ_FILENAME_PREFIX, +from torch.export.pt2_archive._package import ( + AOTI_FILES, + AOTICompiledModel, + load_pt2, + package_pt2, ) from torch.types import FileLike @@ -84,7 +81,7 @@ def get_aoti_file_with_suffix(suffix: str) -> str: def package_aoti( archive_file: FileLike, - aoti_files: Union[list[str], dict[str, list[str]]], + aoti_files: AOTI_FILES, ) -> FileLike: """ Saves the AOTInductor generated files to the PT2Archive format. @@ -95,122 +92,11 @@ def package_aoti( the AOTInductor files, or a dictionary mapping the model name to the path to its AOTInductor generated files. """ - if isinstance(aoti_files, list): - aoti_files = {"model": aoti_files} - - assert isinstance(aoti_files, dict), ( - "Please pass a list of AOTI generated files to be packaged or " - "a dictionary mapping model names to their list of AOTI generated " - "files. You can get this list of files through calling " - "`torch._inductor.aot_compile(..., options={aot_inductor.package=True})`" - ) - assert ( - isinstance(archive_file, (io.IOBase, IO)) - and archive_file.writable() - and archive_file.seekable() - ) or ( - isinstance(archive_file, (str, os.PathLike)) - and os.fspath(archive_file).endswith(".pt2") - ), ( - f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}" - ) - - # Save using the PT2 packaging format - # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) - - with PT2ArchiveWriter(archive_file) as archive_writer: - for model_name, files in aoti_files.items(): - num_so_files = 0 - num_cpp_files = 0 - - for file in files: - if file == "": - continue - - if file.endswith(".so"): - num_so_files += 1 - if num_so_files > 1: - raise RuntimeError( - f"Multiple .so files found in {files}. " - "You might need to clear your cache " - "directory before calling aoti_compile again." - ) - if file.endswith(".cpp"): - num_cpp_files += 1 - if num_so_files > 1: - raise RuntimeError( - f"Multiple .cpp files found in {files}. " - "You might need to clear your cache " - "directory before calling aoti_compile again." - ) - - filename = os.path.basename(file) - if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX): - new_filepath = os.path.join(CONSTANTS_DIR, filename) - else: - new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename) - log.debug( - "Saving AOTI generated file %s to archive in %s", file, new_filepath - ) - archive_writer.write_file( - str(new_filepath), - file, - ) - - if isinstance(archive_file, (io.IOBase, IO)): - archive_file.seek(0) - return archive_file - - -class AOTICompiledModel: - """ - Callable AOT Inductor loaded model from a .pt2 - """ - - def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: - self.loader = loader - - def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] - call_spec = self.loader.get_call_spec() # type: ignore[attr-defined] - in_spec = pytree.treespec_loads(call_spec[0]) - out_spec = pytree.treespec_loads(call_spec[1]) - flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] - flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] - flat_outputs = self.loader.boxed_run(flat_inputs) # type: ignore[attr-defined] - return pytree.tree_unflatten(flat_outputs, out_spec) - - def get_metadata(self) -> dict[str, str]: - return self.loader.get_metadata() # type: ignore[attr-defined] - - def load_constants( - self, - constants_map: dict[str, torch.Tensor], - *, - check_full_update: bool, - user_managed: bool = False, - ) -> None: - """ - Given a mapping of constant fqns to tensors, load the constants into the model. - You can use ``get_constant_fqns`` to get the list of constant fqns that - are needed in the compiled model. - - Args: - constants_map: A mapping of constant fqns to tensors. - check_full_update: Whether to add check to see if all the constants - are updated and have values. - """ - self.loader.load_constants( # type: ignore[attr-defined] - constants_map, False, check_full_update, user_managed - ) - - def get_constant_fqns(self) -> list[str]: - return self.loader.get_constant_fqns() # type: ignore[attr-defined] - def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel": - log.warning( - "AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied." - ) - return AOTICompiledModel(self.loader) # type: ignore[attr-defined] + return package_pt2( + archive_file, + aoti_files=aoti_files, + ) def load_package( @@ -220,26 +106,33 @@ def load_package( num_runners: int = 1, device_index: int = -1, ) -> AOTICompiledModel: # type: ignore[type-arg] - assert ( - isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable() - ) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), ( - f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}" - ) + try: + pt2_contents = load_pt2( + path, + run_single_threaded=run_single_threaded, + num_runners=num_runners, + device_index=device_index, + ) + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in package") + return pt2_contents.aoti_runners[model_name] + except RuntimeError: + log.warning("Loading outdated pt2 file. Please regenerate your package.") if isinstance(path, (io.IOBase, IO)): with tempfile.NamedTemporaryFile(suffix=".pt2") as f: # TODO(angelayi): We shouldn't need to do this -- miniz should # handle reading the buffer. This is just a temporary workaround - f.write(path.read()) path.seek(0) + f.write(path.read()) log.debug("Writing buffer to tmp file located at %s.", f.name) loader = torch._C._aoti.AOTIModelPackageLoader( f.name, model_name, run_single_threaded, num_runners, device_index - ) # type: ignore[call-arg] + ) return AOTICompiledModel(loader) path = os.fspath(path) # AOTIModelPackageLoader expects (str, str) loader = torch._C._aoti.AOTIModelPackageLoader( path, model_name, run_single_threaded, num_runners, device_index - ) # type: ignore[call-arg] + ) return AOTICompiledModel(loader) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 792a6b4385a2e7..1da31586b0a18b 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1010,7 +1010,7 @@ def __init__(self) -> None: self.memoized_objs_pp: dict[PatternExpr, str] = {} @staticmethod - @functools.lru_cache(None) + @functools.cache def run(obj: PatternExpr, output_name: str = "output") -> str: """ Serializes obj to python code with obj written out to `output_name` @@ -1167,12 +1167,18 @@ def run_node(self, node: torch.fx.Node) -> Any: raise NotImplementedError( f"NYI: replacement_graph.{target} is not a graph module. Got {sub_gm}." ) - assert graph.owning_module is not None - _, graph_name = unique_graph_name_with_root( - graph.owning_module, str(target) - ) - graph.owning_module.register_module(graph_name, sub_gm) + graph_name = None + for n, mod in graph.owning_module.named_modules(): + if sub_gm is mod: + graph_name = n + break + if graph_name is None: + assert isinstance(target, str) + _, graph_name = unique_graph_name_with_root( + graph.owning_module, target + ) + graph.owning_module.register_module(graph_name, sub_gm) return graph.get_attr(graph_name) raise NotImplementedError(f"unhandled {node}") @@ -1558,10 +1564,10 @@ def normalize_args(**kwargs: Any) -> list[Any]: normalize_args=normalize_args, ) pattern.register(pass_dicts) - return pattern.pattern + return pattern.pattern # type: ignore[return-value] -_serialized_patterns = OrderedSet[str]() +_serialized_patterns: OrderedSet[str] = OrderedSet() def _serialize_pattern( @@ -2195,7 +2201,7 @@ def stable_topological_sort(graph: torch.fx.Graph) -> None: def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: """Wrapper around lazy init functions in fx_passes/""" - @functools.lru_cache(None) + @functools.cache @functools.wraps(fn) def lazy_init() -> Any: counters_ref = counters["inductor"].copy() @@ -2235,7 +2241,7 @@ def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node: # TODO: remove in follow up diff, used internally -_seen_patterns = OrderedSet[str]() +_seen_patterns: OrderedSet[str] = OrderedSet() def get_arg_value( diff --git a/torch/_inductor/quantized_lowerings.py b/torch/_inductor/quantized_lowerings.py index c5d1cd634e9536..c7628314a85cbb 100644 --- a/torch/_inductor/quantized_lowerings.py +++ b/torch/_inductor/quantized_lowerings.py @@ -4,7 +4,7 @@ import torch from torch._inductor.kernel.mm_common import mm_args -from . import lowering +from . import config, lowering from .codegen.cpp_gemm_template import CppGemmTemplate, CppWoqInt4GemmTemplate from .codegen.cpp_utils import create_epilogue_with_attr from .lowering import expand, register_lowering @@ -14,7 +14,7 @@ ExternKernelChoice, realize_inputs, ) -from .utils import use_aten_gemm_kernels, use_cpp_gemm_template, use_max_autotune +from .utils import use_aten_gemm_kernels, use_cpp_gemm_template from .virtualized import V @@ -126,7 +126,7 @@ def int4pack_mm_cpu( else [] ) if ( - use_max_autotune() + (config.max_autotune or config.max_autotune_gemm) and use_cpp_gemm_template( aten_layout, mat1, diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 2aadc806bf902c..aaa266b60e00b3 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -136,7 +136,7 @@ def decode(self, data: _T) -> _T: # To write (`put`), the RemoteCache takes data, uses the RemoteCacheSerde to # convert it for the backend and passes it to the backend. # -# Conversly when reading (`get`), the RemoteCache takes data from the backend, +# Conversely when reading (`get`), the RemoteCache takes data from the backend, # uses the RemoteCacheSerde to convert it and returns it. # # The RemoteCacheBackend is generic on _U - which is the type of data the diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index f4433743726986..01d038aab8e7bc 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -1,3 +1,28 @@ +""" +PyTorch Inductor Autotuning Cache System + +This module implements a caching system for autotuning configurations in PyTorch's Inductor compiler. +It provides mechanisms to store and retrieve optimal kernel configurations both locally and remotely, +which significantly speeds up compilation by reusing previously discovered optimal parameters. + +The caching system includes: +- Local filesystem caching for individual machine reuse +- Remote caching for sharing optimizations across machines +- Bundled caching to efficiently store multiple related configurations +- Cache invalidation based on PyTorch versions and backend changes +- Serialization/deserialization support for worker processes + +Key components: +- AutotuneCache: Main class for managing cache access and storage +- AutotuneCacheBundler: Bundles multiple cache entries for efficient storage +- LocalAutotuneCache: Handles filesystem-based caching +- _LocalAutotuneCacheBackend: Low-level file operations for cache storage +- AutotuneCacheArtifact: Integration with PyTorch's artifact system + +This caching system is critical for performance as it eliminates the need to re-run +expensive autotuning operations when the same kernels are compiled multiple times. +""" + from __future__ import annotations import dataclasses @@ -242,7 +267,11 @@ def __setstate__(self, state: dict[str, Any]) -> None: # Save the config in the caches def save( - self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False + self, + config: Config, + time_taken_ns: int, + found_by_coordesc: bool = False, + triton_cache_hash: Optional[str] = None, ) -> None: data = { **config.kwargs, @@ -251,6 +280,7 @@ def save( "configs_hash": self.configs_hash, "found_by_coordesc": found_by_coordesc, "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS + "triton_cache_hash": triton_cache_hash, } if HAS_WARP_SPEC: data.update( @@ -514,6 +544,8 @@ def _load_cached_autotuning( # Remove time taken for comparison best_config.pop("time_taken_ms", None) + best_config.pop("triton_cache_hash", None) + if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( "found_by_coordesc", False ): diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 74df6ed671ef3e..5c9cc60bef87a0 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -230,7 +230,7 @@ def benchmark_gpu( in milliseconds. An estimated duration is calculated based on the values of `memory_warmup_iters` and `benchmark_iters`, along with the estimated runtime of `_callable` and various other factors, and we then shrink - `benchmark_iters` to fit in the alloted maximum duration. + `benchmark_iters` to fit in the allotted maximum duration. - **kwargs: Additional kwargs that may be passed to the fallback. Returns: diff --git a/torch/_inductor/runtime/cache_dir_utils.py b/torch/_inductor/runtime/cache_dir_utils.py index e1939d59162c35..34b84a68f6300c 100644 --- a/torch/_inductor/runtime/cache_dir_utils.py +++ b/torch/_inductor/runtime/cache_dir_utils.py @@ -39,15 +39,15 @@ def triton_cache_dir(device: int) -> str: @contextmanager def temporary_cache_dir(directory: str) -> Generator[None, None, None]: - from torch._inductor.utils import clear_inductor_caches + from torch._inductor.utils import clear_caches original = os.environ.get("TORCHINDUCTOR_CACHE_DIR") os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory try: - clear_inductor_caches() + clear_caches() yield finally: - clear_inductor_caches() + clear_caches() if original is None: del os.environ["TORCHINDUCTOR_CACHE_DIR"] else: diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 78be2e3787cc19..67140369faac46 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -34,7 +34,7 @@ def _reload_python_module( return mod -@functools.lru_cache(None) +@functools.cache def _set_triton_ptxas_path() -> None: if os.environ.get("TRITON_PTXAS_PATH") is not None: return diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index b41ca81ebdfc6b..413dfaf09d061b 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -208,7 +208,7 @@ def compare_config(self, func, candidate_config, best_config, best_timing): """ Check if candidate_config is better than best_config. - Return a touple of (compare_result, candidate_timing). + Return a tuple of (compare_result, candidate_timing). compare_result is true iff candidate_config is better. """ log.debug("Try config %s", candidate_config) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index f224217db22b11..e559eaa1a31d4d 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -136,7 +136,7 @@ class DeviceProperties(typing.NamedTuple): warp_size: Optional[int] = None @classmethod - @functools.lru_cache(None) + @functools.cache def create(cls, device) -> DeviceProperties: import torch from torch._dynamo.device_interface import get_interface_for_device diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 9d57232e299e47..21cd5987f8f435 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -5,6 +5,9 @@ from typing import Any, TYPE_CHECKING import torch + +# NOTE: other files rely on the imports below +from torch._dynamo import callback as compilation_callback # noqa: F401 from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 cache_dir, default_cache_dir, @@ -22,8 +25,8 @@ def conditional_product(*args: int) -> int: return functools.reduce(operator.mul, [x for x in args if x]) -def ceildiv(numer: int, denom: int) -> int: - return -(numer // -denom) +def ceildiv(number: int, denom: int) -> int: + return -(number // -denom) def is_power_of_2(n: int) -> bool: @@ -152,7 +155,7 @@ def get_first_attr(obj: Any, *attrs: str) -> Any: def triton_hash_to_path_key(key: str) -> str: # In early versions of Triton, the hash is directly used in the path name. # Later, the hash is converted to base64 before being used in the path name. - # Later, the base64 convertion was replaced to the base32 + # Later, the base64 conversion was replaced to the base32 # # This code tries to import _base64 and falls back to _base32 if _base64 is unavailable. # diff --git a/torch/_inductor/runtime/triton_compat.py b/torch/_inductor/runtime/triton_compat.py index e244ecbee635f0..877f72b50c5506 100644 --- a/torch/_inductor/runtime/triton_compat.py +++ b/torch/_inductor/runtime/triton_compat.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from typing import Any, Union import torch @@ -74,6 +75,18 @@ def _log2(x: Any) -> Any: from triton import knobs except ImportError: knobs = None + + try: + from triton.runtime.cache import triton_key # type: ignore[attr-defined] + except ImportError: + from triton.compiler.compiler import ( + triton_key, # type: ignore[attr-defined,no-redef] + ) + + builtins_use_semantic_kwarg = ( + "_semantic" in inspect.signature(triton.language.core.view).parameters + ) + HAS_TRITON = True else: def _raise_error(*args: Any, **kwargs: Any) -> Any: @@ -94,6 +107,7 @@ class PTXASError(Exception): # type: ignore[no-redef] libdevice = None math = None knobs = None + builtins_use_semantic_kwarg = False class triton: # type: ignore[no-redef] @staticmethod @@ -109,6 +123,8 @@ def constexpr(val: Any) -> Any: dtype = Any HAS_WARP_SPEC = False + triton_key = _raise_error + HAS_TRITON = False def cc_warp_size(cc: Union[str, int]) -> int: @@ -145,4 +161,5 @@ class autograd_profiler: # type: ignore[no-redef] "triton", "cc_warp_size", "knobs", + "triton_key", ] diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 05a04b4030a516..cfd708bcf4bf81 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -2,9 +2,17 @@ # mypy: allow-untyped-defs import math as pymath import warnings -from typing import Any, TypeVar +from functools import wraps +from typing import Any, Callable, TypeVar -from .triton_compat import _log2, libdevice, math, tl, triton # noqa: F401 +from .triton_compat import ( # noqa: F401 + _log2, + builtins_use_semantic_kwarg, + libdevice, + math, + tl, + triton, +) _T = TypeVar("_T") @@ -194,7 +202,7 @@ def online_softmax_combine(lhs_max, lhs_sum, rhs_max, use_fast_math: tl.constexp # Should be # out_sum = lhs_sum * lhs_scale + rhs_sum * rhs_scale - # but since rhs_sum is all 1, we can simpliy it. + # but since rhs_sum is all 1, we can simplify it. out_sum = lhs_sum * lhs_scale + rhs_scale return out_max, out_sum @@ -345,7 +353,7 @@ def pack_value_flag( DTYPE_PACK: tl.constexpr, ): # Workaround for triton bug, tensor.to doesn't unwrap constexpr values - DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT) + DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK) return flag.to(DTYPE_PACK) | (uv << bitwidth) @@ -358,8 +366,8 @@ def unpack_value( DTYPE_VALUE_AS_UINT, ): # Workaround for triton bug, tensor.to doesn't unwrap constexpr values - DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE) - DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT) + DTYPE_VALUE = tl.core._unwrap_if_constexpr(DTYPE_VALUE) + DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT) return value_uint.to(DTYPE_VALUE, bitcast=True) @@ -452,7 +460,7 @@ def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combi block_value: Scalar value for this block, must be 64-bits wide index: Scalar index of this block relative to the current scan combine_fn: Function ``(value, value) -> value`` which is scanned over - init: Scalar value equal to the identiy of combine_fn + init: Scalar value equal to the identity of combine_fn """ # Publish block sum so subsequent blocks don't get stuck waiting for us if index > 0: @@ -682,7 +690,7 @@ def x_grid_barrier(sem): tl.debug_barrier() -def triton_builtin(f: _T) -> _T: +def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]: """ Decorator to mark a function as a Triton built-in function. These functions are evaluated at compile time. @@ -693,8 +701,18 @@ def triton_builtin(f: _T) -> _T: Returns: function: The same function, marked as a Triton built-in. """ - f.__triton_builtin__ = True # type: ignore[attr-defined] - return f + if builtins_use_semantic_kwarg: + # support Triton before and after https://github.com/triton-lang/triton/pull/7054 + @wraps(f) + def wrapper(*args, **kwargs): + kwargs["_builder"] = kwargs["_semantic"] + del kwargs["_semantic"] + return f(*args, **kwargs) + else: + wrapper = f # type: ignore[assignment] + + wrapper.__triton_builtin__ = True # type: ignore[attr-defined] + return wrapper @triton_builtin diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 172a8620354b52..bee87251d6e904 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -745,7 +745,8 @@ def _get_args_with_constexprs(self, args, launcher): # so we can sort them by index. constexpr_args: list[tuple[int, Any]] = [] for arg_name, arg_val in launcher.config.kwargs.items(): - constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) + if arg_name in self.fn.arg_names: + constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) constexpr_args.sort() new_args = [*args] @@ -913,15 +914,22 @@ def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]: return self.maybe_clone_args(OrderedSet(), *args, **kwargs) def benchmark_all_configs(self, *args, **kwargs): - with dynamo_timed( - "CachingAutotuner.benchmark_all_configs", - log_pt2_compile_event=True, - metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, - dynamo_compile_column_us="runtime_triton_autotune_time_us", - compile_id=self.compile_id, - is_backward=self.is_backward, - log_waitcounter=True, - waitcounter_name_override="triton_autotuner", + with ( + dynamo_timed( + "CachingAutotuner.benchmark_all_configs", + log_pt2_compile_event=True, + metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, + dynamo_compile_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, + log_waitcounter=True, + waitcounter_name_override="triton_autotuner", + ), + # Temporarily disable due to spam + # compilation_callback.callback_handler.install_callbacks( + # compilation_callback.CallbackTrigger.TRITON_AUTOTUNING, + # str(self.compile_id), + # ), ): timings = { launcher: self.bench(launcher, *args, **kwargs) @@ -969,7 +977,11 @@ def autotune_to_one_config(self, *args, **kwargs): ) if self.save_cache_hook: - self.save_cache_hook(launcher.config, self.autotune_time_taken_ns) + self.save_cache_hook( + launcher.config, + self.autotune_time_taken_ns, + triton_cache_hash=launcher.cache_hash, + ) def save_gpu_kernel(self, stream, launcher): key = self.inductor_meta.get("kernel_name", None) # unique kernel name @@ -997,13 +1009,16 @@ def save_gpu_kernel(self, stream, launcher): "triton_meta": self.triton_meta, "def_args": launcher.def_args, "call_args": launcher.call_args, + "global_scratch": launcher.global_scratch, } from torch._inductor.codecache import CudaKernelParamCache bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") binary = launcher.bin.asm[bin_type] # Also store asm code which can be used for debugging and generating cpp package - asm_type = {"hip": "amdgcn", "cuda": "ptx"}.get(self.device_props.type, None) + asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get( + self.device_props.type, None + ) asm = launcher.bin.asm.get(asm_type, None) CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type) @@ -1092,6 +1107,15 @@ def run( benchmark_run=False, **kwargs, ): # type:ignore[override] + if hasattr(triton, "set_allocator"): + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty( + size, dtype=torch.int8, device=self.device_props.type + ) + + triton.set_allocator(alloc_fn) + if self.triton_interpret: args, grid = self._interpret_args_grid(args, self.configs[0]) return self.fn[grid]( @@ -1119,6 +1143,10 @@ def run( if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): self.save_gpu_kernel(stream, launcher) + # PyTorch execution trace replay calls CachingAutotuner::run() instread of calls launcher + # so _RecordFunctionFast need to capture the args into CachingAutotuner::run() + # make a copy here to avoid mutating the original args + args_without_constexprs = tuple(args) args = self._get_args_with_constexprs(args, launcher) if self.dump_launch_params: @@ -1144,7 +1172,7 @@ def run( with torch._C._profiler._RecordFunctionFast( self.inductor_meta.get("kernel_name", "triton kernel"), - args, + args_without_constexprs, profiler_kwargs, ): return launcher( @@ -1432,6 +1460,7 @@ def make_launcher(self) -> LauncherType: launcher.n_regs = self.kernel.n_regs # type: ignore[attr-defined] launcher.n_spills = self.kernel.n_spills # type: ignore[attr-defined] launcher.shared = self.kernel.shared # type: ignore[attr-defined] + launcher.cache_hash = triton_hash_to_path_key(self.kernel.hash) # type: ignore[attr-defined] launcher.store_cubin = False # type: ignore[attr-defined] launcher._is_static = True # type: ignore[attr-defined] return launcher @@ -1541,6 +1570,10 @@ def make_launcher(self) -> LauncherType: launch_enter = knobs.runtime.launch_enter_hook launch_exit = knobs.runtime.launch_exit_hook + import math as math_lib + + import torch as torch_lib + scope = { "grid_meta": cfg.kwargs, "bin": binary, @@ -1571,6 +1604,8 @@ def make_launcher(self) -> LauncherType: ), "function": get_first_attr(binary, "function", "cu_function"), "runner": get_first_attr(binary, "run", "c_wrapper"), + "math": math_lib, + "torch": torch_lib, } if not hasattr(binary, "launch_metadata"): @@ -1620,6 +1655,7 @@ def make_launcher(self) -> LauncherType: launcher.n_regs = getattr(binary, "n_regs", None) launcher.n_spills = getattr(binary, "n_spills", None) launcher.shared = binary_shared + launcher.cache_hash = triton_hash_to_path_key(binary.hash) launcher.store_cubin = self.inductor_meta.get("store_cubin", False) # store this global variable to avoid the high overhead of reading it when calling run if launcher.store_cubin: @@ -1637,6 +1673,10 @@ def make_launcher(self) -> LauncherType: ] launcher.def_args = def_args launcher.call_args = call_args + kernel_metadata = getattr(self.kernel, "metadata", None) + launcher.global_scratch = getattr( + kernel_metadata, "global_scratch_size", None + ) return launcher @@ -2141,7 +2181,9 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()} -def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1): +def triton_config_tiled_reduction( + size_hints, x, y, r, num_stages=1, register_intensive=False +): """ Construct a tile reduction triton config with some adjustment heuristics based on size_hints. Size_hints is a tuple of numels in @@ -2167,12 +2209,15 @@ def total_numel() -> int: for prefix in sorted(rnumels): while rnumels[prefix] < size_hints[prefix] and total_numel() < target: rnumels[prefix] *= 2 - while y < size_hints[1] and total_numel() < target: + while y < size_hints["y"] and total_numel() < target: y *= 2 cfg = _get_config({"x": x, "y": y, **rnumels}) num_warps = _num_warps(total_numel() // 256, min_num_warps=1) - check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1]) + num_warps = _num_warps( + num_warps, max_num_warps=16, register_intensive=register_intensive + ) + check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) check_max_block(cfg) return Config(cfg, num_warps=num_warps, num_stages=num_stages) @@ -2295,22 +2340,47 @@ def _reduction_configs( MAX_R0_BLOCK = 1024 register_intensive = True - contiguous_config = triton_config_reduction( - size_hints, + def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): + # For 3D case with tiling scores, create an adapted version + if "y" in size_hints: + assert "tiling_scores" in inductor_meta + return adapt_config_for_tiling( + size_hints, + inductor_meta["tiling_scores"], + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + ) + else: + # For other cases, use the original function + return triton_config_reduction( + size_hints, + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + ) + + contiguous_config = make_config( 1, - rnumel if 256 <= rnumel < MAX_R0_BLOCK else MAX_R0_BLOCK, + min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) - outer_config = triton_config_reduction( - size_hints, 64, 8, register_intensive=register_intensive - ) - tiny_config = triton_config_reduction( - size_hints, + outer_config = make_config(64, 8, register_intensive=register_intensive) + tiny_config = make_config( 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) - if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"): + # For 3d tiling, default to more autotuning initially + if "y" in size_hints: + pass + elif inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ): pass # skip all these cases elif reduction_hint == ReductionHint.INNER: return [contiguous_config] @@ -2319,20 +2389,96 @@ def _reduction_configs( elif reduction_hint == ReductionHint.OUTER_TINY: return [tiny_config] if disable_pointwise_autotuning(inductor_meta): - return [triton_config_reduction(size_hints, 32, 128)] + return [make_config(32, 128)] return [ contiguous_config, outer_config, tiny_config, - triton_config_reduction(size_hints, 64, 64), - triton_config_reduction(size_hints, 8, 512), + make_config(64, 64), + make_config(8, 512), # halve the XBLOCK/Rn_BLOCK compared to outer_config # TODO: this may only be beneficial when each iteration of the reduction # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 - triton_config_reduction(size_hints, 64, 4, num_warps=8), + make_config(64, 4, num_warps=8), ] +def match_target_block_product( + size_hints, tiling_scores, target_block_product, min_block_size=1 +): + """ + Distribute block sizes across dimensions according to tiling scores, + aiming to match a target product of block sizes. + """ + total_score = sum(tiling_scores.values()) + if total_score == 0: + # just assume even score with no minimum block size + min_block_size = 1 + tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product) + + # First, give each coalescing dimension at least min_block_size + block_sizes = {} + relative_scores = {} + curr_block_product = 1 + + for dim, score in tiling_scores.items(): + if score == 0: + block_sizes[dim] = 1 + continue + + block_sizes[dim] = min_block_size + curr_block_product *= min_block_size + relative_scores[dim] = score / total_score + + # Scale up dimensions by their relative scores until we reach the target + while curr_block_product < target_block_product and len(relative_scores): + dim, score = max(relative_scores.items(), key=lambda item: item[1]) + + # Check if we've hit the max for this dimension + if ( + block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()] + or block_sizes[dim] >= size_hints[dim] + ): + del relative_scores[dim] + continue + + block_sizes[dim] *= 2 + relative_scores[dim] /= 2 + curr_block_product *= 2 + + return block_sizes + + +def adapt_config_for_tiling( + size_hints, + tiling_scores, + original_x, + original_r, + num_warps=None, + num_stages=1, + register_intensive=False, + persistent_reduction=False, +) -> Config: + """ + Create an adapted configuration based on tiling scores, + redistributing the same total block size (x * r) according to tiling scores. + """ + assert all(s in tiling_scores for s in size_hints) + target_block_product = original_x * original_r + block_sizes = match_target_block_product( + size_hints, tiling_scores, target_block_product + ) + + return triton_config_tiled_reduction( + size_hints, + block_sizes["x"], + block_sizes["y"], + block_sizes["r0_"], + num_stages=num_stages, + register_intensive=register_intensive, + ) + + def reduction( size_hints, reduction_hint=False, @@ -2413,14 +2559,37 @@ def _persistent_reduction_configs( xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) - configs = [ - triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) - if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel) - ] + MAX_PERSISTENT_BLOCK_NUMEL = 4096 + + if "y" not in size_hints: + configs = [ + triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) + for xblock in (1, 8, 32, 128) + if xblock == 1 + or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) + ] + else: + configs = [] + assert "tiling_scores" in inductor_meta + x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} + for target_block_size in (1, 8, 32, 64, 128): + if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: + continue + + block_sizes = match_target_block_product( + size_hints, x_y_scores, target_block_size + ) + configs.append( + triton_config_tiled_reduction( + size_hints, block_sizes["x"], block_sizes["y"], rnumel + ) + ) + # defer to more autotuning, initially + if "y" in size_hints: + pass # TODO(jansel): we should be able to improve these heuristics - if reduction_hint == ReductionHint.INNER and rnumel >= 256: + elif reduction_hint == ReductionHint.INNER and rnumel >= 256: configs = configs[:1] elif reduction_hint == ReductionHint.OUTER: configs = configs[-1:] diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index e3f73f3c54d302..61f862f847e12a 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -15,6 +15,7 @@ import typing from collections import Counter, defaultdict from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec, TypeAlias if TYPE_CHECKING: @@ -75,7 +76,9 @@ fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") -PartitionType = list["BaseSchedulerNode"] +PartitionType: TypeAlias = list["BaseSchedulerNode"] +_T = TypeVar("_T") +_P = ParamSpec("_P") @dataclasses.dataclass @@ -212,7 +215,7 @@ def __init__(self, scheduler: Scheduler) -> None: def _init_from_node(self, node: ir.Operation) -> None: self.node: Optional[ir.Operation] = node - self.ancestors = OrderedSet[str]() + self.ancestors: OrderedSet[str] = OrderedSet() self.last_usage = OrderedSet[ str ]() # buffers that won't be used after this kernel @@ -325,7 +328,7 @@ def used_buffer_names(self) -> OrderedSet[str]: ) def used_or_aliased_buffer_names(self) -> OrderedSet[str]: - used_names = OrderedSet[str]() + used_names: OrderedSet[str] = OrderedSet() deps = [ dep.name @@ -477,7 +480,7 @@ def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool: buf_name = buf_to_be_inplaced.get_name() # Dedup read/writes with equivalent indices # TODO - would be nice if we could just cache accesses on ReadWrites, - # and inforce variant that this class & members are functional.. + # and enforce variant that this class & members are functional.. deps: OrderedSet[Dep] = OrderedSet() for user in buf_to_be_inplaced.users: user_node = user.node @@ -1016,13 +1019,14 @@ def __init__( def _compute_attrs( self, extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, - recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + recompute_sizes_body_func: Optional[Callable[_P, _T]] = None, ) -> None: assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) - self._sizes, self._body = self.node.simplify_and_reorder( + self._sizes, body = self.node.simplify_and_reorder( extra_indexing_constraints=extra_indexing_constraints, recompute_sizes_body_func=recompute_sizes_body_func, ) + self._body = body # type: ignore[assignment] device = self.node.get_device_or_error() group_fn = self.scheduler.get_backend(device).group_fn @@ -1079,7 +1083,7 @@ def refresh_dependencies( # TODO(shunting) if this cause compilation time increase when # enabling LOAF by default, try just clearing the specific cache - # entry by using a customized cache implemetation rather than + # entry by using a customized cache implementation rather than # lru_cache. SIMDScheduling.candidate_tilings.cache_clear() @@ -1187,6 +1191,17 @@ def ranges_from_index_vars( return var_ranges def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: + """ + Generate code for this node using the provided index variables. + + This method sets up the appropriate context for code generation, including + simplifying indexing expressions based on the variable ranges, and then + calls the node's body function with the index variables. + + Args: + index_vars: A sequence of sequences of sympy expressions representing + the index variables for each dimension of the computation. + """ var_ranges = self.ranges_from_index_vars(index_vars) try: with ( @@ -1238,7 +1253,7 @@ def can_inplace(self, read_dep: dependencies.Dep) -> bool: @cache_on_self def _get_atomic_add_buffers(self) -> OrderedSet[str]: - buffers_store_as_atomic_add = OrderedSet[str]() + buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet() if isinstance(self._body, LoopBody): for node in self._body.get_nodes(): if ( @@ -1441,7 +1456,7 @@ def set_last_usage( super().set_last_usage(future_used_buffers, mutation_real_name) # Set self.last_usage on the snodes # This will be used for optimisations within the kernel - future_used_buffers = OrderedSet[str]() + future_used_buffers: OrderedSet[str] = OrderedSet() for node in reversed(self.snodes): node.set_last_usage(future_used_buffers, mutation_real_name) future_used_buffers.update(node.last_usage) @@ -1934,7 +1949,7 @@ def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> b def pick_loop_order( stride_lengths: list[list[int]], sizes: Sequence[sympy.Expr], - priority_idx: tuple[int, ...] = (), + priority_idx: Sequence[int] = (), ) -> list[int]: """ A heuristic to decide loop iteration orders. This has not been well @@ -2032,7 +2047,7 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.post_grad_graph_id = next(_post_grad_graph_counter) self._graph_partition_counter = itertools.count() - self.completed_operations = OrderedSet[str]() + self.completed_operations: OrderedSet[str] = OrderedSet() self.available_buffer_names = OrderedSet( [ *V.graph.graph_inputs.keys(), @@ -2101,6 +2116,8 @@ def _init(self, nodes: list[ir.Operation]) -> None: if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) self.nodes = self.fuse_nodes(self.nodes) + if config._post_fusion_custom_pass is not None: + self.nodes = config._post_fusion_custom_pass(self.nodes) self.merge_loops() self.finalize_multi_template_buffers() if config.combo_kernels: @@ -2132,7 +2149,7 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.debug_draw_graph() # used during codegen: - self.buffer_names_to_free = OrderedSet[str]() + self.buffer_names_to_free: OrderedSet[str] = OrderedSet() # fx graph node to the position it appears in the graph # for debug attribution @@ -2192,7 +2209,7 @@ def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode: raise NotImplementedError(node) def create_foreach_nodes(self) -> None: - removed_node_names = OrderedSet[str]() + removed_node_names: OrderedSet[str] = OrderedSet() fe_nodes = [] kept_node_names = self.name_to_fused_node.keys() @@ -2233,9 +2250,7 @@ def compute_dependencies(self) -> None: mutation properly. """ - T = TypeVar("T") - - class DedupList(Generic[T]): + class DedupList(Generic[_T]): """ This data structure behaves like a list except it makes sure the elements remain unique. @@ -2247,19 +2262,19 @@ class DedupList(Generic[T]): def __init__( self, - items: Optional[list[T]] = None, - membership: Optional[OrderedSet[T]] = None, + items: Optional[list[_T]] = None, + membership: Optional[OrderedSet[_T]] = None, ) -> None: self.items = items or [] self.membership = membership or OrderedSet() - def append(self, node_user: T) -> None: + def append(self, node_user: _T) -> None: if node_user in self.membership: return self.items.append(node_user) self.membership.add(node_user) - def __add__(self, other: DedupList[T]) -> DedupList[T]: + def __add__(self, other: DedupList[_T]) -> DedupList[_T]: new_membership = OrderedSet.union(self.membership, other.membership) new_items = self.items + [ x for x in other.items if x not in self.membership @@ -2513,7 +2528,7 @@ def visit(n: BaseSchedulerNode) -> None: return result def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]: - unmet_deps = OrderedSet[str]() + unmet_deps: OrderedSet[str] = OrderedSet() if isinstance( snode, ( @@ -2565,7 +2580,7 @@ def compute_ancestors(self) -> None: # note self.nodes is topologically sorted name_to_ancestors: dict[str, OrderedSet[str]] = {} for node in self.nodes: - ancestors = OrderedSet[str]() + ancestors: OrderedSet[str] = OrderedSet() for dep in node.unmet_dependencies: dep_node_name = self.name_to_buf[dep.name].defining_op_name() ancestors.add(dep_node_name) @@ -2694,7 +2709,7 @@ def finalize_multi_template_buffers(self) -> None: choice finalized through fusion. In the case of an extern choice, this will result in replacing the SchedulerNode. - If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choie + If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choice will force completion of compilation and benchmarking. """ @@ -2752,7 +2767,7 @@ def replace_operation_buffer( continue out_tensorbox = min_node_unfused.output_node() - out_storage = out_tensorbox.data + out_storage = out_tensorbox.data # type: ignore[union-attr] assert isinstance(out_storage, ir.StorageBox) out_buffer = out_storage.data assert isinstance(out_buffer, ir.OperationBuffer) @@ -2802,6 +2817,20 @@ def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool: for n in node_list ) + def _template_upcast( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + # Check if fusing an upcast onto a Triton template. If so, we want to benchmark + # the fusion to make sure that shared memory requirements are still met + return ( + isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) + and node1.node is not None + and node2.node is not None + and hasattr(node1.node, "get_dtype") + and hasattr(node2.node, "get_dtype") + and node1.node.get_dtype().itemsize < node2.node.get_dtype().itemsize + ) + def speedup_by_fusion( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> Union[bool, Callable[[], bool]]: @@ -2815,7 +2844,12 @@ def speedup_by_fusion( and isinstance(n.get_template_node(), ir.MultiTemplateBuffer) for n in (node1, node2) ) - if not config.benchmark_fusion and not is_multi_template: + + if ( + not self._template_upcast(node1, node2) + and not config.benchmark_fusion + and not is_multi_template + ): return True if ( @@ -2909,7 +2943,7 @@ def compile_kernel( future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] triton_choices = 0 for choice, unfused_time in sorted( - choice_timings.items(), key=lambda x: x[1] + choice_timings.items(), key=operator.itemgetter(1) ): if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): continue @@ -3046,7 +3080,10 @@ def benchmark_when_ready() -> bool: except NoTritonConfigsError: return False - + except RuntimeError as e: + if "out of resource" in str(e): + return False + raise except CompilationError as e: if "Loop-carried variable" in str(e): return True @@ -3227,7 +3264,11 @@ def get_possible_fusions( def check_all_pairs(nodes: list[BaseSchedulerNode]) -> None: for node1_index, node1 in enumerate(nodes): - for node2 in nodes[node1_index + 1 :]: + for node2 in nodes[ + node1_index + 1 : node1_index + + 1 + + config.max_fusion_buffer_group_pairwise_attempts + ]: key = (node1, node2) if key in seen: continue @@ -3319,7 +3360,7 @@ def can_fusion_increase_peak_memory( Return true if fusing the two nodes can potentially increasing peak memory. The implementation is more like a heuristic since we don't really know if we are at peak - or not when trying to fuse these two ndoes. The order of nodes may change later which makes the + or not when trying to fuse these two nodes. The order of nodes may change later which makes the peak memory estimation hard. Here is how we decide the LOWER BOUND of extra memory allocation if we fuse these 2 nodes: @@ -3359,7 +3400,7 @@ def _find_single_user_inputs( try: memory_overhead += int(key[2]) except ValueError: - # not an interger. Fallback is to fuse + # not an integer. Fallback is to fuse return False bw_saving = self.score_fusion_memory(node1, node2) @@ -3464,7 +3505,7 @@ def shared_data_after_reordering_loop( """ Right now just greedily reorder the loop of node1 to be compatible with node2, but ideally we should have some heuristics to reorder the loop for node2 - to be compatibile with node1 if that's more efficient. + to be compatible with node1 if that's more efficient. """ # TODO Don't do loop reordering for CPU for now. @@ -3505,7 +3546,7 @@ def shared_data_after_reordering_loop( return 0 # Pick the largest buffer to guide the loop reordering - _numel, lhs_dep, rhs_dep = max(candidates, key=lambda x: x[0]) + _numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0)) if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): return 0 @@ -3563,7 +3604,7 @@ def check_prologue_fusion_heuristics_fusable( # potential bad cache behavior and shared memory use. # we also want to avoid benchmarking reliably unprofitable fusions like downcasts from fp32 -> fp16 inside kernel. # allowing gathers by allowing increasing write_bytes by small factor - # TODO - make configurable per input, for insance, bias can fuse fp32 -> fp16 profitably + # TODO - make configurable per input, for instance, bias can fuse fp32 -> fp16 profitably BYTES_THRESHOLD_MULTIPLIER = 1.1 if read_bytes > (write_bytes * BYTES_THRESHOLD_MULTIPLIER): @@ -3657,7 +3698,7 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: allowed_prologue_inps = template.get_allowed_prologue_inps() unsupported_prologue_args = ( - OrderedSet(inp.get_name() for inp in template.inputs) + OrderedSet(inp.get_name() for inp in template.inputs) # type: ignore[union-attr] - allowed_prologue_inps ) @@ -4249,7 +4290,15 @@ def filter_symbols( *(get_input_node_symbols(node) for _, node in input_nodes.items()) ) - return filter_symbols(candidate_symbols) + candidate_symbols = filter_symbols(candidate_symbols) + + res: OrderedSet[sympy.Symbol] = OrderedSet() + for s in candidate_symbols: + symplified_s = V.graph.sizevars.simplify(s) + # use free_symbols only when s is simplified to an Integer or expr + res.update(symplified_s.free_symbols) + + return OrderedSet(sorted(res, key=operator.attrgetter("name"))) def get_graph_partition_signature( self, partitions: list[PartitionType], skip_cudagraphs: list[bool] @@ -4422,7 +4471,7 @@ def reorder_for_minimizing_partition( ) -> list[BaseSchedulerNode]: """ Reorder nodes to minimize the number of partitions via a bfs - topological sort. This is the optimal reodering such that the + topological sort. This is the optimal reordering such that the number of partitions cannot be reduced further. This may be sub-optimal for other metrics such as peak memory. This does not change relative orders of two cudagraphable nodes, nor the diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index a0d92409eb777d..d8c1ce5adfebec 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import builtins import contextlib import dataclasses import functools @@ -29,7 +28,7 @@ from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.testing import rand_strided from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state -from torch._inductor.utils import clear_on_fresh_inductor_cache +from torch._inductor.utils import clear_on_fresh_cache from torch.utils._filelock import FileLock from torch.utils._ordered_set import OrderedSet @@ -376,7 +375,7 @@ def __init__( self.template_out: Optional[str] = None self.ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined] - # Whe caching is enabled, the generated code is not dependent on the input nodes names, or + # When caching is enabled, the generated code is not dependent on the input nodes names, or # symbolic sizes names. # However, some of the variables returned by generate_and_load that are computed during the # triton template expansions (code generation) are dependent on those. @@ -495,7 +494,10 @@ def jit_lines(self): argdefs, _, signature, _ = self.args.python_argdefs() triton_meta: dict[str, Any] = { "signature": signature_to_meta( - signature, size_dtype=self.index_dtype, argdefs=argdefs + signature, + size_dtype=self.index_dtype, + argdefs=argdefs, + is_template=True, ), "device": DeviceProperties.create(self.output_node.get_device()), "constants": {}, @@ -1122,7 +1124,7 @@ def kernel_benchmark_extra_args(self) -> list[str]: ] -@functools.lru_cache(None) +@functools.cache def _jinja2_env(): try: import jinja2 @@ -1174,10 +1176,11 @@ def make_key( input_nodes: tuple[ir.IRNode], num_stages: int, num_warps: int, - call_sizes: list[sympy.core.symbol.Symbol], + call_sizes: Sequence[sympy.core.symbol.Symbol], prefix_args: int, suffix_args: int, epilogue_fn: Optional[Callable[..., Any]], + epilogue_fn_hash: Optional[str], subgraphs: Optional[list[ir.Buffer]], # has to be none to cache workspace_arg: Optional[WorkspaceArg], # has to be none to cache layout: ir.Layout, @@ -1206,12 +1209,16 @@ def has_flexible_layout() -> bool: return True return False + if epilogue_fn is identity: + assert epilogue_fn_hash is None + epilogue_fn_hash = "identity" + # we do not cache under those conditions right now. if ( has_flexible_layout() or subgraphs is not None or workspace_arg is not None - or epilogue_fn is not identity + or epilogue_fn_hash is None ): return None @@ -1228,6 +1235,7 @@ def has_flexible_layout() -> bool: "layout": layout_key(layout), "num_consumer_groups": num_consumer_groups, "num_buffers_warp_spec": num_buffers_warp_spec, + "epilogue_fn_hash": epilogue_fn_hash, "kwargs": kwargs, } ) @@ -1283,7 +1291,7 @@ def __init__( self.debug = debug self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache() - clear_on_fresh_inductor_cache(self._generated_code_cache) + clear_on_fresh_cache(self._generated_code_cache) # When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel # by adding all inputs. self.prologue_loads_all_inputs = prologue_loads_all_inputs @@ -1321,10 +1329,11 @@ def generate_and_load( input_nodes: tuple[ir.IRNode], num_stages: int, num_warps: int, - call_sizes: list[sympy.core.symbol.Symbol], + call_sizes: Sequence[sympy.core.symbol.Symbol], prefix_args: int, suffix_args: int, epilogue_fn: Optional[Callable[..., Any]], + epilogue_fn_hash: Optional[str], subgraphs: Optional[list[ir.Buffer]], workspace_arg: Optional[WorkspaceArg], num_consumer_groups: int, @@ -1349,6 +1358,7 @@ def generate_and_load( prefix_args, suffix_args, epilogue_fn, + epilogue_fn_hash, subgraphs, workspace_arg, layout, @@ -1516,9 +1526,10 @@ def generate( # type: ignore[override] prefix_args: int = 0, suffix_args: int = 0, epilogue_fn: Optional[Callable[..., Any]] = identity, + epilogue_fn_hash: Optional[str] = None, subgraphs: Optional[list[ir.Buffer]] = None, mutated_inputs: Optional[list[ir.IRNode]] = None, - call_sizes: Optional[list[sympy.core.symbol.Symbol]] = None, + call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None, workspace_arg: Optional[WorkspaceArg] = None, generate_with_caching=False, **kwargs, @@ -1557,13 +1568,14 @@ def generate( # type: ignore[override] prefix_args, suffix_args, epilogue_fn, + epilogue_fn_hash, subgraphs, workspace_arg, num_consumer_groups, num_buffers_warp_spec, layout, kwargs, - generate_with_caching, + generate_with_caching and self._cache_codegen_enabled_for_template, ) # May happen as result of dev by 0. @@ -1714,7 +1726,7 @@ def to_callable(self): def call_name(self): return f"extern_kernels.{self.name}" - @functools.lru_cache(None) # noqa: B019 + @functools.cache # noqa: B019 def hash_key(self): fn = self.to_callable() parts = [ @@ -1891,7 +1903,7 @@ def output_node(self): assert self.choice.op_overload is not None, ( "Please provide an op_overload to use ir.FallbackKernel" ) - inner = ir.FallbackKernel.create( + inner: ir.IRNode = ir.FallbackKernel.create( self.choice.op_overload, *self.input_nodes, **self.kwargs ) elif self.choice.kernel_creator is not None: @@ -1921,7 +1933,7 @@ def autoheuristic_id(self): return f"extern_{self.choice.name}" -@functools.lru_cache(None) +@functools.cache def get_mm_log_filename() -> Optional[str]: mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) if not mm_file_name: @@ -2040,7 +2052,7 @@ class NoValidChoicesError(RuntimeError): pass -@functools.lru_cache(None) +@functools.cache def get_num_workers() -> int: if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) @@ -2051,6 +2063,15 @@ def get_num_workers() -> int: else os.cpu_count() ) assert cpu_count + + # Divide the number of CPUs by the number of GPUs for distributed workloads + if ( + config.is_fbcode() + and torch.cuda.is_available() + and torch.cuda.device_count() > 0 + ): + cpu_count = cpu_count // torch.cuda.device_count() + return cpu_count @@ -2067,7 +2088,7 @@ def create_precompile_key( inputs_key, torch.get_float32_matmul_precision(), ] - + [choice.hash_key() for choice in choices] + + [choice.kernel_hash_key() for choice in choices] ) @@ -2088,6 +2109,39 @@ def create_precompile_key( None, ] +# Args to PreprocessingFunctions +# choices: list of ChoiceCaller objects to preprocess +# Returns: modified list of ChoiceCaller objects +PreprocessingFunction = Callable[[list[ChoiceCaller]], list[ChoiceCaller]] + + +def filter_choices_by_name_regex(choices: list[ChoiceCaller]) -> list[ChoiceCaller]: + """Filter choices based on autotune_choice_name_regex config.""" + if config.test_configs.autotune_choice_name_regex is not None: + return [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_name_regex, + c.name, + ) + ] + return choices + + +def filter_choices_by_desc_regex(choices: list[ChoiceCaller]) -> list[ChoiceCaller]: + """Filter choices based on autotune_choice_desc_regex config.""" + if config.test_configs.autotune_choice_desc_regex is not None: + return [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_desc_regex, + c.description, + ) + ] + return choices + class AlgorithmSelectorCache(PersistentCache): """ @@ -2108,13 +2162,28 @@ def __init__(self, *args, **kwargs) -> None: # first to benchmark it. share a single precompilation function for all lowerings # of a particular key self.precompile_cache: dict[str, Callable[[], None]] = {} + # cache for prescreening results to ensure deterministic candidate selection + self.prescreening_cache: dict[str, OrderedSet[str]] = {} # list of callbacks that are called after benchmarking self.feedback_saver_fns: list[FeedbackFunction] = [] + # list of callbacks that are called to preprocess choices + self.preprocessing_fns: list[PreprocessingFunction] = [] - clear_on_fresh_inductor_cache(self) + self._register_default_preprocessing_fns() + + # registers `self.cache_clear(...)` to be called when a fresh Inductor cache is requested + clear_on_fresh_cache(self) + + def _register_default_preprocessing_fns(self): + """Register default preprocessing functions.""" + # Note: broken out into its own function so that we can avoid clearing + # them (i.e. so we can restore them after clearing user provided ones) + self.add_preprocessing_fn(filter_choices_by_name_regex) + self.add_preprocessing_fn(filter_choices_by_desc_regex) def cache_clear(self) -> None: self.precompile_cache.clear() + self.prescreening_cache.clear() def __call__( self, @@ -2133,6 +2202,10 @@ def __call__( ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller + # Run preprocessing functions on choices + for preprocessing_fn in self.preprocessing_fns: + choices = preprocessing_fn(choices) + # Templates selected with input_gen_fns require specific input data to avoid IMA # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection # TODO(jgong5): support multi-template on CPU @@ -2141,25 +2214,6 @@ def __call__( # TODO - assert that we have not mutating kernels here - if config.test_configs.autotune_choice_name_regex is not None: - choices = [ - c - for c in choices - if re.search( - config.test_configs.autotune_choice_name_regex, - c.name, - ) - ] - if config.test_configs.autotune_choice_desc_regex is not None: - choices = [ - c - for c in choices - if re.search( - config.test_configs.autotune_choice_desc_regex, - c.description, - ) - ] - if mm_file_name := get_mm_log_filename(): M, K = input_nodes[-2].get_size()[:2] N = input_nodes[-1].get_size()[-1] @@ -2182,20 +2236,48 @@ def __call__( # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. return choices[0].output_node() - @functools.lru_cache(None) - def make_benchmark_fn(): - return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) - inputs_key = create_inputs_key(input_nodes) + # TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking + has_autotuned = False + + def benchmark(choices): + nonlocal has_autotuned + # TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking + has_autotuned = True + counters["inductor"]["select_algorithm_autotune"] += 1 + # TODO(nmacchioni): remove this layer of abstraction + # construct `benchmark_fn` which should pick between in-process and sub-process autotuning + benchmark_fn = self.make_benchmark_fn( + choices, input_nodes, layout, input_gen_fns + ) + # `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which + # maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds + return benchmark_fn(choices) + def autotune(choices): log.debug("Starting autotuning") + with dynamo_timed( f"{name}_template_autotuning", log_pt2_compile_event=True, dynamo_compile_column_us="compile_time_autotune_time_us", + metadata={ + "autotune_strides": ", ".join( + [str(n.get_stride()) for n in input_nodes] + ), + "autotune_dtypes": ", ".join( + [str(n.get_dtype()) for n in input_nodes] + ), + "autotune_shape": ", ".join( + ["x".join(map(str, n.get_size())) for n in input_nodes] + ), + "autotune_offset": ", ".join( + [str(n.get_layout().offset) for n in input_nodes] + ), + }, ): - return make_benchmark_fn()(choices) + return benchmark(choices) if config.autotune_in_subproc: # Initialize the suprocess pool so it will warmup early. @@ -2212,7 +2294,10 @@ def do_autotuning(choices, precompile_fn): precompile_elapse = time.time() - precompile_start_ts log.debug("Precompilation elapsed time: %.02fs", precompile_elapse) - candidates = self.prescreen_choices(choices) + candidates = self.prescreen_choices( + choices, name, inputs_key, self.prescreening_cache + ) + prescreening_elapse: Optional[float] = None if candidates: prescreening_start_ts = time.time() timings = self.lookup( @@ -2221,7 +2306,9 @@ def do_autotuning(choices, precompile_fn): inputs_key, autotune, ) - choices = self.prune_choices_postscreen(choices, timings) + choices = self.prune_choices_postscreen( + choices, timings, name, inputs_key, self.prescreening_cache + ) prescreening_elapse = time.time() - prescreening_start_ts log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse) @@ -2241,16 +2328,18 @@ def do_autotuning(choices, precompile_fn): ): raise NoValidChoicesError - if make_benchmark_fn.cache_info().currsize: - counters["inductor"]["select_algorithm_autotune"] += 1 - if ( - make_benchmark_fn.cache_info().currsize + has_autotuned or log.getEffectiveLevel() == logging.DEBUG or config.trace.log_autotuning_results ): self.log_results( - name, input_nodes, timings, autotune_elapse, precompile_elapse + name, + input_nodes, + timings, + autotune_elapse, + precompile_elapse, + prescreening_elapse, ) def profiler_bench_function(): @@ -2262,9 +2351,7 @@ def profiler_bench_function(): profile_bandwidth_with_do_bench_using_profiling=True, autotune_in_subproc=False, ): - return self.make_benchmark_fn( - choices, input_nodes, layout, input_gen_fns - )(choices) + return benchmark(choices) for feedback_fn in self.feedback_saver_fns: # re-benchmarking the same choices with profiler is a bit expensive, so pass it in as a thunk. @@ -2307,7 +2394,7 @@ def get_timings(): # We take the union of allowed prologue inputs from all choices, # and, within benchmark fusion, don't allow prologue fusion for - # choices which dont support the whole union. + # choices which don't support the whole union. allowed_prologue_inps: OrderedSet[str] = OrderedSet() for c in choices: if isinstance(c, TritonTemplateCaller): @@ -2324,13 +2411,34 @@ def get_timings(): ) timings = do_autotuning(choices, precompile_fn) - if timings == {} or choices[0] not in timings: - return choices[0].output_node() - selected_key = builtins.min(timings, key=timings.__getitem__) - selected_choice = selected_key.output_node() - log.debug("selected choice: %s", str(selected_choice)) - return selected_choice + # if timings is empty, we really have no choice but to return a semi-random + # choice. returning the first `ExternKernelCaller` is probably the safest bet + # in this case, since it will generally be the ATen kernel. if there are no + # `ExternKernelCaller`s to return, then returning the 0th kernel is our next + # best option (ideally we'd fail whenever there is no ATen kernel to fallback + # to, but that's not trivial to figure out) + if timings == {}: + for choice in choices: + if isinstance(choice, ExternKernelCaller): + node = choice.output_node() + log.debug( + "Autotuning returned empty timings, falling back to first `ExternKernelCaller`: %s", + node, + ) + return node + node = choices[0].output_node() + log.debug( + "Autotuning returned empty timings, falling back to first choice: %s", + node, + ) + return node + + # if we got any timings at all, pick the best of those + choice = min(timings, key=timings.__getitem__) + node = choice.output_node() + log.debug("Autotuning selected choice: %s", node) + return node def make_precompile_fn( self, @@ -2381,11 +2489,6 @@ def no_op(*args, **kwargs): log.debug("Found all %d timings in cache, returning no_op", len(timings)) return no_op - if config.search_autotune_cache and not ( - config.max_autotune or config.max_autotune_gemm - ): - return no_op - precompile_key = create_precompile_key(name, inputs_key, choices) if precompile_func := self.precompile_cache.get(precompile_key): log.debug("Precompile function found in cache, returning it") @@ -2434,11 +2537,11 @@ def on_complete(future): seen_choices: OrderedSet[str] = OrderedSet() for c in choices: # Skip choices which we have already issued a precompile - if c.hash_key() in seen_choices: + if c.kernel_hash_key() in seen_choices: log.debug("Skipping already seen choice: %s", c) continue else: - seen_choices.add(c.hash_key()) + seen_choices.add(c.kernel_hash_key()) if hasattr(c, "precompile"): triton_cuda_choice = isinstance(c, TritonTemplateCaller) and isinstance( @@ -2458,7 +2561,7 @@ def on_complete(future): future.add_done_callback(on_complete) futures[future] = c - @functools.lru_cache(None) + @functools.cache @restore_stdout_stderr() def wait_on_futures(): log.debug("Waiting on futures") @@ -2564,11 +2667,11 @@ def benchmark_choice( ) -> float: is_extern = isinstance(choice, (ExternKernelCaller, SubgraphChoiceCaller)) benchmark_tensors = autotune_args.get_benchmark_tensors(is_extern) - inpts, output = benchmark_tensors.unpack() + inputs, output = benchmark_tensors.unpack() output.zero_() - result = choice.benchmark(*inpts, out=output) + result = choice.benchmark(*inputs, out=output) device_type = next( - (tensor.device.type for tensor in inpts if is_gpu(tensor.device.type)), + (tensor.device.type for tensor in inputs if is_gpu(tensor.device.type)), "cuda", ) device_interface = get_interface_for_device(device_type) @@ -2704,11 +2807,39 @@ def make_benchmark_fn( @staticmethod def prescreen_choices( choices: list[ChoiceCaller], + name: str, + inputs_key: str, + prescreen_cache: dict[str, OrderedSet[str]], ) -> list[ChoiceCaller]: """ - Add prescreening phase. Motivation is to reduce the number of autotuning needed, - for example, when there are runtime params. + Figure out what choices need to be prescreened before autotuning with runtime + params. + + Prescreening is a process of reducing the number of autotuning for choices with + runtime params via a two stage autotuning process. First, we fix a set of runtime + params (here we use swizzle=2) and run autotuning to get a set of candidates. + Then, we run autotuning again with the candidates and the full set of runtime + params. + + Since have the concept of runtime params, we need to differentiate between + choice's hash_key and choice's kernel_hash_key. The former includes information + like runtime params, while the latter does not. prescreen_cache, if exists, stores + the set of hash_key that should win the prescreening. + + Right now, only CUTLASS choices have runtime params. """ + # Create a cache key for prescreening results + prescreen_key = f"{name}:{inputs_key}" + + # Check if we have cached prescreening results (prescreen_winners) + if prescreen_key in prescreen_cache: + prescreen_winners = [ + choice + for choice in choices + if choice.hash_key() in prescreen_cache[prescreen_key] + ] + return prescreen_winners + # prescreen cutlass from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2737,40 +2868,89 @@ def prescreen_choices( def prune_choices_postscreen( choices: list[ChoiceCaller], candidate_timings: dict[ChoiceCaller, float], + name: str, + inputs_key: str, + prescreen_cache: dict[str, OrderedSet[str]], ) -> list[ChoiceCaller]: """ Prune the choices after prescreening. """ from .codegen.cuda.cuda_kernel import CUDATemplateCaller - if len(candidate_timings) < 10: - return [] + prescreen_key = f"{name}:{inputs_key}" + + # Check if we have cached postscreen results + if prescreen_key in prescreen_cache: + # candidate_timings are from choices that have won prescreening already + winner_kernel_hashes = [ + candidate.kernel_hash_key() for candidate in candidate_timings + ] + + pruned_choices = [ + choice + for choice in choices + if not isinstance(choice, CUDATemplateCaller) + or choice.kernel_hash_key() in winner_kernel_hashes + ] + return pruned_choices log.debug("Before pruning using prescreening timings, %d choices", len(choices)) sorted_candidates = sorted( candidate_timings.keys(), key=lambda choice: candidate_timings[choice] ) + + # Print prescreening timings + if ( + candidate_timings + and PRINT_AUTOTUNE + and config.autotune_num_choices_displayed != 0 + ): + n = config.autotune_num_choices_displayed + top_k = sorted_candidates[:n] + best = top_k[0] + best_time = candidate_timings[best] + + lines = ["PRESCREENING CANDIDATE TIMINGS"] + for choice in top_k: + result = candidate_timings[choice] + if result: + lines.append( + f" {choice.name} {result:.4f} ms {best_time / result:.1%} {choice.description}" + ) + else: + lines.append( + f" {choice.name} {result:.4f} ms " + ) + + log.info("\n".join(lines)) num_to_keep = max(int(math.sqrt(len(choices)) / 4), 8) # prune choices based on prescreening timings candidates_to_prune = OrderedSet( - candidate.hash_key() for candidate in sorted_candidates[num_to_keep:] + candidate.kernel_hash_key() for candidate in sorted_candidates[num_to_keep:] ) + winner_hashes: OrderedSet[str] = OrderedSet() for candidate in sorted_candidates[:num_to_keep]: if candidate_timings[candidate] == float("inf"): - candidates_to_prune.add(candidate.hash_key()) + candidates_to_prune.add(candidate.kernel_hash_key()) else: + winner_hashes.add(candidate.hash_key()) if isinstance(candidate, CUDATemplateCaller): candidate.bmreq.ensure_dll_loaded() - choices = [ + pruned_choices = [ choice for choice in choices - if choice.hash_key() not in candidates_to_prune # type: ignore[attr-defined] + if choice.kernel_hash_key() not in candidates_to_prune # type: ignore[attr-defined] ] - log.debug("After pruning using prescreening timings, %d choices", len(choices)) - return choices + # Cache the hash_key of winners of prescreening + prescreen_cache[prescreen_key] = winner_hashes + + log.debug( + "After pruning using prescreening timings, %d choices", len(pruned_choices) + ) + return pruned_choices @staticmethod def log_results( @@ -2779,6 +2959,7 @@ def log_results( timings: dict[ChoiceCaller, float], elapse: float, precompile_elapse: float, + prescreening_elapse: Optional[float] = None, ): V.debug.log_autotuning_results( name, input_nodes, timings, elapse, precompile_elapse @@ -2867,9 +3048,16 @@ def get_choice_info(choice): autotune_type_str = ( "SubProcess" if config.autotune_in_subproc else "SingleProcess" ) + prescreening_msg = ( + f" and {prescreening_elapse:.4f} seconds prescreening" + if prescreening_elapse is not None + else "" + ) sys.stderr.write( f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}" - f" seconds precompiling for {len(timings)} choices\n" + f" seconds precompiling for {len(timings)} choices" + + prescreening_msg + + "\n" ) @staticmethod @@ -2957,14 +3145,34 @@ def key_of(node): def add_feedback_saver(self, fn: FeedbackFunction): self.feedback_saver_fns.append(fn) + def add_preprocessing_fn(self, fn: PreprocessingFunction): + self.preprocessing_fns.append(fn) + + def clear_preprocessing_fns(self, clear_defaults: bool = False): + """Clear preprocessing functions. + + Args: + clear_defaults: If True, clears all functions including defaults. + If False, clears only user-added functions and re-registers defaults. + """ + self.preprocessing_fns.clear() + if not clear_defaults: + self._register_default_preprocessing_fns() + _ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None -def autotune_select_algorithm(*args, **kwargs): +def get_algorithm_selector_cache() -> AlgorithmSelectorCache: + """Get the global algorithm selector cache, creating it if it doesn't exist.""" global _ALGORITHM_SELECTOR_CACHE if _ALGORITHM_SELECTOR_CACHE is None: _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + return _ALGORITHM_SELECTOR_CACHE + + +def autotune_select_algorithm(*args, **kwargs): + cache = get_algorithm_selector_cache() if "return_multi_template" not in kwargs: kwargs["return_multi_template"] = ( @@ -2974,16 +3182,49 @@ def autotune_select_algorithm(*args, **kwargs): if "precompilation_timeout_seconds" not in kwargs: kwargs["precompilation_timeout_seconds"] = config.precompilation_timeout_seconds - return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) + return cache(*args, **kwargs) def add_feedback_saver( fn: FeedbackFunction, ): - global _ALGORITHM_SELECTOR_CACHE - if _ALGORITHM_SELECTOR_CACHE is None: - _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() - _ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn) + cache = get_algorithm_selector_cache() + cache.add_feedback_saver(fn) + + +def add_preprocessing_fn( + fn: PreprocessingFunction, +): + """Add a preprocessing function to be applied to choices before autotuning. + + Preprocessing functions are called sequentially in the order they were registered, + with each function receiving the output of the previous one. They can filter, + reorder, transform, or modify the list of choices in any way. + + Args: + fn: A function that takes a list of ChoiceCaller objects and returns + a modified list of ChoiceCaller objects. + + Example: + def my_filter(choices): + # Filter out choices with certain names + return [c for c in choices if 'slow' not in c.name.lower()] + + add_preprocessing_fn(my_filter) + """ + cache = get_algorithm_selector_cache() + cache.add_preprocessing_fn(fn) + + +def clear_preprocessing_fns(clear_defaults: bool = False): + """Clear preprocessing functions at module level. + + Args: + clear_defaults: If True, clears all functions including defaults. + If False, clears only user-added functions and re-registers defaults. + """ + cache = get_algorithm_selector_cache() + cache.clear_preprocessing_fns(clear_defaults) def realize_inputs(*args): diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index dac88b82cc39b1..0aa6d30d2f8799 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -8,11 +8,7 @@ import sympy from sympy import Expr -from torch.fx.experimental.symbolic_shapes import ( - free_unbacked_symbols, - has_free_unbacked_symbols, - ShapeEnv, -) +from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols, ShapeEnv from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.symbol import symbol_is_type, SymT @@ -32,7 +28,7 @@ log = logging.getLogger(__name__) -def evaluate_expr( +def statically_known_true( shape_env: ShapeEnv, expr: Union[sympy.Basic, bool], axioms: Optional[tuple[sympy.Expr]] = None, @@ -308,33 +304,16 @@ def prune(index): return [x for x in sizes if x is not None], reindex, prune # Note - [On Statically Known] - # - # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system - # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was - # true, we add a guard and return True, otherwise, False. - # - # def maybe_guard_foo(args): - # if size_hinted_check(args): - # return False # No guard, no optim - # guard(args) # Make a guard - # return True # Safe to apply optimization - # - # The prior system incurred a guard, and green lit an optimization. - # - # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the - # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we - # return False. - # - # def maybe_guard_foo(args): - # if all_static(args): - # return True # Safe to apply optimization - # else: - # return False # No guard, no optim - - # See Note - [On Statically Known] - - def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool: - return evaluate_expr(self.shape_env, expr) + # The statically_known_* family of functions below NEVER guard, they could return True if the + # asked questions can be answered without guarding otherwise they return False. + # Those are similar to statically_known_true in symbolic_shapes.py but operate on sympy + # expressions instead of symnodes. + def statically_known_true(self, expr: Union[sympy.Basic, bool]) -> bool: + """ + Returns true if an expression is always true (symbolically or via guards), + false otherwise. Never add guards, or throw data dependent errors. + """ + return statically_known_true(self.shape_env, expr) def statically_known_equals( self, left: Union[Expr, int], right: Union[Expr, int] @@ -342,10 +321,11 @@ def statically_known_equals( """ Returns a bool indicating if it is sound to optimize as if left and right are equal. """ - return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type] + return self.statically_known_true(sympy.Eq(left, right)) # type: ignore[arg-type] - # See Note - [On Statically Known] - def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool: + def statically_known_list_equals( + self, left: Sequence[Expr], right: Sequence[Expr] + ) -> bool: """ Returns a bool indicating if it is sound to optimize as if left and right lists are equal. """ @@ -353,106 +333,109 @@ def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> b self.statically_known_equals(l, r) for l, r in zip(left, right) ) - # See Note - [On Statically Known] def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool: """ Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. """ expr = left <= right - return self.is_expr_static_and_true(expr) + return self.statically_known_true(expr) - # See Note - [On Statically Known] def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool: """ Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right. """ expr = left >= right - return self.is_expr_static_and_true(expr) + return self.statically_known_true(expr) - # See Note - [On Statically Known] def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool: """ Returns a bool indicating if it is sound to optimize as if left is less than right. """ expr = left < right - return self.is_expr_static_and_true(expr) + return self.statically_known_true(expr) - # See Note - [On Statically Known] def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool: """ Returns a bool indicating if it is sound to optimize as if left is greater than right. """ expr = left > right - return self.is_expr_static_and_true(expr) + return self.statically_known_true(expr) - # See Note - [On Statically Known] def statically_known_multiple_of( self, numerator: Expr, denominator: Union[Expr, int] ) -> bool: """ Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. """ - if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator): + # The reason we skip unbacked here is that we want to avoid the cost of trying to eval this symbolically. + if has_free_unbacked_symbols(numerator) or has_free_unbacked_symbols( + denominator + ): return False expr = sympy.Eq(numerator % denominator, 0) - return self.is_expr_static_and_true(expr) # type: ignore[arg-type] + return self.statically_known_true(expr) # type: ignore[arg-type] - # See Note - [On Statically Known] def statically_known_power_of_2(self, expr: Expr) -> bool: """ Returns a bool indicating if x is known to be a power of 2. """ return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr)) - # The guard functions require you to ALREADY KNOW that a particular - # condition holds. If you don't know (you want to guard on an expression - # being a particular value, and then get access to that value), use - # the evaluate functions. - - def guard_equals(self, left: Expr, right: Expr) -> Expr: - if isinstance(left, Expr): - left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] - if isinstance(right, Expr): - right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + # The expect/check functions require you to ALREADY KNOW that a particular + # condition holds. They are similar to expect_true in symbolic_shapes.py and + # torch.check but operates on sympy expressions instead of symnodes. + def expect_true(self, expr: Expr) -> bool: + """ + Use it when you already know that expr is true or should be true and want to + ensure that guards/runtime assertions are in place to ensure this in compiled + function. Unlike check, this WON'T raise an error if expr isn't actually true. + check Note [expect_true]. + """ + if not self.statically_known_true(expr): + return self.shape_env.guard_or_defer_runtime_assert( + expr, "sizevars.expect_true" + ) + return True - expr = sympy.Eq(left, right) - static_expr = self.shape_env._maybe_evaluate_static(expr) + def check(self, expr: Expr) -> None: + """ + Use it when you already know that expr is true or should be true and want to + ensure that guards/runtime assertions are in place to ensure this in compiled + function. Unlike expect_true, this WILL raise an error if expr isn't actually true. + check Note [expect_true]. + """ + expr = sympy_subs(expr, self.inv_precomputed_replacements) + assert self.expect_true(expr) - if static_expr is not None: - assert bool(static_expr) - return left + def check_equals(self, left: Expr, right: Expr) -> None: + """ + check(sympy.Eq(left, right)). - assert self.shape_env.defer_runtime_assert(expr, "guard_equals") + """ + self.check(sympy.Eq(left, right)) return left - def guard_leq(self, left: Expr, right: Expr) -> None: - return self.guard_lt(left, right + 1) + def check_equals_and_simplify(self, left: Expr, right: Expr) -> Expr: + """ + check(sympy.Eq(left, right)) and returns left after applying + inv_precomputed_replacements. + """ + self.check(sympy.Eq(left, right)) + return sympy_subs(left, self.inv_precomputed_replacements) - def guard_lt(self, left: Expr, right: Expr) -> None: - expr = sympy.Lt(left, right) - static_expr = self.shape_env._maybe_evaluate_static(expr) + def check_leq(self, left: Expr, right: Expr) -> None: + self.check(sympy.Le(left, right)) - if static_expr is not None: - assert bool(static_expr) - return + def check_lt(self, left: Expr, right: Expr) -> None: + self.check(sympy.Lt(left, right)) - assert self.shape_env.defer_runtime_assert(expr, "guard_lt") + # Similar to the functions guard_or_false/guard_or_true in symbolic_shapes.py + # but operates on sympy expressions instead of symnodes. see Note [guard_or_]. + def guard_or_false(self, left): + return self.evaluate_expr(left, fallback_value=False) - def guarded_order(self, seq): - """ - Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. - """ - seq = [*map(self.remove_precomputed_replacements, seq)] - seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)] - seq.sort() - order = [-1] * len(seq) - last_var = None - for new_index, (_, orig_index, var) in enumerate(seq): - order[orig_index] = new_index - if last_var is not None: - self.guard_leq(last_var, var) - last_var = var - return order + def guard_or_true(self, left): + return self.evaluate_expr(left, fallback_value=True) # The evaluate functions evaluate some symbolic sympy expression # (NB: not necessarily an Expr) and return what the concrete result @@ -466,10 +449,13 @@ def evaluate_expr( self, left: Union[Expr, sympy.logic.boolalg.Boolean], size_oblivious: bool = False, + fallback_value: Optional[bool] = None, ) -> bool: assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) return self.shape_env.evaluate_expr( - sympy.sympify(left), size_oblivious=size_oblivious + sympy.sympify(left), + size_oblivious=size_oblivious, + fallback_value=fallback_value, ) def evaluate_min(self, left: Expr, right: Expr) -> Expr: @@ -479,8 +465,8 @@ def evaluate_min(self, left: Expr, right: Expr) -> Expr: if isinstance(right, Expr): right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] try: - lv = self.size_hint(left) - rv = self.size_hint(right) + lv = self.size_hint_or_throw(left) + rv = self.size_hint_or_throw(right) except TypeError: # unbacked symints if left == right or self.statically_known_leq(left, right): return left @@ -495,10 +481,10 @@ def evaluate_min(self, left: Expr, right: Expr) -> Expr: f"evaluate_min({left}, {right}) with unbacked symints" ) from None if lv <= rv: - self.guard_leq(left, right) + self.check_leq(left, right) return left else: - self.guard_leq(right, left) + self.check_leq(right, left) return right def evaluate_max(self, left: Expr, right: Expr) -> Expr: @@ -508,15 +494,24 @@ def evaluate_max(self, left: Expr, right: Expr) -> Expr: min_val = self.evaluate_min(left, right) return right if min_val is left else left - def evaluate_static_shape(self, left: Union[Expr, int]) -> int: - if isinstance(left, int): - return left - right = self.size_hint(left) - self.guard_equals(left, sympy.Integer(right)) - return int(right) + def guard_int(self, expr: Union[Expr, int]) -> int: + """ + Similar to guard_int in symbolic_shapes.py, except this function works with SymPy + expressions instead of SymNodes. It extracts the value represented by expr from shapeEnv + and specialize the compiled graph on it. Raises an error if the result cannot be + determined due to unhinted or unbacked symbols. + """ + if isinstance(expr, int): + return expr + val = self.size_hint_or_throw(expr) + self.check_equals(expr, sympy.Integer(val)) + return int(val) - def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]: - return [self.evaluate_static_shape(x) for x in left] + def guard_int_seq(self, left: Sequence[Union[Expr, int]]) -> list[int]: + """ + Apply guard_int on a sequence of inputs. + """ + return [self.guard_int(x) for x in left] def remove_precomputed_replacements(self, expr: Expr) -> Expr: if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] @@ -563,9 +558,17 @@ def size_hint( log.debug("failed on: %s", out) raise + def size_hint_or_throw(self, expr: Union[Expr, int]) -> int: + out = self.symbolic_hint(expr) + try: + return int(out) + except Exception: + log.debug("failed on: %s", out, exc_info=True) + raise + def size_hints( self, - exprs: Iterable[Expr], + exprs: Iterable[Union[Expr, int]], *, fallback: Optional[int] = None, ) -> tuple[int, ...]: @@ -709,7 +712,7 @@ def atomically_apply_size_hint( } return expr.subs(size_dict) - def offset_var(self, index: Expr, vars: list[sympy.Symbol]) -> Expr: + def offset_var(self, index: Expr, vars: Sequence[sympy.Symbol]) -> Expr: """Extract offset part of an indexing expression""" index = self.simplify(index) return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0}) @@ -726,7 +729,7 @@ def stride_hints( result = [] for s in self.stride_vars(index, vars, support_vars): try: - result.append(self.size_hint(s)) + result.append(self.size_hint_or_throw(s)) except TypeError: result.append(0) return result @@ -776,11 +779,11 @@ def _check_args(x, div, mod, is_first): return False if is_first: - # first ModularIndexing should conatins a nested ModularIndex + # first ModularIndexing should contains a nested ModularIndex if not isinstance(x, ModularIndexing): return False else: - # second ModularIndexing should constains a non-negative + # second ModularIndexing should contains a non-negative # symbol if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( x, 0 @@ -811,7 +814,7 @@ def expand_floor_div( ) -> Union[bool, tuple[sympy.Expr, sympy.Expr]]: """ Expand the FloorDiv to the entire expression so that the expression may - be simplfied. + be simplified. E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables x1, x2, index expression 'x1 * 2b + x2' can be easily combined. diff --git a/torch/_inductor/standalone_compile.py b/torch/_inductor/standalone_compile.py index 93af8cc3209d98..1f58f81b17f61c 100644 --- a/torch/_inductor/standalone_compile.py +++ b/torch/_inductor/standalone_compile.py @@ -74,7 +74,7 @@ def save( key = cache_info.aot_autograd_artifacts[0] if format == "binary": - # cant assert that it is a file since it might not exist yet + # can't assert that it is a file since it might not exist yet assert not os.path.isdir(path) from torch.utils._appending_byte_serializer import BytesWriter @@ -118,7 +118,7 @@ def load( ) -> CompiledArtifact: with dynamo_timed("CompiledArtifact.load"): if format == "binary": - # cant assert that it is a file since it might not exist yet + # can't assert that it is a file since it might not exist yet assert not os.path.isdir(path) with open(path, "rb") as file: artifacts = file.read() @@ -155,7 +155,12 @@ def load( ) entry = AOTAutogradCache._lookup( - key, local=True, remote=False, args=[], cache_info={} + key, + local=True, + remote=False, + args=[], + cache_info={}, + aot_config=None, ) assert entry is not None diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index 84bd26ed1dd3a6..dfd37523a37027 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -2,10 +2,12 @@ import dataclasses import itertools +import math from functools import partial from threading import Lock from typing import Any, Callable, TYPE_CHECKING +import torch from torch.utils._ordered_set import OrderedSet from . import config @@ -19,6 +21,7 @@ from triton import Config as TritonConfig +# Gemm Configs @dataclasses.dataclass class BaseConfig: """ @@ -44,6 +47,36 @@ class GemmConfig(BaseConfig): ConvConfig = BaseConfig +# FlexAttention Configs +@dataclasses.dataclass +class FlexConfig: + """ + Base Config class for flex attention + - FlexAttn forward, backward and flex decode will use this + + NOTE: + For flex_attn bwd block_m and block_n are reused for block_m1, block_m2, block_n1, block_n2 + + """ + + block_m: int + block_n: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class FlexDecodeConfig: + """ + Config class for flex decoding + """ + + block_n: int + num_stages: int + num_warps: int + + +# ROCm classes @dataclasses.dataclass class ROCmGemmConfig(GemmConfig): """ @@ -66,6 +99,28 @@ class ROCmConvConfig(ConvConfig): kpack: int = 2 +@dataclasses.dataclass +class ROCmFlexConfig(FlexConfig): + """ + ROCm subclass for FlexAttn, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexDecodeConfig(FlexDecodeConfig): + """ + ROCm subclass for FlexDecode, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + class BaseHeuristicSingleton(type): """ Thread-safe implementation of single to be used in the config heuristic subclasses @@ -312,6 +367,53 @@ def __init__(self) -> None: ConvConfig(256, 64, 32, 2, 8), ] + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(128, 64, 3, 4), + FlexConfig(128, 128, 3, 4), + FlexConfig(128, 128, 2, 8), + FlexConfig(64, 128, 3, 4), + FlexConfig(64, 64, 3, 4), + ] + + self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(BLOCK1, BLOCK2, s, w) + for BLOCK1 in [32, 64] + for BLOCK2 in [32, 64, 128] + for s in [1, 3, 4, 5] # num_stages + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(64, 3, 2), + FlexDecodeConfig(32, 3, 2), + FlexDecodeConfig(128, 3, 2), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK1, BLOCK2, num_stages, num_warps) + for BLOCK1 in [16, 32, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + if BLOCK2 % BLOCK1 == 0 + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(block_n, num_stages, num_warps) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + def _finalize_mm_configs( self, configs: list[BaseConfig], @@ -418,6 +520,40 @@ def _scale_mm_configs( return scaled_configs + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + import torch + + pruned_configs = [] + for gemm_config in configs: + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + sm_available = props.shared_memory_per_block_optin # type: ignore[attr-defined] + NUM_REG = 255 + + acc_regs = math.ceil( + gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32) + ) + + shared_mem_accum = dtype_size * ( + gemm_config.block_m * gemm_config.block_k + + gemm_config.block_n * gemm_config.block_k + ) + + # Will use more shared memory than available + if shared_mem_accum * gemm_config.num_stages > sm_available: + continue + # Lower bound for register spillage, if exceeds the kernel will certainly spill + elif acc_regs > NUM_REG: + continue + + pruned_configs.append(gemm_config) + + return pruned_configs + def preprocess_mm_configs( self, m: int, @@ -427,10 +563,15 @@ def preprocess_mm_configs( has_int8_tensor: bool = False, scale: int = 1, exclude: Callable[[int, int, int], bool] = lambda m, n, k: False, + dtype_size: int = 0, ) -> Generator[TritonConfig, None, None]: scaled_configs = self._scale_mm_configs( m, n, k, configs, scale, has_int8_tensor, exclude ) + + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + assert dtype_size > 0, "dtype_size must be provided for exhaustive search" + scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size) return self._finalize_mm_configs(scaled_configs) def triton_config( @@ -461,7 +602,17 @@ def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: return partial(self.preprocess_mm_configs, configs=mm_configs) def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - return partial(self.preprocess_mm_configs, configs=self.persistent_mm_configs) + persistent_mm_configs = ( + self.exhaustive_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.persistent_mm_configs + ) + + # num_warps=2 not safe for TMA + persistent_mm_configs = [ + config for config in persistent_mm_configs if config.num_warps != 2 + ] + return partial(self.preprocess_mm_configs, configs=persistent_mm_configs) def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs) @@ -479,16 +630,202 @@ def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]] def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: return partial(self.preprocess_mm_configs, configs=self.conv_configs) + # Flex attn helpers + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + default_config = FlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = FlexDecodeConfig(block_n=64, num_stages=1, num_warps=2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + class CPUConfigHeuristic(BaseConfigHeuristic): pass class CUDAConfigHeuristic(BaseConfigHeuristic): - pass + """ + Child class for CUDA device specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.h100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 128, 3, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + + self.a100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(128, 32, 3, 4), + (torch.float32, 256): FlexConfig(64, 16, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 64, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(32, 64, 3, 4), + (torch.float16, 64): FlexConfig(128, 64, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 256): FlexConfig(32, 64, 3, 4), + } + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + if capability >= (9, 0): + default_config = self.h100_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (8, 0): + default_config = self.a100_default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = FlexConfig(16, 16, 1, 4) + elif head_dim <= 256 and capability >= (9, 0): # H100 + if head_dim == 64: + default_config = FlexConfig(64, 64, 3, 4) + elif head_dim == 128: + default_config = FlexConfig(64, 128, 3, 8) + else: + default_config = FlexConfig(64, 64, 2, 4) + elif capability >= (8, 0): # A100 + if head_dim == 64: + default_config = FlexConfig(32, 128, 3, 4) + elif head_dim == 128: + # SM86/89 have smaller shared memory sizes + num_stages = 3 if capability[1] == 0 else 2 + default_config = FlexConfig(64, 64, num_stages, 4) + else: + default_config = FlexConfig(64, 64, 2, 4) + else: # modest hardware or extremely large head_dim + default_config = FlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + capability = torch.cuda.get_device_capability() + + default_config = FlexDecodeConfig(64, 1, 2) + + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + if capability >= (9, 0): # sm_90+ + if head_dim > 128 and dtype == torch.float32: + default_config = FlexDecodeConfig(64, 1, 2) + else: + default_config = FlexDecodeConfig(64, 3, 2) + else: + default_config = FlexDecodeConfig(64, 1, 2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs class ROCmConfigHeuristic(BaseConfigHeuristic): + """ + Child class for ROCm specific gemm/flex attention/conv/ configs. + """ + def __init__(self) -> None: super().__init__() @@ -575,6 +912,73 @@ def __init__(self) -> None: for kpack in [2] ] + self.default_flex_config = { + (torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4), + (torch.bfloat16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 256): ROCmFlexConfig(32, 64, 1, 8), + (torch.float16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 256): ROCmFlexConfig(32, 64, 1, 4), + } + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w) + for BLOCK1 in [16, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for w in [4, 8] + ] + + self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w, mfma) + for BLOCK1 in [16, 32, 64] + for BLOCK2 in [32, 64, 128] + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + for mfma in [0, 16] + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(32, 1, 4), + ROCmFlexDecodeConfig(64, 1, 4), + ROCmFlexDecodeConfig(128, 1, 4), + ROCmFlexDecodeConfig(32, 1, 8), + ROCmFlexDecodeConfig(64, 1, 8), + ROCmFlexDecodeConfig(128, 1, 8), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, num_stages, num_warps, mfma, wpeu) + for BLOCK1 in [16, 32, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + if BLOCK2 % BLOCK1 == 0 + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(block_n, num_stages, num_warps, mfma, wpeu, kpack=2) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + def _filter_configs( self, configs: list[BaseConfig], new_num_stages: int ) -> list[BaseConfig]: @@ -700,6 +1104,77 @@ def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: ) return partial(self.preprocess_mm_configs, configs=filtered_configs) + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(128, 64, 1, 8) + default_config = self.default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = ROCmFlexConfig(32, 16, 1, 4) + else: + default_config = ROCmFlexConfig(64, 32, 1, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = ROCmFlexConfig(16, 16, 1, 4) + elif head_dim <= 256: + if head_dim == 64: + default_config = ROCmFlexConfig(64, 64, 1, 4) + elif head_dim == 128: + default_config = ROCmFlexConfig(64, 128, 1, 8) + else: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = ROCmFlexDecodeConfig(64, 1, 4) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + class XPUConfigHeuristic(BaseConfigHeuristic): - pass + """ + Placeholder child class for XPU specific overrides. + """ diff --git a/torch/_inductor/test_case.py b/torch/_inductor/test_case.py index e59ba6406773f0..227e369c6ac2bd 100644 --- a/torch/_inductor/test_case.py +++ b/torch/_inductor/test_case.py @@ -8,7 +8,7 @@ ) from torch._functorch import config as functorch_config from torch._inductor import config -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: @@ -41,7 +41,7 @@ def setUp(self) -> None: os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1" and os.environ.get("TORCH_COMPILE_DEBUG") != "1" ): - self._inductor_test_stack.enter_context(fresh_inductor_cache()) + self._inductor_test_stack.enter_context(fresh_cache()) def tearDown(self) -> None: super().tearDown() diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py new file mode 100644 index 00000000000000..4a1febe08e993a --- /dev/null +++ b/torch/_inductor/tiling_utils.py @@ -0,0 +1,764 @@ +import dataclasses +import functools +import itertools +import sys +from collections import Counter, defaultdict +from collections.abc import Iterable, Iterator +from typing import Callable, Literal, Optional, overload, TYPE_CHECKING, TypeVar, Union + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.dependencies import index_vars_no_squeeze +from torch._inductor.utils import sympy_product, sympy_subs +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import Identity +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .virtualized import V + + +T = TypeVar("T") +U = TypeVar("U") + + +Split = tuple[sympy.Expr, ...] +VarsAndRanges = tuple[list[sympy.Symbol], list[sympy.Expr]] + + +loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling") +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + + +if TYPE_CHECKING: + from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode + + +def solve_for_zero(expr: sympy.Expr) -> Optional[sympy.Expr]: + """ + Given an expr with a single free symbol, solve for a constant relation that would make + this expression 0. + """ + if expr.is_constant(): + return None + elif isinstance(expr, FloorDiv): + return None + + assert len(expr.free_symbols) == 1 + free_symbol = next(iter(expr.free_symbols)) + if isinstance(expr, ModularIndexing): + out = try_solve(sympy.Eq(expr.args[0], expr.args[2]), free_symbol) + else: + out = try_solve(sympy.Eq(expr, 0), free_symbol) + if not out or not out[1].is_constant(): + return None + return out[1] + + +def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]: + """ + Giving an expr with a single free symbol, try to find a tiling that would + make the expression coalesced with respect to that symbol. + + Tiling an expression `x` by `y` means that the expression will now be indexed + by both the original (x) and by (x * y). So we are looking for a + multiplicative factor that will make ((x + 1) * y) - (x * y) == 1. + + To simplify things for sympy, we'll try just x * y == 1, check x(1) and x(0). + """ + + if len(expr.free_symbols) == 0: + return None + + free_symbol = next(iter(expr.free_symbols)) + + def _solve_simple_expr(expr: sympy.Expr) -> Optional[sympy.Expr]: + assert not expr.has(ModularIndexing) and not expr.has(FloorDiv) + if len(expr.free_symbols) != 1: + return None + + out = try_solve(sympy.Eq(expr, 1), free_symbol) + if not out or not out[1].is_constant(): + return None + return out[1] + + # Sympy solving is very limited with ModularIndexing and FloorDiv, + # but good otherwise. + if not expr.has(ModularIndexing) and not expr.has(FloorDiv): + return _solve_simple_expr(expr) + + required_values = [] + eq_1_expressions = [] + + # very piecemeal solution if ModularIndexing or FloorDiv involved. + # Look for terms we'll try to make 0, and then other terms we'll try to make 1. + # Expand as needed. + for arg in sympy.Add.make_args(expr): + # Try to make mul terms 0 + if isinstance(arg, sympy.Mul): + seen = False + # TODO - only need one of these to be solvable to zero + # + for mul_arg in arg.args: + out = solve_for_zero(mul_arg) + if out is None: + continue + + assert out.is_constant() + seen = True + required_values.append(out) + + if not seen: + return None + else: + eq_1_expressions.append(arg) + + if not eq_1_expressions: + return None + + eq_1_expr = sum(eq_1_expressions) + + def indexing_div_rep( + x: sympy.Expr, + y: sympy.Expr, + z: Optional[sympy.Expr] = None, + ) -> sympy.Expr: + return x / y + + # For the purposes of tiling/coalesced access, approximate ModularIndexing and FloorDiv + # then check later + eq_1_expr_simplified = eq_1_expr.replace(ModularIndexing, indexing_div_rep).replace( + FloorDiv, indexing_div_rep + ) + + out = _solve_simple_expr(eq_1_expr_simplified) + # since we approximated FloorDiv/ModularIndexing, double check here + if not out or not (sympy_subs(eq_1_expr, {free_symbol: out})) == 1: + return None + + required_values.append(out) + + if len(OrderedSet(required_values)) == 1: + return required_values[0] + + return None + + +def find_coalesced_var( + index: sympy.Expr, var_ranges: dict[sympy.Expr, int] +) -> Optional[sympy.Expr]: + """ + Try to find the symbol which coalesces this index + """ + top_level_terms = sympy.Add.make_args(index) + for v in var_ranges: + if v in top_level_terms: + return v + + # Approximate analysis by evaluating at 1 and 0 + variables: dict[sympy.Symbol, int] = {} + for v in index.free_symbols: + if v in var_ranges: + variables[v] = 0 + else: + variables[v] = get_hint(v) + + zero_index = sympy_subs(index, variables) + for v in var_ranges.keys(): + variables[v] = 1 + try: + new_val = sympy_subs(index, variables) + except ZeroDivisionError: + loop_tiling_log.info("zero division error %s %s", index, variables) + continue + if new_val - zero_index == 1: + variables[v] = 2 + # in some more complex expressions, 0->1 will be coalesced, + # but not 1->2 + if (sympy_subs(index, variables) - new_val) == 1: + return v + variables[v] = 0 + + return None + + +@dataclasses.dataclass(frozen=True) +class FusedNormalizedReadsWrites: + """ + Normalized reads and writes for nodes in the same FusedSchedulerNode. + """ + + index_vars: OrderedSet[sympy.Symbol] + reduce_vars: OrderedSet[sympy.Symbol] + reads: dict[sympy.Expr, OrderedSet[str]] + writes: dict[sympy.Expr, OrderedSet[str]] + var_ranges: dict[sympy.Symbol, int] + + +@overload +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: Literal[True], +) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]: ... + + +@overload +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: Literal[False] = False, +) -> tuple[VarsAndRanges, VarsAndRanges]: ... + + +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: bool = False, +) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]: + if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel: + return ( + (n._body.iter_vars, n._body.sizes[0]), + (n._body.reduce_vars, n._body.sizes[1]), + ) # type: ignore[return-value] + + assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator] + i = len(n._body.sizes[0]) - 1 + prod = 1 + while i >= 0: + prod *= n._body.sizes[0][i] + if prod == red_numel: + break + i -= 1 + + if i >= 0: + pw_splits = n._body.sizes[0][0:i] + iter_vars = n._body.iter_vars[0:i] + + red_splits = n._body.sizes[0][i:] + red_vars = n._body.iter_vars[i:] + return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value] + + if none_if_not_divisible: + return None + else: + return ( + (n._body.iter_vars, n._body.sizes[0]), + (n._body.reduce_vars, n._body.sizes[1]), + ) # type: ignore[return-value] + + +class NodeSplitGetter: + """ + Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode. + """ + + def __init__( + self, + node: Union["FusedSchedulerNode", "SchedulerNode"], + ): + self.node = node + self.pointwise_numel: sympy.Expr = node.group[1][0] + self.red_numel: sympy.Expr = node.group[1][1] + + self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet) + + self.reduction_split: Split = () + self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet() + + fused_group = node.group[1] + for n in reversed(node.get_nodes()): + if not isinstance(n, torch._inductor.scheduler.SchedulerNode): + continue + + # if we can't split the pw ranges into a (pw, red) split, + # dont add as a split option, but do make sure we check that this size + # is splittable + maybe_splits = get_pw_red_splits( + n, self.pointwise_numel, self.red_numel, none_if_not_divisible=True + ) + if maybe_splits is None: + self.all_node_sizes.add(n._body.sizes) + continue + + (_, n_pw_splits), (_, n_red_splits) = maybe_splits + + # fill in reduction size + n_pw_splits, n_red_splits = ( + torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths( + fused_group, (n_pw_splits, n_red_splits), self.red_numel + ) + ) + + self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits)) + + # initially, we are just going to do a single reduction split since + # reduction tiling is off by default. even if we miss a reduction split, + # we can recover it in the split var analysis. + # TODO: an earlier version for this code tried to iteratively try the maximum number + # of split vars, by iterating over both pointwise and reduction. but not worth + # the complexity yet. + + if n_red_splits != (): + self.reduction_split = (sympy_product(n_red_splits),) + + n_size = (tuple(n_pw_splits), tuple(n_red_splits)) + self.all_node_sizes.add(n_size) + + self.seen_pw_splits: OrderedSet[Split] = OrderedSet() + + def get_node_splits(self) -> tuple[Split, Split]: + """ + Get a compatible pointwise, reduction split of the node + """ + + if len(self.all_node_sizes) == 1: + return next(iter(self.all_node_sizes)) + + max_pw_split = max(self.pw_split_options.keys()) + for pw_split_len in range(max_pw_split, 0, -1): + for pw_split in self.pw_split_options[pw_split_len]: + if out := self.try_split(pw_split, self.reduction_split): + return out + + # combine dims for next round + for pw_split in self.pw_split_options[pw_split_len]: + for i in range(len(pw_split) - 1): + new_split = tuple( + pw_split[0:i] + + (sympy_product(pw_split[i : i + 2]),) + + pw_split[i + 2 :] + ) + self.pw_split_options[len(new_split)].add(new_split) + + # if for whatever reason we couldn't split above, return default split + return ((self.pointwise_numel,), (self.red_numel,)) + + def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]: + """ + See if this split is compatible, and potentially returning a longer split + than the input. + """ + + from torch._inductor.codegen.simd import CantSplit, SIMDKernel + + if pw in self.seen_pw_splits: + return None + self.seen_pw_splits.add(pw) + + for n_pw, n_red in self.all_node_sizes: + try: + groups = pw + red + lengths = (n_pw, n_red) + splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths) + except CantSplit: + return None + + assert len(getters) == 2 + pw_group_splits = splits[: len(pw)] + # if we had to divide a variable into two to do this split, + # then lets try the larger, induced split. + # e.g. splitting (12, 2) into (2, 12) will split the first var into: + # (2, 6) and produce an overall split of (2, 6, 2) + flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits)) + if flattened_pw_splits != pw: + if out := self.try_split(flattened_pw_splits, red): + return out + + return pw, red + + +if sys.version_info >= (3, 10): + # On Python 3.10+ we can use zip(strict=True) + zip_equal = functools.partial(zip, strict=True) +else: + # Fallback for older versions + def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]: + """ + Zip two iterables, raising ValueError if their lengths differ. + """ + if len(it1) != len(it2): + raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}") + return zip(it1, it2) + + +def apply_var_mapping( + iter_vars: list[sympy.Symbol], + red_vars: list[sympy.Symbol], + norm_pw_vars: list[sympy.Symbol], + norm_red_vars: list[sympy.Symbol], + new_ranges: list[list[sympy.Expr]], + return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]], +) -> dict[sympy.Symbol, sympy.Expr]: + """Maps original variables to expressions using normalized variables.""" + + # the output of split_iteration_range is a new_ranges, return_getters_groups + # new_ranges is a flattened list of ranges corresponding to the new pw and red vars + # for example, taking in pw vars of range (6, 6) to normalized range [36], + # new_ranges would be [[6, 6]] + # There is a return_getter callable for each input iter_var and red_vars. + # if you flatten out all of the ranges, and create a variable for each index, + # then applying the flattening vars to the callables in return_getters_groups + # gives you the mapping from input vars -> flattened vars. + # From there, we can compute the output, normalized variables. + # For instance [6, 6] corresponding to flat vars v0, v1 will be + # v0 + 6 * v1 + + # Create flattened iteration variables + num_vars = sum(len(s) for s in new_ranges) + flat_vars = sympy.symbols(f"v_0:{num_vars}") + count = 0 + + if len(iter_vars) == 0 and len(red_vars) == 0: + return {} + + assert len(new_ranges) == len(norm_pw_vars + norm_red_vars) + apply_groups = [] + for group in return_getters_groups: + apply_groups.append([g(flat_vars) for g in group]) + + iter_vars_to_flat_vars = {} + for i, (group, var_group) in enumerate( + zip_equal(apply_groups, ((iter_vars, red_vars))) + ): + # if the node has sizes (p0, 1) and the fused node is (p0, r0) + # the reduction var gets filled in for split_iteration_range + if len(group) != len(var_group): + assert i == 1 + assert len(var_group) == 0 + continue + + iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)}) + + count = 0 + flat_vars_to_new_vars = {} + for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars): + range_vars = [] + for i in range(len(new_range)): + range_vars.append(flat_vars[count]) + count += 1 + + prod = 1 + for i in range(len(new_range) - 1, -1, -1): + flat_vars_to_new_vars[range_vars[i]] = new_var * prod + prod = new_range[i] * prod + + return { + k: sympy_subs(v, flat_vars_to_new_vars) + for k, v in iter_vars_to_flat_vars.items() + } + + +def extract_normalized_read_writes( + node: Union["FusedSchedulerNode", "SchedulerNode"], +) -> Optional[FusedNormalizedReadsWrites]: + """Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node.""" + reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + + all_output_names = node.get_buffer_names() + op_names = node.get_operation_names() + outputs: OrderedSet[str] = OrderedSet() + removed_buffers: OrderedSet[str] = OrderedSet() + for buf_name in all_output_names: + if V.graph.scheduler.can_buffer_be_removed_through_fusion(buf_name, op_names): + removed_buffers.add(buf_name) + else: + outputs.add(buf_name) + + inputs = OrderedSet( + dep.name for dep in node.read_writes.reads if dep.name not in removed_buffers + ) + + pointwise_numel: sympy.Expr = node.group[1][0] + red_numel: sympy.Expr = node.group[1][1] + + # TODO - a few dynamic shapes issues to resolve + if any( + (isinstance(var, sympy.Expr) and not var.is_constant()) + for var in (pointwise_numel, red_numel) + ): + return None + + pw_splits, red_splits = NodeSplitGetter(node).get_node_splits() + + # lets use different prefix (`n`) to distinguish + (norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze( + pw_splits, red_splits, prefix="n" + ) + node = node + + for n in list(node.get_nodes()): + if not isinstance(n, torch._inductor.scheduler.SchedulerNode): + continue + + body = n._body + + # TODO - not handled well. indirect loads will not be coalesced, + # need to account for that in analysis. + if body.indirect_vars: + return None + + n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + + # TODO - will the names for all the inputs/outputs accurately + # reflect mutation, or do I need to remap with mutation_real_name + for inp in inputs: + for expr in body.get_all_read_expr(inp): + n_reads[expr].add(inp) + + for out in outputs: + for expr in body.get_all_write_expr(out): + n_writes[expr].add(out) + + if not n_reads and not n_writes: + continue + + (iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits( + n, pointwise_numel, red_numel + ) + + groups = pw_splits + red_splits + lengths = (n_pw_splits, (n_red_splits)) + lengths = ( + torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths( + groups, lengths, red_numel + ) + ) + new_ranges, return_getters_groups = ( + torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges( + groups, lengths + ) + ) + var_map = apply_var_mapping( + iter_vars, + red_vars, + norm_pw_vars, + norm_red_vars, + new_ranges, + return_getters_groups, + ) + + # We create Identity sympy.Functions to prevent expansion to int64, + # unwrap for tiling analysis. + def remove_identity(expr: sympy.Expr) -> sympy.Expr: + return expr.replace(Identity, lambda x: x) + + n_reads_new = { + sympy_subs(remove_identity(read), var_map): v for read, v in n_reads.items() + } + n_writes_new = { + sympy_subs(remove_identity(write), var_map): v + for write, v in n_writes.items() + } + + for expr, buf_names in n_reads_new.items(): + reads[expr] |= buf_names + + for expr, buf_names in n_writes_new.items(): + writes[expr] |= buf_names + + reads = { + V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items() + } + writes = { + V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items() + } + + fused_out = FusedNormalizedReadsWrites( + norm_pw_vars, # type: ignore[arg-type] + norm_red_vars, # type: ignore[arg-type] + reads, + writes, + ranges, + ) + loop_tiling_log.info("Normalized Fused reads: %s", fused_out) + return fused_out + + +def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int: + """ + Score addr according to its approximate size + """ + + # TODO - deduplicate with candidate_tilings + var_sizes = [] + for v in addr.free_symbols: + v_size = var_ranges.get(v, None) + # TODO - reason about indirect vars + if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None: + var_sizes.append(v_size) + from .virtualized import V + + return V.graph.sizevars.atomically_apply_size_hint( + sympy_product(var_sizes), fallback=config.unbacked_symint_fallback + ) + + +def get_hint(v: Union[sympy.Expr, int]) -> int: + if isinstance(v, int): + return v + else: + return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback) + + +@dataclasses.dataclass(frozen=True) +class VarTiling: + """ + Tiling of a var by `tiling_factor` that yields additional coalesced mem accesses by `benefit_score` + """ + + var: sympy.Symbol + tiling_factor: int + score: int + + +@dataclasses.dataclass(frozen=True) +class CoalesceVarAnalysis: + # Var -> Memory Score - not strictly the amount of memory + # because we multiply writes x2 + # TODO: separate into dataclass that olds mem, dtype, is_write + coalesced_by_var: dict[sympy.Expr, int] + + norm_read_writes: FusedNormalizedReadsWrites + + suggested_split: Optional[VarTiling] = None + + +def analyze_memory_coalescing( + fused_node: Union["FusedSchedulerNode", "SchedulerNode"], +) -> Optional[CoalesceVarAnalysis]: + """ + Find variables that coalesce the reads and writes and score the total size. + + If uncoalesced memory expressions are found, look for additionally tiling of variables + which will coalesce memory accesses. + + For instance - for the following expression: + + (32*p0) // 2048 + + Tiling p0 by 64 will make this expression coalesced. + """ + + norm_read_writes = extract_normalized_read_writes(fused_node) + + if norm_read_writes is None: + return None + + reads = norm_read_writes.reads + writes = norm_read_writes.writes + var_ranges = norm_read_writes.var_ranges + + coalesced_by_var: dict[sympy.Symbol, int] = Counter() + uncoalesced_addrs: dict[sympy.Expr, int] = Counter() + + for is_read, (memory_expr, buf_names) in itertools.chain( + ((True, item) for item in reads.items()), + ((False, item) for item in writes.items()), + ): + # skip memory deps with indirect vars - todo: better handling + indirect_expr = bool( + memory_expr.free_symbols - norm_read_writes.var_ranges.keys() + ) + + if indirect_expr: + continue + + size = get_score(memory_expr, var_ranges) + if size == 0: + continue + + maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges) + + byte_multipler = 0 + for buf_name in buf_names: + if buf := V.graph.try_get_buffer(buf_name): + byte_multipler += buf.dtype.itemsize + + # coalesced writes more important + byte_multipler *= 1 if is_read else 2 + + if maybe_coalesced_var: + coalesced_by_var[maybe_coalesced_var] += size * byte_multipler + else: + uncoalesced_addrs[memory_expr] += size * byte_multipler + + if not uncoalesced_addrs: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes + ) + + # map from var -> tiling -> total_score + tiling_scores: dict[sympy.Expr, dict[int, int]] = defaultdict(Counter) + + for uncoalesced_expr, addr_score in uncoalesced_addrs.items(): + expr_subs = dict.fromkeys(uncoalesced_expr.free_symbols, 0) + for v in uncoalesced_expr.free_symbols: + # skip non iter/reduce var variables + if v not in var_ranges: + continue + # skip small addrs + if addr_score == 0: + continue + del expr_subs[v] + single_var_expr = sympy_subs(uncoalesced_expr, expr_subs) + expr_subs[v] = 0 + tiling_factor = solve_for_tiling(single_var_expr) + if ( + tiling_factor is None + or not tiling_factor.is_constant() + or not tiling_factor.is_integer + ): + continue + + tiling_factor = int(tiling_factor) + if not V.graph.sizevars.statically_known_lt(tiling_factor, var_ranges[v]): + continue + + # TODO - if a var is in the middle, such as [n0, n1, n2] + # n1 can can be split beyond range + + MIN_TILING_BLOCK = 8 + if not all( + V.graph.sizevars.statically_known_lt(MIN_TILING_BLOCK, block) + for block in (tiling_factor, var_ranges[v] // tiling_factor) + ): + continue + + tiling_scores[v][tiling_factor] += addr_score + + if len(tiling_scores) == 0: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes + ) + + best_tiling: Optional[tuple[sympy.Expr, int]] = None + best_tiling_score = 0 + + for var, tiling_counter in tiling_scores.items(): + for tile, tile_score in tiling_counter.items(): + if tile_score > best_tiling_score: + best_tiling = (var, tile) + best_tiling_score = tile_score + + if best_tiling is None: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes + ) + + # TODO - for strictly pointwise fusions, + # we can consider just swizzling the var if the var we are going to tile + # does not coalesce a significant portion of global reads + # TODO - could also prefer index var splits to reduction, better tested + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, + norm_read_writes=norm_read_writes, + suggested_split=VarTiling(best_tiling[0], best_tiling[1], best_tiling_score), + ) diff --git a/torch/_inductor/triton_bundler.py b/torch/_inductor/triton_bundler.py index 6fb1424776179a..b5ccb873e33f9f 100644 --- a/torch/_inductor/triton_bundler.py +++ b/torch/_inductor/triton_bundler.py @@ -109,7 +109,7 @@ class TritonBundler: _static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] = None # __grp__kernel_name.json contains metadata with source code paths - # we use this as sentinal value for search and replace + # we use this as sentinel value for search and replace _REPLACE_BYTES: bytes = b"[REPLACE]" @staticmethod diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index fe1c3bdc88d2bc..6931d93a937cbb 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -91,7 +91,7 @@ # defines here before import torch._dynamo is for avoiding circular import # when get_gpu_type is imported from dynamo -@functools.lru_cache(None) +@functools.cache def get_gpu_type() -> str: avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] assert len(avail_gpus) <= 1 @@ -241,7 +241,10 @@ def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float: [ event for event in p.events() - if event.device_type == DeviceType.CUDA and "fused_abs_max_0" in event.name + if ( + event.device_type == DeviceType.CUDA + and re.match(r"fused_abs_max_\d", event.name) is not None + ) ] ) if filtered_events: @@ -288,6 +291,8 @@ def do_bench_using_profiling( for _ in range(n_warmup): fn() + torch.cuda.synchronize() + with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CUDA, @@ -338,7 +343,7 @@ def do_bench_using_profiling( return res -@functools.lru_cache(None) +@functools.cache def has_torchvision_roi_align() -> bool: try: from torchvision.ops import roi_align # noqa: F401 @@ -379,17 +384,17 @@ def unique(it: Iterable[_T]) -> ValuesView[_T]: def ceildiv( - numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] + number: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] ) -> Union[int, sympy.Expr]: - if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): - return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) + if isinstance(number, sympy.Expr) or isinstance(denom, sympy.Expr): + return CeilDiv(sympy.sympify(number), sympy.sympify(denom)) # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes - assert isinstance(numer, int) and isinstance(denom, int), ( - f"{numer}: {type(numer)}, {denom}: {type(denom)}" + assert isinstance(number, int) and isinstance(denom, int), ( + f"{number}: {type(number)}, {denom}: {type(denom)}" ) - return runtime_ceildiv(numer, denom) + return runtime_ceildiv(number, denom) def _type_of(key: Optional[torch.dtype]) -> str: @@ -411,6 +416,7 @@ def _type_of(key: Optional[torch.dtype]) -> str: # TODO: remove when support is added in triton # https://github.com/triton-lang/triton/issues/6054 "float8_e8m0fnu": "u8", + "float4_e2m1fn_x2": "u8", "float16": "fp16", "bfloat16": "bf16", "float32": "fp32", @@ -974,10 +980,10 @@ def get_first_incompatible_cudagraph_node( if ( not torch._inductor.config.graph_partition and isinstance(node.target, torch._ops.OpOverload) - and torch._C.Tag.cudagraph_unsafe in node.target.tags + and torch._C.Tag.cudagraph_unsafe in node.target.tags # type: ignore[attr-defined] ): # skip cudagraph if a cudagraph_unsafe op is detected. - # graph_partition helps by spliting on this cudagraph_unsafe + # graph_partition helps by splitting on this cudagraph_unsafe # op and cudagraphifying the subgraphs. return node @@ -994,27 +1000,23 @@ def output_node(gm: torch.fx.GraphModule) -> Node: return last_node -_registered_caches: list[Any] = [] - - -def clear_on_fresh_inductor_cache(obj: Any) -> Any: - """ - Use this decorator to register any caches that should be cache_clear'd - with fresh_inductor_cache(). - """ - if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): - raise AttributeError(f"{obj} does not have a cache_clear method") - - _registered_caches.append(obj) - return obj - +def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]: + placeholder_nodes = gm.graph.find_nodes(op="placeholder") + input_devices: OrderedSet[torch.device] = OrderedSet( + node.meta["val"].device + for node in placeholder_nodes + if isinstance(node.meta.get("val"), torch.Tensor) + ) -def clear_inductor_caches() -> None: - """ - Clear all registered caches. - """ - for obj in _registered_caches: - obj.cache_clear() + out_arg = output_node(gm).args[0] # type: ignore[union-attr] + out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,) + out_devices: OrderedSet[torch.device] = OrderedSet( + arg.meta["val"].device + for arg in out_args + if isinstance(arg, torch.fx.Node) + and isinstance(arg.meta.get("val"), torch.Tensor) + ) + return input_devices | out_devices import gc @@ -1049,19 +1051,42 @@ def unload_xpu_triton_pyds() -> None: gc.collect() +_registered_caches: list[Any] = [] + + +def clear_on_fresh_cache(obj: Any) -> Any: + """ + Use this decorator to register any caches that should be cache_clear'd + with fresh_cache(). + """ + if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): + raise AttributeError(f"{obj} does not have a cache_clear method") + + _registered_caches.append(obj) + return obj + + +def clear_caches() -> None: + """ + Clear all registered caches. + """ + for obj in _registered_caches: + obj.cache_clear() + + @contextlib.contextmanager -def fresh_inductor_cache( +def fresh_cache( cache_entries: Optional[dict[str, Any]] = None, dir: Optional[str] = None, delete: bool = True, ) -> Iterator[None]: """ - Contextmanager that provides a clean tmp cachedir for inductor. + Contextmanager that provides a clean tmp cachedir for pt2 caches. Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes generated with this cache instance. """ - clear_inductor_caches() + clear_caches() inductor_cache_dir = tempfile.mkdtemp(dir=dir) try: @@ -1102,7 +1127,13 @@ def fresh_inductor_cache( log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) raise finally: - clear_inductor_caches() + clear_caches() + + +# Deprecated functions -- only keeping them for BC reasons +clear_on_fresh_inductor_cache = clear_on_fresh_cache +clear_inductor_caches = clear_caches +fresh_inductor_cache = fresh_cache def argsort(seq: Sequence[Any]) -> list[int]: @@ -1384,7 +1415,7 @@ def _new_line(self, line: str) -> DelayReplaceLine: return DelayReplaceLine(self.key, self.value_fn, line) -@functools.lru_cache(None) +@functools.cache def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: if isinstance(index_or_device, torch.device): device = index_or_device @@ -1445,12 +1476,6 @@ def get_tma_workspace_arg( ) -def use_max_autotune() -> bool: - return ( - config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache - ) - - def _use_template_for_gpu( layout: Layout, allowed_layout_dtypes: list[torch.dtype] ) -> bool: @@ -1497,14 +1522,14 @@ def use_triton_template( ) or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) ) - and use_max_autotune() + and (config.max_autotune or config.max_autotune_gemm) and _use_autotune_backend("TRITON") and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) ) def use_triton_tma_template(*matrices: IRNode) -> bool: - from torch.utils._triton import has_triton_tma_device + from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device from .virtualized import V @@ -1533,6 +1558,10 @@ def _is_tma_compatible(x: IRNode) -> bool: inner_bytes = inner_dim * dtype.itemsize return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) + if has_triton_stable_tma_api() and config.cpp_wrapper: + # TODO(dberard) remove this when we get AOTI support for new TMA APIs (#155047) + return False + return ( config.triton.enable_persistent_tma_matmul and has_triton_tma_device() @@ -1557,7 +1586,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: layout_dtypes = [torch.float16, torch.bfloat16, torch.int32] res = ( _use_template_for_gpu(layout, layout_dtypes) - and use_max_autotune() + and (config.max_autotune or config.max_autotune_gemm) and _use_autotune_backend("CUTLASS") ) @@ -1572,6 +1601,14 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: return res +def _use_cutlass_for_op(op_name: str) -> bool: + """Check if CUTLASS should be used for the given operation.""" + enabled_ops = config.cuda.cutlass_enabled_ops.upper() + if enabled_ops == "ALL": + return True + return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] + + decompose_k_threshold = 32 # To limit compile time @@ -1587,7 +1624,8 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: from torch._inductor.virtualized import V return ( - V.graph.sizevars.is_expr_static_and_true( + not torch.version.hip + and V.graph.sizevars.statically_known_true( sympy.And( sympy.Ge(k, decompose_k_threshold * m), sympy.Ge(k, decompose_k_threshold * n), @@ -1599,7 +1637,7 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: ) -@functools.lru_cache(None) +@functools.cache def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: # If k is a sympy expression, we can't do any splitting if isinstance(k, sympy.Expr) and not k.is_number: @@ -1641,6 +1679,8 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: else: rest_of_splits.append(d) + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits # If the # of power of 2 divisors are greater than k_splits_limit, return all # This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n) # should never be a massive amount @@ -1652,12 +1692,12 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: return best_splits[:k_splits_limit] -@functools.lru_cache(None) +@functools.cache def _rocm_native_device_arch_name(device: str) -> str: return torch.cuda.get_device_properties(device).gcnArchName -@functools.lru_cache(None) +@functools.cache def try_import_ck_lib() -> tuple[ Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any] ]: @@ -1689,7 +1729,7 @@ class CKGemmOperation: # type: ignore[no-redef] def use_ck_template(layout: Layout) -> bool: # config knobs check 1 - if not use_max_autotune(): + if not (config.max_autotune or config.max_autotune_gemm): return False # platform check if not torch.version.hip: @@ -1743,12 +1783,24 @@ def use_ck_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool: ) +def use_ck_tile_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool: + from .virtualized import V + + return ( + _use_autotune_backend("CKTILE") + and use_ck_template(layout) + and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0 + ) + + def use_ck_conv_template(layout: Layout) -> bool: return _use_conv_autotune_backend("CK") and use_ck_template(layout) def _use_template_for_cpu(layout: Layout) -> bool: - return use_max_autotune() and layout.device.type == "cpu" + return ( + config.max_autotune or config.max_autotune_gemm + ) and layout.device.type == "cpu" def use_cpp_bmm_template( @@ -1797,6 +1849,7 @@ def use_cpp_gemm_template( # TODO(jgong5): support dynamic shapes for n or k if has_free_symbols((n, k)): return False + if isinstance(mat2, ir.BaseView): mat2 = mat2.unwrap_view() @@ -1828,7 +1881,9 @@ def is_last_dim_stride1(x: IRNode) -> bool: def use_aten_gemm_kernels() -> bool: - return not use_max_autotune() or _use_autotune_backend("ATEN") + return not ( + config.max_autotune or config.max_autotune_gemm + ) or _use_autotune_backend("ATEN") class DebugDirManager: @@ -2096,7 +2151,7 @@ def parallel_num_threads() -> int: return threads -@functools.lru_cache(None) +@functools.cache def get_backend_num_stages() -> int: from .runtime.triton_helpers import get_backend_options @@ -2104,7 +2159,7 @@ def get_backend_num_stages() -> int: return options.get("num_stages", 2 if torch.version.hip else 3) -@functools.lru_cache(None) +@functools.cache def get_device_tflops(dtype: torch.dtype) -> int: from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops @@ -2132,7 +2187,7 @@ def get_device_tflops(dtype: torch.dtype) -> int: return get_max_simd_tflops(torch.float32) -@functools.lru_cache(None) +@functools.cache def get_gpu_dram_gbps() -> int: from triton.testing import get_dram_gbps @@ -2258,7 +2313,7 @@ def is_output_of_multi_outputs_template( return ( isinstance(input_buf, ir.MultiOutput) and len(input_buf.inputs) == 1 - and is_multi_outputs_template(input_buf.inputs[0]) + and is_multi_outputs_template(input_buf.inputs[0]) # type: ignore[arg-type] ) @@ -2272,7 +2327,9 @@ def is_collective( from . import ir return ( - type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op) + isinstance(node, ir._CollectiveKernel) + and not isinstance(node, ir._WaitKernel) + and (op is None or node.op_overload is op) ) or ( # TODO: this is a temporary solution to ensure that we can identify torchrec's # communication ops. But in order to allow better communication and computation @@ -2703,7 +2760,9 @@ def copy_misaligned_inputs( ret_pair_defined = return_pair_idxs is not None for i in check_inputs_idxs: _inp = new_inputs[i] - assert isinstance(_inp, torch.Tensor) + assert isinstance(_inp, torch.Tensor), ( + f"Expected tensors only, but got: {type(_inp)}" + ) if _inp.data_ptr() % ALIGNMENT: new_inputs[i] = clone_preserve_strides(_inp) @@ -2741,7 +2800,7 @@ def expr_fits_within_32bit(e: sympy.Expr) -> bool: # Allow for unhinted e as long as we can still statically prove # (e.g., via ValueRanges) that it is still in bounds - if V.graph.sizevars.is_expr_static_and_true(e <= int_max): + if V.graph.sizevars.statically_known_true(e <= int_max): return True # Otherwise, the hint MUST exist and be in range return has_hint(e) and size_hint(e) <= int_max @@ -2809,6 +2868,7 @@ def normalize_name(name: str) -> str: # TODO: remove when support is added in triton # https://github.com/triton-lang/triton/issues/6054 "tl.float8_e8m0fnu": "tl.uint8", + "tl.float4_e2m1fn_x2": "tl.uint8", } _torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()} @@ -2852,7 +2912,7 @@ def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool: ) -@functools.lru_cache(None) +@functools.cache def boolean_ops() -> tuple[str, ...]: return ( "isinf", @@ -3041,7 +3101,7 @@ class TritonAttrsDescriptorVersion(enum.Enum): V4_DICT = 4 # a raw dict -@functools.lru_cache(None) +@functools.cache def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion: if importlib.util.find_spec("triton") is None: return TritonAttrsDescriptorVersion.V0_NO_TRITON @@ -3083,7 +3143,7 @@ def is_cudagraph_unsafe_op(node: Operation) -> bool: if ( isinstance(node.op_overload, torch._ops.OpOverload) - and torch._C.Tag.cudagraph_unsafe in node.op_overload.tags + and torch._C.Tag.cudagraph_unsafe in node.op_overload.tags # type: ignore[attr-defined] ): return True @@ -3101,3 +3161,50 @@ def get_ld_library_path() -> str: path = os.pathsep.join([lib_path, path]) if path else lib_path return path + + +def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: + from torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen + + return ( + isinstance(wrapper, SubgraphPythonWrapperCodegen) + and wrapper.partition_signatures is not None + ) + + +def dtype_from_size(size: int) -> torch.dtype: + from .virtualized import V + + if V.graph.sizevars.statically_known_lt( + size, 2**31 + ) and V.graph.sizevars.statically_known_geq(size, -(2**31)): + return torch.int32 + else: + return torch.int64 + + +SUPPORTED_MKLDNN_DEVICES = ("cpu", "xpu") + + +def is_mkldnn_bf16_supported(device_type: str) -> bool: + """ + Returns True if the device supports MKL-DNN BF16. + """ + if device_type == "cpu": + return torch.ops.mkldnn._is_mkldnn_bf16_supported() + elif "xpu" in device_type: + # match "xpu", "xpu:0", "xpu:1", etc. + return True + return False + + +def is_mkldnn_fp16_supported(device_type: str) -> bool: + """ + Returns True if the device supports MKL-DNN FP16. + """ + if device_type == "cpu": + return torch.ops.mkldnn._is_mkldnn_fp16_supported() + elif "xpu" in device_type: + # match "xpu", "xpu:0", "xpu:1", etc. + return True + return False diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 911b719fc607ca..c8cf2826cdbc86 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -397,6 +397,9 @@ def collect_memory_snapshot( print(f"The collect memory snapshot has been written to {snapshot_path}") +# With AOTAutograd cache, we directly call the compiled module. So prevent +# Dynamo from reentering +@torch.compiler.disable # type: ignore[misc] def compiled_module_main( benchmark_name: str, benchmark_compiled_module_fn: BenchmarkCallableType ) -> None: diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 8c46dfabc0580b..547d305c47afd3 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -28,8 +28,7 @@ def custom_op( mutates_args: Union[str, Iterable[str]], device_types: device_types_t = None, schema: Optional[str] = None, -) -> Callable[[Callable[..., object]], "CustomOpDef"]: - ... +) -> Callable[[Callable[..., object]], "CustomOpDef"]: ... @overload @@ -41,8 +40,7 @@ def custom_op( mutates_args: Union[str, Iterable[str]], device_types: device_types_t = None, schema: Optional[str] = None, -) -> "CustomOpDef": - ... +) -> "CustomOpDef": ... @exposed_in("torch.library") @@ -448,10 +446,10 @@ def register_fake(self, fn: Callable, /) -> Callable: >>> >>> @nonzero.register_fake >>> def _(x): - >>> # Number of nonzero-elements is data-dependent. - >>> # Since we cannot peek at the data in an abstract impl, - >>> # we use the ctx object to construct a new symint that - >>> # represents the data-dependent size. + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] @@ -561,7 +559,7 @@ def register_autograd( >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_sin(x) - >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, x.cos()) >>> >>> # Example with a keyword-only arg @@ -581,7 +579,7 @@ def register_autograd( >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_mul(x, val=3.14) - >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) """ @@ -919,7 +917,7 @@ def get_library_allowing_overwrite( def _maybe_get_opdef( - op: Union[CustomOpDef, _ops.OpOverload, str] + op: Union[CustomOpDef, _ops.OpOverload, str], ) -> Optional[CustomOpDef]: if isinstance(op, CustomOpDef): return op diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index b9a0061139d67f..512bd5835bd951 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -77,7 +77,7 @@ def convert_type_string(annotation_type: str): ) def unstringify_types( - tys: tuple[Union[type[object], str], ...] + tys: tuple[Union[type[object], str], ...], ) -> tuple[tuple[typing.Any, ...], bool]: res = [] changed = False @@ -150,13 +150,13 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]: "the arguments that are mutated or the string 'unknown'. " ) if schema_type.startswith("Tensor"): - schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}" elif name in mutates_args: if not schema_type.startswith("Tensor"): error_fn( f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" ) - schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}" seen_args.add(name) if param.default is inspect.Parameter.empty: params.append(f"{schema_type} {name}") @@ -282,8 +282,12 @@ def parse_return(annotation, error_fn): f"Return has unsupported type {annotation}. " f"The valid types are: {SUPPORTED_RETURN_TYPES}." ) + output_ty = ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) - return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")" + # use (()) to represent tuple with single element + if len(args) == 1: + output_ty = "(" + output_ty + ")" + return "(" + output_ty + ")" SUPPORTED_PARAM_TYPES = get_supported_param_types() diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 78ac3eb6f304a3..bd78fbfa103ed4 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -391,10 +391,17 @@ def lobpcg( we do the following symmetrization map: `A -> (A + A.t()) / 2`. The map is performed only when the `A` requires gradients. + .. warning:: LOBPCG algorithm is not applicable when the number of `A`'s rows + is smaller than 3x the number of requested eigenpairs `n`. + Args: A (Tensor): the input tensor of size :math:`(*, m, m)` + k (integer, optional): the number of requested + eigenpairs. Default is the number of :math:`X` + columns (when specified) or `1`. + B (Tensor, optional): the input tensor of size :math:`(*, m, m)`. When not specified, `B` is interpreted as identity matrix. @@ -404,19 +411,21 @@ def lobpcg( initial approximation of eigenvectors. X must be a dense tensor. - iK (tensor, optional): the input tensor of size :math:`(*, m, - m)`. When specified, it will be used as preconditioner. - - k (integer, optional): the number of requested - eigenpairs. Default is the number of :math:`X` - columns (when specified) or `1`. - n (integer, optional): if :math:`X` is not specified then `n` specifies the size of the generated random approximation of eigenvectors. Default value for `n` - is `k`. If :math:`X` is specified, the value of `n` - (when specified) must be the number of :math:`X` - columns. + is `k`. If :math:`X` is specified, any provided value of `n` is + ignored and `n` is automatically set to the number of + columns in :math:`X`. + + iK (tensor, optional): the input tensor of size :math:`(*, m, + m)`. When specified, it will be used as preconditioner. + + niter (int, optional): maximum number of iterations. When + reached, the iteration process is hard-stopped and + the current approximation of eigenpairs is returned. + For infinite iteration but until convergence criteria + is met, use `-1`. tol (float, optional): residual tolerance for stopping criterion. Default is `feps ** 0.5` where `feps` is @@ -432,12 +441,6 @@ def lobpcg( description of the function above. Default is "ortho". - niter (int, optional): maximum number of iterations. When - reached, the iteration process is hard-stopped and - the current approximation of eigenpairs is returned. - For infinite iteration but until convergence criteria - is met, use `-1`. - tracker (callable, optional) : a function for tracing the iteration process. When specified, it is called at each iteration step with LOBPCG instance as an diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index a6f358ec52ae1d..2901fdaad43a47 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -14,7 +14,8 @@ import time from collections import defaultdict from dataclasses import dataclass, field -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Generic, Optional, Union +from typing_extensions import ParamSpec from weakref import WeakSet import torch._logging.structured @@ -23,6 +24,8 @@ from torch.utils._traceback import CapturedTraceback +_P = ParamSpec("_P") + log = logging.getLogger(__name__) # This is a synthetic logger which doesn't correspond to an actual logger, @@ -96,7 +99,7 @@ def is_log(self, alias): return alias in self.log_alias_to_log_qnames # register a log with an alias - def register_log(self, alias, log_qnames: Union[str, list[str]]): + def register_log(self, alias, log_qnames: Union[str, list[str]]) -> None: if isinstance(log_qnames, str): log_qnames = [log_qnames] self.log_alias_to_log_qnames[alias] = log_qnames @@ -104,7 +107,7 @@ def register_log(self, alias, log_qnames: Union[str, list[str]]): # register an artifact name def register_artifact_name( self, name, description, visible, off_by_default, log_format - ): + ) -> None: self.artifact_names.add(name) if visible: self.visible_artifacts.add(name) @@ -121,10 +124,10 @@ def register_artifact_name( # register the qualified name of an artifact log # this is needed to know which logs need to be reset # whenever the log_state is changed - def register_artifact_log(self, artifact_log_qname): + def register_artifact_log(self, artifact_log_qname) -> None: self.artifact_log_qnames.add(artifact_log_qname) - def register_child_log(self, log_qname): + def register_child_log(self, log_qname) -> None: self.child_log_qnames.add(log_qname) # flattens all the qnames together (TODO: consider memoizing?) @@ -149,13 +152,13 @@ class LogState: # the set of currently enabled artifacts artifact_names: set[str] = field(default_factory=set) - def enable_artifact(self, artifact_name): + def enable_artifact(self, artifact_name) -> None: self.artifact_names.add(artifact_name) def is_artifact_enabled(self, name): return name in self.artifact_names - def enable_log(self, log_qnames, log_level): + def enable_log(self, log_qnames, log_level) -> None: if isinstance(log_qnames, str): log_qnames = [log_qnames] for log_qname in log_qnames: @@ -175,7 +178,7 @@ def get_log_level_pairs(self): """ return self.log_qname_to_level.items() - def clear(self): + def clear(self) -> None: self.log_qname_to_level.clear() self.artifact_names.clear() @@ -248,7 +251,8 @@ def set_logs( autotuning: bool = False, graph_region_expansion: bool = False, inductor_metrics: bool = False, -): + hierarchical_compile: bool = False, +) -> None: """ Sets the log level for individual components and toggles individual log artifact types. @@ -445,6 +449,9 @@ def set_logs( inductor_metrics (:class:`bool`): Whether to estimate the runtimes of the nodes in a graph and log them to the metrics table. Default: ``False`` + hierarchical_compile (:class:`bool`): + Whether to emit debug info for hierarchical compilation. Default: ``False`` + Example:: >>> # xdoctest: +SKIP @@ -470,7 +477,7 @@ def set_logs( modules = modules or {} - def _set_logs(**kwargs): + def _set_logs(**kwargs) -> None: for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr] if val is None: continue @@ -557,17 +564,18 @@ def _set_logs(**kwargs): autotuning=autotuning, graph_region_expansion=graph_region_expansion, inductor_metrics=inductor_metrics, + hierarchical_compile=hierarchical_compile, ) -def get_loggers(): +def get_loggers() -> list[logging.Logger]: """ Returns: a list of all registered loggers """ return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()] -def register_log(setting_name, log_name): +def register_log(setting_name, log_name) -> None: """ Enables a log to be controlled by the env var and user API with the setting_name Args: @@ -579,7 +587,7 @@ def register_log(setting_name, log_name): def register_artifact( setting_name, description, visible=False, off_by_default=False, log_format=None -): +) -> None: """ Enables an artifact to be controlled by the env var and user API with name Args: @@ -594,7 +602,7 @@ def register_artifact( ) -def getArtifactLogger(module_qname, artifact_name): +def getArtifactLogger(module_qname, artifact_name) -> logging.Logger: if artifact_name not in log_registry.artifact_names: raise ValueError( f"Artifact name: {repr(artifact_name)} not registered," @@ -617,7 +625,7 @@ def getArtifactLogger(module_qname, artifact_name): ) -def configure_artifact_log(log): +def configure_artifact_log(log) -> None: # If the artifact is off by default, then it should only be logged when explicitly # enabled; set propagate to False so that this artifact is not propagated # to its ancestor logger @@ -773,14 +781,14 @@ def _is_valid_module(qname): return spec is not None -def _update_log_state_from_env(): +def _update_log_state_from_env() -> None: global log_state log_setting = os.environ.get(LOG_ENV_VAR, None) if log_setting is not None: log_state = _parse_log_settings(log_setting) -def _has_registered_parent(log_qname): +def _has_registered_parent(log_qname) -> bool: cur_log = logging.getLogger(log_qname) registered_log_qnames = log_registry.get_log_qnames() @@ -817,7 +825,7 @@ def make_module_path_relative(abs_path): class TorchLogsFormatter(logging.Formatter): def __init__( self, *, trace: bool = False, trace_id_filter: Optional[set[str]] = None - ): + ) -> None: super().__init__() self._is_trace = trace self._trace_id_filter = trace_id_filter @@ -922,7 +930,7 @@ def _default_formatter(): DEFAULT_FORMATTER = _default_formatter() -def _setup_handlers(create_handler_fn, log): +def _setup_handlers(create_handler_fn, log) -> None: debug_handler = _track_handler(create_handler_fn()) debug_handler.setFormatter(DEFAULT_FORMATTER) debug_handler.setLevel(logging.DEBUG) @@ -944,13 +952,13 @@ def _is_torch_handler(handler): # clears all torch handlers on specified loggers -def _clear_handlers(log): +def _clear_handlers(log) -> None: to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)] for handler in to_remove: log.removeHandler(handler) -def _reset_logs(): +def _reset_logs() -> None: # reset all registered logs for log_qname in log_registry.get_log_qnames(): log = logging.getLogger(log_qname) @@ -974,12 +982,12 @@ def _get_log_state(): return log_state -def _set_log_state(state): +def _set_log_state(state) -> None: global log_state log_state = state -def _init_logs(log_file_name=None): +def _init_logs(log_file_name=None) -> None: global GET_DTRACE_STRUCTURED _reset_logs() @@ -1053,7 +1061,7 @@ def _init_logs(log_file_name=None): class LazyTraceHandler(logging.StreamHandler): """Like FileHandler, but the file is allocated lazily only upon the first log message""" - def __init__(self, root_dir: Optional[str]): + def __init__(self, root_dir: Optional[str]) -> None: # This is implemented in the same way that delay is implemented on # FileHandler self.root_dir = root_dir @@ -1062,7 +1070,7 @@ def __init__(self, root_dir: Optional[str]): self._builtin_open = open # cloned from FileHandler in cpython - def close(self): + def close(self) -> None: self.acquire() try: try: @@ -1083,7 +1091,7 @@ def close(self): finally: self.release() - def emit(self, record): + def emit(self, record) -> None: if self.stream is None: if self.root_dir is None: TRACE_LOG_DIR = "/logs" @@ -1135,8 +1143,8 @@ def emit(self, record): super().emit(record) -@functools.lru_cache(None) -def warning_once(logger_obj, *args, **kwargs): +@functools.cache +def warning_once(logger_obj, *args, **kwargs) -> None: """ This function is similar to `logger.warning()`, but will emit the warning with the same message only once Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. @@ -1146,13 +1154,15 @@ def warning_once(logger_obj, *args, **kwargs): logger_obj.warning(*args, **kwargs) -class LazyString: - def __init__(self, func, *args, **kwargs): +class LazyString(Generic[_P]): + def __init__( + self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs + ) -> None: self.func = func self.args = args self.kwargs = kwargs - def __str__(self): + def __str__(self) -> str: return self.func(*self.args, **self.kwargs) @@ -1232,12 +1242,12 @@ def trace_structured( "frame_compile_id", "attempt", ] - assert callable( - metadata_fn - ), f"metadata_fn should be callable, but got {type(metadata_fn)}" - assert callable( - payload_fn - ), f"payload_fn should be callable, but got {type(payload_fn)}" + assert callable(metadata_fn), ( + f"metadata_fn should be callable, but got {type(metadata_fn)}" + ) + assert callable(payload_fn), ( + f"payload_fn should be callable, but got {type(payload_fn)}" + ) # trace_log never propagates and is ALWAYS DEBUG, so also check that there # are handlers instead of checking the log level if trace_log.handlers: @@ -1312,7 +1322,7 @@ def dtrace_structured( suppress_context: bool = False, expect_trace_id: bool = False, # Whether or not we expect to have a current trace id record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging -): +) -> None: """ For logging more detailed information used for debugging. This may result in the program becoming slow. diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 663d3e6e7af10b..62e5d9b7064ca9 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -13,6 +13,13 @@ "torch.nn.parallel.distributed", ] +register_log( + "async_compile", + [ + "torch._inductor.async_compile", + "torch._inductor.compile_worker.tracked_process_pool", + ], +) register_log( "cache", ("torch._inductor.remote_cache", "torch._inductor.fb.remote_cache") ) @@ -26,10 +33,6 @@ "cudagraphs", "Logs information from wrapping inductor generated code with cudagraphs.", ) -register_artifact( - "codecache", - "Logs information about inductor's code cache", -) register_log("dynamic", DYNAMIC) register_log("torch", "torch") @@ -190,6 +193,12 @@ "Logs related to loop ordering", off_by_default=True, ) +register_artifact( + "loop_tiling", + "Logs related to loop ordering", + off_by_default=True, +) + register_artifact( "overlap", "Detailed Inductor compute/comm overlap decisions", @@ -231,5 +240,9 @@ "Logs Inductor metrics, such as num_bytes, nodes_num_elem, node_runtimes", off_by_default=True, ) - +register_artifact( + "hierarchical_compile", + "Logs debug info for hierarchical compilation", + off_by_default=True, +) register_artifact("custom_format_test_artifact", "Testing only", log_format="") diff --git a/torch/_logging/structured.py b/torch/_logging/structured.py index 43ccad0b3e0b5d..4eae33227e618c 100644 --- a/torch/_logging/structured.py +++ b/torch/_logging/structured.py @@ -1,6 +1,7 @@ """ Utilities for converting data types into structured JSON for dumping. """ + import inspect import os import traceback diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index c590e4d0dfbe77..7f2bcfaa4720ee 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs import math +import operator from collections.abc import Sequence from enum import Enum -from functools import wraps +from functools import reduce, wraps from typing import Callable, Optional, TypeVar, Union from typing_extensions import ParamSpec @@ -16,7 +17,11 @@ meta_table, ) from torch._ops import OpOverload -from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND +from torch._prims import ( + _prim_elementwise_meta, + ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, + view_of, +) from torch._prims_common import ( BoolLike, corresponding_complex_dtype, @@ -25,6 +30,8 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, FloatLike, IntLike, + is_contiguous, + is_contiguous_or_false, make_contiguous_strides_for, Number, suggest_memory_format, @@ -195,6 +202,168 @@ def linalg_cross(self, other, *, dim=-1): return self.new_empty(out_shape) +# This function is python match of computeStride_impl in TensorUtils.cpp +def _compute_stride(old_shape, old_stride, new_shape, size_oblivious=False): + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + sym_eq, + ) + + def maybe_guard_or_false(x): + if size_oblivious: + return guard_or_false(x) + + return x + + def maybe_guard_or_true(x): + if size_oblivious: + return guard_or_true(x) + + return x + + if len(old_shape) == 0: + return [1] * len(new_shape) + + numel = reduce(operator.mul, old_shape, 1) + zero_numel = maybe_guard_or_false(numel == 0) + if zero_numel and maybe_guard_or_false(sym_eq(old_shape, new_shape)): + return old_stride + + new_stride = [0] * len(new_shape) + + if zero_numel: + for view_d in range(len(new_shape) - 1, -1, -1): + if view_d == len(new_shape) - 1: + new_stride[view_d] = 1 + else: + new_stride[view_d] = ( + max(new_shape[view_d + 1], 1) * new_stride[view_d + 1] + ) + return new_stride + + view_d = len(new_shape) - 1 + chunk_base_stride = old_stride[-1] + tensor_numel = 1 + view_numel = 1 + + for tensor_d in range(len(old_shape) - 1, -1, -1): + tensor_numel *= old_shape[tensor_d] + + if tensor_d == 0 or ( + maybe_guard_or_true(old_shape[tensor_d - 1] != 1) + and maybe_guard_or_true( + old_stride[tensor_d - 1] != tensor_numel * chunk_base_stride + ) + ): + while view_d >= 0 and ( + maybe_guard_or_true(view_numel < tensor_numel) + or maybe_guard_or_false(new_shape[view_d] == 1) + ): + new_stride[view_d] = view_numel * chunk_base_stride + view_numel *= new_shape[view_d] + view_d -= 1 + + if maybe_guard_or_true(view_numel != tensor_numel): + return None + + if tensor_d > 0: + chunk_base_stride = old_stride[tensor_d - 1] + tensor_numel = 1 + view_numel = 1 + if view_d != -1: + return None + return new_stride + + +def _view_has_unbacked_input(a, shape): + from torch.fx.experimental.symbolic_shapes import has_hint + + return ( + any(not has_hint(s) for s in a.size()) + or any(not has_hint(s) for s in a.stride()) + or any(not has_hint(s) for s in shape) + ) + + +def _view_unbacked_meta(a, shape, size_oblivious_enabled=True): + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq + + # Creates a valid shape + shape = utils.extract_shape_from_varargs(shape, validate=False) + + # Reshape may be given a shape with a -1 length + # This indicates that the dimension's length should be inferred + shape = utils.infer_size(shape, a.numel()) + + # Special-cases reshaping zero dim tensors + if a.ndim == 0: + _a = a + for length in shape: + torch._check(length == 1) + _a = torch._refs.unsqueeze(_a, -1) + if _a is a: + return view_of(a) + else: + return _a + + # Special-cases reshaping to zero dim tensors + if len(shape) == 0: + _a = a + for length in a.shape: + torch._check(length == 1) + _a = torch._refs.squeeze(_a, -1) + if _a is a: + return view_of(a) + else: + return _a + + shape_numel = reduce(operator.mul, shape, 1) + + torch._check( + a.numel() == shape_numel, + lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", + ) + + if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)): + return view_of(a) + + if is_contiguous_or_false(a) if size_oblivious_enabled else is_contiguous(a): + strides = utils.make_contiguous_strides_for(shape) + return a.as_strided(shape, strides) + + new_strides = _compute_stride( + a.size(), a.stride(), shape, size_oblivious=size_oblivious_enabled + ) + + if new_strides is not None: + return a.as_strided(shape, new_strides) + + # If we fail to do size oblivious view, and backed_size_oblivious was on, + # then we redo everything by looking at hints and guarding instead of failing. + # Also if the expression has unbacked symbols, then we run again with size_oblivious_enabled=False + # to throw a data dependent error. + + if size_oblivious_enabled and ( + torch.fx.experimental._config.backed_size_oblivious + or _view_has_unbacked_input(a, shape) + ): + return _view_unbacked_meta(a, shape, size_oblivious_enabled=False) + + msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" + raise ValueError(msg) + + +@register_meta(aten.view.default) +def _view_meta(a, *shape): + if torch.fx.experimental._config.backed_size_oblivious or _view_has_unbacked_input( + a, shape + ): + return _view_unbacked_meta(a, shape) + else: + return torch._refs._reshape_view_helper(a, *shape, allow_copy=False) + + @register_meta(aten.linalg_matrix_exp) @out_wrapper() def linalg_matrix_exp(self): @@ -330,14 +499,15 @@ def meta_fft_r2c(self, dim, normalization, onesided): if onesided: out_sizes[last_dim] = last_dim_halfsize - if device_hint(self) == "cuda": + if device_hint(self) == "cuda" or device_hint(self) == "xpu": # _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp + # _fft_r2c_xpu in torch-xpu-ops/src/ATen/native/xpu/SpectralOps.cpp output = self.new_empty( out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) ) working_tensor = self - if use_optimized_cufft_path(dim): + if device_hint(self) == "cuda" and use_optimized_cufft_path(dim): _exec_fft(output, working_tensor, out_sizes, dim, forward=True) else: # First do the R2C transform on the last dimension @@ -370,12 +540,6 @@ def meta_fft_r2c(self, dim, normalization, onesided): return output - elif device_hint(self) == "xpu": - sorted_dims = _sort_dims(self, dim, exclude_last=True) - out = self.new_empty( - out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) - ) - return _exec_fft(out, self, out_sizes, sorted_dims, forward=True) else: return self.new_empty( out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) @@ -2167,10 +2331,12 @@ def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> @register_meta([aten.baddbmm.default, aten.baddbmm.out]) @out_wrapper(exact_dtype=True) def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): + from torch.fx.experimental.symbolic_shapes import guard_or_true, sym_eq + dim1 = batch1.size(0) dim2 = batch1.size(1) dim3 = batch2.size(2) - if self.shape != (dim1, dim2, dim3): + if guard_or_true(torch.sym_not(sym_eq(self.shape, (dim1, dim2, dim3)))): self = self.expand((dim1, dim2, dim3)) torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") @@ -2376,6 +2542,7 @@ def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int ret_shape.append( _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]) ) + torch._check( any(x > 0 for x in ret_shape[2:]), lambda: f"Given input size per channel: {list(dims)}. " @@ -3426,9 +3593,9 @@ def _restride_src(self): return self.as_strided(shape, strides) out = self.new_empty(before_shape + replacement_shape + after_shape) - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import guard_or_false - if guard_size_oblivious(self.numel() == 0): + if guard_or_false(self.numel() == 0): # No need to worry about the output strides if self is empty. return out @@ -3503,6 +3670,11 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): return self.new_empty(self.size()) +@register_meta([aten.randint_like.Tensor]) +def meta_randint_like(self, high, **kwargs): + return self.new_empty(self.size()) + + @register_meta([aten._fused_adam_.default, aten._fused_adamw_.default]) def meta__fused_adam_( self, @@ -4071,6 +4243,11 @@ def meta_repeat(self, repeats): len(repeats) >= self.dim(), lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", ) + for i, rep in enumerate(repeats): + torch._check( + rep >= 0, + lambda: f"Repeats cannot be negative, found {rep} at index {i}", + ) # Add new leading dimensions to the tensor if the # number of target dimensions is larger than the # number of source dimensions. @@ -5429,10 +5606,10 @@ def gather_shape_check(self, dim, index): @register_meta(aten.gather.default) def meta_gather(self, dim, index, sparse_grad=False): - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import guard_or_false wrapped_dim = maybe_wrap_dim(dim, self.dim()) - is_index_empty = guard_size_oblivious(index.numel() == 0) + is_index_empty = guard_or_false(index.numel() == 0) if not is_index_empty: torch._check( index.dtype == torch.long or index.dtype == torch.int, @@ -5471,9 +5648,9 @@ def get_operator_enum(reduce_, use_new_options=False): # From aten/src/ATen/native/ScatterGatherChecks.h def scatter_gather_dtype_check(method_name, self, index, src_opt=None): - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import guard_or_true - if guard_size_oblivious(index.numel() != 0): + if guard_or_true(index.numel() != 0): torch._check( index.dtype == torch.long or index.dtype == torch.int, lambda: f"{method_name}(): Expected dtype int32/int64 for index", @@ -5492,9 +5669,9 @@ def ensure_nonempty_dim(dim): # From aten/src/ATen/native/ScatterGatherChecks.h def scatter_shape_check(self, dim, index, src_opt=None): - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import guard_or_false - if guard_size_oblivious(index.numel() == 0): + if guard_or_false(index.numel() == 0): return torch._check( ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), @@ -5655,6 +5832,26 @@ def meta__scaled_dot_product_flash_attention( ) +def alloc_with_matching_layout( + query: Tensor, + res_shape: tuple[int, ...], +): + if tuple(query.shape) == res_shape: + query_t = query.transpose(1, 2) + res = torch.empty_like(query_t).transpose(1, 2) + else: + dim_order = sorted( + [0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True + ) + permuted_shape = [res_shape[idx] for idx in dim_order] + final_permute = [dim_order.index(i) for i in range(len(dim_order))] + res = torch.empty( + permuted_shape, dtype=query.dtype, device=query.device + ).permute(final_permute) + + return res + + @register_meta([aten._scaled_dot_product_cudnn_attention]) def meta__scaled_dot_product_cudnn_attention( query: Tensor, @@ -5674,18 +5871,7 @@ def meta__scaled_dot_product_cudnn_attention( D_V = value.size(-1) res_shape = (B, H, S_Q, D_V) - if tuple(query.shape) == res_shape: - query_t = query.transpose(1, 2) - res = torch.empty_like(query_t).transpose(1, 2) - else: - dim_order = sorted( - [0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True - ) - permuted_shape = [res_shape[idx] for idx in dim_order] - final_permute = [dim_order.index(i) for i in range(len(dim_order))] - res = torch.empty( - permuted_shape, dtype=query.dtype, device=query.device - ).permute(final_permute) + res = alloc_with_matching_layout(query, res_shape) logsum_exp = torch.empty( (B, H, S_Q), @@ -5722,14 +5908,16 @@ def meta__scaled_dot_product_fused_attention_overrideable( scale: Optional[float] = None, ): B = query.size(0) - H = query.size(1) + H_Q = query.size(1) S_Q = query.size(2) S_KV = key.size(2) D_V = value.size(-1) - res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device) + res_shape = (B, H_Q, S_Q, D_V) + res = alloc_with_matching_layout(query, res_shape) + logsum_exp = torch.empty( - (B, H, S_Q), + (B, H_Q, S_Q), dtype=torch.float, device=query.device, ) @@ -7285,182 +7473,253 @@ def sigmoid(self: Tensor) -> Tensor: return torch.empty_like(self, dtype=result_dtype) -def _compute_grouped_gemm_output_size(mat1, mat2, offs): +def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): mat1_is_2d = mat1.dim() == 2 mat2_is_2d = mat2.dim() == 2 if mat1_is_2d: if mat2_is_2d: - return offs.size(0), mat1.size(0), mat2.size(1) + out_size = [offs.size(0), mat1.size(0), mat2.size(1)] else: torch._check( offs.size(0) == mat2.size(0), "matrix batch sizes have to match" ) - return mat1.size(0), mat2.size(-1) + out_size = [mat1.size(0), mat2.size(-1)] else: if mat2_is_2d: torch._check( offs.size(0) == mat1.size(0), "matrix batch sizes have to match" ) - return mat1.size(1), mat2.size(1) + out_size = [mat1.size(1), mat2.size(1)] else: # regular bmm torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match") - return mat1.size(0), mat1.size(1), mat2.size(-1) - - -@register_meta(aten._grouped_mm) -@out_wrapper() -def grouped_mm( - mat1: Tensor, - mat2: Tensor, - offs: Optional[Tensor] = None, - bias: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, -) -> Tensor: - torch._check(mat1.dim() == 2 or mat1.dim() == 3, lambda: "mat1 must be 2d or 3d") - torch._check(mat2.dim() == 2 or mat2.dim() == 3, lambda: "mat2 must be 2d or 3d") - torch._check( - (offs is not None) == (mat1.dim() == 2 or mat2.dim() == 2), - lambda: "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d", - ) - - if offs is not None: - torch._check(offs.dim() == 1, lambda: "offsets must be 1d") + out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)] out_dtype = out_dtype or mat1.dtype - torch._check(bias is None, lambda: "bias not supported yet") - - out_size = _compute_grouped_gemm_output_size(mat1, mat2, offs) - out = mat1.new_empty(out_size, dtype=out_dtype) + alignment = 16 // out_dtype.itemsize + size_padded = (out_size[-1] + alignment - 1) // alignment * alignment + if mat1_is_2d == mat2_is_2d: + out_stride = [out_size[1] * size_padded, size_padded, 1] + else: + out_stride = [size_padded, 1] + out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device) return out -@register_meta([aten._scaled_grouped_mm.default]) -def meta_scaled_grouped_mm( - mat_a: torch.Tensor, - mat_b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - offs: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, +def _meta_grouped_mm_common( + mat_a: Tensor, + mat_b: Tensor, + scale_a: Optional[torch.Tensor], + scale_b: Optional[torch.Tensor], + offs: Optional[Tensor] = None, + bias: Optional[Tensor] = None, scale_result: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): - # Check dimensions - torch._check( - mat_a.dim() == 2 or mat_a.dim() == 3, lambda: "mat_a has to be 2 or 3d" - ) torch._check( - mat_b.dim() == 2 or mat_b.dim() == 3, lambda: "mat_b has to be 2 or 3d" + (scale_a is None) == (scale_b is None), + lambda: "Either both scale factors are given, or none", ) + scaled = scale_a is not None and scale_b is not None - a_is_2d = mat_a.dim() == 2 - b_is_2d = mat_b.dim() == 2 + # Implementing all the checks from + # _grouped_mm_cuda()/_scaled_grouped_mm_cuda() code in + # aten/src/ATen/native/cuda/Blas.cpp. + + if scaled: + torch._check( + mat_a.dtype == torch.float8_e4m3fn and mat_b.dtype == torch.float8_e4m3fn, + lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", + ) + else: + torch._check( + mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16, + lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", + ) - # Check offsets torch._check( - (offs is not None) == (a_is_2d or b_is_2d), - lambda: "Have to provide offsets if there is a 2d matrix", + mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3], + lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", ) - if offs is not None: - torch._check(offs.dim() == 1, lambda: "offs has to be 1D") - torch._check(offs.dtype == torch.int, lambda: "Offsets have to be int32") + mat_a_is_2d = mat_a.dim() == 2 + mat_b_is_2d = mat_b.dim() == 2 - # Check matrix sizes - torch._check( - mat_a.size(-1) % 16 == 0, - lambda: f"Expected trailing dimension of mat_a to be divisible by 16 but got mat1 shape: {mat_a.size()}", - ) - torch._check( - mat_b.size(-2) % 16 == 0 and mat_b.size(-1) % 16 == 0, - lambda: f"Expected mat_b shape to be divisible by 16 but got mat_b shape: {mat_b.size()}", - ) + if scaled: - # Check scales - torch._check( - scale_a.dtype == torch.float and scale_b.dtype == torch.float, - lambda: "Both scale_a and scale_b must be float (fp32) tensors.", - ) + def is_row_major(mat): + mat_stride = mat.stride() + return mat_stride[-2] > 1 and mat_stride[-1] == 1 - # Check scale dimensions - scale_multiplier = offs.size(0) if (a_is_2d and b_is_2d) else 1 # type: ignore[union-attr] + def is_col_major(mat): + mat_stride = mat.stride() + return mat_stride[-2] == 1 and mat_stride[-1] > 1 - if a_is_2d: torch._check( - scale_a.dim() == 1, - lambda: f"scale must be a 1D tensor for 2D mat_a, but got {scale_a.dim()}D", + is_row_major(mat_a), + lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", ) - torch._check(scale_a.is_contiguous(), lambda: "scale_a must be contiguous") torch._check( - scale_a.size(0) == mat_a.size(0) * scale_multiplier, - lambda: "scale must have the same length as mat_a", - ) - else: - torch._check( - scale_a.dim() == 2, - lambda: f"scale must be a 2D tensor for 3D mat_a, but got {scale_a.dim()}D", - ) - torch._check( - scale_a.stride(1) == 1, - lambda: "scale_a must be contiguous in the last dimension", + is_col_major(mat_b), + lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", ) + + def check_valid_strides(mat_name, mat): + end_dim = mat.dim() - 1 + alignment = 16 // mat.element_size() + mat_stride = mat.stride() + if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max( + 1, mat.shape[end_dim - 1] + ): + torch._check( + mat_stride[end_dim] % alignment == 0, + lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", + ) + elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max( + 1, mat.shape[end_dim] + ): + torch._check( + mat_stride[end_dim - 1] % alignment == 0, + lambda: f"Expected {mat_name} stride along {end_dim - 1} dim to be multiple of 16 bytes, got {mat_stride[end_dim - 1]}.", # noqa: B950 + ) + else: + torch._check( + False, + lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950 + ) + + check_valid_strides("mat_a", mat_a) + check_valid_strides("mat_b", mat_b) + + if scale_a is not None and scale_b is not None: torch._check( - scale_a.size(0) == mat_a.size(0), - lambda: "scale must have the same batch dimension as mat_a", + scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, + lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950 ) - torch._check( - scale_a.size(1) == mat_a.size(1), - lambda: "scale must have the same first dimension as mat_a", + + def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): + if mat.dim() == 2: + torch._check( + scale.dim() == 1, + lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.", + ) + torch._check( + scale.is_contiguous(), + lambda: f"Expected {scale_name} to be contiguous.", + ) + torch._check( + scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier, + lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950 + ) + else: + torch._check( + scale.dim() == 2, + lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.", + ) + torch._check( + scale.stride(1) == 1, + lambda: f"Expected {scale_name} to be contiguous in the last dimension.", + ) + torch._check( + scale.shape[0] == mat.shape[0], + lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.", + ) + torch._check( + scale.shape[1] == mat.shape[1 + scaled_dim], + lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", + ) + + scale_multiplier = ( + offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1 ) + check_scale("scale_a", scale_a, mat_a, 0, scale_multiplier) + check_scale("scale_b", scale_b, mat_b, 1, scale_multiplier) - # Similar checks for scale_b - if b_is_2d: torch._check( - scale_b.dim() == 1, - lambda: f"scale must be a 1D tensor for 2D mat_b, but got {scale_b.dim()}D", + scale_result is None, + lambda: "Scale result tensor provided, but it is not supported yet.", ) - torch._check(scale_b.is_contiguous(), lambda: "scale_b must be contiguous") + + if mat_a_is_2d or mat_b_is_2d: torch._check( - scale_b.size(0) == mat_b.size(1) * scale_multiplier, - lambda: "scale must have the same length as mat_b", + offs is not None, + lambda: f"Offsets tensor not provided, but is needed for {mat_a.dim()}D/{mat_b.dim()}D multiplicand layouts.", ) + if offs is not None: # to silence Mypy + torch._check( + offs.dim() == 1, + lambda: f"Offsets tensor must be 1D, but got offs.dim()={offs.dim()}.", + ) + torch._check( + offs.dtype == torch.int32, + lambda: f"Offsets tensor must be integer (int32) tensor, but got {offs.dtype}.", + ) else: torch._check( - scale_b.dim() == 2, - lambda: f"scale must be a 2D tensor for 3D mat_b, but got {scale_b.dim()}D", - ) - torch._check( - scale_b.stride(1) == 1, - lambda: "scale_b must be contiguous in the last dimension", - ) - torch._check( - scale_b.size(0) == mat_b.size(0), - lambda: "scale must have the same batch dimension as mat_b", - ) - torch._check( - scale_b.size(1) == mat_b.size(2), - lambda: "scale must have the same last dimension as mat_b", + offs is None, + lambda: "Offsets tensor provided, but is not needed for 3D/3D multiplicand layouts.", ) - # Check bias - torch._check(bias is None, lambda: "Bias not supported yet") + torch._check( + bias is None, + lambda: "Bias tensor provided, but it is not supported yet.", + ) - # Check output dtype - out_dtype_ = out_dtype if out_dtype is not None else mat_a.dtype torch._check( - out_dtype_ == torch.bfloat16, - lambda: "Only bf16 high precision output types are supported for grouped gemm", + out_dtype is None or out_dtype == torch.bfloat16, + lambda: "If output dtype provided, it must be torch.bfloat16.", ) - # Compute output size - out_size = _compute_grouped_gemm_output_size(mat_a, mat_b, offs) - out = mat_a.new_empty(out_size, dtype=out_dtype) + return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype) - return out + +@register_meta(aten._grouped_mm) +@out_wrapper() +def grouped_mm( + mat_a: Tensor, + mat_b: Tensor, + offs: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> Tensor: + return _meta_grouped_mm_common( + mat_a, + mat_b, + scale_a=None, + scale_b=None, + offs=offs, + bias=bias, + scale_result=None, + out_dtype=out_dtype, + ) + + +@register_meta([aten._scaled_grouped_mm.default]) +def meta_scaled_grouped_mm( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + offs: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +): + return _meta_grouped_mm_common( + mat_a, + mat_b, + scale_a=scale_a, + scale_b=scale_b, + offs=offs, + bias=bias, + scale_result=scale_result, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) @register_meta(aten._softmax) diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index 27799adaf56376..e955a47060fffa 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -1,8 +1,9 @@ # mypy: ignore-errors -""" Define analogs of numpy dtypes supported by pytorch. +"""Define analogs of numpy dtypes supported by pytorch. Define the scalar types and supported dtypes and numpy <--> torch dtype mappings. """ + import builtins import torch diff --git a/torch/_numpy/_dtypes_impl.py b/torch/_numpy/_dtypes_impl.py index 7dfe6d4787bb27..d9eb9cc94c27e0 100644 --- a/torch/_numpy/_dtypes_impl.py +++ b/torch/_numpy/_dtypes_impl.py @@ -5,6 +5,7 @@ Here `dtype` is always a torch.dtype, this module knows nothing about scalar types, wrapper dtypes or anything like that. PyTorch only. """ + from collections import namedtuple import torch diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py index 3579cfe83b421f..4030ba97766b46 100644 --- a/torch/_numpy/_funcs_impl.py +++ b/torch/_numpy/_funcs_impl.py @@ -5,6 +5,7 @@ Things imported from here have numpy-compatible signatures but operate on pytorch tensors. """ + # Contents of this module ends up in the main namespace via _funcs.py # where type annotations are used in conjunction with the @normalizer decorator. from __future__ import annotations @@ -941,7 +942,7 @@ def choose( ] idx_list[0] = a - return choices[idx_list].squeeze(0) + return choices[tuple(idx_list)].squeeze(0) # ### unique et al. ### diff --git a/torch/_numpy/_normalizations.py b/torch/_numpy/_normalizations.py index ef60dd18900480..ccdf91c2497a8d 100644 --- a/torch/_numpy/_normalizations.py +++ b/torch/_numpy/_normalizations.py @@ -1,7 +1,7 @@ # mypy: ignore-errors -""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on. -""" +""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.""" + from __future__ import annotations import functools diff --git a/torch/_numpy/_reductions_impl.py b/torch/_numpy/_reductions_impl.py index 57dc8e660cad14..4afc217ebd4b7d 100644 --- a/torch/_numpy/_reductions_impl.py +++ b/torch/_numpy/_reductions_impl.py @@ -1,10 +1,11 @@ # mypy: ignore-errors -""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc +"""Implementation of reduction operations, to be wrapped into arrays, dtypes etc in the 'public' layer. Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc """ + from __future__ import annotations import functools diff --git a/torch/_numpy/_util.py b/torch/_numpy/_util.py index 3005dbd79da112..443623bcc90120 100644 --- a/torch/_numpy/_util.py +++ b/torch/_numpy/_util.py @@ -1,7 +1,6 @@ # mypy: ignore-errors -"""Assorted utilities, which do not need anything other then torch and stdlib. -""" +"""Assorted utilities, which do not need anything other then torch and stdlib.""" import operator diff --git a/torch/_numpy/random.py b/torch/_numpy/random.py index b10a4c667c8c63..a3d4a1c73241f8 100644 --- a/torch/_numpy/random.py +++ b/torch/_numpy/random.py @@ -7,6 +7,7 @@ Q: default dtype is float64 in numpy """ + from __future__ import annotations import functools diff --git a/torch/_numpy/testing/utils.py b/torch/_numpy/testing/utils.py index 29885b917049e3..cd0d33893ac28b 100644 --- a/torch/_numpy/testing/utils.py +++ b/torch/_numpy/testing/utils.py @@ -4,6 +4,7 @@ Utility function to facilitate testing. """ + import contextlib import gc import operator @@ -167,7 +168,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True): Examples -------- - >>> np.testing.assert_equal([4,5], [4,6]) + >>> np.testing.assert_equal([4, 5], [4, 6]) Traceback (most recent call last): ... AssertionError: @@ -298,8 +299,12 @@ def print_assert_equal(test_string, actual, desired): Examples -------- - >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) # doctest: +SKIP - >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) # doctest: +SKIP + >>> np.testing.print_assert_equal( + ... "Test XYZ of func xyz", [0, 1], [0, 1] + ... ) # doctest: +SKIP + >>> np.testing.print_assert_equal( + ... "Test XYZ of func xyz", [0, 1], [0, 2] + ... ) # doctest: +SKIP Traceback (most recent call last): ... AssertionError: Test XYZ of func xyz failed @@ -377,8 +382,9 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True): ACTUAL: 2.3333333333333 DESIRED: 2.33333334 - >>> assert_almost_equal(np.array([1.0,2.3333333333333]), - ... np.array([1.0,2.33333334]), decimal=9) + >>> assert_almost_equal( + ... np.array([1.0, 2.3333333333333]), np.array([1.0, 2.33333334]), decimal=9 + ... ) Traceback (most recent call last): ... AssertionError: @@ -487,11 +493,19 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True Examples -------- - >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) # doctest: +SKIP - >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, # doctest: +SKIP - ... significant=8) - >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, # doctest: +SKIP - ... significant=8) + >>> np.testing.assert_approx_equal( + ... 0.12345677777777e-20, 0.1234567e-20 + ... ) # doctest: +SKIP + >>> np.testing.assert_approx_equal( + ... 0.12345670e-20, + ... 0.12345671e-20, # doctest: +SKIP + ... significant=8, + ... ) + >>> np.testing.assert_approx_equal( + ... 0.12345670e-20, + ... 0.12345672e-20, # doctest: +SKIP + ... significant=8, + ... ) Traceback (most recent call last): ... AssertionError: @@ -501,7 +515,7 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True the evaluated condition that raises the exception is - >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) + >>> abs(0.12345670e-20 / 1e-21 - 0.12345672e-20 / 1e-21) >= 10 ** -(8 - 1) True """ @@ -776,15 +790,16 @@ def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): -------- The first assert does not raise an exception: - >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], - ... [np.exp(0),2.33333, np.nan]) + >>> np.testing.assert_array_equal( + ... [1.0, 2.33333, np.nan], [np.exp(0), 2.33333, np.nan] + ... ) Use `assert_allclose` or one of the nulp (number of floating point values) functions for these cases instead: - >>> np.testing.assert_allclose([1.0,np.pi,np.nan], - ... [1, np.sqrt(np.pi)**2, np.nan], - ... rtol=1e-10, atol=0) + >>> np.testing.assert_allclose( + ... [1.0, np.pi, np.nan], [1, np.sqrt(np.pi) ** 2, np.nan], rtol=1e-10, atol=0 + ... ) As mentioned in the Notes section, `assert_array_equal` has special handling for scalars. Here the test checks that each value in `x` is 3: @@ -809,7 +824,7 @@ def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): The `strict` parameter also ensures that the array data types match: >>> x = np.array([2, 2, 2]) - >>> y = np.array([2., 2., 2.], dtype=np.float32) + >>> y = np.array([2.0, 2.0, 2.0], dtype=np.float32) >>> np.testing.assert_array_equal(x, y, strict=True) Traceback (most recent call last): ... @@ -881,11 +896,11 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): -------- the first assert does not raise an exception - >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], - ... [1.0,2.333,np.nan]) + >>> np.testing.assert_array_almost_equal([1.0, 2.333, np.nan], [1.0, 2.333, np.nan]) - >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], - ... [1.0,2.33339,np.nan], decimal=5) + >>> np.testing.assert_array_almost_equal( + ... [1.0, 2.33333, np.nan], [1.0, 2.33339, np.nan], decimal=5 + ... ) Traceback (most recent call last): ... AssertionError: @@ -897,8 +912,9 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64) - >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], - ... [1.0,2.33333, 5], decimal=5) + >>> np.testing.assert_array_almost_equal( + ... [1.0, 2.33333, np.nan], [1.0, 2.33333, 5], decimal=5 + ... ) Traceback (most recent call last): ... AssertionError: @@ -1054,8 +1070,8 @@ def assert_string_equal(actual, desired): Examples -------- - >>> np.testing.assert_string_equal('abc', 'abc') # doctest: +SKIP - >>> np.testing.assert_string_equal('abc', 'abcd') # doctest: +SKIP + >>> np.testing.assert_string_equal("abc", "abc") # doctest: +SKIP + >>> np.testing.assert_string_equal("abc", "abcd") # doctest: +SKIP Traceback (most recent call last): File "", line 1, in ... @@ -1341,11 +1357,11 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): Examples -------- - >>> x = np.array([1., 1e-10, 1e-20]) + >>> x = np.array([1.0, 1e-10, 1e-20]) >>> eps = np.finfo(x.dtype).eps - >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) # doctest: +SKIP + >>> np.testing.assert_array_almost_equal_nulp(x, x * eps / 2 + x) # doctest: +SKIP - >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) # doctest: +SKIP + >>> np.testing.assert_array_almost_equal_nulp(x, x * eps + x) # doctest: +SKIP Traceback (most recent call last): ... AssertionError: X and Y are not equal to 1 ULP (max is 2) @@ -1404,7 +1420,7 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None): Examples -------- - >>> a = np.linspace(0., 1., 100) + >>> a = np.linspace(0.0, 1.0, 100) >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP """ @@ -1562,7 +1578,7 @@ def assert_warns(warning_class, *args, **kwargs): >>> import warnings >>> def deprecated_func(num): ... warnings.warn("Please upgrade", DeprecationWarning) - ... return num*num + ... return num * num >>> with np.testing.assert_warns(DeprecationWarning): ... assert deprecated_func(4) == 16 >>> # or passing a func @@ -1663,19 +1679,29 @@ def inp(): yield out, inp(), ufmt % (o, o, s, dtype, "out of place") d = inp() yield d, d, ufmt % (o, o, s, dtype, "in place") - yield out[1:], inp()[:-1], ufmt % ( - o + 1, - o, - s - 1, - dtype, - "out of place", + yield ( + out[1:], + inp()[:-1], + ufmt + % ( + o + 1, + o, + s - 1, + dtype, + "out of place", + ), ) - yield out[:-1], inp()[1:], ufmt % ( - o, - o + 1, - s - 1, - dtype, - "out of place", + yield ( + out[:-1], + inp()[1:], + ufmt + % ( + o, + o + 1, + s - 1, + dtype, + "out of place", + ), ) yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased") yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased") @@ -1691,53 +1717,89 @@ def inp1(): yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1") d = inp2() yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2") - yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % ( - o + 1, - o, - o, - s - 1, - dtype, - "out of place", + yield ( + out[1:], + inp1()[:-1], + inp2()[:-1], + bfmt + % ( + o + 1, + o, + o, + s - 1, + dtype, + "out of place", + ), ) - yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % ( - o, - o + 1, - o, - s - 1, - dtype, - "out of place", + yield ( + out[:-1], + inp1()[1:], + inp2()[:-1], + bfmt + % ( + o, + o + 1, + o, + s - 1, + dtype, + "out of place", + ), ) - yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % ( - o, - o, - o + 1, - s - 1, - dtype, - "out of place", + yield ( + out[:-1], + inp1()[:-1], + inp2()[1:], + bfmt + % ( + o, + o, + o + 1, + s - 1, + dtype, + "out of place", + ), ) - yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % ( - o + 1, - o, - o, - s - 1, - dtype, - "aliased", + yield ( + inp1()[1:], + inp1()[:-1], + inp2()[:-1], + bfmt + % ( + o + 1, + o, + o, + s - 1, + dtype, + "aliased", + ), ) - yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % ( - o, - o + 1, - o, - s - 1, - dtype, - "aliased", + yield ( + inp1()[:-1], + inp1()[1:], + inp2()[:-1], + bfmt + % ( + o, + o + 1, + o, + s - 1, + dtype, + "aliased", + ), ) - yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % ( - o, - o, - o + 1, - s - 1, - dtype, - "aliased", + yield ( + inp1()[:-1], + inp1()[:-1], + inp2()[1:], + bfmt + % ( + o, + o, + o + 1, + s - 1, + dtype, + "aliased", + ), ) @@ -1818,9 +1880,10 @@ class clear_and_catch_warnings(warnings.catch_warnings): -------- >>> import warnings >>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP - ... modules=[np.core.fromnumeric]): - ... warnings.simplefilter('always') - ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') + ... modules=[np.core.fromnumeric] + ... ): + ... warnings.simplefilter("always") + ... warnings.filterwarnings("ignore", module="np.core.fromnumeric") ... # do something that raises a warning but ignore those in ... # np.core.fromnumeric """ @@ -1918,6 +1981,8 @@ class suppress_warnings: sup = np.testing.suppress_warnings() sup.filter(module=np.ma.core) # module must match exactly + + @sup def some_function(): # do something which causes a warning in np.ma.core diff --git a/torch/_ops.py b/torch/_ops.py index 4d308dde96544b..337b9a11e6a180 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -6,8 +6,19 @@ import inspect import sys import types -from typing import Any, Callable, final, Optional, TYPE_CHECKING, TypeVar, Union -from typing_extensions import Concatenate, ParamSpec +from collections.abc import Iterator +from functools import cached_property +from typing import ( + Any, + Callable, + ClassVar, + final, + Generic, + Optional, + TYPE_CHECKING, + Union, +) +from typing_extensions import Concatenate, ParamSpec, TypeVar import torch import torch.utils._pytree as pytree @@ -21,8 +32,8 @@ from torch._subclasses.functional_tensor import BaseFunctionalizeAPI -_T = TypeVar("_T") -_P = ParamSpec("_P") +_T = TypeVar("_T", default=Any) +_P = ParamSpec("_P", default=...) # Query `hasattr` only once. @@ -319,6 +330,18 @@ def maybe_run_autograd(*args: _P.args, **kwargs: _P.kwargs) -> _T: with torch._C._AutoDispatchBelowAutograd(): return self(*args, **kwargs) + from torch._higher_order_ops.utils import _has_gen_schema + + if _has_gen_schema(self): + schema = self.gen_schema(*args, **kwargs) + if any(arg.is_write for arg in schema.arguments): + raise RuntimeError( + f"The {self.name()} HigherOrderOperator does not currently support training " + "with in-place input or buffer mutations " + "If you require this feature, please submit an issue to PyTorch. " + "Alternatively, consider creating your own custom autograd.Function. " + ) + return fn(*args, **kwargs) self.py_impl(DispatchKey.Autograd)(maybe_run_autograd) @@ -734,8 +757,15 @@ def get_cached_ops(): # Each OpOverload object contains pointer to a specific operator overload, a pointer to the parent `OpOverloadPacket` object. # You can obtain an OpOverload object through attribute query on OpOverloadPacket. -class OpOverload(OperatorBase): - def __init__(self, overloadpacket, op, op_dk, schema, tags): +class OpOverload(OperatorBase, Generic[_P, _T]): + def __init__( + self, + overloadpacket: "OpOverloadPacket", + op: Callable[_P, _T], + op_dk: Callable[Concatenate[DispatchKey, _P], _T], + schema: torch._C.FunctionSchema, + tags: list[Any], + ) -> None: super().__init__() self._op = op self._op_dk = op_dk @@ -755,9 +785,6 @@ def __init__(self, overloadpacket, op, op_dk, schema, tags): op.__module__ = overloadpacket.__module__ self.__qualname__ = self._name self.__annotations__ = {} - # Only compute the OperatorHandle when we need it. Not all OpOverloads have - # OperatorHandles (the TorchScript ones don't...) - self._lazy_handle = None # If the OpOverload was constructed from a Library.def in Python. self._defined_in_python = self.__qualname__ in torch.library._defs @@ -775,40 +802,38 @@ def __init__(self, overloadpacket, op, op_dk, schema, tags): is_write = a.alias_info.is_write or is_write self.is_view = is_write is not None and not is_write - @property - def _namespace(self): - return self._schema.name.split("::")[0] + @cached_property + def _namespace(self) -> str: + return self._schema.name.split("::", maxsplit=1)[0] - @property - def _opname(self): - return self._schema.name.split("::")[1] + @cached_property + def _opname(self) -> str: + return self._schema.name.split("::", maxsplit=1)[1] - @property - def _handle(self): - if self._lazy_handle is None: - self._lazy_handle = torch._C._dispatch_find_schema_or_throw( - self._schema.name, self._schema.overload_name - ) - return self._lazy_handle + @cached_property + def _handle(self) -> torch._C._DispatchOperatorHandle: + return torch._C._dispatch_find_schema_or_throw( + self._schema.name, self._schema.overload_name + ) # it's a no-op since OpOverload object is immutable and must be unique for a given op overload. def __deepcopy__(self, memo=None): return self def __repr__(self): - return "".format( - *self._schema.name.split("::"), self._overloadname - ) + return f"" # Use positional-only argument to avoid naming collision with aten ops arguments # that are named "self". This way, all the aten ops can be called by kwargs. - def __call__(self, /, *args, **kwargs): + def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: return self._op(*args, **kwargs) # Use positional-only argument to avoid naming collision with aten ops arguments # that are named "self". This way, all the aten ops can be called by kwargs. - def redispatch(self, /, keyset, *args, **kwargs): - return self._handle.redispatch_boxed(keyset, *args, **kwargs) + def redispatch( + self, /, keyset: torch._C.DispatchKeySet, *args: _P.args, **kwargs: _P.kwargs + ) -> _T: + return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value] def __hash__(self): return hash(self._op) @@ -817,27 +842,27 @@ def __hash__(self): def __str__(self): return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) - def has_kernel_for_dispatch_key(self, k): + def has_kernel_for_dispatch_key(self, k: DispatchKey) -> bool: return super().has_kernel_for_dispatch_key( k ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k) - def has_kernel_for_any_dispatch_key(self, ks): + def has_kernel_for_any_dispatch_key(self, ks: torch._C.DispatchKeySet) -> bool: return torch._C._dispatch_has_kernel_for_any_dispatch_key( self.name(), ks ) or super().has_kernel_for_any_dispatch_key(ks) @property - def namespace(self): - return self._schema.name.split("::")[0] + def namespace(self) -> str: + return self._namespace - def _can_decompose(self): + def _can_decompose(self) -> bool: dk = DispatchKey.CompositeImplicitAutograd return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key( self.name(), dk ) - def decompose(self, *args, **kwargs): + def decompose(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: dk = DispatchKey.CompositeImplicitAutograd if dk in self.py_kernels: # NB: This branch is not too necessary anymore, because we can @@ -858,11 +883,11 @@ def decompose(self, *args, **kwargs): # registering Autograd affects AutogradCPU). del_dispatch is to be used # only if you are specifically modifying how get_dispatch handles a # particular input 'key'. - def _uncache_dispatch(self, key): + def _uncache_dispatch(self, key: DispatchKey) -> None: self._dispatch_cache.pop(key, None) # This implements the pre-computation logic for the Python dispatcher. - def _get_dispatch(self, key): + def _get_dispatch(self, key: DispatchKey) -> Union[DispatchKey, Callable[_P, _T]]: # This is only called upon a cache miss assert key not in self._dispatch_cache, f"{self} {key}" @@ -872,7 +897,7 @@ def _get_dispatch(self, key): add_cached_op(self) return key - def handler(*args, **kwargs): + def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T: from torch.utils._python_dispatch import _get_current_dispatch_mode # TODO: We also need to handle tensor subclasses here @@ -912,7 +937,7 @@ def handler(*args, **kwargs): ) ): - def handler(*args, **kwargs): + def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T: @contextlib.contextmanager def _temporarily_pop_modes_from_pre_dispatch(): top_mode = _pop_mode_from_pre_dispatch() @@ -945,7 +970,7 @@ def _temporarily_pop_modes_from_pre_dispatch(): import torch._dispatch.python as pydispatch if pydispatch.CROSSREF_FUNCTIONALIZE: - handler = pydispatch.make_crossref_functionalize(self, final_key) + handler = pydispatch.make_crossref_functionalize(self, final_key) # type: ignore[assignment] if cache_result: self._dispatch_cache[key] = handler add_cached_op(self) @@ -979,7 +1004,7 @@ def tags(self): # schema consists of torch.ScriptObject (i.e. custom class) input. # TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python # when its inputs contain FakeScriptObject in a similar way as higher order ops. -class TorchBindOpOverload(OpOverload): +class TorchBindOpOverload(OpOverload[_P, _T]): def _fallthrough_keys(self) -> list[DispatchKey]: # TODO: we should be calling the fallback for these, but a fallthrough is almost close # enough to the fallback in most cases that we care about. @@ -1028,7 +1053,7 @@ def _register_as_effectful_op_temporarily(self): # Use positional-only argument to avoid naming collision with aten ops arguments # that are named "self". This way, all the aten ops can be called by kwargs. - def __call__(self, /, *args, **kwargs): + def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: if _must_dispatch_in_python(args, kwargs): # When any inputs are FakeScriptObject, we need to # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher @@ -1041,10 +1066,14 @@ def __call__(self, /, *args, **kwargs): # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction. with self._register_as_effectful_op_temporarily(): - return self._dispatch_in_python(args, kwargs, self._fallthrough_keys()) + return self._dispatch_in_python( + self._fallthrough_keys(), *args, **kwargs + ) return self._op(*args, **kwargs) - def _dispatch_in_python(self, args, kwargs, fallthrough_keys): + def _dispatch_in_python( + self, fallthrough_keys: list[DispatchKey], *args: _P.args, **kwargs: _P.kwargs + ) -> _T: non_fallthrough_keys = torch._C._dispatch_keyset_full() for key in fallthrough_keys: non_fallthrough_keys = non_fallthrough_keys.remove(key) @@ -1065,7 +1094,9 @@ def _dispatch_in_python(self, args, kwargs, fallthrough_keys): self.name(), dispatch_key ): return self._dispatch_in_python( - args, kwargs, fallthrough_keys + [dispatch_key] + fallthrough_keys + [dispatch_key], + *args, + **kwargs, ) raise RuntimeError( @@ -1097,15 +1128,23 @@ def _has_script_object_arg(schema: torch.FunctionSchema) -> bool: # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator # You can obtain an OpOverload object through attribute query. -class OpOverloadPacket: - def __init__(self, qualified_op_name, op_name, op, overload_names): +class OpOverloadPacket(Generic[_P, _T]): + __file__: ClassVar[str] = "torch.ops" + + def __init__( + self, + qualified_op_name: str, + op_name: str, + op: Callable[_P, _T], + overload_names: list[str], + ) -> None: # These attributes are accessible on the object through the properties # defined below but are immutable self._qualified_op_name = qualified_op_name self.__name__ = op_name self._op = op self._overload_names = overload_names - self._dir = [] + self._dir: list[str] = [] self._has_torchbind_op_overload = any( _has_script_object_arg(schema) for schema in self._schemas.values() ) @@ -1136,11 +1175,7 @@ def _schemas(self): for overload_name in self._overload_names } - def __getattr__(self, key) -> Any: - # It is not a valid op_name when __file__ is passed in - if key == "__file__": - return "torch.ops" - + def __getattr__(self, key: str) -> OpOverload[_P, _T]: # ensure that query for dunder attributes that does not exist on # opoverloadpacket but instead exists on the self._op object does not unnecessarily call # `_get_operation_overload` (which is an expensive operation). @@ -1175,7 +1210,7 @@ def __getattr__(self, key) -> Any: op_, op_dk_, tags = op_dk_tags schema = torch._C._get_schema(self._qualified_op_name, use_key) - overload = ( + overload: OpOverload[_P, _T] = ( OpOverload(self, op_, op_dk_, schema, tags) if not _has_script_object_arg(schema) else TorchBindOpOverload(self, op_, op_dk_, schema, tags) @@ -1189,12 +1224,12 @@ def __getattr__(self, key) -> Any: f"The underlying op of '{str(self)}' has no overload name '{key}'" ) from None - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._dir) # Use positional-only argument to avoid naming collision with aten ops arguments # that are named "self". This way, all the aten ops can be called by kwargs. - def __call__(self, /, *args, **kwargs): + def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: # overloading __call__ to ensure torch.ops.foo.bar() # is still callable from JIT # We save the function ptr as the `op` attribute on @@ -1204,8 +1239,8 @@ def __call__(self, /, *args, **kwargs): # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we # intercept it here and call TorchBindOpverload instead. if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): - return _call_overload_packet_from_python(self, args, kwargs) - return self._op(*args, **(kwargs or {})) + return _call_overload_packet_from_python(self, *args, **kwargs) + return self._op(*args, **kwargs) # TODO: use this to make a __dir__ def overloads(self): @@ -1214,7 +1249,9 @@ def overloads(self): # Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp # _jit_get_operations, which calls _get_operation_for_overload_or_packet. -def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs): +def _call_overload_packet_from_python( + op: OpOverloadPacket[_P, _T], *args: _P.args, **kwargs: _P.kwargs +) -> _T: # Re-use the torch function handling logic in cpp torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet( op, *args, **kwargs @@ -1288,19 +1325,18 @@ class _OpNamespace(types.ModuleType): operation will already exist). """ - def __init__(self, name): + __file__ = "torch.ops" + + def __init__(self, name: str) -> None: super().__init__("torch.ops." + name) self.name = name - self._dir = [] + self._dir: list[str] = [] - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._dir) - def __getattr__(self, op_name) -> Any: - # It is not a valid op_name when __file__ is passed in - if op_name == "__file__": - return "torch.ops" - elif op_name in ["__origin__", "__self__"]: + def __getattr__(self, op_name: str) -> OpOverloadPacket: + if op_name in ("__origin__", "__self__"): raise AttributeError( f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'" ) @@ -1353,19 +1389,25 @@ def _refresh_packet(packet): packet._overload_names = overload_names -class _PyOpNamespace(_OpNamespace): - def __init__(self, name, ops): - super().__init__(name) - self._ops = ops +class _HigherOrderNamespace(types.ModuleType): + __file__ = "torch.ops" + + def __init__(self) -> None: + super().__init__("torch.ops.higher_order") + self._dir: list[str] = [] + + def __iter__(self) -> Iterator[str]: + return iter(self._dir) - def __getattr__(self, name): - # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object. - op = self._ops.get(name, None) + def __getattr__(self, name: str) -> HigherOrderOperator: + # Following _OpNamespace.__getattr__, we cache the op on this object. + op = _higher_order_ops.get(name, None) if op is None: raise AttributeError( - f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'" + f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'" ) setattr(self, name, op) + self._dir.append(name) return op @@ -1375,23 +1417,17 @@ class _Ops(types.ModuleType): def __init__(self): super().__init__("torch.ops") self.loaded_libraries = set() - self._higher_order_op_namespace = _PyOpNamespace( - "torch.ops.higher_order", _higher_order_ops - ) + self.higher_order = _HigherOrderNamespace() self._dir = [] - def __getattr__(self, name): - # Check if the name is a HigherOrderOperator - if name == "higher_order": - return self._higher_order_op_namespace - + def __getattr__(self, name: str) -> _OpNamespace: # Here we are creating `torch.ops.my_namespace` namespace = _OpNamespace(name) setattr(self, name, namespace) self._dir.append(name) return namespace - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._dir) def import_module(self, module): diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 0fed56ebc3fc16..93c9e5ffb10115 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -2513,7 +2513,11 @@ def _full_aten( ) -> Tensor: # Note that Mypy thinks torch.full can't accept a complex fill_value return torch.full( - shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] + shape, + fill_value, + dtype=dtype, + device=device, + requires_grad=requires_grad, # type: ignore[arg-type] ) @@ -2556,7 +2560,11 @@ def _full_like_aten( ) -> Tensor: # Note that Mypy thinks torch.full can't accept a complex fill_value return torch.full_like( - a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] + a, + fill_value, + dtype=dtype, + device=device, + requires_grad=requires_grad, # type: ignore[arg-type] ) diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 36cb40e79165c9..30bc1f85c0ee38 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -1,8 +1,13 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import functools -from collections.abc import Sequence from contextlib import nullcontext -from typing import Any, Callable, Optional +from typing import Any, Callable, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec + + +if TYPE_CHECKING: + from collections.abc import Sequence import torch import torch._decomp @@ -15,8 +20,12 @@ from torch._prims_common import torch_function_passthrough -@functools.lru_cache(None) -def torch_to_refs_map(): +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +@functools.cache +def torch_to_refs_map() -> dict[Any, Any]: """ Mapping of torch API functions to torch._refs functions. E.g. torch_to_refs_map()[torch.add] == torch._refs.add @@ -70,8 +79,8 @@ def torch_to_refs_map(): return r -@functools.lru_cache(None) -def all_prims(): +@functools.cache +def all_prims() -> set[Any]: """ Set of all prim functions, e.g., torch._prims.add in all_prims() """ @@ -95,21 +104,21 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode): def __init__( self, - strict=False, - should_fallback_fn=lambda *_: False, - prims_mode_cls=nullcontext, - ): + strict: bool = False, + should_fallback_fn: Callable[..., bool] = lambda *_: False, + prims_mode_cls: type = nullcontext, + ) -> None: self.strict = strict self.should_fallback_fn = should_fallback_fn self.prims_mode_cls = prims_mode_cls def __torch_function__( self, - orig_func: Callable, - types: Sequence, + orig_func: Callable[_P, _R], + types: Sequence[type], args: Sequence[Any] = (), - kwargs: Optional[dict] = None, - ): + kwargs: dict[str, Any] | None = None, + ) -> Any: if kwargs is None: kwargs = {} # For primitive operations, run them as is without interception diff --git a/torch/_prims/debug_prims.py b/torch/_prims/debug_prims.py index 6958cbcef283db..d52462815229bd 100644 --- a/torch/_prims/debug_prims.py +++ b/torch/_prims/debug_prims.py @@ -1,5 +1,5 @@ -# mypy: allow-untyped-defs import contextlib +from collections.abc import Generator, Sequence from typing import Optional import torch @@ -10,7 +10,7 @@ @contextlib.contextmanager -def load_tensor_reader(loc): +def load_tensor_reader(loc: str) -> Generator[None, None, None]: global LOAD_TENSOR_READER assert LOAD_TENSOR_READER is None # load_tensor is an "op", and we will play merry hell on @@ -26,14 +26,20 @@ def load_tensor_reader(loc): LOAD_TENSOR_READER = None -def register_debug_prims(): +def register_debug_prims() -> None: torch.library.define( "debugprims::load_tensor", "(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor", ) @torch.library.impl("debugprims::load_tensor", "BackendSelect") - def load_tensor_factory(name, size, stride, dtype, device): + def load_tensor_factory( + name: str, + size: Sequence[int], + stride: Sequence[int], + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: if LOAD_TENSOR_READER is None: from torch._dynamo.testing import rand_strided @@ -50,5 +56,5 @@ def load_tensor_factory(name, size, stride, dtype, device): # Unlike the other properties, we will do coercions for dtype # mismatch if r.dtype != dtype: - r = clone_input(r, dtype=dtype) + r = clone_input(r, dtype=dtype) # type: ignore[no-untyped-call] return r diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index 70b4bc472358ad..e6ed4a4e3ea6bb 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Optional +from typing import cast, Optional import torch import torch.utils._pytree as pytree @@ -69,12 +69,10 @@ def philox_rand_offset( curand4_engine_calls = 4 device_property = torch.cuda.get_device_properties(torch.cuda.current_device()) blocks_per_sm = device_property.max_threads_per_multi_processor // block_size - grid_size = (numel + block_size - 1) // block_size + num = cast(int, numel) + grid_size = (num + block_size - 1) // block_size grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) - offset = ( - (numel - 1) // (block_size * grid_size * unroll) + 1 - ) * curand4_engine_calls - return offset + return ((num - 1) // (block_size * grid_size * unroll) + 1) * curand4_engine_calls def register_philox_rand(): @@ -342,9 +340,9 @@ def impl_cuda(op, *args, rng_state=None, **kwargs): @graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect) def impl_backend_select(op, *args, rng_state=None, **kwargs): device = get_device(args, kwargs) - assert ( - device == "cuda" - ), f"GraphSafe RNG operations only supported for CUDA, got {device}" + assert device == "cuda", ( + f"GraphSafe RNG operations only supported for CUDA, got {device}" + ) return impl_cuda(op, *args, rng_state=rng_state, **kwargs) @graphsafe_run_with_rng_state.py_impl(FakeTensorMode) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 79ea646964f07c..1ba3c34d512b1a 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -33,17 +33,13 @@ import sympy class _WorksWithInt(typing.Protocol): - def __add__(self, other: Any) -> typing.Self: - ... + def __add__(self, other: Any) -> typing.Self: ... - def __radd__(self, other: Any) -> typing.Self: - ... + def __radd__(self, other: Any) -> typing.Self: ... - def __mul__(self, other: Any) -> typing.Self: - ... + def __mul__(self, other: Any) -> typing.Self: ... - def __rmul__(self, other: Any) -> typing.Self: - ... + def __rmul__(self, other: Any) -> typing.Self: ... _IntLikeT = TypeVar("_IntLikeT", bound=_WorksWithInt) @@ -270,6 +266,7 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: guard_or_false, guard_or_true, guard_size_oblivious, + is_nested_int, ) maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious @@ -279,17 +276,25 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: return True expected_stride = 1 + expected_stride_max = 1 + for x, y in reversed(tuple(zip(a.shape, a.stride()))): - # Skips checking strides when a dimension has length 1 + # Skips checking strides when a dimension has length 1. if maybe_guard_or_false(x == 1): continue - if maybe_guard_or_true(y != expected_stride): + if maybe_guard_or_true(y != expected_stride) and maybe_guard_or_true( + y != expected_stride_max + ): return False - # if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can - # can assume x is not 0 in expected_stride equation. This is also consistent with make_contiguous_strides_for. - expected_stride = expected_stride * sym_max(x, 1) + # We symbolically check both paths to maximize the cases where this function + # returns true. This is because make_contiguous_strides_for adds the max + # symbolically, and in some other situations the max might not be there. + # And we want to ensure we return true in both cases. + expected_stride_max *= x if is_nested_int(x) else sym_max(x, 1) # type:ignore[assignment] + + expected_stride *= x return True @@ -386,22 +391,22 @@ def is_contiguous_for_memory_format( # type: ignore[return] ) -def definitely_contiguous(a: TensorLikeType) -> bool: +def is_contiguous_or_false(a: TensorLikeType) -> bool: return is_contiguous(a, false_if_dde=True) # similar to is_channels_last_contiguous_2d but return false on data dependency. -def is_known_channels_last_contiguous_2d(a: Tensor) -> bool: +def is_channels_last_contiguous_or_false_2d(a: Tensor) -> bool: return is_channels_last_contiguous_2d(a, false_if_dde=True) # similar to is_channels_last_contiguous_3d but return false on data dependency. -def is_known_channels_last_contiguous_3d(a: Tensor) -> bool: +def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool: return is_channels_last_contiguous_3d(a, false_if_dde=True) # similar to is_contiguous_for_memory_format but return false on data dependency. -def definitely_contiguous_for_memory_format( # type: ignore[return] +def contiguous_for_memory_format_or_false( # type: ignore[return] a: Tensor, *, memory_format: torch.memory_format ) -> bool: return is_contiguous_for_memory_format( @@ -427,10 +432,10 @@ def is_channels_last_contiguous(a: Tensor) -> bool: # similar to is_channels_last_contiguous but return false on data dependency. -def is_known_channels_last_contiguous(a: Tensor) -> bool: - return is_known_channels_last_contiguous_2d( +def is_channels_last_contiguous_or_false(a: Tensor) -> bool: + return is_channels_last_contiguous_or_false_2d( a - ) or is_known_channels_last_contiguous_3d(a) + ) or is_channels_last_contiguous_or_false_3d(a) def is_non_overlapping_and_dense(a: Tensor) -> bool: @@ -447,7 +452,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: return False # Short-circuits if the tensor is already contiguous or channels-last contiguous - if is_contiguous(a) or is_channels_last_contiguous(a): + if is_contiguous_or_false(a) or is_channels_last_contiguous_or_false(a): return True # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp @@ -542,10 +547,10 @@ def compute_elementwise_output_logical_to_physical_perm( is_contiguous = True is_channels_last = True for t in tensors: - is_contiguous = is_contiguous and definitely_contiguous_for_memory_format( + is_contiguous = is_contiguous and contiguous_for_memory_format_or_false( t, memory_format=torch.contiguous_format ) - is_channels_last = is_channels_last and definitely_contiguous_for_memory_format( + is_channels_last = is_channels_last and contiguous_for_memory_format_or_false( t, memory_format=torch.channels_last ) @@ -907,7 +912,7 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: # Extracts dimensions that might be passed either as a list/tuple or as varargs. # A typical case is Tensor.permute . def extract_dims_from_varargs( - dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]] + dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]], ) -> DimsSequenceType: if dims and isinstance(dims[0], Sequence): assert len(dims) == 1 @@ -1229,7 +1234,7 @@ def get_higher_dtype( assert b is None or isinstance(b, (torch.dtype, TensorLike, Number)) def _extract_dtype( - x: Optional[Union[torch.dtype, TensorLikeType, NumberType]] + x: Optional[Union[torch.dtype, TensorLikeType, NumberType]], ) -> Optional[torch.dtype]: if x is None: return None @@ -1447,7 +1452,7 @@ class RETURN_TYPE(Enum): # TODO: when NumberType contains the sym types, can simplify this def number_type( - x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool] + x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool], ) -> type: if isinstance(x, torch.SymInt): return int @@ -1703,9 +1708,7 @@ def make_contiguous_strides_for( strides = [] for l in reversed(shape): strides.append(multiplier) - multiplier *= ( - l if is_nested_int(l) else sym_max(l, 1) - ) # type:ignore[assignment] + multiplier *= l if is_nested_int(l) else sym_max(l, 1) # type:ignore[assignment] result = tuple(reversed(strides)) @@ -1855,7 +1858,9 @@ def compute_required_storage_length( >>> # xdoctest: +SKIP(failing) >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) - >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) + >>> size = compute_required_storage_length( + ... t2.shape, t2.stride(), t2.storage_offset() + ... ) >>> size == t.storage().size() True @@ -1865,7 +1870,9 @@ def compute_required_storage_length( >>> slice.storage().size() 100 - >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset()) + >>> compute_required_storage_length( + ... slice.shape, slice.stride(), slice.storage_offset() + ... ) 40 """ diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index e08b5937260491..2ccba3c28c132d 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -316,8 +316,7 @@ def maybe_check_copy_devices(out): and len(result) == len(out_names) # type: ignore[arg-type] ) or ( - fn.__name__ == "unbind" - and isinstance(result, (list, tuple)) # type: ignore[arg-type] + fn.__name__ == "unbind" and isinstance(result, (list, tuple)) # type: ignore[arg-type] ) ) # unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829 @@ -342,9 +341,15 @@ def maybe_check_copy_devices(out): assert isinstance(out, TensorLike) # These two operations are done in-place _maybe_resize_out( - out, result.shape, maybe_compute_memory_format(result) # type: ignore[union-attr] + out, + result.shape, # type: ignore[union-attr] + maybe_compute_memory_format(result), + ) + _safe_copy_out( + copy_from=result, # type: ignore[arg-type] + copy_to=out, + exact_dtype=exact_dtype, ) - _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] else: if fn.__name__ != "unbind": assert isinstance(out, tuple) # type: ignore[arg-type] @@ -370,7 +375,9 @@ def maybe_check_copy_devices(out): annotation=out_type, ) # Mark that the function now returns a tuple - assert isinstance(sig.return_annotation, str) or sig.return_annotation in ( + assert isinstance( + sig.return_annotation, (str, TypeVar) + ) or sig.return_annotation in ( sig.empty, out_type, bc_out_type, @@ -383,7 +390,8 @@ def maybe_check_copy_devices(out): params = sorted(params, key=lambda p: p.kind) _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] - parameters=params, return_annotation=return_type # type: ignore[arg-type] + parameters=params, + return_annotation=return_type, # type: ignore[arg-type] ) _fn.__annotations__ = dict(getattr(fn, "__annotations__", {})) @@ -398,7 +406,9 @@ def maybe_check_copy_devices(out): # Add an indicator attribute that can be used in special cases # where having a function wrapped by `out_wrapper` is not desirable e.g. # jit - _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined] + _fn._torch_decompositions_out_wrapper = ( # type: ignore[attr-defined] + f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" + ) return _fn diff --git a/torch/_python_dispatcher.py b/torch/_python_dispatcher.py index 2dfdbb296a4b2d..d2d4fbbf621e56 100644 --- a/torch/_python_dispatcher.py +++ b/torch/_python_dispatcher.py @@ -92,10 +92,10 @@ def keys(self): """ def register(self, dispatchKeys): - # Overriden is not supported and triggers a warning in C++ dispatcher. + # Overridden is not supported and triggers a warning in C++ dispatcher. if len(set(dispatchKeys)) != len(dispatchKeys): raise RuntimeError( - f"Overriden is not allowed but found duplicates in {dispatchKeys}." + f"Overridden is not allowed but found duplicates in {dispatchKeys}." ) # We currently forbid this in codegen instead of C++ dispatcher. if ( diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index e128a3b5f81f66..3b2344f44b9d72 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -19,7 +19,7 @@ from torch import sym_float, sym_int from torch._prims_common import ( BoolLike, - definitely_contiguous, + contiguous_for_memory_format_or_false, DeviceLikeType, Dim, DimsSequenceType, @@ -29,6 +29,7 @@ FloatLike, FloatWithoutSymFloat, IntLike, + is_contiguous_or_false, is_weakly_lesser_type, Number, NumberType, @@ -384,7 +385,7 @@ def handle_noncontiguous_outputs(input_tlist, output): def _broadcast_shapes(*_shapes): - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import guard_or_false shapes = tuple( (x,) if isinstance(x, IntLike) else x @@ -406,13 +407,20 @@ def _broadcast_shapes(*_shapes): ] * reduce(max, (len(shape) for shape in shapes)) for arg_idx, shape in enumerate(shapes): for idx in range(-1, -1 - len(shape), -1): - if guard_size_oblivious(common_shape[idx] == 1): + # if both 1, or statically known the same, we rather pick non-broadcast path. + if guard_or_false(common_shape[idx] == shape[idx]): + continue + elif guard_or_false(common_shape[idx] == 1): if shape[idx] < 0: raise ValueError( "Attempting to broadcast a dimension with negative length!" ) common_shape[idx] = shape[idx] - elif guard_size_oblivious(shape[idx] != 1): + elif guard_or_false(shape[idx] == 1): + # broadcast case . + continue + else: + # If broadcasting is undecided we pick non-broadcast path and add runtime assertion. torch._check( common_shape[idx] == shape[idx], lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " @@ -2237,17 +2245,21 @@ def _reduction( return result -def _make_copy_from_view(fn): +def _make_copy_from_view(fn, return_none_on_out_variant=False): """ Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy) """ aten_fn = getattr(aten, fn.__name__) annotations = getattr(fn, "__annotations__", {}) - fn = out_wrapper()(aten_fn) + # view ops should not change dtypes, this ensures that the decomp path has + # the same error checks as eager. + fn = out_wrapper(exact_dtype=True)(aten_fn) @wraps(fn) def _fn(*args, out=None, **kwargs): result = fn(*args, out=out, **kwargs) + if return_none_on_out_variant and out is not None: + return None if out is not None: return result @@ -2768,7 +2780,10 @@ def cat_compute_output_memory_format(inputs): utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_size_oblivious, + ) # This is a bit tricky. Naively, you would expect to just pick one # arbitrary tensor and check that all tensors match this tensor. However, @@ -2829,7 +2844,7 @@ def cat_compute_output_memory_format(inputs): ) else: # Remove inputs that are 1-D, zero size - if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0): + if tensor.ndim == 1 and guard_or_false(tensor.shape[0] == 0): continue # Don't bother checking size match, prims.cat will handle it filtered.append(tensor) @@ -2975,7 +2990,8 @@ def contiguous( lambda: "preserve memory format is unsupported by the contiguous operator", ) - if utils.is_contiguous_for_memory_format(a, memory_format=memory_format): + # TODO: make logic consistent with aten contiguous + if contiguous_for_memory_format_or_false(a, memory_format=memory_format): return a return torch.clone(a, memory_format=memory_format) @@ -3298,11 +3314,11 @@ def native_layer_norm( + str(input.shape), ) - input = input.contiguous() + input = contiguous(input) if weight is not None: - weight = weight.contiguous() + weight = contiguous(weight) if bias is not None: - bias = bias.contiguous() + bias = contiguous(bias) axis = input.ndim - normalized_ndim reduction_dims = list(range(axis, input.ndim)) @@ -3730,118 +3746,9 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor: return permuted_result.reshape(target_shape) -# this function is python match of computeStride_impl in TensorUtils.cpp -def _compute_stride(old_shape, old_stride, new_shape): - from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - guard_or_true, - sym_eq, - ) - - if len(old_shape) == 0: - return [1] * len(new_shape) - - numel = reduce(operator.mul, old_shape, 1) - zero_numel = guard_or_false(numel == 0) - if zero_numel and guard_or_false(sym_eq(old_shape, new_shape)): - return old_stride - - new_stride = [0] * len(new_shape) - - if zero_numel: - for view_d in range(len(new_shape) - 1, -1, -1): - if view_d == len(new_shape) - 1: - new_stride[view_d] = 1 - else: - new_stride[view_d] = ( - max(new_shape[view_d + 1], 1) * new_stride[view_d + 1] - ) - return new_stride - - view_d = len(new_shape) - 1 - chunk_base_stride = old_stride[-1] - tensor_numel = 1 - view_numel = 1 - - for tensor_d in range(len(old_shape) - 1, -1, -1): - tensor_numel *= old_shape[tensor_d] - - if tensor_d == 0 or ( - guard_or_true(old_shape[tensor_d - 1] != 1) - and guard_or_true( - old_stride[tensor_d - 1] != tensor_numel * chunk_base_stride - ) - ): - while view_d >= 0 and ( - guard_or_true(view_numel < tensor_numel) - or guard_or_false(new_shape[view_d] == 1) - ): - new_stride[view_d] = view_numel * chunk_base_stride - view_numel *= new_shape[view_d] - view_d -= 1 - - if guard_or_true(view_numel != tensor_numel): - return None - - if tensor_d > 0: - chunk_base_stride = old_stride[tensor_d - 1] - tensor_numel = 1 - view_numel = 1 - if view_d != -1: - return None - return new_stride - - -# This function is called to trace through view operation during fake tensor tracing. -# It will be called when the exisiting path throws a data dependent error. It's much -# simpler that reshape_view_helper, if it fails it will throw the original data_dependent_error -# that was passed to it. -# The function does the following: -# (1) if _compute_stride succeeds, the requested shape is valid, the output strides are those -# returned by _compute_stride. -# (2) if a contiguous, we know the requested shape is valid, the output strides can be computed using -# make_contiguous_strides_for. -def _view_simple(a: TensorLikeType, shape, data_dependent_error) -> TensorLikeType: - from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq - - # Creates a valid shape - shape = utils.extract_shape_from_varargs(shape, validate=False) - - # Reshape may be given a shape with a -1 length - # This indicates that the dimension's length should be inferred - shape = utils.infer_size(shape, a.numel()) - - # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape - shape_numel = reduce(operator.mul, shape, 1) - torch._check( - a.numel() == shape_numel, - f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", - ) - - if len(shape) == len(a.shape) and statically_known_true(sym_eq(shape, a.shape)): - return prims.view_of(a) - - new_strides = _compute_stride(a.size(), a.stride(), shape) - if new_strides is not None: - return a.as_strided(shape, new_strides) - - if definitely_contiguous(a): - return a.as_strided(shape, utils.make_contiguous_strides_for(shape)) - - raise data_dependent_error - - def _reshape_view_helper_core_alg( a: TensorLikeType, shape, allow_copy: bool ) -> TensorLikeType: - from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - guard_or_true, - GuardOnDataDependentSymNode, - ) - - deferred: list[Callable[[], bool]] = [] - # NOTE [Reshape Algorithm] # This algorithm works by attempting to greedily construct the desired dimensions in # the output shape, left to right. It does this by, conceptually, accumulating @@ -3873,34 +3780,13 @@ def _reshape_view_helper_core_alg( continue # Skips dimensions that are already the correct length - if guard_or_false(length == a_.shape[idx]): + if length == a_.shape[idx]: idx = idx + 1 continue - # Gathers enough original dimensions such that this new dimension can be created - # Note that this accumulation will terminate because we've verified a and the shape - # specify the same number of elements above - def maybe_throw_dde(): - # NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input - # tensor dimensions to match the target shape (length), we've hit data-dependent errors testing - # divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd - # figure out a valid reshape later in the loop. - # But we failed, either by running out of dimensions, or we couldn't figure out the strides, - # and we've decided to re-raise to either graph break out, or provide the exact guard so the user - # can torch._check() to avoid this. - for f in deferred: - f() - accum = a_.shape[idx] end = idx - while True: - try: - if accum % length == 0: - break - except GuardOnDataDependentSymNode: - deferred.append(lambda: bool(accum % length == 0)) - if end == a_.ndim - 1: - maybe_throw_dde() + while accum % length != 0: end += 1 accum *= a_.shape[end] if end != idx: @@ -3914,7 +3800,6 @@ def maybe_throw_dde(): if allow_copy: return prims.reshape(a, shape) - maybe_throw_dde() msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" raise ValueError(msg) @@ -3922,7 +3807,7 @@ def maybe_throw_dde(): # Splits the (possibly flattened) dimension to create the desired dim length. # guard_or_true is safe due to the tail unsqueeze routine. - if guard_or_true(accum != length): + if accum != length: a_ = prims.split_dim(a_, idx, length) idx = idx + 1 @@ -3942,11 +3827,6 @@ def maybe_throw_dde(): def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: - from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - GuardOnDataDependentSymNode, - ) - # Creates a valid shape shape = utils.extract_shape_from_varargs(shape, validate=False) # Reshape may be given a shape with a -1 length @@ -3954,7 +3834,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL shape = utils.infer_size(shape, a.numel()) # Special-cases tensors with no elements - if guard_or_false(a.numel() == 0): + if a.numel() == 0: return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) # Special-cases reshaping zero dim tensors @@ -3979,7 +3859,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL else: return _a - if a.is_contiguous(): + if is_contiguous_or_false(a): # Special-cases for nd_to_1d if len(shape) == 1 and a.ndim > 1: return torch.as_strided(a, [a.numel()], [1]) @@ -3994,14 +3874,9 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL a.numel() == shape_numel, f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", ) - try: - # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape - return _reshape_view_helper_core_alg(a, shape, allow_copy) - except GuardOnDataDependentSymNode as e: - # For compile this function is only called on view operations since reshape_symint will do a clone and - # compose to view before calling this. GuardOnDataDependentSymNode does not show up for eager. - assert not allow_copy - return _view_simple(a, shape, e) + + # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape + return _reshape_view_helper_core_alg(a, shape, allow_copy) # CompositeImplicitAutograd - don't register decomp @@ -4171,14 +4046,15 @@ def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: @register_decomposition(aten.unbind) def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - dim = utils.canonicalize_dim(t.ndim, dim) torch._check_index( len(t.shape) > 0, lambda: "Dimension specified as 0 but tensor has no dimensions", ) - if guard_size_oblivious(t.shape[dim] == 0): + + # Note: t.shape[dim] can't be dynamic or unbacked, even if we use guard_or_false here we will fail + # later in the split since t.shape[dim] control the number of output tensors. + if t.shape[dim] == 0: return () else: return tuple( @@ -6626,7 +6502,7 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) permute_copy = _make_copy_from_view(aten.permute) t_copy = _make_copy_from_view(aten.t) transpose_copy = _make_copy_from_view(aten.transpose) -unbind_copy = _make_copy_from_view(aten.unbind) +unbind_copy = _make_copy_from_view(aten.unbind, return_none_on_out_variant=True) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) view_copy = _make_copy_from_view(aten.view) diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py index c95a5bab02f2e6..e12e4c8e603ba5 100644 --- a/torch/_refs/fft.py +++ b/torch/_refs/fft.py @@ -313,7 +313,8 @@ def _canonicalize_fft_shape_and_dim_args( # Translate any -1 values in shape to the default length ret_shape = tuple( - s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] + s if s != -1 else input_sizes[d] + for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] ) elif dim is None: # No shape, no dim diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index b1520261d2fa4a..418691fe24aaa9 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -135,7 +135,7 @@ def vector_norm( *, dtype: Optional[torch.dtype] = None, ) -> Tensor: - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import guard_or_false check_fp_or_complex(x.dtype, "linalg.vector_norm") @@ -170,7 +170,7 @@ def vector_norm( if (dim is None and x.numel() == 1) or ( dim is not None - and (x.ndim > 0 and all(guard_size_oblivious(x.shape[d] == 1) for d in dim)) + and (x.ndim > 0 and all(guard_or_false(x.shape[d] == 1) for d in dim)) ): if x.ndim > 64: raise RuntimeError( diff --git a/torch/_storage_docs.py b/torch/_storage_docs.py index d1dbad078d9af8..f0d16bc4250ffb 100644 --- a/torch/_storage_docs.py +++ b/torch/_storage_docs.py @@ -20,7 +20,7 @@ def add_docstr_all(method, docstr): add_docstr_all( "from_file", """ -from_file(filename, shared=False, size=0) -> Storage +from_file(filename, shared=False, nbytes=0) -> Storage Creates a CPU storage backed by a memory-mapped file. @@ -28,15 +28,15 @@ def add_docstr_all(method, docstr): All changes are written to the file. If ``shared`` is ``False``, then the changes on the storage do not affect the file. -``size`` is the number of elements in the storage. If ``shared`` is ``False``, -then the file must contain at least ``size * sizeof(Type)`` bytes -(``Type`` is the type of storage, in the case of an ``UnTypedStorage`` the file must contain at -least ``size`` bytes). If ``shared`` is ``True`` the file will be created if needed. +``nbytes`` is the number of bytes of storage. If ``shared`` is ``False``, +then the file must contain at least ``nbytes`` bytes. If ``shared`` is +``True`` the file will be created if needed. (Note that for ``UntypedStorage`` +this argument differs from that of ``TypedStorage.from_file``) Args: filename (str): file name to map shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the underlying `mmap(2) call `_) - size (int): number of elements in the storage + nbytes (int): number of bytes of storage """, ) diff --git a/torch/_strobelight/cli_function_profiler.py b/torch/_strobelight/cli_function_profiler.py index 4fe133cafc03ed..29150b43f9f4bc 100644 --- a/torch/_strobelight/cli_function_profiler.py +++ b/torch/_strobelight/cli_function_profiler.py @@ -310,7 +310,7 @@ def strobelight( profiler = StrobelightCLIFunctionProfiler(**kwargs) def strobelight_inner( - work_function: Callable[_P, _R] + work_function: Callable[_P, _R], ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index f8a4da69970bfc..8cc9cae224ef45 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -12,6 +12,7 @@ from torch._dispatch.python import no_python_dispatcher from torch._ops import OpOverload from torch._prims_common import ( + contiguous_for_memory_format_or_false, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, is_boolean_dtype, @@ -68,6 +69,8 @@ def is_noncontiguous_supported(device): aten.randn_like.default, aten.randn_like.out, aten.randint_like.default, + aten.randint_like.Tensor, + aten.randint_like.Tensor_out, aten.randint_like.out, aten.randint_like.low_dtype, aten.randint_like.low_dtype_out, @@ -111,7 +114,7 @@ def contains_tensor_types(type): ) -@functools.lru_cache(None) +@functools.cache def _is_tensor_constructor(func: OpOverload): assert isinstance(func, OpOverload) schema = func._schema @@ -126,9 +129,9 @@ def _is_tensor_constructor(func: OpOverload): def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): def impl_decorator(op_impl): if isinstance(run_impl_check, OpOverload): - assert ( - run_impl_check not in op_implementations_dict - ), f"duplicate registration: {run_impl_check}" + assert run_impl_check not in op_implementations_dict, ( + f"duplicate registration: {run_impl_check}" + ) op_implementations_dict[run_impl_check] = op_impl elif isinstance(run_impl_check, (list, tuple)): for op in run_impl_check: @@ -572,25 +575,25 @@ def assert_tensor_metadata( layout=None, ) -> None: if sizes is not None: - assert ( - t.size() == sizes - ), f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}" + assert t.size() == sizes, ( + f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}" + ) if strides is not None: - assert ( - t.stride() == strides - ), f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}" + assert t.stride() == strides, ( + f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}" + ) if dtype is not None: - assert ( - t.dtype == dtype - ), f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}" + assert t.dtype == dtype, ( + f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}" + ) if layout is not None: - assert ( - t.layout == layout - ), f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}" + assert t.layout == layout, ( + f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}" + ) if device is not None: - assert ( - t.device == device - ), f"Tensor device mismatch! Expected: {device}, Got: {t.device}" + assert t.device == device, ( + f"Tensor device mismatch! Expected: {device}, Got: {t.device}" + ) # NB: this must be ordered after local_scalar_dense @@ -634,8 +637,11 @@ def has_meta(func): return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") +# These are for the `torch._foreach_...` ops like `torch._foreach_add`. @register_op_impl( - lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func) + lambda func: is_builtin(func) + and func.name().startswith("aten::_foreach_") + and has_meta(func) ) def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): tensor_lists = [ @@ -956,7 +962,7 @@ def slow(msg): final_shape = infer_size(final_shape, shape) assert final_shape is not None - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq # Do some extra safety checks to see if the output # stride is obvious @@ -964,10 +970,12 @@ def slow(msg): if ( isinstance(op, torch.Tensor) and len(op.shape) == len(final_shape) - and guard_size_oblivious(sym_eq(op.shape, final_shape)) + # take the slow path if result is not determined. + and guard_or_false(sym_eq(op.shape, final_shape)) ): break else: + # if we never break in the for loop above we take the slow path. return slow("both tensors nontrivially broadcast") # compute_types @@ -1010,22 +1018,29 @@ def slow(msg): return slow("error") # compute_fast_setup_type - is_contiguous = True - is_channels_last = True - # TODO: is_non-overlapping_and_dense (not bound from Python + definitely_contiguous = True + definitely_channels_last = True + + # TODO: is_non-overlapping_and_dense not bound from Python # no inplace, no out, everything defined if is_noncontiguous_supported(common_device): for op in operands: if not isinstance(op, torch.Tensor): continue - is_contiguous = is_contiguous and op.is_contiguous( - memory_format=torch.contiguous_format + definitely_contiguous = ( + definitely_contiguous + and contiguous_for_memory_format_or_false( + op, memory_format=torch.contiguous_format + ) ) - is_channels_last = is_channels_last and op.is_contiguous( - memory_format=torch.channels_last + definitely_channels_last = ( + definitely_channels_last + and contiguous_for_memory_format_or_false( + op, memory_format=torch.channels_last + ) ) - if is_contiguous: + if definitely_contiguous: # do contiguous count_label("fast is_contiguous") return FakeTensor( @@ -1038,7 +1053,7 @@ def slow(msg): ), device=common_device, ) - if is_channels_last: + if definitely_channels_last: count_label("fast channels_last") # do channels last return FakeTensor( @@ -1059,13 +1074,15 @@ def slow(msg): # disable the python dispatcher to avoid decomposing detach() further # (proxy_mode should still decompose detach() though) -def fast_detach(fake_mode, x): +def fast_detach(fake_mode, x, include_real=False): with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode): out = torch.ops.aten.detach.default(x) - return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor) + if include_real: + return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor) + return FakeTensor(fake_mode, out, x.device) -@functools.lru_cache(None) +@functools.cache def get_fast_op_impls(): import torch._refs @@ -1075,7 +1092,9 @@ def get_fast_op_impls(): register_fast_op_impl(torch.ops.aten.sub.Tensor)( make_fast_binary_impl(torch._refs.sub) ) - register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] + register_fast_op_impl(torch.ops.aten.mul.Tensor)( + make_fast_binary_impl(torch._refs.mul) + ) # type: ignore[has-type] register_fast_op_impl(torch.ops.aten.div.Tensor)( make_fast_binary_impl( torch._refs.div, diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 3c1a4492a6f929..bbecee8004be28 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -62,6 +62,7 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext log = logging.getLogger(__name__) +hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") # TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186 # Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105 @@ -120,6 +121,11 @@ class UnsupportedOperatorException(RuntimeError): func: OpOverload +@dataclass +class UnsupportedMutationAliasingException(RuntimeError): + reason: str + + @dataclass class MetadataMismatchError(RuntimeError): reason: str @@ -227,7 +233,7 @@ def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]: return None -@functools.lru_cache(None) +@functools.cache def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: return torch._C._SchemaInfo(func._schema) @@ -237,7 +243,7 @@ def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: # torch/_decomp/decompositions.py. # decomps are used for aot autograd tracing so we would like to unify on their # implementation and add additional testing to them -@functools.lru_cache(None) +@functools.cache def torch_decomp_decompositions(func: OpOverload) -> bool: from torch._decomp import decomposition_table @@ -490,9 +496,9 @@ def from_meta_and_device( pytype: Optional[type[torch.Tensor]] = None, dispatch_keys: Optional[torch.DispatchKeySet] = None, ) -> FakeTensor: - assert ( - t.device.type == "meta" - ), f"tensor's device must be `meta`, got {t.device.type} instead" + assert t.device.type == "meta", ( + f"tensor's device must be `meta`, got {t.device.type} instead" + ) # This is a bit abusive (this is not the "real" tensor) but whatever, # the meta tensor should be fresh so there's no way to get it wrong maybe_memo = self._get_memo(t) @@ -505,7 +511,7 @@ def from_meta_and_device( return out -@functools.lru_cache(None) +@functools.cache def init_gpu_context(device: torch.device) -> None: # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first if torch.cuda.is_available() or torch.xpu.is_available(): @@ -1433,6 +1439,15 @@ def _cached_dispatch_impl( key = self._cache_key(state, func, args, kwargs) except _BypassDispatchCache as e: # We couldn't create the cache key at all + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func.name() == "invoke_subgraph" + ): + hc_log.debug( + "Fake tensor cache failed: identifier = %s, reason = %s", + args[1], + e.reason, + ) FakeTensorMode.cache_bypasses[e.reason] += 1 if key is None: @@ -1477,6 +1492,15 @@ def _cached_dispatch_impl( # We ran "extra" checks on the cache key and determined that it's no # good. Record the reason and mark it so we don't bother validating # again. + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func.name() == "invoke_subgraph" + ): + hc_log.debug( + "Fake tensor cache failed: identifier = %s, reason = %s", + args[1], + e.reason, + ) FakeTensorMode.cache_bypasses[e.reason] += 1 set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason)) return output @@ -1521,6 +1545,10 @@ def _cache_key( # where it wasn't seen on a previous instance of the same op. self.shape_env.settings if self.shape_env else None, ] + if state.known_symbols: + # If there are symbols then include the epoch - this is really more + # of a Shape env var which lives on the FakeTensorMode. + key_values.append(self.epoch) # Collect the id_hashed objects to attach a weakref finalize later id_hashed_objects: list[object] = [] # Translate any FakeTensor args to metadata. @@ -1566,7 +1594,10 @@ def _validate_cache_key( if torch.Tag.dynamic_output_shape in func.tags: if func is aten.index.Tensor: _, new_kwargs = normalize_function( # type: ignore[misc] - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] + func, + args=args, # type: ignore[arg-type] + kwargs=kwargs, # type: ignore[arg-type] + normalize_to_only_use_kwargs=True, ) for index in new_kwargs["indices"]: # index calls nonzero for bool or int8 tensors, and @@ -1617,6 +1648,9 @@ def _prep_args_for_hash( convert FakeTensors into metadata. Raises _BypassDispatchCache to signal unsupported cases that should bypass caching. """ + from torch._higher_order_ops.auto_functionalize import ( + FunctionalCallableWithEpilogue, + ) from torch._higher_order_ops.utils import FunctionalizeCtxWrapper if isinstance(args, dict): @@ -1632,10 +1666,6 @@ def _prep_args_for_hash( raise _BypassDispatchCache("constant attribute") if is_sparse_any(arg): raise _BypassDispatchCache(f"{arg.layout} tensor") - # FIXME: For now back out caching when there are symbolic nbytes - # - this doesn't seem to play nice with set(). See T196779132 for examples. - if isinstance(arg.untyped_storage().nbytes(), SymInt): - raise _BypassDispatchCache("symbolic nbytes") metadata = extract_tensor_metadata(arg) metadata._flatten_into(result, self, state) elif isinstance(arg, Tensor): @@ -1661,6 +1691,10 @@ def _prep_args_for_hash( # functional wrapper is destroyed after fake tensor prop. We # need to put the finalizer on the subgraph. id_hashed_objects.append(arg.subgraph) + elif isinstance(arg, FunctionalCallableWithEpilogue): + result.append(type(arg)) + result.append(hash(arg)) + id_hashed_objects.append(arg.orig_callable) else: # It's important to capture the type of the arg since, e.g., 1 and 1.0 # hash to the same value, but can produce different dtypes for the @@ -1764,9 +1798,18 @@ def _get_output_info_for_cache_entry( entry_for_synth_output = _DispatchCacheValidEntry( output_infos=(entry,), is_output_tuple=False ) - synth_output = self._output_from_cache_entry( - state, entry_for_synth_output, key, func, args - ) + from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode + + try: + synth_output = self._output_from_cache_entry( + state, entry_for_synth_output, key, func, args + ) + except GuardOnDataDependentSymNode: + # This should probably never really happen. If it does it means that + # although the original call didn't get a data-dependent error when + # we tried to reconstruct the output we did - that's almost + # certainly a bug. + raise _BypassDispatchCache("data dependent symnode") from None # Make sure the dispatch_key_set from the synthesized output tensor will # be the same. @@ -1937,11 +1980,7 @@ def _output_from_cache_entry( if entry.is_output_tuple: outputs = [ self._get_output_tensor_from_cache_entry( - state, - output_info, - key, - func, - args, + state, output_info, key, func, args ) for output_info in entry.output_infos ] @@ -1974,8 +2013,8 @@ def assert_helper(a: Any, b: Any) -> None: assert isinstance(b, int) and a == b elif a is None: assert b is None - elif isinstance(a, torch.SymInt): - assert a is b + elif isinstance(a, py_sym_types): + assert type(a) == type(b) and a.node is b.node elif isinstance(a, torch.Tensor): assert isinstance(b, torch.Tensor) assert_metadata_eq(assert_eq, a, b) @@ -2100,9 +2139,7 @@ def _check_fake_real_vals(fake: Any, real: Any) -> None: try: _check_fake_real_vals(s_fake, s_real) except MetadataMismatchError as exc: - if ( - torch._functorch.config.generate_fake_kernels_from_real_mismatches - ): + if torch._functorch.config.generate_fake_kernels_from_real_mismatches: dtrace_structured( "mismatched_fake_kernel", metadata_fn=lambda: { @@ -2275,9 +2312,9 @@ def _dispatch_impl( and not flat_arg_fake_tensors and not device_conversion_skip_const_prop ): - assert all( - t.constant is not None for t in flat_arg_fake_tensors - ), f"{func} should not have fake inputs without constants" + assert all(t.constant is not None for t in flat_arg_fake_tensors), ( + f"{func} should not have fake inputs without constants" + ) const_flat_args = [ a.constant if self.is_our_fake(a) else a for a in flat_args ] @@ -2502,9 +2539,7 @@ def go(t: object, real_t: Tensor) -> None: if real_out is not nil: # cross check fake/real outputs, and optionally override fake kernel mismatches - if ( - not torch._functorch.config.generate_fake_kernels_from_real_mismatches - ): + if not torch._functorch.config.generate_fake_kernels_from_real_mismatches: self._maybe_infer_fake_kernel_from_pytree_out( func, (args, kwargs), @@ -2745,7 +2780,7 @@ def validate(x: T) -> Union[T, FakeTensor]: nonlocal flat_arg_fake_tensors if not self.is_our_fake(x): - if torch.Tag.inplace_view in func.tags: + if hasattr(func, "tags") and torch.Tag.inplace_view in func.tags: args, kwargs = pytree.tree_unflatten(flat_args, args_spec) raise AssertionError( f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {render_call(func, args, kwargs)}" @@ -2888,7 +2923,10 @@ def invalidate_written_to_constants( schema_info = get_schema_info(func) if any_constant and schema_info.is_mutable(): _, new_kwargs = normalize_function( # type: ignore[misc] - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] + func, + args=args, # type: ignore[arg-type] + kwargs=kwargs, # type: ignore[arg-type] + normalize_to_only_use_kwargs=True, ) for k, v in new_kwargs.items(): k = k if (k != "input" or schema_info.has_argument(k)) else "self" @@ -2912,9 +2950,9 @@ def from_tensor( if static_shapes is None: static_shapes = self.static_shapes if static_shapes: - assert ( - symbolic_context is None - ), "cannot set both static_shapes and symbolic_context" + assert symbolic_context is None, ( + "cannot set both static_shapes and symbolic_context" + ) shape_env = None return self.fake_tensor_converter.from_real_tensor( self, diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index d47234ea1b6d75..bd481c87cf6f34 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -102,7 +102,7 @@ def is_sdpa_error(func, idx, e): def try_convert_fake_to_real( - ten_list: list[Union[FakeTensor, Any]] + ten_list: list[Union[FakeTensor, Any]], ) -> list[Union[FakeTensor, torch.Tensor, Any]]: """ Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up @@ -266,9 +266,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if fake_r is not None: r_flat = pytree.tree_leaves(r) f_flat = pytree.tree_leaves(fake_r) - assert len(f_flat) == len( - r_flat - ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" + assert len(f_flat) == len(r_flat), ( + f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" + ) if self.check_aliasing: _check_alias_info( @@ -279,9 +279,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) ): r_is_ten = isinstance(r_out, torch.Tensor) - assert r_is_ten == isinstance( - f_out, torch.Tensor - ), f"{context} mismatched number of tensor outputs" + assert r_is_ten == isinstance(f_out, torch.Tensor), ( + f"{context} mismatched number of tensor outputs" + ) if r_is_ten: try: _check_fake_real_tensors( diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index b01ebd8bb878fa..956f22d1c4b65b 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -160,7 +160,7 @@ def __new__(cls, elem, mode): assert out._inference_mode_base is not None return out - def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[override] unrecognized_types = [ t for t in types @@ -291,7 +291,7 @@ def to_dense(self): # type: ignore[override] return self.elem.to_dense() @property - def layout(self): + def layout(self): # type: ignore[override] return self.elem.layout def __bool__(self): diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 6dc03369cc90ba..9feeb46b65bce6 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -39,6 +39,7 @@ maybe_get_level, peek_interpreter_stack, ) +from torch._dispatch.python import enable_python_dispatcher from torch._logging import trace_structured from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -356,7 +357,9 @@ def describe_tensor( maybe_functorch_stack = None if is_functorch_wrapped: - with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack: + with ( + torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() + ) as maybe_functorch_stack: pass attrs = None @@ -385,7 +388,7 @@ def describe_tensor( is_leaf=is_leaf, requires_grad=t.requires_grad, # NB: ndim should be OK too but there is a disaster at - # python test/dynamo/test_subclasses.py -k test_user_overidden_property_unsupported + # python test/dynamo/test_subclasses.py -k test_user_overridden_property_unsupported # Actually, this means that we have a little bit of a problem # here, which is that there is some sensitivity to how exactly an # access is done if you have a __torch_function__ subclass. Maybe @@ -516,8 +519,7 @@ def apply( new_base: _TensorT, symint_visitor_fn: Optional[Callable[[int], int]] = None, tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None, - ) -> _TensorT: - ... + ) -> _TensorT: ... @staticmethod def from_tensor(t: torch.Tensor) -> ViewFunc: @@ -573,8 +575,7 @@ def apply( class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]): def __call__( self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str] - ) -> _TensorT_cov: - ... + ) -> _TensorT_cov: ... class _MetaTensorCallbackKwargs(TypedDict, total=False): @@ -591,8 +592,7 @@ def __call__( arg: Callable[[], torch.Tensor], /, **kwargs: Unpack[_MetaTensorCallbackKwargs], - ) -> _TensorT_cov: - ... + ) -> _TensorT_cov: ... @dataclass(frozen=True) @@ -784,9 +784,9 @@ def __init__(self, *, copy_data: bool = False) -> None: ] = weakref.WeakValueDictionary() # Maps MetaTensorId to torch.Tensor (typically a meta tensor or # FakeTensor) - self.tensor_memo: weakref.WeakValueDictionary[ - MetaTensorId, _TensorT - ] = weakref.WeakValueDictionary() + self.tensor_memo: weakref.WeakValueDictionary[MetaTensorId, _TensorT] = ( + weakref.WeakValueDictionary() + ) self.hit = 0 self.miss = 0 self.del_hook = None @@ -1178,134 +1178,140 @@ def view_from_base( torch.fx.experimental.symbolic_shapes.ShapeEnv ] = shape_env, ) -> _TensorT: - # fake-ify t's metadata according to the outer symbolic context - (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( - t, source - ) - if ( - not t.is_traceable_wrapper_subclass - and not is_traceable_wrapper_subclass(base) - ): - # Dense -> Dense view case uses as_strided() to construct view relationship. - # TODO: Change this logic to use view replay for consistency? - # It's likely there is no view func available. - with maybe_suppress(): - return self._checked_cast_tensor_t( - base.as_strided(sizes, strides, storage_offset) - ) + with enable_python_dispatcher(): + # fake-ify t's metadata according to the outer symbolic context + (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( + t, source + ) + if ( + not t.is_traceable_wrapper_subclass + and not is_traceable_wrapper_subclass(base) + ): + # Dense -> Dense view case uses as_strided() to construct view relationship. + # TODO: Change this logic to use view replay for consistency? + # It's likely there is no view func available. + with maybe_suppress(): + return self._checked_cast_tensor_t( + base.as_strided(sizes, strides, storage_offset) + ) - from torch._dynamo.source import EphemeralSource - from torch.fx.experimental.symbolic_shapes import ( - StatelessSymbolicContext, - sym_eq, - ) + from torch._dynamo.source import EphemeralSource + from torch.fx.experimental.symbolic_shapes import ( + StatelessSymbolicContext, + sym_eq, + ) - def symint_visitor_fn(s: int) -> int: - nonlocal symbolic_context - from torch.fx.experimental.symbolic_shapes import DimDynamic + def symint_visitor_fn(s: int) -> int: + nonlocal symbolic_context + from torch.fx.experimental.symbolic_shapes import DimDynamic - all_static_sizes = ( - symbolic_context is not None - and isinstance(symbolic_context, StatelessSymbolicContext) - and all( - x is DimDynamic.STATIC for x in symbolic_context.dynamic_sizes + all_static_sizes = ( + symbolic_context is not None + and isinstance(symbolic_context, StatelessSymbolicContext) + and all( + x is DimDynamic.STATIC + for x in symbolic_context.dynamic_sizes + ) + ) + # Can't just rely on shape env being None - dynamo always initializes it + if all_static_sizes or shape_env is None: + return s + + # NB: The symbol here is expected to be simplified out because we a priori + # allocate inner and outer symbols according to the appropriate symbolic + # contexts and prefer those over this symbol during symbol simplification + # (via usage of EphemeralSource below). This -shouldn't- happen, but if + # this symbol somehow leaks out beyond the view tensor's shape metadata, our + # assumption of it being simplified out will fail and it may be guarded on, + # which will hard error. + sym_source = EphemeralSource("symint_visitor_fn") + + symbol = shape_env.create_symbol(s, sym_source, positive=None) + return shape_env.create_symintnode( + symbol, hint=s, source=sym_source ) - ) - # Can't just rely on shape env being None - dynamo always initializes it - if all_static_sizes or shape_env is None: - return s - - # NB: The symbol here is expected to be simplified out because we a priori - # allocate inner and outer symbols according to the appropriate symbolic - # contexts and prefer those over this symbol during symbol simplification - # (via usage of EphemeralSource below). This -shouldn't- happen, but if - # this symbol somehow leaks out beyond the view tensor's shape metadata, our - # assumption of it being simplified out will fail and it may be guarded on, - # which will hard error. - sym_source = EphemeralSource("symint_visitor_fn") - - symbol = shape_env.create_symbol(s, sym_source, positive=None) - return shape_env.create_symintnode(symbol, hint=s, source=sym_source) - - real_to_fake_mapping = {} - if t.is_traceable_wrapper_subclass: - assert t.attrs is not None - # NB: t.ctx could be None if the subclass in question has no - # meaningful context - assert t.type is not None - # Fake-ify t naively here; this is only done so we can get fake-ified inner - # tensors with the correct relationships to the outer sizes / strides for use - # in view replay. It's done beforehand here because it's not easy to do when - # visiting tensors one-by-one during view replay. - # - # Example: - # Consider a Dense -> NJT view. NJT has (values, offsets) components and we - # want a view of values with the offsets closed over. As the offsets component - # is needed to describe the output view, it's important that it's fakeified - # correctly. - fake_t: _TensorT = empty_create_subclass( - t, outer_size=sizes, outer_stride=strides - ) - attrs, _ = fake_t.__tensor_flatten__() # type: ignore[attr-defined] - for attr in attrs: - real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr) - - def tensor_visitor_fn( - visited_t: torch.Tensor, - # These arguments are never passed, we just use them to close - # over these relevant values - shape_env: Optional[ - torch.fx.experimental.symbolic_shapes.ShapeEnv - ] = shape_env, - callback: _MetaTensorCallbackOptDevice[_TensorT] = callback, - ) -> torch.Tensor: - # It's possible to close over an undefined tensor (e.g. NJT's lengths). - if visited_t is None: - return None - - # NB: visited_t being a Tensor here is very naughty! Should - # have already been described - - # Fake inner tensors of view subclasses will come from the mapping built above. - visited_id = self.describer.get_tensor_id(visited_t) - fake_visited_t = real_to_fake_mapping.get(visited_id, None) - if fake_visited_t is not None: - return fake_visited_t - - visited_desc = self.describer.describe_tensor(visited_t) - - # For other closed-over tensor state, fake-ify it as all dynamic with an - # ephemeral source. This avoids invalid specialization during view replay. - # If we find that in practice the usage of ephemeral sources isn't enough - # to guarantee that we don't have guards on these symbols, we may need to - # explicitly suppress guards (as is done for _base in the dense -> dense - # view case). - temp_source = EphemeralSource("tensor_visitor_fn") - return self.meta_tensor( - visited_desc, - shape_env, - callback, - temp_source, - all_dynamic_symbolic_context( - visited_desc, temp_source, shape_env, callback - ), + real_to_fake_mapping = {} + if t.is_traceable_wrapper_subclass: + assert t.attrs is not None + # NB: t.ctx could be None if the subclass in question has no + # meaningful context + assert t.type is not None + + # Fake-ify t naively here; this is only done so we can get fake-ified inner + # tensors with the correct relationships to the outer sizes / strides for use + # in view replay. It's done beforehand here because it's not easy to do when + # visiting tensors one-by-one during view replay. + # + # Example: + # Consider a Dense -> NJT view. NJT has (values, offsets) components and we + # want a view of values with the offsets closed over. As the offsets component + # is needed to describe the output view, it's important that it's fakeified + # correctly. + fake_t: _TensorT = empty_create_subclass( + t, outer_size=sizes, outer_stride=strides + ) + attrs, _ = fake_t.__tensor_flatten__() # type: ignore[attr-defined] + for attr in attrs: + real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr) + + def tensor_visitor_fn( + visited_t: torch.Tensor, + # These arguments are never passed, we just use them to close + # over these relevant values + shape_env: Optional[ + torch.fx.experimental.symbolic_shapes.ShapeEnv + ] = shape_env, + callback: _MetaTensorCallbackOptDevice[_TensorT] = callback, + ) -> torch.Tensor: + # It's possible to close over an undefined tensor (e.g. NJT's lengths). + if visited_t is None: + return None + + # NB: visited_t being a Tensor here is very naughty! Should + # have already been described + + # Fake inner tensors of view subclasses will come from the mapping built above. + visited_id = self.describer.get_tensor_id(visited_t) + fake_visited_t = real_to_fake_mapping.get(visited_id, None) + if fake_visited_t is not None: + return fake_visited_t + + visited_desc = self.describer.describe_tensor(visited_t) + + # For other closed-over tensor state, fake-ify it as all dynamic with an + # ephemeral source. This avoids invalid specialization during view replay. + # If we find that in practice the usage of ephemeral sources isn't enough + # to guarantee that we don't have guards on these symbols, we may need to + # explicitly suppress guards (as is done for _base in the dense -> dense + # view case). + temp_source = EphemeralSource("tensor_visitor_fn") + return self.meta_tensor( + visited_desc, + shape_env, + callback, + temp_source, + all_dynamic_symbolic_context( + visited_desc, temp_source, shape_env, callback + ), + ) + + # Replay the view, swapping out any non-symbolic SymInts or real tensors + # for symbolic SymInts or fake tensors. + assert t.view_func is not None + # NB: we do NOT suppress guards here, we need to remove ephemeral + # sources + fake_t = t.view_func.apply( + t, base, symint_visitor_fn, tensor_visitor_fn ) - # Replay the view, swapping out any non-symbolic SymInts or real tensors - # for symbolic SymInts or fake tensors. - assert t.view_func is not None - # NB: we do NOT suppress guards here, we need to remove ephemeral - # sources - fake_t = t.view_func.apply(t, base, symint_visitor_fn, tensor_visitor_fn) - - # Ensure the output has symbolic shapes according to the outer symbolic context. - # These checks should simplify out any symbols created for closed-over view func - # SymInts. - torch._check(sym_eq(fake_t.size(), sizes)) - torch._check(sym_eq(fake_t.stride(), strides)) - torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) - return fake_t + # Ensure the output has symbolic shapes according to the outer symbolic context. + # These checks should simplify out any symbols created for closed-over view func + # SymInts. + torch._check(sym_eq(fake_t.size(), sizes)) + torch._check(sym_eq(fake_t.stride(), strides)) + torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) + return fake_t if self.get_tensor_memo(t) is None: GRAD_TENSOR_SENTINEL_VALUE = -2 @@ -1646,7 +1652,7 @@ def is_c_of_r( # correct requires_grad, then do the final view. # NB: Can't have a non-leaf without requiring grad! assert t.requires_grad - with torch.no_grad(): + with torch.no_grad(), enable_python_dispatcher(): mid = self._checked_cast_tensor_t( base.view(base.shape) ) @@ -1765,9 +1771,9 @@ def is_c_of_r( # subclasses. Relevant test is # DynamicShapesFunctionTests::test_add_dynamic_shapes in # test/dynamo/test_dynamic_shapes.py - maybe_fake_mgr: AbstractContextManager[ - None - ] = contextlib.nullcontext() + maybe_fake_mgr: AbstractContextManager[None] = ( + contextlib.nullcontext() + ) from torch._subclasses.fake_tensor import ( in_kernel_invocation_manager, maybe_get_fake_mode, diff --git a/torch/_tensor.py b/torch/_tensor.py index 50bfc8f3c2c533..26935a6a60a463 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -6,7 +6,8 @@ from collections import OrderedDict from copy import deepcopy from numbers import Number -from typing import Any, Callable, cast, Optional, Union +from typing import Any, Callable, cast, Optional, TypeVar, Union +from typing_extensions import Concatenate, ParamSpec import torch import torch._C as _C @@ -27,16 +28,21 @@ ) -def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): - assigned = functools.WRAPPER_ASSIGNMENTS +_P = ParamSpec("_P") +_TensorLike = TypeVar("_TensorLike", bound=_C.TensorBase) - @functools.wraps(f, assigned=assigned) - def wrapped(*args, **kwargs): + +def _handle_torch_function_and_wrap_type_error_to_not_implemented( + f: Callable[Concatenate[_TensorLike, _P], "Tensor"], +) -> Callable[Concatenate[_TensorLike, _P], "Tensor"]: + @functools.wraps(f) + def wrapped(self: _TensorLike, *args: _P.args, **kwargs: _P.kwargs) -> "Tensor": try: # See https://github.com/pytorch/pytorch/issues/75462 - if has_torch_function(args): - return handle_torch_function(wrapped, args, *args, **kwargs) - return f(*args, **kwargs) + sargs = self, *args + if has_torch_function(sargs): + return handle_torch_function(wrapped, sargs, *sargs, **kwargs) + return f(self, *args, **kwargs) except TypeError: return NotImplemented @@ -621,19 +627,18 @@ def backward( Args: gradient (Tensor, optional): The gradient of the function being differentiated w.r.t. ``self``. - This argument can be omitted if ``self`` is a scalar. - retain_graph (bool, optional): If ``False``, the graph used to compute - the grads will be freed. Note that in nearly all cases setting - this option to True is not needed and often can be worked around - in a much more efficient way. Defaults to the value of - ``create_graph``. + This argument can be omitted if ``self`` is a scalar. Defaults to ``None``. + retain_graph (bool, optional): If ``False``, the graph used to compute the grads will be freed; + If ``True``, it will be retained. The default is ``None``, in which case the value is inferred from ``create_graph`` + (i.e., the graph is retained only when higher-order derivative tracking is requested). Note that in nearly all cases + setting this option to True is not needed and often can be worked around in a much more efficient way. create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. - inputs (sequence of Tensor, optional): Inputs w.r.t. which the gradient will be + inputs (Sequence[Tensor], optional): Inputs w.r.t. which the gradient will be accumulated into ``.grad``. All other tensors will be ignored. If not provided, the gradient is accumulated into all the leaf Tensors that were - used to compute the :attr:`tensors`. + used to compute the :attr:`tensors`. Defaults to ``None``. """ if has_torch_function_unary(self): return handle_torch_function( @@ -1094,11 +1099,11 @@ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None ) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rsub__(self, other): + def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": return _C._VariableFunctions.rsub(self, other) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rdiv__(self, other): + def __rdiv__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": return self.reciprocal() * other __rtruediv__ = __rdiv__ @@ -1113,12 +1118,13 @@ def __rdiv__(self, other): _C.TensorBase.pow ), ) + __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( _C.TensorBase.pow_ ) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rmod__(self, other): + def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": return torch.remainder(other, self) def __format__(self, format_spec): @@ -1131,27 +1137,33 @@ def __format__(self, format_spec): return object.__format__(self, format_spec) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rpow__(self, other): + def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": return torch.pow(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __floordiv__(self, other): + def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override] + # TODO(rec): the superclass says it accepts complex here, + # but torch.floor_divide says it doesn't. return torch.floor_divide(self, other) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override] return torch.floor_divide(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rlshift__(self, other): + def __rlshift__( + self, other: Union["Tensor", int, float, bool, complex] + ) -> "Tensor": return torch.bitwise_left_shift(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rrshift__(self, other): + def __rrshift__( + self, other: Union["Tensor", int, float, bool, complex] + ) -> "Tensor": return torch.bitwise_right_shift(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented - def __rmatmul__(self, other): + def __rmatmul__(self, other: "Tensor") -> "Tensor": return torch.matmul(other, self) __pos__ = _C.TensorBase.positive @@ -1675,7 +1687,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): __torch_dispatch__ = _C._disabled_torch_dispatch_impl - def __dlpack__(self, stream=None): + def __dlpack__(self, *, stream=None, max_version=None): """ Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ of the current tensor to be exported to other libraries. @@ -1692,9 +1704,18 @@ def __dlpack__(self, stream=None): both streams. If None or -1 is passed then no synchronization is performed. If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for synchronization. + + max_version (tuple[int, int] or None): An optional Python tuple with + 2 integers, representing the maximum version the caller supports. If + None (default), PyTorch will fallback to DLPack 0.8. """ if has_torch_function_unary(self): - return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) + args = (self,) + kwargs = { + "stream": stream, + "max_version": max_version, + } + return handle_torch_function(Tensor.__dlpack__, (self,), *args, **kwargs) # DLPack capsules can't capture all of PyTorch's semantics, # so we prohibit exporting tensors that would lose their properties like @@ -1742,8 +1763,15 @@ def __dlpack__(self, stream=None): raise RuntimeError( "Can't export to dlpack an XLA tensor that is not on CUDA." ) + + # Does not support DLPack 1.0, yet. return xla_dlpack.to_dlpack(self) - return torch.to_dlpack(self) + + if max_version is None or max_version[0] < 1: + # Fallback to the old, unversioned variant. + return torch.to_dlpack(self) + + return _C._to_dlpack_versioned(self) def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: if has_torch_function_unary(self): @@ -1757,9 +1785,9 @@ def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: if torch_device_type == "cuda" and torch.version.hip is not None: device_type = DLDeviceType.kDLROCM elif torch_device_type == "cpu" and self.is_pinned(): - device_type = DLDeviceType.kDLCPUPinned + device_type = DLDeviceType.kDLCUDAHost elif torch_device_type == "cuda": - device_type = DLDeviceType.kDLGPU + device_type = DLDeviceType.kDLCUDA elif torch_device_type == "cpu": device_type = DLDeviceType.kDLCPU elif torch_device_type == "xpu": @@ -1775,7 +1803,7 @@ def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: ): raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") - device_type = DLDeviceType.kDLGPU + device_type = DLDeviceType.kDLCUDA else: raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") return (device_type, idx) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 4b31a9de93b7b9..38a8708b1b59a1 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -5218,6 +5218,13 @@ def callable(a, b) -> number Args: {memory_format} +.. note:: + + According to `C++ type conversion rules `_, + converting floating point value to integer type will truncate the fractional part. + If the truncated value cannot fit into the target type (e.g., casting ``torch.inf`` to ``torch.long``), + the behavior is undefined and the result may vary across platforms. + .. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor :noindex: diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index a03c4efc15b7eb..921e97be233ad1 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -120,6 +120,7 @@ def tensor_totype(t): if ( t.is_mps or (t.is_xpu and not torch.xpu.get_device_properties(t.device).has_fp64) + or t.is_maia ) else torch.double ) @@ -167,8 +168,7 @@ def __init__(self, tensor): # support for them is removed nonzero_finite_vals = nonzero_finite_vals.float() - # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. - + # Convert to double (or float) for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs()) nonzero_finite_min = tensor_totype(nonzero_finite_abs.min()) nonzero_finite_max = tensor_totype(nonzero_finite_abs.max()) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 34ffdb313c4ab3..42b1a9da3e0e31 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -101,6 +101,11 @@ def merge_dicts(*dicts): "opt_dim_all_reduce": """ dim (int, optional): the dimension to reduce. If ``None``, all dimensions are reduced. +""" + }, + { + "opt_dim_without_none": """ + dim (int, optional): the dimension to reduce. If omitted, all dimensions are reduced. Explicit ``None`` is not supported. """ }, { @@ -2777,6 +2782,12 @@ def merge_dicts(*dicts): result of this operation to :attr:`input`. To create a tensor without an autograd relationship to :attr:`input` see :meth:`~Tensor.detach`. + In addition, when ``torch.preserve_format`` is used: + If the input tensor is dense (i.e., non-overlapping strided), + its memory format (including strides) is retained. + Otherwise (e.g., a non-dense view like a stepped slice), + the output is converted to the dense (contiguous) format. + Args: {input} @@ -6543,7 +6554,7 @@ def merge_dicts(*dicts): Args: {input} - {opt_dim_all_reduce} + {opt_dim_without_none} {opt_keepdim} Keyword args: @@ -6574,7 +6585,7 @@ def merge_dicts(*dicts): See :func:`torch.maximum`. -""".format(**multi_dim_common), +""".format(**single_dim_common), ) add_docstr( @@ -6647,10 +6658,10 @@ def merge_dicts(*dicts): .. note:: The difference between ``max``/``min`` and ``amax``/``amin`` is: - ``amax``/``amin`` supports reducing on multiple dimensions, - - ``amax``/``amin`` does not return indices, - - ``amax``/``amin`` evenly distributes gradient between equal values, - while ``max(dim)``/``min(dim)`` propagates gradient only to a single - index in the source tensor. + - ``amax``/``amin`` does not return indices. + + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. {keepdim_details} @@ -7161,7 +7172,7 @@ def merge_dicts(*dicts): Args: {input} - {opt_dim_all_reduce} + {opt_dim_without_none} {opt_keepdim} Keyword args: @@ -7255,10 +7266,10 @@ def merge_dicts(*dicts): .. note:: The difference between ``max``/``min`` and ``amax``/``amin`` is: - ``amax``/``amin`` supports reducing on multiple dimensions, - - ``amax``/``amin`` does not return indices, - - ``amax``/``amin`` evenly distributes gradient between equal values, - while ``max(dim)``/``min(dim)`` propagates gradient only to a single - index in the source tensor. + - ``amax``/``amin`` does not return indices. + + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. {keepdim_details} diff --git a/torch/_utils.py b/torch/_utils.py index e813c082ba90a6..8e818f58d9885e 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -274,11 +274,25 @@ def _rebuild_tensor_v3( # to Pickler semantics, we have to use the same (non-validating) function for # unpickling sparse tensors, regardless of the caller. def _validate_loaded_sparse_tensors(): + if not torch.sparse.check_sparse_tensor_invariants().is_enabled(): + # Skip sparse tensor invariants validation for better + # performance. See check_sparse_tensor_invariants + # documentation for how to control sparse tensor invariants + # checking. + _sparse_tensors_to_validate.clear() + return try: + # We disable pinning check (see check_pinning=False below) to + # avoid gh-153143. In fact, pinning check is unnecessary + # anywhy when loading sparse data from external sources. for t in _sparse_tensors_to_validate: if t.layout is torch.sparse_coo: torch._validate_sparse_coo_tensor_args( - t._indices(), t._values(), t.size(), t.is_coalesced() + t._indices(), + t._values(), + t.size(), + t.is_coalesced(), + check_pinning=False, ) elif t.layout in { torch.sparse_csr, @@ -299,7 +313,12 @@ def _validate_loaded_sparse_tensors(): t.row_indices(), ) torch._validate_sparse_compressed_tensor_args( - compressed_indices, plain_indices, t.values(), t.size(), t.layout + compressed_indices, + plain_indices, + t.values(), + t.size(), + t.layout, + check_pinning=False, ) else: raise NotImplementedError( diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 89fbd6787281cd..fd8b8f08f8b8ce 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -4,6 +4,7 @@ import os import sys import tempfile +import typing_extensions from typing import Any, Callable, Optional, TypeVar from typing_extensions import ParamSpec @@ -210,7 +211,7 @@ def is_fb_unit_test() -> bool: return False -@functools.lru_cache(None) +@functools.cache def max_clock_rate(): if not torch.version.hip: from triton.testing import nvsmi @@ -282,3 +283,54 @@ def record_chromium_event_internal( def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12(): return True + + +def deprecated(): + """ + When we deprecate a function that might still be in use, we make it internal + by adding a leading underscore. This decorator is used with a private function, + and creates a public alias without the leading underscore, but has a deprecation + warning. This tells users "THIS FUNCTION IS DEPRECATED, please use something else" + without breaking them, however, if they still really really want to use the + deprecated function without the warning, they can do so by using the internal + function name. + """ + + def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: + # Validate naming convention – single leading underscore, not dunder + if not (func.__name__.startswith("_")): + raise ValueError( + "@deprecate must decorate a function whose name " + "starts with a single leading underscore (e.g. '_foo') as the api should be considered internal for deprecation." + ) + + public_name = func.__name__[1:] # drop exactly one leading underscore + module = sys.modules[func.__module__] + + # Don't clobber an existing symbol accidentally. + if hasattr(module, public_name): + raise RuntimeError( + f"Cannot create alias '{public_name}' -> symbol already exists in {module.__name__}. \ + Please rename it or consult a pytorch developer on what to do" + ) + + warning_msg = f"{func.__name__[1:]} is DEPRECATED, please consider using an alternative API(s). " + + # public deprecated alias + alias = typing_extensions.deprecated( + warning_msg, category=UserWarning, stacklevel=1 + )(func) + + alias.__name__ = public_name + + # Adjust qualname if nested inside a class or another function + if "." in func.__qualname__: + alias.__qualname__ = func.__qualname__.rsplit(".", 1)[0] + "." + public_name + else: + alias.__qualname__ = public_name + + setattr(module, public_name, alias) + + return func + + return decorator diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index 6da73100838ec5..1ed4f4bb4c3109 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -2,7 +2,7 @@ This package introduces support for the current :ref:`accelerator` in python. """ -from typing import Literal, Optional +from typing import Optional from typing_extensions import deprecated import torch @@ -33,7 +33,7 @@ def device_count() -> int: If there is no available accelerators, return 0. .. note:: This API delegates to the device-specific version of `device_count`. - On CUDA, this API will NOT posion fork if NVML discovery succeeds. + On CUDA, this API will NOT poison fork if NVML discovery succeeds. Otherwise, it will. For more details, see :ref:`multiprocessing-poison-fork-note`. """ acc = current_accelerator() @@ -129,7 +129,7 @@ def set_device_index(device: _device_t, /) -> None: .. note:: This function is a no-op if this device index is negative. """ - device_index = _get_device_index(device) + device_index = _get_device_index(device, optional=False) torch._C._accelerator_setDeviceIndex(device_index) @@ -150,7 +150,7 @@ def current_stream(device: _device_t = None, /) -> torch.Stream: Returns: torch.Stream: the currently selected stream for a given device. """ - device_index = _get_device_index(device, True) + device_index = _get_device_index(device, optional=True) return torch._C._accelerator_getStream(device_index) @@ -188,7 +188,7 @@ def synchronize(device: _device_t = None, /) -> None: >>> torch.accelerator.synchronize() >>> elapsed_time_ms = start_event.elapsed_time(end_event) """ - device_index = _get_device_index(device, True) + device_index = _get_device_index(device, optional=True) torch._C._accelerator_synchronizeDevice(device_index) @@ -224,7 +224,6 @@ def __enter__(self) -> None: if self.idx is not None: self.prev_idx = torch._C._accelerator_exchangeDevice(self.idx) - def __exit__(self, *args: object) -> Literal[False]: + def __exit__(self, *exc_info: object) -> None: if self.idx is not None: torch._C._accelerator_maybeExchangeDevice(self.prev_idx) - return False diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index e8af9aca5a5a92..3e81f46f5c23dd 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -43,7 +43,9 @@ def decorate_autocast(*args, **kwargs): with autocast_instance: return func(*args, **kwargs) - decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined] + decorate_autocast.__script_unsupported = ( # type: ignore[attr-defined] + "@autocast() decorator is not supported in script mode" + ) return decorate_autocast @@ -88,9 +90,9 @@ class autocast: class AutocastModel(nn.Module): ... + @torch.autocast(device_type="cuda") - def forward(self, input): - ... + def forward(self, input): ... Floating-point Tensors produced in an autocast-enabled region may be ``float16``. After returning to an autocast-disabled region, using them with floating-point @@ -152,9 +154,11 @@ class TestModel(nn.Module): def __init__(self, input_size, num_classes): super().__init__() self.fc1 = nn.Linear(input_size, num_classes) + def forward(self, x): return self.fc1(x) + input_size = 2 num_classes = 2 model = TestModel(input_size, num_classes).eval() diff --git a/torch/amp/grad_scaler.py b/torch/amp/grad_scaler.py index 2931b5b9fadd38..2fd3d3e8c49885 100644 --- a/torch/amp/grad_scaler.py +++ b/torch/amp/grad_scaler.py @@ -175,20 +175,16 @@ def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None: ) @overload - def scale(self, outputs: torch.Tensor) -> torch.Tensor: - ... + def scale(self, outputs: torch.Tensor) -> torch.Tensor: ... @overload - def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: - ... + def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ... @overload - def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: - ... + def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ... @overload - def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: - ... + def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ... def scale( self, @@ -458,9 +454,9 @@ def step( if optimizer_state["stage"] is OptState.READY: self.unscale_(optimizer) - assert ( - len(optimizer_state["found_inf_per_device"]) > 0 - ), "No inf checks were recorded for this optimizer." + assert len(optimizer_state["found_inf_per_device"]) > 0, ( + "No inf checks were recorded for this optimizer." + ) retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) @@ -504,8 +500,10 @@ def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None if isinstance(new_scale, float): self._scale.fill_(new_scale) else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ - torch.FloatTensor with requires_grad=False." + reason = ( + "new_scale should be a float or a 1-element torch.cuda.FloatTensor or " + "torch.FloatTensor with requires_grad=False." + ) assert new_scale.device.type == self._device, reason assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason @@ -683,9 +681,9 @@ def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, A dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device) found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device) - self._per_optimizer_states[id(optimizer)][ - "found_inf_per_device" - ] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = ( + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + ) return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/torch/ao/nn/intrinsic/modules/fused.py b/torch/ao/nn/intrinsic/modules/fused.py index c7ae2ce3319d3c..ec5b9c26fdd004 100644 --- a/torch/ao/nn/intrinsic/modules/fused.py +++ b/torch/ao/nn/intrinsic/modules/fused.py @@ -47,7 +47,10 @@ def __init__(self, conv, relu): assert ( type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(conv, relu) @@ -59,7 +62,10 @@ def __init__(self, conv, relu): assert ( type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(conv, relu) @@ -71,7 +77,10 @@ def __init__(self, conv, relu): assert ( type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(conv, relu) @@ -83,7 +92,10 @@ def __init__(self, linear, relu): assert ( type_before_parametrizations(linear) == Linear and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(relu)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(linear)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(linear, relu) @@ -95,7 +107,10 @@ def __init__(self, conv, bn): assert ( type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) super().__init__(conv, bn) @@ -107,7 +122,10 @@ def __init__(self, conv, bn): assert ( type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) super().__init__(conv, bn) @@ -120,7 +138,11 @@ def __init__(self, conv, bn, relu): type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950 + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(conv, bn, relu) @@ -133,7 +155,11 @@ def __init__(self, conv, bn, relu): type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950 + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(conv, bn, relu) @@ -145,7 +171,10 @@ def __init__(self, conv, bn): assert ( type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) super().__init__(conv, bn) @@ -158,7 +187,11 @@ def __init__(self, conv, bn, relu): type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950 + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(conv, bn, relu) @@ -170,7 +203,10 @@ def __init__(self, batch_norm, relu): assert ( type_before_parametrizations(batch_norm) == BatchNorm2d and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(batch_norm, relu) @@ -182,7 +218,10 @@ def __init__(self, batch_norm, relu): assert ( type_before_parametrizations(batch_norm) == BatchNorm3d and type_before_parametrizations(relu) == ReLU - ), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}" + f"{type_before_parametrizations(relu)}" + ) super().__init__(batch_norm, relu) @@ -194,7 +233,10 @@ def __init__(self, linear, bn): assert ( type_before_parametrizations(linear) == Linear and type_before_parametrizations(bn) == BatchNorm1d - ), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}" + ), ( + f"Incorrect types for input modules{type_before_parametrizations(linear)}" + f"{type_before_parametrizations(bn)}" + ) super().__init__(linear, bn) @@ -203,9 +245,9 @@ class LinearLeakyReLU(_FusedModule): During quantization this will be replaced with the corresponding fused module.""" def __init__(self, linear, leaky_relu): - assert ( - type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU - ), f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}" + assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, ( + f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}" + ) super().__init__(linear, leaky_relu) @@ -214,9 +256,9 @@ class LinearTanh(_FusedModule): During quantization this will be replaced with the corresponding fused module.""" def __init__(self, linear, tanh): - assert ( - type(linear) == Linear and type(tanh) == torch.nn.Tanh - ), f"Incorrect types for input modules{type(linear)}{type(tanh)}" + assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, ( + f"Incorrect types for input modules{type(linear)}{type(tanh)}" + ) super().__init__(linear, tanh) diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 8e0ee5dcce0429..6671e317b6b02e 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -456,6 +456,7 @@ class ConvBn1d(_ConvBnNd, nn.Conv1d): weight_fake_quant: fake quant module for weight """ + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment] @@ -524,6 +525,7 @@ class ConvBnReLU1d(ConvBn1d): weight_fake_quant: fake quant module for weight """ + # base class defines _FLOAT_MODULE as "ConvBn1d" _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBnReLU1d _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d @@ -590,6 +592,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment] _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None @@ -630,7 +633,7 @@ def forward(self, input): ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return super().from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) @@ -653,6 +656,7 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d): weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment] _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.BatchNorm2d @@ -721,6 +725,7 @@ class ConvBnReLU2d(ConvBn2d): weight_fake_quant: fake quant module for weight """ + # base class defines _FLOAT_MODULE as "ConvBn2d" _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment] _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d @@ -787,6 +792,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment] _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None @@ -827,7 +833,7 @@ def forward(self, input): ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return super().from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) @@ -850,6 +856,7 @@ class ConvBn3d(_ConvBnNd, nn.Conv3d): weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment] _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.BatchNorm3d @@ -918,6 +925,7 @@ class ConvBnReLU3d(ConvBn3d): weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment] _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm3d]] = nn.BatchNorm3d @@ -985,6 +993,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment] _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None @@ -1025,7 +1034,7 @@ def forward(self, input): ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return super().from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py index 9cf0d4bba898c0..8446468dddcff6 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -1,11 +1,18 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Optional + import torch import torch.ao.nn.intrinsic as nni import torch.ao.nn.qat as nnqat import torch.nn.functional as F +from torch.ao.nn.intrinsic.modules.fused import _FusedModule + +__all__ = ["LinearReLU"] -class LinearReLU(nnqat.Linear, nni._FusedModule): + +class LinearReLU(nnqat.Linear, _FusedModule): r""" A LinearReLU module fused from Linear and ReLU modules, attached with FakeQuantize modules for weight, used in @@ -28,19 +35,30 @@ class LinearReLU(nnqat.Linear, nni._FusedModule): >>> print(output.size()) torch.Size([128, 30]) """ - _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] - def __init__(self, in_features, out_features, bias=True, qconfig=None): + _FLOAT_MODULE = nni.LinearReLU + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + qconfig: Optional[object] = None, + ) -> None: super().__init__(in_features, out_features, bias, qconfig) - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): - return super().from_float(mod, use_precomputed_fake_quant) + def from_float( + cls, + mod: torch.nn.Module, + use_precomputed_fake_quant: bool = False, + ) -> LinearReLU: + return super().from_float(mod, use_precomputed_fake_quant) # type: ignore[no-untyped-call,no-any-return] - def to_float(self): + def to_float(self) -> nni.LinearReLU: linear = torch.nn.Linear( self.in_features, self.out_features, self.bias is not None ) @@ -48,4 +66,4 @@ def to_float(self): if self.bias is not None: linear.bias = torch.nn.Parameter(self.bias.detach()) relu = torch.nn.ReLU() - return torch.ao.nn.intrinsic.LinearReLU(linear, relu) + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) # type: ignore[no-untyped-call] diff --git a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py index 1515e005cc7b39..f19c2c8e9d9db8 100644 --- a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -27,6 +27,7 @@ class LinearReLU(nnqd.Linear): >>> print(output.size()) torch.Size([128, 30]) """ + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): @@ -56,5 +57,5 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) @classmethod - def from_reference(cls, ref_qlinear_relu): + def from_reference(cls, ref_qlinear_relu): # type: ignore[override] return super().from_reference(ref_qlinear_relu[0]) diff --git a/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py b/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py index 78009d1c76f4ac..99b535625cbc7e 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py @@ -19,6 +19,7 @@ class BNReLU2d(nnq.BatchNorm2d): Same as torch.ao.nn.quantized.BatchNorm2d """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): @@ -46,7 +47,7 @@ def _get_name(self): return "QuantizedBNReLU2d" @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] # TODO: Add qat support for BNReLU2d return super().from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant @@ -67,6 +68,7 @@ class BNReLU3d(nnq.BatchNorm3d): Same as torch.ao.nn.quantized.BatchNorm3d """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): @@ -94,7 +96,7 @@ def _get_name(self): return "QuantizedBNReLU3d" @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] # TODO: Add qat support for BNReLU3d return super().from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py index 0d1b7e01f4479f..71bfa845f150ae 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -19,6 +19,7 @@ class ConvAdd2d(nnq.Conv2d): Same as torch.ao.nn.quantized.Conv2d """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment] def __init__( @@ -67,7 +68,7 @@ def _get_name(self): return "QuantizedConvAdd2d" @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return super().from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) @@ -87,6 +88,7 @@ class ConvAddReLU2d(nnq.Conv2d): Same as torch.ao.nn.quantized.Conv2d """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment] def __init__( @@ -135,7 +137,7 @@ def _get_name(self): return "QuantizedConvAddReLU2d" @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return super().from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py index 25d695859180fc..8172004d95fc80 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -28,6 +28,7 @@ class ConvReLU1d(nnq.Conv1d): Same as torch.ao.nn.quantized.Conv1d """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment] def __init__( @@ -77,7 +78,7 @@ def _get_name(self): return "QuantizedConvReLU1d" @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( @@ -93,9 +94,9 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): - assert ( - type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d - ), "BatchNorm1d should be fused into Conv1d before converting to reference module" + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, ( + "BatchNorm1d should be fused into Conv1d before converting to reference module" + ) return super().from_reference(ref_qconv[0], output_scale, output_zero_point) @@ -109,6 +110,7 @@ class ConvReLU2d(nnq.Conv2d): Same as torch.ao.nn.quantized.Conv2d """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment] def __init__( @@ -157,7 +159,7 @@ def _get_name(self): return "QuantizedConvReLU2d" @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( @@ -175,9 +177,9 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): - assert ( - type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d - ), "BatchNorm2d should be fused into Conv2d before converting to reference module" + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d, ( + "BatchNorm2d should be fused into Conv2d before converting to reference module" + ) return super().from_reference(ref_qconv[0], output_scale, output_zero_point) @@ -190,6 +192,7 @@ class ConvReLU3d(nnq.Conv3d): Attributes: Same as torch.ao.nn.quantized.Conv3d """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment] def __init__( @@ -239,7 +242,7 @@ def _get_name(self): return "QuantizedConvReLU3d" @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( @@ -257,7 +260,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): - assert ( - type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d - ), "BatchNorm3d should be fused into Conv3d before converting to reference module" + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d, ( + "BatchNorm3d should be fused into Conv3d before converting to reference module" + ) return super().from_reference(ref_qconv[0], output_scale, output_zero_point) diff --git a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py index a4b4fd1e7f364a..0ff5a7e4029fa5 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -30,6 +30,7 @@ class LinearReLU(nnq.Linear): >>> print(output.size()) torch.Size([128, 30]) """ + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): @@ -70,6 +71,7 @@ class LinearLeakyReLU(nnq.Linear): >>> print(output.size()) torch.Size([128, 30]) """ + _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment] def __init__( @@ -92,9 +94,9 @@ def _get_name(self): @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): - assert ( - type(mod) == nni.LinearLeakyReLU - ), "Input float module should be LinearLeakyReLU" + assert type(mod) == nni.LinearLeakyReLU, ( + "Input float module should be LinearLeakyReLU" + ) assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" activation_post_process = mod.activation_post_process leaky_relu = mod[1] @@ -145,6 +147,7 @@ class LinearTanh(nnq.Linear): >>> print(output.size()) torch.Size([128, 30]) """ + _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment] def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index ac6363b8b097bc..90474ab1ce60cb 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -134,6 +134,7 @@ class Conv1d(_ConvNd, nn.Conv1d): Attributes: weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d @@ -174,7 +175,7 @@ def __init__( ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return super().from_float( cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) @@ -195,6 +196,7 @@ class Conv2d(_ConvNd, nn.Conv2d): Attributes: weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d @@ -238,7 +240,7 @@ def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return super().from_float( cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) @@ -259,6 +261,7 @@ class Conv3d(_ConvNd, nn.Conv3d): Attributes: weight_fake_quant: fake quant module for weight """ + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d @@ -302,7 +305,7 @@ def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return super().from_float( cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) diff --git a/torch/ao/nn/qat/modules/embedding_ops.py b/torch/ao/nn/qat/modules/embedding_ops.py index 27ad7c9db437af..13fd7a5983fbee 100644 --- a/torch/ao/nn/qat/modules/embedding_ops.py +++ b/torch/ao/nn/qat/modules/embedding_ops.py @@ -23,6 +23,7 @@ class Embedding(nn.Embedding): Attributes: weight: fake quant module for weight """ + _FLOAT_MODULE = nn.Embedding def __init__( @@ -137,6 +138,7 @@ class EmbeddingBag(nn.EmbeddingBag): Attributes: weight: fake quant module for weight """ + _FLOAT_MODULE = nn.EmbeddingBag def __init__( diff --git a/torch/ao/nn/qat/modules/linear.py b/torch/ao/nn/qat/modules/linear.py index ede488e66b0b97..5edf16ed3ea53d 100644 --- a/torch/ao/nn/qat/modules/linear.py +++ b/torch/ao/nn/qat/modules/linear.py @@ -28,6 +28,7 @@ class Linear(nn.Linear): Attributes: weight: fake quant module for weight """ + _FLOAT_MODULE = nn.Linear def __init__( diff --git a/torch/ao/nn/quantizable/modules/activation.py b/torch/ao/nn/quantizable/modules/activation.py index 7e949a866cefe4..d9f5e4ff4c86ce 100644 --- a/torch/ao/nn/quantizable/modules/activation.py +++ b/torch/ao/nn/quantizable/modules/activation.py @@ -96,7 +96,9 @@ def __init__( self.vdim, self.embed_dim, bias=bias, **factory_kwargs ) # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969 - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment] + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs + ) # type: ignore[assignment] # Functionals self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional() @@ -212,7 +214,7 @@ def dequantize(self): fp.bias_v = nn.Parameter(self.bias_v.dequantize()) # Set the linear weights - # Note: Because the linear layers are quantized, mypy does not nkow how + # Note: Because the linear layers are quantized, mypy does not know how # to deal with them -- might need to ignore the typing checks. # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type] @@ -375,9 +377,9 @@ def _forward_impl( assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = self.embed_dim // self.num_heads - assert ( - head_dim * self.num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert head_dim * self.num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) scaling = float(head_dim) ** -0.5 q = self.linear_Q(query) @@ -394,9 +396,9 @@ def _forward_impl( stacklevel=3, ) attn_mask = attn_mask.to(torch.bool) - assert ( - attn_mask.is_floating_point() or attn_mask.dtype == torch.bool - ), f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}" + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, ( + f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}" + ) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) diff --git a/torch/ao/nn/quantizable/modules/rnn.py b/torch/ao/nn/quantizable/modules/rnn.py index af58765f33de10..ad32cf174c6280 100644 --- a/torch/ao/nn/quantizable/modules/rnn.py +++ b/torch/ao/nn/quantizable/modules/rnn.py @@ -38,6 +38,7 @@ class LSTMCell(torch.nn.Module): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx) """ + _FLOAT_MODULE = torch.nn.LSTMCell __constants__ = ["split_gates"] # for jit.script @@ -145,8 +146,9 @@ def forward( def initialize_hidden( self, batch_size: int, is_quantized: bool = False ) -> tuple[Tensor, Tensor]: - h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros( - (batch_size, self.hidden_size) + h, c = ( + torch.zeros((batch_size, self.hidden_size)), + torch.zeros((batch_size, self.hidden_size)), ) if is_quantized: (h_scale, h_zp) = self.initial_hidden_state_qparams @@ -319,8 +321,9 @@ def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None): if hx_fw is None and cx_fw is None: hidden_fw = None else: - hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional( - cx_fw + hidden_fw = ( + torch.jit._unwrap_optional(hx_fw), + torch.jit._unwrap_optional(cx_fw), ) result_fw, hidden_fw = self.layer_fw(x, hidden_fw) @@ -421,6 +424,7 @@ class LSTM(torch.nn.Module): >>> print(rnn.layers[0].weight_hh) AssertionError: There is no reverse path in the non-bidirectional layer """ + _FLOAT_MODULE = torch.nn.LSTM def __init__( diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index 1a6d73f93174ec..8855ccfdbfe60d 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -133,6 +133,7 @@ class Conv2d(nnq.Conv2d): >>> output = m(input) """ + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None @@ -217,6 +218,7 @@ class Conv3d(nnq.Conv3d): >>> output = m(input) """ + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None diff --git a/torch/ao/nn/quantized/dynamic/modules/linear.py b/torch/ao/nn/quantized/dynamic/modules/linear.py index a94b308da267ec..0faaf62cedb504 100644 --- a/torch/ao/nn/quantized/dynamic/modules/linear.py +++ b/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -35,6 +35,7 @@ class Linear(nnq.Linear): >>> print(output.size()) torch.Size([128, 30]) """ + # version used in this class is different from the parent class nnq.Linear _version = 4 @@ -111,10 +112,9 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): torch.ao.nn.qat.dynamic.Linear, ] - assert ( - type(mod) in float_modules - ), "nn.quantized.dynamic.Linear.from_float only works for one of" + str( - [float_mod.__name__ for float_mod in float_modules] + assert type(mod) in float_modules, ( + "nn.quantized.dynamic.Linear.from_float only works for one of" + + str([float_mod.__name__ for float_mod in float_modules]) ) assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" if type(mod) == nni.LinearReLU: @@ -147,7 +147,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): return qlinear @classmethod - def from_reference(cls, ref_qlinear): + def from_reference(cls, ref_qlinear): # type: ignore[override] """Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized module Args: diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index 0c5363e9b71dd6..10db59aafbf7ee 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import numbers import warnings @@ -521,6 +520,7 @@ class LSTM(RNNBase): >>> c0 = torch.randn(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) """ + _FLOAT_MODULE = nn.LSTM __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} @@ -805,6 +805,7 @@ class GRU(RNNBase): >>> h0 = torch.randn(2, 3, 20) >>> output, hn = rnn(input, h0) """ + _FLOAT_MODULE = nn.GRU __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} @@ -1036,8 +1037,10 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell, - }, "nn.quantized.dynamic.RNNCellBase.from_float \ + }, ( + "nn.quantized.dynamic.RNNCellBase.from_float \ only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell" + ) assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" if mod.qconfig is not None and mod.qconfig.weight is not None: @@ -1210,6 +1213,7 @@ class RNNCell(RNNCellBase): ... hx = rnn(input[i], hx) ... output.append(hx) """ + __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] def __init__( diff --git a/torch/ao/nn/quantized/functional.py b/torch/ao/nn/quantized/functional.py index 297629e088061f..51a2f4905c257c 100644 --- a/torch/ao/nn/quantized/functional.py +++ b/torch/ao/nn/quantized/functional.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -r""" Functional interface (quantized).""" +r"""Functional interface (quantized).""" + import warnings from typing import Optional diff --git a/torch/ao/nn/quantized/modules/activation.py b/torch/ao/nn/quantized/modules/activation.py index c7e37fd80baaa7..15b4d36e8b44a7 100644 --- a/torch/ao/nn/quantized/modules/activation.py +++ b/torch/ao/nn/quantized/modules/activation.py @@ -265,7 +265,8 @@ def from_observed(cls, other): if converted.bias_v is not None: bias_v = converted._parameters.pop("bias_v") sc, zp = torch._choose_qparams_per_tensor( - bias_k, reduce_range=False # type: ignore[possibly-undefined] + bias_k, # type: ignore[possibly-undefined] + reduce_range=False, ) bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) setattr(converted, "bias_v", bias_v) # noqa: B010 diff --git a/torch/ao/nn/quantized/modules/batchnorm.py b/torch/ao/nn/quantized/modules/batchnorm.py index 345a17e0db9c69..069db116a064b5 100644 --- a/torch/ao/nn/quantized/modules/batchnorm.py +++ b/torch/ao/nn/quantized/modules/batchnorm.py @@ -83,7 +83,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return _BatchNorm.from_float( cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) @@ -122,7 +122,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] return _BatchNorm.from_float( cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 9743f40e80749d..907a04898273be 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -247,9 +247,9 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None): if weight_post_process is None: weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) - assert ( - weight_post_process.dtype == torch.qint8 - ), "Weight observer must have a dtype of qint8" + assert weight_post_process.dtype == torch.qint8, ( + "Weight observer must have a dtype of qint8" + ) qweight = _quantize_weight(mod.weight.float(), weight_post_process) # the __init__ call used is the one from derived classes and not the one from _ConvNd qconv = cls( @@ -290,9 +290,9 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): mod.bn.weight, mod.bn.bias, ) - assert hasattr( - mod, "activation_post_process" - ), "Input QAT module must have observer attached" + assert hasattr(mod, "activation_post_process"), ( + "Input QAT module must have observer attached" + ) weight_post_process = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: @@ -304,9 +304,9 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): + " but got:" + str(type(mod)) ) - assert hasattr( - mod, "qconfig" - ), "Input float module must have qconfig defined." + assert hasattr(mod, "qconfig"), ( + "Input float module must have qconfig defined." + ) activation_post_process = ( None if not hasattr(mod, "activation_post_process") @@ -467,7 +467,7 @@ def forward(self, input): ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] r"""Creates a quantized module from a float module or qparams_dict. Args: @@ -517,6 +517,7 @@ class Conv2d(_ConvNd): >>> output = m(q_input) """ + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn2d _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU2d @@ -596,7 +597,7 @@ def forward(self, input): ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] r"""Creates a quantized module from a float module or qparams_dict. Args: @@ -646,6 +647,7 @@ class Conv3d(_ConvNd): >>> output = m(q_input) """ + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn3d _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU3d @@ -726,7 +728,7 @@ def forward(self, input): ) @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] r"""Creates a quantized module from a float module or qparams_dict. Args: @@ -792,7 +794,7 @@ def _input_padding( return res @classmethod - def from_float(cls, mod, use_precomputed_fake_quant=False): + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization @@ -809,9 +811,9 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined." weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr] weight_post_process(mod.weight) - assert ( - weight_post_process.dtype == torch.qint8 - ), "Weight observer must have a dtype of qint8" + assert weight_post_process.dtype == torch.qint8, ( + "Weight observer must have a dtype of qint8" + ) qweight = _quantize_weight(mod.weight.float(), weight_post_process) # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd qconv = cls( @@ -839,7 +841,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): return qconv @staticmethod - def from_reference(cls, ref_qconvt, output_scale, output_zero_point): + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module Args: ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization @@ -987,7 +989,7 @@ def forward(self, input): ) @classmethod - def from_reference(cls, ref_qconvt, output_scale, output_zero_point): + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] return _ConvTransposeNd.from_reference( cls, ref_qconvt, output_scale, output_zero_point ) @@ -1110,7 +1112,7 @@ def forward(self, input): ) @classmethod - def from_reference(cls, ref_qconvt, output_scale, output_zero_point): + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] return _ConvTransposeNd.from_reference( cls, ref_qconvt, output_scale, output_zero_point ) @@ -1235,7 +1237,7 @@ def forward(self, input): ) @classmethod - def from_reference(cls, ref_qconvt, output_scale, output_zero_point): + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] return _ConvTransposeNd.from_reference( cls, ref_qconvt, output_scale, output_zero_point ) diff --git a/torch/ao/nn/quantized/modules/embedding_ops.py b/torch/ao/nn/quantized/modules/embedding_ops.py index 2d9b5a6f068382..c39c8de8ce2ccc 100644 --- a/torch/ao/nn/quantized/modules/embedding_ops.py +++ b/torch/ao/nn/quantized/modules/embedding_ops.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import torch import torch.nn as nn @@ -116,6 +115,7 @@ class Embedding(torch.nn.Module): torch.Size([9, 12]) """ + _version = 1 def __init__( @@ -211,9 +211,9 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): + ".from_float only works for " + nn.Embedding.__name__ ) - assert hasattr( - mod, "qconfig" - ), "Embedding input float module must have qconfig defined" + assert hasattr(mod, "qconfig"), ( + "Embedding input float module must have qconfig defined" + ) from torch.ao.quantization import float_qparams_weight_only_qconfig if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] @@ -225,13 +225,13 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): is_float_qparams_qconfig = ( weight_observer.qscheme == torch.per_channel_affine_float_qparams ) - assert ( - is_float_qparams_qconfig - ), "Embedding quantization is only supported with float_qparams_weight_only_qconfig." + assert is_float_qparams_qconfig, ( + "Embedding quantization is only supported with float_qparams_weight_only_qconfig." + ) - assert ( - dtype == torch.quint8 or dtype == torch.quint4x2 - ), f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}" + assert dtype == torch.quint8 or dtype == torch.quint4x2, ( + f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}" + ) # Run the observer to calculate qparams. weight_observer(mod.weight) @@ -280,6 +280,7 @@ class EmbeddingBag(Embedding): torch.Size([5, 12]) """ + _version = 1 def __init__( @@ -354,9 +355,9 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): + ".from_float only works for " + nn.EmbeddingBag.__name__ ) - assert hasattr( - mod, "qconfig" - ), "EmbeddingBag input float module must have qconfig defined" + assert hasattr(mod, "qconfig"), ( + "EmbeddingBag input float module must have qconfig defined" + ) from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] @@ -368,13 +369,13 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): is_float_qparams_qconfig = ( weight_observer.qscheme == torch.per_channel_affine_float_qparams ) - assert ( - is_float_qparams_qconfig - ), "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig." + assert is_float_qparams_qconfig, ( + "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig." + ) - assert ( - dtype == torch.quint8 or dtype == torch.quint4x2 - ), f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}" + assert dtype == torch.quint8 or dtype == torch.quint4x2, ( + f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}" + ) # Run the observer to calculate qparams. weight_observer(mod.weight) diff --git a/torch/ao/nn/quantized/modules/functional_modules.py b/torch/ao/nn/quantized/modules/functional_modules.py index 10bdce2c00755d..3b364b43f60607 100644 --- a/torch/ao/nn/quantized/modules/functional_modules.py +++ b/torch/ao/nn/quantized/modules/functional_modules.py @@ -288,9 +288,9 @@ def matmul(self, x: Tensor, y: Tensor) -> Tensor: @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): - assert ( - type(mod) == FloatFunctional - ), "QFunctional.from_float expects an instance of FloatFunctional" + assert type(mod) == FloatFunctional, ( + "QFunctional.from_float expects an instance of FloatFunctional" + ) scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] new_mod = QFunctional() new_mod.scale = float(scale) diff --git a/torch/ao/nn/quantized/modules/linear.py b/torch/ao/nn/quantized/modules/linear.py index cf4997a6c2c601..9042833f5e30b2 100644 --- a/torch/ao/nn/quantized/modules/linear.py +++ b/torch/ao/nn/quantized/modules/linear.py @@ -145,6 +145,7 @@ class Linear(WeightedQuantizedModule): >>> print(output.size()) torch.Size([128, 30]) """ + _version = 3 _FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear) @@ -314,12 +315,12 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): [float_mod.__name__ for float_mod in cls._FLOAT_MODULE] ) error_msg = f"nnq.{cls.__name__}.from_float only works for {supported_modules}, but got: {type(mod)}" - assert ( - type_before_parametrizations(mod) in cls._FLOAT_MODULE - ), error_msg.format() - assert hasattr( - mod, "qconfig" - ), "Input float module must have qconfig defined" + assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, ( + error_msg.format() + ) + assert hasattr(mod, "qconfig"), ( + "Input float module must have qconfig defined" + ) activation_post_process = mod.activation_post_process if type_before_parametrizations(mod) == nni.LinearReLU: mod = mod[0] diff --git a/torch/ao/nn/quantized/modules/normalization.py b/torch/ao/nn/quantized/modules/normalization.py index e025184bd4a20c..4db2ac6e928f47 100644 --- a/torch/ao/nn/quantized/modules/normalization.py +++ b/torch/ao/nn/quantized/modules/normalization.py @@ -93,6 +93,7 @@ class GroupNorm(torch.nn.GroupNorm): * **zero_point** - quantization zero point of the output, type: long. """ + __constants__ = ["num_groups", "num_channels", "eps", "affine"] def __init__( diff --git a/torch/ao/nn/quantized/modules/rnn.py b/torch/ao/nn/quantized/modules/rnn.py index 24b17ca2d62bd5..5076c9225d2eb0 100644 --- a/torch/ao/nn/quantized/modules/rnn.py +++ b/torch/ao/nn/quantized/modules/rnn.py @@ -32,6 +32,7 @@ class LSTM(torch.ao.nn.quantizable.LSTM): >>> tq.prepare(model, prepare_custom_module_class=custom_module_config) >>> tq.convert(model, convert_custom_module_class=custom_module_config) """ + _FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment] def _get_name(self): diff --git a/torch/ao/nn/quantized/reference/modules/conv.py b/torch/ao/nn/quantized/reference/modules/conv.py index cbe2fdca52e5cd..3d4def5c4b7a0d 100644 --- a/torch/ao/nn/quantized/reference/modules/conv.py +++ b/torch/ao/nn/quantized/reference/modules/conv.py @@ -110,7 +110,7 @@ def _get_name(self): return "QuantizedConv1d(Reference)" @classmethod - def from_float(cls, float_conv, weight_qparams): + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] return _ConvNd.from_float(cls, float_conv, weight_qparams) @@ -173,7 +173,7 @@ def _get_name(self): return "QuantizedConv2d(Reference)" @classmethod - def from_float(cls, float_conv, weight_qparams): + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] return _ConvNd.from_float(cls, float_conv, weight_qparams) @@ -236,7 +236,7 @@ def _get_name(self): return "QuantizedConv3d(Reference)" @classmethod - def from_float(cls, float_conv, weight_qparams): + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] return _ConvNd.from_float(cls, float_conv, weight_qparams) @@ -346,7 +346,7 @@ def _get_name(self): return "QuantizedConvTranspose1d(Reference)" @classmethod - def from_float(cls, float_conv, weight_qparams): + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) @@ -427,7 +427,7 @@ def _get_name(self): return "QuantizedConvTranspose2d(Reference)" @classmethod - def from_float(cls, float_conv, weight_qparams): + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) @@ -507,5 +507,5 @@ def _get_name(self): return "QuantizedConvTranspose3d(Reference)" @classmethod - def from_float(cls, float_conv, weight_qparams): + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) diff --git a/torch/ao/nn/quantized/reference/modules/rnn.py b/torch/ao/nn/quantized/reference/modules/rnn.py index bd5329851e5e1b..adb1356cb3d360 100644 --- a/torch/ao/nn/quantized/reference/modules/rnn.py +++ b/torch/ao/nn/quantized/reference/modules/rnn.py @@ -82,9 +82,9 @@ def __init__( "weight_hh": weight_qparams, "is_decomposed": False, } - assert ( - len(weight_qparams_dict) == 3 - ), "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)" + assert len(weight_qparams_dict) == 3, ( + "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)" + ) self._init_weight_qparams_dict(weight_qparams_dict, device) def _init_weight_qparams_dict(self, weight_qparams_dict, device): @@ -185,7 +185,9 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: assert input.dim() in ( 1, 2, - ), f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ), ( + f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) @@ -274,7 +276,9 @@ def forward( assert input.dim() in ( 1, 2, - ), f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ), ( + f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) @@ -347,7 +351,9 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: assert input.dim() in ( 1, 2, - ), f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ), ( + f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) @@ -750,7 +756,9 @@ def forward(self, input, hx=None): # noqa: F811 assert input.dim() in ( 2, 3, - ), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor" + ), ( + f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor" + ) is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: diff --git a/torch/ao/nn/quantized/reference/modules/utils.py b/torch/ao/nn/quantized/reference/modules/utils.py index 6d31f9e31496b3..0701b73da38b0e 100644 --- a/torch/ao/nn/quantized/reference/modules/utils.py +++ b/torch/ao/nn/quantized/reference/modules/utils.py @@ -25,7 +25,9 @@ def _init_weight_qparams(self, weight_qparams, device): torch.per_tensor_affine, torch.per_channel_affine, torch.per_channel_affine_float_qparams, - ], f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}" + ], ( + f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}" + ) if self.weight_dtype in [ torch.quint8, torch.qint8, diff --git a/torch/ao/nn/sparse/quantized/dynamic/linear.py b/torch/ao/nn/sparse/quantized/dynamic/linear.py index e937a0ad2a5245..6da18e15101212 100644 --- a/torch/ao/nn/sparse/quantized/dynamic/linear.py +++ b/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -18,6 +18,7 @@ class Linear(torch.nn.Module): r""" A dynamically quantized sparse linear module with float tensor as inputs and outputs. """ + _version = 1 _op_type = "sparse_dynamic" _FLOAT_MODULE = torch.nn.Linear @@ -83,9 +84,9 @@ def _load_from_state_dict( error_msgs, ): op_type = int(state_dict[prefix + "op_type"]) - assert ( - op_type == "sparse" - ), f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]" + assert op_type == "sparse", ( + f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]" + ) state_dict.pop(prefix + "op_type") version = local_metadata.get("version", None) diff --git a/torch/ao/nn/sparse/quantized/linear.py b/torch/ao/nn/sparse/quantized/linear.py index 81f663018e7e26..e3dbf23b9f682c 100644 --- a/torch/ao/nn/sparse/quantized/linear.py +++ b/torch/ao/nn/sparse/quantized/linear.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-decorators # mypy: allow-untyped-defs from typing import Optional @@ -104,6 +103,7 @@ class Linear(torch.nn.Module): r""" A quantized sparse linear module with quantized tensor as inputs and outputs. """ + _version = 1 _FLOAT_MODULE = torch.nn.Linear @@ -265,7 +265,10 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): dtype=dtype, ) qlinear.set_weight_bias( - qweight, mod.bias, row_block_size, col_block_size # type: ignore[arg-type] + qweight, + mod.bias, + row_block_size, # type: ignore[arg-type] + col_block_size, # type: ignore[arg-type] ) qlinear.scale = float(act_scale) qlinear.zero_point = int(act_zp) diff --git a/torch/ao/nn/sparse/quantized/utils.py b/torch/ao/nn/sparse/quantized/utils.py index 70daa8fd9f361e..ccf85e68d84ff4 100644 --- a/torch/ao/nn/sparse/quantized/utils.py +++ b/torch/ao/nn/sparse/quantized/utils.py @@ -15,7 +15,7 @@ def _is_valid_linear_block_sparse_pattern( # This is a stop-gap measure as current flow does not allow module # specific block sparse pattern. -# Infact there is no way to convey sparse pattern via module config +# In fact there is no way to convey sparse pattern via module config # of quantization flow. Thus using the global context to convey # sparsity pattern. # Once the flow supports it, this should be removed. diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 71ce96e4a3873b..96d24a2cf2e754 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -63,15 +63,14 @@ def compare_weights( Example usage:: - wt_compare_dict = compare_weights( - float_model.state_dict(), qmodel.state_dict()) + wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict()) for key in wt_compare_dict: print( key, compute_error( - wt_compare_dict[key]['float'], - wt_compare_dict[key]['quantized'].dequantize() - ) + wt_compare_dict[key]["float"], + wt_compare_dict[key]["quantized"].dequantize(), + ), ) Args: @@ -422,10 +421,17 @@ def compare_model_stub( Example usage:: - module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock] - ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data) + module_swap_list = [ + torchvision.models.quantization.resnet.QuantizableBasicBlock + ] + ob_dict = compare_model_stub(float_model, qmodel, module_swap_list, data) for key in ob_dict: - print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize())) + print( + key, + compute_error( + ob_dict[key]["float"], ob_dict[key]["quantized"].dequantize() + ), + ) Args: float_model: float model used to generate the q_model @@ -532,9 +538,9 @@ def compare_model_outputs( print( key, compute_error( - act_compare_dict[key]['float'], - act_compare_dict[key]['quantized'].dequantize() - ) + act_compare_dict[key]["float"], + act_compare_dict[key]["quantized"].dequantize(), + ), ) Args: diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index 420fea50740bba..ec13839f3c9b75 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -9,7 +9,7 @@ import torch.ao.ns._numeric_suite_fx as ns m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval() - mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = quantize_fx.prepare_fx(m, {"": torch.ao.quantization.default_qconfig}) # We convert a copy because we need the original prepared model # to be available for comparisons, and `quantize_fx.convert_fx` is inplace. mq = quantize_fx.convert_fx(copy.deepcopy(mp)) @@ -19,12 +19,12 @@ # # extract weight pairs - weight_comparison = ns.extract_weights('a', mp, 'b', mq) + weight_comparison = ns.extract_weights("a", mp, "b", mq) # add SQNR for each comparison, inplace ns.extend_logger_results_with_comparison( - weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, - 'sqnr') + weight_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) # weight_comparison contains the weights from `mp` and `mq` stored # in pairs, and can be used for further analysis. @@ -36,9 +36,8 @@ # add loggers mp_ns, mq_ns = ns.add_loggers( - 'a', copy.deepcopy(mp), - 'b', copy.deepcopy(mq), - ns.OutputLogger) + "a", copy.deepcopy(mp), "b", copy.deepcopy(mq), ns.OutputLogger + ) # send an example datum to capture intermediate activations datum = torch.randn(1, 1, 1, 1) @@ -46,13 +45,12 @@ mq_ns(datum) # extract intermediate activations - act_comparison = ns.extract_logger_info( - mp_ns, mq_ns, ns.OutputLogger, 'b') + act_comparison = ns.extract_logger_info(mp_ns, mq_ns, ns.OutputLogger, "b") # add SQNR for each comparison, inplace ns.extend_logger_results_with_comparison( - act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, - 'sqnr') + act_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) # act_comparison contains the activations from `mp_ns` and `mq_ns` stored # in pairs, and can be used for further analysis. @@ -63,9 +61,8 @@ # create shadow model mp_shadows_mq = ns.add_shadow_loggers( - 'a', copy.deepcopy(mp), - 'b', copy.deepcopy(mq), - ns.OutputLogger) + "a", copy.deepcopy(mp), "b", copy.deepcopy(mq), ns.OutputLogger + ) # send an example datum to capture intermediate activations datum = torch.randn(1, 1, 1, 1) @@ -73,12 +70,13 @@ # extract intermediate activations shadow_act_comparison = ns.extract_shadow_logger_info( - mp_shadows_mq, ns.OutputLogger, 'b') + mp_shadows_mq, ns.OutputLogger, "b" + ) # add SQNR for each comparison, inplace ns.extend_logger_results_with_comparison( - shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, - 'sqnr') + shadow_act_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored # in pairs, and can be used for further analysis. @@ -596,9 +594,9 @@ def _extract_logger_info_one_model( key = mod.ref_name if key not in results: results[key] = {} - assert ( - mod.model_name not in results[key] - ), f"{mod.model_name} is already present in results" + assert mod.model_name not in results[key], ( + f"{mod.model_name} is already present in results" + ) if mod.results_type not in results[key]: results[key][mod.results_type] = {} if mod.model_name not in results[key][mod.results_type]: @@ -810,12 +808,12 @@ def extend_logger_results_with_comparison( """ for results_type_to_results in results.values(): for model_name_to_results in results_type_to_results.values(): - assert ( - model_name_1 in model_name_to_results - ), f"{model_name_1} not found in results" - assert ( - model_name_2 in model_name_to_results - ), f"{model_name_2} not found in results" + assert model_name_1 in model_name_to_results, ( + f"{model_name_1} not found in results" + ) + assert model_name_2 in model_name_to_results, ( + f"{model_name_2} not found in results" + ) results_1 = model_name_to_results[model_name_1] results_2 = model_name_to_results[model_name_2] diff --git a/torch/ao/ns/fx/graph_matcher.py b/torch/ao/ns/fx/graph_matcher.py index 36914e2ebb30bd..1f9c873971a3db 100644 --- a/torch/ao/ns/fx/graph_matcher.py +++ b/torch/ao/ns/fx/graph_matcher.py @@ -225,7 +225,9 @@ def _get_subgraph_relationship_type( assert ( subgraph_a.base_op_node == subgraph_a.start_node and subgraph_b.base_op_node == subgraph_b.start_node - ), "Matching call_module patterns where base_op_node != start_node is not supported yet" + ), ( + "Matching call_module patterns where base_op_node != start_node is not supported yet" + ) # for call_module, we need to look up the modules to do the type check assert isinstance(node_a.target, str) mod_a = getattr_from_fqn(gm_a, node_a.target) @@ -444,9 +446,9 @@ def get_matching_subgraph_pairs( key_name_b = _get_name_for_subgraph( cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b ) - assert ( - key_name_a == key_name_b - ), f"Subgraph names {key_name_a} and {key_name_b} do not match" + assert key_name_a == key_name_b, ( + f"Subgraph names {key_name_a} and {key_name_b} do not match" + ) results[key_name_a] = (cur_subgraph_a, cur_subgraph_b) continue elif cur_subgraph_a is None and cur_subgraph_b is None: diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index 04964bb79be646..bc30a014c195ad 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -646,9 +646,9 @@ def _copy_arg(arg): return arg elif isinstance(kwarg_val, (list, tuple)): for el in kwarg_val: - assert not isinstance( - el, Node - ), "handling of Node inside list is not implemented" + assert not isinstance(el, Node), ( + "handling of Node inside list is not implemented" + ) return arg else: raise AssertionError( @@ -689,13 +689,21 @@ def _copy_arg(arg): mod_a = getattr_from_fqn(gm_a, node_a.target) setattr(gm_b, new_mod_copy_name, mod_a) node_a_shadows_c = graph_c.create_node( - node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type] + node_a.op, + new_mod_copy_name, + new_args, # type: ignore[arg-type] + new_kwargs, # type: ignore[arg-type] + node_a_shadows_c_name, ) return node_a_shadows_c else: assert node_a.op in ("call_function", "call_method") node_a_shadows_c = graph_c.create_node( - node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type] + node_a.op, + node_a.target, + new_args, # type: ignore[arg-type] + new_kwargs, # type: ignore[arg-type] + node_a_shadows_c_name, ) return node_a_shadows_c @@ -1116,7 +1124,7 @@ def load_arg(a): # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c # # Note: node_start_c may be the same node as node_end_c, or they - # may have nodes inbetween. + # may have nodes in between. else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 564b32d486a1e0..5d8b569036ff24 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -356,7 +356,10 @@ def _add_placeholder( new_kwarg = [] for inner_kwarg in kwarg: p = _add_placeholder( - g, inner_kwarg, seen_names, old_name_to_new_node # type: ignore[arg-type] + g, + inner_kwarg, # type: ignore[arg-type] + seen_names, + old_name_to_new_node, ) new_kwarg.append(p) cur_kwargs_copy[kwarg_name] = new_kwarg @@ -427,9 +430,9 @@ def _add_placeholder( break # go to next node - assert ( - len(cur_node_orig.users.keys()) == 1 - ), f"{cur_node_orig} has more than 1 users, not supported yet" + assert len(cur_node_orig.users.keys()) == 1, ( + f"{cur_node_orig} has more than 1 users, not supported yet" + ) cur_node_orig = next(iter(cur_node_orig.users.keys())) cur_iteration += 1 if cur_iteration > iteration_limit: @@ -529,9 +532,9 @@ def create_one_transformed_and_logged_copy_of_subgraph( "prepare_custom_config", "qconfig_mapping", ]: - assert ( - kwarg_name not in custom_prepare_kwargs - ), f"cannot specify {kwarg_name} in custom_prepare_kwargs" + assert kwarg_name not in custom_prepare_kwargs, ( + f"cannot specify {kwarg_name} in custom_prepare_kwargs" + ) prepare_kwargs: dict[str, Any] = { "example_inputs": example_inputs, "qconfig_mapping": qconfig_mapping, @@ -1073,9 +1076,7 @@ def extract_weight_comparison(m: GraphModule) -> NSResultsType: if shadow_wrapper_node is None: continue - shadow_wrapper = getattr_from_fqn( - m, shadow_wrapper_node.target - ) # type: ignore[arg-type] + shadow_wrapper = getattr_from_fqn(m, shadow_wrapper_node.target) # type: ignore[arg-type] weight_info = _get_weight_info_from_shadow_wrapper(shadow_wrapper) if weight_info is None: continue @@ -1226,9 +1227,9 @@ def group_results_by_subgraph(results: NSResultsType) -> Any: "comparison_fn_name": subgraph_candidate_results[0]["comparison_fn_name"], } - subgraph_name_to_subgraph_results[subgraph_name][ - subgraph_candidate_idx - ] = subgraph_results + subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = ( + subgraph_results + ) return dict(subgraph_name_to_subgraph_results) diff --git a/torch/ao/ns/fx/qconfig_multi_mapping.py b/torch/ao/ns/fx/qconfig_multi_mapping.py index 4a7865f2f14b76..530d5ce52d9986 100644 --- a/torch/ao/ns/fx/qconfig_multi_mapping.py +++ b/torch/ao/ns/fx/qconfig_multi_mapping.py @@ -109,7 +109,7 @@ def _handle_list_size_mismatch( target_qconfigs_dict[key] = None break - # insert copies of this new QConfigMapping until all entires + # insert copies of this new QConfigMapping until all entries # in qconfig_list can fit among the QConfigMappings while len(qconfig_list) > len(self.qconfig_mappings_list): self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping)) diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index 6f55600671c864..b6357120dc1414 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -76,7 +76,8 @@ def get_node_first_input_and_output_type( assert isinstance(node.target, str) mod = getattr_from_fqn(gm, node.target) is_known_fp32_or_int8_input_module = any( - isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32_OR_INT8 ) if ( isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type] @@ -94,10 +95,12 @@ def get_node_first_input_and_output_type( ) return (prev_node_output_type, prev_node_output_type) is_known_fp32_input_module = any( - isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type] + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32 ) is_known_int8_input_module = any( - isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type] + isinstance(mod, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_INT8 ) if is_known_fp32_input_module: return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) @@ -136,9 +139,9 @@ def get_node_first_input_and_output_type( ) cur_node_dtype_target = get_normalized_nth_input(node, gm, 1) - assert ( - cur_node_dtype_target is torch.float16 - ), f"{cur_node_dtype_target} handling needs to be added" + assert cur_node_dtype_target is torch.float16, ( + f"{cur_node_dtype_target} handling needs to be added" + ) return (prev_node_output_type, NodeInputOrOutputType.FP16) @@ -230,7 +233,8 @@ def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value] is_known_fp32_or_int8_input_module = any( - isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] + isinstance(module_obj, target_type) # type: ignore[arg-type] + for target_type in MODS_IO_TYPE_FP32_OR_INT8 ) if is_known_fp32_or_int8_input_module: return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map) diff --git a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py index 721426e90800cd..ef6a35686c7d6b 100644 --- a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py +++ b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -75,13 +75,18 @@ def mean_reduce_fn(agg_tensor): return agg_tensor.mean(dim=0) >>> return torch.eye(data.shape).to(data.device) >>> >>> - >>> act_sparsifier.register_layer(model.some_layer, aggregate_fn=agg_fn, reduce_fn=reduce_fn, mask_fn=mask_fn) + >>> act_sparsifier.register_layer( + ... model.some_layer, + ... aggregate_fn=agg_fn, + ... reduce_fn=reduce_fn, + ... mask_fn=mask_fn, + ... ) >>> >>> # start training process >>> for _ in [...]: - >>> # epoch starts - >>> # model.forward(), compute_loss() and model.backwards() - >>> # epoch ends + >>> # epoch starts + >>> # model.forward(), compute_loss() and model.backwards() + >>> # epoch ends >>> act_sparsifier.step() >>> # end training process >>> sparsifier.squash_mask() @@ -154,7 +159,7 @@ def hook(module, input) -> None: if data is None: out_data = [ 0 for _ in range(0, len(features)) - ] # create one incase of 1st forward + ] # create one in case of 1st forward self.state[name]["mask"] = [0 for _ in range(0, len(features))] else: out_data = data # a list @@ -231,9 +236,9 @@ def register_layer( self.data_groups[name] = local_args agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name)) - self.state[name][ - "mask" - ] = None # mask will be created when model forward is called. + self.state[name]["mask"] = ( + None # mask will be created when model forward is called. + ) # attach agg hook self.data_groups[name]["hook"] = agg_hook @@ -255,9 +260,9 @@ def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None Hence, if get_mask() is called before model.forward(), an error will be raised. """ - assert ( - name is not None or layer is not None - ), "Need at least name or layer obj to retrieve mask" + assert name is not None or layer is not None, ( + "Need at least name or layer obj to retrieve mask" + ) if name is None: assert layer is not None @@ -360,9 +365,9 @@ def squash_mask(self, attach_sparsify_hook=True, **kwargs): configs["hook"] = configs["layer"].register_forward_pre_hook( self._sparsify_hook(name) ) - configs[ - "hook_state" - ] = "sparsify" # signals that sparsify hook is now attached + configs["hook_state"] = ( + "sparsify" # signals that sparsify hook is now attached + ) def _get_serializable_data_groups(self): """Exclude hook and layer from the config keys before serializing diff --git a/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py b/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py index 3f685661bd9fcb..672903e8f058cb 100644 --- a/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py +++ b/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py @@ -98,7 +98,9 @@ def get_schedule_param(self): >>> def get_schedule_param(self): ... new_param = {} ... for name in self.sparsifier.data_groups.keys(): - ... new_param[name] = self.sparsifier.data_groups[name][self.schedule_param] * 0.5 + ... new_param[name] = ( + ... self.sparsifier.data_groups[name][self.schedule_param] * 0.5 + ... ) ... return new_param When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param] diff --git a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py index 1c5e698e8b4a23..3dea01586a2b3c 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -91,9 +91,9 @@ def add_data(self, name: str, data, reuse_mask=True, **config): 4. By default, the config of the replaced data is used as config for the replacing data, unless something is specified in the config dictionary. """ - assert ( - type(data) in SUPPORTED_TYPES - ), "specified data type not supported at the moment" + assert type(data) in SUPPORTED_TYPES, ( + "specified data type not supported at the moment" + ) local_args = copy.deepcopy(self.defaults) local_args.update(config) weight = self._extract_weight(data) @@ -115,9 +115,9 @@ def add_data(self, name: str, data, reuse_mask=True, **config): if reuse_mask: current_data = self.get_data(name=name) - assert ( - weight.shape == current_data.shape - ), "to retain the old mask, the shape of the new data must be the same as the previous one" + assert weight.shape == current_data.shape, ( + "to retain the old mask, the shape of the new data must be the same as the previous one" + ) mask = self.get_mask( name=name ) # reuse mask instead of creating a new one @@ -310,7 +310,7 @@ def step(self): # type:ignore[override] self.update_mask(name, data, **config) @abc.abstractmethod - def update_mask(self, name, data, **kwargs): + def update_mask(self, name, data, **kwargs): # type: ignore[override] pass def _delete_data(self, name): diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md index 3b0d5fb3b1644d..234a573029f802 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md @@ -14,7 +14,7 @@ The [DataNormSparsifier](https://github.com/pytorch/pytorch/blob/main/torch/ao/p 3. Norm: L1 and L2 ## Dataset -The benchmarks are created for the dlrm model on the Kaggle CriteoDataset which can be downloaded from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1). +The benchmarks are created for the dlrm model on the Kaggle CriteoDataset which can be downloaded from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1). ## Results 1. **Disk Usage**: Introducing sparsity in the embeddings reduces file size after compression. The compressed model size goes down from 1.9 GB to 150 MB after 100% sparsity. @@ -34,7 +34,7 @@ The takeaway is that the dlrm model with sparse coo tensor is slower (roughly 2x ## Setup The benchmark codes depend on the [DLRM codebase](https://github.com/facebookresearch/dlrm). 1. Clone the dlrm git repository -2. Download the dataset from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1) +2. Download the dataset from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1) 3. The DLRM model can be trained using the following script ``` # Make sure you go into the file and make sure that the path to dataset is correct. diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py index 052c137c35ef6c..8192b617139bf9 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py @@ -47,7 +47,7 @@ def save_model_states( state_dict (Dict) The state_dict() as dumped by dlrm_s_pytorch.py. Only the model state will be extracted from this dictionary. This corresponds to the 'state_dict' key in the state_dict dictionary. - >>> model_state = state_dict['state_dict'] + >>> model_state = state_dict["state_dict"] save_file_name (str) The filename (not path) when saving the model state dictionary sparse_block_shape (Tuple) diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index c57b639af82e52..ff4b4f913f5033 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -32,7 +32,7 @@ class DataNormSparsifier(BaseDataSparsifier): zeros_per_block: Number of zeros in a sparse block Note:: All arguments to the DataNormSparsifier constructor are "default" - arguments and could be overriden by the configuration provided in the + arguments and could be overridden by the configuration provided in the `add_data` step. """ diff --git a/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py b/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py index ed5f2c37a02047..442639be9b2141 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py @@ -199,7 +199,7 @@ def _check_on_train_epoch_start(self, pl_module, callback): do not want as the config of each layer changes after .step() - Hence, we need to dump and restore the state_dict() everytime because we're + Hence, we need to dump and restore the state_dict() every time because we're copying the model after each epoch. Hence, it is essential to make sure that the sparsifier's state_dict() is being correctly dumped and restored. diff --git a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py index 2efdf524b367ec..b2943e2af1a872 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -66,17 +66,17 @@ def post_training_sparse_quantize( else: embedding_modules = [] - assert isinstance( - select_embeddings, list - ), "the embedding_modules must be a list of embedding modules" + assert isinstance(select_embeddings, list), ( + "the embedding_modules must be a list of embedding modules" + ) for emb in select_embeddings: - assert ( - type(emb) in SUPPORTED_MODULES - ), "the embedding_modules list must be an embedding or embedding bags" + assert type(emb) in SUPPORTED_MODULES, ( + "the embedding_modules list must be an embedding or embedding bags" + ) fqn_name = module_to_fqn(model, emb) - assert ( - fqn_name is not None - ), "the embedding modules must be part of input model" + assert fqn_name is not None, ( + "the embedding modules must be part of input model" + ) embedding_modules.append((fqn_name, emb)) if sparsify_first: @@ -118,9 +118,9 @@ def post_training_sparse_quantize( quantized_weight = quantized_emb.weight() # type: ignore[operator] quantize_params["scales"][name] = quantized_weight.q_per_channel_scales() - quantize_params["zero_points"][ - name - ] = quantized_weight.q_per_channel_zero_points() + quantize_params["zero_points"][name] = ( + quantized_weight.q_per_channel_zero_points() + ) quantize_params["dequant_weights"][name] = torch.dequantize( quantized_weight ) diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index 3da27ba38df55b..680ecd9f139e3a 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -11,7 +11,7 @@ class FPGMPruner(BaseStructuredSparsifier): r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner - This sparsifier prune fliter (row) in a tensor according to distances among filters according to + This sparsifier prune filter (row) in a tensor according to distances among filters according to `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. This sparsifier is controlled by three variables: diff --git a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py index fcbdb35939794f..ffbb99bb2967e1 100644 --- a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py +++ b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -90,12 +90,10 @@ def _get_supported_activation_modules(): return SUPPORTED_ACTIVATION_MODULES -def _get_default_structured_pruning_patterns() -> ( - dict[ - tuple[Union[type[nn.Module], Callable, MatchAllNode, str], ...], - Callable[..., None], - ] -): +def _get_default_structured_pruning_patterns() -> dict[ + tuple[Union[type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], +]: """ Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. """ diff --git a/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py index e8acbc5e458c65..f904cc3ab8c4c3 100644 --- a/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py @@ -1,9 +1,10 @@ -# mypy: allow-untyped-defs -from typing import cast +from typing import Any, cast import torch +from torch import nn -from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity +from .base_structured_sparsifier import BaseStructuredSparsifier +from .parametrization import FakeStructuredSparsity class LSTMSaliencyPruner(BaseStructuredSparsifier): @@ -25,7 +26,7 @@ class LSTMSaliencyPruner(BaseStructuredSparsifier): This applies to both weight_ih_l{k} and weight_hh_l{k}. """ - def update_mask(self, module, tensor_name, **kwargs): + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Any) -> None: weights = getattr(module, tensor_name) for p in getattr(module.parametrizations, tensor_name): diff --git a/torch/ao/pruning/_experimental/pruner/match_utils.py b/torch/ao/pruning/_experimental/pruner/match_utils.py index 3f8567bc79070d..64ef6d78c58c78 100644 --- a/torch/ao/pruning/_experimental/pruner/match_utils.py +++ b/torch/ao/pruning/_experimental/pruner/match_utils.py @@ -1,6 +1,7 @@ """ Contains utility functions to check if a pattern is in the graph and return the matching nodes """ + from typing import Any, Optional, Union import torch diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py index eef1d5d6f3bb8e..a1882af4ca11cc 100644 --- a/torch/ao/pruning/_experimental/pruner/prune_functions.py +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -3,6 +3,7 @@ Collection of conversion functions for linear / conv2d structured pruning Also contains utilities for bias propagation """ + from typing import Callable, cast, Optional import torch @@ -326,9 +327,9 @@ def prune_conv2d_pool_flatten_linear( linear_ic = linear.weight.shape[1] conv2d_oc = len(mask) - assert ( - linear_ic % conv2d_oc == 0 - ), f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + assert linear_ic % conv2d_oc == 0, ( + f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + ) flatten_scale = linear_ic // conv2d_oc flattened_mask = torch.tensor( diff --git a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py index a295b4622cc2d6..1a97cff7ab231f 100644 --- a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -7,7 +7,7 @@ class SaliencyPruner(BaseStructuredSparsifier): Prune rows based on the saliency (L1 norm) of each row. This pruner works on N-Dimensional weight tensors. - For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. + For each row, we will calculate the saliency, which is the sum the L1 norm of all weights in that row. We expect that the resulting saliency vector has the same shape as our mask. We then pick elements to remove until we reach the target sparsity_level. """ diff --git a/torch/ao/pruning/scheduler/lambda_scheduler.py b/torch/ao/pruning/scheduler/lambda_scheduler.py index 07e95b5248119a..5588c157161a00 100644 --- a/torch/ao/pruning/scheduler/lambda_scheduler.py +++ b/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -1,5 +1,7 @@ -# mypy: allow-untyped-defs import warnings +from typing import Callable, Union + +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier from .base_scheduler import BaseScheduler @@ -21,7 +23,7 @@ class LambdaSL(BaseScheduler): Example: >>> # Assuming sparsifier has two groups. >>> lambda1 = lambda epoch: epoch // 30 - >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> lambda2 = lambda epoch: 0.95**epoch >>> # xdoctest: +SKIP >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) >>> for epoch in range(100): @@ -30,7 +32,13 @@ class LambdaSL(BaseScheduler): >>> scheduler.step() """ - def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): + def __init__( + self, + sparsifier: BaseSparsifier, + sl_lambda: Union[Callable[[int], float], list[Callable[[int], float]]], + last_epoch: int = -1, + verbose: bool = False, + ) -> None: self.sparsifier = sparsifier if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): @@ -41,9 +49,9 @@ def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}" ) self.sl_lambdas = list(sl_lambda) - super().__init__(sparsifier, last_epoch, verbose) + super().__init__(sparsifier, last_epoch, verbose) # type: ignore[no-untyped-call] - def get_sl(self): + def get_sl(self) -> list[float]: if not self._get_sl_called_within_step: warnings.warn( "To get the last sparsity level computed by the scheduler, " diff --git a/torch/ao/pruning/sparsifier/base_sparsifier.py b/torch/ao/pruning/sparsifier/base_sparsifier.py index ed233b0f0b5a96..73d4c283da6326 100644 --- a/torch/ao/pruning/sparsifier/base_sparsifier.py +++ b/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -200,7 +200,9 @@ def prepare(self, model, config): and "." + info_from_tensor_fqn[key] == local_args[key] ) # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that - ), f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" + ), ( + f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" + ) local_args.update(info_from_tensor_fqn) self.groups.append(local_args) self._prepare() @@ -243,22 +245,23 @@ def squash_mask( >>> # xdoctest: +SKIP("locals are undefined") >>> # Don't save any sparse params >>> sparsifier.squash_mask() - >>> hasattr(model.submodule1, 'sparse_params') + >>> hasattr(model.submodule1, "sparse_params") False >>> # Keep sparse params per layer >>> sparsifier.squash_mask( ... params_to_keep_per_layer={ - ... 'submodule1.linear1': ('foo', 'bar'), - ... 'submodule2.linear42': ('baz',) - ... }) + ... "submodule1.linear1": ("foo", "bar"), + ... "submodule2.linear42": ("baz",), + ... } + ... ) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) {'baz': 0.1} >>> # Keep sparse params for all layers - >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar')) + >>> sparsifier.squash_mask(params_to_keep=("foo", "bar")) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) @@ -267,10 +270,9 @@ def squash_mask( >>> # Keep some sparse params for all layers, and specific ones for >>> # some other layers >>> sparsifier.squash_mask( - ... params_to_keep=('foo', 'bar'), - ... params_to_keep_per_layer={ - ... 'submodule2.linear42': ('baz',) - ... }) + ... params_to_keep=("foo", "bar"), + ... params_to_keep_per_layer={"submodule2.linear42": ("baz",)}, + ... ) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) diff --git a/torch/ao/pruning/sparsifier/utils.py b/torch/ao/pruning/sparsifier/utils.py index 4b7ce0ec44687a..302f7e0b0b7c1e 100644 --- a/torch/ao/pruning/sparsifier/utils.py +++ b/torch/ao/pruning/sparsifier/utils.py @@ -52,9 +52,9 @@ def swap_module( # respect device affinity when swapping modules devices = {p.device for p in chain(mod.parameters(), mod.buffers())} - assert ( - len(devices) <= 1 - ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + assert len(devices) <= 1, ( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) device = next(iter(devices)) if len(devices) > 0 else None if device: new_mod.to(device) diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index 58c0f7efa37d1e..89c707ad33e635 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -52,7 +52,7 @@ class WeightNormSparsifier(BaseSparsifier): Note:: All arguments to the WeightNormSparsifier constructor are "default" - arguments and could be overriden by the configuration provided in the + arguments and could be overridden by the configuration provided in the `prepare` step. """ diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 57ed1f60f948ae..ffc1792fd23faf 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -223,9 +223,9 @@ def __init__( from .utils import is_per_channel if is_per_channel(self.qscheme): - assert ( - self.ch_axis is not None - ), "Must provide a valid ch_axis if qscheme is per channel" + assert self.ch_axis is not None, ( + "Must provide a valid ch_axis if qscheme is per channel" + ) def forward(self, x: Tensor) -> Tensor: return x diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index 99b87b01dffbe4..5d79f7f71b4f2e 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -92,9 +92,9 @@ def channel_range(input, axis=0): mins = min_over_ndim(input, axis_list) maxs = max_over_ndim(input, axis_list) - assert mins.size(0) == input.size( - axis - ), "Dimensions of resultant channel range does not match size of requested axis" + assert mins.size(0) == input.size(axis), ( + "Dimensions of resultant channel range does not match size of requested axis" + ) return maxs - mins diff --git a/torch/ao/quantization/_learnable_fake_quantize.py b/torch/ao/quantization/_learnable_fake_quantize.py index 9673318d3c70ab..d12c96f66c0092 100644 --- a/torch/ao/quantization/_learnable_fake_quantize.py +++ b/torch/ao/quantization/_learnable_fake_quantize.py @@ -56,19 +56,19 @@ def __init__( self.scale = Parameter(torch.tensor([scale])) self.zero_point = Parameter(torch.tensor([zero_point])) else: - assert ( - isinstance(channel_len, int) and channel_len > 0 - ), "Channel size must be a positive integer." + assert isinstance(channel_len, int) and channel_len > 0, ( + "Channel size must be a positive integer." + ) self.scale = Parameter(torch.tensor([scale] * channel_len)) self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) self.activation_post_process = observer(**observer_kwargs) - assert ( - torch.iinfo(self.activation_post_process.dtype).min <= quant_min - ), "quant_min out of bound" - assert ( - quant_max <= torch.iinfo(self.activation_post_process.dtype).max - ), "quant_max out of bound" + assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, ( + "quant_min out of bound" + ) + assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, ( + "quant_max out of bound" + ) self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = ( @@ -145,7 +145,7 @@ def observe_quant_params(self): print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}") @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] scale = self.scale.detach() zero_point = ( diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index 60f2fe86b12e41..781bfdc8b39290 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -165,9 +165,7 @@ def _get_binary_op_configs( ) # matmul binary_op_configs.append( - BackendPatternConfig(torch.matmul).set_dtype_configs( - dtype_configs - ) # noqa: E131 + BackendPatternConfig(torch.matmul).set_dtype_configs(dtype_configs) # noqa: E131 ) return binary_op_configs @@ -483,16 +481,12 @@ def _get_ln_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf ln_configs = [] ln_configs.append( BackendPatternConfig(torch.nn.LayerNorm) - .set_observation_type( - ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 .set_dtype_configs(dtype_configs) ) ln_configs.append( BackendPatternConfig(torch.nn.functional.layer_norm) - .set_observation_type( - ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 2, "bias": 3}) ) @@ -518,27 +512,21 @@ def _get_default_op_configs( ] configs = [ BackendPatternConfig(op) - .set_observation_type( - ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 .set_dtype_configs(dtype_configs) for op in default_ops ] configs.append( BackendPatternConfig(torch.nn.functional.group_norm) - .set_observation_type( - ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 2, "bias": 3}) ) configs.append( BackendPatternConfig(torch.nn.functional.instance_norm) - .set_observation_type( - ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 3, "bias": 4}) ) diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index 33ebc91cfffd3a..3919b84da28088 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -272,13 +272,13 @@ def to_dict(self) -> dict[str, Any]: if self.input_dtype is not None: dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints if self.output_dtype is not None: - dtype_config_dict[ - OUTPUT_DTYPE_DICT_KEY - ] = self.output_dtype_with_constraints + dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = ( + self.output_dtype_with_constraints + ) if self.weight_dtype is not None: - dtype_config_dict[ - WEIGHT_DTYPE_DICT_KEY - ] = self.weight_dtype_with_constraints + dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = ( + self.weight_dtype_with_constraints + ) if self.bias_dtype is not None: dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype if self.is_dynamic is not None: @@ -671,23 +671,23 @@ def _get_dtype_config(obj: Any) -> DTypeConfig: for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []): conf.add_dtype_config(_get_dtype_config(d)) conf.set_root_module( - backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) + backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) # type: ignore[arg-type] ) - conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) + conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) # type: ignore[arg-type] conf.set_reference_quantized_module( - backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) + backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) # type: ignore[arg-type] ) conf.set_fused_module( - backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) + backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) # type: ignore[arg-type] ) conf.set_fuser_method( - backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) + backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) # type: ignore[arg-type] ) conf._set_root_node_getter( - backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) + backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) # type: ignore[arg-type] ) conf._set_extra_inputs_getter( - backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) + backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) # type: ignore[arg-type] ) conf._set_num_tensor_args_to_observation_type( backend_pattern_config_dict.get( @@ -719,31 +719,31 @@ def to_dict(self) -> dict[str, Any]: if self.qat_module is not None: backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module if self.reference_quantized_module is not None: - backend_pattern_config_dict[ - REFERENCE_QUANTIZED_MODULE_DICT_KEY - ] = self.reference_quantized_module + backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = ( + self.reference_quantized_module + ) if self.fused_module is not None: backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module if self.fuser_method is not None: backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method if self._root_node_getter is not None: - backend_pattern_config_dict[ - ROOT_NODE_GETTER_DICT_KEY - ] = self._root_node_getter + backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = ( + self._root_node_getter + ) if self._extra_inputs_getter is not None: - backend_pattern_config_dict[ - EXTRA_INPUTS_GETTER_DICT_KEY - ] = self._extra_inputs_getter + backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = ( + self._extra_inputs_getter + ) if len(self._num_tensor_args_to_observation_type) > 0: backend_pattern_config_dict[ NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY ] = self._num_tensor_args_to_observation_type if len(self._input_type_to_index) > 0: - backend_pattern_config_dict[ - INPUT_TYPE_TO_INDEX_DICT_KEY - ] = self._input_type_to_index + backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = ( + self._input_type_to_index + ) if self._pattern_complex_format is not None: - backend_pattern_config_dict[ - PATTERN_COMPLEX_FORMAT_DICT_KEY - ] = self._pattern_complex_format + backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = ( + self._pattern_complex_format + ) return backend_pattern_config_dict diff --git a/torch/ao/quantization/backend_config/onednn.py b/torch/ao/quantization/backend_config/onednn.py index 92f168e111454f..348cec62ea18a8 100644 --- a/torch/ao/quantization/backend_config/onednn.py +++ b/torch/ao/quantization/backend_config/onednn.py @@ -88,9 +88,9 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu): >>> lr = nn.LeakyReLU(0.01) >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) """ - assert ( - linear.training == bn.training and bn.training == leaky_relu.training - ), "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." + assert linear.training == bn.training and bn.training == leaky_relu.training, ( + "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." + ) if is_qat: raise NotImplementedError( @@ -200,9 +200,7 @@ def _conv_bn_add_extra_inputs_getter_left(add_pattern): else: conv_configs.append( BackendPatternConfig() - ._set_pattern_complex_format( - (add_op, nn.Conv2d, MatchAllNode) - ) # noqa: E131 + ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) # noqa: E131 .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_add_left) @@ -285,9 +283,7 @@ def _conv_bn_add_extra_inputs_getter_right(pattern): else: conv_configs.append( BackendPatternConfig() - ._set_pattern_complex_format( - (add_op, MatchAllNode, nn.Conv2d) - ) # noqa: E131 + ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) # noqa: E131 .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_add_right) @@ -390,9 +386,7 @@ def _conv_bn_add_relu_extra_inputs_getter_left(pattern): else: conv_configs.append( BackendPatternConfig() - ._set_pattern_complex_format( - (nn.ReLU, (add_op, nn.Conv2d, MatchAllNode)) - ) # noqa: E131 + ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))) # noqa: E131 .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_add_relu_left) @@ -485,9 +479,7 @@ def _conv_bn_add_relu_extra_inputs_getter_right(pattern): else: conv_configs.append( BackendPatternConfig() - ._set_pattern_complex_format( - (nn.ReLU, (add_op, MatchAllNode, nn.Conv2d)) - ) # noqa: E131 + ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))) # noqa: E131 .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_add_relu_right) diff --git a/torch/ao/quantization/experimental/adaround_fake_quantize.py b/torch/ao/quantization/experimental/adaround_fake_quantize.py index f9b0067acf286f..9992b22839b433 100644 --- a/torch/ao/quantization/experimental/adaround_fake_quantize.py +++ b/torch/ao/quantization/experimental/adaround_fake_quantize.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-decorators -# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any import torch from torch.ao.quantization.fake_quantize import _is_symmetric_quant @@ -19,17 +21,15 @@ class AdaroundFakeQuantizer(FakeQuantize): zero_point: torch.Tensor V: torch.nn.Parameter - # pyre-fixme[3]: Return type must be annotated. def __init__( self, - observer=MinMaxObserver, - qscheme=torch.per_tensor_symmetric, # not used, but needed for fakequant + observer: type = MinMaxObserver, + qscheme: torch.qscheme = torch.per_tensor_symmetric, # not used, but needed for fakequant quant_min: int = -128, quant_max: int = 127, ch_axis: int = 0, - # pyre-fixme[2]: Parameter must be annotated. - **observer_kwargs, - ): + **observer_kwargs: Any, + ) -> None: super().__init__( observer=observer, qscheme=qscheme, @@ -40,11 +40,10 @@ def __init__( ) # Populate quant_min/quant_max to observer_kwargs if valid if quant_min is not None and quant_max is not None: - assert ( - quant_min <= quant_max - ), "quant_min must be less than or equal to quant_max" - # pyre-fixme[4]: Attribute must be annotated. - self.qscheme = qscheme + assert quant_min <= quant_max, ( + "quant_min must be less than or equal to quant_max" + ) + self.qscheme: torch.qscheme = qscheme self.is_per_tensor: bool = is_per_tensor(qscheme) self.is_symmetric: bool = _is_symmetric_quant(qscheme) assert self.is_symmetric, "Only symmetric quantization is supported" @@ -106,9 +105,9 @@ def update_scale( X_q = X / self.scale X_q_floor = torch.floor(X_q) residual = X_q - X_q_floor # [0,1) - assert torch.all( - torch.ge(residual, 0) - ), "residual should be non-negative [0, 1)" + assert torch.all(torch.ge(residual, 0)), ( + "residual should be non-negative [0, 1)" + ) V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1) self.V.data = V_init @@ -117,8 +116,9 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: X_detached = X.detach() self.activation_post_process(X_detached) _scale, _zero_point = self.activation_post_process.calculate_qparams() - _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( - self.zero_point.device + _scale, _zero_point = ( + _scale.to(self.scale.device), + _zero_point.to(self.zero_point.device), ) dims = list(range(X.dim())) if not self.is_per_tensor: diff --git a/torch/ao/quantization/experimental/adaround_loss.py b/torch/ao/quantization/experimental/adaround_loss.py index 3fcf32b086a9d6..9b0ce6a32f14d4 100644 --- a/torch/ao/quantization/experimental/adaround_loss.py +++ b/torch/ao/quantization/experimental/adaround_loss.py @@ -37,9 +37,9 @@ def rounding_regularization( Major logics copied from official Adaround Implementation. Apply rounding regularization to the input tensor V. """ - assert ( - curr_iter < self.max_iter - ), "Current iteration strictly les sthan max iteration" + assert curr_iter < self.max_iter, ( + "Current iteration strictly les sthan max iteration" + ) if curr_iter < self.warm_start * self.max_iter: return torch.tensor(0.0) else: @@ -54,7 +54,7 @@ def rounding_regularization( 1 + np.cos(rel_iter * np.pi) ) - # A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf + # A rectified sigmoid for soft-quantization as formulated [23] in https://arxiv.org/pdf/2004.10568.pdf h_alpha = torch.clamp( torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA, min=0, diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py index be724458199025..fd2d8124bb7012 100644 --- a/torch/ao/quantization/experimental/adaround_optimization.py +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -107,7 +107,7 @@ def get_data_inp_out( ) if torch.cuda.is_available(): # Somehow, we need to move the model continuously - # Otherwise, the model will be lowered to CPU misteriously + # Otherwise, the model will be lowered to CPU mysteriously self.model = self.model.cuda() self.q_model = self.q_model.cuda() for data_ in data: @@ -186,9 +186,9 @@ def optimize_adaptive_rounding( inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data) print("==================== Before adaround ====================") - assert ( - torch.abs(out[0] - module(fp_in[0])).sum().item() == 0 - ), "In-placed activation is detected, please do not use activation in-placed" + assert torch.abs(out[0] - module(fp_in[0])).sum().item() == 0, ( + "In-placed activation is detected, please do not use activation in-placed" + ) # Stack the tensors in each list into a single tensor # Assuming inp and out are your lists of tensors inp_tensor = torch.vstack(inp) diff --git a/torch/ao/quantization/experimental/fake_quantize.py b/torch/ao/quantization/experimental/fake_quantize.py index 50fdcdb33ac28a..b18b5e133f1929 100644 --- a/torch/ao/quantization/experimental/fake_quantize.py +++ b/torch/ao/quantization/experimental/fake_quantize.py @@ -20,7 +20,9 @@ def __init__(self, observer: Callable = APoTObserver, **observer_kwargs: Any): self.activation_post_process = observer(**observer_kwargs) self.dtype = self.activation_post_process.dtype - def calculate_qparams(self, signed: bool = False) -> tuple[Tensor, Tensor, Tensor, Tensor]: # type: ignore[override] + def calculate_qparams( # type: ignore[override] + self, signed: bool = False + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: return self.activation_post_process.calculate_qparams(signed=signed) def forward(self, X: torch.Tensor) -> Tensor: # type: ignore[override] diff --git a/torch/ao/quantization/experimental/fake_quantize_function.py b/torch/ao/quantization/experimental/fake_quantize_function.py index c9ad8058008d76..15722cf85a1844 100644 --- a/torch/ao/quantization/experimental/fake_quantize_function.py +++ b/torch/ao/quantization/experimental/fake_quantize_function.py @@ -27,6 +27,9 @@ def forward( # type: ignore[override] return result @staticmethod - def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: Tensor) -> Tensor: # type: ignore[override] + def backward( # type: ignore[override] + ctx: torch.autograd.function.FunctionCtx, + grad_output: Tensor, + ) -> Tensor: mask = ctx.saved_tensors # type: ignore[attr-defined] return grad_output * mask diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py index 74d91e6733a190..c17008adcf6518 100644 --- a/torch/ao/quantization/fake_quantize.py +++ b/torch/ao/quantization/fake_quantize.py @@ -133,7 +133,7 @@ class FakeQuantize(FakeQuantizeBase): The output of this module is given by:: x_out = ( - clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point + clamp(round(x / scale + zero_point), quant_min, quant_max) - zero_point ) * scale * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization @@ -177,9 +177,9 @@ def __init__( super().__init__() # Populate quant_min/quant_max to observer_kwargs if valid if quant_min is not None and quant_max is not None: - assert ( - quant_min <= quant_max - ), "quant_min must be less than or equal to quant_max" + assert quant_min <= quant_max, ( + "quant_min must be less than or equal to quant_max" + ) dtype = observer_kwargs.get("dtype", torch.quint8) if hasattr(observer, "p"): # In case observer is _PartialWrapper, dtype can be stored in @@ -218,15 +218,16 @@ def __init__( self.is_per_channel = _is_per_channel(self.qscheme) @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] return self.activation_post_process.calculate_qparams() def forward(self, X): if self.observer_enabled[0] == 1: self.activation_post_process(X.detach()) _scale, _zero_point = self.calculate_qparams() - _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( - self.zero_point.device + _scale, _zero_point = ( + _scale.to(self.scale.device), + _zero_point.to(self.zero_point.device), ) if self.scale.shape != _scale.shape: self.scale.resize_(_scale.shape) @@ -328,9 +329,9 @@ class FixedQParamsFakeQuantize(FakeQuantize): # TODO: rename observer to observer_ctr def __init__(self, observer): super().__init__(observer=observer) - assert ( - type(self.activation_post_process) == FixedQParamsObserver - ), f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" + assert type(self.activation_post_process) == FixedQParamsObserver, ( + f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" + ) self._observer_ctr = observer self.scale = self.activation_post_process.scale self.zero_point = self.activation_post_process.zero_point @@ -341,7 +342,7 @@ def __init__(self, observer): ) @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] return self.scale, self.zero_point @torch.jit.export @@ -384,7 +385,9 @@ def __init__( assert isinstance( self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver), - ), "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" + ), ( + "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" + ) self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) self.is_symmetric_quant = _is_symmetric_quant( diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 20232c5dd4d7b8..260bbee37bd2bd 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -34,9 +34,9 @@ def fuse_conv_bn(is_qat, conv, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_conv_bn(m1, b1) """ - assert ( - conv.training == bn.training - ), "Conv and BN both must be in the same mode (train or eval)." + assert conv.training == bn.training, ( + "Conv and BN both must be in the same mode (train or eval)." + ) fused_module_class_map = { nn.Conv1d: nni.ConvBn1d, @@ -45,13 +45,13 @@ def fuse_conv_bn(is_qat, conv, bn): } if is_qat: - assert ( - bn.num_features == conv.out_channels - ), "Output channel of Conv2d must match num_features of BatchNorm2d" + assert bn.num_features == conv.out_channels, ( + "Output channel of Conv2d must match num_features of BatchNorm2d" + ) assert bn.affine, "Only support fusing BatchNorm2d with affine set to True" - assert ( - bn.track_running_stats - ), "Only support fusing BatchNorm2d with tracking_running_stats set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm2d with tracking_running_stats set to True" + ) fused_module_class = fused_module_class_map.get((type(conv)), None) if fused_module_class is not None: return fused_module_class(conv, bn) @@ -80,9 +80,9 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): >>> # xdoctest: +SKIP >>> m2 = fuse_conv_bn_relu(m1, b1, r1) """ - assert ( - conv.training == bn.training == relu.training - ), "Conv and BN both must be in the same mode (train or eval)." + assert conv.training == bn.training == relu.training, ( + "Conv and BN both must be in the same mode (train or eval)." + ) fused_module: Optional[type[nn.Sequential]] = None if is_qat: map_to_fused_module_train = { @@ -90,13 +90,13 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): nn.Conv2d: nni.ConvBnReLU2d, nn.Conv3d: nni.ConvBnReLU3d, } - assert ( - bn.num_features == conv.out_channels - ), "Output channel of Conv must match num_features of BatchNorm" + assert bn.num_features == conv.out_channels, ( + "Output channel of Conv must match num_features of BatchNorm" + ) assert bn.affine, "Only support fusing BatchNorm with affine set to True" - assert ( - bn.track_running_stats - ), "Only support fusing BatchNorm with tracking_running_stats set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm with tracking_running_stats set to True" + ) fused_module = map_to_fused_module_train.get(type(conv), None) if fused_module is not None: return fused_module(conv, bn, relu) @@ -133,18 +133,18 @@ def fuse_linear_bn(is_qat, linear, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_linear_bn(m1, b1) """ - assert ( - linear.training == bn.training - ), "Linear and BN both must be in the same mode (train or eval)." + assert linear.training == bn.training, ( + "Linear and BN both must be in the same mode (train or eval)." + ) if is_qat: - assert ( - bn.num_features == linear.out_features - ), "Output features of Linear must match num_features of BatchNorm1d" + assert bn.num_features == linear.out_features, ( + "Output features of Linear must match num_features of BatchNorm1d" + ) assert bn.affine, "Only support fusing BatchNorm1d with affine set to True" - assert ( - bn.track_running_stats - ), "Only support fusing BatchNorm1d with tracking_running_stats set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm1d with tracking_running_stats set to True" + ) return nni.LinearBn1d(linear, bn) else: return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) @@ -166,9 +166,9 @@ def fuse_convtranspose_bn(is_qat, convt, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_convtranspose_bn(m1, b1) """ - assert ( - convt.training == bn.training - ), "ConvTranspose and BN both must be in the same mode (train or eval)." + assert convt.training == bn.training, ( + "ConvTranspose and BN both must be in the same mode (train or eval)." + ) if is_qat: raise Exception( # noqa: TRY002 diff --git a/torch/ao/quantization/fx/README.md b/torch/ao/quantization/fx/README.md index c41fd51ff6f38c..cd380977b2aa59 100644 --- a/torch/ao/quantization/fx/README.md +++ b/torch/ao/quantization/fx/README.md @@ -296,7 +296,7 @@ BackendConfig(nniqat.LinearReLU) Pattern in this case is the same as before, it defines the pattern for the subgraph we are dealing with -`set_observation_type`: sets the observation type for the patter, currently only two types: +`set_observation_type`: sets the observation type for the pattern, currently only two types: `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` means the output observer instance will be different from the input, which is the most common type of observer placement. diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index da44665a53399d..1c4517b93c7fa3 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -72,9 +72,9 @@ def quantize_per_tensor( """ if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert ( - input.dtype == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) _quant_min_max_bounds_check(quant_min, quant_max, dtype) inv_scale = 1.0 / scale @@ -94,9 +94,9 @@ def quantize_per_tensor_meta( ) -> torch.Tensor: if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert ( - input.dtype == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) return torch.empty_like(input, dtype=dtype) @@ -122,14 +122,19 @@ def quantize_per_tensor_tensor( Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of scalar values """ - assert ( - zero_point.numel() == 1 - ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - assert ( - scale.numel() == 1 - ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) return quantize_per_tensor( - input, scale.item(), zero_point.item(), quant_min, quant_max, dtype # type: ignore[arg-type] + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min, # type: ignore[arg-type] + quant_max, # type: ignore[arg-type] + dtype, ) @@ -144,15 +149,15 @@ def quantize_per_tensor_tensor_meta( ) -> torch.Tensor: if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert ( - zero_point.numel() == 1 - ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - assert ( - scale.numel() == 1 - ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" - assert ( - input.dtype == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) return torch.empty_like(input, dtype=dtype) @@ -179,12 +184,12 @@ def quantize_per_tensor_tensor2( Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of scalar values """ - assert ( - zero_point.numel() == 1 - ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - assert ( - scale.numel() == 1 - ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) return quantize_per_tensor( input, scale.item(), @@ -205,7 +210,12 @@ def quantize_per_tensor_tensor2_meta( dtype: torch.dtype, ) -> torch.Tensor: return quantize_per_tensor_tensor_meta( - input, scale, zero_point, quant_min, quant_max, dtype # type: ignore[arg-type] + input, + scale, + zero_point, # type: ignore[arg-type] + quant_min, # type: ignore[arg-type] + quant_max, # type: ignore[arg-type] + dtype, ) @@ -256,9 +266,9 @@ def dequantize_per_tensor( Returns: dequantized float32 Tensor """ - assert ( - input.dtype == dtype - ), f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + assert input.dtype == dtype, ( + f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + ) if out_dtype is None: out_dtype = torch.float32 if dtype in _DTYPE_TO_QVALUE_BOUNDS: @@ -312,12 +322,12 @@ def dequantize_per_tensor_tensor( Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of scalar values """ - assert ( - zero_point.numel() == 1 - ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - assert ( - scale.numel() == 1 - ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) return dequantize_per_tensor( input, scale.item(), @@ -342,12 +352,12 @@ def dequantize_per_tensor_tensor_meta( ) -> torch.Tensor: if out_dtype is None: out_dtype = torch.float32 - assert ( - zero_point.numel() == 1 - ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - assert ( - scale.numel() == 1 - ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" if dtype in _DTYPE_TO_QVALUE_BOUNDS: return torch.empty_like(input, dtype=out_dtype) @@ -382,12 +392,12 @@ def dequantize_per_tensor_tensor2( Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of scalar values """ - assert ( - zero_point.numel() == 1 - ), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - assert ( - scale.numel() == 1 - ), f"Expecting scale tensor to be one element, but received : {scale.numel()}" + assert zero_point.numel() == 1, ( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + assert scale.numel() == 1, ( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) return dequantize_per_tensor( input, scale.item(), @@ -442,10 +452,12 @@ def choose_qparams_tensor( torch.float32, torch.float16, torch.bfloat16, - ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" - assert ( - dtype in _DTYPE_TO_QVALUE_BOUNDS - ), f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ], ( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + assert dtype in _DTYPE_TO_QVALUE_BOUNDS, ( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) validate_qmin_qmax(qmin, qmax) min_val, max_val = torch.aminmax(input) @@ -492,10 +504,12 @@ def choose_qparams_symmetric_tensor( torch.float32, torch.float16, torch.bfloat16, - ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" - assert ( - dtype in _DTYPE_TO_QVALUE_BOUNDS - ), f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ], ( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + assert dtype in _DTYPE_TO_QVALUE_BOUNDS, ( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) validate_qmin_qmax(qmin, qmax) min_val, max_val = torch.aminmax(input) @@ -519,11 +533,13 @@ def choose_qparams_tensor_meta( torch.float32, torch.float16, torch.bfloat16, - ], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" - assert ( - quant_min < quant_max - ), f"Expecting quant_min to be smaller than quant_max but received min: \ + ], ( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + assert quant_min < quant_max, ( + f"Expecting quant_min to be smaller than quant_max but received min: \ {quant_min} max: {quant_max}" + ) return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( 1, dtype=torch.int64, device=input.device ) @@ -582,9 +598,9 @@ def quantize_per_channel( """ if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert ( - input.dtype == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" _quant_min_max_bounds_check(quant_min, quant_max, dtype) input, permute_axis_list = _permute_to_axis_zero(input, axis) @@ -613,9 +629,9 @@ def quantize_per_channel_meta( ) -> torch.Tensor: if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert ( - input.dtype == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" _quant_min_max_bounds_check(quant_min, quant_max, dtype) return torch.empty_like(input, dtype=dtype) @@ -671,9 +687,9 @@ def dequantize_per_channel( Returns: dequantized float32 Tensor """ - assert ( - input.dtype == dtype - ), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + assert input.dtype == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + ) if out_dtype is None: out_dtype = torch.float32 assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" @@ -706,9 +722,9 @@ def dequantize_per_channel_meta( *, out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - assert ( - input.dtype == dtype - ), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + assert input.dtype == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + ) if out_dtype is None: out_dtype = torch.float32 assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" @@ -863,12 +879,12 @@ def choose_qparams_per_token_asymmetric_meta( def _per_token_quant_qparam_dim_check(input, scales, zero_points): num_tokens = math.prod(list(input.size())[:-1]) - assert ( - num_tokens == scales.numel() - ), f"num_tokens: {num_tokens} scales: {scales.size()}" - assert ( - num_tokens == zero_points.numel() - ), f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" + assert num_tokens == scales.numel(), ( + f"num_tokens: {num_tokens} scales: {scales.size()}" + ) + assert num_tokens == zero_points.numel(), ( + f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" + ) quantized_decomposed_lib.define( @@ -1138,9 +1154,9 @@ def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): scales = scales.to(torch.float32) if zero_points.dtype != torch.int32: zero_points = zero_points.to(torch.int32) - assert ( - input.dtype == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert input.dtype == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim)) unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims) diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index 77bc4e31d19912..822d261ffc3282 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -360,9 +360,11 @@ def get_op_node_and_weight_eq_obs( model, "equalization_node_name_to_qconfig" ) assert maybe_equalization_node_name_to_config is not None - equalization_node_name_to_qconfig: dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment] + equalization_node_name_to_qconfig: dict[str, Any] = ( + maybe_equalization_node_name_to_config # type: ignore[assignment] + ) assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None - weight_eq_obs = equalization_node_name_to_qconfig.get( + weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr] op_node.name, None ).weight() @@ -843,7 +845,7 @@ def convert_eq_obs( # Erase the weight equalization observer node prev_node = weight_eq_obs_node.args[0] - remove_node(model, weight_eq_obs_node, prev_node) + remove_node(model, weight_eq_obs_node, prev_node) # type: ignore[arg-type] else: raise ValueError( "Expected operation node to be 'call_module' or 'call_function" diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index b48fe4630661b6..eeaad6b8afccc2 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -28,7 +28,7 @@ ) -QOP_TO_ARG_NAMES_TO_SKIP = { +QOP_TO_ARG_NAMES_TO_SKIP: dict[Callable[..., Any], list[str]] = { torch._ops.ops.quantized.hardswish: ["inplace"], torch._ops.ops.quantized.elu: ["inplace"], torch._ops.ops.quantized.dropout: ["inplace"], @@ -435,7 +435,9 @@ def _load_packed_weight( ): attrs_to_pop = [] for attr_name in state_dict: - if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950 + if attr_name.startswith("_packed_weight") and isinstance( + state_dict[attr_name], torch._C.ScriptObject + ): # type: ignore[attr-defined] # noqa: B950 setattr(self, attr_name, state_dict[attr_name]) attrs_to_pop.append(attr_name) @@ -613,9 +615,9 @@ def _match_static_pattern( # (2) There must be at least one dequantize node matched_dequantize = False for i in dequantize_node_arg_indices: - assert i < len( - ref_node.args - ), f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" + assert i < len(ref_node.args), ( + f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" + ) arg = ref_node.args[i] if is_dequantize_node(arg): matched_dequantize = True @@ -700,7 +702,11 @@ def _lower_static_weighted_ref_module( STATIC_LOWER_FUSED_MODULE_MAP.keys() ) q_node, _relu_node, ref_node = _match_static_pattern( - n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0] # type: ignore[arg-type] + n, + modules, + qconfig_map, + matching_modules, # type: ignore[arg-type] + dequantize_node_arg_indices=[0], ) if q_node is None: continue @@ -757,7 +763,10 @@ def _lower_static_weighted_ref_module_with_two_inputs( # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys()) (q_node, ref_node) = _match_static_pattern_with_two_inputs( - n, modules, qconfig_map, matching_modules # type: ignore[arg-type] + n, + modules, + qconfig_map, + matching_modules, # type: ignore[arg-type] ) if q_node is None: continue diff --git a/torch/ao/quantization/fx/_model_report/README.md b/torch/ao/quantization/fx/_model_report/README.md index 858b271ef56f21..fa4f142aa23cfa 100644 --- a/torch/ao/quantization/fx/_model_report/README.md +++ b/torch/ao/quantization/fx/_model_report/README.md @@ -8,10 +8,10 @@ ModelReport Most detectors require a **traceable GraphModule**, but some (ex. `PerChannelDetector`) require just an `nn.Module`. #### Typical Fx Workflow -- Initialize model → Prepare model → Callibrate model → Convert model → ... +- Initialize model → Prepare model → Calibrate model → Convert model → ... #### Fx Workflow with ModelReport -- Initialize model → Prepare model → **Add detector observers** → Callibrate model → **Generate report** → **Remove detector observers** → Convert model → ... +- Initialize model → Prepare model → **Add detector observers** → Calibrate model → **Generate report** → **Remove detector observers** → Convert model → ... > ⚠️ **You can only prepare and remove observers once with a given ModelReport Instance**: Be very careful here! @@ -23,7 +23,7 @@ This snippet should be ready to copy, paste, and use with the exception of a few # prep model qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() model = Model() # TODO define model -example_input = torch.randn((*args)) # TODO get example data for callibration +example_input = torch.randn((*args)) # TODO get example data for calibration prepared_model = quantize_fx.prepare_fx(model, qconfig_mapping, example_input) # create ModelReport instance and insert observers @@ -31,8 +31,8 @@ detector_set = set([DynamicStaticDetector()]) # TODO add all desired detectors model_report = ModelReport(model, detector_set) ready_for_callibrate = model_report.prepare_detailed_callibration() -# callibrate model and generate report -ready_for_callibrate(example_input) # TODO run callibration of model with relevant data +# calibrate model and generate report +ready_for_callibrate(example_input) # TODO run calibration of model with relevant data reports = model_report.generate_model_report(remove_inserted_observers=True) for report_name in report.keys(): text_report, report_dict = reports[report_name] @@ -46,7 +46,7 @@ mod_rep_visualizer.generate_table_visualization() # shows collected data as a ta ``` There is a tutorial in the works that will walk through a full usage of the ModelReport API. -This tutorial will show the ModelReport API being used on toy model in both an Fx Graph Mode workflow and an alterative workflow with just a traceable model. +This tutorial will show the ModelReport API being used on toy model in both an Fx Graph Mode workflow and an alternative workflow with just a traceable model. This README will be updated with a link to the tutorial upon completion of the tutorial. # Key Modules Overview @@ -60,7 +60,7 @@ There are three primary methods to be familiar with when using the ModelReport c This is so that we can keep track of where we want to insert observers on a detector by detector basis and also keep track of which detectors to generate reports for. - `prepare_detailed_calibration(self)` → `GraphModule` inserts observers into the locations specified by each detector in the model. It then returns the GraphModule with the detectors inserted into both the regular module structure as well as the node structure. -- `generate_model_report(self, remove_inserted_observers: bool)` → `Dict[str, Tuple[str, Dict]]` uses callibrated GraphModule to optionally removes inserted observers, and generate, for each detector the ModelReport instance was initialized with: +- `generate_model_report(self, remove_inserted_observers: bool)` → `Dict[str, Tuple[str, Dict]]` uses calibrated GraphModule to optionally removes inserted observers, and generate, for each detector the ModelReport instance was initialized with: - A string-based report that is easily digestable and actionable explaining the data collected by relevant observers for that detector - A dictionary containing statistics collected by the relevant observers and values calculated by the detector for further analysis or plotting @@ -107,7 +107,7 @@ For both of the two things listed above, you can filter the data by either `modu To get a list of all the modules or features, you can call `mod_rep_visualizer.get_all_unique_module_fqns()` and `mod_rep_visualizer.get_all_unique_feature_names()` respectively. For the features, because some features are not plottable, you can set the flag to only get plottable features -in the aformentioned `get_all_unique_feature_names` method. +in the aforementioned `get_all_unique_feature_names` method. ## Detector Overview @@ -152,7 +152,7 @@ The statistics collected by the `ModelReportObserver` include: - Ratio of 100th percentile to some *n*th percentile - Number of constant value batches to pass through each channel -After the `ModelReportObserver` collects the statistics above during the callibration process, the detectors then extract the information they need to generate their reports from the relevant observers. +After the `ModelReportObserver` collects the statistics above during the calibration process, the detectors then extract the information they need to generate their reports from the relevant observers. ### Using Your Own Observer diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index 351f88e43aa505..4625e287011c4f 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -456,6 +456,7 @@ class DynamicStaticDetector(DetectorBase): Args: tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5 """ + # names for the pre and post observers that are inserted DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer" DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer" @@ -1158,9 +1159,9 @@ def _generate_comparison_values( input_channels = len(input_ratio) if weight_channels != input_channels: # we try to replicate - assert ( - input_channels % weight_channels == 0 - ), "input channels should be divisible by weight channels." + assert input_channels % weight_channels == 0, ( + "input channels should be divisible by weight channels." + ) # get replication factor rep_factor: int = input_channels // weight_channels diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index e76a2bf06f6647..04035b41bf68eb 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -36,7 +36,7 @@ class ModelReport: - Suggestions for outlier detection for all layers (Graph Modules) The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver) - where needed for each detector to gather the information it needs, and then after callibration, the ModelReport + where needed for each detector to gather the information it needs, and then after calibration, the ModelReport class compiles the report generated by each Detector class into a single report to return to the user. It also has the capability to remove all the observers it inserted as well. @@ -70,7 +70,7 @@ class compiles the report generated by each Detector class into a single report 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model 2.) Prepare your model with prepare_fx 3.) Call model_report.prepare_detailed_calibration to add relevant observers - 4.) Callibrate your model with data + 4.) Calibrate your model with data 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers Optional 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance @@ -84,7 +84,9 @@ class compiles the report generated by each Detector class into a single report >>> # xdoctest: +SKIP >>> # get the necessary qconfig >>> config = PrepareCustomConfig() - >>> skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(config, False) + >>> skipped_module_names, skipped_module_classes = ( + ... get_skipped_module_name_and_classes(config, False) + ... ) >>> # initialize our model and get GraphModule >>> model = SomeModel() @@ -92,17 +94,24 @@ class compiles the report generated by each Detector class into a single report >>> graph_module = GraphModule(model, tracer.trace(model)) >>> # get our set of detectors and ModelReport instance - >>> detector_set = set([DynamicStaticDetector(tolerance=0.5), InputWeightEqualizationDetector(ratio_threshold=0.7)]) + >>> detector_set = set( + ... [ + ... DynamicStaticDetector(tolerance=0.5), + ... InputWeightEqualizationDetector(ratio_threshold=0.7), + ... ] + ... ) >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set) - >>> # now we insert the observers and callibrate the model + >>> # now we insert the observers and calibrate the model >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration() >>> for i in range(num_callibration_batches): >>> example_input = get_callibration_input() >>> tracer_model_with_observers(example_input) >>> # finally we generate the reports and optionally remove the observers we inserted - >>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True) + >>> reports = tracer_reporter.generate_model_report( + ... remove_inserted_observers=True + ... ) >>> # Optional: we can generate the qconfig mapping based on the suggestions >>> qconfigs = model_report.generate_qconfig_mapping() @@ -170,7 +179,7 @@ def prepare_detailed_calibration(self) -> GraphModule: # if already prepared once, cannot prepare again if self._prepared_flag: raise ValueError( - "Already ran preparing detailed callibration. Run the report generation next after callibration." + "Already ran preparing detailed calibration. Run the report generation next after calibration." ) # loop through each detector, find where placements should be, and keep track @@ -262,7 +271,7 @@ def generate_model_report( Generates all the requested reports. Note: - You should have callibrated the model with relevant data before calling this + You should have calibrated the model with relevant data before calling this The reports generated are specified by the desired_reports specified in desired_reports @@ -277,12 +286,12 @@ def generate_model_report( Note: Throws exception if we try to generate report on model we already removed observers from - Throws exception if we try to generate report without preparing for callibration + Throws exception if we try to generate report without preparing for calibration """ - # if we haven't prepped model for callibration, then we shouldn't generate report yet + # if we haven't prepped model for calibration, then we shouldn't generate report yet if not self._prepared_flag: raise Exception( # noqa: TRY002 - "Cannot generate report without preparing model for callibration" + "Cannot generate report without preparing model for calibration" ) # if we already removed the observers, we cannot generate report @@ -537,12 +546,12 @@ def _generate_module_fqn_to_detector_info_mapping( Note: Throws exception if we try to generate mapping on model we already removed observers from - Throws exception if we try to generate mapping without preparing for callibration + Throws exception if we try to generate mapping without preparing for calibration """ - # if we haven't prepped model for callibration, then we shouldn't generate mapping yet + # if we haven't prepped model for calibration, then we shouldn't generate mapping yet if not self._prepared_flag: raise Exception( # noqa: TRY002 - "Cannot generate report without preparing model for callibration" + "Cannot generate report without preparing model for calibration" ) # if we already removed the observers, we cannot mapping @@ -591,7 +600,7 @@ def generate_qconfig_mapping(self) -> QConfigMapping: Note: Throws exception if we try to generate mapping on model we already removed observers from - Throws exception if we try to generate mapping without preparing for callibration + Throws exception if we try to generate mapping without preparing for calibration """ # get the mapping info detector_qconfig_info_combined = ( diff --git a/torch/ao/quantization/fx/_model_report/model_report_observer.py b/torch/ao/quantization/fx/_model_report/model_report_observer.py index db9c130606a820..a809dc60838e57 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_observer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_observer.py @@ -279,7 +279,7 @@ def reset_batch_and_epoch_values(self): self.constant_channels = torch.tensor([], device=device) @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] raise Exception( # noqa: TRY002 "calculate_qparams should not be called for ModelReportObserver" ) diff --git a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py index c8699813f2d181..63d31171bbe76f 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -63,7 +63,7 @@ class ModelReportVisualizer: 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects 2.) Prepare your model with prepare_fx 3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers - 4.) Callibrate your model with data + 4.) Calibrate your model with data 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers 6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance 7.) Use instance to view different views of data as desired, applying filters as needed @@ -338,9 +338,8 @@ def generate_filtered_tables( Example Use: >>> # xdoctest: +SKIP("undefined variables") >>> mod_report_visualizer.generate_filtered_tables( - ... feature_filter = "per_channel_min", - ... module_fqn_filter = "block1" - ... ) # generates table with per_channel_min info for all modules in block 1 of the model + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) # generates table with per_channel_min info for all modules in block 1 of the model """ # first get the filtered data filtered_data: OrderedDict[str, Any] = self._get_filtered_data( @@ -427,8 +426,7 @@ def generate_table_visualization( Example Use: >>> # xdoctest: +SKIP("undefined variables") >>> mod_report_visualizer.generate_table_visualization( - ... feature_filter = "per_channel_min", - ... module_fqn_filter = "block1" + ... feature_filter="per_channel_min", module_fqn_filter="block1" ... ) >>> # prints out neatly formatted table with per_channel_min info >>> # for all modules in block 1 of the model @@ -590,8 +588,7 @@ def generate_plot_visualization( Example Use: >>> # xdoctest: +SKIP("undefined variables") >>> mod_report_visualizer.generate_plot_visualization( - ... feature_filter = "per_channel_min", - ... module_fqn_filter = "block1" + ... feature_filter="per_channel_min", module_fqn_filter="block1" ... ) >>> # outputs line plot of per_channel_min information for all >>> # modules in block1 of model each channel gets it's own line, @@ -664,8 +661,7 @@ def generate_histogram_visualization( Example Use: >>> # xdoctest: +SKIP >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization( - ... feature_filter = "per_channel_min", - ... module_fqn_filter = "block1" + ... feature_filter="per_channel_min", module_fqn_filter="block1" ... ) # outputs histogram of per_channel_min information for all modules in block1 of model information is gathered across all channels for all modules in block 1 for the diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 457e03f6609012..9513fb288850b5 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -509,9 +509,9 @@ def _replace_observer_or_dequant_stub_with_dequantize_node( node: Node, graph: Graph ) -> None: call_custom_module_node = node.args[0] - assert isinstance( - call_custom_module_node, Node - ), f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + assert isinstance(call_custom_module_node, Node), ( + f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + ) node.replace_all_uses_with(call_custom_module_node) graph.erase_node(node) _insert_dequantize_node(call_custom_module_node, graph) @@ -604,9 +604,9 @@ def _get_module_path_and_prefix( # operator (they can be the same) # this flag identifies if the observer is inserted only because the observed node is # the input of the next operator - assert isinstance( - observed_node, Node - ), f"Expecting observed node to be a Node, but got {observed_node}" + assert isinstance(observed_node, Node), ( + f"Expecting observed node to be a Node, but got {observed_node}" + ) is_input_observer_only = ( node_name_to_qconfig[observed_node.name] is None if observed_node.name in node_name_to_qconfig @@ -865,9 +865,9 @@ def convert_weighted_module( ref_qmodule_cls = root_module_to_quantized_reference_module.get( type_before_parametrizations(float_module), None ) - assert ( - ref_qmodule_cls is not None - ), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" + assert ref_qmodule_cls is not None, ( + f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" + ) ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] if fused_module is not None: fused_module[0] = ref_qmodule # type: ignore[operator] @@ -887,9 +887,9 @@ def _remove_previous_dequantize_in_custom_module( \\ - dequantize """ # expecting the input node for a custom module node to be a Node - assert isinstance( - prev_node, Node - ), f"Expecting the argument for custom module node to be a Node, but got {prev_node}" + assert isinstance(prev_node, Node), ( + f"Expecting the argument for custom module node to be a Node, but got {prev_node}" + ) if prev_node.op == "call_method" and prev_node.target == "dequantize": node.replace_input_with(prev_node, prev_node.args[0]) # Remove the dequantize node if it doesn't have other users @@ -1060,14 +1060,16 @@ def convert( assert _is_observed_module(model), "incoming model must be produced by prepare_fx" observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] - node_name_to_scope: dict[ - str, tuple[str, type] - ] = observed_graph_module_attrs.node_name_to_scope + node_name_to_scope: dict[str, tuple[str, type]] = ( + observed_graph_module_attrs.node_name_to_scope + ) prepare_custom_config: PrepareCustomConfig = ( observed_graph_module_attrs.prepare_custom_config ) observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names - node_name_to_qconfig: dict[str, QConfigAny] = observed_graph_module_attrs.node_name_to_qconfig # type: ignore[assignment] + node_name_to_qconfig: dict[str, QConfigAny] = ( + observed_graph_module_attrs.node_name_to_qconfig + ) # type: ignore[assignment] # mapping from fully qualified module name to module instance # for example, @@ -1083,14 +1085,18 @@ def convert( # TODO refactor this code once we update the prepare logic to have additional information on # which graph nodes have been observed and share that with convert to decide which observers to ignore. if qconfig_mapping: - prepare_qconfig_mapping: QConfigMapping = observed_graph_module_attrs.qconfig_mapping # type: ignore[assignment] + prepare_qconfig_mapping: QConfigMapping = ( + observed_graph_module_attrs.qconfig_mapping + ) # type: ignore[assignment] modules_copy = copy.deepcopy(modules) if observed_graph_module_attrs.is_qat: _update_qconfig_for_qat(qconfig_mapping, backend_config) _update_qconfig_for_fusion(model, qconfig_mapping) - _compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type] + _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping, qconfig_mapping + ) # type: ignore[arg-type] convert_node_name_to_qconfig = _generate_node_name_to_qconfig( model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope ) @@ -1098,9 +1104,9 @@ def convert( # all the values either match what was set in prepare node_name_to_qconfig # or are set to None in the convert_node_name_to_qconfig. for k, v in node_name_to_qconfig.items(): - assert ( - k in convert_node_name_to_qconfig - ), f"Expected key {k} in convert node_name_to_qconfig" + assert k in convert_node_name_to_qconfig, ( + f"Expected key {k} in convert node_name_to_qconfig" + ) if convert_node_name_to_qconfig[k] is not None: assert qconfig_equals(v, convert_node_name_to_qconfig[k]), ( f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " diff --git a/torch/ao/quantization/fx/custom_config.py b/torch/ao/quantization/fx/custom_config.py index 5301db9317fdc4..598c42ea22e3b2 100644 --- a/torch/ao/quantization/fx/custom_config.py +++ b/torch/ao/quantization/fx/custom_config.py @@ -355,9 +355,9 @@ def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): ) in self.float_to_observed_mapping.items(): if FLOAT_TO_OBSERVED_DICT_KEY not in d: d[FLOAT_TO_OBSERVED_DICT_KEY] = {} - d[FLOAT_TO_OBSERVED_DICT_KEY][ - _get_quant_type_to_str(quant_type) - ] = float_to_observed_mapping + d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = ( + float_to_observed_mapping + ) if len(self.non_traceable_module_names) > 0: d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names if len(self.non_traceable_module_classes) > 0: @@ -460,9 +460,9 @@ def to_dict(self) -> dict[str, Any]: ) in self.observed_to_quantized_mapping.items(): if OBSERVED_TO_QUANTIZED_DICT_KEY not in d: d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {} - d[OBSERVED_TO_QUANTIZED_DICT_KEY][ - _get_quant_type_to_str(quant_type) - ] = observed_to_quantized_mapping + d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = ( + observed_to_quantized_mapping + ) if len(self.preserved_attributes) > 0: d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes return d @@ -474,7 +474,9 @@ class FuseCustomConfig: Example usage:: - fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + ["attr1", "attr2"] + ) """ def __init__(self) -> None: diff --git a/torch/ao/quantization/fx/fuse_handler.py b/torch/ao/quantization/fx/fuse_handler.py index b7a3c60d0dd58d..68a5a440a51284 100644 --- a/torch/ao/quantization/fx/fuse_handler.py +++ b/torch/ao/quantization/fx/fuse_handler.py @@ -64,9 +64,9 @@ def fuse( fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]], is_qat: bool, ) -> Node: - assert ( - root_node.op == "call_module" - ), "Expecting module node to be a call_module Node" + assert root_node.op == "call_module", ( + "Expecting module node to be a call_module Node" + ) root_module = named_modules[str(root_node.target)] def get_modules(pattern): diff --git a/torch/ao/quantization/fx/graph_module.py b/torch/ao/quantization/fx/graph_module.py index 235292553d2296..15d8fc7852e0fb 100644 --- a/torch/ao/quantization/fx/graph_module.py +++ b/torch/ao/quantization/fx/graph_module.py @@ -175,7 +175,9 @@ def _load_from_state_dict( ): attrs_to_pop = [] for attr_name in state_dict: - if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950 + if attr_name.startswith("_packed_weight") and isinstance( + state_dict[attr_name], torch._C.ScriptObject + ): # type: ignore[attr-defined] # noqa: B950 setattr(self, attr_name, state_dict[attr_name]) attrs_to_pop.append(attr_name) diff --git a/torch/ao/quantization/fx/lstm_utils.py b/torch/ao/quantization/fx/lstm_utils.py index 83a234fd8e1b8b..fe18ba465212f9 100644 --- a/torch/ao/quantization/fx/lstm_utils.py +++ b/torch/ao/quantization/fx/lstm_utils.py @@ -81,14 +81,14 @@ def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: quantizable_lstm.qconfig = float_lstm.qconfig for idx in range(float_lstm.num_layers): - quantizable_lstm.layers[ - idx - ] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float( - float_lstm, - idx, - float_lstm.qconfig, - batch_first=False, - split_gates=split_gates, + quantizable_lstm.layers[idx] = ( + torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float( + float_lstm, + idx, + float_lstm.qconfig, + batch_first=False, + split_gates=split_gates, + ) ) # Build QConfigMapping for the LSTM cell diff --git a/torch/ao/quantization/fx/pattern_utils.py b/torch/ao/quantization/fx/pattern_utils.py index 551f68be424f27..e86f95d67aba09 100644 --- a/torch/ao/quantization/fx/pattern_utils.py +++ b/torch/ao/quantization/fx/pattern_utils.py @@ -47,9 +47,9 @@ def _register_quant_pattern(pattern, fixed_qparams_observer=None): def insert(fn): _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn if fixed_qparams_observer is not None: - _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[ - pattern - ] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer) + _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = ( + FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer) + ) _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer return fn @@ -81,7 +81,7 @@ def get_default_output_activation_post_process_map( def _sorted_patterns_dict( - patterns_dict: dict[Pattern, QuantizeHandler] + patterns_dict: dict[Pattern, QuantizeHandler], ) -> dict[Pattern, QuantizeHandler]: """ Return a sorted version of the patterns dictionary such that longer patterns are matched first, diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index e6fb3cda3bcfbf..b1b2c6b05b33eb 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -117,7 +117,7 @@ def _get_observer_kwargs( - quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec] + quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec], ): kwargs_dict = asdict(quant_spec) return copy.deepcopy(kwargs_dict) @@ -213,9 +213,9 @@ def _needs_obs_or_fq( # need to insert placeholder observer for dynamic quantization so that it can # be converted to choose_qparams -> q -> dq in convert step if cur_target_is_dynamic: - assert ( - cur_target_dtype in _OBS_DTYPE_LIST - ), f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" + assert cur_target_dtype in _OBS_DTYPE_LIST, ( + f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" + ) assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST return is_zeroth_arg if reuse_input_obs_or_fq: @@ -695,9 +695,9 @@ def _get_output_act_obs_or_fq( ) elif _is_activation_post_process_node(arg, named_modules): observed_arg = arg.args[0] - assert isinstance( - observed_arg, Node - ), "Currently we only support observing Node" + assert isinstance(observed_arg, Node), ( + "Currently we only support observing Node" + ) if "quantization_annotation" in observed_arg.meta: output_act_obs_or_fq = _create_obs_or_fq_from_qspec( observed_arg.meta["quantization_annotation"].output_qspec, @@ -935,8 +935,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] if ( type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) - and maybe_obs_mod.dtype - == arg_as_input_target_dtype # type: ignore[possibly-undefined] + and maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined] ): arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] existing_obs_node = maybe_obs_node @@ -1108,7 +1107,7 @@ def _maybe_insert_output_observer_for_node( ) target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq) # uncomment after we support reuse_input_obs_or_fq properly by having separate - # implemntations for this key instead of reusing the input_output_share_observers + # implementations for this key instead of reusing the input_output_share_observers # code # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False) # for now we set this to False since reuse_input_obs_or_fq for @@ -1118,7 +1117,7 @@ def _maybe_insert_output_observer_for_node( reuse_input_obs_or_fq = False # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False - # because the prev_output is the output of an fp32 op, althought technically + # because the prev_output is the output of an fp32 op, although technically # we should get the dtype of the output from node.meta["val"] in the future # if we deprecate fx graph mode quantization needs_obs_or_fq = _needs_obs_or_fq( @@ -1502,7 +1501,7 @@ def insert_observers_for_model( # first, populate the dtype map based only on qconfig and qhandler # this assumes: - # graph inputs are fp32 by default, and int8 where overriden + # graph inputs are fp32 by default, and int8 where overridden # other nodes output dtype is specified by the qconfig named_modules = dict(model.named_modules(remove_duplicate=False)) @@ -1938,9 +1937,7 @@ def _run_prepare_fx_on_standalone_modules( ) standalone_module = named_modules[root_node.target] - prepare = ( - torch.ao.quantization.quantize_fx._prepare_standalone_module_fx - ) # type: ignore[attr-defined] + prepare = torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined] observed_standalone_module = prepare( standalone_module, sm_qconfig_mapping, @@ -2005,7 +2002,7 @@ def prepare( same as input_quantized_idxs configuration provided for the standalone module standalone_module_output_quantized_idxs(List[Int]): a list of - indexs for the graph output that is quantized + indices for the graph output that is quantized same as input_quantized_idxs configuration provided for the standalone module """ @@ -2174,9 +2171,9 @@ def prepare( # converting List[int] to Tensor since module attribute is # Union[Tensor, Module] input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes - output_quantized_idxs: list[ - int - ] = prepare_custom_config.output_quantized_indexes + output_quantized_idxs: list[int] = ( + prepare_custom_config.output_quantized_indexes + ) observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] # inplace modification observed_graph_module_attrs.is_observed_standalone_module = True diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index ff45c15946dc01..421e6d4b8eba6b 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -229,7 +229,9 @@ def _compare_prepare_convert_qconfig_mappings( """ assert qconfig_equals( prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig - ), "Expected global qconfigs to be the same in the prepare and convert quantization configs" + ), ( + "Expected global qconfigs to be the same in the prepare and convert quantization configs" + ) prepare_dicts: list[OrderedDict] = [ prepare_qconfig_mapping.object_type_qconfigs, prepare_qconfig_mapping.module_name_qconfigs, @@ -247,14 +249,16 @@ def _compare_prepare_convert_qconfig_mappings( ] for i in range(len(prepare_dicts)): for name in prepare_dicts[i].keys(): - assert ( - name in convert_dicts[i] - ), f"Missing key {dict_names[i]} {name} in convert QConfigMapping \ + assert name in convert_dicts[i], ( + f"Missing key {dict_names[i]} {name} in convert QConfigMapping \ when it was present in prepare" + ) assert convert_dicts[i][name] is None or qconfig_equals( prepare_dicts[i][name], convert_dicts[i][name] - ), f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \ + ), ( + f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \ prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" + ) def _is_qconfig_supported_by_dtype_configs( diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index ba8d779e1c020c..fb17d6b1641753 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -114,7 +114,7 @@ def node_arg_is_bias(node: Node, arg: Any) -> bool: def get_custom_module_class_keys( - custom_module_mapping: dict[QuantType, dict[type, type]] + custom_module_mapping: dict[QuantType, dict[type, type]], ) -> list[Any]: r"""Get all the unique custom module keys in the custom config dict e.g. @@ -190,7 +190,7 @@ def get_attr_name(i: int): def collect_producer_nodes(node: Node) -> Optional[list[Node]]: - r"""Starting from a target node, trace back until we hit inpu or + r"""Starting from a target node, trace back until we hit input or getattr node. This is used to extract the chain of operators starting from getattr to the target node, for example def forward(self, x): @@ -495,7 +495,9 @@ def _is_custom_module_lstm( """ mod = _get_module(node, named_modules) if qconfig is not None and qhandler is not None: - assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined] + assert isinstance( + qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler + ) # type: ignore[attr-defined] return ( isinstance(mod, torch.nn.LSTM) and activation_is_statically_quantized(qconfig) @@ -517,7 +519,9 @@ def _is_custom_module_mha( """ mod = _get_module(node, named_modules) if qconfig is not None and qhandler is not None: - assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined] + assert isinstance( + qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler + ) # type: ignore[attr-defined] return ( isinstance(mod, torch.nn.MultiheadAttention) and activation_is_statically_quantized(qconfig) diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 673d52e8924e1f..c2610fd3ca7f42 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -255,9 +255,11 @@ def __init__( torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams, - ), "Default Observer only works for per_tensor_affine, \ + ), ( + "Default Observer only works for per_tensor_affine, \ per_tensor_symmetric, per_channel_affine, \ per_channel_symmetric and per_channel_float_qparams quantization scheme" + ) _ALLOWED_DTYPES = ( torch.qint8, @@ -273,9 +275,9 @@ def __init__( torch.uint16, ) - assert ( - self.dtype in _ALLOWED_DTYPES - ), f"Default Observer only works for {_ALLOWED_DTYPES} data type" + assert self.dtype in _ALLOWED_DTYPES, ( + f"Default Observer only works for {_ALLOWED_DTYPES} data type" + ) self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) if self.has_customized_qrange: validate_qmin_qmax(quant_min, quant_max) @@ -331,12 +333,12 @@ def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: """ # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. - assert ( - quant_min <= 0 <= quant_max - ), "Used-specified quantization range must include 0." - assert ( - quant_min < quant_max - ), "qmin must be strictly less than qmax for user-specified quantization range." + assert quant_min <= 0 <= quant_max, ( + "Used-specified quantization range must include 0." + ) + assert quant_min < quant_max, ( + "qmin must be strictly less than qmax for user-specified quantization range." + ) @torch.jit.export def _calculate_qparams( @@ -356,7 +358,7 @@ def _calculate_qparams( # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code - # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. + # seems unlikely to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. # TODO(jakeszwe, jerryzh168) if not check_min_max_valid(min_val, max_val): return torch.tensor([1.0], device=min_val.device.type), torch.tensor( @@ -562,7 +564,7 @@ def forward(self, x_orig): return x_orig @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] r"""Calculates the quantization parameters.""" return self._calculate_qparams(self.min_val, self.max_val) @@ -785,7 +787,7 @@ def _forward(self, x_orig): return x_orig @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] return self._calculate_qparams(self.min_val, self.max_val) def extra_repr(self): @@ -1239,7 +1241,7 @@ def _combine_histograms( # If the orig hist only has one value (i.e., the min and max are the same) # we can just add it into new histogram if orig_min == orig_max: - bin_value = torch.sum(update_hist) + bin_value = torch.sum(orig_hist) transformed_orig_hist = ( torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type] * bin_value @@ -1268,9 +1270,9 @@ def reset_histogram( self.min_val.copy_(min_val) self.max_val.resize_(max_val.shape) self.max_val.copy_(max_val) - assert ( - min_val.numel() == 1 and max_val.numel() == 1 - ), "histogram min/max values must be scalar." + assert min_val.numel() == 1 and max_val.numel() == 1, ( + "histogram min/max values must be scalar." + ) new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type] self.histogram.detach_().resize_(new_histogram.shape) self.histogram.copy_(new_histogram) @@ -1305,7 +1307,10 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor: # pyre-ignore[14] # new_min and new_max should already have requires_grad set to False new_min, new_max = new_min.detach(), new_max.detach() update_histogram = torch.histc( - x, self.bins, min=new_min, max=new_max # type: ignore[arg-type] + x, + self.bins, + min=new_min, # type: ignore[arg-type] + max=new_max, # type: ignore[arg-type] ).to(self.histogram.device) if new_min == current_min and new_max == current_max: combined_histogram = self.histogram + update_histogram @@ -1330,7 +1335,7 @@ def forward(self, x_orig: torch.Tensor) -> torch.Tensor: # pyre-ignore[14] return x_orig @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] is_uninitialized = self.min_val == float("inf") and self.max_val == float( "-inf" ) @@ -1443,7 +1448,7 @@ def forward(self, X): return X @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] return self.scale, self.zero_point @@ -1512,7 +1517,7 @@ def extra_repr(self): return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}" @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] raise Exception( # noqa: TRY002 "calculate_qparams should not be called for PlaceholderObserver" ) @@ -1539,7 +1544,7 @@ def forward(self, x): return x @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] raise Exception( # noqa: TRY002 "calculate_qparams should not be called for RecordingObserver" ) @@ -1572,7 +1577,7 @@ def forward(self, x): return x @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] raise Exception( # noqa: TRY002 "calculate_qparams should not be called for NoopObserver" ) @@ -1599,7 +1604,7 @@ def forward(self, x): return x @torch.jit.export - def calculate_qparams(self): + def calculate_qparams(self): # type: ignore[override] raise Exception( # noqa: TRY002 "calculate_qparams should not be called for ReuseInputObserver" ) @@ -1777,9 +1782,9 @@ def get_block_size( input_shape: The input tensor shape possibly more than 2 dimensions granularity: The granularity type of the quantization """ - assert isinstance( - granularity, Granularity - ), "Please provide an instance of Granularity, not subclass of it" + assert isinstance(granularity, Granularity), ( + "Please provide an instance of Granularity, not subclass of it" + ) if isinstance(granularity, PerTensor): return input_shape elif isinstance(granularity, PerAxis): @@ -1789,9 +1794,9 @@ def get_block_size( elif isinstance(granularity, PerRow): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) elif isinstance(granularity, PerGroup): - assert ( - len(input_shape) == 2 - ), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" + assert len(input_shape) == 2, ( + f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" + ) return (1, granularity.group_size) elif isinstance(granularity, PerToken): block_size = [1] * len(input_shape) @@ -1861,16 +1866,16 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): Converts the observer node in the graph into its quantized representation Args: - model: graph module to conver the observer node in + model: graph module to convert the observer node in observer_node: the observer node to convert """ from torch.ao.quantization.fx.utils import create_getattr_from_value with model.graph.inserting_before(observer_node): assert self.block_size is not None, "Expecting block_size to be populated" - assert ( - self.original_dtype is not None - ), "Expecting original_dtype to be populated" + assert self.original_dtype is not None, ( + "Expecting original_dtype to be populated" + ) if hasattr(self, "is_dynamic") and self.is_dynamic: choose_qparams_affine = model.graph.call_function( torch.ops.pt2e_quant.choose_qparams_affine, diff --git a/torch/ao/quantization/pt2e/_affine_quantization.py b/torch/ao/quantization/pt2e/_affine_quantization.py index 32b4a773f28f4d..e4eac6f6cc776f 100644 --- a/torch/ao/quantization/pt2e/_affine_quantization.py +++ b/torch/ao/quantization/pt2e/_affine_quantization.py @@ -1,6 +1,6 @@ # copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py # and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py -# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC +# PLEASE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC import logging from abc import ABCMeta from typing import Any, Optional, Union @@ -165,12 +165,12 @@ def decorator(fn): # expecting fn.__name__ starts with `_` and we want to take the rest # to be the name of the custom op - assert ( - fn.__name__[0] == "_" - ), f"Expecting function name starts with `_`, got {fn.__name__}" - assert not any( - c in fn.__name__ for c in ".<>" - ), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + assert fn.__name__[0] == "_", ( + f"Expecting function name starts with `_`, got {fn.__name__}" + ) + assert not any(c in fn.__name__ for c in ".<>"), ( + f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + ) op_name = fn.__name__[1:] schema = op_name + infer_schema(fn, mutates_args={}) lib.define(schema) @@ -261,9 +261,9 @@ def _choose_qparams_affine( MappingType.ASYMMETRIC.name, ], f"Unsupported mapping type: {mapping_type}" if target_dtype in FP8_TYPES: - assert ( - mapping_type == MappingType.SYMMETRIC.name - ), f"Only symmetric quantization is supported for FP8 types, got {mapping_type}" + assert mapping_type == MappingType.SYMMETRIC.name, ( + f"Only symmetric quantization is supported for FP8 types, got {mapping_type}" + ) if input is not None: if scale_dtype is None: @@ -273,9 +273,9 @@ def _choose_qparams_affine( if eps is None: eps = torch.finfo(input.dtype).eps - assert ( - len(block_size) == input.dim() - ), f"Got input dim:{input.dim()}, block_size: {block_size}" + assert len(block_size) == input.dim(), ( + f"Got input dim:{input.dim()}, block_size: {block_size}" + ) shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input.size() ) @@ -284,12 +284,12 @@ def _choose_qparams_affine( min_val = torch.amin(input, dim=reduction_dims, keepdim=False) max_val = torch.amax(input, dim=reduction_dims, keepdim=False) else: - assert ( - min_val is not None and max_val is not None - ), "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" - assert ( - min_val.dtype == max_val.dtype - ), "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" + assert min_val is not None and max_val is not None, ( + "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" + ) + assert min_val.dtype == max_val.dtype, ( + "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" + ) if scale_dtype is None: scale_dtype = min_val.dtype @@ -351,9 +351,9 @@ def _choose_qparams_affine( zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) else: - assert ( - zero_point_domain == ZeroPointDomain.FLOAT.name - ), "if not preserve_zero, zero_point must be in FLOAT domain" + assert zero_point_domain == ZeroPointDomain.FLOAT.name, ( + "if not preserve_zero, zero_point must be in FLOAT domain" + ) mid_point = (quant_max + quant_min + 1) / 2 zero_point = min_val_neg + scale * mid_point @@ -469,7 +469,7 @@ def _quantize_affine_no_dtype_cast( 1. figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction 2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain - 3. reshape the quantized result to origianl shape + 3. reshape the quantized result to original shape """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -478,9 +478,9 @@ def _quantize_affine_no_dtype_cast( torch.float16, torch.bfloat16, ], f"Unsupported input dtype: {input.dtype}" - assert ( - len(block_size) == input.dim() - ), f"Got input dim:{input.dim()}, block_size: {block_size}" + assert len(block_size) == input.dim(), ( + f"Got input dim:{input.dim()}, block_size: {block_size}" + ) shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input.size() ) @@ -498,15 +498,15 @@ def _quantize_affine_no_dtype_cast( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ) elif zero_point_domain == ZeroPointDomain.NONE.name: - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is NONE" + assert zero_point is None, ( + "zero_point should be None when zero_point_domain is NONE" + ) quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) elif zero_point_domain is None: # This case handles quantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" + assert zero_point is None, ( + "zero_point should be None when zero_point_domain is None" + ) quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) else: assert zero_point_domain == ZeroPointDomain.FLOAT.name @@ -582,9 +582,9 @@ def _dequantize_affine( """op definition that has compatible signatures with custom op library""" # TODO: validate scale/zero_point dimensions are compatible with block_size if input_dtype not in _SUB_BYTE_UINT_BOUNDS: - assert ( - input.dtype == input_dtype - ), f"Expected: {input_dtype}, got: {input.dtype}" + assert input.dtype == input_dtype, ( + f"Expected: {input_dtype}, got: {input.dtype}" + ) assert output_dtype in [ torch.float32, torch.float16, @@ -619,11 +619,11 @@ def _dequantize_affine_no_dtype_check( 1. figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain - 3. reshape the quantized result to origianl shape and change dtype to the output_dtype + 3. reshape the quantized result to original shape and change dtype to the output_dtype """ - assert ( - len(block_size) == input.dim() - ), f"Got input dim:{input.dim()}, block_size: {block_size}" + assert len(block_size) == input.dim(), ( + f"Got input dim:{input.dim()}, block_size: {block_size}" + ) shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input.size() ) @@ -646,25 +646,25 @@ def _dequantize_affine_no_dtype_check( dequant = dequant.to(output_dtype) dequant = dequant * scale elif zero_point_domain == ZeroPointDomain.NONE.name: - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is NONE" + assert zero_point is None, ( + "zero_point should be None when zero_point_domain is NONE" + ) dequant = input.to(output_dtype) dequant = dequant * scale elif zero_point_domain is None: # This case handles dequantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - assert _is_float8_type( - input.dtype - ), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + assert zero_point is None, ( + "zero_point should be None when zero_point_domain is None" + ) + assert _is_float8_type(input.dtype), ( + f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + ) dequant = input.to(output_dtype) dequant = dequant * scale else: - assert ( - zero_point_domain == ZeroPointDomain.FLOAT.name - ), f"Unexpected zero point domain: {zero_point_domain}" + assert zero_point_domain == ZeroPointDomain.FLOAT.name, ( + f"Unexpected zero point domain: {zero_point_domain}" + ) # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification @@ -697,12 +697,12 @@ def forward(self, input: torch.Tensor): self.min_val = min_val self.max_val = max_val else: - assert ( - self.min_val.shape == min_val.shape - ), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" - assert ( - self.max_val.shape == max_val.shape - ), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + assert self.min_val.shape == min_val.shape, ( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + assert self.max_val.shape == max_val.shape, ( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) min_val = torch.min(self.min_val, min_val) max_val = torch.max(self.max_val, max_val) self.min_val.copy_(min_val) @@ -711,9 +711,9 @@ def forward(self, input: torch.Tensor): return input def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: - assert hasattr(self, "min_val") and hasattr( - self, "max_val" - ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + assert hasattr(self, "min_val") and hasattr(self, "max_val"), ( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) return choose_qparams_affine_with_min_max( self.min_val, self.max_val, @@ -788,12 +788,12 @@ def forward(self, input: torch.Tensor): self.min_val = min_val self.max_val = max_val else: - assert ( - self.min_val.shape == min_val.shape - ), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" - assert ( - self.max_val.shape == max_val.shape - ), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + assert self.min_val.shape == min_val.shape, ( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + assert self.max_val.shape == max_val.shape, ( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) min_val = self.min_val + self.averaging_constant * (min_val - self.min_val) max_val = self.max_val + self.averaging_constant * (max_val - self.max_val) self.min_val.copy_(min_val) @@ -803,9 +803,9 @@ def forward(self, input: torch.Tensor): return input def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: - assert hasattr(self, "min_val") and hasattr( - self, "max_val" - ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + assert hasattr(self, "min_val") and hasattr(self, "max_val"), ( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) return choose_qparams_affine_with_min_max( self.min_val, diff --git a/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/torch/ao/quantization/pt2e/duplicate_dq_pass.py index fdfdc7f84acddf..163184c00f1d1d 100644 --- a/torch/ao/quantization/pt2e/duplicate_dq_pass.py +++ b/torch/ao/quantization/pt2e/duplicate_dq_pass.py @@ -33,7 +33,7 @@ def _maybe_duplicate_dq( gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node ): annotation = user.meta.get("quantization_annotation", None) - if not _is_valid_annotation(annotation): + if not _is_valid_annotation(annotation): # type: ignore[arg-type] return with gm.graph.inserting_after(dq_node): new_node = gm.graph.node_copy(dq_node) diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py index 54bad84e3ee895..80520d1ef0d0db 100644 --- a/torch/ao/quantization/pt2e/graph_utils.py +++ b/torch/ao/quantization/pt2e/graph_utils.py @@ -161,9 +161,9 @@ def bfs_trace_with_node_process( ) -> None: """Traverse the graph module and apply node_op to each node.""" - assert isinstance( - model, (ExportedProgram, torch.fx.GraphModule) - ), f"Expected GraphModule or ExportedProgram, got {type(model)}" + assert isinstance(model, (ExportedProgram, torch.fx.GraphModule)), ( + f"Expected GraphModule or ExportedProgram, got {type(model)}" + ) gm = model.graph_module if isinstance(model, ExportedProgram) else model queue = [gm] while queue: diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index 0c96f915306d3e..aab4c435c872fe 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -177,19 +177,19 @@ class PortNodeMetaForQDQ(PassBase): - Example 1: - Original: [Conv -> AvgPool -> Linear] - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] - - Inner brackets specify which nodes Q/DQ inherit metdata from + - Inner brackets specify which nodes Q/DQ inherit metadata from - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] - Note first Q and last DQ do not inherit metadata from any nodes - Example 2: - Original: [Conv -> AvgPool -> Linear] - AvgPool is not quantized - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] - - Inner brackets specify which nodes Q/DQ inherit metdata from + - Inner brackets specify which nodes Q/DQ inherit metadata from - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation - on the nodes (in this case AvgPool node) to conclude if the node or patter was - supposed to be quantized. And subsequntly decide if the preceding Q, if any, should + on the nodes (in this case AvgPool node) to conclude if the node or pattern was + supposed to be quantized. And subsequently decide if the preceding Q, if any, should inherit metadata from AvgPool. - Dynamically quantized patterns: - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 789f892266beb0..8b1c5bfed4eb1b 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -163,7 +163,7 @@ def _union_input_edge_with( def _get_edge_or_node_to_group_id( - edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], ) -> dict[EdgeOrNode, int]: """Map from edge/node to the group ID, generated from quantization annotations, edge/node with the same group ID should use the same observer/fake_quant instance @@ -275,7 +275,7 @@ def _get_edge_or_node_to_group_id( _update_shared_with(input_edge, qspec, shared_with_map) - # now that we get the sharing relations between all edges and nodes, we can assingn group ids + # now that we get the sharing relations between all edges and nodes, we can assign group ids cur_group_id = 0 edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} for edge_or_node in shared_with_map.keys(): @@ -351,9 +351,9 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( original_arg = arg while _is_activation_post_process_node(original_arg, named_modules): original_arg = original_arg.args[0] # type: ignore[assignment] - assert isinstance( - original_arg, Node - ), f"expect original argument to be a Node, but got: {type(original_arg)}" + assert isinstance(original_arg, Node), ( + f"expect original argument to be a Node, but got: {type(original_arg)}" + ) input_edge = (original_arg, node) if input_edge not in obs_or_fq_map: diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index 7423a5320d97d3..b9ce762896f1f6 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -876,7 +876,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda ) - # remove in place add from batchnorm tracking traning stats + # remove in place add from batchnorm tracking training stats for node in m.graph.nodes: if ( node.target == torch.ops.aten.add_.Tensor diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index ae23b43b9cb0f8..5a757a700498d4 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -300,7 +300,7 @@ def _reference_quantized_conv2d( # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32 # In order to addition of bias_(i)_fp32 inside, we must do # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950 - # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale + # Note we had to multiply bias_fp32 with X_scale * W_scale = bias_scale # Thus bias quantization to int32 must be with X_scale * W_scale bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) @@ -436,7 +436,7 @@ def _reference_quantized_add( x_fp32 = (x_i8 - x_zero_point) * x_scale (3) y_fp32 = (y_i8 - y_zero_point) * y_scale (4) - # applying the above fomula to the out_i8 equation we can get the following: + # applying the above formula to the out_i8 equation we can get the following: out_i8 = out_fp32 / out_scale + out_zero_point # (1) = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32 = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4) @@ -808,7 +808,10 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: replacement_post_trans = rewrite_info.replacement_post_trans pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] - replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] + replacement = _get_aten_graph_module_for_pattern( # type: ignore[assignment] + replacement, + example_inputs, # type: ignore[arg-type] + ) remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] if pattern_post_trans: pattern = pattern_post_trans(pattern) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 0e807b5fbc38e8..f919c3d9dff05e 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -85,10 +85,11 @@ def _find_q_dq_node_for_user( q_node = None if ( - dq_node.args[0].op == "call_function" # type: ignore[union-attr] - and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr] + isinstance(arg := dq_node.args[0], torch.fx.Node) + and arg.op == "call_function" + and arg.target in _QUANTIZE_OPS ): - q_node = dq_node.args[0] + q_node = arg return (q_node, dq_node) diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 246d74b601c838..efee5302ad42ad 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -98,7 +98,8 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])): my_qconfig = QConfig( activation=MinMaxObserver.with_args(dtype=torch.qint8), - weight=default_observer.with_args(dtype=torch.qint8)) + weight=default_observer.with_args(dtype=torch.qint8), + ) """ @@ -561,9 +562,9 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N torch.ao.quantization.MovingAveragePerChannelMinMaxObserver, ), ) - assert ( - not is_per_channel - ), "Per channel weight observer is not supported yet for ConvTranspose{n}d." + assert not is_per_channel, ( + "Per channel weight observer is not supported yet for ConvTranspose{n}d." + ) QConfigAny = Optional[QConfig] diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index a43b69e4fa8f42..bd34a6b8a1f451 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -232,9 +232,9 @@ class QConfigMapping: def __init__(self) -> None: # In increasing match priority: self.global_qconfig: QConfigAny = None - self.object_type_qconfigs: OrderedDict[ - Union[Callable, str], QConfigAny - ] = OrderedDict() + self.object_type_qconfigs: OrderedDict[Union[Callable, str], QConfigAny] = ( + OrderedDict() + ) self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() self.module_name_object_type_order_qconfigs: OrderedDict[ diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py index 0b35c9634ab8ca..e22fba05bbc99c 100644 --- a/torch/ao/quantization/quantization_mappings.py +++ b/torch/ao/quantization/quantization_mappings.py @@ -333,9 +333,9 @@ def get_default_compare_output_module_list() -> set[Callable]: return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST) -def get_default_float_to_quantized_operator_mappings() -> ( - dict[Union[Callable, str], Callable] -): +def get_default_float_to_quantized_operator_mappings() -> dict[ + Union[Callable, str], Callable +]: return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS) @@ -343,9 +343,9 @@ def get_default_float_to_quantized_operator_mappings() -> ( def get_quantized_operator(float_op: Union[Callable, str]) -> Callable: """Get the quantized operator corresponding to the float operator""" quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) - assert ( - quantized_op is not None - ), f"Operator {str(float_op)} does not have corresponding quantized op" + assert quantized_op is not None, ( + f"Operator {str(float_op)} does not have corresponding quantized op" + ) return quantized_op diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 1c85f857e8bc2e..b85618a16331fe 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -158,9 +158,9 @@ def _observer_forward_pre_hook(self, input): def _register_activation_post_process_hook(module, pre_hook=False): - assert hasattr( - module, "activation_post_process" - ), "Expect activation_post_process attribute already attached to the module" + assert hasattr(module, "activation_post_process"), ( + "Expect activation_post_process attribute already attached to the module" + ) if pre_hook: module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True) else: @@ -198,9 +198,9 @@ def _add_observer_( # respect device affinity when adding observers if device is None: devices = _get_unique_devices_(module) - assert ( - len(devices) <= 1 - ), f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" + assert len(devices) <= 1, ( + f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" + ) device = next(iter(devices)) if len(devices) > 0 else None def get_activation_post_process(qconfig, device, special_act_post_process=None): @@ -243,9 +243,9 @@ def insert_activation_post_process(m, special_act_post_process=None): type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional) ): if needs_observation(child): - assert hasattr( - child, "activation_post_process" - ), f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" + assert hasattr(child, "activation_post_process"), ( + f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" + ) child.activation_post_process = get_activation_post_process( child.qconfig, device ) @@ -367,10 +367,8 @@ def prepare( # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module - "float_to_observed_custom_module_class": { - CustomModule: ObservedCustomModule - } - } + "float_to_observed_custom_module_class": {CustomModule: ObservedCustomModule} + } """ torch._C._log_api_usage_once("quantization_api.quantize.prepare") @@ -791,7 +789,9 @@ def swap_module( devices = _get_unique_devices_(mod) assert len(devices) <= 1 or ( len(devices) == 2 and torch.device("meta") in devices - ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ), ( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) device = next(iter(devices)) if len(devices) > 0 else None if device: new_mod.to(device) @@ -811,9 +811,9 @@ def get_prefix(prefix): return prefix if prefix == "" else prefix + "." if hasattr(mod, "activation_post_process"): - target_dict[ - get_prefix(prefix) + "activation_post_process" - ] = mod.activation_post_process + target_dict[get_prefix(prefix) + "activation_post_process"] = ( + mod.activation_post_process + ) for name, child in mod.named_children(): module_prefix = get_prefix(prefix) + name if prefix else name _get_observer_dict(child, target_dict, module_prefix) diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 86ba43693e17c8..c59d35c573505f 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -84,9 +84,7 @@ def _fuse_fx( model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) """ _check_is_graph_module(model) - return fuse( - model, is_qat, fuse_custom_config, backend_config - ) # type: ignore[operator] + return fuse(model, is_qat, fuse_custom_config, backend_config) # type: ignore[operator] def _prepare_fx( @@ -187,7 +185,7 @@ def _prepare_standalone_module_fx( same as input_quantized_idxs configuration provided for the standalone module * `standalone_module_output_quantized_idxs(List[Int])`: a list of - indexs for the graph output that is quantized + indices for the graph output that is quantized same as input_quantized_idxs configuration provided for the standalone module @@ -219,6 +217,7 @@ def fuse_fx( Example:: from torch.ao.quantization import fuse_fx + m = Model().eval() m = fuse_fx(m) @@ -430,14 +429,17 @@ def prepare_qat_fx( from torch.ao.quantization import get_default_qat_qconfig_mapping from torch.ao.quantization.quantize_fx import prepare_qat_fx + class Submodule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) + def forward(self, x): x = self.linear(x) return x + class M(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -449,17 +451,20 @@ def forward(self, x): x = self.sub(x) + x return x + # initialize a floating point model float_model = M().train() # (optional, but preferred) load the weights from pretrained model # float_model.load_weights(...) + # define the training loop for quantization aware training def train_loop(model, train_data): model.train() for image, target in data_loader: ... + # qconfig is the configuration for how we insert observers for a particular # operator # qconfig = get_default_qconfig("fbgemm") @@ -474,7 +479,7 @@ def train_loop(model, train_data): # in the model through qconfig_mapping # the following call will get the qconfig_mapping that works best for models # that target "fbgemm" backend - qconfig_mapping = get_default_qat_qconfig("fbgemm") + qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm") # We can customize qconfig_mapping in different ways, please take a look at # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index 59e546458f8c05..38d9cd6b8b765e 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -157,12 +157,12 @@ def _convert_ondevice_jit( model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC ): _check_is_script_module(model) - assert ( - quant_type == QuantType.DYNAMIC - ), "This API, while should work for static quant, is only tested for dynamic quant." - assert not method_name.startswith( - "observe_" - ), "Pass in valid method to be quantized, e.g. forward" + assert quant_type == QuantType.DYNAMIC, ( + "This API, while should work for static quant, is only tested for dynamic quant." + ) + assert not method_name.startswith("observe_"), ( + "Pass in valid method to be quantized, e.g. forward" + ) observe_method_name = "observe_" + method_name quantize_method_name = "quantize_" + method_name model_c = model._c @@ -230,12 +230,12 @@ def _quantize_jit( model = prepare_dynamic_jit(model, qconfig_dict, inplace) model = convert_dynamic_jit(model, True, debug) else: - assert ( - run_fn - ), "Must provide calibration function for post training static quantization" - assert ( - run_args - ), "Must provide calibration dataset for post training static quantization" + assert run_fn, ( + "Must provide calibration function for post training static quantization" + ) + assert run_args, ( + "Must provide calibration dataset for post training static quantization" + ) model = prepare_jit(model, qconfig_dict, inplace) run_fn(model, *run_args) model = convert_jit(model, True, debug) @@ -280,19 +280,22 @@ def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=Fal from torch.ao.quantization import get_default_qconfig from torch.ao.quantization import quantize_jit - ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) - qconfig = get_default_qconfig('fbgemm') + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + + def calibrate(model, data_loader): model.eval() with torch.no_grad(): for image, target in data_loader: model(image) + quantized_model = quantize_jit( - ts_model, - {'': qconfig}, - calibrate, - [data_loader_test]) + ts_model, {"": qconfig}, calibrate, [data_loader_test] + ) ``` """ torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit") @@ -330,19 +333,22 @@ def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False): from torch.ao.quantization import per_channel_dynamic_qconfig from torch.ao.quantization import quantize_dynamic_jit - ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) - qconfig = get_default_qconfig('fbgemm') + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + + def calibrate(model, data_loader): model.eval() with torch.no_grad(): for image, target in data_loader: model(image) + quantized_model = quantize_dynamic_jit( - ts_model, - {'': qconfig}, - calibrate, - [data_loader_test]) + ts_model, {"": qconfig}, calibrate, [data_loader_test] + ) ``` """ torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit") @@ -401,13 +407,13 @@ def _quantize_ondevice_dynamic_jit( from torch.ao.quantization import per_channel_dynamic_qconfig from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit - ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) - qconfig = get_default_qconfig('fbgemm') + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") quant_ready_model = _quantize_ondevice_dynamic_jit( - ts_model, - {'': qconfig}, - 'forward', - True) + ts_model, {"": qconfig}, "forward", True + ) ``` """ return _quantize_ondevice_dynamic_jit_impl( diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index 3188eba9e96c5b..169e2905ddbdcc 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -76,7 +76,7 @@ def calibrate(model, data_loader): # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured - # result shoud mostly stay the same + # result should mostly stay the same m = torch.export.export_for_training(m, *example_inputs).module() # we get a model with aten ops @@ -153,7 +153,7 @@ def train_loop(model, train_data): # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured - # result shoud mostly stay the same + # result should mostly stay the same m = torch.export.export_for_training(m, *example_inputs).module() # we get a model with aten ops @@ -218,7 +218,7 @@ def convert_pt2e( Args: * `model` (torch.fx.GraphModule): calibrated/trained model - * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not + * `use_reference_representation` (bool): boolean flag to indicate whether to produce reference representation or not * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not Returns: diff --git a/torch/ao/quantization/quantizer/composable_quantizer.py b/torch/ao/quantization/quantizer/composable_quantizer.py index 6b95edbc2193a4..15404cc5601177 100644 --- a/torch/ao/quantization/quantizer/composable_quantizer.py +++ b/torch/ao/quantization/quantizer/composable_quantizer.py @@ -28,8 +28,12 @@ class ComposableQuantizer(Quantizer): ``` embedding_quantizer = EmbeddingQuantizer() linear_quantizer = MyLinearQuantizer() - xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers - composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer]) + xnnpack_quantizer = ( + XNNPackQuantizer() + ) # to handle ops not quantized by previous two quantizers + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, linear_quantizer, xnnpack_quantizer] + ) prepared_m = prepare_pt2e(model, composed_quantizer) ``` """ diff --git a/torch/ao/quantization/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py index 7da601052a9c06..9884cb1990f079 100644 --- a/torch/ao/quantization/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -111,7 +111,7 @@ class DerivedQuantizationSpec(QuantizationSpecBase): @dataclass class QuantizationAnnotation: - """How are input arguemnt or output should be quantized, + """How are input argument or output should be quantized, expressed as QuantizationSpec, this corresponds to how a Tensor in the operator Graph is observed (PTQ) or fake quantized (QAT) """ diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index cae2ec30d1e337..04fefb7e463bca 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -28,7 +28,7 @@ def _node_only_used_for_sym_size(node: Node, partition_nodes: list[Node]): This utility is used to handle cases when dynami_shape=True tracing leads to symint nodes in the pattern of linear module. In those cases, we need to distinguish between the nodes that are in input for just extracting value of - some dimentions (and symint nodes) vs. the one that is activation. + some dimensions (and symint nodes) vs. the one that is activation. For example: graph(x, y, weight): size_0 = torch.ops.aten.sym_size([x], [0]) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index df4d94b3fbf315..e4777645a9e90c 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -807,19 +807,19 @@ def _annotate_qat_conv2d_bn_binary_unary( binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( quantization_config ) - binary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - _annotated=True, + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) ) - unary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) else: _annotate_nodes_not_quantize([binary_node, unary_node]) @@ -877,14 +877,14 @@ def _annotate_qat_conv2d_bn_binary( binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( quantization_config ) - binary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) else: _annotate_nodes_not_quantize(binary_node) @@ -934,13 +934,13 @@ def _annotate_qat_conv2d_bn_unary( self._annotate_conv_node_helper(conv_node, False, quantization_config) if quantization_config is not None: - unary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) else: _annotate_nodes_not_quantize(unary_node) @@ -975,13 +975,13 @@ def _annotate_qat_conv2d_bn( self._annotate_conv_node_helper(conv_node, False, quantization_config) if quantization_config is not None: - bn_output_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, + bn_output_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) else: _annotate_nodes_not_quantize(bn_output_node) @@ -1556,19 +1556,19 @@ def _annotate_linear_binary_unary( linear_node, False, quantization_config ) # We don't insert q-dq before the binary input node due to accuracy issues - binary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map={}, - _annotated=True, - _is_output_of_quantized_pattern=(not has_unary), + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map={}, + _annotated=True, + _is_output_of_quantized_pattern=(not has_unary), + ) ) if unary_node is not None: - unary_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - _annotated=True, - _is_output_of_quantized_pattern=True, + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) ) def validate(self, model: torch.fx.GraphModule) -> None: diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 044d4ff63af308..6005152a4d73f0 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -245,7 +245,7 @@ def not_module_type_or_name_filter(n: Node) -> bool: class XNNPACKQuantizer(Quantizer): """ !!! DEPRECATED !!! - XNNPACKQuantizer is a marked as deprected. It will be removed in the future. + XNNPACKQuantizer is a marked as deprecated. It will be removed in the future. It has been moved to executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer. Please use the new quantizer instead. """ @@ -345,9 +345,9 @@ def set_module_name( quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator patterns in the submodule with this module name with the given `quantization_config` """ - assert ( - quantization_config is not None - ), " quantization_config == None is not supported yet" + assert quantization_config is not None, ( + " quantization_config == None is not supported yet" + ) self.module_name_config[module_name] = quantization_config return self diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index c9891bc7add85d..f8ac0a7727de35 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -165,9 +165,9 @@ def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): if quantization_config.bias is None: return None quantization_spec: QuantizationSpec = quantization_config.bias - assert ( - quantization_spec.dtype == torch.float - ), "Only float dtype for bias is supported for bias right now" + assert quantization_spec.dtype == torch.float, ( + "Only float dtype for bias is supported for bias right now" + ) return quantization_spec @@ -422,7 +422,7 @@ def _annotate_conv_bn( filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: """ - Find conv + batchnorm parititions + Find conv + batchnorm partitions Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) @@ -435,7 +435,7 @@ def _annotate_conv_bn_relu( filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: """ - Find conv + batchnorm + relu parititions + Find conv + batchnorm + relu partitions Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) @@ -448,7 +448,7 @@ def _annotate_conv_transpose_bn( filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: """ - Find conv_transpose + batchnorm parititions + Find conv_transpose + batchnorm partitions Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ return _do_annotate_conv_bn( @@ -463,7 +463,7 @@ def _annotate_conv_transpose_bn_relu( filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: """ - Find conv_transpose + batchnorm + relu parititions + Find conv_transpose + batchnorm + relu partitions Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ return _do_annotate_conv_bn( diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index 68dd42936cf529..eff97dbcf27da0 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -85,7 +85,7 @@ def __init__(self) -> None: overrides. We keep the annotate methods but make the function body empty, aiming to let `_generate_qdq_quantized_model` generate qdq around op and graph execute on fp32 dtype for - unspported operators. + unsupported operators. """ def _annotate_qat_conv2d_fusion_pattern( diff --git a/torch/ao/quantization/stubs.py b/torch/ao/quantization/stubs.py index 916d7de35c902e..ebfffcb756f765 100644 --- a/torch/ao/quantization/stubs.py +++ b/torch/ao/quantization/stubs.py @@ -1,6 +1,11 @@ -# mypy: allow-untyped-defs +from typing import Any, Optional +import torch from torch import nn +from torch.ao.quantization import QConfig + + +__all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"] class QuantStub(nn.Module): @@ -12,12 +17,12 @@ class QuantStub(nn.Module): if qconfig is not provided, we will get qconfig from parent modules """ - def __init__(self, qconfig=None): + def __init__(self, qconfig: Optional[QConfig] = None): super().__init__() if qconfig: self.qconfig = qconfig - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x @@ -30,12 +35,12 @@ class DeQuantStub(nn.Module): if qconfig is not provided, we will get qconfig from parent modules """ - def __init__(self, qconfig=None): + def __init__(self, qconfig: Optional[Any] = None): super().__init__() if qconfig: self.qconfig = qconfig - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x @@ -50,11 +55,12 @@ class QuantWrapper(nn.Module): will be swapped to `nnq.Quantize` which does actual quantization. Similarly for `DeQuantStub`. """ + quant: QuantStub dequant: DeQuantStub module: nn.Module - def __init__(self, module): + def __init__(self, module: nn.Module): super().__init__() qconfig = getattr(module, "qconfig", None) self.add_module("quant", QuantStub(qconfig)) @@ -62,7 +68,7 @@ def __init__(self, module): self.add_module("module", module) self.train(module.training) - def forward(self, X): + def forward(self, X: torch.Tensor) -> torch.Tensor: X = self.quant(X) X = self.module(X) return self.dequant(X) diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index 0e375ad9823758..feae45df3b8631 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -2,6 +2,7 @@ """ Utils shared by different modes of quantization (eager/graph) """ + import functools import warnings from collections import OrderedDict @@ -414,9 +415,9 @@ def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" else: - assert torch.all( - min_val <= max_val - ), f"min {min_val} should be less than max {max_val}" + assert torch.all(min_val <= max_val), ( + f"min {min_val} should be less than max {max_val}" + ) return True @@ -451,13 +452,13 @@ def calculate_qmin_qmax( qrange_len = initial_quant_max - initial_quant_min + 1 if dtype in [torch.qint8, torch.int8]: - assert ( - 0 < qrange_len <= 256 - ), "quantization range should be positive and not exceed the maximum bit range (=256)." + assert 0 < qrange_len <= 256, ( + "quantization range should be positive and not exceed the maximum bit range (=256)." + ) elif dtype in [torch.qint32, torch.int32]: - assert ( - 0 < qrange_len <= 2**32 - ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)." + assert 0 < qrange_len <= 2**32, ( + "quantization range should be positive and not exceed the maximum bit range (=4294967296)." + ) if reduce_range: quant_min, quant_max = quant_min // 2, quant_max // 2 else: @@ -605,17 +606,17 @@ def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: """ # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. - assert ( - quant_min <= 0 <= quant_max - ), "Used-specified quantization range must include 0." - assert ( - quant_min < quant_max - ), "qmin must be strictly less than qmax for user-specified quantization range." + assert quant_min <= 0 <= quant_max, ( + "Used-specified quantization range must include 0." + ) + assert quant_min < quant_max, ( + "qmin must be strictly less than qmax for user-specified quantization range." + ) # Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer -# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change +# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikely to change # (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168) def determine_qparams( min_val: torch.Tensor, diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 66cf168f411df2..74dcb4b70433c0 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -138,7 +138,7 @@ def _make_grads( shape_matches = expect_true(sym_eq(out_size, first_grad.size())) if not shape_matches: - out = cast(Union[torch.Tensor, graph.GradientEdge], out) + out = cast(Union[torch.Tensor, graph.GradientEdge], out) # type: ignore[redundant-cast] out_shape, grad_shape = _calculate_shape( out, first_grad, is_grads_batched ) diff --git a/torch/autograd/anomaly_mode.py b/torch/autograd/anomaly_mode.py index 511d27df077fab..0277f1b75541f6 100644 --- a/torch/autograd/anomaly_mode.py +++ b/torch/autograd/anomaly_mode.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Autograd anomaly mode.""" + import warnings import torch @@ -31,6 +32,7 @@ class detect_anomaly: ... @staticmethod ... def forward(ctx, inp): ... return inp.clone() + ... ... @staticmethod ... def backward(ctx, gO): ... # Error during the backward pass diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 2840689892615b..b8036a5235b914 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -335,9 +335,6 @@ def __init__(cls, name, bases, attrs): name + "Backward", (BackwardCFunction,), {"_forward_cls": cls} ) backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined] - backward_fn._bw_module = None # type: ignore[attr-defined] - if getattr(cls, "_lazy_backward_info", None): - backward_fn._bw_module = cls._lazy_backward_info.bw_module # type: ignore[attr-defined] cls._backward_cls = backward_fn super().__init__(name, bases, attrs) @@ -369,6 +366,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: def forward(*args: Any, **kwargs: Any) -> Any: pass + @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass @@ -769,6 +767,7 @@ class NestedIOFunction(Function): This class is here only for backward compatibility reasons. Use :class:`Function` instead of this for any new use case. """ + # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the # superclass (Function) but are instance methods here, which mypy reports as incompatible. @@ -815,7 +814,7 @@ def save_for_backward(self, *args: Any) -> None: self._to_save_nested = args @property - def saved_tensors(self): + def saved_tensors(self): # type: ignore[override] r""" See :meth:`Function.saved_tensors`. """ diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py index de1c1343347615..09ced2e03f775b 100644 --- a/torch/autograd/functional.py +++ b/torch/autograd/functional.py @@ -653,6 +653,16 @@ def jacobian( [0.0000, 3.3963]]), tensor([[3., 0.], [0., 3.]])) + + >>> def linear_model(x): + ... W = torch.tensor([[2.0, -1.0], [0.0, 1.0]]) + ... b = torch.tensor([1.0, 0.5]) + ... return x @ W.T + b + + >>> x = torch.randn(4, 2, requires_grad=True) + >>> jac = jacobian(linear_model, x, vectorize=True) + >>> jac.shape + torch.Size([4, 2, 4, 2]) """ assert strategy in ("forward-mode", "reverse-mode"), ( 'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your ' diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 73c07294819876..92bbd129e1439e 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -196,6 +196,12 @@ def __enter__(self) -> None: def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch._C._set_grad_enabled(self.prev) + def __str__(self) -> str: + return f"{torch.typename(self)}(mode={self.mode})" + + def __repr__(self) -> str: + return str(self) + def clone(self) -> "set_grad_enabled": r""" Create a copy of this class diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 20f4d9704f50ae..d1099e969ccc8d 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -2036,15 +2036,15 @@ def gradcheck( ``True`` if all differences satisfy allclose condition """ - assert ( - check_forward_ad or check_backward_ad - ), "Expected at least one of check_forward_ad or check_backward_ad to be True" - assert not ( - check_batched_grad and not check_backward_ad - ), "Setting check_batched_grad=True requires check_backward_ad to be True" - assert not ( - check_batched_forward_grad and not check_forward_ad - ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True" + assert check_forward_ad or check_backward_ad, ( + "Expected at least one of check_forward_ad or check_backward_ad to be True" + ) + assert not (check_batched_grad and not check_backward_ad), ( + "Setting check_batched_grad=True requires check_backward_ad to be True" + ) + assert not (check_batched_forward_grad and not check_forward_ad), ( + "Setting check_batched_forward_grad=True requires check_forward_ad to be True" + ) args = locals().copy() args.pop("raise_exception") if not raise_exception: @@ -2189,15 +2189,15 @@ def gradgradcheck( Returns: True if all differences satisfy allclose condition """ - assert ( - check_fwd_over_rev or check_rev_over_rev - ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" - assert not ( - check_undefined_grad and not check_rev_over_rev - ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True" - assert not ( - check_batched_grad and not check_rev_over_rev - ), "Setting check_batched_grad=True requires check_rev_over_rev to be True" + assert check_fwd_over_rev or check_rev_over_rev, ( + "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" + ) + assert not (check_undefined_grad and not check_rev_over_rev), ( + "Setting check_undefined_grad=True requires check_rev_over_rev to be True" + ) + assert not (check_batched_grad and not check_rev_over_rev), ( + "Setting check_batched_grad=True requires check_rev_over_rev to be True" + ) # TODO: do we want to test this too? # assert not (check_batched_forward_grad and not check_fwd_over_rev), ( # "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True") diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 0e36f89ca08592..dc7d07955f3b2b 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -509,9 +509,9 @@ def get_inner_hook(idx: int) -> Callable[[torch.Tensor], None]: def inner_hook(grad: torch.Tensor) -> None: nonlocal count, nb_calls, buffer, fn id = torch._C._current_graph_task_id() - assert ( - id != -1 - ), "expected this hook to be called inside a backward call" + assert id != -1, ( + "expected this hook to be called inside a backward call" + ) count[id] = count.get(id, 0) buffer[id] = buffer.get(id, [None] * len_tensors) @@ -720,9 +720,9 @@ def clear(self) -> None: @contextlib.contextmanager -def allow_mutation_on_saved_tensors() -> ( - Generator[_AllowMutationOnSavedContext, None, None] -): +def allow_mutation_on_saved_tensors() -> Generator[ + _AllowMutationOnSavedContext, None, None +]: """Context manager under which mutating tensors saved for backward is allowed. Under this context manager, tensors saved for backward are cloned on mutation, diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 15935d6453b969..af0694676550ea 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -95,6 +95,7 @@ def _run_on_profiler_stop(): @dataclass class _ProfilerStats: "Profiler timing and stats used by developers to catch issues/regressions" + profiling_window_duration_sec: float = 0 number_of_events: int = 0 profiler_prepare_call_duration_us: int = 0 @@ -251,9 +252,9 @@ def __init__( self.custom_trace_id_callback = custom_trace_id_callback self.trace_id = "" if not self.use_cpu: - assert ( - use_kineto - ), "Device-only events supported only with Kineto (use_kineto=True)" + assert use_kineto, ( + "Device-only events supported only with Kineto (use_kineto=True)" + ) if self.use_device is not None: VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia", "hpu"] @@ -290,35 +291,35 @@ def __init__( else: self.kineto_activities.add(ProfilerActivity.CUDA) elif self.use_device == "xpu": - assert ( - use_kineto and ProfilerActivity.XPU in _supported_activities() - ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." + assert use_kineto and ProfilerActivity.XPU in _supported_activities(), ( + "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." + ) self.kineto_activities.add(ProfilerActivity.XPU) elif self.use_device == "mtia": - assert ( - use_kineto and ProfilerActivity.MTIA in _supported_activities() - ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices." + assert use_kineto and ProfilerActivity.MTIA in _supported_activities(), ( + "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices." + ) self.kineto_activities.add(ProfilerActivity.MTIA) elif self.use_device == "hpu": - assert ( - use_kineto and ProfilerActivity.HPU in _supported_activities() - ), "Legacy HPU profiling is not supported. Requires use_kineto=True on HPU devices." + assert use_kineto and ProfilerActivity.HPU in _supported_activities(), ( + "Legacy HPU profiling is not supported. Requires use_kineto=True on HPU devices." + ) self.kineto_activities.add(ProfilerActivity.HPU) elif self.use_device is not None and self.use_device != "privateuseone": if ( not use_kineto or ProfilerActivity.PrivateUse1 not in _supported_activities() ): - assert ( - self.use_cpu - ), "Legacy custombackend profiling requires use_cpu=True" + assert self.use_cpu, ( + "Legacy custombackend profiling requires use_cpu=True" + ) self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK else: self.kineto_activities.add(ProfilerActivity.PrivateUse1) - assert ( - len(self.kineto_activities) > 0 - ), "No activities specified for the profiler" + assert len(self.kineto_activities) > 0, ( + "No activities specified for the profiler" + ) def default_trace_id(self): # Generate a UUID @@ -582,7 +583,10 @@ def _device_memory_usage(mem_record): device_corr_map: dict[int, list[FunctionEvent]] = {} max_evt_id = 0 for kineto_event in result.events(): - if _filter_name(kineto_event.name()): + if ( + _filter_name(kineto_event.name()) + or getattr(kineto_event, "is_hidden_event", lambda: False)() + ): continue rel_start_ns = kineto_event.start_ns() - trace_start_ns rel_end_ns = kineto_event.end_ns() - trace_start_ns @@ -738,11 +742,12 @@ class record_function(_ContextDecorator): >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) >>> x = torch.randn((1, 1), requires_grad=True) >>> with torch.autograd.profiler.profile() as prof: - ... y = x ** 2 - ... with torch.autograd.profiler.record_function("label-z"): # label the block - ... z = y ** 3 + ... y = x**2 + ... with torch.autograd.profiler.record_function( + ... "label-z" + ... ): # label the block + ... z = y**3 ... y.backward() - ... >>> # xdoctest: +IGNORE_WANT >>> # NOTE: some columns were removed for brevity >>> print(prof.key_averages().table(sort_by="self_cpu_time_total")) diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 5cd5ce34c6bbda..84dd5f1013d788 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -126,9 +126,9 @@ def _populate_cpu_children(self): current_events.pop() else: parent.append_cpu_child(event) - assert ( - event.cpu_parent is None - ), f"There is already a CPU parent event for {event.key}" + assert event.cpu_parent is None, ( + f"There is already a CPU parent event for {event.key}" + ) event.set_cpu_parent(parent) break @@ -398,13 +398,13 @@ def _format_memory(nbytes): MB = 1024 * KB GB = 1024 * MB if abs(nbytes) >= GB: - return f"{nbytes * 1.0 / GB:.2f} Gb" + return f"{nbytes * 1.0 / GB:.2f} GB" elif abs(nbytes) >= MB: - return f"{nbytes * 1.0 / MB:.2f} Mb" + return f"{nbytes * 1.0 / MB:.2f} MB" elif abs(nbytes) >= KB: - return f"{nbytes * 1.0 / KB:.2f} Kb" + return f"{nbytes * 1.0 / KB:.2f} KB" else: - return str(nbytes) + " b" + return str(nbytes) + " B" def _attr_formatter(name): diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index 90166913e324e3..de194b12d02c3b 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -1,7 +1,10 @@ # mypy: allow-untyped-defs +import sys import types from contextlib import contextmanager +import torch + # The idea for this parameter is that we forbid bare assignment # to torch.backends..enabled and friends when running our @@ -57,6 +60,70 @@ def __getattr__(self, attr): return self.m.__getattribute__(attr) +class _FP32Precision: + def __init__(self, backend, op): + self.backend = backend + self.op = op + + def __setattr__(self, name, value): + if name == "fp32_precision": + torch._C._set_fp32_precision_setter(self.backend, self.op, value) + elif name in ("backend", "op"): + super().__setattr__(name, value) + else: + raise AttributeError("Unknown attribute " + name) + + def __getattr__(self, name): + if name == "fp32_precision": + return torch._C._get_fp32_precision_getter(self.backend, self.op) + else: + raise AttributeError("Unknown attribute " + name) + + +def set_flags(_fp32_precision="none"): + orig_flags = (torch._C._get_fp32_precision_getter("generic", "all"),) + if _fp32_precision is not None: + torch._C._set_fp32_precision_setter("generic", "all", _fp32_precision) + return orig_flags + + +@contextmanager +def flags(fp32_precision="none"): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(fp32_precision) + try: + yield + finally: + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +def _get_fp32_precision_getter(backend, op): + def inner(): + return torch._C._get_fp32_precision_getter(backend, op) + + return inner + + +def _set_fp32_precision_setter(backend, op): + def inner(precision): + return torch._C._set_fp32_precision_setter(backend, op, precision) + + return inner + + +class GenericModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + + fp32_precision = ContextProp( + _get_fp32_precision_getter("generic", "all"), + _set_fp32_precision_setter("generic", "all"), + ) + + +sys.modules[__name__] = GenericModule(sys.modules[__name__], __name__) + from torch.backends import ( cpu as cpu, cuda as cuda, diff --git a/torch/backends/_coreml/preprocess.py b/torch/backends/_coreml/preprocess.py index f05e0bcee9fb2b..3180e56a6baf96 100644 --- a/torch/backends/_coreml/preprocess.py +++ b/torch/backends/_coreml/preprocess.py @@ -55,6 +55,7 @@ def CompileSpec( allow_low_precision=True, quantization_mode=CoreMLQuantizationMode.NONE, mlmodel_export_path=None, + convert_to=None, ): return ( inputs, @@ -63,6 +64,7 @@ def CompileSpec( allow_low_precision, quantization_mode, mlmodel_export_path, + convert_to, ) @@ -91,6 +93,7 @@ def preprocess(script_module: torch._C.ScriptObject, compile_spec: dict[str, tup allow_low_precision, quantization_mode, mlmodel_export_path, + convert_to, ) = spec mil_inputs = [] inputs = [] @@ -101,7 +104,7 @@ def preprocess(script_module: torch._C.ScriptObject, compile_spec: dict[str, tup ml_type = _convert_to_mil_type(shape, dtype, name) mil_inputs.append(ml_type) model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None) - mlmodel = ct.convert(model, inputs=mil_inputs) + mlmodel = ct.convert(model, inputs=mil_inputs, convert_to=convert_to) if quantization_mode != CoreMLQuantizationMode.NONE: quant_model_spec = quantization_utils.quantize_weights( diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 7541c9c7ca670d..d10af638bd243f 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -135,6 +135,8 @@ def __getattr__(self, name): return torch._C._get_cublas_allow_bf16_reduced_precision_reduction() elif name == "allow_fp16_accumulation": return torch._C._get_cublas_allow_fp16_accumulation() + elif name == "fp32_precision": + return torch._C._get_fp32_precision_getter("cuda", "matmul") raise AttributeError("Unknown attribute " + name) def __setattr__(self, name, value): @@ -146,6 +148,8 @@ def __setattr__(self, name, value): return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value) elif name == "allow_fp16_accumulation": return torch._C._set_cublas_allow_fp16_accumulation(value) + elif name == "fp32_precision": + return torch._C._set_fp32_precision_setter("cuda", "matmul", value) raise AttributeError("Unknown attribute " + name) @@ -158,7 +162,7 @@ def __setattr__(self, name, value): def preferred_linalg_library( - backend: Union[None, str, torch._C._LinalgBackend] = None + backend: Union[None, str, torch._C._LinalgBackend] = None, ) -> torch._C._LinalgBackend: r""" Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations. @@ -206,7 +210,7 @@ def preferred_linalg_library( elif isinstance(backend, str): if backend not in _LinalgBackends: raise RuntimeError( - "Unknown input value. " f"Choose from: {_LinalgBackends_str}." + f"Unknown input value. Choose from: {_LinalgBackends_str}." ) torch._C._set_linalg_preferred_backend(_LinalgBackends[backend]) elif isinstance(backend, torch._C._LinalgBackend): @@ -229,7 +233,7 @@ def preferred_linalg_library( def preferred_blas_library( - backend: Union[None, str, torch._C._BlasBackend] = None + backend: Union[None, str, torch._C._BlasBackend] = None, ) -> torch._C._BlasBackend: r""" Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only]. @@ -261,7 +265,7 @@ def preferred_blas_library( elif isinstance(backend, str): if backend not in _BlasBackends: raise RuntimeError( - "Unknown input value. " f"Choose from: {_BlasBackends_str}." + f"Unknown input value. Choose from: {_BlasBackends_str}." ) torch._C._set_blas_preferred_backend(_BlasBackends[backend]) elif isinstance(backend, torch._C._BlasBackend): @@ -284,7 +288,7 @@ def preferred_blas_library( def preferred_rocm_fa_library( - backend: Union[None, str, torch._C._ROCmFABackend] = None + backend: Union[None, str, torch._C._ROCmFABackend] = None, ) -> torch._C._ROCmFABackend: r""" [ROCm-only] @@ -312,13 +316,13 @@ def preferred_rocm_fa_library( elif isinstance(backend, str): if backend not in _ROCmFABackends: raise RuntimeError( - "Unknown input value. " f"Choose from: {_ROCmFABackends_str}." + f"Unknown input value. Choose from: {_ROCmFABackends_str}." ) torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend]) elif isinstance(backend, torch._C._ROCmFABackend): torch._C._set_rocm_fa_preferred_backend(backend) else: - raise ValueError("Unknown input value. " f"Choose from: {_ROCmFABackends_str}.") + raise ValueError(f"Unknown input value. Choose from: {_ROCmFABackends_str}.") return torch._C._get_rocm_fa_preferred_backend() diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 0ba6f9e3b40bdf..9c155de7c04b09 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -6,7 +6,14 @@ from typing import Optional import torch -from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule +from torch.backends import ( + __allow_nonbracketed_mutation, + _FP32Precision, + _get_fp32_precision_getter, + _set_fp32_precision_setter, + ContextProp, + PropModule, +) try: @@ -128,6 +135,7 @@ def set_flags( _benchmark_limit=None, _deterministic=None, _allow_tf32=None, + _fp32_precision="none", ): orig_flags = ( torch._C._get_cudnn_enabled(), @@ -135,6 +143,7 @@ def set_flags( None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), torch._C._get_cudnn_deterministic(), torch._C._get_cudnn_allow_tf32(), + torch._C._get_fp32_precision_getter("cuda", "all"), ) if _enabled is not None: torch._C._set_cudnn_enabled(_enabled) @@ -146,6 +155,8 @@ def set_flags( torch._C._set_cudnn_deterministic(_deterministic) if _allow_tf32 is not None: torch._C._set_cudnn_allow_tf32(_allow_tf32) + if _fp32_precision is not None: + torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision) return orig_flags @@ -156,10 +167,16 @@ def flags( benchmark_limit=10, deterministic=False, allow_tf32=True, + fp32_precision="none", ): with __allow_nonbracketed_mutation(): orig_flags = set_flags( - enabled, benchmark, benchmark_limit, deterministic, allow_tf32 + enabled, + benchmark, + benchmark_limit, + deterministic, + allow_tf32, + fp32_precision, ) try: yield @@ -194,6 +211,12 @@ def __init__(self, m, name): allow_tf32 = ContextProp( torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32 ) + conv = _FP32Precision("cuda", "conv") + rnn = _FP32Precision("cuda", "rnn") + fp32_precision = ContextProp( + _get_fp32_precision_getter("cuda", "all"), + _set_fp32_precision_setter("cuda", "all"), + ) # This is the sys.modules replacement trick, see diff --git a/torch/backends/mkl/__init__.py b/torch/backends/mkl/__init__.py index 9f96d692ae0219..ae16922761afea 100644 --- a/torch/backends/mkl/__init__.py +++ b/torch/backends/mkl/__init__.py @@ -30,6 +30,7 @@ class verbose: .. code-block:: python import torch + model(data) with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON): model(data) @@ -47,9 +48,9 @@ def __enter__(self): if self.enable == VERBOSE_OFF: return st = torch._C._verbose.mkl_set_verbose(self.enable) - assert ( - st - ), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." + assert st, ( + "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." + ) return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 2c97bcd9b079ba..ae76a9f20c46f5 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -4,7 +4,14 @@ from typing import TYPE_CHECKING import torch -from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule +from torch.backends import ( + __allow_nonbracketed_mutation, + _FP32Precision, + _get_fp32_precision_getter, + _set_fp32_precision_setter, + ContextProp, + PropModule, +) def is_available(): @@ -36,6 +43,7 @@ class verbose: .. code-block:: python import torch + model(data) with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON): model(data) @@ -54,9 +62,9 @@ def __enter__(self): if self.level == VERBOSE_OFF: return st = torch._C._verbose.mkldnn_set_verbose(self.level) - assert ( - st - ), "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." + assert st, ( + "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." + ) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -64,11 +72,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False -def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None): +def set_flags( + _enabled=None, _deterministic=None, _allow_tf32=None, _fp32_precision="none" +): orig_flags = ( torch._C._get_mkldnn_enabled(), torch._C._get_mkldnn_deterministic(), torch._C._get_onednn_allow_tf32(), + torch._C._get_fp32_precision_getter("mkldnn", "all"), ) if _enabled is not None: torch._C._set_mkldnn_enabled(_enabled) @@ -76,13 +87,15 @@ def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None): torch._C._set_mkldnn_deterministic(_deterministic) if _allow_tf32 is not None: torch._C._set_onednn_allow_tf32(_allow_tf32) + if _fp32_precision is not None: + torch._C._set_fp32_precision_setter("mkldnn", "all", _fp32_precision) return orig_flags @contextmanager -def flags(enabled=False, deterministic=False, allow_tf32=True): +def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="none"): with __allow_nonbracketed_mutation(): - orig_flags = set_flags(enabled, deterministic, allow_tf32) + orig_flags = set_flags(enabled, deterministic, allow_tf32, fp32_precision) try: yield finally: @@ -104,6 +117,13 @@ def is_available(self): allow_tf32 = ContextProp( torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32 ) + matmul = _FP32Precision("mkldnn", "matmul") + conv = _FP32Precision("mkldnn", "conv") + rnn = _FP32Precision("mkldnn", "rnn") + fp32_precision = ContextProp( + _get_fp32_precision_getter("mkldnn", "all"), + _set_fp32_precision_setter("generic", "all"), + ) if TYPE_CHECKING: diff --git a/torch/backends/mps/__init__.py b/torch/backends/mps/__init__.py index d3b934a6ced4c3..2fe445bfcb0e55 100644 --- a/torch/backends/mps/__init__.py +++ b/torch/backends/mps/__init__.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs from functools import lru_cache as _lru_cache from typing import Optional, TYPE_CHECKING @@ -40,7 +39,7 @@ def is_macos13_or_newer(minor: int = 0) -> bool: _lib: Optional[_Library] = None -def _init(): +def _init() -> None: r"""Register prims as implementation of var_mean and group_norm.""" global _lib diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index 5dcff85a8ec43d..59c89149b47164 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -262,9 +262,11 @@ def numa_aware_check(self, core_list): class _Launcher: r"""Class for launcher.""" - msg_lib_notfound = f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \ + msg_lib_notfound = ( + f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \ {expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set." + ) def __init__(self) -> None: self.cpuinfo = _CPUinfo() @@ -611,14 +613,12 @@ def launch(self, args): args.rank == -1 ): # sequentially assign ncores_per_instance to ninstances core_list = cores[ - i - * args.ncores_per_instance : (i + 1) + i * args.ncores_per_instance : (i + 1) * args.ncores_per_instance ] else: # assign ncores_per_instance from rank core_list = cores[ - args.rank - * args.ncores_per_instance : (args.rank + 1) + args.rank * args.ncores_per_instance : (args.rank + 1) * args.ncores_per_instance ] @@ -626,9 +626,9 @@ def launch(self, args): if local_size > 1: total_num_cores = len(core_list) cores_per_rank = total_num_cores // local_size - assert ( - cores_per_rank >= 1 - ), "At least one core needs to be assigned to each rank" + assert cores_per_rank >= 1, ( + "At least one core needs to be assigned to each rank" + ) core_list = core_list[ cores_per_rank * local_rank : cores_per_rank * (local_rank + 1) ] diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index aa6a27a3dcc366..578b56e504e208 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -21,6 +21,7 @@ "list_backends", "disable", "set_stance", + "set_enable_guard_collectives", "cudagraph_mark_step_begin", "wrap_numpy", "is_compiling", @@ -28,6 +29,11 @@ "is_exporting", "save_cache_artifacts", "load_cache_artifacts", + "skip_guard_on_inbuilt_nn_modules_unsafe", + "skip_guard_on_all_nn_modules_unsafe", + "keep_tensor_guards_unsafe", + "skip_guard_on_globals_unsafe", + "nested_compile_region", ] @@ -117,6 +123,7 @@ def allow_in_graph(fn): torch.compiler.allow_in_graph(my_custom_function) + @torch.compile(...) def fn(x): x = torch.add(x, 1) @@ -124,6 +131,7 @@ def fn(x): x = torch.add(x, 1) return x + fn(...) Will capture a single graph containing ``my_custom_function()``. @@ -254,14 +262,15 @@ def set_stance( .. code-block:: python @torch.compile - def foo(x): - ... + def foo(x): ... + @torch.compiler.set_stance("force_eager") def bar(): # will not be compiled foo(...) + bar() with torch.compiler.set_stance("force_eager"): @@ -284,6 +293,15 @@ def bar(): - "eager_on_recompile": Run code eagerly when a recompile is necessary. If there is cached compiled code valid for the input, it will still be used. - "fail_on_recompile": Raise an error when recompiling a function. + - "eager_then_compile": Run the first invocation in eager mode, then compile on + subsequent calls. This is beneficial for dynamic shapes as it allows inferring + dynamism from the first two invocations instead of wasting a static compile on + the first invocation. + - "aot_eager_then_compile": Run the first invocation with AOT eager to get memory + benefits from activation checkpointing, then compile on subsequent calls. Like + eager_then_compile, this improves handling of dynamic shapes by avoiding an + initial static compile. + skip_guard_eval_unsafe: A flag to run only differentiating guards. CAUTION - This flag is unsafe and should only be used if your setup @@ -316,6 +334,35 @@ def bar(): set_stance._dynamo_forbidden = True # type: ignore[attr-defined] +def set_enable_guard_collectives(enabled: bool): + """ + Enables use of collectives *during* guard evaluation to synchronize behavior + across ranks. This is expensive: we have to issue a collective every time + we enter a compiled code region, even if no rank actually would need to + compile. This can help prevent NCCL hangs by ensuring that we never have a + situation where one rank starts recompiling while other ranks don't compile; + it is especially useful in conjunction with enable_compiler_collectives + where such a situation would immediately cause a hang (as it is necessary + for all ranks to compile at the same time to run compiler collectives). Like + compiler collectives, you can only run this on SPMD programs; you will hang + otherwise. Note that a guard collective is only issued if there is any + compiled code to guard on; if this the first time we encounter a frame or + the frame is skipped, we don't issue collectives. + + Returns the previous setting of enabled. + """ + from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401 + from torch._dynamo.eval_frame import guard_collectives_hook + + if enabled: + return set_guard_complete_hook(guard_collectives_hook) is not None + else: + return set_guard_complete_hook(None) is not None + + +set_enable_guard_collectives._dynamo_forbidden = True # type: ignore[attr-defined] + + def cudagraph_mark_step_begin(): """ Indicates that a new iteration of inference or training is about to begin. @@ -331,6 +378,7 @@ def cudagraph_mark_step_begin(): def rand_foo(): return torch.rand([4], device="cuda") + for _ in range(5): torch.compiler.cudagraph_mark_step_begin() rand_foo() + rand_foo() @@ -464,4 +512,125 @@ def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]: """ from ._cache import CacheArtifactManager, CacheInfo - return CacheArtifactManager.deserialize(serialized_artifacts) + artifacts = CacheArtifactManager.deserialize(serialized_artifacts) + if artifacts is not None: + return CacheArtifactManager.populate_caches(artifacts) + return None + + +def skip_guard_on_inbuilt_nn_modules_unsafe(guard_entries): + """ + A common function to skip guards on the inbuilt nn modules like + torch.nn.Linear. This is unsafe to use by default. But for majority of + torch.compile users, the model code does not modify the inbuilt nn module + attributes. They can benefit from reduction in guard latency overhead using + this API. + + To use this API, use guard_filter_fn argument while calling torch.compile + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe}, + >> ) + """ + return [ + not entry.orig_guard.source.is_unspecialized_builtin_nn_module() + for entry in guard_entries + ] + + +def skip_guard_on_all_nn_modules_unsafe(guard_entries): + """ + A common function to skip guards on all nn modules, both user defined as + well inbuilt nn modules (like torch.nn.Linear). This is unsafe to use by + default. But for majority of torch.compile users, the model code does not + modify the nn module attributes. They can benefit from reduction in guard + latency overhead using this API. + + To use this API, use guard_filter_fn argument while calling torch.compile + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe}, + >> ) + """ + + return [ + not entry.orig_guard.source.is_unspecialized_nn_module() + for entry in guard_entries + ] + + +def keep_tensor_guards_unsafe(guard_entries, keep_parameters=False): + """ + A common function to keep tensor guards on all tensors. This is unsafe to + use by default. But if you don't expect any changes in the model code, you + can just keep the tensor guards. + + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.keep_tensor_guards}, + >> ) + """ + + keep_flags = [] + for entry in guard_entries: + if entry.guard_type == "TENSOR_MATCH": + if not isinstance(entry.value, torch.nn.Parameter): + keep_flags.append(True) + elif keep_parameters: + keep_flags.append(True) + else: + keep_flags.append(False) + else: + keep_flags.append(False) + return keep_flags + + +def skip_guard_on_globals_unsafe(guard_entries): + """ + A common function to skip guards on all globals. This is unsafe to use by + default. But if you don't expect any changes in the globals, you can just + keep the tensor guards. + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_globals}, + >> ) + """ + + return [not entry.is_global for entry in guard_entries] + + +def nested_compile_region(fn=None): + """ + Tells **``torch.compile``** that the marked set of operations forms a nested + compile region (which is often repeated in the full model) whose code can be + compiled once and safely reused. ``nested_compile_region`` can also be used + as a decorator. + + During **``torch.compile``** tracing, the compiler applies *hierarchical + compilation* with ``nested_compile_region``: it emits optimized code for the + marked region the first time it is encountered and re-emits (or “stamps + out”) the previously compiled code on every subsequent invocation. This can + substantially reduce overall compile time for deeply-stacked, + structurally-identical components such as the transformer layers of a + large-language-model (LLM). + + Outside a ``torch.compile`` context—i.e., in standard eager execution—the + call is a no-op, so existing workflows remain unaffected. + + Note that ``nested_compile_region`` **does not** promise that a region will + be compiled exactly once. If the compiler detects that new input conditions + (shape, dtype, device, stride, globals etc.) make the cached version invalid + to reuse, it will transparently re-compile the region. Using it is + therefore *safe*: correctness is always preserved, and you pay the extra + compilation cost only when required. + """ + + from torch._higher_order_ops.invoke_subgraph import ( + mark_compile_region as _mark_compile_region, + ) + + return _mark_compile_region(fn) diff --git a/torch/compiler/_cache.py b/torch/compiler/_cache.py index 85c2a391e10b07..d486af78097d08 100644 --- a/torch/compiler/_cache.py +++ b/torch/compiler/_cache.py @@ -48,6 +48,9 @@ def encode(content: Any) -> bytes: def populate_cache(self) -> None: pass + def precompile_compatible(self) -> bool: + return False + @staticmethod def type() -> str: """ @@ -69,9 +72,9 @@ class CacheArtifactFactory: @classmethod def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]: artifact_type_key = artifact_cls.type() - assert ( - artifact_cls.type() not in cls._artifact_types - ), f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory" + assert artifact_cls.type() not in cls._artifact_types, ( + f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory" + ) cls._artifact_types[artifact_type_key] = artifact_cls setattr( CacheInfo, @@ -82,9 +85,9 @@ def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]: @classmethod def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]: - assert ( - artifact_type_key in cls._artifact_types - ), f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory" + assert artifact_type_key in cls._artifact_types, ( + f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory" + ) return cls._artifact_types[artifact_type_key] @classmethod @@ -128,6 +131,10 @@ def aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body] def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body] ... + @property + def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + def add(self, artifact: CacheArtifact) -> None: self.artifacts[artifact.type()].append(artifact.key) @@ -159,6 +166,9 @@ def _deserialize_single_cache( return artifact_type_key, artifacts +CacheArtifactsResult = dict[str, list[CacheArtifact]] + + class CacheArtifactManager: """ Lightweight manager class for collecting and processing cache artifacts for @@ -177,16 +187,16 @@ class CacheArtifactManager: """ # Protected by the compile_lock - _new_cache_artifacts: defaultdict[str, list[CacheArtifact]] = defaultdict(list) + _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) # Keep a seperate seen artifacts list to make avoid unnecessary duplicates # This list will not be cleared between serialize() calls _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet() # When serialize() is called, artifacts are transferred from _cache_artifacts to # internal data structure of the _serializer # This allows us to only pay the cost of serialization if serialize() is called - _serializer: AppendingByteSerializer[ - tuple[str, list[CacheArtifact]] - ] = AppendingByteSerializer(serialize_fn=_serialize_single_cache) + _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = ( + AppendingByteSerializer(serialize_fn=_serialize_single_cache) + ) _cache_info: CacheInfo = CacheInfo() @classmethod @@ -207,7 +217,7 @@ def with_fresh_cache(cls) -> Generator[None, None, None]: cls._new_cache_artifacts = defaultdict(list) cls._seen_artifacts = OrderedSet() cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache) - cls._cache_info = CacheInfo() + cls._cache_info = cls._cache_info.__class__() try: yield finally: @@ -268,9 +278,9 @@ def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: return None @staticmethod - def deserialize(serialized_artifacts: bytes) -> Optional[CacheInfo]: + def deserialize(serialized_artifacts: bytes) -> Optional[CacheArtifactsResult]: """ - Converts the portable format back into various filesystem caches + Converts the portable format back into CacheArtifacts """ try: CacheArtifactManager._ensure_cache_artifacts_registered() @@ -284,6 +294,10 @@ def deserialize(serialized_artifacts: bytes) -> Optional[CacheInfo]: log.warning("Failed to un-pickle cache artifacts", exc_info=True) return None + return artifacts + + @staticmethod + def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: info = CacheInfo() for artifact in chain(*artifacts.values()): log.debug("writing: %s", artifact) @@ -292,8 +306,8 @@ def deserialize(serialized_artifacts: bytes) -> Optional[CacheInfo]: return info - @staticmethod - def _ensure_cache_artifacts_registered() -> None: + @classmethod + def _ensure_cache_artifacts_registered(cls) -> None: """When deserializing caches in fresh process, we need to ensure that all cache artifacts are registered in the cache registry. This is done by simply importing all the cache artifacts already wrapped with register call. diff --git a/torch/compiler/config.py b/torch/compiler/config.py index b03afa12fa1747..f9ec226c254899 100644 --- a/torch/compiler/config.py +++ b/torch/compiler/config.py @@ -29,7 +29,10 @@ # FB-internal note: you do NOT have to specify this explicitly specify this if # you run on MAST, we will automatically default this to # mast:MAST_JOB_NAME:MAST_JOB_VERSION. -job_id: Optional[str] = Config(env_name_default="TORCH_COMPILE_JOB_ID", default=None) +job_id: Optional[str] = Config( + env_name_default=["TORCH_COMPILE_JOB_ID", "TORCH_COMPILE_STICKY_PGO_KEY"], + default=None, +) """ Semantically, this should be an identifier that uniquely identifies, e.g., a training job. You might have multiple attempts of the same job, e.g., if it was @@ -74,15 +77,6 @@ and force_parameter_static_shapes. """ -sticky_pgo_key: str = Config( - env_name_default="TORCH_COMPILE_STICKY_PGO_KEY", default="" -) -""" -If you want to share PGO profiles across different jobs (and not just attempts), you can set -this to a string that identifies the shared profile. This is useful if you want to share PGO profiles -for models that are not identical, but are similar enough to share PGO profiles. -""" - unbacked_sources: str = Config( env_name_default="TORCH_COMPILE_UNBACKED_SOURCES", default="" ) diff --git a/torch/csrc/Event.cpp b/torch/csrc/Event.cpp index cd5f6ccbd7d586..b5369e436cfec2 100644 --- a/torch/csrc/Event.cpp +++ b/torch/csrc/Event.cpp @@ -257,22 +257,22 @@ static struct PyGetSetDef THPEvent_properties[] = { // NOLINTNEXTLINE(*c-arrays*, *global-variables) static PyMethodDef THPEvent_methods[] = { - {(char*)"from_ipc_handle", + {"from_ipc_handle", castPyCFunctionWithKeywords(THPEvent_from_ipc_handle), METH_CLASS | METH_VARARGS | METH_KEYWORDS, nullptr}, - {(char*)"record", + {"record", castPyCFunctionWithKeywords(THPEvent_record), METH_VARARGS | METH_KEYWORDS, nullptr}, - {(char*)"wait", + {"wait", castPyCFunctionWithKeywords(THPEvent_wait), METH_VARARGS | METH_KEYWORDS, nullptr}, - {(char*)"query", THPEvent_query, METH_NOARGS, nullptr}, - {(char*)"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr}, - {(char*)"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr}, - {(char*)"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr}, + {"query", THPEvent_query, METH_NOARGS, nullptr}, + {"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr}, + {"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr}, + {"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr}, {nullptr}}; PyTypeObject THPEventType = { diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index 60b56668410828..80ee9630dcf59b 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -14,7 +14,8 @@ PyObject *THPException_FatalError, *THPException_LinAlgError, *THPException_OutOfMemoryError, *THPException_DistError, *THPException_DistBackendError, *THPException_DistNetworkError, - *THPException_DistStoreError, *THPException_DistQueueEmptyError; + *THPException_DistStoreError, *THPException_DistQueueEmptyError, + *THPException_AcceleratorError; #define ASSERT_TRUE(cond) \ if (!(cond)) \ @@ -125,6 +126,18 @@ could not be completed because the input matrix is singular.", module, "_DistQueueEmptyError", THPException_DistQueueEmptyError) == 0); + // NOLINTNEXTLINE(bugprone-assignment-in-if-condition) + ASSERT_TRUE( + THPException_AcceleratorError = PyErr_NewExceptionWithDoc( + "torch.AcceleratorError", + "Exception raised while executing on device", + PyExc_RuntimeError, + nullptr)); + type = (PyTypeObject*)THPException_AcceleratorError; + ASSERT_TRUE( + PyModule_AddObject( + module, "AcceleratorError", THPException_AcceleratorError) == 0); + return true; } @@ -244,13 +257,6 @@ TypeError::TypeError(const char* format, ...) { va_end(fmt_args); } -AttributeError::AttributeError(const char* format, ...) { - va_list fmt_args{}; - va_start(fmt_args, format); - msg = formatMessage(format, fmt_args); - va_end(fmt_args); -} - void PyWarningHandler::InternalHandler::process(const c10::Warning& warning) { warning_buffer_.push_back(warning); } @@ -341,4 +347,18 @@ PyWarningHandler::~PyWarningHandler() noexcept(false) { } } +namespace detail { +PyObject* _new_accelerator_error_object(const c10::AcceleratorError& e) { + auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() + : e.what_without_backtrace(); + + auto py_msg = PyUnicode_FromString(msg); + auto rc = PyObject_CallOneArg(THPException_AcceleratorError, py_msg); + auto error_code = PyInt_FromLong(e.get_error_code()); + PyObject_SetAttrString(rc, "error_code", error_code); + Py_XDECREF(py_msg); + Py_XDECREF(error_code); + return rc; +} +} // namespace detail } // namespace torch diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index e5c9afbd657ab4..d74447de83f415 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -86,6 +86,12 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) { DistQueueEmptyError, THPException_DistQueueEmptyError, retstmnt) \ _CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \ _CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt) \ + catch (c10::AcceleratorError & e) { \ + auto exc = torch::detail::_new_accelerator_error_object(e); \ + PyErr_SetObject(THPException_AcceleratorError, exc); \ + Py_XDECREF(exc); \ + retstmnt; \ + } \ _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \ catch (torch::PyTorchError & e) { \ auto msg = torch::processErrorMsg(e.what()); \ @@ -141,7 +147,8 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) { extern PyObject *THPException_FatalError, *THPException_LinAlgError, *THPException_OutOfMemoryError, *THPException_DistError, *THPException_DistBackendError, *THPException_DistNetworkError, - *THPException_DistStoreError, *THPException_DistQueueEmptyError; + *THPException_DistStoreError, *THPException_DistQueueEmptyError, + *THPException_AcceleratorError; // Throwing this exception means that the python error flags have been already // set and control should be immediately returned to the interpreter. @@ -296,7 +303,7 @@ struct TypeError : public PyTorchError { // Translates to Python AttributeError struct AttributeError : public PyTorchError { - AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); + using PyTorchError::PyTorchError; PyObject* python_type() override { return PyExc_AttributeError; } @@ -322,7 +329,7 @@ struct PyWarningHandler { /** Call if an exception has been thrown - * Necessary to determine if it is safe to throw from the desctructor since + * Necessary to determine if it is safe to throw from the destructor since * std::uncaught_exception is buggy on some platforms and generally * unreliable across dynamic library calls. */ @@ -369,6 +376,8 @@ auto wrap_pybind_function_impl_( END_HANDLE_TH_ERRORS_PYBIND }; } + +PyObject* _new_accelerator_error_object(const c10::AcceleratorError&); } // namespace detail // Wrap a function with TH error and warning handling. diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 3c23782e74c2fc..d60602f2086f6f 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #ifndef _MSC_VER @@ -280,6 +281,8 @@ static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) { virtual ~Baz() = default; }; Baz x{}; + // Purposely cast through `void*` so there's no fixups applied. + // NOLINTNEXTLINE(bugprone-casting-through-void,-warnings-as-errors) auto y = static_cast(static_cast(&x)); auto rc = y->bar(); return THPUtils_packInt32(rc); @@ -583,8 +586,11 @@ static PyObject* THPModule_getCpuCapability( END_HANDLE_TH_ERRORS } -static void DLPack_Capsule_Destructor(PyObject* data) { - if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor"))) { +namespace { + +template +void DLPack_Capsule_Destructor(PyObject* data) { + if (C10_LIKELY(!PyCapsule_IsValid(data, at::DLPackTraits::capsule))) { // early out, see DLPack spec: if a consuming library sets the capsule // name to something else, they own it and we don't need to do anything return; @@ -594,23 +600,36 @@ static void DLPack_Capsule_Destructor(PyObject* data) { // since consuming libraries should rename the capsule according to spec. // Note that this cannot set a python error (we checked validity above), // so we don't need to handle python error state here. - DLManagedTensor* dlMTensor = - (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + T* tensor = (T*)PyCapsule_GetPointer(data, at::DLPackTraits::capsule); // the dlMTensor has not been consumed, call deleter ourselves. // DLPack spec mentions that deleter may be NULL, but deleter from // `at::toDLPack` is never NULL, so no need for an additional check here. - dlMTensor->deleter(dlMTensor); + tensor->deleter(tensor); END_HANDLE_TH_ERRORS_RET() } -static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { +template +PyObject* THPModule_toDLPackImpl(PyObject* _unused, PyObject* data) { HANDLE_TH_ERRORS TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); - DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data)); - return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor); + auto tensor = at::DLPackTraits::toDLPack(THPVariable_Unpack(data)); + return PyCapsule_New( + tensor, at::DLPackTraits::capsule, DLPack_Capsule_Destructor); END_HANDLE_TH_ERRORS } +} // namespace + +static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { + return THPModule_toDLPackImpl(_unused, data); +} + +static PyObject* THPModule_toDLPackVersioned( + PyObject* _unused, + PyObject* data) { + return THPModule_toDLPackImpl(_unused, data); +} + static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { using namespace torch::autograd; HANDLE_TH_ERRORS @@ -667,10 +686,12 @@ static PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) { } static PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS if (at::globalContext().allowTF32CuDNN()) Py_RETURN_TRUE; else Py_RETURN_FALSE; + END_HANDLE_TH_ERRORS } static PyObject* THPModule_setFloat32MatmulPrecision( @@ -691,6 +712,7 @@ static PyObject* THPModule_setFloat32MatmulPrecision( static PyObject* THPModule_float32MatmulPrecision( PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS std::string s = "highest"; auto p = at::globalContext().float32MatmulPrecision(); if (p == at::Float32MatmulPrecision::HIGH) { @@ -699,6 +721,7 @@ static PyObject* THPModule_float32MatmulPrecision( s = "medium"; } return THPUtils_packString(s); + END_HANDLE_TH_ERRORS } static PyObject* THPModule_setSDPPriorityOrder( PyObject* _unused, @@ -1113,10 +1136,12 @@ static PyObject* THPModule_setAllowTF32CuBLAS( static PyObject* THPModule_allowTF32CuBLAS( PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS if (at::globalContext().allowTF32CuBLAS()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; + END_HANDLE_TH_ERRORS } static PyObject* THPModule_setAllowFP16ReductionCuBLAS( @@ -1248,6 +1273,7 @@ static PyObject* THPModule_setQEngine(PyObject* /* unused */, PyObject* arg) { "but got ", THPUtils_typename(arg)); auto qengine = THPUtils_unpackLong(arg); + // NOLINTNEXTLINE(clang-analyzer-optin.core.EnumCastOutOfRange) at::globalContext().setQEngine(static_cast(qengine)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -1664,6 +1690,7 @@ static std::initializer_list TorchMethods = { METH_NOARGS, nullptr}, {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, + {"_to_dlpack_versioned", THPModule_toDLPackVersioned, METH_O, nullptr}, {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr}, {"_rename_privateuse1_backend", @@ -1742,6 +1769,7 @@ static std::initializer_list TorchMethods = { {nullptr, nullptr, 0, nullptr}}; #ifdef USE_CUDA +// NOLINTBEGIN(misc-use-internal-linkage) void THCPStream_init(PyObject* module); void THCPEvent_init(PyObject* module); void THCPGraph_init(PyObject* module); @@ -1750,6 +1778,7 @@ PyMethodDef* THCPModule_methods(); namespace torch::cuda { void initModule(PyObject* module); } // namespace torch::cuda +// NOLINTEND(misc-use-internal-linkage) #endif #ifdef USE_XPU @@ -1791,6 +1820,66 @@ class WeakTensorRef { } }; +namespace { + +using SigHandler = void (*)(int); + +SigHandler* _getOldHandler(int signum) { +#define SIG_CHECK(n) \ + if (signum == (n)) { \ + static SigHandler handler = nullptr; \ + return &handler; \ + } + + SIG_CHECK(SIGSEGV); + SIG_CHECK(SIGILL); + + throw std::runtime_error("unexpected signal number"); +#undef SIG_CHECK +} + +extern "C" void _signalHandler(int signum) { + // Note that technically there's not much you're allowed to do here - but + // we're probably dying anyway so give it a try... + + auto oldAction = *_getOldHandler(signum); + *_getOldHandler(signum) = nullptr; + + // If we hit another signal don't run this handler again. + std::signal(signum, oldAction); + +#ifdef _WIN32 + const char* signame = ""; +#else + const char* signame = strsignal(signum); +#endif + + fprintf( + stderr, + "Process %d crashed with signal %s (%d):\n", + getpid(), + signame, + signum); + + auto bt = c10::get_backtrace(); + fwrite(bt.data(), 1, bt.size(), stderr); + + // Try to run the previous signal handler + if (oldAction != SIG_IGN && oldAction != SIG_DFL) { + oldAction(signum); + } + if (oldAction != SIG_IGN) { + _exit(-1); + } +} + +void _initCrashHandler() { + *_getOldHandler(SIGILL) = std::signal(SIGILL, _signalHandler); + *_getOldHandler(SIGSEGV) = std::signal(SIGSEGV, _signalHandler); +} + +} // anonymous namespace + extern "C" TORCH_PYTHON_API PyObject* initModule(); // separate decl and defn for msvc error C2491 PyObject* initModule() { @@ -1969,6 +2058,7 @@ PyObject* initModule() { }); auto py_module = py::reinterpret_borrow(module); + py_module.def("_initCrashHandler", &_initCrashHandler); py_module.def("_demangle", &c10::demangle); py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython); py_module.def("_log_api_usage_metadata", &LogAPIUsageMetadataFromPython); @@ -2012,6 +2102,12 @@ Call this whenever a new thread is created in order to propagate values from return at::caching::is_cached_tensor(t); }); + py_module.def("_storage_Use_Count", [](size_t storage_impl_ptr) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; + return c10::raw::weak_intrusive_ptr::use_count(storage_impl); + }); + ASSERT_TRUE( set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False)); ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False)); @@ -2287,6 +2383,18 @@ Call this whenever a new thread is created in order to propagate values from at::DataPtr(reinterpret_cast(data_ptr), device)); }); + py_module.def( + "_get_fp32_precision_getter", [](std::string backend, std::string op) { + return at::globalContext().float32Precision(backend, op); + }); + + py_module.def( + "_set_fp32_precision_setter", + [](std::string backend, std::string op, std::string precision) { + at::globalContext().setFloat32Precision(backend, op, precision); + return precision; + }); + py_module.def( "_stash_obj_in_tls", [](const std::string& key, py::handle arg) { at::impl::ThreadLocalPythonObjects::get_state().set( @@ -2365,7 +2473,7 @@ Call this whenever a new thread is created in order to propagate values from auto acc = at::getAccelerator(check.value_or(false)); if (acc.has_value()) { bool is_available = at::globalContext() - .getAcceleratorHooksInterface(acc.value()) + .getAcceleratorHooksInterface(acc) .isAvailable(); if (!is_available) { diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index 094645030b7702..e0a54ad1765992 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -161,17 +161,55 @@ static PyObject* wrap_tuple_fn(Args... args) { return result.release(); } +static PyObject* THPSize_concat(PyObject* left, PyObject* right) { + // wrap tuple's sq_concat with a customized error message + HANDLE_TH_ERRORS + TORCH_CHECK_TYPE( + PyTuple_Check(right), + "can only concatenate tuple (not ", + Py_TYPE(right)->tp_name, + ") to torch.Size"); + static binaryfunc tuple_concat = PyTuple_Type.tp_as_sequence->sq_concat; + static binaryfunc size_concat = + wrap_tuple_fn; + return size_concat(left, right); + END_HANDLE_TH_ERRORS +} + +static PyObject* THPSize_add(PyObject* left, PyObject* right) { + /* NOTE: The python interpreter tries, in order: + * 1. right.nb_add(left, right) (only if right is a subclass of left) + * 2. left.nb_add(left, right) + * 3. right.nb_add(left, right) + * 4. left.sq_concat(right) + * Hence, to support tuple + size -> size, we need to implement nb_add. + */ + HANDLE_TH_ERRORS + if (!PyTuple_Check(left) || !PyTuple_Check(right)) { + Py_RETURN_NOTIMPLEMENTED; + } + return THPSize_concat(left, right); + END_HANDLE_TH_ERRORS +} + +// Needed to ensure tuple + size returns a size instead of a tuple +static PyNumberMethods THPSize_as_number = { + &THPSize_add, // nb_add + nullptr, // nb_subtract + nullptr, // nb_multiply + // ... rest nullptr +}; + // We use an anonymous namespace instead of static to work around // (what @peterjc123 think is) a bug in Visual Studio namespace { -auto sq_concat = PyTuple_Type.tp_as_sequence->sq_concat; auto sq_repeat = PyTuple_Type.tp_as_sequence->sq_repeat; binaryfunc mp_subscript = PyTuple_Type.tp_as_mapping->mp_subscript; } // namespace static PySequenceMethods THPSize_as_sequence = { nullptr, /* sq_length */ - wrap_tuple_fn, + &THPSize_concat, /* sq_concat */ wrap_tuple_fn, nullptr, /* sq_item */ nullptr, /* sq_slice */ @@ -242,7 +280,7 @@ PyTypeObject THPSizeType = { nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ (reprfunc)THPSize_repr, /* tp_repr */ - nullptr, /* tp_as_number */ + &THPSize_as_number, /* tp_as_number */ &THPSize_as_sequence, /* tp_as_sequence */ &THPSize_as_mapping, /* tp_as_mapping */ nullptr, /* tp_hash */ diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 8074406cdcad98..d566dc666ebfed 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -101,7 +101,7 @@ PyObject* THPStorage_Wrap(c10::Storage storage) { // If the StorageImpl has a PyObject that is managed by a different // interpreter than the current one, create a new StorageImpl that points to // the same data and then create the Python storage from that. - // NOTE: This is only supposed to happen in MultiPy + // NOTE: This is only supposed to happen in MultiPy // codespell:ignore if (pyobj_slot->has_pyobj_nonhermetic() && !pyobj_slot->check_interpreter(getPyInterpreter())) { return THPStorage_NewWithStorage( diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp index 614ea9d6f5d269..9f7d667613dc59 100644 --- a/torch/csrc/StorageSharing.cpp +++ b/torch/csrc/StorageSharing.cpp @@ -331,8 +331,7 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) { _ref_counter = PyBytes_FromString((sent_data->handle()).c_str()); _ref_counter_offset = THPUtils_packUInt64(sent_data->offset()); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - cudaIpcEventHandle_t ipc_event_handle; + cudaIpcEventHandle_t ipc_event_handle{}; if (sent_data->event_sync_required_) { C10_CUDA_CHECK( diff --git a/torch/csrc/api/include/torch/data/dataloader/stateless.h b/torch/csrc/api/include/torch/data/dataloader/stateless.h index cdd4c2cc069c82..07bf3302054424 100644 --- a/torch/csrc/api/include/torch/data/dataloader/stateless.h +++ b/torch/csrc/api/include/torch/data/dataloader/stateless.h @@ -15,7 +15,7 @@ namespace torch::data { /// A dataloader for stateless datasets. /// /// This dataloader follows the traditional PyTorch dataloader design, whereby a -/// (posssibly) stateful sampler produces *batch requests* for a stateless +/// (possibly) stateful sampler produces *batch requests* for a stateless /// dataset, which acts as a simple batch request to batch mapping. The batch /// request will often be an array of indices, and if the dataset is a simple /// image dataset, the dataset would produce the images at those indices. diff --git a/torch/csrc/api/include/torch/data/datasets/chunk.h b/torch/csrc/api/include/torch/data/datasets/chunk.h index a32a7b21b569e8..1eba537c44c281 100644 --- a/torch/csrc/api/include/torch/data/datasets/chunk.h +++ b/torch/csrc/api/include/torch/data/datasets/chunk.h @@ -234,7 +234,7 @@ class BatchDataBuffer { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) ExampleSampler& example_sampler_; - // configurable maximun number of elements the queue can hold at one time. + // configurable maximum number of elements the queue can hold at one time. size_t queue_capacity_; // When set to true, it wakes the writer threads from the wait and exit @@ -286,7 +286,7 @@ struct ChunkDatasetOptions { /// The capacity of the queue for batch caching. TORCH_ARG(size_t, cache_size) = 2048; - // The number of chunks to perfrom cross-chunk shuffling. Default to 1 meaning + // The number of chunks to perform cross-chunk shuffling. Default to 1 meaning // no cross-chunk shuffling. When it is equal to n (n > 1), n random // chunks will be loaded at once and example shuffling will be performed // across all those n chunks. @@ -303,9 +303,10 @@ struct ChunkDatasetOptions { /// /// Unlike regular dataset, chunk dataset require two samplers to operate and /// keeps an internal state. `ChunkSampler` selects, which chunk to load next, -/// while the `ExampleSampler` determins the order of Examples that are returned -/// in each `get_batch` call. The hierarchical sampling approach used here is -/// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf +/// while the `ExampleSampler` determines the order of Examples that are +/// returned in each `get_batch` call. The hierarchical sampling approach used +/// here is inspired by this paper +/// http://martin.zinkevich.org/publications/nips2010.pdf template < typename ChunkReader, typename ChunkSampler = samplers::RandomSampler, @@ -346,7 +347,7 @@ class ChunkDataset final } /// Default get_batch method of BatchDataset. This method returns - /// Example batches created from the preloaded chunks. The implemenation + /// Example batches created from the preloaded chunks. The implementation /// is dataset agnostic and does not need overriding in different chunk /// datasets. BatchType get_batch(size_t batch_size) override { diff --git a/torch/csrc/api/include/torch/data/samplers/base.h b/torch/csrc/api/include/torch/data/samplers/base.h index 67c1ad5ea7cbe5..ebaf40848abcef 100644 --- a/torch/csrc/api/include/torch/data/samplers/base.h +++ b/torch/csrc/api/include/torch/data/samplers/base.h @@ -24,7 +24,7 @@ class Sampler { /// Resets the `Sampler`'s internal state. /// Typically called before a new epoch. - /// Optionally, accepts a new size when reseting the sampler. + /// Optionally, accepts a new size when resetting the sampler. virtual void reset(std::optional new_size) = 0; /// Returns the next index if possible, or an empty optional if the diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 7737cbe5cfd020..49de1c8af63f33 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -344,7 +344,7 @@ namespace detail { inline Tensor glu(const Tensor& input, int64_t dim) { TORCH_CHECK( input.dim() != 0, - "glu does not suppport scalars because halving size must be even"); + "glu does not support scalars because halving size must be even"); return torch::glu(input, dim); } } // namespace detail diff --git a/torch/csrc/api/include/torch/nn/modules/container/moduledict.h b/torch/csrc/api/include/torch/nn/modules/container/moduledict.h index 16c9c94489b0d4..246ed8abb633bf 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/moduledict.h +++ b/torch/csrc/api/include/torch/nn/modules/container/moduledict.h @@ -130,7 +130,7 @@ class ModuleDictImpl : public Cloneable { return modules_.is_empty(); } - /// Check if the centain parameter with the key in the `ModuleDict`. + /// Check if the certain parameter with the key in the `ModuleDict`. bool contains(const std::string& key) const noexcept { return modules_.contains(key); } diff --git a/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h b/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h index df6d003750ab98..008d790fdece11 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h +++ b/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h @@ -107,7 +107,7 @@ class ParameterDictImpl : public Cloneable { parameters_.clear(); } - /// Check if the centain parameter with the key in the ParameterDict + /// Check if the certain parameter with the key in the ParameterDict bool contains(const std::string& key) const noexcept { return parameters_.contains(key); } diff --git a/torch/csrc/api/include/torch/serialize/input-archive.h b/torch/csrc/api/include/torch/serialize/input-archive.h index f399ac63d5e7ec..6495d532c32ce1 100644 --- a/torch/csrc/api/include/torch/serialize/input-archive.h +++ b/torch/csrc/api/include/torch/serialize/input-archive.h @@ -101,7 +101,7 @@ class TORCH_API InputArchive final { std::vector keys(); /// Forwards all arguments to `read()`. - /// Useful for generic code that can be re-used for both `InputArchive` and + /// Useful for generic code that can be reused for both `InputArchive` and /// `OutputArchive` (where `operator()` forwards to `write()`). template void operator()(Ts&&... ts) { diff --git a/torch/csrc/api/include/torch/serialize/output-archive.h b/torch/csrc/api/include/torch/serialize/output-archive.h index 29052bfe6c6874..f47aca4df95a51 100644 --- a/torch/csrc/api/include/torch/serialize/output-archive.h +++ b/torch/csrc/api/include/torch/serialize/output-archive.h @@ -66,7 +66,7 @@ class TORCH_API OutputArchive final { void save_to(const std::function& func); /// Forwards all arguments to `write()`. - /// Useful for generic code that can be re-used for both `OutputArchive` and + /// Useful for generic code that can be reused for both `OutputArchive` and /// `InputArchive` (where `operator()` forwards to `read()`). template void operator()(Ts&&... ts) { diff --git a/torch/csrc/api/src/nn/modules/transformer.cpp b/torch/csrc/api/src/nn/modules/transformer.cpp index c755c61b751005..1030cf182438a3 100644 --- a/torch/csrc/api/src/nn/modules/transformer.cpp +++ b/torch/csrc/api/src/nn/modules/transformer.cpp @@ -19,7 +19,7 @@ TransformerEncoderLayerImpl::TransformerEncoderLayerImpl( void TransformerEncoderLayerImpl::reset() { // NOTE: reset() is for initializing the model only, calling reset() after the - // model is created will throw exceptionss. Call reset_parameter() if the + // model is created will throw exceptions. Call reset_parameter() if the // created model needs a reset self_attn = this->register_module( diff --git a/torch/csrc/api/src/serialize.cpp b/torch/csrc/api/src/serialize.cpp index e8497a7f22b567..fae54d1248476b 100644 --- a/torch/csrc/api/src/serialize.cpp +++ b/torch/csrc/api/src/serialize.cpp @@ -1,5 +1,4 @@ #include -#include #include #include diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index aaaadc49672999..908a980cfee9c1 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -823,7 +823,7 @@ Tensor prod_backward( if (input.dim() == 0) { return grad; } - dim = at::maybe_wrap_dim(dim, static_cast(input.sym_sizes().size())); + dim = at::maybe_wrap_dim(dim, input.dim()); if (!keepdim) { // `prod` reduces the dimension at `dim`, // so, unsqueeze `grad` and `result` at dim. @@ -876,8 +876,8 @@ Tensor logsumexp_backward( IntArrayRef dim, bool keepdim) { if (!keepdim && self.dim() != 0) { - grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size()); - result = unsqueeze_multiple(result, dim, self.sym_sizes().size()); + grad = unsqueeze_multiple(grad, dim, self.dim()); + result = unsqueeze_multiple(result, dim, self.dim()); } return grad * (self - result).exp().conj(); } @@ -894,7 +894,8 @@ Tensor logcumsumexp_backward( // Reference: // https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863 - auto scalar_min = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( + auto scalar_min = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, at::typeMetaToScalarType(grad.dtype()), "logcumsumexp_backward", @@ -1888,7 +1889,7 @@ Tensor var_backward( } auto dim = dim_opt.value(); if (!keepdim && self.dim() > 1) { - grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size()); + grad = unsqueeze_multiple(grad, dim, self.dim()); } const c10::SymFloat rnumel(_safe_size(self.sym_sizes(), dim)); return (c10::SymFloat(2.0) / (rnumel - correction)) * grad * @@ -2904,7 +2905,7 @@ Tensor softplus_double_backward( // 4. Return the as_strided view of the storage tensor using input geometry. // // See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to -// roughly detech overlapping memory. +// roughly detect overlapping memory. // NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // @@ -2994,7 +2995,7 @@ Tensor softplus_double_backward( // Now that we established the above claim (***), we consider the // view operation as first sorting the dimensions (i.e., blocks), // apply the original view (since it only cares dimensions being -// consecutive and contiguous withtin each block), and then undo +// consecutive and contiguous within each block), and then undo // the sort. // // Consider a single block B in the output, @@ -3046,7 +3047,7 @@ Tensor softplus_double_backward( // size'[i] <= floor(size[i] / k) // // If size'[i] = 1, invariant is obviously satisfied as we are -// just removing a dimension (afte step (1)). +// just removing a dimension (after step (1)). // // Assume size'[i] > 1. // @@ -4718,10 +4719,10 @@ static Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim = true) { // reductions were done with keepdim=True static Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) { auto src_expanded = src; - while (src_expanded.sizes().size() < target.sizes().size() - 1) { + while (src_expanded.dim() < target.dim() - 1) { src_expanded = src_expanded.unsqueeze(1); } - if (src_expanded.sizes().size() == target.sizes().size() - 1) { + if (src_expanded.dim() == target.dim() - 1) { src_expanded = src_expanded.unsqueeze(0); } return src_expanded; @@ -4732,7 +4733,7 @@ static Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) { // do a straight expansion because it won't follow the broadcasting rules. static Tensor expand_as_dim1(const Tensor& src, const Tensor& target) { auto src_expanded = src; - while (src_expanded.sizes().size() < target.sizes().size() - 1) { + while (src_expanded.dim() < target.dim() - 1) { src_expanded = src_expanded.unsqueeze(1); } return src_expanded.expand_as(target); @@ -5244,7 +5245,7 @@ bool any_variable_defined(const variable_list& variables) { // Derivations for the householder_product.backward method. // // Given a sequence of vectors v_1, ..., v_n and a sequence of scalars tau_1, -// ..., tau_k, the torch.linalg.householder_product computes the firt n columns +// ..., tau_k, the torch.linalg.householder_product computes the first n columns // of the following product: Q = (I - tau_1 v_1 v_1^H) ... (I - tau_k v_k // v_k^H). Let // H_i(sigma) := I - sigma v_i v_i^H, so Q = (H_1(sigma_1) ... @@ -5648,7 +5649,7 @@ std::tuple ormqr_backward( // left = false and transpose = true is very much similar with just // transposed arguments passed into householder_product_backward. // Ormqr computes B = H_1 * ... * H_k * A. - // The sensivity wrt H_i is given by (see notes in + // The sensitivity wrt H_i is given by (see notes in // householder_product_backward) Tr(H_i_plus B B_grad^H H_i_minus dH_i), // so, since householder_product_backward respects `for i in range(k)`, we // could reuse householder_product_backward with diff --git a/torch/csrc/autograd/TraceTypeManual.cpp b/torch/csrc/autograd/TraceTypeManual.cpp index 1c6e1d29e010e7..3690751ed196b4 100644 --- a/torch/csrc/autograd/TraceTypeManual.cpp +++ b/torch/csrc/autograd/TraceTypeManual.cpp @@ -278,7 +278,7 @@ static void general_trace_function( tracer::addOutput(node, iter->toTensorList()); } else { throw std::runtime_error( - "unsupported ouptut list type: " + elem_type->str()); + "unsupported output list type: " + elem_type->str()); } } else if (type->kind() == TypeKind::ClassType) { AT_ASSERT(iter->isObject()); diff --git a/torch/csrc/autograd/anomaly_mode.h b/torch/csrc/autograd/anomaly_mode.h index e29d1bbf054cb5..b3412b6b9e588f 100644 --- a/torch/csrc/autograd/anomaly_mode.h +++ b/torch/csrc/autograd/anomaly_mode.h @@ -30,7 +30,7 @@ struct TORCH_API AnomalyMode { /// /// Anomaly detection mode is useful for debugging problems happening /// in the backward, such as unexpectedly modified tensors or NaNs -/// occuring in the backward. +/// occurring in the backward. /// /// The enabling of anomaly mode is global - as soon as there is one /// such guard, it is enabled for all computation and threads. It also diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index c705ba11d5e932..b1ef5b3a76a42d 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -53,7 +53,7 @@ using at::Tensor; // // This layout constraint is ensured in the `set_fw_grad` function below -// More complex cases arrise when non-dual Tensor interact with dual Tensors. +// More complex cases arise when non-dual Tensor interact with dual Tensors. // The two most important cases are: // // # Have: @@ -222,7 +222,7 @@ void AutogradMeta::set_fw_grad( if (utils::has_same_meta(new_grad, base) && utils::has_same_meta(new_grad, self)) { // TODO extend this special case to when the underlying storage of - // new_grad can be re-used. + // new_grad can be reused. new_base_fw_grad = new_grad; } else { new_base_fw_grad = diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 8b5d0536df0e66..d51f07093213ae 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -75,17 +75,17 @@ inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) { std::atomic the_compiled_autograd = nullptr; #define COMPILED_AUTOGRAD_POISON \ reinterpret_cast(1) -std::atomic num_threads_in_backwards; +std::atomic num_threads_in_compiled_autograd; struct CompiledAutogradThreadingDebugCheck { CompiledAutogradThreadingDebugCheck() { - num_threads_in_backwards++; + num_threads_in_compiled_autograd++; } ~CompiledAutogradThreadingDebugCheck() { release(); } void release() { if (std::exchange(incremented, false)) { - num_threads_in_backwards--; + num_threads_in_compiled_autograd--; } } @@ -611,7 +611,7 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { } } -// Reentrant call will re-use the graph_task's owner thread ready_queue for +// Reentrant call will reuse the graph_task's owner thread ready_queue for // queueing tasks (NOTE: this is not true in the async_mode of the engine). // While we can create separate ready queue for each new reentrant // thread, but sharing the same cpu_ready_queue with parent thread is a @@ -1228,7 +1228,7 @@ void Engine::evaluate_function( } static uint64_t compute_min_topological_nr(const edge_list& outputs) { - // Computes the mininum topological number among all the outputs + // Computes the minimum topological number among all the outputs if (outputs.empty()) { return 0; } @@ -1299,8 +1299,6 @@ auto Engine::execute( "your parameters to None after use to break the cycle and avoid the leak."); } - // Allows us to assert no other threads are in backwards - CompiledAutogradThreadingDebugCheck _thread_check; auto compiled_autograd = the_compiled_autograd.load(); TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON); @@ -1347,6 +1345,11 @@ auto Engine::execute( } if (compiled_autograd != nullptr) { + TORCH_CHECK_NOT_IMPLEMENTED( + num_threads_in_compiled_autograd.load() == 0, + "Re-entrant into Compiled Autograd from a parent Compiled Autograd call is not yet supported. Consider disabling Compiled Autograd on the re-entrant call."); + // Allows us to assert no other threads are in backwards + CompiledAutogradThreadingDebugCheck _thread_check; // see [Note: Compiled Autograd] _thread_check.release(); GraphTaskGuard guard(graph_task); @@ -1471,7 +1474,7 @@ c10::intrusive_ptr Engine::execute_with_graph_task( return graph_task->future_result_; } -// note that when python is present, this base engine will be overriden +// note that when python is present, this base engine will be overridden // with a PythonEngine. Because this typically happens before get_default_engine // is called, this base engine will never be created. Engine& Engine::get_base_engine() { @@ -1495,8 +1498,8 @@ void Engine::set_compiled_autograd(Engine::compiled_autograd_fn fn) { } auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON); TORCH_CHECK( - num_threads_in_backwards.load() == 0 && prior != COMPILED_AUTOGRAD_POISON, - "compiled_autograd._enable() requires no threads in backwards()"); + prior != COMPILED_AUTOGRAD_POISON, + "compiled_autograd._enable() does not support multiple Python threads"); the_compiled_autograd.store(fn); } diff --git a/torch/csrc/autograd/forward_grad.h b/torch/csrc/autograd/forward_grad.h index 9b111ac6b48489..a8c242e08d8694 100644 --- a/torch/csrc/autograd/forward_grad.h +++ b/torch/csrc/autograd/forward_grad.h @@ -27,7 +27,7 @@ struct ForwardGrad; // - Ensure that we can keep the level that we expose to the user API simple // (an integer // that represents the nesting depth) while avoiding confusions when the -// level index is re-used. +// level index is reused. // The important external APIs from this file are: // - ForwardADLevel::get_next_idx() that can be used to enter a new level and diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 106ff5ee0f2f31..fba950bbcec559 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -67,7 +67,7 @@ TORCH_API std::shared_ptr get_current_node(); // or more input `Variable`s and producing zero or more output `Variable`s. All // functions in PyTorch's autograd machinery derive from this class and // override its `apply` method. Instances of such subclasses will then be -// invokable via the call operator. +// invocable via the call operator. // // Nodes in the Autograd Graph //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -592,10 +592,10 @@ struct TORCH_API Node : std::enable_shared_from_this { // 1) Extract tensors/symint args // 2) Collect node information for specialization and caching // Implementations in subclasses should call args.collect() with all node - // attrs. These functions are only called durring backward. + // attrs. These functions are only called during backward. virtual void compiled_args(CompiledNodeArgs& args) const { - throw std::runtime_error( - std::string("compiled_args not implemented: ") + name()); + TORCH_CHECK_NOT_IMPLEMENTED( + false, std::string("compiled_args not implemented: ") + name()); } // Used by compiled autograd to call apply() with different saved tensors @@ -604,8 +604,8 @@ struct TORCH_API Node : std::enable_shared_from_this { virtual variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) { - throw std::runtime_error( - std::string("apply_with_saved not implemented: ") + name()); + TORCH_CHECK_NOT_IMPLEMENTED( + false, std::string("apply_with_saved not implemented: ") + name()); } // If this node is the AOTBackward node produced by torch.compile. diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index 08d0b8d4c4cc54..c72aac4fbecf15 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -24,9 +24,10 @@ struct TORCH_API FunctionPreHook { // only implemented for python hooks, registers hook with compiled autograd virtual void compiled_args( torch::dynamo::autograd::CompiledNodeArgs& args) const { - throw std::runtime_error( + TORCH_CHECK_NOT_IMPLEMENTED( + false, std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + - typeid(*this).name()); + typeid(*this).name()); } }; @@ -38,9 +39,10 @@ struct TORCH_API FunctionPostHook { // only implemented for python hooks, registers hook with compiled autograd virtual void compiled_args( torch::dynamo::autograd::CompiledNodeArgs& args) const { - throw std::runtime_error( + TORCH_CHECK_NOT_IMPLEMENTED( + false, std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + - typeid(*this).name()); + typeid(*this).name()); } }; @@ -51,17 +53,19 @@ struct TORCH_API PostAccumulateGradHook { // autograd virtual void compiled_args( torch::dynamo::autograd::CompiledNodeArgs& args) const { - throw std::runtime_error( - std::string("not yet implemented for compiled autograd: ") + - typeid(*this).name()); + TORCH_CHECK_NOT_IMPLEMENTED( + false, + std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + + typeid(*this).name()); } virtual void apply_with_saved( Variable&, torch::dynamo::autograd::SwapSavedVariables&) { - throw std::runtime_error( - std::string("not yet implemented for compiled autograd: ") + - typeid(*this).name()); + TORCH_CHECK_NOT_IMPLEMENTED( + false, + std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + + typeid(*this).name()); } }; diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp index c415d7131b33a0..2e2c96c464f80b 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.cpp +++ b/torch/csrc/autograd/functions/accumulate_grad.cpp @@ -14,24 +14,23 @@ namespace torch::autograd { -// AccumulateGrad sets sequence_nr to the max value so it's always called -// ASAP during backwards. -AccumulateGrad::AccumulateGrad(Variable variable_) - : Node(/*sequence_nr=*/UINT64_MAX), variable(std::move(variable_)) { - add_input_metadata(variable); -} +using torch::dynamo::autograd::IValuePacker; -// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) -auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { +namespace { + +void AccumulateGrad_apply_impl( + variable_list&& grads, + at::Tensor& variable, + at::Tensor& variable_grad, + int64_t num_expected_refs, + const std::function& grad_update, + std::mutex* mutex = nullptr) { check_input_variables("AccumulateGrad", grads, 1, 0); if (!grads[0].defined()) - return {}; - if (variable.grad_fn()) - throw std::logic_error( - "leaf variable has been moved into the graph interior"); + return; if (!variable.requires_grad()) - return {}; + return; // std::move(grads[0]) to avoid bumping up refcount at::Tensor new_grad = std::move(grads[0]); @@ -41,22 +40,82 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { // when updating the gradients. We don't ensure thread safety on hooks // and rely on user to provide thread safe hooks // see Note [Thread Safety on Autograd Node] - std::lock_guard lock(mutex_); + // need to still lock for eager here + std::optional> lock; + if (mutex != nullptr) { + lock.emplace(*mutex); + } + + AccumulateGrad::accumulateGrad( + variable, variable_grad, new_grad, num_expected_refs, grad_update); +} + +variable_list AccumulateGrad_apply_functional_no_hooks_ivalue( + const variable_list& grads, + const ivalue_list& args) { + PackedArgs r(args); + auto variable = r.unpack(); + auto variable_grad = r.unpack(); + auto has_post_hooks = r.unpack(); + + // Functional Tensors insert an Error node to assert that backward is never + // called + if (variable.grad_fn() && + std::dynamic_pointer_cast(variable.grad_fn()) == nullptr) { + throw std::logic_error( + "leaf variable has been moved into the graph interior"); + } - at::Tensor& grad = variable.mutable_grad(); + at::Tensor functional_grad; + AccumulateGrad_apply_impl( + variable_list(grads), + variable, + variable_grad, + 1 + has_post_hooks, + [&functional_grad](at::Tensor&& grad_update) { + functional_grad = std::move(grad_update); + }, + nullptr // no mutex needed since this is executed under a single thread + ); + if (!functional_grad.defined()) { + // In-place accumulation (Case 2.3) does not execute grad_update + functional_grad = std::move(variable_grad); + } + return {functional_grad}; +} +} // namespace + +// AccumulateGrad sets sequence_nr to the max value so it's always called +// ASAP during backwards. +AccumulateGrad::AccumulateGrad(Variable variable_) + : Node(/*sequence_nr=*/UINT64_MAX), variable(std::move(variable_)) { + add_input_metadata(variable); +} + +// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) +auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { + if (variable.grad_fn()) { + throw std::logic_error( + "leaf variable has been moved into the graph interior"); + } + + at::Tensor& variable_grad = variable.mutable_grad(); // If the function has post hooks (for example, a DDP allreduce hook), // call_function in Engine.cpp will temporarily bump the expected refcount - // by one, hence the addition of !post_hooks().empty() for 'num_expected_refs' - // in addition to the one reference that we're holding. + // by one, hence the addition of !post_hooks().empty() for + // 'num_expected_refs' in addition to the one reference that we're holding. // 'num_expected_refs' is used to determine whether or not we should clone // the grad or can steal the grad. - accumulateGrad( + AccumulateGrad_apply_impl( + std::move(grads), variable, - grad, - new_grad, + variable_grad, 1 + !post_hooks().empty() /* num_expected_refs */, - [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); }); + [&variable_grad](at::Tensor&& grad_update) { + variable_grad = std::move(grad_update); + }, + &mutex_); auto& hook = tensor_post_acc_grad_hooks(); if (hook != nullptr) { @@ -77,6 +136,7 @@ void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const { hook->compiled_args(args); } } + variable_list AccumulateGrad::apply_with_saved( const variable_list& grads, SwapSavedVariables& saved) { @@ -91,10 +151,29 @@ variable_list AccumulateGrad::apply_with_saved( saved.before(grad_copy); variable_copy.mutable_grad() = grad_copy; + // name() includes namespace for historical reasons: + // torch::autograd::AcumulateGrad For Compiled Autograd, we just want the op + // name without the namespace + std::string name = "AccumulateGrad"; + // proxy a call to torch.ops.inductor.accumulate_grad_.default - const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); - pyinterface->call_accumulate_grad( - saved.get_py_compiler(), variable_copy, grads[0]); + static bool flag [[maybe_unused]] = [&]() { + std::vector schema = { + IValuePacker::packed_type(), + IValuePacker::packed_type(), + IValuePacker::packed_type()}; + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + interface->bind_function( + saved.get_py_compiler(), + name, + AccumulateGrad_apply_functional_no_hooks_ivalue, + schema); + return true; + }(); + + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + interface->call_accumulate_grad( + saved.get_py_compiler(), variable_copy, grads[0], !post_hooks().empty()); auto& hook = tensor_post_acc_grad_hooks(); if (hook != nullptr) { diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index b1768ee2a93c57..97e689d36050c9 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -58,11 +58,7 @@ struct TORCH_API AccumulateGrad : public Node { return impl::post_acc_grad_hooks(variable); } - // Given a variable with its current grad as variable_grad, accumulates - // new_grad into variable_grad if in place accumulation is possible. - // Otherwise, uses 'update_grad' to update the grad for the variable. - - // "Gradient Layout Contract" + // Note: Gradient Layout Contract // // AccumulateGrad tries to stash strided (non-sparse) grads with memory layout // (strides) such that variables and grads interact efficiently in later @@ -101,6 +97,68 @@ struct TORCH_API AccumulateGrad : public Node { // degraded performance in Reducer.cpp or optimizer kernels, not death by // assert or silently bad numerics. + // Gradient Accumulation + // Given a variable with its current grad as variable_grad, accumulates + // new_grad into variable_grad if in place accumulation is possible. + // Otherwise, uses 'update_grad' to update the grad for the variable. + // + // Branch breakdown: + // - Case 1: Param has no existing grad + // - Case 1.1: Stealable dense new_grad + // . We aren't setting up for double-backward. + // . No other user-visible tensor references new_grad. + // . new_grad obeys the "Gradient Layout Contract", there has a special + // case, For MKLDNN tensor, which is a opaque tensor, assuming it obeys + // layout_contract. + // - Case 1.2: Stealable sparse new_grad + // . Can't detach sparse tensor (since metadata changes are not allowed + // after detach), so just create a new one for the grad which is a + // shallow copy. We need a shallow copy so that modifying the original + // grad tensor doesn't modify the grad we accumulate. + // . We only skip clone if indices and values themselves are contiguous + // for backward compatibility reasons. Since without this optimization, + // earlier we would clone the entire SparseTensor which cloned indices + // and values. For details see + // https://github.com/pytorch/pytorch/issues/34375. + // - Case 1.3: Cloning sparse/nested new_grad + // - Case 1.4: Cloning MKLDNN new_grad + // - Case 1.5: Deep copies new_grad according to the Gradient Layout + // Contract. + // - Case 2: Param has existing grad and grad mode is not enabled + // - This case is not strictly necessary, but it makes the first-order only + // case slightly more efficient. + // - Case 2.1: Sparse variable_grad + Dense new_grad + // . If `variable_grad` is sparse and `new_grad` is not sparse, their + // sum is not sparse, and we must change the TensorImpl type of + // `variable_grad` for it to store the result. However, changing the + // TensorImpl type of a tensor requires changing the tensor itself, and + // thus in this case we have to change the grad tensor. + // - Case 2.2: Vmap-incompatible + // . Ideally we'd perform an in-place operation to avoid changing + // the grad tensor. However, if that's impossible because the grads + // are vmap-incompatible (See NOTE: [vmap-incompatible in-place + // operations]), then we just add them out-of-place. + // - Case 2.3: In-place addition + // . In this case we can avoid changing the grad tensor. There are three + // scenarios when we'll hit this case: + // . `variable_grad` is sparse, and `new_grad` is sparse. + // . `variable_grad` is dense, and `new_grad` is sparse. + // . `variable_grad` is dense, and `new_grad` is dense. + // . `variable_grad` is mkldnn, and `new_grad` is mkldnn. + // + // In all of these four cases, `variable_grad += new_grad` is a + // valid operation which adds `new_grad` to `variable_grad` in + // place. `variable_grad` is thus still referring to the same tensor + // after the operation. + // . DistributedDataParallel(DDP) package relies on grad being + // mutated in place for saving peak memory usage. DDP will still + // work correctly if it is mutated out of place here, but DDP will + // maintain one extra copy of grad tensors in buffer and thus + // increase peak memory usage. + // - Case 3: Param has existing grad and grad mode is enabled + // - Case 3.1: Sparse variable_grad + Dense new_grad + // - Case 3.2: Not Sparse variable_grad + Dense new_grad + // // variable: the variable whose grad we're accumulating. // variable_grad: the current grad for the variable. // new_grad: new grad we want to accumulate for the variable. @@ -125,13 +183,7 @@ struct TORCH_API AccumulateGrad : public Node { at::caching::adjusted_use_count(new_grad) <= num_expected_refs && (new_grad.is_mkldnn() || utils::obeys_layout_contract(new_grad, variable))) { - // we aren't setting up for double-backward - // not sparse - // no other user-visible tensor references new_grad - // new_grad obeys the "Gradient Layout Contract", there has a special - // case, For MKLDNN tensor, which is a opaque tensor, assuming it obeys - // layout_contract. Under these conditions, we can steal new_grad - // without a deep copy. + // See Case 1.1: Stealable dense new_grad update_grad(new_grad.detach()); } else if ( !GradMode::is_enabled() && new_grad.is_sparse() && @@ -142,16 +194,7 @@ struct TORCH_API AccumulateGrad : public Node { new_grad._indices().use_count() <= 1 && new_grad._values().use_count() <= 1 && new_grad.use_count() <= num_expected_refs) { - // Can't detach sparse tensor (since metadata changes are not allowed - // after detach), so just create a new one for the grad which is a - // shallow copy. We need a shallow copy so that modifying the original - // grad tensor doesn't modify the grad we accumulate. - // We only skip clone if indices and values themselves are contiguous - // for backward compatibility reasons. Since without this optimization, - // earlier we would clone the entire SparseTensor which cloned indices - // and values. - // For details see https://github.com/pytorch/pytorch/issues/34375. - + // Case 1.2: Stealable sparse new_grad // No scenario where we expect this to be true currently TORCH_INTERNAL_ASSERT_DEBUG_ONLY( !at::caching::is_cached_tensor(new_grad._indices()) && @@ -166,54 +209,33 @@ struct TORCH_API AccumulateGrad : public Node { } else { if (new_grad.is_sparse() || new_grad.is_sparse_csr() || new_grad.is_nested()) { + // Case 1.3: Cloning sparse/nested new_grad update_grad(new_grad.clone()); } else { if (new_grad.is_mkldnn()) { + // Case 1.4: Cloning MKLDNN new_grad update_grad(new_grad.clone()); } else { - // Deep copies new_grad according to the "Gradient Layout Contract." + // Case 1.5: Deep copies new_grad according to the "Gradient + // Layout Contract." update_grad(utils::clone_obey_contract(new_grad, variable)); } } } } else if (!GradMode::is_enabled()) { - // This case is not strictly necessary, but it makes the first-order only - // case slightly more efficient. + // Case 2: Param has existing grad and grad mode is not enabled if (variable_grad.is_sparse() && !new_grad.is_sparse()) { - // If `variable_grad` is sparse and `new_grad` is not sparse, their - // sum is not sparse, and we must change the TensorImpl type of - // `variable_grad` for it to store the result. However, changing the - // TensorImpl type of a tensor requires changing the tensor itself, and - // thus in this case we have to change the grad tensor. + // Case 2.1: Sparse variable_grad + Dense new_grad auto result = new_grad + variable_grad; CHECK_RESULT(result, variable); update_grad(std::move(result)); } else if (!at::inplaceIsVmapCompatible(variable_grad, new_grad)) { - // Ideally we'd perform an in-place operation to avoid changing - // the grad tensor. However, if that's impossible because the grads - // are vmap-incompatible (See NOTE: [vmap-incompatible in-place - // operations]), then we just add them out-of-place. + // Case 2.2: Vmap-incompatible auto result = variable_grad + new_grad; CHECK_RESULT(result, variable); update_grad(std::move(result)); } else { - // In this case we can avoid changing the grad tensor. There are three - // scenarios when we'll hit this case: - // - // 1. `variable_grad` is sparse, and `new_grad` is sparse. - // 2. `variable_grad` is dense, and `new_grad` is sparse. - // 3. `variable_grad` is dense, and `new_grad` is dense. - // 4. `variable_grad` is mkldnn, and `new_grad` is mkldnn. - // - // In all of these four cases, `variable_grad += new_grad` is a - // valid operation which adds `new_grad` to `variable_grad` in - // place. `variable_grad` is thus still referring to the same tensor - // after the operation. - // Also DistributedDataParallel(DDP) package relies on grad being - // mutated in place for saving peak memory usage. DDP will still - // work correctly if it is mutated out of place here, but DDP will - // maintain one extra copy of grad tensors in buffer and thus - // increase peak memory usage. + // Case 2.3: In-place addition variable_grad += new_grad; CHECK_RESULT(variable_grad, variable); // ^ We could enforce the contract more aggressively here by writing: @@ -231,12 +253,15 @@ struct TORCH_API AccumulateGrad : public Node { // which may break user code. } } else { + // Case 3: Param has existing grad and grad mode is enabled at::Tensor result; if (variable_grad.is_sparse() && !new_grad.is_sparse()) { - // CPU backend throws an error on sparse + dense, so prefer dense + - // sparse here. + // Case 3.1: Sparse variable_grad + Dense new_grad + // CPU backend throws an error on sparse + dense, so + // prefer dense + sparse here. result = new_grad + variable_grad; } else { + // Case 3.2: Not Sparse variable_grad + Dense new_grad // Assumes operator+ result typically matches strides of first arg, // and hopes variable_grad was originally created obeying layout // contract. diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp index a310be58e2882f..af5763df659a0b 100644 --- a/torch/csrc/autograd/functions/basic_ops.cpp +++ b/torch/csrc/autograd/functions/basic_ops.cpp @@ -21,7 +21,7 @@ variable_list Error::apply(variable_list&& inputs) const { } void Error::compiled_args(CompiledNodeArgs& args) const { - // throw the error durring collect, the graph won't get compiled + // throw the error during collect, the graph won't get compiled apply(variable_list()); } diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index 0494f7f89a8b32..2f0027d59fa1f5 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -71,7 +71,10 @@ inline void set_history( // If the codegen triggers this, you most likely want to add your newly // added function to the DONT_REQUIRE_DERIVATIVE list in // tools/autograd/gen_variable_type.py - TORCH_INTERNAL_ASSERT(isDifferentiableType(variable.scalar_type())); + TORCH_CHECK( + isDifferentiableType(variable.scalar_type()), + "Autograd not support dtype: ", + variable.scalar_type()); auto output_nr = grad_fn->add_input_metadata(variable); impl::set_gradient_edge(variable, {grad_fn, output_nr}); } else { diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 4e31bc42d96d04..380060501882f8 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -307,7 +307,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { e.activityType() == (uint8_t)libkineto::ActivityType::GPU_USER_ANNOTATION; }) - .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); }); + .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); }) + // whether the event is hidden + .def("is_hidden_event", [](const KinetoEvent& e) { + return e.isHiddenEvent(); + }); m.def("_soft_assert_raises", &setSoftAssertRaises); m.def("_get_sequence_nr", &at::sequence_number::peek); @@ -605,7 +609,7 @@ static PyObject* set_autocast_enabled( HANDLE_TH_ERRORS static PythonArgParser parser( {"set_autocast_enabled(std::string_view device_type, bool enabled)", - "set_autocast_enabled(bool enabled)"}); // this signature is depracated. + "set_autocast_enabled(bool enabled)"}); // this signature is deprecated. ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); // Set at::kCUDA as default value to prevent BC-breaking changes. @@ -628,7 +632,7 @@ static PyObject* is_autocast_enabled( HANDLE_TH_ERRORS static PythonArgParser parser( {"is_autocast_enabled(std::string_view device_type)", - "is_autocast_enabled()"}); // this signature is depracated. + "is_autocast_enabled()"}); // this signature is deprecated. ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); // Set at::kCUDA as default value to prevent BC-breaking changes. diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 1448292d88415a..6880caddc8d250 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -622,7 +622,7 @@ void prepareProfiler( /* * Sending a warning and passing the non-standard event to the backend * Backend can abort if the event is not supported. - * TODO Should we gracefully drop the invalid event if we have atleast one + * TODO Should we gracefully drop the invalid event if we have at least one * valid? */ auto is_standard_event = [](const std::string& event) -> bool { @@ -690,17 +690,20 @@ void toggleCollectionDynamic( const bool enable, const std::set& activities) { if (activities.count(torch::autograd::profiler::ActivityType::CPU) > 0 && - activities.count(torch::autograd::profiler::ActivityType::CUDA) == 0) { + (activities.count(torch::autograd::profiler::ActivityType::CUDA) == 0 || + activities.count(torch::autograd::profiler::ActivityType::XPU) == 0)) { LOG(WARNING) - << "Toggling CPU activity with CUDA activity on may result in traces with CUDA events on artibrary tracks"; + << "Toggling CPU activity with GPU activity on may result in traces with GPU events on artibrary tracks"; } else if ( - activities.count(torch::autograd::profiler::ActivityType::CUDA) > 0 && + (activities.count(torch::autograd::profiler::ActivityType::CUDA) > 0 || + activities.count(torch::autograd::profiler::ActivityType::XPU) > 0) && activities.count(torch::autograd::profiler::ActivityType::CPU) == 0) { LOG(WARNING) - << "Toggling CUDA activity with CPU activity on may result in traces with incorrect correlation between CPU and CUDA events"; + << "Toggling GPU activity with CPU activity on may result in traces with incorrect correlation between CPU and GPU events"; } for (auto act : activities) { - if (act == torch::autograd::profiler::ActivityType::CUDA) { + if (act == torch::autograd::profiler::ActivityType::CUDA || + act == torch::autograd::profiler::ActivityType::XPU) { torch::profiler::impl::kineto::toggleCollectionDynamic(enable); } else if (act == torch::autograd::profiler::ActivityType::CPU) { toggleCPUCollectionDynamic(enable); @@ -933,6 +936,10 @@ bool KinetoEvent::hasKwinputs() const { return !kwinputs_.empty(); } +bool KinetoEvent::isHiddenEvent() const { + return result_ && result_->hidden_; +} + const std::unordered_map KinetoEvent::kwinputs() const { return kwinputs_; diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index 2e4b89da4b7958..34d65a0b8dd6b0 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -37,6 +37,7 @@ struct TORCH_API KinetoEvent { bool hasConcreteInputs() const; const c10::ArrayRef concreteInputs() const; bool hasKwinputs() const; + bool isHiddenEvent() const; const std::unordered_map kwinputs() const; uint64_t flops() const; int64_t sequenceNr() const; diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 1cd22fabb7fe0d..32f2cc34cf3deb 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -451,12 +451,12 @@ static PyObject* THPEngine_new( // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyMethodDef THPEngine_methods[] = { - {(char*)"run_backward", + {"run_backward", castPyCFunctionWithKeywords(THPEngine_run_backward), METH_VARARGS | METH_KEYWORDS, nullptr}, - {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr}, - {(char*)"is_checkpoint_valid", + {"queue_callback", THPEngine_queue_callback, METH_O, nullptr}, + {"is_checkpoint_valid", THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr}, diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 669bc30d5f3a22..dcbbcd550e2ab9 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -60,6 +60,20 @@ PyObject* THPGradientEdgeClass = nullptr; // Anonymous namespace for helpful functions used in this file namespace { +inline void check_legacy_fn_attr_access( + const std::shared_ptr& cdata, + const char* attr) { + TORCH_CHECK( + cdata, + "Attribute '", + attr, + "' is invalid for this instance of _C._FunctionBase. " + "Accessing this attribute directly on an instance of autograd.Function " + "is a legacy access pattern that is no longer supported. For examples " + "on how to use new‑style autograd functions, see " + "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); +} + // TODO: We shouldn't need to call this function because the engine // can already persist the errors for us. This still seems to be // needed for the DistEngine however. @@ -246,7 +260,6 @@ auto PyNode::apply_with_saved_impl( Py_CLEAR(py_fn->compiled_autograd_backward_state); } THPObjectPtr r(PyObject_CallMethod( - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) saved.get_py_compiler(), "proxy_call_backward", "OOOiOO", @@ -1142,13 +1155,7 @@ PyObject* process_outputs( PyObject* THPFunction_name(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS auto cdata = ((THPFunction*)self)->cdata.lock(); - TORCH_CHECK( - cdata, - "Attribute 'name' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + check_legacy_fn_attr_access(cdata, "name"); return THPUtils_packString(cdata->name()); END_HANDLE_TH_ERRORS } @@ -1156,6 +1163,7 @@ PyObject* THPFunction_name(PyObject* self, PyObject* noargs) { PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS; auto cdata = ((THPFunction*)self)->cdata.lock(); + check_legacy_fn_attr_access(cdata, "_sequence_nr"); return THPUtils_packUInt64(cdata->sequence_nr()); END_HANDLE_TH_ERRORS } @@ -1163,6 +1171,7 @@ PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) { PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) { HANDLE_TH_ERRORS; auto cdata = ((THPFunction*)self)->cdata.lock(); + check_legacy_fn_attr_access(cdata, "_set_sequence_nr"); cdata->set_sequence_nr(THPUtils_unpackUInt64(sequence_nr)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -1171,6 +1180,7 @@ PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) { PyObject* THPFunction_input_metadata(PyObject* self, void* unused) { HANDLE_TH_ERRORS; auto cdata = ((THPFunction*)self)->cdata.lock(); + check_legacy_fn_attr_access(cdata, "_input_metadata"); const auto num_inputs = cdata->num_inputs(); THPObjectPtr list(PyTuple_New(num_inputs)); if (!list) { @@ -1388,13 +1398,7 @@ PyObject* THPFunction__register_hook_dict(PyObject* _self, PyObject* _var) { new PyFunctionTensorPreHook(var->backward_hooks, tensor.output_nr())); auto self = (THPFunction*)_self; auto cdata = self->cdata.lock(); - TORCH_CHECK( - cdata, - "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + check_legacy_fn_attr_access(cdata, "_register_hook_dict"); cdata->add_tensor_pre_hook(std::move(hook)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -1404,13 +1408,7 @@ PyObject* THPFunction_register_hook(PyObject* _self, PyObject* hook) { HANDLE_TH_ERRORS auto self = (THPFunction*)_self; auto cdata = self->cdata.lock(); - TORCH_CHECK( - cdata, - "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + check_legacy_fn_attr_access(cdata, "register_hook"); return torch::autograd::registerFunctionHook(*cdata, hook); END_HANDLE_TH_ERRORS } @@ -1419,13 +1417,7 @@ PyObject* THPFunction_register_prehook(PyObject* _self, PyObject* hook) { HANDLE_TH_ERRORS auto self = (THPFunction*)_self; auto cdata = self->cdata.lock(); - TORCH_CHECK( - cdata, - "Attribute 'register_prehook' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + check_legacy_fn_attr_access(cdata, "register_prehook"); return torch::autograd::registerFunctionPreHook(*cdata, hook); END_HANDLE_TH_ERRORS } @@ -1568,13 +1560,7 @@ PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) { PyObject* THPFunction_next_functions(THPFunction* self, void* _unused) { HANDLE_TH_ERRORS auto cdata = self->cdata.lock(); - TORCH_CHECK( - cdata, - "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + check_legacy_fn_attr_access(cdata, "next_functions"); const auto num_outputs = cdata->num_outputs(); THPObjectPtr result(PyTuple_New(num_outputs)); if (!result) diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 93a7d9e3cb9220..1236fad45f3690 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -621,6 +621,14 @@ void initTorchFunctions(PyObject* module) { auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); return impl->was_inductor_storage_resized(); }); + py_module.def( + "_functionalize_inductor_storage_resized_counter", + [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return impl->inductor_storage_resized_counter(); + }); py_module.def( "_functionalize_are_all_mutations_hidden_from_autograd", [](const at::Tensor& t) { @@ -698,6 +706,11 @@ void initTorchFunctions(PyObject* module) { auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); return t_impl->has_data_mutation(); }); + py_module.def("_functionalize_mutation_counter", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return t_impl->mutation_counter(); + }); py_module.def( "_functionalize_get_storage_size", [](const at::Tensor& t, bool before) { TORCH_INTERNAL_ASSERT( @@ -707,16 +720,24 @@ void initTorchFunctions(PyObject* module) { auto size = wrapper->get_storage_size(/*before=*/before); return size; }); - py_module.def("_functionalize_set_storage_changed", [](const at::Tensor& t) { + py_module.def("_functionalize_mark_storage_changed", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t); - wrapper->set_storage_changed(); + wrapper->mark_storage_changed(); }); py_module.def("_functionalize_was_storage_changed", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t); return wrapper->was_storage_changed(); }); + py_module.def( + "_functionalize_storage_changed_counter", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + auto t_impl = + at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return t_impl->storage_changed_counter(); + }); py_module.def( "_functionalize_unsafe_set", [](at::Tensor& dst, const at::Tensor& src) { // Forcefully/unsafely dumps src.storage into dst. diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index af31d5911f6e27..b0235da869fbcc 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -238,7 +238,7 @@ void registerPythonTensorClass( c10::Device dev(device); TORCH_CHECK( - dev.type() == kXLA, "Only the python class for XLA can be overriden"); + dev.type() == kXLA, "Only the python class for XLA can be overridden"); if (device_to_py_class_[static_cast(dev.type())] != nullptr) { TORCH_WARN( "Overriding a previously registered python class for ", dev.str()); @@ -409,13 +409,13 @@ static bool THPVariable_tryResurrect(THPVariable* self) { static int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) { TORCH_INTERNAL_ASSERT( - false, "TensorBase tp_traverse function was not overriden properly"); + false, "TensorBase tp_traverse function was not overridden properly"); return 0; } static int THPFake_clear(THPVariable* self) { TORCH_INTERNAL_ASSERT( - false, "TensorBase tp_clear function was not overriden properly"); + false, "TensorBase tp_clear function was not overridden properly"); return 0; } @@ -817,6 +817,7 @@ static PyObject* THPVariable_get_python_dispatch( // - static Tensor fn(const Tensor&); // - This function calls the relevant ATen on the tensor template +// NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility) struct GetterBase { static PyObject* getter(THPVariable* self, void* /*unused*/) { HANDLE_TH_ERRORS @@ -2330,7 +2331,7 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { if (PyType_Type.tp_init(cls, args, kwargs) < 0) { return -1; } - // It is important for all three of these to be overriden correctly for the + // It is important for all three of these to be overridden correctly for the // resurrection checks to properly happen. In particular, an older version // was not overriding tp_clear here. This lead to the default subtype_clear // running on the Tensor object (as only TensorBase tp_clear was custom), diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index ae1780e66ba71d..3aa241b06f3a36 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -302,9 +303,23 @@ static bool treatSequenceAsTuple(PyObject* index) { } if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) || PySlice_Check(obj.get())) { + TORCH_WARN( + "Using a non-tuple sequence for " + "multidimensional indexing is deprecated and will be changed in " + "pytorch 2.9; use x[tuple(seq)] instead of " + "x[seq]. In pytorch 2.9 this will be interpreted as tensor index, " + "x[torch.tensor(seq)], which will result either in an error or a " + "different result"); return true; } if (obj.get() == Py_Ellipsis || obj.get() == Py_None) { + TORCH_WARN( + "Using a non-tuple sequence for " + "multidimensional indexing is deprecated and will be changed in " + "pytorch 2.9; use x[tuple(seq)] instead of " + "x[seq]. In pytorch 2.9 this will be interpreted as tensor index, " + "x[torch.tensor(seq)], which will result either in an error or a " + "different result"); return true; } } diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 54a478707b7752..5a1b7be0a15ddd 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -186,7 +186,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator std::function end_allocate_to_pool_fn_; std::function relase_pool_fn_; std::mutex allocator_mutex_; - // We do the bookeeping here in order to simplify custom allocators + // We do the bookkeeping here in order to simplify custom allocators std::unordered_map allocation_metadata_; bool initialized_ = false; diff --git a/torch/csrc/cuda/Event.cpp b/torch/csrc/cuda/Event.cpp index 82b5dc55d79d8a..bc5780ad3ea02f 100644 --- a/torch/csrc/cuda/Event.cpp +++ b/torch/csrc/cuda/Event.cpp @@ -23,19 +23,21 @@ static PyObject* THCPEvent_pynew( unsigned char enable_timing = 0; unsigned char blocking = 0; unsigned char interprocess = 0; + unsigned char external = 0; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) constexpr const char* kwlist[] = { - "enable_timing", "blocking", "interprocess", nullptr}; + "enable_timing", "blocking", "interprocess", "external", nullptr}; if (!PyArg_ParseTupleAndKeywords( args, kwargs, - "|bbb", + "|bbbb", // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(kwlist), &enable_timing, &blocking, - &interprocess)) { + &interprocess, + &external)) { return nullptr; } @@ -47,7 +49,8 @@ static PyObject* THCPEvent_pynew( THCPEvent* self = (THCPEvent*)ptr.get(); unsigned int flags = (blocking ? cudaEventBlockingSync : cudaEventDefault) | (enable_timing ? cudaEventDefault : cudaEventDisableTiming) | - (interprocess ? cudaEventInterprocess : cudaEventDefault); + (interprocess ? cudaEventInterprocess : cudaEventDefault) | + (external ? cudaEventExternal : cudaEventDefault); new (&self->cuda_event) at::cuda::CUDAEvent(flags); @@ -89,8 +92,7 @@ static PyObject* THCPEvent_from_ipc_handle( } THCPEvent* self = (THCPEvent*)ptr.get(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - cudaIpcEventHandle_t handle; + cudaIpcEventHandle_t handle{}; std::memcpy(&handle, handle_string.c_str(), handle_string.size()); new (&self->cuda_event) at::cuda::CUDAEvent(device.index(), &handle); @@ -172,8 +174,7 @@ static PyObject* THCPEvent_synchronize(PyObject* _self, PyObject* noargs) { static PyObject* THCPEvent_ipc_handle(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THCPEvent*)_self; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - cudaIpcEventHandle_t handle; + cudaIpcEventHandle_t handle{}; self->cuda_event.ipc_handle(&handle); return PyBytes_FromStringAndSize((const char*)&handle, sizeof(handle)); END_HANDLE_TH_ERRORS diff --git a/torch/csrc/cuda/GdsFile.cpp b/torch/csrc/cuda/GdsFile.cpp index e46287d9a0ae72..ac304e9617ae77 100644 --- a/torch/csrc/cuda/GdsFile.cpp +++ b/torch/csrc/cuda/GdsFile.cpp @@ -21,7 +21,7 @@ std::string cuGDSFileGetErrorString(T status) { : std::string(c10::utils::str_error(errno)); } -// To get error message for Buf/Handle registeration APIs that return +// To get error message for Buf/Handle registration APIs that return // CUfileError_t template < class T, @@ -91,8 +91,7 @@ void gds_deregister_buffer(const at::Storage& storage) { int64_t gds_register_handle(int fd) { CUfileDescr_t cf_descr; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - CUfileHandle_t cf_handle; + CUfileHandle_t cf_handle{}; memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t)); cf_descr.handle.fd = fd; cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp index 827cfec858a52a..377a92667a30f7 100644 --- a/torch/csrc/cuda/Graph.cpp +++ b/torch/csrc/cuda/Graph.cpp @@ -26,7 +26,7 @@ void THCPGraph_init(PyObject* module) { torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle); shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph") - .def(py::init<>()) + .def(py::init(), py::arg("keep_graph") = false) .def( "capture_begin", [](::at::cuda::CUDAGraph& self, @@ -56,6 +56,9 @@ void THCPGraph_init(PyObject* module) { .def( "capture_end", torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end)) + .def( + "instantiate", + torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::instantiate)) .def( "register_generator_state", [](::at::cuda::CUDAGraph& self, py::handle raw_generator) { @@ -87,5 +90,16 @@ void THCPGraph_init(PyObject* module) { "debug_dump", torch::wrap_pybind_function_no_gil( &::at::cuda::CUDAGraph::debug_dump), - py::arg("debug_path")); + py::arg("debug_path")) + .def( + "raw_cuda_graph", + [](::at::cuda::CUDAGraph& self) { + cudaGraph_t graph = self.raw_cuda_graph(); + // We return a raw int here, since otherwise pybind11 will + // try to return the underlying struct of cudaGraph_t + // points to, which is opaque and therefore causes a + // compile error. + return reinterpret_cast(graph); + }, + py::call_guard()); } diff --git a/torch/csrc/cuda/MemPool.cpp b/torch/csrc/cuda/MemPool.cpp index b651a4b5e68aa0..feb22e360bb985 100644 --- a/torch/csrc/cuda/MemPool.cpp +++ b/torch/csrc/cuda/MemPool.cpp @@ -16,12 +16,15 @@ void THCPMemPool_init(PyObject* module) { .def( py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator, bool is_user_created, - bool use_on_oom) { + bool use_on_oom, + bool symmetric) { torch::utils::device_lazy_init(at::kCUDA); return std::make_shared<::c10::cuda::MemPool>( - allocator, is_user_created, use_on_oom); + allocator, is_user_created, use_on_oom, symmetric); })) .def_property_readonly("id", &::c10::cuda::MemPool::id) + .def_property_readonly( + "is_symmetric", &::c10::cuda::MemPool::is_symmetric) .def_property_readonly("allocator", &::c10::cuda::MemPool::allocator) .def("use_count", &::c10::cuda::MemPool::use_count); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index cbb8910fff57cc..b44ce311ecd924 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -64,7 +64,7 @@ PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) { auto device = THPUtils_unpackLong(arg); torch::utils::device_lazy_init(at::kCUDA); - c10::cuda::set_device(static_cast(device)); + c10::cuda::set_device(static_cast(device), /*force*/ true); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -1125,7 +1125,7 @@ static void registerCudaDeviceProperties(PyObject* module) { m.def( "_cuda_record_memory_history_legacy", - static_cast( + static_cast( torch::cuda::_record_memory_history)); m.def( @@ -1136,6 +1136,7 @@ static void registerCudaDeviceProperties(PyObject* module) { const std::string&, size_t, bool, + bool, bool)>(torch::cuda::_record_memory_history)); m.def("_cuda_isHistoryEnabled", []() { @@ -1352,12 +1353,6 @@ static void registerCudaPluggableAllocator(PyObject* module) { return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter()); }); - m.def("_storage_Use_Count", [](size_t storage_impl_ptr) { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; - return c10::raw::weak_intrusive_ptr::use_count(storage_impl); - }); - m.def( "_tensors_data_ptrs_at_indices_equal", [](py::list& tensors, py::list& data_ptrs, py::list& indices) { diff --git a/torch/csrc/cuda/memory_snapshot.cpp b/torch/csrc/cuda/memory_snapshot.cpp index f8ee958e4dc17b..3abd4acddc7960 100644 --- a/torch/csrc/cuda/memory_snapshot.cpp +++ b/torch/csrc/cuda/memory_snapshot.cpp @@ -129,8 +129,11 @@ CapturedTraceback* getFromContext( "attempting to gather stack context from the wrong StackContext type."); } -at::CallbackHandle _initRecordAnnotations() { - return at::addGlobalCallback( +#define ADD_CALLBACK(callbackType) at::add##callbackType##Callback +at::CallbackHandle _initRecordAnnotations(bool useGlobalCallback) { + auto addCallback = + useGlobalCallback ? ADD_CALLBACK(Global) : ADD_CALLBACK(ThreadLocal); + return addCallback( at::RecordFunctionCallback( [](const at::RecordFunction& fn) -> std::unique_ptr { @@ -169,12 +172,16 @@ at::CallbackHandle _initCompileContexts() { .scopes({at::RecordScope::FUNCTION})); } -void setRecordFunctionCallbacks(bool enabled, bool compileContext) { +void setRecordFunctionCallbacks( + bool enabled, + bool compileContext, + bool globalRecordAnnotations) { // Handle Callbacks under mutex auto lock = callbackManager.lockCallbackMutex(); if (enabled) { if (callbackManager.getAnnotationHandle() == 0) { - callbackManager.setAnnotationHandle(_initRecordAnnotations()); + callbackManager.setAnnotationHandle( + _initRecordAnnotations(globalRecordAnnotations)); } if (compileContext && callbackManager.getCompileContextHandle() == 0) { callbackManager.setCompileContextHandle(_initCompileContexts()); @@ -184,7 +191,7 @@ void setRecordFunctionCallbacks(bool enabled, bool compileContext) { at::removeCallback(callbackManager.getAnnotationHandle()); callbackManager.setAnnotationHandle(0); } - if (compileContext && callbackManager.getCompileContextHandle() != 0) { + if (callbackManager.getCompileContextHandle() != 0) { at::removeCallback(callbackManager.getCompileContextHandle()); callbackManager.setCompileContextHandle(0); } @@ -200,7 +207,8 @@ void _record_memory_history( bool trace_alloc_record_context, bool record_cpp_context, bool clearHistory, - bool compileContext) { + bool compileContext, + bool globalRecordAnnotations) { c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather; if (enabled && record_cpp_context && (trace_alloc_record_context || record_context)) { @@ -216,7 +224,7 @@ void _record_memory_history( } at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); - setRecordFunctionCallbacks(enabled, compileContext); + setRecordFunctionCallbacks(enabled, compileContext, globalRecordAnnotations); c10::cuda::CUDACachingAllocator::recordHistory( enabled, recorder, trace_alloc_max_entries, when, clearHistory); } @@ -235,7 +243,8 @@ void _record_memory_history( const std::string& stacks, size_t max_entries, bool clearHistory, - bool compileContext) { + bool compileContext, + bool globalRecordAnnotations) { if (enabled) { checkOptionIn( *enabled, @@ -269,7 +278,8 @@ void _record_memory_history( } } at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); - setRecordFunctionCallbacks(enabled.has_value(), compileContext); + setRecordFunctionCallbacks( + enabled.has_value(), compileContext, globalRecordAnnotations); c10::cuda::CUDACachingAllocator::recordHistory( enabled.has_value(), recorder, max_entries, when, clearHistory); } diff --git a/torch/csrc/cuda/memory_snapshot.h b/torch/csrc/cuda/memory_snapshot.h index 5d89f2f6534937..fc366f424292e2 100644 --- a/torch/csrc/cuda/memory_snapshot.h +++ b/torch/csrc/cuda/memory_snapshot.h @@ -16,7 +16,8 @@ TORCH_CUDA_CU_API void _record_memory_history( bool trace_alloc_record_context = false, bool record_cpp_context = false, bool clearHistory = false, - bool compileContext = false); + bool compileContext = false, + bool globalRecordAllocations = false); TORCH_CUDA_CU_API void _record_memory_history( std::optional enabled = "all", @@ -24,7 +25,8 @@ TORCH_CUDA_CU_API void _record_memory_history( const std::string& stacks = "all", size_t max_entries = SIZE_MAX, bool clearHistory = false, - bool compileContext = false); + bool compileContext = false, + bool globalRecordAllocations = false); TORCH_CUDA_CU_API std::string _memory_snapshot_pickled(); diff --git a/torch/csrc/cuda/shared/nvtx.cpp b/torch/csrc/cuda/shared/nvtx.cpp index d13562883bc1ea..d28e8ae222eaa1 100644 --- a/torch/csrc/cuda/shared/nvtx.cpp +++ b/torch/csrc/cuda/shared/nvtx.cpp @@ -41,7 +41,7 @@ static void device_callback_range_start(void* userData) { } static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) { - RangeHandle* handle = (RangeHandle*)calloc(sizeof(RangeHandle), 1); + auto handle = static_cast(calloc(1, sizeof(RangeHandle))); handle->msg = strdup(msg); handle->id = 0; TORCH_CHECK( diff --git a/torch/csrc/deploy/README.md b/torch/csrc/deploy/README.md index c757287f8e1bd2..2d40ca8361ff49 100644 --- a/torch/csrc/deploy/README.md +++ b/torch/csrc/deploy/README.md @@ -1,2 +1,2 @@ -# torch::deploy has been moved to pytorch/multipy -Please check out [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy) to find the new home for torch::deploy. +# torch::deploy has been moved to pytorch/multipy +Please check out [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy) to find the new home for torch::deploy. diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.h b/torch/csrc/distributed/autograd/engine/dist_engine.h index 362c78fa07b1fb..7911462307fb49 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.h +++ b/torch/csrc/distributed/autograd/engine/dist_engine.h @@ -15,7 +15,7 @@ class BackwardPassCleanupGuard; // This is a singleton class responsible for running distributed backward // passes. This engine relies heavily on the vanilla autograd engine and tries -// to re-use it as much as possible. This class is mostly responsible for the +// to reuse it as much as possible. This class is mostly responsible for the // distributed aspects of autograd and tries to hook into the autograd engine // where convenient. diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.cpp index 19db3671c7decd..52e3465f85aba5 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.cpp @@ -45,7 +45,7 @@ RpcWithProfilingReq::RpcWithProfilingReq( tensors_(std::move(tensors)), profilerConfig_(std::move(profilerConfig)), profilingKeyId_(profilingKeyId) { - TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cant be null"); + TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc can't be null"); } rpc::MessageType RpcWithProfilingReq::wrappedMessageType() const { diff --git a/torch/csrc/distributed/c10d/FileStore.cpp b/torch/csrc/distributed/c10d/FileStore.cpp index 6fcbd4ad86f67f..862c983d9e0503 100644 --- a/torch/csrc/distributed/c10d/FileStore.cpp +++ b/torch/csrc/distributed/c10d/FileStore.cpp @@ -323,7 +323,7 @@ FileStore::~FileStore() { auto numFinishedWorker = addHelper(cleanupKey_, 1); auto refCount = addHelper(refCountKey_, -1); // The last worker cleans up the file. If numWorkers was not initialized to - // a specific postive value (i.e. meaning that there was not a fixed number + // a specific positive value (i.e. meaning that there was not a fixed number // of workers), we don't attempt to clean. // Clean up the file if number of references is 0. if (refCount == 0 && numWorkers_ >= 0 && numFinishedWorker >= numWorkers_) { diff --git a/torch/csrc/distributed/c10d/FlightRecorder.cpp b/torch/csrc/distributed/c10d/FlightRecorder.cpp index f0f9f3064e0ede..bc47f40c6dc61d 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.cpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.cpp @@ -1,133 +1,7 @@ -#ifdef USE_C10D_NCCL -#include -#include -#endif // USE_C10D_NCCL - -#include -#include -#include -#include -#include - -#include - -#include -#ifdef USE_C10D_NCCL -#include -#endif // USE_C10D_NCCL -#include +#include namespace c10d { -#ifdef USE_C10D_NCCL -control_plane::RegisterHandler dumpHandler{ - "dump_nccl_trace_pickle", - [](const control_plane::Request& req, control_plane::Response& res) { - const auto& params = req.params(); - size_t validParamCount = 0; - - // valid params - const std::string includeCollectivesStr = "includecollectives"; - const std::string includeStackTracesStr = "includestacktraces"; - const std::string onlyActiveStr = "onlyactive"; - - std::unordered_map processedParams = { - {includeCollectivesStr, true}, - {includeStackTracesStr, true}, - {onlyActiveStr, false}}; - - for (const auto& [paramName, paramValue] : params) { - auto it = processedParams.find(paramName); - if (it != processedParams.end()) { - validParamCount++; - if (paramValue == "true") { - it->second = true; - } else if (paramValue == "false") { - it->second = false; - } else { - res.setStatus(400); - res.setContent( - "Invalid value for " + paramName + - " valid values are true or false", - "text/plain"); - return; - } - } - } - if (validParamCount < params.size()) { - res.setStatus(400); - res.setContent( - "Invalid parameters - unexpected param passed in", "text/plain"); - return; - } - res.setContent( - dump_nccl_trace( - processedParams[includeCollectivesStr], - processedParams[includeStackTracesStr], - processedParams[onlyActiveStr]), - "application/octet-stream"); - }}; - -control_plane::RegisterHandler jsonDumpHandler{ - "dump_nccl_trace_json", - [](const control_plane::Request& req, control_plane::Response& res) { - const auto& params = req.params(); - size_t validParamCount = 0; - - // valid params - const std::string includeCollectivesStr = "includecollectives"; - const std::string onlyActiveStr = "onlyactive"; - - std::unordered_map processedParams = { - {includeCollectivesStr, true}, {onlyActiveStr, false}}; - - for (const auto& [paramName, paramValue] : params) { - auto it = processedParams.find(paramName); - if (it != processedParams.end()) { - validParamCount++; - if (paramValue == "true") { - it->second = true; - } else if (paramValue == "false") { - it->second = false; - } else { - res.setStatus(400); - res.setContent( - "Invalid value for " + paramName + - " valid values are true or false", - "text/plain"); - return; - } - } - } - if (validParamCount < params.size()) { - res.setStatus(400); - res.setContent( - "Invalid parameters - unexpected param passed in", "text/plain"); - return; - } - res.setStatus(200); - res.setContent( - dump_nccl_trace_json( - processedParams[includeCollectivesStr], - processedParams[onlyActiveStr]), - "application/json"); - }}; - -/* Helper used by work::getDuration() and nccl flight recorder */ -float getDurationFromEvent( - at::cuda::CUDAEvent& ncclStartEvent, - at::cuda::CUDAEvent& ncclEndEvent) { - TORCH_CHECK( - ncclEndEvent.query(), - "getDuration can only be called after work is succeeded.") - return ncclStartEvent.elapsed_time(ncclEndEvent); -} -#endif // USE_C10D_NCCL - -float getDurationFromEvent(c10::Event& startEvent, c10::Event& endEvent) { - TORCH_CHECK(false, "getDuration not supported by c10::Event."); -} - void DebugInfoWriter::write(const std::string& trace) { // Open a file for writing. The ios::binary flag is used to write data as // binary. @@ -167,9 +41,10 @@ DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { std::filesystem::create_directories(cacheDirPath); auto defaultLocation = cacheDirPath / "nccl_trace_rank_"; + // For internal bc compatibility, we keep the old the ENV check. std::string fileNamePrefix = getCvarString( {"TORCH_FR_DUMP_TEMP_FILE", "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, - defaultLocation.c_str()); + defaultLocation.string().c_str()); // Using std::unique_ptr here to auto-delete the writer object // when the pointer itself is destroyed. std::unique_ptr writerPtr( @@ -190,545 +65,40 @@ void DebugInfoWriter::registerWriter(std::unique_ptr writer) { writer_ = std::move(writer); } -// Returns the traceback of current entry, in string form. -// Note: `getTraceback` invokes `torch::symbolize`, which may need to acquire -// the GIL. If you don't want to block the current thread or take the risk of a -// GIL deadlock, you can use an asynchronous calling mechanism like std::async. -template -std::string FlightRecorder::Entry::getTraceback() { - torch::CapturedTraceback* traceback = traceback_.get(); - torch::SymbolizedTracebacks s_tbs = torch::symbolize({traceback}); - // We use 0 because we only have one traceback here. - const auto& s_tb = s_tbs.tracebacks.at(0); - std::stringstream oss; - for (auto idx : c10::irange(s_tb.size())) { - auto frame_id = s_tb[idx]; - const auto& frame = s_tbs.all_frames.at(frame_id); - oss << "#" << idx << " " << frame.funcname << " from " << frame.filename - << ":" << frame.lineno << '\n'; - } - /* Resulted format is like: - #0 all_reduce from pytorch/torch/distributed/distributed_c10d.py:2696 - #1 wrapper from pytorch/torch/distributed/c10d_logger.py:83 - #2 bar from /home/user/repro.py:15 - #3 foo from /home/user/repro.py:24 - #4 main from /home/user/repro.py:34 - #5 from /home/user/repro.py:40 - */ - return oss.str(); -} - -template -std::optional FlightRecorder::record( - size_t pg_id, - const std::tuple& pg_name, - size_t collective_seq_id, - size_t p2p_seq_id, - size_t op_id, - std::string profiling_name, - const std::vector& inputs, - const std::vector& outputs, - EventType* start, - EventType* end, - std::chrono::milliseconds timeout_ms, - std::shared_ptr pg_status, - bool isP2P) { - if (!enabled_) { - return std::nullopt; - } - if (all_pg_status_.find(pg_id) == all_pg_status_.end()) { - // Current pg_status is not in FR. - all_pg_status_[pg_id] = std::move(pg_status); - } - auto traceback = - torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - std::lock_guard guard(mutex_); - - auto te = Entry{ - id_, - pg_id, - pg_name, - collective_seq_id, - p2p_seq_id, - op_id, - std::move(profiling_name), - std::move(traceback), - start, - end, - c10::getTime(), - timeout_ms.count(), - isP2P, - std::nullopt, - std::nullopt, - std::nullopt, - {}, - {}, - {}, - {}, - {}, - false}; - - for (const auto& input : inputs) { - c10::IntArrayRef sizes = input.sizes(); - te.input_dtypes_.push_back(input.dtype().toScalarType()); - te.input_dims_.push_back(static_cast(sizes.size())); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - for (const auto& output : outputs) { - c10::IntArrayRef sizes = output.sizes(); - te.output_dtypes_.push_back(output.dtype().toScalarType()); - te.output_dims_.push_back(static_cast(sizes.size())); - te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); - } - - if (entries_.size() < max_entries_) { - entries_.emplace_back(std::move(te)); - } else { - entries_[next_++] = std::move(te); - if (next_ == max_entries_) { - next_ = 0; - } - } - return id_++; -} - -template -void FlightRecorder::record_pg_ranks( - const std::tuple& pg_name, - std::vector ranks) { - if (!enabled_) { - return; - } - std::lock_guard guard(mutex_); - pg_name_to_ranks_[pg_name] = std::move(ranks); -} - -template -void FlightRecorder::record_accelerator_version( - const std::string nccl_version) { - if (!enabled_) { - return; - } - std::lock_guard guard(mutex_); - nccl_version_ = std::move(nccl_version); -} - -template -void FlightRecorder::update_state(Entry& r) { - try { - if (r.start_ != nullptr) { - bool started = r.start_->query(); - if (started && !r.time_discovered_started_) { - r.time_discovered_started_ = c10::getTime(); - } - } - if (r.end_ != nullptr) { - bool completed = r.end_->query(); - if (completed && !r.time_discovered_completed_) { - r.time_discovered_completed_ = c10::getTime(); - } - } - } catch (std::exception& e) { - LOG(ERROR) << "Failed to update state for entry " << r.id_ << ": " - << r.profiling_name_ << " with error: " << e.what(); - } -} - -template -std::vector::Entry> FlightRecorder< - EventType>::dump_entries() { - std::lock_guard guard(mutex_); - std::vector result; - result.reserve(entries_.size()); - result.insert( - result.end(), - entries_.begin() + static_cast(next_), - entries_.end()); - result.insert( - result.end(), - entries_.begin(), - entries_.begin() + static_cast(next_)); - // query any remaining events - for (auto& r : result) { - update_state(r); - r.start_ = r.end_ = nullptr; - } - return result; -} - -template -// Returns the entry with the given id, if it exists. Otherwise, returns -// std::nullopt. -std::optional::Entry> FlightRecorder< - EventType>::getEntry(std::optional id) { - if (!enabled_ || !id) { - return std::nullopt; - } - - std::unique_lock guard(mutex_); - Entry entry = entries_.at(*id % max_entries_); - if (entry.id_ == *id) { - return entry; - } else { - return std::nullopt; - } -} - -template -void FlightRecorder::retire_id( - std::optional id, - bool compute_duration) { - if (!enabled_ || !id) { - return; - } - - bool can_compute_duration = false; - EventType* startEvent = nullptr; - EventType* endEvent = nullptr; - std::optional duration = std::nullopt; - - std::unique_lock guard(mutex_); - - Entry* entry = &entries_.at(*id % max_entries_); - if (entry->id_ == *id) { - update_state(*entry); - - if (compute_duration) { - can_compute_duration = entry->time_discovered_completed_.has_value() && - entry->start_ && entry->end_; - startEvent = entry->start_; - endEvent = entry->end_; - } - entry->retired_ = true; - entry->start_ = entry->end_ = nullptr; - } - - if (can_compute_duration) { - // Compute duration without without holding the lock, because - // cudaEventDuration() can hang, and we need to acquire the lock before we - // can dump(), which we never want to block. - guard.unlock(); - duration = getDurationFromEvent(*startEvent, *endEvent); - guard.lock(); - - // Refresh the entry pointer, see if the entry has been overwritten - entry = &entries_.at(*id % max_entries_); - if (entry->id_ != *id) { - LOG(INFO) << "retire_id abandoned for id " << *id - << ", event was overwritten while waiting to compute duration."; - return; - } - if (duration.has_value()) { - entry->duration_ = duration; - } - } -} - -template -const c10::List FlightRecorder::getCollectiveTrace( - bool includeStacktraces, - bool onlyActive) { - auto entries = new_list(); - // Entries are returned in the order they were recorded - auto result = dump_entries(); - std::vector tracebacks; - torch::SymbolizedTracebacks stracebacks; - std::vector all_frames; - if (includeStacktraces) { - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - stracebacks = torch::symbolize(tracebacks); - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_key, f.funcname); - d.insert(filename_key, f.filename); - d.insert(line_key, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); - } - } - for (auto i : c10::irange(result.size())) { - auto dict = new_dict(); - auto& e = result.at(i); - // Skip completed events - if (onlyActive && e.time_discovered_completed_.has_value()) { - continue; - } - if (includeStacktraces) { - auto& tb = stracebacks.tracebacks.at(i); - auto frames = new_list(); - for (auto frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_key, frames); - } - - dict.insert(record_id_key, int64_t(e.id_)); - dict.insert(pg_id_key, int64_t(e.pg_id_)); - dict.insert(pg_name_key, e.pg_name_); - dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); - dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); - dict.insert(op_id_key, int64_t(e.op_id_)); - dict.insert(profiling_name_key, e.profiling_name_); - dict.insert(time_created_key, int64_t(e.time_created_)); - if (e.duration_) { - dict.insert(duration_key, *e.duration_); - } - - auto it = e.sizes_.begin(); - auto read_sizes = [&](const c10::SmallVector& dims) { - auto sizes = new_list(); - for (auto dim : dims) { - auto arg_sizes = new_list(); - for ([[maybe_unused]] auto i : c10::irange(dim)) { - arg_sizes.push_back(*it++); - } - sizes.push_back(arg_sizes); - } - return sizes; - }; - - dict.insert(input_sizes_key, read_sizes(e.input_dims_)); - std::vector input_dtypes_strs; - input_dtypes_strs.reserve(e.input_dtypes_.size()); - for (const auto& input_dtype : e.input_dtypes_) { - input_dtypes_strs.emplace_back(c10::toString(input_dtype)); - } - dict.insert(input_dtypes_key, input_dtypes_strs); - dict.insert(output_sizes_key, read_sizes(e.output_dims_)); - std::vector output_dtypes_strs; - output_dtypes_strs.reserve(e.output_dtypes_.size()); - for (const auto& output_dtype : e.output_dtypes_) { - output_dtypes_strs.emplace_back(c10::toString(output_dtype)); - } - dict.insert(output_dtypes_key, output_dtypes_strs); - if (e.time_discovered_completed_.has_value()) { - dict.insert(state_key, completed_state); - } else if (e.time_discovered_started_.has_value()) { - dict.insert(state_key, started_state); - } else { - dict.insert(state_key, scheduled_state); - } - - dict.insert( - time_discovered_started_key, - e.time_discovered_started_.has_value() - ? int64_t(*e.time_discovered_started_) - : c10::IValue()); - dict.insert( - time_discovered_completed_key, - e.time_discovered_completed_.has_value() - ? int64_t(*e.time_discovered_completed_) - : c10::IValue()); - dict.insert(retired_key, e.retired_); - dict.insert(timeout_key, e.timeout_ms_); - dict.insert(is_p2p_key, e.isP2P_); - - entries.push_back(dict); - } - return entries; -} - -template -const c10::Dict FlightRecorder< - EventType>::getPgConfig() { - auto pg_config = new_dict(); - for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { - auto pg_info = new_dict(); - pg_info.insert("name", std::get<0>(pg_name)); - pg_info.insert("desc", std::get<1>(pg_name)); - pg_info.insert("ranks", ranks_str(ranks)); - pg_config.insert(std::get<0>(pg_name), pg_info); - } - return pg_config; -} - -template -const std::map> FlightRecorder< - EventType>::getPgConfigJson() { - std::map> result; - for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { - auto pg_info = std::map(); - pg_info["name"] = std::get<0>(pg_name); - pg_info["desc"] = std::get<1>(pg_name); - pg_info["ranks"] = ranks_str(ranks); - result.emplace(std::get<0>(pg_name), pg_info); - } - return result; -} - -template -const c10::Dict FlightRecorder< - EventType>::getPgStatus() { - auto all_pg_status = new_dict(); - for (const auto& [pg_id, status] : all_pg_status_) { - auto pg_status = new_dict(); - pg_status.insert("last_enqueued_collective", status->lastEnqueuedSeq); - pg_status.insert("last_started_collective", status->lastStartedSeq); - pg_status.insert("last_completed_collective", status->lastCompletedSeq); - all_pg_status.insert(std::to_string(pg_id), pg_status); - } - return all_pg_status; -} - -template -const std::map> FlightRecorder< - EventType>::getPgStatusJson() { - std::map> result; - for (const auto& [pg_id, status] : all_pg_status_) { - auto pg_status = std::map(); - pg_status["last_enqueued_collective"] = - std::to_string(status->lastEnqueuedSeq); - pg_status["last_started_collective"] = - std::to_string(status->lastStartedSeq); - pg_status["last_completed_collective"] = - std::to_string(status->lastCompletedSeq); - result[std::to_string(pg_id)] = pg_status; - } - return result; -} - -using json = nlohmann::json; -template -std::string FlightRecorder::dump_json( - const std::optional>>& extraDumpMap, - bool includeCollectives, - bool onlyActive) { - json result; - result[version_key_str] = version_val_str; - result[nccl_version_key_str] = nccl_version_; - result[pg_config_key_str] = getPgConfigJson(); - result[pg_status_key_str] = getPgStatusJson(); - - // collective trace - if (includeCollectives) { - std::list entries; - for (auto& e : dump_entries()) { - json j; - if (onlyActive && e.time_discovered_completed_.has_value()) { - continue; - } - j[record_id_key_str] = int64_t(e.id_); - j[pg_id_key_str] = int64_t(e.pg_id_); - j[pg_name_key_str] = e.pg_name_; - j[collective_seq_id_key_str] = int64_t(e.collective_seq_id_); - j[p2p_seq_id_key_str] = int64_t(e.p2p_seq_id_); - j[op_id_key_str] = int64_t(e.op_id_); - j[profiling_name_key_str] = e.profiling_name_; - j[time_created_key_str] = int64_t(e.time_created_); - if (e.duration_) { - j[duration_key_str] = *e.duration_; - } - auto it = e.sizes_.begin(); - auto read_sizes = [&](const c10::SmallVector& dims) { - auto sizes = std::list>(); - for (auto dim : dims) { - auto arg_sizes = std::list(); - for (auto i : c10::irange(dim)) { - (void)i; - arg_sizes.push_back(*it++); - } - sizes.push_back(arg_sizes); - } - return sizes; - }; - j[input_sizes_key_str] = read_sizes(e.input_dims_); - std::vector input_dtypes_strs; - input_dtypes_strs.reserve(e.input_dtypes_.size()); - for (const auto& input_dtype : e.input_dtypes_) { - input_dtypes_strs.emplace_back(c10::toString(input_dtype)); - } - j[input_dtypes_key_str] = input_dtypes_strs; - j[output_sizes_key_str] = read_sizes(e.output_dims_); - std::vector output_dtypes_strs; - output_dtypes_strs.reserve(e.output_dtypes_.size()); - for (const auto& output_dtype : e.output_dtypes_) { - output_dtypes_strs.emplace_back(c10::toString(output_dtype)); - } - j[output_dtypes_key_str] = output_dtypes_strs; - if (e.time_discovered_completed_.has_value()) { - j[state_key_str] = completed_state_str; - } else if (e.time_discovered_started_.has_value()) { - j[state_key_str] = started_state_str; - } else { - j[state_key_str] = scheduled_state_str; - } - j[time_discovered_started_key_str] = - e.time_discovered_started_.has_value() - ? int64_t(*e.time_discovered_started_) - : 0; - j[time_discovered_completed_key_str] = - e.time_discovered_completed_.has_value() - ? int64_t(*e.time_discovered_completed_) - : 0; - j[retired_key_str] = e.retired_; - j[timeout_key_str] = e.timeout_ms_; - j[is_p2p_key_str] = e.isP2P_; - entries.emplace_back(j); - } - - if (!entries.empty()) { - result[entries_key_str] = entries; - } - } - - if (extraDumpMap.has_value()) { - result[nccl_comm_key_str] = extraDumpMap.value(); - } - return result.dump(); -} - -template -std::string FlightRecorder::dump( - const std::optional>>& extraDumpMap, - bool includeCollectives, - bool includeStackTraces, - bool onlyActive) { - STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.FlightRecorder__dump); - auto result = new_dict(); - // common values - result.insert(version_key, version_val); - result.insert(pg_config_key, getPgConfig()); - result.insert(nccl_version_key_str, nccl_version_); - result.insert(pg_status_key, getPgStatus()); - - // collective trace - if (includeCollectives) { - result.insert( - entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); - } - - // convert extraDumpMap into a dictionary - auto per_comm_dict = new_dict(); - if (extraDumpMap.has_value()) { - for (const auto& [ncclId, ncclDump] : extraDumpMap.value()) { - auto inner_dict = new_dict(); - for (const auto& [key, value] : ncclDump) { - inner_dict.insert(key, value); - } - per_comm_dict.insert(ncclId, inner_dict); - } - } - if (!per_comm_dict.empty()) { - result.insert(nccl_comm_key, per_comm_dict); - } - return pickle_str(result); -} - std::unique_ptr DebugInfoWriter::writer_ = nullptr; std::atomic DebugInfoWriter::hasWriterRegistered_(false); +template <> +float getDurationFromEvent( + c10::Event& startEvent, + c10::Event& endEvent) { + TORCH_CHECK(false, "getDuration not supported by c10::Event."); +} + // For any third party library that uses the flight recorder, if one wants to // use an Event type other than c10::Event, one also needs to registers here to // avoid linking errors. template struct FlightRecorder; -#ifdef USE_C10D_NCCL -template struct FlightRecorder; -#endif // USE_C10D_NCCL +std::string dump_fr_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + return FlightRecorder::get()->dump( + std::unordered_map< + std::string, + std::unordered_map>{}, + includeCollectives, + includeStackTraces, + onlyActive); +} + +std::string dump_fr_trace_json(bool includeCollectives, bool onlyActive) { + return FlightRecorder::get()->dump_json( + std::unordered_map< + std::string, + std::unordered_map>{}, + includeCollectives, + onlyActive); +} } // namespace c10d diff --git a/torch/csrc/distributed/c10d/FlightRecorder.hpp b/torch/csrc/distributed/c10d/FlightRecorder.hpp index 1aae94c57769c2..768889015fb754 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include namespace c10d { @@ -19,7 +20,7 @@ namespace c10d { // (minor when adding fields, major when changing existing fields) // Also update both JSON and Pickle dumps to make use of the newly defined // field(s). -DEFINE_CONSTANT(version_val, "2.7") +DEFINE_CONSTANT(version_val, "2.9") DEFINE_CONSTANT(entries_key, "entries") DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state") DEFINE_CONSTANT(nccl_version_key, "nccl_version") @@ -52,6 +53,8 @@ DEFINE_CONSTANT(time_discovered_completed_key, "time_discovered_completed_ns") DEFINE_CONSTANT(completed_state, "completed") DEFINE_CONSTANT(scheduled_state, "scheduled") DEFINE_CONSTANT(started_state, "started") +DEFINE_CONSTANT(thread_id_key, "thread_id") +DEFINE_CONSTANT(thread_name_key, "thread_name") #undef DEFINE_CONSTANT // Write NCCL debug info to local disk or any storage users define. @@ -142,7 +145,7 @@ struct FlightRecorder { std::optional time_discovered_started_; // timestamp when our CPU threads discovered that the kernel completed. - // will always be _after_ it actually complated, and can be the same time + // will always be _after_ it actually completed, and can be the same time // as the discovery of the start if the watchdog thread is stuck on CUDA // APIs std::optional time_discovered_completed_; @@ -153,6 +156,8 @@ struct FlightRecorder { c10::SmallVector output_dims_; std::vector output_dtypes_; c10::SmallVector sizes_; // flattened from inputs, outputs + std::thread::id thread_id_; + std::string thread_name_; bool retired_ = false; // is this work entry no longer in the workMetaList_? // a retired but not completed event has timed out @@ -161,7 +166,7 @@ struct FlightRecorder { // acquire the GIL. If you don't want to block the current thread or take // the risk of a GIL deadlock, you can use an asynchronous calling mechanism // like std::async. - std::string getTraceback(); + TORCH_API std::string getTraceback(); }; bool enabled_ = false; @@ -191,7 +196,7 @@ struct FlightRecorder { std::shared_ptr pg_status, bool isP2P); - void record_pg_ranks( + TORCH_API void record_pg_ranks( const std::tuple& pg_name, std::vector ranks); @@ -203,7 +208,7 @@ struct FlightRecorder { // Returns the entry with the given id, if it exists. Otherwise, returns // std::nullopt. - std::optional getEntry(std::optional id); + TORCH_API std::optional getEntry(std::optional id); /* Mark an Event as completed and free its events. @@ -215,7 +220,9 @@ struct FlightRecorder { never hang. (timing must also be enabled for compute_duration - see TORCH_NCCL_ENABLE_TIMING). */ - void retire_id(std::optional id, bool compute_duration = true); + TORCH_API void retire_id( + std::optional id, + bool compute_duration = true); const c10::List getCollectiveTrace( bool includeStacktraces, @@ -248,4 +255,18 @@ struct FlightRecorder { bool includeStackTraces, bool onlyActive); }; + +// Dumps the fr traces and additional information about the Process +// Group. +TORCH_API std::string dump_fr_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive); + +// Dumps the fr traces and additional information about the Process +// Group in JSON formatted string. +// We don't include stack traces in JSON format as it is far too much data. +TORCH_API std::string dump_fr_trace_json( + bool includeCollectives, + bool onlyActive); } // namespace c10d diff --git a/torch/csrc/distributed/c10d/FlightRecorderCuda.cpp b/torch/csrc/distributed/c10d/FlightRecorderCuda.cpp new file mode 100644 index 00000000000000..25ac1279d62e92 --- /dev/null +++ b/torch/csrc/distributed/c10d/FlightRecorderCuda.cpp @@ -0,0 +1,122 @@ +#ifdef USE_C10D_NCCL +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace c10d { +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request& req, control_plane::Response& res) { + const auto& params = req.params(); + size_t validParamCount = 0; + + // valid params + const std::string includeCollectivesStr = "includecollectives"; + const std::string includeStackTracesStr = "includestacktraces"; + const std::string onlyActiveStr = "onlyactive"; + + std::unordered_map processedParams = { + {includeCollectivesStr, true}, + {includeStackTracesStr, true}, + {onlyActiveStr, false}}; + + for (const auto& [paramName, paramValue] : params) { + auto it = processedParams.find(paramName); + if (it != processedParams.end()) { + validParamCount++; + if (paramValue == "true") { + it->second = true; + } else if (paramValue == "false") { + it->second = false; + } else { + res.setStatus(400); + res.setContent( + "Invalid value for " + paramName + + " valid values are true or false", + "text/plain"); + return; + } + } + } + if (validParamCount < params.size()) { + res.setStatus(400); + res.setContent( + "Invalid parameters - unexpected param passed in", "text/plain"); + return; + } + res.setContent( + dump_nccl_trace( + processedParams[includeCollectivesStr], + processedParams[includeStackTracesStr], + processedParams[onlyActiveStr]), + "application/octet-stream"); + }}; + +control_plane::RegisterHandler jsonDumpHandler{ + "dump_nccl_trace_json", + [](const control_plane::Request& req, control_plane::Response& res) { + const auto& params = req.params(); + size_t validParamCount = 0; + + // valid params + const std::string includeCollectivesStr = "includecollectives"; + const std::string onlyActiveStr = "onlyactive"; + + std::unordered_map processedParams = { + {includeCollectivesStr, true}, {onlyActiveStr, false}}; + + for (const auto& [paramName, paramValue] : params) { + auto it = processedParams.find(paramName); + if (it != processedParams.end()) { + validParamCount++; + if (paramValue == "true") { + it->second = true; + } else if (paramValue == "false") { + it->second = false; + } else { + res.setStatus(400); + res.setContent( + "Invalid value for " + paramName + + " valid values are true or false", + "text/plain"); + return; + } + } + } + if (validParamCount < params.size()) { + res.setStatus(400); + res.setContent( + "Invalid parameters - unexpected param passed in", "text/plain"); + return; + } + res.setStatus(200); + res.setContent( + dump_nccl_trace_json( + processedParams[includeCollectivesStr], + processedParams[onlyActiveStr]), + "application/json"); + }}; + +/* Helper used by work::getDuration() and nccl flight recorder */ +template <> +float getDurationFromEvent( + at::cuda::CUDAEvent& ncclStartEvent, + at::cuda::CUDAEvent& ncclEndEvent) { + TORCH_CHECK( + ncclEndEvent.query(), + "getDuration can only be called after work is succeeded.") + return ncclStartEvent.elapsed_time(ncclEndEvent); +} + +template struct FlightRecorder; +} // namespace c10d +#endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp new file mode 100644 index 00000000000000..608b9157ac3911 --- /dev/null +++ b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp @@ -0,0 +1,550 @@ +#include + +#include +#include + +#include + +namespace c10d { + +template +float getDurationFromEvent(EventType& start, EventType& end); + +// Returns the traceback of current entry, in string form. +// Note: `getTraceback` invokes `torch::symbolize`, which may need to acquire +// the GIL. If you don't want to block the current thread or take the risk of a +// GIL deadlock, you can use an asynchronous calling mechanism like std::async. +template +std::string FlightRecorder::Entry::getTraceback() { + torch::CapturedTraceback* traceback = traceback_.get(); + torch::SymbolizedTracebacks s_tbs = torch::symbolize({traceback}); + // We use 0 because we only have one traceback here. + const auto& s_tb = s_tbs.tracebacks.at(0); + std::stringstream oss; + for (auto idx : c10::irange(s_tb.size())) { + auto frame_id = s_tb[idx]; + const auto& frame = s_tbs.all_frames.at(frame_id); + oss << "#" << idx << " " << frame.funcname << " from " << frame.filename + << ":" << frame.lineno << '\n'; + } + /* Resulted format is like: + #0 all_reduce from pytorch/torch/distributed/distributed_c10d.py:2696 + #1 wrapper from pytorch/torch/distributed/c10d_logger.py:83 + #2 bar from /home/user/repro.py:15 + #3 foo from /home/user/repro.py:24 + #4 main from /home/user/repro.py:34 + #5 from /home/user/repro.py:40 + */ + return oss.str(); +} + +template +std::optional FlightRecorder::record( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + EventType* start, + EventType* end, + std::chrono::milliseconds timeout_ms, + std::shared_ptr pg_status, + bool isP2P) { + if (!enabled_) { + return std::nullopt; + } + if (all_pg_status_.find(pg_id) == all_pg_status_.end()) { + // Current pg_status is not in FR. + all_pg_status_[pg_id] = std::move(pg_status); + } + auto traceback = + torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); + std::lock_guard guard(mutex_); + + auto te = Entry{ + id_, + pg_id, + pg_name, + collective_seq_id, + p2p_seq_id, + op_id, + std::move(profiling_name), + std::move(traceback), + start, + end, + c10::getTime(), + timeout_ms.count(), + isP2P, + std::nullopt, + std::nullopt, + std::nullopt, + {}, + {}, + {}, + {}, + {}, + std::this_thread::get_id(), + c10::getThreadName(), + false}; + + for (const auto& input : inputs) { + c10::IntArrayRef sizes = input.sizes(); + te.input_dtypes_.push_back(input.dtype().toScalarType()); + te.input_dims_.push_back(static_cast(sizes.size())); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + for (const auto& output : outputs) { + c10::IntArrayRef sizes = output.sizes(); + te.output_dtypes_.push_back(output.dtype().toScalarType()); + te.output_dims_.push_back(static_cast(sizes.size())); + te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); + } + + if (entries_.size() < max_entries_) { + entries_.emplace_back(std::move(te)); + } else { + entries_[next_++] = std::move(te); + if (next_ == max_entries_) { + next_ = 0; + } + } + return id_++; +} + +template +void FlightRecorder::record_pg_ranks( + const std::tuple& pg_name, + std::vector ranks) { + if (!enabled_) { + return; + } + std::lock_guard guard(mutex_); + pg_name_to_ranks_[pg_name] = std::move(ranks); +} + +template +void FlightRecorder::record_accelerator_version( + const std::string nccl_version) { + if (!enabled_) { + return; + } + std::lock_guard guard(mutex_); + nccl_version_ = std::move(nccl_version); +} + +template +void FlightRecorder::update_state(Entry& r) { + try { + if (r.start_ != nullptr) { + bool started = r.start_->query(); + if (started && !r.time_discovered_started_) { + r.time_discovered_started_ = c10::getTime(); + } + } + if (r.end_ != nullptr) { + bool completed = r.end_->query(); + if (completed && !r.time_discovered_completed_) { + r.time_discovered_completed_ = c10::getTime(); + } + } + } catch (std::exception& e) { + LOG(ERROR) << "Failed to update state for entry " << r.id_ << ": " + << r.profiling_name_ << " with error: " << e.what(); + } +} + +template +std::vector::Entry> FlightRecorder< + EventType>::dump_entries() { + std::vector result; + { + std::lock_guard guard(mutex_); + result.reserve(entries_.size()); + result.insert( + result.end(), + entries_.begin() + static_cast(next_), + entries_.end()); + result.insert( + result.end(), + entries_.begin(), + entries_.begin() + static_cast(next_)); + } + // query any remaining events + for (auto& r : result) { + update_state(r); + r.start_ = r.end_ = nullptr; + } + return result; +} + +template +// Returns the entry with the given id, if it exists. Otherwise, returns +// std::nullopt. +std::optional::Entry> FlightRecorder< + EventType>::getEntry(std::optional id) { + if (!enabled_ || !id) { + return std::nullopt; + } + + std::unique_lock guard(mutex_); + Entry entry = entries_.at(*id % max_entries_); + if (entry.id_ == *id) { + return entry; + } else { + return std::nullopt; + } +} + +template +void FlightRecorder::retire_id( + std::optional id, + bool compute_duration) { + if (!enabled_ || !id) { + return; + } + + bool can_compute_duration = false; + EventType* startEvent = nullptr; + EventType* endEvent = nullptr; + std::optional duration = std::nullopt; + + std::unique_lock guard(mutex_); + + Entry* entry = &entries_.at(*id % max_entries_); + if (entry->id_ == *id) { + update_state(*entry); + + if (compute_duration) { + can_compute_duration = entry->time_discovered_completed_.has_value() && + entry->start_ && entry->end_; + startEvent = entry->start_; + endEvent = entry->end_; + } + entry->retired_ = true; + entry->start_ = entry->end_ = nullptr; + } + + if (can_compute_duration) { + // Compute duration without without holding the lock, because + // cudaEventDuration() can hang, and we need to acquire the lock before we + // can dump(), which we never want to block. + guard.unlock(); + duration = getDurationFromEvent(*startEvent, *endEvent); + guard.lock(); + + // Refresh the entry pointer, see if the entry has been overwritten + entry = &entries_.at(*id % max_entries_); + if (entry->id_ != *id) { + LOG(INFO) << "retire_id abandoned for id " << *id + << ", event was overwritten while waiting to compute duration."; + return; + } + if (duration.has_value()) { + entry->duration_ = duration; + } + } +} + +template +const c10::List FlightRecorder::getCollectiveTrace( + bool includeStacktraces, + bool onlyActive) { + auto entries = new_list(); + // Entries are returned in the order they were recorded + auto result = dump_entries(); + std::vector tracebacks; + torch::SymbolizedTracebacks stracebacks; + std::vector all_frames; + if (includeStacktraces) { + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + stracebacks = torch::symbolize(tracebacks); + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_key, f.funcname); + d.insert(filename_key, f.filename); + d.insert(line_key, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } + } + for (auto i : c10::irange(result.size())) { + auto dict = new_dict(); + auto& e = result.at(i); + // Skip completed events + if (onlyActive && e.time_discovered_completed_.has_value()) { + continue; + } + if (includeStacktraces) { + auto& tb = stracebacks.tracebacks.at(i); + auto frames = new_list(); + for (auto frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_key, frames); + } + + dict.insert(record_id_key, int64_t(e.id_)); + dict.insert(pg_id_key, int64_t(e.pg_id_)); + dict.insert(pg_name_key, e.pg_name_); + dict.insert(thread_name_key, e.thread_name_); + dict.insert(thread_id_key, c10::str(e.thread_id_)); + dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); + dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); + dict.insert(op_id_key, int64_t(e.op_id_)); + dict.insert(profiling_name_key, e.profiling_name_); + dict.insert(time_created_key, int64_t(e.time_created_)); + if (e.duration_) { + dict.insert(duration_key, *e.duration_); + } + + auto it = e.sizes_.begin(); + auto read_sizes = [&](const c10::SmallVector& dims) { + auto sizes = new_list(); + for (auto dim : dims) { + auto arg_sizes = new_list(); + for ([[maybe_unused]] auto i : c10::irange(dim)) { + arg_sizes.push_back(*it++); + } + sizes.push_back(arg_sizes); + } + return sizes; + }; + + dict.insert(input_sizes_key, read_sizes(e.input_dims_)); + std::vector input_dtypes_strs; + input_dtypes_strs.reserve(e.input_dtypes_.size()); + for (const auto& input_dtype : e.input_dtypes_) { + input_dtypes_strs.emplace_back(c10::toString(input_dtype)); + } + dict.insert(input_dtypes_key, input_dtypes_strs); + dict.insert(output_sizes_key, read_sizes(e.output_dims_)); + std::vector output_dtypes_strs; + output_dtypes_strs.reserve(e.output_dtypes_.size()); + for (const auto& output_dtype : e.output_dtypes_) { + output_dtypes_strs.emplace_back(c10::toString(output_dtype)); + } + dict.insert(output_dtypes_key, output_dtypes_strs); + if (e.time_discovered_completed_.has_value()) { + dict.insert(state_key, completed_state); + } else if (e.time_discovered_started_.has_value()) { + dict.insert(state_key, started_state); + } else { + dict.insert(state_key, scheduled_state); + } + + dict.insert( + time_discovered_started_key, + e.time_discovered_started_.has_value() + ? int64_t(*e.time_discovered_started_) + : c10::IValue()); + dict.insert( + time_discovered_completed_key, + e.time_discovered_completed_.has_value() + ? int64_t(*e.time_discovered_completed_) + : c10::IValue()); + dict.insert(retired_key, e.retired_); + dict.insert(timeout_key, e.timeout_ms_); + dict.insert(is_p2p_key, e.isP2P_); + + entries.push_back(dict); + } + return entries; +} + +template +const c10::Dict FlightRecorder< + EventType>::getPgConfig() { + auto pg_config = new_dict(); + for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { + auto pg_info = new_dict(); + pg_info.insert("name", std::get<0>(pg_name)); + pg_info.insert("desc", std::get<1>(pg_name)); + pg_info.insert("ranks", ranks_str(ranks)); + pg_config.insert(std::get<0>(pg_name), pg_info); + } + return pg_config; +} + +template +const std::map> FlightRecorder< + EventType>::getPgConfigJson() { + std::map> result; + for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { + auto pg_info = std::map(); + pg_info["name"] = std::get<0>(pg_name); + pg_info["desc"] = std::get<1>(pg_name); + pg_info["ranks"] = ranks_str(ranks); + result.emplace(std::get<0>(pg_name), pg_info); + } + return result; +} + +template +const c10::Dict FlightRecorder< + EventType>::getPgStatus() { + auto all_pg_status = new_dict(); + for (const auto& [pg_id, status] : all_pg_status_) { + auto pg_status = new_dict(); + pg_status.insert("last_enqueued_collective", status->lastEnqueuedSeq); + pg_status.insert("last_started_collective", status->lastStartedSeq); + pg_status.insert("last_completed_collective", status->lastCompletedSeq); + all_pg_status.insert(std::to_string(pg_id), pg_status); + } + return all_pg_status; +} + +template +const std::map> FlightRecorder< + EventType>::getPgStatusJson() { + std::map> result; + for (const auto& [pg_id, status] : all_pg_status_) { + auto pg_status = std::map(); + pg_status["last_enqueued_collective"] = + std::to_string(status->lastEnqueuedSeq); + pg_status["last_started_collective"] = + std::to_string(status->lastStartedSeq); + pg_status["last_completed_collective"] = + std::to_string(status->lastCompletedSeq); + result[std::to_string(pg_id)] = pg_status; + } + return result; +} + +using json = nlohmann::json; +template +std::string FlightRecorder::dump_json( + const std::optional>>& extraDumpMap, + bool includeCollectives, + bool onlyActive) { + json result; + result[version_key_str] = version_val_str; + result[nccl_version_key_str] = nccl_version_; + result[pg_config_key_str] = getPgConfigJson(); + result[pg_status_key_str] = getPgStatusJson(); + + // collective trace + if (includeCollectives) { + std::list entries; + for (auto& e : dump_entries()) { + json j; + if (onlyActive && e.time_discovered_completed_.has_value()) { + continue; + } + j[record_id_key_str] = int64_t(e.id_); + j[pg_id_key_str] = int64_t(e.pg_id_); + j[pg_name_key_str] = e.pg_name_; + j[thread_name_key_str] = e.thread_name_; + j[thread_id_key_str] = c10::str(e.thread_id_); + j[collective_seq_id_key_str] = int64_t(e.collective_seq_id_); + j[p2p_seq_id_key_str] = int64_t(e.p2p_seq_id_); + j[op_id_key_str] = int64_t(e.op_id_); + j[profiling_name_key_str] = e.profiling_name_; + j[time_created_key_str] = int64_t(e.time_created_); + if (e.duration_) { + j[duration_key_str] = *e.duration_; + } + auto it = e.sizes_.begin(); + auto read_sizes = [&](const c10::SmallVector& dims) { + auto sizes = std::list>(); + for (auto dim : dims) { + auto arg_sizes = std::list(); + for (auto i : c10::irange(dim)) { + (void)i; + arg_sizes.push_back(*it++); + } + sizes.push_back(arg_sizes); + } + return sizes; + }; + j[input_sizes_key_str] = read_sizes(e.input_dims_); + std::vector input_dtypes_strs; + input_dtypes_strs.reserve(e.input_dtypes_.size()); + for (const auto& input_dtype : e.input_dtypes_) { + input_dtypes_strs.emplace_back(c10::toString(input_dtype)); + } + j[input_dtypes_key_str] = input_dtypes_strs; + j[output_sizes_key_str] = read_sizes(e.output_dims_); + std::vector output_dtypes_strs; + output_dtypes_strs.reserve(e.output_dtypes_.size()); + for (const auto& output_dtype : e.output_dtypes_) { + output_dtypes_strs.emplace_back(c10::toString(output_dtype)); + } + j[output_dtypes_key_str] = output_dtypes_strs; + if (e.time_discovered_completed_.has_value()) { + j[state_key_str] = completed_state_str; + } else if (e.time_discovered_started_.has_value()) { + j[state_key_str] = started_state_str; + } else { + j[state_key_str] = scheduled_state_str; + } + j[time_discovered_started_key_str] = + e.time_discovered_started_.has_value() + ? int64_t(*e.time_discovered_started_) + : 0; + j[time_discovered_completed_key_str] = + e.time_discovered_completed_.has_value() + ? int64_t(*e.time_discovered_completed_) + : 0; + j[retired_key_str] = e.retired_; + j[timeout_key_str] = e.timeout_ms_; + j[is_p2p_key_str] = e.isP2P_; + entries.emplace_back(j); + } + + if (!entries.empty()) { + result[entries_key_str] = entries; + } + } + + if (extraDumpMap.has_value()) { + result[nccl_comm_key_str] = extraDumpMap.value(); + } + return result.dump(); +} + +template +std::string FlightRecorder::dump( + const std::optional>>& extraDumpMap, + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.FlightRecorder__dump); + auto result = new_dict(); + // common values + result.insert(version_key, version_val); + result.insert(pg_config_key, getPgConfig()); + result.insert(nccl_version_key_str, nccl_version_); + result.insert(pg_status_key, getPgStatus()); + + // collective trace + if (includeCollectives) { + result.insert( + entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); + } + + // convert extraDumpMap into a dictionary + auto per_comm_dict = new_dict(); + if (extraDumpMap.has_value()) { + for (const auto& [ncclId, ncclDump] : extraDumpMap.value()) { + auto inner_dict = new_dict(); + for (const auto& [key, value] : ncclDump) { + inner_dict.insert(key, value); + } + per_comm_dict.insert(ncclId, inner_dict); + } + } + if (!per_comm_dict.empty()) { + result.insert(nccl_comm_key, per_comm_dict); + } + return pickle_str(result); +} +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index ebefc4754f41b4..b40b9bf92e9f9e 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -30,6 +30,37 @@ c10d::ReduceOp to_reduce_op(const std::string& reduce_op) { return it->second; } +at::Tensor allocate_all_gather_output( + const at::Tensor& input, + int64_t group_size) { + TORCH_CHECK(input.is_contiguous()); + auto output_size = input.sizes().vec(); + output_size[0] *= group_size; + return at::empty( + output_size, + at::TensorOptions().dtype(input.dtype()).device(input.device())); +} + +at::Tensor allocate_reduce_scatter_output( + const at::Tensor& input, + const int64_t group_size) { + TORCH_CHECK(input.is_contiguous()); + auto output_size = input.sizes().vec(); + if (output_size[0] % group_size != 0) { + LOG(WARNING) << "The first dimension of the reduce_scatter input (" + << output_size[0] << ") is not divisible by the group size (" + << group_size << ")."; + } + output_size[0] /= group_size; + return at::empty( + output_size, + at::TensorOptions().dtype(input.dtype()).device(input.device())); +} + +} // namespace + +namespace c10d { + at::Tensor& all_reduce_( at::Tensor& input, // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -50,8 +81,21 @@ at::Tensor all_reduce( const at::Tensor& input, std::string reduce_op, std::string group_name) { - auto output = input.clone(at::MemoryFormat::Contiguous); - return all_reduce_(output, std::move(reduce_op), std::move(group_name)); + if (input.is_complex()) { + TORCH_CHECK( + // TODO - ideally use 'to_reduce_op' helper but it currently errors on + // premul_sum + reduce_op == "sum" || reduce_op == "avg" || reduce_op == "premul_sum" || + reduce_op == "unused", + "all_reduce: reduce_op ", + reduce_op, + " does not support complex tensors"); + } + auto input_real = input.is_complex() ? at::view_as_real(input) : input; + auto output = input_real.clone(at::MemoryFormat::Contiguous); + auto output_ret = + all_reduce_(output, std::move(reduce_op), std::move(group_name)); + return input.is_complex() ? at::view_as_complex(output_ret) : output_ret; } std::vector all_reduce_coalesced_( @@ -85,17 +129,6 @@ std::vector all_reduce_coalesced( outputs, std::move(reduce_op), std::move(group_name)); } -at::Tensor allocate_all_gather_output( - const at::Tensor& input, - int64_t group_size) { - TORCH_CHECK(input.is_contiguous()); - auto output_size = input.sizes().vec(); - output_size[0] *= group_size; - return at::empty( - output_size, - at::TensorOptions().dtype(input.dtype()).device(input.device())); -} - std::vector all_gather_into_tensor_coalesced( std::vector inputs, int64_t group_size, @@ -121,9 +154,11 @@ at::Tensor all_gather_into_tensor( int64_t group_size, std::string group_name) { TORCH_CHECK(input.is_contiguous()); - std::vector inputs{input}; - return all_gather_into_tensor_coalesced( + auto real_input = input.is_complex() ? at::view_as_real(input) : input; + std::vector inputs{real_input}; + auto output = all_gather_into_tensor_coalesced( inputs, group_size, std::move(group_name))[0]; + return input.is_complex() ? at::view_as_complex(output) : output; } at::Tensor& all_gather_into_tensor_out( @@ -140,22 +175,6 @@ at::Tensor& all_gather_into_tensor_out( return output; } -at::Tensor allocate_reduce_scatter_output( - const at::Tensor& input, - const int64_t group_size) { - TORCH_CHECK(input.is_contiguous()); - auto output_size = input.sizes().vec(); - if (output_size[0] % group_size != 0) { - LOG(WARNING) << "The first dimension of the reduce_scatter input (" - << output_size[0] << ") is not divisible by the group size (" - << group_size << ")."; - } - output_size[0] /= group_size; - return at::empty( - output_size, - at::TensorOptions().dtype(input.dtype()).device(input.device())); -} - std::vector reduce_scatter_tensor_coalesced( std::vector inputs, // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -186,6 +205,12 @@ at::Tensor reduce_scatter_tensor( int64_t group_size, std::string group_name) { TORCH_CHECK(input.is_contiguous()); + if (input.is_complex()) { + auto real_input = at::view_as_real(input); + std::vector inputs{real_input}; + return at::view_as_complex(reduce_scatter_tensor_coalesced( + inputs, std::move(reduce_op), group_size, std::move(group_name))[0]); + } std::vector inputs{input}; return reduce_scatter_tensor_coalesced( inputs, std::move(reduce_op), group_size, std::move(group_name))[0]; @@ -218,7 +243,8 @@ at::Tensor all_to_all_single( at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) { c10d::BroadcastOptions opts; opts.rootRank = src; - std::vector inputs{input}; + auto input_real = input.is_complex() ? at::view_as_real(input) : input; + std::vector inputs{input_real}; auto group = c10d::resolve_process_group(group_name); auto work = group->broadcast(inputs, opts); @@ -234,65 +260,68 @@ at::Tensor broadcast( return broadcast_(output, src, std::move(group_name)); } -} // namespace +} // namespace c10d TORCH_LIBRARY(_c10d_functional, m) { m.def( "all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce), + c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce), {at::Tag::pt2_compliant_tag}); m.def( "all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_), + c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce_), {at::Tag::pt2_compliant_tag}); m.def( "all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced), + c10::DispatchKey::CompositeExplicitAutograd, + c10d::all_reduce_coalesced), {at::Tag::pt2_compliant_tag}); m.def( "all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_), + c10::DispatchKey::CompositeExplicitAutograd, + c10d::all_reduce_coalesced_), {at::Tag::pt2_compliant_tag}); m.def( "all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - ::all_gather_into_tensor_out), + c10d::all_gather_into_tensor_out), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - ::all_gather_into_tensor), + c10d::all_gather_into_tensor), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - ::all_gather_into_tensor_coalesced), + c10d::all_gather_into_tensor_coalesced), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::reduce_scatter_tensor), + c10::DispatchKey::CompositeExplicitAutograd, + c10d::reduce_scatter_tensor), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - ::reduce_scatter_tensor_coalesced), + c10d::reduce_scatter_tensor_coalesced), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( @@ -302,18 +331,19 @@ TORCH_LIBRARY(_c10d_functional, m) { "SymInt[] input_split_sizes, " "str group_name) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::all_to_all_single), + c10::DispatchKey::CompositeExplicitAutograd, c10d::all_to_all_single), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "broadcast(Tensor input, int src, str group_name) -> Tensor", - torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, ::broadcast), + torch::dispatch( + c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast), {at::Tag::pt2_compliant_tag}); m.def( "broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::broadcast_), + c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast_), {at::Tag::pt2_compliant_tag}); m.def( @@ -342,7 +372,7 @@ class AllToAllSingle : public torch::autograd::Function { return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_to_all_single", "") - .typed() + .typed() .call(input, output_split_sizes, input_split_sizes, group_name); } @@ -361,7 +391,7 @@ class AllToAllSingle : public torch::autograd::Function { auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_to_all_single", "") - .typed() + .typed() .call(grad_out, output_split_sizes, input_split_sizes, group_name); // do an explicit wait to avoid cuda stream issues @@ -400,7 +430,7 @@ class ReduceScatterTensor return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "") - .typed() + .typed() .call(input, reduce_op, group_size, group_name); } @@ -416,7 +446,7 @@ class ReduceScatterTensor auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "") - .typed() + .typed() .call(grad_out, group_size, group_name); // do an explicit wait to avoid cuda stream issues @@ -456,7 +486,7 @@ class AllGatherIntoTensor return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "") - .typed() + .typed() .call(input, group_size, group_name); } @@ -472,7 +502,7 @@ class AllGatherIntoTensor auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "") - .typed() + .typed() .call(grad_out, "sum", group_size, group_name); // do an explicit wait to avoid cuda stream issues @@ -549,7 +579,10 @@ at::Tensor shard_dim_alltoall( input_sizes.insert(input_sizes.begin() + shard_dim, group_size); auto tensor_reshaped = input.view(input_sizes); - auto tensor_for_comm = tensor_reshaped.movedim(shard_dim, 0).contiguous(); + auto tensor_shard_contig = tensor_reshaped.movedim(shard_dim, 0).contiguous(); + auto tensor_for_comm = input.is_complex() + ? at::view_as_real(tensor_shard_contig) + : tensor_shard_contig; auto recv_tensor = at::empty_like(tensor_for_comm); std::vector out_split_sizes; @@ -571,7 +604,8 @@ at::Tensor shard_dim_alltoall( // view/reshape it back to the expected output shape output_sizes[shard_dim] /= group_size; output_sizes[gather_dim] *= group_size; - return output.view(output_sizes); + return input.is_complex() ? at::view_as_complex(output).view(output_sizes) + : output.view(output_sizes); } } // namespace diff --git a/torch/csrc/distributed/c10d/Functional.hpp b/torch/csrc/distributed/c10d/Functional.hpp index e81d44b8dbd23d..d89dbe24b1c645 100644 --- a/torch/csrc/distributed/c10d/Functional.hpp +++ b/torch/csrc/distributed/c10d/Functional.hpp @@ -1,3 +1,78 @@ #pragma once #include + +namespace c10d { + +C10_EXPORT at::Tensor& all_reduce_( + at::Tensor& input, + std::string reduce_op, + std::string group_name); + +C10_EXPORT at::Tensor all_reduce( + const at::Tensor& input, + std::string reduce_op, + std::string group_name); + +C10_EXPORT std::vector all_reduce_coalesced_( + std::vector inputs, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string reduce_op, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string group_name); + +C10_EXPORT std::vector all_reduce_coalesced( + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector inputs, + std::string reduce_op, + std::string group_name); + +C10_EXPORT std::vector all_gather_into_tensor_coalesced( + std::vector inputs, + int64_t group_size, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string group_name); + +C10_EXPORT at::Tensor all_gather_into_tensor( + const at::Tensor& input, + int64_t group_size, + std::string group_name); + +C10_EXPORT at::Tensor& all_gather_into_tensor_out( + at::Tensor& input, + int64_t group_size, + const std::string& group_name, + at::Tensor& output); + +C10_EXPORT std::vector reduce_scatter_tensor_coalesced( + std::vector inputs, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string reduce_op, + int64_t group_size, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string group_name); + +C10_EXPORT at::Tensor reduce_scatter_tensor( + const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + std::string group_name); + +C10_EXPORT at::Tensor all_to_all_single( + const at::Tensor& input, + std::vector output_split_sizes, + std::vector input_split_sizes, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string group_name); + +C10_EXPORT at::Tensor& broadcast_( + at::Tensor& input, + int64_t src, + std::string group_name); + +C10_EXPORT at::Tensor broadcast( + const at::Tensor& input, + int64_t src, + std::string group_name); + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 4d11ea92391b02..d0e3d6b41e8ca6 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,4 +1,5 @@ #include +#include #include @@ -39,8 +40,24 @@ NCCLComm::NCCLComm(NCCLComm&& other) { std::swap(deviceIndex_, other.deviceIndex_); } -ncclUniqueId NCCLComm::getNcclId() { - return ncclId_; +void NCCLComm::setUniqueHash(ncclUniqueId ncclId) { + const uint8_t* bytes = reinterpret_cast(&ncclId); + + fmt::memory_buffer buf; + buf.reserve(NCCL_UNIQUE_ID_BYTES * 2); // 2 hex chars per byte + for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) { + fmt::format_to( + std::back_inserter(buf), "{:02x}", static_cast(bytes[i])); + } + this->uniqueHash_ = fmt::to_string(buf); +} + +void NCCLComm::setUniqueHash(std::string hash) { + this->uniqueHash_ = std::move(hash); +} + +std::string NCCLComm::getUniqueHash() { + return uniqueHash_; } std::shared_ptr NCCLComm::create( @@ -53,7 +70,7 @@ std::shared_ptr NCCLComm::create( C10D_NCCL_CHECK( ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), std::nullopt); - comm->ncclId_ = commId; + comm->setUniqueHash(commId); comm->rank_ = rank; comm->deviceIndex_ = deviceIndex; comm->initialized_ = true; @@ -78,7 +95,7 @@ std::shared_ptr NCCLComm::create( ncclCommInitRankConfig( &(comm->ncclComm_), numRanks, commId, rank, &config), std::nullopt); - comm->ncclId_ = commId; + comm->setUniqueHash(commId); comm->rank_ = rank; comm->deviceIndex_ = deviceIndex; // Under blocking mode, comm is initialized immediately after NCCL init @@ -112,7 +129,7 @@ std::shared_ptr NCCLComm::create_scalable( // Only the first ncclUniqueId will be used to create the // communicator hash id, which is used to identify the communicator // in the log file and in the replay tool. - comm->ncclId_ = commIds[0]; + comm->setUniqueHash(commIds[0]); comm->rank_ = rank; comm->deviceIndex_ = deviceIndex; comm->initialized_ = !comm->nonBlocking_; @@ -237,6 +254,9 @@ std::shared_ptr NCCLComm::split( // Child comm should be on the same device as parent comm comm->deviceIndex_ = source->deviceIndex_; comm->nonBlocking_ = config.blocking == 0; + comm->setUniqueHash( + source->getUniqueHash() + ":" + + std::to_string(source->ncclCommSplitCounter_)); LOG(INFO) << "Rank " << source->rank_ << ": created child comm " << comm->repr() << " with color_id " << color_id; return comm; @@ -350,7 +370,8 @@ ncclResult_t NCCLComm::checkForNcclError() { ncclResult_t NCCLComm::registerSegment( void* ptr, size_t size, - bool errorOnRereg /*=true*/) { + bool errorOnRereg, /*=true*/ + bool window /*=false*/) { LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator @@ -371,6 +392,30 @@ ncclResult_t NCCLComm::registerSegment( void* handle = nullptr; // Use getNcclComm to make sure comm is ready before calling nccl APIs auto comm = getNcclComm(); +#ifdef NCCL_HAS_COMM_WINDOW_REGISTER + if (window) { + C10D_NCCL_CHECK( + ncclCommWindowRegister( + comm, ptr, size, (ncclWindow_t*)&handle, NCCL_WIN_COLL_SYMMETRIC), + c10::str( + "Failed to window register segment with ptr ", + ptr, + ", size ", + size, + " on ncclComm_ ", + comm)); + } else { + C10D_NCCL_CHECK( + ncclCommRegister(comm, ptr, size, &handle), + c10::str( + "Failed to register segment with ptr ", + ptr, + ", size ", + size, + " on ncclComm_ ", + comm)); + } +#else C10D_NCCL_CHECK( ncclCommRegister(comm, ptr, size, &handle), c10::str( @@ -380,6 +425,7 @@ ncclResult_t NCCLComm::registerSegment( size, " on ncclComm_ ", comm)); +#endif registeredSegmentHandles_[ptr] = handle; return ncclSuccess; #else @@ -387,7 +433,7 @@ ncclResult_t NCCLComm::registerSegment( #endif } -ncclResult_t NCCLComm::deregisterSegment(void* ptr) { +ncclResult_t NCCLComm::deregisterSegment(void* ptr, bool window /*false*/) { LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER TORCH_CHECK( @@ -400,6 +446,29 @@ ncclResult_t NCCLComm::deregisterSegment(void* ptr) { void* handle = registeredSegmentHandles_[ptr]; // Use getNcclComm to make sure comm is ready before calling nccl APIs auto comm = getNcclComm(); +#ifdef NCCL_HAS_COMM_WINDOW_REGISTER + if (window) { + C10D_NCCL_CHECK( + ncclCommWindowDeregister(comm, (ncclWindow_t)handle), + c10::str( + "Failed to window deregister segment handle ", + handle, + ", with ptr ", + ptr, + " on ncclComm_ ", + comm)); + } else { + C10D_NCCL_CHECK( + ncclCommDeregister(comm, handle), + c10::str( + "Failed to deregister segment handle ", + handle, + ", with ptr ", + ptr, + " on ncclComm_ ", + comm)); + } +#else C10D_NCCL_CHECK( ncclCommDeregister(comm, handle), c10::str( @@ -409,6 +478,7 @@ ncclResult_t NCCLComm::deregisterSegment(void* ptr) { ptr, " on ncclComm_ ", comm)); +#endif registeredSegmentHandles_.erase(ptr); return ncclSuccess; #else @@ -434,21 +504,11 @@ std::unordered_map NCCLComm::ncclCommDump() { std::string getNcclVersion() { static std::string versionString = []() { - int version = 0; + auto [ncclMajor, ncclMinor, ncclPatch] = getNcclVersionTuple(); std::string versionString; - ncclResult_t status = ncclGetVersion(&version); - // can't compute the version if call did not return successfully or version - // code < 100 (corresponding to 0.1.0) - if (status != ncclSuccess || version < 100) { + if (ncclMajor == 0 && ncclMinor == 0 && ncclPatch == 0) { versionString = "Unknown NCCL version"; } else { - // NCCL changed version coding starting 2.9 - const int majorBase = version < 2900 ? 1000 : 10000; - const int minorBase = 100; - auto ncclMajor = version / majorBase; - auto ncclMinor = (version % majorBase) / minorBase; - auto ncclPatch = - version % (ncclMajor * majorBase + ncclMinor * minorBase); versionString = std::to_string(ncclMajor) + "." + std::to_string(ncclMinor) + "." + std::to_string(ncclPatch); #ifdef NCCL_SUFFIX @@ -464,6 +524,25 @@ std::string getNcclVersion() { return versionString; } +std::tuple getNcclVersionTuple() { + static std::tuple versionTuple = []() { + int version = getNcclVersionNumber(); + // can't compute the version if call did not return successfully or version + // code < 100 (corresponding to 0.1.0) + if (version < 100) { + return std::make_tuple(0, 0, 0); + } + // NCCL changed version coding starting 2.9 + const int majorBase = version < 2900 ? 1000 : 10000; + const int minorBase = 100; + auto ncclMajor = version / majorBase; + auto ncclMinor = (version % majorBase) / minorBase; + auto ncclPatch = version % minorBase; + return std::make_tuple(ncclMajor, ncclMinor, ncclPatch); + }(); + return versionTuple; +} + int getNcclVersionNumber() { static int version = []() { int version = 0; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index e5ca948b5c96ee..837a7bec3d5db1 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -14,7 +14,6 @@ #include #include #include -#include #include constexpr int64_t kCommInitBusyWaitMillis = 2; @@ -63,6 +62,10 @@ static_assert( #define NCCL_HAS_COMM_REGISTER #endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_COMM_WINDOW_REGISTER +#endif + #if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0) #define NCCL_HAS_MEM_ALLOC #endif @@ -75,6 +78,18 @@ static_assert( #define NCCL_SUPPORTS_FP8 #endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_COLLNET +#endif + +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_CTA_POLICY +#endif + +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_NVLS_CTAS +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -217,6 +232,7 @@ static std::map ncclDataType = { TORCH_API size_t hashTensors(const std::vector& tensors); TORCH_API std::string getNcclVersion(); +TORCH_API std::tuple getNcclVersionTuple(); TORCH_API int getNcclVersionNumber(); TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); int nccl_nonblocking_timeout(); @@ -242,6 +258,10 @@ class NCCLComm { ~NCCLComm() noexcept; + void setUniqueHash(ncclUniqueId ncclId); + void setUniqueHash(std::string hash); + std::string getUniqueHash(); + static std::shared_ptr create( int numRanks, int rank, @@ -278,7 +298,6 @@ class NCCLComm { std::unordered_map ncclCommDump(); #endif - ncclUniqueId getNcclId(); at::DeviceIndex getDeviceIndex(); // Must not be copyable @@ -328,17 +347,18 @@ class NCCLComm { ncclResult_t registerSegment( void* ptr, size_t size, - bool errorOnRereg = true); + bool errorOnRereg = true, + bool window = false); - ncclResult_t deregisterSegment(void* ptr); + ncclResult_t deregisterSegment(void* ptr, bool window = false); std::string repr() const; friend class ProcessGroupNCCL; protected: - // Unique nccl_id for this communicator. - ncclUniqueId ncclId_{}; + // Unique hash for this communicator. + std::string uniqueHash_; bool aborted_{false}; uint64_t ncclCommSplitCounter_{0}; ncclResult_t ncclAsyncErr_{ncclSuccess}; diff --git a/torch/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu b/torch/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu deleted file mode 100644 index e49edb9a7bcd78..00000000000000 --- a/torch/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu +++ /dev/null @@ -1,329 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace c10d { -namespace symmetric_memory { - -/* Start of CUDASymmetricMemory implementation */ - -static StoreExchange storeExchange = StoreExchange("NVSHMEMSymmetricMemory"); - -struct NVSHMEMAllocation { - void* ptr; - size_t buffer_size; - int device_idx; - - NVSHMEMAllocation(void* ptr, size_t buffer_size, int device_idx) - : ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {} -}; - -class NVSHMEMSymmetricMemory : public SymmetricMemory { - public: - NVSHMEMSymmetricMemory( - std::shared_ptr allocation, - const std::string& group_name) - : allocation_(allocation), - buffer_size_(allocation->buffer_size), - device_idx_(allocation->device_idx), - group_name_(group_name) { - c10::cuda::CUDAGuard guard(device_idx_); - - auto global_rank = get_group_info("0").rank; - auto group_info = get_group_info(group_name_); - auto store = group_info.store; - rank_ = group_info.rank; - world_size_ = group_info.world_size; - rank_to_global_rank_ = - storeExchange.all_gather(store, rank_, world_size_, global_rank); - LOG(INFO) << "[rank " << rank_ << "]" - << "rank_to_global_rank: " << rank_to_global_rank_; - - for (int r = 0; r < world_size_; ++r) { - buffers_.push_back(nvshmem_extension::nvshmem_ptr( - allocation->ptr, rank_to_global_rank_[r])); - } - - // TODO: use the same allocation for signal pad - void* signal_pad_ptr = nvshmem_extension::nvshmem_malloc(signal_pad_size); - AT_CUDA_CHECK(cudaMemset(signal_pad_ptr, 0, signal_pad_size)); - - for (int r = 0; r < world_size_; ++r) { - signal_pads_.push_back(nvshmem_extension::nvshmem_ptr( - signal_pad_ptr, rank_to_global_rank_[r])); - } - - const size_t arr_size = sizeof(void*) * world_size_; - buffers_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - signal_pads_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - - AT_CUDA_CHECK(cudaMemcpy( - buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); - AT_CUDA_CHECK(cudaMemcpy( - signal_pads_dev_, - signal_pads_.data(), - arr_size, - cudaMemcpyHostToDevice)); - - rank_to_global_rank_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int) * world_size_)); - AT_CUDA_CHECK(cudaMemcpy( - rank_to_global_rank_dev_, - rank_to_global_rank_.data(), - sizeof(int) * world_size_, - cudaMemcpyHostToDevice)); - } - - ~NVSHMEMSymmetricMemory() override{ - // TODO - }; - - std::vector get_buffer_ptrs() override { - return buffers_; - } - - std::vector get_signal_pad_ptrs() override { - return signal_pads_; - } - - void** get_buffer_ptrs_dev() override { - return buffers_dev_; - } - - void** get_signal_pad_ptrs_dev() override { - return signal_pads_dev_; - } - - size_t get_buffer_size() override { - return buffer_size_; - } - - size_t get_signal_pad_size() override { - return signal_pad_size; - }; - - bool has_multicast_support() override { - // TODO - return false; - } - - void* get_multicast_ptr() override { - // TODO - return nullptr; - } - - at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) { - // TODO: deduplicate - const size_t numel = std::accumulate( - sizes.begin(), - sizes.end(), - static_cast(1), - std::multiplies()); - const auto element_size = c10::elementSize(dtype); - const auto req_size = (numel + storage_offset) * element_size; - TORCH_CHECK( - req_size <= buffer_size_, - "NVSHMEMSymmetricMemory::get_buffer: the requested size (", - req_size, - " bytes) exceeds the allocated size (", - buffer_size_, - " bytes)"); - auto data_ptr = reinterpret_cast(buffers_[rank]) + - storage_offset * element_size; - auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::for_blob(data_ptr, sizes) - .options(options) - .target_device(device) - .make_tensor(); - } - - at::Tensor get_signal_pad( - int rank, - c10::IntArrayRef sizes, - std::optional dtype, - int64_t storage_offset) override { - // TODO: deduplicate - // If the dtype is unspecified, default it to UInt32, as it - // is the most common type for signaling purposes. - if (!dtype.has_value()) { - dtype = c10::ScalarType::UInt32; - } - - // If the shape is unspecified, treat the signal pad as a 1d tensor. - const auto element_size = c10::elementSize(*dtype); - std::vector shape; - if (!sizes.empty()) { - shape = sizes.vec(); - } else { - shape.push_back(signal_pad_size / element_size); - } - - const size_t numel = std::accumulate( - shape.begin(), - shape.end(), - static_cast(1), - std::multiplies()); - const auto req_size = (numel + storage_offset) * element_size; - TORCH_CHECK( - req_size <= signal_pad_size, - "NVSHMEMSymmetricMemory::get_signal_pad: the requested size (", - req_size, - " bytes) exceeds the allocated size (", - signal_pad_size, - " bytes)"); - auto data_ptr = reinterpret_cast(signal_pads_[rank]) + - storage_offset * element_size; - auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); - auto options = at::TensorOptions().dtype(*dtype).device(device); - return at::for_blob(data_ptr, shape) - .options(options) - .target_device(device) - .make_tensor(); - } - - void barrier(int channel, size_t timeout_ms) override { - // TODO - } - - void put_signal(int dst_rank, int channel, size_t timeout_ms) override { - // TODO - } - - void wait_signal(int src_rank, int channel, size_t timeout_ms) override { - // TODO - } - - int get_rank() override { - return rank_; - } - - int get_world_size() override { - return world_size_; - } - - virtual std::vector get_rank_to_global_rank() override { - return rank_to_global_rank_; - }; - - int* get_rank_to_global_rank_dev() override { - return rank_to_global_rank_dev_; - }; - - private: - std::shared_ptr allocation_; - size_t buffer_size_; - std::vector buffers_; - std::vector signal_pads_; - int device_idx_; - int rank_; - int world_size_; - void** buffers_dev_; - void** signal_pads_dev_; - std::string group_name_; - - std::vector rank_to_global_rank_; - int* rank_to_global_rank_dev_; -}; - -class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { - public: - void* alloc( - size_t size, - int device_idx, - const std::optional& group_name) override { - TORCH_CHECK( - group_name == std::nullopt, - "NVSHMEMSymmetricMemoryAllocator::alloc " - "must not be called with a group_name"); - - auto group_info = get_group_info("0"); - auto store = group_info.store; - int rank = group_info.rank; - int world_size = group_info.world_size; - - nvshmem_extension::initialize_nvshmem_with_store(store, rank, world_size); - auto ptr = nvshmem_extension::nvshmem_malloc(size); - auto allocation = - std::make_shared(ptr, size, device_idx); - // TODO: thread safety - allocations_.emplace(ptr, allocation); - return ptr; - } - - void free(void* ptr) override { - // TODO: thread safety - ptr_to_symm_mem_.erase(ptr); - }; - - size_t get_alloc_size(void* ptr) override { - auto it = ptr_to_symm_mem_.find(ptr); - if (it == ptr_to_symm_mem_.end()) { - TORCH_CHECK( - false, ptr, " is not allocated with NVSHMEMSymmetricMemoryAllocator"); - } - return it->second->get_buffer_size(); - }; - - c10::intrusive_ptr rendezvous( - void* ptr, - const std::optional& group_name) override { - TORCH_CHECK(group_name.has_value()); - { - auto it = symm_mems_.find(std::make_tuple(ptr, *group_name)); - if (it != symm_mems_.end()) { - return it->second; - } - } - auto it = allocations_.find(ptr); - TORCH_CHECK(it != allocations_.end()); - auto symm_mem = - c10::make_intrusive(it->second, *group_name); - - symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem; - return symm_mem; - }; - - bool has_multicast_support(int device_idx) override { - // TODO - return false; - }; - - private: - std::unordered_map> - ptr_to_symm_mem_; - - std::unordered_map> allocations_; - std::map, c10::intrusive_ptr> - symm_mems_; -}; - -struct RegisterNVSHMEMSymmetricMemoryAllocator { - RegisterNVSHMEMSymmetricMemoryAllocator() { - // Query backend used for CUDA tensor - if (getSymmMemBackendCUDA() == "NVSHMEM") { - register_allocator( - c10::DeviceType::CUDA, - c10::make_intrusive()); - } - } -}; - -static RegisterNVSHMEMSymmetricMemoryAllocator register_allocator_; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index bd18c68312a687..087c2831b4edbb 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -4,10 +4,12 @@ #ifdef USE_C10D_GLOO +#include #include #include #include #include +#include #include #include @@ -232,6 +234,10 @@ c10::intrusive_ptr ProcessGroupGloo::AsyncWork:: return future_; } +std::chrono::milliseconds ProcessGroupGloo::AsyncWork::getTimeout() const { + return context_->getTimeout(); +} + namespace { c10::intrusive_ptr createFutureAsOutput( const std::vector>& outputTensors) { @@ -289,6 +295,7 @@ inline void ProcessGroupGloo::AsyncWork::recordAsyncWorkProfilingInfo( } ProcessGroupGloo::AsyncWork::AsyncWork( + std::shared_ptr context, std::vector> outputTensors, OpType opType, uint64_t seq, @@ -298,11 +305,13 @@ ProcessGroupGloo::AsyncWork::AsyncWork( // replace default profiler implementation with async version that reports // correct timestamps for work that is asynchronously executed. : Work(-1, opType, nullptr, inputTensors), + context_(std::move(context)), outputTensors_(std::move(outputTensors)), future_(createFutureAsOutput(outputTensors_)), seq_(seq) { if (profilingTitle != nullptr) { recordAsyncWorkProfilingInfo(profilingTitle, inputTensors); + profilingTitle_ = profilingTitle; } } @@ -538,6 +547,8 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: } #endif +static std::atomic process_group_id = 0; + ProcessGroupGloo::ProcessGroupGloo( const c10::intrusive_ptr& store, int rank, @@ -547,7 +558,8 @@ ProcessGroupGloo::ProcessGroupGloo( store_(new GlooStore(store)), options_(std::move(options)), stop_(false), - collectiveCounter_(0) { + collectiveCounter_(0), + local_id_(process_group_id++) { auto& devices = options_->devices; if (devices.empty()) { TORCH_CHECK(false, "No device(s) specified"); @@ -606,8 +618,14 @@ ProcessGroupGloo::ProcessGroupGloo( for (const auto i : c10::irange(threads_.size())) { threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i); } + this->setGroupUid(options_->group_name); + // TODO: If gloo has version, we also need to log gloo version into FR. + FlightRecorder::get()->record_pg_ranks( + std::make_tuple(pg_uid_, pg_desc_), groupRanks()); init(); + + // TODO: Add configs print like ProcessGroupNCCL. } ProcessGroupGloo::~ProcessGroupGloo() { @@ -655,13 +673,52 @@ void ProcessGroupGloo::runLoop(int workerIndex) { workConsumeCV_.notify_one(); AsyncWork::execute(work); + // TODO: Need to find a way to calculate the difference of duration of two + // c10d::Event + pgStatus_->lastCompletedSeq = static_cast(work->seq_); + pgStatus_->lastCompletedWorkName = opTypeToString(work->opType_); + // TODO: We need to have numel of tensors for gloo as well. + pgStatus_->lastCompletedNumelIn = 0; + pgStatus_->lastCompletedNumelOut = 0; + FlightRecorder::get()->retire_id(work->trace_id_, false); lock.lock(); workInProgress_[workerIndex].reset(); } } +const std::vector& ProcessGroupGloo::groupRanks() const { + if (options_->global_ranks_in_group.empty() && local_id_ == 0) { + static std::vector globalRanks(size_); + std::iota(globalRanks.begin(), globalRanks.end(), 0); + return globalRanks; + } + return options_->global_ranks_in_group; +} + void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); + pgStatus_->lastEnqueuedSeq = static_cast(work->seq_); + pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_); + // TODO: We need to have numel of tensors for gloo as well. + pgStatus_->lastEnqueuedNumelIn = 0; + pgStatus_->lastEnqueuedNumelOut = 0; + // using c10d::FlightRecorder; + // TODO: We need to have a way to use c10::Event inside gloo as well. + work->trace_id_ = FlightRecorder::get()->record( + local_id_, + std::make_tuple(pg_uid_, pg_desc_), + collectiveCounter_, + 0, // p2p_seq_id, set 0 for now since p2p does not call enqueue + work->getSequencenumber(), // We need to differentiate between p2p and + // non-p2p op. + work->getProfilerTitle(), + work->getInputTensors(), + work->getOutputTensors(), + nullptr, + nullptr, + work->getTimeout(), + pgStatus_, + false); workQueue_.push_back(std::move(work)); lock.unlock(); @@ -682,18 +739,17 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), {inputs}, OpType::BROADCAST, seq, "gloo:broadcast", inputs), - context(std::move(context)), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), tag(tag) {} - std::shared_ptr context; std::vector inputs{}; const int rootRank; const int rootTensor; @@ -701,13 +757,21 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { void broadcast(at::Tensor& tensor) { const auto& scalarType = tensor.scalar_type(); - gloo::BroadcastOptions opts(context); + gloo::BroadcastOptions opts(context_); opts.setRoot(rootRank); opts.setTag(tag); GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor); gloo::broadcast(opts); } + const std::vector getInputTensors() override { + return inputs; + } + + const std::vector getOutputTensors() override { + return inputs; + } + void run() override { broadcast(inputs[rootTensor]); @@ -736,7 +800,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { // Create pinned host side tensors. tmp = pinnedLike(inputs[rootTensor]); c10::OptionalStreamGuard guard; - if (context->rank == rootRank) { + if (context_->rank == rootRank) { guard.reset_stream(streams[rootTensor]); tmp.copy_(inputs[rootTensor], /* non_blocking */ true); } @@ -744,7 +808,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { void run() override { // Synchronize with copy operation if applicable. - if (context->rank == rootRank) { + if (context_->rank == rootRank) { streams[rootTensor].synchronize(); } @@ -901,7 +965,7 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce_sparse( const AllreduceOptions& opts) { // all reduce sparse calls into default allreduce which // implemented with all_gathering indices and values - // we do ths we do not have a native cuda implementation + // we do this we do not have a native cuda implementation return allreduce(inputs, opts); } @@ -978,19 +1042,18 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), {inputs}, OpType::REDUCE, seq, "gloo:reduce", inputs), - context(std::move(context)), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), reduceOp(std::move(reduceOp)), tag(tag) {} - std::shared_ptr context; std::vector inputs{}; const int rootRank; const int rootTensor; @@ -999,7 +1062,7 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { void reduce(std::vector& tensors) { const auto& scalarType = tensors[0].scalar_type(); - gloo::ReduceOptions opts(context); + gloo::ReduceOptions opts(context_); opts.setRoot(rootRank); opts.setTag(tag); opts.setReduceFunction(getFunction(scalarType, reduceOp)); @@ -1008,7 +1071,7 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { // Gloo doesn't support AVG so we use SUM + division. if (reduceOp == ReduceOp::AVG) { - tensors[0] /= context->size; + tensors[0] /= context_->size; } } @@ -1016,6 +1079,14 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { reduce(inputs); } + const std::vector getInputTensors() override { + return inputs; + } + + const std::vector getOutputTensors() override { + return inputs; + } + protected: template void getFunction(gloo::ReduceOptions::Func& fn, const ReduceOp op) { @@ -1159,17 +1230,16 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), outputs, OpType::ALLGATHER, seq, "gloo:all_gather", inputs), - context(std::move(context)), outputs(outputs), inputs(inputs), tag(tag) {} - std::shared_ptr context; std::vector> outputs{}; std::vector inputs{}; const uint32_t tag; @@ -1178,7 +1248,7 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { std::vector>& outputs, std::vector& inputs) { const auto& scalarType = inputs[0].scalar_type(); - gloo::AllgatherOptions opts(context); + gloo::AllgatherOptions opts(context_); opts.setTag(tag); // Use single flattened input tensor. @@ -1200,6 +1270,14 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { } } + const std::vector getInputTensors() override { + return inputs; + } + + const std::vector getOutputTensors() override { + return {newLikeFlat(outputs[0])}; + } + void run() override { allgather(outputs, inputs); } @@ -1431,17 +1509,16 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), output_lists, OpType::ALLGATHER_COALESCED, seq, "gloo:all_gather", input_list), - context(std::move(context)), output_lists(output_lists), input_list(input_list), tag(tag) {} - std::shared_ptr context; std::vector> output_lists{}; std::vector input_list{}; const uint32_t tag; @@ -1452,7 +1529,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { assert(!input_list.empty()); const auto& scalarType = input_list[0].scalar_type(); - gloo::AllgatherOptions opts(context); + gloo::AllgatherOptions opts(context_); opts.setTag(tag); // Use single flattened input tensor. @@ -1484,6 +1561,14 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { } } + const std::vector getInputTensors() override { + return input_list; + } + + const std::vector getOutputTensors() override { + return {newLikeFlat(output_lists[0])}; + } + void run() override { allgather_coalesced(); } @@ -1574,18 +1659,17 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), outputs, OpType::GATHER, seq, "gloo:gather", inputs), - context(std::move(context)), outputs(outputs), inputs(inputs), root(root), tag(tag) {} - std::shared_ptr context; std::vector> outputs{}; std::vector inputs{}; const int root; @@ -1595,14 +1679,14 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { std::vector>& outputs, std::vector& inputs) { const auto scalarType = inputs[0].scalar_type(); - gloo::GatherOptions opts(context); + gloo::GatherOptions opts(context_); opts.setRoot(root); opts.setTag(tag); // Set single temporary tensor on root process. // This is later scattered to the separate output tensors. at::Tensor flatOutputTensor; - if (context->rank == root) { + if (context_->rank == root) { flatOutputTensor = newLikeFlat(outputs[0]); GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); } @@ -1612,13 +1696,22 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { gloo::gather(opts); // Unflatten into output tensors on root process. - if (context->rank == root) { + if (context_->rank == root) { for (const auto i : c10::irange(outputs[0].size())) { outputs[0][i].copy_(flatOutputTensor[static_cast(i)]); } } } + const std::vector getInputTensors() override { + return inputs; + } + + const std::vector getOutputTensors() override { + return outputs.empty() ? std::vector{} + : std::vector{newLikeFlat(outputs[0])}; + } + void run() override { gather(outputs, inputs); } @@ -1779,19 +1872,18 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), {outputs}, OpType::SCATTER, seq, "gloo:scatter", !inputs.empty() ? std::optional>(inputs[0]) : std::nullopt), - context(std::move(context)), outputs(outputs), inputs(inputs), root(root), tag(tag) {} - std::shared_ptr context; std::vector outputs{}; std::vector> inputs{}; const int root; @@ -1801,12 +1893,12 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { std::vector& outputs, std::vector>& inputs) { const auto scalarType = outputs[0].scalar_type(); - gloo::ScatterOptions opts(context); + gloo::ScatterOptions opts(context_); opts.setRoot(root); opts.setTag(tag); // Set list of input tensors on root process - if (context->rank == root) { + if (context_->rank == root) { GENERATE_ALL_TYPES(scalarType, setInputs, opts, inputs[0]); } @@ -1815,6 +1907,15 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { gloo::scatter(opts); } + const std::vector getInputTensors() override { + return inputs.empty() ? std::vector{} + : std::vector{newLikeFlat(inputs[0])}; + } + + const std::vector getOutputTensors() override { + return outputs; + } + void run() override { scatter(outputs, inputs); } @@ -2012,19 +2113,18 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), {{outputTensor}}, OpType::ALLTOALL, seq, "gloo:all_to_all", std::optional>({inputTensor})), - context(std::move(context)), outputTensor(outputTensor), inputTensor(inputTensor), outputCounts(std::move(outputCounts)), inputCounts(std::move(inputCounts)), tag(tag) {} - std::shared_ptr context; at::Tensor outputTensor; at::Tensor inputTensor; std::vector outputCounts{}; @@ -2035,24 +2135,24 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { const auto scalarType = outputTensor.scalar_type(); if (outputCounts.empty() && inputCounts.empty()) { // Gloo alltoall - gloo::AlltoallOptions opts(context); + gloo::AlltoallOptions opts(context_); opts.setTag(tag); GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor); GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor); gloo::alltoall(opts); } else { // Gloo alltoallv - c10d::checkSplitSizes(inputCounts, inputTensor, context->size); - c10d::checkSplitSizes(outputCounts, outputTensor, context->size); - std::vector sendCounts(context->size); - std::vector recvCounts(context->size); - std::vector sendOffsets(context->size); - std::vector recvOffsets(context->size); + c10d::checkSplitSizes(inputCounts, inputTensor, context_->size); + c10d::checkSplitSizes(outputCounts, outputTensor, context_->size); + std::vector sendCounts(context_->size); + std::vector recvCounts(context_->size); + std::vector sendOffsets(context_->size); + std::vector recvOffsets(context_->size); c10d::computeLengthsAndOffsets( inputCounts, inputTensor, &sendCounts, &sendOffsets); c10d::computeLengthsAndOffsets( outputCounts, outputTensor, &recvCounts, &recvOffsets); - gloo::AlltoallvOptions opts(context); + gloo::AlltoallvOptions opts(context_); opts.setTag(tag); GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor, sendCounts); GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor, recvCounts); @@ -2060,6 +2160,14 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { } } + const std::vector getInputTensors() override { + return {inputTensor}; + } + + const std::vector getOutputTensors() override { + return {outputTensor}; + } + void run() override { alltoall(outputTensor, inputTensor); } @@ -2284,18 +2392,26 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), {}, OpType::BARRIER, seq, "gloo:barrier", std::nullopt), - context(std::move(context)), priorWork(std::move(priorWork)), tag(tag) {} - std::shared_ptr context; std::vector> priorWork{}; const uint32_t tag; + std::vector inputs{}; + + const std::vector getInputTensors() override { + return inputs; + } + + const std::vector getOutputTensors() override { + return inputs; + } void run() override { // Wait on prior work to complete @@ -2306,7 +2422,7 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { } } - gloo::BarrierOptions opts(context); + gloo::BarrierOptions opts(context_); opts.setTag(tag); gloo::barrier(opts); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 917544d9e11372..6a9f52771ea630 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -64,6 +65,7 @@ class TORCH_API ProcessGroupGloo : public Backend { class TORCH_API AsyncWork : public Work { public: explicit AsyncWork( + std::shared_ptr context, std::vector> outputTensors, OpType opType, uint64_t seq, @@ -81,13 +83,22 @@ class TORCH_API ProcessGroupGloo : public Backend { c10::intrusive_ptr getFuture() override; uint64_t getSequencenumber() const override; - + std::chrono::milliseconds getTimeout() const; + virtual const std::vector getInputTensors() = 0; + virtual const std::vector getOutputTensors() = 0; + inline std::string getProfilerTitle() const { + return profilingTitle_; + } inline at::ThreadLocalState getTLS() const { return tls_; } protected: friend class ProcessGroupGloo; + // unique id used to tell the trace buffer that this + // work has completed + std::optional trace_id_; + std::shared_ptr context_; private: void finishWorkGloo(); @@ -100,6 +111,7 @@ class TORCH_API ProcessGroupGloo : public Backend { c10::intrusive_ptr future_; std::function recordFunctionBeforeCallback_; const uint64_t seq_; + std::string profilingTitle_; at::ThreadLocalState tls_; }; @@ -237,6 +249,8 @@ class TORCH_API ProcessGroupGloo : public Backend { return c10::make_intrusive(timeout); } + std::vector global_ranks_in_group; + std::string group_name; std::vector> devices; int threads; }; @@ -278,6 +292,8 @@ class TORCH_API ProcessGroupGloo : public Backend { return options_; } + const std::vector& groupRanks() const; + c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; @@ -438,6 +454,9 @@ class TORCH_API ProcessGroupGloo : public Backend { std::condition_variable workProduceCV_; std::condition_variable workConsumeCV_; uint64_t seq_{0}; + size_t local_id_; + std::shared_ptr pgStatus_ = + std::make_shared(); }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp b/torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp index cd2d77ee1056ef..dae9d3044fa789 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp @@ -15,12 +15,12 @@ class AsyncAllreduceCUDADeviceWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), {inputs}, OpType::ALLREDUCE, seq, "gloo:all_reduce", inputs), - context_(context), inputs_(inputs), reduceOp_(reduceOp) {} @@ -52,12 +52,19 @@ class AsyncAllreduceCUDADeviceWork : public ProcessGroupGloo::AsyncWork { } } + const std::vector getInputTensors() override { + return inputs_; + } + + const std::vector getOutputTensors() override { + return inputs_; + } + void synchronize() override { // TODO: is synchronization needed? } private: - std::shared_ptr context_; std::vector inputs_; const ReduceOp reduceOp_; }; diff --git a/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp b/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp index 0301ed9cb18ae1..1cf6cf25fff69b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp @@ -26,7 +26,7 @@ func(__VA_ARGS__); \ break; \ case ::at::ScalarType::Half: \ - func(__VA_ARGS__); \ + func(__VA_ARGS__); \ break; \ case ::at::ScalarType::BFloat16: \ func(__VA_ARGS__); \ @@ -59,7 +59,7 @@ func(args); \ break; \ case ::at::ScalarType::Half: \ - func(args); \ + func(args); \ break; \ case ::at::ScalarType::BFloat16: \ func(args); \ @@ -271,24 +271,23 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), {inputs}, OpType::ALLREDUCE, seq, "gloo:all_reduce", inputs), - context(std::move(context)), inputs(inputs), reduceOp(std::move(reduceOp)), tag(tag) {} - std::shared_ptr context; std::vector inputs{}; const ReduceOp reduceOp; const uint32_t tag; void allreduce(std::vector& tensors) { const auto& scalarType = tensors[0].scalar_type(); - gloo::AllreduceOptions opts(context); + gloo::AllreduceOptions opts(context_); opts.setReduceFunction(getFunction(scalarType, reduceOp)); opts.setTag(tag); GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors); @@ -296,10 +295,18 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { // Gloo doesn't support AVG so we use SUM + division. if (reduceOp == ReduceOp::AVG) { - tensors[0] /= context->size; + tensors[0] /= context_->size; } } + const std::vector getInputTensors() override { + return inputs; + } + + const std::vector getOutputTensors() override { + return inputs; + } + void run() override { allreduce(inputs); } @@ -359,16 +366,15 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { uint32_t tag, uint64_t seq) : ProcessGroupGloo::AsyncWork( + std::move(context), {inputs}, OpType::_ALLREDUCE_SPARSE, seq, "gloo:sparse_all_reduce", inputs), - context(std::move(context)), inputs(inputs), tag(tag) {} - std::shared_ptr context; std::vector inputs{}; const uint32_t tag; @@ -472,9 +478,9 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { // Sanity check dimensionality across ranks. { - const auto expected = metadata[context->rank].sizes(); - for (const auto i : c10::irange(context->size)) { - if (i == context->rank) { + const auto expected = metadata[context_->rank].sizes(); + for (const auto i : c10::irange(context_->size)) { + if (i == context_->rank) { continue; } const auto actual = metadata[i].sizes(); @@ -487,11 +493,11 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { auto values = allgather_values(input, metadata); // Perform global reduction. - AT_ASSERT(static_cast(indices.size()) == context->size); - AT_ASSERT(static_cast(values.size()) == context->size); + AT_ASSERT(static_cast(indices.size()) == context_->size); + AT_ASSERT(static_cast(values.size()) == context_->size); auto output = at::sparse_coo_tensor( indices[0], values[0], input.sizes(), input.options()); - for (const auto i : c10::irange(1, context->size)) { + for (const auto i : c10::irange(1, context_->size)) { output += at::sparse_coo_tensor( indices[i], values[i], input.sizes(), input.options()); } @@ -510,24 +516,32 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { } } + const std::vector getInputTensors() override { + return inputs; + } + + const std::vector getOutputTensors() override { + return inputs; + } + private: std::vector allgather_metadata( const at::Tensor& tensor) { auto buffer = - at::zeros({context->size, SparseTensorMetadata::dim}, at::kLong); + at::zeros({context_->size, SparseTensorMetadata::dim}, at::kLong); // Prepare metadata vector (1 entry per rank) std::vector metadata; - metadata.reserve(context->size); - for (const auto i : c10::irange(context->size)) { + metadata.reserve(context_->size); + for (const auto i : c10::irange(context_->size)) { metadata.emplace_back(buffer.select(0, i)); } // Populate data for this rank - metadata[context->rank].populate_from_sparse_tensor(tensor); + metadata[context_->rank].populate_from_sparse_tensor(tensor); // Allgather metadata - gloo::AllgatherOptions opts(context); + gloo::AllgatherOptions opts(context_); opts.setOutput(buffer.mutable_data_ptr(), buffer.numel()); opts.setTag(tag); gloo::allgather(opts); @@ -540,7 +554,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { const std::vector& metadata) { const auto sparseDim = tensor.sparse_dim(); - std::vector counts(context->size); + std::vector counts(context_->size); size_t totalSize = 0; for (const auto i : c10::irange(metadata.size())) { counts[i] = metadata[i].nnz() * sparseDim; @@ -554,7 +568,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { auto input = tensor.indices().contiguous(); // Allgatherv indices. - gloo::AllgathervOptions opts(context); + gloo::AllgathervOptions opts(context_); opts.setInput( // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(input.const_data_ptr()), @@ -588,7 +602,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { denseNumel *= dim; } - std::vector counts(context->size); + std::vector counts(context_->size); int64_t totalSize = 0; for (const auto i : c10::irange(metadata.size())) { counts[i] = metadata[i].nnz() * denseNumel; @@ -598,7 +612,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { auto output = at::empty({totalSize}, tensor.scalar_type()); // Allgatherv indices. - gloo::AllgathervOptions opts(context); + gloo::AllgathervOptions opts(context_); // tensors copied from cuda may not be contiguous, get a contiguous // tensor before use its data_ptr at::Tensor valueTensor = tensor.values().contiguous(); diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp index a2dc53884326ee..33bb696cf2a857 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp @@ -65,7 +65,7 @@ struct WorkEntry { // That is, The process may be multi-threaded, and multiple threads may make // MPI calls, but only one at a time: MPI calls are not made concurrently from // two distinct threads (all MPI calls are serialized). However, with -// MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process +// MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a single process // group. In other words, no more than 1 process group can be created globally. // // If you would like to use multiple ProcessGroupMPI, it requires your MPI diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index e6836ca8fa8cec..26fe0af7ef2256 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -217,17 +217,6 @@ void syncStream( ncclEvent.block(ncclStream); } -// Given a ncclUniqueId, convert it to a string representation that can be put -// in the store. -std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { - const uint8_t* bytes = reinterpret_cast(&ncclID); - std::ostringstream oss; - for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { - oss << std::hex << static_cast(bytes[i]); - } - return oss.str(); -} - std::string getNcclAbortedCommStoreKey(const std::string& ncclIdStr) { return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; } @@ -382,8 +371,7 @@ static std:: } } for (auto& ncclComm : allNCCLComms) { - std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); - ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); + ncclDumpMap[ncclComm->getUniqueHash()] = ncclComm->ncclCommDump(); } return ncclDumpMap; #else @@ -511,7 +499,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( bool isP2P, const char* profilingTitle, const std::optional>& inputs, - bool desyncDebug, bool enableTiming, bool cudaEventCacheEnabled, DebugLevel distDebugLevel) @@ -701,7 +688,7 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout( } // Print the traceback of the collective at call time -void ProcessGroupNCCL::WorkNCCL::printTraceback() const { +std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const { // First step we get the corresponding record entry from FR, based on work's // trace_id_ std::optional entry = @@ -717,10 +704,19 @@ void ProcessGroupNCCL::WorkNCCL::printTraceback() const { // Wait for the future to complete or timeout auto status = future.wait_for(std::chrono::seconds(8)); if (status == std::future_status::ready) { - std::string tracebackStr = future.get(); - LOG(ERROR) << "Stack trace of the failed collective: \n" << tracebackStr; - } // else, symbolizer probably timed out, we skip logging the stack trace. - } else { + return future.get(); + } + } + return ""; +} + +// Print the traceback of the collective at call time +void ProcessGroupNCCL::WorkNCCL::printTraceback() const { + std::string tracebackStr = getTraceback(); + if (!tracebackStr.empty()) { + LOG(ERROR) << "Stack trace of the failed collective: \n" << tracebackStr; + } // else, symbolizer probably timed out, we skip logging the stack trace. + else { LOG(ERROR) << "Stack trace of the failed collective not found, " << "potentially because FlightRecorder is disabled. " @@ -952,15 +948,8 @@ ProcessGroupNCCL::ProcessGroupNCCL( blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false); asyncErrorHandling_ = static_cast( getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); - desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || - (dist_debug_level_ >= DebugLevel::Detail); - rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true); - propagatePgError_ = getCvarBool(TORCH_NCCL_PROPAGATE_ERROR, false); enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false); - heartbeat_ = 1ULL; cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, true)); - waitTimeoutDumpInMilSec_ = - getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /*15 Sec*/); traceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 2000); enableCollectiveHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); // store_ usually is wrapped with PrefixStore and the prefix is different @@ -970,21 +959,25 @@ ProcessGroupNCCL::ProcessGroupNCCL( PrefixStore* prefixStore = dynamic_cast(store_.get()); globalStore_ = prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; + auto desyncDebug = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || + (dist_debug_level_ >= DebugLevel::Detail); #ifdef ENABLE_NCCL_ERROR_CHECKING enableTiming_.store( - getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); + getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug); #endif // ENABLE_NCCL_ERROR_CHECKING if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) { TORCH_WARN_ONCE( "TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated."); } + showSerializationWarning_ = + getCvarBool(TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING, true); if (blockingWait_) { LOG(INFO) << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT is enabled, NO watchdog thread is created."; } else { - if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { + if (desyncDebug && asyncErrorHandling_ == NoHandling) { LOG(INFO) << logPrefix() << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " @@ -994,13 +987,34 @@ ProcessGroupNCCL::ProcessGroupNCCL( } } + // If deterministic mode is enabled, we need to disable the NVLS algorithm in + // NCCL. + // TODO: remove this once NVLS supports deterministic mode. + if (at::globalContext().deterministicAlgorithms()) { + // Check if user have already set NCCL_ALGO. If already set, leave it. + auto nccl_algo = c10::utils::get_env("NCCL_ALGO"); + if (!nccl_algo.has_value()) { + LOG(INFO) + << "torch deterministic mode is enabled, " + << "disabling NVLS algorithm in NCCL which can lead to non-deterministic reduction."; + // Sorry we have to disable NVLS for all collectives, be it all-reduce + // or all-gather, because NCCL does not support per-collective + // algorithm selection today. + c10::utils::set_env("NCCL_ALGO", "^NVLS"); + } + } + + // Initialize the heartbeat monitor/watchdog instance. This has to be done + // before the corresponding thread is launched to avoid the error. + heartbeatMonitor_ = std::make_unique(this); + watchdog_ = std::make_unique(this); + #ifdef ENABLE_NCCL_ERROR_CHECKING // in blockingWait mode, we don't need to enable the watchdog thread to check // the timeout or nccl error because the main thread would throw an exception // and it is the user's responsibility to handle the exception. if (!blockingWait_) { - ncclCommWatchdogThread_ = - std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); + watchdog_->start(); } #endif // ENABLE_NCCL_ERROR_CHECKING @@ -1020,8 +1034,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: " << "NCCL version: " << ncclVersion << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ - << ", TORCH_NCCL_PROPAGATE_ERROR: " << propagatePgError_ - << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_ << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load() << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_ << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug @@ -1032,7 +1044,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << traceBufferSize_ << ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_ << ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_; - heartbeatMonitor_ = std::make_unique(this); getGlobalRankStartAndStride( options_->global_ranks_in_group, @@ -1047,11 +1058,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( // This call is idempotent. attachAllocatorHooks(); } - - // Enable Desync Debugger per user setting - if (desyncDebug_) { - desyncDebugger_.init(rank, size, globalRank(), getUid(), store_); - } } void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { @@ -1059,6 +1065,7 @@ void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device " << device; initNCCLComm(key, device, OpType::ALLREDUCE); + eagerInit_ = true; } bool ProcessGroupNCCL::useNonblocking() { @@ -1166,6 +1173,14 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { // register future segments allocated in this pool (this call is idempotent). attachAllocatorHooks(); auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id()); + // TODO: + // if(pool->is_symmetric()) { + // Allgather to verify len(mempool.snapshot.segments) matches across GPUs + // Allgather to verify mempool.alloc_request_counter matches across GPUs + // add alloc_request_counter per mempool (How many allocations a mempool has + // served during its lifetime) this should guarantee pool is used in a + // symmetric/SPMD manner + // } for (const auto& segmentInfo : snapshot.segments) { TORCH_INTERNAL_ASSERT( segmentInfo.device == pool->device(), @@ -1174,7 +1189,9 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(segmentInfo.address), segmentInfo.total_size, - /*errorOnRereg=*/false); // ignores reregistration error + /*errorOnRereg=*/false, // ignores reregistration error + /*window=*/pool->is_symmetric()); // whether to use NCCL symmetric + // memory } } @@ -1205,7 +1222,8 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) { segmentInfo.device == pool->device(), "Mismatch between CUDA memory segment device and pool's device"); // NOLINTNEXTLINE(performance-no-int-to-ptr) - ncclComm->deregisterSegment(reinterpret_cast(segmentInfo.address)); + ncclComm->deregisterSegment( + reinterpret_cast(segmentInfo.address), pool->is_symmetric()); } } @@ -1268,7 +1286,7 @@ void ProcessGroupNCCL::waitForPendingWorks() { // completedWorkList_ before it finishes. // 3. We have three threads and two locks. // a. main thread (this function) grabs two locks atomically - // b. watchdog thread (watchdogHandler function) always grabs + // b. watchdog thread (runLoop function) always grabs // workMetaListMutex_ // first and then grabs completedWorkListMutex_. // c. hook thread (runHookLoop function) only grabs @@ -1396,7 +1414,7 @@ void ProcessGroupNCCL::abortCommsFromMap( bool ProcessGroupNCCL::abortComms( const std::optional& abortReason) { // Remove record from global ncclCommMemPoolMapMutex before aboarting, - // so that a new cache segment would not register to already aborded + // so that a new cache segment would not register to already aborted // communicators. Note that ncclCommMemPoolMap is a global container which may // contain other PG's communicators, thus we need to only erase communicators // for the current PG. @@ -1422,11 +1440,11 @@ void ProcessGroupNCCL::abort() { // communicators and signal the threads to exit. Joining on the threads could // potentially block and hence avoid it in this method. terminateProcessGroup_.store(true); - workMetaListCV_.notify_one(); + watchdog_->notify(); - // lauch abort asynchrounously and wait for it to complete or timeout + // launch abort asynchronously and wait for it to complete or timeout LOG(INFO) << logPrefix() - << "Launching ProcessGroupNCCL abort asynchrounously."; + << "Launching ProcessGroupNCCL abort asynchronously."; std::future fut = std::async(std::launch::async, [this]() { return this->abortComms(); }); @@ -1477,10 +1495,8 @@ void ProcessGroupNCCL::shutdown() { // anymore because I am going to destroy them now LOG(INFO) << logPrefix() << "Operations flushed, joining watchdog thread."; terminateProcessGroup_.store(true); - workMetaListCV_.notify_one(); - if (ncclCommWatchdogThread_.joinable()) { - ncclCommWatchdogThread_.join(); - } + watchdog_->notify(); + watchdog_->join(); if (onCompletionHookThread_.joinable()) { onCompletionHookThread_.join(); } @@ -1539,15 +1555,12 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { // Make sure we've told threads to stop; doesn't hurt if we'd done so before. // Tell watchdog and onCompletionHook: terminateProcessGroup_.store(true); - workMetaListCV_.notify_one(); + watchdog_->notify(); // Tell heartbeat thread: heartbeatMonitor_->stop(); // Wait for all threads to finish before returning - if (ncclCommWatchdogThread_.joinable()) { - ncclCommWatchdogThread_.join(); - LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; - } + watchdog_->join(); heartbeatMonitor_->join(); if (onCompletionHookThread_.joinable()) { onCompletionHookThread_.join(); @@ -1564,7 +1577,6 @@ bool ProcessGroupNCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) { // multiple calls in one runtime. User is responsible for preserving the // output file from an earlier call before a later call overwrites it. static std::mutex writeDebugInfoMutex; - std::lock_guard lock(writeDebugInfoMutex); LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info. Include stack trace: " @@ -1574,6 +1586,9 @@ bool ProcessGroupNCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) { // their customized writer by inheriting `DebugInfoWriter` via // `registerDebugInfoWriter`. auto ncclTrace = dump_nccl_trace(true, includeStackTrace, false); + // dump_nccl_trace will hang so we don't grab the global lock until we get + // the trace. + std::lock_guard lock(writeDebugInfoMutex); DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to " << writer.getWriterTarget(); @@ -1631,11 +1646,15 @@ std::string ProcessGroupNCCL::HeartbeatMonitor::getNCCLWatchdogTimeoutExitMsg( void ProcessGroupNCCL::HeartbeatMonitor::setLastWorkListUpdateTime( std::chrono::time_point time) { - // We intentially let the race condition to happen but this is ok + // We intentionally let the race condition to happen but this is ok // as long as we update the time, we know we are making progress. lastWorkListUpdateTime_ = time; } +int ProcessGroupNCCL::HeartbeatMonitor::getDumpTimeout() const { + return waitTimeoutDumpInMilSec_; +} + ProcessGroupNCCL::HeartbeatMonitor::HeartbeatMonitor(ProcessGroupNCCL* pg) { pg_ = pg; heartbeatTimeoutInSec_ = @@ -1689,6 +1708,8 @@ void ProcessGroupNCCL::HeartbeatMonitor::join() { void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { c10::setThreadName("pt_nccl_heartbt"); + STATIC_SCOPED_WAIT_COUNTER( + pytorch.ProcessGroupNCCL__HeartbeatMonitor__runLoop); uint64_t heartBeatCounter = 0ULL; std::string errorMsg; @@ -1699,6 +1720,9 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { auto lastTimePollStore = std::chrono::steady_clock::now(); auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now(); std::optional dumpPipe = std::nullopt; + // Use a pool to temporarily store the futures to avoid blocking when the code + // exits the scope of when future is generated by std::async. + std::vector> futures; if (pg_->getUid() == 0) { // DumpPipe is one per-trainer process, and its convenient to name them @@ -1733,7 +1757,7 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { // 1. The current rank is the first to observe a timeout in watchdog. // (shouldDump_ was set to true by the watchdog thread). // 2. Other ranks detected the timeout and signal the current rank to - // dump. In addtion, monitor threads will dump if watchdog threads has no + // dump. In addition, monitor threads will dump if watchdog threads has no // heartbeat or dumpPipe is not empty. if (shouldDump_.load()) { errorMsg = getNCCLWatchdogTimeoutErrorMsg("this local rank"); @@ -1829,9 +1853,11 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { // recorder and dump. After dump, the training should continue. if (dumpPipe.has_value() && dumpPipe->shouldDump()) { // best effort dump, not waiting for the dump here - std::future fut = std::async(std::launch::async, [this]() { + LOG(INFO) << pg_->logPrefix() + << "Dump signal received through pipe, triggering FR dump."; + futures.emplace_back(std::async(std::launch::async, [this]() { return this->pg_->dumpDebuggingInfo(); - }); + })); } } LOG(ERROR) << errorMsg; @@ -1881,6 +1907,7 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { // If we failed to dump, try dumping without stack trace in the 2nd // iteration. dumpStackTrace = false; + futures.emplace_back(std::move(asyncDebugDump)); } debugLog.integers["trace_enabled"] = int64_t(dumpStackTrace); auto logger = c10d::C10dLogger::getLogger(); @@ -1929,7 +1956,8 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { // // Or we get stuck in destructors, we will sleep for some time before calling // std::abort() to kill the whole process. - if (shouldDump_.load() && !terminateHeartbeatMonitorThread_.load()) { + if ((pg_->terminateProcessGroup_.load() || shouldDump_.load()) && + !terminateHeartbeatMonitorThread_.load()) { std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_)); LOG(INFO) << pg_->logPrefix() << "slept for " << heartbeatTimeoutInSec_ @@ -1965,27 +1993,70 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { } } -void ProcessGroupNCCL::ncclCommWatchdog() { +ProcessGroupNCCL::Watchdog::Watchdog(ProcessGroupNCCL* pg) { + pg_ = pg; + heartbeat_ = 1ULL; + rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true); + propagatePgError_ = getCvarBool(TORCH_NCCL_PROPAGATE_ERROR, false); + desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || + (pg_->dist_debug_level_ >= DebugLevel::Detail); + + // print out ENV settings for the watchdog thread. + LOG(INFO) << pg_->logPrefix() << "PGNCCL Watchdog environments: " + << "TORCH_NCCL_RETHROW_CUDA_ERRORS: " << rethrowCUDAErrors_ + << ", TORCH_NCCL_PROPAGATE_ERROR: " << propagatePgError_ + << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_; + + // Enable Desync Debugger per user setting + if (desyncDebug_) { + desyncDebugger_.init( + pg_->getRank(), + pg_->getSize(), + pg_->globalRank(), + pg_->getUid(), + pg_->store_); + } +} + +void ProcessGroupNCCL::Watchdog::notify() { + workMetaListCV_.notify_one(); +} + +void ProcessGroupNCCL::Watchdog::start() { + TORCH_CHECK( + !ncclCommWatchdogThread_.joinable(), "Watchdog thread already started"); + ncclCommWatchdogThread_ = std::thread(&ProcessGroupNCCL::Watchdog::run, this); +} + +void ProcessGroupNCCL::Watchdog::join() { + if (ncclCommWatchdogThread_.joinable()) { + ncclCommWatchdogThread_.join(); + LOG(INFO) << pg_->logPrefix() << "ProcessGroupNCCL watchdog thread joined."; + } +} + +void ProcessGroupNCCL::Watchdog::run() { c10::setThreadName("pt_nccl_watchdg"); + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__Watchdog__run); try { - VLOG(2) << logPrefix() << "Process group watchdog thread started!"; - heartbeatMonitor_->start(); - watchdogHandler(); - VLOG(2) << logPrefix() + VLOG(2) << pg_->logPrefix() << "Process group watchdog thread started!"; + pg_->heartbeatMonitor_->start(); + runLoop(); + VLOG(2) << pg_->logPrefix() << "Process group watchdog thread terminated normally"; } catch (std::exception& e) { if (std::string(e.what()).find("driver shutting down") != std::string::npos) { VLOG(2) - << logPrefix() + << pg_->logPrefix() << "main process destroyed cuda before watchdog loop exited, terminating watchdog." << " (Watchdog caught exception: " << e.what(); } else { - // Append error message reported from watchdogHandler + // Append error message reported from runLoop const auto exitMsg = c10::str( - logPrefix(), + pg_->logPrefix(), "Process group watchdog thread terminated with exception: ", e.what()); LOG(ERROR) << exitMsg; @@ -2000,7 +2071,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() { } } catch (...) { const auto exitMsg = c10::str( - logPrefix(), + pg_->logPrefix(), "Process group watchdog thread terminated with exception: unknown"); LOG(ERROR) << exitMsg; watchDogException_ = @@ -2009,176 +2080,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() { } } -// Initialize and enable DesyncDebugger -void ProcessGroupNCCL::DesyncDebugger::init( - int rank, - int size, - int globalRank, - int pgId, - c10::intrusive_ptr store) { - rank_ = rank; - size_ = size; - globalRank_ = globalRank; - pgId_ = pgId; - store_ = std::move(store); - enabled_ = true; - traceKeyStart_ = getTraceStartKey("NCCL", rank); - traceKeyEnd_ = getTraceEndKey("NCCL", rank); -} - -// Run desync debug. This function is called by watchdog at time of timeout. -void ProcessGroupNCCL::DesyncDebugger::run() { - if (!enabled_) - return; - auto logPrefix = c10::str("Rank ", rank_); - ::c10d::C10dLoggingData log; - log.integers["pg_id"] = pgId_; - log.integers["rank"] = rank_; - log.integers["global_rank"] = globalRank_; - log.integers["world_size"] = size_; - // Use this to differentiate between flight recorder and desync debug report. - log.strings["flight_recorder_version"] = "-1"; - - try { - std::string desyncMsg = retrieveDesyncReport(store_, "NCCL", rank_, size_); - log.strings["status"] = "SUCCESS"; - LOG(ERROR) << logPrefix << desyncMsg; - } catch (const std::exception& e) { - log.strings["status"] = "EXCEPTION"; - log.strings["exception_msg"] = e.what(); - enabled_ = false; - LOG(ERROR) << logPrefix - << " Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " - << " Please file an issue. Error: " << e.what(); - } catch (...) { - enabled_ = false; - log.strings["status"] = "EXCEPTION"; - log.strings["exception_msg"] = "Unknown exception"; - LOG(ERROR) - << logPrefix - << " Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." - << " Please file an issue."; - } - auto logger = c10d::C10dLogger::getLogger(); - if (logger) { - logger->log(log); - } -} - -// Log work start to store. -void ProcessGroupNCCL::DesyncDebugger::logWorkStart(WorkNCCL& work) { - if (!enabled_) - return; - if (work.startTraceUpdated_) - return; - - work.startTraceUpdated_ = true; - // If not successful, disable the debugger - enabled_ = c10d::traceUpdate( - store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_)); -} - -// Log work end to store. -void ProcessGroupNCCL::DesyncDebugger::logWorkEnd(WorkNCCL& work) { - if (!enabled_) - return; - - // In case the start of the work hasn't been logged - if (!work.startTraceUpdated_) { - logWorkStart(work); - } - - // If not successful, disable the debugger - enabled_ = c10d::traceUpdate( - store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_)); -} - -// We want to have both PG ID and global unique ID (guid) for the logging -// prefix. PG ID records how many ProcessGroupNCCL objects were created on a -// specific rank and is a stable index across ranks, which lets users reason -// about, for example, the second PG we initialized on this rank is for FSDP, -// and corresponds with PG ID = 1 on other ranks as well. Unlike PG ID, guid (or -// group name) is a global unique ID across ranks. The guid is either a hash of -// all the ranks in the group or a counter of how many times -// `_process_group_name` is called, essentially it means how many times we -// have PGs users have created. Before using split_group, even if -// we are creating a new sub-PG, all ranks have to call the API at the same -// time, and this makes `group_name` a unique identifier for a group (PG). -std::string ProcessGroupNCCL::createLogPrefix() const { - if (!pg_desc_.empty() && pg_desc_ != "undefined") { - return c10::str( - "[PG ID ", - local_id_, - " PG GUID ", - pg_uid_, - "(", - pg_desc_, - ") Rank ", - rank_, - "] "); - } - return c10::str( - "[PG ID ", local_id_, " PG GUID ", pg_uid_, " Rank ", rank_, "] "); -} - -const std::string& ProcessGroupNCCL::logPrefix() const { - return logPrefix_; -} - -const int& ProcessGroupNCCL::globalRank() const { - static int globalRank = rank_; - return globalRank; -} - -const c10::intrusive_ptr& ProcessGroupNCCL::globalStore() const { - return globalStore_; -} - -const std::vector& ProcessGroupNCCL::groupRanks() const { - if (options_->global_ranks_in_group.empty() && local_id_ == 0) { - static std::vector globalRanks(size_); - std::iota(globalRanks.begin(), globalRanks.end(), 0); - return globalRanks; - } - return options_->global_ranks_in_group; -} - -void ProcessGroupNCCL::addEphemeralTimeout( - const std::chrono::milliseconds& timeout) { - std::lock_guard timeoutLock(mtxTimeoutExtension_); - ephemeralTimeoutActive_ += timeout; -} - -bool ProcessGroupNCCL::verifyWorkTimeoutForTest( - const c10::intrusive_ptr& work, - const std::chrono::milliseconds& timeout) { - // Since collective returns a c10d::Work, we need to cast it to WorkNCCL. - if (auto workNCCL = c10::dynamic_intrusive_pointer_cast(work)) { - // workNCCL is now a c10::intrusive_ptr - return workNCCL->opTimeout_ == timeout; - } - C10_THROW_ERROR( - DistBackendError, "Non c10d::WorkNCCL object returned from collective"); -} - -void ProcessGroupNCCL::broadcastSignal( - c10::intrusive_ptr& store, - const std::string& signal, - int srcRank) { - try { - auto vec = std::vector( - reinterpret_cast(&srcRank), - reinterpret_cast(&srcRank) + sizeof(srcRank)); - store->set(signal, vec); - LOG(INFO) << logPrefix() << "Broadcasting signal " << signal - << " to other ranks via TCPStore."; - } catch (const std::exception& e) { - LOG(ERROR) << logPrefix() << "Failed to broadcast signal " << signal - << " through TCPStore. Error: " << e.what(); - } -} - -int ProcessGroupNCCL::getSignalSrcRank( +int ProcessGroupNCCL::Watchdog::getSignalSrcRank( c10::intrusive_ptr& store, const std::string& signal) { // This function is 'non blocking'. We first 'check' if the key exists in the @@ -2188,7 +2090,7 @@ int ProcessGroupNCCL::getSignalSrcRank( try { signalExists = store->check({signal}); } catch (const std::exception& e) { - LOG(WARNING) << logPrefix() << "Failed to check the signal " << signal + LOG(WARNING) << pg_->logPrefix() << "Failed to check the signal " << signal << " on TCPStore, " << e.what(); } if (!signalExists) { @@ -2200,7 +2102,7 @@ int ProcessGroupNCCL::getSignalSrcRank( try { vec = store->get(std::string(signal)); } catch (const std::exception& e) { - LOG(ERROR) << logPrefix() << "Failed to get source rank of the signal " + LOG(ERROR) << pg_->logPrefix() << "Failed to get source rank of the signal " << signal << " from TCPStore." << e.what(); } TORCH_CHECK_WITH( @@ -2211,76 +2113,40 @@ int ProcessGroupNCCL::getSignalSrcRank( return srcRank; } -void ProcessGroupNCCL::broadcastDumpSignal() { - // broadcast dump signal to all other global ranks. - auto global_store = globalStore(); - broadcastSignal(global_store, std::string(kStoreDumpKey), globalRank()); - // signal the local rank to start dumping - if (!shouldDump_.load()) { - LOG(ERROR) << logPrefix() << "First PG on this rank to signal dumping."; - // signal the monitor thread on PG0 to start dumping - shouldDump_.store(true); - } -} - -void ProcessGroupNCCL::checkAndSetRemoteError() { +void ProcessGroupNCCL::Watchdog::checkAndSetRemoteError() { // if the error is already set, no need to check again - if (getError() != ErrorType::SUCCESS) { + if (pg_->getError() != ErrorType::SUCCESS) { return; } // key/signal to read from the tcpstore is a string and pg specific: // format is: remote_error:pg_uid int remoteErrorRank = getSignalSrcRank( - store_, std::string(kStoreErrorSignalKey) + ':' + pg_uid_); + pg_->store_, std::string(kStoreErrorSignalKey) + ':' + pg_->pg_uid_); if (remoteErrorRank != -1) { - std::lock_guard lock(errorMutex_); - error_ = ErrorType::REMOTE_ERROR; + std::lock_guard lock(pg_->errorMutex_); + pg_->error_ = ErrorType::REMOTE_ERROR; LOG(ERROR) << c10::str( - logPrefix(), " remote error detected from rank: ", remoteErrorRank); + pg_->logPrefix(), + " remote error detected from rank: ", + remoteErrorRank); } } -// NCCL recommends to evenly distribute ncclUniqueIds across the ranks -// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#init-rank-config -// Let’s consider an example where: -// nRanks = 10 (total ranks), -// nIds = 3 (roots), -// rmr = 10 % 3 = 1 (1 larger group), -// rpr = 10 / 3 = 3 (base number of ranks per group). -// rlim = 4 -// Output root: -// For ranks [0, 1, 2, 3], root rank is 0 and index is 0. -// For ranks [4, 5, 6], root rank is 4 and index is 1. -// For ranks [7, 8, 9], root rank is 7 and index is 2. -static int getRootIndex(const int rank, const int nRanks, const int nIds) { - const int rmr = nRanks % nIds; - const int rpr = nRanks / nIds; - // For the first rmr roots, we assign one more rank to the root. - const int rlim = rmr * (rpr + 1); - if (rank < rlim) { - // Root with `rpr + 1` ranks, (0, 1, 2, ..., rmr - 1). - return rank % (rpr + 1) ? -1 : rank / (rpr + 1); - } else { - // Root with `rpr` ranks, (rmr, rmr + 1, ..., nIds - 1). - return (rank - rlim) % rpr ? -1 : ((rank - rlim) / rpr) + rmr; - } -} - -void ProcessGroupNCCL::watchdogHandler() { +void ProcessGroupNCCL::Watchdog::runLoop() { bool done = false; - heartbeatMonitor_->setLastWorkListUpdateTime( + pg_->heartbeatMonitor_->setLastWorkListUpdateTime( std::chrono::steady_clock::now()); auto lastStatusUpdateTime = std::chrono::steady_clock::now(); std::list completedWorkList; - while (!done || !terminateProcessGroup_.load()) { - std::unique_lock lock(workMetaListMutex_); + while (!done || !pg_->terminateProcessGroup_.load()) { + std::unique_lock lock(pg_->workMetaListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. workMetaListCV_.wait_for( lock, std::chrono::milliseconds(kWatchdogThreadSleepMillis), - [&]() -> bool { return terminateProcessGroup_.load(); }); + [&]() -> bool { return pg_->terminateProcessGroup_.load(); }); // Bump up heart beat by one. heartbeat_++; @@ -2292,9 +2158,9 @@ void ProcessGroupNCCL::watchdogHandler() { logPrefix(), "NCCL Work update periodically: ", "last enqueued NCCL work: ", - pgStatus_->lastEnqueuedSeq, + pg_->pgStatus_->lastEnqueuedSeq, ", last completed NCCL work: ", - pgStatus_->lastCompletedSeq, + pg_->pgStatus_->lastCompletedSeq, "."); #endif // LOG_EVERY_MS auto logger = ::c10d::C10dLogger::getLogger(); @@ -2304,31 +2170,33 @@ void ProcessGroupNCCL::watchdogHandler() { kWorkStatusUpdatePeriodMs) { ::c10d::C10dLoggingData data; // logging integers - data.integers["pg_id"] = static_cast(local_id_); - data.integers["rank"] = rank_; - data.integers["global_rank"] = globalRank(); - data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq; - data.integers["last_started_work"] = pgStatus_->lastStartedSeq; - data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq; + data.integers["pg_id"] = static_cast(pg_->local_id_); + data.integers["rank"] = pg_->rank_; + data.integers["global_rank"] = pg_->globalRank(); + data.integers["last_enqueued_work"] = pg_->pgStatus_->lastEnqueuedSeq; + data.integers["last_started_work"] = pg_->pgStatus_->lastStartedSeq; + data.integers["last_completed_work"] = pg_->pgStatus_->lastCompletedSeq; data.integers["last_enqueued_numel_in"] = - static_cast(pgStatus_->lastEnqueuedNumelIn); + static_cast(pg_->pgStatus_->lastEnqueuedNumelIn); data.integers["last_enqueued_numel_out"] = - static_cast(pgStatus_->lastEnqueuedNumelOut); + static_cast(pg_->pgStatus_->lastEnqueuedNumelOut); data.integers["last_completed_numel_in"] = - static_cast(pgStatus_->lastCompletedNumelIn); + static_cast(pg_->pgStatus_->lastCompletedNumelIn); data.integers["last_completed_numel_out"] = - static_cast(pgStatus_->lastCompletedNumelOut); + static_cast(pg_->pgStatus_->lastCompletedNumelOut); data.integers["last_started_numel_in"] = - static_cast(pgStatus_->lastStartedNumelIn); + static_cast(pg_->pgStatus_->lastStartedNumelIn); data.integers["last_started_numel_out"] = - static_cast(pgStatus_->lastStartedNumelOut); + static_cast(pg_->pgStatus_->lastStartedNumelOut); // logging strings - data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName; - data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName; + data.strings["last_enqueued_work_name"] = + pg_->pgStatus_->lastEnqueuedWorkName; + data.strings["last_started_work_name"] = + pg_->pgStatus_->lastStartedWorkName; data.strings["last_completed_work_name"] = - pgStatus_->lastCompletedWorkName; - data.strings["pg_name"] = pg_uid_; - data.strings["pg_desc"] = pg_desc_; + pg_->pgStatus_->lastCompletedWorkName; + data.strings["pg_name"] = pg_->pg_uid_; + data.strings["pg_desc"] = pg_->pg_desc_; logger->log(data); lastStatusUpdateTime = std::chrono::steady_clock::now(); } @@ -2338,7 +2206,7 @@ void ProcessGroupNCCL::watchdogHandler() { checkAndSetRemoteError(); } - for (auto it = workMetaList_.begin(); it != workMetaList_.end(); + for (auto it = pg_->workMetaList_.begin(); it != pg_->workMetaList_.end(); /* no increment */) { auto& work = *it; // When terminateProcessGroup_ is true, communicators have already been @@ -2347,15 +2215,15 @@ void ProcessGroupNCCL::watchdogHandler() { // workMetaList_ // check NCCL errors first - if (!terminateProcessGroup_.load()) { + if (!pg_->terminateProcessGroup_.load()) { work.checkAndSetException(); } if (work.exception()) { // set the error to the first error found - std::lock_guard lock(errorMutex_); - if (error_ == ErrorType::SUCCESS) { - error_ = ErrorType::COMM_ERROR; + std::lock_guard lock(pg_->errorMutex_); + if (pg_->error_ == ErrorType::SUCCESS) { + pg_->error_ = ErrorType::COMM_ERROR; } } @@ -2366,9 +2234,9 @@ void ProcessGroupNCCL::watchdogHandler() { // Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is // turned on; otherwise, run() is no-op) if (timedout) { - std::lock_guard lock(errorMutex_); - if (error_ == ErrorType::SUCCESS) { - error_ = ErrorType::TIMEOUT; + std::lock_guard lock(pg_->errorMutex_); + if (pg_->error_ == ErrorType::SUCCESS) { + pg_->error_ = ErrorType::TIMEOUT; } desyncDebugger_.run(); } @@ -2376,13 +2244,13 @@ void ProcessGroupNCCL::watchdogHandler() { // If work hits an exception (either an error or timeout) if (work.exception()) { LOG(ERROR) << c10::str( - logPrefix(), + pg_->logPrefix(), " failure detected by watchdog at work sequence id: ", work.seq_, " PG status: last enqueued work: ", - pgStatus_->lastEnqueuedSeq, + pg_->pgStatus_->lastEnqueuedSeq, ", last completed work: ", - pgStatus_->lastCompletedSeq); + pg_->pgStatus_->lastCompletedSeq); // Print the traceback of the collective at call time work.printTraceback(); @@ -2391,31 +2259,33 @@ void ProcessGroupNCCL::watchdogHandler() { // key/signal to write in the tcpstore is a string and pg specific: // format is: remote_error:pg_uid if (propagatePgError_) { - broadcastSignal( - store_, std::string(kStoreErrorSignalKey) + ':' + pg_uid_, rank_); + pg_->broadcastSignal( + pg_->store_, + std::string(kStoreErrorSignalKey) + ':' + pg_->pg_uid_, + pg_->rank_); } // try to notify other ranks via global TCPStore to dump the flight // recorder when a collective timeout or exception happens. Flight // recorder behavior is independent of desync Debug. - broadcastDumpSignal(); + pg_->broadcastDumpSignal(); // Give time for dumping before throwing exception for all ranks. // It is hard to presume or control what the pattern of watchdog might // look like, so it is better to let all ranks universally sleep for a // short period of time, in this case, 60 seconds, which is also the // maximum time we leave for FR dump. - std::this_thread::sleep_for( - std::chrono::milliseconds(waitTimeoutDumpInMilSec_ * 4)); + std::this_thread::sleep_for(std::chrono::milliseconds( + pg_->heartbeatMonitor_->getDumpTimeout() * 4)); - if (SHOULD_CLEAN_UP(asyncErrorHandling_)) { + if (SHOULD_CLEAN_UP(pg_->asyncErrorHandling_)) { // Abort work and corresponding communicators work.abort(); // PG level abort, which would abort all other communicators on this // rank - abortComms(); + pg_->abortComms(); } // Throw exception - work.handleException(asyncErrorHandling_); + work.handleException(pg_->asyncErrorHandling_); } // Work status logging for desync debug @@ -2424,12 +2294,12 @@ void ProcessGroupNCCL::watchdogHandler() { // a work could be started but not completed, so we should not update // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start - if (pgStatus_->lastStartedSeq < static_cast(work.seq_) && + if (pg_->pgStatus_->lastStartedSeq < static_cast(work.seq_) && work.isStarted()) { - pgStatus_->lastStartedSeq = static_cast(work.seq_); - pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); - pgStatus_->lastStartedNumelIn = work.numelIn_; - pgStatus_->lastStartedNumelOut = work.numelOut_; + pg_->pgStatus_->lastStartedSeq = static_cast(work.seq_); + pg_->pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); + pg_->pgStatus_->lastStartedNumelIn = work.numelIn_; + pg_->pgStatus_->lastStartedNumelOut = work.numelOut_; } // allow watchdog to do an event query on a side thread @@ -2449,10 +2319,29 @@ void ProcessGroupNCCL::watchdogHandler() { // to be destructed, so we transfer the work's shelf to a shelves // structure owned by the PG. if (!work.stashed_for_allocator_safety_->empty()) { - std::lock_guard lock(shelvesMutex_); + std::lock_guard lock(pg_->shelvesMutex_); // We are just pushing back a shared_ptr here, so the cost should be // minimal - shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_); + pg_->shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_); + } + + if (pg_->enableTiming_ && logger) { + ::c10d::C10dLoggingData data; + // logging integers + data.strings["collective_duration"] = + std::to_string(work.getDuration()); + data.integers["global_rank"] = pg_->globalRank(); + data.integers["pg_id"] = static_cast(pg_->local_id_); + data.strings["pg_name"] = pg_->pg_uid_; + data.strings["pg_desc"] = pg_->pg_desc_; + data.integers["pg_rank"] = pg_->rank_; + data.integers["world_size"] = pg_->size_; + data.strings["comm_backend"] = "nccl"; + data.strings["comm_backend_version"] = getNcclVersion(); + // TODO: We see errors for this line, revert it for now. + data.strings["collective_stack"] = ""; + data.strings["collective_name"] = opTypeToString(work.opType_); + logger->log(data); } // Work status logging for desync debug @@ -2465,29 +2354,30 @@ void ProcessGroupNCCL::watchdogHandler() { } { // Reset the timeout and first work if the work is completed. - std::lock_guard timeoutLock(mtxTimeoutExtension_); + std::lock_guard timeoutLock(pg_->mtxTimeoutExtension_); if (work.ownedEphermeralTimeout_.count() > 0) { - ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; - ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; + pg_->ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; + pg_->ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; } } - pgStatus_->lastCompletedSeq = static_cast(work.seq_); - pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); - pgStatus_->lastCompletedNumelIn = work.numelIn_; - pgStatus_->lastCompletedNumelOut = work.numelOut_; + pg_->pgStatus_->lastCompletedSeq = static_cast(work.seq_); + pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); + pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_; + pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_; FlightRecorderCUDA::get()->retire_id(work.trace_id_, true); - if (onCompletionHook_) { + if (pg_->onCompletionHook_) { // Move Work object to completedWorkList_ to be consumed by the hook // thread { - const std::lock_guard lock(completedWorkListMutex_); - completedWorkList_.splice( - completedWorkList_.end(), workMetaList_, it++); + const std::lock_guard lock( + pg_->completedWorkListMutex_); + pg_->completedWorkList_.splice( + pg_->completedWorkList_.end(), pg_->workMetaList_, it++); } - completedWorkListCV_.notify_one(); + pg_->completedWorkListCV_.notify_one(); } else { - it = workMetaList_.erase(it); - heartbeatMonitor_->setLastWorkListUpdateTime( + it = pg_->workMetaList_.erase(it); + pg_->heartbeatMonitor_->setLastWorkListUpdateTime( std::chrono::steady_clock::now()); } } else { @@ -2499,7 +2389,221 @@ void ProcessGroupNCCL::watchdogHandler() { // in case processing is slowed down (but not hung) by cuda api contention heartbeat_++; } - done = workMetaList_.empty(); + done = pg_->workMetaList_.empty(); + } +} + +uint64_t ProcessGroupNCCL::Watchdog::getHeartbt() const { + return heartbeat_.load(); +} + +void ProcessGroupNCCL::Watchdog::setDesyncDebug(bool desyncDebug) { + desyncDebug_ = desyncDebug; +} + +// Initialize and enable DesyncDebugger +void ProcessGroupNCCL::DesyncDebugger::init( + int rank, + int size, + int globalRank, + int pgId, + c10::intrusive_ptr store) { + rank_ = rank; + size_ = size; + globalRank_ = globalRank; + pgId_ = pgId; + store_ = std::move(store); + enabled_ = true; + traceKeyStart_ = getTraceStartKey("NCCL", rank); + traceKeyEnd_ = getTraceEndKey("NCCL", rank); +} + +// Run desync debug. This function is called by watchdog at time of timeout. +void ProcessGroupNCCL::DesyncDebugger::run() { + if (!enabled_) + return; + auto logPrefix = c10::str("Rank ", rank_); + ::c10d::C10dLoggingData log; + log.integers["pg_id"] = pgId_; + log.integers["rank"] = rank_; + log.integers["global_rank"] = globalRank_; + log.integers["world_size"] = size_; + // Use this to differentiate between flight recorder and desync debug report. + log.strings["flight_recorder_version"] = "-1"; + + try { + std::string desyncMsg = retrieveDesyncReport(store_, "NCCL", rank_, size_); + log.strings["status"] = "SUCCESS"; + LOG(ERROR) << logPrefix << desyncMsg; + } catch (const std::exception& e) { + log.strings["status"] = "EXCEPTION"; + log.strings["exception_msg"] = e.what(); + enabled_ = false; + LOG(ERROR) << logPrefix + << " Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " + << " Please file an issue. Error: " << e.what(); + } catch (...) { + enabled_ = false; + log.strings["status"] = "EXCEPTION"; + log.strings["exception_msg"] = "Unknown exception"; + LOG(ERROR) + << logPrefix + << " Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." + << " Please file an issue."; + } + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(log); + } +} + +// Log work start to store. +void ProcessGroupNCCL::DesyncDebugger::logWorkStart(WorkNCCL& work) { + if (!enabled_) + return; + if (work.startTraceUpdated_) + return; + + work.startTraceUpdated_ = true; + // If not successful, disable the debugger + enabled_ = c10d::traceUpdate( + store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_)); +} + +// Log work end to store. +void ProcessGroupNCCL::DesyncDebugger::logWorkEnd(WorkNCCL& work) { + if (!enabled_) + return; + + // In case the start of the work hasn't been logged + if (!work.startTraceUpdated_) { + logWorkStart(work); + } + + // If not successful, disable the debugger + enabled_ = c10d::traceUpdate( + store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_)); +} + +// We want to have both PG ID and global unique ID (guid) for the logging +// prefix. PG ID records how many ProcessGroupNCCL objects were created on a +// specific rank and is a stable index across ranks, which lets users reason +// about, for example, the second PG we initialized on this rank is for FSDP, +// and corresponds with PG ID = 1 on other ranks as well. Unlike PG ID, guid (or +// group name) is a global unique ID across ranks. The guid is either a hash of +// all the ranks in the group or a counter of how many times +// `_process_group_name` is called, essentially it means how many times we +// have PGs users have created. Before using split_group, even if +// we are creating a new sub-PG, all ranks have to call the API at the same +// time, and this makes `group_name` a unique identifier for a group (PG). +std::string ProcessGroupNCCL::createLogPrefix() const { + if (!pg_desc_.empty() && pg_desc_ != "undefined") { + return c10::str( + "[PG ID ", + local_id_, + " PG GUID ", + pg_uid_, + "(", + pg_desc_, + ") Rank ", + rank_, + "] "); + } + return c10::str( + "[PG ID ", local_id_, " PG GUID ", pg_uid_, " Rank ", rank_, "] "); +} + +const std::string& ProcessGroupNCCL::logPrefix() const { + return logPrefix_; +} + +const int& ProcessGroupNCCL::globalRank() const { + static int globalRank = rank_; + return globalRank; +} + +const c10::intrusive_ptr& ProcessGroupNCCL::globalStore() const { + return globalStore_; +} + +const std::vector& ProcessGroupNCCL::groupRanks() const { + if (options_->global_ranks_in_group.empty() && local_id_ == 0) { + static std::vector globalRanks(size_); + std::iota(globalRanks.begin(), globalRanks.end(), 0); + return globalRanks; + } + return options_->global_ranks_in_group; +} + +void ProcessGroupNCCL::addEphemeralTimeout( + const std::chrono::milliseconds& timeout) { + std::lock_guard timeoutLock(mtxTimeoutExtension_); + ephemeralTimeoutActive_ += timeout; +} + +bool ProcessGroupNCCL::verifyWorkTimeoutForTest( + const c10::intrusive_ptr& work, + const std::chrono::milliseconds& timeout) { + // Since collective returns a c10d::Work, we need to cast it to WorkNCCL. + if (auto workNCCL = c10::dynamic_intrusive_pointer_cast(work)) { + // workNCCL is now a c10::intrusive_ptr + return workNCCL->opTimeout_ == timeout; + } + C10_THROW_ERROR( + DistBackendError, "Non c10d::WorkNCCL object returned from collective"); +} + +void ProcessGroupNCCL::broadcastSignal( + c10::intrusive_ptr& store, + const std::string& signal, + int srcRank) { + try { + auto vec = std::vector( + reinterpret_cast(&srcRank), + reinterpret_cast(&srcRank) + sizeof(srcRank)); + store->set(signal, vec); + LOG(INFO) << logPrefix() << "Broadcasting signal " << signal + << " to other ranks via TCPStore."; + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() << "Failed to broadcast signal " << signal + << " through TCPStore. Error: " << e.what(); + } +} + +void ProcessGroupNCCL::broadcastDumpSignal() { + // broadcast dump signal to all other global ranks. + broadcastSignal(globalStore_, std::string(kStoreDumpKey), globalRank()); + // signal the local rank to start dumping + if (!shouldDump_.load()) { + LOG(ERROR) << logPrefix() << "First PG on this rank to signal dumping."; + // signal the monitor thread on PG0 to start dumping + shouldDump_.store(true); + } +} + +// NCCL recommends to evenly distribute ncclUniqueIds across the ranks +// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#init-rank-config +// Let’s consider an example where: +// nRanks = 10 (total ranks), +// nIds = 3 (roots), +// rmr = 10 % 3 = 1 (1 larger group), +// rpr = 10 / 3 = 3 (base number of ranks per group). +// rlim = 4 +// Output root: +// For ranks [0, 1, 2, 3], root rank is 0 and index is 0. +// For ranks [4, 5, 6], root rank is 4 and index is 1. +// For ranks [7, 8, 9], root rank is 7 and index is 2. +static int getRootIndex(const int rank, const int nRanks, const int nIds) { + const int rmr = nRanks % nIds; + const int rpr = nRanks / nIds; + // For the first rmr roots, we assign one more rank to the root. + const int rlim = rmr * (rpr + 1); + if (rank < rlim) { + // Root with `rpr + 1` ranks, (0, 1, 2, ..., rmr - 1). + return rank % (rpr + 1) ? -1 : rank / (rpr + 1); + } else { + // Root with `rpr` ranks, (rmr, rmr + 1, ..., nIds - 1). + return (rank - rlim) % rpr ? -1 : ((rank - rlim) / rpr) + rmr; } } @@ -2924,7 +3028,7 @@ std::shared_ptr ProcessGroupNCCL::initNCCLComm( bool useScalableInit = false; // (nranks / nroots) == 128 was the default NCCL recommended - // accoring to + // according to // https://github.com/pytorch/pytorch/pull/136789#discussion_r1779171615. auto ranksPerRoot = getCvarInt(TORCH_NCCL_RANKS_PER_ROOT, 128); #if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG) @@ -3205,7 +3309,6 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) : std::nullopt, - desyncDebug_, enableTiming_.load(), cudaEventCacheEnabled_.load(), dist_debug_level_); @@ -3222,7 +3325,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( // - initially, moved record() into workEnqueue(), but found that makes it // hard to get access to profilingTitle, // inputs, and outputs for metadata recording, and we don't want to attach - // these objects to the Work becuase it has implications for keeping those + // these objects to the Work because it has implications for keeping those // tensors alive longer and adds overhead when copying Work objects // between threads r->trace_id_ = FlightRecorderCUDA::get()->record( @@ -3326,7 +3429,7 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; uint64_t ProcessGroupNCCL::getWatchdogHeartbt() const { - return heartbeat_.load(); + return watchdog_->getHeartbt(); } void ProcessGroupNCCL::startCoalescing() { @@ -3337,7 +3440,7 @@ void ProcessGroupNCCL::startCoalescing() { // ops from a coalesce group into the flight recorder, we want to have the // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during // start, which has one minor downside- we burn a seq_ if someone ever does a - // 'start' and 'end' coalescing region without doing an operation inbetween. + // 'start' and 'end' coalescing region without doing an operation in between. coalescedDevice_.set_index(-1); coalescedComm_ = nullptr; @@ -3357,7 +3460,7 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { } TORCH_CHECK( coalescedDevice_.index() >= 0, - "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + "Something went wrong. Did you call end_coalescing before start_coalescing?"); // `coalescedComm_` should have same set of comms across collectives auto comm = coalescedComm_; @@ -3513,7 +3616,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( device, rank_, opType, false, profilingTitle, inputs, outputs, enqueue); if (coalescing_state_) { // When coalescing, we record events per op that lack timing/state - // information becuase there is no 'work' associated with them, and then + // information because there is no 'work' associated with them, and then // later in endCoalescing we record a 'coalesced' Work which has // timing/state updates via watchdog thread, but lacks op metadata such as // input/output sizes and profilingTitle per-op in the group. @@ -3676,7 +3779,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // collective so there is no flight record and we increment seqCollective_ and // op_id_ together. Compare this to startCoalescing/endCoalescing flow where // we increment either seqP2P_ or seqCollective_ once per group and increment - // op_id_ once per indvidual operation within the group + // op_id_ once per individual operation within the group op_id_++; const auto key = getKeyFromDevice(device); @@ -3848,23 +3951,72 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( at::cuda::OptionalCUDAGuard gpuGuard(device); std::string key; - int p2pRank = 0, p2pTargetRank = 0; - bool isSendRecvSelf = false; + int p2pRank = -1, p2pTargetRank = -1; + bool isSendRecvSelf = rank_ == peer; // For batch_isend_irecv, ncclGroupStart() would be called upfront bool batchP2P = ncclActiveGroupCounter_ > 0; - if (batchP2P) { - // For batch P2P, we need to treat it like a collective when selecting - // communicator, because other ranks can call into this batch other than my - // rank and my peer + + std::shared_ptr ncclComm = nullptr; + if (this->eagerInit_) { + /* In eagerInit mode, reuse the parent comm. Do not lazily create + * p2p communicators. */ + if (!batchP2P && showSerializationWarning_) { + TORCH_WARN_ONCE(c10::str( + logPrefix(), + "An unbatched P2P op (send/recv) was called on this ProcessGroup with size ", + groupRanks().size(), + ". In eager initialization mode, unbatched P2P ops are treated as ", + "independent collective ops, and are thus serialized with ", + "all other ops on this ProcessGroup, including other P2P ", + "ops. To avoid serialization, either create additional ", + "independent ProcessGroups for the P2P ops or use batched ", + "P2P ops. You can squash this warning by setting the environment variable ", + "TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false.")); + } + key = getKeyFromDevice(device); p2pRank = rank_; p2pTargetRank = peer; + ncclComm = getNCCLComm(key); + + TORCH_INTERNAL_ASSERT( + ncclComm != nullptr, + "Parent communicator missing in eager initialization mode."); + + if (!coalescing_state_) { + // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be + // bumped in `startCoalescing`. + seqP2P_++; + } + } else if (batchP2P) { + // TODO(whc) - unclear why we special-case batchP2P to avoid this path, but + // I preserved this existing special case. + key = getKeyFromDevice(device); + p2pRank = rank_; + p2pTargetRank = peer; + ncclComm = getNCCLComm(key); } else { - // For single P2P, preserve the old two-rank behavior (to avoid perf diff) + // We create special 2-rank communicators for each pair of + // send/recv ranks. This limitation exists for two reasons: (1) + // we use a single stream per communicator, so if multiple + // unbatched p2p operations are issued on the same communicator, + // they would map to the same stream and thus would be serialized; + // and (2) Nvidia NCCL does not allow multiple p2p operations to + // be issued on the same communicator over different streams. + + TORCH_WARN_ONCE( + "An unbatched P2P op (send/recv) was called on this ", + "ProcessGroup with size ", + groupRanks().size(), + ". In lazy initialization mode, this will result in a new 2-rank", + " NCCL communicator to be created."); + key = getKeySendRecv(rank_, peer); + /* if we are creating a new comm, reset the p2pRank and + * p2pTargetRank to correspond to this new 2-process communicator */ p2pRank = rank_ <= peer ? 0 : 1; - isSendRecvSelf = rank_ == peer; p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; + ncclComm = getNCCLComm(key); if (!coalescing_state_) { // Bump P2P sequence number. @@ -3876,9 +4028,13 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // coalesced or individual op_id_++; - std::shared_ptr ncclComm = getNCCLComm(key); if (ncclComm == nullptr) { - ncclComm = initNCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + // ncclComm should never be a nullptr in eager init mode. + // For lazy init mode, isSendRecvSelf is only valid for non-batch + // point-to-point operations. For batch operations, force the + // argument to be false. + ncclComm = + initNCCLComm(key, device, opType, p2pRank, isSendRecvSelf && !batchP2P); } if (coalescing_state_ & CoalActive) { @@ -3911,7 +4067,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( c10::intrusive_ptr work; if (coalescing_state_) { // When coalescing, we record events per op that lack timing/state - // information becuase there is no 'work' associated with them, and then + // information because there is no 'work' associated with them, and then // later in endCoalescing we record a 'coalesced' Work which has // timing/state updates via watchdog thread, but lacks op metadata such as // input/output sizes and profilingTitle per-op in the group. @@ -4292,7 +4448,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( std::make_tuple( static_cast(seqCollective_) + 1, false), // seq + 1 to match collective and assume only one collective - // in coalesed range + // in coalesced range std::make_tuple(pg_uid_, pg_desc_), // PG name tuple tensors, // inputTensors tensors, // outputTensors @@ -4589,7 +4745,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( // User-facing outputTensors should be held by the user until after // waiting on work_, or the call makes no sense. We do a stashing here // in case user doesn't hold the outputTensors in downstream code, - // which can cause an early recyle by the CachingAllocator, which can + // which can cause an early recycle by the CachingAllocator, which can // lead to segfault or data corruption. if (opts.asyncOp) { work->stashed_for_allocator_safety_->stash(outputTensors_); @@ -4637,7 +4793,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( std::make_tuple( static_cast(seqCollective_) + 1, false), // seq + 1 to match collective and assume only one collective - // in coalesed range + // in coalesced range std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputs, // inputTensors outputs, // outputTensors @@ -4851,7 +5007,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( std::make_tuple( static_cast(seqCollective_) + 1, false), // seq + 1 to match collective and assume only one collective - // in coalesed range + // in coalesced range std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputs, // inputTensors outputs, // outputTensors diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 91ab0e1e17e13e..104357fb1b38c4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -23,8 +23,8 @@ #include #include #include -#include #include +#include #include #include @@ -43,6 +43,11 @@ namespace c10d { static std::vector TORCH_NCCL_BCAST_UNIQUEID = { "TORCH_NCCL_BCAST_UNIQUEID"}; +// Control EagerInit P2P serialization warning +static std::vector + TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING = { + "TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING"}; + // Control whether to always use high priority streams static std::vector TORCH_NCCL_HIGH_PRIORITY = { "TORCH_NCCL_HIGH_PRIORITY"}; @@ -319,7 +324,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { bool isP2P = false, const char* profilingTitle = nullptr, const std::optional>& inputs = std::nullopt, - bool desyncDebug = false, bool enableTiming = false, bool cudaEventCacheEnabled = false, DebugLevel distDebugLevel = DebugLevel::Off); @@ -386,6 +390,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Print the traceback of the collective at call time void printTraceback() const; + std::string getTraceback() const; + std::vector result() override; protected: @@ -621,6 +627,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { void setLastWorkListUpdateTime( std::chrono::time_point time); + int getDumpTimeout() const; + // Util function to get the timeout error message std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg); @@ -676,6 +684,80 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::chrono::time_point lastWorkListUpdateTime_; }; + // Class that runs as a side thread to check whether the NCCL collective + // is timed out or errors on the cached NCCL communicators. + class Watchdog { + public: + Watchdog(ProcessGroupNCCL* pg); + virtual ~Watchdog() = default; + + // Start the watchdog thread. + void start(); + + // Join the watchdog thread. + void join(); + + // Function that runs as part of a separate thread and checks for errors on + // NCCL communicators. We need a separate thread to check for NCCL errors + // since we can't rely on the user calling certain methods like wait(), + // isCompleted() etc. to detect and remediate errors. In addition to this, + // we need a mechanism to safely abort and remove NCCL communicators from + // our cache. This can be done cleanly by having a thread for the + // ProcessGroupNCCL class. Attempting to modify the communicator cache from + // the WorkNCCL class might run into issues with object lifetime since the + // ProcessGroupNCCL object might get destroyed before the WorkNCCL object. + void run(); + + // Watchdog's inside loop. + // Takes care of cleaning up completed work, and aborting upon failure or + // timeout. + void runLoop(); + + // Notify the loop inside watchdog. + void notify(); + + void checkAndSetRemoteError(); + + // A helper function to get the src rank of a signal from the Store. This is + // nonblocking function returning -1 if the signal is not available yet. + int getSignalSrcRank( + c10::intrusive_ptr& store, + const std::string& signal); + + uint64_t getHeartbt() const; + + void setDesyncDebug(bool desyncDebug); + + private: + std::thread ncclCommWatchdogThread_; + + // We need to keep a reference to the PG instance so that we can access + // the member functions of the PG instance. We store a raw pointer on + // purpose because the watchdog thread now still lives within the + // lifetime of the PG instance. + ProcessGroupNCCL* pg_; + + // Whether the NCCL watchdog should rethrow CUDA errors. + bool rethrowCUDAErrors_ = false; + + std::exception_ptr watchDogException_ = nullptr; + + // Condition Variable for watchdog thread sleep + std::condition_variable workMetaListCV_; + + // Heartbeat of watchdog thread. + std::atomic_uint64_t heartbeat_{}; + + // Whether or not to propagate detected errors to all ranks in the same PG + // through TCPStore. + bool propagatePgError_; + + // Whether or not to enable timeout root cause analysis. + bool desyncDebug_; + + DesyncDebugger desyncDebugger_; + }; + // If you wish to create multiple process groups, each with a potentially // different rank and size, you can do so by passing a new store instance // to each one. If you have only a single store object, you can @@ -947,6 +1029,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Instance of the heartbeat monitor thread. std::unique_ptr heartbeatMonitor_; + // Instance of the watchdog thread. + std::unique_ptr watchdog_; + // Helper that broadcasts nccl unique ID to all ranks through the store void broadcastUniqueNCCLID( ncclUniqueId* ncclID, @@ -1002,6 +1087,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { int globalRankStart_; int globalRankStride_; + private: + bool eagerInit_{false}; + bool showSerializationWarning_{true}; + // Helper that encapsulates work shared across all collective communication // primitives. The callbacks have the following signatures: // @@ -1082,17 +1171,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { static std::exception_ptr checkForNCCLErrorsInternal( std::shared_ptr& ncclComm); - // Function that runs as part of a separate thread and checks for errors on - // NCCL communicators. We need a separate thread to check for NCCL errors - // since we can't rely on the user calling certain methods like wait(), - // isCompleted() etc. to detect and remediate errors. In addition to this, we - // need a mechanism to safely abort and remove NCCL communicators from our - // cache. This can be done cleanly by having a thread for the ProcessGroupNCCL - // class. Attempting to modify the communicator cache from the WorkNCCL class - // might run into issues with object lifetime since the ProcessGroupNCCL - // object might get destroyed before the WorkNCCL object. - void ncclCommWatchdog(); - // Return the CUDA device most likely associated with this backend. // If we aren't bound to a specific device, there is no strict // guarantee that this heuristic is the correct assignment of ranks @@ -1106,11 +1184,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // communicators from the cache and clears used device indices. void destroyNCCLComms(const std::string& devNCCLCommMapKey); - // Watchdog's inside loop. - // Takes care of cleaning up completed work, and aborting upon failure or - // timeout. - void watchdogHandler(); - void runHookLoop(); // Generates a prefix that is unique to this process group and rank, for @@ -1146,12 +1219,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { const std::string& signal, int srcRank); - // A helper function to get the src rank of a signal from the Store. This is - // nonblocking function returning -1 if the signal is not available yet. - int getSignalSrcRank( - c10::intrusive_ptr& store, - const std::string& signal); - protected: // Function that directly trigger std::abort so that the whole process // gets terminated. @@ -1166,8 +1233,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { ::c10d::C10dLoggingData& debugLog, bool throwException = false); - void checkAndSetRemoteError(); - // A helper function to guess the device id of the current rank, based on // bounded device or used device. Do not use this function if you already know // the device id to operate on. @@ -1235,7 +1300,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // communication, the key will be "1:2" on both processes. Note: this is for // the scenario where there is only 1 GPU per process. When it comes to // multiple GPUs per process, this part may need to redesigned. - // TODO: we probably need a separte map for P2P comms + // TODO: we probably need a separate map for P2P comms std::unordered_map> devNCCLCommMap_; // The NCCL communicators currently in process of being initialized. @@ -1245,21 +1310,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Mutex to guard maps like devNCCLCommMap_. std::mutex mutex_; - // Heartbeat of watchdog thread. - std::atomic_uint64_t heartbeat_{}; - - // timeout for the dump to finish. - int waitTimeoutDumpInMilSec_; - // Size of ring buffer where we store NCCL Traces for debugging. int traceBufferSize_; // We gate the cudaEventCache so that we can roll it out gradually. std::atomic cudaEventCacheEnabled_{}; - // Watchdog thread which looks for errors on the cached NCCL communicators. - std::thread ncclCommWatchdogThread_; - std::thread onCompletionHookThread_; // Whether or not we should terminate the watchdog and workCleanup threads. @@ -1269,7 +1325,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::atomic hasPendingHooks_{}; // This is the signal from watchdog threads to indicate whether the monitor - // thread should dump. Making it static so that it is accessiable from all the + // thread should dump. Making it static so that it is accessible from all the // PGs. With this flag, monitor thread would dump debug info under any one of // the three conditions: // @@ -1286,9 +1342,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { bool writeDebugInfo_ = false; - // Condition Variable for watchdog thread sleep - std::condition_variable workMetaListCV_; - // Vector to store WorkNCCL pointers std::list workMetaList_; @@ -1349,14 +1402,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::mutex errorMutex_; - // Whether or not to enable timeout root cause analysis. - bool desyncDebug_; - DesyncDebugger desyncDebugger_; - - // Whether or not to propagate detected errors to all ranks in the same PG - // through TCPStore. - bool propagatePgError_; - // Whether or not to sleep after an exception is thrown in the watchdog. bool sleepAfterException_{}; @@ -1375,9 +1420,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set bool avoidRecordStreams_ = false; - // Whether the NCCL watchdog should rethrow CUDA errors. - bool rethrowCUDAErrors_ = false; - // The number of active ncclGroupStart() calls. This counter will be increased // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd() // is called. @@ -1395,8 +1437,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // the ProcessGroup uint64_t op_id_{0}; - std::exception_ptr watchDogException_ = nullptr; - // The number of ProcessGroupNCCL created on the current rank. size_t local_id_; diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 25223fca442c33..52354de93edf42 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -397,7 +397,7 @@ class WriterPayload : public c10::intrusive_ptr_target { void registeredInLoop() { /* This refcount increment must be matched by a reclaim call. - Call this method after sucessfully scheduling this handle with a loop. + Call this method after successfully scheduling this handle with a loop. */ at::raw::intrusive_ptr::incref(this); } diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 9732f46a8fd3c2..c31085754ba275 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -22,37 +22,6 @@ namespace c10d { -// A struct to hold the latest status of the process group. -struct ProcessGroupStatus { - // the sequential number of the last collective enqueued into workMetaList_ - // This is useful for indentifying a rank that has not join a collective - // initialized to be -1 to indicate no collective has been enqueued - int64_t lastEnqueuedSeq{-1}; - // the sequential number of the last collective started as the kernel - int64_t lastStartedSeq{-1}; - // the sequential number of the last collective completed marked by - // the watchdog thread - // initialized to be -1 to indicate no collective has been completed - int64_t lastCompletedSeq{-1}; - - // the name of the last collective enqueued into workMetaList_ - std::string lastEnqueuedWorkName; - // the name of the last collective started as the kernel - std::string lastStartedWorkName; - // the name of the last collective completed - std::string lastCompletedWorkName; - - // the sizes of the last work enqueued - size_t lastEnqueuedNumelIn; - size_t lastEnqueuedNumelOut; - // the sizes of the last work completed - size_t lastCompletedNumelIn; - size_t lastCompletedNumelOut; - // the sizes of the last work started - size_t lastStartedNumelIn; - size_t lastStartedNumelOut; -}; - inline std::string getTraceStartKey(const std::string& pgName, int rank) { return fmt::format(FMT_COMPILE("{}_{}_trace_start"), pgName, rank); } diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 411b23a783cce5..03bd6ef3cafd85 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -573,9 +573,9 @@ using SizeType = uint64_t; // (https://stackoverflow.com/a/20295079), and thus `errno` should really only // be inspected if an error occurred. // -// `success_cond` is an expression used to check if an error has happend. So for -// `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output -// is stored in variable `__output` and may be used in `success_cond`. +// `success_cond` is an expression used to check if an error has happened. So +// for `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function +// output is stored in variable `__output` and may be used in `success_cond`. #ifdef _WIN32 #define SYSCHECK(expr, success_cond) \ while (true) { \ diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index 5fd6c6c737885a..e9e785a9c643de 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -118,7 +118,7 @@ class TORCH_API Work : public torch::CustomClassHolder { // Get a Future object that would be marked as either success or failure // This API can be used by the user to track the completion of the work - // and hanlde the exception if any. + // and handle the exception if any. virtual c10::intrusive_ptr getFutureResult(); virtual float getDuration() const; diff --git a/torch/csrc/distributed/c10d/comm.hpp b/torch/csrc/distributed/c10d/comm.hpp index 6f9203e214348f..599c1709c4df55 100644 --- a/torch/csrc/distributed/c10d/comm.hpp +++ b/torch/csrc/distributed/c10d/comm.hpp @@ -67,7 +67,8 @@ class TORCH_API GradBucket { return parameters_; } - // Returns whther this bucket is the last bucket to allreduce in an iteration. + // Returns whether this bucket is the last bucket to allreduce in an + // iteration. bool isLast() const { return index_ == bucket_count_ - 1; } diff --git a/torch/csrc/distributed/c10d/cuda/AsyncMM.cu b/torch/csrc/distributed/c10d/cuda/AsyncMM.cu index c07dcd46a0151e..3b7effb3a7d6a1 100644 --- a/torch/csrc/distributed/c10d/cuda/AsyncMM.cu +++ b/torch/csrc/distributed/c10d/cuda/AsyncMM.cu @@ -5,7 +5,7 @@ #include #include -// Two warninngs in Cutlass included header files +// Two warnings in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") @@ -157,13 +157,13 @@ at::Tensor async_input_mm_impl( }; TORCH_CHECK( - a_chunk_signals.sizes().size() == 1, + a_chunk_signals.dim() == 1, "async_input_mm: `a_chunk_signals` must be a 1D tensor."); size_t num_chunks_M = a_chunk_signals.numel(); TORCH_CHECK( M % num_chunks_M == 0, - "async_input_mm: `a.shape(0)` must be an interger multiple of `a_chunk_signals.numel()`"); + "async_input_mm: `a.shape(0)` must be an integer multiple of `a_chunk_signals.numel()`"); size_t chunk_size_M = M / num_chunks_M; size_t tile_size_M = cute::get<0>(TileShape_MNK{}); @@ -248,7 +248,7 @@ at::Tensor async_input_mm_out( }); #else TORCH_CHECK( - false, "async_input_mm is not currenlty supported on your device"); + false, "async_input_mm is not currently supported on your device"); #endif return out; } diff --git a/torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh b/torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh index 3c8ef2a052a0e6..0610a862f1589c 100644 --- a/torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh +++ b/torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh @@ -3,7 +3,7 @@ * that supports consuming asynchronous input. This tile scheduler introduces the following arguments: * * - tiles_per_chunk_m – Specifies the size of an M chunk. Chunks are the granularity at which the - * asynchronous input becomes ready. It must be an interger multiple of the size of an M tile. + * asynchronous input becomes ready. It must be an integer multiple of the size of an M tile. * * - chunk_signals – chunk_signals[i] == 1 indicates that chunk i is ready. Before returning a work * tile, get_current_work() waits for the signal to ensure that the corresponding chunk is ready. @@ -327,7 +327,7 @@ public: wait_signal(scheduler_params.chunk_signals + chunk_idx); } - // An arbirary, non-default id + // An arbitrary, non-default id constexpr int barrier_id = 8; arch::NamedBarrier barrier(NumThreadsPerWarp, barrier_id); barrier.arrive_and_wait(); diff --git a/torch/csrc/distributed/c10d/cuda/utils.cpp b/torch/csrc/distributed/c10d/cuda/utils.cpp index 0072fab983f69b..44d5242e1401da 100644 --- a/torch/csrc/distributed/c10d/cuda/utils.cpp +++ b/torch/csrc/distributed/c10d/cuda/utils.cpp @@ -22,7 +22,7 @@ bool deviceSupportsMulticast(int device_idx) { // - Device support: Determined by querying // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime. auto driver_api = c10::cuda::DriverAPI::get(); - int multicast_supported; + int multicast_supported = 0; C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_( &multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 9c43e9a2697362..788759934a9053 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -31,7 +32,7 @@ #ifdef USE_C10D_NCCL #include #include -#include +#include #endif #ifdef USE_C10D_MPI @@ -44,9 +45,10 @@ #include #include -#include #include -#include +#include +#include +#include #include #include @@ -191,6 +193,10 @@ template using intrusive_ptr_no_gil_destructor_class_ = py::class_>; +template +using intrusive_ptr_no_gil_destructor_trampoline_class_ = + py::class_, Trampoline>; + // PythonStore is a pybind11 trampoline class to allow a Python // class to inherit from c10d.Store and implement its interface. class PythonStore : public ::c10d::Store { @@ -414,7 +420,7 @@ static PyObject* reduceopmeta___instancecheck__( // NOLINTNEXTLINE(*c-arrays) static PyMethodDef reduceopmeta_methods[] = { {"__instancecheck__", - (PyCFunction)reduceopmeta___instancecheck__, + reduceopmeta___instancecheck__, METH_O, "Custom `__instancecheck__` for ReduceOp"}, {nullptr, nullptr}}; @@ -999,6 +1005,19 @@ This class does not support ``__members__`` property.)"); return ::c10d::unregister_all_process_groups(); }); +#ifdef USE_NVSHMEM + // Initializes the device state in CUmodule so that it’s able to perform + // NVSHMEM operations. + module.def( + "_nvshmemx_cumodule_init", + ::c10d::nvshmem_extension::nvshmemx_cumodule_init, + py::arg("module")); + + // Check if NVSHMEM is available on current system. + module.def( + "_is_nvshmem_available", ::c10d::nvshmem_extension::is_nvshmem_available); +#endif + py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions") .def(py::init<>()) .def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank) @@ -1107,6 +1126,8 @@ This class does not support ``__members__`` property.)"); .def_static( "has_multicast_support", &::c10d::symmetric_memory::has_multicast_support) + .def_static("set_backend", &::c10d::symmetric_memory::set_backend) + .def_static("get_backend", &::c10d::symmetric_memory::get_backend) .def_property_readonly("rank", &SymmetricMemory::get_rank) .def_property_readonly("world_size", &SymmetricMemory::get_world_size) .def_property_readonly( @@ -2010,10 +2031,8 @@ communication mechanism. py::arg("world_size")); auto processGroup = - py::class_< - ::c10d::ProcessGroup, - c10::intrusive_ptr<::c10d::ProcessGroup>, - ::c10d::PyProcessGroup>(module, "ProcessGroup", + intrusive_ptr_no_gil_destructor_trampoline_class_< + ::c10d::ProcessGroup, ::c10d::PyProcessGroup>(module, "ProcessGroup", R"(A ProcessGroup is a communication primitive that allows for collective operations across a group of processes. @@ -2878,6 +2897,27 @@ The hook must have the following signature: "_end_coalescing", &::c10d::Backend::endCoalescing, py::call_guard()) + .def( + "supports_tensor_alloc", + [](::c10d::Backend& self, c10::Device device) { + return self.supportsTensorAlloc(device.index()); + }, + py::arg("device"), + py::call_guard()) + .def( + "allocate_tensor", + [](::c10d::Backend& self, + long size, + c10::ScalarType dtype, + c10::Device device) { + return self.allocateTensor( + size, at::TensorOptions().dtype(dtype).device(device)); + }, + py::arg("size"), + py::kw_only(), + py::arg("dtype"), + py::arg("device"), + py::call_guard()) .def_property_readonly( "mem_allocator", &::c10d::Backend::getMemAllocator); @@ -2921,7 +2961,12 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). processGroupGloo, "_Options", backendOptions) .def(py::init<>()) .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices) - .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads); + .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads) + .def_readwrite( + "global_ranks_in_group", + &::c10d::ProcessGroupGloo::Options::global_ranks_in_group) + .def_readwrite( + "group_name", &::c10d::ProcessGroupGloo::Options::group_name); processGroupGloo .def_static( @@ -3174,7 +3219,15 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). self->setEnableNanCheck(enable_nan_check); }, py::arg("enable_nan_check"), - py::call_guard()); + py::call_guard()) + .def_static( + "get_build_nccl_version", + [] { + return std::make_tuple(NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH); + }) + .def_static("get_runtime_nccl_version", [] { + return ::c10d::getNcclVersionTuple(); + }); module.def( "_get_intra_node_comm_usage_counter", @@ -3189,7 +3242,10 @@ ncclConfig_t data type for configuring NCCL communicators. See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t for details. )") - .def(py::init<>()) + .def(py::init([]() { + ncclConfig_t defaultCfg = NCCL_CONFIG_INITIALIZER; + return std::make_unique(defaultCfg); + })) .def_readwrite("blocking", &ncclConfig_t::blocking) .def_readwrite("cga_cluster_size", &ncclConfig_t::cgaClusterSize) .def_readwrite("min_ctas", &ncclConfig_t::minCTAs) @@ -3199,6 +3255,15 @@ for details. #endif #ifdef NCCL_HAS_QOS .def_readwrite("traffic_class", &ncclConfig_t::trafficClass) +#endif +#ifdef NCCL_HAS_COLLNET + .def_readwrite("collnet_enable", &ncclConfig_t::collnetEnable) +#endif +#ifdef NCCL_HAS_CTA_POLICY + .def_readwrite("cta_policy", &ncclConfig_t::CTAPolicy) +#endif +#ifdef NCCL_HAS_NVLS_CTAS + .def_readwrite("nvls_ctas", &ncclConfig_t::nvlsCTAs) #endif .def_property( "net_name", @@ -3208,7 +3273,16 @@ for details. // shouldn't leak because of allocation in strdup. [](ncclConfig_t& self, const char* tmp) { self.netName = strdup(tmp); - }); + }) + .def( + "__copy__", + [](const ncclConfig_t& self) { return ncclConfig_t(self); }) + .def( + "__deepcopy__", + [](const ncclConfig_t& self, const py::dict& memo) { + return ncclConfig_t(self); + }, + py::arg("memo")); #endif // NCCL_HAS_CONFIG intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( @@ -3225,7 +3299,7 @@ ProcessGroup options for the NCCL backend Default is False. Attributes: - config (NCCLConfig): configures NCCL communicators (only avaiable for + config (NCCLConfig): configures NCCL communicators (only available for builds using NCCL 2.17+). This can be used to improve communication-computation overlap for NCCL kernels by tuning available parameters in the config. See @@ -3259,7 +3333,20 @@ Example:: "global_ranks_in_group", &::c10d::ProcessGroupNCCL::Options::global_ranks_in_group) .def_readwrite( - "group_name", &::c10d::ProcessGroupNCCL::Options::group_name); + "group_name", &::c10d::ProcessGroupNCCL::Options::group_name) + .def( + "__copy__", + [](const ::c10d::ProcessGroupNCCL::Options& self) { + return ::c10d::ProcessGroupNCCL::Options(self); + }) + .def( + "__deepcopy__", + [](const ::c10d::ProcessGroupNCCL::Options& self, + const py::dict& memo) { + return ::c10d::ProcessGroupNCCL::Options(self); + }, + py::arg("memo")); + #endif #ifdef USE_C10D_MPI @@ -3285,7 +3372,9 @@ Example:: .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, - int size) { + int size, + c10::intrusive_ptr<::c10d::ProcessGroupXCCL::Options> + options) { // gil_scoped_release is not safe as a call_guard in init. // https://github.com/pybind/pybind11/issues/5473 py::gil_scoped_release nogil{}; @@ -3295,7 +3384,38 @@ Example:: }), py::arg("store"), py::arg("rank"), - py::arg("size")); + py::arg("size"), + py::arg("options"), + R"(Create a new ProcessGroupXCCL instance.)") + .def( + "comm_split_count", + &::c10d::ProcessGroupXCCL::getCommSplitCounter) + .def_property( + "bound_device_id", + &::c10d::ProcessGroupXCCL::getBoundDeviceId, + &::c10d::ProcessGroupXCCL::setBoundDeviceId, + R"(Return the bound device id.)") + .def( + "perform_nocolor_split", + &::c10d::ProcessGroupXCCL::performNocolorSplit); + intrusive_ptr_class_<::c10d::ProcessGroupXCCL::Options>( + processGroupXCCL, + "Options", + backendOptions) + .def(py::init(), py::arg("is_high_priority_stream") = false) + .def_readwrite("config", &::c10d::ProcessGroupXCCL::Options::config) + .def_readwrite( + "is_high_priority_stream", + &::c10d::ProcessGroupXCCL::Options::is_high_priority_stream) + .def_readwrite( + "split_from", &::c10d::ProcessGroupXCCL::Options::split_from) + .def_readwrite( + "split_color", &::c10d::ProcessGroupXCCL::Options::split_color) + .def_readwrite( + "global_ranks_in_group", + &::c10d::ProcessGroupXCCL::Options::global_ranks_in_group) + .def_readwrite( + "group_name", &::c10d::ProcessGroupXCCL::Options::group_name); #endif #ifdef USE_C10D_UCC @@ -3477,7 +3597,7 @@ such as `dist.all_reduce(tensor, async_op=True)`. Example:: Below is an example of a simple allreduce DDP communication hook that uses - ``get_future` API to retrieve a Future associated with the completion of + ``get_future`` API to retrieve a Future associated with the completion of ``allreduce``. >>> def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -> torch.futures.Future @@ -3789,6 +3909,46 @@ such as `dist.all_reduce(tensor, async_op=True)`. )"); #endif + module.def( + "_dump_fr_trace_json", + [](std::optional includeCollectives, + std::optional onlyActive) { + return py::bytes(::c10d::dump_fr_trace_json( + includeCollectives.value_or(true), onlyActive.value_or(false))); + }, + py::arg("includeCollectives") = std::optional(), + py::arg("onlyActive") = std::optional(), + R"( + Arguments: + includeCollectives(bool, optional): Whether to include collective work traces. Default is True. + onlyActive (bool, optional): Whether to only include active collective work traces. Default is False. + Returns: + Stringified json work traces. + Default settings return everything. + )"); + module.def( + "_dump_fr_trace", + [](std::optional includeCollectives, + std::optional includeStackTraces, + std::optional onlyActive) { + return py::bytes(::c10d::dump_fr_trace( + includeCollectives.value_or(true), + includeStackTraces.value_or(true), + onlyActive.value_or(false))); + }, + py::arg("includeCollectives") = std::optional(), + py::arg("includeStackTraces") = std::optional(), + py::arg("onlyActive") = std::optional(), + R"( + Arguments: + includeCollectives(bool, optional): Whether to include collective work traces. Default is True. + includeStackTraces(bool, optional): Whether to include stacktraces in the collective work traces. Default is True. + onlyActive (bool, optional): Whether to only include active collective work traces. Default is False. + Returns: + Stringified pickle work traces. + Default settings return everything. + )"); + intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( module, "_WorkerServer", R"( )") diff --git a/torch/csrc/distributed/c10d/logger.hpp b/torch/csrc/distributed/c10d/logger.hpp index c1797046a97991..cd562af7473ae3 100644 --- a/torch/csrc/distributed/c10d/logger.hpp +++ b/torch/csrc/distributed/c10d/logger.hpp @@ -7,6 +7,37 @@ namespace c10d { +// A struct to hold the latest status of the process group. +struct ProcessGroupStatus { + // the sequential number of the last collective enqueued into workMetaList_ + // This is useful for identifying a rank that has not join a collective + // initialized to be -1 to indicate no collective has been enqueued + int64_t lastEnqueuedSeq{-1}; + // the sequential number of the last collective started as the kernel + int64_t lastStartedSeq{-1}; + // the sequential number of the last collective completed marked by + // the watchdog thread + // initialized to be -1 to indicate no collective has been completed + int64_t lastCompletedSeq{-1}; + + // the name of the last collective enqueued into workMetaList_ + std::string lastEnqueuedWorkName; + // the name of the last collective started as the kernel + std::string lastStartedWorkName; + // the name of the last collective completed + std::string lastCompletedWorkName; + + // the sizes of the last work enqueued + size_t lastEnqueuedNumelIn; + size_t lastEnqueuedNumelOut; + // the sizes of the last work completed + size_t lastCompletedNumelIn; + size_t lastCompletedNumelOut; + // the sizes of the last work started + size_t lastStartedNumelIn; + size_t lastStartedNumelOut; +}; + class TORCH_API Logger { public: explicit Logger(std::shared_ptr reducer); diff --git a/torch/csrc/distributed/c10d/nvshmem_extension.cu b/torch/csrc/distributed/c10d/nvshmem_extension.cu deleted file mode 100644 index 67c1765ec03379..00000000000000 --- a/torch/csrc/distributed/c10d/nvshmem_extension.cu +++ /dev/null @@ -1,328 +0,0 @@ -#include - -#include - -#include -#include -#include - -#include -// Use torch's cub wrapper instead of CUDA's , see #55292 -#include -#include - -namespace c10d::nvshmem_extension { - -using c10d::symmetric_memory::StoreExchange; -static StoreExchange storeExchange = StoreExchange("nvshmem_ext"); - -#define THREADS_PER_BLOCK 512 - -constexpr int MiB = 1024 * 1024; - -// Bootstrap based on user's setting for NCCL -// Long term, this may be a bit unclean; short term, it improves UX -void maybe_initialize_env_vars() { - auto nccl_socket_if_name = c10::utils::get_env("NCCL_SOCKET_IFNAME"); - auto nccl_hca_list = c10::utils::get_env("NCCL_IB_HCA"); - auto nccl_ib_gid_index = c10::utils::get_env("NCCL_IB_GID_INDEX"); - auto nvshmem_socket_if_name = - c10::utils::get_env("NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME"); - auto nvshmem_hca_list = c10::utils::get_env("NCCL_IB_HCA"); - auto nvshmem_ib_gid_index = c10::utils::get_env("NVSHMEM_IB_GID_INDEX"); - - if (!nvshmem_socket_if_name.has_value() && nccl_socket_if_name.has_value()) { - c10::utils::set_env( - "NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME", nccl_socket_if_name->c_str()); - } - if (!nvshmem_hca_list.has_value() && nccl_hca_list.has_value()) { - c10::utils::set_env("NVSHMEM_ENABLE_NIC_PE_MAPPING", "1"); - c10::utils::set_env("NVSHMEM_HCA_LIST", nccl_hca_list->c_str()); - } - if (!nvshmem_ib_gid_index.has_value() && nccl_ib_gid_index.has_value()) { - c10::utils::set_env("NVSHMEM_IB_GID_INDEX", nccl_ib_gid_index->c_str()); - } -} - -void initialize_nvshmem_with_store( - c10::intrusive_ptr store, - int rank, - int world_size) { - static bool is_initialized = false; - if (is_initialized) { - return; - } - - maybe_initialize_env_vars(); - - nvshmemx_uniqueid_t unique_id; - TORCH_CHECK( - nvshmemx_get_uniqueid(&unique_id) == 0, "nvshmemx_get_uniqueid failed"); - - // Using an existing store_all_gather due to laziness. - // TODO(yifu): should use broadcast - auto unique_ids = storeExchange.all_gather(store, rank, world_size, unique_id); - - nvshmemx_init_attr_t attr; - nvshmemx_set_attr_uniqueid_args(rank, world_size, &unique_ids[0], &attr); - - TORCH_CHECK( - nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr) == 0, - "nvshmemx_init_attr failed"); - - is_initialized = true; -} - -void* nvshmem_malloc(size_t size) { - return ::nvshmem_malloc(size); -} - -void* nvshmem_ptr(const void* dest, int pe) { - return ::nvshmem_ptr(dest, pe); -} - -std::unordered_map group_name_to_team_; - -nvshmem_team_t group_to_team( - const std::string& group_name, - const std::vector& global_ranks) { - auto it = group_name_to_team_.find(group_name); - if (it != group_name_to_team_.end()) { - return it->second; - } - TORCH_CHECK(global_ranks.size() > 1); - int stride = global_ranks[1] - global_ranks[0]; - for (size_t r = 1; r < global_ranks.size(); ++r) { - TORCH_CHECK(global_ranks[r] - global_ranks[r - 1] == stride); - } - - nvshmem_team_t team; - TORCH_CHECK( - nvshmem_team_split_strided( - NVSHMEM_TEAM_WORLD, - global_ranks[0], - stride, - global_ranks.size(), - nullptr, - 0, - &team) == 0); - group_name_to_team_[group_name] = team; - TORCH_CHECK(team != NVSHMEM_TEAM_INVALID); - return team; -} - -at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) { - auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); - int rank = input_hdl->get_rank(); - int world_size = input_hdl->get_world_size(); - auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank()); - void* buffer_ptr = input_hdl->get_buffer_ptrs()[rank]; - - auto stream = at::cuda::getCurrentCUDAStream(); - nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), 0, stream); - return input; -} - -at::Tensor nvshmem_all_to_all( - at::Tensor& input, - at::Tensor& out, - std::string group_name) { - auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); - auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); - int rank = input_hdl->get_rank(); - int world_size = input_hdl->get_world_size(); - auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank()); - - void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; - void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; - size_t bytes_per_rank = input_hdl->get_buffer_size() / world_size; - - auto stream = at::cuda::getCurrentCUDAStream(input.device().index()); - nvshmemx_alltoallmem_on_stream(team, output_ptr, input_ptr, bytes_per_rank, stream); - return out; -} - -// This is an exclusive prefix sum function that calculates read (or write) offsets for each peer. -__device__ void prefixSum(int64_t *odata, int64_t *idata, int n) { - // Specialize BlockScan for a 1D block of threads, of type int64_t. - // - `BLOCK_SCAN_WARP_SCANS` is a low-latency scan algorithm (instead of high - // throughput which we don't need here). - // - `at_cuda_detail::cub` is torch's cub wrapper, see #55292. - using BlockScanT = at_cuda_detail::cub::BlockScan; - // Allocate shared memory for BlockScan - __shared__ typename BlockScanT::TempStorage temp_storage; - - // TODO: currently it is assumed that the number of PE's is smaller than - // `THREADS_PER_BLOCK` - CUDA_KERNEL_ASSERT(n <= THREADS_PER_BLOCK); - - // Obtain input item for each thread - int tid = threadIdx.x; - int64_t thread_data = (tid < n) ? idata[tid] : 0; - - // Collectively compute the block-wide exclusive prefix sum - BlockScanT(temp_storage).ExclusiveSum(thread_data, thread_data); - - // Store the result - if (tid < n) { - odata[tid] = thread_data; - } -} - -// This kernel is used to exchange output splits and source offsets between peers. -// `in_out_splits` is of size (3, npes) and contains: -// - input splits (IN) -// - output splits (OUT) and -// - source offsets (OUT). -__global__ void exchangeSplitAndOffset(int64_t* in_out_splits, int mype, int npes) { - auto input_splits = in_out_splits; - auto output_splits = in_out_splits + npes; - auto source_offsets = in_out_splits + npes * 2; - int tid = threadIdx.x; - - __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; - - // Scan input splits to get the source offsets - prefixSum(peer_offsets, input_splits, npes); - __syncthreads();; - - // Use 1 block to do the exchange - if (tid < npes) { - int peer = tid; - nvshmem_int64_p(source_offsets + mype, peer_offsets[peer], peer); - nvshmem_int64_p(output_splits + mype, input_splits[peer], peer); - } - // This barrier ensures that all remote PEs see the updated values - nvshmemx_barrier_all_block(); -} - -// This kernel is used to do the actual data exchange. -// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. -// `stride` is the stride at dim 0, unit in byte. -__global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes) { - auto output_splits = in_out_splits + npes; - auto source_offsets = in_out_splits + npes * 2; - int bid = blockIdx.x; - int tid = threadIdx.x; - int blocks_per_peer = max(gridDim.x / npes, 1); - - // Calculate the output offsets - __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; - prefixSum(peer_offsets, output_splits, npes); - __syncthreads(); - - // Target a different peer based on bid - for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) { - int peer = (mype + i) % npes; - // Total amount from `peer` - auto peer_size = output_splits[peer] * stride; - // Amount to get from `peer` in this block - auto block_size = peer_size / blocks_per_peer; - // Being lazy here, we should handle the residual if the division is not exact - CUDA_KERNEL_ASSERT(block_size * blocks_per_peer == peer_size); - // This block's offset in the data from `peer` - auto block_offset = block_size * (bid % blocks_per_peer); - auto source_offset = source_offsets[peer] * stride + block_offset; - auto write_offset = peer_offsets[peer] * stride + block_offset; - nvshmemx_getmem_block( - (char*)recv_data + write_offset, - (char*)send_data + source_offset, - block_size, - peer); - } - // Write out the output offsets (to the scratchpad line) - if (bid == 0 && tid < npes) { - source_offsets[tid] = peer_offsets[tid]; - } -} - -at::Tensor nvshmem_all_to_all_vdev( - at::Tensor& input, - at::Tensor& out, - at::Tensor& in_out_splits, - std::string group_name) { - /* Perform AllToAllv operation using NVSHMEM, with split information provided on device. - * Arguments: - * - `input` is the input tensor - * - `out` is the output tensor - * - `in_out_splits` is a 2D tensor of size (3, npes). The rows are (in order): - input splits (IN) - output splits (OUT) and - output offsets (OUT). - */ - auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); - auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); - auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name); - int rank = input_hdl->get_rank(); - int world_size = input_hdl->get_world_size(); - - void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; - void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; - int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]); - - auto stream = at::cuda::getCurrentCUDAStream(input.device().index()); - - // Exchange output splits and source offsets - // Use collective launch because kernel involves nvshmem barrier - void* args0[] = { - &splits_ptr, - &rank, - &world_size}; - nvshmemx_collective_launch( - (const void*)exchangeSplitAndOffset, - dim3(1), - dim3(THREADS_PER_BLOCK), - args0, - 0, - stream); - - // CTA Tuning - // Intra-node: use multiple blocks per peer to increase data parallelism, up to 8. - // Up to 1 MB -> 1 block - // Up to 2 MB -> 2 blocks - // Up to 4 MB -> 4 blocks - // More -> 8 blocks - // The tuning for `num_blocks` below multiplies these numbers by world_size - // (e.g. 8 -> 8 * 8). If world_size is smaller, we simply shift the blocks - // towards data parallelism. (There may be room for improvement here) - auto input_size = input.numel() * input.element_size(); - int num_blocks = input_size < MiB ? 8 : - (input_size < 2 * MiB ? 16 : - (input_size < 4 * MiB ? 32 : 64)); - - // Inter-node: limit the total the number of blocks to 8 which is able to - // drive 57 GB/s bandwidth in test, enough to drive a 400 Gb/s NIC. - // TODO: better intra vs inter detection, currently it is based on world_size - if (world_size > 8) { - num_blocks = std::min(num_blocks, 8); - } - - // Stride at dim 0 (assuming input is contiguous, TODO) - size_t stride_bytes = input.stride(0) * input.element_size(); - - // All to all data exchange - void* args1[] = { - &input_ptr, - &output_ptr, - &splits_ptr, - &stride_bytes, - &rank, - &world_size}; - nvshmemx_collective_launch( - (const void*)allToAllV, - dim3(num_blocks), - dim3(THREADS_PER_BLOCK), - args1, - 0, - stream); - return out; -} - -} // namespace c10d::nvshmem_extension - - -TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { - m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast); - m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all); - m.impl("nvshmem_all_to_all_vdev", c10d::nvshmem_extension::nvshmem_all_to_all_vdev); -} diff --git a/torch/csrc/distributed/c10d/nvshmem_extension.cuh b/torch/csrc/distributed/c10d/nvshmem_extension.cuh deleted file mode 100644 index 5b14354855e754..00000000000000 --- a/torch/csrc/distributed/c10d/nvshmem_extension.cuh +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include - -#include - -namespace c10d::nvshmem_extension { - -void initialize_nvshmem_with_store( - c10::intrusive_ptr store, - int rank, - int world_size); - -void* nvshmem_malloc(size_t size); - -void* nvshmem_ptr(const void* dest, int pe); - -at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name); - -at::Tensor nvshmem_all_to_all( - at::Tensor& input, - at::Tensor& out, - std::string group_name); - -at::Tensor nvshmem_all_to_all_vdev( - at::Tensor& input, - at::Tensor& out, - at::Tensor& in_out_splits, - std::string group_name); - -} // namespace c10d::nvshmem_extension diff --git a/torch/csrc/distributed/c10d/python_comm_hook.cpp b/torch/csrc/distributed/c10d/python_comm_hook.cpp index adf73452bd7b4c..af3bf6b4c65d34 100644 --- a/torch/csrc/distributed/c10d/python_comm_hook.cpp +++ b/torch/csrc/distributed/c10d/python_comm_hook.cpp @@ -28,7 +28,7 @@ c10::intrusive_ptr PythonCommHook::runHook( try { return py_fut.cast>()->fut; } catch (const py::cast_error& e) { - auto type = py_fut.get_type(); + auto type = py::type::handle_of(py_fut); auto errMsg = c10::str( e.what(), ". DDP communication hook's callback must return a " diff --git a/torch/csrc/distributed/c10d/python_comm_hook.h b/torch/csrc/distributed/c10d/python_comm_hook.h index 48ad7cefae9418..a63f03fbf8c1ec 100644 --- a/torch/csrc/distributed/c10d/python_comm_hook.h +++ b/torch/csrc/distributed/c10d/python_comm_hook.h @@ -15,7 +15,7 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { // The state is passed to the hook in runHook method, and it can be used to // maintain and update any state information during the execution of the hook. // The hook performs user-specified processing and returns a future indicating - // asychronous communication of gradients. + // asynchronous communication of gradients. PythonCommHook(py::object state, py::object hook) : state_(std::move(state)), hook_(std::move(hook)) {} diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 9b2cc9f5eedfbb..1e9e7006a663e5 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -1245,7 +1245,7 @@ void Reducer::initialize_buckets( // patterns when copy_ing grad data in and out of its bucket view. // However, numerics remain correct, because the bucket view is the same // on either end of the raw allreduce. bucket_view_in.copy(grad) - // tranposes + // transposes // (+ densifies) to the bucket view's layout, the data is allreduced, // then grad.copy_(bucket_view_out) transposes it back to grad's layout. // diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index 43536bd515df11..6707975d38ac1e 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -564,7 +564,7 @@ class TORCH_API Reducer { // Retrieves parameter corresponding to the given VariableIndex. at::Tensor& get_param_from_index(size_t index); // Python reducer keeps C++ reducer initialized. To remove this flag, - // we need to refactor the DDP wrapper's initilization. + // we need to refactor the DDP wrapper's initialization. bool use_python_reducer_; // Cached bucket index to model parameter mapping. Populated after buckets diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h similarity index 92% rename from torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h rename to torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h index ef2a712db344e9..f8e958b7f9fa6e 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h @@ -50,7 +50,7 @@ __device__ __forceinline__ void trap() { #if defined(USE_ROCM) // abort() calls trap() under the covers. However, on ROCm, the trap is // handled differently inside hip runtime. It collects a gpu core dump and - // causes linux kernerl to create a core dump of the host application. + // causes linux kernel to create a core dump of the host application. abort(); #else __trap(); @@ -260,6 +260,31 @@ __device__ __inline__ T add_bf16x2(T a, T b) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) CUDA_KERNEL_ASSERT(false); return T{}; +#elif defined(USE_ROCM) + union bf2f { + float f; + __hip_bfloat16 bf[2]; + } _bf2f_a = {.f = 0}, _bf2f_b = {.f = 0}; + + //__hip_bfloat162 is a struct wtih two __hip_bfloat16 elements called x and y + // This typecasts input a and b as bfloat16 and maps to low bits of a float + // and does the addition in float + _bf2f_a.bf[1] = reinterpret_cast<__hip_bfloat162*>(&a)->x; + _bf2f_b.bf[1] = reinterpret_cast<__hip_bfloat162*>(&b)->x; + union f2bf { + float f; + __hip_bfloat16 bf[2]; + } _f2bf_res0, _f2bf_res1; + _f2bf_res0.f = _bf2f_a.f + _bf2f_b.f; + + // Same thing for y elements of __hip_bfloat162 + _bf2f_a.bf[1] = reinterpret_cast<__hip_bfloat162*>(&a)->y; + _bf2f_b.bf[1] = reinterpret_cast<__hip_bfloat162*>(&b)->y; + _f2bf_res1.f = _bf2f_a.f + _bf2f_b.f; + + // Put the two results together + __hip_bfloat162 rtn(_f2bf_res0.bf[1], _f2bf_res1.bf[1]); + return *reinterpret_cast(&rtn); #else auto res = __hadd2( *reinterpret_cast<__nv_bfloat162*>(&a), diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu similarity index 96% rename from torch/csrc/distributed/c10d/CUDASymmetricMemory.cu rename to torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index 87a14a5f26d748..3b5f080d4c406f 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -1,6 +1,6 @@ -#include -#include -#include +#include +#include +#include #include #include @@ -255,7 +255,7 @@ static __global__ void barrier_kernel( void CUDASymmetricMemory::barrier(int channel, size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); - barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + barrier_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(signal_pads_dev_), channel, rank_, @@ -293,7 +293,7 @@ void CUDASymmetricMemory::put_signal( size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); - put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + put_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(signal_pads_dev_), dst_rank, channel, @@ -337,7 +337,7 @@ void CUDASymmetricMemory::wait_signal( size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); - wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + wait_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(signal_pads_dev_), src_rank, channel, @@ -722,6 +722,14 @@ bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) { return device_has_multicast_support(device_idx); } +c10::DeviceType CUDASymmetricMemoryAllocator::supported_device_type() { + return c10::DeviceType::CUDA; +} + +std::string CUDASymmetricMemoryAllocator::name() { + return "CUDA"; +} + c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { std::shared_lock lock(mutex_); auto it = ptr_to_block_.find(ptr); @@ -733,12 +741,17 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { struct RegisterCUDASymmetricMemoryAllocator { RegisterCUDASymmetricMemoryAllocator() { + auto allocator = c10::make_intrusive(); // Query backend used for CUDA tensor // "CUDA" backend stands for this implementation if (getSymmMemBackendCUDA() == "CUDA") { + // Direct set (static registration) register_allocator( c10::DeviceType::CUDA, - c10::make_intrusive()); + allocator); + } else { + // Register availability in case `set_backend` is called dynamically + register_availability("CUDA", allocator); } } }; diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp similarity index 93% rename from torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp rename to torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp index d1a85e3a236d00..04443882993880 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp @@ -1,9 +1,9 @@ #pragma once #include -#include #include -#include +#include +#include namespace c10d::symmetric_memory { @@ -115,6 +115,8 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { void* ptr, const std::optional& group_name) override; bool has_multicast_support(int device_idx) override; + c10::DeviceType supported_device_type() override; + std::string name() override; private: c10::intrusive_ptr find_block(void* ptr); diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu similarity index 98% rename from torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu rename to torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu index 744276772bffb1..c4f38e468192db 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu @@ -15,10 +15,9 @@ #include #endif - -#include -#include #include +#include +#include #if defined(USE_ROCM) || (defined(CUDART_VERSION) && CUDART_VERSION >= 12030) @@ -56,7 +55,7 @@ INT_SWITCH_CASE(k_alignment, 8, __VA_ARGS__); \ INT_SWITCH_CASE(k_alignment, 4, __VA_ARGS__); \ default: { \ - TORCH_CHECK(false, "Not implemented for aligment=", alignment); \ + TORCH_CHECK(false, "Not implemented for alignment=", alignment); \ } \ } @@ -115,7 +114,7 @@ void init_elementwise_launch_config( num_blocks = 1; num_threads = at::round_up( at::ceil_div(numel_per_split, numel_per_thread), - static_cast(C10_WARP_SIZE)); + static_cast(at::cuda::warp_size())); } else { num_blocks = std::min( at::ceil_div(numel_per_split, max_num_threads * numel_per_thread), @@ -403,7 +402,6 @@ at::Tensor multimem_all_gather_out( // count to 512 to prevent/alleviate register spill. constexpr size_t one_shot_all_reduce_max_num_blocks = 24; constexpr size_t one_shot_all_reduce_max_num_threads = 512; - template static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ void one_shot_all_reduce_kernel( @@ -562,9 +560,13 @@ at::Tensor one_shot_all_reduce_copy( input, local_input, reduce_op, group_name, out); } +#if defined(USE_ROCM) +constexpr size_t two_shot_all_reduce_max_num_blocks = 64; +constexpr size_t two_shot_all_reduce_max_num_threads = 128; +#else constexpr size_t two_shot_all_reduce_max_num_blocks = 24; constexpr size_t two_shot_all_reduce_max_num_threads = 1024; - +#endif template < typename T, int alignment, @@ -628,11 +630,16 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ for (size_t step = 0; step < k_world_size; ++step) { size_t remote_rank = (rank + step) % k_world_size; size_t remote_start = numel_per_rank * remote_rank; +#if defined (USE_ROCM) + tmp[step] = at::native::memory::ld_vec( + input_ptrs[remote_rank] + input_offset + min(remote_start + i, numel-1)); +#else if (remote_start + i >= numel) { continue; } tmp[step] = at::native::memory::ld_vec( input_ptrs[remote_rank] + input_offset + remote_start + i); +#endif } #pragma unroll k_world_size for (size_t step = 0; step < k_world_size; ++step) { diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryTypes.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp similarity index 100% rename from torch/csrc/distributed/c10d/CUDASymmetricMemoryTypes.hpp rename to torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp similarity index 92% rename from torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp rename to torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp index ef95e47a27c513..225304faca652f 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp @@ -12,9 +12,9 @@ #include #endif -#include #include #include +#include namespace c10d::symmetric_memory { @@ -32,15 +32,23 @@ bool allow_overlapping_devices() { // Query environment variable to get the backend used for CUDA Symmetric Memory. std::string getSymmMemBackendCUDA() { + // TORCH_SYMMMEM environment variable can be used to indicate the preferred + // backend. static auto val = c10::utils::get_env("TORCH_SYMMMEM"); - if (!val.has_value()) { - // In-house implementation: `CUDASymmetricMemory` - return "CUDA"; - } else { - // Other backends: - // - "NVSHMEM": `NVSHMEMSymmetricMemory` + if (val.has_value()) { + TORCH_CHECK( + val.value() == "CUDA" || val.value() == "NVSHMEM" || + val.value() == "NCCL", + "TORCH_SYMMMEM environment variable must be one of 'CUDA', 'NVSHMEM', 'NCCL'.") return val.value(); } + // If TORCH_SYMMMEM is not set, check if NVSHMEM is available (for broader + // support). + // TODO: uncomment this once all single-node tests work with NVSHMEM + // if (is_nvshmem_available()) { + // return "NVSHMEM"; + // } + return "CUDA"; } IpcChannel::IpcChannel() @@ -148,7 +156,7 @@ int IpcChannel::recv_fd() { .msg_control = cbuf, .msg_controllen = sizeof(cbuf)}; - // Recieve message on socket_ + // Receive message on socket_ TORCH_CHECK( recvmsg(socket_, &msg, 0) > 0, "Failed to receive fd: ", diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp similarity index 95% rename from torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp rename to torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp index 9fac3c9f69832c..77dd36b778aeaa 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp @@ -1,8 +1,8 @@ #pragma once -#include #include -#include +#include +#include namespace c10d { namespace symmetric_memory { diff --git a/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp b/torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp similarity index 95% rename from torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp rename to torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp index 1ed72a9aa116a2..b5efcfeb3006f9 100644 --- a/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp @@ -1,8 +1,9 @@ #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -#include +#include #include #include +#include #include #include @@ -12,18 +13,13 @@ namespace { constexpr int max_nvlinks = 64; std::string get_bus_id(int device_idx) { - // NOLINTNEXTLINE(*array*) - char bus_id[80]; cudaDeviceProp prop{}; C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_idx)); - snprintf( - bus_id, - sizeof(bus_id), + return fmt::sprintf( NVML_DEVICE_PCI_BUS_ID_FMT, prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); - return std::string(bus_id); } struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector { @@ -39,6 +35,7 @@ struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector { // Obtain the bus_id for all visible devices std::unordered_map bus_id_to_device_idx; + bus_id_to_device_idx.reserve(num_devices); std::vector bus_ids; bus_ids.reserve(num_devices); for (int i = 0; i < num_devices; ++i) { @@ -47,7 +44,7 @@ struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector { bus_ids.push_back(std::move(bus_id)); } - static const char* warning_msg = + static constexpr const char* warning_msg = "PyTorch features that use NVLinkDetector may assume no NVLink presence."; auto driver_api = c10::cuda::DriverAPI::get(); diff --git a/torch/csrc/distributed/c10d/DMAConnectivity.cpp b/torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.cpp similarity index 97% rename from torch/csrc/distributed/c10d/DMAConnectivity.cpp rename to torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.cpp index a2bab1247a51d7..0d54c389ddee64 100644 --- a/torch/csrc/distributed/c10d/DMAConnectivity.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.cpp @@ -1,4 +1,4 @@ -#include +#include #include namespace { diff --git a/torch/csrc/distributed/c10d/DMAConnectivity.hpp b/torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp similarity index 100% rename from torch/csrc/distributed/c10d/DMAConnectivity.hpp rename to torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp diff --git a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu new file mode 100644 index 00000000000000..4f69c497438648 --- /dev/null +++ b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu @@ -0,0 +1,365 @@ +#ifdef USE_C10D_NCCL +#include +#include + +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 1) +#define NCCL_HAS_SYMMEM_SUPPORT +#endif + +#ifdef NCCL_HAS_SYMMEM_SUPPORT +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace c10d { +namespace symmetric_memory { + +/* Start of NCCLAllocation implementation */ + +static StoreExchange storeExchange = StoreExchange("NCCLAllocation"); + +struct NCCLAllocation { + void* ptr; + size_t buffer_size; + int device_idx; + + NCCLAllocation(void* ptr, size_t buffer_size, int device_idx) + : ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {} +}; + +class NCCLSymmetricMemory : public SymmetricMemory { + public: + NCCLSymmetricMemory( + std::shared_ptr allocation, + const std::string& group_name, + ncclWindow_t handle, + ncclWindow_t signal_handle) + : allocation_(allocation), + buffer_size_(allocation->buffer_size), + device_idx_(allocation->device_idx), + group_name_(group_name), + handle_(handle), + signal_handle_(signal_handle) { + c10::cuda::CUDAGuard guard(device_idx_); + + // We need some API like nvshmem_extension::nvshmem_ptr() + // put API to get the reference of remote memory. + // WIP + } + + ~NCCLSymmetricMemory() override = default; + + std::vector get_buffer_ptrs() override { + return buffers_; + } + + std::vector get_signal_pad_ptrs() override { + return signal_pads_; + } + + void** get_buffer_ptrs_dev() override { + return buffers_dev_; + } + + void** get_signal_pad_ptrs_dev() override { + return signal_pads_dev_; + } + + size_t get_buffer_size() override { + return buffer_size_; + } + + size_t get_signal_pad_size() override { + return signal_pad_size; + }; + + bool has_multicast_support() override { + // TODO + return false; + } + + void* get_multicast_ptr() override { + // TODO + return nullptr; + } + + // TODO: This is up for change. + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + // TODO: deduplicate + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "NCCLSymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::for_blob(data_ptr, sizes) + .options(options) + .target_device(device) + .make_tensor(); + } + + // TODO: This is up for change. + at::Tensor get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype, + int64_t storage_offset) override { + // TODO: deduplicate + // If the dtype is unspecified, default it to UInt32, as it + // is the most common type for signaling purposes. + if (!dtype.has_value()) { + dtype = c10::ScalarType::UInt32; + } + + // If the shape is unspecified, treat the signal pad as a 1d tensor. + const auto element_size = c10::elementSize(*dtype); + std::vector shape; + if (!sizes.empty()) { + shape = sizes.vec(); + } else { + shape.push_back(signal_pad_size / element_size); + } + + const size_t numel = std::accumulate( + shape.begin(), + shape.end(), + static_cast(1), + std::multiplies()); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= signal_pad_size, + "NCCLSymmetricMemory::get_signal_pad: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + signal_pad_size, + " bytes)"); + auto data_ptr = reinterpret_cast(signal_pads_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); + auto options = at::TensorOptions().dtype(*dtype).device(device); + return at::for_blob(data_ptr, shape) + .options(options) + .target_device(device) + .make_tensor(); + } + + void barrier(int channel, size_t timeout_ms) override { + // TODO + } + + void put_signal(int dst_rank, int channel, size_t timeout_ms) override { + // TODO + } + + void wait_signal(int src_rank, int channel, size_t timeout_ms) override { + // TODO + } + + int get_rank() override { + return rank_; + } + + int get_world_size() override { + return world_size_; + } + + virtual std::vector& get_rank_to_global_rank() override { + return rank_to_global_rank_; + }; + + int* get_rank_to_global_rank_dev() override { + return rank_to_global_rank_dev_; + }; + + private: + std::shared_ptr allocation_; + size_t buffer_size_; + // TODO: We need to finalize what booking variables we need for nccl backend. + std::vector buffers_; + std::vector signal_pads_; + int device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; + std::string group_name_; + ncclWindow_t handle_; + ncclWindow_t signal_handle_; + + std::vector rank_to_global_rank_; + int* rank_to_global_rank_dev_; +}; + +class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc( + size_t size, + int device_idx, + const std::optional& group_name) override { + TORCH_CHECK( + group_name == std::nullopt, + "NCCLSymmetricMemoryAllocator::alloc " + "must not be called with a group_name"); + + auto group_info = get_group_info("0"); + auto store = group_info.store; + c10::cuda::CUDAGuard guard(device_idx); + // TODO: we might need to use a roundup or mempool for mem allocation. + void* ptr; + C10D_NCCL_CHECK(ncclMemAlloc(&ptr, size), "ncclMemAlloc"); + auto allocation = + std::make_shared(ptr, size, device_idx); + // TODO: thread safety + allocations_.emplace(ptr, allocation); + return ptr; + } + + void free(void* ptr) override { + // TODO: thread safety + ptr_to_symm_mem_.erase(ptr); + allocations_.erase(ptr); + }; + + size_t get_alloc_size(void* ptr) override { + auto it = ptr_to_symm_mem_.find(ptr); + if (it == ptr_to_symm_mem_.end()) { + TORCH_CHECK( + false, ptr, " is not allocated with NCCLSymmetricMemoryAllocator"); + } + return it->second->get_buffer_size(); + }; + + c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) override { + TORCH_CHECK(group_name.has_value(), "group_name must be provided"); + { + auto it = symm_mems_.find(std::make_tuple(ptr, *group_name)); + if (it != symm_mems_.end()) { + return it->second; + } + } + auto it = allocations_.find(ptr); + TORCH_CHECK(it != allocations_.end(), "memory needs to be first allocated before calling rendezvous."); + + + auto group = resolve_process_group(group_name.value()); + auto alloc = it->second; + c10::cuda::CUDAGuard guard(alloc->device_idx); + ncclWindow_t handle; + ncclWindow_t signal_handle; + + auto group_info = get_group_info(group_name.value()); + auto buffer_size_map = + storeExchange.all_gather(group_info.store, group_info.rank, group_info.world_size, it->second->buffer_size); + + LOG(INFO) << "[rank " << group_info.rank << "]" + << "buffer_size_map: " << buffer_size_map; + // NCCL window registration api requires all ranks to have the same buffer size + // we have this check to make sure all ranks have the same buffer size. + for (auto r = 0; r < group_info.world_size; ++r) { + TORCH_CHECK(alloc->buffer_size == buffer_size_map[r], "buffer size mismatch"); + } + auto* ncclPg = dynamic_cast( + group->getBackend(c10::DeviceType::CUDA).get()); + TORCH_CHECK(ncclPg != nullptr, "backend must be a NCCL process group"); + ncclComm_t comm = reinterpret_cast(ncclPg->getCommPtr()); + C10D_NCCL_CHECK( + ncclCommWindowRegister(comm, ptr, alloc->buffer_size, (ncclWindow_t*)&handle, NCCL_WIN_COLL_SYMMETRIC), + c10::str( + "Failed to window register segment with ptr ", + ptr, + ", size ", + alloc->buffer_size, + " on ncclComm_ ", + comm)); + + void* signal_pad_ptr; + C10D_NCCL_CHECK(ncclMemAlloc(&signal_pad_ptr, signal_pad_size), "ncclMemAlloc failed"); + C10D_NCCL_CHECK( + ncclCommWindowRegister(comm, signal_pad_ptr, signal_pad_size, (ncclWindow_t*)&signal_handle, NCCL_WIN_COLL_SYMMETRIC), + c10::str( + "Failed to window register segment with ptr ", + signal_pad_ptr, + ", size ", + signal_pad_size, + " on ncclComm_ ", + comm)); + + auto symm_mem = + c10::make_intrusive(alloc, *group_name, std::move(handle), std::move(signal_handle)); + + symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem; + return symm_mem; + }; + + bool has_multicast_support(int device_idx) override { + // TODO + return false; + }; + + c10::DeviceType supported_device_type() override { + return c10::DeviceType::CUDA; + } + + std::string name() override { + return "NCCL"; + } + + private: + std::unordered_map> + ptr_to_symm_mem_; + + std::unordered_map> allocations_; + std::map, c10::intrusive_ptr> + symm_mems_; +}; + +struct RegisterNCCLSymmetricMemoryAllocator { + RegisterNCCLSymmetricMemoryAllocator() { + auto allocator = c10::make_intrusive(); + // Query backend used for CUDA tensor + if (getSymmMemBackendCUDA() == "NCCL") { + // Direct set (static registration) + register_allocator( + c10::DeviceType::CUDA, + allocator); + } else { + // Register availability in case `set_backend` is called dynamically + register_availability("NCCL", allocator); + } + } +}; + +static RegisterNCCLSymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d +#endif // NCCL_HAS_SYMMEM_SUPPORT +#endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu new file mode 100644 index 00000000000000..00b31a4d5a6cb7 --- /dev/null +++ b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu @@ -0,0 +1,420 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace c10d { +namespace symmetric_memory { + +/* Start of CUDASymmetricMemory implementation */ + +static StoreExchange storeExchange = StoreExchange("NVSHMEMSymmetricMemory"); + +struct NVSHMEMAllocation { + void* ptr; + size_t buffer_size; + int device_idx; + + NVSHMEMAllocation(void* ptr, size_t buffer_size, int device_idx) + : ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {} + + ~NVSHMEMAllocation() { + // Avoid calling CUDA functions after driver shutting down + if (is_finalizing()) { + return; + } + c10::cuda::CUDAGuard guard(device_idx); + nvshmem_free(ptr); // nvshmem_free has no return value + } +}; + +class NVSHMEMSymmetricMemory : public SymmetricMemory { + public: + NVSHMEMSymmetricMemory( + std::shared_ptr allocation, + const std::string& group_name) + : allocation_(allocation), + buffer_size_(allocation->buffer_size), + device_idx_(allocation->device_idx), + group_name_(group_name) { + // For logging only + static int exchanged_n_times = 0; + c10::cuda::CUDAGuard guard(device_idx_); + + auto global_rank = get_group_info("0").rank; + GroupInfo& group_info = get_group_info(group_name_); + auto store = group_info.store; + rank_ = group_info.rank; + world_size_ = group_info.world_size; + // Exchange rank to global rank mapping for this group. + // If it is already available, skip the exchange. + if (group_info.rank_to_global_rank.empty()) { + group_info.rank_to_global_rank = + storeExchange.all_gather(store, rank_, world_size_, global_rank); + exchanged_n_times++; + if (rank_ == 0) { + LOG(INFO) << "[rank " << rank_ << "]" + << " rank_to_global_rank: " << group_info.rank_to_global_rank + << ", group_name: " << group_name_ + << ", exchanged_n_times: " << exchanged_n_times; + } + } + TORCH_INTERNAL_ASSERT(!group_info.rank_to_global_rank.empty()); + rank_to_global_rank_ = group_info.rank_to_global_rank; + for (int r = 0; r < world_size_; ++r) { + buffers_.push_back(nvshmem_ptr( + allocation->ptr, rank_to_global_rank_[r])); + } + + // TODO: use the same allocation for signal pad + void* signal_pad_ptr = nvshmem_malloc(signal_pad_size); + AT_CUDA_CHECK(cudaMemset(signal_pad_ptr, 0, signal_pad_size)); + + for (int r = 0; r < world_size_; ++r) { + signal_pads_.push_back(nvshmem_ptr( + signal_pad_ptr, rank_to_global_rank_[r])); + } + + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); + + AT_CUDA_CHECK(cudaMemcpy( + buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); + AT_CUDA_CHECK(cudaMemcpy( + signal_pads_dev_, + signal_pads_.data(), + arr_size, + cudaMemcpyHostToDevice)); + + rank_to_global_rank_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int) * world_size_)); + AT_CUDA_CHECK(cudaMemcpy( + rank_to_global_rank_dev_, + rank_to_global_rank_.data(), + sizeof(int) * world_size_, + cudaMemcpyHostToDevice)); + } + + ~NVSHMEMSymmetricMemory() override{ + // TODO + }; + + std::vector get_buffer_ptrs() override { + return buffers_; + } + + std::vector get_signal_pad_ptrs() override { + return signal_pads_; + } + + void** get_buffer_ptrs_dev() override { + return buffers_dev_; + } + + void** get_signal_pad_ptrs_dev() override { + return signal_pads_dev_; + } + + size_t get_buffer_size() override { + return buffer_size_; + } + + size_t get_signal_pad_size() override { + return signal_pad_size; + }; + + bool has_multicast_support() override { + // TODO + return false; + } + + void* get_multicast_ptr() override { + // TODO + return nullptr; + } + + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + // TODO: deduplicate + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "NVSHMEMSymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::for_blob(data_ptr, sizes) + .options(options) + .target_device(device) + .make_tensor(); + } + + at::Tensor get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype, + int64_t storage_offset) override { + // TODO: deduplicate + // If the dtype is unspecified, default it to UInt32, as it + // is the most common type for signaling purposes. + if (!dtype.has_value()) { + dtype = c10::ScalarType::UInt32; + } + + // If the shape is unspecified, treat the signal pad as a 1d tensor. + const auto element_size = c10::elementSize(*dtype); + std::vector shape; + if (!sizes.empty()) { + shape = sizes.vec(); + } else { + shape.push_back(signal_pad_size / element_size); + } + + const size_t numel = std::accumulate( + shape.begin(), + shape.end(), + static_cast(1), + std::multiplies()); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= signal_pad_size, + "NVSHMEMSymmetricMemory::get_signal_pad: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + signal_pad_size, + " bytes)"); + auto data_ptr = reinterpret_cast(signal_pads_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); + auto options = at::TensorOptions().dtype(*dtype).device(device); + return at::for_blob(data_ptr, shape) + .options(options) + .target_device(device) + .make_tensor(); + } + + void barrier(int channel, size_t timeout_ms) override { + // TODO + } + + void put_signal(int dst_rank, int channel, size_t timeout_ms) override { + // TODO + } + + void wait_signal(int src_rank, int channel, size_t timeout_ms) override { + // TODO + } + + int get_rank() override { + return rank_; + } + + int get_world_size() override { + return world_size_; + } + + virtual const std::vector& get_rank_to_global_rank() override { + return rank_to_global_rank_; + }; + + int* get_rank_to_global_rank_dev() override { + return rank_to_global_rank_dev_; + }; + + private: + std::shared_ptr allocation_; + size_t buffer_size_; + std::vector buffers_; + std::vector signal_pads_; + int device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; + std::string group_name_; + + std::vector rank_to_global_rank_; + int* rank_to_global_rank_dev_; +}; + +// Bootstrap based on user's setting for NCCL +// Long term, this may be a bit unclean; short term, it improves UX +void maybe_initialize_env_vars() { + auto nccl_socket_if_name = c10::utils::get_env("NCCL_SOCKET_IFNAME"); + auto nccl_hca_list = c10::utils::get_env("NCCL_IB_HCA"); + auto nccl_ib_gid_index = c10::utils::get_env("NCCL_IB_GID_INDEX"); + auto nvshmem_socket_if_name = + c10::utils::get_env("NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME"); + auto nvshmem_hca_list = c10::utils::get_env("NCCL_IB_HCA"); + auto nvshmem_ib_gid_index = c10::utils::get_env("NVSHMEM_IB_GID_INDEX"); + + if (!nvshmem_socket_if_name.has_value() && nccl_socket_if_name.has_value()) { + c10::utils::set_env( + "NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME", nccl_socket_if_name->c_str()); + } + if (!nvshmem_hca_list.has_value() && nccl_hca_list.has_value()) { + c10::utils::set_env("NVSHMEM_ENABLE_NIC_PE_MAPPING", "1"); + c10::utils::set_env("NVSHMEM_HCA_LIST", nccl_hca_list->c_str()); + } + if (!nvshmem_ib_gid_index.has_value() && nccl_ib_gid_index.has_value()) { + c10::utils::set_env("NVSHMEM_IB_GID_INDEX", nccl_ib_gid_index->c_str()); + } +} + +void initialize_nvshmem_with_store( + c10::intrusive_ptr store, + int rank, + int world_size) { + static bool is_initialized = false; + if (is_initialized) { + return; + } + + maybe_initialize_env_vars(); + + nvshmemx_uniqueid_t unique_id; + NVSHMEM_CHECK( + nvshmemx_get_uniqueid(&unique_id), "nvshmemx_get_uniqueid failed"); + + // Using an existing store_all_gather due to laziness. + // TODO(yifu): should use broadcast + auto unique_ids = storeExchange.all_gather(store, rank, world_size, unique_id); + + nvshmemx_init_attr_t attr; + nvshmemx_set_attr_uniqueid_args(rank, world_size, &unique_ids[0], &attr); + + NVSHMEM_CHECK( + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr), + "nvshmemx_init_attr failed"); + + is_initialized = true; + + // Print version + int major, minor; + ::nvshmem_info_get_version(&major, &minor); + LOG(INFO) << "NVSHMEM is available, version: " << major << '.' << minor; +} + +class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc( + size_t size, + int device_idx, + const std::optional& group_name) override { + TORCH_CHECK( + group_name == std::nullopt, + "NVSHMEMSymmetricMemoryAllocator::alloc " + "must not be called with a group_name"); + + auto group_info = get_group_info("0"); + auto store = group_info.store; + int rank = group_info.rank; + int world_size = group_info.world_size; + + initialize_nvshmem_with_store(store, rank, world_size); + auto ptr = nvshmem_malloc(size); + auto allocation = + std::make_shared(ptr, size, device_idx); + // TODO: thread safety + allocations_.try_emplace(ptr, std::move(allocation)); + return ptr; + } + + void free(void* ptr) override { + // TODO: thread safety + allocations_.erase(ptr); + }; + + size_t get_alloc_size(void* ptr) override { + auto it = allocations_.find(ptr); + if (it == allocations_.end()) { + TORCH_CHECK( + false, ptr, " is not allocated with NVSHMEMSymmetricMemoryAllocator"); + } + return it->second->buffer_size; + }; + + c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) override { + TORCH_CHECK(group_name.has_value()); + { + auto it = symm_mems_.find(std::make_tuple(ptr, *group_name)); + if (it != symm_mems_.end()) { + return it->second; + } + } + auto it = allocations_.find(ptr); + TORCH_CHECK(it != allocations_.end()); + auto symm_mem = + c10::make_intrusive(it->second, *group_name); + + symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem; + return symm_mem; + }; + + bool has_multicast_support(int device_idx) override { + // TODO + return false; + }; + + c10::DeviceType supported_device_type() override { + return c10::DeviceType::CUDA; + } + + std::string name() override { + return "NVSHMEM"; + } + + private: + std::unordered_map> allocations_; + std::map, c10::intrusive_ptr> + symm_mems_; +}; + +struct RegisterNVSHMEMSymmetricMemoryAllocator { + RegisterNVSHMEMSymmetricMemoryAllocator() { + auto allocator = c10::make_intrusive(); + // Query backend used for CUDA tensor + if (getSymmMemBackendCUDA() == "NVSHMEM") { + // Direct set (static registration) + register_allocator( + c10::DeviceType::CUDA, + allocator); + } else { + // Register availability in case `set_backend` is called dynamically + register_availability("NVSHMEM", allocator); + } + } +}; + +static RegisterNVSHMEMSymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp similarity index 79% rename from torch/csrc/distributed/c10d/SymmetricMemory.cpp rename to torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp index 599892e0b1731e..e2b8007c2079cc 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp @@ -1,4 +1,4 @@ -#include +#include namespace { @@ -22,6 +22,39 @@ class AllocatorMap { map_[device_type] = std::move(allocator); } + void register_availability( + const std::string& name, + c10::intrusive_ptr allocator) { + avail_map_[name] = std::move(allocator); + } + + void set_backend(const std::string& name) { + auto it = avail_map_.find(name); + TORCH_CHECK( + it != avail_map_.end(), + "SymmetricMemory does not find allocation backend ", + name); + auto device_type = it->second->supported_device_type(); + // Check if the existing one is already the one desired. + auto existing = map_.find(device_type); + if (existing != map_.end()) { + if (existing->second->name() == name) { + // The existing one is the same as the desired one. No need to change. + return; + } + TORCH_CHECK(!in_use_, "Backend can not be changed after use."); + } + register_allocator(device_type, it->second); + } + + std::optional get_backend(c10::DeviceType device_type) { + auto it = map_.find(device_type); + if (it == map_.end()) { + return std::nullopt; + } + return it->second->name(); + } + c10::intrusive_ptr get_allocator( c10::DeviceType device_type) { auto it = map_.find(device_type); @@ -29,6 +62,7 @@ class AllocatorMap { it != map_.end(), "SymmetricMemory does not support device type ", device_type); + in_use_ = true; return it->second; } @@ -48,6 +82,17 @@ class AllocatorMap { c10::DeviceType, c10::intrusive_ptr> map_; + + // For backends to register availability. + // This registration is at static time. Therefore, it is expected that the + // derived `SymmetricMemoryAllocator` classes do not have backend-specific + // initialization in constructor (in case it is not selected). + std::unordered_map< + std::string, // backend name "NVSHMEM", "CUDA", "NCCL", etc. + c10::intrusive_ptr> + avail_map_; + + bool in_use_ = false; }; static std::unordered_map group_info_map{}; @@ -128,6 +173,20 @@ void register_allocator( device_type, std::move(allocator)); } +void register_availability( + const std::string& name, + c10::intrusive_ptr allocator) { + return AllocatorMap::get().register_availability(name, std::move(allocator)); +} + +void set_backend(const std::string& name) { + return AllocatorMap::get().set_backend(name); +} + +std::optional get_backend(c10::Device device) { + return AllocatorMap::get().get_backend(device.type()); +} + bool has_allocator(c10::DeviceType device_type) { return AllocatorMap::get().has_allocator(device_type); } @@ -150,7 +209,7 @@ void set_group_info( group_info_map.emplace(group_name, std::move(group_info)); } -const GroupInfo& get_group_info(const std::string& group_name) { +GroupInfo& get_group_info(const std::string& group_name) { TORCH_CHECK( group_info_map.find(group_name) != group_info_map.end(), "get_group_info: no group info associated with the group name ", @@ -275,11 +334,15 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def( "memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)"); + m.def("nvshmem_put(Tensor(a!) tensor, int peer) -> ()"); + m.def("nvshmem_get(Tensor(a!) tensor, int peer) -> ()"); m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)"); m.def( "nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)"); m.def( - "nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)"); + "all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)"); + m.def( + "all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name, int? major_align=None) -> Tensor(a!)"); } TORCH_LIBRARY_IMPL(symm_mem, Meta, m) { diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp similarity index 90% rename from torch/csrc/distributed/c10d/SymmetricMemory.hpp rename to torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp index 86883d5df6f07c..c2828de04c9b3c 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp @@ -72,7 +72,7 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { virtual int get_rank() = 0; virtual int get_world_size() = 0; - virtual std::vector get_rank_to_global_rank() { + virtual const std::vector& get_rank_to_global_rank() { TORCH_CHECK(false, "NYI"); } @@ -96,6 +96,8 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { void* ptr, const std::optional& group_name) = 0; virtual bool has_multicast_support(int device_idx) = 0; + virtual c10::DeviceType supported_device_type() = 0; + virtual std::string name() = 0; }; C10_EXPORT bool is_finalizing(); @@ -104,6 +106,10 @@ C10_EXPORT void register_allocator( c10::DeviceType device_type, c10::intrusive_ptr allocator); +C10_EXPORT void register_availability( + const std::string& name, + c10::intrusive_ptr allocator); + C10_EXPORT bool has_allocator(c10::DeviceType device_type); C10_EXPORT c10::intrusive_ptr get_allocator( @@ -125,9 +131,13 @@ struct GroupInfo { int rank; int world_size; c10::intrusive_ptr store; + // Note this field is not automatically populated by set_group_info(). If a + // SymmetricMemory implementation needs to use it, it must be populated by a + // call to exchange_global_ranks() first. + std::vector rank_to_global_rank; }; -C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); +C10_EXPORT GroupInfo& get_group_info(const std::string& group_name); // Identical to empty_strided, but allows symmetric memory access to be // established for the allocated tensor via SymmetricMemory::rendezvous(). This @@ -169,4 +179,9 @@ TORCH_API c10::intrusive_ptr rendezvous( TORCH_API bool has_multicast_support( c10::DeviceType device_type, int device_idx); + +TORCH_API void set_backend(const std::string& name); + +TORCH_API std::optional get_backend(c10::Device device); + } // namespace c10d::symmetric_memory diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp similarity index 97% rename from torch/csrc/distributed/c10d/intra_node_comm.cpp rename to torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp index 2694dabbac215c..0d53d100cee7d6 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp @@ -1,7 +1,6 @@ -#include - -#include #include +#include +#include #if defined(USE_ROCM) #include @@ -63,7 +62,7 @@ static NvlMesh getNvlMesh(const std::vector& rankToDeviceIdx) { } /** - * Detech topology given a NvlMesh. + * Detect topology given a NvlMesh. */ static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) { if (getCvarBool(TEST_INTRA_NODE_COMM, false)) { diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu similarity index 96% rename from torch/csrc/distributed/c10d/intra_node_comm.cu rename to torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu index c490cba2021ca6..6a6a6520e36bac 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu @@ -1,6 +1,6 @@ -#include +#include -#include +#include namespace c10d { namespace intra_node_comm { diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp similarity index 97% rename from torch/csrc/distributed/c10d/intra_node_comm.hpp rename to torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp index d21ee398f1a63e..7b5e8ff999c5db 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include namespace c10d::intra_node_comm { diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu new file mode 100644 index 00000000000000..8528ecc5f4aa7c --- /dev/null +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu @@ -0,0 +1,640 @@ +#include +#include + +#include +#include +#include +#include + +// Use torch's cub wrapper instead of CUDA's , see #55292 +#include + +// NVSHMEM minimum SM arch +#define _NVSHMEM_MIN_SM_ARCH 700 + +// Some NVSHMEM device APIs do not compile on older SM archs +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH) +// Only include host APIs. See nvshmem.h for details. +#define NVSHMEM_HOSTLIB_ONLY +#endif // Must be done before nvshmem.h is included + +#include + +namespace c10d::nvshmem_extension { + +#define THREADS_PER_BLOCK 512 +#define WARP_SIZE 32 + +constexpr int MiB = 1024 * 1024; + +// Check if NVSHMEM is available +bool is_nvshmem_available() { + // Runtime check + static std::mutex mutex; + static int is_available = -2; + std::lock_guard lock(mutex); + if (is_available == -2) { + void* handle{}; + // Open the shared library, RTLD_LAZY defers symbol resolution until needed + handle = dlopen("libnvshmem_host.so.3", RTLD_LAZY); + if (!handle) { + std::cerr << dlerror() << "\n"; + is_available = 0; + } else { + is_available = 1; + // Close the shared library + dlclose(handle); + } + } + return is_available == 1; +} + +// Initializes the device state in CUmodule so that it’s able to perform NVSHMEM +// operations. +void nvshmemx_cumodule_init(uintptr_t module) { + auto cumodule = reinterpret_cast(module); + NVSHMEM_CHECK( + ::nvshmemx_cumodule_init(cumodule), + "nvshmemx_cumodule_init failed"); +} + +static std::unordered_map group_name_to_team_; + +nvshmem_team_t group_to_team( + const std::string& group_name, + const std::vector& global_ranks) { + auto it = group_name_to_team_.find(group_name); + if (it != group_name_to_team_.end()) { + return it->second; + } + TORCH_CHECK(global_ranks.size() > 1); + int stride = global_ranks[1] - global_ranks[0]; + for (size_t r = 1; r < global_ranks.size(); ++r) { + TORCH_CHECK(global_ranks[r] - global_ranks[r - 1] == stride); + } + + nvshmem_team_t team; + NVSHMEM_CHECK( + nvshmem_team_split_strided( + NVSHMEM_TEAM_WORLD, + global_ranks[0], + stride, + global_ranks.size(), + nullptr, + 0, + &team), + "nvshmem_team_split_strided failed"); + group_name_to_team_[group_name] = team; + TORCH_CHECK(team != NVSHMEM_TEAM_INVALID); + return team; +} + +at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) { + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); + auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank()); + void* buffer_ptr = input_hdl->get_buffer_ptrs()[rank]; + + auto stream = at::cuda::getCurrentCUDAStream(); + nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), 0, stream); + return input; +} + +void nvshmem_put(at::Tensor& tensor, int64_t peer) { + // TODO: support non-contiguous tensors + TORCH_CHECK(tensor.is_contiguous(), + "put op currently supports contiguous tensors only"); + // TODO: rendezvous should remember the group name + auto hdl = c10d::symmetric_memory::rendezvous(tensor, "0"); + auto rank = hdl->get_rank(); + void* buffer_ptr = hdl->get_buffer_ptrs()[rank]; + auto buffer_size = tensor.numel() * tensor.element_size(); + + c10::cuda::CUDAGuard guard(tensor.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + nvshmemx_putmem_on_stream(buffer_ptr, tensor.data_ptr(), buffer_size, peer, stream); +} + +void nvshmem_get(at::Tensor& tensor, int64_t peer) { + // TODO: support non-contiguous tensors + TORCH_CHECK(tensor.is_contiguous(), + "get op currently supports contiguous tensors only"); + // TODO: rendezvous should remember the group name + auto hdl = c10d::symmetric_memory::rendezvous(tensor, "0"); + auto rank = hdl->get_rank(); + void* buffer_ptr = hdl->get_buffer_ptrs()[rank]; + auto buffer_size = tensor.numel() * tensor.element_size(); + + c10::cuda::CUDAGuard guard(tensor.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + nvshmemx_getmem_on_stream(tensor.data_ptr(), buffer_ptr, buffer_size, peer, stream); +} + +at::Tensor nvshmem_all_to_all( + at::Tensor& input, + at::Tensor& out, + std::string group_name) { + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); + int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); + auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank()); + + void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; + void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; + size_t bytes_per_rank = input_hdl->get_buffer_size() / world_size; + + auto stream = at::cuda::getCurrentCUDAStream(input.device().index()); + nvshmemx_alltoallmem_on_stream(team, output_ptr, input_ptr, bytes_per_rank, stream); + return out; +} + +// This is an exclusive prefix sum function that calculates read (or write) offsets for each peer. +__device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) { + // Specialize BlockScan for a 1D block of threads, of type int64_t. + // - `BLOCK_SCAN_WARP_SCANS` is a low-latency scan algorithm (instead of high + // throughput which we don't need here). + // - `at_cuda_detail::cub` is torch's cub wrapper, see #55292. + using BlockScanT = at_cuda_detail::cub::BlockScan; + // Allocate shared memory for BlockScan + __shared__ typename BlockScanT::TempStorage temp_storage; + + // TODO: currently it is assumed that the number of PE's is smaller than + // `THREADS_PER_BLOCK` + CUDA_KERNEL_ASSERT(n <= THREADS_PER_BLOCK); + + // Obtain input item for each thread + int tid = threadIdx.x; + int64_t thread_data = (tid < n) ? idata[tid] : 0; + + // Collectively compute the block-wide exclusive prefix sum + int64_t block_aggregate; + BlockScanT(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate); + + // Store the result + odata[tid] = thread_data; + return block_aggregate; +} + +// This kernel is used to exchange output splits and source offsets between peers. +// `in_out_splits` is of size (3, npes) and contains: +// - input splits (IN) +// - output splits (OUT) and +// - source offsets (OUT). +__global__ void exchangeSplitAndOffset(int64_t* in_out_splits, int mype, int npes) { +#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH + CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM"); +#else + auto input_splits = in_out_splits; + auto output_splits = in_out_splits + npes; + auto source_offsets = in_out_splits + npes * 2; + int tid = threadIdx.x; + + __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; + + // Scan input splits to get the source offsets + prefixSum(peer_offsets, input_splits, npes); + __syncthreads();; + + // Use 1 block to do the exchange + if (tid < npes) { + int peer = tid; + nvshmem_int64_p(source_offsets + mype, peer_offsets[peer], peer); + nvshmem_int64_p(output_splits + mype, input_splits[peer], peer); + } + // This barrier ensures that all remote PEs see the updated values + nvshmemx_barrier_all_block(); +#endif +} + +// This kernel is used to do the actual data exchange. +// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. +// `stride` is the stride at dim 0, unit in byte. +__global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes) { +#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH + CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM"); +#else + auto output_splits = in_out_splits + npes; + auto source_offsets = in_out_splits + npes * 2; + int bid = blockIdx.x; + int tid = threadIdx.x; + int blocks_per_peer = max(gridDim.x / npes, 1); + + // Calculate the output offsets + __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; + prefixSum(peer_offsets, output_splits, npes); + __syncthreads(); + + // Target a different peer based on bid + for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) { + int peer = (mype + i) % npes; + // Total amount from `peer` + auto peer_size = output_splits[peer] * stride; + // Amount to get from `peer` in this block + auto block_size = peer_size / blocks_per_peer; + // Being lazy here, we should handle the residual if the division is not exact + CUDA_KERNEL_ASSERT(block_size * blocks_per_peer == peer_size); + // This block's offset in the data from `peer` + auto block_offset = block_size * (bid % blocks_per_peer); + auto source_offset = source_offsets[peer] * stride + block_offset; + auto write_offset = peer_offsets[peer] * stride + block_offset; + nvshmemx_getmem_block( + (char*)recv_data + write_offset, + (char*)send_data + source_offset, + block_size, + peer); + } + // Write out the output offsets (to the scratchpad line) + if (bid == 0 && tid < npes) { + source_offsets[tid] = peer_offsets[tid]; + } +#endif +} + +at::Tensor all_to_all_vdev( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_out_splits, + std::string group_name) { + /* Perform AllToAllv operation using NVSHMEM, with split information provided on device. + * Arguments: + * - `input` is the input tensor + * - `out` is the output tensor + * - `in_out_splits` is a 2D tensor of size (3, npes). The rows are (in order): + input splits (IN) + output splits (OUT) and + output offsets (OUT). + */ + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); + auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name); + int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); + + void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; + void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; + int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]); + + auto stream = at::cuda::getCurrentCUDAStream(input.device().index()); + + // Exchange output splits and source offsets + // Use collective launch because kernel involves nvshmem barrier + void* args0[] = { + &splits_ptr, + &rank, + &world_size}; + nvshmemx_collective_launch( + (const void*)exchangeSplitAndOffset, + dim3(1), + dim3(THREADS_PER_BLOCK), + args0, + 0, + stream); + + // CTA Tuning + // Intra-node: use multiple blocks per peer to increase data parallelism, up to 8. + // Up to 1 MB -> 1 block + // Up to 2 MB -> 2 blocks + // Up to 4 MB -> 4 blocks + // More -> 8 blocks + // The tuning for `num_blocks` below multiplies these numbers by world_size + // (e.g. 8 -> 8 * 8). If world_size is smaller, we simply shift the blocks + // towards data parallelism. (There may be room for improvement here) + auto input_size = input.numel() * input.element_size(); + int num_blocks = input_size < MiB ? 8 : + (input_size < 2 * MiB ? 16 : + (input_size < 4 * MiB ? 32 : 64)); + + // Inter-node: limit the total the number of blocks: + // = 16 for 16GPUs which is enough to max out 90 GB/s bandwidth perf + // = 8 for more than 16 GPUs which is enough to max out approx 50 GB/s bandwidth perf + // Above assumes 400Gb/s NIC for inter-node and 400GB/s NVLinks for intra-node comms. + // TODO: better intra vs inter detection, currently it is based on world_size. + int max_inter_node_blocks = world_size <= 16 ? 16 : 8; + if (world_size > 8) { + num_blocks = std::min(num_blocks, max_inter_node_blocks); + } + + // Stride at dim 0 (assuming input is contiguous, TODO) + size_t stride_bytes = input.stride(0) * input.element_size(); + + // All to all data exchange + void* args1[] = { + &input_ptr, + &output_ptr, + &splits_ptr, + &stride_bytes, + &rank, + &world_size}; + nvshmemx_collective_launch( + (const void*)allToAllV, + dim3(num_blocks), + dim3(THREADS_PER_BLOCK), + args1, + 0, + stream); + return out; +} + +// Start of `all_to_all_vdev_2d` +// This kernel is used to exchange output splits and source offsets between peers. +// For meaning of `mype` and `npes`, see the docstring of `all_to_all_vdev_2d`. +// `in_out_splits` is of size (3, npes * ne) and contains: +// - input splits (IN) +// - output splits (OUT) and +// - source offsets (OUT). +__global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int npes, int ne, size_t input_dim0) { +#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH + CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM"); +#else + int nsplits = npes * ne; + auto input_splits = in_out_splits; + auto output_splits = in_out_splits + nsplits; + auto source_offsets = in_out_splits + nsplits * 2; + int tid = threadIdx.x; + + __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; + + // Scan input splits to get the source offsets + auto sum_of_splits = prefixSum(peer_offsets, input_splits, nsplits); + __syncthreads();; + CUDA_KERNEL_ASSERT(sum_of_splits <= input_dim0); + + // Use 1 block to do the exchange + if (tid < nsplits) { + int peer = tid / ne; + int e = tid % ne; + // This does a transpose from rank-major order to expert-major order + int dst_offset = e * npes + mype; + auto split_val = input_splits[tid]; + CUDA_KERNEL_ASSERT(split_val >= 0); + nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer); + nvshmem_int64_p(output_splits + dst_offset, split_val, peer); + } + // This barrier ensures that all remote PEs see the updated values + nvshmemx_barrier_all_block(); +#endif +} + +// This is an warp-scope, exclusive prefix sum. When called by a block of +// threads, each warp will perform an independent prefix sum, concurrently. +// Returns the sum of all elements in the warp. +// `NUM_WARPS` is the number of warps participating the concurrent prefix sum. +template +__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) { + CUDA_KERNEL_ASSERT(n <= WARP_SIZE); + + // Specialize WarpScan for type int + using WarpScan = at_cuda_detail::cub::WarpScan; + // Allocate WarpScan shared memory for N warps + __shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS]; + + int warp_id = threadIdx.x / WARP_SIZE; + if (warp_id >= NUM_WARPS) { + return 0; + } + + // Obtain input item for each thread + int tid = threadIdx.x % WARP_SIZE; + int64_t thread_data = (tid < n) ? idata[tid] : 0; + + // Total sum of all elements in the warp + int64_t warp_aggregate; + // Compute the warp-wide exclusive prefix sum + WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate); + + // Store the result + odata[tid] = thread_data; + return warp_aggregate; +} + +// This is for abstracting a thread-group-scope, exclusive prefix sum. +// Since we use warp-scope prefix sum, the thread group size is limited to warp size. +#define A2AV_TILE_SIZE WARP_SIZE + +// This kernel is used to do the actual data exchange. +// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. +// `stride` is the stride at dim 0, unit in byte. +// For meaning of `mype` and `npes`, see the docstring of `all_to_all_vdev_2d`. +__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) { +#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH + CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM"); +#else + int nsplits = npes * ne; + auto output_splits = in_out_splits + nsplits; + auto source_offsets = in_out_splits + nsplits * 2; + int bid = blockIdx.x; + int tid = threadIdx.x; + + // Split the thread block into tiles + constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; + int tileId = tid / A2AV_TILE_SIZE; + int laneId = tid % A2AV_TILE_SIZE; + // Each tile calculates its own prefix sum + __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE]; + // A tile takes care of npes worth of splits + int nsplits_per_tile = min(npes, nsplits - tileId * npes); + // TODO: currently it is assumed that the number of PE's is smaller than + // `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to + // WARP_SIZE elements + CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE); + // Similarly, the number of experts per rank is also assumed to be smaller + // than `NUM_TILES` + CUDA_KERNEL_ASSERT(ne <= NUM_TILES); + + // Total length of each tile + __shared__ int64_t len_per_tile[NUM_TILES]; + // When `nsplits` is small, not every tile gets data to sum. They can skip + // this local prefix sum. + if (nsplits_per_tile > 0) { + // Each tile calculates its own prefix sum, return value is the sum of all elements in the tile. + int64_t my_tile_len = prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile); + // Last thread in each tile does the up aligning. + if (laneId == A2AV_TILE_SIZE - 1) { + auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align; + // In case `aligned_len` is 0, we set it to `major_align` to avoid an + // empty bin, bc cutlass currently does not support it. See + // https://github.com/pytorch/pytorch/issues/152668. + len_per_tile[tileId] = max(aligned_len, major_align); + } + } + __syncthreads(); + + // Starting offset of each tile + __shared__ int64_t start_offset_per_tile[NUM_TILES]; + // Prefix sum again to get the tiles' start offsets. + // `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads + // = 1024 threads, and this kernel is launched within 1024 threads. Thus, we + // can use warp-scope prefix sum. + static_assert(NUM_TILES <= WARP_SIZE); + // Only 1 warp is needed + prefixSum_warp<1>(start_offset_per_tile, len_per_tile, NUM_TILES); + __syncthreads(); + + // Add tile offset to every element in the tile + tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId]; + __syncthreads(); + + // Target a different e based on bid + for (int eid = bid; eid < nsplits; eid += gridDim.x) { + int peer = eid % npes; + // Amount from `peer` for `e` + auto peer_size = output_splits[eid] * stride; + auto source_offset = source_offsets[eid] * stride; + auto e_offset = tile_prefix_sums[eid / npes][peer]; + auto write_offset = e_offset * stride; + nvshmemx_getmem_block( + (char*)recv_data + write_offset, + (char*)send_data + source_offset, + peer_size, + peer); + } + // Write out the output offsets (to the scratchpad line) + if (bid == 0 && tid < nsplits) { + source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes]; + } +#endif +} + +at::Tensor all_to_all_vdev_2d( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_out_splits, + std::string group_name, + std::optional major_align) { + /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device. + * Arguments: + * - `input` is the input tensor + * - `out` is the output tensor + * - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the + scenario of Mixture-of-Experts models, `ne` is the number of experts per + rank. The rows of `in_out_splits` are (in order): + input splits (IN) + output splits (OUT) and + output offsets (OUT). + * - `group_name` is the name of the group to use for the collective operation. + * - `major_align` is the alignment of the "major dimension" of the output + sequence. See below for details. + + * A 2D AllToAllv shuffle is illustrated below: + (world_size = 2, ne = 2, total number of experts = 4) + Source: | Rank 0 | Rank 1 | + | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 | + + Dest : | Rank 0 | Rank 1 | + | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 | + where each `c_i` / `d_i` are slices of the `input` tensor, targeting + expert `i`, with length indicated by input splits (in + `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achieves a + transpose from rank-major order at input to expert-major order at + output. + + * If `major_align` is not 1, the output offsets of c1, c2, c3 will be + up-aligned to this value. For example, if c0 has length 5 and d0 has + length 7 (making a total of 12), and if the `major_align` is set to 16, + the output offset of c1 will be 16. Similar for c2 and c3. This value has + no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3. + Note: since cutlass does not support empty bins, we set the aligned length + to `major_align` if it is 0. See + https://github.com/pytorch/pytorch/issues/152668. + */ + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); + auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name); + int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); + // TODO: world_size is currently limited by the number of elements in a WarpScan. + TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE); + + // If `major_align` is not provided, use 1 as the default value. + int64_t major_align_val = major_align.value_or(1); + TORCH_CHECK(major_align_val > 0, "major_align must be positive"); + + void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; + void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; + int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]); + + // Shape checks + auto split_shape = in_out_splits.sizes(); + TORCH_CHECK(in_out_splits.is_contiguous() + && input.is_contiguous() + && out.is_contiguous(), + "input, out and in_out_splits must be contiguous"); + TORCH_CHECK(split_shape.size() == 2 + && split_shape[0] == 3 + && split_shape[1] % world_size == 0, + "in_out_splits must be 2D with 3 rows, " + "each row must be a multiple of world_size"); + + // Consistency checks + TORCH_CHECK(input.dtype() == out.dtype() + && input.stride(0) == out.stride(0), + "input and out must have the same dtype and same stride at dim 0"); + TORCH_CHECK(in_out_splits.scalar_type() == at::kLong, "in_out_splits must be int64"); + + // Number of experts per rank + int ne = split_shape[1] / world_size; + constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; + TORCH_CHECK(ne <= NUM_TILES, "Number of experts must be smaller than NUM_TILES", NUM_TILES); + + // Set device context for getting the stream and launching kernels below + c10::cuda::CUDAGuard guard(input.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + // Exchange output splits and source offsets + auto input_dim0 = input.size(0); + // Use collective launch because kernel involves nvshmem barrier + void* args0[] = { + &splits_ptr, + &rank, + &world_size, + &ne, + &input_dim0}; + nvshmemx_collective_launch( + (const void*)exchangeSplitAndOffset_2d, + dim3(1), + dim3(THREADS_PER_BLOCK), + args0, + 0, + stream); + + // CTA Tuning + // Naive for now, use 1 block per expert. + // Total number of blocks is limited to 64 (intra-node) or 8 (inter-node). + int num_blocks = std::min(world_size * ne, world_size > 8 ? 8 : 64); + + // Stride at dim 0 + size_t stride_bytes = input.stride(0) * input.element_size(); + + // All to all data exchange + void* args1[] = { + &input_ptr, + &output_ptr, + &splits_ptr, + &stride_bytes, + &rank, + &world_size, + &ne, + &major_align_val}; + nvshmemx_collective_launch( + (const void*)allToAllV_2d, + dim3(num_blocks), + dim3(THREADS_PER_BLOCK), + args1, + 0, + stream); + return out; +} + +} // namespace c10d::nvshmem_extension + + +TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { + m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast); + m.impl("nvshmem_put", c10d::nvshmem_extension::nvshmem_put); + m.impl("nvshmem_get", c10d::nvshmem_extension::nvshmem_get); + m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all); + m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev); + m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d); +} diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh new file mode 100644 index 00000000000000..86f8724cd5b88e --- /dev/null +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +#define NVSHMEM_CHECK(stmt, msg) \ + do { \ + int result = (stmt); \ + TORCH_CHECK( \ + result == 0, \ + std::string(__FILE__) + ":" + std::to_string(__LINE__) + " " + msg + \ + ". Error code: " + std::to_string(result)); \ + } while (0) + +namespace c10d::nvshmem_extension { + +// Check if NVSHMEM is available +TORCH_API bool is_nvshmem_available(); + +// Initializes the device state in CUmodule so that it’s able to perform NVSHMEM +// operations. +TORCH_API void nvshmemx_cumodule_init(uintptr_t module); + +TORCH_API void nvshmem_put(at::Tensor& tensor, int64_t peer); + +TORCH_API void nvshmem_get(at::Tensor& tensor, int64_t peer); + +at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name); + +at::Tensor nvshmem_all_to_all( + at::Tensor& input, + at::Tensor& out, + std::string group_name); + +at::Tensor all_to_all_vdev( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_out_splits, + std::string group_name); + +at::Tensor all_to_all_vdev_2d( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_out_splits, + std::string group_name, + std::optional major_align = std::nullopt); + +} // namespace c10d::nvshmem_extension diff --git a/torch/csrc/distributed/rpc/agent_utils.h b/torch/csrc/distributed/rpc/agent_utils.h index 016f6110e13e2a..8e403bcb691283 100644 --- a/torch/csrc/distributed/rpc/agent_utils.h +++ b/torch/csrc/distributed/rpc/agent_utils.h @@ -24,7 +24,7 @@ TORCH_API std::unordered_map collectCurrentNames( const worker_id_t selfId, const std::string& selfName); -// Remove name frmo Store, used in dynamic RPC groups. +// Remove name from Store, used in dynamic RPC groups. // NOTE: This needs to be called with the Dynamic RPC group // membership management token held. TORCH_API void removeCurrentName( diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index d5274289d6102a..f7bc517f41c53e 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -16,7 +16,7 @@ namespace torch::distributed::rpc { namespace { py::tuple toPyTuple(const RRefForkData& rrefForkData) { - // add GIL as it is contructing a py::object + // add GIL as it is constructing a py::object pybind11::gil_scoped_acquire ag; return py::make_tuple( rrefForkData.ownerId_, @@ -86,12 +86,14 @@ TypePtr tryInferTypeWithTypeHint( // Check if value is an instance of a ScriptClass. If not, skip type inference // because it will try to script the class that value is in instance of, and // this should be avoided. - py::bool_ can_compile = py::module::import("torch._jit_internal") - .attr("can_compile_class")(value.get_type()); + py::bool_ can_compile = + py::module::import("torch._jit_internal") + .attr("can_compile_class")(py::type::handle_of(value)); if (py::cast(can_compile)) { - py::object existing_ty = py::module::import("torch.jit._state") - .attr("_get_script_class")(value.get_type()); + py::object existing_ty = + py::module::import("torch.jit._state") + .attr("_get_script_class")(py::type::handle_of(value)); if (existing_ty.is_none()) { return PyObjectType::get(); diff --git a/torch/csrc/distributed/rpc/python_remote_call.h b/torch/csrc/distributed/rpc/python_remote_call.h index 0a3054b594d28b..09d4ba36dc62bc 100644 --- a/torch/csrc/distributed/rpc/python_remote_call.h +++ b/torch/csrc/distributed/rpc/python_remote_call.h @@ -3,7 +3,6 @@ #include #include #include -#include namespace torch::distributed::rpc { class TORCH_API PythonRemoteCall : public RpcCommandBase { diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.cpp b/torch/csrc/distributed/rpc/python_rpc_handler.cpp index 35a4f9c9877fc5..de5cd0540a4542 100644 --- a/torch/csrc/distributed/rpc/python_rpc_handler.cpp +++ b/torch/csrc/distributed/rpc/python_rpc_handler.cpp @@ -121,7 +121,7 @@ PythonRpcHandler& PythonRpcHandler::getInstance() { // initialization by calling `new PythonRpcHandler()`, inside of which GIL is // also required. Static data initialization is thread-safe, so the thread // holding the GIL will wait for the other thread to finish static data - // initializating before going forward. Because the initialization can't + // initializing before going forward. Because the initialization can't // proceed without GIL, there is a deadlock. We ask the calling thread to // release GIL to avoid this situation. TORCH_INTERNAL_ASSERT(!PyGILState_Check()); @@ -174,7 +174,7 @@ void PythonRpcHandler::handleExceptionGILHeld(const py::object& obj) { bool PythonRpcHandler::isRemoteException(const py::object& obj) { PROFILE_GIL_SCOPED_ACQUIRE; - auto type = obj.get_type(); + auto type = py::type::handle_of(obj); auto moduleName = type.attr("__module__").cast(); auto qualName = type.attr("__qualname__").cast(); return moduleName == kInternalModule && qualName == "RemoteException"; diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 438f2ff9954f33..ccc2720b18ae51 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -134,7 +134,7 @@ c10::intrusive_ptr RequestCallbackImpl::runPythonFunction( try { return result.cast().fut; } catch (const py::cast_error& e) { - auto type = result.get_type(); + auto type = py::type::handle_of(result); auto errMsg = c10::str( e.what(), ". Functions decorated with @rpc.async_function must return a " diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index 1022d6ff97d7ff..fa26c1849ddec7 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -348,7 +348,7 @@ c10::intrusive_ptr RRefContext::getOrCreateOwnerRRef( // here is a plain TensorType, they are not equal relationship: // specialized TensorType <: plain TensorType // - // In RPC we don't care the difference as we ser/de with just the + // In RPC we don't care the difference as we ser'de with just the // plain TensorType. This is not a issue for UserRRef creation either, // since Tensor can only get specialized with a previous run of local // JIT function, and we shouldn't preserve the specialized SubTensorType diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h index 3282e8c0e108f5..ce3b71580ab6c5 100644 --- a/torch/csrc/distributed/rpc/rref_context.h +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -318,7 +318,7 @@ class TORCH_API RRefContext { // RRef is forwarded to the callee as new UserRRefs (if the callee is not // the owner). In this case, we block running the user function until all // UserRRefs are confirmed by the owner. - // This contract gurantees that no UserRRefs can be used remotely without + // This contract guarantees that no UserRRefs can be used remotely without // confirmation. Note that, however, the UserRRef created by rpc.remote can // still be passed to local functions as arguments and used there. This is by // design, because this feature is especially useful when, say a master node diff --git a/torch/csrc/distributed/rpc/rref_proto.h b/torch/csrc/distributed/rpc/rref_proto.h index e6bffd1870b3f2..a1482b46939b18 100644 --- a/torch/csrc/distributed/rpc/rref_proto.h +++ b/torch/csrc/distributed/rpc/rref_proto.h @@ -4,7 +4,6 @@ #include #include #include -#include #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 19e1871ead871a..476ee118fe7fa9 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/distributed/rpc/script_remote_call.h b/torch/csrc/distributed/rpc/script_remote_call.h index 534ac0044599da..e18edab648210a 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.h +++ b/torch/csrc/distributed/rpc/script_remote_call.h @@ -3,7 +3,6 @@ #include #include #include -#include #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_resp.h b/torch/csrc/distributed/rpc/script_resp.h index fd8cd4b845d1cd..53841e3d705c2e 100644 --- a/torch/csrc/distributed/rpc/script_resp.h +++ b/torch/csrc/distributed/rpc/script_resp.h @@ -2,7 +2,6 @@ #include #include -#include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 9801a0327ddf17..c25e83c07c6db8 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -131,6 +131,7 @@ std::vector getDevicesOfTensors( devices.reserve(deviceCount); for (const auto idx : c10::irange(indexBitset.size())) { if (indexBitset[idx]) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) devices.emplace_back(impl->type(), static_cast(idx)); } } @@ -371,7 +372,7 @@ void TensorPipeAgent::checkAndSetStaticGroup( isStaticGroupKey, std::vector(), isStaticGroupVec); std::string returnedVal = std::string(returnedVec.begin(), returnedVec.end()); // In both cases, the returned value should be the value of isStaticGroupStr, - // otherwise there is a discrepency with initialization among one of the + // otherwise there is a discrepancy with initialization among one of the // members TORCH_CHECK( returnedVal == isStaticGroupStr, diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index aaa2e9699e4e56..adce4056840257 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -121,8 +121,8 @@ struct TORCH_API TensorPipeRpcBackendOptions : public RpcBackendOptions { deviceMaps[workerName] = deviceMap; } else { for (auto& entry : deviceMap) { - // c10::Device has no default constructor, hence map[device] dosn't work - // In C++-17 we can use insert_or_assign. + // c10::Device has no default constructor, hence map[device] doesn't + // work In C++-17 we can use insert_or_assign. auto entryIter = iter->second.find(entry.first); if (entryIter == iter->second.end()) { iter->second.emplace(entry.first, entry.second); diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp index 9a4fbf8fac0a59..beb8064ba6c242 100644 --- a/torch/csrc/dynamo/cache_entry.cpp +++ b/torch/csrc/dynamo/cache_entry.cpp @@ -5,7 +5,7 @@ #include CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) - : backend{backend} { + : backend{py::cast(get_backend(backend))} { this->guard_manager = guarded_code.attr("guard_manager"); this->code = guarded_code.attr("code"); this->compile_id = guarded_code.attr("compile_id"); @@ -52,6 +52,7 @@ void CacheEntry::invalidate(py::object deleted_guard_manager) { this->guard_manager = std::move(deleted_guard_manager); this->root_mgr = nullptr; this->trace_annotation = "Invalidated"; + this->backend = py::none(); } void CacheEntry::update_diff_guard_root_manager() { @@ -76,8 +77,8 @@ PyObject* CacheEntry_to_obj(CacheEntry* e) { PyObject* get_backend(PyObject* callback) { py::handle handle = py::handle(callback); - while (py::hasattr(handle, "_torchdynamo_orig_callable")) { - handle = handle.attr("_torchdynamo_orig_callable"); + while (py::hasattr(handle, "_torchdynamo_orig_backend")) { + handle = handle.attr("_torchdynamo_orig_backend"); } return handle.ptr(); } diff --git a/torch/csrc/dynamo/cache_entry.h b/torch/csrc/dynamo/cache_entry.h index cab22926147c2f..e7c58f31a090de 100644 --- a/torch/csrc/dynamo/cache_entry.h +++ b/torch/csrc/dynamo/cache_entry.h @@ -53,7 +53,7 @@ typedef struct VISIBILITY_HIDDEN CacheEntry { // diff guard root guard manager if exists void* diff_guard_root_mgr{nullptr}; // backend used to create this cache entry - PyObject* backend{nullptr}; + py::object backend; // Reference to owning ExtraState ExtraState* _owner{nullptr}; // Reference to this CacheEntry's location in owner's linked list diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index d1b06f47597568..cba8158213c660 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -78,7 +78,8 @@ struct TORCH_API PyCompilerInterface { virtual void call_accumulate_grad( PyObject* py_compiler, const at::Tensor& variable, - const at::Tensor& grad) const { + const at::Tensor& grad, + bool has_post_hooks) const { TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); } }; @@ -97,7 +98,7 @@ struct TORCH_API PyCompilerGuard { // including torch/csrc/autograd/engine.h breaks BC by somehow introducing // symbol resolution issues. Instead requiring downstream users to include // engine.h to access collect_input_metadata, we provide it here (with a -// different name to avoid ambigous symbols...) +// different name to avoid ambiguous symbols...) TORCH_API std::vector> get_input_metadata( const edge_list& edges); @@ -571,7 +572,8 @@ class CompiledNodeArgs { } } void collect(const InputMetadata& t) { - TORCH_CHECK(!t.is_nested_tensor(), "NestedTensor not implemented"); + TORCH_CHECK_NOT_IMPLEMENTED( + !t.is_nested_tensor(), "NestedTensor support not implemented. "); collect(t.options()); collect(t.is_tensor_subclass()); collect(t.shape_as_dim_vector()); @@ -1066,7 +1068,7 @@ class SwapSavedVariables { // (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph // capture wants to take a variant of this function and proxy it into the graph. // Every autograd node defines an apply_with_saved function, that when invoked, -// proxys a call to a function into the Compiled Autograd graph. +// proxies a call to a function into the Compiled Autograd graph. // // Some requirements that we have are: // - The proxy'ed function must have inputs that are FX-graphable types. @@ -1109,7 +1111,8 @@ struct IValuePacker { // with certain compiler settings // (see https://github.com/pytorch/pytorch/pull/144707 for examples). // It's not clear what the problem is, so we're going to ignore it for now. - TORCH_INTERNAL_ASSERT(false, "torch.compile not supported on Windows"); + TORCH_CHECK_NOT_IMPLEMENTED( + false, "torch.compile not supported on Windows"); #else if constexpr (::std::is_same_v) { return at::TensorType::get(); @@ -1146,7 +1149,8 @@ struct IValuePacker { // define how to pack and unpack an object of this time into an IValue // by creating a specialization of IValuePacker for this type. // See NOTE: [Compiled Autograd and backward functions] for context. - TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type"); + TORCH_CHECK_NOT_IMPLEMENTED( + false, "IValuePacker not implemented for type"); return at::NoneType::get(); } #endif @@ -1269,11 +1273,11 @@ inline at::TensorOptions unpack_TensorOptions( at::TensorOptions result; auto maybe_requires_grad = std::get<0>(tuple); if (maybe_requires_grad.has_value()) { - result = result.requires_grad(maybe_requires_grad.value()); + result = result.requires_grad(maybe_requires_grad); } auto maybe_memory_format = std::get<1>(tuple); if (maybe_memory_format.has_value()) { - result = result.memory_format(maybe_memory_format.value()); + result = result.memory_format(maybe_memory_format); } auto maybe_device = std::get<2>(tuple); if (maybe_device.has_value()) { @@ -1286,11 +1290,11 @@ inline at::TensorOptions unpack_TensorOptions( } auto maybe_layout = std::get<4>(tuple); if (maybe_layout.has_value()) { - result = result.layout(maybe_layout.value()); + result = result.layout(maybe_layout); } auto maybe_pinned_memory = std::get<5>(tuple); if (maybe_pinned_memory.has_value()) { - result = result.pinned_memory(maybe_pinned_memory.value()); + result = result.pinned_memory(maybe_pinned_memory); } return result; } diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index cb6615c8aca625..f413782b2d3013 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -11,6 +11,7 @@ #include PyObject* guard_error_hook = NULL; +PyObject* guard_complete_hook = NULL; typedef struct { int active_dynamo_threads; @@ -626,6 +627,22 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) { Py_RETURN_NONE; } +static PyObject* set_guard_complete_hook(PyObject* dummy, PyObject* obj) { + PyObject* old_hook = guard_complete_hook; + + if (obj == Py_None) { + obj = NULL; + } + + guard_complete_hook = Py_XNewRef(obj); + + if (old_hook == NULL) { + Py_RETURN_NONE; + } else { + return old_hook; + } +} + // Debugging function for GNU C only. // Used to set gdb breakpoints in hot CPython sites from Python. // Code example: @@ -666,6 +683,7 @@ static PyMethodDef _methods[] = { {"unsupported", unsupported, METH_VARARGS, NULL}, {"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, + {"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL}, {"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL}, {NULL, NULL, 0, NULL}}; diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index fdc34aa58866bd..e05de24259e0b6 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -7,6 +7,10 @@ #include #include +extern "C" { +extern PyObject* guard_complete_hook; +} + static constexpr const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup"; @@ -197,7 +201,23 @@ PyObject* dynamo__custom_eval_frame( // guard eval failed, keep propagating fail(); return eval_result; - } else if (maybe_cached_code != Py_None) { + } + + // NB: We only do guard collectives when there are any compiled code entries + // at all; these reduces overtriggering and we don't need to do guard + // collectives the very first time we've seen a frame + // TODO: We could also check if we had just created extra for the first + // time? Not too sure the best condition for extra->cache_entry_list + if (guard_complete_hook != nullptr && !extra->cache_entry_list.empty()) { + py::handle guard_complete_hook_handle(guard_complete_hook); + // False means force compilation (someone cache missed) + py::object res = guard_complete_hook_handle(maybe_cached_code != Py_None); + if (!py::cast(res)) { + maybe_cached_code = Py_None; // NB: non-owning + } + } + + if (maybe_cached_code != Py_None) { cached_code = (PyCodeObject*)maybe_cached_code; // used cached version DEBUG_TRACE("cache hit %s", get_frame_name(frame)); @@ -274,7 +294,7 @@ PyObject* dynamo__custom_eval_frame( // NB: We could use extract_cache_entry to get the cache_entry, but // extract_cache_entry returns a borrowed reference. Modifying a borrowed // reference seems wrong. Therefore, we directly access the - // extra->cache_entry. extra wont be NULL here. + // extra->cache_entry. extra won't be NULL here. CacheEntry* new_cache_entry = create_cache_entry(extra, guarded_code, backend); diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 2e60816aa2dfad..ad617a8de5b090 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -121,12 +121,13 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code) { static bool backend_match(PyObject* saved_backend, PyObject* backend) { // Pointer equality check for common case if (saved_backend != backend) { - // The Py_TYPE check should not be required but there is a pre-existing - // issue where backend is possibly deallocated (or nullptr) and causes - // segfaults. Check test - test_inplace_custom_op_intermediate - return ( - Py_TYPE(saved_backend) == Py_TYPE(backend) && - PyObject_RichCompareBool(saved_backend, backend, Py_EQ)); + int result = PyObject_RichCompareBool(saved_backend, backend, Py_EQ); + // Check for exception + if (result == -1) { + PyErr_Clear(); + return false; + } + return (result == 1); } return true; } @@ -140,11 +141,19 @@ void lookup( bool is_skip_guard_eval_unsafe) { size_t index = 0; CacheEntry* found = nullptr; + + for (const auto& entry : extra_state->precompile_entries) { + if (torch::dynamo::run_root_guard_manager(entry.root_mgr, f_locals)) { + *maybe_cached_code = entry.code.ptr(); + return; + } + } + for (CacheEntry& cache_entry : extra_state->cache_entry_list) { // Check backend. Py_False means run only mode. - bool valid = - backend == Py_False || backend_match(cache_entry.backend, backend); + bool valid = backend == Py_False || + backend_match(cache_entry.backend.ptr(), backend); if (valid) { try { @@ -220,3 +229,57 @@ py::list _debug_get_cache_entry_list(const py::handle& code_obj) { } return result; } + +PrecompileEntry::PrecompileEntry(py::object gm, py::object c) + : guard_manager(std::move(gm)), code(std::move(c)) { + if (!PyCode_Check(code.ptr())) { + throw std::runtime_error("Expecting CodeType from PrecompileEntry."); + } + root_mgr = + torch::dynamo::convert_to_root_guard_manager(guard_manager.attr("root")); +} + +void _reset_precompile_entries(const py::handle& code_obj) { + if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { + throw py::type_error("expected a code object!"); + } + PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); + ExtraState* extra = get_extra_state(code); + py::list result; + if (extra != nullptr) { + extra->precompile_entries.clear(); + } +} + +void _load_precompile_entry( + const py::handle& code_obj, + py::object guard_manager, + py::object dynamo_code) { + if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { + throw py::type_error("expected a code object!"); + } + PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); + ExtraState* extra = get_extra_state(code); + py::list result; + if (extra == nullptr) { + extra = init_and_set_extra_state(code); + } + auto entry = + PrecompileEntry(std::move(guard_manager), std::move(dynamo_code)); + extra->precompile_entries.push_back(std::move(entry)); +} + +py::list _debug_get_precompile_entries(const py::handle& code_obj) { + if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { + throw py::type_error("expected a code object!"); + } + PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); + ExtraState* extra = get_extra_state(code); + py::list result; + if (extra != nullptr) { + for (PrecompileEntry& e : extra->precompile_entries) { + result.append(py::cast(e, py::return_value_policy::reference)); + } + } + return result; +} diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h index 288d6cd3e5cfab..1630ac90b21dd6 100644 --- a/torch/csrc/dynamo/extra_state.h +++ b/torch/csrc/dynamo/extra_state.h @@ -47,10 +47,19 @@ typedef struct CacheEntry CacheEntry; #ifdef __cplusplus +typedef struct VISIBILITY_HIDDEN PrecompileEntry { + py::object guard_manager; + py::object code; + void* root_mgr; + + PrecompileEntry(py::object gm, py::object c); +} PrecompileEntry; + typedef struct VISIBILITY_HIDDEN ExtraState { // A pointer to the orig_code object to prevent race conditions in invalidate // function. PyCodeObject* orig_code; + std::list precompile_entries; // List of cache entries for compiled code objects std::list cache_entry_list; // Frame state to detect dynamic shape dims @@ -68,6 +77,7 @@ typedef struct VISIBILITY_HIDDEN ExtraState { #else typedef struct ExtraState ExtraState; +typedef struct PrecompileEntry PrecompileEntry; #endif @@ -122,7 +132,7 @@ void destroy_extra_state(void* obj); // Clears the existing object sitting on the extra scratch spance and sets it // up with the new state. Note that _PyCode_SetExtra calls the // destroy_extra_state deleter internally, and therefore we don't call it -// explicity here. +// explicitly here. // Ownership contract // args @@ -138,7 +148,7 @@ void destroy_extra_state(void* obj); // scratch space. void set_extra_state(PyCodeObject* code, ExtraState* extra_state); -// Creates a new extra state and put it on the extra scrach space of the code +// Creates a new extra state and put it on the extra scratch space of the code // object. // Ownership contract @@ -187,5 +197,11 @@ PyObject* get_backend(PyObject* callback); // Returns the list of CacheEntry corresponding to code_obj. // Warning: returns references whose lifetimes are controlled by C++ py::list _debug_get_cache_entry_list(const py::handle& code_obj); +void _reset_precompile_entries(const py::handle& code_obj); +void _load_precompile_entry( + const py::handle& code_obj, + py::object guard_manager, + py::object dynamo_code); +py::list _debug_get_precompile_entries(const py::handle& code_obj); #endif diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index e18a69a0fe82c5..83fb0adbe6c9a6 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -60,7 +61,7 @@ typedef struct { PyTupleObject* it_seq; /* Set to NULL when iterator is exhausted */ } _PyTupleIterObject; -// Copied from CPython, and given a unified name for different Python verions. +// Copied from CPython, and given a unified name for different Python versions. // https://github.com/python/cpython/blob/7f71003b222ad398713514c2b55d34dc05dba6bc/Objects/rangeobject.c#L765-L771 typedef struct { PyObject_HEAD @@ -124,7 +125,7 @@ TensorCheck::TensorCheck( // See note in guards.py [Note - On Export Tensor Guards] // Logic parallel to here must be maintained in python bool TensorCheck::check(const LocalState& state, const at::Tensor& v) { - // In terms of a sparse_csr tensor, it does not support strides informatio + // In terms of a sparse_csr tensor, it does not support strides information c10::SymIntArrayRef sym_strides(std::vector(v.ndimension(), -1)); bool does_not_support_stride = v.layout() == c10::kSparseCsr || v.layout() == c10::kSparseCsc || v.layout() == c10::kSparseBsc || @@ -590,7 +591,7 @@ struct GlobalStateGuard { _torch_function_all_disabled = at::impl::torch_function_all_disabled(); _deterministic_algorithms = ctx.deterministicAlgorithms(); _deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly(); - _allow_tf32 = ctx.allowTF32CuBLAS(); + _allow_tf32 = ctx.float32Precision("cuda", "matmul") == "tf32"; _allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS(); _allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS(); _num_threads = at::get_num_threads(); @@ -607,7 +608,7 @@ struct GlobalStateGuard { _deterministic_algorithms == ctx.deterministicAlgorithms() && _deterministic_algorithms_warn_only == ctx.deterministicAlgorithmsWarnOnly() && - _allow_tf32 == ctx.allowTF32CuBLAS() && + _allow_tf32 == (ctx.float32Precision("cuda", "matmul") == "tf32") && _allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() && _allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() && _num_threads == at::get_num_threads()) && @@ -628,7 +629,7 @@ struct GlobalStateGuard { if (_deterministic_algorithms_warn_only != ctx.deterministicAlgorithmsWarnOnly()) os << "deterministic_algorithms_warn_only "; - if (_allow_tf32 != ctx.allowTF32CuBLAS()) + if (_allow_tf32 != (ctx.float32Precision("cuda", "matmul") == "tf32")) os << "allow_tf32 "; if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS()) os << "allow_fp16_reduce "; @@ -744,11 +745,11 @@ static PyMethodDef GlobalStateGuard_methods[] = { (PyCFunction)(void*)GlobalStateGuard_reason, METH_NOARGS, "Return string reason for guard check failing"}, - {"dump", + {"__getstate__", (PyCFunction)(void*)GlobalStateGuard_dump, METH_NOARGS, "Return serialized json format"}, - {"load", + {"__setstate__", (PyCFunction)(void*)GlobalStateGuard_load, METH_VARARGS, "Parse serialized json format"}, @@ -844,21 +845,38 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) { PyObject* item = nullptr; PyObject* size = nullptr; PyObject* stride = nullptr; - if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) { + const char* op_name = nullptr; + + if (!PyArg_ParseTuple(args, "OOO|s", &item, &size, &stride, &op_name)) { return nullptr; } if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) { - PyErr_SetString(PyExc_TypeError, "expected Tensor()"); + std::stringstream msg; + msg << "expected Tensor()"; + if (op_name) { + msg << " for op: " << op_name; + } + PyErr_SetString(PyExc_TypeError, msg.str().c_str()); return nullptr; } if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) { - PyErr_SetString(PyExc_TypeError, "expected tuple()"); + std::stringstream msg; + msg << "expected tuple()"; + if (op_name) { + msg << " for op: " << op_name; + } + PyErr_SetString(PyExc_TypeError, msg.str().c_str()); return nullptr; } at::Tensor tensor = THPVariable_Unpack(item); int64_t ndim = tensor.ndimension(); if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) { - PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions"); + std::stringstream msg; + msg << "wrong number of dimensions" << ndim; + if (op_name) { + msg << " for op: " << op_name; + } + PyErr_SetString(PyExc_AssertionError, msg.str().c_str()); return nullptr; } @@ -887,6 +905,9 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) { } if (num_errors) { + if (op_name) { + msg << "\nError in op: " << op_name; + } msg << "\nThis error most often comes from a incorrect fake (aka meta) kernel for a custom op."; msg << "\nUse torch.library.opcheck to test your custom op."; msg << "\nSee https://pytorch.org/docs/stable/library.html#torch.library.opcheck"; @@ -904,15 +925,27 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) { */ PyObject* item = nullptr; unsigned long alignment = 0; - if (!PyArg_ParseTuple(args, "Ok", &item, &alignment)) { + const char* op_name = nullptr; + + if (!PyArg_ParseTuple(args, "Ok|s", &item, &alignment, &op_name)) { return nullptr; } if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) { - PyErr_SetString(PyExc_TypeError, "expected Tensor()"); + std::stringstream msg; + msg << "expected Tensor()"; + if (op_name) { + msg << " for op: " << op_name; + } + PyErr_SetString(PyExc_TypeError, msg.str().c_str()); return nullptr; } if (alignment == 0) { - PyErr_SetString(PyExc_AssertionError, "alignment can not be 0"); + std::stringstream msg; + msg << "alignment cannot be 0"; + if (op_name) { + msg << " in op: " << op_name; + } + PyErr_SetString(PyExc_AssertionError, msg.str().c_str()); return nullptr; } @@ -922,7 +955,10 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) { size_t itemsize = tensor.itemsize(); if (storage_offset * itemsize % alignment != 0) { std::stringstream msg; - msg << "Expect the tensor to be " << alignment + if (op_name) { + msg << "\nError in op: " << op_name; + } + msg << "\nExpect the tensor to be " << alignment << " bytes aligned. Fail due to storage_offset=" << storage_offset << " itemsize=" << itemsize; PyErr_SetString(PyExc_AssertionError, msg.str().c_str()); @@ -1377,7 +1413,7 @@ class StorageOverlapChecker { */ std::vector _tensors_from( const std::vector& objects, - int64_t size) { + size_t size) { std::vector tensors; tensors.reserve(size); std::transform( @@ -1457,12 +1493,6 @@ class DictGuardManager; */ class LeafGuard { public: - // Most guards do not need root guard manager. - LeafGuard(py::object verbose_code_parts) - : _verbose_code_parts(std::move(verbose_code_parts)) {} - - // Guards like TENSOR_MATCH require root_guard_manager to access local_state - // shared across all leaf guards. LeafGuard(RootGuardManager* root_guard_manager, py::object verbose_code_parts) : _root_guard_manager(root_guard_manager), _verbose_code_parts(std::move(verbose_code_parts)) {} @@ -1524,8 +1554,11 @@ class LeafGuard { */ class LAMBDA_GUARD : public LeafGuard { public: - LAMBDA_GUARD(py::object guard_check_fn, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) { + LAMBDA_GUARD( + RootGuardManager* root_guard_manager, + py::object guard_check_fn, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { if (py::isinstance(guard_check_fn)) { _guard_check_fn = py::cast(std::move(guard_check_fn)); } else { @@ -1570,8 +1603,11 @@ class LAMBDA_GUARD : public LeafGuard { class TYPE_MATCH : public LeafGuard { public: // type_id = id(type(obj)) - TYPE_MATCH(py::object type_id, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + TYPE_MATCH( + RootGuardManager* root_guard_manager, + py::object type_id, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _expected(py::cast(std::move(type_id))) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1587,8 +1623,11 @@ class TYPE_MATCH : public LeafGuard { class ID_MATCH : public LeafGuard { public: // obj_id = id(obj) - ID_MATCH(py::object obj_id, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + ID_MATCH( + RootGuardManager* root_guard_manager, + py::object obj_id, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _expected(py::cast(std::move(obj_id))) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1603,8 +1642,10 @@ class ID_MATCH : public LeafGuard { class NONE_MATCH : public LeafGuard { public: - NONE_MATCH(py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) {} + NONE_MATCH( + RootGuardManager* root_guard_manager, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} bool check_nopybind(PyObject* value) override { // borrowed ref return value == Py_None; @@ -1613,8 +1654,10 @@ class NONE_MATCH : public LeafGuard { class TRUE_MATCH : public LeafGuard { public: - TRUE_MATCH(py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) {} + TRUE_MATCH( + RootGuardManager* root_guard_manager, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} bool check_nopybind(PyObject* value) override { // borrowed ref return value == Py_True; @@ -1623,8 +1666,10 @@ class TRUE_MATCH : public LeafGuard { class FALSE_MATCH : public LeafGuard { public: - FALSE_MATCH(py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) {} + FALSE_MATCH( + RootGuardManager* root_guard_manager, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} bool check_nopybind(PyObject* value) override { // borrowed ref return value == Py_False; @@ -1633,8 +1678,11 @@ class FALSE_MATCH : public LeafGuard { class EQUALS_MATCH : public LeafGuard { public: - EQUALS_MATCH(py::object value, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + EQUALS_MATCH( + RootGuardManager* root_guard_manager, + py::object value, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _value(value), _value_type(Py_TYPE(value.ptr())) {} @@ -1671,12 +1719,13 @@ class EQUALS_MATCH : public LeafGuard { class RANGE_ITERATOR_MATCH : public LeafGuard { public: RANGE_ITERATOR_MATCH( + RootGuardManager* root_guard_manager, py::object start, py::object stop, py::object step, py::object type_id, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _type_id(py::cast(std::move(type_id))) { PyObject* start_obj = start.ptr(); PyObject* stop_obj = stop.ptr(); @@ -1717,10 +1766,11 @@ class RANGE_ITERATOR_MATCH : public LeafGuard { class TUPLE_ITERATOR_LEN : public LeafGuard { public: TUPLE_ITERATOR_LEN( + RootGuardManager* root_guard_manager, py::object length, py::object type_id, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _length(py::cast(std::move(length))), _type_id(py::cast(std::move(type_id))) {} @@ -1745,8 +1795,11 @@ class TUPLE_ITERATOR_LEN : public LeafGuard { class LENGTH_CHECK : public LeafGuard { public: - LENGTH_CHECK(py::object value, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + LENGTH_CHECK( + RootGuardManager* root_guard_manager, + py::object value, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _length(py::cast(std::move(value))) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1762,8 +1815,11 @@ class LENGTH_CHECK : public LeafGuard { class DICT_LENGTH : public LeafGuard { public: - DICT_LENGTH(py::object value, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + DICT_LENGTH( + RootGuardManager* root_guard_manager, + py::object value, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _length(py::cast(std::move(value))) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1777,8 +1833,8 @@ class DICT_LENGTH : public LeafGuard { class NOT_NONE : public LeafGuard { public: - NOT_NONE(py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) {} + NOT_NONE(RootGuardManager* root_guard_manager, py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} bool check_nopybind(PyObject* value) override { // borrowed ref return value != Py_None; @@ -1787,8 +1843,11 @@ class NOT_NONE : public LeafGuard { class MAPPING_KEYS_MATCH : public LeafGuard { public: - MAPPING_KEYS_MATCH(py::object value, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) { + MAPPING_KEYS_MATCH( + RootGuardManager* root_guard_manager, + py::object value, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { // This is ok to stash in the state because we only support // MappingProxyType objects with constant keys. So, the mem overhead is // negligible. @@ -1808,8 +1867,10 @@ class MAPPING_KEYS_MATCH : public LeafGuard { class DEFAULT_DEVICE : public LeafGuard { public: - DEFAULT_DEVICE(py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) { + DEFAULT_DEVICE( + RootGuardManager* root_guard_manager, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { py::handle device_module = py::module::import("torch.utils._device"); // Save the dict using py::object _utils_device_dict = device_module.attr("__dict__"); @@ -1853,10 +1914,25 @@ class DEFAULT_DEVICE : public LeafGuard { class GLOBAL_STATE : public LeafGuard { public: - GLOBAL_STATE(py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) { - _guard = std::make_unique(); + GLOBAL_STATE( + RootGuardManager* root_guard_manager, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), + _guard(PyObject_New(GlobalStateGuard, &GlobalStateGuardType)) { _guard->init(); + owner_ = py::reinterpret_steal((PyObject*)_guard); + } + + GLOBAL_STATE( + RootGuardManager* root, + py::object initial_state, + py::object verbose_code_parts) + : LeafGuard(root, std::move(verbose_code_parts)), + owner_(std::move(initial_state)), + _guard((GlobalStateGuard*)owner_.ptr()) { + if (!PyObject_TypeCheck(owner_.ptr(), &GlobalStateGuardType)) { + throw py::type_error("GLOBAL_STATE expects a GlobalStateGuard"); + } } bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1878,7 +1954,8 @@ class GLOBAL_STATE : public LeafGuard { } private: - std::unique_ptr _guard; + py::object owner_; + GlobalStateGuard* _guard; }; // Checks that an attr is absent in the object. We don't need the opposite @@ -1886,8 +1963,11 @@ class GLOBAL_STATE : public LeafGuard { // HASATTR guard. class NO_HASATTR : public LeafGuard { public: - NO_HASATTR(py::object attr_name, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + NO_HASATTR( + RootGuardManager* root_guard_manager, + py::object attr_name, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _attr_name(std::move(attr_name)) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1905,8 +1985,12 @@ class NO_HASATTR : public LeafGuard { // being faster. class DICT_CONTAINS : public LeafGuard { public: - DICT_CONTAINS(bool contains, py::object key, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + DICT_CONTAINS( + RootGuardManager* root_guard_manager, + bool contains, + py::object key, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _contains(contains ? 1 : 0), _key(std::move(key)) {} @@ -1924,6 +2008,33 @@ class DICT_CONTAINS : public LeafGuard { py::object _key; }; +// Check that set contains an item. +class SET_CONTAINS : public LeafGuard { + public: + SET_CONTAINS( + RootGuardManager* root_guard_manager, + bool contains, + py::object item, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), + _contains(contains ? 1 : 0), + _item(std::move(item)) {} + + bool check_nopybind(PyObject* value) override { // borrowed ref + int result = (PySet_Check(value) || PyFrozenSet_Check(value)) && + PySet_Contains(value, _item.ptr()); + if (result == -1) { + PyErr_Clear(); + return false; + } + return result == _contains; + } + + private: + int _contains; + py::object _item; +}; + /** * Relational guards compare more than one value. We implement Relational * guards by capturing some state in the guard object. For example for tensor @@ -1941,8 +2052,10 @@ class DICT_CONTAINS : public LeafGuard { */ class RelationalGuard : public LeafGuard { public: - RelationalGuard(py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) {} + RelationalGuard( + RootGuardManager* root_guard_manager, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} // reset the relational guard state on guard failure. This is called by the // guard manager. @@ -1954,8 +2067,10 @@ class RelationalGuard : public LeafGuard { */ class OBJECT_ALIASING : public RelationalGuard { public: - OBJECT_ALIASING(py::object verbose_code_parts) - : RelationalGuard(std::move(verbose_code_parts)) {} + OBJECT_ALIASING( + RootGuardManager* root_guard_manager, + py::object verbose_code_parts) + : RelationalGuard(root_guard_manager, std::move(verbose_code_parts)) {} bool check_nopybind(PyObject* value) override { // borrowed ref if (_is_first_call) { @@ -1981,9 +2096,10 @@ class OBJECT_ALIASING : public RelationalGuard { class NO_TENSOR_ALIASING : public RelationalGuard { public: NO_TENSOR_ALIASING( + RootGuardManager* root_guard_manager, const py::list& tensor_names, py::object verbose_code_parts) - : RelationalGuard(std::move(verbose_code_parts)), + : RelationalGuard(root_guard_manager, std::move(verbose_code_parts)), _tensor_names(tensor_names) { _unique_tensors.reserve(tensor_names.size()); } @@ -2031,10 +2147,11 @@ class NO_TENSOR_ALIASING : public RelationalGuard { class STORAGE_OVERLAPPING : public RelationalGuard { public: STORAGE_OVERLAPPING( + RootGuardManager* root_guard_manager, bool overlapping, std::shared_ptr checker, py::object verbose_code_parts) - : RelationalGuard(std::move(verbose_code_parts)), + : RelationalGuard(root_guard_manager, std::move(verbose_code_parts)), _overlapping(overlapping), _checker(std::move(checker)) {} @@ -2062,12 +2179,13 @@ class STORAGE_OVERLAPPING : public RelationalGuard { class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { public: SYMBOLIC_SHAPE_GUARD( + RootGuardManager* root_guard_manager, py::int_ nargs_int, py::int_ nargs_float, py::int_ py_addr, py::object py_addr_keep_alive, py::object verbose_code_parts) - : RelationalGuard(std::move(verbose_code_parts)), + : RelationalGuard(root_guard_manager, std::move(verbose_code_parts)), _py_addr_keep_alive(std::move(py_addr_keep_alive)) { _nargs_int = PyLong_AsSize_t(nargs_int.ptr()); _nargs_float = PyLong_AsSize_t(nargs_float.ptr()); @@ -2175,10 +2293,12 @@ class DYNAMIC_INDICES : public LeafGuard { // f"(({tensor_name}._dynamo_dynamic_indices.issubset({value._dynamo_dynamic_indices})) // if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # // noqa: B950 - // ) public: - DYNAMIC_INDICES(py::set dynamic_indices, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), + DYNAMIC_INDICES( + RootGuardManager* root_guard_manager, + py::set dynamic_indices, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), _dynamic_indices(std::move(dynamic_indices)) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -2208,8 +2328,11 @@ class DYNAMIC_INDICES : public LeafGuard { class DICT_VERSION : public LeafGuard { public: - DICT_VERSION(py::object value, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) { + DICT_VERSION( + RootGuardManager* root_guard_manager, + py::object value, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { if (!PyDict_Check(value.ptr())) { throw py::type_error("DICT_VERSION expects a dict"); } @@ -2372,7 +2495,7 @@ class GuardAccessor { * value passed to the check function to call the check function of the child * guard manager. * - * Performace optimization for fail fast - An optimization for runtime here is + * Performance optimization for fail fast - An optimization for runtime here is * to sort the execution of child guards depending on the failure count. This * ensures that we run the guards that are more prone to fail statistically * first. This can improve the cache lookup time when we have multiple cache @@ -2796,7 +2919,7 @@ class RootGuardManager : public GuardManager { template bool check_nopybind_template(T* value) { // borrowed ref // Check [Note on GIL interaction with mutex lock] for details on why we - // need mutex and its interactions wth GIL. + // need mutex and its interactions with GIL. PyThreadState* _save = nullptr; Py_UNBLOCK_THREADS; // ; is added to avoid clang-formatting std::lock_guard lock_guard(_lock); @@ -2854,7 +2977,7 @@ class RootGuardManager : public GuardManager { GuardDebugInfo check_verbose_nopybind( PyObject* value) override { // borrowed ref // Check [Note on GIL interaction with mutex lock] for details on why we - // need mutex and its interactions wth GIL. + // need mutex and its interactions with GIL. PyThreadState* _save = nullptr; Py_UNBLOCK_THREADS; // ; is added to avoid clang-formatting std::lock_guard lock_guard(_lock); @@ -2957,7 +3080,7 @@ class RootGuardManager : public GuardManager { LocalState _local_state; private: - // All the relational guards under this guard mananger. We only use these + // All the relational guards under this guard manager. We only use these // when the guard evaluates to False. This ensures that guard state is reset // on guard failure so that next invocation is clean. std::vector> _relational_guard_resetters; @@ -3378,9 +3501,10 @@ std::unique_ptr make_guard_manager( class TORCH_FUNCTION_MODE_STACK : public LeafGuard { public: TORCH_FUNCTION_MODE_STACK( + RootGuardManager* root_guard_manager, const py::list& initial_stack, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)) { + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref @@ -3540,7 +3664,7 @@ class TENSOR_MATCH : public LeafGuard { }; /** - * Represents __getattr__ acccessor. + * Represents __getattr__ accessor. */ class GetAttrGuardAccessor : public GuardAccessor { public: @@ -3588,7 +3712,7 @@ class GetAttrGuardAccessor : public GuardAccessor { } std::string repr() const override { - // Helpful when priting GuardManager tree structure. + // Helpful when printing GuardManager tree structure. return "GetAttrGuardAccessor(" + py::str(_attr_name).cast() + ")"; } @@ -3616,7 +3740,7 @@ class GetAttrGuardAccessor : public GuardAccessor { }; /** - * Represents object.__getattribute__(obj, attr_name) acccessor. + * Represents object.__getattribute__(obj, attr_name) accessor. */ class GenericGetAttrGuardAccessor : public GuardAccessor { public: @@ -3664,7 +3788,7 @@ class GenericGetAttrGuardAccessor : public GuardAccessor { } std::string repr() const override { - // Helpful when priting GuardManager tree structure. + // Helpful when printing GuardManager tree structure. return "GenericGetAttrGuardAccessor(" + py::str(_attr_name).cast() + ")"; } @@ -3695,7 +3819,7 @@ class GenericGetAttrGuardAccessor : public GuardAccessor { }; /** - * Represents x.__dict__ acccessor. + * Represents x.__dict__ accessor. */ class GetGenericDictGuardAccessor : public GuardAccessor { public: @@ -3742,7 +3866,7 @@ class GetGenericDictGuardAccessor : public GuardAccessor { } std::string repr() const override { - // Helpful when priting GuardManager tree structure. + // Helpful when printing GuardManager tree structure. return "GetGenericDictGuardAccessor"; } @@ -3763,7 +3887,7 @@ class GetGenericDictGuardAccessor : public GuardAccessor { }; /** - * Represents __getitem__ acccessor. + * Represents __getitem__ accessor. */ class GetItemGuardAccessor : public GuardAccessor { public: @@ -3960,7 +4084,7 @@ class FrameLocalsGuardAccessor : public GuardAccessor { }; /** - * Represents dict[name] acccessor. Needed since DictGuardManager does not + * Represents dict[name] accessor. Needed since DictGuardManager does not * support sorting. We differentiate it from GetItemGuardAccessor because * PyDict_GetItem should be faster than PyObject_GetItem. */ @@ -3988,7 +4112,7 @@ class DictGetItemGuardAccessor : public GuardAccessor { _guard_manager->has_no_accessors()) { // immutable object and dict tag matches, we can skip the guard subtree. // NB: We only skip the subtree if there are no accessors in the subtree. - // This is specificallly for tensors which are used in symbolic shape C++ + // This is specifically for tensors which are used in symbolic shape C++ // guards, and therefore have accessors on the tensor GuardManager itself. return true; } @@ -4117,6 +4241,82 @@ class ListGetItemGuardAccessor : public GuardAccessor { Py_ssize_t _index{-1}; }; +/** + * Represents set[index] accessor by converting the set into a list. + */ +class SetGetItemGuardAccessor : public GuardAccessor { + public: + SetGetItemGuardAccessor( + RootGuardManager* root, + const py::object& index, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + index, + std::move(source), + example_value, + guard_manager_enum), + _index(py::cast(index)) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) + override { // borrowed ref + + PyObject* lst = PySequence_List(obj); + PyObject* x = PyList_GetItem(lst, _index); // borrowed ref + Py_XDECREF(lst); + if (x == nullptr) { + PyErr_Clear(); + return false; + } + bool result = _guard_manager->check_nopybind(x); + return result; + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + + PyObject* lst = PySequence_List(obj); + PyObject* x = PyList_GetItem(lst, _index); // borrowed ref + Py_XDECREF(lst); + + if (x == nullptr) { + PyErr_Clear(); + return GuardDebugInfo(false, 0); + } + GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x); + return result; + } + + std::string repr() const override { + return fmt::format("SetGetItemGuardAccessor(index={})", _index); + } + + public: // cloning functions + SetGetItemGuardAccessor( + GuardManager* guard_manager, + SetGetItemGuardAccessor* from) + : GuardAccessor(guard_manager, from) { + from->clone_visitor(this); + } + + GuardAccessor* clone( + RootGuardManager* cloned_root, + const py::function& clone_filter_fn) override { + return clone_common(cloned_root, clone_filter_fn); + } + + void clone_visitor(SetGetItemGuardAccessor* to) { + to->_index = _index; + } + + private: + Py_ssize_t _index{-1}; +}; + /** * Represents tuple[index] accessor. It is faster than generic * GetItemGuardAccessor. @@ -4209,7 +4409,7 @@ std::string to_string(TensorProperty prop) { } /** - * Represents tensor.size/shape/storage_offset acccessor. + * Represents tensor.size/shape/storage_offset accessor. */ template class TensorPropertyGuardAccessor : public GuardAccessor { @@ -4307,7 +4507,7 @@ class TensorPropertyGuardAccessor : public GuardAccessor { } std::string repr() const override { - // Helpful when priting GuardManager tree structure. + // Helpful when printing GuardManager tree structure. return "TensorPropertyGuardAccessor<" + to_string(_prop) + +">(" + std::to_string(_index) + ")"; } @@ -4399,7 +4599,7 @@ class IndexedGuardAccessor : public GuardAccessor { }; /** - * Represents tensor.grad acccessor. + * Represents tensor.grad accessor. */ class GradGuardAccessor : public GuardAccessor { public: @@ -4450,7 +4650,7 @@ class GradGuardAccessor : public GuardAccessor { } std::string repr() const override { - // Helpful when priting GuardManager tree structure. + // Helpful when printing GuardManager tree structure. return "GradGuardAccessor(grad)"; } @@ -4619,7 +4819,7 @@ class FuncKwDefaultsGuardAccessor : public GuardAccessor { }; /** - * Represents f_globals acccessor. This sits as a child accessor of the + * Represents f_globals accessor. This sits as a child accessor of the * RootGuardManager. */ class GlobalsGuardAccessor : public GuardAccessor { @@ -4812,7 +5012,7 @@ class TupleIteratorGetItemAccessor : public GuardAccessor { * GlobalWeakRef accessor. Dynamo can insert a weakref object into the frame * globals. This accessor reads the globals and then calls the weakref object * to get the underlying object. This is a child of GlobalsGuardAccessor. - * Therefore, we will get the globals dict while caling check_nopybind. + * Therefore, we will get the globals dict while calling check_nopybind. */ class GlobalWeakRefGuardAccessor : public GuardAccessor { public: @@ -5169,10 +5369,10 @@ void install_object_aliasing_guard( py::object verbose_code_parts) { // Adds tensor X is tensor Y guard. This is a an example of relational guard. // There is one guard object that is shared between two guard managers. - std::shared_ptr guard = - std::make_shared(std::move(verbose_code_parts)); + std::shared_ptr guard = std::make_shared( + x->get_root(), std::move(verbose_code_parts)); - // Register the resetter on the root guard mananger, so that it can reset + // Register the resetter on the root guard manager, so that it can reset // the newly added relational guard when the guard eval fails. x->get_root()->add_relational_guard_resetter(guard); @@ -5190,9 +5390,11 @@ void install_no_tensor_aliasing_guard( // relational guard. There is one guard object that is shared between multiple // guard managers. std::shared_ptr guard = std::make_shared( - tensor_names, std::move(verbose_code_parts)); + py::cast(guard_managers[0])->get_root(), + tensor_names, + std::move(verbose_code_parts)); - // Register the resetter on the root guard mananger, so that it can reset + // Register the resetter on the root guard manager, so that it can reset // the newly added relational guard when the guard eval fails. py::cast(guard_managers[0]) ->get_root() @@ -5214,13 +5416,14 @@ void install_symbolic_shape_guard( // multiple guard managers. std::shared_ptr guard = std::make_shared( + py::cast(guard_managers[0])->get_root(), std::move(nargs_int), std::move(nargs_float), std::move(py_addr), std::move(py_addr_keep_alive), std::move(verbose_code_parts)); - // Register the resetter on the root guard mananger, so that it can reset + // Register the resetter on the root guard manager, so that it can reset // the newly added relational guard when the guard eval fails. py::cast(guard_managers[0]) ->get_root() @@ -5243,7 +5446,10 @@ void install_storage_overlapping_guard_with_checker( std::shared_ptr guard = std::make_shared( - overlapping, checker, verbose_code_parts); + py::cast(guard_managers[0])->get_root(), + overlapping, + checker, + verbose_code_parts); py::cast(guard_managers[0]) ->get_root() ->add_relational_guard_resetter(guard); @@ -5276,23 +5482,40 @@ void install_storage_overlapping_guard( /* overlapping= */ false); } +char flush_cache_by_eviction() { + constexpr size_t evict_size = 32 * 1024 * 1024; + std::vector buffer(evict_size, 1); + + volatile char sink = 0; + for (size_t i = 0; i < buffer.size(); i += 64) { + sink ^= buffer[i]; + } + return sink; +} + double profile_guard_manager( RootGuardManager* root, py::object f_locals, int n_iters) { PyObject* locals = f_locals.ptr(); - // Warmup + // Warmup to setup fast paths (like dict_tags) for the actual profiling for (int i = 0; i < 5; i++) { root->check_nopybind(locals); } - auto start = std::chrono::high_resolution_clock::now(); + std::chrono::duration total_elapsed{0.0}; for (int i = 0; i < n_iters; i++) { + // Flush the caches to accurately measure the overhead + // store into a volatile to prevent optimization + volatile char dummy = flush_cache_by_eviction(); + (void)dummy; + + auto start = std::chrono::high_resolution_clock::now(); root->check_nopybind(locals); + auto end = std::chrono::high_resolution_clock::now(); + total_elapsed += end - start; } - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration total_elapsed = end - start; // Calculate the average time per iteration in microseconds return (total_elapsed.count() * 1e6) / n_iters; @@ -5404,67 +5627,73 @@ PyObject* torch_c_dynamo_guards_init() { .def("verbose_code_parts", &LeafGuard::verbose_code_parts); py::class_>( py_m, "LAMBDA_GUARD") - .def(py::init()) + .def(py::init()) .def("__call__", &LAMBDA_GUARD::check); py::class_>( py_m, "TYPE_MATCH") - .def(py::init()) + .def(py::init()) .def("__call__", &TYPE_MATCH::check); py::class_>(py_m, "ID_MATCH") - .def(py::init()) + .def(py::init()) .def("__call__", &ID_MATCH::check); py::class_>( py_m, "NONE_MATCH") - .def(py::init()) + .def(py::init()) .def("__call__", &NONE_MATCH::check); py::class_>( py_m, "TRUE_MATCH") - .def(py::init()) + .def(py::init()) .def("__call__", &TRUE_MATCH::check); py::class_>( py_m, "FALSE_MATCH") - .def(py::init()) + .def(py::init()) .def("__call__", &FALSE_MATCH::check); py::class_>( py_m, "EQUALS_MATCH") - .def(py::init()) + .def(py::init()) .def("__call__", &EQUALS_MATCH::check); py::class_>( py_m, "LENGTH_CHECK") - .def(py::init()) + .def(py::init()) .def("__call__", &LENGTH_CHECK::check); py::class_>( py_m, "DICT_LENGTH") - .def(py::init()) + .def(py::init()) .def("__call__", &DICT_LENGTH::check); py::class_>( py_m, "DEFAULT_DEVICE") - .def(py::init()) + .def(py::init()) .def("__call__", &DEFAULT_DEVICE::check); py::class_>(py_m, "NOT_NONE") - .def(py::init()) + .def(py::init()) .def("__call__", &NOT_NONE::check); py::class_< MAPPING_KEYS_MATCH, LeafGuard, std::shared_ptr>(py_m, "MAPPING_KEYS_MATCH") - .def(py::init()) + .def(py::init()) .def("__call__", &MAPPING_KEYS_MATCH::check); py::class_< TUPLE_ITERATOR_LEN, LeafGuard, std::shared_ptr>(py_m, "TUPLE_ITERATOR_LEN") - .def(py::init()) + .def(py::init()) .def("__call__", &TUPLE_ITERATOR_LEN::check); py::class_< RANGE_ITERATOR_MATCH, LeafGuard, std::shared_ptr>(py_m, "RANGE_ITERATOR_MATCH") - .def(py::init()) + .def(py::init< + RootGuardManager*, + py::object, + py::object, + py::object, + py::object, + py::list>()) .def("__call__", &RANGE_ITERATOR_MATCH::check); py::class_>( py_m, "GLOBAL_STATE") - .def(py::init()) + .def(py::init()) .def("check_verbose", &GLOBAL_STATE::check_verbose) .def("__call__", &GLOBAL_STATE::check); py::class_< @@ -5472,23 +5701,27 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") - .def(py::init()) + .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( py_m, "NO_HASATTR") - .def(py::init()) + .def(py::init()) .def("__call__", &NO_HASATTR::check); py::class_>( py_m, "DICT_CONTAINS") - .def(py::init()) + .def(py::init()) .def("__call__", &DICT_CONTAINS::check); + py::class_>( + py_m, "SET_CONTAINS") + .def(py::init()) + .def("__call__", &SET_CONTAINS::check); py::class_>( py_m, "DYNAMIC_INDICES") - .def(py::init()) + .def(py::init()) .def("__call__", &DYNAMIC_INDICES::check); py::class_>( py_m, "DICT_VERSION") - .def(py::init()) + .def(py::init()) .def("__call__", &DICT_VERSION::check); py::class_< DISPATCH_KEY_SET_MATCH, @@ -5659,7 +5892,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object lambda, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - std::move(lambda), std::move(verbose_code_parts))); + self.get_root(), + std::move(lambda), + std::move(verbose_code_parts))); }) .def( "add_type_match_guard", @@ -5668,7 +5903,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("TYPE_MATCH"); self.add_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_id_match_guard", @@ -5677,28 +5914,30 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("ID_MATCH"); self.add_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_none_match_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("NONE_MATCH"); - self.add_leaf_guard( - std::make_shared(std::move(verbose_code_parts))); + self.add_leaf_guard(std::make_shared( + self.get_root(), std::move(verbose_code_parts))); }) .def( "add_true_match_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("TRUE_MATCH"); - self.add_leaf_guard( - std::make_shared(std::move(verbose_code_parts))); + self.add_leaf_guard(std::make_shared( + self.get_root(), std::move(verbose_code_parts))); }) .def( "add_false_match_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("FALSE_MATCH"); - self.add_leaf_guard( - std::make_shared(std::move(verbose_code_parts))); + self.add_leaf_guard(std::make_shared( + self.get_root(), std::move(verbose_code_parts))); }) .def( "add_equals_match_guard", @@ -5707,7 +5946,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("EQUALS_MATCH"); self.add_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_length_check_guard", @@ -5716,7 +5957,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("LENGTH_CHECK"); self.add_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_dict_length_check_guard", @@ -5725,7 +5968,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("DICT_LENGTH"); self.add_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_tuple_iterator_length_guard", @@ -5735,6 +5980,7 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("TUPLE_ITERATOR_LEN"); self.add_leaf_guard(std::make_shared( + self.get_root(), std::move(length), std::move(type_id), std::move(verbose_code_parts))); @@ -5749,6 +5995,7 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("RANGE_ITERATOR_MATCH"); self.add_leaf_guard(std::make_shared( + self.get_root(), std::move(start), std::move(stop), std::move(step), @@ -5759,14 +6006,14 @@ PyObject* torch_c_dynamo_guards_init() { "add_default_device_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - std::move(verbose_code_parts))); + self.get_root(), std::move(verbose_code_parts))); }) .def( "add_not_none_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("NOT_NONE"); - self.add_leaf_guard( - std::make_shared(std::move(verbose_code_parts))); + self.add_leaf_guard(std::make_shared( + self.get_root(), std::move(verbose_code_parts))); }) .def( "add_mapping_keys_guard", @@ -5775,7 +6022,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("MAPPING_KEYS_MATCH"); self.add_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_dispatch_key_set_guard", @@ -5790,9 +6039,13 @@ PyObject* torch_c_dynamo_guards_init() { }) .def( "add_global_state_guard", - [](GuardManager& self, py::object verbose_code_parts) -> void { - self.add_leaf_guard( - std::make_shared(std::move(verbose_code_parts))); + [](GuardManager& self, + py::object initial_state, + py::object verbose_code_parts) -> void { + self.add_leaf_guard(std::make_shared( + self.get_root(), + std::move(initial_state), + std::move(verbose_code_parts))); }) .def( "add_torch_function_mode_stack_guard", @@ -5800,7 +6053,7 @@ PyObject* torch_c_dynamo_guards_init() { const py::list& initial_stack, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - initial_stack, std::move(verbose_code_parts))); + self.get_root(), initial_stack, std::move(verbose_code_parts))); }) .def( "add_no_hasattr_guard", @@ -5808,7 +6061,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object attr_name, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - std::move(attr_name), std::move(verbose_code_parts))); + self.get_root(), + std::move(attr_name), + std::move(verbose_code_parts))); }) .def( "add_dict_contains_guard", @@ -5817,7 +6072,22 @@ PyObject* torch_c_dynamo_guards_init() { py::object key, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - contains, std::move(key), std::move(verbose_code_parts))); + self.get_root(), + contains, + std::move(key), + std::move(verbose_code_parts))); + }) + .def( + "add_set_contains_guard", + [](GuardManager& self, + bool contains, + py::object item, + py::object verbose_code_parts) -> void { + self.add_leaf_guard(std::make_shared( + self.get_root(), + contains, + std::move(item), + std::move(verbose_code_parts))); }) .def( "add_dynamic_indices_guard", @@ -5825,7 +6095,9 @@ PyObject* torch_c_dynamo_guards_init() { py::set value, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_dict_version_guard", @@ -5834,7 +6106,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("DICT_VERSION"); self.add_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_tensor_match_guard", @@ -6072,6 +6346,14 @@ PyObject* torch_c_dynamo_guards_init() { py::arg("example_value"), py::arg("guard_manager_enum"), py::return_value_policy::reference) + .def( + "set_getitem_manager", + &GuardManager::get_child_manager, + py::arg("index"), + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) // return by reference because GuardManager has the ownership of accessors // and guard managers .def( @@ -6175,7 +6457,7 @@ PyObject* torch_c_dynamo_guards_init() { py::object lambda, py::object verbose_code_parts) -> void { self.add_epilogue_lambda_guard(std::make_unique( - std::move(lambda), std::move(verbose_code_parts))); + &self, std::move(lambda), std::move(verbose_code_parts))); }); // Dict Guard Manager @@ -6238,7 +6520,10 @@ PyObject* torch_c_dynamo_guards_init() { py::object key, py::object verbose_code_parts) -> void { self.add_permitted_leaf_guard(std::make_shared( - contains, std::move(key), std::move(verbose_code_parts))); + self.get_root(), + contains, + std::move(key), + std::move(verbose_code_parts))); }) .def( "add_dict_version_guard", @@ -6247,7 +6532,9 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("DICT_VERSION"); self.add_permitted_leaf_guard(std::make_shared( - std::move(value), std::move(verbose_code_parts))); + self.get_root(), + std::move(value), + std::move(verbose_code_parts))); }) .def( "add_no_hasattr_guard", @@ -6255,9 +6542,11 @@ PyObject* torch_c_dynamo_guards_init() { py::object attr_name, py::object verbose_code_parts) -> void { self.add_permitted_leaf_guard(std::make_shared( - std::move(attr_name), std::move(verbose_code_parts))); + self.get_root(), + std::move(attr_name), + std::move(verbose_code_parts))); }) - // Not permitted accesssors + // Not permitted accessors .def("lambda_manager", &DictGuardManager::fail_on_get_child_manager) .def("getitem_manager", &DictGuardManager::fail_on_get_child_manager) .def("dict_getitem_manager", &DictGuardManager::fail_on_get_child_manager) diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index e62d3cb8b17ae8..2b642ce0bfe804 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -238,6 +238,9 @@ void initDynamoBindings(PyObject* torch) { "update_diff_guard_root_manager", &CacheEntry::update_diff_guard_root_manager); + py::class_(m, "_PrecompileEntry") + .def_readonly("guard_manager", &PrecompileEntry::guard_manager); + py::class_(m, "_ExtraState") .def("invalidate", &ExtraState::invalidate); @@ -257,6 +260,9 @@ void initDynamoBindings(PyObject* torch) { .def_readwrite("recursive_action", &FrameExecStrategy::recursive_action); m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list); + m.def("_reset_precompile_entries", &_reset_precompile_entries); + m.def("_load_precompile_entry", &_load_precompile_entry); + m.def("_debug_get_precompile_entries", &_debug_get_precompile_entries); py::bind_vector>(m, "VectorUInt8"); m.attr("py_opcode_caches") = _PyOpcode_Caches_vec; m.def("code_framelocals_names", &code_framelocals_names); diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 80f09ad9daf43b..1e1783477d2e0a 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include /* @@ -56,6 +57,19 @@ namespace { PyObject* the_autograd_compiler = nullptr; int default_dyn_type_int = 0; PyObject* python_verbose_logger = nullptr; + +constexpr std::string_view _TURN_OFF_COMPILED_AUTOGRAD_MSG = R"( + You can disable compiled autograd for this operation by: + 1. Relocating the unsupported autograd call outside the compiled region. + 2. Wrapping the unsupported autograd call within a scope that disables compiled autograd. + 3. Configuring the specific compilation unit to disable compiled autograd. + 4. Globally disabling compiled autograd at the application's initialization. + )"; + +std::string TURN_OFF_COMPILED_AUTOGRAD_MSG() { + return std::string(_TURN_OFF_COMPILED_AUTOGRAD_MSG); +} + } // namespace // see https://github.com/pytorch/pytorch/pull/34845 @@ -286,9 +300,11 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface { void call_accumulate_grad( PyObject* py_compiler, const at::Tensor& variable, - const at::Tensor& grad) const override { + const at::Tensor& grad, + bool has_post_hooks) const override { py::handle handle(py_compiler); - py::object stuff = handle.attr("accumulate_grad")(variable, grad); + py::object stuff = + handle.attr("accumulate_grad")(variable, grad, has_post_hooks); TORCH_INTERNAL_ASSERT(stuff.is_none()); } }; @@ -1170,9 +1186,10 @@ struct LockGuardWithErrorLogs { // performance reasons, but it shouldn't happen here since we: // 1. disable multithreaded autograd // 2. plenty of latency between backward calls - TORCH_INTERNAL_ASSERT( + TORCH_CHECK_NOT_IMPLEMENTED( mtx_.try_lock(), - "Trying to run compiled autograd within another compiled autograd call (e.g. reentrant checkpointing), this is not supported yet."); + "Trying to run compiled autograd within another compiled autograd call, this is not supported yet. " + + TURN_OFF_COMPILED_AUTOGRAD_MSG()); } ~LockGuardWithErrorLogs() { @@ -1188,9 +1205,6 @@ static variable_list compiled_autograd( const GraphTask& graph_task, bool accumulate_grad, const edge_list& output_edges) { - TORCH_CHECK( - c10::impl::TorchDispatchModeTLS::stack_len() == 0, - "TorchDispatchMode not yet implemented for compiled autograd") static std::mutex mtx; LockGuardWithErrorLogs lock_guard(mtx); pybind11::gil_scoped_acquire gil; @@ -1202,17 +1216,27 @@ static variable_list compiled_autograd( THPObjectPtr ivalue_args; THPObjectPtr hooks; THPObjectPtr packed_inputs; - CacheNode* cache = _compiled_autograd_impl( - graph_root, - graph_task, - accumulate_grad, - output_edges, - &inputs, - &sizes, - &ivalue_args, - &hooks, - &packed_inputs, - active_rstate); + CacheNode* cache = nullptr; + try { + torch_dispatch_mode::StashTorchDispatchStackGuard stash_stack_guard; + TORCH_INTERNAL_ASSERT(c10::impl::TorchDispatchModeTLS::stack_len() == 0); + cache = _compiled_autograd_impl( + graph_root, + graph_task, + accumulate_grad, + output_edges, + &inputs, + &sizes, + &ivalue_args, + &hooks, + &packed_inputs, + active_rstate); + TORCH_INTERNAL_ASSERT(c10::impl::TorchDispatchModeTLS::stack_len() == 0); + } catch (const c10::NotImplementedError& e) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, std::string(e.what()) + " " + TURN_OFF_COMPILED_AUTOGRAD_MSG()); + } + TORCH_INTERNAL_ASSERT(cache != nullptr); THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs( cache->runtime_wrapper.get(), diff --git a/torch/csrc/export/example_upgraders.cpp b/torch/csrc/export/example_upgraders.cpp new file mode 100644 index 00000000000000..398c01301f0f26 --- /dev/null +++ b/torch/csrc/export/example_upgraders.cpp @@ -0,0 +1,89 @@ +#include +#include + +namespace torch::_export { + +/// Register test upgraders for the upgrader system. +/// and shows some common upgrade patterns. +static bool test_upgraders_registered = false; + +void registerExampleUpgraders() { + if (test_upgraders_registered) { + return; + } + + registerUpgrader( + 0, + "graph_module.graph.nodes", + [](const nlohmann::json& nodes_array) -> nlohmann::json { + nlohmann::json upgraded_nodes = nodes_array; + + // Process each node in the nodes array + for (auto& node : upgraded_nodes) { + if (node.contains("metadata") && node["metadata"].is_object()) { + // Process each metadata key-value pair + for (auto& [key, value] : node["metadata"].items()) { + if (key == "nn_module_stack") { + // Transform nn_module_stack values by prepending prefix + if (value.is_string()) { + std::string stack_str = value.get(); + value = "test_upgrader_" + stack_str; + } else { + throwUpgraderError( + "version_0_upgrader_registered", + 0, + "nn_module_stack metadata value must be a string, got: " + + std::string(value.type_name()), + node); + } + } + // Other metadata keys remain unchanged + } + } + } + + return upgraded_nodes; + }); + + registerUpgrader( + 0, + "graph_module.graph", + [](const nlohmann::json& graph_obj) -> nlohmann::json { + nlohmann::json upgraded_graph = graph_obj; + + // Rename field if it exists in the graph object + if (upgraded_graph.contains("old_test_field")) { + upgraded_graph["new_test_field"] = upgraded_graph["old_test_field"]; + upgraded_graph.erase("old_test_field"); + } + + return upgraded_graph; + }); + + registerUpgrader( + 1, + std::vector{"graph_module", "graph"}, + [](const nlohmann::json& graph_obj) -> nlohmann::json { + nlohmann::json upgraded_graph = graph_obj; + + // Continue the field renaming chain from version 0 + if (upgraded_graph.contains("new_test_field")) { + upgraded_graph["new_test_field2"] = upgraded_graph["new_test_field"]; + upgraded_graph.erase("new_test_field"); + } + + return upgraded_graph; + }); + + test_upgraders_registered = true; +} + +/// Deregister test upgraders for the upgrader system. +void deregisterExampleUpgraders() { + deregisterUpgrader(0, "graph_module.graph.nodes"); + deregisterUpgrader(0, "graph_module.graph"); + deregisterUpgrader(1, std::vector{"graph_module", "graph"}); + test_upgraders_registered = false; +} + +} // namespace torch::_export diff --git a/torch/csrc/export/example_upgraders.h b/torch/csrc/export/example_upgraders.h new file mode 100644 index 00000000000000..40e1fb14e7226d --- /dev/null +++ b/torch/csrc/export/example_upgraders.h @@ -0,0 +1,15 @@ +#pragma once + +namespace torch::_export { + +/// Register example upgraders for the upgrader system for testing. +/// This function demonstrates common upgrade patterns and is primarily +/// used for testing and demonstration purposes. +void registerExampleUpgraders(); + +/// Deregister example upgraders for the upgrader system for testing. +/// This function cleans up the example upgraders that were registered +/// by registerExampleUpgraders(). +void deregisterExampleUpgraders(); + +} // namespace torch::_export diff --git a/torch/csrc/export/pt2_archive_constants.h b/torch/csrc/export/pt2_archive_constants.h index 72888e3953859b..804cadccbd43c9 100644 --- a/torch/csrc/export/pt2_archive_constants.h +++ b/torch/csrc/export/pt2_archive_constants.h @@ -32,9 +32,12 @@ namespace torch::_export::archive_spec { /* weights, including parameters and buffers */ \ DO(WEIGHTS_DIR, "data/weights/") \ DO(WEIGHT_FILENAME_PREFIX, "weight_") \ + DO(WEIGHTS_PARAM_CONFIG_FORMAT, "data/weights/{}_model_param_config.json") \ /* constants, including tensor_constants, non-persistent buffers and script \ * objects */ \ DO(CONSTANTS_DIR, "data/constants/") \ + DO(CONSTANTS_PARAM_CONFIG_FORMAT, \ + "data/constants/{}_model_constants_config.json") \ DO(TENSOR_CONSTANT_FILENAME_PREFIX, "tensor_") \ DO(CUSTOM_OBJ_FILENAME_PREFIX, "custom_obj_") \ /* example inputs */ \ diff --git a/torch/csrc/export/pybind.cpp b/torch/csrc/export/pybind.cpp index 65206d06dbebda..eedd8666ea168b 100644 --- a/torch/csrc/export/pybind.cpp +++ b/torch/csrc/export/pybind.cpp @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include @@ -15,13 +17,37 @@ void initExportBindings(PyObject* module) { exportModule.def( "deserialize_exported_program", [](const std::string& serialized) { - return nlohmann::json::parse(serialized).get(); + auto parsed = nlohmann::json::parse(serialized); + + // Query the current Python schema version as target + // TODO: expose schema_version in gneerated_serialization_types.h and + // access it here directly. + py::module_ schema_module = + py::module_::import("torch._export.serde.schema"); + py::tuple schema_version_tuple = schema_module.attr("SCHEMA_VERSION"); + int target_version = schema_version_tuple[0].cast(); + + auto upgraded = upgrade(parsed, target_version); + return upgraded.get(); }); exportModule.def("serialize_exported_program", [](const ExportedProgram& ep) { return nlohmann::json(ep).dump(); }); + exportModule.def( + "upgrade", [](const std::string& serialized_json, int target_version) { + auto parsed = nlohmann::json::parse(serialized_json); + auto upgraded = upgrade(parsed, target_version); + return upgraded.dump(); + }); + + exportModule.def( + "register_example_upgraders", []() { registerExampleUpgraders(); }); + + exportModule.def( + "deregister_example_upgraders", []() { deregisterExampleUpgraders(); }); + for (const auto& entry : torch::_export::archive_spec::kAllConstants) { pt2ArchiveModule.attr(entry.first) = entry.second; } diff --git a/torch/csrc/export/upgrader.cpp b/torch/csrc/export/upgrader.cpp new file mode 100644 index 00000000000000..9f92239840b9cd --- /dev/null +++ b/torch/csrc/export/upgrader.cpp @@ -0,0 +1,242 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace torch::_export { + +// Global upgrader registry organized by version. +// Using std::multiset to maintain automatic bottom-up ordering where +// deeper keypaths are processed before shallower ones. +static std::map> upgrader_registry; + +static const std::multiset& getUpgrader(int current_version) { + static const std::multiset empty_upgraders; + auto it = upgrader_registry.find(current_version); + if (it != upgrader_registry.end()) { + return it->second; + } + return empty_upgraders; +} + +static nlohmann::json getFieldByKeypath( + const nlohmann::json& obj, + const std::vector& keypath) { + nlohmann::json current = obj; + for (const auto& key : keypath) { + if (!current.contains(key)) { + throw std::runtime_error("Keypath not found: " + key); + } + current = current[key]; + } + return current; +} + +static void setFieldByKeypath( + nlohmann::json& obj, + const std::vector& keypath, + const nlohmann::json& value) { + nlohmann::json* current = &obj; + for (size_t i = 0; i < keypath.size() - 1; ++i) { + const auto& key = keypath[i]; + if (!current->contains(key)) { + throw std::runtime_error("Keypath not found: " + key); + } + current = &((*current)[key]); + } + if (!current->contains(keypath.back())) { + throw std::runtime_error("Keypath not found: " + keypath.back()); + } + (*current)[keypath.back()] = value; +} + +Upgrader::Upgrader(std::vector kp, UpgraderFunction func) + : keypath(std::move(kp)), upgrade_func(std::move(func)) {} + +bool Upgrader::operator<(const Upgrader& other) const { + // First compare by depth - deeper paths come first for bottom-up processing + if (keypath.size() != other.keypath.size()) { + return keypath.size() > other.keypath.size(); + } + // If same depth, compare lexicographically for deterministic ordering + return keypath < other.keypath; +} + +void registerUpgrader( + int version, + const std::vector& keypath, + const UpgraderFunction& upgrade_func) { + // Check if an upgrader already exists for this version and keypath + auto version_it = upgrader_registry.find(version); + if (version_it != upgrader_registry.end()) { + const auto& upgraders = version_it->second; + + // Search for existing upgrader with the same keypath + for (const auto& existing_upgrader : upgraders) { + if (existing_upgrader.keypath == keypath) { + std::ostringstream error_stream; + error_stream << "Upgrader already registered for version " << version + << " and keypath: "; + for (size_t i = 0; i < keypath.size(); ++i) { + if (i > 0) + error_stream << "."; + error_stream << keypath[i]; + } + throw std::runtime_error(error_stream.str()); + } + } + } + + upgrader_registry[version].emplace(keypath, upgrade_func); +} + +void registerUpgrader( + int version, + const std::string& dot_keypath, + const UpgraderFunction& upgrade_func) { + // Convert dot-separated keypath to vector and delegate to main implementation + std::vector keypath_vector; + std::stringstream ss(dot_keypath); + std::string component; + + while (std::getline(ss, component, '.')) { + if (component.empty()) { + throw std::invalid_argument("Empty component in keypath: " + dot_keypath); + } + keypath_vector.push_back(component); + } + + if (keypath_vector.empty()) { + throw std::invalid_argument("Empty keypath provided"); + } + + registerUpgrader(version, keypath_vector, upgrade_func); +} + +bool deregisterUpgrader(int version, const std::vector& keypath) { + auto version_it = upgrader_registry.find(version); + if (version_it == upgrader_registry.end()) { + return false; // Version not found + } + + auto& upgraders = version_it->second; + + // Find the upgrader with matching keypath + for (auto it = upgraders.begin(); it != upgraders.end(); ++it) { + if (it->keypath == keypath) { + upgraders.erase(it); + + // If this was the last upgrader for this version, remove the version + // entry + if (upgraders.empty()) { + upgrader_registry.erase(version_it); + } + + return true; // Successfully removed + } + } + + return false; // Upgrader not found +} + +bool deregisterUpgrader(int version, const std::string& dot_keypath) { + // Convert dot-separated keypath to vector and delegate to main implementation + std::vector keypath_vector; + std::stringstream ss(dot_keypath); + std::string component; + + while (std::getline(ss, component, '.')) { + if (component.empty()) { + throw std::invalid_argument("Empty component in keypath: " + dot_keypath); + } + keypath_vector.push_back(component); + } + + if (keypath_vector.empty()) { + throw std::invalid_argument("Empty keypath provided"); + } + + return deregisterUpgrader(version, keypath_vector); +} + +void throwUpgraderError( + const std::string& upgrader_name, + int from_version, + const std::string& error_message, + const nlohmann::json& problematic_object) { + std::ostringstream error_stream; + error_stream << "Error in upgrader '" << upgrader_name << "' " + << "while upgrading from version " << from_version + << " to version " << from_version + 1 << ": " << error_message; + + if (!problematic_object.empty()) { + error_stream << "\nProblematic object: " << problematic_object.dump(2); + } + + throw std::runtime_error(error_stream.str()); +} + +nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) { + auto current_artifact = artifact; + + // Validate that the artifact contains required schema version information + if (!current_artifact.contains("schema_version")) { + throw std::runtime_error("Missing schema_version field in artifact"); + } + + int current_version = current_artifact["schema_version"]["major"]; + + // Iteratively apply upgraders until target version is reached or no more are + // available + while (current_version < target_version) { + // Look up upgraders for the current version + const auto& upgraders = getUpgrader(current_version); + + if (upgraders.empty()) { + // No more upgraders available - stop upgrading + break; + } + + // Apply all upgraders for this version in bottom-up order + // (deeper keypaths first to prevent parent/child conflicts) + for (const auto& upgrader : upgraders) { + // Extract the field to be upgraded using its keypath + auto field_to_upgrade = + getFieldByKeypath(current_artifact, upgrader.keypath); + + // Apply the upgrade transformation + auto upgraded_field = upgrader.upgrade_func(field_to_upgrade); + + // Update the artifact with the upgraded field + setFieldByKeypath(current_artifact, upgrader.keypath, upgraded_field); + } + + // Move to the next version for potential additional upgrades + current_version++; + } + + // Update schema version to reflect the final upgraded version + if (current_artifact["schema_version"]["major"] != current_version) { + current_artifact["schema_version"]["major"] = current_version; + // Reset minor version to 0 - the correct minor version should be set + // when converting the json to in memory representation of ExportedProgram + current_artifact["schema_version"]["minor"] = 0; + } + + // Validate that we reached the target version if requested + if (current_version != target_version) { + std::ostringstream error_stream; + error_stream + << "Failed to upgrade to target version " << target_version + << ". Final version reached: " << current_version + << ". This may indicate missing upgraders for intermediate versions."; + throw std::runtime_error(error_stream.str()); + } + + return current_artifact; +} + +} // namespace torch::_export diff --git a/torch/csrc/export/upgrader.h b/torch/csrc/export/upgrader.h new file mode 100644 index 00000000000000..c9e9b8f7ff1d0f --- /dev/null +++ b/torch/csrc/export/upgrader.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::_export { + +/// Function type for upgrading JSON fields during schema version migration. +/// Takes a JSON field and returns the upgraded version of that field. +using UpgraderFunction = std::function; + +/// Structure containing upgrader information for a specific keypath. +/// The version is stored as the map key in the registry, so it's not +/// duplicated here. +struct Upgrader { + /// Path to the field that should be upgraded (e.g., {"graph_module", "graph", + /// "nodes"}) Assuming top-level is a JSON object that represents + /// ExportedProgram + std::vector keypath; + + /// Function that performs the actual upgrade transformation + UpgraderFunction upgrade_func; + + /// Constructor for creating an upgrader with keypath and function + Upgrader(std::vector kp, UpgraderFunction func); + + /// Comparator for maintaining bottom-up ordering in the registry. + /// Deeper keypaths are processed first to ensure safe upgrade application + /// without conflicts between parent and child field modifications. + bool operator<(const Upgrader& other) const; +}; + +/// Register an upgrader function for a specific schema version and keypath. +/// +/// This function allows registration of custom upgrade logic that will be +/// applied when upgrading artifacts from the specified version. Upgraders +/// are applied in bottom-up order (deeper keypaths first) to prevent +/// conflicts between parent and child field modifications. +/// +/// @param version The schema version this upgrader applies to +/// @param keypath The key path to the field that should be upgraded +/// @param upgrade_func Function that performs the upgrade transformation +void registerUpgrader( + int version, + const std::vector& keypath, + const UpgraderFunction& upgrade_func); + +/// Register an upgrader function using dot-separated keypath notation. +/// +/// Convenience overload that accepts dot-separated keypath strings for +/// simpler syntax. For example: "graph_module.graph.nodes" instead of +/// {"graph_module", "graph", "nodes"}. +/// +/// @param version The schema version this upgrader applies to +/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes") +/// @param upgrade_func Function that performs the upgrade transformation +void registerUpgrader( + int version, + const std::string& dot_keypath, + const UpgraderFunction& upgrade_func); + +/// Deregister an upgrader function for a specific schema version and keypath. +/// +/// This function allows removal of previously registered upgrade logic for +/// the specified version and keypath. This is useful for testing scenarios +/// where you need to clean up registered upgraders or modify upgrader +/// behavior dynamically. +/// +/// @param version The schema version to deregister the upgrader from +/// @param keypath The key path to the field that should be deregistered +/// @return true if an upgrader was found and removed, false otherwise +bool deregisterUpgrader(int version, const std::vector& keypath); + +/// Deregister an upgrader function using dot-separated keypath notation. +/// +/// Convenience overload that accepts dot-separated keypath strings for +/// simpler syntax. For example: "graph_module.graph.nodes" instead of +/// {"graph_module", "graph", "nodes"}. +/// +/// @param version The schema version to deregister the upgrader from +/// @param dot_keypath Dot-separated keypath string (e.g., "graph.nodes") +/// @return true if an upgrader was found and removed, false otherwise +bool deregisterUpgrader(int version, const std::string& dot_keypath); + +/// Utility function for throwing consistent upgrader errors. +/// +/// This function formats error messages in a standardized way for upgrader +/// failures, including version information and optional problematic object +/// details for debugging. +/// +/// @param upgrader_name Name of the upgrader that failed +/// @param from_version Source schema version being upgraded from +/// @param error_message Descriptive error message +/// @param problematic_object Optional JSON object that caused the error +/// @throws std::runtime_error Always throws with formatted error message +void throwUpgraderError( + const std::string& upgrader_name, + int from_version, + const std::string& error_message, + const nlohmann::json& problematic_object = nlohmann::json::object()); + +/// Upgrade a JSON artifact to a specific target version with available +/// upgraders until a target version is reached. +/// +/// This handles major version upgrade only. For minor version upgrade, +/// e.g. adding a new field with default value, it's automatically handled by +/// the default constructor in generated_serialization_types.h. +/// +/// @param artifact The JSON artifact to upgrade +/// @param target_version The target schema version to upgrade to +/// @return The upgraded JSON artifact with updated schema version +/// @throws std::runtime_error if artifact is missing schema_version field +/// @throws std::runtime_error if final version doesn't match target version +nlohmann::json upgrade(const nlohmann::json& artifact, int target_version); + +} // namespace torch::_export diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp index d3244441da161c..7a5729a96efe72 100644 --- a/torch/csrc/fx/node.cpp +++ b/torch/csrc/fx/node.cpp @@ -335,7 +335,7 @@ PyTypeObject NodeBaseType = { "torch._C._NodeBase", /* tp_name */ sizeof(NodeBase), /* tp_basicsize */ 0, /* tp_itemsize */ - (destructor)NodeBase_dealloc, /* tp_dealloc */ + NodeBase_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index d1f6ca4025ba33..fcdefeac9219c8 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -110,7 +110,7 @@ std::vector unpack_input_parameters( } if (stack[idx].isScalar()) { - // Beyond c10::Scalar, the floating value and interger value are also + // Beyond c10::Scalar, the floating value and integer value are also // represented as Scalar. inputs_metadata.emplace_back(stack[idx].toScalar(), arg_order); } else if (stack[idx].isTensorList()) { @@ -421,6 +421,7 @@ std::shared_ptr AOTIPythonKernelHolder:: "AOTI for eager does not support ", c10::DeviceTypeName(device_.type()), " now."); + // NOLINTNEXTLINE(bugprone-branch-clone) if (device_.type() == c10::DeviceType::CUDA) { #ifdef USE_CUDA return std::make_shared(so_path); @@ -528,7 +529,7 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib( auto kernel_lib_path = py::cast(result); TORCH_CHECK( !kernel_lib_path.empty(), - "Failed to produce kernel libarary by using AOTI for ", + "Failed to produce kernel library by using AOTI for ", c10::DeviceTypeName(device_.type()), ". Operator Name is ", op.operator_name().name, diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 4acfc5d2610503..dacdc9eac3882e 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #ifndef _WIN32 #include @@ -324,6 +325,30 @@ std::string compile_so( return output_so; } + +std::unordered_set find_model_names( + const std::vector& paths) { + std::unordered_set model_names; + + // Escape the separator if it's backslash (needed for regex) + std::string sep = k_separator; + if (sep == "\\") + sep = "\\\\"; + + std::string pattern = + "data" + sep + "aotinductor" + sep + "([^" + sep + "]+)" + sep; + std::regex re(pattern); + + for (const auto& path : paths) { + std::smatch match; + if (std::regex_search(path, match, re) && match.size() > 1) { + model_names.insert(match[1].str()); + } + } + + return model_names; +} + } // namespace void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { @@ -487,8 +512,21 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( for (const std::string& filename : found_filenames) { found_filenames_str += filename + "\n"; } + std::string model_names_str; + for (const std::string& model_name_tmp : + find_model_names(found_filenames)) { + model_names_str += model_name_tmp + "\n"; + } + throw std::runtime_error( - "No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" + + "Failed to find a generated cpp file or so file for model '" + + model_name + + "' in the zip archive.\n\n" + "Available models in the archive:\n" + + model_names_str + + "\n\n" + "To load a specific model, please provide its name using the `model_name` parameter when calling AOTIModelPackageLoader() or torch._inductor.package.load_package.\n\n" + "The following files were loaded from the archive:\n" + found_filenames_str); } diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.h b/torch/csrc/inductor/aoti_runner/model_container_runner.h index 2fd08b15a7eda4..39065dab187f9a 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.h +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.h @@ -121,8 +121,8 @@ TORCH_API std::unordered_map& getAOTIModelRunnerRegistry(); // To register a new external backend in AOTI one needs to create an instance of -// this struct. It is not thread-safe. Becase it is expected to be called during -// the initialization of the program. +// this struct. It is not thread-safe. Because it is expected to be called +// during the initialization of the program. struct TORCH_API RegisterAOTIModelRunner{RegisterAOTIModelRunner( const std::string& name, CreateAOTIModelRunnerFunc create_aoti_model_runner_fn){ diff --git a/torch/csrc/inductor/aoti_runner/pybind.cpp b/torch/csrc/inductor/aoti_runner/pybind.cpp index 2c48690ea36ec8..ff0d198aeaeb6a 100644 --- a/torch/csrc/inductor/aoti_runner/pybind.cpp +++ b/torch/csrc/inductor/aoti_runner/pybind.cpp @@ -134,7 +134,8 @@ void initAOTIRunnerBindings(PyObject* module) { &AOTIModelContainerRunnerXpu::free_inactive_constant_buffer); #endif -#if defined(__APPLE__) && !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) +#if defined(USE_MPS) && defined(__APPLE__) && \ + !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) py::class_(m, "AOTIModelContainerRunnerMps") .def(py::init()) .def( diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h index a3a1a9fe84d34d..1c12f018cd423d 100644 --- a/torch/csrc/inductor/aoti_runtime/model.h +++ b/torch/csrc/inductor/aoti_runtime/model.h @@ -15,11 +15,14 @@ // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule // applies to other files under torch/csrc/inductor/aoti_runtime/. #include +#ifdef USE_MPS +#include +#endif // USE_MPS #ifdef USE_XPU #include #else #include -#endif +#endif // USE_XPU #include #define AOTI_RUNTIME_CHECK(EXPR, MSG) \ @@ -74,6 +77,15 @@ RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) { return RAIIDataPtr(data_ptr, deleter); } +#elif defined(USE_MPS) + +RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) { + void* data_ptr = nullptr; + aoti_torch_mps_malloc(&data_ptr, num_bytes); + auto deleter = [](void* ptr) { aoti_torch_mps_free(ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + #else RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { @@ -113,7 +125,7 @@ inline void parse_device_str( } else if (sm[1].str() == "xpu") { device_type = aoti_torch_device_type_xpu(); #endif -#ifdef __APPLE__ +#ifdef USE_MPS } else if (sm[1].str() == "mps") { device_type = aoti_torch_device_type_mps(); #endif @@ -165,6 +177,11 @@ class AOTInductorModelBase { aoti_torch_set_current_xpu_device(device_idx_); } #endif // USE_XPU +#ifdef USE_MPS + if (device_idx_ == -1) { + device_idx_ = 0; + } +#endif // USE_MPS } // NOLINTNEXTLINE(modernize-use-equals-default) @@ -299,7 +316,7 @@ class AOTInductorModelBase { if (!include_weights) { return; } -#if defined(USE_CUDA) || defined(USE_XPU) +#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS) constant_blob_ = RAII_gpuMalloc(blob_size); #else constant_blob_ = RAII_cpuMalloc(blob_size); @@ -327,7 +344,12 @@ class AOTInductorModelBase { auto ndim = this->constant_ndim(i); auto size = this->constant_shape(i); auto stride = this->constant_stride(i); +#ifdef USE_MPS + auto offset = this->constant_offset(i) + + (constants_internal_offset[i] / aoti_torch_dtype_element_size(dtype)); +#else auto offset = this->constant_offset(i); +#endif auto layout = this->constant_layout(i); auto opaque_metadata_ptr = this->opaque_metadata(i); auto opaque_metadata_size = this->opaque_metadata_size(i); @@ -390,6 +412,14 @@ class AOTInductorModelBase { _get_constants_start() + bytes_read, data_size, cudaMemcpyHostToDevice)); +#elif USE_MPS + aoti_torch_mps_memcpy( + constants_ptr, + constant_offset, + bytes_read, + data_size, + _get_constants_start()); + return constants_ptr; #else memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size); #endif @@ -629,7 +659,7 @@ class AOTInductorModelBase { AOTI_RUNTIME_CHECK( reinterpret_cast( self_mmap + weights_size - sizeof(uint64_t))[0] == magic_number, - "Weigths data seems corrupt"); + "Weights data seems corrupt"); return self_mmap; #endif } @@ -677,7 +707,7 @@ class AOTInductorModelBase { bool include_weights; // Record if the model finishes an inference run so that its owning - // AOTModelContainer can re-use this instance. + // AOTModelContainer can reuse this instance. #ifdef USE_CUDA std::optional run_finished_; #elif defined(USE_XPU) diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index f249b045a80727..10292f7968a268 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -18,7 +18,7 @@ namespace torch::aot_inductor { // when model_container is created and no constants are being loaded or updated. // (2) INITIALIZED state: This state get set whenever we load the constants into // the buffer. This could be done by load_constants or update_constants_buffer. -// (3) FOLDED state: This state should transition from INITIALILZED after +// (3) FOLDED state: This state should transition from INITIALIZED after // const_fold is being invoked. enum class ConstantState : uint8_t { NONE, INITIALIZED, FOLDED, UNKNOWN }; @@ -666,7 +666,7 @@ class AOTInductorModelContainer { std::shared_mutex model_exec_mutex_; RAIIDataPtr allocate_constant_blob() { -#if defined(USE_CUDA) || defined(USE_XPU) +#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS) return RAII_gpuMalloc(blob_size_); #else return RAII_cpuMalloc(blob_size_); diff --git a/torch/csrc/inductor/aoti_runtime/utils_cuda.h b/torch/csrc/inductor/aoti_runtime/utils_cuda.h index dc9bac89c9f0fd..a4f1706ec7fb6a 100644 --- a/torch/csrc/inductor/aoti_runtime/utils_cuda.h +++ b/torch/csrc/inductor/aoti_runtime/utils_cuda.h @@ -12,6 +12,7 @@ #ifndef USE_ROCM #include #include +#include #endif namespace torch::aot_inductor { diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index b6bca96903d064..6a23c9d465c7f0 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -129,6 +129,7 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128(); +AOTI_TORCH_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype); AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided(); AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_coo(); @@ -273,6 +274,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( AtenTensorHandle tensor, int64_t* ret_storage_offset); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle( AtenTensorHandle orig_handle, AtenTensorHandle* new_handle); diff --git a/torch/csrc/inductor/aoti_torch/c/shim_cpu.h b/torch/csrc/inductor/aoti_torch/c/shim_cpu.h index c7b713bf7f877e..5a10290decd1db 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_cpu.h @@ -245,6 +245,22 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_( + AtenTensorHandle inp, + const char* reduce_op, + const char* group_name, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce( + AtenTensorHandle inp, + const char* reduce_op, + const char* group_name, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor( + AtenTensorHandle inp, + AtenTensorHandle* ret0); + #ifdef __cplusplus } // extern "C" #endif diff --git a/torch/csrc/inductor/aoti_torch/c/shim_mps.h b/torch/csrc/inductor/aoti_torch/c/shim_mps.h index cdcbad27f36831..bd86885de13ca8 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_mps.h @@ -10,11 +10,28 @@ extern "C" { struct AOTIMetalKernelFunctionOpaque; using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*; -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg( +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_tensor( AOTIMetalKernelFunctionHandle func, unsigned idx, AtenTensorHandle tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_mps_malloc(void** buffer, size_t num_bytes); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_free(void* ptr); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start); + #ifdef __cplusplus } // extern "C" #endif diff --git a/torch/csrc/inductor/aoti_torch/c/shim_xpu.h b/torch/csrc/inductor/aoti_torch/c/shim_xpu.h index baecfc3521794a..408c99ca655f65 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_xpu.h @@ -1,6 +1,7 @@ #ifndef AOTI_TORCH_SHIM_XPU #define AOTI_TORCH_SHIM_XPU +#include #include #ifdef USE_XPU @@ -45,6 +46,68 @@ aoti_torch_set_current_xpu_device(const int32_t& device_index); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_sycl_queue(void** ret); +#if AT_MKLDNN_ENABLED() + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_xpu_mkldnn__convolution_pointwise_binary( + AtenTensorHandle X, + AtenTensorHandle other, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mkldnn__convolution_pointwise( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* attr, + const double** scalars, + int64_t scalars_len_, + const char** algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_xpu_mkldnn__convolution_pointwise_binary_( + AtenTensorHandle other, + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0); + +#endif // AT_MKLDNN_ENABLED() #ifdef __cplusplus } // extern "C" #endif diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 8399f68d026914..2aa09cb802ecde 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -73,6 +73,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummin(AtenTensorHandle self, in AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fill__Scalar(AtenTensorHandle self, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); @@ -104,10 +105,12 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mode(AtenTensorHandle self, int6 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index efbe72f47ee8b0..e0607f984b3d0f 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -80,6 +80,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummin(AtenTensorHandle self, i AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fill__Scalar(AtenTensorHandle self, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); @@ -110,10 +111,12 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mode(AtenTensorHandle self, int AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index 7fecb88772c89d..d6f77d9f1343d8 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -18,9 +18,12 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); @@ -49,6 +52,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cummin(AtenTensorHandle self, in AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_fill__Scalar(AtenTensorHandle self, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1); @@ -62,13 +66,16 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_scatter_backward(AtenTens AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_select(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_median(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); @@ -106,6 +113,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_triangular_solve(AtenTensorHandl AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_as_complex(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_as_real(AtenTensorHandle self, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index b965490ebb6edb..243bfb5fc87aaf 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -16,6 +16,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0); @@ -37,7 +38,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter(AtenTensorHandle AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp index fba4660a002c25..0beffa32d6c915 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp @@ -445,10 +445,9 @@ void OSSProxyExecutor::get_input_info_from_serialized( // If an argument is not filled and has a default value, we should // also prefill the default value. for (size_t index = 0; index < schema_args.size(); index++) { - if (!filled[index] && schema_args[index].default_value()) { - // @lint-ignore CLANGTIDY bugprone-unchecked-optional-access - auto default_value = *schema_args[index].default_value(); - op_kernel.stack_.at(index) = default_value; + auto default_value = schema_args[index].default_value(); + if (!filled[index] && default_value.has_value()) { + op_kernel.stack_.at(index) = std::move(default_value.value()); } } } @@ -536,6 +535,17 @@ void OSSProxyExecutor::get_output_info_from_serialized( } break; } + case c10::TypeKind::IntType: { + TORCH_CHECK( + serialized_output_type == "as_int", + "Expected extern kernel ", + serialized_node["target"], + " to have serialized output type as_int, ", + " but got ", + serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::IntType, 1); + break; + } default: { TORCH_CHECK( false, @@ -800,12 +810,14 @@ void OSSProxyExecutor::call_function( tensor_id, ", expected num = ", num_tensors - num_output_tensors); + + int num_output_ints = op_kernel->num_output_ints(); TORCH_CHECK( - int_id == num_ints, + int_id == num_ints - num_output_ints, "Mismatch between ints consumed and num_ints, got int_id = ", int_id, ", num_ints = ", - num_ints); + num_ints - num_output_ints); // Call the op with the prepared stack. op_kernel->run(stack); @@ -849,8 +861,20 @@ void OSSProxyExecutor::call_function( TORCH_CHECK(false, "Expected tensor, got None"); } } else { - continue; + index++; } + } else if (schema_return.real_type()->kind() == c10::TypeKind::IntType) { + // need to use real_type() to differentiate between IntType and SymIntType + // for int type, it is already specialized in downstream kernels. So we + // don't need to do anything here. + auto returned_int_value = stack[index++].toInt(); + auto serialized_int_value = flatten_int_args[int_id++]; + TORCH_CHECK( + returned_int_value == serialized_int_value, + "Expect returned int value to match the serialized int value, but got returned int value: ", + returned_int_value, + " and serialized int value: ", + serialized_int_value); } else { TORCH_CHECK( false, @@ -865,6 +889,13 @@ void OSSProxyExecutor::call_function( tensor_id, ", expected num = ", num_tensors); + + TORCH_CHECK( + int_id == num_ints, + "Mismatch between tensors consumed and num_ints, got tensor_id = ", + int_id, + ", expected num = ", + num_ints); } } // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h index d20ef2e521861d..b0f5b3083cf232 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h @@ -12,26 +12,11 @@ namespace torch::aot_inductor { -enum class DynamicArgType : int { - TensorType = 0, - ListTensorType = 1, - ListOptionalTensorType = 2, - IntType = 3, - ListIntType = 4, - NoneType = 5, -}; - inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) { os << static_cast(arg_type); return os; } -inline bool isTensorType(DynamicArgType arg_type) { - return arg_type == DynamicArgType::TensorType || - arg_type == DynamicArgType::ListTensorType || - arg_type == DynamicArgType::ListOptionalTensorType; -} - struct OSSDynamicArg { OSSDynamicArg( int arg_index, @@ -82,6 +67,16 @@ struct OSSOpKernel { return num_output_tensors; } + int num_output_ints() const { + int num_output_ints = 0; + for (const auto& output : outputs_) { + if (output.arg_type == DynamicArgType::IntType) { + num_output_ints += output.length; + } + } + return num_output_ints; + } + virtual void run(std::vector& stack) = 0; virtual c10::FunctionSchema schema() const = 0; virtual ~OSSOpKernel() = default; diff --git a/torch/csrc/inductor/aoti_torch/proxy_executor.h b/torch/csrc/inductor/aoti_torch/proxy_executor.h index 6943bca5df49e9..708dc52a760cf0 100644 --- a/torch/csrc/inductor/aoti_torch/proxy_executor.h +++ b/torch/csrc/inductor/aoti_torch/proxy_executor.h @@ -6,6 +6,21 @@ namespace torch::aot_inductor { +enum class DynamicArgType : int { + TensorType = 0, + ListTensorType = 1, + ListOptionalTensorType = 2, + IntType = 3, + ListIntType = 4, + NoneType = 5, +}; + +inline bool isTensorType(DynamicArgType arg_type) { + return arg_type == DynamicArgType::TensorType || + arg_type == DynamicArgType::ListTensorType || + arg_type == DynamicArgType::ListOptionalTensorType; +} + class ProxyExecutor { public: ProxyExecutor() = default; diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 5a0c9f581fc520..dc6e52b0c4db1d 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -253,6 +253,11 @@ void aoti_torch_grad_mode_set_enabled(bool enabled) { return c10::GradMode::set_enabled(enabled); } +size_t aoti_torch_dtype_element_size(int32_t dtype) { + auto scalar_type = static_cast(dtype); + return c10::elementSize(scalar_type); +} + AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); @@ -388,6 +393,15 @@ AOTITorchError aoti_torch_get_storage_offset( }); } +AOTITorchError aoti_torch_is_contiguous( + AtenTensorHandle tensor, + bool* ret_is_contiguous) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); + *ret_is_contiguous = t->is_contiguous(); + }); +} + AOTITorchError aoti_torch_new_tensor_handle( AtenTensorHandle orig_handle, AtenTensorHandle* new_handle) { diff --git a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp index 85467bbad85934..904bd5f9e51ff7 100644 --- a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp @@ -1,4 +1,7 @@ +#ifdef USE_DISTRIBUTED +#include +#endif #include #include @@ -539,3 +542,38 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( *ret0 = new_tensor_handle(std::move(tmp_result)); }); } + +#ifdef USE_DISTRIBUTED +AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_( + AtenTensorHandle inp, + const char* reduce_op, + const char* group_name, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = c10d::all_reduce_( + *tensor_handle_to_tensor_pointer(inp), reduce_op, group_name); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} + +AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce( + AtenTensorHandle inp, + const char* reduce_op, + const char* group_name, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = c10d::all_reduce( + *tensor_handle_to_tensor_pointer(inp), reduce_op, group_name); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} + +AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor( + AtenTensorHandle inp, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = c10d::wait_tensor(*tensor_handle_to_tensor_pointer(inp)); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} +#endif diff --git a/torch/csrc/inductor/aoti_torch/shim_mps.cpp b/torch/csrc/inductor/aoti_torch/shim_mps.cpp index d21d54924abe23..47cb8f0f71f011 100644 --- a/torch/csrc/inductor/aoti_torch/shim_mps.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_mps.cpp @@ -4,7 +4,7 @@ using namespace torch::aot_inductor; -AOTITorchError aoti_torch_mps_set_arg( +AOTITorchError aoti_torch_mps_set_arg_tensor( AOTIMetalKernelFunctionHandle handle, unsigned idx, AtenTensorHandle tensor) { @@ -17,3 +17,13 @@ AOTITorchError aoti_torch_mps_set_arg( func->setArg(idx, *t); }); } + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle handle, + unsigned idx, + int64_t val) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto func = reinterpret_cast(handle); + func->setArg(idx, val); + }); +} diff --git a/torch/csrc/inductor/aoti_torch/shim_mps.mm b/torch/csrc/inductor/aoti_torch/shim_mps.mm new file mode 100644 index 00000000000000..9f70331ffc0b96 --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/shim_mps.mm @@ -0,0 +1,42 @@ +#include +#include +#include +#include +#include + + +using namespace torch::aot_inductor; + +AOTITorchError aoti_torch_mps_malloc( + void** buffer, + size_t num_bytes) { + if (num_bytes == 0) { + *buffer = nullptr; + return AOTI_TORCH_SUCCESS; + } + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + id device = at::mps::MPSDevice::getInstance()->device(); + TORCH_CHECK(device, "Failed to get MPS device"); + id metal_buffer = [device newBufferWithLength:num_bytes options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared]; + TORCH_CHECK(metal_buffer, "Failed to allocate memory on MPS device"); + *buffer = (void*)metal_buffer; + }); +} + +AOTITorchError aoti_torch_mps_free( + void* ptr) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto metal_buffer = (id)ptr; + [metal_buffer release]; + }); +} + + +AOTITorchError +aoti_torch_mps_memcpy(void* buffer, size_t constant_offset, size_t bytes_read, size_t data_size, uint8_t* constants_start) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto metal_buffer = (id)buffer; + auto buffer_pointer = static_cast([metal_buffer contents]); + memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size); + }); +} diff --git a/torch/csrc/inductor/aoti_torch/shim_xpu.cpp b/torch/csrc/inductor/aoti_torch/shim_xpu.cpp index ab4e8df4af3771..33f8985d83bdfc 100644 --- a/torch/csrc/inductor/aoti_torch/shim_xpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_xpu.cpp @@ -7,6 +7,8 @@ #include #include +using namespace torch::aot_inductor; + AOTITorchError aoti_torch_create_xpu_guard( int32_t device_index, XPUGuardHandle* ret_guard // returns new reference @@ -57,8 +59,10 @@ AOTITorchError aoti_torch_get_current_xpu_stream( } AOTITorchError aoti_torch_get_current_xpu_device(int32_t* device_index) { - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( - { *device_index = static_cast(c10::xpu::current_device()); }); + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + *device_index = + static_cast(static_cast(c10::xpu::current_device())); + }); } AOTITorchError aoti_torch_set_current_xpu_device(const int32_t& device_index) { @@ -68,7 +72,136 @@ AOTITorchError aoti_torch_set_current_xpu_device(const int32_t& device_index) { AOTITorchError aoti_torch_get_current_sycl_queue(void** ret) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ - int32_t device_index = static_cast(c10::xpu::current_device()); + int32_t device_index = + static_cast(static_cast(c10::xpu::current_device())); *ret = &(at::xpu::getCurrentXPUStream(device_index).queue()); }); } + +#if AT_MKLDNN_ENABLED() +#include + +AOTITorchError aoti_torch_xpu_mkldnn__convolution_pointwise_binary( + AtenTensorHandle X, + AtenTensorHandle other, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::List> unary_scalars_list; + unary_scalars_list.reserve(unary_scalars_len_); + for (int64_t i = 0; i < unary_scalars_len_; i++) { + unary_scalars_list.emplace_back(pointer_to_optional(unary_scalars[i])); + } + auto tmp_result = at::native::xpu::convolution_pointwise_binary( + *tensor_handle_to_tensor_pointer(X), + *tensor_handle_to_tensor_pointer(other), + *tensor_handle_to_tensor_pointer(W), + pointer_to_optional(B), + pointer_to_list(padding, padding_len_), + pointer_to_list(stride, stride_len_), + pointer_to_list(dilation, dilation_len_), + groups, + binary_attr, + pointer_to_optional(alpha), + pointer_to_optional(unary_attr), + unary_scalars_list, + pointer_to_optional(unary_algorithm)); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} + +AOTITorchError aoti_torch_xpu_mkldnn__convolution_pointwise_binary_( + AtenTensorHandle other, + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::List> unary_scalars_list; + unary_scalars_list.reserve(unary_scalars_len_); + for (int64_t i = 0; i < unary_scalars_len_; i++) { + unary_scalars_list.emplace_back(pointer_to_optional(unary_scalars[i])); + } + auto tmp_result = at::native::xpu::convolution_pointwise_binary_( + *tensor_handle_to_tensor_pointer(other), + *tensor_handle_to_tensor_pointer(X), + *tensor_handle_to_tensor_pointer(W), + pointer_to_optional(B), + pointer_to_list(padding, padding_len_), + pointer_to_list(stride, stride_len_), + pointer_to_list(dilation, dilation_len_), + groups, + binary_attr, + pointer_to_optional(alpha), + pointer_to_optional(unary_attr), + unary_scalars_list, + pointer_to_optional(unary_algorithm)); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} + +AOTITorchError aoti_torch_xpu_mkldnn__convolution_pointwise( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* attr, + const double** scalars, + int64_t scalars_len_, + const char** algorithm, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::List> scalars_list; + scalars_list.reserve(scalars_len_); + for (int64_t i = 0; i < scalars_len_; i++) { + scalars_list.emplace_back(pointer_to_optional(scalars[i])); + } + auto tmp_result = at::native::xpu::convolution_pointwise( + *tensor_handle_to_tensor_pointer(X), + *tensor_handle_to_tensor_pointer(W), + pointer_to_optional(B), + pointer_to_list(padding, padding_len_), + pointer_to_list(stride, stride_len_), + pointer_to_list(dilation, dilation_len_), + groups, + attr, + scalars_list, + pointer_to_optional(algorithm)); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} + +#endif // AT_MKLDNN_ENABLED() diff --git a/torch/csrc/inductor/cpp_wrapper/common.h b/torch/csrc/inductor/cpp_wrapper/common.h index 78fdda0a45a7af..9d9ae16462cc19 100644 --- a/torch/csrc/inductor/cpp_wrapper/common.h +++ b/torch/csrc/inductor/cpp_wrapper/common.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/inductor/static_cuda_launcher.cpp b/torch/csrc/inductor/static_cuda_launcher.cpp index 6ee43590692ec3..35756b704faa90 100644 --- a/torch/csrc/inductor/static_cuda_launcher.cpp +++ b/torch/csrc/inductor/static_cuda_launcher.cpp @@ -438,12 +438,12 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) { std::array StaticCudaLauncherMethods = { PyMethodDef{ "_launch_kernel", - (PyCFunction)launch_kernel, + launch_kernel, METH_VARARGS, "Statically launch triton compiled CUDA kernels"}, PyMethodDef{ "_load_kernel", - (PyCFunction)load_kernel, + load_kernel, METH_VARARGS, "Load CUDA kernel from cubin file"}}; diff --git a/torch/csrc/instruction_counter/Module.cpp b/torch/csrc/instruction_counter/Module.cpp index 1fc80196e27d1e..aafae0dd3cbd48 100644 --- a/torch/csrc/instruction_counter/Module.cpp +++ b/torch/csrc/instruction_counter/Module.cpp @@ -9,6 +9,7 @@ #include #if defined(__linux__) +#include #include #include #include @@ -36,7 +37,7 @@ static long start() { long fd = syscall(SYS_perf_event_open, &attr, 0, -1, -1, 0); if (fd == -1) { - fprintf( + fmt::fprintf( stderr, "Failed to open instruction count event: %s.\n", c10::utils::str_error(errno).c_str()); @@ -54,7 +55,7 @@ static uint64_t end(int fd) { #else // Disable the event group if (ioctl(fd, PERF_EVENT_IOC_DISABLE, PERF_IOC_FLAG_GROUP) == -1) { - fprintf( + fmt::fprintf( stderr, "Error disabling perf event (fd: %d): %s\n", fd, @@ -67,7 +68,7 @@ static uint64_t end(int fd) { // Read results long ret_val = read(fd, &total_instructions, sizeof(total_instructions)); if (ret_val == -1) { - fprintf( + fmt::fprintf( stderr, "Error reading perf event results: %s\n", c10::utils::str_error(errno).c_str()); diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index 29079448abfae5..1ef0522d2175a4 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -958,7 +958,7 @@ torch._C._jit_set_fusion_strategy([ ]) ``` -This will make two attempts to generate static-shape graphs, and after that fall back to generating dynamic-shape graphs. If for some reason compilation keeps occuring (even with dynamic-shape graphs - e.g. this could happen if ranks or dtypes vary), after 20 compilation attempts the graph executor will fall back to running the graph without any attempts to compile it. +This will make two attempts to generate static-shape graphs, and after that fall back to generating dynamic-shape graphs. If for some reason compilation keeps occurring (even with dynamic-shape graphs - e.g. this could happen if ranks or dtypes vary), after 20 compilation attempts the graph executor will fall back to running the graph without any attempts to compile it. ### Pre-derivative Optimization ### diff --git a/torch/csrc/jit/README.md b/torch/csrc/jit/README.md index 83a2393a78623c..4d9c2d07f3d1db 100644 --- a/torch/csrc/jit/README.md +++ b/torch/csrc/jit/README.md @@ -26,5 +26,5 @@ A brief summary of the source tree: **Refer** to each folder for more in-depth documentation. Other relevant parts of the codebase not contained here: -- [aten/src/ATen/core](../../../aten/src/ATen/core): contains JIT code re-used by other elements of the +- [aten/src/ATen/core](../../../aten/src/ATen/core): contains JIT code reused by other elements of the runtime system (eager, mobile, etc.) diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h index f5d86039ec0aa6..f508f3e5d522bb 100644 --- a/torch/csrc/jit/api/function_impl.h +++ b/torch/csrc/jit/api/function_impl.h @@ -147,7 +147,7 @@ struct TORCH_API GraphFunction : public Function { mutable std::array, SpecializationKey::TotalCount> optimized_graphs_; - // GraphFunctions are invokable from multiple threads, so this lock needs to + // GraphFunctions are invocable from multiple threads, so this lock needs to // be held when we're initializing graph executor for the first time or // computing the optimized graph. We're using reentrant mutex so that we don't // need to worry about causing a deadlock by calling one method from another diff --git a/torch/csrc/jit/api/method.h b/torch/csrc/jit/api/method.h index 28675e5bd059f5..d7ef14ddb193dd 100644 --- a/torch/csrc/jit/api/method.h +++ b/torch/csrc/jit/api/method.h @@ -67,7 +67,7 @@ struct TORCH_API Method : public torch::IMethod { private: void setArgumentNames(std::vector&) const override; - // Methods are uniqued onwed by a single module. This raw pointer allows + // Methods are uniqued owned by a single module. This raw pointer allows // looking up the module. ObjectPtr owner_; diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 8e9be1de48a5fd..52cec12fb85984 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -327,7 +327,7 @@ struct TORCH_API Module : public Object { // Map of function names to the traced inputs that they have been traced with c10::Dict traced_inputs_; - // Mutex to keep registring buffer or parameter thread safe. + // Mutex to keep registering buffer or parameter thread safe. std::shared_ptr register_mutex_ = std::make_shared(); }; diff --git a/torch/csrc/jit/backends/backend_debug_handler.cpp b/torch/csrc/jit/backends/backend_debug_handler.cpp index 6c2ba467bc6b2f..0d410341303955 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.cpp +++ b/torch/csrc/jit/backends/backend_debug_handler.cpp @@ -26,7 +26,7 @@ int64_t BackendDebugInfoRecorder::getNextDebugHandle(const Node* node) { BackendDebugInfoMapType BackendDebugInfoRecorder::stopRecording() { // Note that this is return by copy and since // InlinedCallStackPtrs are intrusive ptr it will result in - // bump of refcount. Not performant, but this is not intented + // bump of refcount. Not performant, but this is not intended // to be used in perf critical path. // Alternate might be do move but that will be destructive return handles_to_inlined_callstack_ptrs_; diff --git a/torch/csrc/jit/backends/backend_debug_handler.h b/torch/csrc/jit/backends/backend_debug_handler.h index 4128832e7a078a..2e0145b56c294c 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.h +++ b/torch/csrc/jit/backends/backend_debug_handler.h @@ -18,7 +18,7 @@ namespace torch::jit { * Effectively debug handles are something that is given to backend and later * when an exception occurs in the backend, backend can tell, using debug * handle, that an exception occurred here. Then the runtime can generate - * callstack correspoding to the exception. + * callstack corresponding to the exception. * There are two parts to BackendDebugHandleManager: * 1. static std::atomic debug_handle * 2. Map of [debug-handle, DebugInfoTuple] diff --git a/torch/csrc/jit/backends/backend_exception.h b/torch/csrc/jit/backends/backend_exception.h index d964f1bfcf0086..807ef38e283054 100644 --- a/torch/csrc/jit/backends/backend_exception.h +++ b/torch/csrc/jit/backends/backend_exception.h @@ -16,13 +16,13 @@ class TORCH_API BackendRuntimeException : public c10::Error { } // If rethrowing, can push another debug_handle // This is useful in couple of scenarios. - // 1. A submodule is lowered and lite interperter has CallMethod + // 1. A submodule is lowered and lite interpreter has CallMethod // to lowered module's method. In this case lowered module will throw with // a handle, plus there will be another debug handle corresponding // to the CallMethod node in lite interpreter. Both together give complete // trace. This function allows lite interpreter to rethrow with debug // handle it has for CallMethod. - // 2. Another scenarios is when lite interperter can make function calls or + // 2. Another scenarios is when lite interpreter can make function calls or // the lowered backend also has function call ability. Thus we have // multiple function frames. Now we need a stack of handles to symbolicate // entire stack trace. diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h index 05594983d246b0..4d8fe049134fe1 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h @@ -37,7 +37,7 @@ class XNNSerializer { // Serialize add node, we are serializing the argument needed to call // xnn_define_add2. Serializing these values, and at run time we build - // teh graph by re running xnn_define_add2 + // the graph by re running xnn_define_add2 void serializeAddNode( uint32_t input1_id, uint32_t input2_id, diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp index 645653964794e7..0428ac370b728b 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp @@ -34,7 +34,7 @@ class XNNPackBackend : public PyTorchBackendInterface { c10::impl::GenericDict method_compile_spec) override { auto dict = processed.toGenericDict(); - // Compiling and wrapping exeuction object + // Compiling and wrapping execution object const std::string& ser_model = dict.at("ser_model").toStringRef(); XNNExecutor executor; XNNCompiler::compileModel(ser_model.data(), ser_model.length(), &executor); diff --git a/torch/csrc/jit/codegen/cuda/README.md b/torch/csrc/jit/codegen/cuda/README.md index 284fd14111962b..a68bc0491919b9 100644 --- a/torch/csrc/jit/codegen/cuda/README.md +++ b/torch/csrc/jit/codegen/cuda/README.md @@ -78,7 +78,7 @@ Graph print out is straight forward and you should look for `prim::CudaFusionGro return (%o.5) ``` -Note that one thing that could prevents fusion when you are running training is autodiff. Fusion pass only runs within `prim::DifferentiableGraph`, so the first thing you should check is to that targetted ops are within differentiable graph subgraphs. +Note that one thing that could prevents fusion when you are running training is autodiff. Fusion pass only runs within `prim::DifferentiableGraph`, so the first thing you should check is to that targeted ops are within differentiable graph subgraphs. Graph dump could be quite confusing to look at, since it naively dumps all graphs executed by profiling executor and differentiable graphs are executed via a nested graph executor. So for each graph, you might see a few segmented `Optimized Graph` where each corresponds to a differentiable node in the original graph. #### 2. Cuda Fusion Graphs diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index 04700b905f41bd..2f1e7e8e95059f 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -635,7 +635,7 @@ std::string generateKernel( } // Includes headers - // Note: CUDA kernels support halfs and random generation, CPU kernels do not + // Note: CUDA kernels support Halfs and random generation, CPU kernels do not if (has_half_tensor) { env.s("HalfHeader", cuda::half_support_literal); } else { diff --git a/torch/csrc/jit/codegen/fuser/executor.cpp b/torch/csrc/jit/codegen/fuser/executor.cpp index fa104b7cc16bff..67c4501dc2758a 100644 --- a/torch/csrc/jit/codegen/fuser/executor.cpp +++ b/torch/csrc/jit/codegen/fuser/executor.cpp @@ -28,7 +28,7 @@ static std::optional> getMapSize( // exactly how much storage do we need, so this could be fixed in-place at // every step. We're just missing a few functions for ATen, but the fix // should be straightforward. - // Note: left unitialized since empty shape is broadcastable to any shape + // Note: left uninitialized since empty shape is broadcastable to any shape std::vector map_size; map_size.reserve(8); for (const auto arg_idx : arg_subset) { @@ -201,7 +201,7 @@ static void launchFusion( for (const auto& c : fusion.concatDesc()) flat_outputs_size += c.nSubTensors(); - // Fails if the elements of the first (any) tensor are not expressable as + // Fails if the elements of the first (any) tensor are not expressible as // a 32-bit integer. // Note: this code assumes that inputs are 32-bit addressable // Note: this code assumes that all inputs are of the same size diff --git a/torch/csrc/jit/codegen/fuser/fused_kernel.h b/torch/csrc/jit/codegen/fuser/fused_kernel.h index de00904a749c8e..0f785c45066095 100644 --- a/torch/csrc/jit/codegen/fuser/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/fused_kernel.h @@ -40,7 +40,7 @@ struct FusedKernel { // CUDA code), and the remainder are pointers to the TensorInfo structs // that compiled code uses to load Tensor data. // launch_with_tensors handles packing at::Tensors into this arguments array. - // CPU code uses the same convension so that launch_with_tensors can be + // CPU code uses the same convention so that launch_with_tensors can be // shared. virtual void launch_raw(const uint32_t numel, std::vector& arguments) const = 0; diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.cpp b/torch/csrc/jit/codegen/onednn/graph_helper.cpp index cc72489cec5984..2ef9f3cfa955cc 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_helper.cpp @@ -70,7 +70,7 @@ Operator LlgaGraphHelper::makeBinaryOp(Node* node, opkind kind) { // third_party/ideep/mkl-dnn/src/interface/op_def.hpp. Operator LlgaGraphHelper::createOperator(Node* node) { auto nodeKind = node->kind(); - // we're using an if-else clause instead of a switch staement + // we're using an if-else clause instead of a switch statement // because we would soon be adding custom ops with function schemas. // We would have to use Symbol::fromQualString at that time anyway, // but we are okay with this choice, since this code is not in the hot-path. diff --git a/torch/csrc/jit/codegen/onednn/kernel.cpp b/torch/csrc/jit/codegen/onednn/kernel.cpp index 6b9c6a6c64a926..c5421643e8c43e 100644 --- a/torch/csrc/jit/codegen/onednn/kernel.cpp +++ b/torch/csrc/jit/codegen/onednn/kernel.cpp @@ -84,9 +84,9 @@ ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) { for (const auto i : c10::irange(nGraphInputs_)) { auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]); initializedInputIds_.insert(spec.tid()); - int64_t occurence = tensorIdToOccurence[spec.tid()]; - inputSpecs.insert(inputSpecs.end(), occurence, spec); - runArgsIdx_.insert(runArgsIdx_.end(), occurence, i); + int64_t occurrence = tensorIdToOccurence[spec.tid()]; + inputSpecs.insert(inputSpecs.end(), occurrence, spec); + runArgsIdx_.insert(runArgsIdx_.end(), occurrence, i); } GRAPH_DEBUG("Initializing constant input tensors"); initializeConstantInputs(); diff --git a/torch/csrc/jit/docs/serialization.md b/torch/csrc/jit/docs/serialization.md index 3fb463c7e7fe37..43f7e261f02079 100644 --- a/torch/csrc/jit/docs/serialization.md +++ b/torch/csrc/jit/docs/serialization.md @@ -371,7 +371,7 @@ TorchScript class, or a `ScriptModule`. Owns other its attribute types **`Object`**: An instance of a particular class. Own the `CompilationUnit` that owns its `ClassType`. This is to ensure that if the user passes the object around in C++, all its code will stay around and methods will be -invokable. +invocable. **`Module`**: A view over a `ClassType` and the `Object` that holds its state. Also responsible for turning unqualified names (e.g. `forward()`) into diff --git a/torch/csrc/jit/frontend/builtin_functions.cpp b/torch/csrc/jit/frontend/builtin_functions.cpp index fec41092df2355..2225f58e54e75e 100644 --- a/torch/csrc/jit/frontend/builtin_functions.cpp +++ b/torch/csrc/jit/frontend/builtin_functions.cpp @@ -103,10 +103,10 @@ struct BuiltinFunctionRegistry { // re-lock, the mutex without waiting), and report no loaded builtins during // init. std::lock_guard guard(mutex); - if (state == INTIIALIZING) { + if (state == INITIALIZING) { return empty; } else if (state == UNINITIALIZED) { - state = INTIIALIZING; + state = INITIALIZING; loadBuiltinFunctions(); state = INITIALIZED; } @@ -168,10 +168,16 @@ struct BuiltinFunctionRegistry { loadSource(aten_ops_additional, "aten"); // These are under `prim` instead of `aten` since they exist to bind certain - // tensor property getters to correpsonding methods + // tensor property getters to corresponding methods loadSource(tensor_properties, "prim"); } - enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED; + enum { + UNINITIALIZED = 0, + INITIALIZING = 1, + // typo in the original code, keeping for compatibility + INTIIALIZING = 1, // codespell:ignore + INITIALIZED = 2 + } state = UNINITIALIZED; std::recursive_mutex mutex; std::vector> modules; std::unordered_map> builtins_by_name_; diff --git a/torch/csrc/jit/frontend/exit_transforms.cpp b/torch/csrc/jit/frontend/exit_transforms.cpp index 86b546a0a7b46e..48fc133fe3d044 100644 --- a/torch/csrc/jit/frontend/exit_transforms.cpp +++ b/torch/csrc/jit/frontend/exit_transforms.cpp @@ -333,7 +333,7 @@ struct ExitTransformer { std::vector exit_block_vals; // after an exit, the only values that will get used // are the hasExited() and exitValues(), so we match the existing - // block outputs with unitialized + // block outputs with uninitialized exit_block_vals = matchValuesWithUnitialized(block->outputs()); // Set the new if to have the same outputs of the original block, @@ -362,7 +362,7 @@ struct ExitTransformer { // break // j = j + 1 // where the j + 1 value will be a block output, but since they will - // never be used, it is safe to replace them with unitialized value + // never be used, it is safe to replace them with uninitialized value void destroyNodeAfterExit(Node* n) { for (auto output : n->outputs()) { if (!output->uses().empty()) { diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index c36678ef363ca1..3004562e9ff568 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -959,7 +959,7 @@ struct to_ir { emitDef( def, nullptr, - closure_block); // ignore schema return, we just wont use it for now + closure_block); // ignore schema return, we just won't use it for now // since we never create a Method for the closure }; auto closure_value = emitClosure(emit_body); @@ -1578,7 +1578,7 @@ struct to_ir { /*default_to_union=*/true, elem_type_hint); - // Case: The list comprehension generated heterogenous values, + // Case: The list comprehension generated heterogeneous values, // and we don't have a type hint to suggest that this is what the // user expected if (!type_hint && (*unified_elem_type)->isUnionType()) { @@ -1701,7 +1701,7 @@ struct to_ir { << "the first generated key was " << k->type()->repr_str()); } else if ( first_generated_key_type && first_generated_key_type != k->type()) { - // Values can be heterogenous, so we only need to check that the + // Values can be heterogeneous, so we only need to check that the // key types are all the same throw( ErrorReport(dc) @@ -2118,7 +2118,7 @@ struct to_ir { // Try to unify the types. If we found a type annotation earlier // in the environment, and if that type annotation is some form // of union, then we need to tell `unifyTypes` not to throw an - // error if the branched return types we found are heterogenous + // error if the branched return types we found are heterogeneous bool default_to_union = full_type && (full_type->kind() == UnionType::Kind || full_type->kind() == OptionalType::Kind || @@ -2440,7 +2440,7 @@ struct to_ir { SugaredValuePtr iterable = sv->iter(loc, method); // We unroll the loop for iterables that contain ModuleLists so that we can - // compile Heterogenous module lists. + // compile Heterogeneous module lists. if (!iterable->shouldEmitUnrolled()) { emitLoopCommon(loc, emit_body, iterable, targets, {}); } else { @@ -4260,7 +4260,7 @@ struct to_ir { } std::shared_ptr emitRpcExpr(const Apply& apply, Symbol rpc_op) { - // TODO: This is a temporary apporoach to enable calling user fucntion + // TODO: This is a temporary apporoach to enable calling user function // through RPC in TorchScript, // Ideally, function value in JIT IR is first-class citizen and // The RPC C++ entry API can take c10::Function directly. @@ -5399,7 +5399,7 @@ struct FunctionResolver : public Resolver { CompilationUnit::CompilationUnit(const std::string& source) : CompilationUnit() { - // calles the define with native resolver to generate the graph for functions + // calls the define with native resolver to generate the graph for functions define(std::nullopt, source, nativeResolver(), nullptr); } diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index 9708baed7da107..cbc22fab84e232 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -333,12 +333,12 @@ bool isBlockListedSchema(const FunctionSchema& schema) { // Currently JIT does not distinguish ScalarType vs int, so there is really // no way to distinguish x.view(1) vs x.view(torch.int8). So we have to // hardcode the aten::view.dtype here to block this overload. This blocklist - // should be removed when JIT fully suports ScalarType as its own type. + // should be removed when JIT fully supports ScalarType as its own type. if (schema.name() == "aten::view" && schema.overload_name() == "dtype") { return true; } // Note (@tugsbayasgalan) - // TorchScript doesn't suport kwargs so this op collides with aten.max.others + // TorchScript doesn't support kwargs so this op collides with aten.max.others // since both of them have 2 Tensor inputs. Since we don't expect users to // use this op in TS, we just skip it if (schema.name() == "aten::max" && schema.overload_name() == "unary_out") { diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index 4d7904b8707c27..31fc483812ab02 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -448,7 +448,7 @@ std::vector ScriptTypeParser::parseArgsFromDecl( } std::vector ScriptTypeParser::parseReturnFromDecl(const Decl& decl) { - // we represent no annoation on a return type as having no values in the + // we represent no annotation on a return type as having no values in the // schema's return() list // in emitReturn we take the actual return value to be the value of the // return statement if no one was provided here diff --git a/torch/csrc/jit/frontend/source_range.cpp b/torch/csrc/jit/frontend/source_range.cpp index 7bed3bb453032f..89815d386ac05f 100644 --- a/torch/csrc/jit/frontend/source_range.cpp +++ b/torch/csrc/jit/frontend/source_range.cpp @@ -42,12 +42,12 @@ size_t StringCordView::find(const std::string& tok, size_t start) const { size_t offset = start; for (; begin != end_iter; ++begin, ++offset) { if (*begin == tok[0]) { - auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end()); - if (mis.second == tok.end()) { + auto mismatch = std::mismatch(begin, end_iter, tok.begin(), tok.end()); + if (mismatch.second == tok.end()) { // no mismatch, and second string (tok) is exhausted. return offset; } - if (mis.first == end_iter) { + if (mismatch.first == end_iter) { // this str is exhausted but tok is not return std::string::npos; } @@ -312,7 +312,7 @@ void SourceRange::print_with_context( } out << "\n"; } - // print out inital context + // print out initial context out << str.substr(begin_context, start() - begin_context); size_t line_start = start(); size_t line_end = range_end; diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 04ba980bb4e169..d88e77b16cd1be 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -118,7 +118,7 @@ struct TORCH_API SugaredValue // If we are iterating over a Sugared Value and it returns a value from this // function, then we emit an unrolled loop over the variable. This allows us - // to support containers of Heterogenous types, like Module Containers & + // to support containers of Heterogeneous types, like Module Containers & // Tuples virtual std::optional staticLen() { return std::nullopt; @@ -140,7 +140,7 @@ struct TORCH_API SugaredValue << " object is not iterable"); } - // expression for ith elemement for iterable value + // expression for ith element for iterable value virtual std::shared_ptr getitem( const SourceRange& loc, GraphFunction& m, @@ -297,7 +297,7 @@ struct TORCH_API SugaredTupleValue : public SugaredValue { return shared_from_this(); } - // Because this is used to contain SugaredValues of Heterogenous types, + // Because this is used to contain SugaredValues of Heterogeneous types, // we define staticLen() so that when this is iterated over it is emitted // as an unrolled loop. std::optional staticLen() override { @@ -319,7 +319,7 @@ struct TORCH_API BuiltinModule : public SugaredValue { GraphFunction& m, const std::string& field) override { if (field == "autograd") { - // When refering torch.autograd, it is also considered to be a + // When referring torch.autograd, it is also considered to be a // BuiltinModule and we will dispatch to the aten operators for the // methods under its module. return std::make_shared("aten", version); @@ -331,12 +331,12 @@ struct TORCH_API BuiltinModule : public SugaredValue { private: std::string name; - // when we add operator versioning, emit this op as it exising at 'version' + // when we add operator versioning, emit this op as it existing at 'version' // if not set, use the latest version std::optional version; }; -// Represents a class, analagous to `int` or `dict`. Instances of classes, +// Represents a class, analogous to `int` or `dict`. Instances of classes, // like `1` or `{"foo": 5}`, are represented as SimpleValues struct TORCH_API ClassValue : public SugaredValue { explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {} diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 145dd02f6c7b96..3cfa77ef05cca1 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -557,7 +557,7 @@ void TracingState::setValue(const IValue& v, Value* value) { // If the value comes from a CallFunction or CallMethod, it may not have // shape information attached. For debuggability, we enhance the type - // information by assigning the concrete value's tupe to the jit::Value. + // information by assigning the concrete value's type to the jit::Value. if (auto tensor_type = value->type()->cast()) { if (!tensor_type->isComplete()) { value->inferTypeFrom(var); diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index ff2c4127ff85c6..16edf669da9be1 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -53,7 +53,7 @@ class MutableTypePtrHelper { // Tensor with shape information removed. For example, a Tensor // of dimension 4 would map to the same type as a Tensor of // dimension 1. This allows us to treat all subclasses of Tensor - // as a single, homogenous "Tensor" type. + // as a single, homogeneous "Tensor" type. std::optional mapTypeToAliasTypeSet(const TypePtr& type) { if (mutable_type_cache_) { const AliasTypeSet* result = mapTypeToBorrowedAliasTypeSet(type); diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index f83c96a2da186f..497412c6476e5e 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -48,7 +48,7 @@ class ValueAndMemoryLocationSet; * * `descendFunctionCalls` - recursively analyze function and method calls * instead of conservative analysis. Generally analysis should be done after - * inlining so the implmentation for recursive analysis is unoptimized. + * inlining so the implementation for recursive analysis is unoptimized. */ class AliasDb { public: @@ -102,7 +102,7 @@ class AliasDb { // Do any nodes write to an alias set output by `n`? TORCH_API bool hasOutputWriters(const Node* n) const; - // Do any nodes write to an alias set inputed/outputed by `n`? + // Do any nodes write to an alias set inputted/outputted by `n`? TORCH_API bool hasWriters(const Node* n) const; // Do any nodes write to `v`s memory location? @@ -338,7 +338,7 @@ TORCH_API void Lint(const AliasDb* db); * * The AliasDb must not be mutated after construction of a * ValueAndMemoryLocationsSet, or else the MemoryLocations stored in the * ValueAndMemoryLocationSet will no longer be accurate. - * * A ValueAndMemoryLocationsSet is tied to an instsance of AliasDb but + * * A ValueAndMemoryLocationsSet is tied to an instance of AliasDb but * does not own the AliasDb. It is the user's responsibility to ensure * that the AliasDb outlives the ValuesAndMemoryLocationsSet. * diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index a4e13f0f6f12f8..fea29767d2653d 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1143,7 +1143,7 @@ bool Node::isNondeterministic() const { if (!kind().is_aten()) { return false; } - // All aten ops are expecte to have a schema. However this is left as a + // All aten ops are expected to have a schema. However this is left as a // warning instead of an assert to ensure that previous use cases do not // break. if (!schema) { @@ -1648,7 +1648,7 @@ Block* Node::findCommonAncestorBlockWith(Node* n) { n2 = n2->owningBlock()->owningNode(); } - // Now they are the same numer of blocks from the graph block, + // Now they are the same number of blocks from the graph block, // recurse upwards, checking if they are on the same block while (true) { if (n1->owningBlock() == n2->owningBlock()) { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index fc780c26c3dd90..c3b4f455d576b7 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -616,7 +616,7 @@ struct TORCH_API Node { // as the equivalents phi-nodes in standard SSA form, // defining a new Value to represent any term that has multiple // definitions depending on how control flowed. Outputs of the node containing - // control flow serve a similiar purpose defining new values for variables + // control flow serve a similar purpose defining new values for variables // that would have different definitions depending on which way control // flowed. @@ -1374,7 +1374,7 @@ struct Graph : std::enable_shared_from_this { // kwargs using Python argument matching rules, and checks that the op matches // a known schema. // - // If this node successfully completes, it guarentees the node + // If this node successfully completes, it guarantees the node // is a correctly-formed invocation of opname TORCH_API Value* insert( Symbol opname, diff --git a/torch/csrc/jit/ir/ir_views.h b/torch/csrc/jit/ir/ir_views.h index 224754ab840bbc..94aec3bde85ae0 100644 --- a/torch/csrc/jit/ir/ir_views.h +++ b/torch/csrc/jit/ir/ir_views.h @@ -143,7 +143,7 @@ struct LoopView { private: Node* node_; - // adjust index_ordering by adding indices 0 - thorugh adjust, and + // adjust index_ordering by adding indices 0 - thorough adjust, and // incrementing all existing inputs by adjust static std::vector adjustIndices( size_t adjust, diff --git a/torch/csrc/jit/ir/irparser.h b/torch/csrc/jit/ir/irparser.h index ed2a62dd8d536f..9b256b71487f64 100644 --- a/torch/csrc/jit/ir/irparser.h +++ b/torch/csrc/jit/ir/irparser.h @@ -13,7 +13,7 @@ struct Value; // \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH. // if parse_tensor_constants is true will construct empty tensors -// for Tensor constants with random or unitialized contents, otherwise will +// for Tensor constants with random or uninitialized contents, otherwise will // throw TORCH_API void parseIR( const std::string& str, @@ -25,7 +25,7 @@ TORCH_API void parseIR( * \p VMAP is filled with String to Value pairs allowing to index Values in the * newly created graph by their name in the original IR string. * if parse_tensor_constants is true will construct empty tensors - * for Tensor constants with random or unitialized contents, otherwise will + * for Tensor constants with random or uninitialized contents, otherwise will * throw */ TORCH_API void parseIR( diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index 372ca2fcf4976c..1551e610c3d108 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -16,7 +16,7 @@ namespace torch::jit { namespace { bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) { - // type_equal doesnt distinguish between mkldnn/pytorch cpu tensors, + // type_equal doesn't distinguish between mkldnn/pytorch cpu tensors, // and we dont want to coalesce mkldnn tensors bc they do layout // transformations based on usage if (lhs.is_mkldnn() || rhs.is_mkldnn()) { diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index af37cc75f9877b..51baee8e277c11 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -208,7 +208,7 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { }; // {source range, node name, InlinedCallStack} -// We store node name because same debug infor will be used for +// We store node name because same debug info will be used for // profiling as well, so we need to know op names as well. using DebugInfoTuple = std::tuple; diff --git a/torch/csrc/jit/ir/subgraph_matcher.h b/torch/csrc/jit/ir/subgraph_matcher.h index 5ace4983de6f02..91e170c052750d 100644 --- a/torch/csrc/jit/ir/subgraph_matcher.h +++ b/torch/csrc/jit/ir/subgraph_matcher.h @@ -11,7 +11,7 @@ namespace torch::jit { * \brief A structure describing a match of a pattern in a graph. * * The structure contains an anchor node, from which the match was found, and - * match-maps for nodes and values. A match-map specifies the correspondance + * match-maps for nodes and values. A match-map specifies the correspondence * between nodes in the pattern graph (match-map keys) with nodes in the actual * graph (match-map values). We keep such maps for both nodes and values. */ @@ -38,7 +38,7 @@ struct Match { * graph are ignored during matching (IOW, we're essentially performing DCE on * the pattern). * - Pattern graph nodes cannot alias. TODO: the check not implemented yet. - * - Aliasing nodes in the graph cannot consitute a match (i.e. through all + * - Aliasing nodes in the graph cannot constitute a match (i.e. through all * found matches, no nodes in the subgraph alias with each other). TODO: check * not implemented yet. * - The matcher will not mutate either the pattern graph or the matched graph. diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 3d71329e61240d..4422608423ee7d 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -125,7 +125,7 @@ void write_archive_current( std::string fname = tensor_dir + tensor_names[i++]; if (use_storage_context && pre_serialized_files.find(fname) != pre_serialized_files.end()) { - // storage has been serialzed already, skip + // storage has been serialized already, skip continue; } writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); @@ -230,7 +230,7 @@ std::stringstream update_bytecode_version( How to add backport_v{i}_to_v{i-1} ? There are two options: - 1) [Format change only, recommended] Constrcut a reader with the + 1) [Format change only, recommended] Construct a reader with the input_model_stream, modify the file, and use PyTorchWriter to write it to output_model_stream. See backport_v5_to_v4. @@ -322,7 +322,7 @@ std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) { // The export function to generate bytecode.pkl for version 4. After bytecode // version bump, the old export function doesn't exist anymore, so keep a copy - // here for backport pupose. + // here for backport purpose. auto writeArchiveV4 = [](PyTorchStreamWriter& writer, const std::string& archive_name, const c10::IValue& value) { @@ -502,7 +502,7 @@ std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) { torch::jit::load(input_model_stream, std::nullopt, extra_files); std::stringstream intermediate_model_stream; // TODO(@pavithran) : Check if debug info is available and use load/save while - // backporting hardcode debaug info to be false untill supported. + // backporting hardcode debaug info to be false until supported. bool hasBytecodeDebug = false; { BytecodeEmitModeGuard argNumGuard( diff --git a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp index 92b5164f52b132..8d847ddeb533f6 100644 --- a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp @@ -393,7 +393,7 @@ ModelCompatCheckResult is_compatible( OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name); // If the runtime op has no schema information its a false alarm and isn't - // actually useable + // actually usable if (!runtime_op_info.num_schema_args.has_value()) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; diff --git a/torch/csrc/jit/mobile/debug_info.cpp b/torch/csrc/jit/mobile/debug_info.cpp index eeb8f48ee747ac..0a410a42fef041 100644 --- a/torch/csrc/jit/mobile/debug_info.cpp +++ b/torch/csrc/jit/mobile/debug_info.cpp @@ -76,7 +76,7 @@ std::pair, std::string> getStackTraceWithModuleHierarchy // This function construct stacktrace with module hierarchy // Module hierarchy will contain information about where in the // module hierarchy this source is. For example if conv2d op -// exist in hierarcy A->B->C->Conv2d with type annotations of +// exist in hierarchy A->B->C->Conv2d with type annotations of // A -> TopM, B->MyModule, C->SomeModule, then module hierarchy // will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv) // Source level stack information will be from model source code. diff --git a/torch/csrc/jit/mobile/debug_info.h b/torch/csrc/jit/mobile/debug_info.h index 0cf4b42508b26b..14e1b1e4e7cd18 100644 --- a/torch/csrc/jit/mobile/debug_info.h +++ b/torch/csrc/jit/mobile/debug_info.h @@ -14,7 +14,7 @@ namespace torch::jit { * exception of BackendRuntimeException should raised using debug handles. * getSourceDebugString method is responsible for translating debug * handles to correspond debug information. - * This debug informatin includes stack trace of model level source code and + * This debug information includes stack trace of model level source code and * module hierarchy where the exception occurred. */ class MobileDebugTable { diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h index c1e062edf98348..24c670e01f79b2 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.h +++ b/torch/csrc/jit/mobile/flatbuffer_loader.h @@ -48,7 +48,7 @@ using ExtraFilesMap = std::unordered_map; // shared_ptr overload of this function. // // If should_copy_tensor_memory is true, then the returned module will NOT have -// refences to `data`, so `data` can be freed immediately. +// references to `data`, so `data` can be freed immediately. // // If should_copy_tensor_memory is false, then returned module will have tensors // that points inside of `data`; the caller will need to make sure that `data` @@ -93,7 +93,7 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit( // // This function does steps 1+2+3 described above. // -// We need to have this as a convienience because Python API will need to wrap +// We need to have this as a convenience because Python API will need to wrap // this. C++ clients should use one of the versions of // parse_and_initialize_mobile_module() so they can manage the raw data more // directly. @@ -110,7 +110,7 @@ TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer( char* flatbuffer_content); // The methods below are less efficient because it need to read the stream in -// its entirity to a buffer +// its entirety to a buffer TORCH_API mobile::Module load_mobile_module_from_stream_with_copy( std::istream& in, std::optional device = std::nullopt, diff --git a/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp b/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp index 6770c6e1cb66c9..b02c7ef74096ab 100644 --- a/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp +++ b/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp @@ -105,7 +105,7 @@ std::unordered_map MobileModelRunner:: function_and_info_dict[key.toStringRef()] = data_list; } - // Could store the full mapping of std types, but the 'info' section isnt + // Could store the full mapping of std types, but the 'info' section isn't // needed here std::string input_function = function_and_info_dict["get_inputs_function_name"][0]; diff --git a/torch/csrc/jit/mobile/profiler_edge.h b/torch/csrc/jit/mobile/profiler_edge.h index 117b8b595daa8c..4acfb041fc41ff 100644 --- a/torch/csrc/jit/mobile/profiler_edge.h +++ b/torch/csrc/jit/mobile/profiler_edge.h @@ -38,7 +38,7 @@ class TORCH_API KinetoEdgeCPUProfiler { * * Thus, when KinetoEdgeCPUProfiler is used as RAII to do profiling * within certain scope. In that scope, the captured reference to - * Module will outlive KinetoEdgeCPUProfiler. This is gauranteed because + * Module will outlive KinetoEdgeCPUProfiler. This is guaranteed because * KinetoEdgeCPUProfiler must be constructed later than Module, on stack. * * An example of the anti-pattern and wrong usage is: diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index a3de4f2d3c8302..f9287a5eb7040c 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -36,7 +36,7 @@ TypeParser::TypeParser(std::vector& pythonStrs) // instruction. In nested type, the lowest level type will be at the beginning // of the type list. It is possible to parse it without worrying about // ordering, but it also introduces 1) extra cost to process nested type to -// the correct order 2) lost the benifit that the instruction order is likely +// the correct order 2) lost the benefit that the instruction order is likely // problematic if type list parsing fails. std::vector TypeParser::parseList() { std::vector typePtrs; diff --git a/torch/csrc/jit/operator_upgraders/README.md b/torch/csrc/jit/operator_upgraders/README.md index ce995276d283a2..60558a308110ba 100644 --- a/torch/csrc/jit/operator_upgraders/README.md +++ b/torch/csrc/jit/operator_upgraders/README.md @@ -11,7 +11,7 @@ You can determine if your change in the operator is BC breaking, if it fails `te ### Some examples BC breaking changes -When making changes to the operators, the first thing to identify is if it's BC/FC breaking. Again, we only targetting for BC breaking changes on this guidance. Here are some examples to help understanding what a BC changes may look like: +When making changes to the operators, the first thing to identify is if it's BC/FC breaking. Again, we only targeting for BC breaking changes on this guidance. Here are some examples to help understanding what a BC changes may look like: #### Backward Compatibility Breakage: @@ -32,7 +32,7 @@ When making changes to the operators, the first thing to identify is if it's BC/ ### 1.Preparation -[Build PyTorch from souce](https://github.com/pytorch/pytorch#from-source) and prepare a test model before making changes to the operator, following the process below. A test model before making the operator changes is needed to test the upgrader. Otherwise, after the change to operator, the new runtime will no longer be able to produce a model with the historic operator and can't test it anymore. +[Build PyTorch from source](https://github.com/pytorch/pytorch#from-source) and prepare a test model before making changes to the operator, following the process below. A test model before making the operator changes is needed to test the upgrader. Otherwise, after the change to operator, the new runtime will no longer be able to produce a model with the historic operator and can't test it anymore. 1. Add a test module in `test/jit/fixtures_srcs/fixtures_src.py`. In `test/jit/fixtures_srcs/generate_models.py`, ``` diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index f406ed079c484e..38e4b5068e2ffd 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -319,7 +319,7 @@ static void BatchMMTreeReduce(Block* block, AliasDb& alias_db) { } static bool shape_is_fast_for_side(const at::Tensor& other_side_input) { - // Cutoff chosed by benchmarking on a TITAN V + // Cutoff chose by benchmarking on a TITAN V return other_side_input.numel() <= 1024 * 2048; } diff --git a/torch/csrc/jit/passes/canonicalize.cpp b/torch/csrc/jit/passes/canonicalize.cpp index fca8381d4d776d..1cc849d4a3cd75 100644 --- a/torch/csrc/jit/passes/canonicalize.cpp +++ b/torch/csrc/jit/passes/canonicalize.cpp @@ -96,7 +96,7 @@ static bool isBefore(Node* n1, Node* n2) { } } - // Now they are the same numer of blocks from the graph block, + // Now they are the same number of blocks from the graph block, // recurse upwards, checking if they are on the same block while (true) { if (n1->owningBlock() == n2->owningBlock()) { diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp index 2e2daaa11a0c3f..680f7683009c88 100644 --- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp +++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp @@ -98,7 +98,7 @@ void InplaceMKLDNNSubgraph(const std::shared_ptr& graph) { // This function first calculates aliasing sets, // then calculates the last node each aliasing set is alive for. // Then we go through each node, if it's a node which has an equivalent - // inplace node and the aliasing set for its input is dead afer this node, we + // inplace node and the aliasing set for its input is dead after this node, we // inplace it. Then we merge the aliasing sets for the input and output of the // node and extend the liveness of the set. To inplace a node you need to // prove device and dtype of the input and output are the same, which we've @@ -812,7 +812,7 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) { if (body_node->kind() == aten::conv2d || body_node->kind() == aten::conv3d) { - // this node doesnt handle string padding yet... + // this node doesn't handle string padding yet... if (!body_node->namedInput("padding")->type()->cast()) { body_node->replaceWithNewSymbol(Symbol::prim("mkldnn_convolution")); body_node->destroy(); diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 35692506d8cb27..a0e6babe54b62d 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -167,7 +167,7 @@ std::shared_ptr ToONNX( ConstantValueMap::ClearMaps(); auto new_graph = std::make_shared(graph->current_scope()); py::dict env; - // Kept identical to values in env. Used for constant-time existance check. + // Kept identical to values in env. Used for constant-time existence check. py::set values_in_env; try { BlockToONNX( @@ -448,7 +448,8 @@ void NodeToONNX( std::ostringstream ss; ss << "Error casting results of symbolic for " << op_name << ": expected to return list of op nodes, instead received type ''" - << py::str(raw_output.get_type()) << "': " << py::str(raw_output); + << py::str(py::type::handle_of(raw_output)) + << "': " << py::str(raw_output); throw std::runtime_error(ss.str()); } diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index fedafdce57874b..2687ee9fb07dc7 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -399,7 +399,7 @@ static void InferShapeTypeForUninitializedOutput( } else { const_node->t_(attr::value, at::zeros({}, elem_type)); const_node->output()->setType( - TensorType::create(*(output_type->scalarType()), at::kCPU, {}, {})); + TensorType::create(output_type->scalarType(), at::kCPU, {}, {})); } } else if (auto output_type = other_output->type()->cast()) { TypePtr elem = output_type->getElementType(); diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h index 1f69cb8def1163..8d05fbe9426519 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h @@ -17,7 +17,7 @@ namespace torch::jit { // information. Shape and type information is only available after // _jit_pass_onnx, which converts aten nodes to onnx nodes. So there is a // interdependent issue. _jit_pass_onnx depends on preprocess passes to convert -// aten nodes into convertable condition, and preprocess passes depend on +// aten nodes into convertible condition, and preprocess passes depend on // _jit_pass_onnx to convert upstream nodes and apply onnx shape inference. // Separating the pass into two parts breaks the interdependency. // diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index 3a0f889a728a3c..36d6884637d2a1 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -116,7 +116,7 @@ static std::vector _single_input_general_shape_aten_funcs = { "__getitem__", }; -// Theses are prim::CallFunctions for ops that doesn't require observation and +// These are prim::CallFunctions for ops that doesn't require observation and // have a single input Tensor // Also these ops do computation on the value of Tensor // TODO: [Need verify] looks like we can quantize simple functionals that just @@ -136,7 +136,7 @@ static std::vector _single_input_general_value_call_funcs = { "leaky_relu", }; -// Theses are aten functions for ops that doesn't require observation and +// These are aten functions for ops that doesn't require observation and // have a single input Tensor // Also these ops do computation on the value of Tensor // e.g. `aten::avg_pool2d(%input_tensor, ...)` diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index 4a0d600ca1b947..5fab235044453a 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -1702,7 +1702,7 @@ Module InsertObserversForOnDevicePTQ( // you will have multiple getattrs for the same attribute and thus potentially // multiple observers observing the same value. This will also lead to // increased size of the packed param struct. I dont expect this to be a - // common pattern but something to be aware fo Note that current quant + // common pattern but something to be aware of Note that current quant // workflow does not prevent this anyway since during inset quant dequant // things are inlined anyway helper.fillBoundaryValueMap(cloned_module, observer_method_name); diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 8739c4fcaf424d..2e39bf67bf5f33 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -1622,7 +1622,7 @@ void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps( void InsertQuantDeQuantHelper::runForOnDevicePTQ( Module& module, const std::string& method_name) { - // In all likelihood this really wont do anything because we expect that + // In all likelihood this really won't do anything because we expect that // the input method for quantization's prepare step will be inlined. Thus // only call methods we will see will belong to observer's forward calls. for (auto& invoked_methods : getInvokedMethods(module, method_name)) { @@ -1834,8 +1834,8 @@ Module InsertQuantDeQuantOnDevicePTQ( // ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's // quant dequant RemoveRedundantQuantizationOps: THis is removing activation // observers for dynamic quant when the op related to it is not dynamically - // quantizable. Doesnt really make sense. In our case we wont have those - // anyway since for dynamic quant activations wont be observed We can still + // quantizable. Doesn't really make sense. In our case we won't have those + // anyway since for dynamic quant activations won't be observed We can still // use this function because the above two methods should really be a noop h.propagateQuantizationOps(module); return module; diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h index 549741ac6ed903..86d7b5857c49c2 100644 --- a/torch/csrc/jit/passes/quantization/quantization_patterns.h +++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h @@ -206,7 +206,7 @@ QuantFusionInfo getFixedQParamOpFusionInfo( %r = )"; op_pattern += op_name + "(" + "%a_dequant" + extra_op_arg_list + ")"; // IR pattern common to all ops with fixed quantization parameters for - // asymetric quantization + // asymmetric quantization std::string asym_fixed_qparam_op_suffix = R"( %r_scale : float = prim::Constant[value=0.00390625]() %r_zero_point : int = prim::Constant[value=0]() diff --git a/torch/csrc/jit/passes/symbolic_shape_cache.h b/torch/csrc/jit/passes/symbolic_shape_cache.h index d9c7f66ee66258..4d0f1bdcd62986 100644 --- a/torch/csrc/jit/passes/symbolic_shape_cache.h +++ b/torch/csrc/jit/passes/symbolic_shape_cache.h @@ -8,7 +8,7 @@ namespace torch::jit { struct TORCH_API CanonicalizedSymbolicShape { // TODO: Consider in the future if it is reasonable to // merge code with SymbolicShape or VaryingShape while keeping - // the two not implicitly convertable (and cause bugs). + // the two not implicitly convertible (and cause bugs). CanonicalizedSymbolicShape( const c10::SymbolicShape& orig_shape, std::unordered_map& ss_map) { diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index cd0ec4e3207f23..1471546092230c 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -396,7 +396,7 @@ void insertTypeGuard( namespace { bool has_unsupported_pin_memory(const Node* node) { - // cant support non-constant pin_memory or pin_memory = True + // can't support non-constant pin_memory or pin_memory = True if (auto maybe_index = node->schema().argumentIndexWithName("pin_memory")) { int index = *maybe_index; auto inp = node->input(index); diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h index 4a72b8d409b141..c9007c82b95e5a 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.h +++ b/torch/csrc/jit/passes/tensorexpr_fuser.h @@ -66,7 +66,7 @@ TORCH_API bool isSupported(Node* node); /// work with dynamic shapes unless explicitly register the shape function via /// `torch::jit::RegisterShapeComputeGraphForSchema` for the custom operator. /// -/// @return Reference of the custome operator set +/// @return Reference of the custom operator set /// TORCH_API OperatorSet& getCustomOperatorSet(); diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index 8fd18e4717e28e..f9fd65f9ce5411 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -62,7 +62,7 @@ struct ValueMapper { auto new_outputs = merged_node->outputs(); for (Value* v : new_outputs) { auto maybe_last_use = firstOrLastUse(v, /*find_first*/ false); - // if it doesnt have a use it shouldnt have been added as output + // if it doesn't have a use it shouldn't have been added as output TORCH_INTERNAL_ASSERT(maybe_last_use); const Use last_use = *maybe_last_use; diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index d1bddf370ba91d..254162764afa44 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1263,6 +1263,16 @@ void initJITBindings(PyObject* module) { [](const c10::SymNode& a, const char* file, int64_t line) { return a->guard_size_oblivious(file, line); }) + .def( + "guard_or_false", + [](const c10::SymNode& a, const char* file, int64_t line) { + return a->guard_or_false(file, line); + }) + .def( + "guard_or_true", + [](const c10::SymNode& a, const char* file, int64_t line) { + return a->guard_or_true(file, line); + }) .def( "has_hint", [](const c10::SymNode& a) { @@ -2306,7 +2316,7 @@ void initJITBindings(PyObject* module) { // Throw errors when calling wait() on the returned Future if // any of the original futures would throw. // NB: PythonFutureWrapper takes an unwrap_func which serves as a - // callback to evalute the value in the Future. RPC uses this + // callback to evaluate the value in the Future. RPC uses this // unwrap_func to check whether the returned py::object is a // RemoteException object, and re-throw the exception if it is. // By extracting the c10::ivalue::Future from PythonFutureWrapper diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 3802546420af4c..3f2708619be86d 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -313,7 +313,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { bool is_symbolic = false; for (auto it = obj.begin(); it != obj.end(); it++) { auto elm = *it; - if (torch::is_symint(elm)) { + if (torch::is_symint(elm) || THPVariable_Check(elm.ptr())) { is_symbolic = true; break; } @@ -468,8 +468,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { } else { // We inspect the value to found the compiled TorchScript class // and then create a ivalue::Object from that class type. - py::str qualified_name = py::module::import("torch._jit_internal") - .attr("_qualified_name")(obj.get_type()); + py::str qualified_name = + py::module::import("torch._jit_internal") + .attr("_qualified_name")(py::type::handle_of(obj)); auto pyCu = get_python_cu(); classType = pyCu->get_class(c10::QualifiedName(qualified_name)); if (!classType) { @@ -808,7 +809,7 @@ std::pair, Stack> getOpWithStack( } // This function is used to check if the schema is valid for the given args and -// kwargs. It checks script object by checking wether the FakeScriptObject is +// kwargs. It checks script object by checking whether the FakeScriptObject is // an instance of the corresponding fake class for the actual class used in // schema. bool checkSchemaAllowFakeScriptObject( diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index bbac829782dcb5..f80ae1b9481c43 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -433,20 +433,22 @@ inline InferredType tryToInferType(py::handle input) { } py::bool_ isClass = - py::module::import("inspect").attr("isclass")(input.get_type()); + py::module::import("inspect").attr("isclass")(py::type::handle_of(input)); if (py::cast(isClass)) { // Assume that the class is compiled already or will compile. Invalidate // this later if needed. bool class_compiled = true; // Check if the type is already compiled. - py::object existing_ty = py::module::import("torch.jit._state") - .attr("_get_script_class")(input.get_type()); + py::object existing_ty = + py::module::import("torch.jit._state") + .attr("_get_script_class")(py::type::handle_of(input)); if (existing_ty.is_none()) { // If not, try to compile it. - py::bool_ can_compile = py::module::import("torch._jit_internal") - .attr("can_compile_class")(input.get_type()); + py::bool_ can_compile = + py::module::import("torch._jit_internal") + .attr("can_compile_class")(py::type::handle_of(input)); if (py::cast(can_compile)) { // Try to compile the class. This is wrapped in a try-catch because @@ -456,7 +458,7 @@ inline InferredType tryToInferType(py::handle input) { try { py::module::import("torch.jit._script") .attr("_recursive_compile_class")( - input.get_type(), SourceRange()); + py::type::handle_of(input), SourceRange()); } catch (...) { // Invalidate the assumption that the class compiled so that we don't // look up and return its JIT type as the type for the input. @@ -468,8 +470,9 @@ inline InferredType tryToInferType(py::handle input) { // If the class compiled successfully, look up the existing JIT type by // qualified name and return it. if (class_compiled) { - auto script_class = py::module::import("torch.jit._state") - .attr("_get_script_class")(input.get_type()); + auto script_class = + py::module::import("torch.jit._state") + .attr("_get_script_class")(py::type::handle_of(input)); if (!script_class.is_none()) { auto class_type = py::cast(script_class); @@ -642,18 +645,18 @@ inline InferredType tryToInferContainerType( "are supported ", "as inputs or outputs of traced functions", ", but instead got value of type ", - py::str(input.get_type().attr("__name__")), + py::str(py::type::handle_of(input).attr("__name__")), ".")); } else { // TODO: this message is not correct anymore, since this InferredType is - // used from a bunch of circumstances unrelated to tracing. We can re-use + // used from a bunch of circumstances unrelated to tracing. We can reuse // this instead of the attribute_failure stuff in concreteType return InferredType(c10::str( "Only tensors and (possibly nested) tuples of tensors, lists, or dicts ", "are supported ", "as inputs or outputs of traced functions", ", but instead got value of type ", - py::str(input.get_type().attr("__name__")), + py::str(py::type::handle_of(input).attr("__name__")), ".")); } } @@ -780,7 +783,7 @@ inline std::string friendlyTypeName(py::handle obj) { auto field_names = py::cast>(py::getattr(obj, "_fields")); std::stringstream ss; - ss << py::str(obj.get_type().attr("__name__")); + ss << py::str(py::type::handle_of(obj).attr("__name__")); ss << " (aka NamedTuple("; bool first = true; for (auto& field_name : field_names) { @@ -793,7 +796,7 @@ inline std::string friendlyTypeName(py::handle obj) { ss << "))"; return ss.str(); } else { - return py::str(obj.get_type().attr("__name__")); + return py::str(py::type::handle_of(obj).attr("__name__")); } } @@ -841,7 +844,7 @@ inline IValue returnToIValue(const TypePtr& type, py::handle object) { " expected value of type ", type->str(), " for return value but instead got value of type ", - py::str(object.get_type().attr("__name__")), + py::str(py::type::handle_of(object).attr("__name__")), ".", "\nValue: ", py::repr(object), diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp index b3ea056e0b2962..32ba91df0ab34f 100644 --- a/torch/csrc/jit/python/python_custom_class.cpp +++ b/torch/csrc/jit/python/python_custom_class.cpp @@ -66,7 +66,7 @@ void initPythonCustomClassBindings(PyObject* module) { return ScriptClassFunctionPtr(fn); } - throw AttributeError("%s does not exist", name.c_str()); + throw AttributeError(fmt::format("{} does not exist", name)); }) .def_property_readonly("__doc__", [](const ScriptClass& self) { return self.class_type_.type_->expectRef().doc_string(); diff --git a/torch/csrc/jit/python/python_ivalue.h b/torch/csrc/jit/python/python_ivalue.h index 1d44282d59d678..73297c3ac07949 100644 --- a/torch/csrc/jit/python/python_ivalue.h +++ b/torch/csrc/jit/python/python_ivalue.h @@ -99,7 +99,7 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder { py_obj_.ptr() = nullptr; } - // explicit construction to avoid errornous implicit conversion and + // explicit construction to avoid erroneous implicit conversion and // copy-initialization explicit ConcretePyObjectHolder(py::object py_obj) : py_obj_(std::move(py_obj)) {} diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 21ba03f128f84b..b9db0be814e45e 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -21,7 +21,7 @@ namespace torch::jit { std::string typeString(py::handle h) { - return py::str(h.get_type().attr("__name__")); + return py::str(py::type::handle_of(h).attr("__name__")); } std::optional as_function(const py::object& obj) { @@ -1223,7 +1223,7 @@ std::shared_ptr toSugaredValue( obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) { return SpecialFormValue::create(prim::isinstance); #ifdef USE_RPC - // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. + // RPC module is only available when build flag "USE_DISTRIBUTED" is on. } else if ( isRpcAvailable && obj.ptr() == @@ -1236,7 +1236,7 @@ std::shared_ptr toSugaredValue( return SpecialFormValue::create(prim::rpc_sync); } else if ( isRpcAvailable && - // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. + // RPC module is only available when build flag "USE_DISTRIBUTED" is on. obj.ptr() == py::module::import("torch.distributed.rpc").attr("remote").ptr()) { return SpecialFormValue::create(prim::rpc_remote); diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index 15cc2445fd56b0..c00eefa20df031 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -68,7 +68,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { ErrorReport(loc) << kind() << " cannot be used as a value. " << "Perhaps it is a closed over global variable? If so, please " - << "consider passing it in as an argument or use a local varible " + << "consider passing it in as an argument or use a local variable " << "instead."); } diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 154b80ee3e6b27..81da1605fcbe2b 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -89,7 +89,7 @@ std::pair, Stack> createGraphByTracingWithDict( }; // The argument_names parameter is parsed in python and its order - // is the same as the arguments' decalaration order in forward() method. + // is the same as the arguments' declaration order in forward() method. // These name shall be added to the graph as debug name and the order // should align with the traceable stack we generated by the python dict. std::vector compact_argument_names; diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 44837f4da93a89..b9fbf4d1ec30f7 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -272,7 +272,7 @@ static void checkMutableFunctionDefault( << "Mutable default parameters are not supported because Python binds them to the function" << " and they persist across function calls.\n As a workaround, make the default None and instantiate" << " the default parameter within the body of the function. Found " - << def_arg.get_type() << " on parameter " << arg.name()); + << py::type::handle_of(def_arg) << " on parameter " << arg.name()); } } @@ -1430,7 +1430,7 @@ void initJitScriptBindings(PyObject* module) { return StrongFunctionPtr(std::move(self), fn); } else { throw AttributeError( - "'CompilationUnit' has no attribute '%s'", name.c_str()); + fmt::format("'CompilationUnit' has no attribute '{}'", name)); } }) .def( @@ -1485,12 +1485,30 @@ void initJitScriptBindings(PyObject* module) { "__call__", [](py::args args, const py::kwargs& kwargs) { HANDLE_TH_ERRORS - // see: [pybind11 varargs] auto strongPtr = py::cast(args[0]); - Function& callee = *strongPtr.function_; - py::object result = invokeScriptFunctionFromPython( - callee, tuple_slice(std::move(args), 1), kwargs); - return result; + if (py::module::import("torch") + .attr("compiler") + .attr("is_exporting")() + .cast()) { + TORCH_INTERNAL_ASSERT( + py::hasattr(args[0], py::str("_torchdynamo_inline")), + "During PT2 exporting, we encountered TorchScripted function", + strongPtr.function_->name(), + "When tracing through it, we cannot find its _torchdynamo_inline attribute, ", + "which stores non scripted kcallable. ", + "Please file an issue to PyTorch if you see this error."); + + // remove the function itself with args[1:] + py::slice slice0(1, args.size(), 1); + return args[0].attr("_torchdynamo_inline")( + *args[slice0], **kwargs); + } else { + // see: [pybind11 varargs] + Function& callee = *strongPtr.function_; + py::object result = invokeScriptFunctionFromPython( + callee, tuple_slice(std::move(args), 1), kwargs); + return result; + } END_HANDLE_TH_ERRORS_PYBIND }) .def( @@ -1590,12 +1608,29 @@ void initJitScriptBindings(PyObject* module) { .def( "__call__", [](py::args args, const py::kwargs& kwargs) { - // see: [pybind11 varargs] HANDLE_TH_ERRORS - Method& method = py::cast(args[0]); - - return invokeScriptMethodFromPython( - method, tuple_slice(std::move(args), 1), kwargs); + if (py::module::import("torch") + .attr("compiler") + .attr("is_exporting")() + .cast() && + // TODO: fix all cases where ScriptMethod doesn't have + // __wrapped__, which is the non-scripted original method. E.g. + // it seems the top-level script module's scriptMethod doesn't + // have __wrapped__ attributes: + // class M(torch.nn.Module): + // def forward(self, x): + // return x.cos() + x.sin() + // traced_module = torch.jit.trace(M(), example_inputs=inps) + // , where traced_module.forward is a ScriptMethod but doesn't + // have __wrapped__. + py::hasattr(args[0], "__wrapped__")) { + return args[0].attr("__wrapped__")(*args, **kwargs); + } else { + // see: [pybind11 varargs] + Method& method = py::cast(args[0]); + return invokeScriptMethodFromPython( + method, tuple_slice(std::move(args), 1), kwargs); + } END_HANDLE_TH_ERRORS_PYBIND }) .def_property_readonly("graph", &Method::graph) diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 98e5f4bd00214c..9d4d681f8b32f5 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -55,7 +55,7 @@ C10_DEFINE_bool( C10_DEFINE_bool( torch_jit_enable_expanded_stacks, false, - "When true we will attemps to pre-expand node stacks and cache expanded stacks.") + "When true we will attempts to pre-expand node stacks and cache expanded stacks.") C10_DEFINE_bool( torch_jit_expanded_stacks_mangled, diff --git a/torch/csrc/jit/runtime/jit_exception.h b/torch/csrc/jit/runtime/jit_exception.h index cb4f572a8bd3c0..580febe465ff2a 100644 --- a/torch/csrc/jit/runtime/jit_exception.h +++ b/torch/csrc/jit/runtime/jit_exception.h @@ -18,7 +18,7 @@ struct TORCH_API JITException : public std::runtime_error { return python_class_name_; } - // the original msg if this is from a python exception. The interpretor has + // the original msg if this is from a python exception. The interpreter has // changed the original message by adding "The following operation failed in // the TorchScript interpreter." in front of it in the handleError function. std::optional getOriginalMsg() const { diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 636686507cffc4..d59b93190e36a8 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -115,8 +115,8 @@ bool isSortableListOfObjectsOrTuples( } auto type = ivalues.get(0).type(); - // We assume lists have homogenous types, use first element to determine - // best sorting methods. If in the future we need to support heterogenous + // We assume lists have homogeneous types, use first element to determine + // best sorting methods. If in the future we need to support heterogeneous // types inside list, then sorting needs to have runtime sortable checks. const size_t n = ivalues.size(); for (const auto i : c10::irange(n)) { @@ -1141,7 +1141,7 @@ static const std::vector opGenArgs{ // // create a clone of these declarations with a _hacked_twin overload name // and nullability scrubbed from TensorList arg types - // TOOD find out why this exists and how to do it without the hack + // TODO find out why this exists and how to do it without the hack // OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( @@ -2839,7 +2839,7 @@ void hashValue(Stack& stack) { } static const std::vector opGenArgs2{ - // registered as Any[] so that heterogenous tuples can be called with len() + // registered as Any[] so that heterogeneous tuples can be called with len() OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::len.any(Any[] a) -> int"), listLen, diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp index f163f310cbf4f9..d77e0b3a10d643 100644 --- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp +++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp @@ -3204,7 +3204,7 @@ def _batch_norm_with_update(input: List[int], )=====") + std::string(R"=====(def broadcast_inplace(a: List[int], b: List[int]) -> List[int]: - _0 = "The dims of tensor b ({}) must be less than or equal tothe dims of tensor a ({}) " + _0 = "The dims of tensor b ({}) must be less than or equal to the dims of tensor a ({}) " _1 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" dimsA = torch.len(a) dimsB = torch.len(b) diff --git a/torch/csrc/jit/runtime/static/README.md b/torch/csrc/jit/runtime/static/README.md index 9b72db912684a8..ba5e057ca1ec88 100644 --- a/torch/csrc/jit/runtime/static/README.md +++ b/torch/csrc/jit/runtime/static/README.md @@ -71,7 +71,7 @@ Runtime instances in your code. Static runtime's memory planner does two things: 1) Coalesces internal allocations for tensor storage -2) Does static analysis to figure out how to efficiently re-use memory. +2) Does static analysis to figure out how to efficiently reuse memory. ### Standard Resizing Static runtime will record the space required for each intermediate managed tensor it sees diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 0e2a89544b5613..78378b04b4a620 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -982,7 +982,8 @@ void check_type(const Argument& schema_arg, const IValue& arg) { return; } TORCH_CHECK( - arg.type()->isSubtypeOf(schema_arg.type()), + arg.type()->isSubtypeOf(schema_arg.type()) || + arg.type()->isSubtypeOfExt(schema_arg.type(), /*why_not=*/nullptr), arg.type()->annotation_str(), " is not a subtype of ", schema_arg.type()->annotation_str(), diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index b4138c590e837a..f92a28d5d6cf0d 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -70,7 +70,7 @@ TORCH_API inline bool borrowsOutputs(c10::Symbol kind) { // The output aliases that end up here are as a result of aliasDb failing to // recognize them as outputs due to collection object (e.g., Tuple) aliasing // inputs. -// Values that dont't show up in output_aliases or external_aliases are created +// Values that don't show up in output_aliases or external_aliases are created // and consumed within the graph. class ValueGroup { public: @@ -111,7 +111,7 @@ class TORCH_API ManagedTensorRanges { // If true, then this node is the last use of at least one // managed tensor. availableTensorValuesAfterNode(node) will return a vector - // of the managed tensors that are available for re-use + // of the managed tensors that are available for reuse // in the nodes following this one. bool nodeFreesManagedTensors(Node* node) const; const std::vector& availableTensorValuesAfterNode( @@ -141,7 +141,7 @@ class TORCH_API ManagedTensorRanges { void extendInputLifetime(Node* node, size_t new_end); // Maps Node* to the set of managed tensors that are now available - // for re-use after this node. + // for reuse after this node. c10::FastMap> node_to_newly_free_tensors_{}; // Maps each Value* to its lifetime (start node index, end node index) c10::FastMap value_lifetimes_{}; diff --git a/torch/csrc/jit/runtime/static/memory_planner.cpp b/torch/csrc/jit/runtime/static/memory_planner.cpp index a73f630a651ae2..8660183867e08e 100644 --- a/torch/csrc/jit/runtime/static/memory_planner.cpp +++ b/torch/csrc/jit/runtime/static/memory_planner.cpp @@ -76,7 +76,7 @@ std::vector assignStorageToManagedTensors( // This set maps each Value* to its assigned storage group. c10::FastMap storage_group_mapping; // On each iteration, this vector stores the set of storage groups that - // are available for re-use. + // are available for reuse. std::vector free_storage_groups; auto makeNewStorageGroup = [&](const Value* value) { diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 367eb490a294a0..716202f45687a8 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -529,7 +529,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator { const auto in1_i = p_node->Input(1).toOptional(); const auto in2_i = p_node->Input(2).toBool(); const auto in3_i = p_node->Input(3).toBool(); - // To mimick the behavior of the JIT interpreter, if both dtype + // To mimic the behavior of the JIT interpreter, if both dtype // and copy are not set, we return self. Otherwise, we assume // that dtype is set. if (!in1_i && !in3_i) { diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 527e8e9cee432b..6184889e5f10e8 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -87,8 +88,8 @@ namespace { namespace onnx_torch = ::torch::onnx; namespace onnx = ::ONNX_NAMESPACE; -const static int kInvalidOpsetVersion = -1; -const static int kMainOpsetVersion = 20; +constexpr int kInvalidOpsetVersion = -1; +constexpr int kMainOpsetVersion = 23; // Based on OP_SET_ID_VERSION_MAP in // https://github.com/onnx/onnx/blob/master/onnx/helper.py. constexpr static std::array @@ -114,6 +115,9 @@ constexpr static std::array 8, // opset 18 9, // opset 19 9, // opset 20 + 10, // opset 21 + 10, // opset 22 + 11, // opset 23 }; std::string getNodeStackTraceString(const Node* n) { diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index 8b2d6d84716ae5..b8746d07224120 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -215,7 +214,7 @@ struct TORCH_API BytecodeEmitMode { // true: instruction of default argument values (like LOADC) is emitted. // false: instruction of default argument values are not emitted. Instead // they are fetched from operator schema. -// default_args_before_out_args (to forward compatibile support +// default_args_before_out_args (to forward compatible support // operators allowing out arguments and default arguments): // true: the number of specified arguments will deserialized to (#all_args - // #default_args). false: the number of specified arguments will deserialized to diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index d7f5914c0d0400..e0ded27d375b17 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -131,7 +131,7 @@ std::string get_named_tuple_str_or_default( // str() return "Tensor" and repr_str() return "Tensor (inferred)". If // it's not inferred type, str() return "Tensor[]" and repr_str() // return "Tensor". In cpp, repr_str() will always return "Tensor" - // regardless inferred type. When exporing custom type in bytecode, + // regardless inferred type. When exporting custom type in bytecode, // "Tensor" is the preferred way to deserialize Tensor type std::string named_tuple_type_str = it->is_inferred_type() ? named_tuple_type->str() @@ -554,7 +554,7 @@ void ScriptModuleSerializer::writeArchive( } WriteableTensorData writable_td = getWriteableTensorData(td); if (use_storage_context && serialized_tensors.count(tensor_name)) { - // storage has been serialzed already, skip + // storage has been serialized already, skip continue; } writer_.writeRecord( @@ -698,10 +698,10 @@ void ScriptModuleSerializer::writeByteCode( // debug handles. // The reason we save debug handles conditionally is so that // we dont end up with a model that has debug handles but has not - // debug map to correlate debug handels with. + // debug map to correlate debug handles with. // Once we have a model with both handles and debug map, we can // strip off debug map and have a lean model served to production. - // If exception ocurrs we have a model with debug map that can be + // If exception occurs we have a model with debug map that can be // used to symbolicate debug handles writeArchive( debug_info_telements, diff --git a/torch/csrc/jit/serialization/mobile_bytecode_generated.h b/torch/csrc/jit/serialization/mobile_bytecode_generated.h index cffe8bc7a63645..b61fad2ab7aefd 100644 --- a/torch/csrc/jit/serialization/mobile_bytecode_generated.h +++ b/torch/csrc/jit/serialization/mobile_bytecode_generated.h @@ -8,9 +8,9 @@ // Ensure the included flatbuffers.h is the same version as when this file was // generated, otherwise it may not be compatible. -static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && - FLATBUFFERS_VERSION_MINOR == 3 && - FLATBUFFERS_VERSION_REVISION == 3, +static_assert(FLATBUFFERS_VERSION_MAJOR == 24 && + FLATBUFFERS_VERSION_MINOR == 12 && + FLATBUFFERS_VERSION_REVISION == 23, "Non-compatible flatbuffers version included"); namespace torch { @@ -2597,3 +2597,4 @@ inline void FinishSizePrefixedModuleBuffer( } // namespace torch #endif // FLATBUFFERS_GENERATED_MOBILEBYTECODE_TORCH_JIT_MOBILE_SERIALIZATION_H_ +// @generated diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 6ce524293a7072..2dc3f138ff76d3 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -1,20 +1,20 @@ +#include +#include + #include #include -#ifdef USE_RPC -#include -#endif #include + #include #include #include #include -#include -#include +#ifdef USE_RPC +#include +#endif namespace torch::jit { -using ::c10::IValue; - // Protocol 2 is the highest that can be decoded by Python 2 // See https://docs.python.org/3/library/pickle.html#data-stream-format constexpr static uint8_t PROTOCOL_VERSION = 2; @@ -719,92 +719,4 @@ void Pickler::pushTuple(const IValue& ivalue) { } } -WriteableTensorData getWriteableTensorData( - const at::Tensor& tensor, - bool to_cpu) { - WriteableTensorData result; - result.tensor_ = tensor; - result.size_ = tensor.storage().nbytes(); - // TODO HIP support - if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) { - // NB: This new tensor is created to support cuda tensors. - // Storages can be mutated when converting tensors from cuda to cpu, - // and we need a cpu tensor to copy data from. - result.tensor_ = - at::empty({0}, tensor.options()) - .set_( - tensor.storage(), - /* storage_offset = */ 0, - /* size = */ - {static_cast( - tensor.storage().nbytes() / tensor.element_size())}, - /* stride = */ {1}) - .cpu(); - TORCH_CHECK( - result.tensor_.storage().nbytes() == result.size_, - "Storage tensor size did not match record size"); - } - return result; -} - -bool checkHasValidSetGetState(const std::shared_ptr& cls) { - // Check that the schemas for __getstate__ and __setstate__ are correct - auto getstate = cls->findMethod("__getstate__"); - if (getstate == nullptr) { - return false; - } - auto get_schema = getstate->getSchema(); - - // Check __getstate__ - // __getstate__ is expected to be (self) -> T - TORCH_CHECK( - get_schema.arguments().size() == 1, - "'__getstate__' must have 'self' as its only argument, but found ", - get_schema.arguments().size(), - " arguments"); - TORCH_CHECK( - get_schema.returns().size() == 1, - "'__getstate__' must return 1 value, but found ", - get_schema.returns().size()); - - // Check __setstate__ if the method exists - // __setstate__ is expected to be (self, T) -> None - auto setstate = cls->findMethod("__setstate__"); - if (!setstate) { - return false; - } - auto set_schema = setstate->getSchema(); - - TORCH_CHECK( - set_schema.arguments().size() == 2, - "'__setstate__' must have 'self' and the state as its " - "only arguments, but found ", - set_schema.arguments().size(), - " arguments"); - TORCH_CHECK( - set_schema.returns().size() == 1, - "'__setstate__' must return None, but found ", - set_schema.returns().size(), - " return values"); - TORCH_CHECK( - set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()), - "'__setstate__' must return None, but found value of type", - set_schema.returns().at(0).type()->annotation_str()); - - // Check that the return type of __getstate__ matches the input to - // __setstate__ - auto get_type = get_schema.returns().at(0).type(); - auto set_type = set_schema.arguments().at(1).type(); - - TORCH_CHECK( - get_type->isSubtypeOf(*set_type), - "'__getstate__'s return type (", - get_type->annotation_str(), - ") does not match '__setstate__'s argument type (", - set_type->annotation_str(), - ")"); - - return true; -} - } // namespace torch::jit diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index c4d7af12fa5183..526c840bc10e80 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -9,115 +8,17 @@ #include #include #include +#include #include #include #include #include +#include namespace torch::jit { -// See Python's pickletools.py for a detailed description of each of these codes -enum class PickleOpCode : char { - MARK = '(', - STOP = '.', - POP = '0', - POP_MARK = '1', - DUP = '2', - FLOAT = 'F', - INT = 'I', - BININT = 'J', - BININT1 = 'K', - LONG = 'L', - BININT2 = 'M', - NONE = 'N', - PERSID = 'P', - BINPERSID = 'Q', - REDUCE = 'R', - STRING = 'S', - BINSTRING = 'T', - SHORT_BINSTRING = 'U', - // NB: Avoid using UNICODE as it is a macro in the Windows API - UNICODE_ = 'V', - BINUNICODE = 'X', - APPEND = 'a', - BUILD = 'b', - GLOBAL = 'c', - DICT = 'd', - EMPTY_DICT = '}', - APPENDS = 'e', - GET = 'g', - BINGET = 'h', - INST = 'i', - LONG_BINGET = 'j', - LIST = 'l', - EMPTY_LIST = ']', - OBJ = 'o', - PUT = 'p', - BINPUT = 'q', - LONG_BINPUT = 'r', - SETITEM = 's', - TUPLE = 't', - EMPTY_TUPLE = ')', - SETITEMS = 'u', - BINFLOAT = 'G', - - // Protocol 2 - PROTO = char('\x80'), - NEWOBJ = '\x81', - EXT1 = '\x82', - EXT2 = '\x83', - EXT4 = '\x84', - TUPLE1 = '\x85', - TUPLE2 = '\x86', - TUPLE3 = '\x87', - NEWTRUE = '\x88', - NEWFALSE = '\x89', - LONG1 = '\x8a', - LONG4 = '\x8b', - - // Protocol 3 (Python 3.x) - BINBYTES = 'B', - SHORT_BINBYTES = 'C', - - // Protocol 4 - SHORT_BINUNICODE = char('\x8c'), - BINUNICODE8 = '\x8d', - BINBYTES8 = '\x8e', - EMPTY_SET = '\x8f', - ADDITEMS = '\x90', - FROZENSET = '\x91', - NEWOBJ_EX = '\x92', - STACK_GLOBAL = '\x93', - MEMOIZE = '\x94', - FRAME = '\x95' -}; - using ::c10::IValue; -struct WriteableTensorData { - const char* data() const { - return static_cast(tensor_.storage().data()); - } - size_t sizeInBytes() const { - return size_; - } - size_t nbytes() const { - return tensor_.storage().nbytes(); - } - bool storageHasDeleter() const { - return tensor_.storage().data_ptr().get_context() != nullptr; - } - - private: - friend TORCH_API WriteableTensorData - getWriteableTensorData(const at::Tensor& tensor, bool to_cpu); - at::Tensor tensor_; - uint64_t size_; -}; - -void setTypeTags(bool state); -bool getTypeTags(); - class TORCH_API Pickler { AT_DISALLOW_COPY_AND_ASSIGN(Pickler); @@ -281,145 +182,4 @@ class TORCH_API Pickler { bool tag_aggregates_; }; -// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor -// if it was CUDA and to_cpu is True. -TORCH_API WriteableTensorData -getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true); - -// return the value of the tensor's storage pointer -uint64_t getStorageKey(const at::Tensor& tensor); - -// if the cls has __getstate__/__setstate__ -// assert they have the right schema and return true, -// otherwise return false -bool checkHasValidSetGetState(const std::shared_ptr& cls); - -// Declare BackendMeta serialization and deserialization function pointer types. -using BackendMetaPtr = std::function< - void(const at::Tensor&, std::unordered_map&)>; - -// A allowlist of device type, currently available is PrivateUse1 -inline std::unordered_set& GetBackendMetaAllowlist() { - static std::unordered_set DeviceTypeAllowlist{ - c10::DeviceType::PrivateUse1}; - return DeviceTypeAllowlist; -} - -// Dynamically obtain serialization function pairs -// that require the corresponding backend. -inline std::array< - std::optional>, - at::COMPILE_TIME_MAX_DEVICE_TYPES>& -GetBackendMetaSerialization() { - // The array to save function pointer for BackendMeta serialization. - // key is the DeviceType, value is std::pair obj. - // value.first represent get function and value.seconde represent set function - static std::array< - std::optional>, - at::COMPILE_TIME_MAX_DEVICE_TYPES> - BackendMetaSerialization; - return BackendMetaSerialization; -} - -// Register function pointer of Tensor BackendMetadata for serialization. -TORCH_API inline void TensorBackendMetaRegistry( - c10::DeviceType t, - const BackendMetaPtr& get_fptr, - const BackendMetaPtr& set_fptr) { - // allowlist verification - // Only if the devicetype is in the allowlist, - // we allow the serialization extension to be registered for backendmeta data. - const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist(); - TORCH_CHECK( - DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(), - "It is not allowed to register the serialization method ", - "of backendMeta data for PrivateUse1. ", - "If you have related serialization requirements, ", - "please expand the allowlist"); - // Register function pointer - int device_type = static_cast(t); - auto& BackendMetaSerialization = GetBackendMetaSerialization(); - TORCH_CHECK( - !BackendMetaSerialization[device_type].has_value(), - "The tensor BackendMeta serialization function pointer for ", - t, - " has been registered."); - BackendMetaSerialization[device_type] = - std::optional>( - std::make_pair(get_fptr, set_fptr)); -} - -// Return a map of Tensor Metadata which including BackendMetaData for -// serialization. For now, it only takes care of `conj` and `neg` bit. -inline std::unordered_map getTensorMetadata( - const at::Tensor& t) { - // We don't support serializing `ZeroTensor` as it is not public - // facing yet. - TORCH_CHECK( - !t._is_zerotensor(), - "ZeroTensor is not serializable,", - " please file an issue if required."); - std::unordered_map metadata{}; - - // Only add meta-data if the value is not default. - if (t.is_conj()) { - metadata["conj"] = true; - } - if (t.is_neg()) { - metadata["neg"] = true; - } - // Only add BackendMetaData for custom backend if the function pointer is - // registered. - int device_type = static_cast(t.device().type()); - const auto& BackendMetaSerialization = GetBackendMetaSerialization(); - if (BackendMetaSerialization[device_type].has_value()) { - // Pass the tensor and metadata map references as parameters to the custom - // serialization function. - BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first; - fptr(t, metadata); - } - return metadata; -} - -// set Tensor Metadata based on the map. -// Refer: getTensorMetadata -inline void setTensorMetadata( - const at::Tensor& t, - std::unordered_map metadata) { - auto iter_end = metadata.end(); - auto iter_temp = metadata.find("conj"); - if (iter_temp != iter_end) { - t._set_conj(true); - metadata.erase(iter_temp); - } - iter_temp = metadata.find("neg"); - if (iter_temp != iter_end) { - t._set_neg(true); - metadata.erase(iter_temp); - } - // Only set BackendMetaData for custom backend if the function pointer is - // registered. - int device_type = static_cast(t.device().type()); - const auto& BackendMetaSerialization = GetBackendMetaSerialization(); - if (BackendMetaSerialization[device_type].has_value()) { - // Pass the tensor and metadata map references as parameters to the custom - // deserialization function. - BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second; - fptr(t, metadata); - } -} - -// set Tensor metadata based on the map. -// NOTE: This overload is required by unpickler.cpp -inline void setTensorMetadata( - const at::Tensor& t, - const c10::Dict& metadata_idict) { - std::unordered_map metadata; - for (auto& pair : metadata_idict) { - auto key = *pair.key().toString(); - metadata[key] = pair.value().toBool(); - } - setTensorMetadata(t, std::move(metadata)); -} - } // namespace torch::jit diff --git a/torch/csrc/jit/serialization/pickler_helper.cpp b/torch/csrc/jit/serialization/pickler_helper.cpp new file mode 100644 index 00000000000000..261ae15d36e0a9 --- /dev/null +++ b/torch/csrc/jit/serialization/pickler_helper.cpp @@ -0,0 +1,117 @@ +#include +#include + +#include +#include + +namespace torch::jit { + +WriteableTensorData getWriteableTensorData( + const at::Tensor& tensor, + bool to_cpu) { + WriteableTensorData result; + result.tensor_ = tensor; + result.size_ = tensor.storage().nbytes(); + // TODO HIP support + if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) { + // NB: This new tensor is created to support cuda tensors. + // Storages can be mutated when converting tensors from cuda to cpu, + // and we need a cpu tensor to copy data from. + result.tensor_ = + at::empty({0}, tensor.options()) + .set_( + tensor.storage(), + /* storage_offset = */ 0, + /* size = */ + {static_cast( + tensor.storage().nbytes() / tensor.element_size())}, + /* stride = */ {1}) + .cpu(); + TORCH_CHECK( + result.tensor_.storage().nbytes() == result.size_, + "Storage tensor size did not match record size"); + } + return result; +} + +bool checkHasValidSetGetState(const std::shared_ptr& cls) { + // Check that the schemas for __getstate__ and __setstate__ are correct + auto getstate = cls->findMethod("__getstate__"); + if (getstate == nullptr) { + return false; + } + auto get_schema = getstate->getSchema(); + + // Check __getstate__ + // __getstate__ is expected to be (self) -> T + TORCH_CHECK( + get_schema.arguments().size() == 1, + "'__getstate__' must have 'self' as its only argument, but found ", + get_schema.arguments().size(), + " arguments"); + TORCH_CHECK( + get_schema.returns().size() == 1, + "'__getstate__' must return 1 value, but found ", + get_schema.returns().size()); + + // Check __setstate__ if the method exists + // __setstate__ is expected to be (self, T) -> None + auto setstate = cls->findMethod("__setstate__"); + if (!setstate) { + return false; + } + auto set_schema = setstate->getSchema(); + + TORCH_CHECK( + set_schema.arguments().size() == 2, + "'__setstate__' must have 'self' and the state as its " + "only arguments, but found ", + set_schema.arguments().size(), + " arguments"); + TORCH_CHECK( + set_schema.returns().size() == 1, + "'__setstate__' must return None, but found ", + set_schema.returns().size(), + " return values"); + TORCH_CHECK( + set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()), + "'__setstate__' must return None, but found value of type", + set_schema.returns().at(0).type()->annotation_str()); + + // Check that the return type of __getstate__ matches the input to + // __setstate__ + auto get_type = get_schema.returns().at(0).type(); + auto set_type = set_schema.arguments().at(1).type(); + + TORCH_CHECK( + get_type->isSubtypeOf(*set_type), + "'__getstate__'s return type (", + get_type->annotation_str(), + ") does not match '__setstate__'s argument type (", + set_type->annotation_str(), + ")"); + + return true; +} + +std::unordered_set& GetBackendMetaAllowlist() { + static std::unordered_set DeviceTypeAllowlist{ + c10::DeviceType::PrivateUse1}; + return DeviceTypeAllowlist; +} + +std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES>& +GetBackendMetaSerialization() { + // The array to save function pointer for BackendMeta serialization. + // key is the DeviceType, value is std::pair obj. + // value.first represent get function and value.seconde represent set function + static std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES> + BackendMetaSerialization; + return BackendMetaSerialization; +} + +} // namespace torch::jit diff --git a/torch/csrc/jit/serialization/pickler_helper.h b/torch/csrc/jit/serialization/pickler_helper.h new file mode 100644 index 00000000000000..c1ac5f6feb0325 --- /dev/null +++ b/torch/csrc/jit/serialization/pickler_helper.h @@ -0,0 +1,232 @@ +#pragma once + +#include + +#include +#include + +namespace torch::jit { + +// See Python's pickletools.py for a detailed description of each of these codes +enum class PickleOpCode : char { + MARK = '(', + STOP = '.', + POP = '0', + POP_MARK = '1', + DUP = '2', + FLOAT = 'F', + INT = 'I', + BININT = 'J', + BININT1 = 'K', + LONG = 'L', + BININT2 = 'M', + NONE = 'N', + PERSID = 'P', + BINPERSID = 'Q', + REDUCE = 'R', + STRING = 'S', + BINSTRING = 'T', + SHORT_BINSTRING = 'U', + // NB: Avoid using UNICODE as it is a macro in the Windows API + UNICODE_ = 'V', + BINUNICODE = 'X', + APPEND = 'a', + BUILD = 'b', + GLOBAL = 'c', + DICT = 'd', + EMPTY_DICT = '}', + APPENDS = 'e', + GET = 'g', + BINGET = 'h', + INST = 'i', + LONG_BINGET = 'j', + LIST = 'l', + EMPTY_LIST = ']', + OBJ = 'o', + PUT = 'p', + BINPUT = 'q', + LONG_BINPUT = 'r', + SETITEM = 's', + TUPLE = 't', + EMPTY_TUPLE = ')', + SETITEMS = 'u', + BINFLOAT = 'G', + + // Protocol 2 + PROTO = char('\x80'), + NEWOBJ = '\x81', + EXT1 = '\x82', + EXT2 = '\x83', + EXT4 = '\x84', + TUPLE1 = '\x85', + TUPLE2 = '\x86', + TUPLE3 = '\x87', + NEWTRUE = '\x88', + NEWFALSE = '\x89', + LONG1 = '\x8a', + LONG4 = '\x8b', + + // Protocol 3 (Python 3.x) + BINBYTES = 'B', + SHORT_BINBYTES = 'C', + + // Protocol 4 + SHORT_BINUNICODE = char('\x8c'), + BINUNICODE8 = '\x8d', + BINBYTES8 = '\x8e', + EMPTY_SET = '\x8f', + ADDITEMS = '\x90', + FROZENSET = '\x91', + NEWOBJ_EX = '\x92', + STACK_GLOBAL = '\x93', + MEMOIZE = '\x94', + FRAME = '\x95' +}; + +struct WriteableTensorData { + const char* data() const { + return static_cast(tensor_.storage().data()); + } + size_t sizeInBytes() const { + return size_; + } + size_t nbytes() const { + return tensor_.storage().nbytes(); + } + bool storageHasDeleter() const { + return tensor_.storage().data_ptr().get_context() != nullptr; + } + + private: + friend TORCH_API WriteableTensorData + getWriteableTensorData(const at::Tensor& tensor, bool to_cpu); + at::Tensor tensor_; + uint64_t size_; +}; + +// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor +// if it was CUDA and to_cpu is True. +TORCH_API WriteableTensorData +getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true); + +// if the cls has __getstate__/__setstate__ +// assert they have the right schema and return true, +// otherwise return false +bool checkHasValidSetGetState(const std::shared_ptr& cls); + +// Declare BackendMeta serialization and deserialization function pointer types. +using BackendMetaPtr = std::function< + void(const at::Tensor&, std::unordered_map&)>; + +// A allowlist of device type, currently available is PrivateUse1 +TORCH_API std::unordered_set& GetBackendMetaAllowlist(); + +// Dynamically obtain serialization function pairs +// that require the corresponding backend. +TORCH_API std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES>& +GetBackendMetaSerialization(); + +// Return a map of Tensor Metadata which including BackendMetaData for +// serialization. For now, it only takes care of `conj` and `neg` bit. +inline std::unordered_map getTensorMetadata( + const at::Tensor& t) { + // We don't support serializing `ZeroTensor` as it is not public + // facing yet. + TORCH_CHECK( + !t._is_zerotensor(), + "ZeroTensor is not serializable,", + " please file an issue if required."); + std::unordered_map metadata{}; + + // Only add meta-data if the value is not default. + if (t.is_conj()) { + metadata["conj"] = true; + } + if (t.is_neg()) { + metadata["neg"] = true; + } + // Only add BackendMetaData for custom backend if the function pointer is + // registered. + int device_type = static_cast(t.device().type()); + const auto& BackendMetaSerialization = GetBackendMetaSerialization(); + if (BackendMetaSerialization[device_type].has_value()) { + // Pass the tensor and metadata map references as parameters to the custom + // serialization function. + BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first; + fptr(t, metadata); + } + return metadata; +} + +// set Tensor Metadata based on the map. +// Refer: getTensorMetadata +inline void setTensorMetadata( + const at::Tensor& t, + std::unordered_map metadata) { + auto iter_end = metadata.end(); + auto iter_temp = metadata.find("conj"); + if (iter_temp != iter_end) { + t._set_conj(true); + metadata.erase(iter_temp); + } + iter_temp = metadata.find("neg"); + if (iter_temp != iter_end) { + t._set_neg(true); + metadata.erase(iter_temp); + } + // Only set BackendMetaData for custom backend if the function pointer is + // registered. + int device_type = static_cast(t.device().type()); + const auto& BackendMetaSerialization = GetBackendMetaSerialization(); + if (BackendMetaSerialization[device_type].has_value()) { + // Pass the tensor and metadata map references as parameters to the custom + // deserialization function. + BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second; + fptr(t, metadata); + } +} + +// set Tensor metadata based on the map. +// NOTE: This overload is required by unpickler.cpp +inline void setTensorMetadata( + const at::Tensor& t, + const c10::Dict& metadata_idict) { + std::unordered_map metadata; + for (auto& pair : metadata_idict) { + auto key = *pair.key().toString(); + metadata[key] = pair.value().toBool(); + } + setTensorMetadata(t, std::move(metadata)); +} + +// Register function pointer of Tensor BackendMetadata for serialization. +TORCH_API inline void TensorBackendMetaRegistry( + c10::DeviceType t, + const BackendMetaPtr& get_fptr, + const BackendMetaPtr& set_fptr) { + // allowlist verification + // Only if the devicetype is in the allowlist, + // we allow the serialization extension to be registered for backendmeta data. + const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist(); + TORCH_CHECK( + DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(), + "It is not allowed to register the serialization method ", + "of backendMeta data for PrivateUse1. ", + "If you have related serialization requirements, ", + "please expand the allowlist"); + // Register function pointer + int device_type = static_cast(t); + auto& BackendMetaSerialization = GetBackendMetaSerialization(); + TORCH_CHECK( + !BackendMetaSerialization[device_type].has_value(), + "The tensor BackendMeta serialization function pointer for ", + t, + " has been registered."); + BackendMetaSerialization[device_type] = + std::optional>( + std::make_pair(get_fptr, set_fptr)); +} + +} // namespace torch::jit diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index ef130c1ed45199..70e188816fb4cb 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -212,7 +212,7 @@ struct PythonPrintImpl { // and would appear in the same order when the expression tree is // reparsed. // The last case can be checked - // because when we emit a expresion tree in the parser, + // because when we emit a expression tree in the parser, // we do a left-to-right postorder traversal of the expression tree (emit // children, then emit op). The reverse of this is a right-to-left preorder // traversal of the tree. By doing a right-to-left preorder traversal of the @@ -222,12 +222,12 @@ struct PythonPrintImpl { // expression. // The inductive step is that the right-most input should be produced by the - // node immediatly before the current node if it is in tree order. + // node immediately before the current node if it is in tree order. bool canInline(Value* v) { Node* n = v->node(); // there must be only 1 values, otherwise we need an assignment to handle - // the multiple outout values + // the multiple output values if (n->outputs().size() != 1) return false; // if it is used more than once, then we need a variable @@ -651,7 +651,7 @@ struct PythonPrintImpl { // [reordering of inlines] // We inline anything that is semantically legal to inline, but sometimes // we find that these lines get too long. In that case we break the lines - /// and it is important that we un-inline all the inputs preceeding the long + /// and it is important that we un-inline all the inputs preceding the long /// input: // r = foo(x.add_(b), some_long + expression) // wrong! @@ -1410,7 +1410,7 @@ struct PythonPrintImpl { enforce_importable_(enforce_importable) {} void printClass(const ClassTypePtr& classType) { - // If any of the methods are not Graph funtions, this indicates that + // If any of the methods are not Graph functions, this indicates that // this class is a custom-bound C++ class. Skip serialization // of this class, we will depend on the ClassType being defined // in the target process. diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 918470ddde3332..747dd8aae55d66 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -5,7 +5,6 @@ #endif #include #include -#include #include #include #include @@ -45,7 +44,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { to_process.pop_back(); // ensure we only scan each pointer value once, otherwise this // can become exponential (and if we allow recursive data in the future, - // it would not terminiate). + // it would not terminate). if (w.value.isPtrType()) { const void* key = w.value.internalToPointer(); auto it = scanned.find(key); @@ -491,7 +490,7 @@ PickleOpCode Unpickler::readInstruction() { stack_.size(), " and start index is ", start, - ", but stack_ is iterated by two elemenst at a time"); + ", but stack_ is iterated by two elements at a time"); for (size_t i = start; i < stack_.size(); i += 2) { dict.insert_or_assign(stack_[i], stack_[i + 1]); } diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index 47c845cdd79648..d66cf23f4789f6 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -3,9 +3,10 @@ #include #include #include + #include #include -#include +#include namespace torch::jit { diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py index 5dcf1b28407dd3..6c8316cc9a4208 100644 --- a/torch/csrc/jit/tensorexpr/codegen_external.py +++ b/torch/csrc/jit/tensorexpr/codegen_external.py @@ -77,7 +77,7 @@ def gen_external(native_functions_path, tags_path, external_path): at::Tensor& r = tensors[0]; {nl.join(tensor_decls)} try {{ - at::{name}_out({', '.join(['r'] + arg_names)}); + at::{name}_out({", ".join(["r"] + arg_names)}); }} catch (...) {{ }} }}""" diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index ee20c09f573349..c9aedb115a98fd 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1437,7 +1437,7 @@ void nnc_aten_embedding( r = at::embedding(weight, indices); } catch (...) { } - // TODO: have to copy output because at::embedding doesnt have an out + // TODO: have to copy output because at::embedding doesn't have an out // variant and NNC's external calls don't support allocations memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel()); } diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp index fbe1b5ca3ade01..e75af13df9327f 100644 --- a/torch/csrc/jit/tensorexpr/ir.cpp +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -125,7 +125,7 @@ Dtype Intrinsics::IntrinsicsDtype( IntrinsicsOp op_type, const std::vector& params) { // TODO: check the op_type and make a real decision - // Doesnt this fail with kRand? + // Doesn't this fail with kRand? if (params.empty()) { throw malformed_input("invalid params in Intrinsics"); } else if (params.size() == 1) { diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 9e40117d9c02a6..88d86d639c686d 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -930,7 +930,7 @@ ExprPtr PolynomialTransformer::mutate(const MulPtr& v) { variable = lhs_new; } - // Handle special case mul by 1 since thats safe for floating point, even if + // Handle special case mul by 1 since that's safe for floating point, even if // it's Nan/Inf. if (scalar && immediateEquals(scalar, 1)) { auto c = alloc(v->dtype(), variable); @@ -1105,8 +1105,8 @@ ExprPtr PolynomialTransformer::mutate(const DivPtr& v) { return lhs_new; } - // If numberator and denominator are equal the result is 1. - // Unless the demoninator could be zero. + // If numerator and denominator are equal the result is 1. + // Unless the denominator could be zero. // if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) { // return getImmediateByType(v->dtype(), 1); // } @@ -1745,7 +1745,7 @@ ExprPtr TermExpander::mutate(const TermPtr& v) { std::vector vars; std::vector multilaneVars; - // Assume we can reorder here because we wont merge floating terms. + // Assume we can reorder here because we won't merge floating terms. ExprPtr lastNode{nullptr}; for (const auto& var : v->variables()) { ExprPtr node = var->accept_mutator(this); @@ -1830,7 +1830,7 @@ static ExprPtr polyGCD(const PolynomialPtr& poly) { ExprPtr scalar = poly->scalar(); const std::vector& variables = poly->variables(); - // We ony want to factorize if we're saving complete operations, i.e. no + // We only want to factorize if we're saving complete operations, i.e. no // value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work. int opsSaved = 1; // default to saving the scalar. long GCD = std::abs(immediateAs(scalar)); @@ -2088,7 +2088,7 @@ static ExprPtr simplifyRoundModPattern(const PolynomialPtr& poly) { // TODO: for now don't attempt partial factorization of this // optimization. E.g. it's possible to do: 2 * (x/y) * y + (x%y) => x + - // (x/y) * y but unsure thats actually much better, particularly with + // (x/y) * y but unsure that's actually much better, particularly with // CSE. if (!immediateEquals( evaluateOp(alloc(r->scalar(), m->scalar())), 0)) { diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 93264d7755503b..a8ffa40f58dba6 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1263,11 +1263,11 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides( const std::vector& sorted_stride_indices_descending, const std::vector& strides, BufPtr& buf) { - // We need to convert the output tensor so that its values are layed + // We need to convert the output tensor so that its values are laid // so that when viewed from the output strides the values are correct. - // A contiguous Tensor of size(2, 3) with values 0-5 is layed out as: + // A contiguous Tensor of size(2, 3) with values 0-5 is laid out as: // [0] [1] [2] [3] [4] [5] - // The same valued tensor with strides (1, 2) would be layed out like + // The same valued tensor with strides (1, 2) would be laid out like // [0] [3] [1] [4] [2] [5] // When we are doing the re-ordering of values into the output tensor, // we are iterating per-element of the input, and we are fixed @@ -1378,7 +1378,7 @@ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides( tt->strides().concrete_sizes(), buildErrorMessage("Output strides are unknown.")); const std::vector strides = *tt->strides().concrete_sizes(); - // All Tensors in NNC are layed out in default, contiguous layout. + // All Tensors in NNC are laid out in default, contiguous layout. // If the output is also default contiguous we don't need to do anything if (strides == default_strides) { return Tensor(buf, nullptr); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 6fc52c05761da7..d6c5590a71003a 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -780,7 +780,7 @@ void LLVMCodeGenImpl::emitKernel( GRAPH_DEBUG("\nLLVM generated assembly code\n\n", asmCode_, "\n"); } -// TODO: The binary ops are copypasta. +// TODO: The binary ops are copypaste. void LLVMCodeGenImpl::visit(const AddPtr& v) { v->lhs()->accept(this); @@ -878,7 +878,7 @@ void LLVMCodeGenImpl::visit(const OrPtr& v) { bool rfp = rhs->getType()->isFPOrFPVectorTy(); if (!lfp && !rfp) { - value_ = irb_.CreateOr(lhs, rhs); + value_ = irb_.CreateOr(lhs, rhs); // codespell:ignore } else { throw malformed_input("llvm_codegen: bad type in Or", v); } @@ -1225,7 +1225,7 @@ void LLVMCodeGenImpl::visit(const CastPtr& v) { } value_ = irb_.CreateFPCast(value_, dstType); } else if (dstType->isIntOrIntVectorTy()) { - // Strictly casting from Float -> i8 doesnt give correct results + // Strictly casting from Float -> i8 doesn't give correct results // set one bit true if the input float is not 0 if (v->dtype().scalar_type() == ScalarType::Bool) { llvm::Value* zero = diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 4e09bf51ba96d9..646801fa9a19d0 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -987,7 +987,7 @@ void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { } } - // all bufs will have at least one store (if they have > 1 they cant be + // all bufs will have at least one store (if they have > 1 they can't be // inlined anyway) size_t reads = uses.size() - 1; // if only one read, we can inline it without duplicating work @@ -1843,11 +1843,11 @@ bool LoopNest::hasLoopCarriedDependence(const ForPtr& loop) { auto bLoads = NodeFinder::find(*it2); // ReadAfterWrite for (auto& aStore : aStores) { - for (auto& bLoad : bLoads) { + for (auto& bLoad : bLoads) { // codespell:ignore if (aStore->buf() == bLoad->buf()) { if (!areIndicesLoopIndependent( aStore->indices(), bLoad->indices(), outer_loop_vars)) { - if (isOverlapping(analyzer, aStore, bLoad)) { + if (isOverlapping(analyzer, aStore, bLoad)) { // codespell:ignore return true; } } diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h index 39202f487ad2de..222ac5713d36b2 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -240,7 +240,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor { std::unordered_set> accessesWithin( const StmtPtr& A) const; // TODO: this will return only the AccessInfo for A. It's included for - // completeness but be aware it wont return accesses used in the computation + // completeness but be aware it won't return accesses used in the computation // of A. std::unordered_set> accessesWithin( const ExprPtr& A) const; diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 736f00a126d0b8..37f79d529238d7 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -225,7 +225,7 @@ void RegisterizerAnalysis::visit(const ForPtr& v) { // possible that an access at a higher scope could "unhide" the // conditional access, in which case we need to hoist. If there is no // access to this element at a higher scope then we cannot safely hoist. - // We cannot know at this level whether that will or wont occur. + // We cannot know at this level whether that will or won't occur. // // The solution we take here is to split the space-time continuum, and // keep both versions of the access handy. If the hoisted access is not @@ -542,7 +542,7 @@ void RegisterizerAnalysis::mergeCurrentScopeIntoParent() { closeAccessIntoScope(pCandidate, parent); parentAccesses.erase(parentIt); - // the childs access inserted into the parent scope. + // the children access inserted into the parent scope. closeAccessIntoScope(candidate, parent); continue; } @@ -567,7 +567,7 @@ void RegisterizerAnalysis::mergeCurrentScopeIntoParent() { ++it; } - // Insert the childs closed access into the parent scope. + // Insert the children closed access into the parent scope. closeAccessIntoScope(candidate, parent); } diff --git a/torch/csrc/jit/tensorexpr/registerizer.h b/torch/csrc/jit/tensorexpr/registerizer.h index 752537bb089953..c507d3b13a95e9 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.h +++ b/torch/csrc/jit/tensorexpr/registerizer.h @@ -186,7 +186,7 @@ class AccessInfo { bool firstUsageOverlapped_{false}; // The cost in real ops that this access represents, to enable - // filtering accesses that wont save any loads or stores. + // filtering accesses that won't save any loads or stores. ExprPtr store_cost_; ExprPtr load_cost_; diff --git a/torch/csrc/lazy/core/hash.h b/torch/csrc/lazy/core/hash.h index ea9e9e1be6b4ec..97e72ad7305919 100644 --- a/torch/csrc/lazy/core/hash.h +++ b/torch/csrc/lazy/core/hash.h @@ -20,7 +20,7 @@ using size_t = std::size_t; class TORCH_API hash_t : public c10::uint128 { public: - // Swich from typedef hash_t = uint128 to provide explicit casters + // Switch from typedef hash_t = uint128 to provide explicit casters hash_t(int8_t val) : uint128(static_cast(val)) {} hash_t(int16_t val) : uint128(static_cast(val)) {} hash_t(int32_t val) : uint128(static_cast(val)) {} @@ -69,7 +69,7 @@ hash_t Hash(const T& value) { // breaks falling through to the templated arithmetic types above hash_t TORCH_API Hash(const std::vector& value); -// Specialiazed implementations for proprietary types +// Specialized implementations for proprietary types static inline hash_t Hash(const c10::ScalarType& value) { return DataHash(&value, sizeof(value)); } diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index f0ca671fd59b85..754894e6096b15 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -1042,7 +1042,7 @@ std::vector LazyGraphExecutor::GatherTensorsData( void LazyGraphExecutor::TensorCollectionBarrier(SyncTensorCollection* coll) { if (coll) { static const std::string invalid_device( - "Unknown0"); /* Temp solution to idetify unassigned devices */ + "Unknown0"); /* Temp solution to identify unassigned devices */ if (coll->device.toString() == invalid_device || !coll->unlocker.empty()) { return; } diff --git a/torch/csrc/lazy/core/metrics.h b/torch/csrc/lazy/core/metrics.h index 83b388d7740662..e31e0989e323f7 100644 --- a/torch/csrc/lazy/core/metrics.h +++ b/torch/csrc/lazy/core/metrics.h @@ -232,7 +232,7 @@ TORCH_API std::string CreateMetricReport( const std::vector& metric_names); // Returns the currently registered metric names. Note that the list can grow -// since metrics are usually function intialized (they are static function +// since metrics are usually function initialized (they are static function // variables). TORCH_API std::vector GetMetricNames(); @@ -241,7 +241,7 @@ TORCH_API std::vector GetMetricNames(); TORCH_API MetricData* GetMetric(const std::string& name); // Returns the currently registered counter names. Note that the list can grow -// since counters are usually function intialized (they are static function +// since counters are usually function initialized (they are static function // variables). TORCH_API std::vector GetCounterNames(); diff --git a/torch/csrc/lazy/core/shape.h b/torch/csrc/lazy/core/shape.h index 8b657a19b256af..fc5a69a30df415 100644 --- a/torch/csrc/lazy/core/shape.h +++ b/torch/csrc/lazy/core/shape.h @@ -60,9 +60,9 @@ class TORCH_API Shape { // Sizes are the upper bound sizes for a tensor, used by XLA. std::vector sizes_; - // Stores which dimmensions are symbolic + // Stores which dimensions are symbolic // If nullopt, either it hasn't been initialized or the symbolic - // dimmensions are not calculatable + // dimensions are not calculable std::optional> is_symbolic_ = std::nullopt; }; diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index e2e9795ad5a48d..5e9c7dd295608f 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -59,6 +59,7 @@ #include #include #include +#include #include #include #include @@ -72,7 +73,7 @@ namespace torch::lazy { -// Copied from ATen/native/utils/ParamUtils.h, which aparently I can't include +// Copied from ATen/native/utils/ParamUtils.h, which apparently I can't include // from here? static std::vector expand_param_if_needed( at::IntArrayRef list_param, @@ -106,9 +107,6 @@ TORCH_API std::vector compute_shape_arange_out( // Note: acc_type further defines an accumulataion type depending on the // scalar_t and whether its on cuda vs cpu. using accscalar_t = at::acc_type; - auto xstart = start.to(); - auto xend = end.to(); - auto xstep = step.to(); // we use double precision for (start - end) / step // to compute size_d for consistency across devices. @@ -129,18 +127,7 @@ TORCH_API std::vector compute_shape_arange_out( step.to()); } - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK( - std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", - xstart, - " -> ", - xend); - TORCH_CHECK( - ((xstep > 0) && (xend >= xstart)) || - ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + at::native::arange_check_bounds(start, end, step); TORCH_CHECK( size_d >= 0 && @@ -294,7 +281,7 @@ std::vector compute_shape_convolution( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); // at::convolution performs parameter expansion before running kernels on - // expanded parameters we must do the same. Shape formulae access differnent + // expanded parameters we must do the same. Shape formulae access different // dimensions of e.g. output_padding, but output_padding may be passed in as a // scalar. Sadly, accessing output_padding[1] in this case gives incorrect // results rather than indexing error @@ -367,7 +354,7 @@ static std::vector compute_shape_nonzero( for (auto dim_size : t.sizes()) { max_elements *= dim_size; } - return {Shape(at::kLong, {max_elements, (int64_t)t.sizes().size()})}; + return {Shape(at::kLong, {max_elements, t.dim()})}; } std::vector compute_shape_nonzero(const at::Tensor& self) { @@ -540,7 +527,7 @@ std::vector compute_shape_native_batch_norm( // A separate mean and var needs to be kept for each channel. TORCH_CHECK( - input.sizes().size() >= 2, + input.dim() >= 2, "Input tensor must have at least batch and channel dimensions!"); int64_t num_features = input.size(1); @@ -581,7 +568,7 @@ std::vector compute_shape_native_batch_norm_backward( // A separate mean and var needs to be kept for each channel. TORCH_CHECK( - input.sizes().size() >= 2, + input.dim() >= 2, "Input tensor must have at least batch and channel dimensions!"); int64_t num_features = input.size(1); diff --git a/torch/csrc/lazy/core/tensor.cpp b/torch/csrc/lazy/core/tensor.cpp index 4f153399b1e6fd..85a31d718b8c40 100644 --- a/torch/csrc/lazy/core/tensor.cpp +++ b/torch/csrc/lazy/core/tensor.cpp @@ -252,7 +252,7 @@ at::Tensor LazyTensor::ToTensor(bool detached) { tensor = *tensor_data; if (detached) { if (data()->ir_value || data()->handle != nullptr) { - // If we have other authoritive sources, just drop our reference and + // If we have other authoritative sources, just drop our reference and // transfer it to the caller. data()->tensor_data = std::nullopt; } else { diff --git a/torch/csrc/lazy/core/tensor.h b/torch/csrc/lazy/core/tensor.h index b739399b6bbdb3..a0f4ade6fdc92a 100644 --- a/torch/csrc/lazy/core/tensor.h +++ b/torch/csrc/lazy/core/tensor.h @@ -125,7 +125,7 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { // Retrieves the IR Node representing this LazyTensor. One will be created if // missing. Note that although this is a const API, it actually changes the - // internal state ofthe object. + // internal state of the object. Value GetIrValue() const; void SetIrValue(Value ir_value); @@ -231,7 +231,7 @@ TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor); // lazy tensors, then you should think of that function as an "entrypoint" to // functionalization, and use functionalize_output=true Examples include: // - factory functions (the LTC kernel for at::empty) -// - CPU -> Lazy device converions (the LTC kernel for at::to_device) +// - CPU -> Lazy device conversions (the LTC kernel for at::to_device) // // Case 2: lazy -> lazy // If you're implementing a function that takes in lazy tensors and returns diff --git a/torch/csrc/lazy/core/tensor_impl.cpp b/torch/csrc/lazy/core/tensor_impl.cpp index 04730d5529527c..ce49338936e398 100644 --- a/torch/csrc/lazy/core/tensor_impl.cpp +++ b/torch/csrc/lazy/core/tensor_impl.cpp @@ -195,13 +195,14 @@ bool LTCTensorImpl::is_strides_like_custom( return false; } -bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const { +c10::SymBool LTCTensorImpl::sym_is_non_overlapping_and_dense_custom() const { // This should be true, but false as a temporary fix for a PyTorch core issue, // according to https://github.com/pytorch/xla/pull/2682. return false; } -bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const { +c10::SymBool LTCTensorImpl::sym_is_contiguous_custom( + c10::MemoryFormat _unused) const { // TODO(ezyang): I don't think this branch is actually necessary // TODO(ezyang): I don't think this logic is right, shouldn't we pass on // the memory format? diff --git a/torch/csrc/lazy/core/tensor_impl.h b/torch/csrc/lazy/core/tensor_impl.h index d5e937fc3dc8aa..02f68c01c6f444 100644 --- a/torch/csrc/lazy/core/tensor_impl.h +++ b/torch/csrc/lazy/core/tensor_impl.h @@ -41,10 +41,11 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl { int64_t numel_custom() const override; int64_t storage_offset_custom() const override; int64_t dim_custom() const override; - bool is_contiguous_custom(at::MemoryFormat memory_format) const override; bool is_strides_like_custom(at::MemoryFormat memory_format) const override; - bool is_non_overlapping_and_dense_custom() const override; + c10::SymBool sym_is_non_overlapping_and_dense_custom() const override; + c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const override; c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymIntArrayRef sym_strides_custom() const override; c10::SymInt sym_numel_custom() const override; diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index 98d0adfbfaa8c4..f2b14cbfd7bb4d 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -44,7 +44,7 @@ static std::ptrdiff_t GetTensorId(const at::Tensor& tensor) { static std::string GetTensorsDump( const std::vector& tensors, const std::function)>& - coverter) { + converter) { std::vector nodes; std::vector values; for (auto& tensor : tensors) { @@ -54,7 +54,7 @@ static std::string GetTensorsDump( values.push_back(lazy_tensor->GetIrValue()); nodes.push_back(values.back().node.get()); } - return coverter(nodes); + return converter(nodes); } static std::vector GetLtcTensors( @@ -146,18 +146,18 @@ void initLazyBindings(PyObject* module) { lazy.def( "_get_tensors_text", [](const std::vector& tensors) -> std::string { - auto coverter = [](c10::ArrayRef nodes) { + auto converter = [](c10::ArrayRef nodes) { return torch::lazy::DumpUtil::ToText(nodes); }; - return GetTensorsDump(tensors, coverter); + return GetTensorsDump(tensors, converter); }); lazy.def( "_get_tensors_dot", [](const std::vector& tensors) -> std::string { - auto coverter = [](c10::ArrayRef nodes) { + auto converter = [](c10::ArrayRef nodes) { return torch::lazy::DumpUtil::ToDot(nodes); }; - return GetTensorsDump(tensors, coverter); + return GetTensorsDump(tensors, converter); }); lazy.def( "_get_tensors_backend", @@ -325,10 +325,11 @@ void initLazyBindings(PyObject* module) { #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) }); - // GetPythonFramesFunction() has not ever worked with torchdeploy/multipy - // possibly becuase GetPythonFrames resolves to external cpython rather - // than embedded cpython. So far this problem has only been observed - // internally, so we will just block it off there. + // GetPythonFramesFunction() has not ever worked with + // torchdeploy/multipy possibly because // codespell:ignore multipy + // GetPythonFrames resolves to external cpython rather than embedded cpython. + // So far this problem has only been observed internally, so we will just + // block it off there. #if !(defined(USE_DEPLOY)) diff --git a/torch/csrc/lazy/ts_backend/ops/device_data.cpp b/torch/csrc/lazy/ts_backend/ops/device_data.cpp index 8567f1d2ed8ce7..bf9d6592cf72f0 100644 --- a/torch/csrc/lazy/ts_backend/ops/device_data.cpp +++ b/torch/csrc/lazy/ts_backend/ops/device_data.cpp @@ -30,7 +30,7 @@ NodePtr DeviceData::Create(const std::shared_ptr& data) { // ReuseOrMakeNode may return a reused node which has the same shape, // however, we need to replace the old data_ with the new one. // Ditching the old data_ is safe because tracing is done iteration - // by iteration, and after we lauch the async device execution for the + // by iteration, and after we launch the async device execution for the // previous iteration, data_ in DeviceData nodes are not needed anymore. DeviceData* device_data = static_cast(node.get()); device_data->SetData(data); diff --git a/torch/csrc/lazy/ts_backend/ops/to_copy.h b/torch/csrc/lazy/ts_backend/ops/to_copy.h index 53e0d76689c768..d4b0176d219f24 100644 --- a/torch/csrc/lazy/ts_backend/ops/to_copy.h +++ b/torch/csrc/lazy/ts_backend/ops/to_copy.h @@ -5,8 +5,8 @@ namespace torch::lazy { // This IR was copied from code-generated output, but the entire _to_copy -// operator cannot be trivially code genereated since it is only desirable to -// capture IR for certain permutaions of _to_copy (e.g. dtype), and for the +// operator cannot be trivially code generated since it is only desirable to +// capture IR for certain permutations of _to_copy (e.g. dtype), and for the // others it is difficult to even invoke the aten/eager fallback necessitating // directly implementing the right to(device) behavior class ToCopy : public torch::lazy::TsNode { diff --git a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp index 055ca006528fce..5faf5bb8f1ca3f 100644 --- a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp +++ b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp @@ -271,7 +271,7 @@ void ts_eager_fallback( // the temporary eager output tensor that we created. // // Note [Eager Fallback Does Not Handle View Operators] - // Also note that we are incapable of handling immutable alises properly. + // Also note that we are incapable of handling immutable aliases properly. // Why? // Schemas with an immutable alias'd tensor outputs correspond to view // operators. For example, the `view_as` schema from native_functions.yaml: @@ -340,7 +340,7 @@ void ts_eager_fallback( // We should never hit this for a view op, // because LazyTensor should provide a lowering for the // corresponding view_copy operator. The functionalization pass will - // take care of calling the view_copy operator intead of the view. + // take care of calling the view_copy operator instead of the view. TORCH_CHECK( false, "The operator ", diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index 99fca62916d034..1bb720b810f935 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -398,7 +398,7 @@ at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) { // All of the below ops correspond to CompositeExplicitAutograd kernels from // core that call into view operators internally. These are all composite ops -// that LTC can technically re-use / get for free, but we need to +// that LTC can technically reuse / get for free, but we need to // "functionalize" them to remove the view ops before we can use them. at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) { return at::functionalization::functionalize_aten_op LazyNativeFunctions::native_group_norm( const at::Tensor& input, diff --git a/torch/csrc/lazy/ts_backend/ts_node.h b/torch/csrc/lazy/ts_backend/ts_node.h index 125d4c1283d875..a1bf1083893314 100644 --- a/torch/csrc/lazy/ts_backend/ts_node.h +++ b/torch/csrc/lazy/ts_backend/ts_node.h @@ -78,7 +78,7 @@ const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list"); // Note: shape is undefined for TensorList. We assert in some places that // #shapes matches #outputs and this stems from // the fact that currently all IR nodes represent tensors (there is no -// type system for this IR). Becuase of this, TensorList is a bit of a +// type system for this IR). Because of this, TensorList is a bit of a // hack. // // TODO(whc) once Shape() API is moved to Node base, also make it virtual, and diff --git a/torch/csrc/lazy/tutorial.md b/torch/csrc/lazy/tutorial.md index b72ae13eca7dd8..eaeaac2064aaf0 100644 --- a/torch/csrc/lazy/tutorial.md +++ b/torch/csrc/lazy/tutorial.md @@ -218,7 +218,7 @@ If we don't stop the trace after `optimizer_step` it will include two or more it Another important point is that after `mark_step()` we actually continue tracing the next iteration! And... start executing the previous one at the same time! Really, nothing stops us from tracing the next iteration ...and then the one after next until we hit `if batch_idx % log_interval == 0:` where we actually need to wait for execution to catch up, so we can print out `loss`. Remember to avoid accessing intermediate results too often if you would like to extract the maximum benefit out of Lazy Tensor. -Since every iteration looks exactly like the one before it, the TS backend will be re-using the same TS compilation. +Since every iteration looks exactly like the one before it, the TS backend will be reusing the same TS compilation. Alright, let's run it now! diff --git a/torch/csrc/monitor/events.cpp b/torch/csrc/monitor/events.cpp index 61eda8bfd10a8c..2374b692a3c079 100644 --- a/torch/csrc/monitor/events.cpp +++ b/torch/csrc/monitor/events.cpp @@ -32,8 +32,8 @@ class EventHandlers { } static EventHandlers& get() noexcept { - static auto ehsPtr = new EventHandlers(); - return *ehsPtr; + static auto ehs = EventHandlers(); + return ehs; } private: diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 8c33c596b3278c..95263f108c825e 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -443,7 +443,7 @@ void initModule(PyObject* module) { } TORCH_CHECK( threads.has_value() && threads->size() < 4, - "Number of threads is undefined or has wrong dimention"); + "Number of threads is undefined or has wrong dimension"); TORCH_CHECK( !group_size.has_value() || threads->size() == group_size->size()); diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 72bffcd8b8cea8..b9abd5ae508f3f 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -396,7 +396,8 @@ std::unique_ptr ThreadLocalSubqueue::begin_op( } event->start_time_ = c10::getApproximateTime(); - event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS(); + event->allow_tf32_cublas_ = + at::globalContext().float32Precision("cuda", "matmul") == "tf32"; if (!config_.experimental_config.performance_events.empty()) { const size_t n = config_.experimental_config.performance_events.size(); event->counters_ = std::make_unique(n, 0); @@ -1015,6 +1016,12 @@ class TransferEvents { } } + bool isHiddenEvent(const itrace_t* activity) const { + TORCH_INTERNAL_ASSERT(activity != nullptr); + // Kineto uses "hidden" metadata to mark events that should be hidden. + return activity->getMetadataValue("hidden") == "1"; + } + std::shared_ptr resultFromActivity(const itrace_t* activity) { TORCH_INTERNAL_ASSERT(activity != nullptr); @@ -1035,7 +1042,7 @@ class TransferEvents { {/*id=*/static_cast(activity->flowId()), /*type=*/static_cast(activity->flowType()), /*start=*/activity->flowStart()}}); - + event->hidden_ = isHiddenEvent(activity); // NB: It's tempting to set `event->kineto_activity_`; however we can only // guarantee that the events we passed to Kineto are of type // `GenericTraceActivity`. Others may derive from ITraceActivity and thus diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 675f9999f7cf93..59ebda87a176e8 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -421,7 +421,7 @@ struct TORCH_API Result : public std::enable_shared_from_this { std::weak_ptr parent_; std::vector> children_; bool finished_{false}; - + bool hidden_{false}; const torch::profiler::impl::kineto::activity_t* kineto_activity_{nullptr}; private: diff --git a/torch/csrc/profiler/data_flow.cpp b/torch/csrc/profiler/data_flow.cpp index 6738db961d93b1..5f13421c55246c 100644 --- a/torch/csrc/profiler/data_flow.cpp +++ b/torch/csrc/profiler/data_flow.cpp @@ -58,7 +58,7 @@ struct RawTensors { void calculateUniqueTensorIDs( std::vector>& sorted_results) { - // This task is equivilent to https://leetcode.com/problems/number-of-islands/ + // This task is equivalent to https://leetcode.com/problems/number-of-islands/ // We first cluster events with a greedy index assignment, and then merge // groups that overlap. std::vector tensors; diff --git a/torch/csrc/profiler/data_flow.h b/torch/csrc/profiler/data_flow.h index e2c1ace7b07137..3a485b36510955 100644 --- a/torch/csrc/profiler/data_flow.h +++ b/torch/csrc/profiler/data_flow.h @@ -35,7 +35,7 @@ using AllocationID = strong::type< strong::regular, strong::hashable>; -// We use a Tensor's TensorImpl adress and StorageImpl data start to build the +// We use a Tensor's TensorImpl address and StorageImpl data start to build the // data flow graph. We do not hold an owning reference so we wrap them in strong // types to prevent direct access. using TensorImplAddress = strong::type< diff --git a/torch/csrc/profiler/events.h b/torch/csrc/profiler/events.h index 78bac1fea19ad1..79c9499bfbdec4 100644 --- a/torch/csrc/profiler/events.h +++ b/torch/csrc/profiler/events.h @@ -13,7 +13,7 @@ using perf_counters_t = std::vector; /* Standard list of performance events independent of hardware or backend */ constexpr std::array ProfilerPerfEvents = { /* - * Number of Processing Elelement (PE) cycles between two points of interest + * Number of Processing Element (PE) cycles between two points of interest * in time. This should correlate positively with wall-time. Measured in * uint64_t. PE can be non cpu. TBD reporting behavior for multiple PEs * participating (i.e. threadpool). diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index feed2c21873f51..ec9994e15ec9c0 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -50,6 +50,7 @@ const std::set kXpuTypes = { const std::set kMtiaTypes = { libkineto::ActivityType::MTIA_CCP_EVENTS, libkineto::ActivityType::MTIA_RUNTIME, + libkineto::ActivityType::MTIA_INSIGHT, }; const std::set hpuTypes = { libkineto::ActivityType::HPU_OP, @@ -177,13 +178,15 @@ class ExperimentalConfigWrapper { return !config_.profiler_metrics.empty(); } - void prepareTraceWithExperimentalOptions(bool add_cpu_activity) { + void prepareTraceWithExperimentalOptions( + std::set&& enabled_activities) { + std::set k_activities = + std::move(enabled_activities); #ifdef USE_KINETO - std::set k_activities{ - libkineto::ActivityType::CUDA_PROFILER_RANGE}; + k_activities.insert(libkineto::ActivityType::CUDA_PROFILER_RANGE); - // Only add CPU activities if we are measuring per kernel ranges - if (add_cpu_activity && config_.profiler_measure_per_kernel) { + // Add CPU activities if we are measuring per kernel ranges + if (config_.profiler_measure_per_kernel) { k_activities.insert(kCpuTypes.begin(), kCpuTypes.end()); } @@ -204,6 +207,7 @@ class ExperimentalConfigWrapper { configss << "\nCUPTI_PROFILER_ENABLE_PER_KERNEL=" << (config_.profiler_measure_per_kernel ? "true" : "false") << "\n"; + configss << "CUSTOM_CONFIG=" << config_.custom_profiler_config << "\n"; LOG(INFO) << "Generated config = " << configss.str(); libkineto::api().activityProfiler().prepareTrace( @@ -236,6 +240,18 @@ static const std::string setTraceID(const std::string& trace_id) { configss << "REQUEST_GROUP_TRACE_ID=" << trace_id << "\n"; return configss.str(); } + +static const std::string appendCustomConfig( + const std::string& config, + const std::string& custom_profiler_config) { + if (custom_profiler_config.empty()) { + return config; + } + std::stringstream configss; + configss << config; + configss << "CUSTOM_CONFIG=" << custom_profiler_config << "\n"; + return configss.str(); +} #endif void prepareTrace( @@ -288,11 +304,13 @@ void prepareTrace( // Experimental Configuration options are present if (config && configWrap.assertValid()) { - configWrap.prepareTraceWithExperimentalOptions(has_cpu_activity); + configWrap.prepareTraceWithExperimentalOptions(std::move(k_activities)); return; } - const std::string configStr = setTraceID(trace_id); + const std::string traceIdStr = setTraceID(trace_id); + const std::string configStr = + appendCustomConfig(traceIdStr, config.custom_profiler_config); libkineto::api().activityProfiler().prepareTrace(k_activities, configStr); #endif // USE_KINETO @@ -392,7 +410,6 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { } // TODO: T151322015 case libkineto::ActivityType::MTIA_CCP_EVENTS: - case libkineto::ActivityType::MTIA_WORKLOADD: case libkineto::ActivityType::MTIA_INSIGHT: { // PrivateUse1 kineto backend reuse above ActivityTypes, // If PrivateUse1 backend enabled, this should return diff --git a/torch/csrc/profiler/orchestration/observer.cpp b/torch/csrc/profiler/orchestration/observer.cpp index 363fb206353a2a..18b792a1abe975 100644 --- a/torch/csrc/profiler/orchestration/observer.cpp +++ b/torch/csrc/profiler/orchestration/observer.cpp @@ -21,6 +21,7 @@ ExperimentalConfig::ExperimentalConfig( bool disable_external_correlation, bool profile_all_threads, bool capture_overload_names, + std::string custom_profiler_config, bool adjust_timestamps) : profiler_metrics{std::move(profiler_metrics)}, profiler_measure_per_kernel{profiler_measure_per_kernel}, @@ -31,6 +32,7 @@ ExperimentalConfig::ExperimentalConfig( disable_external_correlation{disable_external_correlation}, profile_all_threads{profile_all_threads}, capture_overload_names{capture_overload_names}, + custom_profiler_config(std::move(custom_profiler_config)), adjust_timestamps{adjust_timestamps} {} /*explicit*/ ExperimentalConfig::operator bool() const { diff --git a/torch/csrc/profiler/orchestration/observer.h b/torch/csrc/profiler/orchestration/observer.h index 54f109ae5c8120..427736e6c63590 100644 --- a/torch/csrc/profiler/orchestration/observer.h +++ b/torch/csrc/profiler/orchestration/observer.h @@ -62,6 +62,7 @@ struct TORCH_API ExperimentalConfig { bool disable_external_correlation = false, bool profile_all_threads = false, bool capture_overload_names = false, + std::string custom_profiler_config = "", bool adjust_timestamps = false); explicit operator bool() const; @@ -101,6 +102,12 @@ struct TORCH_API ExperimentalConfig { * function schema and stored in the profile */ bool capture_overload_names; + /* + * A custom_profiler_config option is introduced to allow custom backends + * to apply custom configurations as needed. + */ + std::string custom_profiler_config; + /* * Controls whether or not timestamp adjustment occurs after profiling. * The purpose of this is to adjust Vulkan event timelines to align with those diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 1ec82fd8c35ca2..92f2f39a5da239 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -340,7 +340,8 @@ void initPythonBindings(PyObject* module) { bool /* adjust_profiler_step */, bool /* disable_external_correlation*/, bool /* profile_all_threads */, - bool /* capture_overload_names */ + bool /* capture_overload_names */, + std::string /* custom_profiler_config*/ >(), "An experimental config for Kineto features. Please note that" "backward compatibility is not guaranteed.\n" @@ -359,6 +360,7 @@ void initPythonBindings(PyObject* module) { " disable_external_correlation (bool) : whether to disable external correlation\n", " profile_all_threads (bool) : whether to profile all threads\n", " capture_overload_names (bool) : whether to include ATen overload names in the profile\n", + " custom_profiler_config (string) : Used to pass some configurations to the custom profiler backend.\n", py::arg("profiler_metrics") = std::vector(), py::arg("profiler_measure_per_kernel") = false, py::arg("verbose") = false, @@ -367,7 +369,8 @@ void initPythonBindings(PyObject* module) { py::arg("adjust_profiler_step") = false, py::arg("disable_external_correlation") = false, py::arg("profile_all_threads") = false, - py::arg("capture_overload_names") = false) + py::arg("capture_overload_names") = false, + py::arg("custom_profiler_config") = "") .def(py::pickle( [](const ExperimentalConfig& p) { // __getstate__ py::list py_metrics; @@ -390,11 +393,12 @@ void initPythonBindings(PyObject* module) { p.disable_external_correlation, p.profile_all_threads, p.capture_overload_names, + p.custom_profiler_config, p.performance_events); }, [](const py::tuple& t) { // __setstate__ if (t.size() >= 5) { - throw std::runtime_error("Expected atleast 5 values in state"); + throw std::runtime_error("Expected at least 5 values in state"); } py::list py_metrics = t[0].cast(); diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 28cf938b976ce3..1c88e80d4021cd 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -113,6 +113,41 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT // Uses the underlying TensorImpl object pointer as the key and map to its // unique id. std::map objectId{}; + + using weak_storage_ptr = c10::weak_intrusive_ptr; + std::unordered_map data_ptr_to_storage_id{}; + std::unordered_map + data_ptr_to_weak_storage_ptr{}; + + ID get_tensor_storage_ID(const c10::Storage& t_storage) { + const std::lock_guard lock(gMutex); + + const void* raw_data_ptr = t_storage.data(); + auto iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr); + if (iter == data_ptr_to_weak_storage_ptr.end()) { + ID id = storage_id_++; + data_ptr_to_storage_id.emplace(raw_data_ptr, id); + data_ptr_to_weak_storage_ptr.emplace( + raw_data_ptr, t_storage.getWeakStorageImpl()); + return id; + } else { + // check if the storage is still alive + if (iter->second.expired()) { + ID id = storage_id_++; + // std::unorder_map does not change if the key is already in the map. + // So we need to remove the key and insert the key with the new value. + data_ptr_to_storage_id.erase(raw_data_ptr); + data_ptr_to_storage_id[raw_data_ptr] = id; + data_ptr_to_weak_storage_ptr.erase(raw_data_ptr); + data_ptr_to_weak_storage_ptr.emplace( + raw_data_ptr, t_storage.getWeakStorageImpl()); + return id; + } else { + return data_ptr_to_storage_id[raw_data_ptr]; + } + } + } + // Observer run state. enum class RunState { uninitialized, disabled, enabled }; @@ -171,10 +206,12 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT // All tensors and operators have an unique id assigned. Increment id for each // new tensor or operator node. - // 0 -> unintialized + // 0 -> uninitialized // 1 -> root ID // 2 ... -> regular node ID std::atomic id_{2}; + + std::atomic storage_id_{1}; }; // Using a singleton manager here to allow init and delete the observer object. @@ -445,8 +482,8 @@ convertIValue( // symbolic sizes/strides implies t->storage_offset() will fail if (tensor_impl->has_storage() && !tensor_impl->has_symbolic_sizes_strides()) { - auto& t_storage = tensor_impl->storage(); - storage_id = getObjectID(ob, t_storage.data()); + const c10::Storage& t_storage = tensor_impl->storage(); + storage_id = ob.get_tensor_storage_ID(t_storage); offset = tensor_impl->storage_offset(); numel = tensor_impl->numel(); itemsize = tensor_impl->itemsize(); diff --git a/torch/csrc/profiler/unwind/mem_file.h b/torch/csrc/profiler/unwind/mem_file.h index 4a718a937348d8..08593dad2d0d63 100644 --- a/torch/csrc/profiler/unwind/mem_file.h +++ b/torch/csrc/profiler/unwind/mem_file.h @@ -35,7 +35,7 @@ struct Section { /// Memory maps a file into the address space read-only, and manages the /// lifetime of the mapping. Here are a few use cases: /// 1. Used in the loader to read in initial image, and to inspect -// ELF files for dependencies before callling dlopen. +// ELF files for dependencies before calling dlopen. /// /// 2. Used in unity to load the elf file. struct MemFile { diff --git a/torch/csrc/profiler/unwind/range_table.h b/torch/csrc/profiler/unwind/range_table.h index b8c405ddad6a82..3ed126bf058b62 100644 --- a/torch/csrc/profiler/unwind/range_table.h +++ b/torch/csrc/profiler/unwind/range_table.h @@ -9,7 +9,7 @@ namespace torch::unwind { template struct RangeTable { RangeTable() { - // guarentee that lower_bound[-1] is always valid + // guarantee that lower_bound[-1] is always valid addresses_.push_back(0); payloads_.emplace_back(std::nullopt); } diff --git a/torch/csrc/stable/library.h b/torch/csrc/stable/library.h index 1649d1665c632b..ef7f1d0784eaef 100644 --- a/torch/csrc/stable/library.h +++ b/torch/csrc/stable/library.h @@ -1,8 +1,10 @@ +#pragma once // this file can only have stable stuff! Akin to shim.h // but unlike shim.h, this file can contain header-only C++ // code for better UX. #include +#include #include @@ -10,41 +12,52 @@ // versions of this file that may be included by different sources namespace { +// ============================================================================= +// helpers for converting between StableIValue and T +// ============================================================================= + +// forward declare so that from/to() calls in detail work +template +StableIValue from(T val); +template +T to(StableIValue val); + namespace detail { -// utility functions to detect optional -template -struct is_optional : std::false_type {}; -template -struct is_optional> : std::true_type {}; -} // namespace detail -template < - typename T, - std::enable_if_t::value, bool> = true> -StableIValue from(T val) { - static_assert( - sizeof(T) <= sizeof(StableIValue), - "StableLibrary stack does not support parameter types larger than 64 bits."); - static_assert(std::is_trivially_copyable_v); - // Initialization should be cheap enough; let's give people well-specified - // reproducible behavior. - StableIValue result = 0; - // NOTE [-Wclass-memaccess ]: reinterpret_cast to suppress - // overzealous -Wclass-memaccess. (see - // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a - // static_assert above that T is trivially copyable, which should be - // enough. - std::memcpy(&result, reinterpret_cast(&val), sizeof(val)); - return result; -} +// ============================================================================= +// FROM CONVERSIONS (T -> StableIValue) +// ============================================================================= -// Specialization for std::nullopt_t +// Specialization for general copyable types (catch-all) => StableIValue +template +struct FromImpl { + static StableIValue call(T val) { + static_assert( + sizeof(T) <= sizeof(StableIValue), + "StableLibrary stack does not support parameter types larger than 64 bits."); + static_assert(std::is_trivially_copyable_v); + // Initialization should be cheap enough; let's give people well-specified + // reproducible behavior. + StableIValue result = 0; + // NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress + // overzealous -Wclass-memaccess. (see + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a + // static_assert above that T is trivially copyable, which should be + // enough. + std::memcpy(&result, reinterpret_cast(&val), sizeof(val)); + return result; + } +}; + +// Specialization for std::nullopt_t => StableIValue template <> -StableIValue from(std::nullopt_t val) { - return from(nullptr); -} +struct FromImpl { + static StableIValue call(std::nullopt_t val) { + return from(nullptr); + } +}; -// Specialization for std::optional +// Specialization for std::optional => StableIValue // [Handling std::optional] // When the schema is represented by an optional type, say int?, then we // expect the custom extension representation to be a std::optional @@ -74,63 +87,118 @@ StableIValue from(std::nullopt_t val) { // The schema requests an optional (T?) so I must call `from` on a // std::optional or a std::nullopt. template -StableIValue from(std::optional val) { - if (!val.has_value()) { - return from(std::nullopt); +struct FromImpl> { + static StableIValue call(const std::optional& val) { + if (!val.has_value()) { + return from(std::nullopt); + } + StableIValue* heap_val = new StableIValue(from(val.value())); + return from(heap_val); } - StableIValue* heap_val = new StableIValue(from(val.value())); - return from(heap_val); -} +}; -template < - typename T, - std::enable_if_t::value, bool> = true> -T to(StableIValue val) { - static_assert(std::is_trivially_copyable_v); - // T may not have a default constructor. (For example, it might be - // c10::Device.) However, std::memcpy implicitly creates a T at the - // destination. So, we can use a union to work around this lack of - // default constructor. - union Result { - Result() {} - T t; - }; - Result result; - // See NOTE[ -Wclass-memaccess ] above. - std::memcpy(reinterpret_cast(&result.t), &val, sizeof(result)); - return result.t; -} +// Specialization for torch::stable::Tensor => StableIValue +// Returns a new owning reference of the underlying Tensor. +template <> +struct FromImpl { + static StableIValue call(const torch::stable::Tensor& val) { + AtenTensorHandle new_ath; + aoti_torch_new_tensor_handle(val.get(), &new_ath); + return from(new_ath); + } +}; -template < - typename T, - std::enable_if_t, bool> = true> -T to(StableIValue val) { - // val should be equivalent to from(nullptr) - return std::nullopt; -} +// ============================================================================= +// TO CONVERSIONS (StableIValue -> T) +// ============================================================================= -// Specialization for std::optional, see [Handling std::optional] above -// as the semantic is the same but in reverse direction as we go from -// IValue --(from_ivalue)-> StableIValue --(to)-> T in custom extension -template < - typename T, - std::enable_if_t::value, bool> = true> -T to(StableIValue val) { - using V = typename T::value_type; - auto sivp = to(val); +// Specialization for StableIValue => general copyable types (catch-all) +template +struct ToImpl { + static T call(StableIValue val) { + static_assert(std::is_trivially_copyable_v); + // T may not have a default constructor. (For example, it might be + // c10::Device.) However, std::memcpy implicitly creates a T at the + // destination. So, we can use a union to work around this lack of + // default constructor. + union Result { + Result() {} + T t; + }; + Result result; + // See NOTE[ -Wclass-memaccess ] above. + std::memcpy(reinterpret_cast(&result.t), &val, sizeof(result)); + return result.t; + } +}; + +// Specialization for StableIValue => std::nullopt_t +template <> +struct ToImpl { + static std::nullopt_t call(StableIValue val) { + // val should be equivalent to from(nullptr) + return std::nullopt; + } +}; + +// Specialization for StableIValue => std::optional, see [Handling +// std::optional] as the semantic is the same but in reverse direction as we go +// from IValue --(from_ivalue)-> StableIValue --(to)-> T in custom extension +template +struct ToImpl> { + static std::optional call(StableIValue val) { + auto sivp = to(val); - // sivp is either nullptr or a pointer to a StableIValue - if (sivp == nullptr) { - return {}; + // sivp is either nullptr or a pointer to a StableIValue + if (sivp == nullptr) { + return {}; + } + auto inner_val = to(*sivp); + + // free the memory associated with StableIValue* sivp + delete sivp; + + return std::make_optional(inner_val); + } +}; + +// Specialization for StableIValue => torch::stable::Tensor +// The resulting stable::Tensor steals ownership of the input's +// underlying AtenTensorHandle. +template <> +struct ToImpl { + static torch::stable::Tensor call(StableIValue val) { + return torch::stable::Tensor(to(val)); } - auto inner_val = to(*sivp); +}; + +} // namespace detail + +// Expose the partially templated class functions through single functions +template +StableIValue from(T val) { + return detail::FromImpl::call(val); +} - // free the memory associated with StableIValue* sivp - delete sivp; +template +StableIValue from(const std::optional& val) { + return detail::FromImpl>::call(val); +} - return std::make_optional(inner_val); +// The below overload is used! See https://godbolt.org/z/859cshxrW +// We are suppressing the warning for versions clang12- and gcc11- +[[maybe_unused]] StableIValue from(const torch::stable::Tensor& val) { + return detail::FromImpl::call(val); } -// end to helpers for converting between StableIValue and actual IValues + +template +T to(StableIValue val) { + return detail::ToImpl::call(val); +} + +// ============================================================================= +// end to helpers for converting between StableIValue and T +// ============================================================================= class StableLibrary final { private: diff --git a/torch/csrc/stable/tensor.h b/torch/csrc/stable/tensor.h new file mode 100644 index 00000000000000..1b9b3fecb41735 --- /dev/null +++ b/torch/csrc/stable/tensor.h @@ -0,0 +1,126 @@ +#pragma once + +// TODO ASAP: THIS FILE SHOULD BE HEADER ONLY BUT ISN'T ENFORCED: +// I only need it for AOTI_TORCH_ERROR_CODE_CHECK, see #154908 +#include + +#include + +namespace torch::stable { + +using DeviceIndex = + int8_t; // this is from c10/core/Device.h and can be header only + +// The torch::stable::Tensor class is a highlevel C++ wrapper around +// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom +// op kernels only really need to interact with Tensor metadata (think sizes, +// strides, device, dtype). Other functions on Tensor (like empty_like) should +// live like the ATen op that they are and exist outside of this struct. +// +// There are several goals of this class over AtenTensorHandle and +// RAIIAtenTensorHandle: +// 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the +// C APIs with AtenTensorHandle. Under the hood we still call to these C shim +// APIs to preserve stability. +// 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass +// around ownership. This makes it difficult to pass one input into 2 +// different functions, e.g., doing something like c = a(t) + b(t) for +// stable::Tensor t. Thus, we use a shared_ptr here. +class Tensor { + private: + std::shared_ptr ath_; + + public: + Tensor() = delete; + + // Construct a stable::Tensor from an AtenTensorHandle (ATH) + // Steals ownership from the ATH + explicit Tensor(AtenTensorHandle ath) + : ath_(ath, [](AtenTensorHandle ath) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); + }) {} + + // Copy and move constructors can be default cuz the underlying handle is a + // shared_ptr + Tensor(const Tensor& other) = default; + Tensor(Tensor&& other) noexcept = default; + + // Copy and move assignment operators can be default cuz the underlying handle + // is a shared_ptr + Tensor& operator=(const Tensor& other) = default; + Tensor& operator=(Tensor&& other) noexcept = default; + + // Destructor can be default: shared ptr has custom deletion logic + ~Tensor() = default; + + // Returns a borrowed reference to the AtenTensorHandle + AtenTensorHandle get() const { + return ath_.get(); + } + + // ============================================================================= + // C-shimified TensorBase APIs: the below APIs have the same signatures and + // semantics as their counterparts in TensorBase.h. + // ============================================================================= + + void* data_ptr() const { + void* data_ptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); + return data_ptr; + } + + int64_t dim() const { + int64_t dim; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); + return dim; + } + + int64_t numel() const { + int64_t numel; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); + return numel; + } + + // note: this is a subset of the original TensorBase API. It takes no + // arguments whereas the original API takes in a kwarg of memory format. + // Here, we assume the default contiguous memory format. + bool is_contiguous() const { + bool is_contiguous; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_is_contiguous(ath_.get(), &is_contiguous)); + return is_contiguous; + } + + int64_t stride(int64_t dim) const { + int64_t stride; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_stride(ath_.get(), dim, &stride)); + return stride; + } + + DeviceIndex get_device() const { + int32_t device_index; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(ath_.get(), &device_index)); + return static_cast(device_index); + } + + bool is_cuda() const { + int32_t device_type; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_type(ath_.get(), &device_type)); + return device_type == aoti_torch_device_type_cuda(); + } + + int64_t size(int64_t dim) const { + int64_t size; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size)); + return size; + } + + // ============================================================================= + // END of C-shimified TensorBase APIs + // ============================================================================= +}; + +} // namespace torch::stable diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index 0e1189e1a7bbbe..eee9af9d9ecbf8 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -254,7 +254,7 @@ namespace torch::gdb { // Return an human-readable representation of the given Tensor. The resulting // string is stored into a malloc()ed buffer. The caller is responsible to // free() it. We use malloc() instead of new[] because it's much easier to -// call free than delete[] from withing gdb. +// call free than delete[] from within gdb. // Currently the code for computing the repr of a tensor is written in Python, // so we need to wrap the Tensor into a Python object first. char* tensor_repr(const at::Tensor& tensor) { @@ -300,7 +300,7 @@ char* tensor_repr(const at::Tensor& tensor) { return result; error: - fprintf(stderr, "torch::gdb::tensor_repr: unexpected error\n"); + fmt::print(stderr, "torch::gdb::tensor_repr: unexpected error\n"); if (PyErr_Occurred()) PyErr_Print(); Py_XDECREF(pytensor); diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index a22a08cc222fa1..681d9458298694 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -339,7 +339,7 @@ struct type_caster> { bool load(handle src, bool) { PyObject* obj = src.ptr(); - // Refered from `THPUtils_unpackComplexDouble` + // Referred from `THPUtils_unpackComplexDouble` Py_complex py_complex = PyComplex_AsCComplex(obj); if (py_complex.real == -1.0 && PyErr_Occurred()) { return false; diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 4ef705d731570a..36a68a2449b4e5 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -299,7 +299,7 @@ static py::object maybe_get_registered_torch_dispatch_rule( #endif auto result = find_torch_dispatch_rule( py::reinterpret_borrow(torch_api_function), - torch_dispatch_object.get_type()); + py::type::handle_of(torch_dispatch_object)); return result; } @@ -350,7 +350,7 @@ static py::object dispatch_on_subclass( auto py_arg = py::reinterpret_borrow(arg); ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( torch_function.ptr(), - py_arg.get_type().ptr(), + py::type::handle_of(py_arg).ptr(), torch_api_function, py_types.ptr(), args, diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 1531c78ce7eb54..72edd38433504f 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -1248,7 +1248,7 @@ auto handle_torch_function_indexing( /* * Check if the input obj is Tensor type, including its subclass, or overloaded * type. If the type defines __torch_function__, it also returns true. - * Otherwise returns flase. If the class is not torch.Tensor, and it defines + * Otherwise returns false. If the class is not torch.Tensor, and it defines * __torch_function__, we append obj to overloaded_args. * * 'obj': the input argument to be checked diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index bf304e981ab20f..34fbfec49c919f 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -186,11 +187,12 @@ class PythonKernelHolder : public c10::OperatorKernel { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; - // Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic - // mode unconditionally in all situations when you're using multipy. - // Eventually just delete this entirely. (Note that you may break multipy - // anyway this way with dispatcher registered functions that require - // hermetic to be off.) + // Jan 2024: We're slated to get rid of multipy, // codespell:ignore multipy + // so stop forcing hermetic mode unconditionally in all situations when + // you're using multipy. // codespell:ignore multipy + // Eventually just delete this entirely. (Note that you may break + // multipy anyway this way with dispatcher // codespell:ignore multipy + // registered functions that require hermetic to be off.) #if defined(USE_DEPLOY) EnableHermeticPyObject g2; #endif @@ -299,8 +301,8 @@ void initDispatchBindings(PyObject* module) { return; }, "") - // Some of these APIs are only for testing and do not work in multipy - // environment + // Some of these APIs are only for testing and do not work in + // multipy environment // codespell:ignore multipy .def( "def_", [](py::object self, const char* schema, const char* alias) { @@ -957,6 +959,15 @@ void initDispatchBindings(PyObject* module) { include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode)); }); + m.def("_autocast_supported_devices", []() { + std::vector result; + for (const auto device_type : at::autocast::_AUTOCAST_SUPPORTED_DEVICES) { + result.emplace_back( + c10::DeviceTypeName(device_type, /*lower_case*/ true)); + } + return result; + }); + m.def("_get_nested_int", [](int64_t data, int64_t coeff) { return c10::SymInt(c10::SymNode( c10::make_intrusive(data, coeff))); diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index c22f752d78349a..2fe0f60e4123e6 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -182,7 +182,7 @@ inline bool THPUtils_unpackNumberAsBool(PyObject* obj) { if (value == -1 && PyErr_Occurred()) { throw python_error(); } - // No need to check overflow, because when overflow occured, it should + // No need to check overflow, because when overflow occurred, it should // return true in order to keep the same behavior of numpy. return (bool)value; } diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h index cfca55bb86ec71..9671e5156b9d4c 100644 --- a/torch/csrc/utils/six.h +++ b/torch/csrc/utils/six.h @@ -13,8 +13,8 @@ namespace six { // by a pytorch operator. inline bool isStructSeq(pybind11::handle input) { - return pybind11::cast(input.get_type().attr("__module__")) == - "torch.return_types"; + return pybind11::cast(pybind11::type::handle_of(input).attr( + "__module__")) == "torch.return_types"; } inline bool isStructSeq(PyObject* obj) { diff --git a/torch/csrc/utils/structseq.cpp b/torch/csrc/utils/structseq.cpp index f23af7bf31f52b..29d20d5a9bfe2f 100644 --- a/torch/csrc/utils/structseq.cpp +++ b/torch/csrc/utils/structseq.cpp @@ -5,7 +5,7 @@ * https://github.com/python/cpython/blob/2.7/Objects/structseq.c * * The purpose of this file is to overwrite the default behavior - * of repr of structseq to provide better printting for returned + * of repr of structseq to provide better printing for returned * structseq objects from operators, aka torch.return_types.* * * For more information on copyright of CPython, see: diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 418061cf376794..f41a1e250e5403 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -172,13 +172,13 @@ ScalarType infer_scalar_type(PyObject* obj) { Py_TYPE(obj)->tp_name, "'"); if (PySequence_Check(obj)) { - std::optional scalarType; auto length = PySequence_Length(obj); if (length < 0) throw python_error(); // match NumPy semantics, except use default tensor type instead of double. if (length == 0) return torch::tensors::get_default_scalar_type(); + ScalarType scalarType{}; for (const auto i : c10::irange(length)) { THPObjectPtr handle(PySequence_GetItem(obj, i)); if (!handle) @@ -187,16 +187,15 @@ ScalarType infer_scalar_type(PyObject* obj) { TORCH_CHECK_TYPE( cur_item != obj, "new(): self-referential lists are incompatible"); ScalarType item_scalarType = infer_scalar_type(cur_item); - scalarType = (scalarType) ? at::promoteTypes(*scalarType, item_scalarType) - : item_scalarType; + scalarType = (i > 0) ? at::promoteTypes(scalarType, item_scalarType) + : item_scalarType; if (scalarType == ScalarType::ComplexDouble) { // this won't change (unless we hit undefined, but that will fail // later). - return *scalarType; + return scalarType; } } - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - return *scalarType; + return scalarType; } TORCH_CHECK(false, "Could not infer dtype of ", Py_TYPE(obj)->tp_name); } @@ -557,6 +556,7 @@ void check_base_legacy_new( c10::DispatchKey::SparseCUDA, c10::DispatchKey::SparseHIP, c10::DispatchKey::SparseXPU, + c10::DispatchKey::SparseMPS, c10::DispatchKey::SparsePrivateUse1, }); TORCH_CHECK( @@ -1654,19 +1654,23 @@ Tensor tensor_frombuffer( return tensor; } -Tensor tensor_fromDLPack(PyObject* data) { - DLManagedTensor* dlMTensor = - (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); - TORCH_CHECK( - dlMTensor, - "from_dlpack received an invalid capsule. " - "Note that DLTensor capsules can be consumed only once, " - "so you might have already constructed a tensor from it once."); - - auto deleter_with_gil = [dlMTensor](void*) { - if (dlMTensor->deleter) { - pybind11::gil_scoped_acquire gil; - dlMTensor->deleter(dlMTensor); +namespace { + +template +at::Tensor tensor_fromDLPackImpl(PyObject* data, T* tensor) { + // HACK: Ensure that we hold the GIL here just in case the + // managed tensor originating from a buggy NumPy build. + bool is_numpy_dlpack_deleter_bugged = + torch::utils::is_numpy_dlpack_deleter_bugged(); + + auto deleter_maybe_gil = [=](void*) { + if (tensor->deleter) { + if (is_numpy_dlpack_deleter_bugged) { + pybind11::gil_scoped_acquire gil; + tensor->deleter(tensor); + } else { + tensor->deleter(tensor); + } } }; @@ -1674,14 +1678,11 @@ Tensor tensor_fromDLPack(PyObject* data) { // destructor function that will be called when the underlying storage goes // out of scope. When the destructor is called, the dlMTensor is destructed // too. - // HACK: Ensure that we hold the GIL here just in case the - // managed tensor originating from a buggy NumPy build. - auto atensor = torch::utils::is_numpy_dlpack_deleter_bugged() - ? at::fromDLPack(dlMTensor, std::move(deleter_with_gil)) - : at::fromDLPack(dlMTensor); + auto atensor = + at::DLPackTraits::fromDLPack(tensor, std::move(deleter_maybe_gil)); // Make sure this capsule will never be used again. - PyCapsule_SetName(data, "used_dltensor"); + PyCapsule_SetName(data, at::DLPackTraits::used); // It is possible that the call to at::fromDLPack is the very first // call to create a Tensor in PyTorch. If so, then _lazy_init has @@ -1693,6 +1694,44 @@ Tensor tensor_fromDLPack(PyObject* data) { return atensor; } +// Check whether `data` is a valid DLPack capsule. +// This function checks for the versioned and unversioned forms. +bool isValidDLPackCapsule(PyObject* data) { + return PyCapsule_IsValid( + data, at::DLPackTraits::capsule) || + PyCapsule_IsValid(data, at::DLPackTraits::capsule); +} + +} // namespace + +Tensor tensor_fromDLPack(PyObject* data) { + const char* bad_capsule = + ("from_dlpack received an invalid capsule. " + "Note that DLTensor capsules can be consumed only once, " + "so you might have already constructed a tensor from it once."); + + if (PyCapsule_IsValid( + data, at::DLPackTraits::capsule)) { + auto versioned = (DLManagedTensorVersioned*)PyCapsule_GetPointer( + data, at::DLPackTraits::capsule); + + TORCH_CHECK(versioned != nullptr, bad_capsule); + TORCH_CHECK( + versioned->version.major <= DLPACK_MAJOR_VERSION, + "unsupported DLPack capsule major version: ", + versioned->version.major, + ". Maximum supported version: ", + DLPACK_MAJOR_VERSION); + + return tensor_fromDLPackImpl(data, versioned); + } else { + auto managed = (DLManagedTensor*)PyCapsule_GetPointer( + data, at::DLPackTraits::capsule); + TORCH_CHECK(managed != nullptr, bad_capsule); + return tensor_fromDLPackImpl(data, managed); + } +} + Tensor asarray( PyObject* obj, std::optional dtype, @@ -1757,7 +1796,7 @@ Tensor asarray( #endif // Check whether 'obj' is a 'DLPack' capsule - if (!tensor.defined() && PyCapsule_IsValid(obj, "dltensor") != 0) { + if (!tensor.defined() && isValidDLPackCapsule(obj)) { tensor = tensor_fromDLPack(obj); } @@ -1786,7 +1825,7 @@ Tensor asarray( tensor = tensor.clone(); } } else { - // If we are not copying, we have to check whther we have the tensor + // If we are not copying, we have to check whether we have the tensor // in the right device, with the right dtype. TORCH_CHECK_VALUE( !wrong_device, diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 4b521f4c265980..ada10b665d055b 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -231,8 +231,13 @@ at::Tensor tensor_from_numpy( int ndim = PyArray_NDIM(array); auto sizes = to_aten_shape(ndim, PyArray_DIMS(array)); auto strides = to_aten_shape(ndim, PyArray_STRIDES(array)); + // This must go before the INCREF and element_size checks + // in case the dtype mapping doesn't exist and an exception is thrown + auto torch_dtype = numpy_dtype_to_aten(PyArray_TYPE(array)); // NumPy strides use bytes. Torch strides use element counts. - auto element_size_in_bytes = PyArray_ITEMSIZE(array); + const auto element_size_in_bytes = PyArray_ITEMSIZE(array); + TORCH_CHECK(element_size_in_bytes > 0, "element_size must be 0"); + for (auto& stride : strides) { TORCH_CHECK_VALUE( stride % element_size_in_bytes == 0, @@ -255,9 +260,6 @@ at::Tensor tensor_from_numpy( PyArray_EquivByteorders(PyArray_DESCR(array)->byteorder, NPY_NATIVE), "given numpy array has byte order different from the native byte order. " "Conversion between byte orders is currently not supported."); - // This has to go before the INCREF in case the dtype mapping doesn't - // exist and an exception is thrown - auto torch_dtype = numpy_dtype_to_aten(PyArray_TYPE(array)); Py_INCREF(obj); return at::lift_fresh(at::from_blob( data_ptr, diff --git a/torch/csrc/utils/tensor_numpy.h b/torch/csrc/utils/tensor_numpy.h index a067af44a45271..a7c1d8cf5476e2 100644 --- a/torch/csrc/utils/tensor_numpy.h +++ b/torch/csrc/utils/tensor_numpy.h @@ -5,16 +5,21 @@ namespace torch::utils { -PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force = false); -at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable = true); +TORCH_API PyObject* tensor_to_numpy( + const at::Tensor& tensor, + bool force = false); -int aten_to_numpy_dtype(const at::ScalarType scalar_type); -at::ScalarType numpy_dtype_to_aten(int dtype); +TORCH_API at::Tensor tensor_from_numpy( + PyObject* obj, + bool warn_if_not_writeable = true); -bool is_numpy_available(); -bool is_numpy_int(PyObject* obj); -bool is_numpy_bool(PyObject* obj); -bool is_numpy_scalar(PyObject* obj); +TORCH_API int aten_to_numpy_dtype(const at::ScalarType scalar_type); +TORCH_API at::ScalarType numpy_dtype_to_aten(int dtype); + +TORCH_API bool is_numpy_available(); +TORCH_API bool is_numpy_int(PyObject* obj); +TORCH_API bool is_numpy_bool(PyObject* obj); +TORCH_API bool is_numpy_scalar(PyObject* obj); void warn_numpy_not_writeable(); at::Tensor tensor_from_cuda_array_interface(PyObject* obj); diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index 917eeb4afc62b0..d696a0cdf4ddde 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -39,6 +39,8 @@ const char* backend_to_string(const at::Backend& backend) { return "torch.cuda.sparse"; case at::Backend::SparseXPU: return "torch.xpu.sparse"; + case at::Backend::SparseMPS: + return "torch.mps.sparse"; case at::Backend::QuantizedCPU: return "torch.quantized"; case at::Backend::HPU: diff --git a/torch/csrc/utils/throughput_benchmark.h b/torch/csrc/utils/throughput_benchmark.h index 8cf2f97158f2d8..075591220d6df3 100644 --- a/torch/csrc/utils/throughput_benchmark.h +++ b/torch/csrc/utils/throughput_benchmark.h @@ -18,7 +18,7 @@ namespace torch::throughput_benchmark { /** * The struct is used to provide results of a benchmark to the caller - * In the future all additional statics should be added here. + * In the future all additional statistics should be added here. */ struct BenchmarkExecutionStats { float latency_avg_ms{-1}; diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 6c32044900031d..715bf5b8fb66f7 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -305,6 +305,7 @@ static void registerXpuDeviceProperties(PyObject* module) { ._(name) \ ._(platform_name) \ ._(vendor) \ + ._(device_id) \ ._(driver_version) \ ._(version) \ ._(max_compute_units) \ @@ -333,8 +334,10 @@ static void registerXpuDeviceProperties(PyObject* module) { std::ostringstream stream; stream << "_XpuDeviceProperties(name='" << prop.name << "', platform_name='" << prop.platform_name << "', type='" - << get_device_type(prop) << "', driver_version='" - << prop.driver_version << "', total_memory=" + << get_device_type(prop) << "', device_id=0x" << std::hex + << std::uppercase << prop.device_id << std::dec + << ", driver_version='" << prop.driver_version + << "', total_memory=" << prop.global_mem_size / (1024ull * 1024) << "MB" << ", max_compute_units=" << prop.max_compute_units << ", gpu_eu_count=" << prop.gpu_eu_count diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 4f38de8a2c5c81..88f26adeb4c3f5 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -333,6 +333,7 @@ class DeferredCudaCallError(Exception): pass +AcceleratorError = torch._C.AcceleratorError OutOfMemoryError = torch._C.OutOfMemoryError @@ -426,7 +427,7 @@ def cudart(): >>> from torch.cuda import cudart, check_error >>> import os >>> - >>> os.environ['CUDA_PROFILE'] = '1' + >>> os.environ["CUDA_PROFILE"] = "1" >>> >>> def perform_cuda_operations_with_streams(): >>> stream = torch.cuda.Stream() @@ -1746,7 +1747,7 @@ def _compile_kernel( >>> a = torch.randn(1024, device="cuda") >>> b = torch.randn(1024, device="cuda") >>> c = torch.empty_like(a) - >>> add_kernel(grid=(4,1,1), block=(256,1,1), args=[a, b, c, a.numel()]) + >>> add_kernel(grid=(4, 1, 1), block=(256, 1, 1), args=[a, b, c, a.numel()]) """ import ctypes diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index 07527c397e5a5c..7f0f4fc3559ff5 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -89,16 +89,25 @@ def _block_extra(b): def format_flamegraph(flamegraph_lines, flamegraph_script=None): if flamegraph_script is None: - flamegraph_script = f"/tmp/{os.getuid()}_flamegraph.pl" + cache_dir = os.path.expanduser("~/.cache/") + os.makedirs(cache_dir, exist_ok=True) + flamegraph_script = f"{cache_dir}/flamegraph.pl" if not os.path.exists(flamegraph_script): + import tempfile import urllib.request print(f"Downloading flamegraph.pl to: {flamegraph_script}") - urllib.request.urlretrieve( - "https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl", - flamegraph_script, - ) - subprocess.check_call(["chmod", "+x", flamegraph_script]) + with tempfile.NamedTemporaryFile(mode="wb", suffix=".pl") as f: + urllib.request.urlretrieve( + "https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl", + f.name, + ) + try: + os.chmod(f.name, 0o755) + os.rename(f.name, flamegraph_script) + except OSError: # noqa: B001,E722 + # Ok to skip, the file will be removed by tempfile + pass args = [flamegraph_script, "--countname", "bytes"] p = subprocess.Popen( args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding="utf-8" @@ -124,7 +133,7 @@ def frames_fragment(frames): if "history" not in b: frames, accounted_for_size = _block_extra(b) f.write( - f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n' + f"{prefix};{b['state']};{frames_fragment(frames)} {accounted_for_size}\n" ) else: accounted_for_size = 0 @@ -133,18 +142,18 @@ def frames_fragment(frames): accounted_for_size += sz if "frames" in h: frames = h["frames"] - f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + f.write(f"{prefix};{b['state']};{frames_fragment(frames)} {sz}\n") else: - f.write(f'{prefix};{b["state"]}; {sz}\n') + f.write(f"{prefix};{b['state']}; {sz}\n") gaps = b["size"] - accounted_for_size if gaps: - f.write(f'{prefix};{b["state"]}; {gaps}\n') + f.write(f"{prefix};{b['state']}; {gaps}\n") def segments(snapshot, format_flamegraph=format_flamegraph): f = io.StringIO() for seg in snapshot["segments"]: - prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' + prefix = f"stream_{seg['stream']};seg_{seg['address']}" _write_blocks(f, prefix, seg["blocks"]) return format_flamegraph(f.getvalue()) @@ -152,7 +161,7 @@ def segments(snapshot, format_flamegraph=format_flamegraph): def memory(snapshot, format_flamegraph=format_flamegraph): f = io.StringIO() for seg in snapshot["segments"]: - prefix = f'stream_{seg["stream"]}' + prefix = f"stream_{seg['stream']}" _write_blocks(f, prefix, seg["blocks"]) return format_flamegraph(f.getvalue()) @@ -162,7 +171,7 @@ def _seg_key(seg): return (seg["address"], seg["total_size"]) def _seg_info(seg): - return f'stream_{seg["stream"]};seg_{seg["address"]}' + return f"stream_{seg['stream']};seg_{seg['address']}" f = io.StringIO() @@ -292,18 +301,18 @@ def segsum(data): occupied[j] = "0123456789*"[int(frac[j] * 10)] else: occupied[j] = m - stream = "" if seg["stream"] == 0 else f', stream_{seg["stream"]}' + stream = "" if seg["stream"] == 0 else f", stream_{seg['stream']}" body = "".join(occupied) assert ( seg_free_external + seg_free_internal + seg_allocated == seg["total_size"] ) - stream = f' stream_{seg["stream"]}' if seg["stream"] != 0 else "" + stream = f" stream_{seg['stream']}" if seg["stream"] != 0 else "" if seg["total_size"] >= PAGE_SIZE: out.write( - f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f"[{body}] {Bytes(seg['total_size'])} allocated, " f"{_report_free(seg_free_external, seg_free_internal)} free{stream}\n" ) - out.write(f'segments: {len(data["segments"])}\n') + out.write(f"segments: {len(data['segments'])}\n") out.write(f"total_reserved: {Bytes(total_reserved)}\n") out.write(f"total_allocated: {Bytes(total_allocated)}\n") out.write(f"total_free: {_report_free(free_external, free_internal)}\n") @@ -329,7 +338,7 @@ def _name(): return free_names.pop() r, m = next_name // 26, next_name % 26 next_name += 1 - return f'{chr(ord("a") + m)}{"" if r == 0 else r}' + return f"{chr(ord('a') + m)}{'' if r == 0 else r}" def find_segment(addr): for name, saddr, size in segment_intervals: diff --git a/torch/cuda/_pin_memory_utils.py b/torch/cuda/_pin_memory_utils.py new file mode 100644 index 00000000000000..d3c01f3293f7e3 --- /dev/null +++ b/torch/cuda/_pin_memory_utils.py @@ -0,0 +1,24 @@ +import torch + + +def pin_memory(data_ptr: int, size: int) -> None: + cudart = torch.cuda.cudart() + succ = int( + cudart.cudaHostRegister( + data_ptr, + size, + 1, # lines up with 'cudaHostRegisterPortable' + ) + ) + + if succ != 0: + raise RuntimeError( + f"Registering memory failed with cudaError: {succ}." + " It's possible that this is an asynchronous error raised from a previous cuda operation." + " Consider launching with CUDA_LAUNCH_BLOCKING=1 to debug." + ) + + +def unpin_memory(data_ptr: int) -> None: + succ = int(torch.cuda.cudart().cudaHostUnregister(data_ptr)) + assert succ == 0, f"Unpinning shared memory failed with error-code: {succ}" diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index 2c643250e61a61..5fdcd65ddf7b79 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -1,5 +1,4 @@ import ctypes -import os import sys from typing import Any, Optional, Union @@ -9,28 +8,6 @@ from torch._utils import _get_device_index as _torch_get_device_index -def _get_nvrtc_version(cuda_version: int) -> str: - # TODO: Expose this from native code - # Follows same logic as LazyNVRTC.cpp getLibVersion() - major = cuda_version // 1000 - minor = (cuda_version // 10) % 10 - - if sys.platform == "win32": - if major < 11 or (major == 11 and minor < 3): - return f"{major}{minor}" - elif major == 11: - return "112" - else: - return f"{major}0" - else: - if major < 11 or (major == 11 and minor < 3): - return f"{major}.{minor}" - elif major == 11: - return "11.2" - else: - return str(major) - - # Load CUDA driver and NVRTC def _get_cuda_library() -> ctypes.CDLL: if sys.platform == "win32": @@ -53,51 +30,12 @@ def _check_cuda(result: int) -> None: def _get_nvrtc_library() -> ctypes.CDLL: - # Get NVRTC version based on CUDA runtime version - # Use an alternative approach to get the CUDA version - # since cudart().getVersion() is failing - import torch - - try: - import torch.cuda - - cuda_runtime_version = torch.cuda.cudart().getVersion() - except (ImportError, AttributeError): - # Fallback: if we have CUDA available, get version from device properties - if hasattr(torch, "cuda") and torch.cuda.is_available(): - # Import locally to avoid circular imports - import torch.cuda - - props = torch.cuda.get_device_properties(torch.cuda.current_device()) - cuda_runtime_version = props.major * 1000 + props.minor * 10 - else: - # Hardcode a default CUDA version if all else fails - cuda_runtime_version = 12000 # Assume CUDA 12.0 as default - - version = _get_nvrtc_version(cuda_runtime_version) - + # Since PyTorch already loads NVRTC, we can use the system library + # which should be compatible with PyTorch's version if sys.platform == "win32": - # Windows handling remains the same - lib_name = f"nvrtc64_{version}_0.dll" - return ctypes.CDLL(lib_name) + return ctypes.CDLL("nvrtc64_120_0.dll") else: - lib_paths = [ - f"libnvrtc.so.{version}", - os.path.join( - os.environ.get("CUDA_HOME", ""), f"lib64/libnvrtc.so.{version}" - ), - "/usr/local/cuda/lib64/libnvrtc.so", - ] - - for path in lib_paths: - try: - return ctypes.CDLL(path) - except OSError: - continue - - raise RuntimeError( - "Could not find libnvrtc.so. Please make sure CUDA is installed." - ) + return ctypes.CDLL("libnvrtc.so") def _nvrtc_compile( @@ -294,8 +232,10 @@ def __call__( for arg in args: if isinstance(arg, torch.Tensor): - if not arg.is_cuda: - raise ValueError("All tensor arguments must be CUDA tensors") + if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()): + raise ValueError( + "All tensor arguments must be CUDA tensors or pinned CPU tensors" + ) # Get pointer to tensor data ptr = ctypes.c_void_p(arg.data_ptr()) processed_args.append(ptr) @@ -326,23 +266,21 @@ def __call__( stream = torch.cuda.current_stream() - # Launch the kernel with the current stream - with stream: - _check_cuda( - libcuda.cuLaunchKernel( - self.func, - grid[0], - grid[1], - grid[2], - block[0], - block[1], - block[2], - shared_mem, - None, - c_args_array, - None, - ) + _check_cuda( + libcuda.cuLaunchKernel( + self.func, + grid[0], + grid[1], + grid[2], + block[0], + block[1], + block[2], + shared_mem, + stream._as_parameter_, + c_args_array, + None, ) + ) def _cuda_load_module( diff --git a/torch/cuda/gds.py b/torch/cuda/gds.py index 5ed03be85bdd79..d3922499682e4d 100644 --- a/torch/cuda/gds.py +++ b/torch/cuda/gds.py @@ -119,9 +119,9 @@ def register_handle(self) -> None: This is a wrapper around ``cuFileHandleRegister``. """ - assert ( - self.handle is None - ), "Cannot register a handle that is already registered." + assert self.handle is None, ( + "Cannot register a handle that is already registered." + ) self.handle = torch._C._gds_register_handle(self.fd) def deregister_handle(self) -> None: @@ -129,9 +129,9 @@ def deregister_handle(self) -> None: This is a wrapper around ``cuFileHandleDeregister``. """ - assert ( - self.handle is not None - ), "Cannot deregister a handle that is not registered." + assert self.handle is not None, ( + "Cannot deregister a handle that is not registered." + ) torch._C._gds_deregister_handle(self.handle) self.handle = None @@ -145,9 +145,9 @@ def load_storage(self, storage: Storage, offset: int = 0) -> None: storage (Storage): Storage to load data into. offset (int, optional): Offset into the file to start loading from. (Default: 0) """ - assert ( - self.handle is not None - ), "Cannot load data from a file that is not registered." + assert self.handle is not None, ( + "Cannot load data from a file that is not registered." + ) torch._C._gds_load_storage(self.handle, storage, offset) def save_storage(self, storage: Storage, offset: int = 0) -> None: @@ -160,7 +160,7 @@ def save_storage(self, storage: Storage, offset: int = 0) -> None: storage (Storage): Storage to save data from. offset (int, optional): Offset into the file to start saving to. (Default: 0) """ - assert ( - self.handle is not None - ), "Cannot save data to a file that is not registered." + assert self.handle is not None, ( + "Cannot save data to a file that is not registered." + ) torch._C._gds_save_storage(self.handle, storage, offset) diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 226278aabc1f85..f42da65ed56da6 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -46,12 +46,32 @@ def graph_pool_handle(): class CUDAGraph(torch._C._CUDAGraph): r"""Wrapper around a CUDA graph. + Arguments: + keep_graph (bool, optional): If ``keep_graph=False``, the + cudaGraphExec_t will be instantiated on GPU at the end of + ``capture_end`` and the underlying cudaGraph_t will be + destroyed. Users who want to query or otherwise modify the + underlying cudaGraph_t before instantiatiation can set + ``keep_graph=True`` and access it via ``raw_cuda_graph`` after + ``capture_end``. Note that the cudaGraphExec_t will not be + instantiated at the end of ``capture_end`` in this + case. Instead, it wil be instantiated via an explicit called + to ``instantiate`` or automatically on the first call to + ``replay`` if ``instantiate`` was not already called. Calling + ``instantiate`` manually before ``replay`` is recommended to + prevent increased latency on the first call to ``replay``. It + is allowed to modify the raw cudaGraph_t after first calling + ``instantiate``, but the user must call ``instantiate`` again + manually to make sure the instantiated graph has these + changes. Pytorch has no means of tracking these changes. + .. warning:: This API is in beta and may change in future releases. + """ - def __new__(cls): - return super().__new__(cls) + def __new__(cls, keep_graph=False): + return super().__new__(cls, keep_graph) def capture_begin(self, pool=None, capture_error_mode="global"): r"""Begin capturing CUDA work on the current stream. @@ -83,6 +103,15 @@ def capture_end(self): """ super().capture_end() + def instantiate(self): + r"""Instantiate the CUDA graph. Will be called by + ``capture_end`` if ``keep_graph=False``, or by ``replay`` if + ``keep_graph=True`` and ``instantiate`` has not already been + explicitly called. Does not destroy the cudaGraph_t returned + by ``raw_cuda_graph``. + """ + super().instantiate() + def replay(self): r"""Replay the CUDA work captured by this graph.""" super().replay() @@ -113,6 +142,13 @@ def debug_dump(self, debug_path): """ return super().debug_dump(debug_path) + def raw_cuda_graph(self): + r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True. + + See the following for APIs for how to manipulate this object: `Graph Managmement `_ and `cuda-python Graph Management bindings `_ + """ # noqa: B950 + return super().raw_cuda_graph() + class graph: r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay. @@ -479,7 +515,9 @@ def new_fwd(*user_args): return new_fwd - func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] + func.forward = make_graphed_forward( + func, func.training, graphed, func.forward + ) # type: ignore[assignment] ret.append(func) else: ret.append(graphed) diff --git a/torch/cuda/jiterator.py b/torch/cuda/jiterator.py index e0c5decc0effd8..8bcb14d9fcfbda 100644 --- a/torch/cuda/jiterator.py +++ b/torch/cuda/jiterator.py @@ -57,9 +57,9 @@ def __init__( ): self.code_string = code_string - assert ( - return_by_ref or num_outputs == 1 - ), "Return by value only works for single output. " + assert return_by_ref or num_outputs == 1, ( + "Return by value only works for single output. " + ) self.return_by_ref = return_by_ref self.num_outputs = num_outputs @@ -72,9 +72,9 @@ def __init__( def __call__(self, *tensors: Tensor, **kwargs): # Jiterator follow torch.cuda's lazy initialization behavior # Defer checking cuda's availability at the function invocation time - assert ( - self.is_cuda_available - ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available." + assert self.is_cuda_available, ( + "Jiterator is only supported on CUDA and ROCm GPUs, none are available." + ) assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs." @@ -114,8 +114,8 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: code_string = "template T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) - a = torch.rand(3, device='cuda') - b = torch.rand(3, device='cuda') + a = torch.rand(3, device="cuda") + b = torch.rand(3, device="cuda") # invoke jitted function like a regular python function result = jitted_fn(a, b, alpha=3.14) @@ -123,11 +123,13 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: Example:: - code_string = "template T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" + code_string = ( + "template T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" + ) code_string += "template T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" jitted_fn = create_jit_fn(code_string, val=0.0) - a = torch.rand(3, device='cuda') - b = torch.rand(3, device='cuda') + a = torch.rand(3, device="cuda") + b = torch.rand(3, device="cuda") # invoke jitted function like a regular python function result = jitted_fn(a, b) # using default val=0.0 @@ -139,9 +141,9 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: code_string = "template T my_gelu(T a) { return a > 0 ? a : 0; }" my_gelu = create_jit_fn(code_string) my_lib = torch.library.Library("aten", "IMPL") - my_lib.impl('aten::gelu', my_gelu, "CUDA") + my_lib.impl("aten::gelu", my_gelu, "CUDA") # torch.nn.GELU and torch.nn.function.gelu are now overridden - a = torch.rand(3, device='cuda') + a = torch.rand(3, device="cuda") torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a)) .. warning:: @@ -171,8 +173,8 @@ def _create_multi_output_jit_fn( code_string = "template void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) - a = torch.rand(3, device='cuda') - b = torch.rand(3, device='cuda') + a = torch.rand(3, device="cuda") + b = torch.rand(3, device="cuda") # invoke jitted function like a regular python function result = jitted_fn(a, b, alpha=3.14) diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index df02fb28d44f95..e6c34ef64f5a04 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -847,6 +847,7 @@ def _record_memory_history_legacy( record_context_cpp=False, clear_history=False, compile_context=False, + global_record_annotations=False, ): _C._cuda_record_memory_history_legacy( # type: ignore[call-arg] enabled, @@ -856,6 +857,7 @@ def _record_memory_history_legacy( record_context_cpp, clear_history, compile_context, + global_record_annotations, ) @@ -866,12 +868,42 @@ def _record_memory_history( allocations, so you can tell what allocated any piece of memory in :func:`torch.cuda.memory._snapshot()`. - In addition too keeping stack traces with each current allocation and free, + In addition to keeping stack traces with each current allocation and free, this will also enable recording of a history of all alloc/free events. Use :func:`torch.cuda.memory._snapshot()` to retrieve this information, and the tools in `_memory_viz.py` to visualize snapshots. + Buffer behavior + --------------- + + This will store up to `max_entries` instances of `TraceEntry` when enabled. + Python trace collection defaults to `sys.maxsize`, meaning long-running + or indefinitely running jobs should set a reasonable limit to avoid excessive + memory use. Expect each entry to be several KB. + + Longer running workflows or those with smaller `max_entries` values will only + store the last accumulated `max_entries` entries, meaning new entries overwrite + older entries. + + C++ implementation for reference to ring buffer implemenation: + + .. code-block:: cpp + + if (record_history) { + if (alloc_trace->size() < alloc_trace_max_entries_) { + alloc_trace->emplace_back(te); + } else { + (*alloc_trace)[alloc_trace_next++] = te; + if (alloc_trace_next == alloc_trace_max_entries_) { + alloc_trace_next = 0; + } + } + } + + Latency impact + -------------- + The Python trace collection is fast (2us per trace), so you may consider enabling this on production jobs if you anticipate ever having to debug memory issues. @@ -912,9 +944,16 @@ def _record_memory_history_impl( device: "Device" = None, clear_history: bool = False, compile_context: bool = False, + global_record_annotations: bool = False, ): _C._cuda_record_memory_history( # type: ignore[call-arg] - enabled, context, stacks, max_entries, clear_history, compile_context + enabled, + context, + stacks, + max_entries, + clear_history, + compile_context, + global_record_annotations, ) @@ -929,9 +968,10 @@ def _snapshot(device: "Device" = None): .. code-block:: python class Snapshot(TypedDict): - segments : List[Segment] + segments: List[Segment] device_traces: List[List[TraceEntry]] + class Segment(TypedDict): # Segments are memory returned from a cudaMalloc call. # The size of reserved memory is the sum of all Segments. @@ -940,57 +980,62 @@ class Segment(TypedDict): # is split into more then one Block. # empty_cache() frees Segments that are entirely inactive. address: int - total_size: int # cudaMalloc'd size of segment + total_size: int # cudaMalloc'd size of segment stream: int - segment_type: Literal['small', 'large'] # 'large' (>1MB) - allocated_size: int # size of memory in use - active_size: int # size of memory in use or in active_awaiting_free state - blocks : List[Block] + segment_type: Literal["small", "large"] # 'large' (>1MB) + allocated_size: int # size of memory in use + active_size: int # size of memory in use or in active_awaiting_free state + blocks: List[Block] + class Block(TypedDict): # A piece of memory returned from the allocator, or # current cached but inactive. size: int - requested_size: int # size requested during malloc, may be smaller than - # size due to rounding + requested_size: int # size requested during malloc, may be smaller than + # size due to rounding address: int - state: Literal['active_allocated', # used by a tensor - 'active_awaiting_free', # waiting for another stream to finish using - # this, then it will become free - 'inactive',] # free for reuse - frames: List[Frame] # stack trace from where the allocation occurred + state: Literal[ + "active_allocated", # used by a tensor + "active_awaiting_free", # waiting for another stream to finish using + # this, then it will become free + "inactive", + ] # free for reuse + frames: List[Frame] # stack trace from where the allocation occurred + class Frame(TypedDict): - filename: str - line: int - name: str + filename: str + line: int + name: str + class TraceEntry(TypedDict): # When `torch.cuda.memory._record_memory_history()` is enabled, # the snapshot will contain TraceEntry objects that record each # action the allocator took. action: Literal[ - 'alloc' # memory allocated - 'free_requested', # the allocated received a call to free memory - 'free_completed', # the memory that was requested to be freed is now - # able to be used in future allocation calls - 'segment_alloc', # the caching allocator ask cudaMalloc for more memory - # and added it as a segment in its cache - 'segment_free', # the caching allocator called cudaFree to return memory - # to cuda possibly trying free up memory to - # allocate more segments or because empty_caches was called - 'oom', # the allocator threw an OOM exception. 'size' is - # the requested number of bytes that did not succeed - 'snapshot' # the allocator generated a memory snapshot - # useful to coorelate a previously taken - # snapshot with this trace + "alloc" # memory allocated + "free_requested", # the allocated received a call to free memory + "free_completed", # the memory that was requested to be freed is now + # able to be used in future allocation calls + "segment_alloc", # the caching allocator ask cudaMalloc for more memory + # and added it as a segment in its cache + "segment_free", # the caching allocator called cudaFree to return memory + # to cuda possibly trying free up memory to + # allocate more segments or because empty_caches was called + "oom", # the allocator threw an OOM exception. 'size' is + # the requested number of bytes that did not succeed + "snapshot", # the allocator generated a memory snapshot + # useful to coorelate a previously taken + # snapshot with this trace ] - addr: int # not present for OOM + addr: int # not present for OOM frames: List[Frame] size: int stream: int - device_free: int # only present for OOM, the amount of - # memory cuda still reports to be free + device_free: int # only present for OOM, the amount of + # memory cuda still reports to be free Returns: The Snapshot dictionary object @@ -1004,6 +1049,10 @@ def _dump_snapshot(filename="dump_snapshot.pickle"): This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz + Snapshot file sizes scale with `max_entries` and stack trace depth per entry, + with several KB per entry. These can easily be in the GB range for longer running + workflows with large `max_entries`. + Args: filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle". """ @@ -1120,21 +1169,28 @@ class MemPool(_MemPool): use_on_oom(bool): a bool that indicates if this pool can be used as a last resort if a memory allocation outside of the pool fails due to Out Of Memory. This is False by default. - + symmetric(bool): a bool that indicates if this pool is symmetrical + across ranks. This is False by default. """ def __init__( self, allocator: Optional[_cuda_CUDAAllocator] = None, use_on_oom: bool = False, + symmetric: bool = False, ): - super().__init__(allocator, True, use_on_oom) + super().__init__(allocator, True, use_on_oom, symmetric) @property def id(self) -> tuple[int, int]: r"""Returns the ID of this pool as a tuple of two ints.""" return super().id + @property + def is_symmetric(self) -> bool: + r"""Returns whether this pool is used for NCCL's symmetric memory.""" + return super().is_symmetric + @property def allocator(self) -> Optional[_cuda_CUDAAllocator]: r"""Returns the allocator this MemPool routes allocations to.""" diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 15095868bbad3f..023f5f9a53b221 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -158,17 +158,21 @@ class Event(torch._C._CudaEventBase): blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) interprocess (bool): if ``True``, the event can be shared between processes (default: ``False``) + external (bool, optional): indicates whether this event should create event record and event wait nodes, or create an internal cross-stream dependency, when captured in a cuda graph. See `cross-stream dependencies `_, `cudaEventRecordExternal `_, and `cudaEventWaitExternal `_ for more information about internal vs. external events. (default: ``False``) .. _CUDA Event Documentation: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html - """ + """ # noqa: B950 - def __new__(cls, enable_timing=False, blocking=False, interprocess=False): + def __new__( + cls, enable_timing=False, blocking=False, interprocess=False, external=False + ): return super().__new__( cls, enable_timing=enable_timing, blocking=blocking, interprocess=interprocess, + external=external, ) @classmethod diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index 4b75f8dc5b3c09..0e3e692a1f2533 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -124,11 +124,11 @@ There are basically two steps: 1) Set the environment variables to collect the untuned GEMM and this will generate ``tunableop_untuned0.csv``: -.. code-block:: python +.. code-block:: bash - PYTORCH_TUNABLEOP_ENABLED=1 - PYTORCH_TUNABLEOP_TUNING=0 - PYTORCH_TUNABLEOP_RECORD_UNTUNED=1 + export PYTORCH_TUNABLEOP_ENABLED=1 + export PYTORCH_TUNABLEOP_TUNING=0 + export PYTORCH_TUNABLEOP_RECORD_UNTUNED=1 ... 2) Run a Python script that reads the ``tunableop_untuned0.csv`` and generates the ``tunableop_results0.csv``, like this: @@ -138,9 +138,9 @@ import torch.cuda.tunable as tunable import os - os.putenv('PYTORCH_TUNABLEOP_ENABLED', '1') - os.putenv('PYTORCH_TUNABLEOP_TUNING', '1') - os.putenv('PYTORCH_TUNABLEOP_RECORD_UNTUNED', '0') + os.putenv("PYTORCH_TUNABLEOP_ENABLED", "1") + os.putenv("PYTORCH_TUNABLEOP_TUNING", "1") + os.putenv("PYTORCH_TUNABLEOP_RECORD_UNTUNED", "0") tunable.tune_gemm_in_file("tunableop_untuned0.csv") @@ -155,7 +155,7 @@ .. code-block:: python if __name__ == "__main__": - num_gpus = 8 # number of GPUs that will be used during the tuning process + num_gpus = 8 # number of GPUs that will be used during the tuning process tunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv", num_gpus) Note that the usage of the ``mgpu_tune_gemm_in_file`` API is different from its single GPU counterpart @@ -179,6 +179,7 @@ Use the C++ or Python APIs instead. """ + import concurrent.futures import glob import multiprocessing as mp diff --git a/torch/distributed/_checkpointable.py b/torch/distributed/_checkpointable.py index bc0a288f1291f3..0594c20337b3bf 100644 --- a/torch/distributed/_checkpointable.py +++ b/torch/distributed/_checkpointable.py @@ -1,5 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Protocol, runtime_checkable +from typing_extensions import Protocol, runtime_checkable import torch diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 0532b4ccf5b23f..ec51b2b7a18174 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -635,7 +635,7 @@ def _get_acs_underlying_tensor(self): return self.elem @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] if func == torch.ops.aten.view.default: # Fast handle aten.view as a lot of view related op goes to aten.view # eventually, this avoids pytree slowdown diff --git a/torch/distributed/_serialization.py b/torch/distributed/_serialization.py index 4c49f2585bcb6d..2aa9786c0e47ba 100644 --- a/torch/distributed/_serialization.py +++ b/torch/distributed/_serialization.py @@ -97,7 +97,7 @@ def _streaming_save( This behaves similarly to :func:`torch.save` with a few notable differences: * A non-seekable file like object can be used when loading. - * No forwards/backwards compatiblity is provided for the serialization + * No forwards/backwards compatibility is provided for the serialization format. This is only intended to be used with a single version of PyTorch with transient storage (i.e. sockets or temp files). * mmap is not supported diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 5b8849e27d500c..2bfbbcb575cd65 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -184,7 +184,7 @@ def _init_from_local_shards_and_global_metadata( return sharded_tensor_base @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] raise RuntimeError( f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} " "but the there is no custom __torch_dispatch__ implementation for it." @@ -515,7 +515,7 @@ def cpu( .. note:: When moving a ShardedTensor from GPU to CPU, the ShardedTensor might need to be managed by a different type of ProcessGroup(i.e. ProcessGroupGloo), - it is the user's responsiblity to explicitly pass in a new process_group that + it is the user's responsibility to explicitly pass in a new process_group that is compatible with CPU. """ # TODO: make this a __torch_function__ op once ShardedTensor becomes a @@ -575,7 +575,7 @@ def cuda( metadata, but no underlying data movements are performed. .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL), - it is the user's responsiblity to explicitly pass in a new process_group that + it is the user's responsibility to explicitly pass in a new process_group that is compatible with GPU. """ if ( diff --git a/torch/distributed/_shard/sharded_tensor/utils.py b/torch/distributed/_shard/sharded_tensor/utils.py index 67527947dc67c8..5ddb05d4d3c05f 100644 --- a/torch/distributed/_shard/sharded_tensor/utils.py +++ b/torch/distributed/_shard/sharded_tensor/utils.py @@ -288,7 +288,7 @@ def recalc_global_sharded_tensor_metadata( placement_idx_pairs.append((shard_metadata.placement.rank(), i)) else: raise AssertionError( - "currently only support rw, it should alwyas have vaid rank info" + "currently only support rw, it should always have valid rank info" ) sorted_idx = sorted(placement_idx_pairs) shard_sizes = [ diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index fe50bef5339edb..61a2729ec45e9f 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -7,6 +7,7 @@ from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union import torch +import torch.cuda._pin_memory_utils as pin_memory_utils import torch.distributed as dist import torch.nn.functional as F from torch.distributed._functional_collectives import AsyncCollectiveTensor @@ -187,6 +188,12 @@ def _iterate_state_dict( companion_obj._local_tensor.copy_( ret._local_tensor, non_blocking=non_blocking ) + elif isinstance(companion_obj, ShardedTensor): + assert isinstance(ret, ShardedTensor) + for idx, shard in enumerate(companion_obj.local_shards()): + shard.tensor.copy_( + ret.local_shards()[idx].tensor, non_blocking=non_blocking + ) else: companion_obj.copy_(ret, non_blocking=non_blocking) ret = companion_obj @@ -402,28 +409,22 @@ def tensor_func( if len(obj.size()) == 0: return torch.tensor(0, dtype=obj.dtype) + # sometimes, a tensor might have non-zero size and 0 numel. In this case, pinning memory will fail + # so we take a best guess at how to replicate the tensor below to maintain symmetry in the returned + # state dict. + if obj.numel() == 0 or obj.data_ptr() == 0: + t = torch.zeros_like(obj, device="cpu") + if share_memory: + t = t.share_memory_() + return t + if share_memory: t = torch.empty(*tuple(obj.size()), dtype=obj.dtype) t = t.share_memory_() if pin_memory: + pin_memory_utils.pin_memory(t.data_ptr(), t.numel() * t.element_size()) + weakref.finalize(t, pin_memory_utils.unpin_memory, t) - def unpin_memory(t): - succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr())) - assert succ == 0, ( - f"Unpinning shared memory failed with error-code: {succ}" - ) - - weakref.finalize(t, unpin_memory, t) - succ = int( - torch.cuda.cudart().cudaHostRegister( - t.data_ptr(), - t.numel() * t.element_size(), - 1, # lines up with 'cudaHostRegisterPortable' - ) - ) - assert succ == 0, ( - f"Pinning shared memory failed with error-code: {succ}" - ) return t elif pin_memory: return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() @@ -446,9 +447,28 @@ def dtensor_func( ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None) return ret + def sharded_tensor_func( + obj: ShardedTensor, + pg: Optional[dist.ProcessGroup], + device: Optional[torch.device], + _: Any, + ) -> ShardedTensor: + if not obj.local_shards(): + return obj + + if obj.device != torch.device("cpu"): + ret = obj.to(device="cpu") + else: + ret = copy.deepcopy(obj) + + for shards in ret.local_shards(): + shards.tensor = tensor_func(shards.tensor, pg, device, None) + + return ret + ret = _iterate_state_dict( state_dict, - _identity_func, + sharded_tensor_func, dtensor_func, tensor_func, pg=None, @@ -592,7 +612,7 @@ def _distribute_tensors( ] if local_state.is_meta: # Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost. - local_tensor = full_tensor[slices].detach().clone() + local_tensor = full_tensor[tuple(slices)].detach().clone() # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). ret = DTensor.from_local( @@ -605,7 +625,7 @@ def _distribute_tensors( else: ret = local_state # Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint. - ret.to_local().copy_(full_tensor[slices]) + ret.to_local().copy_(full_tensor[tuple(slices)]) local_state_dict[key] = ret diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index dcc4a41194904e..634e953aeb36b3 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -7,7 +7,7 @@ from datetime import timedelta from enum import Enum from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Callable, Literal, Optional import torch import torch.distributed._functional_collectives as funcol @@ -1235,7 +1235,7 @@ def _fused_scaled_matmul_reduce_scatter_impl( # To handle case where A is 3D+, reshape to 2D to prepare for mm which requires 2D inputs. A_2D_with_scatter_dim_0 = A_with_scatter_dim_0.flatten(0, -2) - # Parition A along the first dim to prepare for sharding across TP process group. + # Partition A along the first dim to prepare for sharding across TP process group. A_shards = A_2D_with_scatter_dim_0.chunk(group.size()) # Now that 'A' is sharded along the first dim, we need to update its scale(s) accordingly. @@ -1704,4 +1704,46 @@ def rendezvous( return _SymmetricMemory.rendezvous(tensor, group_name) -__all__ = ["empty", "rendezvous"] +def is_nvshmem_available() -> bool: + r""" + is_nvshmem_available() -> bool + + Check if NVSHMEM is available in current build and on current system. + """ + try: + from torch._C._distributed_c10d import _is_nvshmem_available + except ImportError: + # Not all builds have NVSHMEM support. + return False + + # Check if NVSHMEM is available on current system. + return _is_nvshmem_available() + + +def set_backend(name: Literal["NVSHMEM", "CUDA", "NCCL"]) -> None: + r""" + Set the backend for symmetric memory allocation. This is a global setting + and affects all subsequent calls to + :func:`torch._distributed._symmetric_memory.empty()`. Note that the backend + cannot be changed once a symmetric memory tensor has been allocated. + + Args: + backend (str): the backend for symmetric memory allocation. Currently, + only "NVSHMEM", "CUDA", "NCCL" are supported. + """ + _SymmetricMemory.set_backend(name) + + +def get_backend(device: _device) -> Optional[str]: + r""" + Get the backend for symmetric memory allocation for a given device. If not + found, return None. + + Args: + device (class:`torch.device` or str): the device for which to get the + backend. + """ + return _SymmetricMemory.get_backend(torch.device(device)) + + +__all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"] diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py new file mode 100644 index 00000000000000..75abae38c755a6 --- /dev/null +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -0,0 +1,199 @@ +import os +import sysconfig +from typing import Optional + +from torch.utils._triton import has_triton + + +def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: + """ + Enable NVSHMEM device functions for Triton. It performs a NVSHMEM + device-side initialization on the kernel module created by Triton. + + Args: + lib_dir (Optional[str]): The directory where the NVSHMEM device library + is located. If not provided, it will use the default path where NVSHMEM + wheel is installed. + + Returns: + dict[str, str]: A dictionary containing the NVSHMEM device library name + and path. + """ + from triton.runtime.jit import JITFunction + + from torch._C._distributed_c10d import _nvshmemx_cumodule_init + + # Detect NVSHMEM device library path from python library path + if lib_dir is None: + py_lib_path = sysconfig.get_path("purelib") + lib_dir = py_lib_path + "/nvidia/nvshmem/lib" + + lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") + if not os.path.exists(lib_path): + raise RuntimeError("NVSHMEM device library not found") + + extern_libs = {"libnvshmem_device": lib_path} + + # A hook function to initialize NVSHMEM in Triton + def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] + key = kwargs["key"] + device = kwargs["compile"]["device"] + jit_function = kwargs["fn"].jit_function + kernel_cache, _, _, _ = jit_function.device_caches[device] + kernel = kernel_cache.get(key, None) + kernel.run + _nvshmemx_cumodule_init(kernel.module) + + # Register the function as a post-compile hook + JITFunction.compiled_hook = nvshmem_init_hook + + # Return to user so that they can use it in Triton kernel invocation + return extern_libs + + +if has_triton(): + from triton.language import core + + @core.extern + def putmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_putmem_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_getmem_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def putmem_signal_block( # type: ignore[no-untyped-def] + dst, + src, + nelems, + sig_addr, + signal, + sig_op, + pe, + _builder=None, + ): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, sig_addr, signal, sig_op, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_putmem_signal_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [ivar, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_longlong_wait_until", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [sig_addr, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_signal_wait_until", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [sig_addr, signal, sig_op, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_signal_op", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def fence(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + { + (): ("nvshmem_fence", core.dtype("int32")), + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def quiet(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + { + (): ("nvshmem_quiet", core.dtype("int32")), + }, + is_pure=False, + _builder=_builder, + ) diff --git a/torch/distributed/_tools/fsdp2_mem_tracker.py b/torch/distributed/_tools/fsdp2_mem_tracker.py index 2eab61e12401bf..5ab0da5522145b 100644 --- a/torch/distributed/_tools/fsdp2_mem_tracker.py +++ b/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -170,7 +170,7 @@ def __init__( def _instrument_fsdp_sharded_params_grads( self, fsdp_param_group: FSDPParamGroup ) -> None: - # Track sharded params and grads after initilization + # Track sharded params and grads after initialization for fsdp_param in fsdp_param_group.fsdp_params: self._update_and_maybe_create_winfos( fsdp_param.sharded_param, @@ -199,7 +199,7 @@ def _fsdp_state_pre_forward( # this module is called for the second time. If it is a root module, that means we are in the next # iteration and we error out. If it is not a root module, that means it's a submodule that is being # used multiple times in the same iteration, which we allow and track. - # For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + # For Case 1 and 3, we also initialize the ``local_peak`` and ``PEAK_FW`` snapshot for the module. # For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op. @wraps(orig_fsdp_state_pre_fw) def inner( diff --git a/torch/distributed/_tools/mem_tracker.py b/torch/distributed/_tools/mem_tracker.py index 38a25eb2a29453..097cf0fba54a2f 100644 --- a/torch/distributed/_tools/mem_tracker.py +++ b/torch/distributed/_tools/mem_tracker.py @@ -643,7 +643,7 @@ def _pre_fw_hook(self, module: nn.Module, inputs: Any) -> None: # this module is called for the second time. If it is a root module, that means we are in the next # iteration and we error out. If it is not a root module, that means it's a submodule that is being # used multiple times in the same iteration, which we allow and track. - # For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + # For Case 1 and 3, we also initialize the ``local_peak`` and ``PEAK_FW`` snapshot for the module. mod_name = self._mod_tracker.get_known_fqn(module) assert mod_name is not None if module not in self.memory_tracking: diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index f5226b9fb38f0f..290846d604b780 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -81,7 +81,8 @@ def __init__(self) -> None: self._markers: dict[str, int] = defaultdict(int) self._cur_module_name: str = "" self._op_index: int = 0 - self._num_cuda_retries: int = 0 + self._num_alloc_retries: int = 0 + self._device_module = torch.get_device_module() @no_type_check def start_monitor(self, root_module: nn.Module) -> None: @@ -106,7 +107,7 @@ def start_monitor(self, root_module: nn.Module) -> None: # clear and remove it for now as it does not really capture important info. # h3 = m.register_backward_hook(self._create_backward_hook(name)) self._hooks.extend([h1, h2]) - torch.cuda.empty_cache() + self._device_module.empty_cache() assert getattr(self, "profile_mode", None) is None self.profile_mode = MemoryProfileDispatchMode(self) self.profile_mode.__enter__() @@ -116,9 +117,11 @@ def stop(self) -> None: """ Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level. - Get some aggregated stats when the memory_tracker() is enabled, like cuda ``num_alloc_retries``. + Get some aggregated stats when the memory_tracker() is enabled, like ``num_alloc_retries``. """ - self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0) + self._num_alloc_retries = self._device_module.memory_stats().get( + "num_alloc_retries", 0 + ) for h in self._hooks: h.remove() @@ -142,7 +145,7 @@ def summary(self, top: int = 20) -> None: previous_allocated_memory = current_allocated_memory print("------------------------------------------------") - print(f"The number of cuda retries are: {self._num_cuda_retries}") + print(f"The number of alloc retries are: {self._num_alloc_retries}") print(f"Top {top} ops that generates memory are:") for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[ :top @@ -206,7 +209,7 @@ def save_stats(self, path: str) -> None: "memories_active": self.memories_active, "memories_reserved": self.memories_reserved, "markers": self._markers, - "num_alloc_retries": self._num_cuda_retries, + "num_alloc_retries": self._num_alloc_retries, } with open(path, "wb") as f: @@ -221,7 +224,7 @@ def load(self, path: str) -> None: self.memories_active = stats["memories_active"] self.memories_reserved = stats["memories_reserved"] self._markers = stats["markers"] - self._num_cuda_retries = stats["num_alloc_retries"] + self._num_alloc_retries = stats["num_alloc_retries"] def _create_pre_forward_hook(self, name: str) -> Callable: """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" @@ -269,10 +272,11 @@ def _record_memory_stats(self, fn_name: str) -> None: The memory stats dict is indexed with ``self._op_index``. """ - memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB - memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB + memory_allocated: float = self._device_module.memory_allocated() / BYTES_PER_MB + memory_reserved: float = self._device_module.memory_reserved() / BYTES_PER_MB memory_active: float = ( - torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB + self._device_module.memory_stats().get("active_bytes.all.current", 0) + / BYTES_PER_MB ) self.memories_allocated[self._op_index] = (fn_name, memory_allocated) self.memories_reserved[self._op_index] = (fn_name, memory_reserved) @@ -293,4 +297,4 @@ def _clear_state(self) -> None: self._markers.clear() self._cur_module_name = "" self._op_index = 0 - self._num_cuda_retries = 0 + self._num_alloc_retries = 0 diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index 5dabb23b6347a9..734e463fceaa66 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -83,7 +83,7 @@ class RuntimeEstimator(TorchDispatchMode): This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and roofline cost modeling (`operator-level-cost-model`). - For modules executed under this context manager, it agggregates the forward and backward operation runtimes + For modules executed under this context manager, it aggregates the forward and backward operation runtimes and also records their execution orders. Attributes: diff --git a/torch/distributed/_tools/sac_estimator.py b/torch/distributed/_tools/sac_estimator.py index 9af5e538ca7b10..55b66777614179 100644 --- a/torch/distributed/_tools/sac_estimator.py +++ b/torch/distributed/_tools/sac_estimator.py @@ -125,7 +125,7 @@ class MSPS(NamedTuple): Attributes: func_names (set[str]): Set of operator/operator group names. - op_idx (int): Operator index (group head index incase of operator groups). + op_idx (int): Operator index (group head index in case of operator groups). memory (int): Memory usage in bytes. runtime (float): Runtime in milliseconds. msps (float): Memory per second calculated as memory/runtime. @@ -194,7 +194,7 @@ class SACEstimator(TorchDispatchMode): estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model). Attributes: - sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fuly qualified name) to ``SACStats``. + sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fully qualified name) to ``SACStats``. sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``. sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``. @@ -364,7 +364,7 @@ def _get_inplace_metadata( # 5. Initialize the parent op ids of the inplace op for each of the active modules mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1) for i, d in enumerate(self._sac_metadata): - # 6. Find the first occurence of a tensor corresponding to each module that + # 6. Find the first occurrence of a tensor corresponding to each module that # shares the same storage as the current tensor past_output_ids = d.output_ids if set(output_ids).issubset(set(past_output_ids)): @@ -483,7 +483,7 @@ def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta: # a) If the head of this group is an inplace op, then we have to store the entire group. # b) If any op in the group is random and force_store_random is set, then entire group will be stored. # c) If none of ops in the group are random and the head of the group is not an in-place op, then - # this group can be considered for recomputation in its entireity + # this group can be considered for recomputation in its entirety stored_ops: set[int] = set() recomputed_ops: set[int] = set() # Case 1: @@ -533,7 +533,7 @@ def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta: func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices} msps = (mem / runtime) if runtime > 0 else sys.float_info.max msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps)) - # We choose canidates to be recomputed based on increasing msps + # We choose candidates to be recomputed based on increasing msps msps_meta.sort(key=lambda x: x.msps, reverse=True) return SACGreedyOrderMeta( recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta @@ -560,7 +560,7 @@ def _get_sac_tradeoff_pwlf_stats( greedy_order_meta.random_ops_group, greedy_order_meta.msps_meta, ) - # 1. Intitialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops + # 1. Initialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops recomp_indices: set[int] = set() for r_idx in recomputed_ops: recomp_indices.add(r_idx) @@ -574,7 +574,7 @@ def _get_sac_tradeoff_pwlf_stats( # 2. Initialize the max recomputation time and total recomputation memory sac_runtime = sum(sac_stats.runtimes) sac_memory = sum(sac_stats.memory) - # 3. Tradeoff curve stores the KV pair of the dicarded memory to total memory and, + # 3. Tradeoff curve stores the KV pair of the discarded memory to total memory and, # recomputation time to total runtime incurred. delta = 1e-2 tradeoff_curve = OrderedDict() diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 00b84d6c28eecc..23456f6170534b 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -631,6 +631,7 @@ def decompress(fut): if state.use_error_feedback: # Memorize the local errors. + assert input_tensor_cp is not None state.error_dict[bucket_index] = input_tensor_cp - input_tensor if not state.warm_start: state.p_memory_dict.clear() @@ -843,6 +844,7 @@ def decompress(fut): if state.use_error_feedback: # Memorize the local errors. + assert input_tensor_cp is not None state.error_dict[bucket_index] = input_tensor_cp - input_tensor # Removing this seemingly unnecessary sync somehow may cause failures. # See: https://github.com/pytorch/pytorch/pull/54838 diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py index b1cf0aec6140f8..6a52c36942e48e 100644 --- a/torch/distributed/autograd/__init__.py +++ b/torch/distributed/autograd/__init__.py @@ -1,9 +1,15 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, TYPE_CHECKING import torch -def is_available(): +if TYPE_CHECKING: + from types import TracebackType + + +def is_available() -> bool: return hasattr(torch._C, "_dist_autograd_init") @@ -25,6 +31,8 @@ def is_available(): get_gradients, ) +__all__ = ["context", "is_available"] + class context: """ @@ -45,9 +53,14 @@ class context: >>> dist_autograd.backward(context_id, [loss]) """ - def __enter__(self): + def __enter__(self) -> int: self.autograd_context = _new_context() return self.autograd_context._context_id() - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: _release_context(self.autograd_context._context_id()) diff --git a/torch/distributed/checkpoint/__init__.py b/torch/distributed/checkpoint/__init__.py index b65db2f3442339..56bac60b956629 100644 --- a/torch/distributed/checkpoint/__init__.py +++ b/torch/distributed/checkpoint/__init__.py @@ -1,9 +1,8 @@ from . import _extension -from ._hf_planner import _HuggingFaceLoadPlanner, _HuggingFaceSavePlanner -from ._hf_storage import _HuggingFaceStorageReader, _HuggingFaceStorageWriter from .api import CheckpointException from .default_planner import DefaultLoadPlanner, DefaultSavePlanner from .filesystem import FileSystemReader, FileSystemWriter +from .hf_storage import HuggingFaceStorageReader, HuggingFaceStorageWriter from .metadata import ( BytesStorageMetadata, ChunkStorageMetadata, diff --git a/torch/distributed/checkpoint/_async_executor.py b/torch/distributed/checkpoint/_async_executor.py index 7da04c12b4b8aa..52c73d5aa3186e 100644 --- a/torch/distributed/checkpoint/_async_executor.py +++ b/torch/distributed/checkpoint/_async_executor.py @@ -15,7 +15,7 @@ class _AsyncCheckpointExecutor(abc.ABC): @abc.abstractmethod def execute_save( self, - staged_state_dict: STATE_DICT_TYPE, + staged_state_dict_future: Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]], *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, diff --git a/torch/distributed/checkpoint/_async_process_executor.py b/torch/distributed/checkpoint/_async_process_executor.py index 513d71f427defc..ddbe3f4dcb3ceb 100644 --- a/torch/distributed/checkpoint/_async_process_executor.py +++ b/torch/distributed/checkpoint/_async_process_executor.py @@ -257,7 +257,7 @@ def __init__(self) -> None: def _execute_save_impl( *, pg_init_info: Optional[_ProcessGroupInitInfo], - staged_state_dict: STATE_DICT_TYPE, + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, @@ -279,6 +279,11 @@ def create_checkpoint_daemon_process() -> None: create_checkpoint_daemon_process() assert _CHECKPOINT_PROCESS is not None + staged_state_dict = ( + staging_future_or_state_dict.result() + if isinstance(staging_future_or_state_dict, Future) + else staging_future_or_state_dict + ) return _CHECKPOINT_PROCESS.save( staged_state_dict=staged_state_dict, checkpoint_id=checkpoint_id, @@ -288,7 +293,7 @@ def create_checkpoint_daemon_process() -> None: def execute_save( self, - staged_state_dict: STATE_DICT_TYPE, + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, @@ -320,7 +325,7 @@ def execute_save( f: Future = self._executor.submit( self._execute_save_impl, pg_init_info=pg_init_info, - staged_state_dict=staged_state_dict, + staging_future_or_state_dict=staging_future_or_state_dict, checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, diff --git a/torch/distributed/checkpoint/_async_thread_executor.py b/torch/distributed/checkpoint/_async_thread_executor.py index 541ad1d8c8eb3c..1038c177529d29 100644 --- a/torch/distributed/checkpoint/_async_thread_executor.py +++ b/torch/distributed/checkpoint/_async_thread_executor.py @@ -11,24 +11,46 @@ from torch.distributed.checkpoint.storage import StorageWriter +def save_wrapper( + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> Future: + from torch.distributed.checkpoint.state_dict_saver import save + + staged_dict = ( + staging_future_or_state_dict.result() + if isinstance(staging_future_or_state_dict, Future) + else staging_future_or_state_dict + ) + return save( + staged_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + ) + + class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): def __init__(self) -> None: self._executor = ThreadPoolExecutor(max_workers=1) def execute_save( self, - staged_state_dict: STATE_DICT_TYPE, + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> Future: - from torch.distributed.checkpoint.state_dict_saver import save - f: Future = self._executor.submit( - save, - staged_state_dict, + save_wrapper, + staging_future_or_state_dict=staging_future_or_state_dict, checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, diff --git a/torch/distributed/checkpoint/_checkpointer.py b/torch/distributed/checkpoint/_checkpointer.py index d35c8b59ca3609..d21d8248d20479 100644 --- a/torch/distributed/checkpoint/_checkpointer.py +++ b/torch/distributed/checkpoint/_checkpointer.py @@ -83,12 +83,14 @@ def async_save( Returns: Future: A future holding the resultant Metadata object from `save`. """ - return saver.async_save( + response = saver.async_save( state_dict, storage_writer=self.storage_writer, process_group=self.process_group, planner=self.save_planner, ) + assert isinstance(response, Future) + return response def load(self, state_dict: dict[str, Any]) -> None: """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py new file mode 100644 index 00000000000000..0637073d4d9909 --- /dev/null +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -0,0 +1,713 @@ +# pyre-strict + +import concurrent.futures +import json +import logging +import math +import os +import struct +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +import fsspec # type: ignore[import-untyped] +from fsspec.core import url_to_fs # type: ignore[import-untyped] + +import torch +from torch.distributed.checkpoint._hf_utils import ( + _gen_file_name, + _get_dcp_custom_metadata, + _get_dtype, + _get_safetensors_file_metadata, + _metadata_fn, + DATA_OFFSETS_KEY, + DEFAULT_EXTRA_METADATA_KEY, + DTYPE_KEY, + SAVED_OFFSETS_KEY, + SHAPE_KEY, + SUFFIX, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass +class _FqnData: + """ + Dataclass to store information about a tensor (identified by its fully qualified name). + + Attributes: + offset_in_file: Byte offset where this tensor's data begins in the output file + shape_in_file: Shape of the tensor in the output file + dtype_size: Size of the tensor's data type in bytes + dtype_str: String representation of the tensor's data type + """ + + offset_in_file: int = 0 + shape_in_file: list[int] = field(default_factory=list) + dtype_size: int = 0 + dtype_str: str = "" + + +@dataclass +class _OutputFileData: + """ + Dataclass to store information about an output safetensors file. + + Attributes: + metadata_size: Size of the metadata section in bytes + fqn_data: Dictionary mapping tensor names to their metadata + """ + + metadata_size: int = 0 + fqn_data: dict[str, _FqnData] = field(default_factory=dict) + + +@dataclass +class _InputFileData: + """ + Dataclass to store information about an input safetensors file. + + Attributes: + metadata_size: Size of the metadata section in bytes + metadata: Json metadata from the safetensors file + """ + + metadata_size: int = 0 + metadata: Any = None + + +def _parse_input_metadata( + input_files_data: dict[str, _InputFileData], + output_files_data: dict[str, _OutputFileData], +) -> None: + """ + Parse metadata from input safetensors files to determine the full tensor shapes and types. + + This function analyzes the metadata from all input files to determine the complete shape + of each tensor after consolidation. It updates the output_files_data with this information. + + Args: + input_files_data: dict of metadata from input safetensors files + output_files_data: Dictionary mapping output file paths to their metadata + + Raises: + ValueError: If no DCP custom metadata is found in a safetensors file + """ + # Dictionary to track the full size of each tensor across all shards + fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {} + + for file_data in input_files_data.values(): + safetensors_metadata = file_data.metadata + dcp_sharding_info = _get_dcp_custom_metadata(safetensors_metadata) + if not dcp_sharding_info: + raise ValueError( + "No DCP custom metadata found in safetensors file. The file must be saved with DCP to be consolidated." + ) + + for key, val in safetensors_metadata.items(): + if key == DEFAULT_EXTRA_METADATA_KEY: + continue + + # Get the shape of this tensor shard and its offset in the full tensor + sizes = val[SHAPE_KEY] + offsets = dcp_sharding_info[key][SAVED_OFFSETS_KEY] + + if key not in fqn_to_size_mapping: + # First time seeing this tensor - calculate its full size by adding offsets to dimensions + cur_size = [size + offset for size, offset in zip(sizes, offsets)] + fqn_to_size_mapping[key] = (cur_size, val[DTYPE_KEY]) + else: + # We've seen this tensor before - update its size if this shard extends beyond current known dimensions + cur_size = fqn_to_size_mapping[key][0] + for i in range(len(sizes)): + cur_size[i] = max(cur_size[i], sizes[i] + offsets[i]) + + # Now that we know the full size of each tensor, populate the output file data + for fqn, tensor_info in fqn_to_size_mapping.items(): + tensor_size = tensor_info[0] + dtype_str = tensor_info[1] + for output_data in output_files_data.values(): + # Add this tensor to the output file if it's already assigned there or if we're using a single output file + if fqn in output_data.fqn_data or len(output_files_data) == 1: + output_data.fqn_data[fqn] = _FqnData( + shape_in_file=tensor_size, + dtype_size=torch.finfo(_get_dtype(dtype_str)).bits + // 8, # Convert bits to bytes + dtype_str=dtype_str, + ) + + +def _write_metadata( + fs: fsspec.AbstractFileSystem, + output_files_data: dict[str, _OutputFileData], +) -> None: + """ + Write metadata to the beginning of each output safetensors file. + + This function writes the metadata section to each output file, including information + about tensor shapes, data types, and offsets. It also updates the offset_in_file + field for each tensor in the output_files_data. + + Args: + fs: Filesystem interface for file operations + output_files_data: Dictionary mapping output file paths to their metadata + """ + # Process each output file + for file_path, output_data in output_files_data.items(): + with fs.open(file_path, "wb") as f: + metadata = {} + curr_offset = 0 + + # Calculate offsets for each tensor in the file + for fqn, fqn_data in output_data.fqn_data.items(): + # Calculate the end offset by multiplying all dimensions and the data type size + end_offset = ( + curr_offset + + math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size + ) + + # Store metadata for this tensor + metadata[fqn] = { + SHAPE_KEY: fqn_data.shape_in_file, + DTYPE_KEY: fqn_data.dtype_str, + DATA_OFFSETS_KEY: [ + curr_offset, + end_offset, + ], # Start and end byte offsets + } + # Store the offset for later use when writing the actual tensor data + fqn_data.offset_in_file = curr_offset + + # Update current offset for the next tensor + curr_offset = end_offset + + # Convert metadata to JSON and encode as bytes + json_metadata = json.dumps(metadata) + json_bytes = json_metadata.encode("utf-8") + + # Write the metadata size as an 8-byte unsigned integer (little-endian) + size_in_bytes = len(json_bytes) + header_len = struct.pack(" None: + """ + Process a single output file by writing tensor data from input files. + + This function is designed to be run in parallel for different output files. + + Args: + input_fs: Filesystem interface for input file operations + output_fs: Filesystem interface for output file operations + output_file: Path to the output file + output_data: Metadata for the output file + input_safetensors_files: List of input safetensors file paths + input_metadatas: Dictionary mapping input file paths to their metadata + """ + # Process each input safetensors file + for safetensors_file in input_files_data.keys(): + with input_fs.open(safetensors_file, "rb") as f: + file_metadata = input_files_data[safetensors_file].metadata + input_metadata_size = input_files_data[safetensors_file].metadata_size + for fqn, metadata in file_metadata.items(): + if fqn == DEFAULT_EXTRA_METADATA_KEY: + continue + + # Skip if this tensor doesn't belong in this output file + if fqn not in output_data.fqn_data: + continue + + data_offsets = metadata[DATA_OFFSETS_KEY] + # Get the tensor data as bytes + f.seek(input_metadata_size + data_offsets[0]) + data_to_write = f.read(data_offsets[1]) + + # Get the offsets of this tensor shard within the full tensor + offsets_of_tensor_being_read = _get_dcp_custom_metadata(file_metadata)[ + fqn + ][SAVED_OFFSETS_KEY] # type: ignore[index] + + # Get metadata for this tensor in the output file + fqn_data = output_data.fqn_data[fqn] + + # Write this tensor shard to the appropriate position in the output file + _write_sub_tensor_to_file( + output_fs, + data_to_write, + fqn_data.dtype_size, # Size of each element in bytes + fqn_data.shape_in_file, # Full tensor shape + offsets_of_tensor_being_read, # Where this shard belongs in the full tensor + metadata[SHAPE_KEY], # Shape of this shard + output_file, + # Calculate the exact byte position where this tensor data should start + output_data.metadata_size + fqn_data.offset_in_file, + ) + + +def _write_data( + input_fs: fsspec.AbstractFileSystem, + output_fs: fsspec.AbstractFileSystem, + input_files_data: dict[str, _InputFileData], + output_files_data: dict[str, _OutputFileData], + num_threads: int = 1, +) -> None: + """ + Write tensor data from input files to the output files. + + This function reads tensor data from each input file and writes it to the appropriate + position in the output files based on the tensor's offsets. When num_threads > 1, + the work is split across threads with each thread handling a different output file. + + Args: + input_fs: Filesystem interface for input file operations + output_fs: Filesystem interface for output file operations + input_files_data: Dictionary mapping input file paths to their metadata + output_files_data: Dictionary mapping output file paths to their metadata + num_threads: Number of threads to use for parallel processing + """ + if num_threads <= 1 or len(output_files_data) <= 1: + # Sequential processing + for output_file, output_data in output_files_data.items(): + _process_output_file( + input_fs, output_fs, output_file, output_data, input_files_data + ) + else: + # Parallel processing with ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(num_threads, len(output_files_data)) + ) as executor: + futures = [] + for output_file, output_data in output_files_data.items(): + futures.append( + executor.submit( + _process_output_file, + input_fs, + output_fs, + output_file, + output_data, + input_files_data, + ) + ) + + # Wait for all futures to complete + for future in concurrent.futures.as_completed(futures): + # Handle any exceptions that might have occurred + try: + future.result() + except Exception as e: + print(f"Error processing output file: {e}") + raise + + +def _write_row_wise_tensor( + fs: fsspec.AbstractFileSystem, + sub_tensor_bytes: bytearray, + element_size: int, + full_tensor_strides: list[int], + sub_tensor_strides: list[int], + sub_tensor_offsets: list[int], + sub_tensor_shape: list[int], + output_file_path: str, + output_start_byte: int, +) -> None: + """ + Writes a row-wise sharded tensor to the output file. + + This is an optimized path for tensors that are sharded along the first dimension, + with all other dimensions being complete. This allows writing entire rows at once. + + Args: + fs: Filesystem interface for file operations + sub_tensor_bytes: Byte array containing the sub-tensor data + element_size: The size of each element in bytes + full_tensor_strides: Strides of the full tensor + sub_tensor_strides: Strides of the sub-tensor + sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor + sub_tensor_shape: The shape of the sub-tensor + output_file_path: The path to the file where the full tensor is stored + output_start_byte: The starting byte of the full tensor in the file + """ + # Open the output file in read+binary mode to allow seeking and writing + with fs.open(output_file_path, "r+b") as out_f: + # Calculate the number of elements in each row + elements_per_row = full_tensor_strides[ + 0 + ] # This is the stride of the first dimension + + # For each row in the sub-tensor + for row_idx in range(sub_tensor_shape[0]): + # Calculate the row index in the full tensor + full_row_idx = sub_tensor_offsets[0] + row_idx + + # Calculate the position in the full tensor + full_pos = full_row_idx * full_tensor_strides[0] + full_byte_offset = output_start_byte + full_pos * element_size + + # Calculate the position in the sub-tensor + sub_pos = row_idx * sub_tensor_strides[0] + sub_byte_offset = sub_pos * element_size + + # Extract the row data from the sub-tensor + row_size = elements_per_row * element_size + row_data = sub_tensor_bytes[sub_byte_offset : sub_byte_offset + row_size] + + # Seek to the correct position in the output file and write the data + out_f.seek(full_byte_offset) + out_f.write(row_data) + + +def _write_column_wise_tensor( + fs: fsspec.AbstractFileSystem, + sub_tensor_bytes: bytearray, + element_size: int, + tensor_shape: list[int], + sub_tensor_offsets: list[int], + sub_tensor_shape: list[int], + output_file_path: str, + output_start_byte: int, +) -> None: + """ + Writes a column-wise sharded 2D tensor to the output file. + + This is an optimized path for 2D tensors that are sharded along the second dimension, + with the first dimension being complete. This requires writing column by column. + + Args: + fs: Filesystem interface for file operations + sub_tensor_bytes: Byte array containing the sub-tensor data + element_size: The size of each element in bytes + tensor_shape: The shape of the overall tensor + sub_tensor_strides: Strides of the sub-tensor + sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor + sub_tensor_shape: The shape of the sub-tensor + output_file_path: The path to the file where the full tensor is stored + output_start_byte: The starting byte of the full tensor in the file + """ + # Open the output file in read+binary mode to allow seeking and writing + with fs.open(output_file_path, "r+b") as out_f: + # For each column in the sub-tensor + for col_idx in range(sub_tensor_shape[1]): + # Calculate the column index in the full tensor + full_col_idx = sub_tensor_offsets[1] + col_idx + + # For each row in the column + for row_idx in range(sub_tensor_shape[0]): + # Calculate the position in the full tensor + full_pos = row_idx * tensor_shape[1] + full_col_idx + full_byte_offset = output_start_byte + full_pos * element_size + + # Calculate the position in the sub-tensor + sub_pos = row_idx * sub_tensor_shape[1] + col_idx + sub_byte_offset = sub_pos * element_size + + # Extract the element data from the sub-tensor + element_data = sub_tensor_bytes[ + sub_byte_offset : sub_byte_offset + element_size + ] + + # Seek to the correct position in the output file and write the data + out_f.seek(full_byte_offset) + out_f.write(element_data) + + +def _write_element_by_element( + fs: fsspec.AbstractFileSystem, + sub_tensor_bytes: bytearray, + element_size: int, + tensor_shape: list[int], + full_tensor_strides: list[int], + sub_tensor_strides: list[int], + sub_tensor_offsets: list[int], + sub_tensor_shape: list[int], + output_file_path: str, + output_start_byte: int, +) -> None: + """ + Writes a sub-tensor to the output file using a general element-by-element approach. + + This is a general approach that works for any sharding pattern, but is less efficient + than the specialized approaches for row-wise or column-wise sharding. + + Args: + fs: Filesystem interface for file operations + sub_tensor_bytes: Byte array containing the sub-tensor data + element_size: The size of each element in bytes + tensor_shape: The shape of the overall tensor + full_tensor_strides: Strides of the full tensor + sub_tensor_strides: Strides of the sub-tensor + sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor + sub_tensor_shape: The shape of the sub-tensor + output_file_path: The path to the file where the full tensor is stored + output_start_byte: The starting byte of the full tensor in the file + """ + # Open the output file in read+binary mode to allow seeking and writing + with fs.open(output_file_path, "r+b") as out_f: + # Create a list to hold the current indices for each dimension + indices = [0] * len(tensor_shape) + + # Calculate the total number of elements in the sub-tensor + total_elements = 1 + for dim_size in sub_tensor_shape: + total_elements *= dim_size + + # Process each element in the sub-tensor + for element_idx in range(total_elements): + # Calculate the indices for this element in the sub-tensor + sub_idx = element_idx + for dim in range(len(sub_tensor_shape) - 1, -1, -1): + indices[dim] = sub_idx % sub_tensor_shape[dim] + sub_idx //= sub_tensor_shape[dim] + + # Calculate the position of this element in the sub-tensor's byte array + sub_pos = 0 + for dim in range(len(sub_tensor_shape)): + sub_pos += indices[dim] * sub_tensor_strides[dim] + sub_byte_offset = sub_pos * element_size + + # Calculate the position of this element in the full tensor + full_pos = 0 + for dim in range(len(tensor_shape)): + # The global index is the local index plus the offset for this dimension + global_idx = indices[dim] + sub_tensor_offsets[dim] + full_pos += global_idx * full_tensor_strides[dim] + full_byte_offset = output_start_byte + full_pos * element_size + + # Extract the element data from the sub-tensor + element_data = sub_tensor_bytes[ + sub_byte_offset : sub_byte_offset + element_size + ] + + # Seek to the correct position in the output file and write the data + out_f.seek(full_byte_offset) + out_f.write(element_data) + + +def _write_sub_tensor_to_file( + fs: fsspec.AbstractFileSystem, + sub_tensor_bytes: bytearray, + element_size: int, + tensor_shape: list[int], + sub_tensor_offsets: list[int], + sub_tensor_shape: list[int], + output_file_path: str, + output_start_byte: int, +) -> None: + """ + Writes a sub-tensor from a byte array into a file representing the full tensor at specified offsets. + + This function handles the complex task of placing a tensor shard (sub-tensor) at the correct + position within the consolidated tensor file. It works by calculating the exact byte offsets + for each slice of data and writing them to the appropriate positions. This implementation + supports tensors of any dimensionality with optimized paths for common sharding patterns: + - Row-wise sharding (optimized path) + - Column-wise sharding for 2D tensors (optimized path) + - Any other arbitrary sharding pattern (general element-by-element approach) + + Args: + sub_tensor_bytes: Byte array containing the sub-tensor data + element_size: The size of each element in bytes + tensor_shape: The shape of the overall tensor (list) + sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor (list) + sub_tensor_shape: The shape of the sub-tensor (list) + output_file_path: The path to the file where the full tensor is stored + output_start_byte: The starting byte of the full tensor in the file + """ + # Handle the case of empty tensors + if not tensor_shape or not sub_tensor_shape: + return + + # Calculate strides for the full tensor (row-major order, C-style) + # Stride is the number of elements to skip to move to the next element in that dimension + full_tensor_strides = [1] * len(tensor_shape) + for i in range(len(tensor_shape) - 2, -1, -1): + full_tensor_strides[i] = full_tensor_strides[i + 1] * tensor_shape[i + 1] + + # Calculate strides for the sub-tensor (row-major order, C-style) + sub_tensor_strides = [1] * len(sub_tensor_shape) + for i in range(len(sub_tensor_shape) - 2, -1, -1): + sub_tensor_strides[i] = sub_tensor_strides[i + 1] * sub_tensor_shape[i + 1] + + # Check if this is a row-wise sharded tensor + # Row-wise sharding is detected when the last dimension is complete + # and only the first dimension is partial + is_row_wise = False + if len(tensor_shape) >= 2: + # Check if all dimensions except the first are complete + all_other_dims_complete = True + for i in range(1, len(tensor_shape)): + if sub_tensor_shape[i] != tensor_shape[i]: + all_other_dims_complete = False + break + + # Row-wise sharding: first dimension is partial, all others are complete + is_row_wise = all_other_dims_complete and sub_tensor_shape[0] < tensor_shape[0] + + # Check if this is a column-wise sharded 2D tensor + # Column-wise sharding is detected when the first dimension is complete + # and the second dimension is partial (only for 2D tensors) + is_column_wise = False + if len(tensor_shape) == 2: + is_column_wise = ( + sub_tensor_shape[0] == tensor_shape[0] + and sub_tensor_shape[1] < tensor_shape[1] + ) + + # Call the appropriate function based on the sharding pattern + if is_row_wise: + _write_row_wise_tensor( + fs, + sub_tensor_bytes, + element_size, + full_tensor_strides, + sub_tensor_strides, + sub_tensor_offsets, + sub_tensor_shape, + output_file_path, + output_start_byte, + ) + elif is_column_wise: + _write_column_wise_tensor( + fs, + sub_tensor_bytes, + element_size, + tensor_shape, + sub_tensor_offsets, + sub_tensor_shape, + output_file_path, + output_start_byte, + ) + else: + _write_element_by_element( + fs, + sub_tensor_bytes, + element_size, + tensor_shape, + full_tensor_strides, + sub_tensor_strides, + sub_tensor_offsets, + sub_tensor_shape, + output_file_path, + output_start_byte, + ) + + +def _write_overall_metadata_file( + fs: fsspec.AbstractFileSystem, + output_dir: str, + output_files_data: dict[str, _OutputFileData], +) -> None: + total_size = 0 + weight_map = {} + for output_path, value in output_files_data.items(): + for fqn, fqn_data in value.fqn_data.items(): + total_size += math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size + weight_map[fqn] = os.path.basename(output_path) + + metadata_to_write: dict[str, Any] = {} + metadata_to_write["metadata"] = {"total_size": total_size} + metadata_to_write["weight_map"] = weight_map + + metadata_path = os.path.join(output_dir, f"{_metadata_fn}") + with fs.open(metadata_path, "w") as metadata_file: + json.dump(metadata_to_write, metadata_file, indent=2) + + +def consolidate_safetensors_files( + input_dir: str, + output_dir: str, + fqn_to_index_mapping: Optional[dict[str, int]] = None, + num_threads: int = 1, +) -> None: + """ + Main function to consolidate sharded safetensors files into one or more output files. + + This function orchestrates the entire consolidation process: + 1. Sets up the output file structure based on the fqn_to_index_mapping + 2. Finds all safetensors files in the input directory + 3. Parses metadata from all input files + 4. Writes metadata to the output files + 5. Writes tensor data from input files to output files + 6. Writes overall model.index.safetensors.json file with weight map + + Args: + input_dir: Directory containing sharded safetensors files + output_dir: Directory where consolidated files will be written + fqn_to_index_mapping: Optional mapping of tensor names to output file indices. + If None, all tensors will be consolidated into a single file. + num_threads: Number of threads to use for parallel processing of saving data to output files. + """ + start_time = time.time() + logger.info( + "Consolidating safetensors files from %s to %s. Beginning at time %f", + input_dir, + output_dir, + start_time, + ) + # Create filesystem using fsspec for file operations + input_fs, _ = url_to_fs(input_dir) + output_fs, _ = url_to_fs(output_dir) + + # Initialize the output file structure + output_files_data: dict[str, _OutputFileData] = {} + if fqn_to_index_mapping is not None: + # Create multiple output files based on the provided mapping + for fqn, index in fqn_to_index_mapping.items(): + # Generate names like "model-00001-of-00005.safetensors" + file_name = _gen_file_name(index, max(fqn_to_index_mapping.values())) + output_path = f"{output_dir}/{file_name}" + + if output_path not in output_files_data: + output_files_data[output_path] = _OutputFileData( + fqn_data={fqn: _FqnData()} + ) + else: + output_files_data[output_path].fqn_data[fqn] = _FqnData() + else: + # If no mapping is provided, create a single output file + file_name = _gen_file_name(1, 1) + output_path = f"{output_dir}/{file_name}" + output_files_data[output_path] = _OutputFileData() + + # Find all safetensors files in the input directory + safetensors_files = [] + for file in input_fs.ls(input_dir, detail=False): + if file.endswith(SUFFIX): + safetensors_files.append(file) + + # Read metadata from all input files + input_files_data: dict[str, _InputFileData] = {} + for safetensor_file in safetensors_files: + with input_fs.open(safetensor_file, "rb") as f: + metadata, size = _get_safetensors_file_metadata(f) + input_files_data[safetensor_file] = _InputFileData( + metadata_size=size, metadata=metadata + ) + + # Step 1: Parse metadata to determine tensor shapes and types + _parse_input_metadata(input_files_data, output_files_data) + + # Step 2: Write metadata headers to output files + _write_metadata(output_fs, output_files_data) + + # Step 3: Write actual tensor data from input files to output files + _write_data(input_fs, output_fs, input_files_data, output_files_data, num_threads) + + # Step 4: Write overall model.index.safetensors.json file with weight map + _write_overall_metadata_file(output_fs, output_dir, output_files_data) + + logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time) diff --git a/torch/distributed/checkpoint/_dedup_save_plans.py b/torch/distributed/checkpoint/_dedup_save_plans.py index eb64dd3b063854..3e2cf954c409d3 100644 --- a/torch/distributed/checkpoint/_dedup_save_plans.py +++ b/torch/distributed/checkpoint/_dedup_save_plans.py @@ -62,25 +62,3 @@ def dedup_save_plans( ) for plan, item_indexes in zip(all_plans, plan_to_item_indices) ] - - -def dedup_save_plans_with_fqn_to_index_mapping( - all_plans: list[SavePlan], fqn_to_index_mapping: dict[str, int] -) -> list[SavePlan]: - num_plans = len(all_plans) - - to_remove: list[set] = [set() for _ in range(len(all_plans))] - for plan_idx, plan in enumerate(all_plans): - for item_idx, item in enumerate(plan.items): - if (fqn_to_index_mapping[item.index.fqn] - 1) % num_plans != plan_idx: - to_remove[plan_idx].add(item_idx) - - for plan_idx, remove_set in enumerate(to_remove): - new_items = [ - write_item - for item_idx, write_item in enumerate(all_plans[plan_idx].items) - if item_idx not in remove_set - ] - all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) - - return all_plans diff --git a/torch/distributed/checkpoint/_experimental/__init__.py b/torch/distributed/checkpoint/_experimental/__init__.py new file mode 100644 index 00000000000000..8361362eb3a5ed --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/__init__.py @@ -0,0 +1,53 @@ +""" +Checkpoint functionality for machine learning models. + +This module provides classes for saving and loading model checkpoints in a distributed +training environment. It includes functionality for coordinating checkpoint operations +across multiple processes and customizing the checkpoint process through hooks. + +Key components: +- Checkpointer: Main class for orchestrating checkpoint operations (save, load) +- CheckpointWriter: Handles writing state dictionaries to storage +- CheckpointReader: Handles reading state dictionaries from storage read +- Barrier: Synchronization mechanism for distributed checkpointing +- RankInfo: Information about the current rank in a distributed environment +""" + +from .barriers import ( + Barrier, + BarrierConfig, + create_barrier_from_config, + TCPStoreBarrier, +) +from .builder import make_async_checkpointer, make_sync_checkpointer +from .checkpoint_reader import CheckpointReader +from .checkpoint_writer import CheckpointWriter, CheckpointWriterConfig, WriterHook +from .checkpointer import AsyncCheckpointer, Checkpointer, SyncCheckpointer +from .config import CheckpointerConfig +from .staging import CheckpointStager, CheckpointStagerConfig, DefaultStager +from .types import RankInfo, STATE_DICT +from .utils import wrap_future + + +__all__ = [ + "Barrier", + "TCPStoreBarrier", + "CheckpointReader", + "CheckpointWriter", + "CheckpointWriterConfig", + "WriterHook", + "Checkpointer", + "SyncCheckpointer", + "AsyncCheckpointer", + "CheckpointerConfig", + "BarrierConfig", + "create_barrier_from_config", + "CheckpointStager", + "CheckpointStagerConfig", + "DefaultStager", + "RankInfo", + "STATE_DICT", + "wrap_future", + "make_sync_checkpointer", + "make_async_checkpointer", +] diff --git a/torch/distributed/checkpoint/_experimental/barriers.py b/torch/distributed/checkpoint/_experimental/barriers.py new file mode 100644 index 00000000000000..18de93c81d131f --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/barriers.py @@ -0,0 +1,268 @@ +""" +Barrier implementations for synchronizing distributed checkpoint operations. + +This module provides abstract and concrete barrier implementations that ensure +all ranks in a distributed training environment complete their checkpoint operations +before proceeding, which is essential for data consistency. +""" + +import abc +import logging +from collections import Counter +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any, Optional + +import torch.distributed as dist +import torch.distributed.elastic.utils.store as store_util + + +logger = logging.getLogger() + + +# Registry of barrier types +BARRIER_REGISTRY: dict[str, type] = {} + + +def register_barrier(barrier_class: type) -> type: + """Register a barrier class in the global registry.""" + if hasattr(barrier_class, "barrier_type"): + BARRIER_REGISTRY[barrier_class.barrier_type] = barrier_class + return barrier_class + + +@dataclass +class BarrierConfig: + """ + Configuration for barrier construction. + + This class provides a flexible way to configure different barrier implementations + with their specific constructor arguments. The barrier type will be looked up + from a registry and instantiated with rank_info and barrier_args. + + Attributes: + barrier_type: A string identifying the barrier type (e.g., "tcp_store"). + If None, no barrier will be used. + barrier_args: Dictionary of arguments to pass to the barrier constructor. + rank_info will be automatically injected as the first argument. + + Examples: + # No barrier + BarrierConfig() + + # TCPStore barrier + BarrierConfig( + barrier_type="tcp_store", + barrier_args={ + 'timeout_barrier_init_secs': 30, + 'barrier_prefix_list': ['checkpoint'], + 'use_checkpoint_barrier_tcpstore_libuv': False, + 'tcpstore_port': 12345, + 'master_address': 'localhost' + } + ) + """ + + barrier_type: Optional[str] = None + barrier_args: dict[str, Any] = field(default_factory=dict) + + +def create_barrier_from_config( + barrier_config: BarrierConfig, +) -> Optional["Barrier"]: + """ + Create a barrier instance from BarrierConfig. + + Args: + barrier_config: Configuration for barrier construction. + + Returns: + Barrier instance or None if no barrier type is configured. + + Raises: + ValueError: If the barrier_type is not found in the registry. + """ + if barrier_config.barrier_type is None: + return None + + if barrier_config.barrier_type not in BARRIER_REGISTRY: + raise ValueError( + f"Unknown barrier type: {barrier_config.barrier_type}. " + f"Available types: {list(BARRIER_REGISTRY.keys())}" + ) + + barrier_class = BARRIER_REGISTRY[barrier_config.barrier_type] + return barrier_class(**barrier_config.barrier_args) + + +class Barrier(abc.ABC): + """ + Abstract base class for synchronization barriers. + + A barrier ensures that all ranks in a distributed environment reach a certain + point in execution before any rank proceeds further, which is essential for + coordinating operations like checkpointing across multiple processes. + """ + + @abc.abstractmethod + def __init__(self, **kwargs: dict[str, Any]): + """ + Initialize a barrier. + + Args: + **kwargs: Keyword arguments for specific barrier implementations. + Common arguments may include rank information, barrier prefixes, + timeout settings, and other barrier-specific configuration. + """ + # No implementation needed in the abstract base class + + @abc.abstractmethod + def execute_barrier(self) -> None: + """ + Execute a synchronization barrier. + + This method uses the barrier_prefix provided during initialization to + coordinate synchronization across processes. + """ + + +@register_barrier +class DistBarrier(Barrier): + """ + A barrier implementation using PyTorch's distributed barrier for synchronization. + + This barrier uses the built-in torch.distributed.barrier() function to coordinate + synchronization across multiple processes. It's simpler than TCPStoreBarrier but + requires an initialized process group. + """ + + barrier_type = "dist_barrier" + + def __init__( + self, + ) -> None: + """ + Initialize a DistBarrier. + + This barrier requires an initialized PyTorch distributed process group. + No additional arguments are needed as it uses the current process group. + + Raises: + AssertionError: If the distributed process group is not initialized. + """ + assert dist.is_initialized(), ( + "DistBarrier requires an initialized process group." + ) + + def execute_barrier(self) -> None: + """ + Execute a synchronization barrier using the prefix provided during initialization. + """ + # Note: dist.barrier() doesn't support explicit timeouts + # The timeout is handled by the underlying implementation + dist.barrier() + + +@register_barrier +class TCPStoreBarrier(Barrier): + """ + A barrier implementation using PyTorch's TCPStore for synchronization. + + This barrier uses a TCP-based distributed key-value store to coordinate + synchronization across multiple processes. It uses a single TCP store + for all barrier operations, with different prefixes to distinguish between + different barrier types. + """ + + barrier_type = "tcp_store" + + def __init__( + self, + global_rank: int, + global_world_size: int, + barrier_prefix: str, + timeout_barrier_init_secs: int, + use_checkpoint_barrier_tcpstore_libuv: bool, + tcpstore_port: int, + master_address: str, + timeout_secs: int, + ): + """ + Initialize a TCPStoreBarrier. + + Args: + global_rank: The rank of the current process in the distributed environment. + global_world_size: The total number of processes in the distributed environment. + barrier_prefix: A string prefix to identify this specific barrier. + timeout_barrier_init_secs: Timeout in seconds for initializing the TCPStore. + use_checkpoint_barrier_tcpstore_libuv: Whether to use libuv for the TCPStore. + tcpstore_port: Port number for the TCPStore. + master_address: Address of the master node for the TCPStore. + timeout_secs: Maximum time in seconds to wait for all ranks to reach the barrier. + """ + logger.info( + "Initializing TCPStore master_address=%s tcpstore_port=%s rank=%s " + "world_size=%s barrier_prefix=%s timeout_barrier_init_secs=%s " + "use_checkpoint_barrier_tcpstore_libuv=%s timeout_secs=%s", + master_address, + tcpstore_port, + global_rank, + global_world_size, + barrier_prefix, + timeout_barrier_init_secs, + use_checkpoint_barrier_tcpstore_libuv, + timeout_secs, + ) + + # Counter collection to track barrier seq on a per barrier prefix basis. + self._tcp_store_barrier_seq: Counter = Counter() + self._barrier_prefix = barrier_prefix + + # Store rank and world size for barrier operations + self._global_rank = global_rank + self._global_world_size = global_world_size + self._timeout_secs = timeout_secs + + # Create a single TCP store for all barrier operations + self._tcp_store = dist.TCPStore( + master_address, + int(tcpstore_port), + world_size=self._global_world_size, + timeout=timedelta(seconds=timeout_barrier_init_secs), + is_master=(self._global_rank == 0), + ) + + def execute_barrier(self) -> None: + """ + Execute a synchronization barrier using the prefix provided during initialization. + + The implementation uses a sequence number that is incremented every time + a barrier is reached. The sequence number is per barrier prefix to allow + different barriers to operate concurrently. + """ + barrier_prefix = self._barrier_prefix + + logger.info( + "Executing barrier barrier_prefix=%s timeout_secs=%s", + barrier_prefix, + self._timeout_secs, + ) + + def _rank_key(rank: int) -> str: + return f"rank{rank}" + + # Track which barrier sequence this rank is joining. + self._tcp_store.set( + _rank_key(self._global_rank), + str(self._tcp_store_barrier_seq[barrier_prefix]), + ) + + # Execute barrier for that sequence number (for the specific prefix). + store_util.barrier( + store=self._tcp_store, + world_size=self._global_world_size, + key_prefix=( + barrier_prefix + str(self._tcp_store_barrier_seq[barrier_prefix]) + ), + ) + self._tcp_store_barrier_seq[barrier_prefix] += 1 diff --git a/torch/distributed/checkpoint/_experimental/builder.py b/torch/distributed/checkpoint/_experimental/builder.py new file mode 100644 index 00000000000000..f705072790a1d6 --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/builder.py @@ -0,0 +1,173 @@ +""" +Factory functions for creating checkpointer instances with sensible defaults. + +This module provides high-level factory functions that simplify the creation +of checkpointer instances by automatically handling component initialization +and configuration with reasonable defaults. +""" + +from typing import Any, Callable, Optional + +import torch.distributed as dist + +from .barriers import create_barrier_from_config +from .checkpoint_process import CheckpointProcess +from .checkpoint_reader import CheckpointReader +from .checkpoint_writer import CheckpointWriter, CheckpointWriterConfig, WriterHook +from .checkpointer import AsyncCheckpointer, SyncCheckpointer +from .config import CheckpointerConfig +from .staging import DefaultStager +from .types import RankInfo + + +def _get_default_rank_info() -> RankInfo: + """ + Get default rank information from the current distributed environment. + + Returns: + RankInfo: Rank information from the default process group if initialized, + otherwise single-rank fallback. + """ + if dist.is_initialized(): + return RankInfo( + global_world_size=dist.get_world_size(), + global_rank=dist.get_rank(), + ) + else: + # Single-rank fallback + return RankInfo(global_world_size=1, global_rank=0) + + +def default_subprocess_init_fn(*_: Any) -> None: + """Default subprocess initialization function (no-op).""" + + +def default_writer_init_fn(rank_info: RankInfo) -> CheckpointWriter: + """Default checkpoint writer initialization function.""" + return CheckpointWriter( + config=CheckpointWriterConfig(), + rank_info=rank_info, + ) + + +def make_sync_checkpointer( + config: CheckpointerConfig = CheckpointerConfig(), + rank_info: Optional[RankInfo] = None, + commit_hook: Optional[WriterHook] = None, +) -> SyncCheckpointer: + """ + Factory function to create a SyncCheckpointer instance with sensible defaults. + + This function creates a synchronous checkpointer with default components, automatically + detecting rank information from the default process group if available, and using the + provided component configurations. + + Args: + config: CheckpointerConfig containing component-specific configurations + (writer_config, staging_config, process_config). Defaults to CheckpointerConfig(). + rank_info: RankInfo for distributed training. Defaults to auto-detection from + the default PyTorch distributed process group if initialized, otherwise + falls back to single-rank (world_size=1, rank=0). + commit_hook: Optional hook for custom actions before and after checkpoint commits. + + Returns: + SyncCheckpointer: A configured synchronous checkpointer instance. + + Examples: + # Simplest usage - auto-detect rank, default config + checkpointer = make_sync_checkpointer() + + # Explicit rank configuration + checkpointer = make_sync_checkpointer( + rank_info=RankInfo(global_world_size=4, global_rank=0) + ) + + # Disable barrier + from .barriers import BarrierConfig + config = CheckpointerConfig(barrier_config=BarrierConfig(barrier_type=None)) + checkpointer = make_sync_checkpointer(config=config) + """ + if rank_info is None: + rank_info = _get_default_rank_info() + + reader = CheckpointReader( + rank_info=rank_info, + ) + + barrier = create_barrier_from_config(config.barrier_config) + + writer = CheckpointWriter( + config=config.writer_config, + rank_info=rank_info, + barrier=barrier, + commit_hook=commit_hook, + ) + + return SyncCheckpointer( + writer=writer, + reader=reader, + ) + + +def make_async_checkpointer( + config: CheckpointerConfig = CheckpointerConfig(), + rank_info: Optional[RankInfo] = None, + subprocess_init_fn: Callable[..., None] = default_subprocess_init_fn, + subprocess_init_args: tuple[Any, ...] = (), + checkpoint_writer_init_fn: Callable[..., CheckpointWriter] = default_writer_init_fn, + checkpoint_writer_init_args: Optional[dict[str, Any]] = None, +) -> AsyncCheckpointer: + """ + Factory function to create an AsyncCheckpointer instance with sensible defaults. + + This function creates an asynchronous checkpointer using the provided configuration, + automatically detecting rank information if not provided. + + Args: + config: CheckpointerConfig containing component-specific configurations. + rank_info: RankInfo for distributed training. Defaults to auto-detection. + subprocess_init_fn: Function to initialize the subprocess. Defaults to no-op. + subprocess_init_args: Arguments to pass to subprocess_init_fn. + checkpoint_writer_init_fn: Function to create CheckpointWriter instance. + checkpoint_writer_init_args: Arguments to pass to checkpoint_writer_init_fn. + + Returns: + AsyncCheckpointer: A configured asynchronous checkpointer instance. + + Examples: + # Create with default config + checkpointer = make_async_checkpointer() + + # Create with custom init functions + checkpointer = make_async_checkpointer( + subprocess_init_fn=my_subprocess_init_fn, + checkpoint_writer_init_fn=my_writer_init_fn + ) + """ + if rank_info is None: + rank_info = _get_default_rank_info() + + reader = CheckpointReader( + rank_info=rank_info, + ) + + checkpoint_stager = DefaultStager( + config=config.staging_config, + ) + + checkpoint_writer_init_args = checkpoint_writer_init_args or {} + + checkpoint_process = CheckpointProcess( + rank_info=rank_info, + config=config.process_config, + subprocess_init_fn=subprocess_init_fn, + subprocess_init_args=subprocess_init_args, + checkpoint_writer_init_fn=checkpoint_writer_init_fn, + checkpoint_writer_init_args=checkpoint_writer_init_args, + ) + + return AsyncCheckpointer( + checkpoint_stager=checkpoint_stager, + checkpoint_process=checkpoint_process, + reader=reader, + ) diff --git a/torch/distributed/checkpoint/_experimental/checkpoint_process.py b/torch/distributed/checkpoint/_experimental/checkpoint_process.py new file mode 100644 index 00000000000000..8917245236e36e --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/checkpoint_process.py @@ -0,0 +1,361 @@ +import logging +import os +import traceback +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from multiprocessing.connection import Connection +from typing import Any, Callable, Optional, Union + +import torch.multiprocessing as mp +from torch.multiprocessing.spawn import ProcessExitedException + +from .checkpoint_writer import CheckpointWriter +from .types import RankInfo, STATE_DICT + + +logger = logging.getLogger(__name__) + + +@dataclass +class CheckpointProcessConfig: + """ + Configuration options for the CheckpointProcess. + + This class provides configuration options for the checkpoint process, + including initialization functions, timeouts, and writer configuration. + + Attributes: + subprocess_init_timeout_secs: Maximum time in seconds to wait for subprocess initialization. + subprocess_shutdown_timeout_secs: Maximum time in seconds to wait for subprocess shutdown. + """ + + subprocess_init_timeout_secs: int = 30 + subprocess_shutdown_timeout_secs: int = 60 + + +class RequestType(Enum): + PING = "ping" + WRITE_CHECKPOINT = "write_checkpoint" + TERMINATE_PROCESS = "exit" + + +@dataclass +class WorkerRequest: + """ + A dataclass for storing the command to be sent to the worker process. + Note: This relies on pickling to send the command to the worker process. Handle + backward compatibility accordingly. + """ + + request_type: RequestType + payload: dict[str, Any] + + +@dataclass +class WorkerResponse: + request_type: RequestType + success: bool + error_msg: Optional[str] = None + payload: Optional[dict[str, Any]] = None + + +class CheckpointProcess: + """ + A checkpoint writer that writes checkpoints to a remote process. + """ + + def __init__( + self, + rank_info: RankInfo, + config: CheckpointProcessConfig, + subprocess_init_fn: Callable[[Any], None], + subprocess_init_args: tuple[Any, ...], + checkpoint_writer_init_fn: Callable[..., CheckpointWriter], + checkpoint_writer_init_args: dict[str, Any], + ): + self._executor = ThreadPoolExecutor(max_workers=1) + self._rank_info = rank_info + self._config = config + self._subprocess_init_fn = subprocess_init_fn + self._subprocess_init_args = subprocess_init_args + self._checkpoint_writer_init_fn = checkpoint_writer_init_fn + self._checkpoint_writer_init_args = checkpoint_writer_init_args + self.process = None + self._parent_end: Optional[Connection] = None + self._child_end: Optional[Connection] = None + + self.process_creation_future = self._executor.submit( + self._create_subprocess, + config, + ) + + def _create_subprocess( + self, + config: CheckpointProcessConfig, + ) -> None: + logger.info( + "Creating checkpoint subprocess for rank %d", self._rank_info.global_rank + ) + + spawn_context = mp.get_context("spawn") + self._parent_end, child_end = spawn_context.Pipe() + + # Known workaround for https://github.com/pytorch/pytorch/issues/37377 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" + + logger.debug("Spawning subprocess for rank_info=%s", self._rank_info) + self.process = mp.spawn( + fn=CheckpointProcess._subprocess, + args=( + self._rank_info, + child_end, + self._subprocess_init_fn, + self._subprocess_init_args, + self._checkpoint_writer_init_fn, + self._checkpoint_writer_init_args, + ), + nprocs=1, + join=False, + daemon=True, + ) + + # close the child end of the pipe so recv on it will fail + # fast when the child process is terminated unexpectedly. + child_end.close() + self._send( + request_type=RequestType.PING, + payload={}, + ) + + logger.debug( + "Waiting for checkpoint subprocess to initialize (timeout: %ds)", + config.subprocess_init_timeout_secs, + ) + + # wait for the timeout or a response from subprocess + assert self._parent_end is not None, "Parent end of pipe should be initialized" + if not self._parent_end.poll(timeout=config.subprocess_init_timeout_secs): + msg = f"Timed out after {config.subprocess_init_timeout_secs}s waiting for checkpoint subprocess to initialize" + logger.error(msg) + raise TimeoutError(msg) + + self._recv() + logger.info("Checkpoint subprocess initialized successfully") + + @staticmethod + def _subprocess( + sub_rank: int, + rank_info: RankInfo, + parent_pipe: Connection, + subprocess_init_fn: Callable[[Any], None], + subprocess_init_args: tuple[Any, ...], + checkpoint_writer_init_fn: Callable[..., CheckpointWriter], + checkpoint_writer_init_args: dict[str, Any], + ) -> None: + logger.debug( + "Checkpoint subprocess started for rank %d/%d (PID: %d)", + rank_info.global_rank, + rank_info.global_world_size, + os.getpid(), + ) + + assert sub_rank == 0, "We need only one checkpointer per parent training" + request = WorkerRequest(request_type=RequestType.PING, payload={}) + + try: + # Calling initialize callback, so we can perform app-specific initialization of the subprocess. + subprocess_init_fn(*subprocess_init_args) + + # Initialize checkpoint writer - automatically include rank_info in init_args + writer_init_args = dict(checkpoint_writer_init_args) + if "rank_info" not in writer_init_args: + writer_init_args["rank_info"] = rank_info + checkpoint_writer = checkpoint_writer_init_fn(**writer_init_args) + + while True: + request = parent_pipe.recv() + + if request.request_type == RequestType.PING: + parent_pipe.send( + WorkerResponse(request_type=RequestType.PING, success=True) + ) + elif request.request_type == RequestType.WRITE_CHECKPOINT: + path = request.payload["path"] + logger.info("Writing checkpoint to %s", path) + + checkpoint_writer.write( + state_dict=request.payload["state_dict"], + path=path, + **request.payload["kwargs"], + ) + + logger.info("Checkpoint written successfully to %s", path) + parent_pipe.send( + WorkerResponse(RequestType.WRITE_CHECKPOINT, success=True) + ) + elif request.request_type == RequestType.TERMINATE_PROCESS: + logger.debug("Received termination request.") + parent_pipe.send( + WorkerResponse(RequestType.TERMINATE_PROCESS, success=True) + ) + logger.info("Subprocess terminated gracefully") + break + else: + error_msg = f"Unknown request type: {request.request_type}" + logger.error(error_msg) + raise ValueError(error_msg) + + except Exception as e: + error_text = traceback.format_exc() + logger.error( + "Exception in subprocess (%s): %s", type(e).__name__, error_text + ) + + # Communicating exception via the queue to the main process + parent_pipe.send( + WorkerResponse( + request_type=request.request_type, + success=False, + error_msg=error_text, + ) + ) + parent_pipe.close() + logger.error("Subprocess terminated due to exception: %s", e) + + def _send(self, request_type: RequestType, payload: dict[str, Any]) -> None: + try: + assert self._parent_end is not None, ( + "Parent end of pipe should be initialized" + ) + self._parent_end.send( + WorkerRequest( + request_type=request_type, + payload=payload, + ) + ) + except OSError as e: + error_msg = "Child process terminated unexpectedly" + logger.error( + "Communication failed during %s request: %s", request_type.value, e + ) + raise RuntimeError(error_msg) from e + + def _recv(self) -> Optional[dict[str, Any]]: + try: + assert self._parent_end is not None, ( + "Parent end of pipe should be initialized" + ) + response = self._parent_end.recv() + if response.success is False: + error_msg = ( + f"Unexpected response from worker process: {response.error_msg}" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + return response.payload + except (EOFError, BrokenPipeError, ConnectionResetError) as e: + error_msg = f"Child process terminated unexpectedly: {e}" + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + def write( + self, + state_dict: Union[STATE_DICT, Future[STATE_DICT]], + path: str, + **kwargs: Any, + ) -> Optional[Future[None]]: + logger.debug("Waiting for subprocess initialization to complete") + + # wait until the process is started + self.process_creation_future.result() + + return self._executor.submit( + self._write, + state_dict, + path, + **kwargs, + ) + + def _write( + self, + state_dict: Union[STATE_DICT, Future[STATE_DICT]], + path: str, + **kwargs: Any, + ) -> None: + logger.debug("Starting checkpoint write to %s", path) + + # wait for staging state_dict to be available + if isinstance(state_dict, Future): + logger.debug("Waiting for state_dict Future to resolve") + sd = state_dict.result() + else: + sd = state_dict + + # Log state_dict info only if debug logging is enabled (performance-conscious) + if logger.isEnabledFor(logging.DEBUG): + if hasattr(sd, "keys"): + logger.debug("State_dict contains %d keys", len(sd.keys())) + + self._send( + request_type=RequestType.WRITE_CHECKPOINT, + payload={ + "state_dict": sd, + "path": path, + "kwargs": kwargs, + }, + ) + + logger.debug("Waiting for write completion response") + # wait for response + self._recv() + logger.debug("Checkpoint write to %s completed successfully", path) + + def close(self) -> None: + logger.debug( + "Closing CheckpointProcess for rank %d", self._rank_info.global_rank + ) + self._executor.shutdown(wait=True, cancel_futures=True) + + if self.process and self.process.processes[0].is_alive(): + subprocess_pid = self.process.processes[0].pid + # send graceful termination to sub process + try: + self._parent_end.send( + WorkerRequest( + request_type=RequestType.TERMINATE_PROCESS, + payload={}, + ) + ) + except BrokenPipeError: + logger.warning( + "BrokenPipeError when sending termination request - subprocess (PID: %d) may have already terminated", + subprocess_pid, + ) + # subprocess terminated unexpectedly and below code will raise a + # ProcessExitedException. + + logger.debug( + "Waiting for subprocess to terminate gracefully (timeout: %ds)", + self._config.subprocess_shutdown_timeout_secs, + ) + + try: + if not self.process.join( + timeout=self._config.subprocess_shutdown_timeout_secs + ): + # graceful shutdown failed, kill the process. + logger.warning( + "Subprocess (PID: %d) did not terminate gracefully within %ds, killing it", + subprocess_pid, + self._config.subprocess_shutdown_timeout_secs, + ) + self.process.processes[0].kill() + logger.info("Subprocess killed forcefully") + except ProcessExitedException as e: + logger.error( + "ProcessExitedException during subprocess termination: %s", e + ) + raise + + logger.debug("CheckpointProcess closed successfully") diff --git a/torch/distributed/checkpoint/_experimental/checkpoint_reader.py b/torch/distributed/checkpoint/_experimental/checkpoint_reader.py new file mode 100644 index 00000000000000..3119fb22a0be5a --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/checkpoint_reader.py @@ -0,0 +1,221 @@ +""" +Checkpoint reader functionality for machine learning models. + +This module provides classes for reading checkpoints from storage, including +determining checkpoint layout and configuring the reader. +""" + +import logging +import os +from itertools import zip_longest +from pathlib import Path +from typing import Any, Optional + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode + +from .types import RankInfo, STATE_DICT + + +logger = logging.getLogger(__name__) + + +class CheckpointReader: + """ + Handles reading state dictionaries from storage. + + This class is responsible for reading model state dictionaries from storage according + to the specified checkpoint layout. It supports synchronization barriers to ensure + all ranks in a distributed setting complete their checkpoint operations. + """ + + def __init__( + self, + rank_info: RankInfo, + ): + """ + Initialize a CheckpointReader. + + Args: + rank_info: Information about the current rank in a distributed setting. + """ + + self._rank_info = rank_info + + def read( + self, + path: str, + state_dict: Optional[STATE_DICT] = None, + *, + map_location: Any = None, + **kwargs: dict[str, Any], + ) -> tuple[STATE_DICT, list[str]]: + """ + Reads a state dictionary from storage. + + Args: + path (str): The path from which to read the checkpoint. + map_location (Any): Device mapping function or device name for relocating tensors. + **kwargs: Additional keyword arguments passed to torch.load. + + Returns: + STATE_DICT: The loaded state dictionary. + list[str]: List of missing keys. + """ + logger.debug( + "Reading checkpoint from %s for rank %s", + path, + self._rank_info.global_rank, + ) + + dir_path = Path(path) + file_path = dir_path / f"checkpoint_{self._rank_info.global_rank}.pt" + + # Check if the file exists + if not os.path.exists(file_path): + logger.error("Checkpoint file not found at %s", file_path) + raise FileNotFoundError(f"Checkpoint file not found at {file_path}") + + if state_dict is None: + result: tuple[STATE_DICT, list[str]] = ( + torch.load(file_path, map_location=map_location), + [], + ) + else: + result = self._partial_read( + file_path, state_dict, map_location=map_location, **kwargs + ) + logger.debug("Successfully read checkpoint file from %s", file_path) + return result + + def _partial_read( + self, + file_path: Path, + state_dict: STATE_DICT, + *, + map_location: Any = None, + **kwargs: dict[str, Any], + ) -> tuple[STATE_DICT, list[str]]: + """ + Reads only the keys present in state_dict from the checkpoint file. + + This method optimizes checkpoint loading by only loading the tensors that + are actually needed, based on the keys present in the input state_dict. + This can significantly reduce memory usage and loading time for large checkpoints + when only a subset of the model needs to be loaded. + + Args: + file_path (str): The path to the checkpoint file. + state_dict (STATE_DICT): The state dictionary containing keys to load. + map_location (Any): Device mapping function or device name for relocating tensors. + **kwargs: Additional keyword arguments passed to torch.load. + + Returns: + tuple[STATE_DICT, list[str]]: The updated state dictionary with loaded values and a list of missing keys. + """ + + with FakeTensorMode(): + metadata_dict = torch.load(file_path, map_location=map_location) + + missing_keys = [] + + with open(file_path, "rb") as file: + # Helper function to load tensor data from file + def load_tensor( + target: Optional[torch.Tensor], source: torch.Tensor, full_key: str + ) -> torch.Tensor: + if target is not None and ( + target.size() != source.size() or target.dtype != source.dtype + ): + raise RuntimeError( + f"Target tensor size={target.size()} dtype={target.dtype} does not match " + f"source tensor size={source.size()} dtype={source.dtype} for key {full_key}" + ) + + tensor_offset = source.untyped_storage()._checkpoint_offset + + assert tensor_offset is not None, ( + "checkpoint_offset for tensor in torch serialized file is not set. This could" + "happen if the checkpoint was saved with a older version of Pytorch." + "Please make sure that the checkpoint was saved with Pytorch 2.7 or later." + ) + + tensor_len = source.nelement() * source.element_size() + file.seek( + tensor_offset + source.element_size() * int(source.storage_offset()) + ) + if target is None: + target = torch.empty( + source.size(), dtype=source.dtype, device=source.device + ) + + buffer = file.read(tensor_len) + cpu_tensor = torch.frombuffer(buffer, dtype=source.dtype) + tensor = cpu_tensor.view(source.size()) + target.copy_(tensor) + return target + + # Helper function to recursively process nested structures + def process_value( + target_value: Any, source_value: Any, key_path: str + ) -> Any: + source_type = type(source_value) + if source_type is torch._subclasses.fake_tensor.FakeTensor: + source_type = torch.Tensor + if target_value is not None and not isinstance( + target_value, source_type + ): + raise RuntimeError( + f"Target value {key_path} is set to {type(target_value)}, but source value is {type(source_value)}" + ) + if isinstance(source_value, torch.Tensor): + return load_tensor(target_value, source_value, key_path) + elif isinstance(source_value, dict): + if target_value is None: + # create a new map with all the keys present in source_value + target_value = dict.fromkeys(source_value.keys()) + + for key in list(target_value.keys()): + current_path = f"{key_path}.{key}" if key_path else key + if key in source_value: + target_value[key] = process_value( + target_value[key], source_value[key], current_path + ) + else: + missing_keys.append(current_path) + + return target_value + elif isinstance(source_value, list): + if target_value is None: + target_value = [None] * len(source_value) + result = [] + for i, (target_item, source_item) in enumerate( + zip_longest(target_value, source_value, fillvalue=None) + ): + current_path = f"{key_path}[{i}]" if key_path else f"[{i}]" + result.append( + process_value(target_item, source_item, current_path) + ) + return result + else: + return source_value + + # Start recursive processing from the root of the state dictionary + updated_state_dict = process_value(state_dict, metadata_dict, "") + + if missing_keys: + if len(missing_keys) > 10: + logger.warning( + "Missing %s keys from checkpoint: %s... (and %s more)", + len(missing_keys), + missing_keys[:10], + len(missing_keys) - 10, + ) + else: + logger.warning( + "Missing %s keys from checkpoint: %s", + len(missing_keys), + missing_keys, + ) + + return updated_state_dict, missing_keys diff --git a/torch/distributed/checkpoint/_experimental/checkpoint_writer.py b/torch/distributed/checkpoint/_experimental/checkpoint_writer.py new file mode 100644 index 00000000000000..1f9026d6e8322a --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/checkpoint_writer.py @@ -0,0 +1,163 @@ +""" +Checkpoint writer functionality for machine learning models. + +This module provides classes for writing checkpoints to storage, including +determining checkpoint layout, configuring the writer, and defining hooks +for custom actions during the checkpoint writing process. +""" + +import abc +import logging +import os +from concurrent.futures import Future +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import torch + +from .barriers import Barrier +from .types import RankInfo, STATE_DICT + + +logger = logging.getLogger(__name__) + + +class WriterHook(abc.ABC): + """ + Abstract base class for checkpoint commit hooks. + + A commit hook provides callbacks that are executed before and after a checkpoint + is committed to storage. This allows for custom actions to be performed at specific + points in the checkpoint writing process, such as metadata updates, cleanup operations, + or notifications. + """ + + @abc.abstractmethod + def pre_commit(self, path: str, **kwargs: dict[str, Any]) -> None: + """ + Performs actions before committing the checkpoint. + """ + + @abc.abstractmethod + def post_commit(self, path: str, **kwargs: dict[str, Any]) -> None: + """ + Performs actions after committing the checkpoint. + """ + + +@dataclass +class CheckpointWriterConfig: + """ + Configuration options for the CheckpointWriter. + + Attributes: + write_barrier_timeout_secs: Maximum time in seconds to wait for all ranks + to reach the checkpoint barrier before timing out. Default is 600 seconds. + """ + + write_barrier_timeout_secs: int = 600 + + +class CheckpointWriter: + """ + Handles writing state dictionaries to storage. + + This class is responsible for writing model state dictionaries to storage according + to the specified checkpoint layout. It supports synchronization barriers to ensure + all ranks in a distributed setting complete their checkpoint operations. + """ + + def __init__( + self, + config: CheckpointWriterConfig, + rank_info: RankInfo, + barrier: Optional[Barrier] = None, + commit_hook: Optional[WriterHook] = None, + ): + """ + Initialize a CheckpointWriter. + + Args: + config: Configuration options for the checkpoint writer. + rank_info: Information about the current rank in a distributed setting. + barrier: Optional synchronization barrier for distributed checkpointing. + Note: The barrier should be initialized with the appropriate barrier_prefix + and timeout_secs parameters. + commit_hook: Optional hook for custom actions before and after checkpoint commits. + """ + + self._config = config + self._rank_info = rank_info + self._commit_hook = commit_hook + self._barrier = barrier + + def write( + self, + state_dict: STATE_DICT, + path: str, + **kwargs: dict[str, Any], + ) -> Optional[Future[None]]: + """ + Writes the state_dict to storage. + + Args: + state_dict (STATE_DICT): The state_dict to write. + path (str): The path to write the checkpoint to. + **kwargs: Additional keyword arguments passed to hooks. + + Returns: + Optional[Future[None]]: A future for tracking the write operation, if applicable. + """ + logger.debug( + "Writing checkpoint to %s for rank %s", + path, + self._rank_info.global_rank, + ) + dir_path = Path(path) + full_path = dir_path / f"checkpoint_{self._rank_info.global_rank}.pt" + os.makedirs( + os.path.dirname(full_path), + exist_ok=True, + ) + torch.save(state_dict, full_path) + logger.debug("Successfully saved checkpoint file to %s", full_path) + + # Execute pre-commit hook if available + commit_hook = self._commit_hook + if commit_hook is not None: + logger.debug("Executing pre-commit hook for %s", path) + commit_hook.pre_commit(path, **kwargs) + + # Wait for all ranks to finish writing if barrier is available + barrier = self._barrier + if barrier is not None: + logger.info( + "Waiting for all ranks at barrier with timeout %ss", + self._config.write_barrier_timeout_secs, + ) + barrier.execute_barrier() + logger.info("All ranks passed barrier") + else: + logger.info("No barrier configured, skipping synchronization") + + # Execute commit hook if available + if commit_hook is not None: + logger.debug("Executing commit hook for %s", path) + commit_hook.post_commit(path, **kwargs) + + logger.info( + "Successfully wrote checkpoint to %s for rank %s", + path, + self._rank_info.global_rank, + ) + return None + + def close(self) -> None: + """ + Close the writer and release any resources. + + This is a no-op for the base CheckpointWriter but may be overridden + by subclasses that need to perform cleanup. + """ + logger.debug("Closing checkpoint writer") diff --git a/torch/distributed/checkpoint/_experimental/checkpointer.py b/torch/distributed/checkpoint/_experimental/checkpointer.py new file mode 100644 index 00000000000000..839a6c970f5844 --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/checkpointer.py @@ -0,0 +1,341 @@ +import abc +import logging +from concurrent.futures import Future +from typing import Any, Optional, TypeVar + +from .checkpoint_process import CheckpointProcess +from .checkpoint_reader import CheckpointReader +from .checkpoint_writer import CheckpointWriter +from .staging import CheckpointStager +from .types import STATE_DICT +from .utils import wrap_future + + +logger = logging.getLogger(__name__) + +LOG_INTERVAL = 60 +T = TypeVar("T") + + +class Checkpointer(abc.ABC): + """ + WARNING: This class is experimental, and is created to validate certain ideas, + and is subjected to change or deprecation and we strong discourage any usages at + this time. + + Abstract base class that defines the API for checkpointing. + + This class defines the interface for coordinating the writing and loading of model + state dictionaries to and from storage. It provides abstract methods to save and load model states + with support for both synchronous and asynchronous operations. + + Concrete implementations of this class must implement all the abstract methods. + """ + + @abc.abstractmethod + def save( + self, + state_dict: STATE_DICT, + path: str, + **kwargs: dict[str, Any], + ) -> Optional[tuple[Future, Future]]: + """ + Save a state dictionary to storage. + + Args: + state_dict: The state dictionary to save. + path: The path where the checkpoint should be saved. + **kwargs: Additional keyword arguments to pass to the writer. + + Returns: + For synchronous implementations: None + For asynchronous implementations: tuple of (stage_future, write_future) + representing the staging and writing operations. + """ + + @abc.abstractmethod + def load( + self, + path: str, + state_dict: Optional[STATE_DICT] = None, + *, + default_map_location: Any = None, + strict: bool = False, + **kwargs: dict[str, Any], + ) -> STATE_DICT: + """ + Load a state dictionary from storage. + + Args: + path: The path from which to load the checkpoint. + state_dict: Optional state dictionary to update with loaded values. + If provided, only keys in this dictionary will be loaded. + default_map_location: Device mapping function or device name for relocating tensors. + strict: If True, raises an error when there are missing keys in the checkpoint. + **kwargs: Additional keyword arguments to pass to the reader. + + Returns: + The loaded state dictionary. + """ + + @abc.abstractmethod + def close(self) -> None: + """ + Close the checkpointer and release any resources. + + This method should be called when the checkpointer is no longer needed to ensure + proper cleanup of resources. + """ + + +class SyncCheckpointer(Checkpointer): + """ + Synchronous implementation of Checkpointer. + + This class coordinates the writing and loading of model state dictionaries to and from storage + using only synchronous operations. It provides a simple, efficient interface for checkpoint + operations without async overhead. + + Attributes: + _writer: CheckpointWriter for writing state dictionaries to storage. + _reader: CheckpointReader for reading state dictionaries from storage. + + Example: + checkpointer = SyncCheckpointer(writer=writer, reader=reader) + checkpointer.save(state_dict, path) + loaded_state_dict = checkpointer.load(path) + """ + + def __init__( + self, + writer: CheckpointWriter, + reader: CheckpointReader, + ): + """ + Initialize a synchronous checkpointer. + + Args: + writer: CheckpointWriter for writing checkpoints to storage. + reader: CheckpointReader for reading checkpoints from storage. + """ + self._writer = writer + self._reader = reader + + def save( + self, + state_dict: STATE_DICT, + path: str, + **kwargs: dict[str, Any], + ) -> Optional[tuple[Future, Future]]: + """ + Save a state dictionary to storage synchronously. + + Args: + state_dict: The state dictionary to save. + path: The path where the checkpoint should be saved. + **kwargs: Additional keyword arguments to pass to the writer. + + Returns: + Always returns None as operations are synchronous. + + Example: + checkpointer.save(state_dict, "/path/to/checkpoint") + """ + logger.debug("Saving checkpoint synchronously to %s", path) + self._writer.write(state_dict, path, **kwargs) + return None + + def load( + self, + path: str, + state_dict: Optional[STATE_DICT] = None, + *, + default_map_location: Any = None, + strict: bool = False, + **kwargs: dict[str, Any], + ) -> STATE_DICT: + """ + Load a state dictionary from storage. + + Args: + path: The path from which to load the checkpoint. + state_dict: Optional state dictionary to update with loaded values. + If provided, only keys in this dictionary will be loaded. + default_map_location: Device mapping function or device name for relocating tensors. + strict: If True, raises an error when there are missing keys in the checkpoint. + **kwargs: Additional keyword arguments to pass to the reader. + + Returns: + The loaded state dictionary. + + Raises: + RuntimeError: If strict=True and there are missing keys in the checkpoint. + FileNotFoundError: If the checkpoint file is not found. + """ + logger.info("Loading checkpoint from %s", path) + + loaded_state_dict, missing_keys = self._reader.read( + path=path, + state_dict=state_dict, + map_location=default_map_location, + **kwargs, + ) + if strict and missing_keys is not None and missing_keys != []: + raise RuntimeError(f"Checkpoint at {path} is missing keys: {missing_keys}") + return loaded_state_dict + + def close(self) -> None: + """ + Close the checkpointer and release any resources. + + This method should be called when the checkpointer is no longer needed to ensure + proper cleanup of resources. + """ + self._writer.close() + logger.info("SyncCheckpointer closed") + + +class AsyncCheckpointer(Checkpointer): + """ + Asynchronous implementation of Checkpointer. + + This class coordinates the writing and loading of model state dictionaries to and from storage + using asynchronous operations for saving. It provides efficient async checkpoint operations + with staging and background writing capabilities. + + Attributes: + _reader: CheckpointReader for reading state dictionaries from storage. + _checkpoint_stager: Stager for async operations. + _checkpoint_process: Process for async operations. + _write_future: Future representing the ongoing async write operation. + + Example: + checkpointer = AsyncCheckpointer( + reader=reader, + checkpoint_stager=stager, + checkpoint_process=process + ) + stage_future, write_future = checkpointer.save(state_dict, path) + # ... do other work ... + write_future.result() # Wait for completion + """ + + def __init__( + self, + checkpoint_stager: CheckpointStager, + checkpoint_process: CheckpointProcess, + reader: CheckpointReader, + ): + """ + Initialize an asynchronous checkpointer. + + Args: + checkpoint_stager: Stager for async operations. + checkpoint_process: Process for async operations. + reader: CheckpointReader for reading checkpoints from storage. + """ + self._reader = reader + self._checkpoint_stager = checkpoint_stager + self._checkpoint_process = checkpoint_process + self._write_future: Optional[Future[Any]] = None + + def save( + self, + state_dict: STATE_DICT, + path: str, + **kwargs: Any, + ) -> Optional[tuple[Future, Future]]: + """ + Save a state dictionary to storage asynchronously. + + Args: + state_dict: The state dictionary to save. + path: The path where the checkpoint should be saved. + **kwargs: Additional keyword arguments to pass to the stager and writer. + + Returns: + A tuple of (stage_future, write_future) representing the staging and writing operations. + + Example: + stage_future, write_future = checkpointer.save(state_dict, "/path/to/checkpoint") + # ... do other work ... + write_future.result() # Wait for completion + """ + logger.info( + "Initiating checkpoint save to %s. Will wait for prev checkpoints to complete.", + path, + ) + # Wait for previous checkpoint ops to finish and verify they are successful + if self._write_future is not None: + self._write_future.result() + + logger.debug("Starting state dictionary staging") + staging_result = self._checkpoint_stager.stage( + state_dict=state_dict, + **kwargs, + ) + + logger.debug("Starting checkpoint write to %s", path) + self._write_future = self._checkpoint_process.write( + staging_result, path, **kwargs + ) + logger.info("Checkpoint save to %s initiated", path) + + # Return futures for the staging and writing operations + if self._write_future is not None: + return wrap_future(staging_result), self._write_future + else: + # This should not happen since we just assigned _write_future above + raise RuntimeError("Write future is unexpectedly None") + + def load( + self, + path: str, + state_dict: Optional[STATE_DICT] = None, + *, + default_map_location: Any = None, + strict: bool = False, + **kwargs: Any, + ) -> STATE_DICT: + """ + Load a state dictionary from storage. + + Loading is always performed synchronously, even in AsyncCheckpointer. + + Args: + path: The path from which to load the checkpoint. + state_dict: Optional state dictionary to update with loaded values. + If provided, only keys in this dictionary will be loaded. + default_map_location: Device mapping function or device name for relocating tensors. + strict: If True, raises an error when there are missing keys in the checkpoint. + **kwargs: Additional keyword arguments to pass to the reader. + + Returns: + The loaded state dictionary. + + Raises: + RuntimeError: If strict=True and there are missing keys in the checkpoint. + FileNotFoundError: If the checkpoint file is not found. + """ + logger.info("Loading checkpoint from %s", path) + + loaded_state_dict, missing_keys = self._reader.read( + path=path, + state_dict=state_dict, + map_location=default_map_location, + **kwargs, + ) + if strict and missing_keys is not None and missing_keys != []: + raise RuntimeError(f"Checkpoint at {path} is missing keys: {missing_keys}") + return loaded_state_dict + + def close(self) -> None: + """ + Close the checkpointer and release any resources. + + This method should be called when the checkpointer is no longer needed to ensure + proper cleanup of async resources. + """ + self._checkpoint_stager.close() + self._checkpoint_process.close() + logger.info("AsyncCheckpointer closed") diff --git a/torch/distributed/checkpoint/_experimental/config.py b/torch/distributed/checkpoint/_experimental/config.py new file mode 100644 index 00000000000000..a81156e3929cac --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/config.py @@ -0,0 +1,44 @@ +""" +Configuration classes for checkpointer construction. + +This module provides configuration dataclasses that consolidate all +configuration options needed to construct checkpointers. +""" + +from dataclasses import dataclass, field + +from .barriers import BarrierConfig +from .checkpoint_process import CheckpointProcessConfig +from .checkpoint_writer import CheckpointWriterConfig +from .staging import CheckpointStagerConfig + + +@dataclass +class CheckpointerConfig: + """ + Configuration class for checkpointer construction. + + This class consolidates the core component configuration options needed to construct + a checkpointer, providing a clean separation of concerns where each component + manages its own configuration. + + Attributes: + writer_config: Configuration options for the checkpoint writer component. + barrier_config: Configuration for barrier construction and arguments. + staging_config: Configuration options for the async staging component. + process_config: Configuration options for the async checkpoint process component. + + """ + + writer_config: CheckpointWriterConfig = field( + default_factory=CheckpointWriterConfig + ) + barrier_config: BarrierConfig = field(default_factory=BarrierConfig) + + # Below configs are used for async checkpointing + staging_config: CheckpointStagerConfig = field( + default_factory=CheckpointStagerConfig + ) + process_config: CheckpointProcessConfig = field( + default_factory=CheckpointProcessConfig + ) diff --git a/torch/distributed/checkpoint/_experimental/staging.py b/torch/distributed/checkpoint/_experimental/staging.py new file mode 100644 index 00000000000000..55e4c15921a2d9 --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/staging.py @@ -0,0 +1,216 @@ +""" +Experimental staging module for PyTorch Distributed Checkpointing. + +This module provides advanced staging capabilities for checkpoints including: +- Asynchronous staging using ThreadPoolExecutor +- Pinned memory allocation for faster CPU-GPU transfers +- Shared memory support for multi-process scenarios +- Non-blocking CUDA operations with stream synchronization +- Caching of frequently used storages for efficient memory management +- Automatic resource cleanup and memory management + +Classes: + CheckpointStager: Abstract base class defining the staging interface + StagingOptions: Configuration dataclass for staging behavior + DefaultStager: Default implementation with comprehensive staging features +""" + +import abc +import logging +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from logging import getLogger +from typing import Any, TypeVar, Union + +import torch +from torch.distributed.checkpoint._state_dict_stager import StateDictStager + +from .types import STATE_DICT + + +T = TypeVar("T") + +logger = getLogger() +logger.setLevel(logging.INFO) + + +class CheckpointStager(abc.ABC): + """ + Abstract base class for checkpoint staging implementations. + + CheckpointStager defines the interface that all staging implementations + must follow. Staging is the process of offloading state dictionaries + for async checkpointing. + """ + + @abc.abstractmethod + def stage( + self, + state_dict: STATE_DICT, + **kwargs: Any, + ) -> Union[STATE_DICT, Future[STATE_DICT]]: + """ + Stage a state dictionary for checkpointing. + + Args: + state_dict: The state dictionary to stage + **kwargs: Additional staging parameters + + Returns: + Either a staged state dictionary (synchronous) or a Future + that will resolve to the staged state dictionary (asynchronous) + """ + + @abc.abstractmethod + def close(self) -> None: + """ + Clean up all resources used by the stager. + """ + + +@dataclass +class CheckpointStagerConfig: + """ + Configuration options for checkpoint staging behavior. + + Attributes: + use_pinned_memory (bool): Enable pinned memory allocation for faster + CPU-GPU transfers. Requires CUDA to be available. Default: True + use_shared_memory (bool): Enable shared memory for multi-process + scenarios. Useful when multiple processes need access to the + same staged data. Default: True + use_async_staging (bool): Enable asynchronous staging using a + background thread pool. Allows overlapping computation with + staging operations. Requires CUDA. Default: True + use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory + copies with stream synchronization. Improves performance by + allowing CPU work to continue during GPU transfers. Default: True + + Note: + CUDA-dependent features will raise exception if CUDA is not available. + """ + + use_pinned_memory: bool = True + use_shared_memory: bool = True + use_async_staging: bool = True + use_cuda_non_blocking_copy: bool = True + + +class DefaultStager(CheckpointStager): + """ + DefaultStager provides a full-featured staging implementation that combines + multiple optimization techniques for efficient checkpoint preparation. + + The staging process works as follows: + 1. State dictionary is submitted for staging (sync or async) + 2. Tensors are copied from GPU to optimized CPU storage + 3. CUDA operations are synchronized if non-blocking copies are used + 4. Staged state dictionary is returned or made available via Future + + NOTE: state_dict should be deep-copyable object as staging will create a + copy of it. + + Usage Patterns: + # Synchronous staging + stager = DefaultStager(CheckpointStagerConfig(use_async_staging=False)) + staged_dict = stager.stage(state_dict) + stager.close() + + # Asynchronous staging + stager = DefaultStager(CheckpointStagerConfig(use_async_staging=True)) + future = stager.stage(state_dict) + # ... do other work ... + staged_dict = future.result() + stager.close() + + # Context manager pattern (recommended) + with DefaultStager(config) as stager: + result = stager.stage(state_dict) + # Automatic cleanup on exit + + Performance Considerations: + - Async staging provides best performance when model computation + can overlap with staging operations + - Pinned memory improves CPU-GPU transfer speeds but uses more memory + - Shared memory allows efficient IPC to checkpoint process + - Non-blocking copies reduce GPU idle time during memory transfers + + Thread Safety: + DefaultStager is not thread-safe. Each thread should use its own + instance, or external synchronization should be provided. + """ + + def __init__( + self, + config: CheckpointStagerConfig = CheckpointStagerConfig(), + ): + self._config = config + self._state_dict_stager = StateDictStager( + pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory + ) + self._staging_executor = None + self._staging_stream = None + + if self._config.use_async_staging: + self._staging_executor = ThreadPoolExecutor(max_workers=1) + if torch.cuda.is_available(): + # Note: stream needs to be initialized on the main thread after default cuda + # stream is setup/used to avoid the risk of accidentally reusing the main + # compute stream or in other cases kernels actually launching from the + # main thread. + self._staging_stream = torch.cuda.Stream() + + if self._config.use_cuda_non_blocking_copy: + assert torch.cuda.is_available(), "Non-blocking copy requires CUDA" + + def stage( + self, + state_dict: STATE_DICT, + **kwargs: Any, + ) -> Union[STATE_DICT, Future[STATE_DICT]]: + if self._config.use_async_staging: + assert self._staging_executor is not None, ( + "Staging executor should be initialized for async staging" + ) + return self._staging_executor.submit( + self._stage, + state_dict, + **kwargs, + ) + else: + return self._stage(state_dict, **kwargs) + + def _stage(self, state_dict: STATE_DICT, **kwargs: Any) -> STATE_DICT: + state_dict = self._state_dict_stager.stage( + state_dict, non_blocking=self._config.use_cuda_non_blocking_copy, **kwargs + ) + + if self._config.use_cuda_non_blocking_copy: + assert self._staging_stream or not self._config.use_async_staging, ( + "Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized." + ) + + # waits for the enqued copy operations to finish. + self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize() + + return state_dict + + def close(self) -> None: + """ + Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor + used for async staging operations and cleans up the underlying StateDictStager's + cached storages. Should be called when the stager is no longer needed to prevent + resource leaks, especially in long-running applications. After calling close(), + the stager should not be used for further staging operations. + + state_dict should be deep-copyable object. + + Example: + stager = DefaultStager(CheckpointStagerConfig(use_async_staging=True)) + # ... do staging operations ... + stager.close() # Clean up all resources + """ + if self._staging_executor: + self._staging_executor.shutdown(wait=True) + + self._state_dict_stager.close() diff --git a/torch/distributed/checkpoint/_experimental/types.py b/torch/distributed/checkpoint/_experimental/types.py new file mode 100644 index 00000000000000..3874ecc30bf43f --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/types.py @@ -0,0 +1,29 @@ +""" +Type definitions for distributed training and checkpointing. + +This module provides type definitions and classes for managing rank information +in distributed training environments, which is essential for proper checkpoint +saving and loading. +""" + +from dataclasses import dataclass +from typing import Any +from typing_extensions import TypeAlias + + +# Type alias for state dictionaries used in checkpointing +STATE_DICT: TypeAlias = dict[str, Any] + + +@dataclass +class RankInfo: + """ + Information about the current rank in a distributed training environment. + + Attributes: + global_rank: The global rank ID of the current process. + global_world_size: The total number of processes in the distributed environment. + """ + + global_rank: int + global_world_size: int diff --git a/torch/distributed/checkpoint/_experimental/utils.py b/torch/distributed/checkpoint/_experimental/utils.py new file mode 100644 index 00000000000000..271e9aa112f682 --- /dev/null +++ b/torch/distributed/checkpoint/_experimental/utils.py @@ -0,0 +1,42 @@ +""" +Utility functions for the experimental checkpoint module. + +This module contains helper functions and utilities used across the experimental +checkpoint functionality. +""" + +from concurrent.futures import Future +from typing import Any + + +def wrap_future(original_result: Any) -> Future[None]: + """ + Wraps a result (Future or not) to return a Future with None result. + + If the input is a Future, returns a new Future that completes with None when + the original Future completes successfully, or propagates any exception. + If the input is not a Future, returns a completed Future with None result. + + Args: + original_result: The result to wrap (Future or any other value). + + Returns: + A Future that completes with None on success or propagates exceptions. + """ + masked_future: Future[None] = Future() + + if isinstance(original_result, Future): + + def on_complete(_: Future[Any]) -> None: + try: + original_result.result() + masked_future.set_result(None) + except Exception as e: + masked_future.set_exception(e) + + original_result.add_done_callback(on_complete) + else: + # Return a completed future with None result + masked_future.set_result(None) + + return masked_future diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index ef8f5823fdb2d7..377c34ae1e5dde 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -92,7 +92,7 @@ def rm_file(self, path: Union[str, os.PathLike]) -> None: self.fs.rm(path) def ls(self, path: Union[str, os.PathLike]) -> list[str]: - # setting detail to False explictly to keep the list[str] return type, + # setting detail to False explicitly to keep the list[str] return type, # instead of the list[Dict] return type when detail=True return self.fs.ls(path, detail=False) diff --git a/torch/distributed/checkpoint/_hf_planner.py b/torch/distributed/checkpoint/_hf_planner.py deleted file mode 100644 index 4ee176339f44cb..00000000000000 --- a/torch/distributed/checkpoint/_hf_planner.py +++ /dev/null @@ -1,49 +0,0 @@ -# mypy: allow-untyped-defs -from dataclasses import dataclass - -from torch.distributed.checkpoint._dedup_save_plans import ( - dedup_save_plans_with_fqn_to_index_mapping, -) -from torch.distributed.checkpoint.default_planner import ( - DefaultLoadPlanner, - DefaultSavePlanner, -) -from torch.distributed.checkpoint.planner import ReadItem, SavePlan - - -__all__ = ["_HuggingFaceSavePlanner", "_HuggingFaceLoadPlanner"] - - -@dataclass -class _FqnToFileMapping: - fqn_to_file_index_mapping: dict[str, int] - - -class _HuggingFaceSavePlanner(DefaultSavePlanner): - """ - A save planner that dedups the save plans based on the fqn to file index mapping. - """ - - def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]: - assert len(all_plans) > 0, "all_plans should not be empty" - assert all_plans[0].storage_data is not None, "storage_data should not be None" - assert isinstance(all_plans[0].storage_data, _FqnToFileMapping), ( - "storage_data should be of type _FqnToFileMapping" - ) - - fqn_to_index_mapping: dict[str, int] = all_plans[ - 0 - ].storage_data.fqn_to_file_index_mapping - - return dedup_save_plans_with_fqn_to_index_mapping( - all_plans, fqn_to_index_mapping - ) - - -class _HuggingFaceLoadPlanner(DefaultLoadPlanner): - def __init__(self, allow_tensor_resize: bool = False): - super().__init__() - self.allow_tensor_resize = allow_tensor_resize - - def resolve_tensor(self, read_item: ReadItem): - return self.lookup_tensor(read_item.dest_index) diff --git a/torch/distributed/checkpoint/_hf_storage.py b/torch/distributed/checkpoint/_hf_storage.py deleted file mode 100644 index 1c8f8e376d7420..00000000000000 --- a/torch/distributed/checkpoint/_hf_storage.py +++ /dev/null @@ -1,276 +0,0 @@ -# mypy: allow-untyped-defs -import dataclasses -import io -import json -import os -import queue -import struct -from typing import Optional - -import fsspec # type: ignore[import-untyped] - -from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter -from torch.distributed.checkpoint._hf_planner import ( - _FqnToFileMapping, - _HuggingFaceLoadPlanner, -) -from torch.distributed.checkpoint.filesystem import SerializationFormat -from torch.distributed.checkpoint.metadata import ( - BytesStorageMetadata, - Metadata, - STORAGE_TYPES, - StorageMeta, -) -from torch.distributed.checkpoint.planner import ( - LoadPlan, - LoadPlanner, - ReadItem, - SavePlan, - SavePlanner, - WriteItem, -) -from torch.distributed.checkpoint.storage import WriteResult -from torch.futures import Future - - -__all__ = ["_HuggingFaceStorageWriter", "_HuggingFaceStorageReader"] - -_metadata_fn: str = "model.safetensors.index.json" - -FILE_NAME = "model-{cpt_idx}-of-{num_shards}" -SUFFIX = ".safetensors" - - -class _HuggingFaceStorageWriter(FsspecWriter): - """ - A writer that writes to a huggingface repository in the huggingface format. - Uses in Fsspec back-end to communicate with the huggingface hub. - """ - - def __init__( - self, - path: str, - fqn_to_index_mapping: dict[str, int], - token: Optional[str] = None, - ) -> None: - """ - Initialize the huggingface writer pointing to path. - - Args: - path: hf directory where the checkpoint will be written to. Should begin with hf://. - token: The token to use to authenticate with huggingface hub. - fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. - Indices are from 1 to N, where N is the number of files. - - """ - from huggingface_hub import HfFileSystem # type: ignore[import-not-found] - - if HfFileSystem.protocol not in fsspec.available_protocols(): - fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) - - if token is not None: - super().__init__( - path=path, - token=token, - serialization_format=SerializationFormat.SAFETENSORS, - ) - else: - super().__init__( - path=path, - serialization_format=SerializationFormat.SAFETENSORS, - ) - self._fqn_to_index_mapping: dict[str, int] = fqn_to_index_mapping - - def prepare_local_plan(self, plan: SavePlan) -> SavePlan: - plan = super().prepare_local_plan(plan) - return dataclasses.replace( - plan, storage_data=_FqnToFileMapping(self._fqn_to_index_mapping) - ) - - def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: - return plans - - def write_data( - self, - plan: SavePlan, - planner: SavePlanner, - ) -> Future[list[WriteResult]]: - if len(plan.items) == 0: - fut: Future = Future() - fut.set_result([]) - return fut - - # storage_plan is a map from key to file index - storage_plan: dict[str, int] = plan.storage_data.fqn_to_file_index_mapping - - buckets = self._split_by_storage_plan(storage_plan, plan.items) - highest_index = max(storage_plan.values()) - - file_queue: queue.Queue = queue.Queue() - for file_index, write_items in buckets.items(): - file_name = self._gen_file_name(file_index, highest_index) - file_queue.put( - (self.fs.concat_path(self.path, file_name), file_name, write_items) - ) - - return super()._write_data(planner, file_queue) - - def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: - metadata_to_write = {} - storage_md = {} - total_size = 0 - for wr_list in results: - storage_md.update( - {wr.index.fqn: wr.storage_data.relative_path for wr in wr_list} - ) - total_size += sum([wr.storage_data.length for wr in wr_list]) - metadata_to_write["metadata"] = {"total_size": total_size} - metadata_to_write["weight_map"] = storage_md - - metadata_path = self.fs.concat_path(self.path, f"{_metadata_fn}") - with self.fs.create_stream(metadata_path, "w") as metadata_file: - json.dump(metadata_to_write, metadata_file, indent=2) - - def _split_by_storage_plan( - self, storage_plan: dict[str, int], items: list[WriteItem] - ) -> dict[int, list[WriteItem]]: - # storage_plan is a map from key to index - buckets = {} - for item in items: - key = item.index.fqn - idx = storage_plan[key] - if idx not in buckets: - buckets[idx] = [item] - else: - buckets[idx].append(item) - - return buckets - - def _gen_file_name(self, index: int, largest_index: int) -> str: - return ( - FILE_NAME.format( - cpt_idx=f"{index}".zfill(5), num_shards=f"{largest_index}".zfill(5) - ) - + SUFFIX - ) - - @property - def metadata_path(self) -> str: - return _metadata_fn - - -class _HuggingFaceStorageReader(FsspecReader): - """ - A reader that reads from a huggingface repository in the huggingface format. - Uses in Fsspec back-end to communicate with the huggingface hub. - """ - - def __init__(self, path: str, token: Optional[str] = None) -> None: - """ - Initialize the huggingface reader pointing to path. - - Args: - path: hf directory where the checkpoint will be read from. Should begin with hf://. - token: The token to use to authenticate with huggingface hub. - """ - from huggingface_hub import HfFileSystem # type: ignore[import-not-found] - - if HfFileSystem.protocol not in fsspec.available_protocols(): - fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) - - if token is not None: - super().__init__(path=path, token=token) - else: - super().__init__(path=path) - - self.storage_data: dict[str, str] = {} - - def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: - from safetensors.torch import load # type: ignore[import-not-found] - - per_file: dict[str, list[ReadItem]] = {} - - for read_item in plan.items: - file_name = self.storage_data[read_item.storage_index.fqn] - per_file.setdefault(file_name, []).append(read_item) - - for file_name, reqs in per_file.items(): - new_path = self.fs.concat_path(self.path, file_name) - with self.fs.create_stream(new_path, "rb") as stream: - loaded_tensors = load(stream.read()) - for req in reqs: - tensor = loaded_tensors[req.dest_index.fqn] - - target_tensor = planner.resolve_tensor(req) - if ( - isinstance(planner, _HuggingFaceLoadPlanner) - and planner.allow_tensor_resize - ): - target_tensor.resize_(tensor.size()) - else: - assert target_tensor.size() == tensor.size(), ( - f"Tensor size mismatch for {req.dest_index.fqn}: {target_tensor.size()} != {tensor.size()}" - ) - target_tensor = target_tensor.detach() - target_tensor.copy_(tensor) - planner.commit_tensor(req, target_tensor) - - fut: Future = Future() - fut.set_result(None) - return fut - - def read_metadata(self) -> Metadata: - metadata_path = self.fs.concat_path(self.path, _metadata_fn) - - state_dict_metadata: dict[str, STORAGE_TYPES] = {} - storage_data: dict[str, str] = {} - - if not self.fs.exists(metadata_path): - # if metadata file doesn't exist, create it from the safetensors file - safetensors_files = [] - for file in self.fs.ls(self.path): - if file.endswith(SUFFIX): - safetensors_files.append(file) - - if len(safetensors_files) != 1: - raise ValueError( - f"Need exactly one safetensors file to load without metadata, found {len(safetensors_files)} files" - ) - storage_data = {} - with self.fs.create_stream(safetensors_files[0], "rb") as f: - keys = _get_safetensors_file_keys(f) - - for key in keys: - state_dict_metadata[key] = BytesStorageMetadata() - storage_data[key] = os.path.basename(safetensors_files[0]) - else: - with self.fs.create_stream(metadata_path, "r") as metadata_file: - metadata = json.load(metadata_file) - - for key in metadata["weight_map"].keys(): - state_dict_metadata[key] = BytesStorageMetadata() - storage_data = metadata["weight_map"] - - metadata = Metadata( - state_dict_metadata=state_dict_metadata, - storage_data=storage_data, - ) - - if getattr(metadata, "storage_meta", None) is None: - metadata.storage_meta = StorageMeta() - metadata.storage_meta.load_id = self.load_id - - return metadata - - -def _get_safetensors_file_keys(file_bytes: io.IOBase) -> list[str]: - # this uses the same logic that's done in HF code base - # https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308 - # and follows their documentation on how their files are serialized - # https://huggingface.co/docs/safetensors/index#format - - header_len_bytes = file_bytes.read(8) - header_len = struct.unpack(" str: + if shard_index is not None: + return ( + SHARDED_FILE_NAME.format( + shard_idx=f"{shard_index}".zfill(5), + cpt_idx=f"{index}".zfill(5), + num_files=f"{largest_index}".zfill(5), + ) + + SUFFIX + ) + else: + return ( + FILE_NAME.format( + cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5) + ) + + SUFFIX + ) + + +def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]: + # this uses the same logic that's done in HF code base + # https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308 + # and follows their documentation on how their files are serialized + # https://huggingface.co/docs/safetensors/index#format + + num_bytes_for_header_len = 8 + header_len_bytes = file_bytes.read(num_bytes_for_header_len) + header_len = struct.unpack(" torch.dtype: + try: + dtype = DTYPE_MAP[dtype_str] + except KeyError: + dtype = torch.get_default_dtype() + + return dtype + + +def _get_dcp_custom_metadata(metadata: Any) -> Optional[Any]: + if DEFAULT_EXTRA_METADATA_KEY in metadata: + custom_metadata = metadata[DEFAULT_EXTRA_METADATA_KEY] + if CUSTOM_METADATA_KEY in custom_metadata: + return json.loads(custom_metadata[CUSTOM_METADATA_KEY]) + return None diff --git a/torch/distributed/checkpoint/_pg_transport.py b/torch/distributed/checkpoint/_pg_transport.py new file mode 100644 index 00000000000000..cab908b5a8510d --- /dev/null +++ b/torch/distributed/checkpoint/_pg_transport.py @@ -0,0 +1,310 @@ +import logging +import pickle +import time +from collections.abc import Generator +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import Callable, cast, Optional, TypeVar, Union + +import torch +from torch.distributed import ProcessGroup, Work +from torch.distributed.tensor import _DTensorSpec, DTensor +from torch.utils._pytree import ( + KeyPath, + tree_flatten_with_path, + tree_unflatten, + TreeSpec, +) + + +logger: logging.Logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass +class _TensorMeta: + """ + This is the metadata for a tensor that is used to transfer checkpoints. + It contains the shape, the dtype, the storage offset and the stride of the + tensor. + + This must be pickleable so that it can be sent over the wire. + """ + + shape: torch.Size + dtype: torch.dtype + storage_offset: int + stride: tuple[int, ...] + nbytes: int + + +@dataclass +class _DTensorMeta: + """ + This is the metadata for a DTensor that is used to transfer checkpoints. + It contains the metadata for the local tensor and the spec of the DTensor. + + This must be pickleable so that it can be sent over the wire. + """ + + local: _TensorMeta + spec: _DTensorSpec + + +@dataclass +class _StateDictMeta: + """ + This is the metadata for a state dict that is used to transfer checkpoints. + It contains the step, the pytree spec of the state dict and the metadata for + each tensor in the state dict. + + This must be pickleable so that it can be sent over the wire. + + Args: + step: the step of the checkpoint to verify consistency + treespec: the pytree spec of the state dict + paths: the path of each leaf in the state dict + non_tensor_leaves: the metadata for each tensor in the state dict and any + non-tensor leaves in the state dict + """ + + treespec: TreeSpec + paths: list[KeyPath] + non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] + + +@contextmanager +def _timeit(name: str) -> Generator[None, None, None]: + start = time.perf_counter() + yield + dur = time.perf_counter() - start + logger.info("%s took %ss", name, dur) + + +def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]: + return ( + _cast_tensor(tensor, torch.uint8), + _TensorMeta( + shape=tensor.shape, + dtype=tensor.dtype, + storage_offset=cast(int, tensor.storage_offset()), + stride=tensor.stride(), + nbytes=tensor.untyped_storage().nbytes(), + ), + ) + + +def _prepare_state_dict( + state_dict: object, + device: torch.device, +) -> tuple[_StateDictMeta, list[torch.Tensor]]: + leaves: list[tuple[KeyPath, object]] + leaves, treespec = tree_flatten_with_path(state_dict) + + paths: list[KeyPath] = [] + non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = [] + tensors: list[torch.Tensor] = [] + for key_path, v in leaves: + paths.append(key_path) + + if isinstance(v, DTensor): + tensor, tensor_meta = _prepare_tensor(v._local_tensor) + + tensors.append(tensor) + + non_tensor_leaves.append( + _DTensorMeta( + local=tensor_meta, + spec=v._spec, + ) + ) + elif isinstance(v, torch.Tensor): + tensor, tensor_meta = _prepare_tensor(v) + tensors.append(tensor) + non_tensor_leaves.append(tensor_meta) + else: + non_tensor_leaves.append(v) + + return ( + _StateDictMeta( + treespec=treespec, + paths=paths, + non_tensor_leaves=non_tensor_leaves, + ), + tensors, + ) + + +def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """ + Casts the underlying storage to a tensor of the given dtype. + + The returned tensor will be of size ``storage.nbytes``. + + This works for all datatypes and supports strided/offset tensors with the + caveat that the cast tensor may be larger than the original tensor due to + the differences in striding. + """ + assert type(tensor) is torch.Tensor, ( + f"can only cast standard tensors not {type(tensor)}" + ) + storage = tensor.untyped_storage() + ret = torch.tensor(storage, dtype=dtype, device=tensor.device) + assert ret.untyped_storage() is storage, "storage should be the same" + return ret + + +class PGTransport: + """ + This is a checkpoint transport that uses the process group to transfer checkpoints. + This allows for fast recovery of workers by fetching the current weights + from an existing worker. + + Args: + pg: the process group to use for communication + timeout: the timeout for communication + device: the device to use for tensors + state_dict: if specified this function will be called to do an inplace + receive into the returned state_dict. This is much faster than + having to allocate new tensors and transferring them to the CPU. + """ + + def __init__( + self, + pg: ProcessGroup, + timeout: timedelta, + device: torch.device, + state_dict: Optional[Callable[[], object]] = None, + ) -> None: + self._work: list[Work] = [] + self._pg = pg + self._timeout = timeout + self._device = device + self._state_dict = state_dict + + def send_checkpoint(self, dst_ranks: list[int], state_dict: object) -> None: + """ + Send a checkpoint to multiple destination ranks. + + The process: + 1. Prepares the state dict by converting tensors to a serializable format + 2. Sends metadata as pickled data + 3. Sends each tensor sequentially to all destination ranks + + Args: + dst_ranks: List of destination ranks to send the checkpoint to + state_dict: The state dictionary containing model parameters + """ + with _timeit("preparing state_dict"): + meta, tensors = _prepare_state_dict(state_dict, device=self._device) + + work = [] + + with _timeit("send meta"): + buf = pickle.dumps(meta) + len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device) + buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([len_t], dst_rank, tag=1)) + work.append(self._pg.send([buf_t], dst_rank, tag=2)) + + with _timeit("send tensors"): + for i, t in enumerate(tensors): + original_device = t.device + t = t.to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([t], dst_rank, tag=3 + i)) + + # if we did a copy we should wait for the work to complete so we + # can free the memory to avoid OOMs + if original_device == torch.device("cpu"): + for w in work: + w.wait() + work = [] + + for w in work: + w.wait() + + def recv_checkpoint(self, src_rank: int) -> object: + """ + Receive a checkpoint from a source rank. + + The process: + 1. Receives metadata about the checkpoint structure + 2. Receives each tensor, potentially reusing existing tensors for in-place updates + 3. Reconstructs the original state dict structure + + Args: + src_rank: The source rank to receive the checkpoint from + + Returns: + The reconstructed state dictionary with model parameters + """ + + state_dict = self._state_dict() if self._state_dict else {} + state_dict_leaves, _ = tree_flatten_with_path(state_dict) + + dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves) + + len_t = torch.zeros(1, dtype=torch.int64, device=self._device) + self._pg.recv([len_t], src_rank, tag=1).wait() + length = cast(int, len_t.item()) + + buf = torch.empty(length, dtype=torch.uint8, device=self._device) + self._pg.recv([buf], src_rank, tag=2).wait() + + meta: _StateDictMeta = pickle.loads(buf.cpu().numpy().tobytes()) + + i: int = 0 + works: list[Work] = [] + + def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor: + nonlocal i + + inplace = dst_tensors.get(path) + if ( + isinstance(inplace, torch.Tensor) + and inplace.device.type == self._device.type + ): + if isinstance(inplace, DTensor): + inplace = inplace._local_tensor + t = _cast_tensor(inplace, torch.uint8) + assert t.nbytes == v.nbytes, ( + "inplace tensor storage must be the same size" + ) + else: + t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device) + + work = self._pg.recv([t], src_rank, tag=3 + i) + i += 1 + + if inplace is None: + # if not inplace we need to copy it to CPU to avoid OOMing + work.wait() + t = t.cpu() + else: + works.append(work) + + return torch.as_strided( + t.view(v.dtype), + size=v.shape, + stride=v.stride, + storage_offset=v.storage_offset, + ) + + values: list[object] = [] + for path, v in zip(meta.paths, meta.non_tensor_leaves): + if isinstance(v, _TensorMeta): + values.append(recv(path, v)) + elif isinstance(v, _DTensorMeta): + tensor = recv(path, v.local) + values.append(DTensor(tensor, v.spec, requires_grad=False)) + else: + values.append(v) + + for work in works: + work.wait() + + return tree_unflatten(values, meta.treespec) diff --git a/torch/distributed/checkpoint/_state_dict_stager.py b/torch/distributed/checkpoint/_state_dict_stager.py new file mode 100644 index 00000000000000..414b079aa10904 --- /dev/null +++ b/torch/distributed/checkpoint/_state_dict_stager.py @@ -0,0 +1,365 @@ +# mypy: allow-untyped-defs +import logging +import types +import weakref +from copyreg import dispatch_table +from logging import getLogger +from typing import Any + +import torch +import torch.cuda._pin_memory_utils as pin_memory_utils +from torch.storage import UntypedStorage +from torch.utils.weak import WeakIdKeyDictionary + + +logger = getLogger() +logger.setLevel(logging.INFO) + + +class StateDictStager: + """ + A class for optimizing storage objects during staging for async checkpointing. + + StateDictStager stages the state_dict to CPU DRAM while applying optimizations + like memory sharing and pinning to improve performance. It caches storage objects + to avoid redundant copies and can be configured to automatically share memory + (for multi-process usage) and pin memory (for faster CPU-GPU transfers). + + Attributes: + pin_memory (bool): Whether to pin CPU memory for faster CPU-GPU transfers + share_memory (bool): Whether to share memory across processes + _cached_storage_mapping (WeakIdKeyDictionary): Maps storage objects to optimized CPU storages using weak references + """ + + def __init__(self, pin_memory: bool = False, share_memory: bool = False): + if pin_memory and not torch.cuda.is_available(): + logger.warning( + "Ignoring pin_memory flag for checkpoint staging as pinning memory" + "requires CUDA, but CUDA is not available. " + ) + self.pin_memory = False + else: + self.pin_memory = pin_memory + self.share_memory = share_memory + # Mapping from original storage objects to CPU storages using weak references + self._cached_storage_mapping = WeakIdKeyDictionary() + + def _deepcopy_atomic(x, _): + return x + + def _deepcopy_list(x, memo): + y: list = [] + memo[id(x)] = y + append = y.append + for a in x: + append(self.deepcopy_with_tensor_offload(a, memo)) + return y + + def _deepcopy_tuple(x, memo): + y = [self.deepcopy_with_tensor_offload(a, memo) for a in x] + # We're not going to put the tuple in the memo, but it's still important we + # check for it, in case the tuple contains recursive mutable structures. + try: + return memo[id(x)] + except KeyError: + pass + + # Check if any elements changed during deepcopy + for k, j in zip(x, y): + if k is not j: + # At least one element changed, create new tuple + return tuple(y) + + # No elements changed, return original tuple + return x + + def _deepcopy_dict(x, memo): + y: dict = {} + memo[id(x)] = y + for key, value in x.items(): + y[self.deepcopy_with_tensor_offload(key, memo)] = ( + self.deepcopy_with_tensor_offload(value, memo) + ) + return y + + def _deepcopy_method(x, memo): # Copy instance methods + return type(x)( + x.__func__, self.deepcopy_with_tensor_offload(x.__self__, memo) + ) + + d: dict[Any, Any] = {} + self._deepcopy_dispatch = d + d[type(None)] = _deepcopy_atomic + d[int] = _deepcopy_atomic + d[float] = _deepcopy_atomic + d[bool] = _deepcopy_atomic + d[complex] = _deepcopy_atomic + d[bytes] = _deepcopy_atomic + d[str] = _deepcopy_atomic + d[types.CodeType] = _deepcopy_atomic + d[type] = _deepcopy_atomic + d[range] = _deepcopy_atomic + d[types.BuiltinFunctionType] = _deepcopy_atomic + d[types.FunctionType] = _deepcopy_atomic + d[weakref.ref] = _deepcopy_atomic + d[property] = _deepcopy_atomic + d[types.MethodType] = _deepcopy_method + d[dict] = _deepcopy_dict + d[tuple] = _deepcopy_tuple + d[list] = _deepcopy_list + + def _stage_untyped_storage( + self, storage: UntypedStorage, non_blocking: bool = False + ): + """ + Called from the hooked storage_deepcopy function in torch.Tensor.__deepcopy__. + + This method handles the storage optimization logic for the StagingStateDict class. + It checks if the storage has already been cached, and if so, reuses it. + Otherwise, it creates a new CPU storage and applies memory optimizations. + + Args: + storage: The storage to optimize + + Returns: + The optimized storage + """ + # Check if we've already cached this storage + if storage in self._cached_storage_mapping: + cached_storage = self._cached_storage_mapping[storage] + assert cached_storage.size() == storage.size(), ( + "For async checkpointing, We cache storages in DRAM and reuse them." + "Cached storage size does not match original storage size." + "This should never happen as we track the original storage weakref " + "and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing." + ) + # Reuse cached storage but update with new data + cached_storage.copy_(storage, non_blocking=non_blocking) + return cached_storage + + # Create new CPU storage + if self.share_memory: + new_storage = type(storage)._new_shared(storage.size(), device="cpu") + else: + new_storage = type(storage)(storage.size(), device="cpu") + + if self.pin_memory and new_storage.nbytes() > 0: + pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes()) + # Set up a weak reference to unpin when cpu storage is garbage collected + f = weakref.finalize( + new_storage, pin_memory_utils.unpin_memory, new_storage.data_ptr() + ) + # This makes sure that the finalizer is not called after + # cuda context is destroyed. + f.atexit = False + + new_storage.copy_(storage, non_blocking=non_blocking) + + # Cache the storage - WeakIdKeyDictionary will automatically clean up when storage is garbage collected + self._cached_storage_mapping[storage] = new_storage + return new_storage + + @torch.no_grad() + def stage( + self, + state_dict: dict[str, Any], + non_blocking: bool = False, + ) -> dict[str, Any]: + return self.deepcopy_with_tensor_offload(state_dict, non_blocking=non_blocking) + + def _offload_tensor(self, x, memo, non_blocking=False): + """ + Deep copy a PyTorch tensor with optimized storage handling. + + This method creates a CPU copy of a tensor while applying memory optimizations + like sharing and pinning based on the StateDictStager configuration. + + Args: + x: The tensor to copy + memo: Memo dictionary for tracking already copied objects + non_blocking: Whether to perform non-blocking copies where possible + + Returns: + A CPU copy of the tensor with optimized storage + """ + # if data_ptr is not 0, we allocate a new storage below. so we can skip + # memory allocation by using [] for size. + y = x.new_empty([] if x.data_ptr() != 0 else x.size(), device="cpu") + + # Store in memo dict early to handle recursive references + d = id(x) + memo[d] = y + + if type(x) is torch.Tensor or x.data_ptr() != 0: + # Try to get the untyped storage and optimize it + untyped_storage = x.untyped_storage() + copied_storage = self._stage_untyped_storage( + untyped_storage, non_blocking=non_blocking + ) + # Set the tensor data using the optimized storage + y.set_(copied_storage, x.storage_offset(), x.size(), x.stride()) + + # Copy any attributes the tensor might have + if hasattr(x, "__dict__"): + for attr_name, attr_value in x.__dict__.items(): + setattr( + y, + attr_name, + self.deepcopy_with_tensor_offload( + attr_value, memo, non_blocking=non_blocking + ), + ) + + if hasattr(x, "__slots__"): + for slot in x.__slots__: + if hasattr(x, slot): + setattr( + y, + slot, + self.deepcopy_with_tensor_offload( + getattr(x, slot), memo, non_blocking=non_blocking + ), + ) + + return y + + def close(self): + """ + Clean up all cached storages and release associated resources. + + This method clears the internal storage cache, allowing garbage collection + of cached CPU storages. Any pinned memory associated with cached storages + will be automatically unpinned through weak reference finalizers. + """ + self._cached_storage_mapping.clear() + + @torch.no_grad() + def deepcopy_with_tensor_offload(self, x, memo=None, _nil=[], non_blocking=False): # noqa: B006 + """Deep copy operation on arbitrary Python objects with special handling for PyTorch tensors. + + This implementation extends the standard deepcopy functionality to handle PyTorch tensors + and their storages in a way that optimizes memory usage and performance, similar to the + stage method. It applies memory sharing and pinning optimizations based on the StateDictStager + configuration. + + Args: + x: The object to deep copy + memo: Memo dictionary for tracking already copied objects + _nil: Sentinel value for memo dictionary + non_blocking: Whether to perform non-blocking copies where possible + + Returns: + A deep copy of the input object with optimized tensor storage handling + """ + if memo is None: + memo = {} + + d = id(x) + y = memo.get(d, _nil) + if y is not _nil: + return y + + cls = type(x) + + # tensors and subclasses of tensors are handled separately + if isinstance(x, torch.Tensor): + y = self._offload_tensor(x, memo, non_blocking=non_blocking) + + # Use the dispatch table for standard types + copier = self._deepcopy_dispatch.get(cls) + if copier is not None: + y = copier(x, memo) + else: + if issubclass(cls, type): + y = self._deepcopy_dispatch[type](x, memo) + else: + copier = getattr(x, "__deepcopy__", None) + if copier is not None: + y = copier(memo) + else: + reductor = dispatch_table.get(cls) + if reductor: + rv = reductor(x) + else: + reductor = getattr(x, "__reduce_ex__", None) + if reductor is not None: + rv = reductor(4) + else: + reductor = getattr(x, "__reduce__", None) + if reductor: + rv = reductor() + else: + raise RuntimeError( + f"un(deep)copyable object of type {cls}" + ) + if isinstance(rv, str): + y = x + else: + y = self._reconstruct(x, memo, *rv) + + # If is its own copy, don't memoize. + if y is not x: + memo[d] = y + self._keep_alive(x, memo) # Make sure x lives at least as long as d + return y + + def _keep_alive(self, x, memo): + """Keeps a reference to the object x in the memo. + + Because we remember objects by their id, we have + to assure that possibly temporary objects are kept + alive by referencing them. + We store a reference at the id of the memo, which should + normally not be used unless someone tries to deepcopy + the memo itself... + """ + try: + memo[id(memo)].append(x) + except KeyError: + # aha, this is the first one :-) + memo[id(memo)] = [x] + + def _reconstruct( + self, x, memo, func, args, state=None, listiter=None, dictiter=None + ): + deep = memo is not None + if deep and args: + args = (self.deepcopy_with_tensor_offload(arg, memo) for arg in args) + y = func(*args) + if deep: + memo[id(x)] = y + + if state is not None: + if deep: + state = self.deepcopy_with_tensor_offload(state, memo) + if hasattr(y, "__setstate__"): + y.__setstate__(state) + else: + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + else: + slotstate = None + if state is not None: + y.__dict__.update(state) + if slotstate is not None: + for key, value in slotstate.items(): + setattr(y, key, value) + + if listiter is not None: + if deep: + for item in listiter: + item = self.deepcopy_with_tensor_offload(item, memo) + y.append(item) + else: + for item in listiter: + y.append(item) + if dictiter is not None: + if deep: + for key, value in dictiter: + key = self.deepcopy_with_tensor_offload(key, memo) + value = self.deepcopy_with_tensor_offload(value, memo) + y[key] = value + else: + for key, value in dictiter: + y[key] = value + return y diff --git a/torch/distributed/checkpoint/_storage_utils.py b/torch/distributed/checkpoint/_storage_utils.py index 3d8d9a0806ae0f..73acc628342a05 100644 --- a/torch/distributed/checkpoint/_storage_utils.py +++ b/torch/distributed/checkpoint/_storage_utils.py @@ -17,7 +17,7 @@ def _storage_setup( if not checkpoint_id: raise RuntimeError( - "`checkpoint_id` must be specificed if " + "`checkpoint_id` must be specified if " "storage_reader/storage_writer is None." ) diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index dc4978ecad0608..baae0b2bd498ac 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -442,7 +442,7 @@ def set_up_planner( if isinstance(v, TensorStorageMetadata): v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] - if k in metadata.planner_data: + if metadata.planner_data is not None and k in metadata.planner_data: set_element(state_dict, metadata.planner_data[k], v) else: state_dict[k] = v diff --git a/torch/distributed/checkpoint/examples/async_checkpointing_example.py b/torch/distributed/checkpoint/examples/async_checkpointing_example.py index 0f8e392f4e9c36..eb0562ec3dada7 100644 --- a/torch/distributed/checkpoint/examples/async_checkpointing_example.py +++ b/torch/distributed/checkpoint/examples/async_checkpointing_example.py @@ -4,6 +4,7 @@ import os import shutil import traceback +from concurrent.futures import Future import torch import torch.distributed as dist @@ -106,6 +107,7 @@ def run(rank, world_size): if epoch % SAVE_PERIOD == 0: if f is not None: + assert isinstance(f, Future) f.result() f = dcp.state_dict_saver.async_save( state_dict, checkpoint_id=CHECKPOINT_DIR @@ -122,6 +124,7 @@ def run(rank, world_size): _print("Reloading model from last checkpoint!") if f is not None: + assert isinstance(f, Future) f.result() dcp.load(state_dict) diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 9219ed5f44717a..b3d3b3f915179f 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -2,6 +2,7 @@ import collections import dataclasses import io +import json import operator import os import pickle @@ -29,6 +30,13 @@ ExtensionRegistry, StreamTransformExtension, ) +from torch.distributed.checkpoint._hf_utils import ( + CUSTOM_METADATA_KEY, + DCP_VERSION_KEY, + FORMAT_KEY, + FORMAT_VALUE, + HF_DCP_VERSION, +) from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, StorageMeta from torch.distributed.checkpoint.planner import ( LoadItemType, @@ -374,7 +382,7 @@ def _write_files_from_queue( custom_device_mod = getattr(torch, custom_backend_name, None) # TODO: Using the OverlappingCpuLoader with multiple threads creates significant - # performance degredation, observed as being related to cuda stream syncs. We + # performance degradation, observed as being related to cuda stream syncs. We # should try to fix this and use _OverlappingCpuLoader for all threaded cases if ( thread_count == 1 @@ -416,6 +424,7 @@ def _write_files_from_queue( ) tensor_dict = {} + metadata_dict = {} for tensor, write_item in loader.values(): assert tensor.is_cpu write_results.append( @@ -423,17 +432,29 @@ def _write_files_from_queue( transforms, stream, tensor, - write_item, + write_item, # type: ignore[arg-type] storage_key, serialization_format, ) ) - tensor_dict[write_item.index.fqn] = tensor + tensor_dict[write_item.index.fqn] = tensor # type: ignore[attr-defined] + metadata_dict[write_item.index.fqn] = { # type: ignore[attr-defined] + "saved_offsets": write_item.tensor_data.chunk.offsets # type: ignore[attr-defined] + } if serialization_format == SerializationFormat.SAFETENSORS: from safetensors.torch import save # type: ignore[import-not-found] - stream.write(save(tensor_dict)) + stream.write( + save( + tensor_dict, + metadata={ + CUSTOM_METADATA_KEY: json.dumps(metadata_dict), + DCP_VERSION_KEY: str(HF_DCP_VERSION), + FORMAT_KEY: FORMAT_VALUE, + }, + ) + ) if use_fsync: try: @@ -925,7 +946,7 @@ def __init__( per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation - that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False. + that the stager is maintained and reused for multiple dcp.async_save calls. Default to False. overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. _extensions: Extensions to apply to output streams (EXPERIMENTAL) diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py new file mode 100644 index 00000000000000..b0df40b9bb0704 --- /dev/null +++ b/torch/distributed/checkpoint/hf_storage.py @@ -0,0 +1,366 @@ +# mypy: allow-untyped-defs +import dataclasses +import json +import logging +import queue +from typing import Any, Optional + +import torch +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + consolidate_safetensors_files, +) +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter +from torch.distributed.checkpoint._hf_utils import ( + _gen_file_name, + _get_dtype, + _get_safetensors_file_metadata, + _HFStorageInfo, + _metadata_fn, + CUSTOM_METADATA_KEY, + DATA_KEY, + DATA_OFFSETS_KEY, + DEFAULT_EXTRA_METADATA_KEY, + DTYPE_KEY, + SAVED_OFFSETS_KEY, + SHAPE_KEY, + SUFFIX, +) +from torch.distributed.checkpoint.filesystem import SerializationFormat +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + Metadata, + MetadataIndex, + StorageMeta, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, +) +from torch.distributed.checkpoint.storage import WriteResult +from torch.futures import Future + + +logger: logging.Logger = logging.getLogger(__name__) + +__all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"] + + +class HuggingFaceStorageWriter(FsspecWriter): + """ + A writer that writes to a huggingface repository in the huggingface format. + Uses Fsspec back-end to communicate with back-end storage. + Fsspec registration of the storage solution is required. + """ + + def __init__( + self, + path: str, + fqn_to_index_mapping: Optional[dict[str, int]] = None, + thread_count: int = 1, + token: Optional[str] = None, + save_distributed: bool = False, + consolidated_output_path: Optional[str] = None, + thread_count_consolidation: int = 1, + ) -> None: + """ + Initialize the huggingface writer pointing to path. + + Args: + path: hf directory where the checkpoint will be read from. + Needs to have .safetensors files, but can be from any fsspec supported storage, + including localFS and hf://. + This needs to be a remote path if you want to enable consolidation after saving. + fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. + Indices are from 1 to N, where N is the number of files. If not provided, + the tensors will be written to a single file. If none, then all the tensors on the + same rank will be written to the same file. + thread_count: Number of threads to use to write distributed checkpoint. Default to 1. + token: The token to use to authenticate with huggingface hub. + save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard. + Default is False which assumes rank-0 checkpointing of the full state_dict. + consolidated_output_path: If provided, the output path where the consolidated files will be written in the finish step. + This needs to be a local fs path right now. + thread_count_consolidation: Number of threads to use for parallel processing of saving data + to consolidated output files. Default to 1. + """ + + if token is not None: + super().__init__( + path=path, + token=token, + serialization_format=SerializationFormat.SAFETENSORS, + thread_count=thread_count, + ) + else: + super().__init__( + path=path, + serialization_format=SerializationFormat.SAFETENSORS, + thread_count=thread_count, + ) + self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping + self.save_distributed: bool = save_distributed + self.consolidated_output_path: Optional[str] = consolidated_output_path + self.thread_count_consolidation = thread_count_consolidation + + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + new_plans = [] + for i, plan in enumerate(plans, start=1): + storage_data: dict[str, Any] = {} + if self.fqn_to_index_mapping is not None: + storage_data["fqn_to_index_mapping"] = self.fqn_to_index_mapping + if self.save_distributed: + storage_data["shard_index"] = i + + new_plans.append(dataclasses.replace(plan, storage_data=storage_data)) + + return new_plans + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ) -> Future[list[WriteResult]]: + if len(plan.items) == 0: + fut: Future = Future() + fut.set_result([]) + return fut + + # storage_plan is a map from key to file index + storage_data: dict[str, Any] = plan.storage_data + storage_plan: Optional[dict[str, int]] = None + shard_index: Optional[int] = None + if "fqn_to_index_mapping" in storage_data: + storage_plan = storage_data["fqn_to_index_mapping"] + if "shard_index" in storage_data: + shard_index = storage_data["shard_index"] + + buckets = self._split_by_storage_plan(storage_plan, plan.items) + highest_index = max(storage_plan.values()) if storage_plan is not None else 1 + + file_queue: queue.Queue = queue.Queue() + for file_index, write_items in buckets.items(): + file_name = _gen_file_name(file_index, highest_index, shard_index) + file_queue.put( + (self.fs.concat_path(self.path, file_name), file_name, write_items) + ) + + return super()._write_data(planner, file_queue) + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + if self.save_distributed and not self.consolidated_output_path: + # if we are saving distributed, without consolidating, + # then we have no metadata to write because a metadata + # file with fqn to file mapping doesn't make sense + # in this case, because fqns will be in multiple files + logger.info("Not consolidating sharded checkpoint in finish step.") + return + if self.save_distributed: + return consolidate_safetensors_files( + input_dir=str(self.path), + output_dir=self.consolidated_output_path, # type: ignore[arg-type] + num_threads=self.thread_count_consolidation, + fqn_to_index_mapping=self.fqn_to_index_mapping, + ) + + # writing a model.index.safetensors.json file with fqn to file mapping + # for the rank-0 checkpointing case + metadata_to_write = {} + storage_md = {} + total_size = 0 + for wr_list in results: + storage_md.update( + {wr.index.fqn: wr.storage_data.relative_path for wr in wr_list} + ) + total_size += sum([wr.storage_data.length for wr in wr_list]) + metadata_to_write["metadata"] = {"total_size": total_size} + metadata_to_write["weight_map"] = storage_md + + metadata_path = self.fs.concat_path(self.path, f"{_metadata_fn}") + with self.fs.create_stream(metadata_path, "w") as metadata_file: + json.dump(metadata_to_write, metadata_file, indent=2) + + def _split_by_storage_plan( + self, storage_plan: Optional[dict[str, int]], items: list[WriteItem] + ) -> dict[int, list[WriteItem]]: + # storage_plan is a map from key to index + if storage_plan is None: + return {1: items} + + buckets = {} + for item in items: + key = item.index.fqn + + idx = storage_plan[key] + if idx not in buckets: + buckets[idx] = [item] + else: + buckets[idx].append(item) + + return buckets + + @property + def metadata_path(self) -> str: + return _metadata_fn + + +class HuggingFaceStorageReader(FsspecReader): + """ + A reader that reads from a huggingface repository in the huggingface format. + Uses in Fsspec back-end to communicate with storage. + Fsspec registration of the storage solution is required. + """ + + def __init__(self, path: str, token: Optional[str] = None) -> None: + """ + Initialize the huggingface reader pointing to path. + + Args: + path: hf directory where the checkpoint will be read from. + Needs to have .safetensors file, but can be from any fsspec supported storage, + including localFS and hf://. + token: The token to use to authenticate with huggingface hub. + """ + + if token is not None: + super().__init__(path=path, token=token) + else: + super().__init__(path=path) + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + from safetensors import deserialize # type: ignore[import-not-found] + + per_file: dict[str, list[ReadItem]] = {} + + for read_item in plan.items: + item_md: _HFStorageInfo = self.storage_data[read_item.storage_index] + file_name = item_md.relative_path + per_file.setdefault(file_name, []).append(read_item) + + for file_name, reqs in per_file.items(): + with self.fs.create_stream(file_name, "rb") as stream: + # TODO: make this more efficient by doing offset reads instead of a + # full deserialization of the file + deserialized = deserialize(stream.read()) + deserialized_dict: dict[str, dict[str, Any]] = { + tensor_info[0]: tensor_info[1] for tensor_info in deserialized + } + + for req in reqs: + item_md = self.storage_data[req.storage_index] + + tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY] + + tensor = torch.frombuffer( + tensor_bytes, + dtype=item_md.dtype, + ) + tensor = tensor.reshape(item_md.shape) + tensor = narrow_tensor_by_index( + tensor, req.storage_offsets, req.lengths + ) + target_tensor = planner.resolve_tensor(req).detach() + + assert target_tensor.size() == tensor.size(), ( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + def read_metadata(self) -> Metadata: + state_dict_metadata: dict[str, TensorStorageMetadata] = {} + storage_data: dict[MetadataIndex, _HFStorageInfo] = {} + + safetensors_files = [] + for file in self.fs.ls(self.path): + if file.endswith(SUFFIX): + safetensors_files.append(file) + + for safetensor_file in safetensors_files: + with self.fs.create_stream(safetensor_file, "rb") as f: + safetensors_metadata, _ = _get_safetensors_file_metadata(f) + custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY) + + dcp_sharding_info = None + if custom_metadata and custom_metadata.get(CUSTOM_METADATA_KEY): + dcp_sharding_info = json.loads( + custom_metadata.get(CUSTOM_METADATA_KEY) + ) + + for key, val in safetensors_metadata.items(): + if key == DEFAULT_EXTRA_METADATA_KEY: + continue + + # construct state_dict_metadata + if dcp_sharding_info is not None: + offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY] + else: + offset = [0] * len(val[SHAPE_KEY]) + + if key not in state_dict_metadata: + state_dict_metadata[key] = TensorStorageMetadata( + properties=TensorProperties( + dtype=_get_dtype(val[DTYPE_KEY]) + ), + size=torch.Size( + [ + saved + offset + for saved, offset in zip(val[SHAPE_KEY], offset) + ] + ), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=torch.Size(val[SHAPE_KEY]), + ) + ], + ) + else: + state_dict_metadata[key].chunks.append( + ChunkStorageMetadata( + torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]) + ) + ) + size = list(state_dict_metadata[key].size) + for i in range(len(size)): + size[i] = max(size[i], val[SHAPE_KEY][i] + offset[i]) + state_dict_metadata[key].size = torch.Size(size) + + # construct storage data + if dcp_sharding_info is not None: + metadata_index = MetadataIndex( + fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY] + ) + else: + metadata_index = MetadataIndex( + fqn=key, offset=[0] * len(val[SHAPE_KEY]) + ) + storage_data[metadata_index] = _HFStorageInfo( + relative_path=safetensor_file, + offset=val[DATA_OFFSETS_KEY][0], + length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0], + shape=torch.Size(val[SHAPE_KEY]), + dtype=_get_dtype(val[DTYPE_KEY]), + ) + + metadata = Metadata( + state_dict_metadata=state_dict_metadata, # type: ignore[arg-type] + storage_data=storage_data, + ) + + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = StorageMeta() + metadata.storage_meta.load_id = self.load_id # type: ignore[union-attr] + + return metadata diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index 43193afe6e67c4..ed864aa249653a 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -89,7 +89,7 @@ def _is_nested_tensor(val: torch.Tensor) -> bool: if type(val.local_shards()[0].tensor) is ShardedTensor: return True if type(val.local_shards()[0].tensor) is DTensor: - raise ValueError("Cannot handle DTensor nested insided ShardedTensor") + raise ValueError("Cannot handle DTensor nested inside ShardedTensor") elif type(val) is DTensor and ( type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor ): diff --git a/torch/distributed/checkpoint/resharding.py b/torch/distributed/checkpoint/resharding.py index a911bda05485eb..e6f24b891aa895 100644 --- a/torch/distributed/checkpoint/resharding.py +++ b/torch/distributed/checkpoint/resharding.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - from torch.distributed.checkpoint.metadata import ChunkStorageMetadata @@ -8,7 +6,7 @@ def _check_shard_metadata_pair_overlap( shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata -): +) -> bool: """Check if two shards overlap.""" # For each dim of each shard, check if one shard resides on the other # end of second shard with respect to that dim. As an example for a 2D diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index 9f3233ad06d5f8..a2093f803ee6d0 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -1,11 +1,33 @@ -from typing import Optional, runtime_checkable -from typing_extensions import Protocol +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Any, Optional, Union +from typing_extensions import deprecated, Protocol, runtime_checkable +import torch from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint._state_dict_stager import StateDictStager from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE -__all__ = ["AsyncStager", "BlockingAsyncStager"] +__all__ = ["AsyncStager", "BlockingAsyncStager", "DefaultStager", "StagingOptions"] + +""" +Experimental staging module for PyTorch Distributed Checkpointing. +This module provides advanced staging capabilities for checkpoints including: +- Asynchronous staging using ThreadPoolExecutor +- Pinned memory allocation for faster CPU-GPU transfers +- Shared memory support for multi-process scenarios +- Non-blocking CUDA operations with stream synchronization +- Caching of frequently used storages for efficient memory management +- Automatic resource cleanup and memory management +Classes: + AsyncStager: Protocol defining the staging interface + StagingOptions: Configuration dataclass for staging behavior + DefaultStager: Default implementation with comprehensive staging features + BlockingAsyncStager: Implementation of AsyncStager which stages the state_dict + on CPU RAM and blocks until the copy is complete. Please use DefaultStager instead. +""" @runtime_checkable @@ -44,24 +66,199 @@ def should_synchronize_after_execute(self) -> bool: """ Whether to synchronize after executing the stage. """ - return self._synchronize_after_execute - def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + def stage( + self, state_dict: STATE_DICT_TYPE + ) -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: """ Returns a "staged" copy of `state_dict`. The expectation of the staged copy is that it is - innoculated from any updates incurred after the stage call is complete. + inoculated from any updates incurred after the stage call is complete. """ raise NotImplementedError( f"{self.__class__.__name__} must implement stage method" ) + @deprecated( + "`synchronize_staging` is deprecated and will be removed in future versions." + "Please use staging_future from AsyncSaveResponse instead.", + category=FutureWarning, + ) def synchronize_staging(self) -> None: """ In the case `stage` is async in some way, this method should be called to ensure staging is complete and it is safe to begin modifying the original `state_dict` """ + def close(self) -> None: + """ + Clean up all resources used by the stager. + """ + + +@dataclass +class StagingOptions: + """ + Configuration options for checkpoint staging behavior. + + Attributes: + use_pinned_memory (bool): Enable pinned memory allocation for faster + CPU-GPU transfers. Requires CUDA to be available. Default: True + use_shared_memory (bool): Enable shared memory for multi-process + scenarios. Useful when multiple processes need access to the + same staged data. Default: True + use_async_staging (bool): Enable asynchronous staging using a + background thread pool. Allows overlapping computation with + staging operations. Requires CUDA. Default: True + use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory + copies with stream synchronization. Improves performance by + allowing CPU work to continue during GPU transfers. Default: True + + Note: + CUDA-dependent features will raise exception if CUDA is not available. + """ + + use_pinned_memory: bool = True + use_shared_memory: bool = True + use_async_staging: bool = True + use_cuda_non_blocking_copy: bool = True + + +class DefaultStager(AsyncStager): + """ + DefaultStager provides a full-featured staging implementation that combines + multiple optimization techniques for efficient checkpoint preparation. + + The staging process works as follows: + 1. State dictionary is submitted for staging (sync or async) + 2. Tensors are copied from GPU to optimized CPU storage + 3. CUDA operations are synchronized if non-blocking copies are used + 4. Staged state dictionary is returned or made available via Future + + Usage Patterns: + # Synchronous staging + stager = DefaultStager(StagingOptions(use_async_staging=False)) + staged_dict = stager.stage(state_dict) + stager.close() + + # Asynchronous staging + stager = DefaultStager(StagingOptions(use_async_staging=True)) + future = stager.stage(state_dict) + # ... do other work ... + staged_dict = future.result() + stager.close() + + # Context manager pattern (recommended) + stager = DefaultStager(config) + with stager: + result = stager.stage(state_dict) + + Performance Considerations: + - Async staging provides best performance when model computation + can overlap with staging operations + - Pinned memory improves CPU-GPU transfer speeds but uses more memory + - Shared memory allows efficient IPC to checkpoint process + - Non-blocking copies reduce GPU idle time during memory transfers + + Thread Safety: + DefaultStager is not thread-safe. Each thread should use its own + instance, or external synchronization should be provided. + """ + + def __init__( + self, + config: StagingOptions = StagingOptions(), + ): + self._config = config + self._state_dict_stager = StateDictStager( + pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory + ) + self._staging_executor = None + self._staging_stream = None + if self._config.use_async_staging: + self._staging_executor = ThreadPoolExecutor(max_workers=1) + if torch.cuda.is_available(): + # Note: stream needs to be initialized on the main thread after default cuda + # stream is setup/used to avoid the risk of accidentally reusing the main + # compute stream or in other cases kernels actually launching from the + # main thread. + self._staging_stream = torch.cuda.Stream() + + if self._config.use_cuda_non_blocking_copy: + assert torch.cuda.is_available(), "Non-blocking copy requires CUDA" + + self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None + + def stage( + self, + state_dict: STATE_DICT_TYPE, + **kwargs: Any, + ) -> Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]]: + """ + This function is responsible for staging staging the state_dict. + See class docstring for more details on staging. + If use_async_staging is True, it will return a Future object that will be + fulfilled when staging is complete. + If use_async_staging is False, it will return the fully staged state_dict. + + Args: + state_dict (STATE_DICT_TYPE): The state_dict to be staged. + """ + if self._config.use_async_staging: + assert self._staging_executor is not None + self._staging_future = self._staging_executor.submit( + self._stage, + state_dict, + **kwargs, + ) + return self._staging_future + else: + return self._stage(state_dict, **kwargs) + + def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE: + if self._config.use_cuda_non_blocking_copy: + assert self._staging_stream or not self._config.use_async_staging, ( + "Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized." + ) + with ( + self._staging_stream + if self._staging_stream is not None + else nullcontext() + ): + state_dict = self._state_dict_stager.stage( + state_dict, non_blocking=self._config.use_cuda_non_blocking_copy + ) + # waits for the enqued copy operations to finish. + self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize() + else: + state_dict = self._state_dict_stager.stage(state_dict, non_blocking=False) + return state_dict + + def close(self) -> None: + """ + Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor + used for async staging operations and cleans up the underlying StateDictStager's + cached storages. Should be called when the stager is no longer needed to prevent + resource leaks, especially in long-running applications. After calling close(), + the stager should not be used for further staging operations. + + Example Usage: + stager = DefaultStager(StagingOptions(use_async_staging=True)) + future = stager.stage(state_dict) + result = future.result() + stager.close() # Clean up all resources + """ + if self._staging_executor: + self._staging_executor.shutdown(wait=True) + + def synchronize_staging(self) -> None: + """ + When use_async_staging is True, this method will wait until staging is complete. + If use_async_staging is False, this method is a no-op. + """ + if self._staging_future is not None: + self._staging_future.result() + class BlockingAsyncStager(AsyncStager): """ @@ -87,7 +284,7 @@ def __init__( Args: cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation - that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False. + that the stager is maintained and reused for multiple dcp.async_save calls. Default to False. type_check: Whether to perform a type check during cpu_offload. Defaults to False. """ @@ -113,3 +310,6 @@ def synchronize_staging(self) -> None: """ No-op function, since staging is blocking. """ + + def close(self) -> None: + pass diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 1369d8e93f3b91..a430a64fad819d 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -207,7 +207,7 @@ def _get_fqns( if not skip_compiler_prefix: fqn_obj_names.append(curr_obj_name) else: - # In some modeuls, _fqn_modifiers would not shown in the state_dict keys, + # In some modules, _fqn_modifiers would not shown in the state_dict keys, # skip them in the fqn to ensure load stat dict successfully for them. if hasattr(curr_obj, dsd_fqn_modifiers): if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get( @@ -791,7 +791,7 @@ def _get_optim_state_dict( # We need to specially handle FlatParameter FSDP as # FlatParameter FSDP converts the FQNs. # There are no easy ways to do this conversion systematically. - # We can only use a string replacment without correctness check. + # We can only use a string replacement without correctness check. if not osd: continue for k in list(osd[_STATE].keys()): diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index d2741f6f6f130e..41e185574b1941 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -74,7 +74,7 @@ def load( For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), load will first call ``state_dict`` before attempting deserialization, followed by ``load_state_dict`` once the deserialization is complete. - For each non-``Stateful`` object, load will deserailize the object, and then replace + For each non-``Stateful`` object, load will deserialize the object, and then replace it in the ``state_dict`` with the deserialized object. .. warning:: @@ -110,7 +110,7 @@ def load( checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``) planner (Optional[LoadPlanner]): - Instance of LoadPlanner. If this is not specificed, the default + Instance of LoadPlanner. If this is not specified, the default planner will be used. (Default: ``None``) process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 6caa2d23ef855c..d75ffb7bcf27d6 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -4,13 +4,14 @@ import os import warnings from concurrent.futures import Future +from dataclasses import dataclass from enum import Enum from typing import cast, Optional, Union from typing_extensions import deprecated import torch import torch.distributed as dist -from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed._state_dict_utils import STATE_DICT_TYPE from torch.distributed.checkpoint._async_executor import ( # noqa: TC001 _AsyncCheckpointExecutor, ) @@ -23,9 +24,13 @@ from torch.distributed.checkpoint._storage_utils import _storage_setup from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.logger import _dcp_method_logger -from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.checkpoint.planner import SavePlan, SavePlanner -from torch.distributed.checkpoint.staging import AsyncStager +from torch.distributed.checkpoint.staging import ( + AsyncStager, + DefaultStager, + StagingOptions, +) from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.storage import StorageWriter from torch.distributed.distributed_c10d import _get_default_group @@ -33,7 +38,13 @@ from .utils import _api_bc_check, _DistWrapper, _profile -__all__ = ["save_state_dict", "save", "async_save", "AsyncCheckpointerType"] +__all__ = [ + "save_state_dict", + "save", + "async_save", + "AsyncCheckpointerType", + "AsyncSaveResponse", +] class AsyncCheckpointerType(Enum): @@ -125,7 +136,7 @@ def save( checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``) planner (Optional[SavePlanner]): - Instance of SavePlanner. If this is not specificed, the default + Instance of SavePlanner. If this is not specified, the default planner will be used. (Default: ``None``) process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. @@ -182,6 +193,20 @@ def save( ) +@dataclass +class AsyncSaveResponse: + """This class contains futures for staging and upload completion. + It is returned by async_save(). + staging_completion is a future that indicates when local copy + of state_dict is complete. + upload_completion is a future that indicates when a checkpoint + completed saving. + """ + + staging_completion: Future[None] + upload_completion: Future[None] + + @_dcp_method_logger(log_exceptions=True) def async_save( state_dict: STATE_DICT_TYPE, @@ -191,12 +216,14 @@ def async_save( planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD, -) -> Future: + async_stager: Optional[AsyncStager] = None, +) -> Union[Future, AsyncSaveResponse]: """Asynchronous version of ``save``. This code first de-stages the state_dict on to the staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. .. warning:: This feature is experimental and subject to change. + MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED Args: state_dict (Dict[str, Any]): The state_dict to save. @@ -211,11 +238,17 @@ def async_save( checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``) planner (Optional[SavePlanner]): - Instance of SavePlanner. If this is not specificed, the default + Instance of SavePlanner. If this is not specified, the default planner will be used. (Default: ``None``) process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. (Default: ``None``) + async_checkpointer_type (AsyncCheckpointerType): + whether to do checkpoint in separate thread or pocess + (Default: ``AsyncCheckpointerType.THREAD``) + async_stager (AsyncStager): + provides staging implementation. If storage_writer implements AsyncStager + and async_stager is provided, async_stager will be used for staging Returns: Future: A future holding the resultant Metadata object from `save`. @@ -249,6 +282,20 @@ def async_save( "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" ) + if async_stager is None: + if storage_writer is not None and isinstance(storage_writer, AsyncStager): + # bwc with old storage_writers + async_stager = storage_writer + else: + async_stager = DefaultStager( + StagingOptions( + False, + False, + False, + False, + ) + ) + storage_writer = cast( StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) ) @@ -256,42 +303,57 @@ def async_save( state_dict = _stateful_to_state_dict(state_dict) @_dcp_method_logger(log_exceptions=True) - def stage_state_dict(): - if isinstance(storage_writer, AsyncStager): - staged_state_dict = storage_writer.stage(state_dict) - else: # provides bwc for storage_writers not implementing AsyncStager - staged_state_dict = _create_cpu_state_dict(state_dict) - _copy_state_dict(state_dict, staged_state_dict, type_check=False) - - return staged_state_dict + def stage_state_dict() -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + return async_stager.stage(state_dict) - staged_state_dict = stage_state_dict() + staging_future_or_state_dict = stage_state_dict() - executor: _AsyncCheckpointExecutor = ( + upload_executor: _AsyncCheckpointExecutor = ( _ProcessBasedAsyncCheckpointExecutor() if async_checkpointer_type == AsyncCheckpointerType.PROCESS else _ThreadBasedAsyncCheckpointExecutor() ) - f: Future = executor.execute_save( - staged_state_dict, + upload_future: Future = upload_executor.execute_save( + staging_future_or_state_dict, checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, process_group=process_group, ) - @_dcp_method_logger(log_exceptions=True) - def maybe_synchronize_staging(): - if ( - isinstance(storage_writer, AsyncStager) - and storage_writer.should_synchronize_after_execute + if isinstance(staging_future_or_state_dict, Future): + staging_future = staging_future_or_state_dict + return_staging_future: Future[None] = Future() + + def callback( + original_staging_future: Future[STATE_DICT_TYPE], + return_staging_future: Future[None] = return_staging_future, ): - storage_writer.synchronize_staging() + try: + original_staging_future.result() + return_staging_future.set_result(None) + except Exception as e: + return_staging_future.set_exception(e) + + if not staging_future.done(): + staging_future.add_done_callback(callback) + else: + return_staging_future.set_result(None) + + # return new AsyncSaveResponse for users using new ZOC implementation + return AsyncSaveResponse( + staging_completion=return_staging_future, upload_completion=upload_future + ) + else: - maybe_synchronize_staging() + @_dcp_method_logger(log_exceptions=True) + def maybe_synchronize_staging(): + if async_stager.should_synchronize_after_execute: + async_stager.synchronize_staging() - return f + maybe_synchronize_staging() + return upload_future @_dcp_method_logger(log_exceptions=True) diff --git a/torch/distributed/checkpoint/stateful.py b/torch/distributed/checkpoint/stateful.py index 95cbb1873d6490..15e227d92fb5d2 100644 --- a/torch/distributed/checkpoint/stateful.py +++ b/torch/distributed/checkpoint/stateful.py @@ -1,5 +1,5 @@ -from typing import Any, runtime_checkable, TypeVar -from typing_extensions import Protocol +from typing import Any, TypeVar +from typing_extensions import Protocol, runtime_checkable __all__ = ["Stateful", "StatefulT"] diff --git a/torch/distributed/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py index 9c682bc1aff4fd..8cc8b9f7520dc7 100644 --- a/torch/distributed/checkpoint/storage.py +++ b/torch/distributed/checkpoint/storage.py @@ -147,7 +147,7 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: @abc.abstractmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: """ - Check if the given checkpoint_id is supported by the stroage. This allow + Check if the given checkpoint_id is supported by the storage. This allow us to enable automatic storage selection. """ ... @@ -278,7 +278,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: @abc.abstractmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: """ - Check if the given checkpoint_id is supported by the stroage. This allow + Check if the given checkpoint_id is supported by the storage. This allow us to enable automatic storage selection. """ ... diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index dd9c27f6542c35..e39bfd25cdd38f 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -119,7 +119,7 @@ def broadcast_object(self, object: Optional[T]) -> T: dist.broadcast_object_list( object_list=object_list, group=self.group, - src=self.coordinator_rank, + src=self.global_coordinator_rank, ) return cast(T, object_list[0]) @@ -436,7 +436,7 @@ def _normalize_device_info(device_type: str, device_id: int) -> str: @contextmanager def _profile(): # Only log the profiling when it is enable and is on rank0 or dist is not - # avaiable. + # available. if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0): profiler = cProfile.Profile() profiler.enable() diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index b77e1ba8956e3d..b1a7c824c2e3bf 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -46,7 +46,7 @@ def broadcast( data_or_fn: the data to broadcast or function to execute and broadcast result. success: False to stop all ranks. stage_name: the name of the logical stage for synchronization and debugging - rank: rank to broadcast data or execute function and broadcast resutls. + rank: rank to broadcast data or execute function and broadcast results. pg: the process group for sync Throws: RuntimeError from original exception trace diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index b3754043644b8c..c1e604bc86753e 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -11,7 +11,7 @@ # To make an attempt at backwards compatibility with THD, we use an # extraordinarily high default timeout, given that THD did not have timeouts. default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT -# Separate timeout for PGNCCL mainly becuase it's always been that way in the C++ layer, but until recently +# Separate timeout for PGNCCL mainly because it's always been that way in the C++ layer, but until recently # there was one default that applied across all backends in the python layer. # Later, we could consider merging them back together at the c++ layer if we can align on a same value. # (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1). diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 55c6ff83146486..8b63844da0df44 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -102,7 +102,7 @@ def create_sub_mesh( ] mesh_tensor = device_mesh.mesh - # slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims. + # slice_dim_idx could be different from submesh_dims, as we may need to flatten out some dims. slice_dim_idx = [] slice_dim_group_name = [] # keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the @@ -125,12 +125,12 @@ def create_sub_mesh( slice_dim_group_name.append( self.root_to_flatten_mapping[device_mesh][ mesh_dim_name - ]._dim_group_names[0] + ]._dim_group_names[0] # type: ignore[has-type] ) else: slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) slice_dim_group_name.append( - device_mesh._dim_group_names[mesh_dim_indices[0]] + device_mesh._dim_group_names[mesh_dim_indices[0]] # type: ignore[has-type] ) # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now. @@ -156,7 +156,7 @@ def create_sub_mesh( if cur_rank in mesh_nd: res_submesh = submesh - res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined] + res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined, has-type] self.child_to_root_mapping[res_submesh] = device_mesh return res_submesh @@ -333,9 +333,9 @@ def _get_slice_mesh_dims( slice_mesh_dims.append((next_idx,)) if next_idx <= curr_idx: raise KeyError( - f"Invalid mesh_dim_names {mesh_dim_names} specified. ", - f"Found mesh dim indices to slice: {slice_mesh_dims}. ", - "Mesh dim indices should be in ascending order.", + f"Invalid mesh_dim_names {mesh_dim_names} specified. " + f"Found mesh dim indices to slice: {slice_mesh_dims}. " + "Mesh dim indices should be in ascending order." ) curr_idx = next_idx @@ -362,7 +362,7 @@ def _get_all_submeshes( _init_backend=False, ) submesh._dim_group_names = ( - [device_mesh._dim_group_names[mesh_dim]] + [device_mesh._dim_group_names[mesh_dim]] # type: ignore[has-type] if cur_rank in mesh_1d else [] ) @@ -392,7 +392,7 @@ class DeviceMesh: each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization), and will select/set the device for the current process if user does not set the device - beforehands. Note that manual device selection should happen BEFORE the DeviceMesh initialization. + beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization. DeviceMesh can also be used as a context manager when using together with DTensor APIs. @@ -603,7 +603,7 @@ def _init_process_groups(self): for dim_mesh in pg_ranks_by_dim: subgroup_ranks = dim_mesh.tolist() - # We temporarily revert the re-use subgroup, since it breaks two internal tests. + # We temporarily revert the reuse subgroup, since it breaks two internal tests. # Temporarily reverting to resolve test timeout while root-causing. # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. if bound_device_id is None or not has_split_group: @@ -621,7 +621,7 @@ def _init_process_groups(self): f"Each device mesh dimension should get only one process group, but got {self.get_rank()} " f"in {subgroup_ranks}!" ) - dim_group_names.append(dim_group.group_name) + dim_group_names.append(dim_group.group_name) # type: ignore[union-attr] self._dim_group_names = dim_group_names def __enter__(self) -> "DeviceMesh": @@ -961,13 +961,13 @@ def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": """ Returns a 1D DeviceMesh by flattening the current DeviceMesh. - If no mesh_dim_name is provided, the default is a string concatentaing the mesh_dim_names of the + If no mesh_dim_name is provided, the default is a string concatenating the mesh_dim_names of the given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 1, 2, 3], mesh_dim_names=("dp_cp",)) on rank 0, 1, 2, 3 and a 1D submesh DeviceMesh([4, 5, 6, 7], mesh_dim_names=("dp_cp",)) on rank 4, 5, 6, 7. - After the flattened dimension is created, to access the flattened dimesnion in mesh_3d, one can use the + After the flattened dimension is created, to access the flattened dimension in mesh_3d, one can use the existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. """ if not self.mesh_dim_names: diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 4b25bf3ea523af..18ac8344600ef8 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -340,7 +340,7 @@ def register_backend( if devices is not None: for device in devices: - if device != "cpu" and device != "cuda": + if device not in Backend.default_device_backend_map: Backend.default_device_backend_map[device] = name.lower() Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM @@ -547,7 +547,7 @@ class _CollOp: Args: op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``. tensor (Tensor): Tensor to operate on. - dst_tensor (Tensor, optional): Provided when source and destinaton tensors are not the same. + dst_tensor (Tensor, optional): Provided when source and destination tensors are not the same. redop (ReduceOp, optional): reduce operation. root (int, optional): root of broadcast or reduce. """ @@ -1074,17 +1074,18 @@ def _get_global_rank(group, rank) -> int: return get_global_rank(group, rank) -def get_process_group_ranks(group: ProcessGroup) -> list[int]: +def get_process_group_ranks(group: Optional[ProcessGroup]) -> list[int]: """ Get all ranks associated with ``group``. Args: - group (ProcessGroup): ProcessGroup to get all ranks from. + group (Optional[ProcessGroup]): ProcessGroup to get all ranks from. + If None, the default process group will be used. Returns: List of global ranks ordered by group rank. """ - return list(_world.pg_group_ranks[group].keys()) + return list(_world.pg_group_ranks[group or _get_default_group()].keys()) def _get_group_size(group) -> int: @@ -1548,7 +1549,7 @@ def init_process_group( store: Optional[Store] = None, group_name: str = "", pg_options: Optional[Any] = None, - device_id: Optional[torch.device] = None, + device_id: Optional[Union[torch.device, int]] = None, ) -> None: """ Initialize the default distributed process group. @@ -1609,17 +1610,18 @@ def init_process_group( options we support is ``ProcessGroupNCCL.Options`` for the ``nccl`` backend, ``is_high_priority_stream`` can be specified so that the nccl backend can pick up high priority cuda streams when - there're compute kernels waiting. For other availble options to config nccl, + there're compute kernels waiting. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t - device_id (torch.device, optional): a single, specific device - to "bind" this process to, allowing for backend-specific + device_id (torch.device | int, optional): a single, specific device + this process will work on, allowing for backend-specific optimizations. Currently this has two effects, only under NCCL: the communicator is immediately formed (calling ``ncclCommInit*`` immediately rather than the normal lazy call) and sub-groups will use ``ncclCommSplit`` when possible to avoid unnecessary overhead of group creation. If you want to know NCCL initialization error early, you can also use this - field. + field. If an `int` is provided, the API assumes that the accelerator + type at compile time will be used. .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source on a system that supports MPI. @@ -1664,6 +1666,39 @@ def init_process_group( elif init_method is None: init_method = "env://" + # Get the compile-time accelerator type. + # None indicates no accelerator support. + acc = torch.accelerator.current_accelerator() + + # Auto complete device id + if isinstance(device_id, int): + if acc is None: + raise ValueError( + "device_id is an int, but no accelerator support is found from the current compilation. " + "Please use a different compiled version that supports your accelerator." + ) + device_id = torch.device(acc.type, device_id) + + # Sanity check device_id + if device_id is not None and device_id.type != "cpu": + # Type + if acc is None or device_id.type != acc.type: + raise ValueError( + f"device_id {device_id} does not match the current compilation's accelerator support: {acc}. " + "Please use a different compiled version that supports your accelerator." + ) + # Index + if device_id.index is None: + raise ValueError("Please use a device_id with index.") + # Range + if device_id.index >= torch.accelerator.device_count(): + raise ValueError( + f"device_id {device_id} is out of range. Please use a device index less than " + f"the number of accelerators available: {torch.accelerator.device_count()}." + ) + + logger.info("Using device: %s", device_id) + # If user did not provide a backend string but provided a device id, e.g. # >>> init_process_group(device_id=device) # we try to figure out the backend name based on the device type. @@ -1949,9 +1984,13 @@ def _new_process_group_helper( # TODO: remove this check after lazy initialization is supported # if pg_options is not None: # raise RuntimeError("GLOO options not supported") + if not is_gloo_available(): + raise RuntimeError("Distributed package doesn't have Gloo built in") backend_class = ProcessGroupGloo( backend_prefix_store, group_rank, group_size, timeout=timeout ) + backend_class.options.global_ranks_in_group = global_ranks_in_group + backend_class.options.group_name = group_name backend_type = ProcessGroup.BackendType.GLOO elif backend_str == Backend.NCCL: if not is_nccl_available(): @@ -1994,8 +2033,18 @@ def _new_process_group_helper( elif backend_str == Backend.XCCL: if not is_xccl_available(): raise RuntimeError("Distributed package doesn't have XCCL built in") + if backend_options is not None: + assert isinstance(backend_options, ProcessGroupXCCL.Options), ( + "Expected backend_options argument to be of type ProcessGroupXCCL.Options" + ) + else: + # default backend_options for XCCL + backend_options = ProcessGroupXCCL.Options() + backend_options.is_high_priority_stream = False + backend_options.global_ranks_in_group = global_ranks_in_group + backend_options.group_name = group_name backend_class = ProcessGroupXCCL( - backend_prefix_store, group_rank, group_size + backend_prefix_store, group_rank, group_size, backend_options ) backend_type = ProcessGroup.BackendType.XCCL else: @@ -2653,7 +2702,7 @@ def _time_estimator( backend = group._get_backend(device) if not backend.supports_time_estimate: raise NotImplementedError( - f"collective time estimator is not supported in the curent version of backend {backend}" + f"collective time estimator is not supported in the current version of backend {backend}" ) backend._start_time_estimate() # type: ignore[attr-defined] cm = _TimeEstimator() @@ -2780,6 +2829,8 @@ def broadcast( opts.rootRank = group_src opts.rootTensor = 0 opts.asyncOp = async_op + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) work = group.broadcast([tensor], opts) if async_op: return work @@ -3019,7 +3070,7 @@ def _object_to_tensor(obj, device, group): _pickler(f).dump(obj) byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. - # Otherwise, it will casue 100X slowdown. + # Otherwise, it will cause 100X slowdown. # See: https://github.com/pytorch/pytorch/issues/65696 byte_tensor = torch.ByteTensor(byte_storage).to(device) if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): @@ -3077,7 +3128,7 @@ def all_gather_object(object_list, obj, group=None): .. note:: For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by - ``torch.cuda.current_device()`` and it is the user's responsiblity to + ``torch.cuda.current_device()`` and it is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()``. @@ -3181,7 +3232,7 @@ def gather_object( .. note:: For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by - ``torch.cuda.current_device()`` and it is the user's responsiblity to + ``torch.cuda.current_device()`` and it is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()``. @@ -4842,7 +4893,7 @@ def monitored_barrier( if timeout is None: timeout = _get_default_timeout(get_backend(group)) elif isinstance(timeout, float): - # TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format? + # TODO(whc) apparently some existing test case for monitored_barrier passes in a timeout in float format? warnings.warn( "Please specify timeout arg as a timedelta. " f"Converting current value of {timeout} assuming it represents seconds", @@ -4940,16 +4991,16 @@ def split_group( group_desc: Optional[str] = None, ) -> Optional[ProcessGroup]: """ - Create a new process group splitted from the given parent process group. + Create a new process group split from the given parent process group. warning:: This is an experimental API. Only the ``NCCL`` and custom plugin backends are supported. Other backends will raise an error. - Users of this API must gurantee that all ranks in the parent group enter this API call, + Users of this API must guarantee that all ranks in the parent group enter this API call, and the split of the sub groups is the same across all ranks in the parent group. Args: parent_pg (ProcessGroup, optional): The parent process group. If None, - the default process group will be used. Users need to gurantee that + the default process group will be used. Users need to guarantee that the parent group is fully initialized (e.g, communicators are initialized) split_ranks (list[list[int]]): the split ranks, which is a list of list of ranks. Users need to make sure the validity of the split ranks such that one @@ -5186,7 +5237,7 @@ def new_group( specifying what additional options need to be passed in during the construction of specific process groups. i.e. for the ``nccl`` backend, ``is_high_priority_stream`` can be specified so that - process group can pick up high priority cuda streams. For other availble options to config nccl, + process group can pick up high priority cuda streams. For other available options to config nccl, See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization (bool, optional): perform a group-local barrier at the end of the process group creation. This is different in that non-member ranks don't need to call into API and don't @@ -5207,7 +5258,7 @@ def new_group( as non-member ranks don't join the group barrier(). N.B. use_local_synchronization=True can lead to deadlocks when each rank creates - multiple overlaping process groups. To avoid that, make sure all ranks follow the + multiple overlapping process groups. To avoid that, make sure all ranks follow the same global creation order. """ return _new_group_with_tag( @@ -5443,7 +5494,7 @@ def new_subgroups( ) # TODO: Use itertools.batched(get_process_group_ranks(group=group), group_size) instead when Python 3.12 is supported. - ranks = get_process_group_ranks(group=group or _get_default_group()) + ranks = get_process_group_ranks(group=group) ranks_per_subgroup_list = [ ranks[i : i + group_size] for i in range(0, len(ranks), group_size) ] diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 90102af642319e..3d2457c9bb5cc0 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -71,7 +71,8 @@ class WorkerSpec: tee: tees the specified std stream(s) to console + file, selectively tee for a particular local rank by passing a map, takes precedence over ``redirects`` settings. - + event_log_handler: name of the event logging handler as registered in + `elastic/events/handlers.py `_. """ role: str @@ -86,6 +87,7 @@ class WorkerSpec: master_port: Optional[int] = None master_addr: Optional[str] = None local_addr: Optional[str] = None + event_log_handler: str = "null" def __post_init__(self): assert self.local_world_size > 0 @@ -424,7 +426,7 @@ def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: Note that the worker group is a mutable object and hence in a multi-threaded/process environment it may change state. - Implementors are encouraged (but not required) to return + Implementers are encouraged (but not required) to return a defensive read-only copy. """ raise NotImplementedError @@ -460,7 +462,7 @@ def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: def _stop_workers(self, worker_group: WorkerGroup) -> None: r"""Stop all workers in the given worker group. - Implementors must deal with workers in all states defined by + Implementers must deal with workers in all states defined by ``WorkerState``. That is, it must gracefully handle stopping non-existent workers, unhealthy (stuck) workers, etc. """ @@ -498,8 +500,8 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None: group_rank = rdzv_info.rank group_world_size = rdzv_info.world_size - # master_addr/master_port could be explicitly overriden - # TODO: BC - specific to static rdzv and can be simplifed further + # master_addr/master_port could be explicitly overridden + # TODO: BC - specific to static rdzv and can be simplified further master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port @@ -529,7 +531,8 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None: " role_ranks=%(role_ranks)s\n" " global_ranks=%(global_ranks)s\n" " role_world_sizes=%(role_world_sizes)s\n" - " global_world_sizes=%(global_world_sizes)s\n", + " global_world_sizes=%(global_world_sizes)s\n" + " event_log_handler=%(event_log_handler)s\n", { "role": spec.role, "restart_count": restart_count, @@ -542,6 +545,7 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None: "global_ranks": [worker.global_rank for worker in workers], "role_world_sizes": [worker.role_world_size for worker in workers], "global_world_sizes": [worker.world_size for worker in workers], + "event_log_handler": spec.event_log_handler, }, ) @@ -683,7 +687,10 @@ def _initialize_workers(self, worker_group: WorkerGroup) -> None: for local_rank, w_id in worker_ids.items(): worker = worker_group.workers[local_rank] worker.id = w_id - record(self._construct_event("START", EventSource.WORKER, worker)) + record( + self._construct_event("START", EventSource.WORKER, worker), + worker_group.spec.event_log_handler, + ) worker_group.state = WorkerState.HEALTHY @@ -741,7 +748,10 @@ def _record_worker_events(self, result: RunResult) -> None: failure = result.failures.get(worker.global_rank) state: str = self._get_worker_state(worker, result) raw_error = json.dumps(failure.error_file_data) if failure else None - record(self._construct_event(state, EventSource.WORKER, worker, raw_error)) + record( + self._construct_event(state, EventSource.WORKER, worker, raw_error), + self._worker_group.spec.event_log_handler, + ) def _get_worker_state(self, worker: Worker, result: RunResult) -> str: failure = result.failures.get(worker.global_rank) @@ -764,7 +774,8 @@ def record_duration(self, state: str): record( self._construct_event( state=state, source=EventSource.AGENT, duration_ms=duration_ms - ) + ), + self._worker_group.spec.event_log_handler, ) def _construct_event( diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 47c92600fea717..50b0e388187116 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -275,7 +275,7 @@ def _log_watchdog_event( event = events.Event( name=name, source=events.EventSource.AGENT, metadata=metadata ) - events.record(event) + events.record(event, self._worker_group.spec.event_log_handler) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 6d899a95d6a7d8..228ea5107ff894 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -290,7 +290,7 @@ def reify( - `//attempt_//error.json` """ nprocs = len(envs) - global_env = {} # use only to query properies that are not dependent on a rank + global_env = {} # use only to query properties that are not dependent on a rank if nprocs > 0: global_env = envs[0] else: @@ -452,7 +452,7 @@ def __init__( # all local ranks are accounted for nprocs = len(args) - # TODO log_line_prefixes can be exanded too + # TODO log_line_prefixes can be expanded too logs_dest = logs_specs.reify(envs) _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts") diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index 0766df8e5f3a77..c387a3ec2833ac 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -66,7 +66,7 @@ connectivity, etc), between joining the rendezvous and it being completed, then a re-rendezvous with remaining healthy nodes will happen automatically. -A node can also fail *after* it has completed (or *has been observered* by other +A node can also fail *after* it has completed (or *has been observed* by other nodes to have completed) the rendezvous - this scenario will be handled by the Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a re-rendezvous). diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index be0d6e28536f6e..9d9a192e2c17a7 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -157,9 +157,9 @@ def get_backend(self) -> str: @property def use_agent_store(self) -> bool: """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user - applications and will be available during application lifecyle. + applications and will be available during application lifecycle. - Rendezous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`. + Rendezvous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`. Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store. """ return False diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index 2cbb37a1b510f3..7ad0d470a00074 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -805,7 +805,7 @@ def _remove_from_wait_list(self) -> None: def _remove_from_redundancy_list(self) -> None: msg = ( - f"The node '{self._node}' removed itself from the redunant list of round " + f"The node '{self._node}' removed itself from the redundant list of round " f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." ) self._record(message=msg) @@ -880,7 +880,7 @@ def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: return _Action.ERROR_CLOSED if ctx.node in state.redundancy_list: - msg = f"The node {ctx.node} is in redunancy list" + msg = f"The node {ctx.node} is in redundancy list" logger.debug(msg) # don't apply the timeout logic here, since we want to allow the node to rejoin if len(state.participants) == ctx.settings.max_nodes: @@ -890,7 +890,7 @@ def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: return _Action.SYNC else: # transition to waiting state that will respect timeouts. - msg = f"The node {ctx.node} is removed from redunancy list" + msg = f"The node {ctx.node} is removed from redundancy list" logger.debug(msg) return _Action.REMOVE_FROM_REDUNDANCY_LIST diff --git a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py index 8e378c6a1be1a4..d95c2b0256fe9a 100644 --- a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py +++ b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. @@ -8,12 +7,20 @@ # LICENSE file in the root directory of this source tree. import math +from collections.abc import Iterator, Sized +from typing import cast, Optional, TypeVar import torch +from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler -class ElasticDistributedSampler(DistributedSampler): +T = TypeVar("T") + +__all__ = ["ElasticDistributedSampler"] + + +class ElasticDistributedSampler(DistributedSampler[T]): """ Sampler that restricts data loading to a subset of the dataset for elastic training. @@ -34,25 +41,39 @@ class ElasticDistributedSampler(DistributedSampler): start_index (optional): Which index of the dataset to start sampling from """ - def __init__(self, dataset, num_replicas=None, rank=None, start_index=0): + def __init__( + self, + dataset: Dataset[T], + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + start_index: int = 0, + ): super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank) - if start_index >= len(dataset): + if not isinstance(dataset, Sized): + raise TypeError("Dataset must be an instance of collections.abc.Sized") + + # Cast to Sized for mypy + sized_dataset = cast(Sized, dataset) + + if start_index >= len(sized_dataset): raise ValueError( - f"Start index {start_index} should be less than dataset size {len(dataset)}" + f"Start index {start_index} should be less than dataset size {len(sized_dataset)}" ) self.start_index = start_index + sized_dataset = cast(Sized, self.dataset) self.num_samples = int( - math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type] + math.ceil(float(len(sized_dataset) - self.start_index) / self.num_replicas) ) self.total_size = self.num_samples * self.num_replicas - def __iter__(self): + def __iter__(self) -> Iterator[T]: # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) + sized_dataset = cast(Sized, self.dataset) indices = ( - torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type] + torch.randperm(len(sized_dataset) - self.start_index, generator=g) .add(self.start_index) .tolist() ) @@ -67,5 +88,5 @@ def __iter__(self): return iter(indices) - def __len__(self): + def __len__(self) -> int: return self.num_samples diff --git a/torch/distributed/elastic/utils/logging.py b/torch/distributed/elastic/utils/logging.py index d87504d255d6f7..8f0370173b76b6 100644 --- a/torch/distributed/elastic/utils/logging.py +++ b/torch/distributed/elastic/utils/logging.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. @@ -16,7 +15,7 @@ from torch.distributed.elastic.utils.log_level import get_log_level -def get_logger(name: Optional[str] = None): +def get_logger(name: Optional[str] = None) -> logging.Logger: """ Util function to set up a simple logger that writes into stderr. The loglevel is fetched from the LOGLEVEL @@ -33,7 +32,7 @@ def get_logger(name: Optional[str] = None): return _setup_logger(name or _derive_module_name(depth=2)) -def _setup_logger(name: Optional[str] = None): +def _setup_logger(name: Optional[str] = None) -> logging.Logger: logger = logging.getLogger(name) logger.setLevel(os.environ.get("LOGLEVEL", get_log_level())) return logger diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py index 0afe82c46d8963..8c7ded1261edb8 100644 --- a/torch/distributed/elastic/utils/store.py +++ b/torch/distributed/elastic/utils/store.py @@ -184,7 +184,7 @@ def barrier( Optionally, passing rank will enable tracing of missing ranks on timeouts. `rank_tracing_decoder` lambda arg can be used to convert rank data - into a more meaninful information at an app level (e.g. hostname). + into a more meaningful information at an app level (e.g. hostname). Note: Since the data is not removed from the store, the barrier can be used once per unique ``key_prefix``. diff --git a/torch/distributed/examples/memory_tracker_example.py b/torch/distributed/examples/memory_tracker_example.py index d33ebf3f280406..177c9c49ff1c80 100644 --- a/torch/distributed/examples/memory_tracker_example.py +++ b/torch/distributed/examples/memory_tracker_example.py @@ -3,9 +3,9 @@ from torch.distributed._tools import MemoryTracker -def run_one_model(net: torch.nn.Module, input: torch.Tensor): - net.cuda() - input = input.cuda() +def run_one_model(net: torch.nn.Module, input: torch.Tensor, device: str = "cuda"): + net.to(device) + input = input.to(device) # Create the memory Tracker mem_tracker = MemoryTracker() @@ -31,6 +31,9 @@ def run_one_model(net: torch.nn.Module, input: torch.Tensor): if __name__ == "__main__": import torchvision + dev = "cuda" run_one_model( - torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda") + torchvision.models.resnet34(), + torch.rand(32, 3, 224, 224, device=dev), + device=dev, ) diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index 2103da08a976b8..ab6b5975ea941b 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -68,7 +68,7 @@ def _get_sharded_module_tree_with_module_name_to_fqns( ) -> tuple[str, dict[str, list[str]]]: """ It is used for composable fully_shard() code path, it returns - 1. sharded module tree info: each line reprents a submodule name that contats the + 1. sharded module tree info: each line represents a submodule name that contains the submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`, the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index bea8e7d522a2a4..4fe05da4c844cf 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -2294,8 +2294,9 @@ def _writeback_orig_params(self) -> bool: flat_param._params[i] = param if needs_param_writeback: expected_shape = torch.Size([numel_in_shard]) + src = param if self.uses_sharded_strategy else param.view(-1) self._writeback_tensor( - param, flat_param, i, expected_shape, offset_in_shard, True + src, flat_param, i, expected_shape, offset_in_shard, True ) wroteback = True @@ -2327,8 +2328,13 @@ def _writeback_orig_params(self) -> bool: if flat_param_grad is None: flat_param_grad = torch.zeros_like(flat_param) expected_shape = torch.Size([numel_in_shard]) + src = ( + param.grad + if self.uses_sharded_strategy + else param.grad.view(-1) + ) self._writeback_tensor( - param.grad, + src, flat_param_grad, i, expected_shape, @@ -2749,7 +2755,7 @@ def _construct_padding_tensor( # Use `lru_cache(1)` to only log the warning once (assuming the fixed warning -# messasge is passed in) +# message is passed in) @functools.lru_cache(1) def _warn_skip_writeback_check(log: logging.Logger, warning: str): logger.warning(warning) diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_api.py b/torch/distributed/fsdp/_fully_shard/_fsdp_api.py index 4e04396f07fe45..38650323f5e997 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_api.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_api.py @@ -1,8 +1,14 @@ # mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch +import torch.distributed as dist + + +_ReduceOp = Union[dist.ReduceOp, dist.ReduceOp.RedOpType] @dataclass(frozen=True) @@ -47,6 +53,80 @@ class MixedPrecisionPolicy: cast_forward_inputs: bool = True +class Comm(ABC): + """ + Interface for communication primitives. + A primitive primarily needs to handle 3 tasks, namely: + + 1. How to allocate memory for communication + Depending on the goal, an implementation can choose to: + a. associate each call to a temporary buffer + (best for flexibility and simplicity) + b. reuse an persistent buffer for efficiency reasons + + 2. Where to allocate memory + (e.g. NCCL mem pool or regular cuda caching allocator) + + 3. What to do/call upon the comm is called + (see `AllGather` interface as an example) + """ + + @abstractmethod + def allocate( + self, + size: Sequence[Union[int, torch.SymInt]], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """ + This handles the "how to allocate memory" part. + + A default implementation could be simply: + + .. code-block:: python + with self.mem_pool: + torch.empty(...) + + Args: + size (Sequence[Union[int, torch.SymInt]]): size of the tensor buffer + dtype (torch.dtype): dtype of the tensor buffer + device (torch.device): which device to allocate the tensor onto + """ + ... + + +class AllGather(Comm): + """ + Interface for all_gather comm primitive + """ + + @abstractmethod + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + async_op: bool = False, + ) -> Optional[dist.Work]: ... + + +class ReduceScatter(Comm): + """ + Interface for reduce_scatter comm primitive + """ + + @abstractmethod + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + op: _ReduceOp, + async_op: bool = False, + ) -> Optional[dist.Work]: ... + + @dataclass class OffloadPolicy: """ diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 7ad991cebdc626..6e85c0987d906c 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -1,12 +1,16 @@ +import math +from collections.abc import Sequence from itertools import chain -from typing import Callable, cast, NamedTuple, Optional, Union +from typing import Any, Callable, cast, NamedTuple, Optional, Union import torch import torch.distributed as dist from torch.distributed.device_mesh import _get_device_handle from torch.distributed.distributed_c10d import ReduceOp +from torch.distributed.fsdp._fully_shard._fsdp_api import AllGather, ReduceScatter from torch.distributed.tensor import DTensor +from ._fsdp_api import _ReduceOp from ._fsdp_common import ( _get_dim0_padded_size, _raise_assert_with_print, @@ -35,30 +39,127 @@ class AllGatherResult(NamedTuple): """ all_gather_copy_in( Tensor[] all_gather_inputs, + Tensor all_gather_output, SymInt[] inp_split_sizes, SymInt all_gather_input_numel, - SymInt world_size, - SymInt rank, - ScalarType dtype, - Device device + SymInt rank ) -> (Tensor, Tensor) """ ) +class DefaultAllocMixin: + def allocate( + self, + size: Sequence[Union[int, torch.SymInt]], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + return torch.empty(*size, dtype=dtype, device=device) + + +class ProcessGroupAllocMixin: + def __init__(self, group: dist.ProcessGroup, *args: Any, **kwargs: Any): + self._group = group + super().__init__(*args, **kwargs) + + def allocate( + self, + size: Sequence[Union[int, torch.SymInt]], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + backend = self._group._get_backend(device) + if backend.supports_tensor_alloc(device): + size_1d = math.prod(int(s) for s in size) + return backend.allocate_tensor(size_1d, dtype=dtype, device=device) + return torch.empty(*size, dtype=dtype, device=device) + + +class DefaultAllGather(DefaultAllocMixin, AllGather): + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + async_op: bool = False, + ) -> Optional[dist.Work]: + return dist.all_gather_into_tensor( + output_tensor, + input_tensor, + group=group, + async_op=async_op, + ) + + +class ProcessGroupAllocAllGather(ProcessGroupAllocMixin, AllGather): + def __init__(self, group: dist.ProcessGroup) -> None: + super().__init__(group) + + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + async_op: bool = False, + ) -> Optional[dist.Work]: + return dist.all_gather_into_tensor( + output_tensor, + input_tensor, + group=group, + async_op=async_op, + ) + + +class DefaultReduceScatter(DefaultAllocMixin, ReduceScatter): + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + op: _ReduceOp, + async_op: bool = False, + ) -> dist.Work: + return dist.reduce_scatter_tensor( + output=output_tensor, + input=input_tensor, + group=group, + op=op, + async_op=async_op, + ) + + +class ProcessGroupAllocReduceScatter(ProcessGroupAllocMixin, ReduceScatter): + def __init__(self, group: dist.ProcessGroup) -> None: + super().__init__(group) + + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + op: _ReduceOp, + async_op: bool = False, + ) -> dist.Work: + return dist.reduce_scatter_tensor( + output=output_tensor, + input=input_tensor, + group=group, + op=op, + async_op=async_op, + ) + + @torch.library.impl(lib, "all_gather_copy_in", "Meta") def all_gather_copy_in_meta( all_gather_inputs: list[torch.Tensor], + all_gather_output: torch.Tensor, inp_split_sizes: list[int], all_gather_input_numel: int, - world_size: int, rank: int, - dtype: torch.dtype, - device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - all_gather_output = torch.empty( - (all_gather_input_numel * world_size,), dtype=dtype, device="meta" - ) all_gather_input = all_gather_output.narrow( 0, all_gather_input_numel * rank, all_gather_input_numel ) @@ -73,16 +174,11 @@ def all_gather_copy_in_meta( @torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1") def all_gather_copy_in_cuda( all_gather_inputs: list[torch.Tensor], + all_gather_output: torch.Tensor, inp_split_sizes: list[int], all_gather_input_numel: int, - world_size: int, rank: int, - dtype: torch.dtype, - device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - all_gather_output = torch.empty( - (all_gather_input_numel * world_size,), dtype=dtype, device=device - ) all_gather_input = all_gather_output.narrow( 0, all_gather_input_numel * rank, all_gather_input_numel ) @@ -144,6 +240,7 @@ def foreach_all_gather( all_gather_copy_in_stream: torch.Stream, all_gather_stream: torch.Stream, device: torch.device, + all_gather_comm: AllGather, ) -> Optional[AllGatherResult]: world_size, rank = group.size(), group.rank() device_handle = _get_device_handle(device.type) @@ -162,19 +259,20 @@ def foreach_all_gather( all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)] inp_split_sizes = [t.numel() for t in all_gather_inputs] all_gather_input_numel = sum(inp_split_sizes) + all_gather_output = all_gather_comm.allocate( + (all_gather_input_numel * world_size,), dtype=dtype, device=device + ) all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( all_gather_inputs, + all_gather_output, inp_split_sizes, all_gather_input_numel, - world_size, rank, - dtype, - device, ) del param_all_gather_inputs all_gather_stream.wait_stream(all_gather_copy_in_stream) with device_handle.stream(all_gather_stream): - all_gather_work = dist.all_gather_into_tensor( + all_gather_work = all_gather_comm( output_tensor=all_gather_output, input_tensor=all_gather_input, group=group, @@ -351,15 +449,17 @@ def foreach_reduce( unsharded_grads: list[torch.Tensor], reduce_scatter_group: dist.ProcessGroup, reduce_scatter_stream: torch.Stream, - orig_dtype: torch.dtype, + reduce_scatter_comm: ReduceScatter, + orig_dtype: Optional[torch.dtype], reduce_dtype: Optional[torch.dtype], device: torch.device, - reduce_scatter_reduce_op: Optional[Union[dist.ReduceOp, dist.ReduceOp.RedOpType]], + gradient_divide_factor: Optional[float], all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP all_reduce_stream: torch.Stream, all_reduce_grads: bool, partial_reduce_output: Optional[torch.Tensor], # only used for HSDP all_reduce_hook: Optional[Callable[[torch.Tensor], None]], + force_sum_reduction_for_comms: bool = False, ) -> tuple[ torch.Tensor, torch.Event, @@ -381,8 +481,15 @@ def foreach_reduce( ) grad_dtype = unsharded_grads[0].dtype reduce_dtype = reduce_dtype or grad_dtype - predivide_factor, postdivide_factor = _get_gradient_divide_factors( - reduce_scatter_group, all_reduce_group, reduce_dtype, device.type + (predivide_factor, postdivide_factor, reduce_scatter_op, all_reduce_op) = ( + _get_gradient_divide_factors( + reduce_scatter_group, + all_reduce_group, + reduce_dtype, + device.type, + gradient_divide_factor, + force_sum_reduction_for_comms, + ) ) world_size = reduce_scatter_group.size() for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)): @@ -398,8 +505,10 @@ def foreach_reduce( ) reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) reduce_scatter_output_numel = reduce_scatter_input_numel // world_size - reduce_scatter_input = torch.empty( - (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device + reduce_scatter_input = reduce_scatter_comm.allocate( + (reduce_scatter_input_numel,), + dtype=reduce_dtype, + device=device, ) device_handle = _get_device_handle(device.type) foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size) @@ -410,18 +519,17 @@ def foreach_reduce( all_reduce_input = None all_reduce_event = None with device_handle.stream(reduce_scatter_stream): - reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) + reduce_output = reduce_scatter_comm.allocate( + (reduce_scatter_output_numel,), + dtype=reduce_dtype, + device=device, + ) _div_if_needed(reduce_scatter_input, predivide_factor) - if reduce_scatter_reduce_op is None: - if predivide_factor is None: - reduce_scatter_reduce_op = ReduceOp.AVG - else: - reduce_scatter_reduce_op = ReduceOp.SUM - dist.reduce_scatter_tensor( - output=reduce_output, - input=reduce_scatter_input, + reduce_scatter_comm( + output_tensor=reduce_output, + input_tensor=reduce_scatter_input, group=reduce_scatter_group, - op=reduce_scatter_reduce_op, + op=reduce_scatter_op, ) reduce_scatter_event = reduce_scatter_stream.record_event() post_reduce_stream = reduce_scatter_stream @@ -448,7 +556,7 @@ def foreach_reduce( dist.all_reduce( reduce_output, group=all_reduce_group, - op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, + op=all_reduce_op, ) all_reduce_input = reduce_output all_reduce_event = all_reduce_stream.record_event() @@ -566,25 +674,55 @@ def _get_gradient_divide_factors( all_reduce_group: Optional[dist.ProcessGroup], reduce_dtype: torch.dtype, device_type: str = "", -) -> Union[tuple[None, None], tuple[float, float]]: + factor: Optional[float] = None, + force_sum_reduction_for_comms: bool = False, +) -> tuple[ + Optional[float], + Optional[float], + Union[dist.ReduceOp, dist.ReduceOp.RedOpType], + Union[dist.ReduceOp, dist.ReduceOp.RedOpType], +]: + # MTIA appears to only support SUM reduction, hence we force it implicitly + if device_type == "mtia": + force_sum_reduction_for_comms = True + # For fp32/bf16, we do not need to worry about overflow/underflow, so we # use NCCL's built-in division to avoid separate div kernels - if reduce_dtype in (torch.float32, torch.bfloat16) and device_type != "mtia": - return None, None + overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16) + data_parallel_size = reduce_scatter_group.size() if all_reduce_group is not None: data_parallel_size *= all_reduce_group.size() - # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid - # overflow/underflow. For N data parallel workers, each worker computes - # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid - # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. - factor: int = 1 - while data_parallel_size % factor == 0 and data_parallel_size / factor > factor: - factor *= 2 - factor = float(factor) - return (factor, data_parallel_size / factor) + + if factor is None: + factor = float(data_parallel_size) + + if not overflow_risk and not force_sum_reduction_for_comms: + if factor == data_parallel_size: + # Warning: NCCL ReduceOp.AVG may produce incorrect results with + # world size 1. + return None, None, ReduceOp.AVG, ReduceOp.AVG + else: + reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) + return None, None, reduce_scatter_op, ReduceOp.SUM + + pre_factor: Optional[float] + if overflow_risk: + # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid + # overflow/underflow. For N data parallel workers, each worker computes + # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid + # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. + pre_factor = 1 + while factor % pre_factor == 0 and factor / pre_factor > pre_factor: + pre_factor *= 2 + post_factor = factor / pre_factor + else: + # Prefer post-multiplying as it operates on less data and is thus faster + pre_factor, post_factor = None, factor + + return pre_factor, post_factor, ReduceOp.SUM, ReduceOp.SUM def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None: - if div_factor is not None and div_factor > 1: + if div_factor is not None and div_factor != 1: tensor.div_(div_factor) diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index 362409168a509a..855a706e6d304f 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -42,7 +42,7 @@ - Unsharded parameter: parameter used for forward/backward computation, derived from the all-gather output; autograd leaf -We define these tensors to describe the general framework that can accomodate +We define these tensors to describe the general framework that can accommodate extensions, where: - all-gather-inputs = pre-all-gather-transform(sharded-parameter) - unsharded-parameter = post-all-gather-transform(all-gather-outputs) @@ -376,9 +376,7 @@ def _init_sharded_param( if self.offload_to_cpu and not padded_sharded_param.is_meta: padded_sharded_param = padded_sharded_param.cpu() if self.pin_memory: - padded_sharded_param = padded_sharded_param.pin_memory( - device=self.device - ) + padded_sharded_param = padded_sharded_param.pin_memory() self._sharded_param_data = padded_sharded_param.view(-1) length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0 sharded_param = padded_sharded_param.narrow( @@ -848,7 +846,7 @@ def reset_sharded_param(self): local_tensor = padded_local_tensor updated_local_tensor = True if self.pin_memory and not local_tensor.is_pinned(): - local_tensor = local_tensor.cpu().pin_memory(device=self.device) + local_tensor = local_tensor.cpu().pin_memory() updated_local_tensor = True self._sharded_param_data = local_tensor.view(-1) assert isinstance(self.sharded_param, DTensor) # mypy diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index c9c36654e8821d..121f3d4c138857 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -15,10 +15,16 @@ from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_collectives import ( + AllGather, AllGatherResult, + DefaultAllGather, + DefaultReduceScatter, foreach_all_gather, foreach_all_gather_copy_out, foreach_reduce, + ProcessGroupAllocAllGather, + ProcessGroupAllocReduceScatter, + ReduceScatter, ) from ._fsdp_common import ( compiled_autograd_enabled, @@ -111,7 +117,7 @@ class AllReduceState(NamedTuple): class FSDPParamGroup: """This class represents a parameter group to communicate together.""" - _orig_dtype: torch.dtype + _orig_dtype: Optional[torch.dtype] _reduce_dtype: Optional[torch.dtype] def __init__( @@ -159,6 +165,8 @@ def __init__( self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {} self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {} self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None + self._all_gather_comm: AllGather = DefaultAllGather() + self._reduce_scatter_comm: ReduceScatter = DefaultReduceScatter() # Optional stream to run the user-defined all-reduce hook in # Saved here and not in the comm. context because we allow the user to # specify it, possibly at construction time before lazy init @@ -177,9 +185,12 @@ def __init__( # Whether to reshard parameters after backward (only useful for # gradient accumulation) self.reshard_after_backward: bool = True - # Optional custom reduce-scatter reduce op (e.g. to divide by a - # factor other than the shard world size) - self.reduce_scatter_reduce_op: Optional[dist.ReduceOp] = None + # Optional custom factor for the gradient reduction op (e.g. to divide + # by a factor other than the world size) + self.gradient_divide_factor: Optional[float] = None + # Whether reduce-scatter and all-reduce should be issued using only + # summations, potentially with separate pre-/post-scaling. + self.force_sum_reduction_for_comms: bool = False # `async_op` arg used for pre-forward/pre-backward unshard; can be # overridden to only do explicit prefetching and avoid inter-stream # fragmentation from using separate unshard streams @@ -212,22 +223,27 @@ def __init__( def _init_mp_dtypes(self) -> None: for fsdp_param in self.fsdp_params: fsdp_param.init_dtype_attrs(self.mp_policy) - orig_dtypes = {fsdp_param.orig_dtype for fsdp_param in self.fsdp_params} - if len(orig_dtypes) != 1: - # This can be relaxed if we copy-out for the reduce-scatter + trainable_params: list[FSDPParam] = [ + p for p in self.fsdp_params if p.sharded_param.requires_grad + ] + orig_dtypes = {p.orig_dtype for p in trainable_params} + reduce_dtypes = {p.reduce_dtype for p in trainable_params} + if len(trainable_params) > 0 and len(orig_dtypes) != 1: + # Models may have no grad params raise AssertionError( f"FSDP expects uniform original parameter dtype but got {orig_dtypes}" ) - self._orig_dtype = next(iter(orig_dtypes)) - reduce_dtypes = {fsdp_param.reduce_dtype for fsdp_param in self.fsdp_params} - if len(reduce_dtypes) != 1: + self._orig_dtype = next(iter(orig_dtypes)) if len(trainable_params) else None + if len(trainable_params) > 0 and len(reduce_dtypes) != 1: # This can be relaxed if we issue one reduce-scatter per reduce # dtype (but we would need a way for users to specify multiple # reduce dtypes) raise AssertionError( f"FSDP expects uniform reduce dtype but got {reduce_dtypes}" ) - self._reduce_dtype = next(iter(reduce_dtypes)) + self._reduce_dtype = ( + next(iter(reduce_dtypes)) if len(trainable_params) else None + ) def lazy_init(self): # Lazy init should be idempotent @@ -248,6 +264,36 @@ def lazy_init(self): self._init_mp_dtypes() self._register_state_dict_hooks() + def set_allocate_memory_from_process_group(self, enable: bool) -> None: + """ + Whether to (try to) use the ProcessGroup's allocate_tensor method for + the staging buffers for collective comms. + """ + assert isinstance( + self._all_gather_comm, (DefaultAllGather, ProcessGroupAllocAllGather) + ), ( + "cannot call set_allocate_memory_from_process_group() " + f"when all gather comm is custom: {self._all_gather_comm.__class__.__name__}" + ) + self._all_gather_comm = ( + ProcessGroupAllocAllGather(self._all_gather_process_group) + if enable + else DefaultAllGather() + ) + + assert isinstance( + self._reduce_scatter_comm, + (DefaultReduceScatter, ProcessGroupAllocReduceScatter), + ), ( + "cannot call set_allocate_memory_from_process_group() " + f"when reduce scatter comm is custom: {self._reduce_scatter_comm.__class__.__name__}" + ) + self._reduce_scatter_comm = ( + ProcessGroupAllocReduceScatter(self._reduce_scatter_process_group) + if enable + else DefaultReduceScatter() + ) + # Runtime # def unshard(self, async_op: bool = False): if self._all_gather_result is not None: # already called, pending wait @@ -271,11 +317,12 @@ def unshard(self, async_op: bool = False): async_op, *self.comm_ctx.get_all_gather_streams(async_op, self._training_state), self.device, + self._all_gather_comm, ) def wait_for_unshard(self): """ - 1. In forward with implict prefetching, to overlap the current copy-out + 1. In forward with implicit prefetching, to overlap the current copy-out with the next all-gather, we save a reference to the current all-gather result to free after the next copy-out. 2. Otherwise (explicit prefetching or in backward), we free the @@ -447,15 +494,17 @@ def post_backward(self, *unused: Any): unsharded_grads, self._reduce_scatter_process_group, self.comm_ctx.reduce_scatter_stream, + self._reduce_scatter_comm, self._orig_dtype, self._reduce_dtype, self.device, - self.reduce_scatter_reduce_op, + self.gradient_divide_factor, self._all_reduce_process_group if self._is_hsdp else None, all_reduce_stream, self.all_reduce_grads, self._partial_reduce_output, self._all_reduce_hook, + self.force_sum_reduction_for_comms, ) self.comm_ctx.reduce_scatter_state = ReduceScatterState( reduce_scatter_input, reduce_scatter_event diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py index 765a4a6908e7e6..237f59673828ae 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -16,8 +16,8 @@ _State, ) from torch.distributed.device_mesh import _get_device_handle -from torch.distributed.utils import _to_kwargs -from torch.utils._pytree import tree_flatten, tree_map +from torch.distributed.utils import _apply_to_tensors, _to_kwargs +from torch.utils._pytree import tree_flatten from ._fsdp_api import MixedPrecisionPolicy from ._fsdp_common import ( @@ -81,6 +81,9 @@ def __init__(self) -> None: self._states_to_forward_prefetch: list[FSDPState] = [] self._states_to_backward_prefetch: list[FSDPState] = [] self._modules_to_run_forward: set[nn.Module] = set() + # ``False`` when user set reshard_after_forward + # through ``fully_shard`` or ``set_reshard_after_forward`` + self._auto_reshard_after_forward: Optional[bool] = True # Define a separate init since `__init__` is called in the contract def init( @@ -88,6 +91,7 @@ def init( modules: tuple[nn.Module, ...], device: torch.device, mp_policy: MixedPrecisionPolicy, + auto_reshard_after_forward: bool, ) -> None: for module in modules: _insert_module_state(module, self) @@ -95,6 +99,7 @@ def init( self._device = device self._device_handle = _get_device_handle(device.type) self._mp_policy = mp_policy + self._auto_reshard_after_forward = auto_reshard_after_forward if len(modules) == 1: self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( self._pre_forward, prepend=True, with_kwargs=True @@ -175,7 +180,7 @@ def _lazy_init(self) -> None: state._is_root = False self._state_ctx.all_states.append(state) visited_states.add(state) - if self._fsdp_param_group: + if self._fsdp_param_group and self._auto_reshard_after_forward: # For the root, do not reshard after forward since for training, # the parameters would be freed and all-gathered immediately self._fsdp_param_group.post_forward_mesh_info = None @@ -235,7 +240,10 @@ def _pre_forward( cast_fn = functools.partial( _cast_fp_tensor, self._mp_policy.param_dtype ) - args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs) + args, kwargs = ( + _apply_to_tensors(cast_fn, args), + _apply_to_tensors(cast_fn, kwargs), + ) if self._fsdp_param_group: args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) for fsdp_state in self._states_to_forward_prefetch: @@ -265,7 +273,7 @@ def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any: self._state_ctx.iter_forward_root = None if self._mp_policy.output_dtype is not None: with torch.profiler.record_function("FSDP::cast_forward_outputs"): - output = tree_map( + output = _apply_to_tensors( functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype), output, ) diff --git a/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/torch/distributed/fsdp/_fully_shard/_fully_shard.py index a3300aacd463ce..eb348a00f5f98c 100644 --- a/torch/distributed/fsdp/_fully_shard/_fully_shard.py +++ b/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -14,13 +14,14 @@ TYPE_CHECKING, Union, ) +from typing_extensions import deprecated import torch import torch.nn as nn from torch.distributed._composable import contract from torch.distributed.utils import _get_root_modules -from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_api import AllGather, MixedPrecisionPolicy, OffloadPolicy, ReduceScatter from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo from ._fsdp_init import ( _get_device_from_mesh, @@ -86,7 +87,7 @@ def fully_shard( module, *, mesh: Optional[DeviceMesh] = None, - reshard_after_forward: Union[bool, int] = True, + reshard_after_forward: Optional[Union[bool, int]] = None, shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), offload_policy: OffloadPolicy = OffloadPolicy(), @@ -139,25 +140,26 @@ def fully_shard( placement. The mesh's device type gives the device type used for communication; if a CUDA or CUDA-like device type, then we use the current device. - reshard_after_forward (Union[bool, int]): This controls the parameter + reshard_after_forward (Optional[Union[bool, int]]): This controls the parameter behavior after forward and can trade off memory and communication: - If ``True``, then this reshards parameters after forward and re-all-gathers in backward. - If ``False``, then this keeps the unsharded parameters in memory - after forward and avoids the all-gather in backward. + after forward and avoids the all-gather in backward. For best performance, + we usually set ``False`` for the root module, because the root module + is typically required immediately when the backward pass begins. + - If ``None``, it is set to ``True`` for non-root modules and ``False`` + for root modules. - If an ``int``, then this represents the world size to reshard to after forward. It should be a non-trivial divisor of the ``mesh`` shard dim size (i.e. excluding 1 and the dim size itself). A choice may be the intra-node size (e.g. ``torch.cuda.device_count()``). This allows the all-gather in backward to be over a smaller world size at the cost of higher memory usage than setting to ``True``. - - The root FSDP state has its value specially set to ``False`` as a - heuristic since its parameters would typically be immediately - all-gathered for backward. - After forward, the parameters registered to the module depend on to this: The registered parameters are the sharded parameters if - ``True``; unsharded parameters if ``False``; and the paramters + ``True``; unsharded parameters if ``False``; and the parameters resharded to the smaller mesh otherwise. To modify the parameters between forward and backward, the registered parameters must be the sharded parameters. For ``False`` or an ``int``, this can be @@ -183,6 +185,7 @@ def fully_shard( Returns: FSDPModule: The module with FSDP applied (in-place). """ + torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard") if isinstance(module, (nn.ModuleList, nn.ModuleDict)): raise ValueError( f"fully_shard does not support containers that do not implement forward: {module}" @@ -199,8 +202,12 @@ def fully_shard( ) mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) device = _get_device_from_mesh(mesh) + auto_reshard_after_forward = reshard_after_forward is None + # If the user does not provide ``reshard_after_forward``, we set it to True. + # During lazy_init, we identify which module is the root and override its value to False post_forward_mesh_info = _get_post_forward_mesh_info( - reshard_after_forward, mesh_info + reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type] + mesh_info, ) arg_module = module @@ -208,7 +215,7 @@ def fully_shard( (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) ) state = fully_shard.state(modules[0]) # type: ignore[attr-defined] # see [1] - state.init(modules, device, mp_policy) + state.init(modules, device, mp_policy, auto_reshard_after_forward) managed_modules = _get_managed_modules(modules, ignored_params) params, buffers = _get_managed_states(managed_modules, ignored_params) @@ -368,11 +375,16 @@ def set_reshard_after_forward( recurse (bool): Whether to set for all FSDP submodules or just the passed-in module. """ + if not isinstance(reshard_after_forward, bool): + raise ValueError( + f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}" + ) self_module = cast(nn.Module, self) modules = list(self_module.modules()) if recurse else [self_module] for module in modules: if isinstance(module, FSDPModule): state = module._get_fsdp_state() + state._auto_reshard_after_forward = False if fsdp_param_group := state._fsdp_param_group: fsdp_param_group.post_forward_mesh_info = ( _get_post_forward_mesh_info( @@ -443,6 +455,32 @@ def set_modules_to_backward_prefetch(self, modules: list[FSDPModule]) -> None: module._get_fsdp_state() for module in modules ] + def set_custom_all_gather(self, comm: AllGather) -> None: + """ + Overrides the default ``all_gather`` communication behavior, + to have better control over the communication and memory usage. + See `Comm` and `ReduceScatter` for details. + + Args: + comm (AllGather): Custom all-gather communication. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group._all_gather_comm = comm + + def set_custom_reduce_scatter(self, comm: ReduceScatter) -> None: + """ + Overrides the default ``reduce_scatter`` communication behavior, + to have better control over the communication and memory usage. + See `Comm` and `ReduceScatter` for details. + + Args: + comm (ReduceScatter): Custom reduce_scatter communication. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group._reduce_scatter_comm = comm + def set_all_reduce_hook( self, hook: Callable[[torch.Tensor], None], @@ -487,10 +525,15 @@ def set_post_optim_event(self, event: torch.Event) -> None: """ self._get_fsdp_state()._state_ctx.post_optim_event = event + @deprecated("Use `set_gradient_divide_factor` instead") def set_reduce_scatter_divide_factor(self, factor: float) -> None: + """Use :py:meth:`set_gradient_divide_factor` instead""" + self.set_gradient_divide_factor(factor) + + def set_gradient_divide_factor(self, factor: float) -> None: """ - Sets a custom divide factor for the reduce-scatter. This becomes a - custom reduce op using NCCL's PreMulSum, which allows multiplying by + Sets a custom divide factor for the gradient reduction. This might use + a custom reduce op using NCCL's PreMulSum, which allows multiplying by the factor before reduction. Args: @@ -498,9 +541,28 @@ def set_reduce_scatter_divide_factor(self, factor: float) -> None: """ state = self._get_fsdp_state() if (fsdp_param_group := state._fsdp_param_group) is not None: - mul_factor = 1.0 / float(factor) - reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor) - fsdp_param_group.reduce_scatter_reduce_op = reduce_op + fsdp_param_group.gradient_divide_factor = factor + + def set_force_sum_reduction_for_comms(self, enable: bool) -> None: + """ + Sets whether to require the low-level collective communication + primitives to exclusively use "sum"-type reductions, even if it comes + at the cost of separate additional pre- or post-scaling operations. + This is needed for example because NCCL currently supports zero-copy + transfers only for this kind of collectives. + + NB: for MTIA devices, this is always implicitly enabled. + + NB: if `set_all_reduce_hook` is used under FSDP setup, the caller needs + to ensure the custom all-reduce across FSDP units follow this strategy + as well, as FSDP can no longer automatically handle that. + + Args: + enable (bool): Whether to only ever use ReduceOp.SUM for comms. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.force_sum_reduction_for_comms = enable def set_unshard_in_backward(self, unshard_in_backward: bool) -> None: """ @@ -513,6 +575,27 @@ def set_unshard_in_backward(self, unshard_in_backward: bool) -> None: if (fsdp_param_group := state._fsdp_param_group) is not None: fsdp_param_group.unshard_in_backward = unshard_in_backward + def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None: + """ + Sets whether the temporary staging buffers used to send and receive data + over collective communications should be allocated using the custom + optimized allocator provided by the ProcessGroup itself (if any). This + might allow the ProcessGroup to be more efficient. For example, when + using NCCL, this enables it to leverage zero-copy transfers over SHARP + (for NVLink and/or InfiniBand). + + This cannot be used together with :meth:`set_custom_all_gather` or + :meth:`set_custom_reduce_scatter` as those APIs allow for + finer-grained control over each communication, and this method cannot + determine their staging buffer allocation strategy. + + Args: + enable (bool): Whether to turn on ProcessGroup allocation. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.set_allocate_memory_from_process_group(enable) + def _set_unshard_async_op(self, async_op: bool): """ Sets whether to use ``async_op=True`` or ``False`` for the pre-forward diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index feaf8b8829630d..b145b3e059a696 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -361,7 +361,7 @@ def _init_device_handle( See the :ref:`Accelerators` for details. - This method will be called once ignored paramters was determined, as the device handle maybe needed + This method will be called once ignored parameters was determined, as the device handle maybe needed for other initialization. """ determined_device = None @@ -517,7 +517,7 @@ def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPS if device_mesh and root_mesh != state._device_mesh: state._fsdp_extension = DTensorExtensions(state._device_handle) else: - # We need to explicilty set _fsdp_extension to None. + # We need to explicitly set _fsdp_extension to None. # Otherwise, we will run into an infinite recursion when getting the attribute. state._fsdp_extension = None return state diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 64d56a3391672b..671995671c75b3 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1417,7 +1417,7 @@ def _unflatten_orig_param_states( ) -> None: """ Given a output state dict, ``output_states``, which the keys are FQNs to the - original parameters (not FlatParameters nor parmeter ID), and the values + original parameters (not FlatParameters nor parameter ID), and the values are gathered states, unflatten the states to the original dimensions. This function performs the unflattening process in-place. @@ -1656,7 +1656,7 @@ def _gather_all_orig_param_state( ) -> dict[str, Any]: """ Given a optimizer state dict, ``input_states``, which the keys are FQNs to the - original parameters (not FlatParameters nor parmeter ID), gather all the + original parameters (not FlatParameters nor parameter ID), gather all the states and unflatten them to the original dimensions. Note that all the params referred by the ``input_states`` must be managed by FSDP. """ @@ -2057,7 +2057,7 @@ def _set_optim_use_dtensor( fsdp_state: _FSDPState, state_dict_settings: StateDictSettings, ) -> None: - # If device_mesh is passed in when initalizing FSDP, we automatically turn the + # If device_mesh is passed in when initializing FSDP, we automatically turn the # _use_dtensor flag to be true for ShardedOptimStateDictConfig() if state_dict_type # has to be set to SHARDED_STATE_DICT. if getattr(fsdp_state, "_device_mesh", None): diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index f723e8e0464edd..f4dd3d2b35bd11 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -518,7 +518,7 @@ def _root_pre_forward( _p_assert(state._is_root is not None, "Expects a root FSDP to have been set") if not state._is_root: # Always cast forward inputs in the root of this local FSDP unit for mixed - # precision, as this is where mixed precision could be configed. + # precision, as this is where mixed precision could be configured. # This is more useful for auto wrapping that is recommended in composable path. # For manual wrapping, cast forward inputs on each local FSDP unit root will # increase some overhead, so not turned on for model wrapper path right now where diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 399da7e2f437ea..d59b5b4492c0bb 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -784,7 +784,7 @@ def _pre_state_dict_hook( @no_type_check def _set_use_dtensor(fsdp_state: _FSDPState) -> None: - # If device_mesh is passed in when initalizing FSDP, we automatically turn the + # If device_mesh is passed in when initializing FSDP, we automatically turn the # _use_dtensor flag to be true for ShardedStateDictConfig(). if getattr(fsdp_state, "_device_mesh", None): state_dict_type = fsdp_state._state_dict_type diff --git a/torch/distributed/fsdp/_traversal_utils.py b/torch/distributed/fsdp/_traversal_utils.py index 5ca758c83a9729..51140d3b0a8d3d 100644 --- a/torch/distributed/fsdp/_traversal_utils.py +++ b/torch/distributed/fsdp/_traversal_utils.py @@ -1,7 +1,7 @@ """ NOTE: This file must be imported like ``import torch.distributed.fsdp._traversal_utils`` and not like -``from torch.distirbuted.fsdp._traversal_utils import ...`` to avoid circular +``from torch.distributed.fsdp._traversal_utils import ...`` to avoid circular imports. For brevity, we may import the file as ``traversal_utils``. """ diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 0eafd26e31f961..491b26e0814174 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -117,12 +117,10 @@ class OptimStateKeyType(Enum): class FullyShardedDataParallel(nn.Module, _FSDPState): """A wrapper for sharding module parameters across data parallel workers. - This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. + This is inspired by `Xu et al. `_ as + well as the ZeRO Stage 3 from `DeepSpeed `_. FullyShardedDataParallel is commonly shortened to FSDP. - .. _`Xu et al.`: https://arxiv.org/abs/2004.13336 - .. _DeepSpeed: https://www.deepspeed.ai/ - To understand FSDP internals, refer to the :ref:`fsdp_notes`. @@ -388,7 +386,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): ``ignored_modules`` soon. For backward compatibility, we keep both ``ignored_states`` and `ignored_modules``, but FSDP only allows one of them to be specified as not ``None``. - device_mesh (Optional[DeviceMesh]): DeviceMesh can be used as an altenative to + device_mesh (Optional[DeviceMesh]): DeviceMesh can be used as an alternative to process_group. When device_mesh is passed, FSDP will use the underlying process groups for all-gather and reduce-scatter collective communications. Therefore, these two args need to be mutually exclusive. For hybrid sharding strategies such as diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index b1611130c9e03a..4a8d41c9358a11 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -320,8 +320,8 @@ def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = ( - "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ - torch.FloatTensor with requires_grad=False." + "new_scale should be a float or a 1-element torch.cuda.FloatTensor or " + "torch.FloatTensor with requires_grad=False." ) assert new_scale.device.type == self._device, reason assert new_scale.numel() == 1, reason diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index d8e2017e7e1511..474184ca15d9ae 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -64,6 +64,9 @@ class LaunchConfig: local_addr: address of the local node if any. If not set, a lookup on the local machine's FQDN will be performed. local_ranks_filter: ranks for which to show logs in console. If not set, show from all. + event_log_handler: name of the event logging handler as registered in + `elastic/events/handlers.py `_. + .. note:: `rdzv_timeout` is a legacy argument that will be removed in future. @@ -87,6 +90,7 @@ class LaunchConfig: log_line_prefix_template: Optional[str] = None metrics_cfg: dict[str, str] = field(default_factory=dict) local_addr: Optional[str] = None + event_log_handler: str = "null" def __post_init__(self): default_timeout = 900 @@ -194,18 +198,19 @@ def launch_agent( logger.info( "Starting elastic_operator with launch configs:\n" - " entrypoint : %(entrypoint)s\n" - " min_nodes : %(min_nodes)s\n" - " max_nodes : %(max_nodes)s\n" - " nproc_per_node : %(nproc_per_node)s\n" - " run_id : %(run_id)s\n" - " rdzv_backend : %(rdzv_backend)s\n" - " rdzv_endpoint : %(rdzv_endpoint)s\n" - " rdzv_configs : %(rdzv_configs)s\n" - " max_restarts : %(max_restarts)s\n" - " monitor_interval : %(monitor_interval)s\n" - " log_dir : %(log_dir)s\n" - " metrics_cfg : %(metrics_cfg)s\n", + " entrypoint : %(entrypoint)s\n" + " min_nodes : %(min_nodes)s\n" + " max_nodes : %(max_nodes)s\n" + " nproc_per_node : %(nproc_per_node)s\n" + " run_id : %(run_id)s\n" + " rdzv_backend : %(rdzv_backend)s\n" + " rdzv_endpoint : %(rdzv_endpoint)s\n" + " rdzv_configs : %(rdzv_configs)s\n" + " max_restarts : %(max_restarts)s\n" + " monitor_interval : %(monitor_interval)s\n" + " log_dir : %(log_dir)s\n" + " metrics_cfg : %(metrics_cfg)s\n" + " event_log_handler : %(event_log_handler)s\n", { "entrypoint": entrypoint_name, "min_nodes": config.min_nodes, @@ -219,6 +224,7 @@ def launch_agent( "monitor_interval": config.monitor_interval, "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] "metrics_cfg": config.metrics_cfg, + "event_log_handler": config.event_log_handler, }, ) @@ -245,6 +251,7 @@ def launch_agent( master_addr=master_addr, master_port=master_port, local_addr=config.local_addr, + event_log_handler=config.event_log_handler, ) agent = LocalElasticAgent( @@ -260,7 +267,7 @@ def launch_agent( result = agent.run() # records that agent.run() has succeeded NOT that workers have succeeded - events.record(agent.get_event_succeeded()) + events.record(agent.get_event_succeeded(), config.event_log_handler) if result.is_failed(): # ChildFailedError is treated specially by @record @@ -280,10 +287,10 @@ def launch_agent( # since this closes the rendezvous on this rdzv_id permanently and # prevents any additional scaling events shutdown_rdzv = False - events.record(agent.get_event_failed()) + events.record(agent.get_event_failed(), config.event_log_handler) raise except Exception: - events.record(agent.get_event_failed()) + events.record(agent.get_event_failed(), config.event_log_handler) raise finally: if shutdown_rdzv: diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 7fc3fd6736ea55..00d96739e517c2 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -323,5 +323,5 @@ def _post_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: def _gen_param_group_key(param_keys: list[str]) -> str: - """Concatenate all param keys as a unique indentifier for one param group.""" + """Concatenate all param keys as a unique identifier for one param group.""" return "/".join(sorted(param_keys)) diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index cb7fb8a26a262a..b1664cd588bbea 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import logging from collections import defaultdict diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index 3c0027d1124073..44d59cab44e4f2 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -9,7 +9,7 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): r""" Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD `_, This optimizer runs local optimizer at every step. - After the warm-up stage, it averages parameters periodically afer the local optimizer is applied. + After the warm-up stage, it averages parameters periodically after the local optimizer is applied. Args: optim: The local optimizer. @@ -61,7 +61,7 @@ def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverag self.averager = averager @property - def state(self): + def state(self): # type: ignore[override] return self.optim.state def __repr__(self): diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index e8414fd1374bc4..18e4ed1ea6e324 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -284,7 +284,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): r""" Wrap an arbitrary :class:`optim.Optimizer ` and shards its states across ranks in the group. - The sharing is done as described by ZeRO_. + The sharing is done as described by `ZeRO `_. The local optimizer instance in each rank is only responsible for updating approximately ``1 / world_size`` parameters and @@ -365,9 +365,6 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): is to prepend dummy inputs. .. warning:: ZeroRedundancyOptimizer is experimental and subject to change. - - .. _ZeRO: https://arxiv.org/abs/1910.02054 - """ def __init__( diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 4e1b9676d7caf3..f21e9cde8d3754 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1209,7 +1209,7 @@ def pipeline( Arguments --------- module: - The module to be splitted. + The module to be split. mb_args: Example positional inputs, in micro-batch form. mb_kwargs: diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index 58ea3b88d054d4..1bfcc53830e8e4 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -244,7 +244,7 @@ def stage_backward_weight( # Break a reference cycle caused inside stage_backward_input->get_hook->hook # The summarized cycle is: # `hook` -> cell -> param_group -> intermediates -> `hook` - # becuase we install the hook function onto each of the intermediate autograd nodes. + # because we install the hook function onto each of the intermediate autograd nodes. # We need to keep intermediates alive up until backward_weight, but we can free it now. del param_group["intermediates"] diff --git a/torch/distributed/pipelining/_schedule_visualizer.py b/torch/distributed/pipelining/_schedule_visualizer.py index ccda7177e88902..b39a806fa776f8 100644 --- a/torch/distributed/pipelining/_schedule_visualizer.py +++ b/torch/distributed/pipelining/_schedule_visualizer.py @@ -56,14 +56,14 @@ def get_schedule_ops( num_stages_per_rank = 1 assert num_stages_per_rank == 1 stages = mock_pipeline_stage - stages.num_stages = num_stages_per_rank + stages.num_stages = num_stages_per_rank * pp_degree elif issubclass(schedule_class, PipelineScheduleMulti): if num_stages_per_rank is None: num_stages_per_rank = 2 assert num_stages_per_rank >= 2 stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)] for stage in stages: - stage.num_stages = num_stages_per_rank + stage.num_stages = num_stages_per_rank * pp_degree else: raise ValueError(f"Invalid schedule: {schedule_class}") diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index 28d5daf8d23630..61f87fb7fd6a63 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging +import operator from typing import Any, Optional import torch @@ -46,7 +47,7 @@ class _LossReducer(_CustomReducer): pass -sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b) +sum_reducer = _LossReducer(torch.tensor(0.0), operator.add) # Default chunking dimension is 0. This is used for the case where the user did # not specify a chunking dimension. diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index c3b3165777441b..51394c4f0b63dc 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -485,6 +485,10 @@ def __init__( or equal to the number of stages ({self._num_stages})." ) + self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = ( + self._get_pipeline_order() + ) + def _initialize_stage(self, args, kwargs): self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) if self._has_backward: @@ -524,6 +528,24 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): else: return None + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline execution order as a schedule IR. + + The returned IR is a dictionary mapping rank IDs to lists of actions. + Each action is either an _Action object representing computation to perform, + or None representing a deliberate idle step. + + The None values are used to represent pipeline bubbles where a rank + must wait for dependencies from other ranks before proceeding. However + during execution, with the _PipelineScheduleRuntime, these Nones are + skipped since the relevant communication (send/recv) will be scheduled and waited on. + + Returns: + A dictionary mapping rank -> list of actions + """ + return None + class _ScheduleForwardOnly(PipelineScheduleSingle): """ @@ -666,6 +688,38 @@ def _step_microbatches( for work in bwd_sends_to_wait: _wait_batch_p2p(work) + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline order for GPipe schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[Optional[_Action]] = [] + + # 1. Initial delay based on rank position + warmup_delay = rank + actions.extend([None] * warmup_delay) + + # 2. Forward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx)) + + # 3. Wait period before backward passes can begin + backward_delay = 3 * (pp_group_size - 1 - rank) + actions.extend([None] * backward_delay) + + # 4. Backward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx)) + + pipeline_order[rank] = actions + + return pipeline_order + class Schedule1F1B(PipelineScheduleSingle): """ @@ -728,7 +782,7 @@ def _step_microbatches( # Safe to fire send_work = _batch_p2p(fwd_sends, desc="fwd_send") # otherwise: - # The last foward send is left for fuse with first 1B in 1B1F below + # The last forward send is left for fuse with first 1B in 1B1F below # Compute loss self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) @@ -813,6 +867,74 @@ def _step_microbatches( # Return losses if there is a container passed in self._update_losses(self._stage, losses) + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline order for 1F1B schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[Optional[_Action]] = [] + + # 1. Warmup phase: initial delay based on rank + actions.extend([None] * rank) + + # 2. Initial forward passes before 1F1B phase + num_forward = (pp_group_size - 1) - rank + forward_mb = 0 + for i in range(num_forward): + actions.append(_Action(rank, _ComputationType.FORWARD, i)) + forward_mb = i + + # 3. Wait for backward to be ready + wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank)) + actions.extend([None] * wait_for_1f1b) + + # 4. 1F1B steady state phase + backward_mb = 0 + remaining_forward = self._n_microbatches - num_forward + + while remaining_forward > 0: + # One forward + forward_mb += 1 + actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb)) + remaining_forward -= 1 + + # One backward + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + + # 5. Cooldown phase: remaining backward passes + remaining_backward = self._n_microbatches - backward_mb + + while remaining_backward > 0: + # Add None and backward actions in alternating pattern + # based on distance from the last stage + if (pp_group_size - rank) > 0: + actions.append(None) + # Decrement the wait counter only if we still have backward passes to do + if remaining_backward > 0: + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + else: + # If we're at the last stage, just add backward actions without None + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + + pipeline_order[rank] = actions + return pipeline_order + def _add_unshard_reshard( compute_actions: list[Optional[_Action]], @@ -821,7 +943,7 @@ def _add_unshard_reshard( """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. - RESHARD does the opposite, releasing memory (but doing no commmunication) + RESHARD does the opposite, releasing memory (but doing no communication) We abandon the "timestep lock" during lowering @@ -1482,7 +1604,7 @@ def _load_actions( raise NotImplementedError(f"{format=} is not implemented") def _load_csv(self, filename: str, format: str = "compute_only"): - """Loads a csv in simple format and then lowers it to include comunication actions + """Loads a csv in simple format and then lowers it to include communication actions format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes will automatically be run to generate a compute_comms schedule. @@ -1552,7 +1674,7 @@ def _step_microbatches( bwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {} fwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {} - # send ops should be waited on before step() exists, mainly for hygeine + # send ops should be waited on before step() exists, mainly for hygiene send_ops: list[list[dist.Work]] = [] # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages @@ -1768,7 +1890,7 @@ class ScheduleLoopedBFS(PipelineScheduleMulti): """ Breadth-First Pipeline Parallelism. See https://arxiv.org/abs/2211.05953 for details. - Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. + Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. What is different is that when microbatches are ready for multiple local stages, Loops BFS will prioritizes the earlier stage, running all available microbatches at once. diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 491324e3f24091..df229c98320906 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -55,7 +55,7 @@ def _normalize_model_output_as_tuple(output: Any) -> tuple[Any]: # output in list format output = tuple(output) - # Unify output form to tuple for easy correspondance with + # Unify output form to tuple for easy correspondence with # `act_send_info` output_tuple = output if type(output) is tuple else (output,) return output_tuple @@ -267,7 +267,7 @@ def _create_grad_send_info( def map_recv_to_send(a): # Note: we send gradients back to previous stage as long as in # forward it is a received input, regardless of whether it requires - # grad. It is up to the previous stage to disgard this gradient. + # grad. It is up to the previous stage to discard this gradient. if isinstance(a, _RecvInfo): grad_send_info.append(a.source) return a.source @@ -919,7 +919,7 @@ def _validate_fwd_input(self, args, kwargs): def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]): """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. - Most likely, this could be cause either by incorrect user specification of output shapes, or becuase + Most likely, this could be cause either by incorrect user specification of output shapes, or because shape inference was done on the original model but then at runtime the model is wrapped with something like mixed precision which changes output dtype. """ @@ -1010,7 +1010,7 @@ def _prepare_forward_infra( """ # TODO(whc) # this method should be deleted once lazy buffer allocation is implemented - # for now, it ignores args/kwargs becuase it should not need to do shape inference + # for now, it ignores args/kwargs because it should not need to do shape inference for chunk in range(num_microbatches): self.args_recv_info[chunk] = self._create_act_recv_info() @@ -1272,7 +1272,7 @@ def __init__( super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) self.inputs: Optional[list[torch.Tensor]] = None self.inputs_meta: Optional[tuple[torch.Tensor, ...]] = None - # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) becuase it + # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) because it # might be breaking for existing users. if input_args is None: assert output_args is None, ( diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 4722f76413c028..a7b8c358d9abce 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -162,7 +162,7 @@ def _create_c10d_store( hostname, port, rank, world_size, timeout, use_libuv=True ) -> Store: """ - Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. + Smartly creates a c10d Store object on ``rank`` based on whether we need to reuse agent store. The TCPStore server is assumed to be hosted on ``hostname:port``. diff --git a/torch/distributed/run.py b/torch/distributed/run.py index b1c073dc861f3c..d3f3d84d201bd8 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -486,6 +486,14 @@ def get_args_parser() -> ArgumentParser: choices=["spawn", "fork", "forkserver"], help="Multiprocessing start method to use when creating workers.", ) + parser.add_argument( + "--event-log-handler", + "--event_log_handler", + action=env, + type=str, + default="null", + help="name of a registered event logging handler (see: https://docs.pytorch.org/docs/stable/elastic/events.html)", + ) parser.add_argument( "--role", action=env, @@ -523,7 +531,7 @@ def get_args_parser() -> ArgumentParser: type=str, default=None, help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same " - "directory is re-used for multiple runs (a unique job-level sub-directory is created with " + "directory is reused for multiple runs (a unique job-level sub-directory is created with " "rdzv_id as the prefix).", ) parser.add_argument( @@ -712,7 +720,7 @@ def get_use_env(args) -> bool: def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]: """ - Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. + Attempts to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. Provides plugin mechanism to provide custom implementation of LogsSpecs. Returns `DefaultLogsSpecs` when logs_spec_name is None. @@ -817,6 +825,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str log_line_prefix_template=log_line_prefix_template, local_addr=args.local_addr, logs_specs=logs_specs, + event_log_handler=args.event_log_handler, ) with_python = not args.no_python diff --git a/torch/distributed/tensor/README.md b/torch/distributed/tensor/README.md index fc7eb0135bcb00..17185caec95bd6 100644 --- a/torch/distributed/tensor/README.md +++ b/torch/distributed/tensor/README.md @@ -49,7 +49,7 @@ We offer both a lower level DistributedTensor API and a module level API to crea Here are some basic DTensor API examples that showcase: 1. How to construct a DTensor directly, to represent different types of sharding, replication, sharding + replication strategies. 2. How to create DTensor from a local `torch.Tensor`. -3. How to “reshard” an existing DTensor to a different DTensor with modified placement strategy or world size. +3. How to “reshard” an existing DTensor to a different DTensor with a new DTensor Layout. ```python # torchrun --standalone --nnodes=1 --nproc-per-node=4 dtensor_example.py diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index e2e0104ecf2581..bb46549e400966 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -346,7 +346,7 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): @torch._disable_dynamo # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] return DTensor._op_dispatcher.dispatch( func, args, @@ -486,7 +486,7 @@ def redistribute( ) -> "DTensor": """ ``redistribute`` performs necessary collective operations that redistribute the current - DTensor from its current placements to a new placements, or from is current DeviceMesh + DTensor from its current placements to a new placements, or from its current DeviceMesh to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by specifying a Replicate placement for each dimension of the DeviceMesh. diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 839021e0f6e5c7..83270b5a64bb78 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -13,13 +13,7 @@ import torch.distributed.tensor._random as random from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta -from torch.distributed.tensor._op_schema import ( - _is_inplace_op, - _is_out_variant_op, - OpInfo, - OpSchema, - OutputSpecType, -) +from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType from torch.distributed.tensor._random import is_rng_supported_mesh from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor._sharding_prop import ShardingPropagator @@ -260,13 +254,13 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: # perform reduce on the collection with AND op local_results = functools.reduce(operator.and_, obj_list, True) - if _is_inplace_op(op_call): + if op_info.schema.is_inplace_op(): # inplace op should return self instead of re-wrapping if output_sharding.output_spec is not None: return args[0] else: return None - elif _is_out_variant_op(op_call): + elif op_info.schema.is_out_variant_op(): # out variant could possibly have multiple out args (i.e. lu_unpack.out) output_specs = ( (output_sharding.output_spec,) diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index 27672206a8d868..d103e8ab250bef 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -44,20 +44,6 @@ def _rebuild_tensor_from_dtensor_meta(arg) -> object: ) -def _is_inplace_op(op: OpOverload): - # simple analysis of function schema to determine - # if this is an inplace variant, it might not - # be entirely correct, but it's good enough for now. - return op._schema.name[-1] == "_" - - -def _is_out_variant_op(op: OpOverload): - # simple analysis of function schema to determine - # if this is an out variant, it might not - # be entirely correct, but it's good enough for now. - return "out" in op._schema.overload_name - - def _pretty_print_spec(spec: object) -> str: if spec is None: return "None" @@ -70,10 +56,10 @@ def _pretty_print_spec(spec: object) -> str: @dataclass -class PlacementStrategy: +class OpSpec: """ - A placement strategy describes acceptable sharding placements of the output - and the tensor arguments of an operation. + An OpSpec describes an acceptable sharding placements of an operation, with the + specified DTensorSpecs for both the output and the inputs. note: when the op return value is a single DTensor object, output_specs is DTensorSpec; when the return value is a tuple of Optional[DTensor], @@ -83,10 +69,9 @@ class PlacementStrategy: output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]] input_specs: Optional[Sequence[DTensorSpec]] = None - # redistribute costs for this op placement strategy - # we need a nested list to record the cost for each - # operand of this operator, and for each operand of - # this operator it might have multiple placement strategies + # redistribute costs to redistribute the operator input shardings to this OpSpec. + # Note that We need a nested list to record the cost for each operand of this + # operator, and for each operand of this operator it might have multiple OpSpecs. redistribute_cost: Optional[list[list[float]]] = None @cached_property @@ -116,7 +101,7 @@ def mesh(self): ) def input_spec(self, index: int = 0) -> DTensorSpec: - assert self.input_specs is not None, "input_specs of PlacementStrategy is None!" + assert self.input_specs is not None, "input_specs of OpSpec is None!" assert len(self.input_specs) > index, ( f"Invalid index {index} for input_specs of length " f"{len(self.input_specs)}: {self.input_specs}" @@ -141,12 +126,13 @@ class StrategyType: class OpStrategy(StrategyType): """ - OpStrategy that consists of a list of placement strategies associated with the op + OpStrategy that consists of a list of sharding strategies associated with the op, + where each strategy is an OpSpec that describes the acceptable input/output sharding. """ - def __init__(self, strategies: list[PlacementStrategy]) -> None: + def __init__(self, strategies: list[OpSpec]) -> None: super().__init__() - self.strategies: list[PlacementStrategy] = strategies + self.strategies: list[OpSpec] = strategies def __str__(self) -> str: strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies]) @@ -155,7 +141,7 @@ def __str__(self) -> str: def max_num_shards(self) -> int: """ - Returns the max number of shards across all placement strategies + Returns the max number of shards across all OpSpecs """ return max(strategy.output_spec.num_shards for strategy in self.strategies) @@ -178,14 +164,14 @@ def shape(self): class TupleStrategy(StrategyType): """ - TupleStrategy represents the output strategy of this op is a tuple - of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors - with possibly different placement strategies, we should return a TupleStrategy that - contains a tuple of OpStrategy, where each child represents the sharding strategy - of "each element" of the tuple/list of tensors the op returns. - - NOTE: if the output of the op is a List[Tensor] and they share the same placement - strategy, then we should return a single OpStrategy instead of a TupleStrategy + TupleStrategy represents the output strategy of this op is a tuple of OpStrategies, + i.e. If the output of this op is a tuple of tensors or list of tensors with possibly + different OpStrategies, we should return a TupleStrategy that contains a tuple of + OpStrategy, where each child represents the sharding strategy of "each element" of + the tuple/list of tensors the op returns. + + NOTE: if the output of the op is a List[Tensor] and they share the same OpStrategy, + then we should return a single OpStrategy instead of a TupleStrategy """ def __init__(self, childs: Sequence[StrategyType]) -> None: @@ -229,8 +215,8 @@ class RuntimeSchemaInfo: class OpSchema: """ OpSchema is a data class that describes an operator input schemas, it includes - DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order - preserved). It is mainly used by the DTensor's dispatching logic to perform various + DTensorSpecs/OpStrategies (instead of DTensor) and non-tensor args/kwargs (positional + order preserved). It is mainly used by the DTensor's dispatching logic to perform various actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.) NOTE: this should be used as a read only data class @@ -296,9 +282,9 @@ def __str__(self) -> str: args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs)) mesh_shape = arg.mesh_shape elif isinstance(arg, TupleStrategy): - first_op_strtgy = arg.childs[0] - assert isinstance(first_op_strtgy, OpStrategy) - mesh_shape = first_op_strtgy.mesh_shape + first_op_strategy = arg.childs[0] + assert isinstance(first_op_strategy, OpStrategy) + mesh_shape = first_op_strategy.mesh_shape args_schema.append(str(arg)) else: args_schema.append(str(arg)) @@ -376,6 +362,18 @@ def get_mesh_from_args(self, validate: bool = True) -> DeviceMesh: return mesh + def is_inplace_op(self) -> bool: + # simple analysis of function schema to determine + # if this is an inplace variant, it might not + # be entirely correct, but it's good enough for now. + return self.op._schema.name[-1] == "_" + + def is_out_variant_op(self) -> bool: + # simple analysis of function schema to determine + # if this is an out variant, it might not + # be entirely correct, but it's good enough for now. + return "out" in self.op._schema.overload_name + def __hash__(self) -> int: # Only hash args and kwargs that op indicates to hash if not self.schema_info: diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 7a2f500fca7742..d70cc130dfc293 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -4,12 +4,7 @@ import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta -from torch.distributed.tensor._op_schema import ( - _is_inplace_op, - _is_out_variant_op, - OpSchema, - OutputSharding, -) +from torch.distributed.tensor._op_schema import OpSchema, OutputSharding from torch.distributed.tensor._ops.utils import prod from torch.distributed.tensor._utils import compute_local_shape_and_global_offset @@ -271,12 +266,12 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}" enforce_sharding: dict[str, int] = {} - if _is_inplace_op(op_schema.op): - # inplace op should keep the input sharding it writes to - enforce_sharding.update(zip(out_dimchars, input_specs[0].dim_map)) - elif _is_out_variant_op(op_schema.op): - out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"]) - enforce_sharding.update(zip(out_dimchars, out_spec.dim_map)) + if op_schema.is_inplace_op(): + follow_spec = op_schema.args_spec[0] + enforce_sharding.update(zip(out_dimchars, follow_spec.dim_map)) + elif op_schema.is_out_variant_op(): + follow_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"]) + enforce_sharding.update(zip(out_dimchars, follow_spec.dim_map)) return einop_rule( fmt, diff --git a/torch/distributed/tensor/_ops/_einsum_strategy.py b/torch/distributed/tensor/_ops/_einsum_strategy.py index 5953721d219c86..b666cae0e22e38 100644 --- a/torch/distributed/tensor/_ops/_einsum_strategy.py +++ b/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -3,7 +3,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec -from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor._op_schema import OpSpec, OpStrategy from torch.distributed.tensor.placement_types import ( Partial, Placement, @@ -167,7 +167,7 @@ def gen_einsum_strategies( all_strategies = [] for strategy_comb in strategy_combs: spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)] - strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:]) + strat = OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:]) all_strategies.append(strat) return OpStrategy(all_strategies) diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 41c03c3cb84163..b1f3b249e4a07f 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -11,9 +11,9 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, + OpSpec, OpStrategy, PlacementList, - PlacementStrategy, RuntimeSchemaInfo, TupleStrategy, ) @@ -267,20 +267,20 @@ def common_reduction_strategy( # by default follow reduction input strategy reduction_strategy = OpStrategy([]) - for strtg in input_strategy.strategies: + for op_spec in input_strategy.strategies: if not reduction_linear: # input placements for this strategy should clear out pending sum and sharding # on the reduction dimension input_placements = replicate_reduction_dims( - strtg.output_spec.placements, reduce_dims + op_spec.output_spec.placements, reduce_dims ) else: - input_placements = strtg.output_spec.placements + input_placements = op_spec.output_spec.placements input_spec = DTensorSpec( mesh=input_strategy.mesh, placements=input_placements, - tensor_meta=strtg.output_spec.tensor_meta, + tensor_meta=op_spec.output_spec.tensor_meta, ) reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim) @@ -289,7 +289,7 @@ def common_reduction_strategy( ) redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)] reduction_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=DTensorSpec( mesh=input_strategy.mesh, placements=out_placements, @@ -465,7 +465,7 @@ def linalg_replicate_strategy(op_schema: OpSchema) -> OpStrategy: assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" mesh = input_strategy.mesh - output_strategies: list[PlacementStrategy] = [] + output_strategies: list[OpSpec] = [] for placement_strategy in input_strategy.strategies: replicate_placements = tuple(Replicate() for _ in range(mesh.ndim)) replicate_spec = DTensorSpec( @@ -476,7 +476,7 @@ def linalg_replicate_strategy(op_schema: OpSchema) -> OpStrategy: redistribute_cost = [ generate_redistribute_costs(input_strategy, replicate_spec) ] - replicate_strategy = PlacementStrategy( + replicate_strategy = OpSpec( output_specs=replicate_spec, input_specs=(replicate_spec,), redistribute_cost=redistribute_cost, @@ -514,7 +514,7 @@ def softmax_strategy(op_schema: OpSchema) -> OpStrategy: ) output_target_spec = input_target_spec output_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=output_target_spec, input_specs=[input_target_spec], redistribute_cost=redistribute_costs, @@ -559,7 +559,7 @@ def softmax_backward_strategy(op_schema: OpSchema) -> OpStrategy: redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec) redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec) grad_in_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=tgt_spec, redistribute_cost=[redist_grad_out_cost, redist_out_cost], ) @@ -682,7 +682,7 @@ def nll_loss_forward_strategy(op_schema: OpSchema) -> OpStrategy: ) output_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=(output_expected_spec, total_weight_expected_spec), input_specs=op_args_target_specs, redistribute_cost=redistribute_costs, @@ -797,7 +797,7 @@ def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy: grad_in_expected_spec = input_expected_spec grad_in_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=grad_in_expected_spec, input_specs=op_args_target_specs, redistribute_cost=redistribute_costs, @@ -894,7 +894,7 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: # the output spec is the same as input spec output_target_spec = input_target_spec output_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=output_target_spec, input_specs=op_args_target_specs, redistribute_cost=redistribute_costs, @@ -944,7 +944,7 @@ def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: # output triple: (d_input, d_weight, d_bias) out_tuple_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): - # args for PlacementStrategy + # args for OpSpec output_specs_list: list[Optional[DTensorSpec]] = [] input_specs_list: list[DTensorSpec] = [] redistribute_costs = [] @@ -1052,7 +1052,7 @@ def _add_target_input_spec(strategy) -> DTensorSpec: output_specs_list.append(None) out_tuple_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=tuple(output_specs_list), input_specs=input_specs_list, redistribute_cost=redistribute_costs, diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index fbe26d144456e7..b7804d318104da 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -9,9 +9,9 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, + OpSpec, OpStrategy, PlacementList, - PlacementStrategy, RuntimeSchemaInfo, ) from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies @@ -48,7 +48,7 @@ def transpose_strategy(op_schema: OpSchema) -> OpStrategy: Shard(1 - p.dim) if isinstance(p, Shard) else p for p in input_spec.placements ] - transpose_strategy = PlacementStrategy( + transpose_strategy = OpSpec( output_specs=DTensorSpec( mesh=input_strategy.mesh, placements=tuple(output_placements), @@ -447,7 +447,7 @@ def constant_pad_nd_strategy(op_schema: OpSchema) -> OpStrategy: # TODO(d4l3k); implement a more correct strategy for constant_pad_nd return OpStrategy( [ - PlacementStrategy( + OpSpec( output_specs=DTensorSpec(mesh, (Replicate(),)), input_specs=( DTensorSpec(mesh, (Replicate(),)), diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 1cb3501ac5ca40..3abf7af747718d 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -5,11 +5,9 @@ import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( - _is_inplace_op, - _is_out_variant_op, OpSchema, + OpSpec, OpStrategy, - PlacementStrategy, RuntimeSchemaInfo, StrategyType, TupleStrategy, @@ -47,15 +45,6 @@ # ] -linear_pointwise_ops = [ - aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. - aten.div_.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. - aten.to.dtype, - aten.add.Tensor, - aten.add_.Tensor, -] - - pointwise_ops = [ # please keep the entries below alphabetically sorted aten.__ilshift__.Scalar, @@ -299,11 +288,7 @@ aten.maximum.out, aten.minimum.default, aten.minimum.out, - aten.mul.Scalar, - aten.mul.Tensor, aten.mul.out, - aten.mul_.Scalar, - aten.mul_.Tensor, aten.mvlgamma.default, aten.mvlgamma.out, aten.mvlgamma_.default, @@ -418,18 +403,36 @@ aten.threshold_backward.default, ] - -def pointwise_strategy(op_schema: OpSchema, linearity: bool = False) -> OpStrategy: - max_shards_strategy_index = -1 +# the linear pointwise ops map, key is op, value is the type of linearity +linear_pointwise_ops = { + aten.to.dtype: 0, + aten.add.Tensor: 1, + aten.add_.Tensor: 1, + aten.div.Scalar: 0, + aten.div_.Scalar: 0, + aten.mul.Scalar: 0, + aten.mul_.Scalar: 0, + aten.mul.Tensor: 2, + aten.mul_.Tensor: 2, +} + + +def pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> OpStrategy: + followed_strategy_index = -1 max_shards = -1 max_ndim = -1 - if _is_inplace_op(op_schema.op): + if op_schema.is_inplace_op(): # inplace op should follow the first arg strategy followed_strategy = op_schema.args_schema[0] - elif _is_out_variant_op(op_schema.op): + followed_strategy_index = 0 + elif op_schema.is_out_variant_op(): # out variant op should follow the out kwarg strategy followed_strategy = op_schema.kwargs_schema["out"] + # out variant is technically a kwarg for the strategy to follow so it does not + # have an "index", we set it to a reasonably large number just to indicate it's + # not a valid index + followed_strategy_index = 100 else: # normal pointwise op, we choose to follow the arg with # the max shards in case operands needs reshard @@ -444,33 +447,70 @@ def pointwise_strategy(op_schema: OpSchema, linearity: bool = False) -> OpStrate if (arg_max_shards > max_shards) or ( arg_max_shards == max_shards and arg_max_ndim > max_ndim ): - max_shards_strategy_index = idx + followed_strategy_index = idx max_shards = arg_max_shards max_ndim = arg_max_ndim - followed_strategy = op_schema.args_schema[max_shards_strategy_index] + followed_strategy = op_schema.args_schema[followed_strategy_index] assert isinstance(followed_strategy, OpStrategy), ( f"no strategy to follow for {op_schema}!" ) return common_pointwise_strategy( - op_schema.args_schema, followed_strategy, linearity + op_schema.args_schema, + followed_strategy, + followed_strategy_index, + linearity, ) +def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + + Note that: + 1. Only unary and binary operations are supported, out variant + ops are not supported. + 2. There're multiple types of linearity, refer to the doc of + common_pointwise_strategy for more details. + """ + linearity_type = linear_pointwise_ops.get(op_schema.op, -1) + return pointwise_strategy(op_schema, linearity=linearity_type) + + def common_pointwise_strategy( args_schema: Sequence[object], followed_strategy: OpStrategy, - linearity: bool, + followed_strategy_index: int, + linearity: int = -1, ) -> OpStrategy: + """ + Common strategy for pointwise operations. + + Args: + args_schema: Input arguments schema + followed_strategy: Strategy to follow for output placement + followed_strategy_index: Index of the strategy being followed + linearity: depending on the operator, we support different types of linearity + -1: the operation does not support linearity + 0: the unary operation that supports linearity, output propagates partial. + 1: the binary operation supports add linearity, where it requires every operand + to be partial, output propagates partial. + 2: the binary operation supports multiplicative linearity, where it requires + the primary operand to be partial, and the other operands to be replicate, + output propagates partial. + """ # handle broadcasting common_shape = torch.broadcast_shapes( *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] ) pointwise_strategy = OpStrategy([]) - for placement_strategy in followed_strategy.strategies: - spec_to_follow = placement_strategy.output_spec + for op_spec in followed_strategy.strategies: + spec_to_follow = op_spec.output_spec + out_placements: list[Placement] = [] for placement in spec_to_follow.placements: if isinstance(placement, Shard): @@ -478,17 +518,25 @@ def common_pointwise_strategy( common_ndim = len(common_shape) new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim out_placements.append(Shard(new_shard_dim)) - elif isinstance(placement, Partial) and not linearity: - # clear the partial placemnet if op does not support linearity - # by default we just replicate the partial, need to see if this - # is optimal for all cases - out_placements.append(Replicate()) + elif isinstance(placement, Partial): + # note that only partial-sum and partial-avg are supported for linearity + partial_supports_linearity = placement.is_partial( + "sum" + ) or placement.is_partial("avg") + if linearity > 0 and partial_supports_linearity: + # propagate the partial placement + out_placements.append(placement) + else: + # clear the partial placement if op does not support linearity + # by default we just replicate the partial, need to see if this + # is optimal for all cases + out_placements.append(Replicate()) else: out_placements.append(placement) input_specs: list[DTensorSpec] = [] redistribute_costs: list[list[float]] = [] - for arg_idx, input_arg in enumerate(args_schema): + for input_idx, input_arg in enumerate(args_schema): if isinstance(input_arg, OpStrategy): # sanity check that all args that follow the same strategy # are on the same DeviceMesh @@ -503,11 +551,21 @@ def common_pointwise_strategy( input_arg_dims_map = infer_broadcast_dims_map( common_shape, input_arg_spec.shape ) + + # Determine if this input should convert Partial to Replicate base on linearity + should_convert_partial = ( + linearity == 2 + and input_idx + != followed_strategy_index # Don't convert the "followed" strategy + ) + input_target_placements = map_placements_after_broadcast( tuple(out_placements), common_shape, input_arg_dims_map, + partial_to_replicate=should_convert_partial, ) + input_arg_target_spec = DTensorSpec( mesh=followed_strategy.mesh, placements=input_target_placements, @@ -519,7 +577,7 @@ def common_pointwise_strategy( ) pointwise_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=DTensorSpec( mesh=followed_strategy.mesh, placements=tuple(out_placements), @@ -531,16 +589,7 @@ def common_pointwise_strategy( return pointwise_strategy -def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: - """ - Linear pointwise operators can propagate pending reductions. - For example, c = add(a, b); if a is pending sum, then c will be - pending sum as well without any communication overhead. - """ - return pointwise_strategy(op_schema, linearity=True) - - -for op in linear_pointwise_ops: +for op in linear_pointwise_ops.keys(): register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( linear_pointwise_strategy ) diff --git a/torch/distributed/tensor/_ops/_random_ops.py b/torch/distributed/tensor/_ops/_random_ops.py index 51b1faed14ea49..e16a623904feb2 100644 --- a/torch/distributed/tensor/_ops/_random_ops.py +++ b/torch/distributed/tensor/_ops/_random_ops.py @@ -2,8 +2,8 @@ import torch from torch.distributed.tensor._op_schema import ( OpSchema, + OpSpec, OpStrategy, - PlacementStrategy, StrategyType, ) from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy @@ -31,6 +31,6 @@ def random_op_strategy(op_schema: OpSchema) -> StrategyType: if is_tensor_partial(arg_spec): # TODO: figure out how inplace random op should behave when it's partial raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") - random_strategy.strategies.append(PlacementStrategy(output_specs=arg_spec)) + random_strategy.strategies.append(OpSpec(output_specs=arg_spec)) return random_strategy diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index 9b73f36d855f09..d53d576b65c537 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -6,12 +6,11 @@ import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( - _is_inplace_op, OpSchema, + OpSpec, OpStrategy, OutputSharding, PlacementList, - PlacementStrategy, RuntimeSchemaInfo, StrategyType, TupleStrategy, @@ -46,7 +45,7 @@ def default_strategy(op_schema: OpSchema) -> StrategyType: # we create new DTensorSpecs even for default strategy to assure that # the tensor metas are distinct between the arguments and outputs default_strategy = [ - PlacementStrategy( + OpSpec( output_specs=DTensorSpec( mesh=select_strategy.mesh, placements=strategy.output_spec.placements, @@ -69,6 +68,7 @@ def default_strategy(op_schema: OpSchema) -> StrategyType: ] )(default_strategy) + register_op_strategy( aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) )(default_strategy) @@ -109,11 +109,9 @@ def equal_strategy(op_schema: OpSchema) -> StrategyType: for p in arg_spec.placements ), ) - equal_strategy.strategies.append( - PlacementStrategy(output_specs=output_spec) - ) + equal_strategy.strategies.append(OpSpec(output_specs=output_spec)) else: - equal_strategy.strategies.append(PlacementStrategy(arg_spec)) + equal_strategy.strategies.append(OpSpec(arg_spec)) return equal_strategy @@ -157,7 +155,7 @@ def create_like_strategy(op_schema: OpSchema) -> StrategyType: ), ) create_like_strategy.strategies.append( - PlacementStrategy(output_specs=output_spec, input_specs=(arg_spec,)) + OpSpec(output_specs=output_spec, input_specs=(arg_spec,)) ) return create_like_strategy @@ -190,7 +188,7 @@ def new_factory_strategy(op_schema: OpSchema) -> StrategyType: input_spec = arg_strategy.output_spec replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) new_factory_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=replica_spec, input_specs=(input_spec,), redistribute_cost=[[0.0] * mesh.ndim], @@ -207,7 +205,7 @@ def new_factory_strategy(op_schema: OpSchema) -> StrategyType: continue new_factory_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=input_spec, input_specs=(input_spec,), # encouraging new tensor placement to be the same as input @@ -229,9 +227,7 @@ def gen_bucketize_strategy(op_schema: OpSchema) -> StrategyType: arg_spec = DTensorSpec(mesh, arg_strategy.output_spec.placements) replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) bucketize_strategy.strategies.append( - PlacementStrategy( - output_specs=arg_spec, input_specs=(arg_spec, replica_spec) - ) + OpSpec(output_specs=arg_spec, input_specs=(arg_spec, replica_spec)) ) return bucketize_strategy @@ -292,7 +288,7 @@ def select_int_strategy(op_schema: OpSchema) -> StrategyType: ) select_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=output_specs, input_specs=(input_specs,), ) @@ -310,7 +306,7 @@ def select_backward_strategy(op_schema: OpSchema) -> OpStrategy: input_strategy, dim = args_schema[0], args_schema[2] assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" assert isinstance(dim, int) - output_strategies: list[PlacementStrategy] = [] + output_strategies: list[OpSpec] = [] for placement_strategy in input_strategy.strategies: input_spec = placement_strategy.output_spec output_spec_placements: list[Placement] = [] @@ -327,7 +323,7 @@ def select_backward_strategy(op_schema: OpSchema) -> OpStrategy: output_spec_placements.append(placement) output_specs = DTensorSpec(input_spec.mesh, tuple(output_spec_placements)) output_strategies.append( - PlacementStrategy(output_specs=output_specs, input_specs=(input_spec,)) + OpSpec(output_specs=output_specs, input_specs=(input_spec,)) ) return OpStrategy(output_strategies) @@ -367,7 +363,7 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice: # only add the strategy if the slice dim is not sharded out_spec = DTensorSpec(mesh, arg_spec.placements) - slice_strategy.strategies.append(PlacementStrategy(output_specs=out_spec)) + slice_strategy.strategies.append(OpSpec(output_specs=out_spec)) if not slice_strategy.strategies: # if all strategies are filtered out, unsharding all specs on slice dim # of the input strategy, and use that as the op strategy @@ -376,9 +372,7 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: unshard_spec = DTensorSpec( mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim) ) - slice_strategy.strategies.append( - PlacementStrategy(output_specs=unshard_spec) - ) + slice_strategy.strategies.append(OpSpec(output_specs=unshard_spec)) return slice_strategy @@ -391,7 +385,7 @@ def slice_backward_rules(op_schema: OpSchema) -> OpStrategy: args_schema = op_schema.args_schema input_strategy, dim = args_schema[0], args_schema[2] assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" - output_strategies: list[PlacementStrategy] = [] + output_strategies: list[OpSpec] = [] for placement_strategy in input_strategy.strategies: output_spec = placement_strategy.output_spec new_placements: list[Placement] = [] @@ -404,7 +398,7 @@ def slice_backward_rules(op_schema: OpSchema) -> OpStrategy: new_spec = DTensorSpec(output_spec.mesh, tuple(new_placements)) redistribute_cost = [generate_redistribute_costs(input_strategy, new_spec)] placement_strategy.redistribute_cost = redistribute_cost - new_strategy = PlacementStrategy(output_specs=new_spec) + new_strategy = OpSpec(output_specs=new_spec) output_strategies.append(new_strategy) return OpStrategy(output_strategies) @@ -458,9 +452,7 @@ def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: or is_tensor_partial(arg_spec) ): # only add the strategy if the slice_scatter dim is not sharded or partial - slice_scatter_strategy.strategies.append( - PlacementStrategy(output_specs=arg_spec) - ) + slice_scatter_strategy.strategies.append(OpSpec(output_specs=arg_spec)) if not slice_scatter_strategy.strategies: # if all strategies are filtered out, replicating all specs on slice_scatter dim @@ -471,7 +463,7 @@ def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: mesh, replicate_tensor_dim(arg_spec.placements, dim=slice_dim) ) slice_scatter_strategy.strategies.append( - PlacementStrategy(output_specs=replicate_spec) + OpSpec(output_specs=replicate_spec) ) return slice_scatter_strategy @@ -483,7 +475,7 @@ def replica_only_strategy(op_schema: OpSchema) -> StrategyType: assert isinstance(input_strategy, OpStrategy) mesh = input_strategy.mesh replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) - return OpStrategy([PlacementStrategy(replicate_spec)]) + return OpStrategy([OpSpec(replicate_spec)]) @register_op_strategy( @@ -504,10 +496,11 @@ def scatter_strategy(op_schema: OpSchema) -> StrategyType: single_mesh_dim_strategies.append(all_replicate) # TODO: see if we can support input sharding pattern - inplace_op = _is_inplace_op(op_schema.op) - op_strategy = expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op + mesh, + op_schema, + single_mesh_dim_strategies, + inplace_op=op_schema.is_inplace_op(), ) return op_strategy @@ -661,7 +654,7 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType: follow_placements = normalize_shard_for_stack(follow_placements, dim) op_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=DTensorSpec(mesh, tuple(follow_placements)), input_specs=input_specs, ) @@ -697,7 +690,7 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType: for _ in range(len(input_tuple_strategy.childs)) ) op_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=DTensorSpec(mesh, tuple(follow_placements)), input_specs=input_specs, ) @@ -738,6 +731,111 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding: return result +@register_op_strategy( + [ + aten.index_put.default, + aten._index_put_impl_.default, + ], + schema_info=RuntimeSchemaInfo(needs_pytree=True), +) +def prop_index_put(op_schema: OpSchema) -> StrategyType: + # We have 3 DTensor spec from argument `in`, `indices` and `values` + # accordingly. + in_spec, indices_spec, values_spec = op_schema.args_schema + assert isinstance(in_spec, OpStrategy) + # `indices`` is a tuple of scalar LongTensor, so we use TupleStrategy. + assert isinstance(indices_spec, TupleStrategy) + assert isinstance(values_spec, OpStrategy) + mesh = values_spec.mesh + op_strategy = OpStrategy([]) + # 1. `indices` should all be replicated first. + indices_redistribute_costs = [] + new_indices_spec: list[Optional[DTensorSpec]] = [] + for indices_spec_child in indices_spec.childs: + assert isinstance(indices_spec_child, OpStrategy) + + replicated_spec = DTensorSpec( + mesh=mesh, + placements=tuple([Replicate()] * mesh.ndim), + tensor_meta=indices_spec_child.strategies[0].output_spec.tensor_meta, + ) + new_indices_spec.append(replicated_spec) + child_costs = generate_redistribute_costs(indices_spec_child, replicated_spec) + indices_redistribute_costs.append(child_costs) + + # 2. For placement rule of `values` and `in`, assume `values` shape = + # [a,b,c,d,e,f], `in` shape = [d,e,f]. Then `values`'s a,b,c (selected dim) + # must be replicated and d,e,f (nonselected dim) in both `values` and `in` + # should follow the same sharding (replicate or shard, but not partial). + size_offset = ( + in_spec.strategies[0].output_spec.ndim + - values_spec.strategies[0].output_spec.ndim + ) + # We can either let `values` follow `in`'s placements or reverse. + for exemplar_spec in [in_spec, values_spec]: + # use exemplar_spec as the target spec + for strategy in exemplar_spec.strategies: + in_spec_new_placements: list[Placement] = [] + values_spec_new_placements: list[Placement] = [] + placements = strategy.output_spec.placements + for placement in placements: + if placement.is_shard(): + assert isinstance(placement, Shard) + if exemplar_spec is in_spec: + # let `values_spce` follow `in_spec` + if placement.dim < size_offset: + # sharded on selected dim, need to change to replicate + in_spec_new_placements.append(Replicate()) + values_spec_new_placements.append(Replicate()) + else: + in_spec_new_placements.append(placement) + values_spec_new_placements.append( + Shard(placement.dim - size_offset) + ) + else: + # let `in_spec` follow `values_spec` + in_spec_new_placements.append( + Shard(placement.dim + size_offset) + ) + values_spec_new_placements.append(placement) + else: + in_spec_new_placements.append(Replicate()) + values_spec_new_placements.append(Replicate()) + new_in_spec = DTensorSpec( + mesh=mesh, + placements=tuple(in_spec_new_placements), + tensor_meta=in_spec.strategies[0].output_spec.tensor_meta, + ) + new_values_spec = DTensorSpec( + mesh=mesh, + placements=tuple(values_spec_new_placements), + tensor_meta=values_spec.strategies[0].output_spec.tensor_meta, + ) + output_spec = DTensorSpec( + mesh=mesh, + placements=tuple(in_spec_new_placements), + tensor_meta=in_spec.strategies[0].output_spec.tensor_meta, + ) + cost_in_spec = generate_redistribute_costs(in_spec, new_in_spec) + cost_values_spec = generate_redistribute_costs(values_spec, new_values_spec) + op_strategy.strategies.append( + OpSpec( + input_specs=( + new_in_spec, + *new_indices_spec, # type: ignore[arg-type] + new_values_spec, + ), + output_specs=output_spec, + redistribute_cost=[ + cost_in_spec, + *indices_redistribute_costs, + cost_values_spec, + ], + ) + ) + return op_strategy + + @register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)) def prop_index(op_schema: OpSchema) -> OutputSharding: """ @@ -916,7 +1014,7 @@ def size_split(N, i) -> list: spec = DTensorSpec(spec.mesh, placements) op_strategy.strategies.append( - PlacementStrategy(output_specs=spec, input_specs=([spec])) + OpSpec(output_specs=spec, input_specs=([spec])) ) split_strategies.append(op_strategy) diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 877ace7207b3ab..8fe213f39846e0 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -10,8 +10,8 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, + OpSpec, OpStrategy, - PlacementStrategy, RuntimeSchemaInfo, StrategyType, ) @@ -666,7 +666,7 @@ def reshape_strategy(op_schema: OpSchema) -> StrategyType: output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) output_strategy.strategies.append( - PlacementStrategy( + OpSpec( output_specs=output_spec, input_specs=(input_tgt_spec,), redistribute_cost=redistribute_costs, diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index b209adaa6cbb6e..9a211ba448348a 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -14,10 +14,10 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, + OpSpec, OpStrategy, OutputSharding, PlacementList, - PlacementStrategy, RuntimeSchemaInfo, ) from torch.distributed.tensor.device_mesh import DeviceMesh @@ -196,11 +196,18 @@ def map_placements_after_broadcast( placements: tuple[Placement, ...], shape: torch.Size, broadcast_dims_map: list[int], + partial_to_replicate: bool = False, ) -> tuple[Placement, ...]: """Map each placement based on the output shape after broadcast.""" new_placements: list[Placement] = [] for placement in placements: - if isinstance(placement, (Replicate, Partial)): + if isinstance(placement, Partial): + if partial_to_replicate: + # map the partial placement to replicate + new_placements.append(Replicate()) + else: + new_placements.append(placement) + elif isinstance(placement, Replicate): new_placements.append(placement) else: assert isinstance(placement, Shard) @@ -265,7 +272,7 @@ def expand_to_full_mesh_op_strategy( self_spec = input_args_strategy[0].strategies[0].output_spec if inplace_op and self_spec.placements != input_specs[0].placements: - # if it's inplace op, we would only allow the placement strategy to be added when the + # if it's inplace op, we would only allow the OpSpec to be added when the # input_spec matches the first argument's runtime sharding, otherwise we skip continue @@ -288,7 +295,7 @@ def expand_to_full_mesh_op_strategy( output_specs = spec_list[0] # type: ignore[assignment] else: raise RuntimeError("output spec is None") - strategy = PlacementStrategy( + strategy = OpSpec( output_specs=output_specs, input_specs=input_specs, redistribute_cost=redistribute_cost, diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index aa8d1240447ecd..11fc2d11e1a88f 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -318,7 +318,6 @@ def forward( # type: ignore[override] device_mesh, placements, tensor_meta=current_spec.tensor_meta ) - local_tensor = input._local_tensor output = redistribute_local_tensor( local_tensor, current_spec, target_spec, async_op=async_op ) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 0e186da5615238..819d07d8cfd35a 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -12,10 +12,10 @@ from torch.distributed.tensor._op_schema import ( OpInfo, OpSchema, + OpSpec, OpStrategy, OutputSharding, OutputSpecType, - PlacementStrategy, RuntimeSchemaInfo, StrategyType, TupleStrategy, @@ -227,7 +227,7 @@ def _wrap_with_op_strategy(self, op_schema: OpSchema) -> OpSchema: def spec_to_strategy(spec: object) -> object: if isinstance(spec, DTensorSpec): - return OpStrategy([PlacementStrategy(spec)]) + return OpStrategy([OpSpec(spec)]) elif ( isinstance(spec, (list, tuple)) and len(spec) > 0 @@ -365,8 +365,8 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin ) elif isinstance(op_strategy, TupleStrategy): # tuple strategy output sharding processing - # runtime selected placement strategy for each TupleStrategy input arg - selected_strategies: list[PlacementStrategy] = [] + # runtime select OpSpec for each TupleStrategy input arg + selected_strategies: list[OpSpec] = [] out_spec_list: list[DTensorSpec] = [] for strategy in op_strategy.childs: assert isinstance(strategy, OpStrategy) @@ -487,21 +487,21 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin f"Operator {op_schema.op} does not have a sharding strategy registered." ) - def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: + def _select_strategy(self, strategy: OpStrategy) -> OpSpec: if len(strategy.strategies) == 1: - # short cut with only one possible strategy + # short cut with only one possible OpSpec return strategy.strategies[0] - strategy_costs: list[float] = [] - for strtg in strategy.strategies: - assert strtg.redistribute_cost is not None, ( - "must set redistribute cost each strategy!" + op_spec_costs: list[float] = [] + for op_spec in strategy.strategies: + assert op_spec.redistribute_cost is not None, ( + "must set redistribute cost each OpSpec!" ) - redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) - strategy_costs.append(redistribute_cost) + redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost)) + op_spec_costs.append(redistribute_cost) # for eager execution, we just select the one with the minimal redistribute cost - return strategy.strategies[strategy_costs.index(min(strategy_costs))] + return strategy.strategies[op_spec_costs.index(min(op_spec_costs))] def _adjust_shape_and_stride_args( self, diff --git a/torch/distributed/tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py index 0f3770a1eac7f8..f66ea658daf4b3 100644 --- a/torch/distributed/tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -84,7 +84,7 @@ def __new__( # necessary for ops dispatching from this subclass to its local shards @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] kwargs = kwargs or {} # TODO: we shall continually extend this function to support more ops if needed diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index d31f8d07297256..73b53f051421d4 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -1,5 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates - import contextlib import itertools import logging @@ -17,7 +15,13 @@ import torch.nn.functional as F from torch import nn from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard +from torch.distributed.tensor import ( + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) from torch.distributed.tensor.parallel.style import ParallelStyle from torch.overrides import TorchFunctionMode @@ -193,106 +197,6 @@ def results(self) -> tuple[torch.Tensor, torch.Tensor]: return out.to(self._out_dtype), lse.to(self._lse_dtype) -def _scaled_dot_product_ring_flash_attention( - mesh: DeviceMesh, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - if return_debug_mask: - raise NotImplementedError("return_debug_mask is not supported yet") - - seq_dim = 2 - return _templated_ring_attention( - mesh, - seq_dim, - aten._scaled_dot_product_flash_attention, - query=query, - key=key, - value=value, - is_causal=is_causal, - dropout_p=dropout_p, - scale=scale, - ) - - -def _scaled_dot_product_ring_efficient_attention( - mesh: DeviceMesh, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, - compute_log_sumexp: bool = True, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - if attn_bias is not None: - raise NotImplementedError("attn_bias is not supported yet") - - if not compute_log_sumexp: - # CP requires compute_log_sumexp to be True because it always merges LSE - compute_log_sumexp = True - - seq_dim = 2 - return _templated_ring_attention( - mesh, - seq_dim, - aten._scaled_dot_product_efficient_attention, - query=query, - key=key, - value=value, - is_causal=is_causal, - attn_bias=attn_bias, - dropout_p=dropout_p, - scale=scale, - compute_log_sumexp=compute_log_sumexp, - ) - - -def _scaled_dot_product_ring_cudnn_attention( - mesh: DeviceMesh, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, - compute_log_sumexp: bool = True, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - if attn_bias is not None: - raise NotImplementedError("attn_bias is not supported yet") - - if not compute_log_sumexp: - # CP requires compute_log_sumexp to be True because it always merges LSE - compute_log_sumexp = True - - seq_dim = 2 - return _templated_ring_attention( - mesh, - seq_dim, - aten._scaled_dot_product_cudnn_attention, - query=query, - key=key, - value=value, - attn_bias=attn_bias, - compute_log_sumexp=compute_log_sumexp, - dropout_p=dropout_p, - is_causal=is_causal, - return_debug_mask=return_debug_mask, - scale=scale, - ) - - class _AttentionOp(Protocol): def __call__( self, @@ -376,21 +280,8 @@ def _create_rotater( raise NotImplementedError(f"Unkonwn method {method}") -def _ring_rotate( - block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool -) -> torch.Tensor: - block = block.contiguous() - size = dist.get_world_size(pg) - dsts = ( - list(range(1, size)) + [0] - if send_to_next - else [size - 1] + list(range(0, size - 1)) - ) - return ft_c.permute_tensor(block, dsts, pg) - - def _templated_ring_attention( - mesh: DeviceMesh, + group: dist.ProcessGroup, seq_dim: int, op: _AttentionOp, query: torch.Tensor, @@ -480,13 +371,11 @@ def _templated_ring_attention( if not is_causal and _cp_options.enable_load_balance: raise RuntimeError("Load balancing requires `is_causal=True`.") - if isinstance(mesh, dist.ProcessGroup): - pg: Union[dist.ProcessGroup, list[dist.ProcessGroup]] = mesh - else: - pg = mesh.get_group() - assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension" - rank = dist.get_rank(pg) - size = dist.get_world_size(pg) + assert isinstance(group, dist.ProcessGroup), ( + "process group must be single dimension" + ) + rank = dist.get_rank(group) + size = dist.get_world_size(group) next_kv = None @@ -502,7 +391,7 @@ def _templated_ring_attention( out: torch.Tensor logsumexp: torch.Tensor - rotater = _create_rotater(pg, 2) + rotater = _create_rotater(group, 2) for i in range(size): if i > 0: @@ -562,95 +451,8 @@ def _templated_ring_attention( return *sdpa_merger.results(), *rest -def _sdpa_handler( - op_call: torch._ops.OpOverload, - args: tuple[object, ...], - kwargs: dict[str, object], -) -> object: - # extract local tensor and sharding infos to a OpInfo - op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) - logger.debug("Dispatching op_call: %s", op_info.schema) - - # sharding propagation - # TODO: remove the context parallel strategy from the default propagation - # rule. Either figure out how to dynamically enable it or just don't call - # propagate. - DTensor._op_dispatcher.sharding_propagator.propagate(op_info) - output_sharding = op_info.output_sharding - assert output_sharding is not None, "output sharding should not be None" - assert not output_sharding.needs_redistribute, "inputs need to be redistributed" - - if op_call == aten._scaled_dot_product_flash_attention.default: - local_results = _scaled_dot_product_ring_flash_attention( - op_info.compute_mesh, - *op_info.local_args, # type: ignore[arg-type] - **op_info.local_kwargs, # type: ignore[arg-type] - ) - elif op_call == aten._scaled_dot_product_efficient_attention.default: - local_results = _scaled_dot_product_ring_efficient_attention( - op_info.compute_mesh, - *op_info.local_args, # type: ignore[arg-type] - **op_info.local_kwargs, # type: ignore[arg-type] - ) - elif op_call == aten._scaled_dot_product_cudnn_attention.default: - local_results = _scaled_dot_product_ring_cudnn_attention( - op_info.compute_mesh, - *op_info.local_args, # type: ignore[arg-type] - **op_info.local_kwargs, # type: ignore[arg-type] - ) - else: - raise NotImplementedError( - "CP only supports flash attention and memory efficient attention now." - ) - - return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) - - -def _sdpa_backward_handler( - op_call: torch._ops.OpOverload, - args: tuple[object, ...], - kwargs: dict[str, object], -) -> object: - # Redistribute grad_output tensor to the same placement as output tensor - args = list(args) - args = tuple(args) - - # extract local tensor and sharding infos to a OpInfo - op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) - logger.debug("Dispatching op_call: %s", op_info.schema) - - # sharding propagation - DTensor._op_dispatcher.sharding_propagator.propagate(op_info) - output_sharding = op_info.output_sharding - assert output_sharding is not None, "output sharding should not be None" - assert not output_sharding.needs_redistribute, "inputs need to be redistributed" - - if op_call == aten._scaled_dot_product_flash_attention_backward.default: - local_results = _scaled_dot_product_ring_flash_attention_backward( - op_info.compute_mesh, - *op_info.local_args, # type: ignore[arg-type] - **op_info.local_kwargs, # type: ignore[arg-type] - ) - elif op_call == aten._scaled_dot_product_efficient_attention_backward.default: - local_results = _scaled_dot_product_ring_efficient_attention_backward( - op_info.compute_mesh, - *op_info.local_args, # type: ignore[arg-type] - **op_info.local_kwargs, # type: ignore[arg-type] - ) - elif op_call == aten._scaled_dot_product_cudnn_attention_backward.default: - local_results = _scaled_dot_product_ring_cudnn_attention_backward( - op_info.compute_mesh, - *op_info.local_args, # type: ignore[arg-type] - **op_info.local_kwargs, # type: ignore[arg-type] - ) - else: - raise NotImplementedError(f"{op_call=}") - - return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) - - def _templated_ring_attention_backward( - mesh: DeviceMesh, + group: dist.ProcessGroup, seq_dim: int, op: _AttentionOp, grad_out: torch.Tensor, @@ -666,10 +468,8 @@ def _templated_ring_attention_backward( """This API implements the backward of the ring attention.""" if not is_causal and _cp_options.enable_load_balance: raise RuntimeError("Load balancing requires `is_causal=True`.") - pg = mesh.get_group() - assert isinstance(pg, dist.ProcessGroup), "must be single dimension" - rank = dist.get_rank(pg) - size = dist.get_world_size(pg) + rank = dist.get_rank(group) + size = dist.get_world_size(group) next_kv = None next_grad_kv = None rest: list[Any] @@ -682,8 +482,8 @@ def _templated_ring_attention_backward( key = key.contiguous() value = value.contiguous() - kv_rotater = _create_rotater(pg, 2) - dkv_rotater = _create_rotater(pg, 2, method=_RotateMethod.ALL_TO_ALL) + kv_rotater = _create_rotater(group, 2) + dkv_rotater = _create_rotater(group, 2, method=_RotateMethod.ALL_TO_ALL) for i in range(size): if i > 0: # Wait for the kv from the (cp_rank - 1) rank. @@ -818,6 +618,109 @@ def _templated_ring_attention_backward( ) +def _scaled_dot_product_ring_flash_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if return_debug_mask: + raise NotImplementedError("return_debug_mask is not supported yet") + + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_flash_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + dropout_p=dropout_p, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_efficient_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + attn_bias=attn_bias, + dropout_p=dropout_p, + scale=scale, + compute_log_sumexp=compute_log_sumexp, + ) + + +def _scaled_dot_product_ring_cudnn_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_cudnn_attention, + query=query, + key=key, + value=value, + attn_bias=attn_bias, + compute_log_sumexp=compute_log_sumexp, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=return_debug_mask, + scale=scale, + ) + + def _scaled_dot_product_ring_flash_attention_backward( mesh: DeviceMesh, grad_out: torch.Tensor, @@ -838,8 +741,9 @@ def _scaled_dot_product_ring_flash_attention_backward( scale: Optional[float] = None, ) -> tuple[torch.Tensor, ...]: seq_dim = 2 + group = mesh.get_group() return _templated_ring_attention_backward( - mesh, + group, seq_dim, aten._scaled_dot_product_flash_attention_backward.default, grad_out=grad_out, @@ -879,8 +783,9 @@ def _scaled_dot_product_ring_efficient_attention_backward( scale: Optional[float] = None, ) -> tuple[torch.Tensor, ...]: seq_dim = 2 + group = mesh.get_group() return _templated_ring_attention_backward( - mesh, + group, seq_dim, aten._scaled_dot_product_efficient_attention_backward.default, grad_out=grad_out, @@ -921,8 +826,9 @@ def _scaled_dot_product_ring_cudnn_attention_backward( scale: Optional[float] = None, ) -> tuple[torch.Tensor, ...]: seq_dim = 2 + group = mesh.get_group() return _templated_ring_attention_backward( - mesh, + group, seq_dim, aten._scaled_dot_product_cudnn_attention_backward.default, grad_out=grad_out, @@ -945,13 +851,53 @@ def _scaled_dot_product_ring_cudnn_attention_backward( ) +def _sdpa_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + # sharding propagation + # TODO: remove the context parallel strategy from the default propagation + # rule. Either figure out how to dynamically enable it or just don't call + # propagate. + DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert not output_sharding.needs_redistribute, "inputs need to be redistributed" + + call_maps: dict[torch._ops.OpOverload, Callable] = { + aten._scaled_dot_product_flash_attention.default: _scaled_dot_product_ring_flash_attention, + aten._scaled_dot_product_efficient_attention.default: _scaled_dot_product_ring_efficient_attention, + aten._scaled_dot_product_cudnn_attention.default: _scaled_dot_product_ring_cudnn_attention, + aten._scaled_dot_product_flash_attention_backward.default: _scaled_dot_product_ring_flash_attention_backward, + aten._scaled_dot_product_efficient_attention_backward.default: _scaled_dot_product_ring_efficient_attention_backward, + aten._scaled_dot_product_cudnn_attention_backward.default: _scaled_dot_product_ring_cudnn_attention_backward, + } + if op_call in call_maps: + local_results = call_maps[op_call]( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + else: + raise NotImplementedError( + "CP only supports flash attention and memory efficient attention now." + ) + + return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) + + customized_ops = { aten._scaled_dot_product_flash_attention.default: _sdpa_handler, - aten._scaled_dot_product_flash_attention_backward.default: _sdpa_backward_handler, + aten._scaled_dot_product_flash_attention_backward.default: _sdpa_handler, aten._scaled_dot_product_efficient_attention.default: _sdpa_handler, - aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_backward_handler, + aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_handler, aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler, - aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_backward_handler, + aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_handler, } @@ -1062,13 +1008,8 @@ class _AttentionContextParallel(ParallelStyle): ) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - if not isinstance(device_mesh, DeviceMesh): - raise ValueError( - f"{type(device_mesh)} is not supported by {type(self)} yet." - ) - if not device_mesh.ndim == 1: - raise ValueError + raise ValueError("CP only supports single dimension device mesh") return distribute_module( module, @@ -1232,107 +1173,66 @@ def __torch_function__( raise NotImplementedError("torch dispatch mode is not supported yet.") -class _LoadBalancer(ABC): - @classmethod - @abstractmethod - def shard( - cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int - ) -> torch.Tensor: ... - - @classmethod - @abstractmethod - def unshard( - cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int - ) -> torch.Tensor: ... - - -class _SequentialSharder(_LoadBalancer): - """ - This load balancer chunks the buffer into cp_world_size and rank0 gets - 0th shard, rank1 gets 1st shard, ... - So this doesn't have any load balancing effect when using the causal masking. - """ - - @classmethod - def shard( - cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int - ) -> torch.Tensor: - assert buffer.size()[seq_dim] % mesh.size() == 0 - return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()] - - @classmethod - def unshard( - cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int - ) -> torch.Tensor: - buffer = buffer.contiguous() - all_buffers = [torch.empty_like(buffer) for _ in range(mesh.size())] - ft_c.all_gather_inplace(all_buffers, buffer, mesh) - return torch.cat(all_buffers, dim=seq_dim) - - -class _RoundRobinLoadBalancer(_LoadBalancer): +def _generate_round_robin_indices( + seq_length: int, + cp_world_size: int, + device: torch.device, + restore: bool = False, +) -> torch.Tensor: """ - This load balancer chunk the buffer into cp_world_size * ROUND_ROBIN_CYCLE - shards, and uses a round robin approach to achieve load balancing. - Since ROUND_ROBIN_CYCLE being 2 will achieve perfect load balancing for - causal masking, we assume ROUND_ROBIN_CYCLE is always 2 to simplify the - implementation. + Generate round-robin load balancing indices or restore indices. + Args: + seq_length: Total sequence length + cp_world_size: Context parallel world size + device: Device to place the tensor on + restore: If True, generate restore indices that map round-robin reordered + positions back to original positions. If False, generate load + balance indices that reorder original positions to round-robin pattern. + Returns: + Index tensor of shape (seq_length,) with the requested mapping. """ - - ROUND_ROBIN_CYCLE = 2 - - @classmethod - def shard( - cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int - ) -> torch.Tensor: - assert cls.ROUND_ROBIN_CYCLE == 2, ( - "The current implementation only works if ROUND_ROBIN_CYCLE is 2." - ) - cp_world_size = mesh.size() - cp_rank = mesh.get_local_rank() - assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0 - chunks = buffer.chunk(cp_world_size * 2, dim=seq_dim) - return torch.cat( - (chunks[cp_rank], chunks[cp_world_size * 2 - cp_rank - 1]), - dim=seq_dim, + assert seq_length % (cp_world_size * 2) == 0 + chunk_size = seq_length // (cp_world_size * 2) + all_indices = [] + + for cp_rank in range(cp_world_size): + # Generate indices for first chunk of the cp rank + first_chunk_start = cp_rank * chunk_size + first_chunk_indices = list( + range(first_chunk_start, first_chunk_start + chunk_size) ) - @classmethod - def unshard( - cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int - ) -> torch.Tensor: - assert cls.ROUND_ROBIN_CYCLE == 2, ( - "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + # Second chunk: positions from the complementary chunk + second_chunk_idx = cp_world_size * 2 - cp_rank - 1 + second_chunk_start = second_chunk_idx * chunk_size + second_chunk_indices = list( + range(second_chunk_start, second_chunk_start + chunk_size) ) - buffer = buffer.contiguous() - cp_world_size = mesh.size() - - all_buffers = [torch.empty_like(buffer) for _ in range(cp_world_size)] - ft_c.all_gather_inplace(all_buffers, buffer, mesh) - sliced_buffers = [sb for b in all_buffers for sb in b.chunk(2, dim=seq_dim)] - ordered_buffers = list(sliced_buffers) - for i, b in enumerate(sliced_buffers): - if i % 2 == 0: - ordered_buffers[i // 2] = b - else: - ordered_buffers[cp_world_size * 2 - (i // 2) - 1] = b - return torch.cat(ordered_buffers, dim=seq_dim) + # combine the indices for this rank + all_indices.extend(first_chunk_indices + second_chunk_indices) + all_indices_tensor = torch.tensor(all_indices, dtype=torch.int, device=device) + if restore: + all_indices_tensor = torch.argsort(all_indices_tensor) + return all_indices_tensor def _context_parallel_buffers( mesh: DeviceMesh, buffers: list[torch.Tensor], buffer_seq_dims: list[int], + load_balance_indices: Optional[torch.Tensor] = None, ) -> list[torch.Tensor]: """Shard the buffers along the sequence dimensions according to CP rules.""" new_buffers = [] - sharder = ( - _RoundRobinLoadBalancer - if _cp_options.enable_load_balance - else _SequentialSharder - ) for buffer, seq_dim in zip(buffers, buffer_seq_dims): - new_buffers.append(sharder.shard(buffer, mesh, seq_dim)) + if load_balance_indices is not None: + buffer = torch.index_select(buffer, dim=seq_dim, index=load_balance_indices) + + # use DTensor to shard the buffer on sequence dimension, retain the local tensor + sharded_buffer = distribute_tensor( + buffer, mesh, [Shard(seq_dim)], src_data_rank=None + ).to_local() + new_buffers.append(sharded_buffer) return new_buffers @@ -1390,11 +1290,25 @@ def context_parallel( raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] - chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims) - for buffer, chunk in zip(buffers, chunks): - chunk = chunk.clone() - buffer.resize_(chunk.shape) - buffer.copy_(chunk) + + device = buffers[0].device + seq_length = buffers[0].shape[buffer_seq_dims[0]] + cp_world_size = mesh.size() + if _cp_options.enable_load_balance: + load_balance_indices = _generate_round_robin_indices( + seq_length=seq_length, + cp_world_size=cp_world_size, + device=device, + ) + else: + load_balance_indices = None + shards = _context_parallel_buffers( + mesh, buffers, buffer_seq_dims, load_balance_indices + ) + for buffer, shard in zip(buffers, shards): + shard = shard.clone() + buffer.resize_(shard.shape) + buffer.copy_(shard) with _context_parallel(seq_dim=2, mesh=mesh): yield @@ -1423,12 +1337,30 @@ def context_parallel_unshard( Returns: List[torch.Tensor]: the unsharded buffers. """ - sharder = ( - _RoundRobinLoadBalancer - if _cp_options.enable_load_balance - else _SequentialSharder - ) - return [sharder.unshard(b, mesh, dim) for b, dim in zip(buffers, seq_dims)] + if _cp_options.enable_load_balance: + device = buffers[0].device + cp_world_size = mesh.size() + seq_length = buffers[0].shape[seq_dims[0]] * cp_world_size + restore_indices = _generate_round_robin_indices( + seq_length=seq_length, + cp_world_size=cp_world_size, + device=device, + restore=True, + ) + else: + restore_indices = None + unsharded_buffers = [] + for b, dim in zip(buffers, seq_dims): + b = b.contiguous() + unsharded_b = _maybe_wait(ft_c.all_gather_tensor(b, dim, mesh)) + + if restore_indices is not None: + unsharded_b = torch.index_select( + unsharded_b, dim=dim, index=restore_indices + ) + + unsharded_buffers.append(unsharded_b) + return unsharded_buffers def set_rotate_method(rotate_method: str) -> None: diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index d0f94d36cae1d7..7eb2e72343e21f 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -27,6 +27,7 @@ def local_map( func: Callable, out_placements: OutputPlacements, in_placements: Optional[InputPlacements] = None, + in_grad_placements: Optional[InputPlacements] = None, device_mesh: Optional[DeviceMesh] = None, *, redistribute_inputs: bool = False, @@ -66,11 +67,20 @@ def local_map( will be skipped and the argument will be directly passed to ``func``. If ``in_placements`` is ``None``, no placements examination will be performed. Default: None + in_grad_placements (Tuple[`PlacementType`, ...], optional): + the placements hint of the :class:`DTensor` s gradient corresponds + to the flattened input DTensor. This argument is the hint that user + can give to :meth:`to_local` in case the gradient layout of the + local tensor input does not match its :class:`DTensor` input layout. + If not specified, we will assume the gradient layout of the local + tensor input remains the same as the original :class:`DTensor` input + and use that for gradient computation. Default: None. device_mesh (:class:`DeviceMesh`, optional): - the device mesh that all the :class:`DTensor` s are placed on. If not - specified, this will be inferred from the input :class:`DTensor` s' device - mesh. `local_map` requires every :class:`DTensor` s to be placed on the same - device mesh. Default: None. + the device mesh that the output :class:`DTensor` s are placed on. If not + specified, this will be inferred from the first input :class:`DTensor`'s device + mesh. Default: None. + + Keyword Args: redistribute_inputs (bool, optional): the bool value indicating whether to reshard the input :class:`DTensor` s when their placements are different from the required input placements. If this @@ -82,10 +92,6 @@ def local_map( and returns a :class:`DTensor` constructed from the return value of ``func``. Raises: - AssertionError: If the input :class:`DTensor` is not placed on the same device - mesh, or if they are placed on a different device mesh than the ``device_mesh`` - argument passed in. - AssertionError: For any non-DTensor output, we require its corresponding output placement in ``out_placements`` be None. An AssertionError will be raised if this is not the case. @@ -150,11 +156,6 @@ def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): # this function is applied to at least one DTensor argument seen_dtensor_arg = True - assert arg.device_mesh == device_mesh, ( - f"arg {arg} in local_map has a mismatched device mesh: " - f"{arg} has device mesh {arg.device_mesh} while " - f"the expected device mesh is {device_mesh}!" - ) if in_placements is not None: spec = in_placements[idx] assert spec is not None, ( @@ -167,7 +168,7 @@ def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): if arg.placements != spec: if redistribute_inputs: # redistribute to input placements - arg = arg.redistribute(device_mesh, spec) + arg = arg.redistribute(placements=spec) else: raise ValueError( f"arg {arg} in local_map has a mismatched placements: " @@ -177,7 +178,17 @@ def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): "redistribute_inputs=True to local_map." ) - local_arg = arg.to_local() + if in_grad_placements is not None: + spec = in_grad_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects in grad placements but received {spec}!" + ) + if not isinstance(spec, tuple): + spec = tuple(spec) + local_arg = arg.to_local(grad_placements=spec) + else: + local_arg = arg.to_local() + if isinstance(local_arg, AsyncCollectiveTensor): local_arg = local_arg.wait() @@ -198,7 +209,7 @@ def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): out = func(*local_args, **kwargs) if seen_dtensor_arg: - # process output + # process output to be DTensor if we've seen DTensor inputs flat_out, out_spec = pytree.tree_flatten(out) flat_dist_out = [] diff --git a/torch/distributed/tensor/experimental/_register_sharding.py b/torch/distributed/tensor/experimental/_register_sharding.py index 5d817912ac9f23..a1fa82e87af456 100644 --- a/torch/distributed/tensor/experimental/_register_sharding.py +++ b/torch/distributed/tensor/experimental/_register_sharding.py @@ -8,7 +8,6 @@ from torch._ops import OpOverload from torch.distributed.tensor import DTensor from torch.distributed.tensor._op_schema import ( - _is_inplace_op, OpSchema, OpStrategy, PlacementList, @@ -101,7 +100,7 @@ def strategy_to_spec(strategy: object) -> object: op_schema, single_mesh_dim_strategies, input_index=len(op_schema.op._schema.returns), - inplace_op=_is_inplace_op(op_schema.op), + inplace_op=op_schema.is_inplace_op(), ) def wrapper(custom_sharding_fn): diff --git a/torch/distributed/tensor/experimental/_tp_transform.py b/torch/distributed/tensor/experimental/_tp_transform.py index 52de6cebe684a6..7bdfa768cf55b8 100644 --- a/torch/distributed/tensor/experimental/_tp_transform.py +++ b/torch/distributed/tensor/experimental/_tp_transform.py @@ -10,9 +10,9 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpSchema, + OpSpec, OutputSharding, OutputSpecType, - PlacementStrategy, ) from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle @@ -69,8 +69,8 @@ def tensor_parallel_transformation( class _TensorParallelTransformPass(PassBase): """ This pass is responsible for transforming a single-device graph into a tensor parallel - graph. It will mark the placement strategy of each node in the graph, - partition the graph into distributed graph, then shard the parameters/buffers accordingly. + graph. It will mark the OpSpec of each node in the graph, partition the graph into + distributed graph, then shard the parameters/buffers accordingly. """ def __init__( @@ -132,11 +132,11 @@ def _mark_tensor_parallel_shardings( graph_signature: ExportGraphSignature, mesh: DeviceMesh, parameter_placements: dict[str, Placement], -) -> dict[Node, PlacementStrategy]: +) -> dict[Node, OpSpec]: """ Mark the placement strategies of the parameter and buffer placeholder nodes. """ - placement_strategies: dict[Node, PlacementStrategy] = {} + placement_strategies: dict[Node, OpSpec] = {} num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len( graph_signature.inputs_to_buffers ) @@ -184,17 +184,15 @@ def _mark_sharding( graph_signature: ExportGraphSignature, mesh: DeviceMesh, parameter_placements: dict[str, Placement], -) -> dict[Node, PlacementStrategy]: +) -> dict[Node, OpSpec]: """ Mark the sharding strategy for each node in the graph module. """ - placement_strategies: dict[Node, PlacementStrategy] = ( - _mark_tensor_parallel_shardings( - gm, - graph_signature, - mesh, - parameter_placements, - ) + placement_strategies: dict[Node, OpSpec] = _mark_tensor_parallel_shardings( + gm, + graph_signature, + mesh, + parameter_placements, ) for node in gm.graph.nodes: @@ -238,7 +236,7 @@ def _mark_sharding( output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding( # type: ignore[assignment] op_schema, ) - placement_strategies[node] = PlacementStrategy( + placement_strategies[node] = OpSpec( output_specs=_get_output_spec_from_output_sharding(output_sharding), input_specs=output_sharding.redistribute_schema.args_spec if output_sharding.redistribute_schema is not None @@ -273,11 +271,11 @@ def _create_placement_strategy( mesh: DeviceMesh, placements: tuple[Placement, ...], input_specs: Optional[Sequence[DTensorSpec]] = None, -) -> PlacementStrategy: +) -> OpSpec: """ - Util function to construct a placement strategy for a given node. + Util function to construct an OpSpec for a given node. """ - placement = PlacementStrategy( + placement = OpSpec( input_specs=input_specs, output_specs=DTensorSpec( mesh=mesh, @@ -491,7 +489,7 @@ def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None: def _get_input_node_specs( - node: Node, placement_strategies: dict[Node, PlacementStrategy] + node: Node, placement_strategies: dict[Node, OpSpec] ) -> tuple[DTensorSpec, ...]: """ Get the input specs of a node. @@ -507,9 +505,7 @@ def _get_input_node_specs( return tuple(input_specs_list) -def _get_op_schema( - node: Node, placement_strategies: dict[Node, PlacementStrategy] -) -> OpSchema: +def _get_op_schema(node: Node, placement_strategies: dict[Node, OpSpec]) -> OpSchema: """ Util function to construct the operator schema of a node. """ @@ -526,14 +522,14 @@ def _get_op_schema( def _shard_state_dict( state_dict: dict[str, torch.Tensor], - placement_strategies: dict[Node, PlacementStrategy], + placement_strategies: dict[Node, OpSpec], graph_signature: ExportGraphSignature, mesh: DeviceMesh, ) -> None: """ - Inplace partition the weights based on the placement strategy + Inplace partition the weights based on the OpSpec """ - for node, placement_strategy in placement_strategies.items(): + for node, op_spec in placement_strategies.items(): if node.op != "placeholder": continue if node.name in graph_signature.inputs_to_parameters: @@ -548,7 +544,7 @@ def _shard_state_dict( dtensor_param = distribute_tensor( original_param, mesh, - placement_strategy.output_spec.placements, + op_spec.output_spec.placements, ) local_param = dtensor_param.to_local() state_dict[fqn] = ( diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index ea578239960e50..81c005000a8553 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -88,39 +88,53 @@ def parallelize_module( # type: ignore[return] return parallelize_plan._apply(module, device_mesh) elif isinstance(parallelize_plan, dict): for module_path, parallelize_style in parallelize_plan.items(): + if module_path == "": + # shortcut: empty string means to apply the plan to the current module + parallelize_module(module, device_mesh, parallelize_style) + continue + path_splits = module_path.split(".") - if len(path_splits) == 0: - raise ValueError( - "Expect module path to be non-empty, but got empty string!" - ) - while path_splits: - atom = path_splits.pop(0) - matched_children = filter( + # Instead of blindly popping tokens, first check the match, + # we only consume/pop the token if we found a match. + token = path_splits[0] + + matched_children = list( + filter( # `t[0]` is child name - lambda t: fnmatch(t[0], atom), + lambda t: fnmatch(t[0], token), module.named_children(), ) - # apply the plan to all matched submodules - for _, submodule in matched_children: - if path_splits: - # we haven't reached the leaf, apply in dict style - leaf_path = ".".join( - path_splits - ) # rest of the path after `atom` - parallelize_module( - submodule, - device_mesh, - {leaf_path: parallelize_style}, - src_data_rank=src_data_rank, - ) - else: - # otherwise, directly apply style to this submodule - parallelize_module( - submodule, - device_mesh, - parallelize_style, - src_data_rank=src_data_rank, - ) + ) + if not matched_children: + # No match at this level. Log a warning and process next plan entry. + warnings.warn( + f"Parallelize plan key '{module_path}' could not be resolved: " + f"no submodule matching token '{token}' in module {module}, " + f"skipping this plan entry." + ) + continue + + # Now that we have a match, we can consume the token. + path_splits.pop(0) + # apply the plan to all matched submodules + for _, submodule in matched_children: + if path_splits: + # we haven't reached the leaf, apply in dict style + leaf_path = ".".join(path_splits) # rest of the path after `token` + parallelize_module( + submodule, + device_mesh, + {leaf_path: parallelize_style}, + src_data_rank=src_data_rank, + ) + else: + # otherwise, directly apply style to this submodule + parallelize_module( + submodule, + device_mesh, + parallelize_style, + src_data_rank=src_data_rank, + ) return module else: raise TypeError( # pyre-ignore[7] diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 443a3375f21cfc..a8fdd7bec1ac9d 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -41,8 +41,10 @@ def is_shard(self, dim: Optional[int] = None) -> bool: def is_replicate(self) -> bool: return isinstance(self, Replicate) - def is_partial(self) -> bool: - return isinstance(self, Partial) + def is_partial(self, reduce_op: Optional[str] = None) -> bool: + if reduce_op is None: + return isinstance(self, Partial) + return isinstance(self, Partial) and self.reduce_op == reduce_op @dataclass(frozen=True) diff --git a/torch/distributions/exp_family.py b/torch/distributions/exp_family.py index 7f275fe8d6f3ea..ab8d340bd79310 100644 --- a/torch/distributions/exp_family.py +++ b/torch/distributions/exp_family.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Union + import torch from torch import Tensor from torch.distributions.distribution import Distribution @@ -55,7 +57,7 @@ def entropy(self): """ Method to compute the entropy using Bregman divergence of the log normalizer. """ - result = -self._mean_carrier_measure + result: Union[Tensor, float] = -self._mean_carrier_measure nparams = [p.detach().requires_grad_() for p in self._natural_params] lg_normal = self._log_normalizer(*nparams) gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index d5fbff877413a8..1724b586b5a762 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -170,7 +170,7 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) event_dim = len(self.event_shape) - log_prob = 0.0 + log_prob: Union[Tensor, float] = 0.0 y = value for transform in reversed(self.transforms): x = transform.inv(y) diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index b53c4721ffc719..8ebed81f493d19 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -1,16 +1,15 @@ -# mypy: allow-untyped-defs +from collections.abc import Sequence from functools import update_wrapper -from typing import Any, Callable, Generic, overload, Union -from typing_extensions import TypeVar +from typing import Any, Callable, Final, Generic, Optional, overload, TypeVar, Union import torch import torch.nn.functional as F -from torch import Tensor +from torch import SymInt, Tensor from torch.overrides import is_tensor_like -from torch.types import _Number, Number +from torch.types import _dtype, _Number, Device, Number -euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant +euler_constant: Final[float] = 0.57721566490153286060 # Euler Mascheroni Constant __all__ = [ "broadcast_all", @@ -59,7 +58,11 @@ def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]: return torch.broadcast_tensors(*values) -def _standard_normal(shape, dtype, device): +def _standard_normal( + shape: Sequence[Union[int, SymInt]], + dtype: Optional[_dtype], + device: Optional[Device], +) -> Tensor: if torch._C._get_tracing_state(): # [JIT WORKAROUND] lack of support for .normal_() return torch.normal( @@ -69,7 +72,7 @@ def _standard_normal(shape, dtype, device): return torch.empty(shape, dtype=dtype, device=device).normal_() -def _sum_rightmost(value, dim): +def _sum_rightmost(value: Tensor, dim: int) -> Tensor: r""" Sum out ``dim`` many rightmost dimensions of a given tensor. @@ -83,7 +86,7 @@ def _sum_rightmost(value, dim): return value.reshape(required_shape).sum(-1) -def logits_to_probs(logits, is_binary=False): +def logits_to_probs(logits: Tensor, is_binary: bool = False) -> Tensor: r""" Converts a tensor of logits into probabilities. Note that for the binary case, each value denotes log odds, whereas for the @@ -95,7 +98,7 @@ def logits_to_probs(logits, is_binary=False): return F.softmax(logits, dim=-1) -def clamp_probs(probs): +def clamp_probs(probs: Tensor) -> Tensor: """Clamps the probabilities to be in the open interval `(0, 1)`. The probabilities would be clamped between `eps` and `1 - eps`, @@ -121,7 +124,7 @@ def clamp_probs(probs): return probs.clamp(min=eps, max=1 - eps) -def probs_to_logits(probs, is_binary=False): +def probs_to_logits(probs: Tensor, is_binary: bool = False) -> Tensor: r""" Converts a tensor of probabilities into logits. For the binary case, this denotes the probability of occurrence of the event indexed by `1`. diff --git a/torch/export/__init__.py b/torch/export/__init__.py index d2b208ca19b667..6c3c2b6f937785 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -358,22 +358,24 @@ def save( import torch import io + class MyModule(torch.nn.Module): def forward(self, x): return x + 10 + ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file - torch.export.save(ep, 'exported_program.pt2') + torch.export.save(ep, "exported_program.pt2") # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files - extra_files = {'foo.txt': b'bar'.decode('utf-8')} - torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files) + extra_files = {"foo.txt": b"bar".decode("utf-8")} + torch.export.save(ep, "exported_program.pt2", extra_files=extra_files) """ if not isinstance(ep, ExportedProgram): @@ -381,29 +383,15 @@ def forward(self, x): f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead." ) - from torch._export.serde.schema import SCHEMA_VERSION - from torch._export.serde.serialize import serialize, SerializedArtifact - - artifact: SerializedArtifact = serialize(ep, opset_version, pickle_protocol) - - if isinstance(f, (str, os.PathLike)): - f = os.fspath(f) - - with zipfile.ZipFile(f, "w") as zipf: - # Save every field in the SerializedArtifact to a file. - assert isinstance(artifact.exported_program, bytes) - zipf.writestr("serialized_exported_program.json", artifact.exported_program) - zipf.writestr("serialized_state_dict.pt", artifact.state_dict) - zipf.writestr("serialized_constants.pt", artifact.constants) - zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs) - - zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION))) + from torch.export.pt2_archive._package import package_pt2 - # Add extra files if provided - if extra_files: - for extra_file_name, content in extra_files.items(): - encoded_content = content.encode("utf-8") - zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) + package_pt2( + f, + exported_programs={"model": ep}, + extra_files=extra_files, + pickle_protocol=pickle_protocol, + opset_version=opset_version, + ) def load( @@ -441,18 +429,18 @@ def load( import io # Load ExportedProgram from file - ep = torch.export.load('exported_program.pt2') + ep = torch.export.load("exported_program.pt2") # Load ExportedProgram from io.BytesIO object - with open('exported_program.pt2', 'rb') as f: + with open("exported_program.pt2", "rb") as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. - extra_files = {'foo.txt': ''} # values will be replaced with data - ep = torch.export.load('exported_program.pt2', extra_files=extra_files) - print(extra_files['foo.txt']) + extra_files = {"foo.txt": ""} # values will be replaced with data + ep = torch.export.load("exported_program.pt2", extra_files=extra_files) + print(extra_files["foo.txt"]) print(ep(torch.randn(5))) """ if isinstance(f, (str, os.PathLike)): @@ -460,12 +448,36 @@ def load( extra_files = extra_files or {} + from torch.export.pt2_archive._package import load_pt2, PT2ArchiveContents + + try: + pt2_contents = load_pt2( + f, + expected_opset_version=expected_opset_version, + ) + except RuntimeError: + pt2_contents = PT2ArchiveContents({}, {}, {}) + + if len(pt2_contents.exported_programs) > 0 or len(pt2_contents.extra_files) > 0: + for k, v in pt2_contents.extra_files.items(): + extra_files[k] = v + + return pt2_contents.exported_programs["model"] + + # TODO: For backward compatibility, we support loading a zip file from 2.7. Delete this path in 2.9(?) + warnings.warn( + "This version of file is deprecated. Please generate a new pt2 saved file." + ) with zipfile.ZipFile(f, "r") as zipf: # Check the version version = zipf.read("version").decode().split(".") - from torch._export.serde.schema import SCHEMA_VERSION + from torch._export.serde.schema import ( + SCHEMA_VERSION, # todo change archive version to schema version + ) - assert len(version) == len(SCHEMA_VERSION) + assert len(version) == len(SCHEMA_VERSION), ( + "Version in the saved file has incorrect length, double check if the file is generated by torch.export.save()" + ) if version[0] != str(SCHEMA_VERSION[0]): raise RuntimeError( f"Serialized version {version} does not match our current " @@ -564,24 +576,29 @@ def register_dataclass( import torch from dataclasses import dataclass + @dataclass class InputDataClass: feature: torch.Tensor bias: int + @dataclass class OutputDataClass: res: torch.Tensor + torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) + class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res) - ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), )) + + ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),)) print(ep) """ diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 2c77df8ade0da8..9a9ed922c83e76 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -43,7 +43,7 @@ def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str]) continue res += f""" - File {str_to_filename[frame['filename']]}, lineno {frame['line']}, in {frame['name']}""" # type: ignore[index] + File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index] res += f"\n {stack[-1]['loc']}" return res @@ -327,12 +327,12 @@ def _log_expression_created( # We don't want to log all expression_created logs, only # the ones that are relevant to the # guards/propagate_real_tensor - self.expression_created_logs[ - metadata[key]["result_id"] - ] = ExpressionCreatedNode( - metadata[key]["result_id"], - metadata[key].get("argument_ids", []), - record, + self.expression_created_logs[metadata[key]["result_id"]] = ( + ExpressionCreatedNode( + metadata[key]["result_id"], + metadata[key].get("argument_ids", []), + record, + ) ) return @@ -374,10 +374,13 @@ def draft_export( capture_structured_log = CaptureStructuredTrace() - with torch._functorch.config.patch( - fake_tensor_propagate_real_tensors=True, - generate_fake_kernels_from_real_mismatches=True, - ), capture_structured_log: + with ( + torch._functorch.config.patch( + fake_tensor_propagate_real_tensors=True, + generate_fake_kernels_from_real_mismatches=True, + ), + capture_structured_log, + ): try: new_shapes = None ep = _export( @@ -424,10 +427,10 @@ def convert_dim_to_auto(dim: Any) -> Any: continue elif log_name == "propagate_real_tensors_provenance": - log_contents[ - "occurrences" - ] = capture_structured_log.log_record.get_log_count( - (log_name, log_contents) + log_contents["occurrences"] = ( + capture_structured_log.log_record.get_log_count( + (log_name, log_contents) + ) ) failure_type = FailureType.DATA_DEPENDENT_ERROR diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 470d591cd54e14..a1a529ee8a3b99 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -53,13 +53,13 @@ def _remove_effect_tokens_from_graph_helper( assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) if func == torch.ops.higher_order.call_torchbind: - custom_obj_meta = node.args[2].meta["val"] + custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr] assert isinstance(custom_obj_meta, CustomObjArgument) if custom_obj_meta.fake_val: custom_obj = custom_obj_meta.fake_val - elif node.args[2].name in inputs_to_lifted_custom_objs: + elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] custom_obj = ep.constants[ - inputs_to_lifted_custom_objs[node.args[2].name] + inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr] ] else: raise RuntimeError(f"Unable to find custom obj for node {node}") diff --git a/torch/export/_swap.py b/torch/export/_swap.py index 74b564c9fccbea..df003403569ae1 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -26,9 +26,9 @@ def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]: if user.op == "output": continue - assert ( - user.op == "call_function" and user.target == operator.getitem - ), f"Expected getitem node as user for {node}, instead got {user}" + assert user.op == "call_function" and user.target == operator.getitem, ( + f"Expected getitem node as user for {node}, instead got {user}" + ) getitem_users.update(list(user.users.keys())) return getitem_users @@ -63,9 +63,9 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: log.debug("Trying to remove pytrees for module call %s", curr_module_node) curr_module_users = list(curr_module_node.users.keys()) - assert ( - len(curr_module_users) == 1 - ), f"Expected only one user for module node, instead got {list(curr_module_users)}" + assert len(curr_module_users) == 1, ( + f"Expected only one user for module node, instead got {list(curr_module_users)}" + ) flatten_node = curr_module_users[0] assert ( flatten_node.op == "call_function" diff --git a/torch/export/_trace.py b/torch/export/_trace.py index b07c41b2ea9941..80f38908f44e9a 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -10,6 +10,7 @@ import warnings from contextlib import contextmanager, nullcontext from typing import Any, Callable, Optional, Union +from typing_extensions import TypeAlias import torch import torch._dynamo @@ -93,12 +94,9 @@ ShapeEnv, ) from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo -from torch.fx.graph_module import _get_attr from torch.utils._pytree import TreeSpec from torch.utils._sympy.value_ranges import ValueRangeError -from ._safeguard import AutogradStateOpsFailSafeguard -from ._wrapper_utils import _WrapperModule from .exported_program import ( _disable_prexisiting_fake_mode, ExportedProgram, @@ -112,6 +110,10 @@ log = logging.getLogger(__name__) +# Type alias for dynamic shapes specification +_DynamicShapesSpec: TypeAlias = Union[dict[str, Any], tuple[Any, ...], list[Any]] + + @dataclasses.dataclass class ExportDynamoConfig: """ @@ -269,9 +271,9 @@ def _extract_fake_inputs(gm, args, kwargs): if detected_fake_mode: if detected_shape_env: - assert ( - detected_shape_env is detected_fake_mode.shape_env - ), "Detected shape env does not match fake mode's shape env" + assert detected_shape_env is detected_fake_mode.shape_env, ( + "Detected shape env does not match fake mode's shape env" + ) fake_mode = detected_fake_mode elif detected_shape_env: fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True) @@ -840,23 +842,9 @@ def _export_to_aten_ir( transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. pre_dispatch=False, decomp_table=None, - _check_autograd_state: bool = True, - _is_torch_jit_trace: bool = False, _prettify_placeholder_names: bool = True, decompose_custom_triton_ops: bool = False, ) -> ATenExportArtifact: - # [NOTE] If the user is exporting under training mode, we want to detect if there is any - # state change in the autograd global state and error. If the user is exporting under inference - # mode, we don't care. At predispatch level, we don't care about the state change. - is_grad_enabled = torch._C.is_grad_enabled() - grad_safe_guard = nullcontext() - # export_to_aten_ir is called when we decompose the ep into inference IR - # In that setting, we actually shouldn't check the state change as at this point, - # because the intention is specalizing to inference. - if _check_autograd_state: - if not pre_dispatch and is_grad_enabled: - grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment] - custom_triton_ops_decomposition_ctx = ( nullcontext if decompose_custom_triton_ops @@ -865,13 +853,18 @@ def _export_to_aten_ir( # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, # otherwise aot_export_module will error out because it sees a mix of fake_modes. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. - with torch.nn.utils.stateless._reparametrize_module( - mod, - fake_params_buffers, - tie_weights=True, - strict=True, - stack_weights=True, - ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(), custom_triton_ops_decomposition_ctx(): # type: ignore[attr-defined] + with ( + torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ), + _ignore_backend_decomps(), + _compiling_state_context(), + custom_triton_ops_decomposition_ctx(), + ): gm, graph_signature = transform(aot_export_module)( mod, fake_args, @@ -1198,7 +1191,51 @@ def _get_original_state_dict(mod: torch.nn.Module) -> dict[str, Any]: return original_state_dict -def _process_export_inputs(mod, args, kwargs, dynamic_shapes): +def _process_export_inputs( + mod: torch.nn.Module, + args: tuple[object, ...], + kwargs: Optional[dict[str, object]], + dynamic_shapes: Optional[ + Union[ + _DynamicShapesSpec, + torch.export.AdditionalInputs, + torch.export.ShapesCollection, + ] + ], +) -> tuple[ + tuple[object, ...], + dict[str, object], + TreeSpec, + Optional[_DynamicShapesSpec], + Callable[[ExportedProgram], None], +]: + """ + Process and validate export inputs for the torch.export API. + + This function validates the input arguments, normalizes kwargs, computes input tree specs, + and handles special dynamic shapes cases like AdditionalInputs and ShapesCollection. + + Args: + mod: The PyTorch module to be exported. + args: Tuple of example positional inputs for the module. + kwargs: Optional dictionary of example keyword inputs. + dynamic_shapes: Optional specification for dynamic shapes. Can be: + - dict mapping argument names to dynamic shape specifications + - tuple/list specifying dynamic shapes for each input in order + - torch.export.AdditionalInputs object with verification callback + - torch.export.ShapesCollection object + + Returns: + A tuple containing: + - args: Validated tuple of positional inputs + - kwargs: Normalized dictionary of keyword inputs (empty dict if None was passed) + - original_in_spec: TreeSpec representing the flattened input structure + - dynamic_shapes: Processed dynamic shapes specification + - verify_additional_inputs: Callback function for additional input verification + + Raises: + UserError: If args is not a tuple. + """ if not isinstance(args, tuple): raise UserError( UserErrorType.INVALID_INPUT, @@ -1207,15 +1244,19 @@ def _process_export_inputs(mod, args, kwargs, dynamic_shapes): kwargs = kwargs if kwargs is not None else {} _, original_in_spec = pytree.tree_flatten((args, kwargs)) + verify_additional_inputs: Callable[[ExportedProgram], None] + out_dynamic_shapes: Optional[_DynamicShapesSpec] if isinstance(dynamic_shapes, torch.export.AdditionalInputs): - verify_additional_inputs = dynamic_shapes.verify - dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + verify_additional_inputs = dynamic_shapes.verify # type: ignore[assignment] + out_dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) # type: ignore[assignment] else: verify_additional_inputs = lambda ep: None # noqa: E731 if isinstance(dynamic_shapes, torch.export.ShapesCollection): - dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + out_dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) # type: ignore[assignment] + else: + out_dynamic_shapes = dynamic_shapes - return args, kwargs, original_in_spec, dynamic_shapes, verify_additional_inputs + return args, kwargs, original_in_spec, out_dynamic_shapes, verify_additional_inputs def _get_module_call_graph( @@ -1230,9 +1271,9 @@ def _get_module_call_graph( """ gm: torch.fx.GraphModule = export_artifact.aten.gm export_graph_signature: ExportGraphSignature = export_artifact.aten.sig - module_call_specs: dict[ - str, dict[str, TreeSpec] - ] = export_artifact.module_call_specs + module_call_specs: dict[str, dict[str, TreeSpec]] = ( + export_artifact.module_call_specs + ) in_spec: TreeSpec = export_artifact.in_spec out_spec: TreeSpec = export_artifact.out_spec @@ -1271,7 +1312,6 @@ def _get_range_constraints( args, kwargs, dynamic_shapes, - _is_torch_jit_trace=False, ): gm: torch.fx.GraphModule = export_artifact.aten.gm export_graph_signature: ExportGraphSignature = export_artifact.aten.sig @@ -1284,24 +1324,21 @@ def _get_range_constraints( ), len(export_graph_signature.input_specs), ) - combined_args = _combine_args( - mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace - ) + combined_args = _combine_args(mod, args, kwargs) - # This is because we trace based on the kewargs passed in from user + # This is because we trace based on the kwargs passed in from user # not based on the signature. I feel it would be better to just enforce # one ordering at the start of tracing to avoid confusions, but that is # bigger refactor, so do this to unblock for now. - if not _is_torch_jit_trace: - combined_args_traced_order = {} - for arg in combined_args: - if arg not in kwargs: - combined_args_traced_order[arg] = combined_args[arg] + combined_args_traced_order = {} + for arg in combined_args: + if arg not in kwargs: + combined_args_traced_order[arg] = combined_args[arg] - for key in kwargs: - combined_args_traced_order[key] = kwargs[key] + for key in kwargs: + combined_args_traced_order[key] = kwargs[key] - combined_args = combined_args_traced_order + combined_args = combined_args_traced_order range_constraints = make_constraints( fake_mode, @@ -1350,43 +1387,6 @@ def _temp_disable_texpr_fuser(): torch._C._jit_set_texpr_fuser_enabled(original_state) -def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): - with _temp_disable_texpr_fuser(): - from torch.jit._trace import TopLevelTracedModule - - export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs) - - if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator] - return _export( - traced_callable, - export_args, - export_kwargs, - strict=False, - _is_torch_jit_trace=True, - ).module() - - elif isinstance(traced_callable, torch.ScriptMethod) and isinstance( - traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator] - ): - with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator] - return _export( - traced_callable.owner(), # type: ignore[operator] - export_args, - export_kwargs, - strict=False, - _is_torch_jit_trace=True, - ).module() - - else: - return _export( - _WrapperModule(traced_callable), - export_args, - export_kwargs, - strict=False, - _is_torch_jit_trace=True, - ).module() - - def _strict_export( mod: torch.nn.Module, args: tuple[Any, ...], @@ -1395,7 +1395,6 @@ def _strict_export( preserve_module_call_signature: tuple[str, ...], orig_in_spec: TreeSpec, allow_complex_guards_as_runtime_asserts: bool, - _is_torch_jit_trace: bool, _to_aten_func: Callable, ) -> ExportArtifact: """ @@ -1431,9 +1430,9 @@ def _strict_export( attr = getattr(gm_torch_level, node.target) # Checks if it is not a HigherOrderOp branch or a module if not isinstance(attr, torch.nn.Module): - assert ( - dynamo_fake_mode is not None - ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." + assert dynamo_fake_mode is not None, ( + "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." + ) node.meta["val"] = dynamo_fake_mode.from_tensor( attr, static_shapes=True ) @@ -1750,13 +1749,17 @@ def _is_impure(node): # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, # otherwise aot_export_module will error out because it sees a mix of fake_modes. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. - with torch.nn.utils.stateless._reparametrize_module( - mod, - fake_params_buffers, - tie_weights=True, - strict=True, - stack_weights=True, - ), _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] + with ( + torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ), + _ignore_backend_decomps(), + _compiling_state_context(), + ): gm, graph_signature = transform(_make_fx_helper)( mod, fake_args, @@ -1809,7 +1812,6 @@ def set_missing_meta_vals(gm, flat_args, num_params_buffers): # need to have their metadata set before lifting them because it is needed # for computing the exported program's signature. index = 0 - fake_mode = detect_fake_mode(flat_args) for node in gm.graph.nodes: if node.op == "placeholder": if index >= num_params_buffers: @@ -1817,16 +1819,6 @@ def set_missing_meta_vals(gm, flat_args, num_params_buffers): if not isinstance(user_arg, torch.Tensor): node.meta["val"] = user_arg index += 1 - if node.op == "get_attr": - val = _get_attr(gm, node.target) - if isinstance(val, torch.Tensor): - assert "val" not in node.meta, ( - f"Found attribute {node.target} that has already been fakified " - "but not yet lifted as an input. This should be impossible because " - "(1) we should have already fakified AND lifted params/buffers " - "(2) we should have NOT yet fakified OR lifted tensor constants. " - ) - node.meta["val"] = fake_mode.from_tensor(val, static_shapes=True) def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node: @@ -1841,7 +1833,6 @@ def _non_strict_export( preserve_module_call_signature: tuple[str, ...], orig_in_spec: TreeSpec, allow_complex_guards_as_runtime_asserts: bool, - _is_torch_jit_trace: bool, _to_aten_func: Callable, ) -> ExportArtifact: """ @@ -1935,7 +1926,6 @@ def forward(self, *args, **kwargs): args, kwargs, dynamic_shapes, - _is_torch_jit_trace=_is_torch_jit_trace, allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization ) @@ -1948,7 +1938,6 @@ def _produce_guards_callback(gm): dynamic_shapes=dynamic_shapes, equalities_inputs=equalities_inputs, original_signature=original_signature, - _is_torch_jit_trace=_is_torch_jit_trace, ) tx = TracingContext(fake_mode) @@ -1956,22 +1945,27 @@ def _produce_guards_callback(gm): # We also need to attach dynamo configs as these will be used in HOOs that # use torch.compile, like cond dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG) - dynamo_config[ - "do_not_emit_runtime_asserts" - ] = False # We want to emit runtime asserts - - with fake_mode, _NonStrictTorchFunctionHandler(), tracing( - tx - ), torch._dynamo.config.patch(dynamo_config): - with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( - patched_mod, - new_fake_args, - new_fake_kwargs, - new_fake_constant_attrs, - map_fake_to_real, - ), _fakify_module_inputs( - fake_args, fake_kwargs, fake_mode - ), _override_builtin_ops(): + dynamo_config["do_not_emit_runtime_asserts"] = ( + False # We want to emit runtime asserts + ) + + with ( + fake_mode, + _NonStrictTorchFunctionHandler(), + tracing(tx), + torch._dynamo.config.patch(dynamo_config), + ): + with ( + _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( + patched_mod, + new_fake_args, + new_fake_kwargs, + new_fake_constant_attrs, + map_fake_to_real, + ), + _fakify_module_inputs(fake_args, fake_kwargs, fake_mode), + _override_builtin_ops(), + ): aten_export_artifact = _to_aten_func( # type: ignore[operator] patched_mod, new_fake_args, @@ -2038,7 +2032,6 @@ def _export_for_training( preserve_module_call_signature=preserve_module_call_signature, orig_in_spec=orig_in_spec, allow_complex_guards_as_runtime_asserts=False, - _is_torch_jit_trace=False, _to_aten_func=_export_to_aten_ir_make_fx, ) @@ -2101,14 +2094,13 @@ def _export( preserve_module_call_signature: tuple[str, ...] = (), pre_dispatch: bool = False, allow_complex_guards_as_runtime_asserts: bool = False, - _is_torch_jit_trace: bool = False, ) -> ExportedProgram: """ Traces either an nn.Module's forward function or just a callable with PyTorch operations inside and produce a ExportedProgram. Args: - f: the `nn.Module` to trace. + mod: the `nn.Module` to trace. args: example positional inputs. @@ -2144,7 +2136,7 @@ def _export( while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes. Returns: - An ExportedProgram containing the traced method. + An ExportedProgram containing the traced module. """ from torch._utils_internal import export_training_ir_rollout_check @@ -2201,18 +2193,14 @@ def _export( preserve_module_call_signature=preserve_module_call_signature, orig_in_spec=original_in_spec, allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, - _is_torch_jit_trace=_is_torch_jit_trace, _to_aten_func=functools.partial( _export_to_aten_ir, pre_dispatch=pre_dispatch, - _is_torch_jit_trace=_is_torch_jit_trace, ), ) export_graph_signature: ExportGraphSignature = export_artifact.aten.sig - forward_arg_names = ( - _get_forward_arg_names(mod, args, kwargs) if not _is_torch_jit_trace else None - ) + forward_arg_names = _get_forward_arg_names(mod, args, kwargs) inline_constraints = _get_inline_constraints(export_artifact.fake_mode) # The unbacked symint symbols are updated in aot_export # so we serialize them here instead of inside dynamo. @@ -2224,7 +2212,6 @@ def _export( args, kwargs, dynamic_shapes, - _is_torch_jit_trace=_is_torch_jit_trace, ) gm, module_call_graph = _get_module_call_graph( export_artifact, @@ -2235,8 +2222,7 @@ def _export( _verify_nn_module_stack(gm) _verify_stack_trace(gm) - if not _is_torch_jit_trace: - _verify_placeholder_names(gm, export_graph_signature) + _verify_placeholder_names(gm, export_graph_signature) # Remove Proxy because they cannot be deepcopied or pickled. torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True) diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index e51c12800ad9bf..996d6830135ed9 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -16,6 +16,7 @@ from torch.export.unflatten import _assign_attr, _AttrKind from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.traceback import NodeSource, NodeSourceAction from ._remove_effect_tokens_pass import _remove_effect_tokens from ._tree_utils import reorder_kwargs @@ -115,6 +116,13 @@ def _unlift_inputs_as_getattr( metadata = input_node.meta gm.graph.erase_node(input_node) getattr_node.meta = metadata + getattr_node.meta["from_node"] = [ + NodeSource( + input_node, + "ExportedProgram.module().unlift()", + [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], + ) + ] unlifted_name_to_node[lifted_node] = getattr_node return unlifted_name_to_node, input_name_to_node @@ -172,6 +180,13 @@ def _insert_copy_for_mutations( gm.graph.erase_node(output_node) new_output.name = output_node.name new_output.meta.update(output_node.meta) + new_output.meta["from_node"] = [ + NodeSource( + output_node, + "ExportedProgram.module().unlift()", + [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], + ) + ] def _get_codegen( @@ -446,6 +461,11 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu for out_spec in ep.graph_signature.output_specs ] + for node in new_gm.graph.nodes: + node.meta["from_node"] = [ + NodeSource(node, "ExportedProgram.module()", NodeSourceAction.CREATE) + ] + new_gm = _unlift( new_gm, lifted_inputs, diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 3772fc1c72aa73..f951b5818afd1b 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -85,18 +85,61 @@ def __call__(self, min=None, max=None) -> "_DimHint": class Dim: """ - :func:`Dim` constructs a type analogous to a named symbolic integer with a range. - It can be used to describe multiple possible values of a dynamic tensor dimension. - Note that different dynamic dimensions of the same tensor, or of different tensors, - can be described by the same type. + The `Dim` class allows users to specify dynamism in their exported programs. By marking a dimension with a `Dim`, + the compiler associates the dimension with a symbolic integer containing a dynamic range. - Args: - name (str): Human-readable name for debugging. - min (Optional[int]): Minimum possible value of given symbol (inclusive) - max (Optional[int]): Maximum possible value of given symbol (inclusive) + The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes: `Dim.AUTO`, `Dim.DYNAMIC`, `Dim.STATIC`), + or named Dims (i.e. `Dim("name", min=1, max=2)`). + + Dim hints provide the lowest barrier to exportability, with the user only needing to specify if a dimension + if dynamic, static, or left for the compiler to decide (`Dim.AUTO`). The export process will automatically + infer the remaining constraints on min/max ranges and relationships between dimensions. + + Example:: + + class Foo(nn.Module): + def forward(self, x, y): + assert x.shape[0] == 4 + assert y.shape[0] >= 16 + return x @ y + + + x = torch.randn(4, 8) + y = torch.randn(8, 16) + dynamic_shapes = { + "x": {0: Dim.AUTO, 1: Dim.AUTO}, + "y": {0: Dim.AUTO, 1: Dim.AUTO}, + } + ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) + + Here, export would raise an exception if we replaced all uses of `Dim.AUTO` with `Dim.DYNAMIC`, + as x.shape[0] is constrained to be static by the model. + + More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler, + e.g. (x.shape[0] + y.shape[1]) % 4 == 0, to be raised if runtime inputs do not satisfy such constraints. + + You may also specify min-max bounds for Dim hints, e.g. `Dim.AUTO(min=16, max=32)`, `Dim.DYNAMIC(max=64)`, + with the compiler inferring the remaining constraints within the ranges. An exception will be raised if + the valid range is entirely outside the user-specified range. + + Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compiler + infers constraints that do not match the user specification. For example, exporting the previous + model, the user would need the following `dynamic_shapes` argument:: + + s0 = Dim("s0") + s1 = Dim("s1", min=16) + dynamic_shapes = { + "x": {0: 4, 1: s0}, + "y": {0: s0, 1: s1}, + } + ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) + + Named Dims also allow specification of relationships between dimensions, up to univariate linear relations. + For example, the following indicates one dimension is a multiple of another plus 4:: + + s0 = Dim("s0") + s1 = 3 * s0 + 4 - Returns: - A type that can be used in dynamic shape specifications for tensors. """ AUTO = _DimHint.AUTO() @@ -199,11 +242,11 @@ def __init__(self, value: int): self.value = value @property - def min(self): + def min(self): # type: ignore[override] return self.value # type: ignore[attr-defined] @property - def max(self): + def max(self): # type: ignore[override] return self.value # type: ignore[attr-defined] @@ -229,7 +272,7 @@ def __init__(self, name: str, root: Dim, fn: Callable): self.fn = fn @property - def min(self): + def min(self): # type: ignore[override] # assume that self.fn is an increasing function # TODO(avik): use sympy value range analysis instead? from sympy import Integer @@ -249,7 +292,7 @@ def min(self): return int(_min_symint) @property - def max(self): + def max(self): # type: ignore[override] # assume that self.fn is an increasing function # TODO(avik): use sympy value range analysis instead? from sympy import Integer @@ -639,20 +682,19 @@ def _compare(tree, dynamic_shapes, path): raise -def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> dict[str, Any]: +def _combine_args(f, args, kwargs) -> dict[str, Any]: # combine args and kwargs following the signature of f, as it happens # in the body of f when called with *args, **kwargs if isinstance(f, ExportedProgram): f = f.module() - if not _is_torch_jit_trace: - signature = ( - inspect.signature(f.forward) - if isinstance(f, torch.nn.Module) - else inspect.signature(f) - ) - kwargs = kwargs if kwargs is not None else {} - return signature.bind(*args, **kwargs).arguments - return args + + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + kwargs = kwargs if kwargs is not None else {} + return signature.bind(*args, **kwargs).arguments class ShapesCollection: @@ -666,7 +708,7 @@ class ShapesCollection: Example:: - args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) + args = {"x": tensor_x, "others": [tensor_y, tensor_z]} dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() @@ -682,7 +724,7 @@ class ShapesCollection: Example:: - args = ({"x": tensor_x, "others": [int_x, int_y]}) + args = {"x": tensor_x, "others": [int_x, int_y]} # Wrap all ints with _IntWrapper mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args) @@ -700,18 +742,18 @@ def __init__(self): self._shapes = {} def __setitem__(self, t, shape): - assert isinstance( - t, (torch.Tensor, _IntWrapper) - ), f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}" + assert isinstance(t, (torch.Tensor, _IntWrapper)), ( + f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}" + ) # TODO(avik): check that shape is indeed a Shape t_id = id(t) if t_id in self._shapes: _shape = self._shapes[t_id] - assert ( - shape == _shape - ), f"Shapes assigned to input do not match: expected {_shape}, got {shape}" + assert shape == _shape, ( + f"Shapes assigned to input do not match: expected {_shape}, got {shape}" + ) else: self._shapes[id(t)] = shape @@ -766,7 +808,7 @@ class AdditionalInputs: Example:: - args0, kwargs0 = ... # example inputs for export + args0, kwargs0 = ... # example inputs for export # other representative inputs that the exported program will run on dynamic_shapes = torch.export.AdditionalInputs() @@ -786,9 +828,9 @@ def add(self, args, kwargs=None): """ assert type(args) is tuple, f"Representative args {args} must be a tuple" - assert ( - kwargs is None or type(kwargs) is dict - ), f"Representative kwargs {kwargs} must be None or a dict" + assert kwargs is None or type(kwargs) is dict, ( + f"Representative kwargs {kwargs} must be None or a dict" + ) self._examples.append((args, kwargs)) def dynamic_shapes(self, m, args, kwargs=None): @@ -1075,7 +1117,8 @@ def root_value(): i, dim.__name__, StrictMinMaxConstraint( - vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined] + vr=ValueRanges(lower=dim.value, upper=dim.value), # type: ignore[attr-defined] + warn_only=False, ), ) else: @@ -1085,7 +1128,8 @@ def root_value(): i, dim.__name__, StrictMinMaxConstraint( - vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined] + vr=ValueRanges(lower=dim.min, upper=dim.max), # type: ignore[attr-defined] + warn_only=False, ), ) return constraint @@ -1161,7 +1205,7 @@ def assoc_shape(path, t, dynamic_shape): def _get_dim_name_mapping( - dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None] + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], ): name_to_dim = {} for dim in tree_flatten( diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 99fc0d6995cb2e..2d8569bd6ae92b 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -137,16 +137,11 @@ class _ExportPackage: "decoder": ExportMethod( overloads={ "prefill": ExportedProgram(...), - "decode": ExportedProgram(...) + "decode": ExportedProgram(...), }, - fallbacks=[] + fallbacks=[], ), - "encoder": ExportMethod( - overloads={}, - fallbacks=[ - ExportedProgram(...) - ] - ) + "encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]), }, ) ``` @@ -212,15 +207,18 @@ def _exporter( ``` package = ExportPackage() + def prefill(x, xa, kv_cache): assert x.shape[1] == 3 assert kv_cache == {} + def decode(x, xa, kv_cache): assert x.shape[1] > 1 assert len(kv_cache) > 0 return {...} # dynamic shape specs here + exporter = ( package.exporter(decoder) .define_overload("prefill", prefill) @@ -326,3 +324,27 @@ def _define_overload( _exporter_context._define_overload = _define_overload # type: ignore[attr-defined] return _exporter_context + + @property + def _method_overloads( + self, + ) -> typing.Iterator[tuple[str, torch.export.ExportedProgram]]: + for method, method_data in self.methods.items(): + for overload, ep in method_data.overloads.items(): + yield f"{method}:{overload}", ep + + def _compiled_and_package(self, f: torch.types.FileLike) -> None: + options = { + "aot_inductor.package": True, + "aot_inductor.package_cpp_only": True, + "always_keep_tensor_constants": True, + "aot_inductor.package_constants_in_so": False, + } + weights_map = {} + for name, ep in self._method_overloads: + weights = torch._inductor.aot_compile(ep.module(), (), options=options) # type: ignore[arg-type] + weights_map[name] = weights + torch._inductor.package.package.package_aoti( + f, + weights_map, # type: ignore[arg-type] + ) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 3a99135392311a..4aff43fb81558b 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -7,7 +7,7 @@ import operator import types import warnings -from collections import namedtuple +from collections import defaultdict, namedtuple from collections.abc import Iterator from contextlib import contextmanager from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union @@ -272,7 +272,7 @@ def _force_dispatch_to_orig_cia_callable(fake_tensor_mode, op, *args, **kwargs): def _split_decomp_table_to_cia_and_python_decomp( - decomp_table: dict[torch._ops.OperatorBase, Callable] + decomp_table: dict[torch._ops.OperatorBase, Callable], ) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]: all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) cia_ops_to_callable = {} @@ -443,9 +443,14 @@ def _is_joint_ir_decomp(ep, joint_loss_index): tx = TracingContext(fake_mode) - with fake_mode, _override_composite_implicit_decomp( - cia_to_decomp, - ), _enable_graph_inputs_of_type_nn_module(ep.example_inputs), tracing(tx): + with ( + fake_mode, + _override_composite_implicit_decomp( + cia_to_decomp, + ), + _enable_graph_inputs_of_type_nn_module(ep.example_inputs), + tracing(tx), + ): retracing_args_unwrapped = pytree.tree_unflatten( retracing_args, mod._in_spec ) @@ -472,7 +477,6 @@ def _is_joint_ir_decomp(ep, joint_loss_index): fake_params_buffers, new_fake_constant_attrs, decomp_table=python_decomp_table, - _check_autograd_state=False, _prettify_placeholder_names=False, decompose_custom_triton_ops=decompose_custom_triton_ops, ) @@ -567,15 +571,18 @@ def _is_joint_ir_decomp(ep, joint_loss_index): # TODO(zhxhchen17) Return the new graph_signature directly. fake_mode = detect_fake_mode(fake_args) - fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode + fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode # type: ignore[assignment] custom_triton_ops_decomposition_ctx = ( contextlib.nullcontext if decompose_custom_triton_ops else _disable_custom_triton_op_functional_decomposition ) - with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( - cia_to_decomp - ), custom_triton_ops_decomposition_ctx(): + with ( + _ignore_backend_decomps(), + fake_mode, + _override_composite_implicit_decomp(cia_to_decomp), + custom_triton_ops_decomposition_ctx(), + ): gm, graph_signature = aot_export_module( ep.graph_module, fake_args, @@ -816,25 +823,118 @@ def _common_getitem_elimination_pass( def _get_updated_module_call_graph( + old_gm: torch.fx.GraphModule, + old_graph_signature: ExportGraphSignature, gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, old_module_call_graph: list[ModuleCallEntry], ): new_module_call_graph = copy.deepcopy(old_module_call_graph) + old_nodes = {node.name: node for node in old_gm.graph.nodes} + + old_graph_params_buffers = { + **old_graph_signature.inputs_to_parameters, + **old_graph_signature.inputs_to_buffers, + } + new_graph_params_buffers = { + **graph_signature.inputs_to_parameters, + **graph_signature.inputs_to_buffers, + } + # use node-level provenance metadata to create a map # from old node names to new node names provenance: dict[str, str] = {} + + user_input_counter = 0 + old_user_input_names = [ + node.target for node in old_gm.graph.nodes if node.op == "placeholder" + ] + old_user_input_names = list( + filter( + lambda x: x not in old_graph_params_buffers + and x not in old_graph_signature.input_tokens, + old_user_input_names, + ) + ) + new_user_input_names = [ + node.target for node in gm.graph.nodes if node.op == "placeholder" + ] + for node in gm.graph.nodes: if history := node.meta.get("from_node", []): provenance[history[-1].name] = node.name + # For params and buffers, we might have applied parameterizaiton rule + # so that the names might have changed. But for user inputs, we know we + # must preserve the old name. + elif node.op == "placeholder": + if not ( + node.target in new_graph_params_buffers + or node.target in graph_signature.input_tokens + ): + if node.target in new_user_input_names: + assert isinstance(node.name, str) + old_name = old_user_input_names[user_input_counter] + assert isinstance(old_name, str) + provenance[old_name] = node.name + user_input_counter += 1 + + # For all the parameters and buffers, we first see + # if they are result of paramerizaitons and if they + # are, we log them and error later + old_param_to_desugared = defaultdict(list) + for name, target in new_graph_params_buffers.items(): + # if the parameters are not parametrized, the naming won't change. + if not target.startswith("parametrizations."): + # If we are in strict mode, we can't just reuse the param names + if name in old_graph_params_buffers: + provenance[name] = name + else: + old_target = ".".join(target.split(".")[1:-1]) + old_param_to_desugared[old_target].append(name) + # map old names to new names in module call signatures for entry in new_module_call_graph: signature = entry.signature if signature is None: continue for x in [*signature.inputs, *signature.outputs]: - x.name = provenance.get(x.name, x.name) + # We noticed that submodule is taking subclass as input. we can't + # preserve signature here. + if x.name in old_param_to_desugared: + raise ValueError( + f"It looks like {x.name} is a tensor subclass. " + f"Preserving submodule that takes subclass parameter is not supported" + f" in inference IR because we desugar them, resulting in more tensors" + ) + + if x.name in provenance: + x.name = provenance[x.name] + + # This can happen when aten.to is called at graph boundaries. + # Basically aten.to at post-dispatch level can either be copy + # or alias. In the alias case, we will no-op it so it will + # disappear from the graph. If we detect such case, we should + # reuse the input to aten.to as the new input to the submodule. + # Technically this can happen for other maybe aliasing ops, + # but aten.to is probably the most common one. + elif x.name in old_nodes: + old_node = old_nodes[x.name] + if old_node.op == "call_function" and old_node.target in [ + torch.ops.aten.to.dtype_layout, + torch.ops.aten.to.device, + torch.ops.aten.to.dtype, + ]: + old_target = old_node.args[0].name + if old_target not in provenance: + raise ValueError( + f"It looks like {old_target} is a tensor subclass. " + f"Preserving submodule that takes subclass parameter is not supported" + f" in inference IR because we desugar them, resulting in more tensors" + ) + + x.name = provenance[old_target] return new_module_call_graph @@ -864,7 +964,10 @@ def _decompose_exported_program( # new nodes due to decompositions. So we need to update these signatures # in the decomposed exported program's module_call_graph. new_module_call_graph = _get_updated_module_call_graph( + ep.graph_module, + ep.graph_signature, gm, + new_graph_signature, ep.module_call_graph, ) @@ -907,6 +1010,30 @@ class ExportedProgram: again to construct a correct ExportedProgram. """ + _graph_module: torch.fx.GraphModule + """The underlying GraphModule containing the exported computation graph.""" + + _graph_signature: ExportGraphSignature + """The signature containing input/output specifications for the graph.""" + + _state_dict: dict[str, Any] + """Dictionary containing parameter and buffer values from the original module.""" + + _range_constraints: "dict[sympy.Symbol, ValueRanges]" + """Symbolic shape constraints for dynamic shapes in the graph.""" + + _module_call_graph: list[ModuleCallEntry] + """Call graph information tracking module hierarchy and signatures.""" + + _example_inputs: Optional[tuple[tuple[Any, ...], dict[str, Any]]] + """Example inputs used during export, stored as (args, kwargs) tuple.""" + + _constants: dict[str, _ConstantAttributeType] + """Dictionary of constant values used in the graph.""" + + _verifiers: list[type[Verifier]] + """List of verifier classes used to validate the exported program.""" + def __init__( self, root: Union[torch.nn.Module, dict[str, Any]], @@ -1418,9 +1545,9 @@ def _get_updated_graph_signature( if node.op != "placeholder": break - assert i < len( - old_signature.input_specs - ), "Number of inputs changed after transformation" + assert i < len(old_signature.input_specs), ( + "Number of inputs changed after transformation" + ) old_input_spec = old_signature.input_specs[i] arg = ( old_input_spec.arg @@ -1443,9 +1570,9 @@ def _get_updated_graph_signature( new_output_specs = [] for i, node in enumerate(output_node.args[0]): - assert i < len( - old_signature.output_specs - ), "Number of outputs changed after transformation" + assert i < len(old_signature.output_specs), ( + "Number of outputs changed after transformation" + ) old_output_spec = old_signature.output_specs[i] arg = ( old_output_spec.arg @@ -1503,9 +1630,9 @@ def validate(self): # TODO: remove this @final def _validate(self): - assert ( - len(self.verifiers) > 0 - ), "ExportedProgram must have at least one verifier." + assert len(self.verifiers) > 0, ( + "ExportedProgram must have at least one verifier." + ) for v in self.verifiers: v().check(self) diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index d3c4e07b09c175..830b4b784860a4 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -95,9 +95,9 @@ class InputSpec: def __post_init__(self): if self.kind == InputKind.BUFFER: - assert ( - self.persistent is not None - ), "Failed to specify persistent flag on BUFFER." + assert self.persistent is not None, ( + "Failed to specify persistent flag on BUFFER." + ) assert isinstance( self.arg, ( @@ -187,48 +187,85 @@ def __init__(self) -> None: self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers - self.register_buffer('my_buffer1', torch.tensor(3.0)) - self.register_buffer('my_buffer2', torch.tensor(4.0)) + self.register_buffer("my_buffer1", torch.tensor(3.0)) + self.register_buffer("my_buffer2", torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method - output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 + output = ( + x1 + self.my_parameter + ) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) - self.my_buffer2.add_(1.0) # In-place addition + self.my_buffer2.add_(1.0) # In-place addition return output - Resulting Graph would be:: + + mod = CustomModule() + ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) + + Resulting Graph is non-functional:: graph(): - %arg0_1 := placeholder[target=arg0_1] - %arg1_1 := placeholder[target=arg1_1] - %arg2_1 := placeholder[target=arg2_1] - %arg3_1 := placeholder[target=arg3_1] - %arg4_1 := placeholder[target=arg4_1] - %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) - %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) - %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) - %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) - %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) - return (add_tensor_2, add_tensor_1) - - Resulting ExportGraphSignature would be:: - - ExportGraphSignature( - input_specs=[ - InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='my_parameter'), - InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), - InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), - InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None), - InputSpec(kind=, arg=TensorArgument(name='arg4_1'), target=None) - ], - output_specs=[ - OutputSpec(kind=, arg=TensorArgument(name='add_2'), target='my_buffer2'), - OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None) - ] - ) + %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] + %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] + %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] + %x1 : [num_users=1] = placeholder[target=x1] + %x2 : [num_users=1] = placeholder[target=x2] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) + %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) + %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) + %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) + return (add_1,) + + Resulting ExportGraphSignature of the non-functional Graph would be:: + + # inputs + p_my_parameter: PARAMETER target='my_parameter' + b_my_buffer1: BUFFER target='my_buffer1' persistent=True + b_my_buffer2: BUFFER target='my_buffer2' persistent=True + x1: USER_INPUT + x2: USER_INPUT + + # outputs + add_1: USER_OUTPUT + + To get a functional Graph, you can use :func:`run_decompositions`:: + + mod = CustomModule() + ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) + ep = ep.run_decompositions() + + Resulting Graph is functional:: + + graph(): + %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] + %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] + %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] + %x1 : [num_users=1] = placeholder[target=x1] + %x2 : [num_users=1] = placeholder[target=x2] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) + %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) + %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) + %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) + return (add_2, add_1) + + Resulting ExportGraphSignature of the functional Graph would be:: + + # inputs + p_my_parameter: PARAMETER target='my_parameter' + b_my_buffer1: BUFFER target='my_buffer1' persistent=True + b_my_buffer2: BUFFER target='my_buffer2' persistent=True + x1: USER_INPUT + x2: USER_INPUT + + # outputs + add_2: BUFFER_MUTATION target='my_buffer2' + add_1: USER_OUTPUT + """ input_specs: list[InputSpec] @@ -520,9 +557,9 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec: # For const outputs we just directly return this return ConstantArgument(name="", value=node) - assert ( - "val" in node.meta - ), f"{node} is not a constant or a node with a 'val' metadata field" + assert "val" in node.meta, ( + f"{node} is not a constant or a node with a 'val' metadata field" + ) val = node.meta["val"] if node.name in token_names: return TokenArgument(name=node.name) @@ -565,9 +602,21 @@ def _convert_to_export_graph_signature( user_outputs = set(graph_signature.user_outputs) buffer_mutations = graph_signature.buffers_to_mutate user_input_mutations = graph_signature.user_inputs_to_mutate - grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {} # type: ignore[union-attr] - grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {} # type: ignore[union-attr] - loss_output = graph_signature.backward_signature.loss_output if is_joint else None # type: ignore[union-attr] + grad_params = ( + graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr] + if is_joint + else {} + ) + grad_user_inputs = ( + graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr] + if is_joint + else {} + ) + loss_output = ( + graph_signature.backward_signature.loss_output # type: ignore[union-attr] + if is_joint + else None + ) input_tokens = graph_signature.input_tokens output_tokens = graph_signature.output_tokens diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 9f43b02ab0cddd..7c97e6abe171cc 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -1,22 +1,52 @@ -# pyre-unsafe - import glob import io +import json import logging import os +import tempfile import zipfile -from typing import Any, Union +from dataclasses import dataclass +from typing import Any, IO, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAlias import torch +import torch.utils._pytree as pytree +from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ExportedProgram +from torch.export.pt2_archive._package_weights import ( + get_complete, + group_weights, + Weights, +) from torch.export.pt2_archive.constants import ( + AOTINDUCTOR_DIR, ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE, ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE, + CONSTANTS_DIR, + CUSTOM_OBJ_FILENAME_PREFIX, + EXTRA_DIR, + MODELS_DIR, + MODELS_FILENAME_FORMAT, + SAMPLE_INPUTS_FILENAME_FORMAT, + WEIGHT_FILENAME_PREFIX, + WEIGHTS_DIR, ) from torch.types import FileLike +if TYPE_CHECKING: + from torch.utils._ordered_set import OrderedSet + + +DEFAULT_PICKLE_PROTOCOL = 2 +AOTI_FILES: TypeAlias = Union[ + list[Union[str, Weights]], dict[str, list[Union[str, Weights]]] +] + + logger: logging.Logger = logging.getLogger(__name__) @@ -140,9 +170,9 @@ class PT2ArchiveReader: def __init__(self, archive_path_or_buffer: FileLike): self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] - assert ( - self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE - ), "Invalid archive format" + assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( + "Invalid archive format" + ) def __enter__(self) -> "PT2ArchiveReader": return self @@ -178,3 +208,478 @@ def archive_version(self) -> int: archive_version = "0" return int(archive_version) + + def get_file_names(self) -> list[str]: + """ + Get the file names in the archive. + """ + return self.archive_file.get_all_records() + + +def _package_aoti_files( + archive_writer: PT2ArchiveWriter, + aoti_files: Optional[AOTI_FILES], + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if aoti_files is None: + return + + if isinstance(aoti_files, list): + aoti_files = {"model": aoti_files} + + assert isinstance(aoti_files, dict) + + all_weights: dict[str, Weights] = {} # model_name -> weight + weights_configs: dict[ + str, dict[str, Any] + ] = {} # model_name -> (weight_name -> (filename, shape, stride, offset)) + + for model_name, files in aoti_files.items(): + num_so_files = 0 + weights_configs[model_name] = {} + + for file in files: + if file == "": + continue + + if isinstance(file, Weights): + all_weights[model_name] = file + continue + + if file.endswith(".so"): + num_so_files += 1 + if num_so_files > 1: + raise RuntimeError( + f"Multiple .so files found in {files}. " + "You might need to clear your cache " + "directory before calling aoti_compile again." + ) + + filename = os.path.basename(file) + if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX): + new_filepath = os.path.join(CONSTANTS_DIR, filename) + else: + new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename) + logger.debug( + "Saving AOTI generated file %s to archive in %s", file, new_filepath + ) + archive_writer.write_file( + str(new_filepath), + file, + ) + + if len(all_weights) > 0: + # Dedup weights + grouped_tensors: list[OrderedSet[tuple[str, str]]] = group_weights(all_weights) + for idx, group in enumerate(grouped_tensors): + filename = f"{WEIGHT_FILENAME_PREFIX}{idx}" + model_name, weight_name = get_complete(group, all_weights) + complete_tensor, _ = all_weights[model_name].get_weight(weight_name) + buffer = io.BytesIO() + torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol) + archive_writer.write_bytes( + os.path.join(WEIGHTS_DIR, filename), buffer.getvalue() + ) + for model_name, weight_name in group: + _, w_property = all_weights[model_name].get_weight(weight_name) + weights_configs[model_name][weight_name] = ( + filename, + w_property.shape, + w_property.stride, + w_property.offset, + ) + + for model_name, weights_config in weights_configs.items(): + archive_writer.write_string( + os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"), + json.dumps(weights_config), + ) + logger.debug("packaging weights_config for model %s", model_name) + logger.debug(weights_config) + + +def _package_exported_programs( + archive_writer: PT2ArchiveWriter, + exported_programs: Optional[Union[ExportedProgram, dict[str, ExportedProgram]]], + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if exported_programs is None: + return + + if isinstance(exported_programs, ExportedProgram): + exported_programs = {"model", exported_programs} # type: ignore[assignment] + + assert isinstance(exported_programs, dict) + + for model_name, ep in exported_programs.items(): + artifact: SerializedArtifact = serialize(ep, opset_version, pickle_protocol) + + archive_writer.write_bytes( + MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program + ) + # TODO:Consider dedup this with the weights saved in package_aoti_files + archive_writer.write_bytes(f"{WEIGHTS_DIR}{model_name}.pt", artifact.state_dict) + archive_writer.write_bytes( + f"{CONSTANTS_DIR}{model_name}.pt", artifact.constants + ) + archive_writer.write_bytes( + SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name), + artifact.example_inputs, + ) + + +def _package_extra_files( + archive_writer: PT2ArchiveWriter, extra_files: Optional[dict[str, Any]] +) -> None: + if extra_files is None: + return + + for extra_file_name, content in extra_files.items(): + archive_writer.write_string(f"{EXTRA_DIR}{extra_file_name}", content) + + +def package_pt2( + f: FileLike, + *, + exported_programs: Optional[ + Union[ExportedProgram, dict[str, ExportedProgram]] + ] = None, + aoti_files: Optional[AOTI_FILES] = None, + extra_files: Optional[dict[str, Any]] = None, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> FileLike: + """ + Saves the artifacts to a PT2Archive format + (https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0#heading=h.v2y2jgnwc56a). + The artifact can then be loaded using ``load_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to + implement write and flush) or a string containing a file name. + + exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): + The exported program to save, or a dictionary mapping model name to an + exported program to save. The exported program will be saved under + models/*.json. If only one ExportedProgram is specified, this will + automatically be named "model". + + aoti_files (Union[list[str], dict[str, list[str]]): A list of files + generated by AOTInductor via + ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, + or a dictionary mapping model name to its AOTInductor generated files. + If only one set of files is specified, this will automatically be named + "model". + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of the pt2. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + pickle_protocol: can be specified to override the default protocol + + """ + assert not ( + exported_programs is None and aoti_files is None and extra_files is None + ), ( + "No value passed in for `exported_programs`, `aoti_files`, and " + "`extra_files`, implying that you do not plan on saving anything." + ) + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + ): + # TODO: turn this into an error + logger.warning( + "Expect archive file to be a file ending in .pt2, or is a buffer. " + "Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with PT2ArchiveWriter(f) as archive_writer: + _package_exported_programs( + archive_writer, exported_programs, pickle_protocol=pickle_protocol + ) + _package_aoti_files( + archive_writer, + aoti_files, + pickle_protocol=pickle_protocol, + ) + _package_extra_files(archive_writer, extra_files) + + if isinstance(f, (io.IOBase, IO)): + f.seek(0) + return f + + +class AOTICompiledModel: + """ + Callable AOT Inductor loaded model from a .pt2 + """ + + def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: + self.loader = loader + + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + call_spec = self.loader.get_call_spec() + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = self.loader.boxed_run(flat_inputs) + return pytree.tree_unflatten(flat_outputs, out_spec) + + def get_metadata(self) -> dict[str, str]: + return self.loader.get_metadata() + + def load_constants( + self, + constants_map: dict[str, torch.Tensor], + *, + check_full_update: bool, + user_managed: bool = False, + ) -> None: + """ + Given a mapping of constant fqns to tensors, load the constants into the model. + You can use ``get_constant_fqns`` to get the list of constant fqns that + are needed in the compiled model. + + Args: + constants_map: A mapping of constant fqns to tensors. + check_full_update: Whether to add check to see if all the constants + are updated and have values. + """ + self.loader.load_constants( + constants_map, False, check_full_update, user_managed + ) + + def get_constant_fqns(self) -> list[str]: + return self.loader.get_constant_fqns() + + def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel": + logger.warning( + "AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied." + ) + return AOTICompiledModel(self.loader) + + +@dataclass +class PT2ArchiveContents: + exported_programs: dict[str, ExportedProgram] + aoti_runners: dict[str, AOTICompiledModel] + extra_files: dict[str, Any] + + +def _load_exported_programs( + archive_reader: PT2ArchiveReader, + file_names: list[str], + expected_opset_version: Optional[dict[str, int]], +) -> dict[str, ExportedProgram]: + exported_program_files = [ + file for file in file_names if file.startswith(MODELS_DIR) + ] + exported_programs = {} + for file in exported_program_files: + prefix, suffix = MODELS_FILENAME_FORMAT.split( + "{}" + ) # split "models/{}.json" into "models/" and "json" + model_name = file[ + len(prefix) : -len(suffix) + ] # given "models/foo.json" we can now get "foo" + + weights_file = f"{WEIGHTS_DIR}{model_name}.pt" + constants_file = f"{CONSTANTS_DIR}{model_name}.pt" + sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name) + + serialized_exported_program = archive_reader.read_bytes(file) + serialized_weights = archive_reader.read_bytes(weights_file) + serialized_constants = archive_reader.read_bytes(constants_file) + serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file) + + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_weights, + serialized_constants, + serialized_sample_inputs, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + exported_programs[model_name] = ep + + return exported_programs + + +def _load_extra_files( + archive_reader: PT2ArchiveReader, file_names: list[str] +) -> dict[str, Any]: + extra_files = [file for file in file_names if file.startswith(EXTRA_DIR)] + + extra_file_contents: dict[str, Any] = {} + for file in extra_files: + contents = archive_reader.read_string(file) + extra_file_contents[file[len(EXTRA_DIR) :]] = contents + + return extra_file_contents + + +def load_pt2( + f: FileLike, + *, + expected_opset_version: Optional[dict[str, int]] = None, + run_single_threaded: bool = False, + num_runners: int = 1, + device_index: int = -1, + load_weights_from_disk: bool = False, +) -> PT2ArchiveContents: # type: ignore[type-arg] + """ + Loads all the artifacts previously saved with ``package_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + num_runners (int): Number of runners to load AOTInductor artifacts + + run_single_threaded (bool): Whether the model should be run without + thread synchronization logic. This is useful to avoid conflicts with + CUDAGraphs. + + device_index (int): The index of the device to which the PT2 package is + to be loaded. By default, `device_index=-1` is used, which corresponds + to the device `cuda` when using CUDA. Passing `device_index=1` would + load the package to `cuda:1`, for example. + + Returns: + A ``PT2ArchiveContents`` object which contains all the objects in the PT2. + """ + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + ): + # TODO: turn this into an error in 2.9 + logger.warning( + "Unable to load package. f must be a buffer or a file ending in " + ".pt2. Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + weights = {} + weight_maps = {} + with PT2ArchiveReader(f) as archive_reader: + version = archive_reader.read_string(ARCHIVE_VERSION_PATH) + if version != ARCHIVE_VERSION_VALUE: + raise ValueError( + f"Saved archive version {version} does not match our current " + f"archive version {ARCHIVE_VERSION_VALUE}." + ) + + file_names = archive_reader.get_file_names() + + exported_programs = _load_exported_programs( + archive_reader, file_names, expected_opset_version + ) + extra_files = _load_extra_files(archive_reader, file_names) + + # Get a list of AOTI model names + aoti_model_names: set[str] = set() + for file in file_names: + if file.startswith(AOTINDUCTOR_DIR): + file_end = file[ + len(AOTINDUCTOR_DIR) : + ] # remove data/aotinductor/ prefix + model_name = file_end.split("/")[ + 0 + ] # split "model_name/...cpp" into "model_name" + aoti_model_names.add(model_name) + if load_weights_from_disk and file.endswith("weights_config.json"): + weight_map = json.loads(archive_reader.read_string(file)) + weight_maps[model_name] = weight_map + elif load_weights_from_disk and file.startswith(WEIGHTS_DIR): + weight_file_name = file[ + len(WEIGHTS_DIR) : + ] # remove data/weights/ prefix + weight_bytes = archive_reader.read_bytes(file) + loaded_weight = torch.load(io.BytesIO(weight_bytes)) + weights[weight_file_name] = loaded_weight + + if isinstance(f, (io.IOBase, IO)): + if len(aoti_model_names) > 0: + # Workaround for AOTIModelPackageLoader not reading buffers + with tempfile.NamedTemporaryFile(suffix=".pt2") as tf: + f.seek(0) + tf.write(f.read()) + f.seek(0) + logger.debug("Writing buffer to tmp file located at %s.", tf.name) + + aoti_runners = { + model_name: AOTICompiledModel( + torch._C._aoti.AOTIModelPackageLoader( + tf.name, + model_name, + run_single_threaded, + num_runners, + device_index, + ) + ) + for model_name in aoti_model_names + } + else: + aoti_runners = {} + else: + aoti_runners = { + model_name: AOTICompiledModel( + torch._C._aoti.AOTIModelPackageLoader( + f, model_name, run_single_threaded, num_runners, device_index + ) + ) + for model_name in aoti_model_names + } + + if weight_maps: + for model_name in aoti_model_names: + model_weights = {} + for weight_name, (file, shape, stride, storage_offset) in weight_maps[ + model_name + ].items(): + weight = weights[file] + model_weights[weight_name] = weight.as_strided( + shape, stride, storage_offset + ) + + # user_managed=True ensures the weights updates are shared by all runners. + aoti_runners[model_name].load_constants( + model_weights, check_full_update=True, user_managed=True + ) + + return PT2ArchiveContents(exported_programs, aoti_runners, extra_files) + + +def load_weights_to_pt2_contents( + pt2_contents: PT2ArchiveContents, weights_map: dict[str, Any] +) -> None: + """ + Load weights into the models in PT2 archive contents + + Args: + pt2_contents (PT2ArchiveContents): The contents of the PT2 archive. + """ + for model_name, weights in weights_map.items(): + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in PT2 archive contents.") + pt2_contents.aoti_runners[model_name].load_constants( + weights, check_full_update=True, user_managed=True + ) diff --git a/torch/export/pt2_archive/_package_weights.py b/torch/export/pt2_archive/_package_weights.py new file mode 100644 index 00000000000000..e6721ea9229a6b --- /dev/null +++ b/torch/export/pt2_archive/_package_weights.py @@ -0,0 +1,101 @@ +import collections + +import torch +from torch.utils._ordered_set import OrderedSet + + +def _end_ptr(tensor: torch.Tensor) -> int: + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +class TensorProperties: + def __init__(self, tensor: torch.Tensor): + # info about underlying storage + self.storage_ptr = tensor.untyped_storage().data_ptr() + self.storage_size = tensor.untyped_storage().nbytes() + + # info to recover tensor + self.shape = tensor.shape + self.stride = tensor.stride() + self.offset = tensor.storage_offset() + + self.start = tensor.data_ptr() + self.end = _end_ptr(tensor) + + def is_complete(self) -> bool: + """ + Whehter the tensor completely overlaps with its underlying storage + """ + return ( + self.start == self.storage_ptr + and self.end == self.storage_ptr + self.storage_size + ) + + +class Weights(dict): + """ + A dictionary mapping from weight name to a tuple of (tensor, TensorProperties). + tensor represents the actual intial value of the weight. + TensorProperties represents the properties of the weight that are needed to recover the weight. + + We use two separate entries because `tensor` could be a clone of the original weight tensor, + so it doesn't have the same property as the original weight (such as underlying storage pointer). + """ + + def __init__(self, weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]]): + super().__init__(weight_dict) + + def get_weight(self, name: str) -> tuple[torch.Tensor, TensorProperties]: + return self[name] + + def get_weight_properties(self, name: str) -> TensorProperties: + return self[name][1] + + +def get_complete( + group: OrderedSet[tuple[str, str]], models_weights: dict[str, Weights] +) -> tuple[str, str]: + """ + `group` is a (model_name, weight_name) tuple. + `model_weights` is a dictionary mapping from model name to its Weights. + + One of the tensor in `group` must be complete and they must share the + same underlying storage. + + Returns the name of the complete tensor in the `group`. If multiple + tensors are complete, returns an arbitrary one. + """ + + def get_tensor_properties(name_tuple: tuple[str, str]) -> TensorProperties: + # returns the tensor properties + (model_name, weight_name) = name_tuple + return models_weights[model_name].get_weight_properties(weight_name) + + for name_tuple in group: + tensor_property = get_tensor_properties(name_tuple) + if tensor_property.is_complete(): + return name_tuple + + raise RuntimeError("No complete tensor found in the group!") + + +def group_weights(all_weights: dict[str, Weights]) -> list[OrderedSet[tuple[str, str]]]: + """ + Group weights that share the same underlying storage. + + Returns a list of sets, each set contains a tuple of (model_name, weight_name). + """ + + weights_dict: dict[int, OrderedSet[tuple[str, str]]] = collections.defaultdict( + OrderedSet + ) # storage_key -> set(weight) + + for model_name, weights in all_weights.items(): + for weight_name, (_, properties) in weights.items(): + weights_dict[properties.storage_ptr].add((model_name, weight_name)) + + return list(weights_dict.values()) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 119de9c21afcc6..210b5755f9e6d4 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -104,9 +104,9 @@ def _assign_attr( assert isinstance(from_obj, torch.Tensor) to_module.register_buffer(field, from_obj, persistent=persistent) elif attr_kind == _AttrKind.CONSTANT: - assert not isinstance( - from_obj, FakeScriptObject - ), "FakeScriptObject should only exist during tracing." + assert not isinstance(from_obj, FakeScriptObject), ( + "FakeScriptObject should only exist during tracing." + ) assert isinstance( from_obj, ( @@ -143,7 +143,7 @@ def __init__( super().__init__() self.graph = graph self._ty = ty - self.graph.owning_module = self + self.graph.owning_module = self # type: ignore[assignment] self._run_with_interpreter = RUN_WITH_INTERPRETER def forward(self, *args, **kwargs): @@ -296,7 +296,7 @@ def __init__( export_graph = deepcopy(export_module.graph) self.graph_signature = deepcopy(export_module.graph_signature) self.graph = torch.fx.Graph() - self.graph.owning_module = self + self.graph.owning_module = self # type: ignore[assignment] self.module_call_graph = deepcopy(export_module.module_call_graph) self.flat_args_adapter = flat_args_adapter @@ -461,9 +461,9 @@ def add_to_consts_map(obj_id, node_name, target_name): # add constants that are aliased and don't appear in graph signature for const_name, const in export_module.constants.items(): if const_name not in consts_targets: - assert ( - id(const) in consts_map - ), "Constants should be either aliased or appear in graph signature" + assert id(const) in consts_map, ( + "Constants should be either aliased or appear in graph signature" + ) ph_name, _ = consts_map[id(const)][0] add_to_consts_map(id(const), ph_name, const_name) added_params_buffers.add(s.target) @@ -524,7 +524,7 @@ def _adapt_flat_args(self, flat_args, in_spec, input): if self.flat_args_adapter is None: raise TypeError( - "There is no flat args adapter sepcified. " + "There is no flat args adapter specified. " "Are you sure you are calling this with the right arguments? " ) else: @@ -1041,9 +1041,9 @@ def create_module(fqn): if arg.name in self.seen_nodes: flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) - self.node_to_placeholder[ - self.seen_nodes[arg.name] - ] = flat_arg_node + self.node_to_placeholder[self.seen_nodes[arg.name]] = ( + flat_arg_node + ) with self.parent.graph.inserting_before(self.parent_call_module): input_nodes: list[Optional[torch.fx.Node]] = [] @@ -1125,8 +1125,7 @@ def remap_input(self, x): if x in self.node_to_placeholder: return self.node_to_placeholder[x] elif ( - x.op == "placeholder" - or self.module_call_graph.get(self.fqn) is None + x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None # allow placeholder creation if we are not preserving module call signature ): self.add_placeholder(x) diff --git a/torch/fft/__init__.py b/torch/fft/__init__.py index 3ad1748bab1a82..b48cd28bb17df5 100644 --- a/torch/fft/__init__.py +++ b/torch/fft/__init__.py @@ -82,9 +82,7 @@ >>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j]) >>> torch.fft.fft(t) tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j]) -""".format( - **common_args - ), +""".format(**common_args), ) ifft = _add_docstr( @@ -125,9 +123,7 @@ >>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) >>> torch.fft.ifft(t) tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j]) -""".format( - **common_args - ), +""".format(**common_args), ) fft2 = _add_docstr( @@ -188,9 +184,7 @@ >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) >>> torch.testing.assert_close(fft2, two_ffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) ifft2 = _add_docstr( @@ -243,9 +237,7 @@ >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) >>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) fftn = _add_docstr( @@ -305,9 +297,7 @@ >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) >>> torch.testing.assert_close(fftn, two_ffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) ifftn = _add_docstr( @@ -359,9 +349,7 @@ >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) >>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) rfft = _add_docstr( @@ -417,9 +405,7 @@ Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted. At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair, and therefore must always be real-valued. -""".format( - **common_args - ), +""".format(**common_args), ) irfft = _add_docstr( @@ -496,9 +482,7 @@ >>> roundtrip = torch.fft.irfft(T, t.numel()) >>> torch.testing.assert_close(roundtrip, t, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) rfft2 = _add_docstr( @@ -565,9 +549,7 @@ >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) >>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) irfft2 = _add_docstr( @@ -649,9 +631,7 @@ torch.Size([10, 9]) >>> torch.testing.assert_close(roundtrip, t, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) rfftn = _add_docstr( @@ -718,9 +698,7 @@ >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) >>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) irfftn = _add_docstr( @@ -801,9 +779,7 @@ torch.Size([10, 9]) >>> torch.testing.assert_close(roundtrip, t, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) hfft = _add_docstr( @@ -894,9 +870,7 @@ >>> torch.fft.hfft(T[:3]) tensor([0.1250, 0.2809, 0.6250, 0.9691]) -""".format( - **common_args - ), +""".format(**common_args), ) ihfft = _add_docstr( @@ -951,9 +925,7 @@ >>> torch.fft.ifft(t) tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j, -0.5000+0.6882j]) -""".format( - **common_args - ), +""".format(**common_args), ) hfft2 = _add_docstr( @@ -1025,9 +997,7 @@ >>> torch.allclose(roundtrip, T) True -""".format( - **common_args - ), +""".format(**common_args), ) ihfft2 = _add_docstr( @@ -1092,9 +1062,7 @@ >>> torch.allclose(t, two_ffts) True -""".format( - **common_args - ), +""".format(**common_args), ) hfftn = _add_docstr( @@ -1187,9 +1155,7 @@ >>> torch.allclose(roundtrip, T) True -""".format( - **common_args - ), +""".format(**common_args), ) ihfftn = _add_docstr( @@ -1259,9 +1225,7 @@ >>> torch.allclose(ihfftn, two_iffts) True -""".format( - **common_args - ), +""".format(**common_args), ) fftfreq = _add_docstr( @@ -1310,9 +1274,7 @@ >>> torch.fft.fftfreq(4) tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) -""".format( - **factory_common_args - ), +""".format(**factory_common_args), ) rfftfreq = _add_docstr( @@ -1361,9 +1323,7 @@ >>> torch.fft.fftfreq(4) tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) -""".format( - **factory_common_args - ), +""".format(**factory_common_args), ) fftshift = _add_docstr( diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index dcca39d06a4ed3..79533346187dbb 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -271,9 +271,9 @@ def set_exception(self, result: T) -> None: ... ValueError: foo """ - assert isinstance( - result, Exception - ), f"{result} is of type {type(result)}, not an Exception." + assert isinstance(result, Exception), ( + f"{result} is of type {type(result)}, not an Exception." + ) def raise_error(fut_result): raise fut_result diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index a4322a884d60a4..c048b4fdd8f894 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -52,7 +52,7 @@ def forward(self, x): The **symbolic tracer** performs "symbolic execution" of the Python code. It feeds fake values, called Proxies, through the code. Operations -on theses Proxies are recorded. More information about symbolic tracing +on these Proxies are recorded. More information about symbolic tracing can be found in the :func:`symbolic_trace` and :class:`Tracer` documentation. diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py index e723046bf37c8b..97e5755d7d52ce 100644 --- a/torch/fx/_graph_pickler.py +++ b/torch/fx/_graph_pickler.py @@ -253,9 +253,9 @@ def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None: for k in MetaTensorDesc._UNSERIALIZABLE: if k in ("fake_mode", "view_func"): continue - assert ( - getattr(self.metadata, k) is None - ), f"not None: {k}: {getattr(self.metadata, k)}" + assert getattr(self.metadata, k) is None, ( + f"not None: {k}: {getattr(self.metadata, k)}" + ) def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor: # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py? diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index cc2f686ebba10d..377faf327fc9de 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -127,7 +127,7 @@ def _lazy_forward(self, *args, **kwargs): forward = _lazy_forward - # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__, + # TODO: we should handle __reduce_deploy__ the same way as __reduce_package__, # or __reduce__ by calling _real_recompile. But I don't find a good way # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule # will be used in torch::deploy. So it's skipped for now. diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index 3ab5ffbea6f5a1..7a31e4ef3cfa6d 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -42,7 +42,7 @@ def tree_flatten_spec( # I guess these exist for BC, FC reasons. # In general, we should be able to directly # use pytree tree flattener to flatten them, - # as export serializes the pytree seperately. + # as export serializes the pytree separately. # Will remove it in follow up PR. if spec.type in SUPPORTED_NODES: flatten_fn_spec = SUPPORTED_NODES[spec.type] diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 2509de1d20767f..dfb9b9f8074b7a 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -694,7 +694,7 @@ def proxy_placeholder(name): # In the case that we have pytree-flattened inputs in # `concrete_args`, generate a flattening wrapper around the # original root function and return that. - self.graph._codegen = _PyTreeCodeGen( + self.graph._codegen = _PyTreeCodeGen( # type: ignore[has-type] _PyTreeInfo(orig_args[:total_args], in_spec, None) ) @@ -702,7 +702,7 @@ def flatten_fn(*args): tree_args = pytree.tree_unflatten(list(args), in_spec) tree_out = root_fn(*tree_args) out_args, out_spec = pytree.tree_flatten(tree_out) - assert isinstance(self.graph._codegen, _PyTreeCodeGen) + assert isinstance(self.graph._codegen, _PyTreeCodeGen) # type: ignore[has-type] self.graph._codegen.pytree_info = ( self.graph._codegen.pytree_info._replace(out_spec=out_spec) ) @@ -755,9 +755,9 @@ def trace( self.root = root - assert hasattr( - type(root), self.traced_func_name - ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + assert hasattr(type(root), self.traced_func_name), ( + f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + ) fn = getattr(type(root), self.traced_func_name) self.root_module_name = root._get_name() @@ -1164,9 +1164,9 @@ def _maybe_revert_all_patches(): finally: if current_patcher is not None: patches_made = current_patcher.reapply_all_patches() - assert ( - patches_made == patches_removed - ), "CURRENT_PATCHER was changed during a revert_all_patches" + assert patches_made == patches_removed, ( + "CURRENT_PATCHER was changed during a revert_all_patches" + ) def _patch_wrapped_functions(patcher: _Patcher): @@ -1248,9 +1248,9 @@ def my_custom_function(x, y): assert not isinstance(fn_or_name, str) # to make mypy happy fn_name = fn_or_name.__name__ else: - assert isinstance( - fn_or_name, str - ), "fn_or_name must be a global function or string name" + assert isinstance(fn_or_name, str), ( + "fn_or_name must be a global function or string name" + ) fn_name = fn_or_name currentframe = inspect.currentframe() @@ -1308,7 +1308,9 @@ def f(x): return out - f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) + f = fx.symbolic_trace( + f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}} + ) assert f({"a": 1, "b": 2, "c": 4}) == 7 diff --git a/torch/fx/config.py b/torch/fx/config.py index da5120d6edf180..db06176c43e13c 100644 --- a/torch/fx/config.py +++ b/torch/fx/config.py @@ -1,5 +1,5 @@ # Whether to disable showing progress on compilation passes -# Need to add a new config otherwise wil get a circular import if dynamo config is imported here +# Need to add a new config otherwise will get a circular import if dynamo config is imported here disable_progress = True # If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 29b8d4541b8185..c29d05f511a795 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -450,9 +450,9 @@ def find_device_based_on_size(node) -> Device: device = find_device_based_on_size(node) occupied_devices.append(device) # Update partition and its left mem size - partition_to_left_mem_bytes[ - partition - ] = device.available_mem_bytes + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) # Update available mem for the current partition partition.logical_device_ids.append(device.logical_id) else: @@ -475,9 +475,9 @@ def find_device_based_on_size(node) -> Device: total_size_of_input_nodes = get_extra_size_of( node, partition.nodes ) - partition_to_left_mem_bytes[ - partition - ] = device.available_mem_bytes + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) partition.logical_device_ids.append(device.logical_id) partition.add_node(node) partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes @@ -509,9 +509,9 @@ def saturate_host(self) -> None: no_device_partitions, ) = get_device_partition_stats(self.partitions, self.devices) - assert ( - len(no_device_partitions) == 0 - ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + assert len(no_device_partitions) == 0, ( + f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + ) # Devices that hold partitions used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index 3b15ae0a6739cf..a8798a6a0726a4 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -82,7 +82,7 @@ def expand_to_tensor_dim(t, n): def broadcast_types(t1, t2): """ Applies broadcasting to both given types such that they - become consistent with eachother and returns two new + become consistent with each other and returns two new resulting types """ @@ -846,7 +846,7 @@ def flatten_refinement_rule(n: Node): @register_algebraic_expressions_inference_rule(Conv2d) def conv_rule(n: Node, module_instance): """ - Represents the outout in terms of an algrbraic expression w.r.t + Represents the output in terms of an algrbraic expression w.r.t the input when possible """ assert isinstance(n.args[0], Node) diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index e2fc033e0b8dbd..bc00be5ee7ae82 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -126,7 +126,7 @@ def __init__(self, root, attr: str): self._node = None @property - def node(self): + def node(self): # type: ignore[override] # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py index 8aca3e482c95f7..388d716245d4f0 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -164,7 +164,7 @@ class TGreatestUpperBound(Constraint): def __init__(self, res, rhs1, rhs2): """ - :param res: tensor variable that stores the result of the outout + :param res: tensor variable that stores the result of the output :param rhs1: tensor or tensor variable :param rhs2: tensor or tensor variabke """ @@ -407,7 +407,7 @@ def __init__( """ :param conv_result: the convolution result :param input_var: input to convolution - :param c_out: output chanel type + :param c_out: output channel type :param kernel: kernel tuple """ self.conv_result = conv_result diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index 03346b800924e5..e4951aab15cbfd 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -532,7 +532,7 @@ def view_inference_rule(n: Node, symbols, constraints, counter): else: num_constraints.append(BinConstraintD(t, Dyn, op_neq)) - t2_type.append(t) + t2_type.append(t) # type: ignore[arg-type] t2_type = TensorType(t2_type) # type: ignore[assignment] @@ -681,7 +681,7 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): # tensor output case elif isinstance(n.args[1], tuple): # create and store the new tensor variable - get_item_output, counter = gen_tvar(counter) + get_item_output, counter = gen_tvar(counter) # type: ignore[arg-type,assignment] symbols[n] = get_item_output # retrieve arg variables @@ -1073,7 +1073,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e1 = symbols[n.args[0]] return [BinConstraintT(my_output, e1, op_eq)], counter elif isinstance(symbols[n.args[0]], DVar): - my_output, counter = gen_dvar(counter) + my_output, counter = gen_dvar(counter) # type: ignore[arg-type,assignment] symbols[n] = my_output e1 = symbols[n.args[0]] @@ -1095,7 +1095,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e2 = symbols[n.args[1]] return [BinConstraintT(my_output, e2, op_eq)], counter elif isinstance(symbols[n.args[1]], DVar): - my_output, counter = gen_dvar(counter) + my_output, counter = gen_dvar(counter) # type: ignore[arg-type,assignment] symbols[n] = my_output e2 = symbols[n.args[1]] diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py index 11ebff0102093c..9b84c12127f0fb 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -823,7 +823,7 @@ def calc_last_two_dims(constraint, d: list[DVar]): [BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)] ) - # transform parameters into tuples incase they are not already + # transform parameters into tuples in case they are not already padding = ( (constraint.padding, constraint.padding) if isinstance(constraint.padding, int) diff --git a/torch/fx/experimental/migrate_gradual_types/util.py b/torch/fx/experimental/migrate_gradual_types/util.py index bd40d2a463f5e7..b160ec8de70f95 100644 --- a/torch/fx/experimental/migrate_gradual_types/util.py +++ b/torch/fx/experimental/migrate_gradual_types/util.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs from torch.fx.experimental.migrate_gradual_types.constraint import ( BinConstraintD, BVar, @@ -8,7 +7,7 @@ from torch.fx.experimental.migrate_gradual_types.operation import op_leq -def gen_tvar(curr): +def gen_tvar(curr: int) -> tuple[TVar, int]: """ Generate a tensor variable :param curr: The current counter @@ -18,7 +17,7 @@ def gen_tvar(curr): return TVar(curr), curr -def gen_dvar(curr): +def gen_dvar(curr: int) -> tuple[DVar, int]: """ Generate a dimension variable :param curr: the current counter @@ -28,7 +27,7 @@ def gen_dvar(curr): return DVar(curr), curr -def gen_bvar(curr): +def gen_bvar(curr: int) -> tuple[BVar, int]: """ Generate a boolean variable :param curr: the current counter @@ -38,7 +37,7 @@ def gen_bvar(curr): return BVar(curr), curr -def gen_tensor_dims(n, curr): +def gen_tensor_dims(n: int, curr: int) -> tuple[list[DVar], int]: """ Generate a list of tensor dimensions :param n: the number of dimensions @@ -52,7 +51,7 @@ def gen_tensor_dims(n, curr): return dims, curr -def gen_nat_constraints(list_of_dims): +def gen_nat_constraints(list_of_dims: list[DVar]) -> list[BinConstraintD]: """ Generate natural number constraints for dimensions """ diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 13d9c2d9ac7796..3e406b57a96d57 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -368,12 +368,12 @@ class MklSupport(Enum): supports_mkldnn = MklSupport.YES sample_parameter = next(cur_module.parameters(), None) if sample_parameter is not None: - assert ( - sample_parameter.dtype == torch.float - ), "this pass is only for torch.float modules" - assert sample_parameter.device == torch.device( - "cpu" - ), "this pass is only for CPU modules" + assert sample_parameter.dtype == torch.float, ( + "this pass is only for torch.float modules" + ) + assert sample_parameter.device == torch.device("cpu"), ( + "this pass is only for CPU modules" + ) elif node.op == "call_function": if node.target in mkldnn_supported: supports_mkldnn = MklSupport.YES @@ -471,7 +471,7 @@ def get_color(n): if not use_mkl_heuristic(graph): for node in graph.start_nodes + graph.end_nodes: prv = node.args[0] - node.replace_all_uses_with(prv) + node.replace_all_uses_with(prv) # type: ignore[arg-type] fx_graph.erase_node(node) reset_modules(graph.nodes, modules, old_modules) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 0b5de5a677524c..6777b1f31cef28 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -11,7 +11,7 @@ import inspect import logging import operator -import traceback +import threading import typing import typing_extensions import weakref @@ -67,7 +67,6 @@ ) from torch.utils._stats import count from torch.utils._thunk import Thunk -from torch.utils._traceback import CapturedTraceback from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary from ._backward_state import BackwardState @@ -181,26 +180,82 @@ def is_sym_node(node: _HasMeta) -> bool: return "val" in node.meta and isinstance(node.meta["val"], py_sym_types) -@overload -def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: - ... +@overload # type: ignore[no-overload-impl] +def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ... @overload def set_proxy_slot( obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy -) -> None: - ... +) -> None: ... @overload def set_proxy_slot( obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType -) -> None: - ... +) -> None: ... -def set_proxy_slot( +class _DisableUpdateTensorTracker(threading.local): + value: bool = False + + +_disable_update_tensor_tracker_tls = _DisableUpdateTensorTracker() + + +def _is_proxy_tensor_update_tensor_tracker_disabled() -> bool: + """ + Returns current state of disabling update tensor tracker. + """ + return _disable_update_tensor_tracker_tls.value + + +@contextmanager +def _proxy_tensor_disable_update_tensor_tracker() -> Generator[None, None, None]: + """ + NOTE "Do not clobber inplace ops" + By default tensor_tracker is updated every time. + This leads to chaining every operation by the FakeTensor. + For example for mutable ops if we have several consecutive mutable operations: + + def f(x, y, z): + x.copy_(y) + x.copy_(z) + return x + + Default graph result: + def f_graph(x, y, z) + x_1 = x.copy_(y) + x_2 = x_1.copy_(z) + return x_2 + + This chaining simplifies the fx passes and helps to prevent the reordering. + But in some cases, we want those nodes to be disconnected. + E.g. in case of splitting joint graph into forward and backward. + If first inplace op happened in forward, second in backward, + we want them after split to be properly placed. + + Enabling this context manager for copy_ will result in: + def f_graph_2(x, y, z): + x_1 = x.copy_(y) + x_2 = x.copy_(z) + return x + + Results of copy_ x1 and x2 will have empty users in the graph. + The reason why this behavior is not enabled for all inplace ops is that + some fx passes (e.g. fx quantization) rely on chaining inplace ops like add_ + in their fusions passes. + We could revisit enabling this logic for all inplace ops in future. + """ + orig_value = _disable_update_tensor_tracker_tls.value + _disable_update_tensor_tracker_tls.value = True + try: + yield + finally: + _disable_update_tensor_tracker_tls.value = orig_value + + +def set_proxy_slot( # type: ignore[no-redef] obj: Union[PySymType, _AnyScriptObjectType, Tensor], tracer: _ProxyTracer, proxy: object, @@ -210,7 +265,9 @@ def set_proxy_slot( # We DO want to clobber proxies whenever we run an inplace operation # on a tensor, and it affects the metadata on the proxy. assert isinstance(proxy, _ProxyTensor) - tracer.tensor_tracker[obj] = proxy + # see NOTE [Do not clobber inplace ops] + if not _is_proxy_tensor_update_tensor_tracker_disabled(): + tracer.tensor_tracker[obj] = proxy elif isinstance(obj, (_AnyScriptObject)): # We DO want to clobber proxies, with a similar rationale as for tensors. assert isinstance(proxy, Proxy) @@ -256,8 +313,7 @@ def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: def get_proxy_slot( obj: Tensor, tracer: _ProxyTracer, -) -> _ProxyTensor: - ... +) -> _ProxyTensor: ... @overload @@ -265,8 +321,7 @@ def get_proxy_slot( obj: Tensor, tracer: _ProxyTracer, default: U, -) -> Union[_ProxyTensor, U]: - ... +) -> Union[_ProxyTensor, U]: ... @overload @@ -275,16 +330,14 @@ def get_proxy_slot( tracer: _ProxyTracer, default: U, transform: Callable[[_ProxyTensor], R], -) -> Union[R, U]: - ... +) -> Union[R, U]: ... @overload def get_proxy_slot( obj: _AnyScriptObjectType, tracer: _ProxyTracer, -) -> Proxy: - ... +) -> Proxy: ... @overload @@ -292,8 +345,7 @@ def get_proxy_slot( obj: _AnyScriptObjectType, tracer: _ProxyTracer, default: U, -) -> Union[Proxy, U]: - ... +) -> Union[Proxy, U]: ... @overload @@ -302,16 +354,14 @@ def get_proxy_slot( tracer: _ProxyTracer, default: U, transform: Callable[[Proxy], R], -) -> Union[R, U]: - ... +) -> Union[R, U]: ... @overload def get_proxy_slot( obj: PySymType, tracer: _ProxyTracer, -) -> _PySymProxyType: - ... +) -> _PySymProxyType: ... @overload @@ -319,8 +369,7 @@ def get_proxy_slot( obj: PySymType, tracer: _ProxyTracer, default: T, -) -> Union[T, _PySymProxyType]: - ... +) -> Union[T, _PySymProxyType]: ... @overload @@ -329,8 +378,7 @@ def get_proxy_slot( tracer: _ProxyTracer, default: U, transform: Callable[[_PySymProxyType], R], -) -> Union[R, U]: - ... +) -> Union[R, U]: ... # the default argument is what to return if the slot is not set. @@ -367,12 +415,12 @@ def get_proxy_slot( return res -def snapshot_fake(val: Tensor) -> Optional[Tensor]: +def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]: # val.detach() will also eventually call fast_detach(), # but this saves us a full trip into __torch_dispatch__ # (snapshot_fake is called a lot) if isinstance(val, FakeTensor): - return fast_detach(val.fake_mode, val) + return fast_detach(val.fake_mode, val, include_real) else: return val.detach() @@ -393,9 +441,9 @@ def snapshot_fake(val: Tensor) -> Optional[Tensor]: ] -def extract_val(val: _ExtractValType) -> _ExtractValType: +def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractValType: if is_fake(val): - return snapshot_fake(val) + return snapshot_fake(val, include_real=include_real) elif isinstance(val, py_sym_types): return val elif isinstance(val, _AnyScriptObject): @@ -494,7 +542,9 @@ def maybe_enable_thunkify() -> Generator[None, None, None]: # grad_fn, _base (_base actually may be set due to recursive call to # ADInplaceOrView, but you shouldn't rely on it.) def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy: - proxy.node.meta["val"] = extract_val(val) + proxy.node.meta["val"] = extract_val( + val, include_real=(proxy.node.op == "placeholder") + ) with _enable_thunkify(proxy.tracer): # type: ignore[arg-type] # Best effort tensor_meta setting; prefer using val! @@ -715,22 +765,21 @@ def inner(e: PySymType) -> Union[int, bool, float, Proxy]: @overload -def fetch_object_proxy(tracer: _ProxyTracer, t: Tensor) -> Union[_ProxyTensor, Tensor]: - ... +def fetch_object_proxy( + tracer: _ProxyTracer, t: Tensor +) -> Union[_ProxyTensor, Tensor]: ... @overload def fetch_object_proxy( tracer: _ProxyTracer, t: _AnyScriptObjectType -) -> Union[Proxy, _AnyScriptObjectType]: - ... +) -> Union[Proxy, _AnyScriptObjectType]: ... @overload def fetch_object_proxy( tracer: _ProxyTracer, t: PySymType -) -> Union[_PySymProxyType, PySymType]: - ... +) -> Union[_PySymProxyType, PySymType]: ... def fetch_object_proxy( @@ -813,7 +862,10 @@ def can_handle_tensor(x: Tensor) -> bool: if func is torch.ops.aten.is_nonzero.default: with proxy_mode: - torch._check(args[0].numel() == 1, lambda: "Boolean value of Tensor with more than one value is ambiguous") # type: ignore[attr-defined] + torch._check( + args[0].numel() == 1, # type: ignore[attr-defined] + lambda: "Boolean value of Tensor with more than one value is ambiguous", + ) return (args[0] != 0).item() # type: ignore[attr-defined] tracer = proxy_mode.tracer @@ -1009,7 +1061,7 @@ def get( ) -> _PySymProxyType: # dict.get()'s annotation doesn't accept `None` when the value type # isn't Optional. - return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type] + return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type, return-value] def __iter__(self) -> Any: raise NotImplementedError @@ -1076,22 +1128,19 @@ def create_arg(self, a: object) -> fx.node.Node: return super().create_arg(a) # type: ignore[return-value] @overload - def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: - ... + def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ... @overload - def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: - ... + def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ... @overload def unwrap_proxy( self, e: _AnyScriptObjectType - ) -> Union[Proxy, _AnyScriptObjectType]: - ... + ) -> Union[Proxy, _AnyScriptObjectType]: ... def unwrap_proxy(self, e: T) -> object: if isinstance(e, Tensor): - return get_proxy_slot(e, self, e, lambda x: x.proxy) + return get_proxy_slot(e, self, e, lambda x: x.proxy) # type: ignore[attr-defined] elif isinstance(e, py_sym_types): return get_proxy_slot(e, self, e, lambda e: e.force()) elif isinstance(e, _AnyScriptObject): @@ -1110,6 +1159,16 @@ def create_node( ) -> torch.fx.Node: node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type] + if node.op in ["placeholder", "output"] and "stack_trace" in node.meta: + del node.meta["stack_trace"] + + if kind == "get_attr": + assert isinstance(target, str) + attr = getattr(self.root, target) + if isinstance(attr, torch.Tensor): + with disable_proxy_modes_tracing(): + node.meta["val"] = extract_val(attr) + def map_fn(v: Any) -> Optional[_ExtractValType]: if not isinstance(v, torch.fx.Node) or "val" not in v.meta: return None @@ -1136,6 +1195,8 @@ def _should_save_eager_input_vals( target: Any, args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None, ) -> bool: + from torch._higher_order_ops.invoke_subgraph import InvokeSubgraphHOP + if not callable(target): return False if isinstance( @@ -1143,6 +1204,7 @@ def _should_save_eager_input_vals( ( torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + InvokeSubgraphHOP, ), ): return True @@ -1353,7 +1415,7 @@ def __torch_function__( kwargs = kwargs or {} if func in _side_effectful_need_to_be_preserved_pre_dispatch: # It's for passing the export verifier which needs to verify the meta['val'] - # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, + # TODO(tmanlaibaatar): we should systematically couple it with export verifier, # instead of hardcoding it here. # T203648563 if func == torch.amp.autocast_mode._exit_autocast: @@ -1370,7 +1432,7 @@ def __torch_function__( node.meta["val"] = None return node # Don't actually run the function! We just want to trace the calls - # into a graph. We don't actualy want to change global autograd state. + # into a graph. We don't actually want to change global autograd state. return func(*args, **kwargs) @@ -1562,7 +1624,10 @@ def __init__( self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") def placeholder( - self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) @@ -1571,7 +1636,10 @@ def placeholder( return out def get_attr( - self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) @@ -1581,7 +1649,10 @@ def get_attr( # call_function, call_method, call_module get traced automatically by the outer mode. def output( - self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: out = super().output(target, args, kwargs) # type: ignore[arg-type] @@ -1656,6 +1727,8 @@ class _ModuleStackTracer(PythonKeyTracer): def __init__(self, scope_root: GraphModule) -> None: super().__init__() + self.record_stack_traces = True + self._record_forward_stack_traces_only = True self.scope_root = scope_root self.enable_attr_proxy = False self.submodule_paths = {} @@ -1906,36 +1979,6 @@ def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}", ) - # stack_trace - if "stack_trace" not in node.meta and node.op not in ["placeholder", "output"]: - user_frame_summary = CapturedTraceback.extract().summary() - if user_frame_summary: - # we retain frames from forward() calls, or ops - # located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap) - stack_trace = [ - frame - for frame in user_frame_summary - if ( - frame.name == "forward" - or frame.filename.endswith("torch/__init__.py") - ) - ] - # filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py - # this is hardcoded, but leads to a much cleaner stack trace - stack_trace = [ - frame - for frame in stack_trace - if not ( - frame.filename.endswith("fx/_symbolic_trace.py") - or frame.filename.endswith("export/_trace.py") - ) - ] - if ( - stack_trace - ): # empty list for strict mode, dynamo should handle stack_trace - stack_trace = traceback.StackSummary.from_list(stack_trace) - node.meta["stack_trace"] = "".join(stack_trace.format()).strip() - return node @@ -1949,6 +1992,7 @@ def __init__( record_module_stack: bool, _allow_fake_constant: bool, _error_on_data_dependent_ops: bool, + record_stack_traces: bool = False, ) -> None: # Configurations that are used to initialize the context managers and their states. # Should not modify them during tracing. @@ -1967,18 +2011,19 @@ def __init__( # All context managers and their states should be initialized before tracing based on the inputs # and configurations. After tracing, their states should be cleaned except for shape_env. - # Remember to specify how to intialize it from user inputs and from parent tracer whenever + # Remember to specify how to initialize it from user inputs and from parent tracer whenever # adding new modes in _MakefxTracer. self.fake_tensor_mode: Optional[FakeTensorMode] = None self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext() - self.proxy_function_mode: Union[ - nullcontext, PreDispatchTorchFunctionMode - ] = nullcontext() + self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = ( + nullcontext() + ) self.fx_tracer: Optional[PythonKeyTracer] = None self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext() - self.torch_fn_metadata_mode: Union[ - nullcontext, TorchFunctionMetadataMode - ] = nullcontext() + self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = ( + nullcontext() + ) + self.record_stack_traces = record_stack_traces def _checkpoint_modes(self) -> list[Any]: return [ @@ -2017,9 +2062,14 @@ def _init_modes_from_inputs( if hasattr(f, "_orig_mod") and self.record_module_stack: scope_root = f._orig_mod + # _ModuleStackTracer always try to preserve stack trace + # in forward functions self.fx_tracer = _ModuleStackTracer(scope_root) else: self.fx_tracer = PythonKeyTracer() + self.fx_tracer.record_stack_traces = self.record_stack_traces + if self.record_stack_traces: + self.fx_tracer._record_forward_stack_traces_only = True if self.tracing_mode == "fake": import torch._dynamo @@ -2050,9 +2100,9 @@ def _init_modes_from_inputs( allow_non_fake_inputs=self._allow_non_fake_inputs, shape_env=shape_env, ) - assert ( - fake_tensor_mode.shape_env is not None - ), "shape_env should be set if tracing with 'symbolic'" + assert fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) self.fake_tensor_mode = fake_tensor_mode else: if not self.tracing_mode == "real": @@ -2140,9 +2190,9 @@ def inner_wrap_fake(x: object) -> object: return self.fake_tensor_mode.from_tensor(x, source=source) # NB: don't match on bools elif type(x) is int and self.tracing_mode == "symbolic": - assert ( - self.fake_tensor_mode.shape_env is not None - ), "shape_env should be set if tracing with 'symbolic'" + assert self.fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) return self.fake_tensor_mode.shape_env.create_symintnode( self.fake_tensor_mode.shape_env.create_symbol( x, source, positive=None @@ -2155,9 +2205,9 @@ def inner_wrap_fake(x: object) -> object: self.fake_tensor_mode, x ) - assert not isinstance( - x, FakeScriptObject - ), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + assert not isinstance(x, FakeScriptObject), ( + f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + ) return x wrap_fn_map = { @@ -2271,15 +2321,20 @@ def make_fx( record_module_stack: bool = False, _allow_fake_constant: bool = False, _error_on_data_dependent_ops: bool = True, + record_stack_traces: bool = False, ) -> Callable[..., GraphModule]: """ Given a function f, return a new function which when executed with valid arguments to f, returns an FX GraphModule representing the set of operations that were executed during the course of execution. + + If record_stack_traces is True, the stack trace will be preserved on node.meta["stack_trace"] """ assert tracing_mode in ["real", "fake", "symbolic"] + from torch._inductor import config + make_fx_tracer = _MakefxTracer( decomposition_table, tracing_mode, @@ -2288,6 +2343,7 @@ def make_fx( record_module_stack, _allow_fake_constant, _error_on_data_dependent_ops, + record_stack_traces=record_stack_traces or config.trace.enabled, ) @functools.wraps(f) @@ -2317,9 +2373,9 @@ def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]: torch._C._TorchDispatchModeKey.PROXY ) mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) - assert ( - pre_dispatch_mode is None or mode is None - ), f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + assert pre_dispatch_mode is None or mode is None, ( + f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + ) return pre_dispatch_mode or mode diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index bb54eba1138409..a9025fc54ebe3e 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -155,7 +155,7 @@ def replacearg(index: int, key: str, fn: Callable): fn=lambda args: tuple(maybe_convert_node(a) for a in args), ) if self.is_evaluate_expr() or self.is_defer_runtime_assert(): - # ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert: + # ShapeEnv.evaluate_expr and ShapeEnv.guard_or_defer_runtime_assert: # "fx_node" parameter is an (optional) FX node that represents the evaluate expression. # They must be replaced, since it will be part of a "call_function" FX node for # torch._assert, which will be added to the FX graph of the new shape_env. @@ -175,7 +175,7 @@ def is_evaluate_expr(self) -> bool: return self.name == "evaluate_expr" def is_defer_runtime_assert(self) -> bool: - return self.name == "defer_runtime_assert" + return self.name == "guard_or_defer_runtime_assert" NEST = 0 @@ -228,7 +228,7 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: # # At the moment, there are 2 methods that save the list: # - ShapeEnv.evaluate_expr -# - ShapeEnv.defer_runtime_assert +# - ShapeEnv.guard_or_defer_runtime_assert def record_shapeenv_event( *, save_tracked_fakes: bool = False, name: Optional[str] = None ) -> Callable: @@ -460,7 +460,7 @@ def value_to_str(value: Any) -> str: # Here, we allow the value of each field to be mapped, so that we appropriately # compare the two values. def compare_vars( - map_value: Callable[[str, Any], Any] + map_value: Callable[[str, Any], Any], ) -> list[tuple[str, str, str]]: env1_set, env2_set = set(env1_vars), set(env2_vars) diff --git a/torch/fx/experimental/schema_type_annotation.py b/torch/fx/experimental/schema_type_annotation.py index 335c027c9321b8..b1b2f1680d64a1 100644 --- a/torch/fx/experimental/schema_type_annotation.py +++ b/torch/fx/experimental/schema_type_annotation.py @@ -103,7 +103,7 @@ def get_attr( for i, atom in enumerate(atoms): if not hasattr(module_itr, atom): raise RuntimeError( - f'Node referenced nonextent target {".".join(atoms[:i])}!' + f"Node referenced nonextent target {'.'.join(atoms[:i])}!" ) module_itr = getattr(module_itr, atom) diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 1ab1e1f64008cf..5468191163ab73 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -149,9 +149,9 @@ def compute_hint(): # This is technically not TV, but this assert is expensive so # let's only do it when we're already doing expensive things computed_hint = compute_hint() - assert ( - hint == computed_hint - ), f"{hint} != {computed_hint} (for {self.expr})" + assert hint == computed_hint, ( + f"{hint} != {computed_hint} (for {self.expr})" + ) else: hint = compute_hint() self._hint = hint @@ -460,7 +460,9 @@ def pow(self, other): return self.float_pow(other) def is_non_overlapping_and_dense(self, sizes, strides): - return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] + return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq( + to_node(self, 1) + ) # type: ignore[attr-defined] def int_(self): return self.guard_int("", 0) # NB: uses Python backtrace @@ -554,7 +556,7 @@ def expect_true(self, file, line): # a regular guard if we can!) # TODO: file/line here is very important, because the assert has been # deferred so you can't backtrace easily - return self.shape_env.defer_runtime_assert( + return self.shape_env.guard_or_defer_runtime_assert( self.expr, f"{file}:{line}", fx_node=self.fx_node ) @@ -1415,7 +1417,7 @@ def binary_magic_impl(self, other): out, self.shape_env, pytype, - out_hint, + out_hint, # type: ignore[arg-type] fx_node=fx_node, optimized_summation=optimized_summation, # see Note [optimized_summation] ) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index db9825556fa957..e38e5f777d669b 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -39,6 +39,7 @@ Any, Callable, cast, + Generic, NamedTuple, NoReturn, Optional, @@ -46,7 +47,7 @@ TypeVar, Union, ) -from typing_extensions import deprecated, TypeAlias, TypeGuard +from typing_extensions import deprecated, ParamSpec, TypeAlias, TypeGuard import torch import torch.fx @@ -103,6 +104,7 @@ import types from torch import Tensor + from torch._dynamo.source import TensorPropertySource from torch._subclasses.fake_tensor import FakeTensor from torch.types import BoolLikeType, FloatLikeType, IntLikeType @@ -170,6 +172,7 @@ class PendingUnbackedSymbolNotFound(RuntimeError): "is_accessor_node", "ValueRangesSLoc", "SymIntEqByExpr", + "Specialization", ] # FX node metadata keys for symbolic shape FX graph. @@ -179,7 +182,9 @@ class PendingUnbackedSymbolNotFound(RuntimeError): def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None: log.debug( - "lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info() # type: ignore[attr-defined] + "lru_cache_stats %s: %s", + wrapped_f.__name__, # type: ignore[attr-defined] + wrapped_f.cumulative_cache_info(), # type: ignore[attr-defined] ) @@ -244,7 +249,7 @@ def __hash__(self) -> int: def _nested_int_aware_sort( - tup: tuple[IntLikeType, int] + tup: tuple[IntLikeType, int], ) -> tuple[int, IntLikeType, int]: return ( # Order nested ints by their coefficients. @@ -306,12 +311,15 @@ def uninteresting_files() -> set[str]: import torch._logging import torch._subclasses.fake_tensor import torch._subclasses.meta_utils + import torch.export._trace mods = [ sys.modules[__name__], + torch.export._trace, torch.fx.experimental.recording, torch.fx.experimental.sym_node, torch.fx.interpreter, + torch.fx._symbolic_trace, torch, torch._compile, torch._dynamo.eval_frame, @@ -494,9 +502,9 @@ def check_consistent(new: _T, old: _T) -> None: torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)") # NB: bool is subclass of int elif isinstance(new, scalar_types) and not isinstance(new, bool): - assert isinstance(old, scalar_types) and not isinstance( - old, bool - ), f"{old} != {new}" + assert isinstance(old, scalar_types) and not isinstance(old, bool), ( + f"{old} != {new}" + ) torch._check(old == new, lambda: f"{old} != {new} (old != new)") @@ -626,9 +634,9 @@ def rebind_unbacked( raw_u1 = new_raw_u1 if not isinstance(raw_u1, sympy.Symbol): - assert ( - not raw_u1.free_symbols - ), f"should have been constant, but got {raw_u1}" + assert not raw_u1.free_symbols, ( + f"should have been constant, but got {raw_u1}" + ) continue # The old and new could be the same if you improperly hit the memo @@ -1016,6 +1024,20 @@ def find_symbol_binding_fx_nodes( return r +@dataclass(frozen=True) +class Specialization: + """ + This class is used in multi-graph compilation contexts where we generate + multiple specialized graphs and dispatch to the appropriate one at runtime. + This allows us to optimize the trade-off between performance and generality + by creating specialized versions for common patterns (e.g., x.shape[0] % 16 == 0) + while maintaining a general fallback. + """ + + source: TensorPropertySource + check_fn: Callable + + # Analogous to ConvertIntSource @dataclass(frozen=True) class ConvertIntKey: @@ -1308,7 +1330,14 @@ def compute_unbacked_bindings( isinstance(old_sym, SymTypes) and (old_s := old_sym.node.expr) != new_s ): - if isinstance(old_s, sympy.Symbol): + # If old_s is not an unbacked_symbol, + # we assume that the original unbacked symbol is replaced + # by a backed symbol (old_s). This can happen + # when this node reuses the original symbol (due to memoi) + # and the original symbol gets replaced by the backed symbol. + # When this happens we just replace new_s by the old_s + # because we know the value is the same. + if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s): shape_env._rename_unbacked_to(new_s, old_s) else: shape_env._eliminate_unbacked(new_s, old_s) @@ -1318,6 +1347,7 @@ def compute_unbacked_bindings( return symbol_to_path +# Note [guard_or_] # The following two functions are common utilities used while defining unbacked semantics # of various framework code. Those would be used in situations you prefer to guard and know # the result of the expression over not guarding, but in case you hit a data dependent error @@ -1433,7 +1463,6 @@ def statically_known_true(x: BoolLikeType) -> bool: if not isinstance(x, SymBool): assert isinstance(x, bool) return x - result = _static_eval_sym_bool(x) if result is None: return False @@ -1445,11 +1474,9 @@ def sym_and(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType: """ and, but for symbolic expressions, without bool casting. """ - assert isinstance(x, (bool, SymBool)) if len(others) == 0: return x for y in others: - assert isinstance(y, (bool, SymBool)) x = operator.and_(x, y) return x @@ -1473,17 +1500,15 @@ def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType: """ or, but for symbolic expressions, without bool casting. """ - assert isinstance(x, (bool, SymBool)) if len(others) == 0: return x for y in others: - assert isinstance(y, (bool, SymBool)) x = operator.or_(x, y) return x def guard_scalar( - a: Union[SymBool, SymInt, SymFloat, int, bool, float] + a: Union[SymBool, SymInt, SymFloat, int, bool, float], ) -> Union[bool, int, float]: """ Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float. @@ -1961,12 +1986,12 @@ def is_derived( def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]: - assert isinstance( - symbolic_context, SymbolicContext - ), "Invalid symbolic_context object" - assert ( - type(symbolic_context) is not SymbolicContext - ), "Illegal usage of symbolic_context ABC" + assert isinstance(symbolic_context, SymbolicContext), ( + "Invalid symbolic_context object" + ) + assert type(symbolic_context) is not SymbolicContext, ( + "Illegal usage of symbolic_context ABC" + ) return True @@ -2017,8 +2042,12 @@ class SymIntSymbolicContext(SymbolicContext): constraint: DimConstraint +_P1 = ParamSpec("_P1") +_T1 = TypeVar("_T1") + + @dataclass(frozen=True) -class StatelessSymbolicContext(SymbolicContext): +class StatelessSymbolicContext(Generic[_P1, _T1], SymbolicContext): """ Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. @@ -2029,6 +2058,7 @@ class StatelessSymbolicContext(SymbolicContext): dynamic_strides: DimList[DimDynamic] = None # type: ignore[assignment] constraint_sizes: DimList[DimConstraint] = None # type: ignore[assignment] constraint_strides: DimList[DimConstraint] = None # type: ignore[assignment] + specialize_on: Optional[list[list[Callable[_P1, _T1]]]] = None # If the tensor is a view, this should be populated for the base. It contains # information on how to allocate symbols when recursively fakeifying the base # during view fake-ification. @@ -2036,6 +2066,12 @@ class StatelessSymbolicContext(SymbolicContext): # TODO: add storage offset and stride symbolic_context def __post_init__(self) -> None: + if self.specialize_on is None: + object.__setattr__( + self, + "specialize_on", + [[]] * len(self.dynamic_sizes), + ) if self.dynamic_strides is None: object.__setattr__( self, @@ -2150,7 +2186,7 @@ def __eq__(self, other: object) -> bool: def is_symbolic( - val: Union[int, SymInt, float, SymFloat, bool, SymBool] + val: Union[int, SymInt, float, SymFloat, bool, SymBool], ) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]: if isinstance(val, (int, float, bool)): return False @@ -2338,7 +2374,7 @@ def _maybe_evaluate_static_worker( # Note: # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers. - # Sympy might give unexepected results when comparing an integer with a non-integer + # Sympy might give unexpected results when comparing an integer with a non-integer # Therefore, we cast offset to int here. # For example: # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True) @@ -2429,7 +2465,7 @@ def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr: def cast_symbool_to_symint_guardless( - symbool: Union[bool, torch.SymBool] + symbool: Union[bool, torch.SymBool], ) -> Union[int, torch.SymInt]: """ Converts a SymBool or bool to a SymInt or int without introducing guards. @@ -2494,9 +2530,9 @@ def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: prior_version = self._version_counter prior_key = self._get_key() else: - assert ( - prior_key == self._get_key() - ), "ShapeEnv cache key changed without version being updated!" + assert prior_key == self._get_key(), ( + "ShapeEnv cache key changed without version being updated!" + ) return fn_cache(self, *args, **kwargs) @@ -2747,9 +2783,9 @@ def __init__( def _print_Symbol(self, expr: sympy.Symbol) -> str: assert isinstance(expr, sympy.Symbol), str(type(expr)) - assert self.symbol_to_source.get( - expr - ), f"Unknown symbol {expr} created by constraints solver" + assert self.symbol_to_source.get(expr), ( + f"Unknown symbol {expr} created by constraints solver" + ) return self.symbol_to_source[expr][0].name() @@ -2767,9 +2803,9 @@ def __init__( source_name_to_debug_name: Mapping[str, str], ) -> None: # We try to solve systems of inequalities with 1 free variable. - self._univariate_inequalities: dict[ - sympy.Symbol, set[SympyBoolean] - ] = defaultdict(set) + self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = ( + defaultdict(set) + ) # Among them, we prioritize solving for a free variable that has equalities. # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() # and removing a symbol from the former => removing it from the latter. @@ -2852,9 +2888,10 @@ def mod_handler(*args: sympy.Expr) -> sympy.Expr: # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! base, divisor = args - base, divisor = self.rewrite_with_congruences( - s, base - ), self.rewrite_with_congruences(s, divisor) + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( self._var_to_val ) @@ -2871,9 +2908,10 @@ def floor_div_handler(*args: sympy.Expr) -> sympy.Expr: # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d # and eliminating b % d as above. base, divisor = args - base, divisor = self.rewrite_with_congruences( - s, base - ), self.rewrite_with_congruences(s, divisor) + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( self._var_to_val ) @@ -3035,9 +3073,9 @@ def solve(self) -> None: (arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution, ) - assert isinstance( - solution, sympy.Eq - ), f"Expected an equality constraint for {s}, got {solution}" + assert isinstance(solution, sympy.Eq), ( + f"Expected an equality constraint for {s}, got {solution}" + ) symbol, val = solution.args assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" # because this is univariate, the solution is a specialization @@ -3315,7 +3353,8 @@ def _check_same_range(c: Mapping[str, int], dim: object) -> bool: "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index] } if not _check_same_range( - result, name_to_dim[mroot] # type: ignore[index, arg-type] + result, + name_to_dim[mroot], # type: ignore[index, arg-type] ): # ignore if unchanged modified_root_values[mroot] = result # type: ignore[index] break @@ -3343,7 +3382,7 @@ def prettify_results( constraint_violation_error: object, forced_specializations: dict[str, str], ) -> str: - """Format a message for constraint violation erros""" + """Format a message for constraint violation errors""" from torch.export.dynamic_shapes import _get_dim_name_mapping if not self._dcp.source_name_to_debug_name: @@ -3813,6 +3852,8 @@ def _init( self.trace_asserts = trace_asserts + self.specializations: OrderedSet[Specialization] = OrderedSet() + from torch.fx.experimental.validator import translation_validation_enabled self._translation_validation_enabled = translation_validation_enabled() @@ -3864,6 +3905,52 @@ def prefer_deferred_runtime_asserts_over_guards(self) -> bool: def allow_complex_guards_as_runtime_asserts(self) -> bool: return self.settings.allow_complex_guards_as_runtime_asserts + @contextmanager + def patch_source_specialization( + self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr] + ) -> Iterator[None]: + """ + Temporarily add symbol-level axioms to the ShapeEnv. This is useful when you want to "fork" + and have parallel universes of ShapeEnvs. For example, we use this when doing multi-graph + compile so we can support various graphs with varying levels of specializations. + + This context manager allows for temporarily adding constraints to the shape environment + based on a specialization function applied to a symbol associated with a source. + + Args: + source: The source of the symbol to specialize + check_fn: A function that takes a sympy Symbol and returns a sympy expression + representing a constraint/specialization to be applied + """ + name = source.name() + sym = self.source_to_var[name] + expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr + new_axioms = dict(self.get_implications(self.simplify(expr))) + added_replacements = {} + + for axiom in new_axioms: + if ( + isinstance(axiom, sympy.Eq) + and isinstance(axiom.lhs, sympy.Symbol) + and isinstance(axiom.rhs, sympy.Integer) + and axiom.lhs not in self.replacements + ): + self.replacements[axiom.lhs] = axiom.rhs + added_replacements[axiom.lhs] = axiom.rhs + self.axioms.update(new_axioms) + + # We need to freeze the ShapeEnv because any additional modification of + # the ShapeEnv will cause unsoundness for subsequent specialization calls. + self.frozen = True + try: + yield + finally: + for k in new_axioms: + self.axioms.pop(k, None) + for k in added_replacements: + self.replacements.pop(k, None) + self.frozen = False + def check_equal(self, other: ShapeEnv) -> None: """Compare another ShapeEnv for equivalence""" # ShapeEnv fields that are not relevant for the outcome of @@ -4049,9 +4136,9 @@ def _constrain_unify(self, a: SymInt, b: SymInt) -> None: if not isinstance(b, SymInt): assert a == b else: - assert isinstance( - b.node.expr, sympy.Symbol - ), "constraining non-Symbols NYI" + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) assert b.node.shape_env is self self.replacements[b.node.expr] = sympy.Integer(a) else: @@ -4064,9 +4151,9 @@ def _constrain_unify(self, a: SymInt, b: SymInt) -> None: self.replacements[a.node.expr] = sympy.Integer(b) else: assert a.node.shape_env is b.node.shape_env - assert isinstance( - b.node.expr, sympy.Symbol - ), "constraining non-Symbols NYI" + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) new_var = self._find(a.node.expr) self.replacements[b.node.expr] = new_var @@ -4159,9 +4246,9 @@ def _create_fx_call_function( # If translation validation is enabled, all arguments must have its # own FX node. - assert all( - a is not None for a in args - ), f"missing arg in FX graph ({op.__name__}): {args}" + assert all(a is not None for a in args), ( + f"missing arg in FX graph ({op.__name__}): {args}" + ) node = self.fx_node_cache[node_key] = self.graph.call_function(op, args) self.name_to_node[node.name] = node @@ -4279,9 +4366,9 @@ def _produce_dyn_sizes_from_int_tuple( source: Source, symbolic_context: SymbolicContext, ) -> list[sympy.Expr]: - assert all( - not is_symbolic(val) for val in tensor_size - ), f"Expect size to be a plain tuple of ints but got {tensor_size}" + assert all(not is_symbolic(val) for val in tensor_size), ( + f"Expect size to be a plain tuple of ints but got {tensor_size}" + ) from torch._dynamo.source import TensorProperty, TensorPropertySource _assert_symbol_context(symbolic_context) @@ -4297,6 +4384,17 @@ def _produce_dyn_sizes_from_int_tuple( do_not_specialize_zero_one=config.backed_size_oblivious, symbolic_context=symbolic_context, ) + if ( + isinstance(symbolic_context, StatelessSymbolicContext) + and symbolic_context.specialize_on + ): + for specialization in symbolic_context.specialize_on[i]: + self.specializations.add( + Specialization( + TensorPropertySource(source, TensorProperty.SIZE, i), + specialization, + ) + ) if ( config.backed_size_oblivious and isinstance(sym, sympy.Symbol) # could be static @@ -4312,7 +4410,11 @@ def create_symbolic_sizes_strides_storage_offset( source: Source, *, symbolic_context: Optional[SymbolicContext] = None, - ) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]: + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: """ Returns a list of symbolic sizes and strides for the given tensor. We try our best to express stride in terms of the sizes, so as to not @@ -4371,15 +4473,15 @@ def create_symbolic_sizes_strides_storage_offset( # The order of checking the guards matters. In this specific example: # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, - # we may have an unnessary shape speciliazation for y. + # we may have an unnecessary shape speciliazation for y. def _maybe_specialize_sym_int_with_hint( self, maybe_sym: IntLikeType ) -> IntLikeType: assert isinstance(maybe_sym, (int, torch.SymInt)) if is_symbolic(maybe_sym): - assert ( - maybe_sym.node.shape_env is not self - ), "expect the symbol is created from an shape env other than current one." + assert maybe_sym.node.shape_env is not self, ( + "expect the symbol is created from an shape env other than current one." + ) return maybe_sym.node.require_hint() return maybe_sym @@ -4395,7 +4497,11 @@ def _create_symbolic_sizes_strides_storage_offset( source: Source, *, symbolic_context: Optional[SymbolicContext] = None, - ) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]: + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: dim = len(ex_size) # Reimplement the legacy behavior @@ -4959,9 +5065,9 @@ def create_symbol( sloc, ) else: - self.var_to_range[ - sympy_expr - ] = self._default_unspecified_value_range() + self.var_to_range[sympy_expr] = ( + self._default_unspecified_value_range() + ) self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) # Small performance optimization: if we have a min-max constraint, @@ -5143,7 +5249,7 @@ def produce_guards_verbose( # calls on this new instance. Finally, it will check whether this new instance # has equal state. # - # It's important that we do it in the begining of this function, since it modifies + # It's important that we do it in the beginning of this function, since it modifies # self.dim_constraints through its execution. Changes that happen in this method # aren't interesting, since this is the function call we wish to reproduce at the # end. If we wish to simply reproduce ShapeEnv instances even after this call, @@ -5152,9 +5258,9 @@ def produce_guards_verbose( shape_env = replay_shape_env_events(self.events) self.check_equal(shape_env) - assert len(placeholders) == len( - sources - ), f"len({placeholders}) != len({sources})" + assert len(placeholders) == len(sources), ( + f"len({placeholders}) != len({sources})" + ) Tensorlike = (torch.Tensor, FakeTensorMeta) def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: @@ -5250,9 +5356,9 @@ def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict( list ) - symbol_to_constraints: defaultdict[ - sympy.Symbol, set[Constraint] - ] = collections.defaultdict(set) + symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = ( + collections.defaultdict(set) + ) constraint_violations: list[tuple[bool, str, Callable[[], str]]] = [] printers: list[_ShapeGuardPrinter] = [] @@ -5431,8 +5537,9 @@ def hint(s: sympy.Expr) -> str: user_stack = self.specialization_stacks.get(source, None) msg = ( f"You marked {self._debug_name(source)} as dynamic but your code " - f"specialized it to be a constant ({val}). Either remove the mark_dynamic " - f"or use a less strict API such as maybe_mark_dynamic or Dim.AUTO." + f"specialized it to be a constant ({val}). If you're using mark_dynamic, " + f"either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, " + f"replace it with either Dim.STATIC or Dim.AUTO." + ( "\n\nUser stack:\n" + "".join(user_stack.format()) if user_stack @@ -6139,7 +6246,7 @@ def _maybe_evaluate_static( Use compute_hint == True if you are trying to compute a non-binding hint for the particular hint values of backed and unbacked SymInts, - e.g., if s0 happens to be 3 this run, compute_hint will subsitute s0 with 3. + e.g., if s0 happens to be 3 this run, compute_hint will substitute s0 with 3. """ # axioms with compute hint NYE @@ -6160,7 +6267,7 @@ def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None: # A FloorDiv in implications could have became CleanDiv at this point, due to new facts # to the shapeEnv. This handles such issue but its not ideal. This is the only expression # simplification that depends on the global state of shape env. - # TODO try to get rid of CleanDiv since it breaks the invariant thats simplifications of sympy + # TODO try to get rid of CleanDiv since it breaks the invariant that's simplifications of sympy # expressions only depend on the expression itself. if k.has(FloorDiv): new_items.update({self.simplify(k): v}) @@ -6242,6 +6349,24 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: expr = safe_expand(expr) expr = self.replace(expr) + # Simplify max(0/1, x) to x when x >= 0/1. max(1, x) is a commonly introduced + # expression when creating contiguous strides. + if not size_oblivious: + min_max_replacements = {} + for atom in expr.atoms(Max): # type: ignore[has-type] + if len(atom.args) > 2: + continue + a, b = atom.args + if b == 1 or b == 0: + a, b = b, a + + if a == 1 and self._maybe_evaluate_static(sympy.Ge(b, 1)): + min_max_replacements[atom] = b + if a == 0 and self._maybe_evaluate_static(sympy.Ge(b, 0)): + min_max_replacements[atom] = b + if min_max_replacements: + expr = expr.xreplace(min_max_replacements) + if size_oblivious and (expr.has(Max) or expr.has(Min)): # type: ignore[has-type] min_max_replacements = {} for atom in (*expr.atoms(Max), *expr.atoms(Min)): # type: ignore[has-type] @@ -6377,7 +6502,7 @@ def size_hint( ), }, ) - self.defer_runtime_assert( + self.guard_or_defer_runtime_assert( sympy.Eq(result_expr, unsound_expr), f"propagate_real_tensors: {result_expr} == {unsound_expr}", ) @@ -6433,7 +6558,7 @@ def _make_data_dependent_error( f"Caused by: {sloc}\n" 'For more information, run with TORCH_LOGS="dynamic"\n' "For extended logs when we create symbols, also add " - f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n" + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n' "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" "For more debugging help, see " "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" @@ -6567,9 +6692,9 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: ) self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) tgt_bound = self.bound_sympy(tgt) - assert tgt_bound.issubset( - src_bound - ), f"{tgt_bound=} not a subset of {src_bound=}" + assert tgt_bound.issubset(src_bound), ( + f"{tgt_bound=} not a subset of {src_bound=}" + ) # TODO: Should we propagate size-like-ness? # @@ -6699,12 +6824,20 @@ def _find(self, a: sympy.Symbol) -> sympy.Expr: return self.replacements[a] @lru_cache(256) - def _maybe_guard_rel(self, expr: sympy.Rel) -> None: + def _maybe_guard_rel(self, expr: sympy.Expr) -> None: """ The relational guard is guarded to be true. Use this information to simplify shapes (i.e. a == b or a % 5 == 0) """ - assert isinstance(expr, sympy.Rel) + if isinstance(expr, sympy.And): + for arg in expr.args: + self._maybe_guard_rel(arg) + return + elif not isinstance(expr, sympy.Rel): + log.warning( + "_maybe_guard_rel() was called on non-relation expression %s", expr + ) + return # A good example of what goes wrong if you don't do this is # python test/functorch/test_aotdispatch.py -k @@ -6714,9 +6847,9 @@ def _maybe_guard_rel(self, expr: sympy.Rel) -> None: free = list(expr.free_symbols) - assert ( - len(free) > 0 - ), f"The expression should not be static by this point: {expr}" + assert len(free) > 0, ( + f"The expression should not be static by this point: {expr}" + ) # In case of really gnarly expression, we don't blow up if len(free) > 5: return @@ -6898,13 +7031,12 @@ def _check_frozen(self, expr: sympy.Basic, concrete_val: sympy.Basic) -> None: stack_info=True if log.getEffectiveLevel() < logging.WARNING else False, ) - def _get_user_frame(self) -> types.FrameType: + def _get_user_frame(self) -> Optional[types.FrameType]: frame = inspect.currentframe() while frame is not None: if frame.f_code.co_filename not in uninteresting_files(): return frame frame = frame.f_back - assert frame is not None return frame def _get_stack_summary( @@ -6914,11 +7046,12 @@ def _get_stack_summary( if floc is None: frame = self._get_user_frame() try: - floc = traceback.FrameSummary( - frame.f_code.co_filename, - frame.f_lineno, - frame.f_code.co_name, - ) + if frame is not None: + floc = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) finally: del frame @@ -7184,8 +7317,6 @@ def _log_real_tensor_propagation( }, ) - @lru_cache(256) - @record_shapeenv_event(save_tracked_fakes=True) def evaluate_expr( self, orig_expr: sympy.Basic, @@ -7275,8 +7406,9 @@ def _evaluate_expr( ): return orig_expr - # Don't track this one - @functools.lru_cache(None) + # Don't track this one. (Because this cache is inside this function the + # cache only lasts for the invocation of this function call) + @functools.cache def compute_concrete_val() -> sympy.Basic: if hint is None: # This is only ever called for expressions WITHOUT unbacked @@ -7356,9 +7488,11 @@ def compute_concrete_val() -> sympy.Basic: if static_expr is not None: self.log.debug( "eval %s == %s [statically known]", - f"size_oblivious({orig_expr})" - if size_oblivious - else size_oblivious, + ( + f"size_oblivious({orig_expr})" + if size_oblivious + else size_oblivious + ), static_expr, ) if ( @@ -7477,7 +7611,7 @@ def compute_concrete_val() -> sympy.Basic: g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type] if transmute_into_runtime_assert: - self.defer_runtime_assert( + self.guard_or_defer_runtime_assert( g, f"propagate_real_tensors: {orig_expr} == {concrete_val}" ) return concrete_val @@ -7485,15 +7619,14 @@ def compute_concrete_val() -> sympy.Basic: if not self._suppress_guards_tls(): self._log_guard("eval", g, forcing_spec=forcing_spec) - if isinstance(g, sympy.Rel): - # TODO: If we successfully eliminate a symbol via equality, it - # is not actually necessary to save a guard for the equality, - # as we will implicitly generate a guard when we match that - # input against the symbol. Probably the easiest way to - # implement this is to have maybe_guard_rel return a bool - # saying if it "subsumed" the guard (and therefore the guard - # is no longer necessary) - self._maybe_guard_rel(g) + # TODO: If we successfully eliminate a symbol via equality, it + # is not actually necessary to save a guard for the equality, + # as we will implicitly generate a guard when we match that + # input against the symbol. Probably the easiest way to + # implement this is to have maybe_guard_rel return a bool + # saying if it "subsumed" the guard (and therefore the guard + # is no longer necessary) + self._maybe_guard_rel(g) if not self.allow_complex_guards_as_runtime_asserts: # at this point, we've evaluated the concrete expr value, and have @@ -7508,7 +7641,7 @@ def compute_concrete_val() -> sympy.Basic: # it's fine to defer simple guards here without checking, # the _maybe_guard_rel() call above will set replacements if possible, # and so the result here will be statically known - self.defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") + self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") else: self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) @@ -7553,17 +7686,18 @@ def cleanup(self) -> None: @lru_cache(256) @record_shapeenv_event(save_tracked_fakes=True) - def defer_runtime_assert( + def guard_or_defer_runtime_assert( self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None ) -> bool: - """Create an assert that is checked at runtime + """ + Adds a guard that orig_expr is True if we can or fall back to adding an assert + that is checked at runtime. Args: orig_expr (sympy.Expr): Boolean expression to assert is true msg (str): Message to display on assertion failure fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding to the expression, if applicable - """ expr = orig_expr @@ -7607,8 +7741,7 @@ def defer_runtime_assert( log.debug("runtime_asserts_frozen but then got %s", expr) self._check_frozen(expr, sympy.true) # eliminate symbols on equality tests / refine ranges - if isinstance(expr, sympy.Rel): - self._maybe_guard_rel(expr) + self._maybe_guard_rel(expr) # canonicalise to remove equations that are trivially equal orig_expr = expr diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index 4f160995cce0a4..11cc8bd59a736c 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -354,7 +354,7 @@ def __setstate__(self, d): self._cache = {} @property - def __doc__(self): + def __doc__(self): # type: ignore[override] docs = [f"Multiply dispatched method: {self.name}"] if self.doc: diff --git a/torch/fx/experimental/unification/multipledispatch/utils.py b/torch/fx/experimental/unification/multipledispatch/utils.py index 9c91cca2067afc..0b21183c40b97a 100644 --- a/torch/fx/experimental/unification/multipledispatch/utils.py +++ b/torch/fx/experimental/unification/multipledispatch/utils.py @@ -5,9 +5,9 @@ __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] -def raises(err, lamda): +def raises(err, lamda): # codespell:ignore lamda try: - lamda() + lamda() # codespell:ignore lamda return False except err: return True diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index 7634c9b2ec90b8..a8035f75d30277 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -23,9 +23,9 @@ def transitive_get(key, d): return key -def raises(err, lamda): +def raises(err, lamda): # codespell:ignore lamda try: - lamda() + lamda() # codespell:ignore lamda return False except err: return True diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 17a814b233c63f..db00952512067c 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -203,9 +203,7 @@ def floordiv( return _Z3Ops.to_real(result) if cast_result_to_real else result def ceil(self, number: z3.ArithRef) -> z3.ArithRef: - return z3.If( - self.floor(number) < number, self.floor(number + 1), number - ) # type: ignore[return-value] + return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value] def trunc(self, number: z3.ArithRef) -> z3.ArithRef: return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value] @@ -363,9 +361,9 @@ def call_function( return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] # Adds the Z3 expression corresponding to the first argument # as a validator input. - assert ( - len(args) == 1 - ), f"expected 1 argument on assertion. Got: {len(args)} " + assert len(args) == 1, ( + f"expected 1 argument on assertion. Got: {len(args)} " + ) self.validator.add_source_expr(args[0]) # type: ignore[arg-type] # Translates SymPy expressions into Z3 expressions. @@ -536,9 +534,9 @@ def _check_freesymbols(self, e: sympy.Basic) -> None: def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: z3expr = SympyToZ3(self).run(e) - assert isinstance( - z3expr, z3.BoolRef - ), f"expected boolean expression. Got: {z3expr}" + assert isinstance(z3expr, z3.BoolRef), ( + f"expected boolean expression. Got: {z3expr}" + ) return z3expr def add_source_expr(self, e: z3.BoolRef) -> None: @@ -653,7 +651,7 @@ def _validate(self) -> None: def translation_validation_enabled() -> bool: - # Checks everytime this function is called, in case the Dynamo + # Checks every time this function is called, in case the Dynamo # option is set, but Z3 is not installed. _assert_z3_installed_if_tv_set() return _HAS_Z3 and config.translation_validation diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 83b288196d302f..6b9622be2a8eee 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -449,7 +449,7 @@ def type_repr(o: Any): # This code-path used in Python < 3.9 return origin_typename - return f'{origin_typename}[{",".join(args)}]' + return f"{origin_typename}[{','.join(args)}]" else: # Bare type, such as `typing.Tuple` with no subscript # This code-path used in Python 3.9+ @@ -573,7 +573,7 @@ def append_stacktrace_summary(node: Node): summary_str = parsed_stack_trace.get_summary_str() else: summary_str = "" - body.append(f'\n {dim(f"# {summary_str}")}\n') + body.append(f"\n {dim(f'# {summary_str}')}\n") elif prev_stacktrace != "": prev_stacktrace = "" no_stacktrace_msg = "# No stacktrace found for following nodes" @@ -842,7 +842,7 @@ def gen_fn_def(self, free_vars, maybe_return_annotation): if len(has_annotation) > 0: fn_definition += "\n " + "".join(has_annotation) + "\n" fn_definition += f""" - {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" return fn_definition def generate_output(self, output_args): @@ -1005,7 +1005,7 @@ def find_nodes( Returns: - Iteratable of nodes with the requested op and target. + Iterable of nodes with the requested op and target. """ node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) if sort: @@ -1565,7 +1565,7 @@ def python_code( # To do this, we create a new namespace just for this source. All names # that get printed must come from this namespace. # - # Why can't we re-use node.name? Because it was generated within the + # Why can't we reuse node.name? Because it was generated within the # namespace `self._graph_namespace`. In order to provide uniqueness # over both locals (node.name) *and* globals, we create a completely # new namespace to put all identifiers in. @@ -1573,7 +1573,7 @@ def python_code( # Override Node's repr to generate a valid name within our namespace. # Since repr() is designed to produce a valid Python expression, it - # makes sense to re-use it. This way, it's easy to print something like + # makes sense to reuse it. This way, it's easy to print something like # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is # implemented cooperatively to allow this. def node_repr(n: Node): @@ -1877,7 +1877,9 @@ def insert_pdb(body): # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( - lambda body: insert_pdb(current_trans(body) if current_trans else body) + lambda body: insert_pdb( + current_trans(body) if current_trans else body + ) ) ) @@ -1916,7 +1918,7 @@ def on_generate_code_context_manager(): @contextmanager def _override_sym_repr( - override: Callable[["torch.types.PySymType"], str] + override: Callable[["torch.types.PySymType"], str], ) -> Iterator[None]: tmp = CodeGen._sym_repr try: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index c51d00a5249a93..2e1a0963f53b6a 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -204,6 +204,14 @@ def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: tracer_extras = body.get("_tracer_extras", {}) graph = KeepModules().trace(com, **tracer_extras) + # Recover node.meta["stack_trace"] after re-tracing + node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None) + if node_meta_stack_trace is not None: + del body["_graphmodule_graph_node_meta_stack_trace"] + for node in graph.nodes: + if node_meta_stack_trace.get(node.name, None) is not None: + node.meta["stack_trace"] = node_meta_stack_trace[node.name] + # Manually set Tracer class on the reconstructed Graph, to avoid # referencing the private local subclass KeepModules. graph._tracer_cls = tracer_cls @@ -316,9 +324,9 @@ def _print_readable( colored=False, ): graph = module.graph - assert graph is not None and isinstance( - graph, torch.fx.Graph - ), "print_readable must be used on a module with a graph" + assert graph is not None and isinstance(graph, torch.fx.Graph), ( + "print_readable must be used on a module with a graph" + ) verbose_python_code = graph.python_code( root_module="self", @@ -859,6 +867,16 @@ def __reduce_package__(self, exporter: PackageExporter): dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ del dict_without_graph["_graph"] + # Store node.meta["stack_trace"] so we can recover them after re-tracing during deserialization + node_meta_stack_trace = { + node.name: node.meta["stack_trace"] + for node in self.graph.nodes + if "stack_trace" in node.meta + } + dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = ( + node_meta_stack_trace + ) + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" python_code = self.recompile() import_block = _format_import_block(python_code.globals, exporter.importer) @@ -977,7 +995,7 @@ def _replicate_for_data_parallel(self): @contextlib.contextmanager def _set_replace_hook(self, f): """ - Takes a callable which will be called everytime when we replace a node + Takes a callable which will be called every time when we replace a node to a new node, or change the node's name. Callable takes three arguments: the old node we're changing, and NAME of the new node, followed by the user node which consumes the old node to be replaced. @@ -991,7 +1009,7 @@ def _set_replace_hook(self, f): def _register_replace_node_hook(self, f): """ - Takes a callable which will be called everytime when we replace a node + Takes a callable which will be called every time when we replace a node to a new node, or change the node's name. Callable takes three arguments: the old node we're changing, and NAME of the new node, followed by the user node which consumes the old node to be replaced. @@ -1001,7 +1019,7 @@ def _register_replace_node_hook(self, f): def _unregister_replace_node_hook(self, f): """ - Takes a callable which was previously registered to be called everytime when we replace a node. + Takes a callable which was previously registered to be called every time when we replace a node. This function will unregister that callable so it is no longer invoked on node replacement. """ assert callable(f), "create_node hook must be a callable." diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index 86648541e3425c..e2d2f9d7466dd2 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -51,7 +51,9 @@ class Interpreter: method equivalents). We could subclass Interpreter like so:: class NegSigmSwapInterpreter(Interpreter): - def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: + def call_function( + self, target: Target, args: Tuple, kwargs: Dict + ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) @@ -405,7 +407,7 @@ def fetch_attr(self, target: str): for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( - f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}" + f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr @@ -468,14 +470,20 @@ class Transformer(Interpreter): class NegSigmSwapXformer(Transformer): def call_function( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, + target: "Target", + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, + target: "Target", + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], ) -> Any: if target == "neg": call_self, *args_tail = args diff --git a/torch/fx/node.py b/torch/fx/node.py index 7f51fe20201c4d..a6b65703287a45 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -77,7 +77,7 @@ # Dynamo is unable to trace global set[Callable].__contains__. # See https://github.com/pytorch/pytorch/issues/145761. Since we only have # a handful of ops so switch to list of callables. -_side_effectful_need_to_be_preserved_pre_dispatch: list[Callable] = [ +_side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [ torch._C._set_grad_enabled, torch.amp._enter_autocast, torch.amp._exit_autocast, @@ -85,7 +85,7 @@ # TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, # or add logic to correctly mark all inplace ops as side effectful. -_side_effectful_functions: set[Callable] = { +_side_effectful_functions: set[Callable[..., Any]] = { torch._assert, torch._assert_async, _ops.aten._assert_async.msg, @@ -98,7 +98,8 @@ _ops.profiler._record_function_exit, _ops.inductor.accumulate_grad_.default, operator.setitem, -} | set(_side_effectful_need_to_be_preserved_pre_dispatch) + *_side_effectful_need_to_be_preserved_pre_dispatch, +} if hasattr(_ops.inductor, "resize_storage_bytes_"): _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default) @@ -244,7 +245,7 @@ class Node(_NodeBase): # should not be accessed directly. _input_nodes: dict["Node", None] # All of the nodes that use the value produced by this Node - # Note one user may correspond to several uses, e.g. the node fo ``x + x`` + # Note one user may correspond to several uses, e.g. the node for ``x + x`` # would appear once here, but represents two uses. # Is a dict to act as an "ordered set". Keys are significant, value dont-care users: dict["Node", None] @@ -514,9 +515,9 @@ def insert_arg(self, idx: int, arg: Argument) -> None: idx (int): The index of the element in ``self.args`` to be inserted before. arg (Argument): The new argument value to insert into ``args`` """ - assert ( - 0 <= idx <= len(self.args) - ), "insert_args index must be between 0 and len(self.args)" + assert 0 <= idx <= len(self.args), ( + "insert_args index must be between 0 and len(self.args)" + ) args_left = self.args[:idx] args_right = self.args[idx:] @@ -747,13 +748,13 @@ def is_impure(self, impure_random: bool = True) -> bool: # Check if an impure module. if self.op == "call_module": - assert ( - self.graph.owning_module is not None - ), "self.graph.owning_module not set for purity check" + assert self.graph.owning_module is not None, ( + "self.graph.owning_module not set for purity check" + ) target_mod = self.graph.owning_module.get_submodule(self.target) - assert ( - target_mod is not None - ), f"Did not find expected submodule target {self.target}" + assert target_mod is not None, ( + f"Did not find expected submodule target {self.target}" + ) return getattr(target_mod, "_is_impure", False) return False diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index d5ed709fa7e22d..bc7537c23847f4 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -64,8 +64,8 @@ # manage to eliminate all float compute, this ends up being equivalent, but # there is a critical difference when some floats cannot be eliminated: when # we call item() on them, what should it's SymFloat be? Ideally, it would -# be the same backed SymFloat we had before. But without symbolic expresssion -# propogation on tensor quantities, repropagating would instead give you an +# be the same backed SymFloat we had before. But without symbolic expression +# propagation on tensor quantities, repropagating would instead give you an # unbacked SymFloat. Maybe it is a good idea to implement symbolic propagation # on 0d scalar tensors, but I decided to go for something simpler to start. # @@ -190,6 +190,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: return expr_to_tensor_proxy[expr] + failed_tensorify_ops: set[str] = set() nodes = list(graph.nodes) for i, node in enumerate(nodes[:-1]): with graph.inserting_before( @@ -302,8 +303,15 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: metrics_context.set( "tensorify_float_success", True, overwrite=True ) - - failed_tensorify_ops: set[str] = set() + else: + for a in node.args: + if ( + isinstance(a, fx.Node) + and "val" in a.meta + and isinstance(zf := a.meta["val"], torch.SymFloat) + ): + failed_tensorify_ops.update(str(node.target)) + log.info("Failed to tensorify %s", str(node.target)) # Now do one more pass that specializes all symfloats we didn't manage # to tensorify away. @@ -332,15 +340,13 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: # # It's better to guard on zf // 2 == 2.0 than zf == 5.0 - failed_tensorify_ops.update(str(key) for key in node.users.keys()) - node.replace_all_uses_with(guard_scalar(val)) graph.erase_node(node) # Sometimes by the time we get to tensorify, there have already been # specializations, eg. in python_arg_parser.h. In these cases, # placeholder nodes no longer have a reference to their original - # symfloat and thus we need to deduce specializations have happend + # symfloat and thus we need to deduce specializations have happened # via shape_env.replacements. NB: there's an important invariant here # that symfloats keep consistent names across restarts. for k, v in shape_env.var_to_val.items(): diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index f559aa0bfcb3d9..6026e9ca25c05c 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -88,7 +88,7 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: """ # Total num of elements total_num_of_elems = 0 - # For a module, conside all parameters + # For a module, consider all parameters if node.op == "call_module": submodule_dict = dict(fx_module.named_modules()) submodule = submodule_dict[node.target] diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index e19abc7ad3d8b7..8e59fc7ae1793a 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -48,7 +48,7 @@ def __init__( self.erased_nodes: set[str] = set() self.created_nodes: set[str] = set() self.name_to_node: dict[str, Node] = {} - # record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context + # record graph modules deepcopied from self.gm, so we can remove hooks on them when exiting the context self.copied_gms: list[GraphModule] = [] self._node_creation_hook = self.get_node_creation_hook() diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 3fef16c6adb3f2..438661090942a8 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -2,6 +2,7 @@ import collections import itertools import logging +import operator from collections.abc import Iterable, Sequence from typing import Optional @@ -215,7 +216,7 @@ def _update_partition_map(node: Node, id: int): # merge all possible partitions for partition_id, _ in sorted( - partitions_order.items(), key=lambda item: item[1] + partitions_order.items(), key=operator.itemgetter(1) ): merge_candidates[partition_id] = None diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 68753d9351f103..4077e74360f568 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -78,7 +78,7 @@ def _topological_sort_passes( if len(constraints) == 0: return passes - # Contruct a graph mapping nodes to a list of their users + # Construct a graph mapping nodes to a list of their users graph: dict[Callable, list[Callable]] = {p: [] for p in passes} indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0) candidates: Queue = Queue() diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index edcb842cc89221..8c15b9097397b5 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -95,7 +95,7 @@ class _MinimizerBase: Currently we provides two ways to traverse the graph and generate submodules. 1. Sequential traversal: this will traverse the graph node by node and generate - one submodule with one sigle node. + one submodule with one single node. 2. Binary searching: this will do a binary search style traversal on the graph. For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. @@ -648,7 +648,7 @@ def _block_traverse( ) -> NodeSet: """ Traverse topologically sorted node list - Find minimium block (start_idx, end_idx) which contains the culprit + Find minimum block (start_idx, end_idx) which contains the culprit 1st pass: search for end_idx by finding the last node in culprit block where Numerical accuracy (0, end_idx) > threshold 2nd pass: search for start_idx by finding the first node in culprit block @@ -770,9 +770,9 @@ def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: node_name = node.name if node_name is not None and isinstance(node_name, tuple): node_name = node_name[0] - assert node_name is not None and isinstance( - node_name, str - ), f"minimize: node_name: {node_name}" + assert node_name is not None and isinstance(node_name, str), ( + f"minimize: node_name: {node_name}" + ) report.append(f"Add node: {node_name}") diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index ddb1410f684066..48dfe702fedbb5 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -93,9 +93,9 @@ def loop_pass( predicate (Callable[Object, bool], optional): """ - assert (n_iter is not None) ^ ( - predicate is not None - ), "Exactly one of `n_iter`or `predicate` must be specified." + assert (n_iter is not None) ^ (predicate is not None), ( + "Exactly one of `n_iter`or `predicate` must be specified." + ) @wraps(base_pass) def new_pass(source): diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 0fcd72938367c5..6027c603ec1feb 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -3,6 +3,7 @@ import itertools from collections import defaultdict from enum import Enum +from typing import Any, Callable import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -187,7 +188,7 @@ def _maybe_get_inplace_op(op): return inplace_op -_VIEW_INVERSE_MAP = { +_VIEW_INVERSE_MAP: dict[Callable[..., Any], Callable[..., Any]] = { torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, @@ -252,6 +253,7 @@ def matching_view_metadata(a, b): assert isinstance(base.meta["fake_result"], FakeTensor) assert isinstance(mutated_view, Node) assert isinstance(mutated_view.meta["fake_result"], FakeTensor) + assert not isinstance(n.target, str) # Check that this view_inverse op actually corresponds to taking doing the inverse # of one of our existing self_alias nodes. original_view = _VIEW_INVERSE_MAP[n.target] @@ -264,7 +266,7 @@ def matching_view_metadata(a, b): continue self_alias_base = self_alias.meta["view_of"] try: - # The we're trying to re-use the args from the view_scatter call inside of the corresponding + # The we're trying to reuse the args from the view_scatter call inside of the corresponding # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse # of the current alias we're looking at. view_replay_metadata = original_view( @@ -289,7 +291,7 @@ def reinplace(gm, *sample_args): mutating the nodes of the graph. We look for out-of-place op call sites like `b = a.add(...)`, and convert them to be inplace (`b = a.add_(...)`), - as long as the input to the current operator ("a") isn't re-used + as long as the input to the current operator ("a") isn't reused anywhere later in the graph. This pass currently expects to operate on a **functional, ATen** graph. @@ -340,7 +342,7 @@ def reinplace(gm, *sample_args): NOTE: there's a future optimization that we should make: if "a" is a (alias of a) program input, but later in the program there is a node that looks like "a.copy_(...)", - Then re-inplacing is ok to do - we are temporarily re-using a's buffer, + Then re-inplacing is ok to do - we are temporarily reusing a's buffer, which will later be overwritten by the copy_() call. This will be an important optimization to have for programs that mutate @@ -597,7 +599,7 @@ def _add_to_map(x): later_node_usages, self_aliases ) - # Step 2: Check to see if the input to the op is re-used later in the graph. + # Step 2: Check to see if the input to the op is reused later in the graph. # If not (same goes for its aliases), then this op is safe to re-in place. # This is a slightly roundabout way to check that there are no later usages of the current self argument. # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index b0479fd84b024a..38c64c527aff06 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -397,7 +397,9 @@ def has_new_unbacked_bindings(): nn_module_stack=node.meta.get("nn_module_stack"), ), ): - expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] + expr_to_proxy[sym_expr] = _sympy_interp( + expr_to_proxy, sym_expr + ) # type: ignore[arg-type] # won't try DCE-ing tensor compute here hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] node.replace_all_uses_with(hash_node) diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index 05fb3b5dbaf606..3815b2f058f0c6 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -7,7 +7,7 @@ import torch.fx from torch._dispatch.python import enable_python_dispatcher from torch._guards import detect_fake_mode -from torch._prims_common import definitely_contiguous_for_memory_format +from torch._prims_common import contiguous_for_memory_format_or_false from torch._subclasses.meta_utils import is_sparse_any from torch.fx._compatibility import compatibility from torch.fx.node import map_aggregate, Node @@ -35,8 +35,8 @@ class TensorMetadata(NamedTuple): # When include_contiguity is True, we will set contiguity when its always true for the tensor. # Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3). -# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous, -# contiguous, and unknown). +# In such situation contiguity is not set. We could also make it a tri-state i.e: (def_contiguous, +# def_not_contiguous and unknown). def _extract_tensor_metadata( result: torch.Tensor, include_contiguity=True ) -> TensorMetadata: @@ -57,7 +57,7 @@ def _extract_tensor_metadata( torch.channels_last_3d, } for query_format in memory_formats: - if definitely_contiguous_for_memory_format( + if contiguous_for_memory_format_or_false( result, memory_format=query_format ): memory_format = query_format diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 926747b2a41ff2..079b1b4364bd8c 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -199,9 +199,9 @@ def flatten(x: torch.fx.node.Argument) -> NodeList: mx = max((c.order for c in upstream_components), default=0) # Expect the component for `node` has higher order then its upstream components. - assert ( - comp.order >= mx - ), f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}" + assert comp.order >= mx, ( + f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}" + ) # Map a input of `node` to nodes in the component's graph. def remap_func(x): diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 2b1d9ee616b2a9..d3ef35bdb10709 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -222,7 +222,7 @@ class SplitResult(NamedTuple): split_module: root module after splitting. submodule_inputs: a dict that maps submodule name to its inputs. non_acc_submodule_prefix: the prefix for non acc submodules. For - acc submodule the prefix is alwasy "_run_on_acc_". + acc submodule the prefix is always "_run_on_acc_". """ split_module: torch.fx.GraphModule @@ -261,7 +261,8 @@ def pre_forward(module, module_inputs): for name, mod in model.named_modules(): if name in target_submodules: - handles.append(mod.register_forward_pre_hook(pre_forward)) + if not isinstance(mod, torch.jit.ScriptModule): + handles.append(mod.register_forward_pre_hook(pre_forward)) def clean_up_handles(): for h in handles: diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 7487bc2c6631bf..1b22490405de51 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import copy from queue import SimpleQueue from typing import Optional as _Optional @@ -9,14 +8,14 @@ from torch.fx.graph_module import GraphModule from torch.fx.node import Node from torch.fx.passes.tools_common import legalize_graph, NodeList, NodeSet -from torch.fx.passes.utils import lift_subgraph_as_module +from torch.fx.passes.utils import lift_subgraph_as_module # type: ignore[attr-defined] @compatibility(is_backward_compatible=False) def topo_sort(nodes: NodeList) -> NodeList: # sort nodes according to the topological order indegree_map = dict.fromkeys(nodes, 0) - candidates: SimpleQueue = SimpleQueue() + candidates: SimpleQueue[Node] = SimpleQueue() for node in nodes: for n in node.all_input_nodes: @@ -36,16 +35,16 @@ def topo_sort(nodes: NodeList) -> NodeList: if indegree_map[n] == 0: candidates.put(n) - assert len(nodes) == len( - sorted_nodes - ), "topological sorted nodes doesn't have same length as input nodes" + assert len(nodes) == len(sorted_nodes), ( + "topological sorted nodes doesn't have same length as input nodes" + ) return sorted_nodes @compatibility(is_backward_compatible=False) def validate_partition(partition: NodeList) -> bool: - # verify the partition does't form a dependency cycle in the original graph + # verify the partition doesn't form a dependency cycle in the original graph # returns True for valid partition, False for invalid partition_set = set(partition) @@ -127,13 +126,13 @@ def fuse_as_graphmodule( # assumption: nodes are already sorted in topo order for node in nodes: - assert ( - node.graph.owning_module is gm - ), f"{node} doesn't belong to passed in graph module {gm._get_name()}" + assert node.graph.owning_module is gm, ( + f"{node} doesn't belong to passed in graph module {gm._get_name()}" + ) assert not node._erased, f"{node} has been removed from owning graph" - assert ( - node in gm.graph._find_nodes_lookup_table - ), f"{node} is not found in graph module {gm._get_name()}" + assert node in gm.graph._find_nodes_lookup_table, ( + f"{node} is not found in graph module {gm._get_name()}" + ) # validates partition doesn't introduce dependency circles in the graph assert validate_partition(nodes), "Invalid partition, found dependency cycles" @@ -150,7 +149,7 @@ def fuse_as_graphmodule( node_map: dict[Node, Node] = {} # mapping of nodes from old graph to new graph # handles inputs through graph.node_copy's arg_transform functions - def remap_inputs(x): + def remap_inputs(x: Node) -> Node: if x.op == "get_attr": # TODO: do we really need copy the get_attr node into the graph? # do something here @@ -158,13 +157,13 @@ def remap_inputs(x): if x in partition_lookup_table: # x is inside subgraph, return the copied node - # the node should have been copied aleady, as we are copying graph in the topological order + # the node should have been copied already, as we are copying graph in the topological order return node_map[x] if x not in node_to_placeholder: # x is not in subgraph, create a new placeholder for subgraph placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) - # copy all meta fields, even if some fields might be irrelvant for the placeholder node + # copy all meta fields, even if some fields might be irrelevant for the placeholder node placeholder_node.meta = copy.copy(x.meta) node_to_placeholder[x] = placeholder_node @@ -195,7 +194,7 @@ def remap_inputs(x): subgraph.output(outs[0] if len(outs) == 1 else outs) # lint to ensure correctness - subgraph.lint() + subgraph.lint() # type: ignore[no-untyped-call] fused_gm: GraphModule fused_gm, _ = lift_subgraph_as_module( gm, subgraph, comp_name="", class_name=module_name @@ -216,7 +215,7 @@ def insert_subgm( sub_gm: GraphModule, orig_inputs: tuple[Node, ...], orig_outputs: tuple[Node, ...], -): +) -> GraphModule: # add sub_gm into gm submodule_name = sub_gm.__class__.__name__ gm.add_submodule(submodule_name, sub_gm) @@ -241,7 +240,7 @@ def insert_subgm( @compatibility(is_backward_compatible=False) -def erase_nodes(gm: GraphModule, nodes: NodeList): +def erase_nodes(gm: GraphModule, nodes: NodeList) -> None: # erase original nodes in inversed topological order for node in reversed(nodes): gm.graph.erase_node(node) diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index 27d24ed29945de..4ecbe8640def1f 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -96,9 +96,9 @@ def __init__( for node in pattern.nodes: if node.op != "output": - assert ( - len(node.users) > 0 - ), "SubgraphMatcher cannot be initialized with an pattern with dead code" + assert len(node.users) > 0, ( + "SubgraphMatcher cannot be initialized with an pattern with dead code" + ) # TODO: assert pattern is a connected graph @@ -137,11 +137,14 @@ def _match_attributes(self, pn: Node, gn: Node) -> bool: raise RuntimeError(f"Unsupported type {pn_value} when matching attributes") return False - def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: + def _nodes_are_equal(self, pn: Node, gn: Node, node_name_match: str = "") -> bool: # if exact match for placeholder is not required, then use placeholder as a wildcard if not self.match_placeholder and pn.op == "placeholder": return True + if node_name_match and node_name_match in gn.name: + return True + if pn.op == gn.op: if pn.op == "placeholder" or pn.op == "output": return True @@ -192,9 +195,9 @@ def _remove_overlapping_matches( return non_overlapping_matches def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: - assert not ( - isinstance(pn, Node) and isinstance(gn, Node) - ), "pn and gn cannot both be Node" + assert not (isinstance(pn, Node) and isinstance(gn, Node)), ( + "pn and gn cannot both be Node" + ) if isinstance(pn, Node) and not isinstance(gn, Node): if pn.op == "placeholder": @@ -212,7 +215,9 @@ def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: else: return type(gn) == type(pn) and gn == pn - def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: + def _match_nodes( + self, pn: Node, gn: Node, match: InternalMatch, node_name_match: str = "" + ) -> bool: logger.info(" matching %s to %s", pn, gn) assert isinstance(pn, Node) and isinstance(gn, Node), str( @@ -228,7 +233,7 @@ def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: if gn in match.nodes_map.values(): return False - if not self._nodes_are_equal(pn, gn): + if not self._nodes_are_equal(pn, gn, node_name_match): return False # Optimistically mark `pn` as a match for `gn`, and save a local copy of match @@ -313,11 +318,11 @@ def get_all_arguments(orig_args, orig_kwargs): return True - def match(self, graph: Graph) -> list[InternalMatch]: + def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]: """ Returns: The matched subgraphs. - Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder + The returned subgraph would be fully self-contained, meaning the nodes (except placeholder and nodes returned by output) can only be consumed by nodes within the matched subgraph. Subgraph pattern matcher is implemented with the backtracking style in the following steps: @@ -355,7 +360,7 @@ def match(self, graph: Graph) -> list[InternalMatch]: match_candidates: dict[Node, list[Node]] = defaultdict(list) for pattern_anchor in self.pattern_anchors: for node in graph.nodes: - if self._nodes_are_equal(pattern_anchor, node): + if self._nodes_are_equal(pattern_anchor, node, node_name_match): match_candidates[pattern_anchor].append(node) match_candidates_list = list(match_candidates.items()) @@ -382,7 +387,9 @@ def backtracking(anchor_index, match): for node in candidate_nodes: logger.info("Trying to match anchor %s to %s", pattern_anchor, node) - match_found = self._match_nodes(pattern_anchor, node, match) + match_found = self._match_nodes( + pattern_anchor, node, match, node_name_match + ) if match_found: # match next anchor backtracking(anchor_index + 1, match) diff --git a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py index 1fa9b721e9ccdb..3114d55b635fcb 100644 --- a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py +++ b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -18,17 +18,17 @@ def _split_to_graph_and_name_node_map( if n.op == "output": assert gm._out_spec is not None output = tree_unflatten(n.args[0], gm._out_spec) - assert isinstance( - output, tuple - ), "Expecting the pattern graph to return a tuple" - assert ( - len(output) >= 2 - ), "Expecting the pattern graph to have at least two outputs" + assert isinstance(output, tuple), ( + "Expecting the pattern graph to return a tuple" + ) + assert len(output) >= 2, ( + "Expecting the pattern graph to have at least two outputs" + ) *out, name_node_map = output flattened, out_spec = tree_flatten(out) - assert isinstance( - name_node_map, dict - ), "Expecting the input graph to have a dict output as the last element" + assert isinstance(name_node_map, dict), ( + "Expecting the input graph to have a dict output as the last element" + ) n.args = (flattened,) orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined] @@ -53,12 +53,14 @@ def pattern(x, weight): relu = F.relu(conv) return relu, {"conv": conv, "relu": relu} + def target_graph(x, weight): conv = F.conv2d(x, weight) relu = F.relu(conv) relu *= 2 return relu + pattern_gm = export_for_training(pattern, example_inputs).module() target_gm = export_for_training(target_graph, example_inputs).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) @@ -86,7 +88,7 @@ def __init__( ignore_literals, ) - def match(self, graph: Graph) -> list[InternalMatch]: + def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]: """The returned InternalMatch will have name_node_map populated with a map from node name (str) to the target node, e.g. {"conv": target_conv_ndoe, "relu": target_relu_node} @@ -105,7 +107,7 @@ def pattern(...): return relu, {"conv": conv, "relu": relu} ``` instead """ - internal_matches = super().match(graph) + internal_matches = super().match(graph, node_name_match) for internal_match in internal_matches: for k, n in self.name_node_map.items(): internal_match.name_node_map[k] = internal_match.nodes_map[n] diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 20bef1628bfc64..b9dca84e069eda 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -124,6 +124,10 @@ def __exit__(self, *args): class TracerBase: graph: Graph record_stack_traces: bool = False + # When record_stack_traces is True, only reocrd stack traces + # with forward function names. + # This helps when we want stack trace back to model code + _record_forward_stack_traces_only: bool = False # Feature flag for mutable schema checking # Enableby default in 1.12 check_mutable_operations: bool = False @@ -204,6 +208,42 @@ def create_node( elif self.module_stack: node.meta["nn_module_stack"] = copy.copy(self.module_stack) + if self.record_stack_traces and not node.stack_trace: + from torch.fx.experimental.symbolic_shapes import uninteresting_files + + user_frame_summary = CapturedTraceback.extract().summary() + if user_frame_summary: + if self._record_forward_stack_traces_only: + user_frame_summary = [ + frame + for frame in user_frame_summary + if ( + frame.name == "forward" + or frame.filename.endswith("torch/__init__.py") + ) + ] + else: + first_forward = -1 + for i, frame in enumerate(user_frame_summary): + if frame.name == "forward": + user_frame_summary = user_frame_summary[i:] + first_forward = i + break + + # Not having a "forward" call in the stacktrace implies the + # stacktrace will probably be irrelevant + if first_forward == -1: + user_frame_summary = [] + + stack_trace = [ + frame + for frame in user_frame_summary + if frame.filename not in uninteresting_files() + ] + if stack_trace: + stack_trace = traceback.StackSummary.from_list(stack_trace) + node.stack_trace = "".join(stack_trace.format()).strip() + log.debug("create_node %s", node) return node @@ -245,31 +285,6 @@ def create_proxy( else: proxy = proxy_factory_fn(node) - if self.record_stack_traces and not proxy.node.stack_trace: - from torch.fx.experimental.symbolic_shapes import uninteresting_files - - user_frame_summary = CapturedTraceback.extract().summary() - if user_frame_summary: - first_forward = -1 - for i, frame in enumerate(user_frame_summary): - if frame.name == "forward": - user_frame_summary = user_frame_summary[i:] - first_forward = i - break - - # Not having a "forward" call in the stacktrace implies the - # stacktrace will probably be irrelevant - if first_forward == -1: - user_frame_summary = [] - - stack_trace = [ - frame - for frame in user_frame_summary - if frame.filename not in uninteresting_files() - ] - stack_trace = traceback.StackSummary.from_list(stack_trace) - proxy.node.stack_trace = "".join(stack_trace.format()).strip() - return proxy def _find_user_frame(self): @@ -654,9 +669,9 @@ def __torch_function__(cls, orig_method, types, args=None, kwargs=None): meta_proxy = arg break - assert ( - meta_proxy is not None - ), "No MetaProxy found in arguments, but one is expected." + assert meta_proxy is not None, ( + "No MetaProxy found in arguments, but one is expected." + ) proxy = super().__torch_function__(orig_method, types, args, kwargs) with meta_proxy.fake_mode: @@ -739,14 +754,14 @@ def impl(*args, **kwargs): return tracer.create_proxy("call_function", target, args, kwargs) impl.__name__ = method - as_magic = f'__{method.strip("_")}__' + as_magic = f"__{method.strip('_')}__" setattr(Proxy, as_magic, impl) _scope(method) def _define_reflectable(orig_method_name): - method_name = f'__r{orig_method_name.strip("_")}__' + method_name = f"__r{orig_method_name.strip('_')}__" def impl(self, rhs): target = getattr(operator, orig_method_name) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index ae6854f678870f..eebdfad0963229 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -234,6 +234,7 @@ def replace_pattern_with_filters( replacement_callback: Optional[ Callable[["InternalMatch", Graph, Graph], Graph] ] = None, + node_name_match: str = "", ) -> list[ReplacedPatterns]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -246,10 +247,17 @@ def replace_pattern_with_filters( ``replacement_callback``: A function that takes in a match and returns a Graph to be used as the replacement. This allows you to construct a replacement graph based on the match. + ``replacement_callback``: Node name to match. If not empty, it will try to match the node name. """ return _replace_pattern( - gm, pattern, replacement, match_filters, ignore_literals, replacement_callback + gm, + pattern, + replacement, + match_filters, + ignore_literals, + replacement_callback, + node_name_match, ) @@ -265,6 +273,7 @@ def _replace_pattern( replacement_callback: Optional[ Callable[["InternalMatch", Graph, Graph], Graph] ] = None, + node_name_match: str = "", ) -> list[ReplacedPatterns]: from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher @@ -288,7 +297,9 @@ def _replace_pattern( remove_overlapping_matches=True, ignore_literals=ignore_literals, ) - _matches: list[InternalMatch] = matcher.match(original_graph) + _matches: list[InternalMatch] = matcher.match( + original_graph, node_name_match=node_name_match + ) # Filter out matches that don't match the filter _matches = [ @@ -307,9 +318,9 @@ def _replace_pattern( elif callable(replacement): common_replacement_graph = symbolic_trace(replacement).graph else: - assert ( - replacement_callback is not None - ), "Must provide either a replacement GraphModule or a replacement callback" + assert replacement_callback is not None, ( + "Must provide either a replacement GraphModule or a replacement callback" + ) common_replacement_graph = None # As we progressively replace nodes, we'll need to keep track of how the match results should change @@ -322,9 +333,9 @@ def _replace_pattern( match, original_graph, pattern_graph ) else: - assert ( - common_replacement_graph is not None - ), "Must provide either a replacement GraphModule or a replacement callback" + assert common_replacement_graph is not None, ( + "Must provide either a replacement GraphModule or a replacement callback" + ) replacement_graph = common_replacement_graph replacement_placeholders = [ n for n in replacement_graph.nodes if n.op == "placeholder" diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 83bc39ae550412..49b2784e1df129 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -47,3 +47,6 @@ loadu maximum minimum size + +# torch/headeronly/macros/Export.h +C10_API diff --git a/torch/headeronly/BUILD.bazel b/torch/headeronly/BUILD.bazel new file mode 100644 index 00000000000000..f4a27fac1f7f64 --- /dev/null +++ b/torch/headeronly/BUILD.bazel @@ -0,0 +1,9 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "torch_headeronly", + hdrs = glob([ + "**/*.h" + ]), + visibility = ["//visibility:public"], +) diff --git a/torch/headeronly/macros/Export.h b/torch/headeronly/macros/Export.h new file mode 100644 index 00000000000000..183aeab563445c --- /dev/null +++ b/torch/headeronly/macros/Export.h @@ -0,0 +1,87 @@ +#pragma once + +/* Header file to define the common scaffolding for exported symbols. + * + * Export is by itself a quite tricky situation to deal with, and if you are + * hitting this file, make sure you start with the background here: + * - Linux: https://gcc.gnu.org/wiki/Visibility + * - Windows: + * https://docs.microsoft.com/en-us/cpp/cpp/dllexport-dllimport?view=vs-2017 + * + * Do NOT include this file directly. Instead, use c10/macros/Macros.h + */ + +// You do not need to edit this part of file unless you are changing the core +// pytorch export abstractions. +// +// This part defines the C10 core export and import macros. This is controlled +// by whether we are building shared libraries or not, which is determined +// during build time and codified in c10/core/cmake_macros.h. +// When the library is built as a shared lib, EXPORT and IMPORT will contain +// visibility attributes. If it is being built as a static lib, then EXPORT +// and IMPORT basically have no effect. + +// As a rule of thumb, you should almost NEVER mix static and shared builds for +// libraries that depend on c10. AKA, if c10 is built as a static library, we +// recommend everything dependent on c10 to be built statically. If c10 is built +// as a shared library, everything dependent on it should be built as shared. In +// the PyTorch project, all native libraries shall use the macro +// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static +// libraries. + +// For build systems that do not directly depend on CMake and directly build +// from the source directory (such as Buck), one may not have a cmake_macros.h +// file at all. In this case, the build system is responsible for providing +// correct macro definitions corresponding to the cmake_macros.h.in file. +// +// In such scenarios, one should define the macro +// C10_USING_CUSTOM_GENERATED_MACROS +// to inform this header that it does not need to include the cmake_macros.h +// file. + +#ifdef _WIN32 +#define C10_HIDDEN +#if defined(C10_BUILD_SHARED_LIBS) +#define C10_EXPORT __declspec(dllexport) +#define C10_IMPORT __declspec(dllimport) +#else +#define C10_EXPORT +#define C10_IMPORT +#endif +#else // _WIN32 +#if defined(__GNUC__) +#define C10_EXPORT __attribute__((__visibility__("default"))) +#define C10_HIDDEN __attribute__((__visibility__("hidden"))) +#else // defined(__GNUC__) +#define C10_EXPORT +#define C10_HIDDEN +#endif // defined(__GNUC__) +#define C10_IMPORT C10_EXPORT +#endif // _WIN32 + +#ifdef NO_EXPORT +#undef C10_EXPORT +#define C10_EXPORT +#endif + +// Definition of an adaptive XX_API macro, that depends on whether you are +// building the library itself or not, routes to XX_EXPORT and XX_IMPORT. +// Basically, you will need to do this for each shared library that you are +// building, and the instruction is as follows: assuming that you are building +// a library called libawesome.so. You should: +// (1) for your cmake target (usually done by "add_library(awesome, ...)"), +// define a macro called AWESOME_BUILD_MAIN_LIB using +// target_compile_options. +// (2) define the AWESOME_API macro similar to the one below. +// And in the source file of your awesome library, use AWESOME_API to +// annotate public symbols. + +// Here, for the C10 library, we will define the macro C10_API for both import +// and export. + +// This one is being used by libc10.so +#ifdef C10_BUILD_MAIN_LIB +#define C10_API C10_EXPORT +#else +#define C10_API C10_IMPORT +#endif diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index fb8ac26471a958..2aa2fae3fde51f 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -18,7 +18,15 @@ _builtin_table: Optional[dict[int, str]] = None -_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950 +_modules_containing_builtins = ( + torch, + torch._C._nn, + torch._C._fft, # type: ignore[attr-defined] + torch._C._linalg, # type: ignore[attr-defined] + torch._C._nested, # type: ignore[attr-defined] + torch._C._sparse, # type: ignore[attr-defined] + torch._C._special, # type: ignore[attr-defined] +) _builtin_ops = [ # Pairs of (function, op_name) @@ -94,7 +102,10 @@ (torch.autograd.grad, "aten::grad"), (torch.autograd.backward, "aten::backward"), (torch._C._infer_size, "aten::_infer_size"), - (torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined] + ( + torch.nn.functional._no_grad_embedding_renorm_, # type: ignore[attr-defined] + "aten::_no_grad_embedding_renorm_", + ), (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"), (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"), (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"), diff --git a/torch/jit/_decomposition_utils.py b/torch/jit/_decomposition_utils.py index 795f9da8e073a1..3a4b4ceff2cf3d 100644 --- a/torch/jit/_decomposition_utils.py +++ b/torch/jit/_decomposition_utils.py @@ -4,9 +4,9 @@ def _register_decomposition(op: OpOverload, graph: torch._C.Graph): - assert not isinstance( - op, OpOverloadPacket - ), f"Must pass specific op overload, not overload packet, found {op}" + assert not isinstance(op, OpOverloadPacket), ( + f"Must pass specific op overload, not overload packet, found {op}" + ) assert isinstance(op, OpOverload) torch._C._jit_register_decomposition_for_schema(op._schema, graph) diff --git a/torch/jit/_decompositions.py b/torch/jit/_decompositions.py index ba37fe5f0cac26..000ec7d0ec7963 100644 --- a/torch/jit/_decompositions.py +++ b/torch/jit/_decompositions.py @@ -23,13 +23,13 @@ def check_decomposition_has_type_annotations(f): inspect_empty = inspect._empty # type: ignore[attr-defined] sig = inspect.signature(f) for param in sig.parameters.values(): - assert ( - param.annotation != inspect_empty - ), f"No signature on param {param.name} for function {f.name}" + assert param.annotation != inspect_empty, ( + f"No signature on param {param.name} for function {f.name}" + ) - assert ( - sig.return_annotation != inspect_empty - ), f"No return annotation for function {f.name}" + assert sig.return_annotation != inspect_empty, ( + f"No return annotation for function {f.name}" + ) def signatures_match(decomposition_sig, torch_op_sig): @@ -40,7 +40,7 @@ def signatures_match(decomposition_sig, torch_op_sig): return False for decomp_param, op_param in zip(decomp_params.values(), op_params.values()): - # can't check full equality yet because not all fields are correcly deduced + # can't check full equality yet because not all fields are correctly deduced # in the torch_op_sig - like default value # can't check 'kind' bc # kwarg-only values with defaults not yet supported in TS @@ -75,9 +75,9 @@ def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]: assert isinstance(aten_op, torch._ops.OpOverload) # Need unique name for jit function serialization - assert ( - f.__name__ not in function_name_set - ), f"Duplicated function name {f.__name__}" + assert f.__name__ not in function_name_set, ( + f"Duplicated function name {f.__name__}" + ) function_name_set.add(f.__name__) scripted_func = torch.jit.script(f) diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index 2d2db0a4f14246..b61a2dd6207d17 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -150,7 +150,7 @@ def run_frozen_optimizations( None Note: - In rare occassions, this can result in slower execution. + In rare occasions, this can result in slower execution. Example (Freezing a module with Conv->Batchnorm) .. code-block:: python diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index 7a324fda8af880..84ea4d5c3f6b0e 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -26,7 +26,7 @@ _IS_MONKEYTYPE_INSTALLED = False -# Checks whether a class is defind in `torch.*` modules +# Checks whether a class is defined in `torch.*` modules def is_torch_native_class(cls): if not hasattr(cls, "__module__"): return False @@ -130,7 +130,7 @@ def consolidate_types(self, qualified_name: str) -> dict: types = list(types) type_length = len(types) if type_length == 2 and type(None) in types: - # TODO: To remove this check once Union suppport in TorchScript lands. + # TODO: To remove this check once Union support in TorchScript lands. all_args[arg] = get_optional_of_element_type(types) elif type_length > 1: all_args[arg] = "Any" diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index d62f039263c202..e89bcc47dff6bb 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -431,7 +431,7 @@ def __init__(self) -> None: self.methods_compiled = set() def get_or_create_concrete_type(self, nn_module): - """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are re-used if possible.""" + """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are reused if possible.""" concrete_type_builder = infer_concrete_type_builder(nn_module) nn_module_type = type(nn_module) @@ -502,7 +502,7 @@ def get_module_concrete_type(nn_module, share_types=True): # Look into the store of cached JIT types concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module) else: - # Get a concrete type directly, without trying to re-use an existing JIT + # Get a concrete type directly, without trying to reuse an existing JIT # type from the type store. concrete_type_builder = infer_concrete_type_builder(nn_module, share_types) concrete_type_builder.set_poisoned() @@ -588,9 +588,9 @@ def init_fn(script_module): # recursively scripting them. for name, sub_concrete_type in concrete_type.get_modules(): orig_value = getattr(nn_module, name) - assert isinstance( - orig_value, Module - ), f"Expected Module but got {type(orig_value)}" + assert isinstance(orig_value, Module), ( + f"Expected Module but got {type(orig_value)}" + ) module_type = sub_concrete_type.jit_type if isinstance(module_type, torch._C.InterfaceType): # use the interface inference rule to compile the module diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 5777b047e74ef0..79442f57d30631 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -318,10 +318,10 @@ def make_stubs(module): else: return infer_methods_to_compile(module) - self.__dict__[ - "_actual_script_module" - ] = torch.jit._recursive.create_script_module( - self, make_stubs, share_types=not added_methods_in_init + self.__dict__["_actual_script_module"] = ( + torch.jit._recursive.create_script_module( + self, make_stubs, share_types=not added_methods_in_init + ) ) # Delete the Python attributes that now shadow the ScriptModule diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index aa0dc2b82d541d..f2a6f4a8417631 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -86,7 +86,7 @@ def broadcast_inplace(a: list[int], b: list[int]): dimsB = len(b) if dimsB > dimsA: raise AssertionError( - f"The dims of tensor b ({dimsB}) must be less than or equal tothe dims of tensor a ({dimsA}) " + f"The dims of tensor b ({dimsB}) must be less than or equal to the dims of tensor a ({dimsA}) " ) for dimA in range(dimsA): dimB = dimsB - dimsA + dimA @@ -280,15 +280,15 @@ def max_pool2d( dilation: list[int], ceil_mode: bool, ): - assert ( - len(kernel_size) == 1 or len(kernel_size) == 2 - ), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + assert len(kernel_size) == 1 or len(kernel_size) == 2, ( + "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + ) kH = kernel_size[0] kW = kH if len(kernel_size) == 1 else kernel_size[1] - assert ( - len(stride) == 0 or len(stride) == 1 or len(stride) == 2 - ), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, ( + "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + ) dH = kH if len(stride) == 0 else stride[0] if len(stride) == 0: dW = kW @@ -297,15 +297,15 @@ def max_pool2d( else: dW = stride[1] - assert ( - len(padding) == 1 or len(padding) == 2 - ), "max_pool2d: padding must either be a single int, or a tuple of two ints" + assert len(padding) == 1 or len(padding) == 2, ( + "max_pool2d: padding must either be a single int, or a tuple of two ints" + ) padH = padding[0] padW = padH if len(padding) == 1 else padding[1] - assert ( - len(dilation) == 1 or len(dilation) == 2 - ), "max_pool2d: dilation must be either a single int, or a tuple of two ints" + assert len(dilation) == 1 or len(dilation) == 2, ( + "max_pool2d: dilation must be either a single int, or a tuple of two ints" + ) dilationH = dilation[0] dilationW = dilationH if len(dilation) == 1 else dilation[1] @@ -367,17 +367,17 @@ def upsample_nearest2d( assert 0, "Either output_size or scale_factors must be presented" if output_size is not None: - assert ( - scale_factors is None - ), "Must specify exactly one of output_size and scale_factors" + assert scale_factors is None, ( + "Must specify exactly one of output_size and scale_factors" + ) assert len(output_size) == 2 out.append(output_size[0]) out.append(output_size[1]) if scale_factors is not None: - assert ( - output_size is None - ), "Must specify exactly one of output_size and scale_factors" + assert output_size is None, ( + "Must specify exactly one of output_size and scale_factors" + ) assert len(scale_factors) == 2 out.append(int(input[2] * scale_factors[0])) out.append(int(input[3] * scale_factors[1])) @@ -540,9 +540,9 @@ def check_cat_shape_except_dim( assert first_dims == second_dims, "Tensors must have same number of dimensions" for dim in range(0, first_dims): if dim != dimension: - assert ( - first[dim] == second[dim] - ), "Sizes of tensors must match except in dimension" + assert first[dim] == second[dim], ( + "Sizes of tensors must match except in dimension" + ) def cat(tensors: list[list[int]], dim: int): @@ -1088,9 +1088,9 @@ def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]: if len(self) == 0: result: list[int] = [] else: - assert ( - k <= self[dim] - ), f"k ({k}) is too big for dimension {dim} of size {self[dim]}" + assert k <= self[dim], ( + f"k ({k}) is too big for dimension {dim} of size {self[dim]}" + ) result = _copy(self) result[dim] = k return result, result @@ -1172,7 +1172,7 @@ def cross_entropy_loss( adding ops. There are currently cases in the test case where this is being called in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first -opinfo test). The behavoir of index is significantly dependent on the inputs. +opinfo test). The behavior of index is significantly dependent on the inputs. This could be an error with how we are matching up shape functions, or that this function needs to just implement everything. @@ -1452,7 +1452,7 @@ def add_bounded_compute_mapping( # add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor) # TODO: migrate over all of symbolic_shape_registry_util.cpp -# These are duplicated here so that the functions will be serialiazed +# These are duplicated here so that the functions will be serialized add_shape_compute_mapping( "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", broadcast_three, diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index eae30f415e9b0f..5084d7c9228371 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -993,11 +993,7 @@ def forward(self, x): stacklevel=2, ) - from torch._utils_internal import ( - check_if_torch_exportable, - log_torch_jit_trace_exportability, - log_torchscript_usage, - ) + from torch._utils_internal import log_torchscript_usage traced_func = _trace_impl( func, @@ -1014,103 +1010,6 @@ def forward(self, x): _store_inputs, ) log_torchscript_usage("trace", model_id=_get_model_id(traced_func)) - - if check_if_torch_exportable(): - from torch._export.converter import TS2EPConverter - from torch.export._trace import ( - _convert_ts_to_export_experimental, - _process_jit_trace_inputs_for_export, - ) - - traced_func_for_export = _trace_impl( - func, - example_inputs=example_inputs, - optimize=optimize, - check_trace=False, - check_inputs=check_inputs, - check_tolerance=check_tolerance, - strict=strict, - _force_outplace=_force_outplace, - _module_class=_module_class, - _compilation_unit=_compilation_unit, - example_kwarg_inputs=example_kwarg_inputs, - _store_inputs=_store_inputs, - ) - - export_args, _ = _process_jit_trace_inputs_for_export( - example_inputs, example_kwarg_inputs - ) - - def _log_exportability(func_to_export, export_func, export_args, export_type): - try: - traced_result = func_to_export(*export_args) - except Exception as e: - _ = e - log_torch_jit_trace_exportability( - "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded" - ) - return - - try: - ep_module = export_func(func_to_export, export_args) - except Exception as e: - log_torch_jit_trace_exportability( - "trace", - str(export_type), - str(_ExportOutcome.FAILED_TO_EXPORT), - str(e), - ) - return - - try: - export = ep_module(*export_args) - except Exception as e: - log_torch_jit_trace_exportability( - "trace", str(export_type), str(_ExportOutcome.FAILED_TO_RUN), str(e) - ) - return - - if not analyze_ts_result_with_export_result(export, traced_result): - log_torch_jit_trace_exportability( - "trace", - str(export_type), - str(_ExportOutcome.ACCURACY_ERROR), - "accuracy error", - ) - return - - log_torch_jit_trace_exportability( - "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded" - ) - - def _direct_export_and_lower(func, export_args): - return torch.export.export(func, export_args, strict=False).module() - - def _convert_ts_to_export_source_to_source(func, export_args): - return TS2EPConverter(func, export_args).convert().module() - - # torch.jit.trace is noop when the original module is torch.jit.ScriptModule - if not isinstance(traced_func_for_export, torch.jit.ScriptModule): - _log_exportability( - traced_func_for_export, - _direct_export_and_lower, - export_args, - _ExportType.DIRECT_EXPORT, - ) - - _log_exportability( - traced_func_for_export, - _convert_ts_to_export_experimental, - export_args, - _ExportType.TRACE_AND_EXPORT, - ) - _log_exportability( - traced_func_for_export, - _convert_ts_to_export_source_to_source, - export_args, - _ExportType.SOURCE_TO_SOURCE, - ) - return traced_func @@ -1205,7 +1104,10 @@ def weighted_kernel_sum(self, weight): # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods - inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight} + inputs = { + "forward": example_forward_input, + "weighted_kernel_sum": example_weight, + } module = torch.jit.trace_module(n, inputs) """ diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 76682e75229918..f6f4d99918faf1 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -438,7 +438,11 @@ def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=No is_method = self_name is not None if type_line is not None: type_comment_decl = torch._C.parse_type_comment(type_line) - decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method) + decl = torch._C.merge_type_from_type_comment( + decl, # type: ignore[arg-type] + type_comment_decl, + is_method, # type: ignore[assignment] + ) return Def(Ident(r, def_name), decl, build_stmts(ctx, body)) @@ -1055,12 +1059,12 @@ def build_Compare(ctx, expr): in_expr = BinOp("in", lhs, rhs) cmp_expr = UnaryOp(r, "not", in_expr) else: - cmp_expr = BinOp(op_token, lhs, rhs) + cmp_expr = BinOp(op_token, lhs, rhs) # type: ignore[assignment] if result is None: result = cmp_expr else: - result = BinOp("and", result, cmp_expr) + result = BinOp("and", result, cmp_expr) # type: ignore[assignment] return result @staticmethod @@ -1135,7 +1139,7 @@ def build_ExtSlice(ctx, base, extslice): return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)]) elif sub_type is ast.ExtSlice: return Subscript(base, build_ExtSlice(ctx, base, expr.slice)) - else: # In Python3.9 array indicies are not wrapped in ast.Index + else: # In Python3.9 array indices are not wrapped in ast.Index if sub_type is ast.Tuple: # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] indices = [] diff --git a/torch/lib/libshm/CMakeLists.txt b/torch/lib/libshm/CMakeLists.txt index 8a7329ddab77f0..c3cd26fea7bf30 100644 --- a/torch/lib/libshm/CMakeLists.txt +++ b/torch/lib/libshm/CMakeLists.txt @@ -1,5 +1,5 @@ project(libshm C CXX) -cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +cmake_minimum_required(VERSION 3.27 FATAL_ERROR) set(TORCH_ROOT ${CMAKE_CURRENT_LIST_DIR}/../../../) diff --git a/torch/lib/libshm/core.cpp b/torch/lib/libshm/core.cpp index 1c056b72360a66..1a49c278824d31 100644 --- a/torch/lib/libshm/core.cpp +++ b/torch/lib/libshm/core.cpp @@ -27,8 +27,7 @@ static void start_manager() { std::array pipe_ends; SYSCHECK_ERR_RETURN_NEG1(pipe(pipe_ends.data())); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - pid_t pid; + pid_t pid = -1; SYSCHECK_ERR_RETURN_NEG1(pid = fork()); if (!pid) { SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[0])); @@ -99,8 +98,7 @@ THManagedMapAllocatorInit::THManagedMapAllocatorInit( : manager_handle_(manager_handle ? manager_handle : "") { // TODO: unlock GIL when contacting the manager try { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - ClientSocket* socket; + ClientSocket* socket = nullptr; if (!manager_handle_.empty()) { socket = &get_manager_socket(manager_handle_); } else { diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 4d40718bcd063f..355ad00d491aac 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -1369,21 +1369,21 @@ :attr:`ord` defines the norm that is computed. The following norms are supported: -====================== ========================= ======================================================== -:attr:`ord` norm for matrices norm for vectors -====================== ========================= ======================================================== -`None` (default) Frobenius norm `2`-norm (see below) -`'fro'` Frobenius norm -- not supported -- -`'nuc'` nuclear norm -- not supported -- -`inf` `max(sum(abs(x), dim=1))` `max(abs(x))` -`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))` -`0` -- not supported -- `sum(x != 0)` -`1` `max(sum(abs(x), dim=0))` as below -`-1` `min(sum(abs(x), dim=0))` as below -`2` largest singular value as below -`-2` smallest singular value as below -other `int` or `float` -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}` -====================== ========================= ======================================================== +====================== ========================== ====================================================== +:attr:`ord` norm for matrices norm for vectors +====================== ========================== ====================================================== +`None` (default) Frobenius norm `2`-norm (see below) +`'fro'` Frobenius norm -- not supported -- +`'nuc'` nuclear norm -- not supported -- +`inf` `max(sum(abs(x), dim=1))` `max(abs(x))` +`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))` +`0` -- not supported -- `sum(x != 0)` +`1` `max(sum(abs(x), dim=0))` as below +`-1` `min(sum(abs(x), dim=0))` as below +`2` largest `singular value`_ as below +`-2` smallest `singular value`_ as below +other `int` or `float` -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}` +====================== ========================== ====================================================== where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. @@ -1483,6 +1483,9 @@ tensor([ 3.7417, 11.2250]) >>> LA.norm(A[0, :, :]), LA.norm(A[1, :, :]) (tensor(3.7417), tensor(11.2250)) + +.. _singular value: + https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD """, ) diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 7e2f128560c32d..fb802eba1aa8b9 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -309,14 +309,14 @@ def _generate_docstring(func): operation_args, operation_kwargs = args_and_kwargs[func.__name__] arg_declarations = [ "\n ".join( - argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines() + argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines() ) for a in operation_args ] kwarg_declarations = [ "\n ".join( argument_declarations.get( - a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.' + a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD." ) .format(default=a.split("=", 1)[1]) .splitlines() @@ -745,9 +745,9 @@ def _sparse_csr_segment_reduction_helper( ) -> Tensor: # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True # FIXME: when dense dimensions are implemented for CSR tensors - assert ( - keepdim - ), "reduction operations on CSR tensors with keepdim=False is unsupported" + assert keepdim, ( + "reduction operations on CSR tensors with keepdim=False is unsupported" + ) reduce = op.__name__ valid_reductions = ["sum", "prod", "mean", "amax", "amin"] if reduce not in valid_reductions: @@ -781,9 +781,9 @@ def _sparse_csr_segment_reduction_helper( ) new_shape = [1, mask_input.size(1)] else: - assert ( - dims[0] == 1 - ), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." + assert dims[0] == 1, ( + "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." + ) # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 new_crow_indices = torch.cat( @@ -794,7 +794,7 @@ def _sparse_csr_segment_reduction_helper( 0, ) new_nnz = new_crow_indices[-1] - new_col_indices = col_indices.new_zeros(new_nnz) + new_col_indices = col_indices.new_zeros(new_nnz) # type: ignore[call-overload] new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined] new_shape = [mask_input.size(0), 1] else: @@ -1598,9 +1598,9 @@ def _std_var( mask: Optional[Tensor], take_sqrt: Optional[bool], ) -> Tensor: - assert ( - unbiased is None or correction_opt is None - ), "Only one of unbiased and correction may be given" + assert unbiased is None or correction_opt is None, ( + "Only one of unbiased and correction may be given" + ) correction = 1.0 if unbiased is not None: correction = 1.0 if unbiased else 0.0 @@ -1636,7 +1636,11 @@ def _std_var( total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) else: total = sum( - x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined] + x * x.conj(), + dim, + keepdim=keepdim, + dtype=compute_dtype, + mask=inmask, # type: ignore[possibly-undefined] ) if not keepdim: count = count.reshape(total.shape) diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 5bdc31391b7c75..2e3608b3e6d3da 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -25,7 +25,7 @@ def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]: >>> # xdoctest: +SKIP >>> from torch.masked import MaskedTensor - >>> data = torch.arange(6).reshape(2,3) + >>> data = torch.arange(6).reshape(2, 3) >>> mask = torch.tensor([[True, False, False], [True, True, False]]) >>> mt = MaskedTensor(data, mask) >>> is_masked_tensor(mt) @@ -304,7 +304,7 @@ def unary(cls, fn, data, mask): return MaskedTensor(fn(data), mask) @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): + def __torch_dispatch__(cls, func, types, args, kwargs): # type: ignore[override] func = func.overloadpacket from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE @@ -355,5 +355,5 @@ def is_sparse_csr(self): # type: ignore[override] # Update later to support more sparse layouts @property - def is_sparse(self): + def is_sparse(self): # type: ignore[override] return self.is_sparse_coo() or self.is_sparse_csr() diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py index b0a62c182578f7..cdbf6b16ac443e 100644 --- a/torch/mps/__init__.py +++ b/torch/mps/__init__.py @@ -5,6 +5,7 @@ performance can be achieved, by running work on the metal GPU(s). See https://developer.apple.com/documentation/metalperformanceshaders for more details. """ + from typing import Union import torch diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py index 6e194bb63b2879..eebeea9a02a49a 100644 --- a/torch/mps/profiler.py +++ b/torch/mps/profiler.py @@ -76,13 +76,13 @@ def is_metal_capture_enabled() -> bool: def is_capturing_metal() -> bool: - """Cheks if metal capture is in progress""" + """Checks if metal capture is in progress""" return torch._C._mps_isCapturing() # type: ignore[attr-defined] @contextlib.contextmanager def metal_capture(fname: str): - """Conext manager that enables capturing of Metal calls into gputrace""" + """Context manager that enables capturing of Metal calls into gputrace""" try: torch._C._mps_startCapture(fname) # type: ignore[attr-defined] yield diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index a9296539d58eef..4761969fc286b4 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -198,7 +198,7 @@ def snapshot() -> dict[str, Any]: def attach_out_of_memory_observer( - observer: Callable[[int, int, int, int], None] + observer: Callable[[int, int, int, int], None], ) -> None: r"""Attach an out-of-memory observer to MTIA memory allocator""" torch._C._mtia_attachOutOfMemoryObserver(observer) diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index 745c180d8c415c..aa1176d69c60c8 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -14,6 +14,7 @@ Because of the similarity of APIs we do not document most of this package contents, and we recommend referring to very good docs of the original module. """ + import multiprocessing import sys @@ -98,3 +99,24 @@ def _get_thread_name() -> str: init_reductions() + +# Leak ResourceTracker at exit for Python-3.12 on MacOS +# See https://github.com/pytorch/pytorch/issues/153050 and +# https://github.com/python/cpython/issues/88887 for more details +from multiprocessing.resource_tracker import ResourceTracker as _RT + + +if ( + sys.platform == "darwin" + and sys.version_info >= (3, 12, 2) + and hasattr(_RT, "__del__") +): + import atexit + + def _leak_RT_at_exit(): + def _noop(x): + pass + + _RT.__del__ = _noop # type: ignore[attr-defined] + + atexit.register(_leak_RT_at_exit) diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index 4e4539396f8323..cbd6eee571f132 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -290,7 +290,7 @@ def reduce_tensor(tensor): # 0xE000 -> --------CUDA allocation----- # # To send tensor1, the following info are required from sender to receiver for - # storage recontruction. + # storage reconstruction. # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process). # basePtr may not be exactly 0xA000 since it's a different process. # 2. offset(0xA100) of storage1 in the CUDA allocation. diff --git a/torch/nativert/OVERVIEW.md b/torch/nativert/OVERVIEW.md index 2ce87a6aa3c606..bfe97c9aefc75b 100644 --- a/torch/nativert/OVERVIEW.md +++ b/torch/nativert/OVERVIEW.md @@ -244,7 +244,7 @@ For CPU kernels, it is extremely inefficient to go through the dispatcher. For one, the dispatcher doesn't deal with kernel out-variants. > **_NOTE:_** an out-variant of a kernel is one that takes the outputs as -> mutable references. this has a few benefits... namely, it allows us to re-use +> mutable references. this has a few benefits... namely, it allows us to reuse > the storage/manage from the previous execution. In addition, the dispatcher acts as a stack machine. You push the inputs to the @@ -281,7 +281,7 @@ RuntimeConfigs { ### Constant Folding Constant folding is the process of finding all of the constant-evaluable -subgraphs, evaluating them at startup, and then storing thier results as +subgraphs, evaluating them at startup, and then storing their results as constants as opposed to re-evaluting them every time. To enable constant folding, you can set the following configurations. @@ -311,7 +311,7 @@ torch.ops.quantized.linear_dynamic_fp16.default which should give a ~2x speedup over the fp32 variant with minimal effect on correctness. -The linear_prepack_fp16 op will be constant-folded, so it's imperitive that +The linear_prepack_fp16 op will be constant-folded, so it's imperative that these two features are used together. To enable this feature, use the following configurations. @@ -327,7 +327,7 @@ RuntimeConfigs { > :warning: **This is an experimental feature** -The main upside of memory planning comes from the efficient re-use of tensor +The main upside of memory planning comes from the efficient reuse of tensor buffers, which is extremely important in memory-bound services. That is, if two tensors don’t have an overlapping lifetime during execution, and the first tensor is larger than the second, then the second tensor can share the same diff --git a/torch/nativert/common/FileUtil.cpp b/torch/nativert/common/FileUtil.cpp index 10c03638740f86..c0887b52779220 100644 --- a/torch/nativert/common/FileUtil.cpp +++ b/torch/nativert/common/FileUtil.cpp @@ -76,7 +76,7 @@ int filterCloseReturn(int r) { return r; } -// The following wrapX() funcions are private functions for wrapping file-io +// The following wrapX() functions are private functions for wrapping file-io // against interrupt and partial op completions. // Wrap call to f(args) in loop to retry on EINTR diff --git a/torch/nativert/detail/ITree.cpp b/torch/nativert/detail/ITree.cpp new file mode 100644 index 00000000000000..123ee4498d06fb --- /dev/null +++ b/torch/nativert/detail/ITree.cpp @@ -0,0 +1,485 @@ +#include +#include + +#include +#include + +#include +#include +#include + +namespace torch::nativert::detail { + +namespace { +inline constexpr int kDefaultTreeSpecSerializationProtocol = 1; + +c10::IValue dynamicToIValue(const nlohmann::json& obj) { + if (obj.is_string()) { + return obj.get(); + } else if (obj.is_number_integer()) { + return obj.get(); + } else { + TORCH_CHECK(false, "Unsupported dynamic type: ", obj); + } +} + +void itreeFlatten( + const c10::IValue& nested, + const ITreeSpec& spec, + std::vector& ivalues) { + if (spec.isIValue()) { + ivalues.push_back(nested); + return; + } + auto flattenFn = spec.nodeDefCache().flattenFn; + flattenFn(nested, spec, ivalues); +} + +class PytreeNodeRegistry { + public: + PytreeNodeRegistry() { + // Add some law of physics here. + registerNode( + "builtins.tuple", + NodeDef{ + [](const c10::IValue& nested, + const ITreeSpec& spec, + std::vector& ivalues) { + const auto& tuple = nested.toTupleRef().elements(); + TORCH_CHECK_EQ(tuple.size(), spec.children().size()); + for (size_t i = 0; i < tuple.size(); i++) { + itreeFlatten(tuple[i], spec.children(i), ivalues); + } + }, + [](std::vector flats, + const nlohmann::json& obj) -> c10::IValue { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null()); + return c10::ivalue::Tuple::create(std::move(flats)); + }, + [](ITreeMapNoReturnFn fn, + const c10::IValue& nested, + const ITreeSpec& spec) { + const auto& tuple = nested.toTupleRef().elements(); + TORCH_CHECK_EQ(tuple.size(), spec.children().size()); + for (size_t i = 0; i < tuple.size(); i++) { + ivalueApply(fn, tuple[i], spec.children(i)); + } + }}); + const auto& tupleNodeDef = getNodeDef("builtins.tuple"); + registerNode( + "collections.namedtuple", + NodeDef{ + tupleNodeDef.flattenFn, + [](std::vector flats, + const nlohmann::json& obj) -> c10::IValue { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null()); + return c10::ivalue::Tuple::create(std::move(flats)); + }, + tupleNodeDef.ivalueApplyFn, + [](std::string_view context) { return nlohmann::json{context}; }}); + registerNode( + "builtins.list", + NodeDef{ + [](const c10::IValue& nested, + const ITreeSpec& spec, + std::vector& ivalues) { + auto list = nested.toListRef(); + for (size_t i = 0; i < list.size(); i++) { + itreeFlatten(list[i], spec.children(i), ivalues); + } + }, + [](std::vector flats, + const nlohmann::json& obj) -> c10::IValue { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null()); + c10::List list(c10::AnyType::get()); + list.reserve(flats.size()); + for (auto& flat : flats) { + list.push_back(std::move(flat)); + } + return list; + }, + [](ITreeMapNoReturnFn fn, + const c10::IValue& nested, + const ITreeSpec& spec) { + auto list = nested.toListRef(); + for (size_t i = 0; i < list.size(); i++) { + ivalueApply(fn, list[i], spec.children(i)); + } + }}); + registerNode( + "torch.fx.immutable_collections.immutable_list", + getNodeDef("builtins.list")); + registerNode( + "builtins.dict", + NodeDef{ + [](const c10::IValue& nested, + const ITreeSpec& spec, + std::vector& ivalues) { + auto dict = nested.toGenericDict(); + const auto& contextKeys = spec.contextKeys(); + // allow the dict size less than the spec, missing key will be + // filled with empty tensor + TORCH_CHECK_LE(dict.size(), contextKeys.size()); + size_t i = 0; + for (const auto& key : contextKeys) { + auto it = dict.find(key); + + if (it != dict.end()) { + itreeFlatten(it->value(), spec.children(i), ivalues); + } else { + // when we have a dict with missing keys, we fill the missing + // ivalues with an empty tensor which is required for + // validation + for (size_t j = 0; j < spec.children(i).numIValues(); ++j) { + at::Tensor empty_tensor; + ivalues.emplace_back(std::move(empty_tensor)); + } + } + i++; + } + }, + [](std::vector flats, + const nlohmann::json& obj) -> c10::IValue { + c10::Dict dict( + c10::AnyType::get(), c10::AnyType::get()); + TORCH_CHECK(obj.is_array()); + TORCH_CHECK_EQ(obj.size(), flats.size()); + dict.reserve(flats.size()); + for (size_t i = 0; i < flats.size(); i++) { + dict.insert(dynamicToIValue(obj[i]), std::move(flats[i])); + } + return dict; + }, + [](ITreeMapNoReturnFn fn, + const c10::IValue& nested, + const ITreeSpec& spec) { + auto dict = nested.toGenericDict(); + const auto& contextKeys = spec.contextKeys(); + + size_t i = 0; + for (const auto& key : contextKeys) { + if (spec.children(i).isUsed()) { + auto it = dict.find(key); + if (it != dict.end()) { + ivalueApply(fn, it->value(), spec.children(i)); + } else { + TORCH_CHECK(false, "input arg is missing key ", key); + } + } + i++; + } + }}); + registerNode( + "torch.fx.immutable_collections.immutable_dict", + getNodeDef("builtins.dict")); + } + bool hasNodeDef(std::string_view typeName) const { + return registry_.find(std::string{typeName}) != registry_.end(); + } + const NodeDef& getNodeDef(std::string_view typeName) const { + return registry_.at(std::string{typeName}); + } + void registerNode(std::string_view typeName, NodeDef nodeDef) { + TORCH_CHECK(!hasNodeDef(typeName)); + registry_.emplace(typeName, nodeDef); + } + + private: + std::unordered_map registry_; +}; + +c10::Synchronized& getPytreeNodeRegistry() { + static auto* registry = new c10::Synchronized(); + return *registry; +} + +ITreeSpec makeITreeSpec( + const nlohmann::json& obj, + const std::vector& values, + int start) { + TORCH_CHECK(obj.is_object()); + TORCH_CHECK(obj.find("type") != obj.end()); + if (obj["type"].is_null()) { + TORCH_CHECK_EQ(obj["children_spec"].size(), 0); + TORCH_CHECK(obj["context"].is_null()); + + const Value* value = values[start]; + if (value) { + bool isUsed = !value->users().empty(); + return ITreeSpec(value, isUsed); + } else { + return ITreeSpec(value, false); + } + } + const auto& name = obj["type"].get(); + NodeDef nodeDefCache; + getPytreeNodeRegistry().withLock([&](auto& registry) { + TORCH_CHECK(registry.hasNodeDef(name), "Unknown pytree node type: ", name); + nodeDefCache = registry.getNodeDef(name); + }); + auto context = nodeDefCache.contextLoadFn(obj["context"].get()); + const auto& childrenSpec = obj["children_spec"]; + TORCH_CHECK(childrenSpec.is_array()); + std::vector children; + int offset = 0; + for (const auto& child : childrenSpec) { + children.push_back(makeITreeSpec(child, values, start + offset)); + // NOLINTNEXTLINE(*-narrowing-conversions) + offset += children.back().numIValues(); + } + + return ITreeSpec(name, context, std::move(children), nodeDefCache); +} + +} // namespace + +void registerPytreeNode(std::string_view typeName, NodeDef nodeDef) { + getPytreeNodeRegistry().withLock([&](auto& registry) { + registry.registerNode(typeName, std::move(nodeDef)); + }); +} + +ITreeSpec itreeSpecLoads( + std::string_view json, + const std::vector& values) { + const auto obj = nlohmann::json::parse(json); + TORCH_CHECK(obj.is_array()); + TORCH_CHECK_EQ(obj.size(), 2); + TORCH_CHECK_EQ(obj[0].get(), kDefaultTreeSpecSerializationProtocol); + auto result = makeITreeSpec(obj[1], values, 0); + + TORCH_CHECK_EQ(result.numIValues(), values.size()); + return result; +} + +c10::IValue itreeUnflatten( + std::vector ivalues, + const ITreeSpec& spec) { + RECORD_USER_SCOPE("nativert::itreeUnflatten"); + TORCH_CHECK_EQ(ivalues.size(), spec.numIValues()); + if (spec.isIValue()) { + return std::move(ivalues[0]); + } + auto unflattenFn = spec.nodeDefCache().unflattenFn; + if (spec.allIValues()) { + return unflattenFn(std::move(ivalues), spec.context()); + } + size_t start = 0; + std::vector childrenPytrees; + for (const auto& child : spec.children()) { + if (child.isIValue()) { + childrenPytrees.push_back(std::move(ivalues[start])); + start++; + continue; + } + size_t numIValues = child.numIValues(); + std::vector slice( + // NOLINTNEXTLINE(*-narrowing-conversions) + std::make_move_iterator(ivalues.begin() + start), + // NOLINTNEXTLINE(*-narrowing-conversions) + std::make_move_iterator(ivalues.begin() + start + numIValues)); + childrenPytrees.push_back(itreeUnflatten(std::move(slice), child)); + start += numIValues; + } + return unflattenFn(std::move(childrenPytrees), spec.context()); +} + +std::vector itreeFlatten( + const c10::IValue& nested, + const ITreeSpec& spec) { + std::vector ivalues; + ivalues.reserve(spec.numIValues()); + itreeFlatten(nested, spec, ivalues); + return ivalues; +} + +std::vector itreeFlattenFromArgs( + const std::vector& args, + const std::unordered_map& kwargs, + const ITreeSpec& spec) { + RECORD_USER_SCOPE("nativert::itreeFlattenFromArgs"); + TORCH_CHECK(!spec.isIValue()); + TORCH_CHECK_EQ(spec.children().size(), 2); + + std::vector ivalues; + ivalues.reserve(spec.numIValues()); + const auto& specArgs = spec.children(0); + TORCH_CHECK(!specArgs.isIValue()); + TORCH_CHECK_EQ(specArgs.children().size(), args.size()); + for (size_t i = 0; i < args.size(); i++) { + itreeFlatten(args[i], specArgs.children(i), ivalues); + } + + const auto& specKwargs = spec.children(1); + TORCH_CHECK(!specKwargs.isIValue()); + TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size()); + for (size_t i = 0; i < specKwargs.context().size(); i++) { + itreeFlatten( + kwargs.at(specKwargs.context()[i].get_ref()), + specKwargs.children(i), + ivalues); + } + return ivalues; +} + +void ivalueApplyFromArgs( + ITreeMapNoReturnFn fn, + const std::vector& args, + const std::unordered_map& kwargs, + const ITreeSpec& spec) { + RECORD_USER_SCOPE("nativert::ivalueApplyFromArgs"); + TORCH_CHECK(!spec.isIValue()); + TORCH_CHECK_EQ(spec.children().size(), 2); + + const auto& specArgs = spec.children(0); + TORCH_CHECK(!specArgs.isIValue()); + TORCH_CHECK_EQ(specArgs.children().size(), args.size()); + for (size_t i = 0; i < args.size(); i++) { + ivalueApply(fn, args[i], specArgs.children(i)); + } + + const auto& specKwargs = spec.children(1); + TORCH_CHECK(!specKwargs.isIValue()); + + const auto& ctx = specKwargs.context(); + TORCH_CHECK_EQ(ctx.size(), kwargs.size()); + + for (size_t i = 0; i < ctx.size(); i++) { + ivalueApply( + fn, + kwargs.at(ctx[i].get_ref()), + specKwargs.children(i)); + } +} + +std::vector itreeFlattenToTensorList( + const c10::IValue& nested, + const ITreeSpec& spec) { + auto flats = itreeFlatten(nested, spec); + std::vector tensors; + tensors.reserve(flats.size()); + for (const auto& flat : flats) { + tensors.push_back(flat.toTensor()); + } + return tensors; +} + +c10::IValue itreeMap( + ITreeMapFn f, + const c10::IValue& nested, + const ITreeSpec& spec) { + const auto flats = itreeFlatten(nested, spec); + std::vector mapped; + mapped.reserve(flats.size()); + for (const auto& flat : flats) { + mapped.push_back(f(flat)); + } + return itreeUnflatten(std::move(mapped), spec); +} + +c10::IValue argsToIValue( + const std::vector& args, + const std::unordered_map& kwargs) { + c10::Dict dict( + c10::StringType::get(), c10::AnyType::get()); + for (const auto& [key, arg] : kwargs) { + dict.insert(key, arg); + } + return c10::ivalue::Tuple::create({c10::ivalue::Tuple::create(args), dict}); +} + +std:: + pair, std::unordered_map> + itreeMapArgs( + ITreeMapFn f, + const std::vector& args, + const std::unordered_map& kwargs, + const ITreeSpec& spec) { + const auto val = argsToIValue(args, kwargs); + const auto mapVal = itreeMap(f, val, spec); + auto mapArgs = + mapVal.toTupleRef().elements()[0].toTupleRef().elements().vec(); + std::unordered_map mapKwargs; + for (const auto& entry : mapVal.toTupleRef().elements()[1].toGenericDict()) { + mapKwargs.emplace(entry.key().toStringRef(), entry.value()); + } + return {std::move(mapArgs), std::move(mapKwargs)}; +} + +void ivalueApply( + ITreeMapNoReturnFn fn, + const c10::IValue& nested, + const ITreeSpec& spec) { + if (spec.isIValue()) { + if (spec.isUsed()) { + fn(nested, spec.value()); + } + return; + } + auto ivalueApplyFn = spec.nodeDefCache().ivalueApplyFn; + ivalueApplyFn(fn, nested, spec); +} + +nlohmann::json defaultContextLoadFn(std::string_view context) { + return nlohmann::json::parse(context); +} + +ITreeSpec::ITreeSpec( + std::string_view uniformName, + nlohmann::json context, + std::vector children, + NodeDef nodeDefCache) + : uniformName_(uniformName), + context_(std::move(context)), + children_(std::move(children)), + nodeDefCache_(nodeDefCache), + numIValues_(0), + value_(nullptr), + isUsed_(false) { + for (auto& child : children_) { + numIValues_ += child.numIValues(); + allIValues_ &= child.isIValue(); + isUsed_ |= child.isUsed(); + } + + if (uniformName_ == "builtins.dict" || + uniformName_ == "torch.fx.immutable_collections.immutable_dict") { + for (const auto& keyObj : context_) { + contextKeys_.push_back(dynamicToIValue(keyObj)); + } + } +} + +c10::TypePtr ITreeSpec::toAtenType() const { + if (isIValue()) { + return c10::AnyType::get(); + } else if (uniformName_ == "builtins.tuple") { + std::vector childrenType; + childrenType.reserve(children_.size()); + for (const auto& childrenSpec : children_) { + childrenType.emplace_back(childrenSpec.toAtenType()); + } + return c10::TupleType::create(std::move(childrenType)); + } else if ( + uniformName_ == "builtins.list" || + uniformName_ == "torch.fx.immutable_collections.immutable_list") { + if (children_.empty()) { + return c10::ListType::create(c10::AnyType::get()); + } else { + return c10::ListType::create(children_[0].toAtenType()); + } + } else if ( + uniformName_ == "builtins.dict" || + uniformName_ == "torch.fx.immutable_collections.immutable_dict") { + if (children_.empty()) { + return c10::DictType::create(c10::AnyType::get(), c10::AnyType::get()); + } else { + return c10::DictType::create( + dynamicToIValue(context_[0]).type(), children_[0].toAtenType()); + } + } else { + TORCH_CHECK(false, "Unsupported uniform name: ", uniformName()); + } +} + +} // namespace torch::nativert::detail diff --git a/torch/nativert/detail/ITree.h b/torch/nativert/detail/ITree.h new file mode 100644 index 00000000000000..19359920720ac9 --- /dev/null +++ b/torch/nativert/detail/ITree.h @@ -0,0 +1,176 @@ +/* + * A C++ extension bridge with the Python pytree + * serialization/unserialization format for torch.export. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::nativert::detail { + +class ITreeSpec; + +using ITreeFlattenFn = + void (*)(const c10::IValue&, const ITreeSpec&, std::vector&); +using ITreeUnflattenFn = + c10::IValue (*)(std::vector, const nlohmann::json&); + +using ContextLoadFn = nlohmann::json (*)(std::string_view); + +using ITreeMapFn = c10::function_ref; +using ITreeMapNoReturnFn = + c10::function_ref; + +using IValueApplyFn = + void (*)(ITreeMapNoReturnFn, const c10::IValue&, const ITreeSpec&); + +nlohmann::json defaultContextLoadFn(std::string_view); + +struct NodeDef { + ITreeFlattenFn flattenFn; + ITreeUnflattenFn unflattenFn; + IValueApplyFn ivalueApplyFn; + + ContextLoadFn contextLoadFn = defaultContextLoadFn; +}; + +class ITreeSpec { + public: + // Leaf node. + ITreeSpec(const Value* value = nullptr, bool isUsed = true) + : numIValues_(1), value_(value), isUsed_(isUsed) {} + + // Non leaf node. + ITreeSpec( + std::string_view uniformName, + nlohmann::json context, + std::vector children, + NodeDef nodeDefCache); + + bool isIValue() const { + return !uniformName_; + } + + std::string_view uniformName() const { + TORCH_CHECK(uniformName_); + return uniformName_.value(); + } + + const nlohmann::json& context() const { + return context_; + } + + const std::vector& contextKeys() const { + return contextKeys_; + } + + const auto& children() const { + return children_; + } + + const ITreeSpec& children(size_t i) const { + return children_[i]; + } + + const NodeDef& nodeDefCache() const { + return nodeDefCache_; + } + + size_t numIValues() const { + return numIValues_; + } + + bool allIValues() const { + return allIValues_; + } + + c10::TypePtr toAtenType() const; + + bool isUsed() const { + return isUsed_; + } + + const Value* value() const { + return value_; + } + + private: + // Only non leaf nodes have names. + // Examples of uniform name: "builtins.tuple", "builtins.dict". + std::optional uniformName_; + nlohmann::json context_; + std::vector children_; + + std::vector contextKeys_; + + // Cached fields. + NodeDef nodeDefCache_; + size_t numIValues_; + bool allIValues_ = true; + + const Value* value_; + bool isUsed_; +}; + +void registerPytreeNode(std::string_view typeName, NodeDef nodeDef); + +// Serialized json tree spec should be dumped from treespec_dumps() in +// torch.utils._pytree directly . +ITreeSpec itreeSpecLoads( + std::string_view json, + const std::vector& values); + +c10::IValue itreeUnflatten( + std::vector ivalues, + const ITreeSpec& spec); + +std::vector itreeFlatten( + const c10::IValue& nested, + const ITreeSpec& spec); + +std::vector itreeFlattenFromArgs( + const std::vector& args, + const std::unordered_map& kwargs, + const ITreeSpec& spec); + +std::vector itreeFlattenToTensorList( + const c10::IValue& nested, + const ITreeSpec& spec); + +c10::IValue itreeMap( + ITreeMapFn f, + const c10::IValue& nested, + const ITreeSpec& spec); + +c10::IValue TORCH_API argsToIValue( + const std::vector& args, + const std::unordered_map& kwargs); + +std:: + pair, std::unordered_map> + itreeMapArgs( + ITreeMapFn f, + const std::vector& args, + const std::unordered_map& kwargs, + const ITreeSpec& spec); + +void ivalueApply( + ITreeMapNoReturnFn f, + const c10::IValue& nested, + const ITreeSpec& spec); + +void ivalueApplyFromArgs( + ITreeMapNoReturnFn fn, + const std::vector& args, + const std::unordered_map& kwargs, + const ITreeSpec& spec); + +} // namespace torch::nativert::detail diff --git a/torch/nativert/executor/ConstantFolder.cpp b/torch/nativert/executor/ConstantFolder.cpp new file mode 100644 index 00000000000000..7db1fd736243f9 --- /dev/null +++ b/torch/nativert/executor/ConstantFolder.cpp @@ -0,0 +1,167 @@ +#include + +#include +#include + +#include + +#include +#include + +namespace torch::nativert { + +/* + side effects: + 1. nodes deemed const-foldable nodes are unlinked from the graph. + they are still owned by the graph (i.e., show up in graph.nodeOwner_) + but are not accessible through the node iterator. + + 2. kernels associated with const-foldable nodes are removed from the + 'kernels' input + + 3. mark values deemed foldable as such, removing their producers +*/ + +void ConstantFolder::unlinkConstants( + std::vector>& kernels) { + TORCH_CHECK_EQ(kernels.size(), graph_.nodes().size()) + << "graph node count and kernel count should be equal"; + + unlinked_ = true; + + /* resolve all of the nodes that are const foldable */ + + c10::FastMap nodeDynInputs; + nodeDynInputs.reserve(graph_.nodes().size()); + + c10::FastMap*> nodeKernels; + nodeKernels.reserve(graph_.nodes().size()); + + const auto* input = &*graph_.nodes().begin(); + const auto* output = &*graph_.nodes().end(); + + { // ignore prim.Input and prim.Output + auto ct = 0; + for (auto& n : graph_.nodes()) { + if (&n == input || &n == output) { + continue; + } + nodeDynInputs[&n] = n.numInputs(); + nodeKernels[&n] = &kernels[++ct]; + } + } + + const auto& inputsToWeights = graph_.signature().inputsToWeights(); + for (const auto& [inputName, weightName] : inputsToWeights) { + for (auto* user : graph_.getValue(inputName)->users()) { + if (user == input || user == output) { + continue; + } + nodeDynInputs[user] -= 1; + } + } + + // set of foldable nodes for dedupe purposes + c10::FastSet foldable; + + std::queue constFoldableCandidates; + for (auto& [node, ct] : nodeDynInputs) { + if (ct++ /* will be decremented once dequeued */ == 0) { + constFoldableCandidates.push(node); + } + } + + while (!constFoldableCandidates.empty()) { + auto* candidate = constFoldableCandidates.front(); + constFoldableCandidates.pop(); + if (auto& ct = nodeDynInputs[candidate]; --ct == 0) { + foldable.insert(candidate); + Foldable f; + f.node = candidate; + f.kernel = std::move(*nodeKernels[candidate]); + foldables_.push_back(std::move(f)); + + candidate->unlink(); + + for (auto* user : candidate->users()) { + if (user == output) { + continue; + } + if (foldable.find(user) == foldable.end()) { + constFoldableCandidates.push(user); + } + } + + for (auto* out : candidate->outputs()) { + auto* value = graph_.getValue(out->name()); + + value->setIsFolded(); + + // we only store folded values if there is a non-foldable user + if (const auto& users = value->users(); + std::any_of(users.begin(), users.end(), [&](const auto* u) { + return foldable.find(u) == foldable.end(); + })) { + foldedOutputValueIds_.insert(value->id()); + } + } + } + } + + for (const auto& f : foldables_) { + VLOG(1) << "Const-folded node: " << *f.node; + } + + // remove moved (i.e., associated w/ const-folded nodes) kernels + // from the input kernel vector + kernels.erase( + std::remove_if( + kernels.begin(), + kernels.end(), + [](const auto& k) { return k == nullptr; }), + kernels.end()); + + graph_.renumberValues(); + graph_.finalize(); + graph_.lint(); + + return; +} + +/* + side effects: + 1. weights whose users are ONLY const-foldable nodes will be removed + from the 'weights' input +*/ + +void ConstantFolder::evaluate(Weights& weights) { + CHECK(unlinked_) + << "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants"; + + weights.validateAllWeightsLoaded(); + + ExecutionFrame frame(graph_); + frame.setWeights(weights); + + c10::FastMap foldedValues; + + for (const auto& f : foldables_) { + f.kernel->compute(frame); + + for (auto&& [i, out] : c10::enumerate(f.node->outputs())) { + if (foldedOutputValueIds_.find(out->id()) != + foldedOutputValueIds_.end()) { + foldedValues[std::string{out->name()}] = f.kernel->output(i, frame); + } + } + } + + for (auto it = std::make_move_iterator(foldedValues.begin()); + it != std::make_move_iterator(foldedValues.end()); + ++it) { + auto [n, iv] = std::move(*it); + weights.setConstFoldedValue(n, std::move(iv)); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ConstantFolder.h b/torch/nativert/executor/ConstantFolder.h new file mode 100644 index 00000000000000..b1d1afa12f4ffc --- /dev/null +++ b/torch/nativert/executor/ConstantFolder.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::nativert { + +struct Foldable { + Node* node; + std::unique_ptr kernel; +}; + +class ConstantFolder { + public: + explicit ConstantFolder(Graph& graph) : graph_(graph) {} + + /* + 1. identify nodes without dynamic inputs, mark as foldable + + 2. traverse the nodes deemed foldable as if they were being evaluated, + pushing nodes that become foldable after it's inputs were traversed. + + unlink foldable nodes from the graph in the topological order in which + they were traversed, storing the node and its associated kernel (moved + from 'kernels') as a foldable in Constantfolder + */ + void unlinkConstants( + /* kernels for const-foldable nodes will be removed from this vector */ + std::vector>& kernels); + + /* + 1. execute foldables_ on an execution frame initialized with the passed-in + weights, calling Weights::setConstFoldedValue if the folded value is + consumed by a non-foldable node + */ + void evaluate(Weights& weights); + + private: + Graph& graph_; + // unlinked nodes sorted in their topological order + // s.t., they can be evaluated sequentially + std::vector foldables_; + + bool unlinked_{false}; + + c10::FastSet foldedOutputValueIds_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/DelegateExecutor.cpp b/torch/nativert/executor/DelegateExecutor.cpp new file mode 100644 index 00000000000000..78ec4a0c15823b --- /dev/null +++ b/torch/nativert/executor/DelegateExecutor.cpp @@ -0,0 +1,68 @@ +#include + +#ifndef _WIN32 +#include +#endif + +#include + +#include + +#include +#include + +namespace torch::nativert { + +namespace { +char* _mkdtemp(char* outputDir) { + // mkdtemp is not available on Windows +#ifdef _WIN32 + return nullptr; +#else + return mkdtemp(outputDir); +#endif +} + +} // namespace + +std::string extractToTemporaryFolder( + caffe2::serialize::PyTorchStreamReader& packageReader, + const std::string& targetPath) { + char outputDir[] = "/tmp/delegate_model_XXXXXX"; + char* tempdir = _mkdtemp(outputDir); + TORCH_CHECK( + tempdir != nullptr, + "error creating temporary directory for compiled model. errno: ", + errno); + + std::vector allRecords = packageReader.getAllRecords(); + + for (const auto& path : allRecords) { + if (!c10::starts_with(path, targetPath) || c10::ends_with(path, "/")) { + continue; + } + + TORCH_CHECK( + packageReader.hasRecord(path), path, " not present in model package"); + auto [dataPointer, dataSize] = packageReader.getRecord(path); + + std::string fileName = path.substr(path.rfind('/') + 1); + std::string extractedFilename = std::string(outputDir) + "/" + fileName; + + VLOG(1) << "Extracting " << extractedFilename + << " from archive path: " << path << " size: " << dataSize; + + File extracted(extractedFilename, O_CREAT | O_WRONLY, 0640); + const auto bytesWritten = writeFull( + extracted.fd(), const_cast(dataPointer.get()), dataSize); + TORCH_CHECK( + bytesWritten != -1, + "failure copying from archive path ", + path, + " to temporary file"); + } + + return std::string(outputDir); +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/DelegateExecutor.h b/torch/nativert/executor/DelegateExecutor.h new file mode 100644 index 00000000000000..b8c3d506c4313e --- /dev/null +++ b/torch/nativert/executor/DelegateExecutor.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +namespace torch::nativert { + +std::string extractToTemporaryFolder( + caffe2::serialize::PyTorchStreamReader& packageReader, + const std::string& targetPath); + +using MakeProxyExecutorFn = + std::function( + const std::string&, + bool, + std::optional>)>; + +// This is the extension point for delegation backends. +class DelegateExecutor { + public: + virtual ~DelegateExecutor() {} + + // Runtime calls processWeights() to pass the weights to the delegate backend. + // Typically, a backend would perform some form of validation and processing, + // such as constant folding. The processed weights stays in the deactivate + // state until commitWeights() is called. + // + // Weights tensors are co-owned by the runtime and the delegate backend. + // In the regular inference run() path, neither Runtime or Delegate backend + // can modify the weights tensor. + // To support inplace weight update, weight tensors are be exposed by + // ModelRunner::getWeights() to an external caller. The external caller can + // then modify the weight tensors in-place. Such mutation would instantly + // affect the weight tensors in the delegate backend. + // When a weight tensor is no longer used by the delegate backend, the backend + // must release it by decreasing a refcount. Runtime would + // also release the refcount for weight tensor if it's no longer activate. The + // underlying storage for weight tensors will be freed when the refcount + // reaches 0. + virtual void processWeights(std::shared_ptr weights) = 0; + + // This call activate the processed weights. + virtual void commitWeights() = 0; + + virtual std::vector run(std::vector& inputs) = 0; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ExecutionFrame.cpp b/torch/nativert/executor/ExecutionFrame.cpp new file mode 100644 index 00000000000000..deedcde69c6cbf --- /dev/null +++ b/torch/nativert/executor/ExecutionFrame.cpp @@ -0,0 +1,182 @@ +#include +#include + +#include + +namespace torch::nativert { + +ExecutionFrame::ExecutionFrame(const Graph& graph) + : graph_(graph), + allValues_(graph.numValues()), + persistent_(graph.numValues()), + moveable_output_mask_(graph.userOutputs().size()) { + updatePersistentValues(/* weights = nullptr */); +} + +ExecutionFrame::ExecutionFrame( + const Graph& graph, + const Weights& weights, + const torch::nativert::ExecutorConfig& cfg, + LayoutPlanner* layoutPlanner) + : ExecutionFrame(graph) { + setWeights(weights); + if (layoutPlanner != nullptr) { + layoutPlanner_ = layoutPlanner; + layoutManager_ = std::make_unique( + *layoutPlanner, + *this, + cfg.layoutPlannerSettings.layoutManagerSettings()); + } +} + +void ExecutionFrame::setWeights(const Weights& weights) { + weightVersion_ = weights.version(); + updatePersistentValues(&weights); + updateMovableOutputs(); +} + +/* static */ std::vector> ExecutionFrame:: + getPersistentValues(const Graph& graph, const Weights* weights) { + std::vector> persistentValues; + + /* ADD GRAPH-DEPENDENT PERSISTENT VALUES */ + + for (const auto& [valueId, constSymintValue] : + graph.getConstantSymIntValues()) { + persistentValues.emplace_back(valueId, constSymintValue); + } + + if (weights == nullptr) { + return persistentValues; + } + + /* ADD WEIGHT-DEPENDENT PERSISTENT VALUES */ + + const auto& inputsToWeights = graph.signature().inputsToWeights(); + for (const auto& [inputName, weightName] : inputsToWeights) { + const Value* value = graph.getValue(inputName); + persistentValues.emplace_back(value->id(), weights->at(weightName)); + } + + const auto& inputsToCustomObjs = graph.signature().inputsToCustomObjs(); + for (const auto& [inputName, customObjName] : inputsToCustomObjs) { + const Value* value = graph.getValue(inputName); + persistentValues.emplace_back( + value->id(), weights->getCustomObj(customObjName)); + } + + std::unordered_map foldedConstIds; + for (const Node& node : graph.nodes()) { + if (node.target() == "torch.ops.higher_order.run_const_graph") { + const auto& const_graph = + std::get>(node.attributes().at(0).value); + for (size_t i = 0; i < node.outputs().size(); ++i) { + foldedConstIds[std::string{const_graph->outputs().at(i)->name()}] = + node.outputs()[i]->id(); + } + } + } + for (const auto& [name, tensor] : weights->getFoldedConsts()) { + persistentValues.emplace_back(foldedConstIds.at(name), tensor); + } + + for (const auto& [name, iv] : weights->getConstFoldedValues()) { + const Value* value = graph.getValue(name); + persistentValues.emplace_back(value->id(), iv); + } + + return persistentValues; +} + +void ExecutionFrame::updatePersistentValues(const Weights* weights) { + auto persistentValues = ExecutionFrame::getPersistentValues(graph_, weights); + for (auto it = std::make_move_iterator(persistentValues.begin()); + it != std::make_move_iterator(persistentValues.end()); + ++it) { + auto&& [value, iv] = *it; + setPersistentIValue(value, std::move(iv)); + } +} + +void ExecutionFrame::updateMovableOutputs() { + moveable_output_mask_.assign(moveable_output_mask_.size(), true); + + c10::FastSet inputs; + for (const auto* input : graph_.userInputs()) { + if (input) { + inputs.insert(input->id()); + } + } + + const auto& outputs = graph_.userOutputs(); + const size_t num_outputs = outputs.size(); + + c10::FastSet seen; + for (size_t i = 0; i < num_outputs; i++) { + auto idx = num_outputs - 1 - i; + if (const Value* const* valuePtr = std::get_if(&outputs[idx]); + valuePtr && *valuePtr) { + auto id = (*valuePtr)->id(); + + /* + values are not moveable if: + 1. they are persistent + 2. they are inputs (since inputs are borrowed) + 3. the value will be moved in a later (right-more) output + */ + + if (!seen.insert(id).second || persistent_[id] || + inputs.find(id) != inputs.end()) { + moveable_output_mask_[idx] = false; + } + } + } +} + +ExecutionFrame::ExecutionFrame( + const Graph& graph, + size_t numValues, + const std::vector&, + const std::vector&) + : graph_(graph) { + allValues_.resize(numValues); +} + +void ExecutionFrame::setIValue(ValueId id, c10::IValue ivalue) { + DCHECK(static_cast(id) < allValues_.size()); + allValues_[id] = std::move(ivalue); +} + +void ExecutionFrame::setBorrowedIValue(ValueId id, c10::IValue ivalue) { + DCHECK(static_cast(id) < allValues_.size()); + borrowedValueIds_.push_back(id); + allValues_[id] = std::move(ivalue); +} + +at::Tensor ExecutionFrame::getTensor(ValueId id) const { + const auto& ivalue = getIValue(id); + if (C10_LIKELY(ivalue.isTensor())) { + return ivalue.toTensor(); + } else { + throw std::runtime_error("getTensor called on non-tensor value"); + } +} + +std::vector ExecutionFrame::tryMoveUserOutputs() { + std::vector ret; + const auto& outputs = graph_.userOutputs(); + ret.reserve(outputs.size()); + for (const auto& [i, outputValue] : c10::enumerate(outputs)) { + if (const Value* const* valuePtr = std::get_if(&outputValue); + valuePtr && *valuePtr) { + ret.push_back( + isOutputMovable(i) ? moveIValue((*valuePtr)->id()) + : getIValue((*valuePtr)->id())); + } else if (Constant const* constant = std::get_if(&outputValue)) { + ret.push_back(constantToIValue(*constant)); + } + } + return ret; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ExecutionFrame.h b/torch/nativert/executor/ExecutionFrame.h new file mode 100644 index 00000000000000..ae8821a6e58b02 --- /dev/null +++ b/torch/nativert/executor/ExecutionFrame.h @@ -0,0 +1,183 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +namespace torch::nativert { + +/** + * This class encapsulate the stateful values of an execution, + * most notably, the tensor values passed between nodes, aka intermediate + * activations. + */ +class ExecutionFrame { + public: + // Constructor for weight-less graph, used for higher order ops, e.g. + // torch.cond + explicit ExecutionFrame(const Graph& graph); + + explicit ExecutionFrame( + const Graph& graph, + const Weights& weights, + const torch::nativert::ExecutorConfig& executorConfig = {}, + LayoutPlanner* layoutPlanner = nullptr); + + // Constructor for testing purpose + explicit ExecutionFrame( + const Graph& graph, + size_t numValues, + const std::vector& graphInputIds, + const std::vector& graphOutputIds); + + ExecutionFrame(const ExecutionFrame&) = delete; + ExecutionFrame& operator=(const ExecutionFrame&) = delete; + ExecutionFrame(ExecutionFrame&&) = delete; + ExecutionFrame& operator=(ExecutionFrame&&) = delete; + + ~ExecutionFrame() { + destroyBorrowedIValues(); + } + + template + auto withMemoryPlanner(CB&& cb) { + if (!layoutManager_) { + return std::forward(cb)(); + } + + LayoutManagerGuard guard(*layoutManager_); + return std::forward(cb)(); + } + + std::vector tryMoveUserOutputs(); + + c10::IValue moveIValue(ValueId id) { + return std::move(allValues_[id]); + } + + const c10::IValue& getIValue(ValueId id, bool allowNone = true) const { + const auto& iValue = allValues_[id]; + if (allowNone && iValue.isNone()) { + return iValue; + } + DCHECK(!iValue.isNone()); + return iValue; + } + + c10::IValue& getIValue(ValueId id, bool allowNone = true) { + auto& iValue = allValues_[id]; + if (allowNone && iValue.isNone()) { + return iValue; + } + DCHECK(!iValue.isNone()); + return iValue; + } + + void setIValue(ValueId id, c10::IValue ivalue); + void setBorrowedIValue(ValueId id, c10::IValue ivalue); + + at::Tensor getTensor(ValueId id) const; + + std::vector getTensorVector(ValueId id) const { + return getIValue(id).toTensorVector(); + } + + int64_t getSymInt(ValueId id) const { + return getIValue(id).toInt(); + } + + double getSymFloat(ValueId id) const { + return getIValue(id).toDouble(); + } + + C10_ALWAYS_INLINE bool isManagedValue(const ValueId id) const { + return layoutPlanner_ != nullptr && layoutPlanner_->is_managed(id); + } + + void setPersistentIValue(ValueId id, c10::IValue ivalue) { + setIValue(id, std::move(ivalue)); + persistent_[id] = true; + } + + void releaseValueIfNeeded(ValueId id) { + if (!isManagedValue(id) && !persistent_[id]) { + allValues_[id] = c10::IValue(); + } + } + + void destroyBorrowedIValues() { + for (const auto& id : borrowedValueIds_) { + c10::MaybeOwnedTraits::destroyBorrow(getIValue(id)); + } + borrowedValueIds_.clear(); + } + + void setWork(int64_t workId, const c10::intrusive_ptr& work) { + work_[workId] = work; + } + + c10::intrusive_ptr getWork(int64_t workId) const { + CHECK(work_.find(workId) != work_.end()) + << "Couldn't find work with Id: " << workId; + return work_.at(workId); + } + + WeightVersion weightVersion() const { + return weightVersion_; + } + + void setWeights(const Weights& weights); + + static std::vector> getPersistentValues( + const Graph& graph, + const Weights* weights = nullptr); + + static std::vector getPersistentValueMask( + const Graph& graph, + const Weights* weights = nullptr) { + std::vector persistentValuesMask(graph.numValues()); + for (auto& [valueId, _] : getPersistentValues(graph, weights)) { + persistentValuesMask[valueId] = true; + } + return persistentValuesMask; + } + + private: + bool isOutputMovable(size_t idx) const { + TORCH_CHECK_LT(idx, moveable_output_mask_.size()); + return moveable_output_mask_[idx]; + } + + void updatePersistentValues(const Weights* weights = nullptr); + void updateMovableOutputs(); + + const Graph& graph_; + WeightVersion weightVersion_ = -1; + + std::unique_ptr layoutManager_; + LayoutPlanner* layoutPlanner_{nullptr}; + + // All the intermediate values for the entire graph, including graph inputs + // and outputs This table is fixed once constructed + std::vector allValues_; + // a class-local version of getPersistentValueMask + std::vector persistent_; + + std::unordered_map> work_; + + std::vector borrowedValueIds_; + + // moveable_output_mask_[i] corresponds to user_outputs_[i] + // + // if moveable_output_mask_[i] is true, then user_outputs_[i] + // can be moved + std::vector moveable_output_mask_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ExecutionPlanner.cpp b/torch/nativert/executor/ExecutionPlanner.cpp new file mode 100644 index 00000000000000..a3a3f58f3062d5 --- /dev/null +++ b/torch/nativert/executor/ExecutionPlanner.cpp @@ -0,0 +1,117 @@ +#include + +#include + +#include +#include + +namespace torch::nativert { + +std::unique_ptr ExecutionPlanner::createPlan() { + auto plan = std::make_unique(); + + // Current implementation assume that nodes will be executed + // in the same order as the thrift graph. + // In the future, we can do execution order plan, as long as it's + // comply with topological order + + generateDeallocationPlan(*plan); + return plan; +} + +/* static */ c10::FastSet ExecutionPlanner::staticValues( + const Graph& graph) { + c10::FastSet staticValues; + // Filter lastUsedBy by graph inputs + // parameters/buffer values should not be freed + // It's a policy decision to whether to free user inputs. For now, we don't + // free user inputs. + // TODO: It should be fine to "free" the user inputs. If the user holds a ref + // to it, it won't be deallocated. + for (const auto* input : graph.inputs()) { + if (input) { + const auto& id = input->id(); + staticValues.insert(id); + } + } + + // Filter lastUsedBy by graph outputs, as they are still needed to be returned + for (const auto& output : graph.outputs()) { + const auto& id = output->id(); + staticValues.insert(id); + } + + for (const auto& [id, _] : graph.getConstantSymIntValues()) { + staticValues.insert(id); + } + + for (const Node& node : graph.nodes()) { + if (node.target() == "torch.ops.higher_order.run_const_graph") { + for (const auto& output : node.outputs()) { + // Do not free the outputs of run_const_graph, as they are newly + // produced folded constants + staticValues.insert(output->id()); + } + } else { + for (const auto& input : node.inputs()) { + if (input.value->isFolded()) { + staticValues.insert(input.value->id()); + } + } + } + } + + return staticValues; +} + +void ExecutionPlanner::generateDeallocationPlan(ExecutionPlan& plan) { + const auto& nodes = graph_.nodes(); + size_t numNodes = nodes.size(); + + std::unordered_map lastUsedBy; + + // Traverse from the last node to the first node + // For each Value, find out which is the last node that uses it + // the Value can freed after executing the node + size_t nodeIdx = nodes.size() - 1; + for (auto it = std::rbegin(nodes); it != std::rend(nodes); it++) { + const auto& inputs = it->inputs(); + for (const auto& input : inputs) { + const auto& id = input.value->id(); + if (lastUsedBy.find(id) == lastUsedBy.end()) { + lastUsedBy.insert({id, nodeIdx}); + } + } + nodeIdx--; + } + + std::vector> valuesToFree(numNodes); + + const auto& statics = staticValues(graph_); + for (auto& [id, nodeIndex] : lastUsedBy) { + if (statics.find(id) == statics.end()) { + valuesToFree[nodeIndex].push_back(id); + } + } + + plan.valuesToFree = std::move(valuesToFree); + + // print allocation plan + VLOG(2) << plan; + + return; +} + +std::ostream& operator<<(std::ostream& out, const ExecutionPlan& plan) { + out << "****** Deallocation Plan ******\n"; + for (auto&& [i, values] : c10::enumerate(plan.valuesToFree)) { + out << "Node #" << i << ", valuesToFree = ["; + for (const auto& value : values) { + out << value << ", "; + } + out << "]\n"; + } + return out; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ExecutionPlanner.h b/torch/nativert/executor/ExecutionPlanner.h new file mode 100644 index 00000000000000..af470341cc0c2b --- /dev/null +++ b/torch/nativert/executor/ExecutionPlanner.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +#include + +namespace torch::nativert { + +// ExecutionPlan is the result produced by ExecutionPlanner +// ATM, it only contains value deallocation plan. +struct ExecutionPlan { + // i-th entry in this list are the Values can be freed *after* execution i-th + // node + std::vector> valuesToFree; +}; + +class ExecutionPlanner { + public: + explicit ExecutionPlanner(const Graph& graph) : graph_(graph) {} + + std::unique_ptr createPlan(); + // get list of values we can't free + static c10::FastSet staticValues(const Graph& graph); + + private: + void generateDeallocationPlan(ExecutionPlan& plan); + const Graph& graph_; +}; + +std::ostream& operator<<(std::ostream& out, const ExecutionPlan& plan); + +} // namespace torch::nativert diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp new file mode 100644 index 00000000000000..285b6dea00dd7e --- /dev/null +++ b/torch/nativert/executor/Executor.cpp @@ -0,0 +1,387 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Maximum number of retries when trying to get a frame from +// clearedExecutionFrames_ +constexpr uint32_t kClearExecutionFrameRetries = 10; + +namespace torch::nativert { + +Executor::Executor( + torch::nativert::ExecutorConfig executorConfig, + std::shared_ptr graph, + std::shared_ptr weights, + const Placement& placement, + std::shared_ptr pytorchStreamReader, + const MakeProxyExecutorFn& makeProxyExecutorFunc) + : executorConfig_(std::move(executorConfig)), + graph_(std::move(graph)), + placement_(placement), + constantFolder_( + executorConfig_.runConstFolding + ? std::optional(*graph_) + : std::nullopt), + makeProxyExecutorFunc_(makeProxyExecutorFunc), + executionFrames_(executorConfig_.maxNumConcurrentThreads), + clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads), + numExecutionFrames_(0), + lastClearedTimestamp_(getCurrentTimestampSeconds()) { + if (weights) { + initialize(std::move(weights), std::move(pytorchStreamReader)); + } +} + +void Executor::initialize( + std::shared_ptr weights, + std::shared_ptr + pytorchStreamReader) { + auto start = std::chrono::high_resolution_clock::now(); + + auto executionKernels = KernelFactory().initializeNodeKernels( + *graph_, + weights, + executorConfig_, + placement_, + std::move(pytorchStreamReader), + makeProxyExecutorFunc_); + + if (constantFolder_.has_value()) { + constantFolder_->unlinkConstants(executionKernels.nodeKernels); + } + + const auto& kernelSchemas = getKernelSchemas(executionKernels.nodeKernels); + + if (executorConfig_.maxParallelOps > 1) { + graphExecutor_ = std::make_unique( + *graph_, std::move(executionKernels.nodeKernels), executorConfig_); + } else { + graphExecutor_ = std::make_unique( + *graph_, std::move(executionKernels.nodeKernels), executorConfig_); + } + + delegateExecutors_ = std::move(executionKernels.delegateExecutors); + constFoldingExecutions_ = std::move(executionKernels.constFoldingExecutions); + + // initialize weights_ + processWeights(weights); + atomicSwapWeights(weights); + + if (executorConfig_.layoutPlannerSettings.enabled()) { + layoutPlanner_ = std::make_unique( + *graph_, + kernelSchemas, + ExecutionFrame::getPersistentValueMask(*graph_, weights.get()), + executorConfig_.layoutPlannerSettings); + } + + auto end = std::chrono::high_resolution_clock::now(); + LOG(INFO) << "Initialization completed in " + << std::chrono::duration_cast( + end - start) + .count() + << " ms"; +} + +/* static */ c10:: + FastMap + Executor::getKernelSchemas( + const std::vector>& kernels) { + c10::FastMap output; + for (const auto& kernel : kernels) { + if (const auto* casted = dynamic_cast(kernel.get()); casted) { + output.insert({std::string(kernel->node()->target()), casted->schema()}); + } + } + return output; +} + +void Executor::atomicSwapWeights(std::shared_ptr weights) { + weights_.withLock([&](auto& w) { w = std::move(weights); }); + + // update weights in delegate executors + for (auto& delegateExecutor : delegateExecutors_) { + delegateExecutor->commitWeights(); + } +} + +void Executor::maybeRunConstantFolding(std::shared_ptr weights) { + for (auto& execution : constFoldingExecutions_) { + ExecutionFrame constFoldingFrame(execution.executor->graph()); + std::vector inputs; + inputs.reserve(graph_->signature().inputsToWeights().size()); + for (const auto& [_, name] : graph_->signature().inputsToWeights()) { + inputs.push_back(weights->at(name)); + } + + auto outputs = execution.executor->execute(constFoldingFrame, inputs); + for (const auto& [idx, value] : + c10::enumerate(execution.executor->graph().outputs())) { + weights->updateFoldedConst(value->name(), outputs.at(idx)); + } + } +} + +void Executor::processWeights(std::shared_ptr weights) { + maybeRunConstantFolding(weights); + if (constantFolder_.has_value()) { + constantFolder_->evaluate(*weights); + } + for (auto& delegateExecutor : delegateExecutors_) { + delegateExecutor->processWeights(weights); + } +} + +namespace { +void validateInput( + const std::string& inputName, + const at::Tensor& inputTensor, + const torch::nativert::TensorMeta& tensorValueMeta) { + CHECK(inputTensor.dtype() == tensorValueMeta.dtype()) + << "Input tensor dtype mismatch for " << inputName << ", expecting " + << c10::toString(tensorValueMeta.dtype()) << " but got " + << inputTensor.dtype().name(); + + CHECK(inputTensor.device() == tensorValueMeta.device()) + << "Input tensor device mismatch for " << inputName << ", expecting " + << tensorValueMeta.device().str() << " but got " + << inputTensor.device().str(); +} + +} // namespace + +// validate input tensor's dtype matches tensorMeta +void Executor::validateInputs(const std::vector& inputs) const { + const auto& inputValues = graph_->userInputs(); + const auto& tensorValuesMeta = graph_->tensorValuesMeta(); + TORCH_CHECK(inputs.size() == inputValues.size(), "Input size mismatch"); + for (auto&& [i, actualInput] : c10::enumerate(inputs)) { + if (actualInput.isTensor()) { + const auto& inputName = std::string(inputValues[i]->name()); + auto it = tensorValuesMeta.find(inputName); + CHECK(it != tensorValuesMeta.end()) + << "Couldn't find " << inputName << " in tensorValuesMeta"; + validateInput(inputName, actualInput.toTensor(), it->second); + } + } +} + +Executor::ExecutorFramePtr Executor::getExecutorFrameFromPool() { + std::shared_ptr weights; + weights_.withLock([&](auto& w) { weights = w; }); + + // First try to get a frame from clearedExecutionFrames_ if clearing is in + // progress + if (C10_UNLIKELY(clearingInProgress_)) { + ExecutionFrameEntry frameEntry; + uint32_t retry = 0; + while ( + retry < + kClearExecutionFrameRetries) { // Limit retries to avoid infinite loop + if (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) { + if (retry > 0) { + VLOG(1) << "Took " << retry + << " retries to pop from clearedExecutionFrames_"; + } + ExecutorFramePtr ptr{std::move(frameEntry.frame), *this}; + if (ptr->weightVersion() != weights->version()) { + ptr->setWeights(*weights); + } + return ptr; + } + retry++; + } + // If we couldn't get a frame from cleared pool after retries, move onto + // main pool + } + + // Try to get a frame from the main pool or create a new one + std::unique_ptr frame; + while (!executionFrames_.readIfNotEmpty(frame)) { + int64_t numFrames = numExecutionFrames_.load(); + if (numFrames < executorConfig_.maxNumConcurrentThreads) { + if (numExecutionFrames_.compare_exchange_strong( + numFrames, numFrames + 1)) { + return ExecutorFramePtr{ + std::make_unique( + *graph_, *weights, executorConfig_, layoutPlanner_.get()), + *this}; + } + } else { + sem_.acquire(); + } + } + ExecutorFramePtr ptr{std::move(frame), *this}; + + if (ptr->weightVersion() != weights->version()) { + ptr->setWeights(*weights); + } + return ptr; +} + +void Executor::clearStaleExecutionFrames() { + if (!cleanupLock_.try_lock()) { + // Another thread is already doing cleanup + return; + } + // Update timestamp first to minimize contention + lastClearedTimestamp_ = getCurrentTimestampSeconds(); + + int numPopped = 0; + std::unique_ptr frame; + + // Move frames from executionFrames_ to clearedExecutionFrames_ + while (executionFrames_.readIfNotEmpty(frame)) { + ++numPopped; + // Keep the first popped entries up to minimum size + if (numPopped > executorConfig_.minNumExecutionFrames) { + // Discard stale frames + frame.reset(); + numExecutionFrames_ -= 1; + continue; + } + + ExecutionFrameEntry entry; + entry.used = false; + entry.frame = std::move(frame); + clearedExecutionFrames_.writeIfNotFull(std::move(entry)); + // Enable clients to pop from clearedExecutionFrames_ while clearing is in + // progress + clearingInProgress_ = true; + } + + uint32_t numPushed = 0; + ExecutionFrameEntry frameEntry; + // Move frames back from clearedExecutionFrames_ to executionFrames_ + while (clearedExecutionFrames_.readIfNotEmpty(frameEntry)) { + ++numPushed; + executionFrames_.writeIfNotFull(std::move(frameEntry.frame)); + clearingInProgress_ = false; + } + + clearingInProgress_ = false; + VLOG(1) << "Cleared " << (numPopped - numPushed) << " out of " << numPopped + << " ExecutionFrame instances in the pool"; + + cleanupLock_.unlock(); +} + +void Executor::returnExecutorFrameToPool( + std::unique_ptr frame) { + // Check if it's time to clean up stale frames + if (executorConfig_.doExecutionFrameCleanup && + lastClearedTimestamp_ + + executorConfig_.executionFramePoolCleanupIntervalSec < + getCurrentTimestampSeconds()) { + clearStaleExecutionFrames(); + } + + try { + frame->destroyBorrowedIValues(); + + // Create an entry with used=true + if (C10_UNLIKELY(!clearingInProgress_)) { + CHECK(executionFrames_.writeIfNotFull(std::move(frame))) + << "ExecutionFrame pool full"; + } else { + ExecutionFrameEntry frameEntry; + frameEntry.used = true; + frameEntry.frame = std::move(frame); + + CHECK(clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry))) + << "Cleared ExecutionFrame pool full"; + } + } catch (...) { + sem_.release(); + throw; + } + sem_.release(); +} + +std::vector Executor::execute(std::vector inputs) { + if (executorConfig_.validateInputs) { + validateInputs(inputs); + } + + auto executionFrame = getExecutorFrameFromPool(); + return graphExecutor_->execute(*executionFrame, std::move(inputs)); +} + +std::vector Executor::execute( + const std::vector& args, + const std::unordered_map& kwargs, + const ITreeSpec& inputTreeSpec) { + auto executionFrame = getExecutorFrameFromPool(); + + std::optional> outputs; + const auto userInputs = graph_->userInputs(); + const auto& tensorValuesMeta = graph_->tensorValuesMeta(); + TORCH_CHECK_EQ(userInputs.size(), inputTreeSpec.numIValues()); + + auto executionFrameFillUserInputs = [&](const c10::IValue& leaf, + const Value* value) { + // validate input tensor's dtype and device matches tensorMeta + if (executorConfig_.validateInputs && leaf.isTensor()) { + const auto& inputName = std::string(value->name()); + auto it = tensorValuesMeta.find(inputName); + CHECK(it != tensorValuesMeta.end()) + << "Couldn't find " << inputName << " in tensorValuesMeta"; + validateInput(inputName, leaf.toTensor(), it->second); + } + executionFrame->setBorrowedIValue( + value->id(), c10::MaybeOwnedTraits::createBorrow(leaf)); + }; + ivalueApplyFromArgs( + executionFrameFillUserInputs, args, kwargs, inputTreeSpec); + try { + outputs = graphExecutor_->executeWithPrefilledFrame(*executionFrame); + } catch (const std::exception& e) { + LOG(ERROR) << "Exception during executeWithPrefilledFrame: " << e.what(); + throw; + } + + return std::move(*outputs); +} + +ProfileMetrics Executor::benchmarkIndividualNodes( + std::vector> inputsList, + const uint32_t warmupRuns, + const uint32_t mainRuns) { + CHECK(inputsList.size() > 0) << "Need at least one input to benchmark"; + CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run"; + + for (const auto& inputs : inputsList) { + if (executorConfig_.validateInputs) { + validateInputs(inputs); + } + } + auto executionFrame = getExecutorFrameFromPool(); + auto benchmarkResults = graphExecutor_->benchmarkIndividualNodes( + *executionFrame, inputsList, warmupRuns, mainRuns); + + return benchmarkResults; +} + +int64_t Executor::getCurrentTimestampSeconds() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); +} + +std::vector Executor::getDelegates() { + std::vector delegates; + for (const auto& delegateExecutor : delegateExecutors_) { + delegates.push_back(delegateExecutor.get()); + } + return delegates; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/Executor.h b/torch/nativert/executor/Executor.h new file mode 100644 index 00000000000000..964337b691766d --- /dev/null +++ b/torch/nativert/executor/Executor.h @@ -0,0 +1,206 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::nativert { + +using namespace torch::nativert::detail; + +struct DistributedRunConfig; + +/** + * A very dumb executor. Basically just runs each node in order and contains a + * giant unordered map for every intermediate, no optimizations applied. + */ +class Executor { + class ExecutorFrameDeleter { + public: + explicit ExecutorFrameDeleter(Executor& e) : e_(&e) {} + ExecutorFrameDeleter(ExecutorFrameDeleter&&) = default; + ExecutorFrameDeleter& operator=(ExecutorFrameDeleter&&) = default; + ExecutorFrameDeleter(const ExecutorFrameDeleter&) = default; + ExecutorFrameDeleter& operator=(const ExecutorFrameDeleter&) = default; + ~ExecutorFrameDeleter() = default; + + void operator()(ExecutionFrame* p) { + e_->returnExecutorFrameToPool(std::unique_ptr(p)); + } + + private: + Executor* e_; + }; + class ExecutorFramePtr { + public: + ExecutorFramePtr(std::unique_ptr ptr, Executor& e) + : ptr_(std::unique_ptr( + ptr.release(), + ExecutorFrameDeleter{e})) {} + ExecutorFramePtr() = delete; + ExecutorFramePtr(ExecutorFramePtr&&) = default; + ExecutorFramePtr& operator=(ExecutorFramePtr&&) = default; + ExecutorFramePtr(const ExecutorFramePtr&) = delete; + ExecutorFramePtr& operator=(const ExecutorFramePtr&) = delete; + ~ExecutorFramePtr() = default; + + ExecutionFrame& operator*() { + return *ptr_; + } + + ExecutionFrame* operator->() { + return ptr_.get(); + } + + private: + std::unique_ptr ptr_; + }; + + public: + // Constrcutor used for Inference Path + Executor( + torch::nativert::ExecutorConfig executorConfig, + std::shared_ptr graph, + std::shared_ptr weights, + const Placement& placement = Placement(), + std::shared_ptr + pytorchStreamReader = nullptr, + const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr); + + std::shared_ptr getWeights() { + std::shared_ptr ret; + weights_.withLock([&](auto& w) { ret = w; }); + return ret; + } + + void processWeights(std::shared_ptr weights); + void atomicSwapWeights(std::shared_ptr weights); + + // This API only returns the flattened UserOutputs, + // intended to be used for Inference path + // TODO Investigate whether we should remove this, still seems + // useful for testing. + std::vector execute(std::vector inputs); + + std::vector execute( + const std::vector& args, + const std::unordered_map& kwargs, + const ITreeSpec& inputTreeSpec); + + ProfileMetrics benchmarkIndividualNodes( + std::vector> inputsList, + const uint32_t warmupRuns, + const uint32_t mainRuns); + + const torch::nativert::GraphSignature& graphSignature() const { + return graph_->signature(); + } + + static std::string className() { + return "Executor.v0"; + } + + const torch::nativert::ExecutorConfig& executorConfig() const { + return executorConfig_; + } + + std::vector getDelegates(); + + // Get the number of execution frames in the pool + int getNumExecutionFrames() const { + return numExecutionFrames_.load(); + } + + static c10::FastMap + getKernelSchemas(const std::vector>& kernels); + + protected: + torch::nativert::ExecutorConfig executorConfig_; + + std::shared_ptr graph_; + + // manages the parameters, buffers and tensor constants + c10::Synchronized> weights_; + + void initialize( + std::shared_ptr weights, + std::shared_ptr + pytorchStreamReader); + + ExecutorFramePtr getExecutorFrameFromPool(); + void returnExecutorFrameToPool(std::unique_ptr frame); + + // Clears stale execution frames from the pool + void clearStaleExecutionFrames(); + + private: + // Structure to track execution frame usage + struct ExecutionFrameEntry { + bool used{false}; + std::unique_ptr frame; + + // Add move constructor and assignment operator + ExecutionFrameEntry() = default; + ExecutionFrameEntry(ExecutionFrameEntry&& other) noexcept + : used(other.used), frame(std::move(other.frame)) {} + ExecutionFrameEntry& operator=(ExecutionFrameEntry&& other) noexcept { + used = other.used; + frame = std::move(other.frame); + return *this; + } + // Delete copy constructor and assignment operator + ExecutionFrameEntry(const ExecutionFrameEntry&) = delete; + ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete; + }; + + void maybeRunConstantFolding(std::shared_ptr weights); + void validateInputs(const std::vector& inputs) const; + + // Helper method to get current timestamp in seconds + int64_t getCurrentTimestampSeconds() const; + + std::unique_ptr graphExecutor_; + + const Placement placement_; + + // NOTE: delegateExecutors_ is used by nodeKernels_ inside graphExecutor_. + std::vector> delegateExecutors_; + + std::vector constFoldingExecutions_; + + std::optional constantFolder_; + + MakeProxyExecutorFn makeProxyExecutorFunc_; + + c10::Semaphore sem_; + torch::nativert::detail::MPMCQueue> + executionFrames_; + torch::nativert::detail::MPMCQueue + clearedExecutionFrames_; + std::atomic_int64_t numExecutionFrames_; + + std::unique_ptr layoutPlanner_; + std::atomic_int64_t lastClearedTimestamp_; + std::mutex cleanupLock_; + std::atomic_bool clearingInProgress_{false}; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ExecutorConfig.h b/torch/nativert/executor/ExecutorConfig.h index 11ef889149e335..a6a5ef20a0859b 100644 --- a/torch/nativert/executor/ExecutorConfig.h +++ b/torch/nativert/executor/ExecutorConfig.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -9,12 +10,16 @@ struct ExecutorConfig { bool validateInputs = false; bool debugNan = false; bool enableStaticCPUKernels = false; - bool enableStaticMemoryPlanning = false; bool runConstFolding = false; + bool doExecutionFrameCleanup = true; + bool tryFreeUnmanagedValuesAfterUse = true; // allows up to max number of concurrent threads. int64_t maxNumConcurrentThreads = 8; // allows up to max number of parallel ops. int64_t maxParallelOps = 1; + int64_t minNumExecutionFrames = 1; + int64_t executionFramePoolCleanupIntervalSec = 600; + LayoutPlannerSettings layoutPlannerSettings; std::string modelName = "unknown"; }; diff --git a/torch/nativert/executor/GraphExecutorBase.cpp b/torch/nativert/executor/GraphExecutorBase.cpp new file mode 100644 index 00000000000000..5ad31a7dacabee --- /dev/null +++ b/torch/nativert/executor/GraphExecutorBase.cpp @@ -0,0 +1,120 @@ +#include +#include + +#include +#include + +namespace torch::nativert { + +GraphExecutorBase::GraphExecutorBase( + const Graph& graph, + std::vector> nodeKernels, + const ExecutorConfig& executorConfig) + : graph_(graph), + nodeKernels_(std::move(nodeKernels)), + executorConfig_(executorConfig), + execPlan_(ExecutionPlanner{graph_}.createPlan()) {} + +void GraphExecutorBase::fillUserInputs( + ExecutionFrame& frame, + std::vector inputs) { + RECORD_USER_SCOPE("Executor::fillUserInputs"); + const auto& inputValues = graph_.userInputs(); + TORCH_CHECK_EQ(inputValues.size(), inputs.size()); + + // load user input tensor into execution frame + for (size_t i = 0; i < inputValues.size(); i++) { + if (inputValues[i]) { + frame.setIValue(inputValues[i]->id(), std::move(inputs[i])); + } + } +} + +ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( + ExecutionFrame& executionFrame, + std::vector> inputsList, + const uint32_t warmupRuns, + const uint32_t mainRuns) { + // TODO: add support for memory profiling + TORCH_CHECK(warmupRuns >= 1 && mainRuns >= 1); + + ProfileMetrics results; + const auto numNodes = static_cast(nodeKernels_.size()); + results.timePerNode.resize(numNodes, 0); + if (inputsList.empty()) { + auto i = 0; + for (const auto& nodeKernel : nodeKernels_) { + std::string target(nodeKernel->node()->target()); + results.timePerNode[i] = 0; + results.timePerNodeType[target] = 0; + results.instancesPerNodeType[target]++; + if (nodeKernel->hasPrimKernel()) { + results.primNodesCount++; + results.primNodes.insert(target); + } else if (nodeKernel->hasStaticDispatch()) { + results.staticDispatchNodesCount++; + results.staticDispatchNodes.insert(target); + } + i++; + } + results.totalNodesCount = numNodes; + for (const auto& p : results.timePerNodeType) { + const std::string& kind = p.first; + results.percentPerNodeType[kind] = 0; + } + return results; + } + + // Warmup + for (uint32_t i = 0; i < warmupRuns; i++) { + for (const auto& inputs : inputsList) { + execute(executionFrame, inputs); + } + } + + // Execute kernels + caffe2::Timer timer; + for (uint32_t i = 0; i < mainRuns; i++) { + for (auto inputs : inputsList) { + const auto& inputValues = graph_.userInputs(); + + TORCH_CHECK_EQ(inputValues.size(), inputs.size()); + for (size_t j = 0; j < inputValues.size(); j++) { + executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j])); + } + for (NodeIndex nodeIdx = 0; nodeIdx < nodeKernels_.size(); ++nodeIdx) { + timer.Start(); + nodeKernels_[nodeIdx]->compute(executionFrame); + float millis = timer.MilliSeconds(); + results.timePerNode[nodeIdx] += millis; + } + } + } + + // Summarize results + const float numTotalIters = + (static_cast(mainRuns) * static_cast(inputsList.size())); + for (const auto i : c10::irange(numNodes)) { + const Node* node = nodeKernels_[i]->node(); + std::string target(node->target()); + results.timePerNode[i] /= numTotalIters; + results.timePerNodeType[target] += results.timePerNode[i]; + results.instancesPerNodeType[target]++; + if (nodeKernels_[i]->hasPrimKernel()) { + results.primNodes.insert(target); + results.primNodesCount++; + } else if (nodeKernels_[i]->hasStaticDispatch()) { + results.staticDispatchNodes.insert(target); + results.staticDispatchNodesCount++; + } + results.totalTime += results.timePerNode[i]; + } + results.totalNodesCount = numNodes; + for (const auto& r : results.timePerNodeType) { + const std::string& target = r.first; + results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime; + } + return results; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/GraphExecutorBase.h b/torch/nativert/executor/GraphExecutorBase.h new file mode 100644 index 00000000000000..86c6ed61c1f9ae --- /dev/null +++ b/torch/nativert/executor/GraphExecutorBase.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch::nativert { + +struct ProfileMetrics { + size_t primNodesCount{0}; + size_t staticDispatchNodesCount{0}; + size_t totalNodesCount{0}; + std::vector timePerNode; + std::unordered_map timePerNodeType; + std::unordered_map percentPerNodeType; + std::unordered_map instancesPerNodeType; + std::unordered_set staticDispatchNodes; + std::unordered_set primNodes; + float totalTime{0}; +}; + +/** + * GraphExecutor is a lightweight abstraction to execute a graph with + * execution frames without actually owning the graph nor the weights. This is + * introduced to decouple the state management of the top level runtime from the + * kernel executions so that sub graphs from higher order ops can be supported. + */ +class GraphExecutorBase { + public: + GraphExecutorBase( + const Graph& graph, + std::vector> nodeKernels, + const ExecutorConfig& executorConfig); + virtual ~GraphExecutorBase() = default; + + const Graph& graph() const { + return graph_; + } + + // This API only returns the flattened UserOutputs, + // intended to be used for Inference path + virtual std::vector execute( + ExecutionFrame& frame, + std::vector inputs) = 0; + + virtual std::vector executeWithPrefilledFrame( + ExecutionFrame& frame) = 0; + + ProfileMetrics benchmarkIndividualNodes( + ExecutionFrame& executionFrame, + std::vector> inputs, + const uint32_t warmup_runs, + const uint32_t main_runs); + + std::vector> stealKernels() { + return std::move(nodeKernels_); + } + + void setKernels(std::vector>&& kernels) { + nodeKernels_ = std::move(kernels); + } + + protected: + void fillUserInputs(ExecutionFrame& frame, std::vector inputs); + + const Graph& graph_; + + // cache of the constructed kernels to avoid reconstruction per execution + std::vector> nodeKernels_; + + const ExecutorConfig& executorConfig_; + + std::unique_ptr execPlan_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/OpKernel.cpp b/torch/nativert/executor/OpKernel.cpp new file mode 100644 index 00000000000000..ee4a8503d5ce2c --- /dev/null +++ b/torch/nativert/executor/OpKernel.cpp @@ -0,0 +1,163 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include + +namespace torch::nativert { + +c10::OperatorHandle getOperatorForTarget( + std::string_view target, + const Node* node) { + // target could come as either "torch.ops.aten.add.default" or + // "aten.add.default" + std::vector atoms = c10::split(target, '.'); + + size_t numAtoms = atoms.size(); + if (numAtoms < 3) { + TORCH_CHECK(false, "Invalid target: ", target); + } + + const std::string_view ns = atoms[numAtoms - 3]; + const std::string_view opName = atoms[numAtoms - 2]; + const std::string_view overloadName = atoms[numAtoms - 1]; + + const auto operatorName = fmt::format("{}::{}", ns, opName); + std::string normalizedOverloadName; + if (overloadName == "default") { + normalizedOverloadName = ""; + } else { + normalizedOverloadName = overloadName; + } + + auto handle = c10::Dispatcher::singleton().findSchemaOrThrow( + operatorName.c_str(), normalizedOverloadName.c_str()); + + return handle; +} + +std::string readableArgs( + const c10::FunctionSchema& schema, + const std::vector& stack) { + const auto& schemaArgs = schema.arguments(); + std::stringstream ss; + for (const auto& [i, arg] : c10::enumerate(stack)) { + ss << "arg" << i << ' ' << schemaArgs[i].name() << ": " << arg.tagKind() + << ' '; + if (arg.isTensor()) { + auto t = arg.toTensor(); + ss << t.dtype() << t.sizes() << t.device(); + } else if (arg.isTensorList()) { + auto tl = arg.toTensorVector(); + ss << '['; + for (const auto& t : tl) { + ss << t.dtype() << t.sizes() << t.device() << ", "; + } + ss << ']'; + } else if (arg.isNone()) { + // pass + } else { + ss << arg; + } + ss << "\n"; + } + return ss.str(); +} + +const bool OpKernel::blockingEnabled_ = + c10::utils::get_env("CUDA_LAUNCH_BLOCKING").value_or("0") == "1"; + +void OpKernel::compute(ExecutionFrame& executionFrame) const { + VLOG(2) << "Executing: " << *node_; + + computeInternal(executionFrame); + + VLOG(2) << "Completed: " << *node_; +} + +Arguments prefillStackWithStaticArgs( + const Node* node, + const c10::FunctionSchema& schema) { + std::vector stackWithStaticArgs; + std::vector dynamicArgs; + const auto& schemaArgs = schema.arguments(); + stackWithStaticArgs.resize(schemaArgs.size()); + dynamicArgs.resize(schemaArgs.size()); + + // initialized stackWithStaticArgs_ with static inputs + for (const auto& [idx, schemaArg] : c10::enumerate(schemaArgs)) { + const auto& argName = schemaArg.name(); + + // Check if this is a dynamic input to the op. + const auto input = node->tryGetInput(argName); + if (input != nullptr) { + stackWithStaticArgs.at(idx) = c10::IValue(); + dynamicArgs.at(idx) = input->value; + continue; + } + + // Check if this is a statically known input to the op. + const auto attribute = node->tryGetAttribute(argName); + if (attribute != nullptr) { + stackWithStaticArgs.at(idx) = constantToIValue(attribute->value); + continue; + } + + // Otherwise, it must have a default value + auto defaultValueOpt = schemaArg.default_value(); + if (defaultValueOpt.has_value()) { + stackWithStaticArgs.at(idx) = defaultValueOpt.value(); + continue; + } + + TORCH_CHECK( + false, + "Cannot initialize argument ", + argName, + " for node ", + *node, + " with schema ", + schema); + } + return Arguments{std::move(stackWithStaticArgs), std::move(dynamicArgs)}; +} + +void fillDynamicInputs( + const ExecutionFrame& executionFrame, + const Arguments& arguments, + std::vector& stack) { + // fill the stack with dynamic values from execution frame, + // including tensor, tensors, symint, symints + + for (auto [idx, value] : arguments.getDynamicArgs()) { + TORCH_CHECK( + idx < stack.size(), + "Invalid index", + idx, + " for stack size ", + stack.size()); + TORCH_CHECK(stack.at(idx).isNone(), "Encountered None at index ", idx); + if (value->type() == Type::Kind::TensorList) { + // TODO: This is for passing List as an input to op that takes a + // List>. + // Need to cast it to a vector and back to a list, otherwise will get + // list covariance problems where List is not a subtype + // of List> when trying to execute aten.index.Tensor. + // Our lists should be covariant because they are static, + // but IValues don't know that :( + stack[idx] = executionFrame.getIValue(value->id()).toTensorList().vec(); + } else if (value->type() == Type::Kind::None) { + stack[idx] = c10::IValue(); + } else { + stack[idx] = executionFrame.getIValue(value->id()); + } + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/OpKernel.h b/torch/nativert/executor/OpKernel.h new file mode 100644 index 00000000000000..de7c90abffb93b --- /dev/null +++ b/torch/nativert/executor/OpKernel.h @@ -0,0 +1,159 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nativert { + +c10::OperatorHandle getOperatorForTarget( + std::string_view target, + const Node* node = nullptr); +/** + * @brief Manages static and dynamic arguments for kernel execution. + * + * The `Arguments` class encapsulates both static and dynamic arguments + * used during the execution of operators in a graph. + * Static arguments are the inputs that were specialized to a fixed value + * during graph capture phase. For example, scalar inputs and device are + * considered static arguments. + * Dynamic arguments are the inputs that were not baked in the graph + * during graph capture, i.e. all the tensor inputs to operators + */ +class Arguments { + public: + Arguments( + std::vector stackWithStaticArgs, + std::vector dynamicArgs) + : stackWithStaticArgs_(std::move(stackWithStaticArgs)), + dynamicArgs_(std::move(dynamicArgs)) { + for (size_t i = 0; i < dynamicArgs_.size(); i++) { + if (dynamicArgs_[i]) { + indices_.push_back(i); + } + } + } + + // Returns a view of pairs consist of the argument index and + // the corresponding Value pointer from the graph. + auto getDynamicArgs() const { + std::vector> ret; + ret.reserve(indices_.size()); + for (auto i : indices_) { + ret.emplace_back(i, dynamicArgs_[i]); + } + return ret; + } + + // Argument i means the i-th input to the operator in the argument list. + // Will return nullptr if the argument is not dynamic. + Value* findDynamic(size_t i) const { + DCHECK(i < dynamicArgs_.size()) << "Invalid input index: " << i; + return dynamicArgs_[i]; + } + + // Argument i means the i-th input to the operator in the argument list. + // Will return None as IValue if the argument is not static. + const c10::IValue& getStatic(size_t i) const { + DCHECK(i < stackWithStaticArgs_.size()) << "Invalid input index: " << i; + return stackWithStaticArgs_[i]; + } + + const std::vector& getStackWithStaticArgs() const { + return stackWithStaticArgs_; + } + + private: + // stack pre-populated with attributes, aka static arguments + const std::vector stackWithStaticArgs_; + + // Argument can only be asTensor, asTensors, asSymInt, asSymInts + const std::vector dynamicArgs_; + std::vector indices_; +}; + +void fillDynamicInputs( + const ExecutionFrame& executionFrame, + const Arguments& arguments, + std::vector& stack); + +Arguments prefillStackWithStaticArgs( + const Node* node, + const c10::FunctionSchema& schema); + +std::string readableArgs( + const c10::FunctionSchema& schema, + const std::vector& stack); + +/** + * @brief Abstract interface representing a kernel, which is responsible for + * executing a single Node in the graph. + * + * The OpKernel class is responsible for executing a single Node in the graph. + * It provides an interface for accessing node inputs and outputs, determining + * the execution kind, and executing the node's computation. + */ +class OpKernel { + public: + explicit OpKernel( + const Node* node, + std::optional device = std::nullopt, + OpKernelKind kind = OpKernelKind::kInterpreterFallbackKernel) + : node_(node), device_(device), kind_(kind) { + VLOG(1) << "Initializing kernel for node: " << *node_; + } + + const Node* node() const { + return node_; + } + void compute(ExecutionFrame& executionFrame) const; + + OpKernelKind kind() const { + return kind_; + } + + bool hasPrimKernel() const { + return kind() == OpKernelKind::kPrimKernel; + } + + bool hasStaticDispatch() const { + return kind() == OpKernelKind::kStaticDispatchKernel || + kind() == OpKernelKind::kNativeStaticDispatchKernel; + } + + size_t numInputs() const { + return node_->inputs().size(); + } + + size_t numOutputs() const { + return node_->outputs().size(); + } + + // Input is readonly + [[nodiscard]] virtual const c10::IValue& input( + uint32_t i, + ExecutionFrame& executionFrame) const { + TORCH_CHECK(i < numInputs(), "Invalid input index: ", i); + return executionFrame.getIValue(node_->inputs()[i].value->id()); + } + + // Output is read/write + c10::IValue& output(uint32_t i, ExecutionFrame& executionFrame) const { + TORCH_CHECK(i < numOutputs(), "Invalid output index: ", i); + return executionFrame.getIValue(node_->outputs()[i]->id(), true); + } + + virtual ~OpKernel() = default; + + protected: + virtual void computeInternal(ExecutionFrame& executionFrame) const = 0; + + const Node* node_; + std::optional device_; + const static bool blockingEnabled_; + // this should be set in the ctor! + const OpKernelKind kind_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/OpKernelKind.h b/torch/nativert/executor/OpKernelKind.h new file mode 100644 index 00000000000000..045664cfdee19f --- /dev/null +++ b/torch/nativert/executor/OpKernelKind.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace torch::nativert { + +enum class OpKernelKind : uint8_t { + kPrimKernel, + kStaticDispatchKernel, + kInterpreterFallbackKernel, + // static dispatch kernels that don't reuse + // out TensorImpl + kNativeStaticDispatchKernel, +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ParallelGraphExecutor.cpp b/torch/nativert/executor/ParallelGraphExecutor.cpp new file mode 100644 index 00000000000000..c147d23873d3d5 --- /dev/null +++ b/torch/nativert/executor/ParallelGraphExecutor.cpp @@ -0,0 +1,240 @@ +#include +#include +#include + +namespace { + +#define WITH_LOCK(m, block) \ + { \ + std::unique_lock lk_(m); \ + block \ + } + +} // namespace + +namespace torch::nativert { + +ThreadPoolExecutor::ThreadPoolExecutor() + : work_(std::make_unique>()) {} + +ThreadPoolExecutor::~ThreadPoolExecutor() { + stop(); +} + +C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() { + thread_local moodycamel::ProducerToken ptok(*work_); + return ptok; +} + +C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() { + thread_local moodycamel::ConsumerToken ctok(*work_); + return ctok; +} + +void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) { + session->addWork(); + unit->run(this, session); +} + +void ThreadPoolExecutor::start(int32_t numThreads) { + stopped_ = false; + for (int32_t i = 0; i < numThreads; ++i) { + threads_.emplace_back(std::thread(&ThreadPoolExecutor::loop, this)); + } +} + +void ThreadPoolExecutor::loop() { + while (true) { + Work unit; + + sem_->acquire(); + + if (stopped_) { + return; + } + + while (!work_->try_dequeue(ctok(), unit)) { + }; + + unit(); + } +} + +void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) { + session->addWork(); + work_->enqueue(ptok(), std::bind(&WorkUnit::run, unit, this, session)); + sem_->release(); +} + +void ThreadPoolExecutor::add( + SessionState* session, + std::vector::const_iterator&& begin, + const std::vector::const_iterator&& end) { + const auto count = end - begin; + + switch (count) { + case 0: { + return; + } + case 1: { + return add(session, *begin); + } + } + + session->addWork(count); + + std::vector runnables; + runnables.reserve(count); + for (; begin != end; ++begin) { + runnables.push_back(std::bind(&WorkUnit::run, *begin, this, session)); + } + + work_->enqueue_bulk(ptok(), runnables.begin(), count); + sem_->release(count); +} + +void ThreadPoolExecutor::stop() { + stopped_ = true; + sem_->release(threads_.size()); + + std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); }); + threads_.clear(); + + { + // reset sem + auto tmp = std::make_unique(); + sem_.swap(tmp); + } + + { + // flush queue + auto tmp = moodycamel::ConcurrentQueue(); + work_->swap(tmp); + } +} + +void ThreadPoolExecutor::run( + SessionState& session, + const std::vector& roots) { + // case where thread ptok exists but work_ was swapped + if (auto& tok = ptok(); C10_UNLIKELY(!tok.valid())) { + moodycamel::ProducerToken tmp(*work_); + tok.swap(tmp); + } + + const auto rootCount = roots.size(); + + if (C10_UNLIKELY(rootCount == 0)) { + return; + } else if (C10_LIKELY(rootCount > 1)) { + add(&session, roots.begin() + 1, roots.end()); + } + + execute_inline(&session, roots[0]); + + session.wait(); +} + +void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) { + thread_local std::vector newWorkUnits; + thread_local c10::InferenceMode mode; + + WorkUnit* unit = this; + + while (true) { + unit->kernel->compute(session->frame()); + + for (auto* user : unit->users) { + if (session->decrementProducers(user->node)) { + newWorkUnits.push_back(user); + } + } + + switch (newWorkUnits.size()) { + case 0: { + return session->removeWork(); + } + case 1: { + break; + } + case 2: { + executor->add(session, newWorkUnits[1]); + break; + } + default: { + executor->add(session, newWorkUnits.begin() + 1, newWorkUnits.end()); + break; + } + } + + unit = newWorkUnits[0]; + newWorkUnits.clear(); + } +} + +ParallelGraphExecutor::ParallelGraphExecutor( + const Graph& graph, + std::vector> nodeKernels, + const ExecutorConfig& executorConfig) + : GraphExecutorBase(graph, std::move(nodeKernels), executorConfig), + workUnits_( + graph.nodes().size() - 2 /* no need for prim.Input or Prim.Output */), + graph_(graph) { + auto& nodes = graph_.nodes(); + + auto input = &*nodes.begin(); + auto output = &*nodes.rbegin(); + + { + // get rid of prim.Input and prim.Output kernels + // since we won't be needing them + nodeKernels_.erase(nodeKernels_.begin()); + nodeKernels_.pop_back(); + } + + size_t idx = 0; + for (const auto& node : nodes) { + if (&node == input || &node == output) { + continue; + } + auto& workUnit = + nodeToWorkUnit_.insert_or_assign(&node, &workUnits_[idx]).first->second; + workUnit->node = &node; + workUnit->kernel = nodeKernels_[idx++].get(); + producers_.insert({&node, 0}); + } + + for (auto& unit : workUnits_) { + for (const auto* dep : unit.node->users()) { + if (dep != output) { + unit.users.push_back(nodeToWorkUnit_[dep]); + producers_[dep] += 1; + } + } + } + + for (auto& [node, p] : producers_) { + if (p == 0) { + inputWorkUnits_.push_back(nodeToWorkUnit_[node]); + } + } + + executor_.start(executorConfig.maxParallelOps); +} + +std::vector ParallelGraphExecutor::execute( + ExecutionFrame& executionFrame, + std::vector inputs) { + fillUserInputs(executionFrame, std::move(inputs)); + return executeWithPrefilledFrame(executionFrame); +} + +std::vector ParallelGraphExecutor::executeWithPrefilledFrame( + ExecutionFrame& executionFrame) { + auto session = SessionState(executionFrame, producers_); + executor_.run(session, inputWorkUnits_); + + return executionFrame.tryMoveUserOutputs(); +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/ParallelGraphExecutor.h b/torch/nativert/executor/ParallelGraphExecutor.h new file mode 100644 index 00000000000000..747e6993770ac2 --- /dev/null +++ b/torch/nativert/executor/ParallelGraphExecutor.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include +#include + +namespace moodycamel { +struct ProducerToken; +struct ConsumerToken; +struct ConcurrentQueueDefaultTraits; +template +class ConcurrentQueue; +} // namespace moodycamel + +namespace torch::nativert { +class ThreadPoolExecutor; + +typedef std::function Work; + +struct WorkUnit { + const Node* node; + OpKernel* kernel; + std::vector users; + void run(ThreadPoolExecutor* executor, SessionState* sessionState); +}; + +class ThreadPoolExecutor { + public: + explicit ThreadPoolExecutor(); + ~ThreadPoolExecutor(); + ThreadPoolExecutor(const ThreadPoolExecutor&) = delete; + ThreadPoolExecutor& operator=(ThreadPoolExecutor const&) = delete; + ThreadPoolExecutor(ThreadPoolExecutor&&) = delete; + ThreadPoolExecutor& operator=(ThreadPoolExecutor&&) = delete; + + void run(SessionState& session, const std::vector& roots); + + void start(int32_t numThreads); + void stop(); + + // execute unit on the current thread + // NOTE: children can still be offloaded to other threads + C10_ALWAYS_INLINE void execute_inline(SessionState* session, WorkUnit* unit); + + void add(SessionState* session, WorkUnit* unit); + void add( + SessionState* session, + std::vector::const_iterator&& begin, + const std::vector::const_iterator&& end); + + C10_ALWAYS_INLINE moodycamel::ProducerToken& ptok(); + C10_ALWAYS_INLINE moodycamel::ConsumerToken& ctok(); + + private: + void loop(); + + std::atomic_bool stopped_{false}; + + std::unique_ptr sem_{std::make_unique()}; + + std::unique_ptr> + work_; + std::vector threads_; +}; + +class ParallelGraphExecutor : public GraphExecutorBase { + public: + ParallelGraphExecutor( + const Graph& graph, + std::vector> nodeKernels, + const ExecutorConfig& executorConfig); + + std::vector execute( + ExecutionFrame& frame, + std::vector inputs) override; + + std::vector executeWithPrefilledFrame( + ExecutionFrame& frame) override; + + private: + ThreadPoolExecutor executor_; + + std::vector inputWorkUnits_; + c10::FastMap nodeToWorkUnit_; + std::vector workUnits_; + + const Graph& graph_; + c10::FastMap> producers_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/SerialGraphExecutor.cpp b/torch/nativert/executor/SerialGraphExecutor.cpp new file mode 100644 index 00000000000000..017f4f178c8b5a --- /dev/null +++ b/torch/nativert/executor/SerialGraphExecutor.cpp @@ -0,0 +1,34 @@ +#include +#include +#include + +namespace torch::nativert { + +std::vector SerialGraphExecutor::execute( + ExecutionFrame& executionFrame, + std::vector inputs) { + fillUserInputs(executionFrame, std::move(inputs)); + + return executeWithPrefilledFrame(executionFrame); +} + +std::vector SerialGraphExecutor::executeWithPrefilledFrame( + ExecutionFrame& executionFrame) { + executionFrame.withMemoryPlanner([&]() { + // Execute kernels for all nodes except prim.Input and prim.Output + for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) { + nodeKernels_[nodeIdx]->compute(executionFrame); + + // don't free intermediate values when static memory planning is enabled + if (executorConfig_.tryFreeUnmanagedValuesAfterUse) { + // Free the intermediate values that are no used anymore + for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) { + executionFrame.releaseValueIfNeeded(valueKey); + } + } + } + }); + return executionFrame.tryMoveUserOutputs(); +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/SerialGraphExecutor.h b/torch/nativert/executor/SerialGraphExecutor.h new file mode 100644 index 00000000000000..cae3313e61e850 --- /dev/null +++ b/torch/nativert/executor/SerialGraphExecutor.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace torch::nativert { + +class SerialGraphExecutor : public GraphExecutorBase { + public: + SerialGraphExecutor( + const Graph& graph, + std::vector> nodeKernels, + const ExecutorConfig& executorConfig) + : GraphExecutorBase(graph, std::move(nodeKernels), executorConfig) {} + + std::vector execute( + ExecutionFrame& frame, + std::vector inputs) override; + + std::vector executeWithPrefilledFrame( + ExecutionFrame& frame) override; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/SessionState.h b/torch/nativert/executor/SessionState.h new file mode 100644 index 00000000000000..37cdf32b3fd3ef --- /dev/null +++ b/torch/nativert/executor/SessionState.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include + +#include +#include + +namespace torch::nativert { + +template > +struct copyable_atomic : public __atomic_base { + public: + copyable_atomic() = default; + ~copyable_atomic() = default; + copyable_atomic(const T& t) noexcept(__atomic_base::is_always_lock_free) + : __atomic_base(t) {} + copyable_atomic(const copyable_atomic& other) noexcept( + __atomic_base::is_always_lock_free) + : __atomic_base(other.load()) {} + copyable_atomic& operator=(const copyable_atomic& other) noexcept( + __atomic_base::is_always_lock_free) { + this->store(other.load()); + return *this; + } + copyable_atomic(copyable_atomic&& other) = delete; + copyable_atomic& operator=(copyable_atomic&& other) = delete; +}; + +class SessionState { + public: + explicit SessionState( + ExecutionFrame& frame, + c10::FastMap> producers = + {}) + : producers_(std::move(producers)), frame_(frame) {} + + C10_ALWAYS_INLINE void wait() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&]() { + return workOutstanding_.load(std::memory_order_seq_cst) == 0; + }); + } + + C10_ALWAYS_INLINE void addWork(uint32_t ct = 1) { + workOutstanding_.fetch_add(ct, std::memory_order_seq_cst); + } + + C10_ALWAYS_INLINE void removeWork() { + if (workOutstanding_.fetch_sub(1, std::memory_order_seq_cst) == 1) { + std::unique_lock lock(mutex_); + cv_.notify_one(); + } + } + + C10_ALWAYS_INLINE ExecutionFrame& frame() { + return frame_; + } + + C10_ALWAYS_INLINE /* producersRemaining == 0 */ bool decrementProducers( + const Node* node) { + return producers_.at(node).fetch_sub(1, std::memory_order_seq_cst) == 1; + } + + C10_ALWAYS_INLINE void setProducers(const Node* node, uint32_t v = 1) { + producers_[node] += v; + } + + private: + std::atomic_uint_fast32_t workOutstanding_; + c10::FastMap> producers_; + + std::condition_variable cv_; + std::mutex mutex_; + + ExecutionFrame& frame_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/Weights.cpp b/torch/nativert/executor/Weights.cpp new file mode 100644 index 00000000000000..1c14b79e6d9432 --- /dev/null +++ b/torch/nativert/executor/Weights.cpp @@ -0,0 +1,438 @@ + +#include +#include + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +#include + +namespace torch::nativert { + +WeightVersion Weights::globalVersion_ = 0; + +Weights::Weights( + const Graph* graph, + const std::optional>& + stateDict, + Placement placement) + : graph_(graph), + weightsMeta_(graph->weightsMeta()), + placement_(std::move(placement)), + version_(globalVersion_++) { + if (stateDict.has_value()) { + loadStateDict(stateDict.value()); + } +} + +Weights::Weights( + const Graph* graph, + std::shared_ptr pytorchStreamReader, + const std::unordered_map& stateDictPaths, + std::string_view stateDictPathPrefix, + const std::unordered_map& constantPaths, + std::string_view constantPathPrefix, + Placement placement, + std::function skipSizeCheck, + std::function skipDtypeCheck) + : graph_(graph), + weightsMeta_(graph->weightsMeta()), + placement_(std::move(placement)), + version_(globalVersion_++), + skipSizeCheck_(std::move(skipSizeCheck)), + skipDtypeCheck_(std::move(skipDtypeCheck)) { + auto loadAndInsert = + [&](const std::string& tensorName, + std::string_view pathPrefix, + const std::unordered_map& tensorPaths, + bool isUsed) { + auto pathIt = tensorPaths.find(tensorName); + TORCH_CHECK( + pathIt != tensorPaths.end(), + "Couldn't find ", + tensorName, + " in tensorPaths"); + + const std::string tensorPath = std::string{pathPrefix} + pathIt->second; + VLOG(1) << "Loading weight from: " << tensorPath; + TORCH_CHECK( + pytorchStreamReader->hasRecord(tensorPath), + tensorPath, + " not found"); + + auto [tensorData, tensorDataSize] = + pytorchStreamReader->getRecord(tensorPath); + + // TODO: We now have two copies of metadata for weights, one in + // model definition /models/.json, another in + // /extra/xl_weights/_model_param_config.json + // Currently, we only use the metadata from model definition. + std::optional tensorMeta; + if (weightsMeta_.find(tensorName) != weightsMeta_.end()) { + tensorMeta = weightsMeta_.at(tensorName); + } else { + TORCH_CHECK(false, "Tensor meta not found for: ", tensorName); + } + + if (tensorDataSize == 0 && tensorMeta->numel() > 0) { + VLOG(1) << "Tensor " << tensorName + << " does not have data and create on Meta device"; + allValues_[tensorName] = at::empty_strided( + tensorMeta->sizes(), + tensorMeta->strides(), + tensorMeta->asTensorOptions().device(at::kMeta)); + return; + } + + if (!isUsed) { + VLOG(1) << "Tensor " << tensorName << " is not used during inference"; + auto targetDevice = placement_.getMappedDevice(tensorMeta->device()); + allValues_[tensorName] = + at::scalar_tensor(0, at::TensorOptions().device(targetDevice)); + return; + } + + size_t bytesPerEntry = + c10::scalarTypeToTypeMeta(tensorMeta->dtype()).itemsize(); + auto device = tensorData.device(); + auto storage = c10::Storage( + c10::Storage::use_byte_size_t(), + at::detail::computeStorageNbytes( + tensorMeta->sizes(), tensorMeta->strides(), bytesPerEntry), + std::move(tensorData), // ownership is transferred + nullptr, + false); + const auto tensorOptions = at::TensorOptions(device) + .dtype(tensorMeta->dtype()) + .requires_grad(false); + auto tensor = + at::empty({0}, tensorOptions) + .set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides()); + + auto targetDevice = placement_.getMappedDevice(tensorMeta->device()); + VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice; + if (!isSameDevice(targetDevice, tensor.device())) { + tensor = tensor.to(targetDevice); + } + + allValues_[tensorName] = tensor; + }; + + auto loadAndInsertParamsBuffers = [&](const auto& tensorName, bool isUsed) { + return loadAndInsert( + std::string(tensorName), stateDictPathPrefix, stateDictPaths, isUsed); + }; + + size_t weightIndex = 0; + bool isUsed = true; + const auto& weightValues = graph->weightValues(); + + for (const auto& tensorName : graph->signature().parameters()) { + isUsed = !weightValues[weightIndex]->users().empty(); + if (!isUsed) { + unusedWeights_.insert(std::string(tensorName)); + } + loadAndInsertParamsBuffers(tensorName, isUsed); + weightIndex++; + } + for (const auto& tensorName : graph->signature().buffers()) { + isUsed = !weightValues[weightIndex]->users().empty(); + if (!isUsed) { + unusedWeights_.insert(std::string(tensorName)); + } + loadAndInsertParamsBuffers(tensorName, isUsed); + weightIndex++; + } + + // Load tensor constants and custom object constants, they are both stored + // in the same directory in the archive, i.e. "extra/constants/" tensor + // constants are prefixed with "tensor_" custom objects are prefixed with + // "custom_obj_" + auto loadConstants = [&](const auto& constants) { + for (const auto& constantName : constants) { + auto pathIt = constantPaths.find(std::string(constantName)); + TORCH_CHECK( + pathIt != constantPaths.end(), + "Couldn't find ", + constantName, + " in constantPaths"); + auto& fileName = pathIt->second; + + if (c10::starts_with( + fileName, + torch::_export::archive_spec::TENSOR_CONSTANT_FILENAME_PREFIX)) { + // tensor constants + isUsed = !weightValues[weightIndex]->users().empty(); + if (!isUsed) { + unusedWeights_.insert(std::string(constantName)); + } + loadAndInsert( + std::string(constantName), + constantPathPrefix, + constantPaths, + isUsed); + weightIndex++; + } else { + TORCH_CHECK(false, "Unknown constant path: ", fileName); + } + } + }; + loadConstants(graph->signature().nonPersistentBuffers()); + loadConstants(graph->signature().tensorConstants()); + + // custom object constants + for (const auto& customObjName : graph->signature().customObjs()) { + auto pathIt = constantPaths.find(std::string(customObjName)); + TORCH_CHECK( + pathIt != constantPaths.end(), + "Couldn't find ", + customObjName, + " in constantPaths"); + auto& fileName = pathIt->second; + + if (!c10::starts_with( + fileName, + torch::_export::archive_spec::CUSTOM_OBJ_FILENAME_PREFIX)) { + TORCH_CHECK(false, "Unknown constant path: ", fileName); + } + std::string customObjPath = std::string{constantPathPrefix} + fileName; + LOG(INFO) << "Loading custom object from: " << customObjPath; + + TORCH_CHECK( + pytorchStreamReader->hasRecord(customObjPath), + customObjPath, + " not found"); + + const auto& [customObjData, customObjDataSize] = + pytorchStreamReader->getRecord(customObjPath); + + const char* customObjDataPtr = + reinterpret_cast(customObjData.get()); + std::string customObjBytes( + customObjDataPtr, customObjDataPtr + customObjDataSize); + + c10::IValue customObj = torch::jit::pickle_load_obj(customObjBytes); + TORCH_CHECK( + customObj.isCustomClass(), "Custom object is not a custom class"); + TORCH_CHECK(!customObj.isNone(), "Custom object is None"); + customObjs_[std::string(customObjName)] = std::move(customObj); + customObjsPaths_[customObjPath] = std::string(customObjName); + } +} + +std::unordered_map Weights::parameters() const { + std::unordered_map result; + for (const auto& name : graph_->signature().parameters()) { + result.emplace(name, allValues_.at(std::string(name))); + } + return result; +} + +std::unordered_map Weights::buffers() const { + std::unordered_map result; + for (const auto& name : graph_->signature().buffers()) { + result.emplace(name, allValues_.at(std::string(name))); + } + return result; +} + +std::unordered_map Weights::attributes() const { + return allValues_; +} + +at::Tensor Weights::at(const std::string& name) const { + auto it = allValues_.find(name); + if (it != allValues_.end()) { + return it->second; + } + + TORCH_CHECK(false, name, " not found in Weights ", toString()); +} + +at::Tensor& Weights::at(const std::string& name) { + auto it = allValues_.find(name); + if (it != allValues_.end()) { + return it->second; + } + + TORCH_CHECK(false, name, " not found in Weights ", toString()); +} + +bool Weights::contains(const std::string& name) const { + return allValues_.find(name) != allValues_.end(); +} + +c10::IValue Weights::getCustomObj(const std::string& name) const { + auto it = customObjs_.find(name); + if (it != customObjs_.end()) { + return it->second; + } + + TORCH_CHECK(false, "Custom objects ", name, " not found in Weights"); +} + +c10::IValue Weights::getCustomObjByFileName(const std::string& name) const { + auto it = customObjsPaths_.find(name); + TORCH_CHECK( + it != customObjsPaths_.end(), + "Custom objects with file name ", + name, + " not found in Weights"); + const std::string obj_name = it->second; + return getCustomObj(obj_name); +} + +void Weights::loadStateDict( + const std::unordered_map& stateDict) { + auto validateAndInsert = [&](const std::string& name) { + auto stateDictIt = stateDict.find(name); + TORCH_CHECK( + stateDictIt != stateDict.end(), + "Couldn't find ", + name, + " in stateDict"); + + // Verify that the tensor matches the tensorMeta + auto it = weightsMeta_.find(name); + TORCH_CHECK( + it != weightsMeta_.end(), "Couldn't find ", name, " in weightsMeta"); + + auto targetDevice = placement_.getMappedDevice(it->second.device()); + auto tensor = stateDictIt->second.toTensor().to(targetDevice); + + TORCH_CHECK(tensor.sizes() == it->second.sizes()); + TORCH_CHECK(tensor.dtype() == it->second.dtype()); + + allValues_.emplace(name, tensor); + }; + + for (const auto& name : graph_->signature().parameters()) { + validateAndInsert(std::string(name)); + } + for (const auto& name : graph_->signature().buffers()) { + validateAndInsert(std::string(name)); + } + // TensorConstants_ not filled !! +} + +void Weights::validateValue(const std::string& name, const at::Tensor& newValue) + const { + auto& weightMeta = weightsMeta_.at(name); + + TORCH_CHECK( + weightMeta.sizes() == newValue.sizes() || + (skipSizeCheck_ && skipSizeCheck_(name)) || + unusedWeights_.find(name) != unusedWeights_.end(), + "Mismatched sizes for ", + name, + ": ", + weightMeta.sizes(), + " vs ", + newValue.sizes()); + TORCH_CHECK( + weightMeta.dtype() == newValue.dtype() || + (skipDtypeCheck_ && skipDtypeCheck_(name)) || + unusedWeights_.find(name) != unusedWeights_.end(), + "Mismatched dtype for ", + name, + ": ", + weightMeta.dtype(), + " vs ", + newValue.dtype()); + + auto targetDevice = placement_.getMappedDevice(weightMeta.device()); + if (targetDevice.is_cpu() && targetDevice.has_index()) { + LOG(WARNING) << "Target device is cpu but has index: " << targetDevice; + } + TORCH_CHECK( + isSameDevice(targetDevice, newValue.device()), + "Mismatched device for ", + name, + ": ", + targetDevice, + " vs ", + newValue.device()); +} + +void Weights::setValue(const std::string& name, const at::Tensor& newValue) { + if (allValues_.find(name) != allValues_.end()) { + validateValue(name, newValue); + } else { + LOG(WARNING) << name << " is not found in the registered weights"; + } + + allValues_[name] = newValue; +} + +void Weights::updateValue(const std::string& name, const at::Tensor& newValue) { + auto it = allValues_.find(name); + TORCH_CHECK( + it != allValues_.end(), name, " not found in Weights ", toString()); + validateValue(name, newValue); + + it->second.copy_(newValue); +} + +void Weights::updateValues( + const std::unordered_map& newValues) { + for (auto& [name, newValue] : newValues) { + updateValue(name, newValue); + } +} + +std::string Weights::toString() const { + std::stringstream ss; + ss << '['; + for (const auto& [name, _] : allValues_) { + ss << name << ", "; + } + ss << ']'; + ss << '['; + for (const auto& [name, _] : customObjs_) { + ss << name << ", "; + } + ss << ']'; + return ss.str(); +} + +void Weights::validateAllWeightsLoaded() { + auto checkNames = [&](const auto& names) { + for (const auto& name : names) { + if (unusedWeights_.find(std::string(name)) != unusedWeights_.end()) { + continue; + } + auto it = allValues_.find(std::string(name)); + TORCH_CHECK(it != allValues_.end(), "Missing weight: ", name); + TORCH_CHECK(it->second.defined(), "Weight not defined: ", name); + if (it->second.device().is_meta()) { + LOG(WARNING) << "Weight is on meta device: " << name; + } + } + }; + checkNames(graph_->signature().parameters()); + checkNames(graph_->signature().buffers()); + checkNames(graph_->signature().nonPersistentBuffers()); + checkNames(graph_->signature().tensorConstants()); +} + +void Weights::updateFoldedConst(std::string_view name, c10::IValue tensor) { + foldedConstsMap_[std::string{name}] = std::move(tensor); +} + +const std::unordered_map& Weights::getFoldedConsts() + const { + return foldedConstsMap_; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/Weights.h b/torch/nativert/executor/Weights.h new file mode 100644 index 00000000000000..5a6778b524ff19 --- /dev/null +++ b/torch/nativert/executor/Weights.h @@ -0,0 +1,143 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch::nativert { + +using WeightVersion = int; +/** + * @brief A class that manages the weights of a graph, providing functionality + * to load, access, and manipulate them. + * + * It is responsible for handling the parameters, buffers, and constants + * associated with a graph It provides mechanisms to load weights from + * serialized data, access and modify them, and performs necessary validation + * checks. + */ +class Weights { + public: + explicit Weights( + const Graph* graph, + const std::optional>& + stateDict = std::nullopt, + Placement placement = Placement()); + + // Arguments + // - pytorchStreamReader: the reader for the model archive + // - stateDictPath: a map from parameter/buffer/constant name to file path in + // the archive + // - stateDictPathPrefix: a prefix that will be prepended to paths in + // stateDictPathPrefix + // - constantPaths: a map from constant name to file path in the archive + // - constantPathPrefix: a prefix that will be prepended to paths in + // constantPathPrefix + // - placement: the device placement of the weights, default to follow the + // original device in the weight's metadata + explicit Weights( + const Graph* graph, + std::shared_ptr + pytorchStreamReader, + const std::unordered_map& stateDictPaths, + std::string_view stateDictPathPrefix, + const std::unordered_map& constantPaths, + std::string_view constantPathPrefix, + Placement placement = Placement(), + std::function skipSizeCheck = {}, + std::function skipDtypeCheck = {}); + + at::Tensor at(const std::string& name) const; + at::Tensor& at(const std::string& name); + bool contains(const std::string& name) const; + c10::IValue getCustomObj(const std::string& name) const; + c10::IValue getCustomObjByFileName(const std::string& name) const; + + std::unordered_map parameters() const; + + std::unordered_map buffers() const; + + std::unordered_map attributes() const; + + void loadStateDict( + const std::unordered_map& stateDict); + + /* + * Replace the value stored at the weight with name "name". + */ + void setValue(const std::string& name, const at::Tensor& newValue); + + /* + * Update the value stored at the weight with name "name". + * This is done in-place. + */ + void updateValue(const std::string& name, const at::Tensor& newValue); + + void updateValues( + const std::unordered_map& newValues); + + void validateValue(const std::string& name, const at::Tensor& newValue) const; + + void validateAllWeightsLoaded(); + + void updateFoldedConst(std::string_view name, c10::IValue tensor); + + const std::unordered_map& getFoldedConsts() const; + + C10_ALWAYS_INLINE const c10::FastMap& + getConstFoldedValues() const { + return constFoldedValues_; + } + + C10_ALWAYS_INLINE void setConstFoldedValue( + const std::string& n, + c10::IValue iv) { + constFoldedValues_.insert_or_assign(n, std::move(iv)); + } + + std::string toString() const; + + WeightVersion version() const { + return version_; + } + + private: + const Graph* graph_; + const std::unordered_map& weightsMeta_; + Placement placement_; + + // keys are parameter/buffer/constant names, not graph input names! + std::unordered_map allValues_; + + std::unordered_map customObjs_; + + // contains CustomClassHolder map from a file name to an arbitrary + // key in customObjs_ that hold the loaded content of the file. + // This is used in AOTIDelegateExecutor. + std::unordered_map customObjsPaths_; + + // The liftcycle of folded consts should be tied with the weights from which + // it was derived. The ordering of the constant should be consistent with + // the output order of const graph. + std::vector foldedConsts_; + std::unordered_map foldedConstsMap_; + + c10::FastMap constFoldedValues_; + + // unique version number for this instance of weight + const WeightVersion version_; + + // every instance of Weight has a unique version number + static WeightVersion globalVersion_; + + std::function skipSizeCheck_ = {}; + std::function skipDtypeCheck_ = {}; + + // save the names of unused weights + std::unordered_set unusedWeights_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/AliasAnalyzer.cpp b/torch/nativert/executor/memory/AliasAnalyzer.cpp new file mode 100644 index 00000000000000..e56eb408531694 --- /dev/null +++ b/torch/nativert/executor/memory/AliasAnalyzer.cpp @@ -0,0 +1,174 @@ +#include + +#include + +namespace torch::nativert { + +AliasAnalyzer::AliasAnalyzer( + const Graph& graph, + const c10::FastMap& schemas) { + for (const auto&& [i, node] : c10::enumerate(graph.nodes())) { + for (const auto& input : node.inputs()) { + create_or_update_lifetime(input.value, i); + } + + for (const auto& output : node.outputs()) { + create_or_update_lifetime(output, i); + } + + if (update_aliases_if_packed_listunpack(node, i) /* applied? */) { + continue; + } + + maybe_update_aliases_from_schema(node, schemas); + } + + // set all non-aliasing outputs. outputs + // that are aliased will be set later when + // lifetimes are extended + for (const auto* output : graph.outputs()) { + if (!is_alias(output)) { + values_associated_with_outputs_.insert(output); + } + } + + maybe_extend_lifetimes(graph); + log_state(); +} + +bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack( + const Node& node, + size_t i) { + if (node.target() != "prim.ListUnpack") { + return false; + } + + const auto* list = node.inputs()[0].value; + + // we can't infer about how this list was made in this case + // so fallback to default always-aliasing behaviour + if (const auto* p = list->producer(); p && p->target() != "prim.ListPack") { + return false; + } + + const auto& list_elems = list->getListElements(); + TORCH_CHECK_EQ(list_elems.size(), node.numOutputs()); + + for (const auto j : c10::irange(node.numOutputs())) { + const Value* input = list_elems.at(j); + const Value* output = node.outputs().at(j); + + TORCH_CHECK_NE(input, output); + + create_or_update_lifetime(input, i); + create_or_update_lifetime(output, i); + + aliases_[output].insert(input); + } + + return true; +} + +void AliasAnalyzer::maybe_update_aliases_from_schema( + const Node& node, + const c10::FastMap& schemas) { + std::function is_alias = + []([[maybe_unused]] size_t input_idx, + [[maybe_unused]] size_t output_idx) { return true; }; + + const FunctionSchema* schema = nullptr; + if (auto schemaIt = schemas.find(std::string(node.target())); + schemaIt != schemas.end()) { + schema = &schemaIt->second; + } + + if (!schema) { + VLOG(1) << "schema not found for " << node.target() + << " assuming worst case aliasing"; + } + + for (size_t j = 0; j < node.numInputs(); j += 1) { + for (size_t k = 0; k < node.numOutputs(); k += 1) { + const Value* input = node.inputs().at(j).value; + const Value* output = node.outputs().at(k); + + if (!schema || schema->alias(j, k)) { + VLOG(1) << node.target() + << " may contain input/output alias: " << input->id() << " -> " + << output->id(); + aliases_[output].insert(input); + } + } + } +} + +void AliasAnalyzer::create_or_update_lifetime(const Value* value, size_t i) { + if (auto [lifetimeIt, inserted] = lifetimes_.try_emplace(value, i, i); + !inserted) { + lifetimeIt->second.end = i; + } +} + +void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) { + c10::FastSet extended; + + for (auto nodeIt = graph.nodes().rbegin(); nodeIt != graph.nodes().rend(); + ++nodeIt) { + const auto& inputs = nodeIt->inputs(); + for (const auto& input : inputs) { + if (auto aliasIt = aliases_.find(input.value); + aliasIt != aliases_.end()) { + const auto& alias = aliasIt->second; + for (const auto& src : alias) { + if (extended.find(src) != extended.end()) { + continue; + } + + auto& eol = lifetimes_[src].end; + eol = lifetimes_[input.value].end; + + VLOG(1) << "extended EOL of value " << src->id() << " to " << eol; + + extended.insert(src); + + if (eol == graph.nodes().size() - 1 /* aliases output */) { + values_associated_with_outputs_.insert(src); + } + } + } + } + } +} + +void AliasAnalyzer::log_state() const { + if (!VLOG_IS_ON( + 1) /* this is usually too large to be logged with VLOG directly */) { + return; + } + + std::cout << [&]() -> std::string { + std::ostringstream ss; + ss << "[nativert layout planner] AliasAnalyzer ran....\n"; + ss << "lifetimes:\n"; + + for (const auto& [v, lifetime] : lifetimes_) { + ss << " " << v->name() << ": [" << lifetime.start << ", " << lifetime.end + << "]\n"; + } + + ss << "\naliases:\n"; + for (const auto& [v, alias] : aliases_) { + ss << " " << v->name() << " -> "; + for (const auto* a : alias) { + ss << a->name() << ", "; + } + ss << '\n'; + } + + ss << '\n'; + + return ss.str(); + }() << std::flush; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/AliasAnalyzer.h b/torch/nativert/executor/memory/AliasAnalyzer.h new file mode 100644 index 00000000000000..c9784d5d84ab9f --- /dev/null +++ b/torch/nativert/executor/memory/AliasAnalyzer.h @@ -0,0 +1,85 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch::nativert { + +class AliasAnalyzer { + public: + explicit AliasAnalyzer( + const Graph& graph, + const c10::FastMap& schemas); + + C10_ALWAYS_INLINE const AllocationLifetime& lifetime( + const Value* value) const { + return lifetimes_.at(value); + } + + C10_ALWAYS_INLINE bool is_alias(const Value* value) const { + return aliases_.find(value) != aliases_.end(); + } + + C10_ALWAYS_INLINE bool is_storage_associated_with_output( + const Value* value) const { + return values_associated_with_outputs_.find(value) != + values_associated_with_outputs_.end(); + } + + C10_ALWAYS_INLINE const c10::FastSet& + values_associated_with_output_storage() const { + return values_associated_with_outputs_; + } + + private: + // listunpack operations who take a list that has + // been created with a listpack operation should + // be transparent with respect to aliasing + // + // e.g., given the op + // %t[] = prim.ListPack(l0=%t0, l1=%t1) + // %x1, %x2 = prim.ListUnpack(self=%t) + // x1 should directly alias t0 + // and likewise x2 should directly alias t1 + // + // this will make sure that the lifetimes of x1 and x2 + // are not just the max of the lifetimes of t0 and t1 + // which can make tensor-packing more efficient if list + // element EOL's differ by large amounts + bool /* applied */ update_aliases_if_packed_listunpack( + const Node& node, + size_t i); + + // use the schema aliasing spec, or if none is provided, + // assume all outputs alias all inputs + void maybe_update_aliases_from_schema( + const Node& node, + const c10::FastMap& schemas); + + void create_or_update_lifetime(const Value* value, size_t i); + + // work our way from the DAG's output node to the input node + // propagating the maximum EOL of all aliases back to their + // source value(s). + // + // in addition, if a graph output is an alias, we need to ensure + // that the source values are treated as graph outputs + // so that we don't free them before the graph output is copied + // back to the user (and we ignore them when creating a memory plan + // even if they aren't explicitly considered outputs) + void maybe_extend_lifetimes(const Graph& graph); + + void log_state() const; + + // mapping from alias to the set of values that it aliases + c10::FastMap> aliases_; + c10::FastMap lifetimes_; + // non-aliasing outputs or non-aliasing intermediates that are aliased by + // outputs + c10::FastSet values_associated_with_outputs_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/Bump.cpp b/torch/nativert/executor/memory/Bump.cpp new file mode 100644 index 00000000000000..ac396e06a539b4 --- /dev/null +++ b/torch/nativert/executor/memory/Bump.cpp @@ -0,0 +1,24 @@ +#include + +namespace torch::nativert { + +LayoutPlan BumpAllocationPlanner( + const std::vector& allocation_specs) { + LayoutPlan plan; + + auto& allocations = plan.allocations; + auto& total_size = plan.total_size; + + allocations.reserve(allocation_specs.size()); + for (const auto& spec : allocation_specs) { + allocations.push_back(Allocation{ + spec.size, + total_size, + }); + total_size += spec.size; + } + + return plan; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/Bump.h b/torch/nativert/executor/memory/Bump.h new file mode 100644 index 00000000000000..d424e2bb6924ca --- /dev/null +++ b/torch/nativert/executor/memory/Bump.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace torch::nativert { + +// lay out all tensors contiguously in memory +// this doesn't take into account lifetimes, +// it literally just puts them all next to each other +LayoutPlan BumpAllocationPlanner( + const std::vector& allocation_specs); + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/DisjointStorageGroups.cpp b/torch/nativert/executor/memory/DisjointStorageGroups.cpp new file mode 100644 index 00000000000000..751e43fd5277c9 --- /dev/null +++ b/torch/nativert/executor/memory/DisjointStorageGroups.cpp @@ -0,0 +1,175 @@ +#include + +#include + +#include +#include +#include + +namespace { + +using namespace torch::nativert; + +// A StorageGroup represents a collection of allocations that share backing +// storage +class StorageGroup { + public: + // every storage group must contain at least one allocation spec. + explicit StorageGroup(const AllocationSpec* spec) + : max_spec_size_(spec->size), + lifetime_(spec->lifetime), + spec_group_({spec}) {} + + void add_spec(const AllocationSpec* spec) { + spec_group_.push_back(spec); + max_spec_size_ = std::max(max_spec_size_, spec->size); + TORCH_DCHECK_LT(lifetime_.end, spec->lifetime.end); + lifetime_.end = spec->lifetime.end; + is_free_ = false; + } + + const std::vector& spec_group() const { + return spec_group_; + } + + size_t max_spec_size() const { + return max_spec_size_; + } + + size_t num_specs() const { + return spec_group_.size(); + } + + const AllocationLifetime& lifetime() const { + return lifetime_; + } + + bool is_free() const { + return is_free_; + } + + void set_free(bool is_free) { + is_free_ = is_free; + } + + private: + // whether or not this storage group is free + // to add new specs + bool is_free_{false}; + // represents the amount of memory that will be + // allocated for all specs in this group... + size_t max_spec_size_; + // the lifetime of this storage group + AllocationLifetime lifetime_; + // all the specs in this group + std::vector spec_group_; +}; + +} // namespace + +namespace torch::nativert { + +LayoutPlan DisjointStorageGroupsPlanner( + const std::vector& allocation_specs) { + struct CompareAllocationSpecsBySize { + bool operator()(const AllocationSpec* a, const AllocationSpec* b) + const /* noexcept */ + { + return a->size > b->size; + } + }; + + std::vector< + std::multiset> + allocation_indices; + std::vector> deallocation_indices; + + for (const auto& spec : allocation_specs) { + size_t alloc_index = spec.lifetime.start; + size_t dealloc_index = spec.lifetime.end; + + TORCH_DCHECK_LT(alloc_index, dealloc_index); + + if (alloc_index >= allocation_indices.size()) { + allocation_indices.resize(alloc_index + 1); + } + + if (dealloc_index >= deallocation_indices.size()) { + deallocation_indices.resize(dealloc_index + 1); + } + + allocation_indices[alloc_index].insert(&spec); + deallocation_indices[dealloc_index].emplace_back(&spec); + } + + // don't want to invalidate pointers + // so let's make this a list + std::list storage_groups; + // maps each AllocationSpec to its assigned storage group. + c10::FastMap spec_to_storage_group; + // stores the set of storage groups that + // are available for reuse. + std::vector free_storage_groups; + + auto createStorageGroup = [&](const AllocationSpec* spec) { + auto& group = storage_groups.emplace_back(spec); + spec_to_storage_group.emplace(spec, &group); + }; + + auto assignToAvailableStorageGroup = [&](const AllocationSpec* spec) { + DCHECK(!free_storage_groups.empty()); + auto* storage_group = free_storage_groups.back(); + TORCH_DCHECK_NOTNULL(storage_group); + TORCH_DCHECK_EQ(storage_group->is_free(), true); + storage_group->add_spec(spec); + spec_to_storage_group.emplace(spec, storage_group); + free_storage_groups.pop_back(); + }; + + for (const auto i : c10::irange(allocation_indices.size())) { + for (auto* spec : allocation_indices[i]) { + TORCH_DCHECK_NOTNULL(spec); + if (free_storage_groups.empty()) { + createStorageGroup(spec); + } else { + assignToAvailableStorageGroup(spec); + } + } + + if (i < deallocation_indices.size()) { + for (auto* spec : deallocation_indices[i]) { + TORCH_DCHECK_NOTNULL(spec); + auto* storage_group = spec_to_storage_group.at(spec); + if (!storage_group->is_free() && + storage_group->lifetime().end == spec->lifetime.end) { + storage_group->set_free(true); + free_storage_groups.push_back(storage_group); + } + } + } + } + + LayoutPlan plan; + + c10::FastMap storage_group_to_offset; + size_t offset = 0; + for (const auto& storage_group : storage_groups) { + storage_group_to_offset.emplace(&storage_group, offset); + offset += storage_group.max_spec_size(); + } + + plan.total_size = offset; + plan.allocations.reserve(allocation_specs.size()); + + for (const auto& spec : allocation_specs) { + // specs in storage groups lifetime's shouldn't be overlapping + // so we can just set their offset to the offset of the group + plan.allocations.emplace_back(Allocation{ + spec.size, + storage_group_to_offset.at(spec_to_storage_group.at(&spec))}); + } + + return plan; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/DisjointStorageGroups.h b/torch/nativert/executor/memory/DisjointStorageGroups.h new file mode 100644 index 00000000000000..8131a7000da4d1 --- /dev/null +++ b/torch/nativert/executor/memory/DisjointStorageGroups.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace torch::nativert { + +LayoutPlan DisjointStorageGroupsPlanner( + const std::vector& allocation_specs); + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/FunctionSchema.cpp b/torch/nativert/executor/memory/FunctionSchema.cpp new file mode 100644 index 00000000000000..264ed702cbc0d0 --- /dev/null +++ b/torch/nativert/executor/memory/FunctionSchema.cpp @@ -0,0 +1,58 @@ +#include + +namespace torch::nativert { + +bool FunctionSchema::alias(size_t input_idx, size_t output_idx) const { + // probably quicker than using a map since + // overridden inputs/outputs should be small + for (const auto& [i, o] : aliasing_spec_) { + if (i == input_idx && o == output_idx) { + return true; + } + } + + VLOG(1) << "checking aliasing spec for " << c10_fn_schema_.name() << " " + << (c10_fn_schema_.is_varret() ? "varret" : "non-varret") << " " + << (c10_fn_schema_.is_vararg() ? "vararg" : "non-vararg"); + + if (!aliasing_spec_.empty()) { + VLOG(1) << "aliasing spec is not empty but no entry found for (" + << input_idx << "-->" << output_idx + << ") -- falling back to schema->may_contain_alias()"; + } + + /* + varret and vararg will contribute to the input/output idx's + but because we don't know how many inputs/outputs there are, + the schema will consider these indices to be out of bounds. + + e.g., op(a, b, c, d) where c and d are variadic will result in + may_contain_alias(x, idx_of(c)) and may_contain_alias(x, idx_of(d)) to throw + an out-of-bounds exception + + in this case, we can apply the worst-case aliasing to the varidic + inputs/outputs i.e., all outputs might alias all varargs and all inputs + might be aliased by all varrets + */ + + if (c10_fn_schema_.is_vararg() && + input_idx >= c10_fn_schema_.arguments().size()) { + VLOG(1) << "applying worst-case aliasing for " << c10_fn_schema_.name() + << "'s variadic input " << input_idx; + return true; + } + + if (c10_fn_schema_.is_varret() && + output_idx >= c10_fn_schema_.returns().size()) { + VLOG(1) << "applying worst-case aliasing for " << c10_fn_schema_.name() + << "'s variadic output " << output_idx; + return true; + } + + return c10_fn_schema_.may_contain_alias( + {c10::SchemaArgType::output, output_idx}, + {c10::SchemaArgType::input, input_idx}, + /* bidirectional = */ false); +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/FunctionSchema.h b/torch/nativert/executor/memory/FunctionSchema.h new file mode 100644 index 00000000000000..713df508058472 --- /dev/null +++ b/torch/nativert/executor/memory/FunctionSchema.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +namespace torch::nativert { + +struct InputOutputIdxPair { + size_t input_idx; + size_t output_idx; +}; + +using AliasingSpec = std::vector; + +class FunctionSchema { + public: + explicit FunctionSchema( + const c10::FunctionSchema& schema, + AliasingSpec&& aliasing_spec = {}, + OpKernelKind kernel_kind = OpKernelKind::kInterpreterFallbackKernel) + : aliasing_spec_(std::move(aliasing_spec)), + kernel_kind_(kernel_kind), + c10_fn_schema_(schema) {} + + c10::FunctionSchema& base_schema() { + return c10_fn_schema_; + } + + const c10::FunctionSchema& base_schema() const { + return c10_fn_schema_; + } + + bool alias(size_t input_idx, size_t output_idx) const; + + C10_ALWAYS_INLINE OpKernelKind kernel_kind() const { + return kernel_kind_; + } + + private: + AliasingSpec aliasing_spec_; + OpKernelKind kernel_kind_; + c10::FunctionSchema c10_fn_schema_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/GreedyBySize.cpp b/torch/nativert/executor/memory/GreedyBySize.cpp new file mode 100644 index 00000000000000..afc99e0df8cd34 --- /dev/null +++ b/torch/nativert/executor/memory/GreedyBySize.cpp @@ -0,0 +1,166 @@ +#include +#include +#include + +#include +#include +#include + +#include + +namespace { + +using namespace torch::nativert; + +// we need to track the original order in which allocations were made +// since they will be re-sorted between iterations +struct GreedyAllocation : public Allocation { + explicit GreedyAllocation( + Allocation allocation, + size_t allocation_idx, + size_t input_spec_idx) + : Allocation(allocation), + allocation_index(allocation_idx), + input_spec_index(input_spec_idx) {} + // we need to maintain the allocation ordering s.t., we can look up + // previous allocations directly from descending_allocation_specs_ + // even after allocations has been re-sorted, which happens after + // each allocation is complete. + // + // i.e., this index represents the index of the spec that was used + // to create this allocation inside descending_allocation_specs_ + // AFTER the sorting was completed. + size_t allocation_index{0}; + // index of the spec associated with this allocation + // in the event that the specs get re-ordered + // in the process of creating allocations + // e.g., + // allocation_specs[sX, sY, sZ] + // ^ ^ ^ + // values[vX, vY, vZ] + // + // means that an allocation created from sY + // will have an input_spec_index of 1 + // + // this allows us to return to the original + // ordering before returning the allocations + size_t input_spec_index{0}; +}; + +struct AllocationSpecWithIndex { + const AllocationSpec* spec; + size_t index; +}; + +// associate specs with their original (unsorted) index +// and then sort them in descending order by byte size +std::vector prepare_allocation_specs( + const std::vector& allocation_specs) { + std::vector specs; + specs.reserve(allocation_specs.size()); + + for (const auto i : c10::irange(allocation_specs.size())) { + specs.push_back({&allocation_specs[i], i}); + } + + std::sort(specs.begin(), specs.end(), [](auto& lhs, auto& rhs) { + return lhs.spec->size > rhs.spec->size; + }); + + return specs; +} + +} // namespace + +namespace torch::nativert { + +// https://arxiv.org/pdf/2001.03288 +LayoutPlan GreedyBySizeAllocationPlanner( + const std::vector& allocation_specs) { + LayoutPlan plan; + + auto descending_allocation_specs = prepare_allocation_specs(allocation_specs); + + std::vector allocations; + allocations.reserve(allocation_specs.size()); + + auto get_next_offset = [&](const AllocationSpec& spec) -> size_t { + size_t prev_offset = 0; + std::optional best_offset = std::nullopt; + size_t smallest_gap = std::numeric_limits::max(); + + for (const auto& alloc : allocations) { + if (auto* allocated_spec = + descending_allocation_specs.at(alloc.allocation_index).spec; + allocated_spec->not_overlapping_with(spec)) { + continue; + } + + if (alloc.offset > prev_offset) { + if (size_t gap = alloc.offset - prev_offset; + gap >= spec.size && gap < smallest_gap) { + smallest_gap = gap; + best_offset = prev_offset; + } + } + + prev_offset = std::max(prev_offset, alloc.offset + alloc.size); + } + + return best_offset.value_or(prev_offset); + }; + + size_t total_allocation_size = 0; + for (const auto&& [allocation_index, spec_with_original_index] : + c10::enumerate(descending_allocation_specs)) { + auto& spec = spec_with_original_index.spec; + + auto new_allocation = GreedyAllocation( + Allocation{spec->size, get_next_offset(*spec)}, + allocation_index, + spec_with_original_index.index); + + total_allocation_size += new_allocation.size; + plan.total_size = + std::max(plan.total_size, new_allocation.offset + new_allocation.size); + + VLOG(1) << "allocation with interval " << spec->lifetime.start << "-->" + << spec->lifetime.end << " placed at offset " + << new_allocation.offset; + + // insert new allocation while maintaining relative-offset ordering + // the algorithm is already quadratic because of get_next_offset + // so this is negligible + + auto it = std::lower_bound( + allocations.begin(), + allocations.end(), + new_allocation, + [](auto& lhs, auto& rhs) { return lhs.offset < rhs.offset; }); + allocations.insert(it, new_allocation); + } + + // sort allocations so their ordering is consistent with the input specs + std::sort(allocations.begin(), allocations.end(), [](auto& lhs, auto& rhs) { + return lhs.input_spec_index < rhs.input_spec_index; + }); + + plan.allocations.reserve(allocations.size()); + std::move( + allocations.begin(), + allocations.end(), + std::back_inserter(plan.allocations)); + + if (plan.total_size > 0) { + VLOG(1) << std::fixed << std::setprecision(2) + << "greedy-by-size bytes saved over strictly increasing: " + << (1.0 - ((float)plan.total_size / (float)total_allocation_size)) * + 100 + << "% (" << total_allocation_size << " - " << plan.total_size + << " = " << (total_allocation_size - plan.total_size) << " bytes)"; + } + + return plan; +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/GreedyBySize.h b/torch/nativert/executor/memory/GreedyBySize.h new file mode 100644 index 00000000000000..0d5a61132cf941 --- /dev/null +++ b/torch/nativert/executor/memory/GreedyBySize.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace torch::nativert { + +LayoutPlan GreedyBySizeAllocationPlanner( + const std::vector& allocation_specs); + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutManager.cpp b/torch/nativert/executor/memory/LayoutManager.cpp new file mode 100644 index 00000000000000..7b5062d7993ffd --- /dev/null +++ b/torch/nativert/executor/memory/LayoutManager.cpp @@ -0,0 +1,195 @@ +#include + +#include + +#include +#include + +namespace torch::nativert { + +LayoutManager::LayoutManager( + LayoutPlanner& planner, + ExecutionFrame& parent_frame, + const torch::nativert::LayoutManagerSettings settings) + : planner_(planner), parent_frame_(parent_frame), settings_(settings) { + VLOG(1) << "layout manager created for execution frame"; +} + +void ContiguousLayoutBuffer::allocate(size_t size) { + VLOG(1) << "allocating " << size << " bytes"; + if (C10_LIKELY(size_ > 0)) { + if (C10_LIKELY( + size <= size_) /* NOTE: size will be monotonically increasing */) { + return clear(size_); + } else { + deallocate(); + } + } + data_ptr_ = c10::GetCPUCachingAllocator()->allocate(size); + size_ = size; +} + +void LayoutManager::allocate() { + if (C10_UNLIKELY(state_ == LayoutManagerState::WaitingForValues)) { + return; + } + + bool should_allocate_storages = + state_ == LayoutManagerState::AllocatingStorages; + + ensure_managed_storages(/* allocate= */ should_allocate_storages); + + planner_.with_plan([&](const auto& plan) { allocate_plan(plan); }); + + if (should_allocate_storages) { + state_ = LayoutManagerState::Running; + } +} + +void LayoutManager::allocate_plan(const LayoutPlan& plan) { + if (C10_UNLIKELY(storage_impl_buffer_.size() == 0 || plan.total_size == 0)) { + return; + } + + layout_buffer_.allocate(plan.total_size); + VLOG(1) << "allocated " << layout_buffer_.size() + << " bytes for planned layout"; + + auto* storage_buf = storage_impl_buffer_.buffer(); + + for (const auto i : c10::irange(plan.allocations.size())) { + auto& planned_allocation = plan.allocations[i]; + auto& local_max_nbytes = planned_tensors_max_nbytes_local_[i]; + local_max_nbytes = std::max(local_max_nbytes, planned_allocation.size); + + void* offset_ptr = + layout_buffer_.get_ptr_with_offset(planned_allocation.offset); + // NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object) + auto& storage = storage_buf[i]; + + // if the existing data ptr doesn't have an associated deleter then we + // will set the offset and size directly, as opposed to creating and + // swapping it with a new one + // + // apart from the first allocation when the storage still has the its + // allocator-created dataptr (https://fburl.com/code/u7dsspjm) whose + // deleter is non-null (https://fburl.com/code/7hiwo5zo), this should + // always be true + if (C10_LIKELY( + storage._mutable_data_ptr_no_checks().unsafe_reset_data_and_ctx( + offset_ptr))) { + storage.unsafe_set_nbytes(planned_allocation.size); + } else { + storage.set_data_ptr_noswap(at::DataPtr( + offset_ptr, offset_ptr, nullptr, c10::Device(c10::DeviceType::CPU))); + storage.set_nbytes(planned_allocation.size); + } + } +} + +void LayoutManager::ensure_managed_storages(bool allocate) { + if (C10_UNLIKELY(planned_tensors_.empty())) { + return; + } + + if (C10_UNLIKELY(allocate)) { + storage_impl_buffer_.allocate(planned_tensors_.size()); + VLOG(1) << "allocated " << planned_tensors_.size() * sizeof(at::StorageImpl) + << " bytes for contiguous storages"; + } + + auto* storage_buf = storage_impl_buffer_.buffer(); + + for (size_t i = 0; i < planned_tensors_.size(); i += 1) { + auto* tensor = planned_tensors_[i]; + + at::StorageImpl& storage = *tensor->storage().unsafeGetStorageImpl(); + + if (C10_UNLIKELY(allocate)) { + // from: https://fburl.com/code/4it00yph + // + // We want to manage StorageImpls' lifetimes ourselves, but TensorImpl + // expects to refcount them. unsafe_adapt_non_heap_allocated is our + // escape hatch: it sets the reference count for the StorageImpl to an + // impractically high value so that it will never get deallocated by + // intrusive_ptr, leaving us free to manage its lifetime as we see fit. + // (Note that allowing it to be deallocated by intrusive_ptr would be + // UB, because that would entail deleting an object that wasn't + // allocated with operator new.) + // + // For more information, see the doc comment for + // intrusive_ptr::unsafe_adapt_non_heap_allocated. + tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage( + c10::intrusive_ptr::unsafe_adapt_non_heap_allocated( + &storage_impl_buffer_.to_managed(storage), 1))); + } else if ( + C10_UNLIKELY( + &storage != + // NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object) + &storage_buf + [i]) /* managed storage was replaced for some reason */) { + storage.reset(); + tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage( + c10::intrusive_ptr::unsafe_adapt_non_heap_allocated( + // NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object) + &storage_buf[i], + 1))); + } + } +} + +void LayoutManager::populate_tensor_values() { + CHECK(planned_tensors_.empty()); + CHECK(unplanned_ivalues_.empty()); + + const auto& value_ids = planner_.get_planned_values(); + planned_tensors_.resize(value_ids.size()); + planned_tensors_max_nbytes_local_.resize(value_ids.size()); + + for (const auto&& [i, v] : c10::enumerate(value_ids)) { + planned_tensors_[i] = &parent_frame_.getIValue(v).toTensor(); + } + + const auto& unplanned_value_ids = planner_.get_unplanned_values(); + unplanned_ivalues_.resize(unplanned_value_ids.size()); + for (const auto&& [i, v] : c10::enumerate(unplanned_value_ids)) { + unplanned_ivalues_[i] = &parent_frame_.getIValue(v); + } +} + +void LayoutManager::try_update_historical_max_nbytes() { + for (const auto i : c10::irange(planned_tensors_.size())) { + auto nbytes = get_aligned_nbytes(planned_tensors_[i]->nbytes()); + if (auto& old_max = planned_tensors_max_nbytes_local_[i]; + nbytes > old_max) { + old_max = nbytes; + planner_.try_update_max_size_at_index(i, nbytes); + } + } +} + +void LayoutManager::deallocate_and_plan() { + const auto uninitialized = state_ == LayoutManagerState::WaitingForValues; + + if (C10_UNLIKELY(uninitialized)) { + populate_tensor_values(); + } + + try_update_historical_max_nbytes(); + + if (C10_UNLIKELY(uninitialized)) { + planner_.start_worker_if_not_started(); + } + + if (C10_UNLIKELY(uninitialized)) { + state_ = LayoutManagerState::AllocatingStorages; + } else if (settings_.deallocateBetweenRequests()) { + layout_buffer_.deallocate(); + } + + for (auto* ivalue : unplanned_ivalues_) { + *ivalue = c10::IValue(); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutManager.h b/torch/nativert/executor/memory/LayoutManager.h new file mode 100644 index 00000000000000..76f658e09d08b5 --- /dev/null +++ b/torch/nativert/executor/memory/LayoutManager.h @@ -0,0 +1,206 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::nativert { + +class ExecutionFrame; + +struct ContiguousLayoutBuffer { + public: + ContiguousLayoutBuffer() = default; + ~ContiguousLayoutBuffer() { + deallocate(); + } + + ContiguousLayoutBuffer(ContiguousLayoutBuffer&& other) = delete; + ContiguousLayoutBuffer(const ContiguousLayoutBuffer& other) = delete; + ContiguousLayoutBuffer operator=(ContiguousLayoutBuffer&& other) = delete; + ContiguousLayoutBuffer& operator=(const ContiguousLayoutBuffer& other) = + delete; + + void* get_ptr_with_offset(size_t offset) { + void* raw_ptr = data_ptr_.get(); + TORCH_CHECK_NOTNULL(raw_ptr); + TORCH_CHECK_LE(offset, size_); + return reinterpret_cast( + reinterpret_cast(raw_ptr) + offset); + } + + size_t size() { + return size_; + } + + void allocate(size_t size); + + void deallocate() { + VLOG(1) << "deallocating layout buffer of size " << size_; + size_ = 0; + data_ptr_ = {}; + } + + void clear(size_t size) { + VLOG(1) << "clearing first " << size << "bytes of layout buffer of size " + << size_; + TORCH_CHECK_LE(size, size_); + std::memset(data_ptr_.get(), 0, size); + } + + private: + // the size of the buffer in bytes + size_t size_{0}; + + // the dataptr returned by the allocator + at::DataPtr data_ptr_{}; +}; + +struct ContiguousStorageImplBuffer { + ContiguousStorageImplBuffer() = default; + ~ContiguousStorageImplBuffer() { + deallocate(); + } + + ContiguousStorageImplBuffer(ContiguousStorageImplBuffer&& other) = delete; + ContiguousStorageImplBuffer(const ContiguousStorageImplBuffer& other) = + delete; + ContiguousStorageImplBuffer operator=(ContiguousStorageImplBuffer&& other) = + delete; + ContiguousStorageImplBuffer& operator=( + const ContiguousStorageImplBuffer& other) = delete; + + void deallocate() { + if (buffer_ == nullptr) { + return; + } + + for (const size_t idx : c10::irange(size_)) { + buffer_[idx].~StorageImpl(); + } + + delete[] reinterpret_cast(buffer_); + buffer_ = nullptr; + size_ = capacity_ = 0; + } + + void allocate(size_t capacity) { + if (size_ > 0) { + deallocate(); + } + + capacity_ = capacity; + + static_assert(alignof(at::StorageImpl) <= 8); + buffer_ = reinterpret_cast( + new unsigned char[capacity * sizeof(at::StorageImpl)]); + } + + size_t capacity() { + return capacity_; + } + + size_t size() { + return size_; + } + + c10::StorageImpl* buffer() const { + return buffer_; + } + + c10::StorageImpl& at(size_t i) { + TORCH_CHECK_LT(i, size_) + << "requested storage index " << i << " out of bounds " << size_; + return buffer_[i]; + } + + void reset_all() { + for (const size_t idx : c10::irange(size_)) { + buffer_[idx].reset(); + } + } + + c10::StorageImpl& to_managed(at::StorageImpl& s) { + TORCH_CHECK_LT(size_, capacity_); + return *(new (&buffer_[size_++]) at::StorageImpl( + at::StorageImpl::use_byte_size_t(), + static_cast(s.nbytes()), + s.allocator(), + s.resizable())); + } + + private: + size_t size_{0}; + size_t capacity_{0}; + c10::StorageImpl* buffer_{nullptr}; +}; + +enum class LayoutManagerState { WaitingForValues, AllocatingStorages, Running }; + +class LayoutManager { + public: + LayoutManager( + LayoutPlanner& planner, + ExecutionFrame& parent_frame, + torch::nativert::LayoutManagerSettings settings = {}); + ~LayoutManager() = default; + + void allocate(); + void deallocate_and_plan(); + + private: +#ifdef LayoutPlannerTests_TEST_FRIENDS + LayoutPlannerTests_TEST_FRIENDS; +#endif + + static size_t get_aligned_nbytes(size_t nbytes) { +#if defined(__linux__) && !defined(__ANDROID__) + auto alignment = c10::c10_compute_alignment(nbytes); +#else + auto alignment = c10::gAlignment; +#endif + return ((nbytes) + alignment - 1) & (~(alignment - 1)); + } + + void allocate_plan(const LayoutPlan& plan); + void ensure_managed_storages(bool allocate); + + void populate_tensor_values(); + void try_update_historical_max_nbytes(); + + LayoutPlanner& planner_; + ExecutionFrame& parent_frame_; + + std::vector unplanned_ivalues_; + + std::vector planned_tensors_; + std::vector planned_tensors_max_nbytes_local_; + + ContiguousLayoutBuffer layout_buffer_; + ContiguousStorageImplBuffer storage_impl_buffer_; + + LayoutManagerState state_{LayoutManagerState::WaitingForValues}; + torch::nativert::LayoutManagerSettings settings_; +}; + +class LayoutManagerGuard { + public: + explicit LayoutManagerGuard(LayoutManager& manager) : manager_(manager) { + manager_.allocate(); + } + ~LayoutManagerGuard() { + manager_.deallocate_and_plan(); + } + + LayoutManagerGuard(LayoutManagerGuard&& other) = delete; + LayoutManagerGuard(const LayoutManagerGuard& other) = delete; + LayoutManagerGuard operator=(LayoutManagerGuard&& other) = delete; + LayoutManagerGuard& operator=(const LayoutManagerGuard& other) = delete; + + LayoutManager& manager_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutPlanner.cpp b/torch/nativert/executor/memory/LayoutPlanner.cpp new file mode 100644 index 00000000000000..5c45a08ea6f149 --- /dev/null +++ b/torch/nativert/executor/memory/LayoutPlanner.cpp @@ -0,0 +1,217 @@ +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::nativert { + +LayoutPlanner::LayoutPlanner( + const Graph& graph, + const c10::FastMap& kernelSchemas, + const std::vector& persistentValues, + const torch::nativert::LayoutPlannerSettings& settings) + : managed_values_(graph.values().size()), settings_(settings) { + auto value_to_allocation_spec = c10::FastMap{}; + auto alias_analyzer = AliasAnalyzer(graph, kernelSchemas); + + std::set input_values_set_; + for (const auto* nv : graph.userInputs()) { + if (nv->type() == Type::Kind::Tensor) { + input_values_set_.insert(nv); + } + } + + const auto& tensor_meta = graph.tensorValuesMeta(); + + for (auto&& [i, node] : at::enumerate(graph.nodes())) { + // only manage out variant values + if (const auto schemaIt = kernelSchemas.find(std::string(node.target())); + schemaIt == kernelSchemas.end() || + schemaIt->second.kernel_kind() != OpKernelKind::kStaticDispatchKernel) { + VLOG(1) << "not able to plan outputs for node " << node.target() + << " as it is derived from an unsupported kernel kind."; + continue; + } + + for (const auto& output : node.outputs()) { + // don't manage persistent values + if (bool is_persistent = persistentValues[output->id()]; is_persistent) { + VLOG(1) + << "not planning " << output->name() + << " as it is a persistent value (likely a weight or const-folded)"; + continue; + } + + // only manage tensors + if (bool is_tensor = output->type().kind() == Type::Kind::Tensor; + !is_tensor) { + VLOG(1) << "not planning " << output->name() + << " as it is not a raw tensor. type: " << output->type(); + continue; + } + + // output storage ownership must be given to the caller. + if (const auto& values_associated_with_output = + alias_analyzer.values_associated_with_output_storage(); + values_associated_with_output.find(output) != + values_associated_with_output.end()) { + VLOG(1) + << "not planning " << output->name() + << " as its underlying storage may be associated with a graph output"; + continue; + } + + // inputs are borrowed -- this is merely a sanity check + if (input_values_set_.find(output) != input_values_set_.end()) { + VLOG(1) << "not planning " << output->name() + << " as it is a graph input that is borrowed from the user"; + continue; + } + + // don't plan aliases -- they don't own the associated dataptr + if (bool is_alias = alias_analyzer.is_alias(output); is_alias) { + VLOG(1) << "not planning " << output->name() << " as it is an alias"; + continue; + } + + if (bool is_not_consumed = output->users().empty(); is_not_consumed) { + VLOG(1) << "not planning " << output->name() << " as it has no users"; + continue; + } + + if (auto meta_it = tensor_meta.find(std::string(output->name())); + meta_it != tensor_meta.end()) { + if (const auto& meta = meta_it->second; meta.device() == c10::kCPU) { + auto& spec = value_to_allocation_spec[output]; + spec.lifetime = alias_analyzer.lifetime(output); + managed_values_[output->id()] = true; + continue; + } else { + VLOG(1) << "tensor " << output->name() + << " not placed on cpu so we cannot plan it"; + } + } else /* possible if runtime pass didn't populate meta info */ { + VLOG(1) << "tensor " << output->name() << " has no meta information"; + } + + managed_values_[output->id()] = true; + value_to_allocation_spec[output].lifetime = + alias_analyzer.lifetime(output); + } + } + + LOG(INFO) << "layout planner created with " << value_to_allocation_spec.size() + << " values"; + + switch (settings_.algorithmType()) { + case torch::nativert::LayoutPlannerAlgorithmType::Bump: { + algorithm_ = &BumpAllocationPlanner; + break; + } + case torch::nativert::LayoutPlannerAlgorithmType::GreedyBySize: { + algorithm_ = &GreedyBySizeAllocationPlanner; + break; + } + case LayoutPlannerAlgorithmType::DisjointStorageGroups: { + algorithm_ = &DisjointStorageGroupsPlanner; + break; + } + } + + TORCH_CHECK_NOTNULL(algorithm_); + + initialize_vectors(value_to_allocation_spec); + + auto exec_planner = ExecutionPlanner{graph}; + auto p = exec_planner.createPlan(); + for (const auto& freeable : p->valuesToFree) { + for (const auto v : freeable) { + if (!is_managed(v)) { + unplanned_values_.push_back(v); + } + } + } +} + +void LayoutPlanner::initialize_vectors( + c10::FastMap value_to_allocation_spec) { + size_t num_managed = value_to_allocation_spec.size(); + + planned_values_.resize(num_managed); + planned_allocation_specs_.resize(num_managed); + planned_values_historical_max_nbytes_ = + std::vector(num_managed); + + size_t i = 0; + for (auto& [v, spec] : value_to_allocation_spec) { + TORCH_CHECK_LE(spec.lifetime.start, spec.lifetime.end); + + planned_values_[i] = v->id(); + planned_values_historical_max_nbytes_[i] = spec.size; + planned_allocation_specs_[i] = spec; + + i++; + } + + // for sanity in case anyone tries to use this after this method + // is called with a bunch of junk (i.e., moved specs) in it + value_to_allocation_spec.clear(); +} + +const std::vector& LayoutPlanner::get_planned_values() const { + return planned_values_; +} + +const std::vector& LayoutPlanner::get_unplanned_values() const { + return unplanned_values_; +} + +void LayoutPlanner::start_worker_if_not_started() { + static c10::once_flag flag; + c10::call_once(flag, [&]() { + // make sure plan is populated by the time this + // returns for the first time :P + create_plan(); + worker_ = + std::thread([this]() { run_periodic([this] { create_plan(); }); }); + }); +} + +LayoutPlanner::~LayoutPlanner() { + { + std::unique_lock l(mutex_); + stopped_ = true; + } + cv_.notify_one(); + if (worker_.joinable()) { + worker_.join(); + } +} + +void LayoutPlanner::run_periodic(const std::function& f) { + std::unique_lock l(mutex_); + while (!cv_.wait_for( + l, settings_.planningInterval(), [&]() { return stopped_; })) { + f(); + } +} + +void LayoutPlanner::create_plan() { + // update spec sizes to use historical maximums set + // by execution frames before creating the new plan + for (const auto i : c10::irange(planned_allocation_specs_.size())) { + auto& spec = planned_allocation_specs_[i]; + spec.size = planned_values_historical_max_nbytes_[i].load( + std::memory_order_relaxed); + } + plan_.write([p_new = (*algorithm_)(planned_allocation_specs_)]( + LayoutPlan& plan) { plan = p_new; }); +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutPlanner.h b/torch/nativert/executor/memory/LayoutPlanner.h new file mode 100644 index 00000000000000..6382fdbba01b58 --- /dev/null +++ b/torch/nativert/executor/memory/LayoutPlanner.h @@ -0,0 +1,126 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace { +constexpr inline std::memory_order drop_release(std::memory_order m) noexcept { + return ( + m == std::memory_order_release + ? std::memory_order_relaxed + : ((m == std::memory_order_acq_rel || m == std::memory_order_seq_cst) + ? std::memory_order_acquire + : m)); +} +// derivation of +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2024/p0493r5.pdf +template +void atomic_set_max( + std::atomic* pv, + typename std::atomic::value_type v, + std::memory_order m = std::memory_order_seq_cst) noexcept { + auto const mr = drop_release(m); + auto t = (mr != m) ? pv->fetch_add(0, m) : pv->load(mr); + while (std::max(v, t) != t) { + if (pv->compare_exchange_weak(t, v, m, mr)) { + return; + } + } +} +} // namespace + +namespace torch::nativert { + +class LayoutPlanner { + public: + explicit LayoutPlanner( + const Graph& graph, + const c10::FastMap& + kernelSchemas, + const std::vector& persistentValues, + const torch::nativert::LayoutPlannerSettings& settings); + ~LayoutPlanner(); + + LayoutPlanner(LayoutPlanner&& other) = delete; + LayoutPlanner(const LayoutPlanner& other) = delete; + LayoutPlanner operator=(LayoutPlanner&& other) = delete; + LayoutPlanner& operator=(const LayoutPlanner& other) = delete; + + void start_worker_if_not_started(); + + const std::vector& get_planned_values() const; + const std::vector& get_unplanned_values() const; + + C10_ALWAYS_INLINE bool is_managed(ValueId id) { + TORCH_CHECK_LT(static_cast(id), managed_values_.size()); + return managed_values_[id]; + } + + C10_ALWAYS_INLINE void try_update_max_size_at_index(size_t idx, size_t size) { + atomic_set_max(&planned_values_historical_max_nbytes_[idx], size); + } + + C10_ALWAYS_INLINE + void with_plan(std::function&& cb) { + plan_.read( + std::forward>(std::move(cb))); + } + + private: +#ifdef LayoutPlannerTests_TEST_FRIENDS + LayoutPlannerTests_TEST_FRIENDS; +#endif + + // we need some way of mapping graph values to other information + // (e.g., allocation spec, max historical size) + // + // since there is a 1:1 mapping to/from each of these + // we can create+initialize them here + // + // note: planning algorithms are allowed to change the ordering + // of allocation specs -- so we pass the index of the spec during + // it's insertion s.t., each execution frame can use it to + // reference the correct associated max historical size / underlying + // tensor value + void initialize_vectors( + c10::FastMap value_to_allocation_spec); + + void run_periodic(const std::function& f); + void create_plan(); + + // variables for managing the state of the + // interval worker thread that refreshes + // the plan + std::condition_variable cv_; + std::mutex mutex_; + bool stopped_{false}; + std::thread worker_; + + std::vector unplanned_values_; + + std::vector planned_values_; + std::vector planned_allocation_specs_; + std::vector planned_values_historical_max_nbytes_; + + // managed_values_[value_id] == true + // if graph.values()[value_id] has + // an associated allocation spec + std::vector managed_values_; + + LayoutPlannerAlgorithm* algorithm_; + c10::LeftRight plan_; + + torch::nativert::LayoutPlannerSettings settings_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutPlannerAlgorithm.h b/torch/nativert/executor/memory/LayoutPlannerAlgorithm.h new file mode 100644 index 00000000000000..eda8e57c64d19a --- /dev/null +++ b/torch/nativert/executor/memory/LayoutPlannerAlgorithm.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +namespace torch::nativert { + +// represents the inclusive lifetime of a tensor +// i.e., the buffer used by tensor x with lifetime [m, n] +// can only be safely used during intervals 0 --> m-1 and n+1 --> ... +// +// e.g., +// +// g(x): 0 +// a = op_a(x) 1 +// b = op_b(a) 2 +// c = op_c(a) 3 +// return (b, c) 4 +// +// gives: +// +// lifetime(x) = 0 --> 1 +// lifetime(a) = 1 --> 3 +// lifetime(b) = 2 --> 4 +// lifetime(c) = 3 --> 4 +// +// assuming no aliasing... +// however, if b aliases a we'd get +// +// lifetime(x) = 0 --> 1 +// lifetime(a) = 1 --> *4* (max{l_end(a), l_end(b)}) +// lifetime(b) = 2 --> 4 +// lifetime(c) = 3 --> 4 + +struct AllocationLifetime { + AllocationLifetime() = default; + AllocationLifetime(size_t s, size_t e) : start(s), end(e) {} + + // two lifetime intervals are considered not overlapping + // if their lifetimes are exclusive. + // e.g., + // l(a) = 0 --> 3 + // overlaps with + // l(b) = 3 --> 5 + // since both tensors can exist at t = 3 + // + // however, if l(b) = 4 --> 5 + // l(a) and l(b) do not overlap. + bool not_overlapping_with(const AllocationLifetime& other) const { + return this->end < other.start || this->start > other.end; + } + + bool operator==(const AllocationLifetime& other) const { + return this->start == other.start && this->end == other.end; + } + + size_t start{0}; + size_t end{0}; +}; + +struct AllocationSpec { + AllocationLifetime lifetime{}; + size_t size{0}; + + bool not_overlapping_with(const AllocationSpec& other) const { + return this->lifetime.not_overlapping_with(other.lifetime); + } +}; + +struct Allocation { + size_t size{0}; + size_t offset{0}; +}; + +struct LayoutPlan { + size_t total_size{0}; + // in practice, each allocation has an associated + // allocation spec + // + // for example, given: + // + // allocation_specs = [s1, s2, s3] + // plan = algorithm(allocation_specs) + // + // plan.allocations will be [a1, a2, a3] + // ^ ^ ^ + // mapping back to [s1, s2, s3] + std::vector allocations; +}; + +// a layout planner algorithm is provided a vector of +// allocation specs, and returns a plan containing +// a vector of allocations (i.e., offset & size) +// whose order MUST correspond to that of the input +// +// specifically, provided: +// auto plan = algorithm(allocation_specs); +// +// allocation_specs.size() == plan.allocations.size() +// +// AND +// +// allocation_specs[0] --> plan.allocations[0] +// ... +// allocation_specs[i] --> plan.allocations[i] +// ... +// allocation_specs[allocation_specs.size() - 1] --> +// plan.allocations[plan.allocations.size() - 1] +using LayoutPlannerAlgorithm = + LayoutPlan(const std::vector& allocation_specs); + +} // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutPlannerSettings.h b/torch/nativert/executor/memory/LayoutPlannerSettings.h new file mode 100644 index 00000000000000..8ade27997bdfc8 --- /dev/null +++ b/torch/nativert/executor/memory/LayoutPlannerSettings.h @@ -0,0 +1,84 @@ +#pragma once + +#include + +namespace torch::nativert { + +enum class LayoutPlannerAlgorithmType { + Bump, + GreedyBySize, + DisjointStorageGroups, +}; + +class LayoutManagerSettings { + public: + LayoutManagerSettings() = default; + + bool deallocateBetweenRequests() const { + return deallocateBetweenRequests_; + } + + LayoutManagerSettings& setDeallocateBetweenRequests( + bool deallocateBetweenRequests) { + deallocateBetweenRequests_ = deallocateBetweenRequests; + return *this; + } + + private: + friend class LayoutManager; + bool deallocateBetweenRequests_{true}; +}; + +class LayoutPlannerSettings { + public: + LayoutPlannerSettings() = default; + + bool enabled() const { + return enabled_; + } + + LayoutPlannerAlgorithmType algorithmType() const { + return layoutPlannerAlgorithmType_; + } + + std::chrono::seconds planningInterval() const { + return planningInterval_; + } + + const LayoutManagerSettings& layoutManagerSettings() const { + return layoutManagerSettings_; + } + + LayoutPlannerSettings& setEnabled(bool enabled) { + enabled_ = enabled; + return *this; + } + + LayoutPlannerSettings& setAlgorithmType( + LayoutPlannerAlgorithmType layoutPlannerAlgorithmType) { + layoutPlannerAlgorithmType_ = layoutPlannerAlgorithmType; + return *this; + } + + LayoutPlannerSettings& setPlanningInterval( + std::chrono::seconds planningInterval) { + planningInterval_ = planningInterval; + return *this; + } + + LayoutPlannerSettings& setLayoutManagerSettings( + LayoutManagerSettings layoutManagerSettings) { + layoutManagerSettings_ = layoutManagerSettings; + return *this; + } + + private: + friend class LayoutPlanner; + bool enabled_{false}; + LayoutPlannerAlgorithmType layoutPlannerAlgorithmType_{ + LayoutPlannerAlgorithmType::Bump}; + std::chrono::seconds planningInterval_{5}; + LayoutManagerSettings layoutManagerSettings_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/graph/Graph.cpp b/torch/nativert/graph/Graph.cpp new file mode 100644 index 00000000000000..3cc7f678fcff01 --- /dev/null +++ b/torch/nativert/graph/Graph.cpp @@ -0,0 +1,1564 @@ +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::nativert { + +namespace { + +// Workaround for MSVC bug: "std" ambiguous symbol. +template +constexpr bool is_same_v = std::is_same_v; + +bool isBlank(char n) { + return std::isspace(n); +} + +size_t consumeWhitespaceImpl(std::string_view source, size_t curPos) { + while (isBlank(source.at(curPos))) { + curPos++; + } + return curPos; +} + +size_t expectImpl( + std::string_view source, + std::string_view expected, + size_t curPos) { + curPos = consumeWhitespaceImpl(source, curPos); + const auto actual = source.substr(curPos, expected.size()); + TORCH_CHECK( + expected == actual, + fmt::format( + "Parser error: expected '{}' at position {}, but found '{}'.", + expected, + curPos, + actual)); + curPos += expected.size(); + return curPos; +} + +size_t expectImpl(std::string_view source, char expected, size_t curPos) { + curPos = consumeWhitespaceImpl(source, curPos); + while (isBlank(source.at(curPos))) { + curPos++; + } + TORCH_CHECK( + expected == source[curPos], + "Parser error: expected '{}' at position {}, but found '{}'.", + expected, + curPos, + source[curPos]); + curPos++; + return curPos; +} +} // namespace + +bool operator==(const Type& left, const Type& right) { + if (left.kind() != right.kind()) { + return false; + } + if (std::holds_alternative(left.kind_) && + std::holds_alternative(right.kind_)) { + return std::get(left.kind_).classFqn == + std::get(right.kind_).classFqn; + } + return true; +} + +Graph::Graph() + : insertBefore_(nodes_.end()), + inputNode_(insertNode("prim.Input", {})), + outputNode_(insertNode("prim.Output", {})) { + // Set the insertion point to append to the graph + insertBefore_ = nodes_.iterator_to(*outputNode_); +} + +std::string Graph::getUniqueValueName() { + auto name = fmt::format("v{}", uniqueValueName_); + while (values_.find(name) != values_.end()) { + name = fmt::format("v{}", uniqueValueName_++); + } + return name; +} + +// If `name` is null, create a unique value name +Value* Graph::addValue( + const std::optional& name, + const Type& type, + Node* node) { + const auto valueName = name.value_or(getUniqueValueName()); + ValueId valueId = getNextValueId(); + const auto [it, success] = values_.insert( + {valueName, std::make_unique(valueId, valueName, type, node)}); + TORCH_CHECK( + success, + fmt::format( + "Tried to create Value with name: '{}', but it already existed", + valueName)); + return it->second.get(); +} + +Value* Graph::addInput(std::string_view name, const Type& type) { + return inputNode_->addOutput(name, type); +} + +void Graph::addInput() { + inputNode_->addOutput(); +} + +Value* Graph::addOutput(Value* v) { + outputNode_->addInput({std::string(v->name()), v}); + return v; +} + +void Graph::addConstantOutput(Constant c) { + constantOutputs_.push_back(std::move(c)); +} + +// Create a node without inserting it into the execution graph. +Node* Graph::createNode( + std::string target, + std::vector inputs, + std::unordered_map metadata) { + auto& node = nodesOwner_.emplace_back(std::make_unique( + this, std::move(target), std::move(inputs), std::move(metadata))); + return node.get(); +} + +Node* Graph::insertBefore(Node* toInsert, Node* insertionPoint) { + TORCH_CHECK(insertionPoint != inputNode_, "can't insert before prim.Input"); + TORCH_CHECK( + !toInsert->is_linked(), "expected node to be unlinked: ", *toInsert); + TORCH_CHECK( + insertionPoint->is_linked(), + "expected node to be linked: ", + *insertionPoint); + auto it = nodes_.insert(nodes_.iterator_to(*insertionPoint), *toInsert); + return &*it; +} + +Node* Graph::insert(Node* toInsert) { + TORCH_CHECK( + !toInsert->is_linked(), "expected node to be unlinked: ", *toInsert); + nodes_.insert(insertBefore_, *toInsert); + return toInsert; +} + +Node* Graph::insertAfter(Node* toInsert, Node* insertionPoint) { + TORCH_CHECK(insertionPoint != outputNode_, "can't insert after prim.Output"); + TORCH_CHECK( + !toInsert->is_linked(), "expected node to be unlinked: ", *toInsert); + TORCH_CHECK( + insertionPoint->is_linked(), + "expected node to be linked: ", + *insertionPoint); + + auto insertIt = nodes_.iterator_to(*insertionPoint); + // Increment once because we want to insert after the insertion point + ++insertIt; + auto it = nodes_.insert(insertIt, *toInsert); + return &*it; +} + +Node* Graph::insertNode( + std::string target, + std::vector inputs, + std::unordered_map metadata) { + auto node = + createNode(std::move(target), std::move(inputs), std::move(metadata)); + nodes_.insert(insertBefore_, *node); + return node; +} + +std::ostream& operator<<(std::ostream& out, const Type& ty) { + std::visit( + [&out](auto&& arg) { + using T = std::decay_t; + if constexpr (is_same_v) { + switch (arg) { + case Type::Kind::None: + out << "None"; + break; + case Type::Kind::Tensor: + out << "Tensor"; + break; + case Type::Kind::TensorList: + out << "TensorList"; + break; + case Type::Kind::OptionalTensorList: + out << "OptionalTensorList"; + break; + case Type::Kind::SymInt: + out << "SymInt"; + break; + case Type::Kind::SymFloat: + out << "SymFloat"; + break; + case Type::Kind::SymIntList: + out << "SymIntList"; + break; + case Type::Kind::CustomObj: + out << "CustomObj"; + break; + default: + TORCH_CHECK(false, "Unhandled type"); + } + } else if constexpr (is_same_v) { + out << "CustomObj: " << arg.classFqn; + } + }, + ty.kind_); + return out; +} + +const NamedArgument* Node::tryGetInput(std::string_view name) const { + // Just do a scan over the inputs. We expect there to always be a very small + // number of elements, so it shouldn't be slow. This allows us to avoid a + // second datastructure for lookups. + // Drop a debug check here, just to make sure :) + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs_.size() < 1000); + for (const auto& input : inputs_) { + if (input.name == name) { + return &input; + } + } + return nullptr; +} + +const NamedArgument& Node::getInput(std::string_view name) const { + const auto ret = tryGetInput(name); + if (ret == nullptr) { + TORCH_CHECK( + false, + fmt::format( + "Expected input '{}' on node: '{}' to exist, but it does not.", + name, + fmt::streamed(*this))); + } + return *ret; +} + +const Attribute* Node::tryGetAttribute(std::string_view name) const { + // Just do a scan over the inputs. We expect there to always be a very small + // number of elements, so it shouldn't be slow. This allows us to avoid a + // second datastructure for lookups. + // Drop a debug check here, just to make sure :) + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(attributes_.size() < 1000); + for (const auto& attribute : attributes_) { + if (attribute.name == name) { + return &attribute; + } + } + return nullptr; +} + +const Attribute& Node::getAttribute(std::string_view name) const { + const auto ret = tryGetAttribute(name); + if (ret == nullptr) { + TORCH_CHECK( + false, + fmt::format( + "Expected attribute '{}' on node: '{}' to exist, but it does not.", + name, + fmt::streamed(*this))); + } + return *ret; +} + +void Node::applyDevicePlacement(const Placement& placement) { + for (auto& attribute : attributes_) { + if (std::holds_alternative(attribute.value)) { + auto device = std::get(attribute.value); + auto targetDevice = + placement.getMappedDevice(std::get(attribute.value)); + if (!isSameDevice(targetDevice, device)) { + LOG(INFO) << "Overriding " << device.str() << " to " + << targetDevice.str() << " for node " << *this; + attribute.value = targetDevice; + } + } + } +} + +Node* Node::next() { + return owningGraph()->nodeAfter(this); +} + +const Node* Node::next() const { + return owningGraph()->nodeAfter(this); +} + +Node* Node::prev() { + return owningGraph()->nodeBefore(this); +} + +const Node* Node::prev() const { + return owningGraph()->nodeBefore(this); +} + +bool Node::isBefore(const Node* n) const { + if (this == n) { + return false; + } + + for (const Node* cursor = this->next(); cursor != nullptr; + cursor = cursor->next()) { + if (cursor == n) { + return true; + } + } + // Reached the end without finding n + return false; +} + +std::vector Node::producers() const { + std::vector ret; + + if (this->prev() == nullptr /* prim.Input */) { + return ret; + } + + if (this->next() == nullptr /* prim.Output */) { + for (auto& node : owningGraph_->nodes()) { + if (node.next() == nullptr /* prim.Output */ || + node.prev() == nullptr /* prim.Input */) { + continue; + } + for (auto* dep : node.users()) { + if (dep == this /* prim.Output */) { + ret.push_back(&node); + } + } + } + } else { + std::unordered_set seen; + + for (const auto& input : inputs()) { + auto* n = input.value->producer(); + if (n == nullptr) { + continue; + } + if (const auto [_, inserted] = seen.insert(n); inserted) { + ret.push_back(n); + } + } + + if (ret.empty()) { + ret.push_back(owningGraph_->inputNode()); + } + } + + return ret; +} + +std::vector Node::users() const { + std::vector ret; + + if (this->next() == nullptr /* prim.Output */) { + return ret; + } + + if (this->prev() == nullptr /* prim.Input */) { + for (auto& node : owningGraph_->nodes()) { + if (node.prev() == nullptr /* prim.Input */ || + node.next() == nullptr /* prim.Output */) { + continue; + } + for (auto* dep : node.producers()) { + if (dep == this /* prim.Input */) { + ret.push_back(&node); + } + } + } + } else { + std::unordered_set seen; + + for (const auto* output : outputs()) { + for (auto* n : output->users()) { + if (const auto [_, inserted] = seen.insert(n); inserted) { + ret.push_back(n); + } + } + } + + if (ret.empty()) { + ret.push_back(owningGraph_->outputNode()); + } + } + + return ret; +} + +Node* Graph::createListPack(std::vector inputs, const Type& inputType) { + std::vector nodeInputs; + nodeInputs.reserve(inputs.size()); + for (auto [i, input] : c10::enumerate(inputs)) { + nodeInputs.push_back({fmt::format("l{}", i), input}); + } + // Create a new named value for this + auto name = getUniqueValueName(); + auto node = createNode("prim.ListPack", std::move(nodeInputs)); + + // Make sure all inputs are the same type + for (auto& input : inputs) { + TORCH_CHECK(input->type() == inputType); + } + + if (inputType == Type::Kind::Tensor) { + node->addOutput(name, Type::Kind::TensorList); + } else if (inputType == Type::Kind::SymInt) { + node->addOutput(name, Type::Kind::SymIntList); + } + + return node; +} + +Node* Graph::createOptionalListPack(std::vector inputs) { + std::vector nodeInputs; + nodeInputs.reserve(inputs.size()); + for (auto [i, input] : c10::enumerate(inputs)) { + nodeInputs.push_back({fmt::format("l{}", i), input}); + } + // Create a new named value for this + auto name = getUniqueValueName(); + auto node = createNode("prim.ListPack", std::move(nodeInputs)); + // Make sure all inputs are either None or Tensor + for (auto& input : inputs) { + TORCH_CHECK( + input->type() == Type::Kind::None || + input->type() == Type::Kind::Tensor); + } + node->addOutput(name, Type::Kind::OptionalTensorList); + + return node; +} + +Value* Graph::createConstantSymIntValue(int value) { + auto valueName = getUniqueValueName(); + ValueId valueId = getNextValueId(); + const auto [it, success] = values_.insert( + {valueName, + std::make_unique( + valueId, valueName, Type::Kind::SymInt, nullptr)}); + TORCH_CHECK( + success, + fmt::format( + "Tried to create constant SymInt Value with name: '{}', but it already existed", + valueName)); + constantSymIntValues_[valueId] = value; + return it->second.get(); +} + +Value* Graph::getValue(std::string_view name) const { + // TODO: can eliminate this string copy by enabling heterogeneous lookup for + // the container + return values_.at(std::string(name)).get(); +} + +Value* Graph::tryGetValue(std::string_view name) const { + // TODO: can eliminate this string copy by enabling heterogeneous lookup for + // the container + const auto key = std::string(name); + if (values_.find(key) != values_.end()) { + return values_.at(key).get(); + } + return nullptr; +} + +void Graph::renumberValues() { + std::vector currentValues; + currentValues.reserve(values_.size()); + for (auto& kv : values_) { + currentValues.push_back(kv.second.get()); + } + + // Sort values in creation order (by value ids) + std::sort(currentValues.begin(), currentValues.end(), [](Value* a, Value* b) { + return a->id() < b->id(); + }); + + // Build a new id map with all ids < values_.size() + std::unordered_map oldToNew; + oldToNew.reserve(currentValues.size()); + ValueId newId = 0; + for (Value* v : currentValues) { + oldToNew[v->id()] = newId; + v->setId(newId); + newId++; + } + + std::unordered_map newSymIntMap; + for (auto& [oldId, symIntVal] : constantSymIntValues_) { + auto it = oldToNew.find(oldId); + if (it != oldToNew.end()) { + ValueId updatedId = it->second; + newSymIntMap[updatedId] = symIntVal; + } + } + constantSymIntValues_ = std::move(newSymIntMap); + uniqueValueId_ = newId; +} + +bool Graph::cleanupDeadNodes() { + std::unordered_set visited; + std::vector visitStack; + + // Mark reachable nodes from output + visitStack.push_back(outputNode_); + visited.insert(outputNode_); + + while (!visitStack.empty()) { + const Node* current = visitStack.back(); + visitStack.pop_back(); + + for (auto& namedArg : current->inputs()) { + Value* val = namedArg.value; + Node* producer = val->producer(); + + if (!producer) { + continue; + } + if (!visited.count(producer)) { + visited.insert(producer); + visitStack.push_back(producer); + } + } + } + + // Remove all nodes not in visited (other than input/outputs) + std::vector toRemove; + for (auto& n : nodes()) { + if (n.target() == "prim.Input" || n.target() == "prim.Output" || + visited.count(&n)) { + continue; + } + toRemove.push_back(&n); + } + + const bool mutated = !toRemove.empty(); + + // Remove nodes in reverse order to handle input/output dependencies + for (auto it = toRemove.rbegin(); it != toRemove.rend(); ++it) { + removeNode(*it); + } + + renumberValues(); + lint(); + + return mutated; +} + +void Graph::lint() const { + // Check that every value has a producer marked. + for (const auto& [name, value] : values_) { + // Some constant symint and None don't have producer nodes + if (value->type().kind() != Type::Kind::SymInt && + value->type().kind() != Type::Kind::None) { + TORCH_CHECK(value->isFolded() || value->producer() != nullptr); + } + } + for (const auto& node : nodes()) { + TORCH_CHECK_EQ(node.owningGraph(), this); + } + // Check that every list type is either produced by a prim.ListPack or + // immediately consumed by a prim.ListUnpack. We make use of this invariant + // to retrieve list elements in `getListElements`. + for (const auto& [_, value] : values_) { + if (value->type().kind() != Type::Kind::TensorList) { + continue; + } + const bool producedByListPack = + value->producer(/* resolve_folded = */ true)->target() == + "prim.ListPack"; + const bool consumedByListUnpack = value->users().size() == 1 && + value->users()[0]->target() == "prim.ListUnpack"; + TORCH_CHECK(producedByListPack || consumedByListUnpack); + } + + auto getNames = [](const auto& values) { + c10::FastSet names; + for (const auto* value : values) { + if (value) { + names.emplace(value->name()); + } + } + return names; + }; + signature_.lint(getNames(inputs()), getNames(outputs())); +} + +void Graph::finalize() { + // build userOutputs_ view + userOutputs_.clear(); + size_t constantIndex = 0; + for (auto& outputName : signature_.userOutputs()) { + if (outputName.has_value()) { + userOutputs_.emplace_back(getValue(*outputName)); + } else { + if (constantIndex < constantOutputs_.size()) { + userOutputs_.emplace_back(std::move(constantOutputs_[constantIndex])); + constantIndex++; + } else { + TORCH_CHECK(false, "No more constant outputs available"); + } + } + } +} + +namespace { +// Scan through a node's inputs, replacing ALL instances of `old` with +// `replacement`. Returns true if a replacement occurred, otherwise false. +bool replace(Node* node, Value* old, Value* replacement) { + bool replacementOccurred = false; + for (auto& input : node->inputs()) { + if (input.value == old) { + input.value = replacement; + replacementOccurred = true; + } + } + return replacementOccurred; +} +} // namespace + +void Graph::replaceAllUses(Value* old, Value* replacement) { + for (auto user : old->users()) { + // Find this use in the input list and replace it + auto replaced = replace(user, old, replacement); + TORCH_CHECK(replaced); + replacement->addUser(user); + } + old->eraseAllUsers(); + signature_.replaceAllUses(old->name(), replacement->name()); +} + +void Graph::replaceAllUsesAfterNode( + Value* old, + Value* replacement, + Node* afterThis) { + auto it = nodes_.iterator_to(*afterThis); + // Don't search `afterThis` + ++it; + // Scan through all node inputs linearly and replace uses + for (; it != nodes_.end(); ++it) { + Node* node = &*it; + const bool replaced = replace(node, old, replacement); + if (replaced) { + old->eraseUser(node); + replacement->addUser(node); + } + } + signature_.replaceAllUses(old->name(), replacement->name()); +} + +void Graph::applyDevicePlacement(const Placement& placement) { + // TODO: consolidate device info in weight loading here as well. + for (auto& node : nodes_) { + node.applyDevicePlacement(placement); + } +} + +Node* Graph::nodeAfter(Node* n) { + TORCH_CHECK_EQ(n->owningGraph(), this); + if (n == outputNode_) { + return nullptr; + } + auto it = nodes_.iterator_to(*n); + return &*(++it); +} + +const Node* Graph::nodeAfter(const Node* n) const { + TORCH_CHECK_EQ(n->owningGraph(), this); + if (n == outputNode_) { + return nullptr; + } + auto it = nodes_.iterator_to(*n); + return &*(++it); +} + +Node* Graph::nodeBefore(Node* n) { + TORCH_CHECK_EQ(n->owningGraph(), this); + if (n == inputNode_) { + return nullptr; + } + auto it = nodes_.iterator_to(*n); + return &*(--it); +} + +const Node* Graph::nodeBefore(const Node* n) const { + TORCH_CHECK_EQ(n->owningGraph(), this); + if (n == inputNode_) { + return nullptr; + } + auto it = nodes_.iterator_to(*n); + return &*(--it); +} + +void Graph::removeNode(Node* n) { + TORCH_CHECK_EQ(n->owningGraph(), this) + << "Node does not belong to this graph!"; + + for (auto* outputVal : n->outputs()) { + TORCH_CHECK( + outputVal->users().empty(), + "Trying to erase a node that still has users: ", + outputVal->name()); + outputVal->eraseAllUsers(); + removeValue(outputVal); + } + + for (const auto& input : n->inputs()) { + input.value->eraseUser(n); + } + + TORCH_CHECK(n->is_linked(), "Node is not linked to the graph!"); + n->unlink(); + + auto it = std::find_if( + nodesOwner_.begin(), + nodesOwner_.end(), + [n](const std::unique_ptr& ptr) { return ptr.get() == n; }); + + TORCH_CHECK(it != nodesOwner_.end(), "Node not found in nodesOwner_!"); + nodesOwner_.erase(it); +} + +void Graph::removeValue(Value* value) { + // TODO: assuming not removing from constantSymIntValues_ + TORCH_CHECK(value->users().empty(), "Cannot erase a value with users."); + auto it = values_.find(std::string(value->name())); + TORCH_CHECK( + it != values_.end(), + "Attempted to erase a value not in graph ", + value->name()); + values_.erase(it); +} + +std::vector Graph::insertGraph( + const Graph& subgraph, + std::vector inputs, + std::unordered_map& valueMap) { + TORCH_CHECK_EQ(subgraph.inputs().size(), inputs.size()) + << "Input size mismatch"; + for (auto i : c10::irange(subgraph.inputs().size())) { + valueMap[subgraph.inputs()[i]] = inputs[i]; + } + + // Clone each node from subgraph + for (const auto& n : subgraph.nodes()) { + if (n.target() == "prim.Input" || n.target() == "prim.Output") { + continue; + } + + std::vector clonedInputs; + auto inputs = n.inputs(); + clonedInputs.reserve(inputs.size()); + for (auto& inp : inputs) { + auto it = valueMap.find(inp.value); + TORCH_CHECK(it != valueMap.end(), "Missing input value in subgraph"); + clonedInputs.push_back({inp.name, it->second}); + } + + Node* newNode = insertNode( + std::string(n.target()), std::move(clonedInputs), n.metadata()); + + for (const auto& attr : n.attributes()) { + Attribute newAttr; + newAttr.name = attr.name; + + std::visit( + [&](auto&& val) -> void { + // Workaround for MSVC bug: "std" ambiguous symbol. + using std::unique_ptr; + using std::move; + using T = std::decay_t; + if constexpr (is_same_v>) { + LOG(ERROR) + << "Graph attributes are not supported yet. Skipping attribute: " + << attr.name; + } else { + newAttr.value = val; +#ifdef __clang__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunknown-warning-option" +#pragma GCC diagnostic ignored "-Wunqualified-std-cast-call" +#endif + newNode->addAttribute(move(newAttr)); +#ifdef __clang__ +#pragma GCC diagnostic pop +#endif + } + }, + attr.value); + } + + for (const auto* outVal : n.outputs()) { + const auto& uniqueName = getUniqueValueName(); + Value* newOut = newNode->addOutput(uniqueName, outVal->type()); + valueMap[outVal] = newOut; + } + } + + auto subgraphOutputs = subgraph.outputs(); + std::vector outputValues; + outputValues.reserve(subgraphOutputs.size()); + for (auto* outputValue : subgraphOutputs) { + outputValues.emplace_back(valueMap[outputValue]); + } + lint(); + return outputValues; +} + +Node::Node( + Graph* owningGraph, + std::string target, + std::vector inputs, + std::unordered_map metadata) + : owningGraph_(owningGraph), + target_(std::move(target)), + inputs_(std::move(inputs)), + metadata_(std::move(metadata)) { + for (const auto& input : inputs_) { + input.value->addUser(this); + } +} + +Value* Node::addInput(NamedArgument input) { + inputs_.push_back(std::move(input)); + auto val = inputs_.back().value; + val->addUser(this); + return val; +} + +void Node::addInputs(const std::vector& inputs) { + for (const auto& input : inputs) { + addInput(input); + } +} + +void Node::addAttribute(Attribute attr) { + attributes_.push_back(std::move(attr)); +} + +void Node::addOutput() { + outputs_.push_back(nullptr); +} + +Value* Node::addOutput(const Type& type) { + TORCH_CHECK_EQ(type, Type::Kind::None); + Value* v = owningGraph_->addValue(std::nullopt, type, this); + outputs_.push_back(v); + return v; +} + +Value* Node::addOutput(std::string_view name, const Type& type) { + Value* v = owningGraph_->addValue(std::string(name), type, this); + outputs_.push_back(v); + return v; +} + +void Node::destroy() { + owningGraph_->removeNode(this); +} + +void Value::addUser(Node* node) { + for (const auto* user : users_) { + if (user == node) { + return; + } + } + users_.push_back(node); +} + +void Value::eraseUser(Node* node) { + users_.erase( + std::remove_if( + users_.begin(), users_.end(), [&](Node* el) { return el == node; }), + users_.end()); +} + +std::vector Value::getListElements() const { + std::vector ret; + if (auto p = producer(); p && p->target() == "prim.ListPack") { + for (const auto& tv : p->inputs()) { + ret.push_back(tv.value); + } + } else { + TORCH_CHECK_EQ(users().size(), 1); + const auto listUnpack = users()[0]; + TORCH_CHECK_EQ(listUnpack->target(), "prim.ListUnpack"); + for (const auto v : listUnpack->outputs()) { + ret.push_back(v); + } + } + return ret; +} + +template +[[maybe_unused]] inline constexpr bool AlwaysFalse = false; + +c10::IValue constantToIValue(const Constant& constant) { + // Workaround for MSVC bug: "std" ambiguous symbol. + using std::string; + using std::unique_ptr; + using std::vector; + return std::visit( + [](auto&& arg) -> c10::IValue { + using T = std::decay_t; + if constexpr (is_same_v) { + return c10::IValue(); + } else if constexpr (std::is_convertible_v) { + return arg; + } else if constexpr (is_same_v>) { + TORCH_CHECK( + false, "subgraph arguments cannot be turned into ivalues!"); + } else { + static_assert(AlwaysFalse, "non-exhaustive visitor!"); + } + }, + constant); +} + +namespace { + +template +[[maybe_unused]] inline constexpr bool always_false_v = false; + +void printDouble(std::ostream& out, double arg) { + fmt::print(out, "{}", arg); +} + +template +std::ostream& printList( + std::ostream& out, + bool encloseInSquareBrackets, + const T& list, + F formatter) { + if (encloseInSquareBrackets) { + out << '['; + } + for (const auto& [idx, el] : c10::enumerate(list)) { + if (idx > 0) { + out << ", "; + } + formatter(out, el); + } + if (encloseInSquareBrackets) { + out << ']'; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Constant& constant) { + // Workaround for MSVC bug: "std" ambiguous symbol. + using std::quoted; + using std::string; + using std::unique_ptr; + using std::vector; + std::visit( + [&](auto&& arg) { + using T = std::decay_t; + if constexpr (is_same_v) { + out << "None"; + } else if constexpr (is_same_v || is_same_v) { + out << arg; + } else if constexpr ( + is_same_v> || is_same_v>) { + out << fmt::format("{}", fmt::streamed(arg)); + } else if constexpr (is_same_v) { + printDouble(out, arg); + } else if constexpr (is_same_v>) { + printList(out, true, arg, printDouble); + } else if constexpr (is_same_v) { + out << quoted(arg); + } else if constexpr (is_same_v) { + out << kScalarTypePrefix << arg; + } else if constexpr (is_same_v) { + out << kMemoryFormatPrefix << arg; + } else if constexpr (is_same_v) { + out << kLayoutPrefix << arg; + } else if constexpr (is_same_v) { + out << kDevicePrefix << "{" << arg << "}"; + } else if constexpr (is_same_v>) { + out << fmt::format("[{}]", fmt::join(arg, ",")); + } else if constexpr (is_same_v>) { + out << fmt::format(""); + VLOG(0) << "Subgraph pretty print is not implemented"; + } else { + static_assert(always_false_v, "non-exhaustive visitor!"); + } + }, + constant); + return out; +} + +void printValue(std::ostream& out, const Value* v) { + if (!v) { + out << ""; + return; + } + out << *v; +} + +void printNamedArgument(std::ostream& out, const NamedArgument& nv) { + out << nv.name << "=" << *nv.value; +} + +void printAttribute(std::ostream& out, const Attribute& nv) { + out << nv.name << "=" << nv.value; +} +} // namespace + +std::ostream& operator<<(std::ostream& out, const Value& v) { + out << "%" << v.name(); + // If a list, distinguish it by adding a [] + // Looks like %my_list[] + if (v.type() == Type::Kind::TensorList) { + out << "[]"; + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Node& node) { + // special casing for inputs and outputs + if (node.target() == "prim.Input") { + out << "graph("; + printList(out, false, node.outputs(), printValue); + out << "):"; + return out; + } + if (node.target() == "prim.Output") { + out << "return("; + printList(out, false, node.inputs(), [](std::ostream& out, const auto& nv) { + out << *nv.value; + }); + out << ")"; + return out; + } + + printList(out, false, node.outputs_, printValue); + + out << " = "; + out << node.target_ << "("; + printList(out, false, node.inputs_, printNamedArgument); + if (!node.inputs_.empty() && !node.attributes_.empty()) { + // Emit a connective ',' between inputs and attributes. + out << ", "; + } + + printList(out, false, node.attributes_, printAttribute); + out << ")"; + return out; +} + +std::ostream& operator<<(std::ostream& out, const Graph& graph) { + for (const auto& node : graph.nodes_) { + out << node << "\n"; + } + return out; +} + +c10::Device convertDevice(std::string_view symbol) { + // Symbol looks like `Device{cuda:1}` + const auto typeStart = symbol.find('{') + 1; + TORCH_CHECK_LT(typeStart, symbol.size()); + + const auto typeEnd = symbol.find(':'); + TORCH_CHECK_NE(typeEnd, std::string_view::npos); + + const auto type = symbol.substr(typeStart, typeEnd - typeStart); + const auto indexStart = typeEnd + 1; + TORCH_CHECK_LT(indexStart, symbol.size()); + + const auto indexEnd = symbol.find('}'); + TORCH_CHECK_NE(indexEnd, std::string_view::npos); + + const auto index = symbol.substr(indexStart, indexEnd - indexStart); + + c10::Device device((std::string(type))); + auto indexValue = c10::tryToNumber(std::string{index}); + TORCH_CHECK(indexValue.has_value(), "Invalid device index format"); + int64_t deviceIndex = indexValue.value(); + TORCH_CHECK( + deviceIndex >= std::numeric_limits::min() && + deviceIndex <= std::numeric_limits::max(), + "Device index out of range for int8_t"); + device.set_index(static_cast(deviceIndex)); + return device; +} + +Constant convertAtomicConstant(std::string_view symbol) { + if (c10::starts_with(symbol, "\"")) { + // chop off the outer quotes and return the string + TORCH_CHECK_GE(symbol.size(), 2); + symbol.remove_prefix(1); + symbol.remove_suffix(1); + return std::string(symbol); + } else if (symbol == "None") { + return None(); + } else if (symbol == "true") { + return true; + } else if (symbol == "false") { + return false; + } else if (c10::starts_with(symbol, kMemoryFormatPrefix)) { + torch::_export::MemoryFormat value = torch::_export::MemoryFormat::Unknown; + symbol.remove_prefix(kMemoryFormatPrefix.length()); + torch::_export::parseEnum(symbol, value); + return convertJsonMemoryFormat(value); + } else if (c10::starts_with(symbol, kLayoutPrefix)) { + torch::_export::Layout value = torch::_export::Layout::Unknown; + symbol.remove_prefix(kLayoutPrefix.length()); + torch::_export::parseEnum(symbol, value); + return convertJsonLayout(value); + } else if (c10::starts_with(symbol, kDevicePrefix)) { + return convertDevice(symbol); + } else if (c10::starts_with(symbol, kScalarTypePrefix)) { + torch::_export::ScalarType value = torch::_export::ScalarType::UNKNOWN; + symbol.remove_prefix(kScalarTypePrefix.length()); + torch::_export::parseEnum(symbol, value); + return convertJsonScalarType(value); + } + + // match number + // We need to disambiguate between int and float constants + const auto maybeInt = c10::tryToNumber(std::string{symbol}); + + // Libraries may happily convert "5.0" to an int 5, but we want that to + // become a float. So add an extra check for whether a '.' is in the string + // to guard against that. + bool hasDecimalSeparator = symbol.find('.') != std::string_view::npos; + if (maybeInt.has_value() && !hasDecimalSeparator) { + return maybeInt.value(); + } + + const auto maybeDouble = c10::tryToNumber(std::string{symbol}); + if (maybeDouble.has_value()) { + return maybeDouble.value(); + } + + TORCH_CHECK(false, "unhandled symbol: ", symbol); +} + +Constant convertListConstant(std::string_view source) { + std::vector values; + size_t curPos = 0; + Constant type = None(); + + // This basically the same as parseValueList, it's probably better to refactor + curPos = expectImpl(source, '[', curPos); + while (true) { + curPos = consumeWhitespaceImpl(source, curPos); + + size_t start = curPos; + while (source.at(curPos) != ',' && source.at(curPos) != ']') { + curPos++; + } + auto symbol = source.substr(start, curPos - start); + auto val = convertAtomicConstant(symbol); + if (std::holds_alternative(type)) { + // First time around; initialize our type sentinel with the first value. + // We will use this on subsequent iterations to check that all types are + // the same. + if (auto intPtr = std::get_if(&val)) { + type = *intPtr; + } else if (auto doublePtr = std::get_if(&val)) { + type = *doublePtr; + } else if (auto boolPtr = std::get_if(&val)) { + type = *boolPtr; + } else { + TORCH_CHECK(false, "constant lists only support int, float, bool"); + } + } else { + TORCH_CHECK_EQ(type.index(), val.index()) + << "lists must have all the same type"; + } + values.push_back(std::move(val)); + if (source.at(curPos) == ']') { + break; + } + curPos = expectImpl(source, ',', curPos); + } + expectImpl(source, ']', curPos); + + // Some annoying unwrapping + // std::vector> --> + // Constant> + // Do it the dumb way. + if (std::holds_alternative(type)) { + std::vector inner; + inner.reserve(values.size()); + for (const auto& el : values) { + inner.push_back(std::get(el)); + } + return inner; + } else if (std::holds_alternative(type)) { + std::vector inner; + inner.reserve(values.size()); + for (const auto& el : values) { + inner.push_back(std::get(el)); + } + return inner; + } else if (std::holds_alternative(type)) { + std::vector inner; + inner.reserve(values.size()); + for (const auto& el : values) { + inner.push_back(std::get(el)); + } + return inner; + } + TORCH_CHECK(false, "constant lists only support int, float, bool"); +} + +namespace { + +/** + * Deserialization for graphs: parse the output produced by operator<<(Graph). + * This parser really only expects the exact output generated by well-formed + * Graph objects, so it is not very permissive and does not give good error + * messages. + */ +class Parser { + public: + explicit Parser(std::string_view source) + : source_(source), graph_(Graph::createGraph()) {} + std::unique_ptr parse(); + + private: + template + std::vector parseList( + char open, + char close, + const std::function& parseFn); + + std::string_view parseUntil( + const std::function& fn, + bool includeEnd = false); + + void expect(std::string_view expected); + void expect(char expected); + bool nextEquals(std::string_view expected) const; + bool nextIf(std::string_view expected); + bool nextIf(char expected); + void consumeWhitespace(); + bool validIdent(char n); + char cur(); + + void parseReturn(); + void parseNode(); + std::pair parseOutput(); + void parseGraphInputs(); + std::string_view parseString(); + std::variant parseArgument(); + std::variant parseNamedArgument(); + Value* parseSymbolicArgument(); + // Symbols look like %v109, with the same valid ident rules as Python + // This returns the symbol *without* the % at the front. + std::string_view parseAtomicSymbol(); + + size_t curPos_ = 0; + std::string_view source_; + std::unique_ptr graph_; + torch::_export::GraphSignature signature_; +}; + +std::unique_ptr Parser::parse() { + parseGraphInputs(); + while (true) { + consumeWhitespace(); + if (nextEquals("return")) { + parseReturn(); + break; + } + parseNode(); + } + // For graph textual format, it should be safe to assume all + // inputs/outputs are from users. + graph_->setSignature(GraphSignature{signature_}); + graph_->finalize(); + graph_->lint(); + // TODO: Might have some source left over, should check it if so. + return std::move(graph_); +} + +bool Parser::nextIf(std::string_view expected) { + if (nextEquals(expected)) { + curPos_ += expected.size(); + return true; + } + return false; +} + +bool Parser::nextIf(char expected) { + if (cur() == expected) { + curPos_++; + return true; + } + return false; +} + +void Parser::parseGraphInputs() { + TORCH_CHECK_EQ(curPos_, 0); + expect("graph"); + const auto inputs = parseList( + '(', ')', [&]() { return parseAtomicSymbol(); }); + std::vector inputSpecs; + inputSpecs.reserve(inputs.size()); + for (const auto& input : inputs) { + graph_->addInput(input, Type::Kind::Tensor); + + torch::_export::TensorArgument inputTensorArg; + inputTensorArg.set_name(std::string{input}); + torch::_export::Argument inputArg; + inputArg.set_as_tensor(std::move(inputTensorArg)); + torch::_export::UserInputSpec userInput; + userInput.set_arg(std::move(inputArg)); + torch::_export::InputSpec inputSpec; + inputSpec.set_user_input(std::move(userInput)); + inputSpecs.push_back(std::move(inputSpec)); + } + signature_.set_input_specs(std::move(inputSpecs)); + // TODO populate graphinputs + expect(":"); +} + +template +std::vector Parser::parseList( + char open, + char close, + const std::function& parseFn) { + std::vector ret; + expect(open); + + // Handle empty list + if (nextIf(close)) { + return ret; + } + while (true) { + ret.push_back(parseFn()); + if (cur() == close) { + break; + } + expect(','); + } + expect(close); + return ret; +} + +// Parse until `fn` returns true, returning the segment of the source that was +// consumed. If `includeEnd` is true, the returned segment will also include +// final character, which caused `fn` to return true. +std::string_view Parser::parseUntil( + const std::function& fn, + bool includeEnd) { + size_t start = curPos_; + while (!fn()) { + curPos_++; + } + if (includeEnd) { + curPos_++; + } + return source_.substr(start, curPos_ - start); +} + +// Parse a string, including the outer quotes +std::string_view Parser::parseString() { + size_t start = curPos_; + expect('"'); + while (cur() != '"') { + // Handle escaped characters by skipping the next char when we see a + // backslash + if (cur() == '\\') { + curPos_++; + } + curPos_++; + } + + // Consume final quote + curPos_++; + auto ret = source_.substr(start, curPos_ - start); + return ret; +} + +bool Parser::validIdent(char n) { + return std::isalpha(n) || n == '_' || std::isdigit(n); +} + +// Symbols look like %v109, with the same valid ident rules as Python +// This returns the symbol *without* the % at the front. +std::string_view Parser::parseAtomicSymbol() { + expect("%"); + return parseUntil([&]() { return !validIdent(cur()); }); +} + +char Parser::cur() { + return source_.at(curPos_); +} + +void Parser::consumeWhitespace() { + while (isBlank(cur())) { + curPos_++; + } +} + +void Parser::expect(std::string_view expected) { + curPos_ = expectImpl(source_, expected, curPos_); +} + +void Parser::expect(char expected) { + curPos_ = expectImpl(source_, expected, curPos_); +} + +bool Parser::nextEquals(std::string_view expected) const { + const auto actual = source_.substr(curPos_, expected.size()); + return expected == actual; +} + +// %a, %b = aten.foo.default(input=%foo, foo=[7616], blah=%lol) +void Parser::parseNode() { + std::vector> outputs; + + outputs.push_back(parseOutput()); + while (nextIf(",")) { + outputs.push_back(parseOutput()); + } + expect("="); + consumeWhitespace(); + + // parse target name + const auto target = parseUntil([&]() { return cur() == '('; }); + + Node* node = graph_->insertNode(std::string(target)); + for (auto& [name, var] : outputs) { + node->addOutput(name, var); + } + + auto arguments = parseList>( + '(', ')', [&]() { return parseNamedArgument(); }); + + // Split the arguments into symbolic inputs and constant attributes + for (auto& arg : arguments) { + if (std::holds_alternative(arg)) { + node->addInput(std::get(arg)); + } else { + node->addAttribute(std::get(std::move(arg))); + } + } +} + +void Parser::parseReturn() { + expect("return"); + const auto returns = + parseList('(', ')', [&]() { return parseSymbolicArgument(); }); + std::vector outputSpecs; + outputSpecs.reserve(returns.size()); + for (const auto ret : returns) { + graph_->addOutput(ret); + + torch::_export::TensorArgument retTensorArg; + retTensorArg.set_name(std::string{ret->name()}); + torch::_export::Argument retArg; + retArg.set_as_tensor(std::move(retTensorArg)); + torch::_export::UserOutputSpec userOutput; + userOutput.set_arg(std::move(retArg)); + torch::_export::OutputSpec outputSpec; + outputSpec.set_user_output(std::move(userOutput)); + outputSpecs.push_back(std::move(outputSpec)); + } + signature_.set_output_specs(std::move(outputSpecs)); +} + +std::variant Parser::parseNamedArgument() { + consumeWhitespace(); + // Parse name + const auto symbol = parseUntil([&]() { return cur() == '='; }); + expect('='); + + // Parse value + auto value = parseArgument(); + if (std::holds_alternative(value)) { + return NamedArgument{std::string(symbol), std::get(value)}; + } else { + return Attribute{std::string(symbol), std::get(std::move(value))}; + } +} + +std::pair Parser::parseOutput() { + consumeWhitespace(); + TORCH_CHECK(cur() == '%', fmt::format("expected % but got {}", cur())); + + auto symbol = parseAtomicSymbol(); + if (nextIf('[')) { + expect(']'); + return {symbol, Type::Kind::TensorList}; + } else { + return {symbol, Type::Kind::Tensor}; + } +} + +Value* Parser::parseSymbolicArgument() { + consumeWhitespace(); + TORCH_CHECK(cur() == '%', fmt::format("expected % but got {}", cur())); + + auto symbol = parseAtomicSymbol(); + std::vector listElements; + if (cur() == '[') { + listElements = parseList( + '[', ']', [&]() { return graph_->getValue(parseAtomicSymbol()); }); + } + return graph_->getValue(symbol); +} + +std::variant Parser::parseArgument() { + consumeWhitespace(); + + // match symbol + if (cur() == '%') { + return parseSymbolicArgument(); + } + + // match list + if (cur() == '[') { + const auto symbol = + parseUntil([&]() { return cur() == ']'; }, /*includeEnd=*/true); + return convertListConstant(symbol); + } + + // match string + if (cur() == '"') { + return convertAtomicConstant(parseString()); + } + + // otherwise parse this as a value + const auto symbol = + parseUntil([&]() { return cur() == ',' || cur() == ')'; }); + return convertAtomicConstant(symbol); +} + +} // namespace + +std::unique_ptr stringToGraph(std::string_view source) { + return Parser(source).parse(); +} + +std::string graphToString(const Graph& g, bool include_signature) { + std::stringstream ss; + ss << g; + + if (include_signature) { + ss << "\nGraphSignature\n"; + ss << g.signature(); + } + + return ss.str(); +} + +} // namespace torch::nativert diff --git a/torch/nativert/graph/Graph.h b/torch/nativert/graph/Graph.h new file mode 100644 index 00000000000000..7202272a4aa1ce --- /dev/null +++ b/torch/nativert/graph/Graph.h @@ -0,0 +1,717 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace torch::nativert { + +using NodeIndex = size_t; + +class Value; + +class Type { + public: + enum class Kind { + None, + Tensor, + TensorList, + OptionalTensorList, + SymInt, + SymIntList, + SymBool, + SymFloat, + CustomObj, + }; + + // For simple kinds without classFqn + /*implicit*/ Type(Kind kind) : kind_(kind) {} + + // For CustomObj kind with classFqn + explicit Type(Kind kind, const std::string& classFqn) + : kind_(CustomObjData{classFqn}) { + TORCH_CHECK(kind == Kind::CustomObj); + TORCH_CHECK(!classFqn.empty()); + } + + Kind kind() const { + if (std::holds_alternative(kind_)) { + return Kind::CustomObj; + } + return std::get(kind_); + } + + friend std::ostream& operator<<(std::ostream& out, const Type& ty); + friend bool operator==(const Type& left, const Type& right); + + std::string classFqn() const { + TORCH_CHECK( + kind() == Kind::CustomObj, "Only CustomObj type can have classFqn"); + return std::get(kind_).classFqn; + } + + private: + struct CustomObjData { + std::string classFqn; + }; + std::variant kind_; +}; + +// These are all the constant types that are allowed as attributes on Nodes. +struct None {}; +// None always equals itself +inline bool operator==(const None&, const None&) { + return true; +} + +class Graph; + +/** + * We distinguish between a symbolic value (Tensor, TensorList, SymInt, SymInts, + * etc) and a constant value (int, bool, string, etc). Here Constant is the type + * for all possible constant values. Along with a name, they are represented as + * Attributes on a Node. + */ +using Constant = std::variant< + None, + int64_t, + std::vector, + double, + std::vector, + std::string, + c10::ScalarType, + c10::MemoryFormat, + c10::Layout, + c10::Device, + bool, + std::vector, + std::vector, + std::unique_ptr>; + +c10::IValue constantToIValue(const Constant& constant); + +class Node; + +/** + * Represents a single symbolic value (tensor/symint/list of them). Values are + * inputs and outputs of Nodes. + */ +using ValueId = int; +class Value { + public: + explicit Value(ValueId id, std::string name, Type t, Node* producer) + : name_(std::move(name)), id_(id), type_(t), producer_(producer) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(name_ == this->name()); + } + + // Each Value should be uniquely created and managed by a Graph. It's not + // allowed to copy/move Value instances. + Value(Value&&) = delete; + Value& operator=(Value&&) = delete; + Value(const Value&) = delete; + Value& operator=(Value&) = delete; + + Type type() const { + return type_; + } + + ValueId id() const { + return id_; + } + + std::string_view name() const { + return name_; + } + + const Node* producer(bool resolve_folded = false) const { + return (!resolve_folded && isFolded()) ? nullptr : producer_; + } + + Node* producer() { + return producer_; + } + + void addUser(Node* node); + void eraseUser(Node* node); + void eraseAllUsers() { + users_.clear(); + } + + // Throws an exception if the value is not a TensorList + std::vector getListElements() const; + + const auto& users() const { + return users_; + } + + auto& users() { + return users_; + } + + void setId(ValueId newId) { + // This should only be used inside the renumberValues pass + id_ = newId; + } + + void setIsFolded() { + isFolded_ = true; + } + + bool isFolded() const { + return isFolded_; + } + + private: + friend std::ostream& operator<<(std::ostream& out, const Value& v); + std::string name_; + bool isFolded_{false}; + ValueId id_; + Type type_; + Node* producer_; + // All nodes which have this value as input. + // Note that this is a vector to avoid nondeterminism in iteration, but + // probably should be an unordered set given usage patterns. If this becomes a + // perf problem we should revise. + std::vector users_; +}; + +struct NamedArgument { + std::string name; + Value* value; +}; + +struct Attribute { + std::string name; + Constant value; +}; + +/** + * Node represents a single unit of execution, typically a PyTorch operator. + * Using an intrusive list allows us to allocate all the memory at once for a + * node. This also allows us to track nodes safely without passing around the + * list object, as an intrusive list maintains a stronger invariant that + * expiration will always cause unlinking. + */ +class Node : public c10::IntrusiveListHook { + public: + Node( + Graph* owningGraph, + std::string target, + std::vector inputs, + std::unordered_map metadata); + + std::string_view target() const { + return target_; + } + + void setTarget(std::string_view target) { + target_ = target; + } + + const auto& inputs() const { + return inputs_; + } + + auto& inputs() { + return inputs_; + } + + // NOTE: this invalidates spans given out by inputs() + Value* addInput(NamedArgument input); + void addInputs(const std::vector& inputs); + + // NOTE: this invalidates spans given out by attributes() + void addAttribute(Attribute attr); + + // NOTE: this is ONLY for graph's constant inputs and NOT the common case + void addOutput(); + + Value* addOutput(const Type& type); + + // NOTE: this invalidates spans given out by outputs() + Value* addOutput(std::string_view name, const Type& type); + + size_t numInputs() const { + return inputs_.size(); + } + + size_t numOutputs() const { + return outputs_.size(); + } + + // Return the next node in the Graph's node ordering. + // NOTE: Calling next on the last node (prim.Output) returns nullptr. + Node* next(); + const Node* next() const; + + // Return the previous node in the Graph's node ordering. + // NOTE: Calling prev on the first node (prim.Input) returns nullptr. + Node* prev(); + const Node* prev() const; + + bool isBefore(const Node* n) const; + + std::vector producers() const; + std::vector users() const; + + // Returns nullptr if `name` is not an input + const NamedArgument* tryGetInput(std::string_view name) const; + // Throws an exception if `name` is not an input + const NamedArgument& getInput(std::string_view name) const; + + const auto& attributes() const { + return attributes_; + } + + // Returns nullptr if `name` is not an attribute + const Attribute* tryGetAttribute(std::string_view name) const; + // Throws an exception if `name` is not an attribute + const Attribute& getAttribute(std::string_view name) const; + + const auto& outputs() const { + return outputs_; + } + + void applyDevicePlacement(const Placement& placement); + + std::optional getMetadata(std::string_view key) const { + return metadata_.find(std::string{key}) != metadata_.end() + ? std::optional(std::string_view{metadata_.at(std::string{key})}) + : std::nullopt; + } + + Graph* owningGraph() { + return owningGraph_; + } + + const Graph* owningGraph() const { + return owningGraph_; + } + + void destroy(); + + const std::unordered_map& metadata() const { + return metadata_; + } + + std::string toString() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + void updateInputName(std::string_view oldName, std::string_view newName) { + for (auto& input : inputs_) { + if (input.name == oldName) { + input.name = newName; + break; + } + } + } + + void updateAttributeName(std::string_view oldName, std::string_view newName) { + for (auto& attr : attributes_) { + if (attr.name == oldName) { + attr.name = newName; + break; + } + } + } + + private: + friend std::ostream& operator<<(std::ostream& out, const Node& n); + Graph* owningGraph_; + + // Target used to retrieve the actual thing to execute. + // If an aten operator, we expect this to be fully qualified, including an + // overload name, e.g. "aten.unsqueeze.default" + std::string target_; + // *Symbolic* inputs to this node. NOTE: this does not match the ATen operator + // schema inputs directly. It only represents things that actually participate + // in dataflow, like tensors/symints and lists thereof. + // + // The "name" of the NamedArgument refers to the name of the parameter. + std::vector inputs_; + // Constant inputs to the node. The "name" of the Attribute refers to the + // name of the parameter. + std::vector attributes_; + std::vector outputs_; + + // Extra bits of info added to the node. Contents that are guaranteed will be + // eventually moved to a first-class field on the json struct of schema. + std::unordered_map metadata_; +}; + +/** + * Graph represents a model's computation graph, which is designed to + * facilitate transformation and analysis. + * + * Ownership semantics: + * - Graph owns Nodes and Values + * - Nodes own their constant attributes (which we treat as value types) + * - Nodes have non-owning pointers back to the graph. + * + * NOTE: this class is marked noncopyable/nonmovable and only can be + * heap-allocated via `createGraph()`. This is to ensure stability of + * back-pointers held by Nodes/Values. + */ +class Graph { + public: + static std::unique_ptr createGraph() { + return std::unique_ptr(new Graph()); + } + + Graph(const Graph&) = delete; + Graph& operator=(const Graph&) = delete; + Graph(Graph&&) = delete; + Graph& operator=(Graph&&) = delete; + ~Graph() = default; + + // NOTE: this invalidates spans given out by inputs() + Value* addInput(std::string_view name, const Type& type); + + // NOTE: this is ONLY for graph's constant inputs and NOT the common case + void addInput(); + + // NOTE: this invalidates spans given out by outputs() + Value* addOutput(Value* v); + + void addConstantOutput(Constant c); + + // Create and insert a node at insertionPoint_ + Node* insertNode( + std::string target, + std::vector inputs = {}, + std::unordered_map metadata = {}); + + // Returns the inserted node. + Node* insertBefore(Node* toInsert, Node* insertionPoint); + // Returns the inserted node. + Node* insertAfter(Node* toInsert, Node* insertionPoint); + // Insert at the insertionPoint. Returns the inserted node. + Node* insert(Node* toInsert); + + // Create a node without inserting it into the execution graph. + // A raw pointer to the node is created when `createNode()` on the + // owner Graph object is called. It is guranateed that to be valid + // until the Graph object is destructed. + Node* createNode( + std::string target, + std::vector inputs = {}, + std::unordered_map metadata = {}); + + Value* createConstantSymIntValue(int value); + + Node* createListPack(std::vector inputs, const Type& inputType); + + Node* createOptionalListPack(std::vector inputs); + + size_t numValues() const { + return values_.size(); + } + + // throws on missing name + Value* getValue(std::string_view name) const; + // returns nullptr on missing name + Value* tryGetValue(std::string_view name) const; + + const std::unordered_map getConstantSymIntValues() const { + return constantSymIntValues_; + } + + Value* addValue( + const std::optional& name, + const Type& type, + Node* producer); + void removeValue(Value* value); + + void replaceAllUses(Value* old, Value* replacement); + void replaceAllUsesAfterNode(Value* old, Value* replacement, Node* afterThis); + void removeNode(Node* node); + + void applyDevicePlacement(const Placement& placement); + + std::string getUniqueValueName(); + + ValueId getNextValueId() { + return uniqueValueId_++; + } + + // NOTE: this range can be invalidated by mutations to the graph. + const auto& inputs() const { + return inputNode_->outputs(); + } + + c10::ArrayRef userInputs() const { + size_t offset = signature().inputsToWeights().size() + + signature().inputsToCustomObjs().size(); + return {inputs().data() + offset, inputs().data() + inputs().size()}; + } + + c10::ArrayRef weightValues() const { + return { + inputs().data(), + inputs().data() + signature().inputsToWeights().size()}; + } + + // Return a bidirectional range over `const Value*` + // NOTE: this range can be invalidated by mutations to the graph. + auto outputs() const { + std::vector ret; + ret.reserve(outputNode_->inputs().size()); + for (const auto& namedArg : outputNode_->inputs()) { + ret.push_back(namedArg.value); + } + return ret; + } + + // Return a bidirectional range over `Value*` + // NOTE: this range can be invalidated by mutations to the graph. + auto outputs() { + std::vector ret; + ret.reserve(outputNode_->inputs().size()); + for (const auto& namedArg : outputNode_->inputs()) { + ret.push_back(namedArg.value); + } + return ret; + } + + const auto& userOutputs() const { + return userOutputs_; + } + + // Return a list over `const Node&`. + // NOTE: this can be invalidated by mutations to the graph. + const auto& nodes() const { + return nodes_; + } + + auto& nodes() { + return nodes_; + } + + // Return a forward range over `const Value*`. + // NOTE: this range can be invalidated by mutations to the graph. + auto values() const { + std::vector ret; + ret.reserve(values_.size()); + for (const auto& [_, value] : values_) { + ret.push_back(value.get()); + } + return ret; + } + + Node* inputNode() { + return inputNode_; + } + + Node* outputNode() { + return outputNode_; + } + + const Node* outputNode() const { + return outputNode_; + } + + // Assert various graph invariants + void lint() const; + + bool /* removed > 0? */ cleanupDeadNodes(); + + void finalize(); + + Node* insertionPoint() { + // This should never happen, since the last-most insertion point is the + // prim.Outputs node, not end(). + TORCH_CHECK(insertBefore_ != nodes_.end()); + auto& node = *insertBefore_; + return &node; + } + + void setInsertionPoint(Node* n) { + TORCH_CHECK(n != inputNode_, "can't insert before prim.Input"); + insertBefore_ = nodes_.iterator_to(*n); + } + + void setInsertionPointAfter(Node* n) { + TORCH_CHECK(n != outputNode_, "can't insert after prim.Output"); + auto it = nodes_.iterator_to(*n); + ++it; + insertBefore_ = it; + } + + // Return the next node in the Graph's node ordering. + // NOTE: Calling on the last node (prim.Output) returns nullptr. + Node* nodeAfter(Node* n); + const Node* nodeAfter(const Node* n) const; + + // Return the previous node in the Graph's node ordering. + // NOTE: Calling on the first node (prim.Input) returns nullptr. + Node* nodeBefore(Node* n); + const Node* nodeBefore(const Node* n) const; + + // Clone each node from subgraph (except prim.Input/prim.Output) into current + // graph. + // @param subgraph: the subgraph to be cloned + // @param inputs: values from the target graph that will serve as the + // subgraph's inputs + // @param valueMap: a map from the cloned subgraph's values to the target + // graph's values + std::vector insertGraph( + const Graph& subgraph, + std::vector inputs, + std::unordered_map& valueMap); + + const GraphSignature& signature() const { + return signature_; + } + + void setSignature(GraphSignature signature) { + signature_ = std::move(signature); + } + + void setWeightsMeta( + const std::unordered_map& + tensorsMeta) { + for (auto [name, tensorMeta] : tensorsMeta) { + weightsMeta_.emplace(name, TensorMeta{tensorMeta}); + } + } + + const std::unordered_map& weightsMeta() const { + return weightsMeta_; + } + + std::vector userInputsMeta() const { + std::vector userInputsMeta; + userInputsMeta.reserve(signature_.userInputs().size()); + for (auto inputName : signature_.userInputs()) { + userInputsMeta.push_back(tensorValuesMeta_.at(inputName)); + } + return userInputsMeta; + } + + void setTensorValuesMeta( + const std::unordered_map& + tensorsMeta) { + for (auto [name, tensorMeta] : tensorsMeta) { + tensorValuesMeta_.emplace(name, TensorMeta{tensorMeta}); + } + } + + const std::unordered_map& tensorValuesMeta() const { + return tensorValuesMeta_; + } + + std::string toString() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + /* Reassigns IDs to every Value in this Graph so that they are contiguous from + * 0..(numValues()-1). Should be used after values are removed + */ + void renumberValues(); + + private: + Graph(); + friend std::ostream& operator<<(std::ostream& out, const Graph& g); + GraphSignature signature_; + + // keys are parameters, buffers, tensor_constants' names + std::unordered_map weightsMeta_; + + // keys are tensor_values' names + std::unordered_map tensorValuesMeta_; + + // Node lifetime is managed by nodesOwner_, but the actual ordering is + // maintained intrusively using nodes_. + // This is to facilitate quick insertion before/after a given Node*. + std::vector> nodesOwner_; + c10::IntrusiveList nodes_; + // The current insertion point. New nodes are inserted before this node. + // Defaults to prim.Output. + c10::IntrusiveList::iterator insertBefore_; + + // Graphs always start with an input and output node. + // "prim.input() -> Value[]" take no input, and produces some outputs. AKA + // "source“ of a graph. + Node* inputNode_; // target: prim.Input + // "prim.output(Value[]) -> None", take some inputs, but produce no output. + // AKA "sink" of a graph. + Node* outputNode_; // target: prim.Output + + std::unordered_map> values_; + // constantSymIntValues_ is a subset of values_ + std::unordered_map constantSymIntValues_; + // Output values of the graph, which is a subset of values_. + std::vector> userOutputs_; + // Output constant values of the graph + std::vector constantOutputs_; + + size_t uniqueValueName_ = 0; + + ValueId uniqueValueId_ = 0; +}; + +/** + * Scoped utility class for setting temporary insertion points. + * + * Use like: + * { + * InsertingAfter guard(node) + * graph.insertNode(...) // this will be inserted after `node`. + * } + */ +class InsertingAfter { + public: + explicit InsertingAfter(Node* n) + : insertAfter_(n), prev_(n->owningGraph()->insertionPoint()) { + insertAfter_->owningGraph()->setInsertionPointAfter(insertAfter_); + } + ~InsertingAfter() { + insertAfter_->owningGraph()->setInsertionPoint(prev_); + } + + private: + Node* insertAfter_; + Node* prev_; +}; + +inline constexpr std::string_view kMemoryFormatPrefix = "MemoryFormat::"; +inline constexpr std::string_view kLayoutPrefix = "Layout::"; +inline constexpr std::string_view kDevicePrefix = "Device"; +inline constexpr std::string_view kScalarTypePrefix = "ScalarType::"; + +/** + * Debug format serialization. The format here is intended to be human readable + * and easy to work with, and is intended for debugging and testing only. + * If you want stable serialization, use the json conversion utils. + * + * NOTE: node metadata currently not serialized + */ +std::string graphToString(const Graph& g, bool include_signature = false); +std::unique_ptr stringToGraph(std::string_view source); + +// Standalone functions to parse common constructs +// Parse something that looks like `Device{cuda:1}` to a device in json format. +c10::Device convertDevice(std::string_view symbol); +// We have separate functions for parsing atomic and list constants because +// there are restrictive rules about which constants can go in lists (i.e. +// it's not recursive). +Constant convertAtomicConstant(std::string_view symbol); +Constant convertListConstant(std::string_view symbol); + +} // namespace torch::nativert diff --git a/torch/nativert/graph/GraphPasses.cpp b/torch/nativert/graph/GraphPasses.cpp new file mode 100644 index 00000000000000..327f32185e9105 --- /dev/null +++ b/torch/nativert/graph/GraphPasses.cpp @@ -0,0 +1,180 @@ +#include + +#include + +#include + +#include +#include + +#include + +namespace torch::nativert { +namespace { +bool isScalar(const Constant& c) { + return std::holds_alternative(c) || + std::holds_alternative(c); +} + +bool isScalar(const Value& v) { + return v.type() == Type::Kind::SymInt || v.type() == Type::Kind::SymFloat; +} + +bool schemaTypeMatch(const c10::FunctionSchema& schema, const Node& node) { + std::unordered_set inputNames; + for (const auto& input : node.inputs()) { + // The number of arguments is always O(10), so we can just do a linear scan. + for (const auto& schemaArg : schema.arguments()) { + if (schemaArg.name() == input.name) { + if (schemaArg.type() == c10::TensorType::get() && input.value && + isScalar(*input.value)) { + return false; + } + break; + } + } + inputNames.insert(input.name); + } + for (const auto& constant : node.attributes()) { + for (const auto& schemaArg : schema.arguments()) { + if (schemaArg.name() == constant.name) { + if (schemaArg.type() == c10::TensorType::get() && + isScalar(constant.value)) { + return false; + } + break; + } + } + inputNames.insert(constant.name); + } + + // Make sure we have all the required arguments. + for (const auto& schemaArg : schema.arguments()) { + if (!schemaArg.default_value()) { + if (inputNames.find(schemaArg.name()) == inputNames.end()) { + return false; + } + } + } + return true; +} + +} // namespace + +// PT2 intentionally broadcast things like aten.sub.Scalar +// to aten.sub.Tensor. https://github.com/pytorch/pytorch/issues/90923. +std::string selectScalarOverloadName(const Node& node) { + // Copied from torch/csrc/utils/python_arg_parser.cpp + // torch::should_allow_numbers_as_tensors() to workaround + // some linking issues. + static std::unordered_set allowed = { + "add", + "add_", + "add_out", + "div", + "div_", + "div_out", + "divide", + "divide_", + "divide_out", // alias of div + "mul", + "mul_", + "mul_out", + "multiply", + "multiply_", + "multiply_out", // alias of mul + "sub", + "sub_", + "sub_out", + "subtract", + "subtract_", + "subtract_out", // alias of sub + "true_divide", + "true_divide_", + "true_divide_out", + "to", + "_to_copy", + "copy_", + "copy", + "floor_divide", + "floor_divide_", + "floor_divide_out", + "_conj"}; + std::vector atoms = c10::split(node.target(), '.'); + TORCH_CHECK_GE(atoms.size(), 3); + + std::string ns = std::string{atoms[atoms.size() - 3]}; + std::string opName = std::string{atoms[atoms.size() - 2]}; + std::string overloadName = std::string{atoms[atoms.size() - 1]}; + if (overloadName != "Tensor" && overloadName != "Tensor_Tensor" && + overloadName != "Tensor_mode") { + return overloadName; + } + if (allowed.find(std::string{opName}) == allowed.end()) { + return overloadName; + } + auto op = c10::Dispatcher::singleton().findSchemaOrThrow( + fmt::format("{}::{}", ns, opName.c_str()).c_str(), overloadName.c_str()); + if (schemaTypeMatch(op.schema(), node)) { + return overloadName; + } + for (const auto& variant : + {"Scalar_mode", "Scalar", "Scalar_Tensor", "Tensor_Scalar"}) { + if (auto schema = c10::Dispatcher::singleton().findSchema( + {fmt::format("{}::{}", ns, opName.c_str()).c_str(), variant})) { + if (schemaTypeMatch(schema->schema(), node)) { + return variant; + } + } + } + return overloadName; +} + +void selectScalarOverload(Graph* graph) { + for (auto& node : graph->nodes()) { + for (auto& attr : node.attributes()) { + if (std::holds_alternative>(attr.value)) { + selectScalarOverload( + std::get>(attr.value).get()); + } + } + + auto target = node.target(); + std::vector atoms = c10::split(target, '.'); + + size_t numAtoms = atoms.size(); + if (numAtoms != 5) { + continue; + } + + const std::string_view ns = atoms[numAtoms - 3]; + const std::string_view opName = atoms[numAtoms - 2]; + if (atoms[0] != "torch" || atoms[1] != "ops" || ns != "aten") { + continue; + } + + auto overloadName = selectScalarOverloadName(node); + if (overloadName != atoms[numAtoms - 1]) { + node.setTarget( + fmt::format("torch.ops.{}.{}.{}", ns, opName, overloadName)); + } else if (ns == "aten" && opName == "sub" && overloadName == "Tensor") { + // Special case for aten.sub.Tensor. + if (auto i = node.tryGetInput("self")) { + if (isScalar(*i->value)) { + node.updateInputName("self", "other"); + node.updateInputName("other", "self"); + node.setTarget("torch.ops.aten.rsub.Scalar"); + } + } + if (auto a = node.tryGetAttribute("self")) { + if (isScalar(a->value)) { + node.updateAttributeName("self", "other"); + node.updateInputName("other", "self"); + node.setTarget("torch.ops.aten.rsub.Scalar"); + } + } + } + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/graph/GraphPasses.h b/torch/nativert/graph/GraphPasses.h new file mode 100644 index 00000000000000..7971aeb6b22429 --- /dev/null +++ b/torch/nativert/graph/GraphPasses.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace torch::nativert { + +void selectScalarOverload(Graph* graph); + +std::string selectScalarOverloadName(const Node& node); + +} // namespace torch::nativert diff --git a/torch/nativert/graph/GraphSignature.h b/torch/nativert/graph/GraphSignature.h index 890fb9f7173233..a9e2a95bbaa6b9 100644 --- a/torch/nativert/graph/GraphSignature.h +++ b/torch/nativert/graph/GraphSignature.h @@ -14,7 +14,7 @@ namespace torch::nativert { * * The GraphSignature class models the input and output specs of an exported * graph produced by torch.export, which is a fx.Graph with stronger invariants - * gurantees. It holds the graph information deserialized from the pt2 archive + * guarantees. It holds the graph information deserialized from the pt2 archive * package. Runtime relies on the GraphSignature for weight name lookup and * weight loading. The serialization schema is defined in * torch/_export/serde/schema.py See more at: diff --git a/torch/nativert/graph/Serialization.cpp b/torch/nativert/graph/Serialization.cpp new file mode 100644 index 00000000000000..d32e7fe7284362 --- /dev/null +++ b/torch/nativert/graph/Serialization.cpp @@ -0,0 +1,548 @@ +#include +#include +#include +#include +#include +namespace torch::nativert { + +namespace { + +std::unique_ptr jsonToSubgraph( + const torch::_export::Graph& jsonGraph, + const torch::_export::GraphSignature* signature, + bool loadNodeMetadata); + +Value* symbolicToValue( + const torch::_export::Argument& arg, + Graph& graph, + Node* insertBefore) { + switch (arg.tag()) { + case torch::_export::Argument::Tag::AS_TENSOR: + return graph.getValue(arg.get_as_tensor().get_name()); + case torch::_export::Argument::Tag::AS_TENSORS: { + // Need to insert a list pack node + std::vector listValue; + for (const auto& listEl : arg.get_as_tensors()) { + listValue.push_back(graph.getValue(listEl.get_name())); + } + auto listPack = + graph.createListPack(std::move(listValue), Type::Kind::Tensor); + return graph.insertBefore(listPack, insertBefore)->outputs()[0]; + } + case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: { + // Need to insert a list pack node + std::vector listValue; + for (const auto& listEl : arg.get_as_optional_tensors()) { + switch (listEl.tag()) { + case torch::_export::OptionalTensorArgument::Tag::AS_TENSOR: { + listValue.push_back( + graph.getValue(listEl.get_as_tensor().get_name())); + break; + } + case torch::_export::OptionalTensorArgument::Tag::AS_NONE: { + listValue.push_back( + graph.addValue(std::nullopt, Type::Kind::None, nullptr)); + break; + } + default: + TORCH_CHECK( + false, + fmt::format( + "Unknown OptionalTensorArgument type: {}", + torch::_export::printEnum(listEl.tag()))); + } + } + auto listPack = graph.createOptionalListPack(std::move(listValue)); + return graph.insertBefore(listPack, insertBefore)->outputs()[0]; + } + case torch::_export::Argument::Tag::AS_SYM_INT: { + return graph.getValue(arg.get_as_sym_int().get_as_name()); + } + case torch::_export::Argument::Tag::AS_SYM_INTS: { + // Need to insert a list pack node + std::vector listValue; + for (const auto& listEl : arg.get_as_sym_ints()) { + switch (listEl.tag()) { + case torch::_export::SymIntArgument::Tag::AS_NAME: { + listValue.push_back(graph.getValue(listEl.get_as_name())); + break; + } + case torch::_export::SymIntArgument::Tag::AS_INT: { + // These are concrete int values in the SymIntList, e.g [s0, 8] + // We convert them into a constant Value in graph. These value + // doesn't have producer node + int64_t value = listEl.get_as_int(); + TORCH_CHECK( + value >= std::numeric_limits::min() && + value <= std::numeric_limits::max()); + Value* symintValue = + graph.createConstantSymIntValue(static_cast(value)); + listValue.push_back(symintValue); + break; + } + default: + TORCH_CHECK( + false, + fmt::format( + "Unknown SymIntArgument type: {}", + torch::_export::printEnum(listEl.tag()))); + } + } + auto listPack = + graph.createListPack(std::move(listValue), Type::Kind::SymInt); + return graph.insertBefore(listPack, insertBefore)->outputs()[0]; + } + case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: { + return graph.getValue(arg.get_as_custom_obj().get_name()); + } + case torch::_export::Argument::Tag::AS_SYM_BOOL: { + return graph.getValue(arg.get_as_sym_bool().get_as_name()); + } + case torch::_export::Argument::Tag::AS_SYM_FLOAT: { + return graph.getValue(arg.get_as_sym_float().get_as_name()); + } + default: + TORCH_CHECK( + false, + fmt::format( + "This function should only be called with symbolic arguments, got {} instead", + torch::_export::printEnum(arg.tag()))); + } +} + +std::pair< + std::vector, + std::vector> +enforceInputOrder( + const std::vector& inputSpecs, + const std::vector& graphInputs) { + // Enforce the order of inputSpecs and graphInputs to be the following: + // 1. token + // 2. parameter + // 3. persistent buffer, non-persistent buffer + // 4. tensor_constant + // 5. custom_obj + // 6. user_input/constant_input + std::vector reorderedInputSpecs; + std::vector reorderedGraphInputs; + std::vector desiredOrder = { + torch::_export::InputSpec::Tag::TOKEN, + torch::_export::InputSpec::Tag::PARAMETER, + torch::_export::InputSpec::Tag::BUFFER, + torch::_export::InputSpec::Tag::TENSOR_CONSTANT, + torch::_export::InputSpec::Tag::CUSTOM_OBJ}; + + auto reorder = [&](auto condition) { + for (size_t i = 0; i < inputSpecs.size(); ++i) { + if (condition(inputSpecs[i])) { + reorderedInputSpecs.push_back(inputSpecs[i]); + reorderedGraphInputs.push_back(graphInputs[i]); + } + } + }; + + for (const auto& tag : desiredOrder) { + if (tag == torch::_export::InputSpec::Tag::BUFFER) { + // Add persistent buffers first, then non-persistent + reorder([&](const auto& spec) { + return spec.tag() == tag && spec.get_buffer().get_persistent(); + }); + reorder([&](const auto& spec) { + return spec.tag() == tag && !spec.get_buffer().get_persistent(); + }); + } else { + reorder([&](const auto& spec) { return spec.tag() == tag; }); + } + } + + // Append USER_INPUT and CONSTANT_INPUT without reordering + for (size_t i = 0; i < inputSpecs.size(); ++i) { + auto tag = inputSpecs[i].tag(); + if (tag == torch::_export::InputSpec::Tag::USER_INPUT || + tag == torch::_export::InputSpec::Tag::CONSTANT_INPUT) { + reorderedInputSpecs.push_back(inputSpecs[i]); + reorderedGraphInputs.push_back(graphInputs[i]); + } + } + return {std::move(reorderedInputSpecs), std::move(reorderedGraphInputs)}; +} + +std::unique_ptr jsonToSubgraph( + const torch::_export::Graph& jsonGraph, + const torch::_export::GraphSignature* signature, + bool loadNodeMetadata) { + auto graphInputs = jsonGraph.get_inputs(); + auto graph = Graph::createGraph(); + + if (signature) { + // enforcing the order signature inputspecs and graph inputs + const auto& inputSpecs = signature->get_input_specs(); + + auto [reorderedInputSpecs, reorderedGraphInputs] = + enforceInputOrder(inputSpecs, graphInputs); + + graphInputs = std::move(reorderedGraphInputs); + auto reorderedSignature = *signature; + reorderedSignature.set_input_specs(reorderedInputSpecs); + graph->setSignature(GraphSignature{reorderedSignature}); + } + + for (const auto& input : graphInputs) { + if (isSymbolic(input)) { + switch (input.tag()) { + case torch::_export::Argument::Tag::AS_TENSOR: { + const auto& asTensor = input.get_as_tensor(); + const auto& name = asTensor.get_name(); + graph->addInput(name, Type::Kind::Tensor); + break; + } + case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: { + const auto& asCustomObj = input.get_as_custom_obj(); + const std::string& name = asCustomObj.get_name(); + const std::string& classFqn = asCustomObj.get_class_fqn(); + graph->addInput(name, Type(Type::Kind::CustomObj, classFqn)); + break; + } + default: + TORCH_CHECK( + false, + fmt::format( + "Unsupported symbolic graph input type: {}", + torch::_export::printEnum(input.tag()))); + } + } else { + switch (input.tag()) { + case torch::_export::Argument::Tag::AS_INT: + case torch::_export::Argument::Tag::AS_FLOAT: + case torch::_export::Argument::Tag::AS_STRING: + case torch::_export::Argument::Tag::AS_BOOL: + case torch::_export::Argument::Tag::AS_NONE: { + // Constant graph inputs are specialized in the graph, here we simply + // add a nullptr of Value to the graph input node. + graph->addInput(); + break; + } + default: + TORCH_CHECK( + false, + fmt::format( + "Unsupported constant graph input type: {}", + torch::_export::printEnum(input.tag()))); + } + } + } + + for (const auto& jsonNode : jsonGraph.get_nodes()) { + auto node = graph->insertNode( + jsonNode.get_target(), + {}, + loadNodeMetadata ? jsonNode.get_metadata() + : std::unordered_map()); + + std::vector args; + std::vector attributes; + for (const auto& input : jsonNode.get_inputs()) { + // We handle constants and symbolic inputs differently. + const auto& arg = input.get_arg(); + if (isSymbolic(arg)) { + // Symbolic values are made part of the inputs to the node + node->addInput(NamedArgument{ + input.get_name(), symbolicToValue(input.get_arg(), *graph, node)}); + } else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) { + node->addInput(NamedArgument{ + input.get_name(), + graph->addValue(std::nullopt, Type::Kind::None, node)}); + } else { + node->addAttribute(Attribute{ + input.get_name(), + constantToValue(input.get_arg(), loadNodeMetadata)}); + // Constant values are added as "attributes" to the node. + } + } + + std::vector outputs; + std::vector listUnpacksToCreate; + for (const auto& output : jsonNode.get_outputs()) { + switch (output.tag()) { + case torch::_export::Argument::Tag::AS_NONE: { + node->addOutput(Type::Kind::None); + break; + } + case torch::_export::Argument::Tag::AS_TENSOR: { + const auto name = output.get_as_tensor().get_name(); + node->addOutput(name, Type::Kind::Tensor); + break; + } + case torch::_export::Argument::Tag::AS_TENSORS: { + auto outputValue = node->addOutput( + graph->getUniqueValueName(), Type::Kind::TensorList); + + Node* listUnpack = + graph->insertNode("prim.ListUnpack", {{"input", outputValue}}); + for (const auto& arg : output.get_as_tensors()) { + listUnpack->addOutput(arg.get_name(), Type::Kind::Tensor); + } + break; + } + case torch::_export::Argument::Tag::AS_SYM_INT: { + const auto name = output.get_as_sym_int().get_as_name(); + node->addOutput(name, Type::Kind::SymInt); + break; + } + case torch::_export::Argument::Tag::AS_SYM_INTS: { + TORCH_CHECK( + false, + "SymInts NYI. We currently don't have ops that produce SymInts as output"); + } + case torch::_export::Argument::Tag::AS_SYM_BOOL: { + const auto name = output.get_as_sym_bool().get_as_name(); + node->addOutput(name, Type::Kind::SymBool); + break; + } + case torch::_export::Argument::Tag::AS_SYM_BOOLS: { + TORCH_CHECK( + false, + "SymBools NYI. We currently don't have ops that produce SymBools as output"); + } + case torch::_export::Argument::Tag::AS_SYM_FLOAT: { + const auto name = output.get_as_sym_float().get_as_name(); + node->addOutput(name, Type::Kind::SymFloat); + break; + } + case torch::_export::Argument::Tag::AS_SYM_FLOATS: { + TORCH_CHECK( + false, + "SymFloats NYI. We currently doesn't have op that produces SymFloats as output"); + } + default: + TORCH_CHECK( + false, + fmt::format( + "Unsupported graph output type: {}", + torch::_export::printEnum(output.tag()))); + } + } + } + + for (const auto& output : jsonGraph.get_outputs()) { + // handle symbolic outputs and constant outputs differently + if (isSymbolic(output)) { + switch (output.tag()) { + case torch::_export::Argument::Tag::AS_TENSOR: { + const auto& asTensor = output.get_as_tensor(); + const auto& name = asTensor.get_name(); + Value* outputValue = graph->getValue(name); + graph->addOutput(outputValue); + break; + } + case torch::_export::Argument::Tag::AS_SYM_INT: { + const auto& asSymInt = output.get_as_sym_int(); + TORCH_CHECK( + asSymInt.tag() == torch::_export::SymIntArgument::Tag::AS_NAME); + const auto& name = asSymInt.get_as_name(); + Value* outputValue = graph->getValue(name); + graph->addOutput(outputValue); + break; + } + default: + TORCH_CHECK( + false, + fmt::format( + "Unsupported graph output type: {}", + torch::_export::printEnum(output.tag()))); + } + } else { + Constant constValue = constantToValue(output, loadNodeMetadata); + graph->addConstantOutput(std::move(constValue)); + } + } + + auto jsonTensorValue = jsonGraph.get_tensor_values(); + + if (!signature) { + // For subgraphs we just need to derive a graph signature that only + // contains user inputs and outputs, because we don't need to handle any + // special semantics for them, e.g. mutation or gradients. + torch::_export::GraphSignature sig; + std::vector inputSpecs; + for (const auto& input : graph->inputs()) { + torch::_export::Argument arg; + if (input->type().kind() == Type::Kind::Tensor) { + torch::_export::TensorArgument targ; + targ.set_name(std::string{input->name()}); + arg.set_as_tensor(std::move(targ)); + } else { + TORCH_CHECK( + false, + fmt::format( + "Unsupported subgraph input type {}", + fmt::streamed(input->type()))); + } + torch::_export::UserInputSpec userInputSpec; + userInputSpec.set_arg(std::move(arg)); + torch::_export::InputSpec inputSpec; + inputSpec.set_user_input(std::move(userInputSpec)); + inputSpecs.push_back(std::move(inputSpec)); + } + sig.set_input_specs(std::move(inputSpecs)); + + std::vector outputSpecs; + for (const auto& output : graph->outputs()) { + torch::_export::Argument arg; + if (output->type().kind() == Type::Kind::Tensor) { + torch::_export::TensorArgument targ; + targ.set_name(std::string{output->name()}); + arg.set_as_tensor(std::move(targ)); + } else { + TORCH_CHECK( + false, + fmt::format( + "Unsupported subgraph output type {}", + fmt::streamed(output->type()))); + } + torch::_export::UserOutputSpec userOutputSpec; + userOutputSpec.set_arg(std::move(arg)); + torch::_export::OutputSpec outputSpec; + outputSpec.set_user_output(std::move(userOutputSpec)); + outputSpecs.push_back(std::move(outputSpec)); + } + sig.set_output_specs(std::move(outputSpecs)); + + graph->setSignature(GraphSignature{sig}); + } + + // weightsTensorMeta are indexed by weight's name, not graph input's name + std::unordered_map weightsTensorMeta; + for (const auto& [inputName, weightName] : + graph->signature().inputsToWeights()) { + auto value = graph->getValue(inputName); + if (value->type().kind() == Type::Kind::CustomObj) { + // skip setting meta for non-tensor inputs + continue; + } + + auto it = jsonTensorValue.find(inputName); + CHECK(it != jsonTensorValue.end()) + << "Missing tensor metadata for " << inputName + << "in thriftGraph.tensorValue"; + weightsTensorMeta[weightName] = it->second; + } + graph->setWeightsMeta(weightsTensorMeta); + + graph->setTensorValuesMeta(jsonTensorValue); + + graph->finalize(); + + graph->lint(); + return graph; +} + +} // namespace + +bool isSymbolic(const torch::_export::Argument& arg) { + switch (arg.tag()) { + case torch::_export::Argument::Tag::AS_TENSOR: + case torch::_export::Argument::Tag::AS_TENSORS: + case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: + case torch::_export::Argument::Tag::AS_SYM_INT: + case torch::_export::Argument::Tag::AS_SYM_INTS: + case torch::_export::Argument::Tag::AS_SYM_BOOL: + case torch::_export::Argument::Tag::AS_SYM_BOOLS: + case torch::_export::Argument::Tag::AS_SYM_FLOAT: + case torch::_export::Argument::Tag::AS_SYM_FLOATS: + case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: + return true; + default: + return false; + } +} + +Constant constantToValue( + const torch::_export::Argument& jsonArg, + bool loadNodeMetadata) { + switch (jsonArg.tag()) { + case torch::_export::Argument::Tag::AS_NONE: + return None(); + case torch::_export::Argument::Tag::AS_INT: + return jsonArg.get_as_int(); + case torch::_export::Argument::Tag::AS_INTS: { + std::vector ret; + for (const auto& arg : jsonArg.get_as_ints()) { + ret.push_back(arg); + } + return ret; + } + case torch::_export::Argument::Tag::AS_FLOAT: + return jsonArg.get_as_float().get(); + case torch::_export::Argument::Tag::AS_FLOATS: { + std::vector ret; + for (const auto& arg : jsonArg.get_as_floats()) { + ret.push_back(arg.get()); + } + return ret; + } + case torch::_export::Argument::Tag::AS_STRING: + return jsonArg.get_as_string(); + case torch::_export::Argument::Tag::AS_STRINGS: { + std::vector ret; + for (const auto& arg : jsonArg.get_as_strings()) { + ret.push_back(arg); + } + return ret; + } + case torch::_export::Argument::Tag::AS_SCALAR_TYPE: + return convertJsonScalarType(jsonArg.get_as_scalar_type()); + case torch::_export::Argument::Tag::AS_MEMORY_FORMAT: + return convertJsonMemoryFormat(jsonArg.get_as_memory_format()); + case torch::_export::Argument::Tag::AS_LAYOUT: + return convertJsonLayout(jsonArg.get_as_layout()); + case torch::_export::Argument::Tag::AS_DEVICE: + return convertJsonDevice(jsonArg.get_as_device()); + case torch::_export::Argument::Tag::AS_BOOL: + return jsonArg.get_as_bool(); + case torch::_export::Argument::Tag::AS_BOOLS: { + std::vector ret; + for (const auto& arg : jsonArg.get_as_bools()) { + ret.push_back(arg); + } + return ret; + } + case torch::_export::Argument::Tag::AS_GRAPH: { + return jsonToSubgraph( + *jsonArg.get_as_graph().get_graph(), nullptr, loadNodeMetadata); + } + case torch::_export::Argument::Tag::AS_TENSOR: + case torch::_export::Argument::Tag::AS_TENSORS: + case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: + TORCH_CHECK(false, "Tensor values are symbolic, not constant."); + case torch::_export::Argument::Tag::AS_SYM_INT: + case torch::_export::Argument::Tag::AS_SYM_INTS: + case torch::_export::Argument::Tag::AS_SYM_BOOL: + case torch::_export::Argument::Tag::AS_SYM_BOOLS: + TORCH_CHECK(false, "Symint/Symbool Values are symbolic, not constant."); + case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: + TORCH_CHECK(false, "custom obj is symbolic, not constant"); + case torch::_export::Argument::Tag::AS_OPERATOR: + return jsonArg.get_as_operator(); + case torch::_export::Argument::Tag::AS_SYM_FLOAT: { + TORCH_CHECK(false, "SymFloat is not yet implemented"); + } + case torch::_export::Argument::Tag::AS_SYM_FLOATS: { + TORCH_CHECK(false, "SymFloats is not yet implemented"); + } + default: + TORCH_CHECK(false, "Got unknown json argument"); + } +} + +std::unique_ptr jsonToGraph( + const torch::_export::GraphModule& jsonGraphModule, + bool loadNodeMetadata) { + auto graph = jsonToSubgraph( + jsonGraphModule.get_graph(), + &jsonGraphModule.get_signature(), + loadNodeMetadata); + return graph; +} + +} // namespace torch::nativert diff --git a/torch/nativert/graph/Serialization.h b/torch/nativert/graph/Serialization.h new file mode 100644 index 00000000000000..6604bfbc516527 --- /dev/null +++ b/torch/nativert/graph/Serialization.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include + +namespace torch::nativert { +/** + * This file contains serialization utilities for Graph. + * + * There are two serialized representations we care about: + * - Json: stable but hard to work with, not really human readable + * - Debug format: human-readable, not stable. + */ + +// Json -> Graph +std::unique_ptr jsonToGraph( + const torch::_export::GraphModule& jsonGraph, + bool loadNodeMetadata = true); + +bool isSymbolic(const torch::_export::Argument& arg); + +Constant constantToValue( + const torch::_export::Argument& jsonArg, + bool loadNodeMetadata); + +} // namespace torch::nativert diff --git a/torch/nativert/graph/TensorMeta.cpp b/torch/nativert/graph/TensorMeta.cpp index c42cb6b39d9ebd..d7d83710a5a35c 100644 --- a/torch/nativert/graph/TensorMeta.cpp +++ b/torch/nativert/graph/TensorMeta.cpp @@ -116,7 +116,7 @@ TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta) numel_ *= val; } else if (size.tag() == torch::_export::SymInt::Tag::AS_EXPR) { // TODO: it's still unclear how SymInt shape should be used in runtime - // One potential use cases is for verifing inputs shape matches constrain + // One potential use cases is for verifying inputs shape matches constrain // This would require unpacking the serialized constrain, which is NYI // // For the time being, we just set the symbolic dim to -1 diff --git a/torch/nativert/kernels/AutoFunctionalizeKernel.cpp b/torch/nativert/kernels/AutoFunctionalizeKernel.cpp new file mode 100644 index 00000000000000..cbbd502d821528 --- /dev/null +++ b/torch/nativert/kernels/AutoFunctionalizeKernel.cpp @@ -0,0 +1,64 @@ +#include + +#include + +#include + +namespace torch::nativert { + +UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node) + : OpKernel(node), + op_(getOperatorForTarget( + std::get(node->attributes()[0].value))), + schema_(op_.schema()), + arguments_(prefillStackWithStaticArgs(node, schema_)) { + for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) { + if (schemaArg.alias_info() != nullptr && + schemaArg.alias_info()->isWrite()) { + mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value); + } + } + + numOutputs_ = schema_.returns().size(); +} + +void UnsafeAutoFunctionalizeKernel::computeInternal( + ExecutionFrame& executionFrame) const { + // Make a copy of the stack + std::vector stack = arguments_.getStackWithStaticArgs(); + + fillDynamicInputs(executionFrame, arguments_, stack); + + // Call the op with the prepared stack. + try { + op_.callBoxed(stack); + } catch (const std::exception& ex) { + // TODO: this eats the original exception type. ATen returns different + // exception types that correspond to different Python errors (e.g. + // IndexError, ValueError). If retaining this information is important + // to us, we'll have to change this up a little. + auto stackTrace = node_->getMetadata("stack_trace"); + throw std::runtime_error(fmt::format( + "Original Python stacktrace:\n{}\n{}", + stackTrace ? *stackTrace : "", + ex.what())); + } + + const auto& outputValues = node_->outputs(); + + for (int i = 0; i < numOutputs_; ++i) { + executionFrame.setIValue(outputValues[i]->id(), std::move(stack.at(i))); + } + + // Copy over mutating inputs to outputs + int mutatingArgStartIndex = (numOutputs_ == 0) ? 1 : numOutputs_; + for (size_t i = mutatingArgStartIndex; i < outputValues.size(); ++i) { + executionFrame.setIValue( + outputValues[i]->id(), + executionFrame.getIValue( + mutatingInputArgs_.at(i - mutatingArgStartIndex)->id(), + true /* allowNone */)); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/AutoFunctionalizeKernel.h b/torch/nativert/kernels/AutoFunctionalizeKernel.h new file mode 100644 index 00000000000000..f9d6e6e58c6c97 --- /dev/null +++ b/torch/nativert/kernels/AutoFunctionalizeKernel.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::nativert { + +class UnsafeAutoFunctionalizeKernel : public OpKernel { + public: + UnsafeAutoFunctionalizeKernel() = delete; // deleted default constructor + UnsafeAutoFunctionalizeKernel(const Node* node); + + void computeInternal(ExecutionFrame& executionFrame) const override final; + + private: + c10::OperatorHandle op_; + c10::FunctionSchema schema_; + + Arguments arguments_; + + std::vector mutatingInputArgs_; + int numOutputs_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/C10Kernel.cpp b/torch/nativert/kernels/C10Kernel.cpp new file mode 100644 index 00000000000000..450042e7c92d3b --- /dev/null +++ b/torch/nativert/kernels/C10Kernel.cpp @@ -0,0 +1,265 @@ +#include + +#include + +#include + +#ifdef __SIGRID_USE_GPU__ +#include +#include +#endif + +namespace torch::nativert { + +C10Kernel::C10Kernel( + const Node* node, + c10::Device device, + OpKernelKind kind, + AliasingSpec&& aliasingSpec) + : OpKernel(node, device, kind), + op_(getOperatorForTarget(node->target(), node)), + schema_(op_.schema(), std::move(aliasingSpec), kind_), + arguments_(prefillStackWithStaticArgs(node, op_.schema())) {} + +void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const { + // Make a copy of the stack + std::vector stack = arguments_.getStackWithStaticArgs(); + + fillDynamicInputs(executionFrame, arguments_, stack); + + // Call the op with the prepared stack. + try { + op_.callBoxed(stack); + } catch (const std::exception& ex) { + auto stackTrace = node_->getMetadata("stack_trace"); + throw std::runtime_error(fmt::format( + "Exception while executing node: {}\n" + "with args:\n{}\n" + "{}\n" + "Original Python stacktrace:\n{}", + fmt::streamed(*node_), + readableArgs(op_.schema(), stack), + ex.what(), + stackTrace ? *stackTrace : "")); + } + + // Write out results + // TODO: we store intermediates in a single table (symint and tensor alike). + // This can theoretically lead to name collisions, although based on how + // these are named I don't think it will ever happen in practice. We need to + // enforce it though. + const auto& outputValues = node_->outputs(); + TORCH_CHECK_EQ(outputValues.size(), stack.size()) + << "Output size mismatch for " << node_->toString(); + for (auto&& [i, actualOutput] : c10::enumerate(stack)) { + executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput)); + } +} + +namespace { +std::unordered_map getSymInputs( + const ExecutionFrame& executionFrame, + const Node& node) { + std::unordered_map inputs; + for (const auto& input : node.inputs()) { + const auto& val = executionFrame.getIValue(input.value->id()); + if (val.isInt() || val.isDouble() || val.isBool()) { + inputs[input.name] = val; + } else { + throw std::runtime_error("unsupported type for symbolic input"); + } + } + for (const auto& attribute : node.attributes()) { + if (std::holds_alternative(attribute.value)) { + inputs[attribute.name] = std::get(attribute.value); + } else if (std::holds_alternative(attribute.value)) { + inputs[attribute.name] = std::get(attribute.value); + } else if (std::holds_alternative(attribute.value)) { + inputs[attribute.name] = std::get(attribute.value); + } else { + throw std::runtime_error("unsupported type for symbolic input"); + } + } + return inputs; +} + +template +void computeScalarBinaryOp( + ExecutionFrame& executionFrame, + const Node& node, + std::enable_if_t a, + std::enable_if_t b) { + std::string_view target = node.target(); + T out; + + if (target == "_operator.add") { + out = a + b; + } else if (target == "_operator.sub") { + out = a - b; + } else if (target == "_operator.mul") { + out = a * b; + } else if (target == "_operator.pow") { + out = std::pow(a, b); + } else { + throw std::runtime_error( + fmt::format("unsupported operator for symbolic values: {}", target)); + } + + executionFrame.setIValue(node.outputs()[0]->id(), out); + VLOG(2) << fmt::format( + "Completed executing node: {} with a={}, b={}, out={}", + fmt::streamed(node), + a, + b, + out); +} + +} // namespace + +void ScalarBinaryOpKernel::computeInternal( + ExecutionFrame& executionFrame) const { + auto inputs = getSymInputs(executionFrame, *node_); + + const auto& a = inputs.at("a"); + const auto& b = inputs.at("b"); + + auto coerceToDouble = [](const c10::IValue& x) -> double { + if (x.isInt()) { + return static_cast(x.toInt()); + } else if (x.isDouble()) { + return x.toDouble(); + } else { + throw std::runtime_error("unsupported type for symbolic input"); + } + }; + + if (a.isInt() && b.isInt()) { + computeScalarBinaryOp( + executionFrame, *node_, a.toInt(), b.toInt()); + } else { + computeScalarBinaryOp( + executionFrame, *node_, coerceToDouble(a), coerceToDouble(b)); + } +} + +void SymIntOpKernel::computeInternal(ExecutionFrame& executionFrame) const { + auto inputs = getSymInputs(executionFrame, *node_); + + int64_t a = inputs.at("a").toInt(); + std::string_view target = node_->target(); + if (target == "torch.sym_float") { + double out = static_cast(a); + executionFrame.setIValue(node_->outputs()[0]->id(), out); + VLOG(2) << fmt::format( + "Completed executing node: {} with a={}, out={}", + fmt::streamed(*node_), + a, + out); + return; + } + int64_t b = inputs.at("b").toInt(); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t out; + + if (target == "_operator.floordiv") { + out = a / b; + } else if (target == "_operator.mod") { + out = a % b; + } else if (target == "torch.sym_max") { + out = std::max(a, b); + } else if (target == "torch.sym_min") { + out = std::min(a, b); + } else { + throw std::runtime_error( + fmt::format("unsupported operator for SymInt: {}", node_->target())); + } + + executionFrame.setIValue(node_->outputs()[0]->id(), out); + VLOG(2) << fmt::format( + "Completed executing node: {} with a={}, b={}, out={}", + fmt::streamed(*node_), + a, + b, + out); +} + +void SymBoolOpKernel::computeInternal(ExecutionFrame& executionFrame) const { + auto inputs = getSymInputs(executionFrame, *node_); + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool out; + + const std::string_view target = node_->target(); + if (target == "torch.sym_not") { + bool a = inputs.at("a").toBool(); + out = !a; + } else if (target == "_operator.ge") { + int64_t a = inputs.at("a").toInt(); + int64_t b = inputs.at("b").toInt(); + out = a >= b; + } else if (target == "_operator.le") { + int64_t a = inputs.at("a").toInt(); + int64_t b = inputs.at("b").toInt(); + out = a <= b; + } else if (target == "_operator.eq") { + int64_t a = inputs.at("a").toInt(); + int64_t b = inputs.at("b").toInt(); + out = a == b; + } else if (target == "_operator.gt") { + int64_t a = inputs.at("a").toInt(); + int64_t b = inputs.at("b").toInt(); + out = a > b; + } else if (target == "_operator.lt") { + int64_t a = inputs.at("a").toInt(); + int64_t b = inputs.at("b").toInt(); + out = a < b; + } else if (target == "_operator.and_") { + bool a = inputs.at("a").toBool(); + bool b = inputs.at("b").toBool(); + out = a && b; + } else { + throw std::runtime_error( + fmt::format("unsupported operator for SymBool: {}", node_->target())); + } + + executionFrame.setIValue(node_->outputs()[0]->id(), out); +} + +void SymFloatOpKernel::computeInternal(ExecutionFrame& executionFrame) const { + auto inputs = getSymInputs(executionFrame, *node_); + + const std::string_view target = node_->target(); + if (target == "math.trunc") { + double x = inputs.at("x").toDouble(); + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + int64_t out = trunc(x); + executionFrame.setIValue(node_->outputs()[0]->id(), out); + } else if (target == "torch._sym_sqrt") { + double a = inputs.at("a").toDouble(); + double out = std::sqrt(a); + executionFrame.setIValue(node_->outputs()[0]->id(), out); + } else if (target == "_operator.neg") { + auto a = inputs.at("a"); + c10::IValue out; + if (a.isInt()) { + out = -a.toInt(); + } else if (a.isDouble()) { + out = -a.toDouble(); + } else { + throw std::runtime_error("unsupported type for symbolic input"); + } + executionFrame.setIValue(node_->outputs()[0]->id(), out); + } else if (target == "_operator.truediv") { + auto ia = inputs.at("a"); + double a = ia.isInt() ? static_cast(ia.toInt()) : ia.toDouble(); + auto ib = inputs.at("b"); + double b = ib.isInt() ? static_cast(ib.toInt()) : ib.toDouble(); + double out = a / b; + executionFrame.setIValue(node_->outputs()[0]->id(), out); + } else { + throw std::runtime_error( + fmt::format("unsupported operator for SymFloat: {}", node_->target())); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/C10Kernel.h b/torch/nativert/kernels/C10Kernel.h new file mode 100644 index 00000000000000..0f23096dd1f2a9 --- /dev/null +++ b/torch/nativert/kernels/C10Kernel.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::nativert { + +// Implementation of Kernel for ATen operators +// +// This class exists to amortize per-kernel overhead by computing things during +// initialization instead of on every execution. Right now we are only +// amortizing schema resolution, and static arguments parsing, +// but in the future this could be extended to avoid operator dispatch and +// do better "Register" allocation (e.g. convert input/outputs to directly +// array accesses onto a set of registers, in concert with memory planning) +class C10Kernel : public OpKernel { + public: + C10Kernel() = delete; // deleted default constructor + C10Kernel( + const Node* node, + c10::Device device, + OpKernelKind kind = OpKernelKind::kInterpreterFallbackKernel, + AliasingSpec&& aliasingSpec = {}); + virtual ~C10Kernel() override = default; + + [[nodiscard]] const c10::IValue& input( + uint32_t i, + ExecutionFrame& executionFrame) const override { + if (Value* dynamicArg = arguments_.findDynamic(i)) { + return executionFrame.getIValue(dynamicArg->id()); + } + return attribute(i); + } + + [[nodiscard]] const c10::IValue& attribute(uint32_t i) const { + return arguments_.getStatic(i); + } + + C10_ALWAYS_INLINE const FunctionSchema& schema() const { + return schema_; + } + + void computeInternal(ExecutionFrame& executionFrame) const override; + + private: + c10::OperatorHandle op_; + FunctionSchema schema_; + + Arguments arguments_; +}; + +class SymIntOpKernel : public OpKernel { + public: + explicit SymIntOpKernel(const Node* node) : OpKernel(node) {} + void computeInternal(ExecutionFrame& executionFrame) const override final; +}; + +class SymBoolOpKernel : public OpKernel { + public: + explicit SymBoolOpKernel(const Node* node) : OpKernel(node) {} + void computeInternal(ExecutionFrame& executionFrame) const override final; +}; + +class SymFloatOpKernel : public OpKernel { + public: + explicit SymFloatOpKernel(const Node* node) : OpKernel(node) {} + void computeInternal(ExecutionFrame& executionFrame) const override final; +}; + +// ScalarOpKernel does binary arithmetic operations on scalar values. +// Integers and floats are supported as input types. The output will be +// promoted to float if and only if there's at least one float input. +class ScalarBinaryOpKernel : public OpKernel { + public: + explicit ScalarBinaryOpKernel(const Node* node) : OpKernel(node) {} + void computeInternal(ExecutionFrame& executionFrame) const override final; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/CallTorchBindKernel.cpp b/torch/nativert/kernels/CallTorchBindKernel.cpp new file mode 100644 index 00000000000000..5e8c9cf6be7598 --- /dev/null +++ b/torch/nativert/kernels/CallTorchBindKernel.cpp @@ -0,0 +1,51 @@ +#include + +#include + +#include + +namespace torch::nativert { + +CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) { + const Value* customObjValue = node_->inputs()[0].value; + CHECK(customObjValue->type() == Type::Kind::CustomObj); + + customClassName_ = customObjValue->type().classFqn(); + customClassType_ = torch::jit::getCustomClass(customClassName_); + + // sample schema + // torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1); + + CHECK(node->attributes().size() == 1) + << "Expects higher_order.call_torchbind to only have a single attribute, methodName"; + const auto& attr = node->attributes()[0]; + + CHECK(std::holds_alternative(attr.value)) + << "method should be a string"; + methodName_ = std::get(attr.value); + method_ = customClassType_->findMethod(methodName_); + + CHECK(method_ != nullptr) << "method not found: " << methodName_; +} + +void CallTorchBindKernel::computeInternal( + ExecutionFrame& executionFrame) const { + // prepare inputs + std::vector stack; + for (const auto& input : node_->inputs()) { + const auto& id = input.value->id(); + stack.emplace_back(executionFrame.getIValue(id)); + } + + // call the method + method_->run(stack); + + // set outputs + const auto& outputs = node_->outputs(); + TORCH_CHECK_EQ(outputs.size(), stack.size()); + for (auto&& [i, outputValue] : c10::enumerate(stack)) { + executionFrame.setIValue(outputs[i]->id(), std::move(outputValue)); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/CallTorchBindKernel.h b/torch/nativert/kernels/CallTorchBindKernel.h new file mode 100644 index 00000000000000..acddf019d387e4 --- /dev/null +++ b/torch/nativert/kernels/CallTorchBindKernel.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch::nativert { + +class CallTorchBindKernel : public OpKernel { + public: + CallTorchBindKernel() = delete; // deleted default constructor + CallTorchBindKernel(const Node* node); + + void computeInternal(ExecutionFrame& executionFrame) const override final; + + private: + std::string methodName_; + torch::jit::Function* method_; + + std::string customClassName_; + at::ClassTypePtr customClassType_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp b/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp new file mode 100644 index 00000000000000..c33fb81604f6fb --- /dev/null +++ b/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp @@ -0,0 +1,354 @@ +// @generated +// @lint-ignore-every CLANGTIDY HOWTOEVEN +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace torch::nativert { + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.view_as_real.default", + aten_view_as_real_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::view_as_real(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.view_as_complex.default", + aten_view_as_complex_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::view_as_complex(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.real.default", aten_real_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::real(self); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.imag.default", aten_imag_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::imag(self); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten._conj.default", aten__conj_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::_conj(self); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.conj.default", aten_conj_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::conj(self); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.resolve_conj.default", + aten_resolve_conj_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::resolve_conj(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.resolve_neg.default", + aten_resolve_neg_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::resolve_neg(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten._neg_view.default", + aten__neg_view_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::_neg_view(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.diagonal.default", + aten_diagonal_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto offset = KernelInput(1).toInt(); + const auto dim1 = KernelInput(2).toInt(); + const auto dim2 = KernelInput(3).toInt(); + KernelOutput(0) = at::native::diagonal(self, offset, dim1, dim2); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.linalg_diagonal.default", + aten_linalg_diagonal_default, + { + const auto& A = KernelInput(0).toTensor(); + const auto offset = KernelInput(1).toInt(); + const auto dim1 = KernelInput(2).toInt(); + const auto dim2 = KernelInput(3).toInt(); + KernelOutput(0) = at::native::linalg_diagonal(A, offset, dim1, dim2); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.expand_as.default", + aten_expand_as_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + KernelOutput(0) = at::native::expand_as(self, other); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.flatten.using_ints", + aten_flatten_using_ints, + { + const auto& self = KernelInput(0).toTensor(); + const auto start_dim = KernelInput(1).toInt(); + const auto end_dim = KernelInput(2).toInt(); + KernelOutput(0) = at::native::flatten(self, start_dim, end_dim); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.movedim.int", aten_movedim_int, { + const auto& self = KernelInput(0).toTensor(); + const auto source = KernelInput(1).toInt(); + const auto destination = KernelInput(2).toInt(); + KernelOutput(0) = at::native::movedim(self, source, destination); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.moveaxis.int", aten_moveaxis_int, { + const auto& self = KernelInput(0).toTensor(); + const auto source = KernelInput(1).toInt(); + const auto destination = KernelInput(2).toInt(); + KernelOutput(0) = at::native::moveaxis(self, source, destination); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.numpy_T.default", + aten_numpy_T_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::numpy_T(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.matrix_H.default", + aten_matrix_H_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::matrix_H(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.mT.default", aten_mT_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::mT(self); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.mH.default", aten_mH_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::mH(self); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.adjoint.default", + aten_adjoint_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::adjoint(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.ravel.default", aten_ravel_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::ravel(self); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.reshape_as.default", + aten_reshape_as_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + KernelOutput(0) = at::native::reshape_as(self, other); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.detach.default", + aten_detach_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::detach(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.squeeze.default", + aten_squeeze_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::squeeze(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.squeeze.dim", aten_squeeze_dim, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + KernelOutput(0) = at::native::squeeze(self, dim); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.t.default", aten_t_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::t(self); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.transpose.int", aten_transpose_int, { + const auto& self = KernelInput(0).toTensor(); + const auto dim0 = KernelInput(1).toInt(); + const auto dim1 = KernelInput(2).toInt(); + KernelOutput(0) = at::native::transpose(self, dim0, dim1); + return; +}); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.unsqueeze.default", + aten_unsqueeze_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + KernelOutput(0) = at::native::unsqueeze(self, dim); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.view_as.default", + aten_view_as_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + KernelOutput(0) = at::native::view_as(self, other); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.positive.default", + aten_positive_default, + { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::positive(self); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten._autocast_to_reduced_precision.default", + aten__autocast_to_reduced_precision_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto cuda_enabled = KernelInput(1).toBool(); + const auto cpu_enabled = KernelInput(2).toBool(); + const auto cuda_dtype = KernelInput(3).toScalarType(); + const auto cpu_dtype = KernelInput(4).toScalarType(); + KernelOutput(0) = at::native::_autocast_to_reduced_precision( + self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten._autocast_to_full_precision.default", + aten__autocast_to_full_precision_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto cuda_enabled = KernelInput(1).toBool(); + const auto cpu_enabled = KernelInput(2).toBool(); + KernelOutput(0) = at::native::_autocast_to_full_precision( + self, cuda_enabled, cpu_enabled); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.swapaxes.default", + aten_swapaxes_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto axis0 = KernelInput(1).toInt(); + const auto axis1 = KernelInput(2).toInt(); + KernelOutput(0) = at::native::swapaxes(self, axis0, axis1); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.swapdims.default", + aten_swapdims_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim0 = KernelInput(1).toInt(); + const auto dim1 = KernelInput(2).toInt(); + KernelOutput(0) = at::native::swapdims(self, dim0, dim1); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.unfold.default", + aten_unfold_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dimension = KernelInput(1).toInt(); + const auto size = KernelInput(2).toInt(); + const auto step = KernelInput(3).toInt(); + KernelOutput(0) = at::native::unfold(self, dimension, size, step); + return; + }); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.alias.default", aten_alias_default, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::alias(self); + return; +}); + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp b/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp new file mode 100644 index 00000000000000..986eb060cb0fbb --- /dev/null +++ b/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp @@ -0,0 +1,3211 @@ +// @generated +// @lint-ignore-every CLANGTIDY HOWTOEVEN +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace torch::nativert { + +REGISTER_CPU_KERNEL("torch.ops.aten.absolute.default", aten_absolute_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::absolute(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::absolute_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.angle.default", aten_angle_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::angle(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::angle_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sgn.default", aten_sgn_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sgn(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sgn_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.acos.default", aten_acos_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::acos(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::acos_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.arccos.default", aten_arccos_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::arccos(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::arccos_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.add.Tensor", aten_add_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + const auto alpha = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::add(self, other, alpha); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::add_out(out, self, other, alpha); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.add.Scalar", aten_add_Scalar, { + const auto& self = KernelInput(0).toTensor(); + const auto other = KernelInput(1).toScalar(); + const auto alpha = KernelInput(2).toScalar(); + if (auto& out = KernelOutput(0); out.isNone()) { + out = create_empty_from(self); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::add_out(out_t, self, other, alpha); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten._add_relu.Tensor", aten__add_relu_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + const auto alpha = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::add_relu(self, other, alpha); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::add_relu_out(self, other, alpha, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.addmv.default", aten_addmv_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& mat = KernelInput(1).toTensor(); + const auto& vec = KernelInput(2).toTensor(); + const auto beta = KernelInput(3).toScalar(); + const auto alpha = KernelInput(4).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::addmv(self, mat, vec, beta, alpha); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::addmv_out(out, self, mat, vec, beta, alpha); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.addr.default", aten_addr_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& vec1 = KernelInput(1).toTensor(); + const auto& vec2 = KernelInput(2).toTensor(); + const auto beta = KernelInput(3).toScalar(); + const auto alpha = KernelInput(4).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::addr(self, vec1, vec2, beta, alpha); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::addr_out(self, vec1, vec2, beta, alpha, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.all.dim", aten_all_dim, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto keepdim = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::all(self, dim, keepdim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::all_out(out, self, dim, keepdim); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.any.dim", aten_any_dim, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto keepdim = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::any(self, dim, keepdim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::any_out(out, self, dim, keepdim); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.argmax.default", aten_argmax_default, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toOptional(); + const auto keepdim = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::argmax(self, dim, keepdim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::argmax_out(out, self, dim, keepdim); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.acosh.default", aten_acosh_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::acosh(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::acosh_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.asinh.default", aten_asinh_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::asinh(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::asinh_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.arcsinh.default", aten_arcsinh_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::arcsinh(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::arcsinh_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.atanh.default", aten_atanh_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::atanh(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::atanh_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.arctanh.default", aten_arctanh_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::arctanh(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::arctanh_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.asin.default", aten_asin_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::asin(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::asin_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.arcsin.default", aten_arcsin_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::arcsin(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::arcsin_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.atan.default", aten_atan_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::atan(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::atan_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.arctan.default", aten_arctan_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::arctan(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::arctan_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.baddbmm.default", aten_baddbmm_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& batch1 = KernelInput(1).toTensor(); + const auto& batch2 = KernelInput(2).toTensor(); + const auto beta = KernelInput(3).toScalar(); + const auto alpha = KernelInput(4).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::baddbmm(self, batch1, batch2, beta, alpha); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::baddbmm_out(out, self, batch1, batch2, beta, alpha); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.bitwise_not.default", + aten_bitwise_not_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::bitwise_not(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::bitwise_not_out(out, self); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.copysign.Tensor", aten_copysign_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::copysign(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::copysign_out(out, self, other); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.logical_not.default", + aten_logical_not_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::logical_not(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::logical_not_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.logical_xor.default", + aten_logical_xor_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::logical_xor(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::logical_xor_out(self, other, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.logical_and.default", + aten_logical_and_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::logical_and(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::logical_and_out(self, other, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.logical_or.default", + aten_logical_or_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::logical_or(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::logical_or_out(self, other, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.ceil.default", aten_ceil_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::ceil(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::ceil_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.clamp.default", aten_clamp_default, { + const auto& self = KernelInput(0).toTensor(); + const auto min = KernelInput(1).toOptional(); + const auto max = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::clamp(self, min, max); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::clamp_out(out, self, min, max); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.clamp.Tensor", aten_clamp_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto min = KernelInput(1).toOptional(); + const auto max = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::clamp(self, min, max); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::clamp_out(out, self, min, max); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.clamp_max.default", + aten_clamp_max_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto max = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::clamp_max(self, max); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::clamp_max_out(out, self, max); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.clamp_max.Tensor", aten_clamp_max_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& max = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::clamp_max(self, max); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::clamp_max_out(out, self, max); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.clip.default", aten_clip_default, { + const auto& self = KernelInput(0).toTensor(); + const auto min = KernelInput(1).toOptional(); + const auto max = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::clip(self, min, max); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::clip_out(self, min, max, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.complex.default", aten_complex_default, { + const auto& real = KernelInput(0).toTensor(); + const auto& imag = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::complex(real, imag); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::complex_out(real, imag, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.polar.default", aten_polar_default, { + const auto& abs = KernelInput(0).toTensor(); + const auto& angle = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::polar(abs, angle); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::polar_out(abs, angle, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.cos.default", aten_cos_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cos(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::cos_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.cosh.default", aten_cosh_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cosh(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::cosh_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.cumprod.default", aten_cumprod_default, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto dtype = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cumprod(self, dim, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::cumprod_out(out, self, dim, dtype); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.diff.default", aten_diff_default, { + const auto& self = KernelInput(0).toTensor(); + const auto n = KernelInput(1).toInt(); + const auto dim = KernelInput(2).toInt(); + const auto prepend = KernelInput(3).toOptional(); + const auto append = KernelInput(4).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::diff(self, n, dim, prepend, append); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::diff_out(self, n, dim, prepend, append, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor", aten_div_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::div(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::div_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor_mode", aten_div_Tensor_mode, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + const auto rounding_mode = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::div(self, other, rounding_mode); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::div_out(out, self, other, rounding_mode); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.divide.Tensor", aten_divide_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::divide(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::divide_out(self, other, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.true_divide.Tensor", + aten_true_divide_Tensor, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::true_divide(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::true_divide_out(self, other, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.dot.default", aten_dot_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& tensor = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::dot(self, tensor); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::dot_out(self, tensor, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.vdot.default", aten_vdot_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::vdot(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::vdot_out(self, other, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.erf.default", aten_erf_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::erf(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::erf_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.erfc.default", aten_erfc_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::erfc(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::erfc_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.exp.default", aten_exp_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::exp(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::exp_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.exp2.default", aten_exp2_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::exp2(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::exp2_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.expm1.default", aten_expm1_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::expm1(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::expm1_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.floor.default", aten_floor_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::floor(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::floor_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.frac.default", aten_frac_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::frac(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::frac_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.gcd.default", aten_gcd_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::gcd(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::gcd_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.lcm.default", aten_lcm_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::lcm(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::lcm_out(out, self, other); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.index_copy.default", + aten_index_copy_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto& source = KernelInput(3).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::index_copy(self, dim, index, source); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::index_copy_out(out, self, dim, index, source); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.isin.Tensor_Tensor", + aten_isin_Tensor_Tensor, + { + const auto& elements = KernelInput(0).toTensor(); + const auto& test_elements = KernelInput(1).toTensor(); + const auto assume_unique = KernelInput(2).toBool(); + const auto invert = KernelInput(3).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::cpu::isin(elements, test_elements, assume_unique, invert); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::isin_out(out, elements, test_elements, assume_unique, invert); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.isin.Tensor_Scalar", + aten_isin_Tensor_Scalar, + { + const auto& elements = KernelInput(0).toTensor(); + const auto test_element = KernelInput(1).toScalar(); + const auto assume_unique = KernelInput(2).toBool(); + const auto invert = KernelInput(3).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::cpu::isin(elements, test_element, assume_unique, invert); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::isin_out(out, elements, test_element, assume_unique, invert); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.isin.Scalar_Tensor", + aten_isin_Scalar_Tensor, + { + const auto element = KernelInput(0).toScalar(); + const auto& test_elements = KernelInput(1).toTensor(); + const auto assume_unique = KernelInput(2).toBool(); + const auto invert = KernelInput(3).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::cpu::isin(element, test_elements, assume_unique, invert); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::isin_out(out, element, test_elements, assume_unique, invert); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.kron.default", aten_kron_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::kron(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::kron_out(self, other, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.ldexp.Tensor", aten_ldexp_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::ldexp(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::ldexp_out(self, other, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.log10.default", aten_log10_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::log10(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::log10_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.log1p.default", aten_log1p_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::log1p(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::log1p_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.log2.default", aten_log2_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::log2(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::log2_out(out, self); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.logaddexp.default", + aten_logaddexp_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::logaddexp(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::logaddexp_out(out, self, other); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.logaddexp2.default", + aten_logaddexp2_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::logaddexp2(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::logaddexp2_out(out, self, other); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.xlogy.Tensor", aten_xlogy_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::xlogy(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::xlogy_out(out, self, other); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten._log_softmax.default", + aten__log_softmax_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto half_to_float = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::_log_softmax(self, dim, half_to_float); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::_log_softmax_out(out, self, dim, half_to_float); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten._logcumsumexp.default", + aten__logcumsumexp_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::_logcumsumexp_cpu(self, dim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::_logcumsumexp_out_cpu(self, dim, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.logcumsumexp.default", + aten_logcumsumexp_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::logcumsumexp(self, dim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::logcumsumexp_out(self, dim, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.matrix_power.default", + aten_matrix_power_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto n = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::matrix_power(self, n); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::matrix_power_out(self, n, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.mm.default", aten_mm_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& mat2 = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::mm(self, mat2); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::mm_out(out, self, mat2); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.multiply.Tensor", aten_multiply_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::multiply(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::multiply_out(self, other, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.mv.default", aten_mv_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& vec = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::mv(self, vec); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::mv_out(self, vec, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.mvlgamma.default", aten_mvlgamma_default, { + const auto& self = KernelInput(0).toTensor(); + const auto p = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::mvlgamma(self, p); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::mvlgamma_out(self, p, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.rad2deg.default", aten_rad2deg_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::rad2deg(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::rad2deg_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.deg2rad.default", aten_deg2rad_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::deg2rad(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::deg2rad_out(self, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.reciprocal.default", + aten_reciprocal_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::reciprocal(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::reciprocal_out(out, self); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.neg.default", aten_neg_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::neg(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::neg_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.negative.default", aten_negative_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::negative(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::negative_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.round.default", aten_round_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::round(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::round_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.round.decimals", aten_round_decimals, { + const auto& self = KernelInput(0).toTensor(); + const auto decimals = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::round(self, decimals); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::round_out(out, self, decimals); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.gelu.default", aten_gelu_default, { + const auto& self = KernelInput(0).toTensor(); + const auto approximate = KernelInput(1).toStringView(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::gelu(self, approximate); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::gelu_out(out, self, approximate); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.hardshrink.default", + aten_hardshrink_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto lambd = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::hardshrink(self, lambd); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::hardshrink_out(out, self, lambd); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.hardshrink_backward.default", + aten_hardshrink_backward_default, + { + const auto& grad_out = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + const auto lambd = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::hardshrink_backward(grad_out, self, lambd); + return; + } + auto& grad_input = KernelOutput(0).toTensor(); + fastResizeToZero(grad_input); + at::cpu::hardshrink_backward_out(grad_input, grad_out, self, lambd); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.rsqrt.default", aten_rsqrt_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::rsqrt(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::rsqrt_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.silu.default", aten_silu_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::silu(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::silu_out(out, self); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.silu_backward.default", + aten_silu_backward_default, + { + const auto& grad_output = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::silu_backward(grad_output, self); + return; + } + auto& grad_input = KernelOutput(0).toTensor(); + fastResizeToZero(grad_input); + at::cpu::silu_backward_out(grad_input, grad_output, self); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.mish.default", aten_mish_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::mish(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::mish_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sigmoid.default", aten_sigmoid_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sigmoid(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sigmoid_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sin.default", aten_sin_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sin(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sin_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sinc.default", aten_sinc_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sinc(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sinc_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sinh.default", aten_sinh_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sinh(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sinh_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten._softmax.default", aten__softmax_default, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto half_to_float = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::_softmax(self, dim, half_to_float); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::_softmax_out(out, self, dim, half_to_float); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sqrt.default", aten_sqrt_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sqrt(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sqrt_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.square.default", aten_square_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::square(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::square_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.prod.default", aten_prod_default, { + const auto& self = KernelInput(0).toTensor(); + const auto dtype = KernelInput(1).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::prod(self, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::prod_out(self, dtype, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.prod.dim_int", aten_prod_dim_int, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto keepdim = KernelInput(2).toBool(); + const auto dtype = KernelInput(3).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::prod(self, dim, keepdim, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::prod_out(out, self, dim, keepdim, dtype); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.tan.default", aten_tan_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::tan(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::tan_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.tanh.default", aten_tanh_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::tanh(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::tanh_out(out, self); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.threshold.default", + aten_threshold_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto threshold = KernelInput(1).toScalar(); + const auto value = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::threshold(self, threshold, value); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::threshold_out(out, self, threshold, value); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.threshold_backward.default", + aten_threshold_backward_default, + { + const auto& grad_output = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + const auto threshold = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::cpu::threshold_backward(grad_output, self, threshold); + return; + } + auto& grad_input = KernelOutput(0).toTensor(); + fastResizeToZero(grad_input); + at::cpu::threshold_backward_out(grad_input, grad_output, self, threshold); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.trunc.default", aten_trunc_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::trunc(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::trunc_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.fix.default", aten_fix_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::fix(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::fix_out(self, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.nuclear_norm.default", + aten_nuclear_norm_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto keepdim = KernelInput(1).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::nuclear_norm(self, keepdim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::nuclear_norm_out(self, keepdim, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.subtract.Tensor", aten_subtract_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + const auto alpha = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::subtract(self, other, alpha); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::subtract_out(self, other, alpha, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.heaviside.default", + aten_heaviside_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& values = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::heaviside(self, values); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::heaviside_out(out, self, values); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten._addmm_activation.default", + aten__addmm_activation_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& mat1 = KernelInput(1).toTensor(); + const auto& mat2 = KernelInput(2).toTensor(); + const auto beta = KernelInput(3).toScalar(); + const auto alpha = KernelInput(4).toScalar(); + const auto use_gelu = KernelInput(5).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::cpu::_addmm_activation(self, mat1, mat2, beta, alpha, use_gelu); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::_addmm_activation_out( + out, self, mat1, mat2, beta, alpha, use_gelu); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.index_add.default", + aten_index_add_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto& source = KernelInput(3).toTensor(); + const auto alpha = KernelInput(4).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::index_add(self, dim, index, source, alpha); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::index_add_out(out, self, dim, index, source, alpha); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.scatter.src", aten_scatter_src, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto& src = KernelInput(3).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::scatter(self, dim, index, src); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::scatter_out(out, self, dim, index, src); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.scatter.value", aten_scatter_value, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto value = KernelInput(3).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::scatter(self, dim, index, value); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::scatter_out(out, self, dim, index, value); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.scatter.reduce", aten_scatter_reduce, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto& src = KernelInput(3).toTensor(); + const auto reduce = KernelInput(4).toStringView(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::scatter(self, dim, index, src, reduce); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::scatter_out(out, self, dim, index, src, reduce); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.scatter.value_reduce", + aten_scatter_value_reduce, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto value = KernelInput(3).toScalar(); + const auto reduce = KernelInput(4).toStringView(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::scatter(self, dim, index, value, reduce); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::scatter_out(out, self, dim, index, value, reduce); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.scatter_add.default", + aten_scatter_add_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto& src = KernelInput(3).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::scatter_add(self, dim, index, src); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::scatter_add_out(out, self, dim, index, src); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.scatter_reduce.two", + aten_scatter_reduce_two, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto& src = KernelInput(3).toTensor(); + const auto reduce = KernelInput(4).toStringView(); + const auto include_self = KernelInput(5).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::scatter_reduce( + self, dim, index, src, reduce, include_self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::scatter_reduce_out( + out, self, dim, index, src, reduce, include_self); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.eq.Scalar", aten_eq_Scalar, { + const auto& self = KernelInput(0).toTensor(); + const auto other = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::eq(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::eq_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.eq.Tensor", aten_eq_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::eq(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::eq_out(out, self, other); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.bitwise_and.Tensor", + aten_bitwise_and_Tensor, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::bitwise_and(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::bitwise_and_out(out, self, other); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.bitwise_or.Tensor", + aten_bitwise_or_Tensor, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::bitwise_or(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::bitwise_or_out(out, self, other); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.bitwise_xor.Tensor", + aten_bitwise_xor_Tensor, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::bitwise_xor(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::bitwise_xor_out(out, self, other); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.bitwise_left_shift.Tensor", + aten_bitwise_left_shift_Tensor, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::bitwise_left_shift(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::bitwise_left_shift_out(out, self, other); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.bitwise_right_shift.Tensor", + aten_bitwise_right_shift_Tensor, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::bitwise_right_shift(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::bitwise_right_shift_out(out, self, other); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.tril.default", aten_tril_default, { + const auto& self = KernelInput(0).toTensor(); + const auto diagonal = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::tril(self, diagonal); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::tril_out(out, self, diagonal); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.triu.default", aten_triu_default, { + const auto& self = KernelInput(0).toTensor(); + const auto diagonal = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::triu(self, diagonal); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::triu_out(out, self, diagonal); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.digamma.default", aten_digamma_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::digamma(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::digamma_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Scalar", aten_lerp_Scalar, { + const auto& self = KernelInput(0).toTensor(); + const auto& end = KernelInput(1).toTensor(); + const auto weight = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::lerp(self, end, weight); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::lerp_out(out, self, end, weight); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Tensor", aten_lerp_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& end = KernelInput(1).toTensor(); + const auto& weight = KernelInput(2).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::lerp(self, end, weight); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::lerp_out(out, self, end, weight); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.addbmm.default", aten_addbmm_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& batch1 = KernelInput(1).toTensor(); + const auto& batch2 = KernelInput(2).toTensor(); + const auto beta = KernelInput(3).toScalar(); + const auto alpha = KernelInput(4).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::addbmm(self, batch1, batch2, beta, alpha); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::addbmm_out(self, batch1, batch2, beta, alpha, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.cross.default", aten_cross_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + const auto dim = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::cross(self, other, dim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::cross_out(self, other, dim, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.ne.Scalar", aten_ne_Scalar, { + const auto& self = KernelInput(0).toTensor(); + const auto other = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::ne(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::ne_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.ne.Tensor", aten_ne_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::ne(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::ne_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.ge.Scalar", aten_ge_Scalar, { + const auto& self = KernelInput(0).toTensor(); + const auto other = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::ge(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::ge_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.ge.Tensor", aten_ge_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::ge(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::ge_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.le.Scalar", aten_le_Scalar, { + const auto& self = KernelInput(0).toTensor(); + const auto other = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::le(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::le_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.le.Tensor", aten_le_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::le(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::le_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.gt.Scalar", aten_gt_Scalar, { + const auto& self = KernelInput(0).toTensor(); + const auto other = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::gt(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::gt_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.gt.Tensor", aten_gt_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::gt(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::gt_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.lt.Scalar", aten_lt_Scalar, { + const auto& self = KernelInput(0).toTensor(); + const auto other = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::lt(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::lt_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.lt.Tensor", aten_lt_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::lt(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::lt_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.take.default", aten_take_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& index = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::take(self, index); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::take_out(self, index, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.take_along_dim.default", + aten_take_along_dim_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& indices = KernelInput(1).toTensor(); + const auto dim = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::take_along_dim(self, indices, dim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::take_along_dim_out(self, indices, dim, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.masked_select.default", + aten_masked_select_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& mask = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::masked_select_cpu(self, mask); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::masked_select_out_cpu(self, mask, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.gather.default", aten_gather_default, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + const auto sparse_grad = KernelInput(3).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::gather(self, dim, index, sparse_grad); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::gather_out(out, self, dim, index, sparse_grad); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.addcmul.default", aten_addcmul_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& tensor1 = KernelInput(1).toTensor(); + const auto& tensor2 = KernelInput(2).toTensor(); + const auto value = KernelInput(3).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::addcmul(self, tensor1, tensor2, value); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::addcmul_out(out, self, tensor1, tensor2, value); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.addcdiv.default", aten_addcdiv_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& tensor1 = KernelInput(1).toTensor(); + const auto& tensor2 = KernelInput(2).toTensor(); + const auto value = KernelInput(3).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::addcdiv(self, tensor1, tensor2, value); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::addcdiv_out(out, self, tensor1, tensor2, value); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_solve_triangular.default", + aten_linalg_solve_triangular_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& B = KernelInput(1).toTensor(); + const auto upper = KernelInput(2).toBool(); + const auto left = KernelInput(3).toBool(); + const auto unitriangular = KernelInput(4).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_solve_triangular( + self, B, upper, left, unitriangular); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_solve_triangular_out( + self, B, upper, left, unitriangular, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.cholesky_solve.default", + aten_cholesky_solve_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& input2 = KernelInput(1).toTensor(); + const auto upper = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::cholesky_solve(self, input2, upper); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::cholesky_solve_out(self, input2, upper, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.cholesky_inverse.default", + aten_cholesky_inverse_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto upper = KernelInput(1).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::cholesky_inverse(self, upper); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::cholesky_inverse_out(self, upper, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.orgqr.default", aten_orgqr_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& input2 = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::orgqr(self, input2); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::orgqr_out(self, input2, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.ormqr.default", aten_ormqr_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& input2 = KernelInput(1).toTensor(); + const auto& input3 = KernelInput(2).toTensor(); + const auto left = KernelInput(3).toBool(); + const auto transpose = KernelInput(4).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::ormqr(self, input2, input3, left, transpose); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::ormqr_out(self, input2, input3, left, transpose, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.lgamma.default", aten_lgamma_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::lgamma(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::lgamma_out(out, self); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.polygamma.default", + aten_polygamma_default, + { + const auto n = KernelInput(0).toInt(); + const auto& self = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::polygamma(n, self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::polygamma_out(out, n, self); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.erfinv.default", aten_erfinv_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::erfinv(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::erfinv_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.i0.default", aten_i0_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::i0(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::i0_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.signbit.default", aten_signbit_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::signbit(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::signbit_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.atan2.default", aten_atan2_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::atan2(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::atan2_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.arctan2.default", aten_arctan2_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::arctan2(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::arctan2_out(self, other, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.histc.default", aten_histc_default, { + const auto& self = KernelInput(0).toTensor(); + const auto bins = KernelInput(1).toInt(); + const auto min = KernelInput(2).toScalar(); + const auto max = KernelInput(3).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::histogram_histc(self, bins, min, max); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::histogram_histc_out(self, bins, min, max, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.fmod.Tensor", aten_fmod_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::fmod(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::fmod_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.hypot.default", aten_hypot_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::hypot(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::hypot_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.igamma.default", aten_igamma_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::igamma(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::igamma_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.igammac.default", aten_igammac_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::igammac(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::igammac_out(out, self, other); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.nextafter.default", + aten_nextafter_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::nextafter(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::nextafter_out(out, self, other); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.fmin.default", aten_fmin_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::fmin(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::fmin_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.fmax.default", aten_fmax_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::fmax(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::fmax_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.maximum.default", aten_maximum_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::maximum(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::maximum_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.minimum.default", aten_minimum_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::minimum(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::minimum_out(out, self, other); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.min.other", aten_min_other, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::min(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::min_out(self, other, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.quantile.default", aten_quantile_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& q = KernelInput(1).toTensor(); + const auto dim = KernelInput(2).toOptional(); + const auto keepdim = KernelInput(3).toBool(); + const auto interpolation = KernelInput(4).toStringView(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::native::quantile(self, q, dim, keepdim, interpolation); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::quantile_out(self, q, dim, keepdim, interpolation, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.nanquantile.default", + aten_nanquantile_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& q = KernelInput(1).toTensor(); + const auto dim = KernelInput(2).toOptional(); + const auto keepdim = KernelInput(3).toBool(); + const auto interpolation = KernelInput(4).toStringView(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::native::nanquantile(self, q, dim, keepdim, interpolation); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::nanquantile_out(self, q, dim, keepdim, interpolation, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.msort.default", aten_msort_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::msort(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::msort_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.all.default", aten_all_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::all(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::all_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.any.default", aten_any_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::any(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::any_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.renorm.default", aten_renorm_default, { + const auto& self = KernelInput(0).toTensor(); + const auto p = KernelInput(1).toScalar(); + const auto dim = KernelInput(2).toInt(); + const auto maxnorm = KernelInput(3).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::renorm(self, p, dim, maxnorm); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::renorm_out(out, self, p, dim, maxnorm); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten._convert_indices_from_coo_to_csr.default", + aten__convert_indices_from_coo_to_csr_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto size = KernelInput(1).toInt(); + const auto out_int32 = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::cpu::_convert_indices_from_coo_to_csr(self, size, out_int32); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::_convert_indices_from_coo_to_csr_out(out, self, size, out_int32); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten._convert_indices_from_csr_to_coo.default", + aten__convert_indices_from_csr_to_coo_default, + { + const auto& crow_indices = KernelInput(0).toTensor(); + const auto& col_indices = KernelInput(1).toTensor(); + const auto out_int32 = KernelInput(2).toBool(); + const auto transpose = KernelInput(3).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::_convert_indices_from_csr_to_coo( + crow_indices, col_indices, out_int32, transpose); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::_convert_indices_from_csr_to_coo_out( + out, crow_indices, col_indices, out_int32, transpose); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.mse_loss.default", aten_mse_loss_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& target = KernelInput(1).toTensor(); + const auto reduction = KernelInput(2).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::mse_loss(self, target, reduction); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::mse_loss_out(out, self, target, reduction); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.multi_margin_loss.default", + aten_multi_margin_loss_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& target = KernelInput(1).toTensor(); + const auto p = KernelInput(2).toScalar(); + const auto margin = KernelInput(3).toScalar(); + const auto weight = KernelInput(4).toOptional(); + const auto reduction = KernelInput(5).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::multi_margin_loss_cpu( + self, target, p, margin, weight, reduction); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::multi_margin_loss_cpu_out( + self, target, p, margin, weight, reduction, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.multilabel_margin_loss.default", + aten_multilabel_margin_loss_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& target = KernelInput(1).toTensor(); + const auto reduction = KernelInput(2).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::native::multilabel_margin_loss(self, target, reduction); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::multilabel_margin_loss_out(self, target, reduction, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.soft_margin_loss.default", + aten_soft_margin_loss_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& target = KernelInput(1).toTensor(); + const auto reduction = KernelInput(2).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::soft_margin_loss(self, target, reduction); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::soft_margin_loss_out(self, target, reduction, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.elu.default", aten_elu_default, { + const auto& self = KernelInput(0).toTensor(); + const auto alpha = KernelInput(1).toScalar(); + const auto scale = KernelInput(2).toScalar(); + const auto input_scale = KernelInput(3).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::elu(self, alpha, scale, input_scale); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::elu_out(out, self, alpha, scale, input_scale); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.elu_backward.default", + aten_elu_backward_default, + { + const auto& grad_output = KernelInput(0).toTensor(); + const auto alpha = KernelInput(1).toScalar(); + const auto scale = KernelInput(2).toScalar(); + const auto input_scale = KernelInput(3).toScalar(); + const auto is_result = KernelInput(4).toBool(); + const auto& self_or_result = KernelInput(5).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::elu_backward( + grad_output, alpha, scale, input_scale, is_result, self_or_result); + return; + } + auto& grad_input = KernelOutput(0).toTensor(); + fastResizeToZero(grad_input); + at::cpu::elu_backward_out( + grad_input, + grad_output, + alpha, + scale, + input_scale, + is_result, + self_or_result); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.glu.default", aten_glu_default, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::glu(self, dim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::glu_out(out, self, dim); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.hardsigmoid.default", + aten_hardsigmoid_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::hardsigmoid(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::hardsigmoid_out(out, self); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.hardsigmoid_backward.default", + aten_hardsigmoid_backward_default, + { + const auto& grad_output = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::hardsigmoid_backward(grad_output, self); + return; + } + auto& grad_input = KernelOutput(0).toTensor(); + fastResizeToZero(grad_input); + at::cpu::hardsigmoid_backward_out(grad_input, grad_output, self); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.hardtanh.default", aten_hardtanh_default, { + const auto& self = KernelInput(0).toTensor(); + const auto min_val = KernelInput(1).toScalar(); + const auto max_val = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::hardtanh(self, min_val, max_val); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::hardtanh_out(self, min_val, max_val, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.hardswish.default", + aten_hardswish_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::hardswish(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::hardswish_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.leaky_relu_backward.default", + aten_leaky_relu_backward_default, + { + const auto& grad_output = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + const auto negative_slope = KernelInput(2).toScalar(); + const auto self_is_result = KernelInput(3).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::leaky_relu_backward( + grad_output, self, negative_slope, self_is_result); + return; + } + auto& grad_input = KernelOutput(0).toTensor(); + fastResizeToZero(grad_input); + at::cpu::leaky_relu_backward_out( + grad_input, grad_output, self, negative_slope, self_is_result); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.log_sigmoid.default", + aten_log_sigmoid_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::log_sigmoid(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::log_sigmoid_out(self, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.softplus.default", aten_softplus_default, { + const auto& self = KernelInput(0).toTensor(); + const auto beta = KernelInput(1).toScalar(); + const auto threshold = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::softplus(self, beta, threshold); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::softplus_out(out, self, beta, threshold); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.softplus_backward.default", + aten_softplus_backward_default, + { + const auto& grad_output = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + const auto beta = KernelInput(2).toScalar(); + const auto threshold = KernelInput(3).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::cpu::softplus_backward(grad_output, self, beta, threshold); + return; + } + auto& grad_input = KernelOutput(0).toTensor(); + fastResizeToZero(grad_input); + at::cpu::softplus_backward_out( + grad_input, grad_output, self, beta, threshold); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.softshrink.default", + aten_softshrink_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto lambd = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::softshrink(self, lambd); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::softshrink_out(out, self, lambd); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.softshrink_backward.default", + aten_softshrink_backward_default, + { + const auto& grad_output = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + const auto lambd = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::cpu::softshrink_backward(grad_output, self, lambd); + return; + } + auto& grad_input = KernelOutput(0).toTensor(); + fastResizeToZero(grad_input); + at::cpu::softshrink_backward_out(grad_input, grad_output, self, lambd); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.isposinf.default", aten_isposinf_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::isposinf(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::isposinf_out(out, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.isneginf.default", aten_isneginf_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::isneginf(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::isneginf_out(out, self); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_entr.default", + aten_special_entr_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_entr(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_entr_out(out, self); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_ndtri.default", + aten_special_ndtri_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_ndtri(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_ndtri_out(out, self); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_log_ndtr.default", + aten_special_log_ndtr_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_log_ndtr(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_log_ndtr_out(out, self); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_expm1.default", + aten_special_expm1_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_expm1(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_expm1_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_exp2.default", + aten_special_exp2_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_exp2(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_exp2_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_psi.default", + aten_special_psi_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_psi(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_psi_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_digamma.default", + aten_special_digamma_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_digamma(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_digamma_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_gammaln.default", + aten_special_gammaln_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_gammaln(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_gammaln_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_erf.default", + aten_special_erf_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_erf(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_erf_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_erfc.default", + aten_special_erfc_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_erfc(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_erfc_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_erfcx.default", + aten_special_erfcx_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_erfcx(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_erfcx_out(out, self); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_erfinv.default", + aten_special_erfinv_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_erfinv(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_erfinv_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_ndtr.default", + aten_special_ndtr_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_ndtr(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_ndtr_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_xlog1py.default", + aten_special_xlog1py_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_xlog1py(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_xlog1py_out(out, self, other); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_xlogy.default", + aten_special_xlogy_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_xlogy(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_xlogy_out(self, other, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_zeta.default", + aten_special_zeta_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_zeta(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_zeta_out(out, self, other); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_i0.default", + aten_special_i0_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_i0(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_i0_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_i0e.default", + aten_special_i0e_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_i0e(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_i0e_out(out, self); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_i1.default", + aten_special_i1_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_i1(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_i1_out(out, self); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_i1e.default", + aten_special_i1e_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::special_i1e(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::special_i1e_out(out, self); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_polygamma.default", + aten_special_polygamma_default, + { + const auto n = KernelInput(0).toInt(); + const auto& self = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_polygamma(n, self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_polygamma_out(n, self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_expit.default", + aten_special_expit_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_expit(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_expit_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_sinc.default", + aten_special_sinc_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_sinc(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_sinc_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_round.default", + aten_special_round_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto decimals = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_round(self, decimals); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_round_out(self, decimals, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_log1p.default", + aten_special_log1p_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_log1p(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_log1p_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_gammainc.default", + aten_special_gammainc_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_gammainc(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_gammainc_out(self, other, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_gammaincc.default", + aten_special_gammaincc_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_gammaincc(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_gammaincc_out(self, other, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.special_multigammaln.default", + aten_special_multigammaln_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto p = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::special_multigammaln(self, p); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::special_multigammaln_out(self, p, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_cross.default", + aten_linalg_cross_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + const auto dim = KernelInput(2).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::linalg_cross(self, other, dim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::linalg_cross_out(out, self, other, dim); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_det.default", + aten_linalg_det_default, + { + const auto& A = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_det(A); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_det_out(A, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_matmul.default", + aten_linalg_matmul_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_matmul(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_matmul_out(self, other, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_eigvals.default", + aten_linalg_eigvals_default, + { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_eigvals(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_eigvals_out(self, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_inv.default", + aten_linalg_inv_default, + { + const auto& A = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_inv(A); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_inv_out(A, out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.inverse.default", aten_inverse_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::inverse(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::inverse_out(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.inner.default", aten_inner_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::inner(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::inner_out(self, other, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.outer.default", aten_outer_default, { + const auto& self = KernelInput(0).toTensor(); + const auto& vec2 = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::outer(self, vec2); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::outer_out(self, vec2, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_cond.default", + aten_linalg_cond_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto p = KernelInput(1).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_cond(self, p); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_cond_out(self, p, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_solve.default", + aten_linalg_solve_default, + { + const auto& A = KernelInput(0).toTensor(); + const auto& B = KernelInput(1).toTensor(); + const auto left = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_solve(A, B, left); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_solve_out(A, B, left, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_tensorinv.default", + aten_linalg_tensorinv_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto ind = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_tensorinv(self, ind); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_tensorinv_out(self, ind, out); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_matrix_power.default", + aten_linalg_matrix_power_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto n = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_matrix_power(self, n); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_matrix_power_out(self, n, out); + }); + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/HigherOrderKernel.cpp b/torch/nativert/kernels/HigherOrderKernel.cpp new file mode 100644 index 00000000000000..ebb9f646f1c406 --- /dev/null +++ b/torch/nativert/kernels/HigherOrderKernel.cpp @@ -0,0 +1,113 @@ +#include + +#include + +#include + +namespace torch::nativert { + +HigherOrderKernel::HigherOrderKernel( + const Node* node, + std::vector> graphExecutors) + : OpKernel(node), graphExecutors_(std::move(graphExecutors)) { + static constexpr std::string_view prefix = "torch.ops.higher_order."; + CHECK(c10::starts_with(node->target(), prefix)); + auto opName = node->target().substr(prefix.size()); + if (opName == "cond") { + opType_ = OpType::COND; + // Checking torch.cond schema is as expected: + // torch.cond(Tensor predicate, Graph graph1, Graph graph2, Tensor[] args) + // -> Tensor[] + TORCH_CHECK_EQ(node_->attributes().size(), 2); + TORCH_CHECK_EQ(node_->inputs().size(), 2); + } else if (opName == "while_loop") { + opType_ = OpType::WHILE_LOOP; + // Checking torch.while_loop schema is as expected: + // torch.while_loop(Graph cond, Graph body, Tensor[] args, Tensor[] + // additonal) -> Tensor[] + TORCH_CHECK_EQ(node_->attributes().size(), 2); + TORCH_CHECK_EQ(node_->inputs().size(), 2); + } else if (opName == "run_const_graph") { + opType_ = OpType::RUN_CONST_GRAPH; + // Checking torch.run_const_graph schema is as expected: + // torch.run_const_graph(Graph graph, Tensor[] args) -> Tensor[] + TORCH_CHECK_GE(node_->attributes().size(), 1); + TORCH_CHECK_EQ(node_->inputs().size(), 1); + } else { + throw std::runtime_error( + fmt::format("Unknown higher order op: {}", opName)); + } +} + +void HigherOrderKernel::computeInternal(ExecutionFrame& executionFrame) const { + switch (opType_) { + case OpType::COND: { + auto inputs = executionFrame.getIValue(node_->inputs()[1].value->id()) + .toList() + .vec(); + std::vector outputs; + auto cond = executionFrame.getIValue(node_->inputs()[0].value->id()); + size_t branchIdx = 0; + if (cond.isTensor()) { + branchIdx = cond.toTensor().item().toBool() ? 0 : 1; + } else if (cond.isBool()) { + branchIdx = cond.toBool() ? 0 : 1; + } else { + throw std::runtime_error("Unsupported type for cond predicate"); + } + ExecutionFrame branchFrame(*std::get>( + node_->attributes()[branchIdx].value)); + auto ret = + graphExecutors_[branchIdx]->execute(branchFrame, std::move(inputs)); + for (size_t i = 0; i < ret.size(); i++) { + executionFrame.setIValue(node_->outputs()[i]->id(), std::move(ret[i])); + } + break; + } + case OpType::WHILE_LOOP: { + auto carriedVals = + executionFrame.getIValue(node_->inputs()[0].value->id()) + .toList() + .vec(); + auto additonalVals = + executionFrame.getIValue(node_->inputs()[1].value->id()) + .toList() + .vec(); + size_t numCarriedVals = carriedVals.size(); + ExecutionFrame condFrame( + *std::get>(node_->attributes()[0].value)); + ExecutionFrame bodyFrame( + *std::get>(node_->attributes()[1].value)); + while (true) { + auto inputs = carriedVals; + inputs.insert(inputs.end(), additonalVals.begin(), additonalVals.end()); + auto cond = graphExecutors_[0]->execute(condFrame, inputs); + + if (cond.at(0).isTensor() && !cond[0].toTensor().item().toBool()) { + break; + } + if (cond.at(0).isBool() && !cond[0].toBool()) { + break; + } + auto out = graphExecutors_[1]->execute(bodyFrame, std::move(inputs)); + TORCH_CHECK(out.size() == numCarriedVals); + carriedVals = std::move(out); + } + for (size_t i = 0; i < carriedVals.size(); i++) { + executionFrame.setIValue( + node_->outputs()[i]->id(), std::move(carriedVals[i])); + } + break; + } + case OpType::RUN_CONST_GRAPH: { + // run_const_graph op is a special case of higher order op which has + // been executed during weights loading, therefore at runtime we can + // just make this a no-op. + break; + } + default: + TORCH_CHECK(false, "Unknown higher order op"); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/HigherOrderKernel.h b/torch/nativert/kernels/HigherOrderKernel.h new file mode 100644 index 00000000000000..fb98e4bdec5836 --- /dev/null +++ b/torch/nativert/kernels/HigherOrderKernel.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nativert { + +class HigherOrderKernel : public OpKernel { + enum class OpType { + UNKNOWN, + COND, + WHILE_LOOP, + RUN_CONST_GRAPH, + }; + + public: + HigherOrderKernel( + const Node* node, + std::vector> graphExecutors); + void computeInternal(ExecutionFrame& executionFrame) const final; + + private: + std::vector> graphExecutors_; + OpType opType_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp new file mode 100644 index 00000000000000..1f72fef810d6ce --- /dev/null +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -0,0 +1,270 @@ +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::nativert { + +namespace { + +c10::Device inferTargetDevice( + const Node& node, + const std::unordered_map& + tensorValuesMeta, + const Placement& placement) { + if (node.target() == "prim.Input" || node.target() == "prim.Output") { + return c10::Device(c10::DeviceType::CPU); + } + + std::vector devices; + for (auto& output : node.outputs()) { + if (output->type() == Type::Kind::Tensor) { + auto it = tensorValuesMeta.find(std::string{output->name()}); + if (it != tensorValuesMeta.end()) { + devices.emplace_back(it->second.device()); + } + } else if (output->type() == Type::Kind::TensorList) { + for (const auto& el : output->getListElements()) { + auto it = tensorValuesMeta.find(std::string{el->name()}); + if (it != tensorValuesMeta.end()) { + devices.emplace_back(it->second.device()); + } + } + } + } + + if (devices.empty()) { + return c10::Device(c10::DeviceType::CPU); + } else { + for (size_t i = 1; i < devices.size(); ++i) { + if (!torch::nativert::isSameDevice(devices[0], devices[i])) { + LOG(WARNING) << "Node " << node + << " has outputs on multiple devices: " << devices[0] + << " and " << devices[i]; + } + } + + return placement.getMappedDevice(devices[0]); + } +} + +} // namespace + +inline constexpr std::string_view kSymIntOps[] = { + "_operator.floordiv", + "_operator.mod", + "torch.sym_int", + "torch.sym_float", + "torch.sym_ite", + "torch.sym_max", + "torch.sym_min", +}; + +inline constexpr std::string_view kSymBoolOps[] = { + "_operator.eq", + "_operator.ne", + "_operator.le", + "_operator.ge", + "_operator.lt", + "_operator.gt", + "_operator.and_", + "torch.sym_not", +}; + +inline constexpr std::string_view kSymFloatOps[] = { + "torch._sym_sqrt", + "math.trunc", + "_operator.neg", + "_operator.truediv", +}; + +inline constexpr std::string_view kScalarBinaryOps[] = { + "_operator.mul", + "_operator.add", + "_operator.sub", + "_operator.pow", +}; + +namespace { + +struct KernelFactoryRegistry { + std::unordered_map handlers; +}; + +c10::Synchronized& getKernelFactoryRegistry() { + static auto* registry = new c10::Synchronized(); + return *registry; +} + +} // namespace + +void KernelFactory::registerHandler( + const std::string& name, + KernelFactoryHandler handler) { + auto& registry = getKernelFactoryRegistry(); + registry.withLock([&](auto&& reg) { + if (reg.handlers.find(name) != reg.handlers.end()) { + TORCH_CHECK(false, "Handler for ", name, " already registered"); + } + reg.handlers.emplace(name, std::move(handler)); + }); +} + +ExecutionKernels KernelFactory::initializeNodeKernels( + const Graph& graph, + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + const Placement& placement, + std::shared_ptr pytorchStreamReader, + const MakeProxyExecutorFn& makeProxyExecutorFunc) { + std::vector> nodeKernels; + std::vector> delegateExecutors; + std::vector constFoldingExecutions; + + std::unordered_map opsWithoutStaticDispatchCount; + + VLOG(1) << fmt::format( + "PrimKernelRegistry: {}", fmt::join(PrimKernelRegistry()->Keys(), ", ")); + + std::unordered_map handlers; + getKernelFactoryRegistry().withLock( + [&](auto&& reg) { handlers = reg.handlers; }); + + for (const auto& node : graph.nodes()) { + std::string target = std::string(node.target()); + + c10::Device targetDevice = + inferTargetDevice(node, graph.tensorValuesMeta(), placement); + + bool matched = false; + for (const auto& [_, handler] : handlers) { + if (handler.match(node, executorConfig, targetDevice)) { + auto [kernel, delegate] = handler( + node, + weights, + executorConfig, + pytorchStreamReader.get(), + targetDevice); + if (kernel) { + nodeKernels.push_back(std::move(kernel)); + } + if (delegate) { + delegateExecutors.push_back(std::move(delegate)); + } + matched = true; + break; + } + } + if (matched) { + continue; + } + + if (PrimKernelRegistry()->Has(target)) { + nodeKernels.push_back(PrimKernelRegistry()->Create(target, &node)); + } else if (c10::starts_with( + node.target(), "torch.ops.higher_order.call_torchbind")) { + nodeKernels.push_back(std::make_unique(&node)); + } else if ( + c10::starts_with( + node.target(), + "torch.ops.higher_order.auto_functionalized") || + c10::starts_with( // TODO Remove this condition once the old + // pt2 archives are expired. + node.target(), + "torch._higher_order_ops.auto_functionalize.auto_functionalized")) { + nodeKernels.push_back( + std::make_unique(&node)); + } else if ( + std::find( + std::begin(kSymIntOps), std::end(kSymIntOps), node.target()) != + std::end(kSymIntOps)) { + nodeKernels.push_back(std::make_unique(&node)); + } else if ( + std::find( + std::begin(kSymBoolOps), std::end(kSymBoolOps), node.target()) != + std::end(kSymBoolOps)) { + nodeKernels.push_back(std::make_unique(&node)); + } else if ( + std::find( + std::begin(kSymFloatOps), std::end(kSymFloatOps), node.target()) != + std::end(kSymFloatOps)) { + nodeKernels.push_back(std::make_unique(&node)); + } else if ( + std::find( + std::begin(kScalarBinaryOps), + std::end(kScalarBinaryOps), + node.target()) != std::end(kScalarBinaryOps)) { + nodeKernels.push_back(std::make_unique(&node)); + } else if (c10::starts_with(node.target(), "torch.ops.higher_order")) { + std::vector> graphExecutors; + for (const auto& attr : node.attributes()) { + if (std::holds_alternative>(attr.value)) { + const auto& subgraph = std::get>(attr.value); + auto executionKernels = initializeNodeKernels( + *subgraph, weights, executorConfig, placement); + CHECK(executionKernels.delegateExecutors.empty()) + << "HigherOrderKernel does not support delegates"; + CHECK(executionKernels.constFoldingExecutions.size() == 0) + << "HigherOrderKernel does not support const folding"; + if (executorConfig.maxParallelOps > 1) { + graphExecutors.emplace_back( + std::unique_ptr(new ParallelGraphExecutor( + *subgraph, + std::move(executionKernels.nodeKernels), + executorConfig))); + } else { + graphExecutors.emplace_back(std::unique_ptr( + new torch::nativert::SerialGraphExecutor( + *subgraph, + std::move(executionKernels.nodeKernels), + executorConfig))); + } + } + } + if (node.target() == "torch.ops.higher_order.run_const_graph") { + constFoldingExecutions.push_back( + ConstFoldingExecution{std::move(graphExecutors[0])}); + } + nodeKernels.push_back(std::make_unique( + &node, std::move(graphExecutors))); + } else if (c10::starts_with(node.target(), "torch.ops")) { + nodeKernels.push_back(std::make_unique(&node, targetDevice)); + + std::string opName = std::string(node.target()); + if (opsWithoutStaticDispatchCount.find(opName) == + opsWithoutStaticDispatchCount.end()) { + opsWithoutStaticDispatchCount[opName] = 0; + } + opsWithoutStaticDispatchCount[opName] += 1; + } else { + TORCH_CHECK(false, "Unsupported operator: ", target); + } + } + + if (executorConfig.enableStaticCPUKernels) { + std::stringstream ss; + for (const auto& [op, count] : opsWithoutStaticDispatchCount) { + ss << op << ": " << count << ", \n"; + } + LOG(WARNING) << "Following ops are missing static dispatched kernels: \n" + << ss.str(); + } + + return { + std::move(nodeKernels), + std::move(delegateExecutors), + std::move(constFoldingExecutions)}; +} +} // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelFactory.h b/torch/nativert/kernels/KernelFactory.h new file mode 100644 index 00000000000000..c01d64c3a0178c --- /dev/null +++ b/torch/nativert/kernels/KernelFactory.h @@ -0,0 +1,89 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace torch::nativert { + +struct ConstFoldingExecution { + std::unique_ptr executor; +}; + +struct ExecutionKernels { + std::vector> nodeKernels; + std::vector> delegateExecutors; + std::vector constFoldingExecutions; +}; + +class KernelFactoryHandler { + public: + using OpKernelPtr = std::unique_ptr; + using DelegateExecutorPtr = std::unique_ptr; + using Matcher = c10::function_ref; + using Callback = + c10::function_ref( + const Node&, + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + caffe2::serialize::PyTorchStreamReader* pytorchStreamReader, + c10::Device targetDevice)>; + + KernelFactoryHandler(Matcher matcher, Callback callback) + : matcher_(matcher), callback_(callback) {} + + KernelFactoryHandler() = delete; + KernelFactoryHandler(const KernelFactoryHandler&) = default; + KernelFactoryHandler& operator=(const KernelFactoryHandler&) = default; + KernelFactoryHandler(KernelFactoryHandler&&) = default; + KernelFactoryHandler& operator=(KernelFactoryHandler&&) = default; + ~KernelFactoryHandler() = default; + + bool match( + const Node& node, + const torch::nativert::ExecutorConfig& config, + c10::Device device) const { + return matcher_(node, config, device); + } + + std::pair operator()( + const Node& node, + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + caffe2::serialize::PyTorchStreamReader* pytorchStreamReader, + c10::Device targetDevice) const { + return callback_( + node, weights, executorConfig, pytorchStreamReader, targetDevice); + } + + private: + Matcher matcher_; + Callback callback_; +}; + +class KernelFactory { + public: + explicit KernelFactory() {} + + ExecutionKernels initializeNodeKernels( + const Graph& graph, + std::shared_ptr weights, + const torch::nativert::ExecutorConfig& executorConfig, + const Placement& placement, + std::shared_ptr + pytorchStreamReader = nullptr, + const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr); + + static void registerHandler( + const std::string& name, + KernelFactoryHandler handler); +}; + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/PrimKernelRegistry.cpp b/torch/nativert/kernels/PrimKernelRegistry.cpp new file mode 100644 index 00000000000000..e6f69634a71b84 --- /dev/null +++ b/torch/nativert/kernels/PrimKernelRegistry.cpp @@ -0,0 +1,163 @@ +#include + +#include +#include +#include +#include + +#include +#include + +namespace torch::nativert { + +C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*); + +namespace { + +class OpKernel_prim_listpack : public OpKernel { + public: + explicit OpKernel_prim_listpack(const Node* node) + : OpKernel(node, std::nullopt, OpKernelKind::kPrimKernel) { + auto listType = node->outputs()[0]->type(); + switch (listType.kind()) { + case Type::Kind::TensorList: + type_ = c10::TensorType::get(); + break; + case Type::Kind::SymIntList: + type_ = c10::IntType::get(); + break; + case Type::Kind::OptionalTensorList: + type_ = c10::OptionalType::create(c10::TensorType::get()); + break; + default: + TORCH_CHECK(false, "Unsupported list type: ", listType); + } + } + + void computeInternal(ExecutionFrame& executionFrame) const override final { + RECORD_USER_SCOPE("nativert::OpKernel_prim_listpack"); + c10::List list(type_); + list.reserve(numInputs()); + for (size_t i = 0; i < numInputs(); ++i) { + if (KernelInput(i).isNone()) { + list.emplace_back(); + } else { + list.push_back(KernelInput(i)); + } + } + KernelOutput(0) = std::move(list); + } + + private: + c10::TypePtr type_; +}; + +} // namespace + +C10_REGISTER_TYPED_CLASS( + PrimKernelRegistry, + "prim.ListPack", + OpKernel_prim_listpack); + +REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, { + RECORD_USER_SCOPE("nativert::OpKernel_prim_listunpack"); + auto inputListRef = KernelInput(0).toListRef(); + for (const auto& [i, ivalue] : c10::enumerate(inputListRef)) { + KernelOutput(i) = ivalue; + } +}); + +// Noop for input and output +REGISTER_PRIM_KERNEL("prim.Input", prim_input, {}); +REGISTER_PRIM_KERNEL("prim.Output", prim_output, {}); + +namespace { + +class OpKernel_variadic_concat : public OpKernel { + public: + explicit OpKernel_variadic_concat(const Node* node) + : OpKernel(node, std::nullopt, OpKernelKind::kPrimKernel) { + dim_ = node_->attributes().size() > 0 + ? constantToIValue(node_->getAttribute("dim").value).toInt() + : 0; + } + void computeInternal(ExecutionFrame& executionFrame) const override final { + { + const size_t numNodeInps = numInputs(); + auto numCatInps = numNodeInps; + auto dim = dim_; + if (KernelInput(numCatInps - 1).isInt()) { + dim = KernelInput(numCatInps - 1).toInt(); + numCatInps--; + } + std::vector inputs(numCatInps); + for (const auto i : c10::irange(numCatInps)) { + inputs[i] = KernelInput(i).toTensor(); + } + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cat(inputs, dim); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::cat_outf(inputs, dim, out_t); + } + } + + private: + int dim_; +}; + +} // namespace + +C10_REGISTER_TYPED_CLASS( + PrimKernelRegistry, + "prim.VarConcat", + OpKernel_variadic_concat); + +namespace { + +class OpKernel_variadic_stack : public OpKernel { + public: + explicit OpKernel_variadic_stack(const Node* node) + : OpKernel(node, std::nullopt, OpKernelKind::kPrimKernel) { + dim_ = node_->attributes().size() > 0 + ? constantToIValue(node_->getAttribute("dim").value).toInt() + : 0; + } + void computeInternal(ExecutionFrame& executionFrame) const override final { + { + const size_t numNodeInps = numInputs(); + auto numStackInps = numNodeInps; + auto dim = dim_; + if (KernelInput(numStackInps - 1).isInt()) { + dim = KernelInput(numStackInps - 1).toInt(); + numStackInps--; + } + std::vector inputs(numStackInps); + for (const auto i : c10::irange(numStackInps)) { + inputs[i] = KernelInput(i).toTensor(); + } + auto& out = KernelOutput(0); + if (out.isNone()) { + out = at::native::_stack_cpu(inputs, dim); + return; + } + auto& out_t = out.toTensor(); + fastResizeToZero(out_t); + at::native::_stack_out_cpu(inputs, dim, out_t); + } + } + + private: + int64_t dim_; +}; +} // namespace + +C10_REGISTER_TYPED_CLASS( + PrimKernelRegistry, + "prim.VarStack", + OpKernel_variadic_stack); + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/PrimKernelRegistry.h b/torch/nativert/kernels/PrimKernelRegistry.h new file mode 100644 index 00000000000000..89e9c29e7dcb5d --- /dev/null +++ b/torch/nativert/kernels/PrimKernelRegistry.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +namespace torch::nativert { + +#define KernelInput(id) input(id, executionFrame) +#define KernelOutput(id) output(id, executionFrame) + +TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*); + +#define REGISTER_PRIM_KERNEL(name, id, ...) \ + class OpKernel_##id : public OpKernel { \ + public: \ + OpKernel_##id(const Node* node) \ + : OpKernel(node, std::nullopt, OpKernelKind::kPrimKernel) {} \ + void computeInternal( \ + ExecutionFrame& executionFrame) const override final { \ + __VA_ARGS__; \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id); + +inline bool checkResizedDataPtr(at::Tensor& t) { + auto const prev_data_ptr = t.data_ptr(); + t.resize_({0}); + return prev_data_ptr == t.data_ptr(); +} + +inline void fastResizeToZero(at::Tensor& t) { + t.unsafeGetTensorImpl()->set_sizes_contiguous({0}); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t)); +} + +} // namespace torch::nativert diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 2518e86f642a96..14e71c506385eb 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -319,7 +319,7 @@ def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): ) @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] # If you're wondering why there's a nested tensor with one of its # size = -1, see note: [NJT outer_size in AOTDispatcher] kwargs = {} if kwargs is None else kwargs @@ -529,9 +529,9 @@ def jagged_from_tensor_and_lengths( ) # Calculate jagged offsets - assert ( - len(tensor.shape) >= 2 - ), "tensor must at least be 2D for the nested narrow op to work" + assert len(tensor.shape) >= 2, ( + "tensor must at least be 2D for the nested narrow op to work" + ) max_seq_len = tensor.shape[1] offset_lengths = max_seq_len * torch.arange( 0, batch_size, dtype=torch.int64, device=tensor.device diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 9525508d750706..8eb962f8a308d2 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -73,9 +73,9 @@ def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1): """ from torch._prims_common import canonicalize_dims - assert isinstance( - dims, (tuple, list) - ), f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}" + assert isinstance(dims, (tuple, list)), ( + f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}" + ) wrapped_dims = [ canonicalize_dims(ndim, d) for d in dims @@ -535,9 +535,9 @@ def clone_default(func, *args, **kwargs): from .nested_tensor import jagged_from_list # TODO: We probably want the output to have the same ragged structure / nested int. - assert ( - inp._ragged_idx == 1 - ), "NJT with ragged_idx != 1 not supported for contiguous clone" + assert inp._ragged_idx == 1, ( + "NJT with ragged_idx != 1 not supported for contiguous clone" + ) contig, _ = jagged_from_list(inp.unbind(), offsets=None) return contig @@ -1730,8 +1730,8 @@ def native_layer_norm_default(func, *args, **kwargs): ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm padded_normalized = ( - padded_input - mean - ) * padded_mask # mask elements outside of the ragged dimension size for correct variance calculation + (padded_input - mean) * padded_mask + ) # mask elements outside of the ragged dimension size for correct variance calculation variance = ( torch.sum( diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 8ac4cc86a58cc3..997e1805d08c34 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -369,7 +369,7 @@ def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor): # use with the flash-attention and efficient_attention kernels without # needing to call contiguous on the nested tensor input. # It checks that the storage offsets' adjacent_differences are a constant - # mutiple of the previous tensor in the nested tensor and that the strides + # multiple of the previous tensor in the nested tensor and that the strides # are monitonically decreasing. This check is done after calling transpose on # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim] diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 4ccf43c33982c1..2ebbe1804c2174 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """ +"""This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention""" + import contextlib from collections.abc import Iterable from typing import Union @@ -119,6 +120,7 @@ def sdpa_kernel( from torch.nn.functional import scaled_dot_product_attention from torch.nn.attention import SDPBackend, sdpa_kernel + # Only enable flash attention backend with sdpa_kernel(SDPBackend.FLASH_ATTENTION): scaled_dot_product_attention(...) @@ -130,9 +132,9 @@ def sdpa_kernel( This context manager can be used to select which backend to use for scaled dot product attention. Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends. """ - assert isinstance( - backends, (list, SDPBackend) - ), "Backend must be an instance of SDPBackend or a list of SDPBackend instances" + assert isinstance(backends, (list, SDPBackend)), ( + "Backend must be an instance of SDPBackend or a list of SDPBackend instances" + ) if isinstance(backends, SDPBackend): backends = [backends] diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index 5b09a2c14c24a7..a91045b92c13e6 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs """Defines utilities for interacting with scaled_dot_product_attention""" + import math from typing import Optional diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 36c0a18cdd1249..3d002b7b23656b 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs """Defines bias subclasses that work with scaled_dot_product_attention""" + from enum import auto, IntEnum from typing import Optional from warnings import warn @@ -36,14 +37,14 @@ class CausalVariant(IntEnum): Defines two types of causal biases: - `UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention. + ``UPPER_LEFT``: Represents upper-left triangular bias for standard causal attention. The equivalent pytorch code for constructing this bias is: .. code-block:: python torch.tril(torch.ones(size, dtype=torch.bool)) - For instance, with `shape=(3,4)`, the materialized bias tensor will be: + For instance, with ``shape=(3,4)``, the materialized bias tensor will be: .. code-block:: text @@ -52,7 +53,7 @@ class CausalVariant(IntEnum): [1, 1, 1, 0]] - `LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower + ``LOWER_RIGHT``: Represents lower-right triangular bias, the include values are aligned to the lower right corner of the matrix. The equivalent pytorch code for constructing this bias is: @@ -65,7 +66,7 @@ class CausalVariant(IntEnum): diagonal=diagonal_offset, ) - For instance, with `shape=(3,4)`, the materialized bias tensor will be: + For instance, with ``shape=(3,4)``, the materialized bias tensor will be: .. code-block:: text @@ -101,9 +102,15 @@ class CausalBias(torch.Tensor): # Create a lower-right causal bias attn_bias = causal_lower_right(seqlen_q, seqlen_kv) - q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) - k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) - v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) + q = torch.randn( + bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16 + ) + k = torch.randn( + bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16 + ) out = F.scaled_dot_product_attention(q, k, v, attn_bias) @@ -262,7 +269,7 @@ def _dispatch( )[0].transpose(1, 2) else: _raise_kernel_warnings(sdpa_params) - # We cant use efficient attention the only support for lower right is via materialization + # We can't use efficient attention the only support for lower right is via materialization return F.scaled_dot_product_attention( query, key, diff --git a/torch/nn/attention/experimental/__init__.py b/torch/nn/attention/experimental/__init__.py index 0b3d262932b395..4a6694bbe3990b 100644 --- a/torch/nn/attention/experimental/__init__.py +++ b/torch/nn/attention/experimental/__init__.py @@ -1,2 +1,2 @@ # Experimental features are not mature yet and are subject to change. -# We do not provide any BC/FC guarntees +# We do not provide any BC/FC guarantees diff --git a/torch/nn/attention/experimental/_paged_attention.py b/torch/nn/attention/experimental/_paged_attention.py index 06b0b632ab9327..2e31b5ec3cec32 100644 --- a/torch/nn/attention/experimental/_paged_attention.py +++ b/torch/nn/attention/experimental/_paged_attention.py @@ -29,7 +29,7 @@ class PagedAttention: """ PagedAttention supports flex attention inference with a large batch size. With PagedAttention, a batch of key/value tensors with varying kv length - is splitted into tensor blocks of fixed length and cached in a compact way. + is split into tensor blocks of fixed length and cached in a compact way. Thus we can avoid redundant memory consumption due to varying kv length and support a larger batch size. """ @@ -182,9 +182,7 @@ def assign( logical_block_offset = input_pos % self.page_size # [B, S] physical_block_idx = torch.gather( self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64) - ).to( - torch.int32 - ) # [B, S] + ).to(torch.int32) # [B, S] addr = (physical_block_idx * self.page_size + logical_block_offset).view( -1 diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index f36571f58a79a2..15a00e1a9d342b 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs # flake8: noqa: B950 """This module implements the user facing API for flex_attention in PyTorch.""" + import functools import inspect import itertools @@ -293,12 +294,12 @@ def __init__( assert kv_indices is not None, "kv_indices must be provided" assert q_num_blocks is not None, "q_num_blocks must be provided" assert q_indices is not None, "q_indices must be provided" - assert (full_kv_num_blocks is None) == ( - full_kv_indices is None - ), "full_kv_num_blocks and full_kv_indices must be both provided or omitted" - assert (full_q_num_blocks is None) == ( - full_q_indices is None - ), "full_q_num_blocks and full_q_indices must be both provided or omitted" + assert (full_kv_num_blocks is None) == (full_kv_indices is None), ( + "full_kv_num_blocks and full_kv_indices must be both provided or omitted" + ) + assert (full_q_num_blocks is None) == (full_q_indices is None), ( + "full_q_num_blocks and full_q_indices must be both provided or omitted" + ) self.seq_lengths = seq_lengths self.kv_num_blocks = kv_num_blocks @@ -344,9 +345,9 @@ def from_kv_blocks( if kv_indices.dim() < 2: raise RuntimeError("BlockMask must have at least 2 dimensions") - assert (full_kv_num_blocks is None) == ( - full_kv_indices is None - ), "full_kv_num_blocks and full_kv_indices must be both provided or omitted" + assert (full_kv_num_blocks is None) == (full_kv_indices is None), ( + "full_kv_num_blocks and full_kv_indices must be both provided or omitted" + ) # Generate q_num_blocks and q_indices q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices) @@ -434,29 +435,34 @@ def __getitem__(self, index) -> "BlockMask": def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device="cuda") - assert block_mask.kv_num_blocks.shape == (4,2,4) - assert block_mask.kv_indices.shape == (4,2,4,4) + + block_mask = create_block_mask( + causal_mask, 4, 2, 512, 512, device="cuda" + ) + assert block_mask.kv_num_blocks.shape == (4, 2, 4) + assert block_mask.kv_indices.shape == (4, 2, 4, 4) # Index on batch dimension new_block_mask = block_mask[0] - assert new_block_mask.kv_num_blocks.shape == (2,4) - assert new_block_mask.kv_indices.shape == (2,4,4) + assert new_block_mask.kv_num_blocks.shape == (2, 4) + assert new_block_mask.kv_indices.shape == (2, 4, 4) # Index on batch and head dimension new_block_mask = block_mask[0, 1] assert new_block_mask.kv_num_blocks.shape == (4,) - assert new_block_mask.kv_indices.shape == (4,4) + assert new_block_mask.kv_indices.shape == (4, 4) # slicing on batch and head dimension new_block_mask = block_mask[0:2, 1:2] - assert new_block_mask.kv_num_blocks.shape == (2,1,4) - assert new_block_mask.kv_indices.shape == (2,1,4,4) + assert new_block_mask.kv_num_blocks.shape == (2, 1, 4) + assert new_block_mask.kv_indices.shape == (2, 1, 4, 4) # slicing on batch, head, and query dimension - new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)] - assert new_block_mask.kv_num_blocks.shape == (2,1,1) - assert new_block_mask.kv_indices.shape == (2,1,1,4) + new_block_mask = block_mask[ + 0:2, 1:2, torch.tensor([1], dtype=torch.int32) + ] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 1) + assert new_block_mask.kv_indices.shape == (2, 1, 1, 4) """ new_kv_num_blocks = self.kv_num_blocks[index] new_kv_indices = self.kv_indices[index] @@ -485,7 +491,7 @@ def shape_or_none(x: Optional[torch.Tensor]): f"BlockMask(\n" f" kv_num_blocks={self.kv_num_blocks.shape},\n" f" kv_indices={self.kv_indices.shape},\n" - f" full_kv_num_blocks={shape_or_none(self.full_kv_num_blocks )},\n" + f" full_kv_num_blocks={shape_or_none(self.full_kv_num_blocks)},\n" f" full_kv_indices={shape_or_none(self.full_kv_indices)},\n" f" q_num_blocks={shape_or_none(self.q_num_blocks)},\n" f" q_indices={shape_or_none(self.q_indices)},\n" @@ -636,7 +642,7 @@ def to(self, device: Union[torch.device, str]) -> "BlockMask": Note: This method does not modify the original BlockMask in-place. - Instead, it returns a new BlockMask instance where invidual tensor attributes + Instead, it returns a new BlockMask instance where individual tensor attributes may or may not be moved to the specified device, depending on their current device placement. """ @@ -857,6 +863,7 @@ def create_block_mask( def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx + block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) @@ -864,9 +871,9 @@ def causal_mask(b, h, q_idx, kv_idx): output = flex_attention(query, key, value, block_mask=block_mask) """ mod_type = _get_mod_type(mask_mod) - assert ( - mod_type == _ModificationType.MASK_MOD - ), f"create-block_mask requires a mask_mod function! Got {mask_mod}" + assert mod_type == _ModificationType.MASK_MOD, ( + f"create-block_mask requires a mask_mod function! Got {mask_mod}" + ) if B is None: B = 1 if H is None: @@ -962,7 +969,10 @@ def _build_seq_idx(offsets, total_length): kv_seq_idx = q_seq_idx else: # cross attention case - kv_seq_idx = _build_seq_idx(kv_offsets, kv_nt._values.shape[kv_nt._ragged_idx - 1]) # type: ignore[attr-defined] + kv_seq_idx = _build_seq_idx( + kv_offsets, + kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined] + ) # Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers # to the sequence length for each sequence in the NJT, for use in given @@ -1039,10 +1049,14 @@ def create_nested_block_mask( key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) + def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) + + block_mask = create_nested_block_mask( + causal_mask, 1, 1, query, _compile=True + ) output = flex_attention(query, key, value, block_mask=block_mask) .. code-block:: python @@ -1052,11 +1066,15 @@ def causal_mask(b, h, q_idx, kv_idx): key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) + def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx + # cross attention case: pass both query and key/value NJTs - block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True) + block_mask = create_nested_block_mask( + causal_mask, 1, 1, query, key, _compile=True + ) output = flex_attention(query, key, value, block_mask=block_mask) """ # use same structure for kv as for q by default @@ -1381,7 +1399,13 @@ def score_mod( torch._dynamo.mark_static(x, -1) out, lse = flex_attention_hop( - query, key, value, score_mod, block_mask.as_tuple(), scale, kernel_options # type: ignore[union-attr] + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, # type: ignore[union-attr] ) if return_lse: return out, lse * math.log(2) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index cefed02d4743d3..a03c5ec15678d2 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -55,9 +55,7 @@ Note: This operator supports complex data types i.e. ``complex32, complex64, complex128``. -""".format( - **reproducibility_notes, **tf32_notes - ) +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: @@ -106,9 +104,7 @@ Note: This operator supports complex data types i.e. ``complex32, complex64, complex128``. -""".format( - **reproducibility_notes, **tf32_notes - ) +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: @@ -159,9 +155,7 @@ Note: This operator supports complex data types i.e. ``complex32, complex64, complex128``. -""".format( - **reproducibility_notes, **tf32_notes - ) +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: @@ -208,9 +202,7 @@ Note: {cudnn_reproducibility_note} -""".format( - **reproducibility_notes, **tf32_notes - ) +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: @@ -251,9 +243,7 @@ Note: {cudnn_reproducibility_note} -""".format( - **reproducibility_notes, **tf32_notes - ) +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: @@ -296,9 +286,7 @@ Note: {cudnn_reproducibility_note} -""".format( - **reproducibility_notes, **tf32_notes - ) +""".format(**reproducibility_notes, **tf32_notes) + r""" Args: @@ -395,12 +383,12 @@ Args: input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` - kernel_size: size of the pooling region. Can be a single number or a + kernel_size: size of the pooling region. Can be a single number, a single-element tuple or a tuple `(kH, kW)` - stride: stride of the pooling operation. Can be a single number or a + stride: stride of the pooling operation. Can be a single number, a single-element tuple or a tuple `(sH, sW)`. Default: :attr:`kernel_size` padding: implicit zero paddings on both sides of the input. Can be a - single number or a tuple `(padH, padW)`. Default: 0 + single number, a single-element tuple or a tuple `(padH, padW)`. Default: 0 ceil_mode: when True, will use `ceil` instead of `floor` in the formula to compute the output shape. Default: ``False`` count_include_pad: when True, will include the zero-padding in the @@ -2335,9 +2323,7 @@ def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)` - Bias: :math:`(out\_features)` or :math:`()` - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight -""".format( - **sparse_support_notes - ), +""".format(**sparse_support_notes), ) @@ -2535,13 +2521,13 @@ def embedding( ) if padding_idx is not None: if padding_idx > 0: - assert padding_idx < weight.size( - 0 - ), "Padding_idx must be within num_embeddings" + assert padding_idx < weight.size(0), ( + "Padding_idx must be within num_embeddings" + ) elif padding_idx < 0: - assert padding_idx >= -weight.size( - 0 - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -weight.size(0), ( + "Padding_idx must be within num_embeddings" + ) padding_idx = weight.size(0) + padding_idx else: padding_idx = -1 @@ -4701,7 +4687,7 @@ def interpolate( # noqa: F811 ) # "area" mode always requires an explicit size rather than scale factor. - # Re-use the recompute_scale_factor code path. + # Reuse the recompute_scale_factor code path. if mode == "area" and output_size is None: recompute_scale_factor = True @@ -4776,9 +4762,7 @@ def interpolate( # noqa: F811 # Two levels are necessary to prevent TorchScript from touching # are_deterministic_algorithms_enabled. if not torch.jit.is_scripting(): - if torch.are_deterministic_algorithms_enabled() and ( - input.is_cuda or input.is_xpu - ): + if not input.is_cpu and torch.are_deterministic_algorithms_enabled(): # Use slow decomp whose backward will be in terms of index_put # importlib is required because the import cannot be top level # (cycle) and cannot be nested (TS doesn't support) @@ -4790,6 +4774,16 @@ def interpolate( # noqa: F811 ) if input.dim() == 5 and mode == "trilinear": assert align_corners is not None + # Two levels are necessary to prevent TorchScript from touching + # are_deterministic_algorithms_enabled. + if not torch.jit.is_scripting(): + if not input.is_cpu and torch.are_deterministic_algorithms_enabled(): + # Use slow decomp whose backward will be in terms of index_put + # importlib is required because the import cannot be top level + # (cycle) and cannot be nested (TS doesn't support) + return importlib.import_module( + "torch._decomp.decompositions" + )._upsample_linear_vec(input, output_size, align_corners, scale_factors) return torch._C._nn.upsample_trilinear3d( input, output_size, align_corners, scale_factors ) @@ -4921,7 +4915,7 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 This is equivalent with ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. - Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo + Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` for volumetric (5 dimensional) inputs. Args: @@ -5800,15 +5794,15 @@ def _in_projection( Eq, Ev, ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" - assert b_q is None or b_q.shape == ( - Eq, - ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" - assert b_k is None or b_k.shape == ( - Eq, - ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" - assert b_v is None or b_v.shape == ( - Eq, - ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + assert b_q is None or b_q.shape == (Eq,), ( + f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + ) + assert b_k is None or b_k.shape == (Eq,), ( + f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + ) + assert b_v is None or b_v.shape == (Eq,), ( + f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + ) return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) @@ -5914,9 +5908,7 @@ def forward(self, ...): Note: {cudnn_reproducibility_note} - """.format( - **reproducibility_notes - ) + """.format(**reproducibility_notes) + r""" Args: query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. @@ -6026,9 +6018,9 @@ def _mha_shape_check( ) if attn_mask.dim() == 3: expected_shape = (num_heads, query.shape[0], key.shape[0]) - assert ( - attn_mask.shape == expected_shape - ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}" + assert attn_mask.shape == expected_shape, ( + f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}" + ) else: raise AssertionError( f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor" @@ -6117,11 +6109,6 @@ def multi_head_attention_forward( ) -> tuple[Tensor, Optional[Tensor]]: r"""Forward method for MultiHeadAttention. - .. note:: - See `this tutorial `_ - for an in depth discussion of the performant building blocks PyTorch offers for building your own - transformer layers. - See :class:`torch.nn.MultiheadAttention` for details. Args: @@ -6294,45 +6281,45 @@ def multi_head_attention_forward( # longer causal. is_causal = False - assert ( - embed_dim == embed_dim_to_check - ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + assert embed_dim == embed_dim_to_check, ( + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + ) if isinstance(embed_dim, torch.Tensor): # embed_dim can be a tensor when JIT tracing head_dim = embed_dim.div(num_heads, rounding_mode="trunc") else: head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + assert head_dim * num_heads == embed_dim, ( + f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + ) if use_separate_proj_weight: # allow MHA to have different embedding dimensions when separate projection weights are used - assert ( - key.shape[:2] == value.shape[:2] - ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + assert key.shape[:2] == value.shape[:2], ( + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + ) else: - assert ( - key.shape == value.shape - ), f"key shape {key.shape} does not match value shape {value.shape}" + assert key.shape == value.shape, ( + f"key shape {key.shape} does not match value shape {value.shape}" + ) # # compute in-projection # if not use_separate_proj_weight: - assert ( - in_proj_weight is not None - ), "use_separate_proj_weight is False but in_proj_weight is None" + assert in_proj_weight is not None, ( + "use_separate_proj_weight is False but in_proj_weight is None" + ) q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) else: - assert ( - q_proj_weight is not None - ), "use_separate_proj_weight is True but q_proj_weight is None" - assert ( - k_proj_weight is not None - ), "use_separate_proj_weight is True but k_proj_weight is None" - assert ( - v_proj_weight is not None - ), "use_separate_proj_weight is True but v_proj_weight is None" + assert q_proj_weight is not None, ( + "use_separate_proj_weight is True but q_proj_weight is None" + ) + assert k_proj_weight is not None, ( + "use_separate_proj_weight is True but k_proj_weight is None" + ) + assert v_proj_weight is not None, ( + "use_separate_proj_weight is True but v_proj_weight is None" + ) if in_proj_bias is None: b_q = b_k = b_v = None else: @@ -6393,23 +6380,23 @@ def multi_head_attention_forward( k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed - assert ( - static_k.size(0) == bsz * num_heads - ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" - assert ( - static_k.size(2) == head_dim - ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + assert static_k.size(0) == bsz * num_heads, ( + f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + ) + assert static_k.size(2) == head_dim, ( + f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + ) k = static_k if static_v is None: v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed - assert ( - static_v.size(0) == bsz * num_heads - ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" - assert ( - static_v.size(2) == head_dim - ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + assert static_v.size(0) == bsz * num_heads, ( + f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + ) + assert static_v.size(2) == head_dim, ( + f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + ) v = static_v # add zero attention along batch dimension (now first) @@ -6456,9 +6443,9 @@ def multi_head_attention_forward( _B, _Nt, E = q.shape q_scaled = q * math.sqrt(1.0 / float(E)) - assert not ( - is_causal and attn_mask is None - ), "FIXME: is_causal not implemented for need_weights" + assert not (is_causal and attn_mask is None), ( + "FIXME: is_causal not implemented for need_weights" + ) if attn_mask is not None: attn_output_weights = torch.baddbmm( diff --git a/torch/nn/init.py b/torch/nn/init.py index 3d0600b43b68f9..af31b6fa228249 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -1,30 +1,97 @@ -# mypy: allow-untyped-defs """This file contains utilities for initializing neural network parameters.""" + import math import warnings -from typing import Optional as _Optional +from typing import Callable, Literal, Optional as _Optional, TypeVar, Union +from typing_extensions import ParamSpec import torch from torch import Tensor +__all__ = [ + "calculate_gain", + "uniform_", + "normal_", + "trunc_normal_", + "constant_", + "ones_", + "zeros_", + "eye_", + "dirac_", + "xavier_uniform_", + "xavier_normal_", + "kaiming_uniform_", + "kaiming_normal_", + "orthogonal_", + "sparse_", + # Deprecated aliases (for backward compatibility) + "uniform", + "normal", + "constant", + "eye", + "dirac", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + "orthogonal", + "sparse", +] + + +_R = TypeVar("_R") +_P = ParamSpec("_P") + +_NonlinearityType = Literal[ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "sigmoid", + "tanh", + "relu", + "leaky_relu", + "selu", +] + +_FanMode = Literal["fan_in", "fan_out"] + + # These no_grad_* functions are necessary as wrappers around the parts of these # functions that use `with torch.no_grad()`. The JIT doesn't support context # managers, so these need to be implemented as builtins. Using these wrappers -# lets us keep those builtins small and re-usable. -def _no_grad_uniform_(tensor, a, b, generator=None): +# lets us keep those builtins small and reusable. +def _no_grad_uniform_( + tensor: Tensor, a: float, b: float, generator: _Optional[torch.Generator] = None +) -> Tensor: with torch.no_grad(): return tensor.uniform_(a, b, generator=generator) -def _no_grad_normal_(tensor, mean, std, generator=None): +def _no_grad_normal_( + tensor: Tensor, + mean: float, + std: float, + generator: _Optional[torch.Generator] = None, +) -> Tensor: with torch.no_grad(): return tensor.normal_(mean, std, generator=generator) -def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): +def _no_grad_trunc_normal_( + tensor: Tensor, + mean: float, + std: float, + a: float, + b: float, + generator: _Optional[torch.Generator] = None, +) -> Tensor: # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): + def norm_cdf(x: float) -> float: # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 @@ -59,17 +126,19 @@ def norm_cdf(x): return tensor -def _no_grad_fill_(tensor, val): +def _no_grad_fill_(tensor: Tensor, val: float) -> Tensor: with torch.no_grad(): return tensor.fill_(val) -def _no_grad_zero_(tensor): +def _no_grad_zero_(tensor: Tensor) -> Tensor: with torch.no_grad(): return tensor.zero_() -def calculate_gain(nonlinearity, param=None): +def calculate_gain( + nonlinearity: _NonlinearityType, param: _Optional[Union[int, float]] = None +) -> float: r"""Return the recommended gain value for the given nonlinearity function. The values are as follows: @@ -99,7 +168,9 @@ def calculate_gain(nonlinearity, param=None): param: optional parameter for the non-linear function Examples: - >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + >>> gain = nn.init.calculate_gain( + ... "leaky_relu", 0.2 + ... ) # leaky_relu with negative_slope=0.2 .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html """ @@ -268,7 +339,7 @@ def zeros_(tensor: Tensor) -> Tensor: return _no_grad_zero_(tensor) -def eye_(tensor): +def eye_(tensor: Tensor) -> Tensor: r"""Fill the 2-dimensional input `Tensor` with the identity matrix. Preserves the identity of the inputs in `Linear` layers, where as @@ -289,7 +360,7 @@ def eye_(tensor): return tensor -def dirac_(tensor, groups=1): +def dirac_(tensor: Tensor, groups: int = 1) -> Tensor: r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. Preserves the identity of the inputs in `Convolutional` @@ -342,7 +413,7 @@ def dirac_(tensor, groups=1): return tensor -def _calculate_fan_in_and_fan_out(tensor): +def _calculate_fan_in_and_fan_out(tensor: Tensor) -> tuple[int, int]: dimensions = tensor.dim() if dimensions < 2: raise ValueError( @@ -387,15 +458,7 @@ def xavier_uniform_( Examples: >>> w = torch.empty(3, 5) - >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) - - Note: - Be aware that ``fan_in`` and ``fan_out`` are calculated assuming - that the weight matrix is used in a transposed manner, - (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). - This is important for correct initialization. - If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, - pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. + >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu")) """ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) @@ -428,14 +491,6 @@ def xavier_normal_( Examples: >>> w = torch.empty(3, 5) >>> nn.init.xavier_normal_(w) - - Note: - Be aware that ``fan_in`` and ``fan_out`` are calculated assuming - that the weight matrix is used in a transposed manner, - (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). - This is important for correct initialization. - If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, - pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``. """ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) @@ -443,7 +498,7 @@ def xavier_normal_( return _no_grad_normal_(tensor, 0.0, std, generator) -def _calculate_correct_fan(tensor, mode): +def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int: mode = mode.lower() valid_modes = ["fan_in", "fan_out"] if mode not in valid_modes: @@ -456,10 +511,10 @@ def _calculate_correct_fan(tensor, mode): def kaiming_uniform_( tensor: Tensor, a: float = 0, - mode: str = "fan_in", - nonlinearity: str = "leaky_relu", + mode: _FanMode = "fan_in", + nonlinearity: _NonlinearityType = "leaky_relu", generator: _Optional[torch.Generator] = None, -): +) -> Tensor: r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. The method is described in `Delving deep into rectifiers: Surpassing @@ -486,7 +541,7 @@ def kaiming_uniform_( Examples: >>> w = torch.empty(3, 5) - >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') + >>> nn.init.kaiming_uniform_(w, mode="fan_in", nonlinearity="relu") Note: Be aware that ``fan_in`` and ``fan_out`` are calculated assuming @@ -521,10 +576,10 @@ def kaiming_uniform_( def kaiming_normal_( tensor: Tensor, a: float = 0, - mode: str = "fan_in", - nonlinearity: str = "leaky_relu", + mode: _FanMode = "fan_in", + nonlinearity: _NonlinearityType = "leaky_relu", generator: _Optional[torch.Generator] = None, -): +) -> Tensor: r"""Fill the input `Tensor` with values using a Kaiming normal distribution. The method is described in `Delving deep into rectifiers: Surpassing @@ -551,7 +606,7 @@ def kaiming_normal_( Examples: >>> w = torch.empty(3, 5) - >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') + >>> nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu") Note: Be aware that ``fan_in`` and ``fan_out`` are calculated assuming @@ -572,10 +627,10 @@ def kaiming_normal_( def orthogonal_( - tensor, - gain=1, + tensor: Tensor, + gain: float = 1, generator: _Optional[torch.Generator] = None, -): +) -> Tensor: r"""Fill the input `Tensor` with a (semi) orthogonal matrix. Described in `Exact solutions to the nonlinear dynamics of learning in deep @@ -623,11 +678,11 @@ def orthogonal_( def sparse_( - tensor, - sparsity, - std=0.01, + tensor: Tensor, + sparsity: float, + std: float = 0.01, generator: _Optional[torch.Generator] = None, -): +) -> Tensor: r"""Fill the 2D input `Tensor` as a sparse matrix. The non-zero elements will be drawn from the normal distribution @@ -661,11 +716,11 @@ def sparse_( # for backward compatibility -def _make_deprecate(meth): +def _make_deprecate(meth: Callable[_P, _R]) -> Callable[_P, _R]: new_name = meth.__name__ old_name = new_name[:-1] - def deprecated_init(*args, **kwargs): + def deprecated_init(*args: _P.args, **kwargs: _P.kwargs) -> _R: warnings.warn( f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.", FutureWarning, diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 54a2dec94e1887..5fe33feedbf2c2 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -977,13 +977,14 @@ def _is_make_fx_tracing(): class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces. - .. note:: - See `this tutorial `_ - for an in depth discussion of the performant building blocks PyTorch offers for building your own - transformer layers. - - Method described in the paper: - `Attention Is All You Need `_. + This MultiheadAttention layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. Multi-Head Attention is defined as: @@ -1076,9 +1077,9 @@ def __init__( self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert self.head_dim * num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) if not self._qkv_same_embed_dim: self.q_proj_weight = Parameter( @@ -1275,8 +1276,10 @@ def forward( elif query.is_nested and ( key_padding_mask is not None or attn_mask is not None ): - why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ + why_not_fast_path = ( + "supplying both src_key_padding_mask and src_mask at the same time \ is not supported with NestedTensor input" + ) elif torch.is_autocast_enabled(): why_not_fast_path = "autocast is enabled" diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 93e11c63082e15..cd4076c7a11bd7 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -18,13 +18,15 @@ class AdaptiveLogSoftmaxWithLoss(Module): - """Efficient softmax approximation. + ( + """Efficient softmax approximation. As described in `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou `__. -""" r""" +""" + r""" Adaptive softmax is an approximate strategy for training models with large output spaces. It is most effective when the label distribution is highly imbalanced, for example in natural language modelling, where the word @@ -104,6 +106,7 @@ class AdaptiveLogSoftmaxWithLoss(Module): .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law """ + ) in_features: int n_classes: int @@ -182,8 +185,7 @@ def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: if targ_dim == 1: if input_.size(0) != target_.size(0): raise RuntimeError( - "Input and target should have the same size " - "in the batch dimension." + "Input and target should have the same size in the batch dimension." ) if input_.dim() != 2: raise RuntimeError( diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 298ab149639d7e..55ed1f4a01a8f0 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -86,31 +86,30 @@ class Sequential(Module): # for `Conv2d(20,64,5)`. Finally, the output of # `Conv2d(20,64,5)` will be used as input to the second `ReLU` model = nn.Sequential( - nn.Conv2d(1,20,5), - nn.ReLU(), - nn.Conv2d(20,64,5), - nn.ReLU() - ) + nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() + ) # Using Sequential with OrderedDict. This is functionally the # same as the above code - model = nn.Sequential(OrderedDict([ - ('conv1', nn.Conv2d(1,20,5)), - ('relu1', nn.ReLU()), - ('conv2', nn.Conv2d(20,64,5)), - ('relu2', nn.ReLU()) - ])) + model = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 20, 5)), + ("relu1", nn.ReLU()), + ("conv2", nn.Conv2d(20, 64, 5)), + ("relu2", nn.ReLU()), + ] + ) + ) """ _modules: dict[str, Module] # type: ignore[assignment] @overload - def __init__(self, *args: Module) -> None: - ... + def __init__(self, *args: Module) -> None: ... @overload - def __init__(self, arg: OrderedDict[str, Module]) -> None: - ... + def __init__(self, arg: OrderedDict[str, Module]) -> None: ... def __init__(self, *args): super().__init__() @@ -365,12 +364,10 @@ def _get_abs_string_index(self, idx): return str(idx) @overload - def __getitem__(self, idx: slice) -> ModuleList: - ... + def __getitem__(self, idx: slice) -> ModuleList: ... @overload - def __getitem__(self, idx: int) -> Module: - ... + def __getitem__(self, idx: int) -> Module: ... @_copy_to_script_wrapper def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]: @@ -489,7 +486,7 @@ def extend(self, modules: Iterable[Module]) -> Self: self.add_module(str(offset + i), module) return self - # remove forward alltogether to fallback on Module's _forward_unimplemented + # remove forward altogether to fallback on Module's _forward_unimplemented class ModuleDict(Module): @@ -521,14 +518,12 @@ class ModuleDict(Module): class MyModule(nn.Module): def __init__(self) -> None: super().__init__() - self.choices = nn.ModuleDict({ - 'conv': nn.Conv2d(10, 10, 3), - 'pool': nn.MaxPool2d(3) - }) - self.activations = nn.ModuleDict([ - ['lrelu', nn.LeakyReLU()], - ['prelu', nn.PReLU()] - ]) + self.choices = nn.ModuleDict( + {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)} + ) + self.activations = nn.ModuleDict( + [["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]] + ) def forward(self, x, choice, act): x = self.choices[choice](x) @@ -631,7 +626,7 @@ def update(self, modules: Mapping[str, Module]) -> None: # that's too cumbersome to type correctly with overloads, so we add an ignore here self[m[0]] = m[1] # type: ignore[assignment] - # remove forward alltogether to fallback on Module's _forward_unimplemented + # remove forward altogether to fallback on Module's _forward_unimplemented class ParameterList(Module): @@ -653,7 +648,9 @@ class ParameterList(Module): class MyModule(nn.Module): def __init__(self) -> None: super().__init__() - self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)]) + self.params = nn.ParameterList( + [nn.Parameter(torch.randn(10, 10)) for i in range(10)] + ) def forward(self, x): # ParameterList can act as an iterable, or be indexed using ints @@ -678,12 +675,10 @@ def _get_abs_string_index(self, idx): return str(idx) @overload - def __getitem__(self, idx: int) -> Any: - ... + def __getitem__(self, idx: int) -> Any: ... @overload - def __getitem__(self: T, idx: slice) -> T: - ... + def __getitem__(self: T, idx: slice) -> T: ... def __getitem__(self, idx): if isinstance(idx, slice): @@ -805,10 +800,12 @@ class ParameterDict(Module): class MyModule(nn.Module): def __init__(self) -> None: super().__init__() - self.params = nn.ParameterDict({ - 'left': nn.Parameter(torch.randn(5, 10)), - 'right': nn.Parameter(torch.randn(5, 10)) - }) + self.params = nn.ParameterDict( + { + "left": nn.Parameter(torch.randn(5, 10)), + "right": nn.Parameter(torch.randn(5, 10)), + } + ) def forward(self, x, choice): x = self.params[choice].mm(x) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index e723619f227244..e9c674b9982939 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -66,8 +66,9 @@ class _ConvNd(Module): ] __annotations__ = {"bias": Optional[torch.Tensor]} - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body] - ... + def _conv_forward( # type: ignore[empty-body] + self, input: Tensor, weight: Tensor, bias: Optional[Tensor] + ) -> Tensor: ... in_channels: int _reversed_padding_repeated_twice: list[int] @@ -187,10 +188,7 @@ def reset_parameters(self) -> None: init.uniform_(self.bias, -bound, bound) def extra_repr(self): - s = ( - "{in_channels}, {out_channels}, kernel_size={kernel_size}" - ", stride={stride}" - ) + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" if self.padding != (0,) * len(self.padding): s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): @@ -279,9 +277,7 @@ class Conv1d(_ConvNd): padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` - """.format( - **reproducibility_notes, **convolution_notes - ) + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: @@ -450,9 +446,7 @@ class Conv2d(_ConvNd): output. Default: ``True`` padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` - """.format( - **reproducibility_notes, **convolution_notes - ) + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: @@ -619,9 +613,7 @@ class Conv3d(_ConvNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` - """.format( - **reproducibility_notes, **convolution_notes - ) + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: @@ -883,9 +875,7 @@ class ConvTranspose1d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 - """.format( - **reproducibility_notes, **convolution_notes - ) + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: @@ -1051,9 +1041,7 @@ class ConvTranspose2d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 - """.format( - **reproducibility_notes, **convolution_notes - ) + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: @@ -1249,9 +1237,7 @@ class ConvTranspose3d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 - """.format( - **reproducibility_notes, **convolution_notes - ) + """.format(**reproducibility_notes, **convolution_notes) + r""" Shape: diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index 39b702c38e1743..6ae92b92e07eb7 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -96,8 +96,8 @@ class Unflatten(Module): >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) - >>> input = torch.randn(2, 50, names=('N', 'features')) - >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) + >>> input = torch.randn(2, 50, names=("N", "features")) + >>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index 58397caec32c8c..43d2f3c4f592b2 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -9,7 +9,8 @@ class Fold(Module): - r"""Combines an array of sliding local blocks into a large containing tensor. + ( + r"""Combines an array of sliding local blocks into a large containing tensor. Consider a batched :attr:`input` tensor containing sliding local blocks, e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, @@ -42,10 +43,12 @@ class Fold(Module): * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension before reshaping. -""" """ +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. -""" r""" +""" + r""" Args: output_size (int or tuple): the shape of the spatial dimensions of the output (i.e., ``output.sizes()[2:]``) @@ -119,6 +122,7 @@ class Fold(Module): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + ) __constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"] output_size: _size_any_t @@ -162,7 +166,8 @@ def extra_repr(self) -> str: class Unfold(Module): - r"""Extracts sliding local blocks from a batched input tensor. + ( + r"""Extracts sliding local blocks from a batched input tensor. Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`, where :math:`N` is the batch dimension, :math:`C` is the channel dimension, @@ -194,10 +199,12 @@ class Unfold(Module): * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension before reshaping. -""" """ +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. -""" r""" +""" + r""" Args: kernel_size (int or tuple): the size of the sliding blocks dilation (int or tuple, optional): a parameter that controls the @@ -283,6 +290,7 @@ class Unfold(Module): https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + ) __constants__ = ["kernel_size", "dilation", "padding", "stride"] kernel_size: _size_any_t diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index 002447fc0c5977..46e7c7be63dbc9 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -15,11 +15,9 @@ class _LazyProtocol(Protocol): https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes """ - def _register_load_state_dict_pre_hook(self, hook): - ... + def _register_load_state_dict_pre_hook(self, hook): ... - def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): - ... + def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): ... def _lazy_load_hook( self, @@ -30,34 +28,26 @@ def _lazy_load_hook( missing_keys, unexpected_keys, error_msgs, - ): - ... + ): ... - def _get_name(self): - ... + def _get_name(self): ... - def _infer_parameters(self, module, input): - ... + def _infer_parameters(self, module, input): ... @property - def _parameters(self): - ... + def _parameters(self): ... @property - def _buffers(self): - ... + def _buffers(self): ... @property - def _non_persistent_buffers_set(self): - ... + def _non_persistent_buffers_set(self): ... @property - def _load_hook(self): - ... + def _load_hook(self): ... @property - def _initialize_hook(self): - ... + def _initialize_hook(self): ... class LazyModuleMixin: @@ -86,17 +76,17 @@ class LazyModuleMixin: >>> # xdoctest: +SKIP >>> class LazyMLP(torch.nn.Module): - ... def __init__(self) -> None: - ... super().__init__() - ... self.fc1 = torch.nn.LazyLinear(10) - ... self.relu1 = torch.nn.ReLU() - ... self.fc2 = torch.nn.LazyLinear(1) - ... self.relu2 = torch.nn.ReLU() + ... def __init__(self) -> None: + ... super().__init__() + ... self.fc1 = torch.nn.LazyLinear(10) + ... self.relu1 = torch.nn.ReLU() + ... self.fc2 = torch.nn.LazyLinear(1) + ... self.relu2 = torch.nn.ReLU() ... - ... def forward(self, input): - ... x = self.relu1(self.fc1(input)) - ... y = self.relu2(self.fc2(x)) - ... return y + ... def forward(self, input): + ... x = self.relu1(self.fc1(input)) + ... y = self.relu2(self.fc2(x)) + ... return y >>> # constructs a network with lazy modules >>> lazy_mlp = LazyMLP() >>> # transforms the network's device and dtype @@ -109,7 +99,7 @@ class LazyModuleMixin: (relu2): ReLU() ) >>> # performs a dry run to initialize the network's lazy modules - >>> lazy_mlp(torch.ones(10,10).cuda()) + >>> lazy_mlp(torch.ones(10, 10).cuda()) >>> # after initialization, LazyLinear modules become regular Linear modules >>> lazy_mlp LazyMLP( @@ -180,7 +170,7 @@ class LazyModuleMixin: cls_to_become: Optional[type[Any]] = None def __init__(self: _LazyProtocol, *args, **kwargs): - # Mypy doesnt like this super call in a mixin + # Mypy doesn't like this super call in a mixin super().__init__(*args, **kwargs) # type: ignore[misc] self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) self._initialize_hook = self.register_forward_pre_hook( @@ -260,7 +250,7 @@ def has_uninitialized_params(self: _LazyProtocol): def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None): r"""Infers the size and initializes the parameters according to the provided input batch. - Given a module that contains parameters that were declared inferrable + Given a module that contains parameters that were declared inferable using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass in the complete module using the provided input to initialize all the parameters as needed. diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 75d5c91756df0d..4248a1b976b42d 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -119,6 +119,7 @@ class L1Loss(_Loss): >>> output = loss(input, target) >>> output.backward() """ + __constants__ = ["reduction"] def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: @@ -233,6 +234,7 @@ class NLLLoss(_WeightedLoss): >>> loss = loss_fn(output, target) >>> loss.backward() """ + __constants__ = ["ignore_index", "reduction"] ignore_index: int @@ -331,6 +333,7 @@ class PoissonNLLLoss(_Loss): - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`, the same shape as the input. """ + __constants__ = ["log_input", "full", "eps", "reduction"] log_input: bool full: bool @@ -427,6 +430,7 @@ class GaussianNLLLoss(_Loss): Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138. """ + __constants__ = ["full", "eps", "reduction"] full: bool eps: float @@ -467,7 +471,7 @@ class KLDivLoss(_Loss): .. code-block:: python - if not log_target: # default + if not log_target: # default loss_pointwise = target * (target.log() - input) else: loss_pointwise = target.exp() * (target - input) @@ -527,6 +531,7 @@ class KLDivLoss(_Loss): >>> log_target = F.log_softmax(torch.rand(3, 5), dim=1) >>> output = kl_loss(input, log_target) """ + __constants__ = ["reduction"] def __init__( @@ -601,6 +606,7 @@ class MSELoss(_Loss): >>> output = loss(input, target) >>> output.backward() """ + __constants__ = ["reduction"] def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: @@ -684,6 +690,7 @@ class BCELoss(_WeightedLoss): >>> output = loss(m(input), target) >>> output.backward() """ + __constants__ = ["reduction"] def __init__( @@ -785,7 +792,7 @@ class BCEWithLogitsLoss(_Loss): operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of size [B, C, H, W] will apply different pos_weights to each element of the batch or [C, H, W] the same pos_weights across the batch. To apply the same positive weight - along all spacial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. + along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. Default: ``None`` Shape: @@ -876,6 +883,7 @@ class HingeEmbeddingLoss(_Loss): - Target: :math:`(*)`, same shape as the input - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input """ + __constants__ = ["margin", "reduction"] margin: float @@ -950,6 +958,7 @@ class MultiLabelMarginLoss(_Loss): tensor(0.85...) """ + __constants__ = ["reduction"] def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: @@ -1030,6 +1039,7 @@ class SmoothL1Loss(_Loss): - Target: :math:`(*)`, same shape as the input. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. """ + __constants__ = ["reduction"] def __init__( @@ -1092,6 +1102,7 @@ class HuberLoss(_Loss): - Target: :math:`(*)`, same shape as the input. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. """ + __constants__ = ["reduction", "delta"] def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None: @@ -1134,6 +1145,7 @@ class SoftMarginLoss(_Loss): shape as input. """ + __constants__ = ["reduction"] def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: @@ -1215,7 +1227,7 @@ class probabilities only when a single class label per minibatch item is too res Args: weight (Tensor, optional): a manual rescaling weight given to each class. - If given, has to be a Tensor of size `C` and floating point dtype + If given, has to be a Tensor of size `C`. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field :attr:`size_average` @@ -1276,6 +1288,7 @@ class probabilities only when a single class label per minibatch item is too res >>> output = loss(input, target) >>> output.backward() """ + __constants__ = ["ignore_index", "reduction", "label_smoothing"] ignore_index: int label_smoothing: float @@ -1342,6 +1355,7 @@ class MultiLabelSoftMarginLoss(_WeightedLoss): - Target: :math:`(N, C)`, label targets must have the same shape as the input. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. """ + __constants__ = ["reduction"] def __init__( @@ -1410,6 +1424,7 @@ class CosineEmbeddingLoss(_Loss): >>> output = loss(input1, input2, target) >>> output.backward() """ + __constants__ = ["margin", "reduction"] margin: float @@ -1475,6 +1490,7 @@ class MarginRankingLoss(_Loss): >>> output = loss(input1, input2, target) >>> output.backward() """ + __constants__ = ["margin", "reduction"] margin: float @@ -1554,6 +1570,7 @@ class MultiMarginLoss(_WeightedLoss): >>> loss(x, y) tensor(0.32...) """ + __constants__ = ["p", "margin", "reduction"] margin: float p: int @@ -1657,6 +1674,7 @@ class TripletMarginLoss(_Loss): .. _Learning shallow convolutional feature descriptors with triplet losses: https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html """ + __constants__ = ["margin", "p", "eps", "swap", "reduction"] margin: float p: float @@ -1794,6 +1812,7 @@ class TripletMarginWithDistanceLoss(_Loss): V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses: https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html """ + __constants__ = ["margin", "swap", "reduction"] margin: float swap: bool @@ -1889,10 +1908,10 @@ class CTCLoss(_Loss): Examples: >>> # Target are to be padded - >>> T = 50 # Input sequence length - >>> C = 20 # Number of classes (including blank) - >>> N = 16 # Batch size - >>> S = 30 # Target sequence length of longest target in batch (padding length) + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> S = 30 # Target sequence length of longest target in batch (padding length) >>> S_min = 10 # Minimum target length, for demonstration purposes >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) @@ -1902,16 +1921,21 @@ class CTCLoss(_Loss): >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) >>> >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) - >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) + >>> target_lengths = torch.randint( + ... low=S_min, + ... high=S, + ... size=(N,), + ... dtype=torch.long, + ... ) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded - >>> T = 50 # Input sequence length - >>> C = 20 # Number of classes (including blank) - >>> N = 16 # Batch size + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() @@ -1919,15 +1943,20 @@ class CTCLoss(_Loss): >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) - >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long) + >>> target = torch.randint( + ... low=1, + ... high=C, + ... size=(sum(target_lengths),), + ... dtype=torch.long, + ... ) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded and unbatched (effectively N=1) - >>> T = 50 # Input sequence length - >>> C = 20 # Number of classes (including blank) + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) >>> >>> # Initialize random batch of input vectors, for *size = (T,C) >>> # xdoctest: +SKIP("FIXME: error in doctest") @@ -1936,7 +1965,12 @@ class CTCLoss(_Loss): >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long) - >>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long) + >>> target = torch.randint( + ... low=1, + ... high=C, + ... size=(target_lengths,), + ... dtype=torch.long, + ... ) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() @@ -1963,6 +1997,7 @@ class CTCLoss(_Loss): True``. Please see the notes on :doc:`/notes/randomness` for background. """ + __constants__ = ["blank", "reduction"] blank: int zero_infinity: bool diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 5201f1f26c73d0..12fc6907fb4edf 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -412,6 +412,7 @@ class Module: import torch.nn as nn import torch.nn.functional as F + class Model(nn.Module): def __init__(self) -> None: super().__init__() @@ -954,9 +955,13 @@ def compute_should_use_set_data(tensor, tensor_applied): param_applied = fn(param) p_should_use_set_data = compute_should_use_set_data(param, param_applied) + from torch._subclasses.fake_tensor import FakeTensor + # subclasses may have multiple child tensors so we need to use swap_tensors p_should_use_swap_tensors = ( - should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) + should_use_swap_tensors + or is_traceable_wrapper_subclass(param_applied) + or isinstance(param, FakeTensor) ) param_grad = param.grad @@ -1226,16 +1231,13 @@ def to( device: Optional[DeviceLikeType] = ..., dtype: Optional[dtype] = ..., non_blocking: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload - def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: - ... + def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: ... @overload - def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: - ... + def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: ... def to(self, *args, **kwargs): r"""Move and/or cast the parameters and buffers. @@ -1748,7 +1750,11 @@ def _slow_forward(self, *input, **kwargs): if recording_scopes: # type ignore was added because at this point one knows that # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] - name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 + name = ( + torch.jit._trace._trace_module_map[self] # type: ignore[index] + if self in torch.jit._trace._trace_module_map # type: ignore[operator] + else None + ) # noqa: B950 if name: tracing_state.push_scope(name) else: @@ -2035,7 +2041,10 @@ def remove_from(*dicts_or_sets): # register_buffer() method that doesn't have the "persistent" # argument. Only pass it in if it is accepted otherwise assume # it is always true - if self.register_buffer is torch.nn.Module.register_buffer: + if ( + getattr(self.register_buffer, "__func__", None) + is torch.nn.Module.register_buffer + ): self.register_buffer(name, value, persistent) else: sign = inspect.signature(self.register_buffer) @@ -2157,13 +2166,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): @overload def state_dict( - self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ... - ) -> T_destination: - ... + self, + *, + destination: T_destination, + prefix: str = ..., + keep_vars: bool = ..., + ) -> T_destination: ... @overload - def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]: - ... + def state_dict( + self, + *, + prefix: str = ..., + keep_vars: bool = ..., + ) -> dict[str, Any]: ... # TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows. # Also remove the logic for arg parsing together. @@ -2521,15 +2537,14 @@ def load_state_dict( assign (bool, optional): When set to ``False``, the properties of the tensors in the current module are preserved whereas setting it to ``True`` preserves properties of the Tensors in the state dict. The only - exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s - for which the value from the module is preserved. - Default: ``False`` + exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter` + for which the value from the module is preserved. Default: ``False`` Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: - * **missing_keys** is a list of str containing any keys that are expected + * ``missing_keys`` is a list of str containing any keys that are expected by this module but missing from the provided ``state_dict``. - * **unexpected_keys** is a list of str containing the keys that are not + * ``unexpected_keys`` is a list of str containing the keys that are not expected by this module but present in the provided ``state_dict``. Note: diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index aad0bf9fe84651..52354373acaf26 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -358,6 +358,7 @@ class RMSNorm(Module): >>> rms_norm(input) """ + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: tuple[int, ...] eps: Optional[float] diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 13b4b8307b7339..ff90bfd07a4d98 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -58,6 +58,7 @@ class CircularPad1d(_CircularPadNd): padding (int, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 2-`tuple`, uses (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that padding size should be less than or equal to the corresponding input dimension. Shape: - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. @@ -107,6 +108,7 @@ class CircularPad2d(_CircularPadNd): padding (int, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + Note that padding size should be less than or equal to the corresponding input dimension. Shape: - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. @@ -168,6 +170,7 @@ class CircularPad3d(_CircularPadNd): (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that padding size should be less than or equal to the corresponding input dimension. Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. @@ -380,6 +383,7 @@ class ReflectionPad1d(_ReflectionPadNd): padding (int, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 2-`tuple`, uses (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that padding size should be less than the corresponding input dimension. Shape: - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. @@ -476,6 +480,7 @@ class ReflectionPad3d(_ReflectionPadNd): (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that padding size should be less than the corresponding input dimension. Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. @@ -539,6 +544,7 @@ class ReplicationPad1d(_ReplicationPadNd): padding (int, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 2-`tuple`, uses (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`) + Note that the output dimensions must remain positive. Shape: - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`. @@ -580,6 +586,7 @@ class ReplicationPad2d(_ReplicationPadNd): padding (int, tuple): the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`) + Note that the output dimensions must remain positive. Shape: - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`. @@ -634,6 +641,7 @@ class ReplicationPad3d(_ReplicationPadNd): (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`, :math:`\text{padding\_front}`, :math:`\text{padding\_back}`) + Note that the output dimensions must remain positive. Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`. diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index a6ec9f9c8be8db..448a3bb1f981ee 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -696,7 +696,7 @@ class AvgPool2d(_AvgPoolNd): The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be: - - a single ``int`` -- in which case the same value is used for the height and width dimension + - a single ``int`` or a single-element tuple -- in which case the same value is used for the height and width dimension - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 6fd07b53d897af..1d5994f919139e 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -253,7 +253,8 @@ def flatten_parameters(self) -> None: # alias would break the assumptions of the uniqueness check in # Module.named_parameters(). unique_data_ptrs = { - p.data_ptr() for p in self._flat_weights # type: ignore[union-attr] + p.data_ptr() # type: ignore[union-attr] + for p in self._flat_weights } if len(unique_data_ptrs) != len(self._flat_weights): return @@ -611,12 +612,10 @@ def __init__( bidirectional: bool = False, device=None, dtype=None, - ) -> None: - ... + ) -> None: ... @overload - def __init__(self, *args, **kwargs): - ... + def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): if "proj_size" in kwargs: @@ -969,12 +968,10 @@ def __init__( proj_size: int = 0, device=None, dtype=None, - ) -> None: - ... + ) -> None: ... @overload - def __init__(self, *args, **kwargs): - ... + def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): super().__init__("LSTM", *args, **kwargs) @@ -1304,12 +1301,10 @@ def __init__( bidirectional: bool = False, device=None, dtype=None, - ) -> None: - ... + ) -> None: ... @overload - def __init__(self, *args, **kwargs): - ... + def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): if "proj_size" in kwargs: diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index d9991073ee8c9b..e3b8fafa6a2740 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -59,9 +59,11 @@ class Embedding(Module): embedding = nn.Embedding(n, d, max_norm=1.0) W = torch.randn((m, d), requires_grad=True) idx = torch.tensor([1, 2]) - a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable + a = ( + embedding.weight.clone() @ W.t() + ) # weight must be cloned for this to be differentiable b = embedding(idx) @ W.t() # modifies weight in-place - out = (a.unsqueeze(0) + b.unsqueeze(1)) + out = a.unsqueeze(0) + b.unsqueeze(1) loss = out.sigmoid().prod() loss.backward() @@ -150,13 +152,13 @@ def __init__( self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, ( + "Padding_idx must be within num_embeddings" + ) elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, ( + "Padding_idx must be within num_embeddings" + ) padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm @@ -248,9 +250,9 @@ def from_pretrained( >>> embedding(input) tensor([[ 4.0000, 5.1000, 6.3000]]) """ - assert ( - embeddings.dim() == 2 - ), "Embeddings parameter is expected to be 2-dimensional" + assert embeddings.dim() == 2, ( + "Embeddings parameter is expected to be 2-dimensional" + ) rows, cols = embeddings.shape embedding = cls( num_embeddings=rows, @@ -391,13 +393,13 @@ def __init__( self.scale_grad_by_freq = scale_grad_by_freq if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, ( + "padding_idx must be within num_embeddings" + ) elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, ( + "padding_idx must be within num_embeddings" + ) padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx if _weight is None: @@ -526,9 +528,9 @@ def from_pretrained( >>> embeddingbag(input) tensor([[ 2.5000, 3.7000, 4.6500]]) """ - assert ( - embeddings.dim() == 2 - ), "Embeddings parameter is expected to be 2-dimensional" + assert embeddings.dim() == 2, ( + "Embeddings parameter is expected to be 2-dimensional" + ) rows, cols = embeddings.shape embeddingbag = cls( num_embeddings=rows, diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index e41df931c378dd..546c8b90c731ac 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -55,15 +55,17 @@ def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: class Transformer(Module): - r"""A transformer model. + r"""A basic transformer layer. - .. note:: - See `this tutorial `_ - for an in depth discussion of the performant building blocks PyTorch offers for building your own - transformer layers. - User is able to modify the attributes as needed. The architecture - is based on the paper `Attention Is All You Need `_. + This Transformer layer implements the original Transformer architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build an efficient transformer layer from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. Args: d_model: the number of expected features in the encoder/decoder inputs (default=512). @@ -254,7 +256,9 @@ def forward( Examples: >>> # xdoctest: +SKIP - >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) + >>> output = transformer_model( + ... src, tgt, src_mask=src_mask, tgt_mask=tgt_mask + ... ) """ is_batched = src.dim() == 3 if not self.batch_first and src.size(1) != tgt.size(1) and is_batched: @@ -307,12 +311,14 @@ def _reset_parameters(self): class TransformerEncoder(Module): r"""TransformerEncoder is a stack of N encoder layers. - .. note:: - See `this tutorial `_ - for an in depth discussion of the performant building blocks PyTorch offers for building your own - transformer layers. - - Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + This TransformerEncoder layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. .. warning:: All layers in the TransformerEncoder are initialized with the same parameters. @@ -534,10 +540,14 @@ def forward( class TransformerDecoder(Module): r"""TransformerDecoder is a stack of N decoder layers. - .. note:: - See `this tutorial `_ - for an in depth discussion of the performant building blocks PyTorch offers for building your own - transformer layers. + This TransformerDecoder layer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. .. warning:: All layers in the TransformerDecoder are initialized with the same parameters. @@ -635,13 +645,14 @@ def forward( class TransformerEncoderLayer(Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. - .. note:: - See `this tutorial `_ - for an in depth discussion of the performant building blocks PyTorch offers for building your own - transformer layers. - - This standard encoder layer is based on the paper `Attention Is All You Need `_. - Users may modify or implement in a different way during application. + This TransformerEncoderLayer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. TransformerEncoderLayer can handle either traditional torch.tensor inputs, or Nested Tensor inputs. Derived classes are expected to similarly accept @@ -677,7 +688,9 @@ class TransformerEncoderLayer(Module): >>> out = encoder_layer(src) Alternatively, when ``batch_first`` is ``True``: - >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) + >>> encoder_layer = nn.TransformerEncoderLayer( + ... d_model=512, nhead=8, batch_first=True + ... ) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) @@ -953,13 +966,14 @@ def _ff_block(self, x: Tensor) -> Tensor: class TransformerDecoderLayer(Module): r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. - .. note:: - See `this tutorial `_ - for an in depth discussion of the performant building blocks PyTorch offers for building your own - transformer layers. - - This standard decoder layer is based on the paper `Attention Is All You Need `_. - Users may modify or implement in a different way during application. + This TransformerDecoderLayer implements the original architecture described + in the `Attention Is All You Need `_ paper. The + intent of this layer is as a reference implementation for foundational understanding + and thus it contains only limited features relative to newer Transformer architectures. + Given the fast pace of innovation in transformer-like architectures, we recommend + exploring this `tutorial `_ + to build efficient layers from building blocks in core or using higher + level libraries from the `PyTorch Ecosystem `_. Args: d_model: the number of expected features in the input (required). @@ -984,7 +998,9 @@ class TransformerDecoderLayer(Module): >>> out = decoder_layer(tgt, memory) Alternatively, when ``batch_first`` is ``True``: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) + >>> decoder_layer = nn.TransformerDecoderLayer( + ... d_model=512, nhead=8, batch_first=True + ... ) >>> memory = torch.rand(32, 10, 512) >>> tgt = torch.rand(32, 20, 512) >>> out = decoder_layer(tgt, memory) diff --git a/torch/nn/parallel/_functions.py b/torch/nn/parallel/_functions.py index e4d1285ba9f395..5170b172fbbec6 100644 --- a/torch/nn/parallel/_functions.py +++ b/torch/nn/parallel/_functions.py @@ -11,9 +11,9 @@ class Broadcast(Function): @staticmethod def forward(ctx, target_gpus, *inputs): - assert all( - i.device.type != "cpu" for i in inputs - ), "Broadcast function not implemented for CPU tensors" + assert all(i.device.type != "cpu" for i in inputs), ( + "Broadcast function not implemented for CPU tensors" + ) target_gpus = [_get_device_index(x, True) for x in target_gpus] ctx.target_gpus = target_gpus if len(inputs) == 0: @@ -56,9 +56,9 @@ def backward(ctx, *grad_outputs): class Gather(Function): @staticmethod def forward(ctx, target_device, dim, *inputs): - assert all( - i.device.type != "cpu" for i in inputs - ), "Gather function not implemented for CPU tensors" + assert all(i.device.type != "cpu" for i in inputs), ( + "Gather function not implemented for CPU tensors" + ) if target_device == "cpu": ctx.target_device = "cpu" else: diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 4192682577f901..c5db538f52bb20 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -347,20 +347,33 @@ class DistributedDataParallel(Module, Joinable): To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn up ``N`` processes, ensuring that each process exclusively works on a single GPU from 0 to N-1. This can be done by either setting - ``CUDA_VISIBLE_DEVICES`` for every process or by calling: + ``CUDA_VISIBLE_DEVICES`` for every process or by calling the following API for GPUs, >>> # xdoctest: +SKIP("undefined variables") >>> torch.cuda.set_device(i) + or calling the unified API for :ref:`accelerator`, + + >>> # xdoctest: +SKIP("undefined variables") + >>> torch.accelerator.set_device_index(i) + where i is from 0 to N-1. In each process, you should refer the following to construct this module: >>> # xdoctest: +SKIP("undefined variables") + >>> if torch.accelerator.is_available(): + >>> device_type = torch.accelerator.current_accelerator().type + >>> vendor_backend = torch.distributed.get_default_backend_for_device(device_type) + >>> >>> torch.distributed.init_process_group( - >>> backend='nccl', world_size=N, init_method='...' + >>> backend=vendor_backend, world_size=N, init_method='...' >>> ) >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i) + Or you can use the latest API for initialization: + + >>> torch.distributed.init_process_group(device_id=i) + In order to spawn up multiple processes per node, you can use either ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``. @@ -759,7 +772,7 @@ def __init__( "DistributedDataParallel device_ids and output_device arguments " "only work with single-device/multiple-device GPU modules or CPU modules, " f"but got device_ids {device_ids}, output_device {output_device}, " - f"and module parameters {({p.device for p in self._module_parameters})}.", + f"and module parameters { ({p.device for p in self._module_parameters}) }.", # noqa: E201,E202 ) self.device_ids = None @@ -2171,7 +2184,7 @@ def _sync_buffers(self): else: # The process with rank 0 is considered the authoritative copy. authoritative_rank = 0 - # Update self.modules_buffers incase any buffers were + # Update self.modules_buffers in case any buffers were # reassigned. self._assign_modules_buffers() self._sync_module_buffers(authoritative_rank) diff --git a/torch/nn/parallel/parallel_apply.py b/torch/nn/parallel/parallel_apply.py index eb6c8ec29b2046..4d66a7a71d8959 100644 --- a/torch/nn/parallel/parallel_apply.py +++ b/torch/nn/parallel/parallel_apply.py @@ -46,9 +46,9 @@ def parallel_apply( element of :attr:`inputs` can either be a single object as the only argument to a module, or a collection of positional arguments. """ - assert len(modules) == len( - inputs - ), f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}" + assert len(modules) == len(inputs), ( + f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}" + ) if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: @@ -88,9 +88,11 @@ def _worker( if stream is None: stream = torch.cuda.current_stream(device) try: - with torch.cuda.device(device), torch.cuda.stream( - stream - ), torch.amp.autocast("cuda", enabled=autocast_enabled): + with ( + torch.cuda.device(device), + torch.cuda.stream(stream), + torch.amp.autocast("cuda", enabled=autocast_enabled), + ): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): input = (input,) diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index 34c7d5116eec09..6c6e4567efa118 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -111,8 +111,7 @@ def replicate( ) -> list[T]: if not _replicatable_module(network): raise RuntimeError( - "Cannot replicate network where python modules are " - "childrens of ScriptModule" + "Cannot replicate network where python modules are children of ScriptModule" ) if not devices: @@ -184,7 +183,7 @@ def replicate( # so setattr them as non-parameter attributes setattr(replica, key, param_copy) # expose the parameter for DDP - replica._former_parameters[key] = param_copy + replica._former_parameters[key] = param_copy # type: ignore[operator, index] for key, buf in module._buffers.items(): # type: ignore[assignment] if buf is None: for j in range(num_replicas): diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index c70d3d5a7de5b1..947f5635736515 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -35,8 +35,7 @@ def scatter( inputs: torch.Tensor, target_gpus: Sequence[Union[int, torch.device]], dim: int = ..., -) -> tuple[torch.Tensor, ...]: - ... +) -> tuple[torch.Tensor, ...]: ... @overload @@ -44,8 +43,7 @@ def scatter( inputs: T, target_gpus: Sequence[Union[int, torch.device]], dim: int = ..., -) -> list[T]: - ... +) -> list[T]: ... def scatter(inputs, target_gpus, dim=0): diff --git a/torch/nn/qat/__init__.py b/torch/nn/qat/__init__.py index 01a17572bc2e83..766b09382aa78e 100644 --- a/torch/nn/qat/__init__.py +++ b/torch/nn/qat/__init__.py @@ -4,6 +4,7 @@ This package is in the process of being deprecated. Please, use `torch.ao.nn.qat.dynamic` instead. """ + from torch.nn.qat import dynamic, modules # noqa: F403 from torch.nn.qat.modules import * # noqa: F403 diff --git a/torch/nn/qat/dynamic/__init__.py b/torch/nn/qat/dynamic/__init__.py index b8a05d8bde0f67..56838a1cfcae74 100644 --- a/torch/nn/qat/dynamic/__init__.py +++ b/torch/nn/qat/dynamic/__init__.py @@ -4,4 +4,5 @@ This package is in the process of being deprecated. Please, use `torch.ao.nn.qat.dynamic` instead. """ + from torch.nn.qat.dynamic.modules import * # noqa: F403 diff --git a/torch/nn/qat/dynamic/modules/linear.py b/torch/nn/qat/dynamic/modules/linear.py index ea69fba158d3bf..1a5c80ea213c62 100644 --- a/torch/nn/qat/dynamic/modules/linear.py +++ b/torch/nn/qat/dynamic/modules/linear.py @@ -7,4 +7,5 @@ appropriate file under the `torch/ao/nn/qat/dynamic/modules`, while adding an import statement here. """ + from torch.ao.nn.qat.dynamic.modules.linear import Linear diff --git a/torch/nn/qat/modules/__init__.py b/torch/nn/qat/modules/__init__.py index 667ae790b64823..f7f55fbdf789a7 100644 --- a/torch/nn/qat/modules/__init__.py +++ b/torch/nn/qat/modules/__init__.py @@ -4,6 +4,7 @@ This package is in the process of being deprecated. Please, use `torch.ao.nn.qat.modules` instead. """ + from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag from torch.ao.nn.qat.modules.linear import Linear diff --git a/torch/nn/qat/modules/linear.py b/torch/nn/qat/modules/linear.py index f5841a46096c18..4e822eba7e0617 100644 --- a/torch/nn/qat/modules/linear.py +++ b/torch/nn/qat/modules/linear.py @@ -7,4 +7,5 @@ appropriate file under the `torch/ao/nn/qat/modules`, while adding an import statement here. """ + from torch.ao.nn.qat.modules.linear import Linear diff --git a/torch/nn/quantizable/modules/activation.py b/torch/nn/quantizable/modules/activation.py index e4f7a5ca3b540e..28f3eee958115d 100644 --- a/torch/nn/quantizable/modules/activation.py +++ b/torch/nn/quantizable/modules/activation.py @@ -7,4 +7,5 @@ appropriate file under the `torch/ao/nn/quantizable/modules`, while adding an import statement here. """ + from torch.ao.nn.quantizable.modules.activation import MultiheadAttention diff --git a/torch/nn/quantized/dynamic/modules/linear.py b/torch/nn/quantized/dynamic/modules/linear.py index 592384dbdb3442..b23fae2c06aa8c 100644 --- a/torch/nn/quantized/dynamic/modules/linear.py +++ b/torch/nn/quantized/dynamic/modules/linear.py @@ -7,4 +7,5 @@ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, while adding an import statement here. """ + from torch.ao.nn.quantized.dynamic.modules.linear import Linear diff --git a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py index aee3a75e70f455..ad1adf06d0fe9d 100644 --- a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -1,7 +1,13 @@ -# mypy: allow-untyped-defs +from typing import Any, Callable, TypeVar +from typing_extensions import ParamSpec + import torch import torch.nn.functional as F + +_P = ParamSpec("_P") +_R = TypeVar("_R") + from .conv_utils import ( conv_args_and_kwargs, conv_backward, @@ -17,7 +23,12 @@ @implements_per_sample_grads(F.conv3d) class ConvPerSampleGrad(torch.autograd.Function): @staticmethod - def forward(ctx, kwarg_names, conv_fn, *expanded_args_and_kwargs): + def forward( + ctx: Any, + kwarg_names: list[str], + conv_fn: Callable[_P, _R], + *expanded_args_and_kwargs: Any, + ) -> torch.Tensor: expanded_args, expanded_kwargs = conv_args_and_kwargs( kwarg_names, expanded_args_and_kwargs ) @@ -64,5 +75,5 @@ def forward(ctx, kwarg_names, conv_fn, *expanded_args_and_kwargs): return output @staticmethod - def backward(ctx, grad_output): - return conv_backward(ctx.conv_fn, ctx, grad_output) + def backward(ctx: Any, *grad_outputs: Any) -> Any: + return conv_backward(ctx.conv_fn, ctx, grad_outputs[0]) diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index 7b7f58b5ff5f88..74418e14386078 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -314,7 +314,7 @@ def unfold3d( Example: >>> # xdoctest: +SKIP >>> B, C, D, H, W = 3, 4, 5, 6, 7 - >>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W) + >>> tensor = torch.arange(1, B * C * D * H * W + 1.0).view(B, C, D, H, W) >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape torch.Size([3, 32, 120]) """ diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index 1935c591346d7f..2dd4b6de4f697a 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -72,8 +72,10 @@ def reset(ew): @contextmanager def setup_rnn(use_input_variant, args, kwargs): - with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches( - args, kwargs + with ( + batch_second(args, kwargs) + if use_input_variant + else allow_smaller_batches(args, kwargs) ): yield @@ -150,23 +152,23 @@ def __torch_function__(cls, func, _, args=(), kwargs=None): ) @property - def dtype(self): + def dtype(self): # type: ignore[override] return self.orig_weight.dtype @property - def data(self): + def data(self): # type: ignore[override] return self.orig_weight.data @property - def shape(self): + def shape(self): # type: ignore[override] return self.orig_weight.shape @property - def device(self): + def device(self): # type: ignore[override] return self.orig_weight.device @property - def is_cuda(self): + def is_cuda(self): # type: ignore[override] return self.orig_weight.is_cuda def data_ptr(self): diff --git a/torch/nn/utils/_per_sample_grad.py b/torch/nn/utils/_per_sample_grad.py index eeb6e1eeaf3c04..2eae0865845eec 100644 --- a/torch/nn/utils/_per_sample_grad.py +++ b/torch/nn/utils/_per_sample_grad.py @@ -49,7 +49,9 @@ def call_for_per_sample_grads( grad_outputs by 1 / batch_size from cross batch interaction. >>> model = nn.Linear(4, 3) >>> batched_input = torch.randn(5, 4) # batch size of 5 - >>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean() + >>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")( + ... batched_input + ... ).mean() >>> res.backward() Note:: diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 835e5403c8f34e..2e49843fa66b9a 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -16,11 +16,7 @@ ) -__all__ = [ - "clip_grad_norm_", - "clip_grad_norm", - "clip_grad_value_", -] +__all__: list[str] = [] _tensor_or_tensors = Union[ @@ -154,9 +150,7 @@ def _clip_grads_with_norm_( return grouped_grads: dict[ tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]] - ] = _group_tensors_by_device_and_dtype( - [grads] - ) # type: ignore[assignment] + ] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] clip_coef = max_norm / (total_norm + 1e-6) # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so @@ -292,3 +286,8 @@ def clip_grad_value_( else: for grad in grads: cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value) + + +clip_grad_norm.__module__ = "torch.nn.utils" +clip_grad_norm_.__module__ = "torch.nn.utils" +clip_grad_value_.__module__ = "torch.nn.utils" diff --git a/torch/nn/utils/fusion.py b/torch/nn/utils/fusion.py index c9878b0697ee6c..35406785305117 100644 --- a/torch/nn/utils/fusion.py +++ b/torch/nn/utils/fusion.py @@ -135,9 +135,9 @@ def fuse_linear_bn_eval( 2. the number of features in bn is 1 Otherwise, skip the folding path """ - assert ( - linear.out_features == bn.num_features or bn.num_features == 1 - ), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1" + assert linear.out_features == bn.num_features or bn.num_features == 1, ( + "To fuse, linear.out_features == bn.num_features or bn.num_features == 1" + ) assert bn.running_mean is not None and bn.running_var is not None fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights( diff --git a/torch/nn/utils/memory_format.py b/torch/nn/utils/memory_format.py index 3c117ed2827b6e..59e54b11e3b9bb 100644 --- a/torch/nn/utils/memory_format.py +++ b/torch/nn/utils/memory_format.py @@ -63,12 +63,16 @@ def convert_conv2d_weight_memory_format( Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) - >>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda") + >>> input = torch.randint( + ... 1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda" + ... ) >>> model = nn.Sequential( >>> nn.Conv2d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) - >>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) + >>> model = nn.utils.convert_conv2d_weight_memory_format( + ... model, torch.channels_last + ... ) >>> out = model(input) """ # TODO: expand this to `_ConvNd` when channels_last support is extended @@ -137,12 +141,16 @@ def convert_conv3d_weight_memory_format( Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) - >>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") + >>> input = torch.randint( + ... 1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda" + ... ) >>> model = nn.Sequential( >>> nn.Conv3d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) - >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> model = nn.utils.convert_conv3d_weight_memory_format( + ... model, torch.channels_last_3d + ... ) >>> out = model(input) """ diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index b4c88c898195fd..11f7106b314916 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -46,6 +46,7 @@ def cached(): .. code-block:: python import torch.nn.utils.parametrize as P + ... with P.cached(): output = model(inputs) @@ -519,24 +520,26 @@ def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) - >>> A = A + A.T # A is now symmetric + >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True >>> class RankOne(nn.Module): >>> def forward(self, x, y): - >>> # Form a rank 1 matrix multiplying two vectors + >>> # Form a rank 1 matrix multiplying two vectors >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) >>> >>> def right_inverse(self, Z): - >>> # Project Z onto the rank 1 matrices + >>> # Project Z onto the rank 1 matrices >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) - >>> # Return rescaled singular vectors + >>> # Return rescaled singular vectors >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt >>> - >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) + >>> linear_rank_one = P.register_parametrization( + ... nn.Linear(4, 4), "weight", RankOne() + ... ) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1 diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 583620dfa40d35..aee6bdc2ad2180 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Pruning methods.""" + import numbers from abc import ABC, abstractmethod from collections.abc import Iterable @@ -63,9 +64,9 @@ def apply_mask(self, module): """ # to carry out the multiplication, the mask needs to have been computed, # so the pruning method must know what tensor it's operating on - assert ( - self._tensor_name is not None - ), f"Module {module} has to be pruned" # this gets set in apply() + assert self._tensor_name is not None, ( + f"Module {module} has to be pruned" + ) # this gets set in apply() mask = getattr(module, self._tensor_name + "_mask") orig = getattr(module, self._tensor_name + "_orig") pruned_tensor = mask.to(dtype=orig.dtype) * orig @@ -109,10 +110,10 @@ def _get_composite_method(cls, module, name, *args, **kwargs): old_method = hook hooks_to_remove.append(k) found += 1 - assert ( - found <= 1 - ), f"Avoid adding multiple pruning hooks to the\ + assert found <= 1, ( + f"Avoid adding multiple pruning hooks to the\ same tensor {name} of module {module}. Use a PruningContainer." + ) for k in hooks_to_remove: del module._forward_pre_hooks[k] @@ -153,9 +154,9 @@ def _get_composite_method(cls, module, name, *args, **kwargs): orig = getattr(module, name) if importance_scores is not None: - assert ( - importance_scores.shape == orig.shape - ), f"importance_scores should have the same shape as parameter {name} of {module}" + assert importance_scores.shape == orig.shape, ( + f"importance_scores should have the same shape as parameter {name} of {module}" + ) else: importance_scores = orig @@ -222,9 +223,9 @@ def prune(self, t, default_mask=None, importance_scores=None): pruned version of tensor ``t``. """ if importance_scores is not None: - assert ( - importance_scores.shape == t.shape - ), "importance_scores should have the same shape as tensor t" + assert importance_scores.shape == t.shape, ( + "importance_scores should have the same shape as tensor t" + ) else: importance_scores = t default_mask = default_mask if default_mask is not None else torch.ones_like(t) @@ -241,9 +242,9 @@ def remove(self, module): Pruning itself is NOT undone or reversed! """ # before removing pruning from a tensor, it has to have been applied - assert ( - self._tensor_name is not None - ), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply() + assert self._tensor_name is not None, ( + f"Module {module} has to be pruned before pruning can be removed" + ) # this gets set in apply() # to update module[name] to latest trained weights weight = self.apply_mask(module) # masked weights @@ -394,6 +395,8 @@ def _combine_masks(method, t, mask): raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}") # compute the new mask on the unpruned slice of the tensor t + if isinstance(slc, list): + slc = tuple(slc) partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) @@ -414,7 +417,7 @@ def compute_mask(self, t, default_mask): return mask @classmethod - def apply(cls, module, name): + def apply(cls, module, name): # type: ignore[override] r"""Add pruning on the fly and reparametrization of a tensor. Adds the forward pre-hook that enables pruning on the fly and @@ -469,7 +472,7 @@ def compute_mask(self, t, default_mask): return mask @classmethod - def apply(cls, module, name, amount): + def apply(cls, module, name, amount): # type: ignore[override] r"""Add pruning on the fly and reparametrization of a tensor. Adds the forward pre-hook that enables pruning on the fly and @@ -528,7 +531,7 @@ def compute_mask(self, t, default_mask): return mask @classmethod - def apply(cls, module, name, amount, importance_scores=None): + def apply(cls, module, name, amount, importance_scores=None): # type: ignore[override] r"""Add pruning on the fly and reparametrization of a tensor. Adds the forward pre-hook that enables pruning on the fly and @@ -625,6 +628,7 @@ def make_mask(t, dim, nchannels, nchannels_toprune): mask = torch.zeros_like(t) slc = [slice(None)] * len(t.shape) slc[dim] = channel_mask + slc = tuple(slc) mask[slc] = 1 return mask @@ -638,7 +642,7 @@ def make_mask(t, dim, nchannels, nchannels_toprune): return mask @classmethod - def apply(cls, module, name, amount, dim=-1): + def apply(cls, module, name, amount, dim=-1): # type: ignore[override] r"""Add pruning on the fly and reparametrization of a tensor. Adds the forward pre-hook that enables pruning on the fly and @@ -739,6 +743,7 @@ def make_mask(t, dim, indices): # replace a None at position=dim with indices # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] slc[dim] = indices + slc = tuple(slc) # use slc to slice mask and replace all its entries with 1s # e.g.: mask[:, :, [0, 2, 3]] = 1 mask[slc] = 1 @@ -753,7 +758,7 @@ def make_mask(t, dim, indices): return mask @classmethod - def apply(cls, module, name, amount, n, dim, importance_scores=None): + def apply(cls, module, name, amount, n, dim, importance_scores=None): # type: ignore[override] r"""Add pruning on the fly and reparametrization of a tensor. Adds the forward pre-hook that enables pruning on the fly and @@ -800,7 +805,7 @@ def compute_mask(self, t, default_mask): return mask @classmethod - def apply(cls, module, name, mask): + def apply(cls, module, name, mask): # type: ignore[override] r"""Add pruning on the fly and reparametrization of a tensor. Adds the forward pre-hook that enables pruning on the fly and @@ -842,7 +847,7 @@ def identity(module, name): Examples: >>> # xdoctest: +SKIP - >>> m = prune.identity(nn.Linear(2, 3), 'bias') + >>> m = prune.identity(nn.Linear(2, 3), "bias") >>> print(m.bias_mask) tensor([1., 1., 1.]) """ @@ -878,7 +883,7 @@ def random_unstructured(module, name, amount): Examples: >>> # xdoctest: +SKIP - >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) + >>> m = prune.random_unstructured(nn.Linear(2, 3), "weight", amount=1) >>> torch.sum(m.weight_mask == 0) tensor(1) @@ -921,7 +926,7 @@ def l1_unstructured(module, name, amount, importance_scores=None): Examples: >>> # xdoctest: +SKIP - >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2) + >>> m = prune.l1_unstructured(nn.Linear(2, 3), "weight", amount=0.2) >>> m.state_dict().keys() odict_keys(['bias', 'weight_orig', 'weight_mask']) """ @@ -961,9 +966,7 @@ def random_structured(module, name, amount, dim): Examples: >>> # xdoctest: +SKIP - >>> m = prune.random_structured( - ... nn.Linear(5, 3), 'weight', amount=3, dim=1 - ... ) + >>> m = prune.random_structured(nn.Linear(5, 3), "weight", amount=3, dim=1) >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) >>> print(columns_pruned) 3 @@ -1010,7 +1013,7 @@ def ln_structured(module, name, amount, n, dim, importance_scores=None): Examples: >>> from torch.nn.utils import prune >>> m = prune.ln_structured( - ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') + ... nn.Conv2d(5, 3, 2), "weight", amount=0.3, dim=1, n=float("-inf") ... ) """ LnStructured.apply( @@ -1063,13 +1066,17 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw Examples: >>> from torch.nn.utils import prune >>> from collections import OrderedDict - >>> net = nn.Sequential(OrderedDict([ - ... ('first', nn.Linear(10, 4)), - ... ('second', nn.Linear(4, 1)), - ... ])) + >>> net = nn.Sequential( + ... OrderedDict( + ... [ + ... ("first", nn.Linear(10, 4)), + ... ("second", nn.Linear(4, 1)), + ... ] + ... ) + ... ) >>> parameters_to_prune = ( - ... (net.first, 'weight'), - ... (net.second, 'weight'), + ... (net.first, "weight"), + ... (net.second, "weight"), ... ) >>> prune.global_unstructured( ... parameters_to_prune, @@ -1161,7 +1168,7 @@ def custom_from_mask(module, name, mask): Examples: >>> from torch.nn.utils import prune >>> m = prune.custom_from_mask( - ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) + ... nn.Linear(5, 3), name="bias", mask=torch.tensor([0, 1, 0]) ... ) >>> print(m.bias_mask) tensor([0., 1., 0.]) @@ -1187,8 +1194,8 @@ def remove(module, name): will act. Examples: - >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2) - >>> m = remove(m, name='weight') + >>> m = random_unstructured(nn.Linear(5, 7), name="weight", amount=0.2) + >>> m = remove(m, name="weight") """ for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: @@ -1219,7 +1226,7 @@ def is_pruned(module): >>> m = nn.Linear(5, 7) >>> print(prune.is_pruned(m)) False - >>> prune.random_unstructured(m, name='weight', amount=0.2) + >>> prune.random_unstructured(m, name="weight", amount=0.2) >>> print(prune.is_pruned(m)) True """ diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index 0b5e85b87cf48b..3b676ade8ffbac 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -105,8 +105,7 @@ def to( dtype: torch.dtype, non_blocking: bool = ..., copy: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def to( @@ -115,8 +114,7 @@ def to( dtype: Optional[torch.dtype] = ..., non_blocking: bool = ..., copy: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def to( @@ -124,8 +122,7 @@ def to( other: Tensor, non_blocking: bool = ..., copy: bool = ..., - ) -> Self: - ... + ) -> Self: ... def to(self, *args: Any, **kwargs: Any) -> Self: r"""Perform dtype and/or device conversion on `self.data`. @@ -354,7 +351,9 @@ def pad_packed_sequence( >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) >>> lens = [2, 1, 3] - >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False) + >>> packed = pack_padded_sequence( + ... seq, lens, batch_first=True, enforce_sorted=False + ... ) >>> packed PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) @@ -473,7 +472,10 @@ def pad_sequence( # assuming trailing dimensions and type of all the Tensors # in sequences are same and fetching those from sequences[0] return torch._C._nn.pad_sequence( - sequences, batch_first, padding_value, padding_side # type: ignore[arg-type] + sequences, # type: ignore[arg-type] + batch_first, + padding_value, + padding_side, # type: ignore[arg-type] ) diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index 3474a127a0b494..a1eeb87c24ab8c 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs """Spectral Normalization from https://arxiv.org/abs/1802.05957.""" + from typing import Any, Optional, TypeVar import torch diff --git a/torch/nn/utils/weight_norm.py b/torch/nn/utils/weight_norm.py index 0eb51d4df13269..7b336e8b8c08e5 100644 --- a/torch/nn/utils/weight_norm.py +++ b/torch/nn/utils/weight_norm.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" + from typing import Any, TypeVar from typing_extensions import deprecated diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 52ebc41c96e617..302159a2557690 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -9,7 +9,6 @@ "symbolic_helper", "utils", # All opsets - "symbolic_caffe2", "symbolic_opset7", "symbolic_opset8", "symbolic_opset9", @@ -74,7 +73,6 @@ from . import ( # usort: skip. Keep the order instead of sorting lexicographically errors, ops, - symbolic_caffe2, symbolic_helper, symbolic_opset7, symbolic_opset8, @@ -364,6 +362,16 @@ def forward(self, x): if isinstance(args, torch.Tensor): args = (args,) + # Prepare legacy export parameters for potential fallback + legacy_export_kwargs = { + "training": training, + "operator_export_type": operator_export_type, + "do_constant_folding": do_constant_folding, + "custom_opsets": custom_opsets, + "export_modules_as_functions": export_modules_as_functions, + "autograd_inlining": autograd_inlining, + } + return _compat.export_compat( model, args, @@ -386,6 +394,7 @@ def forward(self, x): dump_exported_program=dump_exported_program, artifacts_dir=artifacts_dir, fallback=fallback, + legacy_export_kwargs=legacy_export_kwargs, ) else: import warnings @@ -393,7 +402,7 @@ def forward(self, x): from torch.onnx.utils import export warnings.warn( - "You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.8, " + "You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, " "the new torch.export-based ONNX exporter will be the default. To switch now, set " "dynamo=True in torch.onnx.export. This new exporter supports features like exporting " "LLMs with DynamicCache. We encourage you to try it and share feedback to help improve " diff --git a/torch/onnx/_constants.py b/torch/onnx/_constants.py index 6c91b245ed703f..b3c386b701d925 100644 --- a/torch/onnx/_constants.py +++ b/torch/onnx/_constants.py @@ -4,10 +4,9 @@ ONNX_BASE_OPSET = 9 ONNX_MIN_OPSET = 7 -ONNX_MAX_OPSET = 20 +ONNX_MAX_OPSET = 23 ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20 -# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py -ONNX_DEFAULT_OPSET = 17 +ONNX_DEFAULT_OPSET = 18 ONNX_CONSTANT_FOLDING_MIN_OPSET = 9 PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues" diff --git a/torch/onnx/_globals.py b/torch/onnx/_globals.py index f3dd273386f8f0..55d0550324e731 100644 --- a/torch/onnx/_globals.py +++ b/torch/onnx/_globals.py @@ -54,11 +54,6 @@ def export_onnx_opset_version(self) -> int: @export_onnx_opset_version.setter def export_onnx_opset_version(self, value: int): - supported_versions = range( - _constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1 - ) - if value not in supported_versions: - raise ValueError(f"Unsupported ONNX opset version: {value}") self._export_onnx_opset_version = value @property diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py index 59aa0e49875724..3557ef099309e9 100644 --- a/torch/onnx/_internal/_lazy_import.py +++ b/torch/onnx/_internal/_lazy_import.py @@ -28,13 +28,14 @@ def __getattr__(self, attr: str) -> object: # NOTE: Add additional used imports here. if TYPE_CHECKING: import onnx + import onnx_ir # type: ignore[import-untyped] import onnxscript - import onnxscript._framework_apis.torch_2_7 as onnxscript_apis + import onnxscript._framework_apis.torch_2_8 as onnxscript_apis - onnxscript_ir = onnxscript.ir + onnxscript_ir = onnx_ir else: onnx = _LazyModule("onnx") onnxscript = _LazyModule("onnxscript") - onnxscript_ir = _LazyModule("onnxscript.ir") - onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_7") + onnxscript_ir = _LazyModule("onnx_ir") + onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_8") diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py index 2eb5adab3a7d9d..c01058f0200687 100644 --- a/torch/onnx/_internal/exporter/_analysis.py +++ b/torch/onnx/_internal/exporter/_analysis.py @@ -5,6 +5,7 @@ from __future__ import annotations import dataclasses +import operator import textwrap import traceback from collections import defaultdict @@ -99,7 +100,9 @@ def _format_model_info(model_info: ModelInfo) -> str: lines.append("\n") lines.append("Of the call_function nodes, the counts of operators used are:\n") sorted_targets = sorted( - model_info.fx_node_target_count.items(), key=lambda x: x[1], reverse=True + model_info.fx_node_target_count.items(), + key=operator.itemgetter(1), + reverse=True, ) for target, count in sorted_targets: lines.append(f"- `{target}`: {count}") @@ -127,7 +130,7 @@ def _format_model_info(model_info: ModelInfo) -> str: target_to_messages[str(node.target)] = message for target, nodes in sorted( - target_to_nodes.items(), key=lambda x: x[0], reverse=True + target_to_nodes.items(), key=operator.itemgetter(0), reverse=True ): message = textwrap.indent( f"{target_to_messages[target]}. Example node: `{nodes[0].format_node()}`. All nodes: `{nodes}`", diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index 8ddb5c2fac4b50..4774855e874ee2 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -152,7 +152,11 @@ class TorchExportStrictStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - with _patch_dynamo_unsupported_functions(): + with ( + _patch_dynamo_unsupported_functions(), + # Support the dynamism with 0/1 input dim + torch.fx.experimental._config.patch(backed_size_oblivious=True), # type: ignore[attr-defined] + ): try: return torch.export.export( model, @@ -198,22 +202,30 @@ class TorchExportNonStrictStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - try: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False - ) - except torch._dynamo.exc.UserError as exc: - # Refine the dynamic shapes based on the suggested fixes. + with ( + # Support the dynamism with 0/1 input dim + torch.fx.experimental._config.patch(backed_size_oblivious=True), # type: ignore[attr-defined] + ): try: - new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( - exc.msg, dynamic_shapes + return torch.export.export( + model, + args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + try: + new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + except Exception: + # If the dynamic shapes cannot be refined, re-raise the exception. + raise exc from None + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False ) - except Exception: - # If the dynamic shapes cannot be refined, re-raise the exception. - raise exc from None - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False - ) def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index b570b20bd02c4d..c3a0f26b227d3d 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -66,6 +66,8 @@ def export_compat( dump_exported_program: bool = False, artifacts_dir: str | os.PathLike = ".", fallback: bool = False, + # Legacy export parameters for fallback + legacy_export_kwargs: dict[str, Any] | None = None, ) -> _onnx_program.ONNXProgram: if opset_version is None: opset_version = _constants.TORCHLIB_OPSET @@ -151,6 +153,10 @@ def export_compat( dynamic_axes = _dynamic_shapes.from_dynamic_shapes_to_dynamic_axes( dynamic_shapes=dynamic_shapes, input_names=input_names, exception=e ) + # Use the legacy export kwargs prepared in __init__.py + if legacy_export_kwargs is None: + legacy_export_kwargs = {} + torch.onnx.utils.export( model, # type: ignore[arg-type] args, @@ -159,9 +165,10 @@ def export_compat( export_params=export_params, input_names=input_names, output_names=output_names, - opset_version=17, # TODO(justinchuby): Hard coded to 17 for now + opset_version=opset_version, dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs, + **legacy_export_kwargs, ) onnx_program = _onnx_program.ONNXProgram(ir.load(f), None) diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index c6a84c448587e2..a4e3eea2e1d28e 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -477,6 +477,20 @@ def _get_onnxscript_opset(opset_version: int) -> onnxscript.values.Opset: return onnxscript.values.Opset("", opset_version) +def _is_onnx_op(op: Any) -> bool: + """Whether the op overload is an ONNX custom op implemented with PyTorch.""" + if not isinstance(op, torch._ops.OpOverload): + return False + return op.name().startswith("onnx::") + + +def _parse_onnx_op(op: torch._ops.OpOverload) -> tuple[str, int]: + """Parse the ONNX custom op overload name to get the op type and opset version.""" + name = op.name()[len("onnx::") :] + name, _, opset = name.partition(".opset") + return name, int(opset) + + def _handle_call_function_node_with_lowering( model: ir.Model, node: torch.fx.Node, @@ -512,17 +526,6 @@ def _handle_call_function_node_with_lowering( # use SequenceAt to get the value. This is handled by torchlib pass - # Find the matching ONNX overload for the node - # NOTE: Create different registries for different ONNX opset versions - # TODO: Log the message here to expose false positives - onnx_function, message = _dispatching.dispatch(node, registry) - - if onnx_function is None: - # TODO(justinchuby): Fall back to ATen op or do something else? - raise _errors.DispatchError( - f"No ONNX function found for {node.target!r}. Failure message: {message}" - ) - # Map FX inputs to ONNX inputs and fill optional inputs. # torch_args and torch_kwargs are for op-level validation fx_args = node.args @@ -546,19 +549,68 @@ def _handle_call_function_node_with_lowering( # TODO(justinchuby): Maybe keep it as None? onnx_kwargs[key] = -1 - with onnxscript.evaluator.default_as( - tracer := _building.OpRecorder(opset, constant_farm) - ): - global current_tracer - current_tracer = tracer - try: - outputs = onnx_function(*onnx_args, **onnx_kwargs) - except Exception as e: - raise _errors.GraphConstructionError( - f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'" - ) from e - finally: - current_tracer = None + if _is_onnx_op(node.target): + # Handle torch.ops.onnx.* ops. These ops can be directly added to the graph + op_type, opset_version = _parse_onnx_op(node.target) # type: ignore[arg-type] + # If final inputs are None, strip them from the node inputs + for input_ in reversed(onnx_args): + if input_ is not None: + break + onnx_args.pop() + onnx_node = ir.Node( + "", + op_type, + onnx_args, + ir.convenience.convert_attributes(onnx_kwargs), + name=node.name, + num_outputs=len(node.target._schema.returns), # type: ignore[union-attr] + version=opset_version, + ) + # Store the single node in a list to be consistent with the rest of the code for further processing + onnx_nodes = [onnx_node] + if len(onnx_node.outputs) == 1: + outputs = onnx_node.outputs[0] + else: + outputs = onnx_node.outputs # type: ignore[assignment] + else: + # Find the matching ONNX overload for the node + # TODO: Log the message here to expose false positives + onnx_function, message = _dispatching.dispatch(node, registry) + + if onnx_function is None: + raise _errors.DispatchError( + f"No ONNX function found for {node.target!r}. Failure message: {message}" + ) + + with onnxscript.evaluator.default_as( + tracer := _building.OpRecorder(opset, constant_farm) + ): + global current_tracer + current_tracer = tracer + try: + outputs = onnx_function(*onnx_args, **onnx_kwargs) + except Exception as e: + raise _errors.GraphConstructionError( + f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'" + ) from e + finally: + current_tracer = None + + # Add the defined functions to the model + for identifier, onnxscript_function in tracer.functions.items(): + if identifier in model.functions: + continue + if isinstance(onnxscript_function, ir.Function): + ir_function = onnxscript_function + else: + # TODO: Get IR function directly when onnxscript is updated + proto = onnxscript_function.to_function_proto() + ir_function = ir.serde.deserialize_function(proto) + model.functions[identifier] = ir_function + # Opset imports are added to the model in the final add_opset_imports pass + + onnx_nodes = tracer.nodes + del tracer # tracer is no longer needed # NOTE: Instead of using the output names from node.target._schema, # we always use the index if there are more than one outputs so the @@ -572,31 +624,26 @@ def _handle_call_function_node_with_lowering( node_name_to_values[node.name] = outputs for i, output in enumerate(outputs): output.name = f"{node.name}__{i}" + # Set the name of the producing node using the value name for correspondence + producer = output.producer() + if producer is not None: + producer.name = f"node_{output.name}" else: _set_shape_type(outputs, node.meta["val"], complex_to_float=True) node_name_to_values[node.name] = outputs outputs.name = node.name + producer = outputs.producer() + if producer is not None: + producer.name = f"node_{outputs.name}" - for ir_node in tracer.nodes: + for ir_node in onnx_nodes: ir_node.meta["node"] = node # Record the nn.Module stack for the node _set_node_metadata(node, ir_node) # Add the traced nodes to the current graph # Must add nodes to this graph, not model.graph, because it can be a subgraph that is currently being constructed - graph_like.extend(tracer.nodes) - # Add the defined functions to the model - for identifier, onnxscript_function in tracer.functions.items(): - if identifier in model.functions: - continue - if isinstance(onnxscript_function, ir.Function): - ir_function = onnxscript_function - else: - # TODO: Get IR function directly when onnxscript is updated - proto = onnxscript_function.to_function_proto() - ir_function = ir.serde.deserialize_function(proto) - model.functions[identifier] = ir_function - # Opset imports are added to the model in the final add_opset_imports pass + graph_like.extend(onnx_nodes) def _handle_placeholder_node( @@ -901,7 +948,7 @@ def exported_program_to_ir( Args: exported_program: The exported program to convert. lower: Whether to lower the graph to core ONNX operators. - at_conversion: Lower whe translating the FX graph to ONNX IR. + at_conversion: Lower when translating the FX graph to ONNX IR. none: Do not lower the graph. registry: The registry of all ONNX Script decomposition. """ @@ -979,7 +1026,7 @@ def _exported_program_to_onnx_program( exported_program: The exported program to convert. The exported program should be the one that is after decompositions have been applied. lower: Whether to lower the graph to core ONNX operators. - at_conversion: Lower whe translating the FX graph to ONNX IR. + at_conversion: Lower when translating the FX graph to ONNX IR. none: Do not lower the graph. registry: The registry of all ONNX Script decomposition. """ diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/hop.py b/torch/onnx/_internal/exporter/_torchlib/ops/hop.py index 986d7ef16b50b5..6e226ac2fee8ee 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/hop.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/hop.py @@ -19,7 +19,7 @@ def call_op( *args: ir.Value, _num_outputs: int = 1, _domain: str = "", - **kwargs: int | float | str | bool | ir.Graph | ir.TensorProtocol, + **kwargs: int | float | str | bool | ir.Graph | ir.TensorProtocol | Sequence[int], ) -> Sequence[ir.Value]: """Call an operator with the given arguments and keyword arguments. @@ -92,3 +92,66 @@ def higher_order_cond( (), else_node.outputs, nodes=[else_node], name=false_func.name ), ) + + +@onnx_impl(torch.ops.higher_order.scan, no_compile=True) +def higher_order_scan( + body_func: ir.Function, + scan_inits: Sequence[ir.Value], + scan_inputs: Sequence[ir.Value], + additional_inputs: Sequence[ir.Value] | None, + reverse: bool = False, +) -> Sequence[ir.Value]: + """https://github.com/pytorch/pytorch/blob/66ac724b56e6c37a534f3e066423ef2f41d7477f/torch/_higher_order_ops/scan.py#L109""" + subgraph_inputs = [ + *[ + ir.Value( + name=f"{inp.name}_{body_func.name}__subgraph_in", + shape=inp.shape, + type=ir.TensorType(inp.dtype), # type: ignore[arg-type] + ) + for inp in scan_inits + ], + *[ + ir.Value( + name=f"{inp.name}_{body_func.name}__subgraph_in", + # The iterated element passed to the body subgraph does not have a sequence axis. + # It will have a rank one less than the rank of the corresponding scan_input. + shape=ir.Shape(inp.shape[1:]), # type: ignore[index] + type=ir.TensorType(inp.dtype), # type: ignore[arg-type] + ) + for inp in scan_inputs + ], + ] + # The one and only node in the Scan subgraph that calls the body_func + body_node = ir.Node( + body_func.domain, + body_func.name, + [ + *subgraph_inputs, + *(additional_inputs or []), + ], + num_outputs=len(body_func.outputs), + ) + + # ONNX Runtime complains about duplicate output names if we don't rename them. + # But the doesn't seem to be an actual violation of SSA form without renaming. + for func_out, out in zip(body_func.outputs, body_node.outputs): + out.name = f"{func_out.name}_{body_func.name}" + + n_outputs = len(body_func.outputs) - len(scan_inits) + return call_op( + "Scan", + *scan_inits, + *scan_inputs, + _num_outputs=len(body_func.outputs), + body=ir.Graph( + subgraph_inputs, + body_node.outputs, + nodes=[body_node], + name=body_func.name, + ), + num_scan_inputs=len(scan_inputs), + scan_input_directions=[(1 if reverse else 0) for _ in scan_inputs], + scan_output_directions=[(1 if reverse else 0) for _ in range(n_outputs)], + ) diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py index 9ac2aca9c315fc..11d4af57c62347 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -5,9 +5,13 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, TYPE_CHECKING -from onnxscript.onnx_opset import opset20 as op20, opset21 as op21 +from onnxscript.onnx_opset import ( # type: ignore[attr-defined] + opset20 as op20, + opset21 as op21, + opset23 as op23, +) import torch from torch.onnx._internal._lazy_import import onnxscript_ir as ir @@ -15,8 +19,14 @@ from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl +if TYPE_CHECKING: + from onnxscript.values import Opset + aten = torch.ops.aten +_INT64_MAX = 9223372036854775807 +_INT64_MIN = -9223372036854775808 + @onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) def aten_gelu_opset20( @@ -46,3 +56,230 @@ def aten_group_norm( return op21.GroupNormalization( input, weight, bias, epsilon=eps, num_groups=num_groups ) + + +@onnx_impl( + aten.scaled_dot_product_attention.default, trace_only=True, opset_introduced=23 +) +def aten_scaled_dot_product_attention_23( + query: TFloat, + key: TFloat, + value: TFloat, + attn_mask: Optional[TFloat] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> TFloat: + """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor + + Reference: + 1. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + 2. https://onnx.ai/onnx/operators/onnx__Attention.html + + Attempts to convert SDPA to Attention onnx op and fallbacks to an onnx graph equivalent to the following PyTorch code:: + scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale + attn_mask = ( + torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + if is_causal + else attn_mask + ) + attn_mask = ( + attn_mask.masked_fill(not attn_mask, -float("inf")) + if attn_mask.dtype == torch.bool + else attn_mask + ) + attn_weight = torch.softmax( + (Q @ K.transpose(-2, -1) * attn_mask, dim=-1 + ) + attn_weight = torch.dropout(attn_weight, dropout_p) + return attn_weight @ V + + where Q, K, V are the query, key, and value tensors, respectively. + L is the target sequence length, S is the source sequence length, and E is the embedding size. + """ + assert (not is_causal) or (is_causal and attn_mask is None), ( + "is_causal and attn_mask cannot be set at the same time" + ) + assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( + "only 4D query, key, and value are supported" + ) + + # Attention onnx op can only handle non-training scenarios where dropout is disabled. + if dropout_p == 0: + if enable_gqa: + assert ( + query.shape[1] > key.shape[1] == value.shape[1] + and query.shape[1] % key.shape[1] == 0 + ), ( + "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" + ) + else: + assert query.shape[1] == key.shape[1] == value.shape[1], ( + "SDPA (MHA) requires q_num_heads = kv_num_heads" + ) + + # NOTE: num_heads attributes (q_num_heads/kv_num_heads) should not be specified for 4D. + # They are not populated with 4D inputs because this information directy comes from input shapes: + # `q_num_heads=query.shape[1]` and `kv_num_heads=key.shape[1]`. + # This dimension is usually static but it could not be dynamic if also given as an attribute. + # num_heads attributes are needed for 3D attention inputs: + # (shape: [B, S, N*H]), 4D shape is ([B, N, S, H]). + + Y, _, _, _ = op23.Attention( + query, + key, + value, + attn_mask=attn_mask, + scale=scale, + is_causal=is_causal, + ) + return Y + + if scale is None: + scale = _attention_scale(query, op23) + scale = op23.CastLike(scale, query) + + if is_causal: + attn_mask = _causal_attention_mask(query, key, op23) + + if attn_mask is None: + return _aten_scaled_dot_product_attention_no_mask_onnx( + query, key, value, scale, dropout_p, op23 + ) + + return _aten_scaled_dot_product_attention_float_mask_onnx( + query, key, value, attn_mask, scale, dropout_p, op23 + ) + + +def _attention_scale(query: TFloat, op: Opset) -> TFloat: + """Calculate the scale factor for the attention result. + + Args: + query: Tensor of shape [..., L, E] + + Returns: + Scalar scale factor := 1 / math.sqrt(query.size(-1)) + """ + q_shape = op.Shape(query) + q_last_dim = op.Gather(q_shape, op.Constant(value_ints=[-1])) + embedding_size = op.CastLike(q_last_dim, query) + one = op.Constant(value_float=1.0) + cast_one = op.CastLike(one, query) + scale = op.Div(cast_one, op.Sqrt(embedding_size)) + return scale + + +def _causal_attention_mask(query: TFloat, key: TFloat, op: Opset) -> TFloat: + """Create a causal mask for the given query and key tensors. + + Equivalent to:: + mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_mask = torch.zeros(L, S, dtype=torch.float) + attn_mask = attn_mask.masked_fill(not mask, -float("inf")) + + Args: + query: Tensor of shape [..., L, E] + key: Tensor of shape [..., S, E] + + Returns: + Tensor of shape [L, S] + """ + q_shape = op.Shape(query) + k_shape = op.Shape(key) + + target_length = op.Slice( + q_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1]) + ) + source_length = op.Slice( + k_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1]) + ) + # attn_mask = torch.ones(L, S) := { + size = op.Concat(target_length, source_length, axis=0) + attn_mask = op.Expand(op.Constant(value_float=1.0), size) + # } + attn_mask = op.Trilu(attn_mask, upper=0) + # The causal mask has 0s in the lower triangle and -inf in the upper triangle. + attn_mask = op.Where( + op.Equal(attn_mask, op.Constant(value_float=0.0)), + op.Constant(value_float=-float("inf")), + op.Constant(value_float=0.0), + ) + attn_mask = op.CastLike(attn_mask, query) + return attn_mask + + +def _aten_scaled_dot_product_attention_no_mask_onnx( + query: TFloat, + key: TFloat, + value: TFloat, + scale: TFloat, + dropout_p: float, + op: Opset, +) -> TFloat: + # Swap the last two axes of key + key_last_dim = op.Shape(key, start=-1) + key_second_last_dim = op.Shape(key, start=-2, end=-1) + key_first_dims = op.Shape(key, end=-2) + # Contract the dimensions that are not the last two so we can transpose + # with a static permutation. + key_squeezed_shape = op.Concat( + op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0 + ) + key_squeezed = op.Reshape(key, key_squeezed_shape) + key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1]) + key_transposed_shape = op.Concat( + key_first_dims, key_last_dim, key_second_last_dim, axis=0 + ) + key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = op.Mul(query, op.Sqrt(scale)) + key_transposed_scaled = op.Mul( + key_transposed, op.CastLike(op.Sqrt(scale), key_transposed) + ) + attn_weight = op.Softmax( + op.MatMul(query_scaled, key_transposed_scaled), + axis=-1, + ) + attn_weight, _ = op.Dropout(attn_weight, dropout_p) + return op.MatMul(attn_weight, value) + + +def _aten_scaled_dot_product_attention_float_mask_onnx( + query: TFloat, + key: TFloat, + value: TFloat, + attn_mask: TFloat, + scale: TFloat, + dropout_p: float, + op: Opset, +) -> TFloat: + # Swap the last two axes of key + key_last_dim = op.Shape(key, start=-1) + key_second_last_dim = op.Shape(key, start=-2, end=-1) + key_first_dims = op.Shape(key, end=-2) + # Contract the dimensions that are not the last two so we can transpose + # with a static permutation. + key_squeezed_shape = op.Concat( + op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0 + ) + key_squeezed = op.Reshape(key, key_squeezed_shape) + key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1]) + key_transposed_shape = op.Concat( + key_first_dims, key_last_dim, key_second_last_dim, axis=0 + ) + key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = op.Mul(query, op.Sqrt(scale)) + key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) + attn_weight = op.Softmax( + op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), + axis=-1, + ) + attn_weight, _ = op.Dropout(attn_weight, dropout_p) + return op.MatMul(attn_weight, value) diff --git a/torch/onnx/_internal/exporter/_type_casting.py b/torch/onnx/_internal/exporter/_type_casting.py index ac4538a4cfb745..7f2141fe577e64 100644 --- a/torch/onnx/_internal/exporter/_type_casting.py +++ b/torch/onnx/_internal/exporter/_type_casting.py @@ -25,8 +25,8 @@ def get_float4_shape(tensor: torch.Tensor) -> tuple[int, ...]: https://github.com/pytorch/pytorch/issues/146414. the shell dtype is takes up 1 byte per element and semantically represents - two fp4 values packed into 1 byte. Semantically it represents (*tensor.shape, 2) + two fp4 values packed into 1 byte. Semantically it represents (*tensor.shape[:-1], tensor.shape[-1]*2) fp4 elements. """ assert tensor.dtype == torch.float4_e2m1fn_x2 - return (*tensor.shape, 2) + return (*tensor.shape[:-1], tensor.shape[-1] * 2) diff --git a/torch/onnx/_internal/fx/patcher.py b/torch/onnx/_internal/fx/patcher.py index f8a52efda2c621..6c9724e9f5a73e 100644 --- a/torch/onnx/_internal/fx/patcher.py +++ b/torch/onnx/_internal/fx/patcher.py @@ -11,7 +11,7 @@ # TODO: Remove after https://github.com/huggingface/safetensors/pull/318 -@functools.lru_cache(None) +@functools.cache def has_safetensors_and_transformers(): try: # safetensors is not an exporter requirement, but needed for some huggingface models diff --git a/torch/onnx/_internal/fx/registration.py b/torch/onnx/_internal/fx/registration.py index e855f98f044f6a..ec6fc638e3f2aa 100644 --- a/torch/onnx/_internal/fx/registration.py +++ b/torch/onnx/_internal/fx/registration.py @@ -70,7 +70,7 @@ def from_builtin_function( ) -> OpName: """From a builtin function, e.g. operator.add, math.ceil, etc, get the OpName. - FX graph uses built-in functions to caculate sympy expression. This function + FX graph uses built-in functions to calculate sympy expression. This function is used to get the OpName from a builtin function. Args: diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index 4a6e508c1a38c7..968f69328011d7 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -4,7 +4,8 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Any, Optional, Protocol, runtime_checkable, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union +from typing_extensions import Protocol, runtime_checkable import onnx diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 1ed96e62ffad0b..6c414e8d54e788 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs from __future__ import annotations -from typing import Any, Callable, Protocol, runtime_checkable, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import Protocol, runtime_checkable import torch import torch.export as torch_export diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py index 5db66f6c83a4ea..f3f82c0db7c203 100644 --- a/torch/onnx/_internal/jit_utils.py +++ b/torch/onnx/_internal/jit_utils.py @@ -355,7 +355,7 @@ def parse_node_kind(kind: str) -> tuple[str, str]: raise ValueError(f"Node kind: {kind} is invalid. '::' is not in node kind.") domain, opname = kind.split("::", 1) if "::" in opname: - raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.") + raise ValueError(f"Node kind: {kind} is invalid. '::' should only appear once.") return domain, opname diff --git a/torch/onnx/ops/__init__.py b/torch/onnx/ops/__init__.py index 3bbd3a64327ba0..c0c87d5ccaad75 100644 --- a/torch/onnx/ops/__init__.py +++ b/torch/onnx/ops/__init__.py @@ -4,12 +4,23 @@ which are exportable to ONNX. """ +# flake8: noqa: B950 from __future__ import annotations -from typing import TYPE_CHECKING + +__all__ = [ + "aten_decompositions", + "symbolic", + "symbolic_multi_out", + "rotary_embedding", + "attention", +] + + +from typing import Callable, TYPE_CHECKING import torch -from torch.onnx.ops import _symbolic_impl +from torch.onnx.ops import _impl, _symbolic_impl if TYPE_CHECKING: @@ -44,14 +55,19 @@ } +def aten_decompositions() -> dict[torch._ops.OpOverload, Callable]: + """Return the ONNX to ATen decomp table.""" + return _impl.ONNX_ATEN_DECOMP_TABLE + + def _parse_domain_op_type(domain_op: str) -> tuple[str, str]: - splitted = domain_op.split("::", 1) - if len(splitted) == 1: + split = domain_op.split("::", 1) + if len(split) == 1: domain = "" - op_type = splitted[0] + op_type = split[0] else: - domain = splitted[0] - op_type = splitted[1] + domain = split[0] + op_type = split[1] return domain, op_type @@ -82,6 +98,9 @@ def symbolic( This function is used to create a symbolic operator with a single output. To create an operator with multiple outputs, use :func:`symbolic_multi_out`. + You may use ``if torch.onnx.is_in_onnx_export()`` to conditionally enable the + symbolic logic only during ``torch.onnx.export()``. + Example:: class CustomOp(torch.nn.Module): @@ -177,6 +196,9 @@ def symbolic_multi_out( ) -> Sequence[torch.Tensor]: """Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs. + You may use ``if torch.onnx.is_in_onnx_export()`` to conditionally enable the + symbolic logic only during ``torch.onnx.export()``. + Example:: class CustomOp(torch.nn.Module): @@ -256,3 +278,190 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: domain=domain, version=version, ) + + +def rotary_embedding( + X: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + position_ids: torch.Tensor | None = None, + *, + interleaved: bool = False, + num_heads: int = 0, + rotary_embedding_dim: int = 0, +) -> torch.Tensor: + """RotaryEmbedding op in ONNX. + + https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html + + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864. + The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances + between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token's absolute position (position_ids). + + The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles. + For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the + embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector. + The rotation matrix is parameterized by the token's position in the sequence. The rotated halves of the embedding vector are concatenated + to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism. + The rotation ensures that the model captures both absolute and relative positional information. + + Args: + X: The input tensor representing the token embeddings. 4D tensor with + shape `(batch_size, num_heads, sequence_length, head_size)` or 3D tensor + with shape `(batch_size, sequence_length, hidden_size)`. For cases with + a 4D input tensor, `head_size` has to be even. For cases with a 3D input + tensor, `num_heads` attribute must be provided and `hidden_size` must + be an even multiple of `num_heads` where `hidden_size = num_heads * head_size` + cos_cache: The cosine values for the rotation. 2D tensor with shape `(max_position_id_plus_1, head_size / 2)` + for full rotation or `(max_position_id_plus_1, rotary_embedding_dim / 2)` + for partial rotation when `position_ids` are provided. 3D tensor with shape + `(batch_size, sequence_length, head_size / 2)` for full rotation or + `(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial + rotation when `position_ids` are not provided. `max_position_id_plus_1` + is a parameter to the model. + sin_cache: The sine values for the rotation. 2D tensor with shape `(max_position_id_plus_1, head_size / 2)` + for full rotation or `(max_position_id_plus_1, rotary_embedding_dim / 2)` + for partial rotation when `position_ids` are provided. 3D tensor with shape + `(batch_size, sequence_length, head_size / 2)` for full rotation or + `(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial rotation + when `position_ids` are not provided. `max_position_id_plus_1` is a parameter + to the model. + position_ids: The position indices for the tokens. 2D tensor with shape + `(batch_size, sequence_length)`. + interleaved: Rotate using interleaved pattern. Default value is 0 (False). + num_heads: Number of attention heads. Must be provided when input is a 3D tensor. + rotary_embedding_dim: Rotary embedding dimension used to apply partial rotary embeddings. + + Returns: + Tensor with same shape as input. + """ + return _impl.rotary_embedding_23( + X, + cos_cache, + sin_cache, + position_ids=position_ids, + interleaved=interleaved, + num_heads=num_heads, + rotary_embedding_dim=rotary_embedding_dim, + ) + + +def attention( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + attn_mask: torch.Tensor | None = None, + past_key: torch.Tensor | None = None, + past_value: torch.Tensor | None = None, + *, + is_causal: bool = False, + kv_num_heads: int = 0, + q_num_heads: int = 0, + qk_matmul_output_mode: int = 0, + scale: float | None = None, + softcap: float = 0.0, + softmax_precision: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Attention op in ONNX. + + https://onnx.ai/onnx/operators/onnx__Attention.html + + Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed. + + This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V. + + For self attention, ``kv_sequence_length`` equals to ``q_sequence_length``. + + For cross attention, query and key might have different lengths. + + This operator also covers the 3 following variants based on the number of heads: + + 1. Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`. + 2. Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`. + 3. Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`. + + Attention bias to be added is calculated based on ``attn_mask`` input and ``is_causal` `attribute``, only one of which can be provided. + + 1. If ``is_causal`` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment. + 2. `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score. + + Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them. + The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided:: + + The following pattern is applied by this operator: + Q K V + | | | + Q*sqrt(scale) K*sqrt(scale) | + | | | + | Transpose | + | | | + ---MatMul--- | + | | + at_mask---Add | + | | + softcap (if provided) | + | | + Softmax | + | | + -----MatMul------ + | + Y + + Args: + Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, head_size)` or 3D tensor + with shape `(batch_size, q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor, + `q_hidden_size = q_num_heads * head_size` + K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, head_size)` or 3D tensor + with shape `(batch_size, kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor, + `k_hidden_size = kv_num_heads * head_size` + V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, v_head_size)` or 3D tensor + with shape `(batch_size, kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor, + `v_hidden_size = kv_num_heads * v_head_size` + attn_mask: Attention mask. Shape must be broadcastable to 4D tensor with shape + `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` where + `total_sequence_length = past_sequence_length + kv_sequence_length`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element should take part in attention. + Also supports a float mask of the same type as query, key, value that is added to the attention score. + past_key: Past state cache for key with shape `(batch_size, kv_num_heads, past_sequence_length, head_size)` + past_value: Past state cache for value with shape `(batch_size, kv_num_heads, past_sequence_length, v_head_size)` + is_causal: If set to True, the attention masking is a lower triangular matrix when the mask is a square matrix. + The attention masking has the form of the upper left causal bias due to the alignment. + kv_num_heads: Number of heads of key and value. Must be used with 3D inputs of Q, K and V. + q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K and V. + qk_matmul_output_mode: If set to 0, qk_matmul_output is the output of qk matmul. If set to 1, + qk_matmul_output includes the addition of the attention mask to the output of qk matmul. + If set to 2, qk_matmul_output is the output after the softcap operation. If set to 3, + qk_matmul_output is the output after the softmax operation. Default value is 0. + scale: Scaling factor applied to Q*K^T. Default value is 1/sqrt(head_size). To prevent numerical overflow, + scale Q, K by sqrt(scale) before matmul. + softcap: Softcap value for attention weights. Default value is 0. + softmax_precision: The floating-point precision used in softmax computation. If softmax precision is not provided, + the same precision as the input of softmax (Q and K) is used. + + Returns: + A tuple containing: + - The output tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, v_head_size)` or 3D tensor + with shape `(batch_size, q_sequence_length, hidden_size)`. For cases with a 3D input tensor, + `hidden_size = q_num_heads * v_head_size` + - Updated key cache with shape `(batch_size, kv_num_heads, total_sequence_length, head_size)` where + `total_sequence_length = past_sequence_length + kv_sequence_length`. + - Updated value cache with shape `(batch_size, kv_num_heads, total_sequence_length, v_head_size)` where + `total_sequence_length = past_sequence_length + kv_sequence_length`. + - The output of QK matmul. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` + where `total_sequence_length = past_sequence_length + kv_sequence_length`. + """ + return _impl.attention_23( + Q, + K, + V, + attn_mask=attn_mask, + past_key=past_key, + past_value=past_value, + is_causal=is_causal, + kv_num_heads=kv_num_heads, + q_num_heads=q_num_heads, + qk_matmul_output_mode=qk_matmul_output_mode, + scale=scale, + softcap=softcap, + softmax_precision=softmax_precision, + ) diff --git a/torch/onnx/ops/_dtype_mappings.py b/torch/onnx/ops/_dtype_mappings.py new file mode 100644 index 00000000000000..0023e356d89f1e --- /dev/null +++ b/torch/onnx/ops/_dtype_mappings.py @@ -0,0 +1,27 @@ +import torch + + +ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = { + 1: torch.float32, # FLOAT + 2: torch.uint8, # UINT8 + 3: torch.int8, # INT8 + 4: torch.uint16, # UINT16 + 5: torch.int16, # INT16 + 6: torch.int32, # INT32 + 7: torch.int64, # INT64 + 9: torch.bool, # BOOL + 10: torch.float16, # FLOAT16 + 11: torch.double, # DOUBLE + 12: torch.uint32, # UINT32 + 13: torch.uint64, # UINT64 + 14: torch.complex64, # COMPLEX64 + 15: torch.complex128, # COMPLEX128 + 16: torch.bfloat16, # BFLOAT16 + 17: torch.float8_e4m3fn, # FLOAT8E4M3FN + 18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ + 19: torch.float8_e5m2, # FLOAT8E5M2 + 20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ + 21: torch.uint8, # UINT4 + 22: torch.uint8, # INT4 + 23: torch.float4_e2m1fn_x2, # FLOAT4E2M1 +} diff --git a/torch/onnx/ops/_impl.py b/torch/onnx/ops/_impl.py new file mode 100644 index 00000000000000..7127716872f7b9 --- /dev/null +++ b/torch/onnx/ops/_impl.py @@ -0,0 +1,396 @@ +# flake8: noqa: B950 +import math +import typing +from typing import Callable, Optional + +import torch +from torch.onnx.ops import _dtype_mappings + + +_T = typing.TypeVar("_T", bound=Callable) + +# ONNX to ATen decomp table +ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {} +_ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset( + { + 1, # FLOAT + 10, # FLOAT16 + 11, # DOUBLE + 16, # BFLOAT16 + } +) + + +def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]: + """Decorator to register an ONNX operator with a custom implementation.""" + + def decorator(func: _T) -> _T: + overload = f"opset{opset_version}" + torch_op = torch.library.custom_op( + f"onnx::{op_type}.{overload}", mutates_args=() + )(func) + ONNX_ATEN_DECOMP_TABLE[getattr(getattr(torch.ops.onnx, op_type), overload)] = ( + func # type: ignore[assignment] + ) + # Use the same implementation for the fake implementation + # This is possible because we use pure aten ops to implement ONNX ops + torch_op.register_fake(func) + return torch_op # type: ignore[return-value] + + return decorator + + +@_onnx_op("RotaryEmbedding", 23) +def rotary_embedding_23( + x: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + *, + interleaved: bool = False, + num_heads: int = 0, + rotary_embedding_dim: int = 0, +) -> torch.Tensor: + """RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23""" + # First ensure x has shape [batch_size, num_heads, seq_len, head_size] + batch_size = x.shape[0] + sequence_length = x.shape[1] + if len(x.shape) == 3: + hidden_size = x.shape[2] + torch._check( + num_heads != 0, + lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {x.shape}", + ) + head_size = hidden_size // num_heads + new_shape = [batch_size, sequence_length, num_heads, head_size] + x = torch.reshape(x, new_shape) + torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now") + head_size = x.shape[3] + + # Fully or partially perform rotation on x based on rotary_embedding_dim attribute + if rotary_embedding_dim == 0: + # If rotary_embedding_dim not provided, perform full rotation by using head_size + rotary_embedding_dim = head_size + x_rotate = x[:, :, :, :rotary_embedding_dim] + x_not_rotate = x[:, :, :, rotary_embedding_dim:] + rotary_embedding_dim_half = rotary_embedding_dim // 2 + + # Retrieve sin and cos caches using position ids + if position_ids is not None: + cos = cos_cache[ + position_ids + ] # Shape: [batch_size, sequence_length, head_size/2] + sin = sin_cache[ + position_ids + ] # Shape: [batch_size, sequence_length, head_size/2] + else: + cos = cos_cache + sin = sin_cache + cos = cos[ + :, :, :rotary_embedding_dim_half + ] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + sin = sin[ + :, :, :rotary_embedding_dim_half + ] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + cos = torch.unsqueeze( + cos, 2 + ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] + sin = torch.unsqueeze( + sin, 2 + ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] + + # Either divide the x in halves or interleave (based on interleaved attribute) + if interleaved: + x1 = x_rotate[:, :, :, 0::2] + x2 = x_rotate[:, :, :, 1::2] + else: + x1, x2 = torch.chunk(x_rotate, 2, dim=-1) + + # Calculate real and imaginary values + real = cos * x1 - sin * x2 + imag = sin * x1 + cos * x2 + + # Inserted rotated embeddings back to the original x + if interleaved: + # x_rotate[:, :, :, 0::2] = real + # x_rotate[:, :, :, 1::2] = imag + real = torch.unsqueeze(real, -1) + imag = torch.unsqueeze(imag, -1) + x_rotate_concat = torch.cat((real, imag), dim=-1) + x_rotate = torch.reshape(x_rotate_concat, x_rotate.shape) + else: + x_rotate = torch.cat((real, imag), dim=-1) + output = torch.cat((x_rotate, x_not_rotate), dim=-1) + if len(x.shape) == 3: + output = torch.reshape(output, x.shape) + return output + + +def _get_scale_factor(scale: Optional[float], head_size: int) -> float: + """Get the scale factor for attention computation.""" + return scale if scale is not None else (1.0 / math.sqrt(head_size)) + + +def _reshape_3d_to_4d( + tensor: torch.Tensor, batch_size: int, num_heads: int +) -> torch.Tensor: + """Reshape 3D tensor to 4D for multi-head attention.""" + sequence_length, hidden_size = tensor.shape[1], tensor.shape[2] + head_size = hidden_size // num_heads + return ( + tensor.view(batch_size, sequence_length, num_heads, head_size) + .transpose(1, 2) + .contiguous() + ) + + +def _get_qk_output_for_aten_spda( + Q: torch.Tensor, + K: torch.Tensor, + current_q_num_heads: int, + current_kv_num_heads: int, + scale: Optional[float], + qk_matmul_output_mode: int, +) -> torch.Tensor: + """Get QK output tensor based on the specified mode.""" + if qk_matmul_output_mode == 0: + return _compute_qk_output_for_mode_0( + Q, K, current_q_num_heads, current_kv_num_heads, scale + ) + else: + # For other modes, return a zero tensor with correct shape + return torch.zeros_like(torch.matmul(Q, K.transpose(-2, -1))) + + +def _validate_gqa_configuration( + current_q_num_heads: int, current_kv_num_heads: int +) -> None: + """Validate Group Query Attention configuration.""" + torch._check( + current_q_num_heads % current_kv_num_heads == 0, + lambda: f"q_num_heads ({current_q_num_heads}) must be divisible by kv_num_heads ({current_kv_num_heads}) for GQA", + ) + + +def _compute_qk_output_for_mode_0( + Q: torch.Tensor, + K: torch.Tensor, + current_q_num_heads: int, + current_kv_num_heads: int, + scale: Optional[float], +) -> torch.Tensor: + """Helper function to compute QK output for qk_matmul_output_mode == 0.""" + # Handle GQA manually for QK output + K_for_qk = K + if current_q_num_heads != current_kv_num_heads: + repeat_factor = current_q_num_heads // current_kv_num_heads + K_for_qk = K.repeat_interleave(repeat_factor, dim=1) + + scale_factor = _get_scale_factor(scale, Q.shape[3]) + # Scale both Q and K by sqrt(scale_factor) for numerical stability + sqrt_scale = math.sqrt(scale_factor) + Q_scaled = Q * sqrt_scale + K_scaled = K_for_qk * sqrt_scale + return torch.matmul(Q_scaled, K_scaled.transpose(-2, -1)) + + +@_onnx_op("Attention", 23) +def attention_23( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + past_key: Optional[torch.Tensor] = None, + past_value: Optional[torch.Tensor] = None, + *, + is_causal: bool = False, + kv_num_heads: int = 0, + q_num_heads: int = 0, + qk_matmul_output_mode: int = 0, + scale: Optional[float] = None, + softcap: float = 0.0, + softmax_precision: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Attention-23 https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23""" + + num_head_dim, sequence_dim, head_dim = 1, 2, 3 + + # Store original input shape to determine output shape + input_shape_len = len(Q.shape) + batch_size = Q.shape[0] + + # Reshape 3D inputs to 4D format + if len(Q.shape) == 3: + torch._check( + q_num_heads != 0 and kv_num_heads != 0, + lambda: "q_num_heads and kv_num_heads must be provided for 3D inputs", + ) + q_sequence_length = Q.shape[1] + Q = _reshape_3d_to_4d(Q, batch_size, q_num_heads) + K = _reshape_3d_to_4d(K, batch_size, kv_num_heads) + V = _reshape_3d_to_4d(V, batch_size, kv_num_heads) + + torch._check( + len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4, + lambda: "Q, K, and V should be 4D tensors by now", + ) + + # Calculate scale factor if not provided + q_head_size = Q.shape[head_dim] + scale = _get_scale_factor(scale, q_head_size) + + # Handle past key/value caches + present_key = ( + torch.cat([past_key, K], dim=sequence_dim) + if past_key is not None + else K.clone() + ) + present_value = ( + torch.cat([past_value, V], dim=sequence_dim) + if past_value is not None + else V.clone() + ) + + # Update K and V to include past states + K, V = present_key, present_value + + # Get current dimensions + current_q_num_heads = Q.shape[num_head_dim] + current_kv_num_heads = K.shape[num_head_dim] + q_sequence_length = Q.shape[sequence_dim] + kv_sequence_length = K.shape[sequence_dim] + + # Check if we can use the optimized scaled_dot_product_attention (most optimized) + can_use_sdpa = ( + softcap == 0.0 # No softcap + and qk_matmul_output_mode == 0 # Default QK output mode + and softmax_precision is None # No custom softmax precision + and (attn_mask is None or attn_mask.dtype == torch.bool) + ) + + _validate_gqa_configuration(current_q_num_heads, current_kv_num_heads) + + if can_use_sdpa: + # Use PyTorch's optimized scaled_dot_product_attention + + # Prepare attention mask for SDPA + sdpa_attn_mask = None + if attn_mask is not None: + # Convert boolean mask: True means participate, SDPA expects True to mask out + sdpa_attn_mask = ~attn_mask if attn_mask.dtype == torch.bool else attn_mask + + output = torch.nn.functional.scaled_dot_product_attention( + Q, + K, + V, + attn_mask=sdpa_attn_mask, + dropout_p=0.0, + is_causal=is_causal, + scale=scale, + enable_gqa=bool( + current_q_num_heads != current_kv_num_heads + ), # Ensure enable_gqa is not SymBool + ) + + qk_output = _get_qk_output_for_aten_spda( + Q, + K, + current_q_num_heads, + current_kv_num_heads, + scale, + qk_matmul_output_mode, + ) + else: + # Fallback to manual implementation for complex cases + + # Handle Group Query Attention (GQA) and Multi-Query Attention (MQA) + if current_q_num_heads != current_kv_num_heads: + repeat_factor = current_q_num_heads // current_kv_num_heads + K = K.repeat_interleave(repeat_factor, dim=num_head_dim) + V = V.repeat_interleave(repeat_factor, dim=num_head_dim) + + # Create attention bias + attn_bias = torch.zeros( + q_sequence_length, kv_sequence_length, dtype=Q.dtype, device=Q.device + ) + + # Apply causal masking + if is_causal: + torch._check( + attn_mask is None, lambda: "Cannot use both is_causal and attn_mask" + ) + causal_mask = torch.tril( + torch.ones( + q_sequence_length, + kv_sequence_length, + dtype=torch.bool, + device=Q.device, + ) + ) + attn_bias = attn_bias.masked_fill(~causal_mask, float("-inf")) + + # Apply attention mask + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + # Boolean mask: True means participate in attention + attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf")) + else: + # Float mask: added to attention scores + attn_bias = attn_bias + attn_mask + + # Apply scaling factor + scale_factor = _get_scale_factor(scale, Q.shape[3]) + + # Scale both Q and K by sqrt(scale_factor) for numerical stability + sqrt_scale = math.sqrt(scale_factor) + Q_scaled = Q * sqrt_scale + K_scaled = K * sqrt_scale + + # Compute Q @ K^T + qk_matmul_output = torch.matmul(Q_scaled, K_scaled.transpose(-2, -1)) + + # Initialize QK output based on mode + qk_output = qk_matmul_output # Default case for mode 0 + + # Add attention bias + qk_with_bias = qk_matmul_output + attn_bias + + if qk_matmul_output_mode == 1: + qk_output = qk_with_bias + + # Apply softcap if provided + if softcap > 0.0: + qk_with_bias = softcap * torch.tanh(qk_with_bias / softcap) + + if qk_matmul_output_mode == 2: + qk_output = qk_with_bias + + # Apply softmax with optional precision casting + if softmax_precision is not None: + # Map ONNX data type to torch dtype + if softmax_precision in _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS: + original_dtype = qk_with_bias.dtype + qk_with_bias = qk_with_bias.to( + _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[softmax_precision] + ) + qk_softmax = torch.softmax(qk_with_bias, dim=-1) + qk_softmax = qk_softmax.to(original_dtype) + else: + qk_softmax = torch.softmax(qk_with_bias, dim=-1) + else: + qk_softmax = torch.softmax(qk_with_bias, dim=-1) + + if qk_matmul_output_mode == 3: + qk_output = qk_softmax + + # Compute attention output + output = torch.matmul(qk_softmax, V) + + # Reshape output back to 3D if input was 3D + if input_shape_len == 3: + # output: (batch_size, q_num_heads, q_sequence_length, v_head_size) -> (batch_size, q_sequence_length, hidden_size) + output = ( + output.transpose(1, 2).contiguous().view(batch_size, q_sequence_length, -1) + ) + + return output, present_key, present_value, qk_output diff --git a/torch/onnx/ops/_symbolic_impl.py b/torch/onnx/ops/_symbolic_impl.py index 5a5ede4d65d7a0..4876612ad978bf 100644 --- a/torch/onnx/ops/_symbolic_impl.py +++ b/torch/onnx/ops/_symbolic_impl.py @@ -11,38 +11,15 @@ or less the same thing but is required by the `torch.library.custom_op` interface. """ +# flake8: noqa: B950 import dataclasses from collections.abc import Sequence from typing import Optional, Union import torch +from torch.onnx.ops import _dtype_mappings -_ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = { - 1: torch.float32, # FLOAT - 2: torch.uint8, # UINT8 - 3: torch.int8, # INT8 - 4: torch.uint16, # UINT16 - 5: torch.int16, # INT16 - 6: torch.int32, # INT32 - 7: torch.int64, # INT64 - 9: torch.bool, # BOOL - 10: torch.float16, # FLOAT16 - 11: torch.double, # DOUBLE - 12: torch.uint32, # UINT32 - 13: torch.uint64, # UINT64 - 14: torch.complex64, # COMPLEX64 - 15: torch.complex128, # COMPLEX128 - 16: torch.bfloat16, # BFLOAT16 - 17: torch.float8_e4m3fn, # FLOAT8E4M3FN - 18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ - 19: torch.float8_e5m2, # FLOAT8E5M2 - 20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ - 21: torch.uint8, # UINT4 - 22: torch.uint8, # INT4 - 23: torch.float4_e2m1fn_x2, # FLOAT4E2M1 -} - _INT_TYPE = "i" _FLOAT_TYPE = "f" _STRING_TYPE = "s" @@ -221,10 +198,12 @@ def _symbolic( version: Optional[int] = None, ) -> torch.Tensor: torch._check( - onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE, - lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + return torch.zeros( + shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] ) - return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]) @_symbolic.register_fake @@ -246,12 +225,14 @@ def _( version: Optional[int] = None, ) -> torch.Tensor: torch._check( - onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE, - lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", ) # NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured # out how it can handle empty shapes - return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]) + return torch.zeros( + shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] + ) @torch.library.custom_op( @@ -289,10 +270,14 @@ def _symbolic_multi_out( ) for shape, onnx_dtype in zip(shapes, onnx_dtypes): torch._check( - onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE, - lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + ) + outputs.append( + torch.zeros( + shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] + ) ) - outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])) return outputs @@ -321,10 +306,14 @@ def _( ) for shape, onnx_dtype in zip(shapes, onnx_dtypes): torch._check( - onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE, - lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", + onnx_dtype in _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE, + lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE.keys())}", ) # NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured # out how it can handle empty shapes - outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])) + outputs.append( + torch.zeros( + shape, dtype=_dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype] + ) + ) return outputs diff --git a/torch/onnx/symbolic_caffe2.py b/torch/onnx/symbolic_caffe2.py deleted file mode 100644 index 83a2ff6c32ec95..00000000000000 --- a/torch/onnx/symbolic_caffe2.py +++ /dev/null @@ -1,361 +0,0 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code=arg-type -import importlib -import inspect - -from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 -from torch.onnx._internal import jit_utils, registration - - -def register_quantized_ops(domain: str, version: int): - # Register all quantized ops - module = importlib.import_module("torch.onnx.symbolic_caffe2") - quant_version_ops = inspect.getmembers(module) - aten_q_ops = { - "relu", - "_empty_affine_quantized", - "dequantize", - "quantize_per_tensor", - "upsample_nearest2d", - "avg_pool2d", - "reshape", - "slice", - "cat", - "max_pool2d", - "sigmoid", - } - for op, func in quant_version_ops: - name = f"{domain}::{op}" - if inspect.isfunction(func) and not registration.registry.is_registered_op( - name, version - ): - if op in aten_q_ops: - # Override the builtin aten ops - registration.registry.register( - f"aten::{op}", version, func, custom=True - ) - registration.registry.register(name, version, func) - - -def _permute_helper(g: jit_utils.GraphContext, input, axes): - quant_args = { - "axes_i": axes, - "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), - "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), - } - output = g.op("_caffe2::Int8Transpose", input, **quant_args) - symbolic_helper._quantized_ops.add(output) - return output - - -def nchw2nhwc(g: jit_utils.GraphContext, input): - axes = [0, 2, 3, 1] - return _permute_helper(g, input, axes) - - -def nhwc2nchw(g: jit_utils.GraphContext, input): - axes = [0, 3, 1, 2] - return _permute_helper(g, input, axes) - - -def linear_prepack(g: jit_utils.GraphContext, weight, bias): - # Mapping to a dummy caffe2 prepack node. - # During the onnx -> c2 conversion we can look up original weight and bias - # from this node - output = g.op("_caffe2::WeightPrepack", weight, bias) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v", "v", "v", "f", "i") -def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): - kwargs = { - "Y_scale_f": scale, - "Y_zero_point_i": zero_point, - } - output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -def conv_prepack( - g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups -): - # Mapping to a dummy caffe2 prepack node. - # During the onnx -> c2 conversion we can look up original weight and bias - # from this node - output = g.op("_caffe2::WeightPrepack", input, weight, bias) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") -def conv2d( - g: jit_utils.GraphContext, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - scale, - zero_point, -): - kernel_size = weight.node()["shape"][1:3] - kwargs = { - "strides_i": stride, - "pads_i": padding + padding, - "dilations_i": dilation, - "group_i": groups, - "kernels_i": kernel_size, - "order_s": "NHWC", - "Y_scale_f": scale, - "Y_zero_point_i": zero_point, - } - output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") -def conv2d_relu( - g: jit_utils.GraphContext, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - scale, - zero_point, -): - kernel_size = weight.node()["shape"][1:3] - kwargs = { - "strides_i": stride, - "pads_i": padding + padding, - "dilations_i": dilation, - "group_i": groups, - "kernels_i": kernel_size, - "order_s": "NHWC", - "Y_scale_f": scale, - "Y_zero_point_i": zero_point, - } - output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v", "v", "f", "i") -def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): - kwargs = { - "Y_scale_f": scale, - "Y_zero_point_i": zero_point, - } - output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v") -def relu(g: jit_utils.GraphContext, input): - if input not in symbolic_helper._quantized_ops: - return opset9.relu(g, input) - kwargs = { - "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), - "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), - } - output = g.op("_caffe2::Int8Relu", input, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v", "f", "i", "t") -def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): - kwargs = { - "Y_scale_f": scale, - "Y_zero_point_i": zero_point, - } - output = g.op("_caffe2::Int8Quantize", input, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v") -def dequantize(g: jit_utils.GraphContext, input): - return g.op("_caffe2::Int8Dequantize", input) - - -@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t") -def _empty_affine_quantized( - g: jit_utils.GraphContext, - input, - shape, - scale, - zero_point, - dtype, - pin_memory, - memory_format, - layout, -): - return input - - -def upsample_nearest2d( - g: jit_utils.GraphContext, - input, - output_size, - align_corners=None, - scales_h=None, - scales_w=None, -): - if input not in symbolic_helper._quantized_ops: - return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined] - - output_size = symbolic_helper._parse_arg(output_size, "is") - kwargs = { - "output_size_i": output_size, - "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), - "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), - } - input = nchw2nhwc(g, input) - output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) - output = nhwc2nchw(g, output) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") -def max_pool2d( - g: jit_utils.GraphContext, - input, - kernel_size, - stride, - padding, - dilation, - ceil_mode, -): - if input not in symbolic_helper._quantized_ops: - return opset9.max_pool2d( # type: ignore[attr-defined] - g, input, kernel_size, stride, padding, dilation, ceil_mode - ) - kwargs = { - "strides_i": stride, - "pads_i": padding + padding, - "kernel_i": kernel_size[0], - "order_s": "NHWC", - "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), - "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), - } - input = nchw2nhwc(g, input) - output = g.op("_caffe2::Int8MaxPool", input, **kwargs) - output = nhwc2nchw(g, output) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") -def avg_pool2d( - g: jit_utils.GraphContext, - input, - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, - divisor_override=None, -): - if input not in symbolic_helper._quantized_ops: - return opset9.avg_pool2d( # type: ignore[attr-defined] - g, - input, - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad, - divisor_override, - ) - kwargs = { - "strides_i": stride, - "pads_i": padding + padding, - "kernel_i": kernel_size[0], - "order_s": "NHWC", - "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), - "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), - } - input = nchw2nhwc(g, input) - output = g.op("_caffe2::Int8AveragePool", input, **kwargs) - output = nhwc2nchw(g, output) - symbolic_helper._quantized_ops.add(output) - return output - - -def reshape(g: jit_utils.GraphContext, input, shape): - if input not in symbolic_helper._quantized_ops: - return opset9.reshape(g, input, shape) - - kwargs = { - "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), - "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), - } - output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v", "v", "v", "v", "i") -def slice(g: jit_utils.GraphContext, input, dim, start, end, step): - if input not in symbolic_helper._quantized_ops: - return opset9.slice(g, input, dim, start, end, step) - - if step != 1: - raise RuntimeError("ONNX quantized slice export only works for step 1.") - start = symbolic_helper._parse_arg(start, "i") - end = symbolic_helper._parse_arg(end, "i") - dim = symbolic_helper._parse_arg(dim, "i") - - kwargs = { - "start_idx_i": start, - "end_idx_i": end, - "dim_i": dim, - "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), - "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), - } - output = g.op("_caffe2::Int8Slice", input, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): - tensors = symbolic_helper._unpack_list(tensor_list) - input = tensors[0] - if input not in symbolic_helper._quantized_ops: - return opset9.cat(g, tensor_list, dim) - - dim = symbolic_helper._parse_arg(dim, "i") - kwargs = { - "Y_scale_f": tensors[0].node()["Y_scale"], - "Y_zero_point_i": tensors[0].node()["Y_zero_point"], - } - output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output - - -@symbolic_helper.parse_args("v") -def sigmoid(g: jit_utils.GraphContext, input): - if input not in symbolic_helper._quantized_ops: - return opset9.sigmoid(g, input) - # Caffe2 expects the output scale to be 1/2^8 - # and output zero_point to be 0 (quint8 type) - out_scale = 1.0 / 256 - zero_point = 0 - kwargs = { - "Y_scale_f": out_scale, - "Y_zero_point_i": zero_point, - } - output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) - symbolic_helper._quantized_ops.add(output) - return output diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index f609b4452bb08c..dc6312e5f7a325 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -740,7 +740,7 @@ def _generate_wrapped_number(g: jit_utils.GraphContext, scalar): return g.op("Constant", value_t=torch.tensor(scalar)) -def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None): +def _sort_helper(g: jit_utils.GraphContext, input, dim, descending=True, out=None): if out is not None: _unimplemented("Sort", "Out parameter is not supported") shape_ = g.op("Shape", input) @@ -750,12 +750,12 @@ def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)), ) if g.opset <= 10: - if not decending: + if not descending: _unimplemented("Sort", "Ascending is not supported") return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2) else: return g.op( - "TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2 + "TopK", input, dim_size_, axis_i=dim, largest_i=descending, outputs=2 ) diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 2496b84b76070f..469d7a80f77dca 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -99,7 +99,7 @@ def _floor_divide(g: jit_utils.GraphContext, self, other): out = opset9.true_divide(g, self, other) return g.op("Floor", out) else: - # Integer division does trunction rounding + # Integer division does truncation rounding div = g.op("Div", self, other) # Division is negative if: self < 0 != other < 0 zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) @@ -116,8 +116,8 @@ def _floor_divide(g: jit_utils.GraphContext, self, other): @_onnx_symbolic("aten::sort") @symbolic_helper.parse_args("v", "i", "i", "none") -def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): - return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) +def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): + return symbolic_helper._sort_helper(g, self, dim, descending=descending, out=out) @_onnx_symbolic("aten::topk") diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 4d6da7336c3298..47ed56bcfeac90 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -565,15 +565,15 @@ def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): @_onnx_symbolic("aten::sort") @symbolic_helper.parse_args("v", "i", "i", "none") -def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): - return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) +def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): + return symbolic_helper._sort_helper(g, self, dim, descending=descending, out=out) @_onnx_symbolic("aten::argsort") @symbolic_helper.parse_args("v", "i", "i", "none") -def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None): +def argsort(g: jit_utils.GraphContext, self, dim, descending, out=None): _, indices = symbolic_helper._sort_helper( - g, self, dim, decending=decending, out=out + g, self, dim, descending=descending, out=out ) return indices diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index fa295418504f46..af56a875145978 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -472,7 +472,7 @@ def _floor_divide(g: jit_utils.GraphContext, self, other): out = true_divide(g, self, other) return g.op("Floor", out) else: - # Integer division does trunction rounding + # Integer division does truncation rounding div = g.op("Div", self, other) # Division is negative if: self < 0 != other < 0 zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) @@ -3855,7 +3855,7 @@ def unsqueeze(g: jit_utils.GraphContext, self, dim): @_onnx_symbolic("aten::sort") # TODO(justinchuby): Support multiple quantized args in output @symbolic_helper.parse_args("v", "i", "i", "none") -def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): +def sort(g: jit_utils.GraphContext, self, dim, descending, out=None): if out is not None: symbolic_helper._unimplemented( "Sort", "Out parameter is not supported for sort", self diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index cce74561010843..ec08090a595f6f 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -353,9 +353,9 @@ def export( Models exported this way are probably runnable only by Caffe2. - opset_version (int, default 17): The version of the + opset_version (int, default 18): The version of the `default (ai.onnx) opset `_ - to target. Must be >= 7 and <= 17. + to target. Must be >= 7. do_constant_folding: Apply the constant-folding optimization. Constant-folding will replace some of the ops that have all constant inputs with pre-computed constant nodes. @@ -1393,10 +1393,7 @@ def _export( if opset_version is None: opset_version = _constants.ONNX_DEFAULT_OPSET - # torch.onnx.export does not support opset versions >=18 if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET: - # We do not want to fail because we should still allow users to create - # custom symbolic functions for opset>17 warnings.warn( f"Exporting to ONNX opset version {opset_version} is not supported. " f"by 'torch.onnx.export()'. " diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index 73182202ac8b70..e8408a5e848bca 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -315,6 +315,9 @@ def step(self, closure=None): &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\ \end{aligned} + You may note that Noam Shazeer and Mitchell Stern describe using the sum of squared gradients, + while this implementation uses the mean instead. This choice is mathematically equivalent and + allows for greater numerical stability for large sums. .. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost: https://arxiv.org/pdf/1804.04235 @@ -347,9 +350,9 @@ def _single_tensor_adafactor( maximize: bool, has_complex: bool, ): - assert ( - grad_scale is None and found_inf is None - ), "Grad scaling should occur outside of optimizer.step()" + assert grad_scale is None and found_inf is None, ( + "Grad scaling should occur outside of optimizer.step()" + ) if torch.jit.is_scripting(): # this assert is due to JIT being dumb and not realizing that the ops below @@ -381,9 +384,9 @@ def _single_tensor_adafactor( param.mul_(1 - lr * weight_decay) if grad.dim() > 1: - assert ( - row_var is not None and col_var is not None - ), "row_var and col_var should be defined when grad is multidimensional" + assert row_var is not None and col_var is not None, ( + "row_var and col_var should be defined when grad is multidimensional" + ) # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g row_mean = ( torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1)) @@ -397,9 +400,9 @@ def _single_tensor_adafactor( var_estimate = row_var @ col_var var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1)) else: - assert ( - variance is not None - ), "variance should be defined when grad is a vector" + assert variance is not None, ( + "variance should be defined when grad is a vector" + ) grad_squared = grad * grad variance.lerp_(grad_squared, one_minus_beta2_t) # avoid writing into variance during update @@ -472,9 +475,9 @@ def _multi_tensor_adafactor( if len(params) == 0: return - assert ( - grad_scale is None and found_inf is None - ), "Grad scaling should occur outside of optimizer.step()" + assert grad_scale is None and found_inf is None, ( + "Grad scaling should occur outside of optimizer.step()" + ) lr = _to_scalar(lr) @@ -495,9 +498,9 @@ def _multi_tensor_adafactor( device_grads = cast(list[Tensor], device_grads_) device_state_steps = cast(list[Tensor], device_state_steps_) if eps1 is None: - assert ( - dtype is not None - ), "dtype is needed to compute eps1 when eps1 is unset" + assert dtype is not None, ( + "dtype is needed to compute eps1 when eps1 is unset" + ) eps1 = torch.finfo(dtype).eps if TYPE_CHECKING: @@ -537,9 +540,9 @@ def _multi_tensor_adafactor( if is_multidim: device_row_vars = cast(list[Tensor], device_row_vars_) device_col_vars = cast(list[Tensor], device_col_vars_) - assert ( - device_row_vars[0] is not None and device_col_vars[0] is not None - ), "row_var and col_var should be defined when grad is multidimensional" + assert device_row_vars[0] is not None and device_col_vars[0] is not None, ( + "row_var and col_var should be defined when grad is multidimensional" + ) # same as (g * g).mean(dim=-1) w/o materializing an intermediate size g row_means = [ torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads @@ -570,9 +573,9 @@ def _multi_tensor_adafactor( del row_var_means else: device_variances = cast(list[Tensor], device_variances_) - assert ( - device_variances[0] is not None - ), "variance should be defined when grad is a vector" + assert device_variances[0] is not None, ( + "variance should be defined when grad is a vector" + ) grads_squared = torch._foreach_mul(device_grads, device_grads) torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts) diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py index f48311fb11d88b..9b2c76700b356c 100644 --- a/torch/optim/_functional.py +++ b/torch/optim/_functional.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Functional interface.""" + import math from torch import Tensor diff --git a/torch/optim/_multi_tensor/__init__.py b/torch/optim/_multi_tensor/__init__.py index 41a195713b927e..b6818e5a50f3b6 100644 --- a/torch/optim/_multi_tensor/__init__.py +++ b/torch/optim/_multi_tensor/__init__.py @@ -5,6 +5,7 @@ enough, so that more sophisticated ones can be also easily integrated in the future. """ + from functools import partialmethod from torch import optim diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 90fda8eb29359f..83795411358144 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -267,7 +267,9 @@ def _single_tensor_adadelta( p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -326,7 +328,9 @@ def _multi_tensor_adadelta( p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) if len(params) == 0: return @@ -368,7 +372,7 @@ def _multi_tensor_adadelta( device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] if weight_decay != 0: - # Re-use the intermediate memory (device_grads) already allocated for maximize + # Reuse the intermediate memory (device_grads) already allocated for maximize if maximize: torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index e850a0d555fc80..a28895f00e0ff1 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -466,7 +466,7 @@ def _multi_tensor_adagrad( torch._foreach_add_(device_state_steps, 1) if weight_decay != 0: - # Re-use the intermediate memory (device_grads) already allocated for maximize + # Reuse the intermediate memory (device_grads) already allocated for maximize if maximize: torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: @@ -484,7 +484,7 @@ def _multi_tensor_adagrad( torch._foreach_add_(std, eps) if weight_decay != 0 or maximize: - # Again, re-use the intermediate memory (device_grads) already allocated + # Again, reuse the intermediate memory (device_grads) already allocated torch._foreach_mul_(device_grads, minus_clr) numerator = device_grads else: diff --git a/torch/optim/adam.py b/torch/optim/adam.py index a86cb340082fac..62ecf40ed47a96 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -398,7 +398,9 @@ def _single_tensor_adam( assert ( param.device.type == step_t.device.type and param.device.type in capturable_supported_devices - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) # update step step_t += 1 @@ -433,7 +435,9 @@ def _single_tensor_adam( # cast to workaround https://github.com/pytorch/pytorch/issues/140601 key = (device, dtype) if key not in beta1_dict: - beta1_dict[key] = beta1.to(device=device, dtype=dtype, non_blocking=True) # type: ignore[union-attr] + beta1_dict[key] = beta1.to( # type: ignore[union-attr] + device=device, dtype=dtype, non_blocking=True + ) device_beta1: Union[float, Tensor] = beta1_dict[key] else: @@ -455,9 +459,11 @@ def _single_tensor_adam( # expavg.lerp(grad^2, 1-beta2) exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2) else: - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=cast(float, 1 - beta2) + ) else: - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type] if capturable or differentiable: step = step_t @@ -528,7 +534,7 @@ def _single_tensor_adam( else: denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type] # Lastly, switch back to complex view if amsgrad and torch.is_complex(params[i]): @@ -593,7 +599,9 @@ def _multi_tensor_adam( p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) assert grad_scale is None and found_inf is None @@ -669,7 +677,7 @@ def _multi_tensor_adam( # Perform stepweight decay torch._foreach_mul_(device_params, 1 - lr * weight_decay) else: - # Re-use the intermediate memory (device_grads) already allocated for maximize + # Reuse the intermediate memory (device_grads) already allocated for maximize if maximize: torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: @@ -680,7 +688,9 @@ def _multi_tensor_adam( # Decay the first and second moment running average coefficient # Use device beta1 if beta1 is a tensor to ensure all # tensors are on the same device - torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1) + torch._foreach_lerp_( + device_exp_avgs, device_grads, cast(float, 1 - device_beta1) + ) torch._foreach_mul_(device_exp_avg_sqs, beta2) @@ -769,7 +779,10 @@ def _multi_tensor_adam( torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) torch._foreach_add_(exp_avg_sq_sqrt, eps) torch._foreach_addcdiv_( - device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type] + device_params, + device_exp_avgs, + exp_avg_sq_sqrt, + step_size, # type: ignore[arg-type] ) diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index fcaee7e1789ac1..feca25f2bd0c0f 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -256,7 +256,9 @@ def _single_tensor_adamax( assert ( param.device.type == step_t.device.type and param.device.type in capturable_supported_devices - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) # update step step_t += 1 @@ -331,7 +333,9 @@ def _multi_tensor_adamax( p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) lr = _to_scalar(lr) @@ -372,7 +376,7 @@ def _multi_tensor_adamax( if weight_decay != 0: if maximize: - # Re-use the intermediate memory (grouped_grads) already allocated for maximize + # Reuse the intermediate memory (grouped_grads) already allocated for maximize torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) else: grouped_grads = torch._foreach_add( # type: ignore[assignment] diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index eac7d1a9de9d66..739f8d9a8c5f5d 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -305,7 +305,9 @@ def _multi_tensor_asgd( p.device.type == mu.device.type == eta.device.type == step.device.type and p.device.type in capturable_supported_devices for p, mu, eta, step in zip(params, mus, etas, state_steps) - ), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}." + ) lr = _to_scalar(lr) diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 2770e8f67a9f58..457e3f7637e789 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -299,7 +299,7 @@ def _directional_evaluate(self, closure, x, t, d): return loss, flat_grad @torch.no_grad() - def step(self, closure): + def step(self, closure): # type: ignore[override] """Perform a single optimization step. Args: diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 57ee9fd4a00ed4..6f9f6f1a3cf0c7 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Learning Rate Scheduler.""" + from __future__ import annotations import math @@ -280,7 +281,7 @@ class LambdaLR(LRScheduler): >>> # Assuming optimizer has two groups. >>> num_epochs = 100 >>> lambda1 = lambda epoch: epoch // 30 - >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> lambda2 = lambda epoch: 0.95**epoch >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) >>> for epoch in range(num_epochs): >>> train(...) @@ -548,7 +549,7 @@ class MultiStepLR(LRScheduler): >>> # lr = 0.05 if epoch < 30 >>> # lr = 0.005 if 30 <= epoch < 80 >>> # lr = 0.0005 if epoch >= 80 - >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) + >>> scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) >>> for epoch in range(100): >>> train(...) >>> validate(...) @@ -827,7 +828,11 @@ class SequentialLR(LRScheduler): >>> # lr = 0.0405 if epoch == 22 >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20) >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) - >>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[20]) + >>> scheduler = SequentialLR( + ... optimizer, + ... schedulers=[scheduler1, scheduler2], + ... milestones=[20], + ... ) >>> for epoch in range(100): >>> train(...) >>> validate(...) @@ -1271,11 +1276,11 @@ class ReduceLROnPlateau(LRScheduler): Example: >>> # xdoctest: +SKIP >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> scheduler = ReduceLROnPlateau(optimizer, 'min') + >>> scheduler = ReduceLROnPlateau(optimizer, "min") >>> for epoch in range(10): >>> train(...) >>> val_loss = validate(...) - >>> # Note that step should be called after validate() + >>> # Note that step should be called after validate() >>> scheduler.step(val_loss) .. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png @@ -1502,7 +1507,12 @@ class CyclicLR(LRScheduler): Example: >>> # xdoctest: +SKIP >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=10) + >>> scheduler = torch.optim.lr_scheduler.CyclicLR( + ... optimizer, + ... base_lr=0.01, + ... max_lr=0.1, + ... step_size_up=10, + ... ) >>> data_loader = torch.utils.data.DataLoader(...) >>> for epoch in range(10): >>> for batch in data_loader: @@ -1729,7 +1739,9 @@ class CosineAnnealingWarmRestarts(LRScheduler): Example: >>> # xdoctest: +SKIP >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05) - >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20) + >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + ... optimizer, T_0=20 + ... ) >>> for epoch in range(100): >>> train(...) >>> validate(...) @@ -1800,7 +1812,7 @@ def step(self, epoch=None) -> None: >>> for epoch in range(20): >>> scheduler.step() >>> scheduler.step(26) - >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) + >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) """ if epoch is None and self.last_epoch < 0: epoch = 0 @@ -1936,7 +1948,9 @@ class OneCycleLR(LRScheduler): >>> # xdoctest: +SKIP >>> data_loader = torch.utils.data.DataLoader(...) >>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) - >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) + >>> scheduler = torch.optim.lr_scheduler.OneCycleLR( + ... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10 + ... ) >>> for epoch in range(10): >>> for batch in data_loader: >>> train_batch(...) @@ -2141,8 +2155,6 @@ def get_lr(self) -> list[float]: if self.use_beta1: group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined] else: - group[ - "momentum" - ] = computed_momentum # type: ignore[possibly-undefined] + group["momentum"] = computed_momentum # type: ignore[possibly-undefined] return lrs diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 60495cc294bbef..01a02c1eea923b 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Implementation for the NAdam algorithm.""" + from typing import cast, Optional, Union import torch @@ -370,7 +371,9 @@ def _single_tensor_nadam( grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product))) ) param.addcdiv_( - exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next) + exp_avg, + denom, + value=cast(float, (-lr * mu_next) / (1.0 - mu_product_next)), ) @@ -408,7 +411,11 @@ def _multi_tensor_nadam( p.device.type == mp.device.type == step.device.type and p.device.type in capturable_supported_devices for p, mp, step in zip(params, mu_products, state_steps) - ), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + "If capturable=True, " + "params, mu_products, and state_steps must be on supported devices: " + f"{capturable_supported_devices}." + ) lr = _to_scalar(lr) @@ -455,7 +462,7 @@ def _multi_tensor_nadam( # Perform stepweight decay torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) else: - # Re-use the intermediate memory (grouped_grads) already allocated for maximize + # Reuse the intermediate memory (grouped_grads) already allocated for maximize if maximize: torch._foreach_add_( grouped_grads, grouped_params, alpha=weight_decay @@ -576,10 +583,16 @@ def _multi_tensor_nadam( ) torch._foreach_addcdiv_( - grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type] + grouped_params, + grouped_grads, + exp_avg_sq_sqrt, + step_size_grads, # type: ignore[arg-type] ) torch._foreach_addcdiv_( - grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type] + grouped_params, + grouped_exp_avgs, + exp_avg_sq_sqrt, + step_size_expavg, # type: ignore[arg-type] ) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index d6d088a8f74b4d..2dc95eb5555747 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs """Base optimizer.""" + import functools import warnings from collections import defaultdict, OrderedDict @@ -103,7 +104,7 @@ def _stack_if_compiling(x): def _disable_dynamo_if_unsupported( - single_tensor_fn: Optional[Callable[..., object]] = None + single_tensor_fn: Optional[Callable[..., object]] = None, ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # workaround for torchscript BC # it requires all called functions to be in the @@ -349,15 +350,24 @@ class Optimizer: options (used when a parameter group doesn't specify them). """ - OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[tuple[Args, Kwargs]]] # type: ignore[misc] + OptimizerPreHook: TypeAlias = Callable[ + [Self, Args, Kwargs], # type: ignore[misc] + Optional[tuple[Args, Kwargs]], + ] OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] _optimizer_step_pre_hooks: dict[int, OptimizerPreHook] _optimizer_step_post_hooks: dict[int, OptimizerPostHook] _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' - _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' - _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' - _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' + _optimizer_state_dict_post_hooks: ( + 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + ) + _optimizer_load_state_dict_pre_hooks: ( + 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' + ) + _optimizer_load_state_dict_post_hooks: ( + 'OrderedDict[int, Callable[["Optimizer"], None]]' + ) def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107 torch._C._log_api_usage_once("python.optimizer") @@ -847,7 +857,9 @@ def register_load_state_dict_post_hook( handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) self._optimizer_load_state_dict_post_hooks[handle.id] = hook if prepend: - self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + self._optimizer_load_state_dict_post_hooks.move_to_end( + handle.id, last=False + ) # type: ignore[attr-defined] return handle @torch._disable_dynamo @@ -877,12 +889,25 @@ def load_state_dict(self, state_dict: StateDict) -> None: >>> # xdoctest: +SKIP >>> model = torch.nn.Linear(10, 10) >>> optim = torch.optim.SGD(model.parameters(), lr=3e-4) - >>> scheduler1 = torch.optim.lr_scheduler.LinearLR(optim, start_factor=0.1, end_factor=1, total_iters=20) - >>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=80, eta_min=3e-5) - >>> lr = torch.optim.lr_scheduler.SequentialLR(optim, schedulers=[scheduler1, scheduler2], milestones=[20]) - >>> lr.load_state_dict(torch.load('./save_seq.pt')) + >>> scheduler1 = torch.optim.lr_scheduler.LinearLR( + ... optim, + ... start_factor=0.1, + ... end_factor=1, + ... total_iters=20, + ... ) + >>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( + ... optim, + ... T_max=80, + ... eta_min=3e-5, + ... ) + >>> lr = torch.optim.lr_scheduler.SequentialLR( + ... optim, + ... schedulers=[scheduler1, scheduler2], + ... milestones=[20], + ... ) + >>> lr.load_state_dict(torch.load("./save_seq.pt")) >>> # now load the optimizer checkpoint after loading the LRScheduler - >>> optim.load_state_dict(torch.load('./save_optim.pt')) + >>> optim.load_state_dict(torch.load("./save_optim.pt")) """ # shallow copy, to be consistent with module API @@ -933,7 +958,10 @@ def _cast(param, value, param_id=None, param_groups=None, key=None): for k, v in value.items() } elif isinstance(value, Iterable): - return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] + return type(value)( + _cast(param, v, param_id=param_id, param_groups=param_groups) + for v in value + ) # type: ignore[call-arg] else: return value @@ -1021,12 +1049,10 @@ def zero_grad(self, set_to_none: bool = True) -> None: torch._foreach_zero_(grads) @overload - def step(self, closure: None = None) -> None: - ... + def step(self, closure: None = None) -> None: ... @overload - def step(self, closure: Callable[[], float]) -> float: - ... + def step(self, closure: Callable[[], float]) -> float: ... def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: r"""Perform a single optimization step to update parameter. diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 3f787bee9ffa35..cf4c7fd03dbbb2 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Implementation for the RAdam algorithm.""" + from typing import cast, Optional, Union import torch @@ -285,7 +286,9 @@ def _single_tensor_radam( assert ( param.device.type == step_t.device.type and param.device.type in capturable_supported_devices - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) if torch.is_complex(param): param = torch.view_as_real(param) @@ -386,7 +389,9 @@ def _multi_tensor_radam( p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) lr = _to_scalar(lr) @@ -456,7 +461,7 @@ def _multi_tensor_radam( if decoupled_weight_decay: torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) else: - # Re-use the intermediate memory (grouped_grads) already allocated for maximize + # Reuse the intermediate memory (grouped_grads) already allocated for maximize if maximize: torch._foreach_add_( grouped_grads, grouped_params, alpha=weight_decay diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 9b65ff937850be..11bd6b9e4e695f 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Implementation for the RMSprop algorithm.""" + from typing import cast, Optional, Union import torch @@ -292,7 +293,9 @@ def _single_tensor_rmsprop( assert ( param.device.type == step.device.type and param.device.type in capturable_supported_devices - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) grad = grads[i] grad = grad if not maximize else -grad @@ -366,7 +369,9 @@ def _multi_tensor_rmsprop( p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) lr = _to_scalar(lr) @@ -415,7 +420,7 @@ def _multi_tensor_rmsprop( torch._foreach_add_(grouped_state_steps, 1) if weight_decay != 0: - # Re-use the intermediate memory (grouped_grads) already allocated for maximize + # Reuse the intermediate memory (grouped_grads) already allocated for maximize if maximize: torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) else: diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 48289f6b1f0d5e..77e09f0d3854fb 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Implementation for the Resilient backpropagation.""" + from typing import cast, Optional, Union import torch @@ -199,9 +200,9 @@ def step(self, closure=None): For further details regarding the algorithm we refer to the paper `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm - `_. - """ + `_.""" # codespell:ignore + rf""" + Args: {_params_doc} lr (float, optional): learning rate (default: 1e-2) @@ -248,7 +249,9 @@ def _single_tensor_rprop( assert ( param.device.type == step.device.type and param.device.type in capturable_supported_devices - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) step += 1 @@ -315,7 +318,9 @@ def _multi_tensor_rprop( p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( [params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item] diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 8002df8b308d33..0da64ce67aab77 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs r"""Implementation for Stochastic Gradient Descent optimizer.""" + from typing import cast, Optional, Union import torch @@ -397,7 +398,8 @@ def _multi_tensor_sgd( lr = _to_scalar(lr) grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( - [params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item] + [params, grads, momentum_buffer_list], # type: ignore[list-item] + with_indices=True, ) for ( device_params_, @@ -415,7 +417,7 @@ def _multi_tensor_sgd( device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] if weight_decay != 0: - # Re-use the intermediate memory (device_grads) already allocated for maximize + # Reuse the intermediate memory (device_grads) already allocated for maximize if maximize: torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: @@ -502,7 +504,8 @@ def _fused_sgd( for i, g in enumerate(grads): momentum_buffer_list[i] = torch.empty_like(g) grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( - [params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item] + [params, grads, momentum_buffer_list], # type: ignore[list-item] + with_indices=False, ) for (device, _), ( (device_params_, device_grads_, device_momentum_buffer_list), diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index 49a8383c2113a4..b04bf11fe546aa 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -37,9 +37,9 @@ def __init__( sparse_params = [] complex_params = [] for index, param_group in enumerate(self.param_groups): - assert isinstance( - param_group, dict - ), f"param_groups must be a list of dicts, but got {type(param_group)}" + assert isinstance(param_group, dict), ( + f"param_groups must be a list of dicts, but got {type(param_group)}" + ) # given param group, convert given params to a list first before iterating for d_index, d_param in enumerate(param_group["params"]): if d_param.is_sparse: diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index fffd9462dd22e9..378194cc7c1936 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs r"""Implementation for Stochastic Weight Averaging implementation.""" + import itertools import math import warnings from collections.abc import Iterable from copy import deepcopy -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, cast, Literal, Optional, Union import torch from torch import Tensor @@ -68,7 +69,9 @@ def swa_update( averaged_param_list[0] ): torch._foreach_lerp_( - averaged_param_list, current_param_list, 1 / (num_averaged + 1) + averaged_param_list, + current_param_list, + cast(float, 1 / (num_averaged + 1)), ) else: diffs = torch._foreach_sub(current_param_list, averaged_param_list) @@ -225,9 +228,9 @@ def __init__( use_buffers=False, ): # noqa: D107 super().__init__() - assert ( - avg_fn is None or multi_avg_fn is None - ), "Only one of avg_fn and multi_avg_fn should be provided" + assert avg_fn is None or multi_avg_fn is None, ( + "Only one of avg_fn and multi_avg_fn should be provided" + ) self.module = deepcopy(model) if device is not None: self.module = self.module.to(device) @@ -274,7 +277,9 @@ def update_parameters(self, model: Module): ) in grouped_tensors.items(): if self.multi_avg_fn: self.multi_avg_fn( - self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type] + self_params, # type: ignore[arg-type] + model_params, # type: ignore[arg-type] + self.n_averaged.to(device), ) elif ( device is not None diff --git a/torch/overrides.py b/torch/overrides.py index 67e079d07db0fa..562141ff1cf1b6 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -98,7 +98,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: return wrapper -@functools.lru_cache(None) +@functools.cache @_disable_user_warnings def get_ignored_functions() -> set[Callable]: """ @@ -378,7 +378,7 @@ def get_ignored_functions() -> set[Callable]: } -@functools.lru_cache(None) +@functools.cache def get_default_nowrap_functions() -> set[Callable]: """ Return public functions that do not wrap in a subclass when invoked by @@ -404,7 +404,7 @@ def get_default_nowrap_functions() -> set[Callable]: } -@functools.lru_cache(None) +@functools.cache @_disable_user_warnings def get_testing_overrides() -> dict[Callable, Callable]: """Return a dict containing dummy overrides for all overridable functions @@ -424,7 +424,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: >>> inspect.signature(my_add) """ - # Every function in the PyTorchAPI that can be overriden needs an entry + # Every function in the PyTorchAPI that can be overridden needs an entry # in this dict. # # Optimally we would use inspect to get the function signature and define @@ -1511,7 +1511,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: Tensor.view: lambda self, shape: -1, Tensor.view_as: lambda self, other: -1, Tensor.zero_: lambda self: -1, - Tensor.__dlpack__: lambda self, stream=None: -1, + Tensor.__dlpack__: lambda self, stream=None, max_version=None: -1, Tensor.__dlpack_device__: lambda self: -1, torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, } # fmt: skip @@ -1808,7 +1808,7 @@ def handle_torch_function( ) -@functools.lru_cache(None) +@functools.cache def _get_overridable_functions() -> tuple[ dict[Any, list[Callable]], dict[Callable, str] ]: @@ -1881,7 +1881,7 @@ def _get_overridable_functions() -> tuple[ if ignore: continue - # cannot be overriden by __torch_function__ + # cannot be overridden by __torch_function__ if func in get_ignored_functions(): msg = ( "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " @@ -1929,7 +1929,7 @@ def resolve_name(f): return _get_overridable_functions()[1].get(f) -@functools.lru_cache(None) +@functools.cache def _get_tensor_methods() -> set[Callable]: """Returns a set of the overridable methods on ``torch.Tensor``""" overridable_funcs = get_overridable_functions() diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 42e346c626e340..21446c626b9a39 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -675,7 +675,7 @@ def _check_mocked_error(module: Optional[str], field: Optional[str]): memo_count += 1 elif opcode.name == "STACK_GLOBAL": if module is None: - # If not module was passed on in the entries preceeding this one, continue. + # If not module was passed on in the entries preceding this one, continue. continue assert isinstance(module, str) if module not in all_dependencies: diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index ff997a8faecc30..a97cf475b350a2 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -40,7 +40,7 @@ # This is a list of imports that are implicitly allowed even if they haven't # been marked as extern. This is to work around the fact that Torch implicitly # depends on numpy and package can't track it. -# https://github.com/pytorch/MultiPy/issues/46 +# https://github.com/pytorch/multipy/issues/46 # codespell:ignore multipy IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [ "numpy", "numpy.core", @@ -386,13 +386,13 @@ def _make_module( assert module.__name__ not in _package_imported_modules _package_imported_modules[module.__name__] = module - # pre-emptively install on the parent to prevent IMPORT_FROM from trying to + # preemptively install on the parent to prevent IMPORT_FROM from trying to # access sys.modules self._install_on_parent(parent, name, module) if filename is not None: assert mangled_filename is not None - # pre-emptively install the source in `linecache` so that stack traces, + # preemptively install the source in `linecache` so that stack traces, # `inspect`, etc. work. assert filename not in linecache.cache # type: ignore[attr-defined] linecache.lazycache(mangled_filename, ns) diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index 9ffd93fa0efd48..a90a371130e7a8 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs r""" PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. Profiler's context manager API can be used to better understand what model operators are the most expensive, @@ -9,12 +8,14 @@ """ import os +from typing import Any +from typing_extensions import TypeVarTuple, Unpack from torch._C._autograd import _supported_activities, DeviceType, kineto_available from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope from torch._environment import is_fbcode from torch.autograd.profiler import KinetoStepTracker, record_function -from torch.optim.optimizer import register_optimizer_step_post_hook +from torch.optim.optimizer import Optimizer, register_optimizer_step_post_hook from .profiler import ( _KinetoProfile, @@ -43,7 +44,12 @@ from . import itt -def _optimizer_post_hook(optimizer, args, kwargs): +_Ts = TypeVarTuple("_Ts") + + +def _optimizer_post_hook( + optimizer: Optimizer, args: tuple[Unpack[_Ts]], kwargs: dict[str, Any] +) -> None: KinetoStepTracker.increment_step("Optimizer") diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index f10831ade397b5..2f4c763f256d7e 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -91,7 +91,7 @@ def __hash__(self) -> int: @dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True) class TensorKey(Key): - """Hashable identifier for a storage which has been asigned an ID. + """Hashable identifier for a storage which has been assigned an ID. A detailed description of Tensor IDs and why they are needed is given in `torch/csrc/profiler/collection.h` when `TensorID` is declared. To diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index 41748ea39545a8..11a99573948063 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -90,7 +90,7 @@ def format_time(time_ns: int): def match(self, event: _ProfilerEvent): """ Return True if the event matches the pattern. - This method should be overriden in subclass. + This method should be overridden in subclass. """ raise NotImplementedError @@ -150,7 +150,7 @@ class ExtraCUDACopyPattern(Pattern): example: torch.zeros((100, 100)).to("cuda") Pattern: - build-in method |build-in method + built-in method |built-in method ... | aten::to aten::fill_/aten::zero_ | aten::_to_copy @@ -209,7 +209,7 @@ def match(self, event): return False while event.children: event = event.children[-1] - # aten::zero_ is a special optimzation case where fill_ is not called + # aten::zero_ is a special optimization case where fill_ is not called if event.name in self.init_ops: return True return event.name in self.init_ops @@ -367,7 +367,7 @@ def __init__(self, prof: profile, should_benchmark: bool = False): self.name = "Optimizer Single Tensor Pattern" self.optimizers_with_foreach = ["adam", "sgd", "adamw"] self.description = ( - "Deteced optimizer running with single tensor implementation. " + "Detected optimizer running with single tensor implementation. " "Please enable multi tensor implementation by passing 'foreach=True' into optimizer." ) self.url = "" diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index af693aecdde10f..3f1947ef7112b4 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -142,8 +142,16 @@ def compute_queue_depth(self): cuda_event_list = self.profile.kineto_results.events() def is_cuda_launch_kernel(e): - # TODO: find a better way to identify cudaLaunchKernel - return e.name == "cudaLaunchKernel" + """Check if the event is a CUDA launch kernel.""" + launch_patterns = { + "cudaLaunchKernel", # Standard CUDA + "cudaLaunchKernelExC", # Extended C + "__cudaLaunchKernel", # Internal + "cudaLaunchCooperativeKernel", # Collaborative (single-device) + "cudaLaunchCooperativeKernelMultiDevice", # Collaborative (multi-devices) + } + name = str(getattr(e, "name", e)) + return any(name.startswith(pattern) for pattern in launch_patterns) def is_cuda_kernel(e): # TODO: find a better way to identify CUDA Kernel diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 60439511b2d8db..d4d66ca910dcf9 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -157,6 +157,7 @@ def __init__( self.acc_events = acc_events self.custom_trace_id_callback = custom_trace_id_callback self.profiler: Optional[prof.profile] = None + self.has_cudagraphs = False self.mem_tl: Optional[MemoryProfileTimeline] = None self.use_device = None if ProfilerActivity.CUDA in self.activities: @@ -181,6 +182,10 @@ def stop(self): self.stop_trace() def prepare_trace(self): + if hasattr(torch, "_inductor"): + import torch._inductor.config as inductor_config + + self.has_cudagraphs = inductor_config.triton.cudagraphs if (self.profiler is None) or (not self.acc_events): self.profiler = prof.profile( use_cpu=(ProfilerActivity.CPU in self.activities), @@ -221,26 +226,23 @@ def start_trace(self): "distributedInfo", json.dumps(dist_info, cls=_NumpyEncoder) ) - if hasattr(torch, "_inductor"): - import torch._inductor.config as inductor_config - - cuda_version = None - if hasattr(torch, "version"): - from torch.torch_version import TorchVersion - - cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0")) - - if inductor_config.triton.cudagraphs and ( - (cuda_version and cuda_version < "12.6") - or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12() - ): - os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" - self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1") - # FIXME: CUDA Graph does not work well with CUPTI teardown. - # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) - # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) - # Workaround: turn off CUPTI teardown when using CUDA Graphs. - os.environ["TEARDOWN_CUPTI"] = "0" + cuda_version = None + if hasattr(torch, "version"): + from torch.torch_version import TorchVersion + + cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0")) + + if self.has_cudagraphs and ( + (cuda_version and cuda_version < "12.6") + or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12() + ): + os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" + self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1") + # FIXME: CUDA Graph does not work well with CUPTI teardown. + # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) + # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) + # Workaround: turn off CUPTI teardown when using CUDA Graphs. + os.environ["TEARDOWN_CUPTI"] = "0" # Insert the preset user metadata to the trace for k, v in self.preset_metadata.items(): diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py index 6b704fa8094e8b..bce403549d6858 100644 --- a/torch/quantization/fuse_modules.py +++ b/torch/quantization/fuse_modules.py @@ -18,5 +18,5 @@ get_fuser_method, ) -# for backward compatiblity +# for backward compatibility from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn, fuse_conv_bn_relu diff --git a/torch/serialization.py b/torch/serialization.py index 5ad421437518f5..df0c44322e6cd4 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -825,7 +825,7 @@ def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener: container = _open_zipfile_writer_file else: container = _open_zipfile_writer_buffer - return container(name_or_buffer) + return container(name_or_buffer) # type: ignore[arg-type] def _is_compressed_file(f) -> bool: @@ -1147,7 +1147,7 @@ def _save( pickle_protocol, _disable_byteorder_record, ): - serialized_storages = {} + serialized_storages: dict[str, torch.storage.UntypedStorage] = {} id_map: dict[int, str] = {} # Since loading storages that view the same data with different dtypes is diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index ee9b61f71642a0..7d67de3f83848e 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -701,7 +701,7 @@ def bartlett( >>> torch.signal.windows.general_cosine(10, a=[0.46, 0.23, 0.31], sym=True) tensor([0.5400, 0.3376, 0.1288, 0.4200, 0.9136, 0.9136, 0.4200, 0.1288, 0.3376, 0.5400]) - >>> # Generates a periodic general cosine window wit 2 coefficients. + >>> # Generates a periodic general cosine window with 2 coefficients. >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) """.format( diff --git a/torch/sparse/_semi_structured_conversions.py b/torch/sparse/_semi_structured_conversions.py index 2062c3fe9e4d6b..f9b1b0899f87ec 100644 --- a/torch/sparse/_semi_structured_conversions.py +++ b/torch/sparse/_semi_structured_conversions.py @@ -331,11 +331,11 @@ def _compute_compressed_swizzled_bitmask(dense): # we first need to split into the 8x8 tiles bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8) - # then we unfold again to get our indivdual 4x4 tiles + # then we unfold again to get our individual 4x4 tiles bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4) # Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern - # of that tile. Note that the least siginificant bit is stored first. + # of that tile. Note that the least significant bit is stored first. # [1 1 0 0] # [1 1 0 0] -> 0011 0011 -> 51 # [0 0 1 1] 1100 1100 204 @@ -346,7 +346,7 @@ def _compute_compressed_swizzled_bitmask(dense): *bitmask_4x4_chunks.shape[:2], 4, 2, 8 ) - # to convert from binary representaiton, we can do a matmul with powers of two + # to convert from binary representation, we can do a matmul with powers of two powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda") # To run on GPU: cast to float to do matmul and then cast back compressed_swizzled_bitmask = ( diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index 11a55d9d523c7f..0a4196b1f62bac 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -179,7 +179,7 @@ def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor: assert A.dtype == torch.float8_e4m3fn assert B.dtype == torch.float8_e4m3fn - # only cuSPARSELt supports float8_e4m3fn currentl + # only cuSPARSELt supports float8_e4m3fn currently assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT) assert A.packed is not None # Currently we only support per-tensor scaling, with float32 scales diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index ce0e8446cba23c..a5e802084c28ba 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -124,7 +124,7 @@ def multidim_slicer(dims, slices, *tensors): for d, d_slice in zip(dims, slices): if d is not None: s[d] = d_slice - yield t[s] + yield t[tuple(s)] def ptr_stride_extractor(*tensors): @@ -333,7 +333,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): this property enables defining swizzle operators via rearrangements of ``r_offsets`` items.. - Auxilary functions are provided for pre-computing + Auxiliary functions are provided for pre-computing :attr:`indices_data`. For example, :func:`bsr_scatter_mm_indices_data` is used to define indices data for matrix multiplication of BSR and strided tensors. @@ -836,7 +836,7 @@ def bsr_dense_addmm_meta( class TensorAsKey: """A light-weight wrapper of a tensor that enables storing tensors as - keys with efficient memory reference based comparision as an + keys with efficient memory reference based comparison as an approximation to data equality based keys. Motivation: the hash value of a torch tensor is tensor instance diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 08471ac0588826..fd98b4fd95f846 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -9,7 +9,7 @@ shapes, the usage of a bsr tensor as mat1 argument in addmm-based operations typically outperforms the corresponding operation with strided-only inputs when the blocked representation of a tensor -provides a better alignement with memory access than what the strided +provides a better alignment with memory access than what the strided representation would provide. Pre-computed kernel parameters @@ -57,7 +57,7 @@ If the approximations listed above are unacceptable, e.g. when one seeks a maximal performance possible, the optimal kernel parameters for a particular GPU can be computed by simply running this script in -the pytorch developement tree:: +the pytorch development tree:: cd /path/to/pytorch python setup.py develop @@ -91,7 +91,7 @@ optimal set of kernel parameters. Note that running tune_bsr_dense_addmm can take several minutes. So, -use it wisely, e.g. by implementing persisten storage of optimized +use it wisely, e.g. by implementing persistent storage of optimized kernel parameters. See the source code of get_meta and tune_bsr_dense_addmm to learn how to register a custom set of optimal kernel parameters for addmm-based operations. @@ -852,7 +852,7 @@ def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True): if 0: # Check performance dependence on sparsity and apply - # adjustments when differences are noticable (more than 10%). + # adjustments when differences are noticeable (more than 10%). # # When using NVIDIA A100 GPU, the performance dependence on # sparsity is insignificant (0 % ... 10 %) for majority of diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index d09cf58190cd82..721f25512794d0 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -37,7 +37,7 @@ class SparseSemiStructuredTensor(torch.Tensor): """ - This class implementes semi-structured sparsity as a Tensor subclass. + This class implements semi-structured sparsity as a Tensor subclass. Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained @@ -46,11 +46,11 @@ class SparseSemiStructuredTensor(torch.Tensor): There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS. This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items. - Note that as such, this class cannot be insantiated directly. + Note that as such, this class cannot be instantiated directly. -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints - `def from_dense()` - backend specific compression routines - - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm)) + - `def _mm()` - backend specific mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm)) """ _DEFAULT_ALG_ID: int = 0 @@ -123,7 +123,7 @@ def __new__( # noqa: PYI034 ) cls._PROTOTYPE_WARNING_SHOWN = True - # Because this only runs onces, we also load the dispatch table here as well. + # Because this only runs once, we also load the dispatch table here as well. # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead # But this is useful since it allows users to overload the dispatch table for debugging / testing. cls._load_dispatch_table() @@ -197,10 +197,10 @@ def __tensor_unflatten__( requires_grad=requires_grad, ) - __torch_function__ = torch._C._disabled_torch_function_impl + __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore[assignment] @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: + def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: # type: ignore[override] if func._overloadpacket not in cls.SPARSE_DISPATCH: raise NotImplementedError( f"{cls.__name__} only supports a specific set of operations, " @@ -325,7 +325,7 @@ def to_sparse_semi_structured( This function will check to ensure the dense tensor has the right dtype, size, dims, and device. We currently only support semi-structured sparse tensors for 2d CUDA tensors. - Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in + Additionally, your tensor must be a positive multiple of the minimum sparse block size, given in `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8). Args: @@ -388,7 +388,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): This class implements semi-structured sparsity for the CUTLASS backend. - In this implementation, the specified elements and metadata are stored seprately, + In this implementation, the specified elements and metadata are stored separately, in packed and meta respectively. When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and diff --git a/torch/special/__init__.py b/torch/special/__init__.py index 9eb3fefefdea7b..be027caa94cbb7 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -905,7 +905,7 @@ where :math:`C = \log(\pi) \cdot \frac{p (p - 1)}{4}` and :math:`\Gamma(-)` is the Gamma function. -All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise the behavior is undefiend. +All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise the behavior is undefined. """ + """ @@ -1162,7 +1162,7 @@ chebyshev_polynomial_u = _add_docstr( _special.special_chebyshev_polynomial_u, r""" -chebyshev_polynomial_t(input, n, *, out=None) -> Tensor +chebyshev_polynomial_u(input, n, *, out=None) -> Tensor Chebyshev polynomial of the second kind :math:`U_{n}(\text{input})`. @@ -1171,7 +1171,7 @@ :math:`|\text{input}| > 1`, the recursion: .. math:: - T_{n + 1}(\text{input}) = 2 \times \text{input} \times T_{n}(\text{input}) - T_{n - 1}(\text{input}) + U_{n + 1}(\text{input}) = 2 \times \text{input} \times U_{n}(\text{input}) - U_{n - 1}(\text{input}) is evaluated. Otherwise, the explicit trigonometric formula: diff --git a/torch/storage.py b/torch/storage.py index 824a29af59702b..e651bc9d16eb18 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -154,6 +154,10 @@ def _new_shared_filename_cpu( ) -> Self: raise NotImplementedError + @classmethod + def _release_ipc_counter(cls, *args, device=None, **kwargs): + return cls._release_ipc_counter_cuda(*args, **kwargs) + @classmethod def _release_ipc_counter_cuda(cls, *args, **kwargs) -> Self: raise NotImplementedError @@ -1519,7 +1523,7 @@ def __instancecheck__(cls, instance): class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta): @classmethod - def _new_shared(cls, size): + def _new_shared(cls, size): # type: ignore[override] """Create a new storage in shared memory with the same data type.""" untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size()) return cls(wrap_storage=untyped_storage) diff --git a/torch/testing/_internal/check_kernel_launches.py b/torch/testing/_internal/check_kernel_launches.py index d602c24246f790..d2219ef4ea56aa 100644 --- a/torch/testing/_internal/check_kernel_launches.py +++ b/torch/testing/_internal/check_kernel_launches.py @@ -112,8 +112,8 @@ def check_file(filename): return 0 if should_exclude_file(filename): return 0 - with open(filename) as fo: - contents = fo.read() + with open(filename) as f: + contents = f.read() unsafeCount = check_code_for_cuda_kernel_launches(contents, filename) return unsafeCount diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 82f486e97f6c80..a211851d671fa2 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -57,9 +57,9 @@ def CDNA2OrLater(): def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100"] + arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": - arch_list += ["gfx1201", "gfx950"] + arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return not IS_WINDOWS and SM80OrLater @@ -67,9 +67,9 @@ def evaluate_platform_supports_flash_attention(): def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100"] + arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": - arch_list += ["gfx1201", "gfx950"] + arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return True diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index b74972b00dd283..01499280da8f5d 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -877,7 +877,7 @@ def instantiate_device_type_tests( ): class_name = generic_test_class.__name__ + base.device_type.upper() - # type set to Any and suppressed due to unsupport runtime class: + # type set to Any and suppressed due to unsupported runtime class: # https://github.com/python/mypy/wiki/Unsupported-Python-Features device_type_test_class: Any = type(class_name, (base, generic_test_class), {}) @@ -1320,7 +1320,7 @@ def largeTensorTest(size, device=None, inductor=TEST_WITH_TORCHINDUCTOR): size may be a number of bytes, a string of the form "N GB", or a callable If the test is a device generic test, available memory on the primary device will be checked. - It can also be overriden by the optional `device=` argument. + It can also be overridden by the optional `device=` argument. In other tests, the `device=` argument needs to be specified. """ if isinstance(size, str): diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 16af0737413a75..a9d1c1332724e1 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -39,6 +39,7 @@ retry_on_connect_failures, skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, + TEST_CUDA, TEST_HPU, TEST_WITH_ROCM, TEST_WITH_TSAN, @@ -55,6 +56,10 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"] +DDP_RANK_DEVICES = ["cuda", "xpu"] +HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU + class TestSkip(NamedTuple): exit_code: int @@ -109,21 +114,25 @@ class DistTestCases: backend_feature["xpu"] = {"xccl"} +def requires_ddp_rank(device): + return device in DDP_RANK_DEVICES + + def skip_if_no_gpu(func): """Skips if the world size exceeds the number of GPUs, ensuring that if the test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" @wraps(func) def wrapper(*args, **kwargs): - if not torch.cuda.is_available(): + if not (TEST_CUDA or TEST_HPU or TEST_XPU): sys.exit(TEST_SKIPS["no_cuda"].exit_code) world_size = int(os.environ["WORLD_SIZE"]) - if torch.cuda.device_count() < world_size: + if TEST_CUDA and torch.cuda.device_count() < world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) - if TEST_HPU and torch.hpu.device_count < world_size: + if TEST_HPU and torch.hpu.device_count() < world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) + if TEST_XPU and torch.xpu.device_count() < world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) - if TEST_XPU and torch.xpu.device_count < world_size: - sys.exit(TEST_SKIPS[f"multi-xpu-{world_size}"].exit_code) return func(*args, **kwargs) @@ -189,7 +198,13 @@ def wrapper(*args, **kwargs): def at_least_x_gpu(x): - return torch.cuda.is_available() and torch.cuda.device_count() >= x + if TEST_CUDA and torch.cuda.device_count() >= x: + return True + if TEST_HPU and torch.hpu.device_count() >= x: + return True + if TEST_XPU and torch.xpu.device_count() >= x: + return True + return False def skip_if_lt_x_gpu(x): @@ -355,6 +370,35 @@ def requires_mpi(): ) +def requires_accelerator_dist_backend(backends=None): + """ + Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available. + + Args: + backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]). + If None, checks all supported accelerator backends (NCCL, XCCL, HCCL). + + Returns: + callable: A decorator that skips the test if no specified accelerator backend is available. + """ + if backends is None: + backends = ACCELERATOR_DIST_BACKENDS + + backend_available = any( + { + "nccl": c10d.is_nccl_available, + "xccl": c10d.is_xccl_available, + "hccl": lambda: TEST_HPU, + }.get(backend, lambda: False)() + for backend in backends + ) + + return skip_but_pass_in_sandcastle_if( + not backend_available, + f"No accelerator communication backend available among {backends}", + ) + + def requires_multicast_support(): has_multicast_support = ( torch.cuda.is_available() @@ -968,9 +1012,14 @@ def is_master(self) -> bool: class DistributedTestBase(MultiProcessTestCase): def setUp(self): super().setUp() + os.environ["WORLD_SIZE"] = str(self.world_size) self._spawn_processes() def tearDown(self): + try: + torch.distributed.destroy_process_group() + except AssertionError: + pass try: os.remove(self.file_name) except OSError: @@ -986,12 +1035,14 @@ def backend(self, device) -> str: else: return "gloo" - def create_pg(self, device): + def create_pg(self, device, world_size=None): + if world_size is None: + world_size = self.world_size num_visible_devices = torch.get_device_module(device).device_count() store = torch.distributed.FileStore(self.file_name, num_visible_devices) torch.distributed.init_process_group( backend=self.backend(device), - world_size=self.world_size, + world_size=world_size, rank=self.rank, store=store, ) @@ -1404,7 +1455,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @contextmanager -def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False): +def _dynamo_dist_per_rank_init( + rank, world_size, backend="nccl", init_pg=True, fake_pg=False +): # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, # Just manually implement the most important part of the dynamo behavior to reset/clear. if not fake_pg: @@ -1421,7 +1474,7 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False): store=store, ) else: - c10d.init_process_group("nccl", rank=rank, world_size=world_size) + c10d.init_process_group(backend=backend, rank=rank, world_size=world_size) torch._dynamo.reset() torch._dynamo.utils.counters.clear() try: @@ -1465,7 +1518,7 @@ def tearDownClass(cls): super().tearDownClass() -class DynamoDistributedMultiProcTestCase(MultiProcessTestCase): +class DynamoDistributedMultiProcTestCase(DistributedTestBase): """ Use this for tests that actually run on multiple GPUs. @@ -1476,20 +1529,9 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase): sparingly for integration tests. """ - def setUp(self): - super().setUp() - self._spawn_processes() - - def tearDown(self): - super().tearDown() - try: - os.remove(self.file_name) - except OSError: - pass - @property def world_size(self) -> int: - return torch.cuda.device_count() + return torch.accelerator.device_count() @classmethod def _run( @@ -1649,9 +1691,11 @@ def setUpClass(cls): # Use device count as world size device_type = cls.device_type() - cls.world_size = torch.get_device_module(device_type).device_count() - if cls.world_size == 0: - raise unittest.SkipTest(f"No {device_type} devices available") + # If world_size is not set, use device count + if cls.world_size == -2: + cls.world_size = torch.get_device_module(device_type).device_count() + if cls.world_size == 0: + raise unittest.SkipTest(f"No {device_type} devices available") logger.info( f"Testing class {cls.__name__} on {cls.world_size} {device_type}" # noqa: G004 @@ -1716,7 +1760,7 @@ def wrapper(self): f"skipping rest of tests in Test class: {self.__class__.__name__}" # noqa: G004 ) # Poison rest of tests (because ProcessGroup may be not - # re-usable now) + # reusable now) self.__class__.poison_pill = True raise rv diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 9548f0bf3dad66..a9e24eb90ef8c3 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1180,13 +1180,15 @@ def run_subtests(self, *args, **kwargs): return run_subtests(self, *args, **kwargs) @classmethod - def _run(cls, rank, test_name, file_name, pipe, **kwargs): + def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override] self = cls(test_name) self.rank = rank self.file_name = file_name fake_pg = kwargs.get("fake_pg", False) print(f"dist init r={self.rank}, world={self.world_size}") + if torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) # Specify gloo backend to make 'init_process_group()' succeed, # Actual tests will be skipped if there is no enough GPUs. diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b4a274670109c8..92ae95bef8d0e4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -51,7 +51,7 @@ from torch.utils import _pytree as pytree -from packaging import version +from torch._vendor.packaging import version from torch.testing._internal.opinfo.core import ( # noqa: F401 L, @@ -97,6 +97,7 @@ sample_inputs_foreach, ForeachFuncInfo, gradcheck_wrapper_hermitian_input, + gradcheck_wrapper_ctc_loss, gradcheck_wrapper_triangular_input, gradcheck_wrapper_triangular_input_real_positive_diagonal, gradcheck_wrapper_masked_operation, @@ -776,9 +777,13 @@ def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs): yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False}) def error_inputs_arange(op, device, **kwargs): - yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzer') - yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') - yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzero') + yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(1549556900, args=(1549556828, 1989724)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range') yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range') @@ -1601,7 +1606,7 @@ def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): ((S,), {'dtype': dtype, 'device': device}), # Hard-code some dtypes/devices. We want to test cases where the # (dtype, device) is different from the input's (dtype, device) - ((S,), {'dtype': torch.double}), + ((S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), ((S,), {'device': 'cpu'}), ((S,), {'dtype': torch.double, 'device': 'cpu'}), ] @@ -1786,7 +1791,6 @@ def error_inputs_margin_ranking_loss(op, device, **kwargs): error_regex='margin_ranking_loss : All input tensors should') def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=False, **kwargs): - other_dtype = torch.half if torch.backends.mps.is_available() else torch.double # input_shape, output_shape, strides, kwargs # lengths of output_shape and strides must be equal inputs = [ @@ -1796,9 +1800,9 @@ def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=Fals ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}), # Hard-code some dtypes/devices. We want to test cases where the # (dtype, device) is different from the input's (dtype, device) - ((S,), (10,), (S,), {'dtype': other_dtype}), + ((S,), (10,), (S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}), - ((S,), (2, 2, 2), (L, M, S), {'dtype': other_dtype, 'device': 'cpu'}), + ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}), ] if torch.cuda.is_available(): inputs.append(((S,), (7, 2), (3, 4), {'device': 'cuda'})) @@ -1919,6 +1923,7 @@ def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs): def get_val(dtype): return make_tensor([], dtype=dtype, device="cpu").item() + double_dtype = torch.double if device != "mps:0" else torch.float inputs = [ ((), get_val(dtype), {}), ((S, S), get_val(dtype), {}), @@ -1926,13 +1931,16 @@ def get_val(dtype): ((S,), get_val(dtype), {'dtype': dtype, 'device': device}), # Hard-code some dtypes/devices. We want to test cases where the # (dtype, device) is different from the input's (dtype, device) - ((S,), get_val(torch.double), {'dtype': torch.double}), + ((S,), get_val(double_dtype), {'dtype': double_dtype}), ((S,), get_val(dtype), {'device': 'cpu'}), - ((S,), get_val(torch.double), {'dtype': torch.double, 'device': 'cpu'}), + ((S,), get_val(double_dtype), {'dtype': double_dtype, 'device': 'cpu'}), ] if torch.cuda.is_available(): inputs.append(((S,), get_val(dtype), {'device': 'cuda'})) + if torch.mps.is_available() and dtype not in [torch.float64, torch.complex128, torch.uint32, torch.uint16]: + inputs.append(((S,), get_val(dtype), {'device': 'mps'})) + if not dtype.is_signed: # For unsigned dtypes, negative values are converted. inputs.append(((S,), -get_val(dtype), {})) @@ -3266,17 +3274,17 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): test_args = [ ([1, 2],), (slice(0, 3),), - ([slice(0, 3), 1],), - ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],), - ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],), - ([slice(None), slice(None), [0, 3]],), - ([slice(None), [0, 3], slice(None)],), - ([[0, 3], slice(None), slice(None)],), - ([[0, 3], [1, 2], slice(None)],), - ([[0, 3], ],), - ([[0, 3], slice(None)],), - ([[0, 3], Ellipsis],), - ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],), + ((slice(0, 3), 1),), + (([0, 2, 3], [1, 3, 3], [0, 0, 2]),), + (([0, 0, 3], [1, 1, 3], [0, 0, 2]),), + ((slice(None), slice(None), [0, 3]),), + ((slice(None), [0, 3], slice(None)),), + (([0, 3], slice(None), slice(None)),), + (([0, 3], [1, 2], slice(None)),), + (([0, 3], ),), + (([0, 3], slice(None)),), + (([0, 3], Ellipsis),), + (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),), (index_variable(2, S, device=device),), (mask_not_all_zeros((S,)),), ] @@ -3284,7 +3292,7 @@ def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): for args in test_args: yield SampleInput(make_arg((S, S, S)), args=args) - yield SampleInput(make_arg((S, S, S, S)), args=([slice(None), [0, 1], slice(None), [0, 1]],)) + yield SampleInput(make_arg((S, S, S, S)), args=((slice(None), [0, 1], slice(None), [0, 1]),)) def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -4753,6 +4761,13 @@ def shape(size, rank, with_batch_channel=True): return tuple([N, C] + ([size] * rank)) return tuple([size] * rank) + def uneven_shape(size, rank, with_batch_channel=True): + rc = list(shape(size, rank, with_batch_channel)) + rc[-1] += 1 + if rank > 2: + rc[-2] -= 1 + return tuple(rc) + if mode in ('bilinear', 'bicubic') and dtype == torch.uint8: make_arg = partial( make_tensor, @@ -4791,6 +4806,21 @@ def shape(size, rank, with_batch_channel=True): mode=mode, align_corners=align_corners, ) + if rank > 1 and dtype.is_floating_point: + yield SampleInput( + make_arg(uneven_shape(D, rank)), + uneven_shape(S, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + yield SampleInput( + make_arg(uneven_shape(D, rank)), + uneven_shape(L, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) for recompute_scale_factor in [False, True]: for scale_factor in [1.7, 0.6]: yield SampleInput( @@ -7761,7 +7791,7 @@ def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) def make_bool_mask(shape): - # Make sure atleast one element is nonzero, + # Make sure at least one element is nonzero, # except for empty tensor mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) @@ -8374,7 +8404,9 @@ def make_log_probs(s): input_lengths = input_lengths.tolist() target_lengths = target_lengths.tolist() - yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,), kwargs=dict(reduction=r, zero_infinity=z)) + yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,), + kwargs=dict(reduction=r, zero_infinity=z)) + def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs): shape = (2, 3) @@ -13688,7 +13720,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_fwgrad_bwgrad=True, skips=( # skips below tests as torch.frexp returns tuple-like (mantissa, exponent) as outputs, - # while theses tests currently requires output to a single tensor. + # while these tests currently requires output to a single tensor. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), @@ -15613,9 +15645,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'), DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'), DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), - # NotImplementedError: The operator 'aten::_upsample_nearest_exact3d.out' is not currently implemented - # for the MPS device. - DecorateInfo(unittest.expectedFailure, 'TestConsistency'), ), supports_out=False), OpInfo('nn.functional.interpolate', @@ -19373,7 +19402,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), # The inplace variant (Tensor.normal_) is different from torch.normal - # inplace varaint Tensor.normal_ is decomposed using randn_like() + # inplace variant Tensor.normal_ is decomposed using randn_like() DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'))), OpInfo('normal', # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here @@ -19598,15 +19627,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_gradgrad=True, supports_out=True, check_batched_grad=False, - skips=( - # Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None - # but it returned something else instead. - DecorateInfo( - unittest.expectedFailure, - 'TestProxyTensorOpInfo', - 'test_make_fx_symbolic_exhaustive_out' - ), - )), + ), OpInfo('vstack', aliases=('row_stack',), dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), @@ -20148,15 +20169,14 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ), OpInfo('logcumsumexp', dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), - backward_dtypes=floating_and_complex_types_and(torch.bfloat16), - backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'), # RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble' - # Falling back to non-numerically stablized exp, causing nan in the results. + # Falling back to non-numerically stabilized exp, causing nan in the results. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', dtypes=[torch.complex128]), DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]), DecorateInfo( @@ -21347,15 +21367,9 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): dtypes=floating_types(), supports_out=False, sample_inputs_func=sample_inputs_ctc_loss, + # gradcheck_wrapper, see https://github.com/pytorch/pytorch/issues/52241 + gradcheck_wrapper=gradcheck_wrapper_ctc_loss, skips=( - # https://github.com/pytorch/pytorch/issues/67462 - # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0 - DecorateInfo( - unittest.expectedFailure, - "TestBwdGradients", - "test_fn_grad", - dtypes=(torch.float64,), - ), # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented DecorateInfo( unittest.expectedFailure, diff --git a/torch/testing/_internal/common_mkldnn.py b/torch/testing/_internal/common_mkldnn.py index f9a05cf807ae8f..ffaed6c7e009c8 100644 --- a/torch/testing/_internal/common_mkldnn.py +++ b/torch/testing/_internal/common_mkldnn.py @@ -20,24 +20,30 @@ def bf32_is_not_fp32(): @contextlib.contextmanager def bf32_off(): - old_matmul_precision = torch.get_float32_matmul_precision() + old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision + old_conv_precision = torch.backends.mkldnn.conv.fp32_precision try: - torch.set_float32_matmul_precision("highest") + torch.backends.mkldnn.matmul.fp32_precision = "ieee" + torch.backends.mkldnn.conv.fp32_precision = "ieee" yield finally: - torch.set_float32_matmul_precision(old_matmul_precision) + torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision + torch.backends.mkldnn.conv.fp32_precision = old_conv_precision @contextlib.contextmanager -def bf32_on(self, bf32_precision=1e-5): - old_matmul_precision = torch.get_float32_matmul_precision() +def bf32_on(self, bf32_precision=1e-2): + old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision + old_conv_precision = torch.backends.mkldnn.conv.fp32_precision old_precision = self.precision try: - torch.set_float32_matmul_precision("medium") + torch.backends.mkldnn.matmul.fp32_precision = "bf16" + torch.backends.mkldnn.conv.fp32_precision = "bf16" self.precision = bf32_precision yield finally: - torch.set_float32_matmul_precision(old_matmul_precision) + torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision + torch.backends.mkldnn.conv.fp32_precision = old_conv_precision self.precision = old_precision @@ -45,7 +51,7 @@ def bf32_on(self, bf32_precision=1e-5): # allow_bf32=True, another with allow_bf32=False. When running with # allow_bf32=True, it will use reduced precision as specified by the # argument -def bf32_on_and_off(bf32_precision=1e-5): +def bf32_on_and_off(bf32_precision=1e-2): def with_bf32_disabled(self, function_call): with bf32_off(): function_call() diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 5d99296cc401b7..d713f7b5535d5f 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -16,8 +16,7 @@ floating_types, floating_and_complex_types_and, get_all_fp_dtypes) from torch.testing._internal.common_device_type import ( _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol, - skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, - skipCUDAVersionIn) + skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS) from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_nn import ( cosineembeddingloss_reference, cross_entropy_loss_reference, ctcloss_reference, @@ -3172,14 +3171,6 @@ def padding3d_circular_ref(inp, pad): DecorateInfo( unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors", active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda' - ), - DecorateInfo( - skipCUDAVersionIn([(11, 7)]), "TestExpandedWeightModule", "test_module", - device_type='cuda' - ), - DecorateInfo( - skipCUDAVersionIn([(11, 7)]), "TestDecomp", "test_rnn_decomp_module", - device_type='cuda' ) ) @@ -3852,7 +3843,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad unittest.expectedFailure, 'TestModule', 'test_memory_format', - active_if=operator.itemgetter('training'), + active_if=operator.itemgetter('training') and not _macos15_or_newer, device_type='mps', ),) ), @@ -4040,26 +4031,9 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.Hardshrink, module_inputs_func=module_inputs_torch_nn_Hardshrink, - skips=( - # not supported on MPS backend - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_if_train_and_eval_modes_differ', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_save_load', device_type='mps'),), ), ModuleInfo(torch.nn.Hardswish, module_inputs_func=module_inputs_torch_nn_Hardswish, - skips=None if _macos15_or_newer else ( - # Fails on backward check on MPS - # See https://github.com/pytorch/pytorch/issues/107214 - DecorateInfo( - unittest.expectedFailure, - 'TestModule', - 'test_memory_format', - active_if=operator.itemgetter('training'), - device_type='mps', - ),), supports_gradgrad=False), ModuleInfo(torch.nn.Hardtanh, module_inputs_func=module_inputs_torch_nn_Hardtanh, diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index dfbdb48010b2ea..f6a33486dc563e 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -35,6 +35,9 @@ def mps_ops_modifier( "as_strided", "as_strided_copy", "as_strided_scatter", + "asin", + "acos", + "atan", "broadcast_tensors", "broadcast_to", "chalf", @@ -45,6 +48,7 @@ def mps_ops_modifier( "conj_physical", "contiguous", "cos", + "cosh", "diag", "diag_embed", "diagflat", @@ -57,6 +61,7 @@ def mps_ops_modifier( "empty_permuted", "empty_strided", "exp", + "expm1", "exp2", "expand", "expand_as", @@ -64,6 +69,7 @@ def mps_ops_modifier( "flatten", "fill", "full", + "full_like", "H", "hsplit", "imag", @@ -77,6 +83,7 @@ def mps_ops_modifier( "linalg.diagonal", "linalg.svd", "log10", + "log1p", "log2", "log", "mH", @@ -98,6 +105,7 @@ def mps_ops_modifier( "nn.functional.conv2d", "nn.functional.conv_transpose1d", "nn.functional.conv_transpose2d", + "nn.functional.conv_transpose3d", "nn.functional.feature_alpha_dropoutwithout_train", "nn.functional.padcircular", "nn.functional.softsign", @@ -105,6 +113,7 @@ def mps_ops_modifier( "nn.functional.unfold", "nonzero", "ones", + "ones_like", "outer", "permute", "permute_copy", @@ -122,8 +131,10 @@ def mps_ops_modifier( "scalar_tensor", "select", "sgn", + "sigmoid", "sin", "sinc", + "sinh", "slice", "special.spherical_bessel_j0", "special.entr", @@ -146,6 +157,8 @@ def mps_ops_modifier( "tensor_split", "transpose", "transpose_copy", + "tril", + "triu", "true_divide", "T", "unbind", @@ -164,13 +177,13 @@ def mps_ops_modifier( "vsplit", "zero_", "zeros", + "zeros_like", } AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = { "__rdiv__", "__rmatmul__", "_chunk_cat", - "acos", "acosh", "all", "allclose", @@ -180,8 +193,6 @@ def mps_ops_modifier( "addcmul", "addmmdecomposed", "addmv", - "asin", - "atan", "atanh", "bfloat16", "bmm", @@ -193,7 +204,6 @@ def mps_ops_modifier( "combinations", "corrcoef", "constant_pad_nd", - "cosh", "cov", "count_nonzero", "diff", @@ -203,7 +213,6 @@ def mps_ops_modifier( "einsum", "eq", "equal", - "expm1", "eye", "fft.fft", "fft.fft2", @@ -236,7 +245,6 @@ def mps_ops_modifier( "linalg.pinv", "linspace", "linspacetensor_overload", - "log1p", "logical_and", "logical_not", "logical_or", @@ -266,7 +274,6 @@ def mps_ops_modifier( "roll", "rot90", "short", - "sigmoid", "sinh", "sqrt", "square", @@ -278,8 +285,6 @@ def mps_ops_modifier( "trace", "trapz", "trapezoid", - "tril", - "triu", "vstack", "where", "byte", @@ -299,7 +304,7 @@ def mps_ops_modifier( ], # test blow pass on macOS 12 as it falls back to cpu # Argsort case using duplicate indices (undefined behaviour): - # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu') + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') # Elements from index 30 and 5133 are both equal. # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. @@ -341,7 +346,7 @@ def mps_ops_modifier( # 'nn.functional.pairwise_distance': [torch.float16], # test blow pass on macOS 12 as it falls back to cpu # Argsort case using duplicate indices (undefined behaviour): - # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu') + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') # Elements from index 30 and 5133 are both equal. # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. @@ -372,20 +377,15 @@ def mps_ops_modifier( # Those ops are not expected to work UNIMPLEMENTED_XFAILLIST = { # Failures due to lack of op implementation on MPS backend - "login": None, "logspace": None, "logspacetensor_overload": None, "linalg.eig": None, "linalg.eigvals": None, "put": None, - "nn.functional.conv_transpose3d": None, "cauchy_": None, "cauchy": None, "cholesky_inverse": None, "cholesky_solve": None, - "cummax": None, - "cummin": None, - "erfc": None, "frexp": None, "gcd": None, "geqrf": None, @@ -416,7 +416,6 @@ def mps_ops_modifier( "linalg.qr": None, "linalg.svdvals": None, "linalg.vecdot": None, - "logcumsumexp": None, "lu_solve": None, "masked.median": None, "matrix_exp": None, @@ -429,15 +428,12 @@ def mps_ops_modifier( "nn.functional.adaptive_max_pool3d": None, "nn.functional.interpolatearea": None, "nn.functional.interpolatebicubic": [torch.uint8], - "nn.functional.interpolatetrilinear": None, "nn.functional.max_unpool1dgrad": None, "nn.functional.max_unpool2dgrad": None, "nn.functional.max_unpool3dgrad": None, "nn.functional.avg_pool3d": None, "nn.functional.ctc_loss": None, "nn.functional.embedding_bag": None, - "nn.functional.hardshrink": None, - "nn.functional.max_pool3d": None, "nn.functional.max_unpool1d": None, "nn.functional.max_unpool2d": None, "nn.functional.max_unpool3d": None, @@ -469,6 +465,7 @@ def mps_ops_modifier( "special.airy_ai": None, "special.erfcx": None, "special.laguerre_polynomial_l": None, + "special.legendre_polynomial_p": None, "special.log_ndtr": None, "special.ndtri": None, "svd_lowrank": None, @@ -491,75 +488,27 @@ def mps_ops_modifier( "log_softmaxwith_dtype": None, "softmaxwith_dtype": None, "float_power": None, - "full_like": None, "linalg.matrix_rankhermitian": None, "linalg.pinvhermitian": None, "nonzero_static": None, # MPS: input sizes must be divisible by output sizes "nn.functional.adaptive_avg_pool1d": None, "nn.functional.adaptive_avg_pool2d": None, - # Unsupported dtypes - "ones_like": None, - "zeros_like": None, # Convolution for integral types is not supported on MPS "nn.functional.conv1d": [torch.int64], "nn.functional.conv2d": [torch.int64], "nn.functional.conv3d": [torch.int64], "nn.functional.conv_transpose1d": [torch.int64], "nn.functional.conv_transpose2d": [torch.int64, torch.bfloat16], + "nn.functional.conv_transpose3d": [ + torch.int64, + torch.bfloat16, + torch.float16, + ], # Unsupported dtypes "dot": [torch.int64] if MACOS_VERSION < 14.0 else [], "histc": [torch.float16, torch.bfloat16], "index_add": [torch.int64], - # Operations not supported for integral types - "special.xlog1py": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], - "special.zeta": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], - "special.chebyshev_polynomial_t": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], - "special.chebyshev_polynomial_u": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], - "special.hermite_polynomial_h": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], - "special.hermite_polynomial_he": [ - torch.bool, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - torch.int8, - ], # GEMM on MPS is not supported for integral types "nn.functional.linear": [ torch.int16, @@ -900,6 +849,7 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "floor_divide": [torch.float16, torch.float32], # derivative for aten::narrow_copy is not implemented on CPU "narrow_copy": [torch.float16, torch.float32], + "nn.functional.max_pool3d": [torch.float16, torch.float32], # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU "histogramdd": [torch.float16, torch.float32], # derivative for aten::histogram is not implemented @@ -1043,8 +993,6 @@ def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "aminmax", # memory overlapping checks "index_select", - # unimplemented - "logcumsumexp", } def addDecorator(op: OpInfo, d: DecorateInfo) -> None: diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 05e68df6e71d97..780514e6743970 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1542,13 +1542,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), ), skips=( - DecorateInfo( - skipIfMPS, # addcdiv doesn't work for non-contiguous, see #118115 - "TestOptimRenewed", - "test_forloop_goes_right_direction", - active_if=lambda kwargs: not kwargs["contiguous"], - device_type="mps", - ), DecorateInfo( skipIfTorchDynamo("See #116028"), "TestOptimRenewed", @@ -1639,13 +1632,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), ), skips=( - DecorateInfo( - skipIfMPS, # addcdiv doesn't work for non-contiguous, see #118115 - "TestOptimRenewed", - "test_forloop_goes_right_direction", - active_if=lambda kwargs: not kwargs["contiguous"], - device_type="mps", - ), DecorateInfo( skipIfTorchDynamo( "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" @@ -1676,13 +1662,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( supported_impls=("foreach", "differentiable"), has_capturable_arg=True, skips=( - DecorateInfo( - skipIfMPS, # addcdiv doesn't work for non-contiguous, see #118115 - "TestOptimRenewed", - "test_forloop_goes_right_direction", - active_if=lambda kwargs: not kwargs["contiguous"], - device_type="mps", - ), DecorateInfo( skipIfTorchDynamo("See #116028"), "TestOptimRenewed", @@ -1763,13 +1742,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), ), skips=( - DecorateInfo( - skipIfMPS, # addcdiv doesn't work for non-contiguous, see #118115 - "TestOptimRenewed", - "test_forloop_goes_right_direction", - active_if=lambda kwargs: not kwargs["contiguous"], - device_type="mps", - ), DecorateInfo( skipIfTorchDynamo( "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" @@ -1904,13 +1876,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( supported_impls=("foreach", "differentiable"), has_capturable_arg=True, skips=( - DecorateInfo( - skipIfMPS, # addcdiv doesn't work for non-contiguous, see #118115 - "TestOptimRenewed", - "test_forloop_goes_right_direction", - active_if=lambda kwargs: not kwargs["contiguous"], - device_type="mps", - ), DecorateInfo( skipIfTorchDynamo( "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" @@ -1988,13 +1953,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( supported_impls=("foreach", "differentiable"), has_capturable_arg=True, skips=( - DecorateInfo( - skipIfMPS, # addcdiv doesn't work for non-contiguous, see #118115 - "TestOptimRenewed", - "test_forloop_goes_right_direction", - active_if=lambda kwargs: not kwargs["contiguous"], - device_type="mps", - ), DecorateInfo( skipIfTorchDynamo("See #116028"), "TestOptimRenewed", @@ -2033,13 +1991,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( supported_impls=("foreach", "differentiable"), has_capturable_arg=True, skips=( - DecorateInfo( - skipIfMPS, # Rprop doesn't update for non-contiguous, see #118117 - "TestOptimRenewed", - "test_forloop_goes_right_direction", - active_if=lambda kwargs: not kwargs["contiguous"], - device_type="mps", - ), DecorateInfo( skipIfTorchDynamo("See #116028"), "TestOptimRenewed", diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 19959ae7d82d29..211b282c4fc4a8 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -1,6 +1,6 @@ # mypy: ignore-errors -r"""Importing this file includes common utility methods and base clases for +r"""Importing this file includes common utility methods and base classes for checking quantization api and properties of resulting modules. """ @@ -2806,7 +2806,7 @@ def __init__(self) -> None: self.myadd = nnq.FloatFunctional() self.myadd_relu = nnq.FloatFunctional() self.mymatmul = nnq.FloatFunctional() - # Tracing doesnt work yet for c10 ops with scalar inputs + # Tracing doesn't work yet for c10 ops with scalar inputs # https://github.com/pytorch/pytorch/issues/27097 # self.my_scalar_add = nnq.FloatFunctional() # self.my_scalar_mul = nnq.FloatFunctional() @@ -2816,7 +2816,7 @@ def forward(self, x): z = self.myadd.add(y, y) w = self.myadd_relu.add_relu(z, z) u = self.mymatmul.matmul(w, w.T) - # Tracing doesnt work yet for c10 ops with scalar inputs + # Tracing doesn't work yet for c10 ops with scalar inputs # https://github.com/pytorch/pytorch/issues/27097 # w = self.my_scalar_add.add_scalar(w, -0.5) # w = self.my_scalar_mul.mul_scalar(w, 0.5) diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index e433d907c3b99d..9dc177a7899bdc 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -10,7 +10,6 @@ from torch.testing._internal.common_utils import TEST_WITH_TSAN, IS_PPC, IS_MACOS, IS_WINDOWS supported_qengines = torch.backends.quantized.supported_engines -supported_qengines.remove('none') # Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326 # QNNPACK is not supported on PPC if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_TSAN, IS_MACOS, IS_WINDOWS]): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2a1846bbbf39d0..eb3d5f47c46227 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -112,7 +112,7 @@ def freeze_rng_state(*args, **kwargs): # Class to keep track of test flags configurable by environment variables. # Flags set here are intended to be read-only and should not be modified after # definition. -# TODO: Expand this class to handle abritrary settings in addition to boolean flags? +# TODO: Expand this class to handle arbitrary settings in addition to boolean flags? class TestEnvironment: # Set of env vars to set for the repro command that is output on test failure. # Specifically, this includes env vars that are set to non-default values and @@ -902,6 +902,11 @@ def prof_callable(callable, *args, **kwargs): return callable(*args, **kwargs) +def raise_on_run_directly(file_to_call): + raise RuntimeError("This test file is not meant to be run directly, " + f"use:\n\n\tpython {file_to_call} TESTNAME\n\n" + "instead.") + def prof_func_call(*args, **kwargs): return prof_callable(func_call, *args, **kwargs) @@ -1125,8 +1130,8 @@ def lint_test_case_extension(suite): test_case = first_test if test_case is not None: - test_class = test_case.id().split('.', 1)[1].split('.')[0] if not isinstance(test_case, TestCase): + test_class = test_case.id().split('.', 1)[1].split('.')[0] err = "This test class should extend from torch.testing._internal.common_utils.TestCase but it doesn't." print(f"{test_class} - failed. {err}") succeed = False @@ -1500,6 +1505,26 @@ def split_if_not_empty(x: str): TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12) +TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and ( + torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12 +) + +if TEST_CUDA_PYTHON_BINDINGS: + def cuda_python_error_check(function_call_output): + """Makes calls to cuda-python's cuda runtime functions more + pythonic by throwing an exception if they return a status + which is not cudaSuccess + """ + import cuda.bindings # type: ignore[import] + + error, *others = function_call_output + if error != cuda.bindings.runtime.cudaError_t.cudaSuccess: + raise ValueError(f"CUDA failure! {error}") + else: + return tuple(others) +else: + cuda_python_error_check = None # type: ignore[assignment] + def allocator_option_enabled_fn(allocator_config, _, option): if allocator_config is None: return False @@ -1555,6 +1580,10 @@ def __torch_function__(self, func, types, args=(), kwargs=None): env_var="PYTORCH_TEST_WITH_DYNAMO", implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER, ) +TEST_WITHOUT_COMPILED_AUTOGRAD: bool = TestEnvironment.def_flag( + "TEST_WITHOUT_COMPILED_AUTOGRAD", + env_var="PYTORCH_TEST_WITHOUT_COMPILED_AUTOGRAD", +) if TEST_WITH_TORCHDYNAMO: import torch._dynamo @@ -1567,6 +1596,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if TEST_WITH_TORCHINDUCTOR: import torch._inductor.config torch._inductor.config.fallback_random = True + else: + # only dynamo for now + torch._dynamo.config.compiled_autograd = not TEST_WITHOUT_COMPILED_AUTOGRAD # seems like this is only used in test/torch_np @@ -1642,10 +1674,32 @@ def wrapper(*args, **kwargs): return decorator +def runWithoutCompiledAutograd(msg="test doesn't currently work with compiled autograd"): + """ + Usage: + @runWithoutCompiledAutograd(msg) + def test_blah(self): + ... + """ + assert isinstance(msg, str) + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + with torch._dynamo.compiled_autograd._disable(): + func(*args, **kwargs) + return wrapper + + return decorator + def serialTest(condition=True): """ Decorator for running tests serially. Requires pytest """ + # If one apply decorator directly condition will be callable + # And test will essentially be essentially skipped, which is undesirable + assert type(condition) is bool + def decorator(fn): if has_pytest and condition: return pytest.mark.serial(fn) @@ -1987,6 +2041,26 @@ def wrapper(*args, **kwargs): return dec_fn(func) return dec_fn +def requires_cuda_p2p_access(): + cuda_p2p_access_available = ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) + and torch.cuda.device_count() >= 2 + ) + num_devices = torch.cuda.device_count() + for i in range(num_devices - 1): + for j in range(i + 1, num_devices): + if not torch.cuda.can_device_access_peer(i, j): + cuda_p2p_access_available = False + break + if not cuda_p2p_access_available: + break + + return skip_but_pass_in_sandcastle_if( + not cuda_p2p_access_available, + "cuda p2p access is not available", + ) + # Reverts the linalg backend back to default to make sure potential failures in one # test do not affect other tests def setLinalgBackendsToDefaultFinally(fn): @@ -2501,18 +2575,18 @@ def __exit__(self, exc_type, exc_value, traceback): msg = ("CUDA caching allocator reports a memory leak not " # type: ignore[possibly-undefined] f"verified by the driver API in {self.name}! " f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " - f"and is now reported as {caching_allocator_mem_allocated} " + f"and is now reported as {caching_allocator_mem_allocated} " # type: ignore[possibly-undefined] f"on device {i}. " - f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") + f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") # type: ignore[possibly-undefined] warnings.warn(msg) - elif caching_allocator_discrepancy and driver_discrepancy: + elif caching_allocator_discrepancy and driver_discrepancy: # type: ignore[possibly-undefined] # A caching allocator discrepancy validated by the driver API is a # failure (except on ROCm, see below) msg = (f"CUDA driver API confirmed a leak in {self.name}! " # type: ignore[possibly-undefined] f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " - f"and is now reported as {caching_allocator_mem_allocated} " + f"and is now reported as {caching_allocator_mem_allocated} " # type: ignore[possibly-undefined] f"on device {i}. " - f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") + f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") # type: ignore[possibly-undefined] raise RuntimeError(msg) @@ -3285,6 +3359,11 @@ def wrapper(*args, **kwargs): file_name = os.path.join(subdir, key) setattr(self, self._testMethodName, ignore_failure(method, file_name)) + from .dynamo_test_failures import compiled_autograd_skips + if torch._dynamo.config.compiled_autograd and key in compiled_autograd_skips: + # Still run the test, but with compiled autograd disabled + super_run = runWithoutCompiledAutograd()(super_run) + super_run(result=result) if strict_mode or should_reset_dynamo: @@ -3895,7 +3974,7 @@ def non_contiguous_copy(t, dim=-1, offset=0): ((0, 0), [(1, 2)], [()]), ]: for blocksize in blocksizes: - for densesize in densesizes: + for densesize in densesizes: # type: ignore[attr-defined] if layout == torch.strided: indices = () # type: ignore[assignment] values = torch.empty((basesize + densesize), device=device, dtype=dtype) @@ -4491,7 +4570,7 @@ def find_free_port(): NOTE: If this function is being used to allocate a port to Store (or indirectly via init_process_group or init_rpc), it should be used - in conjuction with the `retry_on_connect_failures` decorator as there is a potential + in conjunction with the `retry_on_connect_failures` decorator as there is a potential race condition where the allocated port may become unavailable before it can be used """ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: @@ -5536,6 +5615,7 @@ def repl_frame(m): s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n( .+\n( +[~^]+ *\n)?)+', repl_frame, s) s = re.sub(r"line \d+", "line N", s) s = re.sub(r".py:\d+", ".py:N", s) + s = re.sub(r'https:/([a-zA-Z0-9_.-]+)', r'https://\1', s) s = re.sub(file, _as_posix_path(os.path.basename(file)), s) s = re.sub(_as_posix_path(os.path.join(os.path.dirname(torch.__file__), "")), "", s) if suppress_suffix: @@ -5652,5 +5732,25 @@ def load_inline(*args, **kwargs): return cpp_extension.load_inline(*args, **kwargs) return func(*args, load_inline=load_inline, **kwargs) - return wrapper + +def recover_orig_fp32_precision(fn): + @contextlib.contextmanager + def recover(): + old_mkldnn_conv_p = torch.backends.mkldnn.conv.fp32_precision # type: ignore[attr-defined] + old_mkldnn_rnn_p = torch.backends.mkldnn.rnn.fp32_precision # type: ignore[attr-defined] + old_mkldnn_matmul_p = torch.backends.mkldnn.matmul.fp32_precision # type: ignore[attr-defined] + old_cudnn_conv_p = torch.backends.cudnn.conv.fp32_precision # type: ignore[attr-defined] + old_cudnn_rnn_p = torch.backends.cudnn.rnn.fp32_precision # type: ignore[attr-defined] + old_cuda_matmul_p = torch.backends.cuda.matmul.fp32_precision + try: + yield + finally: + torch.backends.mkldnn.conv.fp32_precision = old_mkldnn_conv_p # type: ignore[attr-defined] + torch.backends.mkldnn.rnn.fp32_precision = old_mkldnn_rnn_p # type: ignore[attr-defined] + torch.backends.mkldnn.matmul.fp32_precision = old_mkldnn_matmul_p # type: ignore[attr-defined] + torch.backends.cudnn.conv.fp32_precision = old_cudnn_conv_p # type: ignore[attr-defined] + torch.backends.cudnn.rnn.fp32_precision = old_cudnn_rnn_p # type: ignore[attr-defined] + torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p + + return recover()(fn) diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index c7c18b76a3b74e..8007d356309164 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -359,7 +359,7 @@ def check_all_permutations(op, args, kwargs, assert_equal_fn): # - data_ptr accesses # The first is easy to filter for (we could make the error a different # error class), the second is always going to be a RuntimeError due to - # how it is implemented (if you try to access the data_ptr of thex + # how it is implemented (if you try to access the data_ptr of the # wrapper Tensor, it raises you some internal RuntimeError). # # So the most general thing to catch here was RuntimeError. If you diff --git a/torch/testing/_internal/distributed/checkpoint_utils.py b/torch/testing/_internal/distributed/checkpoint_utils.py index 7d4d4a216270ce..07b05140e36e6b 100644 --- a/torch/testing/_internal/distributed/checkpoint_utils.py +++ b/torch/testing/_internal/distributed/checkpoint_utils.py @@ -3,6 +3,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import io +import logging import os import shutil import tempfile @@ -157,3 +158,36 @@ def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None: shutil.rmtree(self.temp_dir, ignore_errors=True) return wrapper + + +def with_checkpoint_logging( + func: Optional[Callable] = None, + logger_name: str = "torch.distributed.checkpoint", + level: int = logging.INFO, +) -> Optional[Callable]: + """ + Wrapper to configure checkpoint logging for distributed tests. + + Args: + func: The test function to wrap + logger_name: Name of the logger to configure (default: 'torch.distributed.checkpoint') + level: Logging level to set (default: logging.INFO) + """ + assert func is not None + + @wraps(func) + def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None: + # Get the logger and store original level + target_logger = logging.getLogger(logger_name) + original_level = target_logger.level + + # Set the desired logging level + target_logger.setLevel(level) + + try: + func(self, *args, **kwargs) + finally: + # Restore original logging level + target_logger.setLevel(original_level) + + return wrapper diff --git a/torch/testing/_internal/distributed/common_state_dict.py b/torch/testing/_internal/distributed/common_state_dict.py index 8111bb6773208c..f7d79907bdbed5 100644 --- a/torch/testing/_internal/distributed/common_state_dict.py +++ b/torch/testing/_internal/distributed/common_state_dict.py @@ -141,7 +141,7 @@ def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> N def _state_dict_hook(self, destination, prefix, keep_vars): """Remove "embedding" from the original embedding in the state_dict - name. This keeps the orginal state dict name for the embedding + name. This keeps the original state dict name for the embedding from before fusing with the FusionEmbedding. """ key = prefix + "embedding.weight" diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index eef2fc899d1b70..28b761a37d58cc 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -85,6 +85,7 @@ IS_WINDOWS, skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, + skipIfRocm, ) from torch.utils._python_dispatch import TorchDispatchMode from torch.utils.data.distributed import DistributedSampler @@ -4853,6 +4854,7 @@ def _test_ddp_apply_optim_in_backward( # case. optim.zero_grad(set_to_none=True) + @skipIfRocm @skip_if_lt_x_gpu(2) def test_ddp_apply_optim_in_backward(self): for optim_cls, init_before in itertools.product( @@ -4865,6 +4867,7 @@ def test_ddp_apply_optim_in_backward(self): init_before=init_before, ) + @skipIfRocm @skip_if_lt_x_gpu(2) def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self): for init_before in [True, False]: @@ -4875,6 +4878,7 @@ def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self): gradient_as_bucket_view=False, ) + @skipIfRocm @skip_if_lt_x_gpu(2) def test_ddp_apply_optim_in_backward_ignored_params(self): torch.cuda.set_device(self.rank) @@ -6869,7 +6873,7 @@ def test_ddp_grad_div_uneven_inputs(self): def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None): """Runs DDP based model training and captures profiles. This test will do two profiler runs. - 1. An inital basic run to check if profiler events are correctly captured. + 1. An initial basic run to check if profiler events are correctly captured. 2. A second profiling pass after running some iterations of DDP, to check robustness of thread local state. args @@ -6992,7 +6996,7 @@ def test_ddp_profiling_torch_profiler(self): def _validate_execution_trace_nccl(self, et_file: str) -> None: """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms" - We test for basic fields in theese nodes in the Execution Trace. + We test for basic fields in these nodes in the Execution Trace. """ with open(et_file) as f: et = json.load(f) diff --git a/torch/testing/_internal/distributed/fake_pg.py b/torch/testing/_internal/distributed/fake_pg.py index c76e5d29615322..a34ee75cf600e4 100644 --- a/torch/testing/_internal/distributed/fake_pg.py +++ b/torch/testing/_internal/distributed/fake_pg.py @@ -19,10 +19,10 @@ def _create_fake_pg(prefix_store, rank, world_size, timeout): without needing multiple processes (simulates per-rank behavior) NOTE: This is not a real process group, and it would produce wrong results - for every collective. It should be used as a convinient tool when playing + for every collective. It should be used as a convenient tool when playing with distributed but don't care about the actual data. """ return FakeProcessGroup(rank, world_size) -dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda"]) +dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda", "hpu"]) diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index 367ef218127c5a..f7cb2075e37343 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -2092,7 +2092,7 @@ def test_debug_info(self): debug_info = dist_autograd._get_debug_info() num_autograd_context = int(debug_info["num_autograd_contexts"]) - # Need atleast one context and not more than 4. + # Need at least one context and not more than 4. self.assertTrue(num_autograd_context >= 1 and num_autograd_context <= 4) for rd in range(self.world_size - 1): diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py index 7d56c6ebd05bc3..08246cc65132b7 100644 --- a/torch/testing/_internal/dynamo_test_failures.py +++ b/torch/testing/_internal/dynamo_test_failures.py @@ -8,6 +8,7 @@ # # We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures` # We generate skipIfTorchDynamo* for all tests in `dynamo_skips` +# We generate runWithoutCompiledAutograd for all tests in `compiled_autograd_skips` # # For an easier-than-manual way of generating and updating these lists, # see scripts/compile_tests/update_failures.py @@ -82,6 +83,8 @@ def find_test_dir(): inductor_expected_failures = set() inductor_skips = set() + + compiled_autograd_skips = set() else: dynamo_failures_directory = os.path.join(test_dir, "dynamo_expected_failures") dynamo_skips_directory = os.path.join(test_dir, "dynamo_skips") @@ -95,6 +98,11 @@ def find_test_dir(): inductor_expected_failures = set(os.listdir(inductor_failures_directory)) inductor_skips = set(os.listdir(inductor_skips_directory)) + compiled_autograd_skips_directory = os.path.join( + test_dir, "compiled_autograd_skips" + ) + compiled_autograd_skips = set(os.listdir(compiled_autograd_skips_directory)) + # TODO: due to case sensitivity problems, for now list these files by hand extra_dynamo_skips = { "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_T_cpu_float32", diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 60fef7f86120cb..91a4aaa5728a82 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -14,8 +14,18 @@ from torch._inductor.graph import GraphLowering from torch._inductor.compile_fx import shape_env_from_inputs from torch._inductor.codecache import CppCodeCache +from torch._inductor.custom_graph_pass import CustomGraphModulePass +from torch._inductor.codegen.common import ( + get_custom_backend_pass_for_device, + get_scheduling_for_device, + get_wrapper_codegen_for_device, + init_backend_registration, + register_backend_for_device +) +from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu +from torch.utils._helion import has_helion from torch.utils._triton import has_triton from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, @@ -48,6 +58,8 @@ def test_cpu(): HAS_TRITON = has_triton() +HAS_HELION = has_helion() + if HAS_TRITON: import triton TRITON_HAS_CPU = "cpu" in triton.backends.backends @@ -133,6 +145,7 @@ def skip_windows_ci(name: str, file: str) -> None: # TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS requires_gpu = functools.partial(unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu") requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton") +requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion") def requires_cuda_with_enough_memory(min_mem_required): def inner(fn): @@ -176,7 +189,7 @@ def dummy_graph() -> GraphLowering: def maybe_skip_size_asserts(op): """ - For certain ops, there meta and eager implementation returns differents + For certain ops, there meta and eager implementation returns different strides. This cause size/strides assert fail. Skip adding those asserts for now. """ @@ -290,3 +303,41 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) inverse_scale = scale.reciprocal() return x_fp8, inverse_scale + +@contextlib.contextmanager +def patch_inductor_backend( + device: str, + python_wrapper_codegen: PythonWrapperCodegen = None, + custom_pass: CustomGraphModulePass = None +): + """ + Patch the inductor backend for a specific device. + """ + # Make sure the backend is already registered + init_backend_registration() + + # Get the original registration parameters + original_scheduling = get_scheduling_for_device(device) + original_python_wrapper = get_wrapper_codegen_for_device(device, False) + original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) + original_custom_pass = get_custom_backend_pass_for_device(device) + + try: + # Register modified backend for the device + register_backend_for_device( + device, + original_scheduling, + python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper, + original_cpp_wrapper, + custom_pass if custom_pass is not None else original_custom_pass + ) + yield + finally: + # Restore the original backend + register_backend_for_device( + device, + original_scheduling, + original_python_wrapper, + original_cpp_wrapper, + original_custom_pass + ) diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 299eb999676c9d..4bc0738ec2f384 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -693,7 +693,7 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper -# make it easy to quicky define/trace a function for these tests +# make it easy to quickly define/trace a function for these tests def _trace(*args, **kwargs): def wrapper(func): return torch.jit.trace(func, args, **kwargs) diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 04c9b69218e10e..5cd248792dcb11 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -482,7 +482,7 @@ def __call__(self, *args, **kwargs): # set with small tensors. An elaborated set of sample inputs # can be specified using the "reference_inputs_func" attribute. # The "reference inputs" for an operation are an extended -# set of sample inputs that can more exhausively test an +# set of sample inputs that can more exhaustively test an # operator. They are used by only a few tests that are careful # not to take too long to run. Adding reference inputs # is highly encouraged! @@ -851,7 +851,7 @@ def __setattr__(self, name: str, value: Any) -> None: # tolerance for nondeterminism while performing gradcheck gradcheck_nondet_tol: float = 0.0 - # Whether to use the fast implmentation for gradcheck/gradgradcheck. + # Whether to use the fast implementation for gradcheck/gradgradcheck. # When set to None, defers to the default value provided by the wrapper # function around gradcheck (testing._internal.common_utils.gradcheck) gradcheck_fast_mode: bool = None @@ -1469,7 +1469,7 @@ def _sample_inputs_unspecified(self, *args, **kwargs): sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func. To avoid this, either define the corresponding sample function, - or re-map unsupported samples to error inputs in an appropiate + or re-map unsupported samples to error inputs in an appropriate opinfo/definitions/sparse.py:_validate_sample_input_sparse_ @@ -3125,6 +3125,12 @@ def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs): return op(input + input.mH, *args, **kwargs) +def gradcheck_wrapper_ctc_loss(op, input, *args, **kwargs): + """Gradcheck wrapper for ctc loss to project onto log-simplex space.""" + # See https://github.com/pytorch/pytorch/issues/52241 + return op(input.log_softmax(dim=2), *args, **kwargs) + + def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs): """Gradcheck wrapper for functions that take lower or upper triangular matrices as input. diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index a69c7ed19d698b..9eeacf887084be 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -725,7 +725,7 @@ def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kw if device.type == "cpu" or has_cusolver(): deltas = (-1, 0, +1) # only square systems if Cusolver is not available - # becase we solve a lstsq problem with a transposed matrix in the backward + # because we solve a lstsq problem with a transposed matrix in the backward else: deltas = (0,) diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index a5b09f4c8dce7e..1418685e88323f 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -394,11 +394,8 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -410,11 +407,8 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -424,13 +418,10 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( - DecorateInfo( - unittest.skip( - "Skipping - testing takes an unreasonably long time, #79528" - ) - ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -440,13 +431,10 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( - DecorateInfo( - unittest.skip( - "Skipping - testing takes an unreasonably long time, #79528" - ) - ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -471,11 +459,8 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: inf + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -487,11 +472,8 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -501,18 +483,10 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( - DecorateInfo( - unittest.skip( - "Skipping - testing takes an unreasonably long time, #79528" - ) - ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -606,18 +580,10 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( - DecorateInfo( - unittest.skip( - "Skipping - testing takes an unreasonably long time, #79528" - ) - ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -627,18 +593,10 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( - DecorateInfo( - unittest.skip( - "Skipping - testing takes an unreasonably long time, #79528" - ) - ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -648,18 +606,10 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( - DecorateInfo( - unittest.skip( - "Skipping - testing takes an unreasonably long time, #79528" - ) - ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, @@ -669,18 +619,10 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( - DecorateInfo( - unittest.skip( - "Skipping - testing takes an unreasonably long time, #79528" - ) - ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), - DecorateInfo( - unittest.skip("testing takes an unreasonably long time, #79528"), - "TestCommon", - "test_compare_cpu", - ), + # Greatest absolute difference: nan + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), supports_one_python_scalar=True, supports_autograd=False, diff --git a/torch/testing/_internal/opinfo/utils.py b/torch/testing/_internal/opinfo/utils.py index 2c97278e5646c8..4000ec6ca13551 100644 --- a/torch/testing/_internal/opinfo/utils.py +++ b/torch/testing/_internal/opinfo/utils.py @@ -156,7 +156,7 @@ def np_unary_ufunc_integer_promotion_wrapper(fn): # Wrapper that passes PyTorch's default scalar # type as an argument to the wrapped NumPy # unary ufunc when given an integer input. - # This mimicks PyTorch's integer->floating point + # This mimics PyTorch's integer->floating point # type promotion. # # This is necessary when NumPy promotes diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index 5484a6c16bea11..51fcadd8dee970 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -490,7 +490,7 @@ def __init__( # Location of the failures dict. Makes it so that the error message is better. self.failures_dict_path = failures_dict_path - # OpCheckMode surpresses errors, collects them here, and then raises them on exit. + # OpCheckMode suppresses errors, collects them here, and then raises them on exit. # Maps qualname -> List[(Exception, func, maybe args, maybe kwargs)] self.seen_ops_to_errors = {} @@ -605,7 +605,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): option = self.failures_dict.get_status(qualname, self.test_name) if option == "xsuccess" or option == "xfail": - # Surpress all errors during execution. Raise them during __exit__. + # Suppress all errors during execution. Raise them during __exit__. try: if qualname not in self.seen_ops_to_errors: self.seen_ops_to_errors[qualname] = [] diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 608a6f14389baf..79aca02b63d40e 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -257,7 +257,7 @@ def add_kernel_with_scaling( tl.store(out_ptr + offsets, output, mask=mask) @triton.jit - def add_kernel_with_tma_1d( + def add_kernel_with_tma_1d_old_api( in_desc_ptr0, in_desc_ptr1, out_desc_ptr, @@ -288,7 +288,7 @@ def add_kernel_with_tma_1d( ) @triton.jit - def add_kernel_with_tma_2d( + def add_kernel_with_tma_2d_old_api( in_desc_ptr0, in_desc_ptr1, out_desc_ptr, @@ -321,6 +321,186 @@ def add_kernel_with_tma_2d( [offset_x, offset_y], ) + @triton.jit + def add_kernel_with_tma_1d_new_api( + in_desc_ptr0, + in_desc_ptr1, + out_desc_ptr, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + offset = pid * BLOCK_SIZE + + a = tl.load_tensor_descriptor( + in_desc_ptr0, + [offset], + ) + b = tl.load_tensor_descriptor( + in_desc_ptr1, + [offset], + ) + + output = a + b + + tl.store_tensor_descriptor( + out_desc_ptr, + [offset], + output, + ) + + @triton.jit + def add_kernel_with_tma_2d_new_api( + in_desc_ptr0, + in_desc_ptr1, + out_desc_ptr, + BLOCK_SIZE_X: "tl.constexpr", + BLOCK_SIZE_Y: "tl.constexpr", + ): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + offset_x = pid_x * BLOCK_SIZE_X + offset_y = pid_y * BLOCK_SIZE_Y + + x = tl.load_tensor_descriptor( + in_desc_ptr0, + [offset_x, offset_y], + ) + y = tl.load_tensor_descriptor( + in_desc_ptr1, + [offset_x, offset_y], + ) + + output = x + y + + tl.store_tensor_descriptor( + out_desc_ptr, + [offset_x, offset_y], + output, + ) + + @triton.jit + def add_kernel_on_device_tma_old_api( + a_ptr, + b_ptr, + c_ptr, + m, + n, + workspace, + BLOCK_SIZE: "tl.constexpr", + ): + a_desc_ptr = workspace + b_desc_ptr = workspace + 128 + c_desc_ptr = workspace + 256 + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=a_ptr, + load_size=[BLOCK_SIZE, BLOCK_SIZE], + global_size=[m, n], + element_ty=a_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=b_ptr, + load_size=[BLOCK_SIZE, BLOCK_SIZE], + global_size=[m, n], + element_ty=b_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr, + load_size=[BLOCK_SIZE, BLOCK_SIZE], + global_size=[m, n], + element_ty=c_ptr.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + offset_x = pid_x * BLOCK_SIZE + offset_y = pid_y * BLOCK_SIZE + + # Load data using the tensor descriptors + a = tl._experimental_descriptor_load( + a_desc_ptr, + [offset_x, offset_y], + [BLOCK_SIZE, BLOCK_SIZE], + tl.float32, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [offset_x, offset_y], + [BLOCK_SIZE, BLOCK_SIZE], + tl.float32, + ) + + # Perform addition + output = a + b + + # Store the result + tl._experimental_descriptor_store( + c_desc_ptr, + output, + [offset_x, offset_y], + ) + + @triton.jit + def add_kernel_on_device_tma_new_api( + a_ptr, + b_ptr, + c_ptr, + m, + n, + workspace, # unused but left here to match the old API kernel + BLOCK_SIZE: "tl.constexpr", + ): + # Create tensor descriptors using the new API + a_desc = tl.make_tensor_descriptor( + base=a_ptr, + shape=[m, n], + strides=[n, 1], + block_shape=[BLOCK_SIZE, BLOCK_SIZE], + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, + shape=[m, n], + strides=[n, 1], + block_shape=[BLOCK_SIZE, BLOCK_SIZE], + ) + c_desc = tl.make_tensor_descriptor( + base=c_ptr, + shape=[m, n], + strides=[n, 1], + block_shape=[BLOCK_SIZE, BLOCK_SIZE], + ) + + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + offset_x = pid_x * BLOCK_SIZE + offset_y = pid_y * BLOCK_SIZE + + # Load data using the tensor descriptors with the new API + a = tl.load_tensor_descriptor( + a_desc, + [offset_x, offset_y], + ) + b = tl.load_tensor_descriptor( + b_desc, + [offset_x, offset_y], + ) + + # Perform addition + output = a + b + + # Store the result with the new API + tl.store_tensor_descriptor( + c_desc, + [offset_x, offset_y], + output, + ) + @triton.jit def mul2_kernel( in_ptr0, @@ -679,3 +859,98 @@ def strange_config_matmul_kernel( c_ptrs = c_ptr + offs_cm[:, None] + offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) + + @triton.jit + def kernel_with_docstring_double_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr): + """ + This kernel contains a triple-quote docstring w/ double quotes. + Make sure that codegen sanitizes the docstring. + """ + pid = tl.program_id(axis=0) + offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32) + tl.store(out_ptr + offsets, ones, mask=offsets < numel) + + @triton.jit + def kernel_with_docstring_single_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr): + ''' + This kernel contains a triple-quote docstring w/ single quotes + Make sure that codegen sanitizes the docstring. + To prevent it from being linted to double quotes: """!!!""" + ''' + pid = tl.program_id(axis=0) + offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32) + tl.store(out_ptr + offsets, ones, mask=offsets < numel) + + @triton.jit + def kernel_inline_asm_double_quotes( + in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + data = tl.load(in_ptr + offsets, mask=offsets < numel) + cos_pow = tl.inline_asm_elementwise( + asm=""" + { + cos.approx.f32 $0, $1; + ex2.approx.f32 $0, $0; + } + """, + constraints=("=r, r"), + args=[data], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel) + + @triton.jit + def kernel_inline_asm_single_quotes( + in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + data = tl.load(in_ptr + offsets, mask=offsets < numel) + cos_pow = tl.inline_asm_elementwise( + asm=''' + { + // double quotes to pacify the linter """!!!""" + cos.approx.f32 $0, $1; + ex2.approx.f32 $0, $0; + } + ''', + constraints=("=r, r"), + args=[data], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel) + + # support the old (experimental) and new (tensor_descriptor) APIs + def create_tensor_descriptor_shim( + tensor, block_sizes: list[int], new_api: bool = True + ): + if new_api: + return triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + tensor, block_sizes + ) + else: + if len(block_sizes) == 1: + return triton.tools.experimental_descriptor.create_1d_tma_descriptor( + tensor.data_ptr(), + tensor.size(0), + block_sizes[0], + tensor.element_size(), + ) + else: + assert len(block_sizes) == 2 + return triton.tools.experimental_descriptor.create_2d_tma_descriptor( + tensor.data_ptr(), + tensor.size(0), + tensor.size(1), + block_sizes[0], + block_sizes[1], + tensor.element_size(), + ) diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 01a421f53084bb..4ec4e5b591596d 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -52,17 +52,17 @@ class _Config(Generic[T]): alias: If set, the directly use the value of the alias. env_name_force: If set, this environment variable has precedence over everything after this. - If multiple env variables are given, the precendence order is from + If multiple env variables are given, the precedence order is from left to right. user_override: If a user sets a value (i.e. foo.bar=True), that has precedence over everything after this. env_name_default: If set, this environment variable will override everything after this. - If multiple env variables are given, the precendence order is from + If multiple env variables are given, the precedence order is from left to right. justknob: If this pytorch installation supports justknobs, that will - override defaults, but will not override the user_override precendence. - default: This value is the lowest precendance, and will be used if nothing is + override defaults, but will not override the user_override precedence. + default: This value is the lowest precedence, and will be used if nothing is set. Environment Variables: @@ -439,7 +439,7 @@ def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None: def _is_default(self, name: str) -> bool: """ Returns true if the config is at its default value. - configs overriden by the env are not considered default. + configs overridden by the env are not considered default. """ config_val = self._config[name] # The config is not overridden by the user, and the env_value_default diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 4f1b991438c078..26217de5bb32e7 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -4,15 +4,16 @@ import functools import inspect -import warnings import sys -from typing import Any, Callable, TypeVar, cast +import warnings +from typing import Any, Callable, cast, TypeVar + # Used for annotating the decorator usage of _DecoratorContextManager (e.g., # 'no_grad' and 'enable_grad'). # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators FuncType = Callable[..., Any] -F = TypeVar('F', bound=FuncType) +F = TypeVar("F", bound=FuncType) def _wrap_generator(ctx_factory, func): @@ -22,6 +23,7 @@ def _wrap_generator(ctx_factory, func): The input should be a function that returns a context manager, not a context manager itself, to handle one-shot context managers. """ + @functools.wraps(func) def generator_context(*args, **kwargs): gen = func(*args, **kwargs) @@ -83,7 +85,7 @@ def context_decorator(ctx, func): be a multi-shot context manager that can be directly invoked multiple times) or a callable that produces a context manager. """ - assert not (callable(ctx) and hasattr(ctx, '__enter__')), ( + assert not (callable(ctx) and hasattr(ctx, "__enter__")), ( f"Passed in {ctx} is both callable and also a valid context manager " "(has __enter__), making it ambiguous which interface to use. If you " "intended to pass a context manager factory, rewrite your call as " @@ -92,8 +94,10 @@ def context_decorator(ctx, func): ) if not callable(ctx): + def ctx_factory(): return ctx + else: ctx_factory = ctx diff --git a/torch/utils/_cpp_embed_headers.py b/torch/utils/_cpp_embed_headers.py index 9cb0fee3a3f8cc..6bcf8d583f0cd2 100644 --- a/torch/utils/_cpp_embed_headers.py +++ b/torch/utils/_cpp_embed_headers.py @@ -39,7 +39,8 @@ def embed_headers( fname: str, include_dirs: Optional[Union[Sequence[str], Sequence[Path], str]] = None ) -> str: if include_dirs is None: - include_dirs = [Path(__file__).parent.parent.parent] + base_dir = Path(__file__).parent.parent.parent + include_dirs = [base_dir, base_dir / "aten" / "src"] elif isinstance(include_dirs, str): include_dirs = [Path(include_dirs)] else: diff --git a/torch/utils/_cpp_extension_versioner.py b/torch/utils/_cpp_extension_versioner.py index f414ec00ddc236..2997f90d7c89d5 100644 --- a/torch/utils/_cpp_extension_versioner.py +++ b/torch/utils/_cpp_extension_versioner.py @@ -2,18 +2,18 @@ import collections -Entry = collections.namedtuple('Entry', 'version, hash') +Entry = collections.namedtuple("Entry", "version, hash") def update_hash(seed, value): # Good old boost::hash_combine # https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html - return seed ^ (hash(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2)) + return seed ^ (hash(value) + 0x9E3779B9 + (seed << 6) + (seed >> 2)) def hash_source_files(hash_value, source_files): for filename in source_files: - with open(filename, 'rb') as file: + with open(filename, "rb") as file: hash_value = update_hash(hash_value, file.read()) return hash_value @@ -34,15 +34,17 @@ def get_version(self, name): entry = self.entries.get(name) return None if entry is None else entry.version - def bump_version_if_changed(self, - name, - source_files, - build_arguments, - build_directory, - with_cuda, - with_sycl, - is_python_module, - is_standalone): + def bump_version_if_changed( + self, + name, + source_files, + build_arguments, + build_directory, + with_cuda, + with_sycl, + is_python_module, + is_standalone, + ): hash_value = 0 hash_value = hash_source_files(hash_value, source_files) hash_value = hash_build_arguments(hash_value, build_arguments) diff --git a/torch/utils/_device.py b/torch/utils/_device.py index e16505791b9d51..de3ee4a9e34474 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -1,13 +1,16 @@ # mypy: allow-untyped-defs +import functools from typing import Optional + import torch -from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode -from torch.utils._contextlib import context_decorator from torch._C import _len_torch_function_stack -import functools +from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode +from torch.utils._contextlib import context_decorator + CURRENT_DEVICE: Optional[torch.device] = None + @functools.lru_cache(1) def _device_constructors(): return { @@ -50,9 +53,10 @@ def _device_constructors(): # weird ones torch.tensor, torch.as_tensor, - torch.scalar_tensor + torch.scalar_tensor, } + # NB: This is directly called from C++ in torch/csrc/Device.cpp class DeviceContext(TorchFunctionMode): def __init__(self, device): @@ -73,13 +77,12 @@ def __enter__(self): for mode in reversed(cur_stack): _push_mode(mode) - def __exit__(self, exc_type, exc_val, exc_tb): global CURRENT_DEVICE CURRENT_DEVICE = self.old_device cur_stack = [] # Invariant: there should only be one DeviceContext on the stack at any time - # (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext + # (At the bottom), pop all modes until we hit the bottom, assert it's a DeviceContext # or else someone else has popped it! for _ in range(_len_torch_function_stack() - 1): mode = _pop_mode() @@ -95,14 +98,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if func in _device_constructors() and kwargs.get('device') is None: - kwargs['device'] = self.device + if func in _device_constructors() and kwargs.get("device") is None: + kwargs["device"] = self.device return func(*args, **kwargs) + # NB: This is directly called from C++ in torch/csrc/Device.cpp def device_decorator(device, func): return context_decorator(lambda: device, func) + def set_device(device): """ Set the default device inside of the wrapped function by decorating it with this function. diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index 863921bbf87f06..e3a2070f2d4d6d 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -1,17 +1,28 @@ from typing import Optional +from typing_extensions import TypeAlias import torch from torch import Tensor from torch.autograd.grad_mode import no_grad -from typing_extensions import TypeAlias + def _get_foreach_kernels_supported_devices() -> list[str]: r"""Return the device type list that supports foreach kernels.""" - return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()] + return ["cuda", "xpu", "mtia", torch._C._get_privateuse1_backend_name()] + def _get_fused_kernels_supported_devices() -> list[str]: r"""Return the device type list that supports fused kernels in optimizer.""" - return ["mps", "cuda", "xpu", "hpu", "cpu", torch._C._get_privateuse1_backend_name()] + return [ + "mps", + "cuda", + "xpu", + "hpu", + "cpu", + "mtia", + torch._C._get_privateuse1_backend_name(), + ] + TensorListList: TypeAlias = list[list[Optional[Tensor]]] Indices: TypeAlias = list[int] @@ -36,9 +47,15 @@ def _group_tensors_by_device_and_dtype( ) -> dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]]: return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) + def _device_has_foreach_support(device: torch.device) -> bool: - return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting() + return ( + device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) + and not torch.jit.is_scripting() + ) def _has_foreach_support(tensors: list[Tensor], device: torch.device) -> bool: - return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors) + return _device_has_foreach_support(device) and all( + t is None or type(t) in _foreach_supported_types for t in tensors + ) diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py index 30a797350d2d28..8696065adb9f92 100644 --- a/torch/utils/_freeze.py +++ b/torch/utils/_freeze.py @@ -3,6 +3,9 @@ """ Freeze Python packages. + + + Freezing makes it possible to ship arbitrary Python modules as part of a C++ library. The Python source of the module is compiled to bytecode and written to `.c` files, to be imported by Python's built-in FrozenImporter. diff --git a/torch/utils/_get_clean_triton.py b/torch/utils/_get_clean_triton.py index f5e1495e7dc574..98ee54a1c23db6 100644 --- a/torch/utils/_get_clean_triton.py +++ b/torch/utils/_get_clean_triton.py @@ -2,6 +2,7 @@ import argparse import os import re +import subprocess from pathlib import Path @@ -80,7 +81,9 @@ def replace(match) -> str: return remove_inductor_wrappers -def process_file(input_filename: str, output_filename: str) -> str: +def process_file( + input_filename: str, output_filename: str, auto_generate_params: bool = True +) -> str: with open(input_filename) as file: source_code = file.read() @@ -94,9 +97,41 @@ def process_file(input_filename: str, output_filename: str) -> str: transformed_code = remove_async_compile(transformed_code) launch_params_filename = f"{input_filename}.launch_params" + + # Auto-generate launch_params if they don't exist and auto_generate_params is True + if not os.path.exists(launch_params_filename) and auto_generate_params: + print(f"Launch params file {launch_params_filename} not found. Generating...") + try: + # Set environment variable and run the input file + env = os.environ.copy() + env["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = "1" + + result = subprocess.run( + ["python", input_filename], + env=env, + capture_output=True, + text=True, + cwd=os.path.dirname(input_filename) or ".", + ) + + if result.returncode != 0: + print(f"Error running {input_filename}:") + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + raise RuntimeError( + f"Failed to generate launch params. Command failed with return code {result.returncode}" + ) + + print(f"Successfully generated {launch_params_filename}") + + except Exception as e: + raise RuntimeError( + f"Failed to generate launch params by running {input_filename}: {str(e)}" + ) from e + if not os.path.exists(launch_params_filename): raise RuntimeError( - f"Missing {launch_params_filename}. Run `TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1 python {input_filename} first." + f"Missing {launch_params_filename}. Run `TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1 python {input_filename}` first." ) with open(launch_params_filename) as f: @@ -108,25 +143,32 @@ def process_file(input_filename: str, output_filename: str) -> str: with open(output_filename, "w") as file: file.write(transformed_code) + print(f"Successfully generated {output_filename}") return transformed_code def get_clean_triton( - input_path: Path, output_path: Path = Path("triton_only_repro.py") + input_path: Path, + output_path: Path = Path("triton_only_repro.py"), + auto_generate_params: bool = True, ): """Run experiments and output results to file Args: input_path (Optional[Path]): Path to inductor generated output codede output_path (Optional[Path]): Path to write out the new python file + auto_generate_params (bool): Whether to automatically generate launch_params if missing """ - return process_file(str(input_path), str(output_path)) + return process_file(str(input_path), str(output_path), auto_generate_params) if __name__ == "__main__": """Sample usage: # Running sweep - python inputcode.py + python _get_clean_triton.py output_code.py + + # To disable auto-generation of launch params: + python _get_clean_triton.py output_code.py --no-auto-generate """ parser = argparse.ArgumentParser( description="Clean Inductor generated code to remove Inductor dependencies" @@ -142,9 +184,16 @@ def get_clean_triton( default=Path("triton_only_repro.py"), help="Path to write out the clean triton output", ) + parser.add_argument( + "--no-auto-generate", + action="store_true", + help="Disable automatic generation of launch_params file", + ) # Parse the arguments args = parser.parse_args() # Call the function with parsed arguments - result = get_clean_triton(args.input_path, args.output_path) + result = get_clean_triton( + args.input_path, args.output_path, not args.no_auto_generate + ) diff --git a/torch/utils/_helion.py b/torch/utils/_helion.py new file mode 100644 index 00000000000000..6d30832cf3f741 --- /dev/null +++ b/torch/utils/_helion.py @@ -0,0 +1,17 @@ +import functools + +from torch.utils._triton import has_triton + + +@functools.cache +def has_helion_package() -> bool: + try: + import helion # type: ignore[import-untyped, import-not-found] # noqa: F401 + except ImportError: + return False + return True + + +@functools.cache +def has_helion() -> bool: + return has_helion_package() and has_triton() diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py index 91c0e07b3d9345..b79b52b13449e8 100644 --- a/torch/utils/_mode_utils.py +++ b/torch/utils/_mode_utils.py @@ -1,11 +1,15 @@ # mypy: allow-untyped-defs -import torch from typing import TypeVar -T = TypeVar('T') +import torch + + +T = TypeVar("T") + # returns if all are the same mode def all_same_mode(modes): return all(tuple(mode == modes[0] for mode in modes)) + no_dispatch = torch._C._DisableTorchDispatch diff --git a/torch/utils/_ordered_set.py b/torch/utils/_ordered_set.py index 29373289c42606..2bead0e00b12a6 100644 --- a/torch/utils/_ordered_set.py +++ b/torch/utils/_ordered_set.py @@ -33,7 +33,7 @@ def _from_dict(dict_inp: dict[T, None]) -> OrderedSet[T]: return s # - # Required overriden abstract methods + # Required overridden abstract methods # def __contains__(self, elem: object) -> bool: return elem in self._dict diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 4a06fd13ae46b2..3fab41d82bc467 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,12 +1,11 @@ # mypy: allow-untyped-defs import contextlib - import warnings -from dataclasses import dataclass -from typing import Any, Optional, Union, Protocol, overload +from collections import deque from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Optional, overload, Protocol, Union from typing_extensions import TypeIs -from collections import deque import torch import torchgen @@ -29,8 +28,13 @@ _is_in_torch_dispatch_mode = False _is_in_non_infra_torch_dispatch_mode = False + def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool: - return _is_in_torch_dispatch_mode if include_infra_modes else _is_in_non_infra_torch_dispatch_mode + return ( + _is_in_torch_dispatch_mode + if include_infra_modes + else _is_in_non_infra_torch_dispatch_mode + ) class TorchDispatchMode: @@ -79,7 +83,6 @@ def _lazy_init_old_dispatch_mode_flags(self): if not hasattr(self, "old_non_infra_dispatch_mode_flags"): self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef] - def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise NotImplementedError @@ -93,8 +96,12 @@ def __enter__(self): self._lazy_init_old_dispatch_mode_flags() self.old_dispatch_mode_flags.append(_is_in_torch_dispatch_mode) _is_in_torch_dispatch_mode = True - self.old_non_infra_dispatch_mode_flags.append(_is_in_non_infra_torch_dispatch_mode) - _is_in_non_infra_torch_dispatch_mode = _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode() + self.old_non_infra_dispatch_mode_flags.append( + _is_in_non_infra_torch_dispatch_mode + ) + _is_in_non_infra_torch_dispatch_mode = ( + _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode() + ) _push_mode(self) return self @@ -107,7 +114,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): global _is_in_torch_dispatch_mode _is_in_torch_dispatch_mode = self.old_dispatch_mode_flags.pop() global _is_in_non_infra_torch_dispatch_mode - _is_in_non_infra_torch_dispatch_mode = self.old_non_infra_dispatch_mode_flags.pop() + _is_in_non_infra_torch_dispatch_mode = ( + self.old_non_infra_dispatch_mode_flags.pop() + ) _pop_mode(mb_dk_or_mode_key) @classmethod @@ -123,7 +132,6 @@ def is_infra_mode(cls): return False - def _get_current_dispatch_mode(): stack_len = _len_torch_dispatch_stack() # Return a user mode on the stack if there are any @@ -133,19 +141,16 @@ def _get_current_dispatch_mode(): def _detect_infra_mode(key): - assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY] + assert key in [ + torch._C._TorchDispatchModeKey.FUNCTIONAL, + torch._C._TorchDispatchModeKey.PROXY, + ] from torch._ops import _get_dispatch_mode_pre_dispatch - pre_dispatch_mode = _get_dispatch_mode_pre_dispatch( - key - ) - post_dispatch_mode = torch._C._get_dispatch_mode( - key - ) + pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key) + post_dispatch_mode = torch._C._get_dispatch_mode(key) - assert (pre_dispatch_mode is None) or ( - post_dispatch_mode is None - ) + assert (pre_dispatch_mode is None) or (post_dispatch_mode is None) if pre_dispatch_mode is None: return post_dispatch_mode @@ -232,8 +237,8 @@ def _disable_current_modes(): _pop_mode_from_pre_dispatch, ) from torch._subclasses.functional_tensor import FunctionalTensorMode - from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode from torch._subclasses.schema_check_mode import SchemaCheckMode + from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch() old_pre_dispatch_modes = [ @@ -267,10 +272,7 @@ def _disable_current_modes(): raise AssertionError( "Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key" ) - if ( - isinstance(old, SchemaCheckMode) - and has_schema_check_mode_in_pre_dispatch - ): + if isinstance(old, SchemaCheckMode) and has_schema_check_mode_in_pre_dispatch: raise AssertionError( "Can't have SchemaCheckMode available both in PreDispatch and Python Key" ) @@ -298,7 +300,9 @@ def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ... @staticmethod - def __tensor_unflatten__(inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int) -> torch.Tensor: + def __tensor_unflatten__( + inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int + ) -> torch.Tensor: ... # It would be really nice to be able to say that the return of @@ -331,41 +335,39 @@ def dim(self) -> int: @overload def to( - self, - dtype: torch.types._dtype, - non_blocking: bool = False, - copy: bool = False, - *, - memory_format: Optional[torch.memory_format] = None + self, + dtype: torch.types._dtype, + non_blocking: bool = False, + copy: bool = False, + *, + memory_format: Optional[torch.memory_format] = None, ) -> torch.Tensor: ... @overload def to( - self, - device: Optional["torch._prims_common.DeviceLikeType"] = None, - dtype: Optional[torch.types._dtype] = None, - non_blocking: bool = False, - copy: bool = False, - *, - memory_format: Optional[torch.memory_format] = None + self, + device: Optional["torch._prims_common.DeviceLikeType"] = None, + dtype: Optional[torch.types._dtype] = None, + non_blocking: bool = False, + copy: bool = False, + *, + memory_format: Optional[torch.memory_format] = None, ) -> torch.Tensor: ... @overload def to( - self, - other: torch.Tensor, - non_blocking: bool = False, - copy: bool = False, - *, - memory_format: Optional[torch.memory_format] = None + self, + other: torch.Tensor, + non_blocking: bool = False, + copy: bool = False, + *, + memory_format: Optional[torch.memory_format] = None, ) -> torch.Tensor: ... - - def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: """ Returns whether or not a tensor subclass that implements __torch_dispatch__ @@ -403,10 +405,15 @@ def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: and hasattr(t, "__tensor_unflatten__") ) + def is_traceable_wrapper_subclass_type(t: type) -> TypeIs[type[TensorWithFlatten]]: """Same as above, but takes a type argument instead of an instance.""" - return (issubclass(t, torch.Tensor) and t != torch.Tensor - and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")) + return ( + issubclass(t, torch.Tensor) + and t != torch.Tensor + and hasattr(t, "__tensor_flatten__") + and hasattr(t, "__tensor_unflatten__") + ) def transform_subclass(t, callback, outer_size=None, outer_stride=None): @@ -551,7 +558,9 @@ def get_alias_info(func) -> SchemaInfo: torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str) torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str) # for aten::rot90 / aten:fft_* - torchgen_schema_str = re.sub(r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str) + torchgen_schema_str = re.sub( + r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str + ) torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str) arg_schemas = [ AliasInfo( diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index ce08f3bef40458..3e7cadc6dc7a7d 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -113,10 +113,14 @@ def get(self, parent: Any) -> Any: class EnumEncoder(json.JSONEncoder): - def default(self, obj: object) -> str: + def default(self, obj: object) -> Union[str, dict[str, Any]]: if isinstance(obj, Enum): - return obj.value # type: ignore[no-any-return] - return super().default(obj) # type: ignore[no-any-return] + return { + "__enum__": True, + "fqn": f"{obj.__class__.__module__}:{obj.__class__.__qualname__}", + "name": obj.name, + } + return cast(str, super().default(obj)) Context = Any @@ -1836,6 +1840,18 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) +def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]: + if "__enum__" in obj: + modname, _, classname = obj["fqn"].partition(":") + mod = importlib.import_module(modname) + enum_cls = mod + for attr in classname.split("."): + enum_cls = getattr(enum_cls, attr) + enum_cls = cast(type[Enum], enum_cls) + return enum_cls[obj["name"]] + return obj + + def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: if ( json_schema["type"] is None @@ -1854,7 +1870,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: if serialize_node_def.from_dumpable_context is None: try: - context = json.loads(json_schema["context"]) + context = json.loads(json_schema["context"], object_hook=enum_object_hook) except TypeError as ex: raise TypeError( "Unable to deserialize context. " diff --git a/torch/utils/_stats.py b/torch/utils/_stats.py index 6d9d48233ee036..74b513932c3056 100644 --- a/torch/utils/_stats.py +++ b/torch/utils/_stats.py @@ -3,8 +3,8 @@ # AND SCRUB AWAY TORCH NOTIONS THERE. import collections import functools -from typing import Callable, TypeVar from collections import OrderedDict +from typing import Callable, TypeVar from typing_extensions import ParamSpec @@ -18,6 +18,7 @@ def count_label(label: str) -> None: prev = simple_call_counter.setdefault(label, 0) simple_call_counter[label] = prev + 1 + def count(fn: Callable[_P, _R]) -> Callable[_P, _R]: @functools.wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: @@ -25,4 +26,5 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: simple_call_counter[fn.__qualname__] = 0 simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1 return fn(*args, **kwargs) + return wrapper diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index d2c1ee83bc1c3f..39e981a78ac5b8 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -58,7 +58,7 @@ class StrobelightCLIFunctionProfiler: StrobelightCLIFunctionProfiler can be used to profile a python function and generate a strobelight link with the results. It works on meta servers but - does not requries an fbcode target. + does not requires an fbcode target. When stop_at_error is false(default), error during profiling does not prevent the work function from running. diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 85c23c0f36bb5c..4f8d045e5554d7 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1055,7 +1055,7 @@ def eval(cls, base, exp): # base is assumed to be nonnegative, thereby prevent complex numbers from -# occuring +# occurring class FloatPow(sympy.Function): is_real = True @@ -1300,6 +1300,12 @@ def _eval_expand_identity(self, **hints): # Removes the identity op. return self.args[0] + def __int__(self) -> int: + return int(self.args[0]) + + def __float__(self) -> float: + return float(self.args[0]) + def make_opaque_unary_fn(name): class OpaqueUnaryFn(sympy.Function): @@ -1309,7 +1315,7 @@ class OpaqueUnaryFn(sympy.Function): constant propagation. This helps avoid performing transformations that are valid for real numbers but are invalid for floating point; in particular, while we are willing to make optimizations that change - numerics for Tensor compute, we are NOT willing to make optimziations + numerics for Tensor compute, we are NOT willing to make optimizations that change numerics for size compute. """ diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 396d1d46d28998..3b020b5fabbc72 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -51,7 +51,7 @@ # TODO: Dedupe this with SYMPY_INTERP -@functools.lru_cache(None) +@functools.cache def handlers(): # TODO add CeilDiv (it doesn't appear in the index_expr) diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index e9c6b8b0e93393..acfcc596bd49cb 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -20,6 +20,9 @@ class ExprPrinter(StrPrinter): def _print_Mul(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, "*", precedence(expr)) + def _print_Not(self, expr: sympy.Expr) -> str: + return f"not ({self._print(expr.args[0])})" + def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: return self.stringify(expr.args, " + ", precedence(expr)) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 784f9e7ba05149..b838727cc04792 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -920,7 +920,7 @@ def expr_cond_pair(a, b): return (a, b) # piecewise function can be used to convert a SymBool to SymInt: - # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise. + # int_expr = Piecewise((1, bool_expr), (0, True)), it evaluates to 1 when sym_bool is True and 0 otherwise. # # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair. # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True. diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index 08aead47681822..b0152794b5c991 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs -from types import TracebackType -from typing import Optional -import tempfile -import traceback import contextlib import inspect import os.path +import tempfile +import traceback +from types import TracebackType +from typing import Optional + # This file contains utilities for ensuring dynamically compile()'d # code fragments display their line numbers in backtraces. @@ -44,6 +45,7 @@ # - Before running the compiled code, enter the # report_compile_source_on_error() context manager. + @contextlib.contextmanager def report_compile_source_on_error(): try: @@ -83,15 +85,17 @@ def report_compile_source_on_error(): # Don't delete the temporary file so the user can inspect it # TODO: This creates a temporary file for every frame, but we # technically only need one per distinct __compile_source__ - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f: + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".py" + ) as f: f.write(source) # Create a frame. Python doesn't let you construct # FrameType directly, so just make one with compile frame = tb.tb_frame - code = compile('__inspect_currentframe()', f.name, 'eval') + code = compile("__inspect_currentframe()", f.name, "eval") code = code.replace(co_name=frame.f_code.co_name) # Python 3.11 only - if hasattr(frame.f_code, 'co_linetable'): + if hasattr(frame.f_code, "co_linetable"): # We can't copy ALL of the metadata over, because you # can cause Python to segfault this way. What exactly # do we need? We need enough information for @@ -109,14 +113,9 @@ def report_compile_source_on_error(): fake_frame = eval( code, frame.f_globals, - { - **frame.f_locals, - '__inspect_currentframe': inspect.currentframe - } - ) - fake_tb = TracebackType( - None, fake_frame, tb.tb_lasti, tb.tb_lineno + {**frame.f_locals, "__inspect_currentframe": inspect.currentframe}, ) + fake_tb = TracebackType(None, fake_frame, tb.tb_lasti, tb.tb_lineno) stack.append(fake_tb) else: stack.append(tb) @@ -131,6 +130,7 @@ def report_compile_source_on_error(): raise exc.with_traceback(tb_next) # noqa: B904 + def shorten_filename(fn, *, base=None): """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user.""" if base is None: @@ -141,7 +141,8 @@ def shorten_filename(fn, *, base=None): except ValueError: return fn else: - return fn[len(prefix) + 1:] + return fn[len(prefix) + 1 :] + def format_frame(frame, *, base=None, line=False): """ @@ -154,12 +155,14 @@ def format_frame(frame, *, base=None, line=False): extra_line = f"{frame.line} # " return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}" + def format_traceback_short(tb): """Format a TracebackType in a short way, printing only the inner-most frame.""" return format_frame(traceback.extract_tb(tb)[-1]) + class CapturedTraceback: - __slots__ = ['tb', 'skip'] + __slots__ = ["tb", "skip"] def __init__(self, tb, skip=0): self.tb = tb @@ -176,15 +179,17 @@ def summary(self): return traceback.StackSummary() return _extract_symbolized_tb( - torch._C._profiler.symbolize_tracebacks([self.tb])[0], - self.skip + torch._C._profiler.symbolize_tracebacks([self.tb])[0], self.skip ) def __getstate__(self): - return (None, { - 'tb': None, # TB is not pickleable - 'skip': self.skip, - }) + return ( + None, + { + "tb": None, # TB is not pickleable + "skip": self.skip, + }, + ) @staticmethod def extract(*, script=False, cpp=False, skip=0): @@ -207,7 +212,7 @@ def extract(*, script=False, cpp=False, skip=0): torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp), # Elide extract() frame if we don't have script/cpp frames. If # we do have those frames, it doesn't work so force zero. - 0 if script or cpp else skip + 1 + 0 if script or cpp else skip + 1, ) def format(self): @@ -251,5 +256,5 @@ def _extract_symbolized_tb(tb, skip): """ stack = traceback.StackSummary() for f in reversed(tb[skip:]): - stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name'])) + stack.append(traceback.FrameSummary(f["filename"], f["line"], f["name"])) return stack diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index ff5091526b2f75..55beae4baf18a2 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -1,30 +1,33 @@ -# mypy: allow-untyped-defs import functools import hashlib +from typing import Any -@functools.lru_cache(None) +@functools.cache def has_triton_package() -> bool: try: - from triton.compiler.compiler import triton_key + import triton # noqa: F401 - return triton_key is not None + return True except ImportError: return False - except RuntimeError: - return False -@functools.lru_cache(None) -def has_triton_tma(): - if has_triton_package(): - import torch +@functools.cache +def _device_supports_tma() -> bool: + import torch - if ( - torch.cuda.is_available() - and torch.cuda.get_device_capability() >= (9, 0) - and not torch.version.hip - ): + return ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ) + + +@functools.cache +def has_triton_experimental_host_tma() -> bool: + if has_triton_package(): + if _device_supports_tma(): try: from triton.tools.experimental_descriptor import ( # noqa: F401 create_1d_tma_descriptor, @@ -38,8 +41,29 @@ def has_triton_tma(): return False -@functools.lru_cache(None) -def has_triton_tma_device(): +@functools.cache +def has_triton_tensor_descriptor_host_tma() -> bool: + if has_triton_package(): + if _device_supports_tma(): + try: + from triton.tools.tensor_descriptor import ( # noqa: F401 + TensorDescriptor, + ) + + return True + except ImportError: + pass + + return False + + +@functools.cache +def has_triton_tma() -> bool: + return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma() + + +@functools.cache +def has_triton_tma_device() -> bool: if has_triton_package(): import torch @@ -48,6 +72,7 @@ def has_triton_tma_device(): and torch.cuda.get_device_capability() >= (9, 0) and not torch.version.hip ): + # old API try: from triton.language.extra.cuda import ( # noqa: F401 experimental_device_tensormap_create1d, @@ -58,25 +83,52 @@ def has_triton_tma_device(): except ImportError: pass + # new API + try: + from triton.language import make_tensor_descriptor # noqa: F401 + + return True + except ImportError: + pass + return False @functools.lru_cache(None) +def has_triton_stable_tma_api() -> bool: + if has_triton_package(): + import torch + + if ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + try: + from triton.language import make_tensor_descriptor # noqa: F401 + + return True + except ImportError: + pass + return False + + +@functools.cache def has_triton() -> bool: if not has_triton_package(): return False from torch._dynamo.device_interface import get_interface_for_device - def cuda_extra_check(device_interface): + def cuda_extra_check(device_interface: Any) -> bool: return device_interface.Worker.get_device_properties().major >= 7 - def cpu_extra_check(device_interface): + def cpu_extra_check(device_interface: Any) -> bool: import triton.backends return "cpu" in triton.backends.backends - def _return_true(device_interface): + def _return_true(device_interface: Any) -> bool: return True triton_supported_devices = { @@ -85,7 +137,7 @@ def _return_true(device_interface): "cpu": cpu_extra_check, } - def is_device_compatible_with_triton(): + def is_device_compatible_with_triton() -> bool: for device, extra_check in triton_supported_devices.items(): device_interface = get_interface_for_device(device) if device_interface.is_available() and extra_check(device_interface): @@ -95,8 +147,8 @@ def is_device_compatible_with_triton(): return is_device_compatible_with_triton() -@functools.lru_cache(None) -def triton_backend(): +@functools.cache +def triton_backend() -> Any: from triton.compiler.compiler import make_backend from triton.runtime.driver import driver @@ -104,9 +156,9 @@ def triton_backend(): return make_backend(target) -@functools.lru_cache(None) -def triton_hash_with_backend(): - from triton.compiler.compiler import triton_key +@functools.cache +def triton_hash_with_backend() -> str: + from torch._inductor.runtime.triton_compat import triton_key backend = triton_backend() key = f"{triton_key()}-{backend.hash()}" diff --git a/torch/utils/_zip.py b/torch/utils/_zip.py index c7dd6445fabe98..b159b61de06aac 100644 --- a/torch/utils/_zip.py +++ b/torch/utils/_zip.py @@ -5,6 +5,7 @@ from pathlib import Path from zipfile import ZipFile + # Exclude some standard library modules to: # 1. Slim down the final zipped file size # 2. Remove functionality we don't want to support. diff --git a/torch/utils/backcompat/__init__.py b/torch/utils/backcompat/__init__.py index 6a53076c90a6ed..a8413b656e906e 100644 --- a/torch/utils/backcompat/__init__.py +++ b/torch/utils/backcompat/__init__.py @@ -1,8 +1,10 @@ # mypy: allow-untyped-defs -from torch._C import _set_backcompat_broadcast_warn -from torch._C import _get_backcompat_broadcast_warn -from torch._C import _set_backcompat_keepdim_warn -from torch._C import _get_backcompat_keepdim_warn +from torch._C import ( + _get_backcompat_broadcast_warn, + _get_backcompat_keepdim_warn, + _set_backcompat_broadcast_warn, + _set_backcompat_keepdim_warn, +) class Warning: @@ -18,5 +20,8 @@ def get_enabled(self): enabled = property(get_enabled, set_enabled) -broadcast_warning = Warning(_set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn) + +broadcast_warning = Warning( + _set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn +) keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn) diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index aa687de37ea04b..e11a7afc09d8aa 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -1,12 +1,11 @@ # mypy: allow-untyped-defs -import torch -from torch.overrides import ( - handle_torch_function, - has_torch_function_unary, -) -from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name from typing import Optional, Union +import torch +from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend +from torch.overrides import handle_torch_function, has_torch_function_unary + + __all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"] # TODO: Should use `torch._C._get_privateuse1_backend_name()` to get @@ -15,6 +14,7 @@ # `_privateuse1_backend_name`. _privateuse1_backend_name = "privateuseone" + def rename_privateuse1_backend(backend_name: str) -> None: r""" Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs. @@ -78,16 +78,22 @@ def rename_privateuse1_backend(backend_name: str) -> None: global _privateuse1_backend_name _privateuse1_backend_name = backend_name + def _check_register_once(module, attr): if hasattr(module, attr): - raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}") + raise RuntimeError( + f"The custom device module of {module} has already been registered with {attr}" + ) -def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int: +def _normalization_device( + custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None +) -> int: def _get_current_device_index(): _get_device_index = "current_device" - if hasattr(torch, custom_backend_name) and \ - hasattr(getattr(torch, custom_backend_name), _get_device_index): + if hasattr(torch, custom_backend_name) and hasattr( + getattr(torch, custom_backend_name), _get_device_index + ): return getattr(getattr(torch, custom_backend_name), _get_device_index)() else: # The default device index is 0. @@ -100,7 +106,7 @@ def _get_current_device_index(): elif isinstance(device, str): device = torch.device(device) - # variable devcie can only be torch.device type or int type + # variable device can only be torch.device type or int type if isinstance(device, torch.device): if device.type != custom_backend_name: raise RuntimeError(f"Invalid device, must be {custom_backend_name} device") @@ -122,12 +128,16 @@ def wrap_tensor_backend(self: torch.Tensor) -> bool: return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined] return self.device.type == custom_backend_name - _check_register_once(torch.Tensor, f'is_{custom_backend_name}') - wrap_tensor_backend.fget.__name__ = f'is_{custom_backend_name}' # type: ignore[attr-defined] - setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend) + _check_register_once(torch.Tensor, f"is_{custom_backend_name}") + wrap_tensor_backend.fget.__name__ = f"is_{custom_backend_name}" # type: ignore[attr-defined] + setattr(torch.Tensor, f"is_{custom_backend_name}", wrap_tensor_backend) - def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False, - **kwargs) -> torch.Tensor: + def wrap_tensor_to( + self: torch.Tensor, + device: Optional[Union[int, torch.device]] = None, + non_blocking=False, + **kwargs, + ) -> torch.Tensor: r"""Perform Tensor device conversion. Call the to operator implementation. .. note:: @@ -143,9 +153,20 @@ def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device] **kwargs (dict): For compatibility, may contain the key ``memory_format`` argument. """ if has_torch_function_unary(self): - return handle_torch_function(wrap_tensor_to, (self,), self, device=device, non_blocking=False, **kwargs) + return handle_torch_function( + wrap_tensor_to, + (self,), + self, + device=device, + non_blocking=False, + **kwargs, + ) device_idx = _normalization_device(custom_backend_name, device) - return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs) + return self.to( + device=torch.device(f"{custom_backend_name}:{device_idx}"), + non_blocking=non_blocking, + **kwargs, + ) _check_register_once(torch.Tensor, custom_backend_name) wrap_tensor_to.__name__ = custom_backend_name @@ -159,10 +180,13 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) - raise RuntimeError( f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module." f"Because torch.Tensor doesn't has the method {custom_backend_name}()." - f"For this error, you can try setting for_tensor=True.") + f"For this error, you can try setting for_tensor=True." + ) - def wrap_module_to(self: torch.nn.modules.module.T, - device: Optional[Union[int, torch.device]] = None) -> torch.nn.modules.module.T: + def wrap_module_to( + self: torch.nn.modules.module.T, + device: Optional[Union[int, torch.device]] = None, + ) -> torch.nn.modules.module.T: r"""Move all model parameters and buffers to the custom device. This also makes associated parameters and buffers different objects. So @@ -180,27 +204,37 @@ def wrap_module_to(self: torch.nn.modules.module.T, _check_register_once(torch.nn.Module, custom_backend_name) setattr(torch.nn.Module, custom_backend_name, wrap_module_to) -def _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name: str) -> None: + +def _generate_packed_sequence_methods_for_privateuse1_backend( + custom_backend_name: str, +) -> None: # Generate PackedSequence Module attributes and methods depends on Tensor methods, # so we need to check whether Tensor methods is already registered. - if not hasattr(torch.Tensor, f'is_{custom_backend_name}') or \ - not hasattr(torch.Tensor, custom_backend_name): + if not hasattr(torch.Tensor, f"is_{custom_backend_name}") or not hasattr( + torch.Tensor, custom_backend_name + ): raise RuntimeError( f"Can not automatically generate is_{custom_backend_name}() or " f"{custom_backend_name}() method for torch.nn.utils.rnn.PackedSequence." f"Because torch.Tensor doesn't has the method is_{custom_backend_name}()" f"or {custom_backend_name}()." - f"For this error, you can try setting for_tensor=True.") + f"For this error, you can try setting for_tensor=True." + ) @property # type: ignore[misc] def wrap_tensor_backend(self: torch.nn.utils.rnn.PackedSequence) -> bool: return self.data.device.type == custom_backend_name - _check_register_once(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}') - setattr(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}', wrap_tensor_backend) + _check_register_once(torch.nn.utils.rnn.PackedSequence, f"is_{custom_backend_name}") + setattr( + torch.nn.utils.rnn.PackedSequence, + f"is_{custom_backend_name}", + wrap_tensor_backend, + ) - def wrap_module_to(self: torch.nn.utils.rnn.PackedSequence, - *args, **kwargs) -> torch.nn.utils.rnn.PackedSequence: + def wrap_module_to( + self: torch.nn.utils.rnn.PackedSequence, *args, **kwargs + ) -> torch.nn.utils.rnn.PackedSequence: r"""Move all model parameters and buffers to the custom device. This also makes associated parameters and buffers different objects. So @@ -213,17 +247,21 @@ def wrap_module_to(self: torch.nn.utils.rnn.PackedSequence, Args: device (int, optional): if specified, all parameters will be copied to that device """ - ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs) + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( + *args, **kwargs + ) if ex.device.type == custom_backend_name: return self.to(*args, **kwargs) - kwargs.update({'device': custom_backend_name}) + kwargs.update({"device": custom_backend_name}) return self.to(*args, **kwargs) _check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name) setattr(torch.nn.utils.rnn.PackedSequence, custom_backend_name, wrap_module_to) -def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str, - unsupported_dtype: Optional[list[torch.dtype]] = None) -> None: + +def _generate_storage_methods_for_privateuse1_backend( + custom_backend_name: str, unsupported_dtype: Optional[list[torch.dtype]] = None +) -> None: # Attribute is registered in the _StorageBase class # and UntypedStorage obtains through inheritance. @property # type: ignore[misc] @@ -231,8 +269,10 @@ def wrap_storage_backend(self: torch.storage._StorageBase) -> bool: r"""Return the internal :class:`torch.UntypedStorage`.""" return self.device.type == custom_backend_name - _check_register_once(torch.storage._StorageBase, f'is_{custom_backend_name}') - setattr(torch.storage._StorageBase, f'is_{custom_backend_name}', wrap_storage_backend) + _check_register_once(torch.storage._StorageBase, f"is_{custom_backend_name}") + setattr( + torch.storage._StorageBase, f"is_{custom_backend_name}", wrap_storage_backend + ) def wrap_storage_to(self, device=None, non_blocking=False): r"""Return a copy of this object in custom device memory. @@ -250,16 +290,18 @@ def wrap_storage_to(self, device=None, non_blocking=False): # but it depends on the extended function, so this part is temporarily omitted in the automatic generation. device_idx = _normalization_device(custom_backend_name, device) - if getattr(self, f'is_{custom_backend_name}'): + if getattr(self, f"is_{custom_backend_name}"): # storage has already on expected device. if self.get_device() == device_idx: return self # For sparse storage, custom need to extend the implementation by themselves. if self.is_sparse: - raise RuntimeError(f"Can not support a sparse storage move to {custom_backend_name} backend") + raise RuntimeError( + f"Can not support a sparse storage move to {custom_backend_name} backend" + ) # create untyped_storage and copy data untyped_storage = torch.UntypedStorage( - self.size(), device=torch.device(f'{custom_backend_name}:{device_idx}') + self.size(), device=torch.device(f"{custom_backend_name}:{device_idx}") ) untyped_storage.copy_(self, non_blocking) return untyped_storage @@ -275,27 +317,38 @@ def wrap_typed_storage_backend(self: torch.storage.TypedStorage) -> bool: torch.storage._warn_typed_storage_removal() return self._untyped_storage.device.type == custom_backend_name - _check_register_once(torch.TypedStorage, f'is_{custom_backend_name}') - setattr(torch.storage.TypedStorage, f'is_{custom_backend_name}', wrap_typed_storage_backend) + _check_register_once(torch.TypedStorage, f"is_{custom_backend_name}") + setattr( + torch.storage.TypedStorage, + f"is_{custom_backend_name}", + wrap_typed_storage_backend, + ) - def wrap_typed_storage_to(self: torch.storage.TypedStorage, - device=None, non_blocking=False, **kwargs) -> torch.storage.TypedStorage: + def wrap_typed_storage_to( + self: torch.storage.TypedStorage, device=None, non_blocking=False, **kwargs + ) -> torch.storage.TypedStorage: torch.storage._warn_typed_storage_removal() if unsupported_dtype and self.dtype in unsupported_dtype: - raise RuntimeError(f"Cannot create {custom_backend_name} storage " - f"as {self.dtype} dtype is not supported by this backend") + raise RuntimeError( + f"Cannot create {custom_backend_name} storage " + f"as {self.dtype} dtype is not supported by this backend" + ) custom_backend_storage: torch.UntypedStorage = getattr( - self._untyped_storage, custom_backend_name)(device, non_blocking, **kwargs) + self._untyped_storage, custom_backend_name + )(device, non_blocking, **kwargs) return self._new_wrapped_storage(custom_backend_storage) _check_register_once(torch.TypedStorage, custom_backend_name) setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to) -def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True, - for_packed_sequence: bool = True, - for_storage: bool = False, - unsupported_dtype: Optional[list[torch.dtype]] = None) -> None: +def generate_methods_for_privateuse1_backend( + for_tensor: bool = True, + for_module: bool = True, + for_packed_sequence: bool = True, + for_storage: bool = False, + unsupported_dtype: Optional[list[torch.dtype]] = None, +) -> None: r""" Automatically generate attributes and methods for the custom backend after rename privateuse1 backend. @@ -337,11 +390,14 @@ def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module _generate_module_methods_for_privateuse1_backend(custom_backend_name) if for_storage: - _generate_storage_methods_for_privateuse1_backend(custom_backend_name, unsupported_dtype) + _generate_storage_methods_for_privateuse1_backend( + custom_backend_name, unsupported_dtype + ) if for_packed_sequence: _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name) + def _get_custom_mod_func(func_name: str): r""" Return the func named `func_name` defined in custom device module. If not defined, @@ -370,12 +426,14 @@ def func_name(*args, **kwargs): it is marked as private. It is a convenience function for backend implementers to more easily call the hooks into their backend extensions. """ - assert isinstance(func_name, str), f"func_name must be `str`, but got `{type(func_name)}`." + assert isinstance( + func_name, str + ), f"func_name must be `str`, but got `{type(func_name)}`." backend_name = _get_privateuse1_backend_name() custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type] function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type] if custom_device_mod is None or function is None: - message = f'Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend ' + message = f"Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend " message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And " message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n" raise RuntimeError(message) diff --git a/torch/utils/benchmark/utils/_stubs.py b/torch/utils/benchmark/utils/_stubs.py index 60861d1f412a0e..068e62ec87a3de 100644 --- a/torch/utils/benchmark/utils/_stubs.py +++ b/torch/utils/benchmark/utils/_stubs.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Protocol, runtime_checkable +from typing import Any, Callable +from typing_extensions import Protocol, runtime_checkable class TimerClass(Protocol): diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index 592ec66b4c04a2..d1df2987ea6c7b 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -280,7 +280,7 @@ class Compare: https://pytorch.org/tutorials/recipes/recipes/benchmark.html Args: - results: List of Measurment to display. + results: List of Measurement to display. """ def __init__(self, results: list[common.Measurement]): self._results: list[common.Measurement] = [] diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index 23b74f946c0007..cee9c8d7f7174b 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -127,7 +127,7 @@ def bench_all( This is a simple utility that can be used to benchmark torch.compile In particular it ensures that your GPU is setup to use tensor cores if it supports its It also tries out all the main backends and prints a table of results so you can easily compare them all - Many of the backendds have their own optional dependencies so please pip install them seperately + Many of the backendds have their own optional dependencies so please pip install them separately You will get one table for inference and another for training If you'd like to leverage this utility for training make sure to pass in a torch.optim.Optimizer diff --git a/torch/utils/benchmark/utils/fuzzer.py b/torch/utils/benchmark/utils/fuzzer.py index 31d5ea3b6cc77c..6fd52a7aecd39f 100644 --- a/torch/utils/benchmark/utils/fuzzer.py +++ b/torch/utils/benchmark/utils/fuzzer.py @@ -290,7 +290,7 @@ def _make_tensor(self, params, state): raw_tensor = raw_tensor.permute(tuple(np.argsort(order))) slices = [slice(0, size * step, step) for size, step in zip(size, steps)] - tensor = raw_tensor[slices] + tensor = raw_tensor[tuple(slices)] properties = { "numel": int(tensor.numel()), diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py index 731ac21359a47a..1889f6756e70fd 100644 --- a/torch/utils/benchmark/utils/timer.py +++ b/torch/utils/benchmark/utils/timer.py @@ -13,21 +13,9 @@ __all__ = ["Timer", "timer", "Language"] -if torch.backends.cuda.is_built() and torch.cuda.is_available(): # type: ignore[no-untyped-call] +if torch.accelerator.is_available(): def timer() -> float: - torch.cuda.synchronize() - return timeit.default_timer() -elif torch.xpu.is_available(): - def timer() -> float: - torch.xpu.synchronize() - return timeit.default_timer() -elif torch._C._get_privateuse1_backend_name() != "privateuseone": - privateuse1_device_handler = getattr(torch, torch._C._get_privateuse1_backend_name(), None) \ - if torch._C._get_privateuse1_backend_name() != "cpu" else None - - def timer() -> float: - if privateuse1_device_handler: - privateuse1_device_handler.synchronize() + torch.accelerator.synchronize() return timeit.default_timer() else: timer = timeit.default_timer @@ -49,12 +37,12 @@ def __init__( ) -> None: if timer is not timeit.default_timer: raise NotImplementedError( - "PyTorch was built with CUDA and a GPU is present; however " - "Timer does not yet support GPU measurements. If your " + "PyTorch was built with accelerators and an accelerator is present; however " + "Timer does not yet support accelerator measurements. If your " "code is CPU only, pass `timer=timeit.default_timer` to the " "Timer's constructor to indicate this. (Note that this will " - "produce incorrect results if the GPU is in fact used, as " - "Timer will not synchronize CUDA.)" + "produce incorrect results if an accelerator is in fact used, as " + "Timer will not synchronize the accelerator.)" ) if globals: @@ -88,7 +76,7 @@ class Timer: 1) Runtime aware: Timer will perform warmups (important as some elements of PyTorch are lazily initialized), set threadpool size so that comparisons are - apples-to-apples, and synchronize asynchronous CUDA functions when + apples-to-apples, and synchronize asynchronous accelerator functions when necessary. 2) Focus on replicates: @@ -131,8 +119,8 @@ class Timer: timer: Callable which returns the current time. If PyTorch was built - without CUDA or there is no GPU present, this defaults to - `timeit.default_timer`; otherwise it will synchronize CUDA before + without accelerators or there is no accelerator present, this defaults to + `timeit.default_timer`; otherwise it will synchronize accelerators before measuring the time. globals: @@ -359,7 +347,7 @@ def blocked_autorange( 2) A large block size better amortizes the cost of `timer` invocation, and results in a less biased measurement. This is - important because CUDA synchronization time is non-trivial + important because accelerator synchronization time is non-trivial (order single to low double digit microseconds) and would otherwise bias the measurement. diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index 7cc733abc50e4d..6209fc8ee87455 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -116,7 +116,7 @@ def bundle_inputs( ) # The above cloning function returns a torch._C.scriptmodule and we need a torch.jit.scriptmodule. - # Fortunately theres a function in _recursive that does exactly that conversion. + # Fortunately there is a function in _recursive that does exactly that conversion. cloned_module = wrap_cpp_module(clone) if isinstance(inputs, dict): assert isinstance(info, dict) or info is None diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 04f5edddce8e57..b4c5b8ea198d6c 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -858,14 +858,15 @@ def check_recomputed_tensors_match(self, gid): if not len(self.weak_holders) == self.recomp_counter[gid]: # 2. During recompute, fewer tensors were saved # - # We know that everytime we save something do original forward + # We know that every time we save something do original forward # we append to weak_holder, and every time we save a tensor # during recompute we increment recompute_counter. raise CheckpointError( "torch.utils.checkpoint: A different number of tensors was saved " "during the original forward and recomputation.\n" f"Number of tensors saved during forward: {len(self.weak_holders)}\n" - f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}" + f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}.\n" + f"{_debug_tip_msg}" ) # 3. During recompute, the same tensors were saved, but they @@ -902,10 +903,19 @@ def check_recomputed_tensors_match(self, gid): raise CheckpointError( "torch.utils.checkpoint: Recomputed values for the following tensors " "have different metadata than during the forward pass.\n" - f"{mismatched_tensors}" + f"{mismatched_tensors}.\n" + f"{_debug_tip_msg}" ) +_debug_tip_msg = """ +Tip: To see a more detailed error message, either pass `debug=True` to +`torch.utils.checkpoint.checkpoint(...)` or wrap the code block +with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to +enable checkpoint‑debug mode globally. +""" + + _checkpoint_error_template = """ \ An error happened while unpacking tensors; dumping logs of latest computation because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. @@ -1068,7 +1078,8 @@ def pack_hook(x): return x raise CheckpointError( "torch.utils.checkpoint: trying to save more tensors during " - "recomputation than during the original forward pass." + "recomputation than during the original forward pass.\n" + f"{_debug_tip_msg}" ) holder = target_frame.weak_holders[recomp_idx]() @@ -1259,7 +1270,7 @@ class CheckpointPolicy(enum.Enum): def _policy_from_bool(b): - # For backward compatability + # For backward compatibility return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE @@ -1297,7 +1308,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): out = func(*args, **kwargs) - any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) + # HOPs don't support func._schema + # HOPs don't alias -> this is always true today and will be always true for a long time + # TODO HOPs don't mutate -> this is always true today but will not be true forever + if isinstance(func, torch._ops.HigherOrderOperator): + any_ret_has_alias_info = False + else: + any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) @@ -1396,7 +1413,7 @@ def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mu # context_fn anyway, so proceed as usual. if isinstance(policy_fn_or_list, list): for op in policy_fn_or_list: - if not isinstance(op, torch._ops.OpOverload): + if not isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): _extra_msg = ( "Please update the OpOverloadPacket to a specific OpOverload." "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 4270522e6cfc06..f1f590182e08a6 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -493,7 +493,7 @@ def get_env_info(): Caching allocator config, XNNPACK availability and CPU information. Returns: - SystemEnv (namedtuple): A tuple containining various environment details + SystemEnv (namedtuple): A tuple containing various environment details and system information. """ run_lambda = run diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 46f3b60e156bca..d7f671a1d4f391 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -291,16 +291,27 @@ def _get_sycl_arch_list(): # If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled # for default spir64 target and avoid device specific compilations entirely. Further, kernels # will be JIT compiled at runtime. -def _get_sycl_target_flags(): +def _append_sycl_targets_if_missing(cflags): + if any(flag.startswith('-fsycl-targets=') for flag in cflags): + # do nothing: user has manually specified sycl targets + return if _get_sycl_arch_list() != '': - return ['-fsycl-targets=spir64_gen,spir64'] - return [''] + # AOT (spir64_gen) + JIT (spir64) + cflags.append('-fsycl-targets=spir64_gen,spir64') + else: + # JIT (spir64) + cflags.append('-fsycl-targets=spir64') + +def _get_sycl_device_flags(cflags): + # We need last occurrence of -fsycl-targets as it will be the one taking effect. + # So searching in reversed list. + flags = [f for f in reversed(cflags) if f.startswith('-fsycl-targets=')] + assert flags, "bug: -fsycl-targets should have been amended to cflags" -def _get_sycl_device_flags(): arch_list = _get_sycl_arch_list() if arch_list != '': - return [f'-Xs "-device {arch_list}"'] - return [''] + flags += [f'-Xs "-device {arch_list}"'] + return flags _COMMON_SYCL_FLAGS = [ '-fsycl', @@ -443,19 +454,18 @@ def get_compiler_abi_compatibility_and_version(compiler) -> tuple[bool, TorchVer try: if IS_LINUX: minimum_required_version = MINIMUM_GCC_VERSION - versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion']) - version = versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.') + compiler_info = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion']) else: minimum_required_version = MINIMUM_MSVC_VERSION compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT) - match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip()) - version = ['0', '0', '0'] if match is None else list(match.groups()) + match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip()) + version = ['0', '0', '0'] if match is None else list(match.groups()) except Exception: _, error, _ = sys.exc_info() logger.warning('Error checking compiler version for %s: %s', compiler, error) return (False, TorchVersion('0.0.0')) - # convert alpha-numeric string to numeric string + # convert alphanumeric string to numeric string # amdclang++ returns str like 0.0.0git, others return 0.0.0 numeric_version = [re.sub(r'\D', '', v) for v in version] @@ -759,7 +769,7 @@ def unix_wrap_ninja_compile(sources, r"""Compiles sources by outputting a ninja file and running it.""" # NB: I copied some lines from self.compiler (which is an instance # of distutils.UnixCCompiler). See the following link. - # https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567 + # https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567 # codespell:ignore # This can be fragile, but a lot of other repos also do this # (see https://github.com/search?q=_setup_compile&type=Code) # so it is probably OK; we'll also get CI signal if/when @@ -821,11 +831,11 @@ def unix_wrap_ninja_compile(sources, sycl_dlink_post_cflags = None if with_sycl: sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS - sycl_cflags += _get_sycl_target_flags() if isinstance(extra_postargs, dict): sycl_post_cflags = extra_postargs['sycl'] else: sycl_post_cflags = list(extra_postargs) + _append_sycl_targets_if_missing(sycl_post_cflags) append_std17_if_no_std_present(sycl_cflags) _append_sycl_std_if_no_std_present(sycl_cflags) host_cflags = extra_cc_cflags + common_cflags + post_cflags @@ -838,8 +848,8 @@ def unix_wrap_ninja_compile(sources, # strings passed to SYCL compiler. sycl_cflags = [shlex.quote(f) for f in sycl_cflags] sycl_cflags += _wrap_sycl_host_flags(host_cflags) - sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS - sycl_dlink_post_cflags += _get_sycl_device_flags() + sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS.copy() + sycl_dlink_post_cflags += _get_sycl_device_flags(sycl_post_cflags) sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags] _write_ninja_file_and_compile_objects( @@ -888,7 +898,7 @@ def spawn(cmd): if m ] - obj_regex = re.compile('/Fo(.*)') + obj_regex = re.compile('/Fo(.*)') # codespell:ignore obj_list = [ m.group(1) for m in (obj_regex.match(elem) for elem in cmd) if m @@ -1046,7 +1056,7 @@ def win_wrap_ninja_compile(sources, # Return *all* object filenames, not just the ones we just built. return objects # Monkey-patch the _compile or compile method. - # https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511 + # https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511 # codespell:ignore if self.compiler.compiler_type == 'msvc': if self.use_ninja: self.compiler.compile = win_wrap_ninja_compile @@ -1733,7 +1743,7 @@ def _check_and_build_extension_h_precompiler_headers( is_standalone=False): r''' Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules. - GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html + GCC official manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild PCH file. @@ -2436,8 +2446,8 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: # Deal with lists that are ' ' separated (only deal with ';' after) _arch_list = _arch_list.replace(' ', ';') # Expand named arches - for named_arch, archval in named_arches.items(): - _arch_list = _arch_list.replace(named_arch, archval) + for named_arch, archival in named_arches.items(): + _arch_list = _arch_list.replace(named_arch, archival) arch_list = _arch_list.split(';') @@ -2537,15 +2547,15 @@ def _get_num_workers(verbose: bool) -> Optional[int]: def _get_vc_env(vc_arch: str) -> dict[str, str]: try: - from setuptools import distutils + from setuptools import distutils # type: ignore[attr-defined] return distutils._msvccompiler._get_vc_env(vc_arch) except AttributeError: try: from setuptools._distutils import _msvccompiler - return _msvccompiler._get_vc_env(vc_arch) + return _msvccompiler._get_vc_env(vc_arch) # type: ignore[attr-defined] except AttributeError: from setuptools._distutils.compilers.C import msvc - return msvc._get_vc_env(vc_arch) + return msvc._get_vc_env(vc_arch) # type: ignore[attr-defined] def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> None: command = ['ninja', '-v'] @@ -2555,7 +2565,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> env = os.environ.copy() # Try to activate the vc env for the users if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' not in env: - from setuptools import distutils + from setuptools import distutils # type: ignore[attr-defined] plat_name = distutils.util.get_platform() plat_spec = PLAT_TO_VCVARS[plat_name] @@ -2572,7 +2582,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> # subprocess.run assumes that sys.__stdout__ has not been modified and # attempts to write to it by default. However, when we call _run_ninja_build # from ahead-of-time cpp extensions, the following happens: - # 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__. + # 1) If the stdout encoding is not utf-8, setuptools detaches __stdout__. # https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110 # (it probably shouldn't do this) # 2) subprocess.run (on POSIX, with no stdout override) relies on @@ -2687,7 +2697,7 @@ def _write_ninja_file_to_build_library(path, cuda_flags += _get_rocm_arch_flags(cuda_flags) cuda_flags += extra_cuda_cflags elif with_cuda: - cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags() + cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags(extra_cuda_cflags) if IS_WINDOWS: for flag in COMMON_MSVC_FLAGS: cuda_flags = ['-Xcompiler', flag] + cuda_flags @@ -2709,16 +2719,16 @@ def _write_ninja_file_to_build_library(path, if with_sycl: sycl_cflags = cflags + _COMMON_SYCL_FLAGS - sycl_cflags += _get_sycl_target_flags() sycl_cflags += extra_sycl_cflags + _append_sycl_targets_if_missing(sycl_cflags) _append_sycl_std_if_no_std_present(sycl_cflags) host_cflags = cflags # escaping quoted arguments to pass them thru SYCL compiler host_cflags = [item.replace('\\"', '\\\\"') for item in host_cflags] host_cflags = ' '.join(host_cflags) sycl_cflags += _wrap_sycl_host_flags(host_cflags) - sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS - sycl_dlink_post_cflags += _get_sycl_device_flags() + sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS.copy() + sycl_dlink_post_cflags += _get_sycl_device_flags(sycl_cflags) else: sycl_cflags = None sycl_dlink_post_cflags = None @@ -2861,7 +2871,9 @@ def sanitize_flags(flags): if IS_WINDOWS: compiler_name = "$cxx" if IS_HIP_EXTENSION else "cl" compile_rule.append( - f' command = {compiler_name} /showIncludes $cflags -c $in /Fo$out $post_cflags') + f' command = {compiler_name} ' + '/showIncludes $cflags -c $in /Fo$out $post_cflags' # codespell:ignore + ) if not IS_HIP_EXTENSION: compile_rule.append(' deps = msvc') else: diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py index ef03e053211b67..44111ef697b718 100644 --- a/torch/utils/data/_utils/__init__.py +++ b/torch/utils/data/_utils/__init__.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py. A lot of multiprocessing is used in data loading, which only supports running @@ -43,7 +42,7 @@ HAS_NUMPY = False -def _set_python_exit_flag(): +def _set_python_exit_flag() -> None: global python_exit_status python_exit_status = True diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 15a71c7d7f94f6..48fac6b5165667 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -175,7 +175,9 @@ class DataLoader(Generic[_T_co]): worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``) multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If - ``None``, the default `multiprocessing context`_ of your operating system will + ``None``, the default + `multiprocessing context `_ # noqa: D401 + of your operating system will be used. (default: ``None``) generator (torch.Generator, optional): If not ``None``, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate @@ -223,9 +225,6 @@ class DataLoader(Generic[_T_co]): .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data. - - .. _multiprocessing context: - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods """ dataset: Dataset[_T_co] @@ -481,7 +480,7 @@ def __setattr__(self, attr, val): def __iter__(self) -> _BaseDataLoaderIter: # When using a single worker the returned iterator should be - # created everytime to avoid resetting its state + # created every time to avoid resetting its state # However, in the case of a multiple workers iterator # the iterator is only created once in the lifetime of the # DataLoader object so that workers can be reused @@ -557,10 +556,10 @@ def check_worker_number_rationality(self): # necessary. # # - # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # [Note] Please note that this function respects `cpuset` only when os.sched_getaffinity is # available (available in most of Linux system, but not OSX and Windows). # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but - # it doesn't repect cpuset. + # it doesn't respect cpuset. # We don't take threading into account since each worker process is single threaded # at this time. # @@ -887,7 +886,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # 2. A similar issue araises when a `DataLoader` is used in a subprocess. # When a process ends, it shuts the all its daemonic children # down with a SIGTERM (instead of joining them without a timeout). - # Simiarly for threads, but by a different mechanism. This fact, + # Similarly for threads, but by a different mechanism. This fact, # together with a few implementation details of multiprocessing, forces # us to make workers daemonic. All of our problems arise when a # DataLoader is used in a subprocess, and are caused by multiprocessing @@ -1018,7 +1017,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # `cancel_join_thread` on that queue if its `IterableDataset` iterator # happens to exhaust coincidentally, which is out of the control of the # main process). Thus, since we will exit `pin_memory_thread` before the - # workers (see below), two separete events are used. + # workers (see below), two separate events are used. # # NOTE: In short, the protocol is that the main process will set these # `done_event`s and then the corresponding processes/threads a `None`, @@ -1572,7 +1571,7 @@ def _mark_worker_as_unavailable(self, worker_id, shutdown=False): # (2) since we don't join, the worker may still raise error, and we # prefer capturing those, rather than ignoring them, even though they # are raised after the worker has finished its job. - # Joinning is deferred to `_shutdown_workers`, which it is called when + # Joining is deferred to `_shutdown_workers`, which it is called when # all workers finish their jobs (e.g., `IterableDataset` replicas) or # when this iterator is garbage collected. diff --git a/torch/utils/data/datapipes/README.md b/torch/utils/data/datapipes/README.md index e4f0ee407c0f88..e8776bc39b87b5 100644 --- a/torch/utils/data/datapipes/README.md +++ b/torch/utils/data/datapipes/README.md @@ -51,7 +51,7 @@ Note that `__len__` method is optional for `IterDataPipe`. Like `CSVParserIterDataPipe` in the [Using DataPipe sector](#using-datapipe), `__len__` is not implemented because the size of each file streams is unknown for us before loading it. Besides, in some special cases, `__len__` method can be provided, but it would either return an integer length or raise Error depending on the arguments of DataPipe. -And, the Error is required to be `TypeError` to support Python's build-in functions like `list(dp)`. +And, the Error is required to be `TypeError` to support Python's built-in functions like `list(dp)`. Please check NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] for detailed reason in PyTorch. ### Registering DataPipe with functional API diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index 0df815358bd0e4..d3ae5b4e18f4c2 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -138,7 +138,7 @@ def _issubtype_with_constraints(variant, constraints, recursive=True): # - TypeVar[TypeVar[...]] # So, variant and each constraint may be a TypeVar or a Union. # In these cases, all of inner types from the variant are required to be - # extraced and verified as a subtype of any constraint. And, all of + # extracted and verified as a subtype of any constraint. And, all of # inner types from any constraint being a TypeVar or a Union are # also required to be extracted and verified if the variant belongs to # any of them. diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index b3958ec0c793f6..d697cb6ebc5c2f 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -51,7 +51,7 @@ def __iter__(self): yield self.output_var.apply_ops(item) -# TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions +# TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registered functions DATAPIPES_OPS = [ "_dataframes_as_tuples", "groupby", @@ -201,7 +201,7 @@ class CaptureLikeMock: def __init__(self, name): import unittest.mock as mock - # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead. + # TODO(VitalyFedyunin): Do not use private function here, copy own implementation instead. get_target, attribute = mock._get_target(name) # type: ignore[attr-defined] self.get_target = get_target self.attribute = attribute diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index d0e4191fd20a2d..c96ce82cf139a9 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -147,7 +147,7 @@ def _collate_helper(conversion, item): for name in conversion.keys(): if name not in columns_name: - raise RuntimeError("Conversion keys missmatch") + raise RuntimeError("Conversion keys mismatch") for name in columns_name: if name in conversion: diff --git a/torch/utils/data/datapipes/iter/filelister.py b/torch/utils/data/datapipes/iter/filelister.py index 9de99cf9b4a2f3..2b3d16bed2a667 100644 --- a/torch/utils/data/datapipes/iter/filelister.py +++ b/torch/utils/data/datapipes/iter/filelister.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs from collections.abc import Iterator, Sequence from typing import Union @@ -63,7 +62,7 @@ def __iter__(self) -> Iterator[str]: path, self.masks, self.recursive, self.abspath, self.non_deterministic ) - def __len__(self): + def __len__(self) -> int: if self.length == -1: raise TypeError(f"{type(self).__name__} instance doesn't have valid length") return self.length diff --git a/torch/utils/data/datapipes/iter/utils.py b/torch/utils/data/datapipes/iter/utils.py index 0728746af23ddd..f90b426be129a1 100644 --- a/torch/utils/data/datapipes/iter/utils.py +++ b/torch/utils/data/datapipes/iter/utils.py @@ -1,14 +1,17 @@ -# mypy: allow-untyped-defs import copy import warnings +from collections.abc import Iterable, Iterator, Sized +from typing import TypeVar from torch.utils.data.datapipes.datapipe import IterDataPipe +_T = TypeVar("_T") + __all__ = ["IterableWrapperIterDataPipe"] -class IterableWrapperIterDataPipe(IterDataPipe): +class IterableWrapperIterDataPipe(IterDataPipe[_T]): r""" Wraps an iterable object to create an IterDataPipe. @@ -30,11 +33,11 @@ class IterableWrapperIterDataPipe(IterDataPipe): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ - def __init__(self, iterable, deepcopy=True): + def __init__(self, iterable: Iterable[_T], deepcopy: bool = True) -> None: self.iterable = iterable self.deepcopy = deepcopy - def __iter__(self): + def __iter__(self) -> Iterator[_T]: source_data = self.iterable if self.deepcopy: try: @@ -50,5 +53,7 @@ def __iter__(self): ) yield from source_data - def __len__(self): - return len(self.iterable) + def __len__(self) -> int: + if isinstance(self.iterable, Sized): + return len(self.iterable) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py index 38e8b8ff56fe44..02865e8064f86b 100644 --- a/torch/utils/data/datapipes/map/utils.py +++ b/torch/utils/data/datapipes/map/utils.py @@ -1,14 +1,17 @@ -# mypy: allow-untyped-defs import copy import warnings +from collections.abc import Mapping, Sequence +from typing import Any, TypeVar, Union from torch.utils.data.datapipes.datapipe import MapDataPipe +_T = TypeVar("_T") + __all__ = ["SequenceWrapperMapDataPipe"] -class SequenceWrapperMapDataPipe(MapDataPipe): +class SequenceWrapperMapDataPipe(MapDataPipe[_T]): r""" Wraps a sequence object into a MapDataPipe. @@ -33,7 +36,11 @@ class SequenceWrapperMapDataPipe(MapDataPipe): 100 """ - def __init__(self, sequence, deepcopy=True): + sequence: Union[Sequence[_T], Mapping[Any, _T]] + + def __init__( + self, sequence: Union[Sequence[_T], Mapping[Any, _T]], deepcopy: bool = True + ) -> None: if deepcopy: try: self.sequence = copy.deepcopy(sequence) @@ -46,8 +53,8 @@ def __init__(self, sequence, deepcopy=True): else: self.sequence = sequence - def __getitem__(self, index): + def __getitem__(self, index: int) -> _T: return self.sequence[index] - def __len__(self): + def __len__(self) -> int: return len(self.sequence) diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index e998d545ac2315..c92bdbb00e1026 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -6,6 +6,12 @@ import torch +# Note: For benchmarking changes to samplers, see: +# /benchmarks/data/samplers_bench.py +# This benchmark compares the performance of different sampler implementations +# and can be used to evaluate the impact of optimizations. + + __all__ = [ "BatchSampler", "RandomSampler", @@ -324,7 +330,6 @@ def __init__( self.drop_last = drop_last def __iter__(self) -> Iterator[list[int]]: - # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 sampler_iter = iter(self.sampler) if self.drop_last: # Create multiple references to the same iterator diff --git a/torch/utils/data/standard_pipes.ipynb b/torch/utils/data/standard_pipes.ipynb index 3bcae687aa62f5..c40058bca7699b 100644 --- a/torch/utils/data/standard_pipes.ipynb +++ b/torch/utils/data/standard_pipes.ipynb @@ -753,7 +753,7 @@ "\n", "Arguments:\n", " - `group_key_fn`\n", - " - `group_size` - yeild resulted group as soon as `group_size` elements accumulated\n", + " - `group_size` - yield resulted group as soon as `group_size` elements accumulated\n", " - `guaranteed_group_size:int = None`\n", " - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n", "\n", @@ -962,7 +962,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This behaviour becomes noticable when data is bigger than buffer and some groups getting evicted before gathering all potential items" + "This behaviour becomes noticeable when data is bigger than buffer and some groups getting evicted before gathering all potential items" ] }, { diff --git a/torch/utils/data/typing.ipynb b/torch/utils/data/typing.ipynb index 064a963d64b2f4..1b1aa8c9da72fc 100644 --- a/torch/utils/data/typing.ipynb +++ b/torch/utils/data/typing.ipynb @@ -399,7 +399,7 @@ "\n", "Note: This decorator is only allowed to be attached to `__iter__` for now. It can be extended into `__getitem__` and further `nonblocking` functions.\n", "\n", - "`runtime_validation_disabled` is a context manager to turn off the type validaiton during runtime. It's useful for DataLoader to disable the runtime validaiton after the first epoch is finished for better performance. Note: the runtime validation is enabled by default." + "`runtime_validation_disabled` is a context manager to turn off the type validation during runtime. It's useful for DataLoader to disable the runtime validation after the first epoch is finished for better performance. Note: the runtime validation is enabled by default." ] }, { @@ -684,7 +684,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- Compatible with context mangager to disable validation" + "- Compatible with context manager to disable validation" ] }, { diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index 876f64d1b876c3..9a53ff9e84ac6e 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -3,28 +3,30 @@ import torch import enum -from torch._C import _from_dlpack from torch._C import _to_dlpack as to_dlpack __all__ = [ "DLDeviceType", "from_dlpack", - "to_dlpack", ] - class DLDeviceType(enum.IntEnum): # Enums as in DLPack specification (aten/src/ATen/dlpack.h) kDLCPU = 1, - kDLGPU = 2, - kDLCPUPinned = 3, + kDLCUDA = 2, + kDLCUDAHost = 3, kDLOpenCL = 4, kDLVulkan = 7, kDLMetal = 8, kDLVPI = 9, kDLROCM = 10, + kDLROCMHost = 11, kDLExtDev = 12, + kDLCUDAManaged = 13, kDLOneAPI = 14, + kDLWebGPU = 15, + kDLHexagon = 16, + kDLMAIA = 17, torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule @@ -104,24 +106,34 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': """ if hasattr(ext_tensor, '__dlpack__'): + kwargs: dict[str, Any] = {} + kwargs["max_version"] = (1, 0) + device = ext_tensor.__dlpack_device__() # device is either CUDA or ROCm, we need to pass the current # stream - if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): + if device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): stream = torch.cuda.current_stream(f'cuda:{device[1]}') # cuda_stream is the pointer to the stream and it is a public # attribute, but it is not documented # The array API specify that the default legacy stream must be passed # with a value of 1 for CUDA # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none - is_cuda = device[0] == DLDeviceType.kDLGPU + is_cuda = device[0] == DLDeviceType.kDLCUDA # Since pytorch is not using PTDS by default, lets directly pass # the legacy stream stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream - dlpack = ext_tensor.__dlpack__(stream=stream_ptr) - else: - dlpack = ext_tensor.__dlpack__() + kwargs["stream"] = stream_ptr + + try: + # Try running __dlpack__ while specifying `max_version` argument. + dlpack = ext_tensor.__dlpack__(**kwargs) + except TypeError: + # If that doesn't work, try removing the `max_version` argument. + kwargs.pop("max_version") + dlpack = ext_tensor.__dlpack__(**kwargs) + else: # Old versions just call the converter dlpack = ext_tensor - return _from_dlpack(dlpack) + return torch._C._from_dlpack(dlpack) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index ae3863c8ec0917..f664b564c370d1 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -179,6 +179,9 @@ ), ("CUlimit", ("hipLimit_t", CONV_TYPE, API_DRIVER)), ("CUlimit_enum", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ("CUmemAccessDesc", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), + ("CUmemAccessDesc_st", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), + ("CUmemAccessDesc_v1", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), ( "CUmemAttach_flags", ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), @@ -187,6 +190,38 @@ "CUmemAttach_flags_enum", ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), ), + ("CUmemAllocationGranularity_flags", ("hipMemAllocationGranularity_flags", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationGranularity_flags_enum", ("hipMemAllocationGranularity_flags", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationHandleType", ("hipMemAllocationHandleType", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationHandleType_enum", ("hipMemAllocationHandleType", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationProp", ("hipMemAllocationProp", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationProp_st", ("hipMemAllocationProp", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationProp_v1", ("hipMemAllocationProp", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationType", ("hipMemAllocationType", CONV_TYPE, API_DRIVER)), + ("CUmemAllocationType_enum", ("hipMemAllocationType", CONV_TYPE, API_DRIVER)), + ("CUmemGenericAllocationHandle", ("hipMemGenericAllocationHandle_t", CONV_TYPE, API_DRIVER)), + ("CUmemGenericAllocationHandle_v1", ("hipMemGenericAllocationHandle_t", CONV_TYPE, API_DRIVER)), + ("CUmemHandleType", ("hipMemHandleType", CONV_TYPE, API_DRIVER)), + ("CUmemHandleType_enum", ("hipMemHandleType", CONV_TYPE, API_DRIVER)), + ("CUmemLocation", ("hipMemLocation", CONV_TYPE, API_DRIVER)), + ("CUmemLocationType", ("hipMemLocationType", CONV_TYPE, API_DRIVER)), + ("CUmemLocationType_enum", ("hipMemLocationType", CONV_TYPE, API_DRIVER)), + ("CUmemLocation_st", ("hipMemLocation", CONV_TYPE, API_DRIVER)), + ("CUmemLocation_v1", ("hipMemLocation", CONV_TYPE, API_DRIVER)), + ("CUmemOperationType", ("hipMemOperationType", CONV_TYPE, API_DRIVER)), + ("CUmemOperationType_enum", ("hipMemOperationType", CONV_TYPE, API_DRIVER)), + ("CUmemPoolHandle_st", ("ihipMemPoolHandle_t", CONV_TYPE, API_DRIVER)), + ("CUmemPoolProps", ("hipMemPoolProps", CONV_TYPE, API_DRIVER)), + ("CUmemPoolProps_st", ("hipMemPoolProps", CONV_TYPE, API_DRIVER)), + ("CUmemPoolProps_v1", ("hipMemPoolProps", CONV_TYPE, API_DRIVER)), + ("CUmemPoolPtrExportData", ("hipMemPoolPtrExportData", CONV_TYPE, API_DRIVER)), + ("CUmemPoolPtrExportData_st", ("hipMemPoolPtrExportData", CONV_TYPE, API_DRIVER)), + ("CUmemPoolPtrExportData_v1", ("hipMemPoolPtrExportData", CONV_TYPE, API_DRIVER)), + ("CUmemPool_attribute", ("hipMemPoolAttr", CONV_TYPE, API_DRIVER)), + ("CUmemPool_attribute_enum", ("hipMemPoolAttr", CONV_TYPE, API_DRIVER)), + ("CUmem_advise_enum", ("hipMemoryAdvise", CONV_TYPE, API_DRIVER)), + ("CUmem_range_attribute_enum", ("hipMemRangeAttribute", CONV_TYPE, API_DRIVER)), + ("CUmemoryPool", ("hipMemPool_t", CONV_TYPE, API_DRIVER)), ("CUmemorytype", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), ("CUmemorytype_enum", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), ("CUresourcetype", ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), @@ -543,6 +578,7 @@ ("curandState", ("hiprandState_t", CONV_TYPE, API_RAND)), ("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)), ("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)), + ("cudaGraphNode_t", ("hipGraphNode_t", CONV_TYPE, API_RAND)), ("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)), ("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)), ("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)), @@ -607,6 +643,7 @@ ("curand_precalc.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("curand_uniform.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("cusparse.h", ("hipsparse/hipsparse.h", CONV_INCLUDE, API_RAND)), + ("cusparseLt.h", ("hipsparselt/hipsparselt.h", CONV_INCLUDE, API_RAND)), ("cufft.h", ("hipfft/hipfft.h", CONV_INCLUDE, API_BLAS)), ("cufftXt.h", ("hipfft/hipfftXt.h", CONV_INCLUDE, API_BLAS)), # PyTorch also has a source file named "nccl.h", so we need to "<"">" to differentiate @@ -2549,6 +2586,38 @@ "CU_MEMORYTYPE_UNIFIED", ("hipMemTypeUnified", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), ), + ("CU_MEMHOSTREGISTER_READ_ONLY", ("hipHostRegisterReadOnly", CONV_TYPE, API_DRIVER)), + ("CU_MEMPOOL_ATTR_RELEASE_THRESHOLD", ("hipMemPoolAttrReleaseThreshold", CONV_TYPE, API_DRIVER)), + ("CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT", ("hipMemPoolAttrReservedMemCurrent", CONV_TYPE, API_DRIVER)), + ("CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH", ("hipMemPoolAttrReservedMemHigh", CONV_TYPE, API_DRIVER)), + ( + "CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES", + ("hipMemPoolReuseAllowInternalDependencies", CONV_TYPE, API_DRIVER) + ), + ("CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC", ("hipMemPoolReuseAllowOpportunistic", CONV_TYPE, API_DRIVER)), + ( + "CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES", + ("hipMemPoolReuseFollowEventDependencies", CONV_TYPE, API_DRIVER) + ), + ("CU_MEMPOOL_ATTR_USED_MEM_CURRENT", ("hipMemPoolAttrUsedMemCurrent", CONV_TYPE, API_DRIVER)), + ("CU_MEMPOOL_ATTR_USED_MEM_HIGH", ("hipMemPoolAttrUsedMemHigh", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ACCESS_FLAGS_PROT_NONE", ("hipMemAccessFlagsProtNone", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ACCESS_FLAGS_PROT_READ", ("hipMemAccessFlagsProtRead", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ACCESS_FLAGS_PROT_READWRITE", ("hipMemAccessFlagsProtReadWrite", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOCATION_TYPE_INVALID", ("hipMemAllocationTypeInvalid", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOCATION_TYPE_MAX", ("hipMemAllocationTypeMax", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOCATION_TYPE_PINNED", ("hipMemAllocationTypePinned", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOC_GRANULARITY_MINIMUM", ("hipMemAllocationGranularityMinimum", CONV_TYPE, API_DRIVER)), + ("CU_MEM_ALLOC_GRANULARITY_RECOMMENDED", ("hipMemAllocationGranularityRecommended", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_GENERIC", ("hipMemHandleTypeGeneric", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_NONE", ("hipMemHandleTypeNone", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR", ("hipMemHandleTypePosixFileDescriptor", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_WIN32", ("hipMemHandleTypeWin32", CONV_TYPE, API_DRIVER)), + ("CU_MEM_HANDLE_TYPE_WIN32_KMT", ("hipMemHandleTypeWin32Kmt", CONV_TYPE, API_DRIVER)), + ("CU_MEM_LOCATION_TYPE_DEVICE", ("hipMemLocationTypeDevice", CONV_TYPE, API_DRIVER)), + ("CU_MEM_LOCATION_TYPE_INVALID", ("hipMemLocationTypeInvalid", CONV_TYPE, API_DRIVER)), + ("CU_MEM_OPERATION_TYPE_MAP", ("hipMemOperationTypeMap", CONV_TYPE, API_DRIVER)), + ("CU_MEM_OPERATION_TYPE_UNMAP", ("hipMemOperationTypeUnmap", CONV_TYPE, API_DRIVER)), ( "CU_RESOURCE_TYPE_ARRAY", ("hipResourceTypeArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), @@ -3179,6 +3248,63 @@ "cuMemGetAddressRange_v2", ("hipMemGetAddressRange", CONV_MEM, API_DRIVER), ), + ("cuArray3DCreate_v2", ("hipArray3DCreate", CONV_MEM, API_DRIVER)), + ("cuArray3DGetDescriptor_v2", ("hipArray3DGetDescriptor", CONV_MEM, API_DRIVER)), + ("cuArrayGetDescriptor_v2", ("hipArrayGetDescriptor", CONV_MEM, API_DRIVER)), + ("cuMemAlloc", ("hipMalloc", CONV_MEM, API_DRIVER)), + ("cuMemAllocHost_v2", ("hipMemAllocHost", CONV_MEM, API_DRIVER)), + ("cuMemAllocPitch_v2", ("hipMemAllocPitch", CONV_MEM, API_DRIVER)), + ("cuMemGetInfo", ("hipMemGetInfo", CONV_MEM, API_DRIVER)), + ("cuMemHostGetDevicePointer_v2", ("hipHostGetDevicePointer", CONV_MEM, API_DRIVER)), + ("cuMemHostRegister", ("hipHostRegister", CONV_MEM, API_DRIVER)), + ("cuMemcpy2DAsync_v2", ("hipMemcpyParam2DAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpy2DUnaligned_v2", ("hipDrvMemcpy2DUnaligned", CONV_MEM, API_DRIVER)), + ("cuMemcpy2D_v2", ("hipMemcpyParam2D", CONV_MEM, API_DRIVER)), + ("cuMemcpy3DAsync_v2", ("hipDrvMemcpy3DAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpy3D_v2", ("hipDrvMemcpy3D", CONV_MEM, API_DRIVER)), + ("cuMemcpyAtoA_v2", ("hipMemcpyAtoA", CONV_MEM, API_DRIVER)), + ("cuMemcpyAtoD_v2", ("hipMemcpyAtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyAtoHAsync_v2", ("hipMemcpyAtoHAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyAtoH_v2", ("hipMemcpyAtoH", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoA_v2", ("hipMemcpyDtoA", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoD", ("hipMemcpyDtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoDAsync", ("hipMemcpyDtoDAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoH", ("hipMemcpyDtoH", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoHAsync", ("hipMemcpyDtoHAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoA_v2", ("hipMemcpyHtoA", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoD", ("hipMemcpyHtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoDAsync", ("hipMemcpyHtoDAsync", CONV_MEM, API_DRIVER)), + ("cuMemsetD16", ("hipMemsetD16", CONV_MEM, API_DRIVER)), + ("cuMemsetD32", ("hipMemsetD32", CONV_MEM, API_DRIVER)), + ("cuMemsetD8", ("hipMemsetD8", CONV_MEM, API_DRIVER)), + ("cuMemAddressFree", ("hipMemAddressFree", CONV_MEM, API_DRIVER)), + ("cuMemAddressReserve", ("hipMemAddressReserve", CONV_MEM, API_DRIVER)), + ("cuMemCreate", ("hipMemCreate", CONV_MEM, API_DRIVER)), + ("cuMemExportToShareableHandle", ("hipMemExportToShareableHandle", CONV_MEM, API_DRIVER)), + ("cuMemGetAccess", ("hipMemGetAccess", CONV_MEM, API_DRIVER)), + ("cuMemGetAllocationGranularity", ("hipMemGetAllocationGranularity", CONV_MEM, API_DRIVER)), + ("cuMemGetAllocationPropertiesFromHandle", ("hipMemGetAllocationPropertiesFromHandle", CONV_MEM, API_DRIVER)), + ("cuMemImportFromShareableHandle", ("hipMemImportFromShareableHandle", CONV_MEM, API_DRIVER)), + ("cuMemMap", ("hipMemMap", CONV_MEM, API_DRIVER)), + ("cuMemMapArrayAsync", ("hipMemMapArrayAsync", CONV_MEM, API_DRIVER)), + ("cuMemRelease", ("hipMemRelease", CONV_MEM, API_DRIVER)), + ("cuMemRetainAllocationHandle", ("hipMemRetainAllocationHandle", CONV_MEM, API_DRIVER)), + ("cuMemSetAccess", ("hipMemSetAccess", CONV_MEM, API_DRIVER)), + ("cuMemUnmap", ("hipMemUnmap", CONV_MEM, API_DRIVER)), + ("cuMemAllocAsync", ("hipMallocAsync", CONV_MEM, API_DRIVER)), + ("cuMemAllocFromPoolAsync", ("hipMallocFromPoolAsync", CONV_MEM, API_DRIVER)), + ("cuMemFreeAsync", ("hipFreeAsync", CONV_MEM, API_DRIVER)), + ("cuMemPoolCreate", ("hipMemPoolCreate", CONV_MEM, API_DRIVER)), + ("cuMemPoolDestroy", ("hipMemPoolDestroy", CONV_MEM, API_DRIVER)), + ("cuMemPoolExportPointer", ("hipMemPoolExportPointer", CONV_MEM, API_DRIVER)), + ("cuMemPoolExportToShareableHandle", ("hipMemPoolExportToShareableHandle", CONV_MEM, API_DRIVER)), + ("cuMemPoolGetAccess", ("hipMemPoolGetAccess", CONV_MEM, API_DRIVER)), + ("cuMemPoolGetAttribute", ("hipMemPoolGetAttribute", CONV_MEM, API_DRIVER)), + ("cuMemPoolImportFromShareableHandle", ("hipMemPoolImportFromShareableHandle", CONV_MEM, API_DRIVER)), + ("cuMemPoolImportPointer", ("hipMemPoolImportPointer", CONV_MEM, API_DRIVER)), + ("cuMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_DRIVER)), + ("cuMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_DRIVER)), + ("cuMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_DRIVER)), ( "cuPointerGetAttributes", ("hipPointerGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), @@ -4066,9 +4192,15 @@ ("cudaMemPoolAttrUsedMemCurrent", ("hipMemPoolAttrUsedMemCurrent", CONV_MEM, API_RUNTIME)), ("cudaMemPoolAttrUsedMemHigh", ("hipMemPoolAttrUsedMemHigh", CONV_MEM, API_RUNTIME)), ("cudaMemPoolGetAttribute", ("hipMemPoolGetAttribute", CONV_MEM, API_RUNTIME)), - ("cudaMemPoolReuseAllowInternalDependencies", ("hipMemPoolReuseAllowInternalDependencies", CONV_MEM, API_RUNTIME)), + ( + "cudaMemPoolReuseAllowInternalDependencies", + ("hipMemPoolReuseAllowInternalDependencies", CONV_MEM, API_RUNTIME) + ), ("cudaMemPoolReuseAllowOpportunistic", ("hipMemPoolReuseAllowOpportunistic", CONV_MEM, API_RUNTIME)), - ("cudaMemPoolReuseFollowEventDependencies", ("hipMemPoolReuseFollowEventDependencies", CONV_MEM, API_RUNTIME)), + ( + "cudaMemPoolReuseFollowEventDependencies", + ("hipMemPoolReuseFollowEventDependencies", CONV_MEM, API_RUNTIME) + ), ("cudaMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_RUNTIME)), ("cudaMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_RUNTIME)), ("cudaMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_RUNTIME)), @@ -4119,12 +4251,15 @@ ("cudaHostAlloc", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), ("cudaMemoryTypeHost", ("hipMemoryTypeHost", CONV_MEM, API_RUNTIME)), ("cudaMemoryTypeDevice", ("hipMemoryTypeDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeUnregistered", ("hipMemoryTypeUnregistered", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeManaged", ("hipMemoryTypeManaged", CONV_MEM, API_RUNTIME)), ("make_cudaExtent", ("make_hipExtent", CONV_MEM, API_RUNTIME)), ("make_cudaPitchedPtr", ("make_hipPitchedPtr", CONV_MEM, API_RUNTIME)), ("make_cudaPos", ("make_hipPos", CONV_MEM, API_RUNTIME)), ("cudaHostAllocDefault", ("hipHostMallocDefault", CONV_MEM, API_RUNTIME)), ("cudaHostAllocPortable", ("hipHostMallocPortable", CONV_MEM, API_RUNTIME)), ("cudaHostAllocMapped", ("hipHostMallocMapped", CONV_MEM, API_RUNTIME)), + ("cudaHostNodeParams", ("hipHostNodeParams", CONV_MEM, API_RUNTIME)), ( "cudaHostAllocWriteCombined", ("hipHostMallocWriteCombined", CONV_MEM, API_RUNTIME), @@ -4193,7 +4328,10 @@ ("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)), ("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)), ("cudaGraphInstantiateWithFlags", ("hipGraphInstantiateWithFlags", CONV_TYPE, API_RUNTIME)), - ("cudaGraphInstantiateFlagAutoFreeOnLaunch", ("hipGraphInstantiateFlagAutoFreeOnLaunch", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphInstantiateFlagAutoFreeOnLaunch", + ("hipGraphInstantiateFlagAutoFreeOnLaunch", CONV_TYPE, API_RUNTIME) + ), ("cudaGraphDestroy", ("hipGraphDestroy", CONV_TYPE, API_RUNTIME)), ("cudaGraphExecDestroy", ("hipGraphExecDestroy", CONV_TYPE, API_RUNTIME)), ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), @@ -4202,6 +4340,209 @@ ("cudaGraphDebugDotFlagsVerbose", ("hipGraphDebugDotFlagsVerbose", CONV_NUMERIC_LITERAL, API_RUNTIME)), ("cudaGraphRetainUserObject", ("hipGraphRetainUserObject", CONV_TYPE, API_RUNTIME)), ("cudaGraphUserObjectMove", ("hipGraphUserObjectMove", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceGetGraphMemAttribute", ("hipDeviceGetGraphMemAttribute", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceGraphMemTrim", ("hipDeviceGraphMemTrim", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceSetGraphMemAttribute", ("hipDeviceSetGraphMemAttribute", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddChildGraphNode", ("hipGraphAddChildGraphNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddDependencies", ("hipGraphAddDependencies", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddEmptyNode", ("hipGraphAddEmptyNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddEventRecordNode", ("hipGraphAddEventRecordNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddEventWaitNode", ("hipGraphAddEventWaitNode", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphAddExternalSemaphoresSignalNode", + ("hipGraphAddExternalSemaphoresSignalNode", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphAddExternalSemaphoresWaitNode", ("hipGraphAddExternalSemaphoresWaitNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddHostNode", ("hipGraphAddHostNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddKernelNode", ("hipGraphAddKernelNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemAllocNode", ("hipGraphAddMemAllocNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemFreeNode", ("hipGraphAddMemFreeNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemcpyNode", ("hipGraphAddMemcpyNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemcpyNode1D", ("hipGraphAddMemcpyNode1D", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemcpyNodeFromSymbol", ("hipGraphAddMemcpyNodeFromSymbol", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemcpyNodeToSymbol", ("hipGraphAddMemcpyNodeToSymbol", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddMemsetNode", ("hipGraphAddMemsetNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphAddNode", ("hipGraphAddNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphChildGraphNodeGetGraph", ("hipGraphChildGraphNodeGetGraph", CONV_TYPE, API_RUNTIME)), + ("cudaGraphClone", ("hipGraphClone", CONV_TYPE, API_RUNTIME)), + ("cudaGraphCreate", ("hipGraphCreate", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDestroyNode", ("hipGraphDestroyNode", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEventRecordNodeGetEvent", ("hipGraphEventRecordNodeGetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEventRecordNodeSetEvent", ("hipGraphEventRecordNodeSetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEventWaitNodeGetEvent", ("hipGraphEventWaitNodeGetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEventWaitNodeSetEvent", ("hipGraphEventWaitNodeSetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecChildGraphNodeSetParams", ("hipGraphExecChildGraphNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecEventRecordNodeSetEvent", ("hipGraphExecEventRecordNodeSetEvent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecEventWaitNodeSetEvent", ("hipGraphExecEventWaitNodeSetEvent", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExecExternalSemaphoresSignalNodeSetParams", + ("hipGraphExecExternalSemaphoresSignalNodeSetParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecExternalSemaphoresWaitNodeSetParams", + ("hipGraphExecExternalSemaphoresWaitNodeSetParams", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphExecGetFlags", ("hipGraphExecGetFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecHostNodeSetParams", ("hipGraphExecHostNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecKernelNodeSetParams", ("hipGraphExecKernelNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecMemcpyNodeSetParams", ("hipGraphExecMemcpyNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecMemcpyNodeSetParams1D", ("hipGraphExecMemcpyNodeSetParams1D", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExecMemcpyNodeSetParamsFromSymbol", + ("hipGraphExecMemcpyNodeSetParamsFromSymbol", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecMemcpyNodeSetParamsToSymbol", + ("hipGraphExecMemcpyNodeSetParamsToSymbol", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphExecMemsetNodeSetParams", ("hipGraphExecMemsetNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecNodeSetParams", ("hipGraphExecNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecUpdate", ("hipGraphExecUpdate", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExternalSemaphoresSignalNodeGetParams", + ("hipGraphExternalSemaphoresSignalNodeGetParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExternalSemaphoresSignalNodeSetParams", + ("hipGraphExternalSemaphoresSignalNodeSetParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExternalSemaphoresWaitNodeGetParams", + ("hipGraphExternalSemaphoresWaitNodeGetParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExternalSemaphoresWaitNodeSetParams", + ("hipGraphExternalSemaphoresWaitNodeSetParams", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphGetEdges", ("hipGraphGetEdges", CONV_TYPE, API_RUNTIME)), + ("cudaGraphGetRootNodes", ("hipGraphGetRootNodes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphHostNodeGetParams", ("hipGraphHostNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphHostNodeSetParams", ("hipGraphHostNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateWithParams", ("hipGraphInstantiateWithParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeCopyAttributes", ("hipGraphKernelNodeCopyAttributes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeGetAttribute", ("hipGraphKernelNodeGetAttribute", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeGetParams", ("hipGraphKernelNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeSetAttribute", ("hipGraphKernelNodeSetAttribute", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodeSetParams", ("hipGraphKernelNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAllocNodeGetParams", ("hipGraphMemAllocNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemFreeNodeGetParams", ("hipGraphMemFreeNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeGetParams", ("hipGraphMemcpyNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeSetParams", ("hipGraphMemcpyNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeSetParams1D", ("hipGraphMemcpyNodeSetParams1D", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeSetParamsFromSymbol", ("hipGraphMemcpyNodeSetParamsFromSymbol", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemcpyNodeSetParamsToSymbol", ("hipGraphMemcpyNodeSetParamsToSymbol", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemsetNodeGetParams", ("hipGraphMemsetNodeGetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemsetNodeSetParams", ("hipGraphMemsetNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeFindInClone", ("hipGraphNodeFindInClone", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeGetDependencies", ("hipGraphNodeGetDependencies", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeGetDependentNodes", ("hipGraphNodeGetDependentNodes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeGetEnabled", ("hipGraphNodeGetEnabled", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeGetType", ("hipGraphNodeGetType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeSetEnabled", ("hipGraphNodeSetEnabled", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeSetParams", ("hipGraphNodeSetParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphReleaseUserObject", ("hipGraphReleaseUserObject", CONV_TYPE, API_RUNTIME)), + ("cudaGraphRemoveDependencies", ("hipGraphRemoveDependencies", CONV_TYPE, API_RUNTIME)), + ("cudaGraphUpload", ("hipGraphUpload", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectRelease", ("hipUserObjectRelease", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectRetain", ("hipUserObjectRetain", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlags", ("hipGraphDebugDotFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsEventNodeParams", ("hipGraphDebugDotFlagsEventNodeParams", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphDebugDotFlagsExtSemasSignalNodeParams", + ("hipGraphDebugDotFlagsExtSemasSignalNodeParams", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphDebugDotFlagsExtSemasWaitNodeParams", + ("hipGraphDebugDotFlagsExtSemasWaitNodeParams", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphDebugDotFlagsHandles", ("hipGraphDebugDotFlagsHandles", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsHostNodeParams", ("hipGraphDebugDotFlagsHostNodeParams", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphDebugDotFlagsKernelNodeAttributes", + ("hipGraphDebugDotFlagsKernelNodeAttributes", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphDebugDotFlagsKernelNodeParams", ("hipGraphDebugDotFlagsKernelNodeParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsMemcpyNodeParams", ("hipGraphDebugDotFlagsMemcpyNodeParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsMemsetNodeParams", ("hipGraphDebugDotFlagsMemsetNodeParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDependencyType", ("hipGraphDependencyType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDependencyTypeDefault", ("hipGraphDependencyTypeDefault", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDependencyTypeProgrammatic", ("hipGraphDependencyTypeProgrammatic", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDependencyType_enum", ("hipGraphDependencyType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEdgeData", ("hipGraphEdgeData", CONV_TYPE, API_RUNTIME)), + ("cudaGraphEdgeData_st", ("hipGraphEdgeData", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecUpdateError", ("hipGraphExecUpdateError", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExecUpdateErrorFunctionChanged", + ("hipGraphExecUpdateErrorFunctionChanged", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecUpdateErrorNodeTypeChanged", + ("hipGraphExecUpdateErrorNodeTypeChanged", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphExecUpdateErrorNotSupported", ("hipGraphExecUpdateErrorNotSupported", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphExecUpdateErrorParametersChanged", + ("hipGraphExecUpdateErrorParametersChanged", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecUpdateErrorTopologyChanged", + ("hipGraphExecUpdateErrorTopologyChanged", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphExecUpdateErrorUnsupportedFunctionChange", + ("hipGraphExecUpdateErrorUnsupportedFunctionChange", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphExecUpdateResult", ("hipGraphExecUpdateResult", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecUpdateSuccess", ("hipGraphExecUpdateSuccess", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateError", ("hipGraphInstantiateError", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateFlagDeviceLaunch", ("hipGraphInstantiateFlagDeviceLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateFlagUpload", ("hipGraphInstantiateFlagUpload", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphInstantiateFlagUseNodePriority", + ("hipGraphInstantiateFlagUseNodePriority", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphInstantiateFlags", ("hipGraphInstantiateFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateInvalidStructure", ("hipGraphInstantiateInvalidStructure", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphInstantiateMultipleDevicesNotSupported", + ("hipGraphInstantiateMultipleDevicesNotSupported", CONV_TYPE, API_RUNTIME) + ), + ( + "cudaGraphInstantiateNodeOperationNotSupported", + ("hipGraphInstantiateNodeOperationNotSupported", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphInstantiateParams", ("hipGraphInstantiateParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateParams_st", ("hipGraphInstantiateParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateResult", ("hipGraphInstantiateResult", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateSuccess", ("hipGraphInstantiateSuccess", CONV_TYPE, API_RUNTIME)), + ("cudaGraphKernelNodePortDefault", ("hipGraphKernelNodePortDefault", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphKernelNodePortLaunchCompletion", + ("hipGraphKernelNodePortLaunchCompletion", CONV_TYPE, API_RUNTIME) + ), + ("cudaGraphKernelNodePortProgrammatic", ("hipGraphKernelNodePortProgrammatic", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttrReservedMemCurrent", ("hipGraphMemAttrReservedMemCurrent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttrReservedMemHigh", ("hipGraphMemAttrReservedMemHigh", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttrUsedMemCurrent", ("hipGraphMemAttrUsedMemCurrent", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttrUsedMemHigh", ("hipGraphMemAttrUsedMemHigh", CONV_TYPE, API_RUNTIME)), + ("cudaGraphMemAttributeType", ("hipGraphMemAttributeType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeParams", ("hipGraphNodeParams", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeType", ("hipGraphNodeType", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeConditional", ("hipGraphNodeTypeConditional", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeCount", ("hipGraphNodeTypeCount", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeEmpty", ("hipGraphNodeTypeEmpty", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeEventRecord", ("hipGraphNodeTypeEventRecord", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeExtSemaphoreSignal", ("hipGraphNodeTypeExtSemaphoreSignal", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeExtSemaphoreWait", ("hipGraphNodeTypeExtSemaphoreWait", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeGraph", ("hipGraphNodeTypeGraph", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeHost", ("hipGraphNodeTypeHost", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeKernel", ("hipGraphNodeTypeKernel", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeMemAlloc", ("hipGraphNodeTypeMemAlloc", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeMemFree", ("hipGraphNodeTypeMemFree", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeMemcpy", ("hipGraphNodeTypeMemcpy", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeMemset", ("hipGraphNodeTypeMemset", CONV_TYPE, API_RUNTIME)), + ("cudaGraphNodeTypeWaitEvent", ("hipGraphNodeTypeWaitEvent", CONV_TYPE, API_RUNTIME)), ("cudaUserObject_t", ("hipUserObject_t", CONV_TYPE, API_RUNTIME)), ("cudaUserObjectCreate", ("hipUserObjectCreate", CONV_TYPE, API_RUNTIME)), ("cudaUserObjectNoDestructorSync", ("hipUserObjectNoDestructorSync", CONV_TYPE, API_RUNTIME)), @@ -7342,6 +7683,9 @@ ("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)), ("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)), ("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)), @@ -8256,6 +8600,44 @@ "CUSPARSE_MATRIX_TYPE_GENERAL", ("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPECIAL), ), + # SparseLt + ("cuSPARSELt", ("hipSPARSELt", CONV_TYPE, API_SPECIAL)), + ("AT_CUSPARSELT_ENABLED", ("AT_HIPSPARSELT_ENABLED", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_SPARSITY_50_PERCENT", ("HIPSPARSELT_SPARSITY_50_PERCENT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseComputeType", ("hipsparseLtComputetype_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32F", ("HIPSPARSELT_COMPUTE_32F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_16F", ("HIPSPARSELT_COMPUTE_16F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32I", ("HIPSPARSELT_COMPUTE_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_TF32", ("HIPSPARSELT_COMPUTE_TF32", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_BIAS_POINTER", ("HIPSPARSELT_MATMUL_BIAS_POINTER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_DEFAULT", ("HIPSPARSELT_MATMUL_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_CONFIG_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", ("HIPSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseLtHandle_t", ("hipsparseLtHandle_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatDescriptor_t", ("hipsparseLtMatDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtInit", ("hipsparseLtInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompressedSize2", ("hipsparseLtSpMMACompressedSize2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompress2", ("hipsparseLtSpMMACompress2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptor_t", ("hipsparseLtMatmulDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulPlan_t", ("hipsparseLtMatmulPlan_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulAlgSelection_t", ("hipsparseLtMatmulAlgSelection_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtDenseDescriptorInit", ("hipsparseLtDenseDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptorInit", ("hipsparseLtMatmulDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescSetAttribute", ("hipsparseLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSelectionInit", ("hipsparseLtMatmulAlgSelectionInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSetAttribute", ("hipsparseLtMatmulAlgSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanInit", ("hipsparseLtMatmulPlanInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulGetWorkspace", ("hipsparseLtMatmulGetWorkspace", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulSearch", ("hipsparseLtMatmulSearch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgGetAttribute", ("hipsparseLtMatmulAlgGetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmul", ("hipsparseLtMatmul", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatDescriptorDestroy", ("hipsparseLtMatDescriptorDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanDestroy", ("hipsparseLtMatmulPlanDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseGetErrorString", ("hipsparseGetErrorString", CONV_MATH_FUNC, API_SPECIAL)), # SOLVER ("cublasOperation_t", ("hipsolverOperation_t", CONV_TYPE, API_SPECIAL)), ("CUBLAS_OP_N", ("HIPSOLVER_OP_N", CONV_NUMERIC_LITERAL, API_SPECIAL)), @@ -8694,7 +9076,7 @@ ] ) -# We must tread very carefully here. Blanket conversions like are done +# We must treat very carefully here. Blanket conversions like are done # in CAFFE2_SPECIFIC_MAPPINGS are not presently supported on PyTorch, # because a regex for CUDA will also match a filename like CUDAGuard.h, # but the HIPIFY script doesn't presently move the file and so the substitution diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 76b63e70a8ef0d..0e816020635bea 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -465,7 +465,7 @@ def find_closure_group(input_string, start, group): def find_bracket_group(input_string, start): - """Finds the first balanced parantheses.""" + """Finds the first balanced parentheses.""" return find_closure_group(input_string, start, group=["{", "}"]) diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index c41add0fcbefd0..e6e93966afdbd6 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -223,6 +223,11 @@ def hook(_, grad_output): # Special case if no input required gradients, this hook should call the user # hook directly if self.input_tensors_index is None: + warnings.warn("Full backward hook is firing when gradients are computed " + "with respect to module outputs since no inputs require gradients. See " + "https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " # noqa: B950 + "for more details.", + stacklevel=5) grad_inputs = self._pack_with_none([], [], self.n_inputs) for user_hook in self.user_hooks: res = user_hook(self.module, grad_inputs, self.grad_outputs) diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index 5ab8fd9a35e11f..12ccd2d2f5cbda 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -232,14 +232,14 @@ def get_pickle(name): model_data = get_pickle("data") constants = get_pickle("constants") - # Intern strings that are likely to be re-used. + # Intern strings that are likely to be reused. # Pickle automatically detects shared structure, - # so re-used strings are stored efficiently. + # so reused strings are stored efficiently. # However, JSON has no way of representing this, # so we have to do it manually. interned_strings : dict[str, int] = {} - def ist(s): + def intern(s): if s not in interned_strings: interned_strings[s] = len(interned_strings) return interned_strings[s] @@ -293,7 +293,7 @@ def parse_new_format(line): s_start = 0 s_end = 0 text = raw_code[start:end] - code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end]) + code_parts.append([text.decode("utf-8"), intern(s_file), s_line, intern(s_text), s_start, s_end]) code_files[zi.filename] = code_parts extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json") diff --git a/torch/utils/tensorboard/_proto_graph.py b/torch/utils/tensorboard/_proto_graph.py index 30140a22cff673..c4e234dff6ba09 100644 --- a/torch/utils/tensorboard/_proto_graph.py +++ b/torch/utils/tensorboard/_proto_graph.py @@ -1,11 +1,13 @@ -# mypy: allow-untyped-defs -from typing import Optional +import torch + +from typing import Optional, Union +from collections.abc import Sequence from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto -def attr_value_proto(dtype, shape, s): +def attr_value_proto(dtype: object, shape: Optional[Sequence[int]], s: Optional[str]) -> dict[str, AttrValue]: """Create a dict of objects matching a NodeDef's attr field. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto @@ -21,7 +23,7 @@ def attr_value_proto(dtype, shape, s): return attr -def tensor_shape_proto(outputsize): +def tensor_shape_proto(outputsize: Sequence[int]) -> TensorShapeProto: """Create an object matching a tensor_shape field. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto . @@ -30,14 +32,14 @@ def tensor_shape_proto(outputsize): def node_proto( - name, - op="UnSpecified", - input=None, - dtype=None, - shape: Optional[tuple] = None, - outputsize=None, - attributes="", -): + name: str, + op: str = "UnSpecified", + input: Optional[Union[list[str], str]] = None, + dtype: Optional[torch.dtype] = None, + shape: Optional[tuple[int, ...]] = None, + outputsize: Optional[Sequence[int]] = None, + attributes: str = "", +) -> NodeDef: """Create an object matching a NodeDef. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto . diff --git a/torch/utils/tensorboard/_utils.py b/torch/utils/tensorboard/_utils.py index 8acaf1696cb1f2..f0ad185d968f54 100644 --- a/torch/utils/tensorboard/_utils.py +++ b/torch/utils/tensorboard/_utils.py @@ -45,7 +45,7 @@ def _prepare_video(V): Convesrion is done from [batchsize, time(frame), channel(color), height, width] (5D tensor) to [time(frame), new_width, new_height, channel] (4D tensor). - A batch of images are spreaded to a grid, which forms a frame. + A batch of images are spread to a grid, which forms a frame. e.g. Video with batchsize 16 will have a 4x4 grid. """ b, t, c, h, w = V.shape diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 45581031081734..a5aebde06a34e6 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -70,7 +70,7 @@ def remove(): gc.callbacks.remove(gc_callback) return remove -# Function to visualize cycles adapated from refcycle: +# Function to visualize cycles adapted from refcycle: # Copyright 2013 Mark Dickinson # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -482,7 +482,7 @@ def warn_tensor_cycles(): Install a warning that reports whenever a cycle that is holding CUDA memory is observed. The warning produces an .html file that visualizes the cycle, - and links it to the stack frame that allocted the CUDA tensor. + and links it to the stack frame that allocated the CUDA tensor. Reference cycles are freed by the cycle collector rather than being cleaned up when the objects in the cycle first become unreachable. If a cycle points to a tensor, diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 948709e5e152ef..23e3a25c90f5f1 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -62,7 +62,7 @@ def device_count() -> int: def is_available() -> bool: r"""Return a bool indicating if XPU is currently available.""" - # This function nerver throws. + # This function never throws. return device_count() > 0 diff --git a/torchgen/_autoheuristic/README.md b/torchgen/_autoheuristic/README.md index 58613e54fb872f..2241785c2983bb 100644 --- a/torchgen/_autoheuristic/README.md +++ b/torchgen/_autoheuristic/README.md @@ -89,7 +89,7 @@ context = AHContext() context.add_feature("m", mat1.shape[0]) context.add_feature("k", mat1.shape[1]) -# adding a categorical feture +# adding a categorical feature context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True) ``` diff --git a/torchgen/_autoheuristic/benchmark_utils.py b/torchgen/_autoheuristic/benchmark_utils.py index f0161065a3a093..ad75c6715dd773 100644 --- a/torchgen/_autoheuristic/benchmark_utils.py +++ b/torchgen/_autoheuristic/benchmark_utils.py @@ -18,7 +18,7 @@ def transpose_tensors(p_transpose_both: float = 0.05) -> tuple[bool, bool]: def fits_in_memory(dtype: Any, m: int, k: int, n: int) -> Any: threshold_memory = torch.cuda.get_device_properties(0).total_memory / 4 - # dividing by 4 beause we otherwise sometimes run out of memory, I assume because + # dividing by 4 because we otherwise sometimes run out of memory, I assume because # inductor creates copies of tensors for benchmarking? return dtype.itemsize * (m * k + k * n + m * n) < threshold_memory diff --git a/torchgen/_autoheuristic/collect_data.sh b/torchgen/_autoheuristic/collect_data.sh index 442f6120327f5b..73b6364829b9b3 100644 --- a/torchgen/_autoheuristic/collect_data.sh +++ b/torchgen/_autoheuristic/collect_data.sh @@ -1,6 +1,6 @@ #!/bin/bash -# this script makes it easy parallize collecting data across using multiple GPUs +# This script makes it easy to parallelize data collection across multiple GPUs # Check if tmux is installed if ! command -v tmux &> /dev/null; then diff --git a/torchgen/_autoheuristic/mixed_mm/gen_data_mixed_mm.py b/torchgen/_autoheuristic/mixed_mm/gen_data_mixed_mm.py index 48dfa788977d9f..c6cde6f814479a 100644 --- a/torchgen/_autoheuristic/mixed_mm/gen_data_mixed_mm.py +++ b/torchgen/_autoheuristic/mixed_mm/gen_data_mixed_mm.py @@ -15,7 +15,7 @@ ) import torch -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported] @@ -59,7 +59,7 @@ def run_benchmark( ) b = b.to(dtype=dtype_right) - with fresh_inductor_cache(): + with fresh_cache(): def mixed_mm(A, B): return torch.mm(A, B.to(A.dtype)) diff --git a/torchgen/_autoheuristic/mixed_mm/generate_heuristic_mixedmm.sh b/torchgen/_autoheuristic/mixed_mm/generate_heuristic_mixedmm.sh index dd6ac78e9dfbce..27a671511beac7 100644 --- a/torchgen/_autoheuristic/mixed_mm/generate_heuristic_mixedmm.sh +++ b/torchgen/_autoheuristic/mixed_mm/generate_heuristic_mixedmm.sh @@ -12,7 +12,7 @@ MODE=$1 # !!! SPECIFY THE GPUs THAT YOU WANT TO USE HERE !!! GPU_DEVICE_IDS="4,5" -# !!! SPECIFY THE CONDA ENVIRONEMNT THAT YOU WANT TO BE ACTIVATED HERE !!! +# !!! SPECIFY THE CONDA ENVIRONMENT THAT YOU WANT TO BE ACTIVATED HERE !!! CONDA_ENV=heuristic-pr NUM_SAMPLES=2000 diff --git a/torchgen/_autoheuristic/mm/gen_data_mm.py b/torchgen/_autoheuristic/mm/gen_data_mm.py index 8ad6dc1c008d16..b614125d9c908b 100644 --- a/torchgen/_autoheuristic/mm/gen_data_mm.py +++ b/torchgen/_autoheuristic/mm/gen_data_mm.py @@ -16,7 +16,7 @@ ) import torch -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported] @@ -57,7 +57,7 @@ def run_benchmark( dtype_right=dtype, ) - with fresh_inductor_cache(): + with fresh_cache(): def mixed_mm(A: Any, B: Any) -> Any: return torch.mm(A, B) diff --git a/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py b/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py index d5ddc44c1b7bcf..b476bacfb67db7 100644 --- a/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py +++ b/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py @@ -18,7 +18,7 @@ from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found] get_alignment_size_dtype, ) -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.utils import fresh_cache class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported] @@ -74,7 +74,7 @@ def run_benchmark( print(f"transpose_left={transpose_left} transpose_right={transpose_right}") print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}") - with fresh_inductor_cache(): + with fresh_cache(): def mm(a: Any, b: Any) -> Any: return torch.mm(a, b) diff --git a/torchgen/_autoheuristic/pad_mm/generate_heuristic_pad_mm.sh b/torchgen/_autoheuristic/pad_mm/generate_heuristic_pad_mm.sh index d7cb6b99164c2c..b7dac53179befa 100644 --- a/torchgen/_autoheuristic/pad_mm/generate_heuristic_pad_mm.sh +++ b/torchgen/_autoheuristic/pad_mm/generate_heuristic_pad_mm.sh @@ -12,7 +12,7 @@ MODE=$1 # !!! SPECIFY THE GPUs THAT YOU WANT TO USE HERE !!! GPU_DEVICE_IDS="4,5" -# !!! SPECIFY THE CONDA ENVIRONEMNT THAT YOU WANT TO BE ACTIVATED HERE !!! +# !!! SPECIFY THE CONDA ENVIRONMENT THAT YOU WANT TO BE ACTIVATED HERE !!! CONDA_ENV=heuristic-pr NUM_SAMPLES=2000 diff --git a/torchgen/_autoheuristic/train_decision.py b/torchgen/_autoheuristic/train_decision.py index f27a30b48fb5bb..932baf16e84507 100644 --- a/torchgen/_autoheuristic/train_decision.py +++ b/torchgen/_autoheuristic/train_decision.py @@ -94,7 +94,7 @@ def get_allowed_wrong_prediction_pct(self): def get_grid_search_values(self): """ - Standard values for grid search. Can be overriden. + Standard values for grid search. Can be overridden. """ return { "max_depth": [5, 6, 7], diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 958e68e3e25d1f..e006ee830ebac7 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -1,163 +1,177 @@ -# Be extra careful when you edit this file, because it affects AOTInductor ABI compatbility. See +# Be extra careful when you edit this file, because it affects AOTInductor ABI compatibility. See # https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 # for details. # # The inductor_fallback_ops list is based on the fallback ops from torch/_inductor/lowering.py. +# # Generally speaking, it is ok to add a new op to the list, but you need to run # `python torchgen/gen.py --update-aoti-c-shim` in order to regenerate C shim header files. # But it is NOT ok to remove an existing fallback op from the list, since that will break # some existing AOTInductor-compiled models. -inductor_fallback_ops = { - "aten._adaptive_avg_pool2d_backward.default", - "aten._adaptive_avg_pool2d.default", - "aten._adaptive_avg_pool3d_backward.default", - "aten._adaptive_avg_pool3d.default", - "aten._addmm_activation.default", - "aten._cdist_backward.default", - "aten._cdist_forward.default", - "aten._cudnn_rnn.default", - "aten._dyn_quant_matmul_4bit.default", - "aten._dyn_quant_pack_4bit_weight.default", - "aten._efficient_attention_backward.default", - "aten._efficient_attention_forward.default", - "aten._efficientzerotensor.default", - "aten._embedding_bag_dense_backward.default", - "aten._embedding_bag_forward_only.default", - "aten._embedding_bag_per_sample_weights_backward.default", - "aten._embedding_bag.default", - "aten._fft_c2c.default", - "aten._fft_r2c.default", - "aten._flash_attention_backward.default", - "aten._flash_attention_forward.default", - "aten._fused_moving_avg_obs_fq_helper_functional.default", - "aten._fused_moving_avg_obs_fq_helper.default", - "aten._histogramdd_from_bin_cts.default", - "aten._int_mm.out", - "aten._pdist_backward.default", - "aten._pdist_forward.default", - "aten._scaled_dot_product_cudnn_attention_backward.default", - "aten._scaled_dot_product_cudnn_attention.default", - "aten._scaled_dot_product_efficient_attention_backward.default", - "aten._scaled_dot_product_efficient_attention.default", - "aten._scaled_dot_product_flash_attention_backward.default", - "aten._scaled_dot_product_flash_attention_for_cpu_backward.default", - "aten._scaled_dot_product_flash_attention_for_cpu.default", - "aten._scaled_dot_product_flash_attention.default", - "aten._scaled_dot_product_fused_attention_overrideable_backward.default", - "aten._scaled_dot_product_fused_attention_overrideable.default", - "aten._scaled_mm.default", - "aten._scaled_mm.out", - "aten._segment_reduce_backward.default", - "aten._thnn_fused_lstm_cell.default", - "aten._to_sparse.default", - "aten._trilinear.default", - "aten._weight_int4pack_mm.default", - "aten._weight_int8pack_mm.default", - "aten.abs.default", - "aten.adaptive_max_pool2d_backward.default", - "aten.adaptive_max_pool2d.default", - "aten.adaptive_max_pool3d_backward.default", - "aten.adaptive_max_pool3d.default", - "aten.add.Scalar", - "aten.add.Tensor", - "aten.addbmm.default", - "aten.addmm.out", - "aten.addmv.default", - "aten.angle.default", - "aten.avg_pool2d_backward.default", - "aten.avg_pool2d.default", - "aten.avg_pool3d_backward.default", - "aten.avg_pool3d.default", - "aten.baddbmm.out", - "aten.bernoulli_.float", - "aten.bernoulli_.Tensor", - "aten.bmm.out", - "aten.bucketize.Tensor", - "aten.cat.default", - "aten.cholesky_inverse.default", - "aten.cholesky_solve.default", - "aten.convolution_backward.default", - "aten.convolution.default", - "aten.cummax.default", - "aten.cummin.default", - "aten.cumprod.default", - "aten.cumsum.default", - "aten.exponential.default", - "aten.fractional_max_pool2d_backward.default", - "aten.fractional_max_pool2d.default", - "aten.fractional_max_pool3d_backward.default", - "aten.fractional_max_pool3d.default", - "aten.gcd.default", - "aten.geqrf.default", - "aten.grid_sampler_2d_backward.default", - "aten.hann_window.default", - "aten.histc.default", - "aten.histogram.bin_ct", - "aten.index_put.default", - "aten.index_reduce.default", - "aten.index.Tensor", - "aten.kthvalue.default", - "aten.logcumsumexp.default", - "aten.lu_unpack.default", - "aten.masked_scatter_backward.default", - "aten.masked_scatter.default", - "aten.masked_select.default", - "aten.max_pool2d_with_indices_backward.default", - "aten.max_pool2d_with_indices.default", - "aten.max_pool3d_with_indices_backward.default", - "aten.max_pool3d_with_indices.default", - "aten.max_unpool2d.default", - "aten.max_unpool3d.default", - "aten.median.default", - "aten.mm.out", - "aten.mode.default", - "aten.mul.Scalar", - "aten.mul.Tensor", - "aten.nanmedian.default", - "aten.native_dropout.default", - "aten.nonzero.default", - "aten.normal_functional.default", - "aten.ormqr.default", - "aten.permute.default", - "aten.polar.default", - "aten.pow.Scalar", - "aten.pow.Tensor_Scalar", - "aten.pow.Tensor_Tensor", - "aten.rand.default", - "aten.rand.generator", - "aten.randint.default", - "aten.randint.generator", - "aten.randint.low_out", - "aten.randint.low", - "aten.randn.default", - "aten.randn.generator", - "aten.randperm.default", - "aten.repeat_interleave.Tensor", - "aten.replication_pad1d_backward.default", - "aten.replication_pad2d_backward.default", - "aten.reshape.default", - "aten.resize_.default", - "aten.resize_as_.default", - "aten.scatter_reduce.two_out", - "aten.scatter.src_out", - "aten.scatter.value_out", - "aten.searchsorted.Scalar", - "aten.searchsorted.Tensor", - "aten.segment_reduce.default", - "aten.set_.source_Tensor", - "aten.slice.Tensor", - "aten.soft_margin_loss_backward.default", - "aten.sort.default", - "aten.sort.stable", - "aten.squeeze.dim", - "aten.to_sparse.default", - "aten.topk.default", - "aten.triangular_solve.default", - "aten.uniform.default", - "aten.upsample_bicubic2d_backward.default", - "aten.upsample_linear1d_backward.default", - "aten.upsample_trilinear3d_backward.default", - "aten.view_as_complex.default", - "aten.view_as_real.default", - "aten.view.dtype", +# +# A fallback op version defaults to 1. If you want to extend an existing fallback op by adding +# a new argument with a default value, while it is fine in the Python world, it will be BC-breaking +# when generating C shim. Thus you need to bump up the version number of that fallback op by +# updating the entry in the inductor_fallback_ops list, adding a new version number with a list +# of new arguments, and then run `python torchgen/gen.py --update-aoti-c-shim` to regenerate. + +inductor_fallback_ops: dict[str, dict[str, list[str]]] = { + "aten._adaptive_avg_pool2d_backward.default": {}, + "aten._adaptive_avg_pool2d.default": {}, + "aten._adaptive_avg_pool3d_backward.default": {}, + "aten._adaptive_avg_pool3d.default": {}, + "aten._addmm_activation.default": {}, + "aten._cdist_backward.default": {}, + "aten._cdist_forward.default": {}, + "aten._cudnn_rnn.default": {}, + "aten._dyn_quant_matmul_4bit.default": {}, + "aten._dyn_quant_pack_4bit_weight.default": {}, + "aten._efficient_attention_backward.default": {}, + "aten._efficient_attention_forward.default": {}, + "aten._efficientzerotensor.default": {}, + "aten._embedding_bag_dense_backward.default": {}, + "aten._embedding_bag_forward_only.default": {}, + "aten._embedding_bag_per_sample_weights_backward.default": {}, + "aten._embedding_bag.default": {}, + "aten._fft_c2c.default": {}, + "aten._fft_r2c.default": {}, + "aten._flash_attention_backward.default": {}, + "aten._flash_attention_forward.default": {}, + "aten._fused_moving_avg_obs_fq_helper_functional.default": {}, + "aten._fused_moving_avg_obs_fq_helper.default": {}, + "aten._fused_rms_norm.default": {}, + "aten._histogramdd_from_bin_cts.default": {}, + "aten._int_mm.out": {}, + "aten._pdist_backward.default": {}, + "aten._pdist_forward.default": {}, + "aten._scaled_dot_product_attention_math_for_mps.default": {}, + "aten._scaled_dot_product_cudnn_attention_backward.default": {}, + "aten._scaled_dot_product_cudnn_attention.default": {}, + "aten._scaled_dot_product_efficient_attention_backward.default": {}, + "aten._scaled_dot_product_efficient_attention.default": {}, + "aten._scaled_dot_product_flash_attention_backward.default": {}, + "aten._scaled_dot_product_flash_attention_for_cpu_backward.default": {}, + "aten._scaled_dot_product_flash_attention_for_cpu.default": {}, + "aten._scaled_dot_product_flash_attention.default": {}, + "aten._scaled_dot_product_fused_attention_overrideable_backward.default": {}, + "aten._scaled_dot_product_fused_attention_overrideable.default": {}, + "aten._scaled_mm.default": {}, + "aten._scaled_mm.out": {}, + "aten._segment_reduce_backward.default": {}, + "aten._thnn_fused_lstm_cell.default": {}, + "aten._to_sparse.default": {}, + "aten._trilinear.default": {}, + "aten._weight_int4pack_mm.default": {}, + "aten._weight_int8pack_mm.default": {}, + "aten.abs.default": {}, + "aten.adaptive_max_pool2d_backward.default": {}, + "aten.adaptive_max_pool2d.default": {}, + "aten.adaptive_max_pool3d_backward.default": {}, + "aten.adaptive_max_pool3d.default": {}, + "aten.add.Scalar": {}, + "aten.add.Tensor": {}, + "aten.addbmm.default": {}, + "aten.addmm.out": {}, + "aten.addmv.default": {}, + "aten.angle.default": {}, + "aten.avg_pool2d_backward.default": {}, + "aten.avg_pool2d.default": {}, + "aten.avg_pool3d_backward.default": {}, + "aten.avg_pool3d.default": {}, + "aten.baddbmm.out": {}, + "aten.bernoulli_.float": {}, + "aten.bernoulli_.Tensor": {}, + "aten.bmm.out": {}, + "aten.bucketize.Tensor": {}, + "aten.cat.default": {}, + "aten.cholesky_inverse.default": {}, + "aten.cholesky_solve.default": {}, + "aten.convolution_backward.default": {}, + "aten.convolution.default": {}, + "aten.cummax.default": {}, + "aten.cummin.default": {}, + "aten.cumprod.default": {}, + "aten.cumsum.default": {}, + "aten.exponential.default": {}, + "aten.fill_.Scalar": {}, + "aten.fractional_max_pool2d_backward.default": {}, + "aten.fractional_max_pool2d.default": {}, + "aten.fractional_max_pool3d_backward.default": {}, + "aten.fractional_max_pool3d.default": {}, + "aten.gcd.default": {}, + "aten.geqrf.default": {}, + "aten.grid_sampler_2d_backward.default": {}, + "aten.hann_window.default": {}, + "aten.histc.default": {}, + "aten.histogram.bin_ct": {}, + "aten.index_put.default": {}, + "aten.index_reduce.default": {}, + "aten.index.Tensor": {}, + "aten.kthvalue.default": {}, + "aten.logcumsumexp.default": {}, + "aten.lu_unpack.default": {}, + "aten.masked_scatter_backward.default": {}, + "aten.masked_scatter.default": {}, + "aten.masked_select.default": {}, + "aten.max_pool2d_with_indices_backward.default": {}, + "aten.max_pool2d_with_indices.default": {}, + "aten.max_pool3d_with_indices_backward.default": {}, + "aten.max_pool3d_with_indices.default": {}, + "aten.max_unpool2d.default": {}, + "aten.max_unpool3d.default": {}, + "aten.median.default": {}, + "aten.mm.out": {}, + "aten.mode.default": {}, + "aten.mul.Scalar": {}, + "aten.mul.Tensor": {}, + "aten.nanmedian.default": {}, + "aten.narrow.default": {}, + "aten.native_dropout.default": {}, + "aten.nonzero.default": {}, + "aten.normal_functional.default": {}, + "aten.ormqr.default": {}, + "aten.pad.default": {}, + "aten.permute.default": {}, + "aten.polar.default": {}, + "aten.pow.Scalar": {}, + "aten.pow.Tensor_Scalar": {}, + "aten.pow.Tensor_Tensor": {}, + "aten.rand.default": {}, + "aten.rand.generator": {}, + "aten.randint.default": {}, + "aten.randint.generator": {}, + "aten.randint.low_out": {}, + "aten.randint.low": {}, + "aten.randn.default": {}, + "aten.randn.generator": {}, + "aten.randperm.default": {}, + "aten.repeat_interleave.Tensor": {}, + "aten.replication_pad1d_backward.default": {}, + "aten.replication_pad2d_backward.default": {}, + "aten.reshape.default": {}, + "aten.resize_.default": {}, + "aten.resize_as_.default": {}, + "aten.scatter_reduce.two_out": {}, + "aten.scatter.src_out": {}, + "aten.scatter.value_out": {}, + "aten.searchsorted.Scalar": {}, + "aten.searchsorted.Tensor": {}, + "aten.segment_reduce.default": {}, + "aten.set_.source_Tensor": {}, + "aten.slice.Tensor": {}, + "aten.soft_margin_loss_backward.default": {}, + "aten.sort.default": {}, + "aten.sort.stable": {}, + "aten.squeeze.dim": {}, + "aten.to_sparse.default": {}, + "aten.topk.default": {}, + "aten.triangular_solve.default": {}, + "aten.uniform.default": {}, + "aten.upsample_bicubic2d_backward.default": {}, + "aten.upsample_linear1d_backward.default": {}, + "aten.upsample_trilinear3d_backward.default": {}, + "aten.view_as_complex.default": {}, + "aten.view_as_real.default": {}, + "aten.view.dtype": {}, + "aten._weight_int4pack_mm_with_scales_and_zeros.default": {}, } diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 3f3b793825c9cb..96e192d3a48a9c 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -93,7 +93,7 @@ class ForwardDerivative: # This is only used by inplace operations required_original_self_value: bool - # If this formula is specified in derivatives.yaml or if we are re-using the + # If this formula is specified in derivatives.yaml or if we are reusing the # out of place formula for inplace is_reusing_outplace_formula: bool @@ -632,7 +632,7 @@ def find_info( info_dict = non_functional_info_by_signature[f_sig] # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 assert not any( - any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs) + any("self" in str(input.nctype.name) for input in info.all_saved_inputs) for info in info_dict.values() ), f"""\ Attempted to convert a derivative formula for a mutable operator @@ -699,7 +699,7 @@ def find_info( # we make sure that the original value of the input that is being modified inplace (self_p) is # not used in the formula. Note that the formula can use "original_self_p" here and that would # trigger a clone of the original input. - # - If we are re-using the out of place formula (is_exact_match == False) then we replace every + # - If we are reusing the out of place formula (is_exact_match == False) then we replace every # occurrence of self_p and self_t by original_self_p and original_self_t. These will be # populated by cloned version of the original input (either the clone done by the backward AD # logic if self is also used in a backward formula or a special clone that we add). diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index c619ec45d2f898..862cef30dba49f 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -127,7 +127,7 @@ def valuetype_type( # Translation of types occurring in JIT arguments to a C++ argument type. -# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type. +# If remove_non_owning_ref_types is set, we'll guarantee that the output CType is not a non-owning reference type. # For example, we'll return std::vector instead of IntArrayRef. # See Note [translation from C++ reference to value types] def argumenttype_type( diff --git a/torchgen/api/python.py b/torchgen/api/python.py index de6134bee6b86b..dbfa7306016305 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -212,17 +212,17 @@ def format_function_signature( if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",): return sig - arguments = [f" {arg}," for arg in arguments] - return "\n".join( - ( - f"def {name}(", - *( - arg if len(arg) <= 80 else f" # fmt: off\n{arg}\n # fmt: on" - for arg in arguments - ), - f"){return_type}: ...", - ) - ).replace(" # fmt: off\n # fmt: on\n", "") + lines = [ + f"def {name}(", + *(f" {arg}," for arg in arguments), + f"){return_type}: ...", + ] + sig = "\n".join(lines) + if all(len(line) <= 80 for line in lines): + return sig + # ruff format bug for compound statements: https://github.com/astral-sh/ruff/issues/18658 + # use `skip` instead of `on` + `off` + return sig.removesuffix(" ...") + " # fmt: skip\n ..." @dataclass(frozen=True) diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index 384eeeb8e483fc..b3856e65e7003c 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -49,7 +49,7 @@ class CppSignature: # Is this a fallback C++ binding? Fallback bindings are enabled by # manual_cpp_binding: True and are alternate, non-public API that - # lets manual C++ binding implementors access the binding that would + # lets manual C++ binding implementers access the binding that would # have been automatically generated fallback_binding: bool = False diff --git a/torchgen/build.bzl b/torchgen/build.bzl index 2ec68955df9161..50765869f8d5d9 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -18,13 +18,3 @@ def define_targets(rules): rules.requirement("typing-extensions"), ], ) - - rules.py_binary( - name = "gen_executorch", - srcs = [":torchgen"], - visibility = ["//visibility:public"], - deps = [ - rules.requirement("PyYAML"), - rules.requirement("typing-extensions"), - ], - ) diff --git a/torchgen/dest/native_functions.py b/torchgen/dest/native_functions.py index b1488b4f1887be..05e252d09f9c16 100644 --- a/torchgen/dest/native_functions.py +++ b/torchgen/dest/native_functions.py @@ -12,7 +12,7 @@ def torch_api_key_word_prefix(bankend_index: BackendIndex) -> str: if bankend_index.external: return "" - # Although Intel GPU ATen library is out-of-tree, it still utilizes torchgen to produce structrued + # Although Intel GPU ATen library is out-of-tree, it still utilizes torchgen to produce structured # kernels. Regarding these produced structured kernels, they should be visible for the Intel GPU ATen # library. Therefore, we need to add "TORCH_XPU_API" prefix to these structured kernels, # rather than "TORCH_API". Because the semantic of "TORCH_API" is "hidden" for out-of-tree backends. diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index ffe90bcaba85de..52bb9602a73f05 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -764,7 +764,7 @@ def gen_one(self, f: NativeFunction) -> str | None: # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the - # MORE likely situation for a backend implementor. How do we pick? + # MORE likely situation for a backend implementer. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require @@ -777,7 +777,7 @@ def gen_one(self, f: NativeFunction) -> str | None: and f.func.kind() is SchemaKind.out ): # Never generate a default implementation for out, that's what you - # have to define as a backend implementor + # have to define as a backend implementer return None # Note [Direct dispatch bindings] diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index 832316d018e851..045d8de110e744 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -42,7 +42,7 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # NB: not bothering to generate dispatch stub forward declaration in header, -# we can just paste it whereever necessary +# we can just paste it wherever necessary # TODO: use BackendIndex # dispatch_key: DispatchKey # only CPU/CUDA right now diff --git a/torchgen/executorch/api/custom_ops.py b/torchgen/executorch/api/custom_ops.py deleted file mode 100644 index 641b7e9c94103c..00000000000000 --- a/torchgen/executorch/api/custom_ops.py +++ /dev/null @@ -1,151 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from torchgen import dest - - -# disable import sorting to avoid circular dependency. -from torchgen.api.types import DispatcherSignature # usort: skip -from torchgen.context import method_with_native_function -from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant -from torchgen.utils import concatMap, Target - - -if TYPE_CHECKING: - from collections.abc import Sequence - - from torchgen.executorch.model import ETKernelIndex - from torchgen.selective_build.selector import SelectiveBuilder - - -# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at -# model authoring side. -@dataclass(frozen=True) -class ComputeNativeFunctionStub: - @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: - if Variant.function not in f.variants: - return None - - sig = DispatcherSignature.from_schema( - f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False - ) - assert sig is not None - if len(f.func.returns) == 0: - ret_name = "" - elif len(f.func.returns) == 1: - if f.func.arguments.out: - ret_name = f.func.arguments.out[0].name - else: - ret_name = next( - ( - a.name - for a in f.func.arguments.flat_non_out - if a.type == f.func.returns[0].type - ), - "", - ) - if not ret_name: - # if return type is tensor - if f.func.returns[0].type == BaseType(BaseTy.Tensor): - # Returns an empty tensor - ret_name = "at::Tensor()" - else: - raise Exception( # noqa: TRY002 - f"Can't handle this return type {f.func}" - ) # noqa: TRY002 - elif len(f.func.arguments.out) == len(f.func.returns): - # Returns a tuple of out arguments - tensor_type = "at::Tensor &" - comma = ", " - ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( - {comma.join([r.name for r in f.func.arguments.out])} - )""" - else: - assert all(a.type == BaseType(BaseTy.Tensor) for a in f.func.returns), ( - f"Only support tensor returns but got {f.func.returns}" - ) - # Returns a tuple of empty tensors - tensor_type = "at::Tensor" - comma = ", " - ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( - {comma.join(["at::Tensor()" for _ in f.func.returns])} - )""" - ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else "" - return f""" -{sig.defn()} {{ - {ret_str} -}} - """ - - -def gen_custom_ops_registration( - *, - native_functions: Sequence[NativeFunction], - selector: SelectiveBuilder, - kernel_index: ETKernelIndex, - rocm: bool, -) -> tuple[str, str]: - """ - Generate custom ops registration code for dest.RegisterDispatchKey. - - :param native_functions: a sequence of `NativeFunction` - :param selector: for selective build. - :param kernel_index: kernels for all the ops. - :param rocm: bool for dest.RegisterDispatchKey. - :return: generated C++ code to register custom operators into PyTorch - """ - - # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet. - # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex. - - dispatch_key = DispatchKey.CPU - backend_index = kernel_index._to_backend_index() - static_init_dispatch_registrations = "" - ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) - for native_function in native_functions: - ns_grouped_native_functions[native_function.namespace].append(native_function) - - for namespace, functions in ns_grouped_native_functions.items(): - if len(functions) == 0: - continue - dispatch_registrations_body = "\n".join( - list( - concatMap( - dest.RegisterDispatchKey( - backend_index, - Target.REGISTRATION, - selector, - rocm=rocm, - symint=False, - class_method_name=None, - skip_dispatcher_op_registration=False, - ), - functions, - ) - ) - ) - static_init_dispatch_registrations += f""" -TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ -{dispatch_registrations_body} -}}""" - anonymous_definition = "\n".join( - list( - concatMap( - dest.RegisterDispatchKey( - backend_index, - Target.ANONYMOUS_DEFINITION, - selector, - rocm=rocm, - symint=False, - class_method_name=None, - skip_dispatcher_op_registration=False, - ), - native_functions, - ) - ) - ) - return anonymous_definition, static_init_dispatch_registrations diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py deleted file mode 100644 index 081a3d4ece1c5b..00000000000000 --- a/torchgen/executorch/api/et_cpp.py +++ /dev/null @@ -1,367 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing_extensions import assert_never - -from torchgen import local -from torchgen.api.types import ( - ArgName, - BaseCType, - Binding, - ConstRefCType, - CType, - MutRefCType, - NamedCType, - SpecialArgName, - TupleCType, - VectorCType, - voidT, -) -from torchgen.executorch.api.types import ( - ArrayRefCType, - BaseTypeToCppMapping, - OptionalCType, - scalarT, - tensorListT, - tensorT, -) -from torchgen.model import ( - Argument, - Arguments, - BaseTy, - BaseType, - ListType, - NativeFunction, - OptionalType, - Return, - SelfArgument, - TensorOptionsArguments, - Type, -) - - -if TYPE_CHECKING: - from collections.abc import Sequence - - -""" -This file describes the translation of JIT schema to the public C++ API, which is what people use when they call -functions like at::add. It also serves as a native function API, which is the signature of kernels, -since in Executorch CppSignature is the same as NativeSignature. - -Difference between this file and torchgen.api.cpp.py: - - - Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with - torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch). - - - Executorch doesn't support Dimname. - - - Executorch runtime doesn't support SymInt, will treat it as int. -""" - - -# Translation of "value types" in JIT schema to C++ API type. Value -# types look the same no matter if they are argument types or return -# types. Returns None if the type in question is not a value type. -def valuetype_type( - t: Type, - *, - binds: ArgName, -) -> NamedCType | None: - if isinstance(t, BaseType): - if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: - return None - # For SymInt we simply treat it as int. - elif str(t) == "SymInt": - return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int])) - # All other BaseType currently map directly to BaseCppTypes. - return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) - elif isinstance(t, OptionalType): - elem = valuetype_type(t.elem, binds=binds) - if elem is None: - return None - return NamedCType(binds, OptionalCType(elem.type)) - elif isinstance(t, ListType): - if str(t.elem) == "bool": - assert t.size is not None - return NamedCType( - binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool])) - ) - else: - return None - else: - raise AssertionError(f"unrecognized type {repr(t)}") - - -# Translation of types occurring in JIT arguments to a C++ argument type. -# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type. -# For example, we'll return std::vector instead of IntArrayRef. -# See Note [translation from C++ reference to value types] -def argumenttype_type( - t: Type, - *, - mutable: bool, - binds: ArgName, - remove_non_owning_ref_types: bool = False, -) -> NamedCType: - # If it's a value type, do the value type translation - r = valuetype_type( - t, - binds=binds, - ) - if r is not None: - return r - if isinstance(t, BaseType): - if t.name == BaseTy.Tensor: - if mutable and not local.use_const_ref_for_mutable_tensors(): - return NamedCType(binds, MutRefCType(BaseCType(tensorT))) - else: - return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) - elif t.name == BaseTy.Scalar: - return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) - else: - raise AssertionError(f"base type should have been value type {t}") - elif isinstance(t, OptionalType): - if str(t.elem) == "Tensor": - if mutable and not local.use_const_ref_for_mutable_tensors(): - return NamedCType( - binds, MutRefCType(BaseCType(tensorT)) - ) # TODO: fix this discrepancy - else: - return NamedCType( - binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) - ) - elif str(t.elem) == "Scalar": - return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) - elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) - return NamedCType(binds, OptionalCType(elem.type)) - elif isinstance(t, ListType): - # TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels. - if str(t.elem) == "Tensor": - return NamedCType(binds, BaseCType(tensorListT)) - elif str(t.elem) == "Dimname": - raise NotImplementedError("Executorch doesn't support Dimname") - elif str(t.elem) == "Tensor?": - return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT)))) - elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) - return NamedCType(binds, ArrayRefCType(elem.type)) - else: - raise AssertionError(f"unrecognized type {repr(t)}") - - -# Translate a JIT argument into its C++ type -def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: - return argumenttype_type(a.type, mutable=a.is_write, binds=binds) - - -# Translation of a (non-multi) return type from JIT to C++ -# N.B: returntype_type returns a CType, not a NamedCType. -# This is mostly because of the mismatch between return types and return names. -# e.g. a function with a return type of 'void' has 0 return names, -# and a function with a return type of 'std::tuple' has >1 return name. -def returntype_type(t: Type, *, mutable: bool) -> CType: - # placeholder is ignored - r = valuetype_type(t, binds="__placeholder__") - if r is not None: - return r.type - - if isinstance(t, BaseType): - if t.name == BaseTy.Tensor: - if mutable: - if local.use_const_ref_for_mutable_tensors(): - return ConstRefCType(BaseCType(tensorT)) - else: - return MutRefCType(BaseCType(tensorT)) - else: - # Note [Tensor Copy Returns] - # Currently, we use "Argument.is_write" to determine - # whether or not Tensor return types should be copies or references. - # If that ever changes, take a look at other locations of this note! - return BaseCType(tensorT) - elif t.name == BaseTy.Scalar: - return BaseCType(scalarT) - elif isinstance(t, ListType): - assert not mutable, ( - "Native functions should never return a mutable tensor list. They should return void." - ) - elem = returntype_type(t.elem, mutable=False) - assert t.size is None, f"fixed size list returns not supported: {t}" - return VectorCType(elem) - - raise AssertionError(f"unrecognized return type {t}") - - -# Translation of a single return to its C++ type -def return_type(r: Return) -> CType: - return returntype_type(r.type, mutable=r.is_write) - - -# Translation of a full (possibly multi) return from JIT to its C++ type -def returns_type(rs: Sequence[Return]) -> CType: - if len(rs) == 0: - return BaseCType(voidT) - elif len(rs) == 1: - return return_type(rs[0]) - else: - return TupleCType([return_type(r) for r in rs]) - - -def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: - returns: list[str] = [] - for i, r in enumerate(f.func.returns): - # If we have an inplace function, the return argument is - # implicitly named self. - # TODO: Consider incorporating this into the data model - if f.func.name.name.inplace: - assert i == 0, "illegal inplace function with multiple returns" - name = "self" - # If we are out function, the name is the name of the - # corresponding output function (r.name will get recorded - # in field_name later.) - elif f.func.is_out_fn(): - name = f.func.arguments.out[i].name - # If the return argument is explicitly named... - elif r.name: - name_conflict = any( - r.name == a.name for a in f.func.schema_order_arguments() - ) - if name_conflict and not f.func.is_out_fn(): - name = f"{r.name}_return" - else: - name = r.name - # If there is no explicit name and no fallback name was passed in, we just name the output result, - # unless it's a multi-return, in which case it's result0, - # result1, etc (zero-indexed) - else: - name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}" - returns.append(name) - return returns - - -JIT_TO_CPP_DEFAULT = { - "False": "false", - "True": "true", - "None": "torch::execustd::nullopt", # UGH this one is type directed - "[]": "{}", - "contiguous_format": "torch::executorch::MemoryFormat::Contiguous", - "long": "torch::executorch::kLong", -} - - -# Convert a JIT default into C++ expression representing the default -def default_expr(d: str, t: Type) -> str: - if d == "None" and str(t) == "Tensor?": - return "{}" - if isinstance(t, BaseType) and t.name is BaseTy.str: - # Schema allows single quotes but C++ needs double - if len(d) >= 2 and d[0] == "'" and d[-1] == "'": - s = "" - i = 1 - while i + 1 < len(d): - if d[i] != "\\": - if d[i] == '"': - s += '\\"' - else: - s += d[i] - i += 1 - else: - if d[i + 1] == "'": - s += "'" - else: - s += d[i : i + 2] - i += 2 - - return f'"{s}"' - - if isinstance(t, OptionalType): - if d == "None": - return "torch::executor::nullopt" - - return default_expr(d, t.elem) - - if isinstance(t, ListType): - if d.startswith("[") and d.endswith("]"): - return "{" + d[1:-1] + "}" - elif t.size is None: - # NOTE: Sized lists can have scalar defaults - raise ValueError(f"Expected a list default '[...]' but found: '{d}'") - - return JIT_TO_CPP_DEFAULT.get(d, d) - - -# Convert an argument into its C++ API form - - -def argument( - a: Argument | TensorOptionsArguments | SelfArgument, - *, - cpp_no_default_args: set[str], - method: bool, - faithful: bool, - has_tensor_options: bool, -) -> list[Binding]: - def sub_argument( - a: Argument | TensorOptionsArguments | SelfArgument, - ) -> list[Binding]: - return argument( - a, - cpp_no_default_args=cpp_no_default_args, - method=method, - faithful=faithful, - has_tensor_options=has_tensor_options, - ) - - if isinstance(a, Argument): - binds: ArgName - if a.name == "memory_format" and has_tensor_options: - binds = SpecialArgName.possibly_redundant_memory_format - else: - binds = a.name - default: str | None = None - if a.name not in cpp_no_default_args and a.default is not None: - default = default_expr(a.default, a.type) - return [ - Binding( - nctype=argument_type(a, binds=binds), - name=a.name, - default=default, - argument=a, - ) - ] - elif isinstance(a, TensorOptionsArguments): - raise NotImplementedError("Need to implement type resolution for TensorOptions") - elif isinstance(a, SelfArgument): - if method: - # Caller is responsible for installing implicit this in context! - return [] - else: - return sub_argument(a.argument) - else: - assert_never(a) - - -def arguments( - arguments: Arguments, - *, - faithful: bool, - method: bool, - cpp_no_default_args: set[str], -) -> list[Binding]: - args: list[Argument | TensorOptionsArguments | SelfArgument] = [] - if faithful: - args.extend(arguments.non_out) - args.extend(arguments.out) - else: - args.extend(arguments.out) - args.extend(arguments.non_out) - return [ - r.no_default() if faithful else r - for a in args - for r in argument( - a, - faithful=faithful, - method=method, - has_tensor_options=arguments.tensor_options is not None, - cpp_no_default_args=cpp_no_default_args, - ) - ] diff --git a/torchgen/executorch/api/types/__init__.py b/torchgen/executorch/api/types/__init__.py deleted file mode 100644 index 08cb168df73716..00000000000000 --- a/torchgen/executorch/api/types/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from torchgen.executorch.api.types.types import * - - -from torchgen.executorch.api.types.signatures import * # usort: skip diff --git a/torchgen/executorch/api/types/signatures.py b/torchgen/executorch/api/types/signatures.py deleted file mode 100644 index ac3477cede6ed0..00000000000000 --- a/torchgen/executorch/api/types/signatures.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING - -import torchgen.api.cpp as aten_cpp -from torchgen.executorch.api.types.types import contextArg - - -if TYPE_CHECKING: - from torchgen.api.types import Binding, CType - from torchgen.model import FunctionSchema, NativeFunction - - -@dataclass(frozen=True) -class ExecutorchCppSignature: - """ - This signature is merely a CppSignature with Executorch types (optionally - contains KernelRuntimeContext as well). The inline definition of - CppSignature is generated in Functions.h and it's used by unboxing - functions. - """ - - # The schema this signature is derived from - func: FunctionSchema - - # The set of C++ arguments which should not have defaults applied to them - cpp_no_default_args: set[str] - - # Allows you to prepend an arbitrary prefix to the signature name. - # This is useful for parts of the codegen that generate wrappers around kernels, - # and need to avoid naming collisions. - prefix: str = "" - - def arguments(self, *, include_context: bool = True) -> list[Binding]: - return ([contextArg] if include_context else []) + et_cpp.arguments( - self.func.arguments, - faithful=True, # always faithful, out argument at the end - method=False, # method not supported - cpp_no_default_args=self.cpp_no_default_args, - ) - - def name(self) -> str: - return self.prefix + aten_cpp.name( - self.func, - faithful_name_for_out_overloads=True, - ) - - def decl(self, name: str | None = None, *, include_context: bool = True) -> str: - args_str = ", ".join( - a.decl() for a in self.arguments(include_context=include_context) - ) - if name is None: - name = self.name() - return f"{self.returns_type().cpp_type()} {name}({args_str})" - - def defn(self, name: str | None = None) -> str: - args = [a.defn() for a in self.arguments()] - args_str = ", ".join(args) - if name is None: - name = self.name() - return f"{self.returns_type().cpp_type()} {name}({args_str})" - - def returns_type(self) -> CType: - return et_cpp.returns_type(self.func.returns) - - @staticmethod - def from_native_function( - f: NativeFunction, *, prefix: str = "" - ) -> ExecutorchCppSignature: - return ExecutorchCppSignature( - func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args - ) - - -from torchgen.executorch.api import et_cpp diff --git a/torchgen/executorch/api/types/types.py b/torchgen/executorch/api/types/types.py deleted file mode 100644 index 712d7e5e341f41..00000000000000 --- a/torchgen/executorch/api/types/types.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass - -from torchgen.api.types import ( - BaseCppType, - BaseCType, - Binding, - boolT, - CType, - doubleT, - Expr, - longT, - MutRefCType, - NamedCType, -) -from torchgen.model import BaseTy - - -halfT = BaseCppType("torch::executor", "Half") -bfloat16T = BaseCppType("torch::executor", "BFloat16") -stringT = BaseCppType("torch::executor", "string_view") -scalarTypeT = BaseCppType("torch::executor", "ScalarType") -tensorT = BaseCppType("torch::executor", "Tensor") -tensorListT = BaseCppType("torch::executor", "TensorList") -scalarT = BaseCppType("torch::executor", "Scalar") -memoryFormatT = BaseCppType("torch::executor", "MemoryFormat") -intArrayRefT = BaseCppType("torch::executor", "IntArrayRef") -optionalT = BaseCppType("torch::executor", "optional") -contextT = BaseCppType("torch::executor", "KernelRuntimeContext") - -contextExpr = Expr( - expr="context", - type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))), -) - -contextArg = Binding( - name="context", - nctype=contextExpr.type, - argument=None, # type: ignore[arg-type] - default=None, -) - -BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { - BaseTy.int: longT, - BaseTy.float: doubleT, - BaseTy.bool: boolT, - BaseTy.str: stringT, - BaseTy.ScalarType: scalarTypeT, - BaseTy.Tensor: tensorT, - BaseTy.Scalar: scalarT, - BaseTy.MemoryFormat: memoryFormatT, -} - - -@dataclass(frozen=True) -class OptionalCType(CType): - elem: CType - - def cpp_type(self, *, strip_ref: bool = False) -> str: - # Do not pass `strip_ref` recursively. - return f"torch::executor::optional<{self.elem.cpp_type()}>" - - def remove_const_ref(self) -> CType: - return OptionalCType(self.elem.remove_const_ref()) - - -@dataclass(frozen=True) -class ArrayRefCType(CType): - elem: CType - - def cpp_type(self, *, strip_ref: bool = False) -> str: - # Do not pass `strip_ref` recursively. - return f"torch::executor::ArrayRef<{self.elem.cpp_type()}>" - - def remove_const_ref(self) -> CType: - return ArrayRefCType(self.elem.remove_const_ref()) diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py deleted file mode 100644 index 6d648f715114f7..00000000000000 --- a/torchgen/executorch/api/unboxing.py +++ /dev/null @@ -1,218 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Callable, TYPE_CHECKING - -from torchgen.model import ( - Argument, - BaseTy, - BaseType, - ListType, - NativeFunction, - OptionalType, - Type, -) - - -if TYPE_CHECKING: - from collections.abc import Sequence - - from torchgen.api.types import Binding, CType, NamedCType - - -connector = "\n\t" - - -# Return unboxing function name for a NativeFunction -def name(f: NativeFunction) -> str: - return f.func.name.unambiguous_name() - - -@dataclass(frozen=True) -class Unboxing: - """ - Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing. - A sample generated code: - // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - void mul_out(EValue** stack) { - EValue& self = *stack[0]; - EValue& other = *stack[1]; - EValue& out = *stack[2]; - const torch::executor::Tensor & self_base = self.to(); - const torch::executor::Tensor & other_base = other.to(); - torch::executor::Tensor & out_base = out.to(); - - EXECUTORCH_SCOPE_PROF("native_call_mul.out"); - torch::executor::mul_outf(self_base, other_base, out_base); - - - } - """ - - # this is a callable that converts a JIT argument, into its C++ type. - # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type. - argument_type_gen: Callable[ - ..., - NamedCType, - ] - - # Convert all the arguments in a NativeFunction to C++ code - def convert_arguments( - self, args: Sequence[Binding] - ) -> tuple[list[Binding], list[str]]: - code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))] - binding_list = [] - for arg in args: - # expecting only Argument - if not isinstance(arg.argument, Argument): - raise Exception( # noqa: TRY002 - f"Unexpected argument type, expecting `Argument` but got {arg}" - ) - argument: Argument = arg.argument - unboxed_name, _, code, decl = self.argumenttype_evalue_convert( - argument.type, argument.name, mutable=argument.is_write - ) - code_list.extend(decl) - code_list.extend(code) - binding_list.append(arg.with_name(unboxed_name)) - return binding_list, code_list - - def argumenttype_evalue_convert( - self, t: Type, arg_name: str, *, mutable: bool = False - ) -> tuple[str, CType, list[str], list[str]]: - """ - Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: - (1) the C++ code necessary to unbox the argument - (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType - :param t: a `Type` of an argument - :param arg_name: argument name - :param mutable: boolean for whether this argument type is mutable - :return: unboxed result - """ - ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type - - if isinstance(t, BaseType): - out_name = f"{arg_name}_base" - code, decl = self._gen_code_base_type( - arg_name=arg_name, out_name=out_name, ctype=ctype - ) - elif isinstance(t, OptionalType): - out_name = f"{arg_name}_opt_out" - code, decl = self._gen_code_optional_type( - arg_name=arg_name, out_name=out_name, t=t, ctype=ctype - ) - elif isinstance(t, ListType): - out_name = f"{arg_name}_list_out" - code, decl = self._gen_code_list_type( - arg_name=arg_name, out_name=out_name, t=t, ctype=ctype - ) - else: - raise Exception( # noqa: TRY002 - f"Cannot handle type {t}. arg_name: {arg_name}" - ) # noqa: TRY002 - return out_name, ctype, code, decl - - def _gen_code_base_type( - self, arg_name: str, out_name: str, ctype: CType - ) -> tuple[list[str], list[str]]: - return [ - f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" - ], [] - - def _gen_code_optional_type( - self, arg_name: str, out_name: str, t: OptionalType, ctype: CType - ) -> tuple[list[str], list[str]]: - in_name = f"{arg_name}_opt_in" - res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( - t.elem, in_name - ) - return ( - f""" - auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>(); - """.split("\n"), - decl, - ) - - def _gen_code_list_type( - self, arg_name: str, out_name: str, t: ListType, ctype: CType - ) -> tuple[list[str], list[str]]: - in_name = f"{arg_name}_list_in" - elem_name = f"{arg_name}_elem" - code = [] - res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert( - t.elem, elem_name - ) - - if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor: - code.extend( - f""" - auto {out_name} = {arg_name}.toTensorList(); - """.split("\n") - ) - elif isinstance(t.elem, BaseType) and ( - t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt - ): - code.extend( - f""" - auto {out_name} = {arg_name}.toIntList(); - """.split("\n") - ) - elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float: - code.extend( - f""" - auto {out_name} = {arg_name}.toDoubleList(); - """.split("\n") - ) - elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool: - # handle list type with size, e.g., bool[4] - code.extend( - f""" -#ifdef USE_ATEN_LIB -std::array {out_name}; -auto {in_name} = {arg_name}.toBoolList(); -size_t _i = 0; -for (auto {elem_name}: {in_name}) {{ - {out_name}[_i++] = {elem_name}; -}} -#else -auto {out_name} = {arg_name}.toBoolList(); -#endif - """.split("\n") - ) - # pytorch codegen: - # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> - elif ( - isinstance(t.elem, OptionalType) - and isinstance(t.elem.elem, BaseType) - and t.elem.elem.name == BaseTy.Tensor - ): - code.extend( - f""" -#ifdef USE_ATEN_LIB -auto {in_name} = {arg_name}.toListOptionalTensor(); -c10::List<::std::optional> {out_name}; -for (auto {elem_name}: {in_name}) {{ - {out_name}.push_back({elem_name}); -}} -#else -auto {out_name} = {arg_name}.toListOptionalTensor(); -#endif - """.split("\n") - ) - else: - # use ArrayRef as default. - vec_name = arg_name + "_vec" - # need to bring vector instantiation out of scope so that ArrayRef has valid data - decl.append( - f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};" - ) - code.extend( - f""" - for (EValue {elem_name}: {in_name}) {{ - {connector.join(res_code)} - {vec_name}.push_back({res_name}); - }} - {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); - """.split("\n") - ) - return code, decl diff --git a/torchgen/executorch/model.py b/torchgen/executorch/model.py deleted file mode 100644 index 310c5968ec0d97..00000000000000 --- a/torchgen/executorch/model.py +++ /dev/null @@ -1,220 +0,0 @@ -# Represents all kernels used by an Executorch model. -# It maintains a dict[OperatorName, dict[ETKernelKey, BackendMetadata]] structure. - -from __future__ import annotations - -import itertools -from collections import defaultdict, namedtuple -from dataclasses import dataclass -from enum import IntEnum -from typing_extensions import assert_never - -from torchgen.model import ( - BackendIndex, - BackendMetadata, - DispatchKey, - NativeFunction, - NativeFunctionsGroup, - OperatorName, -) - - -KERNEL_KEY_VERSION = 1 - - -# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen -class ScalarType(IntEnum): - Byte = 0 - Char = 1 - Short = 2 - Int = 3 - Long = 4 - Float = 6 - Double = 7 - Bool = 11 - - -ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"]) - - -@dataclass(frozen=True) -class ETKernelKeyOpArgMeta: - arg_name: str - dtype: str - # The order of the dimensions if entry is a Tensor - dim_order: tuple[int, ...] - - def to_native_string(self) -> str: - dtype_str = ScalarType[self.dtype].value - dim_str = str(self.dim_order)[1:-1].replace(" ", "") - return f"{dtype_str};{dim_str}" - - -@dataclass(frozen=True) -class ETKernelKey: - # Field undefined is default = True - arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = () - - # Indicator for this kernel being used as a catch all - default: bool = False - - version: int = KERNEL_KEY_VERSION - - @staticmethod - def gen_from_yaml( - args: dict[str, tuple[str, str]], - type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val - dim_order_alias_map: dict[str, list[int]], - ) -> list[ETKernelKey]: - """Generate ETKernelKeys from arg kernel specs - Multiple ETKernelKeys are returned due to dtype permutations from utilizing - type_alias_map (actualizing each potential type permutation as a KernelKey) - - Args: - args: Mapping from argument name to kernel specs - Kernel specs are a tuple of (dtype, dim_order). - Currently tuple entries must be aliased via the alias map arguments - type_alias_map: Mapping from type alias to potential type enums - i.e { T0 : [Double, Int] } means T0 can be either Double or Int - Used for lookup by args - dim_order_alias_map: Mapping from alias to a list of dimension orders - Used for lookup by args - """ - # Cast to dim order to int - dim_order_alias_map = { - k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items() - } - kernel_keys = [] - - # Get all used Dtype Alias - dtype_alias_used = set() - for type_alias, dim_order in args.values(): - # Enforce usage of alias initially - # TODO: Support inlined arguments - assert type_alias in type_alias_map, "Undefined type alias: " + str( - type_alias - ) - assert dim_order in dim_order_alias_map, ( - f"Undefined dim_order alias: {dim_order}" - ) - dtype_alias_used.add(type_alias) - - # Generate all permutations of dtype alias values - alias_dtypes = [ - [(alias, dtype) for dtype in type_alias_map[alias]] - for alias in dtype_alias_used - ] - alias_permutations = [ - dict(permutation) for permutation in list(itertools.product(*alias_dtypes)) - ] - - # Using each alias value permutation, generate kernel keys - op_arg_cache = {} - for permutation in alias_permutations: - arg_list = [] - for arg_name, arg_spec in args.items(): - dtype = permutation[arg_spec[0]] - dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment] - if ( - cache_key := (arg_name, dtype, tuple(dim_order)) - ) not in op_arg_cache: - op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type] - - arg_list.append(op_arg_cache[cache_key]) - kernel_keys.append(ETKernelKey(tuple(arg_list))) - - return kernel_keys - - def to_native_string(self) -> str: - if self.default: - return "default" - return ( - "v" - + str(KERNEL_KEY_VERSION) - + "/" - + "|".join([arg.to_native_string() for arg in self.arg_meta]) - ) - - -@dataclass(frozen=True) -class ETKernelIndex: - index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] - - def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool: - m = self.get_kernels(g) - return m is not None - - def get_kernels( - self, g: NativeFunction | NativeFunctionsGroup - ) -> dict[ETKernelKey, BackendMetadata]: - if isinstance(g, NativeFunction): - f = g - elif isinstance(g, NativeFunctionsGroup): - f = g.functional - else: - assert_never(g) - if f.func.name not in self.index: - return {} - return self.index[f.func.name] - - @staticmethod - def grow_from_backend_indices( - kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]], - backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], - ) -> None: - for dk in backend_indices: - index = backend_indices[dk] - for op, backend_metadata in index.items(): - if op in kernel_index: - kernel_index[op][ETKernelKey(default=True)] = backend_metadata - else: - kernel_index[op] = {ETKernelKey(default=True): backend_metadata} - - @staticmethod - def from_backend_indices( - backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], - ) -> ETKernelIndex: - kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = ( - defaultdict(dict) - ) - ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices) - return ETKernelIndex(kernel_index) - - def grow( - self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] - ) -> ETKernelIndex: - ETKernelIndex.grow_from_backend_indices(self.index, backend_indices) - return self - - def _to_backend_index(self) -> BackendIndex: - """ - WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex. - """ - index: dict[OperatorName, BackendMetadata] = {} - for op in self.index: - kernel_dict = self.index[op] - assert len(kernel_dict.values()) == 1, ( - f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}" - ) - index[op] = kernel_dict.get( - ETKernelKey(default=True), - BackendMetadata(kernel="", structured=False, cpp_namespace=""), - ) - return BackendIndex( - dispatch_key=DispatchKey.CPU, - use_out_as_primary=False, - device_guard=False, - external=False, - index=index, - ) - - # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a - @staticmethod - def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex: - combined = defaultdict(dict, index_a.index.copy()) - - for op, entry in index_b.index.items(): - for key, metadata in entry.items(): - combined[op][key] = metadata - - return ETKernelIndex(combined) diff --git a/torchgen/executorch/parse.py b/torchgen/executorch/parse.py deleted file mode 100644 index 8095abd5b6bc33..00000000000000 --- a/torchgen/executorch/parse.py +++ /dev/null @@ -1,153 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict, namedtuple -from typing import Any - -import yaml - -from torchgen.executorch.model import ETKernelIndex, ETKernelKey -from torchgen.gen import LineLoader, parse_native_yaml -from torchgen.model import ( - BackendMetadata, - DispatchKey, - FunctionSchema, - NativeFunction, - OperatorName, -) -from torchgen.utils import NamespaceHelper - - -# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices. -ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"]) - -# Fields in native_functions.yaml used to determine which kernels should be used -ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"] - - -def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]: - """Given a loaded yaml representing kernel assignment information, extract the - mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance) - - Args: - ei: Dict keys {kernels, type_alias, dim_order_alias} - See ETKernelKey for description of arguments - """ - e = ei.copy() - if (kernels := e.pop("kernels", None)) is None: - return {} - - type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment] - dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment] - dim_order_alias.pop("__line__", None) - - kernel_mapping: dict[ETKernelKey, BackendMetadata] = {} - - for entry in kernels: # type: ignore[attr-defined] - arg_meta = entry.get("arg_meta") - if arg_meta is not None: - arg_meta.pop("__line__") - - kernel_name = entry.get("kernel_name") - namespace_helper = NamespaceHelper.from_namespaced_entity( - kernel_name, max_level=3 - ) - kernel_namespace = namespace_helper.get_cpp_namespace(default="at") - backend_metadata = BackendMetadata( - kernel=namespace_helper.entity_name, - structured=False, - cpp_namespace=(kernel_namespace + "::native"), - ) - - kernel_keys = ( - [ETKernelKey((), default=True)] - if arg_meta is None - else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type] - ) - - for kernel_key in kernel_keys: - assert kernel_key not in kernel_mapping, ( - "Duplicate kernel key: " + str(kernel_key) + " " + str(e) - ) - kernel_mapping[kernel_key] = backend_metadata - - return kernel_mapping - - -def parse_et_yaml_struct(es: object) -> ETKernelIndex: - """Given a loaded yaml representing a list of operators, for each op extract the mapping - of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance - that should be used by the kernel key). - """ - indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {} - for ei in es: # type: ignore[attr-defined] - e = ei.copy() - - funcs = e.pop("func") - assert isinstance(funcs, str), f"not a str: {funcs}" - namespace_helper = NamespaceHelper.from_namespaced_entity( - namespaced_entity=funcs, max_level=1 - ) - opname = FunctionSchema.parse(namespace_helper.entity_name).name - - assert opname not in indices, f"Duplicate func found in yaml: {opname} already" - - if len(index := parse_from_yaml(e)) != 0: - indices[opname] = index - - return ETKernelIndex(indices) - - -def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]: - """Given a loaded yaml representing a list of operators, extract the - kernel key related fields indexed by the operator name. - """ - fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict) - for ei in es: # type: ignore[attr-defined] - funcs = ei.get("func") - assert isinstance(funcs, str), f"not a str: {funcs}" - namespace_helper = NamespaceHelper.from_namespaced_entity( - namespaced_entity=funcs, max_level=1 - ) - opname = FunctionSchema.parse(namespace_helper.entity_name).name - - for field in ET_FIELDS: - if (value := ei.get(field)) is not None: - fields[opname][field] = value - - return fields - - -def parse_et_yaml( - path: str, - tags_yaml_path: str, - ignore_keys: set[DispatchKey] | None = None, - skip_native_fns_gen: bool = False, -) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]: - """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict - of fields to persist from native_functions.yaml to functions.yaml - """ - with open(path) as f: - es = yaml.load(f, Loader=LineLoader) - - et_kernel = extract_kernel_fields(es) - - # Remove ET specific fields from entries for BC compatibility - strip_et_fields(es) - - native_yaml = parse_native_yaml( - path, - tags_yaml_path, - ignore_keys, - skip_native_fns_gen=skip_native_fns_gen, - loaded_yaml=es, - ) - return native_yaml.native_functions, et_kernel - - -def strip_et_fields(es: object) -> None: - """Given a loaded yaml representing a list of operators, - remove ET specific fields from every entries for BC compatibility - """ - for entry in es: # type: ignore[attr-defined] - for field in ET_FIELDS: - entry.pop(field, None) diff --git a/torchgen/gen.py b/torchgen/gen.py index b584a87880f6f7..e5186620e8a473 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -18,7 +18,6 @@ import torchgen.api.native as native import torchgen.api.structured as structured import torchgen.dest as dest -from torchgen.aoti.fallback_ops import inductor_fallback_ops from torchgen.api import cpp from torchgen.api.translate import translate from torchgen.api.types import ( @@ -37,10 +36,8 @@ with_native_function_and_indices, ) from torchgen.gen_aoti_c_shim import ( - gen_aoti_c_shim, + gen_aoti_c_shim_files, gen_static_dispatch_backend_call_signature, - get_fallback_op_name, - get_header_for_aoti, ) from torchgen.gen_functionalization_type import ( gen_functionalization_definition, @@ -509,7 +506,7 @@ def static_dispatch( ) -> str: """ For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one - backends exsit, fallback to static dispatch by determining dispatch key from inputs. + backends exist, fallback to static dispatch by determining dispatch key from inputs. Arguments: sig: A CppSignature or DispatcherSignature for this native function we want to use. f: NativeFunction to generate static dispatch. @@ -2395,101 +2392,19 @@ def register_dispatch_key_env_callable( else: raise AssertionError(f"unrecognized {dispatch_key} for ufunc") - structured_func_group_dict = {} - for func_group in structured_native_functions: - for func in func_group.functions(): - if func.structured_delegate is not None: - structured_func_group_dict[func.structured_delegate] = func_group - break - - if dispatch_key in aoti_backends: - fallbacks = {} - for func in native_functions: - op_name = get_fallback_op_name(func) - if op_name in inductor_fallback_ops: - fallbacks[op_name] = func - fallback_native_functions = tuple( - value for _, value in sorted(fallbacks.items()) - ) - - # header files were checked in for ABI-compatiblilty checking - header_file_name = f"c_shim_{dispatch_key.lower()}.h" - new_header = gen_aoti_c_shim( - fallback_native_functions, - structured_func_group_dict, - dispatch_key, - backend_indices, - header=True, - extend_aoti_c_shim=extend_aoti_c_shim, - includes="", - ) - if update_aoti_c_shim: - aoti_fm.write( - header_file_name, - lambda: new_header, - ) - else: - try: - with open( - os.path.join(aoti_fm.install_dir, header_file_name) - ) as old_file: - old_header = old_file.read() - assert old_header == new_header, """ - -WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This -indicates an AOTInductor fallback operator ABI backward compatibility breakage!!! -Only in a limited number of situations, this is allowed: - -1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py. -If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing -C shim header files. - -2. You added a new default argument to an existing fallback op. This is clearly a BC breaking -change in the AOTInductor land. In this case, you need to keep a manual copy of that existing -fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version -number of that fallback op in the newly generated C shim files, and update the cpp wrapper -codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance. - - """ - except FileNotFoundError: - print( - f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found" - ) - - # cpp files are always generated on-the-fly - def headers_for_aoti() -> str: - headers = [] - for func in fallback_native_functions: - header = get_header_for_aoti( - func, - structured_func_group_dict, - dispatch_key, - backend_indices, - extend_aoti_c_shim=extend_aoti_c_shim, - ) - if header is not None: - headers.append(header) - return "\n".join(sorted(set(headers))) - - extra_headers = ( - extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else "" - ) - - aoti_fm.write( - f"c_shim_{dispatch_key.lower()}.cpp", - lambda: gen_aoti_c_shim( - fallback_native_functions, - structured_func_group_dict, - dispatch_key, - backend_indices, - header=False, - extend_aoti_c_shim=extend_aoti_c_shim, - includes=headers_for_aoti() + "\n" + extra_headers, - ), - ) - del fm + gen_aoti_c_shim_files( + aoti_fm=aoti_fm, + aoti_backends=aoti_backends, + native_functions=native_functions, + backend_indices=backend_indices, + structured_native_functions=structured_native_functions, + extra_cuda_headers=extra_cuda_headers, + update_aoti_c_shim=update_aoti_c_shim, + extend_aoti_c_shim=extend_aoti_c_shim, + ) + # BackendSelect is generated specially def gen_backend_select() -> dict[str, list[str]]: relevant_fns = [ @@ -2696,7 +2611,7 @@ def gen_op_headers( # but they could theoretically be called from user code (I added these kernels for completeness, # since the ops are part of the public API). # (2) A derivative formula for every {view}_copy operator - # {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts, + # {view}_copy operators can reuse the same derivative formulas as their {view} op counterparts, # so rather than stamping all of the entries out in derivatives.yaml, # we codegen them in. # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry. @@ -2909,15 +2824,40 @@ def main() -> None: from torchgen.model import dispatch_keys + # Only a limited set of dispatch keys get CPUFunctions.h headers generated + # for them; this is the set + functions_keys = { + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, + DispatchKey.Meta, + DispatchKey.MTIA, + } + + aoti_backends = { + DispatchKey.CPU, + DispatchKey.CUDA, + } + # TODO: stop generating CUDA kernels for non-CUDA builds ignore_keys = set() - if not options.mps: + + if options.mps or options.update_aoti_c_shim: + functions_keys.add(DispatchKey.MPS) + aoti_backends.add(DispatchKey.MPS) + else: ignore_keys.add(DispatchKey.MPS) if DispatchKey.MPS in dispatch_keys: del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] - if not options.xpu: + if options.xpu or options.update_aoti_c_shim: + functions_keys.add(DispatchKey.XPU) + aoti_backends.add(DispatchKey.XPU) + else: ignore_keys.add(DispatchKey.XPU) if DispatchKey.XPU in dispatch_keys: @@ -2929,6 +2869,13 @@ def main() -> None: if DispatchKey.MTIA in dispatch_keys: del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)] + if options.backend_whitelist: + dispatch_keys = [ + k + for k in dispatch_keys + if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist + ] + parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys) valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path] native_functions, backend_indices = ( @@ -2978,39 +2925,6 @@ def main() -> None: if options.xpu: device_fms["xpu"] = make_file_manager(options=options) - # Only a limited set of dispatch keys get CPUFunctions.h headers generated - # for them; this is the set - functions_keys = { - DispatchKey.CPU, - DispatchKey.CUDA, - DispatchKey.CompositeImplicitAutograd, - DispatchKey.CompositeImplicitAutogradNestedTensor, - DispatchKey.CompositeExplicitAutograd, - DispatchKey.CompositeExplicitAutogradNonFunctional, - DispatchKey.Meta, - DispatchKey.MTIA, - } - - aoti_backends = { - DispatchKey.CPU, - DispatchKey.CUDA, - } - - if options.mps: - functions_keys.add(DispatchKey.MPS) - aoti_backends.add(DispatchKey.MPS) - - if options.xpu: - functions_keys.add(DispatchKey.XPU) - aoti_backends.add(DispatchKey.XPU) - - if options.backend_whitelist: - dispatch_keys = [ - k - for k in dispatch_keys - if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist - ] - static_dispatch_idx: list[BackendIndex] = [] if options.static_dispatch_backend: static_dispatch_idx = [ diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index d02addaa607208..965ae8d268911d 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -1,9 +1,12 @@ from __future__ import annotations +import difflib +import os import textwrap from dataclasses import dataclass from typing import TYPE_CHECKING +from torchgen.aoti.fallback_ops import inductor_fallback_ops from torchgen.api.types import DispatcherSignature from torchgen.api.types.signatures import CppSignature, CppSignatureGroup from torchgen.context import method_with_native_function @@ -14,6 +17,7 @@ BaseType, DispatchKey, FunctionSchema, + is_cuda_dispatch_key, ListType, NativeFunction, NativeFunctionsGroup, @@ -21,7 +25,7 @@ OptionalType, Type, ) -from torchgen.utils import mapMaybe +from torchgen.utils import FileManager, mapMaybe if TYPE_CHECKING: @@ -199,11 +203,16 @@ def zip_type_and_name(types: list[str], names: list[str]) -> list[str]: # Generate argument declarations and callsite expressions -def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]: - types = [] - new_names = [] - callsite_exprs = [] +def gen_arguments( + flat_arguments: Sequence[Argument], skipped_args: set[str] +) -> tuple[list[str], list[str]]: + types: list[str] = [] + new_names: list[str] = [] + callsite_exprs: list[str] = [] for arg in flat_arguments: + if arg.name in skipped_args: + callsite_exprs.append("std::nullopt") + continue new_types, names, _, new_callsite_exprs = convert_arg_type_and_name( arg.type, arg.name, arg.is_write ) @@ -230,7 +239,7 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]: def convert_return(typ: BaseType, val: str) -> str: if typ.name == BaseTy.Tensor: - return f"new_tensor_handle(std::move({val}));" + return f"new_tensor_handle(std::move({val}))" elif typ.name == BaseTy.SymInt: return f"{val}.expect_int()" elif typ.name == BaseTy.Scalar: @@ -269,47 +278,93 @@ def convert_return(typ: BaseType, val: str) -> str: def gen_declaration_and_definition( - schema: FunctionSchema, device: str, backend_call: str + schema: FunctionSchema, + device: str, + backend_call: str, + version_info: dict[str, list[str]], ) -> tuple[str, str]: - func_name = schema.name.unambiguous_name() + base_name = schema.name.unambiguous_name() global declaration_definition_cache - if (func_name, device, backend_call) in declaration_definition_cache: - return declaration_definition_cache[(func_name, device, backend_call)] - - if schema.is_out_fn(): - # out_variant has out arguments in the front, and it's ok to ignore return values - # because C shim functions only return AOTITorchError - args, callsite_exprs = gen_arguments( - [*schema.arguments.out, *schema.arguments.flat_non_out] + if (base_name, device, backend_call) in declaration_definition_cache: + return declaration_definition_cache[(base_name, device, backend_call)] + + # Check the validity of version_info. The format should look like + # {"v2" : ["new_arg1"], "v3": ["new_arg2, new_arg3"]}. + indexed_version_info: dict[int, list[str]] = {1: []} + for ver_str, new_args in sorted(version_info.items()): + assert ver_str.startswith("v"), ( + f"Version number for {base_name} is {ver_str}, not starting with 'v'" ) - ret_assignments: list[str] = [] - else: - args, callsite_exprs = gen_arguments(schema.arguments.flat_all) - # ignore return values for inplace ops - ret_declarations, ret_assignments = ( - ([], []) if schema.name.name.inplace else gen_returns(schema) + try: + ver_id = int(ver_str[1:]) + except ValueError as e: + raise AssertionError( + f"Version number for {base_name} is {ver_str}, not a valid integer after 'v'" + ) from e + assert ver_id not in indexed_version_info, ( + f"{ver_str} for {base_name} has already been defined" + ) + indexed_version_info[ver_id] = new_args + + declarations: list[str] = [] + definitions: list[str] = [] + skipped_args: set[str] = set() + + for ver_id, new_args in sorted(indexed_version_info.items(), reverse=True): + # Iterate in the reverse order, so the latest version of an op will get generated first + # with all the arguments included, while a set of to-be-trimmed args is carried down + # to generate earlier version of the op. + func_name = base_name if ver_id == 1 else f"{base_name}_v{ver_id}" + if schema.is_out_fn(): + # out_variant has out arguments in the front, and it's ok to ignore return values + # because C shim functions only return AOTITorchError + args, callsite_exprs = gen_arguments( + [*schema.arguments.out, *schema.arguments.flat_non_out], skipped_args + ) + ret_assignments: list[str] = [] + else: + args, callsite_exprs = gen_arguments( + schema.arguments.flat_all, skipped_args + ) + # ignore return values for inplace ops + ret_declarations, ret_assignments = ( + ([], []) if schema.name.name.inplace else gen_returns(schema) + ) + args.extend(ret_declarations) + + declaration = textwrap.dedent( + f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})" + ) + + tmp_result = "auto tmp_result = " if ret_assignments else "" + indent = "\t\t" + ret_assignments_str = ( + "\n".join(indent + r for r in ret_assignments) if ret_assignments else "" + ) + definition = ( + textwrap.dedent(f""" + {declaration} {{ + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{ + {tmp_result}{backend_call}( + {", ".join(callsite_exprs)} + ); + """) + + ret_assignments_str + + textwrap.dedent(""" + }); + } + """) ) - args.extend(ret_declarations) - - declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})" - - tmp_result = "auto tmp_result = " if ret_assignments else "" - ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else "" - definition = f""" -{declaration} {{ - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{ - {tmp_result}{backend_call}( -{textwrap.indent(", ".join(callsite_exprs), " ")} - );{textwrap.indent(ret_assignments_str, " ")} - }}); -}} -""" - declaration_definition_cache[(func_name, device, backend_call)] = ( - declaration, - definition, + skipped_args.update(new_args) + declarations.append(f"AOTI_TORCH_EXPORT {declaration};") + definitions.append(definition) + + declaration_definition_cache[(base_name, device, backend_call)] = ( + "\n".join(declarations), + "\n".join(definitions), ) - return declaration, definition + return declaration_definition_cache[(base_name, device, backend_call)] def gen_static_dispatch_backend_call_signature( @@ -402,6 +457,7 @@ def get_fallback_op_name(func: NativeFunction) -> str: def gen_c_shim( func: NativeFunction, + version_info: dict[str, list[str]], func_group_mapping: dict[OperatorName, NativeFunctionsGroup], dispatch_key: DispatchKey, backend_indices: dict[DispatchKey, BackendIndex], @@ -424,11 +480,13 @@ def gen_c_shim( try: if header: declaration, _ = gen_declaration_and_definition( - schema, device, backend_call + schema, device, backend_call, version_info ) - return f"AOTI_TORCH_EXPORT {declaration};" + return declaration else: - _, definition = gen_declaration_and_definition(schema, device, backend_call) + _, definition = gen_declaration_and_definition( + schema, device, backend_call, version_info + ) return definition except NotImplementedError: @@ -437,6 +495,7 @@ def gen_c_shim( @dataclass(frozen=True) class ShimGenerator: + inductor_fallback_ops: dict[str, dict[str, list[str]]] func_group_mapping: dict[OperatorName, NativeFunctionsGroup] dispatch_key: DispatchKey backend_indices: dict[DispatchKey, BackendIndex] @@ -448,8 +507,10 @@ def __call__( self, func: NativeFunction, ) -> str | None: + version_info = self.inductor_fallback_ops[get_fallback_op_name(func)] result = gen_c_shim( func, + version_info, self.func_group_mapping, self.dispatch_key, self.backend_indices, @@ -461,6 +522,7 @@ def __call__( def gen_aoti_c_shim( native_functions: Sequence[NativeFunction], + inductor_fallback_ops: dict[str, dict[str, list[str]]], func_group_mapping: dict[OperatorName, NativeFunctionsGroup], dispatch_key: DispatchKey, backend_indices: dict[DispatchKey, BackendIndex], @@ -472,6 +534,7 @@ def gen_aoti_c_shim( list( mapMaybe( ShimGenerator( + inductor_fallback_ops, func_group_mapping, dispatch_key, backend_indices, @@ -484,44 +547,169 @@ def gen_aoti_c_shim( ) device = dispatch_key.lower() warning = """ + // WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. // See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details""" if header: - return f""" -{warning} - -#pragma once + return ( + warning + + textwrap.dedent(""" -#include + #pragma once -#ifdef __cplusplus -extern "C" {{ -#endif + #include -{body} + #ifdef __cplusplus + extern "C" { + #endif -#ifdef __cplusplus -}} // extern "C" -#endif -""" + """) + + body + + textwrap.dedent(""" + #ifdef __cplusplus + } // extern "C" + #endif + """) + ) else: - return f""" -{warning} + return ( + warning + + textwrap.dedent(f""" + + #include + #include + + #ifndef AT_PER_OPERATOR_HEADERS + #include + #include + #include + #include + #else + """) + + includes + + textwrap.dedent(""" + #endif // AT_PER_OPERATOR_HEADERS + + using namespace torch::aot_inductor; + + """) + + body + ) -#include -#include -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#include -#include -#else -{includes} -#endif +def gen_aoti_c_shim_files( + aoti_fm: FileManager, + aoti_backends: set[DispatchKey], + native_functions: Sequence[NativeFunction], + backend_indices: dict[DispatchKey, BackendIndex], + structured_native_functions: Sequence[NativeFunctionsGroup], + extra_cuda_headers: str, + extend_aoti_c_shim: bool, + update_aoti_c_shim: bool, +) -> None: + structured_func_group_dict = {} + for func_group in structured_native_functions: + for func in func_group.functions(): + if func.structured_delegate is not None: + structured_func_group_dict[func.structured_delegate] = func_group + break + + for dispatch_key in aoti_backends: + fallbacks = {} + for func in native_functions: + op_name = get_fallback_op_name(func) + if op_name in inductor_fallback_ops: + fallbacks[op_name] = func + fallback_native_functions = tuple( + value for _, value in sorted(fallbacks.items()) + ) -using namespace torch::aot_inductor; + # header files were checked in for ABI-compatiblilty checking + header_file_name = f"c_shim_{dispatch_key.lower()}.h" + new_header = gen_aoti_c_shim( + fallback_native_functions, + inductor_fallback_ops, + structured_func_group_dict, + dispatch_key, + backend_indices, + header=True, + extend_aoti_c_shim=extend_aoti_c_shim, + includes="", + ) + if update_aoti_c_shim: + aoti_fm.write( + header_file_name, + lambda: new_header, + ) + else: + try: + with open( + os.path.join(aoti_fm.install_dir, header_file_name) + ) as old_file: + old_header = old_file.read() + + if old_header != new_header: + diff = "\n".join( + difflib.unified_diff( + old_header.splitlines(), + new_header.splitlines(), + fromfile="expected", + tofile="actual", + lineterm="", + ) + ) + + raise RuntimeError(f""" +The generated AOTInductor C shim header files have unexpectedly changed. This +indicates an AOTInductor fallback operator ABI backward compatibility breakage!!! +Only in a limited number of situations, this is allowed: + +1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py. +If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to add a new entry to +existing C shim header files. + +2. You added a new default argument to an existing fallback op. This is clearly a BC breaking +change in the AOTInductor land. You need to annotate the new default argument in +torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aoti-c-shim` to +update the C shim header files by creating different versions of the fallback op. See +https://github.com/pytorch/pytorch/pull/154848 as an example. + +{diff} + """) + except FileNotFoundError: + print( + f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found" + ) -{body}""" + # cpp files are always generated on-the-fly + def headers_for_aoti() -> str: + headers = [] + for func in fallback_native_functions: + header = get_header_for_aoti( + func, + structured_func_group_dict, + dispatch_key, + backend_indices, + extend_aoti_c_shim=extend_aoti_c_shim, + ) + if header is not None: + headers.append(header) + return "\n".join(sorted(set(headers))) + + extra_headers = extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else "" + + aoti_fm.write( + f"c_shim_{dispatch_key.lower()}.cpp", + lambda: gen_aoti_c_shim( + fallback_native_functions, + inductor_fallback_ops, + structured_func_group_dict, + dispatch_key, + backend_indices, + header=False, + extend_aoti_c_shim=extend_aoti_c_shim, + includes=headers_for_aoti() + "\n" + extra_headers, + ), + ) diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py deleted file mode 100644 index 306333f1eaef31..00000000000000 --- a/torchgen/gen_executorch.py +++ /dev/null @@ -1,1024 +0,0 @@ -from __future__ import annotations - -import argparse -import os -from collections import defaultdict -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, TextIO, TYPE_CHECKING - -import yaml - -# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. -from torchgen import dest -from torchgen.api import cpp as aten_cpp -from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType -from torchgen.context import ( - method_with_native_function, - method_with_nested_native_function, - with_native_function_and_index, -) -from torchgen.executorch.api import et_cpp -from torchgen.executorch.api.custom_ops import ( - ComputeNativeFunctionStub, - gen_custom_ops_registration, -) -from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature -from torchgen.executorch.api.unboxing import Unboxing -from torchgen.executorch.model import ETKernelIndex, ETKernelKey, ETParsedYaml -from torchgen.executorch.parse import ET_FIELDS, parse_et_yaml, parse_et_yaml_struct -from torchgen.gen import ( - get_custom_build_selector, - get_native_function_declarations, - get_native_function_declarations_from_ns_grouped_kernels, - get_native_function_schema_registrations, - LineLoader, - parse_native_yaml, -) -from torchgen.model import ( - BackendIndex, - BackendMetadata, - DEFAULT_KERNEL_NAMESPACE, - DispatchKey, - FunctionSchema, - Location, - NativeFunction, - NativeFunctionsGroup, - OperatorName, - Variant, -) -from torchgen.utils import ( - context, - FileManager, - make_file_manager, - mapMaybe, - NamespaceHelper, -) - - -if TYPE_CHECKING: - from collections.abc import Sequence - - from torchgen.selective_build.selector import SelectiveBuilder - - -def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str: - """ - A wrapper function to basically get `sig.decl(include_context=True)`. - For ATen kernel, the codegen has no idea about ET contextArg, so we - use this wrapper to add it. - """ - if isinstance(sig, ExecutorchCppSignature): - return sig.decl() - - returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type() - cpp_args = [a.decl() for a in sig.arguments()] - cpp_args_str = ", ".join([contextArg.decl()] + cpp_args) - sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})" - return sig_decl - - -def static_dispatch( - sig: CppSignature | ExecutorchCppSignature, - f: NativeFunction, - backend_indices: list[BackendIndex], -) -> str: - """ - For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one - native function exists, error out. A simplified version of register_dispatch_key.py - Arguments: - sig: A CppSignature for this native function we want to use. - f: NativeFunction to generate static dispatch. - backend_indices: All available backends. - Return: - C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);" - """ - if len(backend_indices) == 0 or f.manual_kernel_registration: - return "" - - backends = [b for b in backend_indices if b.has_kernel(f)] - static_block = None - if len(backends) == 1: - backend_metadata = backends[0].get_kernel(f) - if backend_metadata: - args = ", ".join(a.name for a in sig.arguments()) - # Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch. - static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});" - else: - static_block = f""" -ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}."); - """ - return f""" -// {f.namespace}::{f.func} -TORCH_API inline {_sig_decl_wrapper(sig)} {{ - {static_block} -}} -""" - - -# Generates Functions.h, which provides the functional public C++ API, -# and the scaffolding to call into the dispatcher from these functions. -@dataclass(frozen=True) -class ComputeFunction: - static_dispatch_backend_indices: list[BackendIndex] - - selector: SelectiveBuilder - - use_aten_lib: bool - - is_custom_op: Callable[[NativeFunction], bool] - - @method_with_native_function - def __call__(self, f: NativeFunction) -> str | None: - is_method_variant = False - if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"): - return None - - if Variant.function not in f.variants and Variant.method in f.variants: - is_method_variant = True - - # only valid remaining case is only function is in f.variants - elif not (Variant.function in f.variants and Variant.method not in f.variants): - raise Exception( # noqa: TRY002 - f"Can't handle native function {f.func} with the following variant specification {f.variants}." - ) - - sig: CppSignature | ExecutorchCppSignature = ( - CppSignatureGroup.from_native_function( - f, method=False, fallback_binding=f.manual_cpp_binding - ).most_faithful_signature() - if self.use_aten_lib - else ExecutorchCppSignature.from_native_function(f) - ) - if self.use_aten_lib and not self.is_custom_op(f): - comma = ", " - - if is_method_variant: - return f""" -// {f.namespace}::{f.func} -TORCH_API inline {_sig_decl_wrapper(sig)} {{ - return {sig.arguments()[0].name}.{sig.name()}({comma.join(e.name for e in sig.arguments()[1:])}); -}} -""" - else: - return f""" -// {f.namespace}::{f.func} -TORCH_API inline {_sig_decl_wrapper(sig)} {{ - return at::{sig.name()}({comma.join(e.name for e in sig.arguments())}); -}} -""" - - else: - return static_dispatch( - sig, - f, - backend_indices=self.static_dispatch_backend_indices, - ) - - -# Generates RegisterCodegenUnboxedKernels.cpp. -@dataclass(frozen=True) -class ComputeCodegenUnboxedKernels: - selector: SelectiveBuilder - - use_aten_lib: bool - - add_exception_boundary: bool - - @method_with_nested_native_function - def __call__( - self, - unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]], - ) -> str: - f: NativeFunction = unbox_kernel_entry[0] - kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0] - kernel_meta: BackendMetadata = unbox_kernel_entry[1][1] - - op_name = f"{f.namespace}::{f.func.name}" - if not self.selector.is_root_operator(op_name): - return "" - - if not isinstance(kernel_key, list): - kernel_key = [kernel_key] - used_kernel_keys = self.selector.et_get_selected_kernels( - op_name, [k.to_native_string() for k in kernel_key] - ) - if not used_kernel_keys: - return "" - sig: CppSignature | ExecutorchCppSignature - argument_type_gen: Callable[..., NamedCType] - return_type_gen: Callable[..., CType] - if self.use_aten_lib: - sig = CppSignatureGroup.from_native_function( - f, method=False, fallback_binding=f.manual_cpp_binding - ).most_faithful_signature() - argument_type_gen = aten_cpp.argumenttype_type - return_type_gen = aten_cpp.returns_type - arguments = sig.arguments() - kernel_call = f"torch::executor::{f.namespace}::{sig.name()}" - else: - sig = ExecutorchCppSignature.from_native_function(f) - argument_type_gen = et_cpp.argumenttype_type - return_type_gen = et_cpp.returns_type - arguments = sig.arguments(include_context=False) - kernel_call = f"{kernel_meta.cpp_namespace}::{kernel_meta.kernel}" - # parse arguments into C++ code - binding_list, code_list = Unboxing( - argument_type_gen=argument_type_gen - ).convert_arguments(arguments) - - # for each C++ argument, generate the conversion code - code_connector = "\n\t" - arg_connector = ", " - - args_str = f"{arg_connector.join(e.name for e in binding_list)}" - event_tracer_output_logging = "" - output_ids = [] - - if len(f.func.returns) == 0: - if len(f.func.arguments.out) == 0: - raise Exception( # noqa: TRY002 - f"Can't handle native function {f.func} with no returns and no out yet." - ) - out = f.func.arguments.out[0] - return_assignment = f"""stack[{len(binding_list)}] = &{out.name};""" - ret_prefix = "" - output_ids = [len(binding_list)] - else: - if len(f.func.arguments.out) == 0: - return_assignment = ( - f"""*stack[{len(binding_list)}] = EValue(result_);""" - ) - ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = " - output_ids = [len(binding_list)] - else: - return_assignment = "" - ret_prefix = "" - output_ids = [ - len(binding_list) - (i + 1) - for i in reversed(range(len(f.func.arguments.out))) - ] - - for output_id in output_ids: - event_tracer_output_logging += ( - f"internal::event_tracer_log_evalue(" - f"context.internal_event_tracer(), " - f"*stack[{output_id}]);\n" - ) - - exception_boundary_begin = "" - exception_boundary_end = "" - if self.add_exception_boundary: - indent = " " * 8 - exception_boundary_begin = indent + "try {" - exception_boundary_end = f"""{indent}}} catch (const std::exception& ex) {{ -{indent} ET_LOG(Error, "Kernel threw an exception: %s", ex.what()); -{indent} context.fail(torch::executor::Error::Internal); -{indent}}}""" - newline = "\n " - return "\n".join( - [ - f""" -Kernel( - "{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != "default" else ""} - []({contextArg.defn()}, EValue** stack) {{ - {code_connector.join(code_list)} - -{exception_boundary_begin} - internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_{f.func.name}"); - EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}"); - {ret_prefix}{kernel_call}(context, {args_str}); - {event_tracer_output_logging} - {return_assignment} -{exception_boundary_end} - }} -), -""" - for k in used_kernel_keys - ] - ) - - -def gen_unboxing( - *, - native_functions: Sequence[NativeFunction], - cpu_fm: FileManager, - selector: SelectiveBuilder, - use_aten_lib: bool, - kernel_index: ETKernelIndex, - manual_registration: bool, - add_exception_boundary: bool = False, -) -> None: - # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata)) - def key_func( - item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]], - ) -> str: - return item[0].root_name + ":" + item[1][0].to_native_string() - - items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [ - (native_function, (kernel_key, metadata)) - for native_function in native_functions - for kernel_key, metadata in kernel_index.get_kernels(native_function).items() - ] - - header = ["Functions.h" if use_aten_lib else "NativeFunctions.h"] - filename = ( - "RegisterKernels.cpp" - if manual_registration - else "RegisterCodegenUnboxedKernels.cpp" - ) - cpu_fm.write_sharded( - filename, - items, - key_fn=key_func, - env_callable=lambda unbox_kernel_entry: { - "unboxed_kernels": [ - ComputeCodegenUnboxedKernels( - selector, use_aten_lib, add_exception_boundary - )(unbox_kernel_entry) - ], - "fn_header": header - if unbox_kernel_entry == items[0] - else [], # Only write header once - }, - num_shards=1, - sharded_keys={"unboxed_kernels", "fn_header"}, - ) - - -@with_native_function_and_index # type: ignore[arg-type] -def compute_native_function_declaration( - g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex -) -> list[str]: - assert isinstance(g, NativeFunction) - sig = ExecutorchCppSignature.from_native_function(f=g) - metadata_list = kernel_index.get_kernels(g).values() - if metadata_list is None: - return [] - - # for kernels in lean mode, we declare two versions, one with context and one without. - # In the end we will cleanup the unused one. - def gen_decl(metadata: BackendMetadata, include_context: bool) -> str: - return f"{sig.decl(name=metadata.kernel, include_context=include_context)};" - - return [ - gen_decl(metadata, include_context) - for include_context in [False, True] - for metadata in metadata_list - ] - - -def gen_functions_declarations( - *, - native_functions: Sequence[NativeFunction], - kernel_index: ETKernelIndex, - selector: SelectiveBuilder, - use_aten_lib: bool, - custom_ops_native_functions: Sequence[NativeFunction] | None = None, -) -> str: - """ - Generates namespace separated C++ function API inline declaration/definitions. - Native functions are grouped by namespaces and the generated code is wrapped inside - namespace blocks. - - E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol - in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when - the other `custom_2::foo.out` is available. - """ - - # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet. - # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex. - - backend_index = kernel_index._to_backend_index() - - ns_grouped_functions = defaultdict(list) - for native_function in native_functions: - ns_grouped_functions[native_function.namespace].append(native_function) - functions_declarations = "" - newline = "\n" - for namespace in ns_grouped_functions: - ns_helper = NamespaceHelper( - namespace_str=namespace, - entity_name="", - max_level=3, - ) - declarations = list( - mapMaybe( - ComputeFunction( - static_dispatch_backend_indices=[backend_index], - selector=selector, - use_aten_lib=use_aten_lib, - is_custom_op=lambda f: custom_ops_native_functions is not None - and f in custom_ops_native_functions, - ), - ns_grouped_functions[namespace], - ) - ) - functions_declarations += f""" -{ns_helper.prologue} -{newline.join(declarations)} -{ns_helper.epilogue} - """ - return functions_declarations - - -def get_ns_grouped_kernels( - *, - native_functions: Sequence[NativeFunction], - kernel_index: ETKernelIndex, - native_function_decl_gen: Callable[ - [ - NativeFunctionsGroup | NativeFunction, - ETKernelIndex, - ], - list[str], - ], -) -> dict[str, list[str]]: - ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) - for f in native_functions: - native_function_namespaces = set() - op_kernels = kernel_index.get_kernels(f) - for backend_metadata in op_kernels.values(): - if backend_metadata: - namespace = backend_metadata.cpp_namespace - native_function_namespaces.add(namespace) - else: - namespace = DEFAULT_KERNEL_NAMESPACE - assert len(native_function_namespaces) <= 1, ( - f"Codegen only supports one namespace per operator, got {native_function_namespaces}" - ) - ns_grouped_kernels[namespace].extend( - native_function_decl_gen(f, kernel_index) - ) - return ns_grouped_kernels - - -def gen_headers( - *, - native_functions: Sequence[NativeFunction], - gen_custom_ops_header: bool, - custom_ops_native_functions: Sequence[NativeFunction], - selector: SelectiveBuilder, - kernel_index: ETKernelIndex, - cpu_fm: FileManager, - use_aten_lib: bool, -) -> None: - """Generate headers. - - Args: - native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops. - gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h - custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops. - kernel_index (ETKernelIndex): kernel collection - cpu_fm (FileManager): file manager manages output stream - use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types. - """ - aten_headers = ["#include "] - backend_indices = {DispatchKey.CPU: kernel_index._to_backend_index()} - if gen_custom_ops_header: - cpu_fm.write_with_template( - "CustomOpsNativeFunctions.h", - "NativeFunctions.h", - lambda: { - "nativeFunctions_declarations": get_native_function_declarations( - grouped_native_functions=custom_ops_native_functions, - backend_indices=backend_indices, - native_function_decl_gen=dest.compute_native_function_declaration, - ), - "headers": [ - "#include ", - "#include ", - ], - }, - ) - aten_headers.append('#include "CustomOpsNativeFunctions.h"') - cpu_fm.write( - "Functions.h", - lambda: { - "static_dispatch_extra_headers": aten_headers - if use_aten_lib - else ['#include "NativeFunctions.h"'], - "Functions_declarations": gen_functions_declarations( - native_functions=native_functions, - kernel_index=kernel_index, - selector=selector, - use_aten_lib=use_aten_lib, - custom_ops_native_functions=custom_ops_native_functions, - ), - }, - ) - cpu_fm.write( - "RegisterKernels.h", - lambda: { - "generated_comment": "@" + "generated by torchgen/gen_executorch.py", - }, - ) - headers = { - "headers": [ - "#include // at::Tensor etc.", - "#include ", - ], - } - if use_aten_lib: - headers["headers"].append("#include // TORCH_API") - cpu_fm.write( - "NativeFunctions.h", - lambda: dict( - { - "nativeFunctions_declarations": get_native_function_declarations( - grouped_native_functions=native_functions, - backend_indices=backend_indices, - native_function_decl_gen=dest.compute_native_function_declaration, - ), - }, - **headers, - ), - ) - else: - ns_grouped_kernels = get_ns_grouped_kernels( - native_functions=native_functions, - kernel_index=kernel_index, - native_function_decl_gen=compute_native_function_declaration, # type: ignore[arg-type] - ) - cpu_fm.write( - "NativeFunctions.h", - lambda: dict( - { - "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels( - ns_grouped_kernels=ns_grouped_kernels, - ), - }, - **headers, - ), - ) - - -def gen_custom_ops( - *, - native_functions: Sequence[NativeFunction], - selector: SelectiveBuilder, - kernel_index: ETKernelIndex, - cpu_fm: FileManager, - rocm: bool, -) -> None: - dispatch_key = DispatchKey.CPU - ( - anonymous_definition, - static_init_dispatch_registrations, - ) = gen_custom_ops_registration( - native_functions=native_functions, - selector=selector, - kernel_index=kernel_index, - rocm=rocm, - ) - cpu_fm.write_with_template( - f"Register{dispatch_key}CustomOps.cpp", - "RegisterDispatchKeyCustomOps.cpp", - lambda: { - "ops_headers": '#include "CustomOpsNativeFunctions.h"', - "DispatchKey": dispatch_key, - "dispatch_namespace": dispatch_key.lower(), - "dispatch_namespaced_definitions": "", - "dispatch_anonymous_definitions": anonymous_definition, - "static_init_dispatch_registrations": static_init_dispatch_registrations, - }, - ) - cpu_fm.write_with_template( - f"Register{dispatch_key}Stub.cpp", - "RegisterDispatchKeyCustomOps.cpp", - lambda: { - "ops_headers": "", - "DispatchKey": dispatch_key, - "dispatch_namespace": dispatch_key.lower(), - "dispatch_namespaced_definitions": "", - "dispatch_anonymous_definitions": list( - mapMaybe(ComputeNativeFunctionStub(), native_functions) - ), - "static_init_dispatch_registrations": static_init_dispatch_registrations, - }, - ) - - ( - aten_schema_registrations, - schema_registrations, - ) = get_native_function_schema_registrations( - native_functions=native_functions, - schema_selector=selector, - ) - cpu_fm.write( - "RegisterSchema.cpp", - lambda: { - "schema_registrations": schema_registrations, - "aten_schema_registrations": aten_schema_registrations, - }, - ) - - -def translate_native_yaml( - tags_yaml_path: str, - aten_yaml_path: str, - native_yaml_path: str | None, - use_aten_lib: bool, - out_file: TextIO, -) -> None: - """Translates Executorch DSL dialect to use the same syntax as - native_functions.yaml. The major difference is that Executorch DSL dialect - supports "op" key, where it refers to the operator name in native_functions.yaml. - - For example, a functions.yaml may have the following entry: - - - op: add.out - ... - - It needs to be translated to the following: - - - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - ... - - We go in aten_yaml_path and find the operator schema for "add.out" and add it - to the original functions.yaml. We also add required field "variants", where for - Executorch it will always be "function". - - For ATen mode we don't have to do the translation because native_yaml_path is - the same as native_functions.yaml. - - Args: - tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing. - It is not optional. - aten_yaml_path: Path to ATen operator yaml file native_functions.yaml. - native_yaml_path: Path to a functions.yaml file to parse. - If the path does not exist in the filesystem, it is treated as an - empty file. If `custom_ops_yaml_path` exists, the contents of that - file are appended to the yaml input to be parsed. - use_aten_lib: We use this flag to determine if we want to generate native - functions. In ATen mode we should generate out= variants. - out_file: The IO object that we are writing into. - Returns: - None - """ - if use_aten_lib: - with open(aten_yaml_path) as aten_yaml: - out_file.writelines(aten_yaml.readlines()) - return - - native_functions, persisted_fields = parse_et_yaml( - aten_yaml_path, - tags_yaml_path, - None, - skip_native_fns_gen=False, - ) - - func_to_scoped_name: dict[FunctionSchema, str] = { - f.func: f"{f.namespace}::{f.func.name}" for f in native_functions - } - op_to_scoped_name: dict[OperatorName, str] = { - func.name: name for func, name in func_to_scoped_name.items() - } - - schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()} - kernel_persist_dict: dict[str, dict[str, Any]] = { - op_to_scoped_name[op]: v for op, v in persisted_fields.items() - } - - if ( - not native_yaml_path - or not os.path.exists(native_yaml_path) - or os.stat(native_yaml_path).st_size == 0 - ): - return - with open(native_yaml_path) as native_yaml: - native_es = yaml.load(native_yaml, Loader=LineLoader) - if not native_es: - return - for e in native_es: - assert isinstance(e.get("__line__"), int), e - loc = Location(native_yaml_path, e.pop("__line__")) - with context(lambda: f"in {loc}:\n "): - if "variants" not in e: - e["variants"] = "function" - if "func" in e: - continue - assert isinstance(e.get("op"), str), e - opname = e.pop("op") - if "::" not in opname: - opname = "aten::" + opname - assert opname in schema_dict - e["func"] = schema_dict.get(opname) - - # Write out persisted kernel information - if opname in kernel_persist_dict: - for k, v in kernel_persist_dict[opname].items(): - e[k] = v - - yaml.dump(native_es, out_file, width=1000) - - -def parse_yaml( - path: str | None, - tags_yaml_path: str, - function_filter: Callable[[NativeFunction], bool], - skip_native_fns_gen: bool = False, -) -> tuple[ - list[NativeFunction], - dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex, -]: - if path and os.path.exists(path) and os.stat(path).st_size > 0: - with open(path) as f: - es = yaml.load(f, Loader=LineLoader) - - # Check for kernel index structure - kernel_index = ( - parse_et_yaml_struct(es) if any("kernels" in e for e in es) else None - ) - - # Remove ET specific fields from entries for BC compatibility - for entry in es: - for field in ET_FIELDS: - entry.pop(field, None) - - parsed_yaml = parse_native_yaml( - path, - tags_yaml_path, - None, - skip_native_fns_gen=skip_native_fns_gen, - loaded_yaml=es, - ) - native_functions = list(filter(function_filter, parsed_yaml.native_functions)) - op_names = [f.func.name for f in native_functions] - - # (1) Return ETKernelIndex if kernel index is present - if kernel_index is not None: - filtered_index = { - op_name: kernel_mapping - for op_name, kernel_mapping in kernel_index.index.items() - if op_name in op_names - } - return native_functions, ETKernelIndex(index=filtered_index) - - # (2) Return BackendIndices if kernel index is absent - def map_index( - m: dict[OperatorName, BackendMetadata], - ) -> dict[OperatorName, BackendMetadata]: - return {op: m[op] for op in m if op in op_names} - - backend_indices = { - k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items() - } - - return native_functions, backend_indices - else: - return [], {} - - -def parse_yaml_files( - tags_yaml_path: str, - aten_yaml_path: str, - native_yaml_path: str | None, - custom_ops_yaml_path: str | None, - selector: SelectiveBuilder, - use_aten_lib: bool, -) -> tuple[ETParsedYaml, ETParsedYaml | None]: - """Parses functions.yaml and custom_ops.yaml files. - - Args: - tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing. - It is not optional. - aten_yaml_path: Path to ATen operator yaml file native_functions.yaml. - native_yaml_path: Path to a functions.yaml file to parse. - If the path does not exist in the filesystem, it is treated as an - empty file. If `custom_ops_yaml_path` exists, the contents of that - file are appended to the yaml input to be parsed. - custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If - the path does not exist in the filesystem, it is ignored. - selector: For selective build. - use_aten_lib: We use this flag to determine if we want to generate native - functions. In ATen mode we should generate out= variants. - Returns: - A tuple with two elements: - [0]: The parsed results of concatenating the contents of - `native_yaml_path` and `custom_ops_yaml_path`. - [1]: The parsed results of the contents of `custom_ops_yaml_path`, if - present. If not present, None. - """ - import tempfile - - # only include selected ops, this is because we want to avoid - def function_filter(f: NativeFunction) -> bool: - return selector.is_native_function_selected(f) - - with tempfile.TemporaryDirectory() as tmpdirname: - translated_yaml_path = os.path.join(tmpdirname, "translated.yaml") - with open(translated_yaml_path, "w") as translated: - translate_native_yaml( - tags_yaml_path, - aten_yaml_path, - native_yaml_path, - use_aten_lib, - translated, - ) - - translated_functions, translated_indices = parse_yaml( - translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib - ) - custom_ops_functions, custom_ops_indices = parse_yaml( - custom_ops_yaml_path, tags_yaml_path, function_filter, True - ) - - # Convert BackendIndices to ETKernelIndex - if not isinstance(translated_indices, ETKernelIndex): - translated_indices = ETKernelIndex.from_backend_indices(translated_indices) - if not isinstance(custom_ops_indices, ETKernelIndex): - custom_ops_indices = ETKernelIndex.from_backend_indices(custom_ops_indices) - - combined_functions = translated_functions + custom_ops_functions - combined_kernel_index = ETKernelIndex.merge_indices( - translated_indices, custom_ops_indices - ) - combined_yaml = ETParsedYaml(combined_functions, combined_kernel_index) - custom_ops_parsed_yaml = ETParsedYaml(custom_ops_functions, custom_ops_indices) - - return combined_yaml, custom_ops_parsed_yaml - - -def main() -> None: - parser = argparse.ArgumentParser(description="Generate operator source files") - # Although we don't refer to --source-path directly, make_file_manager() - # expects it to point to a directory that contains a templates/ subdirectory - # containing the file templates. - parser.add_argument( - "-s", - "--source-path", - help="path to source directory for kernel templates", - ) - parser.add_argument( - "--functions-yaml-path", - "--functions_yaml_path", - help="path to the functions.yaml file to use. Optional, but at least " - "one of --functions-yaml-path and --custom-ops-yaml-path must be " - "specified.", - ) - parser.add_argument( - "--custom-ops-yaml-path", - "--custom_ops_yaml_path", - help="path to the custom_ops.yaml file to use. Optional, but at least " - "one of --functions-yaml-path and --custom-ops-yaml-path must be " - "specified.", - ) - parser.add_argument( - "--aten-yaml-path", - "--aten_yaml_path", - help="path to native_functions.yaml file.", - ) - # Note that make_file_manager() also looks at --install-dir. - parser.add_argument( - "-d", - "--install-dir", - "--install_dir", - help="output directory", - default="build/generated", - ) - parser.add_argument( - "-o", - "--output-dependencies", - help="output a list of dependencies into the given file and exit", - ) - # Although we don't refer to --dry-run directly, make_file_manager() looks - # for it. - parser.add_argument( - "--dry-run", - action="store_true", - help="run without writing any files (still updates outputs)", - ) - parser.add_argument( - "--static-dispatch-backend", - "--static_dispatch_backend", - nargs="*", - help="generate static dispatch code for the specific backend (if set)", - ) - parser.add_argument( - "--op-registration-whitelist", - "--op_registration_whitelist", - nargs="*", - help="filter op registrations by the whitelist (if set); " - "each item is `namespace`::`operator name` without overload name; " - "e.g.: aten::empty aten::conv2d ...", - ) - parser.add_argument( - "--op-selection-yaml-path", - "--op_selection_yaml_path", - help="Provide a path to the operator selection (for custom build) YAML " - "that contains the information about the set of selected operators " - "and their categories (training, ...). Each operator is either a " - "full operator name with overload or just a bare operator name. " - "The operator names also contain the namespace prefix (e.g. aten::)", - ) - parser.add_argument( - "--tags-path", - help="Path to tags.yaml. Required by yaml parsing in codegen system.", - ) - parser.add_argument( - "--rocm", - action="store_true", - help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly", - ) - parser.add_argument( - "--use-aten-lib", - "--use_aten_lib", - action="store_true", - help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per " - "operator", - ) - parser.add_argument( - "--manual_registration", - "--manual-registration", - action="store_true", - help="a boolean flag to indicate whether we want to manually call" - "register_kernels() or rely on static init. ", - ) - parser.add_argument( - "--generate", - type=str, - nargs="*", - choices=["headers", "sources"], - default=["headers", "sources"], - help="Generate only a subset of files", - ) - parser.add_argument( - "--add-exception-boundary", - "--add_exception_boundary", - action="store_true", - help="whether to add a try/catch in the generated kernel wrapper to " - "convert exceptions to clean failures.", - ) - options = parser.parse_args() - assert options.tags_path, "tags.yaml is required by codegen yaml parsing." - - selector = get_custom_build_selector( - options.op_registration_whitelist, - options.op_selection_yaml_path, - ) - - parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( - aten_yaml_path=options.aten_yaml_path, - tags_yaml_path=options.tags_path, - native_yaml_path=options.functions_yaml_path, - custom_ops_yaml_path=options.custom_ops_yaml_path, - selector=selector, - use_aten_lib=options.use_aten_lib, - ) - native_functions, kernel_index = ( - parsed_yaml.native_functions, - parsed_yaml.kernel_index, - ) - custom_ops_native_functions = ( - custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else [] - ) - - cpu_fm = make_file_manager(options=options) - - if "headers" in options.generate: - # generate CustomOpsNativeFunctions.h when custom_ops.yaml is present, to match the build system. - gen_headers( - native_functions=native_functions, - gen_custom_ops_header=options.custom_ops_yaml_path, - custom_ops_native_functions=custom_ops_native_functions, - selector=selector, - kernel_index=kernel_index, - cpu_fm=cpu_fm, - use_aten_lib=options.use_aten_lib, - ) - - if "sources" in options.generate: - gen_unboxing( - native_functions=native_functions, - cpu_fm=cpu_fm, - selector=selector, - use_aten_lib=options.use_aten_lib, - kernel_index=kernel_index, - manual_registration=options.manual_registration, - add_exception_boundary=options.add_exception_boundary, - ) - if custom_ops_native_functions: - gen_custom_ops( - native_functions=custom_ops_native_functions, - selector=selector, - kernel_index=kernel_index, - cpu_fm=cpu_fm, - rocm=options.rocm, - ) - - if options.output_dependencies: - depfile_path = Path(options.output_dependencies).resolve() - depfile_name = depfile_path.name - depfile_stem = depfile_path.stem - - for fm, prefix in [ - (cpu_fm, ""), - ]: - varname = prefix + depfile_stem - path = depfile_path.parent / (prefix + depfile_name) - fm.write_outputs(varname, str(path)) - - -if __name__ == "__main__": - main() diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index bf4b884d849f52..42407974087a0f 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -198,7 +198,7 @@ def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool: # We need to wrap / unwrap various arguments from the op in the functionalization kernels. # Some op schemas include non-owning types though (like TensorList), # and when we unwrap them we expect to get out an owning type!. -# We also return a lambda that tells you how to conver the non-owning type argument into the owning type. +# We also return a lambda that tells you how to convert the non-owning type argument into the owning type. def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]: if t == BaseCType(tensorListT): return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()" @@ -441,7 +441,7 @@ def emit_view_functionalization_body( // This function adds the above view meta to the current tensor and replays them off the base, // mutating the size/stride info of the current FunctionalTensorWrapper. // Because of this, we need to make sure to run the reference shape function above, - // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides) + // BEFORE doing this (otherwise we'll end up running the reference function using the wrong sizes/strides) at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta); // See Note [Propagating strides in the functionalization pass] // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely diff --git a/torchgen/gen_schema_utils.py b/torchgen/gen_schema_utils.py index 1095d2e7e4313b..b81c91527baa18 100644 --- a/torchgen/gen_schema_utils.py +++ b/torchgen/gen_schema_utils.py @@ -47,7 +47,7 @@ def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: all_base_tys = [TypeGen.from_example(x) for x in obj] if len(set(all_base_tys)) > 1: raise RuntimeError( - f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. " + f"Cannot generate schema for a sequence of args of heterogeneous types: {all_base_tys}. " "Consider unpacking the argument and give proper names to them if possible " "instead of using *args." ) diff --git a/torchgen/model.py b/torchgen/model.py index 89a56d98e74e7f..eb3a80dffe6a0d 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -288,6 +288,8 @@ def codegen_per_backend_entries() -> str: DispatchKey.SparseCsrXPU, DispatchKey.SparseCUDA, DispatchKey.SparseCsrCUDA, + DispatchKey.SparseMPS, + DispatchKey.SparseCsrMPS, DispatchKey.QuantizedCPU, DispatchKey.QuantizedCUDA, DispatchKey.CompositeImplicitAutograd, @@ -593,7 +595,7 @@ class NativeFunction: has_composite_explicit_autograd_non_functional_kernel: bool # Tags are used to describe semantic information about (groups of) operators, - # That aren't easily inferrable directly from the operator's schema. + # That aren't easily inferable directly from the operator's schema. tags: set[str] # NB: The benefit of defining a dataclass is that we automatically get diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py index 56a3d8bf0dd3f1..6238f9741f872d 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -102,7 +102,7 @@ def gen_serialized_decompisitions() -> str: output_strs.append(curr_str) final_output = "" - # Windows compiler doesnt correctly handle adjacent + # Windows compiler doesn't correctly handle adjacent # string literals for output_str in output_strs: start = '+ std::string(R"=====(' diff --git a/version.txt b/version.txt index 11922a5ce16846..03e905f0db5fe5 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.8.0a0 +2.9.0a0